torch.nn.utils.rnn:

pack_padded_sequence() pad_packed_sequence()

Notice:

  • The padded embedding metrix must be sorted by the ground length of each sentence.
  • The parameter "batch_first=True" controls the first demension of the embedding metrix is the batch_size (\(B\times T \times W_{emb}\)). Otherwise, the first demension is the length of the longest sentence (\(T\times B \times W_{emb}\))
  • The two functions will inflence the padding word, especially on bidirectional RNN (the backward rnn will go through some padding words first and the forward rnn will go through some padding words last).

Examplt

a = torch.FloatTensor([[[2,3,4], [2, 3,1]], [[2,3,4], [2, 3,1]], [[2,4,5], [0, 0, 0]]])

packed_x = pack_padded_sequence(a, [2, 2, 1], batch_first=True)
# a[0,0,:] = torch.FloatTensor([0,0,0]) h0 = Variable(torch.randn(1, 3, 4))
lstm = nn.LSTM(3, 4,num_layers=1,batch_first=True,bidirectional=True) rnn_packed, (h_last, c_last) = lstm(packed_x) rnn_out, length = pad_packed_sequence(rnn_packed, batch_first=True)

Nan

Mask

When using mask on rnn outputs with \(-\infty\), it will cause the grad to be nan when back-propogating. But it's right when appling to attention. The \(softmax\) operation does not cause the nan problem.

Pytorch的更多相关文章

  1. Ubutnu16.04安装pytorch

    1.下载Anaconda3 首先需要去Anaconda官网下载最新版本Anaconda3(https://www.continuum.io/downloads),我下载是是带有python3.6的An ...

  2. 解决运行pytorch程序多线程问题

    当我使用pycharm运行  (https://github.com/Joyce94/cnn-text-classification-pytorch )  pytorch程序的时候,在Linux服务器 ...

  3. 基于pytorch实现word2vec

    一.介绍 word2vec是Google于2013年推出的开源的获取词向量word2vec的工具包.它包括了一组用于word embedding的模型,这些模型通常都是用浅层(两层)神经网络训练词向量 ...

  4. 基于pytorch的CNN、LSTM神经网络模型调参小结

    (Demo) 这是最近两个月来的一个小总结,实现的demo已经上传github,里面包含了CNN.LSTM.BiLSTM.GRU以及CNN与LSTM.BiLSTM的结合还有多层多通道CNN.LSTM. ...

  5. pytorch实现VAE

    一.VAE的具体结构 二.VAE的pytorch实现 1加载并规范化MNIST import相关类: from __future__ import print_function import argp ...

  6. PyTorch教程之Training a classifier

    我们已经了解了如何定义神经网络,计算损失并对网络的权重进行更新. 接下来的问题就是: 一.What about data? 通常处理图像.文本.音频或视频数据时,可以使用标准的python包将数据加载 ...

  7. PyTorch教程之Neural Networks

    我们可以通过torch.nn package构建神经网络. 现在我们已经了解了autograd,nn基于autograd来定义模型并对他们有所区分. 一个 nn.Module模块由如下部分构成:若干层 ...

  8. PyTorch教程之Autograd

    在PyTorch中,autograd是所有神经网络的核心内容,为Tensor所有操作提供自动求导方法. 它是一个按运行方式定义的框架,这意味着backprop是由代码的运行方式定义的. 一.Varia ...

  9. Linux安装pytorch的具体过程以及其中出现问题的解决办法

    1.安装Anaconda 安装步骤参考了官网的说明:https://docs.anaconda.com/anaconda/install/linux.html 具体步骤如下: 首先,在官网下载地址 h ...

  10. Highway Networks Pytorch

    导读 本文讨论了深层神经网络训练困难的原因以及如何使用Highway Networks去解决深层神经网络训练的困难,并且在pytorch上实现了Highway Networks. 一 .Highway ...

随机推荐

  1. H5_0003:JS禁用调试,禁用右键,监听F12事件的方法

    1,禁用调试 // 这个方法是防止恶意调试的 (function () { console["log"]("=============================== ...

  2. “不能在dropdownlist中选择多个项

    DropDownList.ClearSelection(); DropDownList.SelectedItem.Text = "value值";

  3. js关联数组

    标准javascript 是内含支持hash关联数组,经查找资料并测试,有关标准javascript内含的hash关联数组操作备忘如下 1.Hash关联数组定义 // 定义空数组 myhash = { ...

  4. Mysql查看登录用户以及修改密码和创建用户以及授权(转载)

    本文转自(https://www.cnblogs.com/manzb/p/6491924.html) 1.mysql查看当前登录用户,当前数据库: select user(); select data ...

  5. JAVA的抽象类和接口

    抽象类 在面向对象的概念中,所有的对象都是通过类来描述的,但是反过来,并不是所有的类都是用来描述对象的,如果一个类中没有包含足够的信息来描绘一个具体的对象,这样的类就是抽象类. 抽象类除了不能实例化对 ...

  6. CentOS7离线安装mysql5.7

    下载mysql5.7,系统选择redhat,版本选择RHEL7,下载RPM Bundle后得到一个tar文件.这里得到文件mysql-5.7.25-1.el7.x86_64.rpm-bundle.ta ...

  7. js较深入的知识点

    浏览器渲染过程是怎样的?重绘重排是什么?如何避免过多的重绘重排? 将html解析为dom树; 将css解析为cssom; 结合DOM树和CSSOM树,生成一棵渲染树(Render Tree); 生成布 ...

  8. shiro教程

    ref https://www.jianshu.com/p/5a35d0100a71 https://www.jianshu.com/p/0366a1675bb6 https://blog.csdn. ...

  9. 文件比较与同步工具——FreeFileSync

    1. 基本介绍 FreeFileSync是一个用于文件同步的免费开源程序.FreeFileSync通过比较其内容,日期或文件大小上的一个或多个文件夹,然后根据用户定义的设置同步内容.除了支持本地文件系 ...

  10. 05mycat父子表

    表连接的难题在mycat中是不允许跨分片做表连接查询的 创建t_orders表 create table t_orders( id int PRIMARY key, customer_id int n ...