SSD源码解读——网络搭建
之前,对SSD的论文进行了解读,可以回顾之前的博客:https://www.cnblogs.com/dengshunge/p/11665929.html。
为了加深对SSD的理解,因此对SSD的源码进行了复现,主要参考的github项目是ssd.pytorch。同时,我自己对该项目增加了大量注释:https://github.com/Dengshunge/mySSD_pytorch
搭建SSD的项目,可以分成以下三个部分:
接下来,本篇博客重点分析网络搭建。
该部分整体比较简单,思路也很清晰。
首先,在train.py中,网络搭建的函数入口是函数build_ssd(),该函数需要传入以下几个参数:"train"或者"test"字符串、图片尺寸、类别数。其中,"train"或者"test"字符串用于区分该网络是用于训练还是测试,这两个阶段的网络有些许不同,本文主要将训练阶段的网络;而类别数需要加上背景,对于VOC而言,有20个类别,加上1个背景,即类别数是21。
ssd_net = build_ssd('train', voc['min_dim'], voc['num_classes'])
这里,先放一张SSD的网络结构图,可以看出,SSD网络是有3部分组成的,vgg主干网络,新增网络(Conv6之后的层)和用于检测的头部网络(Extra Feature Layers)。
接着,在ssd.py中,首先定了一个参数,如下所示。这里主要以SSD300为例。这些参数有什么用呢?字典base的参数指的是用于搭建VGG主干网络输出通道数,其中“M”表示需要进行maxpooling;字典extras的参数同样表示新增层的输出通道数,其中“S”表示需要stride=2的降采样;字典mbo的参数表示用于特征融合的层中,每个层对应未知(x,y)的锚点框数量,在SSD300中,使用了6个层进行特征融合,如Conv_4层中,每个位置使用4个锚点框进行预测。
base = {
'': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
512, 512, 512], # M表示maxpolling
'': [],
}
extras = {
'': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256], # S表示stride=2
'': [],
}
mbox = {
'': [4, 6, 6, 6, 4, 4], # 每个特征图的每个点,对应锚点框的数量
'': [],
}
当定义完需要使用到的参数后,可以进行如具体搭建的环节。函数build_ssd()的定义如下所示。利用函数multibox()来构建SSD网络的各个部分,分别是VGG主干网络,新增层和用于检测的头部网络(或许可以理解为分类头和回归头)。而VGG主干网络是通过函数vgg()来实现,新增层是通过函数add_extras()来实现,而函数multibox()则搭建用于检测的头部网络。最后用这些层来初始化类SSD。
def build_ssd(phase, size=300, num_class=21):
if phase != 'test' and phase != 'train':
raise ("ERROR: Phase: " + phase + " not recognized")
base_, extras_, head_ = multibox(vgg(base[str(size)]),
add_extras(extras[str(size)], in_channels=1024),
mbox[str(size)],
num_class)
return SSD(phase, size, base_, extras_, head_, num_class)
我们来看一下VGG主干网络是如何搭建的。函数vgg()需要将上述的base字典传入进去,根据base字典,来搭建卷积层和池化层。作者对vgg网络进行了改进,即将fc6和fc7更改成conv6和conv7。值得留意的是,在conv6中,使用了空点卷积,dilation=6,增大感受野。在SSD论文的最后,也讨论了空洞卷积对结果有好的影响。最后,将这些卷积层和池化层放入list中,并返回这个list。
def vgg(cfg=base[''], batch_norm=False):
'''
该函数来源于torchvision.models.vgg19()中的make_layers()
'''
layers = []
in_channels = 3 # vgg主体部分,到论文的conv6之前
for v in cfg:
if v == 'M':
# ceil_mode是向上取整
layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True) conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)
conv7 = nn.Conv2d(1024, 1024, kernel_size=1)
layers += [pool5, conv6, nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)] return layers
接下来,我们来了解一下SSD在vgg中新增层,即conv7之后的网络层。同样,函数add_extras()需要传入字典extras,来构建网络层。这里可以留意一下,kernel_size的写法,(1,3)为一个元祖tuple,flag来控制取哪个值,即可变换使用3*3或者1*1的卷积核,减少代码的冗余。最后将新构建的层存入list中,并返回这个list。
def add_extras(cfg=extras[''], in_channels=1024):
'''
完成SSD后半部分的网络构建,即作者新加上去的网络,从conv7之后到conv11_2
'''
layers = []
flag = False # 交替控制卷积核,使用1*1或者使用3*3
for k, v in enumerate(cfg):
if in_channels != 'S':
if v == 'S':
layers += [nn.Conv2d(in_channels, cfg[k + 1], kernel_size=(1, 3)[flag], stride=2, padding=1)]
else:
layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])]
flag = not flag
in_channels = v
return layers
当有了vgg的主干网络和新增层后,可以将某些层进行特征融合和预测了。这里,就需要使用到函数multibox()。需要将vgg主干网络和新增层的list、字典mbox和类别数传入函数中。首先,函数multibox()会创建两个list,用于保存位置回归的层和置信度的层。对于每个用于融合的特征层,会分成两部分,一个用于回归,使用3*3的卷积,输出通道数是cfg[k] * 4,其中cfg[k]表示每个位置上锚点框的数量,4表示[x_min,y_min,x_max,y_max];另外一个用于类别的判断,也是使用3*3的卷积,输出通道数是cfg[k] * num_class,表示每个锚点框判断其属于哪一个类别,在voc中,num_class=21(包含背景)。可以理解成将此特征层分成了分类头和回归头,每个锚点框会输出4个坐标和21个类别置信度。最后将vgg主干网络、新增层、分类头和回归头返回。
def multibox(vgg, extra_layers, cfg, num_class):
'''
返回vgg网络,新增网络,位置网络和置信度网络
'''
loc_layers = [] # 判断位置
conf_layers = [] # 判断置信度 vgg_source = [21, -2] # 21表示conv4_3的序号,-2表示conv7的序号
for k, v in enumerate(vgg_source):
# vgg[v]表示需要提取的特征图
# cfg[k]代表该特征图下每个点对应的锚点框数量
loc_layers += [nn.Conv2d(vgg[v].out_channels, cfg[k] * 4, kernel_size=3, padding=1)]
conf_layers += [nn.Conv2d(vgg[v].out_channels, cfg[k] * num_class, kernel_size=3, padding=1)] for k, v in enumerate(extra_layers[1::2], 2):
# [1::2]表示,从第1位开始,步长为2
# 这么做的目的是,新增加的层,都包含2层卷积,需要提取后面那层卷积的结果作为特征图
loc_layers += [nn.Conv2d(v.out_channels, cfg[k] * 4, kernel_size=3, padding=1)]
conf_layers += [nn.Conv2d(v.out_channels, cfg[k] * num_class, kernel_size=3, padding=1)] return vgg, extra_layers, (loc_layers, conf_layers)
函数multibox()返回的各个层,用于初始化类SSD。首先,由于因为“train”阶段和“test”阶段是有点区别的,本节依然主要将“train”阶段,因此,需要传入phase参数,参数只能是两个值(train,test)。函数PriorBox()的作用是来创建先验锚点框,返回的shape为[8732,4],其中具有8732个锚点框,4表示每个锚点框的坐标[中心点x,中心点y,宽,高],这里的坐标值有点不太一样。由于传入的网络层是以list列表的形式,因此,用nn.ModuleList()将其转换为pytorch的网络结构。
接下来看类SSD中的函数forward(),用于前向推理。按顺序对输入图片进行处理,在conv4中,需要对特征图进行L2正则化。并将用于特征融合的特征图存在放sources中。在得到5个用于融合的特征图后,将这些特征图输入到分类头和回归头中,每个特征图对应各自的分类头和回归头。这里注意一下,分类头或者回归头卷积后,使用了permute()函数。该函数的作用是交换维度,原本的维度是[batch_size,channel,height,weight],交换维度后变成了[batch_size,height,weight,channel],这样做的目的是方便后续的处理。将处理后的结果保存在loc和conf这两个List中。后续接着对loc和conf进行变换,利用view()函数,最终,loc的shape为[batch_size,8732*4],conf的shape为[batch_size,8732*21]。
最后,将loc和conf这两个List又变换维度,返回出去,用于计算loss损失函数(感觉这么多变换,有点重复呀,应该可以省略一部分)。"train"阶段和"test"阶段返回的结果类似,其中不同点是,在test阶段,置信度需要经过softmax。
class SSD(nn.Module):
'''
构建SSD的主函数,将base(vgg)、新增网络和位置网络与置信度网络组合起来
''' def __init__(self, phase, size, base, extras, head, num_classes):
super(SSD, self).__init__()
self.phase = phase
self.num_classes = num_classes
self.priors = torch.Tensor(PriorBox(voc))
self.size = size # SSD网络
self.vgg = nn.ModuleList(base)
# 对conv4_3的特征图进行L2正则化
self.L2Norm = L2Norm(512, 20) self.extras = nn.ModuleList(extras)
self.loc = nn.ModuleList(head[0])
self.conf = nn.ModuleList(head[1]) if phase == 'test':
self.softmax = nn.Softmax(dim=-1)
self.detect = Detect(num_classes=self.num_classes, top_k=200,
conf_thresh=0.01, nms_thresh=0.45) def forward(self, x):
sources = [] # 保存特征图
loc = [] # 保存每个特征图进行位置网络后的信息
conf = [] # 保存每个特征图进行置信度网络后的信息 # 处理输入至conv4_3
for k in range(23):
x = self.vgg[k](x) # 对conv4_3进行L2正则化
s = self.L2Norm(x)
sources.append(s) # 完成vgg后续的处理
for k in range(23, len(self.vgg)):
x = self.vgg[k](x)
sources.append(x) # 使用新增网络进行处理
for k, v in enumerate(self.extras):
x = F.relu(v(x), inplace=True)
if k % 2 == 1:
sources.append(x) # 将特征图送入位置网络和置信度网络
# l(x)或者c(x)的shape为[batch_size,channel,height,weight],使用了permute后,变成[batch_size,height,weight,channel]
# 这样做应该是为了方便后续处理
for (x, l, c) in zip(sources, self.loc, self.conf):
loc.append(l(x).permute(0, 2, 3, 1).contiguous())
conf.append(c(x).permute(0, 2, 3, 1).contiguous()) # 进行格式变换
loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) # [batch_size,34928],锚点框的数量8732*4=34928
conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) if self.phase == 'train':
output = (loc.view(loc.size(0), -1, 4), # [batch_size,num_priors,4]
conf.view(conf.size(0), -1, self.num_classes), # [batch_size,num_priors,21]
self.priors) # [num_priors,4]
else: # Test
output = self.detect(
loc.view(loc.size(0), -1, 4), # 位置预测
self.softmax(conf.view((conf.size(0), -1, self.num_classes))), # 置信度预测
self.priors.cuda() # 先验锚点框
) return output
在上面类SSD中,提及到了先验锚点框的构建函数PriorBox(),这个函数在models/prior_box.py中。首先,根据用于融合的特征图尺寸和product()函数,生成一系列的点,如(0,0),(0,1),(0,2)等。然后根据这些像素点位置,偏移0.5作为锚点框的中心点,即cx和cy,并将其归一化。然后计算论文中的$s_k$和${s_k}'$,对应s_k和s_k_prime,先计算$a_r=1$的情况,再计算其余$a_r$的情况。此时,mean的shape为[1,34928],因此,需要使用view()函数,将其切割出来,变成[8732,4]。记得,这里的锚点框的坐标是[中心点x,中心点y,宽,高]。
def PriorBox(cfg):
'''
为所有特征图生成预设的锚点框,返回所有生成的锚点框,尺寸为[8732,4],
每行表示[中心点x,中心点y,宽,高]
'''
image_size = cfg['min_dim'] #
feature_maps = cfg['feature_maps'] # [38, 19, 10, 5, 3, 1],特征图尺寸
steps = cfg['steps'] # [8, 16, 32, 64, 100, 300]
min_sizes = cfg['min_sizes'] # [30, 60, 111, 162, 213, 264]
max_sizes = cfg['max_sizes'] # [60, 111, 162, 213, 264, 315]
aspect_ratios = cfg['aspect_ratios'] # [[2], [2, 3], [2, 3], [2, 3], [2], [2]] mean = []
# 为所有特征图生成锚点框
for k, f in enumerate(feature_maps):
# product(list1,list2)的作用是依次取出list1中的每1个元素,与list2中的每1个元素,
# 组成元组,然后,将所有的元组组成一个列表,返回
# 而这里使用了repeat,说明1个list重复2次
for i, j in product(range(f), repeat=2):
f_k = image_size / steps[k]
# 计算中心点,这里的j是沿x方向变化的
cx = (j + 0.5) / f_k
cy = (i + 0.5) / f_k # aspect_ratio=1有两种情况,s_k=s_k,s_k=sqrt(s_k*s_(k+1))
s_k = min_sizes[k] / image_size
mean += [cx, cy, s_k, s_k] s_k_prime = sqrt(s_k * (max_sizes[k] / image_size))
mean += [cx, cy, s_k_prime, s_k_prime] # 剩余的aspect_ratio
for ar in aspect_ratios[k]:
mean += [cx, cy, s_k * sqrt(ar), s_k / sqrt(ar)]
mean += [cx, cy, s_k / sqrt(ar), s_k * sqrt(ar)] # 此时的mean是1*34928的list,要4个数就分割出来,所以需要用view,从而变成[8732,4],即有8732个锚点框
output = torch.Tensor(mean).view(-1, 4)
if cfg['clip']:
# 对每个元素进行截断限制,限制为[0,1]之间
output.clamp_(min=0, max=1)
return output
最后,类SSD中还对conv4的特征层使用了L2正则化,该函数在models/l2norm.py中。在函数forwand()中,按每个通道对其值进行L2正则化,即除以通道的平方根来实现归一化。
class L2Norm(nn.Module):
'''
对conv4_3进行l2归一化
''' def __init__(self, n_channels, scale):
super(L2Norm, self).__init__()
self.n_channels = n_channels
self.gamma = scale
self.eps = 1e-10
self.weight = nn.Parameter(torch.Tensor(self.n_channels)) # n_channels个随机数
self.reset_parameters() def reset_parameters(self):
# 使用gamma来填充weight的每个值
nn.init.constant_(self.weight, self.gamma) def forward(self, x):
# 按通道进行求值
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps # [1,1,38,38]
x = torch.div(x, norm)
# 将weight通过3个unsqueeze展开成[1,512,1,1],然后通过expand_as进行扩展,形成[1,512,38,38]
out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x
return out
至此,SSD的网络搭建过程已经完成了,通过类SSD的forward()函数,即能返回预测框的坐标和类别置信度,以此可以计算损失函数。
SSD源码解读——网络搭建的更多相关文章
- SSD源码解读——网络测试
之前,对SSD的论文进行了解读,可以回顾之前的博客:https://www.cnblogs.com/dengshunge/p/11665929.html. 为了加深对SSD的理解,因此对SSD的源码进 ...
- SSD源码解读——损失函数的构建
之前,对SSD的论文进行了解读,可以回顾之前的博客:https://www.cnblogs.com/dengshunge/p/11665929.html. 为了加深对SSD的理解,因此对SSD的源码进 ...
- SSD源码解读——数据读取
之前,对SSD的论文进行了解读,可以回顾之前的博客:https://www.cnblogs.com/dengshunge/p/11665929.html. 为了加深对SSD的理解,因此对SSD的源码进 ...
- Alamofire源码解读系列(七)之网络监控(NetworkReachabilityManager)
Alamofire源码解读系列(七)之网络监控(NetworkReachabilityManager) 本篇主要讲解iOS开发中的网络监控 前言 在开发中,有时候我们需要获取这些信息: 手机是否联网 ...
- 基于Docker的TensorFlow机器学习框架搭建和实例源码解读
概述:基于Docker的TensorFlow机器学习框架搭建和实例源码解读,TensorFlow作为最火热的机器学习框架之一,Docker是的容器,可以很好的结合起来,为机器学习或者科研人员提供便捷的 ...
- AFNetworking 3.0 源码解读(六)之 AFHTTPSessionManager
AFHTTPSessionManager相对来说比较好理解,代码也比较短.但却是我们平时可能使用最多的类. AFNetworking 3.0 源码解读(一)之 AFNetworkReachabilit ...
- (转)go语言nsq源码解读二 nsqlookupd、nsqd与nsqadmin
转自:http://www.baiyuxiong.com/?p=886 ---------------------------------------------------------------- ...
- swoft| 源码解读系列一: 好难! swoft demo 都跑不起来怎么破? docker 了解一下呗~
title: swoft| 源码解读系列一: 好难! swoft demo 都跑不起来怎么破? docker 了解一下呗~description: 阅读 sowft 框架源码, swoft 第一步, ...
- SDWebImage源码解读 之 UIImage+GIF
第二篇 前言 本篇是和GIF相关的一个UIImage的分类.主要提供了三个方法: + (UIImage *)sd_animatedGIFNamed:(NSString *)name ----- 根据名 ...
随机推荐
- 爬虫 lxml 模块
Xpath 在 XML 文档中查找信息的语言, 同样适用于 HTML 辅助工具 Xpath Helper Chrome插件 快捷键 Ctrl + shift + x XML Quire xpath ...
- springboot2.0数据制作为excel表格
注意:由于公司需要大量导出数据成excel表格,因此在网上找了方法,亲测有效. 声明:该博客参考于https://blog.csdn.net/long530439142/article/details ...
- Rxjava2实战--第三章 创建操作符
Rxjava2实战--第三章 创建操作符 Rxjava的创建操作符 操作符 用途 just() 将一个或多个对象转换成发射这个或者这些对象的一个Observable from() 将一个Iterabl ...
- CnPack 开源软件项目
Cnpack公共窗体库 ------------------------------ CnPack 2009-09-14 SVN 包,包括以下内容: 1. CnPack 组件包所有源代码.2. CnP ...
- Centos7桥接网络、DNS、时间同步配置
Centos配置桥接网络.DNS服务和时间同步 1.配置桥接网络 2.配置虚拟机网卡,采用的是静态ip方式 重启network服务 3.配置dns 4.关闭防火墙和selinux 5.ping外网域名 ...
- Java 操作Word表格
本文将对如何在Java程序中操作Word表格作进一步介绍.操作要点包括 如何在Word中创建嵌套表格. 对已有表格添加行或者列 复制已有表格中的指定行或者列 对跨页的表格可设置是否禁止跨页断行 创建表 ...
- CTF—攻防练习之HTTP—命令注入
主机:192.168.32.152 靶机:192.168.32.167 首先nmap,nikto -host,dirb 探测robots.txt目录下 在/nothing目录中,查看源码发现pass ...
- 状态压缩DP:蒙德里安的梦想
代码 #include<bits/stdc++.h> using namespace std; int n,m; long long f[12][1<<11]; bool yy ...
- 使用mint ui 的picker解决城市三级联动
<mt-popup v-model="popupVisible" position="bottom"> <div class="po ...
- 使用Python过程中遇到的一些坑及其解决方法(持续更新)
1.列表不能直接赋值 nums1 = nums2 x nums1[:] = nums2 正确 2.返回列表某一元素的值可以使用index函数 aList = [123, 'xyz', 'runoob' ...