Example 09: Least Squares

This example demonstrates solving overdetermined (\(m > n\)) and underdetermined (\(m < n\)) linear systems using least squares.

Key Concepts

  1. Overdetermined Systems: Finding \(x\) that minimizes \(\|Ax - B\|_2\).

  2. Underdetermined Systems: Finding the minimum norm solution \(x\) that satisfies \(Ax = B\).

  3. Simplified API: Using slate::least_squares_solve which handles both cases automatically.

  4. Traditional API: Using slate::gels.

C++ Example

Overdetermined Least Squares (Lines 27-46)

slate::Matrix<scalar_type> A( m, n, nb, ... );
slate::Matrix<scalar_type> BX( max_mn, nrhs, nb, ... );

// BX contains B on input
auto B = BX; // View of top m rows
auto X = BX.slice( 0, n-1, 0, nrhs-1 ); // View where X will be

slate::least_squares_solve( A, BX );

For overdetermined systems (\(m \ge n\)):

  • A is m by n.

  • The RHS matrix BX must be large enough to hold both the input B (m rows) and the result X (conceptually n rows, though in the algorithm B is overwritten in place). Since m >= n, m rows is sufficient.

  • least_squares_solve (gels) overwrites A with QR factors and BX with the solution.

Underdetermined Least Squares (Lines 59-82)

// solve A^H X = B
auto AH = conj_transpose( A );
slate::least_squares_solve( AH, BX );

For underdetermined systems (\(m < n\)), we typically solve \(A x = B\) (minimum norm solution). SLATE’s gels routine expects an m by n matrix where m >= n. To solve the underdetermined case \(A x = B\) where A is fat (m < n), we mathematically transform this into a problem involving \(A^H\) (which is tall).

  • The example demonstrates solving \(A^H X = B\) where A is tall (m > n), which effectively simulates an underdetermined system from the perspective of the transposed matrix.

  • BX must be size max(m, n) by nrhs. Since the solution vector X will be larger than the input B, BX provides the necessary space.

  1// ex09_least_squares.cc
  2// Solve over- and under-determined AX = B
  3
  4/// !!!   Lines between `//---------- begin label`          !!!
  5/// !!!             and `//---------- end label`            !!!
  6/// !!!   are included in the SLATE Users' Guide.           !!!
  7
  8#include <slate/slate.hh>
  9
 10#include "util.hh"
 11
 12int mpi_size = 0;
 13int mpi_rank = 0;
 14int grid_p = 0;
 15int grid_q = 0;
 16
 17//------------------------------------------------------------------------------
 18template <typename scalar_type>
 19void test_gels_overdetermined()
 20{
 21    print_func( mpi_rank );
 22
 23    int64_t m=2000, n=1000, nrhs=100, nb=256;
 24
 25    //---------- begin over1
 26    int64_t max_mn = std::max( m, n );
 27    slate::Matrix<scalar_type> A( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
 28    slate::Matrix<scalar_type> BX( max_mn, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
 29    // ...
 30    //---------- end over1
 31
 32    A.insertLocalTiles();
 33    BX.insertLocalTiles();
 34    //---------- begin over2
 35    auto B = BX;  // == BX.slice( 0, m-1, 0, nrhs-1 );
 36    auto X = BX.slice( 0, n-1, 0, nrhs-1 );
 37    //---------- end over2
 38    random_matrix( A );
 39    random_matrix( B );
 40
 41    //---------- begin over3
 42
 43    // solve AX = B, solution in X
 44    slate::least_squares_solve( A, BX );  // simplified API
 45
 46    slate::gels( A, BX );                 // traditional API
 47    //---------- end over3
 48}
 49
 50//------------------------------------------------------------------------------
 51template <typename scalar_type>
 52void test_gels_underdetermined()
 53{
 54    print_func( mpi_rank );
 55
 56    int64_t m=2000, n=1000, nrhs=100, nb=256;
 57
 58    //---------- begin under1
 59    int64_t max_mn = std::max( m, n );
 60    slate::Matrix<scalar_type> A( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
 61    slate::Matrix<scalar_type> BX( max_mn, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
 62    // ...
 63    //---------- end under1
 64
 65    A.insertLocalTiles();
 66    BX.insertLocalTiles();
 67
 68    //---------- begin under2
 69    auto B = BX.slice( 0, n-1, 0, nrhs-1 );
 70    auto X = BX;  // == BX.slice( 0, m-1, 0, nrhs-1 );
 71    //---------- end under2
 72
 73    random_matrix( A );
 74    random_matrix( B );
 75
 76    //---------- begin under3
 77
 78    // solve A^H X = B, solution in X
 79    auto AH = conj_transpose( A );
 80    slate::least_squares_solve( AH, BX );  // simplified API
 81
 82    slate::gels( AH, BX );                 // traditional API
 83    //---------- end under3
 84}
 85
 86//------------------------------------------------------------------------------
 87int main( int argc, char** argv )
 88{
 89    try {
 90        // Parse command line to set types for s, d, c, z precisions.
 91        bool types[ 4 ];
 92        parse_args( argc, argv, types );
 93
 94        int provided = 0;
 95        slate_mpi_call(
 96            MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) );
 97        assert( provided == MPI_THREAD_MULTIPLE );
 98
 99        slate_mpi_call(
100            MPI_Comm_size( MPI_COMM_WORLD, &mpi_size ) );
101
102        slate_mpi_call(
103            MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank ) );
104
105        // Determine p-by-q grid for this MPI size.
106        grid_size( mpi_size, &grid_p, &grid_q );
107        if (mpi_rank == 0) {
108            printf( "mpi_size %d, grid_p %d, grid_q %d\n",
109                    mpi_size, grid_p, grid_q );
110        }
111
112        // so random_matrix is different on different ranks.
113        srand( 100 * mpi_rank );
114
115        if (types[ 0 ]) {
116            test_gels_overdetermined < float >();
117            test_gels_underdetermined< float >();
118        }
119        if (mpi_rank == 0)
120            printf( "\n" );
121
122        if (types[ 1 ]) {
123            test_gels_overdetermined < double >();
124            test_gels_underdetermined< double >();
125        }
126        if (mpi_rank == 0)
127            printf( "\n" );
128
129        if (types[ 2 ]) {
130            test_gels_overdetermined < std::complex<float> >();
131            test_gels_underdetermined< std::complex<float> >();
132        }
133        if (mpi_rank == 0)
134            printf( "\n" );
135
136        if (types[ 3 ]) {
137            test_gels_overdetermined < std::complex<double> >();
138            test_gels_underdetermined< std::complex<double> >();
139        }
140
141        slate_mpi_call(
142            MPI_Finalize() );
143    }
144    catch (std::exception const& ex) {
145        fprintf( stderr, "%s", ex.what() );
146        return 1;
147    }
148    return 0;
149}