一般在保存模型参数的时候,都会保存一份moving average,是取了不同迭代次数模型的移动平均,移动平均后的模型往往在性能上会比最后一次迭代保存的模型要好一些。

tensorflow-models项目中tutorials下cifar中相关的代码写的有点问题,在这写下我自己的做法:

1.构建训练模型时,添加如下代码

 variable_averages = tf.train.ExponentialMovingAverage(0.999, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
ave_vars = [variable_averages.average(var) for var in tf.trainable_variables()]
train_op = tf.group(train_op, variables_averages_op)

第1行创建了一个指数移动平均类 variable_averages

第2行将variable_averages作用于当前模型中所有可训练的变量上,得到 variables_averages_op操作符

第3行获得所有可训练变量对应的移动平均变量列表集合,后续用于保存模型

第4行在原有的训练操作符基础上,再添加variables_averages_op操作符,后续session执行run的时候,除了训练时前向后向,梯度更新,还会对相应的变量做移动平均

2.开始训练前,创建saver时,使用如下代码

 save_vars = tf.trainable_variables() + ave_vars
saver = tf.train.Saver(var_list=save_vars, max_to_keep=5)

第1行获取所有需要保存的变量列表,这个时候 ave_vars就派上用场了。

第2行创建saver,指定var_list为所有可训练变量及其对应的移动平均变量。

另外需要注意的是,如果你的模型中有bn或者类似层,包含有统计参数(均值、方差等),这些不属于可训练参数,还需要额外添加进save_vars中,可以参考我的这篇博客

3.在做inference的时候,利用如下代码从checkpoint中恢复出移动平均模型

 variable_averages = tf.train.ExponentialMovingAverage(0.999)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess, model_path)

这几行很简单,就不做解释了。

实际上,在inference的时候,刚刚的做法除了可以从checkpoint文件中恢复出移动平均参数,还可以恢复出对应迭代的模型参数,可以用来对比两种方式,哪种效果更好,这时只需要将上面代码的第3行改为saver = tf.train.Saver(tf.trainable_variables())即可(和保存时相同,如果有bn,也需要额外考虑)。在我的测试中,使用移动平均参数效果更佳。

tensorflow中moving average的用法的更多相关文章

  1. tensorflow中batch normalization的用法

    网上找了下tensorflow中使用batch normalization的博客,发现写的都不是很好,在此总结下: 1.原理 公式如下: y=γ(x-μ)/σ+β 其中x是输入,y是输出,μ是均值,σ ...

  2. [LeetCode] Moving Average from Data Stream 从数据流中移动平均值

    Given a stream of integers and a window size, calculate the moving average of all integers in the sl ...

  3. [Swift]LeetCode346. 从数据流中移动平均值 $ Moving Average from Data Stream

    Given a stream of integers and a window size, calculate the moving average of all integers in the sl ...

  4. [转载]Tensorflow中reduction_indices 的用法

    Tensorflow中reduction_indices 的用法 默认时None 压缩成一维

  5. LeetCode 346. Moving Average from Data Stream (数据流动中的移动平均值)$

    Given a stream of integers and a window size, calculate the moving average of all integers in the sl ...

  6. [LeetCode] 346. Moving Average from Data Stream 从数据流中移动平均值

    Given a stream of integers and a window size, calculate the moving average of all integers in the sl ...

  7. TensorFlow中的L2正则化函数:tf.nn.l2_loss()与tf.contrib.layers.l2_regularizerd()的用法与异同

    tf.nn.l2_loss()与tf.contrib.layers.l2_regularizerd()都是TensorFlow中的L2正则化函数,tf.contrib.layers.l2_regula ...

  8. 第十八节,TensorFlow中使用批量归一化(BN)

    在深度学习章节里,已经介绍了批量归一化的概念,详情请点击这里:第九节,改善深层神经网络:超参数调试.正则化以优化(下) 神经网络在进行训练时,主要是用来学习数据的分布规律,如果数据的训练部分和测试部分 ...

  9. 理解滑动平均(exponential moving average)

    1. 用滑动平均估计局部均值 滑动平均(exponential moving average),或者叫做指数加权平均(exponentially weighted moving average),可以 ...

随机推荐

  1. BZOJ4167 : 永远的竹笋采摘

    首先枚举出所有可能成为区间最小差值的点对$(j,i)$. 枚举每个位置作为右端点$i$,假设$a[j]>a[i]$. 找到第一个这样的$j$,那么可以将下一个$a[j]$的范围缩小到$(a[i] ...

  2. Linux硬盘管理

    管理好硬盘/dev/xxynsd SCSI SATA USBhd IDE主分区扩展分区 1-4逻辑分区5以后fdisk -l 硬盘名/分区名fdisk -l /dev/sda 如何给硬盘分区?把500 ...

  3. CSS学习之路,指定值,计算值,使用值。

    前面被问过这几个值得区别,没太研究,有点抠文字的感觉,既然到这儿了 ,就简答梳理下吧. 指定值(specified value):通过样式表样式规则定义的值:可以来自层叠样式表,如果没有指定,则考虑父 ...

  4. JSP(6)—JavaBean及案例

    基础: 一.JavaBean ①用作JavaBean的类必须是具有一个公共的无参数的构造方法 ②JavaBean的属性是以方法定义的形式出现的. ③JavaBean的属性名是根据Setter和gett ...

  5. MySQL匹配指定字符串的查询

    MySQL匹配指定字符串的查询 使用正则表达式查询时,正则表达式可以匹配字符串.当表中的记录包含这个字符串时,就可以将该记录查询出来.如果指定多个字符串时,需要用“|”符号隔开,只要匹配这些字符串中的 ...

  6. 微信公众号申请+新浪SAE申请

    一. 新浪SAE服务申请 1. 注冊地址:http://t.cn/RqMHPto 2. 选择控制台>>云应用SAE 3. 创建新应用 4. 填写域名 5. 代码管理选择SVN 6. 创建版 ...

  7. jsp中添加过滤器,实现校验用户身份

    我现在需要实现一个功能,就是用户登录前不允许访问系统,我使用的是jsp的过滤器来实现的. 先把filter过滤器的代码粘出来: package com.day8.filter; import java ...

  8. 利用StringEscapeUtils来转义和反转义html/xml/javascript中的特殊字符

    我们经常遇到html或者xml在Java程序中被某些库转义成了特殊字符. 例如: 各种逻辑运算符: > >= < <= == 被转义成了 &#x3D;&#x3D ...

  9. servlet的xx方式传值中文乱码

    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOEx ...

  10. Android 实现登录界面和功能实例

    近期一个android小程序须要登录功能,我简单实现了一下.如今记录下来也当做个笔记,同一时候也希望能够相互学习.所以,假设我的代码有问题,还各位请提出来.多谢了! 以下.就简述一下此实例的主要内容: ...