From 2a3b496ac9bcbb4c2f60ca5f66f76858664db802 Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Thu, 10 Dec 2020 11:09:40 +0100 Subject: [PATCH] sampling_inputs is replaced by the possibility to use a tuple for parameter sampling for (inputs, targets). This applies also to interpolation_{method,limit}. Data handlers using mixed sampling types can handle this tuples, for other data handlers it is still required to use a single value. --- .../data_handler_mixed_sampling.py | 43 ++++++++++++++++--- mlair/run_modules/experiment_setup.py | 5 +-- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index c16dfb43..6d1e66bc 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -6,11 +6,12 @@ from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleS from mlair.data_handler import DefaultDataHandler from mlair import helpers from mlair.helpers import remove_items -from mlair.configuration.defaults import DEFAULT_SAMPLING +from mlair.configuration.defaults import DEFAULT_SAMPLING, DEFAULT_INTERPOLATION_LIMIT, DEFAULT_INTERPOLATION_METHOD import inspect from typing import Callable import datetime as dt +from typing import Any import numpy as np import pandas as pd @@ -20,11 +21,39 @@ import xarray as xr class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) - def __init__(self, *args, sampling_inputs, **kwargs): - sampling = (sampling_inputs, kwargs.get("sampling", DEFAULT_SAMPLING)) - kwargs.update({"sampling": sampling}) + def __init__(self, *args, **kwargs): + """ + This data handler requires the kwargs sampling, interpolation_limit, and interpolation_method to be a 2D tuple + for input and target data. If one of these kwargs is only a single argument, it will be applied to inputs and + targets with this value. If one of these kwargs is a 2-dim tuple, the first element is applied to inputs and the + second to targets respectively. If one of these kwargs is not provided, it is filled up with the same default + value for inputs and targets. + """ + self.update_kwargs("sampling", DEFAULT_SAMPLING, kwargs) + self.update_kwargs("interpolation_limit", DEFAULT_INTERPOLATION_LIMIT, kwargs) + self.update_kwargs("interpolation_method", DEFAULT_INTERPOLATION_METHOD, kwargs) super().__init__(*args, **kwargs) + @staticmethod + def update_kwargs(parameter_name: str, default: Any, kwargs: dict): + """ + Update a single element of kwargs inplace to be usable for inputs and targets. + + The updated value in the kwargs dictionary is a tuple consisting on the value applicable to the inputs as first + element and the target's value as second element: (<value_input>, <value_target>). If the value for the given + parameter_name is already a tuple, it is checked to have exact two entries. If the paramter_name is not + included in kwargs, the given default value is used and applied to both elements of the update tuple. + + :param parameter_name: name of the parameter that should be transformed to 2-dim + :param default: the default value to fill if parameter is not in kwargs + :param kwargs: the kwargs dictionary containing parameters + """ + parameter = kwargs.get(parameter_name, default) + if not isinstance(parameter, tuple): + parameter = (parameter, parameter) + assert len(parameter) == 2 # (inputs, targets) + kwargs.update({parameter_name: parameter}) + def setup_samples(self): """ Setup samples. This method prepares and creates samples X, and labels Y. @@ -41,8 +70,8 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind], self.station_type, self.network, self.store_data_locally, self.data_origin, self.start, self.end) - data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, - limit=self.interpolation_limit) + data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind], + limit=self.interpolation_limit[ind]) return data def set_inputs_and_targets(self): @@ -119,7 +148,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi self.station_type, self.network, self.store_data_locally, self.data_origin, start, end) data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, - limit=self.interpolation_limit) + limit=self.interpolation_limit[ind]) return data diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index 9a9253ed..54d23077 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -225,8 +225,8 @@ class ExperimentSetup(RunEnvironment): extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, number_of_bootstraps=None, create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None, - hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, sampling_inputs=None, - sampling_outputs=None, data_origin: Dict = None, **kwargs): + hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, + data_origin: Dict = None, **kwargs): # create run framework super().__init__() @@ -294,7 +294,6 @@ class ExperimentSetup(RunEnvironment): self._set_param("window_history_size", window_history_size, default=DEFAULT_WINDOW_HISTORY_SIZE) self._set_param("overwrite_local_data", overwrite_local_data, default=DEFAULT_OVERWRITE_LOCAL_DATA, scope="preprocessing") - self._set_param("sampling_inputs", sampling_inputs, default=sampling) self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION) self._set_param("transformation", None, scope="preprocessing") self._set_param("data_handler", data_handler, default=DefaultDataHandler) -- GitLab