Skip to content
Snippets Groups Projects
Commit 85d067ff authored by leufen1's avatar leufen1
Browse files

found magic command to add custom objects to entire keras world

parent 5cb07ed5
No related branches found
No related tags found
4 merge requests!253include current develop,!252Resolve "release v1.3.0",!245update #275 branch,!242Resolve "BUG: loading of custom objects not working"
Pipeline #60057 passed
......@@ -9,12 +9,9 @@ import pickle
from typing import Union, List
from typing_extensions import TypedDict
from mlair.helpers import TimeTracking
import numpy as np
from keras import backend as K
from keras.callbacks import History, ModelCheckpoint, Callback
import keras
from mlair import helpers
......@@ -150,7 +147,6 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
def __init__(self, *args, **kwargs):
"""Initialise ModelCheckpointAdvanced and set callbacks attribute."""
self.callbacks = kwargs.pop("callbacks")
self.custom_objects = kwargs.pop("custom_objects")
super().__init__(*args, **kwargs)
def update_best(self, hist):
......@@ -184,23 +180,12 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
if self.epochs_since_last_save == 0 and epoch != 0:
if self.save_best_only:
current = logs.get(self.monitor)
model_save = None
if hasattr(callback["callback"], "model"):
# ToDo: store model in cache?
callback["callback"].model.save(
callback["path"].rsplit(".", maxsplit=1)[0] + "model_save_tmp.h5")
callback["callback"].model = None
if current == self.best:
if self.verbose > 0: # pragma: no branch
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
with open(file_path, "wb") as f:
# ToDo: create "save" method
pickle.dump(callback["callback"], f)
if callback["callback"].model is None:
with TimeTracking("load_model"):
callback["callback"].model = keras.models.load_model(
callback["path"].rsplit(".", maxsplit=1)[0] + "model_save_tmp.h5",
custom_objects=callback["custom_objects"])
else:
with open(file_path, "wb") as f:
if self.verbose > 0: # pragma: no branch
......@@ -269,16 +254,12 @@ class CallbackHandler:
def __init__(self):
"""Initialise CallbackHandler."""
self.__callbacks: List[clbk_type] = []
self.custom_objects = {}
self._checkpoint = None
self.editable = True
def add_custom_objects(self, custom_objects):
self.custom_objects = custom_objects
@property
def _callbacks(self):
return [{"callback": clbk[clbk["name"]], "path": clbk["path"], "custom_objects": self.custom_objects} for clbk
return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk
in self.__callbacks]
@_callbacks.setter
......@@ -334,8 +315,7 @@ class CallbackHandler:
"""Return all callbacks and append checkpoint if available on last position."""
clbks = self._callbacks
if self._checkpoint is not None:
clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath,
"custom_objects": self.custom_objects}]
clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath}]
return clbks
def get_checkpoint(self) -> ModelCheckpointAdvanced:
......@@ -345,8 +325,7 @@ class CallbackHandler:
def create_model_checkpoint(self, **kwargs):
"""Create a model checkpoint and enable edit."""
self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, custom_objects=self.custom_objects,
**kwargs)
self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, **kwargs)
self.editable = False
def load_callbacks(self) -> None:
......
......@@ -118,7 +118,6 @@ class ModelSetup(RunEnvironment):
hist = HistoryAdvanced()
self.data_store.set("hist", hist, scope="model")
callbacks = CallbackHandler()
callbacks.add_custom_objects(self.model.custom_objects)
if lr is not None:
callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
......@@ -141,6 +140,7 @@ class ModelSetup(RunEnvironment):
model = self.data_store.get("model_class")
self.model = model(**args)
self.get_model_settings()
keras.utils.get_custom_objects().update(self.model.custom_objects)
def get_model_settings(self):
"""Load all model settings and store in data store."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment