From bdd19763109d4d82e04089dd418f42a0448ce3c8 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Wed, 4 Dec 2019 13:42:18 +0100 Subject: [PATCH] adjusted tests --- test/test_modules/test_model_setup.py | 6 ++++-- test/test_modules/test_pre_processing.py | 12 ++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py index d595463f..7f8c7a05 100644 --- a/test/test_modules/test_model_setup.py +++ b/test/test_modules/test_model_setup.py @@ -27,7 +27,10 @@ class TestModelSetup: @pytest.fixture def setup_with_gen(self, setup, gen): setup.data_store.put("generator", gen, "general.train") - return setup + setup.data_store.put("window_history_size", gen.window_history_size, "general") + setup.data_store.put("window_lead_time", gen.window_lead_time, "general") + yield setup + RunEnvironment().__del__() def test_set_checkpoint(self, setup): assert "general.modeltest" not in setup.data_store.search_name("checkpoint") @@ -43,7 +46,6 @@ class TestModelSetup: def test_build_model(self, setup_with_gen): setup_with_gen.my_model_settings() - setup_with_gen.data_store.put("window_lead_time", 2, "general") assert setup_with_gen.model is None setup_with_gen.build_model() assert isinstance(setup_with_gen.model, keras.Model) diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index 1af910ee..13abe62a 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -87,8 +87,8 @@ class TestPreProcessing: def test_check_valid_stations(self, caplog, obj_with_exp_setup): pre = obj_with_exp_setup caplog.set_level(logging.INFO) - args = pre._create_args_dict(DEFAULT_ARGS_LIST) - kwargs = pre._create_args_dict(DEFAULT_KWARGS_LIST) + args = pre.data_store.create_args_dict(DEFAULT_ARGS_LIST) + kwargs = pre.data_store.create_args_dict(DEFAULT_KWARGS_LIST) stations = pre.data_store.get("stations", "general") valid_stations = pre.check_valid_stations(args, kwargs, stations) assert len(valid_stations) < len(stations) @@ -105,11 +105,11 @@ class TestPreProcessing: assert dummy_list[test] == list(range(13, 15)) def test_create_args_dict_default_scope(self, obj_super_init): - assert obj_super_init._create_args_dict(["NAME1", "NAME2"]) == {"NAME1": 1, "NAME2": 2} + assert obj_super_init.data_store.create_args_dict(["NAME1", "NAME2"]) == {"NAME1": 1, "NAME2": 2} def test_create_args_dict_given_scope(self, obj_super_init): - assert obj_super_init._create_args_dict(["NAME1", "NAME2"], scope="general.sub") == {"NAME1": 10, "NAME2": 2} + assert obj_super_init.data_store.create_args_dict(["NAME1", "NAME2"], scope="general.sub") == {"NAME1": 10, "NAME2": 2} def test_create_args_dict_missing_entry(self, obj_super_init): - assert obj_super_init._create_args_dict(["NAME5", "NAME2"]) == {"NAME2": 2} - assert obj_super_init._create_args_dict(["NAME4", "NAME2"]) == {"NAME2": 2} + assert obj_super_init.data_store.create_args_dict(["NAME5", "NAME2"]) == {"NAME2": 2} + assert obj_super_init.data_store.create_args_dict(["NAME4", "NAME2"]) == {"NAME2": 2} -- GitLab