diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index b39de8e957a110121c0e8812608d32aad3431570..854182613cdb63456dc8f62d2421560d829ee629 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -639,15 +639,16 @@ class PlotTimeSeries(RunEnvironment): def _plot_ahead(self, ax, data): color = sns.color_palette("Blues_d", self._window_lead_time).as_hex() for ahead in data.coords["ahead"].values: - plot_data = data.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze() - index = plot_data.index + np.timedelta64(int(ahead), self._sampling) + plot_data = data.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze().shift(index=ahead) label = f"{ahead}{self._sampling}" - ax.plot(index, plot_data.values, color=color[ahead-1], label=label) + ax.plot(plot_data, color=color[ahead-1], label=label) def _plot_obs(self, ax, data): - obs_data = data.sel(type="obs", ahead=1) - index = data.index + np.timedelta64(1, self._sampling) - ax.plot(index, obs_data.values, color=matplotlib.colors.cnames["green"], label="obs") + ahead = 1 + obs_data = data.sel(type="obs", ahead=ahead).shift(index=ahead) + # index = data.index + np.timedelta64(1, self._sampling) + # ax.plot(index, obs_data.values, color=matplotlib.colors.cnames["green"], label="obs") + ax.plot(obs_data, color=matplotlib.colors.cnames["green"], label="obs") @staticmethod def _get_time_range(data): diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 962c9f52065729381ce11e8a8adcbeed45a4c011..00d3d2d83409bccf478024b553ac889a4517424b 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -162,7 +162,7 @@ class PostProcessing(RunEnvironment): for normalised in [True, False]: # create empty arrays - nn_prediction, persistence_prediction, ols_prediction = self._create_empty_prediction_arrays(data, count=3) + nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays(data, count=4) # nn forecast nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method, normalised) @@ -175,7 +175,7 @@ class PostProcessing(RunEnvironment): ols_prediction = self._create_ols_forecast(input_data, ols_prediction, mean, std, transformation_method, normalised) # observation - observation = self._create_observation(data, None, mean, std, transformation_method, normalised) + observation = self._create_observation(data, observation, mean, std, transformation_method, normalised) # merge all predictions full_index = self.create_fullindex(data.data.indexes['datetime'], self._get_frequency()) @@ -195,12 +195,19 @@ class PostProcessing(RunEnvironment): getter = {"daily": "1D", "hourly": "1H"} return getter.get(self._sampling, None) - @staticmethod - def _create_observation(data, _, mean, std, transformation_method, normalised): + def _create_observation(self, data, observation, mean, std, transformation_method, normalised): obs = data.observation.copy() if not normalised: obs = statistics.apply_inverse_transformation(obs, mean, std, transformation_method) - return obs + window_lead_time = self.data_store.get("window_lead_time", "general") + obs_w = [] + for w in range(window_lead_time): + obs_w.append(obs.shift(datetime=-(w+1))) + if observation is None: + observation = data.label.copy() + observation.values = np.concatenate(obs_w, axis=0) + return observation + def _create_ols_forecast(self, input_data, ols_prediction, mean, std, transformation_method, normalised): tmp_ols = self.ols_model.predict(input_data) @@ -212,7 +219,7 @@ class PostProcessing(RunEnvironment): return ols_prediction def _create_persistence_forecast(self, data, persistence_prediction, mean, std, transformation_method, normalised): - tmp_persi = data.observation.copy().sel({'window': 0}) + tmp_persi = data.observation.copy().sel({'window': 0})#.shift(datetime=1) if not normalised: tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method) window_lead_time = self.data_store.get("window_lead_time", "general")