• 明天博士论文要答辩了,只有一张12G二手卡,今晚通宵要搞定10个模型实验
  • 挖槽,突然想出一个T9开天霹雳模型,加载不进去我那张12G的二手卡,感觉要错过今年上台Best Paper领奖
 
上面出现的问题主要是机器不够、内存不够用。在深度学习训练的时候,数据的batch size大小受到GPU内存限制,batch size大小会影响模型最终的准确性和训练过程的性能。在GPU内存不变的情况下,模型越来越大,那么这就意味着数据的batch size智能缩小,这个时候,梯度累积(Gradient Accumulation)可以作为一种简单的解决方案来解决这个问题。
 
下面这个图中橙色部分HERE就是梯度累积算法在AI系统中的大致位置,一般在AI框架/AI系统的表达层,跟算法结合比较紧密。

Batch size的作用

训练数据的Batch size大小对训练过程的收敛性,以及训练模型的最终准确性具有关键影响。通常,每个神经网络和数据集的Batch size大小都有一个最佳值或值范围。

不同的神经网络和不同的数据集可能有不同的最佳Batch size大小。
选择Batch size的时候主要考虑两个问题:
 
泛化性:大的Batch size可能陷入局部最小值。陷入局部最小值则意味着神经网络将在训练集之外的样本上表现得很好,这个过程称为泛化。因此,泛化性一般表示过度拟合。
 
收敛速度:小的Batch size可能导致算法学习收敛速度慢。网络模型在每个Batch的更新将会确定下一次Batch的更新起点。每次Batch都会训练数据集中,随机抽取训练样本,因此所得到的梯度是基于部分数据噪声的估计。在单次Batch中使用的样本越少,梯度估计准确度越低。换句话说,较小的Batch size可能会使学习过程波动性更大,从本质上延长算法收敛所需要的时间。
 
考虑到上面两个主要的问题,所以在训练之前需要选择一个合适的Batch size。

Batch size对内存的影响

虽然传统计算机在CPU上面可以访问大量RAM,还可以利用SSD进行二级缓存或者虚拟缓存机制。但是如GPU等AI加速芯片上的内存要少得多。这个时候训练数据Batch size的大小对GPU的内存有很大影响。
为了进一步理解这一点,让我们首先检查训练时候AI芯片内存中内存的内容:
  • 模型参数:网络模型需要用到的权重参数和偏差。
  • 优化器变量:优化器算法需要的变量,例如动量momentum。
  • 中间计算变量:网络模型计算产生的中间值,这些值临时存储在AI加速芯片的内存中,例如,每层激活的输出。
  • 工作区Workspace:AI加速芯片的内核实现是需要用到的局部变量,其产生的临时内存,例如算子D=A+B/C中B/C计算时产生的局部变量。
因此,Batch size越大,意味着神经网络训练的时候所需要的样本就越多,导致需要存储在AI芯片内存变量激增。在许多情况下,没有足够的AI加速芯片内存,Batch size设置得太大,就会出现OOM报错(Out Off Memor)。
 

使用大Batch size的方法

解决AI加速芯片内存限制,并运行大Batch size的一种方法是将数据Sample的Batch拆分为更小的Batch,叫做Mini-Batch。这些小Mini-Batch可以独立运行,并且在网络模型训练的时候,对梯度进行平均或者求和。主要实现有两种方式。
 
1)数据并行:使用多个AI加速芯片并行训练所有Mini-Batch,每份数据都在单个AI加速芯片上。累积所有Mini-Batch的梯度,结果用于在每个Epoch结束时求和更新网络参数。
 
2)梯度累积:按顺序执行Mini-Batch,同时对梯度进行累积,累积的结果在最后一个Mini-Batch计算后求平均更新模型变量。
 
虽然两种技术都挺像的,解决的问题都是内存无法执行更大的Batch size,但梯度累积可以使用单个AI加速芯片就可以完成啦,而数据并行则需要多块AI加速芯片,所以手头上只有一台12G二手卡的同学们赶紧把梯度累积用起来。
 

梯度累积原理

梯度累积是一种训练神经网络的数据Sample样本按Batch拆分为几个小Batch的方式,然后按顺序计算。
 
在进一步讨论梯度累积之前,我们来看看神经网络的计算过程。
 
深度学习模型由许多相互连接的神经网络单元所组成,在所有神经网络层中,样本数据会不断向前传播。在通过所有层后,网络模型会输出样本的预测值,通过损失函数然后计算每个样本的损失值(误差)。神经网络通过反向传播,去计算损失值相对于模型参数的梯度。最后这些梯度信息用于对网络模型中的参数进行更新。
 
优化器用于对网络模型模型权重参数更新的数学公式。以一个简单随机梯度下降(SGD)算法为例。
 
假设Loss Function函数公式为:

 
在构建模型时,优化器用于计算最小化损失的算法。这里SGD算法利用Loss函数来更新权重参数公式为:
 

 
其中theta是网络模型中的可训练参数(权重或偏差),lr是学习率,grad是相对于网络模型参数的损失。
 
梯度累积则是只计算神经网络模型,但是并不及时更新网络模型的参数,同时在计算的时候累积计算时候得到的梯度信息,最后统一使用累积的梯度来对参数进行更新。
 

 
在不更新模型变量的时候,实际上是把原来的数据Batch分成几个小的Mini-Batch,每个step中使用的样本实际上是更小的数据集。
 
在N个step内不更新变量,使所有Mini-Batch使用相同的模型变量来计算梯度,以确保计算出来得到相同的梯度和权重信息,算法上等价于使用原来没有切分的Batch size大小一样。即:
 

 
最终在上面步骤中累积梯度会产生与使用全局Batch size大小相同的梯度总和。
 
 
当然在实际工程当中,关于调参和算法上有两点需要注意的:
学习率 learning rate:一定条件下,Batch size越大训练效果越好,梯度累积则模拟了batch size增大的效果,如果accumulation steps为4,则Batch size增大了4倍,根据ZOMI的经验,使用梯度累积的时候需要把学习率适当放大。 归一化 Batch Norm:accumulation steps为4时进行Batch size模拟放大效果,和真实Batch size相比,数据的分布其实并不完全相同,4倍Batch size的BN计算出来的均值和方差与实际数据均值和方差不太相同,因此有些实现中会使用Group Norm来代替Batch Norm。

梯度累积实现

 
正常训练一个batch的伪代码:
 
for i, (images, labels) in enumerate(train_data):
# 1. forwared 前向计算
outputs = model(images)
loss = criterion(outputs, labels) # 2. backward 反向传播计算梯度
optimizer.zero_grad()
loss.backward()
optimizer.step()

  

  • model(images) 输入图像和标签,前向计算。
  • criterion(outputs, labels) 通过前向计算得到预测值,计算损失函数。
  • ptimizer.zero_grad() 清空历史的梯度信息。
  • loss.backward() 进行反向传播,计算当前batch的梯度。
  • optimizer.step() 根据反向传播得到的梯度,更新网络参数。
 
即在网络中输入一个batch的数据,就计算一次梯度,更新一次网络。
 
使用梯度累加后:
 
# 梯度累加参数
accumulation_steps = 4 for i, (images, labels) in enumerate(train_data):
# 1. forwared 前向计算
outputs = model(imgaes)
loss = criterion(outputs, labels) # 2.1 loss regularization loss正则化
loss += loss / accumulation_steps # 2.2 backward propagation 反向传播计算梯度
loss.backward() # 3. update parameters of net
if ((i+1) % accumulation)==0:
# optimizer the net
optimizer.step()
optimizer.zero_grad() # reset grdient

  

  • model(images) 输入图像和标签,前向计算。
  • criterion(outputs, labels) 通过前向计算得到预测值,计算损失函数。
  • loss / accumulation_steps loss每次更新,因此每次除以steps累积到原梯度上。
  • loss.backward() 进行反向传播,计算当前batch的梯度。
  • 多次循环伪代码步骤1-2,不清空梯度,使梯度累加在历史梯度上。
  • optimizer.step() 梯度累加一定次数后,根据所累积的梯度更新网络参数。
  • optimizer.zero_grad() 清空历史梯度,为下一次梯度累加做准备。
梯度累积就是,每次获取1个batch的数据,计算1次梯度,此时梯度不清空,不断累积,累积一定次数后,根据累积的梯度更新网络参数,然后清空所有梯度信息,进行下一次循环。
 

参考文献

AI系统——梯度累积算法的更多相关文章

  1. Silverlight 2.5D RPG游戏技巧与特效处理:(十一)AI系统

    Silverlight 2.5D RPG游戏技巧与特效处理:(十一)AI系统 作者: 深蓝色右手  来源: 博客园  发布时间: 2011-04-19 11:18  阅读: 1282 次  推荐: 0 ...

  2. AI技术原理|机器学习算法

    摘要 机器学习算法分类:监督学习.半监督学习.无监督学习.强化学习 基本的机器学习算法:线性回归.支持向量机(SVM).最近邻居(KNN).逻辑回归.决策树.k平均.随机森林.朴素贝叶斯.降维.梯度增 ...

  3. 广告系统中weak-and算法原理及编码验证

    wand(weak and)算法基本思路 一般搜索的query比较短,但如果query比较长,如是一段文本,需要搜索相似的文本,这时候一般就需要wand算法,该算法在广告系统中有比较成熟的应 该,主要 ...

  4. 腾讯公司数据分析岗位的hadoop工作 线性回归 k-means算法 朴素贝叶斯算法 SpringMVC组件 某公司的广告投放系统 KNN算法 社交网络模型 SpringMVC注解方式

    腾讯公司数据分析岗位的hadoop工作 线性回归 k-means算法 朴素贝叶斯算法 SpringMVC组件 某公司的广告投放系统 KNN算法 社交网络模型 SpringMVC注解方式 某移动公司实时 ...

  5. 基于R语言的梯度推进算法介绍

    通常来说,我们可以从两个方面来提高一个预测模型的准确性:完善特征工程(feature engineering)或是直接使用Boosting算法.通过大量数据科学竞赛的试炼,我们可以发现人们更钟爱于Bo ...

  6. 科学家开发新AI系统,可读取大脑信息并表达复杂思想

    我们终于找到了一种方法,可以在核磁共振成像的信号中看到这种复杂的想法.美国卡内基梅隆大学的Marcel Just说,思维和大脑活动模式之间的对应关系告诉我们这些想法是如何构建的. 人工智能系统表明,大 ...

  7. 梯度优化算法Adam

    最近读一个代码发现用了一个梯度更新方法, 刚开始还以为是什么奇奇怪怪的梯度下降法, 最后分析一下是用一阶梯度及其二次幂做的梯度更新.网上搜了一下, 果然就是称为Adam的梯度更新算法, 全称是:自适应 ...

  8. [源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积

    [源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积 目录 [源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积 0x00 摘要 0x01 概述 1.1 前文回 ...

  9. ptorch常用代码梯度篇(梯度裁剪、梯度累积、冻结预训练层等)

    梯度裁剪(Gradient Clipping) 在训练比较深或者循环神经网络模型的过程中,我们有可能发生梯度爆炸的情况,这样会导致我们模型训练无法收敛. 我们可以采取一个简单的策略来避免梯度的爆炸,那 ...

随机推荐

  1. VSCode上发布第一篇博客

    在VSCode上发布到博客园的第一篇博客 前段时间在VSCode安装好插件WriteCnblog,多次检查writeCnblog configuration配置信息也是完全正确的,但是一直没能在VSC ...

  2. C++内存管理:简易内存池的实现

    什么是内存池? 在上一篇 C++内存管理:new / delete 和 cookie中谈到,频繁的调用 malloc 会影响运行效率以及产生额外的 cookie, 而内存池的思想是预先申请一大块内存, ...

  3. Java oop 笔记

    摘要网址:http://note.youdao.com/noteshare?id=bbdc0b970721e40d327db983a2f96371

  4. MySQL慢日志优化

    慢日志的性能问题 造成 I/O 和 CPU 资源消耗:慢日志通常会扫描大量非目的的数据,自然就会造成 I/O 和 CPU 的资源消耗,影响到其他业务的正常使用,有可能因为单个慢 SQL 就能拖慢整个数 ...

  5. JAVA使用netty建立websocket连接

    依赖 <!-- https://mvnrepository.com/artifact/commons-io/commons-io --> <dependency> <gr ...

  6. protoc格式生成java文件

    下载protoc.exe 地址:https://yvioo.lanzoui.com/i12opqs7q9g 下载好之后 ,把protoc文件和exe放在一个文件夹内 用记事本打开protoc,删掉包路 ...

  7. ubuntu web服务器配置

    1.安装Apachesudo apt-get install apache2 查看状态: service apache2 status/start/stop/restartWeb目录: /var/ww ...

  8. C++基础之虚析构函数原理

    结论 虚函数表指针 + 虚函数表 共同实现 演示 VS2017(32位) 基类有析虚构函数 一段代码演示 #include <iostream> #include <memory&g ...

  9. LeetCode每日一题打卡组队监督!刷题群!

    近 2000 人已经加入共同刷题啦! 群友每天都会在群里给大家讲解算法题 每周日「负雪明烛」组织直播讲题 我相信来看我博客的大部分人都是通过LeetCode刷题过来的.最近发现LeetCode中文网站 ...

  10. 【LeetCode】616. Add Bold Tag in String 解题报告(C++)

    作者: 负雪明烛 id: fuxuemingzhu 个人博客:http://fuxuemingzhu.cn/ 目录 题目描述 题目大意 解题方法 遍历 日期 题目地址:https://leetcode ...