Skip to content
Snippets Groups Projects
Commit aee2fe01 authored by lukas leufen's avatar lukas leufen
Browse files

implemented CallbackHandler

parent 6bb881f5
No related branches found
No related tags found
2 merge requests!50release for v0.7.0,!42implemented CallbackHandler
Pipeline #29481 passed
...@@ -150,3 +150,64 @@ class ModelCheckpointAdvanced(ModelCheckpoint): ...@@ -150,3 +150,64 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
if self.verbose > 0: # pragma: no branch if self.verbose > 0: # pragma: no branch
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
pickle.dump(callback["callback"], f) pickle.dump(callback["callback"], f)
class CallbackHandler:
def __init__(self):
self.__callbacks = []
self._checkpoint = None
self.editable = True
@property
def _callbacks(self):
return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk in self.__callbacks]
@_callbacks.setter
def _callbacks(self, value):
name, callback, callback_path = value
self.__callbacks.append({"name": name, name: callback, "path": callback_path})
def _update_callback(self, pos, value):
name = self.__callbacks[pos]["name"]
self.__callbacks[pos][name] = value
def add_callback(self, callback, callback_path, name="callback"):
if self.editable:
self._callbacks = (name, callback, callback_path)
else:
raise PermissionError(f"{__class__.__name__} is protected and cannot be edited.")
def get_callbacks(self, as_dict=True):
if as_dict:
return self._get_callbacks()
else:
return [clb["callback"] for clb in self._get_callbacks()]
def get_callback_by_name(self, obj_name):
if obj_name != "callback":
return [clbk[clbk["name"]] for clbk in self.__callbacks if clbk["name"] == obj_name][0]
def _get_callbacks(self):
clbks = self._callbacks
if self._checkpoint is not None:
clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath}]
return clbks
def get_checkpoint(self):
if self._checkpoint is not None:
return self._checkpoint
def create_model_checkpoint(self, **kwargs):
self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, **kwargs)
self.editable = False
def load_callbacks(self):
for pos, callback in enumerate(self.__callbacks):
path = callback["path"]
clb = pickle.load(open(path, "rb"))
self._update_callback(pos, clb)
def update_checkpoint(self, history_name="hist"):
self._checkpoint.update_callbacks(self._callbacks)
self._checkpoint.update_best(self.get_callback_by_name(history_name))
...@@ -110,3 +110,124 @@ class TestModelCheckpointAdvanced: ...@@ -110,3 +110,124 @@ class TestModelCheckpointAdvanced:
assert "callback_hist" in os.listdir(path) assert "callback_hist" in os.listdir(path)
os.remove(os.path.join(path, "callback_hist")) os.remove(os.path.join(path, "callback_hist"))
os.remove(os.path.join(path, "callback_lr")) os.remove(os.path.join(path, "callback_lr"))
class TestCallbackHandler:
@pytest.fixture
def clbk_handler(self):
return CallbackHandler()
@pytest.fixture
def clbk_handler_with_dummies(self, clbk_handler):
clbk_handler.add_callback("callback_new_instance", "this_path")
clbk_handler.add_callback("callback_other", "otherpath", "other_clbk")
return clbk_handler
@pytest.fixture
def callback_handler(self, clbk_handler):
clbk_handler.add_callback(HistoryAdvanced(), "callbacks_hist.pickle", "hist")
clbk_handler.add_callback(LearningRateDecay(), "callbacks_lr.pickle", "lr")
return clbk_handler
@pytest.fixture
def prepare_pickle_files(self):
hist = HistoryAdvanced()
hist.epoch = [1, 2, 3]
hist.history = {"val_loss": [10, 5, 4]}
lr = LearningRateDecay()
lr.epoch = [1, 2, 3]
pickle.dump(hist, open("callbacks_hist.pickle", "wb"))
pickle.dump(lr, open("callbacks_lr.pickle", "wb"))
yield
os.remove("callbacks_hist.pickle")
os.remove("callbacks_lr.pickle")
def test_init(self, clbk_handler):
assert len(clbk_handler._CallbackHandler__callbacks) == 0
assert clbk_handler._checkpoint is None
assert clbk_handler.editable is True
def test_callbacks_set(self, clbk_handler):
clbk_handler._callbacks = ("default", "callback_instance", "callback_path")
assert clbk_handler._CallbackHandler__callbacks == [{"name": "default", "default": "callback_instance",
"path": "callback_path"}]
clbk_handler._callbacks = ("another", "callback_instance2", "callback_path")
assert clbk_handler._CallbackHandler__callbacks == [{"name": "default", "default": "callback_instance",
"path": "callback_path"},
{"name": "another", "another": "callback_instance2",
"path": "callback_path"}]
def test_callbacks_get(self, clbk_handler):
clbk_handler._callbacks = ("default", "callback_instance", "callback_path")
clbk_handler._callbacks = ("another", "callback_instance2", "callback_path2")
assert clbk_handler._callbacks == [{"callback": "callback_instance", "path": "callback_path"},
{"callback": "callback_instance2", "path": "callback_path2"}]
def test_update_callback(self, clbk_handler_with_dummies):
clbk_handler_with_dummies._update_callback(0, "old_instance")
assert clbk_handler_with_dummies.get_callbacks() == [{"callback": "old_instance", "path": "this_path"},
{"callback": "callback_other", "path": "otherpath"}]
def test_add_callback(self, clbk_handler):
clbk_handler.add_callback("callback_new_instance", "this_path")
assert clbk_handler._CallbackHandler__callbacks == [{"name": "callback", "callback": "callback_new_instance",
"path": "this_path"}]
clbk_handler.add_callback("callback_other", "otherpath", "other_clbk")
assert clbk_handler._CallbackHandler__callbacks == [{"name": "callback", "callback": "callback_new_instance",
"path": "this_path"},
{"name": "other_clbk", "other_clbk": "callback_other",
"path": "otherpath"}]
def test_get_callbacks_as_dict(self, clbk_handler_with_dummies):
clbk = clbk_handler_with_dummies
assert clbk.get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"},
{"callback": "callback_other", "path": "otherpath"}]
assert clbk.get_callbacks() == clbk.get_callbacks(as_dict=True)
def test_get_callbacks_no_dict(self, clbk_handler_with_dummies):
assert clbk_handler_with_dummies.get_callbacks(as_dict=False) == ["callback_new_instance", "callback_other"]
def test_get_callback_by_name(self, clbk_handler_with_dummies):
assert clbk_handler_with_dummies.get_callback_by_name("other_clbk") == "callback_other"
assert clbk_handler_with_dummies.get_callback_by_name("callback") is None
def test__get_callbacks(self, clbk_handler_with_dummies):
clbk = clbk_handler_with_dummies
assert clbk._get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"},
{"callback": "callback_other", "path": "otherpath"}]
ckpt = keras.callbacks.ModelCheckpoint("testFilePath")
clbk._checkpoint = ckpt
assert clbk._get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"},
{"callback": "callback_other", "path": "otherpath"},
{"callback": ckpt, "path": "testFilePath"}]
def test_get_checkpoint(self, clbk_handler):
assert clbk_handler.get_checkpoint() is None
clbk_handler._checkpoint = "testCKPT"
assert clbk_handler.get_checkpoint() == "testCKPT"
def test_create_model_checkpoint(self, callback_handler):
callback_handler.create_model_checkpoint(filepath="tester_path", verbose=1)
assert callback_handler.editable is False
assert isinstance(callback_handler._checkpoint, ModelCheckpointAdvanced)
assert callback_handler._checkpoint.filepath == "tester_path"
assert callback_handler._checkpoint.verbose == 1
assert callback_handler._checkpoint.monitor == "val_loss"
def test_load_callbacks(self, callback_handler, prepare_pickle_files):
assert len(callback_handler.get_callback_by_name("hist").epoch) == 0
assert len(callback_handler.get_callback_by_name("lr").epoch) == 0
callback_handler.load_callbacks()
assert len(callback_handler.get_callback_by_name("hist").epoch) == 3
assert len(callback_handler.get_callback_by_name("lr").epoch) == 3
def test_update_checkpoint(self, callback_handler, prepare_pickle_files):
assert len(callback_handler.get_callback_by_name("hist").epoch) == 0
assert len(callback_handler.get_callback_by_name("lr").epoch) == 0
callback_handler.create_model_checkpoint(filepath="tester_path", verbose=1)
callback_handler.load_callbacks()
callback_handler.update_checkpoint()
assert len(callback_handler.get_callback_by_name("hist").epoch) == 3
assert len(callback_handler.get_callback_by_name("lr").epoch) == 3
assert callback_handler._checkpoint.best == 4
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment