diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index bd2012c3fe9f53e7e07bfb4bfc4cde096c2dc891..2484d5ddeac883243f1920c9904f07a7c93bb2cb 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -1234,6 +1234,72 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover return ax +class PlotTimeEvolutionMetric(AbstractPlotClass): + + # mse_pandas = block_mse_per_station.sel(ahead=1, type="nn").to_pandas() + # + # mse_pandas.columns = mse_pandas.columns.strftime("%b %Y") + # years = mse_pandas.columns.strftime("%Y").to_list() + # months = mse_pandas.columns.strftime("%b").to_list() + # + # fig, ax = plt.subplots() + # sns.heatmap(mse_pandas, linewidths=1, square=True, cmap="coolwarm", ax=ax) + # locs = ax.get_xticks(minor=False).tolist() + # # ax.xaxis.set_major_locator(NullLocator()) + # # ax.xaxis.set_minor_locator(FixedLocator(locs)) + # + # ax.set_xticks(locs, minor=True) + # ax.set_xticklabels([m[0] for m in months], minor=True, rotation=0) + # + # locs_major = [] + # labels_major = [] + # for l, major, minor in zip(locs, years, months): + # if minor == "Jan": + # locs_major.append(l+0.001) + # labels_major.append(major) + # + # ax.set_xticks(locs_major) + # ax.set_xticklabels(labels_major, minor=False, rotation=0) + # ax.tick_params(axis="x", which="major", pad=15) + + + def __init__(self, data: xr.DataArray, ahead_dim="ahead", model_type_dim="type"): + + data = data.sel(ahead=1, type="nn").to_pandas() + years = data.columns.strftime("%Y").to_list() + months = data.columns.strftime("%b").to_list() + data.columns = data.columns.strftime("%b %Y") + self._plot(data) + + pass + + def prepare_data(self, data, ahead_dim, model_type_dim): + pass + + def _set_ticks(self, ax, years, months): + locs = ax.get_xticks(minor=False).tolist() + ax.set_xticks(locs, minor=True) + ax.set_xticklabels([m[0] for m in months], minor=True, rotation=0) + locs_major = [] + labels_major = [] + for l, major, minor in zip(locs, years, months): + if minor == "Jan": + locs_major.append(l + 0.001) + labels_major.append(major) + if len(locs_major) == 0: # in case there is less than a year and no Jan included + locs_major = locs[0] + 0.001 + labels_major = years[0] + ax.set_xticks(locs_major) + ax.set_xticklabels(labels_major, minor=False, rotation=0) + ax.tick_params(axis="x", which="major", pad=15) + + def _plot(self, data, years, months): + fig, ax = plt.subplots() + sns.heatmap(data, linewidths=1, square=True, cmap="coolwarm", ax=ax, robust=True) + # or cmap="Spectral_r", cmap="RdYlBu_r", cmap="coolwarm", + self._set_ticks(ax, years, months) + + if __name__ == "__main__": stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] path = "../../testrun_network/forecasts" diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 8a5aa98b22d3f3c2cc4e1f32b8f816a14146b716..ca0a1143e80cf9b47ac6795add62d58c6dc9f69c 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -17,13 +17,13 @@ import xarray as xr from mlair.configuration import path_config from mlair.data_handler import Bootstraps, KerasIterator -from mlair.helpers.datastore import NameNotFoundInDataStore +from mlair.helpers.datastore import NameNotFoundInDataStore, NameNotFoundInScope from mlair.helpers import TimeTracking, TimeTrackingWrapper, statistics, extract_value, remove_items, to_list, tables from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel from mlair.model_modules import AbstractModelClass from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \ PlotCompetitiveSkillScore, PlotTimeSeries, PlotFeatureImportanceSkillScore, PlotConditionalQuantiles, \ - PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap + PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap, PlotTimeEvolutionMetric from mlair.plotting.data_insight_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram, \ PlotPeriodogram, PlotDataHistogram from mlair.run_modules.run_environment import RunEnvironment @@ -480,7 +480,7 @@ class PostProcessing(RunEnvironment): """ try: # is only available if a model was trained in training stage model = self.data_store.get("model") - except NameNotFoundInDataStore: + except (NameNotFoundInDataStore, NameNotFoundInScope): logging.info("No model was saved in data store. Try to load model from experiment path.") model_name = self.data_store.get("model_name", "model") model: AbstractModelClass = self.data_store.get("model", "model") @@ -652,7 +652,16 @@ class PostProcessing(RunEnvironment): except Exception as e: logging.error(f"Could not create plot PlotPeriodogram due to the following error: {e}" f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") - + + try: + if "PlotTimeEvolutionMetric" in plot_list: + PlotTimeEvolutionMetric(self.block_mse_per_station, model_type_dim=self.model_type_dim, ahead_dim=self.ahead_dim) + #plot_folder=self.plot_path, + + except Exception as e: + logging.error(f"Could not create plot PlotTimeEvolutionMetric due to the following error: {e}" + f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") + @TimeTrackingWrapper def calculate_test_score(self): """Evaluate test score of model and save locally."""