diff --git a/mlair/helpers/filter.py b/mlair/helpers/filter.py index 7b9a175245a911e7d69bb90544ae025022488eeb..75c94b36e993c151d2dcc12988857e99bd63ab7e 100644 --- a/mlair/helpers/filter.py +++ b/mlair/helpers/filter.py @@ -531,12 +531,20 @@ class ClimateFIRFilter: extend_length_history = length if minimum_length is None else minimum_length + int((length + 1) / 2) extend_length_future = int((length + 1) / 2) + 1 + # collect some data for visualization + plot_pos = np.array([0.25, 1.5, 2.75, 4]) * 365 * fs + plot_dates = [data.isel({time_dim: int(pos)}).coords[time_dim].values for pos in plot_pos if + pos < len(data.coords[time_dim])] + coll = [] for var in reversed(data.coords[var_dim].values): # self._tmp_analysis(data, apriori, var, var_dim, length, time_dim, new_dim, h) logging.info(f"{data.coords['Stations'].values[0]} ({var}): sel data") + # empty plot data collection + plot_data = [] + _start = pd.to_datetime(data.coords[time_dim].min().values).year _end = pd.to_datetime(data.coords[time_dim].max().values).year filt_coll = [] @@ -550,19 +558,19 @@ class ClimateFIRFilter: continue # combine historical data / observation [t0-length,t0] and climatological statistics [t0+1,t0+length] - logging.info(f"{data.coords['Stations'].values[0]} ({var}): history") + # logging.info(f"{data.coords['Stations'].values[0]} ({var}): history") if new_dim not in d.coords: history = self._shift_data(d, range(int(-extend_length_history), 1), time_dim, var_dim, new_dim) else: history = d.sel({new_dim: slice(int(-extend_length_history), 0)}) - logging.info(f"{data.coords['Stations'].values[0]} ({var}): future") - diff = (a - history.sel(window=slice(-24, 1)).mean(new_dim)) + # logging.info(f"{data.coords['Stations'].values[0]} ({var}): future") + # diff = (a - history.sel(window=slice(-24, 1)).mean(new_dim)) if new_dim not in a.coords: future = self._shift_data(a, range(1, extend_length_future), time_dim, var_dim, new_dim) # future = self._shift_data(a, range(1, int((length - 1) / 2) + 1), time_dim, var_dim, new_dim) - diff else: future = a.sel({new_dim: slice(1, extend_length_future)}) - logging.info(f"{data.coords['Stations'].values[0]} ({var}): concat to filter input") + # logging.info(f"{data.coords['Stations'].values[0]} ({var}): concat to filter input") filter_input_data = xr.concat([history.dropna(time_dim), future], dim=new_dim, join="left") try: @@ -595,20 +603,31 @@ class ClimateFIRFilter: else: filt_coll.append(filt.sel({new_dim: slice(-minimum_length, 0)}, drop=True)) - # plot - # ToDo: enable plotting again - # if self.plot_path is not None: - # for i, time_pos in enumerate([0.25, 1.5, 2.75, 4]): # [0.25, 1.5, 2.75, 4] x 365 days - # try: - # pos = int(time_pos * 365 * fs) - # filter_example = filter_input_data.isel({time_dim: pos}) - # t0 = filter_example.coords[time_dim].values - # t_slice = filter_input_data.isel( - # {time_dim: slice(pos - int((length - 1) / 2), pos + int((length - 1) / 2) + 1)}).coords[ - # time_dim].values - # # self.plot(d, filter_example, var_dim, time_dim, t_slice, t0, f"{plot_index}_{i}") - # except IndexError: - # pass + # visualization + # ToDo: move this code part into a separate plot method that is called on the fly, not afterwards + # just leave a call self.plot(*args) here! + for viz_date in set(plot_dates).intersection(filt.coords[time_dim].values): + td_type = {"1d": "D", "1H": "h"}.get(sampling) + t_minus = viz_date + np.timedelta64(int(-extend_length_history), td_type) + t_plus = viz_date + np.timedelta64(int(extend_length_future), td_type) + + tmp_filter_data = self._shift_data(d.sel({time_dim: slice(t_minus, t_plus)}), + range(int(-extend_length_history), int(extend_length_future)), + time_dim, var_dim, new_dim) + tmp_filt_nc = xr.apply_ufunc(fir_filter_convolve_vectorized, + tmp_filter_data.sel({time_dim: viz_date}), + input_core_dims=[[new_dim]], + output_core_dims=[[new_dim]], + vectorize=True, + kwargs={"h": h}, + output_dtypes=[d.dtype]) + + valid_range = range(int((length + 1) / 2) if minimum_length is None else minimum_length, 1) + plot_data.append({"t0": viz_date, + "filt": filt.sel({time_dim: viz_date}), + "filter_input": filter_input_data.sel({time_dim: viz_date}), + "filt_nc": tmp_filt_nc, + "valid_range": valid_range}) # select only values at tmp dimension 0 at each point in time # coll.append(filt.sel({new_dim: 0}, drop=True)) @@ -616,6 +635,27 @@ class ClimateFIRFilter: coll.append(xr.concat(filt_coll, time_dim)) gc.collect() + # plot + # ToDo: enable plotting again + if self.plot_path is not None: + for i, viz_data in enumerate(plot_data): + self.plot_new(viz_data, data.sel({var_dim: [var]}), var_dim, time_dim, new_dim, f"{plot_index}_{i}", + sampling) + + # for i, time_pos in enumerate([0.25, 1.5, 2.75, 4]): # [0.25, 1.5, 2.75, 4] x 365 days + # try: + # + # plot_data = coll[-1] + # pos = int(time_pos * 365 * fs) + # filter_example = plot_data.isel({time_dim: pos}) + # t0 = filter_example.coords[time_dim].values + # + # slice_tmp = slice(pos - abs(plot_data.coords[new_dim].values.min()), pos + abs(plot_data.coords[new_dim].values.min())) + # t_slice = plot_data.isel({time_dim: slice_tmp}).coords[time_dim].values + # self.plot(data.sel({var_dim: [var]}), filter_example, var_dim, time_dim, t_slice, t0, f"{plot_index}_{i}") + # except IndexError: + # pass + logging.info(f"{data.coords['Stations'].values[0]}: concat all variables") res = xr.concat(coll, var_dim) # create result array with same shape like input data, gabs are filled by nans @@ -665,6 +705,61 @@ class ClimateFIRFilter: res.name = index_name return res + def plot_new(self, viz_data, orig_data, var_dim, time_dim, new_dim, plot_index, sampling): + try: + td_type = {"1d": "D", "1H": "h"}.get(sampling) + filter_example = viz_data["filt"] + filter_input = viz_data["filter_input"] + filter_nc = viz_data["filt_nc"] + valid_range = viz_data["valid_range"] + t0 = viz_data["t0"] + t_minus = t0 + np.timedelta64(filter_input.coords[new_dim].values.min(), td_type) + t_plus = t0 + np.timedelta64(filter_input.coords[new_dim].values.max(), td_type) + t_slice = slice(t_minus, t_plus) + data = orig_data.sel({time_dim: t_slice}) + plot_folder = os.path.join(os.path.abspath(self.plot_path), "climFIR") + if not os.path.exists(plot_folder): + os.makedirs(plot_folder) + + for var in data.coords[var_dim]: + time_axis = data.sel({var_dim: var, time_dim: t_slice}).coords[time_dim].values + rc_params = {'axes.labelsize': 'large', + 'xtick.labelsize': 'large', + 'ytick.labelsize': 'large', + 'legend.fontsize': 'large', + 'axes.titlesize': 'large', + } + plt.rcParams.update(rc_params) + fig, ax = plt.subplots() + + ax.axvspan(t0 + np.timedelta64(-valid_range.start, td_type), + t0 + np.timedelta64(valid_range.stop - 1, td_type), color="whitesmoke", label="valid area") + + ax.axvline(t0, color="lightgrey", lw=6, label="time of interest ($t_0$)") + ax.plot(time_axis, data.sel({var_dim: var, time_dim: t_slice}).values.flatten(), + color="darkgrey", linestyle="dashed", label="original") + d_tmp = filter_input.sel( + {var_dim: var, new_dim: slice(0, filter_input.coords[new_dim].values.max())}).values.flatten() + # ax.plot(time_axis[len(time_axis) - len(d_tmp):], d_tmp, color="darkgrey", linestyle=(0 ,(1, 1)), label="filter input") + ax.plot(time_axis[len(time_axis) - len(d_tmp):], d_tmp, color="darkgrey", linestyle="solid", + label="estimated future") + # data.sel({var_dim: var, time_dim: time_dim_slice}).plot() + # tmp_comb.sel({var_dim: var}).plot() + # d_filt = filter_example.sel({var_dim: var}).values.flatten() + ax.plot(time_axis, filter_example.sel({var_dim: var}).values.flatten(), + color="black", linestyle="solid", label="filter response", linewidth=2) + ax.plot(time_axis, filter_nc.sel({var_dim: var}).values.flatten(), + color="black", linestyle="dashed", label="ideal filter response", linewidth=2) + plt.title(f"Input of ClimFilter ({str(var.values)})") + plt.legend() + fig.autofmt_xdate() + plt.tight_layout() + plot_name = os.path.join(plot_folder, f"climFIR_{self.plot_name}_{str(var.values)}_{plot_index}.pdf") + plt.savefig(plot_name, dpi=300) + plt.close('all') + except: + pass + def plot(self, data, tmp_comb, var_dim, time_dim, time_dim_slice, t0, plot_index): try: plot_folder = os.path.join(os.path.abspath(self.plot_path), "climFIR") @@ -683,7 +778,8 @@ class ClimateFIRFilter: ax.axvline(t0, color="lightgrey", lw=6, label="time of interest ($t_0$)") ax.plot(time_axis, data.sel({var_dim: var, time_dim: time_dim_slice}).values.flatten(), color="darkgrey", linestyle="--", label="original") - ax.plot(time_axis, tmp_comb.sel({var_dim: var}).values.flatten(), color="black", label="filter input") + d_filt = tmp_comb.sel({var_dim: var}).values.flatten() + ax.plot(time_axis[:len(d_filt)], d_filt, color="black", label="filter input") # data.sel({var_dim: var, time_dim: time_dim_slice}).plot() # tmp_comb.sel({var_dim: var}).plot() plt.title(f"Input of ClimFilter ({str(var.values)})")