之前看了Google官网的object_dectect 的源码,感觉Google大神写的还不错。最近想玩下Mask RCNN,就看了下源码,这里刚好当做总结和梳理。链接如下:

Google官网的object_dectect:https://github.com/tensorflow/models/tree/master/research/object_detection

Mask RCNN: https://github.com/matterport/Mask_RCNN

一个使用tensorflow 写的,一个是用keras写的,我自己是对tensorflow 会熟悉,但是kearas没用过,不过不影响看代码哈。有个比较困惑的地方,

好像我记得faster rcnn 中的rpn网络first stage的loss是proposals 和 gt_box的loss,而mask rcnn 是调出来进入第二步骤的正样本的proposals 和对应gt_box的loss,

虽然实际training中可能结果是一样的(我们一般会设置一个比较大的值,使得图像中所有的正样本都被框中,且进入second stage training).恩,废话不多说,开始

写mask rcnn 的源码阅读理解啦。这里简单的梳理下数据流向,就是图像被处理的一个个步骤,不过得对faster rcnn 和 fpn两个网络有所了解才好。

(一般阅读代码,使用py的文件比较多,进入函数,单步调试之类的,mask rcnn的例子都是ipynb,一般调成.py,在pycharm上单步调试。)

简而言之,mask rcnn 使用的是faster rcnn 的框架,和使用fpn的网络提取特征,在这个基础上增加了mask的预测。

事前准备:

训练数据 image

数据label:关于image的目标分割图

label的特征是对目标可以分割出来

然后,处理流程是这样的:

输入一张图,然后进入使用fpn网络提取特征,基本特征网络结构是resnet,框架是fpn的金字塔结构,如下图:

简要介绍是:

1-卷积过程:

一张图进入,resnet 不是有4个block,每个block提取的特征输出保存住,构成list [ c2,c3,c4,c5],代码如下:

2-上采样过程

将顶端的c5进行上采样,生成list [p2,p3,p4,p5,p6],其中p6是通过p5,加一次池化polling得到的。代码如下:

得到以后,first stage 和second stage的公共部分就得到啦,现在要分别处理了哦:

然后,就是两个过程,一个是 rpn过程,一个是mrcnn过程

 1-rnp过程:

rpn做啥???全称是Region Proposal Network,就是生成目标的矩形块,在faster rcnn中,只有一个feature map最终

的输出作为rpn的输入,就是 一个feature map 加上几层卷积,fully connect network ,然后就输出了rpn_bbox,rpn_prob,rpn_logit,

代码如下,输入的 x就是feature map

然而!!!,fpn网络是有多个 feature map 作为支撑输出 rpn的,所有作者这里就 把rpn网络包装成一个model,然后,

rpn_feature_map 中有的p2,p3,p4,p5,p6,一个个输入,再一个个输出,就可以的得到不同尺度的proposals box:

包装成model:

一个个提取proposals box:

这时候我们肯定会有疑问啦,不同尺寸的bbox可以放在一起吗?答案是可以的,因为,他们是归一化以后的,就像1/2 =2/4,一个样子。

接下来怎么办?,rpn还没有结束,rpn要变成roi,就是生成的框还要挑选下,变成生成的框+感兴趣的框(interest),所以要进入nms:

这里有几点要注意,就是rpn网络预测的实际上不是 object bbox,而是基于anchor的deata,就是图片中本来就有bbox(anchor),rpn预测的只是

这些anchor 离object 的bbox的中心偏差多少,长宽比偏差多少?示意图如下:

   

这里的rpn讲的不错:http://www.cnblogs.com/dudumiaomiao/p/6560841.html  

到此,first stage 基本结束,其实first stage做的是将目标找出来(不管目标是哪一个类别),接下来就是生成first stage rpn 和 anchor 一起生成bbox的过程(在Google object detection 代码中

这个过程叫做解码,把bbox信息解码出来):

自此,first stage 暂时到这里,就是输出了 bbox proposals ,object_classifier_prob,(目的是是否目标)

2-mrcnn过程:

对于DEtectionTargetLayer 这个函数对于:功能是

1、通过first stage 生成的proposals 和gt_box计算出iou矩阵 ,并通过iou值,取(crop)出正样本的feature map 和负样本的\feature map (

为什么这么搞,因为后期训练需要正负样本才能训练,如果全是正样本那就模型失效啦,为什么coco数据集是默认0.33正样本,因为其实我们

经常识别的目标在场景中都是占的比例是较少的,所以这里训练时给个经验值让正样本偏少些,但是正负样本比例不能差距太大,不然就训练

不收敛,假如0.1,则不训练随机猜测为负样本的准确率都有90%),其中正负样本比例通过config.ROI_POSITIVE_RATIO设置。

代码如下:

同时通过proposals和gt_box的偏差获得mask的偏差

这里大家可能有疑问:不是得用image height和width吗?为什么这里分母是:gt_h,gt_w,且mask又是怎样的一个形式。在config中有一个

配置项:config.USE_MINI_MASK,这个配置项目默认值是True,目的是为了在存储使用mask的时候使用小mask,也就是以object的bbox

为height和width(这样多余的0就不需要存储了)。所以就有了上面的gt_h,gt_w,当时看到这里的时候在想,不是得用image height和width

吗?后面看到 if config.USE_MINI_MASK就明白啦。

 这里再多说两句,关于mask又是怎样的一个形式

我们使用load_mask()函数,输出的是一个个和图像一样大小的mask,且该图像中有多少个目标,就mask就有多少个channal,一个目标

一个channal,用以后期提取bbox方便。然后,config.USE_MINI_MASK = True,就会再配置

这样,一张和图片同样大小的mask输入后,先将目标的bbox提取出来,再从大图片中crop出目标,然后在resize到(56*56),这样就节省

内存啦。

至此,准备工作做好了,接下来就是要求出每个proposals的类别(是人,还是自行车。。。。),并refine bbox和预测mask.

接下来讲两个函数,一个是fpn_classifier_graph,一个是build_fpn_mask_graph。前面的是对proposals中预测类别,后面的是预测mask,输入的都是

mrcnn_featur_map.

2、通过fpn_classifier_graph,

(更新)这里有个挺有趣的,就是,我们得到proposals后,要获得目标的feature map 信息,但是

我们的特征是 金字塔的(pyramid)的,从哪个图层crop下来呢?这里就有一个根据你的框的大小定位到一个图层的计算公式:

具体公式看fpn 那篇论文,表达的是框越小,level 越小,(目标小,level低的图层保留信息较完善,所以crop  level低的图层)

目标识别的传统操作,输出logits,probs,bbox。

2、通过buil_fpn_mask_grph,都是传统操作,卷积、fully connect,输出想要的结果:

至此,程序就差不多结束啦,后面就是loss了:

(更新)对于loss的简要介绍:

rpn_class_loss_graph:是 在rpn阶段识别的前景/背景的loss计算,这种分类的一般是交叉熵loss,见下:

:这个是rpn bbox loss,有一个trick 就是,这里只统计偏差(diff,见下面代码)比例小于1的,因为大于1的没有必要

就像faster rcnn选择iou > 0.7,为正样本,小于0.3的为负样本,当偏差大于1的时候,距离太远了,iou肯定是小的,这种loss不计算在内:

:同样的,对于类别分类使用交叉熵,这里就不给出了。

:使用l1 loss 为什么和上面不一样呢?实际和rpn bbox 差不多,就是取差值

:使用的是binary 交叉熵,做损失统计,我记得之前传统的分割方法,比如(

FC-DenseNet56)等,使用的loss是cross_entropy 有,还有一种是lovasz

Q/A(更新):

1、mask cnn引入预测K个输出的机制,允许每个类都生成独立的掩膜,避免类间竞争。这样做解耦了掩膜和种类预测,这句话的理解?

嗯,其实就是对于一个目标(或proposal)都生成number_classes 的mask,然后,只有与预测相一致类别的mask才会对loss有贡献,

并且,使用binary cross-entropy loss (具体是tf.nn.sigmoid_cross_entropy_with_logits)进行训练,这和传统的图像分割不一样,传统的图像经过fcn网络在像素级上分割,每个像素进行分类,这样就会导致类间的区域竞争,最后识别分割的结果出现颗粒感比较重的感觉,fcn使用的loss计算是SoftmaxLoss(具体是:tf.nn.softmax_cross_entropy_with_logits)。其实,binary cross-entropy loss常用在2分类,SoftmaxLoss是在多分类的,基本计算loss的原理都差不多,都是和ground truth的差(或“距离”)。

---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

最最后面放一下我的笔记,按我的建议是不建议大家看的,主要目的是笔记本的笔记容易丢,放在网上看的时候方便,给我自己看的哈(字漂,一般人都看不懂!)

Mask RCNN 源码阅读(update)的更多相关文章

  1. faster rcnn源码阅读笔记1

    自己保存的源码阅读笔记哈 faster rcnn 的主要识别过程(粗略) (开始填坑了): 一张3通道,1600*1600图像输入中,经过特征提取网络,得到100*100*512的feature ma ...

  2. faster rcnn源码阅读笔记3

  3. faster rcnn源码阅读笔记2

  4. faster rcnn 源码学习-------数据读入及RoIDataLayer相关模块解读

    参考博客:::https://www.cnblogs.com/Dzhen/p/6845852.html 非常全面的解读参考:::https://blog.csdn.net/DaVinciL/artic ...

  5. Pytorch版本yolov3源码阅读

    目录 Pytorch版本yolov3源码阅读 1. 阅读test.py 1.1 参数解读 1.2 data文件解析 1.3 cfg文件解析 1.4 根据cfg文件创建模块 1.5 YOLOLayer ...

  6. 【原】FMDB源码阅读(二)

    [原]FMDB源码阅读(二) 本文转载请注明出处 -- polobymulberry-博客园 1. 前言 上一篇只是简单地过了一下FMDB一个简单例子的基本流程,并没有涉及到FMDB的所有方方面面,比 ...

  7. 【原】FMDB源码阅读(一)

    [原]FMDB源码阅读(一) 本文转载请注明出处 —— polobymulberry-博客园 1. 前言 说实话,之前的SDWebImage和AFNetworking这两个组件我还是使用过的,但是对于 ...

  8. java8 ArrayList源码阅读

    转载自 java8 ArrayList源码阅读 本文基于jdk1.8 JavaCollection库中有三类:List,Queue,Set 其中List,有三个子实现类:ArrayList,Vecto ...

  9. 20 BasicTaskScheduler0 基本任务调度类基类(二)——Live555源码阅读(一)任务调度相关类

    这是Live555源码阅读的第二部分,包括了任务调度相关的三个类.任务调度是Live555源码中很重要的部分. 本文由乌合之众 lym瞎编,欢迎转载 http://www.cnblogs.com/ol ...

随机推荐

  1. datePecker时间控件区间写法

    成交时间: <input type="text" onclick="WdatePicker({dateFmt:'yyyy-MM-dd',maxDate:'#F{$d ...

  2. 样本失衡会对SVM的影响

    假设正类样本远多于负类 1.线性可分的情况 假设真实数据集如下: 由于负类样本量太少,可能会出现下面这种情况 使得分隔超平面偏向负类.严格意义上,这种样本不平衡不是因为样本数量的问题,而是因为边界点发 ...

  3. Web服务器之Nginx详解(理论部分)

    大纲 一.前言 二.Web服务器提供服务的方式 三.多进程.多线程.异步模式的对比 四.Web 服务请求过程 五.Linux I/O 模型 六.Linux I/O 模型具体说明 七.Linux I/O ...

  4. LeetCode - Maximum Frequency Stack

    Implement FreqStack, a class which simulates the operation of a stack-like data structure. FreqStack ...

  5. SharpDevelope 在 Windows 7 SP1 with .net framework4.0 下编译时找不到resgen.exe 解决办法

    如果在vs下编译正常,在SharpDevelope下编译报这个错误,可以更改编译时的.netframework版本和C#版本.在 Tool->Project Upgrade 进行项目转换后,一般 ...

  6. leetcode习题练习

    day001 #!user/bin/env python # -*- coding:utf-8 -*- #day001 两数之和 #方法1 def Sum(nbs,tgt): len_nums = l ...

  7. ubuntu彻底卸载opencv

    说正事之前,先啰嗦两句背景,算是拿个小本本记下了. 我本打算下载opencv2.4.在github上找到源码,在Branch处选择切换到2.4,然后复制URL,在terminal里面使用git clo ...

  8. 爬虫-day01-基础知识

    '''爬虫的构成下载器: 抓取页面 urllib equests selenium + webdriver解析器: 解释并提取页面元素 BeautifulSoup4 PyQuery Xpath Reg ...

  9. 分享一个生成反遗忘复习计划的java程序

    想必这个曲线大家都认识,这是遗忘曲线,展示人的记忆会随着时间的延长慢慢遗忘的规律,同时还展示了如果我们过一段时间复习一次对遗忘的有利影响. 道理大家都懂,关键怎么做到? 靠在本子上记下今天我该复习哪一 ...

  10. mosquitto centos安装配置

    周末弄wordpress的Mysql,一不小心把wordpress弄不好了,写了的好几遍文章也没有了,一怒之下,把整个系统重装了,安装了不带任何软件的新系统,重新搭一遍. 0.安装ftp服务器 #yu ...