__author__ = 'Felix Kleinert'
__date__ = '2020-03-02'

import tensorflow as tf
import numpy as np
import keras.backend as K

from keras.layers.convolutional import _ZeroPadding
from keras.legacy import interfaces
from keras.utils import conv_utils
from keras.utils.generic_utils import transpose_shape
from keras.backend.common import normalize_data_format


class PadUtils:
    """
    Helper class for advanced paddings
    """

    @staticmethod
    def get_padding_for_same(kernel_size, strides=1):
        """
        This methods calculates the padding size to keep input and output dimensions equal for a given kernel size
        (STRIDES HAVE TO BE EQUAL TO ONE!)
        :param kernel_size:
        :return:
        """
        if strides != 1:
            raise NotImplementedError("Strides other than 1 not implemented!")
        if not all(isinstance(k, int) for k in kernel_size):
            raise ValueError(f"The `kernel_size` argument must have a tuple of integers. Got: {kernel_size} "
                             f"of type {[type(k) for k in kernel_size]}")

        ks = np.array(kernel_size, dtype=np.int64)

        if any(k <= 0 for k in ks):
            raise ValueError(f"All values of kernel_size must be > 0. Got: {kernel_size} ")

        if all(k % 2 == 1 for k in ks):  # (d & 0x1 for d in ks):
            pad = ((ks - 1) / 2).astype(np.int64)
            # convert numpy int to base int
            pad = [np.asscalar(v) for v in pad]
            return tuple(pad)
            # return tuple(PadUtils.check_padding_format(pad))
        else:
            raise NotImplementedError(f"even kernel size not implemented. Got {kernel_size}")

    @staticmethod
    def spatial_2d_padding(padding=((1, 1), (1, 1)), data_format=None):
        """Pads the 2nd and 3rd dimensions of a 4D tensor.

        # Arguments
            x: Tensor or variable.
            padding: Tuple of 2 tuples, padding pattern.
            data_format: string, `"channels_last"` or `"channels_first"`.

        # Returns
            A padded 4D tensor.

        # Raises
            ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`.
        """
        assert len(padding) == 2
        assert len(padding[0]) == 2
        assert len(padding[1]) == 2
        data_format = normalize_data_format(data_format)

        pattern = [[0, 0],
                   list(padding[0]),
                   list(padding[1]),
                   [0, 0]]
        pattern = transpose_shape(pattern, data_format, spatial_axes=(1, 2))
        return pattern

    @staticmethod
    def check_padding_format(padding):
        if isinstance(padding, int):
            normalized_padding = ((padding, padding), (padding, padding))
        elif hasattr(padding, '__len__'):
            if len(padding) != 2:
                raise ValueError('`padding` should have two elements. '
                                 'Found: ' + str(padding))
            for idx_pad, sub_pad in enumerate(padding):
                if isinstance(sub_pad, str):
                    raise ValueError(f'`padding[{idx_pad}]` is str but must be int')
                if hasattr(sub_pad, '__len__'):
                    if len(sub_pad) != 2:
                        raise ValueError(f'`padding[{idx_pad}]` should have one or two elements. '
                                         f'Found: {padding[idx_pad]}')
                    if not all(isinstance(sub_k, int) for sub_k in padding[idx_pad]):
                        raise ValueError(f'`padding[{idx_pad}]` should have one or two elements of type int. ' 
                                         f"Found:{padding[idx_pad]} of type {[type(sub_k) for sub_k in padding[idx_pad]]}")
            height_padding = conv_utils.normalize_tuple(padding[0], 2,
                                                        '1st entry of padding')
            if not all(k >= 0 for k in height_padding):
                raise ValueError(f"The `1st entry of padding` argument must be >= 0. Received: {padding[0]} of type {type(padding[0])}")
            width_padding = conv_utils.normalize_tuple(padding[1], 2,
                                                       '2nd entry of padding')
            if not all(k >= 0 for k in width_padding):
                raise ValueError(f"The `2nd entry of padding` argument must be >= 0. Received: {padding[1]} of type {type(padding[1])}")
            normalized_padding = (height_padding, width_padding)
        else:
            raise ValueError('`padding` should be either an int, '
                             'a tuple of 2 ints '
                             '(symmetric_height_pad, symmetric_width_pad), '
                             'or a tuple of 2 tuples of 2 ints '
                             '((top_pad, bottom_pad), (left_pad, right_pad)). '
                             f'Found: {padding} of type {type(padding)}')
        return normalized_padding


class ReflectionPadding2D(_ZeroPadding):
    """
    Reflection padding layer for 2D input. This custum padding layer is built on keras' zero padding layers. Doc is copy
    pasted from the original functions/methods:


    This layer can add rows and columns of reflected values
    at the top, bottom, left and right side of an image like tensor.

    Example:
                                                        6, 5,  4, 5, 6,  5, 4
                                                              _________
     1, 2, 3     RefPad(padding=[[1, 1,], [2, 2]])      3, 2,| 1, 2, 3,| 2, 1
     4, 5, 6     =============================>>>>      6, 5,| 4, 5, 6,| 5, 4
                                                              _________
                                                        3, 2,  1, 2, 3,  2, 1



    '# Arguments
        padding: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints.
            - If int: the same symmetric padding
                is applied to height and width.
            - If tuple of 2 ints:
                interpreted as two different
                symmetric padding values for height and width:
                `(symmetric_height_pad, symmetric_width_pad)`.
            - If tuple of 2 tuples of 2 ints:
                interpreted as
                `((top_pad, bottom_pad), (left_pad, right_pad))`
        data_format: A string,
            one of `"channels_last"` or `"channels_first"`.
            The ordering of the dimensions in the inputs.
            `"channels_last"` corresponds to inputs with shape
            `(batch, height, width, channels)` while `"channels_first"`
            corresponds to inputs with shape
            `(batch, channels, height, width)`.
            It defaults to the `image_data_format` value found in your
            Keras config file at `~/.keras/keras.json`.
            If you never set it, then it will be "channels_last".

    # Input shape
        4D tensor with shape:
        - If `data_format` is `"channels_last"`:
            `(batch, rows, cols, channels)`
        - If `data_format` is `"channels_first"`:
            `(batch, channels, rows, cols)`

    # Output shape
        4D tensor with shape:
        - If `data_format` is `"channels_last"`:
            `(batch, padded_rows, padded_cols, channels)`
        - If `data_format` is `"channels_first"`:
            `(batch, channels, padded_rows, padded_cols)`
    '
    """

    @interfaces.legacy_zeropadding2d_support
    def __init__(self,
                 padding=(1, 1),
                 data_format=None,
                 **kwargs):
        normalized_padding = PadUtils.check_padding_format(padding=padding)
        super(ReflectionPadding2D, self).__init__(normalized_padding,
                                                  data_format,
                                                  **kwargs)

    def call(self, inputs, mask=None):
        pattern = PadUtils.spatial_2d_padding(padding=self.padding, data_format=self.data_format)
        return tf.pad(inputs, pattern, 'REFLECT')


class SymmetricPadding2D(_ZeroPadding):
    """
    Symmetric padding layer for 2D input. This custom padding layer is built on keras' zero padding layers. Doc is copy
    pasted from the original functions/methods:


    This layer can add rows and columns of symmetric values
    at the top, bottom, left and right side of an image like tensor.

        Example:
                                                        2, 1,  1, 2, 3,  3, 2
                                                              _________
     1, 2, 3     SymPad(padding=[[1, 1,], [2, 2]])      2, 1,| 1, 2, 3,| 3, 2
     4, 5, 6     =============================>>>>      5, 4,| 4, 5, 6,| 6, 5
                                                              _________
                                                        5, 4,  4, 5, 6,  6, 5


    '# Arguments
        padding: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints.
            - If int: the same symmetric padding
                is applied to height and width.
            - If tuple of 2 ints:
                interpreted as two different
                symmetric padding values for height and width:
                `(symmetric_height_pad, symmetric_width_pad)`.
            - If tuple of 2 tuples of 2 ints:
                interpreted as
                `((top_pad, bottom_pad), (left_pad, right_pad))`
        data_format: A string,
            one of `"channels_last"` or `"channels_first"`.
            The ordering of the dimensions in the inputs.
            `"channels_last"` corresponds to inputs with shape
            `(batch, height, width, channels)` while `"channels_first"`
            corresponds to inputs with shape
            `(batch, channels, height, width)`.
            It defaults to the `image_data_format` value found in your
            Keras config file at `~/.keras/keras.json`.
            If you never set it, then it will be "channels_last".

    # Input shape
        4D tensor with shape:
        - If `data_format` is `"channels_last"`:
            `(batch, rows, cols, channels)`
        - If `data_format` is `"channels_first"`:
            `(batch, channels, rows, cols)`

    # Output shape
        4D tensor with shape:
        - If `data_format` is `"channels_last"`:
            `(batch, padded_rows, padded_cols, channels)`
        - If `data_format` is `"channels_first"`:
            `(batch, channels, padded_rows, padded_cols)`
    '
    """

    @interfaces.legacy_zeropadding2d_support
    def __init__(self,
                 padding=(1, 1),
                 data_format=None,
                 **kwargs):
        normalized_padding = PadUtils.check_padding_format(padding=padding)
        super(SymmetricPadding2D, self).__init__(normalized_padding,
                                                 data_format,
                                                 **kwargs)

    def call(self, inputs, mask=None):
        pattern = PadUtils.spatial_2d_padding(padding=self.padding, data_format=self.data_format)
        return tf.pad(inputs, pattern, 'SYMMETRIC')


if __name__ == '__main__':
    from keras.models import Model
    from keras.layers import Conv2D, Flatten, Dense, Input

    kernel_1 = (3, 3)
    kernel_2 = (5, 5)
    x = np.array(range(2000)).reshape(-1, 10, 10, 1)
    y = x.mean(axis=(1, 2))

    x_input = Input(shape=x.shape[1:])
    pad1 = PadUtils.get_padding_for_same(kernel_size=kernel_1)
    x_out = ReflectionPadding2D(padding=pad1, name="RefPAD")(x_input)
    x_out = Conv2D(5, kernel_size=kernel_1, activation='relu')(x_out)

    pad2 = PadUtils.get_padding_for_same(kernel_size=kernel_2)
    x_out = SymmetricPadding2D(padding=pad2, name="SymPAD")(x_out)
    x_out = Conv2D(2, kernel_size=kernel_2, activation='relu')(x_out)
    x_out = Flatten()(x_out)
    x_out = Dense(1, activation='linear')(x_out)

    model = Model(inputs=x_input, outputs=x_out)
    model.compile('adam', loss='mse')
    model.summary()
    model.fit(x, y, epochs=10)