diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index e6381f6c9718608ac50dd93331caf94a06b638d9..4fe6427ddb49f8fe175d90075d85ddd51229b93b 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -8,12 +8,10 @@ import logging
 import os
 import sys
 import traceback
-import pathos
-import multiprocess.context as ctx
+import multiprocessing
 import psutil
 from typing import Dict, Tuple, Union, List, Callable
 
-import tensorflow.keras as keras
 import numpy as np
 import pandas as pd
 import xarray as xr
@@ -698,15 +696,6 @@ class PostProcessing(RunEnvironment):
         logging.info(f"start train_ols_model on train data")
         self.ols_model = OrdinaryLeastSquaredModel(self.train_data)
 
-    @staticmethod
-    def _get_ctx_context():
-        _default_context = ctx._default_context
-        if _default_context._actual_context is None:
-            _actual_context = _default_context._default_context._name
-        else:
-            _actual_context = ctx._default_context._actual_context._name
-        return _actual_context
-
     @TimeTrackingWrapper
     def make_prediction(self, subset):
         """
@@ -726,43 +715,33 @@ class PostProcessing(RunEnvironment):
         max_process = self.data_store.get("max_number_multiprocessing")
         n_process = min([psutil.cpu_count(logical=False), len(subset), max_process])  # use only physical cpus
         if n_process > 1 and use_multiprocessing is True:  # parallel solution
-            for i, data in enumerate(subset):
-                f_proc_make_prediction(data, frequency, time_dimension, self.forecast_indicator,
-                    self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim,
-                    self.forecast_path, subset_type, self.model, self.window_lead_time, self.ols_model,
-                                       use_multiprocessing=True)
-                logging.info(f"...finished: {data} ({int((i + 1.) / len(subset) * 100)}%)")
-            # logging.info("use parallel make prediction approach")
-            # _actual_context = self._get_ctx_context()
-            # ctx._force_start_method('spawn')
-            # pool = pathos.multiprocessing.ProcessingPool(n_process)
-            # logging.info(f"running {getattr(pool, 'ncpus')} processes in parallel")
-            # output = [
-            #     pool.apipe(f_proc_make_prediction, data, frequency, time_dimension, self.forecast_indicator,
-            #                self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim,
-            #                self.forecast_path, subset_type, self.model.model, self.window_lead_time, self.ols_model)
-            #     for data in subset]
-            # for i, p in enumerate(output):
-            #     p.get()
-            #     logging.info(f"...finished: {subset[i]} ({int((i + 1.) / len(output) * 100)}%)")
-            # pool.close()
-            # pool.join()
-            # pool.clear()
-            # ctx._force_start_method(_actual_context)
+            pool = multiprocessing.Pool(n_process)
+            logging.info(f"running {getattr(pool, '_processes')} processes in parallel")
+            output = []
+            output_pre = [pool.apply_async(f_proc_load_data, args=(data, )) for data in subset]
+            for p in output_pre:
+                input_data, target_data, data = p.get()
+                nn_pred = self.model.predict(input_data)
+                output.append(pool.apply_async(
+                    f_proc_make_prediction,
+                    args=(data, input_data, target_data, nn_pred, frequency, time_dimension, self.forecast_indicator,
+                          self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim,
+                          self.forecast_path, subset_type, self.window_lead_time, self.ols_model)))
+            for i, p in enumerate(output):
+                p.get()
+                logging.info(f"...finished: {subset[i]} ({int((i + 1.) / len(output) * 100)}%)")
+            pool.close()
+            pool.join()
         else:  # serial solution
             logging.info("use serial make prediction approach")
             for i, data in enumerate(subset):
-                f_proc_make_prediction(data, frequency, time_dimension, self.forecast_indicator,
+                input_data, target_data = data.get_data(as_numpy=(True, False))
+                nn_pred = self.model.predict(input_data)
+                f_proc_make_prediction(data, input_data, target_data, nn_pred, frequency, time_dimension, self.forecast_indicator,
                     self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim,
-                    self.forecast_path, subset_type, self.model, self.window_lead_time, self.ols_model,
-                                       use_multiprocessing=True)
+                    self.forecast_path, subset_type, self.window_lead_time, self.ols_model)
                 logging.info(f"...finished: {data} ({int((i + 1.) / len(subset) * 100)}%)")
 
-        # for i, data in enumerate(subset):
-        #     f_proc_make_prediction(data, frequency, time_dimension, self.forecast_indicator, self.observation_indicator, window_dim,
-        #                            self.ahead_dim, self.index_dim, self.model_type_dim, self.forecast_path, subset_type, model,
-        #                            self.window_lead_time, self.ols_model)
-
         #
         #
         # for i, data in enumerate(subset):
@@ -1171,8 +1150,8 @@ class PostProcessing(RunEnvironment):
 
 class MakePrediction:
 
-    def __init__(self, model, window_lead_time, ols_model):
-        self.model = model
+    def __init__(self, nn_pred, window_lead_time, ols_model):
+        self.nn_pred = nn_pred
         self.window_lead_time = window_lead_time
         self.ols_model = ols_model  # must be copied maybe
 
@@ -1198,7 +1177,7 @@ class MakePrediction:
 
         :return: filled data array with nn predictions
         """
-        tmp_nn = self.model.predict(input_data, use_multiprocessing=use_multiprocessing)
+        tmp_nn = self.nn_pred
         if isinstance(tmp_nn, list):
             nn_prediction.values = tmp_nn[-1]
         elif tmp_nn.ndim == 3:
@@ -1324,21 +1303,17 @@ class MakePrediction:
         return res
 
 
-def f_proc_make_prediction(data, frequency, time_dimension, forecast_indicator, observation_indicator, window_dim,
-                           ahead_dim, index_dim, model_type_dim, forecast_path, subset_type, model, window_lead_time,
-                           ols_model, use_multiprocessing=False, custom_objects=None):
-    # import tensorflow.keras as keras
-    # if not (hasattr(model, "model") or isinstance(model, keras.Model)):
-    #     print(f"{data} load model")
-    #     model = keras.models.load_model(model, custom_objects=custom_objects)
-    #     model.make_predict_function()
-    #     print(f"{data} done")
+def f_proc_load_data(data):
+    input_data, target_data = data.get_data(as_numpy=(True, False))
+    return input_data, target_data, data
+
 
-    prediction_maker = MakePrediction(model, window_lead_time, ols_model)
+def f_proc_make_prediction(data, input_data, target_data, nn_pred, frequency, time_dimension, forecast_indicator, observation_indicator, window_dim,
+                           ahead_dim, index_dim, model_type_dim, forecast_path, subset_type, window_lead_time,
+                           ols_model):
+
+    prediction_maker = MakePrediction(nn_pred, window_lead_time, ols_model)
 
-    input_data, target_data = data.get_data(as_numpy=(True, False))
-    # input_data = data.get_X()
-    # target_data = data.get_Y(as_numpy=False)
     observation_data = data.get_observation()
 
     # get scaling parameters
@@ -1351,8 +1326,7 @@ def f_proc_make_prediction(data, frequency, time_dimension, forecast_indicator,
                                                                                                       count=4)
 
         # nn forecast
-        nn_pred = prediction_maker._create_nn_forecast(input_data, nn_pred, transformation_func, normalised,
-                                                       use_multiprocessing=use_multiprocessing)
+        nn_pred = prediction_maker._create_nn_forecast(input_data, nn_pred, transformation_func, normalised)
 
         # persistence
         persi_pred = prediction_maker._create_persistence_forecast(observation_data, persi_pred,