Example 14: ScaLAPACK Compatibility

This example demonstrates SLATE’s ScaLAPACK compatibility layer.

Key Concepts

  1. ScaLAPACK Interception: SLATE can intercept standard ScaLAPACK calls (like pdgemm) and execute them using SLATE algorithms.

  2. Legacy Code Support: Allows existing ScaLAPACK applications to benefit from SLATE performance without code changes (just linking).

  3. BLACS Initialization: The example sets up the BLACS grid and ScaLAPACK descriptors as usual.

C++ Example

BLACS Initialization (Lines 45-52)

Cblacs_pinfo( &iam, &nprocs );
Cblacs_get( -1, 0, &ictxt );
Cblacs_gridinit( &ictxt, "Col", grid_p, grid_q );

Standard setup for any ScaLAPACK program. This initializes the process grid.

ScaLAPACK Descriptors (Lines 55-82)

int mlocA = numroc( ... );
descinit( descA, ... );

Allocates local memory (mloc * nloc) and initializes the array descriptor descA which describes the distributed matrix layout (dimensions, block size, process grid). This is standard ScaLAPACK boilerplate.

PBLAS Call (Lines 88-111)

psgemm( ... ); // float
pdgemm( ... ); // double
pcgemm( ... ); // complex<float>
pzgemm( ... ); // complex<double>

The code calls the standard PBLAS functions (p[sdcz]gemm). - Crucial Point: If this program is linked against the SLATE ScaLAPACK API library (-lslate_scalapack_api), these calls will be intercepted by SLATE. - SLATE converts the ScaLAPACK descriptors to SLATE Matrix objects internally, executes the operation using SLATE’s engine (potentially on GPUs), and then ensures the result is consistent with ScaLAPACK expectations. - This allows drop-in acceleration for legacy codes.

  1// ex14_scalapack_gemm.cc
  2// SLATE intercepts ScaLAPACK calls.
  3
  4/// !!!   Lines between `//---------- begin label`          !!!
  5/// !!!             and `//---------- end label`            !!!
  6/// !!!   are included in the SLATE Users' Guide.           !!!
  7
  8#include <mpi.h>
  9
 10#include "util.hh"
 11#include "scalapack.h"
 12
 13int mpi_size = 0;
 14int mpi_rank = 0;
 15int grid_p = 0;
 16int grid_q = 0;
 17
 18//------------------------------------------------------------------------------
 19// We don't include slate.hh here, so define a simple slate_mpi_call.
 20void slate_mpi_call_( int err, const char* file, int line )
 21{
 22    if (err != 0) {
 23        char msg[ 80 ];
 24        snprintf( msg, sizeof(msg), "MPI error %d at %s:%d", err, file, line );
 25        throw std::runtime_error( msg );
 26    }
 27}
 28
 29#define slate_mpi_call( err ) \
 30        slate_mpi_call_( err, __FILE__, __LINE__ )
 31
 32//------------------------------------------------------------------------------
 33template <typename scalar_type>
 34void test_pgemm()
 35{
 36    print_func( mpi_rank );
 37
 38    // constants
 39    int izero = 0, ione = 1;
 40
 41    // problem size and distribution
 42    int m = 15, n = 18, k = 13, nb = 4;
 43
 44    // initialize BLACS communication
 45    int p_, q_, nprocs, ictxt, iam, myrow, mycol, info;
 46    Cblacs_pinfo( &iam, &nprocs );
 47    assert( grid_p * grid_q <= nprocs );
 48    Cblacs_get( -1, 0, &ictxt );
 49    Cblacs_gridinit( &ictxt, "Col", grid_p, grid_q );
 50    Cblacs_gridinfo( ictxt, &p_, &q_, &myrow, &mycol );
 51    assert( p_ == grid_p );
 52    assert( q_ == grid_q );
 53
 54    // matrix A: get local size, allocate, create descriptor, initialize
 55    int mlocA = numroc( &m, &nb, &myrow, &izero, &grid_p );
 56    int nlocA = numroc( &k, &nb, &mycol, &izero, &grid_q );
 57    int lldA  = mlocA;
 58    int descA[9];
 59    descinit( descA, &m, &k, &nb, &nb, &izero, &izero, &ictxt, &lldA, &info );
 60    assert( info == 0 );
 61    std::vector<scalar_type> dataA( lldA * nlocA );
 62    random_matrix( mlocA, nlocA, &dataA[0], lldA );
 63
 64    // matrix B: get local size, allocate, create descriptor, initialize
 65    int mlocB = numroc( &k, &nb, &myrow, &izero, &grid_p );
 66    int nlocB = numroc( &n, &nb, &mycol, &izero, &grid_q );
 67    int lldB  = mlocB;
 68    int descB[9];
 69    descinit( descB, &k, &n, &nb, &nb, &izero, &izero, &ictxt, &lldB, &info );
 70    assert( info == 0 );
 71    std::vector<scalar_type> dataB( lldB * nlocB );
 72    random_matrix( mlocB, nlocB, &dataB[0], lldB );
 73
 74    // matrix C: get local size, allocate, create descriptor, initialize
 75    int mlocC = numroc( &m, &nb, &myrow, &izero, &grid_p );
 76    int nlocC = numroc( &n, &nb, &mycol, &izero, &grid_q );
 77    int lldC  = mlocC;
 78    int descC[9];
 79    descinit( descC, &m, &n, &nb, &nb, &izero, &izero, &ictxt, &lldC, &info );
 80    assert( info == 0 );
 81    std::vector<scalar_type> dataC( lldC * nlocC );
 82    random_matrix( mlocC, nlocC, &dataC[0], lldC );
 83
 84    scalar_type alpha = 2.7183;
 85    scalar_type beta  = 3.1415;
 86
 87    // gemm: C = alpha A B + beta C
 88    if constexpr (std::is_same< scalar_type, float >::value) {
 89        psgemm( "notrans", "notrans", &m, &n, &k,
 90                &alpha, &dataA[0], &ione, &ione, descA,
 91                        &dataB[0], &ione, &ione, descB,
 92                &beta,  &dataC[0], &ione, &ione, descC );
 93    }
 94    else if constexpr (std::is_same< scalar_type, double >::value) {
 95        pdgemm( "notrans", "notrans", &m, &n, &k,
 96                &alpha, &dataA[0], &ione, &ione, descA,
 97                        &dataB[0], &ione, &ione, descB,
 98                &beta,  &dataC[0], &ione, &ione, descC );
 99    }
100    else if constexpr (std::is_same< scalar_type, std::complex<float> >::value) {
101        pcgemm( "notrans", "notrans", &m, &n, &k,
102                &alpha, &dataA[0], &ione, &ione, descA,
103                        &dataB[0], &ione, &ione, descB,
104                &beta,  &dataC[0], &ione, &ione, descC );
105    }
106    else if constexpr (std::is_same< scalar_type, std::complex<double> >::value) {
107        pzgemm( "notrans", "notrans", &m, &n, &k,
108                &alpha, &dataA[0], &ione, &ione, descA,
109                        &dataB[0], &ione, &ione, descB,
110                &beta,  &dataC[0], &ione, &ione, descC );
111    }
112}
113
114//------------------------------------------------------------------------------
115int main( int argc, char** argv )
116{
117    try {
118        // Parse command line to set types for s, d, c, z precisions.
119        bool types[ 4 ];
120        parse_args( argc, argv, types );
121
122        int provided = 0;
123        slate_mpi_call(
124            MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) );
125        assert( provided == MPI_THREAD_MULTIPLE );
126
127        slate_mpi_call(
128            MPI_Comm_size( MPI_COMM_WORLD, &mpi_size ) );
129
130        slate_mpi_call(
131            MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank ) );
132
133        // Determine p-by-q grid for this MPI size.
134        grid_size( mpi_size, &grid_p, &grid_q );
135        if (mpi_rank == 0) {
136            printf( "mpi_size %d, grid_p %d, grid_q %d\n",
137                    mpi_size, grid_p, grid_q );
138        }
139
140        // so random_matrix is different on different ranks.
141        srand( 100 * mpi_rank );
142
143        if (types[ 0 ]) {
144            test_pgemm< float >();
145        }
146
147        if (types[ 1 ]) {
148            test_pgemm< double >();
149        }
150
151        if (types[ 2 ]) {
152            test_pgemm< std::complex<float> >();
153        }
154
155        if (types[ 3 ]) {
156            test_pgemm< std::complex<double> >();
157        }
158
159        slate_mpi_call(
160            MPI_Finalize() );
161    }
162    catch (std::exception const& ex) {
163        fprintf( stderr, "%s", ex.what() );
164        return 1;
165    }
166    return 0;
167}