Skip to content
Snippets Groups Projects
Commit a043bfdc authored by lukas leufen's avatar lukas leufen
Browse files

include addition from #54

parents ac1d06aa ae6b6f49
No related branches found
No related tags found
2 merge requests!50release for v0.7.0,!49Lukas issue054 feat transformation on entire dataset
Pipeline #31014 passed
......@@ -35,16 +35,12 @@ class PreProcessing(RunEnvironment):
def _run(self):
args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="general.preprocessing")
kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="general.preprocessing")
valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"), load_tmp=False, save_tmp=False)
self.calculate_transformation(args, kwargs, valid_stations, load_tmp=False)
stations = self.data_store.get("stations", "general")
valid_stations = self.check_valid_stations(args, kwargs, stations, load_tmp=False, save_tmp=False)
self.data_store.set("stations", valid_stations, "general")
self.split_train_val_test()
self.report_pre_processing()
def calculate_transformation(self, args: Dict, kwargs: Dict, all_stations: List[str], load_tmp):
pass
def report_pre_processing(self):
logging.debug(20 * '##')
n_train = len(self.data_store.get('generator', 'general.train'))
......@@ -58,11 +54,19 @@ class PreProcessing(RunEnvironment):
logging.debug(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'general.test')[0][0].shape}"
f"{self.data_store.get('generator', 'general.test')[0][1].shape}")
def split_train_val_test(self):
def split_train_val_test(self) -> None:
"""
Splits all subsets. Currently: train, val, test and train_val (actually this is only the merge of train and val,
but as an separate generator). IMPORTANT: Do not change to order of the execution of create_set_split. The train
subset needs always to be executed at first, to set a proper transformation.
"""
fraction_of_training = self.data_store.get("fraction_of_training", "general")
stations = self.data_store.get("stations", "general")
train_index, val_index, test_index, train_val_index = self.split_set_indices(len(stations), fraction_of_training)
subset_names = ["train", "val", "test", "train_val"]
if subset_names[0] != "train": # pragma: no cover
raise AssertionError(f"Make sure, that the train subset is always at first execution position! Given subset"
f"order was: {subset_names}.")
for (ind, scope) in zip([train_index, val_index, test_index, train_val_index], subset_names):
self.create_set_split(ind, scope)
......@@ -84,7 +88,16 @@ class PreProcessing(RunEnvironment):
train_val_index = slice(0, pos_test_split)
return train_index, val_index, test_index, train_val_index
def create_set_split(self, index_list, set_name):
def create_set_split(self, index_list: slice, set_name) -> None:
"""
Create the subset for given split index and stores the DataGenerator with given set name in data store as
`generator`. Checks for all valid stations using the default (kw)args for given scope and creates the
DataGenerator for all valid stations. Also sets all transformation information, if subset is training set. Make
sure, that the train set is executed first, and all other subsets afterwards.
:param index_list: list of all stations to use for the set. If attribute use_all_stations_on_all_data_sets=True,
this list is ignored.
:param set_name: name to load/save all information from/to data store without the leading general prefix.
"""
scope = f"general.{set_name}"
args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope)
kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment