Skip to content
Snippets Groups Projects
Select Git revision
  • 3e7939da685bc6f3f47b074d24c83ad4dd61de95
  • master default protected
  • enxhi_issue460_remove_TOAR-I_access
  • michael_issue459_preprocess_german_stations
  • sh_pollutants
  • develop protected
  • release_v2.4.0
  • michael_issue450_feat_load-ifs-data
  • lukas_issue457_feat_set-config-paths-as-parameter
  • lukas_issue454_feat_use-toar-statistics-api-v2
  • lukas_issue453_refac_advanced-retry-strategy
  • lukas_issue452_bug_update-proj-version
  • lukas_issue449_refac_load-era5-data-from-toar-db
  • lukas_issue451_feat_robust-apriori-estimate-for-short-timeseries
  • lukas_issue448_feat_load-model-from-path
  • lukas_issue447_feat_store-and-load-local-clim-apriori-data
  • lukas_issue445_feat_data-insight-plot-monthly-distribution
  • lukas_issue442_feat_bias-free-evaluation
  • lukas_issue444_feat_choose-interp-method-cams
  • 414-include-crps-analysis-and-other-ens-verif-methods-or-plots
  • lukas_issue384_feat_aqw-data-handler
  • v2.4.0 protected
  • v2.3.0 protected
  • v2.2.0 protected
  • v2.1.0 protected
  • Kleinert_etal_2022_initial_submission
  • v2.0.0 protected
  • v1.5.0 protected
  • v1.4.0 protected
  • v1.3.0 protected
  • v1.2.1 protected
  • v1.2.0 protected
  • v1.1.0 protected
  • IntelliO3-ts-v1.0_R1-submit
  • v1.0.0 protected
  • v0.12.2 protected
  • v0.12.1 protected
  • v0.12.0 protected
  • v0.11.0 protected
  • v0.10.0 protected
  • IntelliO3-ts-v1.0_initial-submit
41 results

datastore.py

Blame
  • test_training.py 17.23 KiB
    import copy
    import glob
    import json
    import time
    
    import logging
    import os
    import shutil
    from typing import Callable
    
    import tensorflow.keras as keras
    import mock
    import pytest
    from tensorflow.keras.callbacks import History
    
    from mlair.data_handler import DataCollection, KerasIterator, DefaultDataHandler
    from mlair.helpers import PyTestRegex
    from mlair.model_modules.fully_connected_networks import FCN
    from mlair.model_modules.flatten import flatten_tail
    from mlair.model_modules.inception_model import InceptionModelBase
    from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler, EpoTimingCallback
    from mlair.run_modules.run_environment import RunEnvironment
    from mlair.run_modules.training import Training
    
    
    def my_test_model(activation, window_history_size, channels, output_size, dropout_rate, add_minor_branch=False):
        inception_model = InceptionModelBase()
        conv_settings_dict1 = {
            'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation},
            'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation}, }
        pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation}
        X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels))
        X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1)
        if add_minor_branch:
            out = [flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=4,
                                output_activation='linear', reduction_filter=64,
                                name='Minor_1', dropout_rate=dropout_rate,
                                )]
        else:
            out = []
        X_in = keras.layers.Dropout(dropout_rate)(X_in)
        out.append(flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=output_size,
                                output_activation='linear', reduction_filter=64,
                                name='Main', dropout_rate=dropout_rate,
                                ))
        return keras.Model(inputs=X_input, outputs=out)
    
    
    class TestTraining:
    
        @pytest.fixture
        def init_without_run(self, path: str, model: keras.Model, callbacks: CallbackHandler, model_path, batch_path):
            obj = object.__new__(Training)
            super(Training, obj).__init__()
            obj.model = model
            obj.train_set = None
            obj.val_set = None
            obj.test_set = None
            obj.batch_size = 256
            obj.epochs = 2
            clbk, hist, lr = callbacks
            obj.callbacks = clbk
            obj.lr_sc = lr
            obj.hist = hist
            obj.experiment_name = "TestExperiment"
            obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_train_gen"), "general.train")
            obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_val_gen"), "general.val")
            obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_test_gen"), "general.test")
            if not os.path.exists(path):
                os.makedirs(path)
            obj.data_store.set("experiment_path", path, "general")
            os.makedirs(batch_path)
            obj.data_store.set("batch_path", batch_path, "general")
            os.makedirs(model_path)
            obj.data_store.set("model_path", model_path, "general")
            obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
            obj.data_store.set("experiment_name", "TestExperiment", "general")
    
            path_plot = os.path.join(path, "plots")
            os.makedirs(path_plot)
            obj.data_store.set("plot_path", path_plot, "general")
            obj._train_model = True
            obj._create_new_model = False
            try:
                yield obj
            finally:
                if os.path.exists(path):
                    shutil.rmtree(path)
                try:
                    RunEnvironment().__del__()
                except AssertionError:
                    pass
            # try:
            #     yield obj
            # finally:
            #     if os.path.exists(path):
            #         shutil.rmtree(path)
            #     try:
            #         RunEnvironment().__del__()
            #     except AssertionError:
            #         pass
    
        @pytest.fixture
        def learning_rate(self):
            lr = LearningRateDecay()
            lr.lr = {"lr": [0.01, 0.0094]}
            return lr
    
        @pytest.fixture
        def history(self):
            h = History()
            h.epoch = [0, 1]
            h.history = {'val_loss': [0.5586272982587484, 0.45712877659670287],
                         'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287],
                         'val_mean_absolute_error': [0.595368885413389, 0.530547587585537],
                         'loss': [0.6795708956961347, 0.45963566494176616],
                         'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
                         'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
            h.model = mock.MagicMock()
            return h
    
        @pytest.fixture
        def epo_timing(self):
            epo_timing = EpoTimingCallback()
            epo_timing.epoch = [0, 1]
            epo_timing.epo_timing = {"epo_timing": [0.1, 0.2]}
    
        @pytest.fixture
        def path(self):
            return os.path.join(os.path.dirname(__file__), "TestExperiment")
    
        @pytest.fixture
        def model_path(self, path):
            return os.path.join(path, "model")
    
        @pytest.fixture
        def batch_path(self, path):
            return os.path.join(path, "batch")
    
        @pytest.fixture
        def window_history_size(self):
            return 7
    
        @pytest.fixture
        def window_lead_time(self):
            return 2
    
        @pytest.fixture
        def statistics_per_var(self):
            return {'o3': 'dma8eu', 'temp': 'maximum'}
    
        @pytest.fixture
        def data_collection(self, path, window_history_size, window_lead_time, statistics_per_var):
            data_prep = DefaultDataHandler.build(['DEBW107'], data_path=os.path.join(path, 'data'),
                                                 experiment_path=os.path.join(path, 'exp_path'),
                                                 statistics_per_var=statistics_per_var, station_type="background",
                                                 network="AIRBASE", sampling="daily", target_dim="variables",
                                                 target_var="o3", time_dim="datetime",
                                                 window_history_size=window_history_size,
                                                 window_lead_time=window_lead_time, name_affix="train")
            return DataCollection([data_prep])
    
        @pytest.fixture
        def model(self, window_history_size, window_lead_time, statistics_per_var):
            channels = len(list(statistics_per_var.keys()))
            return FCN([(window_history_size + 1, 1, channels)], [window_lead_time])
    
        @pytest.fixture
        def callbacks(self, path):
            clbk = CallbackHandler()
            hist = HistoryAdvanced()
            epo_timing = EpoTimingCallback()
            clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
            lr = LearningRateDecay()
            clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
            clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing")
            clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
                                         save_best_only=True)
            return clbk, hist, lr
    
        @pytest.fixture
        def ready_to_train(self, data_collection: DataCollection, init_without_run: Training, batch_path: str):
            batch_size = init_without_run.batch_size
            model = init_without_run.model
            init_without_run.train_set = KerasIterator(data_collection, batch_size, batch_path, model=model, name="train")
            init_without_run.val_set = KerasIterator(data_collection, batch_size, batch_path, model=model, name="val")
            init_without_run.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
            return init_without_run
    
        @pytest.fixture
        def ready_to_run(self, data_collection, init_without_run):
            obj = init_without_run
            obj.data_store.set("data_collection", data_collection, "general.train")
            obj.data_store.set("data_collection", data_collection, "general.val")
            obj.data_store.set("data_collection", data_collection, "general.test")
            obj.model.compile(**obj.model.compile_options)
            keras.utils.get_custom_objects().update(obj.model.custom_objects)
            return obj
    
        @pytest.fixture
        def ready_to_init(self, data_collection, model, callbacks, path, model_path, batch_path):
            if not os.path.exists(path):
                os.makedirs(path)
            os.makedirs(model_path)
            obj = RunEnvironment()
            obj.data_store.set("data_collection", data_collection, "general.train")
            obj.data_store.set("data_collection", data_collection, "general.val")
            obj.data_store.set("data_collection", data_collection, "general.test")
            model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
            obj.data_store.set("model", model, "general.model")
            obj.data_store.set("model_path", model_path, "general")
            obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
            obj.data_store.set("batch_size", 256, "general")
            obj.data_store.set("epochs", 2, "general")
            clbk, hist, lr = callbacks
            obj.data_store.set("callbacks", clbk, "general.model")
            obj.data_store.set("lr_decay", lr, "general.model")
            obj.data_store.set("hist", hist, "general.model")
            obj.data_store.set("experiment_name", "TestExperiment", "general")
            obj.data_store.set("experiment_path", path, "general")
            obj.data_store.set("train_model", True, "general")
            obj.data_store.set("create_new_model", True, "general")
            os.makedirs(batch_path)
            obj.data_store.set("batch_path", batch_path, "general")
            path_plot = os.path.join(path, "plots")
            os.makedirs(path_plot)
            obj.data_store.set("plot_path", path_plot, "general")
            yield obj
            if os.path.exists(path):
                shutil.rmtree(path)
    
        @staticmethod
        def create_training_obj(epochs, path, data_collection, batch_path, model_path,
                                statistics_per_var, window_history_size, window_lead_time) -> Training:
    
            channels = len(list(statistics_per_var.keys()))
            model = FCN([(window_history_size + 1, 1, channels)], [window_lead_time])
    
            obj = object.__new__(Training)
            super(Training, obj).__init__()
            obj.model = model
            obj.train_set = None
            obj.val_set = None
            obj.test_set = None
            obj.batch_size = 256
            obj.epochs = epochs
    
            clbk = CallbackHandler()
            hist = HistoryAdvanced()
            epo_timing = EpoTimingCallback()
            clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
            lr = LearningRateDecay()
            clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
            clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing")
            clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
                                         save_best_only=True)
            obj.callbacks = clbk
            obj.lr_sc = lr
            obj.hist = hist
            obj.experiment_name = "TestExperiment"
            obj.data_store.set("data_collection", data_collection, "general.train")
            obj.data_store.set("data_collection", data_collection, "general.val")
            obj.data_store.set("data_collection", data_collection, "general.test")
            if not os.path.exists(path):
                os.makedirs(path)
            obj.data_store.set("experiment_path", path, "general")
            os.makedirs(batch_path, exist_ok=True)
            obj.data_store.set("batch_path", batch_path, "general")
            os.makedirs(model_path, exist_ok=True)
            obj.data_store.set("model_path", model_path, "general")
            obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
            obj.data_store.set("experiment_name", "TestExperiment", "general")
    
            path_plot = os.path.join(path, "plots")
            os.makedirs(path_plot, exist_ok=True)
            obj.data_store.set("plot_path", path_plot, "general")
            obj._train_model = True
            obj._create_new_model = False
    
            obj.model.compile(**obj.model.compile_options)
            return obj
    
        def test_init(self, ready_to_init):
            assert isinstance(Training(), Training)  # just test, if nothing fails
    
        def test_no_training(self, ready_to_init, caplog):
            caplog.set_level(logging.INFO)
            ready_to_init.data_store.set("train_model", False)
            Training()
            message = "No training has started, because train_model parameter was false."
            assert caplog.record_tuples[-2] == ("root", 20, message)
    
        def test_run(self, ready_to_run):
            assert ready_to_run._run() is None  # just test, if nothing fails
    
        def test_make_predict_function(self, init_without_run):
            assert hasattr(init_without_run.model, "predict_function") is True
            assert init_without_run.model.predict_function is None
            init_without_run.make_predict_function()
            assert isinstance(init_without_run.model.predict_function, Callable)
    
        def test_set_gen(self, init_without_run):
            assert init_without_run.train_set is None
            init_without_run._set_gen("train")
            assert isinstance(init_without_run.train_set, KerasIterator)
            assert init_without_run.train_set._collection.return_value == "mock_train_gen"
    
        def test_set_generators(self, init_without_run):
            sets = ["train", "val"]
            assert all([getattr(init_without_run, f"{obj}_set") is None for obj in sets])
            init_without_run.set_generators()
            assert not all([getattr(init_without_run, f"{obj}_set") is None for obj in sets])
            assert all(
                [getattr(init_without_run, f"{obj}_set")._collection.return_value == f"mock_{obj}_gen" for obj in sets])
    
        def test_train(self, ready_to_train, path):
            assert ready_to_train.model.history is None
            assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0
            ready_to_train.train()
            assert sorted(list(ready_to_train.model.history.history.keys())) == ["loss", "val_loss"]
            assert ready_to_train.model.history.epoch == [0, 1]
            assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
    
        def test_save_model(self, init_without_run, model_path, caplog):
            caplog.set_level(logging.DEBUG)
            model_name = "test_model.h5"
            assert model_name not in os.listdir(model_path)
            init_without_run.save_model()
            message = PyTestRegex(f"save model to {os.path.join(model_path, model_name)}")
            assert caplog.record_tuples[1] == ("root", 10, message)
            assert model_name in os.listdir(model_path)
    
        def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
            init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
            assert "history.json" in os.listdir(model_path)
    
        def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
            init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
            assert "history_lr.json" in os.listdir(model_path)
    
        def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, epo_timing, model_path):
            init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
            with open(os.path.join(model_path, "history.json")) as jfile:
                hist = json.load(jfile)
                assert hist == history.history
    
        def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, epo_timing, model_path):
            init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
            with open(os.path.join(model_path, "history_lr.json")) as jfile:
                lr = json.load(jfile)
                assert lr == learning_rate.lr
    
        def test_create_monitoring_plots(self, init_without_run, learning_rate, history, path):
            assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0
            history.model.output_names = mock.MagicMock(return_value=["Main"])
            history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"])
            init_without_run.create_monitoring_plots(history, learning_rate, epoch_best=1)
            assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2
    
        def test_resume_training1(self, path: str, model_path, batch_path, data_collection, statistics_per_var,
                                  window_history_size, window_lead_time):
    
            obj_1st = self.create_training_obj(4, path, data_collection, batch_path, model_path, statistics_per_var,
                                               window_history_size, window_lead_time)
            keras.utils.get_custom_objects().update(obj_1st.model.custom_objects)
            assert obj_1st._run() is None
            obj_2nd = self.create_training_obj(8, path, data_collection, batch_path, model_path, statistics_per_var,
                                               window_history_size, window_lead_time)
            assert obj_2nd._run() is None