tensorflow中moving average的用法
一般在保存模型参数的时候,都会保存一份moving average,是取了不同迭代次数模型的移动平均,移动平均后的模型往往在性能上会比最后一次迭代保存的模型要好一些。
tensorflow-models项目中tutorials下cifar中相关的代码写的有点问题,在这写下我自己的做法:
1.构建训练模型时,添加如下代码
- variable_averages = tf.train.ExponentialMovingAverage(0.999, global_step)
- variables_averages_op = variable_averages.apply(tf.trainable_variables())
- ave_vars = [variable_averages.average(var) for var in tf.trainable_variables()]
- train_op = tf.group(train_op, variables_averages_op)
第1行创建了一个指数移动平均类 variable_averages
第2行将variable_averages作用于当前模型中所有可训练的变量上,得到 variables_averages_op操作符
第3行获得所有可训练变量对应的移动平均变量列表集合,后续用于保存模型
第4行在原有的训练操作符基础上,再添加variables_averages_op操作符,后续session执行run的时候,除了训练时前向后向,梯度更新,还会对相应的变量做移动平均
2.开始训练前,创建saver时,使用如下代码
- save_vars = tf.trainable_variables() + ave_vars
saver = tf.train.Saver(var_list=save_vars, max_to_keep=5)
第1行获取所有需要保存的变量列表,这个时候 ave_vars就派上用场了。
第2行创建saver,指定var_list为所有可训练变量及其对应的移动平均变量。
另外需要注意的是,如果你的模型中有bn或者类似层,包含有统计参数(均值、方差等),这些不属于可训练参数,还需要额外添加进save_vars中,可以参考我的这篇博客
3.在做inference的时候,利用如下代码从checkpoint中恢复出移动平均模型
- variable_averages = tf.train.ExponentialMovingAverage(0.999)
- variables_to_restore = variable_averages.variables_to_restore()
- saver = tf.train.Saver(variables_to_restore)
- saver.restore(sess, model_path)
这几行很简单,就不做解释了。
实际上,在inference的时候,刚刚的做法除了可以从checkpoint文件中恢复出移动平均参数,还可以恢复出对应迭代的模型参数,可以用来对比两种方式,哪种效果更好,这时只需要将上面代码的第3行改为saver = tf.train.Saver(tf.trainable_variables())即可(和保存时相同,如果有bn,也需要额外考虑)。在我的测试中,使用移动平均参数效果更佳。
tensorflow中moving average的用法的更多相关文章
- tensorflow中batch normalization的用法
网上找了下tensorflow中使用batch normalization的博客,发现写的都不是很好,在此总结下: 1.原理 公式如下: y=γ(x-μ)/σ+β 其中x是输入,y是输出,μ是均值,σ ...
- [LeetCode] Moving Average from Data Stream 从数据流中移动平均值
Given a stream of integers and a window size, calculate the moving average of all integers in the sl ...
- [Swift]LeetCode346. 从数据流中移动平均值 $ Moving Average from Data Stream
Given a stream of integers and a window size, calculate the moving average of all integers in the sl ...
- [转载]Tensorflow中reduction_indices 的用法
Tensorflow中reduction_indices 的用法 默认时None 压缩成一维
- LeetCode 346. Moving Average from Data Stream (数据流动中的移动平均值)$
Given a stream of integers and a window size, calculate the moving average of all integers in the sl ...
- [LeetCode] 346. Moving Average from Data Stream 从数据流中移动平均值
Given a stream of integers and a window size, calculate the moving average of all integers in the sl ...
- TensorFlow中的L2正则化函数:tf.nn.l2_loss()与tf.contrib.layers.l2_regularizerd()的用法与异同
tf.nn.l2_loss()与tf.contrib.layers.l2_regularizerd()都是TensorFlow中的L2正则化函数,tf.contrib.layers.l2_regula ...
- 第十八节,TensorFlow中使用批量归一化(BN)
在深度学习章节里,已经介绍了批量归一化的概念,详情请点击这里:第九节,改善深层神经网络:超参数调试.正则化以优化(下) 神经网络在进行训练时,主要是用来学习数据的分布规律,如果数据的训练部分和测试部分 ...
- 理解滑动平均(exponential moving average)
1. 用滑动平均估计局部均值 滑动平均(exponential moving average),或者叫做指数加权平均(exponentially weighted moving average),可以 ...
随机推荐
- 受欢迎的牛 [HAOI2006] [强连通] [传递闭包(划)]
Description 每一头牛的愿望就是变成一头最受欢迎的牛.现在有N头牛,给你M对整数(A,B),表示牛 A 认为牛 B受欢迎.这种关系是具有传递性的,如果A认为B受欢迎,B认为C受欢迎,那么牛A ...
- Flask蓝图
它的作用就是将 功能 与 主服务 分开怎么理解呢? 比如说,你有一个客户管理系统,最开始的时候,只有一个查看客户列表的功能,后来你又加入了一个添加客户的功能(add_user)模块, 然后又加入了一个 ...
- Java 构造器 考虑用静态构造方法代替构造器
类可以提供一个公有的静态工厂方法,它是一个返回类的实例的静态方法.静态工厂方法与设计模式中的工厂方法模式不同. 优势: 静态工厂方法与构造器不同的第一大优势在于,它们有名称.一个类只能有一个带有指定签 ...
- Gird Layout代码解释
<div class="wrapper"> <!--定义一个类名为wrapper的div盒子--> <div class="one" ...
- GMA Round 1 相交
传送门 相交 在实数范围内,设抛物线$C_1:y^2=2x$,双曲线:$C_2:\frac{y^2}{b^2}-\frac{x^2}{a^2}=1$(a,b为参数). 假如a和b都在(0,16)这个区 ...
- Ubunut操作系统下nDPI的部署及简单使用
[系统:Ubuntu16.04LTS ] [ nDPI版本:2.5.0] [ 内核:4.15.0-39-generic] 前期准备工作--依赖安装 所需依赖包(前两个ubuntu16已有不需安装) g ...
- 前端工程化系列[03]-Grunt构建工具的运转机制
在前端工程化系列[02]-Grunt构建工具的基本使用这篇文章中,已经对Grunt做了简单的介绍,此外,我们还知道了该如何来安装Grunt环境,以及使用一些常见的插件了,这篇文章主要介绍Grunt的核 ...
- 基于SOUI开发一个简单的小工具
基于DriectUI有很多库,比如 Duilib (免费) soui (免费) DuiVision (免费) 炫彩 (界面库免费,UI设计器付费,不提供源码) skinui (免费使用,但不开放源码, ...
- EAS开发环境搭建.
一:EAS开发环境安装 解压EAS服务器安装包到E盘即可,内含BOS开发环境. 二:EAS客户端安装 EAS8.0.exe安装到D盘,这是客户端. 三:远程数据库 使用远程运维系统,登陆数据库.
- kafka注册异常
问题描述: kafka注册异常,提示brokers id已经被注册过 -- ::,] FATAL [Kafka Server ], Fatal error during KafkaServer sta ...