diff --git a/src/plotting/training_monitoring.py b/src/plotting/training_monitoring.py index b18cce7a8993899621295644016f1e126d0dfac8..dee36166abacf275c79213305ac15e1918e1957c 100644 --- a/src/plotting/training_monitoring.py +++ b/src/plotting/training_monitoring.py @@ -19,12 +19,13 @@ lr_object = Union[Dict, LearningRateDecay] class PlotModelHistory: """ - Plots history of all losses for a training event. For default loss and val_loss are plotted. If further losses are - provided (name must somehow include the word `loss`), this additional information is added to the plot with an - separate y-axis scale on the right side (shared for all additional losses). The plot is saved locally. For a proper - saving behaviour, the parameter filename must include the absolute path for the plot. + Plots history of all plot_metrics (default: loss) for a training event. For default plot_metric and val_plot_metric + are plotted. If further metrics are provided (name must somehow include the word `<plot_metric>`), this additional + information is added to the plot with an separate y-axis scale on the right side (shared for all additional + metrics). The plot is saved locally. For a proper saving behaviour, the parameter filename must include the absolute + path for the plot. """ - def __init__(self, filename: str, history: history_object): + def __init__(self, filename: str, history: history_object, plot_metric: str = "loss"): """ Sets attributes and create plot :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a @@ -34,31 +35,34 @@ class PlotModelHistory: if isinstance(history, keras.callbacks.History): history = history.history self._data = pd.DataFrame.from_dict(history) + self._plot_metric = plot_metric self._additional_columns = self._filter_columns(history) self._plot(filename) - @staticmethod - def _filter_columns(history: Dict) -> List[str]: + def _filter_columns(self, history: Dict) -> List[str]: """ - Select only columns named like %loss%. The default losses 'loss' and 'val_loss' are also removed. - :param history: a dict with at least 'loss' and 'val_loss' as keys (can be derived from keras History.history) - :return: filtered columns including all loss variations except loss and val_loss. + Select only columns named like %<plot_metric>%. The default metrics '<plot_metric>' and 'val_<plot_metric>' are + also removed. + :param history: a dict with at least '<plot_metric>' and 'val_<plot_metric>' as keys (can be derived from keras + History.history) + :return: filtered columns including all plot_metric variations except <plot_metric> and val_<plot_metric>. """ - cols = list(filter(lambda x: "loss" in x, history.keys())) - cols.remove("val_loss") - cols.remove("loss") + cols = list(filter(lambda x: self._plot_metric in x, history.keys())) + cols.remove(f"val_{self._plot_metric}") + cols.remove(self._plot_metric) return cols def _plot(self, filename: str) -> None: """ - Actual plot routine. Plots loss and val_loss as default. If more losses are provided, they will be added with - an additional yaxis on the right side. The plot is saved in filename. + Actual plot routine. Plots <plot_metric> and val_<plot_metric> as default. If more plot_metrics are provided, + they will be added with an additional yaxis on the right side. The plot is saved in filename. :param filename: name (including total path) of the plot to save. """ - ax = self._data[["loss", "val_loss"]].plot(linewidth=0.7) + ax = self._data[[self._plot_metric, f"val_{self._plot_metric}"]].plot(linewidth=0.7) if len(self._additional_columns) > 0: self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax) - ax.set(xlabel="epoch", ylabel="loss", title=f"Model loss: best = {self._data[['val_loss']].min().values}") + title = f"Model {self._plot_metric}: best = {self._data[[f'val_{self._plot_metric}']].min().values}" + ax.set(xlabel="epoch", ylabel=self._plot_metric, title=title) ax.axhline(y=0, color="gray", linewidth=0.5) plt.tight_layout() plt.savefig(filename)