63. Transfer Learning for Animal Classification#
63.1. Introduction#
In the experiment of image classification, we have understood the concept of transfer learning and completed the task of cat and dog recognition using the pre-trained Alexnet model. In this challenge, we try to use TensorFlow Keras to train a transfer learning model for animal classification.
63.2. Key Points#
Transfer Learning
Pre-trained Model
TensorFlow Keras
{note}
The first half of this challenge can be learned by following the methods of the experiment.
This challenge will use the animal image dataset we provided, which contains three different animals: cats, dogs, and horses.
{note}
```bash
# Download the data file from the course image server
wget -nc "https://cdn.aibydoing.com/aibydoing/files/transfer-train.zip"
unzip -o "transfer-train.zip"
After extraction, there are only 50 images for each animal. We want to see the effect of transfer learning on an extremely small dataset.
Through the Keras official documentation, we can find the pre-trained models it provides. We choose MobileNetV2, which is relatively small in size and has good performance. The size of the model after being trained on ImageNet is 14MB.
The Keras pre-trained models hosted on GitHub are slow to download:
import tensorflow as tf
# 读取并加载模型
mobilenet = tf.keras.applications.MobileNetV2()
mobilenet.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224.h5
14536120/14536120 [==============================] - 2s 0us/step
Model: "mobilenetv2_1.00_224"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 224, 224, 3)] 0 []
Conv1 (Conv2D) (None, 112, 112, 32) 864 ['input_1[0][0]']
bn_Conv1 (BatchNormalizati (None, 112, 112, 32) 128 ['Conv1[0][0]']
on)
Conv1_relu (ReLU) (None, 112, 112, 32) 0 ['bn_Conv1[0][0]']
expanded_conv_depthwise (D (None, 112, 112, 32) 288 ['Conv1_relu[0][0]']
epthwiseConv2D)
expanded_conv_depthwise_BN (None, 112, 112, 32) 128 ['expanded_conv_depthwise[0][0
(BatchNormalization) ]']
expanded_conv_depthwise_re (None, 112, 112, 32) 0 ['expanded_conv_depthwise_BN[0
lu (ReLU) ][0]']
expanded_conv_project (Con (None, 112, 112, 16) 512 ['expanded_conv_depthwise_relu
v2D) [0][0]']
expanded_conv_project_BN ( (None, 112, 112, 16) 64 ['expanded_conv_project[0][0]'
BatchNormalization) ]
block_1_expand (Conv2D) (None, 112, 112, 96) 1536 ['expanded_conv_project_BN[0][
0]']
block_1_expand_BN (BatchNo (None, 112, 112, 96) 384 ['block_1_expand[0][0]']
rmalization)
block_1_expand_relu (ReLU) (None, 112, 112, 96) 0 ['block_1_expand_BN[0][0]']
block_1_pad (ZeroPadding2D (None, 113, 113, 96) 0 ['block_1_expand_relu[0][0]']
)
block_1_depthwise (Depthwi (None, 56, 56, 96) 864 ['block_1_pad[0][0]']
seConv2D)
block_1_depthwise_BN (Batc (None, 56, 56, 96) 384 ['block_1_depthwise[0][0]']
hNormalization)
block_1_depthwise_relu (Re (None, 56, 56, 96) 0 ['block_1_depthwise_BN[0][0]']
LU)
block_1_project (Conv2D) (None, 56, 56, 24) 2304 ['block_1_depthwise_relu[0][0]
']
block_1_project_BN (BatchN (None, 56, 56, 24) 96 ['block_1_project[0][0]']
ormalization)
block_2_expand (Conv2D) (None, 56, 56, 144) 3456 ['block_1_project_BN[0][0]']
block_2_expand_BN (BatchNo (None, 56, 56, 144) 576 ['block_2_expand[0][0]']
rmalization)
block_2_expand_relu (ReLU) (None, 56, 56, 144) 0 ['block_2_expand_BN[0][0]']
block_2_depthwise (Depthwi (None, 56, 56, 144) 1296 ['block_2_expand_relu[0][0]']
seConv2D)
block_2_depthwise_BN (Batc (None, 56, 56, 144) 576 ['block_2_depthwise[0][0]']
hNormalization)
block_2_depthwise_relu (Re (None, 56, 56, 144) 0 ['block_2_depthwise_BN[0][0]']
LU)
block_2_project (Conv2D) (None, 56, 56, 24) 3456 ['block_2_depthwise_relu[0][0]
']
block_2_project_BN (BatchN (None, 56, 56, 24) 96 ['block_2_project[0][0]']
ormalization)
block_2_add (Add) (None, 56, 56, 24) 0 ['block_1_project_BN[0][0]',
'block_2_project_BN[0][0]']
block_3_expand (Conv2D) (None, 56, 56, 144) 3456 ['block_2_add[0][0]']
block_3_expand_BN (BatchNo (None, 56, 56, 144) 576 ['block_3_expand[0][0]']
rmalization)
block_3_expand_relu (ReLU) (None, 56, 56, 144) 0 ['block_3_expand_BN[0][0]']
block_3_pad (ZeroPadding2D (None, 57, 57, 144) 0 ['block_3_expand_relu[0][0]']
)
block_3_depthwise (Depthwi (None, 28, 28, 144) 1296 ['block_3_pad[0][0]']
seConv2D)
block_3_depthwise_BN (Batc (None, 28, 28, 144) 576 ['block_3_depthwise[0][0]']
hNormalization)
block_3_depthwise_relu (Re (None, 28, 28, 144) 0 ['block_3_depthwise_BN[0][0]']
LU)
block_3_project (Conv2D) (None, 28, 28, 32) 4608 ['block_3_depthwise_relu[0][0]
']
block_3_project_BN (BatchN (None, 28, 28, 32) 128 ['block_3_project[0][0]']
ormalization)
block_4_expand (Conv2D) (None, 28, 28, 192) 6144 ['block_3_project_BN[0][0]']
block_4_expand_BN (BatchNo (None, 28, 28, 192) 768 ['block_4_expand[0][0]']
rmalization)
block_4_expand_relu (ReLU) (None, 28, 28, 192) 0 ['block_4_expand_BN[0][0]']
block_4_depthwise (Depthwi (None, 28, 28, 192) 1728 ['block_4_expand_relu[0][0]']
seConv2D)
block_4_depthwise_BN (Batc (None, 28, 28, 192) 768 ['block_4_depthwise[0][0]']
hNormalization)
block_4_depthwise_relu (Re (None, 28, 28, 192) 0 ['block_4_depthwise_BN[0][0]']
LU)
block_4_project (Conv2D) (None, 28, 28, 32) 6144 ['block_4_depthwise_relu[0][0]
']
block_4_project_BN (BatchN (None, 28, 28, 32) 128 ['block_4_project[0][0]']
ormalization)
block_4_add (Add) (None, 28, 28, 32) 0 ['block_3_project_BN[0][0]',
'block_4_project_BN[0][0]']
block_5_expand (Conv2D) (None, 28, 28, 192) 6144 ['block_4_add[0][0]']
block_5_expand_BN (BatchNo (None, 28, 28, 192) 768 ['block_5_expand[0][0]']
rmalization)
block_5_expand_relu (ReLU) (None, 28, 28, 192) 0 ['block_5_expand_BN[0][0]']
block_5_depthwise (Depthwi (None, 28, 28, 192) 1728 ['block_5_expand_relu[0][0]']
seConv2D)
block_5_depthwise_BN (Batc (None, 28, 28, 192) 768 ['block_5_depthwise[0][0]']
hNormalization)
block_5_depthwise_relu (Re (None, 28, 28, 192) 0 ['block_5_depthwise_BN[0][0]']
LU)
block_5_project (Conv2D) (None, 28, 28, 32) 6144 ['block_5_depthwise_relu[0][0]
']
block_5_project_BN (BatchN (None, 28, 28, 32) 128 ['block_5_project[0][0]']
ormalization)
block_5_add (Add) (None, 28, 28, 32) 0 ['block_4_add[0][0]',
'block_5_project_BN[0][0]']
block_6_expand (Conv2D) (None, 28, 28, 192) 6144 ['block_5_add[0][0]']
block_6_expand_BN (BatchNo (None, 28, 28, 192) 768 ['block_6_expand[0][0]']
rmalization)
block_6_expand_relu (ReLU) (None, 28, 28, 192) 0 ['block_6_expand_BN[0][0]']
block_6_pad (ZeroPadding2D (None, 29, 29, 192) 0 ['block_6_expand_relu[0][0]']
)
block_6_depthwise (Depthwi (None, 14, 14, 192) 1728 ['block_6_pad[0][0]']
seConv2D)
block_6_depthwise_BN (Batc (None, 14, 14, 192) 768 ['block_6_depthwise[0][0]']
hNormalization)
block_6_depthwise_relu (Re (None, 14, 14, 192) 0 ['block_6_depthwise_BN[0][0]']
LU)
block_6_project (Conv2D) (None, 14, 14, 64) 12288 ['block_6_depthwise_relu[0][0]
']
block_6_project_BN (BatchN (None, 14, 14, 64) 256 ['block_6_project[0][0]']
ormalization)
block_7_expand (Conv2D) (None, 14, 14, 384) 24576 ['block_6_project_BN[0][0]']
block_7_expand_BN (BatchNo (None, 14, 14, 384) 1536 ['block_7_expand[0][0]']
rmalization)
block_7_expand_relu (ReLU) (None, 14, 14, 384) 0 ['block_7_expand_BN[0][0]']
block_7_depthwise (Depthwi (None, 14, 14, 384) 3456 ['block_7_expand_relu[0][0]']
seConv2D)
block_7_depthwise_BN (Batc (None, 14, 14, 384) 1536 ['block_7_depthwise[0][0]']
hNormalization)
block_7_depthwise_relu (Re (None, 14, 14, 384) 0 ['block_7_depthwise_BN[0][0]']
LU)
block_7_project (Conv2D) (None, 14, 14, 64) 24576 ['block_7_depthwise_relu[0][0]
']
block_7_project_BN (BatchN (None, 14, 14, 64) 256 ['block_7_project[0][0]']
ormalization)
block_7_add (Add) (None, 14, 14, 64) 0 ['block_6_project_BN[0][0]',
'block_7_project_BN[0][0]']
block_8_expand (Conv2D) (None, 14, 14, 384) 24576 ['block_7_add[0][0]']
block_8_expand_BN (BatchNo (None, 14, 14, 384) 1536 ['block_8_expand[0][0]']
rmalization)
block_8_expand_relu (ReLU) (None, 14, 14, 384) 0 ['block_8_expand_BN[0][0]']
block_8_depthwise (Depthwi (None, 14, 14, 384) 3456 ['block_8_expand_relu[0][0]']
seConv2D)
block_8_depthwise_BN (Batc (None, 14, 14, 384) 1536 ['block_8_depthwise[0][0]']
hNormalization)
block_8_depthwise_relu (Re (None, 14, 14, 384) 0 ['block_8_depthwise_BN[0][0]']
LU)
block_8_project (Conv2D) (None, 14, 14, 64) 24576 ['block_8_depthwise_relu[0][0]
']
block_8_project_BN (BatchN (None, 14, 14, 64) 256 ['block_8_project[0][0]']
ormalization)
block_8_add (Add) (None, 14, 14, 64) 0 ['block_7_add[0][0]',
'block_8_project_BN[0][0]']
block_9_expand (Conv2D) (None, 14, 14, 384) 24576 ['block_8_add[0][0]']
block_9_expand_BN (BatchNo (None, 14, 14, 384) 1536 ['block_9_expand[0][0]']
rmalization)
block_9_expand_relu (ReLU) (None, 14, 14, 384) 0 ['block_9_expand_BN[0][0]']
block_9_depthwise (Depthwi (None, 14, 14, 384) 3456 ['block_9_expand_relu[0][0]']
seConv2D)
block_9_depthwise_BN (Batc (None, 14, 14, 384) 1536 ['block_9_depthwise[0][0]']
hNormalization)
block_9_depthwise_relu (Re (None, 14, 14, 384) 0 ['block_9_depthwise_BN[0][0]']
LU)
block_9_project (Conv2D) (None, 14, 14, 64) 24576 ['block_9_depthwise_relu[0][0]
']
block_9_project_BN (BatchN (None, 14, 14, 64) 256 ['block_9_project[0][0]']
ormalization)
block_9_add (Add) (None, 14, 14, 64) 0 ['block_8_add[0][0]',
'block_9_project_BN[0][0]']
block_10_expand (Conv2D) (None, 14, 14, 384) 24576 ['block_9_add[0][0]']
block_10_expand_BN (BatchN (None, 14, 14, 384) 1536 ['block_10_expand[0][0]']
ormalization)
block_10_expand_relu (ReLU (None, 14, 14, 384) 0 ['block_10_expand_BN[0][0]']
)
block_10_depthwise (Depthw (None, 14, 14, 384) 3456 ['block_10_expand_relu[0][0]']
iseConv2D)
block_10_depthwise_BN (Bat (None, 14, 14, 384) 1536 ['block_10_depthwise[0][0]']
chNormalization)
block_10_depthwise_relu (R (None, 14, 14, 384) 0 ['block_10_depthwise_BN[0][0]'
eLU) ]
block_10_project (Conv2D) (None, 14, 14, 96) 36864 ['block_10_depthwise_relu[0][0
]']
block_10_project_BN (Batch (None, 14, 14, 96) 384 ['block_10_project[0][0]']
Normalization)
block_11_expand (Conv2D) (None, 14, 14, 576) 55296 ['block_10_project_BN[0][0]']
block_11_expand_BN (BatchN (None, 14, 14, 576) 2304 ['block_11_expand[0][0]']
ormalization)
block_11_expand_relu (ReLU (None, 14, 14, 576) 0 ['block_11_expand_BN[0][0]']
)
block_11_depthwise (Depthw (None, 14, 14, 576) 5184 ['block_11_expand_relu[0][0]']
iseConv2D)
block_11_depthwise_BN (Bat (None, 14, 14, 576) 2304 ['block_11_depthwise[0][0]']
chNormalization)
block_11_depthwise_relu (R (None, 14, 14, 576) 0 ['block_11_depthwise_BN[0][0]'
eLU) ]
block_11_project (Conv2D) (None, 14, 14, 96) 55296 ['block_11_depthwise_relu[0][0
]']
block_11_project_BN (Batch (None, 14, 14, 96) 384 ['block_11_project[0][0]']
Normalization)
block_11_add (Add) (None, 14, 14, 96) 0 ['block_10_project_BN[0][0]',
'block_11_project_BN[0][0]']
block_12_expand (Conv2D) (None, 14, 14, 576) 55296 ['block_11_add[0][0]']
block_12_expand_BN (BatchN (None, 14, 14, 576) 2304 ['block_12_expand[0][0]']
ormalization)
block_12_expand_relu (ReLU (None, 14, 14, 576) 0 ['block_12_expand_BN[0][0]']
)
block_12_depthwise (Depthw (None, 14, 14, 576) 5184 ['block_12_expand_relu[0][0]']
iseConv2D)
block_12_depthwise_BN (Bat (None, 14, 14, 576) 2304 ['block_12_depthwise[0][0]']
chNormalization)
block_12_depthwise_relu (R (None, 14, 14, 576) 0 ['block_12_depthwise_BN[0][0]'
eLU) ]
block_12_project (Conv2D) (None, 14, 14, 96) 55296 ['block_12_depthwise_relu[0][0
]']
block_12_project_BN (Batch (None, 14, 14, 96) 384 ['block_12_project[0][0]']
Normalization)
block_12_add (Add) (None, 14, 14, 96) 0 ['block_11_add[0][0]',
'block_12_project_BN[0][0]']
block_13_expand (Conv2D) (None, 14, 14, 576) 55296 ['block_12_add[0][0]']
block_13_expand_BN (BatchN (None, 14, 14, 576) 2304 ['block_13_expand[0][0]']
ormalization)
block_13_expand_relu (ReLU (None, 14, 14, 576) 0 ['block_13_expand_BN[0][0]']
)
block_13_pad (ZeroPadding2 (None, 15, 15, 576) 0 ['block_13_expand_relu[0][0]']
D)
block_13_depthwise (Depthw (None, 7, 7, 576) 5184 ['block_13_pad[0][0]']
iseConv2D)
block_13_depthwise_BN (Bat (None, 7, 7, 576) 2304 ['block_13_depthwise[0][0]']
chNormalization)
block_13_depthwise_relu (R (None, 7, 7, 576) 0 ['block_13_depthwise_BN[0][0]'
eLU) ]
block_13_project (Conv2D) (None, 7, 7, 160) 92160 ['block_13_depthwise_relu[0][0
]']
block_13_project_BN (Batch (None, 7, 7, 160) 640 ['block_13_project[0][0]']
Normalization)
block_14_expand (Conv2D) (None, 7, 7, 960) 153600 ['block_13_project_BN[0][0]']
block_14_expand_BN (BatchN (None, 7, 7, 960) 3840 ['block_14_expand[0][0]']
ormalization)
block_14_expand_relu (ReLU (None, 7, 7, 960) 0 ['block_14_expand_BN[0][0]']
)
block_14_depthwise (Depthw (None, 7, 7, 960) 8640 ['block_14_expand_relu[0][0]']
iseConv2D)
block_14_depthwise_BN (Bat (None, 7, 7, 960) 3840 ['block_14_depthwise[0][0]']
chNormalization)
block_14_depthwise_relu (R (None, 7, 7, 960) 0 ['block_14_depthwise_BN[0][0]'
eLU) ]
block_14_project (Conv2D) (None, 7, 7, 160) 153600 ['block_14_depthwise_relu[0][0
]']
block_14_project_BN (Batch (None, 7, 7, 160) 640 ['block_14_project[0][0]']
Normalization)
block_14_add (Add) (None, 7, 7, 160) 0 ['block_13_project_BN[0][0]',
'block_14_project_BN[0][0]']
block_15_expand (Conv2D) (None, 7, 7, 960) 153600 ['block_14_add[0][0]']
block_15_expand_BN (BatchN (None, 7, 7, 960) 3840 ['block_15_expand[0][0]']
ormalization)
block_15_expand_relu (ReLU (None, 7, 7, 960) 0 ['block_15_expand_BN[0][0]']
)
block_15_depthwise (Depthw (None, 7, 7, 960) 8640 ['block_15_expand_relu[0][0]']
iseConv2D)
block_15_depthwise_BN (Bat (None, 7, 7, 960) 3840 ['block_15_depthwise[0][0]']
chNormalization)
block_15_depthwise_relu (R (None, 7, 7, 960) 0 ['block_15_depthwise_BN[0][0]'
eLU) ]
block_15_project (Conv2D) (None, 7, 7, 160) 153600 ['block_15_depthwise_relu[0][0
]']
block_15_project_BN (Batch (None, 7, 7, 160) 640 ['block_15_project[0][0]']
Normalization)
block_15_add (Add) (None, 7, 7, 160) 0 ['block_14_add[0][0]',
'block_15_project_BN[0][0]']
block_16_expand (Conv2D) (None, 7, 7, 960) 153600 ['block_15_add[0][0]']
block_16_expand_BN (BatchN (None, 7, 7, 960) 3840 ['block_16_expand[0][0]']
ormalization)
block_16_expand_relu (ReLU (None, 7, 7, 960) 0 ['block_16_expand_BN[0][0]']
)
block_16_depthwise (Depthw (None, 7, 7, 960) 8640 ['block_16_expand_relu[0][0]']
iseConv2D)
block_16_depthwise_BN (Bat (None, 7, 7, 960) 3840 ['block_16_depthwise[0][0]']
chNormalization)
block_16_depthwise_relu (R (None, 7, 7, 960) 0 ['block_16_depthwise_BN[0][0]'
eLU) ]
block_16_project (Conv2D) (None, 7, 7, 320) 307200 ['block_16_depthwise_relu[0][0
]']
block_16_project_BN (Batch (None, 7, 7, 320) 1280 ['block_16_project[0][0]']
Normalization)
Conv_1 (Conv2D) (None, 7, 7, 1280) 409600 ['block_16_project_BN[0][0]']
Conv_1_bn (BatchNormalizat (None, 7, 7, 1280) 5120 ['Conv_1[0][0]']
ion)
out_relu (ReLU) (None, 7, 7, 1280) 0 ['Conv_1_bn[0][0]']
global_average_pooling2d ( (None, 1280) 0 ['out_relu[0][0]']
GlobalAveragePooling2D)
predictions (Dense) (None, 1000) 1281000 ['global_average_pooling2d[0][
0]']
==================================================================================================
Total params: 3538984 (13.50 MB)
Trainable params: 3504872 (13.37 MB)
Non-trainable params: 34112 (133.25 KB)
__________________________________________________________________________________________________
It can be seen that the final output of the model is
(None,
1000)
, which is the 1000-class classification on ImageNet. And
this challenge is actually a 3-classification problem, so
the final output layer needs to be modified. For the
TensorFlow Keras pre-trained model, although we can directly
add a fully connected layer with 3-class outputs at the
back, since the original output layer has already used the
Softmax activation, we still hope to directly modify the
output layer here.
Considering the need for fine-tuning, TensorFlow Keras
provides pre-trained models without a classifier (output
layer), and you only need to add the
include_top=False
parameter when importing the model.
# 读取并加载模型
mobilenet_notop = tf.keras.applications.MobileNetV2(include_top=False)
mobilenet_notop.summary()
WARNING:tensorflow:`input_shape` is undefined or non-square, or `rows` is not in [96, 128, 160, 192, 224]. Weights for input shape (224, 224) will be loaded as the default.
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
9406464/9406464 [==============================] - 1s 0us/step
Model: "mobilenetv2_1.00_224"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_2 (InputLayer) [(None, None, None, 3)] 0 []
Conv1 (Conv2D) (None, None, None, 32) 864 ['input_2[0][0]']
bn_Conv1 (BatchNormalizati (None, None, None, 32) 128 ['Conv1[0][0]']
on)
Conv1_relu (ReLU) (None, None, None, 32) 0 ['bn_Conv1[0][0]']
expanded_conv_depthwise (D (None, None, None, 32) 288 ['Conv1_relu[0][0]']
epthwiseConv2D)
expanded_conv_depthwise_BN (None, None, None, 32) 128 ['expanded_conv_depthwise[0][0
(BatchNormalization) ]']
expanded_conv_depthwise_re (None, None, None, 32) 0 ['expanded_conv_depthwise_BN[0
lu (ReLU) ][0]']
expanded_conv_project (Con (None, None, None, 16) 512 ['expanded_conv_depthwise_relu
v2D) [0][0]']
expanded_conv_project_BN ( (None, None, None, 16) 64 ['expanded_conv_project[0][0]'
BatchNormalization) ]
block_1_expand (Conv2D) (None, None, None, 96) 1536 ['expanded_conv_project_BN[0][
0]']
block_1_expand_BN (BatchNo (None, None, None, 96) 384 ['block_1_expand[0][0]']
rmalization)
block_1_expand_relu (ReLU) (None, None, None, 96) 0 ['block_1_expand_BN[0][0]']
block_1_pad (ZeroPadding2D (None, None, None, 96) 0 ['block_1_expand_relu[0][0]']
)
block_1_depthwise (Depthwi (None, None, None, 96) 864 ['block_1_pad[0][0]']
seConv2D)
block_1_depthwise_BN (Batc (None, None, None, 96) 384 ['block_1_depthwise[0][0]']
hNormalization)
block_1_depthwise_relu (Re (None, None, None, 96) 0 ['block_1_depthwise_BN[0][0]']
LU)
block_1_project (Conv2D) (None, None, None, 24) 2304 ['block_1_depthwise_relu[0][0]
']
block_1_project_BN (BatchN (None, None, None, 24) 96 ['block_1_project[0][0]']
ormalization)
block_2_expand (Conv2D) (None, None, None, 144) 3456 ['block_1_project_BN[0][0]']
block_2_expand_BN (BatchNo (None, None, None, 144) 576 ['block_2_expand[0][0]']
rmalization)
block_2_expand_relu (ReLU) (None, None, None, 144) 0 ['block_2_expand_BN[0][0]']
block_2_depthwise (Depthwi (None, None, None, 144) 1296 ['block_2_expand_relu[0][0]']
seConv2D)
block_2_depthwise_BN (Batc (None, None, None, 144) 576 ['block_2_depthwise[0][0]']
hNormalization)
block_2_depthwise_relu (Re (None, None, None, 144) 0 ['block_2_depthwise_BN[0][0]']
LU)
block_2_project (Conv2D) (None, None, None, 24) 3456 ['block_2_depthwise_relu[0][0]
']
block_2_project_BN (BatchN (None, None, None, 24) 96 ['block_2_project[0][0]']
ormalization)
block_2_add (Add) (None, None, None, 24) 0 ['block_1_project_BN[0][0]',
'block_2_project_BN[0][0]']
block_3_expand (Conv2D) (None, None, None, 144) 3456 ['block_2_add[0][0]']
block_3_expand_BN (BatchNo (None, None, None, 144) 576 ['block_3_expand[0][0]']
rmalization)
block_3_expand_relu (ReLU) (None, None, None, 144) 0 ['block_3_expand_BN[0][0]']
block_3_pad (ZeroPadding2D (None, None, None, 144) 0 ['block_3_expand_relu[0][0]']
)
block_3_depthwise (Depthwi (None, None, None, 144) 1296 ['block_3_pad[0][0]']
seConv2D)
block_3_depthwise_BN (Batc (None, None, None, 144) 576 ['block_3_depthwise[0][0]']
hNormalization)
block_3_depthwise_relu (Re (None, None, None, 144) 0 ['block_3_depthwise_BN[0][0]']
LU)
block_3_project (Conv2D) (None, None, None, 32) 4608 ['block_3_depthwise_relu[0][0]
']
block_3_project_BN (BatchN (None, None, None, 32) 128 ['block_3_project[0][0]']
ormalization)
block_4_expand (Conv2D) (None, None, None, 192) 6144 ['block_3_project_BN[0][0]']
block_4_expand_BN (BatchNo (None, None, None, 192) 768 ['block_4_expand[0][0]']
rmalization)
block_4_expand_relu (ReLU) (None, None, None, 192) 0 ['block_4_expand_BN[0][0]']
block_4_depthwise (Depthwi (None, None, None, 192) 1728 ['block_4_expand_relu[0][0]']
seConv2D)
block_4_depthwise_BN (Batc (None, None, None, 192) 768 ['block_4_depthwise[0][0]']
hNormalization)
block_4_depthwise_relu (Re (None, None, None, 192) 0 ['block_4_depthwise_BN[0][0]']
LU)
block_4_project (Conv2D) (None, None, None, 32) 6144 ['block_4_depthwise_relu[0][0]
']
block_4_project_BN (BatchN (None, None, None, 32) 128 ['block_4_project[0][0]']
ormalization)
block_4_add (Add) (None, None, None, 32) 0 ['block_3_project_BN[0][0]',
'block_4_project_BN[0][0]']
block_5_expand (Conv2D) (None, None, None, 192) 6144 ['block_4_add[0][0]']
block_5_expand_BN (BatchNo (None, None, None, 192) 768 ['block_5_expand[0][0]']
rmalization)
block_5_expand_relu (ReLU) (None, None, None, 192) 0 ['block_5_expand_BN[0][0]']
block_5_depthwise (Depthwi (None, None, None, 192) 1728 ['block_5_expand_relu[0][0]']
seConv2D)
block_5_depthwise_BN (Batc (None, None, None, 192) 768 ['block_5_depthwise[0][0]']
hNormalization)
block_5_depthwise_relu (Re (None, None, None, 192) 0 ['block_5_depthwise_BN[0][0]']
LU)
block_5_project (Conv2D) (None, None, None, 32) 6144 ['block_5_depthwise_relu[0][0]
']
block_5_project_BN (BatchN (None, None, None, 32) 128 ['block_5_project[0][0]']
ormalization)
block_5_add (Add) (None, None, None, 32) 0 ['block_4_add[0][0]',
'block_5_project_BN[0][0]']
block_6_expand (Conv2D) (None, None, None, 192) 6144 ['block_5_add[0][0]']
block_6_expand_BN (BatchNo (None, None, None, 192) 768 ['block_6_expand[0][0]']
rmalization)
block_6_expand_relu (ReLU) (None, None, None, 192) 0 ['block_6_expand_BN[0][0]']
block_6_pad (ZeroPadding2D (None, None, None, 192) 0 ['block_6_expand_relu[0][0]']
)
block_6_depthwise (Depthwi (None, None, None, 192) 1728 ['block_6_pad[0][0]']
seConv2D)
block_6_depthwise_BN (Batc (None, None, None, 192) 768 ['block_6_depthwise[0][0]']
hNormalization)
block_6_depthwise_relu (Re (None, None, None, 192) 0 ['block_6_depthwise_BN[0][0]']
LU)
block_6_project (Conv2D) (None, None, None, 64) 12288 ['block_6_depthwise_relu[0][0]
']
block_6_project_BN (BatchN (None, None, None, 64) 256 ['block_6_project[0][0]']
ormalization)
block_7_expand (Conv2D) (None, None, None, 384) 24576 ['block_6_project_BN[0][0]']
block_7_expand_BN (BatchNo (None, None, None, 384) 1536 ['block_7_expand[0][0]']
rmalization)
block_7_expand_relu (ReLU) (None, None, None, 384) 0 ['block_7_expand_BN[0][0]']
block_7_depthwise (Depthwi (None, None, None, 384) 3456 ['block_7_expand_relu[0][0]']
seConv2D)
block_7_depthwise_BN (Batc (None, None, None, 384) 1536 ['block_7_depthwise[0][0]']
hNormalization)
block_7_depthwise_relu (Re (None, None, None, 384) 0 ['block_7_depthwise_BN[0][0]']
LU)
block_7_project (Conv2D) (None, None, None, 64) 24576 ['block_7_depthwise_relu[0][0]
']
block_7_project_BN (BatchN (None, None, None, 64) 256 ['block_7_project[0][0]']
ormalization)
block_7_add (Add) (None, None, None, 64) 0 ['block_6_project_BN[0][0]',
'block_7_project_BN[0][0]']
block_8_expand (Conv2D) (None, None, None, 384) 24576 ['block_7_add[0][0]']
block_8_expand_BN (BatchNo (None, None, None, 384) 1536 ['block_8_expand[0][0]']
rmalization)
block_8_expand_relu (ReLU) (None, None, None, 384) 0 ['block_8_expand_BN[0][0]']
block_8_depthwise (Depthwi (None, None, None, 384) 3456 ['block_8_expand_relu[0][0]']
seConv2D)
block_8_depthwise_BN (Batc (None, None, None, 384) 1536 ['block_8_depthwise[0][0]']
hNormalization)
block_8_depthwise_relu (Re (None, None, None, 384) 0 ['block_8_depthwise_BN[0][0]']
LU)
block_8_project (Conv2D) (None, None, None, 64) 24576 ['block_8_depthwise_relu[0][0]
']
block_8_project_BN (BatchN (None, None, None, 64) 256 ['block_8_project[0][0]']
ormalization)
block_8_add (Add) (None, None, None, 64) 0 ['block_7_add[0][0]',
'block_8_project_BN[0][0]']
block_9_expand (Conv2D) (None, None, None, 384) 24576 ['block_8_add[0][0]']
block_9_expand_BN (BatchNo (None, None, None, 384) 1536 ['block_9_expand[0][0]']
rmalization)
block_9_expand_relu (ReLU) (None, None, None, 384) 0 ['block_9_expand_BN[0][0]']
block_9_depthwise (Depthwi (None, None, None, 384) 3456 ['block_9_expand_relu[0][0]']
seConv2D)
block_9_depthwise_BN (Batc (None, None, None, 384) 1536 ['block_9_depthwise[0][0]']
hNormalization)
block_9_depthwise_relu (Re (None, None, None, 384) 0 ['block_9_depthwise_BN[0][0]']
LU)
block_9_project (Conv2D) (None, None, None, 64) 24576 ['block_9_depthwise_relu[0][0]
']
block_9_project_BN (BatchN (None, None, None, 64) 256 ['block_9_project[0][0]']
ormalization)
block_9_add (Add) (None, None, None, 64) 0 ['block_8_add[0][0]',
'block_9_project_BN[0][0]']
block_10_expand (Conv2D) (None, None, None, 384) 24576 ['block_9_add[0][0]']
block_10_expand_BN (BatchN (None, None, None, 384) 1536 ['block_10_expand[0][0]']
ormalization)
block_10_expand_relu (ReLU (None, None, None, 384) 0 ['block_10_expand_BN[0][0]']
)
block_10_depthwise (Depthw (None, None, None, 384) 3456 ['block_10_expand_relu[0][0]']
iseConv2D)
block_10_depthwise_BN (Bat (None, None, None, 384) 1536 ['block_10_depthwise[0][0]']
chNormalization)
block_10_depthwise_relu (R (None, None, None, 384) 0 ['block_10_depthwise_BN[0][0]'
eLU) ]
block_10_project (Conv2D) (None, None, None, 96) 36864 ['block_10_depthwise_relu[0][0
]']
block_10_project_BN (Batch (None, None, None, 96) 384 ['block_10_project[0][0]']
Normalization)
block_11_expand (Conv2D) (None, None, None, 576) 55296 ['block_10_project_BN[0][0]']
block_11_expand_BN (BatchN (None, None, None, 576) 2304 ['block_11_expand[0][0]']
ormalization)
block_11_expand_relu (ReLU (None, None, None, 576) 0 ['block_11_expand_BN[0][0]']
)
block_11_depthwise (Depthw (None, None, None, 576) 5184 ['block_11_expand_relu[0][0]']
iseConv2D)
block_11_depthwise_BN (Bat (None, None, None, 576) 2304 ['block_11_depthwise[0][0]']
chNormalization)
block_11_depthwise_relu (R (None, None, None, 576) 0 ['block_11_depthwise_BN[0][0]'
eLU) ]
block_11_project (Conv2D) (None, None, None, 96) 55296 ['block_11_depthwise_relu[0][0
]']
block_11_project_BN (Batch (None, None, None, 96) 384 ['block_11_project[0][0]']
Normalization)
block_11_add (Add) (None, None, None, 96) 0 ['block_10_project_BN[0][0]',
'block_11_project_BN[0][0]']
block_12_expand (Conv2D) (None, None, None, 576) 55296 ['block_11_add[0][0]']
block_12_expand_BN (BatchN (None, None, None, 576) 2304 ['block_12_expand[0][0]']
ormalization)
block_12_expand_relu (ReLU (None, None, None, 576) 0 ['block_12_expand_BN[0][0]']
)
block_12_depthwise (Depthw (None, None, None, 576) 5184 ['block_12_expand_relu[0][0]']
iseConv2D)
block_12_depthwise_BN (Bat (None, None, None, 576) 2304 ['block_12_depthwise[0][0]']
chNormalization)
block_12_depthwise_relu (R (None, None, None, 576) 0 ['block_12_depthwise_BN[0][0]'
eLU) ]
block_12_project (Conv2D) (None, None, None, 96) 55296 ['block_12_depthwise_relu[0][0
]']
block_12_project_BN (Batch (None, None, None, 96) 384 ['block_12_project[0][0]']
Normalization)
block_12_add (Add) (None, None, None, 96) 0 ['block_11_add[0][0]',
'block_12_project_BN[0][0]']
block_13_expand (Conv2D) (None, None, None, 576) 55296 ['block_12_add[0][0]']
block_13_expand_BN (BatchN (None, None, None, 576) 2304 ['block_13_expand[0][0]']
ormalization)
block_13_expand_relu (ReLU (None, None, None, 576) 0 ['block_13_expand_BN[0][0]']
)
block_13_pad (ZeroPadding2 (None, None, None, 576) 0 ['block_13_expand_relu[0][0]']
D)
block_13_depthwise (Depthw (None, None, None, 576) 5184 ['block_13_pad[0][0]']
iseConv2D)
block_13_depthwise_BN (Bat (None, None, None, 576) 2304 ['block_13_depthwise[0][0]']
chNormalization)
block_13_depthwise_relu (R (None, None, None, 576) 0 ['block_13_depthwise_BN[0][0]'
eLU) ]
block_13_project (Conv2D) (None, None, None, 160) 92160 ['block_13_depthwise_relu[0][0
]']
block_13_project_BN (Batch (None, None, None, 160) 640 ['block_13_project[0][0]']
Normalization)
block_14_expand (Conv2D) (None, None, None, 960) 153600 ['block_13_project_BN[0][0]']
block_14_expand_BN (BatchN (None, None, None, 960) 3840 ['block_14_expand[0][0]']
ormalization)
block_14_expand_relu (ReLU (None, None, None, 960) 0 ['block_14_expand_BN[0][0]']
)
block_14_depthwise (Depthw (None, None, None, 960) 8640 ['block_14_expand_relu[0][0]']
iseConv2D)
block_14_depthwise_BN (Bat (None, None, None, 960) 3840 ['block_14_depthwise[0][0]']
chNormalization)
block_14_depthwise_relu (R (None, None, None, 960) 0 ['block_14_depthwise_BN[0][0]'
eLU) ]
block_14_project (Conv2D) (None, None, None, 160) 153600 ['block_14_depthwise_relu[0][0
]']
block_14_project_BN (Batch (None, None, None, 160) 640 ['block_14_project[0][0]']
Normalization)
block_14_add (Add) (None, None, None, 160) 0 ['block_13_project_BN[0][0]',
'block_14_project_BN[0][0]']
block_15_expand (Conv2D) (None, None, None, 960) 153600 ['block_14_add[0][0]']
block_15_expand_BN (BatchN (None, None, None, 960) 3840 ['block_15_expand[0][0]']
ormalization)
block_15_expand_relu (ReLU (None, None, None, 960) 0 ['block_15_expand_BN[0][0]']
)
block_15_depthwise (Depthw (None, None, None, 960) 8640 ['block_15_expand_relu[0][0]']
iseConv2D)
block_15_depthwise_BN (Bat (None, None, None, 960) 3840 ['block_15_depthwise[0][0]']
chNormalization)
block_15_depthwise_relu (R (None, None, None, 960) 0 ['block_15_depthwise_BN[0][0]'
eLU) ]
block_15_project (Conv2D) (None, None, None, 160) 153600 ['block_15_depthwise_relu[0][0
]']
block_15_project_BN (Batch (None, None, None, 160) 640 ['block_15_project[0][0]']
Normalization)
block_15_add (Add) (None, None, None, 160) 0 ['block_14_add[0][0]',
'block_15_project_BN[0][0]']
block_16_expand (Conv2D) (None, None, None, 960) 153600 ['block_15_add[0][0]']
block_16_expand_BN (BatchN (None, None, None, 960) 3840 ['block_16_expand[0][0]']
ormalization)
block_16_expand_relu (ReLU (None, None, None, 960) 0 ['block_16_expand_BN[0][0]']
)
block_16_depthwise (Depthw (None, None, None, 960) 8640 ['block_16_expand_relu[0][0]']
iseConv2D)
block_16_depthwise_BN (Bat (None, None, None, 960) 3840 ['block_16_depthwise[0][0]']
chNormalization)
block_16_depthwise_relu (R (None, None, None, 960) 0 ['block_16_depthwise_BN[0][0]'
eLU) ]
block_16_project (Conv2D) (None, None, None, 320) 307200 ['block_16_depthwise_relu[0][0
]']
block_16_project_BN (Batch (None, None, None, 320) 1280 ['block_16_project[0][0]']
Normalization)
Conv_1 (Conv2D) (None, None, None, 1280) 409600 ['block_16_project_BN[0][0]']
Conv_1_bn (BatchNormalizat (None, None, None, 1280) 5120 ['Conv_1[0][0]']
ion)
out_relu (ReLU) (None, None, None, 1280) 0 ['Conv_1_bn[0][0]']
==================================================================================================
Total params: 2257984 (8.61 MB)
Trainable params: 2223872 (8.48 MB)
Non-trainable params: 34112 (133.25 KB)
__________________________________________________________________________________________________
Therefore, next, we can simply add a custom output layer
after
mobilenet_notop
. Compared with the original complete model, we also need
to add a
GlobalAveragePooling2D
layer.
pool = tf.keras.layers.GlobalAveragePooling2D()(mobilenet_notop.output)
outs = tf.keras.layers.Dense(3, activation="softmax")(pool)
outs
<KerasTensor: shape=(None, 3) dtype=float32 (created by layer 'dense')>
Next, we can encapsulate the newly defined model.
model = tf.keras.models.Model(inputs=mobilenet_notop.input, outputs=outs)
model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_2 (InputLayer) [(None, None, None, 3)] 0 []
Conv1 (Conv2D) (None, None, None, 32) 864 ['input_2[0][0]']
bn_Conv1 (BatchNormalizati (None, None, None, 32) 128 ['Conv1[0][0]']
on)
Conv1_relu (ReLU) (None, None, None, 32) 0 ['bn_Conv1[0][0]']
expanded_conv_depthwise (D (None, None, None, 32) 288 ['Conv1_relu[0][0]']
epthwiseConv2D)
expanded_conv_depthwise_BN (None, None, None, 32) 128 ['expanded_conv_depthwise[0][0
(BatchNormalization) ]']
expanded_conv_depthwise_re (None, None, None, 32) 0 ['expanded_conv_depthwise_BN[0
lu (ReLU) ][0]']
expanded_conv_project (Con (None, None, None, 16) 512 ['expanded_conv_depthwise_relu
v2D) [0][0]']
expanded_conv_project_BN ( (None, None, None, 16) 64 ['expanded_conv_project[0][0]'
BatchNormalization) ]
block_1_expand (Conv2D) (None, None, None, 96) 1536 ['expanded_conv_project_BN[0][
0]']
block_1_expand_BN (BatchNo (None, None, None, 96) 384 ['block_1_expand[0][0]']
rmalization)
block_1_expand_relu (ReLU) (None, None, None, 96) 0 ['block_1_expand_BN[0][0]']
block_1_pad (ZeroPadding2D (None, None, None, 96) 0 ['block_1_expand_relu[0][0]']
)
block_1_depthwise (Depthwi (None, None, None, 96) 864 ['block_1_pad[0][0]']
seConv2D)
block_1_depthwise_BN (Batc (None, None, None, 96) 384 ['block_1_depthwise[0][0]']
hNormalization)
block_1_depthwise_relu (Re (None, None, None, 96) 0 ['block_1_depthwise_BN[0][0]']
LU)
block_1_project (Conv2D) (None, None, None, 24) 2304 ['block_1_depthwise_relu[0][0]
']
block_1_project_BN (BatchN (None, None, None, 24) 96 ['block_1_project[0][0]']
ormalization)
block_2_expand (Conv2D) (None, None, None, 144) 3456 ['block_1_project_BN[0][0]']
block_2_expand_BN (BatchNo (None, None, None, 144) 576 ['block_2_expand[0][0]']
rmalization)
block_2_expand_relu (ReLU) (None, None, None, 144) 0 ['block_2_expand_BN[0][0]']
block_2_depthwise (Depthwi (None, None, None, 144) 1296 ['block_2_expand_relu[0][0]']
seConv2D)
block_2_depthwise_BN (Batc (None, None, None, 144) 576 ['block_2_depthwise[0][0]']
hNormalization)
block_2_depthwise_relu (Re (None, None, None, 144) 0 ['block_2_depthwise_BN[0][0]']
LU)
block_2_project (Conv2D) (None, None, None, 24) 3456 ['block_2_depthwise_relu[0][0]
']
block_2_project_BN (BatchN (None, None, None, 24) 96 ['block_2_project[0][0]']
ormalization)
block_2_add (Add) (None, None, None, 24) 0 ['block_1_project_BN[0][0]',
'block_2_project_BN[0][0]']
block_3_expand (Conv2D) (None, None, None, 144) 3456 ['block_2_add[0][0]']
block_3_expand_BN (BatchNo (None, None, None, 144) 576 ['block_3_expand[0][0]']
rmalization)
block_3_expand_relu (ReLU) (None, None, None, 144) 0 ['block_3_expand_BN[0][0]']
block_3_pad (ZeroPadding2D (None, None, None, 144) 0 ['block_3_expand_relu[0][0]']
)
block_3_depthwise (Depthwi (None, None, None, 144) 1296 ['block_3_pad[0][0]']
seConv2D)
block_3_depthwise_BN (Batc (None, None, None, 144) 576 ['block_3_depthwise[0][0]']
hNormalization)
block_3_depthwise_relu (Re (None, None, None, 144) 0 ['block_3_depthwise_BN[0][0]']
LU)
block_3_project (Conv2D) (None, None, None, 32) 4608 ['block_3_depthwise_relu[0][0]
']
block_3_project_BN (BatchN (None, None, None, 32) 128 ['block_3_project[0][0]']
ormalization)
block_4_expand (Conv2D) (None, None, None, 192) 6144 ['block_3_project_BN[0][0]']
block_4_expand_BN (BatchNo (None, None, None, 192) 768 ['block_4_expand[0][0]']
rmalization)
block_4_expand_relu (ReLU) (None, None, None, 192) 0 ['block_4_expand_BN[0][0]']
block_4_depthwise (Depthwi (None, None, None, 192) 1728 ['block_4_expand_relu[0][0]']
seConv2D)
block_4_depthwise_BN (Batc (None, None, None, 192) 768 ['block_4_depthwise[0][0]']
hNormalization)
block_4_depthwise_relu (Re (None, None, None, 192) 0 ['block_4_depthwise_BN[0][0]']
LU)
block_4_project (Conv2D) (None, None, None, 32) 6144 ['block_4_depthwise_relu[0][0]
']
block_4_project_BN (BatchN (None, None, None, 32) 128 ['block_4_project[0][0]']
ormalization)
block_4_add (Add) (None, None, None, 32) 0 ['block_3_project_BN[0][0]',
'block_4_project_BN[0][0]']
block_5_expand (Conv2D) (None, None, None, 192) 6144 ['block_4_add[0][0]']
block_5_expand_BN (BatchNo (None, None, None, 192) 768 ['block_5_expand[0][0]']
rmalization)
block_5_expand_relu (ReLU) (None, None, None, 192) 0 ['block_5_expand_BN[0][0]']
block_5_depthwise (Depthwi (None, None, None, 192) 1728 ['block_5_expand_relu[0][0]']
seConv2D)
block_5_depthwise_BN (Batc (None, None, None, 192) 768 ['block_5_depthwise[0][0]']
hNormalization)
block_5_depthwise_relu (Re (None, None, None, 192) 0 ['block_5_depthwise_BN[0][0]']
LU)
block_5_project (Conv2D) (None, None, None, 32) 6144 ['block_5_depthwise_relu[0][0]
']
block_5_project_BN (BatchN (None, None, None, 32) 128 ['block_5_project[0][0]']
ormalization)
block_5_add (Add) (None, None, None, 32) 0 ['block_4_add[0][0]',
'block_5_project_BN[0][0]']
block_6_expand (Conv2D) (None, None, None, 192) 6144 ['block_5_add[0][0]']
block_6_expand_BN (BatchNo (None, None, None, 192) 768 ['block_6_expand[0][0]']
rmalization)
block_6_expand_relu (ReLU) (None, None, None, 192) 0 ['block_6_expand_BN[0][0]']
block_6_pad (ZeroPadding2D (None, None, None, 192) 0 ['block_6_expand_relu[0][0]']
)
block_6_depthwise (Depthwi (None, None, None, 192) 1728 ['block_6_pad[0][0]']
seConv2D)
block_6_depthwise_BN (Batc (None, None, None, 192) 768 ['block_6_depthwise[0][0]']
hNormalization)
block_6_depthwise_relu (Re (None, None, None, 192) 0 ['block_6_depthwise_BN[0][0]']
LU)
block_6_project (Conv2D) (None, None, None, 64) 12288 ['block_6_depthwise_relu[0][0]
']
block_6_project_BN (BatchN (None, None, None, 64) 256 ['block_6_project[0][0]']
ormalization)
block_7_expand (Conv2D) (None, None, None, 384) 24576 ['block_6_project_BN[0][0]']
block_7_expand_BN (BatchNo (None, None, None, 384) 1536 ['block_7_expand[0][0]']
rmalization)
block_7_expand_relu (ReLU) (None, None, None, 384) 0 ['block_7_expand_BN[0][0]']
block_7_depthwise (Depthwi (None, None, None, 384) 3456 ['block_7_expand_relu[0][0]']
seConv2D)
block_7_depthwise_BN (Batc (None, None, None, 384) 1536 ['block_7_depthwise[0][0]']
hNormalization)
block_7_depthwise_relu (Re (None, None, None, 384) 0 ['block_7_depthwise_BN[0][0]']
LU)
block_7_project (Conv2D) (None, None, None, 64) 24576 ['block_7_depthwise_relu[0][0]
']
block_7_project_BN (BatchN (None, None, None, 64) 256 ['block_7_project[0][0]']
ormalization)
block_7_add (Add) (None, None, None, 64) 0 ['block_6_project_BN[0][0]',
'block_7_project_BN[0][0]']
block_8_expand (Conv2D) (None, None, None, 384) 24576 ['block_7_add[0][0]']
block_8_expand_BN (BatchNo (None, None, None, 384) 1536 ['block_8_expand[0][0]']
rmalization)
block_8_expand_relu (ReLU) (None, None, None, 384) 0 ['block_8_expand_BN[0][0]']
block_8_depthwise (Depthwi (None, None, None, 384) 3456 ['block_8_expand_relu[0][0]']
seConv2D)
block_8_depthwise_BN (Batc (None, None, None, 384) 1536 ['block_8_depthwise[0][0]']
hNormalization)
block_8_depthwise_relu (Re (None, None, None, 384) 0 ['block_8_depthwise_BN[0][0]']
LU)
block_8_project (Conv2D) (None, None, None, 64) 24576 ['block_8_depthwise_relu[0][0]
']
block_8_project_BN (BatchN (None, None, None, 64) 256 ['block_8_project[0][0]']
ormalization)
block_8_add (Add) (None, None, None, 64) 0 ['block_7_add[0][0]',
'block_8_project_BN[0][0]']
block_9_expand (Conv2D) (None, None, None, 384) 24576 ['block_8_add[0][0]']
block_9_expand_BN (BatchNo (None, None, None, 384) 1536 ['block_9_expand[0][0]']
rmalization)
block_9_expand_relu (ReLU) (None, None, None, 384) 0 ['block_9_expand_BN[0][0]']
block_9_depthwise (Depthwi (None, None, None, 384) 3456 ['block_9_expand_relu[0][0]']
seConv2D)
block_9_depthwise_BN (Batc (None, None, None, 384) 1536 ['block_9_depthwise[0][0]']
hNormalization)
block_9_depthwise_relu (Re (None, None, None, 384) 0 ['block_9_depthwise_BN[0][0]']
LU)
block_9_project (Conv2D) (None, None, None, 64) 24576 ['block_9_depthwise_relu[0][0]
']
block_9_project_BN (BatchN (None, None, None, 64) 256 ['block_9_project[0][0]']
ormalization)
block_9_add (Add) (None, None, None, 64) 0 ['block_8_add[0][0]',
'block_9_project_BN[0][0]']
block_10_expand (Conv2D) (None, None, None, 384) 24576 ['block_9_add[0][0]']
block_10_expand_BN (BatchN (None, None, None, 384) 1536 ['block_10_expand[0][0]']
ormalization)
block_10_expand_relu (ReLU (None, None, None, 384) 0 ['block_10_expand_BN[0][0]']
)
block_10_depthwise (Depthw (None, None, None, 384) 3456 ['block_10_expand_relu[0][0]']
iseConv2D)
block_10_depthwise_BN (Bat (None, None, None, 384) 1536 ['block_10_depthwise[0][0]']
chNormalization)
block_10_depthwise_relu (R (None, None, None, 384) 0 ['block_10_depthwise_BN[0][0]'
eLU) ]
block_10_project (Conv2D) (None, None, None, 96) 36864 ['block_10_depthwise_relu[0][0
]']
block_10_project_BN (Batch (None, None, None, 96) 384 ['block_10_project[0][0]']
Normalization)
block_11_expand (Conv2D) (None, None, None, 576) 55296 ['block_10_project_BN[0][0]']
block_11_expand_BN (BatchN (None, None, None, 576) 2304 ['block_11_expand[0][0]']
ormalization)
block_11_expand_relu (ReLU (None, None, None, 576) 0 ['block_11_expand_BN[0][0]']
)
block_11_depthwise (Depthw (None, None, None, 576) 5184 ['block_11_expand_relu[0][0]']
iseConv2D)
block_11_depthwise_BN (Bat (None, None, None, 576) 2304 ['block_11_depthwise[0][0]']
chNormalization)
block_11_depthwise_relu (R (None, None, None, 576) 0 ['block_11_depthwise_BN[0][0]'
eLU) ]
block_11_project (Conv2D) (None, None, None, 96) 55296 ['block_11_depthwise_relu[0][0
]']
block_11_project_BN (Batch (None, None, None, 96) 384 ['block_11_project[0][0]']
Normalization)
block_11_add (Add) (None, None, None, 96) 0 ['block_10_project_BN[0][0]',
'block_11_project_BN[0][0]']
block_12_expand (Conv2D) (None, None, None, 576) 55296 ['block_11_add[0][0]']
block_12_expand_BN (BatchN (None, None, None, 576) 2304 ['block_12_expand[0][0]']
ormalization)
block_12_expand_relu (ReLU (None, None, None, 576) 0 ['block_12_expand_BN[0][0]']
)
block_12_depthwise (Depthw (None, None, None, 576) 5184 ['block_12_expand_relu[0][0]']
iseConv2D)
block_12_depthwise_BN (Bat (None, None, None, 576) 2304 ['block_12_depthwise[0][0]']
chNormalization)
block_12_depthwise_relu (R (None, None, None, 576) 0 ['block_12_depthwise_BN[0][0]'
eLU) ]
block_12_project (Conv2D) (None, None, None, 96) 55296 ['block_12_depthwise_relu[0][0
]']
block_12_project_BN (Batch (None, None, None, 96) 384 ['block_12_project[0][0]']
Normalization)
block_12_add (Add) (None, None, None, 96) 0 ['block_11_add[0][0]',
'block_12_project_BN[0][0]']
block_13_expand (Conv2D) (None, None, None, 576) 55296 ['block_12_add[0][0]']
block_13_expand_BN (BatchN (None, None, None, 576) 2304 ['block_13_expand[0][0]']
ormalization)
block_13_expand_relu (ReLU (None, None, None, 576) 0 ['block_13_expand_BN[0][0]']
)
block_13_pad (ZeroPadding2 (None, None, None, 576) 0 ['block_13_expand_relu[0][0]']
D)
block_13_depthwise (Depthw (None, None, None, 576) 5184 ['block_13_pad[0][0]']
iseConv2D)
block_13_depthwise_BN (Bat (None, None, None, 576) 2304 ['block_13_depthwise[0][0]']
chNormalization)
block_13_depthwise_relu (R (None, None, None, 576) 0 ['block_13_depthwise_BN[0][0]'
eLU) ]
block_13_project (Conv2D) (None, None, None, 160) 92160 ['block_13_depthwise_relu[0][0
]']
block_13_project_BN (Batch (None, None, None, 160) 640 ['block_13_project[0][0]']
Normalization)
block_14_expand (Conv2D) (None, None, None, 960) 153600 ['block_13_project_BN[0][0]']
block_14_expand_BN (BatchN (None, None, None, 960) 3840 ['block_14_expand[0][0]']
ormalization)
block_14_expand_relu (ReLU (None, None, None, 960) 0 ['block_14_expand_BN[0][0]']
)
block_14_depthwise (Depthw (None, None, None, 960) 8640 ['block_14_expand_relu[0][0]']
iseConv2D)
block_14_depthwise_BN (Bat (None, None, None, 960) 3840 ['block_14_depthwise[0][0]']
chNormalization)
block_14_depthwise_relu (R (None, None, None, 960) 0 ['block_14_depthwise_BN[0][0]'
eLU) ]
block_14_project (Conv2D) (None, None, None, 160) 153600 ['block_14_depthwise_relu[0][0
]']
block_14_project_BN (Batch (None, None, None, 160) 640 ['block_14_project[0][0]']
Normalization)
block_14_add (Add) (None, None, None, 160) 0 ['block_13_project_BN[0][0]',
'block_14_project_BN[0][0]']
block_15_expand (Conv2D) (None, None, None, 960) 153600 ['block_14_add[0][0]']
block_15_expand_BN (BatchN (None, None, None, 960) 3840 ['block_15_expand[0][0]']
ormalization)
block_15_expand_relu (ReLU (None, None, None, 960) 0 ['block_15_expand_BN[0][0]']
)
block_15_depthwise (Depthw (None, None, None, 960) 8640 ['block_15_expand_relu[0][0]']
iseConv2D)
block_15_depthwise_BN (Bat (None, None, None, 960) 3840 ['block_15_depthwise[0][0]']
chNormalization)
block_15_depthwise_relu (R (None, None, None, 960) 0 ['block_15_depthwise_BN[0][0]'
eLU) ]
block_15_project (Conv2D) (None, None, None, 160) 153600 ['block_15_depthwise_relu[0][0
]']
block_15_project_BN (Batch (None, None, None, 160) 640 ['block_15_project[0][0]']
Normalization)
block_15_add (Add) (None, None, None, 160) 0 ['block_14_add[0][0]',
'block_15_project_BN[0][0]']
block_16_expand (Conv2D) (None, None, None, 960) 153600 ['block_15_add[0][0]']
block_16_expand_BN (BatchN (None, None, None, 960) 3840 ['block_16_expand[0][0]']
ormalization)
block_16_expand_relu (ReLU (None, None, None, 960) 0 ['block_16_expand_BN[0][0]']
)
block_16_depthwise (Depthw (None, None, None, 960) 8640 ['block_16_expand_relu[0][0]']
iseConv2D)
block_16_depthwise_BN (Bat (None, None, None, 960) 3840 ['block_16_depthwise[0][0]']
chNormalization)
block_16_depthwise_relu (R (None, None, None, 960) 0 ['block_16_depthwise_BN[0][0]'
eLU) ]
block_16_project (Conv2D) (None, None, None, 320) 307200 ['block_16_depthwise_relu[0][0
]']
block_16_project_BN (Batch (None, None, None, 320) 1280 ['block_16_project[0][0]']
Normalization)
Conv_1 (Conv2D) (None, None, None, 1280) 409600 ['block_16_project_BN[0][0]']
Conv_1_bn (BatchNormalizat (None, None, None, 1280) 5120 ['Conv_1[0][0]']
ion)
out_relu (ReLU) (None, None, None, 1280) 0 ['Conv_1_bn[0][0]']
global_average_pooling2d_1 (None, 1280) 0 ['out_relu[0][0]']
(GlobalAveragePooling2D)
dense (Dense) (None, 3) 3843 ['global_average_pooling2d_1[0
][0]']
==================================================================================================
Total params: 2261827 (8.63 MB)
Trainable params: 2227715 (8.50 MB)
Non-trainable params: 34112 (133.25 KB)
__________________________________________________________________________________________________
Compared with the original MobileNetV2, it has now been changed into the model we need.
In PyTorch, when we want to freeze the parameters of some
layers in a pre-trained model, we can use
requires_grad
=
False
. In TensorFlow, we use the following method:
for layer in model.layers[:-2]:
layer.trainable = False # 固定前 2 层参数
for layer in model.layers[-2:]:
layer.trainable = True # 后两层可训练
Next, use
tf.keras.preprocessing.image.ImageDataGenerator
🔗 to
build data preprocessing rules. At the same time, use the
exactly the same preprocessing method as the pre-trainable
model by calling
tf.keras.applications.mobilenet_v2.preprocess_input
. Finally, use the
.flow_from_directory
method to load the data folder. The whole process seems
complex, but the code is actually very clear.
train_data = tf.keras.preprocessing.image.ImageDataGenerator(
preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input
)
train_generator = train_data.flow_from_directory(
"train/",
target_size=(224, 224),
color_mode="rgb",
batch_size=16,
class_mode="categorical",
shuffle=True,
)
Found 150 images belonging to 3 classes.
As can be seen, the data loader automatically labels the
data by subfolder. 150 training images are loaded for 3
categories. You can view the classification labels through
train_generator.classes
and try to view the meanings of other attributes by
yourself.
Finally, we compile the model and complete the training.
Note that here we use
fit_generator
🔗
to read images from the data loader (iterator) and train.
model.compile(optimizer="Adam", loss="categorical_crossentropy", metrics=["accuracy"])
history = model.fit_generator(generator=train_generator, steps_per_epoch=10, epochs=5)
Epoch 1/5
10/10 [==============================] - 2s 109ms/step - loss: 0.8937 - accuracy: 0.5667
Epoch 2/5
10/10 [==============================] - 1s 99ms/step - loss: 0.3466 - accuracy: 0.9133
Epoch 3/5
10/10 [==============================] - 1s 97ms/step - loss: 0.1866 - accuracy: 0.9733
Epoch 4/5
10/10 [==============================] - 1s 100ms/step - loss: 0.1223 - accuracy: 0.9733
Epoch 5/5
10/10 [==============================] - 1s 115ms/step - loss: 0.0877 - accuracy: 1.0000
During the training process, we can assign
fit_generator
to a variable
history
, and then load the changes in metrics such as training
accuracy and loss from it and plot them. Since no validation
set is set here, the relevant part of the code is commented
out below.
from matplotlib import pyplot as plt
%matplotlib inline
# 训练集准确度和损失
acc = history.history["accuracy"]
loss = history.history["loss"]
# 验证集准确度和损失
# val_acc = history.history['val_acc']
# val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, "r", label="Training accuracy")
plt.plot(epochs, loss, "b", label="Training loss")
plt.title("Training accuracy and loss")
plt.legend(loc=0)
Since this is the first time dealing with TensorFlow Keras transfer learning, the first half of the challenge provides a detailed training process. Next, you need to complete the inference of the model on your own inspired by the previous content.
The above training process takes a little time. You can start thinking about the challenge questions.
Exercise 63.1
Open-ended Challenge
Challenge: Save the trained TensorFlow Keras model in the correct format and implement inference on images from any link.
Hint: Please use various tools such as search engines to think and solve this problem on your own. The implementation method is not fixed, but the goal is similar to the GIF demonstration below.
## 补充代码 ###
Solution to Exercise 63.1
import requests
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline
IMAGE_URL = input("Please enter the image link:")
res = requests.get(IMAGE_URL)
with open("test.jpg", "wb") as f:
f.write(res.content)
image = tf.keras.preprocessing.image.load_img(
"test.jpg", target_size=(224, 224))
image_a = tf.keras.preprocessing.image.img_to_array(image)
image_e = np.expand_dims(image_a, axis=0)
image_i = tf.keras.applications.mobilenet_v2.preprocess_input(image_e)
preds = np.argmax(model.predict(image_i))
if preds == 0:
title = "cat"
elif preds == 1:
title = "dog"
else:
title = "horse"
plt.title("prediction: {}".format(title))
plt.imshow(image)
Expected output:
Finally, we expect to be able to perform inference on images
from any external link. In the notebook, interactive input
can be achieved through
input()
. After filling in the link and pressing Enter, the input
image and the prediction result can be obtained.
Specifically, the prediction result is displayed in the
image title. The displayed image can be the original image
or the preprocessed image, which is not required. The
dynamic display process is as follows:
