diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index b16028fad56c128cf7431effbab7b25687d485fe..3b9b563426a80816f7cf1ea9e114a8395d9fbba0 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -108,6 +108,7 @@ class PostProcessing(RunEnvironment): skill_score_competitive, skill_score_climatological, errors = self.calculate_error_metrics() self.skill_scores = (skill_score_competitive, skill_score_climatological) self.report_error_metrics(errors) + self.report_error_metrics(skill_score_climatological) # plotting self.plot() @@ -716,6 +717,10 @@ class PostProcessing(RunEnvironment): path_config.check_path_and_create(report_path) metric_collection = {} for station, station_errors in errors.items(): + if isinstance(station_errors, xr.DataArray): + dim = station_errors.dims[0] + sel_index = [sel for sel in station_errors.coords[dim] if "CASE" in str(sel)] + station_errors = {str(i.values): station_errors.sel(**{dim: i}) for i in sel_index} for metric, vals in station_errors.items(): if metric == "n": continue @@ -726,7 +731,9 @@ class PostProcessing(RunEnvironment): metric_collection[metric] = mc for metric, error_df in metric_collection.items(): df = error_df.sort_index() - df.reindex(df.index.drop(["total"]).to_list() + ["total"], ) + if "total" in df.index: + df.reindex(df.index.drop(["total"]).to_list() + ["total"], ) column_format = tables.create_column_format_for_tex(df) - tables.save_to_tex(report_path, f"error_report_{metric}.tex", column_format=column_format, df=df) - tables.save_to_md(report_path, f"error_report_{metric}.md", df=df) + file_name = f"error_report_{metric}.%s".replace(' ', '_') + tables.save_to_tex(report_path, file_name % "tex", column_format=column_format, df=df) + tables.save_to_md(report_path, file_name % "md", df=df)