1. Caffe源码
  2. Blob
  3. protected:
  4. shared_ptr<SyncedMemory> data_;
  5. shared_ptr<SyncedMemory> diff_;
  6. shared_ptr<SyncedMemory> shape_data_;
  7. vector<int> shape_;
  8. int count_;
  9. int capacity_;
  10.  
  11. Blob的构造函数
  12. Blob<Dtype>::Blob(const int num, const int channels, const int height,
  13. const int width)
  14. // capacity_ must be initialized before calling Reshape
  15. : capacity_(0) {
  16. Reshape(num, channels, height, width);
  17. }
  18.  
  19. 会调用reshape函数,为data_,diff_分配内存
  20. template <typename Dtype>
  21. void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
  22. const int width) {
  23. vector<int> shape(4);
  24. shape[0] = num;
  25. shape[1] = channels;
  26. shape[2] = height;
  27. shape[3] = width;
  28. Reshape(shape);
  29. }
  30.  
  31. template <typename Dtype>
  32. void Blob<Dtype>::Reshape(const vector<int>& shape) {
  33. CHECK_LE(shape.size(), kMaxBlobAxes);
  34. count_ = 1;
  35. shape_.resize(shape.size());
  36. if (!shape_data_ || shape_data_->size() < shape.size() * sizeof(int)) {
  37. shape_data_.reset(new SyncedMemory(shape.size() * sizeof(int)));
  38. }
  39. int* shape_data = static_cast<int*>(shape_data_->mutable_cpu_data());
  40. for (int i = 0; i < shape.size(); ++i) {
  41. CHECK_GE(shape[i], 0);
  42. CHECK_LE(shape[i], INT_MAX / count_) << "blob size exceeds INT_MAX";
  43. count_ *= shape[i];
  44. shape_[i] = shape[i];
  45. shape_data[i] = shape[i];
  46. }
  47. if (count_ > capacity_) {
  48. capacity_ = count_;
  49. data_.reset(new SyncedMemory(capacity_ * sizeof(Dtype)));
  50. diff_.reset(new SyncedMemory(capacity_ * sizeof(Dtype)));
  51. }
  52.  
  53. Blob的序列化函数:
  54. //in blob.hpp
  55. void FromProto(const BlobProto& proto, bool reshape = true);
  56. void ToProto(BlobProto* proto, bool write_diff = false) const;
  57. ToProtoBlobshape_,data_,diff_分别copyBlobProtoshape,data,diff,完成序列化。FromProtoBlobProtoshape,data,diff分别copyBlobshape_,data_,diff_,完成数据解析。最后数据持久化函数由Protocol Buffers的工具实现
  58.  
  59. Blob中还有个更新参数的函数update(),data=data-diff
  60. void Blob<Dtype>::Update() {
  61. // We will perform update based on where the data is located.
  62. switch (data_->head()) {
  63. case SyncedMemory::HEAD_AT_CPU:
  64. // perform computation on CPU
  65. caffe_axpy<Dtype>(count_, Dtype(-1),
  66. static_cast<const Dtype*>(diff_->cpu_data()),
  67. static_cast<Dtype*>(data_->mutable_cpu_data()));
  68. break;
  69. case SyncedMemory::HEAD_AT_GPU:
  70. case SyncedMemory::SYNCED:
  71. #ifndef CPU_ONLY
  72. // perform computation on GPU
  73. caffe_gpu_axpy<Dtype>(count_, Dtype(-1),
  74. static_cast<const Dtype*>(diff_->gpu_data()),
  75. static_cast<Dtype*>(data_->mutable_gpu_data()));
  76. #else
  77. NO_GPU;
  78. #endif
  79. break;
  80. default:
  81. LOG(FATAL) << "Syncedmem not initialized.";
  82. }
  83. }
  84.  
  85. Layer5纯虚函数
  86. Reshape()
  87. Forward_cpu()
  88. Backword_cpu()
  89. Forward_gpu()
  90. Backword_gpu()
  91.  
  92. Layer层:
  93. Loss_layer
  94. Common_layer没有了(softmax,innerproduct)
  95. Neuron_layer(tanh)
  96. Vision layer没有了(pooling,conv)
  97. Data_layer变成了BasePrefetchingDataLayer(hdf5 input)
  98.  
  99. Net
  100. Solver
  101. 整个过程
  102. solver变量的构造函数中有init(param)
  103. init中有initTrainNet()函数,initTrainNet()函数有net_.reset(new Net<Dtype>(net_param));
  104. 然后调用net的构造函数
  105. template <typename Dtype>
  106. Net<Dtype>::Net(const NetParameter& param, const Net* root_net)
  107. : root_net_(root_net) {
  108. Init(param);
  109. }
  110. 通过一个for循环将layer一个一个串起来,并且调用layersetup函数
  111. // layer 初始化设置
  112. void SetUp(const vector<Blob<Dtype>*>& bottom,
  113. const vector<Blob<Dtype>*>& top) {
  114. InitMutex();
  115. CheckBlobCounts(bottom, top);
  116. LayerSetUp(bottom, top);
  117. Reshape(bottom, top);
  118. SetLossWeights(top);
  119. }
  120. LayerSetUp(bottom, top):由Layer类派生出的特定类都需要重写这个函数,主要功能是设置权值参数(包括偏置)的空间以及对权值参数经行随机初始化。 
  121. Reshape(bottom, top):根据输出blob和权值参数计算输出blob的维数,并申请空间。
  122.  
  123. 经过上述过程基本上就完成了初始化的工作,总体的流程大概就是新建一个Solver对象,然后调用Solver类的构造函数,然后在Solver的构造函数中又会新建Net类实例,在Net类的构造函数中又会新建各个Layer的实例,一直具体到设置每个Blob,大概就介绍完了网络初始化的工作,当然里面还有很多具体的细节,但大概的流程就是这样。
  124.  
  125. 上面过程就是从shared_ptr<caffe::Solver<float> > //初始化
  126. solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
  127. 这个solver开始的
  128. solver->Solve();
  129. template <typename Dtype>
  130. void Solver<Dtype>::Solve(const char* resume_file) {
  131. ...
  132. int start_iter = iter_;
  133. ...
  134. // 然后调用了'Step'函数,这个函数执行了实际的逐步的迭代过程
  135. Step(param_.max_iter() - iter_);
  136. ...
  137. LOG(INFO) << "Optimization Done.";
  138. }
  139.  
  140. step函数如下
  141. template <typename Dtype>
  142. void Solver<Dtype>::Step(int iters) {
  143. ...
  144. //迭代
  145. while (iter_ < stop_iter) {
  146. ...
  147. // iter_size也是在solver.prototxt里设置,实际上的batch_size=iter_size*网络定义里的batch_size,
  148. // 因此每一次迭代的loss是iter_size次迭代的和,再除以iter_size,这个loss是通过调用`Net::ForwardBackward`函数得到的
  149. // accumulate gradients over `iter_size` x `batch_size` instances
  150. for (int i = 0; i < param_.iter_size(); ++i) {
  151. /*
  152. * 调用了Net中的代码,主要完成了前向后向的计算,
  153. * 前向用于计算模型的最终输出和Loss,后向用于
  154. * 计算每一层网络和参数的梯度。
  155. */
  156. loss += net_->ForwardBackward();
  157. }
  158.  
  159. ...
  160.  
  161. /*
  162. * 这个函数主要做Loss的平滑。由于Caffe的训练方式是SGD,我们无法把所有的数据同时
  163. * 放入模型进行训练,那么部分数据产生的Loss就可能会和全样本的平均Loss不同,在必要
  164. * 时候将Loss和历史过程中更新的Loss求平均就可以减少Loss的震荡问题。
  165. */
  166. UpdateSmoothedLoss(loss, start_iter, average_loss);
  167.  
  168. ...
  169. // 执行梯度的更新,这个函数在基类`Solver`中没有实现,会调用每个子类自己的实现
  170. //,后面具体分析`SGDSolver`的实现
  171. ApplyUpdate();
  172.  
  173. // 迭代次数加1
  174. ++iter_;
  175. ...
  176.  
  177. }
  178. }
  179.  
  180. // 进行一次正向传播,一次反向传播
  181. Dtype ForwardBackward() {
  182. Dtype loss;
  183. Forward(&loss);
  184. Backward();
  185. return loss;
  186. }
  187.  
  188. for (int i = start; i <= end; ++i) {
  189. // 对每一层进行前向计算,返回每层的loss,其实只有最后一层loss不为0
  190. Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], top_vecs_[i]);
  191. loss += layer_loss;
  192. if (debug_info_) { ForwardDebugInfo(i); }
  193. }
  194.  
  195. ApplyUpdate();
  196. 这个函数是Solver类的纯虚函数,需要派生类来实现,比如SGDSolver类实现的ApplyUpdate();函数如下,主要内容包括:设置参数的学习率;对梯度进行Normalize;对反向求导得到的梯度添加正则项的梯度;最后根据SGD算法计算最终的梯度;最后的最后把计算得到的最终梯度对权值进行更新。
  197.  
  198. template <typename Dtype>
  199. void SGDSolver<Dtype>::ApplyUpdate() {
  200. CHECK(Caffe::root_solver());
  201.  
  202. // GetLearningRate根据设置的lr_policy来计算当前迭代的learning rate的值
  203. Dtype rate = GetLearningRate();
  204.  
  205. // 判断是否需要输出当前的learning rate
  206. if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
  207. LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
  208. }
  209.  
  210. // 避免梯度爆炸,如果梯度的二范数超过了某个数值则进行scale操作,将梯度减小
  211. ClipGradients();
  212.  
  213. // 对所有可更新的网络参数进行操作
  214. for (int param_id = 0; param_id < this->net_->learnable_params().size();
  215. ++param_id) {
  216. // 将第param_id个参数的梯度除以iter_size,
  217. // 这一步的作用是保证实际的batch_size=iter_size*设置的batch_size
  218. Normalize(param_id);
  219.  
  220. // 将正则化部分的梯度降入到每个参数的梯度中
  221. Regularize(param_id);
  222.  
  223. // 计算SGD算法的梯度(momentum等)
  224. ComputeUpdateValue(param_id, rate);
  225. }
  226. // 调用`Net::Update`更新所有的参数
  227. this->net_->Update();
  228. }

  

caffe源码整个训练过程的更多相关文章

  1. caffe源码阅读

    参考网址:https://www.cnblogs.com/louyihang-loves-baiyan/p/5149628.html 1.caffe代码层次熟悉blob,layer,net,solve ...

  2. caffe源码学习

    本文转载自:https://buptldy.github.io/2016/10/09/2016-10-09-Caffe_Code/ Caffe简介 Caffe作为一个优秀的深度学习框架网上已经有很多内 ...

  3. caffe源码学习之Proto数据格式【1】

    前言: 由于业务需要,接触caffe已经有接近半年,一直忙着阅读各种论文,重现大大小小的模型. 期间也总结过一些caffe源码学习笔记,断断续续,这次打算系统的记录一下caffe源码学习笔记,巩固一下 ...

  4. Caffe源码-几种优化算法

    SGD简介 caffe中的SGDSolver类中实现了带动量的梯度下降法,其原理如下,\(lr\)为学习率,\(m\)为动量参数. 计算新的动量:history_data = local_rate * ...

  5. Symfony2源码分析——启动过程2

    文章地址:http://www.hcoding.com/?p=46 上一篇分析Symfony2框架源码,探究Symfony2如何完成一个请求的前半部分,前半部分可以理解为Symfony2框架为处理请求 ...

  6. Caffe源码理解2:SyncedMemory CPU和GPU间的数据同步

    目录 写在前面 成员变量的含义及作用 构造与析构 内存同步管理 参考 博客:blog.shinelee.me | 博客园 | CSDN 写在前面 在Caffe源码理解1中介绍了Blob类,其中的数据成 ...

  7. c#源码的执行过程

    我想也许要写些东西,记录我做程序员的日子吧 ================================================ 要讲到C#源码的执行过程 首先要提下程序集,因为Clr并不 ...

  8. Caffe源码中syncedmem文件分析

    Caffe源码(caffe version:09868ac , date: 2015.08.15)中有一些重要文件,这里介绍下syncedmem文件. 1.      include文件: (1).& ...

  9. Caffe源码中math_functions文件分析

    Caffe源码(caffe version:09868ac , date: 2015.08.15)中有一些重要文件,这里介绍下math_functions文件. 1.      include文件: ...

随机推荐

  1. [NOI2011]兔兔与蛋蛋游戏 二分图博弈

    题面 题面 题解 通过观察,我们可以发现如下性质: 可以看做是2个人在不断移动空格,只是2个人能移动的边不同 一个位置不会被重复经过 : 根据题目要求,因为是按黑白轮流走,所以不可能重复经过一个点,不 ...

  2. [COCI2011-2012#5] POPLOCAVANJE 后缀自动机

    题面:洛谷 题解: 其实还可以用AC自动机做,但是没调出来,,,不知道发生了什么... AC自动机做法如下: 观察到如果我们对给定的每个串建AC自动机,那么直接拿大串在上面匹配,如果遇到了一个单词的终 ...

  3. windows2016上如何通过攻击ETERNALBLUE获得meterpreter反弹

    windows2016上如何通过攻击ETERNALBLUE获得meterpreter反弹 译:by  backlion 0x00前言 当微软发布MS17-010漏洞的补丁时,该漏洞影响的范围是从Win ...

  4. 框架----Django之Form提交验证(一)

    一.Form提交验证与Ajax提交验证的运用实例 Form表单提交时会刷新页面,输入失败时,输入框内内容也会随之刷新不能保留:而Ajax提交是在后台偷偷提交,不会刷新页面,因此也就可以保留页面输入框内 ...

  5. Android MediaRecorder解析

    源码路径:frameworks/base/media/java/android/media/MediaRecorder.javaframeworks/base/media/jni/android_me ...

  6. source 导入文件

    有时候,phpmyadmin 导入  是有大小限制的: 只可以用sql命令的source来导入文件

  7. LVS三种模式的区别及负载均衡算法

    LVS简介 LVS(Linux Virtual Server)即Linux虚拟服务器,是一个虚拟的服务器集群系统,由章文嵩博士在1998年5月成立,在linux2.6+后将lvs自动加入了kernel ...

  8. java格式化字符串,在指定位置插入指定字符串,兼容中英文以及特殊字符,例如:换行,用于解决生成pdf换行问题等问题

    本博客是自己在学习和工作途中的积累与总结,仅供自己参考,也欢迎大家转载,转载时请注明出处.  http://www.cnblogs.com/king-xg/p/6370890.html 如果觉得对您有 ...

  9. TersorflowTutorial_MNIST数据集上简单CNN实现

    MNIST数据集上简单CNN实现 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 Tensorflow机器学习实战指南 源代码请点击下方链接欢迎加星 Tesorflow实现基于MNI ...

  10. Handlebars 使用

    引入js <script src="js/json3.min.js"></script> <script src="js/handlebar ...