diff --git a/mlair/plotting/abstract_plot_class.py b/mlair/plotting/abstract_plot_class.py index a26023bb6cb8772623479491ac8bcc731dd42223..377dc6b7abbbb693c9d18175983101e622063a70 100644 --- a/mlair/plotting/abstract_plot_class.py +++ b/mlair/plotting/abstract_plot_class.py @@ -92,11 +92,10 @@ class AbstractPlotClass: # pragma: no cover plt.rcParams.update(self.rc_params) @staticmethod - def _get_sampling(sampling): - if sampling == "daily": - return "D" - elif sampling == "hourly": - return "h" + def _get_sampling(sampling, pos=1): + sampling = (sampling, sampling) if isinstance(sampling, str) else sampling + sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[pos], "") + return sampling, sampling_letter @staticmethod def get_dataset_colors(): diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py index db2b3340e06545f988c81503df2aa27b655095bb..d33f3abb2e2cd04399a2d98f34bf2ed06acfcb6b 100644 --- a/mlair/plotting/data_insight_plotting.py +++ b/mlair/plotting/data_insight_plotting.py @@ -188,7 +188,7 @@ class PlotAvailability(AbstractPlotClass): # pragma: no cover super().__init__(plot_folder, "data_availability") self.time_dim = time_dimension self.window_dim = window_dimension - self.sampling = self._get_sampling(sampling) + self.sampling = self._get_sampling(sampling)[1] self.linewidth = None if self.sampling == 'h': self.linewidth = 0.001 @@ -321,7 +321,7 @@ class PlotAvailabilityHistogram(AbstractPlotClass): # pragma: no cover def _set_dims_from_datahandler(self, data_handler): self.temporal_dim = data_handler.id_class.time_dim self.target_dim = data_handler.id_class.target_dim - self.freq = self._get_sampling(data_handler.id_class.sampling) + self.freq = self._get_sampling(data_handler.id_class.sampling)[1] @property def allowed_plot_types(self): diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 62ab3de3ca3647a41c61a0e8ac5ff94abe2ace47..d1a68896edf0b794b644bc325014efb4c7fe785f 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -748,19 +748,13 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass): # pragma: no cover data = data.assign_coords({self._x_name: new_boot_coords}) except NotImplementedError: pass - _, sampling_letter = self._get_target_sampling(sampling, 1) + _, sampling_letter = self._get_sampling(sampling, 1) self._labels = [str(i) + sampling_letter for i in data.coords[self._ahead_dim].values] if station_dim not in data.dims: data = data.expand_dims(station_dim) self._number_of_bootstraps = np.unique(data.coords[self._boot_dim].values).shape[0] return data.to_dataframe("data").reset_index(level=np.arange(len(data.dims)).tolist()).dropna() - @staticmethod - def _get_target_sampling(sampling, pos): - sampling = (sampling, sampling) if isinstance(sampling, str) else sampling - sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[pos], "") - return sampling, sampling_letter - def _return_vars_without_number_tag(self, values, split_by, keep, as_unique=False): arr = np.array([v.split(split_by) for v in values]) num = arr[:, 0] @@ -1462,12 +1456,6 @@ class PlotSeasonalMSEStack(AbstractPlotClass): season_share = xr_data.sel({season_dim: "total"}) * factor return season_share.sortby(season_share.sum([self.season_dim, self.ahead_dim])).transpose(*self.dim_order) - @staticmethod - def _get_target_sampling(sampling, pos): - sampling = (sampling, sampling) if isinstance(sampling, str) else sampling - sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[pos], "") - return sampling, sampling_letter - @staticmethod def _set_bar_label(ax): opts = {} @@ -1481,7 +1469,7 @@ class PlotSeasonalMSEStack(AbstractPlotClass): ax.bar_label(c, labels=_l, label_type='center') def _plot(self, dim, split_ahead=True, sampling="daily", orientation="vertical"): - _, sampling_letter = self._get_target_sampling(sampling, 1) + _, sampling_letter = self._get_sampling(sampling, 1) if split_ahead is False: self.plot_name = self.plot_name_orig + "_total_" + orientation data = self._data.mean(dim) @@ -1534,26 +1522,34 @@ class PlotErrorsOnMap(AbstractPlotClass): from mlair.plotting.data_insight_plotting import PlotStationMap def __init__(self, data_gen, errors, error_metric, plot_folder: str = ".", iter_dim: str = "station", - model_type_dim: str = "type", ahead_dim: str = "ahead"): + model_type_dim: str = "type", ahead_dim: str = "ahead", sampling: str = "daily"): super().__init__(plot_folder, f"map_plot_{error_metric}") - plot_path = os.path.join(self.plot_folder, f"{self.plot_name}.pdf") pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) error_metric_units = helpers.statistics.get_error_metrics_units("ppb")[error_metric] error_metric_name = helpers.statistics.get_error_metrics_long_name()[error_metric] + self.sampling = self._get_sampling(sampling, 1)[1] coords = self._extract_coords(data_gen) - error_data = {} - for model_type in errors.coords[model_type_dim].values: - error_data[model_type] = self._prepare_data(errors, model_type_dim, model_type, ahead_dim, error_metric) - - limits = self._calculate_limits(error_data) - - for model_type, error in error_data.items(): - plot_data = pd.concat([coords, error], axis=1) - self.plot(plot_data, error_metric, error_metric_name, error_metric_units, model_type, limits) - pdf_pages.savefig() + for split_ahead in [False, True]: + error_data = {} + for model_type in errors.coords[model_type_dim].values: + error_data[model_type] = self._prepare_data(errors, model_type_dim, model_type, ahead_dim, error_metric, + split_ahead=split_ahead) + limits = self._calculate_limits(error_data) + for model_type, error in error_data.items(): + if split_ahead is True: + for ahead in error.index.unique(1).to_list(): + error_ahead = error.query(f"{ahead_dim} == {ahead}").droplevel(1) + plot_data = pd.concat([coords, error_ahead], axis=1) + self.plot(plot_data, error_metric, error_metric_name, error_metric_units, model_type, limits, + ahead=ahead) + pdf_pages.savefig() + else: + plot_data = pd.concat([coords, error], axis=1) + self.plot(plot_data, error_metric, error_metric_name, error_metric_units, model_type, limits) + pdf_pages.savefig() pdf_pages.close() plt.close('all') @@ -1571,7 +1567,7 @@ class PlotErrorsOnMap(AbstractPlotClass): vmin = relative_round(bound_lims[0], 2, floor=True) vmax = relative_round(bound_lims[1], 2, ceil=True) interval = relative_round((vmax - vmin) / ncolors, 1, ceil=True) - bounds = np.arange(vmin, vmax, interval) + bounds = np.sort(np.arange(vmax, vmin, -interval)) return bounds @staticmethod @@ -1588,7 +1584,7 @@ class PlotErrorsOnMap(AbstractPlotClass): cmap = sns.color_palette("magma_r", as_cmap=True) return cmap - def plot(self, plot_data, error_metric, error_long_name, error_units, model_type, limits): + def plot(self, plot_data, error_metric, error_long_name, error_units, model_type, limits, ahead=None): import cartopy.crs as ccrs from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER fig = plt.figure(figsize=(10, 5)) @@ -1606,7 +1602,8 @@ class PlotErrorsOnMap(AbstractPlotClass): cbar_label = f"{error_long_name} (in {error_units})" if error_units is not None else error_long_name plt.colorbar(cb, label=cbar_label) self._adjust_extent(ax) - plt.title(model_type) + title = model_type if ahead is None else f"{model_type} ({ahead}{self.sampling})" + plt.title(title) plt.tight_layout() @staticmethod diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 0ff09e92d90a31dde6ac5fe01d75776bce6ac4e5..0fb14f55cf0d6270a8c26937b955e09758567101 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -735,7 +735,7 @@ class PostProcessing(RunEnvironment): for error_metric in self.errors.keys(): try: PlotErrorsOnMap(self.test_data, self.errors[error_metric], error_metric, - plot_folder=self.plot_path) + plot_folder=self.plot_path, sampling=self._sampling) except Exception as e: logging.error(f"Could not create plot PlotErrorsOnMap for {error_metric} due to the following " f"error: {e}\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")