Skip to content
Snippets Groups Projects
Commit 2a3b496a authored by leufen1's avatar leufen1
Browse files

sampling_inputs is replaced by the possibility to use a tuple for parameter...

sampling_inputs is replaced by the possibility to use a tuple for parameter sampling for (inputs, targets). This applies also to interpolation_{method,limit}. Data handlers using mixed sampling types can handle this tuples, for other data handlers it is still required to use a single value.
parent bcbf1287
Branches
Tags
3 merge requests!226Develop,!225Resolve "release v1.2.0",!208Resolve "mixed sampling decouple interpolation"
Pipeline #54795 passed
...@@ -6,11 +6,12 @@ from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleS ...@@ -6,11 +6,12 @@ from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleS
from mlair.data_handler import DefaultDataHandler from mlair.data_handler import DefaultDataHandler
from mlair import helpers from mlair import helpers
from mlair.helpers import remove_items from mlair.helpers import remove_items
from mlair.configuration.defaults import DEFAULT_SAMPLING from mlair.configuration.defaults import DEFAULT_SAMPLING, DEFAULT_INTERPOLATION_LIMIT, DEFAULT_INTERPOLATION_METHOD
import inspect import inspect
from typing import Callable from typing import Callable
import datetime as dt import datetime as dt
from typing import Any
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -20,11 +21,39 @@ import xarray as xr ...@@ -20,11 +21,39 @@ import xarray as xr
class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
_requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"])
def __init__(self, *args, sampling_inputs, **kwargs): def __init__(self, *args, **kwargs):
sampling = (sampling_inputs, kwargs.get("sampling", DEFAULT_SAMPLING)) """
kwargs.update({"sampling": sampling}) This data handler requires the kwargs sampling, interpolation_limit, and interpolation_method to be a 2D tuple
for input and target data. If one of these kwargs is only a single argument, it will be applied to inputs and
targets with this value. If one of these kwargs is a 2-dim tuple, the first element is applied to inputs and the
second to targets respectively. If one of these kwargs is not provided, it is filled up with the same default
value for inputs and targets.
"""
self.update_kwargs("sampling", DEFAULT_SAMPLING, kwargs)
self.update_kwargs("interpolation_limit", DEFAULT_INTERPOLATION_LIMIT, kwargs)
self.update_kwargs("interpolation_method", DEFAULT_INTERPOLATION_METHOD, kwargs)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@staticmethod
def update_kwargs(parameter_name: str, default: Any, kwargs: dict):
"""
Update a single element of kwargs inplace to be usable for inputs and targets.
The updated value in the kwargs dictionary is a tuple consisting on the value applicable to the inputs as first
element and the target's value as second element: (<value_input>, <value_target>). If the value for the given
parameter_name is already a tuple, it is checked to have exact two entries. If the paramter_name is not
included in kwargs, the given default value is used and applied to both elements of the update tuple.
:param parameter_name: name of the parameter that should be transformed to 2-dim
:param default: the default value to fill if parameter is not in kwargs
:param kwargs: the kwargs dictionary containing parameters
"""
parameter = kwargs.get(parameter_name, default)
if not isinstance(parameter, tuple):
parameter = (parameter, parameter)
assert len(parameter) == 2 # (inputs, targets)
kwargs.update({parameter_name: parameter})
def setup_samples(self): def setup_samples(self):
""" """
Setup samples. This method prepares and creates samples X, and labels Y. Setup samples. This method prepares and creates samples X, and labels Y.
...@@ -41,8 +70,8 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): ...@@ -41,8 +70,8 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind], data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind],
self.station_type, self.network, self.store_data_locally, self.data_origin, self.station_type, self.network, self.store_data_locally, self.data_origin,
self.start, self.end) self.start, self.end)
data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind],
limit=self.interpolation_limit) limit=self.interpolation_limit[ind])
return data return data
def set_inputs_and_targets(self): def set_inputs_and_targets(self):
...@@ -119,7 +148,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi ...@@ -119,7 +148,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
self.station_type, self.network, self.store_data_locally, self.data_origin, self.station_type, self.network, self.store_data_locally, self.data_origin,
start, end) start, end)
data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
limit=self.interpolation_limit) limit=self.interpolation_limit[ind])
return data return data
......
...@@ -225,8 +225,8 @@ class ExperimentSetup(RunEnvironment): ...@@ -225,8 +225,8 @@ class ExperimentSetup(RunEnvironment):
extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None,
number_of_bootstraps=None, number_of_bootstraps=None,
create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None, create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None,
hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, sampling_inputs=None, hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None,
sampling_outputs=None, data_origin: Dict = None, **kwargs): data_origin: Dict = None, **kwargs):
# create run framework # create run framework
super().__init__() super().__init__()
...@@ -294,7 +294,6 @@ class ExperimentSetup(RunEnvironment): ...@@ -294,7 +294,6 @@ class ExperimentSetup(RunEnvironment):
self._set_param("window_history_size", window_history_size, default=DEFAULT_WINDOW_HISTORY_SIZE) self._set_param("window_history_size", window_history_size, default=DEFAULT_WINDOW_HISTORY_SIZE)
self._set_param("overwrite_local_data", overwrite_local_data, default=DEFAULT_OVERWRITE_LOCAL_DATA, self._set_param("overwrite_local_data", overwrite_local_data, default=DEFAULT_OVERWRITE_LOCAL_DATA,
scope="preprocessing") scope="preprocessing")
self._set_param("sampling_inputs", sampling_inputs, default=sampling)
self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION) self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION)
self._set_param("transformation", None, scope="preprocessing") self._set_param("transformation", None, scope="preprocessing")
self._set_param("data_handler", data_handler, default=DefaultDataHandler) self._set_param("data_handler", data_handler, default=DefaultDataHandler)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment