From c8080c7dd246badaa4e53364ea2bb327b12b0175 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Thu, 15 Jun 2023 13:02:23 +0200
Subject: [PATCH] IFS data loader work with filter approach

---
 .../data_handler_single_station.py            |  2 +-
 .../data_handler/data_handler_with_filter.py  |  7 +---
 mlair/helpers/filter.py                       | 34 +++++++++++++++----
 mlair/plotting/data_insight_plotting.py       | 15 ++++++++
 4 files changed, 44 insertions(+), 14 deletions(-)

diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py
index 552123f3..e60456c2 100644
--- a/mlair/data_handler/data_handler_single_station.py
+++ b/mlair/data_handler/data_handler_single_station.py
@@ -313,7 +313,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
         self.target_data = targets
 
     def make_samples(self):
-        self.make_history_window(self.target_dim, self.window_history_size, self.time_dim)  #todo stopped here
+        self.make_history_window(self.target_dim, self.window_history_size, self.time_dim)
         self.make_labels(self.target_dim, self.target_var, self.time_dim, self.window_lead_time)
         self.make_observation(self.target_dim, self.target_var, self.time_dim)
         self.remove_nan(self.time_dim)
diff --git a/mlair/data_handler/data_handler_with_filter.py b/mlair/data_handler/data_handler_with_filter.py
index 4ec25a83..21c45991 100644
--- a/mlair/data_handler/data_handler_with_filter.py
+++ b/mlair/data_handler/data_handler_with_filter.py
@@ -460,12 +460,7 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation
         :param window: this parameter is not used in the inherited method
         :param dim_name_of_shift: Dimension along shift will be applied
         """
-        data = self.input_data # TODO has to be refactored as window_history_offset must be included in filter calc
-        sampling = {"daily": "D", "hourly": "h"}.get(to_list(self.sampling)[0])
-        data.coords[dim_name_of_shift] = data.coords[dim_name_of_shift] - np.timedelta64(self.window_history_offset,
-                                                                                         sampling)
-        data.coords[self.window_dim] = data.coords[self.window_dim] + self.window_history_offset
-        self.history = data
+        self.history = self.input_data
         # from matplotlib import pyplot as plt
         # d = self.load_and_interpolate(0)
         # data.sel(datetime="2007-07-07 00:00").sum("filter").plot()
diff --git a/mlair/helpers/filter.py b/mlair/helpers/filter.py
index 370c8b75..bc115988 100644
--- a/mlair/helpers/filter.py
+++ b/mlair/helpers/filter.py
@@ -594,7 +594,6 @@ class ClimateFIRFilter(FIRFilter):
                 forecasts_tmp.coords[time_dim] = forecasts_tmp.coords[time_dim] + delta
                 forecasts_tmp.coords[new_dim] = forecasts_tmp.coords[new_dim] + offset - lead_time
                 history = history.combine_first(forecasts_tmp)
-                print(lead_time)
                 # history.plot()
                 if lead_time >= forecast_delta - 1:  # stop when all gaps are filled
                     break
@@ -651,10 +650,12 @@ class ClimateFIRFilter(FIRFilter):
 
     def create_visualization(self, filtered, data, filter_input_data, plot_dates, time_dim, new_dim, sampling,
                              extend_length_history, extend_length_future, minimum_length, h,
-                             variable_name, extend_length_opts=None, extend_end=None):  # pragma: no cover
+                             variable_name, extend_length_opts=None, extend_end=None, offset=None, forecast=None):  # pragma: no cover
+
         plot_data = []
         extend_end = 0 if extend_end is None else extend_end
         extend_length_opts = 0 if extend_length_opts is None else extend_length_opts
+        offset = 0 if offset is None else offset
         for t0 in set(plot_dates).intersection(filtered.coords[time_dim].values):
             try:
                 td_type = {"1d": "D", "1H": "h"}.get(sampling)
@@ -668,14 +669,32 @@ class ClimateFIRFilter(FIRFilter):
                                                        new_dim).sel({time_dim: t0})
                 else:
                     tmp_filter_data = None
+                filter_input = filter_input_data.sel({time_dim: t0, new_dim: slice(None, extend_length_future)})
                 valid_start = int(filtered[new_dim].min()) + int((len(h) + 1) / 2)
                 valid_end = min(extend_length_opts + extend_end + 1, int(filtered[new_dim].max()) - int((len(h) + 1) / 2))
                 valid_range = range(valid_start, valid_end)
-                plot_data.append({"t0": t0,
+                # if forecast is not None:
+                #     forecast_deltas = (t0 - forecast.coords[time_dim]) / np.timedelta64(1, "h") + 12
+                #     minimal_forecast_delta = int(forecast_deltas[forecast_deltas >= 0][-1])
+                #     init_time = forecast.coords[time_dim][forecast_deltas >= 0][-1]
+                #     forecast_end = min(extend_end, extend_length_opts)
+                #     f = forecast.sel({time_dim: init_time, new_dim: slice(None, forecast_end)})
+                #     f.coords[time_dim] = t0
+                #     f.coords[new_dim] = f.coords[new_dim] - minimal_forecast_delta + offset
+                # else:
+                #     f = None
+                # correct all data for offset
+                def correct_data_for_offset(d):
+                    d = d.__deepcopy__()
+                    d.coords[time_dim] = d.coords[time_dim] + np.timedelta64(int(offset), td_type)
+                    d.coords[new_dim] = d.coords[new_dim] - offset
+                    return d
+                plot_data.append({"t0": t0 + np.timedelta64(int(offset), td_type),
                                   "var": variable_name,
-                                  "filter_input": filter_input_data.sel({time_dim: t0}),
-                                  "filter_input_nc": tmp_filter_data,
-                                  "valid_range": valid_range,
+                                  "filter_input": correct_data_for_offset(filter_input),
+                                  "filter_input_nc": correct_data_for_offset(tmp_filter_data),
+                                  "valid_range": range(valid_range.start - offset, valid_range.stop - offset),
+                                  # "forecast": f,
                                   "time_range": data.sel(
                                       {time_dim: slice(t_minus, t_plus - np.timedelta64(1, td_type))}).coords[
                                       time_dim].values,
@@ -833,9 +852,10 @@ class ClimateFIRFilter(FIRFilter):
                 filt_input_coll.append(trimmed)
 
                 # visualization
+                plot_dates = [filt.coords[time_dim][49].values]
                 plot_data.extend(self.create_visualization(filt, d, filter_input_data, plot_dates, time_dim, new_dim,
                                                            sampling, extend_length_history, extend_length_future,
-                                                           minimum_length, h, var, extend_opts, extend_end))
+                                                           minimum_length, h, var, extend_opts, extend_end, offset, f))
 
             # collect all filter results
             coll.append(xr.concat(filt_coll, time_dim))
diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py
index ee5aea4a..76e7aadc 100644
--- a/mlair/plotting/data_insight_plotting.py
+++ b/mlair/plotting/data_insight_plotting.py
@@ -1093,6 +1093,9 @@ class PlotClimateFirFilter(AbstractPlotClass):  # pragma: no cover
                         # clim apriori
                         self._plot_apriori(ax, time_axis, filter_input, new_dim, ifilter, offset=1)
 
+                        # get ax lims
+                        ylims = ax.get_ylim()
+
                         # clim filter response
                         residuum_estimated = self._plot_clim_filter(ax, time_axis, filter_input, new_dim, h,
                                                                     output_dtypes=filter_input.dtype)
@@ -1103,6 +1106,7 @@ class PlotClimateFirFilter(AbstractPlotClass):  # pragma: no cover
 
                         # set title, legend, and save plot
                         xlims = self._set_xlim(ax, t0, filter_order, valid_range, td_type, time_axis)
+                        ax.set_ylim(ylims)
 
                         plt.title(f"Input of ClimFilter ({str(var)})")
                         plt.legend()
@@ -1118,6 +1122,7 @@ class PlotClimateFirFilter(AbstractPlotClass):  # pragma: no cover
                         self._plot_series(ax, time_axis, residuum_true.values.flatten(), style="ideal")
                         self._plot_series(ax, time_axis, residuum_estimated.values.flatten(), style="clim")
                         ax.set_xlim(xlims)
+                        self._set_ylim_by_valid_range(ax, residuum_true, residuum_estimated, new_dim, valid_range)
                         plt.title(f"Residuum of ClimFilter ({str(var)})")
                         plt.legend(loc="upper left")
                         fig.autofmt_xdate()
@@ -1129,6 +1134,16 @@ class PlotClimateFirFilter(AbstractPlotClass):  # pragma: no cover
                     logging.info(f"Could not create plot because of:\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")
                     pass
 
+    @staticmethod
+    def _set_ylim_by_valid_range(ax, a, b, dim, valid_range):
+        ymax = max(a.sel({dim: valid_range}).max(),
+                   b.sel({dim: valid_range}).max())
+        ymin = min(a.sel({dim: valid_range}).min(),
+                   b.sel({dim: valid_range}).min())
+        ymax = 1.1 * ymax if ymax > 0 else 0.9 * ymax
+        ymin = 0.9 * ymin if ymin > 0 else 1.1 * ymin
+        ax.set_ylim((ymin, ymax))
+
     def _set_xlim(self, ax, t0, order, valid_range, td_type, time_axis):
         """
         Set xlims
-- 
GitLab