使用ImageNet在faster-rcnn上训练自己的分类网络
具体代码见https://github.com/zhiyishou/py-faster-rcnn
这是我对cup, glasses训练的识别
faster-rcnn在fast-rcnn的基础上加了rpn来将整个训练都置于GPU内,以用来提高效率,这里我们将使用ImageNet的数据集来在faster-rcnn上来训练自己的分类器。从ImageNet上可下载到很多类别的Image与bounding box annotation来进行训练(每一个类别下的annotation都少于等于image的个数,所以我们从annotation来建立索引)。
在lib/dataset/factory.py
中提供了coco与voc的数据集获取方法,而我们要做的就是在这里加上我们自己的ImageNet获取方法,我们先来建立ImageNet数据获取主文件。coco与pascal_voc的获取都是继承于父类imdb,所以我们可根据pascal_voc的获取方法来做模板修改完成我们的ImageNet类。
创建ImageNet类
由于在faster-rcnn里使用rpn来代替了selective_search,所以我们可以在使用时直接略过有关selective_search的方法,根据pascal_voc类做模板,我们需要留下的方法有:
__init__ //初始化
image_path_at //根据数据集列表的index来取图片绝对地址
image_path_from_index //配合上面
_load_image_set_index //获取数据集列表
_gt_roidb //获取ground-truth数据
rpn_roidb //获取region proposal数据
_load_rpn_roidb //根据gt_roidb生成rpn_roidb数据并合成
_load_psacal_annotation //加载annotation文件并对bounding box进行数据整理
__init__:
def __init__(self, image_set):
imdb.__init__(self, 'imagenet')
self._image_set = image_set
self._data_path = os.path.join(cfg.DATA_DIR, "imagenet")
#类别与对应的wnid,可以修改成自己要训练的类别
self._class_wnids = {
'cup': 'n03147509',
'glasses': 'n04272054'
}
#类别,修改类别时同时要修改这里
self._classes = ('__background__', self._class_wnids['cup'], self._class_wnids['glasses'])
self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
#bounding box annotation 文件的目录
self._xml_path = os.path.join(self._data_path, "Annotations")
self._image_ext = '.JPEG'
#我们使用xml文件名来做数据集的索引
# the xml file name and each one corresponding to image file name
self._image_index = self._load_xml_filenames()
self._salt = str(uuid.uuid4())
self._comp_id = 'comp4'
self.config = {'cleanup' : True,
'use_salt' : True,
'use_diff' : False,
'matlab_eval' : False,
'rpn_file' : None,
'min_size' : 2}
assert os.path.exists(self._data_path), \
'Path does not exist: {}'.format(self._data_path)
image_path_at
def image_path_at(self, i):
#使用index来从xml_filenames取到filename,生成绝对路径
return self.image_path_from_image_filename(self._image_index[i])
image_path_from_image_filename(类似pascal_voc中的image_path_from_index)
def image_path_from_image_filename(self, image_filename):
image_path = os.path.join(self._data_path, 'Images',
image_filename + self._image_ext)
assert os.path.exists(image_path), \
'Path does not exist: {}'.format(image_path)
return image_path
_load_xml_filenames(类似pascal_voc中的_load_image_set_index)
def _load_xml_filenames(self):
#从Annotations文件夹中拿取到bounding box annotation文件名
#用来做数据集的索引
xml_folder_path = os.path.join(self._data_path, "Annotations")
assert os.path.exists(xml_folder_path), \
'Path does not exist: {}'.format(xml_folder_path)
for dirpath, dirnames, filenames in os.walk(xml_folder_path):
xml_filenames = [xml_filename.split(".")[0] for xml_filename in filenames]
return xml_filenames
gt_roidb
def gt_roidb(self):
#Ground-Truth 数据缓存
cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
if os.path.exists(cache_file):
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{} gt roidb loaded from {}'.format(self.name, cache_file)
return roidb
#从xml中获取Ground-Truth数据
gt_roidb = [self._load_imagenet_annotation(xml_filename)
for xml_filename in self._image_index]
with open(cache_file, 'wb') as fid:
cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote gt roidb to {}'.format(cache_file)
return gt_roidb
rpn_roidb
def rpn_roidb(self):
#根据gt_roidb生成rpn_roidb,并进行合并
gt_roidb = self.gt_roidb()
rpn_roidb = self._load_rpn_roidb(gt_roidb)
roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)
return roidb
_load_rpn_roidb
def _load_rpn_roidb(self, gt_roidb):
filename = self.config['rpn_file']
print 'loading {}'.format(filename)
assert os.path.exists(filename), \
'rpn data not found at: {}'.format(filename)
with open(filename, 'rb') as f:
box_list = cPickle.load(f)
return self.create_roidb_from_box_list(box_list, gt_roidb)
_load_imagenet_annotation(类似于pascal_voc中的_load_pascal_annotation)
def _load_imagenet_annotation(self, xml_filename):
#从annotation的xml文件中拿取bounding box数据
filepath = os.path.join(self._data_path, 'Annotations', xml_filename + '.xml')
#这里使用了ap,是我写的一个annotation parser,在后面贴出代码
#它会返回这个xml文件的wnid, 图像文件名,以及里面包含的注解物体
wnid, image_name, objects = ap.parse(filepath)
num_objs = len(objects)
boxes = np.zeros((num_objs, 4), dtype=np.uint16)
gt_classes = np.zeros((num_objs), dtype=np.int32)
overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
seg_areas = np.zeros((num_objs), dtype=np.float32)
# Load object bounding boxes into a data frame.
for ix, obj in enumerate(objects):
box = obj["box"]
x1 = box['xmin']
y1 = box['ymin']
x2 = box['xmax']
y2 = box['ymax']
# 如果这个bounding box并不是我们想要学习的类别,那则跳过
# go next if the wnid not exist in declared classes
try:
cls = self._class_to_ind[obj["wnid"]]
except KeyError:
print "wnid %s isn't show in given"%obj["wnid"]
continue
boxes[ix, :] = [x1, y1, x2, y2]
gt_classes[ix] = cls
overlaps[ix, cls] = 1.0
seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)
overlaps = scipy.sparse.csr_matrix(overlaps)
return {'boxes' : boxes,
'gt_classes': gt_classes,
'gt_overlaps' : overlaps,
'flipped' : False,
'seg_areas' : seg_areas}
annotation_parser.py文件
import os
import xml.dom.minidom
def getText(node):
return node.firstChild.nodeValue
def getWnid(node):
return getText(node.getElementsByTagName("name")[0])
def getImageName(node):
return getText(node.getElementsByTagName("filename")[0])
def getObjects(node):
objects = []
for obj in node.getElementsByTagName("object"):
objects.append({
"wnid": getText(obj.getElementsByTagName("name")[0]),
"box":{
"xmin": int(getText(obj.getElementsByTagName("xmin")[0])),
"ymin": int(getText(obj.getElementsByTagName("ymin")[0])),
"xmax": int(getText(obj.getElementsByTagName("xmax")[0])),
"ymax": int(getText(obj.getElementsByTagName("ymax")[0])),
}
})
return objects
def parse(filepath):
dom = xml.dom.minidom.parse(filepath)
root = dom.documentElement
image_name = getImageName(root)
wnid = getWnid(root)
objects = getObjects(root)
return wnid, image_name, objects
则对数据结构的要求是:
|---data
|---imagenet
|---Annotations
|---n03147509
|---n03147509_*.xml
|---...
|---n04272054
|---n04272054_*.xml
|---...
|---Images
|---n03147508_*.JPEG
|---...
|---n04272054_*.JPEG
|---...
同时我在github上也提供了draw方法,可以用来将bounding box画于Image文件上,用来甄别该annotation的正确性
训练
这样,我们的ImageNet类则是生成好了,下面我们则可以训练我们的数据,但是在开始之前,还有一件事情,那就是修改prototxt中的与类别数目有关的值,我将models/pascal_voc
拷贝到了models/imagenet
进行修改,比如我想要训练ZF,如果使用的是train_faster_rcnn_alt_opt.py,则需要修改models/imagenet/ZF/faster_rcnn_alt_opt/
下的所有pt文件里的内容,用如下的法则去替换:
//num为类别的个数
input-data->num_classes = num
class_score->num_output = num
bbox_pred->num_output = num*4
我这里使用train_faster_rcnn_alt_opt.py进行的训练,这样的话则需要把添加的models/imagenet
作为可选项
//pt_type 则是添加的选择项,默认使用psacal_voc的models
./tools/train_faster_rcnn_alt_opt.py --gpu 0 \
--net_name ZF \
--weights data/imagenet_models/ZF.v2.caffemodel[optional] \
--imdb imagenet \
--cfg experiments/cfgs/faster_rcnn_alt_opt.yml \
--pt_type imagenet
识别
这里我们则需要使用刚训练出来的模型进行识别
#就像demo.py一样,但是使用训练的models,我创建了tools/classify.py来单独识别
prototxt = os.path.join(cfg.ROOT_DIR, 'models/imagenet', NETS[args.demo_net][0], 'faster_rcnn_alt_opt', 'faster_rcnn_test.pt')
caffemodel = os.path.join(cfg.ROOT_DIR, 'output/faster_rcnn_alt_opt/imagenet/'+ NETS[args.demo_net][0] +'_faster_rcnn_final.caffemodel')
同样,在识别前我们要对识别方法里的Classes进行修改,修改成你自己训练的类别后
执行
./tools/classify.py --net zf
则可对data/demo
下的图片文件使用训练的zf网络进行识别
Have fun
使用ImageNet在faster-rcnn上训练自己的分类网络的更多相关文章
- Faster RCNN算法训练代码解析(1)
这周看完faster-rcnn后,应该对其源码进行一个解析,以便后面的使用. 那首先直接先主函数出发py-faster-rcnn/tools/train_faster_rcnn_alt_opt.py ...
- Faster RCNN算法训练代码解析(3)
四个层的forward函数分析: RoIDataLayer:读数据,随机打乱等 AnchorTargetLayer:输出所有anchors(这里分析这个) ProposalLayer:用产生的anch ...
- Faster RCNN算法训练代码解析(2)
接着上篇的博客,我们获取imdb和roidb的数据后,就可以搭建网络进行训练了. 我们回到trian_rpn()函数里面,此时运行完了roidb, imdb = get_roidb(imdb_name ...
- 目标检测(四)Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks
作者:Shaoqing Ren, Kaiming He, Ross Girshick, and Jian Sun SPPnet.Fast R-CNN等目标检测算法已经大幅降低了目标检测网络的运行时间. ...
- Faster R-CNN利用新的网络结构来训练
前言 最近利用Faster R-CNN训练数据,使用ZF模型,效果无法有效提高.就想尝试对ZF的网络结构进行改造,记录下具体操作. 一.更改网络,训练初始化模型 这里为了方便,我们假设更换的网络名为L ...
- object detection[faster rcnn]
这部分,写一写faster rcnn 0. faster rcnn 经过了rcnn,spp,fast rcnn,又到了faster rcnn,作者在对前面的模型回顾中发现,fast rcnn提出的ro ...
- 基于候选区域的深度学习目标检测算法R-CNN,Fast R-CNN,Faster R-CNN
参考文献 [1]Rich feature hierarchies for accurate object detection and semantic segmentation [2]Fast R-C ...
- 【神经网络与深度学习】【计算机视觉】Faster R-CNN
Faster R-CNN Fast-RCNN基本实现端对端(除了proposal阶段外),下一步自然就是要把proposal阶段也用CNN实现(放到GPU上).这就出现了Faster-RCNN,一个完 ...
- Paper Reading:Faster RCNN
Faster R-CNN 论文:Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks 发表时间: ...
随机推荐
- c# string 数组转 list
string str = "1,11,121,131"; var arr = str.Split(','); List<string> list = new List& ...
- maven本地仓库的配置以及如何修改默认.m2仓库位置
本地仓库是远程仓库的一个缓冲和子集,当你构建Maven项目的时候,首先会从本地仓库查找资源,如果没有,那么Maven会从远程仓库下载到你本地仓库.这样在你下次使用的时候就不需要从远程下载了.如果你所需 ...
- How to install Wordpress 4.0 on CentOS 7.0
This document describes how to install and configure Wordpress 4.0 on CentOS 7.0. WordPress started ...
- KB975517 "The update does not apply to your system"
https://www.manageengine.com/products//desktop-central/patch-management/Windows-Vista-Ultimate-Editi ...
- netty常用代码
一. Server public class TimeServer_argu { public void bind(int port) throws InterruptedException { Ev ...
- python 读取sqlite3 数据库
import sqlite3 name = "tom" age = 30 con = sqlite3.connect("d:\\test.db") cur = ...
- android的m、mm、mmm编译命令
android的m.mm.mmm编译命令的使用 android源码目录下的build/envsetup.sh文件,描述编译的命令 - m: Makes from the top of th ...
- GL_GL系列 - 预算管理分析(案例)
2014-07-09 Created By BaoXinjian
- Form_通过Custom.pll新增菜单项(案例)
2014-05-31 Created By BaoXinjian
- NeHe OpenGL教程 第三十四课:地形
转自[翻译]NeHe OpenGL 教程 前言 声明,此 NeHe OpenGL教程系列文章由51博客yarin翻译(2010-08-19),本博客为转载并稍加整理与修改.对NeHe的OpenGL管线 ...