RNN是一个很有意思的模型。早在20年前就有学者发现了它强大的时序记忆能力,另外学术界以证实RNN模型属于Turning-Complete,即理论上可以模拟任何函数。但实际运作上,一开始由于vanishing and exploiting gradient问题导致BPTT算法学习不了长期记忆。虽然之后有了LSTM(长短记忆)模型对普通RNN模型的修改,但是训练上还是公认的比较困难。在Tensorflow框架里,之前的两篇博客已经就官方给出的PTB和Machine Translation模型进行了讲解,现在我们来看一看传说中的机器写诗的模型。原模型出自安德烈.卡帕西大神的char-rnn项目,意在显示RNN强大的能力以及并非那么困难的训练方法。对这个方面有兴趣的朋友请点击这里查看详情。原作的框架为Torch,点击这里查看原作代码。中山大学的zhangzibin以卡帕西大神的代码为样本制作了一款基于卡帕西RNN模型以及Samy Bengio(Bengio大神的亲弟弟)提出的Schedule Sampling算法的可运行中文的RNN模型,源代码请点击这里查看。作为Tensorflow的玩家,我本人当然很想了解下这个框架的运行情况,特别是在Tensorflow框架里的运行情况。好在有人已经捷足先登,将代码移植完毕了。今天我们就来看看这个神奇框架在Tensorflow下的代码。对该项目感兴趣的朋友可以在这里下载到项目的源码并在自己的机器上运行。

既然有了Tensorflow版本的代码了,那么我们开始解剖这段代码吧!

在解剖代码之前,让我们先对代码的运行做一个了解。在运行时,我们需要做的是cd到项目里后,先运行train.py文件来训练代码。默认的迭代数是50个迭代,默认的训练文件是tinyshakespear目录里的input.txt文件,也就是莎士比亚的一些作品。由于默认都是设定好的,我们不需要做任何更改,直接运行python train.py就好了。训练速度还是比较客观的,大约需要运行一个小时(没算具体时间),我们会发现训练完成,参数已经保存。之后,如果我们想看看运行的结果如何,打入python sample.py后,就会随机产生一段文字,该段文字是由机器学习了训练文本后自行计算的。之后我会放上机器在学习了郭敬明的幻成和小时代后自己写出的句子供大家参考。

在了解了运行方式后,既然入口文件是train.py,那么我们就先来看看该文件的设计。不出所料,train.py文件的开始为一系列的parser.add_argument。在之前的代码里我们已经多次见到,无非是加入了运行系统所需的参数,他们的默认值以及参数的解释。从这里我们发现默认的RNN框架为lstm,2层RNN结构,每层有128个神经元节点。另外,我们的sequence length定义为50,也就是每一次可以执行50个时间序列。之后便是train函数。如同往常,我们发现textloader函数为录入训练集的函数,这个函数存在于utils.py文件里。该文件很容易理解,在读入数据后通过collections.Counter收集文本中不一样的character,并将他们写入vocab_file文件做保存,已备后用。之后,根据总数据大小,minibatch大小以及时序长短来界定运行完整个文件需要多少个minibatches,并将文本分类成minibatch的训练以及目标batches。由于这个模型的目的是学习一个character后下一个character的概率,训练集跟目标函数间的差异为一个character,即在训练句子My name is Edward时,假设训练集为: My name is Edwar, 相对应的目标集为y name is Edward。 从逻辑角度上说,不管是这个util.py文件还是之前博客里的CBOW模型,他们的核心逻辑都是相似的,只是在处理上由于目标不同而产生出工程上的差异。有兴趣的朋友可以对比这个util.py文件里的逻辑和CBOW模型里读入输入的函数做对比。

之后,train.py文件对需要的目录以及文件进行确认后就是建立模型了。通过model = Model(args),我们建立了这个RNN所需要的模型。那么模型是如何建立的呢?让我们仔细来看看model.py文件。这个model.py文件里存在两个函数:init函数以及sample函数。他们分别被用来训练模型以及测试模型。让我们首先来看看模型的训练:

def __init__(self, args, infer=False):
self.args = args
# 这里的infer被默认为False,只有在测试效果
# 的时候才会被设计为True,在True的状态下
# 只有一个batch,time step也被设计为1,我们
# 可以由此观测训练成功
if infer:
args.batch_size = 1
args.seq_length = 1 # 这里是选择RNN cell的类型,备选的有lstm, gru和simple rnn
# 这里由输入的arg里的model参数作为测试标准,默认为lstm
# 但是,我们可以看到,这里通过不同的模型我们可以用不同
# 的cell。
if args.model == 'rnn':
cell_fn = rnn_cell.BasicRNNCell
elif args.model == 'gru':
cell_fn = rnn_cell.GRUCell
elif args.model == 'lstm':
cell_fn = rnn_cell.BasicLSTMCell
else:
raise Exception("model type not supported: {}".format(args.model))
# 定义cell的神经元数量,等同于cell = rnn_cell.BasicLSTMCell(args.rnn_size)
cell = cell_fn(args.rnn_size)
# 由于结构为多层结构,我们运用MultiRNNCell来定义神经元层。
self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers)
# 输入,同PTB模型,输入的格式为batch_size X sequence_length(step)
self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.initial_state = cell.zero_state(args.batch_size, tf.float32) with tf.variable_scope('rnnlm'):
softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])
softmax_b = tf.get_variable("softmax_b", [args.vocab_size])
with tf.device("/cpu:0"):
# 这里运用embedding来将输入的不同词汇map到隐匿层的神经元上
embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
# 这里对input的shaping很有意思。这个地方如果我们仔细去读PTB模型就会发现在他的
# outputs = []这行附近有一段注释的文字,解释了一个alternative做法,这个做法就是那
# alternative的方法。首先,我们将embedding_loopup所得到的[batch_size, seq_length, rnn_size]
# tensor按照sequence length划分为一个list的[batch_size, 1, rnn_size]的tensor以表示每个
# 步骤的输入。之后通过squeeze把那个1维度去掉,达成一个list的[batch_size, rnn_size]
# 输入来被我们的rnn模型运用。
inputs = tf.split(1, args.seq_length, tf.nn.embedding_lookup(embedding, self.input_data))
inputs = [tf.squeeze(input_, [1]) for input_ in inputs]
# 这里定义的loop实际在于当我们要测试运行结果,即让机器自己写文章时,我们需要对每一步
# 的输出进行查看。如果我们是在训练中,我们并不需要这个loop函数。
def loop(prev, _):
prev = tf.matmul(prev, softmax_w) + softmax_b
prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
return tf.nn.embedding_lookup(embedding, prev_symbol)
# 这里我们得益于tensorflow强大的内部函数,rnn_decoder可以作为黑盒子直接运用,省去了编写
# 的麻烦。另外,上面的loop函数只有在infer是被定为true的时候才会启动,一如我们刚刚所述。另外
# rnn_decoder在tensorflow中的建立方式是以schedule sampling算法为基础制作的,故其自身已经融入
# 了schedule sampling算法。
outputs, last_state = seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None, scope='rnnlm')
# 这里的过程可以说基本等同于PTB模型,首先通过对output的重新梳理得到一个
# [batch_size*seq_length, rnn_size]的输出,并将之放入softmax里,并通过sequence
# loss by example函数进行训练。
output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size])
self.logits = tf.matmul(output, softmax_w) + softmax_b
self.probs = tf.nn.softmax(self.logits)
loss = seq2seq.sequence_loss_by_example([self.logits],
[tf.reshape(self.targets, [-1])],
[tf.ones([args.batch_size * args.seq_length])],
args.vocab_size)
self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length
self.final_state = last_state
self.lr = tf.Variable(0.0, trainable=False)
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),
args.grad_clip)
optimizer = tf.train.AdamOptimizer(self.lr)
self.train_op = optimizer.apply_gradients(zip(grads, tvars))

由上述代码可见,在制作RNN的模型里,不可或缺的步骤如下:

# 制作RNN模型的大概步骤:

# 1.定义cell类型以及模型框架(假设为lstm):
basic_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
cell = tf.nn.rnn_cell.MultiRNNCell([basic_cell]*number_layers) # 2.定义输入
input_data = tf.placeholder(tf.int32, [batch_size, sequence_length])
target = tf.placeholder(tf.int32, [batch_size, sequence_length]) # 3. init zero state
initial_state = cell.zero_state(batch_size, tf.float32) # 4. 整理输入,可以运用PTB的方法或上文介绍的方法,不过要注意
# 你的输入是什么形状的。最后数列要以格式[sequence_length, batch_size, rnn_size]
# 为输入才可以。 # 5. 之后为按照你的应用所需的函数运用了。这里运用的是rnn_decoder, 当然,别的可以
# 运用,比如machine translation里运用的就是embedding_attention_seq2seq # 6. 得到输出,重新编辑输出的结构后可以运用softmax,一般loss为sequence_loss_by_example # 7. 计算loss, final_state以及选用learning rate,之后用clip_by_global norm来定义gradient
# 并运用类似于adam来optimise算法。可以运用minimize或者apply_gradients来训练。

在了解了模型后,我们发现剩下的代码都是比较常见的,例如initialize_all_variables, 以及运用learning rate decay的方式训练模型。由此,train.py文件的训练过程我们已经做了一个大概的了解了。那么,系统又是如何让我们可以测试训练好的模型呢?让我们来看看sample.py文件。通过parser.add_arugment函数,我们发现文件会选取我们保存模型的地点,并会产生500字符的sample,至于sample选项,我们发现设定为0是得到最多的timestep,1是每一个timestep, 2是sample on spaces。之后,我们读取存储的模型内容后将内容传递进Model函数,并将infer设为True。在output的时候,我们运用sample函数来的到输出。在这个函数里,我们发现一般的prime开头是‘The’,这里我们可以通过sample.py里的prime函数来指定一个开头。在之后,我们发现那个sample参数设为0时选取的是argmax,1时是weighted_pick,2时以space为标准,如果有space则选择weighted_pick, 不然就是argmax。好了,实际运行的效果如何呢?让我们来几个例子看看。

这里是the为开始,我们看到了一开始有点乱码。之后,我们看到可是我知道了,他们三个女生等短句都是通顺的,同时,也有一些及其不通的,例如底忘记新地把满些了,这句话什么含义完全不清楚。再看下一个例子,如果以我开头会怎么样呢?

如果把sample设为0又会如何呢?

这里篇幅紧凑多了。再次运行,我们得到了相同的结果,因为是argmax么,所以在没改变的情况下我们会得到相同的结果。

这个运行结果还是很有意思的,有兴趣的朋友可以自行下载项目然后试着去操作一下!

character-RNN模型介绍以及代码解析的更多相关文章

  1. (zhuan) 深度学习全网最全学习资料汇总之模型介绍篇

    This blog from : http://weibo.com/ttarticle/p/show?id=2309351000224077630868614681&u=5070353058& ...

  2. Keras实现RNN模型

    博客作者:凌逆战 博客地址:https://www.cnblogs.com/LXP-Never/p/10940123.html 这篇文章主要介绍使用Keras框架来实现RNN家族模型,TensorFl ...

  3. 【论文笔记】AutoML for MCA on Mobile Devices——论文解读与代码解析

    理论部分 方法介绍 本节将详细介绍AMC的算法流程.AMC旨在自动地找出每层的冗余参数. AMC训练一个强化学习的策略,对每个卷积层会给出其action(即压缩率),然后根据压缩率进行裁枝.裁枝后,A ...

  4. Beam Search快速理解及代码解析(上)

    Beam Search 简单介绍一下在文本生成任务中常用的解码策略Beam Search(集束搜索). 生成式任务相比普通的分类.tagging等NLP任务会复杂不少.在生成的时候,模型的输出是一个时 ...

  5. Beam Search快速理解及代码解析

    目录 Beam Search快速理解及代码解析(上) Beam Search 贪心搜索 Beam Search Beam Search代码解析 准备初始输入 序列扩展 准备输出 总结 Beam Sea ...

  6. R数据分析:二分类因变量的混合效应,多水平logistics模型介绍

    今天给大家写广义混合效应模型Generalised Linear Random Intercept Model的第一部分 ,混合效应logistics回归模型,这个和线性混合效应模型一样也有好几个叫法 ...

  7. Java 集合系列05之 LinkedList详细介绍(源码解析)和使用示例

    概要  前面,我们已经学习了ArrayList,并了解了fail-fast机制.这一章我们接着学习List的实现类——LinkedList.和学习ArrayList一样,接下来呢,我们先对Linked ...

  8. Java 集合系列10之 HashMap详细介绍(源码解析)和使用示例

    概要 这一章,我们对HashMap进行学习.我们先对HashMap有个整体认识,然后再学习它的源码,最后再通过实例来学会使用HashMap.内容包括:第1部分 HashMap介绍第2部分 HashMa ...

  9. Java 集合系列11之 Hashtable详细介绍(源码解析)和使用示例

    概要 前一章,我们学习了HashMap.这一章,我们对Hashtable进行学习.我们先对Hashtable有个整体认识,然后再学习它的源码,最后再通过实例来学会使用Hashtable.第1部分 Ha ...

随机推荐

  1. php获取post参数的几种方式

    php获取post参数的几种方式 1.$_POST['paramName'] 只能接收Content-Type: application/x-www-form-urlencoded提交的数据 2.fi ...

  2. oracle查看系统资源占用情况

    1,连上服务器,使用top命令,可以查看cpu使用率以及内存的使用情况等等,还有当前各用户的使用情况 2,用pl/sql developper,tool里面选sessions,就可以看到当前sessi ...

  3. iOS百度推送的基本使用

    一.iOS证书指导 在 iOS App 中加入消息推送功能时,必须要在 Apple 的开发者中心网站上申请推送证书,每一个 App 需要申请两个证书,一个在开发测试环境下使用,另一个用于上线到 App ...

  4. 关于 css padding 的使用 padding会将使用该属性的元素撑开

    .right_img_box{ width:300px; height:250px; border:1px solid #c9c9c9; margin-bottom:15px; background: ...

  5. silverlight依赖属性

    依赖属性(Dependency Property)和附加属性(Attached Property) 参考 http://www.cnblogs.com/KevinYang/archive/2010/0 ...

  6. delphi 基础书籍推荐

    本文所推荐的书,我均仔细读过,受益良多. 1. Pascal 精要.下载Pascal精要 本书讲Pascal 语言基本知识. 2. Object Pascal 参考(中英文对照版).下载Object ...

  7. Android Activity 的四种启动模式 lunchMode 和 Intent.setFlags();singleTask的两种启动方式。

    原文:Android Activity 的四种启动模式 lunchMode 和 Intent.setFlags();singleTask的两种启动方式. Android Activity 的四种启动模 ...

  8. HDU 3123-GCC(递推)

    GCC Time Limit: 1000/1000 MS (Java/Others)    Memory Limit: 131072/131072 K (Java/Others) Total Subm ...

  9. ceph 参数说明<转>

    //path/to/socket指向某个osd的admin socket文件#> ceph --admin-daemon {path/to/socket} config show | grep ...

  10. JavaScript之arguments.callee

    arguments.callee 在哪一个函数中运行,它就代表哪个函数. 一般用在匿名函数中. 在匿名函数中有时会需要自己调用自己,但是由于是匿名函数,没有名子,无名可调. 这时就可以用argumen ...