Skip to content
Snippets Groups Projects
Commit f1739aff authored by lukas leufen's avatar lukas leufen
Browse files

abstracted PlotModelHistory to be not only able to plot loss but also other...

abstracted PlotModelHistory to be not only able to plot loss but also other metrics from the history callback object for a given plot metric
parent 68b40993
No related branches found
No related tags found
2 merge requests!37include new development,!26Lukas issue034 feat plot mse history
...@@ -19,12 +19,13 @@ lr_object = Union[Dict, LearningRateDecay] ...@@ -19,12 +19,13 @@ lr_object = Union[Dict, LearningRateDecay]
class PlotModelHistory: class PlotModelHistory:
""" """
Plots history of all losses for a training event. For default loss and val_loss are plotted. If further losses are Plots history of all plot_metrics (default: loss) for a training event. For default plot_metric and val_plot_metric
provided (name must somehow include the word `loss`), this additional information is added to the plot with an are plotted. If further metrics are provided (name must somehow include the word `<plot_metric>`), this additional
separate y-axis scale on the right side (shared for all additional losses). The plot is saved locally. For a proper information is added to the plot with an separate y-axis scale on the right side (shared for all additional
saving behaviour, the parameter filename must include the absolute path for the plot. metrics). The plot is saved locally. For a proper saving behaviour, the parameter filename must include the absolute
path for the plot.
""" """
def __init__(self, filename: str, history: history_object): def __init__(self, filename: str, history: history_object, plot_metric: str = "loss"):
""" """
Sets attributes and create plot Sets attributes and create plot
:param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a
...@@ -34,31 +35,34 @@ class PlotModelHistory: ...@@ -34,31 +35,34 @@ class PlotModelHistory:
if isinstance(history, keras.callbacks.History): if isinstance(history, keras.callbacks.History):
history = history.history history = history.history
self._data = pd.DataFrame.from_dict(history) self._data = pd.DataFrame.from_dict(history)
self._plot_metric = plot_metric
self._additional_columns = self._filter_columns(history) self._additional_columns = self._filter_columns(history)
self._plot(filename) self._plot(filename)
@staticmethod def _filter_columns(self, history: Dict) -> List[str]:
def _filter_columns(history: Dict) -> List[str]:
""" """
Select only columns named like %loss%. The default losses 'loss' and 'val_loss' are also removed. Select only columns named like %<plot_metric>%. The default metrics '<plot_metric>' and 'val_<plot_metric>' are
:param history: a dict with at least 'loss' and 'val_loss' as keys (can be derived from keras History.history) also removed.
:return: filtered columns including all loss variations except loss and val_loss. :param history: a dict with at least '<plot_metric>' and 'val_<plot_metric>' as keys (can be derived from keras
History.history)
:return: filtered columns including all plot_metric variations except <plot_metric> and val_<plot_metric>.
""" """
cols = list(filter(lambda x: "loss" in x, history.keys())) cols = list(filter(lambda x: self._plot_metric in x, history.keys()))
cols.remove("val_loss") cols.remove(f"val_{self._plot_metric}")
cols.remove("loss") cols.remove(self._plot_metric)
return cols return cols
def _plot(self, filename: str) -> None: def _plot(self, filename: str) -> None:
""" """
Actual plot routine. Plots loss and val_loss as default. If more losses are provided, they will be added with Actual plot routine. Plots <plot_metric> and val_<plot_metric> as default. If more plot_metrics are provided,
an additional yaxis on the right side. The plot is saved in filename. they will be added with an additional yaxis on the right side. The plot is saved in filename.
:param filename: name (including total path) of the plot to save. :param filename: name (including total path) of the plot to save.
""" """
ax = self._data[["loss", "val_loss"]].plot(linewidth=0.7) ax = self._data[[self._plot_metric, f"val_{self._plot_metric}"]].plot(linewidth=0.7)
if len(self._additional_columns) > 0: if len(self._additional_columns) > 0:
self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax) self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax)
ax.set(xlabel="epoch", ylabel="loss", title=f"Model loss: best = {self._data[['val_loss']].min().values}") title = f"Model {self._plot_metric}: best = {self._data[[f'val_{self._plot_metric}']].min().values}"
ax.set(xlabel="epoch", ylabel=self._plot_metric, title=title)
ax.axhline(y=0, color="gray", linewidth=0.5) ax.axhline(y=0, color="gray", linewidth=0.5)
plt.tight_layout() plt.tight_layout()
plt.savefig(filename) plt.savefig(filename)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment