import os import keras import pytest from mlair.model_modules.keras_extensions import LearningRateDecay from mlair.plotting.training_monitoring import PlotModelLearningRate, PlotModelHistory @pytest.fixture def path(): p = os.path.join(os.path.dirname(__file__), "TestExperiment") if not os.path.exists(p): os.makedirs(p) return p class TestPlotModelHistory: @pytest.fixture def default_history(self): hist = keras.callbacks.History() hist.epoch = [0, 1] hist.history = {'val_loss': [0.5586272982587484, 0.45712877659670287], 'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287], 'val_mean_absolute_error': [0.595368885413389, 0.530547587585537], 'loss': [0.6795708956961347, 0.45963566494176616], 'mean_squared_error': [0.6795708956961347, 0.45963566494176616], 'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]} return hist @pytest.fixture def history_additional_loss(self): hist = keras.callbacks.History() hist.epoch = [0, 1] hist.history = {'val_loss': [0.5586272982587484, 0.45712877659670287], 'test_loss': [0.595368885413389, 0.530547587585537], 'loss': [0.6795708956961347, 0.45963566494176616], 'mean_squared_error': [0.6795708956961347, 0.45963566494176616], 'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]} return hist @pytest.fixture def history_with_main(self, default_history): default_history.history["main_val_loss"] = [0.5586272982587484, 0.45712877659670287] default_history.history["main_loss"] = [0.6795708956961347, 0.45963566494176616] return default_history @pytest.fixture def no_init(self): return object.__new__(PlotModelHistory) def test_get_plot_metric(self, no_init, default_history): history = default_history.history metric = no_init._get_plot_metric(history, plot_metric="loss", main_branch=False) assert metric == "loss" metric = no_init._get_plot_metric(history, plot_metric="mean_squared_error", main_branch=False) assert metric == "mean_squared_error" def test_get_plot_metric_short_metric(self, no_init, default_history): history = default_history.history metric = no_init._get_plot_metric(history, plot_metric="mse", main_branch=False) assert metric == "mean_squared_error" metric = no_init._get_plot_metric(history, plot_metric="mae", main_branch=False) assert metric == "mean_absolute_error" def test_get_plot_metric_main_branch(self, no_init, history_with_main): history = history_with_main.history metric = no_init._get_plot_metric(history, plot_metric="loss", main_branch=True) assert metric == "main_loss" def test_filter_columns(self, no_init): no_init._plot_metric = "loss" res = no_init._filter_columns({'loss': None, 'another_loss': None, 'val_loss': None, 'wrong': None}) assert res == ['another_loss'] no_init._plot_metric = "mean_squared_error" res = no_init._filter_columns({'mean_squared_error': None, 'another_loss': None, 'val_mean_squared_error': None, 'wrong': None}) assert res == [] def test_plot_from_hist_obj(self, default_history, path): assert "hist_obj.pdf" not in os.listdir(path) PlotModelHistory(os.path.join(path, "hist_obj.pdf"), default_history) assert "hist_obj.pdf" in os.listdir(path) def test_plot_from_hist_dict(self, default_history, path): assert "hist_dict.pdf" not in os.listdir(path) PlotModelHistory(os.path.join(path, "hist_dict.pdf"), default_history.history) assert "hist_dict.pdf" in os.listdir(path) def test_plot_additional_loss(self, history_additional_loss, path): assert "hist_additional.pdf" not in os.listdir(path) PlotModelHistory(os.path.join(path, "hist_additional.pdf"), history_additional_loss) assert "hist_additional.pdf" in os.listdir(path) class TestPlotModelLearningRate: @pytest.fixture def learning_rate(self): lr_obj = LearningRateDecay() lr_obj.lr = {"lr": [0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.0094, 0.0094, 0.0094, 0.0094, 0.0094, 0.0094, 0.0094, 0.0094, 0.0094, 0.0094]} return lr_obj def test_plot_from_lr_obj(self, learning_rate, path): assert "lr_obj.pdf" not in os.listdir(path) PlotModelLearningRate(os.path.join(path, "lr_obj.pdf"), learning_rate) assert "lr_obj.pdf" in os.listdir(path) def test_plot_from_lr_dict(self, learning_rate, path): assert "lr_dict.pdf" not in os.listdir(path) PlotModelLearningRate(os.path.join(path, "lr_dict.pdf"), learning_rate.lr) assert "lr_dict.pdf" in os.listdir(path)