目录

  1. 核心概念
  2. graph_optimizer
  3. function
  4. optimization_registry

1. 核心概念

本篇主要讲图的优化迭代器。我们在构建原始图的时候,专注于达到目的,但不会去考虑图的执行效率。如果把图的设计过程比喻为高级语言的编写,那么图的优化过程就相当于,将高级语言编译为机器语言的过程中,为了能够加速进行的编译优化。比如,将相同的常数折叠,将Identity节点去除等等。本节主要用来讨论,跟图优化相关的类和函数。

2. graph_optimizer

进行图优化,需要有一个统一的入口,它的输入是图本身,以及图执行的环境,以及优化的配置,输出是优化后的图。这个入口就是GraphOptimizer,我们先来看看它的结构和接口:

class GraphOptimizer {
public:
GraphOptimizer(const OptimizerOptions& opts);
void Optimize(FunctionLibraryRuntime* runtime, Env* env, Device* device, std::unique_ptr<Graph>* graph, const std::unordered_map<const Node*, std::vector<PartialTensorShape>>* shape_map);
private:
OptimizerOptions opts_;
};

显然,其中的Optimize就是这个类最重要的API,它将图优化配置opts中的优化过程应用的graph上。可能会将graph替换为另外一个图对象。device是这张图将要运行的设备,它使得优化算法可以考虑针对设备应当考虑的优化选项。shape_map如果是非空的话,它将图中节点的名称映射为部分可知的节点输出形状,可能在某些图优化中会被应用,比如常量折叠优化。

关于图优化,我们需要了解的更为细致一些,所以,先看一下这个类的构造函数具体的实现方式。

GraphOptimizer::GraphOptimizer(const OptimizerOptions& opts) : opts_(opts) {
if(opts_.opt_level()>=OptimizerOptions::L1){
opts_.set_do_common_subexpression_elimination(true);
opts_.set_do_constant_folding(true);
}
}

通过这个函数我们了解到,优化配置是有级别概念的,当级别大于等于1时,某些默认的优化配置需要被开启,比如“公共子项消除”和“常量折叠”。这些内容我们在具体的优化步骤中也会看到。下面就来看一下核心API,Optimize的内容:

void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env, Device* device, std::unique_ptr<Graph>* graph, const std::unordered_map<const Node*, std::vector<PartialTensorShape>>* shape_map){
Graph* g = graph->get();
DumpGraph("Initial",g);//导出当前图的结构 bool changed = true;
const int kMaxRounds = 10;
for(int rounds = 0; rounds < kMaxRounds; ++rounds){
changed = false;
if(RemoveListArrayConverter(g)){
DumpGraph("RemoveListArrayConverter", g);
changed = true;
}
if(opts_.do_function_inlining() && RemoveDeadNodes(g)){
DumpGraph("RemoveDeadNodes", g);
changed = true;
}
if(opts_.do_function_inlining() && RemoveIdentityNodes(g)){
DumpGraph("RemoveIdentityNodes", g);
changed = true;
}
if(opts_.do_constant_folding()){
ConstantFoldingOptions cf_opts;
cf_opts.shape_map = shape_map;
bool was_mutated;
ConstantFold(cf_opts, runtime, env, device, g, &was_mutated).IgnoreError();
if(was_mutated){
RemoveDeadNodes(g);
DumpGraph("ConstFolding",g);
changed = true;
}
}
if(opts_.do_function_inlining() && FixupSourceAndSinkEdges(g)){
DumpGraph("FixupSourceAndSinkEdges",g);
changed = true;
}
if(opts_.do_common_subexpression_elimination() && OptimizeCSE(g,nullptr)){
DumpGraph("ExpandInlineFunctions",g);
changed = true;
}
if(!changed) break;
} //由于flib_def永远不会消失,因此我们可以放心的使用它来构建新图
std::unique_ptr<Graph> copy(new Graph(g->flib_def()));
CopyGraph(*g, copy.get());
graph->swap(copy); DumpGraph("ReCopy", graph->get());
}

在对图进行优化时,我们不可能一蹴而就的,因为优化之间会相互影响,比如我们对图进行了A优化,对于A优化来说,此时图已经是最优的了,但之后我们又对图进行了B优化,此时对于B优化来说,图已经是最优的了,但对于A优化来说则未必。因此图优化是一个循环上升的过程,TF设置了最高的优化是10遍,对于大多数图来说,也就足够了。

在图优化的过程中,我们发现了很多之前没见过的函数,这些函数的定义都在function.h文件中,为了加深对于图优化过程的理解,下面我们了解下这个文件中的函数。

3. function

function.h文件中,没有类定义,全部都是硬生生的函数定义,干货满满。

//kernel生成器,根据FunctionLibraryRuntime和NodeDef来生成kernel
typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&, std::unique_ptr<OpKernel>*)> CustomKernelCreator;
void RegisterDefaultCustomKernelCreator(CusteomKernelCreator cb);//kernel生成器的注册器 //创建一个FunctionLibraryRuntime,用来实例化lib_def中的函数,并在device上运行,如果custom_kernel_creator是非空的,它会被返回的runtime用来生成kernel
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, CusteomKernelCreator custom_kernel_creator); //与之前的函数类似,只不过返回的runtime直接利用RegisterDefaultCustomKernelCreator注册的全局custom_kernel_creator来生成新的kernel
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options); //函数体的内容
struct FunctionBody {
FunctionDef fdef;
Graph* graph = nullptr;
DataTypeVector arg_types;
DataTypeVector ret_types;
gtl::InlinedVector<Node*, 4> arg_nodes;
gtl::InlinedVector<Node*, 4> ret_nodes; FuntionBody(){}
FunctionBody(const FunctionDef& f, DataTypeSlice arg_types, DataTypeSlice ret_types, Graph* g);
~FunctionBody();
}; //删除以下节点,第一,无状态的,第二,无参数的,第三,对输出无贡献的
bool RemoveDeadNodes(Graph* g); //寻找如下的模式,src-(in)->node-(out)->dst,如果node是identity节点,in是唯一的输入数据边,out是唯一的输出数据边,则使用src->dst重写以上模式
bool RemoveIdentityNodes(Graph* g); //将图中的_ListToArray和_ArrayToList转化为Identity节点
bool RemoveListArrayConverter(Graph* g); //对于图中的每个节点,如果lib指明这个节点是一个函数调用,那么内联这个函数体。如果至少一个节点被内联了,返回true。
bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph); //将graph中的内容导出到日志文件,如果日志级别足够高的话
void DumpGraph(StringPiece label, const Graph* g); //应用图重写的优化,例如内联、死节点移除等
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g); //将一个函数的图转化为GraphDef
void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false); //给定一个数值函数,返回它的导数函数
FunctionBody* SymbolicGradient(const FunctionBody& f); //将一个FunctionDef示例化为一个graph,设置fbody指向拥有FunctionDef的FunctionBody
Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs, const FunctionLibraryDefinition* const lib_def, const std::function<Status(const string&, const OpDef**)>& get_func_sig, FunctionBody** fbody);

现在回过头来看GraphOptimizer类中的Optimize函数,首先它把Array和List相互转换节点变为Identity节点,然后删除了死节点,删除Identity节点,进行常量折叠,修复输入输出边,进行公共子项消除,最终完成了对图的优化。

4. optimization_registry

optimization_registry.h文件中,包含了一些维护一个全局的图优化遍历注册器所需要的类,在会话初始化一张图时,会使用这个全局优化遍历注册器来对图进行优化。

首先我们来看第一个类,GraphOptimizationPassOptions,顾名思义,它包含了图优化遍历所需要的参数。这些足够作为一个字典的键值,我们通常会使用一个字典来保持各个图优化遍历器的状态。

struct GraphOptimizationPassOptions {
string session_handle;
const SessionOptions* session_options = nullptr;
const CostModel* cost_model = nullptr;
FunctionLibraryDefinition* flib_def = nullptr;
const DeviceSet* device_set = nullptr;
//如果优化遍历在图分割之前被使用,那么它优化的对象就是这个graph,如果是图分割之后被使用,那么这个graph是null
std::unique_ptr<Graph>* graph = nullptr;
//进行图分割后的优化遍历时使用
std::unordered_map<string, std::unique_ptr<Graph>* partition_graphs = nullptr;
};

图优化遍历,按照在图分割之前还是之后进行,可以分为两类,但我们使用了GraphOptimizationPassOptions这样一个接口。

接下来是GraphOptimizationPass类,所有的图优化遍历类,都是这个类的子类,它的结构也非常简单。

class GraphOptimizationPass {
public:
virtual ~GraphOptimizationPass() {}
virtual Status Run(const GraphOptimizationPassOption& options) = 0;
};

当我们拥有了多种图优化遍历的算法之后,需要对这些进行统一管理,因此TF提出了一种对图优化遍历算法进行统一注册和管理的类:

//这里的键值为phase,图优化遍历算法是按照phase的升序顺序执行的,在一个phase内部,执行顺序是未定义的
typedef std::map<int, std::vector<std::unique_ptr<GraphOptimizationPass>>> GraphOptimizationPasses; class OptimizationPassRegistry {
public:
enum Grouping {
PRE_PLACEMENT,//在cost model赋值之后,在节点放置算法之前
POST_PLACEMENT,//在节点放置算法之后
POST_REWRITE_FOR_EXEC,//在利用feed/fetch节点进行重写之后
POST_PARTITIONING,//在图分割之后
};
void Register(Grouping grouping, int phase, std::unique_ptr<GraphOptimizationPass> pass);//注册图优化遍历算法
Status RunGrouping(Grouping grouping, const GraphOptimizationPassOptions& options);//运行一个groupping中所有的图优化遍历算法,按照phase的升序运行
static OptimizationPassRegistry* Global();//返回一个全局的图优化遍历注册器
private:
std::map<Grouping, GraphOptimizationPasses> groups_;
};

总结一下,groups是一个双层的映射,先从Grouping映射到图优化遍历算法组,这个算法组本身也是个映射,从phase映射到真正的图优化遍历算法,如下:

graph LR
Grouping-->GraphOptimizationPasses
phase-->GraphOptimizationPass

最后,TF为刚才的注册器提供了一个全局的入口:

class OptimizationPassRegistration {
public:
OptimizationPassRegistration(OptimizationPassRegistry::Grouping grouping, int phase, std::unique_ptr<GraphOptimizationPass> pass){
OptimizationPassRegistry::Global->Register(grouping,phase,std::move(pass));
}
};

tensorflow源码解析之common_runtime-graph_optimizer的更多相关文章

  1. tensorflow源码解析之common_runtime拾遗

    把common_runtime中剩余的内容,按照文件名排序进行了简单的解析,时间原因写的很仓促,算是占个坑,后续有了新的理解再来补充. allocator_retry 有时候内存分配不可能一次完成,为 ...

  2. tensorflow源码解析系列文章索引

    文章索引 framework解析 resource allocator tensor op node kernel graph device function shape_inference 拾遗 c ...

  3. Tensorflow源码解析1 -- 内核架构和源码结构

    1 主流深度学习框架对比 当今的软件开发基本都是分层化和模块化的,应用层开发会基于框架层.比如开发Linux Driver会基于Linux kernel,开发Android app会基于Android ...

  4. tensorflow源码解析之framework拾遗

    把framework中剩余的内容,按照文件名进行了简单解析.时间原因写的很仓促,算是占个坑,后面有了新的理解再来补充. allocation_description.proto 一个对单次内存分配结果 ...

  5. tensorflow源码解析之common_runtime-executor-上

    目录 核心概念 executor.h Executor NewLocalExecutor ExecutorBarrier executor.cc structs GraphView ExecutorI ...

  6. tensorflow源码解析之framework-allocator

    目录 什么是allocator 内存分配器的管理 内存分配追踪 其它结构 关系图 涉及的文件 迭代记录 1. 什么是allocator Allocator是所有内存分配器的基类,它定义了内存分配器需要 ...

  7. tensorflow源码解析之common_runtime-executor-下

    目录 核心概念 executor.h Executor NewLocalExecutor ExecutorBarrier executor.cc structs GraphView ExecutorI ...

  8. tensorflow源码解析之distributed_runtime

    本篇主要介绍TF的分布式运行时的基本概念.为了对TF的分布式运行机制有一个大致的了解,我们先结合/tensorflow/core/protobuf中的文件给出对TF分布式集群的初步理解,然后介绍/te ...

  9. tensorflow源码解析之framework-op

    目录 什么是op op_def定义 op注册 op构建与注册辅助结构 op重写 关系图 涉及的文件 迭代记录 1. 什么是op op和kernel是TF框架中最重要的两个概念,如果一定要做一个类比的话 ...

  10. Tensorflow源码解析2 -- 前后端连接的桥梁 - Session

    Session概述 1. Session是TensorFlow前后端连接的桥梁.用户利用session使得client能够与master的执行引擎建立连接,并通过session.run()来触发一次计 ...

随机推荐

  1. Git标签 简单操作

    感谢廖雪峰老师,以下内容多数来自老师的Git教程. 另有部分参考Git中文文档. 创建 命令git tag <tagname> [commit id]用于新建一个标签,默认为HEAD; 也 ...

  2. WebSocket协议详解及应用

    WebSocket协议详解及应用(七)-WebSocket协议关闭帧 本篇介绍WebSocket协议的关闭帧,包括客户端及服务器如何发送并处理关闭帧.关闭帧错误码及错误处理方法.本篇内容主要翻译自RF ...

  3. NSArray 遍历

    1.NSArray的下标遍历 NSArray *arr = @[p1, p2, p3, p4, p5]; for (int i = 0; i < arr.count; ++i) { Person ...

  4. autorelease基本使用

    1.autorelease基本概念 autorelease是一种支持引用计数的内存管理方式,只要给对象发送一条autorelease消息,会将对象放到一个自动释放池中,当自动释放池被销毁时,会对池子里 ...

  5. 【转】zabbix监控mysql

    纯属搬家收藏,原文链接 https://www.cnblogs.com/shenjianyu/p/6627843.html 注意: 1.关注的重点在agent端部分 2.zabbix_get命令没有, ...

  6. AFNetWorking 文件上传 By-H罗

    一.文件上传(图片,音频,视频,文本等)(不带进度) /** * 文件上传 导入 #import "AFNetworking.h" * @param filePath 上传文件本地 ...

  7. GRC: 个人信息保护法, 个人隐私, 企业风险合规治理

    声明 个人原创, 转载需注明来源 https://www.cnblogs.com/milton/p/15885344.html 个人信息保护的历史和现状 个人信息保护的立法可追溯至德国黑森州1970年 ...

  8. Java中的多线程你只要看这一篇就够了(引用)

    引 如果对什么是线程.什么是进程仍存有疑惑,请先Google之,因为这两个概念不在本文的范围之内. 用多线程只有一个目的,那就是更好的利用cpu的资源,因为所有的多线程代码都可以用单线程来实现.说这个 ...

  9. suse 12 二进制部署 Kubernetets 1.19.7 - 第03章 - 部署flannel插件

    文章目录 1.3.部署flannel网络 1.3.0.下载flannel二进制文件 1.3.1.创建flannel证书和私钥 1.3.2.生成flannel证书和私钥 1.3.3.将pod网段写入et ...

  10. Python实例:贪吃蛇(简单贪吃蛇编写)🐍

    d=====( ̄▽ ̄*)b 叮~ Python -- 简易贪吃蛇实现 目录: 1.基本原理 2.需要学习的库 3.代码实现 1.基本原理 基本贪吃蛇所需要的东西其实很少,只需要有一块让蛇动的屏幕, 在 ...