diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py index 3598fe0b2939fd16951608bb7846b6487ad57cc1..c91cb691464c37456840e2c3150d43f29fc4859b 100644 --- a/test/test_modules/test_training.py +++ b/test/test_modules/test_training.py @@ -12,6 +12,7 @@ from src.flatten import flatten_tail from src.modules.training import Training from src.modules.run_environment import RunEnvironment from src.data_handling.data_distributor import Distributor +from src.data_handling.data_generator import DataGenerator from src.helpers import LearningRateDecay, PyTestRegex @@ -35,16 +36,16 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m class TestTraining: @pytest.fixture - def init_without_run(self, path): + def init_without_run(self, path, model, checkpoint): obj = object.__new__(Training) super(Training, obj).__init__() - obj.model = my_test_model(keras.layers.PReLU, 5, 3, 0.1, False) + obj.model = model obj.train_set = None obj.val_set = None obj.test_set = None obj.batch_size = 256 obj.epochs = 2 - obj.checkpoint = ModelCheckpoint("model_checkpoint", monitor='val_loss', save_best_only=True, mode='auto') + obj.checkpoint = checkpoint obj.lr_sc = LearningRateDecay() obj.experiment_name = "TestExperiment" obj.data_store.put("generator", mock.MagicMock(return_value="mock_train_gen"), "general.train") @@ -74,11 +75,60 @@ class TestTraining: def path(self): return os.path.join(os.path.dirname(__file__), "TestExperiment") - def test_init(self): - pass + @pytest.fixture + def generator(self, path): + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', + ['DEBW107'], ['o3', 'temp'], 'datetime', 'variables', + 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + + @pytest.fixture + def model(self): + return my_test_model(keras.layers.PReLU, 7, 2, 0.1, False) + + @pytest.fixture + def checkpoint(self, path): + return ModelCheckpoint(os.path.join(path, "model_checkpoint"), monitor='val_loss', save_best_only=True) + + @pytest.fixture + def ready_to_train(self, generator, init_without_run): + init_without_run.train_set = Distributor(generator, init_without_run.model, init_without_run.batch_size) + init_without_run.val_set = Distributor(generator, init_without_run.model, init_without_run.batch_size) + init_without_run.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) + return init_without_run + + @pytest.fixture + def ready_to_run(self, generator, init_without_run): + obj = init_without_run + obj.data_store.put("generator", generator, "general.train") + obj.data_store.put("generator", generator, "general.val") + obj.data_store.put("generator", generator, "general.test") + obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) + return obj + + @pytest.fixture + def ready_to_init(self, generator, model, checkpoint, path): + os.makedirs(path) + obj = RunEnvironment() + obj.data_store.put("generator", generator, "general.train") + obj.data_store.put("generator", generator, "general.val") + obj.data_store.put("generator", generator, "general.test") + model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) + obj.data_store.put("model", model, "general.model") + obj.data_store.put("batch_size", 256, "general.model") + obj.data_store.put("epochs", 2, "general.model") + obj.data_store.put("checkpoint", checkpoint, "general.model") + obj.data_store.put("lr_decay", LearningRateDecay(), "general.model") + obj.data_store.put("experiment_name", "TestExperiment", "general") + obj.data_store.put("experiment_path", path, "general") + yield obj + if os.path.exists(path): + shutil.rmtree(path) + + def test_init(self, ready_to_init): + assert isinstance(Training(), Training) # just test, if nothing fails - def test_run(self): - pass + def test_run(self, ready_to_run): + assert ready_to_run._run() is None # just test, if nothing fails def test_make_predict_function(self, init_without_run): assert hasattr(init_without_run.model, "predict_function") is False @@ -98,8 +148,11 @@ class TestTraining: assert not all([getattr(init_without_run, f"{obj}_set") is None for obj in sets]) assert all([getattr(init_without_run, f"{obj}_set").generator.return_value == f"mock_{obj}_gen" for obj in sets]) - def test_train(self, init_without_run): - pass + def test_train(self, ready_to_train): + assert not hasattr(ready_to_train.model, "history") + ready_to_train.train() + assert list(ready_to_train.model.history.history.keys()) == ["val_loss", "loss"] + assert ready_to_train.model.history.epoch == [0, 1] def test_save_model(self, init_without_run, path, caplog): caplog.set_level(logging.DEBUG)