Skip to content
Snippets Groups Projects
Commit b5666fa2 authored by Mehdi Cherti's avatar Mehdi Cherti
Browse files

remove Flatten layer in resnet

parent fcc14470
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,6 @@ from tensorflow.keras import layers ...@@ -2,7 +2,6 @@ from tensorflow.keras import layers
from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Conv2D from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import AveragePooling2D from tensorflow.keras.layers import AveragePooling2D
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import MaxPool2D from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.layers import Input from tensorflow.keras.layers import Input
from tensorflow.keras.layers import BatchNormalization from tensorflow.keras.layers import BatchNormalization
...@@ -72,7 +71,6 @@ def ResNet(classes, name, input_shape, block_layers_num, weight_decay): ...@@ -72,7 +71,6 @@ def ResNet(classes, name, input_shape, block_layers_num, weight_decay):
for i in range(block_layers_num - 1): for i in range(block_layers_num - 1):
x = ResidualBlock(x, filters=64, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False) x = ResidualBlock(x, filters=64, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False)
x = GlobalAveragePooling2D()(x) x = GlobalAveragePooling2D()(x)
x = Flatten()(x)
x = Dense(classes, activation='softmax')(x) x = Dense(classes, activation='softmax')(x)
model = Model(input, x, name=name) model = Model(input, x, name=name)
return model return model
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment