Skip to content
Snippets Groups Projects
Commit 0ec29f6a authored by lukas leufen's avatar lukas leufen
Browse files

refactoring of check_param

parent 602ecea1
No related branches found
No related tags found
2 merge requests!9new version v0.2.0,!7l_p_loss and lrdecay implementation
Pipeline #25748 passed
...@@ -7,6 +7,7 @@ import keras ...@@ -7,6 +7,7 @@ import keras
import keras.backend as K import keras.backend as K
import math import math
from typing import Union from typing import Union
import numpy as np
def to_list(arg): def to_list(arg):
...@@ -52,22 +53,15 @@ class LearningRateDecay(keras.callbacks.History): ...@@ -52,22 +53,15 @@ class LearningRateDecay(keras.callbacks.History):
:param upper: right (upper) endpoint of interval, closed :param upper: right (upper) endpoint of interval, closed
:return: unchanged value or raise ValueError :return: unchanged value or raise ValueError
""" """
if all(v is not None for v in [lower, upper]): if lower is None:
lower = -np.inf
if upper is None:
upper = np.inf
if lower < value <= upper: if lower < value <= upper:
return value return value
else: else:
raise ValueError(f"{name} is out of allowed range ({lower}, {upper}]: {name}={value}") raise ValueError(f"{name} is out of allowed range ({lower}, {upper}{')' if upper == np.inf else ']'}: "
elif lower is not None: f"{name}={value}")
if lower < value:
return value
else:
raise ValueError(f"{name} is out of allowed range ({lower}, +inf): {name}={value}")
elif upper is not None:
if value <= upper:
return value
else:
raise ValueError(f"{name} is out of allowed range (-inf, {upper}]: {name}={value}")
return value
def on_epoch_begin(self, epoch: int, logs=None): def on_epoch_begin(self, epoch: int, logs=None):
""" """
......
...@@ -41,7 +41,7 @@ class TestLearningRateDecay: ...@@ -41,7 +41,7 @@ class TestLearningRateDecay:
assert lr_decay.check_param(1.5, "tester", upper=None) == 1.5 assert lr_decay.check_param(1.5, "tester", upper=None) == 1.5
with pytest.raises(ValueError) as e: with pytest.raises(ValueError) as e:
lr_decay.check_param(0, "tester", upper=None) lr_decay.check_param(0, "tester", upper=None)
assert "tester is out of allowed range (0, +inf): tester=0" in e.value.args[0] assert "tester is out of allowed range (0, inf): tester=0" in e.value.args[0]
assert lr_decay.check_param(0.5, "tester", lower=None) == 0.5 assert lr_decay.check_param(0.5, "tester", lower=None) == 0.5
with pytest.raises(ValueError) as e: with pytest.raises(ValueError) as e:
lr_decay.check_param(0.5, "tester", lower=None, upper=0.2) lr_decay.check_param(0.5, "tester", lower=None, upper=0.2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment