From ab2f51cb7171a201c8aec9d8236765d6cee6adcf Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Thu, 2 Dec 2021 11:23:13 +0100 Subject: [PATCH] data handlers do not longer require to state dependencies of super classes (only of sub data handlers like single stations), single station data handlers should have now empty _requirements as all dependencies are covered by own_args --- .../data_handler_mixed_sampling.py | 18 +------- .../data_handler/data_handler_with_filter.py | 10 ++--- mlair/run_modules/post_processing.py | 2 +- .../test_data_handler_mixed_sampling.py | 30 +++++++------ .../test_data_handler_with_filter.py | 42 +++++++++---------- .../test_default_data_handler.py | 2 + ...est_default_data_handler_single_station.py | 1 + 7 files changed, 42 insertions(+), 63 deletions(-) diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index c6b612a7..91f7ec01 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -21,8 +21,6 @@ import xarray as xr class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): - _requirements = DataHandlerSingleStation.requirements() - def __init__(self, *args, **kwargs): """ This data handler requires the kwargs sampling, interpolation_limit, and interpolation_method to be a 2D tuple @@ -97,9 +95,6 @@ class DataHandlerMixedSampling(DefaultDataHandler): class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSingleStation, DataHandlerFilterSingleStation): - _requirements1 = DataHandlerFilterSingleStation.requirements() - _requirements2 = DataHandlerMixedSamplingSingleStation.requirements() - _requirements = list(set(_requirements1 + _requirements2)) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -168,10 +163,6 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi class DataHandlerMixedSamplingWithFirFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation, DataHandlerFirFilterSingleStation): - _requirements1 = DataHandlerFirFilterSingleStation.requirements() - _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements() - _requirements = list(set(_requirements1 + _requirements2)) - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -211,14 +202,11 @@ class DataHandlerMixedSamplingWithFirFilter(DataHandlerFirFilter): data_handler = DataHandlerMixedSamplingWithFirFilterSingleStation data_handler_transformation = DataHandlerMixedSamplingWithFirFilterSingleStation - _requirements = list(set(data_handler.requirements() + DataHandlerFirFilter.requirements())) + _requirements = data_handler.requirements() class DataHandlerMixedSamplingWithClimateFirFilterSingleStation(DataHandlerClimateFirFilterSingleStation, DataHandlerMixedSamplingWithFirFilterSingleStation): - _requirements1 = DataHandlerClimateFirFilterSingleStation.requirements() - _requirements2 = DataHandlerMixedSamplingWithFirFilterSingleStation.requirements() - _requirements = list(set(_requirements1 + _requirements2)) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -235,9 +223,7 @@ class DataHandlerMixedSamplingWithClimateFirFilter(DataHandlerClimateFirFilter): data_handler = DataHandlerMixedSamplingWithClimateFirFilterSingleStation data_handler_transformation = DataHandlerMixedSamplingWithClimateFirFilterSingleStation data_handler_unfiltered = DataHandlerMixedSamplingSingleStation - _requirements1 = data_handler.requirements() + data_handler_unfiltered.requirements() - _requirements2 = DataHandlerClimateFirFilter.requirements() - _requirements = list(set(_requirements1 + _requirements2)) + _requirements = list(set(data_handler.requirements() + data_handler_unfiltered.requirements())) DEFAULT_FILTER_ADD_UNFILTERED = False def __init__(self, *args, data_handler_class_unfiltered: data_handler_unfiltered = None, diff --git a/mlair/data_handler/data_handler_with_filter.py b/mlair/data_handler/data_handler_with_filter.py index 1925015a..a522f53b 100644 --- a/mlair/data_handler/data_handler_with_filter.py +++ b/mlair/data_handler/data_handler_with_filter.py @@ -38,7 +38,6 @@ str_or_list = Union[str, List[str]] class DataHandlerFilterSingleStation(DataHandlerSingleStation): """General data handler for a single station to be used by a superior data handler.""" - _requirements = DataHandlerSingleStation.requirements() _hash = DataHandlerSingleStation._hash + ["filter_dim"] DEFAULT_FILTER_DIM = "filter" @@ -111,7 +110,7 @@ class DataHandlerFilter(DefaultDataHandler): data_handler = DataHandlerFilterSingleStation data_handler_transformation = DataHandlerFilterSingleStation - _requirements = list(set(data_handler.requirements() + DefaultDataHandler.requirements())) + _requirements = data_handler.requirements() def __init__(self, *args, use_filter_branches=False, **kwargs): self.use_filter_branches = use_filter_branches @@ -121,7 +120,6 @@ class DataHandlerFilter(DefaultDataHandler): class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): """Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered.""" - _requirements = DataHandlerFilterSingleStation.requirements() _hash = DataHandlerFilterSingleStation._hash + ["filter_cutoff_period", "filter_order", "filter_window_type"] DEFAULT_WINDOW_TYPE = ("kaiser", 5) @@ -306,7 +304,7 @@ class DataHandlerFirFilter(DataHandlerFilter): data_handler = DataHandlerFirFilterSingleStation data_handler_transformation = DataHandlerFirFilterSingleStation - _requirements = list(set(data_handler.requirements() + DataHandlerFilter.requirements())) + _requirements = data_handler.requirements() class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation): @@ -327,8 +325,6 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation :param apriori_diurnal: use diurnal anomalies of each hour as addition to the apriori information type chosen by parameter apriori_type. This is only applicable for hourly resolution data. """ - - _requirements = DataHandlerFirFilterSingleStation.requirements() _hash = DataHandlerFirFilterSingleStation._hash + ["apriori_type", "apriori_sel_opts", "apriori_diurnal", "extend_length_opts"] _store_attributes = DataHandlerFirFilterSingleStation.store_attributes() + ["apriori"] @@ -471,6 +467,6 @@ class DataHandlerClimateFirFilter(DataHandlerFilter): data_handler = DataHandlerClimateFirFilterSingleStation data_handler_transformation = DataHandlerClimateFirFilterSingleStation - _requirements = list(set(data_handler.requirements() + DataHandlerFilter.requirements())) + _requirements = data_handler.requirements() _store_attributes = data_handler.store_attributes() diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 7f2b3b59..c67645fb 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -630,6 +630,7 @@ class PostProcessing(RunEnvironment): logging.info(f"start make_prediction for {subset_type}") time_dimension = self.data_store.get("time_dim") window_dim = self.data_store.get("window_dim") + path = self.data_store.get("forecast_path") subset_type = subset.name for i, data in enumerate(subset): input_data = data.get_X() @@ -669,7 +670,6 @@ class PostProcessing(RunEnvironment): **prediction_dict) # save all forecasts locally - path = self.data_store.get("forecast_path") prefix = "forecasts_norm" if normalised is True else "forecasts" file = os.path.join(path, f"{prefix}_{str(data)}_{subset_type}.nc") all_predictions.to_netcdf(file) diff --git a/test/test_data_handler/test_data_handler_mixed_sampling.py b/test/test_data_handler/test_data_handler_mixed_sampling.py index 6a4f5da2..0515278a 100644 --- a/test/test_data_handler/test_data_handler_mixed_sampling.py +++ b/test/test_data_handler/test_data_handler_mixed_sampling.py @@ -40,9 +40,8 @@ class TestDataHandlerMixedSampling: class TestDataHandlerMixedSamplingSingleStation: def test_requirements(self): - reqs = get_all_args(DataHandlerMixedSamplingSingleStation) + reqs = get_all_args(DataHandlerSingleStation) obj = object.__new__(DataHandlerMixedSamplingSingleStation) - assert reqs == ["self"] assert sorted(obj.own_args()) == reqs reqs = get_all_args(DataHandlerSingleStation, remove="self") assert sorted(obj.requirements()) == reqs @@ -100,49 +99,48 @@ class TestDataHandlerMixedSamplingWithFilterSingleStation: def test_requirements(self): - reqs = get_all_args(DataHandlerMixedSamplingWithFilterSingleStation) + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation) obj = object.__new__(DataHandlerMixedSamplingWithFilterSingleStation) - assert reqs == ["self"] assert sorted(obj.own_args()) == reqs reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, remove="self") - assert sorted(obj._requirements) == reqs + assert sorted(obj._requirements) == [] assert sorted(obj.requirements()) == reqs class TestDataHandlerMixedSamplingWithFirFilter: def test_requirements(self): - reqs = get_all_args(DataHandlerFilter) + reqs = get_all_args(DataHandlerFilter, DefaultDataHandler) obj = object.__new__(DataHandlerMixedSamplingWithFirFilter) assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, DataHandlerFirFilterSingleStation, + remove=["self"]) + assert sorted(obj._requirements) == reqs reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, DataHandlerFilter, DataHandlerFirFilterSingleStation, DefaultDataHandler, remove=["self", "id_class"]) - assert sorted(obj._requirements) == reqs assert sorted(obj.requirements()) == reqs class TestDataHandlerMixedSamplingWithFirFilterSingleStation: def test_requirements(self): - reqs = get_all_args(DataHandlerMixedSamplingWithFirFilterSingleStation) + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation) obj = object.__new__(DataHandlerMixedSamplingWithFirFilterSingleStation) - assert reqs == ["self"] assert sorted(obj.own_args()) == reqs + assert sorted(obj._requirements) == [] reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerFirFilterSingleStation, DataHandlerSingleStation, remove="self") - assert sorted(obj._requirements) == reqs assert sorted(obj.requirements()) == reqs class TestDataHandlerMixedSamplingWithClimateFirFilter: def test_requirements(self): - reqs = get_all_args(DataHandlerMixedSamplingWithClimateFirFilter) + reqs = get_all_args(DataHandlerMixedSamplingWithClimateFirFilter, DataHandlerFilter, DefaultDataHandler) obj = object.__new__(DataHandlerMixedSamplingWithClimateFirFilter) assert sorted(obj.own_args()) == reqs reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerClimateFirFilterSingleStation, - DataHandlerSingleStation, DataHandlerFilter, DataHandlerFirFilterSingleStation, - DefaultDataHandler, remove=["self", "id_class"]) + DataHandlerSingleStation, DataHandlerFirFilterSingleStation, remove=["self"]) assert sorted(obj._requirements) == reqs reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerClimateFirFilterSingleStation, DataHandlerSingleStation, DataHandlerFilter, DataHandlerMixedSamplingWithClimateFirFilter, @@ -153,13 +151,13 @@ class TestDataHandlerMixedSamplingWithClimateFirFilter: class TestDataHandlerMixedSamplingWithClimateFirFilterSingleStation: def test_requirements(self): - reqs = get_all_args(DataHandlerMixedSamplingWithClimateFirFilterSingleStation) + reqs = get_all_args(DataHandlerClimateFirFilterSingleStation, DataHandlerFirFilterSingleStation, + DataHandlerFilterSingleStation, DataHandlerSingleStation) obj = object.__new__(DataHandlerMixedSamplingWithClimateFirFilterSingleStation) - assert reqs == ["self"] assert sorted(obj.own_args()) == reqs + assert sorted(obj._requirements) == [] reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerFirFilterSingleStation, DataHandlerSingleStation, DataHandlerClimateFirFilterSingleStation, remove="self") - assert sorted(obj._requirements) == reqs assert sorted(obj.requirements()) == reqs diff --git a/test/test_data_handler/test_data_handler_with_filter.py b/test/test_data_handler/test_data_handler_with_filter.py index faa1076c..b83effd9 100644 --- a/test/test_data_handler/test_data_handler_with_filter.py +++ b/test/test_data_handler/test_data_handler_with_filter.py @@ -11,11 +11,10 @@ from mlair.helpers.testing import get_all_args class TestDataHandlerFilter: def test_requirements(self): - reqs = get_all_args(DataHandlerFilter) + reqs = get_all_args(DataHandlerFilter, DefaultDataHandler) obj = object.__new__(DataHandlerFilter) assert sorted(obj.own_args()) == reqs - reqs = get_all_args(DataHandlerSingleStation, DataHandlerFilterSingleStation, DefaultDataHandler, - remove=["self", "id_class"]) + reqs = get_all_args(DataHandlerSingleStation, DataHandlerFilterSingleStation, remove=["self"]) assert sorted(obj._requirements) == reqs reqs = get_all_args(DataHandlerSingleStation, DataHandlerFilterSingleStation, DefaultDataHandler, DataHandlerFilter, remove=["self", "id_class"]) @@ -25,25 +24,25 @@ class TestDataHandlerFilter: class TestDataHandlerFilterSingleStation: def test_requirements(self): - reqs = get_all_args(DataHandlerFilterSingleStation) + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation) obj = object.__new__(DataHandlerFilterSingleStation) - assert sorted(reqs) == sorted(["filter_dim", "self"]) assert sorted(obj.own_args()) == reqs - reqs = get_all_args(DataHandlerSingleStation, remove="self") - assert sorted(obj._requirements) == reqs - reqs = get_all_args(DataHandlerSingleStation, DataHandlerFilterSingleStation, remove="self") + assert sorted(obj._requirements) == [] + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, remove="self") assert sorted(obj.requirements()) == reqs class TestDataHandlerFirFilter: def test_requirements(self): - reqs = get_all_args(DataHandlerFilter) + reqs = get_all_args(DataHandlerFilter, DefaultDataHandler) obj = object.__new__(DataHandlerFirFilter) assert sorted(obj.own_args()) == reqs reqs = get_all_args(DataHandlerSingleStation, DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, - DataHandlerFilter, DefaultDataHandler, remove=["self", "id_class"]) + remove=["self"]) assert sorted(obj._requirements) == reqs + reqs = get_all_args(DataHandlerSingleStation, DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, + DataHandlerFilter, DefaultDataHandler, remove=["self", "id_class"]) assert sorted(obj.requirements()) == reqs @@ -51,13 +50,11 @@ class TestDataHandlerFirFilterSingleStation: def test_requirements(self): - reqs = get_all_args(DataHandlerFirFilterSingleStation) + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation) obj = object.__new__(DataHandlerFirFilterSingleStation) - assert sorted(reqs) == sorted(["filter_cutoff_period", "filter_order", "filter_window_type", "self"]) assert sorted(obj.own_args()) == reqs - reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, remove="self") - assert sorted(obj._requirements) == reqs - reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, DataHandlerFirFilterSingleStation, + assert sorted(obj._requirements) == [] + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation, remove="self") assert sorted(obj.requirements()) == reqs @@ -65,27 +62,26 @@ class TestDataHandlerFirFilterSingleStation: class TestDataHandlerClimateFirFilter: def test_requirements(self): - reqs = get_all_args(DataHandlerFilter) + reqs = get_all_args(DataHandlerFilter, DefaultDataHandler) obj = object.__new__(DataHandlerClimateFirFilter) assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation, + DataHandlerClimateFirFilterSingleStation, remove="self") + assert sorted(obj._requirements) == reqs reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation, DataHandlerClimateFirFilterSingleStation, DefaultDataHandler, DataHandlerFilter, remove=["self", "id_class"]) - assert sorted(obj._requirements) == reqs assert sorted(obj.requirements()) == reqs class TestDataHandlerClimateFirFilterSingleStation: def test_requirements(self): - reqs = get_all_args(DataHandlerClimateFirFilterSingleStation) + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation, + DataHandlerClimateFirFilterSingleStation) obj = object.__new__(DataHandlerClimateFirFilterSingleStation) - assert sorted(reqs) == sorted(["apriori", "apriori_type", "apriori_diurnal", "apriori_sel_opts", "plot_path", - "name_affix", "self"]) assert sorted(obj.own_args()) == reqs - reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation, - remove="self") - assert sorted(obj._requirements) == reqs + assert sorted(obj._requirements) == [] reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation, DataHandlerClimateFirFilterSingleStation, remove="self") assert sorted(obj.requirements()) == reqs diff --git a/test/test_data_handler/test_default_data_handler.py b/test/test_data_handler/test_default_data_handler.py index 8f79083b..1e0a5db3 100644 --- a/test/test_data_handler/test_default_data_handler.py +++ b/test/test_data_handler/test_default_data_handler.py @@ -10,6 +10,8 @@ class TestDefaultDataHandler: reqs = get_all_args(DefaultDataHandler) obj = object.__new__(DefaultDataHandler) assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerSingleStation, remove="self") + assert sorted(obj._requirements) == reqs reqs = get_all_args(DefaultDataHandler, DataHandlerSingleStation, remove=["self", "id_class"]) assert sorted(obj.requirements()) == reqs reqs = get_all_args(DefaultDataHandler, DataHandlerSingleStation, remove=["self", "id_class", "station"]) diff --git a/test/test_data_handler/test_default_data_handler_single_station.py b/test/test_data_handler/test_default_data_handler_single_station.py index e22ab9c2..fea8f9cb 100644 --- a/test/test_data_handler/test_default_data_handler_single_station.py +++ b/test/test_data_handler/test_default_data_handler_single_station.py @@ -11,4 +11,5 @@ class TestDataHandlerSingleStation: reqs = get_all_args(DataHandlerSingleStation) obj = object.__new__(DataHandlerSingleStation) assert sorted(obj.own_args()) == reqs + assert obj._requirements == [] assert sorted(obj.requirements()) == remove_items(reqs, "self") -- GitLab