目录

  1. 什么是op
  2. op_def定义
  3. op注册
  4. op构建与注册辅助结构
  5. op重写
  6. 关系图
  7. 涉及的文件
  8. 迭代记录

1. 什么是op

op和kernel是TF框架中最重要的两个概念,如果一定要做一个类比的话,可以认为op相当于函数声明,kernel相当于函数实现。举个例子,对于矩阵相乘,我可以声明一个op叫做MatMul,指明它的名称,输入,输出,参数,以及对参数的限制等。op只是告诉我们,这个操作的目的是什么,操作内部有哪些可定制的东西,但不会提供具体实现。操作在某种设备上的具体实现方法,是由kernel决定的。TF的计算图由节点构成,而每个节点对应了一个op,在构建计算图时,我们只知道不同节点对应的操作是什么,而不知道运行时这个操作是怎样实现的。也就是说,op是编译期概念,而kernel是运行期概念

那为什么要把操作和它的实现分离呢?是为了实现TF代码的可移植性。我们可以把TF构建的计算图想象为Java的字节码,而计算图在执行的时候,需要考虑可用的设备资源,相当于我们在运行Java字节码的时候,需要考虑当前所在的操作系统,选择合适的字节码实现。因为TF的目标是在多设备上运行,但我们在编码的时候,是无法预先知道某一个操作具体是在哪种设备上运行的,因此,将操作和它的实现分离,可以让我们在设计计算图的时候,更专注于它的结构,而不是具体实现。当我们构建完成一个计算图之后,在一个包含GPU的设备上,它可以利用对应操作在GPU上的kernel,充分利用GPU的高计算性能,在一个仅包含CPU的设备上,它也可以利用对应操作在CPU上的kenrel,完成计算功能。这就提高了TF代码在不同设备之间的可移植性。

2. op_def定义

由于仅是操作的声明,OpDef不需要包含太多的API,它被定义在一个proto中。由于这个概念极端重要,我们在这里完整列出它的代码:

message OpDef {
string name = 1;//操作的名称
message ArgDef { //对输入输出的定义
string name = 1;
string description = 2;
DataType type = 3;//以下4个字段说明了数据的类型,详见正文
string type_attr = 4;
string number_attr = 5;
string type_list_attr = 6;
bool is_ref = 16;//输入或输出是否为引用
};
repeated ArgDef input_arg = 2;//输入描述
repeated ArgDef output_arg = 3;//输出描述
message AttrDef {
string name = 1;
string type = 2;
AttrValue default_value = 3;
string description = 4;
bool has_minimum = 5;
int64 minumum = 6;
AttrValue allowed_values = 7;
}
repeated AttrDef attr = 4;
OpDeprecation deprecation = 8;
string summary = 5;
string description = 6;
bool is_commutative = 18;//是否可交换,即op(a,b) == op(b,a)
bool is_aggregate = 16;//是否可聚集
bool is_stateful = 17;//是否带有状态
bool allows_uninitialized_input = 19;//针对赋值操作
};
message OpDeprecation {
int32 version = 1;
string explanation = 2;
};
message OpList {
repeated OpDef op = 1;
};

我们看到,OpDef中最核心的数据成员是操作名称、输入、输出、参数。其中的参数怎样理解呢?我们之前提到op相当于函数声明,这个函数是带参数的,具体使用该操作时,我们需要给参数赋予实际的数值,这个在接下来分析node_def时会详细讲到。

对于其中的几个难理解的点,作出说明:

  • ArgDef中的3-6四个字段,是为了描述输入或输出的类型。当输入或输出是一个张量时,type或type_attr被设置为这个张量的数据类型,当输入或输出是一个由相同数据类型的张量构成的序列时,number_attr被设置为int对应的标识,当输入或输出是一个由张量构成的列表时,type_list_attr被设置为list(type)对应的标识;
  • AttrDef中的has_minimum字段,表明这个属性是否有最小值,如果数据类型是int,那么minimum就是允许的最小值,如果数据类型是列表,那么minimum就是列表的最短长度;
  • is_aggregate这个字段,表明当前的操作是否是可聚集的,一个可聚集的操作是,能接受任意数量相同类型和形状的输入,并且保持输出与每个输入的类型和形状相同,这个字段对于操作的优化非常重要,如果一个操作是可聚集的,并且其输入来自多个不同的设备,那么我们就可以把聚集优化成一个树形的操作,先在设备内部对输入做聚集,最后在操作所在的设备集中,这样可以提高效率。这种优化对于分布式的机器学习模型训练非常有帮助,Spark ML中的TreeAggregate就实现了这样的优化。可惜截止笔者看到的TF1.2版本,还没有实现这个优化;
  • is_stateful这个字段,表明当前的操作是否是带有状态的,什么操作会带有状态呢?比如Variable;

为了方便进行OpDef的构建,TF还设计了OpDefBuilder类,它的私有数据成员如下:

class OpDefBuilder {
//...
private:
OpRegistrationData op_reg_data_;
std::vector<string> attrs_;
std::vector<string> inputs_;
std::vector<string> outputs_;
string doc_;
std::vector<string> errors_;
}

可以看到,除了errors_字段之外,其它内容几乎就是把OpDef的结构原封不动的搬了过来。这里面我们发现了一个新的结构,OpRegistrationData,它的结构如下:

struct OpRegistrationData {
public:
//...
OpDef op_def;
OpShapeInferenceFn shape_inference_fn;
}

在这个结构中,除了我们熟知的OpDef之外,还包含了一个OpShapeInferenceFn结构,它的定义如下:

typedef std::function<Status(shape_inference::InferenceContext* c)> OpShapeInferenceFn;

这个结构的定义中,涉及到了我们后面要讲到的形状推断的内容,这里我们只需要知道,OpShapeInferenceFn是一个帮助操作根据输入形状对输出形状进行推断的函数即可。

3. op注册

为了方便对操作进行统一管理,TF提出了操作注册器的概念。对于核心数据的统一管理类型,我们并不陌生,回想之前介绍的ResourceMgr和AllocatorRegistry,原理如出一辙。因此,这个操作注册器的作用,就是为各种操作提供一个统一的管理接口。

操作注册类的继承结构如下:

graph TB
OpRegistryInterface-->|派生|OpRegistry
OpRegistryInterface-->|派生|OpListOpRegistry

其中,OpRegistryInterface是一个接口类,它提供了注册类最基础的查找功能:

class OpRegistryInterface {
public:
//...
//操作查找方法
virtual Status LookUp(const string& op_type_name, const OpRegistrationData** op_reg_data) const = 0;
Status LookUpOpDef(const string& op_type_name, const OpDef** op_def) const;
}

OpRegistry就是操作注册器,它的核心接口和数据如下:

class OpRegistry : public OpRegistryInterface {
public:
typedef std::function<Status(OpRegistrationData*)> OpRegistrationFactory;
void Register(const OpRegistrationDataFactory& op_data_factory);//操作注册
static OpRegistry* Global();//返回一个全局静态对象
typedef std::function<Status<const Status&, const OpDef&)> Watcher;
Status SetWatcher(const Watcher& watcher);
private:
mutable mutex mu_;
mutable std::vector<OpRegistrationDataFactory> deferred_ GUARDED_BY(mu_);
mutable std::unordered_map<string, const OpRegistrationData*> registry_ GUARDED_BY(mu_);
mutable bool initialized_ GUARDED_BY(mu_);
mutable Watcher watcher_ GUARDED_BY(mu_);
}

这里面有几个有意思的地方:

  • 注册函数Register的输入,是一个函数引用,这个函数接收一个OpRegistrationData指针作为输入,那么这个函数引用的作用究竟是什么呢?它的源代码如下,原来,我们先建立了一个OpRegistrationData的对象,然后将它作为参数传入op_data_factory函数,这个函数会帮我们填充对象的内容,然后再用这个对象的信息进行注册;
Status OpRegistry::RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory) const {
std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData);
Status s = op_data_factory(op_reg_data.get())
//...
}
  • Watcher是一个监视器,每次当我们注册了一个操作的时候,在注册步骤的最后都要调用一下这个Watcher函数,它可以方便我们对注册的操作进行监控,所有的操作注册动作都逃不过它的眼睛,我们可以根据自己的需要定制Watcher;
  • registry_是已注册的操作真正存放的位置,它的结构很简单,是一个操作名到操作数据的映射;
  • initialized_和deferred_是与注册模式相关的两个数据成员,注册模式的概念接下来将会详细阐述;

注册器在注册操作时,分为两种模式,一种是即时注册模式,一种是懒惰注册模式。注册模式通过initialized_字段区分,true表示即时注册模式,false表示懒惰注册模式。在懒惰注册模式中,带注册的操作先被保存在deferred_向量中,在特定的函数调用时再将deferred_中的操作注册到registry_,而即时注册模式下,待注册的操作不用经过deferred_,直接注册到registry_。设计懒惰注册模式的原因是,我们希望部分操作组合的注册是原子的,即要么全部注册,要么全部不注册,因为这些操作之间可能会有相互依赖关系。

为了更加透彻的理解注册模式的转换,我们绘制了OpRegistry类中,与注册相关的函数的调用关系,以及对initialized_的修改如下:

graph TB
LookUpDef-->LookUp
Register-->RegisterAlreadyLocked
LookUp-->MustCallDeferred
GetRegisteredOps-->MustCallDeferred
Export-->MustCallDeferred
ProcessRegistrations-->CallDeferred
DebugString-->Export
MustCallDeferred-->RegisterAlreadyLocked
CallDeferred-->RegisterAlreadyLocked
DeferRegistrations-.设置为false.->initialized_
MustCallDeferred-.设置为true.->initialized_
CallDeferred-.设置为true.->initialized_
OpRegistry-.设置为false.->initialized_

构造函数将initialized_设置为false,进入懒惰注册模式,随后一旦调用了MustCallDeferred或者CallDeferred中的任意一个,都会将initialized_设置为true,进入即时注册模式。想要重新返回懒惰注册模式也很简单,只需要调用DeferRegistrations即可。

最后简单介绍一下OpListRegistry,它允许我们用OpList初始化一个注册器,请注意,OpList仅仅是OpDef的列表,它并不包含形状推断函数这个信息,因此这个注册器中的操作,是不包含形状推断函数的。如果我们要查找的操作不需要形状推断函数,就可以使用这个注册器。它的私有数据如下:

class OpListOpRegistry : public OpRegistryInterface {
public:
//...
private:
std::unordered_map<string, const OpRegistrationData*> index_;
}

4. op构建与注册辅助结构

为了方便对操作的注册,TF提出了专为注册操作的宏,举例如下:

REGISTER_OP("my_op_name")
.Attr("<name>:<type>")
.Attr("<name>:<type>=<default>")
.Input("<name>:<type-expr>")
.Output("<name>:<type-expr>")
.Doc(R"(
<1-line summary>
<rest of the description (potensitally many lines)>
...
)");

这种写法大大方便了注册操作的过程。但想要实现这种宏操作,目前的类还满足不了。TF设计了两个类来实现这个功能,一个类为op的构建提供链式语法支持,另外一个类接受op构建结果,提供操作注册功能。这两个类分别是OpDefBuilderWrapper和OpDefBuilderReceiver。我们先来看前者:

class OpDefBuilderWrapper<true> {
public:
OpDefBuilderWrapper(const char name[]) : builder_(name){}
OpDefBuilderWrapper<true>& Attr(StringPiece spec){
builder_.Attr(spec);
return *this;
}
//...
private:
mutable ::tensorflow::OpDefBuilder builder_;
}

有两点比较有意思,首先顾名思义这个类基本上是对OpDefBuilder的一个封装,提供了几乎完全一致的API;其次,它的API都是设置型,且都返回对象本身,这就为链式的属性设置提供了可能。值得注意的是,这个类名后面跟着一个true,它的含义我们待会儿揭晓。

再来看看OpDefBuilderReceiver:

struct OpDefBuilderReceiver {
OpDefBuilderReceiver(const OpDefBuilderWrapper<true>& wrapper);
constexpr OpDefBuilderReceiver(const OpDefBuilderWrapper<false>&){}
};

它提供的构造函数,以OpDefBuilderWrapper作为输入参数,也就是说,我们可以通过赋值构造把后者直接赋值给前者,看下REGISTER_OP的宏定义:

//为了忽略不必要的细节,以下代码做了适当删减
#define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
#define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
#define REGISTER_OP_UNIQ(ctr, name) \
static OpDefBuilderReceiver register_op##ctr = OpDefBuilderWrapper<SHOULD_REGISTOR_OP(name)>(name)

我们发现,REGISTER_OP绕了一圈,最终就是先用OpDefBuilderWrapper对操作进行封装,然后把它作为参数传递给OpDefBuilderReceiver的构造函数,而在这个构造函数中,完成了对操作的注册:

OpDefBuilderReceiver::OpDefBuilderReceiver(const OpDefBuilderWrapper<true>& wrapper) {
OpRegistry::Global()->Register([wrapper](OpRegistrationData* op_reg_data) -> Status {
return wrapper.builder().Finalize(op_reg_data);
});
}
}

最后我们来解释下刚才卖的关子,OpDefBuilderWrapper<true>后面的这个true到底代表什么。我们知道,TF为我们准备了很多的操作,但有些时候我们可能用不着这所有的操作,仅需要其中一部分。如果不加限制全部编译,会给我们的运行时库带来很大的负担。因此,TF允许我们添加一个头文件,用宏SHOULD_REGISTOR_OP定义我们想要导出的操作,比如:

#define SHOULD_REGISTOR_OP(Add) true
#define SHOULD_REGISTOR_OP(Subtract) false

表示我们希望导出Add操作,但希望屏蔽Subtract操作。这样就能够根据需要定制自己的TF运行时库了。因此源代码中除了这个OpDefBuilderWrapper<true>类之外,还有一个OpDefBuilderWrapper<false>类,最后,有些操作系统是必须要导出的,比如一些内部操作,TF为此设计了另外一个宏,可以无视SHOULD_REGISTOR_OP的宏定义,感兴趣的读者可以去看下源代码。

5. op重写

随着TF的不断拓展,操作本身也在不断的迭代,比如重命名。为了与已有的图实现向前兼容,TF提出了OpGenOverrides的结构,如下:

message OpGenOverride {
string name = 1;
bool skip = 2;//直接废弃这个操作
bool hide = 3;//对外隐藏
string rename_to = 4;
repeated string alias = 5;//更新API的名称
message AttrDefault {
string name = 1;
AttrValue value = 2;
}
repeated AttrDefault attr_default = 6;//修改参数默认值
message Rename {
string from = 1;
string to = 2;
}
repeated Rename attr_rename = 7;
repeated Rename input_rename = 8;
repeated Rename output_rename = 9;
}
message OpGenOverrides {
repeated OpGenOverride op = 1;
}

具体的替换操作是由OpGenOverrideMap这个类实现的,它读入一系列包含OpGenOverrides proto的文本文件,然后允许你查找针对每个已有操作的迭代:

class OpGenOverrideMap {
public:
Status LoadFile(Env* env, const string& filenames);
const OpGenOverride* ApplyOverride(OpDef* op_def) const;
private:
std::unordered_map<string, std::unique_ptr<OpGenOverride>> map_;
};

6. 关系图

graph TB
OpDefBuilder-.包含.->OpRegistrationData
OpRegistrationData-.包含.->OpDef
OpDefBuilder-.构建.->OpDef
OpRegistryInterface-->|派生|OpRegistry
OpRegistryInterface-->|派生|OpListOpRegistry
OpDefBuilder-.包裹.->OpDefBuilderWrapper
OpDefBuilderWrapper-.传递给.->OpDefBuilderReceiver
OpDefBuilderReceiver-.注册.->OpRegistry

7. 涉及的文件

  • op
  • op_def_builder
  • op_def
  • op_gen_lib
  • op_gen_overrides

8. 迭代记录

  • v1.0 2018-08-26 文档创建
  • v2.0 2018-09-09 文档重构

github地址

tensorflow源码解析之framework-op的更多相关文章

  1. tensorflow源码解析之framework拾遗

    把framework中剩余的内容,按照文件名进行了简单解析.时间原因写的很仓促,算是占个坑,后面有了新的理解再来补充. allocation_description.proto 一个对单次内存分配结果 ...

  2. tensorflow源码解析系列文章索引

    文章索引 framework解析 resource allocator tensor op node kernel graph device function shape_inference 拾遗 c ...

  3. Tensorflow源码解析1 -- 内核架构和源码结构

    1 主流深度学习框架对比 当今的软件开发基本都是分层化和模块化的,应用层开发会基于框架层.比如开发Linux Driver会基于Linux kernel,开发Android app会基于Android ...

  4. tensorflow源码解析之common_runtime-executor-上

    目录 核心概念 executor.h Executor NewLocalExecutor ExecutorBarrier executor.cc structs GraphView ExecutorI ...

  5. tensorflow源码解析之common_runtime-executor-下

    目录 核心概念 executor.h Executor NewLocalExecutor ExecutorBarrier executor.cc structs GraphView ExecutorI ...

  6. tensorflow源码解析之framework-allocator

    目录 什么是allocator 内存分配器的管理 内存分配追踪 其它结构 关系图 涉及的文件 迭代记录 1. 什么是allocator Allocator是所有内存分配器的基类,它定义了内存分配器需要 ...

  7. Tensorflow源码解析2 -- 前后端连接的桥梁 - Session

    Session概述 1. Session是TensorFlow前后端连接的桥梁.用户利用session使得client能够与master的执行引擎建立连接,并通过session.run()来触发一次计 ...

  8. tensorflow源码解析之distributed_runtime

    本篇主要介绍TF的分布式运行时的基本概念.为了对TF的分布式运行机制有一个大致的了解,我们先结合/tensorflow/core/protobuf中的文件给出对TF分布式集群的初步理解,然后介绍/te ...

  9. tensorflow源码解析之common_runtime拾遗

    把common_runtime中剩余的内容,按照文件名排序进行了简单的解析,时间原因写的很仓促,算是占个坑,后续有了新的理解再来补充. allocator_retry 有时候内存分配不可能一次完成,为 ...

随机推荐

  1. 为CentOS 6、7升级gcc至4.8、4.9、5.2、6.3、7.3等高版本

    CentOS 7虽然已经出了很多年了,但依然会有很多人选择安装CentOS 6,CentOS 6有些依赖包和软件都比较老旧,如今天的主角gcc编译器,CentOS 6的gcc版本为4.4,CentOS ...

  2. Unsupported major.minor version 52.0报错问题解决方案

    感谢原文:https://blog.csdn.net/wangmaohong0717/article/details/82869359 1.问题描述 工程启动的时候,报错如下: nested exce ...

  3. Visual Studio 下error C2471: 无法更新程序数据库

    转载请注明来源:https://www.cnblogs.com/hookjc/ 解决方案:修改项目属性 右击项目 --> "属性" 1. "C/C++" ...

  4. 关于linux shell编程,alias rm='cp $@ ~/backup; rm $@'

    书上的这个例子需要在ubuntu的低版本的系统才支持,现在基本上都不支持了,想实现也很简单自己写一个脚本先备份再删除. alias也只是做了一次替换alias rm='cp $@ ~/backup; ...

  5. xcode 常用插件 加快开发速度 --严焕培

    1.KSImageNamed-Xcode 为项目中使用的UIImage的imageNamed提供文件名自动补全功能.使用[UIImage imageNamed:@"xxx"]时,该 ...

  6. Express中使用session

    1.安装express-session npm install express-session --save-dev //注意-g无效 2.app.jsvar session = require('e ...

  7. 如何看懂时序图,以DHT21为例

    有很多传感器手册给了我们时序图,我们只要按照时序图操作就行了,还有一些是标准接口,例如SPI,IIC,UART,这些可以利用硬件提供的收发器通信,还有一些我们没有足够的接口,或者没有对应的接口与之通信 ...

  8. 技术管理进阶——谁能成为Leader,大Leader该做什么

    原创不易,求分享.求一键三连 两个故事 谁能成为Leader 之前接手了一块产品业务线,于是与原Leader说了下分工,大概意思是: 我是过来学习的,也能给团队带来更多的资源,团队内的工作你继续管理, ...

  9. SpringBoot中请求参数 @MatrixVariable 矩阵变量

    一.矩阵变量请求格式 /users;id=1,uname=jack 二.SpringBoot开启矩阵请求 首先查看springboot源码关于矩阵部分的内容 在 WebMvcAutoConfigura ...

  10. [LeetCode]1470. 重新排列数组

    给你一个数组 nums ,数组中有 2n 个元素,按 [x1,x2,...,xn,y1,y2,...,yn] 的格式排列. 请你将数组按 [x1,y1,x2,y2,...,xn,yn] 格式重新排列, ...