From 4cce41afbba150247095ad4be07c3c60a3caabe9 Mon Sep 17 00:00:00 2001
From: Rene Halver <r.halver@fz-juelich.de>
Date: Wed, 24 Feb 2021 12:09:02 +0100
Subject: [PATCH] Revert "updated ForceBased method and VTK output for VTK >=
 9.0"

This reverts commit 9c72abf605c88b074ea212e5dead60d6bd3e58e1.
---
 CMakeLists.txt             |   5 -
 example/ALL_test.cpp       | 485 +++++++++++++----------
 include/ALL.hpp            |  54 +--
 include/ALL_ForceBased.hpp | 769 +++++++++++++++++++++++++------------
 4 files changed, 820 insertions(+), 493 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index beee3f9..de77a50 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -76,13 +76,10 @@ if(CM_ALL_VTK_OUTPUT)
     if(NOT VTK_FOUND)
         message(FATAL_ERROR "VTK not found, help CMake to find it by setting VTK_LIBRARY and VTK_INCLUDE_DIR")
     else()
-        message("VTK found in ${VTK_DIR}")
         if (VTK_MAJOR_VERSION GREATER_EQUAL 9)
             find_package(VTK REQUIRED COMPONENTS
                 CommonCore
                 CommonDataModel
-                FiltersHybrid
-                FiltersModeling
                 FiltersProgrammable
                 IOParallelXML
                 IOXML
@@ -92,8 +89,6 @@ if(CM_ALL_VTK_OUTPUT)
             find_package(VTK REQUIRED COMPONENTS
                 vtkCommonCore
                 vtkCommonDataModel
-                vtkFiltersHybrid
-                vtkFiltersModeling
                 vtkFiltersProgrammable
                 vtkIOParallelXML
                 vtkIOXML
diff --git a/example/ALL_test.cpp b/example/ALL_test.cpp
index d21fc51..5827b7d 100644
--- a/example/ALL_test.cpp
+++ b/example/ALL_test.cpp
@@ -53,12 +53,6 @@
 #include <vtkVersion.h>
 #include <vtkXMLPPolyDataWriter.h>
 #include <vtkXMLPolyDataWriter.h>
-
-// for FORCEBASED test
-#include <vtkDelaunay3D.h>
-#include <vtkNew.h>
-#include <vtkDataSetSurfaceFilter.h>
-#include <vtkSelectEnclosedPoints.h>
 #endif
 
 #define BOX_SIZE 300.0
@@ -66,7 +60,7 @@
 #define N_GENERATE 1000
 #define SEED 123789456u
 #define N_LOOP 500
-#define OUTPUT_INTV 10
+#define OUTPUT_INTV 1
 #define MAX_NEIG 1024
 #define ALL_HISTOGRAM_DEFAULT_WIDTH 1.0
 
@@ -128,8 +122,8 @@ void generate_points(std::vector<ALL::Point<double>> &points,
  *   long + 4*n_point double        : point data
  ****************************************************************/
 
-void read_points(std::vector<ALL::Point<double>> &points,
-                 int &n_points, char *filename,
+void read_points(std::vector<ALL::Point<double>> &points, std::vector<double> l,
+                 std::vector<double> u, int &n_points, char *filename,
                  int dimension, int rank, MPI_Comm comm) {
   MPI_File file;
   MPI_Barrier(comm);
@@ -139,9 +133,9 @@ void read_points(std::vector<ALL::Point<double>> &points,
 
   n_points = 0;
 
-  int nRanks;
+  int n_ranks;
   long offset;
-  MPI_Comm_size(comm, &nRanks);
+  MPI_Comm_size(comm, &n_ranks);
 
   if (err) {
     if (rank == 0)
@@ -154,7 +148,7 @@ void read_points(std::vector<ALL::Point<double>> &points,
                    MPI_LONG, MPI_STATUS_IGNORE);
   // read number of points from file
   MPI_File_read_at(file,
-                   (MPI_Offset)(nRanks * sizeof(long) + rank * sizeof(int)),
+                   (MPI_Offset)(n_ranks * sizeof(long) + rank * sizeof(int)),
                    &n_points, 1, MPI_INT, MPI_STATUS_IGNORE);
 
   double values[4];
@@ -163,11 +157,11 @@ void read_points(std::vector<ALL::Point<double>> &points,
   for (int i = 0; i < n_points; ++i) {
     MPI_File_read_at(file,
                      (MPI_Offset)((offset + i) * block_size +
-                                  nRanks * (sizeof(int) + sizeof(long))),
+                                  n_ranks * (sizeof(int) + sizeof(long))),
                      &ID, 1, MPI_DOUBLE, MPI_STATUS_IGNORE);
     MPI_File_read_at(file,
                      (MPI_Offset)((offset + i) * block_size +
-                                  nRanks * (sizeof(int) + sizeof(long)) +
+                                  n_ranks * (sizeof(int) + sizeof(long)) +
                                   sizeof(long)),
                      values, 4, MPI_DOUBLE, MPI_STATUS_IGNORE);
     ALL::Point<double> p(dimension, &values[0], values[3], ID);
@@ -216,10 +210,10 @@ void read_points(std::vector<ALL::Point<double>> &points,
 // function to create a VTK output of the points in the system
 void print_points(std::vector<ALL::Point<double>> plist, int step,
                   ALL::LB_t method, MPI_Comm comm) {
-  int rank, nRanks;
+  int rank, n_ranks;
   static bool vtk_init = false;
   MPI_Comm_rank(comm, &rank);
-  MPI_Comm_size(comm, &nRanks);
+  MPI_Comm_size(comm, &n_ranks);
   vtkMPIController *controller;
 
   // seperate init step required, since VORONOI does not
@@ -278,7 +272,7 @@ void print_points(std::vector<ALL::Point<double>> plist, int step,
 
   auto parallel_writer = vtkSmartPointer<vtkXMLPPolyDataWriter>::New();
   parallel_writer->SetFileName(ss_para.str().c_str());
-  parallel_writer->SetNumberOfPieces(nRanks);
+  parallel_writer->SetNumberOfPieces(n_ranks);
   parallel_writer->SetStartPiece(rank);
   parallel_writer->SetEndPiece(rank);
   parallel_writer->SetInputData(polydata);
@@ -382,7 +376,7 @@ int main(int argc, char **argv) {
 
     // setup of cartesian communicator
     int localRank;
-    int nRanks;
+    int n_ranks;
     MPI_Comm cart_comm;
     int local_coords[sys_dim];
     int periodicity[sys_dim];
@@ -400,11 +394,11 @@ int main(int argc, char **argv) {
     }
 
     // get number of total ranks
-    MPI_Comm_size(MPI_COMM_WORLD, &nRanks);
+    MPI_Comm_size(MPI_COMM_WORLD, &n_ranks);
 
     if (global_dim[0] == 0) {
       // get distribution into number of dimensions
-      MPI_Dims_create(nRanks, sys_dim, global_dim);
+      MPI_Dims_create(n_ranks, sys_dim, global_dim);
     }
 
     // create cartesian MPI communicator
@@ -545,7 +539,7 @@ int main(int argc, char **argv) {
       generate_points(points, l, u, coords, global_dim, sys_dim, n_points,
                       localRank);
     } else {
-      read_points(points, n_points, filename, sys_dim, localRank,
+      read_points(points, l, u, n_points, filename, sys_dim, localRank,
                   cart_comm);
     }
     double *transfer;
@@ -587,7 +581,7 @@ int main(int argc, char **argv) {
     }
     double n_total;
     MPI_Allreduce(&n_local, &n_total, 1, MPI_DOUBLE, MPI_SUM, cart_comm);
-    double avg_work = (double)n_total / (double)nRanks;
+    double avg_work = (double)n_total / (double)n_ranks;
     double n_min, n_max;
     MPI_Allreduce(&n_local, &n_min, 1, MPI_DOUBLE, MPI_MIN, cart_comm);
     MPI_Allreduce(&n_local, &n_max, 1, MPI_DOUBLE, MPI_MAX, cart_comm);
@@ -601,10 +595,9 @@ int main(int argc, char **argv) {
     // get starting number of particles
     MPI_Allreduce(&n_local, &total_points, 1, MPI_DOUBLE, MPI_SUM, cart_comm);
 
-
     // output of borders / contents
-    if (nRanks < 216) {
-      for (int i = 0; i < nRanks; ++i) {
+    if (n_ranks < 216) {
+      for (int i = 0; i < n_ranks; ++i) {
         if (localRank == i) {
           std::ofstream of;
           of.open("domain_data.dat", std::ios::out | std::ios::app);
@@ -617,7 +610,7 @@ int main(int argc, char **argv) {
           of << " " << n_local << " ";
 
           of << std::endl;
-          if (i == nRanks - 1)
+          if (i == n_ranks - 1)
             of << std::endl;
           of.close();
           MPI_Barrier(cart_comm);
@@ -630,7 +623,7 @@ int main(int argc, char **argv) {
     if (chosen_method == ALL::LB_t::VORONOI) {
       // one-time particle output to voronoi/particles.pov
 
-      for (int i = 0; i < nRanks; ++i) {
+      for (int i = 0; i < n_ranks; ++i) {
         if (localRank == i) {
           std::ofstream of;
           if (i != 0)
@@ -648,6 +641,7 @@ int main(int argc, char **argv) {
         MPI_Barrier(cart_comm);
       }
 
+      int minpoints = 0;
       double cow_sys[sys_dim + 1];
       double target_point[sys_dim + 1];
 
@@ -671,11 +665,11 @@ int main(int argc, char **argv) {
       // experimental: as a first step try to find more optimal start points
       for (int i_preloop = 0; i_preloop < ALL_VORONOI_PREP_STEPS; ++i_preloop) {
         // get neighbor information
-        int nNeighbors = nRanks;
-        std::vector<double> neighbor_vertices(sys_dim * nRanks);
-        std::vector<int> neighbors(nRanks);
+        int nNeighbors = n_ranks;
+        std::vector<double> neighbor_vertices(sys_dim * n_ranks);
+        std::vector<int> neighbors(n_ranks);
 
-        for (int n = 0; n < nRanks; ++n)
+        for (int n = 0; n < n_ranks; ++n)
           neighbors.at(n) = n;
 
         double local_vertex[sys_dim];
@@ -858,7 +852,7 @@ int main(int argc, char **argv) {
          */
 
         // output of borders / contents
-        for (int i = 0; i < nRanks; ++i) {
+        for (int i = 0; i < n_ranks; ++i) {
           if (localRank == i) {
             std::ofstream of;
             if (!weighted_points)
@@ -870,7 +864,7 @@ int main(int argc, char **argv) {
             of << " " << vertices.at(0)[0] << " " << vertices.at(0)[1] << " "
                << vertices.at(0)[2] << " " << n_points << std::endl;
 
-            if (i == nRanks - 1)
+            if (i == n_ranks - 1)
               of << std::endl;
             of.close();
             MPI_Barrier(cart_comm);
@@ -894,7 +888,7 @@ int main(int argc, char **argv) {
         gamma *= 2.0;
         limit_efficiency /= 2.0;
       }
-      if (localRank == 0 && (i_loop % OUTPUT_INTV == 0))
+      if (localRank == 0)
         std::cout << "loop " << i_loop << ": " << std::endl;
       std::vector<double> work;
       std::vector<int> n_bins(3, -1);
@@ -1019,32 +1013,6 @@ int main(int argc, char **argv) {
         }
 #endif
       }
-
-      // computing center of points for ForceBased method
-      if (chosen_method == ALL::LB_t::FORCEBASED)
-      {
-        ALL::Point<double> cow(3);
-        for (int i = 0; i < 3; ++i)
-          cow[i] = 0.0;
-        if (points.size() > 0)
-        {
-          for (auto p : points)
-          {
-            cow = cow + p;
-          }
-          cow = cow * (1.0 / (double)points.size());
-        }
-        else
-        {
-          for (auto v : vertices)
-          {
-            cow = cow + v;
-          }
-          cow = cow * (1.0 / (double)vertices.size());
-        }
-        lb_obj.setMethodData(&cow);
-      }
-
 #ifdef ALL_DEBUG_ENABLED
       MPI_Barrier(cart_comm);
       if (localRank == 0)
@@ -1101,11 +1069,9 @@ int main(int argc, char **argv) {
       int n_points_global = 0;
       MPI_Reduce(&n_points, &n_points_global, 1, MPI_INT, MPI_SUM, 0,
                  MPI_COMM_WORLD);
-#ifdef ALL_DEBUG_ENABLED      
       if (localRank == 0)
         std::cout << "number of particles in step " << i_loop << ": "
                   << n_points_global << std::endl;
-#endif      
 #ifdef ALL_DEBUG_ENABLED
       MPI_Barrier(cart_comm);
       if (localRank == 0)
@@ -1188,6 +1154,7 @@ int main(int argc, char **argv) {
               P--;
             }
           }
+          MPI_Status status;
 
           MPI_Request sreq_r, rreq_r;
           MPI_Request sreq_l, rreq_l;
@@ -1552,59 +1519,224 @@ int main(int argc, char **argv) {
         }
 
 #ifdef ALL_VTK_OUTPUT
+#ifdef ALL_VTK_FORCE_SORT
+        // creating an unstructured grid for local domain and neighbors
+        auto vtkpoints = vtkSmartPointer<vtkPoints>::New();
+        auto unstructuredGrid = vtkSmartPointer<vtkForceBasedGrid>::New();
+        for (int i = 0; i < 27; ++i) {
+          for (int v = 0; v < 8; ++v) {
+            vtkpoints->InsertNextPoint(comm_vertices[i * 24 + v * 3],
+                                       comm_vertices[i * 24 + v * 3 + 1],
+                                       comm_vertices[i * 24 + v * 3 + 2]);
+          }
+        }
+        unstructuredGrid->SetPoints(vtkpoints);
 
-        // VTK data sets describing the convex hulls of each domain
-        vtkSmartPointer<vtkSelectEnclosedPoints> dataSets[27];
+        auto work = vtkSmartPointer<vtkFloatArray>::New();
+        work->SetNumberOfComponents(1);
+        work->SetNumberOfTuples(27);
+        work->SetName("Cell");
 
-        // vtkPoints list containing all points
-        vtkSmartPointer<vtkPoints> allPoints = vtkSmartPointer<vtkPoints>::New();
+        for (int n = 0; n < 27; ++n) {
+          // define grid points, i.e. vertices of local domain
+          vtkIdType pointIds[8] = {8 * n + 0, 8 * n + 1, 8 * n + 2, 8 * n + 3,
+                                   8 * n + 4, 8 * n + 5, 8 * n + 6, 8 * n + 7};
+
+          auto faces = vtkSmartPointer<vtkCellArray>::New();
+          // setup faces of polyhedron
+          vtkIdType f0[3] = {8 * n + 0, 8 * n + 2, 8 * n + 1};
+          vtkIdType f1[3] = {8 * n + 1, 8 * n + 2, 8 * n + 3};
+
+          vtkIdType f2[3] = {8 * n + 0, 8 * n + 4, 8 * n + 2};
+          vtkIdType f3[3] = {8 * n + 2, 8 * n + 4, 8 * n + 6};
+
+          vtkIdType f4[3] = {8 * n + 2, 8 * n + 6, 8 * n + 3};
+          vtkIdType f5[3] = {8 * n + 3, 8 * n + 6, 8 * n + 7};
+
+          vtkIdType f6[3] = {8 * n + 1, 8 * n + 5, 8 * n + 3};
+          vtkIdType f7[3] = {8 * n + 3, 8 * n + 5, 8 * n + 7};
+
+          vtkIdType f8[3] = {8 * n + 0, 8 * n + 4, 8 * n + 1};
+          vtkIdType f9[3] = {8 * n + 1, 8 * n + 4, 8 * n + 5};
+
+          vtkIdType fa[3] = {8 * n + 4, 8 * n + 6, 8 * n + 5};
+          vtkIdType fb[3] = {8 * n + 5, 8 * n + 6, 8 * n + 7};
+
+          faces->InsertNextCell(3, f0);
+          faces->InsertNextCell(3, f1);
+          faces->InsertNextCell(3, f2);
+          faces->InsertNextCell(3, f3);
+          faces->InsertNextCell(3, f4);
+          faces->InsertNextCell(3, f5);
+          faces->InsertNextCell(3, f6);
+          faces->InsertNextCell(3, f7);
+          faces->InsertNextCell(3, f8);
+          faces->InsertNextCell(3, f9);
+          faces->InsertNextCell(3, fa);
+          faces->InsertNextCell(3, fb);
+
+          unstructuredGrid->InsertNextCell(VTK_POLYHEDRON, 8, pointIds, 12,
+                                           faces->GetPointer());
+          work->SetValue(n, (double)n);
+        }
+        unstructuredGrid->GetCellData()->AddArray(work);
 
-        // fill list
-        for (auto P = points.begin(); P != points.end(); ++P)
-        {
-          ALL::Point<double> p = *P;
-          double pcoords[3];
-          for (int d = 0; d < 3; ++d)
-            pcoords[d] = p[d];
-          allPoints->InsertNextPoint(pcoords);
+        /* Debug output: print local cell and neighbors */
+        /*
+           if (localRank == 26)
+           {
+           auto writer = vtkSmartPointer<vtkXMLForceBasedGridWriter>::New();
+           writer->SetInputData(unstructuredGrid);
+           writer->SetFileName("test.vtu");
+           writer->SetDataModeToAscii();
+        //writer->SetDataModeToBinary();
+        writer->Write();
         }
+         */
 
-        vtkSmartPointer<vtkPolyData> allData = vtkSmartPointer<vtkPolyData>::New();
-        allData->SetPoints(allPoints);
+        MPI_Allreduce(&n_points, &check_np, 1, MPI_INT, MPI_SUM, cart_comm);
 
-        // create each data set
-        for (int i = 0; i < 27; ++i)
-        {
-          // create points structure
-          vtkSmartPointer<vtkPoints> domVertices = vtkSmartPointer<vtkPoints>::New();
-          // fill point structure with data
-          for (int j = 0; j < 8; ++j)
+        auto locator = vtkSmartPointer<vtkCellLocator>::New();
+        locator->SetDataSet(unstructuredGrid);
+        locator->BuildLocator();
+#else
+        auto in_triangle = [=](double nv[27 * 8 * 3], int n, int A, int B,
+                               int C, double x, double y, double z) {
+          ALL::Point<double> a(3, nv + n * 24 + A * 8);
+          ALL::Point<double> b(3, nv + n * 24 + B * 8);
+          ALL::Point<double> c(3, nv + n * 24 + C * 8 + 0);
+          ALL::Point<double> p(3);
+          p[0] = x;
+          p[1] = y;
+          p[2] = z;
+          return false;
+          // return p.same_side_line(a, b, c) &&
+          //       p.same_side_line(b, c, a) &&
+          //       p.same_side_line(c, a, b);
+        };
+
+        // check if (x,y,z) is on the same side as A in
+        // regard to the plane spanned by (B,C,D)
+        auto same_side_plane = [=](double nv[27 * 8 * 3], int n, int A, int B,
+                                   int C, int D, double x, double y, double z) {
+          // compute difference vectors required:
+          // C-B
+          double CBx = nv[n * 24 + C * 8 + 0] - nv[n * 24 + B * 8 + 0];
+          double CBy = nv[n * 24 + C * 8 + 1] - nv[n * 24 + B * 8 + 1];
+          double CBz = nv[n * 24 + C * 8 + 2] - nv[n * 24 + B * 8 + 2];
+          // D-B
+          double DBx = nv[n * 24 + D * 8 + 0] - nv[n * 24 + B * 8 + 0];
+          double DBy = nv[n * 24 + D * 8 + 1] - nv[n * 24 + B * 8 + 1];
+          double DBz = nv[n * 24 + D * 8 + 2] - nv[n * 24 + B * 8 + 2];
+          // A-B
+          double ABx = nv[n * 24 + A * 8 + 0] - nv[n * 24 + B * 8 + 0];
+          double ABy = nv[n * 24 + A * 8 + 1] - nv[n * 24 + B * 8 + 1];
+          double ABz = nv[n * 24 + A * 8 + 2] - nv[n * 24 + B * 8 + 2];
+          // P-B
+          double PBx = x - nv[n * 24 + B * 8 + 0];
+          double PBy = y - nv[n * 24 + B * 8 + 1];
+          double PBz = z - nv[n * 24 + B * 8 + 2];
+
+          // compute normal vector of plane
+          // n = (C-B) x (D-B)
+          double nx = CBy * DBz - CBz * DBy;
+          double ny = CBz * DBx - CBx * DBz;
+          double nz = CBx * DBy - CBy * DBx;
+
+          // compute dot product <A-B,n>
+          double ABn = ABx * nx + ABy * ny + ABz * nz;
+
+          // compute dot product <P-B,n>
+          double PBn = PBx * nx + PBy * ny + PBz * nz;
+
+          /*
+          if (
+                  x <= TRACK_X+0.01 && x >= TRACK_X &&
+                  y <= TRACK_Y+0.01 && y >= TRACK_Y &&
+                  z <= TRACK_Z+0.01 && z >= TRACK_Z
+             )
           {
-            domVertices->InsertNextPoint(comm_vertices + 24 * i + 3 * j);
+              std::cerr << "Tracked Particle: "
+                        << localRank << " ( "
+                        << x << " , "
+                        << y << " , "
+                        << z << " ) "
+                        << n << ", "
+                        << A << ", "
+                        << B << ", "
+                        << C << ", "
+                        << D << ") "
+                        << ( ( std::abs(PBn) < 1e-12 ) && in_triangle(nv, n, B,
+          C, D, x, y, z) ) << " "
+                        << ( ( ABn * PBn ) > 0 )
+                        << std::endl;
           }
-          vtkSmartPointer<vtkPolyData> domSet = vtkSmartPointer<vtkPolyData>::New();
-          domSet->SetPoints(domVertices);
-          // create a Delaunay grid of the points
-          vtkNew<vtkDelaunay3D> delaunay;
-          delaunay->SetInputData(domSet);
-          delaunay->Update();
-          // extract the surface
-          vtkSmartPointer<vtkDataSetSurfaceFilter> surfaceFilter = vtkSmartPointer<vtkDataSetSurfaceFilter>::New();
-          surfaceFilter->SetInputConnection(delaunay->GetOutputPort());
-          vtkDataSetSurfaceFilter::SetGlobalWarningDisplay(0);
-          surfaceFilter->Update();
-          // create vtkPolyData
-          vtkSmartPointer<vtkPolyData> surfaceData = vtkSmartPointer<vtkPolyData>::New();
-          surfaceData = surfaceFilter->GetOutput();
-
-          dataSets[i] = vtkSmartPointer<vtkSelectEnclosedPoints>::New();
-          dataSets[i]->SetInputData(allData);
-          dataSets[i]->SetSurfaceData(surfaceData);
-          dataSets[i]->Update();
-        }
+          */
 
+          if (std::abs(PBn) < 1e-12) {
+            return in_triangle(nv, n, B, C, D, x, y, z);
+          } else
+            // check if sign matches
+            return ((ABn * PBn) > 0);
+        };
+
+        // check if (x,y,z) is in the tetrahedron described
+        // by the points within the array t
+        auto in_tetraeder = [=](double nv[27 * 8 * 3], int n, int t[4],
+                                double x, double y, double z) {
+          return same_side_plane(nv, n, t[0], t[1], t[2], t[3], x, y, z) &&
+                 same_side_plane(nv, n, t[3], t[0], t[1], t[2], x, y, z) &&
+                 same_side_plane(nv, n, t[2], t[3], t[0], t[1], x, y, z) &&
+                 same_side_plane(nv, n, t[1], t[2], t[3], t[0], x, y, z);
+        };
+
+        auto containing_neighbor = [=](double nv[27 * 8 * 3], bool e[27],
+                                       double x, double y, double z) {
+          // definition of tetrahedra
+          /*
+          int tetrahedra[8][4] =
+          {
+              { 0, 1, 2, 4 },
+              { 1, 0, 3, 5 },
+              { 2, 3, 0, 6 },
+              { 3, 2, 1, 7 },
+              { 4, 5, 6, 0 },
+              { 5, 4, 7, 1 },
+              { 6, 7, 4, 2 },
+              { 7, 6, 5, 3 }
+          };
+          */
+          int tetrahedra[5][4] = {{1, 2, 4, 7},
+                                  {0, 1, 2, 4},
+                                  {1, 2, 3, 7},
+                                  {1, 4, 5, 7},
+                                  {2, 4, 6, 7}};
+
+          // check if the particle is in the local
+          // tetrahedron, if so keep it
+          // (needed to decide what about particles
+          //  on surfaces)
+          for (int t = 0; t < 5; ++t) {
+            if (in_tetraeder(nv, 13, tetrahedra[t], x, y, z))
+              return 13;
+          }
+
+          // loop over neighbors
+          for (int n = 0; n < 27; ++n) {
+            if (!e[n] || n == 13)
+              continue;
+            // loop over tetrahedra
+            for (int t = 0; t < 5; ++t) {
+              if (in_tetraeder(nv, n, tetrahedra[t], x, y, z))
+                return n;
+            }
+          }
+
+          return -1;
+        };
+
+#endif
 
-        // setup counters for send / recv actions
         int n_send[27];
         int n_recv[27];
         for (int i = 0; i < 27; ++i) {
@@ -1612,78 +1744,51 @@ int main(int argc, char **argv) {
           n_recv[i] = 0;
         }
 
-        std::vector<int> sendTarget(allPoints->GetNumberOfPoints(), -1);
-        
-        // check if point stays in local domain
-        for(int j = 0; j < allPoints->GetNumberOfPoints(); ++j)
-        {
-          if (dataSets[13]->IsInside(j))
-          {
-            n_send[13]++;
-            sendTarget.at(j) = 13;
-          }
-        }
+        int check_np = 0;
+        int check_np_new = 0;
 
-        // iterate over neighbors to fill send buffers
-        for(int i = 0; i < 27; ++i)
-        {
-          if (i == 13) continue;
-          for(int j = 0; j < allPoints->GetNumberOfPoints(); ++j)
-          {
-            if (sendTarget.at(j) != -1) continue;
-            if (dataSets[i]->IsInside(j) )
-            {
-              sendTarget.at(j) = i;
-              if (n_send[i] == max_particles) {
-                throw ALL::InvalidArgumentException(
-                    __FILE__, __func__, __LINE__,
-                    "Trying to send more particles than \
-                                              buffer size allows!");
-              }
-              for (int d = 0; d < 3; ++d) {
-                transfer[i * (sys_dim + 1) * max_particles +
-                         n_send[i] * (sys_dim + 1) + d] = points.at(j)[d];
-              }
-              transfer[i * (sys_dim + 1) * max_particles +
-                       n_send[i] * (sys_dim + 1) + sys_dim] = points.at(j).getWeight();
-              n_send[i]++;
-            }
-          }
-        }
-
-        for(int i = allPoints->GetNumberOfPoints() - 1; i >= 0; --i)
-        {
-          if (sendTarget.at(i) < 0)
-          {
-            sendTarget.at(i) = 13;
-          }
-          if (sendTarget.at(i) != 13)
-          {
-            points.erase(points.begin() + i);
-            n_points--;
-          }
-        }
+        for (auto P = points.begin(); P != points.end(); ++P) {
+          ALL::Point<double> p = *P;
+          double pcoords[3];
+          double pccoords[3];
+          int subId;
+          double weights[4];
+          for (int d = 0; d < 3; ++d)
+            pcoords[d] = p[d];
+#ifdef ALL_VTK_FORCE_SORT
+          vtkIdType cellId = locator->FindCell(pcoords);
+#else
+          int cellId = containing_neighbor(comm_vertices, exists, pcoords[0],
+                                           pcoords[1], pcoords[2]);
+#endif
+          /*
+                                              vtkIdType cellId =
+             unstructuredGrid->FindCell(pcoords, NULL, 0, 1e-6, subId, pccoords,
+                                              weights);
+          */
 
-        /*
-        for (int j = 0; j < nRanks; ++j)
-        {
-          if (localRank == j)
-          {
-            int cnt = 0;
-            for (int i = 0; i < 27; ++i)
-            {
-              std::cout << n_send[i] << " ";
-              cnt += n_send[i];
+          // if the particle is in a valid neighboring cell
+          if (cellId >= 0 && cellId != 13 && cellId <= 26) {
+            if (n_send[cellId] == max_particles) {
+              throw ALL::InvalidArgumentException(
+                  __FILE__, __func__, __LINE__,
+                  "Trying to send more particles than \
+                                            buffer size allows!");
             }
-            MPI_Barrier(cart_comm);
+            for (int d = 0; d < 3; ++d) {
+              transfer[cellId * (sys_dim + 1) * max_particles +
+                       n_send[cellId] * (sys_dim + 1) + d] = p[d];
+            }
+            transfer[cellId * (sys_dim + 1) * max_particles +
+                     n_send[cellId] * (sys_dim + 1) + sys_dim] = p.getWeight();
+            n_send[cellId]++;
+            points.erase(P);
+            --P;
+            --n_points;
           }
-          else
-            MPI_Barrier(cart_comm);
         }
-        */
 
         for (int n = 0; n < 27; ++n) {
-          if (n == 13) continue;
           MPI_Isend(&n_send[n], 1, MPI_INT, neighbors.at(n), 1020, cart_comm,
                     &request[n]);
           MPI_Irecv(&n_recv[n], 1, MPI_INT, neighbors.at(n), 1020, cart_comm,
@@ -1692,7 +1797,6 @@ int main(int argc, char **argv) {
         MPI_Waitall(54, request, status);
 
         for (int n = 0; n < 27; ++n) {
-          if (n == 13) continue;
           MPI_Isend(&transfer[n * (sys_dim + 1) * max_particles],
                     (sys_dim + 1) * n_send[n], MPI_DOUBLE, neighbors.at(n),
                     1030, cart_comm, &request[n]);
@@ -1703,7 +1807,6 @@ int main(int argc, char **argv) {
         MPI_Waitall(54, request, status);
 
         for (int n = 0; n < 27; ++n) {
-          if (n == 13) continue;
           if (exists[n]) {
             for (int i = 0; i < n_recv[n]; ++i) {
               ALL::Point<double> p(
@@ -1713,17 +1816,10 @@ int main(int argc, char **argv) {
                        sys_dim]);
               points.push_back(p);
               ++n_points;
-              if (p[0] < 0.5 && p[1] < 0.5 && p[2] < 0.5)
-              {
-                std::cout << "Particle put to zero: " << localRank << " " << neighbors.at(n) << " " << i << " " << p << std::endl;
-              }
             }
           }
         }
 
-        //MPI_Barrier(cart_comm);
-        //MPI_Abort(cart_comm,-999);
-
 #else
         if (localRank == 0)
           std::cout << "Currently no FORCEBASED test without VTK!"
@@ -2086,10 +2182,8 @@ int main(int argc, char **argv) {
 #ifdef ALL_VTK_OUTPUT
           // if (localRank == 0)
           //  std::cout << "creating vtk outlines output" << std::endl;
-          /*
           if (chosen_method != ALL::LB_t::FORCEBASED)
             lb_obj.printVTKoutlines(output_step);
-          */
           // if (localRank == 0)
           //  std::cout << "creating vtk vertices output" << std::endl;
           if (chosen_method == ALL::LB_t::FORCEBASED)
@@ -2123,7 +2217,7 @@ int main(int argc, char **argv) {
                     cart_comm);
 
       MPI_Allreduce(&n_local, &n_total, 1, MPI_DOUBLE, MPI_SUM, cart_comm);
-      avg_work = n_total / (double)nRanks;
+      avg_work = n_total / (double)n_ranks;
       MPI_Allreduce(&n_local, &n_min, 1, MPI_DOUBLE, MPI_MIN, cart_comm);
       MPI_Allreduce(&n_local, &n_max, 1, MPI_DOUBLE, MPI_MAX, cart_comm);
       d_min = n_min / avg_work;
@@ -2170,7 +2264,7 @@ int main(int argc, char **argv) {
       }
 
       // output of borders / contents
-      for (int i = 0; i < nRanks; ++i) {
+      for (int i = 0; i < n_ranks; ++i) {
         if (localRank == i) {
           std::ofstream of;
           if (!weighted_points)
@@ -2200,7 +2294,7 @@ int main(int argc, char **argv) {
           of << n_local << " ";
 
           of << std::endl;
-          if (i == nRanks - 1)
+          if (i == n_ranks - 1)
             of << std::endl;
           of.close();
           MPI_Barrier(cart_comm);
@@ -2240,7 +2334,7 @@ int main(int argc, char **argv) {
                              2.0 * width[1] * width[2] +
                              2.0 * width[0] * width[2];
 
-            for (int r = 0; r < nRanks; ++r) {
+            for (int r = 0; r < n_ranks; ++r) {
               if (r == localRank) {
                 std::ofstream of;
                 of.open("domain_data.dat", std::ios::out | std::ios::app);
@@ -2252,6 +2346,8 @@ int main(int argc, char **argv) {
               MPI_Barrier(cart_comm);
             }
           }
+          if (chosen_method == ALL::LB_t::FORCEBASED)
+            lb_obj.printVTKvertices(output_step);
 #endif
           output_step++;
 #ifdef ALL_VORONOI_ACTIVE
@@ -2266,12 +2362,12 @@ int main(int argc, char **argv) {
 
     // create binary output (e.g. for HimeLB)
     // format:
-    // nRanks integers (number of particles per domain)
-    // nRanks integers (offset in file)
+    // n_ranks integers (number of particles per domain)
+    // n_ranks integers (offset in file)
     // <n_particles[0]> * (long + 3*double) double values (particle positions)
 
     // open file
-    int known_unused err;
+    int err;
     MPI_File outfile;
     err =
         MPI_File_open(MPI_COMM_WORLD, "particles_output.bin",
@@ -2291,20 +2387,18 @@ int main(int argc, char **argv) {
     long block_size = sizeof(long); // + 3 * sizeof(double);
 
     offset *= block_size;
-    offset += nRanks * (sizeof(int) + sizeof(long));
+    offset += n_ranks * (sizeof(int) + sizeof(long));
 
     MPI_File_write_at(
-        outfile, (MPI_Offset)(nRanks * sizeof(int) + localRank * sizeof(long)),
+        outfile, (MPI_Offset)(n_ranks * sizeof(int) + localRank * sizeof(long)),
         &offset, 1, MPI_LONG, MPI_STATUS_IGNORE);
 
     for (int n = 0; n < n_points; ++n) {
       long id = points.at(n).get_id();
-      /*
       double loc[3];
       loc[0] = points.at(n)[0];
       loc[1] = points.at(n)[1];
       loc[2] = points.at(n)[2];
-      */
       MPI_File_write_at(outfile, (MPI_Offset)((offset + n * block_size)), &id,
                         1, MPI_LONG, MPI_STATUS_IGNORE);
       /*
@@ -2321,9 +2415,9 @@ int main(int argc, char **argv) {
 
     /*
     // output of borders / contents
-    if (nRanks < 216)
+    if (n_ranks < 216)
     {
-    for (int i = 0; i < nRanks; ++i)
+    for (int i = 0; i < n_ranks; ++i)
     {
     if (localRank == i)
     {
@@ -2344,7 +2438,7 @@ int main(int argc, char **argv) {
     of << " " << n_local << " ";
 
     of << std::endl;
-    if (i == nRanks - 1) of << std::endl;
+    if (i == n_ranks - 1) of << std::endl;
     of.close();
     MPI_Barrier(cart_comm);
     }
@@ -2366,7 +2460,6 @@ int main(int argc, char **argv) {
     MPI_Finalize();
     return EXIT_SUCCESS;
   } catch (ALL::CustomException &e) {
-    std::cout << "Error caught: " << std::endl;
     std::cout << e.what() << std::endl;
     return EXIT_FAILURE;
   }
diff --git a/include/ALL.hpp b/include/ALL.hpp
index b6dbcc2..f0c685e 100644
--- a/include/ALL.hpp
+++ b/include/ALL.hpp
@@ -66,9 +66,6 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 #include <vtkVoxel.h>
 #include <vtkXMLPUnstructuredGridWriter.h>
 #include <vtkXMLUnstructuredGridWriter.h>
-#ifdef VTK_CELL_ARRAY_V2
-#include <vtkNew.h>
-#endif
 #endif
 
 namespace ALL {
@@ -545,19 +542,14 @@ public:
     // define rank array (length = 4, x,y,z, rank)
     int rank = 0;
     MPI_Comm_rank(communicator, &rank);
-    auto coords = vtkSmartPointer<vtkFloatArray>::New();
-    coords->SetNumberOfComponents(3);
-    coords->SetNumberOfTuples(1);
-    coords->SetName("coords");
-    coords->SetValue(0, procGridLoc.at(0));
-    coords->SetValue(1, procGridLoc.at(1));
-    coords->SetValue(2, procGridLoc.at(2));
-
     auto rnk = vtkSmartPointer<vtkFloatArray>::New();
-    rnk->SetNumberOfComponents(1);
+    rnk->SetNumberOfComponents(4);
     rnk->SetNumberOfTuples(1);
-    rnk->SetName("MPI rank");
-    rnk->SetValue(0, rank);
+    rnk->SetName("rank");
+    rnk->SetValue(0, procGridLoc.at(0));
+    rnk->SetValue(1, procGridLoc.at(1));
+    rnk->SetValue(2, procGridLoc.at(2));
+    rnk->SetValue(3, rank);
 
     // define tag array (length = 1)
     auto tag = vtkSmartPointer<vtkIntArray>::New();
@@ -613,7 +605,6 @@ public:
     unstructuredGrid->GetCellData()->AddArray(work);
     unstructuredGrid->GetCellData()->AddArray(expanse);
     unstructuredGrid->GetCellData()->AddArray(rnk);
-    unstructuredGrid->GetCellData()->AddArray(coords);
     unstructuredGrid->GetCellData()->AddArray(tag);
 
     createDirectory("vtk_outline");
@@ -671,14 +662,11 @@ public:
     local_vertices[nVertices * balancer->getDimension()] =
         (T)balancer->getWork().at(0);
 
-    /*
     T *global_vertices;
     if (local_rank == 0) {
       global_vertices =
           new T[n_ranks * (nVertices * balancer->getDimension() + 1)];
     }
-    */
-    T global_vertices[n_ranks * (nVertices * balancer->getDimension() + 1)];
 
     // collect all works and vertices on a single process
     MPI_Gather(local_vertices, nVertices * balancer->getDimension() + 1,
@@ -688,12 +676,7 @@ public:
 
     if (local_rank == 0) {
       auto vtkpoints = vtkSmartPointer<vtkPoints>::New();
-#ifdef VTK_CELL_ARRAY_V2
-      vtkNew<vtkUnstructuredGrid> unstructuredGrid;
-      unstructuredGrid->Allocate(n_ranks + 10);
-#else
       auto unstructuredGrid = vtkSmartPointer<vtkUnstructuredGrid>::New();
-#endif
 
       // enter vertices into unstructured grid
       for (int i = 0; i < n_ranks; ++i) {
@@ -719,29 +702,7 @@ public:
       cell->SetNumberOfTuples(n_ranks);
       cell->SetName("cell id");
 
-
       for (int n = 0; n < n_ranks; ++n) {
-
-#ifdef VTK_CELL_ARRAY_V2
-        // define grid points, i.e. vertices of local domain
-        vtkIdType pointIds[8] = {8 * n + 0, 8 * n + 1, 8 * n + 2, 8 * n + 3,
-                                 8 * n + 4, 8 * n + 5, 8 * n + 6, 8 * n + 7};
-      
-        vtkIdType faces[48] = { 3, 8 * n + 0, 8 * n + 2, 8 * n + 1,
-                                3, 8 * n + 1, 8 * n + 2, 8 * n + 3, 
-                                3, 8 * n + 0, 8 * n + 4, 8 * n + 2, 
-                                3, 8 * n + 2, 8 * n + 4, 8 * n + 6, 
-                                3, 8 * n + 2, 8 * n + 6, 8 * n + 3, 
-                                3, 8 * n + 3, 8 * n + 6, 8 * n + 7, 
-                                3, 8 * n + 1, 8 * n + 5, 8 * n + 3, 
-                                3, 8 * n + 3, 8 * n + 5, 8 * n + 7, 
-                                3, 8 * n + 0, 8 * n + 4, 8 * n + 1, 
-                                3, 8 * n + 1, 8 * n + 4, 8 * n + 5, 
-                                3, 8 * n + 4, 8 * n + 6, 8 * n + 5, 
-                                3, 8 * n + 5, 8 * n + 6, 8 * n + 7};
-
-        unstructuredGrid->InsertNextCell(VTK_POLYHEDRON, 8, pointIds, 12, faces);
-#else
         // define grid points, i.e. vertices of local domain
         vtkIdType pointIds[8] = {8 * n + 0, 8 * n + 1, 8 * n + 2, 8 * n + 3,
                                  8 * n + 4, 8 * n + 5, 8 * n + 6, 8 * n + 7};
@@ -781,7 +742,6 @@ public:
 
         unstructuredGrid->InsertNextCell(VTK_POLYHEDRON, 8, pointIds, 12,
                                          faces->GetPointer());
-#endif
         work->SetValue(
             n,
             global_vertices[n * (nVertices * balancer->getDimension() + 1) +
@@ -802,7 +762,7 @@ public:
       // writer->SetDataModeToBinary();
       writer->Write();
 
-      //delete[] global_vertices;
+      delete[] global_vertices;
     }
   }
 #endif
diff --git a/include/ALL_ForceBased.hpp b/include/ALL_ForceBased.hpp
index 890b092..6ebbc3b 100644
--- a/include/ALL_ForceBased.hpp
+++ b/include/ALL_ForceBased.hpp
@@ -59,8 +59,6 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
                 gamma:  correction factor to control the
                         speed of the vertex shift
 
-    Information:
-        Currently only implemented for 3D systems
         */
 
 #include "ALL_CustomExceptions.hpp"
@@ -116,7 +114,7 @@ public:
 
   /// method to set specific data structures (unused for tensor grid method)
   /// @param[in] data pointer to the data structure
-  virtual void setAdditionalData(known_unused const void *data) override;
+  virtual void setAdditionalData(known_unused const void *data) override {}
 
 private:
   // type for MPI communication
@@ -140,11 +138,11 @@ private:
   // number of vertices (since it is not equal for all domains)
   int n_vertices;
 
+  // main dimension (correction in staggered style)
+  int main_dim;
+
   // secondary dimensions (2D force shift)
   int sec_dim[2];
-
-  // center of particles
-  Point<T> cop;
 };
 
 template <class T, class W> ForceBased_LB<T, W>::~ForceBased_LB() {}
@@ -155,16 +153,9 @@ std::vector<int> &ForceBased_LB<T, W>::getNeighbors() {
   return neighbors;
 }
 
-template <class T, class W> void ForceBased_LB<T, W>::setAdditionalData(void const *data)
-{
-    cop.setDimension(3);
-    cop[0] = (*((Point<T>*)data))[0];
-    cop[1] = (*((Point<T>*)data))[1];
-    cop[2] = (*((Point<T>*)data))[2];
-}
-
 template <class T, class W> void ForceBased_LB<T, W>::setup() {
   n_vertices = this->vertices.size();
+  vertex_neighbors.resize(n_vertices * 8);
 
   // determine correct MPI data type for template T
   if (std::is_same<T, double>::value)
@@ -217,19 +208,46 @@ template <class T, class W> void ForceBased_LB<T, W>::setup() {
   // get the local rank from the MPI communicator
   MPI_Cart_rank(this->globalComm, this->local_coords.data(), &this->localRank);
 
-  // find list of neighbors (does not change)
-  neighbors.resize(27);
-  int n = 0;
-  for (int z = -1; z <= 1; ++z)
-     for (int y = -1; y <= 1; ++y)
-        for (int x = -1; x <= 1; ++x)
-        {
-            std::vector<int> coords({this->local_coords.at(0)+x,
-                                     this->local_coords.at(1)+y,
-                                     this->local_coords.at(2)+z }); 
-            MPI_Cart_rank(this->globalComm, coords.data(), &neighbors.at(n));
-            ++n;
-        } 
+  // groups required for new communicators
+  MPI_Group known_unused  groups[n_vertices];
+  // arrays of processes belonging to group
+  int known_unused processes[n_vertices][n_vertices];
+
+  // shifted local coordinates to find neighboring processes
+  int known_unused shifted_coords[this->dimension];
+
+  // get main and secondary dimensions
+  if (this->global_dims.at(2) >= this->global_dims.at(1)) {
+    if (this->global_dims.at(2) >= this->global_dims.at(0)) {
+      main_dim = 2;
+      sec_dim[0] = 0;
+      sec_dim[1] = 1;
+    } else {
+      main_dim = 0;
+      sec_dim[0] = 1;
+      sec_dim[1] = 2;
+    }
+  } else {
+    if (this->global_dims.at(1) >= this->global_dims.at(0)) {
+      main_dim = 1;
+      sec_dim[0] = 0;
+      sec_dim[1] = 2;
+    } else {
+      main_dim = 0;
+      sec_dim[0] = 1;
+      sec_dim[1] = 2;
+    }
+  }
+
+  if (this->localRank == 0)
+    std::cout << "DEBUG: main_dim: " << main_dim << std::endl;
+
+  // create main communicator
+  MPI_Comm_split(this->globalComm, this->local_coords.at(main_dim),
+                 this->local_coords.at(sec_dim[0]) +
+                     this->local_coords.at(sec_dim[1]) *
+                         this->global_dims.at(sec_dim[0]),
+                 &main_communicator);
 
   /*
 
@@ -255,7 +273,108 @@ template <class T, class W> void ForceBased_LB<T, W>::setup() {
       // 0 - - - 1
 
   */
-  
+
+#ifdef ALL_DEBUG_ENABLED
+  MPI_Barrier(this->globalComm);
+  if (this->localRank == 0)
+    std::cout << "ALL::ForceBased_LB<T,W>::setup() preparing communicators..."
+              << std::endl;
+#endif
+  std::vector<int> dim_vert(this->global_dims);
+
+  MPI_Comm known_unused tmp_comm;
+  int known_unused own_vertex;
+
+#ifdef ALL_DEBUG_ENABLED
+  MPI_Barrier(this->globalComm);
+  if (this->localRank == 0)
+    std::cout << "ALL::ForceBased_LB<T,W>::setup() computing communicators..."
+              << std::endl;
+  std::cout << "DEBUG: "
+            << " rank: " << this->localRank << " dim_vert: " << dim_vert.at(0)
+            << " " << dim_vert.at(1) << " " << dim_vert.at(2)
+            << " size(local_coords): " << this->local_coords.size() << " "
+            << " size(global_dims):  " << this->global_dims.size() << std::endl;
+#endif
+  for (int iz = 0; iz < dim_vert.at(2); ++iz) {
+    for (int iy = 0; iy < dim_vert.at(1); ++iy) {
+      for (int ix = 0; ix < dim_vert.at(0); ++ix) {
+        bool affected[8];
+        for (auto &a : affected)
+          a = false;
+        int v_neighbors[8];
+        for (auto &vn : vertex_neighbors)
+          vn = -1;
+        if (ix == ((this->local_coords.at(0) + 0) % dim_vert.at(0)) &&
+            iy == ((this->local_coords.at(1) + 0) % dim_vert.at(1)) &&
+            iz == ((this->local_coords.at(2) + 0) % dim_vert.at(2))) {
+          affected[0] = true;
+          v_neighbors[0] = this->localRank;
+        }
+        if (ix == ((this->local_coords.at(0) + 1) % dim_vert.at(0)) &&
+            iy == ((this->local_coords.at(1) + 0) % dim_vert.at(1)) &&
+            iz == ((this->local_coords.at(2) + 0) % dim_vert.at(2))) {
+          affected[1] = true;
+          v_neighbors[1] = this->localRank;
+        }
+        if (ix == ((this->local_coords.at(0) + 0) % dim_vert.at(0)) &&
+            iy == ((this->local_coords.at(1) + 1) % dim_vert.at(1)) &&
+            iz == ((this->local_coords.at(2) + 0) % dim_vert.at(2))) {
+          affected[2] = true;
+          v_neighbors[2] = this->localRank;
+        }
+        if (ix == ((this->local_coords.at(0) + 1) % dim_vert.at(0)) &&
+            iy == ((this->local_coords.at(1) + 1) % dim_vert.at(1)) &&
+            iz == ((this->local_coords.at(2) + 0) % dim_vert.at(2))) {
+          affected[3] = true;
+          v_neighbors[3] = this->localRank;
+        }
+        if (ix == ((this->local_coords.at(0) + 0) % dim_vert.at(0)) &&
+            iy == ((this->local_coords.at(1) + 0) % dim_vert.at(1)) &&
+            iz == ((this->local_coords.at(2) + 1) % dim_vert.at(2))) {
+          affected[4] = true;
+          v_neighbors[4] = this->localRank;
+        }
+        if (ix == ((this->local_coords.at(0) + 1) % dim_vert.at(0)) &&
+            iy == ((this->local_coords.at(1) + 0) % dim_vert.at(1)) &&
+            iz == ((this->local_coords.at(2) + 1) % dim_vert.at(2))) {
+          affected[5] = true;
+          v_neighbors[5] = this->localRank;
+        }
+        if (ix == ((this->local_coords.at(0) + 0) % dim_vert.at(0)) &&
+            iy == ((this->local_coords.at(1) + 1) % dim_vert.at(1)) &&
+            iz == ((this->local_coords.at(2) + 1) % dim_vert.at(2))) {
+          affected[6] = true;
+          v_neighbors[6] = this->localRank;
+        }
+        if (ix == ((this->local_coords.at(0) + 1) % dim_vert.at(0)) &&
+            iy == ((this->local_coords.at(1) + 1) % dim_vert.at(1)) &&
+            iz == ((this->local_coords.at(2) + 1) % dim_vert.at(2))) {
+          affected[7] = true;
+          v_neighbors[7] = this->localRank;
+        }
+
+        MPI_Allreduce(MPI_IN_PLACE, v_neighbors, 8, MPI_INT, MPI_MAX,
+                      this->globalComm);
+
+        for (int v = 0; v < n_vertices; ++v) {
+          if (affected[v]) {
+            for (int n = 0; n < 8; ++n) {
+              vertex_neighbors.at(8 * v + n) = v_neighbors[n];
+            }
+          }
+        }
+      }
+    }
+// ToDo: check if vertices correct
+#ifdef ALL_DEBUG_ENABLED
+    MPI_Barrier(this->globalComm);
+    if (this->localRank == 0)
+      std::cout << "ALL::ForceBased_LB<T,W>::setup() finished computing "
+                   "communicators..."
+                << std::endl;
+#endif
+  }
 }
 
 // TODO: periodic boundary conditions (would require size of the system)
@@ -263,250 +382,410 @@ template <class T, class W> void ForceBased_LB<T, W>::setup() {
 
 template <class T, class W>
 void ForceBased_LB<T, W>::balance(int /*step*/) {
-   
-    // store current vertices
-    this->prevVertices = this->vertices;
 
-    // compute geometric center of domain
-    Point<T> geoCenter({0.0,0.0,0.0});
+  this->prevVertices = this->vertices;
 
-    for (auto v : this->vertices)
-    {
-        geoCenter = geoCenter + v;
-    }
-    geoCenter = geoCenter * (1.0 / (T)n_vertices); 
-
-    std::vector<T> locInfo(   { cop[0],
-                                cop[1],
-                                cop[2],
-                                (T)this->work.at(0)});
-    //std::vector<T> locInfo(   { geoCenter[0],
-    //                            geoCenter[1],
-    //                            geoCenter[2],
-    //                            (T)this->work.at(0)});
-    
-    // list of neighbors information is provided to and the shift is
-    // received from
-    std::vector<int> infoNeig({0,1,3,4,9,10,12,13});
-
-    // list of neighbors information is received from to compute
-    // local shift and shift is send to (mirrored to list above)
-    std::vector<int> shiftNeig({26,25,23,22,17,16,14,13});
-                                 
-    // vector to store data required to store necessary data  
-    std::vector<T> shiftData(32);
-
-    // vector to store shifts for each of the vertices
-    std::vector<T> shifts(24);
-
-    // vector to store MPI requests
-    std::vector<int> recvReq(8);
-    std::vector<int> sendReq(8);
-
-    // vector to store MPI status objects
-    std::vector<MPI_Status> recvStat(8);
-    std::vector<MPI_Status> sendStat(8);
-
-    // exchange information
-    for (int n = 0; n < 8; ++n)
-    {
-        // receives
-        MPI_Irecv(  shiftData.data()+4*n, 
-                    4, 
-                    MPIDataTypeT, 
-                    neighbors.at(shiftNeig.at(n)), 
-                    1012,
-                    this->globalComm,
-                    recvReq.data()+n );
-
-        // sends
-        MPI_Isend(  locInfo.data(), 
-                    4, 
-                    MPIDataTypeT, 
-                    neighbors.at(infoNeig.at(n)), 
-                    1012,
-                    this->globalComm,
-                    sendReq.data()+n );
-
-    }
-
-    MPI_Waitall(8,sendReq.data(),sendStat.data());
-    MPI_Waitall(8,recvReq.data(),recvStat.data());
-
-    // compute shift
-    Point<T> shift(3);
-    shift[0] = shift[1] = shift[2] = 0.0;
+  // geometrical center and work of each neighbor for each vertex
+  // work is cast to vertex data type, since in the end a shift
+  // of the vertex is computed
+  T vertex_info[n_vertices * n_vertices * (this->dimension + 2)];
+  T center[this->dimension];
 
-    auto cmp = [=](std::pair<T,int> a, std::pair<T,int> b)
-    {
-        return a.first > b.first;
-    };
+  // local geometric center
+  T known_unused local_info[(this->dimension + 2) * n_vertices];
 
-/*    
+  for (int i = 0; i < n_vertices * n_vertices * (this->dimension + 2); ++i)
+    vertex_info[i] = -1;
 
-    // compute global influence on the shift
-    int nRanks;
-    MPI_Comm_size(this->globalComm, &nRanks);
-    std::vector<T[4]> globalData(nRanks);
+  for (int d = 0; d < (this->dimension + 2) * n_vertices; ++d)
+    local_info[d] = (T)0;
 
-    Point<T> globalTarget(3);
+  for (int d = 0; d < this->dimension; ++d)
+    center[d] = (T)0;
 
-    MPI_Allgather(locInfo.data(),4, MPIDataTypeT, globalData.data(), 4, MPIDataTypeT, this->globalComm);
-    std::vector<std::pair<T,int>> globalTargets(nRanks);
-
-    for (int i = 0; i < nRanks; ++i)
-    {
-        globalTargets.at(i).first = globalData.at(i)[3];
-        globalTargets.at(i).second = i;
-        globalTarget = globalTarget + Point<T>(3, globalData.at(i));
+  // compute geometric center
+  for (int v = 0; v < n_vertices; ++v) {
+    for (int d = 0; d < this->dimension; ++d) {
+      center[d] += ((this->prevVertices.at(v))[d] / (T)n_vertices);
     }
-    globalTarget = 1.0 / (T) nRanks * globalTarget;
-
-    std::sort(globalTargets.begin(), globalTargets.end(), cmp); 
-
-    int n = 0;
-    while (std::pow(2,n) <= 0.125 * nRanks)
-    {
-        for (int j = (int)std::pow(2,n)-1; j < (int)std::pow(2,n); ++j)
-            for (int i = 0; i < 3; ++i)
-                shift[i] += (0.25 / ( this->gamma * (T)std::pow(2,n) ) ) * (globalData.at(globalTargets.at(j).second)[i] - this->vertices.at(7)[i]) / 8.0;
-        ++n;
+    local_info[v * (this->dimension + 2) + this->dimension] =
+        (T)this->work.at(0);
+  }
+  for (int v = 0; v < n_vertices; ++v) {
+    for (int d = 0; d < this->dimension; ++d) {
+      // vector pointing to center
+      local_info[v * (this->dimension + 2) + d] =
+          center[d] - (this->prevVertices.at(v))[d];
     }
+  }
 
-    for (int i = 0; i < 3; ++i)
-        shift[i] += (0.25 / (this->gamma * nRanks)) * (globalTarget[i] - this->vertices.at(7)[i]);
-
-*/  
-
-    // compute the local influence on the shift
-
-    // compute average work around vertex
-    T workAvg;
-    workAvg = 0.0;
-    for (int n = 0; n < 8; ++n)
-        workAvg += shiftData.at(4*n+3) / 8.0;
+  // compute for each vertex the maximum movement
+  switch (main_dim) {
+  case 0:
+    local_info[0 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(0).d(this->prevVertices.at(2)),
+                       this->prevVertices.at(0).d(this->prevVertices.at(4)));
+    local_info[1 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(1).d(this->prevVertices.at(3)),
+                       this->prevVertices.at(1).d(this->prevVertices.at(5)));
+    local_info[2 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(2).d(this->prevVertices.at(0)),
+                       this->prevVertices.at(2).d(this->prevVertices.at(6)));
+    local_info[3 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(3).d(this->prevVertices.at(1)),
+                       this->prevVertices.at(3).d(this->prevVertices.at(7)));
+    local_info[4 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(4).d(this->prevVertices.at(0)),
+                       this->prevVertices.at(4).d(this->prevVertices.at(6)));
+    local_info[5 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(5).d(this->prevVertices.at(1)),
+                       this->prevVertices.at(5).d(this->prevVertices.at(7)));
+    local_info[6 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(6).d(this->prevVertices.at(2)),
+                       this->prevVertices.at(6).d(this->prevVertices.at(4)));
+    local_info[7 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(7).d(this->prevVertices.at(3)),
+                       this->prevVertices.at(7).d(this->prevVertices.at(5)));
+    break;
+  case 1:
+    local_info[0 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(0).d(this->prevVertices.at(1)),
+                       this->prevVertices.at(0).d(this->prevVertices.at(4)));
+    local_info[1 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(1).d(this->prevVertices.at(0)),
+                       this->prevVertices.at(1).d(this->prevVertices.at(5)));
+    local_info[2 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(2).d(this->prevVertices.at(3)),
+                       this->prevVertices.at(2).d(this->prevVertices.at(6)));
+    local_info[3 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(3).d(this->prevVertices.at(2)),
+                       this->prevVertices.at(3).d(this->prevVertices.at(7)));
+    local_info[4 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(4).d(this->prevVertices.at(0)),
+                       this->prevVertices.at(4).d(this->prevVertices.at(5)));
+    local_info[5 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(5).d(this->prevVertices.at(1)),
+                       this->prevVertices.at(5).d(this->prevVertices.at(4)));
+    local_info[6 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(6).d(this->prevVertices.at(2)),
+                       this->prevVertices.at(6).d(this->prevVertices.at(7)));
+    local_info[7 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(7).d(this->prevVertices.at(3)),
+                       this->prevVertices.at(7).d(this->prevVertices.at(6)));
+    break;
+  case 2:
+    local_info[0 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(0).d(this->prevVertices.at(1)),
+                       this->prevVertices.at(0).d(this->prevVertices.at(2)));
+    local_info[1 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(1).d(this->prevVertices.at(0)),
+                       this->prevVertices.at(1).d(this->prevVertices.at(3)));
+    local_info[2 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(2).d(this->prevVertices.at(0)),
+                       this->prevVertices.at(2).d(this->prevVertices.at(3)));
+    local_info[3 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(3).d(this->prevVertices.at(1)),
+                       this->prevVertices.at(3).d(this->prevVertices.at(2)));
+    local_info[4 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(4).d(this->prevVertices.at(5)),
+                       this->prevVertices.at(4).d(this->prevVertices.at(6)));
+    local_info[5 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(5).d(this->prevVertices.at(4)),
+                       this->prevVertices.at(5).d(this->prevVertices.at(7)));
+    local_info[6 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(6).d(this->prevVertices.at(4)),
+                       this->prevVertices.at(6).d(this->prevVertices.at(7)));
+    local_info[7 * (this->dimension + 2) + this->dimension + 1] =
+        0.5 * std::min(this->prevVertices.at(7).d(this->prevVertices.at(5)),
+                       this->prevVertices.at(7).d(this->prevVertices.at(6)));
+    break;
+  default:
+    throw InternalErrorException(
+        __FILE__, __func__, __LINE__,
+        "Invalid main dimension provided (numerical value not 0, 1 or 2).");
+    break;
+  }
+  // exchange information with all vertex neighbors
+  MPI_Request known_unused request[n_vertices];
+  MPI_Status known_unused status[n_vertices];
 
-    // sort work from highest to lowest
-    std::vector<std::pair<T,int>> localTargets(8);
+  // compute new position for vertex 7 (if not periodic)
+  T known_unused total_work = (T)0;
+  T known_unused shift_vectors[this->dimension * n_vertices];
 
-    for (int i = 0; i < 8; ++i)
-    {
-        localTargets.at(i).first = shiftData.at(4*i+3);
-        localTargets.at(i).second = i;
+  for (int v = 0; v < n_vertices; ++v) {
+    for (int d = 0; d < this->dimension; ++d) {
+      shift_vectors[v * this->dimension + d] = (T)0;
     }
+  }
 
-    std::sort(localTargets.begin(), localTargets.end(), cmp);
-
-    std::vector<T> diffFactor(8);
-
-
-    int nInfl = 0;
-    for (int i = 0; i < 8; ++i)
-        if (localTargets.at(i).first - workAvg > 1e-12)
-            ++nInfl;
-
-    for (int i = 0; i < 8; ++i)
-    {
-        
-        //if (shiftData.at(4*i+3) > workAvg)
-        //    diffFactor.at(i) = (shiftData.at(4*i+3) - workAvg)/(localTargets.at(0).first - workAvg);
-        //else
-        //    diffFactor.at(i) = 0.05;
-        
-        if (localTargets.at(i).first - workAvg > 1e-12)
-            diffFactor.at(i) = 1.0 / (double)nInfl;
-        else
-            diffFactor.at(i) = 0.0;
+  // TODO:
+  // 1.) collect work in main_dim communicators
+  // 2.) exchange work and extension in main_dim to neighbors in main_dim
+  // 3.) correct extension in main_dim
+  // 4.) do force-based correction in secondary dimensions
+  // 5.) find new neighbors in main_dim (hardest part, probably cell based)
+
+  int main_up;
+  int main_down;
+
+  // as each process computes the correction of the upper layer of vertices
+  // the source is 'main_up' and the target 'main_down'
+  MPI_Cart_shift(this->globalComm, main_dim, 1, &main_down, &main_up);
+
+  W local_work = this->work.at(0);
+  W remote_work;
+
+  // reduce work from local layer
+  MPI_Allreduce(MPI_IN_PLACE, &local_work, 1, MPIDataTypeW, MPI_SUM,
+                main_communicator);
+
+  // exchange local work with neighbor in main direction
+  MPI_Status state;
+  MPI_Sendrecv(&local_work, 1, MPIDataTypeW, main_down, 1010, &remote_work, 1,
+               MPIDataTypeW, main_up, 1010, this->globalComm, &state);
+
+  // transfer width of domains as well
+  T remote_width;
+  T local_width;
+  switch (main_dim) {
+  case 0:
+    local_width = this->prevVertices.at(1)[0] - this->prevVertices.at(0)[0];
+    break;
+  case 1:
+    local_width = this->prevVertices.at(2)[1] - this->prevVertices.at(0)[1];
+    break;
+  case 2:
+    local_width = this->prevVertices.at(4)[2] - this->prevVertices.at(0)[2];
+    break;
+  default:
+    throw InternalErrorException(
+        __FILE__, __func__, __LINE__,
+        "Invalid main dimension provided (numerical value not 0, 1 or 2).");
+    break;
+  }
+  MPI_Sendrecv(&local_width, 1, MPIDataTypeT, main_down, 1010, &remote_width, 1,
+               MPIDataTypeT, main_up, 1010, this->globalComm, &state);
+  T total_width = local_width + remote_width;
+  T max_main_shift = std::min(0.49 * std::min(local_width, remote_width),
+                              std::min(local_width, remote_width) - 1.0);
+
+  // compute shift
+  T local_shift = (remote_work - local_work) / (remote_work + local_work) /
+                  this->gamma / 2.0 * total_width;
+  local_shift = (std::abs(local_shift) > max_main_shift)
+                    ? std::abs(local_shift) * max_main_shift / local_shift
+                    : local_shift;
+  T remote_shift = 0;
+
+  // TODO: fix -> send to upper neighbor, receive from lower neighbor!!
+  //              sendrecv not correct!
+
+  // transfer shift back
+  MPI_Sendrecv(&local_shift, 1, MPIDataTypeT, main_up, 2020, &remote_shift, 1,
+               MPIDataTypeT, main_down, 2020, this->globalComm, &state);
+
+  // apply shift in main direction
+  switch (main_dim) {
+  case 0:
+    if (this->local_coords.at(0) > 0) {
+      this->vertices.at(0)[0] = this->prevVertices.at(0)[0] + remote_shift;
+      this->vertices.at(2)[0] = this->prevVertices.at(2)[0] + remote_shift;
+      this->vertices.at(4)[0] = this->prevVertices.at(4)[0] + remote_shift;
+      this->vertices.at(6)[0] = this->prevVertices.at(6)[0] + remote_shift;
     }
-
-  
-    if (this->localRank == 0)
-    {
-        std::cout << "localTargets: ";
-        for (int i = 0; i < 8; ++i)
-            std::cout << "[ " << localTargets.at(i).first << " | " << localTargets.at(i).second << " | " << diffFactor.at(localTargets.at(i).second) << " ] ";
-        std::cout << std::endl;
+    if (this->local_coords.at(0) < this->global_dims.at(0) - 1) {
+      this->vertices.at(1)[0] = this->prevVertices.at(1)[0] + local_shift;
+      this->vertices.at(3)[0] = this->prevVertices.at(3)[0] + local_shift;
+      this->vertices.at(5)[0] = this->prevVertices.at(5)[0] + local_shift;
+      this->vertices.at(7)[0] = this->prevVertices.at(7)[0] + local_shift;
     }
-    
-
-    for (int n = 0; n < 8; ++n)
-    {
-        if (std::abs(shiftData.at(4*n+3) + locInfo.at(3)) > 1e-12)
-        {
-            // compute direction vector from vertex to geometric
-            // center of neighboring domain
-            Point<T> domCenter(3, shiftData.data()+4*n);
-
-            // normalize vector to insure that the domains do not become
-            // too deformed after shift
-            Point<T> direction(3);
-            direction = (domCenter - this->vertices.at(7)) * 0.5;
-
-            T scalContrib =  ( ( shiftData.at(4*n+3) - workAvg )
-                                         * diffFactor.at(n) 
-                                         / workAvg 
-                                         / this->gamma );
-
-            //if (scalContrib < 1e-12) scalContrib = 0.0;
-
-            // compute contribution to shift
-            Point<T> contrib = direction * scalContrib;
-                
-            
-            // add contribution to shift
-            shift = shift + contrib;
-        }
+    break;
+  case 1:
+    if (this->local_coords.at(1) > 0) {
+      this->vertices.at(0)[1] = this->prevVertices.at(0)[1] + remote_shift;
+      this->vertices.at(1)[1] = this->prevVertices.at(1)[1] + remote_shift;
+      this->vertices.at(4)[1] = this->prevVertices.at(4)[1] + remote_shift;
+      this->vertices.at(5)[1] = this->prevVertices.at(5)[1] + remote_shift;
+    }
+    if (this->local_coords.at(1) < this->global_dims.at(1) - 1) {
+      this->vertices.at(2)[1] = this->prevVertices.at(2)[1] + local_shift;
+      this->vertices.at(3)[1] = this->prevVertices.at(3)[1] + local_shift;
+      this->vertices.at(6)[1] = this->prevVertices.at(6)[1] + local_shift;
+      this->vertices.at(7)[1] = this->prevVertices.at(7)[1] + local_shift;
     }
+    break;
+  case 2:
+    if (this->local_coords.at(2) > 0) {
+      this->vertices.at(0)[2] = this->prevVertices.at(0)[2] + remote_shift;
+      this->vertices.at(1)[2] = this->prevVertices.at(1)[2] + remote_shift;
+      this->vertices.at(2)[2] = this->prevVertices.at(2)[2] + remote_shift;
+      this->vertices.at(3)[2] = this->prevVertices.at(3)[2] + remote_shift;
+    }
+    if (this->local_coords.at(2) < this->global_dims.at(2) - 1) {
+      this->vertices.at(4)[2] = this->prevVertices.at(4)[2] + local_shift;
+      this->vertices.at(5)[2] = this->prevVertices.at(5)[2] + local_shift;
+      this->vertices.at(6)[2] = this->prevVertices.at(6)[2] + local_shift;
+      this->vertices.at(7)[2] = this->prevVertices.at(7)[2] + local_shift;
+    }
+    break;
+  default:
+    throw InternalErrorException(
+        __FILE__, __func__, __LINE__,
+        "Invalid main dimension provided (numerical value not 0, 1 or 2).");
+    break;
+  }
 
+  if (this->localRank <= 1) {
+    std::cout << "DEBUG (main-dim shift): " << this->localRank << " "
+              << remote_work << " " << local_work << " " << total_width << " "
+              << local_shift << " " << remote_shift
+              << " 0: " << this->prevVertices.at(0)[2]
+              << " 0: " << this->vertices.at(0)[2]
+              << " 4: " << this->prevVertices.at(4)[2]
+              << " 4: " << this->vertices.at(4)[2] << " " << std::endl;
+  }
 
+  int last_vertex = (n_vertices - 1) * n_vertices * (this->dimension + 2);
+  int vertex_offset = this->dimension + 2;
+
+  // compute max shift
+  T max_shift = std::max(
+      0.49 * vertex_info[last_vertex + 0 * vertex_offset + this->dimension + 1],
+      1e-6);
+  for (int v = 1; v < n_vertices; ++v) {
+    max_shift = std::min(
+        max_shift,
+        0.49 *
+            vertex_info[last_vertex + v * vertex_offset + this->dimension + 1]);
+  }
 
-    // prevent shifts outside of system box and normalize (with number of
-    // neighboring domains for the vertex)
-    for (int d = 0; d < 3; ++d)
-    {
-        if (this->local_coords.at(d) == this->global_dims.at(d) - 1)
-            shift[d] = 0.0;
-        else
-            shift[d] = shift[d];
+  // average work of neighboring processes for last vertex
+  T avg_work = 0.0;
+  for (int v = 0; v < n_vertices; ++v) {
+    avg_work += vertex_info[last_vertex + v * vertex_offset + this->dimension];
+  }
+  avg_work /= (T)n_vertices;
+
+  // compute shift vector
+  std::vector<T> vertex_shift(n_vertices * this->dimension, (T)0.0);
+  for (auto &sv : vertex_shift)
+    sv = (T)0.0;
+  int shift_offset = (n_vertices - 1) * this->dimension;
+
+  for (int v = 0; v < n_vertices; ++v) {
+    for (int d = 0; d < this->dimension; ++d) {
+      if (this->localRank == 0)
+        std::cout
+            << "DEBUG (shift x): " << v << ", " << d << " "
+            << (vertex_info[last_vertex + v * vertex_offset + this->dimension] -
+                avg_work) *
+                   ((vertex_info[last_vertex + v * vertex_offset +
+                                 this->dimension] > avg_work)
+                        ? 1.0
+                        : -1.0) /
+                   this->gamma *
+                   (vertex_info[last_vertex + v * vertex_offset + d] -
+                    this->prevVertices.at(v)[d])
+            << " => " << vertex_shift.at(shift_offset + d) <<
+
+            " max: " << max_shift
+            << " "
+               " avg_work: "
+            << avg_work
+            << " "
+               " work: "
+            << vertex_info[last_vertex + v * vertex_offset + this->dimension]
+            << " "
+               " 1/0: "
+            << ((vertex_info[last_vertex + v * vertex_offset +
+                             this->dimension] > avg_work)
+                    ? 1.0
+                    : -1.0)
+
+            << std::endl;
+      vertex_shift.at(shift_offset + d) +=
+          (vertex_info[last_vertex + v * vertex_offset + this->dimension] -
+           avg_work) *
+          ((vertex_info[last_vertex + v * vertex_offset + this->dimension] >
+            avg_work)
+               ? 1.0
+               : -1.0) /
+          this->gamma *
+          (vertex_info[last_vertex + v * vertex_offset + d] -
+           this->prevVertices.at(v)[d]);
     }
+  }
 
-    std::vector<T> commShift(
-                                {shift[0],
-                                 shift[1],
-                                 shift[2]}
-                            );
-
-    // send computed shift to neighbors
-    for (int n = 0; n < 8; ++n)
-    {
-        // receives
-        MPI_Irecv(  shifts.data()+3*n, 
-                    3, 
-                    MPIDataTypeT, 
-                    neighbors.at(infoNeig.at(n)), 
-                    2012,
-                    this->globalComm,
-                    recvReq.data()+n );
-
-        // sends
-        MPI_Isend(  commShift.data(), 
-                    3, 
-                    MPIDataTypeT, 
-                    neighbors.at(shiftNeig.at(n)), 
-                    2012,
-                    this->globalComm,
-                    sendReq.data()+n );
+  Point<T> shift_vector(3);
+  shift_vector[0] = vertex_shift.at(shift_offset);
+  shift_vector[1] = vertex_shift.at(shift_offset + 1);
+  shift_vector[2] = vertex_shift.at(shift_offset + 2);
+  T shift_length = shift_vector.norm();
+  if (this->localRank == 0)
+    std::cout << "DEBUG (shift length): " << shift_length
+              << " shift vector: " << shift_vector[0] << " " << shift_vector[1]
+              << " " << shift_vector[2] << " " << std::endl;
+
+  // apply correct length, if the shift vector is too large
+  if (shift_length > max_shift) {
+    for (int d = 0; d < this->dimension; ++d) {
+      vertex_shift.at(shift_offset + d) *= max_shift / shift_length;
+    }
+  }
 
+  if (this->localRank == 0) {
+    for (int v = 0; v < n_vertices; ++v) {
+      std::cout << "DEBUG shift vector vertex (before): " << v << ": ";
+      for (int d = 0; d < this->dimension; ++d) {
+        std::cout << vertex_shift.at(v * this->dimension + d) << " ";
+      }
+      std::cout << std::endl;
     }
+  }
 
-    MPI_Waitall(8,sendReq.data(),sendStat.data());
-    MPI_Waitall(8,recvReq.data(),recvStat.data());
+  if (this->localRank == 0 || this->localRank == 1) {
+    for (int v = 0; v < n_vertices; ++v) {
+      std::cout << "DEBUG shift vector vertex " << v << ": ";
+      for (int d = 0; d < this->dimension; ++d) {
+        std::cout << vertex_shift.at(v * this->dimension + d) << " ";
+      }
+      std::cout << std::endl;
+    }
+  }
 
-    for (int n = 0; n < 8; ++n)
-    {
-        Point<T> shiftVec(3, shifts.data()+3*n);
-        this->vertices.at(n) = this->vertices.at(n) + shiftVec;
+  bool shift[3];
+  for (int z = 0; z < 2; ++z) {
+    if ((z == 0 && this->local_coords.at(2) == 0) ||
+        (z == 1 && this->local_coords.at(2) == this->global_dims.at(2) - 1))
+      shift[2] = false;
+    else
+      shift[2] = true;
+    for (int y = 0; y < 2; ++y) {
+      if ((y == 0 && this->local_coords.at(1) == 0) ||
+          (y == 1 && this->local_coords.at(1) == this->global_dims.at(1) - 1))
+        shift[1] = false;
+      else
+        shift[1] = true;
+      for (int x = 0; x < 2; ++x) {
+        if ((x == 0 && this->local_coords.at(0) == 0) ||
+            (x == 1 && this->local_coords.at(0) == this->global_dims.at(0) - 1))
+          shift[0] = false;
+        else
+          shift[0] = true;
+        for (int d = 0; d < this->dimension; ++d) {
+          if (d == main_dim)
+            continue;
+          int v = 4 * z + 2 * y + x;
+          if (shift[d])
+            this->vertices.at(v)[d] = this->prevVertices.at(v)[d] +
+                                      vertex_shift.at(v * this->dimension + d);
+          else
+            this->vertices.at(v)[d] = this->prevVertices.at(v)[d];
+        }
+      }
     }
+  }
 }
 
 }//namespace ALL
-- 
GitLab