diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index 65b7fc0f6dd8a537fa8ded2dbc632e4668eaf0d3..4c5d5ca0f9a9c26d0b8085e89e4bbacf2525e33a 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -509,25 +509,44 @@ class PlotTimeSeries(RunEnvironment): return data.sel(type=["CNN", "orig"]) def _plot(self, plot_folder): - f, axes = plt.subplots(len(self._stations), sharex="all") + pdf_pages = self._save_pdf_pages(plot_folder) + start, end = self._get_time_range(self._load_data(self._stations[0])) color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex() for pos, station in enumerate(self._stations): data = self._load_data(station) - axes[pos].plot(data.index+ np.timedelta64(1, "D"), data.sel(type="CNN", ahead=1).values, color=color_palette[0]) - for ahead in data.coords["ahead"].values: - plot_data = data.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze() - axes[pos].plot(plot_data.index + np.timedelta64(int(ahead), "D"), plot_data.values, color=color_palette[ahead]) - self._save(plot_folder) + f, axes = plt.subplots(end - start + 1, sharey=True, figsize=(40, 20)) + nan_list = [] + for i in range(end - start + 1): + data_year = data.sel(index=f"{start + i}") + orig_data = data_year.sel(type="orig", ahead=1).values + axes[i].plot(data_year.index + np.timedelta64(1, "D"), orig_data, color=color_palette[0], label="orig") + for ahead in data.coords["ahead"].values: + plot_data = data_year.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze() + axes[i].plot(plot_data.index + np.timedelta64(int(ahead), "D"), plot_data.values, color=color_palette[ahead], label=f"{ahead}d") + if np.isnan(orig_data).all(): + nan_list.append(i) + for i in reversed(nan_list): + f.delaxes(axes[i]) + + plt.suptitle(station) + plt.legend() + plt.tight_layout() + pdf_pages.savefig(dpi=500) + pdf_pages.close() + plt.close('all') @staticmethod - def _save(plot_folder): + def _get_time_range(data): + def f(x, f_x): + return pd.to_datetime(f_x(x.index.values)).year + return f(data, min), f(data, max) + + @staticmethod + def _save_pdf_pages(plot_folder): """ Standard save method to store plot locally. The name of this plot is static. :param plot_folder: path to save the plot """ - plot_name = os.path.join(os.path.abspath(plot_folder), 'test_timeseries_plot.pdf') + plot_name = os.path.join(os.path.abspath(plot_folder), 'timeseries_plot.pdf') logging.debug(f"... save plot to {plot_name}") - plt.savefig(plot_name, dpi=500) - plt.close('all') - - + return matplotlib.backends.backend_pdf.PdfPages(plot_name) diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 1db58b8f3c26e76e816ac5c8059d1fc9b2d20e8a..37651d0cf38f8d669018c5fbd5bcd113115905b1 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -65,17 +65,17 @@ class PostProcessing(RunEnvironment): path = self.data_store.get("forecast_path", "general") target_var = self.data_store.get("target_var", "general") - # plot_conditional_quantiles(self.test_data.stations, pred_name="CNN", ref_name="orig", - # forecast_path=path, plot_name_affix="cali-ref", plot_folder=self.plot_path) - # plot_conditional_quantiles(self.test_data.stations, pred_name="orig", ref_name="CNN", - # forecast_path=path, plot_name_affix="like-bas", plot_folder=self.plot_path) - # PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path) - # PlotMonthlySummary(self.test_data.stations, path, r"forecasts_%s_test.nc", target_var, - # plot_folder=self.plot_path) - # PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN") - # PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False, - # extra_name_tag="all_terms_", model_setup="CNN") - # PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN") + plot_conditional_quantiles(self.test_data.stations, pred_name="CNN", ref_name="orig", + forecast_path=path, plot_name_affix="cali-ref", plot_folder=self.plot_path) + plot_conditional_quantiles(self.test_data.stations, pred_name="orig", ref_name="CNN", + forecast_path=path, plot_name_affix="like-bas", plot_folder=self.plot_path) + PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path) + PlotMonthlySummary(self.test_data.stations, path, r"forecasts_%s_test.nc", target_var, + plot_folder=self.plot_path) + PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN") + PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False, + extra_name_tag="all_terms_", model_setup="CNN") + PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN") PlotTimeSeries(self.test_data.stations, path, r"forecasts_%s_test.nc", plot_folder=self.plot_path) def calculate_test_score(self):