diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py index 33358e566ef80f28ee7740531b71d1a83abde115..e0f54282010e765fb3d8b0aca191a75c0b22fdf9 100644 --- a/mlair/model_modules/keras_extensions.py +++ b/mlair/model_modules/keras_extensions.py @@ -8,6 +8,7 @@ import math import pickle from typing import Union, List from typing_extensions import TypedDict +from time import time import numpy as np from keras import backend as K @@ -111,6 +112,20 @@ class LearningRateDecay(History): return K.get_value(self.model.optimizer.lr) +class EpoTimingCallback(Callback): + def __init__(self): + self.epo_timing = {'epo_timing': []} + self.logs = [] + self.starttime = None + super().__init__() + + def on_epoch_begin(self, epoch: int, logs=None): + self.starttime = time() + + def on_epoch_end(self, epoch: int, logs=None): + self.epo_timing["epo_timing"].append(time()-self.starttime) + + class ModelCheckpointAdvanced(ModelCheckpoint): """ Enhance the standard ModelCheckpoint class by additional saves of given callbacks. diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index 8fae430fb48a28bdd8b21f8bfcfc7c569eb24f6c..83f4a2bd96314d6f8c53f8cc9407cbc12e7b9a16 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -12,7 +12,7 @@ import keras import pandas as pd import tensorflow as tf -from mlair.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler +from mlair.model_modules.keras_extensions import HistoryAdvanced, EpoTimingCallback, CallbackHandler from mlair.run_modules.run_environment import RunEnvironment from mlair.configuration import path_config @@ -119,11 +119,14 @@ class ModelSetup(RunEnvironment): """ lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None) hist = HistoryAdvanced() + epo_timing = EpoTimingCallback() self.data_store.set("hist", hist, scope="model") + self.data_store.set("epo_timing", epo_timing, scope="model") callbacks = CallbackHandler() if lr is not None: callbacks.add_callback(lr, self.callbacks_name % "lr", "lr") callbacks.add_callback(hist, self.callbacks_name % "hist", "hist") + callbacks.add_callback(epo_timing, self.callbacks_name % "epo_timing", "epo_timing") callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', save_best_only=True, mode='auto') self.data_store.set("callbacks", callbacks, self.scope) diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index 5f895b77d53d45bedc255bc7ff051f9d6a8d20a3..00e8eae1581453666d3ca11f48fcdaedf6a24ad0 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -166,7 +166,11 @@ class Training(RunEnvironment): lr = self.callbacks.get_callback_by_name("lr") except IndexError: lr = None - self.save_callbacks_as_json(history, lr) + try: + epo_timing = self.callbacks.get_callback_by_name("epo_timing") + except IndexError: + epo_timing = None + self.save_callbacks_as_json(history, lr, epo_timing) self.load_best_model(checkpoint.filepath) self.create_monitoring_plots(history, lr) @@ -190,7 +194,7 @@ class Training(RunEnvironment): except OSError: logging.info('no weights to reload...') - def save_callbacks_as_json(self, history: Callback, lr_sc: Callback) -> None: + def save_callbacks_as_json(self, history: Callback, lr_sc: Callback, epo_timing: Callback) -> None: """ Save callbacks (history, learning rate) of training. @@ -207,6 +211,9 @@ class Training(RunEnvironment): if lr_sc: with open(os.path.join(path, "history_lr.json"), "w") as f: json.dump(lr_sc.lr, f) + if epo_timing is not None: + with open(os.path.join(path, "epo_timing.json"), "w") as f: + json.dump(epo_timing.epo_timing, f) def create_monitoring_plots(self, history: Callback, lr_sc: Callback) -> None: """ diff --git a/test/test_run_modules/test_model_setup.py b/test/test_run_modules/test_model_setup.py index 8a7572148869537b505b2bd8e7f16cfdf7af1cdd..7cefd0e58f5b9b0787bafddffe1ad07e4851a068 100644 --- a/test/test_run_modules/test_model_setup.py +++ b/test/test_run_modules/test_model_setup.py @@ -80,7 +80,7 @@ class TestModelSetup: setup._set_callbacks() assert "general.model" in setup.data_store.search_name("callbacks") callbacks = setup.data_store.get("callbacks", "general.model") - assert len(callbacks.get_callbacks()) == 3 + assert len(callbacks.get_callbacks()) == 4 def test_set_callbacks_no_lr_decay(self, setup): setup.data_store.set("lr_decay", None, "general.model") @@ -88,7 +88,7 @@ class TestModelSetup: setup.checkpoint_name = "TestName" setup._set_callbacks() callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model") - assert len(callbacks.get_callbacks()) == 2 + assert len(callbacks.get_callbacks()) == 3 with pytest.raises(IndexError): callbacks.get_callback_by_name("lr_decay") diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py index c2b58cbd2160bd958c76ba67649ef8caba09fcb4..ed0d8264326f5299403c47deb46859ccde4a85d7 100644 --- a/test/test_run_modules/test_training.py +++ b/test/test_run_modules/test_training.py @@ -13,7 +13,7 @@ from mlair.data_handler import DataCollection, KerasIterator, DefaultDataHandler from mlair.helpers import PyTestRegex from mlair.model_modules.flatten import flatten_tail from mlair.model_modules.inception_model import InceptionModelBase -from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler +from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler, EpoTimingCallback from mlair.run_modules.run_environment import RunEnvironment from mlair.run_modules.training import Training @@ -100,6 +100,12 @@ class TestTraining: h.model = mock.MagicMock() return h + @pytest.fixture + def epo_timing(self): + epo_timing = EpoTimingCallback() + epo_timing.epoch = [0, 1] + epo_timing.epo_timing = {"epo_timing": [0.1, 0.2]} + @pytest.fixture def path(self): return os.path.join(os.path.dirname(__file__), "TestExperiment") @@ -144,9 +150,11 @@ class TestTraining: def callbacks(self, path): clbk = CallbackHandler() hist = HistoryAdvanced() + epo_timing = EpoTimingCallback() clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist") lr = LearningRateDecay() clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr") + clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing") clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss', save_best_only=True) return clbk, hist, lr @@ -256,22 +264,22 @@ class TestTraining: assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting")) assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload...")) - def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, model_path): - init_without_run.save_callbacks_as_json(history, learning_rate) + def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path): + init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing) assert "history.json" in os.listdir(model_path) - def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, model_path): - init_without_run.save_callbacks_as_json(history, learning_rate) + def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, epo_timing, model_path): + init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing) assert "history_lr.json" in os.listdir(model_path) - def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, model_path): - init_without_run.save_callbacks_as_json(history, learning_rate) + def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, epo_timing, model_path): + init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing) with open(os.path.join(model_path, "history.json")) as jfile: hist = json.load(jfile) assert hist == history.history - def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, model_path): - init_without_run.save_callbacks_as_json(history, learning_rate) + def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, epo_timing, model_path): + init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing) with open(os.path.join(model_path, "history_lr.json")) as jfile: lr = json.load(jfile) assert lr == learning_rate.lr