Dynamic Routing Between Capsules
概
虽然11年就提出了capsule的概念, 但是走入人们视线的应该还是这篇文章吧. 虽然现阶段, capsule没有体现出什么优势. 不过, capsule相较于传统的CNN融入了很多先验知识, 更能够拟合人类的视觉系统(我不知), 或许有一天它会大放异彩.
主要内容
直接从这个结构图讲起吧.
- Input: 1 x 28 x 28 的图片 经过 9 x 9的卷积核(stride=1, padding=0, out_channels=256)作用;
- 256 x 20 x 20的特征图, 经过primarycaps作用(9 x 9 的卷积核(strde=2, padding=0, out_channels=256);
- (32 x 8) x 6 x 6的特征图, 理解为32 x 6 x 6 x 8 = 1152 x 8, 即1152个胶囊, 每个胶囊由一个8D的向量表示\(u_{i}\); (这个地方要不要squash, 大部分实现都是要的.)
- 接下来digitcaps中有10个caps(对应10个类别), 1152caps和10个caps一一对应, 分别用\(i, j\)表示, 前一层的caps为后一层提供输入, 输入为
\]
可见, 应当有1152 x 10个\(W_{ij}\in \mathbb{R}^{16\times 8}\), 其中16是输出胶囊的维度. 最后10个caps的输出为
\]
其中\(c_{ij}\)是通过一个路由算法决定的, \(v_j\), 即最后的输入如此定义是出于一种直觉, 即保持原始输出(\(s\))的方向, 同时让\(v\)的长度表示一个概率(这一步称为squash).
首先初始化\(b_{ij}=0\) (这里在程序实现的时候有一个考量, 是每一次都要初始化吗, 我看大部分的实现都是如此的).
上面的Eq.3就是
c_{ij}=\frac{\exp(b_{ij})}{\sum_{k}\exp(b_{ik})}.
\]
另外\(\hat{\mu}_{j|i} \cdot v_j=\hat{\mu}_{j|i}^Tv_j\)是一种cos相似度度量.
损失函数
损失函数采用的是margin loss:
L_k = T_k \max(0, m^+ - \|v_k\|)^2 + \lambda (1 - T_k) \max(0, \|v_k\|-m^-)^2.
\]
\(m^+, m^-\)通常取0.9和0.1, \(\lambda\)通常取0.5.
代码
我的代码, 在sgd下可以训练(但是准确率只有98), 在adam下就死翘翘了, 所以代码肯定是有问题, 但是我实在是找不出来了, 这里有很多实现的汇总.
"""
Sabour S., Frosst N., Hinton G. Dynamic Routing Between Capsules.
Neural Information Processing Systems, pp. 3856-3866, 2017.
https://arxiv.org/pdf/1710.09829.pdf
The implement below refers to https://github.com/adambielski/CapsNet-pytorch.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
def squash(s):
temp = s.norm(dim=-1, keepdim=True)
return (temp / (1. + temp ** 2)) * s
class PrimaryCaps(nn.Module):
def __init__(
self, in_channel, out_entities,
out_dims, kernel_size, stride, padding
):
super(PrimaryCaps, self).__init__()
self.conv = nn.Conv2d(in_channel, out_entities * out_dims,
kernel_size, stride, padding)
self.out_entities = out_entities
self.out_dims = out_dims
def forward(self, inputs):
conv_outs = self.conv(inputs).permute(0, 2, 3, 1).contiguous()
outs = conv_outs.view(conv_outs.size(0), -1, self.out_dims)
return squash(outs)
class AgreeRouting(nn.Module):
def __init__(self, in_caps, out_caps, out_dims, iterations=3):
super(AgreeRouting, self).__init__()
self.in_caps = in_caps
self.out_caps = out_caps
self.out_dims = out_dims
self.iterations = iterations
@staticmethod
def softmax(inputs, dim=-1):
return F.softmax(inputs, dim=dim)
def forward(self, inputs):
# inputs N x in_caps x out_caps x out_dims
b = torch.zeros(inputs.size(0), self.in_caps, self.out_caps).to(inputs.device)
for r in range(self.iterations):
c = self.softmax(b) # N x in_caps x out_caps !!!!!!!!!
s = (c.unsqueeze(-1) * inputs).sum(dim=1) # N x out_caps x out_dims
v = squash(s) # N x out_caps x out_dims
b = b + (v.unsqueeze(dim=1) * inputs).sum(dim=-1)
return v
class CapsLayer(nn.Module):
def __init__(self, in_caps, in_dims, out_caps, out_dims, routing):
super(CapsLayer, self).__init__()
self.in_caps = in_caps
self.in_dims = in_dims
self.routing = routing
self.weights = nn.Parameter(torch.rand(in_caps, out_caps, in_dims, out_dims))
nn.init.kaiming_uniform_(self.weights)
def forward(self, inputs):
# inputs: N x in_caps x in_dims
inputs = inputs.view(inputs.size(0), self.in_caps, 1, 1, self.in_dims)
u_pres = (inputs @ self.weights).squeeze() # N x in_caps x out_caps x out_dims
outs = self.routing(u_pres) # N x out_caps x out_dims
return outs
class CapsNet(nn.Module):
def __init__(self):
super(CapsNet, self).__init__()
# N x 1 x 28 x 28
self.conv = nn.Conv2d(1, 256, 9, 1, padding=0) # N x (32 * 8) x 20 x 20
self.primarycaps = PrimaryCaps(256, 32, 8, 9, 2, 0) # N x (6 x 6 x 32) x 8
routing = AgreeRouting(32 * 6 * 6, 10, 8, 3)
self.digitlayer = CapsLayer(32 * 6 * 6, 8, 10, 16, routing)
def forward(self, inputs):
conv_outs = F.relu(self.conv(inputs))
pri_outs = self.primarycaps(conv_outs)
outs = self.digitlayer(pri_outs)
probs = outs.norm(dim=-1)
return probs
if __name__ == "__main__":
x = torch.randn(4, 1, 28 ,28)
capsnet = CapsNet()
print(capsnet(x))
def margin_loss(logits, labels, m=0.9, leverage=0.5, adverage=True):
# outs: N x num_classes x dim
# labels: N
temp1 = F.relu(m - logits) ** 2
temp2 = F.relu(logits + m - 1) ** 2
T = F.one_hot(labels.long(), logits.size(-1))
loss = (temp1 * T + leverage * temp2 * (1 - T)).sum()
if adverage:
loss = loss / logits.size(0)
# Another implement is using scatter_
# T = torch.zero(logits.size()).long()
# T.scatter_(dim=1, index=labels.view(-1, 1), 1.).cuda() if cuda()
return loss
Dynamic Routing Between Capsules的更多相关文章
- Hinton's paper Dynamic Routing Between Capsules 的 Tensorflow , Keras ,Pytorch实现
Tensorflow 实现 A Tensorflow implementation of CapsNet(Capsules Net) in Hinton's paper Dynamic Routing ...
- 【论文笔记】Dynamic Routing Between Capsules
Dynamic Routing Between Capsules 2018-09-16 20:18:30 Paper:https://arxiv.org/pdf/1710.09829.pdf%20 P ...
- Dynamic Routing Based On Redis
Dynamic Routing Based On Redis Ngnix技术研究系列2-基于Redis实现动态路由 上篇博文我们写了个引子: Ngnix技术研究系列1-通过应用场景看Nginx的反 ...
- dynamic routing between captual
对于人脑 决策树形式 对于CNN 层级与层级间的传递 人在识别物体的时候会进行坐标框架的设置 CNN无法识别,只能通过大量训练 胶囊 :一个神经元集合,有一个活动的向量,来表示物体的各类信息,向量的长 ...
- Paper | SkipNet: Learning Dynamic Routing in Convolutional Networks
目录 1. 概括 2. 相关工作 3. 方法细节 门限模块的结构 训练方法 4. 总结 作者对residual network进行了改进:加入了gating network,基于上一层的激活值,得到一 ...
- 总结近期CNN模型的发展(一)---- ResNet [1, 2] Wide ResNet [3] ResNeXt [4] DenseNet [5] DPNet [9] NASNet [10] SENet [11] Capsules [12]
总结近期CNN模型的发展(一) from:https://zhuanlan.zhihu.com/p/30746099 余俊 计算机视觉及深度学习 1.前言 好久没有更新专栏了,最近因为项目的原因接 ...
- 百年老图难倒谷歌AI,兔还是鸭?这是个问题
上面这张图,画的是鸭子还是兔子? 自从1892年首次出现在一本德国杂志上之后,这张图就一直持续引发争议.有些人只能看到一只兔子,有些人只能看到一只鸭子,有些人两个都能看出来. 心理学家用这张图证明了一 ...
- 浅析 Hinton 最近提出的 Capsule 计划
[原文] 浅析 Hinton 最近提出的 Capsule 计划 关于最新的 Hinton 的论文 Dynamic Routing Between Capsules,参见 https:// ...
- Hinton“深度学习之父”和“神经网络先驱”,新论文Capsule将推翻自己积累了30年的学术成果时
Hinton“深度学习之父”和“神经网络先驱”,新论文Capsule将推翻自己积累了30年的学术成果时 在论文中,Capsule被Hinton大神定义为这样一组神经元:其活动向量所表示的是特定实体类型 ...
随机推荐
- 27.0 linux VM虚拟机IP问题
我的虚拟机是每次换一个不同的网络,b不同的ip,使用桥接模式就无法连接,就需要重新还原默认设置才行: 第一步:点击虚拟机中的编辑-->虚拟网络编辑器 第二步:点击更改设置以管理员权限进入 第三步 ...
- HDFS【Java API操作】
通过java的api对hdfs的资源进行操作 代码:上传.下载.删除.移动/修改.文件详情.判断目录or文件.IO流操作上传/下载 package com.atguigu.hdfsdemo; impo ...
- 大数据学习day23-----spark06--------1. Spark执行流程(知识补充:RDD的依赖关系)2. Repartition和coalesce算子的区别 3.触发多次actions时,速度不一样 4. RDD的深入理解(错误例子,RDD数据是如何获取的)5 购物的相关计算
1. Spark执行流程 知识补充:RDD的依赖关系 RDD的依赖关系分为两类:窄依赖(Narrow Dependency)和宽依赖(Shuffle Dependency) (1)窄依赖 窄依赖指的是 ...
- 大数据学习day20-----spark03-----RDD编程实战案例(1 计算订单分类成交金额,2 将订单信息关联分类信息,并将这些数据存入Hbase中,3 使用Spark读取日志文件,根据Ip地址,查询地址对应的位置信息
1 RDD编程实战案例一 数据样例 字段说明: 其中cid中1代表手机,2代表家具,3代表服装 1.1 计算订单分类成交金额 需求:在给定的订单数据,根据订单的分类ID进行聚合,然后管理订单分类名称, ...
- 零基础学习java------38---------spring中关于通知类型的补充,springmvc,springmvc入门程序,访问保护资源,参数的绑定(简单数据类型,POJO,包装类),返回数据类型,三大组件,注解
一. 通知类型 spring aop通知(advice)分成五类: (1)前置通知[Before advice]:在连接点前面执行,前置通知不会影响连接点的执行,除非此处抛出异常. (2)正常返回通知 ...
- 如何从 100 亿 URL 中找出相同的 URL?
题目描述 给定 a.b 两个文件,各存放 50 亿个 URL,每个 URL 各占 64B,内存限制是 4G.请找出 a.b 两个文件共同的 URL. 解答思路 每个 URL 占 64B,那么 50 亿 ...
- spring boot项目创建与使用
概述 spring boot通常使用maven创建,重点在于pom.xml配置,有了pom.xml配置,可以先创建一个空的maven项目,然后从maven下载相关jar包. spring boot d ...
- ubantu安装maven
下载地址 http://maven.apache.org/download.cgi 或直接命令行下载 wget https://downloads.apache.org/maven/maven-3/3 ...
- 【Java基础】Java反射——Private Fields and Methods
Despite the common belief it is actually possible to access private fields and methods of other clas ...
- 解决在进行socket通信时,一端输出流OutputStream不关闭,另一端输入流就接收不到数据
输出的数据需要达到一定的量才会向另一端输出,所以在传输数据的末端添加 \r\n 可以保证不管数据量是多少,都立刻传输到另一端.