caffe的python接口学习(4)mnist实例手写数字识别
以下主要是摘抄denny博文的内容,更多内容大家去看原作者吧
一 数据准备
准备训练集和测试集图片的列表清单;
二 导入caffe库,设定文件路径
- # -*- coding: utf-8 -*-
- import caffe
- from caffe import layers as L,params as P,proto,to_proto
- #设定文件的保存路径
- root='/home/xxx/' #根目录
- train_list=root+'mnist/train/train.txt' #训练图片列表
- test_list=root+'mnist/test/test.txt' #测试图片列表
- train_proto=root+'mnist/train.prototxt' #训练配置文件
- test_proto=root+'mnist/test.prototxt' #测试配置文件
- solver_proto=root+'mnist/solver.prototxt' #参数文件

其中train.txt 和test.txt文件已经有了,其它三个文件,我们需要自己编写。
此处注意:一般caffe程序都是先将图片转换成lmdb文件,但这样做有点麻烦。因此我就不转换了,我直接用原始图片进行操作,所不同的就是直接用图片操作,均值很难计算,因此可以不减均值。
三 生成配置文件
配置文件实际上就是一些txt文档,只是后缀名是prototxt,我们可以直接到编辑器里编写,也可以用代码生成。此处,我用python来生成。

- #编写一个函数,生成配置文件prototxt
- def Lenet(img_list,batch_size,include_acc=False):
- #第一层,数据输入层,以ImageData格式输入
- data, label = L.ImageData(source=img_list, batch_size=batch_size, ntop=2,root_folder=root,
- transform_param=dict(scale= 0.00390625))
- #第二层:卷积层
- conv1=L.Convolution(data, kernel_size=5, stride=1,num_output=20, pad=0,weight_filler=dict(type='xavier'))
- #池化层
- pool1=L.Pooling(conv1, pool=P.Pooling.MAX, kernel_size=2, stride=2)
- #卷积层
- conv2=L.Convolution(pool1, kernel_size=5, stride=1,num_output=50, pad=0,weight_filler=dict(type='xavier'))
- #池化层
- pool2=L.Pooling(conv2, pool=P.Pooling.MAX, kernel_size=2, stride=2)
- #全连接层
- fc3=L.InnerProduct(pool2, num_output=500,weight_filler=dict(type='xavier'))
- #激活函数层
- relu3=L.ReLU(fc3, in_place=True)
- #全连接层
- fc4 = L.InnerProduct(relu3, num_output=10,weight_filler=dict(type='xavier'))
- #softmax层
- loss = L.SoftmaxWithLoss(fc4, label)
- if include_acc: # test阶段需要有accuracy层
- acc = L.Accuracy(fc4, label)
- return to_proto(loss, acc)
- else:
- return to_proto(loss)
- def write_net():
- #写入train.prototxt
- with open(train_proto, 'w') as f:
- f.write(str(Lenet(train_list,batch_size=64)))
- #写入test.prototxt
- with open(test_proto, 'w') as f:
- f.write(str(Lenet(test_list,batch_size=100, include_acc=True)))

配置文件里面存放的,就是我们所说的network。我这里生成的network,可能和原始的Lenet不太一样,不过影响不大。
四 生成solver文件
同样,可以在编辑器里面直接书写,也可以用代码生成。

- #编写一个函数,生成参数文件
- def gen_solver(solver_file,train_net,test_net):
- s=proto.caffe_pb2.SolverParameter()
- s.train_net =train_net
- s.test_net.append(test_net)
- s.test_interval = 938 #60000/64,测试间隔参数:训练完一次所有的图片,进行一次测试
- s.test_iter.append(100) #10000/100 测试迭代次数,需要迭代100次,才完成一次所有数据的测试
- s.max_iter = 9380 #10 epochs , 938*10,最大训练次数
- s.base_lr = 0.01 #基础学习率
- s.momentum = 0.9 #动量
- s.weight_decay = 5e-4 #权值衰减项
- s.lr_policy = 'step' #学习率变化规则
- s.stepsize=3000 #学习率变化频率
- s.gamma = 0.1 #学习率变化指数
- s.display = 20 #屏幕显示间隔
- s.snapshot = 938 #保存caffemodel的间隔
- s.snapshot_prefix =root+'mnist/lenet' #caffemodel前缀
- s.type ='SGD' #优化算法
- s.solver_mode = proto.caffe_pb2.SolverParameter.GPU #加速
- #写入solver.prototxt
- with open(solver_file, 'w') as f:
- f.write(str(s))
五 开始训练模型
训练过程中,也在不停的测试。
- #开始训练
- def training(solver_proto):
- caffe.set_device(0)
- caffe.set_mode_gpu()
- solver = caffe.SGDSolver(solver_proto)
- solver.solve()
最后,调用以上的函数就可以了。
- if __name__ == '__main__':
- write_net()
- gen_solver(solver_proto,train_proto,test_proto)
- training(solver_proto)
六 完成的python文件
mnist.py
我将此文件放在根目录下的mnist文件夹下,因此可用以下代码执行
- sudo python mnist/mnist.py
在训练过程中,会保存一些caffemodel。多久保存一次,保存多少次,都可以在solver参数文件里进行设置。
我设置为训练10 epoch,9000多次,测试精度可以达到99%
caffe的python接口学习(4)mnist实例手写数字识别的更多相关文章
- caffe的python接口学习(4):mnist实例---手写数字识别
深度学习的第一个实例一般都是mnist,只要这个例子完全弄懂了,其它的就是举一反三的事了.由于篇幅原因,本文不具体介绍配置文件里面每个参数的具体函义,如果想弄明白的,请参看我以前的博文: 数据层及参数 ...
- keras实现mnist数据集手写数字识别
一. Tensorflow环境的安装 这里我们只讲CPU版本,使用 Anaconda 进行安装 a.首先我们要安装 Anaconda 链接:https://pan.baidu.com/s/1AxdGi ...
- NN:利用深度学习之神经网络实现手写数字识别(数据集50000张图片)—Jason niu
import mnist_loader import network training_data, validation_data, test_data = mnist_loader.load_dat ...
- 分类-MNIST(手写数字识别)
这是学习<Hands-On Machine Learning with Scikit-Learn and TensorFlow>的笔记,如果此笔记对该书有侵权内容,请联系我,将其删除. 这 ...
- CNN完成mnist数据集手写数字识别
# coding: utf-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data d ...
- mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)
前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...
- [Python]基于CNN的MNIST手写数字识别
目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...
- 深度学习之 mnist 手写数字识别
深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...
- 用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别
用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 http://phunter.farbox.com/post/mxnet-tutorial1 用MXnet实战深度学 ...
随机推荐
- 【JAVA习题十五】两个乒乓球队进行比赛,各出三人。甲队为a,b,c三人,乙队为x,y,z三人。已抽签决定比赛名单。有人向队员打听比赛的名单。a说他不和x比,c说他不和x,z比,请编程序找出三队赛手的名单。
package erase; public class 选人比赛 { public static void main(String[] args) { // TODO Auto-generated m ...
- JAVASE(十二) Java常用类: 包装类、String类、StringBuffer类、时间日期API、其他类
个人博客网:https://wushaopei.github.io/ (你想要这里多有) 1.包装类 1 .1 八个包装类 1. 2 基本数据类型,包装类,String者之间的转换 2. ...
- Java实现 LeetCode 696 计数二进制子串(暴力)
696. 计数二进制子串 给定一个字符串 s,计算具有相同数量0和1的非空(连续)子字符串的数量,并且这些子字符串中的所有0和所有1都是组合在一起的. 重复出现的子串要计算它们出现的次数. 示例 1 ...
- PAT 人口普查
某城镇进行人口普查,得到了全体居民的生日.现请你写个程序,找出镇上最年长和最年轻的人. 这里确保每个输入的日期都是合法的,但不一定是合理的,假设已知镇上没有超过 200 岁的老人,而今天是 2014 ...
- JSP基础知识点(转传智)
一.JSP概述 1.JSP:Java Server Pages(运行在服务器端的页面).就是Servlet. 学习JSP学好的关键:时刻联想到Servlet即可. 2.JSP的原理 ...
- 记录RecyclerView的位置并进行恢复
//监听RecyclerView滚动状态 mRecyclerView.addOnScrollListener(new RecyclerView.OnScrollListener() { @Overri ...
- k8s学习-安全
4.8.安全 4.8.1.概念 一些内容可参考4.6.2.Secret的内容 说明 Kubernetes 作为一个分布式集群的管理工具,保证集群的安全性是其一个重要的任务.API Server 是集群 ...
- 指定web默认首页,导致访问路径的问题
今天写了一个登陆页面,登陆成功跳转时,url中的路径不对 这是目录结构 |-web |---login |-----login.jsp |---success |-----success.jsp 这是 ...
- 【请帮帮我】为什么www.52pjb.net总是不收录,最多只收录首页?
做的好多个网站百度搜索都百度收录了,可是在其中一个一直不百度收录?http://www.52pjb.net,求大神帮忙看看,很着急很着急
- IDEA环境Spring Boot 2.3整合Activiti 6.0,启动项目初始化表并创建核心服务
如下步骤照着抄就完事了. 一.新建一个spring boot项目,并引入相关依赖 <?xml version="1.0" encoding="UTF-8" ...