From 1415a76076e2a7020001ec517324fa9d572e8eab Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Tue, 16 Feb 2021 19:39:42 +0100 Subject: [PATCH] removed unnecessary changes from previous commits, there is now a new method for the keras magic to emphasise its importance --- mlair/model_modules/keras_extensions.py | 6 ++---- mlair/model_modules/model_class.py | 10 ++++++---- mlair/run_modules/model_setup.py | 12 ++++++++++++ mlair/run_modules/training.py | 2 +- 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py index 4140f15b..33358e56 100644 --- a/mlair/model_modules/keras_extensions.py +++ b/mlair/model_modules/keras_extensions.py @@ -174,8 +174,8 @@ class ModelCheckpointAdvanced(ModelCheckpoint): def on_epoch_end(self, epoch, logs=None): """Save model as usual (see ModelCheckpoint class), but also save additional callbacks.""" super().on_epoch_end(epoch, logs) + for callback in self.callbacks: - print(callback.keys()) file_path = callback["path"] if self.epochs_since_last_save == 0 and epoch != 0: if self.save_best_only: @@ -184,7 +184,6 @@ class ModelCheckpointAdvanced(ModelCheckpoint): if self.verbose > 0: # pragma: no branch print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) with open(file_path, "wb") as f: - # ToDo: create "save" method pickle.dump(callback["callback"], f) else: with open(file_path, "wb") as f: @@ -259,8 +258,7 @@ class CallbackHandler: @property def _callbacks(self): - return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk - in self.__callbacks] + return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk in self.__callbacks] @_callbacks.setter def _callbacks(self, value): diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py index fb2902dc..a2eda6e8 100644 --- a/mlair/model_modules/model_class.py +++ b/mlair/model_modules/model_class.py @@ -636,20 +636,22 @@ class MyPaperModel(AbstractModelClass): assert len(output_shape) == 1 super().__init__(input_shape[0], output_shape[0]) + from mlair.model_modules.keras_extensions import LearningRateDecay + # settings self.dropout_rate = .3 self.regularizer = keras.regularizers.l2(0.001) self.initial_lr = 1e-3 - self.lr_decay = mlair.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, - epochs_drop=10) + self.lr_decay = LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10) self.activation = keras.layers.ELU self.padding = "SymPad2D" # apply to model self.set_model() self.set_compile_options() - self.set_custom_objects(loss=self.compile_options["loss"], SymmetricPadding2D=SymmetricPadding2D, - LearningRateDecay=mlair.model_modules.keras_extensions.LearningRateDecay) + self.set_custom_objects(loss=self.compile_options["loss"], + SymmetricPadding2D=SymmetricPadding2D, + LearningRateDecay=LearningRateDecay) def set_model(self): """ diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index 0ce13790..dda18fac 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -76,6 +76,9 @@ class ModelSetup(RunEnvironment): # build model graph using settings from my_model_settings() self.build_model() + # broadcast custom objects + self.broadcast_custom_objects() + # plot model structure self.plot_model() @@ -140,6 +143,15 @@ class ModelSetup(RunEnvironment): model = self.data_store.get("model_class") self.model = model(**args) self.get_model_settings() + + def broadcast_custom_objects(self): + """ + Broadcast custom objects to keras utils. + + This method is very important, because it adds the model's custom objects to the keras utils. By doing so, all + custom objects can be treated as standard keras modules. Therefore, problems related to model or callback + loading are solved. + """ keras.utils.get_custom_objects().update(self.model.custom_objects) def get_model_settings(self): diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index 3878e79c..113765e0 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -145,7 +145,7 @@ class Training(RunEnvironment): logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.") self.callbacks.load_callbacks() self.callbacks.update_checkpoint() - self.model = keras.models.load_model(checkpoint.filepath, self.model.custom_objects) + self.model = keras.models.load_model(checkpoint.filepath) hist: History = self.callbacks.get_callback_by_name("hist") initial_epoch = max(hist.epoch) + 1 _ = self.model.fit_generator(generator=self.train_set, -- GitLab