tensorflow提供的tf.train.ExponentialMovingAverage 类利用指数衰减维持变量的滑动平均。

当训练模型的时候,保持训练参数的滑动平均是非常有益的。评估时使用取平均后的参数有时会产生比使用最终训练好的参数值好很多的效果。方法apply()会添加被训练变量的影子副本和在影子副本中维持被训练变量的滑动平均的若干操作。该方法在创建训练模型时使用。那些保持维持滑动平均的操作(ops)一般会在每个训练步骤之后被执行。average()和average_name()方法分别提供了对影子变量和影子变量名字访问的途径。它们在建立评估模型或者从checkpoint文件恢复模型时能够用到,主要是帮助使用滑动平均代替最终训练结果进行评估。

滑动平均计算时使用指数衰减。当创建ExponentialMovingAverage对象时,衰减率应该被指定。影子变量和被训练参数的初始值相同。当执行更新滑动平均的操作时,每个影子变量会按照下面的公式进行更新:

shadow_variable -= (1 - decay) * (shadow_variable - variable)

上面的公式与下面的公式相同:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

decay决定了模型更新的速度,越大越趋于稳定。decay的合理取值接近1.0,所以 decay的取值一般包含多个9,如0.999、0.9999等。

创建训练模型时的用法示例:

# Create variables.
var0 = tf.Variable(...)
var1 = tf.Variable(...)
# ... use the variables to build a training model...
...
# Create an op that applies the optimizer. This is what we usually
# would use as a training op.
opt_op = opt.minimize(my_loss, [var0, var1]) # Create an ExponentialMovingAverage object
ema = tf.train.ExponentialMovingAverage(decay=0.9999) with tf.control_dependencies([opt_op]):
# Create the shadow variables, and add ops to maintain moving averages
# of var0 and var1. This also creates an op that will update the moving
# averages after each training step. This is what we will use in place
# of the usual training op.
training_op = ema.apply([var0, var1]) ...train the model by running training_op...

有两种使用滑动平均进行评估的方法:

  • 建立一个使用影子变量(shadow variables)而非变量(variables)的模型。为此,需要使用返回给定变量的影子变量的average()方法
  • 创建一个正常的模型,但是使用影子变量名加载checkpoint文件进行评估。为此,需要使用average_name()方法

恢复影子变量值的示例:

# Create a Saver that loads variables from their saved shadow values.
shadow_var0_name = ema.average_name(var0)
shadow_var1_name = ema.average_name(var1)
saver = tf.train.Saver({shadow_var0_name: var0, shadow_var1_name: var1})
saver.restore(...checkpoint filename...)
# var0 and var1 now hold the moving average values

部分方法:

__init__(decay,
num_updates=None,
zero_debias=False,
name='ExponentialMovingAverage')
# 创建一个ExponentialMovingAverage对象
# 为了创建影子变量和添加维持滑动平均的操作,apply()方法必须被调用
        # 可选参数num_updates允许对衰减率进行动态微调。典型的方式是通过记录训练次数,在每次训练开始之前降低衰减率。这样做可以使模型在训练的初始阶段更新
        # 得更快
        # zero_debias: 如果为True,Tensor objects会被初始化为无偏滑动平均
        # 衰减率更新公式为:
actual_decay = min(decay, (1 + num_updates) / (10 + num_updates))
        可选参数name是被添加到apply()方法中的操作名称的前缀。
apply(var_list=None)
# 维持变量的滑动平均,即对shadow variables进行计算
# var_list必须是Variable或者Tensor objects构成的列表。该方法会为列表中的所有元素创建影子变量,且变量对象的影子变量初始值和变量相同。影子变量
           也会被添加到GraphKeys.MOVING_AVERAGE_VARIABLES集合中。对于Tensor objects,影子变量会被初始化为0,同时被设置为无偏。
# 影子变量被设置trainable=False,并且被添加到GraphKeys.MOVING_AVERAGE_VARIABLES集合中,它们会在调用tf.global_variables()时被返回。
# 该方法返回一个按照要求更新所有影子变量的操作。同时需要注意的是,apply()可以在不同的var_list下被多次调用。
average(var)
# 返回变量的影子变量值,即读取影子变量shadow variables

average_name(var)
# 返回变量的影子变量名,即读取影子变量名
# 在模型训练期间计算变量的滑动平均和在评估时从计算得到的滑动平均恢复变量是ExponentialMovingAverage的典型应用。
# 为了恢复变量,必须知道影子变量名。然后影子变量名和对应的变量被传给Saver()对象来从计算得到滑动平均值恢复变量。
# Saver=tf.train.Saver({ema.average_name(var):var})
# 不管apply()方法有没有被调用,average_name()都可以被调用
variables_to_restore(moving_avg_variables=None)
# 返回要恢复的变量的名称映射
# moving_avg_variables : 需要使用滑动平均名进行恢复的变量构成的list;如果为None,会默认为variables.moving_average_variables() + va
                                   riables.trainable_variables()
# 如果变量有滑动平均,那么使用滑动平均变量名作为恢复时使用的名称;否则,使用变量名。
# 例如:
# variables_to_restore = ema.variables_to_restore()
# saver = tf.train.Saver(variables_to_restore)
# 以下是返回的一个映射的示例:
# conv/batchnorm/gamma/ExponentialMovingAverage: conv/batchnorm/gamma,
# conv_4/conv2d_params/ExponentialMovingAverage: conv_4/conv2d_params,
# global_step: global_step

示例:参考链接

import os
import tensorflow as tf os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 创建待训练参数
variable1 = tf.Variable(initial_value=0, trainable=True, dtype=tf.float32)
# 训练次数,不可训练
step_var = tf.Variable(initial_value=0, trainable=False)
# 创建滑动平均对象
ema = tf.train.ExponentialMovingAverage(decay=0.999, num_updates=step_var)
# 计算变量variable1的滑动平均操作
maintain_average_op = ema.apply([variable1]) # 初始化操作
init_op = tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init_op)
# 初始值输出
# 更新影子变量
sess.run(maintain_average_op)
# 输出变量和变量的影子变量
print(sess.run([variable1, ema.average(variable1)])) # 更新变量
sess.run(tf.assign(variable1, 5))
# 更新影子变量
# decay = min(decay, (1+step_var) / (10+step_var))
# shadow_variable = decay * shadow_variable + (1 - decay) * variable
sess.run(maintain_average_op)
# 输出变量和变量的影子变量
print(sess.run([variable1, ema.average(variable1)])) # 更新step_var
sess.run(tf.assign(step_var, 10000))
# 更新变量
sess.run(tf.assign(variable1, 10))
# 更新影子变量
sess.run(maintain_average_op)
# 输出变量和变量的影子变量
print(sess.run([variable1, ema.average(variable1)])) # 更新影子变量 # 更新影子变量
sess.run(maintain_average_op)
# 输出变量和变量的影子变量
print(sess.run([variable1, ema.average(variable1)]))

输出如下:

[0.0, 0.0]
[5.0, 4.5]
[10.0, 4.5054998]
[10.0, 4.5109944]

tensorflow入门笔记(二) 滑动平均模型的更多相关文章

  1. tensorflow笔记之滑动平均模型

    tensorflow使用tf.train.ExponentialMovingAverage实现滑动平均模型,在使用随机梯度下降方法训练神经网络时候,使用这个模型可以增强模型的鲁棒性(robust),可 ...

  2. Tensorflow滑动平均模型tf.train.ExponentialMovingAverage解析

    觉得有用的话,欢迎一起讨论相互学习~Follow Me 移动平均法相关知识 移动平均法又称滑动平均法.滑动平均模型法(Moving average,MA) 什么是移动平均法 移动平均法是用一组最近的实 ...

  3. tensorflow随机梯度下降算法使用滑动平均模型

    在采用随机梯度下降算法训练神经网络时,使用滑动平均模型可以提高最终模型在测试集数据上的表现.在Tensflow中提供了tf.train.ExponentialMovingAverage来实现滑动平均模 ...

  4. Tensorflow中的滑动平均模型

    原文链接 在Tensorflow的教程里面,使用梯度下降算法训练神经网络时,都会提到一个使模型更加健壮的策略,即滑动平均模型. 基本思想 在使用梯度下降算法训练模型时,每次更新权重时,为每个权重维护一 ...

  5. tensorflow学习笔记二:入门基础 好教程 可用

    http://www.cnblogs.com/denny402/p/5852083.html tensorflow学习笔记二:入门基础   TensorFlow用张量这种数据结构来表示所有的数据.用一 ...

  6. 1 TensorFlow入门笔记之基础架构

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  7. 78、tensorflow滑动平均模型,用来更新迭代的衰减系数

    ''' Created on 2017年4月21日 @author: weizhen ''' #4.滑动平均模型 import tensorflow as tf #定义一个变量用于计算滑动平均,这个变 ...

  8. tensorflow入门笔记(三) tf.GraphKeys

    tf.GraphKeys类存放了图集用到的标准名称. 该标准库使用各种已知的名称收集和检索图中相关的值.例如,tf.Optimizer子类在没有明确指定待优化变量的情况下默认优化被收集到tf.Grap ...

  9. 吴裕雄 PYTHON 神经网络——TENSORFLOW 滑动平均模型

    import tensorflow as tf v1 = tf.Variable(0, dtype=tf.float32) step = tf.Variable(0, trainable=False) ...

随机推荐

  1. java框架篇---hibernate(一对多)映射关系

    一对多关系可以分为单向和双向. 一对多关系单向 单向就是只能从一方找到另一方,通常是从主控类找到拥有外键的类(表).比如一个母亲可以有多个孩子,并且孩子有母亲的主键作为外键.母亲与孩子的关系就是一对多 ...

  2. 【iCore4 双核心板_FPGA】例程一:GPIO输出实验——点亮LED

    实验现象: 三色LED循环点亮. 核心源代码: module led_ctrl( input clk_25m, input rst_n, output fpga_ledr, output fpga_l ...

  3. 在win10企业版x64下使用curl命令

    一.curl命令介绍 curl是利用URL语法在命令行方式下工作的开源文件传输工具.它被广泛应用在Unix.多种Linux发行版中,并且有DOS和Win32.Win64下的移植版本. 详情查看百度百科 ...

  4. Mysql 地区经纬度 查询

    摘要: Mysql数据库,根据地区的经纬度信息,查询附近相邻的地区 2015-03-27 修改,增加 MySQL的空间扩展(MySQL Spatial Extensions)的解决方案: MySQL的 ...

  5. C语言程序设计--文件操作

    前言 这里尝试与Python对别的方法来学习C语言的文件操作,毕竟我是Pythoner. 文件打开与关闭 Python #因为是和C语言比对,所以不使用with filename = "/e ...

  6. WinForm资源管理器开发(TreeView&ListView)

    在C# WinForm开发当中,有三大View控件值得深入应用,分别为DataGridView.ListView.TreeView.如果这三大控件能够熟练的应用,其它的控件也就基本没有问题.所以这篇博 ...

  7. cordova 企业应用打包Archive的时候报 "#import <Cordova file not found"

    可能原因是Cordova的路径问题: For xcode7 add "$(OBJROOT)/UninstalledProducts/$(PLATFORM_NAME)/include" ...

  8. YSQL获取自增ID的四种方法(转发)

    YSQL获取自增ID的四种方法(转发) 1. select max(id) from tablename 2.SELECT LAST_INSERT_ID() 函数 LAST_INSERT_ID 是与t ...

  9. I - 迷宫问题

    定义一个二维数组: int maze[5][5] = { 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, ...

  10. centos7 安装redis服务及phpredis扩展

    闲话少说 服务器版本:centos7.6 64位 软件包:https://pan.baidu.com/s/1Gb4iz5mqLqNVWvvZdBiOMQ 提取码: xrhx 一.安装redis 放在/ ...