diff --git a/mlair/data_handler/data_handler_kz_filter.py b/mlair/data_handler/data_handler_kz_filter.py index adc5ee0e72694baed6ec0ab0c0bf9259126af292..a4f71582ccb842ba45690fcf6db054be44f0bdbd 100644 --- a/mlair/data_handler/data_handler_kz_filter.py +++ b/mlair/data_handler/data_handler_kz_filter.py @@ -49,8 +49,8 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): # import matplotlib # matplotlib.use("TkAgg") # import matplotlib.pyplot as plt - # self.input_data.data.sel(filter="74d", variables="temp", Stations="DEBW107").plot() - # self.input_data.data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") + # self.input_data.sel(filter="74d", variables="temp", Stations="DEBW107").plot() + # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") if self.do_transformation is True: self.call_transform() self.make_samples() @@ -58,11 +58,11 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): @TimeTrackingWrapper def apply_kz_filter(self): """Apply kolmogorov zurbenko filter only on inputs.""" - kz = KZFilter(self.input_data.data, wl=self.kz_filter_length, itr=self.kz_filter_iter, filter_dim="datetime") + kz = KZFilter(self.input_data, wl=self.kz_filter_length, itr=self.kz_filter_iter, filter_dim="datetime") filtered_data: List[xr.DataArray] = kz.run() self.cutoff_period = kz.period_null() self.cutoff_period_days = kz.period_null_days() - self.input_data.data = xr.concat(filtered_data, pd.Index(self.create_filter_index(), name="filter")) + self.input_data = xr.concat(filtered_data, pd.Index(self.create_filter_index(), name="filter")) def create_filter_index(self) -> pd.Index: """ diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index 832a643f2af7c6c2f0510fa1c2cf0353c516f67f..888554c1fd04cf6efbf22e3732fafc7e70760197 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -141,15 +141,14 @@ class DataHandlerSingleStation(AbstractDataHandler): def call_transform(self, inverse=False): opts_input = self._transformation[0] - self.input_data, opts_input = self.transform_new(self.input_data, dim=self.time_dim, inverse=inverse, - opts=opts_input) + self.input_data, opts_input = self.transform(self.input_data, dim=self.time_dim, inverse=inverse, + opts=opts_input) opts_target = self._transformation[1] - self.target_data, opts_target = self.transform_new(self.target_data, dim=self.time_dim, inverse=inverse, - opts=opts_target) + self.target_data, opts_target = self.transform(self.target_data, dim=self.time_dim, inverse=inverse, + opts=opts_target) self._transformation = (opts_input, opts_target) - def transform_new(self, data_in, dim: Union[str, int] = 0, - inverse: bool = False, opts=None): + def transform(self, data_in, dim: Union[str, int] = 0, inverse: bool = False, opts=None): """ Transform data according to given transformation settings. @@ -214,7 +213,7 @@ class DataHandlerSingleStation(AbstractDataHandler): transformed_values.append(values) return xr.concat(transformed_values, dim="variables"), opts_updated # ToDo: replace hardcoded variables dim else: - self.inverse_transform(data_in) # ToDo: add return statement + self.inverse_transform(data_in, opts) # ToDo: add return statement @TimeTrackingWrapper def setup_samples(self): @@ -574,74 +573,8 @@ class DataHandlerSingleStation(AbstractDataHandler): else: raise NotImplementedError("Cannot handle this.") - def transform(self, data_class, dim: Union[str, int] = 0, transform_method: str = 'standardise', - inverse: bool = False, mean=None, - std=None, min=None, max=None, opts=None) -> None: - """ - Transform data according to given transformation settings. - - This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0 - and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale - (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This - method will raise an AssertionError if an internal transform method was already set ('inverse=False') or if the - internal transform method, internal mean and internal standard deviation weren't set ('inverse=True'). - - :param string/int dim: This param is not used for inverse transformation. - | for xarray.DataArray as string: name of dimension which should be standardised - | for pandas.DataFrame as int: axis of dimension which should be standardised - :param method: Choose the transformation method from 'standardise' and 'centre'. 'normalise' is not implemented - yet. This param is not used for inverse transformation. - :param inverse: Switch between transformation and inverse transformation. - :param mean: Used for transformation (if required by 'method') based on external data. If 'None' the mean is - calculated over the data in this class instance. - :param std: Used for transformation (if required by 'method') based on external data. If 'None' the std is - calculated over the data in this class instance. - :param min: Used for transformation (if required by 'method') based on external data. If 'None' min_val is - extracted from the data in this class instance. - :param max: Used for transformation (if required by 'method') based on external data. If 'None' max_val is - extracted from the data in this class instance. - - :return: xarray.DataArrays or pandas.DataFrames: - #. mean: Mean of data - #. std: Standard deviation of data - #. data: Standardised data - """ - - def f(data): - if transform_method == 'standardise': - return statistics.standardise(data, dim) - elif transform_method == 'centre': - return statistics.centre(data, dim) - elif transform_method == 'normalise': - # use min/max of data or given min/max - raise NotImplementedError - else: - raise NotImplementedError - - def f_apply(data): - if transform_method == "standardise": - return mean, std, statistics.standardise_apply(data, mean, std) - elif transform_method == "centre": - return mean, None, statistics.centre_apply(data, mean) - else: - raise NotImplementedError - - if not inverse: - if data_class._method is not None: - raise AssertionError(f"Internal _method is already set. Therefore, data was already transformed with " - f"{data_class._method}. Please perform inverse transformation of data first.") - # apply transformation on local data instance (f) if mean is None, else apply by using mean (and std) from - # external data. - data_class.mean, data_class.std, data_class.data = locals()["f" if mean is None else "f_apply"]( - data_class.data) - - # set transform method to find correct method for inverse transformation. - data_class._method = transform_method - else: - self.inverse_transform(data_class) - @staticmethod - def check_inverse_transform_params(mean: data_or_none, std: data_or_none, method: str) -> None: + def check_inverse_transform_params(method: str, mean: data_or_none, std: data_or_none) -> None: """ Support inverse_transformation method. @@ -660,7 +593,7 @@ class DataHandlerSingleStation(AbstractDataHandler): if len(msg) > 0: raise AttributeError(f"Inverse transform {method} can not be executed because following is None: {msg}") - def inverse_transform(self, data_class) -> None: + def inverse_transform(self, data_in, opts) -> xr.DataArray: """ Perform inverse transformation. @@ -670,24 +603,30 @@ class DataHandlerSingleStation(AbstractDataHandler): current data is not transformed. """ - def f_inverse(data, mean, std, method_inverse): + def f_inverse(data, method_inverse, mean, std): if method_inverse == 'standardise': - return statistics.standardise_inverse(data, mean, std), None, None + return statistics.standardise_inverse(data, mean, std) elif method_inverse == 'centre': - return statistics.centre_inverse(data, mean), None, None + return statistics.centre_inverse(data, mean) elif method_inverse == 'normalise': raise NotImplementedError else: raise NotImplementedError - if data_class.transform_method is None: - raise AssertionError("Inverse transformation method is not set. Data cannot be inverse transformed.") - self.check_inverse_transform_params(data_class.mean, data_class.std, data_class._method) - data_class.data, data_class.mean, data_class.std = f_inverse(data_class.data, data_class.mean, data_class.std, - data_class._method) - data_class.transform_method = None - # update X and Y - self.make_samples() + transformed_values = [] + for var in data_in.variables.values: + data_var = data_in.sel(variables=[var]) # ToDo: replace hardcoded variables dim + var_opts = opts.get(var, {}) + _method = var_opts.get("method", None) + if _method is None: + raise AssertionError("Inverse transformation method is not set. Data cannot be inverse transformed.") + _mean = var_opts.get("mean", None) + _std = var_opts.get("std", None) + self.check_inverse_transform_params(_method, _mean, _std) + + values = f_inverse(data_var, _method, _mean, _std) + transformed_values.append(values) + return xr.concat(transformed_values, dim="variables") # ToDo: replace hardcoded variables dim def get_transformation_targets(self) -> Dict: """ @@ -701,6 +640,10 @@ class DataHandlerSingleStation(AbstractDataHandler): """ return copy.deepcopy(self._transformation[1]) + def apply_transformation(self, data, transformation_opts, dim=0, inverse=False): + + return self.transform(data, dim=dim, opts=transformation_opts, inverse=inverse) + if __name__ == "__main__": # dp = AbstractDataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index b1226be642a52ff630eae3bcaf50309f5164db1e..4b7ec3282d6214179921e8e9f763c63d3b403f71 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -148,6 +148,9 @@ class DefaultDataHandler(AbstractDataHandler): def get_transformation_Y(self): return self.id_class.get_transformation_targets() + def apply_transformation(self, data, transformation_opts, dim=0, inverse=False): + return self.id_class.transform(data, dim=dim, opts=transformation_opts, inverse=inverse) + def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False, timedelta: Tuple[int, str] = (1, 'm'), dim="datetime"): """