Skip to content
Snippets Groups Projects
Commit 6bb881f5 authored by lukas leufen's avatar lukas leufen
Browse files

test implemented, /close #40

parents be542b99 e4244df0
Branches
Tags
2 merge requests!50release for v0.7.0,!41implemented test
Pipeline #29445 passed
...@@ -10,6 +10,8 @@ import numpy as np ...@@ -10,6 +10,8 @@ import numpy as np
from keras import backend as K from keras import backend as K
from keras.callbacks import History, ModelCheckpoint from keras.callbacks import History, ModelCheckpoint
from src import helpers
class HistoryAdvanced(History): class HistoryAdvanced(History):
""" """
...@@ -125,7 +127,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint): ...@@ -125,7 +127,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
Update all stored callback objects. The argument callbacks needs to follow the same convention like described Update all stored callback objects. The argument callbacks needs to follow the same convention like described
in the class description (list of dictionaries). Must be run before resuming a training process. in the class description (list of dictionaries). Must be run before resuming a training process.
""" """
self.callbacks = callbacks self.callbacks = helpers.to_list(callbacks)
def on_epoch_end(self, epoch, logs=None): def on_epoch_end(self, epoch, logs=None):
""" """
...@@ -139,12 +141,12 @@ class ModelCheckpointAdvanced(ModelCheckpoint): ...@@ -139,12 +141,12 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
if self.save_best_only: if self.save_best_only:
current = logs.get(self.monitor) current = logs.get(self.monitor)
if current == self.best: if current == self.best:
if self.verbose > 0: if self.verbose > 0: # pragma: no branch
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
pickle.dump(callback["callback"], f) pickle.dump(callback["callback"], f)
else: else:
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
if self.verbose > 0: if self.verbose > 0: # pragma: no branch
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
pickle.dump(callback["callback"], f) pickle.dump(callback["callback"], f)
import keras import keras
import numpy as np import numpy as np
import pytest import pytest
import mock
import os
from src.helpers import l_p_loss from src.helpers import l_p_loss
from src.model_modules.keras_extensions import * from src.model_modules.keras_extensions import *
...@@ -60,3 +62,51 @@ class TestLearningRateDecay: ...@@ -60,3 +62,51 @@ class TestLearningRateDecay:
model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2)) model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2))
model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=5, callbacks=[lr_decay]) model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=5, callbacks=[lr_decay])
assert lr_decay.lr['lr'] == [0.02, 0.02, 0.02 * 0.95, 0.02 * 0.95, 0.02 * 0.95 * 0.95] assert lr_decay.lr['lr'] == [0.02, 0.02, 0.02 * 0.95, 0.02 * 0.95, 0.02 * 0.95 * 0.95]
class TestModelCheckpointAdvanced:
@pytest.fixture()
def callbacks(self):
callbacks_name = os.path.join(os.path.dirname(__file__), "callback_%s")
return [{"callback": LearningRateDecay(), "path": callbacks_name % "lr"},
{"callback": HistoryAdvanced(), "path": callbacks_name % "hist"}]
@pytest.fixture
def ckpt(self, callbacks):
ckpt_name = "ckpt.test"
return ModelCheckpointAdvanced(filepath=ckpt_name, monitor='val_loss', save_best_only=True, callbacks=callbacks, verbose=1)
def test_init(self, ckpt, callbacks):
assert ckpt.callbacks == callbacks
assert ckpt.monitor == "val_loss"
assert ckpt.save_best_only is True
assert ckpt.best == np.inf
def test_update_best(self, ckpt):
hist = HistoryAdvanced()
hist.history["val_loss"] = [10, 6]
ckpt.update_best(hist)
assert ckpt.best == 6
def test_update_callbacks(self, ckpt, callbacks):
ckpt.update_callbacks(callbacks[0])
assert ckpt.callbacks == [callbacks[0]]
def test_on_epoch_end(self, ckpt):
path = os.path.dirname(__file__)
ckpt.set_model(mock.MagicMock())
ckpt.best = 6
ckpt.on_epoch_end(0, {"val_loss": 6})
assert "callback_hist" not in os.listdir(path)
ckpt.on_epoch_end(9, {"val_loss": 10})
assert "callback_hist" not in os.listdir(path)
ckpt.on_epoch_end(10, {"val_loss": 4})
assert "callback_hist" in os.listdir(path)
os.remove(os.path.join(path, "callback_hist"))
os.remove(os.path.join(path, "callback_lr"))
ckpt.save_best_only = False
ckpt.on_epoch_end(10, {"val_loss": 3})
assert "callback_hist" in os.listdir(path)
os.remove(os.path.join(path, "callback_hist"))
os.remove(os.path.join(path, "callback_lr"))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment