diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py
index 307fd63018df1e4825fa8fbee1fb07f6c8fef67e..54d150e0bb44aa1ade473f5a184652ad2c3444d8 100644
--- a/src/run_modules/model_setup.py
+++ b/src/run_modules/model_setup.py
@@ -70,11 +70,12 @@ class ModelSetup(RunEnvironment):
         Set all callbacks for the training phase. Add all callbacks with the .add_callback statement. Finally, the
         advanced model checkpoint is added.
         """
-        lr = self.data_store.get("lr_decay", scope="general.model")
+        lr = self.data_store.get_default("lr_decay", scope="general.model", default=None)
         hist = HistoryAdvanced()
         self.data_store.set("hist", hist, scope="general.model")
         callbacks = CallbackHandler()
-        callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
+        if lr:
+            callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
         callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
         callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
                                           save_best_only=True, mode='auto')
diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index df60c4f2f8dff4a9acb82920ad3c1d203813033d..55b5c2964de3155a8d34cf87a646c0d53deebbef 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -111,7 +111,10 @@ class Training(RunEnvironment):
                                          callbacks=self.callbacks.get_callbacks(as_dict=False),
                                          initial_epoch=initial_epoch)
             history = hist
-        lr = self.callbacks.get_callback_by_name("lr")
+        try:
+            lr = self.callbacks.get_callback_by_name("lr")
+        except IndexError:
+            lr = None
         self.save_callbacks_as_json(history, lr)
         self.load_best_model(checkpoint.filepath)
         self.create_monitoring_plots(history, lr)
@@ -148,8 +151,9 @@ class Training(RunEnvironment):
         path = self.data_store.get("experiment_path", "general")
         with open(os.path.join(path, "history.json"), "w") as f:
             json.dump(history.history, f)
-        with open(os.path.join(path, "history_lr.json"), "w") as f:
-            json.dump(lr_sc.lr, f)
+        if lr_sc:
+            with open(os.path.join(path, "history_lr.json"), "w") as f:
+                json.dump(lr_sc.lr, f)
 
     def create_monitoring_plots(self, history: keras.callbacks.History, lr_sc: LearningRateDecay) -> None:
         """
@@ -174,4 +178,5 @@ class Training(RunEnvironment):
             PlotModelHistory(filename=filename, history=history, plot_metric="mse", main_branch=multiple_branches_used)
 
         # plot learning rate
-        PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc)
+        if lr_sc:
+            PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc)
diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py
index ade35a244601d138d22af6305e67b5aeae964680..9ff7494ff0540c9c96c1343b4f44fece08bfe4ce 100644
--- a/test/test_modules/test_model_setup.py
+++ b/test/test_modules/test_model_setup.py
@@ -4,6 +4,7 @@ import pytest
 
 from src.data_handling.data_generator import DataGenerator
 from src.datastore import EmptyScope
+from src.model_modules.keras_extensions import CallbackHandler
 from src.model_modules.model_class import AbstractModelClass
 from src.run_modules.model_setup import ModelSetup
 from src.run_modules.run_environment import RunEnvironment
@@ -61,6 +62,18 @@ class TestModelSetup:
         setup.checkpoint_name = "TestName"
         setup._set_callbacks()
         assert "general.modeltest" in setup.data_store.search_name("callbacks")
+        callbacks = setup.data_store.get("callbacks", "general.modeltest")
+        assert len(callbacks.get_callbacks()) == 3
+
+    def test_set_callbacks_no_lr_decay(self, setup):
+        setup.data_store.set("lr_decay", None, "general.model")
+        assert "general.modeltest" not in setup.data_store.search_name("callbacks")
+        setup.checkpoint_name = "TestName"
+        setup._set_callbacks()
+        callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.modeltest")
+        assert len(callbacks.get_callbacks()) == 2
+        with pytest.raises(IndexError):
+            callbacks.get_callback_by_name("lr_decay")
 
     def test_get_model_settings(self, setup_with_model):
         with pytest.raises(EmptyScope):
@@ -73,7 +86,7 @@ class TestModelSetup:
         setup_with_gen.build_model()
         assert isinstance(setup_with_gen.model, AbstractModelClass)
         expected = {"window_history_size", "window_lead_time", "channels", "dropout_rate", "regularizer", "initial_lr",
-                    "optimizer", "lr_decay", "epochs", "batch_size", "activation"}
+                    "optimizer", "epochs", "batch_size", "activation"}
         assert expected <= self.current_scope_as_set(setup_with_gen)
 
     def test_set_channels(self, setup_with_gen_tiny):