import glob import json import logging import os import shutil import keras import mock import pytest from keras.callbacks import History from mlair.data_handler import DataCollection, KerasIterator, DefaultDataHandler from mlair.helpers import PyTestRegex from mlair.model_modules.flatten import flatten_tail from mlair.model_modules.inception_model import InceptionModelBase from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler from mlair.run_modules.run_environment import RunEnvironment from mlair.run_modules.training import Training def my_test_model(activation, window_history_size, channels, output_size, dropout_rate, add_minor_branch=False): inception_model = InceptionModelBase() conv_settings_dict1 = { 'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation}, 'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation}, } pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation} X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels)) X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1) if add_minor_branch: out = [flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=4, output_activation='linear', reduction_filter=64, name='Minor_1', dropout_rate=dropout_rate, )] else: out = [] X_in = keras.layers.Dropout(dropout_rate)(X_in) out.append(flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=output_size, output_activation='linear', reduction_filter=64, name='Main', dropout_rate=dropout_rate, )) return keras.Model(inputs=X_input, outputs=out) class TestTraining: @pytest.fixture def init_without_run(self, path: str, model: keras.Model, callbacks: CallbackHandler, model_path, batch_path): obj = object.__new__(Training) super(Training, obj).__init__() obj.model = model obj.train_set = None obj.val_set = None obj.test_set = None obj.batch_size = 256 obj.epochs = 2 clbk, hist, lr = callbacks obj.callbacks = clbk obj.lr_sc = lr obj.hist = hist obj.experiment_name = "TestExperiment" obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_train_gen"), "general.train") obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_val_gen"), "general.val") obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_test_gen"), "general.test") os.makedirs(path) obj.data_store.set("experiment_path", path, "general") os.makedirs(batch_path) obj.data_store.set("batch_path", batch_path, "general") os.makedirs(model_path) obj.data_store.set("model_path", model_path, "general") obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model") obj.data_store.set("experiment_name", "TestExperiment", "general") path_plot = os.path.join(path, "plots") os.makedirs(path_plot) obj.data_store.set("plot_path", path_plot, "general") obj._trainable = True obj._create_new_model = False yield obj if os.path.exists(path): shutil.rmtree(path) RunEnvironment().__del__() @pytest.fixture def learning_rate(self): lr = LearningRateDecay() lr.lr = {"lr": [0.01, 0.0094]} return lr @pytest.fixture def history(self): h = History() h.epoch = [0, 1] h.history = {'val_loss': [0.5586272982587484, 0.45712877659670287], 'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287], 'val_mean_absolute_error': [0.595368885413389, 0.530547587585537], 'loss': [0.6795708956961347, 0.45963566494176616], 'mean_squared_error': [0.6795708956961347, 0.45963566494176616], 'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]} h.model = mock.MagicMock() return h @pytest.fixture def path(self): return os.path.join(os.path.dirname(__file__), "TestExperiment") @pytest.fixture def model_path(self, path): return os.path.join(path, "model") @pytest.fixture def batch_path(self, path): return os.path.join(path, "batch") @pytest.fixture def window_history_size(self): return 7 @pytest.fixture def window_lead_time(self): return 2 @pytest.fixture def statistics_per_var(self): return {'o3': 'dma8eu', 'temp': 'maximum'} @pytest.fixture def data_collection(self, path, window_history_size, window_lead_time, statistics_per_var): data_prep = DefaultDataHandler.build(['DEBW107'], data_path=os.path.join(os.path.dirname(__file__), 'data'), statistics_per_var=statistics_per_var, station_type="background", network="AIRBASE", sampling="daily", target_dim="variables", target_var="o3", time_dim="datetime", window_history_size=window_history_size, window_lead_time=window_lead_time, name_affix="train") return DataCollection([data_prep]) @pytest.fixture def model(self, window_history_size, window_lead_time, statistics_per_var): channels = len(list(statistics_per_var.keys())) return my_test_model(keras.layers.PReLU, window_history_size, channels, window_lead_time, 0.1, False) @pytest.fixture def callbacks(self, path): clbk = CallbackHandler() hist = HistoryAdvanced() clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist") lr = LearningRateDecay() clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr") clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss', save_best_only=True) return clbk, hist, lr @pytest.fixture def ready_to_train(self, data_collection: DataCollection, init_without_run: Training, batch_path: str): batch_size = init_without_run.batch_size model = init_without_run.model init_without_run.train_set = KerasIterator(data_collection, batch_size, batch_path, model=model, name="train") init_without_run.val_set = KerasIterator(data_collection, batch_size, batch_path, model=model, name="val") init_without_run.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) return init_without_run @pytest.fixture def ready_to_run(self, data_collection, init_without_run): obj = init_without_run obj.data_store.set("data_collection", data_collection, "general.train") obj.data_store.set("data_collection", data_collection, "general.val") obj.data_store.set("data_collection", data_collection, "general.test") obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) return obj @pytest.fixture def ready_to_init(self, data_collection, model, callbacks, path, model_path, batch_path): os.makedirs(path) os.makedirs(model_path) obj = RunEnvironment() obj.data_store.set("data_collection", data_collection, "general.train") obj.data_store.set("data_collection", data_collection, "general.val") obj.data_store.set("data_collection", data_collection, "general.test") model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) obj.data_store.set("model", model, "general.model") obj.data_store.set("model_path", model_path, "general") obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model") obj.data_store.set("batch_size", 256, "general") obj.data_store.set("epochs", 2, "general") clbk, hist, lr = callbacks obj.data_store.set("callbacks", clbk, "general.model") obj.data_store.set("lr_decay", lr, "general.model") obj.data_store.set("hist", hist, "general.model") obj.data_store.set("experiment_name", "TestExperiment", "general") obj.data_store.set("experiment_path", path, "general") obj.data_store.set("trainable", True, "general") obj.data_store.set("create_new_model", True, "general") os.makedirs(batch_path) obj.data_store.set("batch_path", batch_path, "general") path_plot = os.path.join(path, "plots") os.makedirs(path_plot) obj.data_store.set("plot_path", path_plot, "general") yield obj if os.path.exists(path): shutil.rmtree(path) def test_init(self, ready_to_init): assert isinstance(Training(), Training) # just test, if nothing fails def test_no_training(self, ready_to_init, caplog): caplog.set_level(logging.INFO) ready_to_init.data_store.set("trainable", False) Training() message = "No training has started, because trainable parameter was false." assert caplog.record_tuples[-2] == ("root", 20, message) def test_run(self, ready_to_run): assert ready_to_run._run() is None # just test, if nothing fails def test_make_predict_function(self, init_without_run): assert hasattr(init_without_run.model, "predict_function") is False init_without_run.make_predict_function() assert hasattr(init_without_run.model, "predict_function") def test_set_gen(self, init_without_run): assert init_without_run.train_set is None init_without_run._set_gen("train") assert isinstance(init_without_run.train_set, KerasIterator) assert init_without_run.train_set._collection.return_value == "mock_train_gen" def test_set_generators(self, init_without_run): sets = ["train", "val", "test"] assert all([getattr(init_without_run, f"{obj}_set") is None for obj in sets]) init_without_run.set_generators() assert not all([getattr(init_without_run, f"{obj}_set") is None for obj in sets]) assert all( [getattr(init_without_run, f"{obj}_set")._collection.return_value == f"mock_{obj}_gen" for obj in sets]) def test_train(self, ready_to_train, path): assert not hasattr(ready_to_train.model, "history") assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0 ready_to_train.train() assert list(ready_to_train.model.history.history.keys()) == ["val_loss", "loss"] assert ready_to_train.model.history.epoch == [0, 1] assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 def test_save_model(self, init_without_run, model_path, caplog): caplog.set_level(logging.DEBUG) model_name = "test_model.h5" assert model_name not in os.listdir(model_path) init_without_run.save_model() message = PyTestRegex(f"save best model to {os.path.join(model_path, model_name)}") assert caplog.record_tuples[1] == ("root", 10, message) assert model_name in os.listdir(model_path) def test_load_best_model_no_weights(self, init_without_run, caplog): caplog.set_level(logging.DEBUG) init_without_run.load_best_model("notExisting") assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting")) assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload...")) def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, model_path): init_without_run.save_callbacks_as_json(history, learning_rate) assert "history.json" in os.listdir(model_path) def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, model_path): init_without_run.save_callbacks_as_json(history, learning_rate) assert "history_lr.json" in os.listdir(model_path) def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, model_path): init_without_run.save_callbacks_as_json(history, learning_rate) with open(os.path.join(model_path, "history.json")) as jfile: hist = json.load(jfile) assert hist == history.history def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, model_path): init_without_run.save_callbacks_as_json(history, learning_rate) with open(os.path.join(model_path, "history_lr.json")) as jfile: lr = json.load(jfile) assert lr == learning_rate.lr def test_create_monitoring_plots(self, init_without_run, learning_rate, history, path): assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0 history.model.output_names = mock.MagicMock(return_value=["Main"]) history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"]) init_without_run.create_monitoring_plots(history, learning_rate) assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2