Skip to content
Snippets Groups Projects
Select Git revision
  • c8d7c137c3e8e53d148e70a2042267668b73dd12
  • main default protected
  • feature-gp
3 results

utils.py

Blame
  • test_keras_extensions.py 11.44 KiB
    import os
    
    import keras
    import mock
    import pytest
    
    from mlair.model_modules.loss import l_p_loss
    from mlair.model_modules.keras_extensions import *
    
    
    class TestHistoryAdvanced:
    
        def test_init(self):
            hist = HistoryAdvanced()
            assert hist.validation_data is None
            assert hist.model is None
            assert isinstance(hist.epoch, list) and len(hist.epoch) == 0
            assert isinstance(hist.history, dict) and len(hist.history.keys()) == 0
    
        def test_on_train_begin(self):
            hist = HistoryAdvanced()
            hist.epoch = [1, 2, 3]
            hist.history = {"mse": [10, 7, 4]}
            hist.on_train_begin()
            assert hist.epoch == [1, 2, 3]
            assert hist.history == {"mse": [10, 7, 4]}
    
    
    class TestLearningRateDecay:
    
        def test_init(self):
            lr_decay = LearningRateDecay()
            assert lr_decay.lr == {'lr': []}
            assert lr_decay.base_lr == 0.01
            assert lr_decay.drop == 0.96
            assert lr_decay.epochs_drop == 8
    
        def test_check_param(self):
            lr_decay = object.__new__(LearningRateDecay)
            assert lr_decay.check_param(1, "tester") == 1
            assert lr_decay.check_param(0.5, "tester") == 0.5
            with pytest.raises(ValueError) as e:
                lr_decay.check_param(0, "tester")
            assert "tester is out of allowed range (0, 1]: tester=0" in e.value.args[0]
            with pytest.raises(ValueError) as e:
                lr_decay.check_param(1.5, "tester")
            assert "tester is out of allowed range (0, 1]: tester=1.5" in e.value.args[0]
            assert lr_decay.check_param(1.5, "tester", upper=None) == 1.5
            with pytest.raises(ValueError) as e:
                lr_decay.check_param(0, "tester", upper=None)
            assert "tester is out of allowed range (0, inf): tester=0" in e.value.args[0]
            assert lr_decay.check_param(0.5, "tester", lower=None) == 0.5
            with pytest.raises(ValueError) as e:
                lr_decay.check_param(0.5, "tester", lower=None, upper=0.2)
            assert "tester is out of allowed range (-inf, 0.2]: tester=0.5" in e.value.args[0]
            assert lr_decay.check_param(10, "tester", upper=None, lower=None)
    
        def test_on_epoch_begin(self):
            lr_decay = LearningRateDecay(base_lr=0.02, drop=0.95, epochs_drop=2)
            model = keras.Sequential()
            model.add(keras.layers.Dense(1, input_dim=1))
            model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2))
            model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=5, callbacks=[lr_decay])
            assert lr_decay.lr['lr'] == [0.02, 0.02, 0.02 * 0.95, 0.02 * 0.95, 0.02 * 0.95 * 0.95]
    
    
    class TestModelCheckpointAdvanced:
    
        @pytest.fixture()
        def callbacks(self):
            callbacks_name = os.path.join(os.path.dirname(__file__), "callback_%s")
            return [{"callback": LearningRateDecay(), "path": callbacks_name % "lr"},
                    {"callback": HistoryAdvanced(), "path": callbacks_name % "hist"}]
    
        @pytest.fixture
        def ckpt(self, callbacks):
            ckpt_name = "ckpt.test"
            return ModelCheckpointAdvanced(filepath=ckpt_name, monitor='val_loss', save_best_only=True, callbacks=callbacks,
                                           verbose=1)
    
        def test_init(self, ckpt, callbacks):
            assert ckpt.callbacks == callbacks
            assert ckpt.monitor == "val_loss"
            assert ckpt.save_best_only is True
            assert ckpt.best == np.inf
    
        def test_update_best(self, ckpt):
            hist = HistoryAdvanced()
            hist.history["val_loss"] = [10, 6]
            ckpt.update_best(hist)
            assert ckpt.best == 6
    
        def test_update_callbacks(self, ckpt, callbacks):
            ckpt.update_callbacks(callbacks[0])
            assert ckpt.callbacks == [callbacks[0]]
    
        def test_on_epoch_end(self, ckpt):
            path = os.path.dirname(__file__)
            ckpt.set_model(mock.MagicMock())
            ckpt.best = 6
            ckpt.on_epoch_end(0, {"val_loss": 6})
            assert "callback_hist" not in os.listdir(path)
            ckpt.on_epoch_end(9, {"val_loss": 10})
            assert "callback_hist" not in os.listdir(path)
            ckpt.on_epoch_end(10, {"val_loss": 4})
            assert "callback_hist" in os.listdir(path)
            os.remove(os.path.join(path, "callback_hist"))
            os.remove(os.path.join(path, "callback_lr"))
            ckpt.save_best_only = False
            ckpt.on_epoch_end(10, {"val_loss": 3})
            assert "callback_hist" in os.listdir(path)
            os.remove(os.path.join(path, "callback_hist"))
            os.remove(os.path.join(path, "callback_lr"))
    
    
    class TestCallbackHandler:
    
        @pytest.fixture
        def clbk_handler(self):
            return CallbackHandler()
    
        @pytest.fixture
        def clbk_handler_with_dummies(self, clbk_handler):
            clbk_handler.add_callback("callback_new_instance", "this_path")
            clbk_handler.add_callback("callback_other", "otherpath", "other_clbk")
            return clbk_handler
    
        @pytest.fixture
        def callback_handler(self, clbk_handler):
            clbk_handler.add_callback(HistoryAdvanced(), "callbacks_hist.pickle", "hist")
            clbk_handler.add_callback(LearningRateDecay(), "callbacks_lr.pickle", "lr")
            return clbk_handler
    
        @pytest.fixture
        def prepare_pickle_files(self):
            hist = HistoryAdvanced()
            hist.epoch = [1, 2, 3]
            hist.history = {"val_loss": [10, 5, 4]}
            lr = LearningRateDecay()
            lr.epoch = [1, 2, 3]
            pickle.dump(hist, open("callbacks_hist.pickle", "wb"))
            pickle.dump(lr, open("callbacks_lr.pickle", "wb"))
            yield
            os.remove("callbacks_hist.pickle")
            os.remove("callbacks_lr.pickle")
    
        def test_init(self, clbk_handler):
            assert len(clbk_handler._CallbackHandler__callbacks) == 0
            assert clbk_handler._checkpoint is None
            assert clbk_handler.editable is True
    
        def test_callbacks_set(self, clbk_handler):
            clbk_handler._callbacks = ("default", "callback_instance", "callback_path")
            assert clbk_handler._CallbackHandler__callbacks == [{"name": "default", "default": "callback_instance",
                                                                 "path": "callback_path"}]
            clbk_handler._callbacks = ("another", "callback_instance2", "callback_path")
            assert clbk_handler._CallbackHandler__callbacks == [{"name": "default", "default": "callback_instance",
                                                                 "path": "callback_path"},
                                                                {"name": "another", "another": "callback_instance2",
                                                                 "path": "callback_path"}]
    
        def test_callbacks_get(self, clbk_handler):
            clbk_handler._callbacks = ("default", "callback_instance", "callback_path")
            clbk_handler._callbacks = ("another", "callback_instance2", "callback_path2")
            assert clbk_handler._callbacks == [{"callback": "callback_instance", "path": "callback_path"},
                                               {"callback": "callback_instance2", "path": "callback_path2"}]
    
        def test_update_callback(self, clbk_handler_with_dummies):
            clbk_handler_with_dummies._update_callback(0, "old_instance")
            assert clbk_handler_with_dummies.get_callbacks() == [{"callback": "old_instance", "path": "this_path"},
                                                                 {"callback": "callback_other", "path": "otherpath"}]
    
        def test_add_callback(self, clbk_handler):
            clbk_handler.add_callback("callback_new_instance", "this_path")
            assert clbk_handler._CallbackHandler__callbacks == [{"name": "callback", "callback": "callback_new_instance",
                                                                 "path": "this_path"}]
            clbk_handler.add_callback("callback_other", "otherpath", "other_clbk")
            assert clbk_handler._CallbackHandler__callbacks == [{"name": "callback", "callback": "callback_new_instance",
                                                                 "path": "this_path"},
                                                                {"name": "other_clbk", "other_clbk": "callback_other",
                                                                 "path": "otherpath"}]
    
        def test_add_callback_raise(self, clbk_handler):
            clbk_handler.editable = False
            with pytest.raises(PermissionError) as einfo:
                clbk_handler.add_callback("callback_new_instance", "this_path")
            assert 'CallbackHandler is protected and cannot be edited.' in str(einfo.value)
    
        def test_get_callbacks_as_dict(self, clbk_handler_with_dummies):
            clbk = clbk_handler_with_dummies
            assert clbk.get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"},
                                            {"callback": "callback_other", "path": "otherpath"}]
            assert clbk.get_callbacks() == clbk.get_callbacks(as_dict=True)
    
        def test_get_callbacks_no_dict(self, clbk_handler_with_dummies):
            assert clbk_handler_with_dummies.get_callbacks(as_dict=False) == ["callback_new_instance", "callback_other"]
    
        def test_get_callback_by_name(self, clbk_handler_with_dummies):
            assert clbk_handler_with_dummies.get_callback_by_name("other_clbk") == "callback_other"
            assert clbk_handler_with_dummies.get_callback_by_name("callback") is None
    
        def test__get_callbacks(self, clbk_handler_with_dummies):
            clbk = clbk_handler_with_dummies
            assert clbk._get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"},
                                             {"callback": "callback_other", "path": "otherpath"}]
            ckpt = keras.callbacks.ModelCheckpoint("testFilePath")
            clbk._checkpoint = ckpt
            assert clbk._get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"},
                                             {"callback": "callback_other", "path": "otherpath"},
                                             {"callback": ckpt, "path": "testFilePath"}]
    
        def test_get_checkpoint(self, clbk_handler):
            assert clbk_handler.get_checkpoint() is None
            clbk_handler._checkpoint = "testCKPT"
            assert clbk_handler.get_checkpoint() == "testCKPT"
    
        def test_create_model_checkpoint(self, callback_handler):
            callback_handler.create_model_checkpoint(filepath="tester_path", verbose=1)
            assert callback_handler.editable is False
            assert isinstance(callback_handler._checkpoint, ModelCheckpointAdvanced)
            assert callback_handler._checkpoint.filepath == "tester_path"
            assert callback_handler._checkpoint.verbose == 1
            assert callback_handler._checkpoint.monitor == "val_loss"
    
        def test_load_callbacks(self, callback_handler, prepare_pickle_files):
            assert len(callback_handler.get_callback_by_name("hist").epoch) == 0
            assert len(callback_handler.get_callback_by_name("lr").epoch) == 0
            callback_handler.load_callbacks()
            assert len(callback_handler.get_callback_by_name("hist").epoch) == 3
            assert len(callback_handler.get_callback_by_name("lr").epoch) == 3
    
        def test_update_checkpoint(self, callback_handler, prepare_pickle_files):
            assert len(callback_handler.get_callback_by_name("hist").epoch) == 0
            assert len(callback_handler.get_callback_by_name("lr").epoch) == 0
            callback_handler.create_model_checkpoint(filepath="tester_path", verbose=1)
            callback_handler.load_callbacks()
            callback_handler.update_checkpoint()
            assert len(callback_handler.get_callback_by_name("hist").epoch) == 3
            assert len(callback_handler.get_callback_by_name("lr").epoch) == 3
            assert callback_handler._checkpoint.best == 4