Example 09: Least Squares
This example demonstrates solving overdetermined (\(m > n\)) and underdetermined (\(m < n\)) linear systems using least squares.
Key Concepts
Overdetermined Systems: Finding \(x\) that minimizes \(\|Ax - B\|_2\).
Underdetermined Systems: Finding the minimum norm solution \(x\) that satisfies \(Ax = B\).
Simplified API: Using
slate::least_squares_solvewhich handles both cases automatically.Traditional API: Using
slate::gels.
C++ Example
Overdetermined Least Squares (Lines 27-46)
slate::Matrix<scalar_type> A( m, n, nb, ... );
slate::Matrix<scalar_type> BX( max_mn, nrhs, nb, ... );
// BX contains B on input
auto B = BX; // View of top m rows
auto X = BX.slice( 0, n-1, 0, nrhs-1 ); // View where X will be
slate::least_squares_solve( A, BX );
For overdetermined systems (\(m \ge n\)):
A is m by n.
The RHS matrix BX must be large enough to hold both the input B (m rows) and the result X (conceptually n rows, though in the algorithm B is overwritten in place). Since m >= n, m rows is sufficient.
least_squares_solve (gels) overwrites A with QR factors and BX with the solution.
Underdetermined Least Squares (Lines 59-82)
// solve A^H X = B
auto AH = conj_transpose( A );
slate::least_squares_solve( AH, BX );
For underdetermined systems (\(m < n\)), we typically solve \(A x = B\) (minimum norm solution). SLATE’s gels routine expects an m by n matrix where m >= n. To solve the underdetermined case \(A x = B\) where A is fat (m < n), we mathematically transform this into a problem involving \(A^H\) (which is tall).
The example demonstrates solving \(A^H X = B\) where A is tall (m > n), which effectively simulates an underdetermined system from the perspective of the transposed matrix.
BX must be size max(m, n) by nrhs. Since the solution vector X will be larger than the input B, BX provides the necessary space.
1// ex09_least_squares.cc
2// Solve over- and under-determined AX = B
3
4/// !!! Lines between `//---------- begin label` !!!
5/// !!! and `//---------- end label` !!!
6/// !!! are included in the SLATE Users' Guide. !!!
7
8#include <slate/slate.hh>
9
10#include "util.hh"
11
12int mpi_size = 0;
13int mpi_rank = 0;
14int grid_p = 0;
15int grid_q = 0;
16
17//------------------------------------------------------------------------------
18template <typename scalar_type>
19void test_gels_overdetermined()
20{
21 print_func( mpi_rank );
22
23 int64_t m=2000, n=1000, nrhs=100, nb=256;
24
25 //---------- begin over1
26 int64_t max_mn = std::max( m, n );
27 slate::Matrix<scalar_type> A( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
28 slate::Matrix<scalar_type> BX( max_mn, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
29 // ...
30 //---------- end over1
31
32 A.insertLocalTiles();
33 BX.insertLocalTiles();
34 //---------- begin over2
35 auto B = BX; // == BX.slice( 0, m-1, 0, nrhs-1 );
36 auto X = BX.slice( 0, n-1, 0, nrhs-1 );
37 //---------- end over2
38 random_matrix( A );
39 random_matrix( B );
40
41 //---------- begin over3
42
43 // solve AX = B, solution in X
44 slate::least_squares_solve( A, BX ); // simplified API
45
46 slate::gels( A, BX ); // traditional API
47 //---------- end over3
48}
49
50//------------------------------------------------------------------------------
51template <typename scalar_type>
52void test_gels_underdetermined()
53{
54 print_func( mpi_rank );
55
56 int64_t m=2000, n=1000, nrhs=100, nb=256;
57
58 //---------- begin under1
59 int64_t max_mn = std::max( m, n );
60 slate::Matrix<scalar_type> A( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
61 slate::Matrix<scalar_type> BX( max_mn, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
62 // ...
63 //---------- end under1
64
65 A.insertLocalTiles();
66 BX.insertLocalTiles();
67
68 //---------- begin under2
69 auto B = BX.slice( 0, n-1, 0, nrhs-1 );
70 auto X = BX; // == BX.slice( 0, m-1, 0, nrhs-1 );
71 //---------- end under2
72
73 random_matrix( A );
74 random_matrix( B );
75
76 //---------- begin under3
77
78 // solve A^H X = B, solution in X
79 auto AH = conj_transpose( A );
80 slate::least_squares_solve( AH, BX ); // simplified API
81
82 slate::gels( AH, BX ); // traditional API
83 //---------- end under3
84}
85
86//------------------------------------------------------------------------------
87int main( int argc, char** argv )
88{
89 try {
90 // Parse command line to set types for s, d, c, z precisions.
91 bool types[ 4 ];
92 parse_args( argc, argv, types );
93
94 int provided = 0;
95 slate_mpi_call(
96 MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) );
97 assert( provided == MPI_THREAD_MULTIPLE );
98
99 slate_mpi_call(
100 MPI_Comm_size( MPI_COMM_WORLD, &mpi_size ) );
101
102 slate_mpi_call(
103 MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank ) );
104
105 // Determine p-by-q grid for this MPI size.
106 grid_size( mpi_size, &grid_p, &grid_q );
107 if (mpi_rank == 0) {
108 printf( "mpi_size %d, grid_p %d, grid_q %d\n",
109 mpi_size, grid_p, grid_q );
110 }
111
112 // so random_matrix is different on different ranks.
113 srand( 100 * mpi_rank );
114
115 if (types[ 0 ]) {
116 test_gels_overdetermined < float >();
117 test_gels_underdetermined< float >();
118 }
119 if (mpi_rank == 0)
120 printf( "\n" );
121
122 if (types[ 1 ]) {
123 test_gels_overdetermined < double >();
124 test_gels_underdetermined< double >();
125 }
126 if (mpi_rank == 0)
127 printf( "\n" );
128
129 if (types[ 2 ]) {
130 test_gels_overdetermined < std::complex<float> >();
131 test_gels_underdetermined< std::complex<float> >();
132 }
133 if (mpi_rank == 0)
134 printf( "\n" );
135
136 if (types[ 3 ]) {
137 test_gels_overdetermined < std::complex<double> >();
138 test_gels_underdetermined< std::complex<double> >();
139 }
140
141 slate_mpi_call(
142 MPI_Finalize() );
143 }
144 catch (std::exception const& ex) {
145 fprintf( stderr, "%s", ex.what() );
146 return 1;
147 }
148 return 0;
149}