pycaffe训练的完整组件示例

为什么写这篇博客

1. 需要用到pycaffe

因为用到的开源代码基于Caffe;要维护的项目基于Caffe。基本上是用Caffe的Python接口。

2. 训练中想穿插验证并输出关注的指标

比如每训练完1个epoch就应该在完整的validation集合上执行evaluation,输出测量出的、关注的指标,例如AP、Accuracy、F1-score等。Caffe通过solver.prototxt中配置test_net能执行测试,但基本只能输出Accuracy而且是各个test_batch上的平均Accuracy,而不是想关注的验证集整体上的AP(见Solver.cpp源码)

3. 训练中期望有可视化输出

Caffe训练输出在屏幕终端,也可自行重定向到日志文件。的确可以自行解析日志文件,并结合flask搭建web页面实时显示输出。但是这不够标准和鲁棒。期望有专门的可视化工具,避免自己造难用的轮子。

本文给出很简陋的pyCaffe和VisualDL结合的例子。

解决方案

用pycaffe接管训练接口

通过自行编写python代码来执行训练,而不是用$CAFFE_ROOT/build/tools/caffe train --solver solver.prototxt的方式来启动。

  • solver.prototxt中需要配置test_net, test_iter, test_interval,保证solver有test_net对象
  • test_interval设置为999999999,以避开Solver.cpp中执行的TestAll()函数,转而在python代码中手动判断和执行validation
  • 执行validation之前注意test_net.share_with(train_net)
  • 利用solver.step(1)执行训练网络的一次迭代,利用solver.test_net[0].forward()执行测试网络的一次前传
  • 利用net.blobs['prob'].data的形式取出网络输出
  • 利用sklearn.metrics包,将取出的数据执行evaluation
  • 利用VisualDL等可视化工具,将取出的数据执行绘图

依赖项

VisualDL,是PaddlePaddle和ECharts团队联合推出的,应该是对抗谷歌的Tensorboarde的。相信ECharts的实力。

sudo pip install visualdl

看起来VisualDL和Tensorboard类似,不过对于Caffe,用不了Tensorboard,能用VisualDL也是好事。

参考代码

solve.py

#!/usr/bin/env python2
# coding: utf-8 """
inspired and adapted from:
- https://github.com/shelhamer/fcn.berkeleyvision.org
- https://github.com/rbgirshick/py-faster-rcnn
- https://github.com/PaddlePaddle/VisualDL/blob/develop/docs/quick_start_en.md
""" from __future__ import print_function
import _init_paths
import caffe
import argparse
import os
import sys
from datetime import datetime
import cv2 from caffe.proto import caffe_pb2
import google.protobuf as pb2
import google.protobuf.text_format
import numpy as np
import perfeval from visualdl import LogWriter #for visualization during training def parse_args():
"""Parse input arguments"""
parser = argparse.ArgumentParser(description='Train a classification network')
parser.add_argument('--solver', dest='solver',
help='solver prototxt',
default=None, type=str, required=True) parser.add_argument('--weights', dest='pretrained_model',
help='initialize with pretrained model weights',
default=None, type=str) if len(sys.argv) == 1:
parser.print_help()
sys.exit(1) args = parser.parse_args()
return args class SolverWrapper:
"""对于Solver进行封装,便于外部调用"""
def __init__(self, solver_prototxt, num_epoch, num_example, pretrained_model=None):
self.solver = caffe.SGDSolver(solver_prototxt)
if pretrained_model is not None:
print('Loading pretrained model weights from {:s}'.format(pretrained_model))
self.solver.net.copy_from(pretrained_model) self.solver_param = caffe_pb2.SolverParameter()
with open(solver_prototxt, 'rt') as f:
pb2.text_format.Merge(f.read(), self.solver_param)
self.cur_epoch = 0
self.test_interval = 100 #用来替代self.solver_param.test_interval
self.logw = LogWriter("catdog_log", sync_cycle=100)
with self.logw.mode('train') as logger:
self.sc_train_loss = logger.scalar("loss")
self.sc_train_acc = logger.scalar("Accuracy")
with self.logw.mode('val') as logger:
self.sc_val_acc = logger.scalar("Accuracy")
self.sc_val_mAP = logger.scalar("mAP") def train_model(self):
"""执行训练的整个流程,穿插了validation"""
cur_iter = 0
test_batch_size, num_classes = self.solver.test_nets[0].blobs['prob'].shape
num_test_images_tot = test_batch_size * self.solver_param.test_iter[0]
while cur_iter < self.solver_param.max_iter:
#self.solver.step(self.test_interval)
for i in range(self.test_interval):
self.solver.step(1)
loss = self.solver.net.blobs['loss'].data
acc = self.solver.net.blobs['accuracy'].data
step = self.solver.iter
self.sc_train_loss.add_record(step, loss)
self.sc_train_acc.add_record(step, acc) self.eval_on_val(num_classes, num_test_images_tot, test_batch_size)
cur_iter += self.test_interval def eval_on_val(self, num_classes, num_test_images_tot, test_batch_size):
"""在整个验证集上执行inference和evaluation"""
self.solver.test_nets[0].share_with(self.solver.net)
self.cur_epoch += 1
scores = np.zeros((num_classes, num_test_images_tot), dtype=float)
gt_labels = np.zeros((1, num_test_images_tot), dtype=float).squeeze()
for t in range(self.solver_param.test_iter[0]):
output = self.solver.test_nets[0].forward()
probs = output['prob']
labels = self.solver.test_nets[0].blobs['label'].data gt_labels[t*test_batch_size:(t+1)*test_batch_size] = labels.T.astype(float)
scores[:,t*test_batch_size:(t+1)*test_batch_size] = probs.T ap, acc = perfeval.cls_eval(scores, gt_labels)
print('====================================================================\n')
print('\tDo validation after the {:d}-th training epoch\n'.format(self.cur_epoch))
print('>>>>', end='\t') #设定标记,方便于解析日志获取出数据
for i in range(num_classes):
print('AP[{:d}]={:.2f}'.format(i, ap[i]), end=', ')
mAP = np.average(ap)
print('mAP={:.2f}, Accuracy={:.2f}'.format(mAP, acc))
print('\n====================================================================\n')
step = self.solver.iter
self.sc_val_mAP.add_record(step, mAP)
self.sc_val_acc.add_record(step, acc) if __name__ == '__main__':
args = parse_args()
solver_prototxt = args.solver
num_epoch = args.num_epoch
num_batch = args.num_batch
pretrained_model = args.pretrained_model # init
caffe.set_mode_gpu()
caffe.set_device(0) sw = SolverWrapper(solver_prototxt, num_epoch, num_batch, pretrained_model)
sw.train_model()

perfeval.py

#!/usr/bin/env python2
# coding: utf-8 from __future__ import print_function
import numpy as np import sklearn.metrics as metrics def cls_eval(scores, gt_labels):
"""
分类任务的evaluation
@param scores: cxm np-array, m为样本数量(例如一个epoch)
@param gt_labels: 1xm np-array, 元素属于{0,1,2,...,K-1},表示K个类别的索引
"""
num_classes, num_test_imgs = scores.shape pred_labels = scores.argmax(axis=0) ap = np.zeros((num_classes, 1), dtype=float).squeeze()
for i in range(num_classes):
cls_labels = np.zeros((1, num_test_imgs), dtype=float).squeeze()
for j in range(num_test_imgs):
if gt_labels[j]==i:
cls_labels[j]=1
ap[i] = metrics.average_precision_score(cls_labels, scores[i]) acc = metrics.accuracy_score(gt_labels, pred_labels) return ap, acc

样例输出

首先需要开启训练,比如:

python solve.py

然后启动VisualDL:

visualDL --logdir=catdog_log --port=8080

打开浏览器获取训练的实时更新的绘图输出:http://localhost:8080。这里仅截图展示:





pycaffe训练的完整组件示例的更多相关文章

  1. 利用webuploader插件上传图片文件,完整前端示例demo,服务端使用SpringMVC接收

    利用WebUploader插件上传图片文件完整前端示例demo,服务端使用SpringMVC接收 Webuploader简介   WebUploader是由Baidu WebFE(FEX)团队开发的一 ...

  2. Vue列表组件与弹窗组件示例

    列表组件 <!DOCTYPE html> <html> <head> <meta charset="utf-8" /> <me ...

  3. [Nginx]Nginx的基本配置与优化1(完整配置示例与虚拟主机配置)

    ---------------------------------------------------------------------------------------- 完整配置示例: [ n ...

  4. 实战SpringCloud响应式微服务系列教程(第十章)响应式RESTful服务完整代码示例

    本文为实战SpringCloud响应式微服务系列教程第十章,本章给出响应式RESTful服务完整代码示例.建议没有之前基础的童鞋,先看之前的章节,章节目录放在文末. 1.搭建响应式RESTful服务. ...

  5. [deviceone开发]-do_Socket组件示例

    一.简介 do_Socket只实现了socket的客户端的功能,这个示例完整了展示了组件的基本用法,需要和sockettest3工具配合使用,sockettest3做为一个socket server来 ...

  6. SpringMVC札集(01)——SpringMVC入门完整详细示例(上)

    自定义View系列教程00–推翻自己和过往,重学自定义View 自定义View系列教程01–常用工具介绍 自定义View系列教程02–onMeasure源码详尽分析 自定义View系列教程03–onL ...

  7. android四大组件学习总结以及各个组件示例(1)

    android四大组件分别为activity.service.content provider.broadcast receiver. 一.android四大组件详解 1.activity (1)一个 ...

  8. asp.net core封装layui组件示例分享

    用什么封装?自然是TagHelper啊,是啥?自己瞅文档去 在学习使用TagHelper的时候,最希望的就是能有个Demo能够让自己作为参考 怎么去封装一个组件? 不同的情况怎么去实现? 有没有更好更 ...

  9. WebRTC 音频采样算法 附完整C++示例代码

    之前有大概介绍了音频采样相关的思路,详情见<简洁明了的插值音频重采样算法例子 (附完整C代码)>. 音频方面的开源项目很多很多. 最知名的莫过于谷歌开源的WebRTC, 其中的音频模块就包 ...

随机推荐

  1. [转] 如何轻松愉快地理解条件随机场(CRF)?

    原文链接:https://www.jianshu.com/p/55755fc649b1 如何轻松愉快地理解条件随机场(CRF)?   理解条件随机场最好的办法就是用一个现实的例子来说明它.但是目前中文 ...

  2. [转] Understanding-LSTMs 理解LSTM

    图文并茂,讲得极清晰. 原文:http://colah.github.io/posts/2015-08-Understanding-LSTMs/ colah's blog Blog About Con ...

  3. 2017-2018-2 20165325 实验一《Java开发环境的熟悉》实验报告

    一.Java开发环境的熟悉-1 1.实验要求: 0 参考实验要求: 1 建立"自己学号exp1"的目录 : 2 在"自己学号exp1"目录下建立src,bin等 ...

  4. aiojobs

    import asyncio import aiojobs async def coro(timeout): print(timeout) await asyncio.sleep(timeout) p ...

  5. liunx本地网卡流量监控

    作者:邓聪聪 公司网络异常,由于可监控设备有限,无法快速读取网络异常的设备,所以找到了这个办法,部署在服务端用以解决网络突发异常流量故障的查找! 环境:CentOS release 6.8 Linux ...

  6. 正则表达式处理BT的html嵌套问题

    在博问里面求教大神,把问题搞定.在此做个记录备份,也给碰到类似问题的园友提供解决思路. 简化的业务场景就是,在页面html标签中的属性中嵌套了html标签,怎么用用正则表达式过滤闭合的html标签(& ...

  7. python装饰器的4种类型:函数装饰函数、函数装饰类、类装饰函数、类装饰类

    一:函数装饰函数 def wrapFun(func): def inner(a, b): print('function name:', func.__name__) r = func(a, b) r ...

  8. 无线桥接(WDS)如何设置?

    一.WDS使用介绍 无线桥接(WDS)可以将多台无线路由器通过无线方式互联,从而将无线信号扩展放大.无线终端在移动过程中可以自动切换较好的信号,实现无线漫游. 本文指导将TL-WR740N当作副路由器 ...

  9. hibernate框架学习第一天:hibernate介绍及基本操作

    框架辅助开发者进行开发,半成品软件,开发者与框架进行合作开发 Hibernate3Hibernate是一种基于Java的轻量级的ORM框架 基于Java:底层实现是Java语言,可以脱离WEB,在纯J ...

  10. 005_tcp/ip监控

    system.monitor.tcpstat 一.listen+established+time wait+close wait. listen:SELECT mean("listen&qu ...