第四周:卷积神经网络 part 3

视频学习

语义分割中的自注意力机制和低秩重建

  • 语义分割(Semantic Segmentation)

    • 概念:语义分割是在像素级别上的分类,属于同一类的像素都要被归为一类,因此语义分割是从像素级别来理解图像的。
    • 思路:
      • 传统方法:

        • TextonForest和基于随机森林分类器等语义分割方法
      • 深度学习方法:
        • Patch classification
        • 全卷积方法(FCN)
        • encoder-decoder架构
        • 空洞卷积(Dilated/Atrous)
        • 条件随机场
    • 几种架构:

ps:空洞卷积、池化目的都是增大感受野。

  • 自注意力机制(Self-attention Mechanism)

    • 是注意力机制的改进,其减少了对外部信息的依赖,更擅长捕捉数据或特征的内部相关性。

背景:

语义分割是计算机视觉几大主任务之一,被广泛应用到自动驾驶、遥感监测等领域中。语义分割研究中的若干成果,也被诸多相关领域沿用。自注意力机制继在 NLP 领域取得主导地位之后,近两年在计算机视觉领域也开始独领风骚。自注意力机制在语义分割网络中的应用,并由之衍生出的一系列低秩重建相关的方法。

图像语义分割前沿进展

超像素、语义分割、实例分割、全景分割的区别

既需要细节,又需要捕捉全局信息

  • 得到大尺度信息方法

    • 1.Non-local modules(非局部模块)
    • 2.self-attention(自注意力)
    • 3.Dilated convolution(空洞卷积)
    • 4.Pyramid/global pooling(金字塔/全局池化)
  • 缺陷

    • 1、2计算资源消耗多
    • 3、4虽相对低代价,但各向同性,很难获得各向异性

提高CNNs中远程依赖关系建模能力的一种方法是采用self-attention机制non-local模块。然而,它们会消耗大量内存。而其他的远程上下文建模方法包括:

1)扩张卷积,其目的是在不引入额外参数的情况下扩大CNNs的接受域;

2)全局/金字塔池化,它总结了图像的全局线索。

然而,这些方法的一个常见限制,包括扩张卷积和池化在内,它们都在方形窗口中探测输入特征图。这限制了它们在捕获广泛存在于现实场景中的各向异性的上下文上的灵活性。

  • CVPR2020 带状池化(Strip Pooling)

    • 为了更有效地捕获长依赖关系,本文在空间池化扩大CNNs的感受野和捕获上下文信息的基础上,提出了条纹池化的概念,作为全局池化的替代方案。
    • 优点:
      • 它沿着一个空间维度部署一个长条状的池化核形状,因此能够捕获孤立区域的长距离关系
      • 在其他空间维度上保持较窄的内核形状,便于捕获局部上下文,防止不相关区域干扰标签预测
    • 集成这种长而窄的池内核使语义分割网络能够同时聚合全局和本地上下文。这与传统的从固定的正方形区域收集上下文的池化有本质的不同。
    • 基于条纹池化的想法,作者提出了两种即插即用的池化模块:
      • Strip Pooling Module (SPM)
      • Mixed Pooling module (MPM)

代码练习

完善 HybridSN 高光谱分类网络

模型的网络结构为如下图所示:

三维卷积部分:

  • conv1:(1, 30, 25, 25), 8个 7x3x3 的卷积核 ==>(8, 24, 23, 23)
  • conv2:(8, 24, 23, 23), 16个 5x3x3 的卷积核 ==>(16, 20, 21, 21)
  • conv3:(16, 20, 21, 21),32个 3x3x3 的卷积核 ==>(32, 18, 19, 19)

接下来要进行二维卷积,因此把前面的 32*18 reshape 一下,得到 (576, 19, 19)

二维卷积:(576, 19, 19) 64个 3x3 的卷积核,得到 (64, 17, 17)

接下来是一个 flatten 操作,变为 18496 维的向量,

接下来依次为256,128节点的全连接层,都使用比例为0.4的 Dropout,

最后输出为 16 个节点,是最终的分类类别数。

下面是 HybridSN 类的代码:

class_num = 16

class SELayer(nn.Module):

  def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
) def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x) class HybridSN(nn.Module): def __init__(self):
super(HybridSN, self).__init__()
self.conv1=nn.Conv3d(in_channels=1,out_channels=8,kernel_size=(7,3,3))
self.bn1 = nn.BatchNorm3d(8)
self.conv2=nn.Conv3d(in_channels=8,out_channels=16,kernel_size=(5,3,3))
self.bn2 = nn.BatchNorm3d(16)
self.conv3=nn.Conv3d(in_channels=16,out_channels=32,kernel_size=(3,3,3))
self.bn3 = nn.BatchNorm3d(32)
self.se1 = SELayer(576, 16)
self.conv4 = nn.Conv2d(576, 64, 3)
self.bn4 = nn.BatchNorm2d(64)
self.se2 = SELayer(64, 16)
self.fc1=nn.Linear(18496,256)
self.fc2=nn.Linear(256,128)
self.fc3=nn.Linear(128,class_num)
self.dropout = nn.Dropout(p=0.4)
# self.soft = nn.Softmax(dim = 1) def forward(self, x):
x = self.conv1(x)
# x = self.bn1(x)
x = F.relu(x)
x = self.conv2(x)
# x = self.bn2(x)
x = F.relu(x)
x = self.conv3(x)
# x = self.bn3(x)
x = F.relu(x) x = torch.reshape(x,[x.shape[0],576,19,19]) # x = self.se1(x)
x = self.conv4(x)
# x = self.se2(x)
# x = self.bn4(x)
x = F.relu(x) x = torch.flatten(x,start_dim=1) x = self.dropout(F.relu(self.fc1(x)))
x = self.dropout(F.relu(self.fc2(x)))
x = self.fc3(x)
# x = self.soft(x) return x # 随机输入,测试网络结构是否通
x = torch.randn(1, 1, 30, 25, 25)
net = HybridSN()
y = net(x)
print(y.shape)

准确率为 95.96%

性能良好,测试结果稳定。

也可以考虑加入BN,进一步提升性能。

准确率为 96.61%

先后顺序:Batch Normalization 层恰恰插入在 Conv 层或全连接层之后,而在 ReLU等激活层之前。而对于 dropout 则应当置于 activation layer 之后。

注意:BN和Dropout单独使用都能减少过拟合并加速训练速度,但如果一起使用的话并不会产生1+1>2的效果,相反可能会得到比单独使用更差的效果。

SENet 实现

该网络通过学习的方式获取每个特征通道的重要程度,然后依照这个重要程度去提升有用的特征并抑制对当前任务用处不大的特征。

将上方代码SE注释部分解开,加入SE模块。

准确率为 97.36%

SE模块主要为了提升模型对channel特征的敏感性,这个模块是轻量级的,而且可以应用在现有的网络结构中,只需要增加较少的计算量就可以带来性能的提升。

总结:

方法 accuracy
普通HybridSN 95.96%
加入BN 96.61%
加入SENet 97.36%

其他的还可以通过添加学习率衰减函数来提升性能,这里mark一下,以后做个实验验证。

ResNet预训练模型 垃圾分类识别

AI研习社最新的比赛:垃圾分类识别

老师让我们试一试,我找了个预训练的ResNet模型微调了一下,加了几个简单的trick,结果如下:

离标准分_85还是有些距离,看来通用的网络并不比针对特定问题设计的网络表现更好。

后续我会继续关注比赛,争取学习更多的深度学习知识,设计出更加强大的网络。

第四周:卷积神经网络 part 3的更多相关文章

  1. Tensorflow之卷积神经网络(CNN)

    前馈神经网络的弊端 前一篇文章介绍过MNIST,是采用的前馈神经网络的结构,这种结构有一个很大的弊端,就是提供的样本必须面面俱到,否则就容易出现预测失败.如下图: 同样是在一个图片中找圆形,如果左边为 ...

  2. Deep Learning.ai学习笔记_第四门课_卷积神经网络

    目录 第一周 卷积神经网络基础 第二周 深度卷积网络:实例探究 第三周 目标检测 第四周 特殊应用:人脸识别和神经风格转换 第一周 卷积神经网络基础 垂直边缘检测器,通过卷积计算,可以把多维矩阵进行降 ...

  3. torch_06_卷积神经网络

    1.概述 卷积神经网络的参数,由一些可学习的滤波器集合构成的,每个滤波器在空间上都计较小,但是深度和输入数据的深度保持一致.在前向传播中,让每个滤波器都在输入数据的宽度和高度上滑动(卷积),然后计算整 ...

  4. 卷积神经网络概念及使用 PyTorch 简单实现

    卷积神经网络 卷积神经网络(CNN)是深度学习的代表算法之一 .具有表征学习能力,能够按其阶层结构对输入信息进行平移不变分类,因此也被称为“平移不变人工神经网络”.随着深度学习理论的提出和数值计算设备 ...

  5. SIGAI深度学习第八集 卷积神经网络2

    讲授Lenet.Alexnet.VGGNet.GoogLeNet等经典的卷积神经网络.Inception模块.小尺度卷积核.1x1卷积核.使用反卷积实现卷积层可视化等. 大纲: LeNet网络 Ale ...

  6. CNN(卷积神经网络)原理讲解及简单代码

    一.原理讲解 1. 卷积神经网络的应用 分类(分类预测) 检索(检索出该物体的类别) 检测(检测出图像中的物体,并标注) 分割(将图像分割出来) 人脸识别 图像生成(生成不同状态的图像) 自动驾驶 等 ...

  7. 卷积神经网络提取特征并用于SVM

    模式识别课程的一次作业.其目标是对UCI的手写数字数据集进行识别,样本数量大约是1600个.图片大小为16x16.要求必须使用SVM作为二分类的分类器. 本文重点是如何使用卷积神经网络(CNN)来提取 ...

  8. tensorflow学习笔记五:mnist实例--卷积神经网络(CNN)

    mnist的卷积神经网络例子和上一篇博文中的神经网络例子大部分是相同的.但是CNN层数要多一些,网络模型需要自己来构建. 程序比较复杂,我就分成几个部分来叙述. 首先,下载并加载数据: import ...

  9. Deep Learning模型之:CNN卷积神经网络(一)深度解析CNN

    http://m.blog.csdn.net/blog/wu010555688/24487301 本文整理了网上几位大牛的博客,详细地讲解了CNN的基础结构与核心思想,欢迎交流. [1]Deep le ...

随机推荐

  1. Presto 函数开发

    0. 写在前面 Presto Functions 并不能像 Hive UDF 一样动态加载,需要根据 Function 的类型,实现 Presto 内部定义的不同接口,在 Presto 服务启动时进行 ...

  2. flask下直接展示mysql数据库 字段

    在工作中,会导出一份mysql的html来查看,用的是就是路过秋天大神的那个工具,所以想自己用那个样式直接在后端写一个页面做展示! 前端页面 from flask import Flask,reque ...

  3. Mybatis(二)简化Mybatis实现数据库操作

    要操作的数据库: 一.与数据库对应的bean类 public class User { private String username; private String sex; private Str ...

  4. onepill Android端

    使用的框架 第三方登录集成基于ThinkPHP5的第三方登录插件 QQ第三方登录集成QQ互联.qq第三方接入 SharedPreference实现记住账号密码功能参考.参考2

  5. IO—》字节流&字符流

    字节流 一.字节输出流OutputStream OutputStream此抽象类,是表示输出字节流的所有类的超类.操作的数据都是字节,定义了输出字节流的基本共性功能方法. FileOutputStre ...

  6. DVWA SQL 注入关卡初探

    1. 判断回显 给id参数赋不同的值,发现有不同的返回信息 2. 判断参数类型 在参数后加 ' ,查看报错信息 数字型参数左右无引号,字符型参数左右有引号 4. 引号闭合与布尔类型判断 由于是字符型参 ...

  7. 女生学Java编程是什么感受?

    那我就代表女生来说说感受 在编程的世界很难遇到好看的帅哥 记得当年15年7月4号是我实习生入职的日子,因为是校企合作,所以没有面试.老师推荐.直接入职.刚来北京第一个感觉就是人多,还有就是热.刚到公司 ...

  8. Python os.utime() 方法

    概述 os.utime() 方法用于设置指定路径文件最后的修改和访问时间.高佣联盟 www.cgewang.com 在Unix,Windows中有效. 语法 utime()方法语法格式如下: os.u ...

  9. PHP MySQL Delete删除数据库中的数据

    PHP MySQL Delete DELETE 语句用于从数据库表中删除行. 删除数据库中的数据 DELETE FROM 语句用于从数据库表中删除记录. 语法 DELETE FROM table_na ...

  10. luogu P4525 自适应辛普森法1

    LINK:自适应辛普森法1 观察题目 这个东西 凭借我们的数学知识应该是化简不了的. 可以直接认为是一个函数 求定积分直接使用辛普森就行辣. 一种写法: double a,b,c,d; double ...