Skip to content
Snippets Groups Projects
Commit 72c1d3d2 authored by Felix Kleinert's avatar Felix Kleinert
Browse files

include epotimingcallback into workflow

parent 160bf15f
No related branches found
No related tags found
6 merge requests!319add all changes of dev into release v1.4.0 branch,!318Resolve "release v1.4.0",!317enabled window_lead_time=1,!308Felix issue312 callback track epoch runtime,!307Resolve "Implement Callback to track epoch-runtime",!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
Pipeline #72179 passed
......@@ -12,7 +12,7 @@ import keras
import pandas as pd
import tensorflow as tf
from mlair.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler
from mlair.model_modules.keras_extensions import HistoryAdvanced, EpoTimingCallback, CallbackHandler
from mlair.run_modules.run_environment import RunEnvironment
from mlair.configuration import path_config
......@@ -119,11 +119,14 @@ class ModelSetup(RunEnvironment):
"""
lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None)
hist = HistoryAdvanced()
epo_timing = EpoTimingCallback()
self.data_store.set("hist", hist, scope="model")
self.data_store.set("epo_timing", epo_timing, scope="model")
callbacks = CallbackHandler()
if lr is not None:
callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
callbacks.add_callback(epo_timing, self.callbacks_name % "epo_timing", "epo_timing")
callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
save_best_only=True, mode='auto')
self.data_store.set("callbacks", callbacks, self.scope)
......
......@@ -166,7 +166,11 @@ class Training(RunEnvironment):
lr = self.callbacks.get_callback_by_name("lr")
except IndexError:
lr = None
self.save_callbacks_as_json(history, lr)
try:
epo_timing = self.callbacks.get_callback_by_name("epo_timing")
except IndexError:
epo_timing = None
self.save_callbacks_as_json(history, lr, epo_timing)
self.load_best_model(checkpoint.filepath)
self.create_monitoring_plots(history, lr)
......@@ -190,7 +194,7 @@ class Training(RunEnvironment):
except OSError:
logging.info('no weights to reload...')
def save_callbacks_as_json(self, history: Callback, lr_sc: Callback) -> None:
def save_callbacks_as_json(self, history: Callback, lr_sc: Callback, epo_timing: Callback) -> None:
"""
Save callbacks (history, learning rate) of training.
......@@ -207,6 +211,9 @@ class Training(RunEnvironment):
if lr_sc:
with open(os.path.join(path, "history_lr.json"), "w") as f:
json.dump(lr_sc.lr, f)
if epo_timing is not None:
with open(os.path.join(path, "epo_timing.json"), "w") as f:
json.dump(epo_timing.epo_timing, f)
def create_monitoring_plots(self, history: Callback, lr_sc: Callback) -> None:
"""
......
......@@ -80,7 +80,7 @@ class TestModelSetup:
setup._set_callbacks()
assert "general.model" in setup.data_store.search_name("callbacks")
callbacks = setup.data_store.get("callbacks", "general.model")
assert len(callbacks.get_callbacks()) == 3
assert len(callbacks.get_callbacks()) == 4
def test_set_callbacks_no_lr_decay(self, setup):
setup.data_store.set("lr_decay", None, "general.model")
......@@ -88,7 +88,7 @@ class TestModelSetup:
setup.checkpoint_name = "TestName"
setup._set_callbacks()
callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model")
assert len(callbacks.get_callbacks()) == 2
assert len(callbacks.get_callbacks()) == 3
with pytest.raises(IndexError):
callbacks.get_callback_by_name("lr_decay")
......
......@@ -13,7 +13,7 @@ from mlair.data_handler import DataCollection, KerasIterator, DefaultDataHandler
from mlair.helpers import PyTestRegex
from mlair.model_modules.flatten import flatten_tail
from mlair.model_modules.inception_model import InceptionModelBase
from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler
from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler, EpoTimingCallback
from mlair.run_modules.run_environment import RunEnvironment
from mlair.run_modules.training import Training
......@@ -100,6 +100,12 @@ class TestTraining:
h.model = mock.MagicMock()
return h
@pytest.fixture
def epo_timing(self):
epo_timing = EpoTimingCallback()
epo_timing.epoch = [0, 1]
epo_timing.epo_timing = {"epo_timing": [0.1, 0.2]}
@pytest.fixture
def path(self):
return os.path.join(os.path.dirname(__file__), "TestExperiment")
......@@ -144,9 +150,11 @@ class TestTraining:
def callbacks(self, path):
clbk = CallbackHandler()
hist = HistoryAdvanced()
epo_timing = EpoTimingCallback()
clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
lr = LearningRateDecay()
clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing")
clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
save_best_only=True)
return clbk, hist, lr
......@@ -256,22 +264,22 @@ class TestTraining:
assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting"))
assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))
def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate)
def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
assert "history.json" in os.listdir(model_path)
def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate)
def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
assert "history_lr.json" in os.listdir(model_path)
def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate)
def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, epo_timing, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
with open(os.path.join(model_path, "history.json")) as jfile:
hist = json.load(jfile)
assert hist == history.history
def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate)
def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, epo_timing, model_path):
init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
with open(os.path.join(model_path, "history_lr.json")) as jfile:
lr = json.load(jfile)
assert lr == learning_rate.lr
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment