faster RCNN(keras版本)代码讲解(3)-训练流程详情
转载:https://blog.csdn.net/u011311291/article/details/81121519
https://blog.csdn.net/qq_34564612/article/details/79138876
faster RCNN(keras版本)代码讲解博客索引:
1.faster RCNN(keras版本)代码讲解(1)-概述
2.faster RCNN(keras版本)代码讲解(2)-数据准备
3.faster RCNN(keras版本)代码讲解(3)-训练流程详情
4.faster RCNN(keras版本)代码讲解(4)-共享卷积层详情
5.faster RCNN(keras版本)代码讲解(5)-RPN层详情
6.faster RCNN(keras版本)代码讲解(6)-ROI Pooling层详情
一.整体流程概述
1.输入参数,其实输入1个就行了(D:\tempFile\VOCdevkit),另外一个resnet权重只是为了加快训练,如图:
2.从VOC2007数据集中读取数据,变成想要的数据格式
3.定义生成数据的迭代器
4.定义3个网络,一个是resnet共享卷积层,一个rpn层,一个分类器层
5.进入迭代,每次只训练一张图片
6.是否要进行图片增强
7.根据特征图和定义框的比例,IOU等计算出y_train值,作为网络的label
8.训练rpn层,输出物体,和物体框的坐标
9.然后再进行分类器层层的训练
二.代码(关键部位已经给出注释)
from __future__ import division
import random
import pprint
import sys
import time
import numpy as np
from optparse import OptionParser
import pickle
from keras import backend as K
from keras.optimizers import Adam, SGD, RMSprop
from keras.layers import Input
from keras.models import Model
from keras_frcnn import config, data_generators
from keras_frcnn import losses as losses
import keras_frcnn.roi_helpers as roi_helpers
from keras.utils import generic_utils
sys.setrecursionlimit(40000)
parser = OptionParser()
parser.add_option("-p", "--path", dest="train_path", help="Path to training data.")
parser.add_option("-o", "--parser", dest="parser", help="Parser to use. One of simple or pascal_voc",
default="pascal_voc")
parser.add_option("-n", "--num_rois", type="int", dest="num_rois", help="Number of RoIs to process at once.", default=32)
parser.add_option("--network", dest="network", help="Base network to use. Supports vgg or resnet50.", default='resnet50')
parser.add_option("--hf", dest="horizontal_flips", help="Augment with horizontal flips in training. (Default=false).", action="store_true", default=False)
parser.add_option("--vf", dest="vertical_flips", help="Augment with vertical flips in training. (Default=false).", action="store_true", default=False)
parser.add_option("--rot", "--rot_90", dest="rot_90", help="Augment with 90 degree rotations in training. (Default=false).",
action="store_true", default=False)
parser.add_option("--num_epochs", type="int", dest="num_epochs", help="Number of epochs.", default=2000)
parser.add_option("--config_filename", dest="config_filename", help=
"Location to store all the metadata related to the training (to be used when testing).",
default="config.pickle")
parser.add_option("--output_weight_path", dest="output_weight_path", help="Output path for weights.", default='./model_frcnn.hdf5')
parser.add_option("--input_weight_path", dest="input_weight_path", help="Input path for weights. If not specified, will try to load default weights provided by keras.")
(options, args) = parser.parse_args()
if not options.train_path: # if filename is not given
parser.error('Error: path to training data must be specified. Pass --path to command line')
if options.parser == 'pascal_voc':
from keras_frcnn.pascal_voc_parser import get_data
elif options.parser == 'simple':
from keras_frcnn.simple_parser import get_data
else:
raise ValueError("Command line option parser must be one of 'pascal_voc' or 'simple'")
# pass the settings from the command line, and persist them in the config object
C = config.Config()
C.use_horizontal_flips = bool(options.horizontal_flips)
C.use_vertical_flips = bool(options.vertical_flips)
C.rot_90 = bool(options.rot_90)
C.model_path = options.output_weight_path
C.num_rois = int(options.num_rois)
#有基于VGG和resnet两种网络模型
if options.network == 'vgg':
C.network = 'vgg'
from keras_frcnn import vgg as nn
elif options.network == 'resnet50':
from keras_frcnn import resnet as nn
C.network = 'resnet50'
else:
print('Not a valid model')
raise ValueError
# check if weight path was passed via command line
if options.input_weight_path:
C.base_net_weights = options.input_weight_path
else:
# set the path to weights based on backend and model
C.base_net_weights = nn.get_weight_path()
all_imgs, classes_count, class_mapping = get_data(options.train_path)
print(len(all_imgs)) #所有图片的信息,图片名称,位置等
print(len(classes_count)) #dict,类别:数量,例如'chair': 1432
print(len(class_mapping)) #dict,各个类别对应的标签向量,0-19,例如chair:0,car:1
#再加入'背景'这个类别
if 'bg' not in classes_count:
classes_count['bg'] = 0
class_mapping['bg'] = len(class_mapping)
C.class_mapping = class_mapping
# 将class_mapping中的key和value对调
inv_map = {v: k for k, v in class_mapping.items()}
print('Training images per class:')
pprint.pprint(classes_count)
print('Num classes (including bg) = {}'.format(len(classes_count)))
config_output_filename = options.config_filename
with open(config_output_filename, 'wb') as config_f:
pickle.dump(C,config_f)
print('Config has been written to {}, and can be loaded when testing to ensure correct results'.format(config_output_filename))
# shuffle数据
random.shuffle(all_imgs)
num_imgs = len(all_imgs)
# 将all_imgs分为训练集和测试集
train_imgs = [s for s in all_imgs if s['imageset'] == 'trainval']
val_imgs = [s for s in all_imgs if s['imageset'] == 'test']
print('Num train samples {}'.format(len(train_imgs)))
print('Num val samples {}'.format(len(val_imgs)))
# 生成anchor
data_gen_train = data_generators.get_anchor_gt(train_imgs, classes_count, C, nn.get_img_output_length, K.image_dim_ordering(), mode='train')
# data_gen_train = data_generators.get_anchor_gt(train_imgs, classes_count, C, nn.get_img_output_length, K.image_dim_ordering(), mode='train')
data_gen_val = data_generators.get_anchor_gt(val_imgs, classes_count, C, nn.get_img_output_length,K.image_dim_ordering(), mode='val')
#查看后端是th还是tf,纠正输入方式
if K.image_dim_ordering() == 'th':
input_shape_img = (3, None, None)
else:
input_shape_img = (None, None, 3)
img_input = Input(shape=input_shape_img)
roi_input = Input(shape=(None, 4))
# define the base network (resnet here, can be VGG, Inception, etc)
#定义nn的输入层,还有faster rcnn共享卷积层
shared_layers = nn.nn_base(img_input, trainable=True)
print("shared_layers",shared_layers.shape)
# define the RPN, built on the base layers
#获取anchor的个数,3重基准大小快,3种比例框,3*3=9
num_anchors = len(C.anchor_box_scales) * len(C.anchor_box_ratios)
#定义rpn层,return [x_class, x_regr, base_layers]
rpn = nn.rpn(shared_layers, num_anchors)
classifier = nn.classifier(shared_layers, roi_input, C.num_rois, nb_classes=len(classes_count), trainable=True)
#定义rpn模型的输入和输出一个框2分类(最后使用的sigmod而不是softmax)和框的回归
model_rpn = Model(img_input, rpn[:2])
#定义classifier的输入和输出
model_classifier = Model([img_input, roi_input], classifier)
# this is a model that holds both the RPN and the classifier, used to load/save weights for the models
model_all = Model([img_input, roi_input], rpn[:2] + classifier)
try:
print('loading weights from {}'.format(C.base_net_weights))
model_rpn.load_weights(C.base_net_weights, by_name=True)
model_classifier.load_weights(C.base_net_weights, by_name=True)
except:
print('Could not load pretrained model weights. Weights can be found in the keras application folder \
https://github.com/fchollet/keras/tree/master/keras/applications')
optimizer = Adam(lr=1e-5)
optimizer_classifier = Adam(lr=1e-5)
model_rpn.compile(optimizer=optimizer, loss=[losses.rpn_loss_cls(num_anchors), losses.rpn_loss_regr(num_anchors)])
model_classifier.compile(optimizer=optimizer_classifier, loss=[losses.class_loss_cls, losses.class_loss_regr(len(classes_count)-1)], metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
model_all.compile(optimizer='sgd', loss='mae')
epoch_length = 1000
num_epochs = int(options.num_epochs)
iter_num = 0
losses = np.zeros((epoch_length, 5))
rpn_accuracy_rpn_monitor = []
rpn_accuracy_for_epoch = []
start_time = time.time()
best_loss = np.Inf
class_mapping_inv = {v: k for k, v in class_mapping.items()}
print('Starting training')
vis = True
for epoch_num in range(num_epochs):
progbar = generic_utils.Progbar(epoch_length)
print('Epoch {}/{}'.format(epoch_num + 1, num_epochs))
while True:
try:
if len(rpn_accuracy_rpn_monitor) == epoch_length and C.verbose:
mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor))/len(rpn_accuracy_rpn_monitor)
rpn_accuracy_rpn_monitor = []
print('Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'.format(mean_overlapping_bboxes, epoch_length))
if mean_overlapping_bboxes == 0:
print('RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.')
print("生成data_gen_train")
#X为经过最小边600的比例变换的原始图像,Y为[所有框位置的和类别(正例还是反例),所有框的前36层为位置和后36层(框和gt的比值)],img_data增强图像后的图像信息
#那么RPN的reg输出也是比值
X, Y, img_data = next(data_gen_train)
print(X.shape,Y[0].shape,Y[1].shape)
loss_rpn = model_rpn.train_on_batch(X, Y)
print("loss_rpn",len(loss_rpn))
print("loss_rpn0",loss_rpn[0])
print("loss_rpn1",loss_rpn[1])
print("loss_rpn2",loss_rpn[2])
P_rpn = model_rpn.predict_on_batch(X)
# print("P_rpn_cls",P_rpn[0].reshape((P_rpn[0].shape[1],P_rpn[0].shape[2],P_rpn[0].shape[3]))[:,:,0])
print("P_rpn_cls",P_rpn[0].shape)
print("P_rpn_reg",P_rpn[1].shape)
#获得最终选中的框
R = roi_helpers.rpn_to_roi(P_rpn[0], P_rpn[1], C, K.image_dim_ordering(), use_regr=True, overlap_thresh=0.7, max_boxes=300)
# note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
#再对回归出来的框进行一次iou的计算,再一次过滤,只保留bg框和物体框
#X2 from (x1,y1,x2,y2) to (x,y,w,h)
#Y1为每个框对应类别标签,one-host编码
#Y2为每个框和gt的比值,(x,x,160),前80表示框是否正确,后80为20个类别可能的框
X2, Y1, Y2, IouS = roi_helpers.calc_iou(R, img_data, C, class_mapping)
print("X2",X2.shape)
# print("X2_0",X2[0,0,:])
# print("X2_1",X2[0,1,:])
print("Y1",Y1.shape)
print("Y2",Y2.shape)
if X2 is None:
rpn_accuracy_rpn_monitor.append(0)
rpn_accuracy_for_epoch.append(0)
continue
#选出正例还是反例的index,背景的为反例,物体为正例
neg_samples = np.where(Y1[0, :, -1] == 1)
pos_samples = np.where(Y1[0, :, -1] == 0)
print("neg_samples",len(neg_samples[0]))
print("pos_samples",len(pos_samples[0]))
if len(neg_samples) > 0:
neg_samples = neg_samples[0]
else:
neg_samples = []
if len(pos_samples) > 0:
pos_samples = pos_samples[0]
else:
pos_samples = []
rpn_accuracy_rpn_monitor.append(len(pos_samples))
rpn_accuracy_for_epoch.append((len(pos_samples)))
#num_rois=32,正例要求小于num_rois//2,其它全部由反例填充
if C.num_rois > 1:
if len(pos_samples) < C.num_rois//2:
selected_pos_samples = pos_samples.tolist()
print("selected_pos_samples",len(selected_pos_samples))
else:
selected_pos_samples = np.random.choice(pos_samples, C.num_rois//2, replace=False).tolist()
print("selected_pos_samples",len(selected_pos_samples))
try:
selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=False).tolist()
print("selected_neg_samples",len(selected_neg_samples))
except:
selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=True).tolist()
print("selected_neg_samples",len(selected_neg_samples))
sel_samples = selected_pos_samples + selected_neg_samples
else:
# in the extreme case where num_rois = 1, we pick a random pos or neg sample
selected_pos_samples = pos_samples.tolist()
selected_neg_samples = neg_samples.tolist()
if np.random.randint(0, 2):
sel_samples = random.choice(neg_samples)
else:
sel_samples = random.choice(pos_samples)
print("sel_samples",len(sel_samples))
print("sel_samples",sel_samples)
loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]], [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])
# P_classifier = model_classifier.predict([X, X2[:, sel_samples, :]])
# #[out_class, out_regr]
# print("P_classifier_out_class",P_classifier[0].shape)
# print("P_classifier_out_regr",P_classifier[1].shape)
# import cv2
# cv2.waitKey(0)
losses[iter_num, 0] = loss_rpn[1]
losses[iter_num, 1] = loss_rpn[2]
losses[iter_num, 2] = loss_class[1]
losses[iter_num, 3] = loss_class[2]
losses[iter_num, 4] = loss_class[3]
iter_num += 1
progbar.update(iter_num, [('rpn_cls', np.mean(losses[:iter_num, 0])), ('rpn_regr', np.mean(losses[:iter_num, 1])),
('detector_cls', np.mean(losses[:iter_num, 2])), ('detector_regr', np.mean(losses[:iter_num, 3]))])
if iter_num == epoch_length:
loss_rpn_cls = np.mean(losses[:, 0])
loss_rpn_regr = np.mean(losses[:, 1])
loss_class_cls = np.mean(losses[:, 2])
loss_class_regr = np.mean(losses[:, 3])
class_acc = np.mean(losses[:, 4])
mean_overlapping_bboxes = float(sum(rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
rpn_accuracy_for_epoch = []
if C.verbose:
print('Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'.format(mean_overlapping_bboxes))
print('Classifier accuracy for bounding boxes from RPN: {}'.format(class_acc))
print('Loss RPN classifier: {}'.format(loss_rpn_cls))
print('Loss RPN regression: {}'.format(loss_rpn_regr))
print('Loss Detector classifier: {}'.format(loss_class_cls))
print('Loss Detector regression: {}'.format(loss_class_regr))
print('Elapsed time: {}'.format(time.time() - start_time))
curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
iter_num = 0
start_time = time.time()
if curr_loss < best_loss:
if C.verbose:
print('Total loss decreased from {} to {}, saving weights'.format(best_loss,curr_loss))
best_loss = curr_loss
model_all.save_weights(C.model_path)
break
except Exception as e:
print('Exception: {}'.format(e))
continue
print('Training complete, exiting.')
faster RCNN(keras版本)代码讲解(3)-训练流程详情的更多相关文章
- 新人如何运行Faster RCNN的tensorflow代码
0.目的 刚刚学习faster rcnn目标检测算法,在尝试跑通github上面Xinlei Chen的tensorflow版本的faster rcnn代码时候遇到很多问题(我真是太菜),代码地址如下 ...
- (原)faster rcnn的tensorflow代码的理解
转载请注明出处: https://www.cnblogs.com/darkknightzh/p/10043864.html 参考网址: 论文:https://arxiv.org/abs/1506.01 ...
- Faster RCNN算法demo代码解析
一. Faster-RCNN代码解释 先看看代码结构: Data: This directory holds (after you download them): Caffe models pre-t ...
- Faster R-CNN利用新的网络结构来训练
前言 最近利用Faster R-CNN训练数据,使用ZF模型,效果无法有效提高.就想尝试对ZF的网络结构进行改造,记录下具体操作. 一.更改网络,训练初始化模型 这里为了方便,我们假设更换的网络名为L ...
- Windows10 Faster R-CNN(GPU版) 配置训练自己的模型
参考链接 1. 找到合适自己的版本,下载安装Anaconda 点击跳转下载安装 Anaconda,双击下载好的 .exe 文件安装,只勾选第一个把 conda 添加到 PATH 路径.
- faster rcnn相关内容
转自: https://zhuanlan.zhihu.com/p/31426458 faster rcnn的基本结构 Faster RCNN其实可以分为4个主要内容: Conv layers.作为一种 ...
- faster rcnn 详解
转自:https://zhuanlan.zhihu.com/p/31426458 经过R-CNN和Fast RCNN的积淀,Ross B. Girshick在2016年提出了新的Faster RCNN ...
- 实战 | 源码入门之Faster RCNN
前言 学习深度学习和计算机视觉,特别是目标检测方向的学习者,一定听说过Faster Rcnn:在目标检测领域,Faster Rcnn表现出了极强的生命力,被大量的学习者学习,研究和工程应用.网上有很多 ...
- Faster RCNN学习笔记
感谢知乎大神的分享 https://zhuanlan.zhihu.com/p/31426458 Ross B. Girshick在2016年提出了新的Faster RCNN,在结构上,Faster R ...
随机推荐
- 图像检索:CEDD(Color and Edge Directivity Descriptor)算法 颜色和边缘的方向性描述符
颜色和边缘的方向性描述符(Color and Edge Directivity Descriptor,CEDD) 本文节选自论文<Android手机上图像分类技术的研究>. CEDD具有抽 ...
- Easy_Re
这题比较简单,一波常规的操作之后直接上ida(小白的常规操作在以前的博客里都有所以这里不在赘述了),ida打开之后查看一下, 这里应该就是一个入口点了,接着搜索flag字符串, 上面的黄色的部分转换成 ...
- JS - 判断字符串某个下标的值
<html><body> <script type="text/javascript"> var str="0123456789!&q ...
- UML图表示类之间的关系
一.泛化(Generanization) 图: 泛化简单的说就是继承关系,在java中就是extend.表示一般与特殊的关系.如鸭子是鸟的一种,即有鸭子的特性也有鸟的共性.用带空心的三角箭头的实线指向 ...
- 《Interest Rate Risk Modeling》阅读笔记——第十章 主成分模型与 VaR 分析
目录 第十章:主成分模型与 VaR 分析 思维导图 一些想法 推导 PCD.PCC 和 KRD.KRC 的关系 PCD 和 KRD PCC 和 KRC 第十章:主成分模型与 VaR 分析 思维导图 一 ...
- Day7 - B - Super A^B mod C FZU - 1759
Given A,B,C, You should quickly calculate the result of A^B mod C. (1<=A,C<=1000000000,1<=B ...
- 玩个JAVA爬虫,没想玩大
想玩个爬虫,爬些数据玩玩,不成想把自己玩“进去”了 想爬这个新浪的股票 大额交易页面 本以为用 HttpClient 直接爬链接,结果发现这个页面中,翻页数据压根就是动态赋值的,根本没有,那我根本无法 ...
- PHP计划任务
server 2008:D:\SOFT_PHP_PACKAGE\php\php-cgi.exe -f D:\wwwroot\tlbuyuncom\wwwroot\Up_Data.phpPHP路径 -f ...
- css画布
绘制基本图形 绘制直线 <style> .canvas{ } </style> <canvas id="myCanvas1" style=" ...
- 五十、在SAP程序中应用其他单元,INCLUDE的用法
一.在SAP程序中写入以下代码 二.双击引用的单元,会弹出以下窗口 三.点击是 四.点击保存 五.保存在本地 六.此文件被包含进来 七.我们把在GET_DATA和SHOW_DATA写到INCLUDE里 ...