From 98944e1779ef8e776e3ba2a3669429b23ee47135 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Fri, 14 Jan 2022 15:57:12 +0100
Subject: [PATCH] can now use external plot dates

---
 .../data_handler_mixed_sampling.py            | 50 ++++++++++---------
 .../data_handler/data_handler_with_filter.py  |  9 ++--
 mlair/helpers/filter.py                       | 32 ++++++------
 mlair/plotting/data_insight_plotting.py       |  4 +-
 4 files changed, 53 insertions(+), 42 deletions(-)

diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py
index 5466b6da..d9add56b 100644
--- a/mlair/data_handler/data_handler_mixed_sampling.py
+++ b/mlair/data_handler/data_handler_mixed_sampling.py
@@ -413,35 +413,39 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi
             return
 
         chem_vars, meteo_vars = cls._split_chem_and_meteo_variables(**kwargs)
-
+        transformation_chem, transformation_meteo = None, None
         # chem transformation
-        kwargs_chem = copy.deepcopy(kwargs)
-        cls.prepare_build(kwargs_chem, chem_vars, "chem")
-        dh_transformation = (cls.data_handler_climate_fir, cls.data_handler_unfiltered)
-        transformation_chem = super().transformation(set_stations, tmp_path=tmp_path,
-                                                     dh_transformation=dh_transformation, **kwargs_chem)
+        if len(chem_vars) > 0:
+            kwargs_chem = copy.deepcopy(kwargs)
+            cls.prepare_build(kwargs_chem, chem_vars, "chem")
+            dh_transformation = (cls.data_handler_climate_fir, cls.data_handler_unfiltered)
+            transformation_chem = super().transformation(set_stations, tmp_path=tmp_path,
+                                                         dh_transformation=dh_transformation, **kwargs_chem)
 
         # meteo transformation
-        kwargs_meteo = copy.deepcopy(kwargs)
-        cls.prepare_build(kwargs_meteo, meteo_vars, "meteo")
-        dh_transformation = (cls.data_handler_fir[cls.data_handler_fir_pos or 0], cls.data_handler_unfiltered)
-        transformation_meteo = super().transformation(set_stations, tmp_path=tmp_path,
-                                                      dh_transformation=dh_transformation, **kwargs_meteo)
+        if len(meteo_vars) > 0:
+            kwargs_meteo = copy.deepcopy(kwargs)
+            cls.prepare_build(kwargs_meteo, meteo_vars, "meteo")
+            dh_transformation = (cls.data_handler_fir[cls.data_handler_fir_pos or 0], cls.data_handler_unfiltered)
+            transformation_meteo = super().transformation(set_stations, tmp_path=tmp_path,
+                                                          dh_transformation=dh_transformation, **kwargs_meteo)
 
         # combine all transformations
         transformation_res = {}
-        if isinstance(transformation_chem, dict):
-            if len(transformation_chem) > 0:
-                transformation_res["filtered_chem"] = transformation_chem.pop("filtered")
-                transformation_res["unfiltered_chem"] = transformation_chem.pop("unfiltered")
-        else:  # if no unfiltered chem branch
-            transformation_res["filtered_chem"] = transformation_chem
-        if isinstance(transformation_meteo, dict):
-            if len(transformation_meteo) > 0:
-                transformation_res["filtered_meteo"] = transformation_meteo.pop("filtered")
-                transformation_res["unfiltered_meteo"] = transformation_meteo.pop("unfiltered")
-        else:  # if no unfiltered meteo branch
-            transformation_res["filtered_meteo"] = transformation_meteo
+        if transformation_chem is not None:
+            if isinstance(transformation_chem, dict):
+                if len(transformation_chem) > 0:
+                    transformation_res["filtered_chem"] = transformation_chem.pop("filtered")
+                    transformation_res["unfiltered_chem"] = transformation_chem.pop("unfiltered")
+            else:  # if no unfiltered chem branch
+                transformation_res["filtered_chem"] = transformation_chem
+        if transformation_meteo is not None:
+            if isinstance(transformation_meteo, dict):
+                if len(transformation_meteo) > 0:
+                    transformation_res["filtered_meteo"] = transformation_meteo.pop("filtered")
+                    transformation_res["unfiltered_meteo"] = transformation_meteo.pop("unfiltered")
+            else:  # if no unfiltered meteo branch
+                transformation_res["filtered_meteo"] = transformation_meteo
         return transformation_res if len(transformation_res) > 0 else None
 
     def get_X_original(self):
diff --git a/mlair/data_handler/data_handler_with_filter.py b/mlair/data_handler/data_handler_with_filter.py
index 761c7c8d..73afc778 100644
--- a/mlair/data_handler/data_handler_with_filter.py
+++ b/mlair/data_handler/data_handler_with_filter.py
@@ -124,7 +124,8 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation):
 
     DEFAULT_WINDOW_TYPE = ("kaiser", 5)
 
-    def __init__(self, *args, filter_cutoff_period, filter_order, filter_window_type=DEFAULT_WINDOW_TYPE, plot_path=None, **kwargs):
+    def __init__(self, *args, filter_cutoff_period, filter_order, filter_window_type=DEFAULT_WINDOW_TYPE,
+                 plot_path=None, filter_plot_dates=None, **kwargs):
         # self.original_data = None  # ToDo: implement here something to store unfiltered data
         self.fs = self._get_fs(**kwargs)
         if filter_window_type == "kzf":
@@ -136,6 +137,7 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation):
         self.filter_window_type = filter_window_type
         self.unfiltered_name = "unfiltered"
         self.plot_path = plot_path  # use this path to create insight plots
+        self.plot_dates = filter_plot_dates
         super().__init__(*args, **kwargs)
 
     @staticmethod
@@ -191,7 +193,8 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation):
         """Apply FIR filter only on inputs."""
         fir = FIRFilter(self.input_data.astype("float32"), self.fs, self.filter_order, self.filter_cutoff_freq,
                         self.filter_window_type, self.target_dim, self.time_dim, display_name=self.station[0],
-                        minimum_length=self.window_history_size, offset=self.window_history_offset, plot_path=self.plot_path)
+                        minimum_length=self.window_history_size, offset=self.window_history_offset,
+                        plot_path=self.plot_path, plot_dates=self.plot_dates)
         self.fir_coeff = fir.filter_coefficients
         filter_data = fir.filtered_data
         self.input_data = xr.concat(filter_data, pd.Index(self.create_filter_index(), name=self.filter_dim))
@@ -365,7 +368,7 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation
                                           plot_path=self.plot_path,
                                           minimum_length=self.window_history_size, new_dim=self.window_dim,
                                           display_name=self.station[0], extend_length_opts=self.extend_length_opts,
-                                          offset=self.window_history_end)
+                                          offset=self.window_history_end, plot_dates=self.plot_dates)
         self.climate_filter_coeff = climate_filter.filter_coefficients
 
         # store apriori information: store all if residuum_stat method was used, otherwise just store initial apriori
diff --git a/mlair/helpers/filter.py b/mlair/helpers/filter.py
index 6c4247f3..10c9dd8f 100644
--- a/mlair/helpers/filter.py
+++ b/mlair/helpers/filter.py
@@ -17,7 +17,8 @@ from mlair.helpers import to_list, TimeTrackingWrapper, TimeTracking
 class FIRFilter:
     from mlair.plotting.data_insight_plotting import PlotFirFilter
 
-    def __init__(self, data, fs, order, cutoff, window, var_dim, time_dim, display_name=None, minimum_length=None, offset=0, plot_path=None):
+    def __init__(self, data, fs, order, cutoff, window, var_dim, time_dim, display_name=None, minimum_length=None,
+                 offset=0, plot_path=None, plot_dates=None):
         self._filtered = []
         self._h = []
         self.data = data
@@ -31,6 +32,7 @@ class FIRFilter:
         self.minimum_length = minimum_length
         self.offset = offset
         self.plot_path = plot_path
+        self.plot_dates = plot_dates
         self.run()
 
     def run(self):
@@ -40,9 +42,10 @@ class FIRFilter:
         input_data = self.data.__deepcopy__()
 
         # collect some data for visualization
-        plot_pos = np.array([0.25, 1.5, 2.75, 4]) * 365 * self.fs
-        plot_dates = [input_data.isel({self.time_dim: int(pos)}).coords[self.time_dim].values for pos in plot_pos if
-                      pos < len(input_data.coords[self.time_dim])]
+        if self.plot_dates is None:
+            plot_pos = np.array([0.25, 1.5, 2.75, 4]) * 365 * self.fs
+            self.plot_dates = [input_data.isel({self.time_dim: int(pos)}).coords[self.time_dim].values
+                               for pos in plot_pos if pos < len(input_data.coords[self.time_dim])]
         plot_data = []
 
         for i in range(len(self.order)):
@@ -53,7 +56,7 @@ class FIRFilter:
             h.append(hi)
 
             # visualization
-            plot_data.append(self.create_visualization(fi, input_data, plot_dates, self.time_dim, self.fs, hi,
+            plot_data.append(self.create_visualization(fi, input_data, self.plot_dates, self.time_dim, self.fs, hi,
                                                        self.minimum_length, self.order, i, self.offset, self.var_dim))
             # calculate residuum
             input_data = input_data - fi
@@ -167,7 +170,7 @@ class ClimateFIRFilter(FIRFilter):
     def __init__(self, data, fs, order, cutoff, window, time_dim, var_dim, apriori=None, apriori_type=None,
                  apriori_diurnal=False, sel_opts=None, plot_path=None,
                  minimum_length=None, new_dim=None, display_name=None, extend_length_opts: int = 0,
-                 offset: Union[dict, int] = 0):
+                 offset: Union[dict, int] = 0, plot_dates=None):
         """
         :param data: data to filter
         :param fs: sampling frequency in 1/days -> 1d: fs=1 -> 1H: fs=24
@@ -204,7 +207,7 @@ class ClimateFIRFilter(FIRFilter):
         self.plot_data = []
         self.extend_length_opts = extend_length_opts
         super().__init__(data, fs, order, cutoff, window, var_dim, time_dim, display_name=display_name,
-                         minimum_length=minimum_length, plot_path=plot_path, offset=offset)
+                         minimum_length=minimum_length, plot_path=plot_path, offset=offset, plot_dates=plot_dates)
 
     def run(self):
         filtered = []
@@ -226,8 +229,8 @@ class ClimateFIRFilter(FIRFilter):
         apriori_list = to_list(self._apriori)
         input_data = self.data.__deepcopy__()
 
-        # for viz
-        plot_dates = None
+        # for visualization
+        plot_dates = self.plot_dates
 
         # create tmp dimension to apply filter, search for unused name
         new_dim = self._create_tmp_dimension(input_data) if self.new_dim is None else self.new_dim
@@ -253,7 +256,7 @@ class ClimateFIRFilter(FIRFilter):
             h.append(hi)
             gc.collect()
             self.plot_data.append(plot_data)
-            plot_dates = {e["viz_date"] for e in plot_data}
+            plot_dates = {e["t0"] for e in plot_data}
 
             # calculate residuum
             logging.info(f"{self.display_name}: calculate residuum")
@@ -651,10 +654,11 @@ class ClimateFIRFilter(FIRFilter):
         :param extend_length_future: number to use in "future"
         :returns: trimmed data
         """
-        if minimum_length is None:
-            return data.sel({dim: slice(-extend_length_history, extend_length_future)}, drop=True)
-        else:
-            return data.sel({dim: slice(-minimum_length + offset, extend_length_future)}, drop=True)
+        # if minimum_length is None:
+        #     return data.sel({dim: slice(-extend_length_history, extend_length_future)}, drop=True)
+        # else:
+        #     return data.sel({dim: slice(-minimum_length + offset, extend_length_future)}, drop=True)
+        return data.sel({dim: slice(-extend_length_history, extend_length_future)}, drop=True)
 
     @staticmethod
     def _create_full_filter_result_array(template_array: xr.DataArray, result_array: xr.DataArray, new_dim: str,
diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py
index 526ee867..b12565bf 100644
--- a/mlair/plotting/data_insight_plotting.py
+++ b/mlair/plotting/data_insight_plotting.py
@@ -1072,8 +1072,8 @@ class PlotClimateFirFilter(AbstractPlotClass):  # pragma: no cover
         """
         # t_minus_delta = -(valid_range.start - 0.5 * order)
         # t_plus_delta = valid_range.stop + 0.5 * order
-        t_minus_delta = max(min(2 * (valid_range.stop - valid_range.start), 0.5 * order), -valid_range.start)
-        t_plus_delta = max(min(2 * (valid_range.stop - valid_range.start), 0.5 * order), valid_range.stop)
+        t_minus_delta = max(min(2 * (valid_range.stop - valid_range.start), 0.5 * order), (-valid_range.start + 0.3 * order))
+        t_plus_delta = max(min(2 * (valid_range.stop - valid_range.start), 0.5 * order), valid_range.stop + 0.3 * order)
         t_minus = t0 + np.timedelta64(-int(t_minus_delta), td_type)
         t_plus = t0 + np.timedelta64(int(t_plus_delta), td_type)
         ax_start = max(t_minus, time_axis[0])
-- 
GitLab