caffe Solve函数
下面来看Solver<Dtype>::Solve(const char* resume_file)
solver.cpp
template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
CHECK(Caffe::root_solver());
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); // Initialize to false every time we start solving.
requested_early_exit_ = false; if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
// 从以前中断的训练状态中恢复训练
Restore(resume_file);
} // For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
int start_iter = iter_;
// 主要的迭代过程都在这里
Step(param_.max_iter() - iter_);
// If we haven't already, save a snapshot after optimization, unless
// overridden by setting snapshot_after_train := false
if (param_.snapshot_after_train()
&& (!param_.snapshot() || iter_ % param_.snapshot() != )) {
Snapshot();
}
if (requested_early_exit_) {
LOG(INFO) << "Optimization stopped early.";
return;
}
// After the optimization is done, run an additional train and test pass to
// display the train and test loss/outputs if appropriate (based on the
// display and test_interval settings, respectively). Unlike in the rest of
// training, for the train net we only run a forward pass as we've already
// updated the parameters "max_iter" times -- this final pass is only done to
// display the loss, which is computed in the forward pass.
if (param_.display() && iter_ % param_.display() == ) {
int average_loss = this->param_.average_loss();
Dtype loss;
net_->Forward(&loss); UpdateSmoothedLoss(loss, start_iter, average_loss); LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
}
if (param_.test_interval() && iter_ % param_.test_interval() == ) {
TestAll();
}
LOG(INFO) << "Optimization Done.";
}
下面先看Solve中的Restore(resume_file)
solver.cpp
template <typename Dtype>
void Solver<Dtype>::Restore(const char* state_file) {
string state_filename(state_file);
if (state_filename.size() >= &&
state_filename.compare(state_filename.size() - , , ".h5") == ) {
RestoreSolverStateFromHDF5(state_filename);
} else {
RestoreSolverStateFromBinaryProto(state_filename);
}
}
上面的RestoreSolverStateFromHDF5(state_filename)和RestoreSolverStateFromBinaryProto(state_filename)都是虚函数,调用的其实是其派生类的同名方法。例如,若使用SGD求解,SGDSolver类中的RestoreSolverStateFromBinaryProto方法如下
sgd_solver.cpp
template <typename Dtype>
void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
const string& state_file) {
SolverState state;
ReadProtoFromBinaryFile(state_file, &state);
// 此处获取上次训练中断时的迭代次数
this->iter_ = state.iter();
if (state.has_learned_net()) {
NetParameter net_param;
ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
this->net_->CopyTrainedLayersFrom(net_param);
}
this->current_step_ = state.current_step();
CHECK_EQ(state.history_size(), history_.size())
<< "Incorrect length of history blobs.";
LOG(INFO) << "SGDSolver: restoring history";
for (int i = ; i < history_.size(); ++i) {
history_[i]->FromProto(state.history(i));
}
}
下面主要分析Solve中的Step(param_.max_iter() - iter_)
solver.cpp
template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
const int start_iter = iter_;
const int stop_iter = iter_ + iters;
int average_loss = this->param_.average_loss();
losses_.clear();
smoothed_loss_ = ;
iteration_timer_.Start(); while (iter_ < stop_iter) {
// zero-init the params
// 将网络中参数的梯度清零
net_->ClearParamDiffs();
if (param_.test_interval() && iter_ % param_.test_interval() ==
&& (iter_ > || param_.test_initialization())) {
if (Caffe::root_solver()) {
TestAll();
}
if (requested_early_exit_) {
// Break out of the while loop because stop was requested while testing.
break;
}
} for (int i = ; i < callbacks_.size(); ++i) {
callbacks_[i]->on_start();
}
const bool display = param_.display() && iter_ % param_.display() == ;
net_->set_debug_info(display && param_.debug_info());
// accumulate the loss and gradient
Dtype loss = ;
// param.iter_size_默认是1,正常情况下,此处其实只进行了以次前向和反向传播
for (int i = ; i < param_.iter_size(); ++i) {
loss += net_->ForwardBackward();
}
loss /= param_.iter_size();
// average the loss across iterations for smoothed reporting
UpdateSmoothedLoss(loss, start_iter, average_loss);
if (display) {
float lapse = iteration_timer_.Seconds();
float per_s = (iter_ - iterations_last_) / (lapse ? lapse : );
LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
<< " (" << per_s << " iter/s, " << lapse << "s/"
<< param_.display() << " iters), loss = " << smoothed_loss_;
iteration_timer_.Start();
iterations_last_ = iter_;
const vector<Blob<Dtype>*>& result = net_->output_blobs();
int score_index = ;
for (int j = ; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
const string& output_name =
net_->blob_names()[net_->output_blob_indices()[j]];
const Dtype loss_weight =
net_->blob_loss_weights()[net_->output_blob_indices()[j]];
for (int k = ; k < result[j]->count(); ++k) {
ostringstream loss_msg_stream;
if (loss_weight) {
loss_msg_stream << " (* " << loss_weight
<< " = " << loss_weight * result_vec[k] << " loss)";
}
LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
<< score_index++ << ": " << output_name << " = "
<< result_vec[k] << loss_msg_stream.str();
}
}
}
for (int i = ; i < callbacks_.size(); ++i) {
callbacks_[i]->on_gradients_ready();
}
// 网络的参数在此处更新。该函数是一个虚函数,具体由Solver的派生类来实现
ApplyUpdate(); // Increment the internal iter_ counter -- its value should always indicate
// the number of times the weights have been updated.
// 每次迭代其实是一个batch_size个样本输入网络中,将它们产生的网络参数的梯度加起来作为一次迭代的参数梯度。然后用这个梯度跟据一定的正则化方法、参数更新策略来更新参数
++iter_; SolverAction::Enum request = GetRequestedAction(); // Save a snapshot if needed.
if ((param_.snapshot()
&& iter_ % param_.snapshot() ==
&& Caffe::root_solver()) ||
(request == SolverAction::SNAPSHOT)) {
Snapshot();
}
if (SolverAction::STOP == request) {
requested_early_exit_ = true;
// Break out of training loop.
break;
}
}
}
上面的loss += net_->ForwardBackward()是训练过程的核心。这行代码的功能是取一个batch_size数据,让其在网络中进行一次前向传播,得出损失的均值;再进行一次反向传播,得出网络参数的梯度(该梯度是一个batch_size数据产生的梯度的均值)。详细分析见下一章节。
caffe Solve函数的更多相关文章
- MATLAB利用solve函数解多元一次方程组
matlab求解多元方程组示例: syms k1 k2 k3; [k1 k2 k3] = solve(-3-k3==6, 2-k1-k2+2*k3==11, 2*k1+k2-k3+1==6)或者用[k ...
- pycaffe︱caffe中fine-tuning模型三重天(函数详解、框架简述)
本文主要参考caffe官方文档[<Fine-tuning a Pretrained Network for Style Recognition>](http://nbviewer.jupy ...
- 非线性方程(组):MATLAB内置函数 solve, vpasolve, fsolve, fzero, roots [MATLAB]
MATLAB函数 solve, vpasolve, fsolve, fzero, roots 功能和信息概览 求解函数 多项式型 非多项式型 一维 高维 符号 数值 算法 solve 支持,得到全部符 ...
- Matlab的solve()函数的使用方法
Matlab的solve()函数的使用方法 1.首先是对方程的求解 不废话直接上例子 syms x: eq=x^2+2*x+1; s=solve(eq,x); 结果如下 完美的算出了方程的解 现在对上 ...
- 从零开始山寨Caffe·捌:IO系统(二)
生产者 双缓冲组与信号量机制 在第陆章中提到了,如何模拟,以及取代根本不存的Q.full()函数. 其本质是:除了为生产者提供一个成品缓冲队列,还提供一个零件缓冲队列. 当我们从外部给定了固定容量的零 ...
- Caffe 源碼閱讀(五) Solver.cpp
1.Solver类两个构造函数 Solver(const SolverParameter& param) Solver(const string& param_file) 初始化两个类 ...
- caffe简单介绍
从四个层次来理解caffe:Blob.Layer.Net.Solver. 1.BlobBlob是caffe基本的数据结构,用四维矩阵 Batch×Channel×Height×Weight表示,存储了 ...
- caffe源码阅读(1)_整体框架和简介(摘录)
原文链接:https://www.zhihu.com/question/27982282 1.Caffe代码层次.回答里面有人说熟悉Blob,Layer,Net,Solver这样的几大类,我比较赞同. ...
- caffe源码学习
本文转载自:https://buptldy.github.io/2016/10/09/2016-10-09-Caffe_Code/ Caffe简介 Caffe作为一个优秀的深度学习框架网上已经有很多内 ...
随机推荐
- ZOJ 2314 无源汇可行流(输出方案)
Time Limit: 5 Seconds Memory Limit: 32768 KB Special Judge The terrorist group leaded by a ...
- 【转】iPhone获取状态栏和导航栏尺寸(宽度和高度)
原文网址:http://blog.csdn.net/chadeltu/article/details/42708605 iPhone开发当中,有时需要获取状态栏和导航栏高度.宽度信息,方便布局其他控件 ...
- C语言程序读写文件(文件内存一个十进制数,每读一次数值加一)
1.问题:C语言程序实现读写一个txt文件,txt文件中存储一个十进制数.每读一次该数值加一. 2.实现:新建一个文件夹,在该文件夹中建一个outputFileName.txt文件.内容是:1,再在该 ...
- [Apple开发者帐户帮助]五、管理标识符(5)创建一个iCloud容器
您必须拥有一个或多个iCloud容器才能启用iCloud. 所需角色:帐户持有人或管理员. 在“ 证书”,“标识符和配置文件”中,从左侧的弹出菜单中选择操作系统. 在“标识符”下,选择“iCloud容 ...
- SDOI 2018 round2游记
Day 0 早上起来从北京到济南 住宿环境不错 不过比赛环境怎么这么low啊 而且我在最偏僻的考场中最偏僻的角落里 身边居然是个妹子?! Day1 7:40到的考试地点 发现诸位大佬已经打完板子了or ...
- 利用MediaSession发送信息到蓝牙音箱
1.利用MediaSession发送信息到蓝牙音箱,如:播放音乐时接收的歌曲信息,但是每一首歌连续播放时,再次发送的重复信息会被丢弃.则利用MediaSession发现信息时,要保证信息的不重复性. ...
- tp实现多语言支持测试
用tp框架实现网页多种语言切换 时间:2016-11-11 浏览次数:1120 编辑:youjiejie 网页如何设计多种语言切换,本文用tp框架实现网页多种语言切换方法结合实例形式较为详细的分析 ...
- Spring AOP之静态代理
软件151 李飞瑶 一.SpringAOP: ⒈AOP:Aspect Oriented Programming 面向切面编程, 实现的是核心业务和非核心业务之间的的分离,让核心类只做核心业务,代理类只 ...
- 西门子Step7中DB块结构导出
Step7 通过变量表可以导出内存M地址和I,Q,T,C地址的变量,以及DB块的名称.怎么导出DB块的内部结构结构呢.即如何导出结构内的定义呢? 可以通过“选择某个DB块”,通过菜单命令“File&g ...
- 进行https通信时服务器端下发的是一个证书链
进行https通信时服务器端下发的是一个证书链,否则无法验证证书的有效性.