diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 4fe6427ddb49f8fe175d90075d85ddd51229b93b..a5e70ba41ace8829422858b04db44390ab146e8d 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -717,26 +717,53 @@ class PostProcessing(RunEnvironment):
         if n_process > 1 and use_multiprocessing is True:  # parallel solution
             pool = multiprocessing.Pool(n_process)
             logging.info(f"running {getattr(pool, '_processes')} processes in parallel")
+
+
             output = []
-            output_pre = [pool.apply_async(f_proc_load_data, args=(data, )) for data in subset]
-            for p in output_pre:
-                input_data, target_data, data = p.get()
-                nn_pred = self.model.predict(input_data)
-                output.append(pool.apply_async(
-                    f_proc_make_prediction,
-                    args=(data, input_data, target_data, nn_pred, frequency, time_dimension, self.forecast_indicator,
-                          self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim,
-                          self.forecast_path, subset_type, self.window_lead_time, self.ols_model)))
-            for i, p in enumerate(output):
-                p.get()
-                logging.info(f"...finished: {subset[i]} ({int((i + 1.) / len(output) * 100)}%)")
+            output_pre = []
+            pos = 0
+            for i, data in enumerate(subset):
+                output_pre.append(pool.apply_async(f_proc_load_data, args=(data, )))
+                if (i + 1) % (2 * n_process) == 0 or (i + 1) == len(subset):
+                    for p in output_pre:
+                        input_data, target_data, data = p.get()
+                        nn_pred = self.model.predict(input_data, batch_size=512)
+                        output.append(pool.apply_async(
+                            f_proc_make_prediction,
+                            args=(data, input_data, target_data, nn_pred, frequency, time_dimension, self.forecast_indicator,
+                                  self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim,
+                                  self.forecast_path, subset_type, self.window_lead_time, self.ols_model)))
+                    for p in output:
+                        p.get()
+                        logging.info(f"...finished: {subset[pos]} ({int((pos + 1.) / len(output) * 100)}%)")
+                        pos += 1
+                    output, output_pre = [], []
+            assert len(output) == 0
+            assert len(output_pre) == 0
             pool.close()
             pool.join()
+
+
+
+            # output_pre = [pool.apply_async(f_proc_load_data, args=(data, )) for data in subset]
+            # for p in output_pre:
+            #     input_data, target_data, data = p.get()
+            #     nn_pred = self.model.predict(input_data)
+            #     output.append(pool.apply_async(
+            #         f_proc_make_prediction,
+            #         args=(data, input_data, target_data, nn_pred, frequency, time_dimension, self.forecast_indicator,
+            #               self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim,
+            #               self.forecast_path, subset_type, self.window_lead_time, self.ols_model)))
+            # for i, p in enumerate(output):
+            #     p.get()
+            #     logging.info(f"...finished: {subset[i]} ({int((i + 1.) / len(output) * 100)}%)")
+            # pool.close()
+            # pool.join()
         else:  # serial solution
             logging.info("use serial make prediction approach")
             for i, data in enumerate(subset):
                 input_data, target_data = data.get_data(as_numpy=(True, False))
-                nn_pred = self.model.predict(input_data)
+                nn_pred = self.model.predict(input_data, batch_size=512)
                 f_proc_make_prediction(data, input_data, target_data, nn_pred, frequency, time_dimension, self.forecast_indicator,
                     self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim,
                     self.forecast_path, subset_type, self.window_lead_time, self.ols_model)