Example 15: Setting Matrix Elements
This example demonstrates advanced ways to set matrix elements using lambda functions.
Key Concepts
Functional Initialization: Using slate::set with a lambda function f(i, j) to set \(A_{ij}\).
Parallel Execution: The lambda function is executed in parallel across tiles.
Use Cases: * Random initialization (non-deterministic if not careful with seeds). * Coordinate-based initialization (e.g., \(A_{ij} = i + j\)). * Stencil generation (e.g., Laplacian).
C++ Example
Setting Random Values (Lines 19-48)
using entry_type = std::function< scalar_type (int64_t, int64_t) >;
entry_type entry = [random_max]( int64_t i, int64_t j ) { ... };
slate::set( entry, A );
slate::set iterates over the matrix A and for every element (i, j), calls the provided function entry(i, j) to determine the value. - Note: Since tiles are processed in parallel, using a global stateful random number generator (like rand()) inside the lambda is non-deterministic regarding the exact value pattern across runs or ranks unless synchronized or thread-local.
Coordinate-based Initialization (Lines 98-112)
entry_type entry = []( int64_t i, int64_t j ) {
// Return value based on i and j
return i + 1 + (j + 1)/1000.;
};
slate::set( entry, A );
This is useful for generating deterministic test matrices where the value depends on the position.
Stencil Initialization (Lines 128-143)
entry_type entry = [n]( int64_t i, int64_t j ) {
if (i == j) return -3.0; // Diagonal
else if (...) return 0.5; // Neighbors
// ...
};
This pattern allows initializing sparse or structured dense matrices, such as those arising from finite difference discretizations (e.g., a 9-point Laplacian stencil). The lambda defines the connectivity logic.
1// ex15_set_matrix.cc
2// Set matrix entries
3
4/// !!! Lines between `//---------- begin label` !!!
5/// !!! and `//---------- end label` !!!
6/// !!! are included in the SLATE Users' Guide. !!!
7
8#include <slate/slate.hh>
9
10#include "util.hh"
11
12int mpi_size = 0;
13int mpi_rank = 0;
14int grid_p = 1;
15int grid_q = 1;
16
17//------------------------------------------------------------------------------
18template <typename scalar_type>
19void test_set_rand()
20{
21 using real_t = blas::real_type<scalar_type>;
22
23 print_func( mpi_rank );
24
25 int64_t m=20, n=20, nb=8;
26 slate::Matrix<scalar_type> A( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
27 A.insertLocalTiles();
28
29 using entry_type = std::function< scalar_type (int64_t, int64_t) >;
30
31 const real_t random_max = INT32_MAX; // 2^31 - 1
32
33 // Lambda to set entry A_ij.
34 // This is non-deterministic since tiles are set in parallel!
35 // SLATE's matgen library has a deterministic parallel random matrix generator.
36 entry_type entry = [random_max]( int64_t i, int64_t j )
37 {
38 if constexpr (blas::is_complex<scalar_type>::value) {
39 return blas::make_scalar<scalar_type>( random() / random_max,
40 random() / random_max );
41 }
42 else {
43 return random() / random_max;
44 };
45 };
46
47 slate::set( entry, A );
48 slate::print( "A", A );
49}
50
51//------------------------------------------------------------------------------
52template <typename scalar_type>
53void test_set_rand_hermitian()
54{
55 using real_t = blas::real_type<scalar_type>;
56
57 print_func( mpi_rank );
58
59 int64_t n=20, nb=8;
60 slate::HermitianMatrix<scalar_type>
61 A( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
62 A.insertLocalTiles();
63
64 using entry_type = std::function< scalar_type (int64_t, int64_t) >;
65
66 const real_t random_max = INT32_MAX; // 2^31 - 1
67
68 // Lambda to set entry A_ij.
69 // This is non-deterministic since tiles are set in parallel!
70 // SLATE's matgen library has a deterministic parallel random matrix generator.
71 entry_type entry = [random_max]( int64_t i, int64_t j )
72 {
73 if constexpr (blas::is_complex<scalar_type>::value) {
74 return blas::make_scalar<scalar_type>( random() / random_max,
75 random() / random_max );
76 }
77 else {
78 return random() / random_max;
79 };
80 };
81
82 slate::set( entry, A );
83 slate::print( "A", A );
84}
85
86//------------------------------------------------------------------------------
87template <typename scalar_type>
88void test_set_ij()
89{
90 print_func( mpi_rank );
91
92 int64_t m=20, n=20, nb=8;
93 slate::Matrix<scalar_type> A( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
94 A.insertLocalTiles();
95
96 using entry_type = std::function< scalar_type (int64_t, int64_t) >;
97
98 // Lambda to set entry A_ij.
99 entry_type entry = []( int64_t i, int64_t j )
100 {
101 if constexpr (blas::is_complex<scalar_type>::value) {
102 // In complex, real part is i, imag part is j.
103 return blas::make_scalar<scalar_type>( i + 1, j + 1 );
104 }
105 else {
106 // In real, integer part is i, fraction part is j.
107 return i + 1 + (j + 1)/1000.;
108 }
109 };
110
111 slate::set( entry, A );
112 slate::print( "A", A );
113}
114
115//------------------------------------------------------------------------------
116template <typename scalar_type>
117void test_set_stencil()
118{
119 print_func( mpi_rank );
120
121 int64_t n=5, n2=n*n, nb=8;
122 slate::Matrix<scalar_type> A( n2, n2, nb, grid_p, grid_q, MPI_COMM_WORLD );
123 A.insertLocalTiles();
124
125 using entry_type = std::function< scalar_type (int64_t, int64_t) >;
126
127 // Lambda for 9-point Laplacian stencil in 2D.
128 entry_type entry = [n]( int64_t i, int64_t j )
129 {
130 if (i == j)
131 return -3.0;
132 else if (i == j-1 || i == j+1
133 || i == j-n || i == j+n)
134 return 0.5;
135 else if (i == j-n-1 || i == j-n+1
136 || i == j+n-1 || i == j+n+1)
137 return 0.25;
138 else
139 return 0.0;
140 };
141
142 slate::set( entry, A );
143 slate::print( "A", A );
144}
145
146//------------------------------------------------------------------------------
147int main( int argc, char** argv )
148{
149 try {
150 // Parse command line to set types for s, d, c, z precisions.
151 bool types[ 4 ];
152 parse_args( argc, argv, types );
153
154 int provided = 0;
155 slate_mpi_call(
156 MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) );
157 assert( provided == MPI_THREAD_MULTIPLE );
158
159 slate_mpi_call(
160 MPI_Comm_size( MPI_COMM_WORLD, &mpi_size ) );
161
162 slate_mpi_call(
163 MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank ) );
164
165 unsigned t = time( nullptr );
166 printf( "srandom( %u )\n", t );
167 srandom( t );
168
169 if (types[ 0 ]) {
170 test_set_rand<float>();
171 test_set_rand_hermitian<float>();
172 test_set_ij<float>();
173 test_set_stencil<float>();
174 }
175 if (types[ 1 ]) {
176 test_set_rand<double>();
177 test_set_rand_hermitian<double>();
178 test_set_ij<double>();
179 test_set_stencil<double>();
180 }
181 if (types[ 2 ]) {
182 test_set_rand< std::complex<float> >();
183 test_set_rand_hermitian< std::complex<float> >();
184 test_set_ij< std::complex<float> >();
185 test_set_stencil< std::complex<float> >();
186 }
187 if (types[ 3 ]) {
188 test_set_rand< std::complex<double> >();
189 test_set_rand_hermitian< std::complex<double> >();
190 test_set_ij< std::complex<double> >();
191 test_set_stencil< std::complex<double> >();
192 }
193
194 slate_mpi_call(
195 MPI_Finalize() );
196 }
197 catch (std::exception const& ex) {
198 fprintf( stderr, "%s", ex.what() );
199 return 1;
200 }
201 return 0;
202}