diff --git a/src/helpers.py b/src/helpers.py
index f4ac45b2e25156bf8d61990f8e45f09cc0aff4a8..50e99cef006a94c959233bd3d4d2d378adabcda5 100644
--- a/src/helpers.py
+++ b/src/helpers.py
@@ -6,6 +6,7 @@ import logging
 import keras
 import keras.backend as K
 import math
+from typing import Union
 
 
 def to_list(arg):
@@ -26,35 +27,57 @@ def l_p_loss(power: int):
     return loss
 
 
-class lrDecay(keras.callbacks.History):
+class LearningRateDecay(keras.callbacks.History):
+    """
+    Decay learning rate during model training. Start with a base learning rate and lower this rate after every
+    n(=epochs_drop) epochs by drop value (0, 1], drop value = 1 means no decay in learning rate.
+    """
 
     def __init__(self, base_lr: float = 0.01, drop: float = 0.96, epochs_drop: int = 8):
-        super(lrDecay, self).__init__()
-
+        super().__init__()
         self.lr = {'lr': []}
-        self.base_lr = base_lr
-        self.drop = drop
-        self.epochs_drop = epochs_drop
+        self.base_lr = self.check_param(base_lr, 'base_lr')
+        self.drop = self.check_param(drop, 'drop')
+        self.epochs_drop = self.check_param(epochs_drop, 'epochs_drop', upper=None)
+
+    @staticmethod
+    def check_param(value: float, name: str, lower: Union[float, None] = 0, upper: Union[float, None] = 1):
+        """
+        Check if given value is in interval. The left (lower) endpoint is open, right (upper) endpoint is closed. To
+        only one side of the interval, set the other endpoint to None. If both ends are set to None, just return the
+        value without any check.
+        :param value: value to check
+        :param name: name of the variable to display in error message
+        :param lower: left (lower) endpoint of interval, opened
+        :param upper: right (upper) endpoint of interval, closed
+        :return: unchanged value or raise ValueError
+        """
+        if all(v is not None for v in [lower, upper]):
+            if lower < value <= upper:
+                return value
+            else:
+                raise ValueError(f"{name} is out of allowed range ({lower}, {upper}]: {name}={value}")
+        elif lower is not None:
+            if lower < value:
+                return value
+            else:
+                raise ValueError(f"{name} is out of allowed range ({lower}, +inf): {name}={value}")
+        elif upper is not None:
+            if value <= upper:
+                return value
+            else:
+                raise ValueError(f"{name} is out of allowed range (-inf, {upper}]: {name}={value}")
+        return value
 
     def on_epoch_begin(self, epoch: int, logs=None):
-        if epoch > 0:
-            current_lr = self.base_lr * math.pow(self.drop, math.floor(1 + epoch) / self.epochs_drop)
-        else:
-            current_lr = self.base_lr
+        """
+        Lower learning rate every epochs_drop epochs by factor drop.
+        :param epoch: current epoch
+        :param logs: ?
+        :return: update keras learning rate
+        """
+        current_lr = self.base_lr * math.pow(self.drop, math.floor(epoch / self.epochs_drop))
         K.set_value(self.model.optimizer.lr, current_lr)
         self.lr['lr'].append(current_lr)
         logging.info(f"Set learning rate to {current_lr}")
         return K.get_value(self.model.optimizer.lr)
-
-
-class lrCallback(keras.callbacks.History):
-
-    def __init__(self):
-        super(lrCallback, self).__init__()
-        self.lr = None
-
-    def on_train_begin(self, logs=None):
-        self.lr = {}
-
-    def on_epoch_end(self, epoch, logs=None):
-        self.lr.append(self.model.optimizer.lr)
\ No newline at end of file
diff --git a/test/test_helpers.py b/test/test_helpers.py
index bc64176e10cac71471ecc78efbcb639bc9fab81f..69c5909ad3f7ee6be1432a5992278d4c6e873df5 100644
--- a/test/test_helpers.py
+++ b/test/test_helpers.py
@@ -1,9 +1,8 @@
 import pytest
-from src.helpers import l_p_loss
+from src.helpers import l_p_loss, LearningRateDecay
 import logging
 import os
 import keras
-import keras.backend as K
 import numpy as np
 
 
@@ -19,3 +18,40 @@ class TestLoss:
         hist = model.fit(np.array([1, 0, -2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1)
         assert hist.history['loss'][0] == 2.25
 
+
+class TestLearningRateDecay:
+
+    def test_init(self):
+        lr_decay = LearningRateDecay()
+        assert lr_decay.lr == {'lr': []}
+        assert lr_decay.base_lr == 0.01
+        assert lr_decay.drop == 0.96
+        assert lr_decay.epochs_drop == 8
+
+    def test_check_param(self):
+        lr_decay = object.__new__(LearningRateDecay)
+        assert lr_decay.check_param(1, "tester") == 1
+        assert lr_decay.check_param(0.5, "tester") == 0.5
+        with pytest.raises(ValueError) as e:
+            lr_decay.check_param(0, "tester")
+        assert "tester is out of allowed range (0, 1]: tester=0" in e.value.args[0]
+        with pytest.raises(ValueError) as e:
+            lr_decay.check_param(1.5, "tester")
+        assert "tester is out of allowed range (0, 1]: tester=1.5" in e.value.args[0]
+        assert lr_decay.check_param(1.5, "tester", upper=None) == 1.5
+        with pytest.raises(ValueError) as e:
+            lr_decay.check_param(0, "tester", upper=None)
+        assert "tester is out of allowed range (0, +inf): tester=0" in e.value.args[0]
+        assert lr_decay.check_param(0.5, "tester", lower=None) == 0.5
+        with pytest.raises(ValueError) as e:
+            lr_decay.check_param(0.5, "tester", lower=None, upper=0.2)
+        assert "tester is out of allowed range (-inf, 0.2]: tester=0.5" in e.value.args[0]
+        assert lr_decay.check_param(10, "tester", upper=None, lower=None)
+
+    def test_on_epoch_begin(self):
+        lr_decay = LearningRateDecay(base_lr=0.02, drop=0.95, epochs_drop=2)
+        model = keras.Sequential()
+        model.add(keras.layers.Dense(1, input_dim=1))
+        model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2))
+        model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=5, callbacks=[lr_decay])
+        assert lr_decay.lr['lr'] == [0.02, 0.02, 0.02*0.95, 0.02*0.95, 0.02*0.95*0.95]