diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index e6381f6c9718608ac50dd93331caf94a06b638d9..4fe6427ddb49f8fe175d90075d85ddd51229b93b 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -8,12 +8,10 @@ import logging import os import sys import traceback -import pathos -import multiprocess.context as ctx +import multiprocessing import psutil from typing import Dict, Tuple, Union, List, Callable -import tensorflow.keras as keras import numpy as np import pandas as pd import xarray as xr @@ -698,15 +696,6 @@ class PostProcessing(RunEnvironment): logging.info(f"start train_ols_model on train data") self.ols_model = OrdinaryLeastSquaredModel(self.train_data) - @staticmethod - def _get_ctx_context(): - _default_context = ctx._default_context - if _default_context._actual_context is None: - _actual_context = _default_context._default_context._name - else: - _actual_context = ctx._default_context._actual_context._name - return _actual_context - @TimeTrackingWrapper def make_prediction(self, subset): """ @@ -726,43 +715,33 @@ class PostProcessing(RunEnvironment): max_process = self.data_store.get("max_number_multiprocessing") n_process = min([psutil.cpu_count(logical=False), len(subset), max_process]) # use only physical cpus if n_process > 1 and use_multiprocessing is True: # parallel solution - for i, data in enumerate(subset): - f_proc_make_prediction(data, frequency, time_dimension, self.forecast_indicator, - self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim, - self.forecast_path, subset_type, self.model, self.window_lead_time, self.ols_model, - use_multiprocessing=True) - logging.info(f"...finished: {data} ({int((i + 1.) / len(subset) * 100)}%)") - # logging.info("use parallel make prediction approach") - # _actual_context = self._get_ctx_context() - # ctx._force_start_method('spawn') - # pool = pathos.multiprocessing.ProcessingPool(n_process) - # logging.info(f"running {getattr(pool, 'ncpus')} processes in parallel") - # output = [ - # pool.apipe(f_proc_make_prediction, data, frequency, time_dimension, self.forecast_indicator, - # self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim, - # self.forecast_path, subset_type, self.model.model, self.window_lead_time, self.ols_model) - # for data in subset] - # for i, p in enumerate(output): - # p.get() - # logging.info(f"...finished: {subset[i]} ({int((i + 1.) / len(output) * 100)}%)") - # pool.close() - # pool.join() - # pool.clear() - # ctx._force_start_method(_actual_context) + pool = multiprocessing.Pool(n_process) + logging.info(f"running {getattr(pool, '_processes')} processes in parallel") + output = [] + output_pre = [pool.apply_async(f_proc_load_data, args=(data, )) for data in subset] + for p in output_pre: + input_data, target_data, data = p.get() + nn_pred = self.model.predict(input_data) + output.append(pool.apply_async( + f_proc_make_prediction, + args=(data, input_data, target_data, nn_pred, frequency, time_dimension, self.forecast_indicator, + self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim, + self.forecast_path, subset_type, self.window_lead_time, self.ols_model))) + for i, p in enumerate(output): + p.get() + logging.info(f"...finished: {subset[i]} ({int((i + 1.) / len(output) * 100)}%)") + pool.close() + pool.join() else: # serial solution logging.info("use serial make prediction approach") for i, data in enumerate(subset): - f_proc_make_prediction(data, frequency, time_dimension, self.forecast_indicator, + input_data, target_data = data.get_data(as_numpy=(True, False)) + nn_pred = self.model.predict(input_data) + f_proc_make_prediction(data, input_data, target_data, nn_pred, frequency, time_dimension, self.forecast_indicator, self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim, - self.forecast_path, subset_type, self.model, self.window_lead_time, self.ols_model, - use_multiprocessing=True) + self.forecast_path, subset_type, self.window_lead_time, self.ols_model) logging.info(f"...finished: {data} ({int((i + 1.) / len(subset) * 100)}%)") - # for i, data in enumerate(subset): - # f_proc_make_prediction(data, frequency, time_dimension, self.forecast_indicator, self.observation_indicator, window_dim, - # self.ahead_dim, self.index_dim, self.model_type_dim, self.forecast_path, subset_type, model, - # self.window_lead_time, self.ols_model) - # # # for i, data in enumerate(subset): @@ -1171,8 +1150,8 @@ class PostProcessing(RunEnvironment): class MakePrediction: - def __init__(self, model, window_lead_time, ols_model): - self.model = model + def __init__(self, nn_pred, window_lead_time, ols_model): + self.nn_pred = nn_pred self.window_lead_time = window_lead_time self.ols_model = ols_model # must be copied maybe @@ -1198,7 +1177,7 @@ class MakePrediction: :return: filled data array with nn predictions """ - tmp_nn = self.model.predict(input_data, use_multiprocessing=use_multiprocessing) + tmp_nn = self.nn_pred if isinstance(tmp_nn, list): nn_prediction.values = tmp_nn[-1] elif tmp_nn.ndim == 3: @@ -1324,21 +1303,17 @@ class MakePrediction: return res -def f_proc_make_prediction(data, frequency, time_dimension, forecast_indicator, observation_indicator, window_dim, - ahead_dim, index_dim, model_type_dim, forecast_path, subset_type, model, window_lead_time, - ols_model, use_multiprocessing=False, custom_objects=None): - # import tensorflow.keras as keras - # if not (hasattr(model, "model") or isinstance(model, keras.Model)): - # print(f"{data} load model") - # model = keras.models.load_model(model, custom_objects=custom_objects) - # model.make_predict_function() - # print(f"{data} done") +def f_proc_load_data(data): + input_data, target_data = data.get_data(as_numpy=(True, False)) + return input_data, target_data, data + - prediction_maker = MakePrediction(model, window_lead_time, ols_model) +def f_proc_make_prediction(data, input_data, target_data, nn_pred, frequency, time_dimension, forecast_indicator, observation_indicator, window_dim, + ahead_dim, index_dim, model_type_dim, forecast_path, subset_type, window_lead_time, + ols_model): + + prediction_maker = MakePrediction(nn_pred, window_lead_time, ols_model) - input_data, target_data = data.get_data(as_numpy=(True, False)) - # input_data = data.get_X() - # target_data = data.get_Y(as_numpy=False) observation_data = data.get_observation() # get scaling parameters @@ -1351,8 +1326,7 @@ def f_proc_make_prediction(data, frequency, time_dimension, forecast_indicator, count=4) # nn forecast - nn_pred = prediction_maker._create_nn_forecast(input_data, nn_pred, transformation_func, normalised, - use_multiprocessing=use_multiprocessing) + nn_pred = prediction_maker._create_nn_forecast(input_data, nn_pred, transformation_func, normalised) # persistence persi_pred = prediction_maker._create_persistence_forecast(observation_data, persi_pred,