pycaffe训练的完整组件示例

为什么写这篇博客

1. 需要用到pycaffe

因为用到的开源代码基于Caffe;要维护的项目基于Caffe。基本上是用Caffe的Python接口。

2. 训练中想穿插验证并输出关注的指标

比如每训练完1个epoch就应该在完整的validation集合上执行evaluation,输出测量出的、关注的指标,例如AP、Accuracy、F1-score等。Caffe通过solver.prototxt中配置test_net能执行测试,但基本只能输出Accuracy而且是各个test_batch上的平均Accuracy,而不是想关注的验证集整体上的AP(见Solver.cpp源码)

3. 训练中期望有可视化输出

Caffe训练输出在屏幕终端,也可自行重定向到日志文件。的确可以自行解析日志文件,并结合flask搭建web页面实时显示输出。但是这不够标准和鲁棒。期望有专门的可视化工具,避免自己造难用的轮子。

本文给出很简陋的pyCaffe和VisualDL结合的例子。

解决方案

用pycaffe接管训练接口

通过自行编写python代码来执行训练,而不是用$CAFFE_ROOT/build/tools/caffe train --solver solver.prototxt的方式来启动。

  • solver.prototxt中需要配置test_net, test_iter, test_interval,保证solver有test_net对象
  • test_interval设置为999999999,以避开Solver.cpp中执行的TestAll()函数,转而在python代码中手动判断和执行validation
  • 执行validation之前注意test_net.share_with(train_net)
  • 利用solver.step(1)执行训练网络的一次迭代,利用solver.test_net[0].forward()执行测试网络的一次前传
  • 利用net.blobs['prob'].data的形式取出网络输出
  • 利用sklearn.metrics包,将取出的数据执行evaluation
  • 利用VisualDL等可视化工具,将取出的数据执行绘图

依赖项

VisualDL,是PaddlePaddle和ECharts团队联合推出的,应该是对抗谷歌的Tensorboarde的。相信ECharts的实力。

sudo pip install visualdl

看起来VisualDL和Tensorboard类似,不过对于Caffe,用不了Tensorboard,能用VisualDL也是好事。

参考代码

solve.py

#!/usr/bin/env python2
# coding: utf-8 """
inspired and adapted from:
- https://github.com/shelhamer/fcn.berkeleyvision.org
- https://github.com/rbgirshick/py-faster-rcnn
- https://github.com/PaddlePaddle/VisualDL/blob/develop/docs/quick_start_en.md
""" from __future__ import print_function
import _init_paths
import caffe
import argparse
import os
import sys
from datetime import datetime
import cv2 from caffe.proto import caffe_pb2
import google.protobuf as pb2
import google.protobuf.text_format
import numpy as np
import perfeval from visualdl import LogWriter #for visualization during training def parse_args():
"""Parse input arguments"""
parser = argparse.ArgumentParser(description='Train a classification network')
parser.add_argument('--solver', dest='solver',
help='solver prototxt',
default=None, type=str, required=True) parser.add_argument('--weights', dest='pretrained_model',
help='initialize with pretrained model weights',
default=None, type=str) if len(sys.argv) == 1:
parser.print_help()
sys.exit(1) args = parser.parse_args()
return args class SolverWrapper:
"""对于Solver进行封装,便于外部调用"""
def __init__(self, solver_prototxt, num_epoch, num_example, pretrained_model=None):
self.solver = caffe.SGDSolver(solver_prototxt)
if pretrained_model is not None:
print('Loading pretrained model weights from {:s}'.format(pretrained_model))
self.solver.net.copy_from(pretrained_model) self.solver_param = caffe_pb2.SolverParameter()
with open(solver_prototxt, 'rt') as f:
pb2.text_format.Merge(f.read(), self.solver_param)
self.cur_epoch = 0
self.test_interval = 100 #用来替代self.solver_param.test_interval
self.logw = LogWriter("catdog_log", sync_cycle=100)
with self.logw.mode('train') as logger:
self.sc_train_loss = logger.scalar("loss")
self.sc_train_acc = logger.scalar("Accuracy")
with self.logw.mode('val') as logger:
self.sc_val_acc = logger.scalar("Accuracy")
self.sc_val_mAP = logger.scalar("mAP") def train_model(self):
"""执行训练的整个流程,穿插了validation"""
cur_iter = 0
test_batch_size, num_classes = self.solver.test_nets[0].blobs['prob'].shape
num_test_images_tot = test_batch_size * self.solver_param.test_iter[0]
while cur_iter < self.solver_param.max_iter:
#self.solver.step(self.test_interval)
for i in range(self.test_interval):
self.solver.step(1)
loss = self.solver.net.blobs['loss'].data
acc = self.solver.net.blobs['accuracy'].data
step = self.solver.iter
self.sc_train_loss.add_record(step, loss)
self.sc_train_acc.add_record(step, acc) self.eval_on_val(num_classes, num_test_images_tot, test_batch_size)
cur_iter += self.test_interval def eval_on_val(self, num_classes, num_test_images_tot, test_batch_size):
"""在整个验证集上执行inference和evaluation"""
self.solver.test_nets[0].share_with(self.solver.net)
self.cur_epoch += 1
scores = np.zeros((num_classes, num_test_images_tot), dtype=float)
gt_labels = np.zeros((1, num_test_images_tot), dtype=float).squeeze()
for t in range(self.solver_param.test_iter[0]):
output = self.solver.test_nets[0].forward()
probs = output['prob']
labels = self.solver.test_nets[0].blobs['label'].data gt_labels[t*test_batch_size:(t+1)*test_batch_size] = labels.T.astype(float)
scores[:,t*test_batch_size:(t+1)*test_batch_size] = probs.T ap, acc = perfeval.cls_eval(scores, gt_labels)
print('====================================================================\n')
print('\tDo validation after the {:d}-th training epoch\n'.format(self.cur_epoch))
print('>>>>', end='\t') #设定标记,方便于解析日志获取出数据
for i in range(num_classes):
print('AP[{:d}]={:.2f}'.format(i, ap[i]), end=', ')
mAP = np.average(ap)
print('mAP={:.2f}, Accuracy={:.2f}'.format(mAP, acc))
print('\n====================================================================\n')
step = self.solver.iter
self.sc_val_mAP.add_record(step, mAP)
self.sc_val_acc.add_record(step, acc) if __name__ == '__main__':
args = parse_args()
solver_prototxt = args.solver
num_epoch = args.num_epoch
num_batch = args.num_batch
pretrained_model = args.pretrained_model # init
caffe.set_mode_gpu()
caffe.set_device(0) sw = SolverWrapper(solver_prototxt, num_epoch, num_batch, pretrained_model)
sw.train_model()

perfeval.py

#!/usr/bin/env python2
# coding: utf-8 from __future__ import print_function
import numpy as np import sklearn.metrics as metrics def cls_eval(scores, gt_labels):
"""
分类任务的evaluation
@param scores: cxm np-array, m为样本数量(例如一个epoch)
@param gt_labels: 1xm np-array, 元素属于{0,1,2,...,K-1},表示K个类别的索引
"""
num_classes, num_test_imgs = scores.shape pred_labels = scores.argmax(axis=0) ap = np.zeros((num_classes, 1), dtype=float).squeeze()
for i in range(num_classes):
cls_labels = np.zeros((1, num_test_imgs), dtype=float).squeeze()
for j in range(num_test_imgs):
if gt_labels[j]==i:
cls_labels[j]=1
ap[i] = metrics.average_precision_score(cls_labels, scores[i]) acc = metrics.accuracy_score(gt_labels, pred_labels) return ap, acc

样例输出

首先需要开启训练,比如:

python solve.py

然后启动VisualDL:

visualDL --logdir=catdog_log --port=8080

打开浏览器获取训练的实时更新的绘图输出:http://localhost:8080。这里仅截图展示:





pycaffe训练的完整组件示例的更多相关文章

  1. 利用webuploader插件上传图片文件,完整前端示例demo,服务端使用SpringMVC接收

    利用WebUploader插件上传图片文件完整前端示例demo,服务端使用SpringMVC接收 Webuploader简介   WebUploader是由Baidu WebFE(FEX)团队开发的一 ...

  2. Vue列表组件与弹窗组件示例

    列表组件 <!DOCTYPE html> <html> <head> <meta charset="utf-8" /> <me ...

  3. [Nginx]Nginx的基本配置与优化1(完整配置示例与虚拟主机配置)

    ---------------------------------------------------------------------------------------- 完整配置示例: [ n ...

  4. 实战SpringCloud响应式微服务系列教程(第十章)响应式RESTful服务完整代码示例

    本文为实战SpringCloud响应式微服务系列教程第十章,本章给出响应式RESTful服务完整代码示例.建议没有之前基础的童鞋,先看之前的章节,章节目录放在文末. 1.搭建响应式RESTful服务. ...

  5. [deviceone开发]-do_Socket组件示例

    一.简介 do_Socket只实现了socket的客户端的功能,这个示例完整了展示了组件的基本用法,需要和sockettest3工具配合使用,sockettest3做为一个socket server来 ...

  6. SpringMVC札集(01)——SpringMVC入门完整详细示例(上)

    自定义View系列教程00–推翻自己和过往,重学自定义View 自定义View系列教程01–常用工具介绍 自定义View系列教程02–onMeasure源码详尽分析 自定义View系列教程03–onL ...

  7. android四大组件学习总结以及各个组件示例(1)

    android四大组件分别为activity.service.content provider.broadcast receiver. 一.android四大组件详解 1.activity (1)一个 ...

  8. asp.net core封装layui组件示例分享

    用什么封装?自然是TagHelper啊,是啥?自己瞅文档去 在学习使用TagHelper的时候,最希望的就是能有个Demo能够让自己作为参考 怎么去封装一个组件? 不同的情况怎么去实现? 有没有更好更 ...

  9. WebRTC 音频采样算法 附完整C++示例代码

    之前有大概介绍了音频采样相关的思路,详情见<简洁明了的插值音频重采样算法例子 (附完整C代码)>. 音频方面的开源项目很多很多. 最知名的莫过于谷歌开源的WebRTC, 其中的音频模块就包 ...

随机推荐

  1. 记录linux 命令

    1.du:查询文件或文件夹的磁盘使用空间 如果当前目录下文件和文件夹很多,使用不带参数du的命令,可以循环列出所有文件和文件夹所使用的空间.这对查看究竟是那个地方过大是不利的,所以得指定深入目录的层数 ...

  2. LwIP Application Developers Manual5---高层协议之DHCP,AUTOIP,SNMP,PPP

    1.前言 本文主要讲述高层协议,包括DHCP 2.DHCP 2.1 从应用的角度看DHCP 你必须确保在编译和链接时使能DHCP,可通过在文件lwipopts.h里面定义LWIP_DHCP选项,该选项 ...

  3. 题解-bzoj2560 串珠子

    刚被教练数落了一通,心情不好,来写篇题解 Problem bzoj2560 题目简述:给定\(n\)个点的,每两个点\(i,j\)之间有\(c_{i,j}\)条直接相连的路(其中只能选一条或不选),问 ...

  4. 题解-BOI 2004 Sequence

    Problem bzoj & Luogu 题目大意: 给定序列\(\{a_i\}\),求一个严格递增序列\(\{b_i\}\),使得\(\sum \bigl |a_i-b_i\bigr|\)最 ...

  5. CodeForces 937D 936B Sleepy Game 有向图判环,拆点,DFS

    题意: 一种游戏,2个人轮流控制棋子在一块有向图上移动,每次移动一条边,不能移动的人为输,无限循环则为平局,棋子初始位置为$S$ 现在有一个人可以同时控制两个玩家,问是否能使得第一个人必胜,并输出一个 ...

  6. sqlserver 2012 分页

    --2012的OFFSET分页方式 select number from spt_values where type='p' order by number offset 10 rows fetch ...

  7. python笔记---需求文件requirements.txt的创建及使用

    /******************************************/ 感谢大家一直以来的关注与支持. 本站迁移至 http://qingkang.me 也欢迎大家继续关注. /** ...

  8. 让NotePad++添加到右键快捷方式

    添加后的效果图: 操作方式: 第一步:在桌面上新建一个txt文本文档,然后将写入如下内容 Windows Registry Editor Version 5.00 [HKEY_CLASSES_ROOT ...

  9. MySQL建库建表

    一直使用SQL SERVER 数据库:最近项目使用MY SQL感觉还是有一点不适应.不过熟悉之后就会好很多. MY SQL 安装之后会有一个管理工具MySQL Workbench 感觉不太好用,数据库 ...

  10. elasticsearch中的java.io.IOException: 远程主机强迫关闭了一个现有的连接

    [2018-07-31T14:29:41,289][WARN ][o.e.x.s.t.n.SecurityNetty4HttpServerTransport] [9rTGh-y] caught exc ...