简单明朗的 RNN 写诗教程

本来想做一个标题党的,取了一个史上最简单的 RNN 写诗教程这标题,但是后来想了想,这TM不就是标题党吗?怎么活成了自己最讨厌的模样?后来就改成了这个标题。

在上篇博客网络流量预测入门(一)之RNN 介绍中,介绍了RNN的原理,而在这篇博客中,将介绍如何使用keras构建RNN,然后自动写诗。

项目地址:Github:https://github.com/xiaohuiduan/rnn_chinese_poetry

数据集介绍

既然是写诗,当然得有数据集,不过还好有大神已经将数据集准备好了,具体数据集的来源已不可知,因为网上基本上都是使用这个数据集。(如果有人知道,可以在评论区指出,然后我再添加上)

数据集地址:Github,数据集部分数据如下所示:

在数据集中,每一行都是一首唐诗,其中,诗的题目和内容以 ":" 分开,每一首诗都有题目,但是不一定有内容(也就是说内容可能为空)。其中,诗内容中的标点符号都是全角符号。有一些诗五言诗,不过也有一些诗不是五言的。当然,我们只考虑五言诗(大概有27k首)。

代码思路

输入 and 输出

首先我们得先弄清我们要干什么,然后才能更好得写代码。如标题所示,目的是使用RNN写诗,那么必然有输入和输出。那么问题来了,RNN的输入是什么,输出是什么?

我们希望rnn能够写诗,那么怎么写呢?我们这样定义如下的方式:

![](imgs/rnn_io (1).svg)

RNN接受 6个字符(5个字+一个标点符号),然后输出下一个字符。至于怎么生成一首完整的诗词,等到后面讨论。

RNN当然不能够直接接受 "床前明月光," 这个中文的输入,我们要对其进行 Encode,变成数字,然后才能够输入到RNN网络中。同理,RNN输出的肯定也不是一个中文字符,我们也要对其进行Decode 才能将输出变成一个中文字符。

怎么进行Encode,有一个很简单的方法,那就是进行one-hot编码,对于每一个字(包括标点符号在内)我们都进行onehot编码,这样就可以了。但实际上,这个这样会有一点小问题。在数据集中,所有符合条件的诗,大概由近 7,000 个字符组成,如果对每一个字都进行onehot编码的话,就会消耗大量的内存,同时也会加大计算的复杂度。

因此,我们定义如下:只对前出现频率最多的 2999 个字符进行 one-hot 编码,对于剩下的字,用 “ ”(空格字符)代替。这样一共只需要对3000个字符进行one-hot编码就了(2999个字符+一个空格字符)。

训练集构建

在前面我们定义了RNN的输入和输出,同时也有诗的数据集, 那么我们构建训练集呢?参考RNN模型与NLP应用(6/9):Text Generation (自动文本生成)

具体步骤如下图所示:我们将一句诗可以进行如下切分。然后将切分得到的数据进行one-hot编码,然后进行训练即可。(这样看来,每一首诗可以生成很多的数据集)

![](imgs/RNN_split_data (1).svg)

生成一首完整的诗

前面我们讨论了关于网络的输入和输出,以及数据集的构建,那么,假如我们有一个已经训练好的模型,如何来产生一首诗的?

生成一首完整的诗的流程如下所示,与训练的操作有点类似,只不过会将RNN的输出重新当作RNN的输入。(以此来产生符合字数要求的诗)

经过上述的操作,大家实际上可以尝试的写一些代码了,基本上不会有很大的问题。接下来,我将讲一讲具体怎么实现。

代码实现

首先定义一些配置:

  • DISALLOWED_WORDS:如果在诗中出现了DISALLOWED_WORDS,则舍弃这首诗。
# 诗data的地址
poetry_data_path = "./data/poetry.txt"
# 如果诗词中出现这些词,则将诗舍弃
DISALLOWED_WORDS = ['(', ')', '(', ')', '__', '《', '》', '【', '】', '[', ']']
# 取3000个字作诗,其中包括空格字符
WORD_NUM = 3000
# 将出现少的字使用空格代替
UNKONW_CHAR = " "
# 根据前6个字预测下一个字,比如说根据“寒随穷律变,”预测“春”
TRAIN_NUM = 6

读取文件

针对于数据集,我们有如下的要求:

  • 必须是五言诗(不过下面的代码无法完全保证是五言诗),同时至少要有两句诗
  • 不能出现上文中定义的DISALLOWED_WORDS

前面我们说了,每一首诗必有题目和内容(内容可以为空),其中,题目和内容以 ":"(半角)分开,因此,我们可以通过 line.split(":")[1]获得诗的内容。

下述代码实现了两个功能:

  1. 获得符合要求的诗:(len(poetry)-1) % 6,每一首五言诗,包括“,。”一共有\(6*n\) 个字,同时每一首诗是以 "\n" 结尾的,因为我们(len(poetry)-1)%6==0则就代表符合要求。同时五言诗的第6个字符是","——> 使用poetrys保存。
  2. 获得诗中出现的字符。——>使用all_word保存。
# 保存诗词
poetrys = []
# 保存在诗词中出现的字
all_word = [] with open(poetry_data_path,encoding="utf-8") as f:
for line in f:
# 获得诗的内容
poetry = line.split(":")[1].replace(" ","")
flag = True
# 如果在句子中出现'(', ')', '(', ')', '__', '《', '》', '【', '】', '[', ']'则舍弃
for dis_word in DISALLOWED_WORDS:
if dis_word in poetry:
flag = False
break # 只需要5言的诗(两句诗包括标点符号就是12个字),假如少于两句诗则舍弃
if len(poetry) < 12 or poetry[5] != ',' or (len(poetry)-1) % 6 != 0:
flag = False if flag:
# 统计出现的词
for word in poetry:
all_word.append(word)
poetrys.append(poetry)

统计字数

前面我们说过,在数据集中,所有符合条件的诗,大概由近 7,000 个字组成,如果对每一个字都进行one-hot编码的话,就会浪费大量的内存,加大计算的复杂度。解决方法可以这样做:

使用Counter对字数进行统计,然后根据出现的次数进行排序,最后得到出现频率最多的2999个字。

from collections import Counter
# 对字数进行统计
counter = Counter(all_word)
# 根据出现的次数,进行从大到小的排序
word_count = sorted(counter.items(),key=lambda x : -x[1])
most_num_word,_ = zip(*word_count)
# 取前2999个字,然后在最后加上" "
use_words = most_num_word[:WORD_NUM - 1] + (UNKONW_CHAR,)

构建word 与 id的映射

我们需要对word进行onehot编码,怎么编呢?很简单,每一个word对应一个id,然后对这个id进行one-hot编码就行了。因此我们需要构建word到id的映射。

举个例子:如果一共只有3个字“唐”,“宋”,“明”,然后我们可以构建如下的映射:

"唐" ——> 0 ;"宋"——>1;"明"——>2;进行one-hot编码后,则就变成了:

  • 唐:[1,0,0]
  • 宋:[0,1,0]
  • 明:[0,0,1]

构建word与id的映射是必须的,经过如下简单的代码,便构成了映射。

# word 到 id的映射 {',': 0,'。': 1,'\n': 2,'不': 3,'人': 4,'山': 5,……}
word_id_dict = {word:index for index,word in enumerate(use_words)} # id 到 word的映射 {0: ',',1: '。',2: '\n',3: '不',4: '人',5: '山',……}
id_word_dict = {index:word for index,word in enumerate(use_words)}

转成one-hot代码

下面定义两个函数:

  • word_to_one_hot将一个字转成one-hot 形式

  • phrase_to_one_hot 将一个句子转成one-hot形式

import numpy as np
def word_to_one_hot(word):
"""将一个字转成onehot形式 :param word: [一个字]
:type word: [str]
"""
one_hot_word = np.zeros(WORD_NUM)
# 假如字是生僻字,则变成空格
if word not in word_id_dict.keys():
word = UNKONW_CHAR
index = word_id_dict[word]
one_hot_word[index] = 1
return one_hot_word def phrase_to_one_hot(phrase):
"""将一个句子转成onehot :param phrase: [一个句子]
:type poetry: [str]
"""
one_hot_phrase = []
for word in phrase:
one_hot_phrase.append(word_to_one_hot(word))
return one_hot_phrase

随机打乱数据

np.random.shuffle(poetrys)

构建训练集

然后我们需要进行如下操作,根据诗构建数据集(one-hot编码之前的数据集)。

![](imgs/RNN_split_data (1).svg)

构建数据集的时候我们需要注意一件事情,需要区分不同的诗(因为我们总不可能用A的诗去预测B的诗噻,hhh)。每一首诗都是以 "\n" 结尾的,因此,当循环到"\n"时,就代表对于这首诗,我们已经构建好数据集了(上图中的X_Data【用X_train_word表示】,Y_Data【用Y_train_word表示】)。

X_train_word = []
Y_train_word = [] for poetry in poetrys:
for i in range(len(poetry)):
X = poetry[i:i+TRAIN_NUM]
Y = poetry[i+TRAIN_NUM]
if "\n" not in X and "\n" not in Y:
X_train_word.append(X)
Y_train_word.append(Y)
else:
break

在没有打乱顺序的情况下,部分结果如下所示:

构建模型

使用的框架:

  • keras:2.4.3:如果想使用我训练好的模型,请保持版本一致。如果自己训练的话,就无所谓了。

模型图如下所示,模型结构参考Poems_generator_Keras,关于SimpleRNN的介绍可以参考Keras-SimpleRNN,关于如何使用keras构建神经网络可以参考数据挖掘入门系列教程(十一)之keras入门使用以及构建DNN网络识别MNIST

在前面说了,RNN模型输入的是一个 6个字符 的句子,因此经过one-hot编码后就会变成shape为(6,3000)的数组,而输出为一个字符,对应one-hot编码的shape为(3000)。

代码如下所示:

import keras
from keras.callbacks import LambdaCallback,ModelCheckpoint
from keras.models import Input, Model
from keras.layers import Dropout, Dense,SimpleRNN
from keras.optimizers import Adam
from keras.utils import plot_model def build_model():
print('building model')
# 输入的dimension
input_tensor = Input(shape=(TRAIN_NUM,WORD_NUM))
rnn = SimpleRNN(512,return_sequences=True)(input_tensor)
dropout = Dropout(0.6)(rnn) rnn = SimpleRNN(256)(dropout)
dropout = Dropout(0.6)(rnn)
dense = Dense(WORD_NUM, activation='softmax')(dropout) model = Model(inputs=input_tensor, outputs=dense)
optimizer = Adam(lr=0.001)
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
model.summary()
# 画出模型图
# plot_model(model, to_file='model.png', show_shapes=True, expand_nested=True, dpi=500)
return model

对于SimpleRNN,如果return_sequences=True,则代表其返回如下:

如果return_sequences=False(默认),则代表返回如下所示:

这个模型中,套了两层RNN。

model = build_model()

批加载数据

这次数据集比较大,一共有\(1559196\)份数据,一般来说没有这么大内存将其所有的数据一次性全部转成one-hot形式。

因此,我们可以这样做:在训练的时候才开始加载数据,每一次只需要加载batch_size的数据,然后只需要将batch_size 大小的数据转成one-hot形式,然后进行训练。在这种情况下,只需要将batch-size的数据转成one-hot,可以大大减小内存消耗。

so,使用keras训练的时候,不能使用fit(因为fit需要一次将数据集全部放入RAM中),而应该使用fit_generator,关于其使用推荐看看:Keras 如何使用fit和fit_generator

import math
def get_batch(batch_size = 32):
"""源源不断产生产生one-hot编码的训练数据 :param batch_size: [一次产生训练数据的大小], defaults to 32
:type batch_size: int, optional
:yield: [返回X(np.array(X_train_batch))和Y(np.array(Y_train_batch))]
:rtype: [X.shape为(batch_size, 6, 3000) , Y.shape数据的shape(batch_size, 3000)]
"""
# 确定每轮有多少个batch
steps = math.ceil(len(X_train_word) / batch_size)
while True:
for i in range(steps):
X_train_batch = []
Y_train_batch = []
X_batch_datas = X_train_word[i*batch_size:(i+1)*batch_size]
Y_batch_datas = Y_train_word[i*batch_size:(i+1)*batch_size] for x,y in zip(X_batch_datas,Y_batch_datas):
X_train_batch.append(phrase_to_one_hot(x))
Y_train_batch.append(word_to_one_hot(y))
yield np.array(X_train_batch),np.array(Y_train_batch)

训练的过程中生成诗句

在训练的过程中,可以每经过一定数量的epoch生成一首诗,生成诗的操作如下:

在训练的过程中,调用generate_sample_result,即可产生五言诗,然后将生成的诗写入到out/out.txt中。

def predict_next(x):
""" 根据X预测下一个字符 :param x: [输入数据]
:type x: [x的shape为(1,TRAIN_NUM,WORD_NUM)]
:return: [最大概率字符的索引,有可能为为2999,也就是预测的字符可能为“ ”]
:rtype: [int]
"""
predict_y = model.predict(x)[0]
# 获得最大概率的索引
index = np.argmax(predict_y)
return index def generate_sample_result(epoch, logs):
"""生成五言诗 :param epoch: [目前模型训练的epoch]
:type epoch: [int]
:param logs: [模型训练日志]
:type logs: [list]
"""
# 每个epoch都产生输出
if epoch % 1 == 0:
# 根据“一朝春夏改,”生成诗
predict_sen = "一朝春夏改,"
predict_data = predict_sen
# 生成的4句五言诗(4 * 6 = 24)
while len(predict_sen) < 24:
X_data = np.array(phrase_to_one_hot(predict_data)).reshape(1,TRAIN_NUM,WORD_NUM)
# 根据6个字符预测下一个字符
y = predict_next(X_data)
predict_sen = predict_sen+ id_word_dict[y]
# “寒随穷律变,” ——> “随穷律变,春”
predict_data = predict_data[1:]+id_word_dict[y]
# 将数据写入文件
with open('out/out.txt', 'a',encoding='utf-8') as f:
f.write(write_data+'\n')

开始训练

在训练的时候,每隔一个epoch,都会将模型进行保存,每个epoch完成的时候,都会调用generate_sample_result生成诗。

batch_size = 2048
model.fit_generator(
generator=get_batch(batch_size),
verbose=True,
steps_per_epoch=math.ceil(len(X_train_word) / batch_size),
epochs=1000000,
callbacks=[
ModelCheckpoint("poetry_model.hdf5",verbose=1,monitor='val_loss',period=1),
# 每次完成一个epoch会调用generate_sample_result产生五言诗
LambdaCallback(on_epoch_end=generate_sample_result)
]
)

因为我的电脑就是一个mx250小水管,我就放在kaggle上面跑了,毕竟白嫖它不香吗?如果实在想自己跑,但是有没有比较好的GPU,可以尝试将len(X_train_word)改成其他的数,比如说“100000”。要在如下的两个地方改,这样的话,很快就可以出训练的结果。(这样会导致训练的时候无法覆盖整个数据集。)

诗词生成

我在Github中提供了训练好的模型(注意keras版本是2.4.3),在 test.ipynb 中提供了如何加载模型然后生成诗句的方法,在这里就不赘述了。

最后简单的展示一下生成的结果(实际上模型训练的效果并不是很好,

简单明朗的 RNN 写诗教程的更多相关文章

  1. 深度学习(三)之LSTM写诗

    目录 数据预处理 构建数据集 模型结构 生成诗 根据上文生成诗 生成藏头诗 参考 根据前文生成诗: 机器学习业,圣贤不可求.临戎辞蜀计,忠信尽封疆.天子咨两相,建章应四方.自疑非俗态,谁复念鹪鹩. 生 ...

  2. idea搭建简单ssm框架的最详细教程(新)

    为开发一个测试程序,特搭建一个简单的ssm框架,因为网上看到很多都是比较老旧的教程,很多包都不能用了,eclipes搭建并且其中还附带了很多的其他东西,所以特此记录一下mac中idea搭建过程. 另: ...

  3. 最简单的SQLserver,发布订阅教程,保证一次就成功

    最简单的SQLserver,发布订阅教程,保证一次就成功 发布订阅用来做数据库的读写分离,还是很好用的 当单台数据库的压力太大时,可以考虑这种方案,一主多从,主服务器的数据库只管写入,其他的数据库都是 ...

  4. 为你写诗:3 步搭建 Serverless AI 应用

    作者 | 杜万(倚贤) 阿里巴巴技术专家 本文整理自 1 月 2 日社群分享,每月 2 场高质量分享,点击加入社群. 关注"阿里巴巴云原生"公众号,回复关键词 0102 即可下载本 ...

  5. 【阿里云产品公测】简单日志服务SLS使用评测 + 教程

    [阿里云产品公测]简单日志服务SLS使用评测 + 教程 评测介绍 被测产品: 简单日志服务SLS 评测环境: 阿里云基础ECS x2(1核, 512M, 1M) 操作系统: CentOS 6.5 x6 ...

  6. js写插件教程深入

    原文地址:https://github.com/lianxiaozhuang/blog 转载请注明出处 js 写插件教程深入 1.介绍具有安全作用域的构造函数 function Fn(name){ t ...

  7. Qt侠:像写诗一样写代码,玩游戏一样的开心心情,还能领工资!

    [软]上海-Qt侠 2017/7/12 16:11:20我完全是兴趣主导,老板不给我钱,我也要写好代码!白天干,晚上干,周一周五干,周末继续干!编程已经深入我的基因,深入我的骨髓,深入我的灵魂!当我解 ...

  8. AI:为你写诗,为你做不可能的事

    最近,一档全程高能的神仙节目,高调地杀入了我们的视野: 没错,就是撒贝宁主持,董卿.康辉等央视名嘴作为评审嘉宾,同时集齐央视"三大名嘴"同台的央视<主持人大赛>,这够不 ...

  9. 最新最最最简单的Snagit傻瓜式破解教程(带下载地址)

    最新最最最简单的Snagit傻瓜式破解教程(带下载地址) 下载地址 直接滑至文章底部下载 软件介绍 一个非常著名的优秀屏幕.文本和视频捕获.编辑与转换软件.可以捕获Windows屏幕.DOS屏幕:RM ...

随机推荐

  1. DVWA-文件包含-目录遍历学习笔记

    参考文献资料: https://www.cnblogs.com/s0ky1xd/p/5823685.html https://www.cnblogs.com/yuzly/p/10799486.html ...

  2. RabbitMq基本概念理解

    RabbitMQ的基本概念 RabbitMQ github项目地址 RabbitMQ 2007年发布,是一个在AMQP(高级消息队列协议)基础上完成的,可复用的企业消息系统,是当前最主流的 消息中间件 ...

  3. .NET5 API 网关Ocelot+Consul服务注册

    1|0网关介绍 网关其实就是将我们写好的API全部放在一个统一的地址暴露在公网,提供访问的一个入口.在 .NET Core下可以使用Ocelot来帮助我们很方便的接入API 网关.与之类似的库还有Pr ...

  4. Git的使用以及命令

    个人常用命令 git初始化操作 git init 把当前的目录变成git仓库,生成隐藏.git文件. git remote add origin url 把本地仓库的内容推送到GitHub仓库. gi ...

  5. ASP.NET网站部署到服务器IIS上和本地局域网服务器

    控制面板>>>管理工具>>>打开Internet信息服务 2,如果找不到 可以控制面板>>>程序和功能>>>  打开或关闭win ...

  6. (五)cp命令复制文件或者目录

    一.cp的含义.功能及命令格式 cp(英文copy的缩写)命令可以将一个文件或者目录从一个位置复制到另外一个位置.cp的功能就是将一个文件复制成 一个指定的目的文件或者复制到一个指定的目录中,兼具复制 ...

  7. 2.自定义view-QQ运动步数

    1.效果 2.实现 2.1自定义属性 在res/values 文件夹中新建xx.xml,内容如下 <?xml version="1.0" encoding="utf ...

  8. [LeetCode]547. Friend Circles朋友圈数量--不相邻子图问题

    /* 思路就是遍历所有人,对于每一个人,寻找他的好友,找到好友后再找这个好友的好友 ,这样深度优先遍历下去,设置一个flag记录是否已经遍历了这个人. 其实dfs真正有用的是flag这个变量,因为如果 ...

  9. [LeetCode]501. Find Mode in Binary Search Tree二叉搜索树寻找众数

    这次是二叉搜索树的遍历 感觉只要和二叉搜索树的题目,都要用到一个重要性质: 中序遍历二叉搜索树的结果是一个递增序列: 而且要注意,在递归遍历树的时候,有些参数如果是要随递归不断更新(也就是如果递归返回 ...

  10. [Machine Learning] 多变量线性回归(Linear Regression with Multiple Variable)-特征缩放-正规方程

    我们从上一篇博客中知道了关于单变量线性回归的相关问题,例如:什么是回归,什么是代价函数,什么是梯度下降法. 本节我们讲一下多变量线性回归.依然拿房价来举例,现在我们对房价模型增加更多的特征,例如房间数 ...