Caffe源码
Blob
protected:
shared_ptr<SyncedMemory> data_;
shared_ptr<SyncedMemory> diff_;
shared_ptr<SyncedMemory> shape_data_;
vector<int> shape_;
int count_;
int capacity_; Blob的构造函数
Blob<Dtype>::Blob(const int num, const int channels, const int height,
const int width)
// capacity_ must be initialized before calling Reshape
: capacity_(0) {
Reshape(num, channels, height, width);
} 会调用reshape函数,为data_,diff_分配内存
template <typename Dtype>
void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
const int width) {
vector<int> shape(4);
shape[0] = num;
shape[1] = channels;
shape[2] = height;
shape[3] = width;
Reshape(shape);
} template <typename Dtype>
void Blob<Dtype>::Reshape(const vector<int>& shape) {
CHECK_LE(shape.size(), kMaxBlobAxes);
count_ = 1;
shape_.resize(shape.size());
if (!shape_data_ || shape_data_->size() < shape.size() * sizeof(int)) {
shape_data_.reset(new SyncedMemory(shape.size() * sizeof(int)));
}
int* shape_data = static_cast<int*>(shape_data_->mutable_cpu_data());
for (int i = 0; i < shape.size(); ++i) {
CHECK_GE(shape[i], 0);
CHECK_LE(shape[i], INT_MAX / count_) << "blob size exceeds INT_MAX";
count_ *= shape[i];
shape_[i] = shape[i];
shape_data[i] = shape[i];
}
if (count_ > capacity_) {
capacity_ = count_;
data_.reset(new SyncedMemory(capacity_ * sizeof(Dtype)));
diff_.reset(new SyncedMemory(capacity_ * sizeof(Dtype)));
} Blob的序列化函数:
//in blob.hpp
void FromProto(const BlobProto& proto, bool reshape = true);
void ToProto(BlobProto* proto, bool write_diff = false) const;
ToProto将Blob的shape_,data_,diff_分别copy到BlobProto的shape,data,diff,完成序列化。FromProto将BlobProto的shape,data,diff分别copy到Blob的shape_,data_,diff_,完成数据解析。最后数据持久化函数由Protocol Buffers的工具实现 Blob中还有个更新参数的函数update(),data=data-diff
void Blob<Dtype>::Update() {
// We will perform update based on where the data is located.
switch (data_->head()) {
case SyncedMemory::HEAD_AT_CPU:
// perform computation on CPU
caffe_axpy<Dtype>(count_, Dtype(-1),
static_cast<const Dtype*>(diff_->cpu_data()),
static_cast<Dtype*>(data_->mutable_cpu_data()));
break;
case SyncedMemory::HEAD_AT_GPU:
case SyncedMemory::SYNCED:
#ifndef CPU_ONLY
// perform computation on GPU
caffe_gpu_axpy<Dtype>(count_, Dtype(-1),
static_cast<const Dtype*>(diff_->gpu_data()),
static_cast<Dtype*>(data_->mutable_gpu_data()));
#else
NO_GPU;
#endif
break;
default:
LOG(FATAL) << "Syncedmem not initialized.";
}
} Layer有5纯虚函数
Reshape()
Forward_cpu()
Backword_cpu()
Forward_gpu()
Backword_gpu() Layer层:
Loss_layer
Common_layer没有了(softmax,innerproduct)
Neuron_layer(tanh)
Vision layer没有了(pooling,conv)
Data_layer变成了BasePrefetchingDataLayer(hdf5 input) Net
Solver
整个过程
solver变量的构造函数中有init(param)
init中有initTrainNet()函数,initTrainNet()函数有net_.reset(new Net<Dtype>(net_param));
然后调用net的构造函数
template <typename Dtype>
Net<Dtype>::Net(const NetParameter& param, const Net* root_net)
: root_net_(root_net) {
Init(param);
}
通过一个for循环将layer一个一个串起来,并且调用layer的setup函数
// layer 初始化设置
void SetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
InitMutex();
CheckBlobCounts(bottom, top);
LayerSetUp(bottom, top);
Reshape(bottom, top);
SetLossWeights(top);
}
LayerSetUp(bottom, top):由Layer类派生出的特定类都需要重写这个函数,主要功能是设置权值参数(包括偏置)的空间以及对权值参数经行随机初始化。 
Reshape(bottom, top):根据输出blob和权值参数计算输出blob的维数,并申请空间。 经过上述过程基本上就完成了初始化的工作,总体的流程大概就是新建一个Solver对象,然后调用Solver类的构造函数,然后在Solver的构造函数中又会新建Net类实例,在Net类的构造函数中又会新建各个Layer的实例,一直具体到设置每个Blob,大概就介绍完了网络初始化的工作,当然里面还有很多具体的细节,但大概的流程就是这样。 上面过程就是从shared_ptr<caffe::Solver<float> > //初始化
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
这个solver开始的
solver->Solve();
template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
...
int start_iter = iter_;
...
// 然后调用了'Step'函数,这个函数执行了实际的逐步的迭代过程
Step(param_.max_iter() - iter_);
...
LOG(INFO) << "Optimization Done.";
} step函数如下
template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
...
//迭代
while (iter_ < stop_iter) {
...
// iter_size也是在solver.prototxt里设置,实际上的batch_size=iter_size*网络定义里的batch_size,
// 因此每一次迭代的loss是iter_size次迭代的和,再除以iter_size,这个loss是通过调用`Net::ForwardBackward`函数得到的
// accumulate gradients over `iter_size` x `batch_size` instances
for (int i = 0; i < param_.iter_size(); ++i) {
/*
* 调用了Net中的代码,主要完成了前向后向的计算,
* 前向用于计算模型的最终输出和Loss,后向用于
* 计算每一层网络和参数的梯度。
*/
loss += net_->ForwardBackward();
} ... /*
* 这个函数主要做Loss的平滑。由于Caffe的训练方式是SGD,我们无法把所有的数据同时
* 放入模型进行训练,那么部分数据产生的Loss就可能会和全样本的平均Loss不同,在必要
* 时候将Loss和历史过程中更新的Loss求平均就可以减少Loss的震荡问题。
*/
UpdateSmoothedLoss(loss, start_iter, average_loss); ...
// 执行梯度的更新,这个函数在基类`Solver`中没有实现,会调用每个子类自己的实现
//,后面具体分析`SGDSolver`的实现
ApplyUpdate(); // 迭代次数加1
++iter_;
... }
} // 进行一次正向传播,一次反向传播
Dtype ForwardBackward() {
Dtype loss;
Forward(&loss);
Backward();
return loss;
} for (int i = start; i <= end; ++i) {
// 对每一层进行前向计算,返回每层的loss,其实只有最后一层loss不为0
Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], top_vecs_[i]);
loss += layer_loss;
if (debug_info_) { ForwardDebugInfo(i); }
} ApplyUpdate();
这个函数是Solver类的纯虚函数,需要派生类来实现,比如SGDSolver类实现的ApplyUpdate();函数如下,主要内容包括:设置参数的学习率;对梯度进行Normalize;对反向求导得到的梯度添加正则项的梯度;最后根据SGD算法计算最终的梯度;最后的最后把计算得到的最终梯度对权值进行更新。 template <typename Dtype>
void SGDSolver<Dtype>::ApplyUpdate() {
CHECK(Caffe::root_solver()); // GetLearningRate根据设置的lr_policy来计算当前迭代的learning rate的值
Dtype rate = GetLearningRate(); // 判断是否需要输出当前的learning rate
if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
} // 避免梯度爆炸,如果梯度的二范数超过了某个数值则进行scale操作,将梯度减小
ClipGradients(); // 对所有可更新的网络参数进行操作
for (int param_id = 0; param_id < this->net_->learnable_params().size();
++param_id) {
// 将第param_id个参数的梯度除以iter_size,
// 这一步的作用是保证实际的batch_size=iter_size*设置的batch_size
Normalize(param_id); // 将正则化部分的梯度降入到每个参数的梯度中
Regularize(param_id); // 计算SGD算法的梯度(momentum等)
ComputeUpdateValue(param_id, rate);
}
// 调用`Net::Update`更新所有的参数
this->net_->Update();
}

  

caffe源码整个训练过程的更多相关文章

  1. caffe源码阅读

    参考网址:https://www.cnblogs.com/louyihang-loves-baiyan/p/5149628.html 1.caffe代码层次熟悉blob,layer,net,solve ...

  2. caffe源码学习

    本文转载自:https://buptldy.github.io/2016/10/09/2016-10-09-Caffe_Code/ Caffe简介 Caffe作为一个优秀的深度学习框架网上已经有很多内 ...

  3. caffe源码学习之Proto数据格式【1】

    前言: 由于业务需要,接触caffe已经有接近半年,一直忙着阅读各种论文,重现大大小小的模型. 期间也总结过一些caffe源码学习笔记,断断续续,这次打算系统的记录一下caffe源码学习笔记,巩固一下 ...

  4. Caffe源码-几种优化算法

    SGD简介 caffe中的SGDSolver类中实现了带动量的梯度下降法,其原理如下,\(lr\)为学习率,\(m\)为动量参数. 计算新的动量:history_data = local_rate * ...

  5. Symfony2源码分析——启动过程2

    文章地址:http://www.hcoding.com/?p=46 上一篇分析Symfony2框架源码,探究Symfony2如何完成一个请求的前半部分,前半部分可以理解为Symfony2框架为处理请求 ...

  6. Caffe源码理解2:SyncedMemory CPU和GPU间的数据同步

    目录 写在前面 成员变量的含义及作用 构造与析构 内存同步管理 参考 博客:blog.shinelee.me | 博客园 | CSDN 写在前面 在Caffe源码理解1中介绍了Blob类,其中的数据成 ...

  7. c#源码的执行过程

    我想也许要写些东西,记录我做程序员的日子吧 ================================================ 要讲到C#源码的执行过程 首先要提下程序集,因为Clr并不 ...

  8. Caffe源码中syncedmem文件分析

    Caffe源码(caffe version:09868ac , date: 2015.08.15)中有一些重要文件,这里介绍下syncedmem文件. 1.      include文件: (1).& ...

  9. Caffe源码中math_functions文件分析

    Caffe源码(caffe version:09868ac , date: 2015.08.15)中有一些重要文件,这里介绍下math_functions文件. 1.      include文件: ...

随机推荐

  1. 【刷题】HDU 1695 GCD

    Problem Description Given 5 integers: a, b, c, d, k, you're to find x in a...b, y in c...d that GCD( ...

  2. 洛谷 3201 [HNOI2009]梦幻布丁 解题报告

    3201 [HNOI2009]梦幻布丁 题目描述 \(N\)个布丁摆成一行,进行\(M\)次操作.每次将某个颜色的布丁全部变成另一种颜色的,然后再询问当前一共有多少段颜色.例如颜色分别为\(1,2,2 ...

  3. linux内核分析 第三周 构造一个简单的Linux系统MenuOS

    一.计算机的三个法宝 存储程序计算机,函数调用堆栈,中断二.操作系统的两把剑:1.中断上下文的切换,保存现场和恢复现场2.进程上下文的切换. 三.linux内核源代码的分析: ·arch/目录保存支持 ...

  4. Codeforces 864E dp

    题意: 房间着火了,里面有n件物品,每件物品有营救需要的时间t,被烧坏的最晚时间d,他的价值p,问能得到的最大价值,并且输出营救出来的物品编号 代码: //必然是先救存活时间短的即d小的,所以先排个序 ...

  5. C++ ------ 互斥锁、原子操作的性能测试

    atomic原子操作:是在新标准C++11,引入了原子操作的概念,并通过这个新的头文件提供了多种原子操作数据类型,例如,atomic_bool,atomic_int等等 测试程序 #include & ...

  6. HDU 4978 计算 凸包

    有以无限间隔$D$的水平线分割的平面,在上面随机投下一个圆,圆中有一些点,点之间两两成一条线段,问随机投下至少有一条线段于平行线相交的概率. 以下是不严(luan)谨(lai)的思路. 首先都知道对于 ...

  7. ACM选修HUST1058(市赛题) Lucky Sequence 同余定理

    Description Edward  得到了一个长度为  N  的整数序列,他想找出这里面有多少个“幸运的”连续子序列.一个连续子序列被称为“幸运的”,当且仅当该子序列内的整数之和恰好是  K  的 ...

  8. PHP扩展--vld查看opcode代码

    vld安装 wget http://pecl.php.net/get/vld-0.13.0.tgz tar zxvf vld-0.13.0.tgz cd vld-0.13.0 /usr/local/p ...

  9. StringUtils.htmlEncode()--html标签过滤方法实现

    package org.guyezhai.utils; import java.text.CharacterIterator; import java.text.StringCharacterIter ...

  10. [php]http响应头解析

    (Status-Line) HTTP/ OK Cache-Control no-cache Content-Length Content-Type image/gif Date Sat, Dec :: ...