Select Git revision
build_mpi.sh
-
Stephan Schulz authoredStephan Schulz authored
test_training_monitoring.py 3.28 KiB
import keras
import pytest
import os
from src.plotting.training_monitoring import PlotModelLearningRate, PlotModelHistory
from src.helpers import LearningRateDecay
@pytest.fixture
def path():
p = os.path.join(os.path.dirname(__file__), "TestExperiment")
if not os.path.exists(p):
os.makedirs(p)
return p
class TestPlotModelHistory:
@pytest.fixture
def history(self):
hist = keras.callbacks.History()
hist.epoch = [0, 1]
hist.history = {'val_loss': [0.5586272982587484, 0.45712877659670287],
'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287],
'val_mean_absolute_error': [0.595368885413389, 0.530547587585537],
'loss': [0.6795708956961347, 0.45963566494176616],
'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
return hist
@pytest.fixture
def history_var(self):
hist = keras.callbacks.History()
hist.epoch = [0, 1]
hist.history = {'val_loss': [0.5586272982587484, 0.45712877659670287],
'test_loss': [0.595368885413389, 0.530547587585537],
'loss': [0.6795708956961347, 0.45963566494176616],
'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
return hist
@pytest.fixture
def no_init(self):
return object.__new__(PlotModelHistory)
def test_plot_from_hist_obj(self, history, path):
assert "hist_obj.pdf" not in os.listdir(path)
PlotModelHistory(os.path.join(path, "hist_obj.pdf"), history)
assert "hist_obj.pdf" in os.listdir(path)
def test_plot_from_hist_dict(self, history, path):
assert "hist_dict.pdf" not in os.listdir(path)
PlotModelHistory(os.path.join(path, "hist_dict.pdf"), history.history)
assert "hist_dict.pdf" in os.listdir(path)
def test_plot_additional_loss(self, history_var, path):
assert "hist_additional.pdf" not in os.listdir(path)
PlotModelHistory(os.path.join(path, "hist_additional.pdf"), history_var)
assert "hist_additional.pdf" in os.listdir(path)
def test_filter_list(self, no_init):
res = no_init._filter_columns({'loss': None, 'another_loss': None, 'val_loss': None, 'wrong': None})
assert res == ['another_loss']
class TestPlotModelLearningRate:
@pytest.fixture
def learning_rate(self):
lr_obj = LearningRateDecay()
lr_obj.lr = {"lr": [0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.0094, 0.0094, 0.0094, 0.0094,
0.0094, 0.0094, 0.0094, 0.0094, 0.0094, 0.0094]}
return lr_obj
def test_plot_from_lr_obj(self, learning_rate, path):
assert "lr_obj.pdf" not in os.listdir(path)
PlotModelLearningRate(os.path.join(path, "lr_obj.pdf"), learning_rate)
assert "lr_obj.pdf" in os.listdir(path)
def test_plot_from_lr_dict(self, learning_rate, path):
assert "lr_dict.pdf" not in os.listdir(path)
PlotModelLearningRate(os.path.join(path, "lr_dict.pdf"), learning_rate.lr)
assert "lr_dict.pdf" in os.listdir(path)