From ab2f51cb7171a201c8aec9d8236765d6cee6adcf Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Thu, 2 Dec 2021 11:23:13 +0100
Subject: [PATCH] data handlers do not longer require to state dependencies of
 super classes (only of sub data handlers like single stations), single
 station data handlers should have now empty _requirements as all dependencies
 are covered by own_args

---
 .../data_handler_mixed_sampling.py            | 18 +-------
 .../data_handler/data_handler_with_filter.py  | 10 ++---
 mlair/run_modules/post_processing.py          |  2 +-
 .../test_data_handler_mixed_sampling.py       | 30 +++++++------
 .../test_data_handler_with_filter.py          | 42 +++++++++----------
 .../test_default_data_handler.py              |  2 +
 ...est_default_data_handler_single_station.py |  1 +
 7 files changed, 42 insertions(+), 63 deletions(-)

diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py
index c6b612a7..91f7ec01 100644
--- a/mlair/data_handler/data_handler_mixed_sampling.py
+++ b/mlair/data_handler/data_handler_mixed_sampling.py
@@ -21,8 +21,6 @@ import xarray as xr
 
 class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
 
-    _requirements = DataHandlerSingleStation.requirements()
-
     def __init__(self, *args, **kwargs):
         """
         This data handler requires the kwargs sampling, interpolation_limit, and interpolation_method to be a 2D tuple
@@ -97,9 +95,6 @@ class DataHandlerMixedSampling(DefaultDataHandler):
 
 class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSingleStation,
                                                       DataHandlerFilterSingleStation):
-    _requirements1 = DataHandlerFilterSingleStation.requirements()
-    _requirements2 = DataHandlerMixedSamplingSingleStation.requirements()
-    _requirements = list(set(_requirements1 + _requirements2))
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
@@ -168,10 +163,6 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
 class DataHandlerMixedSamplingWithFirFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation,
                                                          DataHandlerFirFilterSingleStation):
 
-    _requirements1 = DataHandlerFirFilterSingleStation.requirements()
-    _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements()
-    _requirements = list(set(_requirements1 + _requirements2))
-
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
 
@@ -211,14 +202,11 @@ class DataHandlerMixedSamplingWithFirFilter(DataHandlerFirFilter):
 
     data_handler = DataHandlerMixedSamplingWithFirFilterSingleStation
     data_handler_transformation = DataHandlerMixedSamplingWithFirFilterSingleStation
-    _requirements = list(set(data_handler.requirements() + DataHandlerFirFilter.requirements()))
+    _requirements = data_handler.requirements()
 
 
 class DataHandlerMixedSamplingWithClimateFirFilterSingleStation(DataHandlerClimateFirFilterSingleStation,
                                                                 DataHandlerMixedSamplingWithFirFilterSingleStation):
-    _requirements1 = DataHandlerClimateFirFilterSingleStation.requirements()
-    _requirements2 = DataHandlerMixedSamplingWithFirFilterSingleStation.requirements()
-    _requirements = list(set(_requirements1 + _requirements2))
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
@@ -235,9 +223,7 @@ class DataHandlerMixedSamplingWithClimateFirFilter(DataHandlerClimateFirFilter):
     data_handler = DataHandlerMixedSamplingWithClimateFirFilterSingleStation
     data_handler_transformation = DataHandlerMixedSamplingWithClimateFirFilterSingleStation
     data_handler_unfiltered = DataHandlerMixedSamplingSingleStation
-    _requirements1 = data_handler.requirements() + data_handler_unfiltered.requirements()
-    _requirements2 = DataHandlerClimateFirFilter.requirements()
-    _requirements = list(set(_requirements1 + _requirements2))
+    _requirements = list(set(data_handler.requirements() + data_handler_unfiltered.requirements()))
     DEFAULT_FILTER_ADD_UNFILTERED = False
 
     def __init__(self, *args, data_handler_class_unfiltered: data_handler_unfiltered = None,
diff --git a/mlair/data_handler/data_handler_with_filter.py b/mlair/data_handler/data_handler_with_filter.py
index 1925015a..a522f53b 100644
--- a/mlair/data_handler/data_handler_with_filter.py
+++ b/mlair/data_handler/data_handler_with_filter.py
@@ -38,7 +38,6 @@ str_or_list = Union[str, List[str]]
 class DataHandlerFilterSingleStation(DataHandlerSingleStation):
     """General data handler for a single station to be used by a superior data handler."""
 
-    _requirements = DataHandlerSingleStation.requirements()
     _hash = DataHandlerSingleStation._hash + ["filter_dim"]
 
     DEFAULT_FILTER_DIM = "filter"
@@ -111,7 +110,7 @@ class DataHandlerFilter(DefaultDataHandler):
 
     data_handler = DataHandlerFilterSingleStation
     data_handler_transformation = DataHandlerFilterSingleStation
-    _requirements = list(set(data_handler.requirements() + DefaultDataHandler.requirements()))
+    _requirements = data_handler.requirements()
 
     def __init__(self, *args, use_filter_branches=False, **kwargs):
         self.use_filter_branches = use_filter_branches
@@ -121,7 +120,6 @@ class DataHandlerFilter(DefaultDataHandler):
 class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation):
     """Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered."""
 
-    _requirements = DataHandlerFilterSingleStation.requirements()
     _hash = DataHandlerFilterSingleStation._hash + ["filter_cutoff_period", "filter_order", "filter_window_type"]
 
     DEFAULT_WINDOW_TYPE = ("kaiser", 5)
@@ -306,7 +304,7 @@ class DataHandlerFirFilter(DataHandlerFilter):
 
     data_handler = DataHandlerFirFilterSingleStation
     data_handler_transformation = DataHandlerFirFilterSingleStation
-    _requirements = list(set(data_handler.requirements() + DataHandlerFilter.requirements()))
+    _requirements = data_handler.requirements()
 
 
 class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation):
@@ -327,8 +325,6 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation
     :param apriori_diurnal: use diurnal anomalies of each hour as addition to the apriori information type chosen by
         parameter apriori_type. This is only applicable for hourly resolution data.
     """
-
-    _requirements = DataHandlerFirFilterSingleStation.requirements()
     _hash = DataHandlerFirFilterSingleStation._hash + ["apriori_type", "apriori_sel_opts", "apriori_diurnal",
                                                        "extend_length_opts"]
     _store_attributes = DataHandlerFirFilterSingleStation.store_attributes() + ["apriori"]
@@ -471,6 +467,6 @@ class DataHandlerClimateFirFilter(DataHandlerFilter):
 
     data_handler = DataHandlerClimateFirFilterSingleStation
     data_handler_transformation = DataHandlerClimateFirFilterSingleStation
-    _requirements = list(set(data_handler.requirements() + DataHandlerFilter.requirements()))
+    _requirements = data_handler.requirements()
     _store_attributes = data_handler.store_attributes()
 
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 7f2b3b59..c67645fb 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -630,6 +630,7 @@ class PostProcessing(RunEnvironment):
         logging.info(f"start make_prediction for {subset_type}")
         time_dimension = self.data_store.get("time_dim")
         window_dim = self.data_store.get("window_dim")
+        path = self.data_store.get("forecast_path")
         subset_type = subset.name
         for i, data in enumerate(subset):
             input_data = data.get_X()
@@ -669,7 +670,6 @@ class PostProcessing(RunEnvironment):
                                                               **prediction_dict)
 
                 # save all forecasts locally
-                path = self.data_store.get("forecast_path")
                 prefix = "forecasts_norm" if normalised is True else "forecasts"
                 file = os.path.join(path, f"{prefix}_{str(data)}_{subset_type}.nc")
                 all_predictions.to_netcdf(file)
diff --git a/test/test_data_handler/test_data_handler_mixed_sampling.py b/test/test_data_handler/test_data_handler_mixed_sampling.py
index 6a4f5da2..0515278a 100644
--- a/test/test_data_handler/test_data_handler_mixed_sampling.py
+++ b/test/test_data_handler/test_data_handler_mixed_sampling.py
@@ -40,9 +40,8 @@ class TestDataHandlerMixedSampling:
 class TestDataHandlerMixedSamplingSingleStation:
 
     def test_requirements(self):
-        reqs = get_all_args(DataHandlerMixedSamplingSingleStation)
+        reqs = get_all_args(DataHandlerSingleStation)
         obj = object.__new__(DataHandlerMixedSamplingSingleStation)
-        assert reqs == ["self"]
         assert sorted(obj.own_args()) == reqs
         reqs = get_all_args(DataHandlerSingleStation, remove="self")
         assert sorted(obj.requirements()) == reqs
@@ -100,49 +99,48 @@ class TestDataHandlerMixedSamplingWithFilterSingleStation:
 
     def test_requirements(self):
 
-        reqs = get_all_args(DataHandlerMixedSamplingWithFilterSingleStation)
+        reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation)
         obj = object.__new__(DataHandlerMixedSamplingWithFilterSingleStation)
-        assert reqs == ["self"]
         assert sorted(obj.own_args()) == reqs
         reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, remove="self")
-        assert sorted(obj._requirements) == reqs
+        assert sorted(obj._requirements) == []
         assert sorted(obj.requirements()) == reqs
 
 
 class TestDataHandlerMixedSamplingWithFirFilter:
 
     def test_requirements(self):
-        reqs = get_all_args(DataHandlerFilter)
+        reqs = get_all_args(DataHandlerFilter, DefaultDataHandler)
         obj = object.__new__(DataHandlerMixedSamplingWithFirFilter)
         assert sorted(obj.own_args()) == reqs
+        reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, DataHandlerFirFilterSingleStation,
+                            remove=["self"])
+        assert sorted(obj._requirements) == reqs
         reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, DataHandlerFilter,
                             DataHandlerFirFilterSingleStation, DefaultDataHandler, remove=["self", "id_class"])
-        assert sorted(obj._requirements) == reqs
         assert sorted(obj.requirements()) == reqs
 
 
 class TestDataHandlerMixedSamplingWithFirFilterSingleStation:
 
     def test_requirements(self):
-        reqs = get_all_args(DataHandlerMixedSamplingWithFirFilterSingleStation)
+        reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation)
         obj = object.__new__(DataHandlerMixedSamplingWithFirFilterSingleStation)
-        assert reqs == ["self"]
         assert sorted(obj.own_args()) == reqs
+        assert sorted(obj._requirements) == []
         reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerFirFilterSingleStation, DataHandlerSingleStation,
                             remove="self")
-        assert sorted(obj._requirements) == reqs
         assert sorted(obj.requirements()) == reqs
 
 
 class TestDataHandlerMixedSamplingWithClimateFirFilter:
 
     def test_requirements(self):
-        reqs = get_all_args(DataHandlerMixedSamplingWithClimateFirFilter)
+        reqs = get_all_args(DataHandlerMixedSamplingWithClimateFirFilter, DataHandlerFilter, DefaultDataHandler)
         obj = object.__new__(DataHandlerMixedSamplingWithClimateFirFilter)
         assert sorted(obj.own_args()) == reqs
         reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerClimateFirFilterSingleStation,
-                            DataHandlerSingleStation, DataHandlerFilter, DataHandlerFirFilterSingleStation,
-                            DefaultDataHandler, remove=["self", "id_class"])
+                            DataHandlerSingleStation, DataHandlerFirFilterSingleStation, remove=["self"])
         assert sorted(obj._requirements) == reqs
         reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerClimateFirFilterSingleStation,
                             DataHandlerSingleStation, DataHandlerFilter, DataHandlerMixedSamplingWithClimateFirFilter,
@@ -153,13 +151,13 @@ class TestDataHandlerMixedSamplingWithClimateFirFilter:
 class TestDataHandlerMixedSamplingWithClimateFirFilterSingleStation:
 
     def test_requirements(self):
-        reqs = get_all_args(DataHandlerMixedSamplingWithClimateFirFilterSingleStation)
+        reqs = get_all_args(DataHandlerClimateFirFilterSingleStation, DataHandlerFirFilterSingleStation,
+                            DataHandlerFilterSingleStation, DataHandlerSingleStation)
         obj = object.__new__(DataHandlerMixedSamplingWithClimateFirFilterSingleStation)
-        assert reqs == ["self"]
         assert sorted(obj.own_args()) == reqs
+        assert sorted(obj._requirements) == []
         reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerFirFilterSingleStation, DataHandlerSingleStation,
                             DataHandlerClimateFirFilterSingleStation, remove="self")
-        assert sorted(obj._requirements) == reqs
         assert sorted(obj.requirements()) == reqs
 
 
diff --git a/test/test_data_handler/test_data_handler_with_filter.py b/test/test_data_handler/test_data_handler_with_filter.py
index faa1076c..b83effd9 100644
--- a/test/test_data_handler/test_data_handler_with_filter.py
+++ b/test/test_data_handler/test_data_handler_with_filter.py
@@ -11,11 +11,10 @@ from mlair.helpers.testing import get_all_args
 class TestDataHandlerFilter:
 
     def test_requirements(self):
-        reqs = get_all_args(DataHandlerFilter)
+        reqs = get_all_args(DataHandlerFilter, DefaultDataHandler)
         obj = object.__new__(DataHandlerFilter)
         assert sorted(obj.own_args()) == reqs
-        reqs = get_all_args(DataHandlerSingleStation, DataHandlerFilterSingleStation, DefaultDataHandler,
-                            remove=["self", "id_class"])
+        reqs = get_all_args(DataHandlerSingleStation, DataHandlerFilterSingleStation, remove=["self"])
         assert sorted(obj._requirements) == reqs
         reqs = get_all_args(DataHandlerSingleStation, DataHandlerFilterSingleStation, DefaultDataHandler,
                             DataHandlerFilter, remove=["self", "id_class"])
@@ -25,25 +24,25 @@ class TestDataHandlerFilter:
 class TestDataHandlerFilterSingleStation:
 
     def test_requirements(self):
-        reqs = get_all_args(DataHandlerFilterSingleStation)
+        reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation)
         obj = object.__new__(DataHandlerFilterSingleStation)
-        assert sorted(reqs) == sorted(["filter_dim", "self"])
         assert sorted(obj.own_args()) == reqs
-        reqs = get_all_args(DataHandlerSingleStation, remove="self")
-        assert sorted(obj._requirements) == reqs
-        reqs = get_all_args(DataHandlerSingleStation, DataHandlerFilterSingleStation, remove="self")
+        assert sorted(obj._requirements) == []
+        reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, remove="self")
         assert sorted(obj.requirements()) == reqs
 
 
 class TestDataHandlerFirFilter:
 
     def test_requirements(self):
-        reqs = get_all_args(DataHandlerFilter)
+        reqs = get_all_args(DataHandlerFilter, DefaultDataHandler)
         obj = object.__new__(DataHandlerFirFilter)
         assert sorted(obj.own_args()) == reqs
         reqs = get_all_args(DataHandlerSingleStation, DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation,
-                            DataHandlerFilter, DefaultDataHandler, remove=["self", "id_class"])
+                            remove=["self"])
         assert sorted(obj._requirements) == reqs
+        reqs = get_all_args(DataHandlerSingleStation, DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation,
+                            DataHandlerFilter, DefaultDataHandler, remove=["self", "id_class"])
         assert sorted(obj.requirements()) == reqs
 
 
@@ -51,13 +50,11 @@ class TestDataHandlerFirFilterSingleStation:
 
     def test_requirements(self):
 
-        reqs = get_all_args(DataHandlerFirFilterSingleStation)
+        reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation)
         obj = object.__new__(DataHandlerFirFilterSingleStation)
-        assert sorted(reqs) == sorted(["filter_cutoff_period", "filter_order", "filter_window_type", "self"])
         assert sorted(obj.own_args()) == reqs
-        reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, remove="self")
-        assert sorted(obj._requirements) == reqs
-        reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, DataHandlerFirFilterSingleStation,
+        assert sorted(obj._requirements) == []
+        reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation,
                             remove="self")
         assert sorted(obj.requirements()) == reqs
 
@@ -65,27 +62,26 @@ class TestDataHandlerFirFilterSingleStation:
 class TestDataHandlerClimateFirFilter:
 
     def test_requirements(self):
-        reqs = get_all_args(DataHandlerFilter)
+        reqs = get_all_args(DataHandlerFilter, DefaultDataHandler)
         obj = object.__new__(DataHandlerClimateFirFilter)
         assert sorted(obj.own_args()) == reqs
+        reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation,
+                            DataHandlerClimateFirFilterSingleStation, remove="self")
+        assert sorted(obj._requirements) == reqs
         reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation,
                             DataHandlerClimateFirFilterSingleStation, DefaultDataHandler, DataHandlerFilter,
                             remove=["self", "id_class"])
-        assert sorted(obj._requirements) == reqs
         assert sorted(obj.requirements()) == reqs
 
 
 class TestDataHandlerClimateFirFilterSingleStation:
 
     def test_requirements(self):
-        reqs = get_all_args(DataHandlerClimateFirFilterSingleStation)
+        reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation,
+                            DataHandlerClimateFirFilterSingleStation)
         obj = object.__new__(DataHandlerClimateFirFilterSingleStation)
-        assert sorted(reqs) == sorted(["apriori", "apriori_type", "apriori_diurnal", "apriori_sel_opts", "plot_path",
-                                       "name_affix", "self"])
         assert sorted(obj.own_args()) == reqs
-        reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation,
-                            remove="self")
-        assert sorted(obj._requirements) == reqs
+        assert sorted(obj._requirements) == []
         reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation,
                             DataHandlerClimateFirFilterSingleStation, remove="self")
         assert sorted(obj.requirements()) == reqs
diff --git a/test/test_data_handler/test_default_data_handler.py b/test/test_data_handler/test_default_data_handler.py
index 8f79083b..1e0a5db3 100644
--- a/test/test_data_handler/test_default_data_handler.py
+++ b/test/test_data_handler/test_default_data_handler.py
@@ -10,6 +10,8 @@ class TestDefaultDataHandler:
         reqs = get_all_args(DefaultDataHandler)
         obj = object.__new__(DefaultDataHandler)
         assert sorted(obj.own_args()) == reqs
+        reqs = get_all_args(DataHandlerSingleStation, remove="self")
+        assert sorted(obj._requirements) == reqs
         reqs = get_all_args(DefaultDataHandler, DataHandlerSingleStation, remove=["self", "id_class"])
         assert sorted(obj.requirements()) == reqs
         reqs = get_all_args(DefaultDataHandler, DataHandlerSingleStation, remove=["self", "id_class", "station"])
diff --git a/test/test_data_handler/test_default_data_handler_single_station.py b/test/test_data_handler/test_default_data_handler_single_station.py
index e22ab9c2..fea8f9cb 100644
--- a/test/test_data_handler/test_default_data_handler_single_station.py
+++ b/test/test_data_handler/test_default_data_handler_single_station.py
@@ -11,4 +11,5 @@ class TestDataHandlerSingleStation:
         reqs = get_all_args(DataHandlerSingleStation)
         obj = object.__new__(DataHandlerSingleStation)
         assert sorted(obj.own_args()) == reqs
+        assert obj._requirements == []
         assert sorted(obj.requirements()) == remove_items(reqs, "self")
-- 
GitLab