Example 03: Submatrices and Slicing

This example demonstrates how to work with submatrices in SLATE.

Key Concepts

  1. Tile Indexing (sub): Creating a submatrix view using tile indices (block coordinates). This is the most efficient way to reference submatrices in SLATE.

  2. Element Indexing (slice): Creating a submatrix view using global element indices (row/column coordinates). Note that slices must align with block boundaries if they are to be treated as standard distributed matrices in many operations.

C++ Example

Tile-based Submatrices (Lines 36-39)

// view of A( i1 : i2, j1 : j2 ) as tile indices, inclusive
auto B = A.sub( i1, i2, j1, j2 );

The sub method creates a view into the matrix using block (tile) coordinates.

  • i1, i2: Start and end block row indices (inclusive).

  • j1, j2: Start and end block column indices (inclusive).

  • If A has tiles of size nb, sub(1, 1, …) starts at global row nb.

  • This operation is very fast and simply adjusts internal offsets and dimensions. It creates a shallow copy.

Common `sub` Use Cases (Lines 43-73)

  • B = A: Assigning a matrix to another creates a shallow copy view of the entire matrix.

  • B = A.sub(0, mt-1, 0, nt-1): Explicitly selecting the whole matrix range.

  • B = A.sub(0, mt-1, 0, 0): Selecting the first block column.

  • B = A.sub(0, 0, 0, nt-1): Selecting the first block row.

Element-based Slicing (Lines 77-80)

// view of A( row1 : row2, col1 : col2 ), inclusive
B = A.slice( row1, row2, col1, col2 );

The slice method creates a view using global element indices (0-based row/column indices).

  • row1, row2: Start and end row indices (inclusive).

  • col1, col2: Start and end column indices (inclusive).

  • Important: Slicing allows for arbitrary boundaries. However, many SLATE algorithms require matrix views to be aligned with tile boundaries. If you slice in the middle of a tile, you may be restricted in which operations you can perform on that view.

Common `slice` Use Cases (Lines 84-106)

  • B = A.slice(0, m-1, 0, n-1): Slice of the entire matrix dimensions.

  • B = A.slice(0, m-1, 0, 0): Slice of the first column (single vector).

  • B = A.slice(0, 0, 0, n-1): Slice of the first row.

  1// ex03_submatrix.cc
  2// A.sub and A.slice
  3
  4/// !!!   Lines between `//---------- begin label`          !!!
  5/// !!!             and `//---------- end label`            !!!
  6/// !!!   are included in the SLATE Users' Guide.           !!!
  7
  8#include <slate/slate.hh>
  9
 10#include "util.hh"
 11
 12int mpi_size = 0;
 13int mpi_rank = 0;
 14int grid_p = 0;
 15int grid_q = 0;
 16
 17//------------------------------------------------------------------------------
 18template <typename scalar_type>
 19void test_submatrix()
 20{
 21    using llong = long long;
 22
 23    print_func( mpi_rank );
 24
 25    int64_t m=2000, n=1000, nb=256;
 26    int64_t i1=1, i2=3, j1=2, j2=3;
 27    int64_t row1=100, row2=300, col1=200, col2=400;
 28
 29    slate::Matrix<scalar_type>
 30        A( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
 31    printf( "rank %d: A mt=%lld, nt=%lld, m=%lld, n=%lld\n",
 32            mpi_rank, llong(A.mt()), llong(A.nt()), llong(A.m()), llong(A.n()) );
 33
 34    //---------------------------------------- sub-matrix
 35
 36    //---------- begin sub1
 37    // view of A( i1 : i2, j1 : j2 ) as tile indices, inclusive
 38    auto B = A.sub( i1, i2, j1, j2 );
 39    //---------- end sub1
 40    printf( "rank %d: B mt=%lld, nt=%lld, m=%lld, n=%lld\n",
 41            mpi_rank, llong(B.mt()), llong(B.nt()), llong(B.m()), llong(B.n()) );
 42
 43    //---------- begin sub2
 44
 45    // view of all of A
 46    B = A;
 47    //---------- end sub2
 48    printf( "rank %d: B mt=%lld, nt=%lld, m=%lld, n=%lld\n",
 49            mpi_rank, llong(B.mt()), llong(B.nt()), llong(B.m()), llong(B.n()) );
 50
 51    //---------- begin sub3
 52
 53    // same, view of all of A
 54    B = A.sub( 0, A.mt()-1, 0, A.nt()-1 );
 55    //---------- end sub3
 56    printf( "rank %d: B mt=%lld, nt=%lld, m=%lld, n=%lld\n",
 57            mpi_rank, llong(B.mt()), llong(B.nt()), llong(B.m()), llong(B.n()) );
 58
 59    //---------- begin sub4
 60
 61    // view of first block-column, A[ 0:mt-1, 0:0 ] as tile indices
 62    B = A.sub( 0, A.mt()-1, 0, 0 );
 63    //---------- end sub4
 64    printf( "rank %d: B mt=%lld, nt=%lld, m=%lld, n=%lld\n",
 65            mpi_rank, llong(B.mt()), llong(B.nt()), llong(B.m()), llong(B.n()) );
 66
 67    //---------- begin sub5
 68
 69    // view of first block-row, A[ 0:0, 0:nt-1 ] as tile indices
 70    B = A.sub( 0, 0, 0, A.nt()-1 );
 71    //---------- end sub5
 72    printf( "rank %d: B mt=%lld, nt=%lld, m=%lld, n=%lld\n",
 73            mpi_rank, llong(B.mt()), llong(B.nt()), llong(B.m()), llong(B.n()) );
 74
 75    //---------------------------------------- slicing
 76
 77    //---------- begin slice1
 78    // view of A( row1 : row2, col1 : col2 ), inclusive
 79    B = A.slice( row1, row2, col1, col2 );
 80    //---------- end slice1
 81    printf( "rank %d: B mt=%lld, nt=%lld, m=%lld, n=%lld\n",
 82            mpi_rank, llong(B.mt()), llong(B.nt()), llong(B.m()), llong(B.n()) );
 83
 84    //---------- begin slice2
 85
 86    // view of all of A
 87    B = A.slice( 0, A.m()-1, 0, A.n()-1 );
 88    //---------- end slice2
 89    printf( "rank %d: B mt=%lld, nt=%lld, m=%lld, n=%lld\n",
 90            mpi_rank, llong(B.mt()), llong(B.nt()), llong(B.m()), llong(B.n()) );
 91
 92    //---------- begin slice3
 93
 94    // view of first column, A[ 0:m-1, 0:0 ]
 95    B = A.slice( 0, A.m()-1, 0, 0 );
 96    //---------- end slice3
 97    printf( "rank %d: B mt=%lld, nt=%lld, m=%lld, n=%lld\n",
 98            mpi_rank, llong(B.mt()), llong(B.nt()), llong(B.m()), llong(B.n()) );
 99
100    //---------- begin slice4
101
102    // view of first row, A[ 0:0, 0:n-1 ]
103    B = A.slice( 0, 0, 0, A.n()-1 );
104    //---------- end slice4
105    printf( "rank %d: B mt=%lld, nt=%lld, m=%lld, n=%lld\n",
106            mpi_rank, llong(B.mt()), llong(B.nt()), llong(B.m()), llong(B.n()) );
107}
108
109//------------------------------------------------------------------------------
110int main( int argc, char** argv )
111{
112    try {
113        // Parse command line to set types for s, d, c, z precisions.
114        bool types[ 4 ];
115        parse_args( argc, argv, types );
116
117        int provided = 0;
118        slate_mpi_call(
119            MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) );
120        assert( provided == MPI_THREAD_MULTIPLE );
121
122        slate_mpi_call(
123            MPI_Comm_size( MPI_COMM_WORLD, &mpi_size ) );
124
125        slate_mpi_call(
126            MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank ) );
127
128        // Determine p-by-q grid for this MPI size.
129        grid_size( mpi_size, &grid_p, &grid_q );
130        if (mpi_rank == 0) {
131            printf( "mpi_size %d, grid_p %d, grid_q %d\n",
132                    mpi_size, grid_p, grid_q );
133        }
134
135        // so random_matrix is different on different ranks.
136        srand( 100 * mpi_rank );
137
138        if (types[ 0 ]) {
139            test_submatrix< float >();
140        }
141
142        slate_mpi_call(
143            MPI_Finalize() );
144    }
145    catch (std::exception const& ex) {
146        fprintf( stderr, "%s", ex.what() );
147        return 1;
148    }
149    return 0;
150}