diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 4fe6427ddb49f8fe175d90075d85ddd51229b93b..a5e70ba41ace8829422858b04db44390ab146e8d 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -717,26 +717,53 @@ class PostProcessing(RunEnvironment): if n_process > 1 and use_multiprocessing is True: # parallel solution pool = multiprocessing.Pool(n_process) logging.info(f"running {getattr(pool, '_processes')} processes in parallel") + + output = [] - output_pre = [pool.apply_async(f_proc_load_data, args=(data, )) for data in subset] - for p in output_pre: - input_data, target_data, data = p.get() - nn_pred = self.model.predict(input_data) - output.append(pool.apply_async( - f_proc_make_prediction, - args=(data, input_data, target_data, nn_pred, frequency, time_dimension, self.forecast_indicator, - self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim, - self.forecast_path, subset_type, self.window_lead_time, self.ols_model))) - for i, p in enumerate(output): - p.get() - logging.info(f"...finished: {subset[i]} ({int((i + 1.) / len(output) * 100)}%)") + output_pre = [] + pos = 0 + for i, data in enumerate(subset): + output_pre.append(pool.apply_async(f_proc_load_data, args=(data, ))) + if (i + 1) % (2 * n_process) == 0 or (i + 1) == len(subset): + for p in output_pre: + input_data, target_data, data = p.get() + nn_pred = self.model.predict(input_data, batch_size=512) + output.append(pool.apply_async( + f_proc_make_prediction, + args=(data, input_data, target_data, nn_pred, frequency, time_dimension, self.forecast_indicator, + self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim, + self.forecast_path, subset_type, self.window_lead_time, self.ols_model))) + for p in output: + p.get() + logging.info(f"...finished: {subset[pos]} ({int((pos + 1.) / len(output) * 100)}%)") + pos += 1 + output, output_pre = [], [] + assert len(output) == 0 + assert len(output_pre) == 0 pool.close() pool.join() + + + + # output_pre = [pool.apply_async(f_proc_load_data, args=(data, )) for data in subset] + # for p in output_pre: + # input_data, target_data, data = p.get() + # nn_pred = self.model.predict(input_data) + # output.append(pool.apply_async( + # f_proc_make_prediction, + # args=(data, input_data, target_data, nn_pred, frequency, time_dimension, self.forecast_indicator, + # self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim, + # self.forecast_path, subset_type, self.window_lead_time, self.ols_model))) + # for i, p in enumerate(output): + # p.get() + # logging.info(f"...finished: {subset[i]} ({int((i + 1.) / len(output) * 100)}%)") + # pool.close() + # pool.join() else: # serial solution logging.info("use serial make prediction approach") for i, data in enumerate(subset): input_data, target_data = data.get_data(as_numpy=(True, False)) - nn_pred = self.model.predict(input_data) + nn_pred = self.model.predict(input_data, batch_size=512) f_proc_make_prediction(data, input_data, target_data, nn_pred, frequency, time_dimension, self.forecast_indicator, self.observation_indicator, window_dim, self.ahead_dim, self.index_dim, self.model_type_dim, self.forecast_path, subset_type, self.window_lead_time, self.ols_model)