在采用随机梯度下降算法训练神经网络时,使用滑动平均模型可以提高最终模型在测试集数据上的表现。在Tensflow中提供了tf.train.ExponentialMovingAverage来实现滑动平均模型。在初始化ExponentialMovingAverage时,需要提供一个衰减率(decay)。这个衰减率将用于控制模型更新的速度。ExponentialMovingAverage对每一个变量会维护一个影子变量(shadowvariable),这个影子变量的初始值就是相应变量的初始值,而每次运行变量更新时,影子变量的值会更新为:

shadow_variable=decay x shadow_variable+(1-decay) x variable

其中shadow_variable 为影子变量,variable为待更新的变量,decay为衰减率。decay决定了模型更新的速度,decay越大模型越趋于稳定。在实际应用中,decay一般会设成非常接近1的数(比如0.999或0.9999)。为了使得模型在训练前期可以更新得更快,ExponentialMovingAverage还提供了num_updates参数来动态设置decay的大小.

下面是ExponentailMovingAverage使用示例

  1. # -*- coding:UTF- -*-
  2. import tensorflow as tf
  3. # 定义一个初始为0的变量来计算滑动平均
  4.  
  5. v1=tf.Variable(,dtype=tf.float32)
  6.  
  7. #这里的step变量模拟神经网络中迭代的轮数,可以用于动态控制衰减率
  8. step=tf.Variable(,trainable=False)
  9.  
  10. #定义一个滑动平均的类,初始化时给定了衰减率(0.99)和控制衰减率的变量step
  11. ema=tf.train.ExponentialMovingAverage(0.99,step)
  12.  
  13. # 定义一个更新变量滑动平均的操作,这里给定一个列表,每次执行这个操作时,这个列表中的变量的值都会更新
  14.  
  15. maintain_averages_op=ema.apply([v1])
  16. with tf.Session() as sess:
  17. # 初始化所有变量
  18. init_op=tf.global_variables_initializer()
  19. sess.run(init_op)
  20.  
  21. # 通过ema.average(v1)获取滑动平均之后变量的取值。在初始化之后变量v1的值和v1的滑动平均都为0
  22.  
  23. print sess.run([v1,ema.average(v1)])
  24. # 更新变量v1的值到5
  25. sess.run(tf.assign(v1,))
  26. # 更新v1的滑动平均值,衰减率为min{0.99,(+step)/(+step)=0.1}=0.1
  27. # 所以v1的滑动平均会被更新为0.*+0.9*=4.5
  28.  
  29. sess.run(maintain_averages_op)
  30. print sess.run([v1,ema.average(v1)])
  31.  
  32. # 更新 step的值为10000
  33. sess.run(tf.assign(step,))
  34. # 更新 v1的值为10。
  35. sess.run(tf.assign(v1,))
  36. # 更新v1 的滑动平均值。衰减率为min(0.99,(+step)/(+step)≈0.999}=0.99
  37. # 所以v1的滑动平均会被更新为0.*4.5+0.01*=4.555
  38.  
  39. sess.run(maintain_averages_op)
  40. print sess.run([v1,ema.average(v1)])
  41.  
  42. #再次更新滑动平均值,得到的新滑动平均值为0.*4.555+0.01*=4.60945
  43.  
  44. sess.run(maintain_averages_op)
  45. print sess.run([v1,ema.average(v1)])

结果如下

[0.0, 0.0]
[5.0, 4.5]
[10.0, 4.555]
[10.0, 4.60945]

tensorflow随机梯度下降算法使用滑动平均模型的更多相关文章

  1. Tensorflow中的滑动平均模型

    原文链接 在Tensorflow的教程里面,使用梯度下降算法训练神经网络时,都会提到一个使模型更加健壮的策略,即滑动平均模型. 基本思想 在使用梯度下降算法训练模型时,每次更新权重时,为每个权重维护一 ...

  2. tensorflow笔记之滑动平均模型

    tensorflow使用tf.train.ExponentialMovingAverage实现滑动平均模型,在使用随机梯度下降方法训练神经网络时候,使用这个模型可以增强模型的鲁棒性(robust),可 ...

  3. Tensorflow滑动平均模型tf.train.ExponentialMovingAverage解析

    觉得有用的话,欢迎一起讨论相互学习~Follow Me 移动平均法相关知识 移动平均法又称滑动平均法.滑动平均模型法(Moving average,MA) 什么是移动平均法 移动平均法是用一组最近的实 ...

  4. 监督学习:随机梯度下降算法(sgd)和批梯度下降算法(bgd)

    线性回归 首先要明白什么是回归.回归的目的是通过几个已知数据来预测另一个数值型数据的目标值. 假设特征和结果满足线性关系,即满足一个计算公式h(x),这个公式的自变量就是已知的数据x,函数值h(x)就 ...

  5. 监督学习——随机梯度下降算法(sgd)和批梯度下降算法(bgd)

    线性回归 首先要明白什么是回归.回归的目的是通过几个已知数据来预测另一个数值型数据的目标值. 假设特征和结果满足线性关系,即满足一个计算公式h(x),这个公式的自变量就是已知的数据x,函数值h(x)就 ...

  6. tensorflow入门笔记(二) 滑动平均模型

    tensorflow提供的tf.train.ExponentialMovingAverage 类利用指数衰减维持变量的滑动平均. 当训练模型的时候,保持训练参数的滑动平均是非常有益的.评估时使用取平均 ...

  7. 78、tensorflow滑动平均模型,用来更新迭代的衰减系数

    ''' Created on 2017年4月21日 @author: weizhen ''' #4.滑动平均模型 import tensorflow as tf #定义一个变量用于计算滑动平均,这个变 ...

  8. 吴裕雄 PYTHON 神经网络——TENSORFLOW 滑动平均模型

    import tensorflow as tf v1 = tf.Variable(0, dtype=tf.float32) step = tf.Variable(0, trainable=False) ...

  9. 随机梯度下降算法求解SVM

    测试代码(matlab)如下: clear; load E:\dataset\USPS\USPS.mat; % data format: % Xtr n1*dim % Xte n2*dim % Ytr ...

随机推荐

  1. [HihoCoder1398]网络流五·最大权闭合子图

    题目大意:有$N$项活动$M$个人,每个活动$act_i$有一个正的权值$a_i$,每个人$stu_i$有一个负的权值$b_i$.每项活动能够被完成当且仅当该项活动所需的所有人到场.如何选择活动使最终 ...

  2. 虚拟机性能监控与故障处理工具------JDK的命令行工具

    ①jps:虚拟机进程状况工具 功能:列出正在运行的虚拟机进程,并显示1.虚拟机执行主类名称以及2.这些进程的本地虚拟机唯一ID(LVMID). 使用频率最高的JDK命令行工具,其他的JDK工具大多需要 ...

  3. BZOJ3861 : Tree

    把集合看成左边的点,图中的点看成右边的点,若集合$i$不包含$j$,则连边$i->j$,得到一个二分图,等价于求这个二分图的完备匹配个数. 设$f[i][j]$表示考虑了前$i$个集合,匹配了$ ...

  4. IE6条件下的bug与常见的bug及其解决方法

    1.IE6条件下有双倍的margin 解决办法:给这个浮动元素增加display:inline属性 2. 图片底部有3像素问题 解决办法:display:block;或者vertical-align: ...

  5. 推荐两款好用的反编译工具(Luyten,Jadx)

    使用JD-Gui打开单个.class文件,总是报错// INTERNAL ERROR 但当我用jd-gui反编译前面操作获得的jar文件的时,但有一部分类不能显示出来--constants类,仅仅显示 ...

  6. C++.Linux下redis编程:error while loading shared libraries: libhiredis.so.0.13

    编译 sudo gcc -o sltest01 sltest01.c -L/usr/local/lib/ -lhiredis 运行 sudo ./sltest01 编译成功后运行报错信息: ./slt ...

  7. Linux和类Unix系统上5个最佳开源备份工具

    一个好的备份最基本的目的就是为了能够从一些错误中恢复: 人为的失误 磁盘阵列或是硬盘故障 文件系统崩溃 数据中心被破坏等等. 所以,我为大家罗列了一些开源的软件备份工具. 当为一个企业选择备份工具的时 ...

  8. android: 在android studio中使用retrolambda的步骤

    找了各种说明,包括retrolambda官方文档都没有试成功 最后在这个链接中找到答案:http://blog.csdn.net/qq_26819733/article/details/5222565 ...

  9. WIN10平板如何录制视频,为什么录制屏幕无法播放

    你的平板分辨率太高(系统推荐2736X1824),实际上一半就够了(1368X912),因为大部分传统显示器分辨率只有1280X720这种.把分辨率调低还有很多的好处,因为很多软件在分辨率太高的情况下 ...

  10. 基础知识:什么是ASP.NET Razor页面?

    Razor页面与ASP.NET MVC开发使用的视图组件非常相似,它们具有所有相同的语法和功能. 最关键的区别是模型和控制器代码也包含在Razor页面中.它更像是一个MVVM(Model-View-V ...