pytorch中的词向量的使用

在pytorch我们使用nn.embedding进行词嵌入的工作。

具体用法就是:

import torch
word_to_ix={'hello':0,'world':1}
embeds = torch.nn.Embedding(2,5)
hello_idx=torch.LongTensor([word_to_ix['hello']])
hello_embed = embeds(hello_idx)
print(hello_embed)
print(embeds.weight) tensor([[ 0.6584, 0.2991, -1.2654, 0.9369, 0.6088]], grad_fn=<EmbeddingBackward>) Parameter containing:
tensor([[ 0.6584, 0.2991, -1.2654, 0.9369, 0.6088],
[ 0.1922, 1.5374, 0.5737, -0.8007, -0.4896]], requires_grad=True)

在torch.nn.Embedding的源代码中,它是这么解释,

This module is often used to store word embeddings and retrieve them using indices.

The input to the module is a list of indices, and the output is the corresponding

word embeddings.

对于这个,我的理解是这样的torch.nn.Embedding 是一个矩阵类,当我传入参数之后,我可以得到一个矩阵对象,比如上面代码中的

embeds = torch.nn.Embedding(2,5) 通过这个代码,我就获得了一个两行三列的矩阵对象embeds。这个时候,矩阵对象embeds的输入就是一个索引列表(当然这个列表

应该是longtensor格式,得到的结果就是对应索引的词向量)

我们这里有一点需要格外注意,在上面的结果中,有个这个东西 requires_grad=True

我在开始接触pytorch的时候,对embedding的一个疑惑就是它是如何定义自动更新的。因为现在我们得到的这个词向量是随机初始化的结果,

在后续神经网络反向传递过程中,这个参数是需要更新的。

这里我想要点出一点来,就是词向量在这里是使用标准正态分布进行的初始化。我们可以通过查看源代码来进行验证。

在源代码中

if _weight is None:
self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim)) ##定义一个Parameter对象
self.reset_parameters() #随后对这个对象进行初始化
...
... def reset_parameters(self): #标准正态进行初始化
init.normal_(self.weight)
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)

pytorch中词向量生成的原理的更多相关文章

  1. Pytorch中的自动求导函数backward()所需参数含义

    摘要:一个神经网络有N个样本,经过这个网络把N个样本分为M类,那么此时backward参数的维度应该是[N X M] 正常来说backward()函数是要传入参数的,一直没弄明白backward需要传 ...

  2. 新手如何入门pytorch?

    我最近的文章中,专门为想学Pytorch的新手推荐了一些学习资源,包括教程.视频.项目.论文和书籍.希望能对你有帮助:一.PyTorch学习教程.手册 (1)PyTorch英文版官方手册:https: ...

  3. 新手必备 | 史上最全的PyTorch学习资源汇总

    目录: PyTorch学习教程.手册 PyTorch视频教程 PyTorch项目资源      - NLP&PyTorch实战      - CV&PyTorch实战 PyTorch论 ...

  4. TFIDF<细读>

    概念 TF-IDF(term frequency–inverse document frequency)是一种用于资讯检索与资讯探勘的常用加权技术.TF-IDF是一种统计方法,用以评估一字词对于一个文 ...

  5. 【目标检测】YOLO:

    PPT 可以说是讲得相当之清楚了... deepsystems.io 中文翻译: https://zhuanlan.zhihu.com/p/24916786 图解YOLO YOLO核心思想:从R-CN ...

  6. 3D点云重建原理及Pytorch实现

    3D点云重建原理及Pytorch实现 Pytorch: Learning Efficient Point Cloud Generation for Dense 3D Object Reconstruc ...

  7. 空间金字塔池化(Spatial Pyramid Pooling, SPP)原理和代码实现(Pytorch)

    想直接看公式的可跳至第三节 3.公式修正 一.为什么需要SPP 首先需要知道为什么会需要SPP. 我们都知道卷积神经网络(CNN)由卷积层和全连接层组成,其中卷积层对于输入数据的大小并没有要求,唯一对 ...

  8. PyTorch-Adam优化算法原理,公式,应用

    概念:Adam 是一种可以替代传统随机梯度下降过程的一阶优化算法,它能基于训练数据迭代地更新神经网络权重.Adam 最开始是由 OpenAI 的 Diederik Kingma 和多伦多大学的 Jim ...

  9. 一文看懂Transformer内部原理(含PyTorch实现)

    Transformer注解及PyTorch实现 原文:http://nlp.seas.harvard.edu/2018/04/03/attention.html 作者:Alexander Rush 转 ...

随机推荐

  1. Hibernate系列3-----之修改

    只是修改测试类,先看其他的代码的同学,请翻看我的博客Hibernate1,嘿嘿,我就在这不在重写一遍来 @Test public void testhibernate() { updateStuden ...

  2. SSRS 参数 单选 多选

    前段时间 公司要求报表的选项可以多选. 知道需求后,研究了下实现. 首先我们创建一个报表,然后添加3个数据集,2个参数,如下图. DataSet1数据集:存放主数据. ddl_emplid数据集:存放 ...

  3. maven课程 项目管理利器-maven 3-3 maven中的坐标和仓库

    本节主要讲了两大方面: 1 maven坐标 1.0  构件定义 任何依赖,插件,项目构建输出 都称之为构件. 1.1 maven坐标概念 groupid 公司或组织的域名倒序+当前项目名称 artif ...

  4. Hibernate课程 初探多对多映射1-1 多对多应用场景

    1 用途: 员工和项目之间的多对多关系 2 实现: 员工表和项目表之外,建立员工和项目关联表来实现: 3 hibernate应用: set元素和many-to-many来实现

  5. 超详细Hexo+Github博客搭建小白教程

    原文链接:超详细Hexo+Github博客搭建小白教程 去年9月的时候开始搭建了第一个自己的独立博客,到现在也稍微像模像样了.很多小伙伴应该也想过搭建一个自己的博客,网上也有一堆详细教程.我在此稍稍总 ...

  6. 【Unity3D学习笔记】解决放大后场景消失不显示问题

    不知道为啥,我的Unity场景放大到一定大小后,就会消失... 解决方案: 选中一个GameObject,然后按F键. F键作用是聚焦,视图将移动,以选中对象为中心.

  7. Coppermine-1.5.46 (Ubuntu 16.04.1)

      平台: Ubuntu 类型: 虚拟机镜像 软件包: coppermine-1.5.46 commercial content management coppermine media sharing ...

  8. vue checkbox 双向绑定及初始化渲染

    双向绑定可以绑定到同一个数组 <input type="checkbox" id="jack" value="Jack" v-mode ...

  9. 详细讲解:通过composer安装TP5.1(Thinkphp5.1)

    现在TP5越来越火了,TP5也更新到了5.1版本,但是5.1以上版本只能通过composer来进行安装,那么这里贴出详细的步骤 前提:PHP版本必须要5.6以上 参考网址:http://www.thi ...

  10. 缓存的set、getAndTouch一定要谨慎使用

    缓存的set.getAndTouch一定要谨慎使用. 很多人认为缓存在内存中性能良好,频繁更新,却不想机器的IO无法支撑,结果就是缓存成了系统的瓶颈.