From f9b7105f9fbd534cc659bef7039b270e1d37763f Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Wed, 10 Feb 2021 16:41:10 +0100 Subject: [PATCH] parallel transformation is back online, tests will still fail --- mlair/data_handler/default_data_handler.py | 102 +++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 070da625..a7ce2dc0 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -248,6 +248,108 @@ class DefaultDataHandler(AbstractDataHandler): return if isinstance(transformation_dict, dict): # tuple for (input, target) transformation transformation_dict = copy.deepcopy(transformation_dict), copy.deepcopy(transformation_dict) + + def _inner(): + """Inner method that is performed in both serial and parallel approach.""" + if dh is not None: + for i, transformation in enumerate(dh._transformation): + for var in transformation.keys(): + if var not in transformation_dict[i].keys(): + transformation_dict[i][var] = {} + opts = transformation[var] + assert transformation_dict[i][var].get("method", opts["method"]) == opts["method"] + transformation_dict[i][var]["method"] = opts["method"] + for k in ["mean", "std"]: + old = transformation_dict[i][var].get(k, None) + new = opts.get(k) + transformation_dict[i][var][k] = new if old is None else old.combine_first(new) + + if multiprocessing.cpu_count() > 1: # parallel solution + logging.info("use parallel transformation approach") + pool = multiprocessing.Pool() + logging.info(f"running {getattr(pool, '_processes')} processes in parallel") + output = [ + pool.apply_async(f_proc, args=(cls.data_handler_transformation, station), kwds=sp_keys) + for station in set_stations] + for p in output: + dh, s = p.get() + _inner() + else: # serial solution + logging.info("use serial transformation approach") + for station in set_stations: + dh, s = f_proc(cls.data_handler_transformation, station, **sp_keys) + _inner() + + # aggregate all information + pop_list = [] + for i, transformation in enumerate(transformation_dict): + for k in transformation.keys(): + try: + if transformation[k]["mean"] is not None: + transformation_dict[i][k]["mean"] = transformation[k]["mean"].mean("Stations") + if transformation[k]["std"] is not None: + transformation_dict[i][k]["std"] = transformation[k]["std"].mean("Stations") + except KeyError: + pop_list.append((i, k)) + for (i, k) in pop_list: + transformation_dict[i].pop(k) + return transformation_dict + + # if multiprocessing.cpu_count() > 1: # parallel solution + # logging.info("use parallel transformation approach") + # pool = multiprocessing.Pool() + # logging.info(f"running {getattr(pool, '_processes')} processes in parallel") + # output = [ + # pool.apply_async(f_proc, args=(cls.data_handler_transformation, station), kwds=sp_keys) + # for station in set_stations] + # for p in output: + # dh, s = p.get() + # if dh is not None: + # for i, data in enumerate([dh.input_data, dh.target_data]): + # means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean) + # stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std) + # else: # serial solution + # logging.info("use serial transformation approach") + # for station in set_stations: + # dh, s = f_proc(cls.data_handler_transformation, station, **sp_keys) + # if dh is not None: + # for i, data in enumerate([dh.input_data, dh.target_data]): + # means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean) + # stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std) + + @classmethod + def transformation_old(cls, set_stations, **kwargs): + """ + ### supported transformation methods + + Currently supported methods are: + + * standardise (default, if method is not given) + * centre + + ### mean and std estimation + + Mean and std (depending on method) are estimated. For each station, mean and std are calculated and afterwards + aggregated using the mean value over all station-wise metrics. This method is not exactly accurate, especially + regarding the std calculation but therefore much faster. Furthermore, it is a weighted mean weighted by the + time series length / number of data itself - a longer time series has more influence on the transformation + settings than a short time series. The estimation of the std in less accurate, because the unweighted mean of + all stds in not equal to the true std, but still the mean of all station-wise std is a decent estimate. Finally, + the real accuracy of mean and std is less important, because it is "just" a transformation / scaling. + + ### mean and std given + + If mean and std are not None, the default data handler expects this parameters to match the data and applies + this values to the data. Make sure that all dimensions and/or coordinates are in agreement. + """ + + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} + transformation_dict = sp_keys.get("transformation", None) + if transformation_dict is None: + return + if isinstance(transformation_dict, dict): # tuple for (input, target) transformation + transformation_dict = copy.deepcopy(transformation_dict), copy.deepcopy(transformation_dict) + for station in set_stations: dh, s = f_proc(cls.data_handler_transformation, station, **sp_keys) if dh is not None: -- GitLab