Skip to content
Snippets Groups Projects
Commit 0ec29f6a authored by lukas leufen's avatar lukas leufen
Browse files

refactoring of check_param

parent 602ecea1
Branches
Tags
2 merge requests!9new version v0.2.0,!7l_p_loss and lrdecay implementation
Pipeline #25748 passed
......@@ -7,6 +7,7 @@ import keras
import keras.backend as K
import math
from typing import Union
import numpy as np
def to_list(arg):
......@@ -52,22 +53,15 @@ class LearningRateDecay(keras.callbacks.History):
:param upper: right (upper) endpoint of interval, closed
:return: unchanged value or raise ValueError
"""
if all(v is not None for v in [lower, upper]):
if lower is None:
lower = -np.inf
if upper is None:
upper = np.inf
if lower < value <= upper:
return value
else:
raise ValueError(f"{name} is out of allowed range ({lower}, {upper}]: {name}={value}")
elif lower is not None:
if lower < value:
return value
else:
raise ValueError(f"{name} is out of allowed range ({lower}, +inf): {name}={value}")
elif upper is not None:
if value <= upper:
return value
else:
raise ValueError(f"{name} is out of allowed range (-inf, {upper}]: {name}={value}")
return value
raise ValueError(f"{name} is out of allowed range ({lower}, {upper}{')' if upper == np.inf else ']'}: "
f"{name}={value}")
def on_epoch_begin(self, epoch: int, logs=None):
"""
......
......@@ -41,7 +41,7 @@ class TestLearningRateDecay:
assert lr_decay.check_param(1.5, "tester", upper=None) == 1.5
with pytest.raises(ValueError) as e:
lr_decay.check_param(0, "tester", upper=None)
assert "tester is out of allowed range (0, +inf): tester=0" in e.value.args[0]
assert "tester is out of allowed range (0, inf): tester=0" in e.value.args[0]
assert lr_decay.check_param(0.5, "tester", lower=None) == 0.5
with pytest.raises(ValueError) as e:
lr_decay.check_param(0.5, "tester", lower=None, upper=0.2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment