diff --git a/src/run_modules/training.py b/src/run_modules/training.py index df60c4f2f8dff4a9acb82920ad3c1d203813033d..55b5c2964de3155a8d34cf87a646c0d53deebbef 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -111,7 +111,10 @@ class Training(RunEnvironment): callbacks=self.callbacks.get_callbacks(as_dict=False), initial_epoch=initial_epoch) history = hist - lr = self.callbacks.get_callback_by_name("lr") + try: + lr = self.callbacks.get_callback_by_name("lr") + except IndexError: + lr = None self.save_callbacks_as_json(history, lr) self.load_best_model(checkpoint.filepath) self.create_monitoring_plots(history, lr) @@ -148,8 +151,9 @@ class Training(RunEnvironment): path = self.data_store.get("experiment_path", "general") with open(os.path.join(path, "history.json"), "w") as f: json.dump(history.history, f) - with open(os.path.join(path, "history_lr.json"), "w") as f: - json.dump(lr_sc.lr, f) + if lr_sc: + with open(os.path.join(path, "history_lr.json"), "w") as f: + json.dump(lr_sc.lr, f) def create_monitoring_plots(self, history: keras.callbacks.History, lr_sc: LearningRateDecay) -> None: """ @@ -174,4 +178,5 @@ class Training(RunEnvironment): PlotModelHistory(filename=filename, history=history, plot_metric="mse", main_branch=multiple_branches_used) # plot learning rate - PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc) + if lr_sc: + PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc)