diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 584151e36fd0c9621d089e88b8ad61cffa0c5925..799dcc45f66cd0519d172a62d15f911823907816 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -11,6 +11,7 @@ import pickle import shutil from functools import reduce from typing import Tuple, Union, List +import multiprocessing import numpy as np import xarray as xr @@ -251,14 +252,28 @@ class DefaultDataHandler(AbstractDataHandler): return means = [None, None] stds = [None, None] - for station in set_stations: - try: - sp = cls.data_handler_transformation(station, **sp_keys) - for i, data in enumerate([sp.input_data, sp.target_data]): - means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean) - stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std) - except (AttributeError, EmptyQueryResult): - continue + + if multiprocessing.cpu_count() > 1: # parallel solution + logging.info("use parallel transformation approach") + pool = multiprocessing.Pool() + output = [ + pool.apply_async(f_proc, args=(cls.data_handler_transformation, station), kwds=sp_keys) + for station in set_stations] + for p in output: + dh, s = p.get() + if dh is not None: + for i, data in enumerate([dh.input_data, dh.target_data]): + means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean) + stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std) + else: # serial solution + logging.info("use serial transformation approach") + for station in set_stations: + dh, s = f_proc(cls.data_handler_transformation, station, **sp_keys) + if dh is not None: + for i, data in enumerate([dh.input_data, dh.target_data]): + means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean) + stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std) + if means[0] is None: return None transformation_class.inputs.mean = means[0].mean("Stations") @@ -268,4 +283,18 @@ class DefaultDataHandler(AbstractDataHandler): return transformation_class def get_coordinates(self): - return self.id_class.get_coordinates() \ No newline at end of file + return self.id_class.get_coordinates() + + +def f_proc(data_handler, station, **sp_keys): + """ + Try to create a data handler for given arguments. If build fails, this station does not fulfil all requirements and + therefore f_proc will return None as indication. On a successful build, f_proc returns the built data handler and + the station that was used. This function must be implemented globally to work together with multiprocessing. + """ + try: + res = data_handler(station, **sp_keys) + except (AttributeError, EmptyQueryResult, KeyError, ValueError) as e: + logging.info(f"remove station {station} because it raised an error: {e}") + res = None + return res, station