Example 06: Linear Systems (LU)

This example demonstrates solving linear systems \(Ax=B\) using LU factorization.

Key Concepts

  1. Simple Solve: Using slate::lu_solve (gesv) for a one-step solution.

  2. Explicit Factorization: Separating factorization (lu_factor/getrf) and solve (lu_solve_using_factor/getrs).

  3. Matrix Inversion: Computing \(A^{-1}\) using lu_inverse_using_factor (getri).

  4. Mixed Precision: Using iterative refinement to solve systems with lower-precision factorization.

  5. Condition Number: Estimating the condition number of the matrix.

C++ Example

Standard LU Solve (Lines 38-41)

slate::lu_solve( A, B );        // simplified API
slate::gesv( A, pivots, B );    // traditional API

The simplest way to solve \(Ax=B\).

  • A is overwritten by its LU factors.

  • B is overwritten by the solution \(X\).

  • pivots (in the traditional API) stores the pivot indices found during factorization. lu_solve manages this internally if you don’t need the pivots later.

Mixed Precision Iterative Refinement (Lines 82-83)

slate::gesv_mixed( A, pivots, B, X, iters );

Mixed precision solvers can provide a significant speedup by doing the expensive factorization in lower precision (e.g., float) and then refining the solution to high precision (e.g., double) using the original matrix.

  • A, B, X are high precision (e.g., double).

  • The internal factorization happens in low precision (e.g., float).

  • iters returns the number of refinement iterations performed.

Explicit Factorization and Solve (Lines 113-118)

slate::lu_factor( A, pivots );
slate::lu_solve_using_factor( A, pivots, B );

Sometimes you need to solve for multiple right-hand sides that arrive at different times, or you want to reuse the factors.

  1. lu_factor (getrf): Computes \(PA = LU\).

  2. lu_solve_using_factor (getrs): Solves \(Ax=B\) using the pre-computed factors and pivots.

Matrix Inversion (Lines 142-147)

slate::lu_factor( A, pivots );
slate::lu_inverse_using_factor( A, pivots );

Computes the inverse of a matrix in-place.

  1. Factorize the matrix.

  2. Call lu_inverse_using_factor (getri). A is overwritten by \(A^{-1}\).

Condition Number Estimation (Lines 173-179)

real_t A_norm = slate::norm( slate::Norm::One, A );
slate::lu_factor( A, pivots );
real_t rcond = slate::lu_rcondest_using_factor( slate::Norm::One, A, A_norm );

Estimates the reciprocal condition number \(1/\kappa(A)\).

  1. Compute the norm of the original matrix before factorization.

  2. Factorize the matrix.

  3. Call lu_rcondest_using_factor. This estimates \(\|A^{-1}\|\) cheaply using the factors and combines it with the provided \(\|A\|\).

  1// ex06_linear_system_lu.cc
  2// Solve AX = B using LU factorization
  3
  4/// !!!   Lines between `//---------- begin label`          !!!
  5/// !!!             and `//---------- end label`            !!!
  6/// !!!   are included in the SLATE Users' Guide.           !!!
  7
  8#include <slate/slate.hh>
  9
 10#include "util.hh"
 11
 12int mpi_size = 0;
 13int mpi_rank = 0;
 14int grid_p = 0;
 15int grid_q = 0;
 16
 17//------------------------------------------------------------------------------
 18template <typename scalar_type>
 19void test_lu()
 20{
 21    print_func( mpi_rank );
 22
 23    int64_t n=1000, nrhs=100, nb=256;
 24
 25    //---------- begin solve1
 26    slate::Matrix<scalar_type> A( n, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
 27    slate::Matrix<scalar_type> B( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
 28    // ...
 29    //---------- end solve1
 30
 31    A.insertLocalTiles();
 32    B.insertLocalTiles();
 33    random_matrix( A );
 34    random_matrix( B );
 35
 36    //---------- begin solve2
 37
 38    slate::lu_solve( A, B );        // simplified API
 39
 40    slate::Pivots pivots;
 41    slate::gesv( A, pivots, B );    // traditional API
 42    //---------- end solve2
 43}
 44
 45//------------------------------------------------------------------------------
 46template <typename scalar_type>
 47void test_lu_mixed()
 48{
 49    print_func( mpi_rank );
 50
 51    int64_t n=1000, nrhs=100, nb=256;
 52    scalar_type zero = 0;
 53
 54    //---------- begin mixed1
 55    // mixed precision: factor in single, iterative refinement to double
 56    slate::Matrix<scalar_type> A( n, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
 57    slate::Matrix<scalar_type> B( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
 58    slate::Matrix<scalar_type> X( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
 59    slate::Matrix<scalar_type> B1( n, 1,   nb, grid_p, grid_q, MPI_COMM_WORLD );
 60    slate::Matrix<scalar_type> X1( n, 1,   nb, grid_p, grid_q, MPI_COMM_WORLD );
 61    int iters = 0;
 62    // ...
 63    //---------- end mixed1
 64
 65    A.insertLocalTiles();
 66    B.insertLocalTiles();
 67    X.insertLocalTiles();
 68    B1.insertLocalTiles();
 69    X1.insertLocalTiles();
 70    random_matrix( A );
 71    random_matrix( B );
 72    random_matrix( B1 );
 73    set( zero, X );
 74    set( zero, X1 );
 75    slate::Pivots pivots;
 76
 77    //---------- begin mixed2
 78
 79    // todo: simplified API
 80
 81    // traditional API
 82    slate::gesv_mixed( A, pivots, B, X, iters );
 83    slate::gesv_mixed_gmres( A, pivots, B1, X1, iters );  // only one RHS
 84    //---------- end mixed2
 85
 86    if (mpi_rank == 0) {
 87        printf( "rank %d: iters %d\n", mpi_rank, iters );
 88    }
 89}
 90
 91//------------------------------------------------------------------------------
 92template <typename scalar_type>
 93void test_lu_factor()
 94{
 95    print_func( mpi_rank );
 96
 97    int64_t n=1000, nrhs=100, nb=256;
 98
 99    //---------- begin factor1
100    slate::Matrix<scalar_type> A( n, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
101    slate::Matrix<scalar_type> B( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
102    slate::Pivots pivots;
103    // ...
104    //---------- end factor1
105
106    A.insertLocalTiles();
107    B.insertLocalTiles();
108    random_matrix( A );
109    random_matrix( B );
110
111    //---------- begin factor2
112    // simplified API
113    slate::lu_factor( A, pivots );
114    slate::lu_solve_using_factor( A, pivots, B );
115
116    // traditional API
117    slate::getrf( A, pivots );     // factor
118    slate::getrs( A, pivots, B );  // solve
119    //---------- end factor2
120}
121
122//------------------------------------------------------------------------------
123template <typename scalar_type>
124void test_lu_inverse()
125{
126    print_func( mpi_rank );
127
128    int64_t n=1000, nb=256;
129
130    //---------- begin inverse1
131    slate::Matrix<scalar_type> A( n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
132    slate::Pivots pivots;
133    // ...
134    //---------- end inverse1
135
136    A.insertLocalTiles();
137    random_matrix( A );
138
139    //---------- begin inverse2
140
141    // simplified API
142    slate::lu_factor( A, pivots );
143    slate::lu_inverse_using_factor( A, pivots );
144
145    // traditional API
146    slate::getrf( A, pivots );  // factor
147    slate::getri( A, pivots );  // inverse
148    //---------- end inverse2
149}
150
151//------------------------------------------------------------------------------
152template <typename scalar_type>
153void test_lu_cond()
154{
155    using real_t = blas::real_type<scalar_type>;
156
157    print_func( mpi_rank );
158
159    int64_t n=1000, nrhs=100, nb=256;
160
161    //---------- begin cond1
162    slate::Matrix<scalar_type> A( n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
163    slate::Pivots pivots;
164    // ...
165    //---------- end cond1
166
167    A.insertLocalTiles();
168    random_matrix( A );
169
170    //---------- begin cond2
171
172    // Compute A_norm before factoring.
173    real_t A_norm = slate::norm( slate::Norm::One, A );
174
175    // Factor using lu_factor or lu_solve.
176    slate::lu_factor( A, pivots );
177
178    // reciprocal condition number, 1 / (||A|| * ||A^{-1}||)
179    real_t A_rcond = slate::lu_rcondest_using_factor( slate::Norm::One, A, A_norm );
180    real_t A_cond = 1. / A_rcond;
181    //---------- end cond2
182
183    if (mpi_rank == 0) {
184        printf( "rank %d: norm %.2e, rcond %.2e, cond %.2e\n",
185                mpi_rank, A_norm, A_rcond, 1 / A_rcond );
186    }
187}
188
189//------------------------------------------------------------------------------
190int main( int argc, char** argv )
191{
192    try {
193        // Parse command line to set types for s, d, c, z precisions.
194        bool types[ 4 ];
195        parse_args( argc, argv, types );
196
197        int provided = 0;
198        slate_mpi_call(
199            MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) );
200        assert( provided == MPI_THREAD_MULTIPLE );
201
202        slate_mpi_call(
203            MPI_Comm_size( MPI_COMM_WORLD, &mpi_size ) );
204
205        slate_mpi_call(
206            MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank ) );
207
208        // Determine p-by-q grid for this MPI size.
209        grid_size( mpi_size, &grid_p, &grid_q );
210        if (mpi_rank == 0) {
211            printf( "mpi_size %d, grid_p %d, grid_q %d\n",
212                    mpi_size, grid_p, grid_q );
213        }
214
215        // so random_matrix is different on different ranks.
216        srand( 100 * mpi_rank );
217
218        if (types[ 0 ]) {
219            test_lu< float >();
220            test_lu_factor< float >();
221            test_lu_inverse< float >();
222            test_lu_cond< float >();
223        }
224        if (mpi_rank == 0)
225            printf( "\n" );
226
227        if (types[ 1 ]) {
228            test_lu< double >();
229            test_lu_factor< double >();
230            test_lu_inverse< double >();
231            test_lu_mixed< double >();
232            test_lu_cond< double >();
233        }
234        if (mpi_rank == 0)
235            printf( "\n" );
236
237        if (types[ 2 ]) {
238            test_lu< std::complex<float> >();
239            test_lu_factor< std::complex<float> >();
240            test_lu_inverse< std::complex<float> >();
241            test_lu_cond< std::complex<float> >();
242        }
243        if (mpi_rank == 0)
244            printf( "\n" );
245
246        if (types[ 3 ]) {
247            test_lu< std::complex<double> >();
248            test_lu_factor< std::complex<double> >();
249            test_lu_inverse< std::complex<double> >();
250            test_lu_mixed< std::complex<double> >();
251            test_lu_cond< std::complex<double> >();
252        }
253
254        slate_mpi_call(
255            MPI_Finalize() );
256    }
257    catch (std::exception const& ex) {
258        fprintf( stderr, "%s", ex.what() );
259        return 1;
260    }
261    return 0;
262}

C API Example

  1// slate06_linear_system_lu.c
  2// Solve AX = B using LU factorization
  3
  4#include <slate/c_api/slate.h>
  5#include <mpi.h>
  6
  7#include "util.h"
  8
  9int mpi_size = 0;
 10int mpi_rank = 0;
 11int grid_p = 0;
 12int grid_q = 0;
 13
 14//------------------------------------------------------------------------------
 15void test_lu_r32()
 16{
 17    print_func( mpi_rank );
 18
 19    int64_t n=1000, nrhs=100, nb=256;
 20    assert( mpi_size == grid_p*grid_q );
 21    slate_Matrix_r32 A = slate_Matrix_create_r32(
 22        n, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
 23    slate_Matrix_r32 B = slate_Matrix_create_r32(
 24        n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
 25    slate_Matrix_insertLocalTiles_r32( A );
 26    slate_Matrix_insertLocalTiles_r32( B );
 27    random_Matrix_r32( A );
 28    random_Matrix_r32( B );
 29
 30    slate_lu_solve_r32( A, B, NULL );
 31
 32    slate_Matrix_destroy_r32( A );
 33    slate_Matrix_destroy_r32( B );
 34}
 35
 36//------------------------------------------------------------------------------
 37void test_lu_r64()
 38{
 39    print_func( mpi_rank );
 40
 41    int64_t n=1000, nrhs=100, nb=256;
 42    assert( mpi_size == grid_p*grid_q );
 43    slate_Matrix_r64 A = slate_Matrix_create_r64(
 44        n, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
 45    slate_Matrix_r64 B = slate_Matrix_create_r64(
 46        n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
 47    slate_Matrix_insertLocalTiles_r64( A );
 48    slate_Matrix_insertLocalTiles_r64( B );
 49    random_Matrix_r64( A );
 50    random_Matrix_r64( B );
 51
 52    slate_lu_solve_r64( A, B, NULL );
 53
 54    slate_Matrix_destroy_r64( A );
 55    slate_Matrix_destroy_r64( B );
 56}
 57
 58//------------------------------------------------------------------------------
 59void test_lu_c32()
 60{
 61    print_func( mpi_rank );
 62
 63    int64_t n=1000, nrhs=100, nb=256;
 64    assert( mpi_size == grid_p*grid_q );
 65    slate_Matrix_c32 A = slate_Matrix_create_c32(
 66        n, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
 67    slate_Matrix_c32 B = slate_Matrix_create_c32(
 68        n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
 69    slate_Matrix_insertLocalTiles_c32( A );
 70    slate_Matrix_insertLocalTiles_c32( B );
 71    random_Matrix_c32( A );
 72    random_Matrix_c32( B );
 73
 74    slate_lu_solve_c32( A, B, NULL );
 75
 76    slate_Matrix_destroy_c32( A );
 77    slate_Matrix_destroy_c32( B );
 78}
 79
 80//------------------------------------------------------------------------------
 81void test_lu_c64()
 82{
 83    print_func( mpi_rank );
 84
 85    int64_t n=1000, nrhs=100, nb=256;
 86    assert( mpi_size == grid_p*grid_q );
 87    slate_Matrix_c64 A = slate_Matrix_create_c64(
 88        n, n,    nb, grid_p, grid_q, MPI_COMM_WORLD );
 89    slate_Matrix_c64 B = slate_Matrix_create_c64(
 90        n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
 91    slate_Matrix_insertLocalTiles_c64( A );
 92    slate_Matrix_insertLocalTiles_c64( B );
 93    random_Matrix_c64( A );
 94    random_Matrix_c64( B );
 95
 96    slate_lu_solve_c64( A, B, NULL );
 97
 98    slate_Matrix_destroy_c64( A );
 99    slate_Matrix_destroy_c64( B );
100}
101
102//------------------------------------------------------------------------------
103void test_lu_inverse_r32()
104{
105    print_func( mpi_rank );
106
107    int64_t n=1000, nb=256;
108    assert( mpi_size == grid_p*grid_q );
109    slate_Matrix_r32 A = slate_Matrix_create_r32(
110        n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
111    slate_Matrix_insertLocalTiles_r32( A );
112    random_Matrix_r32( A );
113    slate_Pivots pivots = slate_Pivots_create();
114
115    slate_lu_factor_r32( A, pivots, NULL );
116    slate_lu_inverse_using_factor_r32( A, pivots, NULL );
117
118    slate_Matrix_destroy_r32( A );
119    slate_Pivots_destroy( pivots );
120}
121
122//------------------------------------------------------------------------------
123void test_lu_inverse_r64()
124{
125    print_func( mpi_rank );
126
127    int64_t n=1000, nb=256;
128    assert( mpi_size == grid_p*grid_q );
129    slate_Matrix_r64 A = slate_Matrix_create_r64(
130        n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
131    slate_Matrix_insertLocalTiles_r64( A );
132    random_Matrix_r64( A );
133    slate_Pivots pivots = slate_Pivots_create();
134
135    slate_lu_factor_r64( A, pivots, NULL );
136    slate_lu_inverse_using_factor_r64( A, pivots, NULL );
137
138    slate_Matrix_destroy_r64( A );
139    slate_Pivots_destroy( pivots );
140}
141
142//------------------------------------------------------------------------------
143void test_lu_inverse_c32()
144{
145    print_func( mpi_rank );
146
147    int64_t n=1000, nb=256;
148    assert( mpi_size == grid_p*grid_q );
149    slate_Matrix_c32 A = slate_Matrix_create_c32(
150        n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
151    slate_Matrix_insertLocalTiles_c32( A );
152    random_Matrix_c32( A );
153    slate_Pivots pivots = slate_Pivots_create();
154
155    slate_lu_factor_c32( A, pivots, NULL );
156    slate_lu_inverse_using_factor_c32( A, pivots, NULL );
157
158    slate_Matrix_destroy_c32( A );
159    slate_Pivots_destroy( pivots );
160}
161
162//------------------------------------------------------------------------------
163void test_lu_inverse_c64()
164{
165    print_func( mpi_rank );
166
167    int64_t n=1000, nb=256;
168    assert( mpi_size == grid_p*grid_q );
169    slate_Matrix_c64 A = slate_Matrix_create_c64(
170        n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
171    slate_Matrix_insertLocalTiles_c64( A );
172    random_Matrix_c64( A );
173    slate_Pivots pivots = slate_Pivots_create();
174
175    slate_lu_factor_c64( A, pivots, NULL );
176    slate_lu_inverse_using_factor_c64( A, pivots, NULL );
177
178    slate_Matrix_destroy_c64( A );
179    slate_Pivots_destroy( pivots );
180}
181
182//------------------------------------------------------------------------------
183int main( int argc, char** argv )
184{
185    // Parse command line to set types for s, d, c, z precisions.
186    bool types[ 4 ];
187    parse_args( argc, argv, types );
188
189    int provided = 0;
190    MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided );
191    assert( provided == MPI_THREAD_MULTIPLE );
192
193    MPI_Comm_size( MPI_COMM_WORLD, &mpi_size );
194    MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank );
195
196    // Determine p-by-q grid for this MPI size.
197    grid_size( mpi_size, &grid_p, &grid_q );
198    if (mpi_rank == 0) {
199        printf( "mpi_size %d, grid_p %d, grid_q %d\n",
200                mpi_size, grid_p, grid_q );
201    }
202
203    // so random_matrix is different on different ranks.
204    srand( 100 * mpi_rank );
205
206    if (types[ 0 ]) {
207        test_lu_r32();
208        test_lu_inverse_r32();
209    }
210    if (mpi_rank == 0)
211        printf( "\n" );
212
213    if (types[ 1 ]) {
214        test_lu_r64();
215        test_lu_inverse_r64();
216    }
217    if (mpi_rank == 0)
218        printf( "\n" );
219
220    if (types[ 2 ]) {
221        test_lu_c32();
222        test_lu_inverse_c32();
223    }
224    if (mpi_rank == 0)
225        printf( "\n" );
226
227    if (types[ 3 ]) {
228        test_lu_c64();
229        test_lu_inverse_c64();
230    }
231
232    MPI_Finalize();
233
234    return 0;
235}