该工作的主要目的是为了练习运用pycaffe来进行神经网络一站式训练,并从多个角度来分析对应的结果。

目标:

  1. python的运用训练
  2. pycaffe的接口熟悉
  3. 卷积网络(CNN)和全连接网络(DNN)的效果差异性
  4. 学会从多个角度来分析分类结果
    • 哪些图片被分类错误并进行可视化?
    • 为什么被分错?
    • 每一类是否同等机会被分错?
    • 在迭代过程中,每一类的错误几率如何变化?
    • 是否开始被正确识别后来又被错误识别了?

测试数据集:mnist

代码:https://github.com/TiBAiL/PycaffeTrain-LeNet

环境:Ubuntu 16.04LTS训练,Windows 7+VS2013分析

关于网络架构,在caffe的训练过程中,会涉及到三种不同类型的prototxt文件,分别用于train、test(validation)以及deploy。这三种文件需要保持网络架构上的统一,方可使得程序正常工作。为了达到这一目的,可以通过程序自动生成对应的文件。这三种文件的主要区别在于输入数据层以及loss/accuracy/prob层。

与此同时,针对solver.prototxt文件,我也采用了python程序的方式进行生成。在求解过程中,为了能够统计train loss/test loss以及accuracy信息以及保存任意时刻的model参数,可以采用pycaffe提供的api接口进行处理。

代码如下:

 import numpy as np
import caffe
from caffe import layers as L, params as P, proto, to_proto # file path
root = '/home/your-account/DL-Analysis/'
train_list = root + 'mnist/mnist_train_lmdb'
test_list = root + 'mnist/mnist_test_lmdb' train_proto = root + 'mnist/LeNet/train.prototxt'
test_proto = root + 'mnist/LeNet/test.prototxt' deploy_proto = root + 'mnist/LeNet/deploy.prototxt' solver_proto = root + 'mnist/LeNet/solver.prototxt' def LeNet(data_list, batch_size, IncludeAccuracy = False, deploy = False):
"""
LeNet define
""" if not(deploy):
data, label = L.Data(source = data_list,
backend = P.Data.LMDB,
batch_size = batch_size,
ntop = 2,
transform_param = dict(scale = 0.00390625))
else:
data = L.Input(input_param = {'shape': {'dim': [64, 1, 28, 28]}}) conv1 = L.Convolution(data,
kernel_size = 5,
stride = 1,
num_output = 20,
pad = 0,
weight_filler = dict(type = 'xavier')) pool1 = L.Pooling(conv1,
pool = P.Pooling.MAX,
kernel_size = 2,
stride = 2) conv2 = L.Convolution(pool1,
kernel_size = 5,
stride = 1,
num_output = 50,
pad = 0,
weight_filler = dict(type = 'xavier')) pool2 = L.Pooling(conv2,
pool = P.Pooling.MAX,
kernel_size = 2,
stride = 2) ip1 = L.InnerProduct(pool2,
num_output = 500,
weight_filler = dict(type = 'xavier')) relu1 = L.ReLU(ip1,
in_place = True) ip2 = L.InnerProduct(relu1,
num_output = 10,
weight_filler = dict(type = 'xavier')) #loss = L.SoftmaxWithLoss(ip2, label) if ( not(IncludeAccuracy) and not(deploy) ):
# train net
loss = L.SoftmaxWithLoss(ip2, label)
return to_proto(loss) elif ( IncludeAccuracy and not(deploy) ):
# test net
loss = L.SoftmaxWithLoss(ip2, label)
Accuracy = L.Accuracy(ip2, label)
return to_proto(loss, Accuracy) else:
# deploy net
prob = L.Softmax(ip2)
return to_proto(prob) def WriteNet():
"""
write proto to file
""" # train net
with open(train_proto, 'w') as file:
file.write( str(LeNet(train_list, 64, IncludeAccuracy = False, deploy = False)) ) # test net
with open(test_proto, 'w') as file:
file.write( str(LeNet(test_list, 100, IncludeAccuracy = True, deploy = False)) ) # deploy net
with open(deploy_proto, 'w') as file:
file.write( str(LeNet('not need', 64, IncludeAccuracy = False, deploy = True)) ) def GenerateSolver(solver_file, train_net, test_net):
"""
generate the solver file
""" s = proto.caffe_pb2.SolverParameter()
s.train_net = train_net
s.test_net.append(test_net)
s.test_interval = 100
s.test_iter.append(100)
s.max_iter = 10000
s.base_lr = 0.01
s.momentum = 0.9
s.weight_decay = 5e-4
s.lr_policy = 'step'
s.stepsize = 3000
s.gamma = 0.1
s.display = 100
s.snapshot = 0
s.snapshot_prefix = './lenet'
s.type = 'SGD'
s.solver_mode = proto.caffe_pb2.SolverParameter.GPU with open(solver_file, 'w') as file:
file.write( str(s) ) def Training(solver_file):
"""
training
""" caffe.set_device(0)
caffe.set_mode_gpu()
solver = caffe.get_solver(solver_file)
#solver.solve() # solve completely number_iteration = 10000 # collect the information
display = 100 # test information
test_iteration = 100
test_interval = 100 # loss and accuracy information
train_loss = np.zeros( np.ceil(number_iteration * 1.0 / display) )
test_loss = np.zeros( np.ceil(number_iteration * 1.0 / test_interval) )
test_accuracy = np.zeros( np.ceil(number_iteration * 1.0 / test_interval) ) # tmp variables
_train_loss = 0; _test_loss = 0; _test_accuracy = 0; # main loop
for iter in range(number_iteration):
solver.step(1) # save model during training
if iter in [10, 30, 60, 100, 300, 600, 1000, 3000, 6000, number_iteration - 1]:
string = 'lenet_iter_%(iter)d.caffemodel'%{'iter': iter}
solver.net.save(string) if 0 == iter % display:
train_loss[iter // display] = solver.net.blobs['SoftmaxWithLoss1'].data '''
# accumulate the train loss
_train_loss += solver.net.blobs['SoftmaxWithLoss1'].data
if 0 == iter % display:
train_loss[iter // display] = _train_loss / display
_train_loss = 0
''' if 0 == iter % test_interval:
for test_iter in range(test_iteration):
solver.test_nets[0].forward()
_test_loss += solver.test_nets[0].blobs['SoftmaxWithLoss1'].data
_test_accuracy += solver.test_nets[0].blobs['Accuracy1'].data test_loss[iter / test_interval] = _test_loss / test_iteration
test_accuracy[iter / test_interval] = _test_accuracy / test_iteration
_test_loss = 0
_test_accuracy = 0 # save for analysis
np.save('./train_loss.npy', train_loss)
np.save('./test_loss.npy', test_loss)
np.save('./test_accuracy.npy', test_accuracy) if __name__ == '__main__':
WriteNet()
GenerateSolver(solver_proto, train_proto, test_proto)
Training(solver_proto)

利用上述代码训练出来的model进行预测,并对结果进行分析(相关分析代码参见上述链接):

  • 利用第10000步(最后一步)时的model进行预测,其分类错误率为0.91%。为了能够直观的观察哪些图片被分类错误,这里我们给出了所有分类错误的图片。在对应标题中,第一个数字为预测值,第二个数字为实际真实值。从中我们可以看到,如红框所示,有许多数字确实是鬼斧神工,人都几乎无法有效区分。

  • 现在,我们来看一看,针对每一类,其究竟被分成了哪些数字?这个其实可以从上图看出,这里我给出他们的柱状图,其中子标题表示真实的label,横坐标表示被错误分类的label,纵坐标表示数量,最后一个子图表示在所有的错误分类中,每一类所占的比例。从中可以看出,3/5,4/6,7/2,9/4等较容易混淆,这个也可以非常容易理解。此外,我们也可以发现,数字1最容易分辨,这个可能是因为每个人写1都比较相似,变体较少导致的。

  • 接着,让我们来考察一下,在迭代过程中,每个数字的分类准确度是如何变化的。其中,子标题表示真实的label,横坐标表示迭代步数,纵坐标表示分类错误的数量,最后一个子图表示迭代过程中,总的错误率的变化曲线。从该图中,我们可以看出,在迭代过程中,一个图片是有可能先被分类正确后来又被分类错误的(各子曲线并不呈现单调递减的关系),这点也可以从中间变量中进行定量分析看出(代码中有该变量)。

  • 关于train loss、test loss以及accuracy的曲线。注意,这里的train loss是某一步的值(因此具有强随机性),而test loss以及accuracy则是100次的平均值(因此较为平滑)

  • 最后,让我们来分析一下全连接网络(DNN)的结果。所采用的网络架构为LeNet-300-100。其最终的分类错误率为3.6%。这里我仅给出按比例随机挑选的100张错误图片,以及在迭代过程中每个数字错误率的变换,如下所示。可以看到,DNN网络在第10步时其错误率高达80%左右,而CNN网络在该步时的错误率为30%左右,这其中是否有某种深刻内涵呢?

基于pycaffe的网络训练和结果分析(mnist数据集)的更多相关文章

  1. Cacti 是一套基于PHP,MySQL,SNMP及RRDTool开发的网络流量监测图形分析工具

    Cacti 是一套基于PHP,MySQL,SNMP及RRDTool开发的网络流量监测图形分析工具. mysqlreport是mysql性能监测时最常用的工具,对了解mysql运行状态和配置调整都有很大 ...

  2. 抓住“新代码”的影子 —— 基于GoAhead系列网络摄像头多个漏洞分析

    PDF 版本下载:抓住“新代码”的影子 —— 基于GoAhead系列网络摄像头多个漏洞分析 Author:知道创宇404实验室 Date:2017/03/19 一.漏洞背景 GoAhead作为世界上最 ...

  3. 基于PySpark的网络服务异常检测系统 (四) Mysql与SparkSQL对接同步数据 kmeans算法计算预测异常

    基于Django Restframework和Spark的异常检测系统,数据库为MySQL.Redis, 消息队列为Celery,分析服务为Spark SQL和Spark Mllib,使用kmeans ...

  4. 基于孪生卷积网络(Siamese CNN)和短时约束度量联合学习的tracklet association方法

    基于孪生卷积网络(Siamese CNN)和短时约束度量联合学习的tracklet association方法 Siamese CNN Temporally Constrained Metrics T ...

  5. Android-蓝牙的网络共享与连接分析

    一.概述 本次分析是基于android7.0的源码,主要是介绍如何通过反射来打开蓝牙的网络共享以及互联网的连接. 二.蓝牙的网络共享 1. 网络共享部分源码分析 关于packages/apps/Set ...

  6. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  7. 开源网络抓包与分析框架学习-Packetbeat篇

    开源简介packbeat是一个开源的实时网络抓包与分析框架,内置了很多常见的协议捕获及解析,如HTTP.MySQL.Redis等.在实际使用中,通常和Elasticsearch以及kibana联合使用 ...

  8. 基于ArcGIS for Server的服务部署分析 分类: ArcGIS for server 云计算 2015-07-26 21:28 11人阅读 评论(0) 收藏

    谨以此纪念去年在学海争锋上的演讲. ---------------------------------------------------- 基于ArcGIS for Server的服务部署分析 -- ...

  9. 卷积网络训练太慢?Yann LeCun:已解决CIFAR-10,目标 ImageNet

    原文连接:http://blog.kaggle.com/2014/12/22/convolutional-nets-and-cifar-10-an-interview-with-yan-lecun/ ...

随机推荐

  1. 洛谷P1004 方格取数-四维DP

    题目描述 设有 N \times NN×N 的方格图 (N \le 9)(N≤9) ,我们将其中的某些方格中填入正整数,而其他的方格中则放入数字 00 .如下图所示(见样例): A 0 0 0 0 0 ...

  2. A. Vasya and Chocolate

    链接 [http://codeforces.com/contest/1065/problem/A] 分析 一个公式完事 代码 #include<bits/stdc++.h> using n ...

  3. hots团队项目终审报告

    一.团队成员: 徐钧鸿: 1994年1月19日生人,摩羯座最后一天.所以有摩羯的强迫症和水瓶古怪的性格 暂且算队长吧…… 高中的时候因为兴趣学了竞赛,于是就入坑了,于是就来北航学计算机了 兴趣面很广, ...

  4. HanderBar

    对于java开发,涉及到页面展示时,比较主流的有两种解决方案: 1. struts2+vo+el表达式. 这种方式,重点不在于struts2,而是vo和el表达式,其基本思想是:根据页面需要的信息,构 ...

  5. 【转】STM32和ARM的区别

    转自:http://www.cnblogs.com/nuc-boy/archive/2012/09/11/2680157.html 这个问题大概2009年的时候很多人就在问,请看09年的时候大家给出的 ...

  6. MySQLi面向对象实践--select

    对于update.insert.delete请参考http://www.cnblogs.com/-beyond/p/8457580.html 执行select,如果SQL语句执行成功,那么返回的是一个 ...

  7. spring-web-4.3.3与spring-webmvc-4.3.3的区别

    spring-web-4.3.3 http(http协议的实现类)和web包(应用,上下文,会话,cookies,过滤器等等) spring-webmvc-4.3.3 主要是一些view层的核心封装, ...

  8. python学习笔记七——字典

    4.3 字典结构 字典是Python中重要的数据类型,字典的由“键-值”对组成的集合,字典中的“值”通过“键”来引用. 4.3.1 字典的创建 字典由一系列的“键-值”(key-value)对组成,“ ...

  9. jdk1.8 HashMap的keySet方法详解

    我在看HashMap源码的时候有一个问题让我产生了兴趣,那就是HashMap的keySet方法,没有调用HashMap的有关数据的任何方法就能获取到map的所有的键,他是怎么做到的,然后我就通过模拟k ...

  10. Get The Treasury HDU - 3642(扫描线求三维面积交。。体积交)

    题意: ...就是求体积交... 解析: 把每一层z抽出来,计算面积交, 然后加起来即可..! 去看一下 二维面积交的代码 再看看这个三维面积交的代码.. down函数里 你发现了什么规律!!! 参考 ...