diff --git a/requirements.txt b/requirements.txt
index b46f44416cf6560ecc0b62f8d22dd7d547a036c6..71bb1338effff38092510982d4a2c1f37f7b026a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -59,6 +59,7 @@ tensorflow-estimator==1.13.0
 tensorflow==1.13.1
 termcolor==1.1.0
 toolz==0.10.0
+typing-extensions
 urllib3==1.25.8
 wcwidth==0.1.8
 Werkzeug==1.0.0
diff --git a/requirements_gpu.txt b/requirements_gpu.txt
index 6ce4df8fe164408024e21db5ea94a692fb5dbf26..5ddb56acc71e0a51abb99b9447f871ddcb715a5d 100644
--- a/requirements_gpu.txt
+++ b/requirements_gpu.txt
@@ -59,6 +59,7 @@ tensorflow-estimator==1.13.0
 tensorflow-gpu==1.13.1
 termcolor==1.1.0
 toolz==0.10.0
+typing-extensions
 urllib3==1.25.8
 wcwidth==0.1.8
 Werkzeug==1.0.0
diff --git a/src/model_modules/keras_extensions.py b/src/model_modules/keras_extensions.py
index 180e324602da25e1df8fb218c1d3bba180004ac8..0b374bc4cfb55c945aeceb54112579716e1c6c17 100644
--- a/src/model_modules/keras_extensions.py
+++ b/src/model_modules/keras_extensions.py
@@ -1,25 +1,31 @@
+"""Collection of different extensions to keras framework."""
+
 __author__ = 'Lukas Leufen, Felix Kleinert'
 __date__ = '2020-01-31'
 
 import logging
 import math
 import pickle
-from typing import Union
+from typing import Union, List
+from typing_extensions import TypedDict
 
 import numpy as np
 from keras import backend as K
-from keras.callbacks import History, ModelCheckpoint
+from keras.callbacks import History, ModelCheckpoint, Callback
 
 from src import helpers
 
 
 class HistoryAdvanced(History):
     """
-    This is almost an identical clone of the original History class. The only difference is that attributes epoch and
-    history are instantiated during the init phase and not during on_train_begin. This is required to resume an already
-    started but disrupted training from an saved state. This HistoryAdvanced callback needs to be added separately as
-    additional callback. To get the full history use this object for further steps instead of the default return of
-    training methods like fit_generator().
+    This is almost an identical clone of the original History class.
+
+    The only difference is that attributes epoch and history are instantiated during the init phase and not during
+    on_train_begin. This is required to resume an already started but disrupted training from an saved state. This
+    HistoryAdvanced callback needs to be added separately as additional callback. To get the full history use this
+    object for further steps instead of the default return of training methods like fit_generator().
+
+    .. code-block:: python
 
         hist = HistoryAdvanced()
         history = model.fit_generator(generator=.... , callbacks=[hist])
@@ -29,21 +35,30 @@ class HistoryAdvanced(History):
     """
 
     def __init__(self):
+        """Set up HistoryAdvanced."""
         self.epoch = []
         self.history = {}
         super().__init__()
 
     def on_train_begin(self, logs=None):
+        """Overload on_train_begin method to do nothing instead of resetting epoch and history."""
         pass
 
 
 class LearningRateDecay(History):
     """
-    Decay learning rate during model training. Start with a base learning rate and lower this rate after every
-    n(=epochs_drop) epochs by drop value (0, 1], drop value = 1 means no decay in learning rate.
+    Decay learning rate during model training.
+
+    Start with a base learning rate and lower this rate after every n(=epochs_drop) epochs by drop value (0, 1], drop
+    value = 1 means no decay in learning rate.
+
+    :param base_lr: base learning rate to start with
+    :param drop: ratio to drop after epochs_drop
+    :param epochs_drop: number of epochs after that drop takes place
     """
 
     def __init__(self, base_lr: float = 0.01, drop: float = 0.96, epochs_drop: int = 8):
+        """Set up LearningRateDecay."""
         super().__init__()
         self.lr = {'lr': []}
         self.base_lr = self.check_param(base_lr, 'base_lr')
@@ -55,13 +70,16 @@ class LearningRateDecay(History):
     @staticmethod
     def check_param(value: float, name: str, lower: Union[float, None] = 0, upper: Union[float, None] = 1):
         """
-        Check if given value is in interval. The left (lower) endpoint is open, right (upper) endpoint is closed. To
-        only one side of the interval, set the other endpoint to None. If both ends are set to None, just return the
-        value without any check.
+        Check if given value is in interval.
+
+        The left (lower) endpoint is open, right (upper) endpoint is closed. To use only one side of the interval, set
+        the other endpoint to None. If both ends are set to None, just return the value without any check.
+
         :param value: value to check
         :param name: name of the variable to display in error message
         :param lower: left (lower) endpoint of interval, opened
         :param upper: right (upper) endpoint of interval, closed
+
         :return: unchanged value or raise ValueError
         """
         if lower is None:
@@ -75,11 +93,13 @@ class LearningRateDecay(History):
                              f"{name}={value}")
 
     def on_train_begin(self, logs=None):
+        """Overload on_train_begin method to do nothing instead of resetting epoch and history."""
         pass
 
     def on_epoch_begin(self, epoch: int, logs=None):
         """
         Lower learning rate every epochs_drop epochs by factor drop.
+
         :param epoch: current epoch
         :param logs: ?
         :return: update keras learning rate
@@ -93,46 +113,66 @@ class LearningRateDecay(History):
 
 class ModelCheckpointAdvanced(ModelCheckpoint):
     """
-    Enhance the standard ModelCheckpoint class by additional saves of given callbacks. Specify this callbacks as follow:
+    Enhance the standard ModelCheckpoint class by additional saves of given callbacks.
+
+    **We recommend to use CallbackHandler instead of ModelCheckpointAdvanced.** CallbackHandler will handler all your
+    callbacks and the ModelCheckpointAdvanced and prevent you from pitfalls like wrong ordering of callbacks. Actually,
+    CallbackHandler makes use of ModelCheckpointAdvanced.
+
+    However, if you want to use the ModelCheckpointAdvanced explicitly, follow these instructions:
 
+    .. code-block:: python
+
+        # load your callbacks
         lr = CustomLearningRate()
         hist = CustomHistory()
+
+        # set your callbacks with a list dictionary structure
         callbacks_name = "your_custom_path_%s.pickle"
         callbacks = [{"callback": lr, "path": callbacks_name % "lr"},
-                 {"callback": hist, "path": callbacks_name % "hist"}]
+                     {"callback": hist, "path": callbacks_name % "hist"}]
+        # initialise ModelCheckpointAdvanced like the normal ModelCheckpoint (see keras callbacks)
         ckpt_callbacks = ModelCheckpointAdvanced(filepath=.... , callbacks=callbacks)
 
-    Add this ckpt_callbacks as all other additional callbacks to the callback list. IMPORTANT: Always add ckpt_callbacks
-    as last callback to properly update all tracked callbacks, e.g.
+    Add ModelCheckpointAdvanced as all other additional callbacks to the callback list. IMPORTANT: Always add
+    ModelCheckpointAdvanced as last callback to properly update all tracked callbacks, e.g.
+
+    .. code-block:: python
 
+        # always add ModelCheckpointAdvanced as last element
         fit_generator(.... , callbacks=[lr, hist, ckpt_callbacks])
 
     """
+
     def __init__(self, *args, **kwargs):
+        """Initialise ModelCheckpointAdvanced and set callbacks attribute."""
         self.callbacks = kwargs.pop("callbacks")
         super().__init__(*args, **kwargs)
 
     def update_best(self, hist):
         """
-        Update internal best on resuming a training process. Otherwise best is set to +/- inf depending on the
-        performance metric and the first trained model (first of the resuming training process) will always saved as
-        best model because its performance will be better than infinity. To prevent this behaviour and compare the
-        performance with the best model performance, call this method before resuming the training process.
+        Update internal best on resuming a training process.
+
+        If no best object is available, best is set to +/- inf depending on the performance metric and the first trained
+        model (first of the resuming training process) will always saved as best model because its performance will be
+        better than infinity. To prevent this behaviour and compare the performance with the best model performance,
+        call this method before resuming the training process.
+
         :param hist: The History object from the previous (interrupted) training.
         """
         self.best = hist.history.get(self.monitor)[-1]
 
     def update_callbacks(self, callbacks):
         """
-        Update all stored callback objects. The argument callbacks needs to follow the same convention like described
-        in the class description (list of dictionaries). Must be run before resuming a training process.
+        Update all stored callback objects.
+
+        The argument callbacks needs to follow the same convention like described in the class description (list of
+        dictionaries). Must be run before resuming a training process.
         """
         self.callbacks = helpers.to_list(callbacks)
 
     def on_epoch_end(self, epoch, logs=None):
-        """
-        Save model as usual (see ModelCheckpoint class), but also save additional callbacks.
-        """
+        """Save model as usual (see ModelCheckpoint class), but also save additional callbacks."""
         super().on_epoch_end(epoch, logs)
 
         for callback in self.callbacks:
@@ -152,10 +192,65 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
                         pickle.dump(callback["callback"], f)
 
 
+clbk_type = TypedDict("clbk_type", {"name": str, str: Callback, "path": str})
+
+
 class CallbackHandler:
+    r"""Use the CallbackHandler for better controlling of custom callbacks.
+
+    The callback handler will always keep your callbacks in the right order and adds a model checkpoint at last position
+    if required. You can add an arbitrary number of callbacks to the handler.
+
+    .. code-block:: python
+
+        # init callbacks handler
+        callbacks = CallbackHandler()
+
+        # set history object (add further elements like this example)
+        hist = keras.callbacks.History()
+        callbacks.add_callback(hist, "callbacks-hist.pickle", "hist")
+
+        # create advanced checkpoint (details see ModelCheckpointAdvanced)
+        ckpt_name = "model-best.h5"
+        callbacks.create_model_checkpoint(filepath=ckpt_name, verbose=1, ...)
+
+        # get checkpoint
+        ckpt = callbacks.get_checkpoint()
+
+        # fit already compiled model and add callbacks, it is important to call get_callbacks with as_dict=False
+        history = model.fit(..., callbacks=self.callbacks.get_callbacks(as_dict=False))
+
+    If you want to continue a training, you can use the callback handler to load already stored callbacks. First you
+    need to reload all callbacks. Make sure, that all callbacks are available from previous training. If the callback
+    handler was set up like in the former code example, this will work.
+
+    .. code-block:: python
+
+        # load callbacks and update checkpoint
+        callbacks.load_callbacks()
+        callbacks.update_checkpoint()
+
+        # optional: load your model using checkpoint path
+        model = keras.models.load_model(ckpt.filepath)
+
+        # extract history object and set starting epoch
+        hist = callbacks.get_callback_by_name("hist")
+        initial_epoch = max(hist.epoch) + 1
+
+        # resume training (including initial_epoch) and use callback handler's history object
+        _ = self.model.fit(..., callbacks=self.callbacks.get_callbacks(as_dict=False), initial_epoch=initial_epoch)
+        history = hist
+
+    Important notes: Do not use the returned history object of model.fit, but use the history object from callback
+    handler. The fit history will only contain the new history, whereas callback handler's history contains the full
+    history including the resumed and new history. For a correct epoch counting, you need to add the initial epoch to
+    the fit method too.
+
+    """
 
     def __init__(self):
-        self.__callbacks = []
+        """Initialise CallbackHandler."""
+        self.__callbacks: List[clbk_type] = []
         self._checkpoint = None
         self.editable = True
 
@@ -168,46 +263,79 @@ class CallbackHandler:
         name, callback, callback_path = value
         self.__callbacks.append({"name": name, name: callback, "path": callback_path})
 
-    def _update_callback(self, pos, value):
+    def _update_callback(self, pos: int, value: Callback) -> None:
+        """Update callback entry with given value."""
         name = self.__callbacks[pos]["name"]
         self.__callbacks[pos][name] = value
 
-    def add_callback(self, callback, callback_path, name="callback"):
+    def add_callback(self, callback: Callback, callback_path: str, name: str = "callback") -> None:
+        """
+        Add given callback on last position if CallbackHandler is editable.
+
+        Save callback with given name. Will raise a PermissionError, if editable is False.
+
+        :param callback: callback object to store
+        :param callback_path: path to callback
+        :param name: name of the callback
+        """
         if self.editable:
             self._callbacks = (name, callback, callback_path)
         else:
             raise PermissionError(f"{__class__.__name__} is protected and cannot be edited.")
 
-    def get_callbacks(self, as_dict=True):
+    def get_callbacks(self, as_dict=True) -> Union[List[clbk_type], List[Callback]]:
+        """
+        Get all callbacks including checkpoint on last position.
+
+        :param as_dict: set return format, either clbk_type with dictionary structure (as_dict=True, default) or list
+
+        :return: all callbacks either as callback dictionary structure (embedded in a list) or as raw objects in a list
+        """
         if as_dict:
             return self._get_callbacks()
         else:
             return [clb["callback"] for clb in self._get_callbacks()]
 
-    def get_callback_by_name(self, obj_name):
+    def get_callback_by_name(self, obj_name: str) -> Callback:
+        """
+        Get single callback by its name.
+
+        :param obj_name: name of callback to look for
+
+        :return: requested callback object
+        """
         if obj_name != "callback":
             return [clbk[clbk["name"]] for clbk in self.__callbacks if clbk["name"] == obj_name][0]
 
-    def _get_callbacks(self):
+    def _get_callbacks(self) -> List[clbk_type]:
+        """Return all callbacks and append checkpoint if available on last position."""
         clbks = self._callbacks
         if self._checkpoint is not None:
             clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath}]
         return clbks
 
-    def get_checkpoint(self):
+    def get_checkpoint(self) -> ModelCheckpointAdvanced:
+        """Return current checkpoint if available."""
         if self._checkpoint is not None:
             return self._checkpoint
 
     def create_model_checkpoint(self, **kwargs):
+        """Create a model checkpoint and enable edit."""
         self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, **kwargs)
         self.editable = False
 
-    def load_callbacks(self):
+    def load_callbacks(self) -> None:
+        """Load callbacks from path and save in callback attribute."""
         for pos, callback in enumerate(self.__callbacks):
             path = callback["path"]
             clb = pickle.load(open(path, "rb"))
             self._update_callback(pos, clb)
 
-    def update_checkpoint(self, history_name="hist"):
+    def update_checkpoint(self, history_name: str = "hist") -> None:
+        """
+        Update callbacks and history's best elements.
+
+        :param history_name: name of history object
+        """
         self._checkpoint.update_callbacks(self._callbacks)
         self._checkpoint.update_best(self.get_callback_by_name(history_name))