Dynamic Routing Between Capsules
概
虽然11年就提出了capsule的概念, 但是走入人们视线的应该还是这篇文章吧. 虽然现阶段, capsule没有体现出什么优势. 不过, capsule相较于传统的CNN融入了很多先验知识, 更能够拟合人类的视觉系统(我不知), 或许有一天它会大放异彩.
主要内容

直接从这个结构图讲起吧.
- Input: 1 x 28 x 28 的图片 经过 9 x 9的卷积核(stride=1, padding=0, out_channels=256)作用;
- 256 x 20 x 20的特征图, 经过primarycaps作用(9 x 9 的卷积核(strde=2, padding=0, out_channels=256);
- (32 x 8) x 6 x 6的特征图, 理解为32 x 6 x 6 x 8 = 1152 x 8, 即1152个胶囊, 每个胶囊由一个8D的向量表示\(u_{i}\); (这个地方要不要squash, 大部分实现都是要的.)
- 接下来digitcaps中有10个caps(对应10个类别), 1152caps和10个caps一一对应, 分别用\(i, j\)表示, 前一层的caps为后一层提供输入, 输入为
\]
可见, 应当有1152 x 10个\(W_{ij}\in \mathbb{R}^{16\times 8}\), 其中16是输出胶囊的维度. 最后10个caps的输出为
\]
其中\(c_{ij}\)是通过一个路由算法决定的, \(v_j\), 即最后的输入如此定义是出于一种直觉, 即保持原始输出(\(s\))的方向, 同时让\(v\)的长度表示一个概率(这一步称为squash).
首先初始化\(b_{ij}=0\) (这里在程序实现的时候有一个考量, 是每一次都要初始化吗, 我看大部分的实现都是如此的).

上面的Eq.3就是
c_{ij}=\frac{\exp(b_{ij})}{\sum_{k}\exp(b_{ik})}.
\]
另外\(\hat{\mu}_{j|i} \cdot v_j=\hat{\mu}_{j|i}^Tv_j\)是一种cos相似度度量.
损失函数
损失函数采用的是margin loss:
L_k = T_k \max(0, m^+ - \|v_k\|)^2 + \lambda (1 - T_k) \max(0, \|v_k\|-m^-)^2.
\]
\(m^+, m^-\)通常取0.9和0.1, \(\lambda\)通常取0.5.
代码
我的代码, 在sgd下可以训练(但是准确率只有98), 在adam下就死翘翘了, 所以代码肯定是有问题, 但是我实在是找不出来了, 这里有很多实现的汇总.
"""
Sabour S., Frosst N., Hinton G. Dynamic Routing Between Capsules.
Neural Information Processing Systems, pp. 3856-3866, 2017.
https://arxiv.org/pdf/1710.09829.pdf
The implement below refers to https://github.com/adambielski/CapsNet-pytorch.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
def squash(s):
temp = s.norm(dim=-1, keepdim=True)
return (temp / (1. + temp ** 2)) * s
class PrimaryCaps(nn.Module):
def __init__(
self, in_channel, out_entities,
out_dims, kernel_size, stride, padding
):
super(PrimaryCaps, self).__init__()
self.conv = nn.Conv2d(in_channel, out_entities * out_dims,
kernel_size, stride, padding)
self.out_entities = out_entities
self.out_dims = out_dims
def forward(self, inputs):
conv_outs = self.conv(inputs).permute(0, 2, 3, 1).contiguous()
outs = conv_outs.view(conv_outs.size(0), -1, self.out_dims)
return squash(outs)
class AgreeRouting(nn.Module):
def __init__(self, in_caps, out_caps, out_dims, iterations=3):
super(AgreeRouting, self).__init__()
self.in_caps = in_caps
self.out_caps = out_caps
self.out_dims = out_dims
self.iterations = iterations
@staticmethod
def softmax(inputs, dim=-1):
return F.softmax(inputs, dim=dim)
def forward(self, inputs):
# inputs N x in_caps x out_caps x out_dims
b = torch.zeros(inputs.size(0), self.in_caps, self.out_caps).to(inputs.device)
for r in range(self.iterations):
c = self.softmax(b) # N x in_caps x out_caps !!!!!!!!!
s = (c.unsqueeze(-1) * inputs).sum(dim=1) # N x out_caps x out_dims
v = squash(s) # N x out_caps x out_dims
b = b + (v.unsqueeze(dim=1) * inputs).sum(dim=-1)
return v
class CapsLayer(nn.Module):
def __init__(self, in_caps, in_dims, out_caps, out_dims, routing):
super(CapsLayer, self).__init__()
self.in_caps = in_caps
self.in_dims = in_dims
self.routing = routing
self.weights = nn.Parameter(torch.rand(in_caps, out_caps, in_dims, out_dims))
nn.init.kaiming_uniform_(self.weights)
def forward(self, inputs):
# inputs: N x in_caps x in_dims
inputs = inputs.view(inputs.size(0), self.in_caps, 1, 1, self.in_dims)
u_pres = (inputs @ self.weights).squeeze() # N x in_caps x out_caps x out_dims
outs = self.routing(u_pres) # N x out_caps x out_dims
return outs
class CapsNet(nn.Module):
def __init__(self):
super(CapsNet, self).__init__()
# N x 1 x 28 x 28
self.conv = nn.Conv2d(1, 256, 9, 1, padding=0) # N x (32 * 8) x 20 x 20
self.primarycaps = PrimaryCaps(256, 32, 8, 9, 2, 0) # N x (6 x 6 x 32) x 8
routing = AgreeRouting(32 * 6 * 6, 10, 8, 3)
self.digitlayer = CapsLayer(32 * 6 * 6, 8, 10, 16, routing)
def forward(self, inputs):
conv_outs = F.relu(self.conv(inputs))
pri_outs = self.primarycaps(conv_outs)
outs = self.digitlayer(pri_outs)
probs = outs.norm(dim=-1)
return probs
if __name__ == "__main__":
x = torch.randn(4, 1, 28 ,28)
capsnet = CapsNet()
print(capsnet(x))
def margin_loss(logits, labels, m=0.9, leverage=0.5, adverage=True):
# outs: N x num_classes x dim
# labels: N
temp1 = F.relu(m - logits) ** 2
temp2 = F.relu(logits + m - 1) ** 2
T = F.one_hot(labels.long(), logits.size(-1))
loss = (temp1 * T + leverage * temp2 * (1 - T)).sum()
if adverage:
loss = loss / logits.size(0)
# Another implement is using scatter_
# T = torch.zero(logits.size()).long()
# T.scatter_(dim=1, index=labels.view(-1, 1), 1.).cuda() if cuda()
return loss
Dynamic Routing Between Capsules的更多相关文章
- Hinton's paper Dynamic Routing Between Capsules 的 Tensorflow , Keras ,Pytorch实现
Tensorflow 实现 A Tensorflow implementation of CapsNet(Capsules Net) in Hinton's paper Dynamic Routing ...
- 【论文笔记】Dynamic Routing Between Capsules
Dynamic Routing Between Capsules 2018-09-16 20:18:30 Paper:https://arxiv.org/pdf/1710.09829.pdf%20 P ...
- Dynamic Routing Based On Redis
Dynamic Routing Based On Redis Ngnix技术研究系列2-基于Redis实现动态路由 上篇博文我们写了个引子: Ngnix技术研究系列1-通过应用场景看Nginx的反 ...
- dynamic routing between captual
对于人脑 决策树形式 对于CNN 层级与层级间的传递 人在识别物体的时候会进行坐标框架的设置 CNN无法识别,只能通过大量训练 胶囊 :一个神经元集合,有一个活动的向量,来表示物体的各类信息,向量的长 ...
- Paper | SkipNet: Learning Dynamic Routing in Convolutional Networks
目录 1. 概括 2. 相关工作 3. 方法细节 门限模块的结构 训练方法 4. 总结 作者对residual network进行了改进:加入了gating network,基于上一层的激活值,得到一 ...
- 总结近期CNN模型的发展(一)---- ResNet [1, 2] Wide ResNet [3] ResNeXt [4] DenseNet [5] DPNet [9] NASNet [10] SENet [11] Capsules [12]
总结近期CNN模型的发展(一) from:https://zhuanlan.zhihu.com/p/30746099 余俊 计算机视觉及深度学习 1.前言 好久没有更新专栏了,最近因为项目的原因接 ...
- 百年老图难倒谷歌AI,兔还是鸭?这是个问题
上面这张图,画的是鸭子还是兔子? 自从1892年首次出现在一本德国杂志上之后,这张图就一直持续引发争议.有些人只能看到一只兔子,有些人只能看到一只鸭子,有些人两个都能看出来. 心理学家用这张图证明了一 ...
- 浅析 Hinton 最近提出的 Capsule 计划
[原文] 浅析 Hinton 最近提出的 Capsule 计划 关于最新的 Hinton 的论文 Dynamic Routing Between Capsules,参见 https:// ...
- Hinton“深度学习之父”和“神经网络先驱”,新论文Capsule将推翻自己积累了30年的学术成果时
Hinton“深度学习之父”和“神经网络先驱”,新论文Capsule将推翻自己积累了30年的学术成果时 在论文中,Capsule被Hinton大神定义为这样一组神经元:其活动向量所表示的是特定实体类型 ...
随机推荐
- LeetCode两数之和
LeetCode 两数之和 题目描述 给定一个整数数组 nums 和一个目标值 target,请你在该数组中找出和为目标值的那两个整数,并返回他们的数组下标. 你可以假设每种输入只会对应一个答案.但是 ...
- A Child's History of England.14
At first, Elfrida possessed great influence over the young King, but, as he grew older and came of a ...
- docker之镜像制作
#:下载镜像并初始化系统 root@ubuntu:~# docker pull centos #:创建目录 root@ubuntu:/opt# mkdir dockerfile/{web/{nginx ...
- .NET6使用DOCFX自动生成开发文档
本文内容来自我写的开源电子书<WoW C#>,现在正在编写中,可以去WOW-Csharp/学习路径总结.md at master · sogeisetsu/WOW-Csharp (gith ...
- Java知识点总结——IO流框架
IO框架 一.流的概念 概念:内存与存储设备之间传输数据的通道. 二.流的分类 按方向分类: 输入流:将<存储设备>中的内容读入到<内存>中 输出流:将<内存>中的 ...
- 企业级BI是自研还是采购?
企业级BI是自研还是采购? 上一篇<企业级BI为什么这么难做?>,谈到了企业级BI项目所具有的特殊背景,以及在"破局"方面的一点思考,其中谈论的焦点主要是在IT开发项目 ...
- Redis监控参数
目录 一.客户端 二.服务端 一.客户端 127.0.0.1:6379> info stats #Redis自启动以来处理的客户端连接数总数 total_connections_received ...
- DevOps和SRE的区别
目录 一.误区 二.DevOps 和 SRE 定义 三.两者产生背景和历史 四.两者的职能不同 五.工作内容不同 六.DevOps 和 SRE 关系 七.附录:技能点 DevOps SRE 一.误区 ...
- 一文详解 纹理采样与Mipmap纹理——构建山地渲染效果
在开发一些相对较大的场景时,例如:一片铺满相同草地纹理的丘陵地形,如果不采用一些技术手段,就会出现远处的丘陵较近处的丘陵相比更加的清晰的视觉效果,而这种效果与真实世界中近处的物体清晰远处物体模糊的效果 ...
- Google Earth Engine 中的位运算
Google Earth Engine中的位运算 按位运算是编程中一个难点,同时也是在我们后续处理影像数据,尤其要使用影像自带的波段比如QA波段经常会用到的一个东西.通过按位运算我们可以筛选出我们想要 ...