From 43f7d659923ab16787857587dabb3936323d534a Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Thu, 20 Jan 2022 16:38:20 +0100
Subject: [PATCH] use class variables for chem and meteo names, is now able to
 distinguish for all parameters of the data handler if provided as dict with
 respective keys, only use test stations for block mse calculation

---
 .../data_handler_mixed_sampling.py            | 48 ++++++++++++-------
 mlair/run_modules/post_processing.py          |  2 +-
 2 files changed, 31 insertions(+), 19 deletions(-)

diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py
index 8633c647..908ae409 100644
--- a/mlair/data_handler/data_handler_mixed_sampling.py
+++ b/mlair/data_handler/data_handler_mixed_sampling.py
@@ -244,11 +244,11 @@ class DataHandlerMixedSamplingWithClimateFirFilter(DataHandlerClimateFirFilter):
     def build(cls, station: str, **kwargs):
         sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler.requirements() if k in kwargs}
         filter_add_unfiltered = kwargs.get("filter_add_unfiltered", False)
-        sp_keys = cls.build_update_kwargs(sp_keys, dh_type="filtered")
+        sp_keys = cls.build_update_transformation(sp_keys, dh_type="filtered")
         sp = cls.data_handler(station, **sp_keys)
         if filter_add_unfiltered is True:
             sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_unfiltered.requirements() if k in kwargs}
-            sp_keys = cls.build_update_kwargs(sp_keys, dh_type="unfiltered")
+            sp_keys = cls.build_update_transformation(sp_keys, dh_type="unfiltered")
             sp_unfiltered = cls.data_handler_unfiltered(station, **sp_keys)
         else:
             sp_unfiltered = None
@@ -256,7 +256,7 @@ class DataHandlerMixedSamplingWithClimateFirFilter(DataHandlerClimateFirFilter):
         return cls(sp, data_handler_class_unfiltered=sp_unfiltered, **dp_args)
 
     @classmethod
-    def build_update_kwargs(cls, kwargs_dict, dh_type="filtered"):
+    def build_update_transformation(cls, kwargs_dict, dh_type="filtered"):
         if "transformation" in kwargs_dict:
             trafo_opts = kwargs_dict.get("transformation")
             if isinstance(trafo_opts, dict):
@@ -313,6 +313,8 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi
     data_handler_unfiltered = DataHandlerMixedSamplingSingleStation
     _requirements = list(set(data_handler_climate_fir.requirements() + data_handler_fir[0].requirements() +
                              data_handler_fir[1].requirements() + data_handler_unfiltered.requirements()))
+    chem_indicator = "chem"
+    meteo_indicator = "meteo"
 
     def __init__(self, data_handler_class_chem, data_handler_class_meteo, data_handler_class_chem_unfiltered,
                  data_handler_class_meteo_unfiltered, chem_vars, meteo_vars, *args, **kwargs):
@@ -351,32 +353,32 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi
 
         if len(chem_vars) > 0:
             sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_climate_fir.requirements() if k in kwargs}
-            sp_keys = cls.build_update_kwargs(sp_keys, dh_type="filtered_chem")
+            sp_keys = cls.build_update_transformation(sp_keys, dh_type="filtered_chem")
 
-            cls.prepare_build(sp_keys, chem_vars, "chem")
+            cls.prepare_build(sp_keys, chem_vars, cls.chem_indicator)
             sp_chem = cls.data_handler_climate_fir(station, **sp_keys)
             if filter_add_unfiltered is True:
                 sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_unfiltered.requirements() if k in kwargs}
-                sp_keys = cls.build_update_kwargs(sp_keys, dh_type="unfiltered_chem")
-                cls.prepare_build(sp_keys, chem_vars, "chem")
+                sp_keys = cls.build_update_transformation(sp_keys, dh_type="unfiltered_chem")
+                cls.prepare_build(sp_keys, chem_vars, cls.chem_indicator)
                 sp_chem_unfiltered = cls.data_handler_unfiltered(station, **sp_keys)
         if len(meteo_vars) > 0:
             if cls.data_handler_fir_pos is None:
                 if "extend_length_opts" in kwargs:
-                    if isinstance(kwargs["extend_length_opts"], dict) and "meteo" not in kwargs["extend_length_opts"].keys():
+                    if isinstance(kwargs["extend_length_opts"], dict) and cls.meteo_indicator not in kwargs["extend_length_opts"].keys():
                         cls.data_handler_fir_pos = 0  # use faster fir version without climate estimate
                     else:
                         cls.data_handler_fir_pos = 1  # use slower fir version with climate estimate
                 else:
                     cls.data_handler_fir_pos = 0  # use faster fir version without climate estimate
             sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_fir[cls.data_handler_fir_pos].requirements() if k in kwargs}
-            sp_keys = cls.build_update_kwargs(sp_keys, dh_type="filtered_meteo")
-            cls.prepare_build(sp_keys, meteo_vars, "meteo")
+            sp_keys = cls.build_update_transformation(sp_keys, dh_type="filtered_meteo")
+            cls.prepare_build(sp_keys, meteo_vars, cls.meteo_indicator)
             sp_meteo = cls.data_handler_fir[cls.data_handler_fir_pos](station, **sp_keys)
             if filter_add_unfiltered is True:
                 sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_unfiltered.requirements() if k in kwargs}
-                sp_keys = cls.build_update_kwargs(sp_keys, dh_type="unfiltered_meteo")
-                cls.prepare_build(sp_keys, meteo_vars, "meteo")
+                sp_keys = cls.build_update_transformation(sp_keys, dh_type="unfiltered_meteo")
+                cls.prepare_build(sp_keys, meteo_vars, cls.meteo_indicator)
                 sp_meteo_unfiltered = cls.data_handler_unfiltered(station, **sp_keys)
 
         dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs}
@@ -385,10 +387,20 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi
     @classmethod
     def prepare_build(cls, kwargs, var_list, var_type):
         kwargs.update({"variables": var_list})
-        cls.adjust_window_opts(var_type, "window_history_size", kwargs)
-        cls.adjust_window_opts(var_type, "window_history_offset", kwargs)
-        cls.adjust_window_opts(var_type, "window_history_end", kwargs)
-        cls.adjust_window_opts(var_type, "extend_length_opts", kwargs)
+        for k in list(kwargs.keys()):
+            v = kwargs[k]
+            if isinstance(v, dict):
+                if len(set(v.keys()).intersection({cls.chem_indicator, cls.meteo_indicator})) > 0:
+                    try:
+                        new_v = kwargs.pop(k)
+                        kwargs[k] = new_v[var_type]
+                    except KeyError:
+                        pass
+        #
+        # cls.adjust_window_opts(var_type, "window_history_size", kwargs)
+        # cls.adjust_window_opts(var_type, "window_history_offset", kwargs)
+        # cls.adjust_window_opts(var_type, "window_history_end", kwargs)
+        # cls.adjust_window_opts(var_type, "extend_length_opts", kwargs)
 
     @staticmethod
     def adjust_window_opts(key: str, parameter_name: str, kwargs: dict):
@@ -420,7 +432,7 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi
         # chem transformation
         if len(chem_vars) > 0:
             kwargs_chem = copy.deepcopy(kwargs)
-            cls.prepare_build(kwargs_chem, chem_vars, "chem")
+            cls.prepare_build(kwargs_chem, chem_vars, cls.chem_indicator)
             dh_transformation = (cls.data_handler_climate_fir, cls.data_handler_unfiltered)
             transformation_chem = super().transformation(set_stations, tmp_path=tmp_path,
                                                          dh_transformation=dh_transformation, **kwargs_chem)
@@ -428,7 +440,7 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi
         # meteo transformation
         if len(meteo_vars) > 0:
             kwargs_meteo = copy.deepcopy(kwargs)
-            cls.prepare_build(kwargs_meteo, meteo_vars, "meteo")
+            cls.prepare_build(kwargs_meteo, meteo_vars, cls.meteo_indicator)
             dh_transformation = (cls.data_handler_fir[cls.data_handler_fir_pos or 0], cls.data_handler_unfiltered)
             transformation_meteo = super().transformation(set_stations, tmp_path=tmp_path,
                                                           dh_transformation=dh_transformation, **kwargs_meteo)
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index dfcc9edb..5f2873a5 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -184,7 +184,7 @@ class PostProcessing(RunEnvironment):
         against the number of observations and diversity ot stations.
         """
         path = self.data_store.get("forecast_path")
-        all_stations = self.data_store.get("stations")
+        all_stations = self.data_store.get("stations", "test")
         start = self.data_store.get("start", "test")
         end = self.data_store.get("end", "test")
         index_dim = self.index_dim
-- 
GitLab