Skip to content
Snippets Groups Projects
Commit 8ab3f95b authored by lukas leufen's avatar lukas leufen
Browse files

first try on faster prediction

parent 74284889
Branches
Tags
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!61Resolve "REFAC: clean-up bootstrap workflow"
Pipeline #31314 passed
......@@ -61,6 +61,31 @@ class BootStrapGenerator:
yield boot_hist, label
return
def get_generator_refactored(self):
"""
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
the history and label data of this generator.
:return:
"""
while True:
for i, data in enumerate(self.orig_generator):
station = self.orig_generator.get_station_key(i)
logging.info(f"station: {station}")
hist, label = data
len_of_label = len(label)
shuffled_data = self.load_boot_data(station)
for var in self.variables:
logging.info(f" var: {var}")
for boot in range(self.boots):
logging.debug(f"boot: {boot}")
boot_hist = hist.sel(variables=helpers.list_pop(self.variables, var))
shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables")
boot_hist = boot_hist.combine_first(shuffled_var)
boot_hist = boot_hist.sortby("variables")
self.bootstrap_meta.extend([[var, station]]*len_of_label)
yield boot_hist, label, var, station
return
def get_orig_prediction(self, path, file_name, prediction_name="CNN"):
file = os.path.join(path, file_name)
data = xr.open_dataarray(file)
......@@ -93,6 +118,9 @@ class BootStraps(RunEnvironment):
def boot_strap_generator(self):
return self._boot_strap_generator.get_generator()
def boot_strap_generator_refactored(self):
return self._boot_strap_generator.get_generator_refactored()
def get_boot_strap_generator_length(self):
return self._boot_strap_generator.__len__()
......
......@@ -5,6 +5,7 @@ __date__ = '2019-12-11'
import logging
import os
import dask.array as da
import keras
import numpy as np
import pandas as pd
......@@ -51,10 +52,42 @@ class PostProcessing(RunEnvironment):
logging.info("take a look on the next reported time measure. If this increases a lot, one should think to "
"skip make_prediction() whenever it is possible to save time.")
if self.data_store.get("evaluate_bootstraps", "general.postprocessing"):
self.bootstrap_skill_scores = self.create_boot_straps()
self.bootstrap_skill_scores = self.create_boot_straps_refactored()
self.skill_scores = self.calculate_skill_scores()
self.plot()
def create_boot_straps_refactored(self):
bootstrap_path = self.data_store.get("bootstrap_path", "general")
forecast_path = self.data_store.get("forecast_path", "general")
window_lead_time = self.data_store.get("window_lead_time", "general")
bootstraps = BootStraps(self.test_data, bootstrap_path, 20)
bootstrap_predictions = []
bootstrap_labels = []
keras.backend.set_learning_phase(0)
with TimeTracking(name="boot predictions"):
station_previous = None
for boot in bootstraps.boot_strap_generator_refactored():
input_data, label, variable, station = boot
predictions = self.model.predict(input_data)
if isinstance(predictions, list):
predictions = predictions[-1]
predictions = np.expand_dims(predictions, 2)
coords = (range(predictions.shape[0]), range(1, window_lead_time + 1))
tmp = xr.DataArray(predictions, coords=(*coords, [variable]), dims=["index", "ahead", "type"])
file_name = os.path.join(forecast_path, f"bootstraps_{variable}_{station}.nc")
tmp.to_netcdf(file_name)
if station_previous != station:
labels = label.assign_coords(type="obs").expand_dims("type").drop(["Stations", "variables"]).rename({"datetime": "index", "window": "ahead"})
file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc")
# labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=["index", "ahead", "type"])
labels.to_netcdf(file_name)
station_previous = station
# stopped here, this implementation is slower, than the old one, take a look on
# https://towardsdatascience.com/keras-data-generators-and-how-to-use-them-b69129ed779c
def create_boot_straps(self):
# forecast
......@@ -66,6 +99,7 @@ class PostProcessing(RunEnvironment):
with TimeTracking(name="boot predictions"):
bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(),
steps=bootstraps.get_boot_strap_generator_length())
if isinstance(bootstrap_predictions, list):
bootstrap_predictions = bootstrap_predictions[-1]
bootstrap_meta = np.array(bootstraps.get_boot_strap_meta())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment