diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py index 983868c8244c40063c84c3e4df2e9e79a960dd01..998ed8c6990d6a16388874c52faf4931ef4ba174 100644 --- a/src/data_handling/bootstraps.py +++ b/src/data_handling/bootstraps.py @@ -31,16 +31,7 @@ class BootStrapGenerator: """ return len(self.orig_generator)*self.boots*len(self.variables) - # def __iter__(self): - # """ - # Define the __iter__ part of the iterator protocol to iterate through this generator. Sets the private attribute - # `_iterator` to 0. - # :return: - # """ - # self._iterator = 0 - # return self - - def __iter__(self): + def get_generator(self): """ This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return the history and label data of this generator. @@ -69,7 +60,6 @@ class BootStrapGenerator: files = os.listdir(self.bootstrap_path) regex = re.compile(rf"{station}_\w*\.nc") file_name = os.path.join(self.bootstrap_path, list(filter(regex.search, files))[0]) - # shuffled_data = xr.open_dataarray(file_name, chunks=self.chunksize) shuffled_data = xr.open_dataarray(file_name, chunks=100) return shuffled_data @@ -79,10 +69,8 @@ class BootStraps(RunEnvironment): def __init__(self, data, bootstrap_path, number_bootstraps=10): super().__init__() - # self.data: DataGenerator = self.data_store.get("generator", "general.test") self.data: DataGenerator = data self.number_bootstraps = number_bootstraps - # self.bootstrap_path = self.data_store.get("bootstrap_path", "general") self.bootstrap_path = bootstrap_path self.chunks = self.get_chunk_size() self.create_shuffled_data() @@ -92,7 +80,10 @@ class BootStraps(RunEnvironment): return self._boot_strap_generator.bootstrap_meta def boot_strap_generator(self): - return self._boot_strap_generator + return self._boot_strap_generator.get_generator() + + def get_boot_strap_generator_length(self): + return self._boot_strap_generator.__len__() def get_chunk_size(self): hist, _ = self.data[0] @@ -104,6 +95,7 @@ class BootStraps(RunEnvironment): randomly selected variables. If there is a suitable local file for requested window size and number of bootstraps, no additional file will be created inside this function. """ + logging.info("create shuffled bootstrap data") variables_str = '_'.join(sorted(self.data.variables)) window = self.data.window_history_size for station in self.data.stations: diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 06203c879872891f57c719040482fe052824c65e..8a0df43756acfdba0c86be7eecc8e4da37999ce3 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -13,6 +13,7 @@ import xarray as xr from src import statistics from src.data_handling.data_distributor import Distributor from src.data_handling.data_generator import DataGenerator +from src.data_handling.bootstraps import BootStraps from src.datastore import NameNotFoundInDataStore from src.helpers import TimeTracking from src.model_modules.linear_model import OrdinaryLeastSquaredModel @@ -50,6 +51,25 @@ class PostProcessing(RunEnvironment): "skip make_prediction() whenever it is possible to save time.") self.skill_scores = self.calculate_skill_scores() self.plot() + self.create_boot_straps() + + def create_boot_straps(self): + bootstrap_path = self.data_store.get("bootstrap_path", "general") + forecast_path = self.data_store.get("forecast_path", "general") + window_lead_time = self.data_store.get("window_lead_time", "general") + bootstraps = BootStraps(self.test_data, bootstrap_path, 20) + with TimeTracking(name="boot predictions"): + bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(), + steps=bootstraps.get_boot_strap_generator_length()) + bootstrap_meta = np.array(bootstraps.get_boot_strap_meta()) + length = sum(bootstrap_meta == bootstrap_meta[0]) + variables = np.unique(bootstrap_meta) + for boot in variables: + ind = (bootstrap_meta == boot) + sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1)) + tmp = xr.DataArray(sel, coords=(range(length), range(window_lead_time), [boot]), dims=["index", "window", "boot"]) + file_name = os.path.join(forecast_path, f"bootstraps_{boot}.nc") + tmp.to_netcdf(file_name) def _load_model(self): try: