在采用随机梯度下降算法训练神经网络时,使用 tf.train.ExponentialMovingAverage 滑动平均操作的意义在于提高模型在测试数据上的健壮性(robustness)

tensorflow 下的 tf.train.ExponentialMovingAverage 需要提供一个衰减率(decay)。该衰减率用于控制模型更新的速度。该衰减率用于控制模型更新的速度,ExponentialMovingAverage 对每一个(待更新训练学习的)变量(variable)都会维护一个影子变量(shadow variable)。影子变量的初始值就是这个变量的初始值,

shadow_variable=decay×shadow_variable+(1−decay)×variable" role="presentation">shadow_variable=decay×shadow_variable+(1−decay)×variableshadow_variable=decay×shadow_variable+(1−decay)×variable

由上述公式可知, decay" role="presentation">decaydecay 控制着模型更新的速度,越大越趋于稳定。实际运用中,decay" role="presentation">decaydecay 一般会设置为十分接近 1 的常数(0.99或0.999)。为了使得模型在训练的初始阶段更新得更快,ExponentialMovingAverage 还提供了 num_updates 参数来动态设置 decay 的大小:

decay=min{decay,1+num_updates10+num_updates}" role="presentation">decay=min{decay,1+num_updates10+num_updates}decay=min{decay,1+num_updates10+num_updates}
import tensorflow as tf

v1 =tf.Variable(dtype=tf.float32, initial_value=0.)
decay = .99
num_updates = tf.Variable(0, trainable=False)
ema = tf.train.ExponentialMovingAverage(decay=decay, num_updates=num_updates) update_var_list = [v1] # 定义更新变量列表
ema_apply = ema.apply(update_var_list) with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run([v1, ema.average(v1)]))
# [0.0, 0.0](此时 num_updates = 0 ⇒ decay = .1, ),shadow_variable = variable = 0. sess.run(tf.assign(v1, 5))
sess.run(ema_apply)
print(sess.run([v1, ema.average(v1)]))
# 此时,num_updates = 0 ⇒ decay =.1, v1 = 5;
# shadow_variable = 0.1 * 0 + 0.9 * 5 = 4.5 ⇒ variable
sess.run(tf.assign(num_updates, 10000))
sess.run(tf.assign(v1, 10))
sess.run(ema_apply)
print(sess.run([v1, ema.average(v1)]))
# decay = .99,
# shadow_variable = 0.99 * 4.5 + .01*10 ⇒ 4.555
sess.run(ema_apply)
print(sess.run([v1, ema.average(v1)]))
# decay = .99
# shadow_variable = .99*4.555 + .01*10 = 4.609

tensorflow 下的滑动平均模型 —— tf.train.ExponentialMovingAverage的更多相关文章

  1. Tensorflow滑动平均模型tf.train.ExponentialMovingAverage解析

    觉得有用的话,欢迎一起讨论相互学习~Follow Me 移动平均法相关知识 移动平均法又称滑动平均法.滑动平均模型法(Moving average,MA) 什么是移动平均法 移动平均法是用一组最近的实 ...

  2. tf.train.ExponentialMovingAverage

    这个函数可以参考吴恩达deeplearning.ai中的指数加权平均. 和指数加权平均不一样的是,tensorflow中提供的这个函数,能够让decay_rate随着step的变化而变化.(在训练初期 ...

  3. TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model

      TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-i ...

  4. TensorFlow 实战(二)—— tf.train(优化算法)

    Training | TensorFlow tf 下以大写字母开头的含义为名词的一般表示一个类(class) 1. 优化器(optimizer) 优化器的基类(Optimizer base class ...

  5. tensorflow:实战Google深度学习框架第四章02神经网络优化(学习率,避免过拟合,滑动平均模型)

    1.学习率的设置既不能太小,又不能太大,解决方法:使用指数衰减法 例如: 假设我们要最小化函数 y=x2y=x2, 选择初始点 x0=5x0=5  1. 学习率为1的时候,x在5和-5之间震荡. im ...

  6. (转)深入解析TensorFlow中滑动平均模型与代码实现

    本文链接:https://blog.csdn.net/m0_38106113/article/details/81542863 指数加权平均算法的原理 TensorFlow中的滑动平均模型使用的是滑动 ...

  7. day-18 滑动平均模型测试样例

    为了使训练模型在测试数据上有更好的效果,可以引入一种新的方法:滑动平均模型.通过维护一个影子变量,来代替最终训练参数,进行训练模型的验证. 在tensorflow中提供了ExponentialMovi ...

  8. deep_learning_Function_tf.train.ExponentialMovingAverage()滑动平均

    近来看batch normalization的代码时,遇到tf.train.ExponentialMovingAverage()函数,特此记录. tf.train.ExponentialMovingA ...

  9. TensorFlow函数(四)tf.trainable_variable() 和 tf.all_variable()

    tf.trainable_variable() 此函数返回的是需要训练的变量列表 tf.all_variable() 此函数返回的是所有变量列表 v = tf.Variable(tf.constant ...

随机推荐

  1. Json应用案例

    Json应用案例之FastJson   这几天在网上找关于Json的一些案例,无意当中找到了一个我个人感觉比较好的就是阿里巴巴工程师写的FastJson. package com.jerehedu.f ...

  2. launcher- 第三方应用图标替换

    有时候我们感觉第三方应用的icon不美观,或者跟我们主题风格不一致,这时候我们希望换成我们想要的icon,那我们可以这么做(以更换QQ应用icon为例): 1.首先我们当然要根据自己的需要做一张替换i ...

  3. spring的BeanWrapper类的原理和使用方法

    转自:http://blog.sina.com.cn/s/blog_79ae79b30100t4hh.html 如果动态设置一个对象属性,可以借助Java的Reflection机制完成: Class ...

  4. div+css制作表格

    html: <div class="table"> <h2 class="table-caption">花名册:</h2> ...

  5. Unity实现发送QQ邮件功能

    闲来无聊,用Unity简单实现了一个发送邮件的功能,希望与大家互相交流互相进步,大神勿喷,测试的是QQ邮件用到的是MailMessage类和SmtpClient类首先如果发送方使用的是个人QQ邮箱账号 ...

  6. CISP/CISA 每日一题 二

    CISA 观察和测试用户操作程序 1.职责分离:确保没人具有执行多于一个下列处理过程的能力:启动.授权.验证或分发 2.输入授权:可以通过在输入文件上的书面授权或唯一口令的使用来获得证据 3.平衡:验 ...

  7. Java中的线程模型及实现方式

    概念: 线程是一个程序内部的顺序控制流 线程和进程的比较: 每个进程都有独立的代码和数据空间(进程上下文),进程切换的开销大. 线程:轻量的进程,同一类线程共享代码和数据空间,每个线程有独立的运行栈和 ...

  8. Android Service com.android.exchange.ExchangeService has leaked ServiceConnection

    启动Android项目的时候,clean  Project的时候,报错: android.app.ServiceConnectionLeaked: Service com.android.exchan ...

  9. c# 调用 C++ dll 传入传出 字符串

    c# 调用 C++ dll 传入传出 字符串 2013-07-02 09:30 7898人阅读 评论(2) 收藏 举报 本文章已收录于:   分类: windows 版权声明:随便转载,随便使用. C ...

  10. [Android 性能优化系列]内存之提升篇--应用应该怎样管理内存

    大家假设喜欢我的博客,请关注一下我的微博,请点击这里(http://weibo.com/kifile),谢谢 转载请标明出处(http://blog.csdn.net/kifile),再次感谢 原文地 ...