利用keras实现MobileNet,并以mnist数据集作为一个小例子进行识别。使用的环境是:tensorflow-gpu 2.0,python=3.7 , GTX-2070的GPU

1.导入数据

  • 首先是导入两行魔法命令,可以多行显示.
  1. %config InteractiveShell.ast_node_interactivity="all"
  2. %pprint
  • 加载keras中自带的mnist数据
  1. import tensorflow as tf
  2. import keras
  3. tf.debugging.set_log_device_placement(True)
  4. mnist = keras.datasets.mnist
  5. (x_train,y_train),(x_test,y_test) = mnist.load_data()

上述tf.debugging.set_log_device_placement(True)的作用是将模型放在GPU上进行训练。

  • 数据的转换

    在mnist上下载的数据的分辨率是2828的,mobilenet用来训练的数据是ImageNet ,其图片的分辨率是224224,所以先将图片的维度调整为224*224.
  1. from PIL import Image
  2. import numpy as np
  3. def convert_mnist_224pix(X):
  4. img=Image.fromarray(X)
  5. x=np.zeros((224,224))
  6. img=np.array(img.resize((224,224)))
  7. x[:,:]=img
  8. return x
  9. iteration = iter(x_train)
  10. new_train =np.zeros((len(x_train),224,224),dtype=np.float32)
  11. for i in range(len(x_train)):
  12. data = next(iteration)
  13. new_train[i]=convert_mnist_224pix(data)
  14. if i%5000==0:
  15. print(i)
  16. new_train.shape

这里要注意一下,new_train中一定要注明dtype=np.float32,不然默认的是float64,这样数据就太大了,没有那么多存储空间装。最后输出的维度是(60000,224,224)

2.搭建模型

  • 导入所有需要的函数和库
  1. from keras.layers import Conv2D,DepthwiseConv2D,Dense,AveragePooling2D,BatchNormalization,Input
  2. from keras import Model
  3. from keras import Sequential
  4. from keras.layers.advanced_activations import ReLU
  5. from keras.utils import to_categorical
  • 自己定义中间可以重复利用的层,将其放在一起,简化搭建网络的重复代码。
  1. def depth_point_conv2d(x,s=[1,1,2,1],channel=[64,128]):
  2. """
  3. s:the strides of the conv
  4. channel: the depth of pointwiseconvolutions
  5. """
  6. dw1 = DepthwiseConv2D((3,3),strides=s[0],padding='same')(x)
  7. bn1 = BatchNormalization()(dw1)
  8. relu1 = ReLU()(bn1)
  9. pw1 = Conv2D(channel[0],(1,1),strides=s[1],padding='same')(relu1)
  10. bn2 = BatchNormalization()(pw1)
  11. relu2 = ReLU()(bn2)
  12. dw2 = DepthwiseConv2D((3,3),strides=s[2],padding='same')(relu2)
  13. bn3 = BatchNormalization()(dw2)
  14. relu3 = ReLU()(bn3)
  15. pw2 = Conv2D(channel[1],(1,1),strides=s[3],padding='same')(relu3)
  16. bn4 = BatchNormalization()(pw2)
  17. relu4 = ReLU()(bn4)
  18. return relu4
  19. def repeat_conv(x,s=[1,1],channel=512):
  20. dw1 = DepthwiseConv2D((3,3),strides=s[0],padding='same')(x)
  21. bn1 = BatchNormalization()(dw1)
  22. relu1 = ReLU()(bn1)
  23. pw1 = Conv2D(channel,(1,1),strides=s[1],padding='same')(relu1)
  24. bn2 = BatchNormalization()(pw1)
  25. relu2 = ReLU()(bn2)
  26. return relu2

根据mobilenet论文中的结构进行模型的搭建

在倒数第5行Conv/dw/s2中,我一直不理解如果strides=2,为什么最后生成图片尺寸没有变化,我感觉可能是笔误?,不过我这里将strides定义为1,因为这样才符合后面的整个输出。

  • 搭建网络
  1. h0=Input(shape=(224,224,1))
  2. h1=Conv2D(32,(3,3),strides = 2,padding="same")(h0)
  3. h2= BatchNormalization()(h1)
  4. h3=ReLU()(h2)
  5. h4 = depth_point_conv2d(h3,s=[1,1,2,1],channel=[64,128])
  6. h5 = depth_point_conv2d(h4,s=[1,1,2,1],channel=[128,256])
  7. h6 = depth_point_conv2d(h5,s=[1,1,2,1],channel=[256,512])
  8. h7 = repeat_conv(h6)
  9. h8 = repeat_conv(h7)
  10. h9 = repeat_conv(h8)
  11. h10 = repeat_conv(h9)
  12. h11 = depth_point_conv2d(h10,s=[1,1,2,1],channel=[512,1024])
  13. h12 = repeat_conv(h11,channel=1024)
  14. h13 = AveragePooling2D((7,7))(h12)
  15. h14 = Dense(10,activation='softmax')(h13)
  16. model =Model(input=h0,output =h14)
  17. model.summary()
  1. Model: "model_4"
  2. _________________________________________________________________
  3. Layer (type) Output Shape Param #
  4. =================================================================
  5. input_11 (InputLayer) (None, 224, 224, 1) 0
  6. _________________________________________________________________
  7. conv2d_63 (Conv2D) (None, 112, 112, 32) 320
  8. _________________________________________________________________
  9. batch_normalization_120 (Bat (None, 112, 112, 32) 128
  10. _________________________________________________________________
  11. re_lu_120 (ReLU) (None, 112, 112, 32) 0
  12. _________________________________________________________________
  13. depthwise_conv2d_58 (Depthwi (None, 112, 112, 32) 320
  14. _________________________________________________________________
  15. batch_normalization_121 (Bat (None, 112, 112, 32) 128
  16. _________________________________________________________________
  17. re_lu_121 (ReLU) (None, 112, 112, 32) 0
  18. _________________________________________________________________
  19. conv2d_64 (Conv2D) (None, 112, 112, 64) 2112
  20. _________________________________________________________________
  21. batch_normalization_122 (Bat (None, 112, 112, 64) 256
  22. _________________________________________________________________
  23. re_lu_122 (ReLU) (None, 112, 112, 64) 0
  24. _________________________________________________________________
  25. depthwise_conv2d_59 (Depthwi (None, 56, 56, 64) 640
  26. _________________________________________________________________
  27. batch_normalization_123 (Bat (None, 56, 56, 64) 256
  28. _________________________________________________________________
  29. re_lu_123 (ReLU) (None, 56, 56, 64) 0
  30. _________________________________________________________________
  31. conv2d_65 (Conv2D) (None, 56, 56, 128) 8320
  32. _________________________________________________________________
  33. batch_normalization_124 (Bat (None, 56, 56, 128) 512
  34. _________________________________________________________________
  35. re_lu_124 (ReLU) (None, 56, 56, 128) 0
  36. _________________________________________________________________
  37. depthwise_conv2d_60 (Depthwi (None, 56, 56, 128) 1280
  38. _________________________________________________________________
  39. batch_normalization_125 (Bat (None, 56, 56, 128) 512
  40. _________________________________________________________________
  41. re_lu_125 (ReLU) (None, 56, 56, 128) 0
  42. _________________________________________________________________
  43. conv2d_66 (Conv2D) (None, 56, 56, 128) 16512
  44. _________________________________________________________________
  45. batch_normalization_126 (Bat (None, 56, 56, 128) 512
  46. _________________________________________________________________
  47. re_lu_126 (ReLU) (None, 56, 56, 128) 0
  48. _________________________________________________________________
  49. depthwise_conv2d_61 (Depthwi (None, 28, 28, 128) 1280
  50. _________________________________________________________________
  51. batch_normalization_127 (Bat (None, 28, 28, 128) 512
  52. _________________________________________________________________
  53. re_lu_127 (ReLU) (None, 28, 28, 128) 0
  54. _________________________________________________________________
  55. conv2d_67 (Conv2D) (None, 28, 28, 256) 33024
  56. _________________________________________________________________
  57. batch_normalization_128 (Bat (None, 28, 28, 256) 1024
  58. _________________________________________________________________
  59. re_lu_128 (ReLU) (None, 28, 28, 256) 0
  60. _________________________________________________________________
  61. depthwise_conv2d_62 (Depthwi (None, 28, 28, 256) 2560
  62. _________________________________________________________________
  63. batch_normalization_129 (Bat (None, 28, 28, 256) 1024
  64. _________________________________________________________________
  65. re_lu_129 (ReLU) (None, 28, 28, 256) 0
  66. _________________________________________________________________
  67. conv2d_68 (Conv2D) (None, 28, 28, 256) 65792
  68. _________________________________________________________________
  69. batch_normalization_130 (Bat (None, 28, 28, 256) 1024
  70. _________________________________________________________________
  71. re_lu_130 (ReLU) (None, 28, 28, 256) 0
  72. _________________________________________________________________
  73. depthwise_conv2d_63 (Depthwi (None, 14, 14, 256) 2560
  74. _________________________________________________________________
  75. batch_normalization_131 (Bat (None, 14, 14, 256) 1024
  76. _________________________________________________________________
  77. re_lu_131 (ReLU) (None, 14, 14, 256) 0
  78. _________________________________________________________________
  79. conv2d_69 (Conv2D) (None, 14, 14, 512) 131584
  80. _________________________________________________________________
  81. batch_normalization_132 (Bat (None, 14, 14, 512) 2048
  82. _________________________________________________________________
  83. re_lu_132 (ReLU) (None, 14, 14, 512) 0
  84. _________________________________________________________________
  85. depthwise_conv2d_64 (Depthwi (None, 14, 14, 512) 5120
  86. _________________________________________________________________
  87. batch_normalization_133 (Bat (None, 14, 14, 512) 2048
  88. _________________________________________________________________
  89. re_lu_133 (ReLU) (None, 14, 14, 512) 0
  90. _________________________________________________________________
  91. conv2d_70 (Conv2D) (None, 14, 14, 512) 262656
  92. _________________________________________________________________
  93. batch_normalization_134 (Bat (None, 14, 14, 512) 2048
  94. _________________________________________________________________
  95. re_lu_134 (ReLU) (None, 14, 14, 512) 0
  96. _________________________________________________________________
  97. depthwise_conv2d_65 (Depthwi (None, 14, 14, 512) 5120
  98. _________________________________________________________________
  99. batch_normalization_135 (Bat (None, 14, 14, 512) 2048
  100. _________________________________________________________________
  101. re_lu_135 (ReLU) (None, 14, 14, 512) 0
  102. _________________________________________________________________
  103. conv2d_71 (Conv2D) (None, 14, 14, 512) 262656
  104. _________________________________________________________________
  105. batch_normalization_136 (Bat (None, 14, 14, 512) 2048
  106. _________________________________________________________________
  107. re_lu_136 (ReLU) (None, 14, 14, 512) 0
  108. _________________________________________________________________
  109. depthwise_conv2d_66 (Depthwi (None, 14, 14, 512) 5120
  110. _________________________________________________________________
  111. batch_normalization_137 (Bat (None, 14, 14, 512) 2048
  112. _________________________________________________________________
  113. re_lu_137 (ReLU) (None, 14, 14, 512) 0
  114. _________________________________________________________________
  115. conv2d_72 (Conv2D) (None, 14, 14, 512) 262656
  116. _________________________________________________________________
  117. batch_normalization_138 (Bat (None, 14, 14, 512) 2048
  118. _________________________________________________________________
  119. re_lu_138 (ReLU) (None, 14, 14, 512) 0
  120. _________________________________________________________________
  121. depthwise_conv2d_67 (Depthwi (None, 14, 14, 512) 5120
  122. _________________________________________________________________
  123. batch_normalization_139 (Bat (None, 14, 14, 512) 2048
  124. _________________________________________________________________
  125. re_lu_139 (ReLU) (None, 14, 14, 512) 0
  126. _________________________________________________________________
  127. conv2d_73 (Conv2D) (None, 14, 14, 512) 262656
  128. _________________________________________________________________
  129. batch_normalization_140 (Bat (None, 14, 14, 512) 2048
  130. _________________________________________________________________
  131. re_lu_140 (ReLU) (None, 14, 14, 512) 0
  132. _________________________________________________________________
  133. depthwise_conv2d_68 (Depthwi (None, 14, 14, 512) 5120
  134. _________________________________________________________________
  135. batch_normalization_141 (Bat (None, 14, 14, 512) 2048
  136. _________________________________________________________________
  137. re_lu_141 (ReLU) (None, 14, 14, 512) 0
  138. _________________________________________________________________
  139. conv2d_74 (Conv2D) (None, 14, 14, 512) 262656
  140. _________________________________________________________________
  141. batch_normalization_142 (Bat (None, 14, 14, 512) 2048
  142. _________________________________________________________________
  143. re_lu_142 (ReLU) (None, 14, 14, 512) 0
  144. _________________________________________________________________
  145. depthwise_conv2d_69 (Depthwi (None, 7, 7, 512) 5120
  146. _________________________________________________________________
  147. batch_normalization_143 (Bat (None, 7, 7, 512) 2048
  148. _________________________________________________________________
  149. re_lu_143 (ReLU) (None, 7, 7, 512) 0
  150. _________________________________________________________________
  151. conv2d_75 (Conv2D) (None, 7, 7, 1024) 525312
  152. _________________________________________________________________
  153. batch_normalization_144 (Bat (None, 7, 7, 1024) 4096
  154. _________________________________________________________________
  155. re_lu_144 (ReLU) (None, 7, 7, 1024) 0
  156. _________________________________________________________________
  157. depthwise_conv2d_70 (Depthwi (None, 7, 7, 1024) 10240
  158. _________________________________________________________________
  159. batch_normalization_145 (Bat (None, 7, 7, 1024) 4096
  160. _________________________________________________________________
  161. re_lu_145 (ReLU) (None, 7, 7, 1024) 0
  162. _________________________________________________________________
  163. conv2d_76 (Conv2D) (None, 7, 7, 1024) 1049600
  164. _________________________________________________________________
  165. batch_normalization_146 (Bat (None, 7, 7, 1024) 4096
  166. _________________________________________________________________
  167. re_lu_146 (ReLU) (None, 7, 7, 1024) 0
  168. _________________________________________________________________
  169. average_pooling2d_5 (Average (None, 1, 1, 1024) 0
  170. _________________________________________________________________
  171. dense_4 (Dense) (None, 1, 1, 10) 10250
  172. =================================================================
  173. Total params: 3,249,482
  174. Trainable params: 3,227,594
  175. Non-trainable params: 21,888
  176. _________________________________________________________________

因为这里的类别只有10类,所以最后的输出层只有10个神经元,原始的mobilenet要进行1000个类别分类,所以最后是1000个神经元。

  1. model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])

上述代码定义优化算法和损失函数。

3、训练数据的整理与训练

将训练数据进行维度变换,标签进行one-hot编码并进行维度变换。

  1. x_train = np.expand_dims(new_train,3)
  2. y_train = to_categorical(y_train)
  3. y=np.expand_dims(y_train,1)
  4. y = np.expand_dims(y,1)
  • 定义数据生成函数
  1. def data_generate(x_train,y_train,batch_size,epochs):
  2. for i in range(epochs):
  3. batch_num = len(x_train)//batch_size
  4. shuffle_index = np.arange(batch_num)
  5. np.random.shuffle(shuffle_index)
  6. for j in shuffle_index:
  7. begin = j*batch_size
  8. end =begin+batch_size
  9. x = x_train[begin:end]
  10. y = y_train[begin:end]
  11. yield ({"input_11":x},{"dense_4":y})

上述命名和model中的第一层和最后一层名字一样,不然会报错。

  • 开始训练
  1. model.fit_generator(data_generate(x_train,y,100,11),step_per_epoch=600,epochs=10)

训练过程图如下:

  1. Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0
  2. Epoch 1/10
  3. Executing op __inference_keras_scratch_graph_22639 in device /job:localhost/replica:0/task:0/device:GPU:0
  4. 600/600 [==============================] - 411s 684ms/step - loss: 0.1469 - accuracy: 0.9529
  5. Epoch 2/10
  6. 600/600 [==============================] - 398s 663ms/step - loss: 0.0375 - accuracy: 0.9884
  7. Epoch 3/10
  8. 600/600 [==============================] - 401s 668ms/step - loss: 0.0283 - accuracy: 0.9909
  9. Epoch 4/10
  10. 600/600 [==============================] - 399s 665ms/step - loss: 0.0211 - accuracy: 0.9936
  11. Epoch 5/10
  12. 600/600 [==============================] - 400s 666ms/step - loss: 0.0216 - accuracy: 0.9932
  13. Epoch 6/10
  14. 600/600 [==============================] - 401s 668ms/step - loss: 0.0208 - accuracy: 0.9935
  15. Epoch 7/10
  16. 600/600 [==============================] - 401s 669ms/step - loss: 0.0174 - accuracy: 0.9945
  17. Epoch 8/10
  18. 131/600 [=====>........................] - ETA: 5:13 - loss: 0.0091 - accuracy: 0.9973

模型卷积比较多,需要训练的时间有点长,参数不多,所以更新较快,收敛速度也很快。

keras实现MobileNet的更多相关文章

  1. keras中使用预训练模型进行图片分类

    keras中含有多个网络的预训练模型,可以很方便的拿来进行使用. 安装及使用主要参考官方教程:https://keras.io/zh/applications/   https://keras-cn. ...

  2. 我的Keras使用总结(4)——Application中五款预训练模型学习及其应用

    本节主要学习Keras的应用模块 Application提供的带有预训练权重的模型,这些模型可以用来进行预测,特征提取和 finetune,上一篇文章我们使用了VGG16进行特征提取和微调,下面尝试一 ...

  3. Keras读取保存的模型时, 产生错误[ValueError: Unknown activation function:relu6]

    Solution: from keras.utils.generic_utils import CustomObjectScope with CustomObjectScope({'relu6': k ...

  4. 卷积神经网络学习笔记——轻量化网络MobileNet系列(V1,V2,V3)

    完整代码及其数据,请移步小编的GitHub地址 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/DeepLearningNote 这里结合网络的资料和Mo ...

  5. Keras RetinaNet github项目安装

    在存储库目录/keras-retinanet/中,执行pip install . --user 后,出现错误: D:\>cd D:\JupyterWorkSpace\keras-retinane ...

  6. Keras学习笔记(完结)

    使用Keras中文文档学习 基本概念 Keras的核心数据结构是模型,也就是一种组织网络层的方式,最主要的是序贯模型(Sequential).创建好一个模型后就可以用add()向里面添加层.模型搭建完 ...

  7. [Tensorflow] Object Detection API - retrain mobileNet

    前言 一.专注话题 重点话题 Retrain mobileNet (transfer learning). Train your own Object Detector. 这部分讲理论,下一篇讲实践. ...

  8. 使用keras导入densenet模型

    从keras的keras_applications的文件夹内可以找到内置模型的源代码 Kera的应用模块Application提供了带有预训练权重的Keras模型,这些模型可以用来进行预测.特征提取和 ...

  9. 【Keras学习】资源

    Keras项目github源码(python):keras-team/keras: Deep Learning for humans 里面的docs包含说明文档 中文文档:Keras中文文档 预训练模 ...

随机推荐

  1. 论文解读《Deep Resdual Learning for Image Recognition》

    总的来说这篇论文提出了ResNet架构,让训练非常深的神经网络(NN)成为了可能. 什么是残差? "残差在数理统计中是指实际观察值与估计值(拟合值)之间的差."如果回归模型正确的话 ...

  2. 源码都没调试过,怎么能说熟悉 redis 呢?

    一:背景 1. 讲故事 记得在很久之前给初学的朋友们录制 redis 视频课程,当时结合了不少源码进行解读,自以为讲的还算可以,但还是有一个非常核心的点没被分享到,那就是源码级调试, 对,读源码还远远 ...

  3. D. Yet Another Problem On a Subsequence 解析(DP)

    Codeforce 1000 D. Yet Another Problem On a Subsequence 解析(DP) 今天我們來看看CF1000D 題目連結 題目 略,請直接看原題 前言 這題提 ...

  4. Mongodb命令 --- MongoDB基础用法(二)

    Mongodb命令 数据库操作 创建数据库 MongoDB 创建数据库的语法格式如下: use DATABASE_NAME 如果数据库不存在,则创建数据库,否则切换到指定数据库. 删除数据库 Mong ...

  5. 4G DTU的使用方法和应用领域

    4G DTU是一种数据传输单元,通俗理解就是,用来传输数据的一种硬件.既然是用来传输数据的,那就能将它视为一个管道,也就是说,指令同过它传给设备,而管道是不对这些指令做出响应的. 4G DTU如何使用 ...

  6. 打印Sql查询语句

    如果在使用了yii的查询语句的话,可以打印本次的Sql,可以用 $model->find()->createCommand()->getRawSql();此语句返回的就是sql查询语 ...

  7. 考场(NOIP/ICPC)沙雕错误锦集(大赛前必看,救命提分良药)

    记住,无论什么测试,一定要先打三题暴力(至少不会被屠得太惨) 2018.10.4 1.记得算内存.(OI一年一场空,没算内存见祖宗) 2018.10.6 1.在二分许多个字符串时(二分长度),要以长度 ...

  8. CF295C Greg and Friends

    首先 我们考虑每次船来回运人时都可以看成一种dp状态 又因为人的体重只有50kg和100kg两种, 所以我们可以开一个三维数组dp[i][j][k],第1维表示在出发岸50kg有i个,第2维表示在出发 ...

  9. python详细图像仿射变换讲解

    仿射变换简介 什么是放射变换 图像上的仿射变换, 其实就是图片中的一个像素点,通过某种变换,移动到另外一个地方. 从数学上来讲, 就是一个向量空间进行一次线形变换并加上平移向量, 从而变换到另外一个向 ...

  10. Effective Modern C++ ——条款2 条款3 理解auto型别推导与理解decltype

    条款2.理解auto型别推导 对于auto的型别推导而言,其中大部分情况和模板型别推导是一模一样的.只有一种特例情况. 我们先针对auto和模板型别推导一致的情况进行讨论: //某变量采用auto来声 ...