From 2a3b496ac9bcbb4c2f60ca5f66f76858664db802 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Thu, 10 Dec 2020 11:09:40 +0100
Subject: [PATCH] sampling_inputs is replaced by the possibility to use a tuple
 for parameter sampling for (inputs, targets). This applies also to
 interpolation_{method,limit}. Data handlers using mixed sampling types can
 handle this tuples, for other data handlers it is still required to use a
 single value.

---
 .../data_handler_mixed_sampling.py            | 43 ++++++++++++++++---
 mlair/run_modules/experiment_setup.py         |  5 +--
 2 files changed, 38 insertions(+), 10 deletions(-)

diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py
index c16dfb43..6d1e66bc 100644
--- a/mlair/data_handler/data_handler_mixed_sampling.py
+++ b/mlair/data_handler/data_handler_mixed_sampling.py
@@ -6,11 +6,12 @@ from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleS
 from mlair.data_handler import DefaultDataHandler
 from mlair import helpers
 from mlair.helpers import remove_items
-from mlair.configuration.defaults import DEFAULT_SAMPLING
+from mlair.configuration.defaults import DEFAULT_SAMPLING, DEFAULT_INTERPOLATION_LIMIT, DEFAULT_INTERPOLATION_METHOD
 
 import inspect
 from typing import Callable
 import datetime as dt
+from typing import Any
 
 import numpy as np
 import pandas as pd
@@ -20,11 +21,39 @@ import xarray as xr
 class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
     _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"])
 
-    def __init__(self, *args, sampling_inputs, **kwargs):
-        sampling = (sampling_inputs, kwargs.get("sampling", DEFAULT_SAMPLING))
-        kwargs.update({"sampling": sampling})
+    def __init__(self, *args, **kwargs):
+        """
+        This data handler requires the kwargs sampling, interpolation_limit, and interpolation_method to be a 2D tuple
+        for input and target data. If one of these kwargs is only a single argument, it will be applied to inputs and
+        targets with this value. If one of these kwargs is a 2-dim tuple, the first element is applied to inputs and the
+        second to targets respectively. If one of these kwargs is not provided, it is filled up with the same default
+        value for inputs and targets.
+        """
+        self.update_kwargs("sampling", DEFAULT_SAMPLING, kwargs)
+        self.update_kwargs("interpolation_limit", DEFAULT_INTERPOLATION_LIMIT, kwargs)
+        self.update_kwargs("interpolation_method", DEFAULT_INTERPOLATION_METHOD, kwargs)
         super().__init__(*args, **kwargs)
 
+    @staticmethod
+    def update_kwargs(parameter_name: str, default: Any, kwargs: dict):
+        """
+        Update a single element of kwargs inplace to be usable for inputs and targets.
+
+        The updated value in the kwargs dictionary is a tuple consisting on the value applicable to the inputs as first
+        element and the target's value as second element: (<value_input>, <value_target>). If the value for the given
+        parameter_name is already a tuple, it is checked to have exact two entries. If the paramter_name is not
+        included in kwargs, the given default value is used and applied to both elements of the update tuple.
+
+        :param parameter_name: name of the parameter that should be transformed to 2-dim
+        :param default: the default value to fill if parameter is not in kwargs
+        :param kwargs: the kwargs dictionary containing parameters
+        """
+        parameter = kwargs.get(parameter_name, default)
+        if not isinstance(parameter, tuple):
+            parameter = (parameter, parameter)
+        assert len(parameter) == 2  # (inputs, targets)
+        kwargs.update({parameter_name: parameter})
+
     def setup_samples(self):
         """
         Setup samples. This method prepares and creates samples X, and labels Y.
@@ -41,8 +70,8 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
         data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind],
                                          self.station_type, self.network, self.store_data_locally, self.data_origin,
                                          self.start, self.end)
-        data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
-                                limit=self.interpolation_limit)
+        data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind],
+                                limit=self.interpolation_limit[ind])
         return data
 
     def set_inputs_and_targets(self):
@@ -119,7 +148,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
                                          self.station_type, self.network, self.store_data_locally, self.data_origin,
                                          start, end)
         data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
-                                limit=self.interpolation_limit)
+                                limit=self.interpolation_limit[ind])
         return data
 
 
diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index 9a9253ed..54d23077 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -225,8 +225,8 @@ class ExperimentSetup(RunEnvironment):
                  extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None,
                  number_of_bootstraps=None,
                  create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None,
-                 hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, sampling_inputs=None,
-                 sampling_outputs=None, data_origin: Dict = None, **kwargs):
+                 hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None,
+                 data_origin: Dict = None, **kwargs):
 
         # create run framework
         super().__init__()
@@ -294,7 +294,6 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("window_history_size", window_history_size, default=DEFAULT_WINDOW_HISTORY_SIZE)
         self._set_param("overwrite_local_data", overwrite_local_data, default=DEFAULT_OVERWRITE_LOCAL_DATA,
                         scope="preprocessing")
-        self._set_param("sampling_inputs", sampling_inputs, default=sampling)
         self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION)
         self._set_param("transformation", None, scope="preprocessing")
         self._set_param("data_handler", data_handler, default=DefaultDataHandler)
-- 
GitLab