以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了。在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在《实战Google深度学习框架》第二版这本书P166里只是提了一句,没有做出解答。

  书中说训练时和测试时使用的参数is_training都为True,然后给出了一个链接供参考。本人刚开始使用时也是按照书中的做法没有改动,后来从保存后的checkpoint中加载模型做预测时出了问题:当改变需要预测数据的batchsize时预测的label也跟着变,这意味着checkpoint里面没有保存训练中BN层的参数,使用的BN层参数还是从需要预测的数据中计算而来的。这显然会出问题,当预测的batchsize越大,假如你的预测数据集和训练数据集的分布一致,结果就越接近于训练结果,但如果batchsize=1,那BN层就发挥不了作用,结果很难看。

  那如果在预测时is_traning=false呢,但BN层的参数没有从训练中保存,那使用的就是随机初始化的参数,结果不堪想象。

  所以需要在训练时把BN层的参数保存下来,然后在预测时加载,参考几位大佬的博客,有了以下训练时添加的代码:

 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_step = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss) # 设置保存模型
var_list = tf.trainable_variables()
g_list = tf.global_variables()
bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
var_list += bn_moving_vars
saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

这样就可以在预测时从checkpoint文件加载BN层的参数并设置is_training=False。

最后要说的是,虽然这么做可以解决这个问题,但也可以利用预测数据来计算BN层的参数,不是说一定要保存训练时的参数,两种方案可以作为超参数来调节使用,看哪种方法的结果更好。

感谢几位大佬的博客解惑:

  https://blog.csdn.net/dongjbstrong/article/details/80447110?utm_source=blogxgwz0

  http://www.cnblogs.com/hrlnw/p/7227447.html

Tensorflow训练和预测中的BN层的坑的更多相关文章

  1. tensorflow CNN 卷积神经网络中的卷积层和池化层的代码和效果图

    tensorflow CNN 卷积神经网络中的卷积层和池化层的代码和效果图 因为很多 demo 都比较复杂,专门抽出这两个函数,写的 demo. 更多教程:http://www.tensorflown ...

  2. TensorFlow使用记录 (七): BN 层及 Dropout 层的使用

    参考:tensorflow中的batch_norm以及tf.control_dependencies和tf.GraphKeys.UPDATE_OPS的探究 1. Batch Normalization ...

  3. 【转载】 Pytorch(1) pytorch中的BN层的注意事项

    原文地址: https://blog.csdn.net/weixin_40100431/article/details/84349470 ------------------------------- ...

  4. 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用隐藏层

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

  5. tensorflow 训练最后预测结果为一个定值,可能的原因

    训练一个分类网络,没想到预测结果为一个定值. 找了很久发现,是因为tensor的维度的原因.  注意:我说的是我的label数据的维度. 我的输入是: y_= tf.placeholder(tf.in ...

  6. 【转载】 【caffe转向pytorch】caffe的BN层+scale层=pytorch的BN层

    原文地址: https://blog.csdn.net/u011668104/article/details/81532592 ------------------------------------ ...

  7. tensorflow在文本处理中的使用——Word2Vec预测

    代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)——第七章:自然语言处理 代码地址:https://github.com/nfmcclure/tensorflow-coo ...

  8. BN层

    论文名字:Batch Normalization: Accelerating Deep Network Training by  Reducing Internal Covariate Shift 论 ...

  9. 【卷积神经网络】对BN层的解释

    前言 Batch Normalization是由google提出的一种训练优化方法.参考论文:Batch Normalization Accelerating Deep Network Trainin ...

随机推荐

  1. 2017(5)软件架构设计,web系统的架构设计,数据库系统,分布式数据库

    试题五(共 25 分) 阅读以下关于 Web 系统架构设计的叙述,在答题纸上回答问题1 至问题 3. [说明] 某公司开发的 B2C 商务平台因业务扩展,导致系统访问量不断增大,现有系统访问速度缓慢, ...

  2. Linux如何实现进程监控和守护

    最近新搭建的亚马逊EC2服务器, 上面部署了一个静态的WEB, 启动了一个nginx做代理.最近发现一个问题: Nginx进程隔一段时间就莫名的挂掉了, 然后就出现了网站无法打开的窘境.. 为了防止这 ...

  3. [elk]验证mapping字段数和数据字段数关系

    验证一个mapping下字段缺少或者超过 结论: 没有什么不可以. 1.如果数据字段不在mapping里,则动态会更新mapping. 2.数据字段数也可以小于mapping里字段数 创建一个mapp ...

  4. OO第一单元总结(表达式求导)

    写在前边:第一次接触面向对象语言,编程思想仍然不可避免的有以前面向过程的影子.从第一次作业的完全面向过程,到第二次学会剥离各个类互不影响到第三次作业的先构思面向对象的基本程序架构再编程.虽然程序有些地 ...

  5. windows10误删Administrator用户的家目录之后

    今天在玩Windows10的用户管理的时候,把Administrator用户给开启了,然后还用这个用户登录了系统. 结果就是,第一次登录的时候,闪过一条条初始化配置欢迎信息,Windows系统为Adm ...

  6. HttpClient 知识点

    1. httpClient 默认超时时间是 60秒 2.httpClient 是模拟表单提交,所以服务端接口要用HttpServletRequest request 接收  例如: request.g ...

  7. 全局路径规划算法Dijkstra(迪杰斯特拉算法)- matlab

    参考博客链接:https://www.cnblogs.com/kex1n/p/4178782.html Dijkstra是常用的全局路径规划算法,其本质上是一个最短路径寻优算法.算法的详细介绍参考上述 ...

  8. Sketchup (待续)

    Sketchup插件 来自20个最好用的SketchUp插件 https://www.bilibili.com/video/av17242031/?from=search&seid=15336 ...

  9. 阿里云 SSL 证书 总结

    历时2天左右的证书上传部署,终于结束了! 因为公司要开发小程序,小程序部署到开发环境必须支持https证书行. 阿里云目前的证书还是比较多的额,大致分为2类,一类是支持单域名,一类是支持泛域名. 自己 ...

  10. 《linux 必读》

    1. linux 内核设计与实现 2. 深入理解 linux 内核