TensorFlow - 框架实现中的三种 Graph
TensorFlow - 框架实现中的三种 Graph
图(Graph) 是 TensorFlow 用于表达计算任务的一个核心概念.
从前端(python) 描述神经网络的结构,到后端在多机和分布式系统上部署,到底层 Device(CPU、GPU、TPU)上运行,都是基于图来完成.
然而在实际使用过程中遇到了三对API,
[1] - tf.train.Saver()/saver.restore()
[2] - export_meta_graph/Import_meta_graph
[3] - tf.train.write_graph()/tf.Import_graph_def()
它们都是用于对图的保存和恢复.
同一个计算框架,为什么需要三对不同的API呢?他们保存/恢复的图在使用时又有什么区别呢?
初学的时候,常常闹不清楚他们的区别,以至常常写出了错误的程序,经过一番研究,本文中对Tensorflow中围绕Graph的核心概念进行了总结.
1. Graph
首先介绍一下关于 TensorFlow 中 Graph 和它的序列化表示 Graph_def
.
在 TensorFlow 官方文档中,Graph 被定义为 “一些 Operation 和 Tensor 的集合”.
例如表达如下的一个计算的 python代码,
import tensorflow as tf
a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
c = tf.placeholder(tf.float32)
d = a*b+c
e = d*2
就会生成相应的一张图,在Tensorboard中看到的图大概如图:
其中,每一个圆圈表示一个 Operation
(输入处为Placeholder),椭圆到椭圆的边为Tensor
,箭头的指向表示了这张图 Operation 输入输出 Tensor 的传递关系.
在真实的 TensorFlow 运行中,Python 构建的“图Graph” 并不是启动一个 Session 之后始终不变的. 因为 TensorFlow 在运行时,真实的计算会被分配到多CPUs,或 GPUs,或 ARM 等,以进行高性能/能效的计算. 单纯使用 Python 肯定是无法有效完成的.
实际上,TensorFlow 是首先将 python 代码所描绘的图转换(即“序列化”)成 Protocol Buffer,再通过 C/C++/CUDA 运行 Protocol Buffer 所定义的图. (Protocol Buffer 可参考:https://www.ibm.com/developerworks/cn/linux/l-cn-gpb/).
2. GraphDef
从 python Graph中序列化出来的图就叫做 GraphDef
(这是一种不严格的说法,先这样进行理解).
而 GraphDef
又是由许多叫做 NodeDef
的 Protocol Buffer 组成. 在概念上 NodeDef
与(Python Graph 中的) Operation
相对应.
如下就是 GraphDef 的 ProtoBuf,由许多node 组成的图表示. 这是与上文 Python 图对应的 GraphDef:
node {
name: "Placeholder" # 注:这是一个叫做 "Placeholder" 的node
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
unknown_rank: true
}
}
}
}
node {
name: "Placeholder_1" # 注:这是一个叫做 "Placeholder_1" 的node
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
unknown_rank: true
}
}
}
}
node {
name: "mul" # 注:一个 Mul(乘法)操作
op: "Mul"
input: "Placeholder" # 使用上面的node(即Placeholder和Placeholder_1)
input: "Placeholder_1" # 作为这个Node的输入
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
以上三个 NodeDef
定义了两个 Placeholde r和一个Multiply.
Placeholder 通过 attr(attribute的缩写)来定义数据类型和 Tensor 的形状.
Multiply 通过 input 属性定义了两个 placeholder 作为其输入.
无论是 Placeholder 还是 Multiply 都没有关于输出(output)的信息.
其实 Tensorflow 中都是通过 Input 来定义 Node 之间的连接信息.
那么既然 tf.Operation
的序列化 ProtoBuf 是 NodeDef
,那么 tf.Variable
呢?在这个 GraphDef
中只有网络的连接信息,却没有任何 Variables呀?
没错,Graphdef
中不保存任何 Variable 的信息,所以如果从 graph_def
来构建图并恢复训练的话,是不能成功的.
如,
with tf.Graph().as_default() as graph:
tf.import_graph_def("graph_def_path")
saver= tf.train.Saver()
with tf.Session() as sess:
tf.trainable_variables()
其中 tf.trainable_variables()
只会返回一个空的list. tf.train.Saver()
也会报告 no variables to save.
然而,在实际线上 inference 中,通常就是使用 GraphDef
. 但,GraphDef
中连 Variable都没有,怎么存储 weight 呢?
原来, GraphDef
虽然不能保存 Variable,但可以保存 Constant. 通过 tf.constant
将 weight 直接存储在 NodeDef
里,tensorflow 1.3.0 版本也提供了一套叫做 freeze_graph
的工具来自动的将图中的 Variable 替换成 constant 存储在 GraphDef
里面,并将该图导出为 Proto.
https://www.tensorflow.org/extend/tool_developers/https://www.tensorflow.org/mobile/prepare_models
tf.train.write_graph()/tf.Import_graph_def()
就是用来进行 GraphDef
读写的API. 那么,我们怎么才能从序列化的图中,得到 Variables呢?这就要学习下一个重要概念,MetaGraph
.
3. MetaGraph
Meta graph 的官方解释是:一个 Meta Graph
由一个计算图和其相关的元数据构成, 其包含了用于继续训练,实施评估和(在已训练好的的图上)做前向推断的信息.
A MetaGraph consists of both a computational graph and its associated metadata.
A MetaGraph contains the information required to continue training, perform evaluation, or run inference on a previously trained graph.
From https://www.tensorflow.org/versions/r1.1/programmers_guide/
这一段看的云里雾里,不过这篇文章(https://www.tensorflow.org/versions/r1.1/programmers_guide/meta_graph)进一步解释说,Meta Graph在具体实现上就是一个 MetaGraphDef
(同样是由 Protocol Buffer来定义的). 其包含了四种主要的信息,根据Tensorflow官网,这四种 Protobuf 分别是:
[1] - MetaInfoDef
,存一些元信息(比如版本和其他用户信息)
[2] - GraphDef
, MetaGraph 的核心内容之一
[3] - SaverDef
,图的Saver信息(比如最多同时保存的check-point数量,需保存的Tensor名字等,但并不保存Tensor中的实际内容)
[4] - CollectionDef
,任何需要特殊注意的 Python 对象,需要特殊的标注以方便import_meta_graph
后取回(如 train_op
, prediction
等等)
在以上四种 ProtoBuf 里面,[1] 和 [3] 都比较容易理解,[2] 刚刚总结过. 这里特别要讲一下 Collection
(CollectionDef是对应的ProtoBuf).
TensorFlow 中并没有一个官方的定义说 collection
是什么. 简单的理解,它就是为了方别用户对图中的操作和变量进行管理,而创建的一个概念. 它可以说是一种“集合”,通过一个 key (string类型) 来对一组 Python 对象进行命名的集合. 这个 key 既可以是TensorFlow 在内部定义的一些 key,也可以是用户自己定义的名字(string).
TensorFlow 内部定义了许多标准 Key,全部定义在了 tf.GraphKeys
这个类中. 其中有一些常用的,tf.GraphKeys.TRAINABLE_VARIABLES
, tf.GraphKeys.GLOBAL_VARIABLES
等等. tf.trainable_variables()
与 tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
是等价的;tf.global_variables()
与 tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
是等价的.
集合类型 | 集合内容 | 使用环境 |
---|---|---|
tf.GraphKeys.VARIABLES | 神经网络参数 | |
tf.GraphKeys.TRAINABLE_VARIABLES | 模型训练,生产模型可视化内容 | |
tf.GraphKeys.SUMMARIES | 日志生成相关张量 | 计算可视化 |
tf.GraphKeys.QUEUE_RUNNER | 处理输入的QueueRunner | 输入处理 |
tf.MOVING_AVERAGE_BARIABLES | 所有计算了滑动平均值的变量 | 计算变量滑动平均值 |
对于用户定义的 key,举一个例子, 例如:
pred = model_network(X)
loss=tf.reduce_mean(…, pred, …)
train_op=tf.train.AdamOptimizer(lr).minimize(loss)
这样一段 Tensorflow程序,用户希望特别关注 pred
, loss
, train_op
这几个操作,那么就可以使用如下代码,将这几个变量加入到 collection
中去. (假设我们将其命名为 “training_collection”)
tf.add_to_collection("training_collection", pred)
tf.add_to_collection("training_collection", loss)
tf.add_to_collection("training_collection", train_op)
并且可以通过 Train_collect = tf.get_collection(“training_collection”)
得到一个python list,其中的内容就是pred
, loss
, train_op
的 Tensor. 这通常是为了在一个新的 session 中打开这张图时,方便我们获取想要的操作. 比如我们可以直接通过 get_collection()
得到 train_op
,然后通过 sess.run(train_op)
来开启一段训练,而无需重新构建 loss
和optimizer
.
通过 export_meta_graph
保存图,并且通过 add_to_collection
将 train_op
加入到 collection
中:
with tf.Session() as sess:
pred = model_network(X)
loss=tf.reduce_mean(…,pred, …)
train_op=tf.train.AdamOptimizer(lr).minimize(loss)
tf.add_to_collection("training_collection", train_op)
Meta_graph_def =
tf.train.export_meta_graph(tf.get_default_graph(), 'my_graph.meta')
通过 import_meta_graph
将图恢复(同时初始化为本 Session的 default 图),并且通过 get_collection
重新获得 train_op
,以及通过 train_op
来开始一段训练(sess.run() ).
with tf.Session() as new_sess:
tf.train.import_meta_graph('my_graph.meta')
train_op = tf.get_collection("training_collection")[0]
new_sess.run(train_op)
更多的代码例子可以在这篇文档(https://www.tensorflow.org/api_guides/python/meta_graph)中的 Import a MetaGraph
章节中看到.
那么,从 Meta Graph
中恢复构建的图可以被训练吗?是可以的. TensorFlow 的官方文档 https://www.tensorflow.org/api_guides/python/meta_graph 说明了使用方法. 这里要特殊的说明一下,Meta Graph
中虽然包含 Variable 的信息,却没有 Variable 的实际值. 所以, 从Meta Graph
中恢复的图,其训练是从随机初始化的值开始的. 训练中 Variable的实际值都保存在 checkpoint 中,如果要从之前训练的状态继续恢复训练,就要从checkpoint 中 restore. 进一步读一下 Export Meta Graph
的代码,可以看到,事实上variables 并没有被 export 到 meta_graph
中.
https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/training/saver.py (1872行)
https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/framework/meta_graph.py (829,845行)
export_meta_graph/Import_meta_graph
就是用来进行 Meta Graph
读写的API.
tf.train.saver.save()
在保存checkpoint的同时也会保存Meta Graph
. 但是在恢复图时,tf.train.saver.restore()
只恢复 Variable,如果要从MetaGraph
恢复图,需要使用 import_meta_graph
. 这是其实为了方便用户,有时我们不需要从MetaGraph
恢复的图,而是需要在 python 中构建神经网络图,并恢复对应的 Variable.
4. Checkpoint
Checkpoint 里全面保存了训练某时间截面的信息,包括参数,超参数,梯度等等. tf.train.Saver()/saver.restore()
则能够完完整整保存和恢复神经网络的训练.
Checkpoint 分为两个文件保存Variable的二进制信息. ckpt
文件保存了Variable的二进制信息,index
文件用于保存 ckpt 文件中对应 Variable 的偏移量信息.
5. 总结
TensorFlow 三种 API 所保存和恢复的图是不一样的.
这三种图是从 TensorFlow 框架设计的角度出发而定义的.
但是从用户的角度来看,TensorFlow 文档的写作难免有些云里雾里,弄不清他们的区别.需要读一读Tensorflow的代码,做一些实验来进行辨析.
简而言之,TensorFlow 在前端 Python 中构建图,并且通过将该图序列化到 ProtoBuf GraphDef
,以方便在后端运行. 在这个过程中,图的保存、恢复和运行都通过 ProtoBuf 来实现. GraphDef
,MetaGraph
,以及Variable
,Collection
和 Saver
等都有对应的 ProtoBuf 定义. ProtoBuf 的定义也决定了用户能对图进行的操作. 例如用户只能找到 Node的前一个Node,却无法得知自己的输出会由哪个Node接收.
TensorFlow - 框架实现中的三种 Graph的更多相关文章
- Java三大框架之——Hibernate中的三种数据持久状态和缓存机制
Hibernate中的三种状态 瞬时状态:刚创建的对象还没有被Session持久化.缓存中不存在这个对象的数据并且数据库中没有这个对象对应的数据为瞬时状态这个时候是没有OID. 持久状态:对象经过 ...
- .net core 注入中的三种模式:Singleton、Scoped 和 Transient
从上篇内容不如题的文章<.net core 并发下的线程安全问题>扩展认识.net core注入中的三种模式:Singleton.Scoped 和 Transient 我们都知道在 Sta ...
- Asp.Net中的三种分页方式
Asp.Net中的三种分页方式 通常分页有3种方法,分别是asp.net自带的数据显示空间如GridView等自带的分页,第三方分页控件如aspnetpager,存储过程分页等. 第一种:使用Grid ...
- httpClient中的三种超时设置小结
httpClient中的三种超时设置小结 本文章给大家介绍一下关于Java中httpClient中的三种超时设置小结,希望此教程能给各位朋友带来帮助. ConnectTimeoutExceptio ...
- MySQL buffer pool中的三种链
三种page.三种list.LRU控制调优 一.innodb buffer pool中的三种页 1.free page:从未用过的页 2.clean page:干净的页,数据页的数据和磁盘一致 3.d ...
- 研究分析JS中的三种逻辑语句
JS中的三种逻辑语句:顺序.分支和循环语句. 一.顺序语句 代码规范如下:1. <script type="text/javascript"> var a = 10; ...
- JavaScript中的三种弹出对话框
学习过js的小伙伴会发现,我们在一些实例中用到了alert()方法.prompt()方法.prompt()方法,他们都是在屏幕上弹出一个对话框,并且在上面显示括号内的内容,使用这种方法使得页面的交互性 ...
- java多线程中的三种特性
java多线程中的三种特性 原子性(Atomicity) 原子性是指在一个操作中就是cpu不可以在中途暂停然后再调度,既不被中断操作,要不执行完成,要不就不执行. 如果一个操作时原子性的,那么多线程并 ...
- python中的三种输入方式
python中的三种输入方式 python2.X python2.x中以下三个函数都支持: raw_input() input() sys.stdin.readline() raw_input( )将 ...
- Netty中的三种Reactor(反应堆)
目录: Reactor(反应堆)和Proactor(前摄器) <I/O模型之三:两种高性能 I/O 设计模式 Reactor 和 Proactor> <[转]第8章 前摄器(Proa ...
随机推荐
- antd动态tree 自定义样式
import React, { useEffect, useState } from 'react';import { Tree } from 'antd';import './index.less' ...
- 微信小程序按下去的样式
微信小程序设置 hover-class,实现点击态效果 目前支持 hover-class 属性的组件有三个:view.button.navigator. 不支持 hover-class 属性的组件,同 ...
- C#清空控件的值
/// 清除容器里面某些控件的值 /// </summary> /// <param name="parContainer">容器类控件</param ...
- 【StoneDB 模块介绍】服务器模块
[StoneDB 模块介绍]服务器模块 一.介绍 客户端程序和服务器程序本质上都是计算机上的一个进程,客户端进程向服务器进程发送请求的过程本质上是一种进程间通信的过程,StoneDB 数据库服务程序作 ...
- App测试之appium参数入门
Appium入门参数: platformName:平台名称,一般是Android或iOS: platformVersion:平台的版本号,可以使用以下命令: adb shell getprop ro. ...
- binder机制分析
1. binder基本概念 1.1 特点 1)binder 是一种基于C/S通信模式的IPC(Inter_Process Communication). 2)在传输过程中近需要一次copy,为发送添加 ...
- simpleini库的介绍和使用(面向业务编程-格式处理)
simpleini库的介绍和使用(面向业务编程-格式处理) 介绍 simpleini是一个跨平台的ini格式处理库,提供了一些简单的API来读取和写入ini风格的配置文件.它支持ASCII.MBCS和 ...
- 【桥接设计模式详解】Java/JS/Go/Python/TS不同语言实现
[桥接设计模式详解]Java/JS/Go/Python/TS不同语言实现 简介 桥接模式(Bridge Pattern)是一种结构型设计模式,它将一个大类或一系列紧密相关的类拆分为抽象和实现两个独立的 ...
- linux环境下部署mysql环境
一.部署步骤 1.将安装包上传到Linux服务器上(目录随意),然后解压缩 2.进入到解压后的目录下,分别执行以下命令安装四个包(严格按照顺序执行) rpm -ivh mysql-community- ...
- LockSupport 详解
更多内容,前往IT-BLOG LockSupport 用来创建锁和其他同步类的基本线程阻塞原语.简而言之,当调用 LockSupport.park时,表示当前线程将会等待,直至获得许可,当调用 Loc ...