CapsuleNet

前言

找了很多资料,终于把整个流程搞懂了,其实要懂这个运算并不难,难的对我来说是怎么用代码实现,也找了github上的一些代码来看,对我来说都有点冗长,变量分布太远导致我脑袋炸了,所以我就在B站找视频看看有没有代码讲解,算是不负苦心吧,终于把实现部分解决了。

不写论文解读,因为原文实在太难读了,这个老外的英文我基本上每看一句都要取查翻译,很难受,而且网上的教程、解析非常非常之多,所以我留个代码,以后看一下就能想起来了。

Capsule是干什么的

capsule是换了一种神经元的表达方式,原来每个神经元我们是用一个scalar来表示的,现在在capsule中我们中vector来表示一个神经元。这样做的好处是可以多维度描述一个神经元,而在capsue中,我们用vector的模长来表示概率,其他每个维度可以表征神经元的属性。比如某个维度表征特征的朝向,当特征朝向改变时,神经元的模长并没有改变,而是该维度的值改变了,这是一个很好的理解。

这部分网上资料简直太多了,上面说的只是我个人的见解,可以看看别人的版本。

Capsule代码怎么写

网络的结构图还是得贴一张

整体网络分三层,第一层卷积层,将(3,28,28)的输入映射到(256,20,20),第二层称为primary_caps,拿32个filter分8次卷积,得到(32,6,6,8)的输出,然后reshape成(1152,1,8)这里就是为了后面vector in vector out做准备了。

这里表达的意思就是有1152个capsule,每个capsule里有1个8维的vector,老有意思了。

然后就是后面digit_caps层了,我们目标vector应该是(10,1,16),输入是(1152,1,8),所以我们在这里思考作者是如何得到这样的映射关系的。

利用动态路由算法,我们成功得到的v。

好,结束。重建的代码我就不写了。

附上总代码:

import torch
import torch.nn as nn from torchsummary import summary from torch.autograd import Variable
class CapsuleLayer(nn.Module):
def __init__(self,routing = False):
super(CapsuleLayer,self).__init__()
self.routing = routing
def create_conv(unit_idx):
conv_unit = nn.Conv2d(256,32,kernel_size = 9,stride = 2)
self.add_module("conv_unit_{}".format(unit_idx),conv_unit)
return conv_unit
self.conv_units = [create_conv(i) for i in range(8)]
self.w = Variable(torch.randn(1,1152,10,16))
self.fc = nn.Linear(8,16)
def forward(self,x):
if self.routing:
return self.use_routing(x)
else:
return self.no_routing(x)
@staticmethod
def squash(x):
f = torch.sum(x**2,dim =2,keepdim = True)
return f / (1 + f) / (x / torch.sqrt(f))
def use_routing(self,x):# (-1,8,32*6*6)
x = x.transpose(1,2).view(-1,32*6*6,1,8)
x = self.fc(x)
w = torch.cat([self.w] * x.size(0), dim = 0)
u = w * x # (b,1152,10,8)
b = Variable(torch.zeros(x.size(0),x.size(1),10,1,1)) for iter in range(3):
c = torch.softmax(u,dim = -1)
s = torch.sum(c,dim = 1,keepdim = True)
v = self.squash(s).view(-1,1,10,16,1)
b = b + u.view(x.size(0),1152,10,1,16) @ v.view(x.size(0),1,10,16,1) return v.view(x.size(0),10,16) def no_routing(self,x):
u = [self.conv_units[i](x) for i in range(8)]
# every u (-1,32,6,6) # (-1,8,32,6,6)
u = torch.stack(u,dim =1)
u = u.view(-1,8,32*6*6)
return self.squash(u)
class CapsuleNet(nn.Module):
def __init__(self):
super(CapsuleNet,self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1,256,kernel_size = 9,stride = 1),
nn.ReLU()
)
self.pri_caps = CapsuleLayer()
self.digit_caps = CapsuleLayer(routing = True)
def forward(self,x):
x = self.conv(x) # (-1,256,20,20)
x = self.pri_caps(x)
x = self.digit_caps(x)
return x
if __name__ == "__main__":
x = torch.randn(2,1,28,28)
net = CapsuleNet()
y = net(x)
print(y.size())

[论文理解] CapsuleNet的更多相关文章

  1. [论文理解]关于ResNet的进一步理解

    [论文理解]关于ResNet的理解 这两天回忆起resnet,感觉残差结构还是不怎么理解(可能当时理解了,时间长了忘了吧),重新梳理一下两点,关于resnet结构的思考. 要解决什么问题 论文的一大贡 ...

  2. [论文理解] CornerNet: Detecting Objects as Paired Keypoints

    [论文理解] CornerNet: Detecting Objects as Paired Keypoints 简介 首先这是一篇anchor free的文章,看了之后觉得方法挺好的,预测左上角和右下 ...

  3. R-FCN论文理解

    一.R-FCN初探 1. R-FCN贡献 提出Position-sensitive score maps来解决目标检测的位置敏感性问题: 区域为基础的,全卷积网络的二阶段目标检测框架: 比Faster ...

  4. YOLO V3论文理解

    YOLO3主要的改进有:调整了网络结构:利用多尺度特征进行对象检测:对象分类用Logistic取代了softmax. 1.Darknet-53 network在论文中虽然有给网络的图,但我还是简单说一 ...

  5. YOLO V2论文理解

    概述 YOLO(You Only Look Once: Unified, Real-Time Object Detection)从v1版本进化到了v2版本,作者在darknet主页先行一步放出源代码, ...

  6. ssd算法论文理解

    这篇博客主要是讲下我在阅读ssd论文时对论文的理解,并且自行使用pytorch实现了下论文的内容,并测试可以用. 开篇放下论文地址https://arxiv.org/abs/1512.02325,可以 ...

  7. [论文理解]Deep Residual Learning for Image Recognition

    Deep Residual Learning for Image Recognition 简介 这是何大佬的一篇非常经典的神经网络的论文,也就是大名鼎鼎的ResNet残差网络,论文主要通过构建了一种新 ...

  8. [论文理解] Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks

    Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks 简介 Faster R-CNN是很经典的t ...

  9. [论文理解]Selective Search for Object Recognition

    Selective Search for Object Recognition 简介 Selective Search是现在目标检测里面非常常用的方法,rcnn.frcnn等就是通过selective ...

随机推荐

  1. Hadoop环境安装和集群创建

    虚拟机使用vmware,vmware可以直接百度下载安装 秘钥也能百度到 安装很简单 CentOS 7下载: 进入官网 https://www.centos.org/download/ 这里有三种 第 ...

  2. jquery事件绑定方式总结(补充)

    总结 : 1.简单事件绑定方式:事件名()  如:click() 2.高级事件绑定方式:bind(事件名,数据参数,function)    3.动态生成元素事件绑定方式:live(事件名,数据参数, ...

  3. 第一个简单的Echarts实例

    该示例使用 vue-cli 脚手架搭建 安装echarts依赖 npm install echarts -S 或者使用国内的淘宝镜像: 安装 npm install -g cnpm --registr ...

  4. wampserver2.2 在window2003下的安装的主要问题

    准备安装最新的wampserver 2.2c,   1.安装问题,安装完成后总是无法启动服务   系统事件中提示错误 找不到附属汇编 Microsoft.VC90.CRT,上一个错误是 参照的汇编没有 ...

  5. centos7.3安装docker

    一.写随笔的原因:最近在阿里云上买了个centos7.3服务器,想将一些demo运行在上面,所以需要做一些环境的安装,通过此篇文章MAKR一下.下面来记录下安装步骤(参考网上的一些教程,有坑的话会实时 ...

  6. vue typescript curd

    用typescript 完成了一个页面 import { Component, Prop } from 'vue-property-decorator'; import Vue, { VNode } ...

  7. Coinbase 雇员被 Firefox 0day 漏洞攻击

    Firefox 刚刚修复的 0day 漏洞被用于攻击 Coinbase 雇员.Coinbase 安全团队的 Philip Martin 称,攻击者组合利用了两个 0day 漏洞,其一是远程代码执行漏洞 ...

  8. 升级python导致yum报错的解决方法

    把python从2.7升级到3.6后 , 使用yum报错 File ‘’/usr/bin/yum'', line 30 except KeyboardInterrupt, e: ^ 故障原因:yum采 ...

  9. 009(1)-saltstack之salt-ssh的使用及配置管理LAMP状态的实现

    1 salt-ssh的使用 1. 安装salt-ssh[root@slave1 .ssh]# yum install -y salt-ssh 2. 配置salt-ssh # Sample salt-s ...

  10. local_time

    time_t time(time_t *tloc); 功能:获取纪元1970-01-01 00:00:00以来所经历的秒数 参数: tloc:用来存储返回时间 返回值:成功:返回秒数, 失败:-1 - ...