Skip to content
Snippets Groups Projects
Select Git revision
  • e2b8a816bdb633f5edb133065dca1c2ba057e097
  • 2023 default
  • pages protected
  • 2022-matse
  • 2022
  • 2021
  • 2019
  • master
8 results

Introduction-to-Pandas--master.ipynb

Blame
  • test_model_setup.py 3.19 KiB
    import pytest
    import os
    import keras
    import mock
    
    from src.modules.model_setup import ModelSetup
    from src.modules.run_environment import RunEnvironment
    from src.data_handling.data_generator import DataGenerator
    from src.model_modules.model_class import AbstractModelClass
    from src.datastore import EmptyScope
    
    
    class TestModelSetup:
    
        @pytest.fixture
        def setup(self):
            obj = object.__new__(ModelSetup)
            super(ModelSetup, obj).__init__()
            obj.scope = "general.modeltest"
            obj.model = None
            yield obj
            RunEnvironment().__del__()
    
        @pytest.fixture
        def gen(self):
            return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
                                 'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
    
        @pytest.fixture
        def setup_with_gen(self, setup, gen):
            setup.data_store.set("generator", gen, "general.train")
            setup.data_store.set("window_history_size", gen.window_history_size, "general")
            setup.data_store.set("window_lead_time", gen.window_lead_time, "general")
            yield setup
            RunEnvironment().__del__()
    
        @pytest.fixture
        def setup_with_model(self, setup_with_gen):
            setup_with_gen.data_store.set("channels", 2, "general")
            setup_with_gen.model = AbstractModelClass()
            setup_with_gen.model.epochs = 2
            setup_with_gen.model.batch_size = int(256)
            yield setup_with_gen
            RunEnvironment().__del__()
    
        @staticmethod
        def current_scope_as_set(model_cls):
            return set(model_cls.data_store.search_scope(model_cls.scope, current_scope_only=True))
    
        def test_set_checkpoint(self, setup):
            assert "general.modeltest" not in setup.data_store.search_name("checkpoint")
            setup.checkpoint_name = "TestName"
            setup._set_checkpoint()
            assert "general.modeltest" in setup.data_store.search_name("checkpoint")
    
        def test_get_model_settings(self, setup_with_model):
            with pytest.raises(EmptyScope):
                self.current_scope_as_set(setup_with_model)  # will fail because scope is not created
            setup_with_model.get_model_settings()  # this saves now the parameters epochs and batch_size into scope
            assert {"epochs", "batch_size"} <= self.current_scope_as_set(setup_with_model)
    
        def test_build_model(self, setup_with_gen):
            assert setup_with_gen.model is None
            setup_with_gen.build_model()
            assert isinstance(setup_with_gen.model, AbstractModelClass)
            expected = {"window_history_size", "window_lead_time", "channels", "dropout_rate", "regularizer", "initial_lr",
                        "optimizer", "lr_decay", "epochs", "batch_size", "activation"}
            assert expected <= self.current_scope_as_set(setup_with_gen)
    
        def test_set_channels(self, setup_with_gen):
            assert len(setup_with_gen.data_store.search_name("channels")) == 0
            setup_with_gen._set_channels()
            assert setup_with_gen.data_store.get("channels", setup_with_gen.scope) == 2
    
        def test_load_weights(self):
            pass
    
        def test_compile_model(self):
            pass
    
        def test_run(self):
            pass
    
        def test_init(self):
            pass