diff --git a/src/model_modules/advanced_paddings.py b/src/model_modules/advanced_paddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d48dfc0a87c05183fc5b8b7755f48efaf7b5428
--- /dev/null
+++ b/src/model_modules/advanced_paddings.py
@@ -0,0 +1,280 @@
+__author__ = 'Felix Kleinert'
+__date__ = '2020-03-02'
+
+import tensorflow as tf
+import numpy as np
+import keras.backend as K
+
+from keras.layers.convolutional import _ZeroPadding
+from keras.legacy import interfaces
+from keras.utils import conv_utils
+from keras.utils.generic_utils import transpose_shape
+from keras.backend.common import normalize_data_format
+
+
+class PadUtils:
+    """
+    Helper class for advanced paddings
+    """
+
+    @staticmethod
+    def get_padding_for_same(kernel_size, strides=1):
+        """
+        This methods calculates the padding size to keep input and output dimensions equal for a given kernel size
+        (STRIDES HAVE TO BE EQUAL TO ONE!)
+        :param kernel_size:
+        :return:
+        """
+        if strides != 1:
+            raise NotImplementedError("Strides other than 1 not implemented!")
+        if not all(isinstance(k, int) for k in kernel_size):
+            raise ValueError(f"The `kernel_size` argument must have a tuple of integers. Got: {kernel_size} "
+                             f"of type {[type(k) for k in kernel_size]}")
+
+        ks = np.array(kernel_size, dtype=np.int64)
+
+        if any(k <= 0 for k in ks):
+            raise ValueError(f"All values of kernel_size must be > 0. Got: {kernel_size} ")
+
+        if all(k % 2 == 1 for k in ks):  # (d & 0x1 for d in ks):
+            pad = ((ks - 1) / 2).astype(np.int64)
+            # convert numpy int to base int
+            pad = [np.asscalar(v) for v in pad]
+            return tuple(pad)
+            # return tuple(PadUtils.check_padding_format(pad))
+        else:
+            raise NotImplementedError(f"even kernel size not implemented. Got {kernel_size}")
+
+    @staticmethod
+    def spatial_2d_padding(padding=((1, 1), (1, 1)), data_format=None):
+        """Pads the 2nd and 3rd dimensions of a 4D tensor.
+
+        # Arguments
+            x: Tensor or variable.
+            padding: Tuple of 2 tuples, padding pattern.
+            data_format: string, `"channels_last"` or `"channels_first"`.
+
+        # Returns
+            A padded 4D tensor.
+
+        # Raises
+            ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`.
+        """
+        assert len(padding) == 2
+        assert len(padding[0]) == 2
+        assert len(padding[1]) == 2
+        data_format = normalize_data_format(data_format)
+
+        pattern = [[0, 0],
+                   list(padding[0]),
+                   list(padding[1]),
+                   [0, 0]]
+        pattern = transpose_shape(pattern, data_format, spatial_axes=(1, 2))
+        return pattern
+
+    @staticmethod
+    def check_padding_format(padding):
+        if isinstance(padding, int):
+            normalized_padding = ((padding, padding), (padding, padding))
+        elif hasattr(padding, '__len__'):
+            if len(padding) != 2:
+                raise ValueError('`padding` should have two elements. '
+                                 'Found: ' + str(padding))
+            for idx_pad, sub_pad in enumerate(padding):
+                if isinstance(sub_pad, str):
+                    raise ValueError(f'`padding[{idx_pad}]` is str but must be int')
+                if hasattr(sub_pad, '__len__'):
+                    if len(sub_pad) != 2:
+                        raise ValueError(f'`padding[{idx_pad}]` should have one or two elements. '
+                                         f'Found: {padding[idx_pad]}')
+                    if not all(isinstance(sub_k, int) for sub_k in padding[idx_pad]):
+                        raise ValueError(f'`padding[{idx_pad}]` should have one or two elements of type int. ' 
+                                         f"Found:{padding[idx_pad]} of type {[type(sub_k) for sub_k in padding[idx_pad]]}")
+            height_padding = conv_utils.normalize_tuple(padding[0], 2,
+                                                        '1st entry of padding')
+            if not all(k >= 0 for k in height_padding):
+                raise ValueError(f"The `1st entry of padding` argument must be >= 0. Received: {padding[0]} of type {type(padding[0])}")
+            width_padding = conv_utils.normalize_tuple(padding[1], 2,
+                                                       '2nd entry of padding')
+            if not all(k >= 0 for k in width_padding):
+                raise ValueError(f"The `2nd entry of padding` argument must be >= 0. Received: {padding[1]} of type {type(padding[1])}")
+            normalized_padding = (height_padding, width_padding)
+        else:
+            raise ValueError('`padding` should be either an int, '
+                             'a tuple of 2 ints '
+                             '(symmetric_height_pad, symmetric_width_pad), '
+                             'or a tuple of 2 tuples of 2 ints '
+                             '((top_pad, bottom_pad), (left_pad, right_pad)). '
+                             f'Found: {padding} of type {type(padding)}')
+        return normalized_padding
+
+
+class ReflectionPadding2D(_ZeroPadding):
+    """
+    Reflection padding layer for 2D input. This custum padding layer is built on keras' zero padding layers. Doc is copy
+    pasted from the original functions/methods:
+
+
+    This layer can add rows and columns of reflected values
+    at the top, bottom, left and right side of an image like tensor.
+
+    Example:
+                                                        6, 5,  4, 5, 6,  5, 4
+                                                              _________
+     1, 2, 3     RefPad(padding=[[1, 1,], [2, 2]])      3, 2,| 1, 2, 3,| 2, 1
+     4, 5, 6     =============================>>>>      6, 5,| 4, 5, 6,| 5, 4
+                                                              _________
+                                                        3, 2,  1, 2, 3,  2, 1
+
+
+
+    '# Arguments
+        padding: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints.
+            - If int: the same symmetric padding
+                is applied to height and width.
+            - If tuple of 2 ints:
+                interpreted as two different
+                symmetric padding values for height and width:
+                `(symmetric_height_pad, symmetric_width_pad)`.
+            - If tuple of 2 tuples of 2 ints:
+                interpreted as
+                `((top_pad, bottom_pad), (left_pad, right_pad))`
+        data_format: A string,
+            one of `"channels_last"` or `"channels_first"`.
+            The ordering of the dimensions in the inputs.
+            `"channels_last"` corresponds to inputs with shape
+            `(batch, height, width, channels)` while `"channels_first"`
+            corresponds to inputs with shape
+            `(batch, channels, height, width)`.
+            It defaults to the `image_data_format` value found in your
+            Keras config file at `~/.keras/keras.json`.
+            If you never set it, then it will be "channels_last".
+
+    # Input shape
+        4D tensor with shape:
+        - If `data_format` is `"channels_last"`:
+            `(batch, rows, cols, channels)`
+        - If `data_format` is `"channels_first"`:
+            `(batch, channels, rows, cols)`
+
+    # Output shape
+        4D tensor with shape:
+        - If `data_format` is `"channels_last"`:
+            `(batch, padded_rows, padded_cols, channels)`
+        - If `data_format` is `"channels_first"`:
+            `(batch, channels, padded_rows, padded_cols)`
+    '
+    """
+
+    @interfaces.legacy_zeropadding2d_support
+    def __init__(self,
+                 padding=(1, 1),
+                 data_format=None,
+                 **kwargs):
+        normalized_padding = PadUtils.check_padding_format(padding=padding)
+        super(ReflectionPadding2D, self).__init__(normalized_padding,
+                                                  data_format,
+                                                  **kwargs)
+
+    def call(self, inputs, mask=None):
+        pattern = PadUtils.spatial_2d_padding(padding=self.padding, data_format=self.data_format)
+        return tf.pad(inputs, pattern, 'REFLECT')
+
+
+class SymmetricPadding2D(_ZeroPadding):
+    """
+    Symmetric padding layer for 2D input. This custom padding layer is built on keras' zero padding layers. Doc is copy
+    pasted from the original functions/methods:
+
+
+    This layer can add rows and columns of symmetric values
+    at the top, bottom, left and right side of an image like tensor.
+
+        Example:
+                                                        2, 1,  1, 2, 3,  3, 2
+                                                              _________
+     1, 2, 3     SymPad(padding=[[1, 1,], [2, 2]])      2, 1,| 1, 2, 3,| 3, 2
+     4, 5, 6     =============================>>>>      5, 4,| 4, 5, 6,| 6, 5
+                                                              _________
+                                                        5, 4,  4, 5, 6,  6, 5
+
+
+    '# Arguments
+        padding: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints.
+            - If int: the same symmetric padding
+                is applied to height and width.
+            - If tuple of 2 ints:
+                interpreted as two different
+                symmetric padding values for height and width:
+                `(symmetric_height_pad, symmetric_width_pad)`.
+            - If tuple of 2 tuples of 2 ints:
+                interpreted as
+                `((top_pad, bottom_pad), (left_pad, right_pad))`
+        data_format: A string,
+            one of `"channels_last"` or `"channels_first"`.
+            The ordering of the dimensions in the inputs.
+            `"channels_last"` corresponds to inputs with shape
+            `(batch, height, width, channels)` while `"channels_first"`
+            corresponds to inputs with shape
+            `(batch, channels, height, width)`.
+            It defaults to the `image_data_format` value found in your
+            Keras config file at `~/.keras/keras.json`.
+            If you never set it, then it will be "channels_last".
+
+    # Input shape
+        4D tensor with shape:
+        - If `data_format` is `"channels_last"`:
+            `(batch, rows, cols, channels)`
+        - If `data_format` is `"channels_first"`:
+            `(batch, channels, rows, cols)`
+
+    # Output shape
+        4D tensor with shape:
+        - If `data_format` is `"channels_last"`:
+            `(batch, padded_rows, padded_cols, channels)`
+        - If `data_format` is `"channels_first"`:
+            `(batch, channels, padded_rows, padded_cols)`
+    '
+    """
+
+    @interfaces.legacy_zeropadding2d_support
+    def __init__(self,
+                 padding=(1, 1),
+                 data_format=None,
+                 **kwargs):
+        normalized_padding = PadUtils.check_padding_format(padding=padding)
+        super(SymmetricPadding2D, self).__init__(normalized_padding,
+                                                 data_format,
+                                                 **kwargs)
+
+    def call(self, inputs, mask=None):
+        pattern = PadUtils.spatial_2d_padding(padding=self.padding, data_format=self.data_format)
+        return tf.pad(inputs, pattern, 'SYMMETRIC')
+
+
+if __name__ == '__main__':
+    from keras.models import Model
+    from keras.layers import Conv2D, Flatten, Dense, Input
+
+    kernel_1 = (3, 3)
+    kernel_2 = (5, 5)
+    x = np.array(range(2000)).reshape(-1, 10, 10, 1)
+    y = x.mean(axis=(1, 2))
+
+    x_input = Input(shape=x.shape[1:])
+    pad1 = PadUtils.get_padding_for_same(kernel_size=kernel_1)
+    x_out = ReflectionPadding2D(padding=pad1, name="RefPAD")(x_input)
+    x_out = Conv2D(5, kernel_size=kernel_1, activation='relu')(x_out)
+
+    pad2 = PadUtils.get_padding_for_same(kernel_size=kernel_2)
+    x_out = SymmetricPadding2D(padding=pad2, name="SymPAD")(x_out)
+    x_out = Conv2D(2, kernel_size=kernel_2, activation='relu')(x_out)
+    x_out = Flatten()(x_out)
+    x_out = Dense(1, activation='linear')(x_out)
+
+    model = Model(inputs=x_input, outputs=x_out)
+    model.compile('adam', loss='mse')
+    model.summary()
+    model.fit(x, y, epochs=10)
+
+
diff --git a/test/test_model_modules/test_advanced_paddings.py b/test/test_model_modules/test_advanced_paddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..5282eb6df34d4d395dbbdd1fd76fd71a95e9c8df
--- /dev/null
+++ b/test/test_model_modules/test_advanced_paddings.py
@@ -0,0 +1,419 @@
+import keras
+import pytest
+
+from src.model_modules.advanced_paddings import *
+
+
+class TestPadUtils:
+
+    def test_get_padding_for_same_negative_kernel_size(self):
+        print('In test_get_padding_for_same_negative_kernel_size')
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.get_padding_for_same((-1, 2))
+        assert 'All values of kernel_size must be > 0. Got: (-1, 2) ' in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.get_padding_for_same((1, -2))
+        assert 'All values of kernel_size must be > 0. Got: (1, -2) ' in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.get_padding_for_same((-1, -2))
+        assert 'All values of kernel_size must be > 0. Got: (-1, -2) ' in str(einfo.value)
+
+    def test_get_padding_for_same_strides_greater_one(self):
+        with pytest.raises(NotImplementedError) as einfo:
+            PadUtils.get_padding_for_same((1, 1), strides=2)
+        assert 'Strides other than 1 not implemented!' in str(einfo.value)
+
+        with pytest.raises(NotImplementedError) as einfo:
+            PadUtils.get_padding_for_same((1, 1), strides=-1)
+        assert 'Strides other than 1 not implemented!' in str(einfo.value)
+
+    def test_get_padding_for_same_non_int_kernel(self):
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.get_padding_for_same((1., 1))
+        assert "The `kernel_size` argument must have a tuple of integers. Got: (1.0, 1) " \
+               "of type [<class 'float'>, <class 'int'>]" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.get_padding_for_same((1, 1.))
+        assert "The `kernel_size` argument must have a tuple of integers. Got: (1, 1.0) " \
+               "of type [<class 'int'>, <class 'float'>]" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.get_padding_for_same((1, '1.'))
+        assert "The `kernel_size` argument must have a tuple of integers. Got: (1, '1.') " \
+               "of type [<class 'int'>, <class 'str'>]" in str(einfo.value)
+
+    def test_get_padding_for_same_stride_3d(self):
+        kernel = (3, 3, 3)
+        pad = PadUtils.get_padding_for_same(kernel)
+        assert pad == (1, 1, 1)
+        assert isinstance(pad, tuple)
+        assert isinstance(pad[0], int) and isinstance(pad[1], int)
+        assert not (isinstance(pad[0], np.int64) and isinstance(pad[1], np.int64) and isinstance(pad[2], np.int64))
+
+    def test_get_padding_for_same_even_pad(self):
+        with pytest.raises(NotImplementedError) as einfo:
+            PadUtils.get_padding_for_same((2, 1))
+        assert 'even kernel size not implemented. Got (2, 1)' in str(einfo.value)
+
+        with pytest.raises(NotImplementedError) as einfo:
+            PadUtils.get_padding_for_same((1, 4))
+        assert 'even kernel size not implemented. Got (1, 4)' in str(einfo.value)
+
+        with pytest.raises(NotImplementedError) as einfo:
+            PadUtils.get_padding_for_same((2, 4))
+        assert 'even kernel size not implemented. Got (2, 4)' in str(einfo.value)
+
+    ##################################################################################
+
+    def test_check_padding_format_negative_pads(self):
+
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.check_padding_format((-2, 1))
+        assert "The `1st entry of padding` argument must be >= 0. Received: -2 of type <class 'int'>" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.check_padding_format((1, -1))
+        assert "The `2nd entry of padding` argument must be >= 0. Received: -1 of type <class 'int'>" in str(
+            einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.check_padding_format((-2, -1))
+        assert "The `1st entry of padding` argument must be >= 0. Received: -2 of type <class 'int'>" in str(
+            einfo.value)
+
+    def test_check_padding_format_len_of_pad_tuple(self):
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.check_padding_format((1, 1, 2))
+        assert "`padding` should have two elements. Found: (1, 1, 2)" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.check_padding_format((1, 1, 2, 2))
+        assert "`padding` should have two elements. Found: (1, 1, 2, 2)" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.check_padding_format(((1, 1, 3), (2, 2, 4)))
+        assert "`padding[0]` should have one or two elements. Found: (1, 1, 3)" in str(einfo.value)
+
+        assert PadUtils.check_padding_format(((1, 1), (2, 2))) == ((1, 1), (2, 2))
+        assert PadUtils.check_padding_format((1, 2)) == ((1, 1), (2, 2))
+        assert PadUtils.check_padding_format(1) == ((1, 1), (1, 1))
+
+    def test_check_padding_format_tuple_of_none_integer(self):
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.check_padding_format((1.2, 1))
+        assert "The `1st entry of padding` argument must be a tuple of 2 integers. Received: 1.2" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.check_padding_format((1, 1.))
+        assert "The `2nd entry of padding` argument must be a tuple of 2 integers. Received: 1.0" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.check_padding_format(1.2)
+        assert "`padding` should be either an int, a tuple of 2 ints (symmetric_height_pad, symmetric_width_pad), " \
+               "or a tuple of 2 tuples of 2 ints ((top_pad, bottom_pad), (left_pad, right_pad)). Found: 1.2 of type " \
+               "<class 'float'>" in str(einfo.value)
+
+    def test_check_padding_format_tuple_of_tuple_none_integer_first(self):
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.check_padding_format(((1., 2), (3, 4)))
+        assert "`padding[0]` should have one or two elements of type int. Found:(1.0, 2) " \
+               "of type [<class 'float'>, <class 'int'>]" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.check_padding_format(((1, 2.), (3, 4)))
+        assert "`padding[0]` should have one or two elements of type int. Found:(1, 2.0) " \
+               "of type [<class 'int'>, <class 'float'>]" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.check_padding_format(((1, '2'), (3, 4)))
+        assert "`padding[0]` should have one or two elements of type int. Found:(1, '2') " \
+               "of type [<class 'int'>, <class 'str'>]" in str(einfo.value)
+
+    def test_check_padding_format_tuple_of_tuple_none_integer_second(self):
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.check_padding_format(((1, 2), (3., 4)))
+        assert "`padding[1]` should have one or two elements of type int. Found:(3.0, 4) " \
+               "of type [<class 'float'>, <class 'int'>]" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.check_padding_format(((1, 2), (3, 4.)))
+        assert "`padding[1]` should have one or two elements of type int. Found:(3, 4.0) " \
+               "of type [<class 'int'>, <class 'float'>]" in str(einfo.value)
+
+    def test_check_padding_format_valid_mix_of_int_and_tuple(self):
+        assert PadUtils.check_padding_format(((1, 2), 3)) == ((1, 2), (3, 3))
+        assert PadUtils.check_padding_format((1, (2, 3))) == ((1, 1), (2, 3))
+
+    def test_check_padding_format_invalid_mixed_tuple_and_int(self):
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.check_padding_format(((1., 2), 3))
+        assert "`padding[0]` should have one or two elements of type int. Found:(1.0, 2) " \
+               "of type [<class 'float'>, <class 'int'>]" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.check_padding_format(((1, 2), 3.))
+        assert "The `2nd entry of padding` argument must be a tuple of 2 integers. Received: 3.0" in str(einfo.value)
+
+    def test_check_padding_format_invalid_mixed_int_and_tuple(self):
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.check_padding_format((1., (2, 3)))
+        assert "The `1st entry of padding` argument must be a tuple of 2 integers. Received: 1.0" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            PadUtils.check_padding_format((1, (2., 3)))
+        assert "`padding[1]` should have one or two elements of type int. Found:(2.0, 3) " \
+               "of type [<class 'float'>, <class 'int'>]" in str(einfo.value)
+
+
+class TestReflectionPadding2D:
+
+    @pytest.fixture
+    def input_x(self):
+        return keras.Input(shape=(10, 10, 3))
+
+    def test_init_tuple_of_valid_int(self):
+        pad = (1, 3)
+        layer_name = "RefPAD"
+        ref_pad = ReflectionPadding2D(padding=pad, name=layer_name)
+        assert ref_pad.padding == ((1, 1), (3, 3))
+        assert ref_pad.name == 'RefPAD'
+        assert ref_pad.data_format == 'channels_last'
+        assert ref_pad.rank == 2
+
+        pad = (0, 1)
+        ref_pad = ReflectionPadding2D(padding=pad, name=layer_name)
+        assert ref_pad.padding == ((0, 0), (1, 1))
+        assert ref_pad.name == 'RefPAD'
+        assert ref_pad.data_format == 'channels_last'
+        assert ref_pad.rank == 2
+
+        pad = (5, 3)
+        layer_name = "RefPAD_5x3"
+        ref_pad = ReflectionPadding2D(padding=pad, name=layer_name)
+        assert ref_pad.padding == ((5, 5), (3, 3))
+
+    def test_init_tuple_of_negative_int(self):
+        with pytest.raises(ValueError) as einfo:
+            ReflectionPadding2D(padding=(-1, 1))
+        assert "The `1st entry of padding` argument must be >= 0. Received: -1 of type <class 'int'>" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            ReflectionPadding2D(padding=(1, -2))
+        assert "The `2nd entry of padding` argument must be >= 0. Received: -2 of type <class 'int'>" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            ReflectionPadding2D(padding=(-1, -2))
+        assert "The `1st entry of padding` argument must be >= 0. Received: -1 of type <class 'int'>" in str(einfo.value)
+
+    def test_init_tuple_of_invalid_format_float(self):
+        with pytest.raises(ValueError) as einfo:
+            ReflectionPadding2D(padding=(1., 1))
+        assert 'The `1st entry of padding` argument must be a tuple of 2 integers. Received: 1.0' in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            ReflectionPadding2D(padding=(1, 1.2))
+        assert 'The `2nd entry of padding` argument must be a tuple of 2 integers. Received: 1.2' in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            ReflectionPadding2D(padding=(1., 1.2))
+        assert 'The `1st entry of padding` argument must be a tuple of 2 integers. Received: 1.0' in str(einfo.value)
+
+    def test_init_tuple_of_invalid_format_string(self):
+        with pytest.raises(ValueError) as einfo:
+            ReflectionPadding2D(padding=('1', 2))
+        # This error message is not the best as it is missing the type information.
+        # But it is raised by keras.utils.conv_utils which I will not touch.
+        assert "`padding[0]` is str but must be int" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            ReflectionPadding2D(padding=(1, '2'))
+        assert '`padding[1]` is str but must be int' in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            ReflectionPadding2D(padding=('1', '2'))
+        assert '`padding[0]` is str but must be int' in str(einfo.value)
+
+    def test_init_int(self):
+        layer_name = "RefPAD"
+        ref_pad = ReflectionPadding2D(padding=1, name=layer_name)
+        assert ref_pad.padding == ((1, 1), (1, 1))
+        assert ref_pad.name == "RefPAD"
+
+    def test_init_tuple_of_tuple_of_valid_int(self):
+        ref_pad = ReflectionPadding2D(padding=((0, 1), (2, 3)), name="RefPAD")
+        assert ref_pad.padding == ((0, 1), (2, 3))
+        assert ref_pad.name == "RefPAD"
+
+    def test_init_tuple_of_tuple_of_invalid_int(self):
+        with pytest.raises(ValueError) as einfo:
+            ReflectionPadding2D(padding=((-4, 1), (2, 3)), name="RefPAD")
+        assert "The `1st entry of padding` argument must be >= 0. Received: (-4, 1) of type <class 'tuple'>" in str(
+            einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            ReflectionPadding2D(padding=((4, -1), (2, 3)), name="RefPAD")
+        assert "The `1st entry of padding` argument must be >= 0. Received: (4, -1) of type <class 'tuple'>" in str(
+            einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            ReflectionPadding2D(padding=((4, 1), (-2, 3)), name="RefPAD")
+        assert "The `2nd entry of padding` argument must be >= 0. Received: (-2, 3) of type <class 'tuple'>" in str(
+            einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            ReflectionPadding2D(padding=((4, 1), (2, -3)), name="RefPAD")
+        assert "The `2nd entry of padding` argument must be >= 0. Received: (2, -3) of type <class 'tuple'>" in str(
+            einfo.value)
+
+    def test_init_tuple_of_tuple_of_invalid_format(self):
+        with pytest.raises(ValueError) as einfo:
+            ReflectionPadding2D(padding=((0.1, 1), (2, 3)), name="RefPAD")
+        assert "`padding[0]` should have one or two elements of type int. Found:(0.1, 1) " \
+               "of type [<class 'float'>, <class 'int'>]" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            ReflectionPadding2D(padding=(1, 2.2))
+        assert "The `2nd entry of padding` argument must be a tuple of 2 integers. Received: 2.2" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            ReflectionPadding2D(padding=((0, 1), ('2', 3)), name="RefPAD")
+        assert "`padding[1]` should have one or two elements of type int. Found:('2', 3) " \
+               "of type [<class 'str'>, <class 'int'>]" in str(einfo.value)
+
+    def test_call(self, input_x):
+        # here it behaves like a "normal" keras layer, I don't know how to test those
+        pad = (1, 0)
+        layer_name = "RefPad_3x1"
+        ref_pad = ReflectionPadding2D(padding=pad, name=layer_name)(input_x)
+        assert ref_pad.get_shape().as_list() == [None, 12, 10, 3]
+        assert ref_pad.name == 'RefPad_3x1/MirrorPad:0'
+
+
+class TestSymmerticPadding2D:
+
+    @pytest.fixture
+    def input_x(self):
+        return keras.Input(shape=(10, 10, 3))
+
+    def test_init_tuple_of_valid_int(self):
+        pad = (1, 3)
+        layer_name = "SymPad"
+        sym_pad = SymmetricPadding2D(padding=pad, name=layer_name)
+        assert sym_pad.padding == ((1, 1), (3, 3))
+        assert sym_pad.name == 'SymPad'
+        assert sym_pad.data_format == 'channels_last'
+        assert sym_pad.rank == 2
+
+        pad = (0, 1)
+        sym_pad = SymmetricPadding2D(padding=pad, name=layer_name)
+        assert sym_pad.padding == ((0, 0), (1, 1))
+        assert sym_pad.name == 'SymPad'
+        assert sym_pad.data_format == 'channels_last'
+        assert sym_pad.rank == 2
+
+        pad = (5, 3)
+        layer_name = "SymPad_5x3"
+        sym_pad = SymmetricPadding2D(padding=pad, name=layer_name)
+        assert sym_pad.padding == ((5, 5), (3, 3))
+
+    def test_init_tuple_of_negative_int(self):
+        with pytest.raises(ValueError) as einfo:
+            SymmetricPadding2D(padding=(-1, 1))
+        assert "The `1st entry of padding` argument must be >= 0. Received: -1 of type <class 'int'>" in str(
+            einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            SymmetricPadding2D(padding=(1, -2))
+        assert "The `2nd entry of padding` argument must be >= 0. Received: -2 of type <class 'int'>" in str(
+            einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            SymmetricPadding2D(padding=(-1, -2))
+        assert "The `1st entry of padding` argument must be >= 0. Received: -1 of type <class 'int'>" in str(
+            einfo.value)
+
+    def test_init_tuple_of_invalid_format_float(self):
+        with pytest.raises(ValueError) as einfo:
+            SymmetricPadding2D(padding=(1., 1))
+        assert 'The `1st entry of padding` argument must be a tuple of 2 integers. Received: 1.0' in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            SymmetricPadding2D(padding=(1, 1.2))
+        assert 'The `2nd entry of padding` argument must be a tuple of 2 integers. Received: 1.2' in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            SymmetricPadding2D(padding=(1., 1.2))
+        assert 'The `1st entry of padding` argument must be a tuple of 2 integers. Received: 1.0' in str(einfo.value)
+
+    def test_init_tuple_of_invalid_format_string(self):
+        with pytest.raises(ValueError) as einfo:
+            SymmetricPadding2D(padding=('1', 2))
+        # This error message is not the best as it is missing the type information.
+        # But it is raised by keras.utils.conv_utils which I will not touch.
+        assert "`padding[0]` is str but must be int" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            SymmetricPadding2D(padding=(1, '2'))
+        assert '`padding[1]` is str but must be int' in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            SymmetricPadding2D(padding=('1', '2'))
+        assert '`padding[0]` is str but must be int' in str(einfo.value)
+
+    def test_init_int(self):
+        layer_name = "SymPad"
+        sym_pad = SymmetricPadding2D(padding=1, name=layer_name)
+        assert sym_pad.padding == ((1, 1), (1, 1))
+        assert sym_pad.name == "SymPad"
+
+    def test_init_tuple_of_tuple_of_valid_int(self):
+        sym_pad = SymmetricPadding2D(padding=((0, 1), (2, 3)), name="SymPad")
+        assert sym_pad.padding == ((0, 1), (2, 3))
+        assert sym_pad.name == "SymPad"
+
+    def test_init_tuple_of_tuple_of_invalid_int(self):
+        with pytest.raises(ValueError) as einfo:
+            SymmetricPadding2D(padding=((-4, 1), (2, 3)), name="SymPad")
+        assert "The `1st entry of padding` argument must be >= 0. Received: (-4, 1) of type <class 'tuple'>" in str(
+            einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            SymmetricPadding2D(padding=((4, -1), (2, 3)), name="SymPad")
+        assert "The `1st entry of padding` argument must be >= 0. Received: (4, -1) of type <class 'tuple'>" in str(
+            einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            SymmetricPadding2D(padding=((4, 1), (-2, 3)), name="SymPad")
+        assert "The `2nd entry of padding` argument must be >= 0. Received: (-2, 3) of type <class 'tuple'>" in str(
+            einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            SymmetricPadding2D(padding=((4, 1), (2, -3)), name="SymPad")
+        assert "The `2nd entry of padding` argument must be >= 0. Received: (2, -3) of type <class 'tuple'>" in str(
+            einfo.value)
+
+    def test_init_tuple_of_tuple_of_invalid_format(self):
+        with pytest.raises(ValueError) as einfo:
+            SymmetricPadding2D(padding=((0.1, 1), (2, 3)), name="SymPad")
+        assert "`padding[0]` should have one or two elements of type int. Found:(0.1, 1) " \
+               "of type [<class 'float'>, <class 'int'>]" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            SymmetricPadding2D(padding=(1, 2.2))
+        assert "The `2nd entry of padding` argument must be a tuple of 2 integers. Received: 2.2" in str(einfo.value)
+
+        with pytest.raises(ValueError) as einfo:
+            SymmetricPadding2D(padding=((0, 1), ('2', 3)), name="SymPad")
+        assert "`padding[1]` should have one or two elements of type int. Found:('2', 3) " \
+               "of type [<class 'str'>, <class 'int'>]" in str(einfo.value)
+
+    def test_call(self, input_x):
+        # here it behaves like a "normal" keras layer, I don't know how to test those
+        pad = (1, 0)
+        layer_name = "SymPad_3x1"
+        sym_pad = SymmetricPadding2D(padding=pad, name=layer_name)(input_x)
+        assert sym_pad.get_shape().as_list() == [None, 12, 10, 3]
+        assert sym_pad.name == 'SymPad_3x1/MirrorPad:0'