From df8e3cba29a1028dbf31b2697976b04dde0fd642 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Fri, 13 Mar 2020 14:45:47 +0100
Subject: [PATCH] added missing lr availability checks

---
 src/run_modules/training.py | 13 +++++++++----
 1 file changed, 9 insertions(+), 4 deletions(-)

diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index df60c4f2..55b5c296 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -111,7 +111,10 @@ class Training(RunEnvironment):
                                          callbacks=self.callbacks.get_callbacks(as_dict=False),
                                          initial_epoch=initial_epoch)
             history = hist
-        lr = self.callbacks.get_callback_by_name("lr")
+        try:
+            lr = self.callbacks.get_callback_by_name("lr")
+        except IndexError:
+            lr = None
         self.save_callbacks_as_json(history, lr)
         self.load_best_model(checkpoint.filepath)
         self.create_monitoring_plots(history, lr)
@@ -148,8 +151,9 @@ class Training(RunEnvironment):
         path = self.data_store.get("experiment_path", "general")
         with open(os.path.join(path, "history.json"), "w") as f:
             json.dump(history.history, f)
-        with open(os.path.join(path, "history_lr.json"), "w") as f:
-            json.dump(lr_sc.lr, f)
+        if lr_sc:
+            with open(os.path.join(path, "history_lr.json"), "w") as f:
+                json.dump(lr_sc.lr, f)
 
     def create_monitoring_plots(self, history: keras.callbacks.History, lr_sc: LearningRateDecay) -> None:
         """
@@ -174,4 +178,5 @@ class Training(RunEnvironment):
             PlotModelHistory(filename=filename, history=history, plot_metric="mse", main_branch=multiple_branches_used)
 
         # plot learning rate
-        PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc)
+        if lr_sc:
+            PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc)
-- 
GitLab