Skip to content
Snippets Groups Projects
Commit a4966e98 authored by Xinzhe Wu's avatar Xinzhe Wu
Browse files

updated updateQR

parent a3258461
Branches master updateQR
No related tags found
No related merge requests found
......@@ -34,6 +34,7 @@
./build
build
build2
build3
CMakeLists.txt.user
CMakeCache.txt
CMakeFiles
......
......@@ -158,6 +158,7 @@ ADD_SUBDIRECTORY(HEEVD)
ADD_SUBDIRECTORY(updateQR)
ADD_SUBDIRECTORY(TSQR)
ADD_SUBDIRECTORY(GEMM)
include(Dart)
include(CPack)
......
if(MKL_MT)
add_executable(gemm_mkl.exe gemm_mkl.cpp)
target_link_libraries(gemm_mkl.exe PRIVATE mkl_common_mt ${OPENMP_FLAGS})
endif()
#include <stdio.h>
#include <iostream>
#include <vector>
#include <cmath>
#include <time.h>
#include <iomanip>
#include <chrono>
#include <cstdlib>
#include <boost/program_options.hpp>
#include <boost/foreach.hpp>
#include <boost/tokenizer.hpp>
#include <mkl.h>
#include <omp.h>
int main(int argc, char** argv) {
using namespace boost::program_options;
options_description desc_cmdline;
desc_cmdline.add_options()
("size,s", value<int>()->default_value(-1), "Matrix size (if > 0, overrides m, n, k).")
("m", value<int>()->default_value(1024), "Size m.")
("n", value<int>()->default_value(1024), "Size n.")
("k", value<int>()->default_value(512), "Size k.")
("omp:threads", value<int>()->default_value(1), "OpenMP thread number.")
("repeat", value<int>()->default_value(1), "repeat times.");
variables_map vm;
store(parse_command_line(argc, argv, desc_cmdline),vm);
int m = vm["m"].as<int>();
int n = vm["n"].as<int>();
int k = vm["k"].as<int>();
int s = vm["size"].as<int>();
int repeat = vm["repeat"].as<int>();
if (s > 0)
m = n = k = s;
int num_threads = vm["omp:threads"].as<int>();
omp_set_num_threads(num_threads);
const double alpha = 1.0;
const double beta = 1.0;
std::vector<double> A(m * k);
std::vector<double> B(k * n);
std::vector<double> C(m * n);
std::generate(A.begin(), A.end(), [ttt = 1] () mutable { return ttt++; });
std::generate(B.begin(), B.end(), [ttt = 2] () mutable { return ttt++; });
std::fill(C.begin(), C.end(), (double)0.0);
double flops = 2.0 * m * n * k / 1e9;
std::chrono::high_resolution_clock::time_point start, end;
std::chrono::duration<double> elapsed;
start = std::chrono::high_resolution_clock::now();
for(int i = 0; i < repeat; i++){
dgemm("N", "N", &m, &n, &k, &alpha, A.data(), &m, B.data(), &k, &beta, C.data(), &m);
}
end = std::chrono::high_resolution_clock::now();
elapsed = std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
std::cout << m << "," << n << "," << k << "," << "GEMM (mkl)," << num_threads << "," << 1 << "," << elapsed.count() << "," << repeat * flops / elapsed.count() << std::endl;
return 0;
}
......@@ -281,40 +281,57 @@ int main(int argc, char** argv) {
std::vector<double> tau(nevx);
std::vector<double> Vtmp;
std::vector<double> Vtmp(m * nevx);
std::chrono::high_resolution_clock::time_point t1, t2, t3, t4;
std::chrono::high_resolution_clock::time_point t1, t2;
std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();
for(int i = 0; i < num_iter; i++){
//monitor of filter V<-V
for(int j = m * fixed; j < V.size(); j++){
V[j] += (double)rand()/RAND_MAX*2.0-1.0;
}
if(printtime){
t1 = std::chrono::high_resolution_clock::now();
}
geqrf(m, nevx, V.data(), m, tau.data());
std::memcpy(V.data(), Vtmp.data(), m * fixed * sizeof(double));
t2 = std::chrono::high_resolution_clock::now();
mqr((char *)"L", (char *)"T", m, nevx - fixed, fixed, V.data(), m, tau.data(), V.data() + m * fixed, m);
std::chrono::duration<double> duration_1 = std::chrono::duration_cast<std::chrono::duration<double>>(t2 - t1);
//QR
geqrf(m - fixed, nevx - fixed, V.data() + fixed * m + fixed, m, tau.data() + fixed);
//
for(int j = 0; j < V.size(); j++){
Vtmp = V;
V[j] = j + 2.1;
if(printtime){
t2 = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> time = std::chrono::duration_cast<std::chrono::duration<double>>(t2 - t1);
std::cout << "updateQR," << i+1 << "," << time.count() << std::endl;
}
t3 = std::chrono::high_resolution_clock::now();
gqr(m, nevx, nevx, V.data(), m, tau.data());
geqrf_c(m, nevx, V.data(), m, tau.data());
//reduce to active space and compute residual
t4 = std::chrono::high_resolution_clock::now();
//deflation and locking: here "locked" number of eigepair are approximated
std::memcpy(Y.data() + fixed * m, V.data() + fixed * m, m * locked * sizeof(double));
std::chrono::duration<double> duration_2 = std::chrono::duration_cast<std::chrono::duration<double>>(t4 - t3);
fixed += locked;
std::cout << "MKL GEQRF = " << duration_1.count() << "s" << ", My GEQRF = " << duration_2.count() << "s." << std::endl;
std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();
}
std::chrono::high_resolution_clock::time_point end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
if(debug == "yes")
showMatrix(V, m, nevx, "V2");
......@@ -322,7 +339,7 @@ int main(int argc, char** argv) {
if(debug == "yes")
showMatrix(Y, m, nev, "Y2");
// std::cout << "UpdateQR," << m << "," << nev << "," << nex << "," << locked << "," << time_span.count() << "," << num_threads << std::endl;
std::cout << "UpdateQR," << m << "," << nev << "," << nex << "," << locked << "," << time_span.count() << "," << num_threads << std::endl;
return 0;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment