Example 07: Linear Systems (Cholesky)
This example demonstrates solving symmetric/Hermitian positive definite linear systems using Cholesky factorization.
Key Concepts
Cholesky Solve: Using
slate::chol_solve(posv) for a one-step solution of \(AX=B\) where \(A\) is positive definite.Explicit Factorization: Separating factorization (
chol_factor/potrf) and solve (chol_solve_using_factor/potrs).Matrix Inversion: Computing \(A^{-1}\) using
chol_inverse_using_factor(potri).Mixed Precision: Using iterative refinement (
posv_mixed).Condition Number: Estimating the condition number of a Hermitian positive definite matrix.
C++ Example
Cholesky Solve (Lines 38-40)
slate::chol_solve( A, B ); // simplified API
slate::posv( A, B ); // traditional API
Solves \(Ax=B\) for symmetric/Hermitian positive definite A.
Requires A to be defined as HermitianMatrix or SymmetricMatrix.
A is overwritten by the Cholesky factor \(L\) (if Uplo::Lower) or \(U\) (if Uplo::Upper).
B is overwritten by the solution.
Cholesky is roughly twice as fast as LU factorization for applicable matrices.
Mixed Precision (Lines 80-81)
slate::posv_mixed( A, B, X, iters );
Similar to the LU case, this routine factors A in lower precision and iteratively refines the solution X to high precision. It requires positive definiteness.
Explicit Factorization (Lines 106-111)
slate::chol_factor( A );
slate::chol_solve_using_factor( A, B );
chol_factor (potrf): Computes \(A = LL^H\).
chol_solve_using_factor (potrs): Solves using the factors.
Inversion (Lines 134-139)
slate::chol_factor( A );
slate::chol_inverse_using_factor( A );
Computes \(A^{-1}\) for a positive definite matrix.
Factorize.
Call chol_inverse_using_factor (potri). A is overwritten by the inverse.
Condition Number (Lines 165-171)
real_t A_norm = slate::norm( slate::Norm::One, A );
slate::chol_factor( A );
real_t rcond = slate::chol_rcondest_using_factor( slate::Norm::One, A, A_norm );
Standard condition number estimation flow: Norm -> Factor -> Estimate.
1// ex07_linear_system_cholesky.cc
2// Solve AX = B using Cholesky factorization
3
4/// !!! Lines between `//---------- begin label` !!!
5/// !!! and `//---------- end label` !!!
6/// !!! are included in the SLATE Users' Guide. !!!
7
8#include <slate/slate.hh>
9
10#include "util.hh"
11
12int mpi_size = 0;
13int mpi_rank = 0;
14int grid_p = 0;
15int grid_q = 0;
16
17//------------------------------------------------------------------------------
18template <typename scalar_type>
19void test_cholesky()
20{
21 print_func( mpi_rank );
22
23 int64_t n=1000, nrhs=100, nb=256;
24
25 //---------- begin solve1
26 slate::HermitianMatrix<scalar_type>
27 A( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
28 slate::Matrix<scalar_type> B( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
29 // ...
30 //---------- end solve1
31
32 A.insertLocalTiles();
33 B.insertLocalTiles();
34 random_matrix_diag_dominant( A );
35 random_matrix( B );
36
37 //---------- begin solve2
38 slate::chol_solve( A, B ); // simplified API
39
40 slate::posv( A, B ); // traditional API
41 //---------- end solve2
42}
43
44//------------------------------------------------------------------------------
45template <typename scalar_type>
46void test_cholesky_mixed()
47{
48 print_func( mpi_rank );
49
50 int64_t n=1000, nrhs=100, nb=256;
51 scalar_type zero = 0;
52
53 //---------- begin mixed1
54 // mixed precision: factor in single, iterative refinement to double
55 slate::HermitianMatrix<scalar_type>
56 A( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
57 slate::Matrix<scalar_type> B( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
58 slate::Matrix<scalar_type> X( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
59 slate::Matrix<scalar_type> B1( n, 1, nb, grid_p, grid_q, MPI_COMM_WORLD );
60 slate::Matrix<scalar_type> X1( n, 1, nb, grid_p, grid_q, MPI_COMM_WORLD );
61 int iters = 0;
62 //---------- end mixed1
63
64 A.insertLocalTiles();
65 B.insertLocalTiles();
66 X.insertLocalTiles();
67 B1.insertLocalTiles();
68 X1.insertLocalTiles();
69 random_matrix_diag_dominant( A );
70 random_matrix( B );
71 random_matrix( B1 );
72 slate::set( zero, X );
73 slate::set( zero, X1 );
74
75 //---------- begin mixed2
76
77 // todo: simplified API
78
79 // traditional API
80 slate::posv_mixed( A, B, X, iters );
81 slate::posv_mixed_gmres( A, B1, X1, iters ); // only one RHS
82 //---------- end mixed2
83
84 if (mpi_rank == 0) {
85 printf( "rank %d: iters %d\n", mpi_rank, iters );
86 }
87}
88
89//------------------------------------------------------------------------------
90template <typename scalar_type>
91void test_cholesky_factor()
92{
93 print_func( mpi_rank );
94
95 int64_t n=1000, nrhs=100, nb=256;
96
97 slate::HermitianMatrix<scalar_type>
98 A( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
99 slate::Matrix<scalar_type> B( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
100 A.insertLocalTiles();
101 B.insertLocalTiles();
102 random_matrix_diag_dominant( A );
103 random_matrix( B );
104
105 // simplified API
106 slate::chol_factor( A );
107 slate::chol_solve_using_factor( A, B );
108
109 // traditional API
110 slate::potrf( A ); // factor
111 slate::potrs( A, B ); // solve
112}
113
114//------------------------------------------------------------------------------
115template <typename scalar_type>
116void test_cholesky_inverse()
117{
118 print_func( mpi_rank );
119
120 int64_t n=1000, nb=256;
121
122 //---------- begin inverse1
123 slate::HermitianMatrix<scalar_type>
124 A( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
125 // ...
126 //---------- end inverse1
127
128 A.insertLocalTiles();
129 random_matrix_diag_dominant( A );
130
131 //---------- begin inverse2
132
133 // simplified API
134 slate::chol_factor( A );
135 slate::chol_inverse_using_factor( A );
136
137 // traditional API
138 slate::potrf( A ); // factor
139 slate::potri( A ); // inverse
140 //---------- end inverse2
141}
142
143//------------------------------------------------------------------------------
144template <typename scalar_type>
145void test_cholesky_cond()
146{
147 using real_t = blas::real_type<scalar_type>;
148
149 print_func( mpi_rank );
150
151 int64_t n=1000, nrhs=100, nb=256;
152
153 //---------- begin cond1
154 slate::HermitianMatrix<scalar_type>
155 A( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
156 // ...
157 //---------- end cond1
158
159 A.insertLocalTiles();
160 random_matrix_diag_dominant( A );
161
162 //---------- begin cond2
163
164 // Compute A_norm before factoring.
165 real_t A_norm = slate::norm( slate::Norm::One, A );
166
167 // Factor using chol_factor or chol_solve.
168 slate::chol_factor( A );
169
170 // reciprocal condition number, 1 / (||A|| * ||A^{-1}||)
171 real_t A_rcond = slate::chol_rcondest_using_factor( slate::Norm::One, A, A_norm );
172 real_t A_cond = 1. / A_rcond;
173 //---------- end cond2
174
175 if (mpi_rank == 0) {
176 printf( "rank %d: norm %.2e, rcond %.2e, cond %.2e\n",
177 mpi_rank, A_norm, A_rcond, 1 / A_rcond );
178 }
179}
180
181//------------------------------------------------------------------------------
182int main( int argc, char** argv )
183{
184 try {
185 // Parse command line to set types for s, d, c, z precisions.
186 bool types[ 4 ];
187 parse_args( argc, argv, types );
188
189 int provided = 0;
190 slate_mpi_call(
191 MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) );
192 assert( provided == MPI_THREAD_MULTIPLE );
193
194 slate_mpi_call(
195 MPI_Comm_size( MPI_COMM_WORLD, &mpi_size ) );
196
197 slate_mpi_call(
198 MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank ) );
199
200 // Determine p-by-q grid for this MPI size.
201 grid_size( mpi_size, &grid_p, &grid_q );
202 if (mpi_rank == 0) {
203 printf( "mpi_size %d, grid_p %d, grid_q %d\n",
204 mpi_size, grid_p, grid_q );
205 }
206
207 // so random_matrix is different on different ranks.
208 srand( 100 * mpi_rank );
209
210 if (types[ 0 ]) {
211 test_cholesky< float >();
212 test_cholesky_factor< float >();
213 test_cholesky_inverse< float >();
214 test_cholesky_cond< float >();
215 }
216 if (mpi_rank == 0)
217 printf( "\n" );
218
219 if (types[ 1 ]) {
220 test_cholesky< double >();
221 test_cholesky_factor< double >();
222 test_cholesky_inverse< double >();
223 test_cholesky_mixed< double >();
224 test_cholesky_cond< double >();
225 }
226 if (mpi_rank == 0)
227 printf( "\n" );
228
229 if (types[ 2 ]) {
230 test_cholesky< std::complex<float> >();
231 test_cholesky_factor< std::complex<float> >();
232 test_cholesky_inverse< std::complex<float> >();
233 test_cholesky_cond< std::complex<float> >();
234 }
235 if (mpi_rank == 0)
236 printf( "\n" );
237
238 if (types[ 3 ]) {
239 test_cholesky< std::complex<double> >();
240 test_cholesky_factor< std::complex<double> >();
241 test_cholesky_inverse< std::complex<double> >();
242 test_cholesky_mixed< std::complex<double> >();
243 test_cholesky_cond< std::complex<double> >();
244 }
245
246 slate_mpi_call(
247 MPI_Finalize() );
248 }
249 catch (std::exception const& ex) {
250 fprintf( stderr, "%s", ex.what() );
251 return 1;
252 }
253 return 0;
254}