Example 06: Linear Systems (LU)
This example demonstrates solving linear systems \(Ax=B\) using LU factorization.
Key Concepts
Simple Solve: Using
slate::lu_solve(gesv) for a one-step solution.Explicit Factorization: Separating factorization (
lu_factor/getrf) and solve (lu_solve_using_factor/getrs).Matrix Inversion: Computing \(A^{-1}\) using
lu_inverse_using_factor(getri).Mixed Precision: Using iterative refinement to solve systems with lower-precision factorization.
Condition Number: Estimating the condition number of the matrix.
C++ Example
Standard LU Solve (Lines 38-41)
slate::lu_solve( A, B ); // simplified API
slate::gesv( A, pivots, B ); // traditional API
The simplest way to solve \(Ax=B\).
A is overwritten by its LU factors.
B is overwritten by the solution \(X\).
pivots (in the traditional API) stores the pivot indices found during factorization. lu_solve manages this internally if you don’t need the pivots later.
Mixed Precision Iterative Refinement (Lines 82-83)
slate::gesv_mixed( A, pivots, B, X, iters );
Mixed precision solvers can provide a significant speedup by doing the expensive factorization in lower precision (e.g., float) and then refining the solution to high precision (e.g., double) using the original matrix.
A, B, X are high precision (e.g., double).
The internal factorization happens in low precision (e.g., float).
iters returns the number of refinement iterations performed.
Explicit Factorization and Solve (Lines 113-118)
slate::lu_factor( A, pivots );
slate::lu_solve_using_factor( A, pivots, B );
Sometimes you need to solve for multiple right-hand sides that arrive at different times, or you want to reuse the factors.
lu_factor (getrf): Computes \(PA = LU\).
lu_solve_using_factor (getrs): Solves \(Ax=B\) using the pre-computed factors and pivots.
Matrix Inversion (Lines 142-147)
slate::lu_factor( A, pivots );
slate::lu_inverse_using_factor( A, pivots );
Computes the inverse of a matrix in-place.
Factorize the matrix.
Call lu_inverse_using_factor (getri). A is overwritten by \(A^{-1}\).
Condition Number Estimation (Lines 173-179)
real_t A_norm = slate::norm( slate::Norm::One, A );
slate::lu_factor( A, pivots );
real_t rcond = slate::lu_rcondest_using_factor( slate::Norm::One, A, A_norm );
Estimates the reciprocal condition number \(1/\kappa(A)\).
Compute the norm of the original matrix before factorization.
Factorize the matrix.
Call lu_rcondest_using_factor. This estimates \(\|A^{-1}\|\) cheaply using the factors and combines it with the provided \(\|A\|\).
1// ex06_linear_system_lu.cc
2// Solve AX = B using LU factorization
3
4/// !!! Lines between `//---------- begin label` !!!
5/// !!! and `//---------- end label` !!!
6/// !!! are included in the SLATE Users' Guide. !!!
7
8#include <slate/slate.hh>
9
10#include "util.hh"
11
12int mpi_size = 0;
13int mpi_rank = 0;
14int grid_p = 0;
15int grid_q = 0;
16
17//------------------------------------------------------------------------------
18template <typename scalar_type>
19void test_lu()
20{
21 print_func( mpi_rank );
22
23 int64_t n=1000, nrhs=100, nb=256;
24
25 //---------- begin solve1
26 slate::Matrix<scalar_type> A( n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
27 slate::Matrix<scalar_type> B( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
28 // ...
29 //---------- end solve1
30
31 A.insertLocalTiles();
32 B.insertLocalTiles();
33 random_matrix( A );
34 random_matrix( B );
35
36 //---------- begin solve2
37
38 slate::lu_solve( A, B ); // simplified API
39
40 slate::Pivots pivots;
41 slate::gesv( A, pivots, B ); // traditional API
42 //---------- end solve2
43}
44
45//------------------------------------------------------------------------------
46template <typename scalar_type>
47void test_lu_mixed()
48{
49 print_func( mpi_rank );
50
51 int64_t n=1000, nrhs=100, nb=256;
52 scalar_type zero = 0;
53
54 //---------- begin mixed1
55 // mixed precision: factor in single, iterative refinement to double
56 slate::Matrix<scalar_type> A( n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
57 slate::Matrix<scalar_type> B( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
58 slate::Matrix<scalar_type> X( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
59 slate::Matrix<scalar_type> B1( n, 1, nb, grid_p, grid_q, MPI_COMM_WORLD );
60 slate::Matrix<scalar_type> X1( n, 1, nb, grid_p, grid_q, MPI_COMM_WORLD );
61 int iters = 0;
62 // ...
63 //---------- end mixed1
64
65 A.insertLocalTiles();
66 B.insertLocalTiles();
67 X.insertLocalTiles();
68 B1.insertLocalTiles();
69 X1.insertLocalTiles();
70 random_matrix( A );
71 random_matrix( B );
72 random_matrix( B1 );
73 set( zero, X );
74 set( zero, X1 );
75 slate::Pivots pivots;
76
77 //---------- begin mixed2
78
79 // todo: simplified API
80
81 // traditional API
82 slate::gesv_mixed( A, pivots, B, X, iters );
83 slate::gesv_mixed_gmres( A, pivots, B1, X1, iters ); // only one RHS
84 //---------- end mixed2
85
86 if (mpi_rank == 0) {
87 printf( "rank %d: iters %d\n", mpi_rank, iters );
88 }
89}
90
91//------------------------------------------------------------------------------
92template <typename scalar_type>
93void test_lu_factor()
94{
95 print_func( mpi_rank );
96
97 int64_t n=1000, nrhs=100, nb=256;
98
99 //---------- begin factor1
100 slate::Matrix<scalar_type> A( n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
101 slate::Matrix<scalar_type> B( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
102 slate::Pivots pivots;
103 // ...
104 //---------- end factor1
105
106 A.insertLocalTiles();
107 B.insertLocalTiles();
108 random_matrix( A );
109 random_matrix( B );
110
111 //---------- begin factor2
112 // simplified API
113 slate::lu_factor( A, pivots );
114 slate::lu_solve_using_factor( A, pivots, B );
115
116 // traditional API
117 slate::getrf( A, pivots ); // factor
118 slate::getrs( A, pivots, B ); // solve
119 //---------- end factor2
120}
121
122//------------------------------------------------------------------------------
123template <typename scalar_type>
124void test_lu_inverse()
125{
126 print_func( mpi_rank );
127
128 int64_t n=1000, nb=256;
129
130 //---------- begin inverse1
131 slate::Matrix<scalar_type> A( n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
132 slate::Pivots pivots;
133 // ...
134 //---------- end inverse1
135
136 A.insertLocalTiles();
137 random_matrix( A );
138
139 //---------- begin inverse2
140
141 // simplified API
142 slate::lu_factor( A, pivots );
143 slate::lu_inverse_using_factor( A, pivots );
144
145 // traditional API
146 slate::getrf( A, pivots ); // factor
147 slate::getri( A, pivots ); // inverse
148 //---------- end inverse2
149}
150
151//------------------------------------------------------------------------------
152template <typename scalar_type>
153void test_lu_cond()
154{
155 using real_t = blas::real_type<scalar_type>;
156
157 print_func( mpi_rank );
158
159 int64_t n=1000, nrhs=100, nb=256;
160
161 //---------- begin cond1
162 slate::Matrix<scalar_type> A( n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
163 slate::Pivots pivots;
164 // ...
165 //---------- end cond1
166
167 A.insertLocalTiles();
168 random_matrix( A );
169
170 //---------- begin cond2
171
172 // Compute A_norm before factoring.
173 real_t A_norm = slate::norm( slate::Norm::One, A );
174
175 // Factor using lu_factor or lu_solve.
176 slate::lu_factor( A, pivots );
177
178 // reciprocal condition number, 1 / (||A|| * ||A^{-1}||)
179 real_t A_rcond = slate::lu_rcondest_using_factor( slate::Norm::One, A, A_norm );
180 real_t A_cond = 1. / A_rcond;
181 //---------- end cond2
182
183 if (mpi_rank == 0) {
184 printf( "rank %d: norm %.2e, rcond %.2e, cond %.2e\n",
185 mpi_rank, A_norm, A_rcond, 1 / A_rcond );
186 }
187}
188
189//------------------------------------------------------------------------------
190int main( int argc, char** argv )
191{
192 try {
193 // Parse command line to set types for s, d, c, z precisions.
194 bool types[ 4 ];
195 parse_args( argc, argv, types );
196
197 int provided = 0;
198 slate_mpi_call(
199 MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) );
200 assert( provided == MPI_THREAD_MULTIPLE );
201
202 slate_mpi_call(
203 MPI_Comm_size( MPI_COMM_WORLD, &mpi_size ) );
204
205 slate_mpi_call(
206 MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank ) );
207
208 // Determine p-by-q grid for this MPI size.
209 grid_size( mpi_size, &grid_p, &grid_q );
210 if (mpi_rank == 0) {
211 printf( "mpi_size %d, grid_p %d, grid_q %d\n",
212 mpi_size, grid_p, grid_q );
213 }
214
215 // so random_matrix is different on different ranks.
216 srand( 100 * mpi_rank );
217
218 if (types[ 0 ]) {
219 test_lu< float >();
220 test_lu_factor< float >();
221 test_lu_inverse< float >();
222 test_lu_cond< float >();
223 }
224 if (mpi_rank == 0)
225 printf( "\n" );
226
227 if (types[ 1 ]) {
228 test_lu< double >();
229 test_lu_factor< double >();
230 test_lu_inverse< double >();
231 test_lu_mixed< double >();
232 test_lu_cond< double >();
233 }
234 if (mpi_rank == 0)
235 printf( "\n" );
236
237 if (types[ 2 ]) {
238 test_lu< std::complex<float> >();
239 test_lu_factor< std::complex<float> >();
240 test_lu_inverse< std::complex<float> >();
241 test_lu_cond< std::complex<float> >();
242 }
243 if (mpi_rank == 0)
244 printf( "\n" );
245
246 if (types[ 3 ]) {
247 test_lu< std::complex<double> >();
248 test_lu_factor< std::complex<double> >();
249 test_lu_inverse< std::complex<double> >();
250 test_lu_mixed< std::complex<double> >();
251 test_lu_cond< std::complex<double> >();
252 }
253
254 slate_mpi_call(
255 MPI_Finalize() );
256 }
257 catch (std::exception const& ex) {
258 fprintf( stderr, "%s", ex.what() );
259 return 1;
260 }
261 return 0;
262}
C API Example
1// slate06_linear_system_lu.c
2// Solve AX = B using LU factorization
3
4#include <slate/c_api/slate.h>
5#include <mpi.h>
6
7#include "util.h"
8
9int mpi_size = 0;
10int mpi_rank = 0;
11int grid_p = 0;
12int grid_q = 0;
13
14//------------------------------------------------------------------------------
15void test_lu_r32()
16{
17 print_func( mpi_rank );
18
19 int64_t n=1000, nrhs=100, nb=256;
20 assert( mpi_size == grid_p*grid_q );
21 slate_Matrix_r32 A = slate_Matrix_create_r32(
22 n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
23 slate_Matrix_r32 B = slate_Matrix_create_r32(
24 n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
25 slate_Matrix_insertLocalTiles_r32( A );
26 slate_Matrix_insertLocalTiles_r32( B );
27 random_Matrix_r32( A );
28 random_Matrix_r32( B );
29
30 slate_lu_solve_r32( A, B, NULL );
31
32 slate_Matrix_destroy_r32( A );
33 slate_Matrix_destroy_r32( B );
34}
35
36//------------------------------------------------------------------------------
37void test_lu_r64()
38{
39 print_func( mpi_rank );
40
41 int64_t n=1000, nrhs=100, nb=256;
42 assert( mpi_size == grid_p*grid_q );
43 slate_Matrix_r64 A = slate_Matrix_create_r64(
44 n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
45 slate_Matrix_r64 B = slate_Matrix_create_r64(
46 n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
47 slate_Matrix_insertLocalTiles_r64( A );
48 slate_Matrix_insertLocalTiles_r64( B );
49 random_Matrix_r64( A );
50 random_Matrix_r64( B );
51
52 slate_lu_solve_r64( A, B, NULL );
53
54 slate_Matrix_destroy_r64( A );
55 slate_Matrix_destroy_r64( B );
56}
57
58//------------------------------------------------------------------------------
59void test_lu_c32()
60{
61 print_func( mpi_rank );
62
63 int64_t n=1000, nrhs=100, nb=256;
64 assert( mpi_size == grid_p*grid_q );
65 slate_Matrix_c32 A = slate_Matrix_create_c32(
66 n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
67 slate_Matrix_c32 B = slate_Matrix_create_c32(
68 n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
69 slate_Matrix_insertLocalTiles_c32( A );
70 slate_Matrix_insertLocalTiles_c32( B );
71 random_Matrix_c32( A );
72 random_Matrix_c32( B );
73
74 slate_lu_solve_c32( A, B, NULL );
75
76 slate_Matrix_destroy_c32( A );
77 slate_Matrix_destroy_c32( B );
78}
79
80//------------------------------------------------------------------------------
81void test_lu_c64()
82{
83 print_func( mpi_rank );
84
85 int64_t n=1000, nrhs=100, nb=256;
86 assert( mpi_size == grid_p*grid_q );
87 slate_Matrix_c64 A = slate_Matrix_create_c64(
88 n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
89 slate_Matrix_c64 B = slate_Matrix_create_c64(
90 n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
91 slate_Matrix_insertLocalTiles_c64( A );
92 slate_Matrix_insertLocalTiles_c64( B );
93 random_Matrix_c64( A );
94 random_Matrix_c64( B );
95
96 slate_lu_solve_c64( A, B, NULL );
97
98 slate_Matrix_destroy_c64( A );
99 slate_Matrix_destroy_c64( B );
100}
101
102//------------------------------------------------------------------------------
103void test_lu_inverse_r32()
104{
105 print_func( mpi_rank );
106
107 int64_t n=1000, nb=256;
108 assert( mpi_size == grid_p*grid_q );
109 slate_Matrix_r32 A = slate_Matrix_create_r32(
110 n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
111 slate_Matrix_insertLocalTiles_r32( A );
112 random_Matrix_r32( A );
113 slate_Pivots pivots = slate_Pivots_create();
114
115 slate_lu_factor_r32( A, pivots, NULL );
116 slate_lu_inverse_using_factor_r32( A, pivots, NULL );
117
118 slate_Matrix_destroy_r32( A );
119 slate_Pivots_destroy( pivots );
120}
121
122//------------------------------------------------------------------------------
123void test_lu_inverse_r64()
124{
125 print_func( mpi_rank );
126
127 int64_t n=1000, nb=256;
128 assert( mpi_size == grid_p*grid_q );
129 slate_Matrix_r64 A = slate_Matrix_create_r64(
130 n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
131 slate_Matrix_insertLocalTiles_r64( A );
132 random_Matrix_r64( A );
133 slate_Pivots pivots = slate_Pivots_create();
134
135 slate_lu_factor_r64( A, pivots, NULL );
136 slate_lu_inverse_using_factor_r64( A, pivots, NULL );
137
138 slate_Matrix_destroy_r64( A );
139 slate_Pivots_destroy( pivots );
140}
141
142//------------------------------------------------------------------------------
143void test_lu_inverse_c32()
144{
145 print_func( mpi_rank );
146
147 int64_t n=1000, nb=256;
148 assert( mpi_size == grid_p*grid_q );
149 slate_Matrix_c32 A = slate_Matrix_create_c32(
150 n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
151 slate_Matrix_insertLocalTiles_c32( A );
152 random_Matrix_c32( A );
153 slate_Pivots pivots = slate_Pivots_create();
154
155 slate_lu_factor_c32( A, pivots, NULL );
156 slate_lu_inverse_using_factor_c32( A, pivots, NULL );
157
158 slate_Matrix_destroy_c32( A );
159 slate_Pivots_destroy( pivots );
160}
161
162//------------------------------------------------------------------------------
163void test_lu_inverse_c64()
164{
165 print_func( mpi_rank );
166
167 int64_t n=1000, nb=256;
168 assert( mpi_size == grid_p*grid_q );
169 slate_Matrix_c64 A = slate_Matrix_create_c64(
170 n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
171 slate_Matrix_insertLocalTiles_c64( A );
172 random_Matrix_c64( A );
173 slate_Pivots pivots = slate_Pivots_create();
174
175 slate_lu_factor_c64( A, pivots, NULL );
176 slate_lu_inverse_using_factor_c64( A, pivots, NULL );
177
178 slate_Matrix_destroy_c64( A );
179 slate_Pivots_destroy( pivots );
180}
181
182//------------------------------------------------------------------------------
183int main( int argc, char** argv )
184{
185 // Parse command line to set types for s, d, c, z precisions.
186 bool types[ 4 ];
187 parse_args( argc, argv, types );
188
189 int provided = 0;
190 MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided );
191 assert( provided == MPI_THREAD_MULTIPLE );
192
193 MPI_Comm_size( MPI_COMM_WORLD, &mpi_size );
194 MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank );
195
196 // Determine p-by-q grid for this MPI size.
197 grid_size( mpi_size, &grid_p, &grid_q );
198 if (mpi_rank == 0) {
199 printf( "mpi_size %d, grid_p %d, grid_q %d\n",
200 mpi_size, grid_p, grid_q );
201 }
202
203 // so random_matrix is different on different ranks.
204 srand( 100 * mpi_rank );
205
206 if (types[ 0 ]) {
207 test_lu_r32();
208 test_lu_inverse_r32();
209 }
210 if (mpi_rank == 0)
211 printf( "\n" );
212
213 if (types[ 1 ]) {
214 test_lu_r64();
215 test_lu_inverse_r64();
216 }
217 if (mpi_rank == 0)
218 printf( "\n" );
219
220 if (types[ 2 ]) {
221 test_lu_c32();
222 test_lu_inverse_c32();
223 }
224 if (mpi_rank == 0)
225 printf( "\n" );
226
227 if (types[ 3 ]) {
228 test_lu_c64();
229 test_lu_inverse_c64();
230 }
231
232 MPI_Finalize();
233
234 return 0;
235}