Skip to content
Snippets Groups Projects
Commit b5666fa2 authored by Mehdi Cherti's avatar Mehdi Cherti
Browse files

remove Flatten layer in resnet

parent fcc14470
Branches
Tags
No related merge requests found
......@@ -2,7 +2,6 @@ from tensorflow.keras import layers
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import AveragePooling2D
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import BatchNormalization
......@@ -72,7 +71,6 @@ def ResNet(classes, name, input_shape, block_layers_num, weight_decay):
for i in range(block_layers_num - 1):
x = ResidualBlock(x, filters=64, kernel_size=(3, 3), weight_decay=weight_decay, downsample=False)
x = GlobalAveragePooling2D()(x)
x = Flatten()(x)
x = Dense(classes, activation='softmax')(x)
model = Model(input, x, name=name)
return model
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment