Example 08: Linear Systems (Indefinite)
This example demonstrates solving symmetric/Hermitian indefinite linear systems using Aasen’s algorithm.
Key Concepts
Indefinite Solve: Using
slate::indefinite_solve(hesv/sysv) for symmetric/Hermitian matrices that are not necessarily positive definite.Aasen’s Algorithm: A factorization method \(A = L T L^H\) where \(T\) is tridiagonal.
Explicit Factorization: Using
indefinite_factor(hetrf) andindefinite_solve_using_factor(hetrs).Workspace: Allocating necessary workspace matrices (BandMatrix T, Matrix H).
C++ Example
Indefinite Solve (Lines 41-50)
slate::indefinite_solve( A, B ); // simplified API
// traditional API with workspace setup
slate::Matrix<scalar_type> H( ... );
slate::BandMatrix<scalar_type> T( ... );
slate::Pivots pivots, pivots2;
slate::hesv( A, pivots, T, pivots2, H, B );
Solves \(Ax=B\) where A is symmetric/Hermitian but not positive definite.
The simplified API (indefinite_solve) automatically handles the allocation of the auxiliary workspaces T and H and pivot vectors.
The traditional API (hesv for Hermitian, sysv for Symmetric) requires you to pre-allocate:
T: A band matrix to store the tridiagonal factor.
H: A matrix for internal workspace.
pivots, pivots2: Vectors to store pivot information.
Explicit Factorization (Lines 78-83)
slate::indefinite_factor( A, pivots, T, pivots2, H );
slate::indefinite_solve_using_factor( A, pivots, T, pivots2, B );
Separates the factorization (Aasen’s algorithm) from the solve.
indefinite_factor (hetrf): Computes the \(LTL^H\) factorization.
indefinite_solve_using_factor (hetrs): Solves the system using the factors.
Requires managing the workspaces T and H explicitly even in the simplified API wrapper if you want to keep the factors.
1// ex08_linear_system_indefinite.cc
2// Solve AX = B using Aasen's symmetric indefinite factorization
3
4/// !!! Lines between `//---------- begin label` !!!
5/// !!! and `//---------- end label` !!!
6/// !!! are included in the SLATE Users' Guide. !!!
7
8#include <slate/slate.hh>
9
10#include "util.hh"
11
12int mpi_size = 0;
13int mpi_rank = 0;
14int grid_p = 0;
15int grid_q = 0;
16
17//------------------------------------------------------------------------------
18template <typename scalar_type>
19void test_hesv()
20{
21 print_func( mpi_rank );
22
23 // note: currently requires n divisible by nb.
24 int64_t n=1000, nrhs=100, nb=100;
25
26 //---------- begin solve1
27 slate::HermitianMatrix<scalar_type>
28 A( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
29 slate::Matrix<scalar_type> B( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
30 // ...
31 //---------- end solve1
32
33 A.insertLocalTiles();
34 B.insertLocalTiles();
35 random_matrix( A );
36 random_matrix( B );
37
38 //---------- begin solve2
39
40 // simplified API
41 slate::indefinite_solve( A, B );
42
43 // traditional API
44 // workspaces
45 // todo: drop H (internal workspace)
46 slate::Matrix<scalar_type> H( n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
47 slate::BandMatrix<scalar_type> T( n, n, nb, nb, nb, grid_p, grid_q, MPI_COMM_WORLD );
48 slate::Pivots pivots, pivots2;
49
50 slate::hesv( A, pivots, T, pivots2, H, B );
51 //---------- end solve2
52}
53
54//------------------------------------------------------------------------------
55template <typename scalar_type>
56void test_hetrf()
57{
58 print_func( mpi_rank );
59
60 // note: currently requires n divisible by nb.
61 int64_t n=1000, nrhs=100, nb=100;
62
63 slate::HermitianMatrix<scalar_type>
64 A( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
65 slate::Matrix<scalar_type> B( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
66 A.insertLocalTiles();
67 B.insertLocalTiles();
68 random_matrix( A );
69 random_matrix( B );
70
71 // workspaces
72 // todo: drop H (internal workspace)
73 slate::Matrix<scalar_type> H( n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
74 slate::BandMatrix<scalar_type> T( n, n, nb, nb, nb, grid_p, grid_q, MPI_COMM_WORLD );
75 slate::Pivots pivots, pivots2;
76
77 // simplified API
78 slate::indefinite_factor( A, pivots, T, pivots2, H );
79 slate::indefinite_solve_using_factor( A, pivots, T, pivots2, B );
80
81 // traditional API
82 slate::hetrf( A, pivots, T, pivots2, H ); // factor
83 slate::hetrs( A, pivots, T, pivots2, B ); // solve
84}
85
86//------------------------------------------------------------------------------
87int main( int argc, char** argv )
88{
89 try {
90 // Parse command line to set types for s, d, c, z precisions.
91 bool types[ 4 ];
92 parse_args( argc, argv, types );
93
94 int provided = 0;
95 slate_mpi_call(
96 MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) );
97 assert( provided == MPI_THREAD_MULTIPLE );
98
99 slate_mpi_call(
100 MPI_Comm_size( MPI_COMM_WORLD, &mpi_size ) );
101
102 slate_mpi_call(
103 MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank ) );
104
105 // Determine p-by-q grid for this MPI size.
106 grid_size( mpi_size, &grid_p, &grid_q );
107 if (mpi_rank == 0) {
108 printf( "mpi_size %d, grid_p %d, grid_q %d\n",
109 mpi_size, grid_p, grid_q );
110 }
111
112 // so random_matrix is different on different ranks.
113 srand( 100 * mpi_rank );
114
115 if (types[ 0 ]) {
116 test_hesv < float >();
117 test_hetrf< float >();
118 }
119 if (mpi_rank == 0)
120 printf( "\n" );
121
122 if (types[ 1 ]) {
123 test_hesv < double >();
124 test_hetrf< double >();
125 }
126 if (mpi_rank == 0)
127 printf( "\n" );
128
129 if (types[ 2 ]) {
130 test_hesv < std::complex<float> >();
131 test_hetrf< std::complex<float> >();
132 }
133 if (mpi_rank == 0)
134 printf( "\n" );
135
136 if (types[ 3 ]) {
137 test_hesv < std::complex<double> >();
138 test_hetrf< std::complex<double> >();
139 }
140
141 slate_mpi_call(
142 MPI_Finalize() );
143 }
144 catch (std::exception const& ex) {
145 fprintf( stderr, "%s", ex.what() );
146 return 1;
147 }
148 return 0;
149}