Skip to content
Snippets Groups Projects
Commit 21b2d231 authored by lukas leufen's avatar lukas leufen
Browse files

refac some remaining o3 statements by target var from datastore

parent 609a4a7c
Branches
Tags
2 merge requests!37include new development,!36include using of hourly data
......@@ -34,6 +34,7 @@ class PostProcessing(RunEnvironment):
self.train_data: DataGenerator = self.data_store.get("generator", "general.train")
self.train_val_data: DataGenerator = self.data_store.get("generator", "general.train_val")
self.plot_path: str = self.data_store.get("plot_path", "general")
self.target_var = self.data_store.get("target_var", "general")
self.skill_scores = None
self._run()
......@@ -63,14 +64,13 @@ class PostProcessing(RunEnvironment):
def plot(self):
logging.debug("Run plotting routines...")
path = self.data_store.get("forecast_path", "general")
target_var = self.data_store.get("target_var", "general")
plot_conditional_quantiles(self.test_data.stations, pred_name="CNN", ref_name="orig",
forecast_path=path, plot_name_affix="cali-ref", plot_folder=self.plot_path)
plot_conditional_quantiles(self.test_data.stations, pred_name="orig", ref_name="CNN",
forecast_path=path, plot_name_affix="like-bas", plot_folder=self.plot_path)
PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path)
PlotMonthlySummary(self.test_data.stations, path, r"forecasts_%s_test.nc", target_var,
PlotMonthlySummary(self.test_data.stations, path, r"forecasts_%s_test.nc", self.target_var,
plot_folder=self.plot_path)
PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN")
PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False,
......@@ -102,7 +102,7 @@ class PostProcessing(RunEnvironment):
input_data = self.test_data[i][0]
# get scaling parameters
mean, std, transformation_method = data.get_transformation_information(variable='o3')
mean, std, transformation_method = data.get_transformation_information(variable=self.target_var)
# nn forecast
nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method)
......@@ -147,7 +147,7 @@ class PostProcessing(RunEnvironment):
return ols_prediction
def _create_persistence_forecast(self, input_data, persistence_prediction, mean, std, transformation_method):
tmp_persi = input_data.sel({'window': 0, 'variables': 'o3'})
tmp_persi = input_data.sel({'window': 0, 'variables': self.target_var})
tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method)
window_lead_time = self.data_store.get("window_lead_time", "general")
persistence_prediction.values = np.expand_dims(np.tile(tmp_persi.squeeze('Stations'), (window_lead_time, 1)),
......@@ -227,7 +227,7 @@ class PostProcessing(RunEnvironment):
def _get_external_data(self, station):
try:
data = self.train_val_data.get_data_generator(station)
mean, std, transformation_method = data.get_transformation_information(variable='o3')
mean, std, transformation_method = data.get_transformation_information(variable=self.target_var)
external_data = self._create_orig_forecast(data, None, mean, std, transformation_method)
external_data = external_data.squeeze("Stations").sel(window=1).drop(["window", "Stations", "variables"])
return external_data.rename({'datetime': 'index'})
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment