Example 14: ScaLAPACK Compatibility
This example demonstrates SLATE’s ScaLAPACK compatibility layer.
Key Concepts
ScaLAPACK Interception: SLATE can intercept standard ScaLAPACK calls (like
pdgemm) and execute them using SLATE algorithms.Legacy Code Support: Allows existing ScaLAPACK applications to benefit from SLATE performance without code changes (just linking).
BLACS Initialization: The example sets up the BLACS grid and ScaLAPACK descriptors as usual.
C++ Example
BLACS Initialization (Lines 45-52)
Cblacs_pinfo( &iam, &nprocs );
Cblacs_get( -1, 0, &ictxt );
Cblacs_gridinit( &ictxt, "Col", grid_p, grid_q );
Standard setup for any ScaLAPACK program. This initializes the process grid.
ScaLAPACK Descriptors (Lines 55-82)
int mlocA = numroc( ... );
descinit( descA, ... );
Allocates local memory (mloc * nloc) and initializes the array descriptor descA which describes the distributed matrix layout (dimensions, block size, process grid). This is standard ScaLAPACK boilerplate.
PBLAS Call (Lines 88-111)
psgemm( ... ); // float
pdgemm( ... ); // double
pcgemm( ... ); // complex<float>
pzgemm( ... ); // complex<double>
The code calls the standard PBLAS functions (p[sdcz]gemm). - Crucial Point: If this program is linked against the SLATE ScaLAPACK API library (-lslate_scalapack_api), these calls will be intercepted by SLATE. - SLATE converts the ScaLAPACK descriptors to SLATE Matrix objects internally, executes the operation using SLATE’s engine (potentially on GPUs), and then ensures the result is consistent with ScaLAPACK expectations. - This allows drop-in acceleration for legacy codes.
1// ex14_scalapack_gemm.cc
2// SLATE intercepts ScaLAPACK calls.
3
4/// !!! Lines between `//---------- begin label` !!!
5/// !!! and `//---------- end label` !!!
6/// !!! are included in the SLATE Users' Guide. !!!
7
8#include <mpi.h>
9
10#include "util.hh"
11#include "scalapack.h"
12
13int mpi_size = 0;
14int mpi_rank = 0;
15int grid_p = 0;
16int grid_q = 0;
17
18//------------------------------------------------------------------------------
19// We don't include slate.hh here, so define a simple slate_mpi_call.
20void slate_mpi_call_( int err, const char* file, int line )
21{
22 if (err != 0) {
23 char msg[ 80 ];
24 snprintf( msg, sizeof(msg), "MPI error %d at %s:%d", err, file, line );
25 throw std::runtime_error( msg );
26 }
27}
28
29#define slate_mpi_call( err ) \
30 slate_mpi_call_( err, __FILE__, __LINE__ )
31
32//------------------------------------------------------------------------------
33template <typename scalar_type>
34void test_pgemm()
35{
36 print_func( mpi_rank );
37
38 // constants
39 int izero = 0, ione = 1;
40
41 // problem size and distribution
42 int m = 15, n = 18, k = 13, nb = 4;
43
44 // initialize BLACS communication
45 int p_, q_, nprocs, ictxt, iam, myrow, mycol, info;
46 Cblacs_pinfo( &iam, &nprocs );
47 assert( grid_p * grid_q <= nprocs );
48 Cblacs_get( -1, 0, &ictxt );
49 Cblacs_gridinit( &ictxt, "Col", grid_p, grid_q );
50 Cblacs_gridinfo( ictxt, &p_, &q_, &myrow, &mycol );
51 assert( p_ == grid_p );
52 assert( q_ == grid_q );
53
54 // matrix A: get local size, allocate, create descriptor, initialize
55 int mlocA = numroc( &m, &nb, &myrow, &izero, &grid_p );
56 int nlocA = numroc( &k, &nb, &mycol, &izero, &grid_q );
57 int lldA = mlocA;
58 int descA[9];
59 descinit( descA, &m, &k, &nb, &nb, &izero, &izero, &ictxt, &lldA, &info );
60 assert( info == 0 );
61 std::vector<scalar_type> dataA( lldA * nlocA );
62 random_matrix( mlocA, nlocA, &dataA[0], lldA );
63
64 // matrix B: get local size, allocate, create descriptor, initialize
65 int mlocB = numroc( &k, &nb, &myrow, &izero, &grid_p );
66 int nlocB = numroc( &n, &nb, &mycol, &izero, &grid_q );
67 int lldB = mlocB;
68 int descB[9];
69 descinit( descB, &k, &n, &nb, &nb, &izero, &izero, &ictxt, &lldB, &info );
70 assert( info == 0 );
71 std::vector<scalar_type> dataB( lldB * nlocB );
72 random_matrix( mlocB, nlocB, &dataB[0], lldB );
73
74 // matrix C: get local size, allocate, create descriptor, initialize
75 int mlocC = numroc( &m, &nb, &myrow, &izero, &grid_p );
76 int nlocC = numroc( &n, &nb, &mycol, &izero, &grid_q );
77 int lldC = mlocC;
78 int descC[9];
79 descinit( descC, &m, &n, &nb, &nb, &izero, &izero, &ictxt, &lldC, &info );
80 assert( info == 0 );
81 std::vector<scalar_type> dataC( lldC * nlocC );
82 random_matrix( mlocC, nlocC, &dataC[0], lldC );
83
84 scalar_type alpha = 2.7183;
85 scalar_type beta = 3.1415;
86
87 // gemm: C = alpha A B + beta C
88 if constexpr (std::is_same< scalar_type, float >::value) {
89 psgemm( "notrans", "notrans", &m, &n, &k,
90 &alpha, &dataA[0], &ione, &ione, descA,
91 &dataB[0], &ione, &ione, descB,
92 &beta, &dataC[0], &ione, &ione, descC );
93 }
94 else if constexpr (std::is_same< scalar_type, double >::value) {
95 pdgemm( "notrans", "notrans", &m, &n, &k,
96 &alpha, &dataA[0], &ione, &ione, descA,
97 &dataB[0], &ione, &ione, descB,
98 &beta, &dataC[0], &ione, &ione, descC );
99 }
100 else if constexpr (std::is_same< scalar_type, std::complex<float> >::value) {
101 pcgemm( "notrans", "notrans", &m, &n, &k,
102 &alpha, &dataA[0], &ione, &ione, descA,
103 &dataB[0], &ione, &ione, descB,
104 &beta, &dataC[0], &ione, &ione, descC );
105 }
106 else if constexpr (std::is_same< scalar_type, std::complex<double> >::value) {
107 pzgemm( "notrans", "notrans", &m, &n, &k,
108 &alpha, &dataA[0], &ione, &ione, descA,
109 &dataB[0], &ione, &ione, descB,
110 &beta, &dataC[0], &ione, &ione, descC );
111 }
112}
113
114//------------------------------------------------------------------------------
115int main( int argc, char** argv )
116{
117 try {
118 // Parse command line to set types for s, d, c, z precisions.
119 bool types[ 4 ];
120 parse_args( argc, argv, types );
121
122 int provided = 0;
123 slate_mpi_call(
124 MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) );
125 assert( provided == MPI_THREAD_MULTIPLE );
126
127 slate_mpi_call(
128 MPI_Comm_size( MPI_COMM_WORLD, &mpi_size ) );
129
130 slate_mpi_call(
131 MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank ) );
132
133 // Determine p-by-q grid for this MPI size.
134 grid_size( mpi_size, &grid_p, &grid_q );
135 if (mpi_rank == 0) {
136 printf( "mpi_size %d, grid_p %d, grid_q %d\n",
137 mpi_size, grid_p, grid_q );
138 }
139
140 // so random_matrix is different on different ranks.
141 srand( 100 * mpi_rank );
142
143 if (types[ 0 ]) {
144 test_pgemm< float >();
145 }
146
147 if (types[ 1 ]) {
148 test_pgemm< double >();
149 }
150
151 if (types[ 2 ]) {
152 test_pgemm< std::complex<float> >();
153 }
154
155 if (types[ 3 ]) {
156 test_pgemm< std::complex<double> >();
157 }
158
159 slate_mpi_call(
160 MPI_Finalize() );
161 }
162 catch (std::exception const& ex) {
163 fprintf( stderr, "%s", ex.what() );
164 return 1;
165 }
166 return 0;
167}