diff --git a/run.py b/run.py index 6d9f1018f6eae671f36c3bac7443c902b35a3270..9809712876dc886007b042a52d7b46c027800faf 100644 --- a/run.py +++ b/run.py @@ -16,7 +16,8 @@ def main(parser_args): with RunEnvironment(): ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], - station_type='background', trainable=False, create_new_model=True) + station_type='background', trainable=False, create_new_model=False, window_history_size=6, + create_new_bootstraps=True) PreProcessing() ModelSetup() diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py index 0d2eb7ef0a825c9b29c1e153ce95e0504ce45558..8888f6b1b5967fcb29fdd58d1408f78b4e21fbc9 100644 --- a/src/data_handling/bootstraps.py +++ b/src/data_handling/bootstraps.py @@ -5,6 +5,7 @@ __date__ = '2020-02-07' from src.data_handling.data_generator import DataGenerator import numpy as np import logging +import keras import dask.array as da import xarray as xr import os @@ -13,6 +14,142 @@ from src import helpers from typing import List, Union, Pattern +class RealBootStrapGenerator(keras.utils.Sequence): + + def __getitem__(self, index): + logging.debug(f"boot: {index}") + boot_hist = self.history.copy() + boot_hist = boot_hist.combine_first(self.__get_shuffled(index)) + return boot_hist.reindex_like(self.history_orig) + + def __get_shuffled(self, index): + shuffled_var = self.shuffled.sel(boots=index).expand_dims("variables").drop("boots") + return shuffled_var.transpose("datetime", "window", "Stations", "variables") + + def __init__(self, number_of_boots, history, shuffled, variables, shuffled_variable): + self.number_of_boots = number_of_boots + self.variables = variables + self.history_orig = history + self.history = history.sel(variables=helpers.list_pop(self.variables, shuffled_variable)) + self.shuffled = shuffled.sel(variables=shuffled_variable) + + def __len__(self): + return self.number_of_boots + + +class BootStrapGeneratorNew: + + def __init__(self, orig_generator, number_of_boots, bootstrap_path): + self.orig_generator: DataGenerator = orig_generator + self.stations = self.orig_generator.stations + self.variables = self.orig_generator.variables + self.number_of_boots = number_of_boots + self.bootstrap_path = bootstrap_path + + def __len__(self): + return len(self.orig_generator) * self.number_of_boots + + def get_generator_station_var_wise(self, station, var): + """ + This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return + the history and label data of this generator. + :return: + """ + hist, label = self.orig_generator[station] + shuffled_data = self.load_shuffled_data(station, self.variables) + gen = RealBootStrapGenerator(self.number_of_boots, hist, shuffled_data, self.variables, var) + return hist, label, gen, self.number_of_boots + + def get_bootstrap_meta_station_var_wise(self, station, var) -> List: + """ + Create meta data on ordering of variable bootstraps according to ordering from get_generator method. + :return: list with bootstrapped variable first and its corresponding station second. + """ + bootstrap_meta = [] + label = self.orig_generator.get_data_generator(station).get_transposed_label() + for boot in range(self.number_of_boots): + bootstrap_meta.extend([[var, station]] * len(label)) + return bootstrap_meta + + def get_labels(self, key: Union[str, int]): + """ + Reepats labels for given key by the number of boots and yield it one by one. + :param key: key of station (either station name as string or the position in generator as integer) + :return: yields labels for length of boots + """ + _, label = self.orig_generator[key] + for _ in range(self.number_of_boots): + yield label + + def get_orig_prediction(self, path: str, file_name: str, prediction_name: str = "CNN"): + """ + Repeats predictions from given file(_name) in path by the number of boots. + :param path: path to file + :param file_name: file name + :param prediction_name: name of the prediction to select from loaded file + :return: yields predictions for length of boots + """ + file = os.path.join(path, file_name) + data = xr.open_dataarray(file) + for _ in range(self.number_of_boots): + yield data.sel(type=prediction_name).squeeze() + + def load_shuffled_data(self, station: str, variables: List[str]) -> xr.DataArray: + """ + Load shuffled data from bootstrap path. Data is stored as + '<station>_<var1>_<var2>_..._hist<histsize>_nboots<nboots>_shuffled.nc', e.g. + 'DEBW107_cloudcover_no_no2_temp_u_v_hist13_nboots20_shuffled.nc' + :param station: + :param variables: + :return: shuffled data as xarray + """ + file_name = self.get_shuffled_data_file(station, variables) + shuffled_data = xr.open_dataarray(file_name, chunks=100) + return shuffled_data + + def get_shuffled_data_file(self, station, variables): + files = os.listdir(self.bootstrap_path) + regex = self.create_file_regex(station, variables) + file = self.filter_files(regex, files, self.orig_generator.window_history_size, self.number_of_boots) + if file: + return os.path.join(self.bootstrap_path, file) + else: + raise FileNotFoundError(f"Could not find a file to match pattern {regex}") + + @staticmethod + def create_file_regex(station: str, variables: List[str]) -> Pattern: + """ + Creates regex for given station and variables to look for shuffled data with pattern: + `<station>(_<var>)*_hist(<hist>)_nboots(<nboots>)_shuffled.nc` + :param station: station name to use as prefix + :param variables: variables to add after station + :return: compiled regular expression + """ + var_regex = "".join([rf"(_\w+)*_{v}(_\w+)*" for v in sorted(variables)]) + regex = re.compile(rf"{station}{var_regex}_hist(\d+)_nboots(\d+)_shuffled\.nc") + return regex + + @staticmethod + def filter_files(regex: Pattern, files: List[str], window: int, nboot: int) -> Union[str, None]: + """ + Filter list of files by regex. Regex has to be structured to match the following string structure + `<station>(_<var>)*_hist(<hist>)_nboots(<nboots>)_shuffled.nc`. Hist and nboots values have to be included as + group. All matches are compared to given window and nboot parameters. A valid file must have the same value (or + larger) than these parameters and contain all variables. + :param regex: compiled regular expression pattern following the style from method description + :param files: list of file names to filter + :param window: minimum length of window to look for + :param nboot: minimal number of boots to search + :return: matching file name or None, if no valid file was found + """ + for f in files: + match = regex.match(f) + if match: + last = match.lastindex + if (int(match.group(last-1)) >= window) and (int(match.group(last)) >= nboot): + return f + + class BootStrapGenerator: def __init__(self, orig_generator, number_of_boots, bootstrap_path): @@ -23,7 +160,7 @@ class BootStrapGenerator: self.bootstrap_path = bootstrap_path def __len__(self): - return len(self.orig_generator) * self.number_of_boots * len(self.variables) + return len(self.orig_generator) * self.number_of_boots def get_generator_station_var_wise(self, station, var): """ @@ -146,7 +283,7 @@ class BootStraps: self.bootstrap_path = bootstrap_path self.chunks = self.get_chunk_size() self.create_shuffled_data() - self._boot_strap_generator = BootStrapGenerator(self.data, self.number_bootstraps, self.bootstrap_path) + self._boot_strap_generator = BootStrapGeneratorNew(self.data, self.number_bootstraps, self.bootstrap_path) @property def stations(self): diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index aa7cce0e9c52e54ecdedaf53f2e0e03dded2b795..bb37b9c2be6728667a4bdf6e6cbc0c292347c991 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -111,8 +111,9 @@ class PostProcessing(RunEnvironment): hist, label, station_bootstrap, length = bootstraps.get_generator_station_var_wise(station, var) # make bootstrap predictions - bootstrap_predictions = self.model.predict_generator(generator=station_bootstrap(), + bootstrap_predictions = self.model.predict_generator(generator=station_bootstrap, steps=length, + workers=4, use_multiprocessing=True) if isinstance(bootstrap_predictions, list): # if model is branched model bootstrap_predictions = bootstrap_predictions[-1]