概述

网络层的构建是在Net<Dtype>::Init()函数中完成的,构建的流程图如下所示:

从图中可以看出网络层的构建分为三个主要部分:解析网络文件、开始建立网络层、网络层需要参与计算的位置。

解析网络文件

该部分主要有两个函数FilterNet()、InsertSplits()。

 void Net<Dtype>::Init(const NetParameter& in_param) {
CHECK(Caffe::root_solver() || root_net_)
<< "root_net_ needs to be set for all non-root solvers";
// Set phase from the state.
phase_ = in_param.state().phase();
// Filter layers based on their include/exclude rules and
// the current NetState.
NetParameter filtered_param;
FilterNet(in_param, &filtered_param);

FilterNet()的作用是模型参数文件(*.prototxt)中的不符合规则的层去掉。例如:在caffe的examples/mnist中的lenet网络中,如果只是用于网络的前向,则需要将包含train的数据层去掉。

 /*
*调用InsertSplits()函数,对于底层的一个输出blob对应多个上层的情况,
*则要在加入分裂层,形成新的网络。
*函数从filtered_param读入新网络到param
**/
InsertSplits(filtered_param, &param);

InsertSplits()函数的作用是对于底层的一个输出blob对应多个上层的情况,则要在加入分裂层,形成新的网络。这么做的主要原因是多个层反传给该blob的梯度需要累加。例如:LeNet网络中的数据层的top label blob对应两个输入层,分别是accuracy层和loss层,那么需要在数据层在插入一层。如下图:

建立网络层

该部分重要的函数有CreateLayer()、AppendBottom()、AppendTop()、SetUp()。

   ...............
//(很大的一个for循环)对每一层处理
for (int layer_id = ; layer_id < param.layer_size(); ++layer_id) {//开始遍历所有层
............
// Setup layer.
//param.layers(i)返回的是关于第当前层的参数:
const LayerParameter& layer_param = param.layer(layer_id);
if (share_from_root) {
............
} else {
/*
*把当前层的参数转换为shared_ptr<Layer<Dtype>>,
*创建一个具体的层,并压入到layers_中
*/
layers_.push_back(LayerRegistry<Dtype>::CreateLayer(layer_param));
}

对于CreateLayer()函数,把解析的当前层调用CreatorRegistry类进行注册,从而获取到当前层。然后会调用AppendBottom()和AppendTop()函数具体创建层结构。

 //下面开始产生当前层:分别处理bottom的blob和top的blob两个步骤
for (int bottom_id = ; bottom_id < layer_param.bottom_size(); ++bottom_id) {
const int blob_id = AppendBottom(param, layer_id, bottom_id,
&available_blobs, &blob_name_to_idx);
need_backward |= blob_need_backward_[blob_id];
}

对于AppendBottom()函数,其作用是为该层创建bottom blob,由于网络是堆叠而成,即:当前层的输出 bottom是前一层的输出top blob,因此此函数并没没有真正的创建blob,只是在将前一层的指针压入到了bottom_vecs_中。

 int num_top = layer_param.top_size();
for (int top_id = ; top_id < num_top; ++top_id) {
AppendTop(param, layer_id, top_id, &available_blobs, &blob_name_to_idx);
...............
}

对于AppendBottom()函数,其作用是为该层创建top blob,该函数真正的new的一个blob的对象。并将top blob 的指针压入到top_vecs_中。经过这两个函数网络层创建出该层所有的输入、输出blob,接下来就是调用SetUp()函数,正式建立层结构,并为blob分配内存空间。

 //层已经连接完成,开始建立关系
if (share_from_root) {
// Set up size of top blobs using root_net_
const vector<Blob<Dtype>*>& base_top = root_net_->top_vecs_[layer_id];
const vector<Blob<Dtype>*>& this_top = this->top_vecs_[layer_id];
for (int top_id = ; top_id < base_top.size(); ++top_id) {
this_top[top_id]->ReshapeLike(*base_top[top_id]);
}
} else {
layers_[layer_id]->SetUp(bottom_vecs_[layer_id], top_vecs_[layer_id]);
} //SetUp()函数的具体内容
void SetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
InitMutex();
CheckBlobCounts(bottom, top);
LayerSetUp(bottom, top);
Reshape(bottom, top);
SetLossWeights(top);
}

对于SetUp()函数,包含了CheckBlobCounts()、LayerSetUp()、SetLossWeights()、Reshape()等子函数,CheckBlobCounts()函数式读取Blob的数量,LayerSetUp()和Reshape()是虚函数,会在相应的层中实现这两个函数,SetLossWeights(top)函数会把top(输出blob)的loss weight进行初始化,loss weight是用来表示不同Layer产生的loss的重要性,Layer名称中以Loss结尾表示这是一个会产生loss的Layer,其他的Layer只是单纯的用于中间计算,同时每一层的loss值就是所有输出top blob的loss值的和。到此当前层的结构建立完成。经过多次循环,就可以构建整个网络。

确定网络层需要计算的blob

该部分的作用是确定哪些层或哪些层的blob需要参与计算,比如前向时需要确定哪些层的blob需要计算loss,后向时确定哪些层的blob需要计算diff。一个layer是否需要backward computation,主要依据两个方面:

(1)该layer的top blob 是否参与loss的计算;

(2)该layer的bottom blob 是否需要backward computation,比如Data层一般就不需要backward computation

对于前向的过程,部分源码如下:

     ..............
for (int param_id = ; param_id < num_param_blobs; ++param_id) {
const ParamSpec* param_spec = (param_id < param_size) ?
&layer_param.param(param_id) : &default_param_spec;
const bool param_need_backward = param_spec->lr_mult() != ;
need_backward |= param_need_backward;
layers_[layer_id]->set_param_propagate_down(param_id, param_need_backward);
}
for (int param_id = ; param_id < num_param_blobs; ++param_id) {
...........
AppendParam(param, layer_id, param_id);
}

AppendParam()函数的作用是记录带有参数的层或者blob,对于某些有参数的层,例如:卷基层、全连接层有weight和bias。该函数主要是修改和参数有关的变量,实际的层参数的blob在上面提到的setup()函数中已经创建。对于后向的过程和前向类似,部分源码如下:

 if (param.force_backward()) {
for (int layer_id = ; layer_id < layers_.size(); ++layer_id) {//迭代所有层
layer_need_backward_[layer_id] = true;//需要参与backward
for (int bottom_id = ;
bottom_id < bottom_need_backward_[layer_id].size(); ++bottom_id) {//每一层下的需要计算diff的所有blob
bottom_need_backward_[layer_id][bottom_id] =
bottom_need_backward_[layer_id][bottom_id] ||
layers_[layer_id]->AllowForceBackward(bottom_id);
blob_need_backward_[bottom_id_vecs_[layer_id][bottom_id]] =
blob_need_backward_[bottom_id_vecs_[layer_id][bottom_id]] ||
bottom_need_backward_[layer_id][bottom_id];
}
for (int param_id = ; param_id < layers_[layer_id]->blobs().size();
++param_id) {//设置不需要计算参数的层
layers_[layer_id]->set_param_propagate_down(param_id, true);
}
}
}

Net的网络层的构建(源码分析)的更多相关文章

  1. MyBatis源码分析(4)—— Cache构建以及应用

    @(MyBatis)[Cache] MyBatis源码分析--Cache构建以及应用 SqlSession使用缓存流程 如果开启了二级缓存,而Executor会使用CachingExecutor来装饰 ...

  2. Flink源码分析 - 源码构建

    原文地址:https://mp.weixin.qq.com/s?__biz=MzU2Njg5Nzk0NQ==&mid=2247483692&idx=1&sn=18cddc1ee ...

  3. Elasticsearch源码分析 - 源码构建

    原文地址:https://mp.weixin.qq.com/s?__biz=MzU2Njg5Nzk0NQ==&mid=2247483694&idx=1&sn=bd03afe5a ...

  4. 构建锁与同步组件的基石AQS:深入AQS的实现原理与源码分析

    Java并发包(JUC)中提供了很多并发工具,这其中,很多我们耳熟能详的并发工具,譬如ReentrangLock.Semaphore,它们的实现都用到了一个共同的基类--AbstractQueuedS ...

  5. 鸿蒙内核源码分析(构建工具篇) | 顺瓜摸藤调试鸿蒙构建过程 | 百篇博客分析OpenHarmony源码 | v59.01

    百篇博客系列篇.本篇为: v59.xx 鸿蒙内核源码分析(构建工具篇) | 顺瓜摸藤调试鸿蒙构建过程 | 51.c.h.o 编译构建相关篇为: v50.xx 鸿蒙内核源码分析(编译环境篇) | 编译鸿 ...

  6. AFNetworking源码分析

    来源:zongmumask 链接:http://www.jianshu.com/p/8eac5b1975de 简述 在iOS开发中,与直接使用苹果框架中提供的NSURLConnection或NSURL ...

  7. Kafka服务端之网络连接源码分析

    #### 简介 上次我们通过分析KafkaProducer的源码了解了生产端的主要流程,今天学习下服务端的网络层主要做了什么,先看下 KafkaServer的整体架构图 ![file](https:/ ...

  8. spark源码分析以及优化

    第一章.spark源码分析之RDD四种依赖关系 一.RDD四种依赖关系 RDD四种依赖关系,分别是 ShuffleDependency.PrunDependency.RangeDependency和O ...

  9. ABP源码分析三:ABP Module

    Abp是一种基于模块化设计的思想构建的.开发人员可以将自定义的功能以模块(module)的形式集成到ABP中.具体的功能都可以设计成一个单独的Module.Abp底层框架提供便捷的方法集成每个Modu ...

随机推荐

  1. 声明对象的方式/构造函数/原型/this指向

      函数的发展历程(声明函数的方式):     1.通过Object构造函数或字面量的方式创建单个对象 var obj = new Object; obj.name="新华"; o ...

  2. k3 cloud提示超出产品激活有效期

    k3 cloud提示超出产品激活有效期,请联系系统管理员登录管理中心进行产品激活(激活路径:许可中心-许可管理-产品激活) 首先进入管理中心:一次点击许可中心-产品激活 复制激活串号并点击金蝶正版验证 ...

  3. Spring基础08——使用p命名空间

    1.使用p命名空间 为了简化xml配置,越来越多的xml文件采用属性而非子元素配置信息,Spring从2.5版本开始引入了一个新的p命名空间,可以通过<bean>元素属性的方式配置Bean ...

  4. Python自动化学习--异常提示

    举例:打开一个不存在的文件时: >>open("abc.txt","r")会提示错误 Traceback (most recent call las ...

  5. do{}while(0);里面有continue

    do{}while(0);里面有continue,退出的只是do{}while(0);

  6. 以太坊智能合约开发工具 Truffle 入门1

    Truffle是以太坊(Ethereum)智能合约开发的瑞士军刀,小巧好用,上手简单. 本篇文章主要展示如何用Truffle 开发第一个Ethereum智能合约. 1.准备工作:(本人针对window ...

  7. ubuntu 系统问题总结

    一.主题问题 高分辨率显示可能会造成虚拟机中的显示很小,需要调整合适的显示比例.但是 ubuntu 18.04 中的 display 的 scale 只能调整 100% 300%,可能使用 gnome ...

  8. c++ 创建线程用CreateThread后,线程直接就开始执行了吗

    //CreateThread函数的参数原型如下 HANDLE CreateThread( LPSECURITY_ATTRIBUTES lpThreadAttributes, // SD SIZE_T ...

  9. Linux为程序员添加了行为准则

    假如你是开发人员,如果您密切关注Linux开发,您就会知道Linux内核讨论会非常热烈.最近,LinusTorvalds承认Linux内核邮件列表(LKML)和其他Linux开发空间对许多人都是敌对的 ...

  10. 2019年8月19日~8月25日 第八周JAVA学习总结

    临近开学,本周的任务完成情况不够好,平常乱七八糟的事情比较多,所以放在学习上的心思比较少.平均每天放在JAVA学习的时间约1个小时,放在编程的时间约半小时,解决问题的时间约1小时. 下一个星期就要开学 ...