Skip to content
Snippets Groups Projects
Commit ae104ccb authored by lukas leufen's avatar lukas leufen
Browse files

keras extensions are documented now

parent a4644e7a
No related branches found
No related tags found
3 merge requests!125Release v0.10.0,!124Update Master to new version v0.10.0,!91WIP: Resolve "create sphinx docu"
Pipeline #35398 passed
......@@ -59,6 +59,7 @@ tensorflow-estimator==1.13.0
tensorflow==1.13.1
termcolor==1.1.0
toolz==0.10.0
typing-extensions
urllib3==1.25.8
wcwidth==0.1.8
Werkzeug==1.0.0
......
......@@ -59,6 +59,7 @@ tensorflow-estimator==1.13.0
tensorflow-gpu==1.13.1
termcolor==1.1.0
toolz==0.10.0
typing-extensions
urllib3==1.25.8
wcwidth==0.1.8
Werkzeug==1.0.0
......
"""Collection of different extensions to keras framework."""
__author__ = 'Lukas Leufen, Felix Kleinert'
__date__ = '2020-01-31'
import logging
import math
import pickle
from typing import Union
from typing import Union, List
from typing_extensions import TypedDict
import numpy as np
from keras import backend as K
from keras.callbacks import History, ModelCheckpoint
from keras.callbacks import History, ModelCheckpoint, Callback
from src import helpers
class HistoryAdvanced(History):
"""
This is almost an identical clone of the original History class. The only difference is that attributes epoch and
history are instantiated during the init phase and not during on_train_begin. This is required to resume an already
started but disrupted training from an saved state. This HistoryAdvanced callback needs to be added separately as
additional callback. To get the full history use this object for further steps instead of the default return of
training methods like fit_generator().
This is almost an identical clone of the original History class.
The only difference is that attributes epoch and history are instantiated during the init phase and not during
on_train_begin. This is required to resume an already started but disrupted training from an saved state. This
HistoryAdvanced callback needs to be added separately as additional callback. To get the full history use this
object for further steps instead of the default return of training methods like fit_generator().
.. code-block:: python
hist = HistoryAdvanced()
history = model.fit_generator(generator=.... , callbacks=[hist])
......@@ -29,21 +35,30 @@ class HistoryAdvanced(History):
"""
def __init__(self):
"""Set up HistoryAdvanced."""
self.epoch = []
self.history = {}
super().__init__()
def on_train_begin(self, logs=None):
"""Overload on_train_begin method to do nothing instead of resetting epoch and history."""
pass
class LearningRateDecay(History):
"""
Decay learning rate during model training. Start with a base learning rate and lower this rate after every
n(=epochs_drop) epochs by drop value (0, 1], drop value = 1 means no decay in learning rate.
Decay learning rate during model training.
Start with a base learning rate and lower this rate after every n(=epochs_drop) epochs by drop value (0, 1], drop
value = 1 means no decay in learning rate.
:param base_lr: base learning rate to start with
:param drop: ratio to drop after epochs_drop
:param epochs_drop: number of epochs after that drop takes place
"""
def __init__(self, base_lr: float = 0.01, drop: float = 0.96, epochs_drop: int = 8):
"""Set up LearningRateDecay."""
super().__init__()
self.lr = {'lr': []}
self.base_lr = self.check_param(base_lr, 'base_lr')
......@@ -55,13 +70,16 @@ class LearningRateDecay(History):
@staticmethod
def check_param(value: float, name: str, lower: Union[float, None] = 0, upper: Union[float, None] = 1):
"""
Check if given value is in interval. The left (lower) endpoint is open, right (upper) endpoint is closed. To
only one side of the interval, set the other endpoint to None. If both ends are set to None, just return the
value without any check.
Check if given value is in interval.
The left (lower) endpoint is open, right (upper) endpoint is closed. To use only one side of the interval, set
the other endpoint to None. If both ends are set to None, just return the value without any check.
:param value: value to check
:param name: name of the variable to display in error message
:param lower: left (lower) endpoint of interval, opened
:param upper: right (upper) endpoint of interval, closed
:return: unchanged value or raise ValueError
"""
if lower is None:
......@@ -75,11 +93,13 @@ class LearningRateDecay(History):
f"{name}={value}")
def on_train_begin(self, logs=None):
"""Overload on_train_begin method to do nothing instead of resetting epoch and history."""
pass
def on_epoch_begin(self, epoch: int, logs=None):
"""
Lower learning rate every epochs_drop epochs by factor drop.
:param epoch: current epoch
:param logs: ?
:return: update keras learning rate
......@@ -93,46 +113,66 @@ class LearningRateDecay(History):
class ModelCheckpointAdvanced(ModelCheckpoint):
"""
Enhance the standard ModelCheckpoint class by additional saves of given callbacks. Specify this callbacks as follow:
Enhance the standard ModelCheckpoint class by additional saves of given callbacks.
**We recommend to use CallbackHandler instead of ModelCheckpointAdvanced.** CallbackHandler will handler all your
callbacks and the ModelCheckpointAdvanced and prevent you from pitfalls like wrong ordering of callbacks. Actually,
CallbackHandler makes use of ModelCheckpointAdvanced.
However, if you want to use the ModelCheckpointAdvanced explicitly, follow these instructions:
.. code-block:: python
# load your callbacks
lr = CustomLearningRate()
hist = CustomHistory()
# set your callbacks with a list dictionary structure
callbacks_name = "your_custom_path_%s.pickle"
callbacks = [{"callback": lr, "path": callbacks_name % "lr"},
{"callback": hist, "path": callbacks_name % "hist"}]
# initialise ModelCheckpointAdvanced like the normal ModelCheckpoint (see keras callbacks)
ckpt_callbacks = ModelCheckpointAdvanced(filepath=.... , callbacks=callbacks)
Add this ckpt_callbacks as all other additional callbacks to the callback list. IMPORTANT: Always add ckpt_callbacks
as last callback to properly update all tracked callbacks, e.g.
Add ModelCheckpointAdvanced as all other additional callbacks to the callback list. IMPORTANT: Always add
ModelCheckpointAdvanced as last callback to properly update all tracked callbacks, e.g.
.. code-block:: python
# always add ModelCheckpointAdvanced as last element
fit_generator(.... , callbacks=[lr, hist, ckpt_callbacks])
"""
def __init__(self, *args, **kwargs):
"""Initialise ModelCheckpointAdvanced and set callbacks attribute."""
self.callbacks = kwargs.pop("callbacks")
super().__init__(*args, **kwargs)
def update_best(self, hist):
"""
Update internal best on resuming a training process. Otherwise best is set to +/- inf depending on the
performance metric and the first trained model (first of the resuming training process) will always saved as
best model because its performance will be better than infinity. To prevent this behaviour and compare the
performance with the best model performance, call this method before resuming the training process.
Update internal best on resuming a training process.
If no best object is available, best is set to +/- inf depending on the performance metric and the first trained
model (first of the resuming training process) will always saved as best model because its performance will be
better than infinity. To prevent this behaviour and compare the performance with the best model performance,
call this method before resuming the training process.
:param hist: The History object from the previous (interrupted) training.
"""
self.best = hist.history.get(self.monitor)[-1]
def update_callbacks(self, callbacks):
"""
Update all stored callback objects. The argument callbacks needs to follow the same convention like described
in the class description (list of dictionaries). Must be run before resuming a training process.
Update all stored callback objects.
The argument callbacks needs to follow the same convention like described in the class description (list of
dictionaries). Must be run before resuming a training process.
"""
self.callbacks = helpers.to_list(callbacks)
def on_epoch_end(self, epoch, logs=None):
"""
Save model as usual (see ModelCheckpoint class), but also save additional callbacks.
"""
"""Save model as usual (see ModelCheckpoint class), but also save additional callbacks."""
super().on_epoch_end(epoch, logs)
for callback in self.callbacks:
......@@ -152,10 +192,65 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
pickle.dump(callback["callback"], f)
clbk_type = TypedDict("clbk_type", {"name": str, str: Callback, "path": str})
class CallbackHandler:
r"""Use the CallbackHandler for better controlling of custom callbacks.
The callback handler will always keep your callbacks in the right order and adds a model checkpoint at last position
if required. You can add an arbitrary number of callbacks to the handler.
.. code-block:: python
# init callbacks handler
callbacks = CallbackHandler()
# set history object (add further elements like this example)
hist = keras.callbacks.History()
callbacks.add_callback(hist, "callbacks-hist.pickle", "hist")
# create advanced checkpoint (details see ModelCheckpointAdvanced)
ckpt_name = "model-best.h5"
callbacks.create_model_checkpoint(filepath=ckpt_name, verbose=1, ...)
# get checkpoint
ckpt = callbacks.get_checkpoint()
# fit already compiled model and add callbacks, it is important to call get_callbacks with as_dict=False
history = model.fit(..., callbacks=self.callbacks.get_callbacks(as_dict=False))
If you want to continue a training, you can use the callback handler to load already stored callbacks. First you
need to reload all callbacks. Make sure, that all callbacks are available from previous training. If the callback
handler was set up like in the former code example, this will work.
.. code-block:: python
# load callbacks and update checkpoint
callbacks.load_callbacks()
callbacks.update_checkpoint()
# optional: load your model using checkpoint path
model = keras.models.load_model(ckpt.filepath)
# extract history object and set starting epoch
hist = callbacks.get_callback_by_name("hist")
initial_epoch = max(hist.epoch) + 1
# resume training (including initial_epoch) and use callback handler's history object
_ = self.model.fit(..., callbacks=self.callbacks.get_callbacks(as_dict=False), initial_epoch=initial_epoch)
history = hist
Important notes: Do not use the returned history object of model.fit, but use the history object from callback
handler. The fit history will only contain the new history, whereas callback handler's history contains the full
history including the resumed and new history. For a correct epoch counting, you need to add the initial epoch to
the fit method too.
"""
def __init__(self):
self.__callbacks = []
"""Initialise CallbackHandler."""
self.__callbacks: List[clbk_type] = []
self._checkpoint = None
self.editable = True
......@@ -168,46 +263,79 @@ class CallbackHandler:
name, callback, callback_path = value
self.__callbacks.append({"name": name, name: callback, "path": callback_path})
def _update_callback(self, pos, value):
def _update_callback(self, pos: int, value: Callback) -> None:
"""Update callback entry with given value."""
name = self.__callbacks[pos]["name"]
self.__callbacks[pos][name] = value
def add_callback(self, callback, callback_path, name="callback"):
def add_callback(self, callback: Callback, callback_path: str, name: str = "callback") -> None:
"""
Add given callback on last position if CallbackHandler is editable.
Save callback with given name. Will raise a PermissionError, if editable is False.
:param callback: callback object to store
:param callback_path: path to callback
:param name: name of the callback
"""
if self.editable:
self._callbacks = (name, callback, callback_path)
else:
raise PermissionError(f"{__class__.__name__} is protected and cannot be edited.")
def get_callbacks(self, as_dict=True):
def get_callbacks(self, as_dict=True) -> Union[List[clbk_type], List[Callback]]:
"""
Get all callbacks including checkpoint on last position.
:param as_dict: set return format, either clbk_type with dictionary structure (as_dict=True, default) or list
:return: all callbacks either as callback dictionary structure (embedded in a list) or as raw objects in a list
"""
if as_dict:
return self._get_callbacks()
else:
return [clb["callback"] for clb in self._get_callbacks()]
def get_callback_by_name(self, obj_name):
def get_callback_by_name(self, obj_name: str) -> Callback:
"""
Get single callback by its name.
:param obj_name: name of callback to look for
:return: requested callback object
"""
if obj_name != "callback":
return [clbk[clbk["name"]] for clbk in self.__callbacks if clbk["name"] == obj_name][0]
def _get_callbacks(self):
def _get_callbacks(self) -> List[clbk_type]:
"""Return all callbacks and append checkpoint if available on last position."""
clbks = self._callbacks
if self._checkpoint is not None:
clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath}]
return clbks
def get_checkpoint(self):
def get_checkpoint(self) -> ModelCheckpointAdvanced:
"""Return current checkpoint if available."""
if self._checkpoint is not None:
return self._checkpoint
def create_model_checkpoint(self, **kwargs):
"""Create a model checkpoint and enable edit."""
self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, **kwargs)
self.editable = False
def load_callbacks(self):
def load_callbacks(self) -> None:
"""Load callbacks from path and save in callback attribute."""
for pos, callback in enumerate(self.__callbacks):
path = callback["path"]
clb = pickle.load(open(path, "rb"))
self._update_callback(pos, clb)
def update_checkpoint(self, history_name="hist"):
def update_checkpoint(self, history_name: str = "hist") -> None:
"""
Update callbacks and history's best elements.
:param history_name: name of history object
"""
self._checkpoint.update_callbacks(self._callbacks)
self._checkpoint.update_best(self.get_callback_by_name(history_name))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment