Skip to content
Snippets Groups Projects
Commit c027b543 authored by lukas leufen's avatar lukas leufen
Browse files

include time axis fix, /close #64

parents 0e36df1e 3bdf3ea0
No related branches found
No related tags found
2 merge requests!59Develop,!57Lukas issue 064 bug check time axis
Pipeline #31231 passed
......@@ -639,15 +639,16 @@ class PlotTimeSeries(RunEnvironment):
def _plot_ahead(self, ax, data):
color = sns.color_palette("Blues_d", self._window_lead_time).as_hex()
for ahead in data.coords["ahead"].values:
plot_data = data.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze()
index = plot_data.index + np.timedelta64(int(ahead), self._sampling)
plot_data = data.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze().shift(index=ahead)
label = f"{ahead}{self._sampling}"
ax.plot(index, plot_data.values, color=color[ahead-1], label=label)
ax.plot(plot_data, color=color[ahead-1], label=label)
def _plot_obs(self, ax, data):
obs_data = data.sel(type="obs", ahead=1)
index = data.index + np.timedelta64(1, self._sampling)
ax.plot(index, obs_data.values, color=matplotlib.colors.cnames["green"], label="obs")
ahead = 1
obs_data = data.sel(type="obs", ahead=ahead).shift(index=ahead)
# index = data.index + np.timedelta64(1, self._sampling)
# ax.plot(index, obs_data.values, color=matplotlib.colors.cnames["green"], label="obs")
ax.plot(obs_data, color=matplotlib.colors.cnames["green"], label="obs")
@staticmethod
def _get_time_range(data):
......
......@@ -65,6 +65,8 @@ class PostProcessing(RunEnvironment):
with TimeTracking(name="boot predictions"):
bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(),
steps=bootstraps.get_boot_strap_generator_length())
if isinstance(bootstrap_predictions, list):
bootstrap_predictions = bootstrap_predictions[-1]
bootstrap_meta = np.array(bootstraps.get_boot_strap_meta())
variables = np.unique(bootstrap_meta[:, 0])
for station in np.unique(bootstrap_meta[:, 1]):
......@@ -162,7 +164,7 @@ class PostProcessing(RunEnvironment):
for normalised in [True, False]:
# create empty arrays
nn_prediction, persistence_prediction, ols_prediction = self._create_empty_prediction_arrays(data, count=3)
nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays(data, count=4)
# nn forecast
nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method, normalised)
......@@ -175,7 +177,7 @@ class PostProcessing(RunEnvironment):
ols_prediction = self._create_ols_forecast(input_data, ols_prediction, mean, std, transformation_method, normalised)
# observation
observation = self._create_observation(data, None, mean, std, transformation_method, normalised)
observation = self._create_observation(data, observation, mean, std, transformation_method, normalised)
# merge all predictions
full_index = self.create_fullindex(data.data.indexes['datetime'], self._get_frequency())
......@@ -195,9 +197,8 @@ class PostProcessing(RunEnvironment):
getter = {"daily": "1D", "hourly": "1H"}
return getter.get(self._sampling, None)
@staticmethod
def _create_observation(data, _, mean, std, transformation_method, normalised):
obs = data.observation.copy()
def _create_observation(self, data, _, mean, std, transformation_method, normalised):
obs = data.label.copy()
if not normalised:
obs = statistics.apply_inverse_transformation(obs, mean, std, transformation_method)
return obs
......@@ -235,7 +236,9 @@ class PostProcessing(RunEnvironment):
tmp_nn = self.model.predict(input_data)
if not normalised:
tmp_nn = statistics.apply_inverse_transformation(tmp_nn, mean, std, transformation_method)
if tmp_nn.ndim == 3:
if isinstance(tmp_nn, list):
nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn[-1], axis=1), 2, 0)
elif tmp_nn.ndim == 3:
nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn[-1, ...], axis=1), 2, 0)
elif tmp_nn.ndim == 2:
nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn, axis=1), 2, 0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment