diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 732a7efdf8f360b49823dfb6ca5ca3239cc774af..92ff8b718dbb7bbbcbebe4e80fb82e0e2a7886c6 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -90,8 +90,7 @@ class DataGenerator(keras.utils.Sequence): :return: The generator's time series of history data and its labels """ data = self.get_data_generator(key=item) - return data.history.transpose("datetime", "window", "Stations", "variables"), \ - data.label.squeeze("Stations").transpose("datetime", "window") + return data.get_transposed_history(), data.label.squeeze("Stations").transpose("datetime", "window") def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep: """ @@ -124,7 +123,10 @@ class DataGenerator(keras.utils.Sequence): Save given data locally as .pickle in self.data_path_tmp with name '<station>_<var1>_<var2>_..._<varX>.pickle' :param data: any data, that should be saved """ - file = os.path.join(self.data_path_tmp, f"{''.join(data.station)}_{'_'.join(sorted(data.variables))}.pickle") + date = f"{self.kwargs.get('start')}_{self.kwargs.get('end')}" + vars = '_'.join(sorted(data.variables)) + station = ''.join(data.station) + file = os.path.join(self.data_path_tmp, f"{station}_{vars}_{date}_.pickle") with open(file, "wb") as f: pickle.dump(data, f) logging.debug(f"save pickle data to {file}") @@ -136,7 +138,10 @@ class DataGenerator(keras.utils.Sequence): :param variables: list of variables to load :return: loaded data """ - file = os.path.join(self.data_path_tmp, f"{''.join(station)}_{'_'.join(sorted(variables))}.pickle") + date = f"{self.kwargs.get('start')}_{self.kwargs.get('end')}" + vars = '_'.join(sorted(variables)) + station = ''.join(station) + file = os.path.join(self.data_path_tmp, f"{station}_{vars}_{date}_.pickle") with open(file, "rb") as f: data = pickle.load(f) logging.debug(f"load pickle data from {file}") diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index c39625b1e02506696ee5b4c13ac86c7e73420acf..81ce5cddf05cc0158f81a7666cd7a4956bf0a400 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -385,6 +385,10 @@ class DataPrep(object): data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) return data + def get_transposed_history(self): + if self.history is not None: + return self.history.transpose("datetime", "window", "Stations", "variables") + if __name__ == "__main__": dp = DataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index 97d326bc87c72142ab20ea95effbd88f490f2937..4c5d5ca0f9a9c26d0b8085e89e4bbacf2525e33a 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -477,3 +477,76 @@ class PlotCompetitiveSkillScore(RunEnvironment): logging.debug(f"... save plot to {plot_name}") plt.savefig(plot_name, dpi=500) plt.close() + + +class PlotTimeSeries(RunEnvironment): + + def __init__(self, stations: List, data_path: str, name: str, window_lead_time: int = None, plot_folder: str = "."): + super().__init__() + self._data_path = data_path + self._data_name = name + self._stations = stations + self._window_lead_time = self._get_window_lead_time(window_lead_time) + self._plot(plot_folder) + + def _get_window_lead_time(self, window_lead_time: int): + """ + Extract the lead time from data and arguments. If window_lead_time is not given, extract this information from + data itself by the number of ahead dimensions. If given, check if data supports the give length. If the number + of ahead dimensions in data is lower than the given lead time, data's lead time is used. + :param window_lead_time: lead time from arguments to validate + :return: validated lead time, comes either from given argument or from data itself + """ + ahead_steps = len(self._load_data(self._stations[0]).ahead) + if window_lead_time is None: + window_lead_time = ahead_steps + return min(ahead_steps, window_lead_time) + + def _load_data(self, station): + logging.debug(f"... preprocess station {station}") + file_name = os.path.join(self._data_path, self._data_name % station) + data = xr.open_dataarray(file_name) + return data.sel(type=["CNN", "orig"]) + + def _plot(self, plot_folder): + pdf_pages = self._save_pdf_pages(plot_folder) + start, end = self._get_time_range(self._load_data(self._stations[0])) + color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex() + for pos, station in enumerate(self._stations): + data = self._load_data(station) + f, axes = plt.subplots(end - start + 1, sharey=True, figsize=(40, 20)) + nan_list = [] + for i in range(end - start + 1): + data_year = data.sel(index=f"{start + i}") + orig_data = data_year.sel(type="orig", ahead=1).values + axes[i].plot(data_year.index + np.timedelta64(1, "D"), orig_data, color=color_palette[0], label="orig") + for ahead in data.coords["ahead"].values: + plot_data = data_year.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze() + axes[i].plot(plot_data.index + np.timedelta64(int(ahead), "D"), plot_data.values, color=color_palette[ahead], label=f"{ahead}d") + if np.isnan(orig_data).all(): + nan_list.append(i) + for i in reversed(nan_list): + f.delaxes(axes[i]) + + plt.suptitle(station) + plt.legend() + plt.tight_layout() + pdf_pages.savefig(dpi=500) + pdf_pages.close() + plt.close('all') + + @staticmethod + def _get_time_range(data): + def f(x, f_x): + return pd.to_datetime(f_x(x.index.values)).year + return f(data, min), f(data, max) + + @staticmethod + def _save_pdf_pages(plot_folder): + """ + Standard save method to store plot locally. The name of this plot is static. + :param plot_folder: path to save the plot + """ + plot_name = os.path.join(os.path.abspath(plot_folder), 'timeseries_plot.pdf') + logging.debug(f"... save plot to {plot_name}") + return matplotlib.backends.backend_pdf.PdfPages(plot_name) diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index e6f271ce3cc6cf2548ff5b06ba40e2fd509f8c8d..37651d0cf38f8d669018c5fbd5bcd113115905b1 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -16,7 +16,8 @@ from src.data_handling.data_generator import DataGenerator from src.model_modules.linear_model import OrdinaryLeastSquaredModel from src import statistics from src.plotting.postprocessing_plotting import plot_conditional_quantiles -from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, PlotCompetitiveSkillScore +from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \ + PlotCompetitiveSkillScore, PlotTimeSeries from src.datastore import NameNotFoundInDataStore from src.helpers import TimeTracking @@ -42,10 +43,10 @@ class PostProcessing(RunEnvironment): logging.info("take a look on the next reported time measure. If this increases a lot, one should think to " "skip make_prediction() whenever it is possible to save time.") with TimeTracking(): - preds_for_all_stations = self.make_prediction() + self.make_prediction() logging.info("take a look on the next reported time measure. If this increases a lot, one should think to " "skip make_prediction() whenever it is possible to save time.") - self.skill_scores = self.calculate_skill_scores() + # self.skill_scores = self.calculate_skill_scores() self.plot() def _load_model(self): @@ -75,6 +76,7 @@ class PostProcessing(RunEnvironment): PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False, extra_name_tag="all_terms_", model_setup="CNN") PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN") + PlotTimeSeries(self.test_data.stations, path, r"forecasts_%s_test.nc", plot_folder=self.plot_path) def calculate_test_score(self): test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(), @@ -93,12 +95,11 @@ class PostProcessing(RunEnvironment): def make_prediction(self, freq="1D"): logging.debug("start make_prediction") - nn_prediction_all_stations = [] - for i, v in enumerate(self.test_data): + for i, _ in enumerate(self.test_data): data = self.test_data.get_data_generator(i) nn_prediction, persistence_prediction, ols_prediction = self._create_empty_prediction_arrays(data, count=3) - input_data = self.test_data[i][0] + input_data = data.get_transposed_history() # get scaling parameters mean, std, transformation_method = data.get_transformation_information(variable='o3') @@ -129,10 +130,6 @@ class PostProcessing(RunEnvironment): file = os.path.join(path, f"forecasts_{data.station[0]}_test.nc") all_predictions.to_netcdf(file) - # save nn forecast to return variable - nn_prediction_all_stations.append(nn_prediction) - return nn_prediction_all_stations - @staticmethod def _create_orig_forecast(data, _, mean, std, transformation_method): return statistics.apply_inverse_transformation(data.label.copy(), mean, std, transformation_method)