From 5114b0cc4d02595e49ca4c2253ab44b1b43b78f0 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Thu, 12 May 2022 17:38:11 +0200
Subject: [PATCH] can now properly load best epoch and best metric value when
 resuming training

---
 mlair/model_modules/keras_extensions.py | 11 +++++++++--
 mlair/run_modules/training.py           |  2 +-
 2 files changed, 10 insertions(+), 3 deletions(-)

diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py
index d36e808b..72f40e45 100644
--- a/mlair/model_modules/keras_extensions.py
+++ b/mlair/model_modules/keras_extensions.py
@@ -163,7 +163,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
     def __init__(self, *args, **kwargs):
         """Initialise ModelCheckpointAdvanced and set callbacks attribute."""
         self.callbacks = kwargs.pop("callbacks")
-        self.epoch_best = kwargs.pop("epoch_best", 0)
+        self.epoch_best = None
         super().__init__(*args, **kwargs)
 
     def update_best(self, hist):
@@ -177,7 +177,14 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
 
         :param hist: The History object from the previous (interrupted) training.
         """
-        self.best = hist.history.get(self.monitor)[-1]
+        f = np.min if self.monitor_op.__name__ == "less" else np.max
+        f_loc = lambda x: np.where(x == f(x))[0][-1]
+        _d = hist.history.get(self.monitor)
+        loc = f_loc(_d)
+        assert f(_d) == _d[loc]
+        self.epoch_best = loc
+        self.best = _d[loc]
+        logging.info(f"Set best epoch {self.epoch_best + 1} with {self.monitor}={self.best}")
 
     def update_callbacks(self, callbacks):
         """
diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py
index 5ddf91eb..53c02332 100644
--- a/mlair/run_modules/training.py
+++ b/mlair/run_modules/training.py
@@ -221,7 +221,7 @@ class Training(RunEnvironment):
             with open(os.path.join(path, "epo_timing.json"), "w") as f:
                 json.dump(epo_timing.epo_timing, f)
 
-    def create_monitoring_plots(self, history: Callback, lr_sc: Callback, epoch_best: int) -> None:
+    def create_monitoring_plots(self, history: Callback, lr_sc: Callback, epoch_best: int = None) -> None:
         """
         Create plot of history and learning rate in dependence of the number of epochs.
 
-- 
GitLab