From 944d384ef6d2d97f0beb6ff255ca196769f772db Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Wed, 11 Dec 2019 13:46:49 +0100 Subject: [PATCH] two new plot routines: plot history and learning rate vs epochs --- requirements.txt | 1 + src/modules/experiment_setup.py | 6 +- src/modules/training.py | 17 +++- src/plotting/__init__.py | 0 src/plotting/training_monitoring.py | 94 +++++++++++++++++++ test/test_modules/test_training.py | 35 +++++-- .../test_plotting/test_training_monitoring.py | 83 ++++++++++++++++ 7 files changed, 226 insertions(+), 10 deletions(-) create mode 100644 src/plotting/__init__.py create mode 100644 src/plotting/training_monitoring.py create mode 100644 test/test_plotting/test_training_monitoring.py diff --git a/requirements.txt b/requirements.txt index d2a7200b..7ac250ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ pytest-cov pytest-html pydot mock +matplotlib diff --git a/src/modules/experiment_setup.py b/src/modules/experiment_setup.py index 472173c7..5fdc1f1f 100644 --- a/src/modules/experiment_setup.py +++ b/src/modules/experiment_setup.py @@ -5,6 +5,7 @@ __date__ = '2019-11-15' import logging import argparse from typing import Union, Dict, Any +import os from src import helpers from src.modules.run_environment import RunEnvironment @@ -32,7 +33,7 @@ class ExperimentSetup(RunEnvironment): window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None, limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None, use_all_stations_on_all_data_sets=True, trainable=False, fraction_of_train=None, - experiment_path=None): + experiment_path=None, plot_path=None): # create run framework super().__init__() @@ -48,6 +49,9 @@ class ExperimentSetup(RunEnvironment): self._set_param("experiment_name", exp_name) self._set_param("experiment_path", exp_path) helpers.check_path_and_create(self.data_store.get("experiment_path", "general")) + default_plot_path = os.path.join(exp_path, "plots") + self._set_param("plot_path", plot_path, default=default_plot_path) + helpers.check_path_and_create(self.data_store.get("plot_path", "general")) # setup for data self._set_param("var_all_dict", var_all_dict, default=DEFAULT_VAR_ALL_DICT) diff --git a/src/modules/training.py b/src/modules/training.py index 87dcf35e..ba70e3b4 100644 --- a/src/modules/training.py +++ b/src/modules/training.py @@ -9,6 +9,8 @@ import keras from src.modules.run_environment import RunEnvironment from src.data_handling.data_distributor import Distributor +from src.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate +from src.helpers import LearningRateDecay class Training(RunEnvironment): @@ -84,6 +86,7 @@ class Training(RunEnvironment): callbacks=[self.checkpoint, self.lr_sc]) self.save_callbacks(history) self.load_best_model(self.checkpoint.filepath) + self.create_monitoring_plots(history, self.lr_sc) def save_model(self) -> None: """ @@ -121,5 +124,15 @@ class Training(RunEnvironment): with open(os.path.join(path, "history_lr.json"), "w") as f: json.dump(self.lr_sc.lr, f) - - + def create_monitoring_plots(self, history: keras.callbacks.History, lr_sc: LearningRateDecay) -> None: + """ + Creates the history and learning rate plot in dependence of the number of epochs. The plots are saved in the + experiment's plot_path. History plot is named '<exp_name>_history_loss_val_loss.pdf', the learning rate with + '<exp_name>_history_learning_rate.pdf'. + :param history: keras history object with losses to plot (must include 'loss' and 'val_loss') + :param lr_sc: learning rate decay object with 'lr' attribute + """ + path = self.data_store.get("plot_path", "general") + name = self.data_store.get("experiment_name", "general") + PlotModelHistory(filename=os.path.join(path, f"{name}_history_loss_val_loss.pdf"), history=history) + PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc) diff --git a/src/plotting/__init__.py b/src/plotting/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/plotting/training_monitoring.py b/src/plotting/training_monitoring.py new file mode 100644 index 00000000..b18cce7a --- /dev/null +++ b/src/plotting/training_monitoring.py @@ -0,0 +1,94 @@ +__author__ = 'Felix Kleinert, Lukas Leufen' +__date__ = '2019-12-11' + + +from typing import Union, Dict, List + +import keras +import pandas as pd +import matplotlib +import matplotlib.pyplot as plt + +from src.helpers import LearningRateDecay + + +matplotlib.use('Agg') +history_object = Union[Dict, keras.callbacks.History] +lr_object = Union[Dict, LearningRateDecay] + + +class PlotModelHistory: + """ + Plots history of all losses for a training event. For default loss and val_loss are plotted. If further losses are + provided (name must somehow include the word `loss`), this additional information is added to the plot with an + separate y-axis scale on the right side (shared for all additional losses). The plot is saved locally. For a proper + saving behaviour, the parameter filename must include the absolute path for the plot. + """ + def __init__(self, filename: str, history: history_object): + """ + Sets attributes and create plot + :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a + format ending like .pdf or .png to work. + :param history: the history object (or a dict with at least 'loss' and 'val_loss' as keys) to plot loss from + """ + if isinstance(history, keras.callbacks.History): + history = history.history + self._data = pd.DataFrame.from_dict(history) + self._additional_columns = self._filter_columns(history) + self._plot(filename) + + @staticmethod + def _filter_columns(history: Dict) -> List[str]: + """ + Select only columns named like %loss%. The default losses 'loss' and 'val_loss' are also removed. + :param history: a dict with at least 'loss' and 'val_loss' as keys (can be derived from keras History.history) + :return: filtered columns including all loss variations except loss and val_loss. + """ + cols = list(filter(lambda x: "loss" in x, history.keys())) + cols.remove("val_loss") + cols.remove("loss") + return cols + + def _plot(self, filename: str) -> None: + """ + Actual plot routine. Plots loss and val_loss as default. If more losses are provided, they will be added with + an additional yaxis on the right side. The plot is saved in filename. + :param filename: name (including total path) of the plot to save. + """ + ax = self._data[["loss", "val_loss"]].plot(linewidth=0.7) + if len(self._additional_columns) > 0: + self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax) + ax.set(xlabel="epoch", ylabel="loss", title=f"Model loss: best = {self._data[['val_loss']].min().values}") + ax.axhline(y=0, color="gray", linewidth=0.5) + plt.tight_layout() + plt.savefig(filename) + plt.close("all") + + +class PlotModelLearningRate: + """ + Plots the behaviour of the learning rate in dependence of the number of epochs. The plot is saved locally as pdf. + For a proper saving behaviour, the parameter filename must include the absolute path for the plot. + """ + def __init__(self, filename: str, lr_sc: lr_object): + """ + Sets attributes and create plot + :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a + format ending like .pdf or .png to work. + :param lr_sc: the learning rate object (or a dict with `lr` as key) to plot from + """ + if isinstance(lr_sc, LearningRateDecay): + lr_sc = lr_sc.lr + self._data = pd.DataFrame.from_dict(lr_sc) + self._plot(filename) + + def _plot(self, filename: str) -> None: + """ + Actual plot routine. Plots the learning rate in dependence of epoch. + :param filename: name (including total path) of the plot to save. + """ + ax = self._data.plot(linewidth=0.7) + ax.set(xlabel="epoch", ylabel="learning rate") + plt.tight_layout() + plt.savefig(filename) + plt.close("all") diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py index c91cb691..ddb301c6 100644 --- a/test/test_modules/test_training.py +++ b/test/test_modules/test_training.py @@ -6,6 +6,7 @@ import os import json import shutil import logging +import glob from src.inception_model import InceptionModelBase from src.flatten import flatten_tail @@ -54,11 +55,23 @@ class TestTraining: os.makedirs(path) obj.data_store.put("experiment_path", path, "general") obj.data_store.put("experiment_name", "TestExperiment", "general") + path_plot = os.path.join(path, "plots") + os.makedirs(path_plot) + obj.data_store.put("plot_path", path_plot, "general") yield obj if os.path.exists(path): shutil.rmtree(path) RunEnvironment().__del__() + @pytest.fixture + def learning_rate(self): + return {"lr": [0.01, 0.0094]} + + @pytest.fixture + def init_with_lr(self, init_without_run, learning_rate): + init_without_run.lr_sc.lr = learning_rate + return init_without_run + @pytest.fixture def history(self): h = History() @@ -120,6 +133,9 @@ class TestTraining: obj.data_store.put("lr_decay", LearningRateDecay(), "general.model") obj.data_store.put("experiment_name", "TestExperiment", "general") obj.data_store.put("experiment_path", path, "general") + path_plot = os.path.join(path, "plots") + os.makedirs(path_plot) + obj.data_store.put("plot_path", path_plot, "general") yield obj if os.path.exists(path): shutil.rmtree(path) @@ -148,11 +164,13 @@ class TestTraining: assert not all([getattr(init_without_run, f"{obj}_set") is None for obj in sets]) assert all([getattr(init_without_run, f"{obj}_set").generator.return_value == f"mock_{obj}_gen" for obj in sets]) - def test_train(self, ready_to_train): + def test_train(self, ready_to_train, path): assert not hasattr(ready_to_train.model, "history") + assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0 ready_to_train.train() assert list(ready_to_train.model.history.history.keys()) == ["val_loss", "loss"] assert ready_to_train.model.history.epoch == [0, 1] + assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 def test_save_model(self, init_without_run, path, caplog): caplog.set_level(logging.DEBUG) @@ -172,8 +190,8 @@ class TestTraining: init_without_run.save_callbacks(history) assert "history.json" in os.listdir(path) - def test_save_callbacks_lr_created(self, init_without_run, history, path): - init_without_run.save_callbacks(history) + def test_save_callbacks_lr_created(self, init_with_lr, history, path): + init_with_lr.save_callbacks(history) assert "history_lr.json" in os.listdir(path) def test_save_callbacks_inspect_history(self, init_without_run, history, path): @@ -182,10 +200,13 @@ class TestTraining: hist = json.load(jfile) assert hist == history.history - def test_save_callbacks_inspect_lr(self, init_without_run, history, path): - init_without_run.save_callbacks(history) + def test_save_callbacks_inspect_lr(self, init_with_lr, history, path): + init_with_lr.save_callbacks(history) with open(os.path.join(path, "history_lr.json")) as jfile: lr = json.load(jfile) - assert lr == init_without_run.lr_sc.lr - + assert lr == init_with_lr.lr_sc.lr + def test_create_monitoring_plots(self, init_without_run, learning_rate, history, path): + assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0 + init_without_run.create_monitoring_plots(history, learning_rate) + assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 diff --git a/test/test_plotting/test_training_monitoring.py b/test/test_plotting/test_training_monitoring.py new file mode 100644 index 00000000..d38fc623 --- /dev/null +++ b/test/test_plotting/test_training_monitoring.py @@ -0,0 +1,83 @@ +import keras +import pytest +import os + +from src.plotting.training_monitoring import PlotModelLearningRate, PlotModelHistory +from src.helpers import LearningRateDecay + + +@pytest.fixture +def path(): + p = os.path.join(os.path.dirname(__file__), "TestExperiment") + if not os.path.exists(p): + os.makedirs(p) + return p + + +class TestPlotModelHistory: + + @pytest.fixture + def history(self): + hist = keras.callbacks.History() + hist.epoch = [0, 1] + hist.history = {'val_loss': [0.5586272982587484, 0.45712877659670287], + 'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287], + 'val_mean_absolute_error': [0.595368885413389, 0.530547587585537], + 'loss': [0.6795708956961347, 0.45963566494176616], + 'mean_squared_error': [0.6795708956961347, 0.45963566494176616], + 'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]} + return hist + + @pytest.fixture + def history_var(self): + hist = keras.callbacks.History() + hist.epoch = [0, 1] + hist.history = {'val_loss': [0.5586272982587484, 0.45712877659670287], + 'test_loss': [0.595368885413389, 0.530547587585537], + 'loss': [0.6795708956961347, 0.45963566494176616], + 'mean_squared_error': [0.6795708956961347, 0.45963566494176616], + 'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]} + return hist + + @pytest.fixture + def no_init(self): + return object.__new__(PlotModelHistory) + + def test_plot_from_hist_obj(self, history, path): + assert "hist_obj.pdf" not in os.listdir(path) + PlotModelHistory(os.path.join(path, "hist_obj.pdf"), history) + assert "hist_obj.pdf" in os.listdir(path) + + def test_plot_from_hist_dict(self, history, path): + assert "hist_dict.pdf" not in os.listdir(path) + PlotModelHistory(os.path.join(path, "hist_dict.pdf"), history.history) + assert "hist_dict.pdf" in os.listdir(path) + + def test_plot_additional_loss(self, history_var, path): + assert "hist_additional.pdf" not in os.listdir(path) + PlotModelHistory(os.path.join(path, "hist_additional.pdf"), history_var) + assert "hist_additional.pdf" in os.listdir(path) + + def test_filter_list(self, no_init): + res = no_init._filter_columns({'loss': None, 'another_loss': None, 'val_loss': None, 'wrong': None}) + assert res == ['another_loss'] + + +class TestPlotModelLearningRate: + + @pytest.fixture + def learning_rate(self): + lr_obj = LearningRateDecay() + lr_obj.lr = {"lr": [0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.0094, 0.0094, 0.0094, 0.0094, + 0.0094, 0.0094, 0.0094, 0.0094, 0.0094, 0.0094]} + return lr_obj + + def test_plot_from_lr_obj(self, learning_rate, path): + assert "lr_obj.pdf" not in os.listdir(path) + PlotModelLearningRate(os.path.join(path, "lr_obj.pdf"), learning_rate) + assert "lr_obj.pdf" in os.listdir(path) + + def test_plot_from_lr_dict(self, learning_rate, path): + assert "lr_dict.pdf" not in os.listdir(path) + PlotModelLearningRate(os.path.join(path, "lr_dict.pdf"), learning_rate.lr) + assert "lr_dict.pdf" in os.listdir(path) -- GitLab