NLP(二十二)使用LSTM进行语言建模以预测最优词
原文链接:http://www.one2know.cn/nlp22/
- 预处理
数据集使用Facebook上的BABI数据集
将文件提取成可训练的数据集,包括:文章 问题 答案
def get_data(infile):
stories,questions,answers = [],[],[]
story_text = []
fin = open(infile,'rb')
for line in fin:
line = line.decode('utf-8').strip()
lno,text = line.split(' ',1)
if '\t' in text:
question,answer,_ = text.split('\t')
stories.append(story_text)
questions.append(question)
answers.append(answer)
story_text = []
else:
story_text.append(text)
fin.close()
return stories,questions,answers
data_train = get_data('qa1_single-supporting-fact_train.txt')
data_test = get_data('qa1_single-supporting-fact_test.txt')
print('\nTrain observations:',len(data_train[0]),
'Test observations:',len(data_test[0]),'\n')
输出:
Train observations: 10000 Test observations: 1000
- 如何实现
1.预处理:创建字典并将文章,问题和答案映射到词表,进一步映射成向量形式
2.模型创建和验证:训练模型并在验证数据集上测试
3.预测结果:测试集测试数据的结果 - 代码
from __future__ import division,print_function
import collections
import itertools
import nltk
import numpy as np
import matplotlib.pyplot as plt
import os
import random
def get_data(infile):
stories,questions,answers = [],[],[]
story_text = []
fin = open(infile,'rb')
for line in fin:
line = line.decode('utf-8').strip()
lno,text = line.split(' ',1) # 去掉前面的数字标记
if '\t' in text: # 有制表符的是 问题 和 答案
question,answer,_ = text.split('\t')
stories.append(story_text)
questions.append(question)
answers.append(answer)
story_text = []
else: # 没制表符的是文章
story_text.append(text)
fin.close()
return stories,questions,answers
data_train = get_data('qa1_single-supporting-fact_train.txt')
data_test = get_data('qa1_single-supporting-fact_test.txt')
print('\nTrain observations:',len(data_train[0]),
'Test observations:',len(data_test[0]),'\n')
print(data_train[0][1],data_train[1][1],data_train[2][1])
# ['Daniel went back to the hallway.', 'Sandra moved to the garden.'] Where is Daniel? hallway
print(np.array(data_train).shape)
# (3, 10000)
dictnry = collections.Counter() # 返回列表元素出现次数的 字典,这里没有参数是一个空字典
for stories,questions,answers in [data_train,data_test]:
for story in stories:
for sent in story:
for word in nltk.word_tokenize(sent):
dictnry[word.lower()] += 1
for question in questions:
for word in nltk.word_tokenize(question):
dictnry[word.lower()] += 1
for answer in answers:
for word in nltk.word_tokenize(answer):
dictnry[word.lower()] += 1
word2indx = {w:(i+1) for i,(w,_) in enumerate(dictnry.most_common())} # 按词频排序
word2indx['PAD'] = 0
indx2word = {v:k for k,v in word2indx.items()}
vocab_size = len(word2indx) # 一共有22个不重复单词
print('vocabulary size:',len(word2indx))
story_maxlen = 0
question_maxlen = 0
for stories,questions,answers in [data_train,data_test]:
for story in stories:
story_len = 0
for sent in story:
swords = nltk.word_tokenize(sent)
story_len += len(swords)
if story_len > story_maxlen:
story_maxlen = story_len
for question in questions:
question_len = len(nltk.word_tokenize(question))
if question_len > question_maxlen:
question_maxlen = question_len
print('Story maximum length:',story_maxlen,'Question maximum length:',question_maxlen)
# 文章单词最大长度为14,问题中的单词最大长度为4,长度不够的补0,维度相同便于并向计算
from keras.layers import Input
from keras.layers.core import Activation,Dense,Dropout,Permute
from keras.layers.embeddings import Embedding
from keras.layers.merge import add,concatenate,dot
from keras.layers.recurrent import LSTM
from keras.models import Model
from keras.preprocessing.sequence import pad_sequences
from keras.utils import np_utils
def data_vectorization(data,word2indx,story_maxlen,question_maxlen): # 词 => 词向量
Xs,Xq,Y = [],[],[]
stories,questions,answers = data
for story,question,answer in zip(stories,questions,answers):
xs = [[word2indx[w.lower()] for w in nltk.word_tokenize(s)] for s in story] #
xs = list(itertools.chain.from_iterable(xs))
xq = [word2indx[w.lower()] for w in nltk.word_tokenize(question)]
Xs.append(xs)
Xq.append(xq)
Y.append(word2indx[answer.lower()])
return pad_sequences(Xs,maxlen=story_maxlen),pad_sequences(Xq,maxlen=question_maxlen),\
np_utils.to_categorical(Y,num_classes=len(word2indx))
Xstrain,Xqtrain,Ytrain = data_vectorization(data_train,word2indx,story_maxlen,question_maxlen)
Xstest,Xqtest,Ytest = data_vectorization(data_test,word2indx,story_maxlen,question_maxlen)
print('Train story',Xstrain.shape,'Train question',Xqtrain.shape,'Train answer',Ytrain.shape)
print('Test story',Xstest.shape,'Test question',Xqtest.shape,'Test answer',Ytest.shape)
# 超参数
EMBEDDING_SIZE = 128
LATENT_SIZE = 64
BATCH_SIZE = 64
NUM_EPOCHS = 40
# 输入层
story_input = Input(shape=(story_maxlen,))
question_input = Input(shape=(question_maxlen,))
# Story encoder embedding
# 将正整数(索引)转换为固定大小的密集向量。
# 例如,[[4],[20]]->[[0.25,0.1],[0.6,-0.2]] 此层只能用作模型中的第一层
story_encoder = Embedding(input_dim=vocab_size,output_dim=EMBEDDING_SIZE,input_length=story_maxlen)(story_input)
story_encoder = Dropout(0.2)(story_encoder)
# Question encoder embedding
question_encoder = Embedding(input_dim=vocab_size,output_dim=EMBEDDING_SIZE,input_length=question_maxlen)(question_input)
question_encoder = Dropout(0.3)(question_encoder)
# 返回两个张量的点积
match = dot([story_encoder,question_encoder],axes=[2,2])
# 将故事编码为问题的向量空间
story_encoder_c = Embedding(input_dim=vocab_size,output_dim=question_maxlen,input_length=story_maxlen)(story_input)
story_encoder_c = Dropout(0.3)(story_encoder_c)
# 结合两个向量 match和story_encoder_c
response = add([match,story_encoder_c])
response = Permute((2,1))(response)
# 结合两个向量 response和question_encoder
answer = concatenate([response, question_encoder], axis=-1)
answer = LSTM(LATENT_SIZE)(answer)
answer = Dropout(0.2)(answer)
answer = Dense(vocab_size)(answer)
output = Activation("softmax")(answer)
model = Model(inputs=[story_input, question_input], outputs=output)
model.compile(optimizer="adam", loss="categorical_crossentropy",metrics=["accuracy"])
print(model.summary())
# 模型训练
history = model.fit([Xstrain,Xqtrain],[Ytrain],batch_size=BATCH_SIZE,epochs=NUM_EPOCHS,
validation_data=([Xstest,Xqtest],[Ytest]))
# 画出准确率和损失函数
plt.title('Episodic Memory Q&A Accuracy')
plt.plot(history.history['acc'],color='g',label='train')
plt.plot(history.history['val_acc'],color='r',label='validation')
plt.legend(loc='best')
plt.show()
# get predictions of labels
ytest = np.argmax(Ytest, axis=1)
Ytest_ = model.predict([Xstest, Xqtest])
ytest_ = np.argmax(Ytest_, axis=1)
# 随机选择几个问题测试
NUM_DISPLAY = 10
for i in random.sample(range(Xstest.shape[0]),NUM_DISPLAY):
story = " ".join([indx2word[x] for x in Xstest[i].tolist() if x != 0])
question = " ".join([indx2word[x] for x in Xqtest[i].tolist()])
label = indx2word[ytest[i]]
prediction = indx2word[ytest_[i]]
print(story, question, label, prediction)
输出:
NLP(二十二)使用LSTM进行语言建模以预测最优词的更多相关文章
- NLP(二十三)使用LSTM进行语言建模以预测最优词
N元模型 预测要输入的连续词,比如 如果抽取两个连续的词汇,则称之为二元模型 准备工作 数据集使用 Alice in Wonderland 将初始数据提取N-grams import nltk imp ...
- JAVA基础知识总结:一到二十二全部总结
>一: 一.软件开发的常识 1.什么是软件? 一系列按照特定顺序组织起来的计算机数据或者指令 常见的软件: 系统软件:Windows\Mac OS \Linux 应用软件:QQ,一系列的播放器( ...
- (C/C++学习笔记) 二十二. 标准模板库
二十二. 标准模板库 ● STL基本介绍 标准模板库(STL, standard template library): C++提供的大量的函数模板(通用算法)和类模板. ※ 为什么我们一般不需要自己写 ...
- 智课雅思词汇---二十二、-al即是名词性后缀又是形容词后缀
智课雅思词汇---二十二.-al即是名词性后缀又是形容词后缀 一.总结 一句话总结: 后缀:-al ②[名词后缀] 1.构成抽象名词,表示行为.状况.事情 refusal 拒绝 proposal 提议 ...
- 学习笔记:CentOS7学习之二十二: 结构化命令case和for、while循环
目录 学习笔记:CentOS7学习之二十二: 结构化命令case和for.while循环 22.1 流程控制语句:case 22.2 循环语句 22.1.2 for-do-done 22.3 whil ...
- [分享] IT天空的二十二条军规
Una 发表于 2014-9-19 20:25:06 https://www.itsk.com/thread-335975-1-1.html IT天空的二十二条军规 第一条.你不是什么都会,也不是什么 ...
- Bootstrap <基础二十二>超大屏幕(Jumbotron)
Bootstrap 支持的另一个特性,超大屏幕(Jumbotron).顾名思义该组件可以增加标题的大小,并为登陆页面内容添加更多的外边距(margin).使用超大屏幕(Jumbotron)的步骤如下: ...
- Web 前端开发精华文章推荐(HTML5、CSS3、jQuery)【系列二十二】
<Web 前端开发精华文章推荐>2014年第一期(总第二十二期)和大家见面了.梦想天空博客关注 前端开发 技术,分享各类能够提升网站用户体验的优秀 jQuery 插件,展示前沿的 HTML ...
- 二十二、OGNL的一些其他操作
二十二.OGNL的一些其他操作 投影 ?判断满足条件 动作类代码: ^ $ public class Demo2Action extends ActionSupport { public ...
随机推荐
- 快速清理maven仓库中下载错误的文件
有时候使用pom文件下载依赖文件的时候突然网络异常,可能会出现依赖文件出现破损,导致怎么都不能使用,也没有重新下载. 之前解决办法是找到出现破损的文件并删除,让其重新下载,但是这样效率很低,也很难找到 ...
- web前端开发-博客目录
web前端开发是一个新的领域,知识连接范围广,处于设计与后端数据交互的桥梁,并且现在很多web前端相关语言标准,框架库都在高速发展.在学习过程中也常常处于烦躁与迷茫,有时候一直在想如何能够使自己更加系 ...
- dz6.0的一个sql注入漏洞
今天开始着手分析第一个漏洞,找了一上午靶机,发现一个含有成人内容的违法网站是用dz6.0搭的,今天就看看dz这个版本的洞了 问题函数位置:my.php第623行 if(is_array($descri ...
- 【iOS】iOS CocoaPods 整理
github 上下载 Demo 时第一次遇到这个情况,当时有些不知所措,也没怎么在意,后来项目调整结构时正式见到了这个,并且自己去了解学习了. CocoaPods安装和使用教程 这篇文章写得很好!ma ...
- Java的几种常见排序算法
一.所谓排序,就是使一串记录,按照其中的某个或某些关键字的大小,递增或递减的排列起来的操作.排序算法,就是如何使得记录按照要求排列的方法.排序算法在很多领域得到相当地重视,尤其是在大量数据的处理方面. ...
- 10分钟了解一致性hash算法
应用场景 当我们的数据表超过500万条或更多时,我们就会考虑到采用分库分表:当我们的系统使用了一台缓存服务器还是不能满足的时候,我们会使用多台缓存服务器,那我们如何去访问背后的库表或缓存服务器呢,我们 ...
- 解决报错:类型“System.Object”在未被引用的程序集中定义。必须添加对程序集“System.Runtime, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a”的引用
Razor视图引擎中,使用部分视图编译报错 类型“System.Object”在未被引用的程序集中定义.必须添加对程序集“System.Runtime, Version=4.0.0.0, Cultur ...
- 改 Anaconda Jupyter Notebook 开发文件保存目录
1.打开cmd,输入命令找到配置文件路径 jupyter notebook --generate-config 2.打开 jupyter_notebook_config.py 修改配置 c.Noteb ...
- collection介绍
1.collection介绍 在mongodb中,collection相当于关系型数据库的表,但并不需提前创建,更不需要预先定义字段 db.collect1.save({username:'mayj' ...
- JavaWeb——Servlet开发2
1.HttpServletRequest的使用 获取Request的参数的方法. 方法getParameter将返回参数的单个值 方法getParameterValues将返回参数的值的数组 方法ge ...