Skip to content
Snippets Groups Projects
Commit 6f29d408 authored by felix kleinert's avatar felix kleinert
Browse files

Merge branch 'felix_issue312-callback-track-epoch-runtime' into 'develop'

Resolve "Implement Callback to track epoch-runtime"

Closes #312

See merge request !307
parents e914db41 72c1d3d2
Branches
Tags
5 merge requests!319add all changes of dev into release v1.4.0 branch,!318Resolve "release v1.4.0",!317enabled window_lead_time=1,!307Resolve "Implement Callback to track epoch-runtime",!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
Pipeline #72461 passed with warnings
...@@ -8,6 +8,7 @@ import math ...@@ -8,6 +8,7 @@ import math
import pickle import pickle
from typing import Union, List from typing import Union, List
from typing_extensions import TypedDict from typing_extensions import TypedDict
from time import time
import numpy as np import numpy as np
from keras import backend as K from keras import backend as K
...@@ -111,6 +112,20 @@ class LearningRateDecay(History): ...@@ -111,6 +112,20 @@ class LearningRateDecay(History):
return K.get_value(self.model.optimizer.lr) return K.get_value(self.model.optimizer.lr)
class EpoTimingCallback(Callback):
def __init__(self):
self.epo_timing = {'epo_timing': []}
self.logs = []
self.starttime = None
super().__init__()
def on_epoch_begin(self, epoch: int, logs=None):
self.starttime = time()
def on_epoch_end(self, epoch: int, logs=None):
self.epo_timing["epo_timing"].append(time()-self.starttime)
class ModelCheckpointAdvanced(ModelCheckpoint): class ModelCheckpointAdvanced(ModelCheckpoint):
""" """
Enhance the standard ModelCheckpoint class by additional saves of given callbacks. Enhance the standard ModelCheckpoint class by additional saves of given callbacks.
......
...@@ -12,7 +12,7 @@ import keras ...@@ -12,7 +12,7 @@ import keras
import pandas as pd import pandas as pd
import tensorflow as tf import tensorflow as tf
from mlair.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler from mlair.model_modules.keras_extensions import HistoryAdvanced, EpoTimingCallback, CallbackHandler
from mlair.run_modules.run_environment import RunEnvironment from mlair.run_modules.run_environment import RunEnvironment
from mlair.configuration import path_config from mlair.configuration import path_config
...@@ -119,11 +119,14 @@ class ModelSetup(RunEnvironment): ...@@ -119,11 +119,14 @@ class ModelSetup(RunEnvironment):
""" """
lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None) lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None)
hist = HistoryAdvanced() hist = HistoryAdvanced()
epo_timing = EpoTimingCallback()
self.data_store.set("hist", hist, scope="model") self.data_store.set("hist", hist, scope="model")
self.data_store.set("epo_timing", epo_timing, scope="model")
callbacks = CallbackHandler() callbacks = CallbackHandler()
if lr is not None: if lr is not None:
callbacks.add_callback(lr, self.callbacks_name % "lr", "lr") callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
callbacks.add_callback(hist, self.callbacks_name % "hist", "hist") callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
callbacks.add_callback(epo_timing, self.callbacks_name % "epo_timing", "epo_timing")
callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
save_best_only=True, mode='auto') save_best_only=True, mode='auto')
self.data_store.set("callbacks", callbacks, self.scope) self.data_store.set("callbacks", callbacks, self.scope)
......
...@@ -166,7 +166,11 @@ class Training(RunEnvironment): ...@@ -166,7 +166,11 @@ class Training(RunEnvironment):
lr = self.callbacks.get_callback_by_name("lr") lr = self.callbacks.get_callback_by_name("lr")
except IndexError: except IndexError:
lr = None lr = None
self.save_callbacks_as_json(history, lr) try:
epo_timing = self.callbacks.get_callback_by_name("epo_timing")
except IndexError:
epo_timing = None
self.save_callbacks_as_json(history, lr, epo_timing)
self.load_best_model(checkpoint.filepath) self.load_best_model(checkpoint.filepath)
self.create_monitoring_plots(history, lr) self.create_monitoring_plots(history, lr)
...@@ -190,7 +194,7 @@ class Training(RunEnvironment): ...@@ -190,7 +194,7 @@ class Training(RunEnvironment):
except OSError: except OSError:
logging.info('no weights to reload...') logging.info('no weights to reload...')
def save_callbacks_as_json(self, history: Callback, lr_sc: Callback) -> None: def save_callbacks_as_json(self, history: Callback, lr_sc: Callback, epo_timing: Callback) -> None:
""" """
Save callbacks (history, learning rate) of training. Save callbacks (history, learning rate) of training.
...@@ -207,6 +211,9 @@ class Training(RunEnvironment): ...@@ -207,6 +211,9 @@ class Training(RunEnvironment):
if lr_sc: if lr_sc:
with open(os.path.join(path, "history_lr.json"), "w") as f: with open(os.path.join(path, "history_lr.json"), "w") as f:
json.dump(lr_sc.lr, f) json.dump(lr_sc.lr, f)
if epo_timing is not None:
with open(os.path.join(path, "epo_timing.json"), "w") as f:
json.dump(epo_timing.epo_timing, f)
def create_monitoring_plots(self, history: Callback, lr_sc: Callback) -> None: def create_monitoring_plots(self, history: Callback, lr_sc: Callback) -> None:
""" """
......
...@@ -80,7 +80,7 @@ class TestModelSetup: ...@@ -80,7 +80,7 @@ class TestModelSetup:
setup._set_callbacks() setup._set_callbacks()
assert "general.model" in setup.data_store.search_name("callbacks") assert "general.model" in setup.data_store.search_name("callbacks")
callbacks = setup.data_store.get("callbacks", "general.model") callbacks = setup.data_store.get("callbacks", "general.model")
assert len(callbacks.get_callbacks()) == 3 assert len(callbacks.get_callbacks()) == 4
def test_set_callbacks_no_lr_decay(self, setup): def test_set_callbacks_no_lr_decay(self, setup):
setup.data_store.set("lr_decay", None, "general.model") setup.data_store.set("lr_decay", None, "general.model")
...@@ -88,7 +88,7 @@ class TestModelSetup: ...@@ -88,7 +88,7 @@ class TestModelSetup:
setup.checkpoint_name = "TestName" setup.checkpoint_name = "TestName"
setup._set_callbacks() setup._set_callbacks()
callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model") callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model")
assert len(callbacks.get_callbacks()) == 2 assert len(callbacks.get_callbacks()) == 3
with pytest.raises(IndexError): with pytest.raises(IndexError):
callbacks.get_callback_by_name("lr_decay") callbacks.get_callback_by_name("lr_decay")
......
...@@ -13,7 +13,7 @@ from mlair.data_handler import DataCollection, KerasIterator, DefaultDataHandler ...@@ -13,7 +13,7 @@ from mlair.data_handler import DataCollection, KerasIterator, DefaultDataHandler
from mlair.helpers import PyTestRegex from mlair.helpers import PyTestRegex
from mlair.model_modules.flatten import flatten_tail from mlair.model_modules.flatten import flatten_tail
from mlair.model_modules.inception_model import InceptionModelBase from mlair.model_modules.inception_model import InceptionModelBase
from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler, EpoTimingCallback
from mlair.run_modules.run_environment import RunEnvironment from mlair.run_modules.run_environment import RunEnvironment
from mlair.run_modules.training import Training from mlair.run_modules.training import Training
...@@ -100,6 +100,12 @@ class TestTraining: ...@@ -100,6 +100,12 @@ class TestTraining:
h.model = mock.MagicMock() h.model = mock.MagicMock()
return h return h
@pytest.fixture
def epo_timing(self):
epo_timing = EpoTimingCallback()
epo_timing.epoch = [0, 1]
epo_timing.epo_timing = {"epo_timing": [0.1, 0.2]}
@pytest.fixture @pytest.fixture
def path(self): def path(self):
return os.path.join(os.path.dirname(__file__), "TestExperiment") return os.path.join(os.path.dirname(__file__), "TestExperiment")
...@@ -144,9 +150,11 @@ class TestTraining: ...@@ -144,9 +150,11 @@ class TestTraining:
def callbacks(self, path): def callbacks(self, path):
clbk = CallbackHandler() clbk = CallbackHandler()
hist = HistoryAdvanced() hist = HistoryAdvanced()
epo_timing = EpoTimingCallback()
clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist") clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
lr = LearningRateDecay() lr = LearningRateDecay()
clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr") clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing")
clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss', clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
save_best_only=True) save_best_only=True)
return clbk, hist, lr return clbk, hist, lr
...@@ -256,22 +264,22 @@ class TestTraining: ...@@ -256,22 +264,22 @@ class TestTraining:
assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting")) assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting"))
assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload...")) assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))
def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, model_path): def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate) init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
assert "history.json" in os.listdir(model_path) assert "history.json" in os.listdir(model_path)
def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, model_path): def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate) init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
assert "history_lr.json" in os.listdir(model_path) assert "history_lr.json" in os.listdir(model_path)
def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, model_path): def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, epo_timing, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate) init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
with open(os.path.join(model_path, "history.json")) as jfile: with open(os.path.join(model_path, "history.json")) as jfile:
hist = json.load(jfile) hist = json.load(jfile)
assert hist == history.history assert hist == history.history
def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, model_path): def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, epo_timing, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate) init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
with open(os.path.join(model_path, "history_lr.json")) as jfile: with open(os.path.join(model_path, "history_lr.json")) as jfile:
lr = json.load(jfile) lr = json.load(jfile)
assert lr == learning_rate.lr assert lr == learning_rate.lr
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment