From 5114b0cc4d02595e49ca4c2253ab44b1b43b78f0 Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Thu, 12 May 2022 17:38:11 +0200 Subject: [PATCH] can now properly load best epoch and best metric value when resuming training --- mlair/model_modules/keras_extensions.py | 11 +++++++++-- mlair/run_modules/training.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py index d36e808b..72f40e45 100644 --- a/mlair/model_modules/keras_extensions.py +++ b/mlair/model_modules/keras_extensions.py @@ -163,7 +163,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint): def __init__(self, *args, **kwargs): """Initialise ModelCheckpointAdvanced and set callbacks attribute.""" self.callbacks = kwargs.pop("callbacks") - self.epoch_best = kwargs.pop("epoch_best", 0) + self.epoch_best = None super().__init__(*args, **kwargs) def update_best(self, hist): @@ -177,7 +177,14 @@ class ModelCheckpointAdvanced(ModelCheckpoint): :param hist: The History object from the previous (interrupted) training. """ - self.best = hist.history.get(self.monitor)[-1] + f = np.min if self.monitor_op.__name__ == "less" else np.max + f_loc = lambda x: np.where(x == f(x))[0][-1] + _d = hist.history.get(self.monitor) + loc = f_loc(_d) + assert f(_d) == _d[loc] + self.epoch_best = loc + self.best = _d[loc] + logging.info(f"Set best epoch {self.epoch_best + 1} with {self.monitor}={self.best}") def update_callbacks(self, callbacks): """ diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index 5ddf91eb..53c02332 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -221,7 +221,7 @@ class Training(RunEnvironment): with open(os.path.join(path, "epo_timing.json"), "w") as f: json.dump(epo_timing.epo_timing, f) - def create_monitoring_plots(self, history: Callback, lr_sc: Callback, epoch_best: int) -> None: + def create_monitoring_plots(self, history: Callback, lr_sc: Callback, epoch_best: int = None) -> None: """ Create plot of history and learning rate in dependence of the number of epochs. -- GitLab