From 755ea62e74d0e997a0d1b1f7c74183c0cf11aebf Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz.juelich.de> Date: Mon, 21 Oct 2019 17:16:56 +0200 Subject: [PATCH] the inception block class was copied from an outdated branch. This is now the most recent version. Tests are not updated yet and fill fail --- requirements.txt | 3 +- src/inception_model.py | 140 ++++++++++++++++++++++++----------- test/test_inception_model.py | 22 ++++-- 3 files changed, 113 insertions(+), 52 deletions(-) diff --git a/requirements.txt b/requirements.txt index 4dd28a7d..a956e8ff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ Keras==2.2.4 numpy==1.15.4 tensorflow==1.12.0 -pytest==5.2.1 \ No newline at end of file +pytest==5.2.1 +pydot \ No newline at end of file diff --git a/src/inception_model.py b/src/inception_model.py index 8ffbb3d5..fddd3e04 100644 --- a/src/inception_model.py +++ b/src/inception_model.py @@ -2,10 +2,12 @@ __author__ = 'Felix Kleinert, Lukas Leufen' import keras from keras.layers import Input, Dense, Conv2D, MaxPooling2D, AveragePooling2D, ZeroPadding2D, Dropout, Flatten, \ - Concatenate, Reshape, Activation + Concatenate, Reshape, Activation, ReLU +import keras.layers as layers from keras.models import Model from keras.regularizers import l2 from keras.optimizers import SGD +from keras.layers.advanced_activations import LeakyReLU, PReLU, ELU class InceptionModelBase: @@ -16,6 +18,7 @@ class InceptionModelBase: def __init__(self): self.number_of_blocks = 0 self.part_of_block = 0 + self.act_number = 0 # conversion between chr and ord: # >>> chr(97) # 'a' @@ -31,71 +34,118 @@ class InceptionModelBase: """ return chr(self.ord_base + self.part_of_block) + def batch_normalisation(self, input_x, **kwargs): + block_name = f"Block_{self.number_of_blocks}{self.block_part_name()}_BN" + return layers.BatchNormalization(name=block_name, **kwargs)(input_x) + def create_conv_tower(self, - input_X, + input_x, reduction_filter, tower_filter, tower_kernel, activation='relu', - regularizer=l2(0.01)): + batch_normalisation=False, + **kwargs): """ - This function creates a "convolution tower block" containing a 1x1 convolution to reduce filter size followed by convolution - with given filter and kernel size - :param input_X: Input to network part + This function creates a "convolution tower block" containing a 1x1 convolution to reduce filter size followed by + convolution with given filter and kernel size + :param input_x: Input to network part :param reduction_filter: Number of filters used in 1x1 convolution to reduce overall filter size before conv. :param tower_filter: Number of filters for n x m convolution :param tower_kernel: kernel size for convolution (n,m) :param activation: activation function for convolution + :param batch_normalisation: :return: """ self.part_of_block += 1 + self.act_number = 1 + regularizer = kwargs.get('regularizer', l2(0.01)) + bn_settings = kwargs.get('bn_settings', {}) + act_settings = kwargs.get('act_settings', {}) + print(f'Inception Block with activation: {activation}') + + block_name = f'Block_{self.number_of_blocks}{self.block_part_name()}_{tower_kernel[0]}x{tower_kernel[1]}' if tower_kernel == (1, 1): tower = Conv2D(tower_filter, tower_kernel, - activation=activation, padding='same', kernel_regularizer=regularizer, - name='Block_{}{}_{}x{}'.format(self.number_of_blocks, - self.block_part_name(), - tower_kernel[0], - tower_kernel[1]))(input_X) + name=block_name)(input_x) + tower = self.act(tower, activation, **act_settings) else: tower = Conv2D(reduction_filter, (1, 1), - activation=activation, padding='same', kernel_regularizer=regularizer, - name='Block_{}{}_1x1'.format(self.number_of_blocks, self.block_part_name()))(input_X) + name=f'Block_{self.number_of_blocks}{self.block_part_name()}_1x1')(input_x) + tower = self.act(tower, activation, **act_settings) tower = Conv2D(tower_filter, tower_kernel, - activation=activation, padding='same', kernel_regularizer=regularizer, - name='Block_{}{}_{}x{}'.format(self.number_of_blocks, - self.block_part_name(), - tower_kernel[0], - tower_kernel[1]))(tower) + name=block_name)(tower) + if batch_normalisation: + tower = self.batch_normalisation(tower, **bn_settings) + tower = self.act(tower, activation, **act_settings) + return tower + def act(self, input_x, activation, **act_settings): + block_name = f"Block_{self.number_of_blocks}{self.block_part_name()}_act_{self.act_number}" + try: + out = getattr(layers, self._get_act_name(activation))(**act_settings, name=block_name)(input_x) + except AttributeError: + block_name += f"_{activation.lower()}" + out = layers.Activation(activation.lower(), name=block_name)(input_x) + self.act_number += 1 + return out + @staticmethod - def create_pool_tower(input_X, pool_kernel, tower_filter): + def _get_act_name(act_name): + if isinstance(act_name, str): + mapping = {'relu': 'ReLU', 'prelu': 'PReLU', 'elu': 'ELU'} + return mapping.get(act_name.lower(), act_name) + else: + return act_name.__name__ + + def create_pool_tower(self, input_x, pool_kernel, tower_filter, activation='relu', max_pooling=True, **kwargs): """ This function creates a "MaxPooling tower block" - :param input_X: Input to network part + :param input_x: Input to network part :param pool_kernel: size of pooling kernel :param tower_filter: Number of filters used in 1x1 convolution to reduce filter size + :param activation: + :param max_pooling: :return: """ - tower = MaxPooling2D(pool_kernel, strides=(1, 1), padding='same')(input_X) - tower = Conv2D(tower_filter, (1, 1), padding='same', activation='relu')(tower) + self.part_of_block += 1 + self.act_number = 1 + act_settings = kwargs.get('act_settings', {}) + + # pooling block + block_name = f"Block_{self.number_of_blocks}{self.block_part_name()}_" + if max_pooling: + block_type = "MaxPool" + pooling = MaxPooling2D + else: + block_type = "AvgPool" + pooling = AveragePooling2D + tower = pooling(pool_kernel, strides=(1, 1), padding='same', name=block_name+block_type)(input_x) + # tower = MaxPooling2D(pool_kernel, strides=(1, 1), padding='same', name=block_name)(input_x) + # tower = AveragePooling2D(pool_kernel, strides=(1, 1), padding='same', name=block_name)(input_x) + + # convolution block + tower = Conv2D(tower_filter, (1, 1), padding='same', name=block_name+"1x1")(tower) + tower = self.act(tower, activation, **act_settings) + return tower - def inception_block(self, input_X, tower_conv_parts, tower_pool_parts): + def inception_block(self, input_x, tower_conv_parts, tower_pool_parts, **kwargs): """ Crate a inception block - :param input_X: Input to block + :param input_x: Input to block :param tower_conv_parts: dict containing settings for parts of inception block; Example: tower_conv_parts = {'tower_1': {'reduction_filter': 32, 'tower_filter': 64, @@ -115,22 +165,24 @@ class InceptionModelBase: self.part_of_block = 0 tower_build = {} for part, part_settings in tower_conv_parts.items(): - tower_build[part] = self.create_conv_tower(input_X, - part_settings['reduction_filter'], - part_settings['tower_filter'], - part_settings['tower_kernel'] - ) - tower_build['pool'] = self.create_pool_tower(input_X, - tower_pool_parts['pool_kernel'], - tower_pool_parts['tower_filter'] - ) + tower_build[part] = self.create_conv_tower(input_x, **part_settings, **kwargs) + if 'max_pooling' in tower_pool_parts.keys(): + max_pooling = tower_pool_parts.get('max_pooling') + if not isinstance(max_pooling, bool): + raise AttributeError(f"max_pooling has to be either a bool or empty. Given was: {max_pooling}") + pool_name = '{}pool'.format('max' if max_pooling else 'avg') + tower_build[pool_name] = self.create_pool_tower(input_x, **tower_pool_parts, **kwargs) + else: + tower_build['maxpool'] = self.create_pool_tower(input_x, **tower_pool_parts, **kwargs) + tower_build['avgpool'] = self.create_pool_tower(input_x, **tower_pool_parts, **kwargs, max_pooling=False) + block = keras.layers.concatenate(list(tower_build.values()), axis=3) return block @staticmethod - def flatten_tail(input_X, tail_block): - input_X = Flatten()(input_X) - tail = tail_block(input_X) + def flatten_tail(input_x, tail_block): + input_x = Flatten()(input_x) + tail = tail_block(input_x) return tail @@ -141,13 +193,16 @@ if __name__ == '__main__': from keras.layers import Input conv_settings_dict = {'tower_1': {'reduction_filter': 64, 'tower_filter': 64, - 'tower_kernel': (3, 3)}, + 'tower_kernel': (3, 3), + 'activation': LeakyReLU}, 'tower_2': {'reduction_filter': 64, 'tower_filter': 64, - 'tower_kernel': (5, 5)}, + 'tower_kernel': (5, 5), + 'activation': 'relu'} } pool_settings_dict = {'pool_kernel': (3, 3), - 'tower_filter': 64} + 'tower_filter': 64, + 'activation': 'relu'} myclass = True (X_train, y_train), (X_test, y_test) = cifar10.load_data() @@ -175,8 +230,8 @@ if __name__ == '__main__': output = keras.layers.concatenate([tower_1, tower_2, tower_3], axis=3) output = Flatten()(output) - out = Dense(10, activation='softmax')(output) - model = Model(inputs=input_img, outputs=out) + output = Dense(10, activation='softmax')(output) + model = Model(inputs=input_img, outputs=output) print(model.summary()) epochs = 10 @@ -187,9 +242,8 @@ if __name__ == '__main__': model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy']) print(X_train.shape) + keras.utils.plot_model(model, to_file='model.pdf', show_shapes=True, show_layer_names=True) # model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=epochs, batch_size=32) # # scores = model.evaluate(X_test, y_test, verbose=0) # print("Accuracy: %.2f%%" % (scores[1]*100)) - - diff --git a/test/test_inception_model.py b/test/test_inception_model.py index a03e50c1..172693b0 100644 --- a/test/test_inception_model.py +++ b/test/test_inception_model.py @@ -18,6 +18,7 @@ class TestInceptionModelBase: assert base.number_of_blocks == 0 assert base.part_of_block == 0 assert base.ord_base == 96 + assert base.act_number == 0 def test_block_part_name(self, base): assert base.block_part_name() == chr(96) @@ -25,7 +26,7 @@ class TestInceptionModelBase: assert base.block_part_name() == 'a' def test_create_conv_tower_3x3(self, base, input_x): - opts = {'input_X': input_x, 'reduction_filter': 64, 'tower_filter': 32, 'tower_kernel': (3, 3)} + opts = {'input_x': input_x, 'reduction_filter': 64, 'tower_filter': 32, 'tower_kernel': (3, 3)} tower = base.create_conv_tower(**opts) # check second element of tower assert base.part_of_block == 1 @@ -46,7 +47,7 @@ class TestInceptionModelBase: assert tower._keras_history[0].input._keras_history[0].input._keras_shape == (None, 32, 32, 3) def test_create_conv_tower_1x1(self, base, input_x): - opts = {'input_X': input_x, 'reduction_filter': 64, 'tower_filter': 32, 'tower_kernel': (1, 1)} + opts = {'input_x': input_x, 'reduction_filter': 64, 'tower_filter': 32, 'tower_kernel': (1, 1)} tower = base.create_conv_tower(**opts) # check second element of tower assert base.part_of_block == 1 @@ -61,14 +62,14 @@ class TestInceptionModelBase: assert tower._keras_history[0].strides == (1, 1) def test_create_conv_towers(self, base, input_x): - opts = {'input_X': input_x, 'reduction_filter': 64, 'tower_filter': 32, 'tower_kernel': (3, 3)} + opts = {'input_x': input_x, 'reduction_filter': 64, 'tower_filter': 32, 'tower_kernel': (3, 3)} _ = base.create_conv_tower(**opts) tower = base.create_conv_tower(**opts) assert base.part_of_block == 2 assert tower.name == 'Block_0b_3x3/Relu:0' def test_create_pool_tower(self, base, input_x): - opts = {'input_X': input_x, 'pool_kernel': (3, 3), 'tower_filter': 32} + opts = {'input_x': input_x, 'pool_kernel': (3, 3), 'tower_filter': 32} tower = base.create_pool_tower(**opts) # check second element of tower assert base.part_of_block == 0 @@ -89,7 +90,7 @@ class TestInceptionModelBase: conv = {'tower_1': {'reduction_filter': 64, 'tower_kernel': (3, 3), 'tower_filter': 64}, 'tower_2': {'reduction_filter': 64, 'tower_kernel': (5, 5), 'tower_filter': 64}} pool = {'pool_kernel': (3, 3), 'tower_filter': 64} - opts = {'input_X': input_x, 'tower_conv_parts': conv, 'tower_pool_parts': pool} + opts = {'input_x': input_x, 'tower_conv_parts': conv, 'tower_pool_parts': pool} block = base.inception_block(**opts) assert base.number_of_blocks == 1 concatenated = block._keras_history[0].input @@ -97,9 +98,9 @@ class TestInceptionModelBase: block_1a, block_1b, block_pool = concatenated assert block_1a.name == 'Block_1a_3x3/Relu:0' assert block_1b.name == 'Block_1b_5x5/Relu:0' - assert block_pool.name == 'conv2d_2/Relu:0' + assert block_pool.name == 'conv2d_1/Relu:0' # next block - opts['input_X'] = block + opts['input_x'] = block block = base.inception_block(**opts) assert base.number_of_blocks == 2 concatenated = block._keras_history[0].input @@ -107,4 +108,9 @@ class TestInceptionModelBase: block_1a, block_1b, block_pool = concatenated assert block_1a.name == 'Block_2a_3x3/Relu:0' assert block_1b.name == 'Block_2b_5x5/Relu:0' - assert block_pool.name == 'conv2d_3/Relu:0' + assert block_pool.name == 'conv2d_2/Relu:0' + m = keras.models.Model(input=input_x, output=block) + keras.utils.plot_model(m, to_file='model.pdf', show_shapes=True, show_layer_names=True) + + def test_batch_normalisation(self): + pass -- GitLab