七 测试网络

模型测试包含于test.py文件,Detector类的image_detector()函数用于检测目标。

import os
import cv2
import argparse
import numpy as np
import tensorflow as tf
import yolo.config as cfg
from yolo.yolo_net import YOLONet
from utils.timer import Timer '''
用于测试
''' class Detector(object):

1、类初始化函数

 def __init__(self, net, weight_file):
'''
构造函数
利用 cfg 文件对网络参数进行初始化,
其中 offset 的作用应该是一个定长的偏移
boundery1和boundery2 作用是在输出中确定每种信息的长度(如类别,置信度等)。
其中 boundery1 指的是对于所有的 cell 的类别的预测的张量维度,所以是 self.cell_size * self.cell_size * self.num_class
boundery2 指的是在类别之后每个cell 所对应的 bounding boxes 的数量的总和,所以是self.boundary1 + self.cell_size * self.cell_size * self.boxes_per_cell args:
net:YOLONet对象
weight_file:检查点文件路径
'''
#yolo网络
self.net = net
#检查点文件路径
self.weights_file = weight_file
#输出文件夹路径
self.output_dir = os.path.dirname(self.weights_file)
#VOC 2012数据集类别名
self.classes = cfg.CLASSES
# #VOC 2012数据类别数
self.num_class = len(self.classes)
##图像大小
self.image_size = cfg.IMAGE_SIZE
#单元格大小S
self.cell_size = cfg.CELL_SIZE
#每个网格边界框的个数B=2
self.boxes_per_cell = cfg.BOXES_PER_CELL
#阈值参数
self.threshold = cfg.THRESHOLD
#IoU 阈值参数
self.iou_threshold = cfg.IOU_THRESHOLD
'''#将网络输出分离为类别和置信度以及边界框的大小,输出维度为7*7*20 + 7*7*2 + 7*7*2*4=1470'''
#7*7*20
self.boundary1 = self.cell_size * self.cell_size * self.num_class
#7*7*20+7*7*2
self.boundary2 = self.boundary1 +\
self.cell_size * self.cell_size * self.boxes_per_cell #运行图之前,初始化变量
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer()) #恢复模型
print('Restoring weights from: ' + self.weights_file)
self.saver = tf.train.Saver()
#直接载入最近保存的检查点文件
ckpt = tf.train.latest_checkpoint(self.output_dir)
print("ckpt:",ckpt)
#如果存在检查点文件 则恢复模型
if ckpt!=None:
#恢复最近的检查点文件
self.saver.restore(self.sess, ckpt)
else:
#从指定检查点文件恢复
self.saver.restore(self.sess, self.weights_file)

2、draw_result()函数

在原始图像上绘制边界框,并添加一些附件信息,如目标类别,置信度。

    def draw_result(self, img, result):
'''
在原图上绘制边界框,以及附加信息 args:
img:原始图片数据
result:yolo网络目标检测到的边界框,list类型 每一个元素对应一个目标框
包含{类别名,x_center,y_center,w,h,置信度}
'''
#遍历每一个边界框
for i in range(len(result)):
#x_center
x = int(result[i][1])
#y_center
y = int(result[i][2])
#w/2
w = int(result[i][3] / 2)
#h/2
h = int(result[i][4] / 2)
#绘制矩形框(目标边界框) 矩形左上角,矩形右下角
cv2.rectangle(img, (x - w, y - h), (x + w, y + h), (0, 255, 0), 2)
#绘制矩形框,用于存放类别名称,使用灰度填充
cv2.rectangle(img, (x - w, y - h - 20),
(x + w, y - h), (125, 125, 125), -1)
#线型
lineType = cv2.LINE_AA if cv2.__version__ > '3' else cv2.CV_AA
#绘制文本信息 写上类别名和置信度
cv2.putText(
img, result[i][0] + ' : %.2f' % result[i][5],
(x - w + 5, y - h - 7), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
(0, 0, 0), 1, lineType)

3、detect()函数

detect()函数用来对图像进行目标检测。

 def detect(self, img):
'''
图片目标检测 args:
img:原始图片数据 return:
result:返回检测到的边界框,list类型 每一个元素对应一个目标框
包含{类别名,x_center,y_center,w,h,置信度}
'''
#获取图片的高和宽
img_h, img_w, _ = img.shape
#图片缩放 [448,448,3]
inputs = cv2.resize(img, (self.image_size, self.image_size))
#BGR->RGB uint->float32
inputs = cv2.cvtColor(inputs, cv2.COLOR_BGR2RGB).astype(np.float32)
#归一化处理 [-1.0,1.0]
inputs = (inputs / 255.0) * 2.0 - 1.0
#reshape [1,448,448,3]
inputs = np.reshape(inputs, (1, self.image_size, self.image_size, 3)) #获取网络输出第一项(即第一张图片) [1,1470]
result = self.detect_from_cvmat(inputs)[0] #对检测的图片的边界框进行缩放处理,一张图片可以有多个边界框
for i in range(len(result)):
#x_center, y_center, w, h都是真实值,分别表示预测边界框的中心坐标,宽和高,都是浮点型
result[i][1] *= (1.0 * img_w / self.image_size) #x_center
result[i][2] *= (1.0 * img_h / self.image_size) #y_center
result[i][3] *= (1.0 * img_w / self.image_size) #w
result[i][4] *= (1.0 * img_h / self.image_size) #h #<class 'list'> 6 ['person', 405.83171163286482, 161.40340532575334, 166.17623397282193, 298.85661533900668, 0.69636690616607666]
#Average detecting time: 0.571s
print(type(result),len(result),result[0])
return result

4、detect_from_cvmat()函数

 def detect_from_cvmat(self, inputs):
'''
运行yolo网络,开始检测 args:
inputs:输入数据 [None,448,448,3] return:
results:返回目标检测的结果,每一个元素对应一个测试图片,每个元素包含着若干个边界框 '''
#返回网络最后一层,激活函数处理之前的值 形状[None,1470]
net_output = self.sess.run(self.net.logits,
feed_dict={self.net.images: inputs})
results = [] #对网络输出每一行数据进行处理
for i in range(net_output.shape[0]):
results.append(self.interpret_output(net_output[i])) #返回处理后的结果
return results

5、interpret_output()函数

该函数对yolo网络输出的结果进行处理,提取出有目标的边界框,方便后续的处理。

 def interpret_output(self, output):
'''
对yolo网络输出进行处理 args:
output:yolo网络输出的每一行数据 大小为[1470,]
0:7*7*20:表示预测类别
7*7*20:7*7*20 + 7*7*2:表示预测置信度,即预测的边界框与实际边界框之间的IOU
7*7*20 + 7*7*2:1470:预测边界框 目标中心是相对于当前格子的,宽度和高度的开根号是相对当前整张图像的(归一化的) return:
result:yolo网络目标检测到的边界框,list类型 每一个元素对应一个目标框
包含{类别名,x_center,y_center,w,h,置信度} 实际上这个置信度是yolo网络输出的置信度confidence和预测对应的类别概率的乘积
'''
#[7,7,2,20]
probs = np.zeros((self.cell_size, self.cell_size,
self.boxes_per_cell, self.num_class))
#类别概率 [7,7,20]
class_probs = np.reshape(
output[0:self.boundary1],
(self.cell_size, self.cell_size, self.num_class))
#置信度 [7,7,2]
scales = np.reshape(
output[self.boundary1:self.boundary2],
(self.cell_size, self.cell_size, self.boxes_per_cell))
#边界框 [7,7,2,4]
boxes = np.reshape(
output[self.boundary2:],
(self.cell_size, self.cell_size, self.boxes_per_cell, 4))
#[14,7] 每一行[0,1,2,3,4,5,6]
offset = np.array(
[np.arange(self.cell_size)] * self.cell_size * self.boxes_per_cell)
#[7,7,2] 每一行都是 [[0,0],[1,1],[2,2],[3,3],[4,4],[5,5],[6,6]]
offset = np.transpose(
np.reshape(
offset,
[self.boxes_per_cell, self.cell_size, self.cell_size]),
(1, 2, 0)) #目标中心是相对于整个图片的
boxes[:, :, :, 0] += offset
boxes[:, :, :, 1] += np.transpose(offset, (1, 0, 2))
boxes[:, :, :, :2] = 1.0 * boxes[:, :, :, 0:2] / self.cell_size
#宽度、高度相对整个图片的
boxes[:, :, :, 2:] = np.square(boxes[:, :, :, 2:]) #转换成实际的编辑框(没有归一化的)
boxes *= self.image_size #遍历每一个边界框的置信度
for i in range(self.boxes_per_cell):
#遍历每一个类别
for j in range(self.num_class):
#在测试时,乘以条件类概率和单个盒子的置信度预测,这些分数编码了j类出现在框i中的概率以及预测框拟合目标的程度。
probs[:, :, i, j] = np.multiply(
class_probs[:, :, j], scales[:, :, i]) #[7,7,2,20] 如果第i个边界框检测到类别j 则[;,;,i,j]=1
filter_mat_probs = np.array(probs >= self.threshold, dtype='bool')
#返回filter_mat_probs非0值的索引 返回4个List,每个list长度为n 即检测到的边界框的个数
filter_mat_boxes = np.nonzero(filter_mat_probs)
#获取检测到目标的边界框 [n,4] n表示边界框的个数
boxes_filtered = boxes[filter_mat_boxes[0],
filter_mat_boxes[1], filter_mat_boxes[2]]
#获取检测到目标的边界框的置信度 (n,)
probs_filtered = probs[filter_mat_probs]
#获取检测到目标的边界框对应的目标类别 (n,)
classes_num_filtered = np.argmax(
filter_mat_probs, axis=3)[
filter_mat_boxes[0], filter_mat_boxes[1], filter_mat_boxes[2]]
#按置信度倒序排序,返回对应的索引
argsort = np.array(np.argsort(probs_filtered))[::-1]
boxes_filtered = boxes_filtered[argsort]
probs_filtered = probs_filtered[argsort]
classes_num_filtered = classes_num_filtered[argsort] for i in range(len(boxes_filtered)):
if probs_filtered[i] == 0:
continue
for j in range(i + 1, len(boxes_filtered)):
#计算n各边界框,两两之间的IoU是否大于阈值,非极大值抑制
if self.iou(boxes_filtered[i], boxes_filtered[j]) :
probs_filtered[j] = 0.0 #非极大值抑制后的输出
filter_iou = np.array(probs_filtered > 0.0, dtype='bool')
boxes_filtered = boxes_filtered[filter_iou]
probs_filtered = probs_filtered[filter_iou]
classes_num_filtered = classes_num_filtered[filter_iou] result = []
#遍历每一个边界框
for i in range(len(boxes_filtered)):
result.append(
[self.classes[classes_num_filtered[i]], #类别名
boxes_filtered[i][0], #x中心
boxes_filtered[i][1], #y中心
boxes_filtered[i][2], #宽度
boxes_filtered[i][3], #高度
probs_filtered[i]]) #置信度 return result

6、iou()函数

计算两个边界框的IoU值。

    def iou(self, box1, box2):
'''
计算两个边界框的IoU args:
box1:边界框1 [4,] 真实值
box2:边界框2 [4,] 真实值
'''
tb = min(box1[0] + 0.5 * box1[2], box2[0] + 0.5 * box2[2]) - \
max(box1[0] - 0.5 * box1[2], box2[0] - 0.5 * box2[2])
lr = min(box1[1] + 0.5 * box1[3], box2[1] + 0.5 * box2[3]) - \
max(box1[1] - 0.5 * box1[3], box2[1] - 0.5 * box2[3])
inter = 0 if tb < 0 or lr < 0 else tb * lr
return inter / (box1[2] * box1[3] + box2[2] * box2[3] - inter)

7、camera_detector()函数

调用摄像头实现实时目标检测。

    def camera_detector(self, cap, wait=10):
'''
打开摄像头,实时检测 '''
#测试时间
detect_timer = Timer()
#读取一帧
ret, _ = cap.read() while ret:
#读取一帧
ret, frame = cap.read()
#测试其实时间
detect_timer.tic()
result = self.detect(frame)
#测试结束时间
detect_timer.toc()
print('Average detecting time: {:.3f}s'.format(
detect_timer.average_time))
#绘制边界框,以及添加附加信息
self.draw_result(frame, result)
#显示
cv2.imshow('Camera', frame)
cv2.waitKey(wait)

8、image_detector()函数

对图片进行目标检测。

    def image_detector(self, imname, wait=0):
'''
目标检测 args:
imname:测试图片路径
'''
#检测时间
detect_timer = Timer()
#读取图片
image = cv2.imread(imname)
#image = cv2.resize(image,(int(image.shape[1]/2),int(image.shape[0]/2)))
#检测的起始时间
detect_timer.tic()
#开始检测
result = self.detect(image)
#检测的结束时间
detect_timer.toc()
print('Average detecting time: {:.3f}s'.format(
detect_timer.average_time))
#绘制检测结果
self.draw_result(image, result)
cv2.imshow('Image', image)
cv2.waitKey(wait)

介绍完了Detector这个类,我们来看一下main函数。该函数比较检测,首先解析命令行参数,然后创建yolo网络,以及检测器对象,最后调用image_detector()函数对图片进行目标检测。

def main():
#创建一个解析器对象,并告诉它将会有些什么参数。当程序运行时,该解析器就可以用于处理命令行参数。
#https://www.cnblogs.com/lovemyspring/p/3214598.html
parser = argparse.ArgumentParser()
#定义参数
parser.add_argument('--weights', default="YOLO_small.ckpt", type=str)
parser.add_argument('--weight_dir', default='weights', type=str)
parser.add_argument('--data_dir', default="data", type=str)
parser.add_argument('--gpu', default='', type=str)
#定义了所有参数之后,你就可以给 parse_args() 传递一组参数字符串来解析命令行。默认情况下,参数是从 sys.argv[1:] 中获取
#parse_args() 的返回值是一个命名空间,包含传递给命令的参数。该对象将参数保存其属性
args = parser.parse_args() #设置环境变量
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu #创建YOLO网络对象
yolo = YOLONet(False)
#加载检查点文件
weight_file = os.path.join(args.data_dir, args.weight_dir, args.weights)
weight_file = './data/pascal_voc/weights/YOLO_small.ckpt'
#weight_file = './data/pascal_voc/output/2018_07_09_17_00/yolo.ckpt-1000' #创建测试对象
detector = Detector(yolo, weight_file) # detect from camera
# cap = cv2.VideoCapture(-1)
# detector.camera_detector(cap) # detect from image file
imname = 'test/car.jpg'
detector.image_detector(imname)

我们执行如下代码,开始测试网络:

if __name__ == '__main__':
tf.reset_default_graph()
main()

我们可以看到yolo网络对小目标检测效果并不好,漏检了一个目标。这主要与yolo的网络结构以及损失函数有关。除此之外yolo网络还有一些其他缺点,我们总结如下:

  • 漏检。每个网格只预测一个类别的边界框,而且最后只取置信度最大的那个边界框。这就导致如果多个不同物体(或者同类物体的不同实体)的中心落在同一个网格中,会造成漏检。yolo对相互靠的很近的物体,还有很小的群体检测效果不好,这是因为一个网格中只预测了两个框,并且只属于一类。
  • 位置精准性差。召回率低。由于损失函数的问题,定位误差是影响检测效果的主要原因。尤其是大小物体的处理上,还有待加强。
  • 对测试图像中,同一类物体出现的新的不常见的长宽比和其他情况是。泛化能力偏弱。

参考文章:

[1]argparse - 命令行选项与参数解析(转)

[2]Yolo v1详解及相关问题解答

yolo源码解析(三)的更多相关文章

  1. Celery 源码解析三: Task 对象的实现

    Task 的实现在 Celery 中你会发现有两处,一处位于 celery/app/task.py,这是第一个:第二个位于 celery/task/base.py 中,这是第二个.他们之间是有关系的, ...

  2. Mybatis源码解析(三) —— Mapper代理类的生成

    Mybatis源码解析(三) -- Mapper代理类的生成   在本系列第一篇文章已经讲述过在Mybatis-Spring项目中,是通过 MapperFactoryBean 的 getObject( ...

  3. 第三十六节,目标检测之yolo源码解析

    在一个月前,我就已经介绍了yolo目标检测的原理,后来也把tensorflow实现代码仔细看了一遍.但是由于这个暑假事情比较大,就一直搁浅了下来,趁今天有时间,就把源码解析一下.关于yolo目标检测的 ...

  4. ReactiveCocoa源码解析(三) Signal代码的基本实现

    上篇博客我们详细的聊了ReactiveSwift源码中的Bag容器,详情请参见<ReactiveSwift源码解析之Bag容器>.本篇博客我们就来聊一下信号量,也就是Signal的的几种状 ...

  5. ReactiveSwift源码解析(三) Signal代码的基本实现

    上篇博客我们详细的聊了ReactiveSwift源码中的Bag容器,详情请参见<ReactiveSwift源码解析之Bag容器>.本篇博客我们就来聊一下信号量,也就是Signal的的几种状 ...

  6. yolo源码解析(一)

    原文:https://www.cnblogs.com/zyly/p/9534063.html yolo源码来源于网址:https://github.com/hizhangp/yolo_tensorfl ...

  7. yolo源码解析(1):代码逻辑

    一. 整体代码逻辑 yolo中源码分为三个部分,\example,\include,以及\src文件夹下都有源代码存在. 结构如下所示 ├── examples │ ├── darknet.c(主程序 ...

  8. React的React.createRef()/forwardRef()源码解析(三)

    1.refs三种使用用法 1.字符串 1.1 dom节点上使用 获取真实的dom节点 //使用步骤: 1. <input ref="stringRef" /> 2. t ...

  9. yolo源码解析(3):视频检测流程

    代码在自己电脑中!!!!不在服务器 根据前文所说yolo代码逻辑: ├── examples │ ├── darknet.c(主程序) │ │── xxx1.c │ └── xxx2.c │ ├── ...

  10. Spring源码解析三:IOC容器的依赖注入

    一般情况下,依赖注入的过程是发生在用户第一次向容器索要Bean是触发的,而触发依赖注入的地方就是BeanFactory的getBean方法. 这里以DefaultListableBeanFactory ...

随机推荐

  1. 001.MySQL高可用主从复制简介

    一 简介 1.1 概述 Mysql内建的复制功能是构建大型,高性能应用程序的基础.将Mysql的数据分布在多个系统之上,这种分布的机制,是通过将Mysql的某一台主机的数据复制到其它主机(slaves ...

  2. 【python学习-4】可复用函数与模块

    1.自定义函数 自定义函数格式如下: def <函数名> (参数列表): <函数语句> return <返回值> #!/usr/bin/python # 定义函数, ...

  3. StringBuffer StringBuilder append

    StringBuilder is not thread safe. So, it performs better in situations where thread safety is not re ...

  4. PC端meta标签

    下面介绍meta标签的几个属性,charset,content,http-equiv,name. 一.charset 此特性声明当前文档所使用的字符编码,但该声明可以被任何一个元素的lang特性的值覆 ...

  5. python 修改文件中的内容

    在python的文件操作中,是没有办法对文件中具体某行或者某个位置的内容进行局部的修改的,如果需要对文件的某一行内容进行修改,可以先将文件中的所有的内容全部读取出来,再进行内容判断,是否是需要修改的内 ...

  6. 【BZOJ-3532】Lis 最小割 + 退流

    3532: [Sdoi2014]Lis Time Limit: 10 Sec  Memory Limit: 512 MBSubmit: 704  Solved: 264[Submit][Status] ...

  7. hdu 5734 Acperience 水题

    Acperience 题目连接: http://acm.hdu.edu.cn/showproblem.php?pid=5734 Description Deep neural networks (DN ...

  8. HDU 4423 Simple Function(数学题,2012长春D题)

    Simple Function Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others)T ...

  9. LPC-LINK 2 LPC4370 简化线路图

  10. LPC18xx/43xx OTP Controller driver

    LPC18xx/43xx OTP Controller driver /* * @brief LPC18xx/43xx OTP Controller driver * * @note * Copyri ...