Example 12: Generalized Hermitian Eigenvalues
This example demonstrates solving generalized Hermitian eigenvalue problems.
Key Concepts
Problem Types: * Type 1: \(Ax = \lambda Bx\) * Type 2: \(ABx = \lambda x\) * Type 3: \(BAx = \lambda x\)
Positive Definite B: The matrix B must be Hermitian positive definite.
API Usage: Using
slate::eig_vals,slate::eig, andslate::hegvwith the type parameter.
C++ Example
Problem Setup (Lines 32-37)
slate::HermitianMatrix<scalar_type> A( ... ), B( ... );
slate::Matrix<scalar_type> Z( ... );
std::vector<real_t> Lambda( n );
A: Hermitian matrix.
B: Hermitian positive definite matrix.
Z: Eigenvectors.
Lambda: Eigenvalues.
Type 1: Ax = lambda Bx (Lines 50-64)
slate::eig_vals( 1, A, B, Lambda );
slate::eig( 1, A, B, Lambda );
slate::hegv( 1, A, B, Lambda );
Solves \(Ax = \lambda Bx\).
A is overwritten.
B is overwritten by its Cholesky factor.
Type 2: ABx = lambda x (Lines 74-123)
slate::eig( 2, A, B, Lambda, Z );
Solves \(ABx = \lambda x\).
Type 3: BAx = lambda x (Lines 84-129)
slate::eig( 3, A, B, Lambda, Z );
Solves \(BAx = \lambda x\).
Note that for all types, B must be positive definite because the algorithms internally perform a Cholesky factorization of B to transform the generalized problem into a standard eigenvalue problem.
1// ex12_generalized_hermitian_eig.cc
2// Solve generalized Hermitian eigenvalues, types:
3// 1. A = B Z Lambda Z^H
4// 2. A B = Z Lambda Z^H
5// 3. B A = Z Lambda Z^H
6// where B is Hermitian positive definite.
7
8/// !!! Lines between `//---------- begin label` !!!
9/// !!! and `//---------- end label` !!!
10/// !!! are included in the SLATE Users' Guide. !!!
11
12#include <slate/slate.hh>
13
14#include "util.hh"
15
16int mpi_size = 0;
17int mpi_rank = 0;
18int grid_p = 0;
19int grid_q = 0;
20
21//------------------------------------------------------------------------------
22template <typename scalar_type>
23void test_hermitian_eig()
24{
25 using real_t = blas::real_type<scalar_type>;
26
27 print_func( mpi_rank );
28
29 int64_t n=1000, nb=256;
30
31 //---------- begin eig1
32 slate::HermitianMatrix<scalar_type>
33 A( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD ),
34 B( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
35 slate::Matrix<scalar_type>
36 Z( n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
37 std::vector<real_t> Lambda( n );
38 // ...
39 //---------- end eig1
40
41 A.insertLocalTiles();
42 B.insertLocalTiles();
43 Z.insertLocalTiles();
44 random_matrix( A );
45 random_matrix_diag_dominant( B );
46
47 //----------------------------------------
48 //---------- begin eig2
49 // Type 1: A = B Z Lambda Z^H, eigenvalues only
50 slate::eig_vals( 1, A, B, Lambda ); // simplified API, or
51 //---------- end eig2
52
53 random_matrix( A );
54 random_matrix_diag_dominant( B );
55
56 //---------- begin eig3
57 slate::eig( 1, A, B, Lambda ); // simplified API
58 //---------- end eig3
59
60 random_matrix( A );
61 random_matrix_diag_dominant( B );
62
63 //---------- begin eig4
64 slate::hegv( 1, A, B, Lambda ); // traditional API
65 //---------- end eig4
66
67 random_matrix( A );
68 random_matrix_diag_dominant( B );
69
70 //----------------------------------------
71 //---------- begin eig5
72
73 // Type 2: A B = Z Lambda Z^H, eigenvalues only
74 slate::eig_vals( 2, A, B, Lambda ); // simplified API
75 //---------- end eig5
76
77 random_matrix( A );
78 random_matrix_diag_dominant( B );
79
80 //----------------------------------------
81 //---------- begin eig6
82
83 // Type 3: A = B Z Lambda Z^H, eigenvalues only
84 slate::eig_vals( 3, A, B, Lambda ); // simplified API
85 //---------- end eig6
86
87 random_matrix( A );
88 random_matrix_diag_dominant( B );
89
90 //----------------------------------------
91 //---------- begin eig7
92
93 // Types 1, 2, and 3, with eigenvectors
94 slate::eig( 1, A, B, Lambda, Z ); // simplified API
95 //---------- end eig7
96
97 random_matrix( A );
98 random_matrix_diag_dominant( B );
99
100 //---------- begin eig8
101 slate::eig( 2, A, B, Lambda, Z ); // simplified API
102 //---------- end eig8
103
104 random_matrix( A );
105 random_matrix_diag_dominant( B );
106
107 //---------- begin eig9
108 slate::eig( 3, A, B, Lambda, Z ); // simplified API
109 //---------- end eig9
110
111 random_matrix( A );
112 random_matrix_diag_dominant( B );
113
114 //---------- begin eig10
115 slate::hegv( 1, A, B, Lambda, Z ); // traditional API
116 //---------- end eig10
117
118 random_matrix( A );
119 random_matrix_diag_dominant( B );
120
121 //---------- begin eig11
122 slate::hegv( 2, A, B, Lambda, Z ); // traditional API
123 //---------- end eig11
124
125 random_matrix( A );
126 random_matrix_diag_dominant( B );
127
128 //---------- begin eig12
129 slate::hegv( 3, A, B, Lambda, Z ); // traditional API
130 //---------- end eig12
131}
132
133//------------------------------------------------------------------------------
134int main( int argc, char** argv )
135{
136 try {
137 // Parse command line to set types for s, d, c, z precisions.
138 bool types[ 4 ];
139 parse_args( argc, argv, types );
140
141 int provided = 0;
142 slate_mpi_call(
143 MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) );
144 assert( provided == MPI_THREAD_MULTIPLE );
145
146 slate_mpi_call(
147 MPI_Comm_size( MPI_COMM_WORLD, &mpi_size ) );
148
149 slate_mpi_call(
150 MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank ) );
151
152 // Determine p-by-q grid for this MPI size.
153 // Hermitian eig requires square MPI grid.
154 grid_size_square( mpi_size, &grid_p, &grid_q );
155 if (mpi_rank == 0) {
156 printf( "mpi_size %d, grid_p %d, grid_q %d\n",
157 mpi_size, grid_p, grid_q );
158 }
159
160 // so random_matrix is different on different ranks.
161 srand( 100 * mpi_rank );
162
163 if (types[ 0 ]) {
164 test_hermitian_eig< float >();
165 }
166
167 if (types[ 1 ]) {
168 test_hermitian_eig< double >();
169 }
170
171 if (types[ 2 ]) {
172 test_hermitian_eig< std::complex<float> >();
173 }
174
175 if (types[ 3 ]) {
176 test_hermitian_eig< std::complex<double> >();
177 }
178
179 slate_mpi_call(
180 MPI_Finalize() );
181 }
182 catch (std::exception const& ex) {
183 fprintf( stderr, "%s", ex.what() );
184 return 1;
185 }
186 return 0;
187}