Skip to content
Snippets Groups Projects
Commit 1696e914 authored by lukas leufen's avatar lukas leufen
Browse files

implemented all tests for training. /close #17

parent fa64b136
Branches
Tags
2 merge requests!24include recent development,!20not distributed training
Pipeline #27226 passed
...@@ -12,6 +12,7 @@ from src.flatten import flatten_tail ...@@ -12,6 +12,7 @@ from src.flatten import flatten_tail
from src.modules.training import Training from src.modules.training import Training
from src.modules.run_environment import RunEnvironment from src.modules.run_environment import RunEnvironment
from src.data_handling.data_distributor import Distributor from src.data_handling.data_distributor import Distributor
from src.data_handling.data_generator import DataGenerator
from src.helpers import LearningRateDecay, PyTestRegex from src.helpers import LearningRateDecay, PyTestRegex
...@@ -35,16 +36,16 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m ...@@ -35,16 +36,16 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m
class TestTraining: class TestTraining:
@pytest.fixture @pytest.fixture
def init_without_run(self, path): def init_without_run(self, path, model, checkpoint):
obj = object.__new__(Training) obj = object.__new__(Training)
super(Training, obj).__init__() super(Training, obj).__init__()
obj.model = my_test_model(keras.layers.PReLU, 5, 3, 0.1, False) obj.model = model
obj.train_set = None obj.train_set = None
obj.val_set = None obj.val_set = None
obj.test_set = None obj.test_set = None
obj.batch_size = 256 obj.batch_size = 256
obj.epochs = 2 obj.epochs = 2
obj.checkpoint = ModelCheckpoint("model_checkpoint", monitor='val_loss', save_best_only=True, mode='auto') obj.checkpoint = checkpoint
obj.lr_sc = LearningRateDecay() obj.lr_sc = LearningRateDecay()
obj.experiment_name = "TestExperiment" obj.experiment_name = "TestExperiment"
obj.data_store.put("generator", mock.MagicMock(return_value="mock_train_gen"), "general.train") obj.data_store.put("generator", mock.MagicMock(return_value="mock_train_gen"), "general.train")
...@@ -74,11 +75,60 @@ class TestTraining: ...@@ -74,11 +75,60 @@ class TestTraining:
def path(self): def path(self):
return os.path.join(os.path.dirname(__file__), "TestExperiment") return os.path.join(os.path.dirname(__file__), "TestExperiment")
def test_init(self): @pytest.fixture
pass def generator(self, path):
return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE',
['DEBW107'], ['o3', 'temp'], 'datetime', 'variables',
'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
@pytest.fixture
def model(self):
return my_test_model(keras.layers.PReLU, 7, 2, 0.1, False)
@pytest.fixture
def checkpoint(self, path):
return ModelCheckpoint(os.path.join(path, "model_checkpoint"), monitor='val_loss', save_best_only=True)
@pytest.fixture
def ready_to_train(self, generator, init_without_run):
init_without_run.train_set = Distributor(generator, init_without_run.model, init_without_run.batch_size)
init_without_run.val_set = Distributor(generator, init_without_run.model, init_without_run.batch_size)
init_without_run.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
return init_without_run
@pytest.fixture
def ready_to_run(self, generator, init_without_run):
obj = init_without_run
obj.data_store.put("generator", generator, "general.train")
obj.data_store.put("generator", generator, "general.val")
obj.data_store.put("generator", generator, "general.test")
obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
return obj
@pytest.fixture
def ready_to_init(self, generator, model, checkpoint, path):
os.makedirs(path)
obj = RunEnvironment()
obj.data_store.put("generator", generator, "general.train")
obj.data_store.put("generator", generator, "general.val")
obj.data_store.put("generator", generator, "general.test")
model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
obj.data_store.put("model", model, "general.model")
obj.data_store.put("batch_size", 256, "general.model")
obj.data_store.put("epochs", 2, "general.model")
obj.data_store.put("checkpoint", checkpoint, "general.model")
obj.data_store.put("lr_decay", LearningRateDecay(), "general.model")
obj.data_store.put("experiment_name", "TestExperiment", "general")
obj.data_store.put("experiment_path", path, "general")
yield obj
if os.path.exists(path):
shutil.rmtree(path)
def test_init(self, ready_to_init):
assert isinstance(Training(), Training) # just test, if nothing fails
def test_run(self): def test_run(self, ready_to_run):
pass assert ready_to_run._run() is None # just test, if nothing fails
def test_make_predict_function(self, init_without_run): def test_make_predict_function(self, init_without_run):
assert hasattr(init_without_run.model, "predict_function") is False assert hasattr(init_without_run.model, "predict_function") is False
...@@ -98,8 +148,11 @@ class TestTraining: ...@@ -98,8 +148,11 @@ class TestTraining:
assert not all([getattr(init_without_run, f"{obj}_set") is None for obj in sets]) assert not all([getattr(init_without_run, f"{obj}_set") is None for obj in sets])
assert all([getattr(init_without_run, f"{obj}_set").generator.return_value == f"mock_{obj}_gen" for obj in sets]) assert all([getattr(init_without_run, f"{obj}_set").generator.return_value == f"mock_{obj}_gen" for obj in sets])
def test_train(self, init_without_run): def test_train(self, ready_to_train):
pass assert not hasattr(ready_to_train.model, "history")
ready_to_train.train()
assert list(ready_to_train.model.history.history.keys()) == ["val_loss", "loss"]
assert ready_to_train.model.history.epoch == [0, 1]
def test_save_model(self, init_without_run, path, caplog): def test_save_model(self, init_without_run, path, caplog):
caplog.set_level(logging.DEBUG) caplog.set_level(logging.DEBUG)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment