diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 2be09830438310fcd4ca37409fdd109314b5a202..e6381f6c9718608ac50dd93331caf94a06b638d9 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -726,29 +726,36 @@ class PostProcessing(RunEnvironment): max_process = self.data_store.get("max_number_multiprocessing") n_process = min([psutil.cpu_count(logical=False), len(subset), max_process]) # use only physical cpus if n_process > 1 and use_multiprocessing is True: # parallel solution - logging.info("use parallel make prediction approach") - _actual_context = self._get_ctx_context() - ctx._force_start_method('spawn') - pool = pathos.multiprocessing.ProcessingPool(n_process) - logging.info(f"running {getattr(pool, 'ncpus')} processes in parallel") - output = [ - pool.apipe(f_proc_make_prediction, data, frequency, time_dimension, self.forecast_indicator, - self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim, - self.forecast_path, subset_type, self.model.model, self.window_lead_time, self.ols_model) - for data in subset] - for i, p in enumerate(output): - p.get() - logging.info(f"...finished: {subset[i]} ({int((i + 1.) / len(output) * 100)}%)") - pool.close() - pool.join() - pool.clear() - ctx._force_start_method(_actual_context) + for i, data in enumerate(subset): + f_proc_make_prediction(data, frequency, time_dimension, self.forecast_indicator, + self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim, + self.forecast_path, subset_type, self.model, self.window_lead_time, self.ols_model, + use_multiprocessing=True) + logging.info(f"...finished: {data} ({int((i + 1.) / len(subset) * 100)}%)") + # logging.info("use parallel make prediction approach") + # _actual_context = self._get_ctx_context() + # ctx._force_start_method('spawn') + # pool = pathos.multiprocessing.ProcessingPool(n_process) + # logging.info(f"running {getattr(pool, 'ncpus')} processes in parallel") + # output = [ + # pool.apipe(f_proc_make_prediction, data, frequency, time_dimension, self.forecast_indicator, + # self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim, + # self.forecast_path, subset_type, self.model.model, self.window_lead_time, self.ols_model) + # for data in subset] + # for i, p in enumerate(output): + # p.get() + # logging.info(f"...finished: {subset[i]} ({int((i + 1.) / len(output) * 100)}%)") + # pool.close() + # pool.join() + # pool.clear() + # ctx._force_start_method(_actual_context) else: # serial solution logging.info("use serial make prediction approach") for i, data in enumerate(subset): f_proc_make_prediction(data, frequency, time_dimension, self.forecast_indicator, self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim, - self.forecast_path, subset_type, self.model, self.window_lead_time, self.ols_model) + self.forecast_path, subset_type, self.model, self.window_lead_time, self.ols_model, + use_multiprocessing=True) logging.info(f"...finished: {data} ({int((i + 1.) / len(subset) * 100)}%)") # for i, data in enumerate(subset): @@ -1176,7 +1183,7 @@ class MakePrediction: return [target_data.copy() for _ in range(count)] def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, transformation_func: Callable, - normalised: bool) -> xr.DataArray: + normalised: bool, use_multiprocessing: bool = False) -> xr.DataArray: """ Create NN forecast for given input data. @@ -1191,7 +1198,7 @@ class MakePrediction: :return: filled data array with nn predictions """ - tmp_nn = self.model.predict(input_data) + tmp_nn = self.model.predict(input_data, use_multiprocessing=use_multiprocessing) if isinstance(tmp_nn, list): nn_prediction.values = tmp_nn[-1] elif tmp_nn.ndim == 3: @@ -1319,7 +1326,7 @@ class MakePrediction: def f_proc_make_prediction(data, frequency, time_dimension, forecast_indicator, observation_indicator, window_dim, ahead_dim, index_dim, model_type_dim, forecast_path, subset_type, model, window_lead_time, - ols_model, custom_objects=None): + ols_model, use_multiprocessing=False, custom_objects=None): # import tensorflow.keras as keras # if not (hasattr(model, "model") or isinstance(model, keras.Model)): # print(f"{data} load model") @@ -1344,7 +1351,8 @@ def f_proc_make_prediction(data, frequency, time_dimension, forecast_indicator, count=4) # nn forecast - nn_pred = prediction_maker._create_nn_forecast(input_data, nn_pred, transformation_func, normalised) + nn_pred = prediction_maker._create_nn_forecast(input_data, nn_pred, transformation_func, normalised, + use_multiprocessing=use_multiprocessing) # persistence persi_pred = prediction_maker._create_persistence_forecast(observation_data, persi_pred,