U-Net网络的Pytorch实现
1.文章原文地址
U-Net: Convolutional Networks for Biomedical Image Segmentation
2.文章摘要
普遍认为成功训练深度神经网络需要大量标注的训练数据。在本文中,我们提出了一个网络结构,以及使用数据增强的策略来训练网络使得可用的标注样本更加有效的被使用。这个网络是由一个捕捉上下文信息的收缩部分和与之相对称的放大部分,后者能够准确的定位。我们的结果展示了这个网络可以进行端到端的训练,使用非常少的数据就可以达到非常好的结果,并且超过了当前的最佳方法(滑动窗网络)在ISBII挑战赛上电子显微镜下神经结构的分割的结果。利用透射光显微镜图像使用相同网络进行训练,我们大幅度的赢得了2015年的ISBI细胞追踪挑战赛。而且,这个网络非常快,在一个当前的GPU上,分割一个512x512的图像所花费的时间少于一秒。完整的代码以及训练好的网络可见(基于Caffe)http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.
3.网络结构
4.Pytorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary class unetConv2(nn.Module):
def __init__(self,in_size,out_size,is_batchnorm):
super(unetConv2,self).__init__() if is_batchnorm:
self.conv1=nn.Sequential(
nn.Conv2d(in_size,out_size,kernel_size=3,stride=1,padding=0),
nn.BatchNorm2d(out_size),
nn.ReLU(inplace=True),
)
self.conv2=nn.Sequential(
nn.Conv2d(out_size,out_size,kernel_size=3,stride=1,padding=0),
nn.BatchNorm2d(out_size),
nn.ReLU(inplace=True),
)
else:
self.conv1=nn.Sequential(
nn.Conv2d(in_size,out_size,kernel_size=3,stride=1,padding=0),
nn.ReLU(inplace=True),
)
self.conv2=nn.Sequential(
nn.Conv2d(out_size,out_size,kernel_size=3,stride=1,padding=0),
nn.ReLU(inplace=True)
)
def forward(self, inputs):
outputs=self.conv1(inputs)
outputs=self.conv2(outputs) return outputs class unetUp(nn.Module):
def __init__(self,in_size,out_size,is_deconv):
super(unetUp,self).__init__()
self.conv=unetConv2(in_size,out_size,False)
if is_deconv:
self.up=nn.ConvTranspose2d(in_size,out_size,kernel_size=2,stride=2)
else:
self.up=nn.UpsamplingBilinear2d(scale_factor=2) def forward(self, inputs1,inputs2):
outputs2=self.up(inputs2)
offset=outputs2.size()[2]-inputs1.size()[2]
padding=2*[offset//2,offset//2]
outputs1=F.pad(inputs1,padding) #padding is negative, size become smaller return self.conv(torch.cat([outputs1,outputs2],1)) class unet(nn.Module):
def __init__(self,feature_scale=4,n_classes=21,is_deconv=True,in_channels=3,is_batchnorm=True):
super(unet,self).__init__()
self.is_deconv=is_deconv
self.in_channels=in_channels
self.is_batchnorm=is_batchnorm
self.feature_scale=feature_scale filters=[64,128,256,512,1024]
filters=[int(x/self.feature_scale) for x in filters] #downsample
self.conv1=unetConv2(self.in_channels,filters[0],self.is_batchnorm)
self.maxpool1=nn.MaxPool2d(kernel_size=2) self.conv2=unetConv2(filters[0],filters[1],self.is_batchnorm)
self.maxpool2=nn.MaxPool2d(kernel_size=2) self.conv3=unetConv2(filters[1],filters[2],self.is_batchnorm)
self.maxpool3=nn.MaxPool2d(kernel_size=2) self.conv4=unetConv2(filters[2],filters[3],self.is_batchnorm)
self.maxpool4=nn.MaxPool2d(kernel_size=2) self.center=unetConv2(filters[3],filters[4],self.is_batchnorm) #umsampling
self.up_concat4=unetUp(filters[4],filters[3],self.is_deconv)
self.up_concat3=unetUp(filters[3],filters[2],self.is_deconv)
self.up_concat2=unetUp(filters[2],filters[1],self.is_deconv)
self.up_concat1=unetUp(filters[1],filters[0],self.is_deconv) #final conv (without and concat)
self.final=nn.Conv2d(filters[0],n_classes,kernel_size=1) def forward(self, inputs):
conv1=self.conv1(inputs)
maxpool1=self.maxpool1(conv1) conv2=self.conv2(maxpool1)
maxpool2=self.maxpool2(conv2) conv3=self.conv3(maxpool2)
maxpool3=self.maxpool3(conv3) conv4=self.conv4(maxpool3)
maxpool4=self.maxpool4(conv4) center=self.center(maxpool4)
up4=self.up_concat4(conv4,center)
up3=self.up_concat3(conv3,up4)
up2=self.up_concat2(conv2,up3)
up1=self.up_concat1(conv1,up2) final=self.final(up1) return final if __name__=="__main__":
model=unet(feature_scale=1)
print(summary(model,(3,572,572)))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 570, 570] 1,792
BatchNorm2d-2 [-1, 64, 570, 570] 128
ReLU-3 [-1, 64, 570, 570] 0
Conv2d-4 [-1, 64, 568, 568] 36,928
BatchNorm2d-5 [-1, 64, 568, 568] 128
ReLU-6 [-1, 64, 568, 568] 0
unetConv2-7 [-1, 64, 568, 568] 0
MaxPool2d-8 [-1, 64, 284, 284] 0
Conv2d-9 [-1, 128, 282, 282] 73,856
BatchNorm2d-10 [-1, 128, 282, 282] 256
ReLU-11 [-1, 128, 282, 282] 0
Conv2d-12 [-1, 128, 280, 280] 147,584
BatchNorm2d-13 [-1, 128, 280, 280] 256
ReLU-14 [-1, 128, 280, 280] 0
unetConv2-15 [-1, 128, 280, 280] 0
MaxPool2d-16 [-1, 128, 140, 140] 0
Conv2d-17 [-1, 256, 138, 138] 295,168
BatchNorm2d-18 [-1, 256, 138, 138] 512
ReLU-19 [-1, 256, 138, 138] 0
Conv2d-20 [-1, 256, 136, 136] 590,080
BatchNorm2d-21 [-1, 256, 136, 136] 512
ReLU-22 [-1, 256, 136, 136] 0
unetConv2-23 [-1, 256, 136, 136] 0
MaxPool2d-24 [-1, 256, 68, 68] 0
Conv2d-25 [-1, 512, 66, 66] 1,180,160
BatchNorm2d-26 [-1, 512, 66, 66] 1,024
ReLU-27 [-1, 512, 66, 66] 0
Conv2d-28 [-1, 512, 64, 64] 2,359,808
BatchNorm2d-29 [-1, 512, 64, 64] 1,024
ReLU-30 [-1, 512, 64, 64] 0
unetConv2-31 [-1, 512, 64, 64] 0
MaxPool2d-32 [-1, 512, 32, 32] 0
Conv2d-33 [-1, 1024, 30, 30] 4,719,616
BatchNorm2d-34 [-1, 1024, 30, 30] 2,048
ReLU-35 [-1, 1024, 30, 30] 0
Conv2d-36 [-1, 1024, 28, 28] 9,438,208
BatchNorm2d-37 [-1, 1024, 28, 28] 2,048
ReLU-38 [-1, 1024, 28, 28] 0
unetConv2-39 [-1, 1024, 28, 28] 0
ConvTranspose2d-40 [-1, 512, 56, 56] 2,097,664
Conv2d-41 [-1, 512, 54, 54] 4,719,104
ReLU-42 [-1, 512, 54, 54] 0
Conv2d-43 [-1, 512, 52, 52] 2,359,808
ReLU-44 [-1, 512, 52, 52] 0
unetConv2-45 [-1, 512, 52, 52] 0
unetUp-46 [-1, 512, 52, 52] 0
ConvTranspose2d-47 [-1, 256, 104, 104] 524,544
Conv2d-48 [-1, 256, 102, 102] 1,179,904
ReLU-49 [-1, 256, 102, 102] 0
Conv2d-50 [-1, 256, 100, 100] 590,080
ReLU-51 [-1, 256, 100, 100] 0
unetConv2-52 [-1, 256, 100, 100] 0
unetUp-53 [-1, 256, 100, 100] 0
ConvTranspose2d-54 [-1, 128, 200, 200] 131,200
Conv2d-55 [-1, 128, 198, 198] 295,040
ReLU-56 [-1, 128, 198, 198] 0
Conv2d-57 [-1, 128, 196, 196] 147,584
ReLU-58 [-1, 128, 196, 196] 0
unetConv2-59 [-1, 128, 196, 196] 0
unetUp-60 [-1, 128, 196, 196] 0
ConvTranspose2d-61 [-1, 64, 392, 392] 32,832
Conv2d-62 [-1, 64, 390, 390] 73,792
ReLU-63 [-1, 64, 390, 390] 0
Conv2d-64 [-1, 64, 388, 388] 36,928
ReLU-65 [-1, 64, 388, 388] 0
unetConv2-66 [-1, 64, 388, 388] 0
unetUp-67 [-1, 64, 388, 388] 0
Conv2d-68 [-1, 21, 388, 388] 1,365
================================================================
Total params: 31,040,981
Trainable params: 31,040,981
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.74
Forward/backward pass size (MB): 3158.15
Params size (MB): 118.41
Estimated Total Size (MB): 3280.31
参考
https://github.com/meetshah1995/pytorch-semseg
U-Net网络的Pytorch实现的更多相关文章
- 群等变网络的pytorch实现
CNN对于旋转不具有等变性,对于平移有等变性,data augmentation的提出就是为了解决这个问题,但是data augmentation需要很大的模型容量,更多的迭代次数才能够在训练数据集合 ...
- ResNet网络的Pytorch实现
1.文章原文地址 Deep Residual Learning for Image Recognition 2.文章摘要 神经网络的层次越深越难训练.我们提出了一个残差学习框架来简化网络的训练,这些 ...
- GoogLeNet网络的Pytorch实现
1.文章原文地址 Going deeper with convolutions 2.文章摘要 我们提出了一种代号为Inception的深度卷积神经网络,它在ILSVRC2014的分类和检测任务上都取得 ...
- AlexNet网络的Pytorch实现
1.文章原文地址 ImageNet Classification with Deep Convolutional Neural Networks 2.文章摘要 我们训练了一个大型的深度卷积神经网络用于 ...
- VGG网络的Pytorch实现
1.文章原文地址 Very Deep Convolutional Networks for Large-Scale Image Recognition 2.文章摘要 在这项工作中,我们研究了在大规模的 ...
- SegNet网络的Pytorch实现
1.文章原文地址 SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation 2.文章摘要 语义分 ...
- 【转载】PyTorch系列 (二):pytorch数据读取
原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...
- pytorch预训练
Pytorch预训练模型以及修改 pytorch中自带几种常用的深度学习网络预训练模型,torchvision.models包中包含alexnet.densenet.inception.resnet. ...
- PyTorch使用总览
PyTorch使用总览 https://blog.csdn.net/u014380165/article/details/79222243 深度学习框架训练模型时的代码主要包含数据读取.网络构建和其他 ...
随机推荐
- [CareerCup] 9.8 Represent N Cents 组成N分钱
9.8 Given an infinite number of quarters (25 cents), dimes (10 cents), nickels (5 cents) and pennies ...
- Socket测试工具(客户端、服务端)
Socket是什么? SOCKET用于在两个基于TCP/IP协议的应用程序之间相互通信.最早出现在UNIX系统中,是UNIX系统主要的信息传递方式.在WINDOWS系统中,SOCKET称为WINSOC ...
- Java学习,从入门到放弃(二)Linux配置mvn
其实网上的教程很多,随便拿一个,比如:https://www.cnblogs.com/chuijingjing/p/10430649.html 但在实践过程中,发现可能需要将JAVA_HOME也加到 ...
- CF1266C Diverse Matrix
思路:构造题. 实现: #include <bits/stdc++.h> using namespace std; ][]; int main() { int r, c; while (c ...
- c# 基础类型探索
一.前言 本章节主要是探索 C# 的基本类型,一直以来我本人常用都是 int .double.bool.decimal.string 这五个类型,其对其它类型没有认真了解过.只是以前在学习的时候背了些 ...
- JS Maximum call stack size exceeded
一.问题描述 Maximum call stack size exceeded 翻译为:超过最大调用堆栈大小 二.效果截图 三.问题解决方案 出现该问题,说明程序出现了死循环了.所以要去检查出错的程 ...
- Deepin 15.11 install nvidia dirver[mei you an zhuang shu ru fa]
1.firstly, exec: sudo vim /etc/modprobe.d/blacklist-nouveau.conf[create], and input [blacklist nouve ...
- LeetCode 459. 重复的子字符串(Repeated Substring Pattern)
459. 重复的子字符串 459. Repeated Substring Pattern 题目描述 给定一个非空的字符串,判断它是否可以由它的一个子串重复多次构成.给定的字符串只含有小写英文字母,并且 ...
- [转帖]Linux下主机间文件传输命令
Linux下主机间文件传输命令 https://yq.aliyun.com/articles/53631?spm=a2c4e.11155435.0.0.580ce8ef4Q9uzs SCP命令: ...
- Jenkins+maven+gitlab自动化部署之Jenkins系统管理配置(四)
一.Jenkins全局工具配置 在jenkins首页依次进入系统管理>>全局工具配置: 1) jdk.git.maven配置 指定其在服务器中的目录位置 二.插件管理 1)依次点开系统管理 ...