diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 299febfe7db6f3237b3390ec6c0563506d61a98c..c7647ef5bf5b5b6c46eae9318c0fd99b294292c6 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -25,6 +25,7 @@ from mlair.helpers import TimeTrackingWrapper from mlair.plotting.abstract_plot_class import AbstractPlotClass from mlair.helpers.statistics import mann_whitney_u_test, represent_p_values_as_asteriks + logging.getLogger('matplotlib').setLevel(logging.WARNING) @@ -1239,6 +1240,85 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover return ax +@TimeTrackingWrapper +class PlotTimeEvolutionMetric(AbstractPlotClass): + + def __init__(self, data: xr.DataArray, ahead_dim="ahead", model_type_dim="type", plot_folder=".", + error_measure: str = "mse", error_unit: str = None, model_name: str = "NN", + model_indicator: str = "nn", time_dim="index"): + super().__init__(plot_folder, "time_evolution_mse") + self.title = error_measure + f" (in {error_unit})" if error_unit is not None else "" + plot_name = self.plot_name + vmin = int(data.quantile(0.05)) + vmax = int(data.quantile(0.95)) + data = self._prepare_data(data, time_dim, model_type_dim, model_indicator, model_name) + + for t in data[model_type_dim]: + # note: could be expanded to create plot per ahead step + plot_data = data.sel({model_type_dim: t}).mean(ahead_dim).to_pandas() + years = plot_data.columns.strftime("%Y").to_list() + months = plot_data.columns.strftime("%b").to_list() + plot_data.columns = plot_data.columns.strftime("%b %Y") + self.plot_name = f"{plot_name}_{t.values}" + self._plot(plot_data, years, months, vmin, vmax, str(t.values)) + + @staticmethod + def _find_nan_edge(data, time_dim): + coll = [] + for i in data: + if bool(i) is False: + break + else: + coll.append(i[time_dim].values) + return coll + + def _prepare_data(self, data, time_dim, model_type_dim, model_indicator, model_name): + # remove nans at begin and end + nan_locs = data.isnull().all(helpers.remove_items(data.dims, time_dim)) + nans_at_end = self._find_nan_edge(reversed(nan_locs), time_dim) + nans_at_begin = self._find_nan_edge(nan_locs, time_dim) + data = data.drop(nans_at_begin + nans_at_end, time_dim) + # rename nn model + data[model_type_dim] = [v if v != model_indicator else model_name for v in data[model_type_dim].data.tolist()] + return data + + @staticmethod + def _set_ticks(ax, years, months): + from matplotlib.ticker import IndexLocator + ax.xaxis.set_major_locator(IndexLocator(1, 0.5)) + locs = ax.get_xticks(minor=False).tolist()[:len(months)] + ax.set_xticks(locs, minor=True) + ax.set_xticklabels([m[0] for m in months], minor=True, rotation=0) + locs_major = [] + labels_major = [] + for l, major, minor in zip(locs, years, months): + if minor == "Jan": + locs_major.append(l + 0.001) + labels_major.append(major) + if len(locs_major) == 0: # in case there is less than a year and no Jan included + locs_major = locs[0] + 0.001 + labels_major = years[0] + ax.set_xticks(locs_major) + ax.set_xticklabels(labels_major, minor=False, rotation=0) + ax.tick_params(axis="x", which="major", pad=15) + + @staticmethod + def _aspect_cbar(val): + return min(max(1.25 * val + 7.5, 10), 30) + + def _plot(self, data, years, months, vmin=None, vmax=None, subtitle=None): + fig, ax = plt.subplots(figsize=(max(data.shape[1] / 6, 12), max(data.shape[0] / 3.5, 2))) + data.sort_index(inplace=True) + sns.heatmap(data, linewidths=1, cmap="coolwarm", ax=ax, vmin=vmin, vmax=vmax, + cbar_kws={"aspect": self._aspect_cbar(data.shape[0])}) + # or cmap="Spectral_r", cmap="RdYlBu_r", cmap="coolwarm", + # square=True + self._set_ticks(ax, years, months) + ax.set(xlabel=None, ylabel=None, title=self.title if subtitle is None else f"{subtitle}\n{self.title}") + plt.tight_layout() + self._save() + + if __name__ == "__main__": stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] path = "../../testrun_network/forecasts" diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 1dfa22a90261df01b12834783f02d96261032394..07ef1ce46b7951c812fbd17cc06916bbd8cb9caf 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -17,13 +17,13 @@ import xarray as xr from mlair.configuration import path_config from mlair.data_handler import Bootstraps, KerasIterator -from mlair.helpers.datastore import NameNotFoundInDataStore +from mlair.helpers.datastore import NameNotFoundInDataStore, NameNotFoundInScope from mlair.helpers import TimeTracking, TimeTrackingWrapper, statistics, extract_value, remove_items, to_list, tables from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel from mlair.model_modules import AbstractModelClass from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \ PlotCompetitiveSkillScore, PlotTimeSeries, PlotFeatureImportanceSkillScore, PlotConditionalQuantiles, \ - PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap + PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap, PlotTimeEvolutionMetric from mlair.plotting.data_insight_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram, \ PlotPeriodogram, PlotDataHistogram from mlair.run_modules.run_environment import RunEnvironment @@ -489,7 +489,7 @@ class PostProcessing(RunEnvironment): """ try: # is only available if a model was trained in training stage model = self.data_store.get("model") - except NameNotFoundInDataStore: + except (NameNotFoundInDataStore, NameNotFoundInScope): logging.info("No model was saved in data store. Try to load model from experiment path.") model_name = self.data_store.get("model_name", "model") model: AbstractModelClass = self.data_store.get("model", "model") @@ -663,7 +663,18 @@ class PostProcessing(RunEnvironment): except Exception as e: logging.error(f"Could not create plot PlotPeriodogram due to the following error: {e}" f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") - + + try: + if "PlotTimeEvolutionMetric" in plot_list: + PlotTimeEvolutionMetric(self.block_mse_per_station, plot_folder=self.plot_path, + model_type_dim=self.model_type_dim, ahead_dim=self.ahead_dim, + error_measure="Mean Squared Error", error_unit=r"ppb$^2$", + model_indicator=self.forecast_indicator, model_name=self.model_display_name, + time_dim=self.index_dim) + except Exception as e: + logging.error(f"Could not create plot PlotTimeEvolutionMetric due to the following error: {e}" + f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") + @TimeTrackingWrapper def calculate_test_score(self): """Evaluate test score of model and save locally."""