decoder.py

"""
实现解码器
"""
import heapq import torch.nn as nn
import config
import torch
import torch.nn.functional as F
import numpy as np
import random
from chatbot.attention import Attention class Decoder(nn.Module):
def __init__(self):
super(Decoder,self).__init__() self.embedding = nn.Embedding(num_embeddings=len(config.target_ws),
embedding_dim=config.chatbot_decoder_embedding_dim,
padding_idx=config.target_ws.PAD) #需要的hidden_state形状:[1,batch_size,64]
self.gru = nn.GRU(input_size=config.chatbot_decoder_embedding_dim,
hidden_size=config.chatbot_decoder_hidden_size,
num_layers=config.chatbot_decoder_number_layer,
bidirectional=False,
batch_first=True,
dropout=config.chatbot_decoder_dropout) #假如encoder的hidden_size=64,num_layer=1 encoder_hidden :[2,batch_sizee,64] self.fc = nn.Linear(config.chatbot_decoder_hidden_size,len(config.target_ws))
self.attn = Attention(method="general")
self.fc_attn = nn.Linear(config.chatbot_decoder_hidden_size * 2, config.chatbot_decoder_hidden_size, bias=False) def forward(self, encoder_hidden,target,encoder_outputs):
# print("target size:",target.size())
#第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden #[1,batch_size,128*2]
#第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
decoder_input = torch.LongTensor([[config.target_ws.SOS]]*batch_size).to(config.device) #[batch_size,1]
# print("decoder_input:",decoder_input.size()) #使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
decoder_outputs = torch.zeros([batch_size,config.chatbot_target_max_len,len(config.target_ws)]).to(config.device) if random.random()>0.5: #teacher_forcing机制 for t in range(config.chatbot_target_max_len):
decoder_output_t,decoder_hidden = self.forward_step(decoder_input,decoder_hidden,encoder_outputs)
decoder_outputs[:,t,:] = decoder_output_t #获取当前时间步的预测值
value,index = decoder_output_t.max(dim=-1)
decoder_input = index.unsqueeze(-1) #[batch_size,1]
# print("decoder_input:",decoder_input.size())
else:
for t in range(config.chatbot_target_max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
decoder_outputs[:, t, :] = decoder_output_t
#把真实值作为下一步的输入
decoder_input = target[:,t].unsqueeze(-1)
# print("decoder_input size:",decoder_input.size())
return decoder_outputs,decoder_hidden def forward_step(self,decoder_input,decoder_hidden,encoder_outputs):
'''
计算一个时间步的结果
:param decoder_input: [batch_size,1]
:param decoder_hidden: [1,batch_size,128*2]
:return:
''' decoder_input_embeded = self.embedding(decoder_input)
# print("decoder_input_embeded:",decoder_input_embeded.size()) #out:[batch_size,1,128*2]
#decoder_hidden :[1,bathc_size,128*2]
# print(decoder_hidden.size())
out,decoder_hidden = self.gru(decoder_input_embeded,decoder_hidden) ##### 开始attention ############
### 1. 计算attention weight
attn_weight = self.attn(decoder_hidden,encoder_outputs) #[batch_size,1,encoder_max_len]
### 2. 计算context vector
#encoder_ouputs :[batch_size,encoder_max_len,128*2]
context_vector = torch.bmm(attn_weight.unsqueeze(1),encoder_outputs).squeeze(1) #[batch_szie,128*2]
### 3. 计算 attention的结果
#[batch_size,128*2] #context_vector:[batch_size,128*2] --> 128*4
#attention_result = [batch_size,128*4] --->[batch_size,128*2]
attention_result = torch.tanh(self.fc_attn(torch.cat([context_vector,out.squeeze(1)],dim=-1)))
# attention_result = torch.tanh(torch.cat([context_vector,out.squeeze(1)],dim=-1))
#### attenion 结束 # print("decoder_hidden size:",decoder_hidden.size())
#out :【batch_size,1,hidden_size】 # out_squeezed = out.squeeze(dim=1) #去掉为1的维度
out_fc = F.log_softmax(self.fc(attention_result),dim=-1) #[bathc_size,vocab_size]
# print("out_fc:",out_fc.size())
return out_fc,decoder_hidden def evaluate(self,encoder_hidden,encoder_outputs): # 第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden # [1,batch_size,128*2]
# 第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device) # [batch_size,1]
# print("decoder_input:",decoder_input.size()) # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
decoder_outputs = torch.zeros([batch_size, config.chatbot_target_max_len, len(config.target_ws)]).to(
config.device) predict_result = []
for t in range(config.chatbot_target_max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
decoder_outputs[:, t, :] = decoder_output_t # 获取当前时间步的预测值
value, index = decoder_output_t.max(dim=-1)
predict_result.append(index.cpu().detach().numpy()) #[[batch],[batch]...]
decoder_input = index.unsqueeze(-1) # [batch_size,1]
# print("decoder_input:",decoder_input.size())
# predict_result.append(decoder_input)
#把结果转化为ndarray,每一行是一条预测结果
predict_result = np.array(predict_result).transpose()
return decoder_outputs, predict_result def evaluate_with_beam_search(self, encoder_hidden, encoder_outputs):
"""
使用beam search完成评估,只能输入一个句子,得到一个输出
:param encoder_hidden:
:param encoder_outputs:
:return:
"""
# 第一个时间步的输入的hidden_state
decoder_hidden = encoder_hidden # [1,batch_size,128*2]
# 第一个时间步的输入的input
batch_size = encoder_hidden.size(1)
assert batch_size == 1, "beam search的过程中,batch_size只能为1"
decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device) # [batch_size,1] prev_beam = Beam()
prev_beam.add(1, False, [decoder_input], decoder_input, decoder_hidden) while True:
cur_beam = Beam()
for prob, complete, seq_list, decoder_input, decoder_hidden in prev_beam:
if complete: # 有可能前一次已经到达eos了,但是概率不是最大的
cur_beam.add(prob, complete, seq_list, decoder_input, decoder_hidden)
else:
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs) value, index = torch.topk(decoder_output_t, config.beam_width)
# print("value index size:",value[0].size(),index[0].size())
for m, n in zip(value[0], index[0]):
# print("m,n size:",m.size(),n.size(),m,n)
cur_prob = prob * m.item()
decoder_input = torch.LongTensor([[n.item()]]).to(config.device)
cur_seq_list = seq_list + [decoder_input]
if n == config.target_ws.EOS:
cur_complete = True
else:
cur_complete = False
cur_beam.add(cur_prob, cur_complete, cur_seq_list, decoder_input, decoder_hidden) best_prob, best_complete, best_seq, _, _ = max(cur_beam)
if best_complete or len(best_seq) - 1 == config.chatbot_target_max_len: best_seq = [i.item() for i in best_seq]
if best_seq[0] == config.target_ws.SOS:
best_seq = best_seq[1:]
if best_seq[-1] == config.target_ws.EOS:
best_seq = best_seq[:-1]
return best_seq else:
prev_beam = cur_beam class Beam:
"""保存每一个时间步的数据""" def __init__(self):
self.heapq = list()
self.beam_width = config.beam_width def add(self, prob, complete, seq_list, decoder_input, decoder_hidden):
heapq.heappush(self.heapq, [prob, complete, seq_list, decoder_input, decoder_hidden])
# 保证最终只有一个beam width个结果
if len(self.heapq) > self.beam_width:
heapq.heappop(self.heapq) def __iter__(self):
for item in self.heapq:
yield item

  seq2seq.py

"""
完成seq2seq模型
"""
import torch.nn as nn
from chatbot.encoder import Encoder
from chatbot.decoder import Decoder class Seq2Seq(nn.Module):
def __init__(self):
super(Seq2Seq,self).__init__()
self.encoder = Encoder()
self.decoder = Decoder() def forward(self, input,input_len,target):
encoder_outputs,encoder_hidden = self.encoder(input,input_len)
decoder_outputs,decoder_hidden = self.decoder(encoder_hidden,target,encoder_outputs)
return decoder_outputs def evaluate(self,input,input_len):
encoder_outputs, encoder_hidden = self.encoder(input, input_len)
decoder_outputs, predict_result = self.decoder.evaluate(encoder_hidden,encoder_outputs)
return decoder_outputs,predict_result def evaluate_with_beam_search(self,input,input_len):
encoder_outputs, encoder_hidden = self.encoder(input, input_len)
best_seq = self.decoder.evaluate_with_beam_search(encoder_hidden, encoder_outputs)
return best_seq

  eval.py

"""
进行模型的评估
""" import torch
import torch.nn.functional as F
from chatbot.dataset import get_dataloader
from tqdm import tqdm
import config
import numpy as np
import pickle
from chatbot.seq2seq import Seq2Seq def eval():
model = Seq2Seq().to(config.device)
model.eval()
model.load_state_dict(torch.load("./models/model.pkl")) loss_list = []
data_loader = get_dataloader(train=False)
bar = tqdm(data_loader,total=len(data_loader),desc="当前进行评估")
with torch.no_grad():
for idx,(input,target,input_len,target_len) in enumerate(bar):
input = input.to(config.device)
target = target.to(config.device)
input_len = input_len.to(config.device) decoder_outputs,predict_result = model.evaluate(input,input_len) #[batch_Size,max_len,vocab_size]
loss = F.nll_loss(decoder_outputs.view(-1,len(config.target_ws)),target.view(-1),ignore_index=config.input_ws.PAD)
loss_list.append(loss.item())
bar.set_description("idx:{} loss:{:.6f}".format(idx,np.mean(loss_list)))
print("当前的平均损失为:",np.mean(loss_list)) def interface():
from chatbot.cut_sentence import cut
import config
#加载模型
model = Seq2Seq().to(config.device)
model.eval()
model.load_state_dict(torch.load("./models/model.pkl")) #准备待预测的数据
while True:
origin_input =input("me>>:")
# if "你是谁" in origin_input or "你叫什么" in origin_input:
# result = "我是小智。"
# elif "你好" in origin_input or "hello" in origin_input:
# result = "Hello"
# else:
_input = cut(origin_input, by_word=True)
input_len = torch.LongTensor([len(_input)]).to(config.device)
_input = torch.LongTensor([config.input_ws.transform(_input,max_len=config.chatbot_input_max_len)]).to(config.device) outputs,predict = model.evaluate(_input,input_len)
result = config.target_ws.inverse_transform(predict[0])
print("chatbot>>:",result) def interface_with_beamsearch():
from chatbot.cut_sentence import cut
import config
# 加载模型
model = Seq2Seq().to(config.device)
model.eval()
model.load_state_dict(torch.load("./models/model.pkl")) # 准备待预测的数据
while True:
origin_input = input("me>>:")
_input = cut(origin_input, by_word=True)
input_len = torch.LongTensor([len(_input)]).to(config.device)
_input = torch.LongTensor([config.input_ws.transform(_input, max_len=config.chatbot_input_max_len)]).to(
config.device) best_seq = model.evaluate_with_beam_search(_input, input_len)
result = config.target_ws.inverse_transform(best_seq)
print("chatbot>>:", result) if __name__ == '__main__':
# interface()
interface_with_beamsearch()

  

pytorch seq2seq闲聊机器人beam search返回结果的更多相关文章

  1. pytorch seq2seq闲聊机器人

    cut_sentence.py """ 实现句子的分词 注意点: 1. 实现单个字分词 2. 实现按照词语分词 2.1 加载词典 3. 使用停用词 "" ...

  2. pytorch seq2seq闲聊机器人加入attention机制

    attention.py """ 实现attention """ import torch import torch.nn as nn im ...

  3. 实现nlp文本生成中的beam search解码器

    自然语言处理任务,比如caption generation(图片描述文本生成).机器翻译中,都需要进行词或者字符序列的生成.常见于seq2seq模型或者RNNLM模型中. 这篇博文主要介绍文本生成解码 ...

  4. Beam Search快速理解及代码解析(下)

    Beam Search的问题 先解释一下什么要对Beam Search进行改进.因为Beam Search虽然比贪心强了不少,但还是会生成出空洞.重复.前后矛盾的文本.如果你有文本生成经验,一定对这些 ...

  5. Beam Search快速理解及代码解析

    目录 Beam Search快速理解及代码解析(上) Beam Search 贪心搜索 Beam Search Beam Search代码解析 准备初始输入 序列扩展 准备输出 总结 Beam Sea ...

  6. 【NLP】选择目标序列:贪心搜索和Beam search

    构建seq2seq模型,并训练完成后,我们只要将源句子输入进训练好的模型,执行一次前向传播就能得到目标句子,但是值得注意的是: seq2seq模型的decoder部分实际上相当于一个语言模型,相比于R ...

  7. Beam Search快速理解及代码解析(上)

    Beam Search 简单介绍一下在文本生成任务中常用的解码策略Beam Search(集束搜索). 生成式任务相比普通的分类.tagging等NLP任务会复杂不少.在生成的时候,模型的输出是一个时 ...

  8. Beam Search(集束搜索/束搜索)

    找遍百度也没有找到关于Beam Search的详细解释,只有一些比较泛泛的讲解,于是有了这篇博文. 首先给出wiki地址:http://en.wikipedia.org/wiki/Beam_searc ...

  9. 关于Beam Search

    Wiki定义:In computer science, beam search is a heuristic search algorithm that explores a graph by exp ...

随机推荐

  1. CodeForces - 1249E 楼梯和电梯

    题意:第一行输入n和c,表示有n层楼,电梯来到需要时间c 输入两行数,每行n-1个,表示从一楼到二楼,二楼到三楼.....n-1楼到n楼,a[ ] 走楼梯和 b[ ] 乘电梯花费的时间 思路:动态规划 ...

  2. 从使用到原理,探究Java线程池

    什么是线程池 当我们需要处理某个任务的时候,可以新创建一个线程,让线程去执行任务.线程池的字面意思就是存放线程的池子,当我们需要处理某个任务的时候,可以从线程池里取出一条线程去执行. 为什么需要线程池 ...

  3. ArcSDE数据库、文件地理数据库和个人地理数据库的区别

    Geodatabase地理数据库分为: Personal Geodastabase个人地理数据库, File Geodatabase文件地理数据库, ArcSDE Geodatabase SDE地理数 ...

  4. Python第十一章-常用的核心模块01-collections模块

    python 自称 "Batteries included"(自带电池, 自备干粮?), 就是因为他提供了很多内置的模块, 使用这些模块无需安装和配置即可使用. 本章主要介绍 py ...

  5. Codeforces 631 (Div. 2) D. Dreamoon Likes Sequences 位运算^ 组合数 递推

    https://codeforces.com/contest/1330/problem/D 给出d,m, 找到一个a数组,满足以下要求: a数组的长度为n,n≥1; 1≤a1<a2<⋯&l ...

  6. Python执行js之PyexecJs

    利用Python执行js 爬虫中会经常碰到JS加密,当我们找到他加密的JS代码之后我们需要获取它的返回值,python虽然可以模仿js写一个python版本的加密,但是这样有点费时间,也有点费头发~ ...

  7. 《Java基础复习》—常识与入门

    突然发现自己Java基础的底子不到位,复习! 所记知识会发布在CSDN与博客网站jirath.cn <Java基础复习>-常识与入门 一.Java语言的知识体系图 分为三部分 编程语言核心 ...

  8. 1005 Spell It Right (20 分)

    Given a non-negative integer N, your task is to compute the sum of all the digits of N, and output e ...

  9. Python Count函数的应用

    Python Count函数的应用 通过LeetCode Origin:https://leetcode-cn.com/problems/robot-return-to-origin/ 学会了Pyth ...

  10. Zabbix报警机制,Zabbix进阶操作,监控案例

                                                                                                        ...