diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py index 8690785659ab256fc78b4cfe8701461f67236a9b..983868c8244c40063c84c3e4df2e9e79a960dd01 100644 --- a/src/data_handling/bootstraps.py +++ b/src/data_handling/bootstraps.py @@ -23,6 +23,7 @@ class BootStrapGenerator: self.chunksize = chunksize self.bootstrap_path = bootstrap_path self._iterator = 0 + self.bootstrap_meta = [] def __len__(self): """ @@ -50,6 +51,7 @@ class BootStrapGenerator: station = self.orig_generator.get_station_key(i) logging.info(f"station: {station}") hist, label = data + len_of_label = len(label) shuffled_data = self.load_boot_data(station) for var in self.variables: logging.info(f" var: {var}") @@ -59,6 +61,7 @@ class BootStrapGenerator: shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables") boot_hist = boot_hist.combine_first(shuffled_var) boot_hist = boot_hist.sortby("variables") + self.bootstrap_meta.extend([var]*len_of_label) yield boot_hist, label return @@ -73,20 +76,26 @@ class BootStrapGenerator: class BootStraps(RunEnvironment): - def __init__(self): + def __init__(self, data, bootstrap_path, number_bootstraps=10): super().__init__() - self.test_data: DataGenerator = self.data_store.get("generator", "general.test") - self.number_bootstraps = 10 - self.bootstrap_path = self.data_store.get("bootstrap_path", "general") + # self.data: DataGenerator = self.data_store.get("generator", "general.test") + self.data: DataGenerator = data + self.number_bootstraps = number_bootstraps + # self.bootstrap_path = self.data_store.get("bootstrap_path", "general") + self.bootstrap_path = bootstrap_path self.chunks = self.get_chunk_size() self.create_shuffled_data() - bsg =BootStrapGenerator(self.test_data, self.number_bootstraps, self.chunks, self.bootstrap_path) - for bs in bsg: - hist, label = bs + self._boot_strap_generator = BootStrapGenerator(self.data, self.number_bootstraps, self.chunks, self.bootstrap_path) + + def get_boot_strap_meta(self): + return self._boot_strap_generator.bootstrap_meta + + def boot_strap_generator(self): + return self._boot_strap_generator def get_chunk_size(self): - hist, _ = self.test_data[0] + hist, _ = self.data[0] return (100, *hist.shape[1:], self.number_bootstraps) def create_shuffled_data(self): @@ -95,13 +104,13 @@ class BootStraps(RunEnvironment): randomly selected variables. If there is a suitable local file for requested window size and number of bootstraps, no additional file will be created inside this function. """ - variables_str = '_'.join(sorted(self.test_data.variables)) - window = self.test_data.window_history_size - for station in self.test_data.stations: + variables_str = '_'.join(sorted(self.data.variables)) + window = self.data.window_history_size + for station in self.data.stations: valid, nboot = self.valid_bootstrap_file(station, variables_str, window) if not valid: logging.info(f'create bootstap data for {station}') - hist, _ = self.test_data[station] + hist, _ = self.data[station] data = hist.copy() file_name = f"{station}_{variables_str}_hist{window}_nboots{nboot}_shuffled.nc" file_path = os.path.join(self.bootstrap_path, file_name) @@ -157,9 +166,16 @@ if __name__ == "__main__": formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]' logging.basicConfig(format=formatter, level=logging.INFO) - with RunEnvironment(): + with RunEnvironment() as run_env: ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013'], station_type='background', trainable=True, window_history_size=9) PreProcessing() - BootStraps() + data = run_env.data_store.get("generator", "general.test") + path = run_env.data_store.get("bootstrap_path", "general") + number_bootstraps = 10 + + boots = BootStraps(data, path, number_bootstraps) + for b in boots.boot_strap_generator(): + a, c = b + logging.info(f"len is {len(boots.get_boot_strap_meta())}")