Example 05: BLAS Operations

This example demonstrates how to perform Basic Linear Algebra Subprograms (BLAS) operations in SLATE.

Key Concepts

  1. Matrix Multiplication: Using slate::multiply (gemm, hemm, symm) for matrix products.

  2. Rank Updates: Performing rank-k (herk, syrk) and rank-2k (her2k, syr2k) updates.

  3. Triangular Operations: Triangular matrix multiplication (trmm) and solving triangular systems (trsm).

  4. Simplified vs Traditional API: Comparing the descriptive multiply API with the traditional BLAS-named API.

C++ Example

General Matrix Multiplication (GEMM) (Lines 36-40)

// C = alpha A B + beta C
slate::multiply( alpha, A, B, beta, C );  // simplified API
slate::gemm( alpha, A, B, beta, C );      // traditional API

Here we perform the standard operation \(C = \alpha AB + \beta C\).

  • A is an m by k matrix.

  • B is a k by n matrix.

  • C is an m by n matrix.

SLATE provides both a descriptive multiply routine and the traditional BLAS-named gemm. They are equivalent.

GPU Execution with Options (Lines 43-52)

if (blas::get_device_count() > 0) {
    slate::Options opts = {
        { slate::Option::Lookahead, 2 },
        { slate::Option::Target, slate::Target::Devices },
    };
    slate::multiply( alpha, A, B, beta, C, opts );
}

Most SLATE routines accept an Options map as the final argument. Here we:

  • Set Target::Devices to offload computation to GPUs.

  • Set Lookahead to 2 to overlap communication and computation.

Transposed Multiplication (Lines 77-83)

auto AT = transpose( A );
auto BH = conj_transpose( B );
slate::multiply( alpha, AT, BH, beta, C );

To compute \(C = \alpha A^T B^H + \beta C\), we simply create transposed views AT and BH and pass them to the multiply function. SLATE detects the transposition flags on the views and handles the logic internally.

Symmetric/Hermitian Multiplication (SYMM/HEMM) (Lines 97-118)

slate::multiply( alpha, A, B, beta, C );                  // simplified
slate::symm( slate::Side::Left, alpha, A, B, beta, C );   // traditional

When A is a SymmetricMatrix (or HermitianMatrix), multiply automatically dispatches to the efficient symmetric/Hermitian algorithm (symm/hemm).

  • Side::Left means \(C = \alpha A B + \beta C\).

  • Side::Right means \(C = \alpha B A + \beta C\) (demonstrated in lines 141-147).

Rank-K Updates (SYRK/HERK) (Lines 230-241)

slate::rank_k_update( alpha, A, beta, C );
slate::syrk( alpha, A, beta, C );

Computes \(C = \alpha A A^T + \beta C\) where C is symmetric. Only the designated triangle of C (Lower or Upper) is updated.

Triangular Operations (TRMM/TRSM) (Lines 299-310)

// B = alpha A B
slate::triangular_multiply( alpha, A, B );       // trmm

// B = alpha A^{-1} B (Solve AX = B)
slate::triangular_solve( alpha, A, B );          // trsm

For triangular matrices, we can multiply (trmm) or solve (trsm). The simplified API names make the intent clear (“multiply” vs “solve”).

  1// ex05_blas.cc
  2// BLAS routines
  3
  4/// !!!   Lines between `//---------- begin label`          !!!
  5/// !!!             and `//---------- end label`            !!!
  6/// !!!   are included in the SLATE Users' Guide.           !!!
  7
  8#include <slate/slate.hh>
  9
 10#include "util.hh"
 11
 12int mpi_size = 0;
 13int mpi_rank = 0;
 14int grid_p = 0;
 15int grid_q = 0;
 16
 17//------------------------------------------------------------------------------
 18template <typename scalar_type>
 19void test_gemm()
 20{
 21    print_func( mpi_rank );
 22
 23    scalar_type alpha = 2.0, beta = 1.0;
 24    int64_t m=2000, n=1000, k=500, nb=256;
 25
 26    slate::Matrix<scalar_type> A( m, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
 27    slate::Matrix<scalar_type> B( k, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
 28    slate::Matrix<scalar_type> C( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
 29    A.insertLocalTiles();
 30    B.insertLocalTiles();
 31    C.insertLocalTiles();
 32    random_matrix( A );
 33    random_matrix( B );
 34    random_matrix( C );
 35
 36    //---------- begin gemm
 37    // C = alpha A B + beta C, where A, B, C are all general matrices.
 38    slate::multiply( alpha, A, B, beta, C );  // simplified API
 39    slate::gemm( alpha, A, B, beta, C );      // traditional API
 40    //---------- end gemm
 41
 42    //--------------------
 43    if (blas::get_device_count() > 0) {
 44        //---------- begin gemm_opts
 45        // Execute on GPU devices with lookahead of 2.
 46        slate::Options opts = {
 47            { slate::Option::Lookahead, 2 },
 48            { slate::Option::Target, slate::Target::Devices },
 49        };
 50        slate::multiply( alpha, A, B, beta, C, opts );
 51        //---------- end gemm_opts
 52    }
 53}
 54
 55//------------------------------------------------------------------------------
 56template <typename scalar_type>
 57void test_gemm_trans()
 58{
 59    print_func( mpi_rank );
 60
 61    scalar_type alpha = 2.0, beta = 1.0;
 62    int64_t m=2000, n=1000, k=500, nb=256;
 63
 64    // Dimensions of A, B are backwards from A, B in test_gemm().
 65    slate::Matrix<scalar_type> A( k, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
 66    slate::Matrix<scalar_type> B( n, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
 67    slate::Matrix<scalar_type> C( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
 68    A.insertLocalTiles();
 69    B.insertLocalTiles();
 70    C.insertLocalTiles();
 71    random_matrix( A );
 72    random_matrix( B );
 73    random_matrix( C );
 74
 75    //---------- begin gemm_trans
 76
 77    // Matrices can be transposed or conjugate-transposed beforehand.
 78    // C = alpha A^T B^H + beta C
 79    auto AT = transpose( A );
 80    auto BH = conj_transpose( B );
 81    slate::multiply( alpha, AT, BH, beta, C );  // simplified API
 82    slate::gemm( alpha, AT, BH, beta, C );      // traditional API
 83    //---------- end gemm_trans
 84
 85    // todo: support rvalues:
 86    // slate::gemm( alpha, transpose( A ), conj_transpose( B ), beta, C );
 87    // or
 88    // slate::gemm( alpha, transpose( A ), conj_transpose( B ), beta, std::move( C ) );
 89}
 90
 91//------------------------------------------------------------------------------
 92template <typename scalar_type>
 93void test_symm_left()
 94{
 95    print_func( mpi_rank );
 96
 97    scalar_type alpha = 2.0, beta = 1.0;
 98    int64_t m=2000, n=1000, nb=256;
 99
100    // A is m-by-m, B and C are m-by-n.
101    slate::SymmetricMatrix<scalar_type>
102        A( slate::Uplo::Lower, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
103    slate::Matrix<scalar_type> B( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
104    slate::Matrix<scalar_type> C( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
105    A.insertLocalTiles();
106    B.insertLocalTiles();
107    C.insertLocalTiles();
108    random_matrix( A );
109    random_matrix( B );
110    random_matrix( C );
111
112    //---------- begin symm_left
113
114    // C = alpha A B + beta C, where A is symmetric, on left side
115    slate::multiply( alpha, A, B, beta, C );                  // simplified API
116    slate::symm( slate::Side::Left, alpha, A, B, beta, C );   // traditional API
117    //---------- end symm_left
118}
119
120//------------------------------------------------------------------------------
121template <typename scalar_type>
122void test_symm_right()
123{
124    print_func( mpi_rank );
125
126    scalar_type alpha = 2.0, beta = 1.0;
127    int64_t m=2000, n=1000, nb=256;
128
129    // A is m-by-m, B and C are n-by-m (reverse of left case above).
130    slate::SymmetricMatrix<scalar_type>
131        A( slate::Uplo::Lower, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
132    slate::Matrix<scalar_type> B( n, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
133    slate::Matrix<scalar_type> C( n, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
134    A.insertLocalTiles();
135    B.insertLocalTiles();
136    C.insertLocalTiles();
137    random_matrix( A );
138    random_matrix( B );
139    random_matrix( C );
140
141    //---------- begin symm_right
142
143    // C = alpha B A + beta C, where A is symmetric, on right side
144    // Note B, A order reversed in multiply compared to symm.
145    slate::multiply( alpha, B, A, beta, C );                  // simplified API
146    slate::symm( slate::Side::Right, alpha, A, B, beta, C );  // traditional API
147    //---------- end symm_right
148}
149
150//------------------------------------------------------------------------------
151template <typename scalar_type>
152void test_hemm_left()
153{
154    print_func( mpi_rank );
155
156    scalar_type alpha = 2.0, beta = 1.0;
157    int64_t m=2000, n=1000, nb=256;
158
159    // A is m-by-m, B and C are m-by-n.
160    slate::HermitianMatrix<scalar_type>
161        A( slate::Uplo::Lower, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
162    slate::Matrix<scalar_type> B( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
163    slate::Matrix<scalar_type> C( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
164    A.insertLocalTiles();
165    B.insertLocalTiles();
166    C.insertLocalTiles();
167    random_matrix( A );
168    random_matrix( B );
169    random_matrix( C );
170
171    //---------- begin hemm_left
172
173    // C = alpha A B + beta C, where A is Hermitian, on left side
174    slate::multiply( alpha, A, B, beta, C );                  // simplified API
175    slate::hemm( slate::Side::Left, alpha, A, B, beta, C );   // traditional API
176    //---------- end hemm_left
177}
178
179//------------------------------------------------------------------------------
180template <typename scalar_type>
181void test_hemm_right()
182{
183    print_func( mpi_rank );
184
185    scalar_type alpha = 2.0, beta = 1.0;
186    int64_t m=2000, n=1000, nb=256;
187
188    // A is m-by-m, B and C are n-by-m (reverse of left case above).
189    slate::HermitianMatrix<scalar_type>
190        A( slate::Uplo::Lower, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
191    slate::Matrix<scalar_type> B( n, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
192    slate::Matrix<scalar_type> C( n, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
193    A.insertLocalTiles();
194    B.insertLocalTiles();
195    C.insertLocalTiles();
196    random_matrix( A );
197    random_matrix( B );
198    random_matrix( C );
199
200    //---------- begin hemm_right
201
202    // C = alpha B A + beta C, where A is Hermitian, on right side
203    // Note B, A order reversed in multiply compared to hemm.
204    slate::multiply( alpha, B, A, beta, C );                  // simplified API
205    slate::hemm( slate::Side::Right, alpha, A, B, beta, C );  // traditional API
206    //---------- end hemm_right
207}
208
209//------------------------------------------------------------------------------
210template <typename scalar_type>
211void test_syrk_syr2k()
212{
213    print_func( mpi_rank );
214
215    scalar_type alpha = 2.0, beta = 1.0;
216    int64_t n=1000, k=500, nb=256;
217
218    slate::Matrix<scalar_type> A( n, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
219    slate::Matrix<scalar_type> B( n, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
220    slate::SymmetricMatrix<scalar_type>
221        C( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
222    A.insertLocalTiles();
223    B.insertLocalTiles();
224    C.insertLocalTiles();
225    random_matrix( A );
226    random_matrix( B );
227    random_matrix( C );
228
229    //---------- begin syrk
230
231    // C = alpha A A^T + beta C, where C is symmetric
232    slate::rank_k_update( alpha, A, beta, C );      // simplified API
233    slate::syrk( alpha, A, beta, C );               // traditional API
234    //---------- end syrk
235
236    //---------- begin syr2k
237
238    // C = alpha A B^T + alpha B A^T + beta C, where C is symmetric
239    slate::rank_2k_update( alpha, A, B, beta, C );  // simplified API
240    slate::syr2k( alpha, A, B, beta, C );           // traditional API
241    //---------- end syr2k
242}
243
244//------------------------------------------------------------------------------
245template <typename scalar_type>
246void test_herk_her2k()
247{
248    print_func( mpi_rank );
249
250    scalar_type alpha = 2.0;
251    blas::real_type<scalar_type> alpha_real = 2.0, beta = 1.0;
252    int64_t n=1000, k=500, nb=256;
253
254    slate::Matrix<scalar_type> A( n, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
255    slate::Matrix<scalar_type> B( n, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
256    slate::HermitianMatrix<scalar_type>
257        C( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
258    A.insertLocalTiles();
259    B.insertLocalTiles();
260    C.insertLocalTiles();
261    random_matrix( A );
262    random_matrix( B );
263    random_matrix( C );
264
265    //---------- begin herk
266
267    // C = alpha A A^H + beta C, where C is Hermitian
268    slate::rank_k_update( alpha_real, A, beta, C );      // simplified API
269    slate::herk( alpha_real, A, beta, C );               // traditional API
270    //---------- end herk
271
272    //---------- begin her2k
273
274    // C = alpha A B^H + conj(alpha) B A^H + beta C, where C is Hermitian
275    slate::rank_2k_update( alpha, A, B, beta, C );  // simplified API
276    slate::her2k( alpha, A, B, beta, C );           // traditional API
277    //---------- end her2k
278}
279
280//------------------------------------------------------------------------------
281template <typename scalar_type>
282void test_trmm_trsm_left()
283{
284    print_func( mpi_rank );
285
286    scalar_type alpha = 2.0;
287    int64_t m=2000, n=1000, nb=256;
288
289    // A is m-by-m, B is m-by-n
290    slate::TriangularMatrix<scalar_type>
291        A( slate::Uplo::Lower, slate::Diag::NonUnit, m, nb,
292           grid_p, grid_q, MPI_COMM_WORLD );
293    slate::Matrix<scalar_type> B( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
294    A.insertLocalTiles();
295    B.insertLocalTiles();
296    random_matrix( A );
297    random_matrix( B );
298
299    //---------- begin trmm_left
300
301    //----- left
302    // B = alpha A B, where A is triangular, on left side
303    slate::triangular_multiply( alpha, A, B );       // simplified API
304    slate::trmm( slate::Side::Left, alpha, A, B );   // traditional API
305
306    // Solve AX = B, where A is triangular, on left side; X overwrites B.
307    // That is, B = alpha A^{-1} B.
308    slate::triangular_solve( alpha, A, B );          // simplified API
309    slate::trsm( slate::Side::Left, alpha, A, B );   // traditional API
310    //---------- end trmm_left
311}
312
313//------------------------------------------------------------------------------
314template <typename scalar_type>
315void test_trmm_trsm_right()
316{
317    print_func( mpi_rank );
318
319    scalar_type alpha = 2.0;
320    int64_t m=2000, n=1000, nb=256;
321
322    // A is m-by-m, B is n-by-m (reverse of left case above).
323    slate::TriangularMatrix<scalar_type>
324        A( slate::Uplo::Lower, slate::Diag::NonUnit, m, nb,
325           grid_p, grid_q, MPI_COMM_WORLD );
326    slate::Matrix<scalar_type> B( n, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
327    A.insertLocalTiles();
328    B.insertLocalTiles();
329    random_matrix( A );
330    random_matrix( B );
331
332    //---------- begin trmm_right
333
334    //----- right
335    // B = alpha B A, where A is triangular, on right side
336    // Note B, A order reversed in multiply compared to trmm.
337    slate::triangular_multiply( alpha, B, A );       // simplified API
338    slate::trmm( slate::Side::Right, alpha, A, B );  // traditional API
339
340    // Solve XA = B, where A is triangular, on right side; X overwrites B.
341    // That is, B = alpha B A^{-1}.
342    // Note B, A order reversed in solve compared to trsm.
343    slate::triangular_solve( alpha, B, A );          // simplified API
344    slate::trsm( slate::Side::Right, alpha, A, B );  // traditional API
345    //---------- end trmm_right
346}
347
348//------------------------------------------------------------------------------
349template <typename scalar_type>
350void test_all()
351{
352    test_gemm      < scalar_type >();
353    test_gemm_trans< scalar_type >();
354    test_symm_left < scalar_type >();
355    test_symm_right< scalar_type >();
356    test_hemm_left < scalar_type >();
357    test_hemm_right< scalar_type >();
358    test_syrk_syr2k< scalar_type >();
359    test_herk_her2k< scalar_type >();
360    test_trmm_trsm_left < scalar_type >();
361    test_trmm_trsm_right< scalar_type >();
362}
363
364//------------------------------------------------------------------------------
365int main( int argc, char** argv )
366{
367    try {
368        // Parse command line to set types for s, d, c, z precisions.
369        bool types[ 4 ];
370        parse_args( argc, argv, types );
371
372        int provided = 0;
373        slate_mpi_call(
374            MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) );
375        assert( provided == MPI_THREAD_MULTIPLE );
376
377        slate_mpi_call(
378            MPI_Comm_size( MPI_COMM_WORLD, &mpi_size ) );
379
380        slate_mpi_call(
381            MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank ) );
382
383        // Determine p-by-q grid for this MPI size.
384        grid_size( mpi_size, &grid_p, &grid_q );
385        if (mpi_rank == 0) {
386            printf( "mpi_size %d, grid_p %d, grid_q %d\n",
387                    mpi_size, grid_p, grid_q );
388        }
389
390        // so random_matrix is different on different ranks.
391        srand( 100 * mpi_rank );
392
393        if (types[ 0 ]) {
394            test_all< float >();
395        }
396        if (mpi_rank == 0)
397            printf( "\n" );
398
399        if (types[ 1 ]) {
400            test_all< double >();
401        }
402        if (mpi_rank == 0)
403            printf( "\n" );
404
405        if (types[ 2 ]) {
406            test_all< std::complex<float> >();
407        }
408        if (mpi_rank == 0)
409            printf( "\n" );
410
411        if (types[ 3 ]) {
412            test_all< std::complex<double> >();
413        }
414
415        slate_mpi_call(
416            MPI_Finalize() );
417    }
418    catch (std::exception const& ex) {
419        fprintf( stderr, "%s", ex.what() );
420        return 1;
421    }
422    return 0;
423}

C API Example

  1// ex05_blas.c
  2// BLAS routines
  3
  4#include <slate/c_api/slate.h>
  5#include <mpi.h>
  6
  7#include "util.h"
  8
  9int mpi_size = 0;
 10int mpi_rank = 0;
 11int grid_p = 0;
 12int grid_q = 0;
 13
 14//------------------------------------------------------------------------------
 15void test_gemm_r32()
 16{
 17    print_func( mpi_rank );
 18
 19    double alpha = 2.0, beta = 1.0;
 20    int64_t m=2000, n=1000, k=500, nb=256;
 21
 22    slate_Matrix_r32 A = slate_Matrix_create_r32(
 23        m, k,    nb, grid_p, grid_q, MPI_COMM_WORLD );
 24    slate_Matrix_r32 B = slate_Matrix_create_r32(
 25        k, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
 26    slate_Matrix_r32 C = slate_Matrix_create_r32(
 27        m, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
 28    slate_Matrix_insertLocalTiles_r32( A );
 29    slate_Matrix_insertLocalTiles_r32( B );
 30    slate_Matrix_insertLocalTiles_r32( C );
 31    random_Matrix_r32( A );
 32    random_Matrix_r32( B );
 33    random_Matrix_r32( C );
 34
 35    // C = alpha A B + beta C, where A, B, C are all general matrices.
 36    slate_multiply_r32( alpha, A, B, beta, C, NULL );
 37
 38    if (slate_Matrix_num_devices_r32( C ) > 0) {
 39        // Execute on GPU devices with lookahead of 2.
 40        slate_Options opts = slate_Options_create();
 41        slate_Options_set_Target( opts, slate_Target_Devices );
 42        slate_Options_set_Lookahead( opts, 2 );
 43
 44        slate_multiply_r32( alpha, A, B, beta, C, opts );
 45
 46        slate_Options_destroy( opts );
 47    }
 48
 49    slate_Matrix_destroy_r32( A );
 50    slate_Matrix_destroy_r32( B );
 51    slate_Matrix_destroy_r32( C );
 52}
 53
 54//------------------------------------------------------------------------------
 55void test_gemm_r64()
 56{
 57    print_func( mpi_rank );
 58
 59    double alpha = 2.0, beta = 1.0;
 60    int64_t m=2000, n=1000, k=500, nb=256;
 61
 62    slate_Matrix_r64 A = slate_Matrix_create_r64(
 63        m, k,    nb, grid_p, grid_q, MPI_COMM_WORLD );
 64    slate_Matrix_r64 B = slate_Matrix_create_r64(
 65        k, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
 66    slate_Matrix_r64 C = slate_Matrix_create_r64(
 67        m, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
 68    slate_Matrix_insertLocalTiles_r64( A );
 69    slate_Matrix_insertLocalTiles_r64( B );
 70    slate_Matrix_insertLocalTiles_r64( C );
 71    random_Matrix_r64( A );
 72    random_Matrix_r64( B );
 73    random_Matrix_r64( C );
 74
 75    // C = alpha A B + beta C, where A, B, C are all general matrices.
 76    slate_multiply_r64( alpha, A, B, beta, C, NULL );
 77
 78    if (slate_Matrix_num_devices_r64( C ) > 0) {
 79        // Execute on GPU devices with lookahead of 2.
 80        slate_Options opts = slate_Options_create();
 81        slate_Options_set_Target( opts, slate_Target_Devices );
 82        slate_Options_set_Lookahead( opts, 2 );
 83
 84        slate_multiply_r64( alpha, A, B, beta, C, opts );
 85
 86        slate_Options_destroy( opts );
 87    }
 88
 89    slate_Matrix_destroy_r64( A );
 90    slate_Matrix_destroy_r64( B );
 91    slate_Matrix_destroy_r64( C );
 92}
 93
 94//------------------------------------------------------------------------------
 95void test_gemm_c32()
 96{
 97    print_func( mpi_rank );
 98
 99    double alpha = 2.0, beta = 1.0;
100    int64_t m=2000, n=1000, k=500, nb=256;
101
102    slate_Matrix_c32 A = slate_Matrix_create_c32(
103        m, k,    nb, grid_p, grid_q, MPI_COMM_WORLD );
104    slate_Matrix_c32 B = slate_Matrix_create_c32(
105        k, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
106    slate_Matrix_c32 C = slate_Matrix_create_c32(
107        m, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
108    slate_Matrix_insertLocalTiles_c32( A );
109    slate_Matrix_insertLocalTiles_c32( B );
110    slate_Matrix_insertLocalTiles_c32( C );
111    random_Matrix_c32( A );
112    random_Matrix_c32( B );
113    random_Matrix_c32( C );
114
115    // C = alpha A B + beta C, where A, B, C are all general matrices.
116    slate_multiply_c32( alpha, A, B, beta, C, NULL );
117
118    if (slate_Matrix_num_devices_c32( C ) > 0) {
119        // Execute on GPU devices with lookahead of 2.
120        slate_Options opts = slate_Options_create();
121        slate_Options_set_Target( opts, slate_Target_Devices );
122        slate_Options_set_Lookahead( opts, 2 );
123
124        slate_multiply_c32( alpha, A, B, beta, C, opts );
125
126        slate_Options_destroy( opts );
127    }
128
129    slate_Matrix_destroy_c32( A );
130    slate_Matrix_destroy_c32( B );
131    slate_Matrix_destroy_c32( C );
132}
133
134//------------------------------------------------------------------------------
135void test_gemm_c64()
136{
137    print_func( mpi_rank );
138
139    double alpha = 2.0, beta = 1.0;
140    int64_t m=2000, n=1000, k=500, nb=256;
141
142    slate_Matrix_c64 A = slate_Matrix_create_c64(
143        m, k,    nb, grid_p, grid_q, MPI_COMM_WORLD );
144    slate_Matrix_c64 B = slate_Matrix_create_c64(
145        k, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
146    slate_Matrix_c64 C = slate_Matrix_create_c64(
147        m, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
148    slate_Matrix_insertLocalTiles_c64( A );
149    slate_Matrix_insertLocalTiles_c64( B );
150    slate_Matrix_insertLocalTiles_c64( C );
151    random_Matrix_c64( A );
152    random_Matrix_c64( B );
153    random_Matrix_c64( C );
154
155    // C = alpha A B + beta C, where A, B, C are all general matrices.
156    slate_multiply_c64( alpha, A, B, beta, C, NULL );
157
158    if (slate_Matrix_num_devices_c64( C ) > 0) {
159        // Execute on GPU devices with lookahead of 2.
160        slate_Options opts = slate_Options_create();
161        slate_Options_set_Target( opts, slate_Target_Devices );
162        slate_Options_set_Lookahead( opts, 2 );
163
164        slate_multiply_c64( alpha, A, B, beta, C, opts );
165
166        slate_Options_destroy( opts );
167    }
168
169    slate_Matrix_destroy_c64( A );
170    slate_Matrix_destroy_c64( B );
171    slate_Matrix_destroy_c64( C );
172}
173
174//------------------------------------------------------------------------------
175void test_gemm_trans_r32()
176{
177    print_func( mpi_rank );
178
179    double alpha = 2.0, beta = 1.0;
180    int64_t m=2000, n=1000, k=500, nb=256;
181
182    slate_Matrix_r32 A = slate_Matrix_create_r32(
183        k, m,    nb, grid_p, grid_q, MPI_COMM_WORLD );
184    slate_Matrix_r32 B = slate_Matrix_create_r32(
185        n, k,    nb, grid_p, grid_q, MPI_COMM_WORLD );
186    slate_Matrix_r32 C = slate_Matrix_create_r32(
187        m, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
188    slate_Matrix_insertLocalTiles_r32( A );
189    slate_Matrix_insertLocalTiles_r32( B );
190    slate_Matrix_insertLocalTiles_r32( C );
191    random_Matrix_r32( A );
192    random_Matrix_r32( B );
193    random_Matrix_r32( C );
194
195    // Matrices can be transposed or conjugate-transposed beforehand.
196    // C = alpha A^T B^H + beta C
197    slate_Matrix_transpose_in_place_r32( A );
198    slate_Matrix_conj_transpose_in_place_r32( B );
199    slate_multiply_r32( alpha, A, B, beta, C, NULL );  // simplified API
200
201    slate_Matrix_destroy_r32( A );
202    slate_Matrix_destroy_r32( B );
203    slate_Matrix_destroy_r32( C );
204}
205
206//------------------------------------------------------------------------------
207void test_gemm_trans_r64()
208{
209    print_func( mpi_rank );
210
211    double alpha = 2.0, beta = 1.0;
212    int64_t m=2000, n=1000, k=500, nb=256;
213
214    slate_Matrix_r64 A = slate_Matrix_create_r64(
215        k, m,    nb, grid_p, grid_q, MPI_COMM_WORLD );
216    slate_Matrix_r64 B = slate_Matrix_create_r64(
217        n, k,    nb, grid_p, grid_q, MPI_COMM_WORLD );
218    slate_Matrix_r64 C = slate_Matrix_create_r64(
219        m, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
220    slate_Matrix_insertLocalTiles_r64( A );
221    slate_Matrix_insertLocalTiles_r64( B );
222    slate_Matrix_insertLocalTiles_r64( C );
223    random_Matrix_r64( A );
224    random_Matrix_r64( B );
225    random_Matrix_r64( C );
226
227    // Matrices can be transposed or conjugate-transposed beforehand.
228    // C = alpha A^T B^H + beta C
229    slate_Matrix_transpose_in_place_r64( A );
230    slate_Matrix_conj_transpose_in_place_r64( B );
231    slate_multiply_r64( alpha, A, B, beta, C, NULL );  // simplified API
232
233    slate_Matrix_destroy_r64( A );
234    slate_Matrix_destroy_r64( B );
235    slate_Matrix_destroy_r64( C );
236}
237
238//------------------------------------------------------------------------------
239void test_gemm_trans_c32()
240{
241    print_func( mpi_rank );
242
243    double alpha = 2.0, beta = 1.0;
244    int64_t m=2000, n=1000, k=500, nb=256;
245
246    slate_Matrix_c32 A = slate_Matrix_create_c32(
247        k, m,    nb, grid_p, grid_q, MPI_COMM_WORLD );
248    slate_Matrix_c32 B = slate_Matrix_create_c32(
249        n, k,    nb, grid_p, grid_q, MPI_COMM_WORLD );
250    slate_Matrix_c32 C = slate_Matrix_create_c32(
251        m, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
252    slate_Matrix_insertLocalTiles_c32( A );
253    slate_Matrix_insertLocalTiles_c32( B );
254    slate_Matrix_insertLocalTiles_c32( C );
255    random_Matrix_c32( A );
256    random_Matrix_c32( B );
257    random_Matrix_c32( C );
258
259    // Matrices can be transposed or conjugate-transposed beforehand.
260    // C = alpha A^T B^H + beta C
261    slate_Matrix_transpose_in_place_c32( A );
262    slate_Matrix_conj_transpose_in_place_c32( B );
263    slate_multiply_c32( alpha, A, B, beta, C, NULL );  // simplified API
264
265    slate_Matrix_destroy_c32( A );
266    slate_Matrix_destroy_c32( B );
267    slate_Matrix_destroy_c32( C );
268}
269
270//------------------------------------------------------------------------------
271void test_gemm_trans_c64()
272{
273    print_func( mpi_rank );
274
275    double alpha = 2.0, beta = 1.0;
276    int64_t m=2000, n=1000, k=500, nb=256;
277
278    slate_Matrix_c64 A = slate_Matrix_create_c64(
279        k, m,    nb, grid_p, grid_q, MPI_COMM_WORLD );
280    slate_Matrix_c64 B = slate_Matrix_create_c64(
281        n, k,    nb, grid_p, grid_q, MPI_COMM_WORLD );
282    slate_Matrix_c64 C = slate_Matrix_create_c64(
283        m, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
284    slate_Matrix_insertLocalTiles_c64( A );
285    slate_Matrix_insertLocalTiles_c64( B );
286    slate_Matrix_insertLocalTiles_c64( C );
287    random_Matrix_c64( A );
288    random_Matrix_c64( B );
289    random_Matrix_c64( C );
290
291    // Matrices can be transposed or conjugate-transposed beforehand.
292    // C = alpha A^T B^H + beta C
293    slate_Matrix_transpose_in_place_c64( A );
294    slate_Matrix_conj_transpose_in_place_c64( B );
295    slate_multiply_c64( alpha, A, B, beta, C, NULL );  // simplified API
296
297    slate_Matrix_destroy_c64( A );
298    slate_Matrix_destroy_c64( B );
299    slate_Matrix_destroy_c64( C );
300}
301
302//------------------------------------------------------------------------------
303int main( int argc, char** argv )
304{
305    // Parse command line to set types for s, d, c, z precisions.
306    bool types[ 4 ];
307    parse_args( argc, argv, types );
308
309    int provided = 0;
310    MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided );
311    assert( provided == MPI_THREAD_MULTIPLE );
312
313    MPI_Comm_size( MPI_COMM_WORLD, &mpi_size );
314    MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank );
315
316    // Determine p-by-q grid for this MPI size.
317    grid_size( mpi_size, &grid_p, &grid_q );
318    if (mpi_rank == 0) {
319        printf( "mpi_size %d, grid_p %d, grid_q %d\n",
320                mpi_size, grid_p, grid_q );
321    }
322
323    // so random_matrix is different on different ranks.
324    srand( 100 * mpi_rank );
325
326    if (types[ 0 ]) {
327        test_gemm_r32();
328        test_gemm_trans_r32();
329        if (mpi_rank == 0)
330            printf( "\n" );
331    }
332
333    if (types[ 1 ]) {
334        test_gemm_r64();
335        test_gemm_trans_r64();
336        if (mpi_rank == 0)
337            printf( "\n" );
338    }
339
340    if (types[ 2 ]) {
341        test_gemm_c32();
342        test_gemm_trans_c32();
343        if (mpi_rank == 0)
344            printf( "\n" );
345    }
346
347    if (types[ 3 ]) {
348        test_gemm_c64();
349        test_gemm_trans_c64();
350    }
351
352    MPI_Finalize();
353
354    return 0;
355}

Fortran API Example

  1! ex05_blas.f90
  2! BLAS routines
  3program ex05_blas
  4    use, intrinsic :: iso_fortran_env
  5    use slate
  6    use mpi
  7    use util
  8    implicit none
  9
 10    !! Variables
 11    logical                            :: types(4)
 12    integer(kind=c_int)                :: p_grid, q_grid
 13
 14    integer(kind=c_int)                :: provided, ierr
 15    integer(kind=c_int)                :: mpi_rank, mpi_size
 16
 17    !! Get requested types
 18    call parse_args( types );
 19
 20    !! MPI
 21    call MPI_Init_thread( MPI_THREAD_MULTIPLE, provided, ierr )
 22    if ((ierr .ne. 0) .or. (provided .ne. MPI_THREAD_MULTIPLE)) then
 23        print *, "Error: MPI_Init_thread"
 24        return
 25    end if
 26    call MPI_Comm_size( MPI_COMM_WORLD, mpi_size, ierr )
 27    if (ierr .ne. 0) then
 28        print *, "Error: MPI_Comm_size"
 29        return
 30    end if
 31    call MPI_Comm_rank( MPI_COMM_WORLD, mpi_rank, ierr )
 32    if (ierr .ne. 0) then
 33        print *, "Error: MPI_Comm_rank"
 34        return
 35    end if
 36
 37    call grid_size( mpi_size, p_grid, q_grid )
 38
 39    call srand( 100 * mpi_rank )
 40
 41    if (types(1)) then
 42        call test_gemm_r32()
 43        call test_gemm_trans_r32()
 44
 45        if (mpi_rank == 0) then
 46          print *
 47         end if
 48    end if
 49    if (types(2)) then
 50        call test_gemm_r64()
 51        call test_gemm_trans_r64()
 52
 53        if (mpi_rank == 0) then
 54          print *
 55         end if
 56    end if
 57    if (types(3)) then
 58        call test_gemm_c32()
 59        call test_gemm_trans_c32()
 60
 61        if (mpi_rank == 0) then
 62          print *
 63         end if
 64    end if
 65    if (types(4)) then
 66        call test_gemm_c64()
 67        call test_gemm_trans_c64()
 68
 69        if (mpi_rank == 0) then
 70          print *
 71         end if
 72    end if
 73
 74    call MPI_Finalize( ierr )
 75    if (ierr .ne. 0) then
 76        print *, "Error: MPI_Finalize"
 77        return
 78    end if
 79
 80contains
 81
 82    subroutine test_gemm_r32()
 83        !! Constants
 84        integer(kind=c_int64_t), parameter :: m  = 2000
 85        integer(kind=c_int64_t), parameter :: n  = 1000
 86        integer(kind=c_int64_t), parameter :: k  = 500
 87        integer(kind=c_int64_t), parameter :: nb = 256
 88
 89        real(kind=c_float),      parameter :: alpha = 2.0
 90        real(kind=c_float),      parameter :: beta  = 1.0
 91
 92        !! Variables
 93        integer(kind=c_int64_t)            :: i
 94        type(c_ptr)                        :: A, B, C, opts
 95
 96        !! Example
 97        call print_func( mpi_rank, 'test_gemm_r32' )
 98
 99        A = slate_Matrix_create_r32( m, k, nb, p_grid, q_grid, MPI_COMM_WORLD )
100        B = slate_Matrix_create_r32( k, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
101        C = slate_Matrix_create_r32( m, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
102        call slate_Matrix_insertLocalTiles_r32( A )
103        call slate_Matrix_insertLocalTiles_r32( B )
104        call slate_Matrix_insertLocalTiles_r32( C )
105        call random_Matrix_r32( A )
106        call random_Matrix_r32( B )
107        call random_Matrix_r32( C )
108
109        ! C = alpha A B + beta C
110        call slate_multiply_r32( alpha, A, B, beta, C, c_null_ptr )
111
112        if (slate_Matrix_num_devices_r32( C ) > 0) then
113            opts = slate_Options_create()
114            call slate_Options_set_Target( opts, slate_Target_Devices );
115            call slate_Options_set_Lookahead( opts, 2_int64 )
116
117            call slate_multiply_r32( alpha, A, B, beta, C, opts )
118
119            call slate_Options_destroy( opts )
120        endif
121
122
123        call slate_Matrix_destroy_r32( A )
124        call slate_Matrix_destroy_r32( B )
125        call slate_Matrix_destroy_r32( C )
126
127    end subroutine test_gemm_r32
128
129    subroutine test_gemm_r64()
130        !! Constants
131        integer(kind=c_int64_t), parameter :: m  = 2000
132        integer(kind=c_int64_t), parameter :: n  = 1000
133        integer(kind=c_int64_t), parameter :: k  = 500
134        integer(kind=c_int64_t), parameter :: nb = 256
135
136        real(kind=c_double),     parameter :: alpha = 2.0
137        real(kind=c_double),     parameter :: beta  = 1.0
138
139        !! Variables
140        integer(kind=c_int64_t)            :: i
141        type(c_ptr)                        :: A, B, C, opts
142
143        !! Example
144        call print_func( mpi_rank, 'test_gemm_r64' )
145
146        A = slate_Matrix_create_r64( m, k, nb, p_grid, q_grid, MPI_COMM_WORLD )
147        B = slate_Matrix_create_r64( k, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
148        C = slate_Matrix_create_r64( m, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
149        call slate_Matrix_insertLocalTiles_r64( A )
150        call slate_Matrix_insertLocalTiles_r64( B )
151        call slate_Matrix_insertLocalTiles_r64( C )
152        call random_Matrix_r64( A )
153        call random_Matrix_r64( B )
154        call random_Matrix_r64( C )
155
156        ! C = alpha A B + beta C
157        call slate_multiply_r64( alpha, A, B, beta, C, c_null_ptr )
158
159        if (slate_Matrix_num_devices_r64( C ) > 0) then
160            opts = slate_Options_create()
161            call slate_Options_set_Target( opts, slate_Target_Devices );
162            call slate_Options_set_Lookahead( opts, 2_int64 )
163
164            call slate_multiply_r64( alpha, A, B, beta, C, opts )
165
166            call slate_Options_destroy( opts )
167        endif
168
169
170        call slate_Matrix_destroy_r64( A )
171        call slate_Matrix_destroy_r64( B )
172        call slate_Matrix_destroy_r64( C )
173
174    end subroutine test_gemm_r64
175
176    subroutine test_gemm_c32()
177        !! Constants
178        integer(kind=c_int64_t), parameter :: m  = 2000
179        integer(kind=c_int64_t), parameter :: n  = 1000
180        integer(kind=c_int64_t), parameter :: k  = 500
181        integer(kind=c_int64_t), parameter :: nb = 256
182
183        complex(kind=c_float),   parameter :: alpha = 2.0
184        complex(kind=c_float),   parameter :: beta  = 1.0
185
186        !! Variables
187        integer(kind=c_int64_t)            :: i
188        type(c_ptr)                        :: A, B, C, opts
189
190        !! Example
191        call print_func( mpi_rank, 'test_gemm_c32' )
192
193        A = slate_Matrix_create_c32( m, k, nb, p_grid, q_grid, MPI_COMM_WORLD )
194        B = slate_Matrix_create_c32( k, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
195        C = slate_Matrix_create_c32( m, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
196        call slate_Matrix_insertLocalTiles_c32( A )
197        call slate_Matrix_insertLocalTiles_c32( B )
198        call slate_Matrix_insertLocalTiles_c32( C )
199        call random_Matrix_c32( A )
200        call random_Matrix_c32( B )
201        call random_Matrix_c32( C )
202
203        ! C = alpha A B + beta C
204        call slate_multiply_c32( alpha, A, B, beta, C, c_null_ptr )
205
206        if (slate_Matrix_num_devices_c32( C ) > 0) then
207            opts = slate_Options_create()
208            call slate_Options_set_Target( opts, slate_Target_Devices );
209            call slate_Options_set_Lookahead( opts, 2_int64 )
210
211            call slate_multiply_c32( alpha, A, B, beta, C, opts )
212
213            call slate_Options_destroy( opts )
214        endif
215
216
217        call slate_Matrix_destroy_c32( A )
218        call slate_Matrix_destroy_c32( B )
219        call slate_Matrix_destroy_c32( C )
220
221    end subroutine test_gemm_c32
222
223    subroutine test_gemm_c64()
224        !! Constants
225        integer(kind=c_int64_t), parameter :: m  = 2000
226        integer(kind=c_int64_t), parameter :: n  = 1000
227        integer(kind=c_int64_t), parameter :: k  = 500
228        integer(kind=c_int64_t), parameter :: nb = 256
229
230        complex(kind=c_double),  parameter :: alpha = 2.0
231        complex(kind=c_double),  parameter :: beta  = 1.0
232
233        !! Variables
234        integer(kind=c_int64_t)            :: i
235        type(c_ptr)                        :: A, B, C, opts
236
237        !! Example
238        call print_func( mpi_rank, 'test_gemm_c64' )
239
240        A = slate_Matrix_create_c64( m, k, nb, p_grid, q_grid, MPI_COMM_WORLD )
241        B = slate_Matrix_create_c64( k, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
242        C = slate_Matrix_create_c64( m, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
243        call slate_Matrix_insertLocalTiles_c64( A )
244        call slate_Matrix_insertLocalTiles_c64( B )
245        call slate_Matrix_insertLocalTiles_c64( C )
246        call random_Matrix_c64( A )
247        call random_Matrix_c64( B )
248        call random_Matrix_c64( C )
249
250        ! C = alpha A B + beta C
251        call slate_multiply_c64( alpha, A, B, beta, C, c_null_ptr )
252
253        if (slate_Matrix_num_devices_c64( C ) > 0) then
254            opts = slate_Options_create()
255            call slate_Options_set_Target( opts, slate_Target_Devices );
256            call slate_Options_set_Lookahead( opts, 2_int64 )
257
258            call slate_multiply_c64( alpha, A, B, beta, C, opts )
259
260            call slate_Options_destroy( opts )
261        endif
262
263
264        call slate_Matrix_destroy_c64( A )
265        call slate_Matrix_destroy_c64( B )
266        call slate_Matrix_destroy_c64( C )
267
268    end subroutine test_gemm_c64
269
270    subroutine test_gemm_trans_r32()
271        !! Constants
272        integer(kind=c_int64_t), parameter :: m  = 2000
273        integer(kind=c_int64_t), parameter :: n  = 1000
274        integer(kind=c_int64_t), parameter :: k  = 500
275        integer(kind=c_int64_t), parameter :: nb = 256
276
277        real(kind=c_float),      parameter :: alpha = 2.0
278        real(kind=c_float),      parameter :: beta  = 1.0
279
280        !! Variables
281        integer(kind=c_int64_t)            :: i
282        type(c_ptr)                        :: A, B, C, opts
283
284        !! Example
285        call print_func( mpi_rank, 'test_gemm_trans_r32' )
286
287        A = slate_Matrix_create_r32( k, m, nb, p_grid, q_grid, MPI_COMM_WORLD )
288        B = slate_Matrix_create_r32( n, k, nb, p_grid, q_grid, MPI_COMM_WORLD )
289        C = slate_Matrix_create_r32( m, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
290        call slate_Matrix_insertLocalTiles_r32( A )
291        call slate_Matrix_insertLocalTiles_r32( B )
292        call slate_Matrix_insertLocalTiles_r32( C )
293        call random_Matrix_r32( A )
294        call random_Matrix_r32( B )
295        call random_Matrix_r32( C )
296
297        ! Matrices can be transposed or conjugate-transposed beforehand
298        ! C = alpha AT BH + beta C
299        call slate_Matrix_transpose_in_place_r32( A );
300        call slate_Matrix_conj_transpose_in_place_r32( B );
301        call slate_multiply_r32( alpha, A, B, beta, C, c_null_ptr )
302
303        call slate_Matrix_destroy_r32( A )
304        call slate_Matrix_destroy_r32( B )
305        call slate_Matrix_destroy_r32( C )
306
307    end subroutine test_gemm_trans_r32
308
309    subroutine test_gemm_trans_r64()
310        !! Constants
311        integer(kind=c_int64_t), parameter :: m  = 2000
312        integer(kind=c_int64_t), parameter :: n  = 1000
313        integer(kind=c_int64_t), parameter :: k  = 500
314        integer(kind=c_int64_t), parameter :: nb = 256
315
316        real(kind=c_double),     parameter :: alpha = 2.0
317        real(kind=c_double),     parameter :: beta  = 1.0
318
319        !! Variables
320        integer(kind=c_int64_t)            :: i
321        type(c_ptr)                        :: A, B, C, opts
322
323        !! Example
324        call print_func( mpi_rank, 'test_gemm_trans_r64' )
325
326        A = slate_Matrix_create_r64( k, m, nb, p_grid, q_grid, MPI_COMM_WORLD )
327        B = slate_Matrix_create_r64( n, k, nb, p_grid, q_grid, MPI_COMM_WORLD )
328        C = slate_Matrix_create_r64( m, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
329        call slate_Matrix_insertLocalTiles_r64( A )
330        call slate_Matrix_insertLocalTiles_r64( B )
331        call slate_Matrix_insertLocalTiles_r64( C )
332        call random_Matrix_r64( A )
333        call random_Matrix_r64( B )
334        call random_Matrix_r64( C )
335
336        ! Matrices can be transposed or conjugate-transposed beforehand
337        ! C = alpha AT BH + beta C
338        call slate_Matrix_transpose_in_place_r64( A );
339        call slate_Matrix_conj_transpose_in_place_r64( B );
340        call slate_multiply_r64( alpha, A, B, beta, C, c_null_ptr )
341
342        call slate_Matrix_destroy_r64( A )
343        call slate_Matrix_destroy_r64( B )
344        call slate_Matrix_destroy_r64( C )
345
346    end subroutine test_gemm_trans_r64
347
348    subroutine test_gemm_trans_c32()
349        !! Constants
350        integer(kind=c_int64_t), parameter :: m  = 2000
351        integer(kind=c_int64_t), parameter :: n  = 1000
352        integer(kind=c_int64_t), parameter :: k  = 500
353        integer(kind=c_int64_t), parameter :: nb = 256
354
355        complex(kind=c_float),   parameter :: alpha = 2.0
356        complex(kind=c_float),   parameter :: beta  = 1.0
357
358        !! Variables
359        integer(kind=c_int64_t)            :: i
360        type(c_ptr)                        :: A, B, C, opts
361
362        !! Example
363        call print_func( mpi_rank, 'test_gemm_trans_c32' )
364
365        A = slate_Matrix_create_c32( k, m, nb, p_grid, q_grid, MPI_COMM_WORLD )
366        B = slate_Matrix_create_c32( n, k, nb, p_grid, q_grid, MPI_COMM_WORLD )
367        C = slate_Matrix_create_c32( m, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
368        call slate_Matrix_insertLocalTiles_c32( A )
369        call slate_Matrix_insertLocalTiles_c32( B )
370        call slate_Matrix_insertLocalTiles_c32( C )
371        call random_Matrix_c32( A )
372        call random_Matrix_c32( B )
373        call random_Matrix_c32( C )
374
375        ! Matrices can be transposed or conjugate-transposed beforehand
376        ! C = alpha AT BH + beta C
377        call slate_Matrix_transpose_in_place_c32( A );
378        call slate_Matrix_conj_transpose_in_place_c32( B );
379        call slate_multiply_c32( alpha, A, B, beta, C, c_null_ptr )
380
381        call slate_Matrix_destroy_c32( A )
382        call slate_Matrix_destroy_c32( B )
383        call slate_Matrix_destroy_c32( C )
384
385    end subroutine test_gemm_trans_c32
386
387    subroutine test_gemm_trans_c64()
388        !! Constants
389        integer(kind=c_int64_t), parameter :: m  = 2000
390        integer(kind=c_int64_t), parameter :: n  = 1000
391        integer(kind=c_int64_t), parameter :: k  = 500
392        integer(kind=c_int64_t), parameter :: nb = 256
393
394        complex(kind=c_double),  parameter :: alpha = 2.0
395        complex(kind=c_double),  parameter :: beta  = 1.0
396
397        !! Variables
398        integer(kind=c_int64_t)            :: i
399        type(c_ptr)                        :: A, B, C, opts
400
401        !! Example
402        call print_func( mpi_rank, 'test_gemm_trans_c64' )
403
404        A = slate_Matrix_create_c64( k, m, nb, p_grid, q_grid, MPI_COMM_WORLD )
405        B = slate_Matrix_create_c64( n, k, nb, p_grid, q_grid, MPI_COMM_WORLD )
406        C = slate_Matrix_create_c64( m, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
407        call slate_Matrix_insertLocalTiles_c64( A )
408        call slate_Matrix_insertLocalTiles_c64( B )
409        call slate_Matrix_insertLocalTiles_c64( C )
410        call random_Matrix_c64( A )
411        call random_Matrix_c64( B )
412        call random_Matrix_c64( C )
413
414        ! Matrices can be transposed or conjugate-transposed beforehand
415        ! C = alpha AT BH + beta C
416        call slate_Matrix_transpose_in_place_c64( A );
417        call slate_Matrix_conj_transpose_in_place_c64( B );
418        call slate_multiply_c64( alpha, A, B, beta, C, c_null_ptr )
419
420        call slate_Matrix_destroy_c64( A )
421        call slate_Matrix_destroy_c64( B )
422        call slate_Matrix_destroy_c64( C )
423
424    end subroutine test_gemm_trans_c64
425
426end program ex05_blas