Skip to content
Snippets Groups Projects
Commit abef8764 authored by Felix Kleinert's avatar Felix Kleinert
Browse files

update paddings and helper functions (#56)

parent f2ccaa3d
No related branches found
No related tags found
2 merge requests!59Develop,!46Felix #56 advanced paddings
...@@ -12,7 +12,7 @@ from keras.utils.generic_utils import transpose_shape ...@@ -12,7 +12,7 @@ from keras.utils.generic_utils import transpose_shape
from keras.backend.common import normalize_data_format from keras.backend.common import normalize_data_format
class pad_utils: class PadUtils:
""" """
Helper class for advanced paddings Helper class for advanced paddings
""" """
...@@ -76,8 +76,12 @@ class pad_utils: ...@@ -76,8 +76,12 @@ class pad_utils:
'Found: ' + str(padding)) 'Found: ' + str(padding))
height_padding = conv_utils.normalize_tuple(padding[0], 2, height_padding = conv_utils.normalize_tuple(padding[0], 2,
'1st entry of padding') '1st entry of padding')
if not all(k >= 0 for k in height_padding):
raise ValueError(f"The `1st entry of padding` argument must be >= 0. Received: {padding[0]} of type {type(padding[0])}")
width_padding = conv_utils.normalize_tuple(padding[1], 2, width_padding = conv_utils.normalize_tuple(padding[1], 2,
'2nd entry of padding') '2nd entry of padding')
if not all(k >= 0 for k in width_padding):
raise ValueError(f"The `2nd entry of padding` argument must be >= 0. Received: {padding[1]} of type {type(padding[1])}")
normalized_padding = (height_padding, width_padding) normalized_padding = (height_padding, width_padding)
else: else:
raise ValueError('`padding` should be either an int, ' raise ValueError('`padding` should be either an int, '
...@@ -85,7 +89,7 @@ class pad_utils: ...@@ -85,7 +89,7 @@ class pad_utils:
'(symmetric_height_pad, symmetric_width_pad), ' '(symmetric_height_pad, symmetric_width_pad), '
'or a tuple of 2 tuples of 2 ints ' 'or a tuple of 2 tuples of 2 ints '
'((top_pad, bottom_pad), (left_pad, right_pad)). ' '((top_pad, bottom_pad), (left_pad, right_pad)). '
'Found: ' + str(padding)) f'Found: {padding} of type {type(padding)}')
return normalized_padding return normalized_padding
...@@ -151,13 +155,13 @@ class ReflectionPadding2D(_ZeroPadding): ...@@ -151,13 +155,13 @@ class ReflectionPadding2D(_ZeroPadding):
padding=(1, 1), padding=(1, 1),
data_format=None, data_format=None,
**kwargs): **kwargs):
normalized_padding = pad_utils.check_padding_format(padding=padding) normalized_padding = PadUtils.check_padding_format(padding=padding)
super(ReflectionPadding2D, self).__init__(normalized_padding, super(ReflectionPadding2D, self).__init__(normalized_padding,
data_format, data_format,
**kwargs) **kwargs)
def call(self, inputs, mask=None): def call(self, inputs, mask=None):
pattern = pad_utils.spatial_2d_padding(padding=self.padding, data_format=self.data_format) pattern = PadUtils.spatial_2d_padding(padding=self.padding, data_format=self.data_format)
return tf.pad(inputs, pattern, 'REFLECT') return tf.pad(inputs, pattern, 'REFLECT')
...@@ -222,13 +226,13 @@ class SymmetricPadding2D(_ZeroPadding): ...@@ -222,13 +226,13 @@ class SymmetricPadding2D(_ZeroPadding):
padding=(1, 1), padding=(1, 1),
data_format=None, data_format=None,
**kwargs): **kwargs):
normalized_padding = pad_utils.check_padding_format(padding=padding) normalized_padding = PadUtils.check_padding_format(padding=padding)
super(SymmetricPadding2D, self).__init__(normalized_padding, super(SymmetricPadding2D, self).__init__(normalized_padding,
data_format, data_format,
**kwargs) **kwargs)
def call(self, inputs, mask=None): def call(self, inputs, mask=None):
pattern = pad_utils.spatial_2d_padding(padding=self.padding, data_format=self.data_format) pattern = PadUtils.spatial_2d_padding(padding=self.padding, data_format=self.data_format)
return tf.pad(inputs, pattern, 'SYMMETRIC') return tf.pad(inputs, pattern, 'SYMMETRIC')
...@@ -241,11 +245,11 @@ if __name__ == '__main__': ...@@ -241,11 +245,11 @@ if __name__ == '__main__':
y = x.mean(axis=(1, 2)) y = x.mean(axis=(1, 2))
x_input = Input(shape=x.shape[1:]) x_input = Input(shape=x.shape[1:])
pad1 = pad_utils.get_padding_for_same(kernel_1) pad1 = PadUtils.get_padding_for_same(kernel_1)
x_out = ReflectionPadding2D(padding=pad1, name="RefPAD")(x_input) x_out = ReflectionPadding2D(padding=pad1, name="RefPAD")(x_input)
x_out = Conv2D(1, kernel_size=kernel_1, activation='relu')(x_out) x_out = Conv2D(5, kernel_size=kernel_1, activation='relu')(x_out)
pad2 = pad_utils.get_padding_for_same(kernel_2) pad2 = PadUtils.get_padding_for_same(kernel_2)
x_out = SymmetricPadding2D(padding=pad2, name="SymPAD")(x_out) x_out = SymmetricPadding2D(padding=pad2, name="SymPAD")(x_out)
x_out = Conv2D(2, kernel_size=kernel_2, activation='relu')(x_out) x_out = Conv2D(2, kernel_size=kernel_2, activation='relu')(x_out)
x_out = Flatten()(x_out) x_out = Flatten()(x_out)
...@@ -254,6 +258,5 @@ if __name__ == '__main__': ...@@ -254,6 +258,5 @@ if __name__ == '__main__':
model = Model(inputs=x_input, outputs=x_out) model = Model(inputs=x_input, outputs=x_out)
model.compile('adam', loss='mse') model.compile('adam', loss='mse')
model.summary() model.summary()
# hist = model.fit(x, y, epochs=10)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment