Caffe源码-Solver类
Solver类简介
Net类中实现了网络的前向/反向计算和参数更新,而Solver类中则是对此进行进一步封装,包含可用于逐次训练网络的Step()
函数,和用于求解网络的优化解的Solve()
函数,同时还实现了一些存储、读取网络模型快照的接口函数。
solver.cpp源码
template<typename Dtype>
void Solver<Dtype>::SetActionFunction(ActionCallback func) {
action_request_function_ = func; //设置回调函数,该函数会返回求解器的动作类型
}
template<typename Dtype>
SolverAction::Enum Solver<Dtype>::GetRequestedAction() { //返回求解器的动作类型
if (action_request_function_) {
// If the external request function has been set, call it.
return action_request_function_(); //运行回调函数,该函数会返回求解器的动作类型
}
return SolverAction::NONE;
}
template <typename Dtype>
Solver<Dtype>::Solver(const SolverParameter& param) //构造函数,使用param消息初始化求解器
: net_(), callbacks_(), requested_early_exit_(false) {
Init(param); //使用param消息初始化当前求解器
}
template <typename Dtype>
Solver<Dtype>::Solver(const string& param_file)
: net_(), callbacks_(), requested_early_exit_(false) { //构造函数,从文本类型的proto文件中读取求解器参数
SolverParameter param;
ReadSolverParamsFromTextFileOrDie(param_file, ¶m); //从param_file中读取消息数据到param中
Init(param); //初始化求解器
}
template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) { //Solver类初始化
LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
<< std::endl << param.DebugString(); //主线程中打印信息
param_ = param;
//loss的滑动平均窗的长度,每次计算最近average_loss_次的平均loss
CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
CheckSnapshotWritePermissions(); //检查是否能够打开快照文件
if (param_.random_seed() >= 0) { //SolverParameter消息中设置了随机种子
Caffe::set_random_seed(param_.random_seed() + Caffe::solver_rank()); //设置
}
// Scaffolding code
InitTrainNet(); //初始化训练网络
InitTestNets(); //初始化所有测试网络 //训练网络只有一个,但是测试网络可以有多个
if (Caffe::root_solver()) {
LOG(INFO) << "Solver scaffolding done."; //只在主线程中打印
}
iter_ = 0; //初始化参数
current_step_ = 0;
}
// Load weights from the caffemodel(s) specified in "weights" solver parameter
// into the train and test nets.
template <typename Dtype>
void LoadNetWeights(shared_ptr<Net<Dtype> > net, const std::string& model_list) { //加载权重文件
std::vector<std::string> model_names;
boost::split(model_names, model_list, boost::is_any_of(",")); //拆分文件名,权重文件名在model_list中以","中分隔开
for (int i = 0; i < model_names.size(); ++i) {
boost::trim(model_names[i]); //删除首位空格
LOG(INFO) << "Finetuning from " << model_names[i]; //打印权重文件名
net->CopyTrainedLayersFrom(model_names[i]); //从文件中拷贝blob数据到网络的同名参数中
}
}
template <typename Dtype>
void Solver<Dtype>::InitTrainNet() { //初始化训练网络,配置网络参数,加载预训练模型
//训练网络的proto文件名可通过SolverParameter消息中的train_net_param, train_net, net_param, net四个中的任意一个指定
const int num_train_nets = param_.has_net() + param_.has_net_param() +
param_.has_train_net() + param_.has_train_net_param(); //这四个参数总共设置的训练网络个数
const string field_names = "net, net_param, train_net, train_net_param";
CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
<< "using one of these fields: " << field_names; //检查是否大于等于1
CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
<< "one of these fields specifying a train_net: " << field_names; //检查是否小于等于1 //四个中只能有一个设置了true
NetParameter net_param;
if (param_.has_train_net_param()) { //训练网络的名称在train_net_param中设置了
LOG_IF(INFO, Caffe::root_solver()) << "Creating training net specified in train_net_param."; //主线程中打印
net_param.CopyFrom(param_.train_net_param()); //从NetParameter消息中拷贝网络参数至net_param
} else if (param_.has_train_net()) { //训练网络的名称在train_net中设置了
LOG_IF(INFO, Caffe::root_solver()) << "Creating training net from train_net file: " << param_.train_net();
ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param); //从proto文件中读取网络参数
}
if (param_.has_net_param()) { //训练网络的名称在net_param中设置了
LOG_IF(INFO, Caffe::root_solver()) << "Creating training net specified in net_param.";
net_param.CopyFrom(param_.net_param()); //从NetParameter类型的消息中拷贝网络参数
}
if (param_.has_net()) { //训练网络的名称在net中设置了
LOG_IF(INFO, Caffe::root_solver()) << "Creating training net from net file: " << param_.net();
ReadNetParamsFromTextFileOrDie(param_.net(), &net_param); //从proto文件中读取网络参数
}
// Set the correct NetState. We start with the solver defaults (lowest
// precedence); then, merge in any NetState specified by the net_param itself;
// finally, merge in any NetState specified by the train_state (highest
// precedence).
//Message::MergeFrom()的机制,单字段的值会被覆盖,嵌套消息的值会被融合在一起,重复字段的值会被拼接在一起
//Message::CopyFrom()的机制,清空当前的消息,然后将指定消息MergeFrom()到当前消息中
//net_param中的状态值先是设置为默认值,然后使用从上面四个设置中读取到的网络参数net_param中的网络状态覆盖其中相同的,
//再用当前求解器中设置的SolverParameter消息中的train_state覆盖其中相同的.
//在网络中设置的网络状态优先级低,会被求解器中设置的网络状态覆盖
NetState net_state;
net_state.set_phase(TRAIN); //设置网络的状态,训练模式
net_state.MergeFrom(net_param.state()); //先使用上面的从文件或者消息中读取的网络参数中的网络状态
net_state.MergeFrom(param_.train_state()); //再使用当前求解器中设置的训练网络状态
net_param.mutable_state()->CopyFrom(net_state); //将最终的到的网络状态存入网络参数中
net_.reset(new Net<Dtype>(net_param)); //使用该网络参数初始化网络,存入net_中
for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) { //weights参数的个数
LoadNetWeights(net_, param_.weights(w_idx)); //加载每个参数中的一个或者多个预训练模型到net_中
}
}
template <typename Dtype>
void Solver<Dtype>::InitTestNets() { //初始化测试网络
const bool has_net_param = param_.has_net_param();
const bool has_net_file = param_.has_net();
const int num_generic_nets = has_net_param + has_net_file; //是否设置了模型参数,是否设置了模型文件名
CHECK_LE(num_generic_nets, 1)
<< "Both net_param and net_file may not be specified."; //检查是否小于等于1,这两个不能同时指定
const int num_test_net_params = param_.test_net_param_size(); //设置的测试网络的参数的个数
const int num_test_net_files = param_.test_net_size(); //设置的测试网络的个数
const int num_test_nets = num_test_net_params + num_test_net_files; //总个数
if (num_generic_nets) {
//test_iter_表示每个测试网络迭代的次数,test_iter_参数设置的个数必须与测试网络的个数相等
//如果设置了模型参数或者模型文件名,那么这里面也可能设置了test net,所以test_iter_的个数必须大于等于num_test_nets
CHECK_GE(param_.test_iter_size(), num_test_nets)
<< "test_iter must be specified for each test network.";
} else {
//没有设置net_parma或者net的话,test net全部在test_net_parma和test_net中指定,个数需相等
CHECK_EQ(param_.test_iter_size(), num_test_nets)
<< "test_iter must be specified for each test network.";
}
// If we have a generic net (specified by net or net_param, rather than
// test_net or test_net_param), we may have an unlimited number of actual
// test networks -- the actual number is given by the number of remaining
// test_iters after any test nets specified by test_net_param and/or test_net
// are evaluated.
const int num_generic_net_instances = param_.test_iter_size() - num_test_nets; //相减得到在net_parma或者net中定义的test net的个数
const int num_test_net_instances = num_test_nets + num_generic_net_instances; //总的test net的个数,即为param_.test_iter_size()
if (param_.test_state_size()) { //设置了test_state_,则个数必须与测试网络的个数相等
CHECK_EQ(param_.test_state_size(), num_test_net_instances)
<< "test_state must be unspecified or specified once per test net."; //检查个数是否相等
}
if (num_test_net_instances) {
CHECK_GT(param_.test_interval(), 0); //检查设置的测试的迭代间隔是否大于0
}
int test_net_id = 0;
vector<string> sources(num_test_net_instances);
vector<NetParameter> net_params(num_test_net_instances);
//caffe.proto文件中注明了test net运行的优先级,(1) test_net_param, (2) test_net, (3) net_param/net.
for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) {
sources[test_net_id] = "test_net_param"; //保存定义该测试网络的来源
net_params[test_net_id].CopyFrom(param_.test_net_param(i)); //从NetParameter类型的消息中拷贝网络参数
}
for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) {
sources[test_net_id] = "test_net file: " + param_.test_net(i); //保存来源,加上文件名
ReadNetParamsFromTextFileOrDie(param_.test_net(i),
&net_params[test_net_id]); //从proto文件中读取网络参数,存入net_param中
}
const int remaining_test_nets = param_.test_iter_size() - test_net_id; //net_param/net中定义的网络的个数
if (has_net_param) { //定义了net_param,则剩余的测试网络都定义在此处
for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
sources[test_net_id] = "net_param";
net_params[test_net_id].CopyFrom(param_.net_param()); //拷贝网络参数
}
}
if (has_net_file) { //同样,从net文件中定义的测试网络文件名中读取网络参数
for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
sources[test_net_id] = "net file: " + param_.net();
ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
}
}
test_nets_.resize(num_test_net_instances); //调整大小
for (int i = 0; i < num_test_net_instances; ++i) {
// Set the correct NetState. We start with the solver defaults (lowest
// precedence); then, merge in any NetState specified by the net_param
// itself; finally, merge in any NetState specified by the test_state
// (highest precedence).
//与InitTrainNet()中的操作类似,先使用默认值,然后使用网络参数中的网络状态覆盖默认值,再使用
//求解器中设置的测试网络状态覆盖之前的值,得到最终的测试网络状态
NetState net_state;
net_state.set_phase(TEST); //设置模式为test
net_state.MergeFrom(net_params[i].state()); //先使用网络参数中设置的网络状态覆盖
if (param_.test_state_size()) {
net_state.MergeFrom(param_.test_state(i)); //然后使用求解器中设置的测试网络状态覆盖
}
net_params[i].mutable_state()->CopyFrom(net_state); //将最终的测试网络状态存入net_params[i]中
LOG(INFO) << "Creating test net (#" << i << ") specified by " << sources[i]; //打印之前保存的来源信息
test_nets_[i].reset(new Net<Dtype>(net_params[i])); //使用net_params[i]创建网络,存入test_nets_中
test_nets_[i]->set_debug_info(param_.debug_info()); //将求解器的是否打印信息的设置存入网络中
for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
LoadNetWeights(test_nets_[i], param_.weights(w_idx)); //加载预训练模型文件,每个测试网络都会尝试加载所有的预训练模型文件
}
}
}
//求解器单步迭代iters次
template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
const int start_iter = iter_; //当前已迭代的次数
const int stop_iter = iter_ + iters; //终止迭代时的次数
int average_loss = this->param_.average_loss(); //loss的滑动平均窗的长度
losses_.clear(); //清空历史loss值
smoothed_loss_ = 0; //清空
iteration_timer_.Start(); //打开计时器
while (iter_ < stop_iter) {
// zero-init the params
net_->ClearParamDiffs(); //清空网络中所有可学习参数的梯度数据
if (param_.test_interval() && iter_ % param_.test_interval() == 0 //两次测试之间的迭代间隔不为0,且当前轮到测试
&& (iter_ > 0 || param_.test_initialization())) { //初始时可以进入测试模式
//test_initialization()仅仅用于表示初始(iter_==0)时是否运行一次测试网络
//该值为true时,(iter_ % test_interval == 0)总是成立,每次开始迭代时都会先进入测试模式.该值为false时只在iter_ > 0时进入测试
if (Caffe::root_solver()) { //测试网络只在主线程中运行
TestAll(); //运行所有测试网络,并打印输出信息
}
if (requested_early_exit_) { //测试过程中出现提前退出动作,退出循环
// Break out of the while loop because stop was requested while testing.
break;
}
}
for (int i = 0; i < callbacks_.size(); ++i) { //solver的回调函数,在多gpu训练时用于同步求解器
callbacks_[i]->on_start();
}
const bool display = param_.display() && iter_ % param_.display() == 0; //设置了打印间隔并且当前迭代轮到打印了
net_->set_debug_info(display && param_.debug_info()); //设置是否打印调试信息
// accumulate the loss and gradient
Dtype loss = 0;
for (int i = 0; i < param_.iter_size(); ++i) { //单次迭代会执行iter_size次前向反向过程
loss += net_->ForwardBackward(); //执行一次前向计算和反向传播,并累加iter_size次计算得到的loss
}
loss /= param_.iter_size(); //每次迭代的平均loss
// average the loss across iterations for smoothed reporting
UpdateSmoothedLoss(loss, start_iter, average_loss); //将loss保存在losses_中,并计算新的均值smoothed_loss_
if (display) { //需要打印此次迭代的信息
float lapse = iteration_timer_.Seconds(); //关闭计时器,返回已运行的时间,单位s
float per_s = (iter_ - iterations_last_) / (lapse ? lapse : 1); //iterations_last_为上次开启计时器时的迭代次数,得到每秒可迭代的次数
LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
<< " (" << per_s << " iter/s, " << lapse << "s/"
<< param_.display() << " iters), loss = " << smoothed_loss_; //打印迭代次数,迭代速度,运行时间等信息
iteration_timer_.Start(); //重新打开计时器
iterations_last_ = iter_; //保存当前的迭代次数
const vector<Blob<Dtype>*>& result = net_->output_blobs(); //训练网络的所有输出blob
int score_index = 0;
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data(); //第j个输出blob的data_数据
const string& output_name = net_->blob_names()[net_->output_blob_indices()[j]]; //该输出blob的名称
const Dtype loss_weight = net_->blob_loss_weights()[net_->output_blob_indices()[j]]; //该输出blob的loss权重
for (int k = 0; k < result[j]->count(); ++k) {
ostringstream loss_msg_stream;
if (loss_weight) { //权重不为0时,保存权重和加权后的输出值
loss_msg_stream << " (* " << loss_weight
<< " = " << loss_weight * result_vec[k] << " loss)";
}
LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
<< score_index++ << ": " << output_name << " = "
<< result_vec[k] << loss_msg_stream.str(); //打印信息
}
}
}
//求解器的回调函数,在梯度计算完毕之后调用.同样也是用于多gpu计算时梯度数据的同步
for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_gradients_ready();
}
ApplyUpdate(); //根据学习率,冲量,权重衰减值等参数计算参数更新时使用的梯度,并更新网络中的参数,在SGDSolver类中实现
SolverAction::Enum request = GetRequestedAction(); //获取当前求解器的动作
// Save a snapshot if needed.
if ((param_.snapshot()
&& iter_ % param_.snapshot() == 0
&& Caffe::root_solver()) ||
(request == SolverAction::SNAPSHOT)) { //当前迭代次数轮到存储快照,或者当前的解器动作为存快照
Snapshot(); //生成快照文件
}
if (SolverAction::STOP == request) { //当前动作为退出,则提前退出
requested_early_exit_ = true;
// Break out of training loop.
break;
}
}
}
template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) { //从resume_file文件中恢复网络和求解器状态,并训练网络
CHECK(Caffe::root_solver()); //在主线程中进行该操作
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); //打印网络名称和学习率更新策略
// Initialize to false every time we start solving.
requested_early_exit_ = false; //每次求解时初始化下状态
if (resume_file) { //文件名不为空
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file); //从文件中还原网络参数和求解器的状态
}
// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
int start_iter = iter_; //当前已迭代的次数
Step(param_.max_iter() - iter_); //max_iter_为最大迭代次数,计算当前需要迭代的次数
// If we haven't already, save a snapshot after optimization, unless
// overridden by setting snapshot_after_train := false
if (param_.snapshot_after_train()
&& (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
//如果设置了训练结束后保存快照,并且当前迭代次数在并未轮到保存快照
//满足 param_.snapshot() && iter_ % param_.snapshot() == 0 的话会在Step()函数中保存当前iter_的快照,此处自然无需再保存
Snapshot();
}
if (requested_early_exit_) { //同样判断下求解器的动作
LOG(INFO) << "Optimization stopped early.";
return;
}
// After the optimization is done, run an additional train and test pass to
// display the train and test loss/outputs if appropriate (based on the
// display and test_interval settings, respectively). Unlike in the rest of
// training, for the train net we only run a forward pass as we've already
// updated the parameters "max_iter" times -- this final pass is only done to
// display the loss, which is computed in the forward pass.
//如果需要显示,会额外进行一次前向计算.这与Step()中的最后一次计算不同,Step()中的最后一次计算包括前向和反向计算,
//还包括参数的更新,此时参数更新之后网络的loss并不知道,所以此处会使用更新后的参数再计算一次前向过程,得到对应的loss
if (param_.display() && iter_ % param_.display() == 0) { //设置了打印求解器的信息并且当前迭代轮到打印了
int average_loss = this->param_.average_loss(); //loss的滑动平均窗的长度
Dtype loss;
net_->Forward(&loss); //一次前向计算
UpdateSmoothedLoss(loss, start_iter, average_loss); //更新losses_,并计算平均loss
LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_; //打印信息
}
if (param_.test_interval() && iter_ % param_.test_interval() == 0) { //设置了测试网络的运行间隔,并且当前轮到测试网络
TestAll(); //运行所有测试网络
}
LOG(INFO) << "Optimization Done."; //求解器优化完成
}
template <typename Dtype>
void Solver<Dtype>::TestAll() { //运行全部测试网络
for (int test_net_id = 0;
test_net_id < test_nets_.size() && !requested_early_exit_; //没有要求提前退出
++test_net_id) {
Test(test_net_id); //执行第test_net_id个测试网络
}
}
template <typename Dtype>
void Solver<Dtype>::Test(const int test_net_id) { //执行第test_net_id个测试网络
CHECK(Caffe::root_solver()); //测试网络只在主线程中运行
LOG(INFO) << "Iteration " << iter_
<< ", Testing net (#" << test_net_id << ")"; //打印迭代信息,测试网络的id
//共享网络,将训练网络net_中的参数blob的数据指针赋给当前的测试网络,只修改测试网络的指针指向位置,不会拷贝数据
CHECK_NOTNULL(test_nets_[test_net_id].get())->ShareTrainedLayersWith(net_.get());
vector<Dtype> test_score;
vector<int> test_score_output_id;
const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id]; //当前的测试网络
Dtype loss = 0;
//test_iter(test_net_id)为第test_net_id个测试网络在测试时需要迭代的次数
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
SolverAction::Enum request = GetRequestedAction(); //获取当前的求解器动作
// Check to see if stoppage of testing/training has been requested.
while (request != SolverAction::NONE) { //非NONE类型的话,则会执行相应的动作
if (SolverAction::SNAPSHOT == request) { //拍摄快照,并继续训练
Snapshot(); //生成快照文件,并继续当前操作
} else if (SolverAction::STOP == request) { //提前退出
requested_early_exit_ = true;
}
request = GetRequestedAction();
}
if (requested_early_exit_) { //退出,不进行后续的操作
// break out of test loop.
break;
}
Dtype iter_loss;
//执行test_net的一次前向计算过程,loss存入iter_loss中,result为网络的输出blob(net_output_blobs_)
const vector<Blob<Dtype>*>& result = test_net->Forward(&iter_loss);
if (param_.test_compute_loss()) { //是否计算测试网络的平均loss
loss += iter_loss; //累加每次计算出的loss
}
if (i == 0) { //初次计算时,先确定好test_score和test_score_output_id的大小
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data(); //网络输出的第j个blob的data_
for (int k = 0; k < result[j]->count(); ++k) {
test_score.push_back(result_vec[k]); //将输出blob的data中的数据全部存入test_score中
test_score_output_id.push_back(j); //将数据在输出blob中的来源存入test_score_output_id中
}
}
} else {
int idx = 0;
for (int j = 0; j < result.size(); ++j) { //每个输出blob
const Dtype* result_vec = result[j]->cpu_data(); //输出blob的data_数据
for (int k = 0; k < result[j]->count(); ++k) {
test_score[idx++] += result_vec[k]; //累加测试网络每次迭代时得到的输出blob数据
}
}
}
}
if (requested_early_exit_) { //提前退出?
LOG(INFO) << "Test interrupted.";
return;
}
if (param_.test_compute_loss()) { //是否计算测试网络的平均loss
loss /= param_.test_iter(test_net_id); //计算该测试网络test_iter(test_net_id)次迭代的loss均值
LOG(INFO) << "Test loss: " << loss;
}
for (int i = 0; i < test_score.size(); ++i) {
//数据test_score[i]来源于blob类型的net_output_blobs_[test_score_output_id[i]]中,output_blob_index为该blob在blobs_的索引
const int output_blob_index = test_net->output_blob_indices()[test_score_output_id[i]];
const string& output_name = test_net->blob_names()[output_blob_index]; //该blob的名称
const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index]; //该blob的loss权重
ostringstream loss_msg_stream;
const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id); //除以迭代次数,得到输出blob的均值
if (loss_weight) { //权重非0时,权重和加权值
loss_msg_stream << " (* " << loss_weight << " = " << loss_weight * mean_score << " loss)";
}
LOG(INFO) << " Test net output #" << i << ": " << output_name << " = "
<< mean_score << loss_msg_stream.str(); //打印测试网络的每个输出blob中的每个数据的均值
}
}
template <typename Dtype>
void Solver<Dtype>::Snapshot() { //生成两个快照文件,分别保存网络参数(NetParameter类型)和求解器的状态(SolverState类型)
CHECK(Caffe::root_solver()); //同样,存快照只在主线程中操作
string model_filename;
switch (param_.snapshot_format()) { //设置的快照文件格式
case caffe::SolverParameter_SnapshotFormat_BINARYPROTO: //二进制proto类型
model_filename = SnapshotToBinaryProto(); //将训练网络的网络参数存为".caffemodel"后缀的文件,返回其文件名
break;
case caffe::SolverParameter_SnapshotFormat_HDF5: //hdf5类型
model_filename = SnapshotToHDF5(); //将训练网络的网络参数写入文件中,返回其文件名
break;
default:
LOG(FATAL) << "Unsupported snapshot format.";
}
SnapshotSolverState(model_filename); //将求解器的状态(SolverState类型)保存为文件
}
template <typename Dtype>
void Solver<Dtype>::CheckSnapshotWritePermissions() { //检查是否能够创建快照文件(只检查是否能够以写方式创建文件,不会存数据进去)
if (Caffe::root_solver() && param_.snapshot()) { //只在主线程中操作
CHECK(param_.has_snapshot_prefix())
<< "In solver params, snapshot is specified but snapshot_prefix is not"; //检查是否设置了快照文件名的前缀
string probe_filename = SnapshotFilename(".tempfile"); //生成快照的文件名,".tempfile"为后缀
std::ofstream probe_ofs(probe_filename.c_str()); //创建临时文件文件
if (probe_ofs.good()) { //判断是否发生错误
probe_ofs.close(); //关闭
std::remove(probe_filename.c_str()); //删除文件
} else {
LOG(FATAL) << "Cannot write to snapshot prefix '"
<< param_.snapshot_prefix() << "'. Make sure "
<< "that the directory exists and is writable."; //创建失败,报错
}
}
}
//生成快照的文件名,前缀字符串 + "_iter_" + 迭代次数转字符串 + 扩展名extension
template <typename Dtype>
string Solver<Dtype>::SnapshotFilename(const string& extension) {
return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)
+ extension;
}
template <typename Dtype>
string Solver<Dtype>::SnapshotToBinaryProto() { //将训练网络的网络参数保存为二进制proto文件,并返回文件名
string model_filename = SnapshotFilename(".caffemodel"); //生成文件名,扩展名为".caffemodel"
LOG(INFO) << "Snapshotting to binary proto file " << model_filename; //打印信息
NetParameter net_param;
//将训练网络net_中的所有layer的参数写入到net_param中,snapshot_diff()表示是否需要保存梯度信息到快照中
net_->ToProto(&net_param, param_.snapshot_diff());
WriteProtoToBinaryFile(net_param, model_filename); //将NetParameter类型的消息写入到文件中
return model_filename; //返回快照文件名
}
template <typename Dtype>
string Solver<Dtype>::SnapshotToHDF5() { //将训练网络的参数存为hdf5文件中,返回文件名
string model_filename = SnapshotFilename(".caffemodel.h5"); //快照的文件名
LOG(INFO) << "Snapshotting to HDF5 file " << model_filename; //打印
net_->ToHDF5(model_filename, param_.snapshot_diff()); //将net_的各layer的参数写入hdf5文件中
return model_filename; //返回文件名
}
//还原网络参数和训练状态,从state_file文件中读取求解器的状态,如果里面还设置了网络参数的模型文件,则还会加载网络参数
template <typename Dtype>
void Solver<Dtype>::Restore(const char* state_file) {
string state_filename(state_file);
if (state_filename.size() >= 3 &&
state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) { //根据文件名判断hdf5还是proto类型,稍微粗糙了点
RestoreSolverStateFromHDF5(state_filename); //从hdf5文件中读取
} else {
RestoreSolverStateFromBinaryProto(state_filename); //从二进制proto文件中读取
}
}
//start_iter为初始迭代的次数
//losses_中存放loss值,初始时(iter_ < start_iter + average_loss)存放的loss的个数逐渐增加,个数达到average_loss时不再增加.
//之后新的loss值都是从前往后依次覆盖之前的保存的值,不断循环.
template <typename Dtype>
void Solver<Dtype>::UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss) {
if (losses_.size() < average_loss) { //个数还不到滑动平均窗的大小,会逐渐增加losses_的大小
losses_.push_back(loss); //将loss存入
int size = losses_.size();
//smoothed_loss_为当前loss存入之前losses_的均值,存入后更新下均值
smoothed_loss_ = (smoothed_loss_ * (size - 1) + loss) / size;
} else {
int idx = (iter_ - start_iter) % average_loss; //将iter_对应的loss存入losses_中的对应位置
smoothed_loss_ += (loss - losses_[idx]) / average_loss; //先计算平均loss,再将值存入
losses_[idx] = loss;
}
}
solver.hpp源码
/**
* @brief Enumeration of actions that a client of the Solver may request by
* implementing the Solver's action request function, which a
* client may optionally provide in order to request early termination
* or saving a snapshot without exiting. In the executable caffe, this
* mechanism is used to allow the snapshot to be saved when stopping
* execution with a SIGINT (Ctrl-C).
*/
namespace SolverAction {
enum Enum {
NONE = 0, // Take no special action.
STOP = 1, // Stop training. snapshot_after_train controls whether a
// snapshot is created. //停止,提前退出
SNAPSHOT = 2 // Take a snapshot, and keep training. //将当前的训练网络的参数存为快照文件,并继续后续操作
};
}
/**
* @brief Type of a function that returns a Solver Action enumeration.
*/
typedef boost::function<SolverAction::Enum()> ActionCallback;
/**
* @brief An interface for classes that perform optimization on Net%s.
*
* Requires implementation of ApplyUpdate to compute a parameter update
* given the current state of the Net parameters.
*/
template <typename Dtype>
class Solver {
public:
explicit Solver(const SolverParameter& param);
explicit Solver(const string& param_file);
void Init(const SolverParameter& param);
void InitTrainNet();
void InitTestNets();
// Client of the Solver optionally may call this in order to set the function
// that the solver uses to see what action it should take (e.g. snapshot or
// exit training early).
void SetActionFunction(ActionCallback func); //设置求解器动作的回调函数
SolverAction::Enum GetRequestedAction();
// The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string& resume_file) { Solve(resume_file.c_str()); }
void Step(int iters);
// The Restore method simply dispatches to one of the
// RestoreSolverStateFrom___ protected methods. You should implement these
// methods to restore the state from the appropriate snapshot type.
void Restore(const char* resume_file);
// The Solver::Snapshot function implements the basic snapshotting utility
// that stores the learned net. You should implement the SnapshotSolverState()
// function that produces a SolverState protocol buffer that needs to be
// written to disk together with the learned net.
void Snapshot();
virtual ~Solver() {}
inline const SolverParameter& param() const { return param_; }
inline shared_ptr<Net<Dtype> > net() { return net_; }
inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
return test_nets_;
}
int iter() const { return iter_; }
// Invoked at specific points during an iteration
//迭代过程中调用的回调类,里面实现了两个函数,用于多gpu训练中的同步
class Callback {
protected:
virtual void on_start() = 0;
virtual void on_gradients_ready() = 0;
template <typename T>
friend class Solver;
};
const vector<Callback*>& callbacks() const { return callbacks_; }
void add_callback(Callback* value) {
callbacks_.push_back(value); //加入
}
void CheckSnapshotWritePermissions();
/**
* @brief Returns the solver type.
*/
virtual inline const char* type() const { return ""; }
// Make and apply the update value for the current iteration.
virtual void ApplyUpdate() = 0;
protected:
string SnapshotFilename(const string& extension);
string SnapshotToBinaryProto();
string SnapshotToHDF5();
// The test routine
void TestAll();
void Test(const int test_net_id = 0);
virtual void SnapshotSolverState(const string& model_filename) = 0;
virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
void DisplayOutputBlobs(const int net_id);
void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);
SolverParameter param_;
int iter_; //当前的迭代次数
int current_step_; //当前迭代的阶段,在学习率更新策略为step和multistep中使用
shared_ptr<Net<Dtype> > net_; //训练网络
vector<shared_ptr<Net<Dtype> > > test_nets_; //所有的测试网络
vector<Callback*> callbacks_; //回调函数
vector<Dtype> losses_; //保存最近average_loss_次迭代的loss值
Dtype smoothed_loss_; //losses_的均值
// A function that can be set by a client of the Solver to provide indication
// that it wants a snapshot saved and/or to exit early.
ActionCallback action_request_function_; //返回值为求解器动作的回调函数
// True iff a request to stop early was received.
bool requested_early_exit_; //是否需要提前退出
// Timing information, handy to tune e.g. nbr of GPUs
Timer iteration_timer_; //计时器
float iterations_last_; //上一次开启计时器的iter_的值
DISABLE_COPY_AND_ASSIGN(Solver);
};
小结
- 求解器的动作回调函数在caffe.cpp文件中设置,为
SignalHandler::CheckForSignals()
的函数指针。当Unix系统中出现SIGINT或SIGHUP信号时,GotSIGINT()
或GotSIGHUP()
函数会返回相应标志,并清空信号。而SignalHandler::CheckForSignals()
函数则会根据标志返回对应的求解器动作类型(NONE/STOP/SNAPSHOT),具体可参考signal_handler.cpp文件。 Step()
函数中每次迭代计算前向/反向过程时,都使用了ClearParamDiffs()
函数清空梯度。这是因为caffe中每次反向传播时的梯度数据都是累加在原数据上的,所以每次迭代时都需要手动清空,这与PyTorch中需要手动将梯度清零一致。
Caffe的源码笔者是第一次阅读,一边阅读一边记录,对代码的理解和分析可能会存在错误或遗漏,希望各位读者批评指正,谢谢支持!
Caffe源码-Solver类的更多相关文章
- Caffe源码-SyncedMemory类
SyncedMemory类简介 最近在阅读caffe源码,代码来自BVLC/caffe,基本是参照网络上比较推荐的 Blob-->Layer-->Net-->Solver 的顺序来分 ...
- Caffe源码-SGDSolver类
SGDSolver类简介 Solver类用于网络参数的更新,而SGDSolver类实现了优化方法中的随机梯度下降法(stochastic gradient descent),此外还具备缩放.正则化梯度 ...
- Caffe源码-Net类(下)
net.cpp部分源码 // 接着上一篇博客的介绍,此部分为Net类中前向反向计算函数,以及一些与HDF5文件或proto文件相互转换的函数. template <typename Dtype& ...
- Caffe源码-Blob类
Blob类简介 Blob是caffe中的数据传递的一个基本类,网络各层的输入输出数据以及网络层中的可学习参数(learnable parameters,如卷积层的权重和偏置参数)都是Blob类型.Bl ...
- Caffe源码-Net类(上)
Net类简介 Net类主要处理各个Layer之间的输入输出数据和参数数据共享等的关系.由于Net类的代码较多,本次主要介绍网络初始化部分的代码.Net类在初始化的时候将各个Layer的输出blob都统 ...
- Caffe源码-Layer类
Layer类简介 Layer是caffe中搭建网络的基本单元,caffe代码中包含大量Layer基类派生出来的各种各样的层,各自通过虚函数 Forward() 和 Backward() 实现自己的功能 ...
- Caffe源码-几种优化算法
SGD简介 caffe中的SGDSolver类中实现了带动量的梯度下降法,其原理如下,\(lr\)为学习率,\(m\)为动量参数. 计算新的动量:history_data = local_rate * ...
- caffe源码阅读
参考网址:https://www.cnblogs.com/louyihang-loves-baiyan/p/5149628.html 1.caffe代码层次熟悉blob,layer,net,solve ...
- Caffe源码中common文件分析
Caffe源码(caffe version:09868ac , date: 2015.08.15)中的一些重要头文件如caffe.hpp.blob.hpp等或者外部调用Caffe库使用时,一般都会in ...
随机推荐
- Flex容器拖动(Bordercontainer为例)
Bordercontainer的拖放到任意位置. mxml: 为Bordercontainer添加鼠标按下和弹起事件 <s:BorderContainer id="bdShow&quo ...
- PowerDesigner从安装到同步数据库
前言 最近项目在如火如荼的进行着4.0版本的数据库设计工作,我们几个后端小伙伴也马不停蹄的进行着数据库的设计.使用的设计软件是PowerDesigner,这里记录一些常见的问题以备日后查看 安装 链接 ...
- React中使用create-react-app创建项目,运行npm run eject建立灰度报错
我在运行npm run eject建立测试环境和正式环境时候报错 这里的问题是是脚手架添加.gitgnore文件,但是却没有本地仓库,按照以下顺序就可以正常使用 git add . git commi ...
- CSV数据存取
CSV数据的读取十分地简单 分为两部分 读 读取csv文件可以使用csv模块下的reader(f)以及DictReader(f) mport csv with open("text.csv& ...
- jQuery学习笔记3
* 动画效果 * 在一定的时间内, 不断改变元素样式 * slideDown()/slideUp()/slideToggle() * fadeOut()/fadeIn()/fadeToggle() * ...
- JAVA中快速生成get与set
快捷键 ctrl+Alt+S generate getters and setters
- python接口设计中的__all__和del
最近在实现python接口中遇到了一些小问题,解决后总结如下. 目的:在设计接口时,只暴露某个文件的特定方法. 例如: t.py import os import sys def a(): pass ...
- 新浪短网址最新api接口
1,雨林短网址 网站链接:http://yldwz.cn 雨林短网址采用新浪.腾讯官方API接口,强大的多功能API,简单易用,质量高官 网提供强技术支持,99.9% SLA服务稳定安全可靠的校验机制 ...
- Linux安装python环境脚本
自动安装python环境的脚本 1.首先判断是不是root用户 2.判断是否安装 3.是否下载成功(网络可能有问题) 4.是否解压成功(文件下载可能缺少) 5.安装配置python环境 # codin ...
- MySQL 库、表、记录、相关操作(3)
MySQL 库.表.记录.相关操作(3) 单表查询 """ 增: insert [into] [数据库名.]表名[(字段1[, ..., 字段n])] values (数 ...