Skip to content
Snippets Groups Projects
Commit 5114b0cc authored by leufen1's avatar leufen1
Browse files

can now properly load best epoch and best metric value when resuming training

parent 4574ee9f
No related branches found
No related tags found
4 merge requests!432IOA works now also with xarray and on identical data, IOA is included in...,!431Resolve "release v2.1.0",!430update recent developments,!419Resolve "loss plot with best result marker"
Pipeline #100269 failed
......@@ -163,7 +163,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
def __init__(self, *args, **kwargs):
"""Initialise ModelCheckpointAdvanced and set callbacks attribute."""
self.callbacks = kwargs.pop("callbacks")
self.epoch_best = kwargs.pop("epoch_best", 0)
self.epoch_best = None
super().__init__(*args, **kwargs)
def update_best(self, hist):
......@@ -177,7 +177,14 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
:param hist: The History object from the previous (interrupted) training.
"""
self.best = hist.history.get(self.monitor)[-1]
f = np.min if self.monitor_op.__name__ == "less" else np.max
f_loc = lambda x: np.where(x == f(x))[0][-1]
_d = hist.history.get(self.monitor)
loc = f_loc(_d)
assert f(_d) == _d[loc]
self.epoch_best = loc
self.best = _d[loc]
logging.info(f"Set best epoch {self.epoch_best + 1} with {self.monitor}={self.best}")
def update_callbacks(self, callbacks):
"""
......
......@@ -221,7 +221,7 @@ class Training(RunEnvironment):
with open(os.path.join(path, "epo_timing.json"), "w") as f:
json.dump(epo_timing.epo_timing, f)
def create_monitoring_plots(self, history: Callback, lr_sc: Callback, epoch_best: int) -> None:
def create_monitoring_plots(self, history: Callback, lr_sc: Callback, epoch_best: int = None) -> None:
"""
Create plot of history and learning rate in dependence of the number of epochs.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment