diff --git a/src/plotting/training_monitoring.py b/src/plotting/training_monitoring.py index 617ff3135734056e51746ae2924a123a3eb34f8f..87e4071a0ac98c15c22131cd5d3418eb7c1b6976 100644 --- a/src/plotting/training_monitoring.py +++ b/src/plotting/training_monitoring.py @@ -25,12 +25,15 @@ class PlotModelHistory: metrics). The plot is saved locally. For a proper saving behaviour, the parameter filename must include the absolute path for the plot. """ - def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = True): + def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = False): """ Sets attributes and create plot :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a format ending like .pdf or .png to work. :param history: the history object (or a dict with at least 'loss' and 'val_loss' as keys) to plot loss from + :param plot_metric: the metric to plot (e.b. mean_squared_error, mse, mean_absolute_error, loss, default: loss) + :param main_branch: switch between only looking for metrics that go with 'main' or for all occurrences (default: + False -> look for losses from all branches, not only from main) """ if isinstance(history, keras.callbacks.History): history = history.history @@ -45,10 +48,8 @@ class PlotModelHistory: plot_metric = "mean_squared_error" elif plot_metric.lower() == "mae": plot_metric = "mean_absolute_error" - available_keys = [k for k in history.keys() if plot_metric in k and ("main" in k if main_branch else True)] - print(available_keys) + available_keys = [k for k in history.keys() if plot_metric in k and ("main" in k.lower() if main_branch else True)] available_keys.sort(key=len) - print(available_keys) return available_keys[0] def _filter_columns(self, history: Dict) -> List[str]: @@ -60,6 +61,9 @@ class PlotModelHistory: :return: filtered columns including all plot_metric variations except <plot_metric> and val_<plot_metric>. """ cols = list(filter(lambda x: self._plot_metric in x, history.keys())) + # heuristic: there is always val_<plot_metric> and <plot_metric> available in cols, because this is generated by + # the keras framework. If this metric isn't available the self._get_plot_metric() will fail before (but only + # because it is executed before) cols.remove(f"val_{self._plot_metric}") cols.remove(self._plot_metric) return cols diff --git a/src/run_modules/training.py b/src/run_modules/training.py index d15060daeb2596d048ff44f5f9948c0002069762..fee1b38b97d8f4649730f0f7110cd3801ba7db33 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -134,10 +134,17 @@ class Training(RunEnvironment): """ path = self.data_store.get("plot_path", "general") name = self.data_store.get("experiment_name", "general") + + # plot history of loss and mse (if available) filename = os.path.join(path, f"{name}_history_loss.pdf") - PlotModelHistory(filename=filename, history=history, main_branch=False) - filename = os.path.join(path, f"{name}_history_main_loss.pdf") PlotModelHistory(filename=filename, history=history) - filename = os.path.join(path, f"{name}_history_main_mse.pdf") - PlotModelHistory(filename=filename, history=history, plot_metric="mse") + multiple_branches_used = len(history.model.output_names) > 1 # means that there are multiple output branches + if multiple_branches_used: + filename = os.path.join(path, f"{name}_history_main_loss.pdf") + PlotModelHistory(filename=filename, history=history, main_branch=True) + if "mean_squared_error" in history.model.metrics_names: + filename = os.path.join(path, f"{name}_history_main_mse.pdf") + PlotModelHistory(filename=filename, history=history, plot_metric="mse", main_branch=multiple_branches_used) + + # plot learning rate PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc) diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py index 3426eb1d355a6690ee57c3ee45e5088d7df9c249..accb32e5e3ec0fc425065ae6199c0418c524b174 100644 --- a/test/test_modules/test_training.py +++ b/test/test_modules/test_training.py @@ -37,7 +37,7 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m class TestTraining: @pytest.fixture - def init_without_run(self, path, model, checkpoint): + def init_without_run(self, path: str, model: keras.Model, checkpoint: ModelCheckpoint): obj = object.__new__(Training) super(Training, obj).__init__() obj.model = model @@ -82,6 +82,7 @@ class TestTraining: 'loss': [0.6795708956961347, 0.45963566494176616], 'mean_squared_error': [0.6795708956961347, 0.45963566494176616], 'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]} + h.model = mock.MagicMock() return h @pytest.fixture @@ -103,7 +104,7 @@ class TestTraining: return ModelCheckpoint(os.path.join(path, "model_checkpoint"), monitor='val_loss', save_best_only=True) @pytest.fixture - def ready_to_train(self, generator, init_without_run): + def ready_to_train(self, generator: DataGenerator, init_without_run: Training): init_without_run.train_set = Distributor(generator, init_without_run.model, init_without_run.batch_size) init_without_run.val_set = Distributor(generator, init_without_run.model, init_without_run.batch_size) init_without_run.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) @@ -208,5 +209,7 @@ class TestTraining: def test_create_monitoring_plots(self, init_without_run, learning_rate, history, path): assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0 + history.model.output_names = mock.MagicMock(return_value=["Main"]) + history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"]) init_without_run.create_monitoring_plots(history, learning_rate) assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 diff --git a/test/test_plotting/test_training_monitoring.py b/test/test_plotting/test_training_monitoring.py index d38fc623b7fd2f5f8c779ea03757d7bf0794089a..e2d17057facee8ae07284109392f3e42e8e2fb6e 100644 --- a/test/test_plotting/test_training_monitoring.py +++ b/test/test_plotting/test_training_monitoring.py @@ -17,50 +17,81 @@ def path(): class TestPlotModelHistory: @pytest.fixture - def history(self): + def default_history(self): hist = keras.callbacks.History() hist.epoch = [0, 1] hist.history = {'val_loss': [0.5586272982587484, 0.45712877659670287], - 'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287], - 'val_mean_absolute_error': [0.595368885413389, 0.530547587585537], - 'loss': [0.6795708956961347, 0.45963566494176616], - 'mean_squared_error': [0.6795708956961347, 0.45963566494176616], - 'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]} + 'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287], + 'val_mean_absolute_error': [0.595368885413389, 0.530547587585537], + 'loss': [0.6795708956961347, 0.45963566494176616], + 'mean_squared_error': [0.6795708956961347, 0.45963566494176616], + 'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]} return hist @pytest.fixture - def history_var(self): + def history_additional_loss(self): hist = keras.callbacks.History() hist.epoch = [0, 1] hist.history = {'val_loss': [0.5586272982587484, 0.45712877659670287], - 'test_loss': [0.595368885413389, 0.530547587585537], - 'loss': [0.6795708956961347, 0.45963566494176616], - 'mean_squared_error': [0.6795708956961347, 0.45963566494176616], - 'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]} + 'test_loss': [0.595368885413389, 0.530547587585537], + 'loss': [0.6795708956961347, 0.45963566494176616], + 'mean_squared_error': [0.6795708956961347, 0.45963566494176616], + 'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]} return hist + @pytest.fixture + def history_with_main(self, default_history): + default_history.history["main_val_loss"] = [0.5586272982587484, 0.45712877659670287] + default_history.history["main_loss"] = [0.6795708956961347, 0.45963566494176616] + return default_history + @pytest.fixture def no_init(self): return object.__new__(PlotModelHistory) - def test_plot_from_hist_obj(self, history, path): + def test_get_plot_metric(self, no_init, default_history): + history = default_history.history + metric = no_init._get_plot_metric(history, plot_metric="loss", main_branch=False) + assert metric == "loss" + metric = no_init._get_plot_metric(history, plot_metric="mean_squared_error", main_branch=False) + assert metric == "mean_squared_error" + + def test_get_plot_metric_short_metric(self, no_init, default_history): + history = default_history.history + metric = no_init._get_plot_metric(history, plot_metric="mse", main_branch=False) + assert metric == "mean_squared_error" + metric = no_init._get_plot_metric(history, plot_metric="mae", main_branch=False) + assert metric == "mean_absolute_error" + + def test_get_plot_metric_main_branch(self, no_init, history_with_main): + history = history_with_main.history + metric = no_init._get_plot_metric(history, plot_metric="loss", main_branch=True) + assert metric == "main_loss" + + def test_filter_columns(self, no_init): + no_init._plot_metric = "loss" + res = no_init._filter_columns({'loss': None, 'another_loss': None, 'val_loss': None, 'wrong': None}) + assert res == ['another_loss'] + no_init._plot_metric = "mean_squared_error" + res = no_init._filter_columns({'mean_squared_error': None, 'another_loss': None, 'val_mean_squared_error': None, + 'wrong': None}) + assert res == [] + + def test_plot_from_hist_obj(self, default_history, path): assert "hist_obj.pdf" not in os.listdir(path) - PlotModelHistory(os.path.join(path, "hist_obj.pdf"), history) + PlotModelHistory(os.path.join(path, "hist_obj.pdf"), default_history) assert "hist_obj.pdf" in os.listdir(path) - def test_plot_from_hist_dict(self, history, path): + def test_plot_from_hist_dict(self, default_history, path): assert "hist_dict.pdf" not in os.listdir(path) - PlotModelHistory(os.path.join(path, "hist_dict.pdf"), history.history) + PlotModelHistory(os.path.join(path, "hist_dict.pdf"), default_history.history) assert "hist_dict.pdf" in os.listdir(path) - def test_plot_additional_loss(self, history_var, path): + def test_plot_additional_loss(self, history_additional_loss, path): assert "hist_additional.pdf" not in os.listdir(path) - PlotModelHistory(os.path.join(path, "hist_additional.pdf"), history_var) + PlotModelHistory(os.path.join(path, "hist_additional.pdf"), history_additional_loss) assert "hist_additional.pdf" in os.listdir(path) - def test_filter_list(self, no_init): - res = no_init._filter_columns({'loss': None, 'another_loss': None, 'val_loss': None, 'wrong': None}) - assert res == ['another_loss'] class TestPlotModelLearningRate: