[AI] 论文笔记 - U-Net 简单而又接近本质的分割网络
越简单越接近本质。
参考资料
U-Net: Convolutional Networks for Biomedical Image Segmentation
Abstract & Introduction
论文中有几个关键词:
- contracting path 收缩路径;
- expansive path 扩张路径;
- precise localization 更精确的位置信息;
- overlap-tile 边界镜像翻转;
- random elastic deformations 随机弹性形变;
- invariance 尺度不变性;
- touching cells 指距离很近的两个细胞;
- seamless tilling 无缝拼接;
好了,说完这些关键词我们来看看这篇论文,这篇论文和他的结构一样简单易懂,很能说明问题。
首先,作者主要拿自己的网络和一个基于sliding window的方法做对比,作者先diss了一下这个方法存在以下问题:
Deep neural networks segment neuronal membranes in electron microscopy images (NIPS2012)
- 非常慢,计算冗余(sliding window的毛病大家都懂);
- 在位置精确性和特征提取之间存在一个平衡,因为更多的特征意味着更多的max-pooling,则会丢失掉更多位置信息。
作者的输入多层特征的思想是受以下论文启发:
- Hypercolumns for object segmentation and fine-grained localization (2014)
- Image segmentation with cascaded hierarchical models and logistic disjunctive normal networks (2013)
这两篇论文指出把多层特征(the features from multiple layers)输入到classifier能够得到更好的特征提取和更好的位置信息(good localization and the use of context are possible at the same time)。
U-Net和其他网络的不同之处在于,上采样(Upsampling)过程中也有很多维特征,让特征流向更高分辨率的卷积层。
由于网络使用的卷积是3x3 unpadded convolutions,所以特征图会缩小,为了让输出的图像和输入图像的大小无缝拼接(seamless tilling),则要用到边界镜像翻转(overlap-tile),具体做法如下图:
Architecture
网络结构
使用3x3 unpadded convolutions,所以特征图会不断缩小,在横向拼接特征的时候,也要对特征图进行裁剪,以保持特征图大小一致。
全部使用ReLU激活函数。
权值初始化使用何恺明的方法:
Surpassing humanlevel performance on imagenet classification
具体做法就是一个标准差满足sqrt(2/N)
的高斯分布,其中的N代表一个神经元的输入节点数(例如一个3x3卷积核的输入是64维的话,那么N=9x64=576)
训练
在训练时作者更倾向于更大的图像输入,所以干脆将batch_size设置为1,所以在优化器的使用方面,使用到了带有动量的优化器,并且动量设置的很大(0.99),这么做是为了让以前的样本可以决定当前梯度更新的方向(因为batch_size太小啦,可以理解)。
损失函数就是pixel-wise soft-max + cross_entropy了。
数据增强
随机弹性形变和weight map:
随机弹性形变就是先用3x3的粗网格初始化随机形变,然后从标准差为10pixel的高斯分布中初始化随机位移矢量,再用bicubic双三次插值来计算每个像素的位移。
随机弹性形变的目的是让网络有invariance(尺度不变性)。
那么weight map是为了强制让网络学习touching cells之间的背景,这些位于touching cells之间的背景在损失函数的权重很高,如下图:
weight map的具体计算方式如下:
代码
最后来看看代码吧:https://github.com/milesial/Pytorch-UNet
整体模型:
class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.inc = inconv(n_channels, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 512)
self.up1 = up(1024, 256)
self.up2 = up(512, 128)
self.up3 = up(256, 64)
self.up4 = up(128, 64)
self.outc = outconv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return F.sigmoid(x)
细节部分:
class double_conv(nn.Module):
'''(conv => BN => ReLU) * 2'''
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class inconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(inconv, self).__init__()
self.conv = double_conv(in_ch, out_ch)
def forward(self, x):
x = self.conv(x)
return x
class down(nn.Module):
def __init__(self, in_ch, out_ch):
super(down, self).__init__()
self.mpconv = nn.Sequential(
nn.MaxPool2d(2),
double_conv(in_ch, out_ch)
)
def forward(self, x):
x = self.mpconv(x)
return x
class up(nn.Module):
def __init__(self, in_ch, out_ch, bilinear=True):
super(up, self).__init__()
# would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
self.conv = double_conv(in_ch, out_ch)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2))
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1)
def forward(self, x):
x = self.conv(x)
return x
训练:
optimizer = optim.SGD(net.parameters(),
lr=lr,
momentum=0.9,
weight_decay=0.0005)
criterion = nn.BCELoss()
[AI] 论文笔记 - U-Net 简单而又接近本质的分割网络的更多相关文章
- [AI] 论文笔记 - CVPR2018 Super SloMo: High Quality Estimation of Multiple Intermediate Frames for Video Interpolation
写在前面 原始视频(30fps) 补帧后的视频(240fps) 本文是博主在做实验的过程中使用到的方法,刚好也做为了本科毕设的翻译文章,现在把它搬运到博客上来,因为觉得这篇文章的思路真的不错. 这篇文 ...
- 【论文笔记】Learning Fashion Compatibility with Bidirectional LSTMs
论文:<Learning Fashion Compatibility with Bidirectional LSTMs> 论文地址:https://arxiv.org/abs/1707.0 ...
- Deep Learning论文笔记之(八)Deep Learning最新综述
Deep Learning论文笔记之(八)Deep Learning最新综述 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文,但老感觉看完 ...
- 论文笔记:Mastering the game of Go with deep neural networks and tree search
Mastering the game of Go with deep neural networks and tree search Nature 2015 这是本人论文笔记系列第二篇 Nature ...
- 论文笔记:CNN经典结构1(AlexNet,ZFNet,OverFeat,VGG,GoogleNet,ResNet)
前言 本文主要介绍2012-2015年的一些经典CNN结构,从AlexNet,ZFNet,OverFeat到VGG,GoogleNetv1-v4,ResNetv1-v2. 在论文笔记:CNN经典结构2 ...
- AI理论学习笔记(一):深度学习的前世今生
AI理论学习笔记(一):深度学习的前世今生 大家还记得以深度学习技术为基础的电脑程序AlphaGo吗?这是人类历史中在某种意义的第一次机器打败人类的例子,其最大的魅力就是深度学习(Deep Learn ...
- Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现(转)
Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文, ...
- 论文笔记之:Visual Tracking with Fully Convolutional Networks
论文笔记之:Visual Tracking with Fully Convolutional Networks ICCV 2015 CUHK 本文利用 FCN 来做跟踪问题,但开篇就提到并非将其看做 ...
- Twitter 新一代流处理利器——Heron 论文笔记之Heron架构
Twitter 新一代流处理利器--Heron 论文笔记之Heron架构 标签(空格分隔): Streaming-process realtime-process Heron Architecture ...
随机推荐
- python入门之垃圾回收机制
目录 一 引入 二.什么是垃圾回收机制? 三.为什么要用垃圾回收机制? 四.垃圾回收机制原理分析 4.1.什么是引用计数? 4.2.引用计数扩展阅读 4.2.1 标记-清除 4.2.2 分代回收 一 ...
- Docker下打包FastDFS镜像以及上传遇到的问题
官方地址:https://github.com/happyfish100/fastdfs 一.先下载个包,然后解压(自己找个目录下载即可) [root@localhost soft]# wget ht ...
- Django-07-Model操作
一.数据库的配置 1. 数据库支持 django默认支持sqlite.mysql.oracle.postgresql数据库 <1> sqlite django默认使用sqlite的数据库 ...
- C 语言 基础篇
1.机器语言 2.汇编语言 3.高级语言:C.C++.Java(基于虚拟机) C语言开发:Unix,Linux,Mac OS,iOS,Android,Windows,Ubuntu 开发环境:visua ...
- MySQL学习一:建表
目标:创建三张表,学生表student(sid,name,gender), 课程表course(cid,name), 分数mark(mid, sid, cid, gender); 要求sid, cid ...
- Go操作ini文件
除了采用json,yaml等格式之外,常用的配置文件还有ini格式的. cfg, err := ini.Load(fyPath + "\\ServerSystem.ini") // ...
- Python与MogoDB交互
睡了大半天,终于有时间整理下拖欠的MongoDB的封装啦. 首先我们先进行下数据库的连接: conn = MongoClient('localhost',27017) # 建立连接 result = ...
- Spring boot java.lang.NoClassDefFoundError: org/springframework/boot/bind/RelaxedPropertyResolver
Spring boot 2.0.3 RELEASE 配置报错 java.lang.NoClassDefFoundError: org/springframework/boot/bind/Relaxed ...
- jedis异常:Could not get a resource from the pool
前几天公司后端系统出现了故障,导致app多个功能无法使用,查看日志,发现日志出现较多的redis.clients.jedis.exceptions.JedisConnectionException: ...
- ubuntu16.04 打开chrome弹出“Enter password to unlock your login keyring”解决方法
问题如图 输入开机密码发现验证失败. 解决 命令: find ~/ -name login.keyring 查找相关文件. 命令: sudo rm -rf /home/la/.local/share/ ...