Skip to content
Snippets Groups Projects
Commit 82a25fc8 authored by lukas leufen's avatar lukas leufen
Browse files

include bootstrap skill scores /close #51

parents cf202273 113e891a
No related branches found
No related tags found
2 merge requests!59Develop,!52implemented bootstraps
...@@ -31,6 +31,11 @@ class BootStrapGenerator: ...@@ -31,6 +31,11 @@ class BootStrapGenerator:
""" """
return len(self.orig_generator)*self.boots*len(self.variables) return len(self.orig_generator)*self.boots*len(self.variables)
def get_labels(self, key):
_, label = self.orig_generator[key]
for _ in range(self.boots):
yield label
def get_generator(self): def get_generator(self):
""" """
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
...@@ -52,10 +57,16 @@ class BootStrapGenerator: ...@@ -52,10 +57,16 @@ class BootStrapGenerator:
shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables") shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables")
boot_hist = boot_hist.combine_first(shuffled_var) boot_hist = boot_hist.combine_first(shuffled_var)
boot_hist = boot_hist.sortby("variables") boot_hist = boot_hist.sortby("variables")
self.bootstrap_meta.extend([var]*len_of_label) self.bootstrap_meta.extend([[var, station]]*len_of_label)
yield boot_hist, label yield boot_hist, label
return return
def get_orig_prediction(self, path, file_name, prediction_name="CNN"):
file = os.path.join(path, file_name)
data = xr.open_dataarray(file)
for _ in range(self.boots):
yield data.sel(type=prediction_name).squeeze()
def load_boot_data(self, station): def load_boot_data(self, station):
files = os.listdir(self.bootstrap_path) files = os.listdir(self.bootstrap_path)
regex = re.compile(rf"{station}_\w*\.nc") regex = re.compile(rf"{station}_\w*\.nc")
...@@ -85,6 +96,27 @@ class BootStraps(RunEnvironment): ...@@ -85,6 +96,27 @@ class BootStraps(RunEnvironment):
def get_boot_strap_generator_length(self): def get_boot_strap_generator_length(self):
return self._boot_strap_generator.__len__() return self._boot_strap_generator.__len__()
def get_labels(self, key):
labels_list = []
chunks = None
for labels in self._boot_strap_generator.get_labels(key):
if len(labels_list) == 0:
chunks = (100, labels.data.shape[1])
labels_list.append(da.from_array(labels.data, chunks=chunks))
labels_out = da.concatenate(labels_list, axis=0)
return labels_out.compute()
def get_orig_prediction(self, path, name):
labels_list = []
chunks = None
for labels in self._boot_strap_generator.get_orig_prediction(path, name):
if len(labels_list) == 0:
chunks = (100, labels.data.shape[1])
labels_list.append(da.from_array(labels.data, chunks=chunks))
labels_out = da.concatenate(labels_list, axis=0)
labels_out = labels_out.compute()
return labels_out[~np.isnan(labels_out).any(axis=1), :]
def get_chunk_size(self): def get_chunk_size(self):
hist, _ = self.data[0] hist, _ = self.data[0]
return (100, *hist.shape[1:], self.number_bootstraps) return (100, *hist.shape[1:], self.number_bootstraps)
......
...@@ -49,27 +49,60 @@ class PostProcessing(RunEnvironment): ...@@ -49,27 +49,60 @@ class PostProcessing(RunEnvironment):
self.make_prediction() self.make_prediction()
logging.info("take a look on the next reported time measure. If this increases a lot, one should think to " logging.info("take a look on the next reported time measure. If this increases a lot, one should think to "
"skip make_prediction() whenever it is possible to save time.") "skip make_prediction() whenever it is possible to save time.")
self.skill_scores = self.calculate_skill_scores() # self.skill_scores = self.calculate_skill_scores()
self.plot() # self.plot()
self.create_boot_straps() self.create_boot_straps()
def create_boot_straps(self): def create_boot_straps(self):
# forecast
bootstrap_path = self.data_store.get("bootstrap_path", "general") bootstrap_path = self.data_store.get("bootstrap_path", "general")
forecast_path = self.data_store.get("forecast_path", "general") forecast_path = self.data_store.get("forecast_path", "general")
window_lead_time = self.data_store.get("window_lead_time", "general") window_lead_time = self.data_store.get("window_lead_time", "general")
bootstraps = BootStraps(self.test_data, bootstrap_path, 20) bootstraps = BootStraps(self.test_data, bootstrap_path, 2)
with TimeTracking(name="boot predictions"): with TimeTracking(name="boot predictions"):
bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(), bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(),
steps=bootstraps.get_boot_strap_generator_length()) steps=bootstraps.get_boot_strap_generator_length())
bootstrap_meta = np.array(bootstraps.get_boot_strap_meta()) bootstrap_meta = np.array(bootstraps.get_boot_strap_meta())
length = sum(bootstrap_meta == bootstrap_meta[0]) variables = np.unique(bootstrap_meta[:, 0])
variables = np.unique(bootstrap_meta) for station in np.unique(bootstrap_meta[:, 1]):
coords = None
for boot in variables: for boot in variables:
ind = (bootstrap_meta == boot) ind = np.all(bootstrap_meta == [boot, station], axis=1)
length = sum(ind)
sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1)) sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1))
tmp = xr.DataArray(sel, coords=(range(length), range(window_lead_time), [boot]), dims=["index", "window", "boot"]) coords = (range(length), range(window_lead_time))
file_name = os.path.join(forecast_path, f"bootstraps_{boot}.nc") tmp = xr.DataArray(sel, coords=(*coords, [boot]), dims=["index", "window", "type"])
file_name = os.path.join(forecast_path, f"bootstraps_{boot}_{station}.nc")
tmp.to_netcdf(file_name) tmp.to_netcdf(file_name)
labels = bootstraps.get_labels(station).reshape((length, window_lead_time, 1))
file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc")
labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=["index", "window", "type"])
labels.to_netcdf(file_name)
# file_name = os.path.join(forecast_path, f"bootstraps_orig.nc")
# orig = xr.open_dataarray(file_name)
# calc skill scores
skill_scores = statistics.SkillScores(None)
score = {}
for station in np.unique(bootstrap_meta[:, 1]):
file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc")
labels = xr.open_dataarray(file_name)
shape = labels.shape
orig = bootstraps.get_orig_prediction(forecast_path, f"forecasts_norm_{station}_test.nc").reshape(shape)
orig = xr.DataArray(orig, coords=(range(shape[0]), range(shape[1]), ["orig"]), dims=["index", "window", "type"])
score[station] = {}
for boot in variables:
file_name = os.path.join(forecast_path, f"bootstraps_{boot}_{station}.nc")
boot_data = xr.open_dataarray(file_name)
boot_data = boot_data.combine_first(labels)
boot_data = boot_data.combine_first(orig)
score[station][boot] = skill_scores.general_skill_score(boot_data, forecast_name=boot, reference_name="orig")
# plot
def _load_model(self): def _load_model(self):
try: try:
...@@ -116,25 +149,27 @@ class PostProcessing(RunEnvironment): ...@@ -116,25 +149,27 @@ class PostProcessing(RunEnvironment):
logging.debug("start make_prediction") logging.debug("start make_prediction")
for i, _ in enumerate(self.test_data): for i, _ in enumerate(self.test_data):
data = self.test_data.get_data_generator(i) data = self.test_data.get_data_generator(i)
nn_prediction, persistence_prediction, ols_prediction = self._create_empty_prediction_arrays(data, count=3)
input_data = data.get_transposed_history() input_data = data.get_transposed_history()
# get scaling parameters # get scaling parameters
mean, std, transformation_method = data.get_transformation_information(variable=self.target_var) mean, std, transformation_method = data.get_transformation_information(variable=self.target_var)
for normalised in [True, False]:
# create empty arrays
nn_prediction, persistence_prediction, ols_prediction = self._create_empty_prediction_arrays(data, count=3)
# nn forecast # nn forecast
nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method) nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method, normalised)
# persistence # persistence
persistence_prediction = self._create_persistence_forecast(input_data, persistence_prediction, mean, std, persistence_prediction = self._create_persistence_forecast(input_data, persistence_prediction, mean, std,
transformation_method) transformation_method, normalised)
# ols # ols
ols_prediction = self._create_ols_forecast(input_data, ols_prediction, mean, std, transformation_method) ols_prediction = self._create_ols_forecast(input_data, ols_prediction, mean, std, transformation_method, normalised)
# observation # observation
observation = self._create_observation(data, None, mean, std, transformation_method) observation = self._create_observation(data, None, mean, std, transformation_method, normalised)
# merge all predictions # merge all predictions
full_index = self.create_fullindex(data.data.indexes['datetime'], self._get_frequency()) full_index = self.create_fullindex(data.data.indexes['datetime'], self._get_frequency())
...@@ -146,7 +181,8 @@ class PostProcessing(RunEnvironment): ...@@ -146,7 +181,8 @@ class PostProcessing(RunEnvironment):
# save all forecasts locally # save all forecasts locally
path = self.data_store.get("forecast_path", "general") path = self.data_store.get("forecast_path", "general")
file = os.path.join(path, f"forecasts_{data.station[0]}_test.nc") prefix = "forecasts_norm" if normalised else "forecasts"
file = os.path.join(path, f"{prefix}_{data.station[0]}_test.nc")
all_predictions.to_netcdf(file) all_predictions.to_netcdf(file)
def _get_frequency(self): def _get_frequency(self):
...@@ -154,26 +190,31 @@ class PostProcessing(RunEnvironment): ...@@ -154,26 +190,31 @@ class PostProcessing(RunEnvironment):
return getter.get(self._sampling, None) return getter.get(self._sampling, None)
@staticmethod @staticmethod
def _create_observation(data, _, mean, std, transformation_method): def _create_observation(data, _, mean, std, transformation_method, normalised):
return statistics.apply_inverse_transformation(data.label.copy(), mean, std, transformation_method) obs = data.label.copy()
if not normalised:
obs = statistics.apply_inverse_transformation(obs, mean, std, transformation_method)
return obs
def _create_ols_forecast(self, input_data, ols_prediction, mean, std, transformation_method): def _create_ols_forecast(self, input_data, ols_prediction, mean, std, transformation_method, normalised):
tmp_ols = self.ols_model.predict(input_data) tmp_ols = self.ols_model.predict(input_data)
if not normalised:
tmp_ols = statistics.apply_inverse_transformation(tmp_ols, mean, std, transformation_method) tmp_ols = statistics.apply_inverse_transformation(tmp_ols, mean, std, transformation_method)
tmp_ols = np.expand_dims(tmp_ols, axis=1) tmp_ols = np.expand_dims(tmp_ols, axis=1)
target_shape = ols_prediction.values.shape target_shape = ols_prediction.values.shape
ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols
return ols_prediction return ols_prediction
def _create_persistence_forecast(self, input_data, persistence_prediction, mean, std, transformation_method): def _create_persistence_forecast(self, input_data, persistence_prediction, mean, std, transformation_method, normalised):
tmp_persi = input_data.sel({'window': 0, 'variables': self.target_var}) tmp_persi = input_data.sel({'window': 0, 'variables': self.target_var})
if not normalised:
tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method) tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method)
window_lead_time = self.data_store.get("window_lead_time", "general") window_lead_time = self.data_store.get("window_lead_time", "general")
persistence_prediction.values = np.expand_dims(np.tile(tmp_persi.squeeze('Stations'), (window_lead_time, 1)), persistence_prediction.values = np.expand_dims(np.tile(tmp_persi.squeeze('Stations'), (window_lead_time, 1)),
axis=1) axis=1)
return persistence_prediction return persistence_prediction
def _create_nn_forecast(self, input_data, nn_prediction, mean, std, transformation_method): def _create_nn_forecast(self, input_data, nn_prediction, mean, std, transformation_method, normalised):
""" """
create the nn forecast for given input data. Inverse transformation is applied to the forecast to get the output create the nn forecast for given input data. Inverse transformation is applied to the forecast to get the output
in the original space. Furthermore, only the output of the main branch is returned (not all minor branches, if in the original space. Furthermore, only the output of the main branch is returned (not all minor branches, if
...@@ -186,6 +227,7 @@ class PostProcessing(RunEnvironment): ...@@ -186,6 +227,7 @@ class PostProcessing(RunEnvironment):
:return: :return:
""" """
tmp_nn = self.model.predict(input_data) tmp_nn = self.model.predict(input_data)
if not normalised:
tmp_nn = statistics.apply_inverse_transformation(tmp_nn, mean, std, transformation_method) tmp_nn = statistics.apply_inverse_transformation(tmp_nn, mean, std, transformation_method)
if tmp_nn.ndim == 3: if tmp_nn.ndim == 3:
nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn[-1, ...], axis=1), 2, 0) nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn[-1, ...], axis=1), 2, 0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment