Back-propagation, an introduction
Back-propagation, an introduction
Given the sheer number of backpropagation tutorials on the internet, is there really need for another? One of us (Sanjeev) recently taught backpropagation in undergrad AI and couldn’t find any account he was happy with. So here’s our exposition, together with some history and context, as well as a few advanced notions at the end. This article assumes the reader knows the definitions of gradients and neural networks.
What is backpropagation?
It is the basic algorithm in training neural nets, apparently independently rediscovered several times in the 1970-80’s (e.g., see Werbos’ Ph.D. thesis and book, and Rumelhart et al.). Some related ideas existed in control theory in the 1960s.
Backpropagation gives a fast way to compute the sensitivity of the output of a neural network to all of its parameters while keeping the inputs of the network fixed: specifically it computes all partial derivatives ∂f/∂wi where f is the output and wi is the ith parameter. (Here parameters can be edge weights or biases associated with nodes or edges of the network, and the precise details of the node computations —e.g., the precise form of nonlinearity like Sigmoid or RELU— are unimportant.) Doing so gives the gradient ∇f of f with respect to its network parameters, which allows a gradient descent step in the training: change all parameters simultaneously to move the vector of parameters a small amount in the direction −∇f.
Note that backpropagation computes the gradient exactly, but properly training neural nets needs many more tricks than just backpropagation. Understanding backpropagation is useful for understanding the advanced tricks.
The importance of backpropagation derives from its efficiency. Assuming node operations take unit time, the running time is linear, specifically, O(Network Size)=O(V+E), where V is the number of nodes in the network and E is the number of edges. The only technical ingredient is chain rule from calculus, but applying it naively would have resulted in quadratic running time—which would be hugely inefficient for networks with millions or even thousands of parameters.
Backpropagation can be efficiently implemented using highly parallel vector operations available in today’s GPUs (Graphical Processing Units), which play an important role in the the recent neural nets revolution.
Side Note: Expert readers will recognize that in the standard accounts of neural net training, the actual quantity of interest is the gradient of the training loss, which happens to be a simple function of the network output. But the above phrasing is fully general since one can simply add a new output node to the network that computes the training loss from the old output. Then the quantity of interest is indeed the gradient of this new output with respect to network parameters.
Problem Setup
Backpropagation applies only to acyclic networks with directed edges. (Later we briefly sketch its use on networks with cycles.)
Without loss of generality, acyclic networks can be visualized as being structured in numbered layers, with nodes in the t+1th layer getting all their inputs from the outputs of nodes in layers t and earlier. We use f∈R to denote the output of the network. In all our figures, the input of the network is at the bottom and the output on the top.
We start with a simple claim that reduces the problem of computing the gradient to the problem of computing partial derivatives with respect to the nodes:
Claim 1: To compute the desired gradient with respect to the parameters, it suffices to compute ∂f/∂u for every node u.
Let’s be clear what ∂f/∂u means. Suppose we cut off all the incoming edges of the node u, and fix/clamp the current values of all network parameters. Now imagine changing u from its current value. This change may affect values of nodes at higher levels that are connected to u, and the final output f is one such node. Then ∂f/∂u denotes the rate at which f will change as we vary u. (Aside: Readers familiar with the usual exposition of back-propagation should note that there f is the training error and this ∂f/∂u turns out to be exactly the “error” propagated back to on the node u.)
Claim 1 is a direct application of chain rule, and let’s illustrate it for a simple neural nets (we address more general networks later). Suppose node u is a weighted sum of the nodes z1,…,zm (which will be passed through a non-linear activation σ afterwards). That is, we have u=w1z1+⋯+wnzn. By Chain rule, we have
Hence, we see that having computed ∂f/∂u we can compute ∂f/∂w1, and moreover this can be done locally by the endpoints of the edge where w1 resides.
Multivariate Chain Rule
Towards computing the derivatives with respect to the nodes, we first recall the multivariate Chain rule, which handily describes the relationships between these partial derivatives (depending on the graph structure).
Suppose a variable f is a function of variables u1,…,un, which in turn depend on the variable z. Then, multivariate Chain rule says that
This is a direct generalization of eqn. (2) and a sub-case of eqn. (11) in this description of chain rule.
This formula is perfectly suitable for our cases. Below is the same example as we used before but with a different focus and numbering of the nodes.
We see that given we’ve computed the derivatives with respect to all the nodes that is above the node z, we can compute the derivative with respect to the node z via a weighted sum, where the weights involve the local derivative ∂uj/∂z that is often easy to compute. This brings us to the question of how we measure running time. For book-keeping, we assume that
Basic assumption: If u is a node at level t+1 and z is any node at level ≤twhose output is an input to u, then computing ∂u∂z takes unit time on our computer.
Naive feedforward algorithm (not efficient!)
It is useful to first point out the naive quadratic time algorithm implied by the chain rule. Most authors skip this trivial version, which we think is analogous to teaching sorting using only quicksort, and skipping over the less efficient bubblesort.
The naive algorithm is to compute ∂ui/∂uj for every pair of nodes where ui is at a higher level than uj. Of course, among these V2 values (where V is the number of nodes) are also the desired ∂f/∂ui for all i since f is itself the value of the output node.
This computation can be done in feedforward fashion. If such value has been obtained for every uj on the level up to and including level t, then one can express (by inspecting the multivariate chain rule) the value ∂uℓ/∂uj for some uℓ at level t+1 as a weighted combination of values ∂ui/∂uj for each ui that is a direct input to uℓ. This description shows that the amount of computation for a fixed j is proportional to the number of edges E. This amount of work happens for all V values of j, letting us conclude that the total work in the algorithm is O(VE).
Backpropagation (Linear Time)
The more efficient backpropagation, as the name suggests, computes the partial derivatives in the reverse direction. Messages are passed in one wave backwards from higher number layers to lower number layers. (Some presentations of the algorithm describe it as dynamic programming.)
Messaging protocol: The node u receives a message along each outgoing edge from the node at the other end of that edge. It sums these messages to get a number S (if u is the output of the entire net, then define S=1) and then it sends the following message to any node z adjacent to it at a lower level: S⋅∂u∂z
Clearly, the amount of work done by each node is proportional to its degree, and thus overall work is the sum of the node degrees. Summing all node degrees counts each edge twice, and thus the overall work is O(Network Size).
To prove correctness, we prove the following:
Main Claim: At each node z, the value S is exactly ∂f/∂z.
Base Case: At the output layer this is true, since ∂f/∂f=1.
Inductive case: Suppose the claim was true for layers t+1 and higher and u is at layer t, with outgoing edges go to some nodes u1,u2,…,um at levels t+1 or higher. By inductive hypothesis, node z indeed receives ∂f∂uj×∂uj∂z from each of uj. Thus by Chain rule,S=∑mi=1∂f∂ui∂ui∂z=∂f∂z. This completes the induction and proves the Main Claim.
Auto-differentiation
Since the exposition above used almost no details about the network and the operations that the node perform, it extends to every computation that can be organized as an acyclic graph whose each node computes a differentiable function of its incoming neighbors. This observation underlies many auto-differentiation packages such as autograd or tensorflow: they allow computing the gradient of the output of such a computation with respect to the network parameters.
We first observe that Claim 1 continues to hold in this very general setting. This is without loss of generality because we can view the parameters associated to the edges as also sitting on the nodes (actually, leaf nodes). This can be done via a simple transformation to the network; for a single node it is shown in the picture below; and one would need to continue to do this transformation in the rest of the networks feeding into u1,u2,.. etc from below.
Then, we can use the messaging protocol to compute the derivatives with respect to the nodes, as long as the local partial derivative can be computed efficiently. We note that the algorithm can be implemented in a fairly modular manner: For every node u, it suffices to specify (a) how it depends on the incoming nodes, say, z1,…,zn and (b) how to compute the partial derivative times S, that is, S⋅∂u∂zj.
Extension to vector messages: In fact (b) can be done efficiently in more general settings where we allow the output of each node in the network to be a vector (or even matrix/tensor) instead of only a real number. For example, as illustrated below, suppose the node U∈Rd1×d3 is a product of two matrices W∈Rd1×d2 and Z∈Rd2×d3. Then we have that ∂U/∂Z is a linear operator that maps Rd2×d3 to Rd1×d3, which naively require a matrix representation of dimension d2d3×d2d3. However, the computation (b) can be done efficiently becauseS⋅∂U∂Z=SW.
Such vector operations can also be implemented efficiently using today’s GPUs.
Notable Extensions
1) Allowing weight tying. In many neural architectures, the designer wants to force many network units such as edges or nodes to share the same parameter. For example, in convolutional neural nets, the same filter has to be applied all over the image, which implies reusing the same parameter for a large set of edges between the two layers.
For simplicity, suppose two parameters a and b are supposed to share the same value. This is equivalent to adding a new node u and connecting u to both a and b with the operation a=uand b=u. Thus, by chain rule, ∂f∂u=∂f∂a⋅∂a∂u+∂f∂b⋅∂b∂u=∂f∂a+∂f∂b. Hence, equivalently, the gradient with respect to a shared parameter is the sum of the gradients with respect to individual occurrences.
2) Backpropagation on networks with loops. The above exposition assumed the network is acyclic. Many cutting-edge applications such as machine translation and language understanding use networks with directed loops (e.g., recurrent neural networks). These architectures —all examples of the “differentiable computing” paradigm below—can get complicated and may involve operations on a separate memory as well as mechanisms to shift attention to different parts of data and memory.
Networks with loops are trained using gradient descent as well, using back-propagation through time, which consists of expanding the network through a finite number of time steps into an acyclic graph, with replicated copies of the same network. These replicas share the weights (weight tying!) so the gradient can be computed. In practice an issue may arise with exploding or vanishing gradients which impact convergence. Such issues can be carefully addressed in practice by clipping the gradient or re-parameterization techniques such as long short-term memory.
The fact that the gradient can be computed efficiently for such general networks with loops has motivated neural net models with memory or even data structures (see for example neural Turing machines and differentiable neural computer). Using gradient descent, one can optimize over a family of parameterized networks with loops to find the best one that solves a certain computational task (on the training examples). The limits of these ideas are still being explored.
3) Hessian-vector product in linear time. It is possible to generalize backprop to enable 2nd order optimization in “near-linear” time, not just gradient descent, as shown in recent independent manuscripts of Carmon et al. and Agarwal et al. (NB: Tengyu is a coauthor on this one.). One essential step is to compute the product of the Hessian matrix and a vector, for which Pearlmutter’93 gave an efficient algorithm. Here we show how to do this in O(Network size)using the ideas above. We need a slightly stronger version of the back-propagation result than the one in the previous subsection:
Claim (informal): Suppose an acyclic network with V nodes and E edges has output f and leaves z1,…,zm. Then there exists a network of size O(V+E) that has z1,…,zm as input nodes and ∂f∂z1,…,∂f∂zm as output nodes.
The proof of the Claim follows in straightforward fashion from implementing the message passing protocol as an acyclic circuit.
Next we show how to compute ∇2f(z)⋅v where v is a given fixed vector. Let g(z)=⟨∇f(z),v⟩ be a function from Rd→R. Then by the Claim above, g(z) can be computed by a network of size O(V+E). Now apply the Claim again on g(z), we obtain that ∇g(z) can also be computed by a network of size O(V+E).
Note that by construction, ∇g(z)=∇2f(z)⋅v. Hence we have computed the Hessian vector product in network size time.
##That’s all!
Please write your comments on this exposition and whether it can be improved.
Comments
Back-propagation, an introduction的更多相关文章
- Propagation of Visual Entity Properties Under Bandwidth Constraints
1. Introduction The Saga of Ryzom is a persistent massively-multiplayer online game (MMORPG) release ...
- Machine Learning Algorithms Study Notes(1)--Introduction
Machine Learning Algorithms Study Notes 高雪松 @雪松Cedro Microsoft MVP 目 录 1 Introduction 1 1.1 ...
- A brief introduction to weakly supervised learning(简要介绍弱监督学习)
by 南大周志华 摘要 监督学习技术通过学习大量训练数据来构建预测模型,其中每个训练样本都有其对应的真值输出.尽管现有的技术已经取得了巨大的成功,但值得注意的是,由于数据标注过程的高成本,很多任务很难 ...
- 吴恩达课后习题第二课第三周:TensorFlow Introduction
目录 第二课第三周:TensorFlow Introduction Introduction to TensorFlow 1 - Packages 1.1 - Checking TensorFlow ...
- 谣言检测(DUCK)《DUCK: Rumour Detection on Social Media by Modelling User and Comment Propagation Networks》
论文信息 论文标题:DUCK: Rumour Detection on Social Media by Modelling User and Comment Propagation Networks论 ...
- A chatroom for all! Part 1 - Introduction to Node.js(转发)
项目组用到了 Node.js,发现下面这篇文章不错.转发一下.原文地址:<原文>. ------------------------------------------- A chatro ...
- Introduction to graph theory 图论/脑网络基础
Source: Connected Brain Figure above: Bullmore E, Sporns O. Complex brain networks: graph theoretica ...
- INTRODUCTION TO BIOINFORMATICS
INTRODUCTION TO BIOINFORMATICS 这套教程源自Youtube,算得上比较完整的生物信息学领域的视频教程,授课内容完整清晰,专题化的讲座形式,细节讲解比国内的京师大 ...
- Spring学习记录1--@Transactional Propagation
起因 学习Spring的时候就知道aop有一个应用是声明式注解..反正往Service上一丢@Transactional就完事了..不用自己去开启hibernate的session,很简单. 但是@T ...
- mongoDB index introduction
索引为mongoDB的查询提供了有效的解决方案,如果没有索引,mongodb必须的扫描文档集中所有记录来match查询条件的记录.然而这些扫描是没有必要,而且每一次操作mongod进程会处理大量的数据 ...
随机推荐
- 使用JavaScript的history对象来实现页面前进后退(go/back/forward)。
我们都知道JavaScript有history对象,主要是用来记录浏览器窗口的浏览记录.但是,JS脚本是不允许访问到这个记录里面的内容(隐私). 常见的用法是: history.back();//返回 ...
- SSZIPArchive的相关用法截图
- MySQL版本升级之5.6到5.7
两种升级方式 In-Place Upgrade: Involves shutting down the old MySQL version, replacing the old MySQL binar ...
- php大力力 [048节] php一点支付开发资料,很散
https://beecloud.cn/activity/jsbutton/?index=4&t=1441261629019 https://beecloud.cn/download/ php ...
- C#的New关键字的几种用法
一.在C#中,new这个关键字使用频率非常高,主要有3个功能: a) 作为运算符用来创建一个对象和调用构造函数. b) 作为修饰符. c) 用于在泛型声明中约束可能用作类型参 ...
- AngularJS执行流程详解
一.启动阶段 大家应该都知道,当浏览器加载一个HTML页面时,它会将HMTL页面先解析成DOM树,然后逐个加载DOM树中的每一个元素节点.我们可以把AngularJS当做一个类似jQuery的js库, ...
- 利用win服务定时为网卡启用/禁用
上周,Boss和我说,他儿子夜里爬起来用笔记本在被窝里玩CF,问路由器可以解决么,我看了是TPLINK的普通家用无线路由器,不支持禁用CF客户端网游,可以通过配置端口屏蔽什么的,但是白天又要开启,想想 ...
- C语言指针,你还觉得难吗?
在研究式学习-c语言程序设计指针式这样介绍的: 内存中存储变量的第一个单元的地址 称为指针,存放指针的变量称为指针变量: 变量的访问方式有: 直接访问方式:按变量名查找其首地址 间接访问方式:通过指针 ...
- 用python+selenium获取北上广深成五地PM2.5数据信息并按空气质量排序
从http://www.pm25.com/shenzhen.html抓取北京,深圳,上海,广州,成都的pm2.5指数,并按照空气质量从优到差排序,保存在txt文档里 代码如下: #coding=utf ...
- VLOOKUP
vlookup(查找目标,查找范围, 返回值的列数,精确或模糊查找) 1. 查找目标:查找的内容或者单元格引用. 2. 查找范围: 选定一个查找区域. 注意一: 查找目标一定要在该区域的第一列. 注意 ...