Skip to content
Snippets Groups Projects
Commit cf2d78fa authored by lukas leufen's avatar lukas leufen
Browse files

method to create an args dictionary is shifted to data store, added report pre-processing function

parent a1525109
No related branches found
No related tags found
2 merge requests!24include recent development,!18include setup ml model
Pipeline #26749 passed
......@@ -4,20 +4,23 @@ __date__ = '2019-11-14'
import logging
import argparse
from src.modules.experiment_setup import ExperimentSetup
from src.modules import run, PreProcessing, Training, PostProcessing
from src.modules.experiment_setup import ExperimentSetup
from src.modules.run_environment import RunEnvironment
from src.modules.pre_processing import PreProcessing
from src.modules.modules import Training, PostProcessing
def main():
with run():
exp_setup = ExperimentSetup(args, trainable=True, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'])
def main(parser_args):
PreProcessing(exp_setup)
with RunEnvironment():
ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'],
station_type='background')
PreProcessing()
Training(exp_setup)
Training()
PostProcessing(exp_setup)
PostProcessing()
if __name__ == "__main__":
......@@ -30,6 +33,4 @@ if __name__ == "__main__":
help="set experiment date as string")
args = parser.parse_args()
experiment = ExperimentSetup(args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'])
a = 1
# main()
main(args)
......@@ -8,16 +8,14 @@ from src.modules.pre_processing import PreProcessing
class Training(RunEnvironment):
def __init__(self, setup):
def __init__(self):
super().__init__()
self.setup = setup
class PostProcessing(RunEnvironment):
def __init__(self, setup):
def __init__(self):
super().__init__()
self.setup = setup
if __name__ == "__main__":
......
__author__ = "Lukas Leufen"
__date__ = '2019-11-25'
import logging
from typing import Any, Tuple, Dict, List
......@@ -29,22 +33,26 @@ class PreProcessing(RunEnvironment):
#
self._run()
def _create_args_dict(self, arg_list, scope="general"):
args = {}
for arg in arg_list:
try:
args[arg] = self.data_store.get(arg, scope)
except (NameNotFoundInDataStore, NameNotFoundInScope):
pass
return args
def _run(self):
args = self._create_args_dict(DEFAULT_ARGS_LIST)
kwargs = self._create_args_dict(DEFAULT_KWARGS_LIST)
args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST)
kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST)
valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"))
self.data_store.put("stations", valid_stations, "general")
self.split_train_val_test()
def report_pre_processing(self):
logging.info(20 * '##')
n_train = len(self.data_store.get('generator', 'general.train'))
n_val = len(self.data_store.get('generator', 'general.val'))
n_test = len(self.data_store.get('generator', 'general.test'))
n_total = n_train + n_val + n_test
logging.info(f"Number of all stations: {n_total}")
logging.info(f"Number of training stations: {n_train}")
logging.info(f"Number of val stations: {n_val}")
logging.info(f"Number of test stations: {n_test}")
logging.info(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'general.test')[0][0].shape}"
f"{self.data_store.get('generator', 'general.test')[0][1].shape}")
def split_train_val_test(self):
fraction_of_training = self.data_store.get("fraction_of_training", "general")
stations = self.data_store.get("stations", "general")
......@@ -71,8 +79,8 @@ class PreProcessing(RunEnvironment):
def create_set_split(self, index_list, set_name):
scope = f"general.{set_name}"
args = self._create_args_dict(DEFAULT_ARGS_LIST, scope)
kwargs = self._create_args_dict(DEFAULT_KWARGS_LIST, scope)
args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope)
kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope)
stations = args["stations"]
if self.data_store.get("use_all_stations_on_all_data_sets", scope):
set_stations = stations
......@@ -81,7 +89,7 @@ class PreProcessing(RunEnvironment):
logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}")
set_stations = self.check_valid_stations(args, kwargs, set_stations)
self.data_store.put("stations", set_stations, scope)
set_args = self._create_args_dict(DEFAULT_ARGS_LIST, scope)
set_args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope)
data_set = DataGenerator(**set_args, **kwargs)
self.data_store.put("generator", data_set, scope)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment