diff --git a/mlair/data_handler/data_handler_with_filter.py b/mlair/data_handler/data_handler_with_filter.py index 9b7b27f8ce2b5230f24095ed9253860f4e6ee082..07fdc41fc4dae49bd44a071dd2228c4bff860b04 100644 --- a/mlair/data_handler/data_handler_with_filter.py +++ b/mlair/data_handler/data_handler_with_filter.py @@ -124,7 +124,7 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): DEFAULT_WINDOW_TYPE = ("kaiser", 5) - def __init__(self, *args, filter_cutoff_period, filter_order, filter_window_type=DEFAULT_WINDOW_TYPE, **kwargs): + def __init__(self, *args, filter_cutoff_period, filter_order, filter_window_type=DEFAULT_WINDOW_TYPE, plot_path=None, **kwargs): # self.original_data = None # ToDo: implement here something to store unfiltered data self.fs = self._get_fs(**kwargs) if filter_window_type == "kzf": @@ -135,6 +135,7 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): self.filter_order = self._prepare_filter_order(filter_order, removed_index, self.fs) self.filter_window_type = filter_window_type self.unfiltered_name = "unfiltered" + self.plot_path = plot_path # use this path to create insight plots super().__init__(*args, **kwargs) @staticmethod @@ -189,7 +190,8 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): def apply_filter(self): """Apply FIR filter only on inputs.""" fir = FIRFilter(self.input_data.astype("float32"), self.fs, self.filter_order, self.filter_cutoff_freq, - self.filter_window_type, self.target_dim, self.time_dim, station_name=self.station) + self.filter_window_type, self.target_dim, self.time_dim, station_name=self.station[0], + minimum_length=self.window_history_size, offset=self.window_history_offset, plot_path=self.plot_path) self.fir_coeff = fir.filter_coefficients filter_data = fir.filtered_data self.input_data = xr.concat(filter_data, pd.Index(self.create_filter_index(), name=self.filter_dim)) @@ -330,14 +332,13 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation _store_attributes = DataHandlerFirFilterSingleStation.store_attributes() + ["apriori"] def __init__(self, *args, apriori=None, apriori_type=None, apriori_diurnal=False, apriori_sel_opts=None, - plot_path=None, name_affix=None, extend_length_opts=None, **kwargs): + name_affix=None, extend_length_opts=None, **kwargs): self.apriori_type = apriori_type self.climate_filter_coeff = None # coefficents of the used FIR filter self.apriori = apriori # exogenous apriori information or None to calculate from data (endogenous) self.apriori_diurnal = apriori_diurnal self.all_apriori = None # collection of all apriori information self.apriori_sel_opts = apriori_sel_opts # ensure to separate exogenous and endogenous information - self.plot_path = plot_path # use this path to create insight plots self.plot_name_affix = name_affix self.extend_length_opts = extend_length_opts if extend_length_opts is not None else {} super().__init__(*args, **kwargs) @@ -347,15 +348,14 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation """Apply FIR filter only on inputs.""" self.apriori = self.apriori.get(str(self)) if isinstance(self.apriori, dict) else self.apriori logging.info(f"{self.station}: call ClimateFIRFilter") - plot_name = str(self) # if self.plot_name_affix is None else f"{str(self)}_{self.plot_name_affix}" climate_filter = ClimateFIRFilter(self.input_data.astype("float32"), self.fs, self.filter_order, self.filter_cutoff_freq, self.filter_window_type, time_dim=self.time_dim, var_dim=self.target_dim, apriori_type=self.apriori_type, apriori=self.apriori, apriori_diurnal=self.apriori_diurnal, sel_opts=self.apriori_sel_opts, - plot_path=self.plot_path, plot_name=plot_name, + plot_path=self.plot_path, minimum_length=self.window_history_size, new_dim=self.window_dim, - station_name=self.station, extend_length_opts=self.extend_length_opts) + station_name=self.station[0], extend_length_opts=self.extend_length_opts) self.climate_filter_coeff = climate_filter.filter_coefficients # store apriori information: store all if residuum_stat method was used, otherwise just store initial apriori diff --git a/mlair/helpers/filter.py b/mlair/helpers/filter.py index 0adf173c6cda6cfb913688b308f024b722503ee3..488fdfd30d684516782b67dfc5a417fefee15a6a 100644 --- a/mlair/helpers/filter.py +++ b/mlair/helpers/filter.py @@ -17,8 +17,9 @@ from mlair.helpers import to_list, TimeTrackingWrapper, TimeTracking class FIRFilter: + from mlair.plotting.data_insight_plotting import PlotFirFilter - def __init__(self, data, fs, order, cutoff, window, var_dim, time_dim, station_name=None): + def __init__(self, data, fs, order, cutoff, window, var_dim, time_dim, station_name=None, minimum_length=0, offset=0, plot_path=None): self._filtered = [] self._h = [] self.data = data @@ -29,6 +30,9 @@ class FIRFilter: self.var_dim = var_dim self.time_dim = time_dim self.station_name = station_name + self.minimum_length = minimum_length + self.offset = offset + self.plot_path = plot_path self.run() def run(self): @@ -36,17 +40,24 @@ class FIRFilter: filtered = [] h = [] input_data = self.data.__deepcopy__() - for i in range(len(self.order)): - # fi, hi = fir_filter(input_data, self.fs, order=self.order[i], cutoff_low=self.cutoff[i][0], - # cutoff_high=self.cutoff[i][1], window=self.window, dim=self.var_dim, h=None, - # causal=True, padlen=None) - # filtered.append(fi) - # h.append(hi) + # collect some data for visualization + plot_pos = np.array([0.25, 1.5, 2.75, 4]) * 365 * self.fs + plot_dates = [input_data.isel({self.time_dim: int(pos)}).coords[self.time_dim].values for pos in plot_pos if + pos < len(input_data.coords[self.time_dim])] + plot_data = [] + + for i in range(len(self.order)): + # apply filter fi, hi = self.fir_filter(input_data, self.fs, self.cutoff[i], self.order[i], time_dim=self.time_dim, var_dim=self.var_dim, window=self.window, station_name=self.station_name) filtered.append(fi) h.append(hi) + + # visualization + plot_data.extend(self.create_visualization(fi, input_data, plot_dates, self.time_dim, self.fs, hi, + self.minimum_length, self.order, i, self.offset, self.var_dim)) + # calculate residuum input_data = input_data - fi # add last residuum to filtered @@ -55,6 +66,35 @@ class FIRFilter: self._filtered = filtered self._h = h + # visualize + if self.plot_path is not None: + try: + self.PlotFirFilter(self.plot_path, plot_data, self.station_name) # not working when t0 != 0 + except Exception as e: + logging.info(f"Could not plot climate fir filter due to following reason:\n{e}") + + def create_visualization(self, filtered, filter_input_data, plot_dates, time_dim, sampling, + h, minimum_length, order, i, offset, var_dim): # pragma: no cover + plot_data = [] + for viz_date in set(plot_dates).intersection(filtered.coords[time_dim].values): + try: + if i < len(order) - 1: + minimum_length += order[i+1] + + td_type = {1: "D", 24: "h"}.get(sampling) + length = len(h) + extend_length_history = minimum_length + int((length + 1) / 2) + extend_length_future = int((length + 1) / 2) + 1 + t_minus = viz_date + np.timedelta64(int(-extend_length_history), td_type) + t_plus = viz_date + np.timedelta64(int(extend_length_future + offset), td_type) + time_slice = slice(t_minus, t_plus - np.timedelta64(1, td_type)) + plot_data.append({"t0": viz_date, "filter_input": filter_input_data.sel({time_dim: time_slice}), + "filtered": filtered.sel({time_dim: time_slice}), "h": h, "time_dim": time_dim, + "var_dim": var_dim}) + except: + pass + return plot_data + @property def filter_coefficients(self): return self._h @@ -126,7 +166,7 @@ class ClimateFIRFilter(FIRFilter): from mlair.plotting.data_insight_plotting import PlotClimateFirFilter def __init__(self, data, fs, order, cutoff, window, time_dim, var_dim, apriori=None, apriori_type=None, - apriori_diurnal=False, sel_opts=None, plot_path=None, plot_name=None, + apriori_diurnal=False, sel_opts=None, plot_path=None, minimum_length=None, new_dim=None, station_name=None, extend_length_opts: Union[dict, int] = 0): """ :param data: data to filter @@ -159,10 +199,9 @@ class ClimateFIRFilter(FIRFilter): self.minimum_length = minimum_length self.new_dim = new_dim self.plot_path = plot_path - self.plot_name = plot_name # ToDo: is there a difference between station_name and plot_name??? self.plot_data = [] self.extend_length_opts = extend_length_opts - super().__init__(data, fs, order, cutoff, window, var_dim, time_dim, station_name) + super().__init__(data, fs, order, cutoff, window, var_dim, time_dim, station_name=station_name) def run(self): filtered = [] @@ -170,17 +209,17 @@ class ClimateFIRFilter(FIRFilter): if self.sel_opts is not None: self.sel_opts = self.sel_opts if isinstance(self.sel_opts, dict) else {self.time_dim: self.sel_opts} sampling = {1: "1d", 24: "1H"}.get(int(self.fs)) - logging.debug(f"{self.plot_name}: create diurnal_anomalies") + logging.debug(f"{self.station_name}: create diurnal_anomalies") if self.apriori_diurnal is True and sampling == "1H": diurnal_anomalies = self.create_seasonal_hourly_mean(self.data, self.time_dim, sel_opts=self.sel_opts, sampling=sampling, as_anomaly=True) else: diurnal_anomalies = 0 - logging.debug(f"{self.plot_name}: create monthly apriori") + logging.debug(f"{self.station_name}: create monthly apriori") if self._apriori is None: self._apriori = self.create_monthly_mean(self.data, self.time_dim, sel_opts=self.sel_opts, sampling=sampling) + diurnal_anomalies - logging.debug(f"{self.plot_name}: apriori shape = {self._apriori.shape}") + logging.debug(f"{self.station_name}: apriori shape = {self._apriori.shape}") apriori_list = to_list(self._apriori) input_data = self.data.__deepcopy__() @@ -191,7 +230,7 @@ class ClimateFIRFilter(FIRFilter): new_dim = self._create_tmp_dimension(input_data) if self.new_dim is None else self.new_dim for i in range(len(self.order)): - logging.info(f"{self.plot_name}: start filter for order {self.order[i]}") + logging.info(f"{self.station_name}: start filter for order {self.order[i]}") # calculate climatological filter _minimum_length = self._minimum_length(self.order, self.minimum_length, i, self.window) fi, hi, apriori, plot_data = self.clim_filter(input_data, self.fs, self.cutoff[i], self.order[i], @@ -202,7 +241,7 @@ class ClimateFIRFilter(FIRFilter): plot_dates=plot_dates, station_name=self.station_name, extend_length_opts=self.extend_length_opts) - logging.info(f"{self.plot_name}: finished clim_filter calculation") + logging.info(f"{self.station_name}: finished clim_filter calculation") if self.minimum_length is None: filtered.append(fi) else: @@ -213,7 +252,7 @@ class ClimateFIRFilter(FIRFilter): plot_dates = {e["t0"] for e in plot_data} # calculate residuum - logging.info(f"{self.plot_name}: calculate residuum") + logging.info(f"{self.station_name}: calculate residuum") coord_range = range(fi.coords[new_dim].values.min(), fi.coords[new_dim].values.max() + 1) if new_dim in input_data.coords: input_data = input_data.sel({new_dim: coord_range}) - fi @@ -222,14 +261,14 @@ class ClimateFIRFilter(FIRFilter): # create new apriori information for next iteration if no further apriori is provided if len(apriori_list) <= i + 1: - logging.info(f"{self.plot_name}: create diurnal_anomalies") + logging.info(f"{self.station_name}: create diurnal_anomalies") if self.apriori_diurnal is True and sampling == "1H": diurnal_anomalies = self.create_seasonal_hourly_mean(input_data.sel({new_dim: 0}, drop=True), self.time_dim, sel_opts=self.sel_opts, sampling=sampling, as_anomaly=True) else: diurnal_anomalies = 0 - logging.info(f"{self.plot_name}: create monthly apriori") + logging.info(f"{self.station_name}: create monthly apriori") if self.apriori_type is None or self.apriori_type == "zeros": # zero version apriori_list.append(xr.zeros_like(apriori_list[i]) + diurnal_anomalies) elif self.apriori_type == "residuum_stats": # calculate monthly statistic on residuum @@ -252,7 +291,7 @@ class ClimateFIRFilter(FIRFilter): # visualize if self.plot_path is not None: try: - self.PlotClimateFirFilter(self.plot_path, self.plot_data, sampling, self.plot_name) # not working when t0 != 0 + self.PlotClimateFirFilter(self.plot_path, self.plot_data, sampling, self.station_name) except Exception as e: logging.info(f"Could not plot climate fir filter due to following reason:\n{e}") @@ -704,7 +743,7 @@ class ClimateFIRFilter(FIRFilter): # visualization plot_data.extend(self.create_visualization(filt, d, filter_input_data, plot_dates, time_dim, new_dim, sampling, extend_length_history, extend_length_future, - minimum_length, h, var, extend_length_opts)) # todo check if this still works with extend_length_opts + minimum_length, h, var, extend_length_opts)) # collect all filter results coll.append(xr.concat(filt_coll, time_dim)) diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py index 5b06b76be1e012032d4a1e8785587fdfd45a3d7c..096163451355cb5011dbb2cf39c48c963d51c03c 100644 --- a/mlair/plotting/data_insight_plotting.py +++ b/mlair/plotting/data_insight_plotting.py @@ -1141,3 +1141,130 @@ class PlotClimateFirFilter(AbstractPlotClass): # pragma: no cover file = os.path.join(self.plot_folder, "plot_data.pickle") with open(file, "wb") as f: dill.dump(data, f) + + +class PlotFirFilter(AbstractPlotClass): # pragma: no cover + """ + Plot FIR filter components. + + * Creates a separate folder FIR inside the given plot directory. + * For each station up to 4 examples are shown (1 for each season). + * Each filtered component and its residuum is drawn in a separate plot. + * A filter component plot includes the FIR input and the filter response + * A filter residuum plot include the FIR residuum + """ + + def __init__(self, plot_folder, plot_data, name): + + logging.info(f"start PlotFirFilter for ({name})") + + # adjust default plot parameters + rc_params = { + 'axes.labelsize': 'large', + 'xtick.labelsize': 'large', + 'ytick.labelsize': 'large', + 'legend.fontsize': 'medium', + 'axes.titlesize': 'large'} + if plot_folder is None: + return + + self.style_dict = { + "original": {"color": "darkgrey", "linestyle": "dashed", "label": "original"}, + "apriori": {"color": "darkgrey", "linestyle": "solid", "label": "estimated future"}, + "clim": {"color": "black", "linestyle": "solid", "label": "clim filter", "linewidth": 2}, + "FIR": {"color": "black", "linestyle": "dashed", "label": "ideal filter", "linewidth": 2}, + "valid_area": {"color": "whitesmoke", "label": "valid area"}, + "t0": {"color": "lightgrey", "lw": 6, "label": "$t_0$"} + } + + plot_folder = os.path.join(os.path.abspath(plot_folder), "FIR") + super().__init__(plot_folder, plot_name=None, rc_params=rc_params) + plot_dict = self._prepare_data(plot_data) + self._name = name + self._plot(plot_dict) + self._store_plot_data(plot_data) + + def _prepare_data(self, data): + """Restructure plot data.""" + plot_dict = {} + for i, o in enumerate(range(len(data))): + plot_data = data[i] + t0 = plot_data.get("t0") + filter_input = plot_data.get("filter_input") + filtered = plot_data.get("filtered") + var_dim = plot_data.get("var_dim") + time_dim = plot_data.get("time_dim") + for var in filtered.coords[var_dim].values: + plot_dict_var = plot_dict.get(var, {}) + plot_dict_t0 = plot_dict_var.get(t0, {}) + plot_dict_order = {"filter_input": filter_input.sel({var_dim: var}, drop=True), + "filtered": filtered.sel({var_dim: var}, drop=True), + "time_dim": time_dim} + plot_dict_t0[i] = plot_dict_order + plot_dict_var[t0] = plot_dict_t0 + plot_dict[var] = plot_dict_var + return plot_dict + + def _plot(self, plot_dict): + for var, viz_date_dict in plot_dict.items(): + for it0, t0 in enumerate(viz_date_dict.keys()): + viz_data = viz_date_dict[t0] + try: + for ifilter in sorted(viz_data.keys()): + data = viz_data[ifilter] + filter_input = data["filter_input"] + filtered = data["filtered"] + time_dim = data["time_dim"] + time_axis = filtered.coords[time_dim].values + fig, ax = plt.subplots() + + # plot backgrounds + self._plot_t0(ax, t0) + + # original data + self._plot_data(ax, time_axis, filter_input, style="original") + + # filter response + self._plot_data(ax, time_axis, filtered, style="FIR") + + # set title, legend, and save plot + ax.set_xlim((time_axis[0], time_axis[-1])) + + plt.title(f"Input of Filter ({str(var)})") + plt.legend() + fig.autofmt_xdate() + plt.tight_layout() + self.plot_name = f"FIR_{self._name}_{str(var)}_{it0}_{ifilter}" + self._save() + + # plot residuum + fig, ax = plt.subplots() + self._plot_t0(ax, t0) + self._plot_data(ax, time_axis, filter_input - filtered, style="FIR") + ax.set_xlim((time_axis[0], time_axis[-1])) + plt.title(f"Residuum of Filter ({str(var)})") + plt.legend(loc="upper left") + fig.autofmt_xdate() + plt.tight_layout() + + self.plot_name = f"FIR_{self._name}_{str(var)}_{it0}_{ifilter}_residuum" + self._save() + except Exception as e: + logging.info(f"Could not create plot because of:\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") + pass + + def _plot_t0(self, ax, t0): + ax.axvline(t0, **self.style_dict["t0"]) + + def _plot_series(self, ax, time_axis, data, style): + ax.plot(time_axis, data, **self.style_dict[style]) + + def _plot_data(self, ax, time_axis, data, style="original"): + # original data + self._plot_series(ax, time_axis, data.values.flatten(), style=style) + + def _store_plot_data(self, data): + """Store plot data. Could be loaded in a notebook to redraw.""" + file = os.path.join(self.plot_folder, "plot_data.pickle") + with open(file, "wb") as f: + dill.dump(data, f)