Skip to content
Snippets Groups Projects
Commit fb17c979 authored by leufen1's avatar leufen1
Browse files

update tests

parent 22d23dbb
No related branches found
No related tags found
5 merge requests!432IOA works now also with xarray and on identical data, IOA is included in...,!431Resolve "release v2.1.0",!430update recent developments,!420Resolve "disable/enable early stopping",!419Resolve "loss plot with best result marker"
Pipeline #100528 passed
...@@ -187,19 +187,6 @@ class Training(RunEnvironment): ...@@ -187,19 +187,6 @@ class Training(RunEnvironment):
self.model.save(model_name, save_format="tf") self.model.save(model_name, save_format="tf")
self.data_store.set("model", self.model) self.data_store.set("model", self.model)
def load_best_model(self, name: str) -> None:
"""
Load model weights for model with name. Skip if no weights are available.
:param name: name of the model to load weights for
"""
logging.debug(f"load best model: {name}")
try:
self.model.load_model(name, compile=True)
logging.info(f"reload model...")
except OSError:
logging.info("no weights to reload...")
def save_callbacks_as_json(self, history: Callback, lr_sc: Callback, epo_timing: Callback) -> None: def save_callbacks_as_json(self, history: Callback, lr_sc: Callback, epo_timing: Callback) -> None:
""" """
Save callbacks (history, learning rate) of training. Save callbacks (history, learning rate) of training.
......
...@@ -80,7 +80,7 @@ class TestModelSetup: ...@@ -80,7 +80,7 @@ class TestModelSetup:
setup._set_callbacks() setup._set_callbacks()
assert "general.model" in setup.data_store.search_name("callbacks") assert "general.model" in setup.data_store.search_name("callbacks")
callbacks = setup.data_store.get("callbacks", "general.model") callbacks = setup.data_store.get("callbacks", "general.model")
assert len(callbacks.get_callbacks()) == 4 assert len(callbacks.get_callbacks()) == 5
def test_set_callbacks_no_lr_decay(self, setup): def test_set_callbacks_no_lr_decay(self, setup):
setup.data_store.set("lr_decay", None, "general.model") setup.data_store.set("lr_decay", None, "general.model")
...@@ -88,7 +88,7 @@ class TestModelSetup: ...@@ -88,7 +88,7 @@ class TestModelSetup:
setup.checkpoint_name = "TestName" setup.checkpoint_name = "TestName"
setup._set_callbacks() setup._set_callbacks()
callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model") callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model")
assert len(callbacks.get_callbacks()) == 3 assert len(callbacks.get_callbacks()) == 4
with pytest.raises(IndexError): with pytest.raises(IndexError):
callbacks.get_callback_by_name("lr_decay") callbacks.get_callback_by_name("lr_decay")
......
...@@ -326,16 +326,10 @@ class TestTraining: ...@@ -326,16 +326,10 @@ class TestTraining:
model_name = "test_model.h5" model_name = "test_model.h5"
assert model_name not in os.listdir(model_path) assert model_name not in os.listdir(model_path)
init_without_run.save_model() init_without_run.save_model()
message = PyTestRegex(f"save best model to {os.path.join(model_path, model_name)}") message = PyTestRegex(f"save model to {os.path.join(model_path, model_name)}")
assert caplog.record_tuples[1] == ("root", 10, message) assert caplog.record_tuples[1] == ("root", 10, message)
assert model_name in os.listdir(model_path) assert model_name in os.listdir(model_path)
def test_load_best_model_no_weights(self, init_without_run, caplog):
caplog.set_level(logging.DEBUG)
init_without_run.load_best_model("notExisting.h5")
assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting.h5"))
assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))
def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path): def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing) init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
assert "history.json" in os.listdir(model_path) assert "history.json" in os.listdir(model_path)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment