diff --git a/src/helpers.py b/src/helpers.py index 50e99cef006a94c959233bd3d4d2d378adabcda5..342a0b5ef77d8b286291792aa007f59c8c7b09b2 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -7,6 +7,7 @@ import keras import keras.backend as K import math from typing import Union +import numpy as np def to_list(arg): @@ -52,22 +53,15 @@ class LearningRateDecay(keras.callbacks.History): :param upper: right (upper) endpoint of interval, closed :return: unchanged value or raise ValueError """ - if all(v is not None for v in [lower, upper]): - if lower < value <= upper: - return value - else: - raise ValueError(f"{name} is out of allowed range ({lower}, {upper}]: {name}={value}") - elif lower is not None: - if lower < value: - return value - else: - raise ValueError(f"{name} is out of allowed range ({lower}, +inf): {name}={value}") - elif upper is not None: - if value <= upper: - return value - else: - raise ValueError(f"{name} is out of allowed range (-inf, {upper}]: {name}={value}") - return value + if lower is None: + lower = -np.inf + if upper is None: + upper = np.inf + if lower < value <= upper: + return value + else: + raise ValueError(f"{name} is out of allowed range ({lower}, {upper}{')' if upper == np.inf else ']'}: " + f"{name}={value}") def on_epoch_begin(self, epoch: int, logs=None): """ diff --git a/test/test_helpers.py b/test/test_helpers.py index 69c5909ad3f7ee6be1432a5992278d4c6e873df5..e4a23d15b6ef497849af28ffd783b6f51c6c5b5d 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -41,7 +41,7 @@ class TestLearningRateDecay: assert lr_decay.check_param(1.5, "tester", upper=None) == 1.5 with pytest.raises(ValueError) as e: lr_decay.check_param(0, "tester", upper=None) - assert "tester is out of allowed range (0, +inf): tester=0" in e.value.args[0] + assert "tester is out of allowed range (0, inf): tester=0" in e.value.args[0] assert lr_decay.check_param(0.5, "tester", lower=None) == 0.5 with pytest.raises(ValueError) as e: lr_decay.check_param(0.5, "tester", lower=None, upper=0.2)