diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 258e0e0c9a53652fedee36ad16d24019c6a4775d..d1a22885325dae0483fae2a2e6493a391c4596b0 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -34,6 +34,7 @@ class PostProcessing(RunEnvironment): self.train_data: DataGenerator = self.data_store.get("generator", "general.train") self.train_val_data: DataGenerator = self.data_store.get("generator", "general.train_val") self.plot_path: str = self.data_store.get("plot_path", "general") + self.target_var = self.data_store.get("target_var", "general") self.skill_scores = None self._run() @@ -63,14 +64,13 @@ class PostProcessing(RunEnvironment): def plot(self): logging.debug("Run plotting routines...") path = self.data_store.get("forecast_path", "general") - target_var = self.data_store.get("target_var", "general") plot_conditional_quantiles(self.test_data.stations, pred_name="CNN", ref_name="orig", forecast_path=path, plot_name_affix="cali-ref", plot_folder=self.plot_path) plot_conditional_quantiles(self.test_data.stations, pred_name="orig", ref_name="CNN", forecast_path=path, plot_name_affix="like-bas", plot_folder=self.plot_path) PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path) - PlotMonthlySummary(self.test_data.stations, path, r"forecasts_%s_test.nc", target_var, + PlotMonthlySummary(self.test_data.stations, path, r"forecasts_%s_test.nc", self.target_var, plot_folder=self.plot_path) PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN") PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False, @@ -102,7 +102,7 @@ class PostProcessing(RunEnvironment): input_data = self.test_data[i][0] # get scaling parameters - mean, std, transformation_method = data.get_transformation_information(variable='o3') + mean, std, transformation_method = data.get_transformation_information(variable=self.target_var) # nn forecast nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method) @@ -147,7 +147,7 @@ class PostProcessing(RunEnvironment): return ols_prediction def _create_persistence_forecast(self, input_data, persistence_prediction, mean, std, transformation_method): - tmp_persi = input_data.sel({'window': 0, 'variables': 'o3'}) + tmp_persi = input_data.sel({'window': 0, 'variables': self.target_var}) tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method) window_lead_time = self.data_store.get("window_lead_time", "general") persistence_prediction.values = np.expand_dims(np.tile(tmp_persi.squeeze('Stations'), (window_lead_time, 1)), @@ -227,7 +227,7 @@ class PostProcessing(RunEnvironment): def _get_external_data(self, station): try: data = self.train_val_data.get_data_generator(station) - mean, std, transformation_method = data.get_transformation_information(variable='o3') + mean, std, transformation_method = data.get_transformation_information(variable=self.target_var) external_data = self._create_orig_forecast(data, None, mean, std, transformation_method) external_data = external_data.squeeze("Stations").sel(window=1).drop(["window", "Stations", "variables"]) return external_data.rename({'datetime': 'index'})