From 0b171807595ecfa36d15910eaa1fbf5a681dad23 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Fri, 31 Jan 2020 12:31:32 +0100
Subject: [PATCH] minor refac

---
 run.py                             |  2 +-
 src/run_modules/training.py        | 30 +++++++++++++-----------------
 test/test_modules/test_training.py |  8 ++++----
 3 files changed, 18 insertions(+), 22 deletions(-)

diff --git a/run.py b/run.py
index c06bf648..9f38fdca 100644
--- a/run.py
+++ b/run.py
@@ -24,7 +24,7 @@ def main(parser_args):
 
         Training()
 
-        # PostProcessing()
+        PostProcessing()
 
 
 if __name__ == "__main__":
diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index d1962605..195ae28a 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -86,31 +86,28 @@ class Training(RunEnvironment):
                                                verbose=2,
                                                validation_data=self.val_set.distribute_on_batches(),
                                                validation_steps=len(self.val_set),
-                                               # callbacks=self.callbacks)
                                                callbacks=[self.lr_sc, self.hist, self.checkpoint])
         else:
             lr_filepath = self.checkpoint.callbacks[0]["path"]
             hist_filepath = self.checkpoint.callbacks[1]["path"]
-            lr_callbacks = pickle.load(open(lr_filepath, "rb"))
-            hist_callbacks = pickle.load(open(hist_filepath, "rb"))
-            self.lr_sc = lr_callbacks
-            self.hist = hist_callbacks
+            self.lr_sc = pickle.load(open(lr_filepath, "rb"))
+            self.hist = pickle.load(open(hist_filepath, "rb"))
             self.model = keras.models.load_model(self.checkpoint.filepath)
-            initial_epoch = max(hist_callbacks.epoch) + 1
+            initial_epoch = max(self.hist.epoch) + 1
             callbacks = [{"callback": self.lr_sc, "path": lr_filepath},
                          {"callback": self.hist, "path": hist_filepath}]
             self.checkpoint.update_callbacks(callbacks)
-            self.checkpoint.update_best(hist_callbacks)
+            self.checkpoint.update_best(self.hist)
             _ = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
-                                                 steps_per_epoch=len(self.train_set),
-                                                 epochs=self.epochs,
-                                                 verbose=2,
-                                                 validation_data=self.val_set.distribute_on_batches(),
-                                                 validation_steps=len(self.val_set),
-                                                 callbacks=[self.lr_sc, self.hist, self.checkpoint],
-                                                 initial_epoch=initial_epoch)
+                                         steps_per_epoch=len(self.train_set),
+                                         epochs=self.epochs,
+                                         verbose=2,
+                                         validation_data=self.val_set.distribute_on_batches(),
+                                         validation_steps=len(self.val_set),
+                                         callbacks=[self.lr_sc, self.hist, self.checkpoint],
+                                         initial_epoch=initial_epoch)
             history = self.hist
-        self.save_callbacks(history)
+        self.save_callbacks_as_json(history)
         self.load_best_model(self.checkpoint.filepath)
         self.create_monitoring_plots(history, self.lr_sc)
 
@@ -137,7 +134,7 @@ class Training(RunEnvironment):
         except OSError:
             logging.info('no weights to reload...')
 
-    def save_callbacks(self, history: keras.callbacks.History) -> None:
+    def save_callbacks_as_json(self, history: keras.callbacks.History) -> None:
         """
         Save callbacks (history, learning rate) of training.
         * history.history -> history.json
@@ -150,7 +147,6 @@ class Training(RunEnvironment):
             json.dump(history.history, f)
         with open(os.path.join(path, "history_lr.json"), "w") as f:
             json.dump(self.lr_sc.lr, f)
-            # json.dump(self.callbacks["learning_rate"].lr, f)
 
     def create_monitoring_plots(self, history: keras.callbacks.History, lr_sc: LearningRateDecay) -> None:
         """
diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py
index 4631fe5a..580e7925 100644
--- a/test/test_modules/test_training.py
+++ b/test/test_modules/test_training.py
@@ -189,21 +189,21 @@ class TestTraining:
         assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))
 
     def test_save_callbacks_history_created(self, init_without_run, history, path):
-        init_without_run.save_callbacks(history)
+        init_without_run.save_callbacks_as_json(history)
         assert "history.json" in os.listdir(path)
 
     def test_save_callbacks_lr_created(self, init_with_lr, history, path):
-        init_with_lr.save_callbacks(history)
+        init_with_lr.save_callbacks_as_json(history)
         assert "history_lr.json" in os.listdir(path)
 
     def test_save_callbacks_inspect_history(self, init_without_run, history, path):
-        init_without_run.save_callbacks(history)
+        init_without_run.save_callbacks_as_json(history)
         with open(os.path.join(path, "history.json")) as jfile:
             hist = json.load(jfile)
             assert hist == history.history
 
     def test_save_callbacks_inspect_lr(self, init_with_lr, history, path):
-        init_with_lr.save_callbacks(history)
+        init_with_lr.save_callbacks_as_json(history)
         with open(os.path.join(path, "history_lr.json")) as jfile:
             lr = json.load(jfile)
             assert lr == init_with_lr.lr_sc.lr
-- 
GitLab