以CapsNet为例谈深度学习源码阅读
本文的参考的github工程链接:https://github.com/laubonghaudoi/CapsNet_guide_PyTorch
之前是看过一些深度学习的代码,但是没有养成良好的阅读规范,由于最近在学习CapsNet的原理,在Github找到了一个很好的示例教程,作者甚至给出了比较好的代码阅读顺序,私以为该顺序具有较强的代码阅读迁移性,遂以此工程为例将该代码分析过程记录于此:
1、代码先看main(),main()为工程中最为顶层的设计,能够给人对于整个流程的把控。而对于深度学习而言,main一般即为加载数据、构建模型、确定优化算法、训练网络模型、保存模型参数这种很具有规范性的结构。
if __name__ == "__main__":
# Default configurations
opt = get_opts()
train_loader, test_loader = get_dataloader(opt) # Initialize CapsNet
model = CapsNet(opt) # Enable GPU usage
if opt.use_cuda & torch.cuda.is_available():
model.cuda() # Print the model architecture and parameters
print("Model architectures: ")
print(model) print("\nSizes of parameters: ")
for name, param in model.named_parameters():
print("{}: {}".format(name, list(param.size())))
n_params = sum([p.nelement() for p in model.parameters()])
# The coupling coefficients b_ij are not included in the parameter list,
# we need to add them mannually, which is 1152 * 10 = 11520.
print('\nTotal number of parameters: %d \n' % (n_params+11520)) # Make model checkpoint directory
if not os.path.exists('ckpt'):
os.makedirs('ckpt') # Start training
train(opt, train_loader, test_loader, model, writer)
2、后面看utils.py文件里面的函数,很多比较复杂的工程中都会有这个文件,一般都是一些工程中较为基础的函数,在CapsNet这个工程中,这个文件中包含了相关的配置以及dataloarder。
def get_dataloader(opt):
# MNIST Dataset ... # Data Loader (Input Pipeline) ... return train_loader, test_loader def get_opts():
parser = argparse.ArgumentParser(description='CapsNet')
# ....
opt = parser.parse_args() return opt
3、然后在弄明白前向传播中最为顶层的设计,一般就是顶层神经网络的__init__()以及forward()
该工程中的CapsNet主要分为四个大部分:
- Conv2d, 用了256个 9×9的卷积核,步长为1,后面跟着Relu。这样对于28*28的图片,输出为[256,20,20 ]
- PrimaryCaps: capsule层,具体构造后面再讲
- DigitCaps:capsule层,具体构造后面再讲
- Decoder:全连接层
4、在网络前向传播的顶层肯定调用了一些层级稍微低一些的module,下面就看这些module,本工程中主要是PrimaryCaps和DigitCaps。
PrimaryCaps
PrimaryCaps包含了32个 capsule units, 每个capsule unit都会接收来自于第一层卷积所输出的feature map的所有数据。首先获得32个张量u,这32个张量u是通过32个卷积运算得到的,前面输入的为第一层卷积所得[256,20,20 ]的feature maps,32个卷积每个都是(out_channels=8, kernel_size=9, stride=2),这个地方使用了Modulelist来构造重复的卷积运算module,值得学习。在forward中将每个卷积moduel计算所得的结果append到list中,这样后面使用torch.cat的时候可以直接使用了。问题在于后面对于这32个张量的维度顺序做了变换。
坐标顺序变换记录于此:
- 每个conv_module输出为[batch_size, 8 ,6,6],便变成了[batch_size, 8 ,36, 1]的形式,也就是这8个feature map中的每个6×6的feature map变成了一个向量
- 对32个conv_module输出的张量cat,保存形式为[batch_size, 8, 36, 32]
- 再次变换为[batch_size,8,36×32] ,这个地方我并没有搞懂这么做有什么意义,这和直接拿32*8个卷积核去卷积的区别在哪呢?直接拿32个卷积核卷积,然后将这32*8个卷积核再分为8组不也一样吗?
- 做了一次维度变换,变为[batch_size, 36×32,8]的形式
上步计算完成后,后面计算squash,这步计算类似于Relu,相当于向量的Relu操作。这个地方可以看出一个很重要的一点,就是向量v是几维的,一个基本的v包含几个数,从代码中看是8个数,也就是说PrimaryCaps开始时的每个卷积module输出的channels数为8,是这个维度组成了向量。
DigitCaps
这一层和上一层都是由capsule组成的,中间的连接是类似于全连接但又有很多的不同。
下面的表示均忽略batch_size:
上一层的输入[36*32,8], 也就是有36*32个输入向量u。计算步骤如下:
- 首先计算u_hat,将输入变换为[36*32,1,1,8]的形式,中间权重为[36*32, 10, 8, 16],这样矩阵相乘的结果为[36*32, 10, 1, 16], 此处的16应该就是输出向量的维度
- 后面的处理与10这个维度有关系,在图中就是c_ij,需要构造的c_ij的数量为[36*32, 10,1],在一次整个网路的前向传播过程中,c_ij的初始值为0,会在一次前向传播过程中内部迭代几次,叫做动态路由算法。如下图所示:
- u_hat的维度为 [ 36*32, 10, 16],s的维度为[10, 16],v的维度为[10,16],这中间有将36*32个数相加的过程,更新c_ij是这样的:先将v变为[1,10,16],再计算u_hat*v得到[36*32, 10, 16],将里层维度相加,急求的是向量相乘,就会有方向的信息。由此更新c_ij
(注:该图来自于https://blog.csdn.net/wc781708249/article/details/80015997)
Decoder:
Decoder 部分是由三层全连接层组成的。这部分是一个重构部分,希望借此部分重新构建出图片。(有点像自编码器)
下面的维度忽略batch_size。
前面输出的是[10,16], 这个地方是将10个16维向量中与target中1相对的那个16维的向量取出作为后面全连接层的输入,后面全连接的维度为16,512,1024,784。 784即28*28。
5、损失函数
损失函数主要包括两部分,一部分是DigitCaps输出的loss,一部分是Decorder的loss。
DigitCaps层的输出是10个16维向量:
计算时,先根据上式计算出每个向量的损失值,然后将10个损失值相加得到最终损失。每个训练样本都有正确的标签,在这种情况下,标签将是一个10维one-hot编码向量。假设正确的标签是1,这意味着第一个DigitCap负责编码数字1的存在。这一DigitCap的损失函数的Tc为1,其余9个DigitCap的Tc为0。当Tc为1时,损失函数的第二项为零,损失函数的值通过第一项计算。在我们的例子中,为了计算第一个DigitCap的损失,我们从m+减去这一DigitCap的输出向量,其中,m+取固定值0.9。接着,我们保留所得值(仅当所得值大于零时)并取平方。否则,返回0。换句话说,当正确DigitCap预测正确标签的概率大于0.9时,损失函数为零,当概率小于0.9时,损失函数不为零。
公式包括了一个lambda系数以确保训练中的数值稳定性(lambda为固定值0.5)。这两项取平方是为了让损失函数符合L2正则,看起来作者们认为这样正则化一下效果更好。
对于Decoder的loss,loss就是求得输入的Image与Decorder输出的784个数的欧式距离平方和。
对于CapsNet的基本原理,该博客给出了比较好的解释:http://www.cnblogs.com/CZiFan/p/9803067.html
以CapsNet为例谈深度学习源码阅读的更多相关文章
- NLP大赛冠军总结:300万知乎多标签文本分类任务(附深度学习源码)
NLP大赛冠军总结:300万知乎多标签文本分类任务(附深度学习源码) 七月,酷暑难耐,认识的几位同学参加知乎看山杯,均取得不错的排名.当时天池AI医疗大赛初赛结束,官方正在为复赛进行平台调 ...
- 源码阅读经验谈-slim,darknet,labelimg,caffe(1)
本文首先谈自己的源码阅读体验,然后给几个案例解读,选的例子都是比较简单.重在说明我琢磨的点线面源码阅读方法.我不是专业架构师,是从一个深度学习算法工程师的角度来谈的,不专业的地方请大家轻拍. 经常看别 ...
- fw: 专访许鹏:谈C程序员修养及大型项目源码阅读与学习
C家最近也有一篇关于如何阅读大型c项目源代码的文章,学习..融合.. -------------------- ref:http://www.csdn.net/article/2014-06-05 ...
- 深度学习(七十一)darknet 源码阅读
深度学习(七十一)darknet 源码阅读
- 转:浅谈深度学习(Deep Learning)的基本思想和方法
浅谈深度学习(Deep Learning)的基本思想和方法 参考:http://blog.csdn.net/xianlingmao/article/details/8478562 深度学习(Deep ...
- 【 js 基础 】【 源码学习 】backbone 源码阅读(三)浅谈 REST 和 CRUD
最近看完了 backbone.js 的源码,这里对于源码的细节就不再赘述了,大家可以 star 我的源码阅读项目(https://github.com/JiayiLi/source-code-stud ...
- 学习源码的第八个月,我成了Spring的开源贡献者
@ 目录 我的经历 碰到的问题 1.担心闹乌龙 2.不知道要怎么提交 3.英文 4.担心问题描述的不清楚 给你的建议 我的经历 关注我的朋友都知道,关注两个字划重点,要考! 我最近一直在写Spring ...
- Java源码阅读的真实体会(一种学习思路)
Java源码阅读的真实体会(一种学习思路) 刚才在论坛不经意间,看到有关源码阅读的帖子.回想自己前几年,阅读源码那种兴奋和成就感(1),不禁又有一种激动. 源码阅读,我觉得最核心有三点:技术基础+强烈 ...
- Java源码阅读的真实体会(一种学习思路)【转】
Java源码阅读的真实体会(一种学习思路) 刚才在论坛不经意间,看到有关源码阅读的帖子.回想自己前几年,阅读源码那种兴奋和成就感(1),不禁又有一种激动. 源码阅读,我觉得最核心有三点:技术基础+ ...
随机推荐
- Python3基础 list append 向尾部添加一个元素
Python : 3.7.0 OS : Ubuntu 18.04.1 LTS IDE : PyCharm 2018.2.4 Conda ...
- Spring整合Redis&JSON序列化&Spring/Web项目部署相关
几种JSON框架用法和效率对比: https://blog.csdn.net/sisyphus_z/article/details/53333925 https://blog.csdn.net/wei ...
- 题解——POJ 2234 Matches Game
这道题也是一个博弈论 根据一个性质 对于\( Nim \)游戏,即双方可以任取石子的游戏,\( SG(x) = x \) 所以直接读入后异或起来输出就好了 代码 #include <cstdio ...
- 【转载】大连商品交易所-新套利撮合算法FAQ
原文网址:http://www.dce.com.cn/dalianshangpin/yw/fw/ywzy/jyywzy/498201/1500371/index.html 大连商品交易所 新套利撮 ...
- sublime 代码段
demo 展示助手中有经常用到个标签. <textarea type="text/md_x" style="display:none"> ## de ...
- 在C#中理解和实现策略模式的绝对入门教程
介绍 本文的目的是理解战略模式的基础知识,并试图了解何时可以使用,并有一个基本的实现,以便更好地理解.在现实世界的应用中,这是无法实施战略模式的,所采用的例子也远没有实际可行.这篇文章的想法只是为了说 ...
- HDU 3401 Trade(斜率优化dp)
http://acm.hdu.edu.cn/showproblem.php?pid=3401 题意:有一个股市,现在有T天让你炒股,在第i天,买进股票的价格为APi,卖出股票的价格为BPi,同时最多买 ...
- openlayers空间点查询之GetFeatureInfo
在map对象上注册点击方法监听, 这里我用的是wms,当然你也可以查询wfs map.events.register('click', map, function (e) { ...
- latex建立参考文献的超链接
在Latex生成的pdf文档中建立超链接(如从正文到参考文献,从目录到相应内容,从页码编号到实际页面等),有利于读者快速定位当前阅读的信息. 如何在生成的pdf文件中包含超链接呢?需要注意一下两点: ...
- Spring-JDBC依赖
<dependency> <groupId>mysql</groupId> <artifactId>mysql-connector-java</a ...