diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py index 33358e566ef80f28ee7740531b71d1a83abde115..7781a0d2cb10f474240a30a6b7fd06a72b1754fa 100644 --- a/mlair/model_modules/keras_extensions.py +++ b/mlair/model_modules/keras_extensions.py @@ -9,9 +9,12 @@ import pickle from typing import Union, List from typing_extensions import TypedDict +from mlair.helpers import TimeTracking + import numpy as np from keras import backend as K from keras.callbacks import History, ModelCheckpoint, Callback +import keras from mlair import helpers @@ -147,6 +150,7 @@ class ModelCheckpointAdvanced(ModelCheckpoint): def __init__(self, *args, **kwargs): """Initialise ModelCheckpointAdvanced and set callbacks attribute.""" self.callbacks = kwargs.pop("callbacks") + self.custom_objects = kwargs.pop("custom_objects") super().__init__(*args, **kwargs) def update_best(self, hist): @@ -174,17 +178,29 @@ class ModelCheckpointAdvanced(ModelCheckpoint): def on_epoch_end(self, epoch, logs=None): """Save model as usual (see ModelCheckpoint class), but also save additional callbacks.""" super().on_epoch_end(epoch, logs) - for callback in self.callbacks: + print(callback.keys()) file_path = callback["path"] if self.epochs_since_last_save == 0 and epoch != 0: if self.save_best_only: current = logs.get(self.monitor) + model_save = None + if hasattr(callback["callback"], "model"): + # ToDo: store model in cache? + callback["callback"].model.save( + callback["path"].rsplit(".", maxsplit=1)[0] + "model_save_tmp.h5") + callback["callback"].model = None if current == self.best: if self.verbose > 0: # pragma: no branch print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) with open(file_path, "wb") as f: + # ToDo: create "save" method pickle.dump(callback["callback"], f) + if callback["callback"].model is None: + with TimeTracking("load_model"): + callback["callback"].model = keras.models.load_model( + callback["path"].rsplit(".", maxsplit=1)[0] + "model_save_tmp.h5", + custom_objects=callback["custom_objects"]) else: with open(file_path, "wb") as f: if self.verbose > 0: # pragma: no branch @@ -253,12 +269,17 @@ class CallbackHandler: def __init__(self): """Initialise CallbackHandler.""" self.__callbacks: List[clbk_type] = [] + self.custom_objects = {} self._checkpoint = None self.editable = True + def add_custom_objects(self, custom_objects): + self.custom_objects = custom_objects + @property def _callbacks(self): - return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk in self.__callbacks] + return [{"callback": clbk[clbk["name"]], "path": clbk["path"], "custom_objects": self.custom_objects} for clbk + in self.__callbacks] @_callbacks.setter def _callbacks(self, value): @@ -313,7 +334,8 @@ class CallbackHandler: """Return all callbacks and append checkpoint if available on last position.""" clbks = self._callbacks if self._checkpoint is not None: - clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath}] + clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath, + "custom_objects": self.custom_objects}] return clbks def get_checkpoint(self) -> ModelCheckpointAdvanced: @@ -323,7 +345,8 @@ class CallbackHandler: def create_model_checkpoint(self, **kwargs): """Create a model checkpoint and enable edit.""" - self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, **kwargs) + self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, custom_objects=self.custom_objects, + **kwargs) self.editable = False def load_callbacks(self) -> None: diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py index 07eaa1ec19fbd394ad0bffd262d2698ccfd81055..fb2902dce93d37b52e91804ab92ed895d85a876a 100644 --- a/mlair/model_modules/model_class.py +++ b/mlair/model_modules/model_class.py @@ -669,10 +669,10 @@ class MyPaperModel(AbstractModelClass): conv_settings_dict1 = { 'tower_1': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (3, 1), 'activation': activation}, - 'tower_2': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (5, 1), - 'activation': activation}, - 'tower_3': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (1, 1), - 'activation': activation}, + # 'tower_2': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (5, 1), + # 'activation': activation}, + # 'tower_3': {'reduction_filter': 8, 'tower_filter': 16 * 2, 'tower_kernel': (1, 1), + # 'activation': activation}, # 'tower_4':{'reduction_filter':8, 'tower_filter':8*2, 'tower_kernel':(7,1), 'activation':activation}, } pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 16, 'activation': activation} @@ -680,10 +680,10 @@ class MyPaperModel(AbstractModelClass): conv_settings_dict2 = { 'tower_1': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (3, 1), 'activation': activation}, - 'tower_2': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (5, 1), - 'activation': activation}, - 'tower_3': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (1, 1), - 'activation': activation}, + # 'tower_2': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (5, 1), + # 'activation': activation}, + # 'tower_3': {'reduction_filter': 64, 'tower_filter': 32 * 2, 'tower_kernel': (1, 1), + # 'activation': activation}, # 'tower_4':{'reduction_filter':8*2, 'tower_filter':16*2, 'tower_kernel':(7,1), 'activation':activation}, } pool_settings_dict2 = {'pool_kernel': (3, 1), 'tower_filter': 32, 'activation': activation} diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index c6af13b02e818431578c7423d837f95e64ca3d15..26c61ff37748d6110752fe5afaf5122e2df16303 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -118,6 +118,7 @@ class ModelSetup(RunEnvironment): hist = HistoryAdvanced() self.data_store.set("hist", hist, scope="model") callbacks = CallbackHandler() + callbacks.add_custom_objects(self.model.custom_objects) if lr is not None: callbacks.add_callback(lr, self.callbacks_name % "lr", "lr") callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")