Caffe2源码解析之core
写在前面
在对Tensorflow的后端源码进行了拆解(参见tensorflow源码解析系列文章索引)之后,很想跟其它深度学习框架的实现进行对比,根据框架的流行程度,先选择了Pytorch。Pytorch的后端核心是直接复用了Caffe2,因此本文针对Caffe2源码的core模块进行了简单拆解。
目录
- 数据存储与表示
- storage
- tensor
- blob
- qtensor
- 操作
- observer observable
- operator
- 操作求导
- operator_schema
- context
- 计算图
- graph
- net
- transform
- 运行时
- allocator
- db
- registry
- module
- scope_guard
- workspace
- init
1. 数据存储与表示
1.1 storage
Caffe2中对数据存储的最底层的描述是Storage,它实际上是指向StorageImpl的共享指针,后者包含数据类型、数据指针、容量、数据所在设备等信息。Storage的定义如下:
using Storage = std::shared_ptr<StorageImpl>;
class StorageImpl {
public:
//...
protected:
using DataPtr = std::shared_ptr<void>;
int64_t capacity_ = 0;
DataType data_type_;
DataPtr data_ptr_;
DeviceType device_type_ = CPU;
};
1.2 tensor
Caffe2中的数据统一使用Tensor表示,Tensor由TensorImpl实现,后者包含一个Storage。
Tensor-->|包含|TensorImpl
TensorImpl-->|包含|Storage
Storage-->|指向|StorageImpl
TensorImpl的定义如下:
class TensorImpl {
public:
//...
protected:
using DimVector = std::vector<TIndex>;
DimVector dims_; //张量的维度
TIndex size_ = -1; //张量中包含的元素数量
Storage storage_; //底层存储
};
Tensor并非继承自TensorImpl,而是在内部包含了一个指向TensorImpl的指针,如下:
class Tensor final {
protected:
using TensorImplPtr = c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
TensorImplPtr impl_;
//...
};
对Tensor的方法调用,通过重定向给TensorImpl实现。
1.3 blob
Blob是一个容器,包含了一个指针和这个指针指向内存的数据类型,在Caffe2中,大部分情况下Blob都包含一个指向Tensor的指针。
class Blob {
public:
//...
private:
TypeMeta meta_;
void* pointer_ = nullptr;
DestroyCall destroy_ = nullptr;
};
为了方便对Blob进行传输,定义了其序列化和反序列化的类,分别是BlobSerializerBase和BlobDeserializerBase,以及对应的为Tensor准备的序列化和反序列化类。
BlobSerializerBase-->|派生|TensorSerializer
BlobDeserializerBase-->|派生|TensorDeserializer
1.4 qtensor
低精度的张量,为了便于快速进行低精度的整数乘法计算。具体的做法是,用更低的位数来表示整数,比如,用3个bit表示无符号整数,用4个bit表示有符号整数。低精度张量可以在略微损失模型精度的情况下,大大降低计算复杂度和存储空间大小。
操作
2.1 Observer Observable
Caffe2使用ObserverBase和Observable两个类实现了观察者模式。ObserverBase是基础观察器,用户可以通过继承此类创建新的观察器,而Observable是可被观察属性,用户可以通过继承此类获得可观察属性。
ObserverBase提供了观察器的统一接口,比较简单:
class ObserverBase {
public:
virtual void Start() {}
virtual void Stop() {}
T* subject() const {
return subject_;
}
protected:
T* subject_;
};
其中,subject_表示被观察对象的指针。
Observable封装了可被观察属性,内部包含了一个观察器的列表,结构如下:
class Observable {
public:
using Observer = ObserverBase<T>;
const Observer* AttachObserver(std::unique_ptr<Observer> observer){} //添加观察器
std::unique_ptr<Observer> DetachObserver(const Observer* observer_ptr){} //解除观察器
virtual size_t NumObservers() {
return num_observers_;
} //观察器的数量
void StartAllObservers(){} //启动所有观察器
void StopAllObservers(){} //关闭所有观察器
private:
Observer* observer_cache_;
size_t num_observers_ = 0;
protected:
std::vector<std::unique_ptr<Observer>> observer_list_; //观察器列表
};
2.2 Operator
Operator代表操作的具体实现,相当于Tensorflow中的kernel。Operator继承自OperatorBase,而后者继承自Observable,所以在Caffe2中,“操作”本质上是一个可观察的对象。
Observable-->|派生|OperatorBase
OperatorBase-->|派生|Operator
OperatorBase类包含了操作需要的基本数据元素和接口:
class OperatorBase {
private:
Workspace* operator_ws_;
std::shared_ptr<const OperatorDef> operator_def_;
DeviceOption device_option_;
std::string engine_;
std::string type_;
vector<const Blob*> inputs_;
vector<Blob*> outputs_;
};
OperatorBase中包含了输入和输出的内存指针,可见,在Caffe2中,Operator本质上是一个运行时的对象,这与Tensorflow中Op的设计理念不同,在Tensorflow中,Op是一个编译时对象,仅规定了操作的类型和目标,并不包含具体数据,具体的计算实际上是通过Kernel完成的。
Operator继承自OperatorBase类:
class Operator : public OperatorBase {
public:
bool Run(int stream_id = 0) final {...}
bool RunAsync(int stream_id = 0) final {...}
virtual bool RunOnDevice() = 0;
};
实际上,Run和RunAsync最终都调用了RunOnDevice,完成实际的计算。
如果我们需要使用一些c10中定义的操作,需要将其转换为在Caffe2中可以调用的操作,可以通过如下的宏进行转换:
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(C10Add, C2MyAddOpName)
上述例子中,我们把一个C10Add操作,包装成C2MyAddOpName操作,供我们使用。为了实现这个功能,Caffe2还提供了一个包装类,C10OperatorWrapper。
2.3 操作求导
为了对操作求导,Caffe2推出了一个导数操作生成类,GradientMakerBase,方便用户定义对于某个操作的导数。类包含的数据成员如下:
//为密集和稀疏的blob提供统一的接口
struct GradientWrapper {
string dense_;
string indices_;
string values_;
inline bool IsDense(){}
inline bool IsSparse(){}
inline bool IsEmpty(){}
};
class GradientMakerBase {
protected:
const OperatorDef& def_;
const vector<GradientWrapper>& g_output_;
vector<GradientWrapper> g_input_;
};
可见,GradientMakerBase仅提供了输入输出,以及原操作。用户可以根据原操作,定制导数。
2.3 operator_schema
OpSchema是对操作的静态描述,相当于Tensorflow中的Op,包含的信息如下:
class OpSchema {
private:
string type_;
string file_;
string doc_;
string onnx_schema_;
std::vector<Argument> args_{};
std::vector<std::pair<const char*, const char*>> input_desc_{};
std::vector<std::pair<const char*, const char*>> output_desc_{};
int line_ = 0;
int min_input_ = 0;
int max_input_ = std::numeric_limits<int>::max();
int min_output_ = 0;
int max_output_ = std::numeric_limits<int>::max();
bool private_ = false;
bool inputs_can_cross_devices_ = false;
std::function<bool(int)> num_inputs_allowed = [](int) { return true; }
std::function<bool(int)> num_outputs_allowed = [](int) { return true; }
std::function<bool(int,int)> num_inputs_outputs_allowed_ = [](int,int) { return true; }
std::function<int(int)> calculate_output_;
std::function<bool(int,int)> inplace_allowed_ = [](int,int){}
std::function<bool(int,int)> inplace_enforced_ = [](int,int){}
TensorInferenceFunctionType tensor_inference_function_ = {...}
std::unique_ptr<CostInferenceFunctionType> cost_inference_function_ = nullptr;
DeviceInferenceFunctionType device_inference_function_ = {...}
};
另外Caffe2也提供了一个对于OpSchema的注册类OpSchemaRegistry,如下:
class OpSchemaRegistry {
private:
static CaffeMap<string, OpSchema>& map();
};
2.4 context
Caffe2中的context,其实就是Tensorflow中的OpKernelContext,为操作的实际计算提供通用的支持,主要包含内存拷贝的接口。所有实际的Context类必须继承自BaseContext,而Caffe2为我们准备了一个标准的Context接口,CPUContext类。另外,也同样为GPU准备了一个CUDAContext类。
BaseContext-->|派生|CPUContext
BaseContext-->|派生|CUDAContext
3. 计算图
3.1 graph
Graph表示图的结构,图包含节点,节点包含操作。
Graph-->|包含|Node
Node-->|包含|OperatorDef
Node包含的数据成员:
class Node {
public:
OperatorDef op;
bool active = true; //操作是否被transformation删除
std::map<int, std::vector<string>> parents;
std::vector<int, std::vector<string>> children;
}
Graph包含的私有数据成员:
class Graph {
private:
NetDef netdef_;
std::set<string> external_input_;
std::set<string> external_output_;
std::vector<Node> nodes_;
}
3.2 net
Net是一个可运行的Graph,包含了一个图的所有“操作”,以及它们的上下文。它继承自Observable,本质上是一个可观察的对象。数据成员如下:
class NetBase : public Observable<NetBase>{
public:
virtual bool Run(){...}
virtual bool RunAsync();
protected:
vector<string> external_input_;
vector<string> external_output_;
string name_;
vector<const Event*> events_;
std::shared_ptr<const NetDef> net_def_;
};
NetBase派生出了三种子类,第一种是AsyncNetBase,它包含了异步执行网络所必须的数据和接口:
class AsyncNetBase : public NetBase {
public:
bool RunAsync() override;
protected:
bool canSchedule(...);
std::vector<OperatorBase*> operators_;
std::vector<dag_utils::OperatorNode> operator_nodes_;
std::vector<std::vector<int>> chains_;
std::vector<dag_utils::OpGraphNode> chain_nodes_;
dag_utils::ExecutionChains execution_chains_;
};
第二种是SimpleNet,它表示了一种对图的单线程的顺序执行模式。
第三种是DAGNetBase,它表示了一种对图的多线程的dag执行模式。
相关的net类形成了一个继承体系:
Observable-->|派生|NetBase
NetBase-->|派生|AsyncNetBase
AsyncNetBase-->|派生|AsyncSchedulingNet
NetBase-->|派生|DAGNetBase
DAGNetBase-->|派生|DAGNet
NetBase-->|派生|SimpleNet
DAGNetBase-->|派生|AsyncDAGNet
AsyncNetBase-->|派生|AsyncPollingNet
3.3 transform
transform是一种针对Caffe2的NetDef结构的操作,它将NetDef作为输入,输出新的经过变换的NetDef。它的工作步骤包括:
- 从旧的NetDef中构建一张图,这张图中保存了节点的连接信息;
- 在图中匹配指定的模式,找到它想要更改的子图;
- 用新的操作替换匹配到的子图;
- 根据图构建一个新的NetDef并返回;
Transform功能的实现,依赖于三个功能函数,如下:
- PatternRule(模式规则),它决定了对于一张子图和一个节点,是否可以将这个节点加入这个子图中;
- ValidatorRule(验证规则),它决定了一张子图是否是匹配的;
- ReplaceRule(替换规则),它对一张匹配的子图进行替换;
常用的模式如下:
- CONNECTED_SUBGRAPH,连接子图,它只能匹配连接的子图。比如对于图(1)-->(2)-->(3)-->(4),它能够匹配到[2,3]和[4,3],但不能匹配到[2,4];
- SORTED_WRT_EXECUTION_ORDER,执行序模式,它只能匹配符合执行顺序的子图,节点之间不一定需要有连接,它比General模式要快,例如对于图(1)-->(2)-->(3)-->(4),它可以匹配到[2,4],[3,4],但不能匹配到[3,1],[4,3];
- GENERAL,它可以匹配到任何子图,比如,对于图(1)-->(2)-->(3)-->(4)来说,它可以匹配到子图[2,4],[3,4],[4,2,1]等;
4. 运行时
4.1 allocator
内存分配器。
CPUAllocator-->|派生|DefaultCPUAllocator
CPUAllocator-->|派生|PinnedCPUAllocator
4.2 db
DB类是对kv存储的抽象。包含了用于读取DB数据的Cursor类,用于写DB数据的Transaction类,DB读取的包裹类DBReader,对DBReader进行序列化和反序列化的DBReaderSerializer和DBReaderDeserializer类。
DB-->|读数据时的游标类|Cursor
DB-->|写数据时的事务类|Transaction
DB-->|读数据包装|DBReader
DBReader-->|序列化|DBReaderSerilizer
DBReader-->|反序列化|DBReaderDeserilizer
4.3 registry
注册类,key为字符串,value可以为任意的类。结构如下:
class Registry {
private:
CaffeMap<SrcType, Creator> registry_;
CaffeMap<SrcType, string> help_message_;
};
4.4 module
查看Caffe2已载入的模块,以及载入指定模块。模块指的是动态链接库。
4.5 scope_guard
是“初始化即资源获取”原语的实现,它保证了,如果不显式说明,函数的执行就会离开当前的scope。
4.6 workspace
Workspace包含了所有的运行时对象,包括blob和net,它是所有这些对象的拥有者,负责对这些对象进行管理。
class Workspace {
private:
typedef CaffeMap<string, unique_ptr<Blob>> BlobMap;
BlobMap blob_map_;
typedef CaffeMap<string, unique_ptr<NetBase>> NetMap;
NetMap net_map_;
const string root_folder_;
const Workspace* shared_;
std::unordered_map<string, std::pair<const Workspace*, string>> forwarded_blobs_;
std::unique_ptr<ThreadPool> thread_pool_;
std::mutex thread_pool_creation_mutex_;
std::shared_ptr<Bookkeeper> bookkeeper_;
};
4.7 init
初始化整个Caffe2的运行环境,运行机制是,把需要在环境初始化中运行的函数注册到注册器中,初始化时,会在不同时期运行不同注册器中的函数。核心的函数如下:
CAFFE2_API bool GlobalInit(int* pargc, char*** argv);
整个初始化过程分为三步:
- 先运行通过REGISTER_CAFFE2_EARLY_INIT_FUNCTION注册的函数;
- 再解析Caffe的命令行参数,并启动日志记录系统;
- 最后运行通过REGISTER_CAFFE2_INIT_FUNCTION注册的函数;
写在后面
我在github上新建了一个repo,pytorch_notes,欢迎大家点星星。
Caffe2源码解析之core的更多相关文章
- Caffe2源码解析
写在前面 上一篇文章对Caffe2中的core模块进行了简单拆解Caffe2源码解析之core,本篇给出其它模块的拆解,目的是大致了解每个模块的内容和目标,进一步理解Caffe2的整体框架.内容不多, ...
- .NET Core实战项目之CMS 第三章 入门篇-源码解析配置文件及依赖注入
作者:依乐祝 原文链接:https://www.cnblogs.com/yilezhu/p/9998021.html 写在前面 上篇文章我给大家讲解了ASP.NET Core的概念及为什么使用它,接着 ...
- .Net Core缓存组件(Redis)源码解析
上一篇文章已经介绍了MemoryCache,MemoryCache存储的数据类型是Object,也说了Redis支持五中数据类型的存储,但是微软的Redis缓存组件只实现了Hash类型的存储.在分析源 ...
- .Net Core中的配置文件源码解析
一.配置简述 之前在.Net Framework平台开发时,一般配置文件都是xml格式的Web.config,而需要配置其他格式的文件就需要自己去读取内容,加载配置了..而Net Core支持从命令行 ...
- .Net Core 认证系统之Cookie认证源码解析
接着上文.Net Core 认证系统源码解析,Cookie认证算是常用的认证模式,但是目前主流都是前后端分离,有点鸡肋但是,不考虑移动端的站点或者纯管理后台网站可以使用这种认证方式.注意:基于浏览器且 ...
- .Net Core 认证系统之基于Identity Server4 Token的JwtToken认证源码解析
介绍JwtToken认证之前,必须要掌握.Net Core认证系统的核心原理,如果你还不了解,请参考.Net Core 认证组件源码解析,且必须对jwt有基本的了解,如果不知道,请百度.最重要的是你还 ...
- abp vnext2.0核心组件之.Net Core默认DI组件切换到AutoFac源码解析
老版Abp对Castle的严重依赖在vnext中已经得到了解决,vnext中DI容器可以任意更换,为了实现这个功能,底层架构相较于老版abp,可以说是进行了高度重构.当然这得益于.Net Core的D ...
- .Net Core 中间件之主机地址过滤(HostFiltering)源码解析
一.介绍 主机地址过滤中间件相当于一个白名单,标记哪些主机地址能访问接口. 二.使用 新建WebAPI项目,修改Startup中的代码段如下所示.下面表示允许主机名为“localhost”的主机访问( ...
- .Net Core缓存组件(MemoryCache)源码解析
一.介绍 由于CPU从内存中读取数据的速度比从磁盘读取快几个数量级,并且存在内存中,减小了数据库访问的压力,所以缓存几乎每个项目都会用到.一般常用的有MemoryCache.Redis.MemoryC ...
随机推荐
- NSMutableString常用方法
1.NSMutableString常用方法 - (void)appendString:(NSString *)aString; 拼接aString到最后面 NSMutableString *strM ...
- FileInputStream 类与 FileReader 类的区别
FileInputStream 类与 FileReader 类的区别: 两个类的构造函数的形式和参数都是相同的,参数为 File 对象或者表示路径的 String ,它们到底有何区别呢? FileIn ...
- nginx实现后端tomcat的负载均衡调度
一.编译安装nginx 请跳转查阅:编译安装nginx 二.tomcat的二进制安装 请跳转查阅:二进制安装tomcat 三.配置nginx代理 # cat /apps/nginx/conf/ngin ...
- pytest(4)-测试用例执行顺序
前言 上一篇文章我们讲了在pytest中测试用例的命名规则,那么在pytest中又是以怎样的顺序执行测试用例的呢? 在unittest框架中,默认按照ACSII码的顺序加载测试用例并执行,顺序为:09 ...
- Java面试-常见基础笔试题目
1. Vector, ArrayList, LinkedList的区别 Vector:①长度可变,以类似数组的形式存储于内存中,线程安全(同步),因为其大部分方法都有synchronized 关键字, ...
- python使用泛型
所谓的泛型, 就是将数据类型作为参数进行传递, 即在我们用的时候确定数据类型, 这是一种在面向对象语言中经常使用的特性 一般类使用 以SQLAlchemy举例 比如: 我们统一写个将数据保存到数据库的 ...
- Solution -「洛谷 P6577」「模板」二分图最大权完美匹配
\(\mathcal{Description}\) Link. 给定二分图 \(G=(V=X\cup Y,E)\),\(|X|=|Y|=n\),边 \((u,v)\in E\) 有权 \(w( ...
- Spring Boot AOP 扫盲,实现接口访问的统一日志记录
AOP 是 Spring 体系中非常重要的两个概念之一(另外一个是 IoC),今天这篇文章就来带大家通过实战的方式,在编程猫 SpringBoot 项目中使用 AOP 技术为 controller 层 ...
- Centos7.+系统,二进制包脚本安装Mysql
#!/bin/bash #配置数据库要安装的目录,可以根据自己的安装路径修改PATHDIRPATHDIR="/project"BASEDIR="$PATHDIR/mysq ...
- IPv6 OSPFv3路由协议测试——信而泰网络测试仪实操
关键词 IPv6; OSPFv3; OSPFv2. 前言:在网络部署中,动态路由协议是重要的组成部分.良好的路由协议,是保证网络可靠.稳定运行的基础.为了适应IPv6协议栈的变化,OSPFv3协议在保 ...