1.mnist实例

##1.数据下载 获得mnist的数据包,在caffe根目录下执行./data/mnist/get_mnist.sh脚本。 get_mnist.sh脚本先下载样本库并进行解压缩,得到四个文件。 

2.生成LMDB

成功解压缩下载的样本库后,然后执行./examples/mnist/create_mnist.sh。 create_mnist.sh脚本先利用caffe-master/build/examples/mnist/目录下的convert_mnist_data.bin工具,将mnist data转化为caffe可用的lmdb格式文件,然后将生成的mnist-train-lmdb和mnist-test-lmdb两个文件放在caffe-master/example/mnist目录下面。

3.网络配置

LeNet网络定义在./examples/mnist/lenet_train_test.prototxt 文件中。

  1. name: "LeNet"
  2. layer {
  3. name: "mnist" //输入层的名称mnist
  4. type: "Data" //输入层的类型为Data层
  5. top: "data" //本层下一场连接data层和label blob空间
  6. top: "label"
  7. include {
  8. phase: TRAIN //训练阶段
  9. }
  10. transform_param {
  11. scale: 0.00390625 //输入图片像素归一到[0,1].1除以256为0.00390625
  12. }
  13. data_param {
  14. source: "examples/mnist/mnist_train_lmdb" //从mnist_train_lmdb中读入数据
  15. batch_size: 64 //batch大小为64,一次训练64条数据
  16. backend: LMDB
  17. }
  18. }
  19. layer {
  20. name: "mnist" //输入层的名称mnist
  21. type: "Data" //输入层的类型为Data层
  22. top: "data" //本层下一场连接data层和label blob空间
  23. top: "label"
  24. include {
  25. phase: TEST //测试阶段
  26. }
  27. transform_param {
  28. scale: 0.00390625 //输入图片像素归一到[0,1].1除以256为0.00390625
  29. }
  30. data_param {
  31. source: "examples/mnist/mnist_test_lmdb" //从mnist_test_lmdb中读入数据
  32. batch_size: 100 //batch大小为100,一次训练100条数据
  33. backend: LMDB
  34. }
  35. }
  36. layer {
  37. name: "conv1" //卷积层名称conv1
  38. type: "Convolution" //层类型为卷积层
  39. bottom: "data" //本层使用上一层的data,生成下一层conv1的blob
  40. top: "conv1"
  41. param {
  42. lr_mult: 1 //权重参数w的学习率倍数
  43. }
  44. param {
  45. lr_mult: 2 //偏置参数b的学习率倍数
  46. }
  47. convolution_param {
  48. num_output: 20 //输出单元数20
  49. kernel_size: 5 //卷积核大小为5*5
  50. stride: 1 //步长为1
  51. weight_filler { //允许用随机值初始化权重和偏置值
  52. type: "xavier" //使用xavier算法自动确定基于输入—输出神经元数量的初始规模
  53. }
  54. bias_filler {
  55. type: "constant" //偏置值初始化为常数,默认为0
  56. }
  57. }
  58. }
  59. layer {
  60. name: "pool1" //层名称为pool1
  61. type: "Pooling" //层类型为pooling
  62. bottom: "conv1" //本层的上一层是conv1,生成下一层pool1的blob
  63. top: "pool1"
  64. pooling_param { //pooling层的参数
  65. pool: MAX //pooling的方式是MAX
  66. kernel_size: 2 //pooling核是2*2
  67. stride: 2 //pooling步长是2
  68. }
  69. }
  70. layer {
  71. name: "conv2" //第二个卷积层,同第一个卷积层相同,只是卷积核为50
  72. type: "Convolution"
  73. bottom: "pool1"
  74. top: "conv2"
  75. param {
  76. lr_mult: 1
  77. }
  78. param {
  79. lr_mult: 2
  80. }
  81. convolution_param {
  82. num_output: 50
  83. kernel_size: 5
  84. stride: 1
  85. weight_filler {
  86. type: "xavier"
  87. }
  88. bias_filler {
  89. type: "constant"
  90. }
  91. }
  92. }
  93. layer {
  94. name: "pool2" //第二个pooling层,与第一个pooling层相同
  95. type: "Pooling"
  96. bottom: "conv2"
  97. top: "pool2"
  98. pooling_param {
  99. pool: MAX
  100. kernel_size: 2
  101. stride: 2
  102. }
  103. }
  104. layer { //全连接层
  105. name: "ip1" //全连接层名称ip1
  106. type: "InnerProduct" //层类型为全连接层
  107. bottom: "pool2"
  108. top: "ip1"
  109. param {
  110. lr_mult: 1
  111. }
  112. param {
  113. lr_mult: 2
  114. }
  115. inner_product_param { //全连接层的参数
  116. num_output: 500 //输出500个节点
  117. weight_filler {
  118. type: "xavier"
  119. }
  120. bias_filler {
  121. type: "constant"
  122. }
  123. }
  124. }
  125. layer {
  126. name: "relu1" //ReLU层
  127. type: "ReLU" //层名称为relu1
  128. bottom: "ip1" //层类型为ReLU
  129. top: "ip1"
  130. }
  131. layer {
  132. name: "ip2" //第二个全连接层
  133. type: "InnerProduct"
  134. bottom: "ip1"
  135. top: "ip2"
  136. param {
  137. lr_mult: 1
  138. }
  139. param {
  140. lr_mult: 2
  141. }
  142. inner_product_param {
  143. num_output: 10 //输出10个单元
  144. weight_filler {
  145. type: "xavier"
  146. }
  147. bias_filler {
  148. type: "constant"
  149. }
  150. }
  151. }
  152. layer {
  153. name: "accuracy"
  154. type: "Accuracy"
  155. bottom: "ip2"
  156. bottom: "label"
  157. top: "accuracy"
  158. include {
  159. phase: TEST
  160. }
  161. }
  162. layer { //loss层,softmax_loss层实现softmax和多项Logistic损失
  163. name: "loss"
  164. type: "SoftmaxWithLoss"
  165. bottom: "ip2"
  166. bottom: "label"
  167. top: "loss"
  168. }

4.训练网络

运行./examples/mnist/train_lenet.sh。 执行此脚本是,实际运行的是lenet_solver.prototxt中的定义。

  1. # The train/test net protocol buffer definition
  2. net: "examples/mnist/lenet_train_test.prototxt" //网络具体定义
  3. # test_iter specifies how many forward passes the test should carry out.
  4. # In the case of MNIST, we have test batch size 100 and 100 test iterations,
  5. # covering the full 10,000 testing images.
  6. test_iter: 100 //test迭代次数,若batch_size=100,则100张图一批,训练100次,可覆盖1000张图
  7. # Carry out testing every 500 training iterations.
  8. test_interval: 500 //训练迭代500次,测试一次
  9. # The base learning rate, momentum and the weight decay of the network.
  10. base_lr: 0.01 //网络参数:学习率,动量,权重的衰减
  11. momentum: 0.9
  12. weight_decay: 0.0005
  13. # The learning rate policy //学习策略:有固定学习率和每步递减学习率
  14. lr_policy: "inv" //当前使用递减学习率
  15. gamma: 0.0001
  16. power: 0.75
  17. # Display every 100 iterations //每迭代100次显示一次
  18. display: 100
  19. # The maximum number of iterations //最大迭代数
  20. max_iter: 10000
  21. # snapshot intermediate results //每5000次迭代存储一次数据
  22. snapshot: 5000
  23. snapshot_prefix: "examples/mnist/lenet"
  24. # solver mode: CPU or GPU
  25. solver_mode: CPU //本例用CPU训练

数据训练结束后,会生成以下四个文件: 

5.测试网络

运行./build/tools/caffe.bin test -model=examples/mnist/lenet_train_test.prototxt -weights=examples/mnist/lenet_iter_10000.caffemodel

test:表示对训练好的模型进行Testing,而不是training。其他参数包括train, time, device_query。

-model=XXX:指定模型prototxt文件,这是一个文本文件,详细描述了网络结构和数据集信息。 

从上面的打印输出可看出,测试数据中的accruacy平均成功率为98%。

mnist手写测试

手写数字的图片必须满足以下条件:

  • 必须是256位黑白色
  • 必须是黑底白字
  • 像素大小必须是28*28
  • 数字在图片中间,上下左右没有过多的空白。

测试图片

    

手写数字识别脚本

  1. import os
  2. import sys
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. caffe_root = '/home/lynn/caffe/'
  6. sys.path.insert(0, caffe_root + 'python')
  7. import caffe
  8. MODEL_FILE = '/home/lynn/caffe/examples/mnist/lenet.prototxt'
  9. PRETRAINED = '/home/lynn/caffe/examples/mnist/lenet_iter_10000.caffemodel'
  10. IMAGE_FILE = '/home/lynn/test.bmp'
  11. input_image = caffe.io.load_image(IMAGE_FILE, color=False)
  12. #print input_image
  13. net = caffe.Classifier(MODEL_FILE, PRETRAINED)
  14. prediction = net.predict([input_image], oversample = False)
  15. caffe.set_mode_cpu()
  16. print 'predicted class: ', prediction[0].argmax()

测试结果

 
 

caffe mnist实例 --lenet_train_test.prototxt 网络配置详解的更多相关文章

  1. Vmware在NAT模式下网络配置详解

    Vmware在NAT模式下网络配置详解 Linux中的网络配置对于接触Linux不久的小白菜来说,还是小有难度的,可能是不熟悉这种与windows系列迥然不同的命令行操作,也可能是由于对Linux的结 ...

  2. Docker基础 :网络配置详解

    本篇文章将讲述 Docker 的网络功能,包括使用端口映射机制来将容器内应用服务提供给外部网络,以及通过容器互联系统让多个容器之间进行快捷的网络通信,有兴趣的可以了解下. 大量的互联网应用服务包含多个 ...

  3. CentOS网络配置详解

    转载于CentOS中文站:http://www.centoscn.com/CentOS/2015/0507/5376.html一.配置文件详解 在RHEL或者CentOS等Redhat系的Linux系 ...

  4. linux网络配置详解

    一:相关网络配置的文件 1.网卡名配置相关文件 网卡名命名规则文件: /etc/udev/rules.d/70-persistent-net.rules # PCI device 0x8086:0x1 ...

  5. VMware虚拟机网络配置详解

    VMware网络配置:三种网络模式简介 安装好虚拟机以后,在网络连接里面可以看到多了两块网卡: 其中VMnet1是虚拟机Host-only模式的网络接口,VMnet8是NAT模式的网络接口,这些后面会 ...

  6. 虚拟机网络配置详解(NAT、桥接、Hostonly)

    VirtualBox中有四种网络连接方式: NAT Bridged Adapter Internal Host-only Adapter VMWare中有三种,其实它跟VMWare的网络连接方式都是一 ...

  7. 虚拟机网络配置详解(NAT、桥接、Hostonly) z

    http://www.cnblogs.com/beginmind/p/6379881.html VirtualBox中有四种网络连接方式: NAT Bridged Adapter Internal H ...

  8. CentOS 7 网络配置详解

    今天在一台PC上安装了CentOS 7,当时选择了最小安装模式,安装完成后马上用ifconfig查看本机的ip地址(局域网已经有DHCP),发现报错,提示ifconfig命令没找到. ? 1 2 3 ...

  9. Linux下MongoDB单实例的安装和配置详解

    推荐网站 MongoDB官网:http://www.mongodb.org/ MongoDB学习网站:http://www.runoob.com/mongodb 一.创建MongoDB的资源目录和安装 ...

随机推荐

  1. css 画三角形

    <div class='triangle-rihgt'></div> <div class='triangle-top'></div> <div ...

  2. intelij idea+springMVC+spring+mybatis 初探(持续更新)

    intelij idea+springMVC+spring+mybatis 初探(持续更新) intellij 创建java web项目(maven管理的SSH) http://blog.csdn.n ...

  3. Spring 获取propertise文件中的值

    Spring 获取propertise文件中的值 Spring 获取propertise的方式,除了之前的博文提到的使用@value的注解注入之外,还可以通过编码的方式获取,这里主要说的是要使用Emb ...

  4. js实现图片上传预览功能,使用base64编码来实现

    <!DOCTYPE html> <html> <head> <meta http-equiv="Content-Type" content ...

  5. 小白学习Spark系列二:spark应用打包傻瓜式教程(IntelliJ+maven 和 pycharm+jar)

    在做spark项目时,我们常常面临如何在本地将其打包,上传至装有spark服务器上运行的问题.下面是我在项目中尝试的两种方案,也踩了不少坑,两者相比,方案一比较简单,本博客提供的jar包适用于spar ...

  6. 《Exception》第八次团队作业:Alpha冲刺

    一.项目基本介绍 项目 内容 这个作业属于哪个课程 任课教师博客主页链接 这个作业的要求在哪里 作业链接地址 团队名称 Exception 作业学习目标 1.掌握软件测试基础技术.2.学习迭代式增量软 ...

  7. Problem 5

    Problem 5 # Problem_5.py """ 2520 is the smallest number that can be divided by each ...

  8. 在windows环境中关于 pycharm配置 anaconda 虚拟环境

    因为要在windows系统系统中练习tensorflow,所以需要配置一下环境(来回的开关机切换环境太麻烦了......) 首先安装anaconda3,我选择的版本是Anaconda3 5.1.0,对 ...

  9. String与StringBuffer与StringBuilder

    package test; public class Test { public static void main(String[] args) { StringBuffer sb = new Str ...

  10. ASP.NET-AD开发技巧

    分享一篇很好的介绍AD属性的文章 AD图片插件 如何给AD添加图片 http://www.doc88.com/p-9542932844870.html AD过滤条件 重命名ou使用user.Renam ...