U-Net网络的Pytorch实现
1.文章原文地址
U-Net: Convolutional Networks for Biomedical Image Segmentation
2.文章摘要
普遍认为成功训练深度神经网络需要大量标注的训练数据。在本文中,我们提出了一个网络结构,以及使用数据增强的策略来训练网络使得可用的标注样本更加有效的被使用。这个网络是由一个捕捉上下文信息的收缩部分和与之相对称的放大部分,后者能够准确的定位。我们的结果展示了这个网络可以进行端到端的训练,使用非常少的数据就可以达到非常好的结果,并且超过了当前的最佳方法(滑动窗网络)在ISBII挑战赛上电子显微镜下神经结构的分割的结果。利用透射光显微镜图像使用相同网络进行训练,我们大幅度的赢得了2015年的ISBI细胞追踪挑战赛。而且,这个网络非常快,在一个当前的GPU上,分割一个512x512的图像所花费的时间少于一秒。完整的代码以及训练好的网络可见(基于Caffe)http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.
3.网络结构
4.Pytorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary class unetConv2(nn.Module):
def __init__(self,in_size,out_size,is_batchnorm):
super(unetConv2,self).__init__() if is_batchnorm:
self.conv1=nn.Sequential(
nn.Conv2d(in_size,out_size,kernel_size=3,stride=1,padding=0),
nn.BatchNorm2d(out_size),
nn.ReLU(inplace=True),
)
self.conv2=nn.Sequential(
nn.Conv2d(out_size,out_size,kernel_size=3,stride=1,padding=0),
nn.BatchNorm2d(out_size),
nn.ReLU(inplace=True),
)
else:
self.conv1=nn.Sequential(
nn.Conv2d(in_size,out_size,kernel_size=3,stride=1,padding=0),
nn.ReLU(inplace=True),
)
self.conv2=nn.Sequential(
nn.Conv2d(out_size,out_size,kernel_size=3,stride=1,padding=0),
nn.ReLU(inplace=True)
)
def forward(self, inputs):
outputs=self.conv1(inputs)
outputs=self.conv2(outputs) return outputs class unetUp(nn.Module):
def __init__(self,in_size,out_size,is_deconv):
super(unetUp,self).__init__()
self.conv=unetConv2(in_size,out_size,False)
if is_deconv:
self.up=nn.ConvTranspose2d(in_size,out_size,kernel_size=2,stride=2)
else:
self.up=nn.UpsamplingBilinear2d(scale_factor=2) def forward(self, inputs1,inputs2):
outputs2=self.up(inputs2)
offset=outputs2.size()[2]-inputs1.size()[2]
padding=2*[offset//2,offset//2]
outputs1=F.pad(inputs1,padding) #padding is negative, size become smaller return self.conv(torch.cat([outputs1,outputs2],1)) class unet(nn.Module):
def __init__(self,feature_scale=4,n_classes=21,is_deconv=True,in_channels=3,is_batchnorm=True):
super(unet,self).__init__()
self.is_deconv=is_deconv
self.in_channels=in_channels
self.is_batchnorm=is_batchnorm
self.feature_scale=feature_scale filters=[64,128,256,512,1024]
filters=[int(x/self.feature_scale) for x in filters] #downsample
self.conv1=unetConv2(self.in_channels,filters[0],self.is_batchnorm)
self.maxpool1=nn.MaxPool2d(kernel_size=2) self.conv2=unetConv2(filters[0],filters[1],self.is_batchnorm)
self.maxpool2=nn.MaxPool2d(kernel_size=2) self.conv3=unetConv2(filters[1],filters[2],self.is_batchnorm)
self.maxpool3=nn.MaxPool2d(kernel_size=2) self.conv4=unetConv2(filters[2],filters[3],self.is_batchnorm)
self.maxpool4=nn.MaxPool2d(kernel_size=2) self.center=unetConv2(filters[3],filters[4],self.is_batchnorm) #umsampling
self.up_concat4=unetUp(filters[4],filters[3],self.is_deconv)
self.up_concat3=unetUp(filters[3],filters[2],self.is_deconv)
self.up_concat2=unetUp(filters[2],filters[1],self.is_deconv)
self.up_concat1=unetUp(filters[1],filters[0],self.is_deconv) #final conv (without and concat)
self.final=nn.Conv2d(filters[0],n_classes,kernel_size=1) def forward(self, inputs):
conv1=self.conv1(inputs)
maxpool1=self.maxpool1(conv1) conv2=self.conv2(maxpool1)
maxpool2=self.maxpool2(conv2) conv3=self.conv3(maxpool2)
maxpool3=self.maxpool3(conv3) conv4=self.conv4(maxpool3)
maxpool4=self.maxpool4(conv4) center=self.center(maxpool4)
up4=self.up_concat4(conv4,center)
up3=self.up_concat3(conv3,up4)
up2=self.up_concat2(conv2,up3)
up1=self.up_concat1(conv1,up2) final=self.final(up1) return final if __name__=="__main__":
model=unet(feature_scale=1)
print(summary(model,(3,572,572)))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 570, 570] 1,792
BatchNorm2d-2 [-1, 64, 570, 570] 128
ReLU-3 [-1, 64, 570, 570] 0
Conv2d-4 [-1, 64, 568, 568] 36,928
BatchNorm2d-5 [-1, 64, 568, 568] 128
ReLU-6 [-1, 64, 568, 568] 0
unetConv2-7 [-1, 64, 568, 568] 0
MaxPool2d-8 [-1, 64, 284, 284] 0
Conv2d-9 [-1, 128, 282, 282] 73,856
BatchNorm2d-10 [-1, 128, 282, 282] 256
ReLU-11 [-1, 128, 282, 282] 0
Conv2d-12 [-1, 128, 280, 280] 147,584
BatchNorm2d-13 [-1, 128, 280, 280] 256
ReLU-14 [-1, 128, 280, 280] 0
unetConv2-15 [-1, 128, 280, 280] 0
MaxPool2d-16 [-1, 128, 140, 140] 0
Conv2d-17 [-1, 256, 138, 138] 295,168
BatchNorm2d-18 [-1, 256, 138, 138] 512
ReLU-19 [-1, 256, 138, 138] 0
Conv2d-20 [-1, 256, 136, 136] 590,080
BatchNorm2d-21 [-1, 256, 136, 136] 512
ReLU-22 [-1, 256, 136, 136] 0
unetConv2-23 [-1, 256, 136, 136] 0
MaxPool2d-24 [-1, 256, 68, 68] 0
Conv2d-25 [-1, 512, 66, 66] 1,180,160
BatchNorm2d-26 [-1, 512, 66, 66] 1,024
ReLU-27 [-1, 512, 66, 66] 0
Conv2d-28 [-1, 512, 64, 64] 2,359,808
BatchNorm2d-29 [-1, 512, 64, 64] 1,024
ReLU-30 [-1, 512, 64, 64] 0
unetConv2-31 [-1, 512, 64, 64] 0
MaxPool2d-32 [-1, 512, 32, 32] 0
Conv2d-33 [-1, 1024, 30, 30] 4,719,616
BatchNorm2d-34 [-1, 1024, 30, 30] 2,048
ReLU-35 [-1, 1024, 30, 30] 0
Conv2d-36 [-1, 1024, 28, 28] 9,438,208
BatchNorm2d-37 [-1, 1024, 28, 28] 2,048
ReLU-38 [-1, 1024, 28, 28] 0
unetConv2-39 [-1, 1024, 28, 28] 0
ConvTranspose2d-40 [-1, 512, 56, 56] 2,097,664
Conv2d-41 [-1, 512, 54, 54] 4,719,104
ReLU-42 [-1, 512, 54, 54] 0
Conv2d-43 [-1, 512, 52, 52] 2,359,808
ReLU-44 [-1, 512, 52, 52] 0
unetConv2-45 [-1, 512, 52, 52] 0
unetUp-46 [-1, 512, 52, 52] 0
ConvTranspose2d-47 [-1, 256, 104, 104] 524,544
Conv2d-48 [-1, 256, 102, 102] 1,179,904
ReLU-49 [-1, 256, 102, 102] 0
Conv2d-50 [-1, 256, 100, 100] 590,080
ReLU-51 [-1, 256, 100, 100] 0
unetConv2-52 [-1, 256, 100, 100] 0
unetUp-53 [-1, 256, 100, 100] 0
ConvTranspose2d-54 [-1, 128, 200, 200] 131,200
Conv2d-55 [-1, 128, 198, 198] 295,040
ReLU-56 [-1, 128, 198, 198] 0
Conv2d-57 [-1, 128, 196, 196] 147,584
ReLU-58 [-1, 128, 196, 196] 0
unetConv2-59 [-1, 128, 196, 196] 0
unetUp-60 [-1, 128, 196, 196] 0
ConvTranspose2d-61 [-1, 64, 392, 392] 32,832
Conv2d-62 [-1, 64, 390, 390] 73,792
ReLU-63 [-1, 64, 390, 390] 0
Conv2d-64 [-1, 64, 388, 388] 36,928
ReLU-65 [-1, 64, 388, 388] 0
unetConv2-66 [-1, 64, 388, 388] 0
unetUp-67 [-1, 64, 388, 388] 0
Conv2d-68 [-1, 21, 388, 388] 1,365
================================================================
Total params: 31,040,981
Trainable params: 31,040,981
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.74
Forward/backward pass size (MB): 3158.15
Params size (MB): 118.41
Estimated Total Size (MB): 3280.31
参考
https://github.com/meetshah1995/pytorch-semseg
U-Net网络的Pytorch实现的更多相关文章
- 群等变网络的pytorch实现
CNN对于旋转不具有等变性,对于平移有等变性,data augmentation的提出就是为了解决这个问题,但是data augmentation需要很大的模型容量,更多的迭代次数才能够在训练数据集合 ...
- ResNet网络的Pytorch实现
1.文章原文地址 Deep Residual Learning for Image Recognition 2.文章摘要 神经网络的层次越深越难训练.我们提出了一个残差学习框架来简化网络的训练,这些 ...
- GoogLeNet网络的Pytorch实现
1.文章原文地址 Going deeper with convolutions 2.文章摘要 我们提出了一种代号为Inception的深度卷积神经网络,它在ILSVRC2014的分类和检测任务上都取得 ...
- AlexNet网络的Pytorch实现
1.文章原文地址 ImageNet Classification with Deep Convolutional Neural Networks 2.文章摘要 我们训练了一个大型的深度卷积神经网络用于 ...
- VGG网络的Pytorch实现
1.文章原文地址 Very Deep Convolutional Networks for Large-Scale Image Recognition 2.文章摘要 在这项工作中,我们研究了在大规模的 ...
- SegNet网络的Pytorch实现
1.文章原文地址 SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation 2.文章摘要 语义分 ...
- 【转载】PyTorch系列 (二):pytorch数据读取
原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...
- pytorch预训练
Pytorch预训练模型以及修改 pytorch中自带几种常用的深度学习网络预训练模型,torchvision.models包中包含alexnet.densenet.inception.resnet. ...
- PyTorch使用总览
PyTorch使用总览 https://blog.csdn.net/u014380165/article/details/79222243 深度学习框架训练模型时的代码主要包含数据读取.网络构建和其他 ...
随机推荐
- 07点睛Spring MVC4.1-ContentNegotiatingViewResolver
转发地址:https://www.iteye.com/blog/wiselyman-2214965 7.1 ContentNegotiatingViewResolver ContentNegotiat ...
- Flutter 路由传入中文参数报错无法push问题
flutter自带路由传递参数和使用第三方库fluro路由传递参数都可以通过一下方式解决问题 String jsonString = json.encode(mapValue); var jsons ...
- 了解 Selenium 定位方式
※元素定位的重要性:在于查找元素 And 执行元素 定位元素的三种方法 1.定位单个元素:在定位单个元素时,selenium-webdriver 提示了如下一些方法对元素进行定位.在这些定位方式中,优 ...
- 说说Spring XML的头
部分内容截取自(http://blog.csdn.net/zhch152/article/details/8191377,http://iswift.iteye.com/blog/1657537) 在 ...
- iOS 13 DeviceToken获取发生变化
问题描述: iOS 13 通过[deviceToken description]获取到的内容已经变了,这段代码运行在 iOS 13 上已经无法获取到准确的DeviceToken字符串了, NSStri ...
- 洛谷 题解 P2540 【斗地主增强版】
[分析] 暴力搜顺子,贪心出散牌 为什么顺子要暴力? 玩过斗地主的都知道,并不是出越长的顺子越好,如果你有一组手牌,3,4,5,6,7,6,7,8,9,10,你一下把最长的出了去,你会单两张牌,不如出 ...
- 011 Android AutoCompleteTextView(自动完成文本框)的基本使用
1.XML布局 android:completionThreshold="1":这里我们设置了输入一个字就显示提示 (1)主界面布局 <?xml version=" ...
- Python 解LeetCode:394 Decode String
题目描述:按照规定,把字符串解码,具体示例见题目链接 思路:使用两个栈分别存储数字和字母 注意1: 数字是多位的话,要处理后入数字栈 注意2: 出栈时过程中产生的组合后的字符串要继续入字母栈 注意3: ...
- 介绍几款常用的在线API管理工具
在项目开发过程中,总会涉及到接口文档的设计编写,之前使用的都是ms office工具,不够漂亮也不直观,变更频繁的话维护成本也更高,及时性也是大问题.基于这个背景,下面介绍几个常用的API管理工具,方 ...
- NOP法破解
目录 步骤 步骤 OD载入目标软件,汇编窗口右键搜索字符串,发现错误类提示字符串,双击该字符串来到该段代码位置. 向上寻找到跳转到本段错误提示代码的跳转指令,用NOP指令填充跳转指令. 保存修改后的代 ...