Skip to content
Snippets Groups Projects
Commit 72259773 authored by lukas leufen's avatar lukas leufen
Browse files

current status, started to fix time shift

parent 0e36df1e
No related branches found
No related tags found
2 merge requests!59Develop,!57Lukas issue 064 bug check time axis
Pipeline #31190 passed
......@@ -639,15 +639,16 @@ class PlotTimeSeries(RunEnvironment):
def _plot_ahead(self, ax, data):
color = sns.color_palette("Blues_d", self._window_lead_time).as_hex()
for ahead in data.coords["ahead"].values:
plot_data = data.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze()
index = plot_data.index + np.timedelta64(int(ahead), self._sampling)
plot_data = data.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze().shift(index=ahead)
label = f"{ahead}{self._sampling}"
ax.plot(index, plot_data.values, color=color[ahead-1], label=label)
ax.plot(plot_data, color=color[ahead-1], label=label)
def _plot_obs(self, ax, data):
obs_data = data.sel(type="obs", ahead=1)
index = data.index + np.timedelta64(1, self._sampling)
ax.plot(index, obs_data.values, color=matplotlib.colors.cnames["green"], label="obs")
ahead = 1
obs_data = data.sel(type="obs", ahead=ahead).shift(index=ahead)
# index = data.index + np.timedelta64(1, self._sampling)
# ax.plot(index, obs_data.values, color=matplotlib.colors.cnames["green"], label="obs")
ax.plot(obs_data, color=matplotlib.colors.cnames["green"], label="obs")
@staticmethod
def _get_time_range(data):
......
......@@ -162,7 +162,7 @@ class PostProcessing(RunEnvironment):
for normalised in [True, False]:
# create empty arrays
nn_prediction, persistence_prediction, ols_prediction = self._create_empty_prediction_arrays(data, count=3)
nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays(data, count=4)
# nn forecast
nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method, normalised)
......@@ -175,7 +175,7 @@ class PostProcessing(RunEnvironment):
ols_prediction = self._create_ols_forecast(input_data, ols_prediction, mean, std, transformation_method, normalised)
# observation
observation = self._create_observation(data, None, mean, std, transformation_method, normalised)
observation = self._create_observation(data, observation, mean, std, transformation_method, normalised)
# merge all predictions
full_index = self.create_fullindex(data.data.indexes['datetime'], self._get_frequency())
......@@ -195,12 +195,19 @@ class PostProcessing(RunEnvironment):
getter = {"daily": "1D", "hourly": "1H"}
return getter.get(self._sampling, None)
@staticmethod
def _create_observation(data, _, mean, std, transformation_method, normalised):
def _create_observation(self, data, observation, mean, std, transformation_method, normalised):
obs = data.observation.copy()
if not normalised:
obs = statistics.apply_inverse_transformation(obs, mean, std, transformation_method)
return obs
window_lead_time = self.data_store.get("window_lead_time", "general")
obs_w = []
for w in range(window_lead_time):
obs_w.append(obs.shift(datetime=-(w+1)))
if observation is None:
observation = data.label.copy()
observation.values = np.concatenate(obs_w, axis=0)
return observation
def _create_ols_forecast(self, input_data, ols_prediction, mean, std, transformation_method, normalised):
tmp_ols = self.ols_model.predict(input_data)
......@@ -212,7 +219,7 @@ class PostProcessing(RunEnvironment):
return ols_prediction
def _create_persistence_forecast(self, data, persistence_prediction, mean, std, transformation_method, normalised):
tmp_persi = data.observation.copy().sel({'window': 0})
tmp_persi = data.observation.copy().sel({'window': 0})#.shift(datetime=1)
if not normalised:
tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method)
window_lead_time = self.data_store.get("window_lead_time", "general")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment