Skip to content
Snippets Groups Projects
Commit b03b3a0c authored by lukas leufen's avatar lukas leufen
Browse files

updated tests

parent 39a058b2
No related branches found
No related tags found
1 merge request!50release for v0.7.0
Pipeline #29420 passed
......@@ -47,7 +47,8 @@ class TestExperimentSetup:
data_store = exp_setup.data_store
# experiment setup
assert data_store.get("data_path", "general") == prepare_host()
assert data_store.get("trainable", "general") is False
assert data_store.get("trainable", "general") is True
assert data_store.get("create_new_model", "general") is True
assert data_store.get("fraction_of_training", "general") == 0.8
# set experiment name
assert data_store.get("experiment_name", "general") == "TestExperiment"
......@@ -104,13 +105,14 @@ class TestExperimentSetup:
target_var="temp", target_dim="target", window_lead_time=10, dimensions="dim1",
interpolate_dim="int_dim", interpolate_method="cubic", limit_nan_fill=5, train_start="2000-01-01",
train_end="2000-01-02", val_start="2000-01-03", val_end="2000-01-04", test_start="2000-01-05",
test_end="2000-01-06", use_all_stations_on_all_data_sets=False, trainable=True,
fraction_of_train=0.5, experiment_path=experiment_path)
test_end="2000-01-06", use_all_stations_on_all_data_sets=False, trainable=False,
fraction_of_train=0.5, experiment_path=experiment_path, create_new_model=True)
exp_setup = ExperimentSetup(**kwargs)
data_store = exp_setup.data_store
# experiment setup
assert data_store.get("data_path", "general") == prepare_host()
assert data_store.get("trainable", "general") is True
assert data_store.get("create_new_model", "general") is True
assert data_store.get("fraction_of_training", "general") == 0.5
# set experiment name
assert data_store.get("experiment_name", "general") == "TODAY_network"
......@@ -150,10 +152,30 @@ class TestExperimentSetup:
# use all stations on all data sets (train, val, test)
assert data_store.get("use_all_stations_on_all_data_sets", "general.test") is False
def test_init_trainable_behaviour(self):
exp_setup = ExperimentSetup(trainable=False, create_new_model=True)
data_store = exp_setup.data_store
assert data_store.get("trainable", "general") is True
assert data_store.get("create_new_model", "general") is True
exp_setup = ExperimentSetup(trainable=False, create_new_model=False)
data_store = exp_setup.data_store
assert data_store.get("trainable", "general") is False
assert data_store.get("create_new_model", "general") is False
exp_setup = ExperimentSetup(trainable=True, create_new_model=True)
data_store = exp_setup.data_store
assert data_store.get("trainable", "general") is True
assert data_store.get("create_new_model", "general") is True
exp_setup = ExperimentSetup(trainable=True, create_new_model=False)
data_store = exp_setup.data_store
assert data_store.get("trainable", "general") is True
assert data_store.get("create_new_model", "general") is False
def test_compare_variables_and_statistics(self):
experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder"))
kwargs = dict(parser_args={"experiment_date": "TODAY"},
var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'},
stations=['DEBY053', 'DEBW059', 'DEBW027'], variables=["o3", "relhum"], statistics_per_var=None)
stations=['DEBY053', 'DEBW059', 'DEBW027'], variables=["o3", "relhum"], statistics_per_var=None,
experiment_path=experiment_path)
with pytest.raises(ValueError) as e:
ExperimentSetup(**kwargs)
assert "for the variables: {'relhum'}" in e.value.args[0]
......
......@@ -20,6 +20,7 @@ class TestModelSetup:
obj.callbacks_name = "placeholder_%s_str.pickle"
obj.data_store.set("lr_decay", "dummy_str", "general.model")
obj.data_store.set("hist", "dummy_str", "general.model")
obj.model_name = "%s.h5"
yield obj
RunEnvironment().__del__()
......
......@@ -57,10 +57,12 @@ class TestTraining:
obj.data_store.set("generator", mock.MagicMock(return_value="mock_test_gen"), "general.test")
os.makedirs(path)
obj.data_store.set("experiment_path", path, "general")
obj.data_store.set("model_name", os.path.join(path, "test_model.h5"), "general.model")
obj.data_store.set("experiment_name", "TestExperiment", "general")
path_plot = os.path.join(path, "plots")
os.makedirs(path_plot)
obj.data_store.set("plot_path", path_plot, "general")
obj._trainable = True
yield obj
if os.path.exists(path):
shutil.rmtree(path)
......@@ -131,6 +133,7 @@ class TestTraining:
obj.data_store.set("generator", generator, "general.test")
model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
obj.data_store.set("model", model, "general.model")
obj.data_store.set("model_name", os.path.join(path, "test_model.h5"), "general.model")
obj.data_store.set("batch_size", 256, "general.model")
obj.data_store.set("epochs", 2, "general.model")
obj.data_store.set("checkpoint", checkpoint, "general.model")
......@@ -138,6 +141,9 @@ class TestTraining:
obj.data_store.set("hist", HistoryAdvanced(), "general.model")
obj.data_store.set("experiment_name", "TestExperiment", "general")
obj.data_store.set("experiment_path", path, "general")
obj.data_store.set("trainable", True, "general")
obj.data_store.set("create_new_model"
"", True, "general")
path_plot = os.path.join(path, "plots")
os.makedirs(path_plot)
obj.data_store.set("plot_path", path_plot, "general")
......@@ -179,7 +185,7 @@ class TestTraining:
def test_save_model(self, init_without_run, path, caplog):
caplog.set_level(logging.DEBUG)
model_name = "TestExperiment_my_model.h5"
model_name = "test_model.h5"
assert model_name not in os.listdir(path)
init_without_run.save_model()
assert caplog.record_tuples[0] == ("root", 10, PyTestRegex(f"save best model to {os.path.join(path, model_name)}"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment