Skip to content
Snippets Groups Projects
Select Git revision
  • a4966e98788fc90ffbffef74874d63ed9a0d7f96
  • master default protected
  • eigensolver
  • jw-booster
  • updateQR
5 results

gemm_mkl.cpp

Blame
  • user avatar
    Xinzhe Wu authored
    a4966e98
    History
    gemm_mkl.cpp 2.24 KiB
    #include <stdio.h>
    #include <iostream>
    #include <vector>
    #include <cmath>
    #include <time.h>
    #include <iomanip>
    #include <chrono>
    #include <cstdlib>
    #include <boost/program_options.hpp>
    #include <boost/foreach.hpp>
    #include <boost/tokenizer.hpp>
    
    #include <mkl.h>
    
    #include <omp.h>
    
    int main(int argc, char** argv) {
    
        using namespace boost::program_options;
    
        options_description desc_cmdline;
    
        desc_cmdline.add_options()
            ("size,s", value<int>()->default_value(-1), "Matrix size (if > 0, overrides m, n, k).")
            ("m", value<int>()->default_value(1024), "Size m.")
            ("n", value<int>()->default_value(1024), "Size n.")
            ("k", value<int>()->default_value(512), "Size k.")
            ("omp:threads", value<int>()->default_value(1), "OpenMP thread number.")
            ("repeat", value<int>()->default_value(1), "repeat times.");
    
        variables_map vm;
    
        store(parse_command_line(argc, argv, desc_cmdline),vm);
    
        int m = vm["m"].as<int>();
        int n = vm["n"].as<int>();
        int k = vm["k"].as<int>();
        int s = vm["size"].as<int>();
        int repeat = vm["repeat"].as<int>();
    
        if (s > 0)
            m = n = k = s;
    
        int num_threads = vm["omp:threads"].as<int>();
        omp_set_num_threads(num_threads);
    
        const double alpha = 1.0;
    
        const double beta = 1.0;
    
        std::vector<double> A(m * k);
        std::vector<double> B(k * n);
        std::vector<double> C(m * n);
    
        std::generate(A.begin(), A.end(), [ttt = 1] () mutable { return ttt++; });
        std::generate(B.begin(), B.end(), [ttt = 2] () mutable { return ttt++; });
        std::fill(C.begin(), C.end(), (double)0.0);
    
        double flops = 2.0 * m * n * k / 1e9;
    
        std::chrono::high_resolution_clock::time_point start, end;
    
        std::chrono::duration<double> elapsed;
    
        start = std::chrono::high_resolution_clock::now();
    
        for(int i = 0; i < repeat; i++){
    
            dgemm("N", "N", &m, &n, &k, &alpha, A.data(), &m, B.data(), &k, &beta, C.data(), &m);
    
        }
    
        end = std::chrono::high_resolution_clock::now();
    
        elapsed = std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
    
        std::cout << m << "," << n << "," << k << "," << "GEMM (mkl)," << num_threads << "," << 1 << "," << elapsed.count() << "," <<  repeat * flops / elapsed.count() << std::endl;
    
        return 0;
    }