diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py index 33f9ddf62bd91c870643727de4d146ce332fbe07..90b598bf80362d6dd5507abfaa02efa590c37a4b 100644 --- a/test/test_modules/test_training.py +++ b/test/test_modules/test_training.py @@ -47,7 +47,7 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m class TestTraining: @pytest.fixture - def init_without_run(self, path: str, model: keras.Model, callbacks: CallbackHandler): + def init_without_run(self, path: str, model: keras.Model, callbacks: CallbackHandler, model_path): obj = object.__new__(Training) super(Training, obj).__init__() obj.model = model @@ -66,7 +66,9 @@ class TestTraining: obj.data_store.set("generator", mock.MagicMock(return_value="mock_test_gen"), "general.test") os.makedirs(path) obj.data_store.set("experiment_path", path, "general") - obj.data_store.set("model_name", os.path.join(path, "test_model.h5"), "general.model") + os.makedirs(model_path) + obj.data_store.set("model_path", model_path, "general") + obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model") obj.data_store.set("experiment_name", "TestExperiment", "general") path_plot = os.path.join(path, "plots") os.makedirs(path_plot) @@ -100,6 +102,10 @@ class TestTraining: def path(self): return os.path.join(os.path.dirname(__file__), "TestExperiment") + @pytest.fixture + def model_path(self, path): + return os.path.join(path, "model") + @pytest.fixture def generator(self, path): return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', @@ -138,15 +144,17 @@ class TestTraining: return obj @pytest.fixture - def ready_to_init(self, generator, model, callbacks, path): + def ready_to_init(self, generator, model, callbacks, path, model_path): os.makedirs(path) + os.makedirs(model_path) obj = RunEnvironment() obj.data_store.set("generator", generator, "general.train") obj.data_store.set("generator", generator, "general.val") obj.data_store.set("generator", generator, "general.test") model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) obj.data_store.set("model", model, "general.model") - obj.data_store.set("model_name", os.path.join(path, "test_model.h5"), "general.model") + obj.data_store.set("model_path", model_path, "general") + obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model") obj.data_store.set("batch_size", 256, "general.model") obj.data_store.set("epochs", 2, "general.model") clbk, hist, lr = callbacks @@ -212,23 +220,23 @@ class TestTraining: assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting")) assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload...")) - def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, path): + def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, model_path): init_without_run.save_callbacks_as_json(history, learning_rate) - assert "history.json" in os.listdir(path) + assert "history.json" in os.listdir(model_path) - def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, path): + def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, model_path): init_without_run.save_callbacks_as_json(history, learning_rate) - assert "history_lr.json" in os.listdir(path) + assert "history_lr.json" in os.listdir(model_path) - def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, path): + def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, model_path): init_without_run.save_callbacks_as_json(history, learning_rate) - with open(os.path.join(path, "history.json")) as jfile: + with open(os.path.join(model_path, "history.json")) as jfile: hist = json.load(jfile) assert hist == history.history - def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, path): + def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, model_path): init_without_run.save_callbacks_as_json(history, learning_rate) - with open(os.path.join(path, "history_lr.json")) as jfile: + with open(os.path.join(model_path, "history_lr.json")) as jfile: lr = json.load(jfile) assert lr == learning_rate.lr