diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index 7ec262033f781294c1ee71a885b3b243136fa47c..7911a572d25e58b96007ab26b35a3eff153acf88 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -281,6 +281,14 @@ def str2bool(v): raise argparse.ArgumentTypeError('Boolean value expected.') +def squeeze_coords(d): + """Look for unused coords and remove them. Does only work for xarray DataArrays.""" + try: + return d.drop(set(d.coords.keys()).difference(d.dims)) + except Exception: + return d + + # def convert_size(size_bytes): # if size_bytes == 0: # return "0B" diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py index 5f3aa45161530ff7d425ccbc7625dd7e081d8839..ee07d7674c73439254cec5eb9f046fb540f3e9a3 100644 --- a/mlair/helpers/statistics.py +++ b/mlair/helpers/statistics.py @@ -12,6 +12,7 @@ from typing import Union, Tuple, Dict, List import itertools from collections import OrderedDict from mlair.helpers import to_list +from mlair.helpers.helpers import squeeze_coords Data = Union[xr.DataArray, pd.DataFrame] @@ -248,7 +249,17 @@ def calculate_error_metrics(a, b, dim): ioa = index_of_agreement(a, b, dim) mnmb = modified_normalized_mean_bias(a, b, dim) n = (a - b).notnull().sum(dim) - return {"mse": mse, "rmse": rmse, "mae": mae, "ioa": ioa, "mnmb": mnmb, "n": n} + results = {"mse": mse, "rmse": rmse, "mae": mae, "ioa": ioa, "mnmb": mnmb, "n": n} + return {k: squeeze_coords(v) for k, v in results.items()} + + +def get_error_metrics_units(base_unit): + return {"mse": f"{base_unit}$^2$", "rmse": base_unit, "mae": base_unit, "ioa": None, "mnmb": None, "n": None} + + +def get_error_metrics_long_name(): + return {"mse": "mean squared error", "rmse": "root mean squared error", "mae": "mean absolute error", + "ioa": "index of agreement", "mnmb": "modified normalized mean bias", "n": "count"} def mann_whitney_u_test(data: pd.DataFrame, reference_col_name: str, **kwargs): diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 22e55220c6aa54f8352435cc5f5ddaf4f072f0b7..bb3e4ba13440b037d93ae7e568f8ac66a7f4131c 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -1097,8 +1097,9 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover def __init__(self, data: xr.DataArray, plot_folder: str = ".", model_type_dim: str = "type", error_measure: str = "mse", error_unit: str = None, dim_name_boots: str = 'boots', block_length: str = None, model_name: str = "NN", model_indicator: str = "nn", - ahead_dim: str = "ahead", sampling: Union[str, Tuple[str]] = "", season_annotation: str = None): - super().__init__(plot_folder, "sample_uncertainty_from_bootstrap") + ahead_dim: str = "ahead", sampling: Union[str, Tuple[str]] = "", season_annotation: str = None, + apply_root: bool = True, plot_name="sample_uncertainty_from_bootstrap"): + super().__init__(plot_folder, plot_name) self.default_plot_name = self.plot_name self.model_type_dim = model_type_dim self.ahead_dim = ahead_dim @@ -1119,10 +1120,11 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover for orientation, utest, agg_type in variants: self._plot(orientation=orientation, apply_u_test=utest, agg_type=agg_type, season=_season) - # plot root of metric (rmse) - self._apply_root() - for orientation, utest, agg_type in variants: - self._plot(orientation=orientation, apply_u_test=utest, agg_type=agg_type, tag="_sqrt", season=_season) + if apply_root is True: + # plot root of metric (rmse) + self._apply_root() + for orientation, utest, agg_type in variants: + self._plot(orientation=orientation, apply_u_test=utest, agg_type=agg_type, tag="_sqrt", season=_season) self._data_table = None self._n_boots = None @@ -1169,6 +1171,9 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover if orientation == "v": figsize, width = (size, 5), 0.4 elif orientation == "h": + if agg_type == "multi": + size *= np.sqrt(len(data_table.index.unique(self.ahead_dim))) + size = max(size, 8) figsize, width = (7, (1+.5*size)), 0.65 else: raise ValueError(f"orientation must be `v' or `h' but is: {orientation}") @@ -1185,7 +1190,7 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover sns.boxplot(data=data_table.stack(self.model_type_dim).reset_index(), ax=ax, whis=1.5, palette=color_palette, showmeans=True, meanprops={"markersize": 6, "markeredgecolor": "k", "markerfacecolor": "white"}, flierprops={"marker": "o", "markerfacecolor": "black", "markeredgecolor": "none", "markersize": 3}, - boxprops={'edgecolor': 'k'}, width=.9, orient=orientation, **xy, hue=self.ahead_dim) + boxprops={'edgecolor': 'k'}, width=.8, orient=orientation, **xy, hue=self.ahead_dim) _labels = [str(i) + self.sampling for i in data_table.index.levels[1].values] handles, _ = ax.get_legend_handles_labels() @@ -1206,17 +1211,18 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover axi.set_title(title) plt.setp(axi.lines, color='k') + error_label = self.error_measure if self.error_unit is None else f"{self.error_measure} (in {self.error_unit})" if agg_type == "panel": if orientation == "v": for axi in ax.axes.flatten(): axi.set_xlabel(None) axi.set_xticklabels(axi.get_xticklabels(), rotation=45) - ax.set_ylabels(f"{self.error_measure} (in {self.error_unit})") + ax.set_ylabels(error_label) loc = "upper left" else: for axi in ax.axes.flatten(): axi.set_ylabel(None) - ax.set_xlabels(f"{self.error_measure} (in {self.error_unit})") + ax.set_xlabels(error_label) loc = "upper right" text = f"n={n_boots}" if self.block_length is not None: @@ -1232,13 +1238,13 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover ax = self.set_significance_bars(asteriks, ax, data_table, orientation) ylims = list(ax.get_ylim()) ax.set_ylim([ylims[0], ylims[1]*1.025]) - ax.set_ylabel(f"{self.error_measure} (in {self.error_unit})") + ax.set_ylabel(error_label) ax.set_xticklabels(ax.get_xticklabels(), rotation=45) ax.set_xlabel(None) elif orientation == "h": if apply_u_test: ax = self.set_significance_bars(asteriks, ax, data_table, orientation) - ax.set_xlabel(f"{self.error_measure} (in {self.error_unit})") + ax.set_xlabel(error_label) xlims = list(ax.get_xlim()) ax.set_xlim([xlims[0], xlims[1] * 1.015]) ax.set_ylabel(None) diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 459a1928958a96238af89caa4241911554df416f..3aec881f8b9f8d55315309bd9e3c5e36e7396108 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -87,6 +87,7 @@ class PostProcessing(RunEnvironment): self._sampling = self.data_store.get("sampling") self.window_lead_time = extract_value(self.data_store.get("output_shape", "model")) self.skill_scores = None + self.errors = None self.feature_importance_skill_scores = None self.uncertainty_estimate = None self.uncertainty_estimate_seasons = {} @@ -140,6 +141,7 @@ class PostProcessing(RunEnvironment): self.report_error_metrics(errors) self.report_error_metrics({self.forecast_indicator: skill_score_climatological}) self.report_error_metrics({"skill_score": skill_score_competitive}) + self.store_errors(errors) # plotting self.plot() @@ -626,6 +628,25 @@ class PostProcessing(RunEnvironment): logging.error(f"Could not create plot PlotSampleUncertaintyFromBootstrap due to the following error: {e}" f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") + try: + if "PlotErrorMetrics" in plot_list and self.errors is not None: + error_metric_units = statistics.get_error_metrics_units("ppb") + error_metrics_name = statistics.get_error_metrics_long_name() + for error_metric in self.errors.keys(): + try: + PlotSampleUncertaintyFromBootstrap( + data=self.errors[error_metric], plot_folder=self.plot_path, model_type_dim=self.model_type_dim, + dim_name_boots="station", error_measure=error_metrics_name[error_metric], + error_unit=error_metric_units[error_metric], model_name=self.model_display_name, + model_indicator=self.model_display_name, sampling=self._sampling, apply_root=False, + plot_name=f"error_plot_{error_metric}") + except Exception as e: + logging.error(f"Could not create plot PlotErrorMetrics for {error_metric} due to the following " + f"error: {e}\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") + except Exception as e: + logging.error(f"Could not create plot PlotErrorMetrics due to the following error: {e}" + f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") + try: if "PlotStationMap" in plot_list: gens = [(self.train_data, {"marker": 5, "ms": 9}), @@ -1163,3 +1184,18 @@ class PostProcessing(RunEnvironment): file_name = f"error_report_{metric}_{model_type}.%s".replace(' ', '_').replace('/', '_') tables.save_to_tex(report_path, file_name % "tex", column_format=column_format, df=df) tables.save_to_md(report_path, file_name % "md", df=df) + + def store_errors(self, errors): + metric_collection = {} + error_dim = "error_metric" + station_dim = "station" + for model_type in errors.keys(): + station_collection = {} + for station, station_errors in errors[model_type].items(): + if station == "total": + continue + station_collection[station] = xr.Dataset(station_errors).to_array(error_dim) + metric_collection[model_type] = xr.Dataset(station_collection).to_array(station_dim) + coll = xr.Dataset(metric_collection).to_array(self.model_type_dim) + coll = coll.transpose(station_dim, self.ahead_dim, self.model_type_dim, error_dim) + self.errors = {k: coll.sel({error_dim: k}, drop=True) for k in coll.coords[error_dim].values}