简单明朗的 RNN 写诗教程
简单明朗的 RNN 写诗教程
本来想做一个标题党的,取了一个史上最简单的 RNN 写诗教程这标题,但是后来想了想,这TM不就是标题党吗?怎么活成了自己最讨厌的模样?后来就改成了这个标题。
在上篇博客网络流量预测入门(一)之RNN 介绍中,介绍了RNN的原理,而在这篇博客中,将介绍如何使用keras构建RNN,然后自动写诗。
项目地址:Github:https://github.com/xiaohuiduan/rnn_chinese_poetry
数据集介绍
既然是写诗,当然得有数据集,不过还好有大神已经将数据集准备好了,具体数据集的来源已不可知,因为网上基本上都是使用这个数据集。(如果有人知道,可以在评论区指出,然后我再添加上)
数据集地址:Github,数据集部分数据如下所示:
在数据集中,每一行都是一首唐诗,其中,诗的题目和内容以 ":" 分开,每一首诗都有题目,但是不一定有内容(也就是说内容可能为空)。其中,诗内容中的标点符号都是全角符号。有一些诗五言诗,不过也有一些诗不是五言的。当然,我们只考虑五言诗(大概有27k首)。
代码思路
输入 and 输出
首先我们得先弄清我们要干什么,然后才能更好得写代码。如标题所示,目的是使用RNN写诗,那么必然有输入和输出。那么问题来了,RNN的输入是什么,输出是什么?
我们希望rnn能够写诗,那么怎么写呢?我们这样定义如下的方式:
![](imgs/rnn_io (1).svg)
RNN接受 6个字符(5个字+一个标点符号),然后输出下一个字符。至于怎么生成一首完整的诗词,等到后面讨论。
RNN当然不能够直接接受 "床前明月光," 这个中文的输入,我们要对其进行 Encode,变成数字,然后才能够输入到RNN网络中。同理,RNN输出的肯定也不是一个中文字符,我们也要对其进行Decode 才能将输出变成一个中文字符。
怎么进行Encode,有一个很简单的方法,那就是进行one-hot编码,对于每一个字(包括标点符号在内)我们都进行onehot编码,这样就可以了。但实际上,这个这样会有一点小问题。在数据集中,所有符合条件的诗,大概由近 7,000 个字符组成,如果对每一个字都进行onehot编码的话,就会消耗大量的内存,同时也会加大计算的复杂度。
因此,我们定义如下:只对前出现频率最多的 2999 个字符进行 one-hot 编码,对于剩下的字,用 “ ”(空格字符)代替。这样一共只需要对3000个字符进行one-hot编码就了(2999个字符+一个空格字符)。
训练集构建
在前面我们定义了RNN的输入和输出,同时也有诗的数据集, 那么我们构建训练集呢?参考RNN模型与NLP应用(6/9):Text Generation (自动文本生成)
具体步骤如下图所示:我们将一句诗可以进行如下切分。然后将切分得到的数据进行one-hot编码,然后进行训练即可。(这样看来,每一首诗可以生成很多的数据集)
![](imgs/RNN_split_data (1).svg)
生成一首完整的诗
前面我们讨论了关于网络的输入和输出,以及数据集的构建,那么,假如我们有一个已经训练好的模型,如何来产生一首诗的?
生成一首完整的诗的流程如下所示,与训练的操作有点类似,只不过会将RNN的输出重新当作RNN的输入。(以此来产生符合字数要求的诗)
经过上述的操作,大家实际上可以尝试的写一些代码了,基本上不会有很大的问题。接下来,我将讲一讲具体怎么实现。
代码实现
首先定义一些配置:
- DISALLOWED_WORDS:如果在诗中出现了DISALLOWED_WORDS,则舍弃这首诗。
# 诗data的地址
poetry_data_path = "./data/poetry.txt"
# 如果诗词中出现这些词,则将诗舍弃
DISALLOWED_WORDS = ['(', ')', '(', ')', '__', '《', '》', '【', '】', '[', ']']
# 取3000个字作诗,其中包括空格字符
WORD_NUM = 3000
# 将出现少的字使用空格代替
UNKONW_CHAR = " "
# 根据前6个字预测下一个字,比如说根据“寒随穷律变,”预测“春”
TRAIN_NUM = 6
读取文件
针对于数据集,我们有如下的要求:
- 必须是五言诗(不过下面的代码无法完全保证是五言诗),同时至少要有两句诗
- 不能出现上文中定义的DISALLOWED_WORDS
前面我们说了,每一首诗必有题目和内容(内容可以为空),其中,题目和内容以 ":"(半角)分开,因此,我们可以通过 line.split(":")[1]
获得诗的内容。
下述代码实现了两个功能:
- 获得符合要求的诗:
(len(poetry)-1) % 6
,每一首五言诗,包括“,。”一共有\(6*n\) 个字,同时每一首诗是以 "\n" 结尾的,因为我们(len(poetry)-1)%6==0
则就代表符合要求。同时五言诗的第6个字符是","——> 使用poetrys
保存。 - 获得诗中出现的字符。——>使用
all_word
保存。
# 保存诗词
poetrys = []
# 保存在诗词中出现的字
all_word = []
with open(poetry_data_path,encoding="utf-8") as f:
for line in f:
# 获得诗的内容
poetry = line.split(":")[1].replace(" ","")
flag = True
# 如果在句子中出现'(', ')', '(', ')', '__', '《', '》', '【', '】', '[', ']'则舍弃
for dis_word in DISALLOWED_WORDS:
if dis_word in poetry:
flag = False
break
# 只需要5言的诗(两句诗包括标点符号就是12个字),假如少于两句诗则舍弃
if len(poetry) < 12 or poetry[5] != ',' or (len(poetry)-1) % 6 != 0:
flag = False
if flag:
# 统计出现的词
for word in poetry:
all_word.append(word)
poetrys.append(poetry)
统计字数
前面我们说过,在数据集中,所有符合条件的诗,大概由近 7,000 个字组成,如果对每一个字都进行one-hot编码的话,就会浪费大量的内存,加大计算的复杂度。解决方法可以这样做:
使用Counter对字数进行统计,然后根据出现的次数进行排序,最后得到出现频率最多的2999个字。
from collections import Counter
# 对字数进行统计
counter = Counter(all_word)
# 根据出现的次数,进行从大到小的排序
word_count = sorted(counter.items(),key=lambda x : -x[1])
most_num_word,_ = zip(*word_count)
# 取前2999个字,然后在最后加上" "
use_words = most_num_word[:WORD_NUM - 1] + (UNKONW_CHAR,)
构建word 与 id的映射
我们需要对word进行onehot编码,怎么编呢?很简单,每一个word对应一个id,然后对这个id进行one-hot编码就行了。因此我们需要构建word到id的映射。
举个例子:如果一共只有3个字“唐”,“宋”,“明”,然后我们可以构建如下的映射:
"唐" ——> 0 ;"宋"——>1;"明"——>2;进行one-hot编码后,则就变成了:
- 唐:[1,0,0]
- 宋:[0,1,0]
- 明:[0,0,1]
构建word与id的映射是必须的,经过如下简单的代码,便构成了映射。
# word 到 id的映射 {',': 0,'。': 1,'\n': 2,'不': 3,'人': 4,'山': 5,……}
word_id_dict = {word:index for index,word in enumerate(use_words)}
# id 到 word的映射 {0: ',',1: '。',2: '\n',3: '不',4: '人',5: '山',……}
id_word_dict = {index:word for index,word in enumerate(use_words)}
转成one-hot代码
下面定义两个函数:
word_to_one_hot将一个字转成one-hot 形式
phrase_to_one_hot 将一个句子转成one-hot形式
import numpy as np
def word_to_one_hot(word):
"""将一个字转成onehot形式
:param word: [一个字]
:type word: [str]
"""
one_hot_word = np.zeros(WORD_NUM)
# 假如字是生僻字,则变成空格
if word not in word_id_dict.keys():
word = UNKONW_CHAR
index = word_id_dict[word]
one_hot_word[index] = 1
return one_hot_word
def phrase_to_one_hot(phrase):
"""将一个句子转成onehot
:param phrase: [一个句子]
:type poetry: [str]
"""
one_hot_phrase = []
for word in phrase:
one_hot_phrase.append(word_to_one_hot(word))
return one_hot_phrase
随机打乱数据
np.random.shuffle(poetrys)
构建训练集
然后我们需要进行如下操作,根据诗构建数据集(one-hot编码之前的数据集)。
![](imgs/RNN_split_data (1).svg)
构建数据集的时候我们需要注意一件事情,需要区分不同的诗(因为我们总不可能用A的诗去预测B的诗噻,hhh)。每一首诗都是以 "\n" 结尾的,因此,当循环到"\n"时,就代表对于这首诗,我们已经构建好数据集了(上图中的X_Data【用X_train_word
表示】,Y_Data【用Y_train_word
表示】)。
X_train_word = []
Y_train_word = []
for poetry in poetrys:
for i in range(len(poetry)):
X = poetry[i:i+TRAIN_NUM]
Y = poetry[i+TRAIN_NUM]
if "\n" not in X and "\n" not in Y:
X_train_word.append(X)
Y_train_word.append(Y)
else:
break
在没有打乱顺序的情况下,部分结果如下所示:
构建模型
使用的框架:
- keras:2.4.3:如果想使用我训练好的模型,请保持版本一致。如果自己训练的话,就无所谓了。
模型图如下所示,模型结构参考Poems_generator_Keras,关于SimpleRNN的介绍可以参考Keras-SimpleRNN,关于如何使用keras构建神经网络可以参考数据挖掘入门系列教程(十一)之keras入门使用以及构建DNN网络识别MNIST。
在前面说了,RNN模型输入的是一个 6个字符 的句子,因此经过one-hot编码后就会变成shape为(6,3000)的数组,而输出为一个字符,对应one-hot编码的shape为(3000)。
代码如下所示:
import keras
from keras.callbacks import LambdaCallback,ModelCheckpoint
from keras.models import Input, Model
from keras.layers import Dropout, Dense,SimpleRNN
from keras.optimizers import Adam
from keras.utils import plot_model
def build_model():
print('building model')
# 输入的dimension
input_tensor = Input(shape=(TRAIN_NUM,WORD_NUM))
rnn = SimpleRNN(512,return_sequences=True)(input_tensor)
dropout = Dropout(0.6)(rnn)
rnn = SimpleRNN(256)(dropout)
dropout = Dropout(0.6)(rnn)
dense = Dense(WORD_NUM, activation='softmax')(dropout)
model = Model(inputs=input_tensor, outputs=dense)
optimizer = Adam(lr=0.001)
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
model.summary()
# 画出模型图
# plot_model(model, to_file='model.png', show_shapes=True, expand_nested=True, dpi=500)
return model
对于SimpleRNN,如果return_sequences=True
,则代表其返回如下:
如果return_sequences=False
(默认),则代表返回如下所示:
这个模型中,套了两层RNN。
model = build_model()
批加载数据
这次数据集比较大,一共有\(1559196\)份数据,一般来说没有这么大内存将其所有的数据一次性全部转成one-hot形式。
因此,我们可以这样做:在训练的时候才开始加载数据,每一次只需要加载batch_size的数据,然后只需要将batch_size 大小的数据转成one-hot形式,然后进行训练。在这种情况下,只需要将batch-size的数据转成one-hot,可以大大减小内存消耗。
so,使用keras训练的时候,不能使用fit(因为fit需要一次将数据集全部放入RAM中),而应该使用fit_generator,关于其使用推荐看看:Keras 如何使用fit和fit_generator
import math
def get_batch(batch_size = 32):
"""源源不断产生产生one-hot编码的训练数据
:param batch_size: [一次产生训练数据的大小], defaults to 32
:type batch_size: int, optional
:yield: [返回X(np.array(X_train_batch))和Y(np.array(Y_train_batch))]
:rtype: [X.shape为(batch_size, 6, 3000) , Y.shape数据的shape(batch_size, 3000)]
"""
# 确定每轮有多少个batch
steps = math.ceil(len(X_train_word) / batch_size)
while True:
for i in range(steps):
X_train_batch = []
Y_train_batch = []
X_batch_datas = X_train_word[i*batch_size:(i+1)*batch_size]
Y_batch_datas = Y_train_word[i*batch_size:(i+1)*batch_size]
for x,y in zip(X_batch_datas,Y_batch_datas):
X_train_batch.append(phrase_to_one_hot(x))
Y_train_batch.append(word_to_one_hot(y))
yield np.array(X_train_batch),np.array(Y_train_batch)
训练的过程中生成诗句
在训练的过程中,可以每经过一定数量的epoch生成一首诗,生成诗的操作如下:
在训练的过程中,调用generate_sample_result
,即可产生五言诗,然后将生成的诗写入到out/out.txt
中。
def predict_next(x):
""" 根据X预测下一个字符
:param x: [输入数据]
:type x: [x的shape为(1,TRAIN_NUM,WORD_NUM)]
:return: [最大概率字符的索引,有可能为为2999,也就是预测的字符可能为“ ”]
:rtype: [int]
"""
predict_y = model.predict(x)[0]
# 获得最大概率的索引
index = np.argmax(predict_y)
return index
def generate_sample_result(epoch, logs):
"""生成五言诗
:param epoch: [目前模型训练的epoch]
:type epoch: [int]
:param logs: [模型训练日志]
:type logs: [list]
"""
# 每个epoch都产生输出
if epoch % 1 == 0:
# 根据“一朝春夏改,”生成诗
predict_sen = "一朝春夏改,"
predict_data = predict_sen
# 生成的4句五言诗(4 * 6 = 24)
while len(predict_sen) < 24:
X_data = np.array(phrase_to_one_hot(predict_data)).reshape(1,TRAIN_NUM,WORD_NUM)
# 根据6个字符预测下一个字符
y = predict_next(X_data)
predict_sen = predict_sen+ id_word_dict[y]
# “寒随穷律变,” ——> “随穷律变,春”
predict_data = predict_data[1:]+id_word_dict[y]
# 将数据写入文件
with open('out/out.txt', 'a',encoding='utf-8') as f:
f.write(write_data+'\n')
开始训练
在训练的时候,每隔一个epoch,都会将模型进行保存,每个epoch完成的时候,都会调用generate_sample_result
生成诗。
batch_size = 2048
model.fit_generator(
generator=get_batch(batch_size),
verbose=True,
steps_per_epoch=math.ceil(len(X_train_word) / batch_size),
epochs=1000000,
callbacks=[
ModelCheckpoint("poetry_model.hdf5",verbose=1,monitor='val_loss',period=1),
# 每次完成一个epoch会调用generate_sample_result产生五言诗
LambdaCallback(on_epoch_end=generate_sample_result)
]
)
因为我的电脑就是一个mx250小水管,我就放在kaggle上面跑了,毕竟白嫖它不香吗?如果实在想自己跑,但是有没有比较好的GPU,可以尝试将len(X_train_word)
改成其他的数,比如说“100000”。要在如下的两个地方改,这样的话,很快就可以出训练的结果。(这样会导致训练的时候无法覆盖整个数据集。)
诗词生成
我在Github中提供了训练好的模型(注意keras版本是2.4.3),在 test.ipynb 中提供了如何加载模型然后生成诗句的方法,在这里就不赘述了。
最后简单的展示一下生成的结果(实际上模型训练的效果并不是很好,
简单明朗的 RNN 写诗教程的更多相关文章
- 深度学习(三)之LSTM写诗
目录 数据预处理 构建数据集 模型结构 生成诗 根据上文生成诗 生成藏头诗 参考 根据前文生成诗: 机器学习业,圣贤不可求.临戎辞蜀计,忠信尽封疆.天子咨两相,建章应四方.自疑非俗态,谁复念鹪鹩. 生 ...
- idea搭建简单ssm框架的最详细教程(新)
为开发一个测试程序,特搭建一个简单的ssm框架,因为网上看到很多都是比较老旧的教程,很多包都不能用了,eclipes搭建并且其中还附带了很多的其他东西,所以特此记录一下mac中idea搭建过程. 另: ...
- 最简单的SQLserver,发布订阅教程,保证一次就成功
最简单的SQLserver,发布订阅教程,保证一次就成功 发布订阅用来做数据库的读写分离,还是很好用的 当单台数据库的压力太大时,可以考虑这种方案,一主多从,主服务器的数据库只管写入,其他的数据库都是 ...
- 为你写诗:3 步搭建 Serverless AI 应用
作者 | 杜万(倚贤) 阿里巴巴技术专家 本文整理自 1 月 2 日社群分享,每月 2 场高质量分享,点击加入社群. 关注"阿里巴巴云原生"公众号,回复关键词 0102 即可下载本 ...
- 【阿里云产品公测】简单日志服务SLS使用评测 + 教程
[阿里云产品公测]简单日志服务SLS使用评测 + 教程 评测介绍 被测产品: 简单日志服务SLS 评测环境: 阿里云基础ECS x2(1核, 512M, 1M) 操作系统: CentOS 6.5 x6 ...
- js写插件教程深入
原文地址:https://github.com/lianxiaozhuang/blog 转载请注明出处 js 写插件教程深入 1.介绍具有安全作用域的构造函数 function Fn(name){ t ...
- Qt侠:像写诗一样写代码,玩游戏一样的开心心情,还能领工资!
[软]上海-Qt侠 2017/7/12 16:11:20我完全是兴趣主导,老板不给我钱,我也要写好代码!白天干,晚上干,周一周五干,周末继续干!编程已经深入我的基因,深入我的骨髓,深入我的灵魂!当我解 ...
- AI:为你写诗,为你做不可能的事
最近,一档全程高能的神仙节目,高调地杀入了我们的视野: 没错,就是撒贝宁主持,董卿.康辉等央视名嘴作为评审嘉宾,同时集齐央视"三大名嘴"同台的央视<主持人大赛>,这够不 ...
- 最新最最最简单的Snagit傻瓜式破解教程(带下载地址)
最新最最最简单的Snagit傻瓜式破解教程(带下载地址) 下载地址 直接滑至文章底部下载 软件介绍 一个非常著名的优秀屏幕.文本和视频捕获.编辑与转换软件.可以捕获Windows屏幕.DOS屏幕:RM ...
随机推荐
- DVWA-文件包含-目录遍历学习笔记
参考文献资料: https://www.cnblogs.com/s0ky1xd/p/5823685.html https://www.cnblogs.com/yuzly/p/10799486.html ...
- RabbitMq基本概念理解
RabbitMQ的基本概念 RabbitMQ github项目地址 RabbitMQ 2007年发布,是一个在AMQP(高级消息队列协议)基础上完成的,可复用的企业消息系统,是当前最主流的 消息中间件 ...
- .NET5 API 网关Ocelot+Consul服务注册
1|0网关介绍 网关其实就是将我们写好的API全部放在一个统一的地址暴露在公网,提供访问的一个入口.在 .NET Core下可以使用Ocelot来帮助我们很方便的接入API 网关.与之类似的库还有Pr ...
- Git的使用以及命令
个人常用命令 git初始化操作 git init 把当前的目录变成git仓库,生成隐藏.git文件. git remote add origin url 把本地仓库的内容推送到GitHub仓库. gi ...
- ASP.NET网站部署到服务器IIS上和本地局域网服务器
控制面板>>>管理工具>>>打开Internet信息服务 2,如果找不到 可以控制面板>>>程序和功能>>> 打开或关闭win ...
- (五)cp命令复制文件或者目录
一.cp的含义.功能及命令格式 cp(英文copy的缩写)命令可以将一个文件或者目录从一个位置复制到另外一个位置.cp的功能就是将一个文件复制成 一个指定的目的文件或者复制到一个指定的目录中,兼具复制 ...
- 2.自定义view-QQ运动步数
1.效果 2.实现 2.1自定义属性 在res/values 文件夹中新建xx.xml,内容如下 <?xml version="1.0" encoding="utf ...
- [LeetCode]547. Friend Circles朋友圈数量--不相邻子图问题
/* 思路就是遍历所有人,对于每一个人,寻找他的好友,找到好友后再找这个好友的好友 ,这样深度优先遍历下去,设置一个flag记录是否已经遍历了这个人. 其实dfs真正有用的是flag这个变量,因为如果 ...
- [LeetCode]501. Find Mode in Binary Search Tree二叉搜索树寻找众数
这次是二叉搜索树的遍历 感觉只要和二叉搜索树的题目,都要用到一个重要性质: 中序遍历二叉搜索树的结果是一个递增序列: 而且要注意,在递归遍历树的时候,有些参数如果是要随递归不断更新(也就是如果递归返回 ...
- [Machine Learning] 多变量线性回归(Linear Regression with Multiple Variable)-特征缩放-正规方程
我们从上一篇博客中知道了关于单变量线性回归的相关问题,例如:什么是回归,什么是代价函数,什么是梯度下降法. 本节我们讲一下多变量线性回归.依然拿房价来举例,现在我们对房价模型增加更多的特征,例如房间数 ...