Skip to content
Snippets Groups Projects
Commit 1415a760 authored by leufen1's avatar leufen1
Browse files

removed unnecessary changes from previous commits, there is now a new method...

removed unnecessary changes from previous commits, there is now a new method for the keras magic to emphasise its importance
parent 85d067ff
No related branches found
No related tags found
4 merge requests!253include current develop,!252Resolve "release v1.3.0",!245update #275 branch,!242Resolve "BUG: loading of custom objects not working"
Pipeline #60058 passed
......@@ -174,8 +174,8 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
def on_epoch_end(self, epoch, logs=None):
"""Save model as usual (see ModelCheckpoint class), but also save additional callbacks."""
super().on_epoch_end(epoch, logs)
for callback in self.callbacks:
print(callback.keys())
file_path = callback["path"]
if self.epochs_since_last_save == 0 and epoch != 0:
if self.save_best_only:
......@@ -184,7 +184,6 @@ class ModelCheckpointAdvanced(ModelCheckpoint):
if self.verbose > 0: # pragma: no branch
print('\nEpoch %05d: save to %s' % (epoch + 1, file_path))
with open(file_path, "wb") as f:
# ToDo: create "save" method
pickle.dump(callback["callback"], f)
else:
with open(file_path, "wb") as f:
......@@ -259,8 +258,7 @@ class CallbackHandler:
@property
def _callbacks(self):
return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk
in self.__callbacks]
return [{"callback": clbk[clbk["name"]], "path": clbk["path"]} for clbk in self.__callbacks]
@_callbacks.setter
def _callbacks(self, value):
......
......@@ -636,20 +636,22 @@ class MyPaperModel(AbstractModelClass):
assert len(output_shape) == 1
super().__init__(input_shape[0], output_shape[0])
from mlair.model_modules.keras_extensions import LearningRateDecay
# settings
self.dropout_rate = .3
self.regularizer = keras.regularizers.l2(0.001)
self.initial_lr = 1e-3
self.lr_decay = mlair.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94,
epochs_drop=10)
self.lr_decay = LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10)
self.activation = keras.layers.ELU
self.padding = "SymPad2D"
# apply to model
self.set_model()
self.set_compile_options()
self.set_custom_objects(loss=self.compile_options["loss"], SymmetricPadding2D=SymmetricPadding2D,
LearningRateDecay=mlair.model_modules.keras_extensions.LearningRateDecay)
self.set_custom_objects(loss=self.compile_options["loss"],
SymmetricPadding2D=SymmetricPadding2D,
LearningRateDecay=LearningRateDecay)
def set_model(self):
"""
......
......@@ -76,6 +76,9 @@ class ModelSetup(RunEnvironment):
# build model graph using settings from my_model_settings()
self.build_model()
# broadcast custom objects
self.broadcast_custom_objects()
# plot model structure
self.plot_model()
......@@ -140,6 +143,15 @@ class ModelSetup(RunEnvironment):
model = self.data_store.get("model_class")
self.model = model(**args)
self.get_model_settings()
def broadcast_custom_objects(self):
"""
Broadcast custom objects to keras utils.
This method is very important, because it adds the model's custom objects to the keras utils. By doing so, all
custom objects can be treated as standard keras modules. Therefore, problems related to model or callback
loading are solved.
"""
keras.utils.get_custom_objects().update(self.model.custom_objects)
def get_model_settings(self):
......
......@@ -145,7 +145,7 @@ class Training(RunEnvironment):
logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.")
self.callbacks.load_callbacks()
self.callbacks.update_checkpoint()
self.model = keras.models.load_model(checkpoint.filepath, self.model.custom_objects)
self.model = keras.models.load_model(checkpoint.filepath)
hist: History = self.callbacks.get_callback_by_name("hist")
initial_epoch = max(hist.epoch) + 1
_ = self.model.fit_generator(generator=self.train_set,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment