一、前言

本文主要尝试将自己的数据集制作成lmdb格式,送进lenet作训练和测试,参考了http://blog.csdn.net/liuweizj12/article/details/52149743http://blog.csdn.net/xiaoxiao_huitailang/article/details/51361036这两篇博文

二、从训练模型到使用模型预测图片分类

(1)自己准备的图像数据

由于主要是使用lenet模型训练自己的图片数据,我的图像数据共有10个类别,分别是0~9,相应地保存在名为0~9的文件夹,在/homg/您的用户名/下新建一文件夹char_images,用于保存图像数据,在/homg/您的用户名/char_images/下新建两个文件夹,名字分别为train和val,各自都包含了名为0~9的文件夹,例如文件夹0内存放的是字符”0”的图像,我的文件夹 如下:

(2)对图像数据作统一缩放至28*28,并生成txt标签

为了计算均值文件,需要将所有图片缩放至统一的尺寸,在train和val文件夹所在路径下创建python文件,命名getPath.py,并写入以下内容:

  1. #coding:utf-8
  2. import cv2
  3. import os
  4. def IsSubString( SubStrList , Str):  #判断SubStrList的元素
  5. flag = True                  #是否在Str内
  6. for substr in SubStrList:
  7. if not ( substr in Str):
  8. flag = False
  9. return flag
  10. def GetFileList(FindPath,FlagStr=[]):  #搜索目录下的子文件路径
  11. FileList=[]
  12. FileNames=os.listdir(FindPath)
  13. if len(FileNames)>0:
  14. for fn in FileNames:
  15. if len(FlagStr)>0:
  16. if IsSubString(FlagStr,fn): #不明白这里判断是为了啥
  17. fullfilename=os.path.join(FindPath,fn)
  18. FileList.append(fullfilename)
  19. else:
  20. fullfilename=os.path.join(FindPath,fn)
  21. FileList.append(fullfilename)
  22. if len(FileList)>0:
  23. FileList.sort()
  24. return FileList
  25. train_txt = open('train.txt' , 'w') #制作标签数据
  26. classList =['0','1','2','3','4','5','6','7','8','9']
  27. for idx in range(len(classList)) :
  28. imgfile=GetFileList('train/'+ classList[idx])#将数据集放在与.py文件相同目录下
  29. for img in imgfile:
  30. srcImg = cv2.imread( img);
  31. resizedImg = cv2.resize(srcImg , (28,28))
  32. cv2.imwrite( img  ,resizedImg)
  33. strTemp=img+' '+classList[idx]+'\n'        #用空格代替转义字符 \t
  34. train_txt.writelines(strTemp)
  35. train_txt.close()
  36. test_txt = open('val.txt' , 'w') #制作标签数据
  37. for idx in range(len(classList)) :
  38. imgfile=GetFileList('val/'+ classList[idx])
  39. for img in imgfile:
  40. srcImg = cv2.imread( img);
  41. resizedImg = cv2.resize(srcImg , (28,28))
  42. cv2.imwrite( img  ,resizedImg)
  43. strTemp=img+' '+classList[idx]+'\n'        #用空格代替转义字符 \t
  44. test_txt.writelines(strTemp)
  45. test_txt.close()
  46. print("成功生成文件列表")

运行该py文件,可将所有图片缩放至28*28大小,并且在rain和val文件夹所在路径下生成训练和测试图像数据的标签txt文件,文件内容为:
         

(3)生成lmdb格式的数据集

首先于caffe路径下新建一文件夹My_File,并在My_File下新建两个文件夹Build_lmdb和Data_label,将(2)中生成文本文件train.txt和val.txt搬至Data_label下

  

将caffe路径下 examples/imagenet/create_imagenet.sh 复制一份到Build_lmdb文件夹下

打开create_imagenet.sh ,修改内容如下:

  1. #!/usr/bin/env sh
  2. # Create the imagenet lmdb inputs
  3. # N.B. set the path to the imagenet train + val data dirs
  4. set -e
  5. EXAMPLE=My_File/Build_lmdb         #生成的lmdb格式数据保存地址
  6. DATA=My_File/Data_label                 #两个txt标签文件所在路径
  7. TOOLS=build/tools                            #caffe自带工具,不用管
  8. TRAIN_DATA_ROOT=/home/zjy/char_images/    #预先准备的训练图片路径,该路径和train.txt上写的路径合起来是图片完整路径
  9. VAL_DATA_ROOT=/home/zjy/char_images/         #预先准备的测试图片路径,...
  10. # Set RESIZE=true to resize the images to 256x256. Leave as false if images have
  11. # already been resized using another tool.
  12. RESIZE=false
  13. if $RESIZE; then
  14. RESIZE_HEIGHT=28
  15. RESIZE_WIDTH=28
  16. else
  17. RESIZE_HEIGHT=0
  18. RESIZE_WIDTH=0
  19. fi
  20. if [ ! -d "$TRAIN_DATA_ROOT" ]; then
  21. echo "Error: TRAIN_DATA_ROOT is not a path to a directory: $TRAIN_DATA_ROOT"
  22. echo "Set the TRAIN_DATA_ROOT variable in create_imagenet.sh to the path" \
  23. "where the ImageNet training data is stored."
  24. exit 1
  25. fi
  26. if [ ! -d "$VAL_DATA_ROOT" ]; then
  27. echo "Error: VAL_DATA_ROOT is not a path to a directory: $VAL_DATA_ROOT"
  28. echo "Set the VAL_DATA_ROOT variable in create_imagenet.sh to the path" \
  29. "where the ImageNet validation data is stored."
  30. exit 1
  31. fi
  32. echo "Creating train lmdb..."
  33. GLOG_logtostderr=1 $TOOLS/convert_imageset \
  34. --resize_height=$RESIZE_HEIGHT \
  35. --resize_width=$RESIZE_WIDTH \
  36. --shuffle \
  37. --gray \        #灰度图像加上这个
  38. $TRAIN_DATA_ROOT \
  39. $DATA/train.txt \
  40. $EXAMPLE/train_lmdb                   #生成的lmdb格式训练数据集所在的文件夹
  41. echo "Creating val lmdb..."
  42. GLOG_logtostderr=1 $TOOLS/convert_imageset \
  43. --resize_height=$RESIZE_HEIGHT \
  44. --resize_width=$RESIZE_WIDTH \
  45. --shuffle \
  46. --gray \        #灰度图像加上这个
  47. $VAL_DATA_ROOT \
  48. $DATA/val.txt \
  49. $EXAMPLE/val_lmdb              #生成的lmdb格式训练数据集所在的文件夹
  50. echo "Done."

以上只是为了说明修改的地方才添加汉字注释,实际时sh文件不要出现汉字,运行该sh文件,可在Build_lmdb文件夹内生成2个文件夹train_lmdb和val_lmdb,里面各有2个lmdb格式的文件

(4)更改lenet_solver.prototxt和lenet_train_test.prototxt
将caffe/examples/mnist下的 train_lenet.sh 、lenet_solver.prototxt 、lenet_train_test.prototxt 这三个文件复制至 My_File,首先修改train_lenet.sh 如下,只改了solver.prototxt的路径

  1. #!/usr/bin/env sh
  2. set -e
  3. ./build/tools/caffe train --solver=My_File/lenet_solver.prototxt $@    #改路径

然后再更改lenet_solver.prototxt,如下:

  1. # The train/test net protocol buffer definition
  2. net: "My_File/lenet_train_test.prototxt"            #改这里
  3. # test_iter specifies how many forward passes the test should carry out.
  4. # In the case of MNIST, we have test batch size 100 and 100 test iterations,
  5. # covering the full 10,000 testing images.
  6. test_iter: 100
  7. # Carry out testing every 500 training iterations.
  8. test_interval: 500
  9. # The base learning rate, momentum and the weight decay of the network.
  10. base_lr: 0.01
  11. momentum: 0.9
  12. weight_decay: 0.0005
  13. # The learning rate policy
  14. lr_policy: "inv"
  15. gamma: 0.0001
  16. power: 0.75
  17. # Display every 100 iterations
  18. display: 100
  19. # The maximum number of iterations
  20. max_iter: 10000
  21. # snapshot intermediate results
  22. snapshot: 5000
  23. snapshot_prefix: "My_File/"         #改这里
  24. # solver mode: CPU or GPU
  25. solver_mode: GPU

最后修改lenet_train_test.prototxt ,如下:

  1. name: "LeNet"
  2. layer {
  3. name: "mnist"
  4. type: "Data"
  5. top: "data"
  6. top: "label"
  7. include {
  8. phase: TRAIN
  9. }
  10. transform_param {
  11. scale: 0.00390625
  12. }
  13. data_param {
  14. source: "My_File/Build_lmdb/train_lmdb"       #改成自己的
  15. batch_size: 64
  16. backend: LMDB
  17. }
  18. }
  19. layer {
  20. name: "mnist"
  21. type: "Data"
  22. top: "data"
  23. top: "label"
  24. include {
  25. phase: TEST
  26. }
  27. transform_param {
  28. scale: 0.00390625
  29. }
  30. data_param {
  31. source: "My_File/Build_lmdb/val_lmdb"        #改成自己的
  32. batch_size: 100
  33. backend: LMDB
  34. }
  35. }
  36. layer {
  37. name: "conv1"
  38. type: "Convolution"
  39. bottom: "data"
  40. top: "conv1"
  41. param {
  42. lr_mult: 1
  43. }
  44. param {
  45. lr_mult: 2
  46. }
  47. convolution_param {
  48. num_output: 20
  49. kernel_size: 5
  50. stride: 1
  51. weight_filler {
  52. type: "xavier"
  53. }
  54. bias_filler {
  55. type: "constant"
  56. }
  57. }
  58. }
  59. layer {
  60. name: "pool1"
  61. type: "Pooling"
  62. bottom: "conv1"
  63. top: "pool1"
  64. pooling_param {
  65. pool: MAX
  66. kernel_size: 2
  67. stride: 2
  68. }
  69. }
  70. layer {
  71. name: "conv2"
  72. type: "Convolution"
  73. bottom: "pool1"
  74. top: "conv2"
  75. param {
  76. lr_mult: 1
  77. }
  78. param {
  79. lr_mult: 2
  80. }
  81. convolution_param {
  82. num_output: 50
  83. kernel_size: 5
  84. stride: 1
  85. weight_filler {
  86. type: "xavier"
  87. }
  88. bias_filler {
  89. type: "constant"
  90. }
  91. }
  92. }
  93. layer {
  94. name: "pool2"
  95. type: "Pooling"
  96. bottom: "conv2"
  97. top: "pool2"
  98. pooling_param {
  99. pool: MAX
  100. kernel_size: 2
  101. stride: 2
  102. }
  103. }
  104. layer {
  105. name: "ip1"
  106. type: "InnerProduct"
  107. bottom: "pool2"
  108. top: "ip1"
  109. param {
  110. lr_mult: 1
  111. }
  112. param {
  113. lr_mult: 2
  114. }
  115. inner_product_param {
  116. num_output: 500
  117. weight_filler {
  118. type: "xavier"
  119. }
  120. bias_filler {
  121. type: "constant"
  122. }
  123. }
  124. }
  125. layer {
  126. name: "relu1"
  127. type: "ReLU"
  128. bottom: "ip1"
  129. top: "ip1"
  130. }
  131. layer {
  132. name: "ip2"
  133. type: "InnerProduct"
  134. bottom: "ip1"
  135. top: "ip2"
  136. param {
  137. lr_mult: 1
  138. }
  139. param {
  140. lr_mult: 2
  141. }
  142. inner_product_param {
  143. num_output: 10
  144. weight_filler {
  145. type: "xavier"
  146. }
  147. bias_filler {
  148. type: "constant"
  149. }
  150. }
  151. }
  152. layer {
  153. name: "accuracy"
  154. type: "Accuracy"
  155. bottom: "ip2"
  156. bottom: "label"
  157. top: "accuracy"
  158. include {
  159. phase: TEST
  160. }
  161. }
  162. layer {
  163. name: "loss"
  164. type: "SoftmaxWithLoss"
  165. bottom: "ip2"
  166. bottom: "label"
  167. top: "loss"
  168. }

运行 My_File/train_lenet.sh ,得到最后的训练结果,在My_File下生成训练的caffemodel和solverstate。

(5)生成均值文件
均值文件主要用于图像预测的时候,由caffe/build/tools/compute_image_mean生成,在My_File文件夹下新建一文件夹Mean,用于存放均值文件,在caffe/下执行:
build/tools/compute_image_mean My_File/Build_lmdb/train_lmdb My_File/Mean/mean.binaryproto
可在My_File/Mean/下生成均值文件mean.binaryproto 
(6)生成deploy.prototxt
deploy.prototxt是在lenet_train_test.prototxt的基础上删除了开头的Train和Test部分以及结尾的Accuracy、SoftmaxWithLoss层,并在开始时增加了一个data层描述,结尾增加softmax层,可以参照博文http://blog.csdn.net/lanxuecc/article/details/52474476 使用python生成,也可以直接由train_val.prototxt上做修改,在My_File文件夹下新建一文件夹Deploy,将 lenet_train_test.prototxt复制至文件夹Deploy下,并重命名为deploy.prototxt ,修改里面的内容如下:

  1. name: "LeNet"
  2. layer {                   #删去原来的Train和Test部分,增加一个data层
  3. name: "data"
  4. type: "Input"
  5. top: "data"
  6. input_param { shape: { dim: 1 dim: 1 dim: 28 dim: 28 } }
  7. }
  8. layer {
  9. name: "conv1"
  10. type: "Convolution"
  11. bottom: "data"
  12. top: "conv1"
  13. param {
  14. lr_mult: 1
  15. }
  16. param {
  17. lr_mult: 2
  18. }
  19. convolution_param {
  20. num_output: 20
  21. kernel_size: 5
  22. stride: 1
  23. weight_filler {
  24. type: "xavier"
  25. }
  26. bias_filler {
  27. type: "constant"
  28. }
  29. }
  30. }
  31. layer {
  32. name: "pool1"
  33. type: "Pooling"
  34. bottom: "conv1"
  35. top: "pool1"
  36. pooling_param {
  37. pool: MAX
  38. kernel_size: 2
  39. stride: 2
  40. }
  41. }
  42. layer {
  43. name: "conv2"
  44. type: "Convolution"
  45. bottom: "pool1"
  46. top: "conv2"
  47. param {
  48. lr_mult: 1
  49. }
  50. param {
  51. lr_mult: 2
  52. }
  53. convolution_param {
  54. num_output: 50
  55. kernel_size: 5
  56. stride: 1
  57. weight_filler {
  58. type: "xavier"
  59. }
  60. bias_filler {
  61. type: "constant"
  62. }
  63. }
  64. }
  65. layer {
  66. name: "pool2"
  67. type: "Pooling"
  68. bottom: "conv2"
  69. top: "pool2"
  70. pooling_param {
  71. pool: MAX
  72. kernel_size: 2
  73. stride: 2
  74. }
  75. }
  76. layer {
  77. name: "ip1"
  78. type: "InnerProduct"
  79. bottom: "pool2"
  80. top: "ip1"
  81. param {
  82. lr_mult: 1
  83. }
  84. param {
  85. lr_mult: 2
  86. }
  87. inner_product_param {
  88. num_output: 500
  89. weight_filler {
  90. type: "xavier"
  91. }
  92. bias_filler {
  93. type: "constant"
  94. }
  95. }
  96. }
  97. layer {
  98. name: "relu1"
  99. type: "ReLU"
  100. bottom: "ip1"
  101. top: "ip1"
  102. }
  103. layer {
  104. name: "ip2"
  105. type: "InnerProduct"
  106. bottom: "ip1"
  107. top: "ip2"
  108. param {
  109. lr_mult: 1
  110. }
  111. param {
  112. lr_mult: 2
  113. }
  114. inner_product_param {
  115. num_output: 10
  116. weight_filler {
  117. type: "xavier"
  118. }
  119. bias_filler {
  120. type: "constant"
  121. }
  122. }
  123. }
  124. layer {                   #增加softmax层
  125. name: "prob"
  126. type: "Softmax"
  127. bottom: "ip2"
  128. top: "prob"
  129. }

(7)预测图片
在My_File文件夹下创建一文件夹Pic,用于存放测试的图片;在My_File文件夹下创建另一文件夹Synset,在其中新建synset_words.txt文件,之后在里面输入:
0
1
2
3
4
5
6
7
8
9

看看My_File文件夹都有啥了

最后使用caffe/build/examples/cpp_classification/classification.bin对图片作预测,在终端输入:

三、结束语

真是篇又臭又长的博文,高手自行忽略,刚刚入门的可以看看!

使用LeNet训练自己的手写图片数据的更多相关文章

  1. 吴裕雄--天生自然 pythonTensorFlow图形数据处理:将MNIST手写图片数据写入TFRecord文件

    import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...

  2. 吴裕雄--天生自然 pythonTensorFlow图形数据处理:读取MNIST手写图片数据写入的TFRecord文件

    import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...

  3. 07 训练Tensorflow识别手写数字

    打开Python Shell,输入以下代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input ...

  4. 人工智能-深度学习(3)TensorFlow 实战一:手写图片识别

    http://gitbook.cn/gitchat/column/59f7e38160c9361563ebea95/topic/59f7e86d60c9361563ebeee5 wiki.jikexu ...

  5. Pytorch1.0入门实战一:LeNet神经网络实现 MNIST手写数字识别

    记得第一次接触手写数字识别数据集还在学习TensorFlow,各种sess.run(),头都绕晕了.自从接触pytorch以来,一直想写点什么.曾经在2017年5月,Andrej Karpathy发表 ...

  6. pytorch: 准备、训练和测试自己的图片数据

    大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试.如果我们是自己的图片数据,又该怎么做呢? 一.我的数据 我在学习的时候,使用的是fashion-mnist.这个 ...

  7. javaScript 手写图片轮播

    <!DOCTYPE html> <html> <head> <meta charset="UTF-8"> <title> ...

  8. js手写图片查看器(图片的缩放、旋转、拖拽)

    在做一次代码编辑任务中,要查看图片器.在时间允许的条件下,放弃了已经封装好的图片jq插件,现在自己手写js实现图片的缩放.旋转.推拽功能! 具体代码如下: <!DOCTYPE html> ...

  9. SVM:利用SVM算法实现手写图片识别(数据集50000张图片)—Jason niu

    import mnist_loader # Third-party libraries from sklearn import svm def svm_baseline(): training_dat ...

随机推荐

  1. hdu4762Cut the Cake(概率+大数操作(java)+C++高精度模板)

    题目链接:点击打开链接 题目描写叙述:现有一个大蛋糕.上面随机分布了n个草莓,然后将草莓切成m块,问n个草莓全在一块蛋糕上面的概率? 解题思路:细致分析可得:C(n,1)/m^(n-1) 因为m< ...

  2. 号外:java基础班教材永久免费 报名就送

    以前万人疯抢的成都传智播客java基础班教材,今日免费赠送,你hold的住吗? 由成都传智播客传道授业解惑的诸位老师,精心制作的教材.如今免费赠送,你能接的住吗? 书是交融感情.获得知识.传承经验的重 ...

  3. 【cl】多表查询(内、外连接)

    交叉连接(cross join):该连接产生的结果集笛卡尔积 a有7行,b有8行    a的第一行与b的每一行进行连接,就有8条a得第一行 7*8=56条 select a.real_name,s.u ...

  4. 深入理解 C 指针阅读笔记 -- 第二章

    Chapter2.h #ifndef __CHAPTER_2_ #define __CHAPTER_2_ /*<深入理解C指针>学习笔记 -- 第二章*/ /* 内存泄露的两种形式 1.忘 ...

  5. VS2013+ffmpeg开发环境搭建-->【转】

    本文转载自:http://blog.csdn.net/qq_28425595/article/details/51488869 版权声明:本文为博主原创文章,未经博主允许不得转载. 今天整理资料时,发 ...

  6. perl的安装

    http://www.activestate.com/activeperl/downloads 安装的时候,默认把perl放置到环境变量的PATH中 之后,需要重启电脑,确保环境变量生效 执行perl ...

  7. 南海区行政审批管理系统接口规范v0.3(规划)

    1. 会话API 1.1. login [登录验证] {"r_code":"500","r_msg":"操作失败",&q ...

  8. hammer教程

    一.前言 移动端框架当前还处在初级阶段,但相对于移动端的应用来说已经有很长时间了.虽然暂时还没有PC端开发的需求量大,但移动端的Web必然是一种趋势,在接触移动端脚本的过程中,最开始想到的是juqer ...

  9. Nginx 404 500

    Nginx反向代理自定义404错误页面 http中添加 proxy_intercept_errors on; server中添加 error_page 404 = https://www.longda ...

  10. hdu3572Task Schedule 最大流,判断满流 优化的SAP算法

    PS:多校联赛的题目质量还是挺高的.建图不会啊,看了题解才会的. 参考博客:http://blog.csdn.net/luyuncheng/article/details/7944417 看了上面博客 ...