diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py index 90b598bf80362d6dd5507abfaa02efa590c37a4b..a26bd18c75ae34d200669141578b5fa3ea2bb7c8 100644 --- a/test/test_modules/test_training.py +++ b/test/test_modules/test_training.py @@ -205,14 +205,14 @@ class TestTraining: assert ready_to_train.model.history.epoch == [0, 1] assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 - def test_save_model(self, init_without_run, path, caplog): + def test_save_model(self, init_without_run, model_path, caplog): caplog.set_level(logging.DEBUG) model_name = "test_model.h5" - assert model_name not in os.listdir(path) + assert model_name not in os.listdir(model_path) init_without_run.save_model() - message = PyTestRegex(f"save best model to {os.path.join(path, model_name)}") + message = PyTestRegex(f"save best model to {os.path.join(model_path, model_name)}") assert caplog.record_tuples[1] == ("root", 10, message) - assert model_name in os.listdir(path) + assert model_name in os.listdir(model_path) def test_load_best_model_no_weights(self, init_without_run, caplog): caplog.set_level(logging.DEBUG)