Skip to content
Snippets Groups Projects
Commit a9467437 authored by lukas leufen's avatar lukas leufen
Browse files

refac naming of concat layers in inception, /close #59

See merge request toar/machinelearningtools!48
parents 0b347cc3 84b30014
No related branches found
No related tags found
2 merge requests!59Develop,!48Felix issue059 refac naming of concat layers in inception
Pipeline #31046 passed with warnings
......@@ -153,6 +153,7 @@ class InceptionModelBase:
self.number_of_blocks += 1
self.part_of_block = 0
tower_build = {}
block_name = f"Block_{self.number_of_blocks}"
for part, part_settings in tower_conv_parts.items():
tower_build[part] = self.create_conv_tower(input_x, **part_settings, **kwargs)
if 'max_pooling' in tower_pool_parts.keys():
......@@ -165,7 +166,8 @@ class InceptionModelBase:
tower_build['maxpool'] = self.create_pool_tower(input_x, **tower_pool_parts, **kwargs)
tower_build['avgpool'] = self.create_pool_tower(input_x, **tower_pool_parts, **kwargs, max_pooling=False)
block = keras.layers.concatenate(list(tower_build.values()), axis=3)
block = keras.layers.concatenate(list(tower_build.values()), axis=3,
name=block_name+"_Co")
return block
......
......@@ -152,6 +152,10 @@ class TestInceptionModelBase:
assert self.step_in(block_1b._keras_history[0]).name == "Block_1b_5x5"
assert isinstance(self.step_in(block_pool1._keras_history[0], depth=2), keras.layers.pooling.MaxPooling2D)
assert isinstance(self.step_in(block_pool2._keras_history[0], depth=2), keras.layers.pooling.AveragePooling2D)
# check naming of concat layer
assert block.name == 'Block_1_Co/concat:0'
assert block._keras_history[0].name == 'Block_1_Co'
assert isinstance(block._keras_history[0], keras.layers.merge.Concatenate)
# next block
opts['input_x'] = block
opts['tower_pool_parts']['max_pooling'] = True
......@@ -166,6 +170,10 @@ class TestInceptionModelBase:
assert self.step_in(block_2a._keras_history[0]).name == "Block_2a_3x3"
assert self.step_in(block_2b._keras_history[0]).name == "Block_2b_5x5"
assert isinstance(self.step_in(block_pool._keras_history[0], depth=2), keras.layers.pooling.MaxPooling2D)
# check naming of concat layer
assert block.name == 'Block_2_Co/concat:0'
assert block._keras_history[0].name == 'Block_2_Co'
assert isinstance(block._keras_history[0], keras.layers.merge.Concatenate)
def test_batch_normalisation(self, base, input_x):
base.part_of_block += 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment