diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 47daaf084d1170385ffa5869385331f000a8bc40..6373d3b61a4aa7ffb4ce782365b9ba1234761f0e 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -20,6 +20,7 @@ from matplotlib.backends.backend_pdf import PdfPages from matplotlib.offsetbox import AnchoredText import matplotlib.dates as mdates from scipy.stats import mannwhitneyu +import datetime as dt from mlair import helpers from mlair.data_handler.iterator import DataCollection @@ -177,16 +178,26 @@ class PlotConditionalQuantiles(AbstractPlotClass): # pragma: no cover warnings.filterwarnings("ignore", message="Attempted to set non-positive bottom ylim on a log-scaled axis.") def __init__(self, stations: List, data_pred_path: str, plot_folder: str = ".", plot_per_seasons=True, - rolling_window: int = 3, forecast_indicator: str = "nn", obs_indicator: str = "obs", **kwargs): + rolling_window: int = 3, forecast_indicator: str = "nn", obs_indicator: str = "obs", + competitors=None, model_type_dim: str = "type", index_dim: str = "index", ahead_dim: str = "ahead", + competitor_path: str = None, sampling: str = "daily", model_name: str = "nn", **kwargs): """Initialise.""" super().__init__(plot_folder, "conditional_quantiles") self._data_pred_path = data_pred_path self._stations = stations self._rolling_window = rolling_window self._forecast_indicator = forecast_indicator + self.model_type_dim = model_type_dim + self.index_dim = index_dim + self.ahead_dim = ahead_dim + self.iter_dim = "station" + self.model_name = model_name self._obs_name = obs_indicator + self._sampling = sampling self._opts = self._get_opts(kwargs) self._seasons = ['DJF', 'MAM', 'JJA', 'SON'] if plot_per_seasons is True else "" + self.competitors = self._correct_persi_name(competitors or []) + self.competitor_path = competitor_path or data_pred_path self._data = self._load_data() self._bins = self._get_bins_from_rage_of_data() self._plot() @@ -211,11 +222,95 @@ class PlotConditionalQuantiles(AbstractPlotClass): # pragma: no cover for station in self._stations: file = os.path.join(self._data_pred_path, f"forecasts_{station}_test.nc") data_tmp = xr.open_dataarray(file) - data_collector.append(data_tmp.loc[:, :, [self._forecast_indicator, - self._obs_name]].assign_coords(station=station)) - res = xr.concat(data_collector, dim='station').transpose('index', 'type', 'ahead', 'station') + start = data_tmp.coords[self.index_dim].min().values + end = data_tmp.coords[self.index_dim].max().values + competitor = self.load_competitors(station, start, end) + combined = self._combine_forecasts(data_tmp, competitor, dim=self.model_type_dim) + sel = combined.sel({self.model_type_dim: [self._forecast_indicator, self._obs_name, *self.competitors]}) + data_collector.append(sel.assign_coords({self.iter_dim: station})) + res = xr.concat(data_collector, dim=self.iter_dim).transpose(self.index_dim, self.model_type_dim, + self.ahead_dim, self.iter_dim) + return res + + def _combine_forecasts(self, forecast, competitor, dim=None): + """ + Combine forecast and competitor if both are xarray. If competitor is None, this returns forecasts and vise + versa. + """ + if dim is None: + dim = self.model_type_dim + try: + return xr.concat([forecast, competitor], dim=dim) + except (TypeError, AttributeError): + return forecast if competitor is None else competitor + + def load_competitors(self, station_name: str, start, end) -> xr.DataArray: + """ + Load all requested and available competitors for a given station. Forecasts must be available in the competitor + path like `<competitor_path>/<target_var>/forecasts_<station_name>_test.nc`. The naming style is equal for all + forecasts of MLAir, so that forecasts of a different experiment can easily be copied into the competitor path + without any change. + + :param station_name: station indicator to load competitors for + + :return: a single xarray with all competing forecasts + """ + competing_predictions = [] + for competitor_name in self.competitors: + try: + prediction = self._create_competitor_forecast(station_name, competitor_name, start, end) + competing_predictions.append(prediction) + except (FileNotFoundError, KeyError): + logging.debug(f"No competitor found for combination '{station_name}' and '{competitor_name}'.") + continue + return xr.concat(competing_predictions, self.model_type_dim) if len(competing_predictions) > 0 else None + + @staticmethod + def create_full_time_dim(data, dim, sampling, start, end): + """Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped.""" + start_data = data.coords[dim].values[0] + freq = {"daily": "1D", "hourly": "1H"}.get(sampling) + _ind = pd.date_range(start, end, freq=freq) # two steps required to include all hours of end interval + datetime_index = pd.DataFrame(index=pd.date_range(_ind.min(), _ind.max() + dt.timedelta(days=1), + closed="left", freq=freq)) + t = data.sel({dim: start_data}, drop=True) + res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords]) + res = res.transpose(*data.dims) + if data.shape == res.shape: + res.loc[data.coords] = data + else: + _d = data.sel({dim: slice(start, end)}) + res.loc[_d.coords] = _d return res + def _create_competitor_forecast(self, station_name: str, competitor_name: str, start, end) -> xr.DataArray: + """ + Load and format the competing forecast of a distinct model indicated by `competitor_name` for a distinct station + indicated by `station_name`. The name of the competitor is set in the `type` axis as indicator. This method will + raise either a `FileNotFoundError` or `KeyError` if no competitor could be found for the given station. Either + there is no file provided in the expected path or no forecast for given `competitor_name` in the forecast file. + Forecast is trimmed on interval start and end of test subset. + + :param station_name: name of the station to load data for + :param competitor_name: name of the model + :return: the forecast of the given competitor + """ + path = os.path.join(self.competitor_path, competitor_name) + file = os.path.join(path, f"forecasts_{station_name}_test.nc") + with xr.open_dataarray(file) as da: + data = da.load() + if self._forecast_indicator in data.coords[self.model_type_dim]: + forecast = data.sel({self.model_type_dim: [self._forecast_indicator]}) + forecast.coords[self.model_type_dim] = [competitor_name] + else: + forecast = data.sel({self.model_type_dim: [competitor_name]}) + # limit forecast to time range of test subset + return self.create_full_time_dim(forecast, self.index_dim, self._sampling, start, end) + + @staticmethod + def _correct_persi_name(competitors): + return ["persi" if x == "Persistence" else x for x in competitors] + def _segment_data(self, data: xr.DataArray, x_model: str) -> xr.DataArray: """ Segment data into bins. @@ -227,12 +322,13 @@ class PlotConditionalQuantiles(AbstractPlotClass): # pragma: no cover """ logging.debug("... segment data") # combine index and station to multi index - data = data.stack(z=['index', 'station']) + data = data.stack(z=[self.index_dim, self.iter_dim]) # replace multi index by simple position index (order is not relevant anymore) data.coords['z'] = range(len(data.coords['z'])) # segment data of x_model into bins - data.loc[x_model, ...] = data.loc[x_model, ...].to_pandas().T.apply(pd.cut, bins=self._bins, - labels=self._bins[1:]).T.values + data_sel = data.sel({self.model_type_dim: x_model}) + data.loc[{self.model_type_dim: x_model}] = data_sel.to_pandas().T.apply(pd.cut, bins=self._bins, + labels=self._bins[1:]).T.values return data @staticmethod @@ -245,7 +341,7 @@ class PlotConditionalQuantiles(AbstractPlotClass): # pragma: no cover :return: tuple with y and x labels """ - names = (f"forecast concentration (in {data_unit})", f"observed concentration (in {data_unit})") + names = (f"forecasted concentration (in {data_unit})", f"observed concentration (in {data_unit})") if plot_type == "obs": return names else: @@ -273,9 +369,9 @@ class PlotConditionalQuantiles(AbstractPlotClass): # pragma: no cover # create empty xarray with dims: time steps ahead, quantiles, bin index (numbers create in previous step) quantile_panel = xr.DataArray( np.full([data.ahead.shape[0], len(self._opts["q"]), self._bins[1:].shape[0]], np.nan), - coords=[data.ahead, self._opts["q"], self._bins[1:]], dims=['ahead', 'quantiles', 'categories']) + coords=[data.ahead, self._opts["q"], self._bins[1:]], dims=[self.ahead_dim, 'quantiles', 'categories']) # ensure that the coordinates are in the right order - quantile_panel = quantile_panel.transpose('ahead', 'quantiles', 'categories') + quantile_panel = quantile_panel.transpose(self.ahead_dim, 'quantiles', 'categories') # calculate for each bin of the pred_name data the quantiles of the ref_name data for bin in self._bins[1:]: mask = (data.loc[x_model, ...] == bin) @@ -309,28 +405,34 @@ class PlotConditionalQuantiles(AbstractPlotClass): # pragma: no cover def _plot(self): """Start plotting routines: overall plot and seasonal (if enabled).""" - logging.info( - f"start plotting {self.__class__.__name__}, scheduled number of plots: {(len(self._seasons) + 1) * 2}") - + logging.info(f"start plotting {self.__class__.__name__}, scheduled number of plots: " + f"{(len(self._seasons) + 1) * (len(self.competitors) + 1) * 2}") if len(self._seasons) > 0: self._plot_seasons() self._plot_all() def _plot_seasons(self): """Create seasonal plots.""" - for season in self._seasons: - self._plot_base(data=self._data.where(self._data['index.season'] == season), x_model=self._forecast_indicator, - y_model=self._obs_name, plot_name_affix="cali-ref", season=season) - self._plot_base(data=self._data.where(self._data['index.season'] == season), x_model=self._obs_name, - y_model=self._forecast_indicator, plot_name_affix="like-base", season=season) + for model in [self._forecast_indicator, *self.competitors]: + for season in self._seasons: + self._plot_base(data=self._data.where(self._data[f"{self.index_dim}.season"] == season), + x_model=model, y_model=self._obs_name, plot_name_affix="cali-ref", + season=season, model_name=model) + self._plot_base(data=self._data.where(self._data[f"{self.index_dim}.season"] == season), + x_model=self._obs_name, y_model=model, plot_name_affix="like-base", + season=season, model_name=model) def _plot_all(self): """Plot overall conditional quantiles on full data.""" - self._plot_base(data=self._data, x_model=self._forecast_indicator, y_model=self._obs_name, plot_name_affix="cali-ref") - self._plot_base(data=self._data, x_model=self._obs_name, y_model=self._forecast_indicator, plot_name_affix="like-base") + for model in [self._forecast_indicator, *self.competitors]: + self._plot_base(data=self._data, x_model=model, y_model=self._obs_name, + plot_name_affix="cali-ref", model_name=model) + self._plot_base(data=self._data, x_model=self._obs_name, y_model=model, + plot_name_affix="like-base", model_name=model) @TimeTrackingWrapper - def _plot_base(self, data: xr.DataArray, x_model: str, y_model: str, plot_name_affix: str, season: str = ""): + def _plot_base(self, data: xr.DataArray, x_model: str, y_model: str, plot_name_affix: str, season: str = "", + model_name: str = ""): """ Create conditional quantile plots. @@ -342,7 +444,8 @@ class PlotConditionalQuantiles(AbstractPlotClass): # pragma: no cover """ segmented_data, quantile_panel = self._prepare_plots(data, x_model, y_model) ylabel, xlabel = self._labels(x_model, self._opts["data_unit"]) - plot_name = f"{self.plot_name}{self.add_affix(season)}{self.add_affix(plot_name_affix)}_plot.pdf" + plot_name = f"{self.plot_name}{self.add_affix(season)}{self.add_affix(plot_name_affix)}" \ + f"{self.add_affix(model_name)}.pdf" plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name) pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) logging.debug(f"... plot path is {plot_path}") @@ -380,7 +483,9 @@ class PlotConditionalQuantiles(AbstractPlotClass): # pragma: no cover ax2.set_ylabel(' sample size', fontsize='x-large') ax2.tick_params(axis='y', which='major', labelsize=15) # set title and save current figure - title = f"{d.values} time step(s) ahead{f' ({season})' if len(season) > 0 else ''}" + sampling_letter = {"daily": "D", "hourly": "H"}.get(self._sampling) + model_name = self.model_name if model_name == self._forecast_indicator else model_name + title = f"{model_name} ({sampling_letter}{d.values}{f', {season}' if len(season) > 0 else ''})" plt.title(title) pdf_pages.savefig() # close all open figures / plots diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 0fb14f55cf0d6270a8c26937b955e09758567101..e4cc34f66db7426876209fece0dc902341ac6582 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -583,7 +583,10 @@ class PostProcessing(RunEnvironment): if "PlotConditionalQuantiles" in plot_list: PlotConditionalQuantiles(self.test_data.keys(), data_pred_path=self.forecast_path, plot_folder=self.plot_path, forecast_indicator=self.forecast_indicator, - obs_indicator=self.observation_indicator) + obs_indicator=self.observation_indicator, competitors=self.competitors, + model_type_dim=self.model_type_dim, index_dim=self.index_dim, + ahead_dim=self.ahead_dim, competitor_path=self.competitor_path, + model_name=self.model_display_name) except Exception as e: logging.error(f"Could not create plot PlotConditionalQuantiles due to the following error:" f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")