Example 03: Submatrices and Slicing
This example demonstrates how to work with submatrices in SLATE.
Key Concepts
Tile Indexing (sub): Creating a submatrix view using tile indices (block coordinates). This is the most efficient way to reference submatrices in SLATE.
Element Indexing (slice): Creating a submatrix view using global element indices (row/column coordinates). Note that slices must align with block boundaries if they are to be treated as standard distributed matrices in many operations.
C++ Example
Tile-based Submatrices (Lines 36-39)
// view of A( i1 : i2, j1 : j2 ) as tile indices, inclusive
auto B = A.sub( i1, i2, j1, j2 );
The sub method creates a view into the matrix using block (tile) coordinates.
i1, i2: Start and end block row indices (inclusive).
j1, j2: Start and end block column indices (inclusive).
If A has tiles of size nb, sub(1, 1, …) starts at global row nb.
This operation is very fast and simply adjusts internal offsets and dimensions. It creates a shallow copy.
Common `sub` Use Cases (Lines 43-73)
B = A: Assigning a matrix to another creates a shallow copy view of the entire matrix.
B = A.sub(0, mt-1, 0, nt-1): Explicitly selecting the whole matrix range.
B = A.sub(0, mt-1, 0, 0): Selecting the first block column.
B = A.sub(0, 0, 0, nt-1): Selecting the first block row.
Element-based Slicing (Lines 77-80)
// view of A( row1 : row2, col1 : col2 ), inclusive
B = A.slice( row1, row2, col1, col2 );
The slice method creates a view using global element indices (0-based row/column indices).
row1, row2: Start and end row indices (inclusive).
col1, col2: Start and end column indices (inclusive).
Important: Slicing allows for arbitrary boundaries. However, many SLATE algorithms require matrix views to be aligned with tile boundaries. If you slice in the middle of a tile, you may be restricted in which operations you can perform on that view.
Common `slice` Use Cases (Lines 84-106)
B = A.slice(0, m-1, 0, n-1): Slice of the entire matrix dimensions.
B = A.slice(0, m-1, 0, 0): Slice of the first column (single vector).
B = A.slice(0, 0, 0, n-1): Slice of the first row.
1// ex03_submatrix.cc
2// A.sub and A.slice
3
4/// !!! Lines between `//---------- begin label` !!!
5/// !!! and `//---------- end label` !!!
6/// !!! are included in the SLATE Users' Guide. !!!
7
8#include <slate/slate.hh>
9
10#include "util.hh"
11
12int mpi_size = 0;
13int mpi_rank = 0;
14int grid_p = 0;
15int grid_q = 0;
16
17//------------------------------------------------------------------------------
18template <typename scalar_type>
19void test_submatrix()
20{
21 using llong = long long;
22
23 print_func( mpi_rank );
24
25 int64_t m=2000, n=1000, nb=256;
26 int64_t i1=1, i2=3, j1=2, j2=3;
27 int64_t row1=100, row2=300, col1=200, col2=400;
28
29 slate::Matrix<scalar_type>
30 A( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
31 printf( "rank %d: A mt=%lld, nt=%lld, m=%lld, n=%lld\n",
32 mpi_rank, llong(A.mt()), llong(A.nt()), llong(A.m()), llong(A.n()) );
33
34 //---------------------------------------- sub-matrix
35
36 //---------- begin sub1
37 // view of A( i1 : i2, j1 : j2 ) as tile indices, inclusive
38 auto B = A.sub( i1, i2, j1, j2 );
39 //---------- end sub1
40 printf( "rank %d: B mt=%lld, nt=%lld, m=%lld, n=%lld\n",
41 mpi_rank, llong(B.mt()), llong(B.nt()), llong(B.m()), llong(B.n()) );
42
43 //---------- begin sub2
44
45 // view of all of A
46 B = A;
47 //---------- end sub2
48 printf( "rank %d: B mt=%lld, nt=%lld, m=%lld, n=%lld\n",
49 mpi_rank, llong(B.mt()), llong(B.nt()), llong(B.m()), llong(B.n()) );
50
51 //---------- begin sub3
52
53 // same, view of all of A
54 B = A.sub( 0, A.mt()-1, 0, A.nt()-1 );
55 //---------- end sub3
56 printf( "rank %d: B mt=%lld, nt=%lld, m=%lld, n=%lld\n",
57 mpi_rank, llong(B.mt()), llong(B.nt()), llong(B.m()), llong(B.n()) );
58
59 //---------- begin sub4
60
61 // view of first block-column, A[ 0:mt-1, 0:0 ] as tile indices
62 B = A.sub( 0, A.mt()-1, 0, 0 );
63 //---------- end sub4
64 printf( "rank %d: B mt=%lld, nt=%lld, m=%lld, n=%lld\n",
65 mpi_rank, llong(B.mt()), llong(B.nt()), llong(B.m()), llong(B.n()) );
66
67 //---------- begin sub5
68
69 // view of first block-row, A[ 0:0, 0:nt-1 ] as tile indices
70 B = A.sub( 0, 0, 0, A.nt()-1 );
71 //---------- end sub5
72 printf( "rank %d: B mt=%lld, nt=%lld, m=%lld, n=%lld\n",
73 mpi_rank, llong(B.mt()), llong(B.nt()), llong(B.m()), llong(B.n()) );
74
75 //---------------------------------------- slicing
76
77 //---------- begin slice1
78 // view of A( row1 : row2, col1 : col2 ), inclusive
79 B = A.slice( row1, row2, col1, col2 );
80 //---------- end slice1
81 printf( "rank %d: B mt=%lld, nt=%lld, m=%lld, n=%lld\n",
82 mpi_rank, llong(B.mt()), llong(B.nt()), llong(B.m()), llong(B.n()) );
83
84 //---------- begin slice2
85
86 // view of all of A
87 B = A.slice( 0, A.m()-1, 0, A.n()-1 );
88 //---------- end slice2
89 printf( "rank %d: B mt=%lld, nt=%lld, m=%lld, n=%lld\n",
90 mpi_rank, llong(B.mt()), llong(B.nt()), llong(B.m()), llong(B.n()) );
91
92 //---------- begin slice3
93
94 // view of first column, A[ 0:m-1, 0:0 ]
95 B = A.slice( 0, A.m()-1, 0, 0 );
96 //---------- end slice3
97 printf( "rank %d: B mt=%lld, nt=%lld, m=%lld, n=%lld\n",
98 mpi_rank, llong(B.mt()), llong(B.nt()), llong(B.m()), llong(B.n()) );
99
100 //---------- begin slice4
101
102 // view of first row, A[ 0:0, 0:n-1 ]
103 B = A.slice( 0, 0, 0, A.n()-1 );
104 //---------- end slice4
105 printf( "rank %d: B mt=%lld, nt=%lld, m=%lld, n=%lld\n",
106 mpi_rank, llong(B.mt()), llong(B.nt()), llong(B.m()), llong(B.n()) );
107}
108
109//------------------------------------------------------------------------------
110int main( int argc, char** argv )
111{
112 try {
113 // Parse command line to set types for s, d, c, z precisions.
114 bool types[ 4 ];
115 parse_args( argc, argv, types );
116
117 int provided = 0;
118 slate_mpi_call(
119 MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) );
120 assert( provided == MPI_THREAD_MULTIPLE );
121
122 slate_mpi_call(
123 MPI_Comm_size( MPI_COMM_WORLD, &mpi_size ) );
124
125 slate_mpi_call(
126 MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank ) );
127
128 // Determine p-by-q grid for this MPI size.
129 grid_size( mpi_size, &grid_p, &grid_q );
130 if (mpi_rank == 0) {
131 printf( "mpi_size %d, grid_p %d, grid_q %d\n",
132 mpi_size, grid_p, grid_q );
133 }
134
135 // so random_matrix is different on different ranks.
136 srand( 100 * mpi_rank );
137
138 if (types[ 0 ]) {
139 test_submatrix< float >();
140 }
141
142 slate_mpi_call(
143 MPI_Finalize() );
144 }
145 catch (std::exception const& ex) {
146 fprintf( stderr, "%s", ex.what() );
147 return 1;
148 }
149 return 0;
150}