From 8a9c5ce5ed60cd077a6ad4d380c2f68f9edf20c2 Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Thu, 25 Mar 2021 15:41:54 +0100
Subject: [PATCH] add aggegation_dim to ensure that transformation parameters
 are calculated on multiple dimensions

---
 mlair/data_handler/data_handler_single_station.py | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py
index e9db27a9..c5b3ae8d 100644
--- a/mlair/data_handler/data_handler_single_station.py
+++ b/mlair/data_handler/data_handler_single_station.py
@@ -18,7 +18,7 @@ import xarray as xr
 
 from mlair.configuration import check_path_and_create
 from mlair import helpers
-from mlair.helpers import join, statistics, TimeTrackingWrapper
+from mlair.helpers import join, statistics, TimeTrackingWrapper, to_list
 from mlair.data_handler.abstract_data_handler import AbstractDataHandler
 
 # define a more general date type for type hinting
@@ -61,7 +61,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
                  interpolation_method: Union[str, Tuple[str]] = DEFAULT_INTERPOLATION_METHOD,
                  overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True,
                  min_length: int = 0, start=None, end=None, variables=None, data_origin: Dict = None,
-                 lazy_preprocessing: bool = False, **kwargs):
+                 lazy_preprocessing: bool = False, aggregation_dim=None, **kwargs):
         super().__init__()
         self.station = helpers.to_list(station)
         self.path = self.setup_data_path(data_path, sampling)
@@ -83,6 +83,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
         self.target_var = target_var
         self.time_dim = time_dim
         self.iter_dim = iter_dim
+        self.aggregation_dim = time_dim if aggregation_dim is None else set(to_list(aggregation_dim)+to_list(time_dim))
         self.window_dim = window_dim
         self.window_history_size = window_history_size
         self.window_history_offset = window_history_offset
@@ -154,10 +155,10 @@ class DataHandlerSingleStation(AbstractDataHandler):
 
     def call_transform(self, inverse=False):
         opts_input = self._transformation[0]
-        self.input_data, opts_input = self.transform(self.input_data, dim=self.time_dim, inverse=inverse,
+        self.input_data, opts_input = self.transform(self.input_data, dim=self.aggregation_dim, inverse=inverse,
                                                      opts=opts_input, transformation_dim=self.target_dim)
         opts_target = self._transformation[1]
-        self.target_data, opts_target = self.transform(self.target_data, dim=self.time_dim, inverse=inverse,
+        self.target_data, opts_target = self.transform(self.target_data, dim=self.aggregation_dim, inverse=inverse,
                                                        opts=opts_target, transformation_dim=self.target_dim)
         self._transformation = (opts_input, opts_target)
 
-- 
GitLab