Skip to content
Snippets Groups Projects
Commit 85d067ff authored by leufen1's avatar leufen1
Browse files

found magic command to add custom objects to entire keras world

parent 5cb07ed5
Branches
No related tags found
4 merge requests!253include current develop,!252Resolve "release v1.3.0",!245update #275 branch,!242Resolve "BUG: loading of custom objects not working"
Pipeline #60057 passed
...@@ -9,12 +9,9 @@ import pickle ...@@ -9,12 +9,9 @@ import pickle
from typing import Union, List from typing import Union, List
from typing_extensions import TypedDict from typing_extensions import TypedDict
from mlair.helpers import TimeTracking
import numpy as np import numpy as np
from keras import backend as K from keras import backend as K
from keras.callbacks import History, ModelCheckpoint, Callback from keras.callbacks import History, ModelCheckpoint, Callback
import keras
from mlair import helpers from mlair import helpers
...@@ -150,7 +147,6 @@ class ModelCheckpointAdvanced(ModelCheckpoint): ...@@ -150,7 +147,6 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Initialise ModelCheckpointAdvanced and set callbacks attribute.""" """Initialise ModelCheckpointAdvanced and set callbacks attribute."""
self.callbacks = kwargs.pop("callbacks") self.callbacks = kwargs.pop("callbacks")
self.custom_objects = kwargs.pop("custom_objects")
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def update_best(self, hist): def update_best(self, hist):
...@@ -184,23 +180,12 @@ class ModelCheckpointAdvanced(ModelCheckpoint): ...@@ -184,23 +180,12 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
if self.epochs_since_last_save == 0 and epoch != 0: if self.epochs_since_last_save == 0 and epoch != 0:
if self.save_best_only: if self.save_best_only:
current = logs.get(self.monitor) current = logs.get(self.monitor)
model_save = None
if hasattr(callback["callback"], "model"):
# ToDo: store model in cache?
callback["callback"].model.save(
callback["path"].rsplit(".", maxsplit=1)[0] + "model_save_tmp.h5")
callback["callback"].model = None
if current == self.best: if current == self.best:
if self.verbose > 0: # pragma: no branch if self.verbose > 0: # pragma: no branch
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
# ToDo: create "save" method # ToDo: create "save" method
pickle.dump(callback["callback"], f) pickle.dump(callback["callback"], f)
if callback["callback"].model is None:
with TimeTracking("load_model"):
callback["callback"].model = keras.models.load_model(
callback["path"].rsplit(".", maxsplit=1)[0] + "model_save_tmp.h5",
custom_objects=callback["custom_objects"])
else: else:
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
if self.verbose > 0: # pragma: no branch if self.verbose > 0: # pragma: no branch
...@@ -269,16 +254,12 @@ class CallbackHandler: ...@@ -269,16 +254,12 @@ class CallbackHandler:
def __init__(self): def __init__(self):
"""Initialise CallbackHandler.""" """Initialise CallbackHandler."""
self.__callbacks: List[clbk_type] = [] self.__callbacks: List[clbk_type] = []
self.custom_objects = {}
self._checkpoint = None self._checkpoint = None
self.editable = True self.editable = True
def add_custom_objects(self, custom_objects):
self.custom_objects = custom_objects
@property @property
def _callbacks(self): def _callbacks(self):
return [{"callback": clbk[clbk["name"]], "path": clbk["path"], "custom_objects": self.custom_objects} for clbk return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk
in self.__callbacks] in self.__callbacks]
@_callbacks.setter @_callbacks.setter
...@@ -334,8 +315,7 @@ class CallbackHandler: ...@@ -334,8 +315,7 @@ class CallbackHandler:
"""Return all callbacks and append checkpoint if available on last position.""" """Return all callbacks and append checkpoint if available on last position."""
clbks = self._callbacks clbks = self._callbacks
if self._checkpoint is not None: if self._checkpoint is not None:
clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath, clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath}]
"custom_objects": self.custom_objects}]
return clbks return clbks
def get_checkpoint(self) -> ModelCheckpointAdvanced: def get_checkpoint(self) -> ModelCheckpointAdvanced:
...@@ -345,8 +325,7 @@ class CallbackHandler: ...@@ -345,8 +325,7 @@ class CallbackHandler:
def create_model_checkpoint(self, **kwargs): def create_model_checkpoint(self, **kwargs):
"""Create a model checkpoint and enable edit.""" """Create a model checkpoint and enable edit."""
self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, custom_objects=self.custom_objects, self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, **kwargs)
**kwargs)
self.editable = False self.editable = False
def load_callbacks(self) -> None: def load_callbacks(self) -> None:
......
...@@ -118,7 +118,6 @@ class ModelSetup(RunEnvironment): ...@@ -118,7 +118,6 @@ class ModelSetup(RunEnvironment):
hist = HistoryAdvanced() hist = HistoryAdvanced()
self.data_store.set("hist", hist, scope="model") self.data_store.set("hist", hist, scope="model")
callbacks = CallbackHandler() callbacks = CallbackHandler()
callbacks.add_custom_objects(self.model.custom_objects)
if lr is not None: if lr is not None:
callbacks.add_callback(lr, self.callbacks_name % "lr", "lr") callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
callbacks.add_callback(hist, self.callbacks_name % "hist", "hist") callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
...@@ -141,6 +140,7 @@ class ModelSetup(RunEnvironment): ...@@ -141,6 +140,7 @@ class ModelSetup(RunEnvironment):
model = self.data_store.get("model_class") model = self.data_store.get("model_class")
self.model = model(**args) self.model = model(**args)
self.get_model_settings() self.get_model_settings()
keras.utils.get_custom_objects().update(self.model.custom_objects)
def get_model_settings(self): def get_model_settings(self):
"""Load all model settings and store in data store.""" """Load all model settings and store in data store."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment