Solver类简介

Net类中实现了网络的前向/反向计算和参数更新,而Solver类中则是对此进行进一步封装,包含可用于逐次训练网络的Step()函数,和用于求解网络的优化解的Solve()函数,同时还实现了一些存储、读取网络模型快照的接口函数。

solver.cpp源码

  1. template<typename Dtype>
  2. void Solver<Dtype>::SetActionFunction(ActionCallback func) {
  3. action_request_function_ = func; //设置回调函数,该函数会返回求解器的动作类型
  4. }
  5. template<typename Dtype>
  6. SolverAction::Enum Solver<Dtype>::GetRequestedAction() { //返回求解器的动作类型
  7. if (action_request_function_) {
  8. // If the external request function has been set, call it.
  9. return action_request_function_(); //运行回调函数,该函数会返回求解器的动作类型
  10. }
  11. return SolverAction::NONE;
  12. }
  13. template <typename Dtype>
  14. Solver<Dtype>::Solver(const SolverParameter& param) //构造函数,使用param消息初始化求解器
  15. : net_(), callbacks_(), requested_early_exit_(false) {
  16. Init(param); //使用param消息初始化当前求解器
  17. }
  18. template <typename Dtype>
  19. Solver<Dtype>::Solver(const string& param_file)
  20. : net_(), callbacks_(), requested_early_exit_(false) { //构造函数,从文本类型的proto文件中读取求解器参数
  21. SolverParameter param;
  22. ReadSolverParamsFromTextFileOrDie(param_file, &param); //从param_file中读取消息数据到param中
  23. Init(param); //初始化求解器
  24. }
  25. template <typename Dtype>
  26. void Solver<Dtype>::Init(const SolverParameter& param) { //Solver类初始化
  27. LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
  28. << std::endl << param.DebugString(); //主线程中打印信息
  29. param_ = param;
  30. //loss的滑动平均窗的长度,每次计算最近average_loss_次的平均loss
  31. CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
  32. CheckSnapshotWritePermissions(); //检查是否能够打开快照文件
  33. if (param_.random_seed() >= 0) { //SolverParameter消息中设置了随机种子
  34. Caffe::set_random_seed(param_.random_seed() + Caffe::solver_rank()); //设置
  35. }
  36. // Scaffolding code
  37. InitTrainNet(); //初始化训练网络
  38. InitTestNets(); //初始化所有测试网络 //训练网络只有一个,但是测试网络可以有多个
  39. if (Caffe::root_solver()) {
  40. LOG(INFO) << "Solver scaffolding done."; //只在主线程中打印
  41. }
  42. iter_ = 0; //初始化参数
  43. current_step_ = 0;
  44. }
  45. // Load weights from the caffemodel(s) specified in "weights" solver parameter
  46. // into the train and test nets.
  47. template <typename Dtype>
  48. void LoadNetWeights(shared_ptr<Net<Dtype> > net, const std::string& model_list) { //加载权重文件
  49. std::vector<std::string> model_names;
  50. boost::split(model_names, model_list, boost::is_any_of(",")); //拆分文件名,权重文件名在model_list中以","中分隔开
  51. for (int i = 0; i < model_names.size(); ++i) {
  52. boost::trim(model_names[i]); //删除首位空格
  53. LOG(INFO) << "Finetuning from " << model_names[i]; //打印权重文件名
  54. net->CopyTrainedLayersFrom(model_names[i]); //从文件中拷贝blob数据到网络的同名参数中
  55. }
  56. }
  57. template <typename Dtype>
  58. void Solver<Dtype>::InitTrainNet() { //初始化训练网络,配置网络参数,加载预训练模型
  59. //训练网络的proto文件名可通过SolverParameter消息中的train_net_param, train_net, net_param, net四个中的任意一个指定
  60. const int num_train_nets = param_.has_net() + param_.has_net_param() +
  61. param_.has_train_net() + param_.has_train_net_param(); //这四个参数总共设置的训练网络个数
  62. const string field_names = "net, net_param, train_net, train_net_param";
  63. CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
  64. << "using one of these fields: " << field_names; //检查是否大于等于1
  65. CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
  66. << "one of these fields specifying a train_net: " << field_names; //检查是否小于等于1 //四个中只能有一个设置了true
  67. NetParameter net_param;
  68. if (param_.has_train_net_param()) { //训练网络的名称在train_net_param中设置了
  69. LOG_IF(INFO, Caffe::root_solver()) << "Creating training net specified in train_net_param."; //主线程中打印
  70. net_param.CopyFrom(param_.train_net_param()); //从NetParameter消息中拷贝网络参数至net_param
  71. } else if (param_.has_train_net()) { //训练网络的名称在train_net中设置了
  72. LOG_IF(INFO, Caffe::root_solver()) << "Creating training net from train_net file: " << param_.train_net();
  73. ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param); //从proto文件中读取网络参数
  74. }
  75. if (param_.has_net_param()) { //训练网络的名称在net_param中设置了
  76. LOG_IF(INFO, Caffe::root_solver()) << "Creating training net specified in net_param.";
  77. net_param.CopyFrom(param_.net_param()); //从NetParameter类型的消息中拷贝网络参数
  78. }
  79. if (param_.has_net()) { //训练网络的名称在net中设置了
  80. LOG_IF(INFO, Caffe::root_solver()) << "Creating training net from net file: " << param_.net();
  81. ReadNetParamsFromTextFileOrDie(param_.net(), &net_param); //从proto文件中读取网络参数
  82. }
  83. // Set the correct NetState. We start with the solver defaults (lowest
  84. // precedence); then, merge in any NetState specified by the net_param itself;
  85. // finally, merge in any NetState specified by the train_state (highest
  86. // precedence).
  87. //Message::MergeFrom()的机制,单字段的值会被覆盖,嵌套消息的值会被融合在一起,重复字段的值会被拼接在一起
  88. //Message::CopyFrom()的机制,清空当前的消息,然后将指定消息MergeFrom()到当前消息中
  89. //net_param中的状态值先是设置为默认值,然后使用从上面四个设置中读取到的网络参数net_param中的网络状态覆盖其中相同的,
  90. //再用当前求解器中设置的SolverParameter消息中的train_state覆盖其中相同的.
  91. //在网络中设置的网络状态优先级低,会被求解器中设置的网络状态覆盖
  92. NetState net_state;
  93. net_state.set_phase(TRAIN); //设置网络的状态,训练模式
  94. net_state.MergeFrom(net_param.state()); //先使用上面的从文件或者消息中读取的网络参数中的网络状态
  95. net_state.MergeFrom(param_.train_state()); //再使用当前求解器中设置的训练网络状态
  96. net_param.mutable_state()->CopyFrom(net_state); //将最终的到的网络状态存入网络参数中
  97. net_.reset(new Net<Dtype>(net_param)); //使用该网络参数初始化网络,存入net_中
  98. for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) { //weights参数的个数
  99. LoadNetWeights(net_, param_.weights(w_idx)); //加载每个参数中的一个或者多个预训练模型到net_中
  100. }
  101. }
  102. template <typename Dtype>
  103. void Solver<Dtype>::InitTestNets() { //初始化测试网络
  104. const bool has_net_param = param_.has_net_param();
  105. const bool has_net_file = param_.has_net();
  106. const int num_generic_nets = has_net_param + has_net_file; //是否设置了模型参数,是否设置了模型文件名
  107. CHECK_LE(num_generic_nets, 1)
  108. << "Both net_param and net_file may not be specified."; //检查是否小于等于1,这两个不能同时指定
  109. const int num_test_net_params = param_.test_net_param_size(); //设置的测试网络的参数的个数
  110. const int num_test_net_files = param_.test_net_size(); //设置的测试网络的个数
  111. const int num_test_nets = num_test_net_params + num_test_net_files; //总个数
  112. if (num_generic_nets) {
  113. //test_iter_表示每个测试网络迭代的次数,test_iter_参数设置的个数必须与测试网络的个数相等
  114. //如果设置了模型参数或者模型文件名,那么这里面也可能设置了test net,所以test_iter_的个数必须大于等于num_test_nets
  115. CHECK_GE(param_.test_iter_size(), num_test_nets)
  116. << "test_iter must be specified for each test network.";
  117. } else {
  118. //没有设置net_parma或者net的话,test net全部在test_net_parma和test_net中指定,个数需相等
  119. CHECK_EQ(param_.test_iter_size(), num_test_nets)
  120. << "test_iter must be specified for each test network.";
  121. }
  122. // If we have a generic net (specified by net or net_param, rather than
  123. // test_net or test_net_param), we may have an unlimited number of actual
  124. // test networks -- the actual number is given by the number of remaining
  125. // test_iters after any test nets specified by test_net_param and/or test_net
  126. // are evaluated.
  127. const int num_generic_net_instances = param_.test_iter_size() - num_test_nets; //相减得到在net_parma或者net中定义的test net的个数
  128. const int num_test_net_instances = num_test_nets + num_generic_net_instances; //总的test net的个数,即为param_.test_iter_size()
  129. if (param_.test_state_size()) { //设置了test_state_,则个数必须与测试网络的个数相等
  130. CHECK_EQ(param_.test_state_size(), num_test_net_instances)
  131. << "test_state must be unspecified or specified once per test net."; //检查个数是否相等
  132. }
  133. if (num_test_net_instances) {
  134. CHECK_GT(param_.test_interval(), 0); //检查设置的测试的迭代间隔是否大于0
  135. }
  136. int test_net_id = 0;
  137. vector<string> sources(num_test_net_instances);
  138. vector<NetParameter> net_params(num_test_net_instances);
  139. //caffe.proto文件中注明了test net运行的优先级,(1) test_net_param, (2) test_net, (3) net_param/net.
  140. for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) {
  141. sources[test_net_id] = "test_net_param"; //保存定义该测试网络的来源
  142. net_params[test_net_id].CopyFrom(param_.test_net_param(i)); //从NetParameter类型的消息中拷贝网络参数
  143. }
  144. for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) {
  145. sources[test_net_id] = "test_net file: " + param_.test_net(i); //保存来源,加上文件名
  146. ReadNetParamsFromTextFileOrDie(param_.test_net(i),
  147. &net_params[test_net_id]); //从proto文件中读取网络参数,存入net_param中
  148. }
  149. const int remaining_test_nets = param_.test_iter_size() - test_net_id; //net_param/net中定义的网络的个数
  150. if (has_net_param) { //定义了net_param,则剩余的测试网络都定义在此处
  151. for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
  152. sources[test_net_id] = "net_param";
  153. net_params[test_net_id].CopyFrom(param_.net_param()); //拷贝网络参数
  154. }
  155. }
  156. if (has_net_file) { //同样,从net文件中定义的测试网络文件名中读取网络参数
  157. for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
  158. sources[test_net_id] = "net file: " + param_.net();
  159. ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
  160. }
  161. }
  162. test_nets_.resize(num_test_net_instances); //调整大小
  163. for (int i = 0; i < num_test_net_instances; ++i) {
  164. // Set the correct NetState. We start with the solver defaults (lowest
  165. // precedence); then, merge in any NetState specified by the net_param
  166. // itself; finally, merge in any NetState specified by the test_state
  167. // (highest precedence).
  168. //与InitTrainNet()中的操作类似,先使用默认值,然后使用网络参数中的网络状态覆盖默认值,再使用
  169. //求解器中设置的测试网络状态覆盖之前的值,得到最终的测试网络状态
  170. NetState net_state;
  171. net_state.set_phase(TEST); //设置模式为test
  172. net_state.MergeFrom(net_params[i].state()); //先使用网络参数中设置的网络状态覆盖
  173. if (param_.test_state_size()) {
  174. net_state.MergeFrom(param_.test_state(i)); //然后使用求解器中设置的测试网络状态覆盖
  175. }
  176. net_params[i].mutable_state()->CopyFrom(net_state); //将最终的测试网络状态存入net_params[i]中
  177. LOG(INFO) << "Creating test net (#" << i << ") specified by " << sources[i]; //打印之前保存的来源信息
  178. test_nets_[i].reset(new Net<Dtype>(net_params[i])); //使用net_params[i]创建网络,存入test_nets_中
  179. test_nets_[i]->set_debug_info(param_.debug_info()); //将求解器的是否打印信息的设置存入网络中
  180. for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
  181. LoadNetWeights(test_nets_[i], param_.weights(w_idx)); //加载预训练模型文件,每个测试网络都会尝试加载所有的预训练模型文件
  182. }
  183. }
  184. }
  185. //求解器单步迭代iters次
  186. template <typename Dtype>
  187. void Solver<Dtype>::Step(int iters) {
  188. const int start_iter = iter_; //当前已迭代的次数
  189. const int stop_iter = iter_ + iters; //终止迭代时的次数
  190. int average_loss = this->param_.average_loss(); //loss的滑动平均窗的长度
  191. losses_.clear(); //清空历史loss值
  192. smoothed_loss_ = 0; //清空
  193. iteration_timer_.Start(); //打开计时器
  194. while (iter_ < stop_iter) {
  195. // zero-init the params
  196. net_->ClearParamDiffs(); //清空网络中所有可学习参数的梯度数据
  197. if (param_.test_interval() && iter_ % param_.test_interval() == 0 //两次测试之间的迭代间隔不为0,且当前轮到测试
  198. && (iter_ > 0 || param_.test_initialization())) { //初始时可以进入测试模式
  199. //test_initialization()仅仅用于表示初始(iter_==0)时是否运行一次测试网络
  200. //该值为true时,(iter_ % test_interval == 0)总是成立,每次开始迭代时都会先进入测试模式.该值为false时只在iter_ > 0时进入测试
  201. if (Caffe::root_solver()) { //测试网络只在主线程中运行
  202. TestAll(); //运行所有测试网络,并打印输出信息
  203. }
  204. if (requested_early_exit_) { //测试过程中出现提前退出动作,退出循环
  205. // Break out of the while loop because stop was requested while testing.
  206. break;
  207. }
  208. }
  209. for (int i = 0; i < callbacks_.size(); ++i) { //solver的回调函数,在多gpu训练时用于同步求解器
  210. callbacks_[i]->on_start();
  211. }
  212. const bool display = param_.display() && iter_ % param_.display() == 0; //设置了打印间隔并且当前迭代轮到打印了
  213. net_->set_debug_info(display && param_.debug_info()); //设置是否打印调试信息
  214. // accumulate the loss and gradient
  215. Dtype loss = 0;
  216. for (int i = 0; i < param_.iter_size(); ++i) { //单次迭代会执行iter_size次前向反向过程
  217. loss += net_->ForwardBackward(); //执行一次前向计算和反向传播,并累加iter_size次计算得到的loss
  218. }
  219. loss /= param_.iter_size(); //每次迭代的平均loss
  220. // average the loss across iterations for smoothed reporting
  221. UpdateSmoothedLoss(loss, start_iter, average_loss); //将loss保存在losses_中,并计算新的均值smoothed_loss_
  222. if (display) { //需要打印此次迭代的信息
  223. float lapse = iteration_timer_.Seconds(); //关闭计时器,返回已运行的时间,单位s
  224. float per_s = (iter_ - iterations_last_) / (lapse ? lapse : 1); //iterations_last_为上次开启计时器时的迭代次数,得到每秒可迭代的次数
  225. LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
  226. << " (" << per_s << " iter/s, " << lapse << "s/"
  227. << param_.display() << " iters), loss = " << smoothed_loss_; //打印迭代次数,迭代速度,运行时间等信息
  228. iteration_timer_.Start(); //重新打开计时器
  229. iterations_last_ = iter_; //保存当前的迭代次数
  230. const vector<Blob<Dtype>*>& result = net_->output_blobs(); //训练网络的所有输出blob
  231. int score_index = 0;
  232. for (int j = 0; j < result.size(); ++j) {
  233. const Dtype* result_vec = result[j]->cpu_data(); //第j个输出blob的data_数据
  234. const string& output_name = net_->blob_names()[net_->output_blob_indices()[j]]; //该输出blob的名称
  235. const Dtype loss_weight = net_->blob_loss_weights()[net_->output_blob_indices()[j]]; //该输出blob的loss权重
  236. for (int k = 0; k < result[j]->count(); ++k) {
  237. ostringstream loss_msg_stream;
  238. if (loss_weight) { //权重不为0时,保存权重和加权后的输出值
  239. loss_msg_stream << " (* " << loss_weight
  240. << " = " << loss_weight * result_vec[k] << " loss)";
  241. }
  242. LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
  243. << score_index++ << ": " << output_name << " = "
  244. << result_vec[k] << loss_msg_stream.str(); //打印信息
  245. }
  246. }
  247. }
  248. //求解器的回调函数,在梯度计算完毕之后调用.同样也是用于多gpu计算时梯度数据的同步
  249. for (int i = 0; i < callbacks_.size(); ++i) {
  250. callbacks_[i]->on_gradients_ready();
  251. }
  252. ApplyUpdate(); //根据学习率,冲量,权重衰减值等参数计算参数更新时使用的梯度,并更新网络中的参数,在SGDSolver类中实现
  253. SolverAction::Enum request = GetRequestedAction(); //获取当前求解器的动作
  254. // Save a snapshot if needed.
  255. if ((param_.snapshot()
  256. && iter_ % param_.snapshot() == 0
  257. && Caffe::root_solver()) ||
  258. (request == SolverAction::SNAPSHOT)) { //当前迭代次数轮到存储快照,或者当前的解器动作为存快照
  259. Snapshot(); //生成快照文件
  260. }
  261. if (SolverAction::STOP == request) { //当前动作为退出,则提前退出
  262. requested_early_exit_ = true;
  263. // Break out of training loop.
  264. break;
  265. }
  266. }
  267. }
  268. template <typename Dtype>
  269. void Solver<Dtype>::Solve(const char* resume_file) { //从resume_file文件中恢复网络和求解器状态,并训练网络
  270. CHECK(Caffe::root_solver()); //在主线程中进行该操作
  271. LOG(INFO) << "Solving " << net_->name();
  272. LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); //打印网络名称和学习率更新策略
  273. // Initialize to false every time we start solving.
  274. requested_early_exit_ = false; //每次求解时初始化下状态
  275. if (resume_file) { //文件名不为空
  276. LOG(INFO) << "Restoring previous solver status from " << resume_file;
  277. Restore(resume_file); //从文件中还原网络参数和求解器的状态
  278. }
  279. // For a network that is trained by the solver, no bottom or top vecs
  280. // should be given, and we will just provide dummy vecs.
  281. int start_iter = iter_; //当前已迭代的次数
  282. Step(param_.max_iter() - iter_); //max_iter_为最大迭代次数,计算当前需要迭代的次数
  283. // If we haven't already, save a snapshot after optimization, unless
  284. // overridden by setting snapshot_after_train := false
  285. if (param_.snapshot_after_train()
  286. && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
  287. //如果设置了训练结束后保存快照,并且当前迭代次数在并未轮到保存快照
  288. //满足 param_.snapshot() && iter_ % param_.snapshot() == 0 的话会在Step()函数中保存当前iter_的快照,此处自然无需再保存
  289. Snapshot();
  290. }
  291. if (requested_early_exit_) { //同样判断下求解器的动作
  292. LOG(INFO) << "Optimization stopped early.";
  293. return;
  294. }
  295. // After the optimization is done, run an additional train and test pass to
  296. // display the train and test loss/outputs if appropriate (based on the
  297. // display and test_interval settings, respectively). Unlike in the rest of
  298. // training, for the train net we only run a forward pass as we've already
  299. // updated the parameters "max_iter" times -- this final pass is only done to
  300. // display the loss, which is computed in the forward pass.
  301. //如果需要显示,会额外进行一次前向计算.这与Step()中的最后一次计算不同,Step()中的最后一次计算包括前向和反向计算,
  302. //还包括参数的更新,此时参数更新之后网络的loss并不知道,所以此处会使用更新后的参数再计算一次前向过程,得到对应的loss
  303. if (param_.display() && iter_ % param_.display() == 0) { //设置了打印求解器的信息并且当前迭代轮到打印了
  304. int average_loss = this->param_.average_loss(); //loss的滑动平均窗的长度
  305. Dtype loss;
  306. net_->Forward(&loss); //一次前向计算
  307. UpdateSmoothedLoss(loss, start_iter, average_loss); //更新losses_,并计算平均loss
  308. LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_; //打印信息
  309. }
  310. if (param_.test_interval() && iter_ % param_.test_interval() == 0) { //设置了测试网络的运行间隔,并且当前轮到测试网络
  311. TestAll(); //运行所有测试网络
  312. }
  313. LOG(INFO) << "Optimization Done."; //求解器优化完成
  314. }
  315. template <typename Dtype>
  316. void Solver<Dtype>::TestAll() { //运行全部测试网络
  317. for (int test_net_id = 0;
  318. test_net_id < test_nets_.size() && !requested_early_exit_; //没有要求提前退出
  319. ++test_net_id) {
  320. Test(test_net_id); //执行第test_net_id个测试网络
  321. }
  322. }
  323. template <typename Dtype>
  324. void Solver<Dtype>::Test(const int test_net_id) { //执行第test_net_id个测试网络
  325. CHECK(Caffe::root_solver()); //测试网络只在主线程中运行
  326. LOG(INFO) << "Iteration " << iter_
  327. << ", Testing net (#" << test_net_id << ")"; //打印迭代信息,测试网络的id
  328. //共享网络,将训练网络net_中的参数blob的数据指针赋给当前的测试网络,只修改测试网络的指针指向位置,不会拷贝数据
  329. CHECK_NOTNULL(test_nets_[test_net_id].get())->ShareTrainedLayersWith(net_.get());
  330. vector<Dtype> test_score;
  331. vector<int> test_score_output_id;
  332. const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id]; //当前的测试网络
  333. Dtype loss = 0;
  334. //test_iter(test_net_id)为第test_net_id个测试网络在测试时需要迭代的次数
  335. for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
  336. SolverAction::Enum request = GetRequestedAction(); //获取当前的求解器动作
  337. // Check to see if stoppage of testing/training has been requested.
  338. while (request != SolverAction::NONE) { //非NONE类型的话,则会执行相应的动作
  339. if (SolverAction::SNAPSHOT == request) { //拍摄快照,并继续训练
  340. Snapshot(); //生成快照文件,并继续当前操作
  341. } else if (SolverAction::STOP == request) { //提前退出
  342. requested_early_exit_ = true;
  343. }
  344. request = GetRequestedAction();
  345. }
  346. if (requested_early_exit_) { //退出,不进行后续的操作
  347. // break out of test loop.
  348. break;
  349. }
  350. Dtype iter_loss;
  351. //执行test_net的一次前向计算过程,loss存入iter_loss中,result为网络的输出blob(net_output_blobs_)
  352. const vector<Blob<Dtype>*>& result = test_net->Forward(&iter_loss);
  353. if (param_.test_compute_loss()) { //是否计算测试网络的平均loss
  354. loss += iter_loss; //累加每次计算出的loss
  355. }
  356. if (i == 0) { //初次计算时,先确定好test_score和test_score_output_id的大小
  357. for (int j = 0; j < result.size(); ++j) {
  358. const Dtype* result_vec = result[j]->cpu_data(); //网络输出的第j个blob的data_
  359. for (int k = 0; k < result[j]->count(); ++k) {
  360. test_score.push_back(result_vec[k]); //将输出blob的data中的数据全部存入test_score中
  361. test_score_output_id.push_back(j); //将数据在输出blob中的来源存入test_score_output_id中
  362. }
  363. }
  364. } else {
  365. int idx = 0;
  366. for (int j = 0; j < result.size(); ++j) { //每个输出blob
  367. const Dtype* result_vec = result[j]->cpu_data(); //输出blob的data_数据
  368. for (int k = 0; k < result[j]->count(); ++k) {
  369. test_score[idx++] += result_vec[k]; //累加测试网络每次迭代时得到的输出blob数据
  370. }
  371. }
  372. }
  373. }
  374. if (requested_early_exit_) { //提前退出?
  375. LOG(INFO) << "Test interrupted.";
  376. return;
  377. }
  378. if (param_.test_compute_loss()) { //是否计算测试网络的平均loss
  379. loss /= param_.test_iter(test_net_id); //计算该测试网络test_iter(test_net_id)次迭代的loss均值
  380. LOG(INFO) << "Test loss: " << loss;
  381. }
  382. for (int i = 0; i < test_score.size(); ++i) {
  383. //数据test_score[i]来源于blob类型的net_output_blobs_[test_score_output_id[i]]中,output_blob_index为该blob在blobs_的索引
  384. const int output_blob_index = test_net->output_blob_indices()[test_score_output_id[i]];
  385. const string& output_name = test_net->blob_names()[output_blob_index]; //该blob的名称
  386. const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index]; //该blob的loss权重
  387. ostringstream loss_msg_stream;
  388. const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id); //除以迭代次数,得到输出blob的均值
  389. if (loss_weight) { //权重非0时,权重和加权值
  390. loss_msg_stream << " (* " << loss_weight << " = " << loss_weight * mean_score << " loss)";
  391. }
  392. LOG(INFO) << " Test net output #" << i << ": " << output_name << " = "
  393. << mean_score << loss_msg_stream.str(); //打印测试网络的每个输出blob中的每个数据的均值
  394. }
  395. }
  396. template <typename Dtype>
  397. void Solver<Dtype>::Snapshot() { //生成两个快照文件,分别保存网络参数(NetParameter类型)和求解器的状态(SolverState类型)
  398. CHECK(Caffe::root_solver()); //同样,存快照只在主线程中操作
  399. string model_filename;
  400. switch (param_.snapshot_format()) { //设置的快照文件格式
  401. case caffe::SolverParameter_SnapshotFormat_BINARYPROTO: //二进制proto类型
  402. model_filename = SnapshotToBinaryProto(); //将训练网络的网络参数存为".caffemodel"后缀的文件,返回其文件名
  403. break;
  404. case caffe::SolverParameter_SnapshotFormat_HDF5: //hdf5类型
  405. model_filename = SnapshotToHDF5(); //将训练网络的网络参数写入文件中,返回其文件名
  406. break;
  407. default:
  408. LOG(FATAL) << "Unsupported snapshot format.";
  409. }
  410. SnapshotSolverState(model_filename); //将求解器的状态(SolverState类型)保存为文件
  411. }
  412. template <typename Dtype>
  413. void Solver<Dtype>::CheckSnapshotWritePermissions() { //检查是否能够创建快照文件(只检查是否能够以写方式创建文件,不会存数据进去)
  414. if (Caffe::root_solver() && param_.snapshot()) { //只在主线程中操作
  415. CHECK(param_.has_snapshot_prefix())
  416. << "In solver params, snapshot is specified but snapshot_prefix is not"; //检查是否设置了快照文件名的前缀
  417. string probe_filename = SnapshotFilename(".tempfile"); //生成快照的文件名,".tempfile"为后缀
  418. std::ofstream probe_ofs(probe_filename.c_str()); //创建临时文件文件
  419. if (probe_ofs.good()) { //判断是否发生错误
  420. probe_ofs.close(); //关闭
  421. std::remove(probe_filename.c_str()); //删除文件
  422. } else {
  423. LOG(FATAL) << "Cannot write to snapshot prefix '"
  424. << param_.snapshot_prefix() << "'. Make sure "
  425. << "that the directory exists and is writable."; //创建失败,报错
  426. }
  427. }
  428. }
  429. //生成快照的文件名,前缀字符串 + "_iter_" + 迭代次数转字符串 + 扩展名extension
  430. template <typename Dtype>
  431. string Solver<Dtype>::SnapshotFilename(const string& extension) {
  432. return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)
  433. + extension;
  434. }
  435. template <typename Dtype>
  436. string Solver<Dtype>::SnapshotToBinaryProto() { //将训练网络的网络参数保存为二进制proto文件,并返回文件名
  437. string model_filename = SnapshotFilename(".caffemodel"); //生成文件名,扩展名为".caffemodel"
  438. LOG(INFO) << "Snapshotting to binary proto file " << model_filename; //打印信息
  439. NetParameter net_param;
  440. //将训练网络net_中的所有layer的参数写入到net_param中,snapshot_diff()表示是否需要保存梯度信息到快照中
  441. net_->ToProto(&net_param, param_.snapshot_diff());
  442. WriteProtoToBinaryFile(net_param, model_filename); //将NetParameter类型的消息写入到文件中
  443. return model_filename; //返回快照文件名
  444. }
  445. template <typename Dtype>
  446. string Solver<Dtype>::SnapshotToHDF5() { //将训练网络的参数存为hdf5文件中,返回文件名
  447. string model_filename = SnapshotFilename(".caffemodel.h5"); //快照的文件名
  448. LOG(INFO) << "Snapshotting to HDF5 file " << model_filename; //打印
  449. net_->ToHDF5(model_filename, param_.snapshot_diff()); //将net_的各layer的参数写入hdf5文件中
  450. return model_filename; //返回文件名
  451. }
  452. //还原网络参数和训练状态,从state_file文件中读取求解器的状态,如果里面还设置了网络参数的模型文件,则还会加载网络参数
  453. template <typename Dtype>
  454. void Solver<Dtype>::Restore(const char* state_file) {
  455. string state_filename(state_file);
  456. if (state_filename.size() >= 3 &&
  457. state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) { //根据文件名判断hdf5还是proto类型,稍微粗糙了点
  458. RestoreSolverStateFromHDF5(state_filename); //从hdf5文件中读取
  459. } else {
  460. RestoreSolverStateFromBinaryProto(state_filename); //从二进制proto文件中读取
  461. }
  462. }
  463. //start_iter为初始迭代的次数
  464. //losses_中存放loss值,初始时(iter_ < start_iter + average_loss)存放的loss的个数逐渐增加,个数达到average_loss时不再增加.
  465. //之后新的loss值都是从前往后依次覆盖之前的保存的值,不断循环.
  466. template <typename Dtype>
  467. void Solver<Dtype>::UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss) {
  468. if (losses_.size() < average_loss) { //个数还不到滑动平均窗的大小,会逐渐增加losses_的大小
  469. losses_.push_back(loss); //将loss存入
  470. int size = losses_.size();
  471. //smoothed_loss_为当前loss存入之前losses_的均值,存入后更新下均值
  472. smoothed_loss_ = (smoothed_loss_ * (size - 1) + loss) / size;
  473. } else {
  474. int idx = (iter_ - start_iter) % average_loss; //将iter_对应的loss存入losses_中的对应位置
  475. smoothed_loss_ += (loss - losses_[idx]) / average_loss; //先计算平均loss,再将值存入
  476. losses_[idx] = loss;
  477. }
  478. }

solver.hpp源码

  1. /**
  2. * @brief Enumeration of actions that a client of the Solver may request by
  3. * implementing the Solver's action request function, which a
  4. * client may optionally provide in order to request early termination
  5. * or saving a snapshot without exiting. In the executable caffe, this
  6. * mechanism is used to allow the snapshot to be saved when stopping
  7. * execution with a SIGINT (Ctrl-C).
  8. */
  9. namespace SolverAction {
  10. enum Enum {
  11. NONE = 0, // Take no special action.
  12. STOP = 1, // Stop training. snapshot_after_train controls whether a
  13. // snapshot is created. //停止,提前退出
  14. SNAPSHOT = 2 // Take a snapshot, and keep training. //将当前的训练网络的参数存为快照文件,并继续后续操作
  15. };
  16. }
  17. /**
  18. * @brief Type of a function that returns a Solver Action enumeration.
  19. */
  20. typedef boost::function<SolverAction::Enum()> ActionCallback;
  21. /**
  22. * @brief An interface for classes that perform optimization on Net%s.
  23. *
  24. * Requires implementation of ApplyUpdate to compute a parameter update
  25. * given the current state of the Net parameters.
  26. */
  27. template <typename Dtype>
  28. class Solver {
  29. public:
  30. explicit Solver(const SolverParameter& param);
  31. explicit Solver(const string& param_file);
  32. void Init(const SolverParameter& param);
  33. void InitTrainNet();
  34. void InitTestNets();
  35. // Client of the Solver optionally may call this in order to set the function
  36. // that the solver uses to see what action it should take (e.g. snapshot or
  37. // exit training early).
  38. void SetActionFunction(ActionCallback func); //设置求解器动作的回调函数
  39. SolverAction::Enum GetRequestedAction();
  40. // The main entry of the solver function. In default, iter will be zero. Pass
  41. // in a non-zero iter number to resume training for a pre-trained net.
  42. virtual void Solve(const char* resume_file = NULL);
  43. inline void Solve(const string& resume_file) { Solve(resume_file.c_str()); }
  44. void Step(int iters);
  45. // The Restore method simply dispatches to one of the
  46. // RestoreSolverStateFrom___ protected methods. You should implement these
  47. // methods to restore the state from the appropriate snapshot type.
  48. void Restore(const char* resume_file);
  49. // The Solver::Snapshot function implements the basic snapshotting utility
  50. // that stores the learned net. You should implement the SnapshotSolverState()
  51. // function that produces a SolverState protocol buffer that needs to be
  52. // written to disk together with the learned net.
  53. void Snapshot();
  54. virtual ~Solver() {}
  55. inline const SolverParameter& param() const { return param_; }
  56. inline shared_ptr<Net<Dtype> > net() { return net_; }
  57. inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
  58. return test_nets_;
  59. }
  60. int iter() const { return iter_; }
  61. // Invoked at specific points during an iteration
  62. //迭代过程中调用的回调类,里面实现了两个函数,用于多gpu训练中的同步
  63. class Callback {
  64. protected:
  65. virtual void on_start() = 0;
  66. virtual void on_gradients_ready() = 0;
  67. template <typename T>
  68. friend class Solver;
  69. };
  70. const vector<Callback*>& callbacks() const { return callbacks_; }
  71. void add_callback(Callback* value) {
  72. callbacks_.push_back(value); //加入
  73. }
  74. void CheckSnapshotWritePermissions();
  75. /**
  76. * @brief Returns the solver type.
  77. */
  78. virtual inline const char* type() const { return ""; }
  79. // Make and apply the update value for the current iteration.
  80. virtual void ApplyUpdate() = 0;
  81. protected:
  82. string SnapshotFilename(const string& extension);
  83. string SnapshotToBinaryProto();
  84. string SnapshotToHDF5();
  85. // The test routine
  86. void TestAll();
  87. void Test(const int test_net_id = 0);
  88. virtual void SnapshotSolverState(const string& model_filename) = 0;
  89. virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
  90. virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
  91. void DisplayOutputBlobs(const int net_id);
  92. void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);
  93. SolverParameter param_;
  94. int iter_; //当前的迭代次数
  95. int current_step_; //当前迭代的阶段,在学习率更新策略为step和multistep中使用
  96. shared_ptr<Net<Dtype> > net_; //训练网络
  97. vector<shared_ptr<Net<Dtype> > > test_nets_; //所有的测试网络
  98. vector<Callback*> callbacks_; //回调函数
  99. vector<Dtype> losses_; //保存最近average_loss_次迭代的loss值
  100. Dtype smoothed_loss_; //losses_的均值
  101. // A function that can be set by a client of the Solver to provide indication
  102. // that it wants a snapshot saved and/or to exit early.
  103. ActionCallback action_request_function_; //返回值为求解器动作的回调函数
  104. // True iff a request to stop early was received.
  105. bool requested_early_exit_; //是否需要提前退出
  106. // Timing information, handy to tune e.g. nbr of GPUs
  107. Timer iteration_timer_; //计时器
  108. float iterations_last_; //上一次开启计时器的iter_的值
  109. DISABLE_COPY_AND_ASSIGN(Solver);
  110. };

小结

  1. 求解器的动作回调函数在caffe.cpp文件中设置,为SignalHandler::CheckForSignals()的函数指针。当Unix系统中出现SIGINT或SIGHUP信号时,GotSIGINT()GotSIGHUP()函数会返回相应标志,并清空信号。而SignalHandler::CheckForSignals()函数则会根据标志返回对应的求解器动作类型(NONE/STOP/SNAPSHOT),具体可参考signal_handler.cpp文件。
  2. Step()函数中每次迭代计算前向/反向过程时,都使用了ClearParamDiffs()函数清空梯度。这是因为caffe中每次反向传播时的梯度数据都是累加在原数据上的,所以每次迭代时都需要手动清空,这与PyTorch中需要手动将梯度清零一致。

Caffe的源码笔者是第一次阅读,一边阅读一边记录,对代码的理解和分析可能会存在错误或遗漏,希望各位读者批评指正,谢谢支持!

Caffe源码-Solver类的更多相关文章

  1. Caffe源码-SyncedMemory类

    SyncedMemory类简介 最近在阅读caffe源码,代码来自BVLC/caffe,基本是参照网络上比较推荐的 Blob-->Layer-->Net-->Solver 的顺序来分 ...

  2. Caffe源码-SGDSolver类

    SGDSolver类简介 Solver类用于网络参数的更新,而SGDSolver类实现了优化方法中的随机梯度下降法(stochastic gradient descent),此外还具备缩放.正则化梯度 ...

  3. Caffe源码-Net类(下)

    net.cpp部分源码 // 接着上一篇博客的介绍,此部分为Net类中前向反向计算函数,以及一些与HDF5文件或proto文件相互转换的函数. template <typename Dtype& ...

  4. Caffe源码-Blob类

    Blob类简介 Blob是caffe中的数据传递的一个基本类,网络各层的输入输出数据以及网络层中的可学习参数(learnable parameters,如卷积层的权重和偏置参数)都是Blob类型.Bl ...

  5. Caffe源码-Net类(上)

    Net类简介 Net类主要处理各个Layer之间的输入输出数据和参数数据共享等的关系.由于Net类的代码较多,本次主要介绍网络初始化部分的代码.Net类在初始化的时候将各个Layer的输出blob都统 ...

  6. Caffe源码-Layer类

    Layer类简介 Layer是caffe中搭建网络的基本单元,caffe代码中包含大量Layer基类派生出来的各种各样的层,各自通过虚函数 Forward() 和 Backward() 实现自己的功能 ...

  7. Caffe源码-几种优化算法

    SGD简介 caffe中的SGDSolver类中实现了带动量的梯度下降法,其原理如下,\(lr\)为学习率,\(m\)为动量参数. 计算新的动量:history_data = local_rate * ...

  8. caffe源码阅读

    参考网址:https://www.cnblogs.com/louyihang-loves-baiyan/p/5149628.html 1.caffe代码层次熟悉blob,layer,net,solve ...

  9. Caffe源码中common文件分析

    Caffe源码(caffe version:09868ac , date: 2015.08.15)中的一些重要头文件如caffe.hpp.blob.hpp等或者外部调用Caffe库使用时,一般都会in ...

随机推荐

  1. 20191017-3 alpha week 2/2 Scrum立会报告+燃尽图 02

    此作业要求参见https://edu.cnblogs.com/campus/nenu/2019fall/homework/9799 一.小组情况 队名:扛把子 组长:迟俊文 组员:宋晓丽 梁梦瑶 韩昊 ...

  2. (一)OpenStack---M版---双节点搭建---基础环境配置

    ↓↓↓↓↓↓↓↓视频已上线B站↓↓↓↓↓↓↓↓ >>>>>>传送门 配置如下 本次搭建采用2台4核4G的虚拟机,也可以用2台2核4G 主机名 配置 网络 Contr ...

  3. Java并发之synchronized关键字和Lock接口

    欢迎点赞阅读,一同学习交流,有疑问请留言 . GitHub上也有开源 JavaHouse,欢迎star 引用 当开发过程中,我们遇到并发问题.怎么解决? 一种解决方式,简单粗暴:上锁.将千军万马都给拦 ...

  4. 题解 P1226 【【模板】快速幂||取余运算】

    1.题目分析 原题 本题在于快速幂的使用,以及对long long的应用问题. 2.解题思路 快速幂 求幂常见用法: int pow(int a,int b) { int ans; for(int i ...

  5. 初探three.js材质

    这节我们浅谈一下THREE的材质.材质就是物体的皮肤,决定物体的表面.THREE的材质有很多种,他们有的和到相机的距离有关,有的和面的法向量角度有关,有的不受光照的影响,有的受到光照的影响会产生反射效 ...

  6. Java基础面试题及答案(二)

    容器 18. java 容器都有哪些? 常用容器的图录: 19. Collection 和 Collections 有什么区别? java.util.Collection 是一个集合接口(集合类的一个 ...

  7. block的本质

    全局变量

  8. Idea工具Debug快捷键

    F9 resume programe 恢复程序 Alt+F10 show execution point 显示执行断点 F8 Step Over 相当于eclipse的f6 跳到下一步 F7 Step ...

  9. Python面试180道题

    版权声明:本文为CSDN博主「CSDN学院官方账号」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明. 原文链接:https://blog.csdn.net/csd ...

  10. Xcode编译引用Framework

    需要两步配置 1.在xcode工程的search path下设置要引用的Framework所在路径 2.将Framewoek拖入工程中时 不要选择copy,而选择引用模式.