Example 07: Linear Systems (Cholesky)

This example demonstrates solving symmetric/Hermitian positive definite linear systems using Cholesky factorization.

Key Concepts

  1. Cholesky Solve: Using slate::chol_solve (posv) for a one-step solution of \(AX=B\) where \(A\) is positive definite.

  2. Explicit Factorization: Separating factorization (chol_factor/potrf) and solve (chol_solve_using_factor/potrs).

  3. Matrix Inversion: Computing \(A^{-1}\) using chol_inverse_using_factor (potri).

  4. Mixed Precision: Using iterative refinement (posv_mixed).

  5. Condition Number: Estimating the condition number of a Hermitian positive definite matrix.

C++ Example

Cholesky Solve (Lines 38-40)

slate::chol_solve( A, B );  // simplified API
slate::posv( A, B );        // traditional API

Solves \(Ax=B\) for symmetric/Hermitian positive definite A.

  • Requires A to be defined as HermitianMatrix or SymmetricMatrix.

  • A is overwritten by the Cholesky factor \(L\) (if Uplo::Lower) or \(U\) (if Uplo::Upper).

  • B is overwritten by the solution.

  • Cholesky is roughly twice as fast as LU factorization for applicable matrices.

Mixed Precision (Lines 80-81)

slate::posv_mixed( A, B, X, iters );

Similar to the LU case, this routine factors A in lower precision and iteratively refines the solution X to high precision. It requires positive definiteness.

Explicit Factorization (Lines 106-111)

slate::chol_factor( A );
slate::chol_solve_using_factor( A, B );
  1. chol_factor (potrf): Computes \(A = LL^H\).

  2. chol_solve_using_factor (potrs): Solves using the factors.

Inversion (Lines 134-139)

slate::chol_factor( A );
slate::chol_inverse_using_factor( A );

Computes \(A^{-1}\) for a positive definite matrix.

  1. Factorize.

  2. Call chol_inverse_using_factor (potri). A is overwritten by the inverse.

Condition Number (Lines 165-171)

real_t A_norm = slate::norm( slate::Norm::One, A );
slate::chol_factor( A );
real_t rcond = slate::chol_rcondest_using_factor( slate::Norm::One, A, A_norm );

Standard condition number estimation flow: Norm -> Factor -> Estimate.

  1// ex07_linear_system_cholesky.cc
  2// Solve AX = B using Cholesky factorization
  3
  4/// !!!   Lines between `//---------- begin label`          !!!
  5/// !!!             and `//---------- end label`            !!!
  6/// !!!   are included in the SLATE Users' Guide.           !!!
  7
  8#include <slate/slate.hh>
  9
 10#include "util.hh"
 11
 12int mpi_size = 0;
 13int mpi_rank = 0;
 14int grid_p = 0;
 15int grid_q = 0;
 16
 17//------------------------------------------------------------------------------
 18template <typename scalar_type>
 19void test_cholesky()
 20{
 21    print_func( mpi_rank );
 22
 23    int64_t n=1000, nrhs=100, nb=256;
 24
 25    //---------- begin solve1
 26    slate::HermitianMatrix<scalar_type>
 27        A( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
 28    slate::Matrix<scalar_type> B( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
 29    // ...
 30    //---------- end solve1
 31
 32    A.insertLocalTiles();
 33    B.insertLocalTiles();
 34    random_matrix_diag_dominant( A );
 35    random_matrix( B );
 36
 37    //---------- begin solve2
 38    slate::chol_solve( A, B );  // simplified API
 39
 40    slate::posv( A, B );        // traditional API
 41    //---------- end solve2
 42}
 43
 44//------------------------------------------------------------------------------
 45template <typename scalar_type>
 46void test_cholesky_mixed()
 47{
 48    print_func( mpi_rank );
 49
 50    int64_t n=1000, nrhs=100, nb=256;
 51    scalar_type zero = 0;
 52
 53    //---------- begin mixed1
 54    // mixed precision: factor in single, iterative refinement to double
 55    slate::HermitianMatrix<scalar_type>
 56        A( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
 57    slate::Matrix<scalar_type> B( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
 58    slate::Matrix<scalar_type> X( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
 59    slate::Matrix<scalar_type> B1( n, 1,   nb, grid_p, grid_q, MPI_COMM_WORLD );
 60    slate::Matrix<scalar_type> X1( n, 1,   nb, grid_p, grid_q, MPI_COMM_WORLD );
 61    int iters = 0;
 62    //---------- end mixed1
 63
 64    A.insertLocalTiles();
 65    B.insertLocalTiles();
 66    X.insertLocalTiles();
 67    B1.insertLocalTiles();
 68    X1.insertLocalTiles();
 69    random_matrix_diag_dominant( A );
 70    random_matrix( B );
 71    random_matrix( B1 );
 72    slate::set( zero, X );
 73    slate::set( zero, X1 );
 74
 75    //---------- begin mixed2
 76
 77    // todo: simplified API
 78
 79    // traditional API
 80    slate::posv_mixed( A, B, X, iters );
 81    slate::posv_mixed_gmres( A, B1, X1, iters );  // only one RHS
 82    //---------- end mixed2
 83
 84    if (mpi_rank == 0) {
 85        printf( "rank %d: iters %d\n", mpi_rank, iters );
 86    }
 87}
 88
 89//------------------------------------------------------------------------------
 90template <typename scalar_type>
 91void test_cholesky_factor()
 92{
 93    print_func( mpi_rank );
 94
 95    int64_t n=1000, nrhs=100, nb=256;
 96
 97    slate::HermitianMatrix<scalar_type>
 98        A( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
 99    slate::Matrix<scalar_type> B( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
100    A.insertLocalTiles();
101    B.insertLocalTiles();
102    random_matrix_diag_dominant( A );
103    random_matrix( B );
104
105    // simplified API
106    slate::chol_factor( A );
107    slate::chol_solve_using_factor( A, B );
108
109    // traditional API
110    slate::potrf( A );     // factor
111    slate::potrs( A, B );  // solve
112}
113
114//------------------------------------------------------------------------------
115template <typename scalar_type>
116void test_cholesky_inverse()
117{
118    print_func( mpi_rank );
119
120    int64_t n=1000, nb=256;
121
122    //---------- begin inverse1
123    slate::HermitianMatrix<scalar_type>
124        A( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
125    // ...
126    //---------- end inverse1
127
128    A.insertLocalTiles();
129    random_matrix_diag_dominant( A );
130
131    //---------- begin inverse2
132
133    // simplified API
134    slate::chol_factor( A );
135    slate::chol_inverse_using_factor( A );
136
137    // traditional API
138    slate::potrf( A );  // factor
139    slate::potri( A );  // inverse
140    //---------- end inverse2
141}
142
143//------------------------------------------------------------------------------
144template <typename scalar_type>
145void test_cholesky_cond()
146{
147    using real_t = blas::real_type<scalar_type>;
148
149    print_func( mpi_rank );
150
151    int64_t n=1000, nrhs=100, nb=256;
152
153    //---------- begin cond1
154    slate::HermitianMatrix<scalar_type>
155        A( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
156    // ...
157    //---------- end cond1
158
159    A.insertLocalTiles();
160    random_matrix_diag_dominant( A );
161
162    //---------- begin cond2
163
164    // Compute A_norm before factoring.
165    real_t A_norm = slate::norm( slate::Norm::One, A );
166
167    // Factor using chol_factor or chol_solve.
168    slate::chol_factor( A );
169
170    // reciprocal condition number, 1 / (||A|| * ||A^{-1}||)
171    real_t A_rcond = slate::chol_rcondest_using_factor( slate::Norm::One, A, A_norm );
172    real_t A_cond = 1. / A_rcond;
173    //---------- end cond2
174
175    if (mpi_rank == 0) {
176        printf( "rank %d: norm %.2e, rcond %.2e, cond %.2e\n",
177                mpi_rank, A_norm, A_rcond, 1 / A_rcond );
178    }
179}
180
181//------------------------------------------------------------------------------
182int main( int argc, char** argv )
183{
184    try {
185        // Parse command line to set types for s, d, c, z precisions.
186        bool types[ 4 ];
187        parse_args( argc, argv, types );
188
189        int provided = 0;
190        slate_mpi_call(
191            MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) );
192        assert( provided == MPI_THREAD_MULTIPLE );
193
194        slate_mpi_call(
195            MPI_Comm_size( MPI_COMM_WORLD, &mpi_size ) );
196
197        slate_mpi_call(
198            MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank ) );
199
200        // Determine p-by-q grid for this MPI size.
201        grid_size( mpi_size, &grid_p, &grid_q );
202        if (mpi_rank == 0) {
203            printf( "mpi_size %d, grid_p %d, grid_q %d\n",
204                    mpi_size, grid_p, grid_q );
205        }
206
207        // so random_matrix is different on different ranks.
208        srand( 100 * mpi_rank );
209
210        if (types[ 0 ]) {
211            test_cholesky< float >();
212            test_cholesky_factor< float >();
213            test_cholesky_inverse< float >();
214            test_cholesky_cond< float >();
215        }
216        if (mpi_rank == 0)
217            printf( "\n" );
218
219        if (types[ 1 ]) {
220            test_cholesky< double >();
221            test_cholesky_factor< double >();
222            test_cholesky_inverse< double >();
223            test_cholesky_mixed< double >();
224            test_cholesky_cond< double >();
225        }
226        if (mpi_rank == 0)
227            printf( "\n" );
228
229        if (types[ 2 ]) {
230            test_cholesky< std::complex<float> >();
231            test_cholesky_factor< std::complex<float> >();
232            test_cholesky_inverse< std::complex<float> >();
233            test_cholesky_cond< std::complex<float> >();
234        }
235        if (mpi_rank == 0)
236            printf( "\n" );
237
238        if (types[ 3 ]) {
239            test_cholesky< std::complex<double> >();
240            test_cholesky_factor< std::complex<double> >();
241            test_cholesky_inverse< std::complex<double> >();
242            test_cholesky_mixed< std::complex<double> >();
243            test_cholesky_cond< std::complex<double> >();
244        }
245
246        slate_mpi_call(
247            MPI_Finalize() );
248    }
249    catch (std::exception const& ex) {
250        fprintf( stderr, "%s", ex.what() );
251        return 1;
252    }
253    return 0;
254}