From 0ec29f6afd2cce559ef09fa12e8d79e1877c60ef Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Tue, 12 Nov 2019 15:40:50 +0100
Subject: [PATCH] refactoring of check_param

---
 src/helpers.py       | 26 ++++++++++----------------
 test/test_helpers.py |  2 +-
 2 files changed, 11 insertions(+), 17 deletions(-)

diff --git a/src/helpers.py b/src/helpers.py
index 50e99cef..342a0b5e 100644
--- a/src/helpers.py
+++ b/src/helpers.py
@@ -7,6 +7,7 @@ import keras
 import keras.backend as K
 import math
 from typing import Union
+import numpy as np
 
 
 def to_list(arg):
@@ -52,22 +53,15 @@ class LearningRateDecay(keras.callbacks.History):
         :param upper: right (upper) endpoint of interval, closed
         :return: unchanged value or raise ValueError
         """
-        if all(v is not None for v in [lower, upper]):
-            if lower < value <= upper:
-                return value
-            else:
-                raise ValueError(f"{name} is out of allowed range ({lower}, {upper}]: {name}={value}")
-        elif lower is not None:
-            if lower < value:
-                return value
-            else:
-                raise ValueError(f"{name} is out of allowed range ({lower}, +inf): {name}={value}")
-        elif upper is not None:
-            if value <= upper:
-                return value
-            else:
-                raise ValueError(f"{name} is out of allowed range (-inf, {upper}]: {name}={value}")
-        return value
+        if lower is None:
+            lower = -np.inf
+        if upper is None:
+            upper = np.inf
+        if lower < value <= upper:
+            return value
+        else:
+            raise ValueError(f"{name} is out of allowed range ({lower}, {upper}{')' if upper == np.inf else ']'}: "
+                             f"{name}={value}")
 
     def on_epoch_begin(self, epoch: int, logs=None):
         """
diff --git a/test/test_helpers.py b/test/test_helpers.py
index 69c5909a..e4a23d15 100644
--- a/test/test_helpers.py
+++ b/test/test_helpers.py
@@ -41,7 +41,7 @@ class TestLearningRateDecay:
         assert lr_decay.check_param(1.5, "tester", upper=None) == 1.5
         with pytest.raises(ValueError) as e:
             lr_decay.check_param(0, "tester", upper=None)
-        assert "tester is out of allowed range (0, +inf): tester=0" in e.value.args[0]
+        assert "tester is out of allowed range (0, inf): tester=0" in e.value.args[0]
         assert lr_decay.check_param(0.5, "tester", lower=None) == 0.5
         with pytest.raises(ValueError) as e:
             lr_decay.check_param(0.5, "tester", lower=None, upper=0.2)
-- 
GitLab