Skip to content
Snippets Groups Projects
Commit bc989d62 authored by lukas leufen's avatar lukas leufen
Browse files

update tests

parent 02e43daa
No related branches found
No related tags found
3 merge requests!125Release v0.10.0,!124Update Master to new version v0.10.0,!101Resolve "model folder in experiment"
Pipeline #39298 failed
......@@ -47,7 +47,7 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m
class TestTraining:
@pytest.fixture
def init_without_run(self, path: str, model: keras.Model, callbacks: CallbackHandler):
def init_without_run(self, path: str, model: keras.Model, callbacks: CallbackHandler, model_path):
obj = object.__new__(Training)
super(Training, obj).__init__()
obj.model = model
......@@ -66,7 +66,9 @@ class TestTraining:
obj.data_store.set("generator", mock.MagicMock(return_value="mock_test_gen"), "general.test")
os.makedirs(path)
obj.data_store.set("experiment_path", path, "general")
obj.data_store.set("model_name", os.path.join(path, "test_model.h5"), "general.model")
os.makedirs(model_path)
obj.data_store.set("model_path", model_path, "general")
obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
obj.data_store.set("experiment_name", "TestExperiment", "general")
path_plot = os.path.join(path, "plots")
os.makedirs(path_plot)
......@@ -100,6 +102,10 @@ class TestTraining:
def path(self):
return os.path.join(os.path.dirname(__file__), "TestExperiment")
@pytest.fixture
def model_path(self, path):
return os.path.join(path, "model")
@pytest.fixture
def generator(self, path):
return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE',
......@@ -138,15 +144,17 @@ class TestTraining:
return obj
@pytest.fixture
def ready_to_init(self, generator, model, callbacks, path):
def ready_to_init(self, generator, model, callbacks, path, model_path):
os.makedirs(path)
os.makedirs(model_path)
obj = RunEnvironment()
obj.data_store.set("generator", generator, "general.train")
obj.data_store.set("generator", generator, "general.val")
obj.data_store.set("generator", generator, "general.test")
model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
obj.data_store.set("model", model, "general.model")
obj.data_store.set("model_name", os.path.join(path, "test_model.h5"), "general.model")
obj.data_store.set("model_path", model_path, "general")
obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
obj.data_store.set("batch_size", 256, "general.model")
obj.data_store.set("epochs", 2, "general.model")
clbk, hist, lr = callbacks
......@@ -212,23 +220,23 @@ class TestTraining:
assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting"))
assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))
def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, path):
def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate)
assert "history.json" in os.listdir(path)
assert "history.json" in os.listdir(model_path)
def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, path):
def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate)
assert "history_lr.json" in os.listdir(path)
assert "history_lr.json" in os.listdir(model_path)
def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, path):
def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate)
with open(os.path.join(path, "history.json")) as jfile:
with open(os.path.join(model_path, "history.json")) as jfile:
hist = json.load(jfile)
assert hist == history.history
def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, path):
def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate)
with open(os.path.join(path, "history_lr.json")) as jfile:
with open(os.path.join(model_path, "history_lr.json")) as jfile:
lr = json.load(jfile)
assert lr == learning_rate.lr
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment