__author__ = 'Felix Kleinert, Lukas Leufen'
__date__ = '2019-10-22'

import logging

import keras
import keras.layers as layers


class InceptionModelBase:
    """
    This class contains all necessary construction blocks
    """

    def __init__(self):
        self.number_of_blocks = 0
        self.part_of_block = 0
        self.act_number = 0
        self.ord_base = 96  # set to 96 as always add +1 for new part of block, chr(97)='a'

    def block_part_name(self):
        """
        Use unicode due to some issues of keras with normal strings
        :return:
        """
        return chr(self.ord_base + self.part_of_block)

    def batch_normalisation(self, input_x, **kwargs):
        block_name = f"Block_{self.number_of_blocks}{self.block_part_name()}_BN"
        return layers.BatchNormalization(name=block_name, **kwargs)(input_x)

    def create_conv_tower(self,
                          input_x,
                          reduction_filter,
                          tower_filter,
                          tower_kernel,
                          activation='relu',
                          batch_normalisation=False,
                          **kwargs):
        """
        This function creates a "convolution tower block" containing a 1x1 convolution to reduce filter size followed by
        convolution with given filter and kernel size
        :param input_x: Input to network part
        :param reduction_filter: Number of filters used in 1x1 convolution to reduce overall filter size before conv.
        :param tower_filter: Number of filters for n x m convolution
        :param tower_kernel: kernel size for convolution (n,m)
        :param activation: activation function for convolution
        :param batch_normalisation:
        :return:
        """
        self.part_of_block += 1
        self.act_number = 1
        regularizer = kwargs.get('regularizer', keras.regularizers.l2(0.01))
        bn_settings = kwargs.get('bn_settings', {})
        act_settings = kwargs.get('act_settings', {})
        logging.debug(f'Inception Block with activation: {activation}')

        block_name = f'Block_{self.number_of_blocks}{self.block_part_name()}_{tower_kernel[0]}x{tower_kernel[1]}'

        if tower_kernel == (1, 1):
            tower = layers.Conv2D(tower_filter,
                                  tower_kernel,
                                  padding='same',
                                  kernel_regularizer=regularizer,
                                  name=block_name)(input_x)
            tower = self.act(tower, activation, **act_settings)
        else:
            tower = layers.Conv2D(reduction_filter,
                                  (1, 1),
                                  padding='same',
                                  kernel_regularizer=regularizer,
                                  name=f'Block_{self.number_of_blocks}{self.block_part_name()}_1x1')(input_x)
            tower = self.act(tower, activation, **act_settings)

            tower = layers.Conv2D(tower_filter,
                                  tower_kernel,
                                  padding='same',
                                  kernel_regularizer=regularizer,
                                  name=block_name)(tower)
            if batch_normalisation:
                tower = self.batch_normalisation(tower, **bn_settings)
            tower = self.act(tower, activation, **act_settings)

        return tower

    def act(self, input_x, activation, **act_settings):
        block_name = f"Block_{self.number_of_blocks}{self.block_part_name()}_act_{self.act_number}"
        try:
            out = getattr(layers, self._get_act_name(activation))(**act_settings, name=block_name)(input_x)
        except AttributeError:
            block_name += f"_{activation.lower()}"
            out = layers.Activation(activation.lower(), name=block_name)(input_x)
        self.act_number += 1
        return out

    @staticmethod
    def _get_act_name(act_name):
        if isinstance(act_name, str):
            mapping = {'relu': 'ReLU', 'prelu': 'PReLU', 'elu': 'ELU'}
            return mapping.get(act_name.lower(), act_name)
        else:
            return act_name.__name__

    def create_pool_tower(self, input_x, pool_kernel, tower_filter, activation='relu', max_pooling=True, **kwargs):
        """
        This function creates a "MaxPooling tower block"
        :param input_x: Input to network part
        :param pool_kernel: size of pooling kernel
        :param tower_filter: Number of filters used in 1x1 convolution to reduce filter size
        :param activation:
        :param max_pooling:
        :return:
        """
        self.part_of_block += 1
        self.act_number = 1
        act_settings = kwargs.get('act_settings', {})

        # pooling block
        block_name = f"Block_{self.number_of_blocks}{self.block_part_name()}_"
        if max_pooling:
            block_type = "MaxPool"
            pooling = layers.MaxPooling2D
        else:
            block_type = "AvgPool"
            pooling = layers.AveragePooling2D
        tower = pooling(pool_kernel, strides=(1, 1), padding='same', name=block_name+block_type)(input_x)

        # convolution block
        tower = layers.Conv2D(tower_filter, (1, 1), padding='same', name=block_name+"1x1")(tower)
        tower = self.act(tower, activation, **act_settings)

        return tower

    def inception_block(self, input_x, tower_conv_parts, tower_pool_parts, **kwargs):
        """
        Crate a inception block
        :param input_x: Input to block
        :param tower_conv_parts: dict containing settings for parts of inception block; Example:
                                 tower_conv_parts = {'tower_1': {'reduction_filter': 32,
                                                                 'tower_filter': 64,
                                                                 'tower_kernel': (3, 1)},
                                                     'tower_2': {'reduction_filter': 32,
                                                                 'tower_filter': 64,
                                                                 'tower_kernel': (5, 1)},
                                                     'tower_3': {'reduction_filter': 32,
                                                                 'tower_filter': 64,
                                                                 'tower_kernel': (1, 1)},
                                                    }
        :param tower_pool_parts: dict containing settings for pool part of inception block; Example:
                                 tower_pool_parts = {'pool_kernel': (3, 1), 'tower_filter': 64}
        :return:
        """
        self.number_of_blocks += 1
        self.part_of_block = 0
        tower_build = {}
        block_name = f"Block_{self.number_of_blocks}"
        for part, part_settings in tower_conv_parts.items():
            tower_build[part] = self.create_conv_tower(input_x, **part_settings, **kwargs)
        if 'max_pooling' in tower_pool_parts.keys():
            max_pooling = tower_pool_parts.get('max_pooling')
            if not isinstance(max_pooling, bool):
                raise AttributeError(f"max_pooling has to be either a bool or empty. Given was: {max_pooling}")
            pool_name = '{}pool'.format('max' if max_pooling else 'avg')
            tower_build[pool_name] = self.create_pool_tower(input_x, **tower_pool_parts, **kwargs)
        else:
            tower_build['maxpool'] = self.create_pool_tower(input_x, **tower_pool_parts, **kwargs)
            tower_build['avgpool'] = self.create_pool_tower(input_x, **tower_pool_parts, **kwargs, max_pooling=False)

        block = keras.layers.concatenate(list(tower_build.values()), axis=3,
                                         name=block_name+"_Co")
        return block


if __name__ == '__main__':
    print(__name__)
    from keras.datasets import cifar10
    from keras.utils import np_utils
    from keras.layers import Input
    from keras.layers.advanced_activations import LeakyReLU
    from keras.optimizers import SGD
    from keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
    from keras.models import Model

    # network settings
    conv_settings_dict = {'tower_1': {'reduction_filter': 64,
                                      'tower_filter': 64,
                                      'tower_kernel': (3, 3),
                                      'activation': LeakyReLU},
                          'tower_2': {'reduction_filter': 64,
                                      'tower_filter': 64,
                                      'tower_kernel': (5, 5),
                                      'activation': 'relu'}
                          }
    pool_settings_dict = {'pool_kernel': (3, 3),
                          'tower_filter': 64,
                          'activation': 'relu'}

    # load data
    (X_train, y_train), (X_test, y_test) = cifar10.load_data()
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_train = X_train / 255.0
    X_test = X_test / 255.0
    y_train = np_utils.to_categorical(y_train)
    y_test = np_utils.to_categorical(y_test)
    input_img = Input(shape=(32, 32, 3))

    # create inception net
    inception_net = InceptionModelBase()
    output = inception_net.inception_block(input_img, conv_settings_dict, pool_settings_dict)
    output = Flatten()(output)
    output = Dense(10, activation='softmax')(output)
    model = Model(inputs=input_img, outputs=output)
    print(model.summary())

    # compile
    epochs = 10
    lrate = 0.01
    decay = lrate/epochs
    sgd = SGD(lr=lrate, momentum=0.9, decay=decay, nesterov=False)
    model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
    print(X_train.shape)
    keras.utils.plot_model(model, to_file='model.pdf', show_shapes=True, show_layer_names=True)