From 160bf15f11ba4e711ca394d79b4d5dab786f9957 Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Mon, 5 Jul 2021 08:49:59 +0200
Subject: [PATCH] update timing callback

---
 mlair/model_modules/keras_extensions.py | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py
index bc868765..e0f54282 100644
--- a/mlair/model_modules/keras_extensions.py
+++ b/mlair/model_modules/keras_extensions.py
@@ -112,17 +112,18 @@ class LearningRateDecay(History):
         return K.get_value(self.model.optimizer.lr)
 
 
-class TimingCallback(Callback):
+class EpoTimingCallback(Callback):
     def __init__(self):
+        self.epo_timing = {'epo_timing': []}
         self.logs = []
         self.starttime = None
         super().__init__()
 
-    def on_epoch_begin(self, logs={}):
+    def on_epoch_begin(self, epoch: int, logs=None):
         self.starttime = time()
 
-    def on_epoch_end(self, logs={}):
-        self.logs.append(time()-self.starttime)
+    def on_epoch_end(self, epoch: int, logs=None):
+        self.epo_timing["epo_timing"].append(time()-self.starttime)
 
 
 class ModelCheckpointAdvanced(ModelCheckpoint):
-- 
GitLab