Skip to content
Snippets Groups Projects
Commit f1739aff authored by lukas leufen's avatar lukas leufen
Browse files

abstracted PlotModelHistory to be not only able to plot loss but also other...

abstracted PlotModelHistory to be not only able to plot loss but also other metrics from the history callback object for a given plot metric
parent 68b40993
No related branches found
No related tags found
2 merge requests!37include new development,!26Lukas issue034 feat plot mse history
......@@ -19,12 +19,13 @@ lr_object = Union[Dict, LearningRateDecay]
class PlotModelHistory:
"""
Plots history of all losses for a training event. For default loss and val_loss are plotted. If further losses are
provided (name must somehow include the word `loss`), this additional information is added to the plot with an
separate y-axis scale on the right side (shared for all additional losses). The plot is saved locally. For a proper
saving behaviour, the parameter filename must include the absolute path for the plot.
Plots history of all plot_metrics (default: loss) for a training event. For default plot_metric and val_plot_metric
are plotted. If further metrics are provided (name must somehow include the word `<plot_metric>`), this additional
information is added to the plot with an separate y-axis scale on the right side (shared for all additional
metrics). The plot is saved locally. For a proper saving behaviour, the parameter filename must include the absolute
path for the plot.
"""
def __init__(self, filename: str, history: history_object):
def __init__(self, filename: str, history: history_object, plot_metric: str = "loss"):
"""
Sets attributes and create plot
:param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a
......@@ -34,31 +35,34 @@ class PlotModelHistory:
if isinstance(history, keras.callbacks.History):
history = history.history
self._data = pd.DataFrame.from_dict(history)
self._plot_metric = plot_metric
self._additional_columns = self._filter_columns(history)
self._plot(filename)
@staticmethod
def _filter_columns(history: Dict) -> List[str]:
def _filter_columns(self, history: Dict) -> List[str]:
"""
Select only columns named like %loss%. The default losses 'loss' and 'val_loss' are also removed.
:param history: a dict with at least 'loss' and 'val_loss' as keys (can be derived from keras History.history)
:return: filtered columns including all loss variations except loss and val_loss.
Select only columns named like %<plot_metric>%. The default metrics '<plot_metric>' and 'val_<plot_metric>' are
also removed.
:param history: a dict with at least '<plot_metric>' and 'val_<plot_metric>' as keys (can be derived from keras
History.history)
:return: filtered columns including all plot_metric variations except <plot_metric> and val_<plot_metric>.
"""
cols = list(filter(lambda x: "loss" in x, history.keys()))
cols.remove("val_loss")
cols.remove("loss")
cols = list(filter(lambda x: self._plot_metric in x, history.keys()))
cols.remove(f"val_{self._plot_metric}")
cols.remove(self._plot_metric)
return cols
def _plot(self, filename: str) -> None:
"""
Actual plot routine. Plots loss and val_loss as default. If more losses are provided, they will be added with
an additional yaxis on the right side. The plot is saved in filename.
Actual plot routine. Plots <plot_metric> and val_<plot_metric> as default. If more plot_metrics are provided,
they will be added with an additional yaxis on the right side. The plot is saved in filename.
:param filename: name (including total path) of the plot to save.
"""
ax = self._data[["loss", "val_loss"]].plot(linewidth=0.7)
ax = self._data[[self._plot_metric, f"val_{self._plot_metric}"]].plot(linewidth=0.7)
if len(self._additional_columns) > 0:
self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax)
ax.set(xlabel="epoch", ylabel="loss", title=f"Model loss: best = {self._data[['val_loss']].min().values}")
title = f"Model {self._plot_metric}: best = {self._data[[f'val_{self._plot_metric}']].min().values}"
ax.set(xlabel="epoch", ylabel=self._plot_metric, title=title)
ax.axhline(y=0, color="gray", linewidth=0.5)
plt.tight_layout()
plt.savefig(filename)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment