Skip to content
Snippets Groups Projects
Commit 6c913d92 authored by leufen1's avatar leufen1
Browse files

Merge branch 'develop' into lukas_issue220_refac_history-for-mixed-sampling-data-handler

# Conflicts:
#	run_mixed_sampling.py
parents dd024db9 bb67bc86
No related branches found
No related tags found
3 merge requests!226Develop,!225Resolve "release v1.2.0",!207Resolve "REFAC: history for mixed sampling data handler"
Pipeline #54876 passed
This commit is part of merge request !225. Comments created here will be created in the context of that merge request.
......@@ -49,6 +49,8 @@ Thumbs.db
##############################
run_*_develgpus.bash
run_*_gpus.bash
run_*_batch.bash
activate_env.sh
# don't check data and plot folder #
####################################
......
......@@ -6,11 +6,12 @@ from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleS
from mlair.data_handler import DefaultDataHandler
from mlair import helpers
from mlair.helpers import remove_items
from mlair.configuration.defaults import DEFAULT_SAMPLING
from mlair.configuration.defaults import DEFAULT_SAMPLING, DEFAULT_INTERPOLATION_LIMIT, DEFAULT_INTERPOLATION_METHOD
import inspect
from typing import Callable
import datetime as dt
from typing import Any
import numpy as np
import pandas as pd
......@@ -20,11 +21,39 @@ import xarray as xr
class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
_requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"])
def __init__(self, *args, sampling_inputs, **kwargs):
sampling = (sampling_inputs, kwargs.get("sampling", DEFAULT_SAMPLING))
kwargs.update({"sampling": sampling})
def __init__(self, *args, **kwargs):
"""
This data handler requires the kwargs sampling, interpolation_limit, and interpolation_method to be a 2D tuple
for input and target data. If one of these kwargs is only a single argument, it will be applied to inputs and
targets with this value. If one of these kwargs is a 2-dim tuple, the first element is applied to inputs and the
second to targets respectively. If one of these kwargs is not provided, it is filled up with the same default
value for inputs and targets.
"""
self.update_kwargs("sampling", DEFAULT_SAMPLING, kwargs)
self.update_kwargs("interpolation_limit", DEFAULT_INTERPOLATION_LIMIT, kwargs)
self.update_kwargs("interpolation_method", DEFAULT_INTERPOLATION_METHOD, kwargs)
super().__init__(*args, **kwargs)
@staticmethod
def update_kwargs(parameter_name: str, default: Any, kwargs: dict):
"""
Update a single element of kwargs inplace to be usable for inputs and targets.
The updated value in the kwargs dictionary is a tuple consisting on the value applicable to the inputs as first
element and the target's value as second element: (<value_input>, <value_target>). If the value for the given
parameter_name is already a tuple, it is checked to have exact two entries. If the paramter_name is not
included in kwargs, the given default value is used and applied to both elements of the update tuple.
:param parameter_name: name of the parameter that should be transformed to 2-dim
:param default: the default value to fill if parameter is not in kwargs
:param kwargs: the kwargs dictionary containing parameters
"""
parameter = kwargs.get(parameter_name, default)
if not isinstance(parameter, tuple):
parameter = (parameter, parameter)
assert len(parameter) == 2 # (inputs, targets)
kwargs.update({parameter_name: parameter})
def setup_samples(self):
"""
Setup samples. This method prepares and creates samples X, and labels Y.
......@@ -41,8 +70,8 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind],
self.station_type, self.network, self.store_data_locally, self.data_origin,
self.start, self.end)
data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
limit=self.interpolation_limit)
data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind],
limit=self.interpolation_limit[ind])
return data
def set_inputs_and_targets(self):
......@@ -119,7 +148,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
self.station_type, self.network, self.store_data_locally, self.data_origin,
start, end)
data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
limit=self.interpolation_limit)
limit=self.interpolation_limit[ind])
return data
......
......@@ -225,8 +225,8 @@ class ExperimentSetup(RunEnvironment):
extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None,
number_of_bootstraps=None,
create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None,
hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, sampling_inputs=None,
sampling_outputs=None, data_origin: Dict = None, **kwargs):
hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None,
data_origin: Dict = None, **kwargs):
# create run framework
super().__init__()
......@@ -294,7 +294,6 @@ class ExperimentSetup(RunEnvironment):
self._set_param("window_history_size", window_history_size, default=DEFAULT_WINDOW_HISTORY_SIZE)
self._set_param("overwrite_local_data", overwrite_local_data, default=DEFAULT_OVERWRITE_LOCAL_DATA,
scope="preprocessing")
self._set_param("sampling_inputs", sampling_inputs, default=sampling)
self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION)
self._set_param("transformation", None, scope="preprocessing")
self._set_param("data_handler", data_handler, default=DefaultDataHandler)
......
......@@ -7,23 +7,34 @@ from mlair.workflows import DefaultWorkflow
from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, DataHandlerMixedSamplingWithFilter, \
DataHandlerSeparationOfScales
stats = {'o3': 'dma8eu', 'no': 'dma8eu', 'no2': 'dma8eu',
'relhum': 'average_values', 'u': 'average_values', 'v': 'average_values',
'cloudcover': 'average_values', 'pblheight': 'maximum',
'temp': 'maximum'}
data_origin = {'o3': '', 'no': '', 'no2': '',
'relhum': 'REA', 'u': 'REA', 'v': 'REA',
'cloudcover': 'REA', 'pblheight': 'REA',
'temp': 'REA'}
def main(parser_args):
args = dict(sampling="daily",
sampling_inputs="hourly",
window_history_size=24,
args = dict(stations=["DEBW107", "DEBW013"],
network="UBA",
evaluate_bootstraps=False, plot_list=[],
data_origin=data_origin, data_handler=DataHandlerMixedSampling,
interpolation_limit=(3, 1), overwrite_local_data=False,
sampling=("hourly", "daily"),
statistics_per_var=stats,
create_new_model=False, train_model=False, epochs=1,
window_history_size=48,
window_history_offset=17,
**parser_args.__dict__,
data_handler=DataHandlerMixedSampling,
kz_filter_length=[100 * 24, 15 * 24],
kz_filter_iter=[4, 5],
start="2006-01-01",
train_start="2006-01-01",
end="2011-12-31",
test_end="2011-12-31",
stations=["DEBW107", "DEBW013"],
epochs=1,
network="UBA",
**parser_args.__dict__,
)
workflow = DefaultWorkflow(**args)
workflow.run()
......
__author__ = 'Lukas Leufen'
__date__ = '2020-12-10'
from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, \
DataHandlerMixedSamplingSingleStation, DataHandlerMixedSamplingWithFilter, \
DataHandlerMixedSamplingWithFilterSingleStation, DataHandlerSeparationOfScales, \
DataHandlerSeparationOfScalesSingleStation
from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleStation
from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
from mlair.helpers import remove_items
from mlair.configuration.defaults import DEFAULT_INTERPOLATION_METHOD
import pytest
import mock
class TestDataHandlerMixedSampling:
def test_data_handler(self):
obj = object.__new__(DataHandlerMixedSampling)
assert obj.data_handler.__qualname__ == DataHandlerMixedSamplingSingleStation.__qualname__
def test_data_handler_transformation(self):
obj = object.__new__(DataHandlerMixedSampling)
assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingSingleStation.__qualname__
def test_requirements(self):
obj = object.__new__(DataHandlerMixedSampling)
req = object.__new__(DataHandlerSingleStation)
assert sorted(obj._requirements) == sorted(remove_items(req.requirements(), "station"))
class TestDataHandlerMixedSamplingSingleStation:
def test_requirements(self):
obj = object.__new__(DataHandlerMixedSamplingSingleStation)
req = object.__new__(DataHandlerSingleStation)
assert sorted(obj._requirements) == sorted(remove_items(req.requirements(), "station"))
@mock.patch("mlair.data_handler.data_handler_mixed_sampling.DataHandlerMixedSamplingSingleStation.setup_samples")
def test_init(self, mock_super_init):
obj = DataHandlerMixedSamplingSingleStation("first_arg", "second", {}, test=23, sampling="hourly",
interpolation_limit=(1, 10))
assert obj.sampling == ("hourly", "hourly")
assert obj.interpolation_limit == (1, 10)
assert obj.interpolation_method == (DEFAULT_INTERPOLATION_METHOD, DEFAULT_INTERPOLATION_METHOD)
@pytest.fixture
def kwargs_dict(self):
return {"test1": 2, "param_2": "string", "another": (10, 2)}
def test_update_kwargs_single_to_tuple(self, kwargs_dict):
obj = object.__new__(DataHandlerMixedSamplingSingleStation)
obj.update_kwargs("test1", "23", kwargs_dict)
assert kwargs_dict["test1"] == (2, 2)
obj.update_kwargs("param_2", "23", kwargs_dict)
assert kwargs_dict["param_2"] == ("string", "string")
def test_update_kwargs_tuple(self, kwargs_dict):
obj = object.__new__(DataHandlerMixedSamplingSingleStation)
obj.update_kwargs("another", "23", kwargs_dict)
assert kwargs_dict["another"] == (10, 2)
def test_update_kwargs_default(self, kwargs_dict):
obj = object.__new__(DataHandlerMixedSamplingSingleStation)
obj.update_kwargs("not_existing", "23", kwargs_dict)
assert kwargs_dict["not_existing"] == ("23", "23")
obj.update_kwargs("also_new", (4, 2), kwargs_dict)
assert kwargs_dict["also_new"] == (4, 2)
def test_update_kwargs_assert_failure(self, kwargs_dict):
obj = object.__new__(DataHandlerMixedSamplingSingleStation)
with pytest.raises(AssertionError):
obj.update_kwargs("error_too_long", (1, 2, 3), kwargs_dict)
def test_setup_samples(self):
pass
def test_load_and_interpolate(self):
pass
def test_set_inputs_and_targets(self):
pass
def test_setup_data_path(self):
pass
class TestDataHandlerMixedSamplingWithFilter:
def test_data_handler(self):
obj = object.__new__(DataHandlerMixedSamplingWithFilter)
assert obj.data_handler.__qualname__ == DataHandlerMixedSamplingWithFilterSingleStation.__qualname__
def test_data_handler_transformation(self):
obj = object.__new__(DataHandlerMixedSamplingWithFilter)
assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingWithFilterSingleStation.__qualname__
def test_requirements(self):
obj = object.__new__(DataHandlerMixedSamplingWithFilter)
req1 = object.__new__(DataHandlerMixedSamplingSingleStation)
req2 = object.__new__(DataHandlerKzFilterSingleStation)
req = list(set(req1.requirements() + req2.requirements()))
assert sorted(obj._requirements) == sorted(remove_items(req, "station"))
class TestDataHandlerMixedSamplingWithFilterSingleStation:
pass
class TestDataHandlerSeparationOfScales:
def test_data_handler(self):
obj = object.__new__(DataHandlerSeparationOfScales)
assert obj.data_handler.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__
def test_data_handler_transformation(self):
obj = object.__new__(DataHandlerSeparationOfScales)
assert obj.data_handler_transformation.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__
def test_requirements(self):
obj = object.__new__(DataHandlerMixedSamplingWithFilter)
req1 = object.__new__(DataHandlerMixedSamplingSingleStation)
req2 = object.__new__(DataHandlerKzFilterSingleStation)
req = list(set(req1.requirements() + req2.requirements()))
assert sorted(obj._requirements) == sorted(remove_items(req, "station"))
class TestDataHandlerSeparationOfScalesSingleStation:
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment