diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py index bc868765019a3d3ce78586adae11d0c9b25ed72f..e0f54282010e765fb3d8b0aca191a75c0b22fdf9 100644 --- a/mlair/model_modules/keras_extensions.py +++ b/mlair/model_modules/keras_extensions.py @@ -112,17 +112,18 @@ class LearningRateDecay(History): return K.get_value(self.model.optimizer.lr) -class TimingCallback(Callback): +class EpoTimingCallback(Callback): def __init__(self): + self.epo_timing = {'epo_timing': []} self.logs = [] self.starttime = None super().__init__() - def on_epoch_begin(self, logs={}): + def on_epoch_begin(self, epoch: int, logs=None): self.starttime = time() - def on_epoch_end(self, logs={}): - self.logs.append(time()-self.starttime) + def on_epoch_end(self, epoch: int, logs=None): + self.epo_timing["epo_timing"].append(time()-self.starttime) class ModelCheckpointAdvanced(ModelCheckpoint):