diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index 8205ae6c28f3683b1052c292e5d063d8bca555dc..ae2e6a1a303076c4da1e7b00ae6653336a633364 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -12,6 +12,7 @@ from mlair.helpers import remove_items from mlair.configuration.defaults import DEFAULT_SAMPLING, DEFAULT_INTERPOLATION_LIMIT, DEFAULT_INTERPOLATION_METHOD from mlair.helpers.filter import filter_width_kzf +import copy import inspect from typing import Callable import datetime as dt @@ -140,13 +141,6 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]: start, end = self.update_start_end(ind) - # if ind == 0: # for inputs - # estimated_filter_width = self.estimate_filter_width() - # start = self._add_time_delta(self.start, -estimated_filter_width) - # end = self._add_time_delta(self.end, estimated_filter_width) - # else: # target - # start, end = self.start, self.end - vars = [self.variables, self.target_var] stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind]) @@ -264,8 +258,83 @@ class DataHandlerMixedSamplingWithClimateFirFilter(DataHandlerClimateFirFilter): data_handler = DataHandlerMixedSamplingWithClimateFirFilterSingleStation data_handler_transformation = DataHandlerMixedSamplingWithClimateFirFilterSingleStation - _requirements = data_handler.requirements() + data_handler_unfiltered = DataHandlerMixedSamplingSingleStation + _requirements = list(set(data_handler.requirements() + data_handler_unfiltered.requirements())) + DEFAULT_FILTER_ADD_UNFILTERED = False + + def __init__(self, *args, data_handler_class_unfiltered: data_handler_unfiltered = None, + filter_add_unfiltered: bool = DEFAULT_FILTER_ADD_UNFILTERED, **kwargs): + self.dh_unfiltered = data_handler_class_unfiltered + self.filter_add_unfiltered = filter_add_unfiltered + super().__init__(*args, **kwargs) + @classmethod + def own_args(cls, *args): + """Return all arguments (including kwonlyargs).""" + super_own_args = DataHandlerClimateFirFilter.own_args(*args) + arg_spec = inspect.getfullargspec(cls) + list_of_args = arg_spec.args + arg_spec.kwonlyargs + super_own_args + return remove_items(list_of_args, ["self"] + list(args)) + + def _create_collection(self): + if self.filter_add_unfiltered is True and self.dh_unfiltered is not None: + return [self.id_class, self.dh_unfiltered] + else: + return super()._create_collection() + + @classmethod + def build(cls, station: str, **kwargs): + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler.requirements() if k in kwargs} + filter_add_unfiltered = kwargs.get("filter_add_unfiltered", False) + sp_keys = cls.build_update_kwargs(sp_keys, dh_type="filtered") + sp = cls.data_handler(station, **sp_keys) + if filter_add_unfiltered is True: + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_unfiltered.requirements() if k in kwargs} + sp_keys = cls.build_update_kwargs(sp_keys, dh_type="unfiltered") + sp_unfiltered = cls.data_handler_unfiltered(station, **sp_keys) + else: + sp_unfiltered = None + dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs} + return cls(sp, data_handler_class_unfiltered=sp_unfiltered, **dp_args) + + @classmethod + def build_update_kwargs(cls, kwargs_dict, dh_type="filtered"): + if "transformation" in kwargs_dict: + trafo_opts = kwargs_dict.get("transformation") + if isinstance(trafo_opts, dict): + kwargs_dict["transformation"] = trafo_opts.get(dh_type) + return kwargs_dict + + @classmethod + def transformation(cls, set_stations, tmp_path=None, **kwargs): + + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} + if "transformation" not in sp_keys.keys(): + return + + transformation_filtered = super().transformation(set_stations, tmp_path=tmp_path, + dh_transformation=cls.data_handler_transformation, **kwargs) + if kwargs.get("filter_add_unfiltered", False) is False: + return transformation_filtered + else: + transformation_unfiltered = super().transformation(set_stations, tmp_path=tmp_path, + dh_transformation=cls.data_handler_unfiltered, **kwargs) + return {"filtered": transformation_filtered, "unfiltered": transformation_unfiltered} + + def get_X_original(self): + if self.use_filter_branches is True: + X = [] + for data in self._collection: + if hasattr(data, "filter_dim"): + X_total = data.get_X() + filter_dim = data.filter_dim + for filter_name in data.filter_dim_order: + X.append(X_total.sel({filter_dim: filter_name}, drop=True)) + else: + X.append(data.get_X()) + return X + else: + return super().get_X_original() class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithKzFilterSingleStation): """ diff --git a/mlair/data_handler/data_handler_with_filter.py b/mlair/data_handler/data_handler_with_filter.py index 53cfa4cad4ade0f5ed988a8598fb8f4fe70a1779..4707fd580562a68fd6b2dc0843551905e70d7e50 100644 --- a/mlair/data_handler/data_handler_with_filter.py +++ b/mlair/data_handler/data_handler_with_filter.py @@ -127,32 +127,16 @@ class DataHandlerFilter(DefaultDataHandler): list_of_args = arg_spec.args + arg_spec.kwonlyargs + super_own_args return remove_items(list_of_args, ["self"] + list(args)) - def get_X_original(self): - if self.use_filter_branches is True: - X = [] - for data in self._collection: - X_total = data.get_X() - filter_dim = data.filter_dim - for filter_name in data.filter_dim_order: - X.append(X_total.sel({filter_dim: filter_name}, drop=True)) - return X - else: - return super().get_X_original() - class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): """Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered.""" _requirements = remove_items(DataHandlerFilterSingleStation.requirements(), "station") - _hash = DataHandlerFilterSingleStation._hash + ["filter_cutoff_period", "filter_order", "filter_window_type", - "_add_unfiltered"] + _hash = DataHandlerFilterSingleStation._hash + ["filter_cutoff_period", "filter_order", "filter_window_type"] DEFAULT_WINDOW_TYPE = ("kaiser", 5) - DEFAULT_ADD_UNFILTERED = False - def __init__(self, *args, filter_cutoff_period, filter_order, filter_window_type=DEFAULT_WINDOW_TYPE, - filter_add_unfiltered=DEFAULT_ADD_UNFILTERED, **kwargs): - # self._check_sampling(**kwargs) + def __init__(self, *args, filter_cutoff_period, filter_order, filter_window_type=DEFAULT_WINDOW_TYPE, **kwargs): # self.original_data = None # ToDo: implement here something to store unfiltered data self.fs = self._get_fs(**kwargs) if filter_window_type == "kzf": @@ -162,7 +146,6 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): assert len(self.filter_cutoff_period) == (len(filter_order) - len(removed_index)) self.filter_order = self._prepare_filter_order(filter_order, removed_index, self.fs) self.filter_window_type = filter_window_type - self._add_unfiltered = filter_add_unfiltered self.unfiltered_name = "unfiltered" super().__init__(*args, **kwargs) @@ -225,8 +208,6 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): self.filter_window_type, self.target_dim) self.fir_coeff = fir.filter_coefficients() fir_data = fir.filtered_data() - if self._add_unfiltered is True: - fir_data.append(self.input_data) self.input_data = xr.concat(fir_data, pd.Index(self.create_filter_index(), name=self.filter_dim)) # this is just a code snippet to check the results of the kz filter # import matplotlib @@ -251,8 +232,6 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): else: index.append(f"band{band_num}") band_num += 1 - if self._add_unfiltered: - index.append(self.unfiltered_name) self.filter_dim_order = index return pd.Index(index, name=self.filter_dim) @@ -263,20 +242,6 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): _data, _meta, _input_data, _target_data, self.fir_coeff, self.filter_dim_order = lazy_data super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data)) - def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]: - """ - Adjust setup of transformation because filtered data will have negative values which is not compatible with - the log transformation. Therefore, replace all log transformation methods by a default standardization. This is - only applied on input side. - """ - transformation = DataHandlerSingleStation.setup_transformation(self, transformation) - if transformation[0] is not None: - unfiltered_option = lambda x: f"{x}/standardise" if self._add_unfiltered is True else "standardise" - for k, v in transformation[0].items(): - if v["method"] in ["log", "min_max"]: - transformation[0][k]["method"] = unfiltered_option(v["method"]) - return transformation - def transform(self, data_in, dim: Union[str, int] = 0, inverse: bool = False, opts=None, transformation_dim=None): """ @@ -360,7 +325,6 @@ class DataHandlerFirFilter(DataHandlerFilter): data_handler = DataHandlerFirFilterSingleStation data_handler_transformation = DataHandlerFirFilterSingleStation - _requirements = data_handler.requirements() class DataHandlerKzFilterSingleStation(DataHandlerFilterSingleStation): @@ -489,13 +453,6 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation input_data = xr.concat(climate_filter_data, pd.Index(self.create_filter_index(add_unfiltered_index=False), name=self.filter_dim)) - # add unfiltered raw data - if self._add_unfiltered is True: - data_raw = self.shift(self.input_data, self.time_dim, -self.window_history_size) - filter_dim = self.create_filter_index(add_unfiltered_index=True)[-1] - data_raw = data_raw.expand_dims({self.filter_dim: [filter_dim]}, -1) - input_data = xr.concat([input_data, data_raw], self.filter_dim) - self.input_data = input_data # this is just a code snippet to check the results of the filter @@ -516,8 +473,6 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1) index = list(map(f, index.tolist())) index = list(map(lambda x: str(x) + "d", index)) + ["res"] - if self._add_unfiltered and add_unfiltered_index: - index.append(self.unfiltered_name) self.filter_dim_order = index return pd.Index(index, name=self.filter_dim) @@ -586,11 +541,3 @@ class DataHandlerClimateFirFilter(DataHandlerFilter): _requirements = data_handler.requirements() _store_attributes = data_handler.store_attributes() - # def get_X_original(self): - # X = [] - # for data in self._collection: - # X_total = data.get_X() - # filter_dim = data.filter_dim - # for filter in data.filter_dim_order: - # X.append(X_total.sel({filter_dim: filter}, drop=True)) - # return X