pycaffe训练的完整组件示例
pycaffe训练的完整组件示例
为什么写这篇博客
1. 需要用到pycaffe
因为用到的开源代码基于Caffe;要维护的项目基于Caffe。基本上是用Caffe的Python接口。
2. 训练中想穿插验证并输出关注的指标
比如每训练完1个epoch就应该在完整的validation集合上执行evaluation,输出测量出的、关注的指标,例如AP、Accuracy、F1-score等。Caffe通过solver.prototxt中配置test_net
能执行测试,但基本只能输出Accuracy而且是各个test_batch
上的平均Accuracy,而不是想关注的验证集整体上的AP(见Solver.cpp源码)
3. 训练中期望有可视化输出
Caffe训练输出在屏幕终端,也可自行重定向到日志文件。的确可以自行解析日志文件,并结合flask搭建web页面实时显示输出。但是这不够标准和鲁棒。期望有专门的可视化工具,避免自己造难用的轮子。
本文给出很简陋的pyCaffe和VisualDL结合的例子。
解决方案
用pycaffe接管训练接口
通过自行编写python代码来执行训练,而不是用$CAFFE_ROOT/build/tools/caffe train --solver solver.prototxt
的方式来启动。
- solver.prototxt中需要配置test_net, test_iter, test_interval,保证solver有test_net对象
- test_interval设置为999999999,以避开Solver.cpp中执行的TestAll()函数,转而在python代码中手动判断和执行validation
- 执行validation之前注意test_net.share_with(train_net)
- 利用solver.step(1)执行训练网络的一次迭代,利用solver.test_net[0].forward()执行测试网络的一次前传
- 利用
net.blobs['prob'].data
的形式取出网络输出 - 利用sklearn.metrics包,将取出的数据执行evaluation
- 利用VisualDL等可视化工具,将取出的数据执行绘图
依赖项
VisualDL,是PaddlePaddle和ECharts团队联合推出的,应该是对抗谷歌的Tensorboarde的。相信ECharts的实力。
sudo pip install visualdl
看起来VisualDL和Tensorboard类似,不过对于Caffe,用不了Tensorboard,能用VisualDL也是好事。
参考代码
solve.py
#!/usr/bin/env python2
# coding: utf-8
"""
inspired and adapted from:
- https://github.com/shelhamer/fcn.berkeleyvision.org
- https://github.com/rbgirshick/py-faster-rcnn
- https://github.com/PaddlePaddle/VisualDL/blob/develop/docs/quick_start_en.md
"""
from __future__ import print_function
import _init_paths
import caffe
import argparse
import os
import sys
from datetime import datetime
import cv2
from caffe.proto import caffe_pb2
import google.protobuf as pb2
import google.protobuf.text_format
import numpy as np
import perfeval
from visualdl import LogWriter #for visualization during training
def parse_args():
"""Parse input arguments"""
parser = argparse.ArgumentParser(description='Train a classification network')
parser.add_argument('--solver', dest='solver',
help='solver prototxt',
default=None, type=str, required=True)
parser.add_argument('--weights', dest='pretrained_model',
help='initialize with pretrained model weights',
default=None, type=str)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
args = parser.parse_args()
return args
class SolverWrapper:
"""对于Solver进行封装,便于外部调用"""
def __init__(self, solver_prototxt, num_epoch, num_example, pretrained_model=None):
self.solver = caffe.SGDSolver(solver_prototxt)
if pretrained_model is not None:
print('Loading pretrained model weights from {:s}'.format(pretrained_model))
self.solver.net.copy_from(pretrained_model)
self.solver_param = caffe_pb2.SolverParameter()
with open(solver_prototxt, 'rt') as f:
pb2.text_format.Merge(f.read(), self.solver_param)
self.cur_epoch = 0
self.test_interval = 100 #用来替代self.solver_param.test_interval
self.logw = LogWriter("catdog_log", sync_cycle=100)
with self.logw.mode('train') as logger:
self.sc_train_loss = logger.scalar("loss")
self.sc_train_acc = logger.scalar("Accuracy")
with self.logw.mode('val') as logger:
self.sc_val_acc = logger.scalar("Accuracy")
self.sc_val_mAP = logger.scalar("mAP")
def train_model(self):
"""执行训练的整个流程,穿插了validation"""
cur_iter = 0
test_batch_size, num_classes = self.solver.test_nets[0].blobs['prob'].shape
num_test_images_tot = test_batch_size * self.solver_param.test_iter[0]
while cur_iter < self.solver_param.max_iter:
#self.solver.step(self.test_interval)
for i in range(self.test_interval):
self.solver.step(1)
loss = self.solver.net.blobs['loss'].data
acc = self.solver.net.blobs['accuracy'].data
step = self.solver.iter
self.sc_train_loss.add_record(step, loss)
self.sc_train_acc.add_record(step, acc)
self.eval_on_val(num_classes, num_test_images_tot, test_batch_size)
cur_iter += self.test_interval
def eval_on_val(self, num_classes, num_test_images_tot, test_batch_size):
"""在整个验证集上执行inference和evaluation"""
self.solver.test_nets[0].share_with(self.solver.net)
self.cur_epoch += 1
scores = np.zeros((num_classes, num_test_images_tot), dtype=float)
gt_labels = np.zeros((1, num_test_images_tot), dtype=float).squeeze()
for t in range(self.solver_param.test_iter[0]):
output = self.solver.test_nets[0].forward()
probs = output['prob']
labels = self.solver.test_nets[0].blobs['label'].data
gt_labels[t*test_batch_size:(t+1)*test_batch_size] = labels.T.astype(float)
scores[:,t*test_batch_size:(t+1)*test_batch_size] = probs.T
ap, acc = perfeval.cls_eval(scores, gt_labels)
print('====================================================================\n')
print('\tDo validation after the {:d}-th training epoch\n'.format(self.cur_epoch))
print('>>>>', end='\t') #设定标记,方便于解析日志获取出数据
for i in range(num_classes):
print('AP[{:d}]={:.2f}'.format(i, ap[i]), end=', ')
mAP = np.average(ap)
print('mAP={:.2f}, Accuracy={:.2f}'.format(mAP, acc))
print('\n====================================================================\n')
step = self.solver.iter
self.sc_val_mAP.add_record(step, mAP)
self.sc_val_acc.add_record(step, acc)
if __name__ == '__main__':
args = parse_args()
solver_prototxt = args.solver
num_epoch = args.num_epoch
num_batch = args.num_batch
pretrained_model = args.pretrained_model
# init
caffe.set_mode_gpu()
caffe.set_device(0)
sw = SolverWrapper(solver_prototxt, num_epoch, num_batch, pretrained_model)
sw.train_model()
perfeval.py
#!/usr/bin/env python2
# coding: utf-8
from __future__ import print_function
import numpy as np
import sklearn.metrics as metrics
def cls_eval(scores, gt_labels):
"""
分类任务的evaluation
@param scores: cxm np-array, m为样本数量(例如一个epoch)
@param gt_labels: 1xm np-array, 元素属于{0,1,2,...,K-1},表示K个类别的索引
"""
num_classes, num_test_imgs = scores.shape
pred_labels = scores.argmax(axis=0)
ap = np.zeros((num_classes, 1), dtype=float).squeeze()
for i in range(num_classes):
cls_labels = np.zeros((1, num_test_imgs), dtype=float).squeeze()
for j in range(num_test_imgs):
if gt_labels[j]==i:
cls_labels[j]=1
ap[i] = metrics.average_precision_score(cls_labels, scores[i])
acc = metrics.accuracy_score(gt_labels, pred_labels)
return ap, acc
样例输出
首先需要开启训练,比如:
python solve.py
然后启动VisualDL:
visualDL --logdir=catdog_log --port=8080
打开浏览器获取训练的实时更新的绘图输出:http://localhost:8080。这里仅截图展示:
pycaffe训练的完整组件示例的更多相关文章
- 利用webuploader插件上传图片文件,完整前端示例demo,服务端使用SpringMVC接收
利用WebUploader插件上传图片文件完整前端示例demo,服务端使用SpringMVC接收 Webuploader简介 WebUploader是由Baidu WebFE(FEX)团队开发的一 ...
- Vue列表组件与弹窗组件示例
列表组件 <!DOCTYPE html> <html> <head> <meta charset="utf-8" /> <me ...
- [Nginx]Nginx的基本配置与优化1(完整配置示例与虚拟主机配置)
---------------------------------------------------------------------------------------- 完整配置示例: [ n ...
- 实战SpringCloud响应式微服务系列教程(第十章)响应式RESTful服务完整代码示例
本文为实战SpringCloud响应式微服务系列教程第十章,本章给出响应式RESTful服务完整代码示例.建议没有之前基础的童鞋,先看之前的章节,章节目录放在文末. 1.搭建响应式RESTful服务. ...
- [deviceone开发]-do_Socket组件示例
一.简介 do_Socket只实现了socket的客户端的功能,这个示例完整了展示了组件的基本用法,需要和sockettest3工具配合使用,sockettest3做为一个socket server来 ...
- SpringMVC札集(01)——SpringMVC入门完整详细示例(上)
自定义View系列教程00–推翻自己和过往,重学自定义View 自定义View系列教程01–常用工具介绍 自定义View系列教程02–onMeasure源码详尽分析 自定义View系列教程03–onL ...
- android四大组件学习总结以及各个组件示例(1)
android四大组件分别为activity.service.content provider.broadcast receiver. 一.android四大组件详解 1.activity (1)一个 ...
- asp.net core封装layui组件示例分享
用什么封装?自然是TagHelper啊,是啥?自己瞅文档去 在学习使用TagHelper的时候,最希望的就是能有个Demo能够让自己作为参考 怎么去封装一个组件? 不同的情况怎么去实现? 有没有更好更 ...
- WebRTC 音频采样算法 附完整C++示例代码
之前有大概介绍了音频采样相关的思路,详情见<简洁明了的插值音频重采样算法例子 (附完整C代码)>. 音频方面的开源项目很多很多. 最知名的莫过于谷歌开源的WebRTC, 其中的音频模块就包 ...
随机推荐
- 记录linux 命令
1.du:查询文件或文件夹的磁盘使用空间 如果当前目录下文件和文件夹很多,使用不带参数du的命令,可以循环列出所有文件和文件夹所使用的空间.这对查看究竟是那个地方过大是不利的,所以得指定深入目录的层数 ...
- LwIP Application Developers Manual5---高层协议之DHCP,AUTOIP,SNMP,PPP
1.前言 本文主要讲述高层协议,包括DHCP 2.DHCP 2.1 从应用的角度看DHCP 你必须确保在编译和链接时使能DHCP,可通过在文件lwipopts.h里面定义LWIP_DHCP选项,该选项 ...
- 题解-bzoj2560 串珠子
刚被教练数落了一通,心情不好,来写篇题解 Problem bzoj2560 题目简述:给定\(n\)个点的,每两个点\(i,j\)之间有\(c_{i,j}\)条直接相连的路(其中只能选一条或不选),问 ...
- 题解-BOI 2004 Sequence
Problem bzoj & Luogu 题目大意: 给定序列\(\{a_i\}\),求一个严格递增序列\(\{b_i\}\),使得\(\sum \bigl |a_i-b_i\bigr|\)最 ...
- CodeForces 937D 936B Sleepy Game 有向图判环,拆点,DFS
题意: 一种游戏,2个人轮流控制棋子在一块有向图上移动,每次移动一条边,不能移动的人为输,无限循环则为平局,棋子初始位置为$S$ 现在有一个人可以同时控制两个玩家,问是否能使得第一个人必胜,并输出一个 ...
- sqlserver 2012 分页
--2012的OFFSET分页方式 select number from spt_values where type='p' order by number offset 10 rows fetch ...
- python笔记---需求文件requirements.txt的创建及使用
/******************************************/ 感谢大家一直以来的关注与支持. 本站迁移至 http://qingkang.me 也欢迎大家继续关注. /** ...
- 让NotePad++添加到右键快捷方式
添加后的效果图: 操作方式: 第一步:在桌面上新建一个txt文本文档,然后将写入如下内容 Windows Registry Editor Version 5.00 [HKEY_CLASSES_ROOT ...
- MySQL建库建表
一直使用SQL SERVER 数据库:最近项目使用MY SQL感觉还是有一点不适应.不过熟悉之后就会好很多. MY SQL 安装之后会有一个管理工具MySQL Workbench 感觉不太好用,数据库 ...
- elasticsearch中的java.io.IOException: 远程主机强迫关闭了一个现有的连接
[2018-07-31T14:29:41,289][WARN ][o.e.x.s.t.n.SecurityNetty4HttpServerTransport] [9rTGh-y] caught exc ...