diff --git a/run.py b/run.py index 1579ae35f0d270dc0d2529cf1b6d36bc410e317a..572d8f9f9c25c08b60d2df035d8332c75375c975 100644 --- a/run.py +++ b/run.py @@ -4,20 +4,23 @@ __date__ = '2019-11-14' import logging import argparse -from src.modules.experiment_setup import ExperimentSetup -from src.modules import run, PreProcessing, Training, PostProcessing +from src.modules.experiment_setup import ExperimentSetup +from src.modules.run_environment import RunEnvironment +from src.modules.pre_processing import PreProcessing +from src.modules.modules import Training, PostProcessing -def main(): - with run(): - exp_setup = ExperimentSetup(args, trainable=True, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']) +def main(parser_args): - PreProcessing(exp_setup) + with RunEnvironment(): + ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], + station_type='background') + PreProcessing() - Training(exp_setup) + Training() - PostProcessing(exp_setup) + PostProcessing() if __name__ == "__main__": @@ -30,6 +33,4 @@ if __name__ == "__main__": help="set experiment date as string") args = parser.parse_args() - experiment = ExperimentSetup(args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']) - a = 1 - # main() + main(args) diff --git a/src/modules/modules.py b/src/modules/modules.py index 033fd0779d8d140e684103b27fc7c025dedcdb81..888c7e06f0ef34b17f6c3f2fc2da6fe0316282f4 100644 --- a/src/modules/modules.py +++ b/src/modules/modules.py @@ -8,16 +8,14 @@ from src.modules.pre_processing import PreProcessing class Training(RunEnvironment): - def __init__(self, setup): + def __init__(self): super().__init__() - self.setup = setup class PostProcessing(RunEnvironment): - def __init__(self, setup): + def __init__(self): super().__init__() - self.setup = setup if __name__ == "__main__": diff --git a/src/modules/pre_processing.py b/src/modules/pre_processing.py index d3056f52bd0a60e0c9e7ed97fa593f3b596898a4..2b830f97bcb67251864003357a87762bcdeca07b 100644 --- a/src/modules/pre_processing.py +++ b/src/modules/pre_processing.py @@ -1,3 +1,7 @@ +__author__ = "Lukas Leufen" +__date__ = '2019-11-25' + + import logging from typing import Any, Tuple, Dict, List @@ -29,22 +33,26 @@ class PreProcessing(RunEnvironment): # self._run() - def _create_args_dict(self, arg_list, scope="general"): - args = {} - for arg in arg_list: - try: - args[arg] = self.data_store.get(arg, scope) - except (NameNotFoundInDataStore, NameNotFoundInScope): - pass - return args - def _run(self): - args = self._create_args_dict(DEFAULT_ARGS_LIST) - kwargs = self._create_args_dict(DEFAULT_KWARGS_LIST) + args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST) + kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST) valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general")) self.data_store.put("stations", valid_stations, "general") self.split_train_val_test() + def report_pre_processing(self): + logging.info(20 * '##') + n_train = len(self.data_store.get('generator', 'general.train')) + n_val = len(self.data_store.get('generator', 'general.val')) + n_test = len(self.data_store.get('generator', 'general.test')) + n_total = n_train + n_val + n_test + logging.info(f"Number of all stations: {n_total}") + logging.info(f"Number of training stations: {n_train}") + logging.info(f"Number of val stations: {n_val}") + logging.info(f"Number of test stations: {n_test}") + logging.info(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'general.test')[0][0].shape}" + f"{self.data_store.get('generator', 'general.test')[0][1].shape}") + def split_train_val_test(self): fraction_of_training = self.data_store.get("fraction_of_training", "general") stations = self.data_store.get("stations", "general") @@ -71,8 +79,8 @@ class PreProcessing(RunEnvironment): def create_set_split(self, index_list, set_name): scope = f"general.{set_name}" - args = self._create_args_dict(DEFAULT_ARGS_LIST, scope) - kwargs = self._create_args_dict(DEFAULT_KWARGS_LIST, scope) + args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope) + kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope) stations = args["stations"] if self.data_store.get("use_all_stations_on_all_data_sets", scope): set_stations = stations @@ -81,7 +89,7 @@ class PreProcessing(RunEnvironment): logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}") set_stations = self.check_valid_stations(args, kwargs, set_stations) self.data_store.put("stations", set_stations, scope) - set_args = self._create_args_dict(DEFAULT_ARGS_LIST, scope) + set_args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope) data_set = DataGenerator(**set_args, **kwargs) self.data_store.put("generator", data_set, scope)