承前

接上节代码『TensotFlow』RNN中文文本_上

import numpy as np
import tensorflow as tf
from collections import Counter poetry_file = 'poetry.txt' poetrys = []
with open(poetry_file, 'r', encoding='utf-8') as f:
for line in f:
try:
title, content = line.strip().split(':')
content = content.replace(' ','') # 去空格,实际上没用到
if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content:
continue
if len(content) < 5 or len(content) > 79:
continue
content = '[' + content + ']'
poetrys.append(content)
except Exception as e:
pass # 依照每个元素的长度排序
poetrys = sorted(poetrys, key=lambda poetry: len(poetry))
print('唐诗数量:', len(poetrys)) # 统计字出现次数
all_words = []
for poetry in poetrys:
all_words += [word for word in poetry]
counter = Counter(all_words)
# print(counter.items())
# item会把字典中的每一项变成一个2元素元组,字典变成大list
count_pairs = sorted(counter.items(), key=lambda x:-x[1])
# 利用zip提取,因为是原生数据结构,在切片上远不如numpy的结构灵活
words, _ = zip(*count_pairs)
# print(words) words = words[:len(words)] + (' ',) # 后面要用' '来补齐诗句长度
# print(words)
# 转换为字典
word_num_map = dict(zip(words, range(len(words))))
# 把诗词转换为向量
to_num = lambda word: word_num_map.get(word, len(words))
poetry_vector = [list(map(to_num, poetry)) for poetry in poetrys] batch_size = 1
n_chunk = len(poetry_vector) // batch_size
x_batches = []
y_batches = []
for i in range(n_chunk):
start_index = i*batch_size
end_index = start_index + batch_size
batches = poetry_vector[start_index:end_index]
length = max(map(len, batches)) # 记录下最长的诗句的长度
xdata = np.full((batch_size, length), word_num_map[' '], np.int32)
for row in range(batch_size):
xdata[row,:len(batches[row])] = batches[row]
# print(len(xdata[0])) 每个batch中数据长度不相等
ydata = np.copy(xdata)
ydata[:,:-1] = xdata[:,1:]
"""
xdata ydata
[6,2,4,6,9] [2,4,6,9,9]
[1,4,2,8,5] [4,2,8,5,5]
"""
x_batches.append(xdata) # (n_chunk, batch, length)
y_batches.append(ydata)  

这里将数据预处理为3维的数据结构,每次输入后两维度,并将最后的每一个数字映射为一个数组,这是承袭上节的数据处理逻辑结构。

然后我们来看RNN部分。

启后

input_data = tf.placeholder(tf.int32, [batch_size, None])
output_targets = tf.placeholder(tf.int32, [batch_size, None])

下面是RNN网络主体,为了深化对与数据在RNN中流动的理解,我把中间的数据维度进行了输出,注意,

  • 由于我在测试时batch_size设定为1,所以下面的1表示的是batch_size
  • RNN网络的特性决定了不同的batch之间的time_steps可以不相等,但同一个batch中的必须相等,所以输出?,对应的placeholder中的相应维度输入None
  • 这里对应的输入数据维度input_data:(1,?),output_target:(1,?)
# 单层RNN
def neural_network(model='lstm',rnn_size=128,num_layers=2): cell = tf.contrib.rnn.BasicLSTMCell(rnn_size,state_is_tuple=True)
# cell = tf.contrib.rnn.MultiRNNCell([cell for _ in range(num_layers)]) initial_state = cell.zero_state(batch_size,tf.float32) with tf.variable_scope('rnnlm'):
softmax_w = tf.get_variable("softmax_w",[rnn_size,len(words) + 1])
# print(softmax_w) # 128,6111
softmax_b = tf.get_variable("softmax_b",[len(words) + 1])
with tf.device("/cpu:0"):
embedding = tf.get_variable("embedding",[len(words) + 1,rnn_size])
inputs = tf.nn.embedding_lookup(embedding,input_data) # print(input_data) # 1,?
# print(inputs) # 1,?,128 outputs,last_state = tf.nn.dynamic_rnn(cell,inputs,initial_state=initial_state,scope='rnnlm')
output = tf.reshape(outputs,[-1,rnn_size])
# print(outputs) # 1,?,128
# print(output) # ?,128 # ?,128 * 128,6111 -> ?,6111
logits = tf.matmul(output,softmax_w) + softmax_b
probs = tf.nn.softmax(logits)
return logits,last_state

  

训练部分相关函数在『TensorFlow』梯度优化相关中均有介绍,当然这里采用了比较麻烦的做法... ...练习么,

def train_neural_network():
logits,last_state,_,_,_ = neural_network()
targets = tf.reshape(output_targets,[-1])
loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example([logits],
[targets],
[tf.ones_like(targets,dtype=tf.float32)])
cost = tf.reduce_mean(loss)
learning_rate = tf.Variable(0.0,trainable=False)
tvars = tf.trainable_variables()
grads,_ = tf.clip_by_global_norm(tf.gradients(cost,tvars),5)
optimizer = tf.train.AdamOptimizer(learning_rate)
train_op = optimizer.apply_gradients(zip(grads,tvars)) with tf.Session() as sess:
sess.run(tf.initialize_all_variables()) #saver = tf.train.Saver(tf.all_variables())
saver = tf.train.Saver() for epoch in range(50):
sess.run(tf.assign(learning_rate,0.002 * (0.97 ** epoch)))
n = 0
for batche in range(n_chunk):
train_loss,_,_ = sess.run([cost,last_state,train_op],
feed_dict={input_data: x_batches[n],output_targets: y_batches[n]})
n += 1
print(epoch,batche,train_loss)
if n % 5000 == 1:
# saver.save(sess,'poetry.module',global_step=epoch)
saver.save(sess,'./model.ckpt',global_step=epoch) # if epoch % 1 == 0:
# #saver.save(sess,'poetry.module',global_step=epoch)
# saver.save(sess,'./model/model.ckpt',global_step=epoch)
train_neural_network()

实际训练不要忘记把batch_size改大一点,只是个训练程序,没什么其他的补充了。

下面给出利用模型生成文本的部分,

思路是指定初始字符串'[',转换为向量后送入RNN,得到state和下一个字符,利用他们两个进行后续迭代,直到']'出现,生成的字符串即为结果。

def gen_poetry():
def to_word(weights):
t = np.cumsum(weights)
s = np.sum(weights)
sample = np.searchsorted(t,(np.random.rand(1) * s)[0])
return words[sample] _,last_state,probs,cell,initial_state = neural_network() with tf.Session() as sess:
sess.run(tf.initialize_all_variables()) # saver = tf.train.Saver(tf.all_variables())
# saver.restore(sess,'poetry.module-49')
ckpt = tf.train.get_checkpoint_state('./')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
saver.restore(sess,ckpt.model_checkpoint_path) state_ = sess.run(cell.zero_state(1,tf.float32)) x = np.array([list(map(word_num_map.get,'['))])
print(x)
[probs_,state_] = sess.run([probs,last_state],feed_dict={input_data: x,initial_state: state_})
print(probs_.shape)
word = to_word(probs_)
#word = words[np.argmax(probs_)]
poem = ''
while word != ']':
poem += word
x = np.zeros((1,1))
x[0,0] = word_num_map[word]
[probs_,state_] = sess.run([probs,last_state],feed_dict={input_data: x,initial_state: state_})
word = to_word(probs_)
#word = words[np.argmax(probs_)]
print(poem)
return poem print(gen_poetry())

按照正常思路应该使用word = words[np.argmax(probs_)]来还原字符,但这样不收敛(不出现’]‘导致诗句不能够正常结束),所以有上面的另一种生成字符串的方法,不过由于手头没有高性能电脑(迭代次数不够),多层RNN的尝试也不太成功(网络性能不佳),所以也不能说孰优孰劣,给出一个似乎比较成功的例子,Tensorflow-3-使用RNN生成中文小说可以作为日后继续探究的参考(多层RNN的构建&向量还原字符串的方法&更为复杂的中文文本的预处理方法)。

后记

明天就回北京了,这大概就是研究生报道(九月三号)前的最后一篇博客了,回想起4月份至今,学到了不少东西,再回想其去年7月份至今,也真的是有很多成功的尝试,当然也有不少失败的努力,不过即使这样我感觉也蛮不错了,开学后很有可能会被学了4年的计算机本专业的同学吊打,所以我要提前警戒自己的就是既然在这个领域上你能做到的不多,那么就不要贪心勉强,选定一个目标,不要迷茫,或者即使迷茫也不能驻足,前进和不断前进,既然你能做到的事不多,那就把它做得登峰造极。虽然很理想化,但我始终坚信,一个还算合理的大方向规划加上孜孜不倦的努力是可以得到好的结果的,而瓶颈就在于那个说起来容易的‘孜孜不倦’上,我不相信奇迹,或者说我不相信我有得到奇迹的好运气,那么就尽量让所谓的好结果来的水到渠成一点吧,加油,我看好你,不要让我失望。

『TensotFlow』RNN中文文本_下_暨研究生开学感想的更多相关文章

  1. 『TensotFlow』RNN中文文本_上

    中文文字预处理流程 文本处理 读取+去除特殊符号 按照字段长度排序 辅助数据结构生成 生成 {字符:出现次数} 字典 生成按出现次数排序好的字符list 生成 {字符:序号} 字典 生成序号list ...

  2. 『TensotFlow』RNN/LSTM古诗生成

    往期RNN相关工程实践文章 『TensotFlow』基础RNN网络分类问题 『TensotFlow』RNN中文文本_上 『TensotFlow』基础RNN网络回归问题 『TensotFlow』RNN中 ...

  3. 『TensorFlow』通过代码理解gan网络_中

    『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...

  4. 『cs231n』卷积神经网络工程实践技巧_下

    概述 计算加速 方法一: 由于计算机计算矩阵乘法速度非常快,所以这是一个虽然提高内存消耗但是计算速度显著上升的方法,把feature map中的感受野(包含重叠的部分,所以会加大内存消耗)和卷积核全部 ...

  5. 『cs231n』RNN之理解LSTM网络

    概述 LSTM是RNN的增强版,1.RNN能完成的工作LSTM也都能胜任且有更好的效果:2.LSTM解决了RNN梯度消失或爆炸的问题,进而可以具有比RNN更为长时的记忆能力.LSTM网络比较复杂,而恰 ...

  6. 『备注』GDI+ 绘制文本有锯齿,透明背景文本绘制

    背景: GDI+ 绘制文本 时,如果 背景是透明的 —— 则会出现 锯齿. //其实,我不用这三个 属性 好多年了 //而且,这三个属性 在关键时刻还有可能 帮倒忙 //关键是:这三个属性,鸟用都没有 ...

  7. 『TensotFlow』转置卷积

    网上解释 作者:张萌链接:https://www.zhihu.com/question/43609045/answer/120266511来源:知乎著作权归作者所有.商业转载请联系作者获得授权,非商业 ...

  8. 『cs231n』卷积神经网络工程实践技巧_上

    概述 数据增强 思路:在训练的时候引入干扰,在测试的时候避免干扰. 翻转图片增强数据. 随机裁切图片后调整大小用于训练,测试时先图像金字塔制作不同尺寸,然后对每个尺寸在固定位置裁切固定大小进入训练,最 ...

  9. 『TensorFlow』DCGAN生成动漫人物头像_下

    『TensorFlow』以GAN为例的神经网络类范式 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 『TensorFlow』通过代码理解gan网络_中 一.计算 ...

随机推荐

  1. 【ContextLoaderListener】Web项目启动报错java.lang.ClassNotFoundException: ContextLoaderListener

    错误原因: 进入到tomcat的部署路径.metadata\.plugins\org.eclipse.wst.server.core\tmp0\wtpwebapps\下检查了一下,发现工程部署后在WE ...

  2. 光学定位点(mark点)

     Mark点是使用机器焊接时用于定位的点.  表贴元件的pcb更需要设置Mark点,因为在大批量生产时,贴片机都是操作人员手动或者机器自动寻找Mark点进行校准.极少数不设置Mark点也可以,操作非常 ...

  3. javascript创建函数的20种方式汇总

    http://www.jb51.net/article/68285.htm 工作中常常会创建一个函数来解决一些需求问题,以下是个人在工作中总结出来的创建函数20种方式,你知道多少? function ...

  4. 【Selenium2】【问题】

    [iframe 和 HTML 相互嵌套] 比如126登录页,我的几个方法都不好用 1. iframeFather = driver.find_element(By.XPATH,"//div[ ...

  5. React-navigation物理返回键提示效果BackHandler

    componentWillMount(){    BackHandler.addEventListener('hardwareBackPress', this.onBackAndroid); } co ...

  6. _event_worldstate

    EventId 事件ID ID WorldStateUI.dbc第10列数字部分 StartValue 起始值 Entry 更新世界状态需要击杀生物或摧毁物体的entry,正数为生物,负数为物体 St ...

  7. bzoj 4034: [HAOI2015]树上操作 树链剖分+线段树

    4034: [HAOI2015]树上操作 Time Limit: 10 Sec  Memory Limit: 256 MBSubmit: 4352  Solved: 1387[Submit][Stat ...

  8. spring boot 配置双数据源mysql、sqlServer

    背景:原来一直都是使用mysql数据库,在application.properties 中配置数据库信息 spring.datasource.url=jdbc:mysql://xxxx/test sp ...

  9. python requests post和get

    import requests import time import hashlib import os import json from contextlib import closing impo ...

  10. php格式化数字输出number_format

    <?php $num = 4999.944444; $formattedNum = number_format($num).PHP_EOL; echo $formattedNum; $forma ...