diff --git a/run.py b/run.py index 8e4d9c46f4a39224335dd65b689f519943166d0f..556cd0b59ed023178fa7e6df1b5b03b9f6c5f157 100644 --- a/run.py +++ b/run.py @@ -17,7 +17,7 @@ def main(parser_args): with RunEnvironment(): ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], - station_type='background', trainable=True) + station_type='background', trainable=False, create_new_model=False) PreProcessing() ModelSetup() diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py index 8690785659ab256fc78b4cfe8701461f67236a9b..998ed8c6990d6a16388874c52faf4931ef4ba174 100644 --- a/src/data_handling/bootstraps.py +++ b/src/data_handling/bootstraps.py @@ -23,6 +23,7 @@ class BootStrapGenerator: self.chunksize = chunksize self.bootstrap_path = bootstrap_path self._iterator = 0 + self.bootstrap_meta = [] def __len__(self): """ @@ -30,16 +31,7 @@ class BootStrapGenerator: """ return len(self.orig_generator)*self.boots*len(self.variables) - # def __iter__(self): - # """ - # Define the __iter__ part of the iterator protocol to iterate through this generator. Sets the private attribute - # `_iterator` to 0. - # :return: - # """ - # self._iterator = 0 - # return self - - def __iter__(self): + def get_generator(self): """ This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return the history and label data of this generator. @@ -50,6 +42,7 @@ class BootStrapGenerator: station = self.orig_generator.get_station_key(i) logging.info(f"station: {station}") hist, label = data + len_of_label = len(label) shuffled_data = self.load_boot_data(station) for var in self.variables: logging.info(f" var: {var}") @@ -59,6 +52,7 @@ class BootStrapGenerator: shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables") boot_hist = boot_hist.combine_first(shuffled_var) boot_hist = boot_hist.sortby("variables") + self.bootstrap_meta.extend([var]*len_of_label) yield boot_hist, label return @@ -66,27 +60,33 @@ class BootStrapGenerator: files = os.listdir(self.bootstrap_path) regex = re.compile(rf"{station}_\w*\.nc") file_name = os.path.join(self.bootstrap_path, list(filter(regex.search, files))[0]) - # shuffled_data = xr.open_dataarray(file_name, chunks=self.chunksize) shuffled_data = xr.open_dataarray(file_name, chunks=100) return shuffled_data class BootStraps(RunEnvironment): - def __init__(self): + def __init__(self, data, bootstrap_path, number_bootstraps=10): super().__init__() - self.test_data: DataGenerator = self.data_store.get("generator", "general.test") - self.number_bootstraps = 10 - self.bootstrap_path = self.data_store.get("bootstrap_path", "general") + self.data: DataGenerator = data + self.number_bootstraps = number_bootstraps + self.bootstrap_path = bootstrap_path self.chunks = self.get_chunk_size() self.create_shuffled_data() - bsg =BootStrapGenerator(self.test_data, self.number_bootstraps, self.chunks, self.bootstrap_path) - for bs in bsg: - hist, label = bs + self._boot_strap_generator = BootStrapGenerator(self.data, self.number_bootstraps, self.chunks, self.bootstrap_path) + + def get_boot_strap_meta(self): + return self._boot_strap_generator.bootstrap_meta + + def boot_strap_generator(self): + return self._boot_strap_generator.get_generator() + + def get_boot_strap_generator_length(self): + return self._boot_strap_generator.__len__() def get_chunk_size(self): - hist, _ = self.test_data[0] + hist, _ = self.data[0] return (100, *hist.shape[1:], self.number_bootstraps) def create_shuffled_data(self): @@ -95,13 +95,14 @@ class BootStraps(RunEnvironment): randomly selected variables. If there is a suitable local file for requested window size and number of bootstraps, no additional file will be created inside this function. """ - variables_str = '_'.join(sorted(self.test_data.variables)) - window = self.test_data.window_history_size - for station in self.test_data.stations: + logging.info("create shuffled bootstrap data") + variables_str = '_'.join(sorted(self.data.variables)) + window = self.data.window_history_size + for station in self.data.stations: valid, nboot = self.valid_bootstrap_file(station, variables_str, window) if not valid: logging.info(f'create bootstap data for {station}') - hist, _ = self.test_data[station] + hist, _ = self.data[station] data = hist.copy() file_name = f"{station}_{variables_str}_hist{window}_nboots{nboot}_shuffled.nc" file_path = os.path.join(self.bootstrap_path, file_name) @@ -157,9 +158,16 @@ if __name__ == "__main__": formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]' logging.basicConfig(format=formatter, level=logging.INFO) - with RunEnvironment(): + with RunEnvironment() as run_env: ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013'], station_type='background', trainable=True, window_history_size=9) PreProcessing() - BootStraps() + data = run_env.data_store.get("generator", "general.test") + path = run_env.data_store.get("bootstrap_path", "general") + number_bootstraps = 10 + + boots = BootStraps(data, path, number_bootstraps) + for b in boots.boot_strap_generator(): + a, c = b + logging.info(f"len is {len(boots.get_boot_strap_meta())}") diff --git a/src/helpers.py b/src/helpers.py index 8a50b0e723d28652e1eb7e27c53636b506774b74..ab6799057145550f4346e05f29aba7741da03989 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -49,9 +49,10 @@ class TimeTracking(object): method. Duration can always be shown by printing the time tracking object or calling get_current_duration. """ - def __init__(self, start=True): + def __init__(self, start=True, name="undefined job"): self.start = None self.end = None + self._name = name if start: self._start() @@ -93,7 +94,7 @@ class TimeTracking(object): def __exit__(self, exc_type, exc_val, exc_tb): self.stop() - logging.info(f"undefined job finished after {self}") + logging.info(f"{self._name} finished after {self}") def prepare_host(create_new=True, sampling="daily"): diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index e3945a542d60b09dc9855bd28be87cdba729ed72..d2c8d93fc957ecb2990e99000cbd3588e2d83eef 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -10,8 +10,8 @@ import tensorflow as tf from src.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler # from src.model_modules.model_class import MyBranchedModel as MyModel -# from src.model_modules.model_class import MyLittleModel as MyModel -from src.model_modules.model_class import MyTowerModel as MyModel +from src.model_modules.model_class import MyLittleModel as MyModel +# from src.model_modules.model_class import MyTowerModel as MyModel from src.run_modules.run_environment import RunEnvironment diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 06203c879872891f57c719040482fe052824c65e..8a0df43756acfdba0c86be7eecc8e4da37999ce3 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -13,6 +13,7 @@ import xarray as xr from src import statistics from src.data_handling.data_distributor import Distributor from src.data_handling.data_generator import DataGenerator +from src.data_handling.bootstraps import BootStraps from src.datastore import NameNotFoundInDataStore from src.helpers import TimeTracking from src.model_modules.linear_model import OrdinaryLeastSquaredModel @@ -50,6 +51,25 @@ class PostProcessing(RunEnvironment): "skip make_prediction() whenever it is possible to save time.") self.skill_scores = self.calculate_skill_scores() self.plot() + self.create_boot_straps() + + def create_boot_straps(self): + bootstrap_path = self.data_store.get("bootstrap_path", "general") + forecast_path = self.data_store.get("forecast_path", "general") + window_lead_time = self.data_store.get("window_lead_time", "general") + bootstraps = BootStraps(self.test_data, bootstrap_path, 20) + with TimeTracking(name="boot predictions"): + bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(), + steps=bootstraps.get_boot_strap_generator_length()) + bootstrap_meta = np.array(bootstraps.get_boot_strap_meta()) + length = sum(bootstrap_meta == bootstrap_meta[0]) + variables = np.unique(bootstrap_meta) + for boot in variables: + ind = (bootstrap_meta == boot) + sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1)) + tmp = xr.DataArray(sel, coords=(range(length), range(window_lead_time), [boot]), dims=["index", "window", "boot"]) + file_name = os.path.join(forecast_path, f"bootstraps_{boot}.nc") + tmp.to_netcdf(file_name) def _load_model(self): try: diff --git a/test/test_helpers.py b/test/test_helpers.py index b807d2b8612b9ee006bff43f1ae4cfcfd2dd07e1..07ec244e078f977dca761274260275aab355c183 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -117,6 +117,14 @@ class TestTimeTracking: expression = PyTestRegex(r"undefined job finished after \d+:\d+:\d+ \(hh:mm:ss\)") assert caplog.record_tuples[-1] == ('root', 20, expression) + def test_name_enter_exit(self, caplog): + caplog.set_level(logging.INFO) + with TimeTracking(name="my job") as t: + assert t.start is not None + assert t.end is None + expression = PyTestRegex(r"my job finished after \d+:\d+:\d+ \(hh:mm:ss\)") + assert caplog.record_tuples[-1] == ('root', 20, expression) + class TestPrepareHost: