import keras
import pytest

from src.helpers import PyTestRegex
from src.model_modules.advanced_paddings import ReflectionPadding2D, SymmetricPadding2D
from src.model_modules.inception_model import InceptionModelBase


class TestInceptionModelBase:

    @pytest.fixture
    def base(self):
        return InceptionModelBase()

    @pytest.fixture
    def input_x(self):
        return keras.Input(shape=(32, 32, 3))

    @staticmethod
    def step_in(element, depth=1):
        for _ in range(depth):
            element = element.input._keras_history[0]
        return element

    def test_init(self, base):
        assert base.number_of_blocks == 0
        assert base.part_of_block == 0
        assert base.ord_base == 96
        assert base.act_number == 0

    def test_block_part_name(self, base):
        assert base.block_part_name() == chr(96)
        base.part_of_block += 1
        assert base.block_part_name() == 'a'

    def test_create_conv_tower_3x3(self, base, input_x):
        opts = {'input_x': input_x, 'reduction_filter': 64, 'tower_filter': 32, 'tower_kernel': (3, 3),
                'padding': 'SymPad2D'}
        tower = base.create_conv_tower(**opts)
        # check last element of tower (activation)
        assert base.part_of_block == 1
        assert tower.name == 'Block_0a_act_2/Relu:0'
        act_layer = tower._keras_history[0]
        assert isinstance(act_layer, keras.layers.advanced_activations.ReLU)
        assert act_layer.name == "Block_0a_act_2"
        # check previous element of tower (conv2D)
        conv_layer = self.step_in(act_layer)
        assert isinstance(conv_layer, keras.layers.Conv2D)
        assert conv_layer.filters == 32
        assert conv_layer.padding == 'valid'
        assert conv_layer.kernel_size == (3, 3)
        assert conv_layer.strides == (1, 1)
        assert conv_layer.name == "Block_0a_3x3"
        # check previous element of tower (padding)
        pad_layer = self.step_in(conv_layer)
        assert isinstance(pad_layer, SymmetricPadding2D)
        assert pad_layer.padding == ((1, 1), (1, 1))
        assert pad_layer.name == 'Block_0a_Pad'
        # check previous element of tower (activation)
        act_layer2 = self.step_in(pad_layer)
        assert isinstance(act_layer2, keras.layers.advanced_activations.ReLU)
        assert act_layer2.name == "Block_0a_act_1"
        # check previous element of tower (conv2D)
        conv_layer2 = self.step_in(act_layer2)
        assert isinstance(conv_layer2, keras.layers.Conv2D)
        assert conv_layer2.filters == 64
        assert conv_layer2.kernel_size == (1, 1)
        assert conv_layer2.padding == 'valid'
        assert conv_layer2.name == 'Block_0a_1x1'
        assert conv_layer2.input._keras_shape == (None, 32, 32, 3)

    def test_create_conv_tower_3x3_batch_norm(self, base, input_x):
        # import keras
        opts = {'input_x': input_x, 'reduction_filter': 64, 'tower_filter': 32, 'tower_kernel': (3, 3),
                'padding': 'SymPad2D', 'batch_normalisation': True}
        tower = base.create_conv_tower(**opts)
        # check last element of tower (activation)
        assert base.part_of_block == 1
        # assert tower.name == 'Block_0a_act_2/Relu:0'
        assert tower.name == 'Block_0a_act_2_1/Relu:0'
        act_layer = tower._keras_history[0]
        assert isinstance(act_layer, keras.layers.advanced_activations.ReLU)
        assert act_layer.name == "Block_0a_act_2"
        # check previous element of tower (batch_normal)
        batch_layer = self.step_in(act_layer)
        assert isinstance(batch_layer, keras.layers.BatchNormalization)
        assert batch_layer.name == 'Block_0a_BN'
        # check previous element of tower (conv2D)
        conv_layer = self.step_in(batch_layer)
        assert isinstance(conv_layer, keras.layers.Conv2D)
        assert conv_layer.filters == 32
        assert conv_layer.padding == 'valid'
        assert conv_layer.kernel_size == (3, 3)
        assert conv_layer.strides == (1, 1)
        assert conv_layer.name == "Block_0a_3x3"
        # check previous element of tower (padding)
        pad_layer = self.step_in(conv_layer)
        assert isinstance(pad_layer, SymmetricPadding2D)
        assert pad_layer.padding == ((1, 1), (1, 1))
        assert pad_layer.name == 'Block_0a_Pad'
        # check previous element of tower (activation)
        act_layer2 = self.step_in(pad_layer)
        assert isinstance(act_layer2, keras.layers.advanced_activations.ReLU)
        assert act_layer2.name == "Block_0a_act_1"
        # check previous element of tower (conv2D)
        conv_layer2 = self.step_in(act_layer2)
        assert isinstance(conv_layer2, keras.layers.Conv2D)
        assert conv_layer2.filters == 64
        assert conv_layer2.kernel_size == (1, 1)
        assert conv_layer2.padding == 'valid'
        assert conv_layer2.name == 'Block_0a_1x1'
        assert conv_layer2.input._keras_shape == (None, 32, 32, 3)

    def test_create_conv_tower_3x3_activation(self, base, input_x):
        opts = {'input_x': input_x, 'reduction_filter': 64, 'tower_filter': 32, 'tower_kernel': (3, 3)}
        # create tower with standard activation function
        tower = base.create_conv_tower(activation='tanh', **opts)
        assert tower.name == 'Block_0a_act_2_tanh/Tanh:0'
        act_layer = tower._keras_history[0]
        assert isinstance(act_layer, keras.layers.core.Activation)
        assert act_layer.name == "Block_0a_act_2_tanh"
        # create tower with activation function class
        tower = base.create_conv_tower(activation=keras.layers.LeakyReLU, **opts)
        assert tower.name == 'Block_0b_act_2/LeakyRelu:0'
        act_layer = tower._keras_history[0]
        assert isinstance(act_layer, keras.layers.advanced_activations.LeakyReLU)
        assert act_layer.name == "Block_0b_act_2"

    def test_create_conv_tower_1x1(self, base, input_x):
        opts = {'input_x': input_x, 'reduction_filter': 64, 'tower_filter': 32, 'tower_kernel': (1, 1)}
        tower = base.create_conv_tower(**opts)
        # check last element of tower (activation)
        assert base.part_of_block == 1
        assert tower.name == 'Block_0a_act_1_2/Relu:0'
        act_layer = tower._keras_history[0]
        assert isinstance(act_layer, keras.layers.advanced_activations.ReLU)
        assert act_layer.name == "Block_0a_act_1"
        # check previous element of tower (conv2D)
        conv_layer = self.step_in(act_layer)
        assert isinstance(conv_layer, keras.layers.Conv2D)
        assert conv_layer.filters == 32
        assert conv_layer.padding == 'valid'
        assert conv_layer.kernel_size == (1, 1)
        assert conv_layer.strides == (1, 1)
        assert conv_layer.name == "Block_0a_1x1"
        assert conv_layer.input._keras_shape == (None, 32, 32, 3)

    def test_create_conv_towers(self, base, input_x):
        opts = {'input_x': input_x, 'reduction_filter': 64, 'tower_filter': 32, 'tower_kernel': (3, 3)}
        _ = base.create_conv_tower(**opts)
        tower = base.create_conv_tower(**opts)
        assert base.part_of_block == 2
        assert tower.name == 'Block_0b_act_2_1/Relu:0'

    def test_create_pool_tower(self, base, input_x):
        opts = {'input_x': input_x, 'pool_kernel': (3, 3), 'tower_filter': 32}
        tower = base.create_pool_tower(**opts)
        # check last element of tower (activation)
        assert base.part_of_block == 1
        assert tower.name == 'Block_0a_act_1_4/Relu:0'
        act_layer = tower._keras_history[0]
        assert isinstance(act_layer, keras.layers.advanced_activations.ReLU)
        assert act_layer.name == "Block_0a_act_1"
        # check previous element of tower (conv2D)
        conv_layer = self.step_in(act_layer)
        assert isinstance(conv_layer, keras.layers.Conv2D)
        assert conv_layer.filters == 32
        assert conv_layer.padding == 'valid'
        assert conv_layer.kernel_size == (1, 1)
        assert conv_layer.strides == (1, 1)
        assert conv_layer.name == "Block_0a_1x1"
        # check previous element of tower (maxpool)
        pool_layer = self.step_in(conv_layer)
        assert isinstance(pool_layer, keras.layers.pooling.MaxPooling2D)
        assert pool_layer.name == "Block_0a_MaxPool"
        assert pool_layer.pool_size == (3, 3)
        assert pool_layer.padding == 'valid'
        # check previous element of tower(padding)
        pad_layer = self.step_in(pool_layer)
        assert isinstance(pad_layer, keras.layers.convolutional.ZeroPadding2D)
        assert pad_layer.name == "Block_0a_Pad"
        assert pad_layer.padding == ((1, 1), (1, 1))
        # check avg pool tower
        opts = {'input_x': input_x, 'pool_kernel': (3, 3), 'tower_filter': 32}
        tower = base.create_pool_tower(max_pooling=False, **opts)
        pool_layer = self.step_in(tower._keras_history[0], depth=2)
        assert isinstance(pool_layer, keras.layers.pooling.AveragePooling2D)
        assert pool_layer.name == "Block_0b_AvgPool"
        assert pool_layer.pool_size == (3, 3)
        assert pool_layer.padding == 'valid'

    def test_inception_block(self, base, input_x):
        conv = {'tower_1': {'reduction_filter': 64,
                            'tower_kernel': (3, 3),
                            'tower_filter': 64, },
                'tower_2': {'reduction_filter': 64,
                            'tower_kernel': (5, 5),
                            'tower_filter': 64,
                            'activation': 'tanh',
                            'padding': 'SymPad2D', },
                }
        pool = {'pool_kernel': (3, 3), 'tower_filter': 64, 'padding': ReflectionPadding2D}
        opts = {'input_x': input_x, 'tower_conv_parts': conv, 'tower_pool_parts': pool}
        block = base.inception_block(**opts)
        assert base.number_of_blocks == 1
        concatenated = block._keras_history[0].input
        assert len(concatenated) == 4
        block_1a, block_1b, block_pool1, block_pool2 = concatenated
        # keras_name_part_split
        assert block_1a.name == PyTestRegex(r'Block_1a_act_2(_\d*)?/Relu:0')
        assert block_1b.name == PyTestRegex(r'Block_1b_act_2_tanh(_\d*)?/Tanh:0')
        assert block_pool1.name == PyTestRegex(r'Block_1c_act_1(_\d*)?/Relu:0')
        assert block_pool2.name == PyTestRegex(r'Block_1d_act_1(_\d*)?/Relu:0')
        assert self.step_in(block_1a._keras_history[0]).name == "Block_1a_3x3"
        assert self.step_in(block_1b._keras_history[0]).name == "Block_1b_5x5"
        assert self.step_in(block_1a._keras_history[0], depth=2).name == 'Block_1a_Pad'
        assert isinstance(self.step_in(block_1a._keras_history[0], depth=2), keras.layers.ZeroPadding2D)
        assert self.step_in(block_1b._keras_history[0], depth=2).name == 'Block_1b_Pad'
        assert isinstance(self.step_in(block_1b._keras_history[0], depth=2), SymmetricPadding2D)
        # pooling
        assert isinstance(self.step_in(block_pool1._keras_history[0], depth=2), keras.layers.pooling.MaxPooling2D)
        assert self.step_in(block_pool1._keras_history[0], depth=3).name == 'Block_1c_Pad'
        assert isinstance(self.step_in(block_pool1._keras_history[0], depth=3), ReflectionPadding2D)

        assert isinstance(self.step_in(block_pool2._keras_history[0], depth=2), keras.layers.pooling.AveragePooling2D)
        assert self.step_in(block_pool2._keras_history[0], depth=3).name == 'Block_1d_Pad'
        assert isinstance(self.step_in(block_pool2._keras_history[0], depth=3), ReflectionPadding2D)
        # check naming of concat layer
        assert block.name == PyTestRegex('Block_1_Co(_\d*)?/concat:0')
        assert block._keras_history[0].name == 'Block_1_Co'
        assert isinstance(block._keras_history[0], keras.layers.merge.Concatenate)
        # next block
        opts['input_x'] = block
        opts['tower_pool_parts']['max_pooling'] = True
        block = base.inception_block(**opts)
        assert base.number_of_blocks == 2
        concatenated = block._keras_history[0].input
        assert len(concatenated) == 3
        block_2a, block_2b, block_pool = concatenated
        assert block_2a.name == PyTestRegex(r'Block_2a_act_2(_\d*)?/Relu:0')
        assert block_2b.name == PyTestRegex(r'Block_2b_act_2_tanh(_\d*)?/Tanh:0')
        assert block_pool.name == PyTestRegex(r'Block_2c_act_1(_\d*)?/Relu:0')
        assert self.step_in(block_2a._keras_history[0]).name == "Block_2a_3x3"
        assert self.step_in(block_2a._keras_history[0], depth=2).name == "Block_2a_Pad"
        assert isinstance(self.step_in(block_2a._keras_history[0], depth=2), keras.layers.ZeroPadding2D)
        # block 2b
        assert self.step_in(block_2b._keras_history[0]).name == "Block_2b_5x5"
        assert self.step_in(block_2b._keras_history[0], depth=2).name == "Block_2b_Pad"
        assert isinstance(self.step_in(block_2b._keras_history[0], depth=2), SymmetricPadding2D)
        # block pool
        assert isinstance(self.step_in(block_pool._keras_history[0], depth=2), keras.layers.pooling.MaxPooling2D)
        assert self.step_in(block_pool._keras_history[0], depth=3).name == 'Block_2c_Pad'
        assert isinstance(self.step_in(block_pool._keras_history[0], depth=3), ReflectionPadding2D)
        # check naming of concat layer
        assert block.name == PyTestRegex(r'Block_2_Co(_\d*)?/concat:0')
        assert block._keras_history[0].name == 'Block_2_Co'
        assert isinstance(block._keras_history[0], keras.layers.merge.Concatenate)

    def test_inception_block_invalid_batchnorm(self, base, input_x):
        conv = {'tower_1': {'reduction_filter': 64,
                            'tower_kernel': (3, 3),
                            'tower_filter': 64, },
                'tower_2': {'reduction_filter': 64,
                            'tower_kernel': (5, 5),
                            'tower_filter': 64,
                            'activation': 'tanh',
                            'padding': 'SymPad2D', },
                }
        pool = {'pool_kernel': (3, 3), 'tower_filter': 64, 'padding': ReflectionPadding2D, 'max_pooling': 'yes'}
        opts = {'input_x': input_x, 'tower_conv_parts': conv, 'tower_pool_parts': pool, }
        with pytest.raises(AttributeError) as einfo:
            block = base.inception_block(**opts)
        assert "max_pooling has to be either a bool or empty. Given was: yes" in str(einfo.value)

    def test_batch_normalisation(self, base, input_x):
        base.part_of_block += 1
        bn = base.batch_normalisation(input_x)._keras_history[0]
        assert isinstance(bn, keras.layers.normalization.BatchNormalization)
        assert bn.name == "Block_0a_BN"