diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 2484d5ddeac883243f1920c9904f07a7c93bb2cb..77c12c43cf20b3176d9c0e5f07feb186eb368150 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -25,6 +25,7 @@ from mlair.helpers import TimeTrackingWrapper from mlair.plotting.abstract_plot_class import AbstractPlotClass from mlair.helpers.statistics import mann_whitney_u_test, represent_p_values_as_asteriks + logging.getLogger('matplotlib').setLevel(logging.WARNING) @@ -1234,50 +1235,31 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover return ax +@TimeTrackingWrapper class PlotTimeEvolutionMetric(AbstractPlotClass): - # mse_pandas = block_mse_per_station.sel(ahead=1, type="nn").to_pandas() - # - # mse_pandas.columns = mse_pandas.columns.strftime("%b %Y") - # years = mse_pandas.columns.strftime("%Y").to_list() - # months = mse_pandas.columns.strftime("%b").to_list() - # - # fig, ax = plt.subplots() - # sns.heatmap(mse_pandas, linewidths=1, square=True, cmap="coolwarm", ax=ax) - # locs = ax.get_xticks(minor=False).tolist() - # # ax.xaxis.set_major_locator(NullLocator()) - # # ax.xaxis.set_minor_locator(FixedLocator(locs)) - # - # ax.set_xticks(locs, minor=True) - # ax.set_xticklabels([m[0] for m in months], minor=True, rotation=0) - # - # locs_major = [] - # labels_major = [] - # for l, major, minor in zip(locs, years, months): - # if minor == "Jan": - # locs_major.append(l+0.001) - # labels_major.append(major) - # - # ax.set_xticks(locs_major) - # ax.set_xticklabels(labels_major, minor=False, rotation=0) - # ax.tick_params(axis="x", which="major", pad=15) - - - def __init__(self, data: xr.DataArray, ahead_dim="ahead", model_type_dim="type"): - - data = data.sel(ahead=1, type="nn").to_pandas() - years = data.columns.strftime("%Y").to_list() - months = data.columns.strftime("%b").to_list() - data.columns = data.columns.strftime("%b %Y") - self._plot(data) - - pass - - def prepare_data(self, data, ahead_dim, model_type_dim): - pass - - def _set_ticks(self, ax, years, months): - locs = ax.get_xticks(minor=False).tolist() + def __init__(self, data: xr.DataArray, ahead_dim="ahead", model_type_dim="type", plot_folder=".", + error_measure: str = "mse", error_unit: str = None,): + super().__init__(plot_folder, "time_evolution_mse") + self.title = error_measure + f" (in {error_unit})" if error_unit is not None else "" + plot_name = self.plot_name + vmin = int(data.quantile(0.05)) + vmax = int(data.quantile(0.95)) + for t in data[model_type_dim]: + # note: could be expanded to create plot per ahead step + print(data.sel({model_type_dim: t}).mean((ahead_dim, "index"))) + plot_data = data.sel({model_type_dim: t}).mean(ahead_dim).to_pandas() + years = plot_data.columns.strftime("%Y").to_list() + months = plot_data.columns.strftime("%b").to_list() + plot_data.columns = plot_data.columns.strftime("%b %Y") + self.plot_name = f"{plot_name}_{t.values}" + self._plot(plot_data, years, months, vmin, vmax) + + @staticmethod + def _set_ticks(ax, years, months): + from matplotlib.ticker import IndexLocator + ax.xaxis.set_major_locator(IndexLocator(1, 0.5)) + locs = ax.get_xticks(minor=False).tolist()[:len(months)] ax.set_xticks(locs, minor=True) ax.set_xticklabels([m[0] for m in months], minor=True, rotation=0) locs_major = [] @@ -1293,11 +1275,17 @@ class PlotTimeEvolutionMetric(AbstractPlotClass): ax.set_xticklabels(labels_major, minor=False, rotation=0) ax.tick_params(axis="x", which="major", pad=15) - def _plot(self, data, years, months): - fig, ax = plt.subplots() - sns.heatmap(data, linewidths=1, square=True, cmap="coolwarm", ax=ax, robust=True) + def _plot(self, data, years, months, vmin=None, vmax=None): + fig, ax = plt.subplots(figsize=(data.shape[1] / 5, data.shape[0] / 2.8)) + sns.heatmap(data, linewidths=1, cmap="coolwarm", ax=ax, vmin=vmin, vmax=vmax, cbar_kws={"aspect": 10}) # or cmap="Spectral_r", cmap="RdYlBu_r", cmap="coolwarm", + # square=True self._set_ticks(ax, years, months) + ax.set_xlabel(None) + ax.set_ylabel(None) + plt.title(self.title) + plt.tight_layout() + self._save() if __name__ == "__main__": diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index ca0a1143e80cf9b47ac6795add62d58c6dc9f69c..ce08fbdaa8a2ef89108dfa37d0af51e714e5712b 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -655,9 +655,9 @@ class PostProcessing(RunEnvironment): try: if "PlotTimeEvolutionMetric" in plot_list: - PlotTimeEvolutionMetric(self.block_mse_per_station, model_type_dim=self.model_type_dim, ahead_dim=self.ahead_dim) - #plot_folder=self.plot_path, - + PlotTimeEvolutionMetric(self.block_mse_per_station, plot_folder=self.plot_path, + model_type_dim=self.model_type_dim, ahead_dim=self.ahead_dim, + error_measure="Mean Squared Error", error_unit=r"ppb$^2$") except Exception as e: logging.error(f"Could not create plot PlotTimeEvolutionMetric due to the following error: {e}" f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")