import keras import numpy as np import pytest import mock import os from src.helpers import l_p_loss from src.model_modules.keras_extensions import * class TestHistoryAdvanced: def test_init(self): hist = HistoryAdvanced() assert hist.validation_data is None assert hist.model is None assert isinstance(hist.epoch, list) and len(hist.epoch) == 0 assert isinstance(hist.history, dict) and len(hist.history.keys()) == 0 def test_on_train_begin(self): hist = HistoryAdvanced() hist.epoch = [1, 2, 3] hist.history = {"mse": [10, 7, 4]} hist.on_train_begin() assert hist.epoch == [1, 2, 3] assert hist.history == {"mse": [10, 7, 4]} class TestLearningRateDecay: def test_init(self): lr_decay = LearningRateDecay() assert lr_decay.lr == {'lr': []} assert lr_decay.base_lr == 0.01 assert lr_decay.drop == 0.96 assert lr_decay.epochs_drop == 8 def test_check_param(self): lr_decay = object.__new__(LearningRateDecay) assert lr_decay.check_param(1, "tester") == 1 assert lr_decay.check_param(0.5, "tester") == 0.5 with pytest.raises(ValueError) as e: lr_decay.check_param(0, "tester") assert "tester is out of allowed range (0, 1]: tester=0" in e.value.args[0] with pytest.raises(ValueError) as e: lr_decay.check_param(1.5, "tester") assert "tester is out of allowed range (0, 1]: tester=1.5" in e.value.args[0] assert lr_decay.check_param(1.5, "tester", upper=None) == 1.5 with pytest.raises(ValueError) as e: lr_decay.check_param(0, "tester", upper=None) assert "tester is out of allowed range (0, inf): tester=0" in e.value.args[0] assert lr_decay.check_param(0.5, "tester", lower=None) == 0.5 with pytest.raises(ValueError) as e: lr_decay.check_param(0.5, "tester", lower=None, upper=0.2) assert "tester is out of allowed range (-inf, 0.2]: tester=0.5" in e.value.args[0] assert lr_decay.check_param(10, "tester", upper=None, lower=None) def test_on_epoch_begin(self): lr_decay = LearningRateDecay(base_lr=0.02, drop=0.95, epochs_drop=2) model = keras.Sequential() model.add(keras.layers.Dense(1, input_dim=1)) model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2)) model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=5, callbacks=[lr_decay]) assert lr_decay.lr['lr'] == [0.02, 0.02, 0.02 * 0.95, 0.02 * 0.95, 0.02 * 0.95 * 0.95] class TestModelCheckpointAdvanced: @pytest.fixture() def callbacks(self): callbacks_name = os.path.join(os.path.dirname(__file__), "callback_%s") return [{"callback": LearningRateDecay(), "path": callbacks_name % "lr"}, {"callback": HistoryAdvanced(), "path": callbacks_name % "hist"}] @pytest.fixture def ckpt(self, callbacks): ckpt_name = "ckpt.test" return ModelCheckpointAdvanced(filepath=ckpt_name, monitor='val_loss', save_best_only=True, callbacks=callbacks, verbose=1) def test_init(self, ckpt, callbacks): assert ckpt.callbacks == callbacks assert ckpt.monitor == "val_loss" assert ckpt.save_best_only is True assert ckpt.best == np.inf def test_update_best(self, ckpt): hist = HistoryAdvanced() hist.history["val_loss"] = [10, 6] ckpt.update_best(hist) assert ckpt.best == 6 def test_update_callbacks(self, ckpt, callbacks): ckpt.update_callbacks(callbacks[0]) assert ckpt.callbacks == [callbacks[0]] def test_on_epoch_end(self, ckpt): path = os.path.dirname(__file__) ckpt.set_model(mock.MagicMock()) ckpt.best = 6 ckpt.on_epoch_end(0, {"val_loss": 6}) assert "callback_hist" not in os.listdir(path) ckpt.on_epoch_end(9, {"val_loss": 10}) assert "callback_hist" not in os.listdir(path) ckpt.on_epoch_end(10, {"val_loss": 4}) assert "callback_hist" in os.listdir(path) os.remove(os.path.join(path, "callback_hist")) os.remove(os.path.join(path, "callback_lr")) ckpt.save_best_only = False ckpt.on_epoch_end(10, {"val_loss": 3}) assert "callback_hist" in os.listdir(path) os.remove(os.path.join(path, "callback_hist")) os.remove(os.path.join(path, "callback_lr")) class TestCallbackHandler: @pytest.fixture def clbk_handler(self): return CallbackHandler() @pytest.fixture def clbk_handler_with_dummies(self, clbk_handler): clbk_handler.add_callback("callback_new_instance", "this_path") clbk_handler.add_callback("callback_other", "otherpath", "other_clbk") return clbk_handler @pytest.fixture def callback_handler(self, clbk_handler): clbk_handler.add_callback(HistoryAdvanced(), "callbacks_hist.pickle", "hist") clbk_handler.add_callback(LearningRateDecay(), "callbacks_lr.pickle", "lr") return clbk_handler @pytest.fixture def prepare_pickle_files(self): hist = HistoryAdvanced() hist.epoch = [1, 2, 3] hist.history = {"val_loss": [10, 5, 4]} lr = LearningRateDecay() lr.epoch = [1, 2, 3] pickle.dump(hist, open("callbacks_hist.pickle", "wb")) pickle.dump(lr, open("callbacks_lr.pickle", "wb")) yield os.remove("callbacks_hist.pickle") os.remove("callbacks_lr.pickle") def test_init(self, clbk_handler): assert len(clbk_handler._CallbackHandler__callbacks) == 0 assert clbk_handler._checkpoint is None assert clbk_handler.editable is True def test_callbacks_set(self, clbk_handler): clbk_handler._callbacks = ("default", "callback_instance", "callback_path") assert clbk_handler._CallbackHandler__callbacks == [{"name": "default", "default": "callback_instance", "path": "callback_path"}] clbk_handler._callbacks = ("another", "callback_instance2", "callback_path") assert clbk_handler._CallbackHandler__callbacks == [{"name": "default", "default": "callback_instance", "path": "callback_path"}, {"name": "another", "another": "callback_instance2", "path": "callback_path"}] def test_callbacks_get(self, clbk_handler): clbk_handler._callbacks = ("default", "callback_instance", "callback_path") clbk_handler._callbacks = ("another", "callback_instance2", "callback_path2") assert clbk_handler._callbacks == [{"callback": "callback_instance", "path": "callback_path"}, {"callback": "callback_instance2", "path": "callback_path2"}] def test_update_callback(self, clbk_handler_with_dummies): clbk_handler_with_dummies._update_callback(0, "old_instance") assert clbk_handler_with_dummies.get_callbacks() == [{"callback": "old_instance", "path": "this_path"}, {"callback": "callback_other", "path": "otherpath"}] def test_add_callback(self, clbk_handler): clbk_handler.add_callback("callback_new_instance", "this_path") assert clbk_handler._CallbackHandler__callbacks == [{"name": "callback", "callback": "callback_new_instance", "path": "this_path"}] clbk_handler.add_callback("callback_other", "otherpath", "other_clbk") assert clbk_handler._CallbackHandler__callbacks == [{"name": "callback", "callback": "callback_new_instance", "path": "this_path"}, {"name": "other_clbk", "other_clbk": "callback_other", "path": "otherpath"}] def test_get_callbacks_as_dict(self, clbk_handler_with_dummies): clbk = clbk_handler_with_dummies assert clbk.get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"}, {"callback": "callback_other", "path": "otherpath"}] assert clbk.get_callbacks() == clbk.get_callbacks(as_dict=True) def test_get_callbacks_no_dict(self, clbk_handler_with_dummies): assert clbk_handler_with_dummies.get_callbacks(as_dict=False) == ["callback_new_instance", "callback_other"] def test_get_callback_by_name(self, clbk_handler_with_dummies): assert clbk_handler_with_dummies.get_callback_by_name("other_clbk") == "callback_other" assert clbk_handler_with_dummies.get_callback_by_name("callback") is None def test__get_callbacks(self, clbk_handler_with_dummies): clbk = clbk_handler_with_dummies assert clbk._get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"}, {"callback": "callback_other", "path": "otherpath"}] ckpt = keras.callbacks.ModelCheckpoint("testFilePath") clbk._checkpoint = ckpt assert clbk._get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"}, {"callback": "callback_other", "path": "otherpath"}, {"callback": ckpt, "path": "testFilePath"}] def test_get_checkpoint(self, clbk_handler): assert clbk_handler.get_checkpoint() is None clbk_handler._checkpoint = "testCKPT" assert clbk_handler.get_checkpoint() == "testCKPT" def test_create_model_checkpoint(self, callback_handler): callback_handler.create_model_checkpoint(filepath="tester_path", verbose=1) assert callback_handler.editable is False assert isinstance(callback_handler._checkpoint, ModelCheckpointAdvanced) assert callback_handler._checkpoint.filepath == "tester_path" assert callback_handler._checkpoint.verbose == 1 assert callback_handler._checkpoint.monitor == "val_loss" def test_load_callbacks(self, callback_handler, prepare_pickle_files): assert len(callback_handler.get_callback_by_name("hist").epoch) == 0 assert len(callback_handler.get_callback_by_name("lr").epoch) == 0 callback_handler.load_callbacks() assert len(callback_handler.get_callback_by_name("hist").epoch) == 3 assert len(callback_handler.get_callback_by_name("lr").epoch) == 3 def test_update_checkpoint(self, callback_handler, prepare_pickle_files): assert len(callback_handler.get_callback_by_name("hist").epoch) == 0 assert len(callback_handler.get_callback_by_name("lr").epoch) == 0 callback_handler.create_model_checkpoint(filepath="tester_path", verbose=1) callback_handler.load_callbacks() callback_handler.update_checkpoint() assert len(callback_handler.get_callback_by_name("hist").epoch) == 3 assert len(callback_handler.get_callback_by_name("lr").epoch) == 3 assert callback_handler._checkpoint.best == 4