Skip to content
Snippets Groups Projects
Select Git revision
  • a3e31215bfd9d6d4c266446fd577643af2ca4a8d
  • master default protected
  • enxhi_issue460_remove_TOAR-I_access
  • michael_issue459_preprocess_german_stations
  • sh_pollutants
  • develop protected
  • release_v2.4.0
  • michael_issue450_feat_load-ifs-data
  • lukas_issue457_feat_set-config-paths-as-parameter
  • lukas_issue454_feat_use-toar-statistics-api-v2
  • lukas_issue453_refac_advanced-retry-strategy
  • lukas_issue452_bug_update-proj-version
  • lukas_issue449_refac_load-era5-data-from-toar-db
  • lukas_issue451_feat_robust-apriori-estimate-for-short-timeseries
  • lukas_issue448_feat_load-model-from-path
  • lukas_issue447_feat_store-and-load-local-clim-apriori-data
  • lukas_issue445_feat_data-insight-plot-monthly-distribution
  • lukas_issue442_feat_bias-free-evaluation
  • lukas_issue444_feat_choose-interp-method-cams
  • 414-include-crps-analysis-and-other-ens-verif-methods-or-plots
  • lukas_issue384_feat_aqw-data-handler
  • v2.4.0 protected
  • v2.3.0 protected
  • v2.2.0 protected
  • v2.1.0 protected
  • Kleinert_etal_2022_initial_submission
  • v2.0.0 protected
  • v1.5.0 protected
  • v1.4.0 protected
  • v1.3.0 protected
  • v1.2.1 protected
  • v1.2.0 protected
  • v1.1.0 protected
  • IntelliO3-ts-v1.0_R1-submit
  • v1.0.0 protected
  • v0.12.2 protected
  • v0.12.1 protected
  • v0.12.0 protected
  • v0.11.0 protected
  • v0.10.0 protected
  • IntelliO3-ts-v1.0_initial-submit
41 results

test_training.py

Blame
  • test_training.py 13.59 KiB
    import glob
    import json
    import logging
    import os
    import shutil
    
    import keras
    import mock
    import pytest
    from keras.callbacks import History
    
    from mlair.data_handler import DataCollection, KerasIterator, DefaultDataHandler
    from mlair.helpers import PyTestRegex
    from mlair.model_modules.flatten import flatten_tail
    from mlair.model_modules.inception_model import InceptionModelBase
    from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler
    from mlair.run_modules.run_environment import RunEnvironment
    from mlair.run_modules.training import Training
    
    
    def my_test_model(activation, window_history_size, channels, output_size, dropout_rate, add_minor_branch=False):
        inception_model = InceptionModelBase()
        conv_settings_dict1 = {
            'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation},
            'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation}, }
        pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation}
        X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels))
        X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1)
        if add_minor_branch:
            out = [flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=4,
                                output_activation='linear', reduction_filter=64,
                                name='Minor_1', dropout_rate=dropout_rate,
                                )]
        else:
            out = []
        X_in = keras.layers.Dropout(dropout_rate)(X_in)
        out.append(flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=output_size,
                                output_activation='linear', reduction_filter=64,
                                name='Main', dropout_rate=dropout_rate,
                                ))
        return keras.Model(inputs=X_input, outputs=out)
    
    
    class TestTraining:
    
        @pytest.fixture
        def init_without_run(self, path: str, model: keras.Model, callbacks: CallbackHandler, model_path, batch_path):
            obj = object.__new__(Training)
            super(Training, obj).__init__()
            obj.model = model
            obj.train_set = None
            obj.val_set = None
            obj.test_set = None
            obj.batch_size = 256
            obj.epochs = 2
            clbk, hist, lr = callbacks
            obj.callbacks = clbk
            obj.lr_sc = lr
            obj.hist = hist
            obj.experiment_name = "TestExperiment"
            obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_train_gen"), "general.train")
            obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_val_gen"), "general.val")
            obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_test_gen"), "general.test")
            if not os.path.exists(path):
                os.makedirs(path)
            obj.data_store.set("experiment_path", path, "general")
            os.makedirs(batch_path)
            obj.data_store.set("batch_path", batch_path, "general")
            os.makedirs(model_path)
            obj.data_store.set("model_path", model_path, "general")
            obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
            obj.data_store.set("experiment_name", "TestExperiment", "general")
    
            path_plot = os.path.join(path, "plots")
            os.makedirs(path_plot)
            obj.data_store.set("plot_path", path_plot, "general")
            obj._train_model = True
            obj._create_new_model = False
            yield obj
            if os.path.exists(path):
                shutil.rmtree(path)
            RunEnvironment().__del__()
    
        @pytest.fixture
        def learning_rate(self):
            lr = LearningRateDecay()
            lr.lr = {"lr": [0.01, 0.0094]}
            return lr
    
        @pytest.fixture
        def history(self):
            h = History()
            h.epoch = [0, 1]
            h.history = {'val_loss': [0.5586272982587484, 0.45712877659670287],
                         'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287],
                         'val_mean_absolute_error': [0.595368885413389, 0.530547587585537],
                         'loss': [0.6795708956961347, 0.45963566494176616],
                         'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
                         'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
            h.model = mock.MagicMock()
            return h
    
        @pytest.fixture
        def path(self):
            return os.path.join(os.path.dirname(__file__), "TestExperiment")
    
        @pytest.fixture
        def model_path(self, path):
            return os.path.join(path, "model")
    
        @pytest.fixture
        def batch_path(self, path):
            return os.path.join(path, "batch")
    
        @pytest.fixture
        def window_history_size(self):
            return 7
    
        @pytest.fixture
        def window_lead_time(self):
            return 2
    
        @pytest.fixture
        def statistics_per_var(self):
            return {'o3': 'dma8eu', 'temp': 'maximum'}
    
        @pytest.fixture
        def data_collection(self, path, window_history_size, window_lead_time, statistics_per_var):
            data_prep = DefaultDataHandler.build(['DEBW107'], data_path=os.path.join(path, 'data'),
                                                 experiment_path=os.path.join(path, 'exp_path'),
                                                 statistics_per_var=statistics_per_var, station_type="background",
                                                 network="AIRBASE", sampling="daily", target_dim="variables",
                                                 target_var="o3", time_dim="datetime",
                                                 window_history_size=window_history_size,
                                                 window_lead_time=window_lead_time, name_affix="train")
            return DataCollection([data_prep])
    
        @pytest.fixture
        def model(self, window_history_size, window_lead_time, statistics_per_var):
            channels = len(list(statistics_per_var.keys()))
            return my_test_model(keras.layers.PReLU, window_history_size, channels, window_lead_time, 0.1, False)
    
        @pytest.fixture
        def callbacks(self, path):
            clbk = CallbackHandler()
            hist = HistoryAdvanced()
            clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
            lr = LearningRateDecay()
            clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
            clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
                                         save_best_only=True)
            return clbk, hist, lr
    
        @pytest.fixture
        def ready_to_train(self, data_collection: DataCollection, init_without_run: Training, batch_path: str):
            batch_size = init_without_run.batch_size
            model = init_without_run.model
            init_without_run.train_set = KerasIterator(data_collection, batch_size, batch_path, model=model, name="train")
            init_without_run.val_set = KerasIterator(data_collection, batch_size, batch_path, model=model, name="val")
            init_without_run.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
            return init_without_run
    
        @pytest.fixture
        def ready_to_run(self, data_collection, init_without_run):
            obj = init_without_run
            obj.data_store.set("data_collection", data_collection, "general.train")
            obj.data_store.set("data_collection", data_collection, "general.val")
            obj.data_store.set("data_collection", data_collection, "general.test")
            obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
            return obj
    
        @pytest.fixture
        def ready_to_init(self, data_collection, model, callbacks, path, model_path, batch_path):
            if not os.path.exists(path):
                os.makedirs(path)
            os.makedirs(model_path)
            obj = RunEnvironment()
            obj.data_store.set("data_collection", data_collection, "general.train")
            obj.data_store.set("data_collection", data_collection, "general.val")
            obj.data_store.set("data_collection", data_collection, "general.test")
            model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
            obj.data_store.set("model", model, "general.model")
            obj.data_store.set("model_path", model_path, "general")
            obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
            obj.data_store.set("batch_size", 256, "general")
            obj.data_store.set("epochs", 2, "general")
            clbk, hist, lr = callbacks
            obj.data_store.set("callbacks", clbk, "general.model")
            obj.data_store.set("lr_decay", lr, "general.model")
            obj.data_store.set("hist", hist, "general.model")
            obj.data_store.set("experiment_name", "TestExperiment", "general")
            obj.data_store.set("experiment_path", path, "general")
            obj.data_store.set("train_model", True, "general")
            obj.data_store.set("create_new_model", True, "general")
            os.makedirs(batch_path)
            obj.data_store.set("batch_path", batch_path, "general")
            path_plot = os.path.join(path, "plots")
            os.makedirs(path_plot)
            obj.data_store.set("plot_path", path_plot, "general")
            yield obj
            if os.path.exists(path):
                shutil.rmtree(path)
    
        def test_init(self, ready_to_init):
            assert isinstance(Training(), Training)  # just test, if nothing fails
    
        def test_no_training(self, ready_to_init, caplog):
            caplog.set_level(logging.INFO)
            ready_to_init.data_store.set("train_model", False)
            Training()
            message = "No training has started, because train_model parameter was false."
            assert caplog.record_tuples[-2] == ("root", 20, message)
    
        def test_run(self, ready_to_run):
            assert ready_to_run._run() is None  # just test, if nothing fails
    
        def test_make_predict_function(self, init_without_run):
            assert hasattr(init_without_run.model, "predict_function") is False
            init_without_run.make_predict_function()
            assert hasattr(init_without_run.model, "predict_function")
    
        def test_set_gen(self, init_without_run):
            assert init_without_run.train_set is None
            init_without_run._set_gen("train")
            assert isinstance(init_without_run.train_set, KerasIterator)
            assert init_without_run.train_set._collection.return_value == "mock_train_gen"
    
        def test_set_generators(self, init_without_run):
            sets = ["train", "val", "test"]
            assert all([getattr(init_without_run, f"{obj}_set") is None for obj in sets])
            init_without_run.set_generators()
            assert not all([getattr(init_without_run, f"{obj}_set") is None for obj in sets])
            assert all(
                [getattr(init_without_run, f"{obj}_set")._collection.return_value == f"mock_{obj}_gen" for obj in sets])
    
        def test_train(self, ready_to_train, path):
            assert not hasattr(ready_to_train.model, "history")
            assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0
            ready_to_train.train()
            assert list(ready_to_train.model.history.history.keys()) == ["val_loss", "loss"]
            assert ready_to_train.model.history.epoch == [0, 1]
            assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
    
        def test_save_model(self, init_without_run, model_path, caplog):
            caplog.set_level(logging.DEBUG)
            model_name = "test_model.h5"
            assert model_name not in os.listdir(model_path)
            init_without_run.save_model()
            message = PyTestRegex(f"save best model to {os.path.join(model_path, model_name)}")
            assert caplog.record_tuples[1] == ("root", 10, message)
            assert model_name in os.listdir(model_path)
    
        def test_load_best_model_no_weights(self, init_without_run, caplog):
            caplog.set_level(logging.DEBUG)
            init_without_run.load_best_model("notExisting")
            assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting"))
            assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))
    
        def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, model_path):
            init_without_run.save_callbacks_as_json(history, learning_rate)
            assert "history.json" in os.listdir(model_path)
    
        def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, model_path):
            init_without_run.save_callbacks_as_json(history, learning_rate)
            assert "history_lr.json" in os.listdir(model_path)
    
        def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, model_path):
            init_without_run.save_callbacks_as_json(history, learning_rate)
            with open(os.path.join(model_path, "history.json")) as jfile:
                hist = json.load(jfile)
                assert hist == history.history
    
        def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, model_path):
            init_without_run.save_callbacks_as_json(history, learning_rate)
            with open(os.path.join(model_path, "history_lr.json")) as jfile:
                lr = json.load(jfile)
                assert lr == learning_rate.lr
    
        def test_create_monitoring_plots(self, init_without_run, learning_rate, history, path):
            assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0
            history.model.output_names = mock.MagicMock(return_value=["Main"])
            history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"])
            init_without_run.create_monitoring_plots(history, learning_rate)
            assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2