diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py index d595463ff0f42818d634018fa8ddb03e5b54a2db..7f8c7a051542ff7fc317c0c92454c28f1d0d70b5 100644 --- a/test/test_modules/test_model_setup.py +++ b/test/test_modules/test_model_setup.py @@ -27,7 +27,10 @@ class TestModelSetup: @pytest.fixture def setup_with_gen(self, setup, gen): setup.data_store.put("generator", gen, "general.train") - return setup + setup.data_store.put("window_history_size", gen.window_history_size, "general") + setup.data_store.put("window_lead_time", gen.window_lead_time, "general") + yield setup + RunEnvironment().__del__() def test_set_checkpoint(self, setup): assert "general.modeltest" not in setup.data_store.search_name("checkpoint") @@ -43,7 +46,6 @@ class TestModelSetup: def test_build_model(self, setup_with_gen): setup_with_gen.my_model_settings() - setup_with_gen.data_store.put("window_lead_time", 2, "general") assert setup_with_gen.model is None setup_with_gen.build_model() assert isinstance(setup_with_gen.model, keras.Model) diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index 1af910ee660510c5667e6170c82c079dcf515bb2..13abe62a2b9199ad8d92528ff5363bd54f1be221 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -87,8 +87,8 @@ class TestPreProcessing: def test_check_valid_stations(self, caplog, obj_with_exp_setup): pre = obj_with_exp_setup caplog.set_level(logging.INFO) - args = pre._create_args_dict(DEFAULT_ARGS_LIST) - kwargs = pre._create_args_dict(DEFAULT_KWARGS_LIST) + args = pre.data_store.create_args_dict(DEFAULT_ARGS_LIST) + kwargs = pre.data_store.create_args_dict(DEFAULT_KWARGS_LIST) stations = pre.data_store.get("stations", "general") valid_stations = pre.check_valid_stations(args, kwargs, stations) assert len(valid_stations) < len(stations) @@ -105,11 +105,11 @@ class TestPreProcessing: assert dummy_list[test] == list(range(13, 15)) def test_create_args_dict_default_scope(self, obj_super_init): - assert obj_super_init._create_args_dict(["NAME1", "NAME2"]) == {"NAME1": 1, "NAME2": 2} + assert obj_super_init.data_store.create_args_dict(["NAME1", "NAME2"]) == {"NAME1": 1, "NAME2": 2} def test_create_args_dict_given_scope(self, obj_super_init): - assert obj_super_init._create_args_dict(["NAME1", "NAME2"], scope="general.sub") == {"NAME1": 10, "NAME2": 2} + assert obj_super_init.data_store.create_args_dict(["NAME1", "NAME2"], scope="general.sub") == {"NAME1": 10, "NAME2": 2} def test_create_args_dict_missing_entry(self, obj_super_init): - assert obj_super_init._create_args_dict(["NAME5", "NAME2"]) == {"NAME2": 2} - assert obj_super_init._create_args_dict(["NAME4", "NAME2"]) == {"NAME2": 2} + assert obj_super_init.data_store.create_args_dict(["NAME5", "NAME2"]) == {"NAME2": 2} + assert obj_super_init.data_store.create_args_dict(["NAME4", "NAME2"]) == {"NAME2": 2}