Skip to content
Snippets Groups Projects
Select Git revision
  • b7dc46f15ee0b7fe8aba804325b3a322b9ce10e4
  • master default protected
  • enxhi_issue460_remove_TOAR-I_access
  • michael_issue459_preprocess_german_stations
  • sh_pollutants
  • develop protected
  • release_v2.4.0
  • michael_issue450_feat_load-ifs-data
  • lukas_issue457_feat_set-config-paths-as-parameter
  • lukas_issue454_feat_use-toar-statistics-api-v2
  • lukas_issue453_refac_advanced-retry-strategy
  • lukas_issue452_bug_update-proj-version
  • lukas_issue449_refac_load-era5-data-from-toar-db
  • lukas_issue451_feat_robust-apriori-estimate-for-short-timeseries
  • lukas_issue448_feat_load-model-from-path
  • lukas_issue447_feat_store-and-load-local-clim-apriori-data
  • lukas_issue445_feat_data-insight-plot-monthly-distribution
  • lukas_issue442_feat_bias-free-evaluation
  • lukas_issue444_feat_choose-interp-method-cams
  • 414-include-crps-analysis-and-other-ens-verif-methods-or-plots
  • lukas_issue384_feat_aqw-data-handler
  • v2.4.0 protected
  • v2.3.0 protected
  • v2.2.0 protected
  • v2.1.0 protected
  • Kleinert_etal_2022_initial_submission
  • v2.0.0 protected
  • v1.5.0 protected
  • v1.4.0 protected
  • v1.3.0 protected
  • v1.2.1 protected
  • v1.2.0 protected
  • v1.1.0 protected
  • IntelliO3-ts-v1.0_R1-submit
  • v1.0.0 protected
  • v0.12.2 protected
  • v0.12.1 protected
  • v0.12.0 protected
  • v0.11.0 protected
  • v0.10.0 protected
  • IntelliO3-ts-v1.0_initial-submit
41 results

run_mixed_sampling.py

Blame
  • test_model_setup.py 4.30 KiB
    import os
    
    import pytest
    
    from src.data_handling.data_generator import DataGenerator
    from src.helpers.datastore import EmptyScope
    from src.model_modules.keras_extensions import CallbackHandler
    from src.model_modules.model_class import AbstractModelClass, MyLittleModel
    from src.run_modules.model_setup import ModelSetup
    from src.run_modules.run_environment import RunEnvironment
    
    
    class TestModelSetup:
    
        @pytest.fixture
        def setup(self):
            obj = object.__new__(ModelSetup)
            super(ModelSetup, obj).__init__()
            obj.scope = "general.model"
            obj.model = None
            obj.callbacks_name = "placeholder_%s_str.pickle"
            obj.data_store.set("model_class", MyLittleModel)
            obj.data_store.set("lr_decay", "dummy_str", "general.model")
            obj.data_store.set("hist", "dummy_str", "general.model")
            obj.data_store.set("epochs", 2)
            obj.model_name = "%s.h5"
            yield obj
            RunEnvironment().__del__()
    
        @pytest.fixture
        def gen(self):
            return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
                                 'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
    
        @pytest.fixture
        def setup_with_gen(self, setup, gen):
            setup.data_store.set("generator", gen, "general.train")
            setup.data_store.set("window_history_size", gen.window_history_size, "general")
            setup.data_store.set("window_lead_time", gen.window_lead_time, "general")
            setup.data_store.set("channels", 2, "general")
            yield setup
            RunEnvironment().__del__()
    
        @pytest.fixture
        def setup_with_gen_tiny(self, setup, gen):
            setup.data_store.set("generator", gen, "general.train")
            yield setup
            RunEnvironment().__del__()
    
        @pytest.fixture
        def setup_with_model(self, setup):
            setup.model = AbstractModelClass()
            setup.model.test_param = "42"
            yield setup
            RunEnvironment().__del__()
    
        @staticmethod
        def current_scope_as_set(model_cls):
            return set(model_cls.data_store.search_scope(model_cls.scope, current_scope_only=True))
    
        def test_set_callbacks(self, setup):
            assert "general.model" not in setup.data_store.search_name("callbacks")
            setup.checkpoint_name = "TestName"
            setup._set_callbacks()
            assert "general.model" in setup.data_store.search_name("callbacks")
            callbacks = setup.data_store.get("callbacks", "general.model")
            assert len(callbacks.get_callbacks()) == 3
    
        def test_set_callbacks_no_lr_decay(self, setup):
            setup.data_store.set("lr_decay", None, "general.model")
            assert "general.model" not in setup.data_store.search_name("callbacks")
            setup.checkpoint_name = "TestName"
            setup._set_callbacks()
            callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model")
            assert len(callbacks.get_callbacks()) == 2
            with pytest.raises(IndexError):
                callbacks.get_callback_by_name("lr_decay")
    
        def test_get_model_settings(self, setup_with_model):
            setup_with_model.scope = "model_test"
            with pytest.raises(EmptyScope):
                self.current_scope_as_set(setup_with_model)  # will fail because scope is not created
            setup_with_model.get_model_settings()  # this saves now the parameter test_param into scope
            assert {"test_param", "model_name"} <= self.current_scope_as_set(setup_with_model)
    
        def test_build_model(self, setup_with_gen):
            assert setup_with_gen.model is None
            setup_with_gen.build_model()
            assert isinstance(setup_with_gen.model, AbstractModelClass)
            expected = {"window_history_size", "window_lead_time", "channels", "dropout_rate", "regularizer", "initial_lr",
                        "optimizer", "batch_size", "activation"}
            assert expected <= self.current_scope_as_set(setup_with_gen)
    
        def test_set_channels(self, setup_with_gen_tiny):
            assert len(setup_with_gen_tiny.data_store.search_name("channels")) == 0
            setup_with_gen_tiny._set_channels()
            assert setup_with_gen_tiny.data_store.get("channels", setup_with_gen_tiny.scope) == 2
    
        def test_load_weights(self):
            pass
    
        def test_compile_model(self):
            pass
    
        def test_run(self):
            pass
    
        def test_init(self):
            pass