具体代码见https://github.com/zhiyishou/py-faster-rcnn



这是我对cup, glasses训练的识别

faster-rcnn在fast-rcnn的基础上加了rpn来将整个训练都置于GPU内,以用来提高效率,这里我们将使用ImageNet的数据集来在faster-rcnn上来训练自己的分类器。从ImageNet上可下载到很多类别的Image与bounding box annotation来进行训练(每一个类别下的annotation都少于等于image的个数,所以我们从annotation来建立索引)。

lib/dataset/factory.py中提供了coco与voc的数据集获取方法,而我们要做的就是在这里加上我们自己的ImageNet获取方法,我们先来建立ImageNet数据获取主文件。coco与pascal_voc的获取都是继承于父类imdb,所以我们可根据pascal_voc的获取方法来做模板修改完成我们的ImageNet类。

创建ImageNet类

由于在faster-rcnn里使用rpn来代替了selective_search,所以我们可以在使用时直接略过有关selective_search的方法,根据pascal_voc类做模板,我们需要留下的方法有:

__init__ //初始化
image_path_at //根据数据集列表的index来取图片绝对地址
image_path_from_index //配合上面
_load_image_set_index //获取数据集列表
_gt_roidb //获取ground-truth数据
rpn_roidb //获取region proposal数据
_load_rpn_roidb //根据gt_roidb生成rpn_roidb数据并合成
_load_psacal_annotation //加载annotation文件并对bounding box进行数据整理

__init__:

def __init__(self, image_set):
imdb.__init__(self, 'imagenet')
self._image_set = image_set
self._data_path = os.path.join(cfg.DATA_DIR, "imagenet")
#类别与对应的wnid,可以修改成自己要训练的类别
self._class_wnids = {
'cup': 'n03147509',
'glasses': 'n04272054'
} #类别,修改类别时同时要修改这里
self._classes = ('__background__', self._class_wnids['cup'], self._class_wnids['glasses'])
self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
#bounding box annotation 文件的目录
self._xml_path = os.path.join(self._data_path, "Annotations")
self._image_ext = '.JPEG'
#我们使用xml文件名来做数据集的索引
# the xml file name and each one corresponding to image file name
self._image_index = self._load_xml_filenames()
self._salt = str(uuid.uuid4())
self._comp_id = 'comp4' self.config = {'cleanup' : True,
'use_salt' : True,
'use_diff' : False,
'matlab_eval' : False,
'rpn_file' : None,
'min_size' : 2} assert os.path.exists(self._data_path), \
'Path does not exist: {}'.format(self._data_path)

image_path_at

def image_path_at(self, i):
#使用index来从xml_filenames取到filename,生成绝对路径
return self.image_path_from_image_filename(self._image_index[i])

image_path_from_image_filename(类似pascal_voc中的image_path_from_index)

def image_path_from_image_filename(self, image_filename):
image_path = os.path.join(self._data_path, 'Images',
image_filename + self._image_ext)
assert os.path.exists(image_path), \
'Path does not exist: {}'.format(image_path)
return image_path

_load_xml_filenames(类似pascal_voc中的_load_image_set_index)

def _load_xml_filenames(self):
#从Annotations文件夹中拿取到bounding box annotation文件名
#用来做数据集的索引
xml_folder_path = os.path.join(self._data_path, "Annotations")
assert os.path.exists(xml_folder_path), \
'Path does not exist: {}'.format(xml_folder_path) for dirpath, dirnames, filenames in os.walk(xml_folder_path):
xml_filenames = [xml_filename.split(".")[0] for xml_filename in filenames] return xml_filenames

gt_roidb

def gt_roidb(self):
#Ground-Truth 数据缓存
cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
if os.path.exists(cache_file):
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{} gt roidb loaded from {}'.format(self.name, cache_file)
return roidb #从xml中获取Ground-Truth数据
gt_roidb = [self._load_imagenet_annotation(xml_filename)
for xml_filename in self._image_index]
with open(cache_file, 'wb') as fid:
cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote gt roidb to {}'.format(cache_file) return gt_roidb

rpn_roidb

def rpn_roidb(self):
#根据gt_roidb生成rpn_roidb,并进行合并
gt_roidb = self.gt_roidb()
rpn_roidb = self._load_rpn_roidb(gt_roidb)
roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb) return roidb

_load_rpn_roidb

def _load_rpn_roidb(self, gt_roidb):
filename = self.config['rpn_file']
print 'loading {}'.format(filename)
assert os.path.exists(filename), \
'rpn data not found at: {}'.format(filename)
with open(filename, 'rb') as f:
box_list = cPickle.load(f)
return self.create_roidb_from_box_list(box_list, gt_roidb)

_load_imagenet_annotation(类似于pascal_voc中的_load_pascal_annotation)

def _load_imagenet_annotation(self, xml_filename):
#从annotation的xml文件中拿取bounding box数据
filepath = os.path.join(self._data_path, 'Annotations', xml_filename + '.xml')
#这里使用了ap,是我写的一个annotation parser,在后面贴出代码
#它会返回这个xml文件的wnid, 图像文件名,以及里面包含的注解物体
wnid, image_name, objects = ap.parse(filepath)
num_objs = len(objects) boxes = np.zeros((num_objs, 4), dtype=np.uint16)
gt_classes = np.zeros((num_objs), dtype=np.int32)
overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
seg_areas = np.zeros((num_objs), dtype=np.float32) # Load object bounding boxes into a data frame.
for ix, obj in enumerate(objects):
box = obj["box"]
x1 = box['xmin']
y1 = box['ymin']
x2 = box['xmax']
y2 = box['ymax']
# 如果这个bounding box并不是我们想要学习的类别,那则跳过
# go next if the wnid not exist in declared classes
try:
cls = self._class_to_ind[obj["wnid"]]
except KeyError:
print "wnid %s isn't show in given"%obj["wnid"]
continue
boxes[ix, :] = [x1, y1, x2, y2]
gt_classes[ix] = cls
overlaps[ix, cls] = 1.0
seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1) overlaps = scipy.sparse.csr_matrix(overlaps) return {'boxes' : boxes,
'gt_classes': gt_classes,
'gt_overlaps' : overlaps,
'flipped' : False,
'seg_areas' : seg_areas}

annotation_parser.py文件

import os
import xml.dom.minidom def getText(node):
return node.firstChild.nodeValue def getWnid(node):
return getText(node.getElementsByTagName("name")[0]) def getImageName(node):
return getText(node.getElementsByTagName("filename")[0]) def getObjects(node):
objects = []
for obj in node.getElementsByTagName("object"):
objects.append({
"wnid": getText(obj.getElementsByTagName("name")[0]),
"box":{
"xmin": int(getText(obj.getElementsByTagName("xmin")[0])),
"ymin": int(getText(obj.getElementsByTagName("ymin")[0])),
"xmax": int(getText(obj.getElementsByTagName("xmax")[0])),
"ymax": int(getText(obj.getElementsByTagName("ymax")[0])),
}
})
return objects def parse(filepath):
dom = xml.dom.minidom.parse(filepath)
root = dom.documentElement
image_name = getImageName(root)
wnid = getWnid(root)
objects = getObjects(root) return wnid, image_name, objects

则对数据结构的要求是:

|---data
|---imagenet
|---Annotations
|---n03147509
|---n03147509_*.xml
|---...
|---n04272054
|---n04272054_*.xml
|---...
|---Images
|---n03147508_*.JPEG
|---...
|---n04272054_*.JPEG
|---...

同时我在github上也提供了draw方法,可以用来将bounding box画于Image文件上,用来甄别该annotation的正确性

训练

这样,我们的ImageNet类则是生成好了,下面我们则可以训练我们的数据,但是在开始之前,还有一件事情,那就是修改prototxt中的与类别数目有关的值,我将models/pascal_voc拷贝到了models/imagenet进行修改,比如我想要训练ZF,如果使用的是train_faster_rcnn_alt_opt.py,则需要修改models/imagenet/ZF/faster_rcnn_alt_opt/下的所有pt文件里的内容,用如下的法则去替换:

//num为类别的个数
input-data->num_classes = num
class_score->num_output = num
bbox_pred->num_output = num*4

我这里使用train_faster_rcnn_alt_opt.py进行的训练,这样的话则需要把添加的models/imagenet作为可选项

//pt_type 则是添加的选择项,默认使用psacal_voc的models
./tools/train_faster_rcnn_alt_opt.py --gpu 0 \
--net_name ZF \
--weights data/imagenet_models/ZF.v2.caffemodel[optional] \
--imdb imagenet \
--cfg experiments/cfgs/faster_rcnn_alt_opt.yml \
--pt_type imagenet

识别

这里我们则需要使用刚训练出来的模型进行识别

#就像demo.py一样,但是使用训练的models,我创建了tools/classify.py来单独识别
prototxt = os.path.join(cfg.ROOT_DIR, 'models/imagenet', NETS[args.demo_net][0], 'faster_rcnn_alt_opt', 'faster_rcnn_test.pt')
caffemodel = os.path.join(cfg.ROOT_DIR, 'output/faster_rcnn_alt_opt/imagenet/'+ NETS[args.demo_net][0] +'_faster_rcnn_final.caffemodel')

同样,在识别前我们要对识别方法里的Classes进行修改,修改成你自己训练的类别后

执行

./tools/classify.py --net zf

则可对data/demo下的图片文件使用训练的zf网络进行识别

Have fun

使用ImageNet在faster-rcnn上训练自己的分类网络的更多相关文章

  1. Faster RCNN算法训练代码解析(1)

    这周看完faster-rcnn后,应该对其源码进行一个解析,以便后面的使用. 那首先直接先主函数出发py-faster-rcnn/tools/train_faster_rcnn_alt_opt.py ...

  2. Faster RCNN算法训练代码解析(3)

    四个层的forward函数分析: RoIDataLayer:读数据,随机打乱等 AnchorTargetLayer:输出所有anchors(这里分析这个) ProposalLayer:用产生的anch ...

  3. Faster RCNN算法训练代码解析(2)

    接着上篇的博客,我们获取imdb和roidb的数据后,就可以搭建网络进行训练了. 我们回到trian_rpn()函数里面,此时运行完了roidb, imdb = get_roidb(imdb_name ...

  4. 目标检测(四)Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks

    作者:Shaoqing Ren, Kaiming He, Ross Girshick, and Jian Sun SPPnet.Fast R-CNN等目标检测算法已经大幅降低了目标检测网络的运行时间. ...

  5. Faster R-CNN利用新的网络结构来训练

    前言 最近利用Faster R-CNN训练数据,使用ZF模型,效果无法有效提高.就想尝试对ZF的网络结构进行改造,记录下具体操作. 一.更改网络,训练初始化模型 这里为了方便,我们假设更换的网络名为L ...

  6. object detection[faster rcnn]

    这部分,写一写faster rcnn 0. faster rcnn 经过了rcnn,spp,fast rcnn,又到了faster rcnn,作者在对前面的模型回顾中发现,fast rcnn提出的ro ...

  7. 基于候选区域的深度学习目标检测算法R-CNN,Fast R-CNN,Faster R-CNN

    参考文献 [1]Rich feature hierarchies for accurate object detection and semantic segmentation [2]Fast R-C ...

  8. 【神经网络与深度学习】【计算机视觉】Faster R-CNN

    Faster R-CNN Fast-RCNN基本实现端对端(除了proposal阶段外),下一步自然就是要把proposal阶段也用CNN实现(放到GPU上).这就出现了Faster-RCNN,一个完 ...

  9. Paper Reading:Faster RCNN

    Faster R-CNN 论文:Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks 发表时间: ...

随机推荐

  1. c# string 数组转 list

    string str = "1,11,121,131"; var arr = str.Split(','); List<string> list = new List& ...

  2. maven本地仓库的配置以及如何修改默认.m2仓库位置

    本地仓库是远程仓库的一个缓冲和子集,当你构建Maven项目的时候,首先会从本地仓库查找资源,如果没有,那么Maven会从远程仓库下载到你本地仓库.这样在你下次使用的时候就不需要从远程下载了.如果你所需 ...

  3. How to install Wordpress 4.0 on CentOS 7.0

    This document describes how to install and configure Wordpress 4.0 on CentOS 7.0. WordPress started ...

  4. KB975517 "The update does not apply to your system"

    https://www.manageengine.com/products//desktop-central/patch-management/Windows-Vista-Ultimate-Editi ...

  5. netty常用代码

    一. Server public class TimeServer_argu { public void bind(int port) throws InterruptedException { Ev ...

  6. python 读取sqlite3 数据库

    import sqlite3 name = "tom" age = 30 con = sqlite3.connect("d:\\test.db") cur = ...

  7. android的m、mm、mmm编译命令

    android的m.mm.mmm编译命令的使用 android源码目录下的build/envsetup.sh文件,描述编译的命令 - m:       Makes from the top of th ...

  8. GL_GL系列 - 预算管理分析(案例)

    2014-07-09 Created By BaoXinjian

  9. Form_通过Custom.pll新增菜单项(案例)

    2014-05-31 Created By BaoXinjian

  10. NeHe OpenGL教程 第三十四课:地形

    转自[翻译]NeHe OpenGL 教程 前言 声明,此 NeHe OpenGL教程系列文章由51博客yarin翻译(2010-08-19),本博客为转载并稍加整理与修改.对NeHe的OpenGL管线 ...