From 090e1e08f84631baeafe0fb8eb4ff3813c37bb87 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Fri, 13 Mar 2020 11:18:19 +0100 Subject: [PATCH] added custom_objects to model_class.py --- src/model_modules/advanced_paddings.py | 72 +++++++++++++------------- src/model_modules/model_class.py | 26 +++++++++- 2 files changed, 61 insertions(+), 37 deletions(-) diff --git a/src/model_modules/advanced_paddings.py b/src/model_modules/advanced_paddings.py index 2e2892d8..d9e55c78 100644 --- a/src/model_modules/advanced_paddings.py +++ b/src/model_modules/advanced_paddings.py @@ -110,42 +110,6 @@ class PadUtils: return normalized_padding -class Padding2D: - ''' - This class combines the implemented padding methods. You can call this method by defining a specific padding type. - The __call__ method will return the corresponding Padding layer. - ''' - def __init__(self, padding_type): - self.padding_type = padding_type - self.allowed_paddings = { - **dict.fromkeys(("RefPad2D", "ReflectionPadding2D"), ReflectionPadding2D), - **dict.fromkeys(("SymPad2D", "SymmetricPadding2D"), SymmetricPadding2D), - **dict.fromkeys(("ZeroPad2D", "ZeroPadding2D"), ZeroPadding2D) - } - - def _check_and_get_padding(self): - if isinstance(self.padding_type, str): - try: - pad2d = self.allowed_paddings[self.padding_type] - except KeyError as einfo: - raise NotImplementedError( - f"`{einfo}' is not implemented as padding. " - "Use one of those: i) `RefPad2D', ii) `SymPad2D', iii) `ZeroPad2D'") - else: - if self.padding_type in self.allowed_paddings.values(): - pad2d = self.padding_type - else: - raise TypeError(f"`{self.padding_type.__name__}' is not a valid padding layer type. " - "Use one of those: " - "i) ReflectionPadding2D, ii) SymmetricPadding2D, iii) ZeroPadding2D") - return pad2d - - def __call__(self, *args, **kwargs): - return self._check_and_get_padding()(*args, **kwargs) - - - - class ReflectionPadding2D(_ZeroPadding): """ Reflection padding layer for 2D input. This custum padding layer is built on keras' zero padding layers. Doc is copy @@ -289,6 +253,42 @@ class SymmetricPadding2D(_ZeroPadding): return tf.pad(inputs, pattern, 'SYMMETRIC') +class Padding2D: + ''' + This class combines the implemented padding methods. You can call this method by defining a specific padding type. + The __call__ method will return the corresponding Padding layer. + ''' + + allowed_paddings = { + **dict.fromkeys(("RefPad2D", "ReflectionPadding2D"), ReflectionPadding2D), + **dict.fromkeys(("SymPad2D", "SymmetricPadding2D"), SymmetricPadding2D), + **dict.fromkeys(("ZeroPad2D", "ZeroPadding2D"), ZeroPadding2D) + } + + def __init__(self, padding_type): + self.padding_type = padding_type + + def _check_and_get_padding(self): + if isinstance(self.padding_type, str): + try: + pad2d = self.allowed_paddings[self.padding_type] + except KeyError as einfo: + raise NotImplementedError( + f"`{einfo}' is not implemented as padding. " + "Use one of those: i) `RefPad2D', ii) `SymPad2D', iii) `ZeroPad2D'") + else: + if self.padding_type in self.allowed_paddings.values(): + pad2d = self.padding_type + else: + raise TypeError(f"`{self.padding_type.__name__}' is not a valid padding layer type. " + "Use one of those: " + "i) ReflectionPadding2D, ii) SymmetricPadding2D, iii) ZeroPadding2D") + return pad2d + + def __call__(self, *args, **kwargs): + return self._check_and_get_padding()(*args, **kwargs) + + if __name__ == '__main__': from keras.models import Model from keras.layers import Conv2D, Flatten, Dense, Input diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py index 7796ea4f..27e13c66 100644 --- a/src/model_modules/model_class.py +++ b/src/model_modules/model_class.py @@ -5,7 +5,7 @@ __date__ = '2019-12-12' from abc import ABC -from typing import Any, Callable +from typing import Any, Callable, Dict import keras from src.model_modules.inception_model import InceptionModelBase @@ -31,6 +31,7 @@ class AbstractModelClass(ABC): self.__model = None self.__loss = None self.model_name = self.__class__.__name__ + self.__custom_objects = {} def __getattr__(self, name: str) -> Any: @@ -79,9 +80,28 @@ class AbstractModelClass(ABC): def loss(self, value) -> None: self.__loss = value + @property + def custom_objects(self) -> Dict: + return self.__custom_objects + + @custom_objects.setter + def custom_objects(self, value) -> None: + self.__custom_objects = value + def get_settings(self): return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith("_AbstractModelClass__")) + def set_model(self): + pass + + def set_loss(self): + pass + + def set_custom_objects(self, **kwargs): + if "Padding2D" in kwargs.keys(): + kwargs.update(kwargs["Padding2D"].allowed_paddings) + self.custom_objects.update(kwargs) + class MyLittleModel(AbstractModelClass): @@ -121,6 +141,7 @@ class MyLittleModel(AbstractModelClass): # apply to model self.set_model() self.set_loss() + self.set_custom_objects(loss=self.loss) def set_model(self): @@ -201,6 +222,7 @@ class MyBranchedModel(AbstractModelClass): # apply to model self.set_model() self.set_loss() + self.set_custom_objects(loss=self.loss) def set_model(self): @@ -277,6 +299,7 @@ class MyTowerModel(AbstractModelClass): # apply to model self.set_model() self.set_loss() + self.set_custom_objects(loss=self.loss) def set_model(self): @@ -388,6 +411,7 @@ class MyPaperModel(AbstractModelClass): # apply to model self.set_model() self.set_loss() + self.set_custom_objects(loss=self.loss, Padding2D=Padding2D) def set_model(self): -- GitLab