From 5843d2f64b3e166a07157fbb2d845d92cf7582c2 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Fri, 31 Jan 2020 15:17:57 +0100
Subject: [PATCH] added docs

---
 src/model_modules/keras_extensions.py | 56 +++++++++++++++++++++++----
 1 file changed, 49 insertions(+), 7 deletions(-)

diff --git a/src/model_modules/keras_extensions.py b/src/model_modules/keras_extensions.py
index f84f33d1..e2a4b932 100644
--- a/src/model_modules/keras_extensions.py
+++ b/src/model_modules/keras_extensions.py
@@ -12,10 +12,23 @@ from keras.callbacks import History, ModelCheckpoint
 
 
 class HistoryAdvanced(History):
+    """
+    This is almost an identical clone of the original History class. The only difference is that attributes epoch and
+    history are instantiated during the init phase and not during on_train_begin. This is required to resume an already
+    started but disrupted training from an saved state. This HistoryAdvanced callback needs to be added separately as
+    additional callback. To get the full history use this object for further steps instead of the default return of
+    training methods like fit_generator().
+
+        hist = HistoryAdvanced()
+        history = model.fit_generator(generator=.... , callbacks=[hist])
+        history = hist
+
+    If training was started from beginning this class is identical to the returned history class object.
+    """
 
-    def __init__(self, old_epoch=None, old_history=None):
-        self.epoch = old_epoch or []
-        self.history = old_history or {}
+    def __init__(self):
+        self.epoch = []
+        self.history = {}
         super().__init__()
 
     def on_train_begin(self, logs=None):
@@ -78,20 +91,46 @@ class LearningRateDecay(History):
 
 class ModelCheckpointAdvanced(ModelCheckpoint):
     """
-    IMPORTANT: Always add the model checkpoint advanced as last callback to properly update all tracked callbacks, e.g.
-    fit_generator(callbacks=[..., <last_here>])
+    Enhance the standard ModelCheckpoint class by additional saves of given callbacks. Specify this callbacks as follow:
+
+        lr = CustomLearningRate()
+        hist = CustomHistory()
+        callbacks_name = "your_custom_path_%s.pickle"
+        callbacks = [{"callback": lr, "path": callbacks_name % "lr"},
+                 {"callback": hist, "path": callbacks_name % "hist"}]
+        ckpt_callbacks = ModelCheckpointAdvanced(filepath=.... , callbacks=callbacks)
+
+    Add this ckpt_callbacks as all other additional callbacks to the callback list. IMPORTANT: Always add ckpt_callbacks
+    as last callback to properly update all tracked callbacks, e.g.
+
+        fit_generator(.... , callbacks=[lr, hist, ckpt_callbacks])
+
     """
     def __init__(self, *args, **kwargs):
         self.callbacks = kwargs.pop("callbacks")
         super().__init__(*args, **kwargs)
 
     def update_best(self, hist):
+        """
+        Update internal best on resuming a training process. Otherwise best is set to +/- inf depending on the
+        performance metric and the first trained model (first of the resuming training process) will always saved as
+        best model because its performance will be better than infinity. To prevent this behaviour and compare the
+        performance with the best model performance, call this method before resuming the training process.
+        :param hist: The History object from the previous (interrupted) training.
+        """
         self.best = hist.history.get(self.monitor)[-1]
 
     def update_callbacks(self, callbacks):
+        """
+        Update all stored callback objects. The argument callbacks needs to follow the same convention like described
+        in the class description (list of dictionaries). Must be run before resuming a training process.
+        """
         self.callbacks = callbacks
 
     def on_epoch_end(self, epoch, logs=None):
+        """
+        Save model as usual (see ModelCheckpoint class), but also save additional callbacks.
+        """
         super().on_epoch_end(epoch, logs)
 
         for callback in self.callbacks:
@@ -100,9 +139,12 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
                 if self.save_best_only:
                     current = logs.get(self.monitor)
                     if current == self.best:
-                        print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
+                        if self.verbose > 0:
+                            print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
                         with open(file_path, "wb") as f:
                             pickle.dump(callback["callback"], f)
                 else:
                     with open(file_path, "wb") as f:
-                        pickle.dump(callback["callback"], f)
\ No newline at end of file
+                        if self.verbose > 0:
+                            print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
+                        pickle.dump(callback["callback"], f)
-- 
GitLab