深度学习之 rnn 台词生成

写一个台词生成的程序,用 pytorch 写的。

import os
def load_data(path):
with open(path, 'r', encoding="utf-8") as f:
data = f.read()
return data text = load_data('./moes_tavern_lines.txt')[81:] train_count = int(len(text) * 0.6)
val_count = int(len(text) * 0.2)
test_count = int(len(text) * 0.2) train_text = text[:train_count]
val_text = text[train_count: train_count + val_count]
test_text = text[train_count + val_count:] view_sentence_range = (0, 10) import numpy as np print("data set State")
print("Roughly the number of unique words: {}".format(len({word: None for word in text.split()})))
scenes = text.split("\n\n")
print("number of scenes: {}".format(len(scenes)))
sentence_count_scene = [scene.count('\n') for scene in scenes]
print('Average number for sentences in each scene: {}'.format(np.average(sentence_count_scene))) sentences = [sentence for scene in scenes for sentence in scene.split('\n')]
print("Number for lines: {}".format(len(sentences)))
word_count_sentence = [len(sentence.split()) for sentence in sentences]
print('Average number for words in each line: {}'.format(np.average(word_count_sentence))) print()
print('The sentences {} to {}:'.format(*view_sentence_range))
print('\n'.join(text.split('\n')[view_sentence_range[0]:view_sentence_range[1]])) def token_lookup():
return {
'.': '||Period||',
',': '||Comma||',
'"': '||Quotation_Mark||',
';': '||Semicolon||',
'!': '||Exclamation_mark||',
'?': '||Question_mark||',
'(': '||Left_Parentheses||',
')': '||Right_Parentheses||',
'--': '||Dash||',
'\n': '||Return||',
} import os
import torch class Dictionary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = [] def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
return self.word2idx[word] def __len__(self):
return len(self.idx2word) class Corpus(object):
def __init__(self, train, val, test):
self.dictionary = Dictionary()
self.train = self.tokenize(train)
self.valid = self.tokenize(val)
self.test = self.tokenize(test) def tokenize(self, text):
words = text.split()
tokens = len(words)
token = 0
ids = torch.LongTensor(tokens)
for i, word in enumerate(words):
self.dictionary.add_word(word)
ids[i] = self.dictionary.word2idx[word] return ids import numpy as np
import torch i_dict = token_lookup() def create_data(text):
vocab_to_int = {}
int_to_vocab = {} new_text = ""
for t in text:
if t in token_lookup():
new_text += " {} ".format(i_dict[t])
else:
new_text += t return new_text import torch
import torch.nn as nn
from torch.autograd import Variable # 模型 RNN
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size, n_layers=1):
super(RNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.n_layers = n_layers self.drop = nn.Dropout(0.5) self.encoder = nn.Embedding(input_size, hidden_size) self.gru = nn.GRU(hidden_size, hidden_size, n_layers) self.decoder = nn.Linear(hidden_size, output_size) def forward(self, input, hidden):
input = self.encoder(input)
output, hidden = self.gru(input, hidden)
output = self.drop(output)
decoded = self.decoder(output.view(output.size(0) * output.size(1), output.size(2)))
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden def init_hidden(self, batch_size):
return Variable(torch.zeros(self.n_layers, batch_size, self.hidden_size)) # batch 化
def batchify(data, bsz):
# Work out how cleanly we can divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * bsz)
# Evenly divide the data across the bsz batches.
data = data.view(bsz, -1).t().contiguous() return data n_epochs = 3500
print_every = 500
plot_every = 10
hidden_size = 100
n_layers = 1
lr = 0.005
chunk_len = 10
batch_size = 20
val_batch_size = 10 # 数据生成
train_data = create_data(train_text)
test_data = create_data(test_text)
val_data = create_data(val_text) corpus = Corpus(train_data, val_data, test_data) train_source = batchify(corpus.train, batch_size)
test_source = batchify(corpus.test, batch_size)
val_source = batchify(corpus.valid, batch_size) n_tokens = len(corpus.dictionary) # 模型
model = RNN(n_tokens, hidden_size, n_tokens, n_layers) # 优化器
optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 损失函数
criterion = nn.CrossEntropyLoss() #
def get_batch(source, i , evaluation = False):
seq_len = min(chunk_len, len(source) - 1 - i)
data = Variable(source[i:i+seq_len], volatile=evaluation)
target = Variable(source[i+1:i+1+seq_len].view(-1))
return data,target def repackage_hidden(h):
if type(h) == Variable:
return Variable(h.data)
else:
return tuple(repackage_hidden(v) for v in h) # 训练
def train():
model.train()
total_loss = 0 ntokens = len(corpus.dictionary)
hidden = model.init_hidden(batch_size)
for batch, i in enumerate(range(0, train_source.size(0) - 1, chunk_len)):
data, targets = get_batch(train_source, i) hidden = repackage_hidden(hidden)
optimizer.zero_grad()
output, hidden = model(data, hidden)
loss = criterion(output.view(-1, ntokens), targets)
loss.backward()
optimizer.step() total_loss += loss.data if batch % 10 == 0:
print('epoch {}/{} {}'.format(epoch, batch, loss.data)) # 验证
def evaluate(data_source):
model.eval()
total_loss = 0 ntokens = len(corpus.dictionary)
hidden = model.init_hidden(batch_size)
for i in range(0, data_source.size(0) - 1, chunk_len):
data, targets = get_batch(data_source, i, evaluation=True) output, hidden = model(data, hidden)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).data
hidden = repackage_hidden(hidden) return total_loss[0] / len(data_source) import time, math # 开始训练
for epoch in range(1, n_epochs + 1):
train()
val_loss = evaluate(val_source)
print("epoch {} {} {}".format(epoch, val_loss, math.exp(val_loss))) # 生成一段短语
def gen(n_words):
model.eval()
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(1) input = Variable(torch.rand(1, 1).mul(ntokens).long(), volatile=True) words = []
for i in range(n_words):
output, hidden = model(input, hidden)
word_weights = output.squeeze().data.exp().cpu()
word_idx = torch.multinomial(word_weights, 1)[0]
input.data.fill_(word_idx) word = corpus.dictionary.idx2word[word_idx] isOk = False
for w,s in i_dict.items():
if s == word:
isOk = True
words.append(w)
break if not isOk:
words.append(word) return words words = gen(1000)
print(" ".join(words))

总结

rnn 总是参数不怎么对,耐心调整即可。

深度学习之 rnn 台词生成的更多相关文章

  1. 惊不惊喜, 用深度学习 把设计图 自动生成HTML代码 !

    如何用前端页面原型生成对应的代码一直是我们关注的问题,本文作者根据 pix2code 等论文构建了一个强大的前端代码生成模型,并详细解释了如何利用 LSTM 与 CNN 将设计原型编写为 HTML 和 ...

  2. 深度学习-CNN+RNN笔记

    以下叙述只是简单的叙述,CNN+RNN(LSTM,GRU)的应用相关文章还很多,而且研究的方向不仅仅是下文提到的1. CNN 特征提取,用于RNN语句生成图片标注.2. RNN特征提取用于CNN内容分 ...

  3. [深度学习]理解RNN, GRU, LSTM 网络

    Recurrent Neural Networks(RNN) 人类并不是每时每刻都从一片空白的大脑开始他们的思考.在你阅读这篇文章时候,你都是基于自己已经拥有的对先前所见词的理解来推断当前词的真实含义 ...

  4. 用深度学习技术FCN自动生成口红

    1 这个是什么?        基于全卷积神经网络(FCN)的自动生成口红Python程序. 图1 FCN生成口红的效果(注:此两张人脸图来自人脸公开数据库LFW) 2 怎么使用了?        首 ...

  5. 4.keras实现-->生成式深度学习之用GAN生成图像

    生成式对抗网络(GAN,generative adversarial network)由Goodfellow等人于2014年提出,它可以替代VAE来学习图像的潜在空间.它能够迫使生成图像与真实图像在统 ...

  6. 【深度学习】RNN | GRU | LSTM

    目录: 1.RNN 2.GRU 3.LSTM 一.RNN 1.RNN结构图如下所示: 其中: $a^{(t)} = \boldsymbol{W}h^{t-1} + \boldsymbol{W}_{e} ...

  7. 机器学习(Machine Learning)&深度学习(Deep Learning)资料【转】

    转自:机器学习(Machine Learning)&深度学习(Deep Learning)资料 <Brief History of Machine Learning> 介绍:这是一 ...

  8. 机器学习(Machine Learning)与深度学习(Deep Learning)资料汇总

    <Brief History of Machine Learning> 介绍:这是一篇介绍机器学习历史的文章,介绍很全面,从感知机.神经网络.决策树.SVM.Adaboost到随机森林.D ...

  9. CNCC2017中的深度学习与跨媒体智能

    CNCC2017中的深度学习与跨媒体智能 转载请注明作者:梦里茶 目录 机器学习与跨媒体智能 传统方法与深度学习 图像分割 小数据集下的深度学习 语音前沿技术 生成模型 基于贝叶斯的视觉信息编解码 珠 ...

随机推荐

  1. LTS和其他解决方案的比较(官方)

    主要根据LTS支持的几种任务(实时任务.定时任务.Cron任务,Repeat任务)和其他一些 开源框架在应用场景上做比较. 实时任务,实时执行 这种场景下,当任务量比较小的时候,单机都可以完成的时候. ...

  2. Angular4---部署---将Angular项目部署到IIS上

    ---恢复内容开始--- Angular项目部署到一个IIS服务器上 1.安装URL rewrite组件: 网址:https://www.microsoft.com/en-us/download/de ...

  3. WP Super Cache+七牛云配置CDN加速,让你的网站秒开

    CDN加速网站是几乎所有的站长都在考虑的问题,CDN,全称是Content Delivery Network,即内容分发网络.所谓CDN加速,通俗的来说就是把原服务器上数据复制到其他服务器上,用户访问 ...

  4. IPFS的竞争对手们(一)

    IPFS的竞争对手 IPFS这个项目真的开发很慢,相比其它区块链项目,IPFS的进度可真是让小编捉急,恨铁不成钢啊.然而小编仍然对他们充满信心,来,借用一句盗梦空间里的经典台词: 既然做梦,那就做大点 ...

  5. PHP实现KMP算法

    KMP算法是一种比较高效的字符串匹配算法,关于其讲解,可参考文章 字符串匹配的KMP算法,本文只给出具体的PHP代码实现. /** * @desc构建next数组 * @param string $s ...

  6. Mecanim之IK动画

    序言:IK动画全名是Inverse Kinematics 意思是逆向动力学,就是子骨骼节点带动父骨骼节点运动. 比如体操运动员,只靠手来带动身体各个部位的移动.手就是子骨骼,身体就是它的父骨骼,这时运 ...

  7. 【Python】 sys和os模块

    sys sys模块能使程序访问于python解释器联系紧密的变量和函数 ● sys中的一些函数和变量 argv 命令行参数构成的列表 path 查找所有可用模块所在的目录名的列表 platform 查 ...

  8. iOS CocoaPods一些特别的用法 指定版本、版本介绍、忽略警告

    简介 介绍一些CocoaPods一些特别的用法 CocoaPods github地址 CocoaPods 官方地址 1. 指定第三方库版本 1. 固定版本 target 'MyApp' do use_ ...

  9. 大数据 --> 分布式文件系统HDFS的工作原理

    分布式文件系统HDFS的工作原理 Hadoop分布式文件系统(HDFS)是一种被设计成适合运行在通用硬件上的分布式文件系统.HDFS是一个高度容错性的系统,适合部署在廉价的机器上.它能提供高吞吐量的数 ...

  10. 网络通信 --> 消息队列

    消息队列 消息队列提供了一种从一个进程向另一个进程发送一个数据块的方法.可以通过发送消息来避免命名管道的同步和阻塞问题.但是消息队列与命名管道一样,每个数据块都有一个最大长度的限制. Linux用宏M ...