diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py index 4140f15b7afb496452357793c8e4d4c914fd9f86..33358e566ef80f28ee7740531b71d1a83abde115 100644 --- a/mlair/model_modules/keras_extensions.py +++ b/mlair/model_modules/keras_extensions.py @@ -174,8 +174,8 @@ class ModelCheckpointAdvanced(ModelCheckpoint): def on_epoch_end(self, epoch, logs=None): """Save model as usual (see ModelCheckpoint class), but also save additional callbacks.""" super().on_epoch_end(epoch, logs) + for callback in self.callbacks: - print(callback.keys()) file_path = callback["path"] if self.epochs_since_last_save == 0 and epoch != 0: if self.save_best_only: @@ -184,7 +184,6 @@ class ModelCheckpointAdvanced(ModelCheckpoint): if self.verbose > 0: # pragma: no branch print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) with open(file_path, "wb") as f: - # ToDo: create "save" method pickle.dump(callback["callback"], f) else: with open(file_path, "wb") as f: @@ -259,8 +258,7 @@ class CallbackHandler: @property def _callbacks(self): - return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk - in self.__callbacks] + return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk in self.__callbacks] @_callbacks.setter def _callbacks(self, value): diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py index fb2902dce93d37b52e91804ab92ed895d85a876a..a2eda6e8287af2ce489bf75b02d7b205549ff144 100644 --- a/mlair/model_modules/model_class.py +++ b/mlair/model_modules/model_class.py @@ -636,20 +636,22 @@ class MyPaperModel(AbstractModelClass): assert len(output_shape) == 1 super().__init__(input_shape[0], output_shape[0]) + from mlair.model_modules.keras_extensions import LearningRateDecay + # settings self.dropout_rate = .3 self.regularizer = keras.regularizers.l2(0.001) self.initial_lr = 1e-3 - self.lr_decay = mlair.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, - epochs_drop=10) + self.lr_decay = LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10) self.activation = keras.layers.ELU self.padding = "SymPad2D" # apply to model self.set_model() self.set_compile_options() - self.set_custom_objects(loss=self.compile_options["loss"], SymmetricPadding2D=SymmetricPadding2D, - LearningRateDecay=mlair.model_modules.keras_extensions.LearningRateDecay) + self.set_custom_objects(loss=self.compile_options["loss"], + SymmetricPadding2D=SymmetricPadding2D, + LearningRateDecay=LearningRateDecay) def set_model(self): """ diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index 0ce13790a10999d69de7a39676d2a1f92997bbf9..dda18fac5d8546c6e399334f3d89415d246a1975 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -76,6 +76,9 @@ class ModelSetup(RunEnvironment): # build model graph using settings from my_model_settings() self.build_model() + # broadcast custom objects + self.broadcast_custom_objects() + # plot model structure self.plot_model() @@ -140,6 +143,15 @@ class ModelSetup(RunEnvironment): model = self.data_store.get("model_class") self.model = model(**args) self.get_model_settings() + + def broadcast_custom_objects(self): + """ + Broadcast custom objects to keras utils. + + This method is very important, because it adds the model's custom objects to the keras utils. By doing so, all + custom objects can be treated as standard keras modules. Therefore, problems related to model or callback + loading are solved. + """ keras.utils.get_custom_objects().update(self.model.custom_objects) def get_model_settings(self): diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index 3878e79cb66365ba186ee736f3a4927c076d2dee..113765e0d295bb0b1d756cd1cefba85093b20089 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -145,7 +145,7 @@ class Training(RunEnvironment): logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.") self.callbacks.load_callbacks() self.callbacks.update_checkpoint() - self.model = keras.models.load_model(checkpoint.filepath, self.model.custom_objects) + self.model = keras.models.load_model(checkpoint.filepath) hist: History = self.callbacks.get_callback_by_name("hist") initial_epoch = max(hist.epoch) + 1 _ = self.model.fit_generator(generator=self.train_set,