From b9c5cf7418db875cf00cb8885eebfbcf68d38827 Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Tue, 16 Feb 2021 15:40:37 +0100 Subject: [PATCH] added custom object for MyPaperModel --- mlair/model_modules/model_class.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py index a603b466..07eaa1ec 100644 --- a/mlair/model_modules/model_class.py +++ b/mlair/model_modules/model_class.py @@ -127,7 +127,7 @@ import keras import tensorflow as tf from mlair.model_modules.inception_model import InceptionModelBase from mlair.model_modules.flatten import flatten_tail -from mlair.model_modules.advanced_paddings import PadUtils, Padding2D +from mlair.model_modules.advanced_paddings import PadUtils, Padding2D, SymmetricPadding2D class AbstractModelClass(ABC): @@ -648,7 +648,8 @@ class MyPaperModel(AbstractModelClass): # apply to model self.set_model() self.set_compile_options() - self.set_custom_objects(loss=self.compile_options["loss"], Padding2D=Padding2D) + self.set_custom_objects(loss=self.compile_options["loss"], SymmetricPadding2D=SymmetricPadding2D, + LearningRateDecay=mlair.model_modules.keras_extensions.LearningRateDecay) def set_model(self): """ @@ -704,8 +705,6 @@ class MyPaperModel(AbstractModelClass): X_input = keras.layers.Input(shape=self._input_shape) pad_size = PadUtils.get_padding_for_same(first_kernel) - # X_in = adv_pad.SymmetricPadding2D(padding=pad_size)(X_input) - # X_in = inception_model.padding_layer("SymPad2D")(padding=pad_size, name="SymPad")(X_input) # adv_pad.SymmetricPadding2D(padding=pad_size)(X_input) X_in = Padding2D("SymPad2D")(padding=pad_size, name="SymPad")(X_input) X_in = keras.layers.Conv2D(filters=first_filters, kernel_size=first_kernel, @@ -748,7 +747,7 @@ class MyPaperModel(AbstractModelClass): def set_compile_options(self): self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9) self.compile_options = {"loss": [keras.losses.mean_squared_error, keras.losses.mean_squared_error], - "metrics": ['mse', 'mea']} + "metrics": ['mse', 'mae']} if __name__ == "__main__": -- GitLab