diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py index d36e808b1024e597e04d25c38853d79425cd89e7..72f40e453c37bcfe16566336fbfd56eb2734ae9c 100644 --- a/mlair/model_modules/keras_extensions.py +++ b/mlair/model_modules/keras_extensions.py @@ -163,7 +163,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint): def __init__(self, *args, **kwargs): """Initialise ModelCheckpointAdvanced and set callbacks attribute.""" self.callbacks = kwargs.pop("callbacks") - self.epoch_best = kwargs.pop("epoch_best", 0) + self.epoch_best = None super().__init__(*args, **kwargs) def update_best(self, hist): @@ -177,7 +177,14 @@ class ModelCheckpointAdvanced(ModelCheckpoint): :param hist: The History object from the previous (interrupted) training. """ - self.best = hist.history.get(self.monitor)[-1] + f = np.min if self.monitor_op.__name__ == "less" else np.max + f_loc = lambda x: np.where(x == f(x))[0][-1] + _d = hist.history.get(self.monitor) + loc = f_loc(_d) + assert f(_d) == _d[loc] + self.epoch_best = loc + self.best = _d[loc] + logging.info(f"Set best epoch {self.epoch_best + 1} with {self.monitor}={self.best}") def update_callbacks(self, callbacks): """ diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index 5ddf91ebf6659d08e1163aceee6000a8082f0bef..53c023325d6ab94ff40c84b5e5ed1045f450f372 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -221,7 +221,7 @@ class Training(RunEnvironment): with open(os.path.join(path, "epo_timing.json"), "w") as f: json.dump(epo_timing.epo_timing, f) - def create_monitoring_plots(self, history: Callback, lr_sc: Callback, epoch_best: int) -> None: + def create_monitoring_plots(self, history: Callback, lr_sc: Callback, epoch_best: int = None) -> None: """ Create plot of history and learning rate in dependence of the number of epochs.