Skip to content
Snippets Groups Projects
Commit 5accca18 authored by lukas leufen's avatar lukas leufen
Browse files

a little bit more docs and additional log if training is resumed

parent 09f7e2d2
No related branches found
No related tags found
2 merge requests!37include new development,!29Lukas issue030 feat continue training
Pipeline #28883 passed
......@@ -24,7 +24,6 @@ class Training(RunEnvironment):
self.batch_size = self.data_store.get("batch_size", "general.model")
self.epochs = self.data_store.get("epochs", "general.model")
self.checkpoint: ModelCheckpointAdvanced = self.data_store.get("checkpoint", "general.model")
# self.callbacks = self.data_store.get("callbacks", "general.model")
self.lr_sc = self.data_store.get("lr_decay", "general.model")
self.hist = self.data_store.get("hist", "general.model")
self.experiment_name = self.data_store.get("experiment_name", "general")
......@@ -38,7 +37,7 @@ class Training(RunEnvironment):
2) make_predict_function():
create predict function before distribution on multiple nodes (detailed information in method description)
3) train():
train model and save callbacks
start or resume training of model and save callbacks
4) save_model():
save best model from training as final model
"""
......@@ -76,7 +75,10 @@ class Training(RunEnvironment):
def train(self) -> None:
"""
Perform training using keras fit_generator(). Callbacks are stored locally in the experiment directory. Best
model from training is saved for class variable model.
model from training is saved for class variable model. If the file path of checkpoint is not empty, this method
assumes, that this is not a new training starting from the very beginning, but a resumption from a previous
started but interrupted training (or a stopped and now continued training). Train will automatically load the
locally stored information and the corresponding model and proceed with the already started training.
"""
logging.info(f"Train with {len(self.train_set)} mini batches.")
if not os.path.exists(self.checkpoint.filepath):
......@@ -88,6 +90,7 @@ class Training(RunEnvironment):
validation_steps=len(self.val_set),
callbacks=[self.lr_sc, self.hist, self.checkpoint])
else:
logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.")
lr_filepath = self.checkpoint.callbacks[0]["path"]
hist_filepath = self.checkpoint.callbacks[1]["path"]
self.lr_sc = pickle.load(open(lr_filepath, "rb"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment