diff --git a/requirements.txt b/requirements.txt index b46f44416cf6560ecc0b62f8d22dd7d547a036c6..71bb1338effff38092510982d4a2c1f37f7b026a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -59,6 +59,7 @@ tensorflow-estimator==1.13.0 tensorflow==1.13.1 termcolor==1.1.0 toolz==0.10.0 +typing-extensions urllib3==1.25.8 wcwidth==0.1.8 Werkzeug==1.0.0 diff --git a/requirements_gpu.txt b/requirements_gpu.txt index 6ce4df8fe164408024e21db5ea94a692fb5dbf26..5ddb56acc71e0a51abb99b9447f871ddcb715a5d 100644 --- a/requirements_gpu.txt +++ b/requirements_gpu.txt @@ -59,6 +59,7 @@ tensorflow-estimator==1.13.0 tensorflow-gpu==1.13.1 termcolor==1.1.0 toolz==0.10.0 +typing-extensions urllib3==1.25.8 wcwidth==0.1.8 Werkzeug==1.0.0 diff --git a/src/model_modules/keras_extensions.py b/src/model_modules/keras_extensions.py index 180e324602da25e1df8fb218c1d3bba180004ac8..0b374bc4cfb55c945aeceb54112579716e1c6c17 100644 --- a/src/model_modules/keras_extensions.py +++ b/src/model_modules/keras_extensions.py @@ -1,25 +1,31 @@ +"""Collection of different extensions to keras framework.""" + __author__ = 'Lukas Leufen, Felix Kleinert' __date__ = '2020-01-31' import logging import math import pickle -from typing import Union +from typing import Union, List +from typing_extensions import TypedDict import numpy as np from keras import backend as K -from keras.callbacks import History, ModelCheckpoint +from keras.callbacks import History, ModelCheckpoint, Callback from src import helpers class HistoryAdvanced(History): """ - This is almost an identical clone of the original History class. The only difference is that attributes epoch and - history are instantiated during the init phase and not during on_train_begin. This is required to resume an already - started but disrupted training from an saved state. This HistoryAdvanced callback needs to be added separately as - additional callback. To get the full history use this object for further steps instead of the default return of - training methods like fit_generator(). + This is almost an identical clone of the original History class. + + The only difference is that attributes epoch and history are instantiated during the init phase and not during + on_train_begin. This is required to resume an already started but disrupted training from an saved state. This + HistoryAdvanced callback needs to be added separately as additional callback. To get the full history use this + object for further steps instead of the default return of training methods like fit_generator(). + + .. code-block:: python hist = HistoryAdvanced() history = model.fit_generator(generator=.... , callbacks=[hist]) @@ -29,21 +35,30 @@ class HistoryAdvanced(History): """ def __init__(self): + """Set up HistoryAdvanced.""" self.epoch = [] self.history = {} super().__init__() def on_train_begin(self, logs=None): + """Overload on_train_begin method to do nothing instead of resetting epoch and history.""" pass class LearningRateDecay(History): """ - Decay learning rate during model training. Start with a base learning rate and lower this rate after every - n(=epochs_drop) epochs by drop value (0, 1], drop value = 1 means no decay in learning rate. + Decay learning rate during model training. + + Start with a base learning rate and lower this rate after every n(=epochs_drop) epochs by drop value (0, 1], drop + value = 1 means no decay in learning rate. + + :param base_lr: base learning rate to start with + :param drop: ratio to drop after epochs_drop + :param epochs_drop: number of epochs after that drop takes place """ def __init__(self, base_lr: float = 0.01, drop: float = 0.96, epochs_drop: int = 8): + """Set up LearningRateDecay.""" super().__init__() self.lr = {'lr': []} self.base_lr = self.check_param(base_lr, 'base_lr') @@ -55,13 +70,16 @@ class LearningRateDecay(History): @staticmethod def check_param(value: float, name: str, lower: Union[float, None] = 0, upper: Union[float, None] = 1): """ - Check if given value is in interval. The left (lower) endpoint is open, right (upper) endpoint is closed. To - only one side of the interval, set the other endpoint to None. If both ends are set to None, just return the - value without any check. + Check if given value is in interval. + + The left (lower) endpoint is open, right (upper) endpoint is closed. To use only one side of the interval, set + the other endpoint to None. If both ends are set to None, just return the value without any check. + :param value: value to check :param name: name of the variable to display in error message :param lower: left (lower) endpoint of interval, opened :param upper: right (upper) endpoint of interval, closed + :return: unchanged value or raise ValueError """ if lower is None: @@ -75,11 +93,13 @@ class LearningRateDecay(History): f"{name}={value}") def on_train_begin(self, logs=None): + """Overload on_train_begin method to do nothing instead of resetting epoch and history.""" pass def on_epoch_begin(self, epoch: int, logs=None): """ Lower learning rate every epochs_drop epochs by factor drop. + :param epoch: current epoch :param logs: ? :return: update keras learning rate @@ -93,46 +113,66 @@ class LearningRateDecay(History): class ModelCheckpointAdvanced(ModelCheckpoint): """ - Enhance the standard ModelCheckpoint class by additional saves of given callbacks. Specify this callbacks as follow: + Enhance the standard ModelCheckpoint class by additional saves of given callbacks. + + **We recommend to use CallbackHandler instead of ModelCheckpointAdvanced.** CallbackHandler will handler all your + callbacks and the ModelCheckpointAdvanced and prevent you from pitfalls like wrong ordering of callbacks. Actually, + CallbackHandler makes use of ModelCheckpointAdvanced. + + However, if you want to use the ModelCheckpointAdvanced explicitly, follow these instructions: + .. code-block:: python + + # load your callbacks lr = CustomLearningRate() hist = CustomHistory() + + # set your callbacks with a list dictionary structure callbacks_name = "your_custom_path_%s.pickle" callbacks = [{"callback": lr, "path": callbacks_name % "lr"}, - {"callback": hist, "path": callbacks_name % "hist"}] + {"callback": hist, "path": callbacks_name % "hist"}] + # initialise ModelCheckpointAdvanced like the normal ModelCheckpoint (see keras callbacks) ckpt_callbacks = ModelCheckpointAdvanced(filepath=.... , callbacks=callbacks) - Add this ckpt_callbacks as all other additional callbacks to the callback list. IMPORTANT: Always add ckpt_callbacks - as last callback to properly update all tracked callbacks, e.g. + Add ModelCheckpointAdvanced as all other additional callbacks to the callback list. IMPORTANT: Always add + ModelCheckpointAdvanced as last callback to properly update all tracked callbacks, e.g. + + .. code-block:: python + # always add ModelCheckpointAdvanced as last element fit_generator(.... , callbacks=[lr, hist, ckpt_callbacks]) """ + def __init__(self, *args, **kwargs): + """Initialise ModelCheckpointAdvanced and set callbacks attribute.""" self.callbacks = kwargs.pop("callbacks") super().__init__(*args, **kwargs) def update_best(self, hist): """ - Update internal best on resuming a training process. Otherwise best is set to +/- inf depending on the - performance metric and the first trained model (first of the resuming training process) will always saved as - best model because its performance will be better than infinity. To prevent this behaviour and compare the - performance with the best model performance, call this method before resuming the training process. + Update internal best on resuming a training process. + + If no best object is available, best is set to +/- inf depending on the performance metric and the first trained + model (first of the resuming training process) will always saved as best model because its performance will be + better than infinity. To prevent this behaviour and compare the performance with the best model performance, + call this method before resuming the training process. + :param hist: The History object from the previous (interrupted) training. """ self.best = hist.history.get(self.monitor)[-1] def update_callbacks(self, callbacks): """ - Update all stored callback objects. The argument callbacks needs to follow the same convention like described - in the class description (list of dictionaries). Must be run before resuming a training process. + Update all stored callback objects. + + The argument callbacks needs to follow the same convention like described in the class description (list of + dictionaries). Must be run before resuming a training process. """ self.callbacks = helpers.to_list(callbacks) def on_epoch_end(self, epoch, logs=None): - """ - Save model as usual (see ModelCheckpoint class), but also save additional callbacks. - """ + """Save model as usual (see ModelCheckpoint class), but also save additional callbacks.""" super().on_epoch_end(epoch, logs) for callback in self.callbacks: @@ -152,10 +192,65 @@ class ModelCheckpointAdvanced(ModelCheckpoint): pickle.dump(callback["callback"], f) +clbk_type = TypedDict("clbk_type", {"name": str, str: Callback, "path": str}) + + class CallbackHandler: + r"""Use the CallbackHandler for better controlling of custom callbacks. + + The callback handler will always keep your callbacks in the right order and adds a model checkpoint at last position + if required. You can add an arbitrary number of callbacks to the handler. + + .. code-block:: python + + # init callbacks handler + callbacks = CallbackHandler() + + # set history object (add further elements like this example) + hist = keras.callbacks.History() + callbacks.add_callback(hist, "callbacks-hist.pickle", "hist") + + # create advanced checkpoint (details see ModelCheckpointAdvanced) + ckpt_name = "model-best.h5" + callbacks.create_model_checkpoint(filepath=ckpt_name, verbose=1, ...) + + # get checkpoint + ckpt = callbacks.get_checkpoint() + + # fit already compiled model and add callbacks, it is important to call get_callbacks with as_dict=False + history = model.fit(..., callbacks=self.callbacks.get_callbacks(as_dict=False)) + + If you want to continue a training, you can use the callback handler to load already stored callbacks. First you + need to reload all callbacks. Make sure, that all callbacks are available from previous training. If the callback + handler was set up like in the former code example, this will work. + + .. code-block:: python + + # load callbacks and update checkpoint + callbacks.load_callbacks() + callbacks.update_checkpoint() + + # optional: load your model using checkpoint path + model = keras.models.load_model(ckpt.filepath) + + # extract history object and set starting epoch + hist = callbacks.get_callback_by_name("hist") + initial_epoch = max(hist.epoch) + 1 + + # resume training (including initial_epoch) and use callback handler's history object + _ = self.model.fit(..., callbacks=self.callbacks.get_callbacks(as_dict=False), initial_epoch=initial_epoch) + history = hist + + Important notes: Do not use the returned history object of model.fit, but use the history object from callback + handler. The fit history will only contain the new history, whereas callback handler's history contains the full + history including the resumed and new history. For a correct epoch counting, you need to add the initial epoch to + the fit method too. + + """ def __init__(self): - self.__callbacks = [] + """Initialise CallbackHandler.""" + self.__callbacks: List[clbk_type] = [] self._checkpoint = None self.editable = True @@ -168,46 +263,79 @@ class CallbackHandler: name, callback, callback_path = value self.__callbacks.append({"name": name, name: callback, "path": callback_path}) - def _update_callback(self, pos, value): + def _update_callback(self, pos: int, value: Callback) -> None: + """Update callback entry with given value.""" name = self.__callbacks[pos]["name"] self.__callbacks[pos][name] = value - def add_callback(self, callback, callback_path, name="callback"): + def add_callback(self, callback: Callback, callback_path: str, name: str = "callback") -> None: + """ + Add given callback on last position if CallbackHandler is editable. + + Save callback with given name. Will raise a PermissionError, if editable is False. + + :param callback: callback object to store + :param callback_path: path to callback + :param name: name of the callback + """ if self.editable: self._callbacks = (name, callback, callback_path) else: raise PermissionError(f"{__class__.__name__} is protected and cannot be edited.") - def get_callbacks(self, as_dict=True): + def get_callbacks(self, as_dict=True) -> Union[List[clbk_type], List[Callback]]: + """ + Get all callbacks including checkpoint on last position. + + :param as_dict: set return format, either clbk_type with dictionary structure (as_dict=True, default) or list + + :return: all callbacks either as callback dictionary structure (embedded in a list) or as raw objects in a list + """ if as_dict: return self._get_callbacks() else: return [clb["callback"] for clb in self._get_callbacks()] - def get_callback_by_name(self, obj_name): + def get_callback_by_name(self, obj_name: str) -> Callback: + """ + Get single callback by its name. + + :param obj_name: name of callback to look for + + :return: requested callback object + """ if obj_name != "callback": return [clbk[clbk["name"]] for clbk in self.__callbacks if clbk["name"] == obj_name][0] - def _get_callbacks(self): + def _get_callbacks(self) -> List[clbk_type]: + """Return all callbacks and append checkpoint if available on last position.""" clbks = self._callbacks if self._checkpoint is not None: clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath}] return clbks - def get_checkpoint(self): + def get_checkpoint(self) -> ModelCheckpointAdvanced: + """Return current checkpoint if available.""" if self._checkpoint is not None: return self._checkpoint def create_model_checkpoint(self, **kwargs): + """Create a model checkpoint and enable edit.""" self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, **kwargs) self.editable = False - def load_callbacks(self): + def load_callbacks(self) -> None: + """Load callbacks from path and save in callback attribute.""" for pos, callback in enumerate(self.__callbacks): path = callback["path"] clb = pickle.load(open(path, "rb")) self._update_callback(pos, clb) - def update_checkpoint(self, history_name="hist"): + def update_checkpoint(self, history_name: str = "hist") -> None: + """ + Update callbacks and history's best elements. + + :param history_name: name of history object + """ self._checkpoint.update_callbacks(self._callbacks) self._checkpoint.update_best(self.get_callback_by_name(history_name))