diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py index 3e69950267f9c95ccc636e560a21731ade388432..9f61dace61af9e5ba2dc9176fe0d199047a3723f 100644 --- a/src/data_handling/bootstraps.py +++ b/src/data_handling/bootstraps.py @@ -61,6 +61,31 @@ class BootStrapGenerator: yield boot_hist, label return + def get_generator_refactored(self): + """ + This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return + the history and label data of this generator. + :return: + """ + while True: + for i, data in enumerate(self.orig_generator): + station = self.orig_generator.get_station_key(i) + logging.info(f"station: {station}") + hist, label = data + len_of_label = len(label) + shuffled_data = self.load_boot_data(station) + for var in self.variables: + logging.info(f" var: {var}") + for boot in range(self.boots): + logging.debug(f"boot: {boot}") + boot_hist = hist.sel(variables=helpers.list_pop(self.variables, var)) + shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables") + boot_hist = boot_hist.combine_first(shuffled_var) + boot_hist = boot_hist.sortby("variables") + self.bootstrap_meta.extend([[var, station]]*len_of_label) + yield boot_hist, label, var, station + return + def get_orig_prediction(self, path, file_name, prediction_name="CNN"): file = os.path.join(path, file_name) data = xr.open_dataarray(file) @@ -93,6 +118,9 @@ class BootStraps(RunEnvironment): def boot_strap_generator(self): return self._boot_strap_generator.get_generator() + def boot_strap_generator_refactored(self): + return self._boot_strap_generator.get_generator_refactored() + def get_boot_strap_generator_length(self): return self._boot_strap_generator.__len__() diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index c7ac3f10799d4329e8744d6ea69286c57119645a..b8f510bd777c6c2140f9d7b9314ae10f931caf67 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -5,6 +5,7 @@ __date__ = '2019-12-11' import logging import os +import dask.array as da import keras import numpy as np import pandas as pd @@ -51,10 +52,42 @@ class PostProcessing(RunEnvironment): logging.info("take a look on the next reported time measure. If this increases a lot, one should think to " "skip make_prediction() whenever it is possible to save time.") if self.data_store.get("evaluate_bootstraps", "general.postprocessing"): - self.bootstrap_skill_scores = self.create_boot_straps() + self.bootstrap_skill_scores = self.create_boot_straps_refactored() self.skill_scores = self.calculate_skill_scores() self.plot() + def create_boot_straps_refactored(self): + + bootstrap_path = self.data_store.get("bootstrap_path", "general") + forecast_path = self.data_store.get("forecast_path", "general") + window_lead_time = self.data_store.get("window_lead_time", "general") + bootstraps = BootStraps(self.test_data, bootstrap_path, 20) + bootstrap_predictions = [] + bootstrap_labels = [] + keras.backend.set_learning_phase(0) + with TimeTracking(name="boot predictions"): + station_previous = None + for boot in bootstraps.boot_strap_generator_refactored(): + input_data, label, variable, station = boot + predictions = self.model.predict(input_data) + if isinstance(predictions, list): + predictions = predictions[-1] + + predictions = np.expand_dims(predictions, 2) + coords = (range(predictions.shape[0]), range(1, window_lead_time + 1)) + tmp = xr.DataArray(predictions, coords=(*coords, [variable]), dims=["index", "ahead", "type"]) + file_name = os.path.join(forecast_path, f"bootstraps_{variable}_{station}.nc") + tmp.to_netcdf(file_name) + if station_previous != station: + labels = label.assign_coords(type="obs").expand_dims("type").drop(["Stations", "variables"]).rename({"datetime": "index", "window": "ahead"}) + file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc") + # labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=["index", "ahead", "type"]) + labels.to_netcdf(file_name) + station_previous = station + + # stopped here, this implementation is slower, than the old one, take a look on + # https://towardsdatascience.com/keras-data-generators-and-how-to-use-them-b69129ed779c + def create_boot_straps(self): # forecast @@ -66,24 +99,25 @@ class PostProcessing(RunEnvironment): with TimeTracking(name="boot predictions"): bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(), steps=bootstraps.get_boot_strap_generator_length()) - if isinstance(bootstrap_predictions, list): - bootstrap_predictions = bootstrap_predictions[-1] - bootstrap_meta = np.array(bootstraps.get_boot_strap_meta()) - variables = np.unique(bootstrap_meta[:, 0]) - for station in np.unique(bootstrap_meta[:, 1]): - coords = None - for boot in variables: - ind = np.all(bootstrap_meta == [boot, station], axis=1) - length = sum(ind) - sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1)) - coords = (range(length), range(1, window_lead_time + 1)) - tmp = xr.DataArray(sel, coords=(*coords, [boot]), dims=["index", "ahead", "type"]) - file_name = os.path.join(forecast_path, f"bootstraps_{boot}_{station}.nc") - tmp.to_netcdf(file_name) - labels = bootstraps.get_labels(station).reshape((length, window_lead_time, 1)) - file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc") - labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=["index", "ahead", "type"]) - labels.to_netcdf(file_name) + + if isinstance(bootstrap_predictions, list): + bootstrap_predictions = bootstrap_predictions[-1] + bootstrap_meta = np.array(bootstraps.get_boot_strap_meta()) + variables = np.unique(bootstrap_meta[:, 0]) + for station in np.unique(bootstrap_meta[:, 1]): + coords = None + for boot in variables: + ind = np.all(bootstrap_meta == [boot, station], axis=1) + length = sum(ind) + sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1)) + coords = (range(length), range(1, window_lead_time + 1)) + tmp = xr.DataArray(sel, coords=(*coords, [boot]), dims=["index", "ahead", "type"]) + file_name = os.path.join(forecast_path, f"bootstraps_{boot}_{station}.nc") + tmp.to_netcdf(file_name) + labels = bootstraps.get_labels(station).reshape((length, window_lead_time, 1)) + file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc") + labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=["index", "ahead", "type"]) + labels.to_netcdf(file_name) # file_name = os.path.join(forecast_path, f"bootstraps_orig.nc") # orig = xr.open_dataarray(file_name)