Select Git revision
keras_extensions.py
keras_extensions.py 3.96 KiB
__author__ = 'Lukas Leufen, Felix Kleinert'
__date__ = '2020-01-31'
import logging
import math
import pickle
from typing import Union
import numpy as np
from keras import backend as K
from keras.callbacks import History, ModelCheckpoint
class HistoryAdvanced(History):
def __init__(self, old_epoch=None, old_history=None):
self.epoch = old_epoch or []
self.history = old_history or {}
super().__init__()
def on_train_begin(self, logs=None):
pass
class LearningRateDecay(History):
"""
Decay learning rate during model training. Start with a base learning rate and lower this rate after every
n(=epochs_drop) epochs by drop value (0, 1], drop value = 1 means no decay in learning rate.
"""
def __init__(self, base_lr: float = 0.01, drop: float = 0.96, epochs_drop: int = 8):
super().__init__()
self.lr = {'lr': []}
self.base_lr = self.check_param(base_lr, 'base_lr')
self.drop = self.check_param(drop, 'drop')
self.epochs_drop = self.check_param(epochs_drop, 'epochs_drop', upper=None)
self.epoch = []
self.history = {}
@staticmethod
def check_param(value: float, name: str, lower: Union[float, None] = 0, upper: Union[float, None] = 1):
"""
Check if given value is in interval. The left (lower) endpoint is open, right (upper) endpoint is closed. To
only one side of the interval, set the other endpoint to None. If both ends are set to None, just return the
value without any check.
:param value: value to check
:param name: name of the variable to display in error message
:param lower: left (lower) endpoint of interval, opened
:param upper: right (upper) endpoint of interval, closed
:return: unchanged value or raise ValueError
"""
if lower is None:
lower = -np.inf
if upper is None:
upper = np.inf
if lower < value <= upper:
return value
else:
raise ValueError(f"{name} is out of allowed range ({lower}, {upper}{')' if upper == np.inf else ']'}: "
f"{name}={value}")
def on_train_begin(self, logs=None):
pass
def on_epoch_begin(self, epoch: int, logs=None):
"""
Lower learning rate every epochs_drop epochs by factor drop.
:param epoch: current epoch
:param logs: ?
:return: update keras learning rate
"""
current_lr = self.base_lr * math.pow(self.drop, math.floor(epoch / self.epochs_drop))
K.set_value(self.model.optimizer.lr, current_lr)
self.lr['lr'].append(current_lr)
logging.info(f"Set learning rate to {current_lr}")
return K.get_value(self.model.optimizer.lr)
class ModelCheckpointAdvanced(ModelCheckpoint):
"""
IMPORTANT: Always add the model checkpoint advanced as last callback to properly update all tracked callbacks, e.g.
fit_generator(callbacks=[..., <last_here>])
"""
def __init__(self, *args, **kwargs):
self.callbacks = kwargs.pop("callbacks")
super().__init__(*args, **kwargs)
def update_best(self, hist):
self.best = hist.history.get(self.monitor)[-1]
def update_callbacks(self, callbacks):
self.callbacks = callbacks
def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(epoch, logs)
for callback in self.callbacks:
file_path = callback["path"]
if self.epochs_since_last_save == 0 and epoch != 0:
if self.save_best_only:
current = logs.get(self.monitor)
if current == self.best:
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
with open(file_path, "wb") as f:
pickle.dump(callback["callback"], f)
else:
with open(file_path, "wb") as f:
pickle.dump(callback["callback"], f)