Skip to content
Snippets Groups Projects
Select Git revision
  • master
  • eigensolver
  • jw-booster
  • updateQR
4 results

pxgeqrf.cpp

Blame
  • pxgeqrf.cpp 8.85 KiB
    #include <mpi.h>
    #include <stdio.h>
    #include <iostream>
    #include <vector>
    #include <cmath>
    #include <time.h>
    #include <iomanip>
    #include <chrono>
    #include <cstdlib>
    #include <boost/program_options.hpp>
    
    #include <mkl_scalapack.h>
    #include <mkl_blacs.h>
    #include <mkl_pblas.h>
    
    #include <mpi.h>
    
    #include "./impl/tsqr_mpi.hpp"
    
    #if defined(USE_MULTITHREADS)
    #include <omp.h>
    #endif
    
    #define MAX(x, y) ((x)>(y) ? (x):(y))
    
    const double  zero = 0.0E+0, one = 1.0E+0, two = 2.0E+0, negone = -1.0E+0;
    const MKL_INT i_zero = 0, i_one = 1, i_four = 4, i_negone = -1;
    const char    trans = 'N';
    
    typedef MKL_INT MDESC[ 9 ];
    
    
    void ScatterMatrix(int ctxt, int M, int N, int Mb, int Nb, int nrows, int ncols, double *A_glob, double *A_loc, int mpiroot){
    
        MKL_INT iam, nprocs;
        MKL_INT procrows, proccols;
        MKL_INT myrow, mycol;
        MPI_Datatype type1, type2;
    
        int TAG = 3;
        MPI_Status status;
    
        blacs_pinfo( &iam, &nprocs );
    
        int sendr = 0, sendc = 0, recvr = 0, recvc = 0;
    
        blacs_gridinfo( &ctxt, &procrows, &proccols, &myrow, &mycol);
    
        for (int r = 0; r < M; r += Mb, sendr = (sendr + 1) % procrows) {
    	
            sendc = 0;
    
            int nr = Mb;
    
            if (M-r < Mb){
    
                nr = M-r;
    
    	}
    
    	for (int c = 0; c < N; c += Nb, sendc = (sendc + 1) % proccols) {
    
                int nc = Nb;
    
                if (N-c < Nb){
                
    	        nc = N-c;
    	    }
    
    	    MPI_Type_vector(nc, nr, M, MPI_DOUBLE,&type1);
                MPI_Type_commit(&type1);
                MPI_Type_vector(nc, nr, nrows, MPI_DOUBLE,&type2);
                MPI_Type_commit(&type2);
    
    	    if(myrow == 0 && mycol == 0){
    
    		MPI_Send(A_glob + M * c + r, 1, type1, sendr + sendc * procrows, TAG, MPI_COMM_WORLD);
                }
    
                if (myrow == sendr && mycol == sendc) {
    
    		MPI_Recv(A_loc + nrows * recvc + recvr, 1, type2, 0, TAG, MPI_COMM_WORLD, &status);
    
                    recvc = (recvc+nc)%ncols;
    
                }
    
            }      
       
            if (myrow == sendr){
    
               recvr = (recvr+nr)%nrows;
    
            }
        }
    }
    
    void GatherMatrix(int ctxt, int M, int N, int Mb, int Nb, int nrows, int ncols, double *A_loc, double *A_glob, int mpiroot){
    
        MKL_INT iam, nprocs;
        MKL_INT procrows, proccols;
        MKL_INT myrow, mycol;
        MPI_Datatype type1, type2;
    
        int TAG = 3;
        MPI_Status status;
    
        blacs_pinfo( &iam, &nprocs );
    
        int sendr = 0, sendc = 0, recvr = 0, recvc = 0;
    
        blacs_gridinfo( &ctxt, &procrows, &proccols, &myrow, &mycol);
    
        for (int r = 0; r < M; r += Mb, sendr = (sendr + 1) % procrows) {
    
            sendc = 0;
    
            int nr = Mb;
    
            if (M - r < Mb){
            
                nr = M - r;
    	
    	}
            
            for (int c = 0; c < N; c += Nb, sendc = (sendc + 1) % proccols){
    
    	    int nc = Nb;
    
                if (N-c < Nb){
    
                    nc = N - c;
    
    	    }
    
                MPI_Type_vector(nc, nr, M, MPI_DOUBLE,&type1);
                MPI_Type_commit(&type1);
                MPI_Type_vector(nc, nr, nrows, MPI_DOUBLE,&type2);
                MPI_Type_commit(&type2);
    
                if (myrow == sendr && mycol == sendc) {
    
                    MPI_Send(A_loc + nrows * recvc + recvr, 1, type2, 0, TAG, MPI_COMM_WORLD);
    
                    recvc = (recvc+nc)%ncols;
                }
    
                if (myrow == 0 && mycol == 0) {
    
                    MPI_Recv(A_glob + M * c + r, 1, type1, sendr + sendc * procrows, TAG, MPI_COMM_WORLD, &status);
                }
    
            }
    
    	if (myrow == sendr){
             
    	   recvr = (recvr+nr)%nrows;
    
    	}
        }
    }
    
    int main(int argc, char** argv) {
    
        using namespace boost::program_options;
    
        options_description desc_commandline;
    
        desc_commandline.add_options()
            ("m", value<int>()->default_value(1024), "Size m.")
            ("n", value<int>()->default_value(32), "Size n.")
    //
            ("mb", value<int>()->default_value(32), "Size mb.")
    //
            ("nb", value<int>()->default_value(32), "Size nb.")
            ("rpoc", value<int>()->default_value(1), "proc nb in row dimension.")
            ("debug", value<std::string>()->default_value("no"), "(debug) => print matrices")
            ("explicit-Q", bool_switch()->default_value(false), "Explicit generation of Q")
            ("validate", bool_switch()->default_value(false), "validation of the implementation")
            ("omp:threads", value<int>()->default_value(1), "OpenMP thread number.")
            ("repetitions,r", value<int>()->default_value(1), "Number of repetitions.");
    
        variables_map vm;
    
        store(parse_command_line(argc, argv, desc_commandline), vm);
    
        int m = vm["m"].as<int>();
        int n = vm["n"].as<int>();
        int nb = vm["nb"].as<int>();
    
    //
        int mb = vm["mb"].as<int>();
        mb = nb;
    //
        int rpoc_init = vm["rpoc"].as<int>();
        int rep = vm["repetitions"].as<int>();
        bool explicitQ=vm["explicit-Q"].as<bool>();
        bool validate = vm["validate"].as<bool>();
    
        std::string out = vm["debug"].as<std::string>();
    
        bool debug = false;
    
        if(out == "yes"){
            debug = true;
        }
    
        int num_threads = vm["omp:threads"].as<int>();
    
    #if defined(USE_MULTITHREADS)
        omp_set_num_threads(num_threads);
    #endif
    
    
        MDESC   descA;
        MKL_INT iam, nprocs, ictxt, myrow, mycol, nprow, npcol;
        MKL_INT i, j, info;
    
        blacs_pinfo( &iam, &nprocs );
    
        nprow = (MKL_INT) nprocs;
        npcol = (MKL_INT) (nprocs / nprow);
    
        blacs_get( &i_negone, &i_zero, &ictxt );
        blacs_gridinit( &ictxt, "R", &nprow, &npcol );
        blacs_gridinfo( &ictxt, &nprow, &npcol, &myrow, &mycol);
    
        std::vector<double> A_Dist(m * n);
    
        if(iam == 0){
            for(int i = 0; i < m * n; i++){
    //	    A_Dist[i] = (double)rand()/RAND_MAX*10.-5.;
                A_Dist[i] = (double)rand()/RAND_MAX*5.-10.;
    
            }
        }
    
        if(debug && iam == 0)
           showMatrix(A_Dist.data(), m, n);
    
        MKL_INT m_A = numroc( &m, &mb, &myrow, &i_zero, &nprow );
    
        MKL_INT a_lld = MAX( m_A, 1 );
    
        MKL_INT n_A = numroc( &n, &nb, &mycol, &i_zero, &npcol );
    
        std::vector<double> A(m_A * n_A);
    
        std::fill(A.begin(), A.end(), (double)1.0);
    
        ScatterMatrix(ictxt, m, n, mb, nb, m_A, n_A, A_Dist.data(), A.data(), 0);
    
        std::vector<double> Q(m * n);
    
        std::vector<double> V(m * n);
    
        std::vector<double> R(n * n);
        
        descinit( descA, &m, &n, &mb, &nb, &i_zero, &i_zero, &ictxt, &a_lld, &info );
    
        std::vector<double> tau(n_A);
    
        double* work; double numwork; int lwork;
    
        lwork = -1;
    
        std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();
    
        pdgeqrf(&m, &n, A.data(), &i_one, &i_one, descA, tau.data(), &numwork, &lwork, &info);
    
        lwork = (int)numwork;
    
        auto wptr = std::unique_ptr<double[]> {
            new double[ lwork ]
        };
    
        work = wptr.get();
    
        pdgeqrf(&m, &n, A.data(), &i_one, &i_one, descA, tau.data(), work, &lwork, &info);
    
        if(debug || validate){
    
            GatherMatrix(ictxt, m, n, mb, nb, m_A, n_A, A.data(), V.data(), 0);
            
            ConstructR(n, R.data(), V.data(), m);
    	
        }
    
        if(explicitQ){
    
    	lwork = -1;
    
            pdorgqr(&m, &n, &n, A.data(), &i_one, &i_one, descA, tau.data(), &numwork, &lwork, &info);
    
            lwork = (int)numwork;
    
            wptr = std::unique_ptr<double[]> {
                new double[ lwork ]
            };
    
            work = wptr.get();
    
            pdorgqr(&m, &n, &n, A.data(), &i_one, &i_one, descA, tau.data(), work, &lwork, &info);
    
            if(debug || validate){
            
                GatherMatrix(ictxt, m, n, mb, nb, m_A, n_A, A.data(), Q.data(), 0);
    
            }
        }
    
        std::chrono::high_resolution_clock::time_point end = std::chrono::high_resolution_clock::now();
    
        std::chrono::duration<double> elapsed = std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
    
    
        if(explicitQ){
    
            if(debug && iam == 0){
    
                showMatrix(Q.data(), m, n);
    
            }
    
        } else if(debug && iam == 0){
    
            showMatrix(R.data(), n, n);
    
        }
    
        if(explicitQ && validate && iam == 0){
    
            double relOrthogError, relError;
    
            auto valid = validation(m, n, Q.data(), m, R.data(), n, A_Dist.data(), m);
    
            relOrthogError = std::get<0>(valid);
    
            relError = std::get<1>(valid);
    
            std::cout << "----------------------" << std::endl;
    
            std::cout << "||Q^H Q - I||_oo / (eps Max(m,n)) =: " << relOrthogError << ". ";
    
            if(relOrthogError > 10.0){
    
                std::cout << "Unacceptably large relative orthogonality error" ;
    
            }
    
            std::cout << std::endl;
    
            std::cout << "||A - QR||_oo / (eps Max(m,n) ||A||_1) =: "<< relError << ". ";
    
            if(relError > 10.0){
    
                std::cout << "Unacceptably large relative error" ;
    
            }
    
            std::cout << std::endl;
    
            std::cout << "----------------------" << std::endl;
    
        }
    
        if(explicitQ){
            if(iam == 0) std::cout << "PxGEQRF ExplicitQ," << m << "," << n << "," << nb << "," << nprocs << "," << num_threads << "," << elapsed.count() << std::endl;
        }else{
            if(iam == 0) std::cout << "PxGEQRF," << m << "," << n << "," << nb << "," << nprocs << "," << num_threads << "," << elapsed.count() << std::endl;
        }
    
    
    
        return 0;
    
    }