Example 08: Linear Systems (Indefinite)

This example demonstrates solving symmetric/Hermitian indefinite linear systems using Aasen’s algorithm.

Key Concepts

  1. Indefinite Solve: Using slate::indefinite_solve (hesv/sysv) for symmetric/Hermitian matrices that are not necessarily positive definite.

  2. Aasen’s Algorithm: A factorization method \(A = L T L^H\) where \(T\) is tridiagonal.

  3. Explicit Factorization: Using indefinite_factor (hetrf) and indefinite_solve_using_factor (hetrs).

  4. Workspace: Allocating necessary workspace matrices (BandMatrix T, Matrix H).

C++ Example

Indefinite Solve (Lines 41-50)

slate::indefinite_solve( A, B );  // simplified API

// traditional API with workspace setup
slate::Matrix<scalar_type>     H( ... );
slate::BandMatrix<scalar_type> T( ... );
slate::Pivots pivots, pivots2;
slate::hesv( A, pivots, T, pivots2, H, B );

Solves \(Ax=B\) where A is symmetric/Hermitian but not positive definite.

  • The simplified API (indefinite_solve) automatically handles the allocation of the auxiliary workspaces T and H and pivot vectors.

  • The traditional API (hesv for Hermitian, sysv for Symmetric) requires you to pre-allocate:

    • T: A band matrix to store the tridiagonal factor.

    • H: A matrix for internal workspace.

    • pivots, pivots2: Vectors to store pivot information.

Explicit Factorization (Lines 78-83)

slate::indefinite_factor( A, pivots, T, pivots2, H );
slate::indefinite_solve_using_factor( A, pivots, T, pivots2, B );

Separates the factorization (Aasen’s algorithm) from the solve.

  • indefinite_factor (hetrf): Computes the \(LTL^H\) factorization.

  • indefinite_solve_using_factor (hetrs): Solves the system using the factors.

  • Requires managing the workspaces T and H explicitly even in the simplified API wrapper if you want to keep the factors.

  1// ex08_linear_system_indefinite.cc
  2// Solve AX = B using Aasen's symmetric indefinite factorization
  3
  4/// !!!   Lines between `//---------- begin label`          !!!
  5/// !!!             and `//---------- end label`            !!!
  6/// !!!   are included in the SLATE Users' Guide.           !!!
  7
  8#include <slate/slate.hh>
  9
 10#include "util.hh"
 11
 12int mpi_size = 0;
 13int mpi_rank = 0;
 14int grid_p = 0;
 15int grid_q = 0;
 16
 17//------------------------------------------------------------------------------
 18template <typename scalar_type>
 19void test_hesv()
 20{
 21    print_func( mpi_rank );
 22
 23    // note: currently requires n divisible by nb.
 24    int64_t n=1000, nrhs=100, nb=100;
 25
 26    //---------- begin solve1
 27    slate::HermitianMatrix<scalar_type>
 28        A( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
 29    slate::Matrix<scalar_type> B( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
 30    // ...
 31    //---------- end solve1
 32
 33    A.insertLocalTiles();
 34    B.insertLocalTiles();
 35    random_matrix( A );
 36    random_matrix( B );
 37
 38    //---------- begin solve2
 39
 40    // simplified API
 41    slate::indefinite_solve( A, B );
 42
 43    // traditional API
 44    // workspaces
 45    // todo: drop H (internal workspace)
 46    slate::Matrix<scalar_type>     H( n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
 47    slate::BandMatrix<scalar_type> T( n, n, nb, nb, nb, grid_p, grid_q, MPI_COMM_WORLD );
 48    slate::Pivots pivots, pivots2;
 49
 50    slate::hesv( A, pivots, T, pivots2, H, B );
 51    //---------- end solve2
 52}
 53
 54//------------------------------------------------------------------------------
 55template <typename scalar_type>
 56void test_hetrf()
 57{
 58    print_func( mpi_rank );
 59
 60    // note: currently requires n divisible by nb.
 61    int64_t n=1000, nrhs=100, nb=100;
 62
 63    slate::HermitianMatrix<scalar_type>
 64        A( slate::Uplo::Lower, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
 65    slate::Matrix<scalar_type> B( n, nrhs, nb, grid_p, grid_q, MPI_COMM_WORLD );
 66    A.insertLocalTiles();
 67    B.insertLocalTiles();
 68    random_matrix( A );
 69    random_matrix( B );
 70
 71    // workspaces
 72    // todo: drop H (internal workspace)
 73    slate::Matrix<scalar_type>     H( n, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
 74    slate::BandMatrix<scalar_type> T( n, n, nb, nb, nb, grid_p, grid_q, MPI_COMM_WORLD );
 75    slate::Pivots pivots, pivots2;
 76
 77    // simplified API
 78    slate::indefinite_factor( A, pivots, T, pivots2, H );
 79    slate::indefinite_solve_using_factor( A, pivots, T, pivots2, B );
 80
 81    // traditional API
 82    slate::hetrf( A, pivots, T, pivots2, H );  // factor
 83    slate::hetrs( A, pivots, T, pivots2, B );  // solve
 84}
 85
 86//------------------------------------------------------------------------------
 87int main( int argc, char** argv )
 88{
 89    try {
 90        // Parse command line to set types for s, d, c, z precisions.
 91        bool types[ 4 ];
 92        parse_args( argc, argv, types );
 93
 94        int provided = 0;
 95        slate_mpi_call(
 96            MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) );
 97        assert( provided == MPI_THREAD_MULTIPLE );
 98
 99        slate_mpi_call(
100            MPI_Comm_size( MPI_COMM_WORLD, &mpi_size ) );
101
102        slate_mpi_call(
103            MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank ) );
104
105        // Determine p-by-q grid for this MPI size.
106        grid_size( mpi_size, &grid_p, &grid_q );
107        if (mpi_rank == 0) {
108            printf( "mpi_size %d, grid_p %d, grid_q %d\n",
109                    mpi_size, grid_p, grid_q );
110        }
111
112        // so random_matrix is different on different ranks.
113        srand( 100 * mpi_rank );
114
115        if (types[ 0 ]) {
116            test_hesv < float >();
117            test_hetrf< float >();
118        }
119        if (mpi_rank == 0)
120            printf( "\n" );
121
122        if (types[ 1 ]) {
123            test_hesv < double >();
124            test_hetrf< double >();
125        }
126        if (mpi_rank == 0)
127            printf( "\n" );
128
129        if (types[ 2 ]) {
130            test_hesv < std::complex<float> >();
131            test_hetrf< std::complex<float> >();
132        }
133        if (mpi_rank == 0)
134            printf( "\n" );
135
136        if (types[ 3 ]) {
137            test_hesv < std::complex<double> >();
138            test_hetrf< std::complex<double> >();
139        }
140
141        slate_mpi_call(
142            MPI_Finalize() );
143    }
144    catch (std::exception const& ex) {
145        fprintf( stderr, "%s", ex.what() );
146        return 1;
147    }
148    return 0;
149}