【撸码caffe四】 solver.cpp&&sgd_solver.cpp
caffe中solver的作用就是交替低啊用前向(forward)算法和后向(backward)算法来更新参数,从而最小化loss,实际上就是一种迭代的优化算法。
solver.cpp中的Solver提供了执行模型训练的入口,在caffe.cpp中train函数的最后通过 solver->Solve()调用:
template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
//检查是否是root_solver,有多个GPU的情况下,允许设置多个solver,GPU间并行工作,
//第一个solver设置为root_solver
CHECK(Caffe::root_solver());
//网络名称
LOG(INFO) << "Solving " << net_->name();
//学习策略
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
// Initialize to false every time we start solving.
requested_early_exit_ = false;
//是否需要从指针所指向的内存读取出之前的训练状态并恢复
if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
}
// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
int start_iter = iter_;
//逐步迭代开始
Step(param_.max_iter() - iter_);
……
}
Solver首先判断执行模式,输出网络名称以及学习策略,并判断是否需要恢复之前的训练状态,之后开始调用Step函数,开始迭代过程。Solver类中的Step函数完成网络模型的逐步优化迭代过程:
template <typename Dtype>
//Step函数完成实际的逐步迭代优化过程
void Solver<Dtype>::Step(int iters) {
//设置开始的迭代次数,如果之前设置了是从snapshot中恢复的,则会从
//snapshot的训练状态继续执行训练
const int start_iter = iter_;
//总的迭代次数
const int stop_iter = iter_ + iters;
//获取设置的要计算之前多少次的loss均值,默认的average_loss为1
int average_loss = this->param_.average_loss();
//清除保存loss的向量
losses_.clear();
//平均loss初始化为0
smoothed_loss_ = 0;
//执行迭代
while (iter_ < stop_iter) {
//清零上一次反向传输过程中产生的梯度数据
// zero-init the params
net_->ClearParamDiffs();
//判断条件,是否执行一次所有测试
if (param_.test_interval() && iter_ % param_.test_interval() == 0
&& (iter_ > 0 || param_.test_initialization())
&& Caffe::root_solver()) {
TestAll();
if (requested_early_exit_) {
// Break out of the while loop because stop was requested while testing.
break;
}
}
for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_start();
}
//是否输出loss等信息
const bool display = param_.display() && iter_ % param_.display() == 0;
net_->set_debug_info(display && param_.debug_info());
// accumulate the loss and gradient
Dtype loss = 0;
//iter_size是在solver.prototxt中设置的,把数据分为多少批次分开迭代,对应还有一个名称为
//batch_size的变量,是在网络中定义的,batch_size定义每批次包含的样本数量,把一个大的
//样本数量分批次训练可以提高训练效率,总的样本数量=iter_size*batch_size
for (int i = 0; i < param_.iter_size(); ++i) {
//累加所有批次的平均误差
loss += net_->ForwardBackward();
}
//计算批次的平均误差
loss /= param_.iter_size();
//更新输出的当前的average_loss个样本的平均loss
// average the loss across iterations for smoothed reporting
UpdateSmoothedLoss(loss, start_iter, average_loss);
if (display) {
//输出迭代次数,平均loss
LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
<< ", loss = " << smoothed_loss_;
const vector<Blob<Dtype>*>& result = net_->output_blobs();
int score_index = 0;
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
const string& output_name =
net_->blob_names()[net_->output_blob_indices()[j]];
const Dtype loss_weight =
net_->blob_loss_weights()[net_->output_blob_indices()[j]];
for (int k = 0; k < result[j]->count(); ++k) {
ostringstream loss_msg_stream;
if (loss_weight) {
loss_msg_stream << " (* " << loss_weight
<< " = " << loss_weight * result_vec[k] << " loss)";
}
LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
<< score_index++ << ": " << output_name << " = "
<< result_vec[k] << loss_msg_stream.str();
}
}
}
for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_gradients_ready();
}
//执行网络更新,每一组网络中的参数的更新都是不同类型的solver实现各自的
//ApplyUpdate函数中完成的
ApplyUpdate();
// Increment the internal iter_ counter -- its value should always indicate
// the number of times the weights have been updated.
++iter_;
SolverAction::Enum request = GetRequestedAction();
// Save a snapshot if needed.
if ((param_.snapshot()
&& iter_ % param_.snapshot() == 0
&& Caffe::root_solver()) ||
(request == SolverAction::SNAPSHOT)) {
Snapshot();
}
if (SolverAction::STOP == request) {
requested_early_exit_ = true;
// Break out of training loop.
break;
}
}
}
一次完整的训练流程包括一次前向传输和一次反向传输,分别计算模型的loss和梯度,通过梯度数据计算出参数的更新,更新是通过在Step函数中调用ApplyUpdate函数完成的,ApplyUpdate是在SGDSolver类中定义的:
template <typename Dtype>
void SGDSolver<Dtype>::ApplyUpdate() {
CHECK(Caffe::root_solver());
//根据设置的lr_policy,依据对应的规则计算当前迭代的learning rete的值
Dtype rate = GetLearningRate();
//是否输出当前的迭代次数和学习率数据
if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
}
//避免梯度爆炸,如果梯度的L1或L2范数超过了某个上限值,则将梯度减小
ClipGradients();
//更新所有参数,包括卷积层和池化层的卷积核和偏置两组参数
for (int param_id = 0; param_id < this->net_->learnable_params().size();
++param_id) {
//将参数的梯度归一化,除以iter_size,其作用是保证实际的batch_size=iter_size*batch_size
Normalize(param_id);
//将正则化部分的梯度存入到每个参数的梯度中
Regularize(param_id);
//计算SGD算法的梯度(momentum等)
ComputeUpdateValue(param_id, rate);
}
//调用Net::Update更新参数
this->net_->Update();
}
template <typename Dtype>
void SGDSolver<Dtype>::Normalize(int param_id) {
//如果训练数据的批次数为1,则不进行归一化,直接返回
if (this->param_.iter_size() == 1) { return; }
//获取所有要优化的参数
// Scale gradient to counterbalance accumulation.
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
//归一化系数
const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();
switch (Caffe::mode()) {
case Caffe::CPU: {
//CPU中执行归一化操作的函数
caffe_scal(net_params[param_id]->count(), accum_normalization,
net_params[param_id]->mutable_cpu_diff());
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
caffe_gpu_scal(net_params[param_id]->count(), accum_normalization,
net_params[param_id]->mutable_gpu_diff());
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}
template <typename Dtype>
void SGDSolver<Dtype>::Regularize(int param_id) {
//获取所有要优化的参数
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
//获取所有要优化的参数的权重衰减向量
const vector<float>& net_params_weight_decay =
this->net_->params_weight_decay();
//获取网络模型整体的权重衰减
Dtype weight_decay = this->param_.weight_decay();
//获取网络的正则化类型,L1或者L2
string regularization_type = this->param_.regularization_type();
//每一个参数的权重衰减等于每个参数的权重衰减乘以网络整体的权重衰减
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
switch (Caffe::mode()) {
case Caffe::CPU: {
if (local_decay) { //权重为0时,代表梯度消失
if (regularization_type == "L2") {
// add weight decay
//执行正则化,L2的梯度diff_=weight_decay*data_+diff_
caffe_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
} else if (regularization_type == "L1") {
caffe_cpu_sign(net_params[param_id]->count(),
net_params[param_id]->cpu_data(),
temp_[param_id]->mutable_cpu_data());
caffe_axpy(net_params[param_id]->count(),
local_decay,
temp_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}
break;
}
……
}
【撸码caffe四】 solver.cpp&&sgd_solver.cpp的更多相关文章
- 【撸码caffe 三】 caffe.cpp
caffe.cpp文件完成对网络模型以及模型配置参数的读入和提取,提供了网络模型训练的入口函数train和对模型的测试入口函数test.文件中使用了很多gflags和glog指令,gflags是goo ...
- 【撸码caffe 五】数据层搭建
caffe.cpp中的train函数内声明了一个类型为Solver类的智能指针solver: // Train / Finetune a model. int train() { -- shared_ ...
- 【撸码caffe 二】 blob.hpp
Blob类是caffe中对处理和传递的实际数据的封装,是caffe中基本的数据存储单元,包括前向传播中的图像数据,反向传播中的梯度数据以及网络层间的中间数据变量(包括权值,偏置等),训练模型的参数等等 ...
- 【撸码caffe 一】syncedmen.hpp
SyncedMemory类主要负责在主机(CPU)和设备(GPU)之间管理内存分配和数据同步工作,封装了CPU和GPU之间的数据交互操作. 补充一点GPU的相关知识: 对CUDA架构而言,主机端的内存 ...
- 36 网络相关函数(四)——live555源码阅读(四)网络
36 网络相关函数(四)——live555源码阅读(四)网络 36 网络相关函数(四)——live555源码阅读(四)网络 简介 7)createSocket创建socket方法 8)closeSoc ...
- 34 网络相关函数(二)——live555源码阅读(四)网络
34 网络相关函数(二)——live555源码阅读(四)网络 34 网络相关函数(二)——live555源码阅读(四)网络 2)socketErr 套接口错误 3)groupsockPriv函数 4) ...
- 响应国家号召,在家撸码之React迁移记
最近这段时间新型冠状病毒肆虐,上海确诊人数每天都在增加,人人提心吊胆,街上都没人了.为了响应国家号召,近期呆在家里撸码,着手将项目迁移到React中,项目比较朴素,是一张线索提交页面,包含表单.图片滚 ...
- 【深度学习】之Caffe的solver文件配置(转载自csdn)
原文: http://blog.csdn.net/czp0322/article/details/52161759 今天在做FCN实验的时候,发现solver.prototxt文件一直用的都是mode ...
- 40 网络相关函数(八)——live555源码阅读(四)网络
40 网络相关函数(八)——live555源码阅读(四)网络 40 网络相关函数(八)——live555源码阅读(四)网络 简介 15)writeSocket向套接口写数据 TTL的概念 函数send ...
随机推荐
- VMware 11安装Mac OS X 10.10 及安装Mac Vmware Tools.
先上一张效果图兴奋一下,博主穷屌丝一个,只能通过虚拟黑苹果体验下高富帅的生活,感觉超爽的,废话不多说的,直接上图了! 目录: 1.安装所需软件下载: 2.Mac OS X10.10 安装基本步骤: 3 ...
- 拍拍贷投资工具|拍拍贷投标工具|PPD投标工具|PPD投资工具介绍
我们先来分析一下现在市场上在PPD投资的途径: 其他解决方案 1.在网站或者手机客户端手动投标 这种方法对于非常小额的资金是可以的,稍微多一点就会发现不可行,目前PPD手动刷新出来的标几乎都是你刚刷新 ...
- [文章转载]-我的Java后端书架-江南白衣
我的Java后端书架 (2016年暮春3.0版) 04月 24, 2016 | Filed under 技术 书架主要针对Java后端开发. 3.0版把一些后来买的.看的书添补进来,又或删掉或降级一些 ...
- centos设置ssh安全只允许用户从指定的IP登陆
1.编辑文件 /etc/ssh/sshd_config vi /etc/ssh/sshd_config 2.root用户只允许在如下ip登录 AllowUsers root@203.212.4.117 ...
- 在Unity中对注册表的信息进行操作
问题1 在对注册表进行操作时无法生成注册表相关的类 解决办法: 增加头文件using Microsft.Win32; 问题2 在运行程序时报错同时注 ...
- UVALive 3026(KMP算法)
UVALive 3026 KMP中next[]数组的应用: 题意:给出一个字符串,问该字符串每个前缀首字母的位置和该前缀的周期. 思路:裸KMP直接上就是了: 设该字符串为str,str字符串 ...
- Nginx反向代理WebSocket(WSS)
1. WebSocket协议 WebSocket 协议提供了一种创建支持客户端和服务端实时双向通信Web应用程序的方法.作为HTML5规范的一部分,WebSockets简化了开发Web实时通信程序的难 ...
- nlogn求LIS(树状数组)
之前一直是用二分 但是因为比较难理解,写的时候也容易忘记怎么写. 今天比赛讲评的时候讲了一种用树状数组求LIS的方法 (1)好理解,自然也好写(但代码量比二分的大) (2)扩展性强.这个解法顺带求出以 ...
- Spring整合Junit框架进行单元测试Demo
一.开发环境 eclipse版本:4.6.1 maven版本:3.3.3 junit版本:4.12 spring版本:4.1.5.RELEASE JDK版本:1.8.0_111 二.项目结构 图 三. ...
- 曾经遇过的sql问题
曾经遇过的sql问题 问题一: 语句1: select SUM(level) from Comment 语句2: ELSE SUM(level) END as totalLevel from Comm ...