越简单越接近本质。

参考资料

U-Net: Convolutional Networks for Biomedical Image Segmentation

Abstract & Introduction

论文中有几个关键词:

  • contracting path 收缩路径;
  • expansive path 扩张路径;
  • precise localization 更精确的位置信息;
  • overlap-tile 边界镜像翻转;
  • random elastic deformations 随机弹性形变;
  • invariance 尺度不变性;
  • touching cells 指距离很近的两个细胞;
  • seamless tilling 无缝拼接;

好了,说完这些关键词我们来看看这篇论文,这篇论文和他的结构一样简单易懂,很能说明问题。

首先,作者主要拿自己的网络和一个基于sliding window的方法做对比,作者先diss了一下这个方法存在以下问题:

Deep neural networks segment neuronal membranes in electron microscopy images (NIPS2012)

  • 非常慢,计算冗余(sliding window的毛病大家都懂);
  • 在位置精确性和特征提取之间存在一个平衡,因为更多的特征意味着更多的max-pooling,则会丢失掉更多位置信息。

作者的输入多层特征的思想是受以下论文启发:

  • Hypercolumns for object segmentation and fine-grained localization (2014)
  • Image segmentation with cascaded hierarchical models and logistic disjunctive normal networks (2013)

这两篇论文指出把多层特征(the features from multiple layers)输入到classifier能够得到更好的特征提取和更好的位置信息(good localization and the use of context are possible at the same time)。

U-Net和其他网络的不同之处在于,上采样(Upsampling)过程中也有很多维特征,让特征流向更高分辨率的卷积层。

由于网络使用的卷积是3x3 unpadded convolutions,所以特征图会缩小,为了让输出的图像和输入图像的大小无缝拼接(seamless tilling),则要用到边界镜像翻转(overlap-tile),具体做法如下图:

Architecture

网络结构

使用3x3 unpadded convolutions,所以特征图会不断缩小,在横向拼接特征的时候,也要对特征图进行裁剪,以保持特征图大小一致。

全部使用ReLU激活函数。

权值初始化使用何恺明的方法:

Surpassing humanlevel performance on imagenet classification

具体做法就是一个标准差满足sqrt(2/N)的高斯分布,其中的N代表一个神经元的输入节点数(例如一个3x3卷积核的输入是64维的话,那么N=9x64=576)

训练

在训练时作者更倾向于更大的图像输入,所以干脆将batch_size设置为1,所以在优化器的使用方面,使用到了带有动量的优化器,并且动量设置的很大(0.99),这么做是为了让以前的样本可以决定当前梯度更新的方向(因为batch_size太小啦,可以理解)。

损失函数就是pixel-wise soft-max + cross_entropy了。

数据增强

随机弹性形变和weight map:

随机弹性形变就是先用3x3的粗网格初始化随机形变,然后从标准差为10pixel的高斯分布中初始化随机位移矢量,再用bicubic双三次插值来计算每个像素的位移。

随机弹性形变的目的是让网络有invariance(尺度不变性)。

那么weight map是为了强制让网络学习touching cells之间的背景,这些位于touching cells之间的背景在损失函数的权重很高,如下图:

weight map的具体计算方式如下:

代码

最后来看看代码吧:https://github.com/milesial/Pytorch-UNet

整体模型:

class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.inc = inconv(n_channels, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 512)
self.up1 = up(1024, 256)
self.up2 = up(512, 128)
self.up3 = up(256, 64)
self.up4 = up(128, 64)
self.outc = outconv(64, n_classes) def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return F.sigmoid(x)

细节部分:

class double_conv(nn.Module):
'''(conv => BN => ReLU) * 2'''
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
) def forward(self, x):
x = self.conv(x)
return x class inconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(inconv, self).__init__()
self.conv = double_conv(in_ch, out_ch) def forward(self, x):
x = self.conv(x)
return x class down(nn.Module):
def __init__(self, in_ch, out_ch):
super(down, self).__init__()
self.mpconv = nn.Sequential(
nn.MaxPool2d(2),
double_conv(in_ch, out_ch)
) def forward(self, x):
x = self.mpconv(x)
return x class up(nn.Module):
def __init__(self, in_ch, out_ch, bilinear=True):
super(up, self).__init__() # would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) self.conv = double_conv(in_ch, out_ch) def forward(self, x1, x2):
x1 = self.up(x1) # input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2)) x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1) def forward(self, x):
x = self.conv(x)
return x

训练:

optimizer = optim.SGD(net.parameters(),
lr=lr,
momentum=0.9,
weight_decay=0.0005) criterion = nn.BCELoss()

[AI] 论文笔记 - U-Net 简单而又接近本质的分割网络的更多相关文章

  1. [AI] 论文笔记 - CVPR2018 Super SloMo: High Quality Estimation of Multiple Intermediate Frames for Video Interpolation

    写在前面 原始视频(30fps) 补帧后的视频(240fps) 本文是博主在做实验的过程中使用到的方法,刚好也做为了本科毕设的翻译文章,现在把它搬运到博客上来,因为觉得这篇文章的思路真的不错. 这篇文 ...

  2. 【论文笔记】Learning Fashion Compatibility with Bidirectional LSTMs

    论文:<Learning Fashion Compatibility with Bidirectional LSTMs> 论文地址:https://arxiv.org/abs/1707.0 ...

  3. Deep Learning论文笔记之(八)Deep Learning最新综述

    Deep Learning论文笔记之(八)Deep Learning最新综述 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文,但老感觉看完 ...

  4. 论文笔记:Mastering the game of Go with deep neural networks and tree search

    Mastering the game of Go with deep neural networks and tree search Nature 2015  这是本人论文笔记系列第二篇 Nature ...

  5. 论文笔记:CNN经典结构1(AlexNet,ZFNet,OverFeat,VGG,GoogleNet,ResNet)

    前言 本文主要介绍2012-2015年的一些经典CNN结构,从AlexNet,ZFNet,OverFeat到VGG,GoogleNetv1-v4,ResNetv1-v2. 在论文笔记:CNN经典结构2 ...

  6. AI理论学习笔记(一):深度学习的前世今生

    AI理论学习笔记(一):深度学习的前世今生 大家还记得以深度学习技术为基础的电脑程序AlphaGo吗?这是人类历史中在某种意义的第一次机器打败人类的例子,其最大的魅力就是深度学习(Deep Learn ...

  7. Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现(转)

    Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现 zouxy09@qq.com http://blog.csdn.net/zouxy09          自己平时看了一些论文, ...

  8. 论文笔记之:Visual Tracking with Fully Convolutional Networks

    论文笔记之:Visual Tracking with Fully Convolutional Networks ICCV 2015  CUHK 本文利用 FCN 来做跟踪问题,但开篇就提到并非将其看做 ...

  9. Twitter 新一代流处理利器——Heron 论文笔记之Heron架构

    Twitter 新一代流处理利器--Heron 论文笔记之Heron架构 标签(空格分隔): Streaming-process realtime-process Heron Architecture ...

随机推荐

  1. python入门之垃圾回收机制

    目录 一 引入 二.什么是垃圾回收机制? 三.为什么要用垃圾回收机制? 四.垃圾回收机制原理分析 4.1.什么是引用计数? 4.2.引用计数扩展阅读 4.2.1 标记-清除 4.2.2 分代回收 一 ...

  2. Docker下打包FastDFS镜像以及上传遇到的问题

    官方地址:https://github.com/happyfish100/fastdfs 一.先下载个包,然后解压(自己找个目录下载即可) [root@localhost soft]# wget ht ...

  3. Django-07-Model操作

    一.数据库的配置 1. 数据库支持 django默认支持sqlite.mysql.oracle.postgresql数据库  <1> sqlite django默认使用sqlite的数据库 ...

  4. C 语言 基础篇

    1.机器语言 2.汇编语言 3.高级语言:C.C++.Java(基于虚拟机) C语言开发:Unix,Linux,Mac OS,iOS,Android,Windows,Ubuntu 开发环境:visua ...

  5. MySQL学习一:建表

    目标:创建三张表,学生表student(sid,name,gender), 课程表course(cid,name), 分数mark(mid, sid, cid, gender); 要求sid, cid ...

  6. Go操作ini文件

    除了采用json,yaml等格式之外,常用的配置文件还有ini格式的. cfg, err := ini.Load(fyPath + "\\ServerSystem.ini") // ...

  7. Python与MogoDB交互

    睡了大半天,终于有时间整理下拖欠的MongoDB的封装啦. 首先我们先进行下数据库的连接: conn = MongoClient('localhost',27017) # 建立连接 result = ...

  8. Spring boot java.lang.NoClassDefFoundError: org/springframework/boot/bind/RelaxedPropertyResolver

    Spring boot 2.0.3 RELEASE 配置报错 java.lang.NoClassDefFoundError: org/springframework/boot/bind/Relaxed ...

  9. jedis异常:Could not get a resource from the pool

    前几天公司后端系统出现了故障,导致app多个功能无法使用,查看日志,发现日志出现较多的redis.clients.jedis.exceptions.JedisConnectionException: ...

  10. ubuntu16.04 打开chrome弹出“Enter password to unlock your login keyring”解决方法

    问题如图 输入开机密码发现验证失败. 解决 命令: find ~/ -name login.keyring 查找相关文件. 命令: sudo rm -rf /home/la/.local/share/ ...