Skip to content
Snippets Groups Projects
Commit 601aaeef authored by leufen1's avatar leufen1
Browse files

training monitoring plot can handle best epoch parameter, training history...

training monitoring plot can handle best epoch parameter, training history stores best epoch, not tested for resuming train process
parent 8f51e9d0
No related branches found
No related tags found
4 merge requests!432IOA works now also with xarray and on identical data, IOA is included in...,!431Resolve "release v2.1.0",!430update recent developments,!419Resolve "loss plot with best result marker"
Pipeline #99868 failed
......@@ -163,6 +163,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
def __init__(self, *args, **kwargs):
"""Initialise ModelCheckpointAdvanced and set callbacks attribute."""
self.callbacks = kwargs.pop("callbacks")
self.epoch_best = kwargs.pop("epoch_best", 0)
super().__init__(*args, **kwargs)
def update_best(self, hist):
......@@ -197,6 +198,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
if self.save_best_only:
current = logs.get(self.monitor)
if current == self.best:
self.epoch_best = epoch
if self.verbose > 0: # pragma: no branch
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
with open(file_path, "wb") as f:
......
......@@ -27,7 +27,8 @@ class PlotModelHistory:
parameter filename must include the absolute path for the plot.
"""
def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = False):
def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = False,
epoch_best: int = None):
"""
Set attributes and create plot.
......@@ -37,12 +38,15 @@ class PlotModelHistory:
:param plot_metric: the metric to plot (e.b. mean_squared_error, mse, mean_absolute_error, loss, default: loss)
:param main_branch: switch between only looking for metrics that go with 'main' or for all occurrences (default:
False -> look for losses from all branches, not only from main)
:param epoch_best: indicator at which epoch the best train result was achieved (should start counting at 0)
"""
if isinstance(history, keras.callbacks.History):
history = history.history
self._data = pd.DataFrame.from_dict(history)
self._data.index += 1
self._plot_metric = self._get_plot_metric(history, plot_metric, main_branch)
self._additional_columns = self._filter_columns(history)
self._epoch_best = epoch_best
self._plot(filename)
def _get_plot_metric(self, history, plot_metric, main_branch, correct_names=True):
......@@ -88,6 +92,9 @@ class PlotModelHistory:
:param filename: name (including total path) of the plot to save.
"""
ax = self._data[[self._plot_metric, f"val_{self._plot_metric}"]].plot(linewidth=0.7)
if self._epoch_best is not None:
ax.scatter(self._epoch_best+1, self._data[[f"val_{self._plot_metric}"]].iloc[self._epoch_best],
s=100, marker="*", c="black")
ax.set_yscale('log')
if len(self._additional_columns) > 0:
self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax, logy=True)
......
......@@ -165,6 +165,8 @@ class Training(RunEnvironment):
initial_epoch=initial_epoch,
workers=psutil.cpu_count(logical=False))
history = hist
epoch_best = checkpoint.epoch_best
logging.info(f"best epoch: {epoch_best + 1}")
try:
lr = self.callbacks.get_callback_by_name("lr")
except IndexError:
......@@ -175,7 +177,7 @@ class Training(RunEnvironment):
epo_timing = None
self.save_callbacks_as_json(history, lr, epo_timing)
self.load_best_model(checkpoint.filepath)
self.create_monitoring_plots(history, lr)
self.create_monitoring_plots(history, lr, epoch_best)
def save_model(self) -> None:
"""Save model in local experiment directory. Model is named as `<experiment_name>_<custom_model_name>.h5`."""
......@@ -194,9 +196,9 @@ class Training(RunEnvironment):
logging.debug(f"load best model: {name}")
try:
self.model.load_model(name, compile=True)
logging.info('reload model...')
logging.info(f"reload model...")
except OSError:
logging.info('no weights to reload...')
logging.info("no weights to reload...")
def save_callbacks_as_json(self, history: Callback, lr_sc: Callback, epo_timing: Callback) -> None:
"""
......@@ -219,7 +221,7 @@ class Training(RunEnvironment):
with open(os.path.join(path, "epo_timing.json"), "w") as f:
json.dump(epo_timing.epo_timing, f)
def create_monitoring_plots(self, history: Callback, lr_sc: Callback) -> None:
def create_monitoring_plots(self, history: Callback, lr_sc: Callback, epoch_best: int) -> None:
"""
Create plot of history and learning rate in dependence of the number of epochs.
......@@ -228,22 +230,23 @@ class Training(RunEnvironment):
:param history: keras history object with losses to plot (must at least include `loss` and `val_loss`)
:param lr_sc: learning rate decay object with 'lr' attribute
:param epoch_best: number of best epoch (starts counting as 0)
"""
path = self.data_store.get("plot_path")
name = self.data_store.get("experiment_name")
# plot history of loss and mse (if available)
filename = os.path.join(path, f"{name}_history_loss.pdf")
PlotModelHistory(filename=filename, history=history)
PlotModelHistory(filename=filename, history=history, epoch_best=epoch_best)
multiple_branches_used = len(history.model.output_names) > 1 # means that there are multiple output branches
if multiple_branches_used:
filename = os.path.join(path, f"{name}_history_main_loss.pdf")
PlotModelHistory(filename=filename, history=history, main_branch=True)
PlotModelHistory(filename=filename, history=history, main_branch=True, epoch_best=epoch_best)
mse_indicator = list(set(history.model.metrics_names).intersection(["mean_squared_error", "mse"]))
if len(mse_indicator) > 0:
filename = os.path.join(path, f"{name}_history_main_mse.pdf")
PlotModelHistory(filename=filename, history=history, plot_metric=mse_indicator[0],
main_branch=multiple_branches_used)
main_branch=multiple_branches_used, epoch_best=epoch_best)
# plot learning rate
if lr_sc:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment