1.文章原文地址

SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation

2.文章摘要

语义分割具有非常广泛的应用,从场景理解、目标相互关系推断到自动驾驶。早期依赖于低水平视觉线索的方法已经快速的被流行的机器学习算法所取代。特别是最近的深度学习在手写数字识别、语音、图像中的分类和目标检测上取得巨大成功。如今有一个活跃的领域是语义分割(对每个像素进行归类)。然而,最近有一些方法直接采用了为图像分类而设计的网络结构来进行语义分割任务。虽然结果十分鼓舞人心,但还是比较粗糙。这首要的原因是最大池化和下采样减小了特征图的分辨率。我们设计SegNet的动机来自于分割任务需要将低分辨率的特征图映射到输入的分辨率并进行像素级分类,这个映射必须产生对准确边界定位有用的特征。

3.网络结构

4.Pytorch实现

 import torch.nn as nn
import torch class conv2DBatchNormRelu(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size,stride,padding,
bias=True,dilation=1,is_batchnorm=True):
super(conv2DBatchNormRelu,self).__init__()
if is_batchnorm:
self.cbr_unit=nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,
bias=bias,dilation=dilation),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
else:
self.cbr_unit=nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
bias=bias, dilation=dilation),
nn.ReLU(inplace=True)
) def forward(self,inputs):
outputs=self.cbr_unit(inputs)
return outputs class segnetDown2(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetDown2,self).__init__()
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True) def forward(self,inputs):
outputs=self.conv1(inputs)
outputs=self.conv2(outputs)
unpooled_shape=outputs.size()
outputs,indices=self.maxpool_with_argmax(outputs)
return outputs,indices,unpooled_shape class segnetDown3(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetDown3,self).__init__()
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True) def forward(self,inputs):
outputs=self.conv1(inputs)
outputs=self.conv2(outputs)
outputs=self.conv3(outputs)
unpooled_shape=outputs.size()
outputs,indices=self.maxpool_with_argmax(outputs)
return outputs,indices,unpooled_shape class segnetUp2(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetUp2,self).__init__()
self.unpool=nn.MaxUnpool2d(2,2)
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1) def forward(self,inputs,indices,output_shape):
outputs=self.unpool(inputs,indices=indices,output_size=output_shape)
outputs=self.conv1(outputs)
outputs=self.conv2(outputs)
return outputs class segnetUp3(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetUp3,self).__init__()
self.unpool=nn.MaxUnpool2d(2,2)
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1) def forward(self,inputs,indices,output_shape):
outputs=self.unpool(inputs,indices=indices,output_size=output_shape)
outputs=self.conv1(outputs)
outputs=self.conv2(outputs)
outputs=self.conv3(outputs)
return outputs class segnet(nn.Module):
def __init__(self,in_channels=3,num_classes=21):
super(segnet,self).__init__()
self.down1=segnetDown2(in_channels=in_channels,out_channels=64)
self.down2=segnetDown2(64,128)
self.down3=segnetDown3(128,256)
self.down4=segnetDown3(256,512)
self.down5=segnetDown3(512,512) self.up5=segnetUp3(512,512)
self.up4=segnetUp3(512,256)
self.up3=segnetUp3(256,128)
self.up2=segnetUp2(128,64)
self.up1=segnetUp2(64,64)
self.finconv=conv2DBatchNormRelu(64,num_classes,3,1,1) def forward(self,inputs):
down1,indices_1,unpool_shape1=self.down1(inputs)
down2,indices_2,unpool_shape2=self.down2(down1)
down3,indices_3,unpool_shape3=self.down3(down2)
down4,indices_4,unpool_shape4=self.down4(down3)
down5,indices_5,unpool_shape5=self.down5(down4) up5=self.up5(down5,indices=indices_5,output_shape=unpool_shape5)
up4=self.up4(up5,indices=indices_4,output_shape=unpool_shape4)
up3=self.up3(up4,indices=indices_3,output_shape=unpool_shape3)
up2=self.up2(up3,indices=indices_2,output_shape=unpool_shape2)
up1=self.up1(up2,indices=indices_1,output_shape=unpool_shape1)
outputs=self.finconv(up1) return outputs if __name__=="__main__":
inputs=torch.ones(1,3,224,224)
model=segnet()
print(model(inputs).size())
print(model)

参考

https://github.com/meetshah1995/pytorch-semseg

SegNet网络的Pytorch实现的更多相关文章

  1. 群等变网络的pytorch实现

    CNN对于旋转不具有等变性,对于平移有等变性,data augmentation的提出就是为了解决这个问题,但是data augmentation需要很大的模型容量,更多的迭代次数才能够在训练数据集合 ...

  2. U-Net网络的Pytorch实现

    1.文章原文地址 U-Net: Convolutional Networks for Biomedical Image Segmentation 2.文章摘要 普遍认为成功训练深度神经网络需要大量标注 ...

  3. ResNet网络的Pytorch实现

    1.文章原文地址 Deep Residual Learning for  Image Recognition 2.文章摘要 神经网络的层次越深越难训练.我们提出了一个残差学习框架来简化网络的训练,这些 ...

  4. GoogLeNet网络的Pytorch实现

    1.文章原文地址 Going deeper with convolutions 2.文章摘要 我们提出了一种代号为Inception的深度卷积神经网络,它在ILSVRC2014的分类和检测任务上都取得 ...

  5. AlexNet网络的Pytorch实现

    1.文章原文地址 ImageNet Classification with Deep Convolutional Neural Networks 2.文章摘要 我们训练了一个大型的深度卷积神经网络用于 ...

  6. VGG网络的Pytorch实现

    1.文章原文地址 Very Deep Convolutional Networks for Large-Scale Image Recognition 2.文章摘要 在这项工作中,我们研究了在大规模的 ...

  7. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  8. pytorch预训练

    Pytorch预训练模型以及修改 pytorch中自带几种常用的深度学习网络预训练模型,torchvision.models包中包含alexnet.densenet.inception.resnet. ...

  9. PyTorch使用总览

    PyTorch使用总览 https://blog.csdn.net/u014380165/article/details/79222243 深度学习框架训练模型时的代码主要包含数据读取.网络构建和其他 ...

随机推荐

  1. [LeetCode] 849. Maximize Distance to Closest Person 最大化最近人的距离

    In a row of seats, 1 represents a person sitting in that seat, and 0 represents that the seat is emp ...

  2. web端自动化——Python的smtplib发送电子邮件

    SMTP (Simple Mail Transfer Protocol)是简单邮件传输协议,它是一组用于由源地址到目的地址传送邮件的规则,由它来控制信件的中转方式. Python的smtplib模块提 ...

  3. 老司机带大家领略MySQL中的乐观锁和悲观锁

    原文地址:https://cloud.tencent.com/developer/news/227982 为什么需要锁 在并发环境下,如果多个客户端访问同一条数据,此时就会产生数据不一致的问题,如何解 ...

  4. 最新 上海轻轻java校招面经 (含整理过的面试题大全)

    从6月到10月,经过4个月努力和坚持,自己有幸拿到了网易雷火.京东.去哪儿.上海轻轻等10家互联网公司的校招Offer,因为某些自身原因最终选择了上海轻轻.6.7月主要是做系统复习.项目复盘.Leet ...

  5. HTTP权威指南-报文与状态码

    所有的报文都向下流动 报文流向 报文组成 HTTP方法 状态码 GET示例 HEAD示例 100~199 信息性状态码 200~299 成功状态码 300~399重定向状态码 400~499 客户端错 ...

  6. 多线程(8) — ThreadLocal

    ThreadLocal是一个线程的局部变量,也就是只有当前线程可以访问,是线程安全的.为每一个线程分配不同的对象,需要在应用层面保证ThreadLocal只起到简单的容器作用. ThreadLocal ...

  7. (5)Spring Boot web开发 --- Restful CRUD

    文章目录 `@RestController` vs `@Controller` 默认访问首页 设置项目名 国际化 登陆 & 拦截 Restful 风格 @RestController vs @ ...

  8. golang日志库之glog使用问题总结

    1. 日志默认输出路径为临时路径,可通过执行命令时带上 -log_dir="路径",指定输出,但路径必须已存在,源码如下,日志文件会生成两个 .INFO等后缀是符号链接文件,另一个 ...

  9. DBA职责和任务

    DBA守则在对生产环境进行修改前,一定要进行备份,一定要在测试环境进行测试,否则不要进行轻易的更改一次尽量只做一件事,不要受环境影响 DBA的十大任务1.了解和掌握硬件环境2.规划数据库3.安装数据库 ...

  10. java连接腾讯云上的redis

    目录 腾讯云上的配置 redis连接单机和集群 依赖 pom.xml redis参数的配置文件 遗留问题 腾讯云上的配置 在安全组上打开相关的端口即可 "来源" 就是你的目标服务器 ...