From 40bc20bd4fd9206809535f9b6d9a9f80394696c3 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Mon, 9 Dec 2019 17:24:02 +0100
Subject: [PATCH] can run training and save callbacks

---
 run.py                                   |  7 ++--
 src/helpers.py                           |  4 +--
 src/modules/model_setup.py               |  8 +++--
 src/modules/training.py                  | 41 +++++++++++++++++-------
 test/test_helpers.py                     |  4 +--
 test/test_modules/test_pre_processing.py |  8 ++---
 6 files changed, 47 insertions(+), 25 deletions(-)

diff --git a/run.py b/run.py
index ea8c04eb..0f88f37b 100644
--- a/run.py
+++ b/run.py
@@ -9,7 +9,8 @@ from src.modules.experiment_setup import ExperimentSetup
 from src.modules.run_environment import RunEnvironment
 from src.modules.pre_processing import PreProcessing
 from src.modules.model_setup import ModelSetup
-from src.modules.modules import Training, PostProcessing
+from src.modules.training import Training
+from src.modules.modules import PostProcessing
 
 
 def main(parser_args):
@@ -32,8 +33,8 @@ if __name__ == "__main__":
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     parser = argparse.ArgumentParser()
-    parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None,
+    parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None,
                         help="set experiment date as string")
-    args = parser.parse_args()
+    args = parser.parse_args(["--experiment_date", "testrun"])
 
     main(args)
diff --git a/src/helpers.py b/src/helpers.py
index d73a39e1..2ef77689 100644
--- a/src/helpers.py
+++ b/src/helpers.py
@@ -178,11 +178,11 @@ def set_experiment_name(experiment_date=None, experiment_path=None):
     if experiment_date is None:
         experiment_name = "TestExperiment"
     else:
-        experiment_name = f"{experiment_date}_network/"
+        experiment_name = f"{experiment_date}_network"
     if experiment_path is None:
         experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", experiment_name))
     else:
-        experiment_path = os.path.abspath(experiment_path)
+        experiment_path = os.path.join(os.path.abspath(experiment_path), experiment_name)
     return experiment_name, experiment_path
 
 
diff --git a/src/modules/model_setup.py b/src/modules/model_setup.py
index 8a29fd3d..d75ecb36 100644
--- a/src/modules/model_setup.py
+++ b/src/modules/model_setup.py
@@ -24,7 +24,9 @@ class ModelSetup(RunEnvironment):
         # create run framework
         super().__init__()
         self.model = None
-        self.model_name = self.data_store.get("experiment_name", "general") + "model-best.h5"
+        path = self.data_store.get("experiment_path", "general")
+        exp_name = self.data_store.get("experiment_name", "general")
+        self.model_name = os.path.join(path, f"{exp_name}_model-best.h5")
         self.scope = "general.model"
         self._run()
 
@@ -74,7 +76,7 @@ class ModelSetup(RunEnvironment):
     def plot_model(self):  # pragma: no cover
         with tf.device("/cpu:0"):
             path = self.data_store.get("experiment_path", "general")
-            name = self.data_store.get("experiment_name", "general") + "model.pdf"
+            name = self.data_store.get("experiment_name", "general") + "_model.pdf"
             file_name = os.path.join(path, name)
             keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
 
@@ -100,7 +102,7 @@ class ModelSetup(RunEnvironment):
         self.data_store.put("lr_decay", LearningRateDecay(base_lr=initial_lr, drop=.94, epochs_drop=10), self.scope)
 
         # learning settings
-        self.data_store.put("epochs", 2, self.scope)
+        self.data_store.put("epochs", 10, self.scope)
         self.data_store.put("batch_size", int(256), self.scope)
 
         # activation
diff --git a/src/modules/training.py b/src/modules/training.py
index 866e9405..71f76584 100644
--- a/src/modules/training.py
+++ b/src/modules/training.py
@@ -3,6 +3,8 @@ __author__ = "Lukas Leufen, Felix Kleinert"
 __date__ = '2019-12-05'
 
 import logging
+import os
+import json
 
 from src.modules.run_environment import RunEnvironment
 from src.data_handling.data_distributor import Distributor
@@ -13,13 +15,20 @@ class Training(RunEnvironment):
     def __init__(self):
         super().__init__()
         self.model = self.data_store.get("model", "general.model")
-        self.train_generator = None
-        self.val_generator = None
-        self.test_generator = None
+        self.train_set = None
+        self.val_set = None
+        self.test_set = None
         self.batch_size = self.data_store.get("batch_size", "general.model")
         self.epochs = self.data_store.get("epochs", "general.model")
         self.checkpoint = self.data_store.get("checkpoint", "general.model")
-        self.lr_sc = self.data_store.get("epochs", "general.model")
+        self.lr_sc = self.data_store.get("lr_decay", "general.model")
+        self.experiment_name = self.data_store.get("experiment_name", "general")
+        self._run()
+
+    def _run(self):
+        self.set_generators()
+        self.make_predict_function()
+        self.train()
 
     def make_predict_function(self):
         # create the predict function before distributing. This is necessary, because tf will compile the predict
@@ -30,19 +39,29 @@ class Training(RunEnvironment):
 
     def _set_gen(self, mode):
         gen = self.data_store.get("generator", f"general.{mode}")
-        setattr(self, f"{mode}_generator", Distributor(gen, self.model, self.batch_size))
+        setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size))
 
     def set_generators(self):
-        map(lambda mode: self._set_gen(mode), ["train", "val", "test"])
+        for mode in ["train", "val", "test"]:
+            self._set_gen(mode)
 
     def train(self):
-        logging.info(f"Train with {len(self.train_generator)} mini batches.")
-        history = self.model.fit_generator(generator=self.train_generator.distribute_on_batches(),
-                                           steps_per_epoch=len(self.train_generator),
+        logging.info(f"Train with {len(self.train_set)} mini batches.")
+        history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
+                                           steps_per_epoch=len(self.train_set),
                                            epochs=self.epochs,
                                            verbose=2,
-                                           validation_data=self.val_generator.distribute_on_batches(),
-                                           validation_steps=len(self.val_generator),
+                                           validation_data=self.val_set.distribute_on_batches(),
+                                           validation_steps=len(self.val_set),
                                            callbacks=[self.checkpoint, self.lr_sc])
+        self.save_callbacks(history)
+
+    def save_callbacks(self, history):
+        path = self.data_store.get("experiment_path", "general")
+        with open(os.path.join(path, "history.json"), "w") as f:
+            json.dump(history.history, f)
+        with open(os.path.join(path, "history_lr.json"), "w") as f:
+            json.dump(self.lr_sc.lr, f)
+
 
 
diff --git a/test/test_helpers.py b/test/test_helpers.py
index f5b41c5d..181c5f29 100644
--- a/test/test_helpers.py
+++ b/test/test_helpers.py
@@ -183,9 +183,9 @@ class TestSetExperimentName:
         assert exp_name == "TestExperiment"
         assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "TestExperiment"))
         exp_name, exp_path = set_experiment_name(experiment_date="2019-11-14", experiment_path="./test2")
-        assert exp_name == "2019-11-14_network/"
+        assert exp_name == "2019-11-14_network"
         assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "test2"))
 
     def test_set_experiment_from_sys(self):
         exp_name, _ = set_experiment_name(experiment_date="2019-11-14")
-        assert exp_name == "2019-11-14_network/"
+        assert exp_name == "2019-11-14_network"
diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py
index c333322a..c884b146 100644
--- a/test/test_modules/test_pre_processing.py
+++ b/test/test_modules/test_pre_processing.py
@@ -39,8 +39,8 @@ class TestPreProcessing:
         with PreProcessing():
             assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
             assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started')
-            assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\). '
-                                                                        r'Found 5/5 valid stations.'))
+            assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 5 '
+                                                                        r'station\(s\). Found 5/5 valid stations.'))
         RunEnvironment().__del__()
 
     def test_run(self, obj_with_exp_setup):
@@ -88,8 +88,8 @@ class TestPreProcessing:
         assert len(valid_stations) < len(stations)
         assert valid_stations == stations[:-1]
         assert caplog.record_tuples[0] == ('root', 20, 'check valid stations started')
-        assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 6 station\(s\). Found '
-                                                                    r'5/6 valid stations.'))
+        assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 6 '
+                                                                    r'station\(s\). Found 5/6 valid stations.'))
 
     def test_split_set_indices(self, obj_super_init):
         dummy_list = list(range(0, 15))
-- 
GitLab