diff --git a/test/test_modules/test_experiment_setup.py b/test/test_modules/test_experiment_setup.py index 7a4f16fd055e0af0aea95181d475d061c185ca92..9e6d17627d1697a2150ea7f74a373a720d2f02ac 100644 --- a/test/test_modules/test_experiment_setup.py +++ b/test/test_modules/test_experiment_setup.py @@ -47,7 +47,8 @@ class TestExperimentSetup: data_store = exp_setup.data_store # experiment setup assert data_store.get("data_path", "general") == prepare_host() - assert data_store.get("trainable", "general") is False + assert data_store.get("trainable", "general") is True + assert data_store.get("create_new_model", "general") is True assert data_store.get("fraction_of_training", "general") == 0.8 # set experiment name assert data_store.get("experiment_name", "general") == "TestExperiment" @@ -104,13 +105,14 @@ class TestExperimentSetup: target_var="temp", target_dim="target", window_lead_time=10, dimensions="dim1", interpolate_dim="int_dim", interpolate_method="cubic", limit_nan_fill=5, train_start="2000-01-01", train_end="2000-01-02", val_start="2000-01-03", val_end="2000-01-04", test_start="2000-01-05", - test_end="2000-01-06", use_all_stations_on_all_data_sets=False, trainable=True, - fraction_of_train=0.5, experiment_path=experiment_path) + test_end="2000-01-06", use_all_stations_on_all_data_sets=False, trainable=False, + fraction_of_train=0.5, experiment_path=experiment_path, create_new_model=True) exp_setup = ExperimentSetup(**kwargs) data_store = exp_setup.data_store # experiment setup assert data_store.get("data_path", "general") == prepare_host() assert data_store.get("trainable", "general") is True + assert data_store.get("create_new_model", "general") is True assert data_store.get("fraction_of_training", "general") == 0.5 # set experiment name assert data_store.get("experiment_name", "general") == "TODAY_network" @@ -150,10 +152,30 @@ class TestExperimentSetup: # use all stations on all data sets (train, val, test) assert data_store.get("use_all_stations_on_all_data_sets", "general.test") is False + def test_init_trainable_behaviour(self): + exp_setup = ExperimentSetup(trainable=False, create_new_model=True) + data_store = exp_setup.data_store + assert data_store.get("trainable", "general") is True + assert data_store.get("create_new_model", "general") is True + exp_setup = ExperimentSetup(trainable=False, create_new_model=False) + data_store = exp_setup.data_store + assert data_store.get("trainable", "general") is False + assert data_store.get("create_new_model", "general") is False + exp_setup = ExperimentSetup(trainable=True, create_new_model=True) + data_store = exp_setup.data_store + assert data_store.get("trainable", "general") is True + assert data_store.get("create_new_model", "general") is True + exp_setup = ExperimentSetup(trainable=True, create_new_model=False) + data_store = exp_setup.data_store + assert data_store.get("trainable", "general") is True + assert data_store.get("create_new_model", "general") is False + def test_compare_variables_and_statistics(self): + experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder")) kwargs = dict(parser_args={"experiment_date": "TODAY"}, var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}, - stations=['DEBY053', 'DEBW059', 'DEBW027'], variables=["o3", "relhum"], statistics_per_var=None) + stations=['DEBY053', 'DEBW059', 'DEBW027'], variables=["o3", "relhum"], statistics_per_var=None, + experiment_path=experiment_path) with pytest.raises(ValueError) as e: ExperimentSetup(**kwargs) assert "for the variables: {'relhum'}" in e.value.args[0] diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py index 2864ae45bcd7d3c6109d6d84fe5ea152a7d86384..35c5f8ee7581856a9feee3abd0face73ee83952c 100644 --- a/test/test_modules/test_model_setup.py +++ b/test/test_modules/test_model_setup.py @@ -20,6 +20,7 @@ class TestModelSetup: obj.callbacks_name = "placeholder_%s_str.pickle" obj.data_store.set("lr_decay", "dummy_str", "general.model") obj.data_store.set("hist", "dummy_str", "general.model") + obj.model_name = "%s.h5" yield obj RunEnvironment().__del__() diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py index 485348ceca740d8263394fca36efbfbde6dd2d0d..ac040c3a286c25dc84853c26c8509278642a1495 100644 --- a/test/test_modules/test_training.py +++ b/test/test_modules/test_training.py @@ -57,10 +57,12 @@ class TestTraining: obj.data_store.set("generator", mock.MagicMock(return_value="mock_test_gen"), "general.test") os.makedirs(path) obj.data_store.set("experiment_path", path, "general") + obj.data_store.set("model_name", os.path.join(path, "test_model.h5"), "general.model") obj.data_store.set("experiment_name", "TestExperiment", "general") path_plot = os.path.join(path, "plots") os.makedirs(path_plot) obj.data_store.set("plot_path", path_plot, "general") + obj._trainable = True yield obj if os.path.exists(path): shutil.rmtree(path) @@ -131,6 +133,7 @@ class TestTraining: obj.data_store.set("generator", generator, "general.test") model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) obj.data_store.set("model", model, "general.model") + obj.data_store.set("model_name", os.path.join(path, "test_model.h5"), "general.model") obj.data_store.set("batch_size", 256, "general.model") obj.data_store.set("epochs", 2, "general.model") obj.data_store.set("checkpoint", checkpoint, "general.model") @@ -138,6 +141,9 @@ class TestTraining: obj.data_store.set("hist", HistoryAdvanced(), "general.model") obj.data_store.set("experiment_name", "TestExperiment", "general") obj.data_store.set("experiment_path", path, "general") + obj.data_store.set("trainable", True, "general") + obj.data_store.set("create_new_model" + "", True, "general") path_plot = os.path.join(path, "plots") os.makedirs(path_plot) obj.data_store.set("plot_path", path_plot, "general") @@ -179,7 +185,7 @@ class TestTraining: def test_save_model(self, init_without_run, path, caplog): caplog.set_level(logging.DEBUG) - model_name = "TestExperiment_my_model.h5" + model_name = "test_model.h5" assert model_name not in os.listdir(path) init_without_run.save_model() assert caplog.record_tuples[0] == ("root", 10, PyTestRegex(f"save best model to {os.path.join(path, model_name)}"))