本文翻译节选自1998-Efficient BackProp, Yann LeCun et al..

4.1 随机VS批训练

每一次迭代, 传统训练方式都需要遍历所有数据集来计算平均梯度. 批训练也同样. 但你也可以使用随机训练的方法: 每次随机选择一个样本$\{Z^t, D^t\}$. 使用它来计算对应的梯度从而更新权值:

$W(t+1) = W(t) - \eta \frac{\partial E^t}{\partial W}$ (11).

这种估计梯度的方式是有噪的, 可能不会每次迭代权值都会精确地沿着正确的梯度下降. 恰恰是这种噪声使得随机训练有如下的优势:

1. 随机训练比批训练快

2. 随机训练结果更好

3. 随机训练可以用来追踪变化

随机训练在大数据集上往往比批训练更快. 为何? 我们举个简单的例子: 如果一个大小为1000的训练集恰好由10个完全一样的子集组成. 在1000个样本中求取的平均梯度和仅仅计算前100个的平均梯度是一样的. 因此批训练浪费了很多时间. 另一方面, 随机梯度将一个epoch视为在大小为100的训练集训练10次. 在实际中, 样本很少在数据集中出现超过1次, 但通常会有很多相似的特征聚合在一起. 在音素分类中, 所有音素/ae/的特征都会包含有相同的信息. 正是这种冗余, 导致了批训练的慢速.

随机训练的结果往往更好是由于权值更新中的噪声. 非线性网络通常有多个深度不同的局部最小值. 训练的目的是找到其中的一个. 批训练会找到权值初始值附近的一个最小值. 在随机训练中, 权值更新中的噪声会导致权值跳转到另一个, 并且有可能是更好的一个解. 这可以参考文献[15,30].

当建模的函数随时间变化时随机训练的方式会更有用. 一个常见的工业应用场景就是数据的分布随着时间的变化而变化(机器随使用时间而老化). 如果训练机不能检测到这种变化并随着它改变, 它就不能正确地学习到数据, 泛化误差也会变大. 使用批训练, 这些变化就无从检测, 学习到的结果也会很差. 使用随机训练, 如果恰当地进行处理的话(见4.7), 它就能够跟踪这些变化, 并产生较好的估计结果.

尽管随机训练有一些优势, 批训练也有如下的几个优点, 使得我们有时也不得不使用它:

1. 收敛条件已经被证明

2. 许多加速训练的方法仅对批训练有效

3. 权值动态特性和收敛速率的理论分析更加简单

这些优点使得我们可以忽略那些使随机训练更好的噪声. 这种噪声, 对于找到更好的局部最佳解非常关键, 同时也避免了最小值的全收敛(full convergence). 和收敛到确切的最小值不同, 收敛进程根据权值的波动而减慢(stall out). 波动的大小取决于随机更新的噪声幅度. 在局部最小值波动的方差和学习速率成比例[28,27,6]. 为了减小波动, 有两种方法: 一是减小学习速率; 二是使用一个自适应的批大小. 在[13,30,36,35]的理论中, 学习速率的最佳退火步骤的形式为:

$\eta \sim \frac{c}{t}$ (12)

其中$t$是模式的数量, $c$是一个常量.

另一种去除噪声的方法是使用"最小批", 即从一个最小的批尺寸开始, 随着训练进行不断增大这个尺寸. Moller使用过这种方法[25], Orr讨论过用这种方法解决线性问题[31]. 然而, 批尺寸增加的速率和批中放置的输入和学习速率一样非常难以确定. 在随机学习中的学习速率的大小可以参考最小批的大小.

注意, 移除数据中的噪声的问题可能在一些人看来并不重要, 因为噪声有助于泛化. 实际情况却是在噪声生效之前, 过训练已经发生.

批训练的另一个优点就是你可以使用二阶的方法来加速学习过程. 二阶方法通过估计梯度和误差平面的曲率来加速学习. 给定曲率, 你就可以大致估计实际最小值的位置.

尽管批更新的好处很多, 随机训练仍然被更多人青睐, 尤其是处理大数据集的时候, 因为它更快.

[NN] 随机VS批训练的更多相关文章

  1. pytorch1.0批训练神经网络

    pytorch1.0批训练神经网络 import torch import torch.utils.data as Data # Torch 中提供了一种帮助整理数据结构的工具, 叫做 DataLoa ...

  2. 随机切分csv训练集和测试集

    使用numpy切分训练集和测试集 觉得有用的话,欢迎一起讨论相互学习~Follow Me 序言 在机器学习的任务中,时常需要将一个完整的数据集切分为训练集和测试集.此处我们使用numpy完成这个任务. ...

  3. [Python] 波士顿房价的7种模型(线性拟合、二次多项式、Ridge、Lasso、SVM、决策树、随机森林)的训练效果对比

    目录 1. 载入数据 列解释Columns: 2. 数据分析 2.1 预处理 2.2 可视化 3. 训练模型 3.1 线性拟合 3.2 多项式回归(二次) 3.3 脊回归(Ridge Regressi ...

  4. pytorch 6 batch_train 批训练

    import torch import torch.utils.data as Data torch.manual_seed(1) # reproducible # BATCH_SIZE = 5 BA ...

  5. daal4py 随机森林模型训练mnist并保存模型给C++ daal predict使用

    # daal4py Decision Forest Classification Training example Serialization import daal4py as d4p import ...

  6. pytorch批训练数据构造

    这是对莫凡python的学习笔记. 1.创建数据 import torch import torch.utils.data as Data BATCH_SIZE = 8 x = torch.linsp ...

  7. 利用VGG19实现火灾分类(附tensorflow代码及训练集)

    源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...

  8. 利用卷积神经网络(VGG19)实现火灾分类(附tensorflow代码及训练集)

    源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...

  9. Tensorflow之训练MNIST(1)

    先说我遇到的一个坑,在下载MNIST训练数据的时候,代码报错: urllib.error.URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FA ...

随机推荐

  1. android升级adt和sdk之后无法识别SDK Location的一个解决方式

    我把android的adt和sdk从4.0升级到4.2,发现eclipse的android设置里面原来列出的各种api level的platform消失了,而且无法新建android工程.而且检查过了 ...

  2. python 小练习 10

    给你一个十进制数a,将它转换成b进制数,如果b>10,用大写字母表示(10用A表示,等等) a为32位整数,2 <= b <= 16 如a=3,b = 2, 则输出11 AC: di ...

  3. http1.0 1.1 2.0区别

    http1.0 1.1 2.0区别 转载:https://blog.csdn.net/linsongbin1/article/details/54980801/ 1.HTTP1.0 1.1区别 (1) ...

  4. 【vue系列】elementUI 穿梭框右侧获取当前选中项的值的思路

    最近 做了一个需求 在查询结果的表格中,选取(可多选)一些值,获取到保单号后,打开一个elementUI的穿梭框,然后获取到所有业务员,选取一些业务员后,将上一步获取到的保单号传递给业务员. 画个示意 ...

  5. memory prefix vice ,with out 1

    1● vice 副的   2● with 向后,相反  

  6. DiskGenius注册算法简析

    初次接触DiskGenius已经成为遥远的记忆,那个时候还只有DOS版本.后来到Windows版,用它来处理过几个找回丢失分区的案例,方便实用.到现在它的功能越来越强大,成为喜好启动技术和桌面支持人员 ...

  7. Wii Party U 游戏简介

  8. 深入理解Linux网络技术内幕——用户空间与内核空间交互

    概述:     内核空间与用户空间经常需要进行交互.举个例子:当用户空间使用一些配置命令如ifconfig或route时,内核处理程序就要响应这些处理请求.     用户空间与内核有多种交互方式,最常 ...

  9. ContentType&CORS&Git

    ContentType django内置的ContentType组件就是帮我们做连表操作 如果一个表与其他表有多个外键关系,我们可以通过ContentType来解决这种关联 from django.d ...

  10. python3.6 django2.06 使用QQ邮箱发送邮件

    开通QQ邮箱IMAP/SMTP服务,忘记了,重新开通一下,记住密码串. import smtplib from email.mime.text import MIMEText # 收件人列表 mail ...