From 5accca186a358ab8cd06b74b6e6801f47e9a339f Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Fri, 31 Jan 2020 15:49:35 +0100 Subject: [PATCH] a little bit more docs and additional log if training is resumed --- src/run_modules/training.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/run_modules/training.py b/src/run_modules/training.py index 195ae28a..99afd830 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -24,7 +24,6 @@ class Training(RunEnvironment): self.batch_size = self.data_store.get("batch_size", "general.model") self.epochs = self.data_store.get("epochs", "general.model") self.checkpoint: ModelCheckpointAdvanced = self.data_store.get("checkpoint", "general.model") - # self.callbacks = self.data_store.get("callbacks", "general.model") self.lr_sc = self.data_store.get("lr_decay", "general.model") self.hist = self.data_store.get("hist", "general.model") self.experiment_name = self.data_store.get("experiment_name", "general") @@ -38,7 +37,7 @@ class Training(RunEnvironment): 2) make_predict_function(): create predict function before distribution on multiple nodes (detailed information in method description) 3) train(): - train model and save callbacks + start or resume training of model and save callbacks 4) save_model(): save best model from training as final model """ @@ -76,7 +75,10 @@ class Training(RunEnvironment): def train(self) -> None: """ Perform training using keras fit_generator(). Callbacks are stored locally in the experiment directory. Best - model from training is saved for class variable model. + model from training is saved for class variable model. If the file path of checkpoint is not empty, this method + assumes, that this is not a new training starting from the very beginning, but a resumption from a previous + started but interrupted training (or a stopped and now continued training). Train will automatically load the + locally stored information and the corresponding model and proceed with the already started training. """ logging.info(f"Train with {len(self.train_set)} mini batches.") if not os.path.exists(self.checkpoint.filepath): @@ -88,6 +90,7 @@ class Training(RunEnvironment): validation_steps=len(self.val_set), callbacks=[self.lr_sc, self.hist, self.checkpoint]) else: + logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.") lr_filepath = self.checkpoint.callbacks[0]["path"] hist_filepath = self.checkpoint.callbacks[1]["path"] self.lr_sc = pickle.load(open(lr_filepath, "rb")) -- GitLab