diff --git a/src/helpers.py b/src/helpers.py index f4ac45b2e25156bf8d61990f8e45f09cc0aff4a8..50e99cef006a94c959233bd3d4d2d378adabcda5 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -6,6 +6,7 @@ import logging import keras import keras.backend as K import math +from typing import Union def to_list(arg): @@ -26,35 +27,57 @@ def l_p_loss(power: int): return loss -class lrDecay(keras.callbacks.History): +class LearningRateDecay(keras.callbacks.History): + """ + Decay learning rate during model training. Start with a base learning rate and lower this rate after every + n(=epochs_drop) epochs by drop value (0, 1], drop value = 1 means no decay in learning rate. + """ def __init__(self, base_lr: float = 0.01, drop: float = 0.96, epochs_drop: int = 8): - super(lrDecay, self).__init__() - + super().__init__() self.lr = {'lr': []} - self.base_lr = base_lr - self.drop = drop - self.epochs_drop = epochs_drop + self.base_lr = self.check_param(base_lr, 'base_lr') + self.drop = self.check_param(drop, 'drop') + self.epochs_drop = self.check_param(epochs_drop, 'epochs_drop', upper=None) + + @staticmethod + def check_param(value: float, name: str, lower: Union[float, None] = 0, upper: Union[float, None] = 1): + """ + Check if given value is in interval. The left (lower) endpoint is open, right (upper) endpoint is closed. To + only one side of the interval, set the other endpoint to None. If both ends are set to None, just return the + value without any check. + :param value: value to check + :param name: name of the variable to display in error message + :param lower: left (lower) endpoint of interval, opened + :param upper: right (upper) endpoint of interval, closed + :return: unchanged value or raise ValueError + """ + if all(v is not None for v in [lower, upper]): + if lower < value <= upper: + return value + else: + raise ValueError(f"{name} is out of allowed range ({lower}, {upper}]: {name}={value}") + elif lower is not None: + if lower < value: + return value + else: + raise ValueError(f"{name} is out of allowed range ({lower}, +inf): {name}={value}") + elif upper is not None: + if value <= upper: + return value + else: + raise ValueError(f"{name} is out of allowed range (-inf, {upper}]: {name}={value}") + return value def on_epoch_begin(self, epoch: int, logs=None): - if epoch > 0: - current_lr = self.base_lr * math.pow(self.drop, math.floor(1 + epoch) / self.epochs_drop) - else: - current_lr = self.base_lr + """ + Lower learning rate every epochs_drop epochs by factor drop. + :param epoch: current epoch + :param logs: ? + :return: update keras learning rate + """ + current_lr = self.base_lr * math.pow(self.drop, math.floor(epoch / self.epochs_drop)) K.set_value(self.model.optimizer.lr, current_lr) self.lr['lr'].append(current_lr) logging.info(f"Set learning rate to {current_lr}") return K.get_value(self.model.optimizer.lr) - - -class lrCallback(keras.callbacks.History): - - def __init__(self): - super(lrCallback, self).__init__() - self.lr = None - - def on_train_begin(self, logs=None): - self.lr = {} - - def on_epoch_end(self, epoch, logs=None): - self.lr.append(self.model.optimizer.lr) \ No newline at end of file diff --git a/test/test_helpers.py b/test/test_helpers.py index bc64176e10cac71471ecc78efbcb639bc9fab81f..69c5909ad3f7ee6be1432a5992278d4c6e873df5 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -1,9 +1,8 @@ import pytest -from src.helpers import l_p_loss +from src.helpers import l_p_loss, LearningRateDecay import logging import os import keras -import keras.backend as K import numpy as np @@ -19,3 +18,40 @@ class TestLoss: hist = model.fit(np.array([1, 0, -2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1) assert hist.history['loss'][0] == 2.25 + +class TestLearningRateDecay: + + def test_init(self): + lr_decay = LearningRateDecay() + assert lr_decay.lr == {'lr': []} + assert lr_decay.base_lr == 0.01 + assert lr_decay.drop == 0.96 + assert lr_decay.epochs_drop == 8 + + def test_check_param(self): + lr_decay = object.__new__(LearningRateDecay) + assert lr_decay.check_param(1, "tester") == 1 + assert lr_decay.check_param(0.5, "tester") == 0.5 + with pytest.raises(ValueError) as e: + lr_decay.check_param(0, "tester") + assert "tester is out of allowed range (0, 1]: tester=0" in e.value.args[0] + with pytest.raises(ValueError) as e: + lr_decay.check_param(1.5, "tester") + assert "tester is out of allowed range (0, 1]: tester=1.5" in e.value.args[0] + assert lr_decay.check_param(1.5, "tester", upper=None) == 1.5 + with pytest.raises(ValueError) as e: + lr_decay.check_param(0, "tester", upper=None) + assert "tester is out of allowed range (0, +inf): tester=0" in e.value.args[0] + assert lr_decay.check_param(0.5, "tester", lower=None) == 0.5 + with pytest.raises(ValueError) as e: + lr_decay.check_param(0.5, "tester", lower=None, upper=0.2) + assert "tester is out of allowed range (-inf, 0.2]: tester=0.5" in e.value.args[0] + assert lr_decay.check_param(10, "tester", upper=None, lower=None) + + def test_on_epoch_begin(self): + lr_decay = LearningRateDecay(base_lr=0.02, drop=0.95, epochs_drop=2) + model = keras.Sequential() + model.add(keras.layers.Dense(1, input_dim=1)) + model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2)) + model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=5, callbacks=[lr_decay]) + assert lr_decay.lr['lr'] == [0.02, 0.02, 0.02*0.95, 0.02*0.95, 0.02*0.95*0.95]