diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py index f51d9c49de48abe21fd98316dca985d9612345a9..95b482df3c2dd6d04d8b3029a2f9091d551d2829 100644 --- a/mlair/plotting/data_insight_plotting.py +++ b/mlair/plotting/data_insight_plotting.py @@ -3,6 +3,7 @@ __author__ = "Lukas Leufen, Felix Kleinert" __date__ = '2021-04-13' from typing import List, Dict +import dill import os import logging import multiprocessing @@ -862,3 +863,218 @@ def f_proc_hist(data, variables, n_bins, variables_dim): # pragma: no cover res[var], bin_edges[var] = np.histogram(d.values, n_bins) interval_width[var] = bin_edges[var][1] - bin_edges[var][0] return res, interval_width, bin_edges + + +class PlotClimateFirFilter(AbstractPlotClass): + """ + Plot climate FIR filter components. + + * Creates a separate folder climFIR inside the given plot directory. + * For each station up to 4 examples are shown (1 for each season). + * Each filtered component and its residuum is drawn in a separate plot. + * A filter component plot includes the climate FIR input, the filter response, the true non-causal (ideal) filter + input, and the corresponding ideal response (containing information about future) + * A filter residuum plot include the climate FIR residuum and the ideal filter residuum. + """ + + def __init__(self, plot_folder, plot_data, sampling, name): + + from mlair.helpers.filter import fir_filter_convolve + + # adjust default plot parameters + rc_params = { + 'axes.labelsize': 'large', + 'xtick.labelsize': 'large', + 'ytick.labelsize': 'large', + 'legend.fontsize': 'medium', + 'axes.titlesize': 'large'} + if plot_folder is None: + return + + self.style_dict = { + "original": {"color": "darkgrey", "linestyle": "dashed", "label": "original"}, + "apriori": {"color": "darkgrey", "linestyle": "solid", "label": "estimated future"}, + "clim": {"color": "black", "linestyle": "solid", "label": "clim filter", "linewidth": 2}, + "ideal": {"color": "black", "linestyle": "dashed", "label": "ideal filter", "linewidth": 2}, + "valid_area": {"color": "whitesmoke", "label": "valid area"}, + "t0": {"color": "lightgrey", "lw": 6, "label": "$t_0$"} + } + + plot_folder = os.path.join(os.path.abspath(plot_folder), "climFIR") + self.fir_filter_convolve = fir_filter_convolve + super().__init__(plot_folder, plot_name=None, rc_params=rc_params) + plot_dict, new_dim = self._prepare_data(plot_data) + self._name = name + self._plot(plot_dict, sampling, new_dim) + self._store_plot_data(plot_data) + + def _prepare_data(self, data): + """Restructure plot data.""" + plot_dict = {} + new_dim = None + for i, o in enumerate(range(len(data))): + plot_data = data[i] + for p_d in plot_data: + var = p_d.get("var") + t0 = p_d.get("t0") + filter_input = p_d.get("filter_input") + filter_input_nc = p_d.get("filter_input_nc") + valid_range = p_d.get("valid_range") + time_range = p_d.get("time_range") + if new_dim is None: + new_dim = p_d.get("new_dim") + else: + assert new_dim == p_d.get("new_dim") + h = p_d.get("h") + plot_dict_var = plot_dict.get(var, {}) + plot_dict_t0 = plot_dict_var.get(t0, {}) + plot_dict_order = {"filter_input": filter_input, + "filter_input_nc": filter_input_nc, + "valid_range": valid_range, + "time_range": time_range, + "order": len(h), "h": h} + plot_dict_t0[i] = plot_dict_order + plot_dict_var[t0] = plot_dict_t0 + plot_dict[var] = plot_dict_var + return plot_dict, new_dim + + def _plot(self, plot_dict, sampling, new_dim="window"): + td_type = {"1d": "D", "1H": "h"}.get(sampling) + for var, viz_date_dict in plot_dict.items(): + for it0, t0 in enumerate(viz_date_dict.keys()): + viz_data = viz_date_dict[t0] + residuum_true = None + for ifilter in sorted(viz_data.keys()): + data = viz_data[ifilter] + filter_input = data["filter_input"] + filter_input_nc = data["filter_input_nc"] if residuum_true is None else residuum_true.sel( + {new_dim: filter_input.coords[new_dim]}) + valid_range = data["valid_range"] + time_axis = data["time_range"] + filter_order = data["order"] + h = data["h"] + fig, ax = plt.subplots() + + # plot backgrounds + self._plot_valid_area(ax, t0, valid_range, td_type) + self._plot_t0(ax, t0) + + # original data + self._plot_original_data(ax, time_axis, filter_input_nc) + + # clim apriori + self._plot_apriori(ax, time_axis, filter_input, new_dim, ifilter) + + # clim filter response + residuum_estimated = self._plot_clim_filter(ax, time_axis, filter_input, new_dim, h, + output_dtypes=filter_input.dtype) + + # ideal filter response + residuum_true = self._plot_ideal_filter(ax, time_axis, filter_input_nc, new_dim, h, + output_dtypes=filter_input.dtype) + + # set title, legend, and save plot + xlims = self._set_xlim(ax, t0, filter_order, valid_range, td_type, time_axis) + + plt.title(f"Input of ClimFilter ({str(var)})") + plt.legend() + fig.autofmt_xdate() + plt.tight_layout() + self.plot_name = f"climFIR_{self._name}_{str(var)}_{it0}_{ifilter}" + self._save() + + # plot residuum + fig, ax = plt.subplots() + self._plot_valid_area(ax, t0, valid_range, td_type) + self._plot_t0(ax, t0) + self._plot_series(ax, time_axis, residuum_true.values.flatten(), style="ideal") + self._plot_series(ax, time_axis, residuum_estimated.values.flatten(), style="clim") + ax.set_xlim(xlims) + plt.title(f"Residuum of ClimFilter ({str(var)})") + plt.legend(loc="upper left") + fig.autofmt_xdate() + plt.tight_layout() + + self.plot_name = f"climFIR_{self._name}_{str(var)}_{it0}_{ifilter}_residuum" + self._save() + + def _set_xlim(self, ax, t0, order, valid_range, td_type, time_axis): + """ + Set xlims + + Use order and valid_range to find a good zoom in that hides edges of filter values that are effected by reduced + filter order. Limits are returned to be usable for other plots. + """ + t_minus_delta = max(1.5 * valid_range.start, 0.3 * order) + t_plus_delta = max(0.5 * valid_range.start, 0.3 * order) + t_minus = t0 + np.timedelta64(-int(t_minus_delta), td_type) + t_plus = t0 + np.timedelta64(int(t_plus_delta), td_type) + ax_start = max(t_minus, time_axis[0]) + ax_end = min(t_plus, time_axis[-1]) + ax.set_xlim((ax_start, ax_end)) + return ax_start, ax_end + + def _plot_valid_area(self, ax, t0, valid_range, td_type): + ax.axvspan(t0 + np.timedelta64(-valid_range.start, td_type), + t0 + np.timedelta64(valid_range.stop - 1, td_type), **self.style_dict["valid_area"]) + + def _plot_t0(self, ax, t0): + ax.axvline(t0, **self.style_dict["t0"]) + + def _plot_series(self, ax, time_axis, data, style): + ax.plot(time_axis, data, **self.style_dict[style]) + + def _plot_original_data(self, ax, time_axis, data): + # original data + filter_input_nc = data + self._plot_series(ax, time_axis, filter_input_nc.values.flatten(), style="original") + # self._plot_series(ax, time_axis, filter_input_nc.values.flatten(), color="darkgrey", linestyle="dashed", + # label="original") + + def _plot_apriori(self, ax, time_axis, data, new_dim, ifilter): + # clim apriori + filter_input = data + if ifilter == 0: + d_tmp = filter_input.sel( + {new_dim: slice(0, filter_input.coords[new_dim].values.max())}).values.flatten() + else: + d_tmp = filter_input.values.flatten() + self._plot_series(ax, time_axis[len(time_axis) - len(d_tmp):], d_tmp, style="apriori") + # self._plot_series(ax, time_axis[len(time_axis) - len(d_tmp):], d_tmp, color="darkgrey", linestyle="solid", + # label="estimated future") + + def _plot_clim_filter(self, ax, time_axis, data, new_dim, h, output_dtypes): + filter_input = data + # clim filter response + filt = xr.apply_ufunc(self.fir_filter_convolve, filter_input, + input_core_dims=[[new_dim]], + output_core_dims=[[new_dim]], + vectorize=True, + kwargs={"h": h}, + output_dtypes=[output_dtypes]) + self._plot_series(ax, time_axis, filt.values.flatten(), style="clim") + # self._plot_series(ax, time_axis, filt.values.flatten(), color="black", linestyle="solid", + # label="clim filter response", linewidth=2) + residuum_estimated = filter_input - filt + return residuum_estimated + + def _plot_ideal_filter(self, ax, time_axis, data, new_dim, h, output_dtypes): + filter_input_nc = data + # ideal filter response + filt = xr.apply_ufunc(self.fir_filter_convolve, filter_input_nc, + input_core_dims=[[new_dim]], + output_core_dims=[[new_dim]], + vectorize=True, + kwargs={"h": h}, + output_dtypes=[output_dtypes]) + self._plot_series(ax, time_axis, filt.values.flatten(), style="ideal") + # self._plot_series(ax, time_axis, filt.values.flatten(), color="black", linestyle="dashed", + # label="ideal filter response", linewidth=2) + residuum_true = filter_input_nc - filt + return residuum_true + + def _store_plot_data(self, data): + """Store plot data. Could be loaded in a notebook to redraw.""" + file = os.path.join(self.plot_folder, "plot_data.pickle") + with open(file, "wb") as f: + dill.dump(data, f)