keras BatchNormalization 之坑
任务简述:最近做一个图像分类的任务, 一开始拿vgg跑一个baseline,输出看起来很正常:
随后,我尝试其他的一些经典的模型架构,比如resnet50, xception,但训练输出显示明显异常:
val_loss 一直乱蹦,val_acc基本不发生变化。
检查了输入数据没发现问题,因此怀疑是网络构造有问题, 对比了vgg同xception, resnet在使用layer上的异同,认为问题可能出在BN层上,将vgg添加了BN层之后再训练果然翻车。
翻看keras BN 的源码, 原来keras 的BN层的call函数里面有个默认参数traing, 默认是None。此参数意义如下:
training=False/0, 训练时通过每个batch的移动平均的均值、方差去做批归一化,测试时拿整个训练集的均值、方差做归一化
training=True/1/None,训练时通过当前batch的均值、方差去做批归一化,测试时拿整个训练集的均值、方差做归一化
当training=None时,训练和测试的批归一化方式不一致,导致validation的输出指标翻车。
当training=True时,拿训练完的模型预测一个样本和预测一个batch的样本的差异非常大,也就是预测的结果根据batch的大小会不同!导致模型结果无法准确评估!也是个坑!
用keras的BN时切记要设置training=False!!!
def build_model():
Inputs = Input(shape=intput_shape, name='input')
x_tmp = Lambda(lambda c: tf.image.rgb_to_grayscale(c))(Inputs)
x_tmp = Conv2D(64, (3, 3), activation='relu')(x_tmp)
x_tmp = Conv2D(64, (3, 3), activation='relu')(x_tmp)
x_tmp = BatchNormalization(x_tmp, training=False)
x_tmp = MaxPooling2D(pool_size=(2, 2))(x_tmp) x_tmp = Flatten()(x_tmp)
x_tmp = Dense(128, activation='relu')(x_tmp)
outputs = Dense(10, activation='softmax')(x_tmp)
model = Model(Inputs, outputs)
return model
参考:
https://arxiv.org/pdf/1502.03167v3.pdf
https://github.com/keras-team/keras/blob/master/keras/layers/normalization.py#L16
keras BatchNormalization 之坑的更多相关文章
- win10+anaconda安装tensorflow和keras遇到的坑小结
win10下利用anaconda安装tensorflow和keras的教程都大同小异(针对CPU版本,我的gpu是1050TI的MAX-Q,不知为啥一直没安装成功),下面简单说下步骤. 一 Anaco ...
- tensorflow 2.0 技巧 | 自定义tf.keras.Model的坑
自定义tf.keras.Model需要注意的点 model.save() subclass Model 是不能直接save的,save成.h5,但是能够save_weights,或者save_form ...
- tf.keras遇见的坑:Output tensors to a Model must be the output of a TensorFlow `Layer`
经过网上查找,找到了问题所在:在使用keras编程模式是,中间插入了tf.reshape()方法便遇到此问题. 解决办法:对于遇到相同问题的任何人,可以使用keras的Lambda层来包装张量流操作, ...
- keras用法
关于Keras的“层”(Layer) 所有的Keras层对象都有如下方法: layer.get_weights():返回层的权重(numpy array) layer.set_weights(weig ...
- 『计算机视觉』Mask-RCNN_推断网络其二:基于ReNet101的FPN共享网络暨TensorFlow和Keras交互简介
零.参考资料 有关FPN的介绍见『计算机视觉』FPN特征金字塔网络. 网络构架部分代码见Mask_RCNN/mrcnn/model.py中class MaskRCNN的build方法的"in ...
- [Tensorflow] 使用 Mask_RCNN 完成目标检测与实例分割,同时输出每个区域的 Feature Map
Mask_RCNN-2.0 网页链接:https://github.com/matterport/Mask_RCNN/releases/tag/v2.0 Mask_RCNN-master(matter ...
- Windows 下安装 tensorflow & keras & opencv 的避坑指南!
安装 Anaconda3 关键的一步: conda update pip 下面再去安装各种你需要的包,一般不会再报错. pip install -U tensorflow pip install -U ...
- Keras实现Hierarchical Attention Network时的一些坑
Reshape 对于的张量x,x.shape=(a, b, c, d)的情况 若调用keras.layer.Reshape(target_shape=(-1, c, d)), 处理后的张量形状为(?, ...
- Keras + Flask 提供接口服务的坑~~~
最近在搞Keras,训练完的模型要提供个预测服务出来.就想了个办法,通过Flask提供一个http服务,后来发现也能正常跑,但是每次预测都需要加载模型,效率非常低. 然后就把模型加载到全局,每次要用的 ...
随机推荐
- 【Linux】服务器识别ntfs移动磁盘方法
Linux服务器无法识别ntfs磁盘 如果想识别的话,需要安装一个包ntfs-3g 安装好后,将移动磁盘插入到服务器的usb口中 新建一个目录,将磁盘挂载在新建的目录上 挂载命令如下: mount - ...
- 快速查询表中的NULL数据
正常情况下,NULL值是不会放入B-TREE索引的,因此根据IS NULL查询的时候走的通常是全表扫描,如果记录比较少还好,记录比较多,查询会非常耗时 可以通过创建一个索引来解决 CREATE IND ...
- MySQL全面瓦解18:自定义函数
定义 我们之前学习了MySQL的内置函数,非常丰富,满足了我们对数据操作的大部分需求. 但是如果有一些复杂的业务逻辑在数据库层面就可以完成,无需在程序层面完成的时候,这时候就可以写成MySQL自定义函 ...
- Ubuntu18.04完全卸载mysql5.7并安装mysql8.0的安装方法
Ubuntu18.04版本下,如果直接输入: sudo apt install mysql-server 命令,会默认安装mysql5.7版本,安装过程并没有提示输入密码,安装完成后也无法正常登录,这 ...
- linux/git常用命令收集中
1.进入文件夹 cd 文件名 进入某个文件 cd .. 返回上一级目录 cd / 进入根目录 cd ~ 切换到当前 cd - 切换到上一个目录 2.查看 pwd 文件名 查看路 ...
- 正向代理 forward proxy、反向代理 reverse proxy、透明代理 transparent proxy
https://zh.wikipedia.org/wiki/反向代理 反向代理在计算机网络中是代理服务器的一种.服务器根据客户端的请求,从其关系的一组或多组后端服务器(如Web服务器)上获取资源,然后 ...
- wireshark使用手册
Wireshark的过滤器 使用Wireshark时最常见的问题,是当您使用默认设置时,会得到大量冗余信息,以至于很难找到自己需要的部分. 过犹不及. 这就是为什么过滤器会如此重要.它们可以帮助我们在 ...
- setTimeout、Promise、Async/Await 的区别
事件循环中分为宏任务队列和微任务队列其中setTimeout的回调函数放到宏任务队列里,等到执行栈清空以后执行promise.then里的回调函数会放到相应宏任务的微任务队列里,等宏任务里面的同步代码 ...
- Docker Desktop启动Kubernetes
Docker_Desktop启动Kubernetes 参考仓库:https://github.com/AliyunContainerService/k8s-for-docker-desktop 视频参 ...
- Java三种IO模型和LinuxIO五种IO模型
Java: https://github.com/Snailclimb/JavaGuide/blob/master/docs/java/BIO-NIO-AIO.md https://github.co ...