diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py
index bc868765019a3d3ce78586adae11d0c9b25ed72f..e0f54282010e765fb3d8b0aca191a75c0b22fdf9 100644
--- a/mlair/model_modules/keras_extensions.py
+++ b/mlair/model_modules/keras_extensions.py
@@ -112,17 +112,18 @@ class LearningRateDecay(History):
         return K.get_value(self.model.optimizer.lr)
 
 
-class TimingCallback(Callback):
+class EpoTimingCallback(Callback):
     def __init__(self):
+        self.epo_timing = {'epo_timing': []}
         self.logs = []
         self.starttime = None
         super().__init__()
 
-    def on_epoch_begin(self, logs={}):
+    def on_epoch_begin(self, epoch: int, logs=None):
         self.starttime = time()
 
-    def on_epoch_end(self, logs={}):
-        self.logs.append(time()-self.starttime)
+    def on_epoch_end(self, epoch: int, logs=None):
+        self.epo_timing["epo_timing"].append(time()-self.starttime)
 
 
 class ModelCheckpointAdvanced(ModelCheckpoint):