导读

本文讨论了深层神经网络训练困难的原因以及如何使用Highway Networks去解决深层神经网络训练的困难,并且在pytorch上实现了Highway Networks。

一 、Highway Networks 与 Deep Networks 的关系

深层神经网络相比于浅层神经网络具有更好的效果,在很多方面都已经取得了很好的效果,特别是在图像处理方面已经取得了很大的突破,然而,伴随着深度的增加,深层神经网络存在的问题也就越大,像大家所熟知的梯度消失问题,这也就造成了训练深层神经网络困难的难题。2015年由Rupesh Kumar Srivastava等人受到LSTM门机制的启发提出的网络结构(Highway Networks)很好的解决了训练深层神经网络的难题,Highway Networks 允许信息高速无阻碍的通过深层神经网络的各层,这样有效的减缓了梯度的问题,使深层神经网络不在仅仅具有浅层神经网络的效果。

二、Deep Networks 梯度消失/爆炸(vanishing and exploding gradient)问题

我们先来看一下简单的深层神经网络(仅仅几个隐层)

先把各个层的公式写出来

我们对W1求导:

W = W - lr * g(t)

以上公式仅仅是四个隐层的情况,当隐层的数量达到数十层甚至是数百层的情况下,一层一层的反向传播回去,当权值 < 1的时候,反向传播到某一层之后权值近乎不变,相当于输入x的映射,例如,g(t) =〖0.9〗^100已经是很小很小了,这就造成了只有前面几层能够正常的反向传播,后面的那些隐层仅仅相当于输入x的权重的映射,权重不进行更新。反过来,当权值 > 1的时候,会造成梯度爆炸,同样是仅仅前面的几层能更改正常学习,后面的隐层会变得很大。

三、Highway Networks Formula

  • Notation

    (.) 操作代表的是矩阵按位相乘

    sigmoid函数:

  • Highway Networks formula

    对于我们普通的神经网络,用非线性激活函数H将输入的x转换成y,公式1忽略了bias。但是,H不仅仅局限于激活函数,也采用其他的形式,像convolutional和recurrent。

    对于Highway Networks神经网络,增加了两个非线性转换层,一个是 T(transform gate) 和一个是 C(carry gate),通俗来讲,T表示输入信息经过convolutional或者是recurrent的信息被转换的部分,C表示的是原始输入信息x保留的部分 ,其中 T=sigmoid(wx + b)

    为了计算方便,这里定义了 C = 1 - T

    需要注意的是x,y, H, T的维度必须一致,要想保证其维度一致,可以采用sub-sampling或者zero-padding策略,也可以使用普通的线性层改变维度,使其一致。

    几个公式相比,公式3要比公式1灵活的多,可以考虑一下特殊的情况,T= 0的时候,y = x,原始输入信息全部保留,不做任何的改变,T = 1的时候,Y = H,原始信息全部转换,不在保留原始信息,仅仅相当于一个普通的神经网络。

四、Highway BiLSTM Networks

  • Highway BiLSTM Networks Structure Diagram

    下图是 Highway BiLSTM Networks 结构图:

    input:代表输入的词向量

    B:在本任务代表bidirection lstm,代表公式(2)中的 H

    T:代表公式(2)中的 T,是Highway Networks中的transform gate

    C:代表公式(2)中的 C,是Highway Networks中的carry gate

    Layer = n,代表Highway Networks中的第n层

    Highway:框出来的代表一层Highway Networks

    在这个结构图中,Highway Networks第 n - 1 层的输出作为第n层的输入

  • Highway BiLSTM Networks Demo

    pytorch搭建神经网络一般需要继承nn.Module这个类,然后实现里面的forward()函数,搭建Highway BiLSTM Networks写了两个类,并使用nn.ModuleList将两个类联系起来:

      class HBiLSTM(nn.Module):
    def __init__(self, args):
    super(HBiLSTM, self).__init__()
    ......
    def forward(self, x):
    # 实现Highway BiLSTM Networks的公式
    ......
      class HBiLSTM_model(nn.Module):
    def __init__(self, args):
    super(HBiLSTM_model, self).__init__()
    ......
    # args.layer_num_highway 代表Highway BiLSTM Networks有几层
    self.highway = nn.ModuleList([HBiLSTM(args) for _ in range(args.layer_num_highway)])
    ......
    def forward(self, x):
    ......
    # 调用HBiLSTM类的forward()函数
    for current_layer in self.highway:
    x, self.hidden = current_layer(x, self.hidden)

    HBiLSTM类的forward()函数里面我们实现Highway BiLSTM Networks的的公式

    首先我们先来计算H,上文已经说过,H可以是卷积或者是LSTM,在这里,normal_fc就是我们需要的H

       x, hidden = self.bilstm(x, hidden)
    # torch.transpose是转置操作
    normal_fc = torch.transpose(x, 0, 1)

    上文提及,x,y,H,T的维度必须保持一致,并且提供了两种策略,这里我们使用一个普通的Linear去转换维度

      source_x = source_x.contiguous()
    information_source = source_x.view(source_x.size(0) * source_x.size(1), source_x.size(2))
    information_source = self.gate_layer(information_source)
    information_source = information_source.view(source_x.size(0), source_x.size(1), information_source.size(1))

    也可以采用zero-padding的策略保证维度一致

      # you also can choose the strategy that zero-padding
    zeros = torch.zeros(source_x.size(0), source_x.size(1), carry_layer.size(2) - source_x.size(2))
    source_x = Variable(torch.cat((zeros, source_x.data), 2))

    维度一致之后我们就可以根据我们的公式来写代码了:

      # transformation gate layer in the formula is T
    transformation_layer = F.sigmoid(information_source)
    # carry gate layer in the formula is C
    carry_layer = 1 - transformation_layer
    # formula Y = H * T + x * C
    allow_transformation = torch.mul(normal_fc, transformation_layer)
    allow_carry = torch.mul(information_source, carry_layer)
    information_flow = torch.add(allow_transformation, allow_carry)

    最后的information_flow就是我们的输出,但是,还需要经过转换维度保证维度一致。

    更多的请参考Github: Highway Networks implement in pytorch

五、Highway BiLSTM Networks 实验结果

本次实验任务是使用Highway BiLSTM Networks 完成情感分类任务(一句话的态度分成积极或者是消极),数据来源于Twitter情感分类数据集,以下是数据集中的各个标签的句子个数:

下图是本次实验任务在2-class数据集中的测试结果。图中1-300在Highway BiLSTM Networks中表示Layer = 1,BiLSTM 隐层的维度是300维。

实验结果:从图中可以看出,简单的多层双向LSTM并没有带来情感分析性能的提升,尤其是是到了10层之后,效果有不如随机的猜测。当用上Highway Networks之后,虽然性能也在逐步的下降,但是下降的幅度有了明显的改善。

References

Highway Networks Pytorch的更多相关文章

  1. 基于pytorch实现HighWay Networks之Highway Networks详解

    (一)简述---承接上文---基于pytorch实现HighWay Networks之Train Deep Networks 上文已经介绍过Highway Netwotrks提出的目的就是解决深层神经 ...

  2. 基于pytorch实现HighWay Networks之Train Deep Networks

    (一)Highway Networks 与 Deep Networks 的关系 理论实践表明神经网络的深度是至关重要的,深层神经网络在很多方面都已经取得了很好的效果,例如,在1000-class Im ...

  3. Highway Networks

    一 .Highway Networks 与 Deep Networks 的关系 深层神经网络相比于浅层神经网络具有更好的效果,在很多方面都已经取得了很好的效果,特别是在图像处理方面已经取得了很大的突破 ...

  4. Highway Networks(高速路神经网络)

    Rupesh Kumar Srivastava (邮箱:RUPESH@IDSIA.CH)Klaus Greff (邮箱:KLAUS@IDSIA.CH)J¨ urgen Schmidhuber (邮箱: ...

  5. Paper | Highway Networks

    目录 1. 网络结构 2. 分析 解决的问题:在当时,人们认为 提高深度 是 提高精度 的法宝.但是网络训练也变得很困难.本文旨在解决深度网络训练难的问题,本质是解决梯度问题. 提出的网络:本文提出的 ...

  6. 【论文笔记】Training Very Deep Networks - Highway Networks

    目标: 怎么训练很深的神经网络 然而过深的神经网络会造成各种问题,梯度消失之类的,导致很难训练 作者利用了类似LSTM的方法,通过增加gate来控制transform前和transform后的数据的比 ...

  7. (转)Awesome PyTorch List

    Awesome-Pytorch-list 2018-08-10 09:25:16 This blog is copied from: https://github.com/Epsilon-Lee/Aw ...

  8. DenseNet算法详解——思路就是highway,DneseNet在训练时十分消耗内存

    论文笔记:Densely Connected Convolutional Networks(DenseNet模型详解) 2017年09月28日 11:58:49 阅读数:1814 [ 转载自http: ...

  9. Residual Networks <2015 ICCV, ImageNet 图像分类Top1>

    本文介绍一下2015 ImageNet中分类任务的冠军——MSRA何凯明团队的Residual Networks.实际上,MSRA是今年Imagenet的大赢家,不单在分类任务,MSRA还用resid ...

随机推荐

  1. 利用JavaScript来切换样式表

    切换样式表 html页 <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http:/ ...

  2. Android studio一些常见技巧(不断更新)

    一.Android studio取消默认每次打开时打开最后一个项目 二.as添加jar包 新建一个libs目录,在java下 进行手动gragle同步或者如下图 三.代码边上的小图标 1.布局文件中存 ...

  3. 关于jsp页面转换成excel格式下载遇到问题及解决

    jsp页面转成excel格式的实现思路: 1.使用poi包:poi-bin-3.9-20121203 下载连接地址:http://www.apache.org/dyn/closer.cgi/poi/r ...

  4. 2017最新的Python教程分享

    Python在数据科学盛行的今天,其易于阅读和编写的特点,越来越受编程者追捧.在IEEE发布的2017年编程语言排行榜中,Python也高居首位.如果你有学Python的计划,快来看看小编分享的Pyt ...

  5. vim代码粘贴缩进混乱的问题[Linux]

    详见: http://blog.yemou.net/article/query/info/tytfjhfascvhzxcytp76   直接在vim插入模式下粘贴: 直接粘贴,剪贴板上的每个字符都相当 ...

  6. poj 2723 二分+2-sat判定

    题意:给出n对钥匙,每对钥匙只能选其中一个,在给出每层门需要的两个钥匙,只要一个钥匙就能开门,问最多能到哪层. 思路:了解了2-SAT判定的问题之后主要就是建图的问题了,这里建图就是对于2*n个钥匙, ...

  7. docker在CentOS7下部署指南

    docker只支持CentOS7.x系统,所以近期根据docker官网指南自己搭建了一套,供大家参考. 1.部署Centos7.x系统,查看系统版本. 2.执行 sudo yum update 更新到 ...

  8. 详解 mpls vpn 的实现

    MPLS VPN的实现 一.实验目的 该实验通过MPLS VPN的数据配置,使学生掌握路由器相关接口的IP地址设置.路由协议的配置以及MPLS VPN的完整的创建过程, 从而加深对IP网络的IP编址. ...

  9. 24点游戏详细截图介绍以及原型、Alpha、Beta对比

    原型设计 图片展示 功能与界面设计 1.登录注册 2.手机号验证 3.24点游戏 4.粉色系女生界面 Alpha 图片展示 功能与界面设计 1.24点游戏 2.背景音乐 3.可查看多种可能的答案 4. ...

  10. 201521123017 《Java程序设计》第3周学习总结

    1. 本周学习总结 2. 书面作业 Q1.代码阅读 public class Test1 { private int i = 1;//这行不能修改 private static int j = 2; ...