diff --git a/mlair/data_handler/data_handler_kz_filter.py b/mlair/data_handler/data_handler_kz_filter.py index f2ff23be83c04ff8acbac116329288136ad979ed..a5a9de6701a03bedad95f54fc22fbd97ee041c86 100644 --- a/mlair/data_handler/data_handler_kz_filter.py +++ b/mlair/data_handler/data_handler_kz_filter.py @@ -32,7 +32,6 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): self.kz_filter_iter = kz_filter_iter self.cutoff_period = None self.cutoff_period_days = None - self.data_target: xr.DataArray = None super().__init__(*args, **kwargs) def setup_samples(self): @@ -41,26 +40,25 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): """ self.load_data() self.interpolate(dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit) + self.set_inputs_and_targets() import matplotlib matplotlib.use("TkAgg") import matplotlib.pyplot as plt - # self.original_data = self.data # ToDo: implement here something to store unfiltered data self.apply_kz_filter() # self.data.sel(filter="74d", variables="temp", Stations="DEBW107").plot() # self.data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") - if self.transformation is not None: + if self.do_transformation is True: self.call_transform() - self.make_samples() # ToDo: target samples are still coming from filtered data + self.make_samples() @TimeTrackingWrapper def apply_kz_filter(self): """Apply kolmogorov zurbenko filter only on inputs.""" - self.data_target = self.data.sel({self.target_dim: [self.target_var]}) - kz = KZFilter(self.data, wl=self.kz_filter_length, itr=self.kz_filter_iter, filter_dim="datetime") + kz = KZFilter(self.input_data.data, wl=self.kz_filter_length, itr=self.kz_filter_iter, filter_dim="datetime") filtered_data: List[xr.DataArray] = kz.run() self.cutoff_period = kz.period_null() self.cutoff_period_days = kz.period_null_days() - self.data = xr.concat(filtered_data, pd.Index(self.create_filter_index(), name="filter")) + self.input_data.data = xr.concat(filtered_data, pd.Index(self.create_filter_index(), name="filter")) def create_filter_index(self) -> pd.Index: """ @@ -75,36 +73,6 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): index = list(map(lambda x: str(x) + "d", index)) + ["res"] return pd.Index(index, name="filter") - def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str, - window: int) -> None: - """ - Create a xr.DataArray containing labels. - - Labels are defined as the consecutive target values (t+1, ...t+n) following the current time step t. Set label - attribute. - - :param dim_name_of_target: Name of dimension which contains the target variable - :param target_var: Name of target variable in 'dimension' - :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied - :param window: lead time of label - """ - window = abs(window) - data = self.data_target.sel({dim_name_of_target: target_var}) - self.label = self.shift(data, dim_name_of_shift, window) - - def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None: - """ - Create a xr.DataArray containing observations. - - Observations are defined as value of the current time step t. Set observation attribute. - - :param dim_name_of_target: Name of dimension which contains the observation variable - :param target_var: Name of observation variable(s) in 'dimension' - :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied - """ - data = self.data_target.sel({dim_name_of_target: target_var}) - self.observation = self.shift(data, dim_name_of_shift, 0) - def get_transposed_history(self) -> xr.DataArray: """Return history. diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index 60ee28f42acb2a5b29c64dfbec8a1b359f56bf77..65144acfa184578b938840afad862f77b728eadb 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -144,7 +144,7 @@ class DataHandlerSingleStation(AbstractDataHandler): self.load_data() self.interpolate(dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit) self.set_inputs_and_targets() - if self.do_transformation: + if self.do_transformation is True: self.call_transform() self.make_samples()