pytorch seq2seq闲聊机器人beam search返回结果
decoder.py
"""
实现解码器
"""
import heapq import torch.nn as nn
import config
import torch
import torch.nn.functional as F
import numpy as np
import random
from chatbot.attention import Attention class Decoder(nn.Module):
def __init__(self):
super(Decoder,self).__init__() self.embedding = nn.Embedding(num_embeddings=len(config.target_ws),
embedding_dim=config.chatbot_decoder_embedding_dim,
padding_idx=config.target_ws.PAD) #需要的hidden_state形状:[1,batch_size,64]
self.gru = nn.GRU(input_size=config.chatbot_decoder_embedding_dim,
hidden_size=config.chatbot_decoder_hidden_size,
num_layers=config.chatbot_decoder_number_layer,
bidirectional=False,
batch_first=True,
dropout=config.chatbot_decoder_dropout) #假如encoder的hidden_size=64,num_layer=1 encoder_hidden :[2,batch_sizee,64] self.fc = nn.Linear(config.chatbot_decoder_hidden_size,len(config.target_ws))
self.attn = Attention(method="general")
self.fc_attn = nn.Linear(config.chatbot_decoder_hidden_size * 2, config.chatbot_decoder_hidden_size, bias=False) def forward(self, encoder_hidden,target,encoder_outputs):
# print("target size:",target.size())
#第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden #[1,batch_size,128*2]
#第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
decoder_input = torch.LongTensor([[config.target_ws.SOS]]*batch_size).to(config.device) #[batch_size,1]
# print("decoder_input:",decoder_input.size()) #使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
decoder_outputs = torch.zeros([batch_size,config.chatbot_target_max_len,len(config.target_ws)]).to(config.device) if random.random()>0.5: #teacher_forcing机制 for t in range(config.chatbot_target_max_len):
decoder_output_t,decoder_hidden = self.forward_step(decoder_input,decoder_hidden,encoder_outputs)
decoder_outputs[:,t,:] = decoder_output_t #获取当前时间步的预测值
value,index = decoder_output_t.max(dim=-1)
decoder_input = index.unsqueeze(-1) #[batch_size,1]
# print("decoder_input:",decoder_input.size())
else:
for t in range(config.chatbot_target_max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
decoder_outputs[:, t, :] = decoder_output_t
#把真实值作为下一步的输入
decoder_input = target[:,t].unsqueeze(-1)
# print("decoder_input size:",decoder_input.size())
return decoder_outputs,decoder_hidden def forward_step(self,decoder_input,decoder_hidden,encoder_outputs):
'''
计算一个时间步的结果
:param decoder_input: [batch_size,1]
:param decoder_hidden: [1,batch_size,128*2]
:return:
''' decoder_input_embeded = self.embedding(decoder_input)
# print("decoder_input_embeded:",decoder_input_embeded.size()) #out:[batch_size,1,128*2]
#decoder_hidden :[1,bathc_size,128*2]
# print(decoder_hidden.size())
out,decoder_hidden = self.gru(decoder_input_embeded,decoder_hidden) ##### 开始attention ############
### 1. 计算attention weight
attn_weight = self.attn(decoder_hidden,encoder_outputs) #[batch_size,1,encoder_max_len]
### 2. 计算context vector
#encoder_ouputs :[batch_size,encoder_max_len,128*2]
context_vector = torch.bmm(attn_weight.unsqueeze(1),encoder_outputs).squeeze(1) #[batch_szie,128*2]
### 3. 计算 attention的结果
#[batch_size,128*2] #context_vector:[batch_size,128*2] --> 128*4
#attention_result = [batch_size,128*4] --->[batch_size,128*2]
attention_result = torch.tanh(self.fc_attn(torch.cat([context_vector,out.squeeze(1)],dim=-1)))
# attention_result = torch.tanh(torch.cat([context_vector,out.squeeze(1)],dim=-1))
#### attenion 结束 # print("decoder_hidden size:",decoder_hidden.size())
#out :【batch_size,1,hidden_size】 # out_squeezed = out.squeeze(dim=1) #去掉为1的维度
out_fc = F.log_softmax(self.fc(attention_result),dim=-1) #[bathc_size,vocab_size]
# print("out_fc:",out_fc.size())
return out_fc,decoder_hidden def evaluate(self,encoder_hidden,encoder_outputs): # 第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden # [1,batch_size,128*2]
# 第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device) # [batch_size,1]
# print("decoder_input:",decoder_input.size()) # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
decoder_outputs = torch.zeros([batch_size, config.chatbot_target_max_len, len(config.target_ws)]).to(
config.device) predict_result = []
for t in range(config.chatbot_target_max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
decoder_outputs[:, t, :] = decoder_output_t # 获取当前时间步的预测值
value, index = decoder_output_t.max(dim=-1)
predict_result.append(index.cpu().detach().numpy()) #[[batch],[batch]...]
decoder_input = index.unsqueeze(-1) # [batch_size,1]
# print("decoder_input:",decoder_input.size())
# predict_result.append(decoder_input)
#把结果转化为ndarray,每一行是一条预测结果
predict_result = np.array(predict_result).transpose()
return decoder_outputs, predict_result def evaluate_with_beam_search(self, encoder_hidden, encoder_outputs):
"""
使用beam search完成评估,只能输入一个句子,得到一个输出
:param encoder_hidden:
:param encoder_outputs:
:return:
"""
# 第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden # [1,batch_size,128*2]
# 第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
assert batch_size == 1, "beam search的过程中,batch_size只能为1"
decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device) # [batch_size,1] prev_beam = Beam()
prev_beam.add(1, False, [decoder_input], decoder_input, decoder_hidden) while True:
cur_beam = Beam()
for prob, complete, seq_list, decoder_input, decoder_hidden in prev_beam:
if complete: # 有可能前一次已经到达eos了,但是概率不是最大的
cur_beam.add(prob, complete, seq_list, decoder_input, decoder_hidden)
else:
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs) value, index = torch.topk(decoder_output_t, config.beam_width)
# print("value index size:",value[0].size(),index[0].size())
for m, n in zip(value[0], index[0]):
# print("m,n size:",m.size(),n.size(),m,n)
cur_prob = prob * m.item()
decoder_input = torch.LongTensor([[n.item()]]).to(config.device)
cur_seq_list = seq_list + [decoder_input]
if n == config.target_ws.EOS:
cur_complete = True
else:
cur_complete = False
cur_beam.add(cur_prob, cur_complete, cur_seq_list, decoder_input, decoder_hidden) best_prob, best_complete, best_seq, _, _ = max(cur_beam)
if best_complete or len(best_seq) - 1 == config.chatbot_target_max_len: best_seq = [i.item() for i in best_seq]
if best_seq[0] == config.target_ws.SOS:
best_seq = best_seq[1:]
if best_seq[-1] == config.target_ws.EOS:
best_seq = best_seq[:-1]
return best_seq else:
prev_beam = cur_beam class Beam:
"""保存每一个时间步的数据""" def __init__(self):
self.heapq = list()
self.beam_width = config.beam_width def add(self, prob, complete, seq_list, decoder_input, decoder_hidden):
heapq.heappush(self.heapq, [prob, complete, seq_list, decoder_input, decoder_hidden])
# 保证最终只有一个beam width个结果
if len(self.heapq) > self.beam_width:
heapq.heappop(self.heapq) def __iter__(self):
for item in self.heapq:
yield item
seq2seq.py
"""
完成seq2seq模型
"""
import torch.nn as nn
from chatbot.encoder import Encoder
from chatbot.decoder import Decoder class Seq2Seq(nn.Module):
def __init__(self):
super(Seq2Seq,self).__init__()
self.encoder = Encoder()
self.decoder = Decoder() def forward(self, input,input_len,target):
encoder_outputs,encoder_hidden = self.encoder(input,input_len)
decoder_outputs,decoder_hidden = self.decoder(encoder_hidden,target,encoder_outputs)
return decoder_outputs def evaluate(self,input,input_len):
encoder_outputs, encoder_hidden = self.encoder(input, input_len)
decoder_outputs, predict_result = self.decoder.evaluate(encoder_hidden,encoder_outputs)
return decoder_outputs,predict_result def evaluate_with_beam_search(self,input,input_len):
encoder_outputs, encoder_hidden = self.encoder(input, input_len)
best_seq = self.decoder.evaluate_with_beam_search(encoder_hidden, encoder_outputs)
return best_seq
eval.py
"""
进行模型的评估
""" import torch
import torch.nn.functional as F
from chatbot.dataset import get_dataloader
from tqdm import tqdm
import config
import numpy as np
import pickle
from chatbot.seq2seq import Seq2Seq def eval():
model = Seq2Seq().to(config.device)
model.eval()
model.load_state_dict(torch.load("./models/model.pkl")) loss_list = []
data_loader = get_dataloader(train=False)
bar = tqdm(data_loader,total=len(data_loader),desc="当前进行评估")
with torch.no_grad():
for idx,(input,target,input_len,target_len) in enumerate(bar):
input = input.to(config.device)
target = target.to(config.device)
input_len = input_len.to(config.device) decoder_outputs,predict_result = model.evaluate(input,input_len) #[batch_Size,max_len,vocab_size]
loss = F.nll_loss(decoder_outputs.view(-1,len(config.target_ws)),target.view(-1),ignore_index=config.input_ws.PAD)
loss_list.append(loss.item())
bar.set_description("idx:{} loss:{:.6f}".format(idx,np.mean(loss_list)))
print("当前的平均损失为:",np.mean(loss_list)) def interface():
from chatbot.cut_sentence import cut
import config
#加载模型
model = Seq2Seq().to(config.device)
model.eval()
model.load_state_dict(torch.load("./models/model.pkl")) #准备待预测的数据
while True:
origin_input =input("me>>:")
# if "你是谁" in origin_input or "你叫什么" in origin_input:
# result = "我是小智。"
# elif "你好" in origin_input or "hello" in origin_input:
# result = "Hello"
# else:
_input = cut(origin_input, by_word=True)
input_len = torch.LongTensor([len(_input)]).to(config.device)
_input = torch.LongTensor([config.input_ws.transform(_input,max_len=config.chatbot_input_max_len)]).to(config.device) outputs,predict = model.evaluate(_input,input_len)
result = config.target_ws.inverse_transform(predict[0])
print("chatbot>>:",result) def interface_with_beamsearch():
from chatbot.cut_sentence import cut
import config
# 加载模型
model = Seq2Seq().to(config.device)
model.eval()
model.load_state_dict(torch.load("./models/model.pkl")) # 准备待预测的数据
while True:
origin_input = input("me>>:")
_input = cut(origin_input, by_word=True)
input_len = torch.LongTensor([len(_input)]).to(config.device)
_input = torch.LongTensor([config.input_ws.transform(_input, max_len=config.chatbot_input_max_len)]).to(
config.device) best_seq = model.evaluate_with_beam_search(_input, input_len)
result = config.target_ws.inverse_transform(best_seq)
print("chatbot>>:", result) if __name__ == '__main__':
# interface()
interface_with_beamsearch()
pytorch seq2seq闲聊机器人beam search返回结果的更多相关文章
- pytorch seq2seq闲聊机器人
cut_sentence.py """ 实现句子的分词 注意点: 1. 实现单个字分词 2. 实现按照词语分词 2.1 加载词典 3. 使用停用词 "" ...
- pytorch seq2seq闲聊机器人加入attention机制
attention.py """ 实现attention """ import torch import torch.nn as nn im ...
- 实现nlp文本生成中的beam search解码器
自然语言处理任务,比如caption generation(图片描述文本生成).机器翻译中,都需要进行词或者字符序列的生成.常见于seq2seq模型或者RNNLM模型中. 这篇博文主要介绍文本生成解码 ...
- Beam Search快速理解及代码解析(下)
Beam Search的问题 先解释一下什么要对Beam Search进行改进.因为Beam Search虽然比贪心强了不少,但还是会生成出空洞.重复.前后矛盾的文本.如果你有文本生成经验,一定对这些 ...
- Beam Search快速理解及代码解析
目录 Beam Search快速理解及代码解析(上) Beam Search 贪心搜索 Beam Search Beam Search代码解析 准备初始输入 序列扩展 准备输出 总结 Beam Sea ...
- 【NLP】选择目标序列:贪心搜索和Beam search
构建seq2seq模型,并训练完成后,我们只要将源句子输入进训练好的模型,执行一次前向传播就能得到目标句子,但是值得注意的是: seq2seq模型的decoder部分实际上相当于一个语言模型,相比于R ...
- Beam Search快速理解及代码解析(上)
Beam Search 简单介绍一下在文本生成任务中常用的解码策略Beam Search(集束搜索). 生成式任务相比普通的分类.tagging等NLP任务会复杂不少.在生成的时候,模型的输出是一个时 ...
- Beam Search(集束搜索/束搜索)
找遍百度也没有找到关于Beam Search的详细解释,只有一些比较泛泛的讲解,于是有了这篇博文. 首先给出wiki地址:http://en.wikipedia.org/wiki/Beam_searc ...
- 关于Beam Search
Wiki定义:In computer science, beam search is a heuristic search algorithm that explores a graph by exp ...
随机推荐
- Jmeter接口测试之案例实战(十一)
在前面的知识体系中详细的介绍了Jmeter测试工具在接口自动化测试中的基础知识,那么今天更新的文章主要是对昨晚的上课内容做个总结. 首先来看Jmeter测试工具在图片上传中的案例应用.首先结合互联网产 ...
- python——体育竞技
一.体育竞技分析基本规则两个球员,交替用球拍击球发球权,回合未能进行一次击打回合结束首先达到15分赢得比赛 1.自顶向下的设计 #7_game_2.py from random import * de ...
- Java系列之泛型
自从 JDK 1.5 提供了泛型概念,泛型使得开发者可以定义较为安全的类型,不至于强制类型转化时出现类型转化异常,在没有反省之前,可以通过 Object 来完成不同类型数据之间的操作,但是强制类型转换 ...
- Java递归练习201908091049
package org.jimmy.autofactory.test; public class TestRecursive20190809 { public static void main(Str ...
- D. Fight with Monsters
D. Fight with Monsters time limit per test 1 second memory limit per test 256 megabytes input standa ...
- KVC讲解
今天趁着项目bug修复完了,来讲解一下OC知识的另一个技术点-KVC!针对KVC,讲解两个知识点 通过KVC修改属性会触发KVO么? KVC的赋值过程是怎样的?原理是什么? KVC的取值过程是怎样的? ...
- 监听窗口大小变化,改变画面大小-[Three.js]-[onResize]
如果没有监听窗口变化,将会出现一下情况: ![](https://img2018.cnblogs.com/blog/1735896/202001/1735896-20200102081845027-2 ...
- vue中的$router 和 $route的区别
最近在学习vue的单页面应用开发,需要vue全家桶,其中用到了VueRouter,在路由的设置和跳转中遇到了两个对象$router 和 $route ,有些傻傻分不清,后来自己结合网上的博客和自己本地 ...
- 1007 Maximum Subsequence Sum (25 分)
1007 Maximum Subsequence Sum (25 分) Given a sequence of K integers { N1, N2, ..., NK }. A ...
- PTA数据结构与算法题目集(中文) 7-29
PTA数据结构与算法题目集(中文) 7-29 7-29 修理牧场 (25 分) 农夫要修理牧场的一段栅栏,他测量了栅栏,发现需要N块木头,每块木头长度为整数Li个长度单位,于是他购买了一条 ...