diff --git a/run.py b/run.py index ea8c04ebde02a80b899a356eb0f7794055abe2d6..0f88f37bb15811cbb2fabb08900618c1d72945da 100644 --- a/run.py +++ b/run.py @@ -9,7 +9,8 @@ from src.modules.experiment_setup import ExperimentSetup from src.modules.run_environment import RunEnvironment from src.modules.pre_processing import PreProcessing from src.modules.model_setup import ModelSetup -from src.modules.modules import Training, PostProcessing +from src.modules.training import Training +from src.modules.modules import PostProcessing def main(parser_args): @@ -32,8 +33,8 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) parser = argparse.ArgumentParser() - parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None, + parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None, help="set experiment date as string") - args = parser.parse_args() + args = parser.parse_args(["--experiment_date", "testrun"]) main(args) diff --git a/src/helpers.py b/src/helpers.py index d73a39e15d43d25f7be717f11fe7f67260f07ac0..2ef776898e35a16b0bfd54b5984864c740dbf341 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -178,11 +178,11 @@ def set_experiment_name(experiment_date=None, experiment_path=None): if experiment_date is None: experiment_name = "TestExperiment" else: - experiment_name = f"{experiment_date}_network/" + experiment_name = f"{experiment_date}_network" if experiment_path is None: experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", experiment_name)) else: - experiment_path = os.path.abspath(experiment_path) + experiment_path = os.path.join(os.path.abspath(experiment_path), experiment_name) return experiment_name, experiment_path diff --git a/src/modules/model_setup.py b/src/modules/model_setup.py index 8a29fd3d1a204affed70517d3ebb807c820b9160..d75ecb36477cb5d8a1739fbc9bfe663f5b359077 100644 --- a/src/modules/model_setup.py +++ b/src/modules/model_setup.py @@ -24,7 +24,9 @@ class ModelSetup(RunEnvironment): # create run framework super().__init__() self.model = None - self.model_name = self.data_store.get("experiment_name", "general") + "model-best.h5" + path = self.data_store.get("experiment_path", "general") + exp_name = self.data_store.get("experiment_name", "general") + self.model_name = os.path.join(path, f"{exp_name}_model-best.h5") self.scope = "general.model" self._run() @@ -74,7 +76,7 @@ class ModelSetup(RunEnvironment): def plot_model(self): # pragma: no cover with tf.device("/cpu:0"): path = self.data_store.get("experiment_path", "general") - name = self.data_store.get("experiment_name", "general") + "model.pdf" + name = self.data_store.get("experiment_name", "general") + "_model.pdf" file_name = os.path.join(path, name) keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) @@ -100,7 +102,7 @@ class ModelSetup(RunEnvironment): self.data_store.put("lr_decay", LearningRateDecay(base_lr=initial_lr, drop=.94, epochs_drop=10), self.scope) # learning settings - self.data_store.put("epochs", 2, self.scope) + self.data_store.put("epochs", 10, self.scope) self.data_store.put("batch_size", int(256), self.scope) # activation diff --git a/src/modules/training.py b/src/modules/training.py index 866e9405acec35d4602a1ca6b079fdc53a05b71f..71f765843123f45d2e09098a3261ad9000e5146c 100644 --- a/src/modules/training.py +++ b/src/modules/training.py @@ -3,6 +3,8 @@ __author__ = "Lukas Leufen, Felix Kleinert" __date__ = '2019-12-05' import logging +import os +import json from src.modules.run_environment import RunEnvironment from src.data_handling.data_distributor import Distributor @@ -13,13 +15,20 @@ class Training(RunEnvironment): def __init__(self): super().__init__() self.model = self.data_store.get("model", "general.model") - self.train_generator = None - self.val_generator = None - self.test_generator = None + self.train_set = None + self.val_set = None + self.test_set = None self.batch_size = self.data_store.get("batch_size", "general.model") self.epochs = self.data_store.get("epochs", "general.model") self.checkpoint = self.data_store.get("checkpoint", "general.model") - self.lr_sc = self.data_store.get("epochs", "general.model") + self.lr_sc = self.data_store.get("lr_decay", "general.model") + self.experiment_name = self.data_store.get("experiment_name", "general") + self._run() + + def _run(self): + self.set_generators() + self.make_predict_function() + self.train() def make_predict_function(self): # create the predict function before distributing. This is necessary, because tf will compile the predict @@ -30,19 +39,29 @@ class Training(RunEnvironment): def _set_gen(self, mode): gen = self.data_store.get("generator", f"general.{mode}") - setattr(self, f"{mode}_generator", Distributor(gen, self.model, self.batch_size)) + setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size)) def set_generators(self): - map(lambda mode: self._set_gen(mode), ["train", "val", "test"]) + for mode in ["train", "val", "test"]: + self._set_gen(mode) def train(self): - logging.info(f"Train with {len(self.train_generator)} mini batches.") - history = self.model.fit_generator(generator=self.train_generator.distribute_on_batches(), - steps_per_epoch=len(self.train_generator), + logging.info(f"Train with {len(self.train_set)} mini batches.") + history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(), + steps_per_epoch=len(self.train_set), epochs=self.epochs, verbose=2, - validation_data=self.val_generator.distribute_on_batches(), - validation_steps=len(self.val_generator), + validation_data=self.val_set.distribute_on_batches(), + validation_steps=len(self.val_set), callbacks=[self.checkpoint, self.lr_sc]) + self.save_callbacks(history) + + def save_callbacks(self, history): + path = self.data_store.get("experiment_path", "general") + with open(os.path.join(path, "history.json"), "w") as f: + json.dump(history.history, f) + with open(os.path.join(path, "history_lr.json"), "w") as f: + json.dump(self.lr_sc.lr, f) + diff --git a/test/test_helpers.py b/test/test_helpers.py index f5b41c5da8fa7b5297910e54bf36e2eb09be080c..181c5f2976508da7828a5e67015b26e12903bb7f 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -183,9 +183,9 @@ class TestSetExperimentName: assert exp_name == "TestExperiment" assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "TestExperiment")) exp_name, exp_path = set_experiment_name(experiment_date="2019-11-14", experiment_path="./test2") - assert exp_name == "2019-11-14_network/" + assert exp_name == "2019-11-14_network" assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "test2")) def test_set_experiment_from_sys(self): exp_name, _ = set_experiment_name(experiment_date="2019-11-14") - assert exp_name == "2019-11-14_network/" + assert exp_name == "2019-11-14_network" diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index c333322a911732470fc25f413c10f2db14514515..c884b14657447a50377dc38ec2dea10ba300f4d7 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -39,8 +39,8 @@ class TestPreProcessing: with PreProcessing(): assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started') assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started') - assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\). ' - r'Found 5/5 valid stations.')) + assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 5 ' + r'station\(s\). Found 5/5 valid stations.')) RunEnvironment().__del__() def test_run(self, obj_with_exp_setup): @@ -88,8 +88,8 @@ class TestPreProcessing: assert len(valid_stations) < len(stations) assert valid_stations == stations[:-1] assert caplog.record_tuples[0] == ('root', 20, 'check valid stations started') - assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 6 station\(s\). Found ' - r'5/6 valid stations.')) + assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 6 ' + r'station\(s\). Found 5/6 valid stations.')) def test_split_set_indices(self, obj_super_init): dummy_list = list(range(0, 15))