#include <stdio.h>
#include <omp.h>
#include <iostream>
#include <chrono>

#ifdef USE_MPI
#include <mpi.h>
#endif

using namespace std::chrono;

void socket_init(int socket_num)
{
   int n_procs;

   n_procs = omp_get_place_num_procs(socket_num);
   #pragma omp parallel num_threads(n_procs) proc_bind(close)
   {
      printf("Reporting in from socket %d, thread ID: %d\n",
                                socket_num,omp_get_thread_num() );
   }
}

void numa_in_operations(int socket_num){

   int n_procs;

   n_procs = omp_get_place_num_procs(socket_num);

   if(socket_num == 0){
        #pragma omp parallel num_threads(n_procs)
        {
	    printf("The first socket does the computation in parallel\n");
	}
   }else{
	printf("The other sockets do nothing\n");
  }

}

int main(int argc, char** argv)
{

   int rank;

#ifdef USE_MPI
   MPI_Init(&argc,&argv);

   MPI_Comm_rank(MPI_COMM_WORLD, &rank);
#else
   rank = 0;
#endif

   int n_sockets, socket_num;
   int n_procs;

   int num_thread;
   
   num_thread = atoi(argv[1]);

   omp_set_nested(1);

   omp_set_max_active_levels(2);

   n_sockets = omp_get_num_places();

   int thread_per_socket = num_thread / n_sockets;

   int size = 100000000;

   double *b = new double[size];

   for(int i = 0; i < size; i++){
        b[i] = i + 1;
   }

   double sum = 0;

   auto t1 = high_resolution_clock::now();

   #pragma omp parallel num_threads(n_sockets) shared(sum) private(socket_num, n_procs) proc_bind(spread)
  {
      socket_num = omp_get_place_num();
      n_procs = omp_get_place_num_procs(socket_num);
      if(socket_num == 0){
           #pragma omp parallel for reduction(+:sum) num_threads(thread_per_socket) 
	   for(int i = 0; i < size; i++){
               sum += b[i];
           } 
      }else{
/*
          printf("The other sockets do nothing\n");
*/
      }

   }

   auto t2 = high_resolution_clock::now();

   auto t = duration_cast<duration<double>>(t2 - t1);

   if(rank == 0)
       std::cout << "OMP (1 DOMAIN)," << size << "," << sum << "," << num_thread << "," << n_sockets << "," << t.count() << std::endl;

   delete [] b;

#ifdef USE_MPI
   MPI_Finalize();
#else
   return 0;
#endif

}