Example 05: BLAS Operations
This example demonstrates how to perform Basic Linear Algebra Subprograms (BLAS) operations in SLATE.
Key Concepts
Matrix Multiplication: Using
slate::multiply(gemm, hemm, symm) for matrix products.Rank Updates: Performing rank-k (herk, syrk) and rank-2k (her2k, syr2k) updates.
Triangular Operations: Triangular matrix multiplication (trmm) and solving triangular systems (trsm).
Simplified vs Traditional API: Comparing the descriptive
multiplyAPI with the traditional BLAS-named API.
C++ Example
General Matrix Multiplication (GEMM) (Lines 36-40)
// C = alpha A B + beta C
slate::multiply( alpha, A, B, beta, C ); // simplified API
slate::gemm( alpha, A, B, beta, C ); // traditional API
Here we perform the standard operation \(C = \alpha AB + \beta C\).
A is an m by k matrix.
B is a k by n matrix.
C is an m by n matrix.
SLATE provides both a descriptive multiply routine and the traditional BLAS-named gemm. They are equivalent.
GPU Execution with Options (Lines 43-52)
if (blas::get_device_count() > 0) {
slate::Options opts = {
{ slate::Option::Lookahead, 2 },
{ slate::Option::Target, slate::Target::Devices },
};
slate::multiply( alpha, A, B, beta, C, opts );
}
Most SLATE routines accept an Options map as the final argument. Here we:
Set Target::Devices to offload computation to GPUs.
Set Lookahead to 2 to overlap communication and computation.
Transposed Multiplication (Lines 77-83)
auto AT = transpose( A );
auto BH = conj_transpose( B );
slate::multiply( alpha, AT, BH, beta, C );
To compute \(C = \alpha A^T B^H + \beta C\), we simply create transposed views AT and BH and pass them to the multiply function. SLATE detects the transposition flags on the views and handles the logic internally.
Symmetric/Hermitian Multiplication (SYMM/HEMM) (Lines 97-118)
slate::multiply( alpha, A, B, beta, C ); // simplified
slate::symm( slate::Side::Left, alpha, A, B, beta, C ); // traditional
When A is a SymmetricMatrix (or HermitianMatrix), multiply automatically dispatches to the efficient symmetric/Hermitian algorithm (symm/hemm).
Side::Left means \(C = \alpha A B + \beta C\).
Side::Right means \(C = \alpha B A + \beta C\) (demonstrated in lines 141-147).
Rank-K Updates (SYRK/HERK) (Lines 230-241)
slate::rank_k_update( alpha, A, beta, C );
slate::syrk( alpha, A, beta, C );
Computes \(C = \alpha A A^T + \beta C\) where C is symmetric. Only the designated triangle of C (Lower or Upper) is updated.
Triangular Operations (TRMM/TRSM) (Lines 299-310)
// B = alpha A B
slate::triangular_multiply( alpha, A, B ); // trmm
// B = alpha A^{-1} B (Solve AX = B)
slate::triangular_solve( alpha, A, B ); // trsm
For triangular matrices, we can multiply (trmm) or solve (trsm). The simplified API names make the intent clear (“multiply” vs “solve”).
1// ex05_blas.cc
2// BLAS routines
3
4/// !!! Lines between `//---------- begin label` !!!
5/// !!! and `//---------- end label` !!!
6/// !!! are included in the SLATE Users' Guide. !!!
7
8#include <slate/slate.hh>
9
10#include "util.hh"
11
12int mpi_size = 0;
13int mpi_rank = 0;
14int grid_p = 0;
15int grid_q = 0;
16
17//------------------------------------------------------------------------------
18template <typename scalar_type>
19void test_gemm()
20{
21 print_func( mpi_rank );
22
23 scalar_type alpha = 2.0, beta = 1.0;
24 int64_t m=2000, n=1000, k=500, nb=256;
25
26 slate::Matrix<scalar_type> A( m, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
27 slate::Matrix<scalar_type> B( k, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
28 slate::Matrix<scalar_type> C( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
29 A.insertLocalTiles();
30 B.insertLocalTiles();
31 C.insertLocalTiles();
32 random_matrix( A );
33 random_matrix( B );
34 random_matrix( C );
35
36 //---------- begin gemm
37 // C = alpha A B + beta C, where A, B, C are all general matrices.
38 slate::multiply( alpha, A, B, beta, C ); // simplified API
39 slate::gemm( alpha, A, B, beta, C ); // traditional API
40 //---------- end gemm
41
42 //--------------------
43 if (blas::get_device_count() > 0) {
44 //---------- begin gemm_opts
45 // Execute on GPU devices with lookahead of 2.
46 slate::Options opts = {
47 { slate::Option::Lookahead, 2 },
48 { slate::Option::Target, slate::Target::Devices },
49 };
50 slate::multiply( alpha, A, B, beta, C, opts );
51 //---------- end gemm_opts
52 }
53}
54
55//------------------------------------------------------------------------------
56template <typename scalar_type>
57void test_gemm_trans()
58{
59 print_func( mpi_rank );
60
61 scalar_type alpha = 2.0, beta = 1.0;
62 int64_t m=2000, n=1000, k=500, nb=256;
63
64 // Dimensions of A, B are backwards from A, B in test_gemm().
65 slate::Matrix<scalar_type> A( k, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
66 slate::Matrix<scalar_type> B( n, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
67 slate::Matrix<scalar_type> C( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
68 A.insertLocalTiles();
69 B.insertLocalTiles();
70 C.insertLocalTiles();
71 random_matrix( A );
72 random_matrix( B );
73 random_matrix( C );
74
75 //---------- begin gemm_trans
76
77 // Matrices can be transposed or conjugate-transposed beforehand.
78 // C = alpha A^T B^H + beta C
79 auto AT = transpose( A );
80 auto BH = conj_transpose( B );
81 slate::multiply( alpha, AT, BH, beta, C ); // simplified API
82 slate::gemm( alpha, AT, BH, beta, C ); // traditional API
83 //---------- end gemm_trans
84
85 // todo: support rvalues:
86 // slate::gemm( alpha, transpose( A ), conj_transpose( B ), beta, C );
87 // or
88 // slate::gemm( alpha, transpose( A ), conj_transpose( B ), beta, std::move( C ) );
89}
90
91//------------------------------------------------------------------------------
92template <typename scalar_type>
93void test_symm_left()
94{
95 print_func( mpi_rank );
96
97 scalar_type alpha = 2.0, beta = 1.0;
98 int64_t m=2000, n=1000, nb=256;
99
100 // A is m-by-m, B and C are m-by-n.
101 slate::SymmetricMatrix<scalar_type>
102 A( slate::Uplo::Lower, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
103 slate::Matrix<scalar_type> B( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
104 slate::Matrix<scalar_type> C( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
105 A.insertLocalTiles();
106 B.insertLocalTiles();
107 C.insertLocalTiles();
108 random_matrix( A );
109 random_matrix( B );
110 random_matrix( C );
111
112 //---------- begin symm_left
113
114 // C = alpha A B + beta C, where A is symmetric, on left side
115 slate::multiply( alpha, A, B, beta, C ); // simplified API
116 slate::symm( slate::Side::Left, alpha, A, B, beta, C ); // traditional API
117 //---------- end symm_left
118}
119
120//------------------------------------------------------------------------------
121template <typename scalar_type>
122void test_symm_right()
123{
124 print_func( mpi_rank );
125
126 scalar_type alpha = 2.0, beta = 1.0;
127 int64_t m=2000, n=1000, nb=256;
128
129 // A is m-by-m, B and C are n-by-m (reverse of left case above).
130 slate::SymmetricMatrix<scalar_type>
131 A( slate::Uplo::Lower, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
132 slate::Matrix<scalar_type> B( n, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
133 slate::Matrix<scalar_type> C( n, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
134 A.insertLocalTiles();
135 B.insertLocalTiles();
136 C.insertLocalTiles();
137 random_matrix( A );
138 random_matrix( B );
139 random_matrix( C );
140
141 //---------- begin symm_right
142
143 // C = alpha B A + beta C, where A is symmetric, on right side
144 // Note B, A order reversed in multiply compared to symm.
145 slate::multiply( alpha, B, A, beta, C ); // simplified API
146 slate::symm( slate::Side::Right, alpha, A, B, beta, C ); // traditional API
147 //---------- end symm_right
148}
149
150//------------------------------------------------------------------------------
151template <typename scalar_type>
152void test_hemm_left()
153{
154 print_func( mpi_rank );
155
156 scalar_type alpha = 2.0, beta = 1.0;
157 int64_t m=2000, n=1000, nb=256;
158
159 // A is m-by-m, B and C are m-by-n.
160 slate::HermitianMatrix<scalar_type>
161 A( slate::Uplo::Lower, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
162 slate::Matrix<scalar_type> B( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
163 slate::Matrix<scalar_type> C( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
164 A.insertLocalTiles();
165 B.insertLocalTiles();
166 C.insertLocalTiles();
167 random_matrix( A );
168 random_matrix( B );
169 random_matrix( C );
170
171 //---------- begin hemm_left
172
173 // C = alpha A B + beta C, where A is Hermitian, on left side
174 slate::multiply( alpha, A, B, beta, C ); // simplified API
175 slate::hemm( slate::Side::Left, alpha, A, B, beta, C ); // traditional API
176 //---------- end hemm_left
177}
178
179//------------------------------------------------------------------------------
180template <typename scalar_type>
181void test_hemm_right()
182{
183 print_func( mpi_rank );
184
185 scalar_type alpha = 2.0, beta = 1.0;
186 int64_t m=2000, n=1000, nb=256;
187
188 // A is m-by-m, B and C are n-by-m (reverse of left case above).
189 slate::HermitianMatrix<scalar_type>
190 A( slate::Uplo::Lower, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
191 slate::Matrix<scalar_type> B( n, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
192 slate::Matrix<scalar_type> C( n, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
193 A.insertLocalTiles();
194 B.insertLocalTiles();
195 C.insertLocalTiles();
196 random_matrix( A );
197 random_matrix( B );
198 random_matrix( C );
199
200 //---------- begin hemm_right
201
202 // C = alpha B A + beta C, where A is Hermitian, on right side
203 // Note B, A order reversed in multiply compared to hemm.
204 slate::multiply( alpha, B, A, beta, C ); // simplified API
205 slate::hemm( slate::Side::Right, alpha, A, B, beta, C ); // traditional API
206 //---------- end hemm_right
207}
208
209//------------------------------------------------------------------------------
210template <typename scalar_type>
211void test_syrk_syr2k()
212{
213 print_func( mpi_rank );
214
215 scalar_type alpha = 2.0, beta = 1.0;
216 int64_t n=1000, k=500, nb=256;
217
218 slate::Matrix<scalar_type> A( n, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
219 slate::Matrix<scalar_type> B( n, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
220 slate::SymmetricMatrix<scalar_type>
221 C( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
222 A.insertLocalTiles();
223 B.insertLocalTiles();
224 C.insertLocalTiles();
225 random_matrix( A );
226 random_matrix( B );
227 random_matrix( C );
228
229 //---------- begin syrk
230
231 // C = alpha A A^T + beta C, where C is symmetric
232 slate::rank_k_update( alpha, A, beta, C ); // simplified API
233 slate::syrk( alpha, A, beta, C ); // traditional API
234 //---------- end syrk
235
236 //---------- begin syr2k
237
238 // C = alpha A B^T + alpha B A^T + beta C, where C is symmetric
239 slate::rank_2k_update( alpha, A, B, beta, C ); // simplified API
240 slate::syr2k( alpha, A, B, beta, C ); // traditional API
241 //---------- end syr2k
242}
243
244//------------------------------------------------------------------------------
245template <typename scalar_type>
246void test_herk_her2k()
247{
248 print_func( mpi_rank );
249
250 scalar_type alpha = 2.0;
251 blas::real_type<scalar_type> alpha_real = 2.0, beta = 1.0;
252 int64_t n=1000, k=500, nb=256;
253
254 slate::Matrix<scalar_type> A( n, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
255 slate::Matrix<scalar_type> B( n, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
256 slate::HermitianMatrix<scalar_type>
257 C( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
258 A.insertLocalTiles();
259 B.insertLocalTiles();
260 C.insertLocalTiles();
261 random_matrix( A );
262 random_matrix( B );
263 random_matrix( C );
264
265 //---------- begin herk
266
267 // C = alpha A A^H + beta C, where C is Hermitian
268 slate::rank_k_update( alpha_real, A, beta, C ); // simplified API
269 slate::herk( alpha_real, A, beta, C ); // traditional API
270 //---------- end herk
271
272 //---------- begin her2k
273
274 // C = alpha A B^H + conj(alpha) B A^H + beta C, where C is Hermitian
275 slate::rank_2k_update( alpha, A, B, beta, C ); // simplified API
276 slate::her2k( alpha, A, B, beta, C ); // traditional API
277 //---------- end her2k
278}
279
280//------------------------------------------------------------------------------
281template <typename scalar_type>
282void test_trmm_trsm_left()
283{
284 print_func( mpi_rank );
285
286 scalar_type alpha = 2.0;
287 int64_t m=2000, n=1000, nb=256;
288
289 // A is m-by-m, B is m-by-n
290 slate::TriangularMatrix<scalar_type>
291 A( slate::Uplo::Lower, slate::Diag::NonUnit, m, nb,
292 grid_p, grid_q, MPI_COMM_WORLD );
293 slate::Matrix<scalar_type> B( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
294 A.insertLocalTiles();
295 B.insertLocalTiles();
296 random_matrix( A );
297 random_matrix( B );
298
299 //---------- begin trmm_left
300
301 //----- left
302 // B = alpha A B, where A is triangular, on left side
303 slate::triangular_multiply( alpha, A, B ); // simplified API
304 slate::trmm( slate::Side::Left, alpha, A, B ); // traditional API
305
306 // Solve AX = B, where A is triangular, on left side; X overwrites B.
307 // That is, B = alpha A^{-1} B.
308 slate::triangular_solve( alpha, A, B ); // simplified API
309 slate::trsm( slate::Side::Left, alpha, A, B ); // traditional API
310 //---------- end trmm_left
311}
312
313//------------------------------------------------------------------------------
314template <typename scalar_type>
315void test_trmm_trsm_right()
316{
317 print_func( mpi_rank );
318
319 scalar_type alpha = 2.0;
320 int64_t m=2000, n=1000, nb=256;
321
322 // A is m-by-m, B is n-by-m (reverse of left case above).
323 slate::TriangularMatrix<scalar_type>
324 A( slate::Uplo::Lower, slate::Diag::NonUnit, m, nb,
325 grid_p, grid_q, MPI_COMM_WORLD );
326 slate::Matrix<scalar_type> B( n, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
327 A.insertLocalTiles();
328 B.insertLocalTiles();
329 random_matrix( A );
330 random_matrix( B );
331
332 //---------- begin trmm_right
333
334 //----- right
335 // B = alpha B A, where A is triangular, on right side
336 // Note B, A order reversed in multiply compared to trmm.
337 slate::triangular_multiply( alpha, B, A ); // simplified API
338 slate::trmm( slate::Side::Right, alpha, A, B ); // traditional API
339
340 // Solve XA = B, where A is triangular, on right side; X overwrites B.
341 // That is, B = alpha B A^{-1}.
342 // Note B, A order reversed in solve compared to trsm.
343 slate::triangular_solve( alpha, B, A ); // simplified API
344 slate::trsm( slate::Side::Right, alpha, A, B ); // traditional API
345 //---------- end trmm_right
346}
347
348//------------------------------------------------------------------------------
349template <typename scalar_type>
350void test_all()
351{
352 test_gemm < scalar_type >();
353 test_gemm_trans< scalar_type >();
354 test_symm_left < scalar_type >();
355 test_symm_right< scalar_type >();
356 test_hemm_left < scalar_type >();
357 test_hemm_right< scalar_type >();
358 test_syrk_syr2k< scalar_type >();
359 test_herk_her2k< scalar_type >();
360 test_trmm_trsm_left < scalar_type >();
361 test_trmm_trsm_right< scalar_type >();
362}
363
364//------------------------------------------------------------------------------
365int main( int argc, char** argv )
366{
367 try {
368 // Parse command line to set types for s, d, c, z precisions.
369 bool types[ 4 ];
370 parse_args( argc, argv, types );
371
372 int provided = 0;
373 slate_mpi_call(
374 MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) );
375 assert( provided == MPI_THREAD_MULTIPLE );
376
377 slate_mpi_call(
378 MPI_Comm_size( MPI_COMM_WORLD, &mpi_size ) );
379
380 slate_mpi_call(
381 MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank ) );
382
383 // Determine p-by-q grid for this MPI size.
384 grid_size( mpi_size, &grid_p, &grid_q );
385 if (mpi_rank == 0) {
386 printf( "mpi_size %d, grid_p %d, grid_q %d\n",
387 mpi_size, grid_p, grid_q );
388 }
389
390 // so random_matrix is different on different ranks.
391 srand( 100 * mpi_rank );
392
393 if (types[ 0 ]) {
394 test_all< float >();
395 }
396 if (mpi_rank == 0)
397 printf( "\n" );
398
399 if (types[ 1 ]) {
400 test_all< double >();
401 }
402 if (mpi_rank == 0)
403 printf( "\n" );
404
405 if (types[ 2 ]) {
406 test_all< std::complex<float> >();
407 }
408 if (mpi_rank == 0)
409 printf( "\n" );
410
411 if (types[ 3 ]) {
412 test_all< std::complex<double> >();
413 }
414
415 slate_mpi_call(
416 MPI_Finalize() );
417 }
418 catch (std::exception const& ex) {
419 fprintf( stderr, "%s", ex.what() );
420 return 1;
421 }
422 return 0;
423}
C API Example
1// ex05_blas.c
2// BLAS routines
3
4#include <slate/c_api/slate.h>
5#include <mpi.h>
6
7#include "util.h"
8
9int mpi_size = 0;
10int mpi_rank = 0;
11int grid_p = 0;
12int grid_q = 0;
13
14//------------------------------------------------------------------------------
15void test_gemm_r32()
16{
17 print_func( mpi_rank );
18
19 double alpha = 2.0, beta = 1.0;
20 int64_t m=2000, n=1000, k=500, nb=256;
21
22 slate_Matrix_r32 A = slate_Matrix_create_r32(
23 m, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
24 slate_Matrix_r32 B = slate_Matrix_create_r32(
25 k, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
26 slate_Matrix_r32 C = slate_Matrix_create_r32(
27 m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
28 slate_Matrix_insertLocalTiles_r32( A );
29 slate_Matrix_insertLocalTiles_r32( B );
30 slate_Matrix_insertLocalTiles_r32( C );
31 random_Matrix_r32( A );
32 random_Matrix_r32( B );
33 random_Matrix_r32( C );
34
35 // C = alpha A B + beta C, where A, B, C are all general matrices.
36 slate_multiply_r32( alpha, A, B, beta, C, NULL );
37
38 if (slate_Matrix_num_devices_r32( C ) > 0) {
39 // Execute on GPU devices with lookahead of 2.
40 slate_Options opts = slate_Options_create();
41 slate_Options_set_Target( opts, slate_Target_Devices );
42 slate_Options_set_Lookahead( opts, 2 );
43
44 slate_multiply_r32( alpha, A, B, beta, C, opts );
45
46 slate_Options_destroy( opts );
47 }
48
49 slate_Matrix_destroy_r32( A );
50 slate_Matrix_destroy_r32( B );
51 slate_Matrix_destroy_r32( C );
52}
53
54//------------------------------------------------------------------------------
55void test_gemm_r64()
56{
57 print_func( mpi_rank );
58
59 double alpha = 2.0, beta = 1.0;
60 int64_t m=2000, n=1000, k=500, nb=256;
61
62 slate_Matrix_r64 A = slate_Matrix_create_r64(
63 m, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
64 slate_Matrix_r64 B = slate_Matrix_create_r64(
65 k, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
66 slate_Matrix_r64 C = slate_Matrix_create_r64(
67 m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
68 slate_Matrix_insertLocalTiles_r64( A );
69 slate_Matrix_insertLocalTiles_r64( B );
70 slate_Matrix_insertLocalTiles_r64( C );
71 random_Matrix_r64( A );
72 random_Matrix_r64( B );
73 random_Matrix_r64( C );
74
75 // C = alpha A B + beta C, where A, B, C are all general matrices.
76 slate_multiply_r64( alpha, A, B, beta, C, NULL );
77
78 if (slate_Matrix_num_devices_r64( C ) > 0) {
79 // Execute on GPU devices with lookahead of 2.
80 slate_Options opts = slate_Options_create();
81 slate_Options_set_Target( opts, slate_Target_Devices );
82 slate_Options_set_Lookahead( opts, 2 );
83
84 slate_multiply_r64( alpha, A, B, beta, C, opts );
85
86 slate_Options_destroy( opts );
87 }
88
89 slate_Matrix_destroy_r64( A );
90 slate_Matrix_destroy_r64( B );
91 slate_Matrix_destroy_r64( C );
92}
93
94//------------------------------------------------------------------------------
95void test_gemm_c32()
96{
97 print_func( mpi_rank );
98
99 double alpha = 2.0, beta = 1.0;
100 int64_t m=2000, n=1000, k=500, nb=256;
101
102 slate_Matrix_c32 A = slate_Matrix_create_c32(
103 m, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
104 slate_Matrix_c32 B = slate_Matrix_create_c32(
105 k, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
106 slate_Matrix_c32 C = slate_Matrix_create_c32(
107 m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
108 slate_Matrix_insertLocalTiles_c32( A );
109 slate_Matrix_insertLocalTiles_c32( B );
110 slate_Matrix_insertLocalTiles_c32( C );
111 random_Matrix_c32( A );
112 random_Matrix_c32( B );
113 random_Matrix_c32( C );
114
115 // C = alpha A B + beta C, where A, B, C are all general matrices.
116 slate_multiply_c32( alpha, A, B, beta, C, NULL );
117
118 if (slate_Matrix_num_devices_c32( C ) > 0) {
119 // Execute on GPU devices with lookahead of 2.
120 slate_Options opts = slate_Options_create();
121 slate_Options_set_Target( opts, slate_Target_Devices );
122 slate_Options_set_Lookahead( opts, 2 );
123
124 slate_multiply_c32( alpha, A, B, beta, C, opts );
125
126 slate_Options_destroy( opts );
127 }
128
129 slate_Matrix_destroy_c32( A );
130 slate_Matrix_destroy_c32( B );
131 slate_Matrix_destroy_c32( C );
132}
133
134//------------------------------------------------------------------------------
135void test_gemm_c64()
136{
137 print_func( mpi_rank );
138
139 double alpha = 2.0, beta = 1.0;
140 int64_t m=2000, n=1000, k=500, nb=256;
141
142 slate_Matrix_c64 A = slate_Matrix_create_c64(
143 m, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
144 slate_Matrix_c64 B = slate_Matrix_create_c64(
145 k, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
146 slate_Matrix_c64 C = slate_Matrix_create_c64(
147 m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
148 slate_Matrix_insertLocalTiles_c64( A );
149 slate_Matrix_insertLocalTiles_c64( B );
150 slate_Matrix_insertLocalTiles_c64( C );
151 random_Matrix_c64( A );
152 random_Matrix_c64( B );
153 random_Matrix_c64( C );
154
155 // C = alpha A B + beta C, where A, B, C are all general matrices.
156 slate_multiply_c64( alpha, A, B, beta, C, NULL );
157
158 if (slate_Matrix_num_devices_c64( C ) > 0) {
159 // Execute on GPU devices with lookahead of 2.
160 slate_Options opts = slate_Options_create();
161 slate_Options_set_Target( opts, slate_Target_Devices );
162 slate_Options_set_Lookahead( opts, 2 );
163
164 slate_multiply_c64( alpha, A, B, beta, C, opts );
165
166 slate_Options_destroy( opts );
167 }
168
169 slate_Matrix_destroy_c64( A );
170 slate_Matrix_destroy_c64( B );
171 slate_Matrix_destroy_c64( C );
172}
173
174//------------------------------------------------------------------------------
175void test_gemm_trans_r32()
176{
177 print_func( mpi_rank );
178
179 double alpha = 2.0, beta = 1.0;
180 int64_t m=2000, n=1000, k=500, nb=256;
181
182 slate_Matrix_r32 A = slate_Matrix_create_r32(
183 k, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
184 slate_Matrix_r32 B = slate_Matrix_create_r32(
185 n, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
186 slate_Matrix_r32 C = slate_Matrix_create_r32(
187 m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
188 slate_Matrix_insertLocalTiles_r32( A );
189 slate_Matrix_insertLocalTiles_r32( B );
190 slate_Matrix_insertLocalTiles_r32( C );
191 random_Matrix_r32( A );
192 random_Matrix_r32( B );
193 random_Matrix_r32( C );
194
195 // Matrices can be transposed or conjugate-transposed beforehand.
196 // C = alpha A^T B^H + beta C
197 slate_Matrix_transpose_in_place_r32( A );
198 slate_Matrix_conj_transpose_in_place_r32( B );
199 slate_multiply_r32( alpha, A, B, beta, C, NULL ); // simplified API
200
201 slate_Matrix_destroy_r32( A );
202 slate_Matrix_destroy_r32( B );
203 slate_Matrix_destroy_r32( C );
204}
205
206//------------------------------------------------------------------------------
207void test_gemm_trans_r64()
208{
209 print_func( mpi_rank );
210
211 double alpha = 2.0, beta = 1.0;
212 int64_t m=2000, n=1000, k=500, nb=256;
213
214 slate_Matrix_r64 A = slate_Matrix_create_r64(
215 k, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
216 slate_Matrix_r64 B = slate_Matrix_create_r64(
217 n, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
218 slate_Matrix_r64 C = slate_Matrix_create_r64(
219 m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
220 slate_Matrix_insertLocalTiles_r64( A );
221 slate_Matrix_insertLocalTiles_r64( B );
222 slate_Matrix_insertLocalTiles_r64( C );
223 random_Matrix_r64( A );
224 random_Matrix_r64( B );
225 random_Matrix_r64( C );
226
227 // Matrices can be transposed or conjugate-transposed beforehand.
228 // C = alpha A^T B^H + beta C
229 slate_Matrix_transpose_in_place_r64( A );
230 slate_Matrix_conj_transpose_in_place_r64( B );
231 slate_multiply_r64( alpha, A, B, beta, C, NULL ); // simplified API
232
233 slate_Matrix_destroy_r64( A );
234 slate_Matrix_destroy_r64( B );
235 slate_Matrix_destroy_r64( C );
236}
237
238//------------------------------------------------------------------------------
239void test_gemm_trans_c32()
240{
241 print_func( mpi_rank );
242
243 double alpha = 2.0, beta = 1.0;
244 int64_t m=2000, n=1000, k=500, nb=256;
245
246 slate_Matrix_c32 A = slate_Matrix_create_c32(
247 k, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
248 slate_Matrix_c32 B = slate_Matrix_create_c32(
249 n, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
250 slate_Matrix_c32 C = slate_Matrix_create_c32(
251 m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
252 slate_Matrix_insertLocalTiles_c32( A );
253 slate_Matrix_insertLocalTiles_c32( B );
254 slate_Matrix_insertLocalTiles_c32( C );
255 random_Matrix_c32( A );
256 random_Matrix_c32( B );
257 random_Matrix_c32( C );
258
259 // Matrices can be transposed or conjugate-transposed beforehand.
260 // C = alpha A^T B^H + beta C
261 slate_Matrix_transpose_in_place_c32( A );
262 slate_Matrix_conj_transpose_in_place_c32( B );
263 slate_multiply_c32( alpha, A, B, beta, C, NULL ); // simplified API
264
265 slate_Matrix_destroy_c32( A );
266 slate_Matrix_destroy_c32( B );
267 slate_Matrix_destroy_c32( C );
268}
269
270//------------------------------------------------------------------------------
271void test_gemm_trans_c64()
272{
273 print_func( mpi_rank );
274
275 double alpha = 2.0, beta = 1.0;
276 int64_t m=2000, n=1000, k=500, nb=256;
277
278 slate_Matrix_c64 A = slate_Matrix_create_c64(
279 k, m, nb, grid_p, grid_q, MPI_COMM_WORLD );
280 slate_Matrix_c64 B = slate_Matrix_create_c64(
281 n, k, nb, grid_p, grid_q, MPI_COMM_WORLD );
282 slate_Matrix_c64 C = slate_Matrix_create_c64(
283 m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
284 slate_Matrix_insertLocalTiles_c64( A );
285 slate_Matrix_insertLocalTiles_c64( B );
286 slate_Matrix_insertLocalTiles_c64( C );
287 random_Matrix_c64( A );
288 random_Matrix_c64( B );
289 random_Matrix_c64( C );
290
291 // Matrices can be transposed or conjugate-transposed beforehand.
292 // C = alpha A^T B^H + beta C
293 slate_Matrix_transpose_in_place_c64( A );
294 slate_Matrix_conj_transpose_in_place_c64( B );
295 slate_multiply_c64( alpha, A, B, beta, C, NULL ); // simplified API
296
297 slate_Matrix_destroy_c64( A );
298 slate_Matrix_destroy_c64( B );
299 slate_Matrix_destroy_c64( C );
300}
301
302//------------------------------------------------------------------------------
303int main( int argc, char** argv )
304{
305 // Parse command line to set types for s, d, c, z precisions.
306 bool types[ 4 ];
307 parse_args( argc, argv, types );
308
309 int provided = 0;
310 MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided );
311 assert( provided == MPI_THREAD_MULTIPLE );
312
313 MPI_Comm_size( MPI_COMM_WORLD, &mpi_size );
314 MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank );
315
316 // Determine p-by-q grid for this MPI size.
317 grid_size( mpi_size, &grid_p, &grid_q );
318 if (mpi_rank == 0) {
319 printf( "mpi_size %d, grid_p %d, grid_q %d\n",
320 mpi_size, grid_p, grid_q );
321 }
322
323 // so random_matrix is different on different ranks.
324 srand( 100 * mpi_rank );
325
326 if (types[ 0 ]) {
327 test_gemm_r32();
328 test_gemm_trans_r32();
329 if (mpi_rank == 0)
330 printf( "\n" );
331 }
332
333 if (types[ 1 ]) {
334 test_gemm_r64();
335 test_gemm_trans_r64();
336 if (mpi_rank == 0)
337 printf( "\n" );
338 }
339
340 if (types[ 2 ]) {
341 test_gemm_c32();
342 test_gemm_trans_c32();
343 if (mpi_rank == 0)
344 printf( "\n" );
345 }
346
347 if (types[ 3 ]) {
348 test_gemm_c64();
349 test_gemm_trans_c64();
350 }
351
352 MPI_Finalize();
353
354 return 0;
355}
Fortran API Example
1! ex05_blas.f90
2! BLAS routines
3program ex05_blas
4 use, intrinsic :: iso_fortran_env
5 use slate
6 use mpi
7 use util
8 implicit none
9
10 !! Variables
11 logical :: types(4)
12 integer(kind=c_int) :: p_grid, q_grid
13
14 integer(kind=c_int) :: provided, ierr
15 integer(kind=c_int) :: mpi_rank, mpi_size
16
17 !! Get requested types
18 call parse_args( types );
19
20 !! MPI
21 call MPI_Init_thread( MPI_THREAD_MULTIPLE, provided, ierr )
22 if ((ierr .ne. 0) .or. (provided .ne. MPI_THREAD_MULTIPLE)) then
23 print *, "Error: MPI_Init_thread"
24 return
25 end if
26 call MPI_Comm_size( MPI_COMM_WORLD, mpi_size, ierr )
27 if (ierr .ne. 0) then
28 print *, "Error: MPI_Comm_size"
29 return
30 end if
31 call MPI_Comm_rank( MPI_COMM_WORLD, mpi_rank, ierr )
32 if (ierr .ne. 0) then
33 print *, "Error: MPI_Comm_rank"
34 return
35 end if
36
37 call grid_size( mpi_size, p_grid, q_grid )
38
39 call srand( 100 * mpi_rank )
40
41 if (types(1)) then
42 call test_gemm_r32()
43 call test_gemm_trans_r32()
44
45 if (mpi_rank == 0) then
46 print *
47 end if
48 end if
49 if (types(2)) then
50 call test_gemm_r64()
51 call test_gemm_trans_r64()
52
53 if (mpi_rank == 0) then
54 print *
55 end if
56 end if
57 if (types(3)) then
58 call test_gemm_c32()
59 call test_gemm_trans_c32()
60
61 if (mpi_rank == 0) then
62 print *
63 end if
64 end if
65 if (types(4)) then
66 call test_gemm_c64()
67 call test_gemm_trans_c64()
68
69 if (mpi_rank == 0) then
70 print *
71 end if
72 end if
73
74 call MPI_Finalize( ierr )
75 if (ierr .ne. 0) then
76 print *, "Error: MPI_Finalize"
77 return
78 end if
79
80contains
81
82 subroutine test_gemm_r32()
83 !! Constants
84 integer(kind=c_int64_t), parameter :: m = 2000
85 integer(kind=c_int64_t), parameter :: n = 1000
86 integer(kind=c_int64_t), parameter :: k = 500
87 integer(kind=c_int64_t), parameter :: nb = 256
88
89 real(kind=c_float), parameter :: alpha = 2.0
90 real(kind=c_float), parameter :: beta = 1.0
91
92 !! Variables
93 integer(kind=c_int64_t) :: i
94 type(c_ptr) :: A, B, C, opts
95
96 !! Example
97 call print_func( mpi_rank, 'test_gemm_r32' )
98
99 A = slate_Matrix_create_r32( m, k, nb, p_grid, q_grid, MPI_COMM_WORLD )
100 B = slate_Matrix_create_r32( k, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
101 C = slate_Matrix_create_r32( m, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
102 call slate_Matrix_insertLocalTiles_r32( A )
103 call slate_Matrix_insertLocalTiles_r32( B )
104 call slate_Matrix_insertLocalTiles_r32( C )
105 call random_Matrix_r32( A )
106 call random_Matrix_r32( B )
107 call random_Matrix_r32( C )
108
109 ! C = alpha A B + beta C
110 call slate_multiply_r32( alpha, A, B, beta, C, c_null_ptr )
111
112 if (slate_Matrix_num_devices_r32( C ) > 0) then
113 opts = slate_Options_create()
114 call slate_Options_set_Target( opts, slate_Target_Devices );
115 call slate_Options_set_Lookahead( opts, 2_int64 )
116
117 call slate_multiply_r32( alpha, A, B, beta, C, opts )
118
119 call slate_Options_destroy( opts )
120 endif
121
122
123 call slate_Matrix_destroy_r32( A )
124 call slate_Matrix_destroy_r32( B )
125 call slate_Matrix_destroy_r32( C )
126
127 end subroutine test_gemm_r32
128
129 subroutine test_gemm_r64()
130 !! Constants
131 integer(kind=c_int64_t), parameter :: m = 2000
132 integer(kind=c_int64_t), parameter :: n = 1000
133 integer(kind=c_int64_t), parameter :: k = 500
134 integer(kind=c_int64_t), parameter :: nb = 256
135
136 real(kind=c_double), parameter :: alpha = 2.0
137 real(kind=c_double), parameter :: beta = 1.0
138
139 !! Variables
140 integer(kind=c_int64_t) :: i
141 type(c_ptr) :: A, B, C, opts
142
143 !! Example
144 call print_func( mpi_rank, 'test_gemm_r64' )
145
146 A = slate_Matrix_create_r64( m, k, nb, p_grid, q_grid, MPI_COMM_WORLD )
147 B = slate_Matrix_create_r64( k, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
148 C = slate_Matrix_create_r64( m, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
149 call slate_Matrix_insertLocalTiles_r64( A )
150 call slate_Matrix_insertLocalTiles_r64( B )
151 call slate_Matrix_insertLocalTiles_r64( C )
152 call random_Matrix_r64( A )
153 call random_Matrix_r64( B )
154 call random_Matrix_r64( C )
155
156 ! C = alpha A B + beta C
157 call slate_multiply_r64( alpha, A, B, beta, C, c_null_ptr )
158
159 if (slate_Matrix_num_devices_r64( C ) > 0) then
160 opts = slate_Options_create()
161 call slate_Options_set_Target( opts, slate_Target_Devices );
162 call slate_Options_set_Lookahead( opts, 2_int64 )
163
164 call slate_multiply_r64( alpha, A, B, beta, C, opts )
165
166 call slate_Options_destroy( opts )
167 endif
168
169
170 call slate_Matrix_destroy_r64( A )
171 call slate_Matrix_destroy_r64( B )
172 call slate_Matrix_destroy_r64( C )
173
174 end subroutine test_gemm_r64
175
176 subroutine test_gemm_c32()
177 !! Constants
178 integer(kind=c_int64_t), parameter :: m = 2000
179 integer(kind=c_int64_t), parameter :: n = 1000
180 integer(kind=c_int64_t), parameter :: k = 500
181 integer(kind=c_int64_t), parameter :: nb = 256
182
183 complex(kind=c_float), parameter :: alpha = 2.0
184 complex(kind=c_float), parameter :: beta = 1.0
185
186 !! Variables
187 integer(kind=c_int64_t) :: i
188 type(c_ptr) :: A, B, C, opts
189
190 !! Example
191 call print_func( mpi_rank, 'test_gemm_c32' )
192
193 A = slate_Matrix_create_c32( m, k, nb, p_grid, q_grid, MPI_COMM_WORLD )
194 B = slate_Matrix_create_c32( k, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
195 C = slate_Matrix_create_c32( m, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
196 call slate_Matrix_insertLocalTiles_c32( A )
197 call slate_Matrix_insertLocalTiles_c32( B )
198 call slate_Matrix_insertLocalTiles_c32( C )
199 call random_Matrix_c32( A )
200 call random_Matrix_c32( B )
201 call random_Matrix_c32( C )
202
203 ! C = alpha A B + beta C
204 call slate_multiply_c32( alpha, A, B, beta, C, c_null_ptr )
205
206 if (slate_Matrix_num_devices_c32( C ) > 0) then
207 opts = slate_Options_create()
208 call slate_Options_set_Target( opts, slate_Target_Devices );
209 call slate_Options_set_Lookahead( opts, 2_int64 )
210
211 call slate_multiply_c32( alpha, A, B, beta, C, opts )
212
213 call slate_Options_destroy( opts )
214 endif
215
216
217 call slate_Matrix_destroy_c32( A )
218 call slate_Matrix_destroy_c32( B )
219 call slate_Matrix_destroy_c32( C )
220
221 end subroutine test_gemm_c32
222
223 subroutine test_gemm_c64()
224 !! Constants
225 integer(kind=c_int64_t), parameter :: m = 2000
226 integer(kind=c_int64_t), parameter :: n = 1000
227 integer(kind=c_int64_t), parameter :: k = 500
228 integer(kind=c_int64_t), parameter :: nb = 256
229
230 complex(kind=c_double), parameter :: alpha = 2.0
231 complex(kind=c_double), parameter :: beta = 1.0
232
233 !! Variables
234 integer(kind=c_int64_t) :: i
235 type(c_ptr) :: A, B, C, opts
236
237 !! Example
238 call print_func( mpi_rank, 'test_gemm_c64' )
239
240 A = slate_Matrix_create_c64( m, k, nb, p_grid, q_grid, MPI_COMM_WORLD )
241 B = slate_Matrix_create_c64( k, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
242 C = slate_Matrix_create_c64( m, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
243 call slate_Matrix_insertLocalTiles_c64( A )
244 call slate_Matrix_insertLocalTiles_c64( B )
245 call slate_Matrix_insertLocalTiles_c64( C )
246 call random_Matrix_c64( A )
247 call random_Matrix_c64( B )
248 call random_Matrix_c64( C )
249
250 ! C = alpha A B + beta C
251 call slate_multiply_c64( alpha, A, B, beta, C, c_null_ptr )
252
253 if (slate_Matrix_num_devices_c64( C ) > 0) then
254 opts = slate_Options_create()
255 call slate_Options_set_Target( opts, slate_Target_Devices );
256 call slate_Options_set_Lookahead( opts, 2_int64 )
257
258 call slate_multiply_c64( alpha, A, B, beta, C, opts )
259
260 call slate_Options_destroy( opts )
261 endif
262
263
264 call slate_Matrix_destroy_c64( A )
265 call slate_Matrix_destroy_c64( B )
266 call slate_Matrix_destroy_c64( C )
267
268 end subroutine test_gemm_c64
269
270 subroutine test_gemm_trans_r32()
271 !! Constants
272 integer(kind=c_int64_t), parameter :: m = 2000
273 integer(kind=c_int64_t), parameter :: n = 1000
274 integer(kind=c_int64_t), parameter :: k = 500
275 integer(kind=c_int64_t), parameter :: nb = 256
276
277 real(kind=c_float), parameter :: alpha = 2.0
278 real(kind=c_float), parameter :: beta = 1.0
279
280 !! Variables
281 integer(kind=c_int64_t) :: i
282 type(c_ptr) :: A, B, C, opts
283
284 !! Example
285 call print_func( mpi_rank, 'test_gemm_trans_r32' )
286
287 A = slate_Matrix_create_r32( k, m, nb, p_grid, q_grid, MPI_COMM_WORLD )
288 B = slate_Matrix_create_r32( n, k, nb, p_grid, q_grid, MPI_COMM_WORLD )
289 C = slate_Matrix_create_r32( m, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
290 call slate_Matrix_insertLocalTiles_r32( A )
291 call slate_Matrix_insertLocalTiles_r32( B )
292 call slate_Matrix_insertLocalTiles_r32( C )
293 call random_Matrix_r32( A )
294 call random_Matrix_r32( B )
295 call random_Matrix_r32( C )
296
297 ! Matrices can be transposed or conjugate-transposed beforehand
298 ! C = alpha AT BH + beta C
299 call slate_Matrix_transpose_in_place_r32( A );
300 call slate_Matrix_conj_transpose_in_place_r32( B );
301 call slate_multiply_r32( alpha, A, B, beta, C, c_null_ptr )
302
303 call slate_Matrix_destroy_r32( A )
304 call slate_Matrix_destroy_r32( B )
305 call slate_Matrix_destroy_r32( C )
306
307 end subroutine test_gemm_trans_r32
308
309 subroutine test_gemm_trans_r64()
310 !! Constants
311 integer(kind=c_int64_t), parameter :: m = 2000
312 integer(kind=c_int64_t), parameter :: n = 1000
313 integer(kind=c_int64_t), parameter :: k = 500
314 integer(kind=c_int64_t), parameter :: nb = 256
315
316 real(kind=c_double), parameter :: alpha = 2.0
317 real(kind=c_double), parameter :: beta = 1.0
318
319 !! Variables
320 integer(kind=c_int64_t) :: i
321 type(c_ptr) :: A, B, C, opts
322
323 !! Example
324 call print_func( mpi_rank, 'test_gemm_trans_r64' )
325
326 A = slate_Matrix_create_r64( k, m, nb, p_grid, q_grid, MPI_COMM_WORLD )
327 B = slate_Matrix_create_r64( n, k, nb, p_grid, q_grid, MPI_COMM_WORLD )
328 C = slate_Matrix_create_r64( m, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
329 call slate_Matrix_insertLocalTiles_r64( A )
330 call slate_Matrix_insertLocalTiles_r64( B )
331 call slate_Matrix_insertLocalTiles_r64( C )
332 call random_Matrix_r64( A )
333 call random_Matrix_r64( B )
334 call random_Matrix_r64( C )
335
336 ! Matrices can be transposed or conjugate-transposed beforehand
337 ! C = alpha AT BH + beta C
338 call slate_Matrix_transpose_in_place_r64( A );
339 call slate_Matrix_conj_transpose_in_place_r64( B );
340 call slate_multiply_r64( alpha, A, B, beta, C, c_null_ptr )
341
342 call slate_Matrix_destroy_r64( A )
343 call slate_Matrix_destroy_r64( B )
344 call slate_Matrix_destroy_r64( C )
345
346 end subroutine test_gemm_trans_r64
347
348 subroutine test_gemm_trans_c32()
349 !! Constants
350 integer(kind=c_int64_t), parameter :: m = 2000
351 integer(kind=c_int64_t), parameter :: n = 1000
352 integer(kind=c_int64_t), parameter :: k = 500
353 integer(kind=c_int64_t), parameter :: nb = 256
354
355 complex(kind=c_float), parameter :: alpha = 2.0
356 complex(kind=c_float), parameter :: beta = 1.0
357
358 !! Variables
359 integer(kind=c_int64_t) :: i
360 type(c_ptr) :: A, B, C, opts
361
362 !! Example
363 call print_func( mpi_rank, 'test_gemm_trans_c32' )
364
365 A = slate_Matrix_create_c32( k, m, nb, p_grid, q_grid, MPI_COMM_WORLD )
366 B = slate_Matrix_create_c32( n, k, nb, p_grid, q_grid, MPI_COMM_WORLD )
367 C = slate_Matrix_create_c32( m, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
368 call slate_Matrix_insertLocalTiles_c32( A )
369 call slate_Matrix_insertLocalTiles_c32( B )
370 call slate_Matrix_insertLocalTiles_c32( C )
371 call random_Matrix_c32( A )
372 call random_Matrix_c32( B )
373 call random_Matrix_c32( C )
374
375 ! Matrices can be transposed or conjugate-transposed beforehand
376 ! C = alpha AT BH + beta C
377 call slate_Matrix_transpose_in_place_c32( A );
378 call slate_Matrix_conj_transpose_in_place_c32( B );
379 call slate_multiply_c32( alpha, A, B, beta, C, c_null_ptr )
380
381 call slate_Matrix_destroy_c32( A )
382 call slate_Matrix_destroy_c32( B )
383 call slate_Matrix_destroy_c32( C )
384
385 end subroutine test_gemm_trans_c32
386
387 subroutine test_gemm_trans_c64()
388 !! Constants
389 integer(kind=c_int64_t), parameter :: m = 2000
390 integer(kind=c_int64_t), parameter :: n = 1000
391 integer(kind=c_int64_t), parameter :: k = 500
392 integer(kind=c_int64_t), parameter :: nb = 256
393
394 complex(kind=c_double), parameter :: alpha = 2.0
395 complex(kind=c_double), parameter :: beta = 1.0
396
397 !! Variables
398 integer(kind=c_int64_t) :: i
399 type(c_ptr) :: A, B, C, opts
400
401 !! Example
402 call print_func( mpi_rank, 'test_gemm_trans_c64' )
403
404 A = slate_Matrix_create_c64( k, m, nb, p_grid, q_grid, MPI_COMM_WORLD )
405 B = slate_Matrix_create_c64( n, k, nb, p_grid, q_grid, MPI_COMM_WORLD )
406 C = slate_Matrix_create_c64( m, n, nb, p_grid, q_grid, MPI_COMM_WORLD )
407 call slate_Matrix_insertLocalTiles_c64( A )
408 call slate_Matrix_insertLocalTiles_c64( B )
409 call slate_Matrix_insertLocalTiles_c64( C )
410 call random_Matrix_c64( A )
411 call random_Matrix_c64( B )
412 call random_Matrix_c64( C )
413
414 ! Matrices can be transposed or conjugate-transposed beforehand
415 ! C = alpha AT BH + beta C
416 call slate_Matrix_transpose_in_place_c64( A );
417 call slate_Matrix_conj_transpose_in_place_c64( B );
418 call slate_multiply_c64( alpha, A, B, beta, C, c_null_ptr )
419
420 call slate_Matrix_destroy_c64( A )
421 call slate_Matrix_destroy_c64( B )
422 call slate_Matrix_destroy_c64( C )
423
424 end subroutine test_gemm_trans_c64
425
426end program ex05_blas