深度学习(三)之LSTM写诗

根据前文生成诗:
机器学习业,圣贤不可求。临戎辞蜀计,忠信尽封疆。天子咨两相,建章应四方。自疑非俗态,谁复念鹪鹩。
生成藏头诗:
国步平生不愿君,古人今在古人风。
科公既得忘机者,白首空山道姓名。
大道不应无散处,未曾进退却还征。
环境:
- python:3.9.7
- pytorch:1.11.0
- numpy:1.21.2
代码地址:https://github.com/xiaohuiduan/deeplearning-study/tree/main/写诗
数据预处理
数据集文件由3部分组成:ix2word,word2ix,data:
- ix2word:id到word的映射,如{23:'姑'},一共有8293个word。
- word2ix2:word到id的映射,如{'姑':23}
- data:保存了诗的数据,一共有57580首诗,每条数据由125个word构成;如果诗的长度大于125则截断,如果诗的长度小于125,则使用""进行填充。
每条数据的构成规则为:</s></s></s>\(\dots\)<START>诗词<EOP>。

在训练的时候,不考虑填充数据,因此,将数据中的填充数据</s>去除,去除后,部分数据显示如下:

构建数据集
模型输入输出决定了数据集怎么构建,下图是模型的输入输出示意图。诗词生成实际上是一个语言模型,我们希望Model能够根据当前输入\(x_0,x_1,x_2\dots x_{n-1}\)去预测下一个状态\(x_n\)。如图中所示例子,则是希望在训练的过程中,模型能够根据输入<START>床前明月光生成床前明月光,。

因此根据“<START>床前明月光,凝是地上霜。举头望明月,低头思故乡<EOP>”,可以生成如下的X和Y(seq_len=6)。
X:<START>床前明月光,Y:床前明月光,
X:,凝是地上霜,Y:凝是地上霜。
X:。举头望明月,Y:举头望明月,
X:,低头思故乡,Y:低头思故乡。
代码示意图如下所示,seq_len代表每条训练数据的长度。
seq_len = 48
X = []
Y = []
poems_data = [j for i in poems for j in i] # 将所有诗的内容变成一个一维数组
for i in range(0,len(poems_data) - seq_len -1,seq_len):
X.append(poems_data[i:i+seq_len])
Y.append(poems_data[i+1:i+seq_len+1])
模型结构
模型结构如下所示,模型一共由3部分构成,Embedding层,LSTM层和全连接层。输入数据首先输入Embedding层,进行word2vec,然后将Word2Vec后的数据输入到LSTM中,最后将LSTM的输出输入到全连接层中得到预测结果。

模型构建代码如下,其中在本文中embedding_dim=200,hidden_dim=1024。
import torch
import torch.nn.functional as F
import torch.nn as nn
class PoemNet(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim):
"""
vocab_size:训练集合字典大小(8293)
embedding_dim:word2vec的维度
hidden_dim:LSTM的hidden_dim
"""
super(PoemNet, self).__init__()
self.hidden_dim = hidden_dim
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, self.hidden_dim,batch_first=True)
self.fc = nn.Sequential(
nn.Linear(self.hidden_dim,2048),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(2048,4096),
nn.Dropout(0.2),
nn.ReLU(),
nn.Linear(4096,vocab_size),
)
def forward(self, input,hidden=None):
"""
input:输入的诗词
hidden:在生成诗词的时候需要使用,在pytorch中,如果不指定初始状态h_0和C_0,则其
默认为0.
pytorch的LSTM的输出是(output,(h_n,c_n))。实际上,output就是h_1,h_2,……h_n
"""
embeds = self.embeddings(input)
batch_size, seq_len = input.size()
if hidden is None:
output, hidden = self.lstm(embeds)
else:
# h_0,c_0 = hidden
output, hidden = self.lstm(embeds,hidden)
output = self.fc(output)
output = output.reshape(batch_size * seq_len, -1)
output = F.log_softmax(output,dim=1)
return output,hidden
优化器使用的是Adam优化器,lr=0.001,损失函数是CrossEntropyLoss。训练次数为100个epcoh。
生成诗
因为在模型构建的过程中,使用了dropout,所以在模型使用的时候,需要将model设置为eval模式。
生成诗的逻辑图:

根据上文生成诗
根据上图的原理,写出的代码如下所示:
def generate_poem(my_words,max_len=128):
'''
根据前文my_words生成一首诗。max_len表示生成诗的最大长度。
'''
def __generate_next(idx,hidden=None):
"""
根据input和hidden输出下一个预测
"""
input = torch.Tensor([idx]).view(1,1).long().to(device)
output,hidden = my_net(input,hidden)
return output,hidden
# 初始化hidden状态
output,hidden = __generate_next(word2ix["<START>"])
my_words_len = len(my_words)
result = []
for word in my_words:
result.append(word)
# 积累hidden状态(h & c)
output,hidden = __generate_next(word2ix[word],hidden)
_,top_index = torch.max(output,1)
word = idx2word[top_index[0].item()]
result.append(word)
for i in range(max_len-my_words_len):
output,hidden = __generate_next(top_index[0].item(),hidden)
_,top_index = torch.max(output,1)
if top_index[0].item() == word2ix['<EOP>']: # 如果诗词已经预测到结尾
break
word = idx2word[top_index[0].item()]
result.append(word)
return "".join(result)
generate_poem("睡觉")
睡觉寒炉火,晨钟坐中朝。炉烟沾煖露,池月静清砧。自有传心法,曾无住处传。不知尘世隔,一觉一壺秋。皎洁垂银液,浮航入绿醪。谁知旧邻里,相对似相亲。
生成藏头诗
生成藏头诗的方法与根据上文生成诗的方法大同小异。
def acrostic_poetry(my_words):
def __generate_next(idx,hidden=None):
"""
根据input和hidden输出下一个预测词
"""
input = torch.Tensor([idx]).view(1,1).long().to(device)
output,hidden = my_net(input,hidden)
return output,hidden
def __generate(word,hidden):
"""
根据word生成一句诗(以“。”结尾的话) 如根据床生成“床前明月光,凝是地上霜。”
"""
generate_word = word2ix[word]
sentence = []
sentence.append(word)
while generate_word != word2ix["。"]:
output,hidden = __generate_next(generate_word,hidden)
_,top_index = torch.max(output,1)
generate_word = top_index[0].item()
sentence.append(idx2word[generate_word])
# 根据"。"生成下一个隐状态。
_,hidden = __generate_next(generate_word,hidden)
return sentence,hidden
_,hidden = __generate_next(word2ix["<START>"])
result = []
for word in my_words:
sentence,hidden = __generate(word,hidden)
result.append("".join(sentence))
print("\n".join(result))
acrostic_poetry("滚去读书")
滚发初生光,三乘如太白。 去去冥冥没,冥茫寄天海。 读书三十年,手把棼琴策。 书罢华省郎,忧人惜凋病。
参考
- 简单明朗的 RNN 写诗教程 - 段小辉 - 博客园 (cnblogs.com)
- LSTM — PyTorch 1.11.0 documentation
- Embedding — PyTorch 1.11.0 documentation
深度学习(三)之LSTM写诗的更多相关文章
- 深度学习之 mnist 手写数字识别
深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...
- [ZZ] 深度学习三巨头之一来清华演讲了,你只需要知道这7点
深度学习三巨头之一来清华演讲了,你只需要知道这7点 http://wemedia.ifeng.com/10939074/wemedia.shtml Yann LeCun还提到了一项FAIR开发的,用于 ...
- 时间序列深度学习:状态 LSTM 模型预测太阳黑子
目录 时间序列深度学习:状态 LSTM 模型预测太阳黑子 教程概览 商业应用 长短期记忆(LSTM)模型 太阳黑子数据集 构建 LSTM 模型预测太阳黑子 1 若干相关包 2 数据 3 探索性数据分析 ...
- 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
- 时间序列深度学习:状态 LSTM 模型预測太阳黑子(一)
版权声明:本文为博主原创文章,未经博主同意不得转载. https://blog.csdn.net/kMD8d5R/article/details/82111558 作者:徐瑞龙,量化分析师,R语言中文 ...
- 深度学习--RNN,LSTM
一.RNN 1.定义 递归神经网络(RNN)是两种人工神经网络的总称.一种是时间递归神经网络(recurrent neural network),另一种是结构递归神经网络(recursive neur ...
- 深度学习 循环神经网络 LSTM 示例
最近在网上找到了一个使用LSTM 网络解决 世界银行中各国 GDP预测的一个问题,感觉比较实用,毕竟这是找到的唯一一个可以正确运行的程序. #encoding:UTF-8 import pandas ...
- pytorch深度学习神经网络实现手写字体识别
利用平pytorch搭建简单的神经网络实现minist手写字体的识别,采用三层线性函数迭代运算,使得其具备一定的非线性转化与运算能力,其数学原理如下: 其具体实现代码如下所示:import torch ...
- go微服务框架go-micro深度学习(三) Registry服务的注册和发现
服务的注册与发现是微服务必不可少的功能,这样系统才能有更高的性能,更高的可用性.go-micro框架的服务发现有自己能用的接口Registry.只要实现这个接口就可以定制自己的服务注册和发现. go- ...
随机推荐
- MySQL优化之索引解析
索引的本质 MySQL索引或者说其他关系型数据库的索引的本质就只有一句话,以空间换时间. 索引的作用 索引关系型数据库为了加速对表中行数据检索的(磁盘存储的)数据结构 索引的分类 数据结构上面的分类 ...
- LGP6276题解
众所周知,排列是一个置换,一个置换是一车环. 步数就是这些环长的 \(lcm\). 如果你去思考直接 DP,会发现很困难,根本设不出来状态.于是考虑正难则反:每个质数幂 \(p^k\) 对答案的贡献. ...
- BSOJ6388题解
看上去就很神秘...考虑建出图论模型. 我们将一张牌的两面 \(a,b\) 连一条边. 考虑一个连通块的意义是什么. 边是一张牌,容易发现,如果连通块是一棵树,那么选择一个根节点相当于可以打出除了根节 ...
- [NOIP2013 普及组] 表达式求值
[NOIP2013 普及组] 表达式求值 给定一个只包含加法和乘法的算术表达式,请你编程计算表达式的值. Input 一行,为需要你计算的表达式,表达式中只包含数字.加法运算符"+" ...
- web自动化之svg标签定位
今天在定位元素的时候,发现页面有一个svg标签需要进行定位. 于是便使用常规的xpath定位方法试了一下,很明显结果是不行的,哈哈哈... 错误定位方法://div[@class="oper ...
- java对配置文件properties的操作
1.读取配置文件的键值对,转为Properties对象:将Properties(键值对)对象写入到指定文件. package com.ricoh.rapp.ezcx.admintoolweb.util ...
- Python knife 一款伪菜刀
Python knife 一款伪菜刀. 设计之初,本想只写个命令行的就可以了,但又想与众不同,想用python写代码,又不想用c#写前端(c#太卡了),万分无奈之下,找到一个替代品,Pyqt, ...
- Python编写简易木马程序(转载乌云)
Python编写简易木马程序 light · 2015/01/26 10:07 0x00 准备 文章内容仅供学习研究.切勿用于非法用途! 这次我们使用Python编写一个具有键盘记录.截屏以及通信功能 ...
- Java基础(中)
面向对象基础 面向对象和面向过程的区别 两者的主要区别在于解决问题的方式不同: 面向过程把解决问题的过程拆成一个个方法,通过一个个方法的执行解决问题. 面向对象会先抽象出对象,然后用对象执行方法的方式 ...
- Lua协程的一个例子
很久没记录笔记了,还是养成不了记录的习惯 下面是来自 programming in lua的一个协程的例(生产者与用户的例子) 帖代码,慢慢理解 -- Programming in Lua Corou ...