pack_padded_sequence是将句子按照batch优先的原则记录每个句子的词,变化为不定长tensor,方便计算损失函数。

pad_packed_sequence是将pack_padded_sequence生成的结构转化为原先的结构,定长的tensor。

其中test.txt的内容

As they sat in a nice coffee shop,
he was too nervous to say anything and she felt uncomfortable.
Suddenly, he asked the waiter,
"Could you please give me some salt? I'd like to put it in my coffee."

具体参见如下代码

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import wordfreq vocab = {}
token_id = 1
lengths = [] #读取文件,生成词典
with open('test.txt', 'r') as f:
lines=f.readlines()
for line in lines:
tokens = wordfreq.tokenize(line.strip(), 'en')
lengths.append(len(tokens))
#将每个词加入到vocab中,并同时保存对应的index
for word in tokens:
if word not in vocab:
vocab[word] = token_id
token_id += 1 x = np.zeros((len(lengths), max(lengths)))
l_no = 0
#将词转化为数字
with open('test.txt', 'r') as f:
lines = f.readlines()
for line in lines:
tokens = wordfreq.tokenize(line.strip(), 'en')
for i in range(len(tokens)):
x[l_no, i] = vocab[tokens[i]]
l_no += 1 x=torch.Tensor(x)
x = Variable(x)
print(x)
'''
tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.]])
'''
lengths = torch.Tensor(lengths)
print(lengths)#tensor([ 8., 11., 5., 14.]) _, idx_sort = torch.sort(torch.Tensor(lengths), dim=0, descending=True)
print(_) #tensor([14., 11., 8., 5.])
print(idx_sort)#tensor([3, 1, 0, 2]) lengths = list(lengths[idx_sort])#按下标取元素 [tensor(14.), tensor(11.), tensor(8.), tensor(5.)]
t = x.index_select(0, idx_sort)#按下标取元素
print(t)
'''
tensor([[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
'''
x_packed = nn.utils.rnn.pack_padded_sequence(input=t, lengths=lengths, batch_first=True)
print(x_packed)
'''
PackedSequence(data=tensor([24., 9., 1., 20., 25., 10., 2., 9., 26., 11., 3., 21., 27., 12.,
4., 22., 28., 13., 5., 23., 29., 14., 6., 30., 15., 7., 31., 16.,
8., 32., 17., 13., 18., 33., 19., 34., 4., 7.]), batch_sizes=tensor([4, 4, 4, 4, 4, 3, 3, 3, 2, 2, 2, 1, 1, 1]))
''' x_padded = nn.utils.rnn.pad_packed_sequence(x_packed, batch_first=True)#x_padded是tuple
print(x_padded)
'''
(tensor([[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([14, 11, 8, 5]))
'''
#还原tensor
_, idx_unsort = torch.sort(idx_sort)
output = x_padded[0].index_select(0, idx_unsort)
print(output)
'''
tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.]])
'''

pytorch中的pack_padded_sequence和pad_packed_sequence用法的更多相关文章

  1. [转载]PyTorch中permute的用法

    [转载]PyTorch中permute的用法 来源:https://blog.csdn.net/york1996/article/details/81876886 permute(dims) 将ten ...

  2. Pytorch中randn和rand函数的用法

    Pytorch中randn和rand函数的用法 randn torch.randn(*sizes, out=None) → Tensor 返回一个包含了从标准正态分布中抽取的一组随机数的张量 size ...

  3. Pytorch中nn.Conv2d的用法

    Pytorch中nn.Conv2d的用法 nn.Conv2d是二维卷积方法,相对应的还有一维卷积方法nn.Conv1d,常用于文本数据的处理,而nn.Conv2d一般用于二维图像. 先看一下接口定义: ...

  4. PyTorch中view的用法

    相当于numpy中resize()的功能,但是用法可能不太一样. 我的理解是: 把原先tensor中的数据按照行优先的顺序排成一个一维的数据(这里应该是因为要求地址是连续存储的),然后按照参数组合成其 ...

  5. pytorch中tensorboardX的用法

    在代码中改好存储Log的路径 命令行中输入 tensorboard --logdir /home/huihua/NewDisk1/PycharmProjects/pytorch-deeplab-xce ...

  6. [PyTorch]PyTorch中反卷积的用法

    文章来源:https://www.jianshu.com/p/01577e86e506 pytorch中的 2D 卷积层 和 2D 反卷积层 函数分别如下: class torch.nn.Conv2d ...

  7. pytorch中如何处理RNN输入变长序列padding

    一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练 ...

  8. pytorch中如何使用DataLoader对数据集进行批处理

    最近搞了搞minist手写数据集的神经网络搭建,一个数据集里面很多个数据,不能一次喂入,所以需要分成一小块一小块喂入搭建好的网络. pytorch中有很方便的dataloader函数来方便我们进行批处 ...

  9. PyTorch中使用深度学习(CNN和LSTM)的自动图像标题

    介绍 深度学习现在是一个非常猖獗的领域 - 有如此多的应用程序日复一日地出现.深入了解深度学习的最佳方法是亲自动手.尽可能多地参与项目,并尝试自己完成.这将帮助您更深入地掌握主题,并帮助您成为更好的深 ...

随机推荐

  1. Python爬虫根据关键词爬取知网论文摘要并保存到数据库中【入门必学】

    前言 本文的文字及图片来源于网络,仅供学习.交流使用,不具有任何商业用途,版权归原作者所有,如有问题请及时联系我们以作处理. 作者:崩坏的芝麻 由于实验室需要一些语料做研究,语料要求是知网上的论文摘要 ...

  2. windows下安装ssdb

    官方下载 http://ssdb.io/docs/install.html 这是官方网站 官方建议 Do not run SSDB server on Windows system for a pro ...

  3. Java Properties 加载

    static{ Properties prop = new Properties(); prop.load(Thread.currentThread().getContextClassLoader() ...

  4. WinForm 自定义控件 - RooF

    由3个标签组成 直接代码 public partial class Roof : UserControl { public Roof() { InitializeComponent(); } priv ...

  5. Docker下载tomcat

    命令 下载tomcat docker pull tomcat //默认是latest版本具体可以到 hub.docker.com上查询 //如果想下其他版本以9.0.16示例那么: docker pu ...

  6. hdu 1052 Tian Ji -- The Horse Racing (田忌赛马)

    Tian Ji -- The Horse Racing Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 65536/32768 K (J ...

  7. GitHub 上的 12306 抢票神器,助力回家过年

    又到周末了,不过本周末有些略微的特殊. 距离每年一次的全球最大规模的人类大迁徙活动已经只剩下一个多月了,各位在外工作一年的小伙伴大多数人又要和小编一样摩拳擦掌的对待史上最难抢的抢票活动. 然鹅,身为一 ...

  8. 《Java基础知识》Java技术总结

    1. Java 知识点总结 Java标示符.保留字和数制:https://www.cnblogs.com/jssj/p/11114041.html Java数据类型以及变量的定义:https://ww ...

  9. VS2019 开发Django(八)------视图

    导航:VS2019开发Django系列 这几天学习了一下Django的视图和模板,从这几天的学习进度来看,视图这里并没有花很多的时间,相反的,模板花费了大量的时间,主要原因还是因为对Jquery操作d ...

  10. python多线程编程-queue模块和生产者-消费者问题

    摘录python核心编程 本例中演示生产者-消费者模型:商品或服务的生产者生产商品,然后将其放到类似队列的数据结构中.生产商品中的时间是不确定的,同样消费者消费商品的时间也是不确定的. 使用queue ...