diff --git a/.gitignore b/.gitignore index 01ca296747666fee411aedca6fbb15a554a0bb51..f5e425f752a1de0de0c68036a54e0d19450320bb 100644 --- a/.gitignore +++ b/.gitignore @@ -49,6 +49,8 @@ Thumbs.db ############################## run_*_develgpus.bash run_*_gpus.bash +run_*_batch.bash +activate_env.sh # don't check data and plot folder # #################################### @@ -84,4 +86,4 @@ report.html # ignore locally build documentation # ###################################### -/docs/_build \ No newline at end of file +/docs/_build diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index c16dfb4344b6d083876081b333f855a9eac99c6b..6d1e66bc1cee53342ad76625887e44ff36cd7562 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -6,11 +6,12 @@ from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleS from mlair.data_handler import DefaultDataHandler from mlair import helpers from mlair.helpers import remove_items -from mlair.configuration.defaults import DEFAULT_SAMPLING +from mlair.configuration.defaults import DEFAULT_SAMPLING, DEFAULT_INTERPOLATION_LIMIT, DEFAULT_INTERPOLATION_METHOD import inspect from typing import Callable import datetime as dt +from typing import Any import numpy as np import pandas as pd @@ -20,11 +21,39 @@ import xarray as xr class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) - def __init__(self, *args, sampling_inputs, **kwargs): - sampling = (sampling_inputs, kwargs.get("sampling", DEFAULT_SAMPLING)) - kwargs.update({"sampling": sampling}) + def __init__(self, *args, **kwargs): + """ + This data handler requires the kwargs sampling, interpolation_limit, and interpolation_method to be a 2D tuple + for input and target data. If one of these kwargs is only a single argument, it will be applied to inputs and + targets with this value. If one of these kwargs is a 2-dim tuple, the first element is applied to inputs and the + second to targets respectively. If one of these kwargs is not provided, it is filled up with the same default + value for inputs and targets. + """ + self.update_kwargs("sampling", DEFAULT_SAMPLING, kwargs) + self.update_kwargs("interpolation_limit", DEFAULT_INTERPOLATION_LIMIT, kwargs) + self.update_kwargs("interpolation_method", DEFAULT_INTERPOLATION_METHOD, kwargs) super().__init__(*args, **kwargs) + @staticmethod + def update_kwargs(parameter_name: str, default: Any, kwargs: dict): + """ + Update a single element of kwargs inplace to be usable for inputs and targets. + + The updated value in the kwargs dictionary is a tuple consisting on the value applicable to the inputs as first + element and the target's value as second element: (<value_input>, <value_target>). If the value for the given + parameter_name is already a tuple, it is checked to have exact two entries. If the paramter_name is not + included in kwargs, the given default value is used and applied to both elements of the update tuple. + + :param parameter_name: name of the parameter that should be transformed to 2-dim + :param default: the default value to fill if parameter is not in kwargs + :param kwargs: the kwargs dictionary containing parameters + """ + parameter = kwargs.get(parameter_name, default) + if not isinstance(parameter, tuple): + parameter = (parameter, parameter) + assert len(parameter) == 2 # (inputs, targets) + kwargs.update({parameter_name: parameter}) + def setup_samples(self): """ Setup samples. This method prepares and creates samples X, and labels Y. @@ -41,8 +70,8 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind], self.station_type, self.network, self.store_data_locally, self.data_origin, self.start, self.end) - data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, - limit=self.interpolation_limit) + data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind], + limit=self.interpolation_limit[ind]) return data def set_inputs_and_targets(self): @@ -119,7 +148,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi self.station_type, self.network, self.store_data_locally, self.data_origin, start, end) data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, - limit=self.interpolation_limit) + limit=self.interpolation_limit[ind]) return data diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index 9a9253eda522c39f348dd96700ed38730e87f9a8..54d2307718bf083cfbfb8296682c9c545157eb72 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -225,8 +225,8 @@ class ExperimentSetup(RunEnvironment): extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, number_of_bootstraps=None, create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None, - hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, sampling_inputs=None, - sampling_outputs=None, data_origin: Dict = None, **kwargs): + hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, + data_origin: Dict = None, **kwargs): # create run framework super().__init__() @@ -294,7 +294,6 @@ class ExperimentSetup(RunEnvironment): self._set_param("window_history_size", window_history_size, default=DEFAULT_WINDOW_HISTORY_SIZE) self._set_param("overwrite_local_data", overwrite_local_data, default=DEFAULT_OVERWRITE_LOCAL_DATA, scope="preprocessing") - self._set_param("sampling_inputs", sampling_inputs, default=sampling) self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION) self._set_param("transformation", None, scope="preprocessing") self._set_param("data_handler", data_handler, default=DefaultDataHandler) diff --git a/run_mixed_sampling.py b/run_mixed_sampling.py index dbc94ef9a50318a18645fce235001e80af104863..04683a17ede641a5370aaeef741d2f4546f966b7 100644 --- a/run_mixed_sampling.py +++ b/run_mixed_sampling.py @@ -7,23 +7,34 @@ from mlair.workflows import DefaultWorkflow from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, DataHandlerMixedSamplingWithFilter, \ DataHandlerSeparationOfScales +stats = {'o3': 'dma8eu', 'no': 'dma8eu', 'no2': 'dma8eu', + 'relhum': 'average_values', 'u': 'average_values', 'v': 'average_values', + 'cloudcover': 'average_values', 'pblheight': 'maximum', + 'temp': 'maximum'} +data_origin = {'o3': '', 'no': '', 'no2': '', + 'relhum': 'REA', 'u': 'REA', 'v': 'REA', + 'cloudcover': 'REA', 'pblheight': 'REA', + 'temp': 'REA'} + def main(parser_args): - args = dict(sampling="daily", - sampling_inputs="hourly", - window_history_size=24, + args = dict(stations=["DEBW107", "DEBW013"], + network="UBA", + evaluate_bootstraps=False, plot_list=[], + data_origin=data_origin, data_handler=DataHandlerMixedSampling, + interpolation_limit=(3, 1), overwrite_local_data=False, + sampling=("hourly", "daily"), + statistics_per_var=stats, + create_new_model=False, train_model=False, epochs=1, + window_history_size=48, window_history_offset=17, - **parser_args.__dict__, - data_handler=DataHandlerMixedSampling, kz_filter_length=[100 * 24, 15 * 24], kz_filter_iter=[4, 5], start="2006-01-01", train_start="2006-01-01", end="2011-12-31", test_end="2011-12-31", - stations=["DEBW107", "DEBW013"], - epochs=1, - network="UBA", + **parser_args.__dict__, ) workflow = DefaultWorkflow(**args) workflow.run() diff --git a/test/test_data_handler/test_data_handler_mixed_sampling.py b/test/test_data_handler/test_data_handler_mixed_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..d2f9ce00224a61815c89e44b7c37a667d239b2f5 --- /dev/null +++ b/test/test_data_handler/test_data_handler_mixed_sampling.py @@ -0,0 +1,130 @@ +__author__ = 'Lukas Leufen' +__date__ = '2020-12-10' + +from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, \ + DataHandlerMixedSamplingSingleStation, DataHandlerMixedSamplingWithFilter, \ + DataHandlerMixedSamplingWithFilterSingleStation, DataHandlerSeparationOfScales, \ + DataHandlerSeparationOfScalesSingleStation +from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleStation +from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation +from mlair.helpers import remove_items +from mlair.configuration.defaults import DEFAULT_INTERPOLATION_METHOD + +import pytest +import mock + + +class TestDataHandlerMixedSampling: + + def test_data_handler(self): + obj = object.__new__(DataHandlerMixedSampling) + assert obj.data_handler.__qualname__ == DataHandlerMixedSamplingSingleStation.__qualname__ + + def test_data_handler_transformation(self): + obj = object.__new__(DataHandlerMixedSampling) + assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingSingleStation.__qualname__ + + def test_requirements(self): + obj = object.__new__(DataHandlerMixedSampling) + req = object.__new__(DataHandlerSingleStation) + assert sorted(obj._requirements) == sorted(remove_items(req.requirements(), "station")) + + +class TestDataHandlerMixedSamplingSingleStation: + + def test_requirements(self): + obj = object.__new__(DataHandlerMixedSamplingSingleStation) + req = object.__new__(DataHandlerSingleStation) + assert sorted(obj._requirements) == sorted(remove_items(req.requirements(), "station")) + + @mock.patch("mlair.data_handler.data_handler_mixed_sampling.DataHandlerMixedSamplingSingleStation.setup_samples") + def test_init(self, mock_super_init): + obj = DataHandlerMixedSamplingSingleStation("first_arg", "second", {}, test=23, sampling="hourly", + interpolation_limit=(1, 10)) + assert obj.sampling == ("hourly", "hourly") + assert obj.interpolation_limit == (1, 10) + assert obj.interpolation_method == (DEFAULT_INTERPOLATION_METHOD, DEFAULT_INTERPOLATION_METHOD) + + @pytest.fixture + def kwargs_dict(self): + return {"test1": 2, "param_2": "string", "another": (10, 2)} + + def test_update_kwargs_single_to_tuple(self, kwargs_dict): + obj = object.__new__(DataHandlerMixedSamplingSingleStation) + obj.update_kwargs("test1", "23", kwargs_dict) + assert kwargs_dict["test1"] == (2, 2) + obj.update_kwargs("param_2", "23", kwargs_dict) + assert kwargs_dict["param_2"] == ("string", "string") + + def test_update_kwargs_tuple(self, kwargs_dict): + obj = object.__new__(DataHandlerMixedSamplingSingleStation) + obj.update_kwargs("another", "23", kwargs_dict) + assert kwargs_dict["another"] == (10, 2) + + def test_update_kwargs_default(self, kwargs_dict): + obj = object.__new__(DataHandlerMixedSamplingSingleStation) + obj.update_kwargs("not_existing", "23", kwargs_dict) + assert kwargs_dict["not_existing"] == ("23", "23") + obj.update_kwargs("also_new", (4, 2), kwargs_dict) + assert kwargs_dict["also_new"] == (4, 2) + + def test_update_kwargs_assert_failure(self, kwargs_dict): + obj = object.__new__(DataHandlerMixedSamplingSingleStation) + with pytest.raises(AssertionError): + obj.update_kwargs("error_too_long", (1, 2, 3), kwargs_dict) + + def test_setup_samples(self): + pass + + def test_load_and_interpolate(self): + pass + + def test_set_inputs_and_targets(self): + pass + + def test_setup_data_path(self): + pass + + +class TestDataHandlerMixedSamplingWithFilter: + + def test_data_handler(self): + obj = object.__new__(DataHandlerMixedSamplingWithFilter) + assert obj.data_handler.__qualname__ == DataHandlerMixedSamplingWithFilterSingleStation.__qualname__ + + def test_data_handler_transformation(self): + obj = object.__new__(DataHandlerMixedSamplingWithFilter) + assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingWithFilterSingleStation.__qualname__ + + def test_requirements(self): + obj = object.__new__(DataHandlerMixedSamplingWithFilter) + req1 = object.__new__(DataHandlerMixedSamplingSingleStation) + req2 = object.__new__(DataHandlerKzFilterSingleStation) + req = list(set(req1.requirements() + req2.requirements())) + assert sorted(obj._requirements) == sorted(remove_items(req, "station")) + + +class TestDataHandlerMixedSamplingWithFilterSingleStation: + pass + + +class TestDataHandlerSeparationOfScales: + + def test_data_handler(self): + obj = object.__new__(DataHandlerSeparationOfScales) + assert obj.data_handler.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__ + + def test_data_handler_transformation(self): + obj = object.__new__(DataHandlerSeparationOfScales) + assert obj.data_handler_transformation.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__ + + def test_requirements(self): + obj = object.__new__(DataHandlerMixedSamplingWithFilter) + req1 = object.__new__(DataHandlerMixedSamplingSingleStation) + req2 = object.__new__(DataHandlerKzFilterSingleStation) + req = list(set(req1.requirements() + req2.requirements())) + assert sorted(obj._requirements) == sorted(remove_items(req, "station")) + + +class TestDataHandlerSeparationOfScalesSingleStation: + pass