反向传播(BPN)算法是神经网络中研究最多、使用最多的算法之一,它用于将输出层中的误差传播到隐藏层的神经元,然后用于更新权重。

学习 BPN 算法可以分成以下两个过程:

  1. 正向传播:输入被馈送到网络,信号从输入层通过隐藏层传播到输出层。在输出层,计算误差和损失函数。
  2. 反向传播:在反向传播中,首先计算输出层神经元损失函数的梯度,然后计算隐藏层神经元损失函数的梯度。接下来用梯度更新权重。

这两个过程重复迭代直到收敛。

前期准备

首先给网络提供 M 个训练对(X,Y),X 为输入,Y 为期望的输出。输入通过激活函数 g(h) 和隐藏层传播到输出层。输出 Yhat 是网络的输出,得到 error=Y-Yhat。其损失函数 J(W) 如下:


 

其中,i 取遍所有输出层的神经元(1 到 N)。然后可以使用 J(W) 的梯度并使用链式法则求导,来计算连接第 i 个输出层神经元到第 j 个隐藏层神经元的权重 Wij 的变化:

这里,Oj 是隐藏层神经元的输出,h 表示隐藏层的输入值。这很容易理解,但现在怎么更新连接第 n 个隐藏层的神经元 k 到第 n+1 个隐藏层的神经元 j 的权值 Wjk?过程是相同的:将使用损失函数的梯度和链式法则求导,但这次计算 Wjk

现在已经有方程了,看看如何在 TensorFlow 中做到这一点。在这里,还是使用 MNIST 数据集(http://yann.lecun.com/exdb/MNIST/)。

具体实现过程

现在开始使用反向传播算法:

  1. 导入模块:


     
  2. 加载数据集,通过设置 one_hot=True 来使用独热编码标签:

     
  3. 定义超参数和其他常量。这里,每个手写数字的尺寸是 28×28=784 像素。数据集被分为 10 类,以 0 到 9 之间的数字表示。这两点是固定的。学习率、最大迭代周期数、每次批量训练的批量大小以及隐藏层中的神经元数量都是超参数。可以通过调整这些超参数,看看它们是如何影响网络表现的:

     
  4. 需要 Sigmoid 函数的导数来进行权重更新,所以定义它:

     
  5. 为训练数据创建占位符:

     
  6. 创建模型:

     
  7. 定义权重和偏置变量:

     
  8. 为正向传播、误差、梯度和更新计算创建计算图:

     
  9. 定义计算精度 accuracy 的操作:

     
  10. 初始化变量:

     
  11. 执行图:

     
  12. 结果如下:

解读分析

在这里,训练网络时的批量大小为 10,如果增加批量的值,网络性能就会下降。另外,需要在测试数据集上检测训练好的网络的精度,这里测试数据集的大小是 1000。

单隐藏层多层感知机在训练数据集上的准确率为 84.45,在测试数据集上的准确率为 92.1。这是好的,但不够好。MNIST 数据集被用作机器学习中分类问题的基准。接下来,看一下如何使用 TensorFlow 的内置优化器影响网络性能。

TensorFlow从0到1之TensorFlow实现反向传播算法(21)的更多相关文章

  1. [2] TensorFlow 向前传播算法(forward-propagation)与反向传播算法(back-propagation)

    TensorFlow Playground http://playground.tensorflow.org 帮助更好的理解,游乐场Playground可以实现可视化训练过程的工具 TensorFlo ...

  2. TensorFlow反向传播算法实现

    TensorFlow反向传播算法实现 反向传播(BPN)算法是神经网络中研究最多.使用最多的算法之一,用于将输出层中的误差传播到隐藏层的神经元,然后用于更新权重. 学习 BPN 算法可以分成以下两个过 ...

  3. TensorFlow从0到1之TensorFlow优化器(13)

    高中数学学过,函数在一阶导数为零的地方达到其最大值和最小值.梯度下降算法基于相同的原理,即调整系数(权重和偏置)使损失函数的梯度下降. 在回归中,使用梯度下降来优化损失函数并获得系数.本节将介绍如何使 ...

  4. Tensorflow笔记——神经网络图像识别(一)前反向传播,神经网络八股

      第一讲:人工智能概述       第三讲:Tensorflow框架         前向传播: 反向传播: 总的代码: #coding:utf-8 #1.导入模块,生成模拟数据集 import t ...

  5. TensorFlow从0到1之TensorFlow Keras及其用法(25)

    Keras 是与 TensorFlow 一起使用的更高级别的作为后端的 API.添加层就像添加一行代码一样简单.在模型架构之后,使用一行代码,你可以编译和拟合模型.之后,它可以用于预测.变量声明.占位 ...

  6. TensorFlow从0到1之TensorFlow多层感知机函数逼近过程(23)

    Hornik 等人的工作(http://www.cs.cmu.edu/~bhiksha/courses/deeplearning/Fall.2016/notes/Sonia_Hornik.pdf)证明 ...

  7. TensorFlow从0到1之TensorFlow常用激活函数(19)

    每个神经元都必须有激活函数.它们为神经元提供了模拟复杂非线性数据集所必需的非线性特性.该函数取所有输入的加权和,进而生成一个输出信号.你可以把它看作输入和输出之间的转换.使用适当的激活函数,可以将输出 ...

  8. TensorFlow从0到1之TensorFlow逻辑回归处理MNIST数据集(17)

    本节基于回归学习对 MNIST 数据集进行处理,但将添加一些 TensorBoard 总结以便更好地理解 MNIST 数据集. MNIST由https://www.tensorflow.org/get ...

  9. TensorFlow从0到1之TensorFlow csv文件读取数据(14)

    大多数人了解 Pandas 及其在处理大数据文件方面的实用性.TensorFlow 提供了读取这种文件的方法. 前面章节中,介绍了如何在 TensorFlow 中读取文件,本节将重点介绍如何从 CSV ...

随机推荐

  1. Spring 基于 Java 的配置

    前面已经学习如何使用 XML 配置文件来配置 Spring bean. 基于 Java 的配置可以达到基于XML配置的相同效果. 基于 Java 的配置选项,可以使你在不用配置 XML 的情况下编写大 ...

  2. LightOJ1236

    题目大意: 给你一个 n,请你找出共有多少对(i,j)满足 lcm(i,j) = n (i<=j) . 解题思路: 我们利用算术基本定理将 n,i,j 进行分解: n = P1a1 * P2a2 ...

  3. Spring Boot 教程 (3) - RESTful

    Spring Boot 教程 - RESTful 1. RESTful风格 1.1 简介与特点 RESTful是一种网络应用程序的设计风格和开发方式,基于HTTP,可以使用XML格式定义或JSON格式 ...

  4. U-Learning 后端开发日志(建设中...)

    目录 U-Learning--基于泛在学习的教学系统 项目背景 技术栈 框架 中间件 插件 里程碑 CentOS 7搭建JAVA开发环境 接口参数校验(不使用hibernate-validator,规 ...

  5. 【Leetcode】560. 和为K的子数组&974. 和可被 K 整除的子数组(前缀和+哈希表)

    public class Solution { public int subarraySum(int[] nums, int k) { int count = 0, pre = 0; HashMap ...

  6. 读Pyqt4教程,带你入门Pyqt4 _005

    对话框窗体或对话框是现代GUI应用不可或缺的一部分.dialog定义为两个或多个人之间的交谈.在计算机程序中dialog是一个窗体,用来和程序“交谈”.对话框用来输入数据.修改数据.改变程序设置等等. ...

  7. 中文分词工具——jieba

    汉字是智慧和想象力的宝库. --索尼公司创始人井深大 简介 在英语中,单词就是"词"的表达,一个句子是由空格来分隔的,而在汉语中,词以字为基本单位,但是一篇文章的表达是以词来划分的 ...

  8. IDEA 插件推荐 —— 让你写出好代码的神器!

    概述 今天介绍的插件主要是围绕编码规范的.有追求的程序员,往往都有代码洁癖,要尽量减少代码的「坏味道」. 代码静态检查是有很多种类,例如圈复杂度.重复率等.业界提供了很多静态检查的插件来识别这些不合规 ...

  9. Azure AD(三)知识补充-Azure资源的托管标识

    一,引言 来个惯例,吹水! 前一周因为考试,还有个人的私事,一下子差点颓废了.想了想,写博客这种的东西还是得坚持,再忙,也要检查.要养成一种习惯,同时这也是自我约束的一种形式.虽然说不能浪费大量时间在 ...

  10. PAT1033 旧键盘打字 (20分) (关于测试点4超时问题)

    1033 旧键盘打字 (20分)   旧键盘上坏了几个键,于是在敲一段文字的时候,对应的字符就不会出现.现在给出应该输入的一段文字.以及坏掉的那些键,打出的结果文字会是怎样? 输入格式: 输入在 2 ...