diff --git a/src/run_modules/training.py b/src/run_modules/training.py index 5dab3da2173727fdcc9ab8e88a3e91fcd9d54a59..53a505484f4e3d59ef37a37f3d0c6996059485d9 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -93,6 +93,7 @@ class Training(RunEnvironment): model_name = os.path.join(path, name) logging.debug(f"save best model to {model_name}") self.model.save(model_name) + self.data_store.put("best_model", self.model, "general") def load_best_model(self, name: str) -> None: """ @@ -119,6 +120,3 @@ class Training(RunEnvironment): json.dump(history.history, f) with open(os.path.join(path, "history_lr.json"), "w") as f: json.dump(self.lr_sc.lr, f) - - -