RCNN算法的tensorflow实现
RCNN算法的tensorflow实现
转载自:https://blog.csdn.net/MyJournal/article/details/77841348?locationNum=9&fps=1
这个算法的思路大致如下:
1、训练人脸分类模型 输入:图像;输出:这张图像的特征
1-1、在Caltech256数据集上pre-trained,训练出一个较大的图片识别库;
1-2、利用之前人脸与非人脸的数据集对预训练模型进行fine tune,得到一个人脸分类模型。
2、训练SVM模型(重新定义正负样本)输入:图像的特征 输出:图像类别
3、将图片分为多个矩形选框,用SVM模型对这些选框区域进行分类,即判定该区域中是否包含人脸
4、使用回归器精细修正候选框位置
下面将进行具体的解释。
1、训练人脸分类模型
以初学者的思维(在基本掌握了MNIST手写数字识别后),我们通常是设置一个神经网络(通常是借鉴在图片分类中较好的模型的网络层次结构,例如Alexnet、VGG16等,但据说VGG16的计算量较大,这个我也没有试过)直接开始训练即可。但这时需要考虑到一个问题:我们为模型选择的数据集的规模如何?
如果我的神经网络结构是七层,前四层是卷积池化层,后三层是连接层,对于这样较复杂的网络使用多少的数据量合适呢?几千张?几万张?可能都有些少了。当图片较少时,模型很容易欠拟合,因此需要借用别人用大数据量作为数据集已经训练好的模型。但需要注意的是,一旦借用别人的模型,之后fine-tuning定义的模型结构需要与之相同,除了最终的图片分类数目不同以外。
以下是我定义的神经网络:
def inference(input_tensor, train, regularizer,num):
with tf.name_scope('layer1-conv1'):
conv1_weights = tf.get_variable("weight1",[5,5,3,32],initializer=tf.truncated_normal_initializer(stddev=0.1))
conv1_biases = tf.get_variable("bias1", [32], initializer=tf.constant_initializer(0.0))
conv1 = tf.nn.conv2d(input_tensor, conv1_weights, strides=[1, 1, 1, 1], padding='SAME')
relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_biases))
with tf.name_scope("layer2-pool1"):
pool1 = tf.nn.max_pool(relu1, ksize = [1,2,2,1],strides=[1,2,2,1],padding="VALID")
with tf.variable_scope("layer3-conv2"):
conv2_weights = tf.get_variable("weight2",[5,5,32,64],initializer=tf.truncated_normal_initializer(stddev=0.1))
conv2_biases = tf.get_variable("bias2", [64], initializer=tf.constant_initializer(0.0))
conv2 = tf.nn.conv2d(pool1, conv2_weights, strides=[1, 1, 1, 1], padding='SAME')
relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_biases))
with tf.name_scope("layer4-pool2"):
pool2 = tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
with tf.variable_scope("layer5-conv3"):
conv3_weights = tf.get_variable("weight3",[3,3,64,128],initializer=tf.truncated_normal_initializer(stddev=0.1))
conv3_biases = tf.get_variable("bias3", [128], initializer=tf.constant_initializer(0.0))
conv3 = tf.nn.conv2d(pool2, conv3_weights, strides=[1, 1, 1, 1], padding='SAME')
relu3 = tf.nn.relu(tf.nn.bias_add(conv3, conv3_biases))
with tf.name_scope("layer6-pool3"):
pool3 = tf.nn.max_pool(relu3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
with tf.variable_scope("layer7-conv4"):
conv4_weights = tf.get_variable("weight4",[3,3,128,128],initializer=tf.truncated_normal_initializer(stddev=0.1))
conv4_biases = tf.get_variable("bias4", [128], initializer=tf.constant_initializer(0.0))
conv4 = tf.nn.conv2d(pool3, conv4_weights, strides=[1, 1, 1, 1], padding='SAME')
relu4 = tf.nn.relu(tf.nn.bias_add(conv4, conv4_biases))
with tf.name_scope("layer8-pool4"):
pool4 = tf.nn.max_pool(relu4, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
nodes = 6*6*128
reshaped = tf.reshape(pool4,[-1,nodes])
with tf.variable_scope('layer9-fc1'):
fc1_weights = tf.get_variable("weight5", [nodes, 1024],initializer=tf.truncated_normal_initializer(stddev=0.1))
if regularizer != None:
tf.add_to_collection('losses1', regularizer(fc1_weights))
fc1_biases = tf.get_variable("bias5", [1024], initializer=tf.constant_initializer(0.1))
fc1 = tf.nn.relu(tf.matmul(reshaped, fc1_weights) + fc1_biases)
if train:
fc1 = tf.nn.dropout(fc1, 0.5)
with tf.variable_scope('layer10-fc2'):
fc2_weights = tf.get_variable("weight6", [1024, 512],initializer=tf.truncated_normal_initializer(stddev=0.1))
if regularizer != None:
tf.add_to_collection('losses2', regularizer(fc2_weights))
fc2_biases = tf.get_variable("bias6", [512], initializer=tf.constant_initializer(0.1))
fc2 = tf.nn.relu(tf.matmul(fc1, fc2_weights) + fc2_biases)
if train:
fc2 = tf.nn.dropout(fc2, 0.5)
with tf.variable_scope('layer11-fc3'):
fc3_weights = tf.get_variable("weight7", [512, num],initializer=tf.truncated_normal_initializer(stddev=0.1))
if regularizer != None:
tf.add_to_collection('losses3', regularizer(fc3_weights))
fc3_biases = tf.get_variable("bias7", [num], initializer=tf.constant_initializer(0.1))
logit = tf.matmul(fc2, fc3_weights) + fc3_biases
return logit #fc3
1-1、pre-trained——Caltech256数据集(有256类的图片,包括静物、动物、人物等)
最终分的类别是256类,将上述网络中num设置为256即可。
将训练好的模型保存到model.ckpt中,之后fine-tuning需要将预训练模型重新加载。
checkpoint_file = os.path.join(log_dir, 'model.ckpt')
saver.save(sess,checkpoint_file)
1-2、fine tuning
在这里我先介绍论文中的做法。(由于我的电脑运行速度太慢了,我就没有这么做,只是找了人脸及非人脸的数据集拉进去fine tuning,但是效果不是很好…)
如果做的目标定位系统是定位男人、女人、猫、狗这四类目标,那我们将fine tuning的神经网络中的最后一层num设置为5(4+1),加的这一类代表背景。那么背景如何获得呢? 首先,需要我们提前对图片数据提前标定目标位置,对于每张图可能获得一个或更多的标定矩形框(x,y,w,h分别表示横坐标的最小值,纵坐标的最小值、矩形框宽度、矩形框长度)。其次,我们通过Python selectivesearch库中的selectivesearch指令获得多个目标框(Proposals)(selectivesearch指令根据图片的颜色变化、纹理等将多个像素合并为多个选框)。接着,我们通过定义并计算出的IoU(目标框与标定框的重合程度,即IoU=重合面积/两个矩形所占的面积(其中一个矩形是标定框,另一个矩形是目标框))与阈值比较,若大于这个阈值则表示该目标框标出的是男人、女人、猫或狗四类中的一类,若小于这个阈值则表示该标定框标出的是背景。论文中选取的阈值threshold=0.5。最后,加载pre-trained模型后,训练这些图片,在预训练模型的基础上对各个参数进行微调。
IOU的定义如下:
def if_intersection(xmin_a, xmax_a, ymin_a, ymax_a, xmin_b, xmax_b, ymin_b, ymax
_b):
if_intersect = False
# 通过四条if来查看两个方框是否有交集。如果四种状况都不存在,我们视为无交集
if xmin_a < xmax_b <= xmax_a and (ymin_a < ymax_b <= ymax_a or ymin_a <= ymin_b < ymax_a):
if_intersect = True
elif xmin_a <= xmin_b < xmax_a and (ymin_a < ymax_b <= ymax_a or ymin_a <= ymin_b < ymax_a):
if_intersect = True
elif xmin_b < xmax_a <= xmax_b and (ymin_b < ymax_a <= ymax_b or ymin_b <= ymin_a < ymax_b):
if_intersect = True
elif xmin_b <= xmin_a < xmax_b and (ymin_b < ymax_a <= ymax_b or ymin_b <= ymin_a < ymax_b):
if_intersect = True else:
return False
# 在有交集的情况下,我们通过大小关系整理两个方框各自的四个顶点, 通过它们得到交集面积
if if_intersect == True:
x_sorted_list = sorted([xmin_a, xmax_a, xmin_b, xmax_b])#from small to big number
y_sorted_list = sorted([ymin_a, ymax_a, ymin_b, ymax_b])
x_intersect_w = x_sorted_list[2] - x_sorted_list[1]
y_intersect_h = y_sorted_list[2] - y_sorted_list[1]
area_inter = x_intersect_w * y_intersect_h
return area_inter def IOU(ver1, ver2):
vertice1 = [ver1[0], ver1[1], ver1[0]+ver1[2], ver1[1]+ver1[3]]
vertice2 = [ver2[0], ver2[1], ver2[0]+ver2[2], ver2[1]+ver2[3]]
area_inter = if_intersection(vertice1[0], vertice1[2], vertice1[1], vertice1[3], vertice2[0], vertice2[2], vertice2[1], vertice2[3])
# 如果有交集,计算IOU
if area_inter:
area_1 = ver1[2] * ver1[3]
area_2 = ver2[2] * ver2[3]
iou = float(area_inter) / (area_1 + area_2 - area_inter)
return iou
iou = 0
return iou
加载pre-trained模型并进行fine-tune训练:
def load_with_skip(data_path, session, skip_layer):
reader = pywrap_tensorflow.NewCheckpointReader(ckpt.model_checkpoint_path)
data_dict = reader.get_variable_to_shape_map()
for key in data_dict:
print("tensor_name: ", key)
if key not in skip_layer:
print ( data_dict[key])
print (reader.get_tensor(key))
session.run([key]) saver = tf.train.Saver()
with tf.Session() as sess:
restore = False
sess.run(tf.global_variables_initializer())
ckpt1 = tf.train.get_checkpoint_state(aim_dir)
if ckpt1 and ckpt1.model_checkpoint_path:
restore = True
saver.restore(sess,ckpt1.model_checkpoint_path)
print ('fine-tuning model has already exist!')
print("Continue training") else:
ckpt = tf.train.get_checkpoint_state(log_dir)
if ckpt and ckpt.model_checkpoint_path:
restore = True
print ('original model has already exist!')
print("Continue training")
load_with_skip(ckpt.model_checkpoint_path, sess, ['layer11-fc3','layer11-fc2','layer11-fc1'])
2、训练SVM模型,论文中是这么说的:
(1)SVM分类与CNN分类的数据集区别:
‘for finetuning we map each object proposal to the ground-truth instance with which it has maximum IoU overlap (if any) and label it as a positive for the matched ground-truth class if the IoU is at least 0.5. All other proposals are labeled “background” (i.e., negative examples for all classes). For training SVMs, in contrast, we take only the ground-truth boxes as positive examples for their respective classes and label proposals with less than 0.3 IoU overlap with all instances of a class as a negative for that class. Proposals that fall into the grey zone (more than 0.3 IoU overlap, but are not ground truth) are ignored.’
Fine tuning 阶段我们将IoU大于0.5的目标框圈定的图片作为正样本,小于0.5的目标框圈定的图片作为负样本。而在对每一类目标分类的SVM训练阶段,我们将标定框圈定的图片作为正样本,IoU小于0.3的目标框圈定的图片作为负样本,其余目标框舍弃。
(2)对每一类目标选择SVM模型
‘Once features are extracted and training labels are applied, we optimize one linear SVM per class.’
对SVM(支持向量机)简单的理解就是:寻找一个(超)平面将一个事物与其对立面尽可能划分开来。(二分类问题)
我们将正样本作为输入送入fine-tune模型中,输出是某一连接层得到的特征值,将这个输出与其标签(上面标定过的正负样本)作为SVM的样本进行训练,得到SVM模型。
(3)为什么选择SVM?
‘In Appendix B we discuss why the positive and negative examples are defined differently in fine-tuning versus SVM training. We also discuss the trade-offs involved in training detection SVMs rather than simply using the outputs from the final softmax layer of the fine-tuned CNN.’
论文的附录中提到了为什么不直接选择CNN模型及softmax对目标分类,而是选择SVM来分类。
def load_from_pkl(dataset_file):
X, Y = pickle.load(open(dataset_file, 'rb'))
return X,Y def load_train_proposals(datafile, num_clss, threshold = 0.5, svm = False, save=False, save_path='dataset.pkl'):
train_list = open(datafile,'r')
labels = []
images = []
n = 0
for line in train_list:
n = n+1
print ('n: '+str(n))
tmp = line.strip().split(' ')
# tmp0 = image address
# tmp1 = label
# tmp2 = rectangle vertices
img = skimage.io.imread(tmp[0])
ref_rect = tmp[2].split(',')
ref_rect_int = [int(i) for i in ref_rect]
print (ref_rect)
# im_orig:输入图片 scale:表示felzenszwalb分割时,值越大,表示保留的下来的集合就越大
# sigma:表示felzenszwalb分割时,用的高斯核宽度 min_size:表示分割后最小组尺寸
img_lbl, regions = selectivesearch.selective_search(img, scale=200, sigma=0.3, min_size=25)
candidates = set()
for r in regions:
# excluding same rectangle (with different segments)
if r['rect'] in candidates:# 剔除重复的方框
continue
if r['size'] < 220:# 剔除太小的方框
continue
if r['size'] > 4000:
continue
proposal_img, proposal_vertice = clip_pic(img, r['rect']) if len(proposal_img) == 0:# Delete Empty array
continue
x, y, w, h = r['rect']
if w == 0 or h == 0: # 长或宽为0的方框,剔除
continue
if h/w <= 0.7 or h/w>=1.3:
continue
# Check if any 0-dimension exist image array的dim里有0的,剔除
[a, b, c] = np.shape(proposal_img)
if a == 0 or b == 0 or c == 0:
continue im = Image.fromarray(proposal_img)
resized_proposal_img = resize_image(im, 100, 100,resize_mode=3) # 重整方框的大小
candidates.add(r['rect']) img_float = pil_to_nparray(resized_proposal_img)
images.append(img_float)
# 计算IOU
iou_val = IOU(ref_rect_int, proposal_vertice) # labels, let 0 represent default class, which is background
index = int(tmp[1])
if svm == False:
label = np.zeros(num_clss+1)
if iou_val < threshold:
labels.append(0)
else:
labels.append(index)
labels.append(label) else:
if iou_val < threshold:
labels.append(0)
else:
labels.append(index)
print (r['rect'])
print ('iou_val: '+str(iou_val))
print ('labels append!') if save:
pickle.dump((images, labels), open(save_path, 'wb'))
return images, labels def generate_single_svm_train(one_class_train_file):#获取SVM训练样本
trainfile = one_class_train_file
savepath = one_class_train_file.replace('txt', 'pkl')
print (savepath)
images = []
Y = []
if os.path.isfile(savepath):
print("restoring svm dataset " + savepath)
images, Y = load_from_pkl(savepath)
else:
print("loading svm dataset " + savepath)
images, Y = load_train_proposals(trainfile, 3, threshold=0.3, svm=True, save=True, save_path=savepath)
return images, Y def train_svms(train_file_folder, model):
listings = os.listdir(train_file_folder)
print (listings)
svms = []
for train_file in listings:
if "pkl" in train_file:
continue
X, Y = generate_single_svm_train(train_file_folder+train_file)
print (np.shape(X))
print ('success!')
train_features = [] for i in range(0,len(Y)):
imgsvm = X[i]
labelsvm = Y[i]
print ('svm LABEL:'+str(labelsvm))
feats,prelabel = Restore_show(imgsvm)
train_features.append(feats[0])
print("feature dimension") clf = svm.LinearSVC()
print("fit svm")
clf.fit(train_features,Y)
print (clf)
print(clf.score(train_features, Y)) # 打印拟合优度 joblib.dump(clf,os.getcwd()+'/svm/filename.pkl')#保存SVM模型
svms.append(clf)
print (svms)
return svms
3、将图片用selectivesearch指令分为多个矩形选框,用SVM模型对这些选框区域进行分类,即判定该区域中是否包含人脸,并将标签为1(即包含人脸的图片)记录下来:
imgs, verts = image_proposal(img_path)#image_proposal函数类似于之前的load_train_proposals函数,用于将选框筛选出来 with tf.Session() as sess:
features = []
box_images = []
print("predict image:")
results = []
results_label = []
results_ratio = []
count = 0
number = 0
temp = []
for f in imgs:
feats ,prelabel ,ratio= Restore_show(f)#Restore_show函数是将图片送入CNN分类模型预测,输出分别是特征、预测标签、是人脸的概率
clf=joblib.load(os.getcwd()+'/svm/filename.pkl')#载入SVM模型
pred = clf.predict(feats[0])#用模型进行预测,feats[0]是图片的特征
print(pred)
if pred[0] != 0:
results.append(verts[count])
results_label.append(pred[0])
results_ratio.append(ratio)
temp.append ((ratio,verts[count][0],verts[count][1],verts[count][2],verts[count][3]))
number += 1
count += 1
4、使用回归器精细修正候选框位置 (box regression)
至于这一部分论文中及许多博客上都仔细讲过,主要计算公式我就不再赘述。大致的原理就是标定框与目标框之间存在一定误差,我们需要寻找一种关系重新对目标框设置中心点及大小。为了保持这个关系为线性关系,我们在使用ridge regression时选择的目标框应是与标定框之间的IoU在0.6以上的值(论文中选择的值,我选取的是0.7,感觉效果也可以。)
4-1、ridge regression训练的输入是:图片标定框的特征值,标定框的中心点坐标、长、宽(x,y,w,h),目标框的中心点坐标、长、宽(x,y,w,h)
4-2、预测:
feature, classnum = Output_show(img_path,0,0,size[0],size[1])
#Output_show函数类似于Restore_show函数,将图片送入CNN分类模型预测,输出分别是特征、预测标签
clf=joblib.load(os.getcwd()+'/boxregression/filenamex.pkl')#载入ridge regression模型
predx = clf.predict(feature)
clf=joblib.load(os.getcwd()+'/boxregression/filenamey.pkl')
predy = clf.predict(feature)
clf=joblib.load(os.getcwd()+'/boxregression/filenamew.pkl')
predw = clf.predict(feature)
clf=joblib.load(os.getcwd()+'/boxregression/filenameh.pkl')
predh = clf.predict(feature) for i in range(number-1,-1,-1):
if i not in flag_not:
print (temp[i][1],temp[i][2],temp[i][3],temp[i][4])
x = float(temp[i][1])
y = float(temp[i][2])
w = float(temp[i][3])
h = float(temp[i][4])
x1 = max(w*predx+x,0)
y1 = max(h*predy+y,0)
w1 = w*math.exp(predw)
h1 = h*math.exp(predh)
print (str(x1)+' '+str(y1)+' '+str(w1)+' '+str(h1)) rect = mpatches.Rectangle(
(x1, y1), w1, h1, fill=False, edgecolor='red', linewidth=2)
ax.add_patch(rect)#画出边框回归后的矩形 rect1 = mpatches.Rectangle(
(x, y), w, h, fill=False, edgecolor='white', linewidth=2)
ax.add_patch(rect1)#画出为边框回归的矩形 out_ratio = str(temp[i][1])
plt.text(x1+15, y1+15, str(temp[i][0]),color='red') #在矩形框上写出预测概率
1、http://blog.csdn.net/bixiwen_liu/article/details/53840913
2、http://blog.csdn.net/ture_dream/article/details/52896452
3、http://blog.csdn.net/daunxx/article/details/51578787
4、https://github.com/rbgirshick/rcnn
5、http://www.cnblogs.com/edwardbi/p/5647522.html
RCNN算法的tensorflow实现的更多相关文章
- 目标检测算法(1)目标检测中的问题描述和R-CNN算法
目标检测(object detection)是计算机视觉中非常具有挑战性的一项工作,一方面它是其他很多后续视觉任务的基础,另一方面目标检测不仅需要预测区域,还要进行分类,因此问题更加复杂.最近的5年使 ...
- 第三十一节,目标检测算法之 Faster R-CNN算法详解
Ren, Shaoqing, et al. “Faster R-CNN: Towards real-time object detection with region proposal network ...
- 第三十节,目标检测算法之Fast R-CNN算法详解
Girshick, Ross. “Fast r-cnn.” Proceedings of the IEEE International Conference on Computer Vision. 2 ...
- 第二十九节,目标检测算法之R-CNN算法详解
Girshick, Ross, et al. “Rich feature hierarchies for accurate object detection and semantic segmenta ...
- 【目标检测】Faster RCNN算法详解
Ren, Shaoqing, et al. “Faster R-CNN: Towards real-time object detection with region proposal network ...
- 【目标检测】RCNN算法详解
网址: 1. https://blog.csdn.net/zijin0802034/article/details/77685438 (box regression 边框回归) 2. https:// ...
- R-CNN算法概要
参考论文:Rich feature hierarchies for accurate object detection and semantic segmentation 下载地址:https://a ...
- 目标检测算法之Faster R-CNN算法详解
Fast R-CNN存在的问题:选择性搜索,非常耗时. 解决:加入一个提取边缘的神经网络,将候选框的选取交给神经网络. 在Fast R-CNN中引入Region Proposal Network(RP ...
- 目标检测算法之R-CNN算法详解
R-CNN全称为Region-CNN,它可以说是第一个成功地将深度学习应用到目标检测上的算法.后面提到的Fast R-CNN.Faster R-CNN全部都是建立在R-CNN的基础上的. 传统目标检测 ...
随机推荐
- mybatis基础(下)
mybatis和spring整合 需要spring通过单例方式管理SqlSessionFactory spring和mybatis整合生成代理对象,使用SqlSessionFactory创建SqlSe ...
- SAP MM ME1M报表结果不科学?
SAP MM ME1M报表结果不科学? 做过SAP MM顾问的都知道,报表ME1M可以查询物料的info record列表,即是说可以以列表的形式批量显示多个物料的采购价格主数据. 但是这个报表有个不 ...
- springmvc实现视频上传+进度条
前台表单: <form id="uploadform" method="post" enctype="multipart/form-data&q ...
- 彻底删除mysql服务(清理注册表)
由于安装某个项目的执行文件,提示要卸载MySQL以便它自身MySQL安装,然后我禁用了MYSQL服务,再把这个文件夹删除后,发现还是提示请卸载MYSQL服务. 解决步骤: 1.以管理员身份运行命令提示 ...
- centos7网络配置方法
方法一:nmtui 这个是字符界面的图形化网络配置工具 方法二:nmcli 命令行配置 方法三:直接vim /etc/sysconfig/network-scripts/ens---- 编辑 ...
- Chinese word segment based on character representation learning 论文笔记
论文名和编号 摘要/引言 相关背景和工作 论文方法/模型 实验(数据集)及 分析(一些具体数据) 未来工作/不足 是否有源码 问题 原因 解决思路 优势 基于表示学习的中文分词 编号:1001-908 ...
- 1.2 NCE22 By heart
Some plays are so successful that they run/are performed/ for years on end/successively/in a row/con ...
- Linux Collection:源和更新
PAS 配置sources.list软件源 参考例子(Debian 9,文件/etc/apt/sources.list): deb https://mirrors.ustc.edu.cn/debian ...
- yum源 Python3 Django mysql安装
yum 源安装 yum源位置: yum源仓库的地址 在/etc/yum.repos.d/,并且只能读出第一层的repo文件 yum仓库的文件都是以.repo结尾的 linux软件包管理 yum工具如同 ...
- 迭代与JDB
1.题目要求 2.程序设计 首先,命令行输入,还是考虑先将输入的数据转化为整型变量 然后,看到C(n,m)=C(n-1,m-1)+C(n-1,m)公式以及"迭代"这两个字,首先想到 ...