Skip to content
Snippets Groups Projects
Commit bb67bc86 authored by lukas leufen's avatar lukas leufen
Browse files

Merge branch 'lukas_issue229_feat_mixed-sampling-decouple-interpolation' into...

Merge branch 'lukas_issue229_feat_mixed-sampling-decouple-interpolation' into 'develop', /close #229

Resolve "mixed sampling decouple interpolation"

See merge request toar/mlair!208
parents bcbf1287 ca24bfc1
Branches
Tags
3 merge requests!226Develop,!225Resolve "release v1.2.0",!208Resolve "mixed sampling decouple interpolation"
Pipeline #54846 passed
...@@ -6,11 +6,12 @@ from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleS ...@@ -6,11 +6,12 @@ from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleS
from mlair.data_handler import DefaultDataHandler from mlair.data_handler import DefaultDataHandler
from mlair import helpers from mlair import helpers
from mlair.helpers import remove_items from mlair.helpers import remove_items
from mlair.configuration.defaults import DEFAULT_SAMPLING from mlair.configuration.defaults import DEFAULT_SAMPLING, DEFAULT_INTERPOLATION_LIMIT, DEFAULT_INTERPOLATION_METHOD
import inspect import inspect
from typing import Callable from typing import Callable
import datetime as dt import datetime as dt
from typing import Any
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -20,11 +21,39 @@ import xarray as xr ...@@ -20,11 +21,39 @@ import xarray as xr
class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
_requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"])
def __init__(self, *args, sampling_inputs, **kwargs): def __init__(self, *args, **kwargs):
sampling = (sampling_inputs, kwargs.get("sampling", DEFAULT_SAMPLING)) """
kwargs.update({"sampling": sampling}) This data handler requires the kwargs sampling, interpolation_limit, and interpolation_method to be a 2D tuple
for input and target data. If one of these kwargs is only a single argument, it will be applied to inputs and
targets with this value. If one of these kwargs is a 2-dim tuple, the first element is applied to inputs and the
second to targets respectively. If one of these kwargs is not provided, it is filled up with the same default
value for inputs and targets.
"""
self.update_kwargs("sampling", DEFAULT_SAMPLING, kwargs)
self.update_kwargs("interpolation_limit", DEFAULT_INTERPOLATION_LIMIT, kwargs)
self.update_kwargs("interpolation_method", DEFAULT_INTERPOLATION_METHOD, kwargs)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@staticmethod
def update_kwargs(parameter_name: str, default: Any, kwargs: dict):
"""
Update a single element of kwargs inplace to be usable for inputs and targets.
The updated value in the kwargs dictionary is a tuple consisting on the value applicable to the inputs as first
element and the target's value as second element: (<value_input>, <value_target>). If the value for the given
parameter_name is already a tuple, it is checked to have exact two entries. If the paramter_name is not
included in kwargs, the given default value is used and applied to both elements of the update tuple.
:param parameter_name: name of the parameter that should be transformed to 2-dim
:param default: the default value to fill if parameter is not in kwargs
:param kwargs: the kwargs dictionary containing parameters
"""
parameter = kwargs.get(parameter_name, default)
if not isinstance(parameter, tuple):
parameter = (parameter, parameter)
assert len(parameter) == 2 # (inputs, targets)
kwargs.update({parameter_name: parameter})
def setup_samples(self): def setup_samples(self):
""" """
Setup samples. This method prepares and creates samples X, and labels Y. Setup samples. This method prepares and creates samples X, and labels Y.
...@@ -41,8 +70,8 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): ...@@ -41,8 +70,8 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind], data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind],
self.station_type, self.network, self.store_data_locally, self.data_origin, self.station_type, self.network, self.store_data_locally, self.data_origin,
self.start, self.end) self.start, self.end)
data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind],
limit=self.interpolation_limit) limit=self.interpolation_limit[ind])
return data return data
def set_inputs_and_targets(self): def set_inputs_and_targets(self):
...@@ -119,7 +148,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi ...@@ -119,7 +148,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
self.station_type, self.network, self.store_data_locally, self.data_origin, self.station_type, self.network, self.store_data_locally, self.data_origin,
start, end) start, end)
data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
limit=self.interpolation_limit) limit=self.interpolation_limit[ind])
return data return data
......
...@@ -225,8 +225,8 @@ class ExperimentSetup(RunEnvironment): ...@@ -225,8 +225,8 @@ class ExperimentSetup(RunEnvironment):
extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None,
number_of_bootstraps=None, number_of_bootstraps=None,
create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None, create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None,
hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, sampling_inputs=None, hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None,
sampling_outputs=None, data_origin: Dict = None, **kwargs): data_origin: Dict = None, **kwargs):
# create run framework # create run framework
super().__init__() super().__init__()
...@@ -294,7 +294,6 @@ class ExperimentSetup(RunEnvironment): ...@@ -294,7 +294,6 @@ class ExperimentSetup(RunEnvironment):
self._set_param("window_history_size", window_history_size, default=DEFAULT_WINDOW_HISTORY_SIZE) self._set_param("window_history_size", window_history_size, default=DEFAULT_WINDOW_HISTORY_SIZE)
self._set_param("overwrite_local_data", overwrite_local_data, default=DEFAULT_OVERWRITE_LOCAL_DATA, self._set_param("overwrite_local_data", overwrite_local_data, default=DEFAULT_OVERWRITE_LOCAL_DATA,
scope="preprocessing") scope="preprocessing")
self._set_param("sampling_inputs", sampling_inputs, default=sampling)
self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION) self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION)
self._set_param("transformation", None, scope="preprocessing") self._set_param("transformation", None, scope="preprocessing")
self._set_param("data_handler", data_handler, default=DefaultDataHandler) self._set_param("data_handler", data_handler, default=DefaultDataHandler)
......
...@@ -7,22 +7,33 @@ from mlair.workflows import DefaultWorkflow ...@@ -7,22 +7,33 @@ from mlair.workflows import DefaultWorkflow
from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, DataHandlerMixedSamplingWithFilter, \ from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, DataHandlerMixedSamplingWithFilter, \
DataHandlerSeparationOfScales DataHandlerSeparationOfScales
stats = {'o3': 'dma8eu', 'no': 'dma8eu', 'no2': 'dma8eu',
'relhum': 'average_values', 'u': 'average_values', 'v': 'average_values',
'cloudcover': 'average_values', 'pblheight': 'maximum',
'temp': 'maximum'}
data_origin = {'o3': '', 'no': '', 'no2': '',
'relhum': 'REA', 'u': 'REA', 'v': 'REA',
'cloudcover': 'REA', 'pblheight': 'REA',
'temp': 'REA'}
def main(parser_args): def main(parser_args):
args = dict(sampling="daily", args = dict(stations=["DEBW107", "DEBW013"],
sampling_inputs="hourly", network="UBA",
window_history_size=24, evaluate_bootstraps=False, plot_list=[],
**parser_args.__dict__, data_origin=data_origin, data_handler=DataHandlerMixedSampling,
data_handler=DataHandlerSeparationOfScales, interpolation_limit=(3, 1), overwrite_local_data=False,
sampling=("hourly", "daily"),
statistics_per_var=stats,
create_new_model=False, train_model=False, epochs=1,
window_history_size=48,
kz_filter_length=[100 * 24, 15 * 24], kz_filter_length=[100 * 24, 15 * 24],
kz_filter_iter=[4, 5], kz_filter_iter=[4, 5],
start="2006-01-01", start="2006-01-01",
train_start="2006-01-01", train_start="2006-01-01",
end="2011-12-31", end="2011-12-31",
test_end="2011-12-31", test_end="2011-12-31",
stations=["DEBW107", "DEBW013"], **parser_args.__dict__,
epochs=100,
network="UBA",
) )
workflow = DefaultWorkflow(**args) workflow = DefaultWorkflow(**args)
workflow.run() workflow.run()
......
__author__ = 'Lukas Leufen'
__date__ = '2020-12-10'
from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, \
DataHandlerMixedSamplingSingleStation, DataHandlerMixedSamplingWithFilter, \
DataHandlerMixedSamplingWithFilterSingleStation, DataHandlerSeparationOfScales, \
DataHandlerSeparationOfScalesSingleStation
from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleStation
from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
from mlair.helpers import remove_items
from mlair.configuration.defaults import DEFAULT_INTERPOLATION_METHOD
import pytest
import mock
class TestDataHandlerMixedSampling:
def test_data_handler(self):
obj = object.__new__(DataHandlerMixedSampling)
assert obj.data_handler.__qualname__ == DataHandlerMixedSamplingSingleStation.__qualname__
def test_data_handler_transformation(self):
obj = object.__new__(DataHandlerMixedSampling)
assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingSingleStation.__qualname__
def test_requirements(self):
obj = object.__new__(DataHandlerMixedSampling)
req = object.__new__(DataHandlerSingleStation)
assert sorted(obj._requirements) == sorted(remove_items(req.requirements(), "station"))
class TestDataHandlerMixedSamplingSingleStation:
def test_requirements(self):
obj = object.__new__(DataHandlerMixedSamplingSingleStation)
req = object.__new__(DataHandlerSingleStation)
assert sorted(obj._requirements) == sorted(remove_items(req.requirements(), "station"))
@mock.patch("mlair.data_handler.data_handler_mixed_sampling.DataHandlerMixedSamplingSingleStation.setup_samples")
def test_init(self, mock_super_init):
obj = DataHandlerMixedSamplingSingleStation("first_arg", "second", {}, test=23, sampling="hourly",
interpolation_limit=(1, 10))
assert obj.sampling == ("hourly", "hourly")
assert obj.interpolation_limit == (1, 10)
assert obj.interpolation_method == (DEFAULT_INTERPOLATION_METHOD, DEFAULT_INTERPOLATION_METHOD)
@pytest.fixture
def kwargs_dict(self):
return {"test1": 2, "param_2": "string", "another": (10, 2)}
def test_update_kwargs_single_to_tuple(self, kwargs_dict):
obj = object.__new__(DataHandlerMixedSamplingSingleStation)
obj.update_kwargs("test1", "23", kwargs_dict)
assert kwargs_dict["test1"] == (2, 2)
obj.update_kwargs("param_2", "23", kwargs_dict)
assert kwargs_dict["param_2"] == ("string", "string")
def test_update_kwargs_tuple(self, kwargs_dict):
obj = object.__new__(DataHandlerMixedSamplingSingleStation)
obj.update_kwargs("another", "23", kwargs_dict)
assert kwargs_dict["another"] == (10, 2)
def test_update_kwargs_default(self, kwargs_dict):
obj = object.__new__(DataHandlerMixedSamplingSingleStation)
obj.update_kwargs("not_existing", "23", kwargs_dict)
assert kwargs_dict["not_existing"] == ("23", "23")
obj.update_kwargs("also_new", (4, 2), kwargs_dict)
assert kwargs_dict["also_new"] == (4, 2)
def test_update_kwargs_assert_failure(self, kwargs_dict):
obj = object.__new__(DataHandlerMixedSamplingSingleStation)
with pytest.raises(AssertionError):
obj.update_kwargs("error_too_long", (1, 2, 3), kwargs_dict)
def test_setup_samples(self):
pass
def test_load_and_interpolate(self):
pass
def test_set_inputs_and_targets(self):
pass
def test_setup_data_path(self):
pass
class TestDataHandlerMixedSamplingWithFilter:
def test_data_handler(self):
obj = object.__new__(DataHandlerMixedSamplingWithFilter)
assert obj.data_handler.__qualname__ == DataHandlerMixedSamplingWithFilterSingleStation.__qualname__
def test_data_handler_transformation(self):
obj = object.__new__(DataHandlerMixedSamplingWithFilter)
assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingWithFilterSingleStation.__qualname__
def test_requirements(self):
obj = object.__new__(DataHandlerMixedSamplingWithFilter)
req1 = object.__new__(DataHandlerMixedSamplingSingleStation)
req2 = object.__new__(DataHandlerKzFilterSingleStation)
req = list(set(req1.requirements() + req2.requirements()))
assert sorted(obj._requirements) == sorted(remove_items(req, "station"))
class TestDataHandlerMixedSamplingWithFilterSingleStation:
pass
class TestDataHandlerSeparationOfScales:
def test_data_handler(self):
obj = object.__new__(DataHandlerSeparationOfScales)
assert obj.data_handler.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__
def test_data_handler_transformation(self):
obj = object.__new__(DataHandlerSeparationOfScales)
assert obj.data_handler_transformation.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__
def test_requirements(self):
obj = object.__new__(DataHandlerMixedSamplingWithFilter)
req1 = object.__new__(DataHandlerMixedSamplingSingleStation)
req2 = object.__new__(DataHandlerKzFilterSingleStation)
req = list(set(req1.requirements() + req2.requirements()))
assert sorted(obj._requirements) == sorted(remove_items(req, "station"))
class TestDataHandlerSeparationOfScalesSingleStation:
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment