反向传播

课程来源:PyTorch深度学习实践——河北工业大学

《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili

笔记

在之前课程中介绍的线性模型就是一个最简单的神经网络的结构,其内部参数的更新过程如下:

对于简单的模型来说可以直接使用表达式的方式来更新权重,但是如果网络结构比较复杂(如下图),直接使用解析式的方式来更新显然有些复杂且不太可能实现。

反向传播就是为了解决这种问题。反向传播的基本思想就是将网络看成一张图,在图上传播梯度,从而使用链式传播来计算梯度。首先介绍两层的网络的计算图的方式表示如下图所示:

矩阵求导参考书籍链接如下:https://bicmr.pku.edu.cn/~wenzw/bigdata/matrix-cook-book.pdf

如果把式子展开,将会有如下结果:

也就是多层线性模型的叠加是可以用一个线性模型来实现的。因此为了提高模型的复杂程度,对于每一层的输出增加一个非线性的变化函数,如sigmoid等函数,如下图所示:

反向传播的链式求导的过程一个实例如下图所示:

得到相应导数之后就可以对于权重进行更新,如果x也只是一个中间结果,则可以继续向前传导。

接下来可以看一个完整的线性模型的计算图示例,过程就是先进行前馈过程,在前馈到loss之后进行反向传播,从而完成计算:

接下来介绍在PyTorch中如何进行前馈和反馈计算。

首先需要介绍的是Tensor,这是PyTorch中构建动态图的一个重要组成部分,Tensor中主要元素的是Data(数据)和Grad(导数),分别用于保存权重值和损失函数对权重的导数。

使用PyTorch实现上述的线性模型的代码如下:

import torch
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0] w = torch.tensor([1.0]) #初值为1.0
w.requires_grad = True # 需要计算梯度 def forward(x):
return x*w # 返回tensor def loss(x, y):
y_pred = forward(x)
return (y_pred - y)**2 print("predict (before training)", 4, forward(4).item()) for epoch in range(100):
for x, y in zip(x_data, y_data):
l =loss(x,y) #l是一个张量
l.backward() #将计算链路上需要梯度的地方计算出梯度,这一步之后计算图释放,每一次更新都创建新的计算图
print('\tgrad:', x, y, w.grad.item())#item是为了把梯度中的数值取出为标量
w.data = w.data - 0.01 * w.grad.data # 权重更新时,使用标量,使用data的时候不会建立新的计算图,注意grad也是一个tensor
w.grad.data.zero_() # 更新之后将梯度数据清零
print('progress:', epoch, l.item())
print("predict (after training)", 4, forward(4).item())

作业

1、手动推导线性模型y=w*x,损失函数loss=(ŷ-y)²下,当数据集x=2,y=4的时候,反向传播的过程。

2、手动推导线性模型 y=w*x+b,损失函数loss=(ŷ-y)²下,当数据集x=1,y=2的时候,反向传播的过程。

3、画出二次模型y=w1x²+w2x+b,损失函数loss=(ŷ-y)²的计算图,并且手动推导反向传播的过程,最后用pytorch的代码实现。

代码如下:

import torch
import matplotlib.pyplot as plt
import numpy as np
x_data=[1.0,2.0,3.0]
y_data=[2.0,4.0,6.0]
w1=torch.tensor([1.0],requires_grad=True)
w2=torch.tensor([1.0],requires_grad=True)
b=torch.tensor([1.0],requires_grad=True)
epoch_list=[]
loss_list=[]
def forward(x):
return w1*x**2+w2*x+b
def loss(x,y):
y_pred=forward(x)
return (y_pred-y)**2
print('Predict (befortraining)',4,forward(4))
for epoch in range(100):
for x,y in zip(x_data,y_data):
l=loss(x,y)
l.backward()
print('\tgrad:',x,y,w1.grad.item(),w2.grad.item(),b.grad.item())
w1.data=w1.data-0.01*w1.grad.data
w2.data = w2.data - 0.01 * w2.grad.data
b.data = b.data - 0.01 * b.grad.data
w1.grad.data.zero_()
w2.grad.data.zero_()
b.grad.data.zero_()
print('Epoch:', epoch, l.item())
epoch_list.append(epoch)
loss_list.append(l.data)
print('Predict(after training)', 4, forward(4).item())
print('predict (after training)', 4, forward(4))
plt.plot(epoch_list, loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()

可视化loss如下:

PyTorch深度学习实践——反向传播的更多相关文章

  1. PyTorch深度学习实践——多分类问题

    多分类问题 目录 多分类问题 Softmax 在Minist数据集上实现多分类问题 作业 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩 ...

  2. PyTorch深度学习实践——处理多维特征的输入

    处理多维特征的输入 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili 这一讲介绍输入为多维数据时的分类. 一个数据集 ...

  3. 深度学习梯度反向传播出现Nan值的原因归类

    症状:前向计算一切正常.梯度反向传播的时候就出现异常,梯度从某一层开始出现Nan值(Nan: Not a number缩写,在numpy中,np.nan != np.nan,是唯一个不等于自身的数). ...

  4. 深度学习之反向传播算法(BP)代码实现

    反向传播算法实战 本文仅仅是反向传播算法的实现,不涉及公式推导,如果对反向传播算法公式推导不熟悉,强烈建议查看另一篇文章神经网络之反向传播算法(BP)公式推导(超详细) 我们将实现一个 4 层的全连接 ...

  5. PyTorch深度学习实践-Overview

    Overview 1.PyTorch简介 ​ PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序.它主要由Facebookd的人工智能小组开发,不仅能够 实现强 ...

  6. 深度学习实践系列(2)- 搭建notMNIST的深度神经网络

    如果你希望系统性的了解神经网络,请参考零基础入门深度学习系列,下面我会粗略的介绍一下本文中实现神经网络需要了解的知识. 什么是深度神经网络? 神经网络包含三层:输入层(X).隐藏层和输出层:f(x) ...

  7. 使用PyTorch构建神经网络以及反向传播计算

    使用PyTorch构建神经网络以及反向传播计算 前一段时间南京出现了疫情,大概原因是因为境外飞机清洁处理不恰当,导致清理人员感染.话说国外一天不消停,国内就得一直严防死守.沈阳出现了一例感染人员,我在 ...

  8. 深度学习实践系列(3)- 使用Keras搭建notMNIST的神经网络

    前期回顾: 深度学习实践系列(1)- 从零搭建notMNIST逻辑回归模型 深度学习实践系列(2)- 搭建notMNIST的深度神经网络 在第二篇系列中,我们使用了TensorFlow搭建了第一个深度 ...

  9. 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码

    PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...

随机推荐

  1. JavaScript如何实现上拉加载,下拉刷新?

    转载地址: 面试官:JavaScript如何实现上拉加载,下拉刷新? 一.前言 下拉刷新和上拉加载这两种交互方式通常出现在移动端中 本质上等同于PC网页中的分页,只是交互形式不同 开源社区也有很多优秀 ...

  2. jsp 4-15

  3. Java微服务监控及与普罗米集成

    一.    背景说明 Java服务级监控用于对每个应用占用的内存.线程池的线程数量.restful调用数量和响应时间.JVM状态.GC信息等进行监控,并可将指标信息同步至普罗米修斯中集中展示和报警.网 ...

  4. RabbitMQ简介及安装

    AMQP简介 AMQP AMQP(Advanced Message Queuing Protocol,高级消息队列协议)是进程之间传递异步消息的网络协议. AMQP工作过程 发布者(Publisher ...

  5. Eclipse不能启动,提示:The Eclipse executable launcher was unable to locate its companion launcher jar

    原因分析:JDK版本与eclipse不匹配 如jdk和eclipse版本号必须统一,64位都是64位,32位都是32位. jdk版本可以用命令,cmd进入命令窗口,然后输入java -version, ...

  6. linux下打包所有文件,包括隐藏文件到压缩包

    命令如下: cd /root/test/ tar czvf test.tar.gz .[!.]* * 解释: tar czvf test.tar.gz * 压缩当前文件夹下非[隐藏文件]的文件 tar ...

  7. 关于viewControllers之间的传值方式

    AViewController----Push----BViewController 1.属性 AViewController---pop----BViewController 1.代理  2.通知  ...

  8. [翻译]Introduction to JSON Web Tokens

    JWT: Json Web Tokens JWT是一种开放标准(RFC 7519),它定义了一种紧凑且独立的方式,用于将各方之间的信息安全地传输为JSON对象.因为它是经过数字签名的,所以该信息可以进 ...

  9. requests实现接口测试

    python+requests实现接口测试 - get与post请求基本使用方法 http://www.cnblogs.com/nizhihong/p/6567928.html   Requests ...

  10. 他人学习Python感悟

    作者:王一 链接:https://www.zhihu.com/question/26235428/answer/36568428 来源:知乎 著作权归作者所有.商业转载请联系作者获得授权,非商业转载请 ...