from tensorflow.keras import layers from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Conv2D from tensorflow.keras.layers import AveragePooling2D from tensorflow.keras.layers import MaxPool2D from tensorflow.keras.layers import Input from tensorflow.keras.layers import BatchNormalization from tensorflow.keras.layers import GlobalAveragePooling2D from tensorflow.keras.layers import Activation from tensorflow.keras.models import Model from tensorflow.keras.regularizers import l2 def conv2d_bn(x, filters, kernel_size, weight_decay=.0, strides=(1, 1)): layer = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same', use_bias=False, kernel_regularizer=l2(weight_decay) )(x) layer = BatchNormalization()(layer) return layer def conv2d_bn_relu(x, filters, kernel_size, weight_decay=.0, strides=(1, 1)): layer = conv2d_bn(x, filters, kernel_size, weight_decay, strides) layer = Activation('relu')(layer) return layer def ResidualBlock(x, filters, kernel_size, weight_decay, downsample=True): if downsample: # residual_x = conv2d_bn_relu(x, filters, kernel_size=1, strides=2) residual_x = conv2d_bn(x, filters, kernel_size=1, strides=2) stride = 2 else: residual_x = x stride = 1 residual = conv2d_bn_relu(x, filters=filters, kernel_size=kernel_size, weight_decay=weight_decay, strides=stride, ) residual = conv2d_bn(residual, filters=filters, kernel_size=kernel_size, weight_decay=weight_decay, strides=1, ) out = layers.add([residual_x, residual]) out = Activation('relu')(out) return out def ResNet(classes, name, input_shape, block_layers_num, weight_decay): input = Input(shape=input_shape) x = input x = conv2d_bn_relu(x, filters=16, kernel_size=(3, 3), weight_decay=weight_decay, strides=(1, 1)) # # conv 2 for i in range(block_layers_num): x = ResidualBlock(x, filters=16, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False) # # conv 3 x = ResidualBlock(x, filters=32, kernel_size=(3, 3), weight_decay=weight_decay, downsample=True) for i in range(block_layers_num - 1): x = ResidualBlock(x, filters=32, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False) # # conv 4 x = ResidualBlock(x, filters=64, kernel_size=(3, 3), weight_decay=weight_decay, downsample=True) for i in range(block_layers_num - 1): x = ResidualBlock(x, filters=64, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False) x = GlobalAveragePooling2D()(x) x = Dense(classes, activation='softmax')(x) model = Model(input, x, name=name) return model def ResNet20(classes, input_shape, weight_decay): return ResNet(classes, 'resnet20', input_shape, 3, weight_decay) def ResNet32(classes, input_shape, weight_decay): return ResNet(classes, 'resnet32', input_shape, 5, weight_decay) def ResNet56(classes, input_shape, weight_decay): return ResNet(classes, 'resnet56', input_shape, 9, weight_decay)