1 简介

胶囊网络(CapsNet)由 Hinton 于2017年10月在《Dynamic Routing Between Capsules》中提出,目的在于解决 CNN 只能提取特征,而不能提取特征的状态方向位置等信息,导致模型的泛化(举一反三)能力较差,如:

(1)将图片旋转180°,CNN 模型可能不能准确识别,因为 CNN 模型不能识别特征的方向信息,如需识别倒着的图片,需要使用倒着的图片训练 CNN 模型;

(2)将某人脸部图片的眼睛和嘴巴位置调换,CNN 模型可能也会识别为此人,因为 CNN 模型只关注眼睛、鼻子、耳朵、嘴巴等脸部特征是否准确,而不关注特征位置是否正确。

传统神经网络的基本单位是神经元,表示一个标量;胶囊网络的基本单位是胶囊(capsule),表示一个向量,由于向量有方向和大小,因此能够识别特征的状态、方向、位置等信息。

2 基本原理

胶囊网络(CapsNet)主要由动态路由层(Dynamic Routing Layer)堆叠而成,本节主要讲解动态路由层的前向传播原理。

胶囊网络的结构与全连接层相似,如下图所示:

胶囊网络结构

其中,u1,u2,...,um 为底层胶囊,v1,v2,...,vn 为高层胶囊,W 和 C 为待调参数,W 通过网络反向传播更新,C 通过动态路由算法更新

2.1 前向传播

前向传播公式如下:

u 为上层胶囊输出,v 为本层胶囊输出。W 不一定是全连接形式的参数,“”也不一定是向量乘法运算;通常,“”指一维卷积运算,W 指卷积核。c 为耦合系数,由动态路由算法获取。squash 为激活函数,如下式所示:

从上式可以看出,vj 和 sj 方向相同,并且 vj 的模长经squash()函数非线性映射到 [0, 1) 上,增加了模型的非线性映射能力。

2.2 动态路由

c 通过动态路由算法获取,在模型的反向传播中,不更新 c 的取值。动态路由算法的核心公式如下:

其中,<>表示内积运算,bi 的维度与 ui 的维度相同,初值为0。

为方便清晰查看胶囊网络中数据流向,笔者绘制了数据流向图如下:(注:图中省略了复杂的变量角标)

其中,红线表示前向传播数据流向,绿线表示动态路由数据流向;带紫色圆圈背景的变量表示在动态路由算法中会更新的变量。从图中可以看出,cij 的大小主要取决于 <ui, vj> ,即 vj 与 ui 的相似度,因此,当 vj 与 ui 较相似时,cij 较大,这点有点像注意力机制。

2.3 损失函数——Margin Loss

其中,y 为预测值,m 和 λ 为待调参数,通常 m=0.1 、 λ =0.5

3 实验

本文以 MNIST 手写数字分类为例,讲解胶囊网络模型。关于 MNIST 数据集的说明,见使用TensorFlow实现MNIST数据集分类。通过在 CNN 模型的最后一层叠加一层动态路由层,并将每个胶囊输出向量的模长作为提取数字特征的强度。

笔者工作空间如下:

代码资源见--> CapsNet

CapsuleLayer.py

from keras import backend as K
from keras.layers import Layer """
压缩函数,使用0.5替代hinton论文中的1,如果是1,所有的向量的范数都将被缩小。
如果是0.5,小于0.5的范数将缩小,大于0.5的将被放大
"""
def squash(x, axis=-1):
s_quared_norm = K.sum(K.square(x), axis, keepdims=True) + K.epsilon() #||x||^2
scale = K.sqrt(s_quared_norm) / (0.5 + s_quared_norm) #||x||/(0.5+||x||^2)
result = scale * x
return result # 定义我们自己的softmax函数,而不是K.softmax.因为K.softmax不能指定轴
def softmax(x, axis=-1):
ex = K.exp(x - K.max(x, axis=axis, keepdims=True))
result = ex / K.sum(ex, axis=axis, keepdims=True)
return result # 定义边缘损失,输入y_true, p_pred,返回分数,传入fit即可
def margin_loss(y_true, y_pred):
lamb, margin = 0.5, 0.1
result = K.sum(y_true * K.square(K.relu(1 - margin -y_pred))
+ lamb * (1-y_true) * K.square(K.relu(y_pred - margin)), axis=-1)
return result class Capsule(Layer):
def __init__(self,
num_capsule,
dim_capsule,
routings=3,
share_weights=True,
activation='squash',
**kwargs):
super(Capsule, self).__init__(**kwargs) # Capsule继承**kwargs参数
self.num_capsule = num_capsule
self.dim_capsule = dim_capsule
self.routings = routings
self.share_weights = share_weights
if activation == 'squash':
self.activation = squash
else:
self.activation = activation.get(activation) # 得到激活函数 # 定义权重
def build(self, input_shape):
input_dim_capsule = input_shape[-1]
if self.share_weights:
# 自定义权重
self.kernel = self.add_weight( #[row,col,channel]->[1,input_dim_capsule,num_capsule*dim_capsule]
name='capsule_kernel',
shape=(1, input_dim_capsule,
self.num_capsule * self.dim_capsule),
initializer='glorot_uniform',
trainable=True)
else:
input_num_capsule = input_shape[-2]
self.kernel = self.add_weight(
name='capsule_kernel',
shape=(input_num_capsule, input_dim_capsule,
self.num_capsule * self.dim_capsule),
initializer='glorot_uniform',
trainable=True)
super(Capsule, self).build(input_shape) # 必须继承Layer的build方法 # 层的功能逻辑(核心)
def call(self, inputs):
if self.share_weights:
#inputs: [batch, input_num_capsule, input_dim_capsule]
#kernel: [1, input_dim_capsule, num_capsule*dim_capsule]
#hat_inputs: [batch, input_num_capsule, num_capsule*dim_capsule]
hat_inputs = K.conv1d(inputs, self.kernel)
else:
hat_inputs = K.local_conv1d(inputs, self.kernel, [1], [1]) batch_size = K.shape(inputs)[0]
input_num_capsule = K.shape(inputs)[1]
hat_inputs = K.reshape(hat_inputs,
(batch_size, input_num_capsule,
self.num_capsule, self.dim_capsule))
#hat_inputs: [batch, input_num_capsule, num_capsule, dim_capsule]
hat_inputs = K.permute_dimensions(hat_inputs, (0, 2, 1, 3))
#hat_inputs: [batch, num_capsule, input_num_capsule, dim_capsule]
b = K.zeros_like(hat_inputs[:, :, :, 0])
#b: [batch, num_capsule, input_num_capsule]
for i in range(self.routings):
c = softmax(b, 1)
o = self.activation(K.batch_dot(c, hat_inputs, [2, 2]))
if K.backend() == 'theano':
o = K.sum(o, axis=1)
if i < self.routings-1:
b += K.batch_dot(o, hat_inputs, [2, 3])
if K.backend() == 'theano':
o = K.sum(o, axis=1)
return o def compute_output_shape(self, input_shape): # 自动推断shape
return (None, self.num_capsule, self.dim_capsule)

注:以上代码来自--> https://github.com/bojone/Capsule

CapsNet.py

from keras import backend as K
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.layers import Input,Conv2D,MaxPooling2D,Reshape,Lambda
from CapsuleLayer import Capsule,margin_loss #载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images.reshape(-1,28,28,1),mnist.train.labels,
valid_x,valid_y=mnist.validation.images.reshape(-1,28,28,1),mnist.validation.labels,
test_x,test_y=mnist.test.images.reshape(-1,28,28,1),mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y #模型
def MODEL():
inputs=Input(shape=(28, 28, 1))
x=Conv2D(16, (5, 5), padding='same', activation='relu')(inputs)
x=MaxPooling2D(pool_size=(2,2))(x)
x=Conv2D(32, (5, 5), padding='same', activation='relu')(x)
x=MaxPooling2D(pool_size=(2,2))(x)
x=Reshape((-1, 32))(x) # [None, 49, 32] 即前一层胶囊 [None, input_num, input_dim]
x=Capsule(num_capsule=10, dim_capsule=30, routings=5)(x)
output=Lambda(lambda z: K.sqrt(K.sum(K.square(z), axis=2)))(x) #每个胶囊取模长
model=Model(inputs=inputs, output=output)
return model #主函数
def main(train_x,train_y,valid_x,valid_y,test_x,test_y):
model = MODEL()
model.compile(loss=margin_loss, optimizer='adam', metrics=['accuracy'])
model.summary()
model.fit(train_x,train_y,batch_size=500,nb_epoch=20,verbose=2)
pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2)
print('test_loss:',pre[0],'- test_acc:',pre[1]) train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
main(train_x,train_y,valid_x,valid_y,test_x,test_y)

网络各层输出尺寸:

_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 28, 28, 1) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 28, 28, 16) 416
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 16) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 14, 14, 32) 12832
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 7, 7, 32) 0
_________________________________________________________________
reshape_1 (Reshape) (None, 49, 32) 0
_________________________________________________________________
capsule_1 (Capsule) (None, 10, 30) 9600
_________________________________________________________________
lambda_1 (Lambda) (None, 10) 0
=================================================================
Total params: 22,848
Trainable params: 22,848
Non-trainable params: 0

网络训练结果:

Epoch 18/20
- 42s - loss: 0.0247 - acc: 0.9816
Epoch 19/20
- 43s - loss: 0.0239 - acc: 0.9828
Epoch 20/20
- 43s - loss: 0.0234 - acc: 0.9825
test_loss: 0.024427699740044773 - test_acc: 0.9793000012636185

4 扩展阅读

揭开迷雾,来一顿美味的Capsule盛宴

再来一顿贺岁宴:从K-Means到Capsule

​ 声明:本文转自基于keras的胶囊网络(CapsNet)

基于keras的胶囊网络(CapsNet)的更多相关文章

  1. 基于 Keras 用 LSTM 网络做时间序列预测

    目录 基于 Keras 用 LSTM 网络做时间序列预测 问题描述 长短记忆网络 LSTM 网络回归 LSTM 网络回归结合窗口法 基于时间步的 LSTM 网络回归 在批量训练之间保持 LSTM 的记 ...

  2. CapsNet胶囊网络(理解)

    0 - 背景 Geoffrey Hinton是深度学习的开创者之一,反向传播等神经网络经典算法发明人,他在去年年底和他的团队发表了两篇论文,介绍了一种全新的神经网络,这种网络基于一种称为胶囊(caps ...

  3. [深度应用]·首届中国心电智能大赛初赛开源Baseline(基于Keras val_acc: 0.88)

    [深度应用]·首届中国心电智能大赛初赛开源Baseline(基于Keras val_acc: 0.88) 个人主页--> https://xiaosongshine.github.io/ 项目g ...

  4. 胶囊网络 -- Capsule Networks

    胶囊网络是 vector in vector out的结构,最后对每个不同的类别,输出不一个向量,向量的模长表示属于该类别的概率. 例如,在数字识别中,两个数字虽然重叠在一起,Capsule中的两个向 ...

  5. Hinton胶囊网络后最新研究:用“在线蒸馏”训练大规模分布式神经网络

    Hinton胶囊网络后最新研究:用“在线蒸馏”训练大规模分布式神经网络 朱晓霞发表于目标检测和深度学习订阅 457 广告关闭 11.11 智慧上云 云服务器企业新用户优先购,享双11同等价格 立即抢购 ...

  6. 基于keras实现的中文实体识别

    1.简介 NER(Named Entity Recognition,命名实体识别)又称作专名识别,是自然语言处理中常见的一项任务,使用的范围非常广.命名实体通常指的是文本中具有特别意义或者指代性非常强 ...

  7. 寻找研究基于NS2研究覆盖网络的小伙伴:)

    如题,本人菜鸟刚刚入门,想找些基于NS2研究覆盖网络方面的小伙伴,具体点是关于覆盖网络中QoS服务调度方法方面的,有的小伙伴可以留下联系方式,或者加我QQ:245939069  :P:P:P

  8. 开源一个基于nio的java网络程序

    因为最近要从公司离职,害怕用nio写的网络程序没有人能看懂(或许是因为写的不好吧),就调整成了mina(这样大家接触起来非常方便,即使没有socket基础,用起来也不难),所以之前基于nio写的网络程 ...

  9. 基于Android Volley的网络请求工具

    基于Android Volley的网络请求工具. 一.说明 AndroidVolley,Android Volley核心库及扩展工程.AndroidVolleySample,网络请求工具示例工程.Re ...

  10. [AI开发]centOS7.5上基于keras/tensorflow深度学习环境搭建

    这篇文章详细介绍在centOS7.5上搭建基于keras/tensorflow的深度学习环境,该环境可用于实际生产.本人现在非常熟练linux(Ubuntu/centOS/openSUSE).wind ...

随机推荐

  1. DataGrip连接MySql数据库失败:dataGrip java.net.ConnectException: Connection refused: connect.

    1.问题 报错:dataGrip java.net.ConnectException: Connection refused: connect. 详细错误:[08S01] Communications ...

  2. MPC 是下一代私钥安全的7大原因

    PrimiHub一款由密码学专家团队打造的开源隐私计算平台,专注于分享数据安全.密码学.联邦学习.同态加密等隐私计算领域的技术和内容. 多重签名钱包与单一密钥钱包相比,因其提升了资产安全性,如今已成为 ...

  3. TLS 加密套件的学习与了解

    TLS 加密套件的学习与了解 加密套件 什么是加密套件? 加密套件是用于在SSL / TLS握手期间协商安全设置的算法的组合. 在ClientHello和ServerHello消息交换之后,客户端发送 ...

  4. [转帖]Oracle数据库中ITL详解

    首先说明这篇文章是转载的,原文地址:http://blog.sina.com.cn/s/blog_616b428f0100lwvq.html 1.什么是ITL ITL(Interested Trans ...

  5. [转帖]sqluldr2 oracle直接导出数据为文本的小工具使用

    https://www.cnblogs.com/ocp-100/p/11098373.html 近期客户有需求,导出某些审计数据,供审计人进行核查,只能导出成文本或excel格式的进行查看,这里我们使 ...

  6. Nginx 系列 | (转)Nginx 上传文件:client_max_body_size 、client_body_buffer_size

    原文:http://php-note.com/article/detail/488 client_max_body_size client_max_body_size 默认 1M,表示 客户端请求服务 ...

  7. 公司内部自建DNS的办法 使用私有域名的方法

    最近总是有一个需求,需要自己弄一些服务器域名之类的. 修改hosts总是比较麻烦,所以想了一个简单办法, 自己搭建一个dns服务器, 本来想用最简单的 dnsmasq 但是发现总是不成功, 然后找了另 ...

  8. net core部署iis执行此操作时出错web.config

    页面访问会报服务器内部错误,你点对应的IIS下的默认页面或模块会出现下面的错语. 请到官网下载对应的运行时:https://www.microsoft.com/net/download 如果是服务器, ...

  9. 使用svn.externals(外链)提升美术多个svn目录的svn up速度

    svn up多个目录耗时大 svn上的美术资源项目,在打包机上对一个很久没有变化的目录进行svn up也是需要消耗不少时间的,特别打包时需要对多个目录进行svn up,比如空跑54个目录的svn up ...

  10. Python 使用Scapy构造特殊数据包

    Scapy是一款Python库,可用于构建.发送.接收和解析网络数据包.除了实现端口扫描外,它还可以用于实现各种网络安全工具,例如SynFlood攻击,Sockstress攻击,DNS查询攻击,ARP ...