From 5843d2f64b3e166a07157fbb2d845d92cf7582c2 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Fri, 31 Jan 2020 15:17:57 +0100 Subject: [PATCH] added docs --- src/model_modules/keras_extensions.py | 56 +++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 7 deletions(-) diff --git a/src/model_modules/keras_extensions.py b/src/model_modules/keras_extensions.py index f84f33d1..e2a4b932 100644 --- a/src/model_modules/keras_extensions.py +++ b/src/model_modules/keras_extensions.py @@ -12,10 +12,23 @@ from keras.callbacks import History, ModelCheckpoint class HistoryAdvanced(History): + """ + This is almost an identical clone of the original History class. The only difference is that attributes epoch and + history are instantiated during the init phase and not during on_train_begin. This is required to resume an already + started but disrupted training from an saved state. This HistoryAdvanced callback needs to be added separately as + additional callback. To get the full history use this object for further steps instead of the default return of + training methods like fit_generator(). + + hist = HistoryAdvanced() + history = model.fit_generator(generator=.... , callbacks=[hist]) + history = hist + + If training was started from beginning this class is identical to the returned history class object. + """ - def __init__(self, old_epoch=None, old_history=None): - self.epoch = old_epoch or [] - self.history = old_history or {} + def __init__(self): + self.epoch = [] + self.history = {} super().__init__() def on_train_begin(self, logs=None): @@ -78,20 +91,46 @@ class LearningRateDecay(History): class ModelCheckpointAdvanced(ModelCheckpoint): """ - IMPORTANT: Always add the model checkpoint advanced as last callback to properly update all tracked callbacks, e.g. - fit_generator(callbacks=[..., <last_here>]) + Enhance the standard ModelCheckpoint class by additional saves of given callbacks. Specify this callbacks as follow: + + lr = CustomLearningRate() + hist = CustomHistory() + callbacks_name = "your_custom_path_%s.pickle" + callbacks = [{"callback": lr, "path": callbacks_name % "lr"}, + {"callback": hist, "path": callbacks_name % "hist"}] + ckpt_callbacks = ModelCheckpointAdvanced(filepath=.... , callbacks=callbacks) + + Add this ckpt_callbacks as all other additional callbacks to the callback list. IMPORTANT: Always add ckpt_callbacks + as last callback to properly update all tracked callbacks, e.g. + + fit_generator(.... , callbacks=[lr, hist, ckpt_callbacks]) + """ def __init__(self, *args, **kwargs): self.callbacks = kwargs.pop("callbacks") super().__init__(*args, **kwargs) def update_best(self, hist): + """ + Update internal best on resuming a training process. Otherwise best is set to +/- inf depending on the + performance metric and the first trained model (first of the resuming training process) will always saved as + best model because its performance will be better than infinity. To prevent this behaviour and compare the + performance with the best model performance, call this method before resuming the training process. + :param hist: The History object from the previous (interrupted) training. + """ self.best = hist.history.get(self.monitor)[-1] def update_callbacks(self, callbacks): + """ + Update all stored callback objects. The argument callbacks needs to follow the same convention like described + in the class description (list of dictionaries). Must be run before resuming a training process. + """ self.callbacks = callbacks def on_epoch_end(self, epoch, logs=None): + """ + Save model as usual (see ModelCheckpoint class), but also save additional callbacks. + """ super().on_epoch_end(epoch, logs) for callback in self.callbacks: @@ -100,9 +139,12 @@ class ModelCheckpointAdvanced(ModelCheckpoint): if self.save_best_only: current = logs.get(self.monitor) if current == self.best: - print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) + if self.verbose > 0: + print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) with open(file_path, "wb") as f: pickle.dump(callback["callback"], f) else: with open(file_path, "wb") as f: - pickle.dump(callback["callback"], f) \ No newline at end of file + if self.verbose > 0: + print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) + pickle.dump(callback["callback"], f) -- GitLab