越简单越接近本质。

参考资料

U-Net: Convolutional Networks for Biomedical Image Segmentation

Abstract & Introduction

论文中有几个关键词:

  • contracting path 收缩路径;
  • expansive path 扩张路径;
  • precise localization 更精确的位置信息;
  • overlap-tile 边界镜像翻转;
  • random elastic deformations 随机弹性形变;
  • invariance 尺度不变性;
  • touching cells 指距离很近的两个细胞;
  • seamless tilling 无缝拼接;

好了,说完这些关键词我们来看看这篇论文,这篇论文和他的结构一样简单易懂,很能说明问题。

首先,作者主要拿自己的网络和一个基于sliding window的方法做对比,作者先diss了一下这个方法存在以下问题:

Deep neural networks segment neuronal membranes in electron microscopy images (NIPS2012)

  • 非常慢,计算冗余(sliding window的毛病大家都懂);
  • 在位置精确性和特征提取之间存在一个平衡,因为更多的特征意味着更多的max-pooling,则会丢失掉更多位置信息。

作者的输入多层特征的思想是受以下论文启发:

  • Hypercolumns for object segmentation and fine-grained localization (2014)
  • Image segmentation with cascaded hierarchical models and logistic disjunctive normal networks (2013)

这两篇论文指出把多层特征(the features from multiple layers)输入到classifier能够得到更好的特征提取和更好的位置信息(good localization and the use of context are possible at the same time)。

U-Net和其他网络的不同之处在于,上采样(Upsampling)过程中也有很多维特征,让特征流向更高分辨率的卷积层。

由于网络使用的卷积是3x3 unpadded convolutions,所以特征图会缩小,为了让输出的图像和输入图像的大小无缝拼接(seamless tilling),则要用到边界镜像翻转(overlap-tile),具体做法如下图:

Architecture

网络结构

使用3x3 unpadded convolutions,所以特征图会不断缩小,在横向拼接特征的时候,也要对特征图进行裁剪,以保持特征图大小一致。

全部使用ReLU激活函数。

权值初始化使用何恺明的方法:

Surpassing humanlevel performance on imagenet classification

具体做法就是一个标准差满足sqrt(2/N)的高斯分布,其中的N代表一个神经元的输入节点数(例如一个3x3卷积核的输入是64维的话,那么N=9x64=576)

训练

在训练时作者更倾向于更大的图像输入,所以干脆将batch_size设置为1,所以在优化器的使用方面,使用到了带有动量的优化器,并且动量设置的很大(0.99),这么做是为了让以前的样本可以决定当前梯度更新的方向(因为batch_size太小啦,可以理解)。

损失函数就是pixel-wise soft-max + cross_entropy了。

数据增强

随机弹性形变和weight map:

随机弹性形变就是先用3x3的粗网格初始化随机形变,然后从标准差为10pixel的高斯分布中初始化随机位移矢量,再用bicubic双三次插值来计算每个像素的位移。

随机弹性形变的目的是让网络有invariance(尺度不变性)。

那么weight map是为了强制让网络学习touching cells之间的背景,这些位于touching cells之间的背景在损失函数的权重很高,如下图:

weight map的具体计算方式如下:

代码

最后来看看代码吧:https://github.com/milesial/Pytorch-UNet

整体模型:

class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.inc = inconv(n_channels, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 512)
self.up1 = up(1024, 256)
self.up2 = up(512, 128)
self.up3 = up(256, 64)
self.up4 = up(128, 64)
self.outc = outconv(64, n_classes) def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return F.sigmoid(x)

细节部分:

class double_conv(nn.Module):
'''(conv => BN => ReLU) * 2'''
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
) def forward(self, x):
x = self.conv(x)
return x class inconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(inconv, self).__init__()
self.conv = double_conv(in_ch, out_ch) def forward(self, x):
x = self.conv(x)
return x class down(nn.Module):
def __init__(self, in_ch, out_ch):
super(down, self).__init__()
self.mpconv = nn.Sequential(
nn.MaxPool2d(2),
double_conv(in_ch, out_ch)
) def forward(self, x):
x = self.mpconv(x)
return x class up(nn.Module):
def __init__(self, in_ch, out_ch, bilinear=True):
super(up, self).__init__() # would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) self.conv = double_conv(in_ch, out_ch) def forward(self, x1, x2):
x1 = self.up(x1) # input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2)) x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1) def forward(self, x):
x = self.conv(x)
return x

训练:

optimizer = optim.SGD(net.parameters(),
lr=lr,
momentum=0.9,
weight_decay=0.0005) criterion = nn.BCELoss()

[AI] 论文笔记 - U-Net 简单而又接近本质的分割网络的更多相关文章

  1. [AI] 论文笔记 - CVPR2018 Super SloMo: High Quality Estimation of Multiple Intermediate Frames for Video Interpolation

    写在前面 原始视频(30fps) 补帧后的视频(240fps) 本文是博主在做实验的过程中使用到的方法,刚好也做为了本科毕设的翻译文章,现在把它搬运到博客上来,因为觉得这篇文章的思路真的不错. 这篇文 ...

  2. 【论文笔记】Learning Fashion Compatibility with Bidirectional LSTMs

    论文:<Learning Fashion Compatibility with Bidirectional LSTMs> 论文地址:https://arxiv.org/abs/1707.0 ...

  3. Deep Learning论文笔记之(八)Deep Learning最新综述

    Deep Learning论文笔记之(八)Deep Learning最新综述 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文,但老感觉看完 ...

  4. 论文笔记:Mastering the game of Go with deep neural networks and tree search

    Mastering the game of Go with deep neural networks and tree search Nature 2015  这是本人论文笔记系列第二篇 Nature ...

  5. 论文笔记:CNN经典结构1(AlexNet,ZFNet,OverFeat,VGG,GoogleNet,ResNet)

    前言 本文主要介绍2012-2015年的一些经典CNN结构,从AlexNet,ZFNet,OverFeat到VGG,GoogleNetv1-v4,ResNetv1-v2. 在论文笔记:CNN经典结构2 ...

  6. AI理论学习笔记(一):深度学习的前世今生

    AI理论学习笔记(一):深度学习的前世今生 大家还记得以深度学习技术为基础的电脑程序AlphaGo吗?这是人类历史中在某种意义的第一次机器打败人类的例子,其最大的魅力就是深度学习(Deep Learn ...

  7. Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现(转)

    Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现 zouxy09@qq.com http://blog.csdn.net/zouxy09          自己平时看了一些论文, ...

  8. 论文笔记之:Visual Tracking with Fully Convolutional Networks

    论文笔记之:Visual Tracking with Fully Convolutional Networks ICCV 2015  CUHK 本文利用 FCN 来做跟踪问题,但开篇就提到并非将其看做 ...

  9. Twitter 新一代流处理利器——Heron 论文笔记之Heron架构

    Twitter 新一代流处理利器--Heron 论文笔记之Heron架构 标签(空格分隔): Streaming-process realtime-process Heron Architecture ...

随机推荐

  1. Tplink路由器怎么设置无线桥接(转载)

    原始文章路径:http://www.pc6.com/video/8548.html .以下为转载内容(少量修改) 第1步 进入副路由器后台,选择网络参数,LAN口设置,把IP地址改成目标网络同一段的I ...

  2. nginx入门系列之应用场景介绍

    目录 HTTP服务器 反向代理服务器 作为一个虚拟主机下多个应用的反向代理 作为多个虚拟主机的反向代理 负载均衡器 简单轮训策略 最小连接数策略 客户端IP哈希策略 服务器权重策略 邮件代理服务器 官 ...

  3. [转]解决 gem install 安装失败

    链接地址:https://blog.csdn.net/weixin_42414461/article/details/85337465

  4. npm 的一些命令

    查看项目中是否安装某个插件 npm [name] -v [name] 为要查询的插件的名字,如果已经安装就会显示该插件的版本号 npm list 查看项目中所有已安装的插件

  5. win7安装 truffle

    1. 最近有个项目需要用到区块链,第一次玩不太熟悉.现在电脑上安装个  truffle,作为一个区块链节点 2. 安装 truffle ,之前需要安装其他几个软件 truffle的安装需要首先装有:n ...

  6. 阿里云k8s事件监控

    事件监控是Kubernetes中的另一种监控方式,可以弥补资源监控在实时性.准确性和场景上的缺欠.Kubernetes的架构设计是基于状态机的,不同的状态之间进行转换则会生成相应的事件,正常的状态之间 ...

  7. 【Docker学习之三】Docker查找拉取镜像、启动容器、容器使用

    环境 docker-ce-19.03.1-3.el7.x86_64 CentOS 7 一.查找.拉取镜像.启动容器1.查找镜像-docker search默认查找Docker Hub上的镜像,举例:D ...

  8. 如何录制高清GIF格式的图片

    如何录制高清GIF格式的图片 工具:傲软GIF 下载地址:https://www.apowersoft.cn/gif 特点:质量高,能够一帧一帧的修改 使用简单.就不说了.自行尝试.这里只是提供一个制 ...

  9. jquery如何生成图片验证码

    jQuery(function($){ /**生成一个随机数**/ function randomNum(min, max) { return Math.floor(Math.random() * ( ...

  10. 22 Oracle数据库基础入门

    1.Oracle数据库的介绍 ORACLE 数据库系统是美国ORACLE 公司(甲骨文)提供的以分布式数据库为核心的一组软件产品,是目前最流行的客户/服务器(CLIENT/SERVER)或 B/S 体 ...