# train_net.py
#!/usr/bin/env python # --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# -------------------------------------------------------- """Train a Fast R-CNN network on a region of interest database.""" import _init_paths
from fast_rcnn.train import get_training_roidb, train_net
from fast_rcnn.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
from datasets.factory import get_imdb
import datasets.imdb
import caffe
import argparse
import pprint
import numpy as np
import sys def parse_args():
# 运行时命令行参数
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
parser.add_argument('--gpu', dest='gpu_id',
help='GPU device id to use [0]',
default=0, type=int)
parser.add_argument('--solver', dest='solver',
help='solver prototxt',
default=None, type=str)
parser.add_argument('--iters', dest='max_iters',
help='number of iterations to train',
default=40000, type=int)
parser.add_argument('--weights', dest='pretrained_model',
help='initialize with pretrained model weights',
default=None, type=str)
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',
default=None, type=str)
parser.add_argument('--imdb', dest='imdb_name',
help='dataset to train on',
default='voc_2007_trainval', type=str)
parser.add_argument('--rand', dest='randomize',
help='randomize (do not use a fixed seed)',
action='store_true')
parser.add_argument('--set', dest='set_cfgs',
help='set config keys', default=None,
nargs=argparse.REMAINDER) if len(sys.argv) == 1:
parser.print_help()
sys.exit(1) args = parser.parse_args()
return args def combined_roidb(imdb_names):
# 融合roidb,roidb来自于数据集(实验可能用到多个),所以需要combine多个数据集的roidb
#
def get_roidb(imdb_name):
imdb = get_imdb(imdb_name)
print 'Loaded dataset `{:s}` for training'.format(imdb.name)
# 设置proposal方法,这里是selective search(config.py)
imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
     # 得到用于训练的roidb,定义在train.py,进行了水平翻转,以及为原始roidb添加了一些说明性的属性
roidb = get_training_roidb(imdb)
return roidb roidbs = [get_roidb(s) for s in imdb_names.split('+')]
roidb = roidbs[0]
# 这里进行combine roidb
if len(roidbs) > 1:
for r in roidbs[1:]:
roidb.extend(r)
imdb = datasets.imdb.imdb(imdb_names)
else:
# get_imdb方法定义在dataset/factory.py,通过名字得到imdb
imdb = get_imdb(imdb_names)
return imdb, roidb if __name__ == '__main__':
args = parse_args() print('Called with args:')
print(args)
#定义在config.py,见py-faster-rcnn代码阅读2-config.py if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs) cfg.GPU_ID = args.gpu_id print('Using config:')
pprint.pprint(cfg) if not args.randomize:
# fix the random seeds (numpy and caffe) for reproducibility
np.random.seed(cfg.RNG_SEED)
caffe.set_random_seed(cfg.RNG_SEED) # set up caffe
caffe.set_mode_gpu()
caffe.set_device(args.gpu_id) imdb, roidb = combined_roidb(args.imdb_name)
print '{:d} roidb entries'.format(len(roidb)) output_dir = get_output_dir(imdb)
print 'Output will be saved to `{:s}`'.format(output_dir)
#定义在train.py
train_net(args.solver, roidb, output_dir,
pretrained_model=args.pretrained_model,
max_iters=args.max_iters)
# train.py

"""Train a Fast R-CNN network."""

import caffe
from fast_rcnn.config import cfg
import roi_data_layer.roidb as rdl_roidb
from utils.timer import Timer
import numpy as np
import os from caffe.proto import caffe_pb2
import google.protobuf as pb2 class SolverWrapper(object):
# 对caffe solver的简单封装,封装允许我们控制snapshot过程,用于去除归一化学习到的bb回归权重
"""A simple wrapper around Caffe's solver.
This wrapper gives us control over the snapshotting process, which we
use to unnormalize the learned bounding-box regression weights.
""" def __init__(self, solver_prototxt, roidb, output_dir,
pretrained_model=None):
"""Initialize the SolverWrapper."""
self.output_dir = output_dir if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and
cfg.TRAIN.BBOX_NORMALIZE_TARGETS):
# RPN can only use precomputed normalization because there are no
# fixed statistics to compute a priori assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED
# 求bb回归器的target,返回box的均值和标准差
if cfg.TRAIN.BBOX_REG:
print 'Computing bounding-box regression targets...'
self.bbox_means, self.bbox_stds = \
rdl_roidb.add_bbox_regression_targets(roidb)
print 'done'
# solver,copy预训练模型的权重,set roidb
self.solver = caffe.SGDSolver(solver_prototxt)
if pretrained_model is not None:
print ('Loading pretrained model '
'weights from {:s}').format(pretrained_model)
self.solver.net.copy_from(pretrained_model) self.solver_param = caffe_pb2.SolverParameter()
with open(solver_prototxt, 'rt') as f:
pb2.text_format.Merge(f.read(), self.solver_param) self.solver.net.layers[0].set_roidb(roidb) def snapshot(self):
"""Take a snapshot of the network after unnormalizing the learned
bounding-box regression weights. This enables easy use at test-time.
"""
# 给snapshot的bb预测参数做去归一化
net = self.solver.net scale_bbox_params = (cfg.TRAIN.BBOX_REG and
cfg.TRAIN.BBOX_NORMALIZE_TARGETS and
net.params.has_key('bbox_pred')) if scale_bbox_params:
# save original values
orig_0 = net.params['bbox_pred'][0].data.copy()
orig_1 = net.params['bbox_pred'][1].data.copy() # scale and shift with bbox reg unnormalization; then save snapshot
# 去除归一化,乘标准差,加上均值
net.params['bbox_pred'][0].data[...] = \
(net.params['bbox_pred'][0].data *
self.bbox_stds[:, np.newaxis])
net.params['bbox_pred'][1].data[...] = \
(net.params['bbox_pred'][1].data *
self.bbox_stds + self.bbox_means)
# 给snapshot命名
infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
filename = (self.solver_param.snapshot_prefix + infix +
'_iter_{:d}'.format(self.solver.iter) + '.caffemodel')
filename = os.path.join(self.output_dir, filename)
# 存snapshot
net.save(str(filename))
print 'Wrote snapshot to: {:s}'.format(filename) # 只是存入的snapshot的bb参数做了去归一化用于测试,但是训练部分仍需要保持归一化的状态
if scale_bbox_params:
# restore net to original state
net.params['bbox_pred'][0].data[...] = orig_0
net.params['bbox_pred'][1].data[...] = orig_1
return filename def train_model(self, max_iters):
"""Network training loop."""
last_snapshot_iter = -1
timer = Timer()
model_paths = []
while self.solver.iter < max_iters:
# Make one SGD update
timer.tic()
self.solver.step(1)
timer.toc()
if self.solver.iter % (10 * self.solver_param.display) == 0:
print 'speed: {:.3f}s / iter'.format(timer.average_time)
# 达到预设次数保存snapshot
if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
last_snapshot_iter = self.solver.iter
model_paths.append(self.snapshot())
# 整体迭代完成后也要存snapshot
if last_snapshot_iter != self.solver.iter:
model_paths.append(self.snapshot())
return model_paths def get_training_roidb(imdb):
"""Returns a roidb (Region of Interest database) for use in training."""
if cfg.TRAIN.USE_FLIPPED:
print 'Appending horizontally-flipped training examples...'
# 水平翻转,得到更多的roidb
imdb.append_flipped_images()
print 'done' print 'Preparing training data...'
# 为原始数据集的roidb添加一些说明性的属性,max-overlap,max-classes...
rdl_roidb.prepare_roidb(imdb)
print 'done' return imdb.roidb def filter_roidb(roidb):
"""Remove roidb entries that have no usable RoIs."""
# 删掉没用的RoIs, 有效的图片必须各有前景和背景ROI
def is_valid(entry):
# Valid images have:
# (1) At least one foreground RoI OR
# (2) At least one background RoI
overlaps = entry['max_overlaps']
# find boxes with sufficient overlap
fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]
# Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
(overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
# image is only valid if such boxes exist
valid = len(fg_inds) > 0 or len(bg_inds) > 0
return valid num = len(roidb)
filtered_roidb = [entry for entry in roidb if is_valid(entry)]
num_after = len(filtered_roidb)
print 'Filtered {} roidb entries: {} -> {}'.format(num - num_after,
num, num_after)
return filtered_roidb def train_net(solver_prototxt, roidb, output_dir,
pretrained_model=None, max_iters=40000):
"""Train a Fast R-CNN network."""
# 训练Fast R-CNN
roidb = filter_roidb(roidb)
sw = SolverWrapper(solver_prototxt, roidb, output_dir,
pretrained_model=pretrained_model) print 'Solving...'
# 训练...
model_paths = sw.train_model(max_iters)
print 'done solving'
return model_paths

py-faster-rcnn代码阅读1-train_net.py & train.py的更多相关文章

  1. py faster rcnn+ 1080Ti+cudnn5.0

    看了py-faster-rcnn上的issue,原来大家都遇到各种问题. 我要好好琢磨一下,看看到底怎么样才能更好地把GPU卡发挥出来.最近真是和GPU卡较上劲了. 上午解决了g++的问题不是. 然后 ...

  2. Faster RCNN代码理解(Python)

    转自http://www.infocool.net/kb/Python/201611/209696.html#原文地址 第一步,准备 从train_faster_rcnn_alt_opt.py入: 初 ...

  3. Faster rcnn代码理解(1)

    这段时间看了不少论文,回头看看,感觉还是有必要将Faster rcnn的源码理解一下,毕竟后来很多方法都和它有相近之处,同时理解该框架也有助于以后自己修改和编写自己的框架.好的开始吧- 这里我们跟着F ...

  4. Faster RCNN代码解析

    1.faster_rcnn_end2end训练 1.1训练入口及配置 def train(): cfg.GPU_ID = 0 cfg_file = "../experiments/cfgs/ ...

  5. Faster rcnn代码理解(2)

    接着上篇的博客,咱们继续看一下Faster RCNN的代码- 上次大致讲完了Faster rcnn在训练时是如何获取imdb和roidb文件的,主要都在train_rpn()的get_roidb()函 ...

  6. Faster R-CNN代码例子

    主要参考文章:1,从编程实现角度学习Faster R-CNN(附极简实现) 经常是做到一半发现收敛情况不理想,然后又回去看看这篇文章的细节. 另外两篇: 2,Faster R-CNN学习总结      ...

  7. Faster rcnn代码理解(4)

    上一篇我们说完了AnchorTargetLayer层,然后我将Faster rcnn中的其他层看了,这里把ROIPoolingLayer层说一下: 我先说一下它的实现原理:RPN生成的roi区域大小是 ...

  8. Faster R-CNN论文阅读摘要

    论文链接: https://arxiv.org/pdf/1506.01497.pdf 代码下载: https://github.com/ShaoqingRen/faster_rcnn (MATLAB) ...

  9. Faster rcnn代码理解(3)

    紧接着之前的博客,我们继续来看faster rcnn中的AnchorTargetLayer层: 该层定义在lib>rpn>中,见该层定义: 首先说一下这一层的目的是输出在特征图上所有点的a ...

  10. tensorflow faster rcnn 代码分析一 demo.py

    os.environ["CUDA_VISIBLE_DEVICES"]=2 # 设置使用的GPU tfconfig=tf.ConfigProto(allow_soft_placeme ...

随机推荐

  1. Leetcode——30.与所有单词相关联的字串【##】

    @author: ZZQ @software: PyCharm @file: leetcode30_findSubstring.py @time: 2018/11/20 19:14 题目要求: 给定一 ...

  2. VS2013安装及测试

    一.Visual Studio的安装 首先是Visual Studio英文版的安装,安装完成后,为了用的时候方便,我从官网下载Visual Studio 2013的语言包并安装. 二.进行单元测试. ...

  3. Oracle client 使用 .net程序连接 数据库时 出现 8.1.7 的解决办法

    1. GS产品 连接oracle数据库时出现错误图示 2. 其实解决这个问题的办法很简单 一般是 修改一下 Oracle的app 目录的权限 最简单的办法是增加 everyone 权限 然后重启机器即 ...

  4. PHP4个载入语句的区别

    4个载入语句的区别 include和require的区别: include载入文件失败时(即没有找到该文件),报一个“提示错误”,然后继续执行后续代码: requre载入文件失败时,报错并立即终止执行 ...

  5. js 复制到剪切板

    function copyTextToClipboard(text) { var copyFrom = $('<textarea/>'); copyFrom.text(text); $(' ...

  6. 去掉DataTable列中的重复行

    DataTable  dt = ds.Tables[0];    //获得 DataTable  DataView dv = new DataView(dt);DataTable dt2 = dv.T ...

  7. Ryuji doesn't want to study 2018 徐州赛区网络预赛

    题意: 1.区间求 a[l]×L+a[l+1]×(L−1)+⋯+a[r−1]×2+a[r](L is the length of [ l, r ] that equals to r - l + 1) ...

  8. dp乱写1:状态压缩dp(状压dp)炮兵阵地

    https://www.luogu.org/problem/show?pid=2704 题意: 炮兵在地图上的摆放位子只能在平地('P') 炮兵可以攻击上下左右各两格的格子: 而高原('H')上炮兵能 ...

  9. AtCoder Regular Contest 069 F - Flags

    题意: 有n个点需要摆在一个数轴上,每个点需要摆在ai这个位置或者bi上,问怎么摆能使数轴上相邻两个点之间的距离的最小值最大. 二分答案后显然是个2-sat判定问题,因为边很多而连边的又是一个区间,所 ...

  10. 洛谷_Cx的故事_解题报告_第四题70

    1.并查集求最小生成树 Code: #include <stdio.h> #include <stdlib.h>   struct node {     long x,y,c; ...