diff --git a/test/test_model_modules/test_inception_model.py b/test/test_model_modules/test_inception_model.py index ae227184eb575b8192ffc008413e327d0aca4ba1..d441925e39c88f9c0dba2a3e811991bb380d61b3 100644 --- a/test/test_model_modules/test_inception_model.py +++ b/test/test_model_modules/test_inception_model.py @@ -1,5 +1,6 @@ import keras import pytest +import re from src.model_modules.inception_model import InceptionModelBase @@ -20,6 +21,19 @@ class TestInceptionModelBase: element = element.input._keras_history[0] return element + @staticmethod + def keras_name_part_split(test_value: str, target_value: str, sep='/'): + layer_name, layer_type = test_value.split(sep) + target_name, target_type = target_value.split(sep) + n_layer_name = len(re.findall("_", layer_name)) + n_target_name = len(re.findall("_", target_name)) + if (n_layer_name == n_target_name) or (n_layer_name == n_target_name + 1): + name_test = target_name in layer_name + else: + name_test = False + type_test = (layer_type == target_type) + return name_test and type_test + def test_init(self, base): assert base.number_of_blocks == 0 assert base.part_of_block == 0 @@ -144,16 +158,17 @@ class TestInceptionModelBase: concatenated = block._keras_history[0].input assert len(concatenated) == 4 block_1a, block_1b, block_pool1, block_pool2 = concatenated - assert block_1a.name == 'Block_1a_act_2/Relu:0' # <- sometimes keras changes given name (I don't know why yet) - assert block_1b.name == 'Block_1b_act_2_tanh/Tanh:0' - assert block_pool1.name == 'Block_1c_act_1/Relu:0' - assert block_pool2.name == 'Block_1d_act_1/Relu:0' + # keras_name_part_split + assert self.keras_name_part_split(block_1a.name, 'Block_1a_act_2/Relu:0') + assert self.keras_name_part_split(block_1b.name, 'Block_1b_act_2_tanh/Tanh:0') + assert self.keras_name_part_split(block_pool1.name, 'Block_1c_act_1/Relu:0') + assert self.keras_name_part_split(block_pool2.name, 'Block_1d_act_1/Relu:0') assert self.step_in(block_1a._keras_history[0]).name == "Block_1a_3x3" assert self.step_in(block_1b._keras_history[0]).name == "Block_1b_5x5" assert isinstance(self.step_in(block_pool1._keras_history[0], depth=2), keras.layers.pooling.MaxPooling2D) assert isinstance(self.step_in(block_pool2._keras_history[0], depth=2), keras.layers.pooling.AveragePooling2D) # check naming of concat layer - assert block.name == 'Block_1_Co/concat:0' + assert self.keras_name_part_split(block.name, 'Block_1_Co/concat:0') assert block._keras_history[0].name == 'Block_1_Co' assert isinstance(block._keras_history[0], keras.layers.merge.Concatenate) # next block @@ -164,14 +179,14 @@ class TestInceptionModelBase: concatenated = block._keras_history[0].input assert len(concatenated) == 3 block_2a, block_2b, block_pool = concatenated - assert block_2a.name == 'Block_2a_act_2/Relu:0' - assert block_2b.name == 'Block_2b_act_2_tanh/Tanh:0' - assert block_pool.name == 'Block_2c_act_1/Relu:0' + assert self.keras_name_part_split(block_2a.name, 'Block_2a_act_2/Relu:0') + assert self.keras_name_part_split(block_2b.name, 'Block_2b_act_2_tanh/Tanh:0') + assert self.keras_name_part_split(block_pool.name, 'Block_2c_act_1/Relu:0') assert self.step_in(block_2a._keras_history[0]).name == "Block_2a_3x3" assert self.step_in(block_2b._keras_history[0]).name == "Block_2b_5x5" assert isinstance(self.step_in(block_pool._keras_history[0], depth=2), keras.layers.pooling.MaxPooling2D) # check naming of concat layer - assert block.name == 'Block_2_Co/concat:0' + assert self.keras_name_part_split(block.name, 'Block_2_Co/concat:0') assert block._keras_history[0].name == 'Block_2_Co' assert isinstance(block._keras_history[0], keras.layers.merge.Concatenate)