diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py index 7781a0d2cb10f474240a30a6b7fd06a72b1754fa..4140f15b7afb496452357793c8e4d4c914fd9f86 100644 --- a/mlair/model_modules/keras_extensions.py +++ b/mlair/model_modules/keras_extensions.py @@ -9,12 +9,9 @@ import pickle from typing import Union, List from typing_extensions import TypedDict -from mlair.helpers import TimeTracking - import numpy as np from keras import backend as K from keras.callbacks import History, ModelCheckpoint, Callback -import keras from mlair import helpers @@ -150,7 +147,6 @@ class ModelCheckpointAdvanced(ModelCheckpoint): def __init__(self, *args, **kwargs): """Initialise ModelCheckpointAdvanced and set callbacks attribute.""" self.callbacks = kwargs.pop("callbacks") - self.custom_objects = kwargs.pop("custom_objects") super().__init__(*args, **kwargs) def update_best(self, hist): @@ -184,23 +180,12 @@ class ModelCheckpointAdvanced(ModelCheckpoint): if self.epochs_since_last_save == 0 and epoch != 0: if self.save_best_only: current = logs.get(self.monitor) - model_save = None - if hasattr(callback["callback"], "model"): - # ToDo: store model in cache? - callback["callback"].model.save( - callback["path"].rsplit(".", maxsplit=1)[0] + "model_save_tmp.h5") - callback["callback"].model = None if current == self.best: if self.verbose > 0: # pragma: no branch print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) with open(file_path, "wb") as f: # ToDo: create "save" method pickle.dump(callback["callback"], f) - if callback["callback"].model is None: - with TimeTracking("load_model"): - callback["callback"].model = keras.models.load_model( - callback["path"].rsplit(".", maxsplit=1)[0] + "model_save_tmp.h5", - custom_objects=callback["custom_objects"]) else: with open(file_path, "wb") as f: if self.verbose > 0: # pragma: no branch @@ -269,16 +254,12 @@ class CallbackHandler: def __init__(self): """Initialise CallbackHandler.""" self.__callbacks: List[clbk_type] = [] - self.custom_objects = {} self._checkpoint = None self.editable = True - def add_custom_objects(self, custom_objects): - self.custom_objects = custom_objects - @property def _callbacks(self): - return [{"callback": clbk[clbk["name"]], "path": clbk["path"], "custom_objects": self.custom_objects} for clbk + return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk in self.__callbacks] @_callbacks.setter @@ -334,8 +315,7 @@ class CallbackHandler: """Return all callbacks and append checkpoint if available on last position.""" clbks = self._callbacks if self._checkpoint is not None: - clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath, - "custom_objects": self.custom_objects}] + clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath}] return clbks def get_checkpoint(self) -> ModelCheckpointAdvanced: @@ -345,8 +325,7 @@ class CallbackHandler: def create_model_checkpoint(self, **kwargs): """Create a model checkpoint and enable edit.""" - self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, custom_objects=self.custom_objects, - **kwargs) + self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, **kwargs) self.editable = False def load_callbacks(self) -> None: diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index 26c61ff37748d6110752fe5afaf5122e2df16303..0ce13790a10999d69de7a39676d2a1f92997bbf9 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -118,7 +118,6 @@ class ModelSetup(RunEnvironment): hist = HistoryAdvanced() self.data_store.set("hist", hist, scope="model") callbacks = CallbackHandler() - callbacks.add_custom_objects(self.model.custom_objects) if lr is not None: callbacks.add_callback(lr, self.callbacks_name % "lr", "lr") callbacks.add_callback(hist, self.callbacks_name % "hist", "hist") @@ -141,6 +140,7 @@ class ModelSetup(RunEnvironment): model = self.data_store.get("model_class") self.model = model(**args) self.get_model_settings() + keras.utils.get_custom_objects().update(self.model.custom_objects) def get_model_settings(self): """Load all model settings and store in data store."""