diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index 97d326bc87c72142ab20ea95effbd88f490f2937..65b7fc0f6dd8a537fa8ded2dbc632e4668eaf0d3 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -477,3 +477,57 @@ class PlotCompetitiveSkillScore(RunEnvironment): logging.debug(f"... save plot to {plot_name}") plt.savefig(plot_name, dpi=500) plt.close() + + +class PlotTimeSeries(RunEnvironment): + + def __init__(self, stations: List, data_path: str, name: str, window_lead_time: int = None, plot_folder: str = "."): + super().__init__() + self._data_path = data_path + self._data_name = name + self._stations = stations + self._window_lead_time = self._get_window_lead_time(window_lead_time) + self._plot(plot_folder) + + def _get_window_lead_time(self, window_lead_time: int): + """ + Extract the lead time from data and arguments. If window_lead_time is not given, extract this information from + data itself by the number of ahead dimensions. If given, check if data supports the give length. If the number + of ahead dimensions in data is lower than the given lead time, data's lead time is used. + :param window_lead_time: lead time from arguments to validate + :return: validated lead time, comes either from given argument or from data itself + """ + ahead_steps = len(self._load_data(self._stations[0]).ahead) + if window_lead_time is None: + window_lead_time = ahead_steps + return min(ahead_steps, window_lead_time) + + def _load_data(self, station): + logging.debug(f"... preprocess station {station}") + file_name = os.path.join(self._data_path, self._data_name % station) + data = xr.open_dataarray(file_name) + return data.sel(type=["CNN", "orig"]) + + def _plot(self, plot_folder): + f, axes = plt.subplots(len(self._stations), sharex="all") + color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex() + for pos, station in enumerate(self._stations): + data = self._load_data(station) + axes[pos].plot(data.index+ np.timedelta64(1, "D"), data.sel(type="CNN", ahead=1).values, color=color_palette[0]) + for ahead in data.coords["ahead"].values: + plot_data = data.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze() + axes[pos].plot(plot_data.index + np.timedelta64(int(ahead), "D"), plot_data.values, color=color_palette[ahead]) + self._save(plot_folder) + + @staticmethod + def _save(plot_folder): + """ + Standard save method to store plot locally. The name of this plot is static. + :param plot_folder: path to save the plot + """ + plot_name = os.path.join(os.path.abspath(plot_folder), 'test_timeseries_plot.pdf') + logging.debug(f"... save plot to {plot_name}") + plt.savefig(plot_name, dpi=500) + plt.close('all') + + diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index e6f271ce3cc6cf2548ff5b06ba40e2fd509f8c8d..1db58b8f3c26e76e816ac5c8059d1fc9b2d20e8a 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -16,7 +16,8 @@ from src.data_handling.data_generator import DataGenerator from src.model_modules.linear_model import OrdinaryLeastSquaredModel from src import statistics from src.plotting.postprocessing_plotting import plot_conditional_quantiles -from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, PlotCompetitiveSkillScore +from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \ + PlotCompetitiveSkillScore, PlotTimeSeries from src.datastore import NameNotFoundInDataStore from src.helpers import TimeTracking @@ -42,10 +43,10 @@ class PostProcessing(RunEnvironment): logging.info("take a look on the next reported time measure. If this increases a lot, one should think to " "skip make_prediction() whenever it is possible to save time.") with TimeTracking(): - preds_for_all_stations = self.make_prediction() + self.make_prediction() logging.info("take a look on the next reported time measure. If this increases a lot, one should think to " "skip make_prediction() whenever it is possible to save time.") - self.skill_scores = self.calculate_skill_scores() + # self.skill_scores = self.calculate_skill_scores() self.plot() def _load_model(self): @@ -64,17 +65,18 @@ class PostProcessing(RunEnvironment): path = self.data_store.get("forecast_path", "general") target_var = self.data_store.get("target_var", "general") - plot_conditional_quantiles(self.test_data.stations, pred_name="CNN", ref_name="orig", - forecast_path=path, plot_name_affix="cali-ref", plot_folder=self.plot_path) - plot_conditional_quantiles(self.test_data.stations, pred_name="orig", ref_name="CNN", - forecast_path=path, plot_name_affix="like-bas", plot_folder=self.plot_path) - PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path) - PlotMonthlySummary(self.test_data.stations, path, r"forecasts_%s_test.nc", target_var, - plot_folder=self.plot_path) - PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN") - PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False, - extra_name_tag="all_terms_", model_setup="CNN") - PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN") + # plot_conditional_quantiles(self.test_data.stations, pred_name="CNN", ref_name="orig", + # forecast_path=path, plot_name_affix="cali-ref", plot_folder=self.plot_path) + # plot_conditional_quantiles(self.test_data.stations, pred_name="orig", ref_name="CNN", + # forecast_path=path, plot_name_affix="like-bas", plot_folder=self.plot_path) + # PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path) + # PlotMonthlySummary(self.test_data.stations, path, r"forecasts_%s_test.nc", target_var, + # plot_folder=self.plot_path) + # PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN") + # PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False, + # extra_name_tag="all_terms_", model_setup="CNN") + # PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN") + PlotTimeSeries(self.test_data.stations, path, r"forecasts_%s_test.nc", plot_folder=self.plot_path) def calculate_test_score(self): test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(), @@ -93,12 +95,11 @@ class PostProcessing(RunEnvironment): def make_prediction(self, freq="1D"): logging.debug("start make_prediction") - nn_prediction_all_stations = [] - for i, v in enumerate(self.test_data): + for i, _ in enumerate(self.test_data): data = self.test_data.get_data_generator(i) nn_prediction, persistence_prediction, ols_prediction = self._create_empty_prediction_arrays(data, count=3) - input_data = self.test_data[i][0] + input_data = data.get_transposed_history() # get scaling parameters mean, std, transformation_method = data.get_transformation_information(variable='o3') @@ -129,10 +130,6 @@ class PostProcessing(RunEnvironment): file = os.path.join(path, f"forecasts_{data.station[0]}_test.nc") all_predictions.to_netcdf(file) - # save nn forecast to return variable - nn_prediction_all_stations.append(nn_prediction) - return nn_prediction_all_stations - @staticmethod def _create_orig_forecast(data, _, mean, std, transformation_method): return statistics.apply_inverse_transformation(data.label.copy(), mean, std, transformation_method)