From 8a9c5ce5ed60cd077a6ad4d380c2f68f9edf20c2 Mon Sep 17 00:00:00 2001 From: Felix Kleinert <f.kleinert@fz-juelich.de> Date: Thu, 25 Mar 2021 15:41:54 +0100 Subject: [PATCH] add aggegation_dim to ensure that transformation parameters are calculated on multiple dimensions --- mlair/data_handler/data_handler_single_station.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index e9db27a9..c5b3ae8d 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -18,7 +18,7 @@ import xarray as xr from mlair.configuration import check_path_and_create from mlair import helpers -from mlair.helpers import join, statistics, TimeTrackingWrapper +from mlair.helpers import join, statistics, TimeTrackingWrapper, to_list from mlair.data_handler.abstract_data_handler import AbstractDataHandler # define a more general date type for type hinting @@ -61,7 +61,7 @@ class DataHandlerSingleStation(AbstractDataHandler): interpolation_method: Union[str, Tuple[str]] = DEFAULT_INTERPOLATION_METHOD, overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True, min_length: int = 0, start=None, end=None, variables=None, data_origin: Dict = None, - lazy_preprocessing: bool = False, **kwargs): + lazy_preprocessing: bool = False, aggregation_dim=None, **kwargs): super().__init__() self.station = helpers.to_list(station) self.path = self.setup_data_path(data_path, sampling) @@ -83,6 +83,7 @@ class DataHandlerSingleStation(AbstractDataHandler): self.target_var = target_var self.time_dim = time_dim self.iter_dim = iter_dim + self.aggregation_dim = time_dim if aggregation_dim is None else set(to_list(aggregation_dim)+to_list(time_dim)) self.window_dim = window_dim self.window_history_size = window_history_size self.window_history_offset = window_history_offset @@ -154,10 +155,10 @@ class DataHandlerSingleStation(AbstractDataHandler): def call_transform(self, inverse=False): opts_input = self._transformation[0] - self.input_data, opts_input = self.transform(self.input_data, dim=self.time_dim, inverse=inverse, + self.input_data, opts_input = self.transform(self.input_data, dim=self.aggregation_dim, inverse=inverse, opts=opts_input, transformation_dim=self.target_dim) opts_target = self._transformation[1] - self.target_data, opts_target = self.transform(self.target_data, dim=self.time_dim, inverse=inverse, + self.target_data, opts_target = self.transform(self.target_data, dim=self.aggregation_dim, inverse=inverse, opts=opts_target, transformation_dim=self.target_dim) self._transformation = (opts_input, opts_target) -- GitLab