SSD算法的实现
本文目的:介绍一个超赞的项目——用Keras来实现SSD算法。
本文目录:
- 0 前言
- 1 如何训练SSD模型
- 2 如何评估SSD模型
- 3 如何微调SSD模型
- 4 其他注意点
0 前言
我在学习完SSD算法之后,对具体细节有很多的疑惑,记录如下:
- SSD的网络是怎么实现的?
- 已有的数据是什么样子的?
- 如何把一张图像打散成anchors?
- 如何根据标注把各anchors打上标签?
- 正负样本是如何定义的?匹配策略是咋回事?
- 困难负样本挖掘是怎么实现的?
- 数据是怎么喂进去的?出来的又是什么?
- L2 Normalization在哪,如何实现的?
- Atrous层在哪?
- SSD的损失函数是怎么实现的?
- 数据在模型中是怎么流动的?
- 数据增强是怎么实现的?
- 预测的结果如何在原图上画框?
- 如何计算模型在Pascal VOC或MS COCO上的mAP?
在github上搜索时,发现了这个超赞的项目——用Keras来实现SSD算法,非常适合那些学习了SSD算法,但具体细节有些模糊的同学(主要是我自己)。文档注释非常详细,且提供非常清晰的操作指导,比如如何训练SSD模型,如何评估模型的性能,如何在自己的数据集上微调预训练的模型等等。
为了便于快速理解,现以其中一个简单版本的SSD网络为例(ssd7_training.ipynb文件)来记录总结,具体细节可参考项目文档注释。
1 如何训练SSD模型
主函数流程中,除了库函数的引入,以及通用参数的预定义之外,主要分为四块内容:
- 准备模型;
- 构建模型;
- 自定义损失函数,并编译;
- 准备数据;
- 定义训练集和验证集的图像生成器对象datagen;
- 利用图像生成器的函数读取文件图像和标签信息;
- 定义图像增强方法链;
- 利用编码器将标签信息编码成损失函数需要的格式;
- 定义数据对的迭代器generator;
- 训练;
- 定义回调函数;
- 训练;
- 训练结果可视化;
- 预测(可视化检测效果);
- 定义数据迭代器,并获取一个batch的样本;
- 将样本送入模型进行预测,并解码得到预测框;
- 将预测框和真值框画在原图上,对比效果。
1.1 准备模型
1.1.1 搭建模型
以training模式搭建一个小型的SSD模型。从4个地方引出predictor。各个预测特征图上每一个像素点均对应4个锚框。
模型搭建流程:
- 搭建base network;
- 从四个特征图处分别引出predictor;
- 每个predictor分为三条路径(按照第一篇参考文章的图示);
- 分类那路需要softmax,最后三路在最后一维度concatenate, 得到(batches, n_total_boxes, 21+4+8),即为模型原始输出raw output,其中n_boxes_total表示四个做预测的特征图对应的锚框总数,最后一个维度是21个类别+gt框偏移量+锚框坐标+variance;
- 若模式为inference,还需要在最后一层接上解码器DecodeDetections(输出经过置信度阈值、NMS等筛选后的预测框)
备注:
- 用AnchorBoxes层生成锚框,但为啥接在boxes4后面,而不是conv4后面,两者一样吗?答:一样的,因为只用了中间两个维度的数值,即特征图的高宽。但根据函数描述,应该接在conv4后面,即输入为(batch, height, width, channels)
AnchorBoxes层
- 目的是为了根据输入的特征图,将原图打散成一系列的锚框。
- 过程:根据参数缩放因子和高宽比,可以计算出特征图一个像素点对应的锚框数量和尺寸,再有特征图的高和宽,即可求得锚框的中心。
- 输入(batch, height, width, channels),即特征图的尺寸。
- 输出(batch, height, width, n_boxes, 8),这里n_boxes表示一个特征图对应的锚框总数,8表示锚框信息,即坐标+variance;
DecodeDetections层
- 在建模中mode=inference时,接在predictor后面的解码器;
- 过程:根据置信度阈值、NMS、最大输出数量等参数,对每张图筛选出前top_K个预测框;
- 输入即为模型的原始输出(batch,n_boxes_total,n_classes+4+8),最后一维是 类别(21)+框偏移量+锚框和variance(centroids格式);
- 输出(batch,top_k,6),最后一维是 (class_id, confidences, box_coordinates),坐标格式是xmin, ymin, xmax, ymax。这里top_K=200,即便合理的预测框不够,也会凑出200个。
备注:
- 输入参数说明里要求,只支持坐标输出为coords='centroids',这里coords='centroids'指的是输入的格式,实际输出格式是[xmin, ymin, xmax, ymax]。
1.1.2 自定义损失函数,并编译
(在SSD300模型中,需要先加载预训练的VGG16权重。)
自定义损失函数keras_ssd_loss.py
- 定义了一个损失类SSDLoss,里面有各种具体的损失函数,比如smooth L1和log损失;
- smooth L1损失:两个参数都是(batch_size, #boxes, 4),输出(batch_size, #boxes)。疑惑:这是直接求smooth L1,直接用坐标值求损失?照理说应该是求偏移量的损失啊?还是说输入的本来就应该是偏移量而非直接坐标值?答:在compute_loss函数中调用时传入的就是偏移量,所以OK。log损失很简单;
- compute_loss函数计算总损失,参数y_true和y_pred都是(batch_size, #boxes, #classes+12),输出scalar。疑问1,总损失为啥除以正样本个数而非总个数?答:正负样本比例为1:3,只是差个倍数,对结果不影响。疑问2,返回结果仍然是(batch,),并非标量?那乘以batch_size还有意义吗?答:keras强制以batch的方式计算各个值,即始终保证batch维度,实际运算的时候会给出(batch,...)对batch的平均值,因为compute_loss计算的是一个batch总的损失,所以keras强制平均后再乘以batch_size即为总和。
备注:
- 自定义的损失函数,传给compile的是对象的一个函数,这个函数返回的是根据y_true和y_predict计算的损失;
- 这个y_pred格式是(batch_size, #boxes, #classes + 12),即模型的raw output;y_true是后续SSDInputEncoder类实例将真值框编码后的输出;
- 如果是加载保存的模型,注意通过load_model中custom_objects传入自定义的层和函数。
1.2 准备数据
这一块内容大体上是通过自定义的图像生成器DataGenerator类及其方法来实现的。其中DataGenerator里面的函数generate()需要接收图像增强链和真值框编码器等参数,所以需要另外自定义两个类。
DataGenerator类
- DataGenerator实例化时自动调用__ init __ (),在这里面可以先进行图像增强处理(keras就是这么处理的,此处是在generate函数中做变形处理);
- DataGenerator里面的函数parser_csv()从文件中读入数据和标签(即真值框),读进来的真值框格式是一个长度为样本个数的list,其中每个元素为一个2D array,即array[[class_id, xmin, ymin, xmax, ymax]...],shape为(n_boxes,5),其中n_boxes为该样本的真值框个数;
- DataGenerator里面的函数generate()接收图像增强链和真值框编码器等参数,作用是产生一批批的数据对(X,y);(注意:keras内置的flow_from_directory实现了读取文件数据和生成(X,y)两个功能,但由于此处需要解析的文件除了CSV,可能还有其他形式,所以分成了两个函数);
- 关于加速的方法之一:第一次先用parser_csv读取图像和标签,然后利用create_hdf5_dataset()函数,将图像和标签转成h5文件(训练集近8G,验证集近2G,均已包含真值框,但未编码)。以后创建DataGenerator时就可以直接读取h5文件,然后用generate函数生成数据对,不再需要用parser_csv。但是经过测试,训练时用不用h5文件貌似没有区别,训练一个epoch的时间均为12分钟。
定义数据增强链
- DataAugmentationConstantInputSize类中,图像变形,真值框也要变形?否则就对不上了。如何变?答:在形变模块data_generator.object_detection_2d_geometric_ops中的方法,将labels一同放入进行了处理。
- Python中,如果在创建class的时候写了__ call __ ()方法,那么该class实例化出实例后,实例名()就是调用__ call __ ()方法;在keras中自定义层时用call(),而不是__ call __ ();
- DataAugmentationConstantInputSize中__ init __ ()集成了一系列变形对象置于sequence中,并在__ call __ ()函数中调用。
用SSDInputEncoder类实例将真值框编码成损失函数需要的格式y_true(这里用y_encoded表示)
- 输入的gt_labels是一个长度为batch_size的list,其中每个元素为一个2D array,即array[[class_id, xmin, ymin, xmax, ymax]...];
- 主要功能在__ call __ ()函数中实现,分为三步:
- 根据原图尺寸、缩放因子与高宽比、特征图尺寸三个条件,创建y_encoded模板(即一系列的anchors,shape为(batch,#boxes,21+4+4+4),最后为21个类别+gt框坐标+锚框坐标+variance);
- 匹配真值框和锚框,即更新最后一个维度的21+4;
- 将gt的坐标转换为锚框的偏移量;
调用train_dataset.generate()生成需要的数据对(X,y)
- 准备好数据增强链对象和SSDInputEncoder对象后,同其他参数一起传入train_dataset.generate中,指定生成器返回数据格式为(processed_images,encoded_labels),(前者shape为(batch,h_img,w_img,channel),后者即为用SSDInputEncoder类实例的输出结果),供后续model.fit_generator使用。
1.3 训练
定义了几个有用的回调函数,ModelCheckpoint(保存每次epoch之后的模型)、CSVLogger(将损失和指标保存至CSV文件)、EarlyStopping(早停)、ReduceLROnPlateau(平缓区自动减小学习率),在SSD300中还用了LearningRateScheduler(按计划调整学习率)、TerminateOnNaN(遇到NaN数据即停止训练)。其中最为常用的是ModelCheckpoint和CSVLogger。
训练时参数initial_epoch和final_epoch也很有意思,允许用户从上次中断的地方开始训练。(再也不怕中午睡觉被吵了:-))
训练结果可视化:可以直接调用fit的返回值,也可以读取CSV文件中记录值。
1.4 预测(可视化检测效果)
获取预测值
- 定义数据迭代器,并获取一个batch的样本;
- 将这个batch的样本送入模型进行预测,得到预测值;(这时候得到的y_pred是模型的raw output)
解码器对预测值进行后处理
- decode_detections函数功能同模型架构中解码器层DecodeDetections功能一样,都是:
- 偏移量转为坐标(可以是绝对坐标,也可以是相对坐标),同时后12个数转成4个数;
- 针对每一个类别,进行置信度过滤和NMS;
- 选取前top_k个预测框(若设置top_k),不足top_k的话直接输出。
- 输入的y_pred参数:training模式下SSD模型的原始输出(batch,#boxes,21+4+4+4),其中#boxes为所有锚框;
- 返回值:(batch,filtered_boxes,6),其中filtered_boxes为经过筛选后的预测框数量,6为[class_id, confidence, xmin, ymin, xmax, ymax];
- 注意: decode_detections函数和DecodeDetections层有不一样处:若经过筛选后预测框数量不足top_k,前者是直接输出,但后者会填充成top_k个(为了计算损失时维度一致)。
将预测框显示在图像上,对比效果
- 显示图像,画标注框和预测框的方法;
- plt中plt.cm可将数值映射成伪色彩,(很有用,因为相对于亮度,人们对颜色的变化更敏感),参考
1.5 SSD300训练的区别
- 训练SSD300的模型时,用的是Pascal VOC的数据,标签文件是XML文件;
- SSD300的模型结构中,有三点需要注意:
- 模型的结构按照原生SSD搭建;
- 空洞卷积层:fc6 = Conv2D(1024, (3, 3), dilation_rate=(6, 6),...);
- L2 Normalization层:conv4_3_norm = L2Normalization(gamma_init=20,...)(conv4_3);
- 疑问:SSD300定义模型参数的时候,将图像通道换成了BGR来训练,但是最后预测的时候图像通道没有换成BGR?
2 如何评估SSD模型
大致要点:
- 这一块单独列了一个文件,即SSD300_evaluation.ipynb;
- SSD Evaluation中,创建模型用的是inference模式,下载的权重文件VGG_VOC0712Plus_SSD_300x300_ft_iter_160000.h5 是以training模式创建的模型训练的(既然是权重文件,那肯定是训练得到的,所以模型肯定是以training模式创建的),所以model.load_weights(weights_path, by_name=True)中需要加上by_name,否则对不上号;
- 绘制PR曲线的方法。
3 如何微调SSD模型
这一块内容详见 weight_sampling_tutorial.ipynb。
作者提供了几种训练好的SSD模型,那么如何微调这些模型,使其能在自己的数据集上完成自己的任务?比如现在我想识别8种物体,而作者提供的是在MS COCO上训练的能识别80种物体的模型,那么该如何操作?
作者提出了3种方法,并认为最好的方法是直接对分类器的结果进行下采样。比如SSD第一个predictor的分类器输出是(batch, h1, w1, 81 * 4),其中h1和w1是conv4_3特征图的高度和宽度,对输出下采样得到(batch, h1, w1, 9 * 4),其中9表示8种物体和背景,然后在自己的数据集上微调即可。这种方法对那些目标物体在MS COCO的80个类别之内的任务特别有效。
4 其他注意点
- model.load_weights('./ssd7_weights.h5', by_name=True):这里by_name是指只加载同名层的权重,适合加载那些结构不同的模型权重,详见
- 尽量使用model.save保存模型整体,因为分开保存后,重新加载时optimizer的状态会被重置,详见
Reference:
SSD算法的实现的更多相关文章
- Bug2算法的实现(RobotBASIC环境中仿真)
移动机器人智能的一个重要标志就是自主导航,而实现机器人自主导航有个基本要求--避障.之前简单介绍过Bug避障算法,但仅仅了解大致理论而不亲自动手实现一遍很难有深刻的印象,只能说似懂非懂.我不是天才,不 ...
- Canny边缘检测算法的实现
图像边缘信息主要集中在高频段,通常说图像锐化或检测边缘,实质就是高频滤波.我们知道微分运算是求信号的变化率,具有加强高频分量的作用.在空域运算中来说,对图像的锐化就是计算微分.由于数字图像的离散信号, ...
- java基础解析系列(四)---LinkedHashMap的原理及LRU算法的实现
java基础解析系列(四)---LinkedHashMap的原理及LRU算法的实现 java基础解析系列(一)---String.StringBuffer.StringBuilder java基础解析 ...
- SSE图像算法优化系列十三:超高速BoxBlur算法的实现和优化(Opencv的速度的五倍)
在SSE图像算法优化系列五:超高速指数模糊算法的实现和优化(10000*10000在100ms左右实现) 一文中,我曾经说过优化后的ExpBlur比BoxBlur还要快,那个时候我比较的BoxBlur ...
- 详解Linux内核红黑树算法的实现
转自:https://blog.csdn.net/npy_lp/article/details/7420689 内核源码:linux-2.6.38.8.tar.bz2 关于二叉查找树的概念请参考博文& ...
- 详细MATLAB 中BP神经网络算法的实现
MATLAB 中BP神经网络算法的实现 BP神经网络算法提供了一种普遍并且实用的方法从样例中学习值为实数.离散值或者向量的函数,这里就简单介绍一下如何用MATLAB编程实现该算法. 具体步骤 这里 ...
- Python学习(三) 八大排序算法的实现(下)
本文Python实现了插入排序.基数排序.希尔排序.冒泡排序.高速排序.直接选择排序.堆排序.归并排序的后面四种. 上篇:Python学习(三) 八大排序算法的实现(上) 1.高速排序 描写叙述 通过 ...
- C++基础代码--20余种数据结构和算法的实现
C++基础代码--20余种数据结构和算法的实现 过年了,闲来无事,翻阅起以前写的代码,无意间找到了大学时写的一套C++工具集,主要是关于数据结构和算法.以及语言层面的工具类.过去好几年了,现在几乎已经 ...
- Python八大算法的实现,插入排序、希尔排序、冒泡排序、快速排序、直接选择排序、堆排序、归并排序、基数排序。
Python八大算法的实现,插入排序.希尔排序.冒泡排序.快速排序.直接选择排序.堆排序.归并排序.基数排序. 1.插入排序 描述 插入排序的基本操作就是将一个数据插入到已经排好序的有序数据中,从而得 ...
随机推荐
- null,blank,default
null 是针对数据库而言,如果 null=True, 表示数据库的该字段可以为空. blank 是针对表单的,如果 blank=True,表示你的表单填写该字段的时候可以不填,比如 admin 界面 ...
- Codeforces 499C:Crazy Town(计算几何)
题目链接 给出点A(x1,y1),B(x2,y2),和n条直线(ai,bi,ci,aix + biy + ci = 0),求A到B穿过多少条直线 枚举每条直线判断A.B是否在该直线两侧即可 #incl ...
- gensim Load embeddings
gensim package from gensim.models.keyedvectors import KeyedVectors twitter_embedding_path = 'twitter ...
- mybatis源码分析之06二级缓存
上一篇整合redis框架作为mybatis的二级缓存, 该篇从源码角度去分析mybatis是如何做到的. 通过上一篇文章知道,整合redis时需要在FemaleMapper.xml中添加如下配置 &l ...
- webpack对html模板的处理
一.打包html模板到相应目录并且引入js 需要安装 html-webpack-plugin 然后在plugins里实例化 new HtmlWebpackPlugin({ template:'./sr ...
- python 操作yaml文件
yaml 5.1版后弃用了yaml.load(file)这个用法,因为觉得很不安全,5.1版后就修改了需要指定Loader,通过默认加载器(FullLoader)禁止执行任意函数yaml 5.1之 ...
- 文字在线中间,CSS巧妙实现分隔线的几种方法
单个标签实现分隔线: .demo_line_01{ padding: 0 20px 0; margin: 20px 0; line-height: 1px; border-left: 200px so ...
- grid布局快速入门
Grid布局快速入门 常用Grid布局属性介绍 下面从一个简单Grid布局例子说起.CSS Grid 布局由两个核心组成部分是 wrapper(父元素)和 items(子元素). wrapper 是实 ...
- AtCoder Grand Contest 012 A - AtCoder Group Contest(贪心)
Time limit : 2sec / Memory limit : 256MB Score : 300 points Problem Statement There are 3N participa ...
- 公司-IT-SanSan:SanSan
ylbtech-公司-IT-SanSan:SanSan 毫不费力的组织.无缝简单.基于名片的联系人管理 SanSan是一个名片管理应用,为企业提供内部联系人管理和分享服务,此外该公司也是日本最大的.基 ...