caffe中solver的作用就是交替低啊用前向(forward)算法和后向(backward)算法来更新参数,从而最小化loss,实际上就是一种迭代的优化算法。

solver.cpp中的Solver提供了执行模型训练的入口,在caffe.cpp中train函数的最后通过 solver->Solve()调用:

  1. template <typename Dtype>
  2. void Solver<Dtype>::Solve(const char* resume_file) {
  3. //检查是否是root_solver,有多个GPU的情况下,允许设置多个solver,GPU间并行工作,
  4. //第一个solver设置为root_solver
  5. CHECK(Caffe::root_solver());
  6. //网络名称
  7. LOG(INFO) << "Solving " << net_->name();
  8. //学习策略
  9. LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
  10. // Initialize to false every time we start solving.
  11. requested_early_exit_ = false;
  12. //是否需要从指针所指向的内存读取出之前的训练状态并恢复
  13. if (resume_file) {
  14. LOG(INFO) << "Restoring previous solver status from " << resume_file;
  15. Restore(resume_file);
  16. }
  17. // For a network that is trained by the solver, no bottom or top vecs
  18. // should be given, and we will just provide dummy vecs.
  19. int start_iter = iter_;
  20. //逐步迭代开始
  21. Step(param_.max_iter() - iter_);
  22. ……
  23. }

Solver首先判断执行模式,输出网络名称以及学习策略,并判断是否需要恢复之前的训练状态,之后开始调用Step函数,开始迭代过程。Solver类中的Step函数完成网络模型的逐步优化迭代过程:

  1. template <typename Dtype>
  2. //Step函数完成实际的逐步迭代优化过程
  3. void Solver<Dtype>::Step(int iters) {
  4. //设置开始的迭代次数,如果之前设置了是从snapshot中恢复的,则会从
  5. //snapshot的训练状态继续执行训练
  6. const int start_iter = iter_;
  7. //总的迭代次数
  8. const int stop_iter = iter_ + iters;
  9. //获取设置的要计算之前多少次的loss均值,默认的average_loss为1
  10. int average_loss = this->param_.average_loss();
  11. //清除保存loss的向量
  12. losses_.clear();
  13. //平均loss初始化为0
  14. smoothed_loss_ = 0;
  15. //执行迭代
  16. while (iter_ < stop_iter) {
  17. //清零上一次反向传输过程中产生的梯度数据
  18. // zero-init the params
  19. net_->ClearParamDiffs();
  20. //判断条件,是否执行一次所有测试
  21. if (param_.test_interval() && iter_ % param_.test_interval() == 0
  22. && (iter_ > 0 || param_.test_initialization())
  23. && Caffe::root_solver()) {
  24. TestAll();
  25. if (requested_early_exit_) {
  26. // Break out of the while loop because stop was requested while testing.
  27. break;
  28. }
  29. }
  30. for (int i = 0; i < callbacks_.size(); ++i) {
  31. callbacks_[i]->on_start();
  32. }
  33. //是否输出loss等信息
  34. const bool display = param_.display() && iter_ % param_.display() == 0;
  35. net_->set_debug_info(display && param_.debug_info());
  36. // accumulate the loss and gradient
  37. Dtype loss = 0;
  38. //iter_size是在solver.prototxt中设置的,把数据分为多少批次分开迭代,对应还有一个名称为
  39. //batch_size的变量,是在网络中定义的,batch_size定义每批次包含的样本数量,把一个大的
  40. //样本数量分批次训练可以提高训练效率,总的样本数量=iter_size*batch_size
  41. for (int i = 0; i < param_.iter_size(); ++i) {
  42. //累加所有批次的平均误差
  43. loss += net_->ForwardBackward();
  44. }
  45. //计算批次的平均误差
  46. loss /= param_.iter_size();
  47. //更新输出的当前的average_loss个样本的平均loss
  48. // average the loss across iterations for smoothed reporting
  49. UpdateSmoothedLoss(loss, start_iter, average_loss);
  50. if (display) {
  51. //输出迭代次数,平均loss
  52. LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
  53. << ", loss = " << smoothed_loss_;
  54. const vector<Blob<Dtype>*>& result = net_->output_blobs();
  55. int score_index = 0;
  56. for (int j = 0; j < result.size(); ++j) {
  57. const Dtype* result_vec = result[j]->cpu_data();
  58. const string& output_name =
  59. net_->blob_names()[net_->output_blob_indices()[j]];
  60. const Dtype loss_weight =
  61. net_->blob_loss_weights()[net_->output_blob_indices()[j]];
  62. for (int k = 0; k < result[j]->count(); ++k) {
  63. ostringstream loss_msg_stream;
  64. if (loss_weight) {
  65. loss_msg_stream << " (* " << loss_weight
  66. << " = " << loss_weight * result_vec[k] << " loss)";
  67. }
  68. LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
  69. << score_index++ << ": " << output_name << " = "
  70. << result_vec[k] << loss_msg_stream.str();
  71. }
  72. }
  73. }
  74. for (int i = 0; i < callbacks_.size(); ++i) {
  75. callbacks_[i]->on_gradients_ready();
  76. }
  77. //执行网络更新,每一组网络中的参数的更新都是不同类型的solver实现各自的
  78. //ApplyUpdate函数中完成的
  79. ApplyUpdate();
  80. // Increment the internal iter_ counter -- its value should always indicate
  81. // the number of times the weights have been updated.
  82. ++iter_;
  83. SolverAction::Enum request = GetRequestedAction();
  84. // Save a snapshot if needed.
  85. if ((param_.snapshot()
  86. && iter_ % param_.snapshot() == 0
  87. && Caffe::root_solver()) ||
  88. (request == SolverAction::SNAPSHOT)) {
  89. Snapshot();
  90. }
  91. if (SolverAction::STOP == request) {
  92. requested_early_exit_ = true;
  93. // Break out of training loop.
  94. break;
  95. }
  96. }
  97. }

一次完整的训练流程包括一次前向传输和一次反向传输,分别计算模型的loss和梯度,通过梯度数据计算出参数的更新,更新是通过在Step函数中调用ApplyUpdate函数完成的,ApplyUpdate是在SGDSolver类中定义的:

  1. template <typename Dtype>
  2. void SGDSolver<Dtype>::ApplyUpdate() {
  3. CHECK(Caffe::root_solver());
  4. //根据设置的lr_policy,依据对应的规则计算当前迭代的learning rete的值
  5. Dtype rate = GetLearningRate();
  6. //是否输出当前的迭代次数和学习率数据
  7. if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
  8. LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
  9. }
  10. //避免梯度爆炸,如果梯度的L1或L2范数超过了某个上限值,则将梯度减小
  11. ClipGradients();
  12. //更新所有参数,包括卷积层和池化层的卷积核和偏置两组参数
  13. for (int param_id = 0; param_id < this->net_->learnable_params().size();
  14. ++param_id) {
  15. //将参数的梯度归一化,除以iter_size,其作用是保证实际的batch_size=iter_size*batch_size
  16. Normalize(param_id);
  17. //将正则化部分的梯度存入到每个参数的梯度中
  18. Regularize(param_id);
  19. //计算SGD算法的梯度(momentum等)
  20. ComputeUpdateValue(param_id, rate);
  21. }
  22. //调用Net::Update更新参数
  23. this->net_->Update();
  24. }
  25. template <typename Dtype>
  26. void SGDSolver<Dtype>::Normalize(int param_id) {
  27. //如果训练数据的批次数为1,则不进行归一化,直接返回
  28. if (this->param_.iter_size() == 1) { return; }
  29. //获取所有要优化的参数
  30. // Scale gradient to counterbalance accumulation.
  31. const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
  32. //归一化系数
  33. const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();
  34. switch (Caffe::mode()) {
  35. case Caffe::CPU: {
  36. //CPU中执行归一化操作的函数
  37. caffe_scal(net_params[param_id]->count(), accum_normalization,
  38. net_params[param_id]->mutable_cpu_diff());
  39. break;
  40. }
  41. case Caffe::GPU: {
  42. #ifndef CPU_ONLY
  43. caffe_gpu_scal(net_params[param_id]->count(), accum_normalization,
  44. net_params[param_id]->mutable_gpu_diff());
  45. #else
  46. NO_GPU;
  47. #endif
  48. break;
  49. }
  50. default:
  51. LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
  52. }
  53. }
  54. template <typename Dtype>
  55. void SGDSolver<Dtype>::Regularize(int param_id) {
  56. //获取所有要优化的参数
  57. const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
  58. //获取所有要优化的参数的权重衰减向量
  59. const vector<float>& net_params_weight_decay =
  60. this->net_->params_weight_decay();
  61. //获取网络模型整体的权重衰减
  62. Dtype weight_decay = this->param_.weight_decay();
  63. //获取网络的正则化类型,L1或者L2
  64. string regularization_type = this->param_.regularization_type();
  65. //每一个参数的权重衰减等于每个参数的权重衰减乘以网络整体的权重衰减
  66. Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
  67. switch (Caffe::mode()) {
  68. case Caffe::CPU: {
  69. if (local_decay) { //权重为0时,代表梯度消失
  70. if (regularization_type == "L2") {
  71. // add weight decay
  72. //执行正则化,L2的梯度diff_=weight_decay*data_+diff_
  73. caffe_axpy(net_params[param_id]->count(),
  74. local_decay,
  75. net_params[param_id]->cpu_data(),
  76. net_params[param_id]->mutable_cpu_diff());
  77. } else if (regularization_type == "L1") {
  78. caffe_cpu_sign(net_params[param_id]->count(),
  79. net_params[param_id]->cpu_data(),
  80. temp_[param_id]->mutable_cpu_data());
  81. caffe_axpy(net_params[param_id]->count(),
  82. local_decay,
  83. temp_[param_id]->cpu_data(),
  84. net_params[param_id]->mutable_cpu_diff());
  85. } else {
  86. LOG(FATAL) << "Unknown regularization type: " << regularization_type;
  87. }
  88. }
  89. break;
  90. }
  91. ……
  92. }

【撸码caffe四】 solver.cpp&&sgd_solver.cpp的更多相关文章

  1. 【撸码caffe 三】 caffe.cpp

    caffe.cpp文件完成对网络模型以及模型配置参数的读入和提取,提供了网络模型训练的入口函数train和对模型的测试入口函数test.文件中使用了很多gflags和glog指令,gflags是goo ...

  2. 【撸码caffe 五】数据层搭建

    caffe.cpp中的train函数内声明了一个类型为Solver类的智能指针solver: // Train / Finetune a model. int train() { -- shared_ ...

  3. 【撸码caffe 二】 blob.hpp

    Blob类是caffe中对处理和传递的实际数据的封装,是caffe中基本的数据存储单元,包括前向传播中的图像数据,反向传播中的梯度数据以及网络层间的中间数据变量(包括权值,偏置等),训练模型的参数等等 ...

  4. 【撸码caffe 一】syncedmen.hpp

    SyncedMemory类主要负责在主机(CPU)和设备(GPU)之间管理内存分配和数据同步工作,封装了CPU和GPU之间的数据交互操作. 补充一点GPU的相关知识: 对CUDA架构而言,主机端的内存 ...

  5. 36 网络相关函数(四)——live555源码阅读(四)网络

    36 网络相关函数(四)——live555源码阅读(四)网络 36 网络相关函数(四)——live555源码阅读(四)网络 简介 7)createSocket创建socket方法 8)closeSoc ...

  6. 34 网络相关函数(二)——live555源码阅读(四)网络

    34 网络相关函数(二)——live555源码阅读(四)网络 34 网络相关函数(二)——live555源码阅读(四)网络 2)socketErr 套接口错误 3)groupsockPriv函数 4) ...

  7. 响应国家号召,在家撸码之React迁移记

    最近这段时间新型冠状病毒肆虐,上海确诊人数每天都在增加,人人提心吊胆,街上都没人了.为了响应国家号召,近期呆在家里撸码,着手将项目迁移到React中,项目比较朴素,是一张线索提交页面,包含表单.图片滚 ...

  8. 【深度学习】之Caffe的solver文件配置(转载自csdn)

    原文: http://blog.csdn.net/czp0322/article/details/52161759 今天在做FCN实验的时候,发现solver.prototxt文件一直用的都是mode ...

  9. 40 网络相关函数(八)——live555源码阅读(四)网络

    40 网络相关函数(八)——live555源码阅读(四)网络 40 网络相关函数(八)——live555源码阅读(四)网络 简介 15)writeSocket向套接口写数据 TTL的概念 函数send ...

随机推荐

  1. DOM对象之window

    window的属性 top:返回当前窗口的最顶层的先辈窗口 document:返回HTML文档对象 location:当前窗口的地址 self:返回对自身窗口的引用 parent:返回父窗口 如何引用 ...

  2. sublime之markdown快捷键

    目录 sublime 快捷键 markdown技能 sublime 快捷键 ctrl + shift + p 进入命令面板 package install 进入下载仓库 ctrl + alt + O ...

  3. java_File对象

    package File; import java.io.File; import java.io.IOException; public class file { public static voi ...

  4. iOS crash log 解析 symbol address = stack address - slide 运行时获取slide的api 利用dwarfdump从dsym文件中得到symbol

    概述: 为什么 crash log 内 Exception Backtrace 部分的地址(stack address)不能从 dsym 文件中查出对应的代码? 因为 ASLR(Address spa ...

  5. 如何在Centos里面,把.net core程序设为开机自启动

    确定你的.net core程序可以在centos手动启动后,下一步,就是把这个程序做成一个服务,让它开机自自动了 1.创建脚本文件 到目录/etc/rc.d/init.d下面,创建一个myserver ...

  6. return和return false的区别

    1. return返回null,起到中断方法执行的效果,只要不return false事件处理函数将会继续执行,表单将提交2. return false,事件处理函数会取消事件,不再继续向下执行.比如 ...

  7. C# 泛基

    1 你有时候希望在父类规定一些行为,让子类无法修改,但是这些实现是依赖一个子类才能获取的值,你又不可能知道所有的子类 ,没办法替它在父类里面初始化,这时候就需要在父类里面定义一个每个子类一个的,但又是 ...

  8. 27.7 并行语言集成查询(PLinq)

    static void Main() { ObsoleteMethods(Assembly.Load("mscorlib.dll")); Console.ReadKey(); } ...

  9. 1 WebService 常见问题

    <binding name="> <readerQuotas maxStringContentLength=" /> </binding> &l ...

  10. 【解题报告】 Leapin' Lizards HDU 2732 网络流

    [解题报告] Leapin' Lizards HDU 2732 网络流 题外话 在正式讲这个题目之前我想先说几件事 1. 如果大家要做网络流的题目,我在网上看到一个家伙,他那里列出了一堆网络流的题目, ...