滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量。

1、滑动平均求解对象初始化

ema = tf.train.ExponentialMovingAverage(decay,num_updates)

参数decay

`shadow_variable = decay * shadow_variable + (1 - decay) * variable`

参数num_updates

`min(decay, (1 + num_updates) / (10 + num_updates))`

2、添加/更新变量

添加目标变量,为之维护影子变量

注意维护不是自动的,需要每轮训练中运行此句,所以一般都会使用tf.control_dependencies使之和train_op绑定,以至于每次train_op都会更新影子变量

ema.apply([var0, var1])

3、获取影子变量值

这一步不需要定义图中,从影子变量集合中提取目标值

sess.run(ema.average([var0, var1]))

4、保存&载入影子变量

我们知道,在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身。

保存影子变量

建立tf.train.ExponentialMovingAverage对象后,Saver正常保存就会存入影子变量,命名规则是"v/ExponentialMovingAverage"对应变量”v“

import tensorflow as tf  

if __name__ == "__main__":
v = tf.Variable(0.,name="v")
#设置滑动平均模型的系数
ema = tf.train.ExponentialMovingAverage(0.99)
#设置变量v使用滑动平均模型,tf.all_variables()设置所有变量
op = ema.apply([v])
#获取变量v的名字
print(v.name)
#v:0
#创建一个保存模型的对象
save = tf.train.Saver()
sess = tf.Session()
#初始化所有变量
init = tf.initialize_all_variables()
sess.run(init)
#给变量v重新赋值
sess.run(tf.assign(v,10))
#应用平均滑动设置
sess.run(op)
#保存模型文件
save.save(sess,"./model.ckpt")
#输出变量v之前的值和使用滑动平均模型之后的值
print(sess.run([v,ema.average(v)]))
#[10.0, 0.099999905]

载入影子变量并映射到变量

利用了Saver载入模型的变量名映射功能,实际上对所有的变量都可以如此操作『TensorFlow』模型载入方法汇总

v = tf.Variable(1.,name="v")
#定义模型对象
saver = tf.train.Saver({"v/ExponentialMovingAverage":v})
sess = tf.Session()
saver.restore(sess,"./model.ckpt")
print(sess.run(v))
#0.0999999

这里特别需要注意的一个地方就是,在使用tf.train.Saver函数中,所传递的模型参数是{"v/ExponentialMovingAverage":v}而不是{"v":v},如果你使用的是后面的参数,那么你得到的结果将是10而不是0.09,那是因为后者获取的是变量本身而不是影子变量

使用这种方式来读取模型文件的时候,还需要输入一大串的变量名称。

variables_to_restore函数的使用

v = tf.Variable(1.,name="v")
#滑动模型的参数的大小并不会影响v的值
ema = tf.train.ExponentialMovingAverage(0.99)
print(ema.variables_to_restore())
#{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
sess = tf.Session()
saver = tf.train.Saver(ema.variables_to_restore())
saver.restore(sess,"./model.ckpt")
print(sess.run(v))
#0.0999999

variables_to_restore会识别网络中的变量,并自动生成影子变量名。

通过使用variables_to_restore函数,可以使在加载模型的时候将影子变量直接映射到变量的本身,所以我们在获取变量的滑动平均值的时候只需要获取到变量的本身值而不需要去获取影子变量。

5、官方文档例子

官方文档中将每次apply更新就会自动训练一边模型,实际上可以反过来两者关系,《tf实战google》P128中有示例

|  Example usage when creating a training model:
 |  
 |  ```python
 |  # Create variables.
 |  var0 = tf.Variable(...)
 |  var1 = tf.Variable(...)
 |  # ... use the variables to build a training model...
 |  ...
 |  # Create an op that applies the optimizer.  This is what we usually
 |  # would use as a training op.
 |  opt_op = opt.minimize(my_loss, [var0, var1])
 |  
 |  # Create an ExponentialMovingAverage object
 |  ema = tf.train.ExponentialMovingAverage(decay=0.9999)
 |  
 |  with tf.control_dependencies([opt_op]):
 |      # Create the shadow variables, and add ops to maintain moving averages
 |      # of var0 and var1. This also creates an op that will update the moving
 |      # averages after each training step.  This is what we will use in place
 |      # of the usual training op.
 |      training_op = ema.apply([var0, var1])
 |  
 |  ...train the model by running training_op...
 |  ```

6、batch_normal的例子

和上面不太一样的是,batch_normal中不太容易绑定到train_op(位于函数体外面),则强行将两个variable的输出过程化为节点,绑定给参数更新步骤

def batch_norm(x,beta,gamma,phase_train,scope='bn',decay=0.9,eps=1e-5):
with tf.variable_scope(scope):
# beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0), trainable=True)
# gamma = tf.get_variable(name='gamma', shape=[n_out],
# initializer=tf.random_normal_initializer(1.0, stddev), trainable=True)
batch_mean,batch_var = tf.nn.moments(x,[0,1,2],name='moments')
ema = tf.train.ExponentialMovingAverage(decay=decay) def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean,batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean),tf.identity(batch_var)
# identity之后会把Variable转换为Tensor并入图中,
# 否则由于Variable是独立于Session的,不会被图控制control_dependencies限制 mean,var = tf.cond(phase_train,
mean_var_with_update,
lambda: (ema.average(batch_mean),ema.average(batch_var)))
   normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)
return normed

『TensorFlow』滑动平均的更多相关文章

  1. 『TensorFlow』专题汇总

    TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...

  2. 『TensorFlow』模型保存和载入方法汇总

    『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...

  3. 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍

    一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...

  4. 『TensorFlow』流程控制

    『PyTorch』第六弹_最小二乘法对比PyTorch和TensorFlow TensorFlow 控制流程操作 TensorFlow 提供了几个操作和类,您可以使用它们来控制操作的执行并向图中添加条 ...

  5. 『TensorFlow』读书笔记_降噪自编码器

    『TensorFlow』降噪自编码器设计  之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Aut ...

  6. 『TensorFlow』梯度优化相关

    tf.trainable_variables可以得到整个模型中所有trainable=True的Variable,也是自由处理梯度的基础 基础梯度操作方法: tf.gradients 用来计算导数.该 ...

  7. 『TensorFlow』命令行参数解析

    argparse很强大,但是我们未必需要使用这么繁杂的东西,TensorFlow自己封装了一个简化版本的解析方式,实际上是对argparse的封装 脚本化调用tensorflow的标准范式: impo ...

  8. 『TensorFlow』TFR数据预处理探究以及框架搭建

    一.TFRecord文件书写效率对比(单线程和多线程对比) 1.准备工作 # Author : Hellcat # Time : 18-1-15 ''' import os os.environ[&q ...

  9. 『TensorFlow』第七弹_保存&载入会话_霸王回马

    首更: 由于TensorFlow的奇怪形式,所以载入保存的是sess,把会话中当前激活的变量保存下来,所以必须保证(其他网络也要求这个)保存网络和载入网络的结构一致,且变量名称必须一致,这是caffe ...

随机推荐

  1. mysql新建数据库、新建用户及授权操作

    1.创建数据库create database if not exists test176 default charset utf8 collate utf8_general_ci; #utf8_gen ...

  2. winform做的excel与数据库的导入导出

    闲来无事,就来做一个常用的demo,也方便以后查阅 先看效果图 中间遇到的主要问题是获取当前连接下的所有的数据库以及数据库下所有的表 在网上查了查,找到如下的方法 首先是要先建立一个连接 _connM ...

  3. jQuery 学习笔记(3)(内容选择器、attr方法、prop方法,类的操作)

    内容选择器: 1.$("div:empty"): 空的div元素 2.$("div:parent"): 非空div元素 3.$("div:contai ...

  4. java框架之SpringBoot(2)-配置

    规范 SpringBoot 使用一个全局的配置文件,配置文件名固定为 application.properties 或 application.yml .比如我们要配置程序启动使用的端口号,如下: s ...

  5. [js]设计模式小结&对原型的修改

    js设计模式小结 工厂模式/构造函数--减少重复 - 创建对象有new - 自动创建obj,this赋值 - 无return 原型链模式 - 进一步去重 类是函数数据类型,每个函数都有prototyp ...

  6. WIN7虚拟桌面创建(多屏幕多桌面)

    Windows7/WIN7虚拟桌面怎么用怎么创建多桌面(摘录) 在使用电脑中经常会遇到桌面软件太多了不够用的感慨,那么要是一台电脑有多个桌面就好了.在windows10中自带已经支持了虚拟桌面,在wi ...

  7. 《linux就该这么学》第八节课:第六章存储结构与磁盘划分

     笔记 (借鉴请修改) 6.3.文件系统与数据资料 目前linux最常见的文件系统: ext3:日志文件系统.宕机时可自动恢复数据资料,容量越大恢复时间越长,且不能保证百分百不丢失.   ext4:e ...

  8. Ngnix 配置文件

    配置文件路径/usr/local/nginx/conf/nginx.conf user www www; #nginx 服务的伪用户和用户组 worker_processes auto; #启动进程, ...

  9. poj1733(并查集+离散化)

    题目大意:有一个长度为n的0,1字符串, 给m条信息,每条信息表示第x到第y个字符中间1的个数为偶数个或奇数个, 若这些信息中第k+1是第一次与前面的话矛盾, 输出k; 思路:x, y之间1的个数为偶 ...

  10. python 匿名函数捕获变量值 (执行时的值)