diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index fc3aa055679b2ad1f03719c6300b59f7ca2371c2..62ab3de3ca3647a41c61a0e8ac5ff94abe2ace47 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -11,6 +11,7 @@ import itertools import matplotlib import matplotlib.pyplot as plt +import matplotlib.colors as colors import numpy as np import pandas as pd import seaborn as sns @@ -23,6 +24,7 @@ from scipy.stats import mannwhitneyu from mlair import helpers from mlair.data_handler.iterator import DataCollection from mlair.helpers import TimeTrackingWrapper +from mlair.helpers.helpers import relative_round from mlair.plotting.abstract_plot_class import AbstractPlotClass from mlair.helpers.statistics import mann_whitney_u_test, represent_p_values_as_asteriks @@ -132,8 +134,7 @@ class PlotMonthlySummary(AbstractPlotClass): # pragma: no cover """ data = self._data.to_dataset(name='values').to_dask_dataframe() logging.debug("... start plotting") - color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", - self._window_lead_time).as_hex() + color_palette = [colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex() ax = sns.boxplot(x='index', y='values', hue='ahead', data=data.compute(), whis=1.5, palette=color_palette, flierprops={'marker': '.', 'markersize': 1}, showmeans=True, meanprops={'markersize': 1, 'markeredgecolor': 'k'}) @@ -1044,7 +1045,7 @@ class PlotTimeSeries: # pragma: no cover def _plot_obs(self, ax, data): ahead = 1 obs_data = data.sel(type="obs", ahead=ahead).shift(index=ahead) - ax.plot(obs_data, color=matplotlib.colors.cnames["green"], label="obs") + ax.plot(obs_data, color=colors.cnames["green"], label="obs") @staticmethod def _get_time_range(data): @@ -1528,9 +1529,237 @@ class PlotSeasonalMSEStack(AbstractPlotClass): fig.tight_layout(rect=[0, 0, 1, 0.95]) -if __name__ == "__main__": - stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] - path = "../../testrun_network/forecasts" - plt_path = "../../" +@TimeTrackingWrapper +class PlotErrorsOnMap(AbstractPlotClass): + from mlair.plotting.data_insight_plotting import PlotStationMap + + def __init__(self, data_gen, errors, error_metric, plot_folder: str = ".", iter_dim: str = "station", + model_type_dim: str = "type", ahead_dim: str = "ahead"): + + super().__init__(plot_folder, f"map_plot_{error_metric}") + + plot_path = os.path.join(self.plot_folder, f"{self.plot_name}.pdf") + pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) + error_metric_units = helpers.statistics.get_error_metrics_units("ppb")[error_metric] + error_metric_name = helpers.statistics.get_error_metrics_long_name()[error_metric] + + coords = self._extract_coords(data_gen) + error_data = {} + for model_type in errors.coords[model_type_dim].values: + error_data[model_type] = self._prepare_data(errors, model_type_dim, model_type, ahead_dim, error_metric) + + limits = self._calculate_limits(error_data) + + for model_type, error in error_data.items(): + plot_data = pd.concat([coords, error], axis=1) + self.plot(plot_data, error_metric, error_metric_name, error_metric_units, model_type, limits) + pdf_pages.savefig() + pdf_pages.close() + plt.close('all') + + @staticmethod + def _calculate_limits(data): + vmin, vmax = np.inf, -np.inf + for v in data.values(): + vmin = min(vmin, v.min().values) + vmax = max(vmax, v.max().values) + return relative_round(float(vmin), 2, floor=True), relative_round(float(vmax), 2, ceil=True) - con_quan_cls = PlotConditionalQuantiles(stations, path, plt_path) + @staticmethod + def _set_bounds(limits, ncolors, error_metric): + bound_lims = {"ioa": [0, 1], "mnmb": [-2, 2]}.get(error_metric, limits) + vmin = relative_round(bound_lims[0], 2, floor=True) + vmax = relative_round(bound_lims[1], 2, ceil=True) + interval = relative_round((vmax - vmin) / ncolors, 1, ceil=True) + bounds = np.arange(vmin, vmax, interval) + return bounds + + @staticmethod + def _get_colorpalette(error_metric): + # cmap = matplotlib.cm.coolwarm + # cmap = sns.color_palette("magma_r", as_cmap=True) + # cmap="Spectral_r", cmap="RdYlBu_r", cmap="coolwarm", + # cmap = sns.cubehelix_palette(8, start=2, rot=0, dark=0, light=.95, as_cmap=True) + if error_metric == "mnmb": + cmap = sns.mpl_palette("coolwarm", as_cmap=True) + elif error_metric == "ioa": + cmap = sns.mpl_palette("coolwarm_r", as_cmap=True) + else: + cmap = sns.color_palette("magma_r", as_cmap=True) + return cmap + + def plot(self, plot_data, error_metric, error_long_name, error_units, model_type, limits): + import cartopy.crs as ccrs + from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER + fig = plt.figure(figsize=(10, 5)) + ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree()) + _gl = ax.gridlines(xlocs=range(-180, 180, 5), ylocs=range(-90, 90, 2), draw_labels=True) + _gl.xformatter = LONGITUDE_FORMATTER + _gl.yformatter = LATITUDE_FORMATTER + self._draw_background(ax) + cmap = self._get_colorpalette(error_metric) + ncolors = 20 + bounds = self._set_bounds(limits, ncolors, error_metric) + norm = colors.BoundaryNorm(bounds, cmap.N, extend='both') + cb = ax.scatter(plot_data["lon"], plot_data["lat"], c=plot_data[error_metric], marker='o', s=50, + transform=ccrs.PlateCarree(), zorder=2, cmap=cmap, norm=norm) + cbar_label = f"{error_long_name} (in {error_units})" if error_units is not None else error_long_name + plt.colorbar(cb, label=cbar_label) + self._adjust_extent(ax) + plt.title(model_type) + plt.tight_layout() + + @staticmethod + def _adjust_extent(ax): + import cartopy.crs as ccrs + + def diff(arr): + return arr[1] - arr[0], arr[3] - arr[2] + + def find_ratio(delta, reference=5): + return min(max(abs(reference / delta[0]), abs(reference / delta[1])), 5) + + extent = ax.get_extent(crs=ccrs.PlateCarree()) + ratio = find_ratio(diff(extent)) + new_extent = extent + np.array([-1, 1, -1, 1]) * ratio + ax.set_extent(new_extent, crs=ccrs.PlateCarree()) + + @staticmethod + def _extract_coords(gen): + coll = [] + for station in gen: + coords = station.get_coordinates() + coll.append((str(station), coords["lon"], coords["lat"])) + return pd.DataFrame(coll, columns=["station", "lon", "lat"]).set_index("station") + + @staticmethod + def _prepare_data(errors, model_type_dim, model_type, ahead_dim, error_metric, split_ahead=False): + e = errors.sel({model_type_dim: model_type}, drop=True) + if split_ahead is False: + e = e.mean(ahead_dim) + return e.to_dataframe(error_metric) + + @staticmethod + def _draw_background(ax): + """Draw coastline, lakes, ocean, rivers and country borders as background on the map.""" + + import cartopy.feature as cfeature + + ax.add_feature(cfeature.LAND.with_scale("50m")) + ax.natural_earth_shp(resolution='50m') + ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black') + ax.add_feature(cfeature.LAKES.with_scale("50m")) + ax.add_feature(cfeature.OCEAN.with_scale("50m")) + ax.add_feature(cfeature.RIVERS.with_scale("50m")) + ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black') + + + + + + + def _plot_individual(self): + import cartopy.feature as cfeature + import cartopy.crs as ccrs + from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER + from mpl_toolkits.axes_grid1 import make_axes_locatable + + for competitor in self.reference_models: + file_name = os.path.join(self.skill_score_report_path, + f"error_report_skill_score_{self.model_name}_-_{competitor}.csv" + ) + + plot_path = os.path.join(os.path.abspath(self.plot_folder), + f"{self.plot_name}_{self.model_name}_-_{competitor}.pdf") + pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) + + for i, lead_name in enumerate(df.columns[:-2]): # last two are lat lon + fig = plt.figure() + self._ax.scatter(df.lon.values, df.lat.values, c=df[lead_name], + transform=ccrs.PlateCarree(), + norm=norm, cmap=cmap) + self._gl = self._ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True) + self._gl.xformatter = LONGITUDE_FORMATTER + self._gl.yformatter = LATITUDE_FORMATTER + label = f"Skill Score: {lead_name.replace('-', 'vs.').replace('(t+', ' (').replace(')', 'd)')}" + self._cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), + orientation='horizontal', ticks=ticks, + label=label, + # cax=cax + ) + + # close all open figures / plots + pdf_pages.savefig() + pdf_pages.close() + plt.close('all') + + def _plot(self, ncol: int = 2): + import cartopy.feature as cfeature + import cartopy.crs as ccrs + from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER + import string + base_plot_name = self.plot_name + for competitor in self.reference_models: + file_name = os.path.join(self.skill_score_report_path, + f"error_report_skill_score_{self.model_name}_-_{competitor}.csv" + ) + + self.plot_name = f"{base_plot_name}_{self.model_name}_-_{competitor}" + df = self.open_data(file_name) + + nrow = int(np.ceil(len(df.columns[:-2])/ncol)) + bounds = np.linspace(-1, 1, 100) + cmap = mpl.cm.coolwarm + norm = colors.BoundaryNorm(bounds, cmap.N, extend='both') + ticks = np.arange(norm.vmin, norm.vmax + .2, .2) + fig, self._axes = plt.subplots(nrows=nrow, ncols=ncol, subplot_kw={'projection': ccrs.PlateCarree()}) + for i, ax in enumerate(self._axes.reshape(-1)): # last two are lat lon + + sub_name = f"({string.ascii_lowercase[i]})" + lead_name = df.columns[i] + ax.add_feature(cfeature.LAND.with_scale("50m")) + ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black') + ax.add_feature(cfeature.OCEAN.with_scale("50m")) + ax.add_feature(cfeature.RIVERS.with_scale("50m")) + ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black') + ax.scatter(df.lon.values, df.lat.values, c=df[lead_name], + marker='.', + transform=ccrs.PlateCarree(), + norm=norm, cmap=cmap) + gl = ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True) + gl.xformatter = LONGITUDE_FORMATTER + gl.yformatter = LATITUDE_FORMATTER + gl.top_labels = [] + gl.right_labels = [] + ax.text(0.01, 1.09, f'{sub_name} {lead_name.split("+")[1][:-1]}d', + verticalalignment='top', horizontalalignment='left', + transform=ax.transAxes, + color='black', + ) + label = f"Skill Score: {lead_name.replace('-', 'vs.').split('(')[0]}" + + fig.subplots_adjust(bottom=0.18) + cax = fig.add_axes([0.15, 0.1, 0.7, 0.02]) + self._cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), + orientation='horizontal', + ticks=ticks, + label=label, + cax=cax + ) + + fig.subplots_adjust(wspace=.001, hspace=.2) + self._save(bbox_inches="tight") + plt.close('all') + + + @staticmethod + def get_coords_from_index(name_string: str) -> List[float]: + """ + + :param name_string: + :type name_string: + :return: List of coords [lat, lon] + :rtype: List + """ + res = [float(frac.replace("_", ".")) for frac in name_string.split(sep="__")[1:]] + return res diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index f8e2677862e23c56fe38b13f7c6dfb78c6a7f964..0ff09e92d90a31dde6ac5fe01d75776bce6ac4e5 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -24,7 +24,8 @@ from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel from mlair.model_modules import AbstractModelClass from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \ PlotCompetitiveSkillScore, PlotTimeSeries, PlotFeatureImportanceSkillScore, PlotConditionalQuantiles, \ - PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap, PlotTimeEvolutionMetric, PlotSeasonalMSEStack + PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap, PlotTimeEvolutionMetric, PlotSeasonalMSEStack, \ + PlotErrorsOnMap from mlair.plotting.data_insight_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram, \ PlotPeriodogram, PlotDataHistogram from mlair.run_modules.run_environment import RunEnvironment @@ -729,6 +730,23 @@ class PostProcessing(RunEnvironment): logging.error(f"Could not create plot PlotSeasonalMSEStack due to the following error: {e}" f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") + try: + if "PlotErrorsOnMap" in plot_list and self.errors is not None: + for error_metric in self.errors.keys(): + try: + PlotErrorsOnMap(self.test_data, self.errors[error_metric], error_metric, + plot_folder=self.plot_path) + except Exception as e: + logging.error(f"Could not create plot PlotErrorsOnMap for {error_metric} due to the following " + f"error: {e}\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") + except Exception as e: + logging.error(f"Could not create plot PlotErrorsOnMap due to the following error: {e}" + f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") + + + + + @TimeTrackingWrapper def calculate_test_score(self): """Evaluate test score of model and save locally."""