1、计算的均值和方差是channel的

2、test/predict 或者use_global_stats的时候,直接使用moving average

use_global_stats 表示是否使用全部数据的统计值(该数据实在train 阶段通过moving average 方法计算得到)训练阶段设置为 fasle, 表示通过当前的minibatch 数据计算得到, test/predict 阶段使用 通过全部数据计算得到的统计值

那什么是 moving average 呢:

反向传播:

源码:(注:caffe_cpu_scale 是y=alpha*x ,这里面求滑动均值时候,alpha是滑动系数和的倒数,x是滑动均值和

  1. template <typename Dtype>
  2. void BatchNormLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
  3. const vector<Blob<Dtype>*>& top) {
  4. const Dtype* bottom_data = bottom[0]->cpu_data();
  5. Dtype* top_data = top[0]->mutable_cpu_data();
  6. int num = bottom[0]->shape(0);
  7. int spatial_dim = bottom[0]->count()/(bottom[0]->shape(0)*channels_);
  8.  
  9. if (bottom[0] != top[0]) {
  10. caffe_copy(bottom[0]->count(), bottom_data, top_data);
  11. }
  12.  
  13. if (use_global_stats_) {
  14. // use the stored mean/variance estimates.
  15. const Dtype scale_factor = this->blobs_[2]->cpu_data()[0] == 0 ?
  16. 0 : 1 / this->blobs_[2]->cpu_data()[0];
  17. caffe_cpu_scale(variance_.count(), scale_factor,
  18. this->blobs_[0]->cpu_data(), mean_.mutable_cpu_data());
  19. caffe_cpu_scale(variance_.count(), scale_factor,
  20. this->blobs_[1]->cpu_data(), variance_.mutable_cpu_data());
  21. } else {
  22. // compute mean
  23. caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim,
  24. 1. / (num * spatial_dim), bottom_data,
  25. spatial_sum_multiplier_.cpu_data(), 0.,
  26. num_by_chans_.mutable_cpu_data());
  27. caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
  28. num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0.,
  29. mean_.mutable_cpu_data());
  30. }
  31.  
  32. // subtract mean
  33. caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
  34. batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0.,
  35. num_by_chans_.mutable_cpu_data());
  36. caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_ * num,
  37. spatial_dim, 1, -1, num_by_chans_.cpu_data(),
  38. spatial_sum_multiplier_.cpu_data(), 1., top_data);
  39.  
  40. if (!use_global_stats_) {
  41. // compute variance using var(X) = E((X-EX)^2)
  42. caffe_powx(top[0]->count(), top_data, Dtype(2),
  43. temp_.mutable_cpu_data()); // (X-EX)^2
  44. caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim,
  45. 1. / (num * spatial_dim), temp_.cpu_data(),
  46. spatial_sum_multiplier_.cpu_data(), 0.,
  47. num_by_chans_.mutable_cpu_data());
  48. caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
  49. num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0.,
  50. variance_.mutable_cpu_data()); // E((X_EX)^2)
  51.  
  52. // compute and save moving average
  53. this->blobs_[2]->mutable_cpu_data()[0] *= moving_average_fraction_;
  54. this->blobs_[2]->mutable_cpu_data()[0] += 1;
  55. caffe_cpu_axpby(mean_.count(), Dtype(1), mean_.cpu_data(),
  56. moving_average_fraction_, this->blobs_[0]->mutable_cpu_data());
  57. int m = bottom[0]->count()/channels_;
  58. Dtype bias_correction_factor = m > 1 ? Dtype(m)/(m-1) : 1;
  59. caffe_cpu_axpby(variance_.count(), bias_correction_factor,
  60. variance_.cpu_data(), moving_average_fraction_,
  61. this->blobs_[1]->mutable_cpu_data());
  62. }
  63.  
  64. // normalize variance
  65. caffe_add_scalar(variance_.count(), eps_, variance_.mutable_cpu_data());
  66. caffe_powx(variance_.count(), variance_.cpu_data(), Dtype(0.5),
  67. variance_.mutable_cpu_data());
  68.  
  69. // replicate variance to input size
  70. caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
  71. batch_sum_multiplier_.cpu_data(), variance_.cpu_data(), 0.,
  72. num_by_chans_.mutable_cpu_data());
  73. caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_ * num,
  74. spatial_dim, 1, 1., num_by_chans_.cpu_data(),
  75. spatial_sum_multiplier_.cpu_data(), 0., temp_.mutable_cpu_data());
  76. caffe_div(temp_.count(), top_data, temp_.cpu_data(), top_data);
  77. // TODO(cdoersch): The caching is only needed because later in-place layers
  78. // might clobber the data. Can we skip this if they won't?
  79. caffe_copy(x_norm_.count(), top_data,
  80. x_norm_.mutable_cpu_data());
  81. }
  82.  
  83. template <typename Dtype>
  84. void BatchNormLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
  85. const vector<bool>& propagate_down,
  86. const vector<Blob<Dtype>*>& bottom) {
  87. const Dtype* top_diff;
  88. if (bottom[0] != top[0]) {
  89. top_diff = top[0]->cpu_diff();
  90. } else {
  91. caffe_copy(x_norm_.count(), top[0]->cpu_diff(), x_norm_.mutable_cpu_diff());
  92. top_diff = x_norm_.cpu_diff();
  93. }
  94. Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
  95. if (use_global_stats_) {
  96. caffe_div(temp_.count(), top_diff, temp_.cpu_data(), bottom_diff);
  97. return;
  98. }
  99. const Dtype* top_data = x_norm_.cpu_data();
  100. int num = bottom[0]->shape()[0];
  101. int spatial_dim = bottom[0]->count()/(bottom[0]->shape(0)*channels_);
  102. // if Y = (X-mean(X))/(sqrt(var(X)+eps)), then
  103. //
  104. // dE(Y)/dX =
  105. // (dE/dY - mean(dE/dY) - mean(dE/dY \cdot Y) \cdot Y)
  106. // ./ sqrt(var(X) + eps)
  107. //
  108. // where \cdot and ./ are hadamard product and elementwise division,
  109. // respectively, dE/dY is the top diff, and mean/var/sum are all computed
  110. // along all dimensions except the channels dimension. In the above
  111. // equation, the operations allow for expansion (i.e. broadcast) along all
  112. // dimensions except the channels dimension where required.
  113.  
  114. // sum(dE/dY \cdot Y)
  115. caffe_mul(temp_.count(), top_data, top_diff, bottom_diff);
  116. caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim, 1.,
  117. bottom_diff, spatial_sum_multiplier_.cpu_data(), 0.,
  118. num_by_chans_.mutable_cpu_data());
  119. caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
  120. num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0.,
  121. mean_.mutable_cpu_data());
  122.  
  123. // reshape (broadcast) the above
  124. caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
  125. batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0.,
  126. num_by_chans_.mutable_cpu_data());
  127. caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_ * num,
  128. spatial_dim, 1, 1., num_by_chans_.cpu_data(),
  129. spatial_sum_multiplier_.cpu_data(), 0., bottom_diff);
  130.  
  131. // sum(dE/dY \cdot Y) \cdot Y
  132. caffe_mul(temp_.count(), top_data, bottom_diff, bottom_diff);
  133.  
  134. // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y
  135. caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim, 1.,
  136. top_diff, spatial_sum_multiplier_.cpu_data(), 0.,
  137. num_by_chans_.mutable_cpu_data());
  138. caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
  139. num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0.,
  140. mean_.mutable_cpu_data());
  141. // reshape (broadcast) the above to make
  142. // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y
  143. caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
  144. batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0.,
  145. num_by_chans_.mutable_cpu_data());
  146. caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num * channels_,
  147. spatial_dim, 1, 1., num_by_chans_.cpu_data(),
  148. spatial_sum_multiplier_.cpu_data(), 1., bottom_diff);
  149.  
  150. // dE/dY - mean(dE/dY)-mean(dE/dY \cdot Y) \cdot Y
  151. caffe_cpu_axpby(temp_.count(), Dtype(1), top_diff,
  152. Dtype(-1. / (num * spatial_dim)), bottom_diff);
  153.  
  154. // note: temp_ still contains sqrt(var(X)+eps), computed during the forward
  155. // pass.
  156. caffe_div(temp_.count(), bottom_diff, temp_.cpu_data(), bottom_diff);
  157. }
  158.  
  159. #ifdef CPU_ONLY
  160. STUB_GPU(BatchNormLayer);
  161. #endif
  162.  
  163. INSTANTIATE_CLASS(BatchNormLayer);
  164. REGISTER_LAYER_CLASS(BatchNorm);
  165. } // namespace caffe

  

BatchNorm caffe源码的更多相关文章

  1. caffe源码学习之Proto数据格式【1】

    前言: 由于业务需要,接触caffe已经有接近半年,一直忙着阅读各种论文,重现大大小小的模型. 期间也总结过一些caffe源码学习笔记,断断续续,这次打算系统的记录一下caffe源码学习笔记,巩固一下 ...

  2. Caffe源码理解2:SyncedMemory CPU和GPU间的数据同步

    目录 写在前面 成员变量的含义及作用 构造与析构 内存同步管理 参考 博客:blog.shinelee.me | 博客园 | CSDN 写在前面 在Caffe源码理解1中介绍了Blob类,其中的数据成 ...

  3. caffe源码阅读

    参考网址:https://www.cnblogs.com/louyihang-loves-baiyan/p/5149628.html 1.caffe代码层次熟悉blob,layer,net,solve ...

  4. Caffe源码中syncedmem文件分析

    Caffe源码(caffe version:09868ac , date: 2015.08.15)中有一些重要文件,这里介绍下syncedmem文件. 1.      include文件: (1).& ...

  5. Caffe源码中math_functions文件分析

    Caffe源码(caffe version:09868ac , date: 2015.08.15)中有一些重要文件,这里介绍下math_functions文件. 1.      include文件: ...

  6. Caffe源码中caffe.proto文件分析

    Caffe源码(caffe version:09868ac , date: 2015.08.15)中有一些重要文件,这里介绍下caffe.proto文件. 在src/caffe/proto目录下有一个 ...

  7. Caffe源码阅读(1) 全连接层

    Caffe源码阅读(1) 全连接层 发表于 2014-09-15   |   今天看全连接层的实现.主要看的是https://github.com/BVLC/caffe/blob/master/src ...

  8. vscode下调试caffe源码

    caffe目录: ├── build -> .build_release // make生成目录,生成各种可执行bin文件,直接调用入口: ├── cmake ├── CMakeLists.tx ...

  9. Caffe源码中common文件分析

    Caffe源码(caffe version:09868ac , date: 2015.08.15)中的一些重要头文件如caffe.hpp.blob.hpp等或者外部调用Caffe库使用时,一般都会in ...

随机推荐

  1. 微服务Kong(八)——代理参考

    Kong侦听四个端口的请求,默认情况是: 8000:此端口是Kong用来监听来自客户端的HTTP请求的,并将此请求转发到您的上游服务.这也是本教程中最主要用到的端口. 8443:此端口是Kong监听H ...

  2. Wannafly挑战赛21 C 大水题

    https://www.nowcoder.com/acm/contest/159/C dp #include <cstdio> #include <cstdlib> #incl ...

  3. SpringCloud微服务实战-Zuul-APIGateway(十)

    本文转自:http://blog.csdn.net/qq_22841811/article/details/67637786#准备工作 1 API Gateway 2 Zuul介绍 2.1 zuul的 ...

  4. python 套接字之select poll epoll

    python下的select模块使用 以及epoll与select.poll的区别 先说epoll与select.poll的区别(总结) select, poll, epoll 都是I/O多路复用的具 ...

  5. MVVM模式原则

    1.MVVM简介 这个模式的核心是ViewModel,它是一种特殊的model类型,用于表示程序的UI状态.它包含描述每个UI控件的状态的属性.例如,文本输入域的当前文本,或者一个特定按钮是否可用.它 ...

  6. 科学计算三维可视化---Mlab基础(常用控制函数)

  7. vue 开发过程中遇到的问题

    1. gitlab团队协作开发 2. element ui 问题集锦 3. 使用vue和ElementUI快速开发后台管理系统

  8. 使用object_box遇到的崩溃 java.lang.UnsatisfiedLinkError:

    java.lang.UnsatisfiedLinkError: dalvik.system.PathClassLoader[DexPathList[[zip file "/data/app/ ...

  9. angularJS 事件广播与接收

    发送消息: $scope.$emit(name, data) 或者 $scope.$broadcast(name, data); 接收消息: $scope.on(name,function(event ...

  10. pandas 视频讲座 from youtube

    Stephen Simmons - Pandas from the inside - YouTube https://www.youtube.com/watch?v=Dr3Hv7aUkmU 2016年 ...