之前一直和小伙伴探讨batch normalization层的实现机理,作用在这里不谈,知乎上有一篇paper在讲这个,链接

这里只探究其具体运算过程,我们假设在网络中间经过某些卷积操作之后的输出的feature map的尺寸为4×3×2×2

4为batch的大小,3为channel的数目,2×2为feature map的长宽

整个BN层的运算过程如下图

上图中,batch size一共是4, 对于每一个batch的feature map的size是3×2×2

对于所有batch中的同一个channel的元素进行求均值与方差,比如上图,对于所有的batch,都拿出来最后一个channel,一共有4×4=16个元素,

然后求区这16个元素的均值与方差(上图只求了mean,没有求方差。。。),

求取完了均值与方差之后,对于这16个元素中的每个元素进行减去求取得到的均值与方差,然后乘以gamma加上beta,公式如下

所以对于一个batch normalization层而言,求取的均值与方差是对于所有batch中的同一个channel进行求取,batch normalization中的batch体现在这个地方

batch normalization层能够学习到的参数,对于一个特定的channel而言实际上是两个参数,gamma与beta,对于total的channel而言实际上是channel数目的两倍。

用pytorch验证上述想法是否准确,用上述方法求取均值,以及用batch normalization层输出的均值,看看是否一样

上代码

  1. # -*-coding:utf-8-*-
  2. from torch import nn
  3. import torch
  4.  
  5. m = nn.BatchNorm2d(3) # bn设置的参数实际上是channel的参数
  6. input = torch.randn(4, 3, 2, 2)
  7. output = m(input)
  8. # print(output)
  9. a = (input[0, 0, :, :]+input[1, 0, :, :]+input[2, 0, :, :]+input[3, 0, :, :]).sum()/16
  10. b = (input[0, 1, :, :]+input[1, 1, :, :]+input[2, 1, :, :]+input[3, 1, :, :]).sum()/16
  11. c = (input[0, 2, :, :]+input[1, 2, :, :]+input[2, 2, :, :]+input[3, 2, :, :]).sum()/16
  12. print('The mean value of the first channel is %f' % a.data)
  13. print('The mean value of the first channel is %f' % b.data)
  14. print('The mean value of the first channel is %f' % c.data)
  15. print('The output mean value of the BN layer is %f, %f, %f' % (m.running_mean.data[0],m.running_mean.data[0],m.running_mean.data[0]))
  16. print(m)

  1. m = nn.BatchNorm2d(3)

声明新的batch normalization层,用

  1. input = torch.randn(4, 3, 2, 2)

模拟feature map的尺寸

输出值

咦,怎么不一样,貌似差了一个小数点,可能与BN层的momentum变量有关系,在生命batch normalization层的时候将momentum设置为1试一试

  1. m.momentum=1

输出结果

没毛病

至于方差以及输出值,大抵也是这样进行计算的吧,留个坑

Pytorch中的Batch Normalization操作的更多相关文章

  1. PyTorch中的Batch Normalization

    Pytorch中的BatchNorm的API主要有: 1 torch.nn.BatchNorm1d(num_features, 2 3 eps=1e-05, 4 5 momentum=0.1, 6 7 ...

  2. Tensorflow BatchNormalization详解:4_使用tf.nn.batch_normalization函数实现Batch Normalization操作

    使用tf.nn.batch_normalization函数实现Batch Normalization操作 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 吴恩达deeplearnin ...

  3. 使用TensorFlow中的Batch Normalization

    问题 训练神经网络是一个很复杂的过程,在前面提到了深度学习中常用的激活函数,例如ELU或者Relu的变体能够在开始训练的时候很大程度上减少梯度消失或者爆炸问题.但是却不能保证在训练过程中不出现该问题, ...

  4. 在tensorflow中使用batch normalization

    问题 训练神经网络是一个很复杂的过程,在前面提到了深度学习中常用的激活函数,例如ELU或者Relu的变体能够在开始训练的时候很大程度上减少梯度消失或者爆炸问题,但是却不能保证在训练过程中不出现该问题, ...

  5. tensorflow中使用Batch Normalization

    在深度学习中为了提高训练速度,经常会使用一些正正则化方法,如L2.dropout,后来Sergey Ioffe 等人提出Batch Normalization方法,可以防止数据分布的变化,影响神经网络 ...

  6. 神经网络中使用Batch Normalization 解决梯度问题

    BN本质上解决的是反向传播过程中的梯度问题. 详细点说,反向传播时经过该层的梯度是要乘以该层的参数的,即前向有: 那么反向传播时便有: 那么考虑从l层传到k层的情况,有: 上面这个 便是问题所在.因为 ...

  7. Batch Normalization原理及其TensorFlow实现——为了减少深度神经网络中的internal covariate shift,论文中提出了Batch Normalization算法,首先是对”每一层“的输入做一个Batch Normalization 变换

    批标准化(Bactch Normalization,BN)是为了克服神经网络加深导致难以训练而诞生的,随着神经网络深度加深,训练起来就会越来越困难,收敛速度回很慢,常常会导致梯度弥散问题(Vanish ...

  8. tensorflow中batch normalization的用法

    网上找了下tensorflow中使用batch normalization的博客,发现写的都不是很好,在此总结下: 1.原理 公式如下: y=γ(x-μ)/σ+β 其中x是输入,y是输出,μ是均值,σ ...

  9. pytorch中文文档-torch.nn常用函数-待添加-明天继续

    https://pytorch.org/docs/stable/nn.html 1)卷积层 class torch.nn.Conv2d(in_channels, out_channels, kerne ...

随机推荐

  1. JDK动态代理源码分析

    先抛出一个问题,JDK的动态代理为什么不支持对实现类的代理,只支持接口的代理??? 首先来看一下如何使用JDK动态代理.JDK提供了Java.lang.reflect.Proxy类来实现动态代理的,可 ...

  2. swap分区不足ubuntu休眠

    安装uswsusp Ubuntu gnu/linux只需 代码: sudo aptitude install uswsusp Arch gnu/linux系统 代码: sudo pacman -S u ...

  3. 【Jenkins】Jenkins安装修改默认路径和端口的方法

    一.修改默认的jenkins安装路径 因为jenkins默认安装在c盘 C:\Users\Administrator\.jenkins下,那怎样将安装路径修改至d盘呢? 新建一个系统变量:JENKIN ...

  4. python2.x 与 python3.x的区别

    从语言的源码角度: python2.x 的源码书写不够规范,且源码有重复,代码的复用率不高; python3.x 的源码清晰.优美.简单 从语言的特性角度: python2.x 默认为ASCII字符编 ...

  5. Win10系列:VC++媒体播放控制3

    (5)添加视频进度条 视频进度条可以用来显示当前视频的播放进度,并可以通过拖动视频进度条来改变视频的播放进度.接下来介绍如何实现视频进度条,首先打开MainPage.xaml文件,并在Grid元素中添 ...

  6. do_bootrk

    1. LMB (logical memory blocks) lmb为uboot下的一种内存管理机制,用于管理镜像的内存.lmb所记录的内存信息最终会传递给kernel.在/include/lmb.h ...

  7. js判断一个值是空的最快方法是不是if(!value){alert("这个变量的值是null");}

    !逻辑非 操作符(js)-操作于任何值,(!undefined)(!Null)(!任何对象)(!"")(!"lihuan")(!任何非零数字值) (!0)(!N ...

  8. vue-11-自定义指令

    用于对纯 DOM 元素进行底层操作. // 注册一个全局自定义指令 v-focus Vue.directive('focus', { // 当绑定元素插入到 DOM 中. inserted: func ...

  9. (Java学习笔记) Java Networking (Java 网络)

    Java Networking (Java 网络) 1. 网络通信协议 Network Communication Protocols Network Protocol is a set of rul ...

  10. DevExpress ASP.NET Bootstrap Controls v18.2新功能详解(二)

    行业领先的.NET界面控件2018年第二次重大更新——DevExpress v18.2日前正式发布,本站将以连载的形式为大家介绍新版本新功能.本文将介绍了DevExpress ASP.NET Boot ...