diff --git a/mlair/data_handler/data_handler_kz_filter.py b/mlair/data_handler/data_handler_kz_filter.py index 1f2c63e58c7eaab645f074ac953d2f05d8ba09fd..539712b39e51c32203e1c55e28ce2eff24069479 100644 --- a/mlair/data_handler/data_handler_kz_filter.py +++ b/mlair/data_handler/data_handler_kz_filter.py @@ -7,7 +7,7 @@ import inspect import numpy as np import pandas as pd import xarray as xr -from typing import List, Union +from typing import List, Union, Tuple, Optional from functools import partial from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation @@ -37,6 +37,19 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): self.cutoff_period_days = None super().__init__(*args, **kwargs) + def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]: + """ + Adjust setup of transformation because kfz filtered data will have negative values which is not compatible with + the log transformation. Therefore, replace all log transformation methods by a default standardization. This is + only applied on input side. + """ + transformation = super(__class__, self).setup_transformation(transformation) + if transformation[0] is not None: + for k, v in transformation[0].items(): + if v["method"] == "log": + transformation[0][k]["method"] = "standardise" + return transformation + def _check_sampling(self, **kwargs): assert kwargs.get("sampling") == "hourly" # This data handler requires hourly data resolution diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index 86e6f856b7bf061287261ae711063d71ed7c8963..75e9e64506231f32406934b67e65454d87a43f61 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -158,7 +158,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi def _extract_lazy(self, lazy_data): _data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data start_inp, end_inp = self.update_start_end(0) - self._data = list(map(self._slice_prep, _data, [start_inp, self.start], [end_inp, self.end])) + self._data = list(map(lambda x: self._slice_prep(_data[x], *self.update_start_end(x)), [0, 1])) self.input_data = self._slice_prep(_input_data, start_inp, end_inp) self.target_data = self._slice_prep(_target_data, self.start, self.end) diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index 0497bee0ae6b6a72301181ef5453dd40f479e5af..0c83e625fd0e1aa4aafdcf204ed00b813868f4f2 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -252,7 +252,9 @@ class DataHandlerSingleStation(AbstractDataHandler): with open(filename, "rb") as pickle_file: lazy_data = dill.load(pickle_file) self._extract_lazy(lazy_data) + logging.debug(f"{self.station[0]}: used lazy data") except FileNotFoundError: + logging.debug(f"{self.station[0]}: could not use lazy data") self.make_input_target() def _extract_lazy(self, lazy_data): @@ -594,8 +596,7 @@ class DataHandlerSingleStation(AbstractDataHandler): """ return data.loc[{coord: slice(str(start), str(end))}] - @staticmethod - def setup_transformation(transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]: + def setup_transformation(self, transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]: """ Set up transformation by extracting all relevant information. diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 07a866aec1efd43de42f918844abeb7c3bbc9524..5eb6fd026e4dead07ab1a3115640a0d853708313 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -273,7 +273,9 @@ class DefaultDataHandler(AbstractDataHandler): if var not in transformation_dict[i].keys(): transformation_dict[i][var] = {} opts = transformation[var] - assert transformation_dict[i][var].get("method", opts["method"]) == opts["method"] + if not transformation_dict[i][var].get("method", opts["method"]) == opts["method"]: + # data handlers with filters are allowed to change transformation method to standardise + assert hasattr(dh, "filter_dim") and opts["method"] == "standardise" transformation_dict[i][var]["method"] = opts["method"] for k in ["mean", "std", "min", "max"]: old = transformation_dict[i][var].get(k, None) diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 3b9b563426a80816f7cf1ea9e114a8395d9fbba0..73aebb008ebf1f61eb2878293fc160cf549d19cb 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -306,7 +306,7 @@ class PostProcessing(RunEnvironment): try: if ("filter" in self.test_data[0].get_X(as_numpy=False)[0].coords) and ( "PlotSeparationOfScales" in plot_list): - filter_dim = self.data_store.get("filter_dim", None) + filter_dim = self.data_store.get_default("filter_dim", None) PlotSeparationOfScales(self.test_data, plot_folder=self.plot_path, time_dim=time_dim, window_dim=window_dim, target_dim=target_dim, **{"filter_dim": filter_dim}) except Exception as e: