diff --git a/mlair/data_handler/data_handler_with_filter.py b/mlair/data_handler/data_handler_with_filter.py index 6a37a44713a963f5cc914aa1754c0cde24071be5..0757e528169268860179613bce2645a505c4352f 100644 --- a/mlair/data_handler/data_handler_with_filter.py +++ b/mlair/data_handler/data_handler_with_filter.py @@ -35,40 +35,24 @@ str_or_list = Union[str, List[str]] # y, h = fir_filter(station_data.values.flatten(), fs, order[0], cutoff_low = cutoff[0][0], cutoff_high = cutoff[0][1], window=window) # filtered = xr.ones_like(station_data) * y.reshape(station_data.values.shape) -class DataHandlerFirFilterSingleStation(DataHandlerSingleStation): - """Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered.""" - _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) - _hash = DataHandlerSingleStation._hash + ["filter_cutoff_period", "filter_order", "filter_window_type", - "filter_dim", "filter_add_unfiltered"] +class DataHandlerFilterSingleStation(DataHandlerSingleStation): + """General data handler for a single station to be used by a superior data handler.""" + + # _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) + _requirements = DataHandlerSingleStation.requirements() + _hash = DataHandlerSingleStation._hash + ["filter_dim"] DEFAULT_FILTER_DIM = "filter" - DEFAULT_WINDOW_TYPE = ("kaiser", 5) - DEFAULT_ADD_UNFILTERED = False - def __init__(self, *args, filter_cutoff_period, filter_order, filter_window_type=DEFAULT_WINDOW_TYPE, - filter_dim=DEFAULT_FILTER_DIM, filter_add_unfiltered=DEFAULT_ADD_UNFILTERED, **kwargs): - # self._check_sampling(**kwargs) + def __init__(self, *args, filter_dim=DEFAULT_FILTER_DIM, **kwargs): # self.original_data = None # ToDo: implement here something to store unfiltered data - self.filter_cutoff_period = (lambda x: [x] if isinstance(x, tuple) else to_list(x))(filter_cutoff_period) - self.filter_cutoff_freq = self._period_to_freq(self.filter_cutoff_period) - assert len(self.filter_cutoff_period) == len(filter_order) - self.filter_order = filter_order - self.filter_window_type = filter_window_type self.filter_dim = filter_dim - self._add_unfiltered = filter_add_unfiltered - self.fs = self._get_fs(**kwargs) - super().__init__(*args, **kwargs) - @staticmethod - def _period_to_freq(cutoff_p): - return list(map(lambda x: (1. / x[0] if x[0] is not None else None, 1. / x[1] if x[1] is not None else None), - cutoff_p)) - def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]: """ - Adjust setup of transformation because kfz filtered data will have negative values which is not compatible with + Adjust setup of transformation because filtered data will have negative values which is not compatible with the log transformation. Therefore, replace all log transformation methods by a default standardization. This is only applied on input side. """ @@ -79,17 +63,6 @@ class DataHandlerFirFilterSingleStation(DataHandlerSingleStation): transformation[0][k]["method"] = "standardise" return transformation - @staticmethod - def _get_fs(**kwargs): - """Return frequency in 1/day (not Hz)""" - sampling = kwargs.get("sampling") - if sampling == "daily": - return 1 - elif sampling == "hourly": - return 24 - else: - raise ValueError(f"Unknown sampling rate {sampling}. Only daily and hourly resolution is supported.") - def _check_sampling(self, **kwargs): assert kwargs.get("sampling") == "hourly" # This data handler requires hourly data resolution, does it? @@ -99,7 +72,7 @@ class DataHandlerFirFilterSingleStation(DataHandlerSingleStation): self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit) self.set_inputs_and_targets() - self.apply_fir_filter() + self.apply_filter() # this is just a code snippet to check the results of the kz filter # import matplotlib # matplotlib.use("TkAgg") @@ -107,8 +80,72 @@ class DataHandlerFirFilterSingleStation(DataHandlerSingleStation): # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot() # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") + def apply_filter(self): + raise NotImplementedError + + def create_filter_index(self) -> pd.Index: + """Create name for filter dimension.""" + raise NotImplementedError + + def get_transposed_history(self) -> xr.DataArray: + """Return history. + + :return: history with dimensions datetime, window, Stations, variables, filter. + """ + return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim, + self.filter_dim).copy() + + def _create_lazy_data(self): + return [self._data, self.meta, self.input_data, self.target_data, self.cutoff_period, self.cutoff_period_days] + + def _extract_lazy(self, lazy_data): + _data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data + f_prep = partial(self._slice_prep, start=self.start, end=self.end) + self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data])) + + +class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): + """Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered.""" + + # _requirements = remove_items(inspect.getfullargspec(DataHandlerFilterSingleStation).args, ["self", "station"]) + _requirements = DataHandlerFilterSingleStation.requirements() + _hash = DataHandlerFilterSingleStation._hash + ["filter_cutoff_period", "filter_order", "filter_window_type", + "filter_add_unfiltered"] + + DEFAULT_WINDOW_TYPE = ("kaiser", 5) + DEFAULT_ADD_UNFILTERED = False + + def __init__(self, *args, filter_cutoff_period, filter_order, filter_window_type=DEFAULT_WINDOW_TYPE, + filter_add_unfiltered=DEFAULT_ADD_UNFILTERED, **kwargs): + # self._check_sampling(**kwargs) + # self.original_data = None # ToDo: implement here something to store unfiltered data + self.filter_cutoff_period = (lambda x: [x] if isinstance(x, tuple) else to_list(x))(filter_cutoff_period) + self.filter_cutoff_freq = self._period_to_freq(self.filter_cutoff_period) + assert len(self.filter_cutoff_period) == len(filter_order) + self.filter_order = filter_order + self.filter_window_type = filter_window_type + self._add_unfiltered = filter_add_unfiltered + self.fs = self._get_fs(**kwargs) + super().__init__(*args, **kwargs) + + @staticmethod + def _period_to_freq(cutoff_p): + return list(map(lambda x: (1. / x[0] if x[0] is not None else None, 1. / x[1] if x[1] is not None else None), + cutoff_p)) + + @staticmethod + def _get_fs(**kwargs): + """Return frequency in 1/day (not Hz)""" + sampling = kwargs.get("sampling") + if sampling == "daily": + return 1 + elif sampling == "hourly": + return 24 + else: + raise ValueError(f"Unknown sampling rate {sampling}. Only daily and hourly resolution is supported.") + @TimeTrackingWrapper - def apply_fir_filter(self): + def apply_filter(self): """Apply FIR filter only on inputs.""" fir = FIRFilter(self.input_data, self.fs, self.filter_order, self.filter_cutoff_freq, self.filter_window_type, self.target_dim) @@ -117,6 +154,12 @@ class DataHandlerFirFilterSingleStation(DataHandlerSingleStation): if self._add_unfiltered is True: fir_data.append(self.input_data) self.input_data = xr.concat(fir_data, pd.Index(self.create_filter_index(), name=self.filter_dim)) + # this is just a code snippet to check the results of the kz filter + # import matplotlib + # matplotlib.use("TkAgg") + # import matplotlib.pyplot as plt + # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot() + # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") def create_filter_index(self) -> pd.Index: """ @@ -138,22 +181,6 @@ class DataHandlerFirFilterSingleStation(DataHandlerSingleStation): index.append("unfiltered") return pd.Index(index, name=self.filter_dim) - def get_transposed_history(self) -> xr.DataArray: - """Return history. - - :return: history with dimensions datetime, window, Stations, variables, filter. - """ - return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim, - self.filter_dim).copy() - - def _create_lazy_data(self): - return [self._data, self.meta, self.input_data, self.target_data, self.cutoff_period, self.cutoff_period_days] - - def _extract_lazy(self, lazy_data): - _data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data - f_prep = partial(self._slice_prep, start=self.start, end=self.end) - self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data])) - class DataHandlerFirFilter(DefaultDataHandler): """Data handler using FIR filtered data.""" @@ -163,62 +190,35 @@ class DataHandlerFirFilter(DefaultDataHandler): _requirements = data_handler.requirements() -class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): +class DataHandlerKzFilterSingleStation(DataHandlerFilterSingleStation): """Data handler for a single station to be used by a superior data handler. Inputs are kz filtered.""" - _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) - _hash = DataHandlerSingleStation._hash + ["kz_filter_length", "kz_filter_iter", "filter_dim"] + _requirements = remove_items(inspect.getfullargspec(DataHandlerFilterSingleStation).args, ["self", "station"]) + _hash = DataHandlerFilterSingleStation._hash + ["kz_filter_length", "kz_filter_iter"] - DEFAULT_FILTER_DIM = "filter" - - def __init__(self, *args, kz_filter_length, kz_filter_iter, filter_dim=DEFAULT_FILTER_DIM, **kwargs): + def __init__(self, *args, kz_filter_length, kz_filter_iter, **kwargs): self._check_sampling(**kwargs) # self.original_data = None # ToDo: implement here something to store unfiltered data self.kz_filter_length = to_list(kz_filter_length) self.kz_filter_iter = to_list(kz_filter_iter) - self.filter_dim = filter_dim self.cutoff_period = None self.cutoff_period_days = None super().__init__(*args, **kwargs) - def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]: - """ - Adjust setup of transformation because kfz filtered data will have negative values which is not compatible with - the log transformation. Therefore, replace all log transformation methods by a default standardization. This is - only applied on input side. - """ - transformation = super(__class__, self).setup_transformation(transformation) - if transformation[0] is not None: - for k, v in transformation[0].items(): - if v["method"] == "log": - transformation[0][k]["method"] = "standardise" - return transformation - - def _check_sampling(self, **kwargs): - assert kwargs.get("sampling") == "hourly" # This data handler requires hourly data resolution - - def make_input_target(self): - data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, - self.station_type, self.network, self.store_data_locally, self.data_origin) - self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, - limit=self.interpolation_limit) - self.set_inputs_and_targets() - self.apply_kz_filter() - # this is just a code snippet to check the results of the kz filter - # import matplotlib - # matplotlib.use("TkAgg") - # import matplotlib.pyplot as plt - # self.input_data.sel(filter="74d", variables="temp", Stations="DEBW107").plot() - # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") - @TimeTrackingWrapper - def apply_kz_filter(self): + def apply_filter(self): """Apply kolmogorov zurbenko filter only on inputs.""" kz = KZFilter(self.input_data, wl=self.kz_filter_length, itr=self.kz_filter_iter, filter_dim=self.time_dim) filtered_data: List[xr.DataArray] = kz.run() self.cutoff_period = kz.period_null() self.cutoff_period_days = kz.period_null_days() self.input_data = xr.concat(filtered_data, pd.Index(self.create_filter_index(), name=self.filter_dim)) + # this is just a code snippet to check the results of the kz filter + # import matplotlib + # matplotlib.use("TkAgg") + # import matplotlib.pyplot as plt + # self.input_data.sel(filter="74d", variables="temp", Stations="DEBW107").plot() + # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") def create_filter_index(self) -> pd.Index: """ @@ -233,22 +233,6 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): index = list(map(lambda x: str(x) + "d", index)) + ["res"] return pd.Index(index, name=self.filter_dim) - def get_transposed_history(self) -> xr.DataArray: - """Return history. - - :return: history with dimensions datetime, window, Stations, variables, filter. - """ - return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim, - self.filter_dim).copy() - - def _create_lazy_data(self): - return [self._data, self.meta, self.input_data, self.target_data, self.cutoff_period, self.cutoff_period_days] - - def _extract_lazy(self, lazy_data): - _data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data - f_prep = partial(self._slice_prep, start=self.start, end=self.end) - self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data])) - class DataHandlerKzFilter(DefaultDataHandler): """Data handler using kz filtered data."""