一、DenseNet的优点

  • 减轻梯度消失问题
  • 加强特征的传递
  • 充分利用特征
  • 减少了参数量

二、网络结构公式

对于每一个DenseBlock中的每一个层,

[x0,x1,…,xl-1]表示将0到l-1层的输出feature map做concatenation。concatenation是做通道的合并,就像Inception那样。而前面resnet是做值的相加,通道数是不变的。Hl包括BN,ReLU和3*3的卷积。

而在ResNet中的每一个残差块,

三、Growth Rate

指的是DenseBlock中每一个非线性变换Hl(BN,ReLU和3*3的卷积)的输出,这个输出与输入Concate.一个DenseBlock的输出=输入+Hl数×growth_rate。在要给DenseBlock中,Feature Map的size保持不变。

四、Bottleneck

这个组件位于DenseBlock中,当一个DenseBlock包含的非线性变换Hl较多时(如nHl=48),此时的grow rate为k=32,那么第48层的输入变成input+47×32,这是一个很大的数,如果不用bottleneck进行降维,那么计算量很大。

因此,使用4×k个1x1卷积进行降维。使得3×3线性变换的输入通道变成4×k。同时,bottleneck起到特征融合的效果。

五、Transition

这个组件位于DenseBlock之间,使用1×1卷积进行降维,降维后的通道数为input_channels*reduction. 参数reduction默认为0.5,后接池化层进行下采样,减小Feature Map 分辨率。

六、网络结构

 

七、代码实现(Pytorch)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math class Bottleneck(nn.Module):
def __init__(self,nChannels,growthRate):
super(Bottleneck,self).__init__()
interChannels = 4*growthRate
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels,interChannels,kernel_size=1,
stride=1,bias=False)
self.bn2 = nn.BatchNorm2d(interChannels)
self.conv2 = nn.Conv2d(interChannels,growthRate,kernel_size=3,
stride=1,padding=1,bias=False) def forward(self, *input):
#先进行BN(pytorch的BN已经包含了Scale),然后进行relu,conv1起到bottleneck的作用
out = self.conv1(F.relu(self.bn1(input)))
out = self.conv2(F.relu(self.bn2(out)))
out = torch.cat(input,out)
return out class SingleLayer(nn.Module):
def __init__(self,nChannels,growthRate):
super(SingleLayer,self).__init__()
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels,growthRate,kernel_size=3,
padding=1,bias=False) def forward(self, *input):
out = self.conv1(F.relu(self.bn1(input)))
out = torch.cat(input,out)
return out class Transition(nn.Module):
def __int__(self,nChannels,nOutChannels):
super(Transition,self).__init__() self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels,nOutChannels,kernel_size=1,bias=False) def forward(self, *input):
out = self.conv1(F.relu(self.bn1(input)))
out = F.avg_pool2d(out,2)
return out class DenseNet(nn.Module):
def __init__(self,growthRate,depth,reduction,nClasses,bottleneck):
super(DenseNet,self).__init__()
#DenseBlock中非线性变换模块的个数
nNoneLinears = (depth-4)//3
if bottleneck:
nNoneLinears //=2 nChannels = 2*growthRate
self.conv1 = nn.Conv2d(3,nChannels,kernel_size=3,padding=1,bias=False)
self.denseblock1 = self._make_dense(nChannels,growthRate,nNoneLinears,bottleneck)
nChannels += nNoneLinears*growthRate
nOutChannels = int(math.floor(nChannels*reduction)) #向下取整
self.transition1 = Transition(nChannels,nOutChannels) nChannels = nOutChannels
self.denseblock2 = self._make_dense(nChannels,growthRate,nNoneLinears,bottleneck)
nChannels += nNoneLinears*growthRate
nOutChannels = int(math.floor(nChannels*reduction))
self.transition2 = Transition(nChannels, nOutChannels) nChannels = nOutChannels
self.denseblock3 = self._make_dense(nChannels, growthRate, nNoneLinears, bottleneck)
nChannels += nNoneLinears * growthRate self.bn1 = nn.BatchNorm2d(nChannels)
self.fc = nn.Linear(nChannels,nClasses) #参数初始化
for m in self.modules():
if isinstance(m,nn.Conv2d):
n = m.kernel_size[0]*m.kernel_size[1]*m.out_channels
m.weight.data.normal_(0,math.sqrt(2./n))
elif isinstance(m,nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m,nn.Linear):
m.bias.data.zero_() def _make_dense(self,nChannels,growthRate,nDenseBlocks,bottleneck):
layers = []
for i in range(int(nDenseBlocks)):
if bottleneck:
layers.append(Bottleneck(nChannels,growthRate))
else:
layers.append(SingleLayer(nChannels,growthRate))
nChannels+=growthRate
return nn.Sequential(*layers) def forward(self, *input):
out = self.conv1(input)
out = self.transition1(self.denseblock1(out))
out = self.transition2(self.denseblock2(out))
out = self.denseblock3(out)
out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)),8))
out = F.log_softmax(self.fc(out))
return out

DenseNet笔记的更多相关文章

  1. 论文笔记——DenseNet

    <Densely Connected Convolutional Networks>阅读笔记 代码地址:https://github.com/liuzhuang13/DenseNet 首先 ...

  2. 论文笔记:CNN经典结构2(WideResNet,FractalNet,DenseNet,ResNeXt,DPN,SENet)

    前言 在论文笔记:CNN经典结构1中主要讲了2012-2015年的一些经典CNN结构.本文主要讲解2016-2017年的一些经典CNN结构. CIFAR和SVHN上,DenseNet-BC优于ResN ...

  3. DenseNet 论文阅读笔记

    Densely Connected Convolutional Networks 原文链接 摘要 研究表明,如果卷积网络在接近输入和接近输出地层之间包含较短地连接,那么,该网络可以显著地加深,变得更精 ...

  4. tensorflow学习笔记——DenseNet

    完整代码及其数据,请移步小编的GitHub地址 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/DeepLearningNote 这里结合网络的资料和De ...

  5. 论文笔记系列-Neural Network Search :A Survey

    论文笔记系列-Neural Network Search :A Survey 论文 笔记 NAS automl survey review reinforcement learning Bayesia ...

  6. 论文笔记:CNN经典结构1(AlexNet,ZFNet,OverFeat,VGG,GoogleNet,ResNet)

    前言 本文主要介绍2012-2015年的一些经典CNN结构,从AlexNet,ZFNet,OverFeat到VGG,GoogleNetv1-v4,ResNetv1-v2. 在论文笔记:CNN经典结构2 ...

  7. 转载:DenseNet算法详解

    原文连接:http://blog.csdn.net/u014380165/article/details/75142664 参考连接:http://blog.csdn.net/u012938704/a ...

  8. Dual Path Networks(DPN)——一种结合了ResNet和DenseNet优势的新型卷积网络结构。深度残差网络通过残差旁支通路再利用特征,但残差通道不善于探索新特征。密集连接网络通过密集连接通路探索新特征,但有高冗余度。

    如何评价Dual Path Networks(DPN)? 论文链接:https://arxiv.org/pdf/1707.01629v1.pdf在ImagNet-1k数据集上,浅DPN超过了最好的Re ...

  9. DenseNet算法详解——思路就是highway,DneseNet在训练时十分消耗内存

    论文笔记:Densely Connected Convolutional Networks(DenseNet模型详解) 2017年09月28日 11:58:49 阅读数:1814 [ 转载自http: ...

随机推荐

  1. 【刷题】洛谷 P3455 [POI2007]ZAP-Queries

    题目描述 Byteasar the Cryptographer works on breaking the code of BSA (Byteotian Security Agency). He ha ...

  2. 使用Dom4解析xml

    XML是一种通用的数据交换格式,它的平台无关性.语言无关性.系统无关性.给数据集成与交互带来了极大的方便. XML在不同的语言环境中解析方式都是一样的,只不过实现的语法不同而已. XML的解析方式分为 ...

  3. python之旅:三元表达式、列表推导式、生成器表达式、函数递归、匿名函数、内置函数

    三元表达式 #以下是比较大小,并返回值 def max2(x,y): if x > y: return x else: return y res=max2(10,11) print(res) # ...

  4. Maven settings.xml配置(指定本地仓库、阿里云镜像设置)

    转: 详解Maven settings.xml配置(指定本地仓库.阿里云镜像设置) 更新时间:2018年12月18日 11:14:45   作者:AmaniZ   我要评论   一.settings. ...

  5. BZOJ 3160 FFT+马拉车

    题意显然 ans=回文子序列数目 - 回文子串数目 回文子串直接用马拉车跑出来 回文子序列一开始总是不知道怎么求 (太蠢了) 后面看了题解 构造一个神奇的卷积 (这个是我盗的图)地址 后面还有一些细节 ...

  6. 特征选择实践---python

    作者:城东链接:https://www.zhihu.com/question/28641663/answer/110165221来源:知乎著作权归作者所有.商业转载请联系作者获得授权,非商业转载请注明 ...

  7. R语言画图

    转http://www.cnblogs.com/jiangmiaomiao/p/6991632.html 0 引言 R支持4种图形类型: base graphics, grid graphics, l ...

  8. NATS_08:NATS客户端Go语言手动编写

    NATS客户端    一个NATS客户端是基于NATS服务端来说既可以是一个生产数据的也可以是消费数据的.生产数据的叫生产者英文为 publishers,消费数据的叫消费者英文为 subscriber ...

  9. caffe 配置文件详解

    主要是遇坑了,要记录一下. solver算是caffe的核心的核心,它协调着整个模型的运作.caffe程序运行必带的一个参数就是solver配置文件.运行代码一般为 # caffe train --s ...

  10. NAT—网络地址转换

    参考链接:http://www.qingsword.com/qing/745.html 视频链接: 一.什么是NAT? NAT --- Network Address Translation  也就是 ...