Pytorch LSTM 词性判断
首先,我们定义好一个LSTM网络,然后给出一个句子,每个句子都有很多个词构成,每个词可以用一个词向量表示,这样一句话就可以形成一个序列,我们将这个序列依次传入LSTM,然后就可以得到与序列等长的输出,每个输出都表示的是一种词性,比如名词,动词之类的,还是一种分类问题,每个单词都属于几种词性中的一种。
我们可以思考一下为什么LSTM在这个问题里面起着重要的作用。如果我们完全孤立的对一个词做词性的判断这样我们需要特别高维的词向量,但是对于LSTM,它有着一个记忆的特性,这样我们就能够通过这个单词前面记忆的一些词语来对其做一个判断,比如前面如果是my,那么他紧跟的词有很大可能就是一个名词,这样就能够充分的利用上文来做这个问题。
同时我们还可以通过引入字符来增强表达,什么意思呢?也就是说一个单词有一些前缀和后缀,比如-ly这种后缀很大可能是一个副词,这样我们就能够在字符水平得到一个词性判断的更好结果。
具体怎么做呢?还是用LSTM。每个单词有不同的字母组成,比如 apple 由a p p l e构成,我们同样给这些字符词向量,这样形成了一个长度为5的序列,然后传入另外一个LSTM网络,只取最后输出的状态层作为它的一种字符表达,我们并不需要关心到底提取出来的字符表达是什么样的,在learning的过程中这些都是会被更新的参数,使得最终我们能够正确预测。
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.autograd import Variable training_data = [("The dog ate the apple".split(),
["DET", "NN", "V", "DET", "NN"]),
("Everybody read that book".split(), ["NN", "V", "DET",
"NN"])]
# 每个单词就用一个数字表示,每种词性也用一个数字表示
word_to_idx = {}
tag_to_idx = {}
for context, tag in training_data:
for word in context:
if word not in word_to_idx:
# 对词进行编码
word_to_idx[word] = len(word_to_idx)
for label in tag:
if label not in tag_to_idx:
# 对词性编码
tag_to_idx[label] = len(tag_to_idx)
alphabet = 'abcdefghijklmnopqrstuvwxyz'
character_to_idx = {}
for i in range(len(alphabet)):
# 对字母编码
character_to_idx[alphabet[i]] = i # 字符LSTM
class CharLSTM(nn.Module):
def __init__(self, n_char, char_dim, char_hidden):
super(CharLSTM, self).__init__()
self.char_embedding = nn.Embedding(n_char, char_dim)
self.char_lstm = nn.LSTM(char_dim, char_hidden, batch_first=True) def forward(self, x):
x = self.char_embedding(x)
_, h = self.char_lstm(x)
# 取隐层
return h[0] class LSTMTagger(nn.Module):
def __init__(self, n_word, n_char, char_dim, n_dim, char_hidden, n_hidden,
n_tag):
super(LSTMTagger, self).__init__()
self.word_embedding = nn.Embedding(n_word, n_dim)
self.char_lstm = CharLSTM(n_char, char_dim, char_hidden)
self.lstm = nn.LSTM(n_dim + char_hidden, n_hidden, batch_first=True)
self.linear1 = nn.Linear(n_hidden, n_tag) def forward(self, x, word):
char = torch.FloatTensor()
for each in word:
char_list = []
for letter in each:
# 对词进行字母编码
char_list.append(character_to_idx[letter.lower()])
char_list = torch.LongTensor(char_list)
char_list = char_list.unsqueeze(0)
if torch.cuda.is_available():
tempchar = self.char_lstm(Variable(char_list).cuda())
else:
tempchar = self.char_lstm(Variable(char_list))
tempchar = tempchar.squeeze(0)
char = torch.cat((char, tempchar.cpu().data), 0)
if torch.cuda.is_available():
char = char.cuda()
char = Variable(char)
x = self.word_embedding(x)
x = torch.cat((x, char), 1) # char编码与word编码cat
x = x.unsqueeze(0)
# 取输出层 句长*n_hidden
x, _ = self.lstm(x)
x = x.squeeze(0)
x = self.linear1(x)
y = F.log_softmax(x)
return y model = LSTMTagger(
len(word_to_idx), len(character_to_idx), 10, 100, 50, 128, len(tag_to_idx))
if torch.cuda.is_available():
model = model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2) def make_sequence(x, dic):
idx = [dic[i] for i in x]
idx = Variable(torch.LongTensor(idx))
return idx for epoch in range(300):
print('*' * 10)
print('epoch {}'.format(epoch + 1))
running_loss = 0
for data in training_data:
word, tag = data
word_list = make_sequence(word, word_to_idx)
tag = make_sequence(tag, tag_to_idx)
if torch.cuda.is_available():
word_list = word_list.cuda()
tag = tag.cuda()
# forward
out = model(word_list, word)
loss = criterion(out, tag)
running_loss += loss.data[0]
# backward 三步常规操作
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Loss: {}'.format(running_loss / len(data)))
print()
input = make_sequence("Everybody ate the apple".split(), word_to_idx)
if torch.cuda.is_available():
input = input.cuda()
model.eval() #对dropout和batch normalization的操作在训练和测试的时候是不一样
out = model(input, "Everybody ate the apple".split())
print(out)
首先n_word 和 n_dim来定义单词的词向量维度,n_char和char_dim来定义字符的词向量维度,char_hidden表示CharLSTM输出的维度,n_hidden表示每个单词作为序列输入的LSTM输出维度,最后n_tag表示输出的词性的种类。
接着开始前向传播,不仅要传入一个编码之后的句子,同时还需要传入原本的单词,因为需要对字符做一个LSTM,所以传入的参数多了一个word_data表示一个句子的所有单词。
然后就是将每个单词传入CharLSTM,得到的结果和单词的词向量拼在一起形成一个新的输入,将输入传入LSTM里面,得到输出,最后接一个全连接层,将输出维数定义为label的数目。
特别要注意里面有一些unsqueeze(增维)和squeeze(降维)是因为LSTM的输入要求要带上batch_size(这里是1),torch.cat里面0和1分别表示沿着行和列来拼接。
预测一下 Everybody ate the apple 这句话每个词的词性,一共有3种词性,DET,NN,V。最后得到的结果为:
一共有4行,每行里面取最大的,那么第一个词的词性就是NN,第二个词是V,第三个词是DET,第四个词是NN。这个是相符的。
参考自:https://sherlockliao.github.io/2017/06/05/lstm%20language/
Pytorch LSTM 词性判断的更多相关文章
- pytorch lstm crf 代码理解 重点
好久没有写博客了,这一次就将最近看的pytorch 教程中的lstm+crf的一些心得与困惑记录下来. 原文 PyTorch Tutorials 参考了很多其他大神的博客,https://blog.c ...
- pytorch lstm crf 代码理解
好久没有写博客了,这一次就将最近看的pytorch 教程中的lstm+crf的一些心得与困惑记录下来. 原文 PyTorch Tutorials 参考了很多其他大神的博客,https://blog.c ...
- pytorch, LSTM介绍
本文中的RNN泛指LSTM,GRU等等 CNN中和RNN中batchSize的默认位置是不同的. CNN中:batchsize的位置是position 0. RNN中:batchsize的位置是pos ...
- pytorch LSTM情感分类全部代码
先运行main.py进行文本序列化,再train.py模型训练 dataset.py from torch.utils.data import DataLoader,Dataset import to ...
- Python中利用LSTM模型进行时间序列预测分析
时间序列模型 时间序列预测分析就是利用过去一段时间内某事件时间的特征来预测未来一段时间内该事件的特征.这是一类相对比较复杂的预测建模问题,和回归分析模型的预测不同,时间序列模型是依赖于事件发生的先后顺 ...
- 神经网络与数字货币量化交易系列(1)——LSTM预测比特币价格
首发地址:https://www.fmz.com/digest-topic/4035 1.简单介绍 深度神经网络这些年越来越热门,在很多领域解决了过去无法解决的难题,体现了强大的能力.在时间序列的预测 ...
- 预训练语言模型的前世今生 - 从Word Embedding到BERT
预训练语言模型的前世今生 - 从Word Embedding到BERT 本篇文章共 24619 个词,一个字一个字手码的不容易,转载请标明出处:预训练语言模型的前世今生 - 从Word Embeddi ...
- 大话循环神经网络(RNN)
在上一篇文章中,介绍了 卷积神经网络(CNN)的算法原理,CNN在图像识别中有着强大.广泛的应用,但有一些场景用CNN却无法得到有效地解决,例如: 语音识别,要按顺序处理每一帧的声音信息,有些结果 ...
- AI 智能写情诗、藏头诗
一.AI 智能情诗.藏头诗展示 最近使用PyTorch的LSTM训练一个写情诗(七言)的模型,可以随机生成情诗.也可以生成藏头情诗. 在特殊的日子用AI生成一首这样的诗,是不是很酷!下面分享下AI 智 ...
随机推荐
- 创建一个yum源,rpm安装二进制包
作者:邓聪聪 安装mariadb vi /etc/yum.repos.d/mariadb.repo [mariadb]name=mariadbbaseurl=http://mirrors.neusof ...
- ubuntu14.04上引入thinkphp5类库遇到的一个问题
ubuntu14.04 上加载OSS\OssClient() ;--->在vendor文件夹下的文件要用大写OSS 小写的报错 无法加载类库 Vendor('OSS.autoload');//引 ...
- OpenStack实践系列③镜像服务Glance
OpenStack实践系列③镜像服务Glance 3.5 Glance部署 修改glance-api和glance-registry的配置文件,同步数据库 [root@node1 ~]# vim /e ...
- java乱码问题解决
1.通过统一的过滤器进行了页面过滤(问题排除) 2.通过debug功能发现页面传到servelet和DAO中文都是OK的,可以说明在web程序端没有问题 问题就可能出现在数据库上面 首先查看数据库编码 ...
- centos7怎么永久修改hosname
centos7怎么永久修改hosname 其实,一般来说安装好虚拟机之后,一般都会进行修改hostname,之前也是在修改的时候,遇到过问题,但是没有深究,今天在修改的时候,好好研究了一下,之前看到好 ...
- POJ 2115
ax=b (mod n) 该方程有解的充要条件为 gcd(a,n) | b ,即 b% gcd(a,n)==0 令d=gcd(a,n) 有该方程的 最小整数解为 x = e (mod n/d) 其中e ...
- sublime text3 golang插件(golang build)
1 前言 先前条件: sublime text3:下载地址:http://www.sublimetext.com/3 golang:下载地址:https://golang.google.cn/dl/ ...
- jqgrid获取数据条数
function getResult() {//获取结果结合的函数,可以通过此函数获取查询后匹配的所有数据行. var o = jQuery("#jqgrid"); ...
- 3)django-路由系统url
一:django路由系统说明 路由都在urls文件里,它将浏览器输入的url映射到相应的业务处理逻辑 二:django 常用路由系统配置 1)URL常用有模式一FBV(function base v ...
- MVVM 简介
转:https://objccn.io/issue-13-1/ 所以,MVVM 到底是什么?与其专注于说明 MVVM 的来历,不如让我们看一个典型的 iOS 是如何构建的,并从那里了解 MVVM: 我 ...