diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 8ad3e1e7ff583bd511d6311f2ab9de886f440fc9..a02ad89c910b8d898778a748db0e68b3f75fd5f1 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -53,6 +53,7 @@ class DefaultDataHandler(AbstractDataHandler): self._Y = None self._X_extreme = None self._Y_extreme = None + self._data_intersection = None self._use_multiprocessing = use_multiprocessing self._max_number_multiprocessing = max_number_multiprocessing _name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self)) @@ -172,10 +173,15 @@ class DefaultDataHandler(AbstractDataHandler): else: X = list(map(lambda x: x.sel({dim: intersect}), X_original)) Y = Y_original.sel({dim: intersect}) + self._data_intersection = intersect self._X, self._Y = X, Y def get_observation(self): - return self.id_class.observation.copy().squeeze() + dim = self.time_dim + if self._data_intersection is not None: + return self.id_class.observation.sel({dim: self._data_intersection}).copy().squeeze() + else: + return self.id_class.observation.copy().squeeze() def apply_transformation(self, data, base="target", dim=0, inverse=False): return self.id_class.apply_transformation(data, dim=dim, base=base, inverse=inverse) @@ -248,7 +254,7 @@ class DefaultDataHandler(AbstractDataHandler): d.coords[dim] = d.coords[dim].values + np.timedelta64(*timedelta) @classmethod - def transformation(cls, set_stations, tmp_path=None, **kwargs): + def transformation(cls, set_stations, tmp_path=None, dh_transformation=None, **kwargs): """ ### supported transformation methods @@ -278,31 +284,14 @@ class DefaultDataHandler(AbstractDataHandler): If min and max are not None, the default data handler expects this parameters to match the data and applies this values to the data. Make sure that all dimensions and/or coordinates are in agreement. """ + if dh_transformation is None: + dh_transformation = cls.data_handler_transformation - sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in dh_transformation.requirements() if k in kwargs} if "transformation" not in sp_keys.keys(): return transformation_dict = ({}, {}) - def _inner(): - """Inner method that is performed in both serial and parallel approach.""" - if dh is not None: - for i, transformation in enumerate(dh._transformation): - for var in transformation.keys(): - if var not in transformation_dict[i].keys(): - transformation_dict[i][var] = {} - opts = transformation[var] - if not transformation_dict[i][var].get("method", opts["method"]) == opts["method"]: - # data handlers with filters are allowed to change transformation method to standardise - assert hasattr(dh, "filter_dim") and opts["method"] == "standardise" - transformation_dict[i][var]["method"] = opts["method"] - for k in ["mean", "std", "min", "max"]: - old = transformation_dict[i][var].get(k, None) - new = opts.get(k) - transformation_dict[i][var][k] = new if old is None else old.combine_first(new) - if "feature_range" in opts.keys(): - transformation_dict[i][var]["feature_range"] = opts.get("feature_range", None) - max_process = kwargs.get("max_number_multiprocessing", 16) n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process]) # use only physical cpus if n_process > 1 and kwargs.get("use_multiprocessing", True) is True: # parallel solution @@ -311,24 +300,29 @@ class DefaultDataHandler(AbstractDataHandler): logging.info(f"running {getattr(pool, '_processes')} processes in parallel") sp_keys.update({"tmp_path": tmp_path, "return_strategy": "reference"}) output = [ - pool.apply_async(f_proc, args=(cls.data_handler_transformation, station), kwds=sp_keys) + pool.apply_async(f_proc, args=(dh_transformation, station), kwds=sp_keys) for station in set_stations] for p in output: _res_file, s = p.get() with open(_res_file, "rb") as f: dh = dill.load(f) os.remove(_res_file) - _inner() + transformation_dict = cls.update_transformation_dict(dh, transformation_dict) pool.close() else: # serial solution logging.info("use serial transformation approach") sp_keys.update({"return_strategy": "result"}) for station in set_stations: - dh, s = f_proc(cls.data_handler_transformation, station, **sp_keys) - _inner() + dh, s = f_proc(dh_transformation, station, **sp_keys) + transformation_dict = cls.update_transformation_dict(dh, transformation_dict) # aggregate all information iter_dim = sp_keys.get("iter_dim", cls.DEFAULT_ITER_DIM) + transformation_dict = cls.aggregate_transformation(transformation_dict, iter_dim) + return transformation_dict + + @classmethod + def aggregate_transformation(cls, transformation_dict, iter_dim): pop_list = [] for i, transformation in enumerate(transformation_dict): for k in transformation.keys(): @@ -349,6 +343,27 @@ class DefaultDataHandler(AbstractDataHandler): transformation_dict[i].pop(k) return transformation_dict + @classmethod + def update_transformation_dict(cls, dh, transformation_dict): + """Inner method that is performed in both serial and parallel approach.""" + if dh is not None: + for i, transformation in enumerate(dh._transformation): + for var in transformation.keys(): + if var not in transformation_dict[i].keys(): + transformation_dict[i][var] = {} + opts = transformation[var] + if not transformation_dict[i][var].get("method", opts["method"]) == opts["method"]: + # data handlers with filters are allowed to change transformation method to standardise + assert hasattr(dh, "filter_dim") and opts["method"] == "standardise" + transformation_dict[i][var]["method"] = opts["method"] + for k in ["mean", "std", "min", "max"]: + old = transformation_dict[i][var].get(k, None) + new = opts.get(k) + transformation_dict[i][var][k] = new if old is None else old.combine_first(new) + if "feature_range" in opts.keys(): + transformation_dict[i][var]["feature_range"] = opts.get("feature_range", None) + return transformation_dict + def get_coordinates(self): return self.id_class.get_coordinates()