Example 15: Setting Matrix Elements

This example demonstrates advanced ways to set matrix elements using lambda functions.

Key Concepts

  1. Functional Initialization: Using slate::set with a lambda function f(i, j) to set \(A_{ij}\).

  2. Parallel Execution: The lambda function is executed in parallel across tiles.

  3. Use Cases: * Random initialization (non-deterministic if not careful with seeds). * Coordinate-based initialization (e.g., \(A_{ij} = i + j\)). * Stencil generation (e.g., Laplacian).

C++ Example

Setting Random Values (Lines 19-48)

using entry_type = std::function< scalar_type (int64_t, int64_t) >;
entry_type entry = [random_max]( int64_t i, int64_t j ) { ... };
slate::set( entry, A );

slate::set iterates over the matrix A and for every element (i, j), calls the provided function entry(i, j) to determine the value. - Note: Since tiles are processed in parallel, using a global stateful random number generator (like rand()) inside the lambda is non-deterministic regarding the exact value pattern across runs or ranks unless synchronized or thread-local.

Coordinate-based Initialization (Lines 98-112)

entry_type entry = []( int64_t i, int64_t j ) {
    // Return value based on i and j
    return i + 1 + (j + 1)/1000.;
};
slate::set( entry, A );

This is useful for generating deterministic test matrices where the value depends on the position.

Stencil Initialization (Lines 128-143)

entry_type entry = [n]( int64_t i, int64_t j ) {
    if (i == j) return -3.0; // Diagonal
    else if (...) return 0.5; // Neighbors
    // ...
};

This pattern allows initializing sparse or structured dense matrices, such as those arising from finite difference discretizations (e.g., a 9-point Laplacian stencil). The lambda defines the connectivity logic.

  1// ex15_set_matrix.cc
  2// Set matrix entries
  3
  4/// !!!   Lines between `//---------- begin label`          !!!
  5/// !!!             and `//---------- end label`            !!!
  6/// !!!   are included in the SLATE Users' Guide.           !!!
  7
  8#include <slate/slate.hh>
  9
 10#include "util.hh"
 11
 12int mpi_size = 0;
 13int mpi_rank = 0;
 14int grid_p = 1;
 15int grid_q = 1;
 16
 17//------------------------------------------------------------------------------
 18template <typename scalar_type>
 19void test_set_rand()
 20{
 21    using real_t = blas::real_type<scalar_type>;
 22
 23    print_func( mpi_rank );
 24
 25    int64_t m=20, n=20, nb=8;
 26    slate::Matrix<scalar_type> A( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
 27    A.insertLocalTiles();
 28
 29    using entry_type = std::function< scalar_type (int64_t, int64_t) >;
 30
 31    const real_t random_max = INT32_MAX;  // 2^31 - 1
 32
 33    // Lambda to set entry A_ij.
 34    // This is non-deterministic since tiles are set in parallel!
 35    // SLATE's matgen library has a deterministic parallel random matrix generator.
 36    entry_type entry = [random_max]( int64_t i, int64_t j )
 37    {
 38        if constexpr (blas::is_complex<scalar_type>::value) {
 39            return blas::make_scalar<scalar_type>( random() / random_max,
 40                                                   random() / random_max );
 41        }
 42        else {
 43            return random() / random_max;
 44        };
 45    };
 46
 47    slate::set( entry, A );
 48    slate::print( "A", A );
 49}
 50
 51//------------------------------------------------------------------------------
 52template <typename scalar_type>
 53void test_set_rand_hermitian()
 54{
 55    using real_t = blas::real_type<scalar_type>;
 56
 57    print_func( mpi_rank );
 58
 59    int64_t n=20, nb=8;
 60    slate::HermitianMatrix<scalar_type>
 61        A( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
 62    A.insertLocalTiles();
 63
 64    using entry_type = std::function< scalar_type (int64_t, int64_t) >;
 65
 66    const real_t random_max = INT32_MAX;  // 2^31 - 1
 67
 68    // Lambda to set entry A_ij.
 69    // This is non-deterministic since tiles are set in parallel!
 70    // SLATE's matgen library has a deterministic parallel random matrix generator.
 71    entry_type entry = [random_max]( int64_t i, int64_t j )
 72    {
 73        if constexpr (blas::is_complex<scalar_type>::value) {
 74            return blas::make_scalar<scalar_type>( random() / random_max,
 75                                                   random() / random_max );
 76        }
 77        else {
 78            return random() / random_max;
 79        };
 80    };
 81
 82    slate::set( entry, A );
 83    slate::print( "A", A );
 84}
 85
 86//------------------------------------------------------------------------------
 87template <typename scalar_type>
 88void test_set_ij()
 89{
 90    print_func( mpi_rank );
 91
 92    int64_t m=20, n=20, nb=8;
 93    slate::Matrix<scalar_type> A( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
 94    A.insertLocalTiles();
 95
 96    using entry_type = std::function< scalar_type (int64_t, int64_t) >;
 97
 98    // Lambda to set entry A_ij.
 99    entry_type entry = []( int64_t i, int64_t j )
100    {
101        if constexpr (blas::is_complex<scalar_type>::value) {
102            // In complex, real part is i, imag part is j.
103            return blas::make_scalar<scalar_type>( i + 1, j + 1 );
104        }
105        else {
106            // In real, integer part is i, fraction part is j.
107            return i + 1 + (j + 1)/1000.;
108        }
109    };
110
111    slate::set( entry, A );
112    slate::print( "A", A );
113}
114
115//------------------------------------------------------------------------------
116template <typename scalar_type>
117void test_set_stencil()
118{
119    print_func( mpi_rank );
120
121    int64_t n=5, n2=n*n, nb=8;
122    slate::Matrix<scalar_type> A( n2, n2, nb, grid_p, grid_q, MPI_COMM_WORLD );
123    A.insertLocalTiles();
124
125    using entry_type = std::function< scalar_type (int64_t, int64_t) >;
126
127    // Lambda for 9-point Laplacian stencil in 2D.
128    entry_type entry = [n]( int64_t i, int64_t j )
129    {
130        if (i == j)
131            return -3.0;
132        else if (i == j-1 || i == j+1
133              || i == j-n || i == j+n)
134            return 0.5;
135        else if (i == j-n-1 || i == j-n+1
136              || i == j+n-1 || i == j+n+1)
137            return 0.25;
138        else
139            return 0.0;
140    };
141
142    slate::set( entry, A );
143    slate::print( "A", A );
144}
145
146//------------------------------------------------------------------------------
147int main( int argc, char** argv )
148{
149    try {
150        // Parse command line to set types for s, d, c, z precisions.
151        bool types[ 4 ];
152        parse_args( argc, argv, types );
153
154        int provided = 0;
155        slate_mpi_call(
156            MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) );
157        assert( provided == MPI_THREAD_MULTIPLE );
158
159        slate_mpi_call(
160            MPI_Comm_size( MPI_COMM_WORLD, &mpi_size ) );
161
162        slate_mpi_call(
163            MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank ) );
164
165        unsigned t = time( nullptr );
166        printf( "srandom( %u )\n", t );
167        srandom( t );
168
169        if (types[ 0 ]) {
170            test_set_rand<float>();
171            test_set_rand_hermitian<float>();
172            test_set_ij<float>();
173            test_set_stencil<float>();
174        }
175        if (types[ 1 ]) {
176            test_set_rand<double>();
177            test_set_rand_hermitian<double>();
178            test_set_ij<double>();
179            test_set_stencil<double>();
180        }
181        if (types[ 2 ]) {
182            test_set_rand< std::complex<float> >();
183            test_set_rand_hermitian< std::complex<float> >();
184            test_set_ij< std::complex<float> >();
185            test_set_stencil< std::complex<float> >();
186        }
187        if (types[ 3 ]) {
188            test_set_rand< std::complex<double> >();
189            test_set_rand_hermitian< std::complex<double> >();
190            test_set_ij< std::complex<double> >();
191            test_set_stencil< std::complex<double> >();
192        }
193
194        slate_mpi_call(
195            MPI_Finalize() );
196    }
197    catch (std::exception const& ex) {
198        fprintf( stderr, "%s", ex.what() );
199        return 1;
200    }
201    return 0;
202}