深度学习之 rnn 台词生成

写一个台词生成的程序,用 pytorch 写的。

import os
def load_data(path):
with open(path, 'r', encoding="utf-8") as f:
data = f.read()
return data text = load_data('./moes_tavern_lines.txt')[81:] train_count = int(len(text) * 0.6)
val_count = int(len(text) * 0.2)
test_count = int(len(text) * 0.2) train_text = text[:train_count]
val_text = text[train_count: train_count + val_count]
test_text = text[train_count + val_count:] view_sentence_range = (0, 10) import numpy as np print("data set State")
print("Roughly the number of unique words: {}".format(len({word: None for word in text.split()})))
scenes = text.split("\n\n")
print("number of scenes: {}".format(len(scenes)))
sentence_count_scene = [scene.count('\n') for scene in scenes]
print('Average number for sentences in each scene: {}'.format(np.average(sentence_count_scene))) sentences = [sentence for scene in scenes for sentence in scene.split('\n')]
print("Number for lines: {}".format(len(sentences)))
word_count_sentence = [len(sentence.split()) for sentence in sentences]
print('Average number for words in each line: {}'.format(np.average(word_count_sentence))) print()
print('The sentences {} to {}:'.format(*view_sentence_range))
print('\n'.join(text.split('\n')[view_sentence_range[0]:view_sentence_range[1]])) def token_lookup():
return {
'.': '||Period||',
',': '||Comma||',
'"': '||Quotation_Mark||',
';': '||Semicolon||',
'!': '||Exclamation_mark||',
'?': '||Question_mark||',
'(': '||Left_Parentheses||',
')': '||Right_Parentheses||',
'--': '||Dash||',
'\n': '||Return||',
} import os
import torch class Dictionary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = [] def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
return self.word2idx[word] def __len__(self):
return len(self.idx2word) class Corpus(object):
def __init__(self, train, val, test):
self.dictionary = Dictionary()
self.train = self.tokenize(train)
self.valid = self.tokenize(val)
self.test = self.tokenize(test) def tokenize(self, text):
words = text.split()
tokens = len(words)
token = 0
ids = torch.LongTensor(tokens)
for i, word in enumerate(words):
self.dictionary.add_word(word)
ids[i] = self.dictionary.word2idx[word] return ids import numpy as np
import torch i_dict = token_lookup() def create_data(text):
vocab_to_int = {}
int_to_vocab = {} new_text = ""
for t in text:
if t in token_lookup():
new_text += " {} ".format(i_dict[t])
else:
new_text += t return new_text import torch
import torch.nn as nn
from torch.autograd import Variable # 模型 RNN
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size, n_layers=1):
super(RNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.n_layers = n_layers self.drop = nn.Dropout(0.5) self.encoder = nn.Embedding(input_size, hidden_size) self.gru = nn.GRU(hidden_size, hidden_size, n_layers) self.decoder = nn.Linear(hidden_size, output_size) def forward(self, input, hidden):
input = self.encoder(input)
output, hidden = self.gru(input, hidden)
output = self.drop(output)
decoded = self.decoder(output.view(output.size(0) * output.size(1), output.size(2)))
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden def init_hidden(self, batch_size):
return Variable(torch.zeros(self.n_layers, batch_size, self.hidden_size)) # batch 化
def batchify(data, bsz):
# Work out how cleanly we can divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * bsz)
# Evenly divide the data across the bsz batches.
data = data.view(bsz, -1).t().contiguous() return data n_epochs = 3500
print_every = 500
plot_every = 10
hidden_size = 100
n_layers = 1
lr = 0.005
chunk_len = 10
batch_size = 20
val_batch_size = 10 # 数据生成
train_data = create_data(train_text)
test_data = create_data(test_text)
val_data = create_data(val_text) corpus = Corpus(train_data, val_data, test_data) train_source = batchify(corpus.train, batch_size)
test_source = batchify(corpus.test, batch_size)
val_source = batchify(corpus.valid, batch_size) n_tokens = len(corpus.dictionary) # 模型
model = RNN(n_tokens, hidden_size, n_tokens, n_layers) # 优化器
optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 损失函数
criterion = nn.CrossEntropyLoss() #
def get_batch(source, i , evaluation = False):
seq_len = min(chunk_len, len(source) - 1 - i)
data = Variable(source[i:i+seq_len], volatile=evaluation)
target = Variable(source[i+1:i+1+seq_len].view(-1))
return data,target def repackage_hidden(h):
if type(h) == Variable:
return Variable(h.data)
else:
return tuple(repackage_hidden(v) for v in h) # 训练
def train():
model.train()
total_loss = 0 ntokens = len(corpus.dictionary)
hidden = model.init_hidden(batch_size)
for batch, i in enumerate(range(0, train_source.size(0) - 1, chunk_len)):
data, targets = get_batch(train_source, i) hidden = repackage_hidden(hidden)
optimizer.zero_grad()
output, hidden = model(data, hidden)
loss = criterion(output.view(-1, ntokens), targets)
loss.backward()
optimizer.step() total_loss += loss.data if batch % 10 == 0:
print('epoch {}/{} {}'.format(epoch, batch, loss.data)) # 验证
def evaluate(data_source):
model.eval()
total_loss = 0 ntokens = len(corpus.dictionary)
hidden = model.init_hidden(batch_size)
for i in range(0, data_source.size(0) - 1, chunk_len):
data, targets = get_batch(data_source, i, evaluation=True) output, hidden = model(data, hidden)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).data
hidden = repackage_hidden(hidden) return total_loss[0] / len(data_source) import time, math # 开始训练
for epoch in range(1, n_epochs + 1):
train()
val_loss = evaluate(val_source)
print("epoch {} {} {}".format(epoch, val_loss, math.exp(val_loss))) # 生成一段短语
def gen(n_words):
model.eval()
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(1) input = Variable(torch.rand(1, 1).mul(ntokens).long(), volatile=True) words = []
for i in range(n_words):
output, hidden = model(input, hidden)
word_weights = output.squeeze().data.exp().cpu()
word_idx = torch.multinomial(word_weights, 1)[0]
input.data.fill_(word_idx) word = corpus.dictionary.idx2word[word_idx] isOk = False
for w,s in i_dict.items():
if s == word:
isOk = True
words.append(w)
break if not isOk:
words.append(word) return words words = gen(1000)
print(" ".join(words))

总结

rnn 总是参数不怎么对,耐心调整即可。

深度学习之 rnn 台词生成的更多相关文章

  1. 惊不惊喜, 用深度学习 把设计图 自动生成HTML代码 !

    如何用前端页面原型生成对应的代码一直是我们关注的问题,本文作者根据 pix2code 等论文构建了一个强大的前端代码生成模型,并详细解释了如何利用 LSTM 与 CNN 将设计原型编写为 HTML 和 ...

  2. 深度学习-CNN+RNN笔记

    以下叙述只是简单的叙述,CNN+RNN(LSTM,GRU)的应用相关文章还很多,而且研究的方向不仅仅是下文提到的1. CNN 特征提取,用于RNN语句生成图片标注.2. RNN特征提取用于CNN内容分 ...

  3. [深度学习]理解RNN, GRU, LSTM 网络

    Recurrent Neural Networks(RNN) 人类并不是每时每刻都从一片空白的大脑开始他们的思考.在你阅读这篇文章时候,你都是基于自己已经拥有的对先前所见词的理解来推断当前词的真实含义 ...

  4. 用深度学习技术FCN自动生成口红

    1 这个是什么?        基于全卷积神经网络(FCN)的自动生成口红Python程序. 图1 FCN生成口红的效果(注:此两张人脸图来自人脸公开数据库LFW) 2 怎么使用了?        首 ...

  5. 4.keras实现-->生成式深度学习之用GAN生成图像

    生成式对抗网络(GAN,generative adversarial network)由Goodfellow等人于2014年提出,它可以替代VAE来学习图像的潜在空间.它能够迫使生成图像与真实图像在统 ...

  6. 【深度学习】RNN | GRU | LSTM

    目录: 1.RNN 2.GRU 3.LSTM 一.RNN 1.RNN结构图如下所示: 其中: $a^{(t)} = \boldsymbol{W}h^{t-1} + \boldsymbol{W}_{e} ...

  7. 机器学习(Machine Learning)&深度学习(Deep Learning)资料【转】

    转自:机器学习(Machine Learning)&深度学习(Deep Learning)资料 <Brief History of Machine Learning> 介绍:这是一 ...

  8. 机器学习(Machine Learning)与深度学习(Deep Learning)资料汇总

    <Brief History of Machine Learning> 介绍:这是一篇介绍机器学习历史的文章,介绍很全面,从感知机.神经网络.决策树.SVM.Adaboost到随机森林.D ...

  9. CNCC2017中的深度学习与跨媒体智能

    CNCC2017中的深度学习与跨媒体智能 转载请注明作者:梦里茶 目录 机器学习与跨媒体智能 传统方法与深度学习 图像分割 小数据集下的深度学习 语音前沿技术 生成模型 基于贝叶斯的视觉信息编解码 珠 ...

随机推荐

  1. 【经验随笔】Java通过代理访问互联网平台提供的WebService接口的一种方法

    背景 通常有两点原因需要通过代理访问互联网平台的提供的WebService接口: 1. 在公司企业内网访问外部互联网平台发布的接口,公司要求通过代理访问外网. 2. 频繁访问平台接口,IP被平台封了, ...

  2. 完全卸载hadoop安装的组件(hdp版本)

    yum remove -y hadoop_* zookeeper* ranger* hbase_* ranger* hbase_* ambari-* hadoop_* zookeeper_* hbas ...

  3. javascript获取系统时间

    function GetDateStr(AddDayCount) { var dd = new Date(); dd.setDate(dd.getDate()+AddDayCount); var ye ...

  4. js文本框字符数输入限制

    我们常常在前台页面做一些文本输入长度的验证,为什么呢?因为数据库字段设置了大小,如果不限制输入长度,那么写入库时就会引发字符串截断异常.今天就给大家分享一个jquery插件来解决这一问题. (func ...

  5. python xlsxwriter库生成图表的应用

    xlsxwriter可能用过的人并不是很多,不过使用后就会感觉,他的功能让你叹服,除了可以按要求生成你所需要的excel外 还可以加上很形象的各种图,比如柱状图.饼图.折线图等. 请看本人生成的: 这 ...

  6. Injection of autowired dependencies failed

    error:org.springframework.beans.factory.BeanCreationException: Error creating bean with name 'mainCo ...

  7. MySQL多数据源笔记3-分库分表理论和各种中间件

    一.使用中间件的好处 使用中间件对于主读写分离新增一个从数据库节点来说,可以不用修改代码,达到新增节点数据库而不影响到代码的修改.因为如果不用中间件,那么在代码中自己是先读写分离,如果新增节点, 你进 ...

  8. 新装的Linux服务系统安装MySQL

    目的描述:全新的腾讯云Linux服务器,系统是ubuntu 16.04.需要在上面安装mysql数据库. 使用XShell远程登录,在终端窗口中使用sudo apt-get 指令在线安装mysql. ...

  9. 前端dom元素的操作优化建议

    参考自:http://blog.csdn.net/xuexiaodong009/article/details/51810252 其实在web开发中,单纯因为js导致性能问题的很少,主要都是因为DOM ...

  10. 连不上虚拟机中的Redis的原因分析、以及虚拟机网络配置

    1. 网络最好是桥接方式.我之前用的是"网络地址转换(NAT)",导致虚拟机里用命令ifconfig得到的ip是10.0.2.15,好奇怪的感觉,然后在真实机上一直连不上.有的说用 ...