一 、Highway Networks 与 Deep Networks 的关系

深层神经网络相比于浅层神经网络具有更好的效果,在很多方面都已经取得了很好的效果,特别是在图像处理方面已经取得了很大的突破,然而,伴随着深度的增加,深层神经网络存在的问题也就越大,像大家所熟知的梯度消失问题,这也就造成了训练深层神经网络困难的难题。2015年由Rupesh Kumar Srivastava等人受到LSTM门机制的启发提出的网络结构(Highway Networks)很好的解决了训练深层神经网络的难题,Highway Networks 允许信息高速无阻碍的通过深层神经网络的各层,这样有效的减缓了梯度的问题,使深层神经网络不在仅仅具有浅层神经网络的效果。

二、Deep Networks 梯度消失/爆炸(vanishing and exploding gradient)问题

我们先来看一下简单的深层神经网络(仅仅几个隐藏层)

先把各个层的公式写出来

C  = sigmoid(W4 * H3 + b4)
H3 = sigmoid(W3 * H2 + b3)
H2 = sigmoid(W2 * H1 + b2)
H1 = sigmoid(W1 * x + b1)

我们对W1求导:

W = W - lr * g(t)

以上公式仅仅是四个隐层的情况,当隐层的数量达到数十层甚至是数百层的情况下,一层一层的反向传播回去,当权值 < 1的时候,反向传播到某一层之后权值近乎不变,相当于输入x的映射,例如,g(t) =〖0.9〗^100已经是很小很小了,这就造成了只有前面几层能够正常的反向传播,后面的那些隐层仅仅相当于输入x的权重的映射,权重不进行更新。反过来,当权值 > 1的时候,会造成梯度爆炸,同样是仅仅前面的几层能更改正常学习,后面的隐层会变得很大。

三、Highway Networks Formula

  • Notation

    (.) 操作代表的是矩阵按位相乘

    sigmoid函数:

  • Highway Networks formula

    对于我们普通的神经网络,用非线性激活函数H将输入的x转换成y,公式1忽略了bias。但是,H不仅仅局限于激活函数,也采用其他的形式,像convolutional和recurrent。

    对于Highway Networks神经网络,增加了两个非线性转换层,一个是 T(transform gate) 和一个是 C(carry gate),通俗来讲,T表示输入信息经过convolutional或者是recurrent的信息被转换的部分,C表示的是原始输入信息x保留的部分 ,其中 T=sigmoid(wx + b)

    为了计算方便,这里定义了 C = 1 - T

    需要注意的是x,y, H, T的维度必须一致,要想保证其维度一致,可以采用sub-sampling或者zero-padding策略,也可以使用普通的线性层改变维度,使其一致,可以采用几个公式相比,公式3要比公式1灵活的多,可以考虑一下特殊的情况,T= 0的时候,y = x,原始输入信息全部保留,不做任何的改变,T = 1的时候,Y = H,原始信息全部转换,不在保留原始信息,仅仅相当于一个普通的神经网络。

四、Highway BiLSTM Networks 搭建##

pytorch搭建神经网络一般需要继承nn.Module这个类,然后实现里面的forward()函数,搭建Highwany BiLSTM Networks写了两个类,并使用nn.ModuleList将两个类联系起来:

    class HBiLSTM(nn.Module):
def __init__(self, args):
super(HBiLSTM, self).__init__()
......
def forward(self, x):
# 实现Highway BiLSTM Networks的公式
......
    class HBiLSTM_model(nn.Module):
def __init__(self, args):
super(HBiLSTM_model, self).__init__()
......
# args.layer_num_highway 代表Highway BiLSTM Networks有几层
self.highway = nn.ModuleList([HBiLSTM(args) for _ in range(args.layer_num_highway)])
......
def forward(self, x):
......
# 调用HBiLSTM类的forward()函数
for current_layer in self.highway:
x, self.hidden = current_layer(x, self.hidden)

HBiLSTM类的forward()函数里面我们实现Highway BiLSTM Networks的的公式

首先我们先来计算H,上文已经说过,H可以是卷积或者是LSTM,在这里,normal_fc就是我们需要的H

	 x, hidden = self.bilstm(x, hidden)
# torch.transpose是转置操作
normal_fc = torch.transpose(x, 0, 1)

上文提及,x,y,H,T的维度必须保持一致,并且提供了两种策略,这里我们使用一个普通的Linear去转换维度

	source_x = source_x.contiguous()
information_source = source_x.view(source_x.size(0) * source_x.size(1), source_x.size(2))
information_source = self.gate_layer(information_source)
information_source = information_source.view(source_x.size(0), source_x.size(1), information_source.size(1))

也可以采用zero-padding的策略保证维度一致

        # you also can choose the strategy that zero-padding
zeros = torch.zeros(source_x.size(0), source_x.size(1), carry_layer.size(2) - source_x.size(2))
source_x = Variable(torch.cat((zeros, source_x.data), 2))

维度一致之后我们就可以根据我们的公式来写代码了:

	# transformation gate layer in the formula is T
transformation_layer = F.sigmoid(information_source)
# carry gate layer in the formula is C
carry_layer = 1 - transformation_layer
# formula Y = H * T + x * C
allow_transformation = torch.mul(normal_fc, transformation_layer)
allow_carry = torch.mul(information_source, carry_layer)
information_flow = torch.add(allow_transformation, allow_carry)

最后的information_flow就是我们的输出,但是,还需要经过转换维度保证维度一致。

更多的请参考Github: Highway Networks implement in pytorch

五、Highway Networks 实验结果

  • 个人实验结果

    任务:情感分类任务 --- 二分类

    数据规模 :

    分析:从图中可以看出,相同的参数情况下,浅层神经网络相互对比变化不是很明显,5层的神经网络就有了一些变化,准确率相差了一个点左右。由于硬件资源,更加深的深层神经网络还没有测试。 但是从图中也可以发现问题就是伴随深度的加深,Highway Networks的准确率也在下降,深度加深,神经网络的参数也就增加的越多,这就需要重新调节超参数。

  • Paper 实验结果

    分析:从论文的实验结果来看,当深层神经网络的层数能够达到50层甚至100层的时候,loss也能够下降的很快,犹如几层的神经网络一样,与普通的深层神经网络形成了鲜明的对比。

References

Highway Networks的更多相关文章

  1. 基于pytorch实现HighWay Networks之Train Deep Networks

    (一)Highway Networks 与 Deep Networks 的关系 理论实践表明神经网络的深度是至关重要的,深层神经网络在很多方面都已经取得了很好的效果,例如,在1000-class Im ...

  2. Highway Networks Pytorch

    导读 本文讨论了深层神经网络训练困难的原因以及如何使用Highway Networks去解决深层神经网络训练的困难,并且在pytorch上实现了Highway Networks. 一 .Highway ...

  3. 基于pytorch实现HighWay Networks之Highway Networks详解

    (一)简述---承接上文---基于pytorch实现HighWay Networks之Train Deep Networks 上文已经介绍过Highway Netwotrks提出的目的就是解决深层神经 ...

  4. Highway Networks(高速路神经网络)

    Rupesh Kumar Srivastava (邮箱:RUPESH@IDSIA.CH)Klaus Greff (邮箱:KLAUS@IDSIA.CH)J¨ urgen Schmidhuber (邮箱: ...

  5. Paper | Highway Networks

    目录 1. 网络结构 2. 分析 解决的问题:在当时,人们认为 提高深度 是 提高精度 的法宝.但是网络训练也变得很困难.本文旨在解决深度网络训练难的问题,本质是解决梯度问题. 提出的网络:本文提出的 ...

  6. 【论文笔记】Training Very Deep Networks - Highway Networks

    目标: 怎么训练很深的神经网络 然而过深的神经网络会造成各种问题,梯度消失之类的,导致很难训练 作者利用了类似LSTM的方法,通过增加gate来控制transform前和transform后的数据的比 ...

  7. Residual Networks <2015 ICCV, ImageNet 图像分类Top1>

    本文介绍一下2015 ImageNet中分类任务的冠军——MSRA何凯明团队的Residual Networks.实际上,MSRA是今年Imagenet的大赢家,不单在分类任务,MSRA还用resid ...

  8. Highway LSTM 学习笔记

    Highway LSTM 学习笔记 zoerywzhou@gmail.com http://www.cnblogs.com/swje/ 作者:Zhouwan  2016-4-5   声明 1)该Dee ...

  9. Re-thinking Deep Residual Networks

    本文是对ImageNet 2015的冠军ResNet(Deep Residual Networks)以及目前围绕ResNet这个工作研究者后续所发论文的总结,主要涉及到下面5篇论文. 1. Link: ...

随机推荐

  1. 如何删除当前正在使用的SQLLite文件?

    从网上搜索一大堆,套路几乎相同,但自己就是不行,怎么也不行,为什么不行呢?不行的话别人肯定不来坑博友了呀.然后放了一会,去拿下午茶回来,再次来看,恍然大悟,What?这么简单. 一开始代码如下: he ...

  2. 一次FCK拿bc全过程

    和大家简单的弄下fckeditor 漏洞在红客我看到好多人对fck 这个漏洞很干兴趣 其实这个漏洞这的很老了 也非常好利用  我也扫了一点fck的漏洞网址  下面我们就来打开一个我们看看这个一号站平台 ...

  3. 002-Apache Maven 构建生命周期

    Maven - 构建生命周期 什么是构建生命周期 构建生命周期是一组阶段的序列(sequence of phases),每个阶段定义了目标被执行的顺序.这里的阶段是生命周期的一部分. 举例说明,一个典 ...

  4. jsp 使用Common-FileUpload组件文件上传及限制上传类型

    1.将commons-fileupload-1.3.3.jar复制到Web应用的lib文件夹下,在WebRoot目录下创建limit.jsp页面,在该页面中添加一个文件域的表单,设置类型为    mu ...

  5. Java中面向字符的输入流

    Java中面向字符的输入流 2016-12-04 Java程序员联盟 Java程序员联盟 Java程序员联盟 微信号 javalm 功能介绍 莫道君行早,更有早行人 全心敲代码,天道自酬勤 字符流是针 ...

  6. mongodb 的服务启动和基本操作命令

    由于在dos 下操作mongodb 很不方便 所以我推荐大家使用mongodb 的可视化工具robomongo  这个是robomongo的下载网址 https://robomongo.org/dow ...

  7. Tomcat正常启动,访问所有页面均报404异常,404异常总结

    今天遇到一个问题:Tomcat正常启动,访问所有页面均报404异常 404异常,很常见,大多情况是路径错误.web.xml文件映射路径写错.服务器设置.servlet的jar包未导进去或者没有随项目发 ...

  8. 垃圾收集器Serial 、Parallel、CMS、G1

    详见:http://blog.yemou.net/article/query/info/tytfjhfascvhzxcyt378 这里介绍4个垃圾收集器,如果进行了错误的选择将会大大的影响程序的性能. ...

  9. 3.修改第一个程序来点亮LED

    在上一节中已经将驱动程序框架搭建好了 接下来开始写硬件的操作(控制LED): (1)看原理图,确定引脚 (2)看2440手册 (3)写代码(需要使用ioremap()函数映射虚拟地址,在linux中只 ...

  10. IOS学习【xcode 7新特性url链接】

    由于xcode7的更新,在访问http链接的时候会输出错误信息 The resource could not be loaded because the App Transport Security ...