RNN(Recurrent Neural Networks)公式推导和实现

http://x-algo.cn/index.php/2016/04/25/rnn-recurrent-neural-networks-derivation-and-implementation/
2016-04-25 分类:Deep Learning / NLP / RNN 阅读(6997) 评论(7) 

本文主要参考wildml的博客所写,所有的代码都是python实现。没有使用任何深度学习的工具,公式推导虽然枯燥,但是推导一遍之后对RNN的理解会更加的深入。看本文之前建议对传统的神经网络的基本知识已经了解,如果不了解的可以看此文:『神经网络(Neural Network)实现』。

所有可执行代码:Code to follow along is on Github.

文章目录 [展开]

语言模型

熟悉NLP的应该会比较熟悉,就是将自然语言的一句话『概率化』。具体的,如果一个句子有m个词,那么这个句子生成的概率就是:

P(w1,...,wm)=∏mi=1P(wi∣w1,...,wi−1)P(w1,...,wm)=∏i=1mP(wi∣w1,...,wi−1)

其实就是假设下一次词生成的概率和只和句子前面的词有关,例如句子『He went to buy some chocolate』生成的概率可以表示为:  P(他喜欢吃巧克力) = P(他喜欢吃) * P(巧克力|他喜欢吃) 。

数据预处理

训练模型总需要语料,这里语料是来自google big query的reddit的评论数据,语料预处理会去掉一些低频词从而控制词典大小,低频词使用一个统一标识替换(这里是UNKNOWN_TOKEN),预处理之后每一个词都会使用一个唯一的编号替换;为了学出来哪些词常常作为句子开始和句子结束,引入SENTENCE_START和SENTENCE_END两个特殊字符。具体就看代码吧:

点击展开代码

 

网络结构

和传统的nn不同,但是也很好理解,rnn的网络结构如下图:

A recurrent neural network and the unfolding in time of the computation involved in its forward computation.

不同之处就在于rnn是一个『循环网络』,并且有『状态』的概念。

如上图,t表示的是状态, xtxt 表示的状态t的输入, stst 表示状态t时隐层的输出, otot 表示输出。特别的地方在于,隐层的输入有两个来源,一个是当前的 xtxt 输入、一个是上一个状态隐层的输出 st−1st−1 , W,U,VW,U,V 为参数。使用公式可以将上面结构表示为:

sty^t=tanh(Uxt+Wst−1)=softmax(Vst)st=tanh⁡(Uxt+Wst−1)y^t=softmax(Vst)

如果隐层节点个数为100,字典大小C=8000,参数的维度信息为:

xtotstUVW∈R8000∈R8000∈R100∈R100×8000∈R8000×100∈R100×100xt∈R8000ot∈R8000st∈R100U∈R100×8000V∈R8000×100W∈R100×100

初始化

参数的初始化有很多种方法,都初始化为0将会导致『symmetric calculations 』(我也不懂),如何初始化其实是和具体的激活函数有关系,我们这里使用的是tanh,一种推荐的方式是初始化为 [−1n√,1n√][−1n,1n] ,其中n是前一层接入的链接数。更多信息请点击查看更多

 
1
2
3
4
5
6
7
8
9
10
11
class RNNNumpy:
    
    def __init__(self, word_dim, hidden_dim=100, bptt_truncate=4):
        # Assign instance variables
        self.word_dim = word_dim
        self.hidden_dim = hidden_dim
        self.bptt_truncate = bptt_truncate
        # Randomly initialize the network parameters
        self.U = np.random.uniform(-np.sqrt(1./word_dim), np.sqrt(1./word_dim), (hidden_dim, word_dim))
        self.V = np.random.uniform(-np.sqrt(1./hidden_dim), np.sqrt(1./hidden_dim), (word_dim, hidden_dim))
        self.W = np.random.uniform(-np.sqrt(1./hidden_dim), np.sqrt(1./hidden_dim), (hidden_dim, hidden_dim))

前向传播

类似传统的nn的方法,计算几个矩阵乘法即可:

 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def forward_propagation(self, x):
    # The total number of time steps
    T = len(x)
    # During forward propagation we save all hidden states in s because need them later.
    # We add one additional element for the initial hidden, which we set to 0
    s = np.zeros((T + 1, self.hidden_dim))
    s[-1] = np.zeros(self.hidden_dim)
    # The outputs at each time step. Again, we save them for later.
    o = np.zeros((T, self.word_dim))
    # For each time step...
    for t in np.arange(T):
        # Note that we are indxing U by x[t]. This is the same as multiplying U with a one-hot vector.
        s[t] = np.tanh(self.U[:,x[t]] + self.W.dot(s[t-1]))
        o[t] = softmax(self.V.dot(s[t]))
    return [o, s]

预测函数可以写为:

 
1
2
3
4
def predict(self, x):
    # Perform forward propagation and return index of the highest score
    o, s = self.forward_propagation(x)
    return np.argmax(o, axis=1)

损失函数

类似nn方法,使用交叉熵作为损失函数,如果有N个样本,损失函数可以写为:

L(y,o)=−1N∑n∈NynlogonL(y,o)=−1N∑n∈Nynlog⁡on

下面两个函数用来计算损失:

 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def calculate_total_loss(self, x, y):
    L = 0
    # For each sentence...
    for i in np.arange(len(y)):
        o, s = self.forward_propagation(x[i])
        # We only care about our prediction of the "correct" words
        correct_word_predictions = o[np.arange(len(y[i])), y[i]]
        # Add to the loss based on how off we were
        L += -1 * np.sum(np.log(correct_word_predictions))
    return L
 
def calculate_loss(self, x, y):
    # Divide the total loss by the number of training examples
    N = np.sum((len(y_i) for y_i in y))
    return self.calculate_total_loss(x,y)/N

BPTT学习参数

BPTT( Backpropagation Through Time)是一种非常直观的方法,和传统的BP类似,只不过传播的路径是个『循环』,并且路径上的参数是共享的。

损失是交叉熵,损失可以表示为:

Et(yt,y^t)E(y,y^)=−ytlogy^t=∑tEt(yt,y^t)=−∑tytlogy^tEt(yt,y^t)=−ytlog⁡y^tE(y,y^)=∑tEt(yt,y^t)=−∑tytlog⁡y^t

其中 ytyt 是真实值, (^yt)(^yt) 是预估值,将误差展开可以用图表示为:

所以对所有误差求W的偏导数为:

∂E∂W=∑t∂Et∂W∂E∂W=∑t∂Et∂W

进一步可以将 EtEt 表示为:

∂E3∂V=∂E3∂y^3∂y^3∂V=∂E3∂y^3∂y^3∂z3∂z3∂V=(y^3−y3)⊗s3∂E3∂V=∂E3∂y^3∂y^3∂V=∂E3∂y^3∂y^3∂z3∂z3∂V=(y^3−y3)⊗s3

根据链式法则和RNN中W权值共享,可以得到:

∂E3∂W=∑k=03∂E3∂y^3∂y^3∂s3∂s3∂sk∂sk∂W∂E3∂W=∑k=03∂E3∂y^3∂y^3∂s3∂s3∂sk∂sk∂W

下图将这个过程表示的比较形象

BPTT更新梯度的代码:

 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def bptt(self, x, y):
    T = len(y)
    # Perform forward propagation
    o, s = self.forward_propagation(x)
    # We accumulate the gradients in these variables
    dLdU = np.zeros(self.U.shape)
    dLdV = np.zeros(self.V.shape)
    dLdW = np.zeros(self.W.shape)
    delta_o = o
    delta_o[np.arange(len(y)), y] -= 1.
    # For each output backwards...
    for t in np.arange(T)[::-1]:
        dLdV += np.outer(delta_o[t], s[t].T)
        # Initial delta calculation: dL/dz
        delta_t = self.V.T.dot(delta_o[t]) * (1 - (s[t] ** 2))
        # Backpropagation through time (for at most self.bptt_truncate steps)
        for bptt_step in np.arange(max(0, t-self.bptt_truncate), t+1)[::-1]:
            # print "Backpropagation step t=%d bptt step=%d " % (t, bptt_step)
            # Add to gradients at each previous step
            dLdW += np.outer(delta_t, s[bptt_step-1])              
            dLdU[:,x[bptt_step]] += delta_t
            # Update delta for next step dL/dz at t-1
            delta_t = self.W.T.dot(delta_t) * (1 - s[bptt_step-1] ** 2)
    return [dLdU, dLdV, dLdW]

梯度弥散现象

tanh和sigmoid函数和导数的取值返回如下图,可以看到导数取值是[0-1],用几次链式法则就会将梯度指数级别缩小,所以传播不了几层就会出现梯度非常弱。克服这个问题的LSTM是一种最近比较流行的解决方案。

Gradient Checking

梯度检验是非常有用的,检查的原理是一个点的『梯度』等于这个点的『斜率』,估算一个点的斜率可以通过求极限的方式:

∂L∂θ≈limh→0J(θ+h)−J(θ−h)2h∂L∂θ≈limh→0J(θ+h)−J(θ−h)2h

通过比较『斜率』和『梯度』的值,我们就可以判断梯度计算的是否有问题。需要注意的是这个检验成本还是很高的,因为我们的参数个数是百万量级的。

梯度检验的代码:

 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def gradient_check(self, x, y, h=0.001, error_threshold=0.01):
    # Calculate the gradients using backpropagation. We want to checker if these are correct.
    bptt_gradients = self.bptt(x, y)
    # List of all parameters we want to check.
    model_parameters = ['U', 'V', 'W']
    # Gradient check for each parameter
    for pidx, pname in enumerate(model_parameters):
        # Get the actual parameter value from the mode, e.g. model.W
        parameter = operator.attrgetter(pname)(self)
        print "Performing gradient check for parameter %s with size %d." % (pname, np.prod(parameter.shape))
        # Iterate over each element of the parameter matrix, e.g. (0,0), (0,1), ...
        it = np.nditer(parameter, flags=['multi_index'], op_flags=['readwrite'])
        while not it.finished:
            ix = it.multi_index
            # Save the original value so we can reset it later
            original_value = parameter[ix]
            # Estimate the gradient using (f(x+h) - f(x-h))/(2*h)
            parameter[ix] = original_value + h
            gradplus = self.calculate_total_loss([x],[y])
            parameter[ix] = original_value - h
            gradminus = self.calculate_total_loss([x],[y])
            estimated_gradient = (gradplus - gradminus)/(2*h)
            # Reset parameter to original value
            parameter[ix] = original_value
            # The gradient for this parameter calculated using backpropagation
            backprop_gradient = bptt_gradients[pidx][ix]
            # calculate The relative error: (|x - y|/(|x| + |y|))
            relative_error = np.abs(backprop_gradient - estimated_gradient)/(np.abs(backprop_gradient) + np.abs(estimated_gradient))
            # If the error is to large fail the gradient check
            if relative_error > error_threshold:
                print "Gradient Check ERROR: parameter=%s ix=%s" % (pname, ix)
                print "+h Loss: %f" % gradplus
                print "-h Loss: %f" % gradminus
                print "Estimated_gradient: %f" % estimated_gradient
                print "Backpropagation gradient: %f" % backprop_gradient
                print "Relative Error: %f" % relative_error
                return
            it.iternext()
        print "Gradient check for parameter %s passed." % (pname)

SGD实现

这个公式应该非常熟悉:

W=W−λΔWW=W−λΔW

其中 ΔWΔW 就是梯度,具体代码:

 
1
2
3
4
5
6
7
8
# Performs one step of SGD.
def numpy_sdg_step(self, x, y, learning_rate):
    # Calculate the gradients
    dLdU, dLdV, dLdW = self.bptt(x, y)
    # Change parameters according to gradients and learning rate
    self.U -= learning_rate * dLdU
    self.V -= learning_rate * dLdV
    self.W -= learning_rate * dLdW

生成文本

生成过程其实就是模型的应用过程,只需要反复执行预测函数即可:

 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def generate_sentence(model):
    # We start the sentence with the start token
    new_sentence = [word_to_index[sentence_start_token]]
    # Repeat until we get an end token
    while not new_sentence[-1] == word_to_index[sentence_end_token]:
        next_word_probs = model.forward_propagation(new_sentence)
        sampled_word = word_to_index[unknown_token]
        # We don't want to sample unknown words
        while sampled_word == word_to_index[unknown_token]:
            samples = np.random.multinomial(1, next_word_probs[-1])
            sampled_word = np.argmax(samples)
        new_sentence.append(sampled_word)
    sentence_str = [index_to_word[x] for x in new_sentence[1:-1]]
    return sentence_str
 
num_sentences = 10
senten_min_length = 7
 
for i in range(num_sentences):
    sent = []
    # We want long sentences, not sentences with one or two words
    while len(sent) < senten_min_length:
        sent = generate_sentence(model)
    print " ".join(sent)

参考文献

Recurrent Neural Networks Tutorial, Part 2 – Implementing a RNN with Python, Numpy and Theano

Recurrent Neural Networks Tutorial, Part 3 – Backpropagation Through Time and Vanishing Gradients

转:RNN(Recurrent Neural Networks)的更多相关文章

  1. 循环神经网络(RNN, Recurrent Neural Networks)介绍(转载)

    循环神经网络(RNN, Recurrent Neural Networks)介绍    这篇文章很多内容是参考:http://www.wildml.com/2015/09/recurrent-neur ...

  2. RNN(Recurrent Neural Networks)公式推导和实现

    RNN(Recurrent Neural Networks)公式推导和实现 http://x-algo.cn/index.php/2016/04/25/rnn-recurrent-neural-net ...

  3. 循环神经网络(RNN, Recurrent Neural Networks)介绍

    原文地址: http://blog.csdn.net/heyongluoyao8/article/details/48636251# 循环神经网络(RNN, Recurrent Neural Netw ...

  4. 《转》循环神经网络(RNN, Recurrent Neural Networks)学习笔记:基础理论

    转自 http://blog.csdn.net/xingzhedai/article/details/53144126 更多参考:http://blog.csdn.net/mafeiyu80/arti ...

  5. 简述RNN Recurrent Neural Networks

    本文结构: 什么是 Recurrent Neural Networks ? Recurrent Neural Networks 的优点和应用? 训练 Recurrent Neural Networks ...

  6. 循环神经网络(RNN, Recurrent Neural Networks)——无非引入了环,解决时间序列问题

    摘自:http://blog.csdn.net/heyongluoyao8/article/details/48636251 不同于传统的FNNs(Feed-forward Neural Networ ...

  7. 循环神经网络(Recurrent Neural Networks, RNN)介绍

    目录 1 什么是RNNs 2 RNNs能干什么 2.1 语言模型与文本生成Language Modeling and Generating Text 2.2 机器翻译Machine Translati ...

  8. The Unreasonable Effectiveness of Recurrent Neural Networks (RNN)

    http://karpathy.github.io/2015/05/21/rnn-effectiveness/ There’s something magical about Recurrent Ne ...

  9. Attention and Augmented Recurrent Neural Networks

    Attention and Augmented Recurrent Neural Networks CHRIS OLAHGoogle Brain SHAN CARTERGoogle Brain Sep ...

随机推荐

  1. Intel Code Challenge Elimination Round (Div.1 + Div.2, combined) A. Broken Clock 水题

    A. Broken Clock 题目连接: http://codeforces.com/contest/722/problem/A Description You are given a broken ...

  2. 使用 IntraWeb (2) - Hello IntraWeb

    IntraWeb 比我相像中的更贴近 VCL, 传统的非可视组件在这里大都可用(其内部很多复合属性是 TStringList 类型的), 它的诸多可视控件也是从 TControl 继承下来的. 这或许 ...

  3. PG的集群技术:Pgpool-II与Postgres-XC Postgres-XL Postgres-XZ Postges-x2

    https://segmentfault.com/a/1190000007012082 https://www.postgres-xl.org/ https://www.biaodianfu.com/ ...

  4. 多用StringBuilder,少用字符串拼接

    在C#中,在处理字符串拼接的时候,使用StringBuilder的效率会比硬拼接字符串高很多.到底有多高,如下: static void Main(string[] args) { string st ...

  5. 在ASP.NET MVC中使用Knockout实践05,基本验证

    本篇体验View Model验证.Knockout的subscribe方法能为View Model成员注册验证规则. @{ ViewBag.Title = "Index"; Lay ...

  6. 报错:System.Data.Entity.Validation.DbEntityValidationException: 对一个或多个实体的验证失败

    使用MVC和EF,在保存数据的时候报错:System.Data.Entity.Validation.DbEntityValidationException: 对一个或多个实体的验证失败.有关详细信息, ...

  7. delphi连接mysql不用添加DSN(mysql connector odbc 5.1版)

    一.下载安装mysql驱动http://mysql.com/downloads/connector/odbc/二.添加adoconnection,adoquery,使用以下连接字符串http://ww ...

  8. Windows Phone本地数据库(SQLCE):7、Database mapping(翻译)

    这是“windows phone mango本地数据库(sqlce)”系列短片文章的第七篇. 为了让你开始在Windows Phone Mango中使用数据库,这一系列短片文章将覆盖所有你需要知道的知 ...

  9. 选股:“均线是水,K线是舟,量是马达!”的选美理念!

    选股:“均线是水,K线是舟,量是马达!”的选美理念! 很多庄家就是故意做数据,让某只股票的数据非常符合“理论”,引诱“技术派”股民

  10. Reflector_8.3.0.93_安装文件及破解工具

    Reflector_8.3.0.93_安装文件及破解工具 下载地址:http://pan.baidu.com/s/1jGwsYYM   约 8.9MB