diff --git a/src/model_modules/advanced_paddings.py b/src/model_modules/advanced_paddings.py new file mode 100644 index 0000000000000000000000000000000000000000..1d48dfc0a87c05183fc5b8b7755f48efaf7b5428 --- /dev/null +++ b/src/model_modules/advanced_paddings.py @@ -0,0 +1,280 @@ +__author__ = 'Felix Kleinert' +__date__ = '2020-03-02' + +import tensorflow as tf +import numpy as np +import keras.backend as K + +from keras.layers.convolutional import _ZeroPadding +from keras.legacy import interfaces +from keras.utils import conv_utils +from keras.utils.generic_utils import transpose_shape +from keras.backend.common import normalize_data_format + + +class PadUtils: + """ + Helper class for advanced paddings + """ + + @staticmethod + def get_padding_for_same(kernel_size, strides=1): + """ + This methods calculates the padding size to keep input and output dimensions equal for a given kernel size + (STRIDES HAVE TO BE EQUAL TO ONE!) + :param kernel_size: + :return: + """ + if strides != 1: + raise NotImplementedError("Strides other than 1 not implemented!") + if not all(isinstance(k, int) for k in kernel_size): + raise ValueError(f"The `kernel_size` argument must have a tuple of integers. Got: {kernel_size} " + f"of type {[type(k) for k in kernel_size]}") + + ks = np.array(kernel_size, dtype=np.int64) + + if any(k <= 0 for k in ks): + raise ValueError(f"All values of kernel_size must be > 0. Got: {kernel_size} ") + + if all(k % 2 == 1 for k in ks): # (d & 0x1 for d in ks): + pad = ((ks - 1) / 2).astype(np.int64) + # convert numpy int to base int + pad = [np.asscalar(v) for v in pad] + return tuple(pad) + # return tuple(PadUtils.check_padding_format(pad)) + else: + raise NotImplementedError(f"even kernel size not implemented. Got {kernel_size}") + + @staticmethod + def spatial_2d_padding(padding=((1, 1), (1, 1)), data_format=None): + """Pads the 2nd and 3rd dimensions of a 4D tensor. + + # Arguments + x: Tensor or variable. + padding: Tuple of 2 tuples, padding pattern. + data_format: string, `"channels_last"` or `"channels_first"`. + + # Returns + A padded 4D tensor. + + # Raises + ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`. + """ + assert len(padding) == 2 + assert len(padding[0]) == 2 + assert len(padding[1]) == 2 + data_format = normalize_data_format(data_format) + + pattern = [[0, 0], + list(padding[0]), + list(padding[1]), + [0, 0]] + pattern = transpose_shape(pattern, data_format, spatial_axes=(1, 2)) + return pattern + + @staticmethod + def check_padding_format(padding): + if isinstance(padding, int): + normalized_padding = ((padding, padding), (padding, padding)) + elif hasattr(padding, '__len__'): + if len(padding) != 2: + raise ValueError('`padding` should have two elements. ' + 'Found: ' + str(padding)) + for idx_pad, sub_pad in enumerate(padding): + if isinstance(sub_pad, str): + raise ValueError(f'`padding[{idx_pad}]` is str but must be int') + if hasattr(sub_pad, '__len__'): + if len(sub_pad) != 2: + raise ValueError(f'`padding[{idx_pad}]` should have one or two elements. ' + f'Found: {padding[idx_pad]}') + if not all(isinstance(sub_k, int) for sub_k in padding[idx_pad]): + raise ValueError(f'`padding[{idx_pad}]` should have one or two elements of type int. ' + f"Found:{padding[idx_pad]} of type {[type(sub_k) for sub_k in padding[idx_pad]]}") + height_padding = conv_utils.normalize_tuple(padding[0], 2, + '1st entry of padding') + if not all(k >= 0 for k in height_padding): + raise ValueError(f"The `1st entry of padding` argument must be >= 0. Received: {padding[0]} of type {type(padding[0])}") + width_padding = conv_utils.normalize_tuple(padding[1], 2, + '2nd entry of padding') + if not all(k >= 0 for k in width_padding): + raise ValueError(f"The `2nd entry of padding` argument must be >= 0. Received: {padding[1]} of type {type(padding[1])}") + normalized_padding = (height_padding, width_padding) + else: + raise ValueError('`padding` should be either an int, ' + 'a tuple of 2 ints ' + '(symmetric_height_pad, symmetric_width_pad), ' + 'or a tuple of 2 tuples of 2 ints ' + '((top_pad, bottom_pad), (left_pad, right_pad)). ' + f'Found: {padding} of type {type(padding)}') + return normalized_padding + + +class ReflectionPadding2D(_ZeroPadding): + """ + Reflection padding layer for 2D input. This custum padding layer is built on keras' zero padding layers. Doc is copy + pasted from the original functions/methods: + + + This layer can add rows and columns of reflected values + at the top, bottom, left and right side of an image like tensor. + + Example: + 6, 5, 4, 5, 6, 5, 4 + _________ + 1, 2, 3 RefPad(padding=[[1, 1,], [2, 2]]) 3, 2,| 1, 2, 3,| 2, 1 + 4, 5, 6 =============================>>>> 6, 5,| 4, 5, 6,| 5, 4 + _________ + 3, 2, 1, 2, 3, 2, 1 + + + + '# Arguments + padding: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints. + - If int: the same symmetric padding + is applied to height and width. + - If tuple of 2 ints: + interpreted as two different + symmetric padding values for height and width: + `(symmetric_height_pad, symmetric_width_pad)`. + - If tuple of 2 tuples of 2 ints: + interpreted as + `((top_pad, bottom_pad), (left_pad, right_pad))` + data_format: A string, + one of `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. + `"channels_last"` corresponds to inputs with shape + `(batch, height, width, channels)` while `"channels_first"` + corresponds to inputs with shape + `(batch, channels, height, width)`. + It defaults to the `image_data_format` value found in your + Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be "channels_last". + + # Input shape + 4D tensor with shape: + - If `data_format` is `"channels_last"`: + `(batch, rows, cols, channels)` + - If `data_format` is `"channels_first"`: + `(batch, channels, rows, cols)` + + # Output shape + 4D tensor with shape: + - If `data_format` is `"channels_last"`: + `(batch, padded_rows, padded_cols, channels)` + - If `data_format` is `"channels_first"`: + `(batch, channels, padded_rows, padded_cols)` + ' + """ + + @interfaces.legacy_zeropadding2d_support + def __init__(self, + padding=(1, 1), + data_format=None, + **kwargs): + normalized_padding = PadUtils.check_padding_format(padding=padding) + super(ReflectionPadding2D, self).__init__(normalized_padding, + data_format, + **kwargs) + + def call(self, inputs, mask=None): + pattern = PadUtils.spatial_2d_padding(padding=self.padding, data_format=self.data_format) + return tf.pad(inputs, pattern, 'REFLECT') + + +class SymmetricPadding2D(_ZeroPadding): + """ + Symmetric padding layer for 2D input. This custom padding layer is built on keras' zero padding layers. Doc is copy + pasted from the original functions/methods: + + + This layer can add rows and columns of symmetric values + at the top, bottom, left and right side of an image like tensor. + + Example: + 2, 1, 1, 2, 3, 3, 2 + _________ + 1, 2, 3 SymPad(padding=[[1, 1,], [2, 2]]) 2, 1,| 1, 2, 3,| 3, 2 + 4, 5, 6 =============================>>>> 5, 4,| 4, 5, 6,| 6, 5 + _________ + 5, 4, 4, 5, 6, 6, 5 + + + '# Arguments + padding: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints. + - If int: the same symmetric padding + is applied to height and width. + - If tuple of 2 ints: + interpreted as two different + symmetric padding values for height and width: + `(symmetric_height_pad, symmetric_width_pad)`. + - If tuple of 2 tuples of 2 ints: + interpreted as + `((top_pad, bottom_pad), (left_pad, right_pad))` + data_format: A string, + one of `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. + `"channels_last"` corresponds to inputs with shape + `(batch, height, width, channels)` while `"channels_first"` + corresponds to inputs with shape + `(batch, channels, height, width)`. + It defaults to the `image_data_format` value found in your + Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be "channels_last". + + # Input shape + 4D tensor with shape: + - If `data_format` is `"channels_last"`: + `(batch, rows, cols, channels)` + - If `data_format` is `"channels_first"`: + `(batch, channels, rows, cols)` + + # Output shape + 4D tensor with shape: + - If `data_format` is `"channels_last"`: + `(batch, padded_rows, padded_cols, channels)` + - If `data_format` is `"channels_first"`: + `(batch, channels, padded_rows, padded_cols)` + ' + """ + + @interfaces.legacy_zeropadding2d_support + def __init__(self, + padding=(1, 1), + data_format=None, + **kwargs): + normalized_padding = PadUtils.check_padding_format(padding=padding) + super(SymmetricPadding2D, self).__init__(normalized_padding, + data_format, + **kwargs) + + def call(self, inputs, mask=None): + pattern = PadUtils.spatial_2d_padding(padding=self.padding, data_format=self.data_format) + return tf.pad(inputs, pattern, 'SYMMETRIC') + + +if __name__ == '__main__': + from keras.models import Model + from keras.layers import Conv2D, Flatten, Dense, Input + + kernel_1 = (3, 3) + kernel_2 = (5, 5) + x = np.array(range(2000)).reshape(-1, 10, 10, 1) + y = x.mean(axis=(1, 2)) + + x_input = Input(shape=x.shape[1:]) + pad1 = PadUtils.get_padding_for_same(kernel_size=kernel_1) + x_out = ReflectionPadding2D(padding=pad1, name="RefPAD")(x_input) + x_out = Conv2D(5, kernel_size=kernel_1, activation='relu')(x_out) + + pad2 = PadUtils.get_padding_for_same(kernel_size=kernel_2) + x_out = SymmetricPadding2D(padding=pad2, name="SymPAD")(x_out) + x_out = Conv2D(2, kernel_size=kernel_2, activation='relu')(x_out) + x_out = Flatten()(x_out) + x_out = Dense(1, activation='linear')(x_out) + + model = Model(inputs=x_input, outputs=x_out) + model.compile('adam', loss='mse') + model.summary() + model.fit(x, y, epochs=10) + + diff --git a/test/test_model_modules/test_advanced_paddings.py b/test/test_model_modules/test_advanced_paddings.py new file mode 100644 index 0000000000000000000000000000000000000000..5282eb6df34d4d395dbbdd1fd76fd71a95e9c8df --- /dev/null +++ b/test/test_model_modules/test_advanced_paddings.py @@ -0,0 +1,419 @@ +import keras +import pytest + +from src.model_modules.advanced_paddings import * + + +class TestPadUtils: + + def test_get_padding_for_same_negative_kernel_size(self): + print('In test_get_padding_for_same_negative_kernel_size') + with pytest.raises(ValueError) as einfo: + PadUtils.get_padding_for_same((-1, 2)) + assert 'All values of kernel_size must be > 0. Got: (-1, 2) ' in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + PadUtils.get_padding_for_same((1, -2)) + assert 'All values of kernel_size must be > 0. Got: (1, -2) ' in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + PadUtils.get_padding_for_same((-1, -2)) + assert 'All values of kernel_size must be > 0. Got: (-1, -2) ' in str(einfo.value) + + def test_get_padding_for_same_strides_greater_one(self): + with pytest.raises(NotImplementedError) as einfo: + PadUtils.get_padding_for_same((1, 1), strides=2) + assert 'Strides other than 1 not implemented!' in str(einfo.value) + + with pytest.raises(NotImplementedError) as einfo: + PadUtils.get_padding_for_same((1, 1), strides=-1) + assert 'Strides other than 1 not implemented!' in str(einfo.value) + + def test_get_padding_for_same_non_int_kernel(self): + with pytest.raises(ValueError) as einfo: + PadUtils.get_padding_for_same((1., 1)) + assert "The `kernel_size` argument must have a tuple of integers. Got: (1.0, 1) " \ + "of type [<class 'float'>, <class 'int'>]" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + PadUtils.get_padding_for_same((1, 1.)) + assert "The `kernel_size` argument must have a tuple of integers. Got: (1, 1.0) " \ + "of type [<class 'int'>, <class 'float'>]" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + PadUtils.get_padding_for_same((1, '1.')) + assert "The `kernel_size` argument must have a tuple of integers. Got: (1, '1.') " \ + "of type [<class 'int'>, <class 'str'>]" in str(einfo.value) + + def test_get_padding_for_same_stride_3d(self): + kernel = (3, 3, 3) + pad = PadUtils.get_padding_for_same(kernel) + assert pad == (1, 1, 1) + assert isinstance(pad, tuple) + assert isinstance(pad[0], int) and isinstance(pad[1], int) + assert not (isinstance(pad[0], np.int64) and isinstance(pad[1], np.int64) and isinstance(pad[2], np.int64)) + + def test_get_padding_for_same_even_pad(self): + with pytest.raises(NotImplementedError) as einfo: + PadUtils.get_padding_for_same((2, 1)) + assert 'even kernel size not implemented. Got (2, 1)' in str(einfo.value) + + with pytest.raises(NotImplementedError) as einfo: + PadUtils.get_padding_for_same((1, 4)) + assert 'even kernel size not implemented. Got (1, 4)' in str(einfo.value) + + with pytest.raises(NotImplementedError) as einfo: + PadUtils.get_padding_for_same((2, 4)) + assert 'even kernel size not implemented. Got (2, 4)' in str(einfo.value) + + ################################################################################## + + def test_check_padding_format_negative_pads(self): + + with pytest.raises(ValueError) as einfo: + PadUtils.check_padding_format((-2, 1)) + assert "The `1st entry of padding` argument must be >= 0. Received: -2 of type <class 'int'>" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + PadUtils.check_padding_format((1, -1)) + assert "The `2nd entry of padding` argument must be >= 0. Received: -1 of type <class 'int'>" in str( + einfo.value) + + with pytest.raises(ValueError) as einfo: + PadUtils.check_padding_format((-2, -1)) + assert "The `1st entry of padding` argument must be >= 0. Received: -2 of type <class 'int'>" in str( + einfo.value) + + def test_check_padding_format_len_of_pad_tuple(self): + with pytest.raises(ValueError) as einfo: + PadUtils.check_padding_format((1, 1, 2)) + assert "`padding` should have two elements. Found: (1, 1, 2)" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + PadUtils.check_padding_format((1, 1, 2, 2)) + assert "`padding` should have two elements. Found: (1, 1, 2, 2)" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + PadUtils.check_padding_format(((1, 1, 3), (2, 2, 4))) + assert "`padding[0]` should have one or two elements. Found: (1, 1, 3)" in str(einfo.value) + + assert PadUtils.check_padding_format(((1, 1), (2, 2))) == ((1, 1), (2, 2)) + assert PadUtils.check_padding_format((1, 2)) == ((1, 1), (2, 2)) + assert PadUtils.check_padding_format(1) == ((1, 1), (1, 1)) + + def test_check_padding_format_tuple_of_none_integer(self): + with pytest.raises(ValueError) as einfo: + PadUtils.check_padding_format((1.2, 1)) + assert "The `1st entry of padding` argument must be a tuple of 2 integers. Received: 1.2" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + PadUtils.check_padding_format((1, 1.)) + assert "The `2nd entry of padding` argument must be a tuple of 2 integers. Received: 1.0" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + PadUtils.check_padding_format(1.2) + assert "`padding` should be either an int, a tuple of 2 ints (symmetric_height_pad, symmetric_width_pad), " \ + "or a tuple of 2 tuples of 2 ints ((top_pad, bottom_pad), (left_pad, right_pad)). Found: 1.2 of type " \ + "<class 'float'>" in str(einfo.value) + + def test_check_padding_format_tuple_of_tuple_none_integer_first(self): + with pytest.raises(ValueError) as einfo: + PadUtils.check_padding_format(((1., 2), (3, 4))) + assert "`padding[0]` should have one or two elements of type int. Found:(1.0, 2) " \ + "of type [<class 'float'>, <class 'int'>]" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + PadUtils.check_padding_format(((1, 2.), (3, 4))) + assert "`padding[0]` should have one or two elements of type int. Found:(1, 2.0) " \ + "of type [<class 'int'>, <class 'float'>]" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + PadUtils.check_padding_format(((1, '2'), (3, 4))) + assert "`padding[0]` should have one or two elements of type int. Found:(1, '2') " \ + "of type [<class 'int'>, <class 'str'>]" in str(einfo.value) + + def test_check_padding_format_tuple_of_tuple_none_integer_second(self): + with pytest.raises(ValueError) as einfo: + PadUtils.check_padding_format(((1, 2), (3., 4))) + assert "`padding[1]` should have one or two elements of type int. Found:(3.0, 4) " \ + "of type [<class 'float'>, <class 'int'>]" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + PadUtils.check_padding_format(((1, 2), (3, 4.))) + assert "`padding[1]` should have one or two elements of type int. Found:(3, 4.0) " \ + "of type [<class 'int'>, <class 'float'>]" in str(einfo.value) + + def test_check_padding_format_valid_mix_of_int_and_tuple(self): + assert PadUtils.check_padding_format(((1, 2), 3)) == ((1, 2), (3, 3)) + assert PadUtils.check_padding_format((1, (2, 3))) == ((1, 1), (2, 3)) + + def test_check_padding_format_invalid_mixed_tuple_and_int(self): + with pytest.raises(ValueError) as einfo: + PadUtils.check_padding_format(((1., 2), 3)) + assert "`padding[0]` should have one or two elements of type int. Found:(1.0, 2) " \ + "of type [<class 'float'>, <class 'int'>]" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + PadUtils.check_padding_format(((1, 2), 3.)) + assert "The `2nd entry of padding` argument must be a tuple of 2 integers. Received: 3.0" in str(einfo.value) + + def test_check_padding_format_invalid_mixed_int_and_tuple(self): + with pytest.raises(ValueError) as einfo: + PadUtils.check_padding_format((1., (2, 3))) + assert "The `1st entry of padding` argument must be a tuple of 2 integers. Received: 1.0" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + PadUtils.check_padding_format((1, (2., 3))) + assert "`padding[1]` should have one or two elements of type int. Found:(2.0, 3) " \ + "of type [<class 'float'>, <class 'int'>]" in str(einfo.value) + + +class TestReflectionPadding2D: + + @pytest.fixture + def input_x(self): + return keras.Input(shape=(10, 10, 3)) + + def test_init_tuple_of_valid_int(self): + pad = (1, 3) + layer_name = "RefPAD" + ref_pad = ReflectionPadding2D(padding=pad, name=layer_name) + assert ref_pad.padding == ((1, 1), (3, 3)) + assert ref_pad.name == 'RefPAD' + assert ref_pad.data_format == 'channels_last' + assert ref_pad.rank == 2 + + pad = (0, 1) + ref_pad = ReflectionPadding2D(padding=pad, name=layer_name) + assert ref_pad.padding == ((0, 0), (1, 1)) + assert ref_pad.name == 'RefPAD' + assert ref_pad.data_format == 'channels_last' + assert ref_pad.rank == 2 + + pad = (5, 3) + layer_name = "RefPAD_5x3" + ref_pad = ReflectionPadding2D(padding=pad, name=layer_name) + assert ref_pad.padding == ((5, 5), (3, 3)) + + def test_init_tuple_of_negative_int(self): + with pytest.raises(ValueError) as einfo: + ReflectionPadding2D(padding=(-1, 1)) + assert "The `1st entry of padding` argument must be >= 0. Received: -1 of type <class 'int'>" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + ReflectionPadding2D(padding=(1, -2)) + assert "The `2nd entry of padding` argument must be >= 0. Received: -2 of type <class 'int'>" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + ReflectionPadding2D(padding=(-1, -2)) + assert "The `1st entry of padding` argument must be >= 0. Received: -1 of type <class 'int'>" in str(einfo.value) + + def test_init_tuple_of_invalid_format_float(self): + with pytest.raises(ValueError) as einfo: + ReflectionPadding2D(padding=(1., 1)) + assert 'The `1st entry of padding` argument must be a tuple of 2 integers. Received: 1.0' in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + ReflectionPadding2D(padding=(1, 1.2)) + assert 'The `2nd entry of padding` argument must be a tuple of 2 integers. Received: 1.2' in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + ReflectionPadding2D(padding=(1., 1.2)) + assert 'The `1st entry of padding` argument must be a tuple of 2 integers. Received: 1.0' in str(einfo.value) + + def test_init_tuple_of_invalid_format_string(self): + with pytest.raises(ValueError) as einfo: + ReflectionPadding2D(padding=('1', 2)) + # This error message is not the best as it is missing the type information. + # But it is raised by keras.utils.conv_utils which I will not touch. + assert "`padding[0]` is str but must be int" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + ReflectionPadding2D(padding=(1, '2')) + assert '`padding[1]` is str but must be int' in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + ReflectionPadding2D(padding=('1', '2')) + assert '`padding[0]` is str but must be int' in str(einfo.value) + + def test_init_int(self): + layer_name = "RefPAD" + ref_pad = ReflectionPadding2D(padding=1, name=layer_name) + assert ref_pad.padding == ((1, 1), (1, 1)) + assert ref_pad.name == "RefPAD" + + def test_init_tuple_of_tuple_of_valid_int(self): + ref_pad = ReflectionPadding2D(padding=((0, 1), (2, 3)), name="RefPAD") + assert ref_pad.padding == ((0, 1), (2, 3)) + assert ref_pad.name == "RefPAD" + + def test_init_tuple_of_tuple_of_invalid_int(self): + with pytest.raises(ValueError) as einfo: + ReflectionPadding2D(padding=((-4, 1), (2, 3)), name="RefPAD") + assert "The `1st entry of padding` argument must be >= 0. Received: (-4, 1) of type <class 'tuple'>" in str( + einfo.value) + + with pytest.raises(ValueError) as einfo: + ReflectionPadding2D(padding=((4, -1), (2, 3)), name="RefPAD") + assert "The `1st entry of padding` argument must be >= 0. Received: (4, -1) of type <class 'tuple'>" in str( + einfo.value) + + with pytest.raises(ValueError) as einfo: + ReflectionPadding2D(padding=((4, 1), (-2, 3)), name="RefPAD") + assert "The `2nd entry of padding` argument must be >= 0. Received: (-2, 3) of type <class 'tuple'>" in str( + einfo.value) + + with pytest.raises(ValueError) as einfo: + ReflectionPadding2D(padding=((4, 1), (2, -3)), name="RefPAD") + assert "The `2nd entry of padding` argument must be >= 0. Received: (2, -3) of type <class 'tuple'>" in str( + einfo.value) + + def test_init_tuple_of_tuple_of_invalid_format(self): + with pytest.raises(ValueError) as einfo: + ReflectionPadding2D(padding=((0.1, 1), (2, 3)), name="RefPAD") + assert "`padding[0]` should have one or two elements of type int. Found:(0.1, 1) " \ + "of type [<class 'float'>, <class 'int'>]" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + ReflectionPadding2D(padding=(1, 2.2)) + assert "The `2nd entry of padding` argument must be a tuple of 2 integers. Received: 2.2" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + ReflectionPadding2D(padding=((0, 1), ('2', 3)), name="RefPAD") + assert "`padding[1]` should have one or two elements of type int. Found:('2', 3) " \ + "of type [<class 'str'>, <class 'int'>]" in str(einfo.value) + + def test_call(self, input_x): + # here it behaves like a "normal" keras layer, I don't know how to test those + pad = (1, 0) + layer_name = "RefPad_3x1" + ref_pad = ReflectionPadding2D(padding=pad, name=layer_name)(input_x) + assert ref_pad.get_shape().as_list() == [None, 12, 10, 3] + assert ref_pad.name == 'RefPad_3x1/MirrorPad:0' + + +class TestSymmerticPadding2D: + + @pytest.fixture + def input_x(self): + return keras.Input(shape=(10, 10, 3)) + + def test_init_tuple_of_valid_int(self): + pad = (1, 3) + layer_name = "SymPad" + sym_pad = SymmetricPadding2D(padding=pad, name=layer_name) + assert sym_pad.padding == ((1, 1), (3, 3)) + assert sym_pad.name == 'SymPad' + assert sym_pad.data_format == 'channels_last' + assert sym_pad.rank == 2 + + pad = (0, 1) + sym_pad = SymmetricPadding2D(padding=pad, name=layer_name) + assert sym_pad.padding == ((0, 0), (1, 1)) + assert sym_pad.name == 'SymPad' + assert sym_pad.data_format == 'channels_last' + assert sym_pad.rank == 2 + + pad = (5, 3) + layer_name = "SymPad_5x3" + sym_pad = SymmetricPadding2D(padding=pad, name=layer_name) + assert sym_pad.padding == ((5, 5), (3, 3)) + + def test_init_tuple_of_negative_int(self): + with pytest.raises(ValueError) as einfo: + SymmetricPadding2D(padding=(-1, 1)) + assert "The `1st entry of padding` argument must be >= 0. Received: -1 of type <class 'int'>" in str( + einfo.value) + + with pytest.raises(ValueError) as einfo: + SymmetricPadding2D(padding=(1, -2)) + assert "The `2nd entry of padding` argument must be >= 0. Received: -2 of type <class 'int'>" in str( + einfo.value) + + with pytest.raises(ValueError) as einfo: + SymmetricPadding2D(padding=(-1, -2)) + assert "The `1st entry of padding` argument must be >= 0. Received: -1 of type <class 'int'>" in str( + einfo.value) + + def test_init_tuple_of_invalid_format_float(self): + with pytest.raises(ValueError) as einfo: + SymmetricPadding2D(padding=(1., 1)) + assert 'The `1st entry of padding` argument must be a tuple of 2 integers. Received: 1.0' in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + SymmetricPadding2D(padding=(1, 1.2)) + assert 'The `2nd entry of padding` argument must be a tuple of 2 integers. Received: 1.2' in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + SymmetricPadding2D(padding=(1., 1.2)) + assert 'The `1st entry of padding` argument must be a tuple of 2 integers. Received: 1.0' in str(einfo.value) + + def test_init_tuple_of_invalid_format_string(self): + with pytest.raises(ValueError) as einfo: + SymmetricPadding2D(padding=('1', 2)) + # This error message is not the best as it is missing the type information. + # But it is raised by keras.utils.conv_utils which I will not touch. + assert "`padding[0]` is str but must be int" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + SymmetricPadding2D(padding=(1, '2')) + assert '`padding[1]` is str but must be int' in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + SymmetricPadding2D(padding=('1', '2')) + assert '`padding[0]` is str but must be int' in str(einfo.value) + + def test_init_int(self): + layer_name = "SymPad" + sym_pad = SymmetricPadding2D(padding=1, name=layer_name) + assert sym_pad.padding == ((1, 1), (1, 1)) + assert sym_pad.name == "SymPad" + + def test_init_tuple_of_tuple_of_valid_int(self): + sym_pad = SymmetricPadding2D(padding=((0, 1), (2, 3)), name="SymPad") + assert sym_pad.padding == ((0, 1), (2, 3)) + assert sym_pad.name == "SymPad" + + def test_init_tuple_of_tuple_of_invalid_int(self): + with pytest.raises(ValueError) as einfo: + SymmetricPadding2D(padding=((-4, 1), (2, 3)), name="SymPad") + assert "The `1st entry of padding` argument must be >= 0. Received: (-4, 1) of type <class 'tuple'>" in str( + einfo.value) + + with pytest.raises(ValueError) as einfo: + SymmetricPadding2D(padding=((4, -1), (2, 3)), name="SymPad") + assert "The `1st entry of padding` argument must be >= 0. Received: (4, -1) of type <class 'tuple'>" in str( + einfo.value) + + with pytest.raises(ValueError) as einfo: + SymmetricPadding2D(padding=((4, 1), (-2, 3)), name="SymPad") + assert "The `2nd entry of padding` argument must be >= 0. Received: (-2, 3) of type <class 'tuple'>" in str( + einfo.value) + + with pytest.raises(ValueError) as einfo: + SymmetricPadding2D(padding=((4, 1), (2, -3)), name="SymPad") + assert "The `2nd entry of padding` argument must be >= 0. Received: (2, -3) of type <class 'tuple'>" in str( + einfo.value) + + def test_init_tuple_of_tuple_of_invalid_format(self): + with pytest.raises(ValueError) as einfo: + SymmetricPadding2D(padding=((0.1, 1), (2, 3)), name="SymPad") + assert "`padding[0]` should have one or two elements of type int. Found:(0.1, 1) " \ + "of type [<class 'float'>, <class 'int'>]" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + SymmetricPadding2D(padding=(1, 2.2)) + assert "The `2nd entry of padding` argument must be a tuple of 2 integers. Received: 2.2" in str(einfo.value) + + with pytest.raises(ValueError) as einfo: + SymmetricPadding2D(padding=((0, 1), ('2', 3)), name="SymPad") + assert "`padding[1]` should have one or two elements of type int. Found:('2', 3) " \ + "of type [<class 'str'>, <class 'int'>]" in str(einfo.value) + + def test_call(self, input_x): + # here it behaves like a "normal" keras layer, I don't know how to test those + pad = (1, 0) + layer_name = "SymPad_3x1" + sym_pad = SymmetricPadding2D(padding=pad, name=layer_name)(input_x) + assert sym_pad.get_shape().as_list() == [None, 12, 10, 3] + assert sym_pad.name == 'SymPad_3x1/MirrorPad:0'