RNN是一个很有意思的模型。早在20年前就有学者发现了它强大的时序记忆能力,另外学术界以证实RNN模型属于Turning-Complete,即理论上可以模拟任何函数。但实际运作上,一开始由于vanishing and exploiting gradient问题导致BPTT算法学习不了长期记忆。虽然之后有了LSTM(长短记忆)模型对普通RNN模型的修改,但是训练上还是公认的比较困难。在Tensorflow框架里,之前的两篇博客已经就官方给出的PTB和Machine Translation模型进行了讲解,现在我们来看一看传说中的机器写诗的模型。原模型出自安德烈.卡帕西大神的char-rnn项目,意在显示RNN强大的能力以及并非那么困难的训练方法。对这个方面有兴趣的朋友请点击这里查看详情。原作的框架为Torch,点击这里查看原作代码。中山大学的zhangzibin以卡帕西大神的代码为样本制作了一款基于卡帕西RNN模型以及Samy Bengio(Bengio大神的亲弟弟)提出的Schedule Sampling算法的可运行中文的RNN模型,源代码请点击这里查看。作为Tensorflow的玩家,我本人当然很想了解下这个框架的运行情况,特别是在Tensorflow框架里的运行情况。好在有人已经捷足先登,将代码移植完毕了。今天我们就来看看这个神奇框架在Tensorflow下的代码。对该项目感兴趣的朋友可以在这里下载到项目的源码并在自己的机器上运行。

既然有了Tensorflow版本的代码了,那么我们开始解剖这段代码吧!

在解剖代码之前,让我们先对代码的运行做一个了解。在运行时,我们需要做的是cd到项目里后,先运行train.py文件来训练代码。默认的迭代数是50个迭代,默认的训练文件是tinyshakespear目录里的input.txt文件,也就是莎士比亚的一些作品。由于默认都是设定好的,我们不需要做任何更改,直接运行python train.py就好了。训练速度还是比较客观的,大约需要运行一个小时(没算具体时间),我们会发现训练完成,参数已经保存。之后,如果我们想看看运行的结果如何,打入python sample.py后,就会随机产生一段文字,该段文字是由机器学习了训练文本后自行计算的。之后我会放上机器在学习了郭敬明的幻成和小时代后自己写出的句子供大家参考。

在了解了运行方式后,既然入口文件是train.py,那么我们就先来看看该文件的设计。不出所料,train.py文件的开始为一系列的parser.add_argument。在之前的代码里我们已经多次见到,无非是加入了运行系统所需的参数,他们的默认值以及参数的解释。从这里我们发现默认的RNN框架为lstm,2层RNN结构,每层有128个神经元节点。另外,我们的sequence length定义为50,也就是每一次可以执行50个时间序列。之后便是train函数。如同往常,我们发现textloader函数为录入训练集的函数,这个函数存在于utils.py文件里。该文件很容易理解,在读入数据后通过collections.Counter收集文本中不一样的character,并将他们写入vocab_file文件做保存,已备后用。之后,根据总数据大小,minibatch大小以及时序长短来界定运行完整个文件需要多少个minibatches,并将文本分类成minibatch的训练以及目标batches。由于这个模型的目的是学习一个character后下一个character的概率,训练集跟目标函数间的差异为一个character,即在训练句子My name is Edward时,假设训练集为: My name is Edwar, 相对应的目标集为y name is Edward。 从逻辑角度上说,不管是这个util.py文件还是之前博客里的CBOW模型,他们的核心逻辑都是相似的,只是在处理上由于目标不同而产生出工程上的差异。有兴趣的朋友可以对比这个util.py文件里的逻辑和CBOW模型里读入输入的函数做对比。

之后,train.py文件对需要的目录以及文件进行确认后就是建立模型了。通过model = Model(args),我们建立了这个RNN所需要的模型。那么模型是如何建立的呢?让我们仔细来看看model.py文件。这个model.py文件里存在两个函数:init函数以及sample函数。他们分别被用来训练模型以及测试模型。让我们首先来看看模型的训练:

def __init__(self, args, infer=False):
self.args = args
# 这里的infer被默认为False,只有在测试效果
# 的时候才会被设计为True,在True的状态下
# 只有一个batch,time step也被设计为1,我们
# 可以由此观测训练成功
if infer:
args.batch_size = 1
args.seq_length = 1 # 这里是选择RNN cell的类型,备选的有lstm, gru和simple rnn
# 这里由输入的arg里的model参数作为测试标准,默认为lstm
# 但是,我们可以看到,这里通过不同的模型我们可以用不同
# 的cell。
if args.model == 'rnn':
cell_fn = rnn_cell.BasicRNNCell
elif args.model == 'gru':
cell_fn = rnn_cell.GRUCell
elif args.model == 'lstm':
cell_fn = rnn_cell.BasicLSTMCell
else:
raise Exception("model type not supported: {}".format(args.model))
# 定义cell的神经元数量,等同于cell = rnn_cell.BasicLSTMCell(args.rnn_size)
cell = cell_fn(args.rnn_size)
# 由于结构为多层结构,我们运用MultiRNNCell来定义神经元层。
self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers)
# 输入,同PTB模型,输入的格式为batch_size X sequence_length(step)
self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.initial_state = cell.zero_state(args.batch_size, tf.float32) with tf.variable_scope('rnnlm'):
softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])
softmax_b = tf.get_variable("softmax_b", [args.vocab_size])
with tf.device("/cpu:0"):
# 这里运用embedding来将输入的不同词汇map到隐匿层的神经元上
embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
# 这里对input的shaping很有意思。这个地方如果我们仔细去读PTB模型就会发现在他的
# outputs = []这行附近有一段注释的文字,解释了一个alternative做法,这个做法就是那
# alternative的方法。首先,我们将embedding_loopup所得到的[batch_size, seq_length, rnn_size]
# tensor按照sequence length划分为一个list的[batch_size, 1, rnn_size]的tensor以表示每个
# 步骤的输入。之后通过squeeze把那个1维度去掉,达成一个list的[batch_size, rnn_size]
# 输入来被我们的rnn模型运用。
inputs = tf.split(1, args.seq_length, tf.nn.embedding_lookup(embedding, self.input_data))
inputs = [tf.squeeze(input_, [1]) for input_ in inputs]
# 这里定义的loop实际在于当我们要测试运行结果,即让机器自己写文章时,我们需要对每一步
# 的输出进行查看。如果我们是在训练中,我们并不需要这个loop函数。
def loop(prev, _):
prev = tf.matmul(prev, softmax_w) + softmax_b
prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
return tf.nn.embedding_lookup(embedding, prev_symbol)
# 这里我们得益于tensorflow强大的内部函数,rnn_decoder可以作为黑盒子直接运用,省去了编写
# 的麻烦。另外,上面的loop函数只有在infer是被定为true的时候才会启动,一如我们刚刚所述。另外
# rnn_decoder在tensorflow中的建立方式是以schedule sampling算法为基础制作的,故其自身已经融入
# 了schedule sampling算法。
outputs, last_state = seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None, scope='rnnlm')
# 这里的过程可以说基本等同于PTB模型,首先通过对output的重新梳理得到一个
# [batch_size*seq_length, rnn_size]的输出,并将之放入softmax里,并通过sequence
# loss by example函数进行训练。
output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size])
self.logits = tf.matmul(output, softmax_w) + softmax_b
self.probs = tf.nn.softmax(self.logits)
loss = seq2seq.sequence_loss_by_example([self.logits],
[tf.reshape(self.targets, [-1])],
[tf.ones([args.batch_size * args.seq_length])],
args.vocab_size)
self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length
self.final_state = last_state
self.lr = tf.Variable(0.0, trainable=False)
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),
args.grad_clip)
optimizer = tf.train.AdamOptimizer(self.lr)
self.train_op = optimizer.apply_gradients(zip(grads, tvars))

由上述代码可见,在制作RNN的模型里,不可或缺的步骤如下:

# 制作RNN模型的大概步骤:

# 1.定义cell类型以及模型框架(假设为lstm):
basic_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
cell = tf.nn.rnn_cell.MultiRNNCell([basic_cell]*number_layers) # 2.定义输入
input_data = tf.placeholder(tf.int32, [batch_size, sequence_length])
target = tf.placeholder(tf.int32, [batch_size, sequence_length]) # 3. init zero state
initial_state = cell.zero_state(batch_size, tf.float32) # 4. 整理输入,可以运用PTB的方法或上文介绍的方法,不过要注意
# 你的输入是什么形状的。最后数列要以格式[sequence_length, batch_size, rnn_size]
# 为输入才可以。 # 5. 之后为按照你的应用所需的函数运用了。这里运用的是rnn_decoder, 当然,别的可以
# 运用,比如machine translation里运用的就是embedding_attention_seq2seq # 6. 得到输出,重新编辑输出的结构后可以运用softmax,一般loss为sequence_loss_by_example # 7. 计算loss, final_state以及选用learning rate,之后用clip_by_global norm来定义gradient
# 并运用类似于adam来optimise算法。可以运用minimize或者apply_gradients来训练。

在了解了模型后,我们发现剩下的代码都是比较常见的,例如initialize_all_variables, 以及运用learning rate decay的方式训练模型。由此,train.py文件的训练过程我们已经做了一个大概的了解了。那么,系统又是如何让我们可以测试训练好的模型呢?让我们来看看sample.py文件。通过parser.add_arugment函数,我们发现文件会选取我们保存模型的地点,并会产生500字符的sample,至于sample选项,我们发现设定为0是得到最多的timestep,1是每一个timestep, 2是sample on spaces。之后,我们读取存储的模型内容后将内容传递进Model函数,并将infer设为True。在output的时候,我们运用sample函数来的到输出。在这个函数里,我们发现一般的prime开头是‘The’,这里我们可以通过sample.py里的prime函数来指定一个开头。在之后,我们发现那个sample参数设为0时选取的是argmax,1时是weighted_pick,2时以space为标准,如果有space则选择weighted_pick, 不然就是argmax。好了,实际运行的效果如何呢?让我们来几个例子看看。

这里是the为开始,我们看到了一开始有点乱码。之后,我们看到可是我知道了,他们三个女生等短句都是通顺的,同时,也有一些及其不通的,例如底忘记新地把满些了,这句话什么含义完全不清楚。再看下一个例子,如果以我开头会怎么样呢?

如果把sample设为0又会如何呢?

这里篇幅紧凑多了。再次运行,我们得到了相同的结果,因为是argmax么,所以在没改变的情况下我们会得到相同的结果。

这个运行结果还是很有意思的,有兴趣的朋友可以自行下载项目然后试着去操作一下!

character-RNN模型介绍以及代码解析的更多相关文章

  1. (zhuan) 深度学习全网最全学习资料汇总之模型介绍篇

    This blog from : http://weibo.com/ttarticle/p/show?id=2309351000224077630868614681&u=5070353058& ...

  2. Keras实现RNN模型

    博客作者:凌逆战 博客地址:https://www.cnblogs.com/LXP-Never/p/10940123.html 这篇文章主要介绍使用Keras框架来实现RNN家族模型,TensorFl ...

  3. 【论文笔记】AutoML for MCA on Mobile Devices——论文解读与代码解析

    理论部分 方法介绍 本节将详细介绍AMC的算法流程.AMC旨在自动地找出每层的冗余参数. AMC训练一个强化学习的策略,对每个卷积层会给出其action(即压缩率),然后根据压缩率进行裁枝.裁枝后,A ...

  4. Beam Search快速理解及代码解析(上)

    Beam Search 简单介绍一下在文本生成任务中常用的解码策略Beam Search(集束搜索). 生成式任务相比普通的分类.tagging等NLP任务会复杂不少.在生成的时候,模型的输出是一个时 ...

  5. Beam Search快速理解及代码解析

    目录 Beam Search快速理解及代码解析(上) Beam Search 贪心搜索 Beam Search Beam Search代码解析 准备初始输入 序列扩展 准备输出 总结 Beam Sea ...

  6. R数据分析:二分类因变量的混合效应,多水平logistics模型介绍

    今天给大家写广义混合效应模型Generalised Linear Random Intercept Model的第一部分 ,混合效应logistics回归模型,这个和线性混合效应模型一样也有好几个叫法 ...

  7. Java 集合系列05之 LinkedList详细介绍(源码解析)和使用示例

    概要  前面,我们已经学习了ArrayList,并了解了fail-fast机制.这一章我们接着学习List的实现类——LinkedList.和学习ArrayList一样,接下来呢,我们先对Linked ...

  8. Java 集合系列10之 HashMap详细介绍(源码解析)和使用示例

    概要 这一章,我们对HashMap进行学习.我们先对HashMap有个整体认识,然后再学习它的源码,最后再通过实例来学会使用HashMap.内容包括:第1部分 HashMap介绍第2部分 HashMa ...

  9. Java 集合系列11之 Hashtable详细介绍(源码解析)和使用示例

    概要 前一章,我们学习了HashMap.这一章,我们对Hashtable进行学习.我们先对Hashtable有个整体认识,然后再学习它的源码,最后再通过实例来学会使用Hashtable.第1部分 Ha ...

随机推荐

  1. 深入理解Java虚拟机:OutOfMemory实战

    在Java虚拟机规范的描述中,除了程序计数器外,虚拟机内存的其他几个运行时区域都有发生OutOfMemoryError(下文称OOM)异常的可能,本节将通过若干实例来验证异常发生的场景.并且会初步介绍 ...

  2. win32 控件的创建和消息响应

    1. 控件的创建 控件的创建和窗口创建是一样的,例如: ,,,, hWnd,(HMENU)IDB_BUTTON01,hInst,NULL); 是一个按钮的创建,其中hWnd是窗口句柄,hInst是应用 ...

  3. ubuntu中Mysql常用命令整理

    启动mysql服务sudo /etc/init.d/mysql start 关闭mysql服务sudo /etc/init.d/mysql stop

  4. 读jQuery源码 jQuery.data

    var rbrace = /(?:\{[\s\S]*\}|\[[\s\S]*\])$/, rmultiDash = /([A-Z])/g; function internalData( elem, n ...

  5. codeforces 652D . Nested Segments 线段树

    题目链接 我们将线段按照右端点从小到大排序, 如果相同, 那么按照左端点从大到小排序. 然后对每一个l, 查询之前有多少个l比他大, 答案就是多少.因为之前的r都是比自己的r小的, 如果l还比自己大的 ...

  6. cocos2d-x -------之笔记篇 3D动作说明

    CCShaky3D::create(时间,晃动网格大小,晃动范围,Z轴是否晃动);    //创建一个3D晃动的效果 CCShakyTiles3D::create(时间,晃动网格大小,晃动范围,Z轴是 ...

  7. Linux常用命令--网络管理篇(三)

    ping –b 10.0.0.255 扫描子网网段 ifconfig 查看网络信息 netconfig 配置网络,配置网络后用service network restart重新启动网络 ifconfi ...

  8. Oracle单个数据文件超过32G后扩容

    Oracle单个数据文件超过32G后扩容   表空间数据文件容量与DB_BLOCK_SIZE的设置有关,而这个参数在创建数据库实例的时候就已经指定.DB_BLOCK_SIZE参数可以设置为4K.8K. ...

  9. c# webBrowser 获取Ajax信息 .

    c#中 webbrowser控件对Ajax的执行,没有任何的响应,难于判断Ajax是否已经执行完毕,我GG了一下午,找到一个方法,介绍一下: 假如在页面中有个<div id=result> ...

  10. stack的应用

    STL除了给我们提供了一些容器(container)以外,还给我们提供了几个容器适配器(container adapters),stack便是其中之一 看过STL源码的人都知道,stack其实是内部封 ...