diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 2be09830438310fcd4ca37409fdd109314b5a202..e6381f6c9718608ac50dd93331caf94a06b638d9 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -726,29 +726,36 @@ class PostProcessing(RunEnvironment):
         max_process = self.data_store.get("max_number_multiprocessing")
         n_process = min([psutil.cpu_count(logical=False), len(subset), max_process])  # use only physical cpus
         if n_process > 1 and use_multiprocessing is True:  # parallel solution
-            logging.info("use parallel make prediction approach")
-            _actual_context = self._get_ctx_context()
-            ctx._force_start_method('spawn')
-            pool = pathos.multiprocessing.ProcessingPool(n_process)
-            logging.info(f"running {getattr(pool, 'ncpus')} processes in parallel")
-            output = [
-                pool.apipe(f_proc_make_prediction, data, frequency, time_dimension, self.forecast_indicator,
-                           self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim,
-                           self.forecast_path, subset_type, self.model.model, self.window_lead_time, self.ols_model)
-                for data in subset]
-            for i, p in enumerate(output):
-                p.get()
-                logging.info(f"...finished: {subset[i]} ({int((i + 1.) / len(output) * 100)}%)")
-            pool.close()
-            pool.join()
-            pool.clear()
-            ctx._force_start_method(_actual_context)
+            for i, data in enumerate(subset):
+                f_proc_make_prediction(data, frequency, time_dimension, self.forecast_indicator,
+                    self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim,
+                    self.forecast_path, subset_type, self.model, self.window_lead_time, self.ols_model,
+                                       use_multiprocessing=True)
+                logging.info(f"...finished: {data} ({int((i + 1.) / len(subset) * 100)}%)")
+            # logging.info("use parallel make prediction approach")
+            # _actual_context = self._get_ctx_context()
+            # ctx._force_start_method('spawn')
+            # pool = pathos.multiprocessing.ProcessingPool(n_process)
+            # logging.info(f"running {getattr(pool, 'ncpus')} processes in parallel")
+            # output = [
+            #     pool.apipe(f_proc_make_prediction, data, frequency, time_dimension, self.forecast_indicator,
+            #                self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim,
+            #                self.forecast_path, subset_type, self.model.model, self.window_lead_time, self.ols_model)
+            #     for data in subset]
+            # for i, p in enumerate(output):
+            #     p.get()
+            #     logging.info(f"...finished: {subset[i]} ({int((i + 1.) / len(output) * 100)}%)")
+            # pool.close()
+            # pool.join()
+            # pool.clear()
+            # ctx._force_start_method(_actual_context)
         else:  # serial solution
             logging.info("use serial make prediction approach")
             for i, data in enumerate(subset):
                 f_proc_make_prediction(data, frequency, time_dimension, self.forecast_indicator,
                     self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim,
-                    self.forecast_path, subset_type, self.model, self.window_lead_time, self.ols_model)
+                    self.forecast_path, subset_type, self.model, self.window_lead_time, self.ols_model,
+                                       use_multiprocessing=True)
                 logging.info(f"...finished: {data} ({int((i + 1.) / len(subset) * 100)}%)")
 
         # for i, data in enumerate(subset):
@@ -1176,7 +1183,7 @@ class MakePrediction:
         return [target_data.copy() for _ in range(count)]
 
     def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, transformation_func: Callable,
-                            normalised: bool) -> xr.DataArray:
+                            normalised: bool, use_multiprocessing: bool = False) -> xr.DataArray:
         """
         Create NN forecast for given input data.
 
@@ -1191,7 +1198,7 @@ class MakePrediction:
 
         :return: filled data array with nn predictions
         """
-        tmp_nn = self.model.predict(input_data)
+        tmp_nn = self.model.predict(input_data, use_multiprocessing=use_multiprocessing)
         if isinstance(tmp_nn, list):
             nn_prediction.values = tmp_nn[-1]
         elif tmp_nn.ndim == 3:
@@ -1319,7 +1326,7 @@ class MakePrediction:
 
 def f_proc_make_prediction(data, frequency, time_dimension, forecast_indicator, observation_indicator, window_dim,
                            ahead_dim, index_dim, model_type_dim, forecast_path, subset_type, model, window_lead_time,
-                           ols_model, custom_objects=None):
+                           ols_model, use_multiprocessing=False, custom_objects=None):
     # import tensorflow.keras as keras
     # if not (hasattr(model, "model") or isinstance(model, keras.Model)):
     #     print(f"{data} load model")
@@ -1344,7 +1351,8 @@ def f_proc_make_prediction(data, frequency, time_dimension, forecast_indicator,
                                                                                                       count=4)
 
         # nn forecast
-        nn_pred = prediction_maker._create_nn_forecast(input_data, nn_pred, transformation_func, normalised)
+        nn_pred = prediction_maker._create_nn_forecast(input_data, nn_pred, transformation_func, normalised,
+                                                       use_multiprocessing=use_multiprocessing)
 
         # persistence
         persi_pred = prediction_maker._create_persistence_forecast(observation_data, persi_pred,