import copy import glob import json import time import logging import os import shutil from typing import Callable import tensorflow.keras as keras import mock import pytest from tensorflow.keras.callbacks import History from mlair.data_handler import DataCollection, KerasIterator, DefaultDataHandler from mlair.helpers import PyTestRegex from mlair.model_modules.fully_connected_networks import FCN from mlair.model_modules.flatten import flatten_tail from mlair.model_modules.inception_model import InceptionModelBase from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler, EpoTimingCallback from mlair.run_modules.run_environment import RunEnvironment from mlair.run_modules.training import Training def my_test_model(activation, window_history_size, channels, output_size, dropout_rate, add_minor_branch=False): inception_model = InceptionModelBase() conv_settings_dict1 = { 'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation}, 'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation}, } pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation} X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels)) X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1) if add_minor_branch: out = [flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=4, output_activation='linear', reduction_filter=64, name='Minor_1', dropout_rate=dropout_rate, )] else: out = [] X_in = keras.layers.Dropout(dropout_rate)(X_in) out.append(flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=output_size, output_activation='linear', reduction_filter=64, name='Main', dropout_rate=dropout_rate, )) return keras.Model(inputs=X_input, outputs=out) class TestTraining: @pytest.fixture def init_without_run(self, path: str, model: keras.Model, callbacks: CallbackHandler, model_path, batch_path): obj = object.__new__(Training) super(Training, obj).__init__() obj.model = model obj.train_set = None obj.val_set = None obj.test_set = None obj.batch_size = 256 obj.epochs = 2 clbk, hist, lr = callbacks obj.callbacks = clbk obj.lr_sc = lr obj.hist = hist obj.experiment_name = "TestExperiment" obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_train_gen"), "general.train") obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_val_gen"), "general.val") obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_test_gen"), "general.test") if not os.path.exists(path): os.makedirs(path) obj.data_store.set("experiment_path", path, "general") os.makedirs(batch_path) obj.data_store.set("batch_path", batch_path, "general") os.makedirs(model_path) obj.data_store.set("model_path", model_path, "general") obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model") obj.data_store.set("experiment_name", "TestExperiment", "general") path_plot = os.path.join(path, "plots") os.makedirs(path_plot) obj.data_store.set("plot_path", path_plot, "general") obj._train_model = True obj._create_new_model = False try: yield obj finally: if os.path.exists(path): shutil.rmtree(path) try: RunEnvironment().__del__() except AssertionError: pass # try: # yield obj # finally: # if os.path.exists(path): # shutil.rmtree(path) # try: # RunEnvironment().__del__() # except AssertionError: # pass @pytest.fixture def learning_rate(self): lr = LearningRateDecay() lr.lr = {"lr": [0.01, 0.0094]} return lr @pytest.fixture def history(self): h = History() h.epoch = [0, 1] h.history = {'val_loss': [0.5586272982587484, 0.45712877659670287], 'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287], 'val_mean_absolute_error': [0.595368885413389, 0.530547587585537], 'loss': [0.6795708956961347, 0.45963566494176616], 'mean_squared_error': [0.6795708956961347, 0.45963566494176616], 'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]} h.model = mock.MagicMock() return h @pytest.fixture def epo_timing(self): epo_timing = EpoTimingCallback() epo_timing.epoch = [0, 1] epo_timing.epo_timing = {"epo_timing": [0.1, 0.2]} @pytest.fixture def path(self): return os.path.join(os.path.dirname(__file__), "TestExperiment") @pytest.fixture def model_path(self, path): return os.path.join(path, "model") @pytest.fixture def batch_path(self, path): return os.path.join(path, "batch") @pytest.fixture def window_history_size(self): return 7 @pytest.fixture def window_lead_time(self): return 2 @pytest.fixture def statistics_per_var(self): return {'o3': 'dma8eu', 'temp': 'maximum'} @pytest.fixture def data_collection(self, path, window_history_size, window_lead_time, statistics_per_var): data_prep = DefaultDataHandler.build(['DEBW107'], data_path=os.path.join(path, 'data'), experiment_path=os.path.join(path, 'exp_path'), statistics_per_var=statistics_per_var, station_type="background", network="AIRBASE", sampling="daily", target_dim="variables", target_var="o3", time_dim="datetime", window_history_size=window_history_size, window_lead_time=window_lead_time, name_affix="train") return DataCollection([data_prep]) @pytest.fixture def model(self, window_history_size, window_lead_time, statistics_per_var): channels = len(list(statistics_per_var.keys())) return FCN([(window_history_size + 1, 1, channels)], [window_lead_time]) @pytest.fixture def callbacks(self, path): clbk = CallbackHandler() hist = HistoryAdvanced() epo_timing = EpoTimingCallback() clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist") lr = LearningRateDecay() clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr") clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing") clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss', save_best_only=True) return clbk, hist, lr @pytest.fixture def ready_to_train(self, data_collection: DataCollection, init_without_run: Training, batch_path: str): batch_size = init_without_run.batch_size model = init_without_run.model init_without_run.train_set = KerasIterator(data_collection, batch_size, batch_path, model=model, name="train") init_without_run.val_set = KerasIterator(data_collection, batch_size, batch_path, model=model, name="val") init_without_run.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) return init_without_run @pytest.fixture def ready_to_run(self, data_collection, init_without_run): obj = init_without_run obj.data_store.set("data_collection", data_collection, "general.train") obj.data_store.set("data_collection", data_collection, "general.val") obj.data_store.set("data_collection", data_collection, "general.test") obj.model.compile(**obj.model.compile_options) keras.utils.get_custom_objects().update(obj.model.custom_objects) return obj @pytest.fixture def ready_to_init(self, data_collection, model, callbacks, path, model_path, batch_path): if not os.path.exists(path): os.makedirs(path) os.makedirs(model_path) obj = RunEnvironment() obj.data_store.set("data_collection", data_collection, "general.train") obj.data_store.set("data_collection", data_collection, "general.val") obj.data_store.set("data_collection", data_collection, "general.test") model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) obj.data_store.set("model", model, "general.model") obj.data_store.set("model_path", model_path, "general") obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model") obj.data_store.set("batch_size", 256, "general") obj.data_store.set("epochs", 2, "general") clbk, hist, lr = callbacks obj.data_store.set("callbacks", clbk, "general.model") obj.data_store.set("lr_decay", lr, "general.model") obj.data_store.set("hist", hist, "general.model") obj.data_store.set("experiment_name", "TestExperiment", "general") obj.data_store.set("experiment_path", path, "general") obj.data_store.set("train_model", True, "general") obj.data_store.set("create_new_model", True, "general") os.makedirs(batch_path) obj.data_store.set("batch_path", batch_path, "general") path_plot = os.path.join(path, "plots") os.makedirs(path_plot) obj.data_store.set("plot_path", path_plot, "general") yield obj if os.path.exists(path): shutil.rmtree(path) @staticmethod def create_training_obj(epochs, path, data_collection, batch_path, model_path, statistics_per_var, window_history_size, window_lead_time) -> Training: channels = len(list(statistics_per_var.keys())) model = FCN([(window_history_size + 1, 1, channels)], [window_lead_time]) obj = object.__new__(Training) super(Training, obj).__init__() obj.model = model obj.train_set = None obj.val_set = None obj.test_set = None obj.batch_size = 256 obj.epochs = epochs clbk = CallbackHandler() hist = HistoryAdvanced() epo_timing = EpoTimingCallback() clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist") lr = LearningRateDecay() clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr") clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing") clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss', save_best_only=True) obj.callbacks = clbk obj.lr_sc = lr obj.hist = hist obj.experiment_name = "TestExperiment" obj.data_store.set("data_collection", data_collection, "general.train") obj.data_store.set("data_collection", data_collection, "general.val") obj.data_store.set("data_collection", data_collection, "general.test") if not os.path.exists(path): os.makedirs(path) obj.data_store.set("experiment_path", path, "general") os.makedirs(batch_path, exist_ok=True) obj.data_store.set("batch_path", batch_path, "general") os.makedirs(model_path, exist_ok=True) obj.data_store.set("model_path", model_path, "general") obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model") obj.data_store.set("experiment_name", "TestExperiment", "general") path_plot = os.path.join(path, "plots") os.makedirs(path_plot, exist_ok=True) obj.data_store.set("plot_path", path_plot, "general") obj._train_model = True obj._create_new_model = False obj.model.compile(**obj.model.compile_options) return obj def test_init(self, ready_to_init): assert isinstance(Training(), Training) # just test, if nothing fails def test_no_training(self, ready_to_init, caplog): caplog.set_level(logging.INFO) ready_to_init.data_store.set("train_model", False) Training() message = "No training has started, because train_model parameter was false." assert caplog.record_tuples[-2] == ("root", 20, message) def test_run(self, ready_to_run): assert ready_to_run._run() is None # just test, if nothing fails def test_make_predict_function(self, init_without_run): assert hasattr(init_without_run.model, "predict_function") is True assert init_without_run.model.predict_function is None init_without_run.make_predict_function() assert isinstance(init_without_run.model.predict_function, Callable) def test_set_gen(self, init_without_run): assert init_without_run.train_set is None init_without_run._set_gen("train") assert isinstance(init_without_run.train_set, KerasIterator) assert init_without_run.train_set._collection.return_value == "mock_train_gen" def test_set_generators(self, init_without_run): sets = ["train", "val"] assert all([getattr(init_without_run, f"{obj}_set") is None for obj in sets]) init_without_run.set_generators() assert not all([getattr(init_without_run, f"{obj}_set") is None for obj in sets]) assert all( [getattr(init_without_run, f"{obj}_set")._collection.return_value == f"mock_{obj}_gen" for obj in sets]) def test_train(self, ready_to_train, path): assert ready_to_train.model.history is None assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0 ready_to_train.train() assert sorted(list(ready_to_train.model.history.history.keys())) == ["loss", "val_loss"] assert ready_to_train.model.history.epoch == [0, 1] assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 def test_save_model(self, init_without_run, model_path, caplog): caplog.set_level(logging.DEBUG) model_name = "test_model.h5" assert model_name not in os.listdir(model_path) init_without_run.save_model() message = PyTestRegex(f"save model to {os.path.join(model_path, model_name)}") assert caplog.record_tuples[1] == ("root", 10, message) assert model_name in os.listdir(model_path) def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path): init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing) assert "history.json" in os.listdir(model_path) def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, epo_timing, model_path): init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing) assert "history_lr.json" in os.listdir(model_path) def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, epo_timing, model_path): init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing) with open(os.path.join(model_path, "history.json")) as jfile: hist = json.load(jfile) assert hist == history.history def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, epo_timing, model_path): init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing) with open(os.path.join(model_path, "history_lr.json")) as jfile: lr = json.load(jfile) assert lr == learning_rate.lr def test_create_monitoring_plots(self, init_without_run, learning_rate, history, path): assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0 history.model.output_names = mock.MagicMock(return_value=["Main"]) history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"]) init_without_run.create_monitoring_plots(history, learning_rate, epoch_best=1) assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 def test_resume_training1(self, path: str, model_path, batch_path, data_collection, statistics_per_var, window_history_size, window_lead_time): obj_1st = self.create_training_obj(4, path, data_collection, batch_path, model_path, statistics_per_var, window_history_size, window_lead_time) keras.utils.get_custom_objects().update(obj_1st.model.custom_objects) assert obj_1st._run() is None obj_2nd = self.create_training_obj(8, path, data_collection, batch_path, model_path, statistics_per_var, window_history_size, window_lead_time) assert obj_2nd._run() is None