From df8e3cba29a1028dbf31b2697976b04dde0fd642 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Fri, 13 Mar 2020 14:45:47 +0100 Subject: [PATCH] added missing lr availability checks --- src/run_modules/training.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/run_modules/training.py b/src/run_modules/training.py index df60c4f2..55b5c296 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -111,7 +111,10 @@ class Training(RunEnvironment): callbacks=self.callbacks.get_callbacks(as_dict=False), initial_epoch=initial_epoch) history = hist - lr = self.callbacks.get_callback_by_name("lr") + try: + lr = self.callbacks.get_callback_by_name("lr") + except IndexError: + lr = None self.save_callbacks_as_json(history, lr) self.load_best_model(checkpoint.filepath) self.create_monitoring_plots(history, lr) @@ -148,8 +151,9 @@ class Training(RunEnvironment): path = self.data_store.get("experiment_path", "general") with open(os.path.join(path, "history.json"), "w") as f: json.dump(history.history, f) - with open(os.path.join(path, "history_lr.json"), "w") as f: - json.dump(lr_sc.lr, f) + if lr_sc: + with open(os.path.join(path, "history_lr.json"), "w") as f: + json.dump(lr_sc.lr, f) def create_monitoring_plots(self, history: keras.callbacks.History, lr_sc: LearningRateDecay) -> None: """ @@ -174,4 +178,5 @@ class Training(RunEnvironment): PlotModelHistory(filename=filename, history=history, plot_metric="mse", main_branch=multiple_branches_used) # plot learning rate - PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc) + if lr_sc: + PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc) -- GitLab