Skip to content
Snippets Groups Projects
Commit 6f72afab authored by Felix Kleinert's avatar Felix Kleinert
Browse files

update tests for keras namings

parent 0e36df1e
No related branches found
No related tags found
2 merge requests!59Develop,!55update tests for keras namings
Pipeline #31178 passed
import keras import keras
import pytest import pytest
import re
from src.model_modules.inception_model import InceptionModelBase from src.model_modules.inception_model import InceptionModelBase
...@@ -20,6 +21,19 @@ class TestInceptionModelBase: ...@@ -20,6 +21,19 @@ class TestInceptionModelBase:
element = element.input._keras_history[0] element = element.input._keras_history[0]
return element return element
@staticmethod
def keras_name_part_split(test_value: str, target_value: str, sep='/'):
layer_name, layer_type = test_value.split(sep)
target_name, target_type = target_value.split(sep)
n_layer_name = len(re.findall("_", layer_name))
n_target_name = len(re.findall("_", target_name))
if (n_layer_name == n_target_name) or (n_layer_name == n_target_name + 1):
name_test = target_name in layer_name
else:
name_test = False
type_test = (layer_type == target_type)
return name_test and type_test
def test_init(self, base): def test_init(self, base):
assert base.number_of_blocks == 0 assert base.number_of_blocks == 0
assert base.part_of_block == 0 assert base.part_of_block == 0
...@@ -144,16 +158,17 @@ class TestInceptionModelBase: ...@@ -144,16 +158,17 @@ class TestInceptionModelBase:
concatenated = block._keras_history[0].input concatenated = block._keras_history[0].input
assert len(concatenated) == 4 assert len(concatenated) == 4
block_1a, block_1b, block_pool1, block_pool2 = concatenated block_1a, block_1b, block_pool1, block_pool2 = concatenated
assert block_1a.name == 'Block_1a_act_2/Relu:0' # <- sometimes keras changes given name (I don't know why yet) # keras_name_part_split
assert block_1b.name == 'Block_1b_act_2_tanh/Tanh:0' assert self.keras_name_part_split(block_1a.name, 'Block_1a_act_2/Relu:0')
assert block_pool1.name == 'Block_1c_act_1/Relu:0' assert self.keras_name_part_split(block_1b.name, 'Block_1b_act_2_tanh/Tanh:0')
assert block_pool2.name == 'Block_1d_act_1/Relu:0' assert self.keras_name_part_split(block_pool1.name, 'Block_1c_act_1/Relu:0')
assert self.keras_name_part_split(block_pool2.name, 'Block_1d_act_1/Relu:0')
assert self.step_in(block_1a._keras_history[0]).name == "Block_1a_3x3" assert self.step_in(block_1a._keras_history[0]).name == "Block_1a_3x3"
assert self.step_in(block_1b._keras_history[0]).name == "Block_1b_5x5" assert self.step_in(block_1b._keras_history[0]).name == "Block_1b_5x5"
assert isinstance(self.step_in(block_pool1._keras_history[0], depth=2), keras.layers.pooling.MaxPooling2D) assert isinstance(self.step_in(block_pool1._keras_history[0], depth=2), keras.layers.pooling.MaxPooling2D)
assert isinstance(self.step_in(block_pool2._keras_history[0], depth=2), keras.layers.pooling.AveragePooling2D) assert isinstance(self.step_in(block_pool2._keras_history[0], depth=2), keras.layers.pooling.AveragePooling2D)
# check naming of concat layer # check naming of concat layer
assert block.name == 'Block_1_Co/concat:0' assert self.keras_name_part_split(block.name, 'Block_1_Co/concat:0')
assert block._keras_history[0].name == 'Block_1_Co' assert block._keras_history[0].name == 'Block_1_Co'
assert isinstance(block._keras_history[0], keras.layers.merge.Concatenate) assert isinstance(block._keras_history[0], keras.layers.merge.Concatenate)
# next block # next block
...@@ -164,14 +179,14 @@ class TestInceptionModelBase: ...@@ -164,14 +179,14 @@ class TestInceptionModelBase:
concatenated = block._keras_history[0].input concatenated = block._keras_history[0].input
assert len(concatenated) == 3 assert len(concatenated) == 3
block_2a, block_2b, block_pool = concatenated block_2a, block_2b, block_pool = concatenated
assert block_2a.name == 'Block_2a_act_2/Relu:0' assert self.keras_name_part_split(block_2a.name, 'Block_2a_act_2/Relu:0')
assert block_2b.name == 'Block_2b_act_2_tanh/Tanh:0' assert self.keras_name_part_split(block_2b.name, 'Block_2b_act_2_tanh/Tanh:0')
assert block_pool.name == 'Block_2c_act_1/Relu:0' assert self.keras_name_part_split(block_pool.name, 'Block_2c_act_1/Relu:0')
assert self.step_in(block_2a._keras_history[0]).name == "Block_2a_3x3" assert self.step_in(block_2a._keras_history[0]).name == "Block_2a_3x3"
assert self.step_in(block_2b._keras_history[0]).name == "Block_2b_5x5" assert self.step_in(block_2b._keras_history[0]).name == "Block_2b_5x5"
assert isinstance(self.step_in(block_pool._keras_history[0], depth=2), keras.layers.pooling.MaxPooling2D) assert isinstance(self.step_in(block_pool._keras_history[0], depth=2), keras.layers.pooling.MaxPooling2D)
# check naming of concat layer # check naming of concat layer
assert block.name == 'Block_2_Co/concat:0' assert self.keras_name_part_split(block.name, 'Block_2_Co/concat:0')
assert block._keras_history[0].name == 'Block_2_Co' assert block._keras_history[0].name == 'Block_2_Co'
assert isinstance(block._keras_history[0], keras.layers.merge.Concatenate) assert isinstance(block._keras_history[0], keras.layers.merge.Concatenate)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment