diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py index 8b99acd0f5723d3b00ec1bd0098712753da21b52..d36e808b1024e597e04d25c38853d79425cd89e7 100644 --- a/mlair/model_modules/keras_extensions.py +++ b/mlair/model_modules/keras_extensions.py @@ -163,6 +163,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint): def __init__(self, *args, **kwargs): """Initialise ModelCheckpointAdvanced and set callbacks attribute.""" self.callbacks = kwargs.pop("callbacks") + self.epoch_best = kwargs.pop("epoch_best", 0) super().__init__(*args, **kwargs) def update_best(self, hist): @@ -197,6 +198,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint): if self.save_best_only: current = logs.get(self.monitor) if current == self.best: + self.epoch_best = epoch if self.verbose > 0: # pragma: no branch print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) with open(file_path, "wb") as f: diff --git a/mlair/plotting/training_monitoring.py b/mlair/plotting/training_monitoring.py index 39dd80651226519463d7b503fb612e43983d73cf..4884dcb81c2b98546da3edce099c02b47aebd7b2 100644 --- a/mlair/plotting/training_monitoring.py +++ b/mlair/plotting/training_monitoring.py @@ -27,7 +27,8 @@ class PlotModelHistory: parameter filename must include the absolute path for the plot. """ - def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = False): + def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = False, + epoch_best: int = None): """ Set attributes and create plot. @@ -37,12 +38,15 @@ class PlotModelHistory: :param plot_metric: the metric to plot (e.b. mean_squared_error, mse, mean_absolute_error, loss, default: loss) :param main_branch: switch between only looking for metrics that go with 'main' or for all occurrences (default: False -> look for losses from all branches, not only from main) + :param epoch_best: indicator at which epoch the best train result was achieved (should start counting at 0) """ if isinstance(history, keras.callbacks.History): history = history.history self._data = pd.DataFrame.from_dict(history) + self._data.index += 1 self._plot_metric = self._get_plot_metric(history, plot_metric, main_branch) self._additional_columns = self._filter_columns(history) + self._epoch_best = epoch_best self._plot(filename) def _get_plot_metric(self, history, plot_metric, main_branch, correct_names=True): @@ -88,6 +92,9 @@ class PlotModelHistory: :param filename: name (including total path) of the plot to save. """ ax = self._data[[self._plot_metric, f"val_{self._plot_metric}"]].plot(linewidth=0.7) + if self._epoch_best is not None: + ax.scatter(self._epoch_best+1, self._data[[f"val_{self._plot_metric}"]].iloc[self._epoch_best], + s=100, marker="*", c="black") ax.set_yscale('log') if len(self._additional_columns) > 0: self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax, logy=True) diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index a38837dce041295d37fae1ea86ef2a215d51dc89..5ddf91ebf6659d08e1163aceee6000a8082f0bef 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -165,6 +165,8 @@ class Training(RunEnvironment): initial_epoch=initial_epoch, workers=psutil.cpu_count(logical=False)) history = hist + epoch_best = checkpoint.epoch_best + logging.info(f"best epoch: {epoch_best + 1}") try: lr = self.callbacks.get_callback_by_name("lr") except IndexError: @@ -175,7 +177,7 @@ class Training(RunEnvironment): epo_timing = None self.save_callbacks_as_json(history, lr, epo_timing) self.load_best_model(checkpoint.filepath) - self.create_monitoring_plots(history, lr) + self.create_monitoring_plots(history, lr, epoch_best) def save_model(self) -> None: """Save model in local experiment directory. Model is named as `<experiment_name>_<custom_model_name>.h5`.""" @@ -194,9 +196,9 @@ class Training(RunEnvironment): logging.debug(f"load best model: {name}") try: self.model.load_model(name, compile=True) - logging.info('reload model...') + logging.info(f"reload model...") except OSError: - logging.info('no weights to reload...') + logging.info("no weights to reload...") def save_callbacks_as_json(self, history: Callback, lr_sc: Callback, epo_timing: Callback) -> None: """ @@ -219,7 +221,7 @@ class Training(RunEnvironment): with open(os.path.join(path, "epo_timing.json"), "w") as f: json.dump(epo_timing.epo_timing, f) - def create_monitoring_plots(self, history: Callback, lr_sc: Callback) -> None: + def create_monitoring_plots(self, history: Callback, lr_sc: Callback, epoch_best: int) -> None: """ Create plot of history and learning rate in dependence of the number of epochs. @@ -228,22 +230,23 @@ class Training(RunEnvironment): :param history: keras history object with losses to plot (must at least include `loss` and `val_loss`) :param lr_sc: learning rate decay object with 'lr' attribute + :param epoch_best: number of best epoch (starts counting as 0) """ path = self.data_store.get("plot_path") name = self.data_store.get("experiment_name") # plot history of loss and mse (if available) filename = os.path.join(path, f"{name}_history_loss.pdf") - PlotModelHistory(filename=filename, history=history) + PlotModelHistory(filename=filename, history=history, epoch_best=epoch_best) multiple_branches_used = len(history.model.output_names) > 1 # means that there are multiple output branches if multiple_branches_used: filename = os.path.join(path, f"{name}_history_main_loss.pdf") - PlotModelHistory(filename=filename, history=history, main_branch=True) + PlotModelHistory(filename=filename, history=history, main_branch=True, epoch_best=epoch_best) mse_indicator = list(set(history.model.metrics_names).intersection(["mean_squared_error", "mse"])) if len(mse_indicator) > 0: filename = os.path.join(path, f"{name}_history_main_mse.pdf") PlotModelHistory(filename=filename, history=history, plot_metric=mse_indicator[0], - main_branch=multiple_branches_used) + main_branch=multiple_branches_used, epoch_best=epoch_best) # plot learning rate if lr_sc: