From 82c4ba1e2844b1e9b23bcc23654d8df8ca2fae29 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Mon, 17 May 2021 18:25:29 +0200
Subject: [PATCH] new data handler for mixed sampling and climate fir, fine
 tuning is required for new parameter apriori_diurnal

---
 .../data_handler_mixed_sampling.py            | 39 +++++++-
 .../data_handler/data_handler_with_filter.py  | 11 ++-
 mlair/helpers/filter.py                       | 99 ++++++++++++++-----
 3 files changed, 120 insertions(+), 29 deletions(-)

diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py
index 718a8f3e..565a50df 100644
--- a/mlair/data_handler/data_handler_mixed_sampling.py
+++ b/mlair/data_handler/data_handler_mixed_sampling.py
@@ -3,7 +3,7 @@ __date__ = '2020-11-05'
 
 from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
 from mlair.data_handler.data_handler_with_filter import DataHandlerKzFilterSingleStation, \
-    DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation
+    DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerClimateFirFilterSingleStation
 from mlair.data_handler import DefaultDataHandler
 from mlair import helpers
 from mlair.helpers import remove_items
@@ -221,6 +221,43 @@ class DataHandlerMixedSamplingWithFirFilter(DefaultDataHandler):
     _requirements = data_handler.requirements()
 
 
+class DataHandlerMixedSamplingWithClimateFirFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation,
+                                                                DataHandlerClimateFirFilterSingleStation):
+    _requirements1 = DataHandlerClimateFirFilterSingleStation.requirements()
+    _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements()
+    _requirements = list(set(_requirements1 + _requirements2))
+
+    def estimate_filter_width(self):
+        """Filter width is determined by the filter with the highest order."""
+        return max(self.filter_order)
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def _extract_lazy(self, lazy_data):
+        _data, _meta, _input_data, _target_data, self.climate_filter_coeff, self.apriori, self.all_apriori = lazy_data
+        DataHandlerSingleStation._extract_lazy(self, (_data, _meta, _input_data, _target_data))
+
+    @staticmethod
+    def _get_fs(**kwargs):
+        """Return frequency in 1/day (not Hz)"""
+        sampling = kwargs.get("sampling")[0]
+        if sampling == "daily":
+            return 1
+        elif sampling == "hourly":
+            return 24
+        else:
+            raise ValueError(f"Unknown sampling rate {sampling}. Only daily and hourly resolution is supported.")
+
+
+class DataHandlerMixedSamplingWithClimateFirFilter(DefaultDataHandler):
+    """Data handler using mixed sampling for input and target. Inputs are temporal filtered."""
+
+    data_handler = DataHandlerMixedSamplingWithClimateFirFilterSingleStation
+    data_handler_transformation = DataHandlerMixedSamplingWithClimateFirFilterSingleStation
+    _requirements = data_handler.requirements()
+
+
 class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithKzFilterSingleStation):
     """
     Data handler using mixed sampling for input and target. Inputs are temporal filtered and depending on the
diff --git a/mlair/data_handler/data_handler_with_filter.py b/mlair/data_handler/data_handler_with_filter.py
index 7be76082..097c0da7 100644
--- a/mlair/data_handler/data_handler_with_filter.py
+++ b/mlair/data_handler/data_handler_with_filter.py
@@ -295,16 +295,20 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation
     :param apriori_type: set type of information that is provided to the clim filter. For the first low pass always a
         calculated or given statistic is used. For residuum prediction a constant value of zero is assumed if
         apriori_type is None or `zeros`, and a climatology of the residuum is used for `residuum_stats`.
+    :param apriori_diurnal: use diurnal anomalies of each hour as addition to the apriori information type chosen by
+        parameter apriori_type. This is only applicable for hourly resolution data.
     """
 
     _requirements = remove_items(DataHandlerFirFilterSingleStation.requirements(), "station")
-    _hash = DataHandlerFirFilterSingleStation._hash + ["apriori_type", "apriori_sel_opts"]
+    _hash = DataHandlerFirFilterSingleStation._hash + ["apriori_type", "apriori_sel_opts", "apriori_diurnal"]
     _store_attributes = DataHandlerFirFilterSingleStation.store_attributes() + ["apriori"]
 
-    def __init__(self, *args, apriori=None, apriori_type=None, apriori_sel_opts=None, plot_path=None, **kwargs):
+    def __init__(self, *args, apriori=None, apriori_type=None, apriori_diurnal=False, apriori_sel_opts=None,
+                 plot_path=None, **kwargs):
         self.apriori_type = apriori_type
         self.climate_filter_coeff = None  # coefficents of the used FIR filter
         self.apriori = apriori  # exogenous apriori information or None to calculate from data (endogenous)
+        self.apriori_diurnal = apriori_diurnal
         self.all_apriori = None  # collection of all apriori information
         self.apriori_sel_opts = apriori_sel_opts  # ensure to separate exogenous and endogenous information
         self.plot_path = plot_path  # use this path to create insight plots
@@ -317,7 +321,8 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation
         climate_filter = ClimateFIRFilter(self.input_data, self.fs, self.filter_order, self.filter_cutoff_freq,
                                           self.filter_window_type, time_dim=self.time_dim, var_dim=self.target_dim,
                                           apriori_type=self.apriori_type, apriori=self.apriori,
-                                          sel_opts=self.apriori_sel_opts, plot_path=self.plot_path, plot_name=str(self))
+                                          apriori_diurnal=self.apriori_diurnal, sel_opts=self.apriori_sel_opts,
+                                          plot_path=self.plot_path, plot_name=str(self))
         self.climate_filter_coeff = climate_filter.filter_coefficients
 
         # store apriori information: store all if residuum_stat method was used, otherwise just store initial apriori
diff --git a/mlair/helpers/filter.py b/mlair/helpers/filter.py
index 84b97295..b26b616f 100644
--- a/mlair/helpers/filter.py
+++ b/mlair/helpers/filter.py
@@ -57,7 +57,8 @@ class FIRFilter:
 class ClimateFIRFilter:
 
     def __init__(self, data, fs, order, cutoff, window, time_dim, var_dim, apriori=None, apriori_type=None,
-                 sel_opts=None, plot_path=None, plot_name=None, vectorized=True, padlen_factor=0.8):
+                 apriori_diurnal=False, sel_opts=None, plot_path=None, plot_name=None, vectorized=True,
+                 padlen_factor=0.8):
         """
         :param data: data to filter
         :param fs: sampling frequency in 1/days -> 1d: fs=1 -> 1H: fs=24
@@ -71,6 +72,8 @@ class ClimateFIRFilter:
         :param apriori_type: type of apriori information to use. Climatology will be used always for first low pass. For
             the residuum either the value zero is used (apriori_type is None or "zeros") or a climatology on the
             residua is used ("residuum_stats").
+        :param apriori_diurnal: Use diurnal cycle as additional apriori information (only applicable for hourly
+            resoluted data). The mean anomaly of each hour is added to the apriori_type information.
         """
         self.plot_path = plot_path
         self.plot_name = plot_name
@@ -78,8 +81,14 @@ class ClimateFIRFilter:
         h = []
         sel_opts = sel_opts if isinstance(sel_opts, dict) else {time_dim: sel_opts}
         sampling = {1: "1d", 24: "1H"}.get(int(fs))
+        if apriori_diurnal is True and sampling == "1H":
+            diurnal_anomalies = self.create_hourly_mean(data, sel_opts=sel_opts, sampling=sampling, time_dim=time_dim,
+                                                        as_anomaly=True)
+        else:
+            diurnal_anomalies = 0
         if apriori is None:
-            apriori = self.create_monthly_mean(data, sel_opts=sel_opts, sampling=sampling, time_dim=time_dim)
+            apriori = self.create_monthly_mean(data, sel_opts=sel_opts, sampling=sampling,
+                                               time_dim=time_dim) + diurnal_anomalies
         apriori_list = to_list(apriori)
         input_data = data.__deepcopy__()
         for i in range(len(order)):
@@ -97,11 +106,16 @@ class ClimateFIRFilter:
 
             # create new apriori information for next iteration if no further apriori is provided
             if len(apriori_list) <= i + 1:
+                if apriori_diurnal is True and sampling == "1H":
+                    diurnal_anomalies = self.create_hourly_mean(input_data, sel_opts=sel_opts, sampling=sampling,
+                                                                time_dim=time_dim, as_anomaly=True)
+                else:
+                    diurnal_anomalies = 0
                 if apriori_type is None or apriori_type == "zeros":  # zero version
-                    apriori_list.append(xr.zeros_like(apriori_list[i]))
+                    apriori_list.append(xr.zeros_like(apriori_list[i]) + diurnal_anomalies)
                 elif apriori_type == "residuum_stats":  # calculate monthly statistic on residuum
                     apriori_list.append(-self.create_monthly_mean(input_data, sel_opts=sel_opts, sampling=sampling,
-                                                                  time_dim=time_dim))
+                                                                  time_dim=time_dim) + diurnal_anomalies)
                 else:
                     raise ValueError(f"Cannot handle unkown apriori type: {apriori_type}. Please choose from None, "
                                      f"`zeros` or `residuum_stats`.")
@@ -141,20 +155,49 @@ class ClimateFIRFilter:
         # create monthly mean and replace entries in unity array
         monthly_mean = data.groupby(f"{time_dim}.month").mean()
         for month in monthly_mean.month.values:
-            loc = (monthly[f"{time_dim}.month"] == month)
-            monthly.loc[{time_dim: loc}] = monthly_mean.sel(month=month)
-
+            monthly = xr.where((monthly[f"{time_dim}.month"] == month),
+                               monthly_mean.sel(month=month, drop=True),
+                               monthly)
+        # transform monthly information into original sampling rate
+        return monthly.resample({time_dim: sampling}).interpolate()
+
+        # for month in monthly_mean.month.values:
+        #     loc = (monthly[f"{time_dim}.month"] == month)
+        #     monthly.loc[{time_dim: loc}] = monthly_mean.sel(month=month, drop=True)
         # aggregate monthly information (shift by half month, because resample base is last day)
-        return monthly.resample({time_dim: "1m"}).max().resample({time_dim: sampling}).interpolate()
+        # return monthly.resample({time_dim: "1m"}).max().resample({time_dim: sampling}).interpolate()
+
+    @staticmethod
+    def create_hourly_mean(data, sel_opts=None, sampling="1H", time_dim="datetime", as_anomaly=True):
+        """Calculate hourly statistics. Either the absolute value or the anomaly (as_anomaly=True)."""
+        # can only be used for hourly sampling rate
+        assert sampling == "1H"
+
+        # create unity xarray in hourly resolution
+        hourly = xr.ones_like(data)
+
+        # apply selection if given (only use subset for hourly means)
+        if sel_opts is not None:
+            data = data.sel(**sel_opts)
+
+        # create mean for each hour and replace entries in unity array, calculate anomaly if enabled
+        hourly_mean = data.groupby(f"{time_dim}.hour").mean()
+        if as_anomaly is True:
+            hourly_mean = hourly_mean - hourly_mean.mean("hour")
+        for hour in hourly_mean.hour.values:
+            loc = (hourly[f"{time_dim}.hour"] == hour)
+            hourly.loc[{f"{time_dim}": loc}] = hourly_mean.sel(hour=hour)
+        return hourly
 
     @staticmethod
-    def extend_apriori(data, apriori, time_dim):
+    def extend_apriori(data, apriori, time_dim, sampling="1d"):
         """
         Extend time range of apriori information.
 
         This method will fail, if apriori is available for a shorter period than the gab to fill.
         """
         dates = data.coords[time_dim].values
+        td_type = {"1d": "D", "1H": "h"}.get(sampling)
 
         # apriori starts after data
         if dates[0] < apriori.coords[time_dim].values[0]:
@@ -164,8 +207,8 @@ class ClimateFIRFilter:
             coords = apriori.coords
 
             # create new time axis
-            start = coords[time_dim][0].values.astype("datetime64[D]") - np.timedelta64(extend_range, "D")
-            end = coords[time_dim][0].values.astype("datetime64[D]")
+            start = coords[time_dim][0].values.astype("datetime64[%s]" % td_type) - np.timedelta64(extend_range, "D")
+            end = coords[time_dim][0].values.astype("datetime64[%s]" % td_type)
             new_time_axis = np.arange(start, end).astype("datetime64[ns]")
 
             # extract old values to use with new axis
@@ -185,13 +228,16 @@ class ClimateFIRFilter:
             coords = apriori.coords
 
             # create new time axis
-            start = coords[time_dim][-1].values.astype("datetime64[D]")
-            end = coords[time_dim][-1].values.astype("datetime64[D]") + np.timedelta64(extend_range, "D")
+            factor = 1 if td_type == "D" else 24
+            start = coords[time_dim][-1].values.astype("datetime64[%s]" % td_type)
+            end = coords[time_dim][-1].values.astype("datetime64[%s]" % td_type) + np.timedelta64(extend_range * factor,
+                                                                                                  td_type)
             new_time_axis = np.arange(start, end).astype("datetime64[ns]")
 
             # extract old values to use with new axis
-            start = coords[time_dim][-1].values.astype("datetime64[D]") - np.timedelta64(extend_range - 1, "D")
-            end = coords[time_dim][-1].values.astype("datetime64[D]")
+            start = coords[time_dim][-1].values.astype("datetime64[%s]" % td_type) - np.timedelta64(
+                extend_range * factor - 1, td_type)
+            end = coords[time_dim][-1].values.astype("datetime64[%s]" % td_type)
             new_values = apriori.sel({time_dim: slice(start, end)})
             new_values.coords[time_dim] = new_time_axis
 
@@ -207,7 +253,7 @@ class ClimateFIRFilter:
         # calculate apriori information from data if not given and extend its range if not sufficient long enough
         if apriori is None:
             apriori = self.create_monthly_mean(data, sel_opts=sel_opts, sampling=sampling, time_dim=time_dim)
-        apriori = self.extend_apriori(data, apriori, time_dim)
+        apriori = self.extend_apriori(data, apriori, time_dim, sampling)
 
         # calculate FIR filter coefficients
         h = signal.firwin(order, cutoff_high, pass_zero="lowpass", fs=fs, window=window)
@@ -248,7 +294,7 @@ class ClimateFIRFilter:
         # calculate apriori information from data if not given and extend its range if not sufficient long enough
         if apriori is None:
             apriori = self.create_monthly_mean(data, sel_opts=sel_opts, sampling=sampling, time_dim=time_dim)
-        apriori = self.extend_apriori(data, apriori, time_dim)
+        apriori = self.extend_apriori(data, apriori, time_dim, sampling)
 
         # calculate FIR filter coefficients
         h = signal.firwin(order, cutoff_high, pass_zero="lowpass", fs=fs, window=window)
@@ -258,12 +304,14 @@ class ClimateFIRFilter:
         new_dim = self._create_tmp_dimension(data)
 
         # combine historical data / observation [t0-length,t0] and climatological statistics [t0+1,t0+length]
-        history = self._shift_data(data, range(-length, 1), time_dim, var_dim, new_dim)
-        future = self._shift_data(apriori, range(1, length + 1), time_dim, var_dim, new_dim)
+        history = self._shift_data(data, range(int(-(length - 1) / 2), 1), time_dim, var_dim, new_dim)
+        future = self._shift_data(apriori, range(1, int((length - 1) / 2) + 1), time_dim, var_dim, new_dim)
         filter_input_data = xr.concat([history.dropna(time_dim), future], dim=new_dim, join="left")
         # filter_input_data = history.combine_first(future)
+        # history.sel(datetime=slice("2010-11-01", "2011-04-01"),variables="o3").plot()
+        # filter_input_data.sel(datetime=slice("2009-11-01", "2011-04-01"),variables="temp").plot()
 
-        time_axis = filter_input_data.coords["datetime"]
+        time_axis = filter_input_data.coords[time_dim]
         # apply vectorized fir filter along the tmp dimension
         kwargs = {"fs": fs, "cutoff_high": cutoff_high, "order": order,
                   "causal": False, "padlen": int(min(padlen_factor, 1) * length), "h": h}
@@ -275,19 +323,20 @@ class ClimateFIRFilter:
         #                           kwargs=kwargs)
         with TimeTracking(name="convolve"):
             slicer = slice(int(-(length - 1) / 2), int((length - 1) / 2))
-            filt = xr.apply_ufunc(fir_filter_convolve_vectorized, filter_input_data.sel(window=slicer),
-                                  input_core_dims=[["window"]],
-                                  output_core_dims=[["window"]],
+            filt = xr.apply_ufunc(fir_filter_convolve_vectorized, filter_input_data.sel({new_dim: slicer}),
+                                  input_core_dims=[[new_dim]],
+                                  output_core_dims=[[new_dim]],
                                   vectorize=True,
                                   kwargs={"h": h})
 
         # plot
         if self.plot_path is not None:
             try:
-                pos = 720
+                pos = 720 * fs
                 filter_example = filter_input_data.isel({time_dim: pos})
                 t0 = filter_example.coords[time_dim].values
-                t_slice = filter_input_data.isel({time_dim: slice(pos - length, pos + length + 1)}).coords[
+                t_slice = filter_input_data.isel(
+                    {time_dim: slice(pos - int((length - 1) / 2), pos + int((length - 1) / 2) + 1)}).coords[
                     time_dim].values
                 self.plot(data, filter_example, var_dim, time_dim, t_slice, t0, plot_index)
             except IndexError:
-- 
GitLab