Example 02: Matrix Type Conversion

This example demonstrates how to create different matrix views (Trapezoid, Triangular, Symmetric, Hermitian) from a general Matrix.

Key Concepts

  1. Shallow Copies: Creating views of existing data without copying the elements.

  2. Matrix Types: * TrapezoidMatrix: Lower or Upper trapezoid. * TriangularMatrix: Square lower or upper triangle. * SymmetricMatrix: Symmetric matrix (where \(A_{ji} = A_{ij}\)). * HermitianMatrix: Hermitian matrix (where \(A_{ji} = \bar{A}_{ij}\)).

  3. Slicing: Creating a square slice of a general matrix to fit triangular requirements.

C++ Example

General Matrix Creation (Lines 26-29)

slate::Matrix<scalar_type>
    A( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );

We start with a standard, general m by n matrix A. This holds the underlying data.

Trapezoid View (Lines 31-34)

slate::TrapezoidMatrix<scalar_type>
    Lz( slate::Uplo::Lower, slate::Diag::Unit, A );

We create a TrapezoidMatrix named Lz from A.

  • This is a shallow copy. Lz points to the same data tiles as A.

  • Uplo::Lower specifies we are interested in the lower trapezoidal part.

  • Diag::Unit specifies that the diagonal elements are implicitly assumed to be 1.0 (they are not accessed/modified).

Slicing for Square Requirements (Lines 37-39)

int64_t min_mn = std::min( m, n );
auto A_square = A.slice( 0, min_mn-1, 0, min_mn-1 );

Triangular, Symmetric, and Hermitian matrices must be square. If A is rectangular (m != n), we cannot directly convert it to these types. We use slice to create a square view A_square of the top-left portion of A.

Triangular Views (Lines 41-48)

slate::TriangularMatrix<scalar_type>
    L( slate::Uplo::Lower, slate::Diag::Unit, A_square );

slate::TriangularMatrix<scalar_type>
    U( slate::Uplo::Upper, slate::Diag::NonUnit, A_square );

Here we create Lower (L) and Upper (U) triangular views.

  • L effectively sees only the lower triangle of A_square.

  • U sees the upper triangle.

  • These are used for operations like triangular solves (TRSM) or Cholesky factorization.

Symmetric and Hermitian Views (Lines 50-57)

slate::SymmetricMatrix<scalar_type>
    S( slate::Uplo::Upper, A_square );

slate::HermitianMatrix<scalar_type>
    H( slate::Uplo::Upper, A_square );
  • S represents a symmetric matrix where \(A_{ji} = A_{ij}\). Only the upper triangle is stored/referenced; the lower triangle is implicitly defined by symmetry.

  • H represents a Hermitian matrix where \(A_{ji} = \bar{A}_{ij}\).

  • These are crucial for optimized solvers (like Cholesky or LDLT) that exploit symmetry to save computation and storage.

  1// ex02_conversion.cc
  2// conversion between matrix types
  3
  4/// !!!   Lines between `//---------- begin label`          !!!
  5/// !!!             and `//---------- end label`            !!!
  6/// !!!   are included in the SLATE Users' Guide.           !!!
  7
  8#include <slate/slate.hh>
  9
 10#include "util.hh"
 11
 12int mpi_size = 0;
 13int mpi_rank = 0;
 14int grid_p = 0;
 15int grid_q = 0;
 16
 17//------------------------------------------------------------------------------
 18template <typename scalar_type>
 19void test_conversion()
 20{
 21    print_func( mpi_rank );
 22
 23    int64_t m=2000, n=1000, nb=256;
 24
 25    //---------- begin convert
 26    // A is defined to be a general m x n matrix of type scalar_type
 27    // (float, std::complex<float>, double, std::complex<double>, etc.).
 28    slate::Matrix<scalar_type>
 29        A( m, n, nb, grid_p, grid_q, MPI_COMM_WORLD );
 30
 31    // Lz is a trapezoid matrix view of the lower trapezoid of A,
 32    // assuming Unit diagonal.
 33    slate::TrapezoidMatrix<scalar_type>
 34        Lz( slate::Uplo::Lower, slate::Diag::Unit, A );
 35
 36    // Triangular, symmetric, and Hermitian matrices must be square --
 37    // take square slice if needed.
 38    int64_t min_mn = std::min( m, n );
 39    auto A_square = A.slice( 0, min_mn-1, 0, min_mn-1 );
 40
 41    // L is a triangular matrix view of the lower triangle of A,
 42    // assuming Unit diagonal.
 43    slate::TriangularMatrix<scalar_type>
 44        L( slate::Uplo::Lower, slate::Diag::Unit, A_square );
 45
 46    // U is a triangular matrix view of the upper triangle of A.
 47    slate::TriangularMatrix<scalar_type>
 48        U( slate::Uplo::Upper, slate::Diag::NonUnit, A_square );
 49
 50    // S is a symmetric matrix view of the upper triangle of A.
 51    slate::SymmetricMatrix<scalar_type>
 52        S( slate::Uplo::Upper, A_square );
 53
 54    // H is a Hermitian matrix view of the upper triangle of A.
 55    slate::HermitianMatrix<scalar_type>
 56        H( slate::Uplo::Upper, A_square );
 57    //---------- end convert
 58}
 59
 60//------------------------------------------------------------------------------
 61int main( int argc, char** argv )
 62{
 63    try {
 64        // Parse command line to set types for s, d, c, z precisions.
 65        bool types[ 4 ];
 66        parse_args( argc, argv, types );
 67
 68        int provided = 0;
 69        slate_mpi_call(
 70            MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &provided ) );
 71        assert( provided == MPI_THREAD_MULTIPLE );
 72
 73        slate_mpi_call(
 74            MPI_Comm_size( MPI_COMM_WORLD, &mpi_size ) );
 75
 76        slate_mpi_call(
 77            MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank ) );
 78
 79        // Determine p-by-q grid for this MPI size.
 80        grid_size( mpi_size, &grid_p, &grid_q );
 81        if (mpi_rank == 0) {
 82            printf( "mpi_size %d, grid_p %d, grid_q %d\n",
 83                    mpi_size, grid_p, grid_q );
 84        }
 85
 86        // so random_matrix is different on different ranks.
 87        srand( 100 * mpi_rank );
 88
 89        if (types[ 0 ]) {
 90            test_conversion< float >();
 91        }
 92
 93        if (types[ 1 ]) {
 94            test_conversion< double >();
 95        }
 96
 97        if (types[ 2 ]) {
 98            test_conversion< std::complex<float> >();
 99        }
100
101        if (types[ 3 ]) {
102            test_conversion< std::complex<double> >();
103        }
104
105        slate_mpi_call(
106            MPI_Finalize() );
107    }
108    catch (std::exception const& ex) {
109        fprintf( stderr, "%s", ex.what() );
110        return 1;
111    }
112    return 0;
113}