diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py index 0e31cd66fd64d19f1cecc7c9906a5c3b9446fe75..6b0fe236ff8ee726c34a721a6be0ed8be91f2bb8 100644 --- a/src/model_modules/model_class.py +++ b/src/model_modules/model_class.py @@ -219,15 +219,15 @@ class MyBranchedModel(AbstractModelClass): x_in = keras.layers.Dense(64, name='{}_Dense_64'.format("major"))(x_in) x_in = self.activation()(x_in) out_minor_1 = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("minor_1"))(x_in) - out_minor_1 = self.activation()(out_minor_1) + out_minor_1 = self.activation(name="minor_1")(out_minor_1) x_in = keras.layers.Dense(32, name='{}_Dense_32'.format("major"))(x_in) x_in = self.activation()(x_in) out_minor_2 = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("minor_2"))(x_in) - out_minor_2 = self.activation()(out_minor_2) + out_minor_2 = self.activation(name="minor_2")(out_minor_2) x_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(x_in) x_in = self.activation()(x_in) x_in = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("major"))(x_in) - out_main = self.activation()(x_in) + out_main = self.activation(name="main")(x_in) self.model = keras.Model(inputs=x_input, outputs=[out_minor_1, out_minor_2, out_main]) def set_loss(self): diff --git a/src/plotting/training_monitoring.py b/src/plotting/training_monitoring.py index b18cce7a8993899621295644016f1e126d0dfac8..87e4071a0ac98c15c22131cd5d3418eb7c1b6976 100644 --- a/src/plotting/training_monitoring.py +++ b/src/plotting/training_monitoring.py @@ -19,46 +19,66 @@ lr_object = Union[Dict, LearningRateDecay] class PlotModelHistory: """ - Plots history of all losses for a training event. For default loss and val_loss are plotted. If further losses are - provided (name must somehow include the word `loss`), this additional information is added to the plot with an - separate y-axis scale on the right side (shared for all additional losses). The plot is saved locally. For a proper - saving behaviour, the parameter filename must include the absolute path for the plot. + Plots history of all plot_metrics (default: loss) for a training event. For default plot_metric and val_plot_metric + are plotted. If further metrics are provided (name must somehow include the word `<plot_metric>`), this additional + information is added to the plot with an separate y-axis scale on the right side (shared for all additional + metrics). The plot is saved locally. For a proper saving behaviour, the parameter filename must include the absolute + path for the plot. """ - def __init__(self, filename: str, history: history_object): + def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = False): """ Sets attributes and create plot :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a format ending like .pdf or .png to work. :param history: the history object (or a dict with at least 'loss' and 'val_loss' as keys) to plot loss from + :param plot_metric: the metric to plot (e.b. mean_squared_error, mse, mean_absolute_error, loss, default: loss) + :param main_branch: switch between only looking for metrics that go with 'main' or for all occurrences (default: + False -> look for losses from all branches, not only from main) """ if isinstance(history, keras.callbacks.History): history = history.history self._data = pd.DataFrame.from_dict(history) + self._plot_metric = self._get_plot_metric(history, plot_metric, main_branch) self._additional_columns = self._filter_columns(history) self._plot(filename) @staticmethod - def _filter_columns(history: Dict) -> List[str]: + def _get_plot_metric(history, plot_metric, main_branch): + if plot_metric.lower() == "mse": + plot_metric = "mean_squared_error" + elif plot_metric.lower() == "mae": + plot_metric = "mean_absolute_error" + available_keys = [k for k in history.keys() if plot_metric in k and ("main" in k.lower() if main_branch else True)] + available_keys.sort(key=len) + return available_keys[0] + + def _filter_columns(self, history: Dict) -> List[str]: """ - Select only columns named like %loss%. The default losses 'loss' and 'val_loss' are also removed. - :param history: a dict with at least 'loss' and 'val_loss' as keys (can be derived from keras History.history) - :return: filtered columns including all loss variations except loss and val_loss. + Select only columns named like %<plot_metric>%. The default metrics '<plot_metric>' and 'val_<plot_metric>' are + also removed. + :param history: a dict with at least '<plot_metric>' and 'val_<plot_metric>' as keys (can be derived from keras + History.history) + :return: filtered columns including all plot_metric variations except <plot_metric> and val_<plot_metric>. """ - cols = list(filter(lambda x: "loss" in x, history.keys())) - cols.remove("val_loss") - cols.remove("loss") + cols = list(filter(lambda x: self._plot_metric in x, history.keys())) + # heuristic: there is always val_<plot_metric> and <plot_metric> available in cols, because this is generated by + # the keras framework. If this metric isn't available the self._get_plot_metric() will fail before (but only + # because it is executed before) + cols.remove(f"val_{self._plot_metric}") + cols.remove(self._plot_metric) return cols def _plot(self, filename: str) -> None: """ - Actual plot routine. Plots loss and val_loss as default. If more losses are provided, they will be added with - an additional yaxis on the right side. The plot is saved in filename. + Actual plot routine. Plots <plot_metric> and val_<plot_metric> as default. If more plot_metrics are provided, + they will be added with an additional yaxis on the right side. The plot is saved in filename. :param filename: name (including total path) of the plot to save. """ - ax = self._data[["loss", "val_loss"]].plot(linewidth=0.7) + ax = self._data[[self._plot_metric, f"val_{self._plot_metric}"]].plot(linewidth=0.7) if len(self._additional_columns) > 0: self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax) - ax.set(xlabel="epoch", ylabel="loss", title=f"Model loss: best = {self._data[['val_loss']].min().values}") + title = f"Model {self._plot_metric}: best = {self._data[[f'val_{self._plot_metric}']].min().values}" + ax.set(xlabel="epoch", ylabel=self._plot_metric, title=title) ax.axhline(y=0, color="gray", linewidth=0.5) plt.tight_layout() plt.savefig(filename) diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index 0f3ff6d436b8a65528626f5f80508af222a1e68f..a7722018c52275b390a10199cb30b7b936ed37a3 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -15,8 +15,8 @@ from src.run_modules.run_environment import RunEnvironment from src.helpers import l_p_loss, LearningRateDecay from src.model_modules.inception_model import InceptionModelBase from src.model_modules.flatten import flatten_tail -# from src.model_modules.model_class import MyBranchedModel as MyModel -from src.model_modules.model_class import MyLittleModel as MyModel +from src.model_modules.model_class import MyBranchedModel as MyModel +# from src.model_modules.model_class import MyLittleModel as MyModel class ModelSetup(RunEnvironment): diff --git a/src/run_modules/training.py b/src/run_modules/training.py index 272609a31a3e3c91d6857ed841d5dd2783c66f35..fee1b38b97d8f4649730f0f7110cd3801ba7db33 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -134,5 +134,17 @@ class Training(RunEnvironment): """ path = self.data_store.get("plot_path", "general") name = self.data_store.get("experiment_name", "general") - PlotModelHistory(filename=os.path.join(path, f"{name}_history_loss_val_loss.pdf"), history=history) + + # plot history of loss and mse (if available) + filename = os.path.join(path, f"{name}_history_loss.pdf") + PlotModelHistory(filename=filename, history=history) + multiple_branches_used = len(history.model.output_names) > 1 # means that there are multiple output branches + if multiple_branches_used: + filename = os.path.join(path, f"{name}_history_main_loss.pdf") + PlotModelHistory(filename=filename, history=history, main_branch=True) + if "mean_squared_error" in history.model.metrics_names: + filename = os.path.join(path, f"{name}_history_main_mse.pdf") + PlotModelHistory(filename=filename, history=history, plot_metric="mse", main_branch=multiple_branches_used) + + # plot learning rate PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc) diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py index 3426eb1d355a6690ee57c3ee45e5088d7df9c249..accb32e5e3ec0fc425065ae6199c0418c524b174 100644 --- a/test/test_modules/test_training.py +++ b/test/test_modules/test_training.py @@ -37,7 +37,7 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m class TestTraining: @pytest.fixture - def init_without_run(self, path, model, checkpoint): + def init_without_run(self, path: str, model: keras.Model, checkpoint: ModelCheckpoint): obj = object.__new__(Training) super(Training, obj).__init__() obj.model = model @@ -82,6 +82,7 @@ class TestTraining: 'loss': [0.6795708956961347, 0.45963566494176616], 'mean_squared_error': [0.6795708956961347, 0.45963566494176616], 'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]} + h.model = mock.MagicMock() return h @pytest.fixture @@ -103,7 +104,7 @@ class TestTraining: return ModelCheckpoint(os.path.join(path, "model_checkpoint"), monitor='val_loss', save_best_only=True) @pytest.fixture - def ready_to_train(self, generator, init_without_run): + def ready_to_train(self, generator: DataGenerator, init_without_run: Training): init_without_run.train_set = Distributor(generator, init_without_run.model, init_without_run.batch_size) init_without_run.val_set = Distributor(generator, init_without_run.model, init_without_run.batch_size) init_without_run.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) @@ -208,5 +209,7 @@ class TestTraining: def test_create_monitoring_plots(self, init_without_run, learning_rate, history, path): assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0 + history.model.output_names = mock.MagicMock(return_value=["Main"]) + history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"]) init_without_run.create_monitoring_plots(history, learning_rate) assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 diff --git a/test/test_plotting/test_training_monitoring.py b/test/test_plotting/test_training_monitoring.py index d38fc623b7fd2f5f8c779ea03757d7bf0794089a..e2d17057facee8ae07284109392f3e42e8e2fb6e 100644 --- a/test/test_plotting/test_training_monitoring.py +++ b/test/test_plotting/test_training_monitoring.py @@ -17,50 +17,81 @@ def path(): class TestPlotModelHistory: @pytest.fixture - def history(self): + def default_history(self): hist = keras.callbacks.History() hist.epoch = [0, 1] hist.history = {'val_loss': [0.5586272982587484, 0.45712877659670287], - 'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287], - 'val_mean_absolute_error': [0.595368885413389, 0.530547587585537], - 'loss': [0.6795708956961347, 0.45963566494176616], - 'mean_squared_error': [0.6795708956961347, 0.45963566494176616], - 'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]} + 'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287], + 'val_mean_absolute_error': [0.595368885413389, 0.530547587585537], + 'loss': [0.6795708956961347, 0.45963566494176616], + 'mean_squared_error': [0.6795708956961347, 0.45963566494176616], + 'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]} return hist @pytest.fixture - def history_var(self): + def history_additional_loss(self): hist = keras.callbacks.History() hist.epoch = [0, 1] hist.history = {'val_loss': [0.5586272982587484, 0.45712877659670287], - 'test_loss': [0.595368885413389, 0.530547587585537], - 'loss': [0.6795708956961347, 0.45963566494176616], - 'mean_squared_error': [0.6795708956961347, 0.45963566494176616], - 'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]} + 'test_loss': [0.595368885413389, 0.530547587585537], + 'loss': [0.6795708956961347, 0.45963566494176616], + 'mean_squared_error': [0.6795708956961347, 0.45963566494176616], + 'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]} return hist + @pytest.fixture + def history_with_main(self, default_history): + default_history.history["main_val_loss"] = [0.5586272982587484, 0.45712877659670287] + default_history.history["main_loss"] = [0.6795708956961347, 0.45963566494176616] + return default_history + @pytest.fixture def no_init(self): return object.__new__(PlotModelHistory) - def test_plot_from_hist_obj(self, history, path): + def test_get_plot_metric(self, no_init, default_history): + history = default_history.history + metric = no_init._get_plot_metric(history, plot_metric="loss", main_branch=False) + assert metric == "loss" + metric = no_init._get_plot_metric(history, plot_metric="mean_squared_error", main_branch=False) + assert metric == "mean_squared_error" + + def test_get_plot_metric_short_metric(self, no_init, default_history): + history = default_history.history + metric = no_init._get_plot_metric(history, plot_metric="mse", main_branch=False) + assert metric == "mean_squared_error" + metric = no_init._get_plot_metric(history, plot_metric="mae", main_branch=False) + assert metric == "mean_absolute_error" + + def test_get_plot_metric_main_branch(self, no_init, history_with_main): + history = history_with_main.history + metric = no_init._get_plot_metric(history, plot_metric="loss", main_branch=True) + assert metric == "main_loss" + + def test_filter_columns(self, no_init): + no_init._plot_metric = "loss" + res = no_init._filter_columns({'loss': None, 'another_loss': None, 'val_loss': None, 'wrong': None}) + assert res == ['another_loss'] + no_init._plot_metric = "mean_squared_error" + res = no_init._filter_columns({'mean_squared_error': None, 'another_loss': None, 'val_mean_squared_error': None, + 'wrong': None}) + assert res == [] + + def test_plot_from_hist_obj(self, default_history, path): assert "hist_obj.pdf" not in os.listdir(path) - PlotModelHistory(os.path.join(path, "hist_obj.pdf"), history) + PlotModelHistory(os.path.join(path, "hist_obj.pdf"), default_history) assert "hist_obj.pdf" in os.listdir(path) - def test_plot_from_hist_dict(self, history, path): + def test_plot_from_hist_dict(self, default_history, path): assert "hist_dict.pdf" not in os.listdir(path) - PlotModelHistory(os.path.join(path, "hist_dict.pdf"), history.history) + PlotModelHistory(os.path.join(path, "hist_dict.pdf"), default_history.history) assert "hist_dict.pdf" in os.listdir(path) - def test_plot_additional_loss(self, history_var, path): + def test_plot_additional_loss(self, history_additional_loss, path): assert "hist_additional.pdf" not in os.listdir(path) - PlotModelHistory(os.path.join(path, "hist_additional.pdf"), history_var) + PlotModelHistory(os.path.join(path, "hist_additional.pdf"), history_additional_loss) assert "hist_additional.pdf" in os.listdir(path) - def test_filter_list(self, no_init): - res = no_init._filter_columns({'loss': None, 'another_loss': None, 'val_loss': None, 'wrong': None}) - assert res == ['another_loss'] class TestPlotModelLearningRate: