Skip to content
Snippets Groups Projects
Commit 7b302f4e authored by lukas leufen's avatar lukas leufen
Browse files

small bug fix for history return statement

parent 5de64e13
Branches
Tags
2 merge requests!37include new development,!29Lukas issue030 feat continue training
Pipeline #28847 passed
...@@ -63,11 +63,11 @@ class ModelSetup(RunEnvironment): ...@@ -63,11 +63,11 @@ class ModelSetup(RunEnvironment):
self.data_store.set("model", self.model, self.scope) self.data_store.set("model", self.model, self.scope)
def _set_checkpoint(self): def _set_checkpoint(self):
"""
Must be run after all callback functions that shall be tracked during training have been created (currently this
affects the learning rate decay and the advanced history [actually created in this method]).
"""
lr = self.data_store.get("lr_decay", scope="general.model") lr = self.data_store.get("lr_decay", scope="general.model")
# checkpoint = ModelCheckpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', save_best_only=True, mode='auto')
# checkpoint = ModelCheckpointAdvanced(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
# save_best_only=True, mode='auto', callbacks_to_save=lr,
# callbacks_filepath=self.callbacks_name)
hist = HistoryAdvanced() hist = HistoryAdvanced()
self.data_store.set("hist", hist, scope="general.model") self.data_store.set("hist", hist, scope="general.model")
callbacks = [{"callback": lr, "path": self.callbacks_name % "lr"}, callbacks = [{"callback": lr, "path": self.callbacks_name % "lr"},
......
...@@ -89,7 +89,7 @@ class Training(RunEnvironment): ...@@ -89,7 +89,7 @@ class Training(RunEnvironment):
# callbacks=self.callbacks) # callbacks=self.callbacks)
callbacks=[self.lr_sc, self.hist, self.checkpoint]) callbacks=[self.lr_sc, self.hist, self.checkpoint])
else: else:
lr_filepath = self.checkpoint.callbacks[0]["path"] # TODO: stopped here. why does training start 1 epoch too early or doesn't it? lr_filepath = self.checkpoint.callbacks[0]["path"]
hist_filepath = self.checkpoint.callbacks[1]["path"] hist_filepath = self.checkpoint.callbacks[1]["path"]
lr_callbacks = pickle.load(open(lr_filepath, "rb")) lr_callbacks = pickle.load(open(lr_filepath, "rb"))
hist_callbacks = pickle.load(open(hist_filepath, "rb")) hist_callbacks = pickle.load(open(hist_filepath, "rb"))
...@@ -101,7 +101,7 @@ class Training(RunEnvironment): ...@@ -101,7 +101,7 @@ class Training(RunEnvironment):
{"callback": self.hist, "path": hist_filepath}] {"callback": self.hist, "path": hist_filepath}]
self.checkpoint.update_callbacks(callbacks) self.checkpoint.update_callbacks(callbacks)
self.checkpoint.update_best(hist_callbacks) self.checkpoint.update_best(hist_callbacks)
self.hist = self.model.fit_generator(generator=self.train_set.distribute_on_batches(), _ = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
steps_per_epoch=len(self.train_set), steps_per_epoch=len(self.train_set),
epochs=self.epochs, epochs=self.epochs,
verbose=2, verbose=2,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment