Skip to content
Snippets Groups Projects
Commit 6bb881f5 authored by lukas leufen's avatar lukas leufen
Browse files

test implemented, /close #40

parents be542b99 e4244df0
No related branches found
No related tags found
2 merge requests!50release for v0.7.0,!41implemented test
Pipeline #29445 passed
......@@ -10,6 +10,8 @@ import numpy as np
from keras import backend as K
from keras.callbacks import History, ModelCheckpoint
from src import helpers
class HistoryAdvanced(History):
"""
......@@ -125,7 +127,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
Update all stored callback objects. The argument callbacks needs to follow the same convention like described
in the class description (list of dictionaries). Must be run before resuming a training process.
"""
self.callbacks = callbacks
self.callbacks = helpers.to_list(callbacks)
def on_epoch_end(self, epoch, logs=None):
"""
......@@ -139,12 +141,12 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
if self.save_best_only:
current = logs.get(self.monitor)
if current == self.best:
if self.verbose > 0:
if self.verbose > 0: # pragma: no branch
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
with open(file_path, "wb") as f:
pickle.dump(callback["callback"], f)
else:
with open(file_path, "wb") as f:
if self.verbose > 0:
if self.verbose > 0: # pragma: no branch
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
pickle.dump(callback["callback"], f)
import keras
import numpy as np
import pytest
import mock
import os
from src.helpers import l_p_loss
from src.model_modules.keras_extensions import *
......@@ -60,3 +62,51 @@ class TestLearningRateDecay:
model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2))
model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=5, callbacks=[lr_decay])
assert lr_decay.lr['lr'] == [0.02, 0.02, 0.02 * 0.95, 0.02 * 0.95, 0.02 * 0.95 * 0.95]
class TestModelCheckpointAdvanced:
@pytest.fixture()
def callbacks(self):
callbacks_name = os.path.join(os.path.dirname(__file__), "callback_%s")
return [{"callback": LearningRateDecay(), "path": callbacks_name % "lr"},
{"callback": HistoryAdvanced(), "path": callbacks_name % "hist"}]
@pytest.fixture
def ckpt(self, callbacks):
ckpt_name = "ckpt.test"
return ModelCheckpointAdvanced(filepath=ckpt_name, monitor='val_loss', save_best_only=True, callbacks=callbacks, verbose=1)
def test_init(self, ckpt, callbacks):
assert ckpt.callbacks == callbacks
assert ckpt.monitor == "val_loss"
assert ckpt.save_best_only is True
assert ckpt.best == np.inf
def test_update_best(self, ckpt):
hist = HistoryAdvanced()
hist.history["val_loss"] = [10, 6]
ckpt.update_best(hist)
assert ckpt.best == 6
def test_update_callbacks(self, ckpt, callbacks):
ckpt.update_callbacks(callbacks[0])
assert ckpt.callbacks == [callbacks[0]]
def test_on_epoch_end(self, ckpt):
path = os.path.dirname(__file__)
ckpt.set_model(mock.MagicMock())
ckpt.best = 6
ckpt.on_epoch_end(0, {"val_loss": 6})
assert "callback_hist" not in os.listdir(path)
ckpt.on_epoch_end(9, {"val_loss": 10})
assert "callback_hist" not in os.listdir(path)
ckpt.on_epoch_end(10, {"val_loss": 4})
assert "callback_hist" in os.listdir(path)
os.remove(os.path.join(path, "callback_hist"))
os.remove(os.path.join(path, "callback_lr"))
ckpt.save_best_only = False
ckpt.on_epoch_end(10, {"val_loss": 3})
assert "callback_hist" in os.listdir(path)
os.remove(os.path.join(path, "callback_hist"))
os.remove(os.path.join(path, "callback_lr"))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment