diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 44de71171377076e887c099ec1229391daae32d8..3263f5c4562eeac321c7ce621df551fdf6373ba0 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -35,16 +35,12 @@ class PreProcessing(RunEnvironment): def _run(self): args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="general.preprocessing") kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="general.preprocessing") - valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"), load_tmp=False, save_tmp=False) - self.calculate_transformation(args, kwargs, valid_stations, load_tmp=False) + stations = self.data_store.get("stations", "general") + valid_stations = self.check_valid_stations(args, kwargs, stations, load_tmp=False, save_tmp=False) self.data_store.set("stations", valid_stations, "general") self.split_train_val_test() self.report_pre_processing() - def calculate_transformation(self, args: Dict, kwargs: Dict, all_stations: List[str], load_tmp): - - pass - def report_pre_processing(self): logging.debug(20 * '##') n_train = len(self.data_store.get('generator', 'general.train')) @@ -58,11 +54,19 @@ class PreProcessing(RunEnvironment): logging.debug(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'general.test')[0][0].shape}" f"{self.data_store.get('generator', 'general.test')[0][1].shape}") - def split_train_val_test(self): + def split_train_val_test(self) -> None: + """ + Splits all subsets. Currently: train, val, test and train_val (actually this is only the merge of train and val, + but as an separate generator). IMPORTANT: Do not change to order of the execution of create_set_split. The train + subset needs always to be executed at first, to set a proper transformation. + """ fraction_of_training = self.data_store.get("fraction_of_training", "general") stations = self.data_store.get("stations", "general") train_index, val_index, test_index, train_val_index = self.split_set_indices(len(stations), fraction_of_training) subset_names = ["train", "val", "test", "train_val"] + if subset_names[0] != "train": # pragma: no cover + raise AssertionError(f"Make sure, that the train subset is always at first execution position! Given subset" + f"order was: {subset_names}.") for (ind, scope) in zip([train_index, val_index, test_index, train_val_index], subset_names): self.create_set_split(ind, scope) @@ -84,7 +88,16 @@ class PreProcessing(RunEnvironment): train_val_index = slice(0, pos_test_split) return train_index, val_index, test_index, train_val_index - def create_set_split(self, index_list, set_name): + def create_set_split(self, index_list: slice, set_name) -> None: + """ + Create the subset for given split index and stores the DataGenerator with given set name in data store as + `generator`. Checks for all valid stations using the default (kw)args for given scope and creates the + DataGenerator for all valid stations. Also sets all transformation information, if subset is training set. Make + sure, that the train set is executed first, and all other subsets afterwards. + :param index_list: list of all stations to use for the set. If attribute use_all_stations_on_all_data_sets=True, + this list is ignored. + :param set_name: name to load/save all information from/to data store without the leading general prefix. + """ scope = f"general.{set_name}" args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope) kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope)