概述

网络层的构建是在Net<Dtype>::Init()函数中完成的,构建的流程图如下所示:

从图中可以看出网络层的构建分为三个主要部分:解析网络文件、开始建立网络层、网络层需要参与计算的位置。

解析网络文件

该部分主要有两个函数FilterNet()、InsertSplits()。

 void Net<Dtype>::Init(const NetParameter& in_param) {
CHECK(Caffe::root_solver() || root_net_)
<< "root_net_ needs to be set for all non-root solvers";
// Set phase from the state.
phase_ = in_param.state().phase();
// Filter layers based on their include/exclude rules and
// the current NetState.
NetParameter filtered_param;
FilterNet(in_param, &filtered_param);

FilterNet()的作用是模型参数文件(*.prototxt)中的不符合规则的层去掉。例如:在caffe的examples/mnist中的lenet网络中,如果只是用于网络的前向,则需要将包含train的数据层去掉。

 /*
*调用InsertSplits()函数,对于底层的一个输出blob对应多个上层的情况,
*则要在加入分裂层,形成新的网络。
*函数从filtered_param读入新网络到param
**/
InsertSplits(filtered_param, &param);

InsertSplits()函数的作用是对于底层的一个输出blob对应多个上层的情况,则要在加入分裂层,形成新的网络。这么做的主要原因是多个层反传给该blob的梯度需要累加。例如:LeNet网络中的数据层的top label blob对应两个输入层,分别是accuracy层和loss层,那么需要在数据层在插入一层。如下图:

建立网络层

该部分重要的函数有CreateLayer()、AppendBottom()、AppendTop()、SetUp()。

   ...............
//(很大的一个for循环)对每一层处理
for (int layer_id = ; layer_id < param.layer_size(); ++layer_id) {//开始遍历所有层
............
// Setup layer.
//param.layers(i)返回的是关于第当前层的参数:
const LayerParameter& layer_param = param.layer(layer_id);
if (share_from_root) {
............
} else {
/*
*把当前层的参数转换为shared_ptr<Layer<Dtype>>,
*创建一个具体的层,并压入到layers_中
*/
layers_.push_back(LayerRegistry<Dtype>::CreateLayer(layer_param));
}

对于CreateLayer()函数,把解析的当前层调用CreatorRegistry类进行注册,从而获取到当前层。然后会调用AppendBottom()和AppendTop()函数具体创建层结构。

 //下面开始产生当前层:分别处理bottom的blob和top的blob两个步骤
for (int bottom_id = ; bottom_id < layer_param.bottom_size(); ++bottom_id) {
const int blob_id = AppendBottom(param, layer_id, bottom_id,
&available_blobs, &blob_name_to_idx);
need_backward |= blob_need_backward_[blob_id];
}

对于AppendBottom()函数,其作用是为该层创建bottom blob,由于网络是堆叠而成,即:当前层的输出 bottom是前一层的输出top blob,因此此函数并没没有真正的创建blob,只是在将前一层的指针压入到了bottom_vecs_中。

 int num_top = layer_param.top_size();
for (int top_id = ; top_id < num_top; ++top_id) {
AppendTop(param, layer_id, top_id, &available_blobs, &blob_name_to_idx);
...............
}

对于AppendBottom()函数,其作用是为该层创建top blob,该函数真正的new的一个blob的对象。并将top blob 的指针压入到top_vecs_中。经过这两个函数网络层创建出该层所有的输入、输出blob,接下来就是调用SetUp()函数,正式建立层结构,并为blob分配内存空间。

 //层已经连接完成,开始建立关系
if (share_from_root) {
// Set up size of top blobs using root_net_
const vector<Blob<Dtype>*>& base_top = root_net_->top_vecs_[layer_id];
const vector<Blob<Dtype>*>& this_top = this->top_vecs_[layer_id];
for (int top_id = ; top_id < base_top.size(); ++top_id) {
this_top[top_id]->ReshapeLike(*base_top[top_id]);
}
} else {
layers_[layer_id]->SetUp(bottom_vecs_[layer_id], top_vecs_[layer_id]);
} //SetUp()函数的具体内容
void SetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
InitMutex();
CheckBlobCounts(bottom, top);
LayerSetUp(bottom, top);
Reshape(bottom, top);
SetLossWeights(top);
}

对于SetUp()函数,包含了CheckBlobCounts()、LayerSetUp()、SetLossWeights()、Reshape()等子函数,CheckBlobCounts()函数式读取Blob的数量,LayerSetUp()和Reshape()是虚函数,会在相应的层中实现这两个函数,SetLossWeights(top)函数会把top(输出blob)的loss weight进行初始化,loss weight是用来表示不同Layer产生的loss的重要性,Layer名称中以Loss结尾表示这是一个会产生loss的Layer,其他的Layer只是单纯的用于中间计算,同时每一层的loss值就是所有输出top blob的loss值的和。到此当前层的结构建立完成。经过多次循环,就可以构建整个网络。

确定网络层需要计算的blob

该部分的作用是确定哪些层或哪些层的blob需要参与计算,比如前向时需要确定哪些层的blob需要计算loss,后向时确定哪些层的blob需要计算diff。一个layer是否需要backward computation,主要依据两个方面:

(1)该layer的top blob 是否参与loss的计算;

(2)该layer的bottom blob 是否需要backward computation,比如Data层一般就不需要backward computation

对于前向的过程,部分源码如下:

     ..............
for (int param_id = ; param_id < num_param_blobs; ++param_id) {
const ParamSpec* param_spec = (param_id < param_size) ?
&layer_param.param(param_id) : &default_param_spec;
const bool param_need_backward = param_spec->lr_mult() != ;
need_backward |= param_need_backward;
layers_[layer_id]->set_param_propagate_down(param_id, param_need_backward);
}
for (int param_id = ; param_id < num_param_blobs; ++param_id) {
...........
AppendParam(param, layer_id, param_id);
}

AppendParam()函数的作用是记录带有参数的层或者blob,对于某些有参数的层,例如:卷基层、全连接层有weight和bias。该函数主要是修改和参数有关的变量,实际的层参数的blob在上面提到的setup()函数中已经创建。对于后向的过程和前向类似,部分源码如下:

 if (param.force_backward()) {
for (int layer_id = ; layer_id < layers_.size(); ++layer_id) {//迭代所有层
layer_need_backward_[layer_id] = true;//需要参与backward
for (int bottom_id = ;
bottom_id < bottom_need_backward_[layer_id].size(); ++bottom_id) {//每一层下的需要计算diff的所有blob
bottom_need_backward_[layer_id][bottom_id] =
bottom_need_backward_[layer_id][bottom_id] ||
layers_[layer_id]->AllowForceBackward(bottom_id);
blob_need_backward_[bottom_id_vecs_[layer_id][bottom_id]] =
blob_need_backward_[bottom_id_vecs_[layer_id][bottom_id]] ||
bottom_need_backward_[layer_id][bottom_id];
}
for (int param_id = ; param_id < layers_[layer_id]->blobs().size();
++param_id) {//设置不需要计算参数的层
layers_[layer_id]->set_param_propagate_down(param_id, true);
}
}
}

Net的网络层的构建(源码分析)的更多相关文章

  1. MyBatis源码分析(4)—— Cache构建以及应用

    @(MyBatis)[Cache] MyBatis源码分析--Cache构建以及应用 SqlSession使用缓存流程 如果开启了二级缓存,而Executor会使用CachingExecutor来装饰 ...

  2. Flink源码分析 - 源码构建

    原文地址:https://mp.weixin.qq.com/s?__biz=MzU2Njg5Nzk0NQ==&mid=2247483692&idx=1&sn=18cddc1ee ...

  3. Elasticsearch源码分析 - 源码构建

    原文地址:https://mp.weixin.qq.com/s?__biz=MzU2Njg5Nzk0NQ==&mid=2247483694&idx=1&sn=bd03afe5a ...

  4. 构建锁与同步组件的基石AQS:深入AQS的实现原理与源码分析

    Java并发包(JUC)中提供了很多并发工具,这其中,很多我们耳熟能详的并发工具,譬如ReentrangLock.Semaphore,它们的实现都用到了一个共同的基类--AbstractQueuedS ...

  5. 鸿蒙内核源码分析(构建工具篇) | 顺瓜摸藤调试鸿蒙构建过程 | 百篇博客分析OpenHarmony源码 | v59.01

    百篇博客系列篇.本篇为: v59.xx 鸿蒙内核源码分析(构建工具篇) | 顺瓜摸藤调试鸿蒙构建过程 | 51.c.h.o 编译构建相关篇为: v50.xx 鸿蒙内核源码分析(编译环境篇) | 编译鸿 ...

  6. AFNetworking源码分析

    来源:zongmumask 链接:http://www.jianshu.com/p/8eac5b1975de 简述 在iOS开发中,与直接使用苹果框架中提供的NSURLConnection或NSURL ...

  7. Kafka服务端之网络连接源码分析

    #### 简介 上次我们通过分析KafkaProducer的源码了解了生产端的主要流程,今天学习下服务端的网络层主要做了什么,先看下 KafkaServer的整体架构图 ![file](https:/ ...

  8. spark源码分析以及优化

    第一章.spark源码分析之RDD四种依赖关系 一.RDD四种依赖关系 RDD四种依赖关系,分别是 ShuffleDependency.PrunDependency.RangeDependency和O ...

  9. ABP源码分析三:ABP Module

    Abp是一种基于模块化设计的思想构建的.开发人员可以将自定义的功能以模块(module)的形式集成到ABP中.具体的功能都可以设计成一个单独的Module.Abp底层框架提供便捷的方法集成每个Modu ...

随机推荐

  1. 浏览器常用12种兼容问题(JS)

    //1.滚动条到顶端的距离(滚动高度) var scrollTop = document.documentElement.scrollTop || document.body.scrollTop; / ...

  2. Django中常用字段

    一.Django框架的常用字段 Django ORM 常用字段和参数 常用字段 常用字段 AutoField int自增列,必须填入参数 primary_key=True.当model中如果没有自增列 ...

  3. github配置及使用

    安装git 对于linux系统,不同发行版本的安装方法不一样,请参考https://git-scm.com/download/linux.以ubuntu为例: sudo add-apt-reposit ...

  4. verilog版插值

      开发环境:IDE:LIBERO 9.0(ACTEL公司的)芯片:AFS600 (BGA256),是混合系列的FPGA仿真软件:modelsim atcel 6.5d综合软件:synplify pr ...

  5. fpga新建nios

    [FPGA]基于Qsys的第一个Nios II系统设计 (2013-12-12 21:50:08) 转载▼ 标签: fpga niosii qsys helloword quartusii 分类: 嵌 ...

  6. 系统环境: CentOS 64位+千万不要在生产环境中升级glibc!

    # cd /lib64# LD_PRELOAD=/lib64/libc-2.15.so ln -sf /lib64/libc-2.15.so libc.so.6 libc-2.15.so 这个文件名根 ...

  7. Batchsize与learning rate

    https://www.zhihu.com/question/64134994 1.增加batch size会使得梯度更准确,但也会导致variance变小,可能会使模型陷入局部最优: 2.因此增大b ...

  8. v-if 和 v-show

    关于条件渲染 所谓条件渲染,就是根据不同的条件,使用不同的模板来生成 html. 在 Vue.js 中,使用 v-if 和 v-show 指令来控制条件渲染. 区别 v-show 会在app初始化的时 ...

  9. man gzip

    GZIP(1)                                                                GZIP(1) NAME/名称       gzip, g ...

  10. vue-cli3.0配置

    仅在项目根目录中新建vue.config.js文件即可,部分配置如下 module.exports = { // 基本路径 baseUrl: '/', // 输出文件目录 outputDir: 'di ...