diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index 0e63125d921029d28a672a5e5e8d0ecd2995d050..01a565e3db284e56bf0b8c94420b71268fd21a80 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -481,14 +481,23 @@ class PlotCompetitiveSkillScore(RunEnvironment): class PlotTimeSeries(RunEnvironment): - def __init__(self, stations: List, data_path: str, name: str, window_lead_time: int = None, plot_folder: str = "."): + def __init__(self, stations: List, data_path: str, name: str, window_lead_time: int = None, plot_folder: str = ".", + sampling="daily"): super().__init__() self._data_path = data_path self._data_name = name self._stations = stations self._window_lead_time = self._get_window_lead_time(window_lead_time) + self._sampling = self._get_sampling(sampling) self._plot(plot_folder) + @staticmethod + def _get_sampling(sampling): + if sampling == "daily": + return "D" + elif sampling == "hourly": + return "h" + def _get_window_lead_time(self, window_lead_time: int): """ Extract the lead time from data and arguments. If window_lead_time is not given, extract this information from @@ -509,32 +518,67 @@ class PlotTimeSeries(RunEnvironment): return data.sel(type=["CNN", "orig"]) def _plot(self, plot_folder): - pdf_pages = self._save_pdf_pages(plot_folder) + pdf_pages = self._create_pdf_pages(plot_folder) start, end = self._get_time_range(self._load_data(self._stations[0])) - color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex() for pos, station in enumerate(self._stations): data = self._load_data(station) - f, axes = plt.subplots(end - start + 1, sharey=True, figsize=(40, 20)) + fig, axes, factor = self._create_subplots(start, end) nan_list = [] - for i in range(end - start + 1): - data_year = data.sel(index=f"{start + i}") - orig_data = data_year.sel(type="orig", ahead=1).values - axes[i].plot(data_year.index + np.timedelta64(1, "D"), orig_data, color=color_palette[0], label="orig") - for ahead in data.coords["ahead"].values: - plot_data = data_year.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze() - axes[i].plot(plot_data.index + np.timedelta64(int(ahead), "D"), plot_data.values, color=color_palette[ahead], label=f"{ahead}d") - if np.isnan(orig_data).all(): - nan_list.append(i) - for i in reversed(nan_list): - f.delaxes(axes[i]) - - plt.suptitle(station) - plt.legend() - plt.tight_layout() - pdf_pages.savefig(dpi=500) + for i_year in range(end - start + 1): + data_year = data.sel(index=f"{start + i_year}") + for i_half_of_year in range(factor): + pos = 2 * i_year + i_half_of_year + plot_data = self._create_plot_data(data_year, factor, i_half_of_year) + self._plot_orig(axes[pos], plot_data) + self._plot_ahead(axes[pos], plot_data) + if np.isnan(plot_data.values).all(): + nan_list.append(pos) + self._clean_up_axes(nan_list, axes, fig) + self._save_page(station, pdf_pages) pdf_pages.close() plt.close('all') + @staticmethod + def _clean_up_axes(nan_list, axes, fig): + for i in reversed(nan_list): + fig.delaxes(axes[i]) + + @staticmethod + def _save_page(station, pdf_pages): + plt.suptitle(station) + plt.legend() + plt.tight_layout() + pdf_pages.savefig(dpi=500) + + @staticmethod + def _create_plot_data(data, factor, running_index): + if factor > 1: + if running_index == 0: + data = data.where(data["index.month"] < 7) + else: + data = data.where(data["index.month"] >= 7) + return data + + def _create_subplots(self, start, end): + factor = 1 + if self._sampling == "h": + factor = 2 + f, ax = plt.subplots((end - start + 1) * factor, sharey=True, figsize=(50, 30)) + return f, ax, factor + + def _plot_ahead(self, ax, data): + color = sns.color_palette("Blues_d", self._window_lead_time).as_hex() + for ahead in data.coords["ahead"].values: + plot_data = data.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze() + index = plot_data.index + np.timedelta64(int(ahead), self._sampling) + label = f"{ahead}{self._sampling}" + ax.plot(index, plot_data.values, color=color[ahead-1], label=label) + + def _plot_orig(self, ax, data): + orig_data = data.sel(type="orig", ahead=1) + index = data.index + np.timedelta64(1, self._sampling) + ax.plot(index, orig_data.values, color=matplotlib.colors.cnames["green"], label="orig") + @staticmethod def _get_time_range(data): def f(x, f_x): @@ -542,7 +586,7 @@ class PlotTimeSeries(RunEnvironment): return f(data, min), f(data, max) @staticmethod - def _save_pdf_pages(plot_folder): + def _create_pdf_pages(plot_folder): """ Standard save method to store plot locally. The name of this plot is static. :param plot_folder: path to save the plot diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index fdc691c33e40acd7f6b6c9ca9e80acaa33d9e055..03d2e36e8662a573b96c970747e9fe4445244e9b 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -35,6 +35,7 @@ class PostProcessing(RunEnvironment): self.train_val_data: DataGenerator = self.data_store.get("generator", "general.train_val") self.plot_path: str = self.data_store.get("plot_path", "general") self.target_var = self.data_store.get("target_var", "general") + self._sampling = self.data_store.get("sampling", "general") self.skill_scores = None self._run() @@ -76,7 +77,7 @@ class PostProcessing(RunEnvironment): PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False, extra_name_tag="all_terms_", model_setup="CNN") PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN") - PlotTimeSeries(self.test_data.stations, path, r"forecasts_%s_test.nc", plot_folder=self.plot_path) + PlotTimeSeries(self.test_data.stations, path, r"forecasts_%s_test.nc", plot_folder=self.plot_path, sampling=self._sampling) def calculate_test_score(self): test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(), @@ -93,7 +94,7 @@ class PostProcessing(RunEnvironment): def train_ols_model(self): self.ols_model = OrdinaryLeastSquaredModel(self.train_data) - def make_prediction(self, freq="1D"): + def make_prediction(self): logging.debug("start make_prediction") for i, _ in enumerate(self.test_data): data = self.test_data.get_data_generator(i) @@ -118,7 +119,7 @@ class PostProcessing(RunEnvironment): orig_pred = self._create_orig_forecast(data, None, mean, std, transformation_method) # merge all predictions - full_index = self.create_fullindex(data.data.indexes['datetime'], freq) + full_index = self.create_fullindex(data.data.indexes['datetime'], self._get_frequency()) all_predictions = self.create_forecast_arrays(full_index, list(data.label.indexes['window']), CNN=nn_prediction, persi=persistence_prediction, @@ -130,6 +131,10 @@ class PostProcessing(RunEnvironment): file = os.path.join(path, f"forecasts_{data.station[0]}_test.nc") all_predictions.to_netcdf(file) + def _get_frequency(self): + getter = {"daily": "1D", "hourly": "1H"} + return getter.get(self._sampling, None) + @staticmethod def _create_orig_forecast(data, _, mean, std, transformation_method): return statistics.apply_inverse_transformation(data.label.copy(), mean, std, transformation_method)