diff --git a/src/model_modules/keras_extensions.py b/src/model_modules/keras_extensions.py index cfb8638816e85ec8427f3df5d45d480d21fd929f..180e324602da25e1df8fb218c1d3bba180004ac8 100644 --- a/src/model_modules/keras_extensions.py +++ b/src/model_modules/keras_extensions.py @@ -150,3 +150,64 @@ class ModelCheckpointAdvanced(ModelCheckpoint): if self.verbose > 0: # pragma: no branch print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) pickle.dump(callback["callback"], f) + + +class CallbackHandler: + + def __init__(self): + self.__callbacks = [] + self._checkpoint = None + self.editable = True + + @property + def _callbacks(self): + return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk in self.__callbacks] + + @_callbacks.setter + def _callbacks(self, value): + name, callback, callback_path = value + self.__callbacks.append({"name": name, name: callback, "path": callback_path}) + + def _update_callback(self, pos, value): + name = self.__callbacks[pos]["name"] + self.__callbacks[pos][name] = value + + def add_callback(self, callback, callback_path, name="callback"): + if self.editable: + self._callbacks = (name, callback, callback_path) + else: + raise PermissionError(f"{__class__.__name__} is protected and cannot be edited.") + + def get_callbacks(self, as_dict=True): + if as_dict: + return self._get_callbacks() + else: + return [clb["callback"] for clb in self._get_callbacks()] + + def get_callback_by_name(self, obj_name): + if obj_name != "callback": + return [clbk[clbk["name"]] for clbk in self.__callbacks if clbk["name"] == obj_name][0] + + def _get_callbacks(self): + clbks = self._callbacks + if self._checkpoint is not None: + clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath}] + return clbks + + def get_checkpoint(self): + if self._checkpoint is not None: + return self._checkpoint + + def create_model_checkpoint(self, **kwargs): + self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, **kwargs) + self.editable = False + + def load_callbacks(self): + for pos, callback in enumerate(self.__callbacks): + path = callback["path"] + clb = pickle.load(open(path, "rb")) + self._update_callback(pos, clb) + + def update_checkpoint(self, history_name="hist"): + self._checkpoint.update_callbacks(self._callbacks) + self._checkpoint.update_best(self.get_callback_by_name(history_name)) diff --git a/test/test_model_modules/test_keras_extensions.py b/test/test_model_modules/test_keras_extensions.py index 9141434d448f8d076836ef357279ae5815686767..17ab4f6d65c95a5a54c9d931818f889acadef532 100644 --- a/test/test_model_modules/test_keras_extensions.py +++ b/test/test_model_modules/test_keras_extensions.py @@ -110,3 +110,124 @@ class TestModelCheckpointAdvanced: assert "callback_hist" in os.listdir(path) os.remove(os.path.join(path, "callback_hist")) os.remove(os.path.join(path, "callback_lr")) + + +class TestCallbackHandler: + + @pytest.fixture + def clbk_handler(self): + return CallbackHandler() + + @pytest.fixture + def clbk_handler_with_dummies(self, clbk_handler): + clbk_handler.add_callback("callback_new_instance", "this_path") + clbk_handler.add_callback("callback_other", "otherpath", "other_clbk") + return clbk_handler + + @pytest.fixture + def callback_handler(self, clbk_handler): + clbk_handler.add_callback(HistoryAdvanced(), "callbacks_hist.pickle", "hist") + clbk_handler.add_callback(LearningRateDecay(), "callbacks_lr.pickle", "lr") + return clbk_handler + + @pytest.fixture + def prepare_pickle_files(self): + hist = HistoryAdvanced() + hist.epoch = [1, 2, 3] + hist.history = {"val_loss": [10, 5, 4]} + lr = LearningRateDecay() + lr.epoch = [1, 2, 3] + pickle.dump(hist, open("callbacks_hist.pickle", "wb")) + pickle.dump(lr, open("callbacks_lr.pickle", "wb")) + yield + os.remove("callbacks_hist.pickle") + os.remove("callbacks_lr.pickle") + + def test_init(self, clbk_handler): + assert len(clbk_handler._CallbackHandler__callbacks) == 0 + assert clbk_handler._checkpoint is None + assert clbk_handler.editable is True + + def test_callbacks_set(self, clbk_handler): + clbk_handler._callbacks = ("default", "callback_instance", "callback_path") + assert clbk_handler._CallbackHandler__callbacks == [{"name": "default", "default": "callback_instance", + "path": "callback_path"}] + clbk_handler._callbacks = ("another", "callback_instance2", "callback_path") + assert clbk_handler._CallbackHandler__callbacks == [{"name": "default", "default": "callback_instance", + "path": "callback_path"}, + {"name": "another", "another": "callback_instance2", + "path": "callback_path"}] + + def test_callbacks_get(self, clbk_handler): + clbk_handler._callbacks = ("default", "callback_instance", "callback_path") + clbk_handler._callbacks = ("another", "callback_instance2", "callback_path2") + assert clbk_handler._callbacks == [{"callback": "callback_instance", "path": "callback_path"}, + {"callback": "callback_instance2", "path": "callback_path2"}] + + def test_update_callback(self, clbk_handler_with_dummies): + clbk_handler_with_dummies._update_callback(0, "old_instance") + assert clbk_handler_with_dummies.get_callbacks() == [{"callback": "old_instance", "path": "this_path"}, + {"callback": "callback_other", "path": "otherpath"}] + + def test_add_callback(self, clbk_handler): + clbk_handler.add_callback("callback_new_instance", "this_path") + assert clbk_handler._CallbackHandler__callbacks == [{"name": "callback", "callback": "callback_new_instance", + "path": "this_path"}] + clbk_handler.add_callback("callback_other", "otherpath", "other_clbk") + assert clbk_handler._CallbackHandler__callbacks == [{"name": "callback", "callback": "callback_new_instance", + "path": "this_path"}, + {"name": "other_clbk", "other_clbk": "callback_other", + "path": "otherpath"}] + + def test_get_callbacks_as_dict(self, clbk_handler_with_dummies): + clbk = clbk_handler_with_dummies + assert clbk.get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"}, + {"callback": "callback_other", "path": "otherpath"}] + assert clbk.get_callbacks() == clbk.get_callbacks(as_dict=True) + + def test_get_callbacks_no_dict(self, clbk_handler_with_dummies): + assert clbk_handler_with_dummies.get_callbacks(as_dict=False) == ["callback_new_instance", "callback_other"] + + def test_get_callback_by_name(self, clbk_handler_with_dummies): + assert clbk_handler_with_dummies.get_callback_by_name("other_clbk") == "callback_other" + assert clbk_handler_with_dummies.get_callback_by_name("callback") is None + + def test__get_callbacks(self, clbk_handler_with_dummies): + clbk = clbk_handler_with_dummies + assert clbk._get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"}, + {"callback": "callback_other", "path": "otherpath"}] + ckpt = keras.callbacks.ModelCheckpoint("testFilePath") + clbk._checkpoint = ckpt + assert clbk._get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"}, + {"callback": "callback_other", "path": "otherpath"}, + {"callback": ckpt, "path": "testFilePath"}] + + def test_get_checkpoint(self, clbk_handler): + assert clbk_handler.get_checkpoint() is None + clbk_handler._checkpoint = "testCKPT" + assert clbk_handler.get_checkpoint() == "testCKPT" + + def test_create_model_checkpoint(self, callback_handler): + callback_handler.create_model_checkpoint(filepath="tester_path", verbose=1) + assert callback_handler.editable is False + assert isinstance(callback_handler._checkpoint, ModelCheckpointAdvanced) + assert callback_handler._checkpoint.filepath == "tester_path" + assert callback_handler._checkpoint.verbose == 1 + assert callback_handler._checkpoint.monitor == "val_loss" + + def test_load_callbacks(self, callback_handler, prepare_pickle_files): + assert len(callback_handler.get_callback_by_name("hist").epoch) == 0 + assert len(callback_handler.get_callback_by_name("lr").epoch) == 0 + callback_handler.load_callbacks() + assert len(callback_handler.get_callback_by_name("hist").epoch) == 3 + assert len(callback_handler.get_callback_by_name("lr").epoch) == 3 + + def test_update_checkpoint(self, callback_handler, prepare_pickle_files): + assert len(callback_handler.get_callback_by_name("hist").epoch) == 0 + assert len(callback_handler.get_callback_by_name("lr").epoch) == 0 + callback_handler.create_model_checkpoint(filepath="tester_path", verbose=1) + callback_handler.load_callbacks() + callback_handler.update_checkpoint() + assert len(callback_handler.get_callback_by_name("hist").epoch) == 3 + assert len(callback_handler.get_callback_by_name("lr").epoch) == 3 + assert callback_handler._checkpoint.best == 4