Skip to content
Snippets Groups Projects
Commit d8838556 authored by lukas leufen's avatar lukas leufen
Browse files

implemented different bootstrap algorithms. Test performance on zam347.

parent ad302db1
No related branches found
No related tags found
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!61Resolve "REFAC: clean-up bootstrap workflow"
Pipeline #31718 passed
......@@ -17,7 +17,7 @@ def main(parser_args):
with RunEnvironment():
ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'],
station_type='background', trainable=False, create_new_model=False)
station_type='background', trainable=False, create_new_model=True)
PreProcessing()
ModelSetup()
......
......@@ -48,6 +48,76 @@ class BootStrapGenerator:
yield boot_hist, label
return
def get_generator_station_wise(self, station):
"""
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
the history and label data of this generator.
:return:
"""
# logging.info(f"station: {station}")
hist, label = self.orig_generator[station]
shuffled_data = self.load_shuffled_data(station, self.variables)
def f():
while True:
for var in self.variables:
logging.debug(f" var: {var}")
for boot in range(self.number_of_boots):
logging.debug(f"boot: {boot}")
boot_hist = hist.sel(variables=helpers.list_pop(self.variables, var))
shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables")
boot_hist = boot_hist.combine_first(shuffled_var)
boot_hist = boot_hist.sortby("variables")
yield boot_hist
return
return hist, label, f, self.number_of_boots * len(self.variables)
def get_bootstrap_meta_station_wise(self, station) -> List:
"""
Create meta data on ordering of variable bootstraps according to ordering from get_generator method.
:return: list with bootstrapped variable first and its corresponding station second.
"""
bootstrap_meta = []
label = self.orig_generator.get_data_generator(station).get_transposed_label()
for var in self.variables:
for boot in range(self.number_of_boots):
bootstrap_meta.extend([[var, station]] * len(label))
return bootstrap_meta
def get_generator_station_var_wise(self, station, var):
"""
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
the history and label data of this generator.
:return:
"""
hist, label = self.orig_generator[station]
shuffled_data = self.load_shuffled_data(station, self.variables)
def f():
while True:
for boot in range(self.number_of_boots):
logging.debug(f"boot: {boot}")
boot_hist = hist.sel(variables=helpers.list_pop(self.variables, var))
shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables")
boot_hist = boot_hist.combine_first(shuffled_var)
boot_hist = boot_hist.sortby("variables")
yield boot_hist
return
return hist, label, f, self.number_of_boots
def get_bootstrap_meta_station_var_wise(self, station, var) -> List:
"""
Create meta data on ordering of variable bootstraps according to ordering from get_generator method.
:return: list with bootstrapped variable first and its corresponding station second.
"""
bootstrap_meta = []
label = self.orig_generator.get_data_generator(station).get_transposed_label()
for boot in range(self.number_of_boots):
bootstrap_meta.extend([[var, station]] * len(label))
return bootstrap_meta
def get_bootstrap_meta(self) -> List:
"""
Create meta data on ordering of variable bootstraps according to ordering from get_generator method.
......@@ -116,6 +186,26 @@ class BootStraps:
self.create_shuffled_data()
self._boot_strap_generator = BootStrapGenerator(self.data, self.number_bootstraps, self.bootstrap_path)
@property
def stations(self):
return self._boot_strap_generator.stations
@property
def variables(self):
return self._boot_strap_generator.variables
def get_generator_station_wise(self, station):
return self._boot_strap_generator.get_generator_station_wise(station)
def get_generator_station_var_wise(self, station, var):
return self._boot_strap_generator.get_generator_station_var_wise(station, var)
def get_bootstrap_meta_station_wise(self, station):
return self._boot_strap_generator.get_bootstrap_meta_station_wise(station)
def get_bootstrap_meta_station_var_wise(self, station, var):
return self._boot_strap_generator.get_bootstrap_meta_station_var_wise(station, var)
def get_boot_strap_meta(self):
return self._boot_strap_generator.get_bootstrap_meta()
......
......@@ -10,8 +10,8 @@ import tensorflow as tf
from src.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler
# from src.model_modules.model_class import MyBranchedModel as MyModel
from src.model_modules.model_class import MyLittleModel as MyModel
# from src.model_modules.model_class import MyTowerModel as MyModel
# from src.model_modules.model_class import MyLittleModel as MyModel
from src.model_modules.model_class import MyTowerModel as MyModel
from src.run_modules.run_environment import RunEnvironment
......
......@@ -51,16 +51,28 @@ class PostProcessing(RunEnvironment):
logging.info("take a look on the next reported time measure. If this increases a lot, one should think to "
"skip make_prediction() whenever it is possible to save time.")
# skill scores
self.skill_scores = self.calculate_skill_scores()
# bootstraps
if self.data_store.get("evaluate_bootstraps", "general.postprocessing"):
self.create_boot_straps()
self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores()
bootstrap_path = self.data_store.get("bootstrap_path", "general")
BootStraps(self.test_data, bootstrap_path, 20)
with TimeTracking(name="split (refac_1)"):
self.create_boot_straps_refac_2()
self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores()
with TimeTracking(name="split (refac)"):
self.create_boot_straps_refac()
self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores()
with TimeTracking(name="merged"):
self.bootstrap_skill_scores = self.combined_boot_forecast_and_skill()
with TimeTracking(name="original version"):
self.create_boot_straps()
self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores()
# skill scores
# self.skill_scores = self.calculate_skill_scores()
# plotting
self.plot()
# self.plot()
def create_boot_straps(self):
# forecast
......@@ -70,6 +82,7 @@ class PostProcessing(RunEnvironment):
window_lead_time = self.data_store.get("window_lead_time", "general")
bootstraps = BootStraps(self.test_data, bootstrap_path, 20)
# make bootstrap predictions
logging.info("predictions")
bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(),
steps=bootstraps.get_boot_strap_generator_length(),
use_multiprocessing=True)
......@@ -81,6 +94,7 @@ class PostProcessing(RunEnvironment):
# save bootstrap predictions separately for each station and variable combination
variables = np.unique(bootstrap_meta[:, 0])
for station in np.unique(bootstrap_meta[:, 1]):
logging.info(station)
coords = None
for boot in variables:
# store each variable - station - combination
......@@ -97,6 +111,81 @@ class PostProcessing(RunEnvironment):
labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=["index", "ahead", "type"])
labels.to_netcdf(file_name)
def create_boot_straps_refac(self):
# forecast
with TimeTracking(name="boot predictions"):
bootstrap_path = self.data_store.get("bootstrap_path", "general")
forecast_path = self.data_store.get("forecast_path", "general")
window_lead_time = self.data_store.get("window_lead_time", "general")
bootstraps = BootStraps(self.test_data, bootstrap_path, 20)
for station in bootstraps.stations:
with TimeTracking(name=station):
logging.info(station)
hist, label, station_bootstrap, length = bootstraps.get_generator_station_wise(station)
# make bootstrap predictions
bootstrap_predictions = self.model.predict_generator(generator=station_bootstrap(),
steps=length,
use_multiprocessing=True)
if isinstance(bootstrap_predictions, list):
bootstrap_predictions = bootstrap_predictions[-1]
# get bootstrap prediction meta data
bootstrap_meta = np.array(bootstraps.get_bootstrap_meta_station_wise(station))
# save bootstrap predictions separately for each station and variable combination
variables = np.unique(bootstrap_meta[:, 0])
coords = None
for boot in variables:
# store each variable - station - combination
ind = np.all(bootstrap_meta == [boot, station], axis=1)
length = sum(ind)
sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1))
coords = (range(length), range(1, window_lead_time + 1))
tmp = xr.DataArray(sel, coords=(*coords, [boot]), dims=["index", "ahead", "type"])
file_name = os.path.join(forecast_path, f"bootstraps_{boot}_{station}.nc")
tmp.to_netcdf(file_name)
# store also true labels for each station
labels = bootstraps.get_labels(station).reshape((length, window_lead_time, 1))
file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc")
labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=["index", "ahead", "type"])
labels.to_netcdf(file_name)
def create_boot_straps_refac_2(self):
# forecast
with TimeTracking(name="boot predictions"):
bootstrap_path = self.data_store.get("bootstrap_path", "general")
forecast_path = self.data_store.get("forecast_path", "general")
window_lead_time = self.data_store.get("window_lead_time", "general")
bootstraps = BootStraps(self.test_data, bootstrap_path, 20)
for station in bootstraps.stations:
with TimeTracking(name=station):
logging.info(station)
for var in bootstraps.variables:
hist, label, station_bootstrap, length = bootstraps.get_generator_station_var_wise(station, var)
# make bootstrap predictions
bootstrap_predictions = self.model.predict_generator(generator=station_bootstrap(),
steps=length,
use_multiprocessing=True)
if isinstance(bootstrap_predictions, list):
bootstrap_predictions = bootstrap_predictions[-1]
# get bootstrap prediction meta data
bootstrap_meta = np.array(bootstraps.get_bootstrap_meta_station_var_wise(station, var))
# save bootstrap predictions separately for each station and variable combination
# store each variable - station - combination
ind = np.all(bootstrap_meta == [var, station], axis=1)
length = sum(ind)
sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1))
coords = (range(length), range(1, window_lead_time + 1))
tmp = xr.DataArray(sel, coords=(*coords, [var]), dims=["index", "ahead", "type"])
file_name = os.path.join(forecast_path, f"bootstraps_{var}_{station}.nc")
tmp.to_netcdf(file_name)
# store also true labels for each station
labels = bootstraps.get_labels(station).reshape((length, window_lead_time, 1))
file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc")
labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=["index", "ahead", "type"])
labels.to_netcdf(file_name)
def calculate_bootstrap_skill_scores(self):
with TimeTracking(name="boot skill scores"):
......@@ -110,6 +199,7 @@ class PostProcessing(RunEnvironment):
skill_scores = statistics.SkillScores(None)
score = {}
for station in self.test_data.stations:
logging.info(station)
file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc")
labels = xr.open_dataarray(file_name)
shape = labels.shape
......@@ -129,6 +219,62 @@ class PostProcessing(RunEnvironment):
score[station] = xr.DataArray(skill, dims=["boot_var", "ahead"])
return score
def combined_boot_forecast_and_skill(self):
# forecast
with TimeTracking(name="boot predictions"):
bootstrap_path = self.data_store.get("bootstrap_path", "general")
forecast_path = self.data_store.get("forecast_path", "general")
window_lead_time = self.data_store.get("window_lead_time", "general")
bootstraps = BootStraps(self.test_data, bootstrap_path, 20)
skill_scores = statistics.SkillScores(None)
score = {}
for station in bootstraps.stations:
with TimeTracking(name=station):
logging.info(station)
# store also true labels for each station
labels = bootstraps.get_labels(station)
shape = labels.shape
labels = labels.reshape((*shape, 1))
coords = (range(labels.shape[0]), range(1, labels.shape[1] + 1))
# file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc")
labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=["index", "ahead", "type"])
# labels.to_netcdf(file_name)
shape = labels.shape
orig = bootstraps.get_orig_prediction(forecast_path, f"forecasts_norm_{station}_test.nc").reshape(shape)
coords = (range(shape[0]), range(1, shape[1] + 1), ["orig"])
orig = xr.DataArray(orig, coords=coords, dims=["index", "ahead", "type"])
skill = pd.DataFrame(columns=range(1, window_lead_time + 1))
for var in bootstraps.variables:
hist, label, station_bootstrap, length = bootstraps.get_generator_station_var_wise(station, var)
# make bootstrap predictions
bootstrap_predictions = self.model.predict_generator(generator=station_bootstrap(),
steps=length,
use_multiprocessing=True)
if isinstance(bootstrap_predictions, list):
bootstrap_predictions = bootstrap_predictions[-1]
# get bootstrap prediction meta data
bootstrap_meta = np.array(bootstraps.get_bootstrap_meta_station_var_wise(station, var))
# save bootstrap predictions separately for each station and variable combination
# store each variable - station - combination
ind = np.all(bootstrap_meta == [var, station], axis=1)
length = sum(ind)
sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1))
coords = (range(length), range(1, window_lead_time + 1))
boot_data = xr.DataArray(sel, coords=(*coords, [var]), dims=["index", "ahead", "type"])
# file_name = os.path.join(forecast_path, f"bootstraps_{var}_{station}.nc")
# boot_data.to_netcdf(file_name)
boot_data = boot_data.combine_first(labels).combine_first(orig)
boot_scores = []
for ahead in range(1, window_lead_time + 1):
data = boot_data.sel(ahead=ahead)
boot_scores.append(skill_scores.general_skill_score(data, forecast_name=var, reference_name="orig"))
skill.loc[var] = np.array(boot_scores)
score[station] = xr.DataArray(skill, dims=["boot_var", "ahead"])
return score
def _load_model(self):
try:
model = self.data_store.get("best_model", "general")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment