转 如何阅读TensorFlow源码
通过bazel学习之后,大概了解了TensorFlow的项目的源文件和描述文件。
下面是一篇不错的介绍,搬砖here。
在静下心来默默看了大半年机器学习的资料并做了些实践后,打算学习下现在热门的TensorFlow的实现,毕竟系统这块和自己关系较大。本文会简单的说明一下如何阅读TensorFlow的源码。最重要的是了解其构建工具bazel以及脚本语言调用c或cpp的包裹工具swig。这里假设大家对bazel及swig以及有所了解(不了解的可以google下)。要看代码首先要知道代码怎么构建,因此本文的一大部分会关注构建这块。
如果从源码构建TensorFlow会需要执行如下命令:
bazel build -c opt //tensorflow/tools/pip_package:build_pip_package
对应的BUILD文件的rule为:
sh_binary(
name = "build_pip_package",
srcs = ["build_pip_package.sh"],
data = [
"MANIFEST.in",
"README",
"setup.py",
"//tensorflow/core:framework_headers",
":other_headers",
":simple_console",
"//tensorflow:tensorflow_py",
"//tensorflow/examples/tutorials/mnist:package",
"//tensorflow/models/embedding:package",
"//tensorflow/models/image/cifar10:all_files",
"//tensorflow/models/image/mnist:convolutional",
"//tensorflow/models/rnn:package",
"//tensorflow/models/rnn/ptb:package",
"//tensorflow/models/rnn/translate:package",
"//tensorflow/tensorboard",
],
)
sh_binary在这里的主要作用是生成data的这些依赖。一个一个来看,一开始的三个文件MANIFEST.in、README、setup.py是直接存在的,因此不会有什么操作。
“//tensorflow/core:framework_headers”:
其对应的rule为:
filegroup(
name = "framework_headers",
srcs = [
"framework/allocator.h",
......
"util/device_name_utils.h",
],
)
这里filegroup的作用是给这一堆头文件一个别名,方便其他rule引用。
“:other_headers”:
rule为:
transitive_hdrs(
name = "other_headers",
deps = [
"//third_party/eigen3",
"//tensorflow/core:protos_all_cc",
],
)
transitive_hdrs的定义在:
load("//tensorflow:tensorflow.bzl", "transitive_hdrs")
实现为:
# Bazel rule for collecting the header files that a target depends on.
def _transitive_hdrs_impl(ctx):
outputs = set()
for dep in ctx.attr.deps:
outputs += dep.cc.transitive_headers
return struct(files=outputs) _transitive_hdrs = rule(attrs={
"deps": attr.label_list(allow_files=True,
providers=["cc"]),
},
implementation=_transitive_hdrs_impl,) def transitive_hdrs(name, deps=[], **kwargs):
_transitive_hdrs(name=name + "_gather",
deps=deps)
native.filegroup(name=name,
srcs=[":" + name + "_gather"])
其作用依旧是收集依赖需要的头文件。
“:simple_console”:
其rule为:
py_binary(
name = "simple_console",
srcs = ["simple_console.py"],
srcs_version = "PY2AND3",
deps = ["//tensorflow:tensorflow_py"],
py_library(
name = "tensorflow_py",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = ["//tensorflow/python"],
)
simple_console.py的代码的主要部分是:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function import code
import sys def main(_):
"""Run an interactive console."""
code.interact()
return 0 if __name__ == '__main__':
sys.exit(main(sys.argv))
可以看到起通过deps = [“//tensorflow/python”]构建了依赖包,然后生成了对应的执行文件。看下依赖的rule规则。
//tensorflow/python对应的rule为:
py_library(
name = "python",
srcs = [
"__init__.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__pkg__"],
deps = [
":client",
":client_testlib",
":framework",
":framework_test_lib",
":kernel_tests/gradient_checker",
":platform",
":platform_test",
":summary",
":training",
"//tensorflow/contrib:contrib_py",
],
)
py_library(
name = "training",
srcs = glob(
["training/**/*.py"],
exclude = ["**/*test*"],
),
srcs_version = "PY2AND3",
deps = [
":client",
":framework",
":lib",
":ops",
":protos_all_py",
":pywrap_tensorflow",
":training_ops",
],
)
这里其依赖的pywrap_tensorflow的rule为:
tf_py_wrap_cc(
name = "pywrap_tensorflow",
srcs = ["tensorflow.i"],
swig_includes = [
"client/device_lib.i",
"client/events_writer.i",
"client/server_lib.i",
"client/tf_session.i",
"framework/python_op_gen.i",
"lib/core/py_func.i",
"lib/core/status.i",
"lib/core/status_helper.i",
"lib/core/strings.i",
"lib/io/py_record_reader.i",
"lib/io/py_record_writer.i",
"platform/base.i",
"platform/numpy.i",
"util/port.i",
"util/py_checkpoint_reader.i",
],
deps = [
":py_func_lib",
":py_record_reader_lib",
":py_record_writer_lib",
":python_op_gen",
":tf_session_helper",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
"//util/python:python_headers",
],
)
tf_py_wrap_cc为其自己实现的一个rule,这里的.i就是SWIG的interface文件。来看下其实现:
def tf_py_wrap_cc(name, srcs, swig_includes=[], deps=[], copts=[], **kwargs):
module_name = name.split("/")[-1]
# Convert a rule name such as foo/bar/baz to foo/bar/_baz.so
# and use that as the name for the rule producing the .so file.
cc_library_name = "/".join(name.split("/")[:-1] + ["_" + module_name + ".so"])
extra_deps = []
_py_wrap_cc(name=name + "_py_wrap",
srcs=srcs,
swig_includes=swig_includes,
deps=deps + extra_deps,
module_name=module_name,
py_module_name=name)
native.cc_binary(
name=cc_library_name,
srcs=[module_name + ".cc"],
copts=(copts + ["-Wno-self-assign", "-Wno-write-strings"]
+ tf_extension_copts()),
linkopts=tf_extension_linkopts(),
linkstatic=1,
linkshared=1,
deps=deps + extra_deps)
native.py_library(name=name,
srcs=[":" + name + ".py"],
srcs_version="PY2AND3",
data=[":" + cc_library_name])
_py_wrap_cc = rule(attrs={
"srcs": attr.label_list(mandatory=True,
allow_files=True,),
"swig_includes": attr.label_list(cfg=DATA_CFG,
allow_files=True,),
"deps": attr.label_list(allow_files=True,
providers=["cc"],),
"swig_deps": attr.label(default=Label(
"//tensorflow:swig")), # swig_templates
"module_name": attr.string(mandatory=True),
"py_module_name": attr.string(mandatory=True),
"swig_binary": attr.label(default=Label("//tensorflow:swig"),
cfg=HOST_CFG,
executable=True,
allow_files=True,),
},
outputs={
"cc_out": "%{module_name}.cc",
"py_out": "%{py_module_name}.py",
},
implementation=_py_wrap_cc_impl,)
_py_wrap_cc_impl的实现为:
# Bazel rules for building swig files.
def _py_wrap_cc_impl(ctx):
srcs = ctx.files.srcs
if len(srcs) != 1:
fail("Exactly one SWIG source file label must be specified.", "srcs")
module_name = ctx.attr.module_name
cc_out = ctx.outputs.cc_out
py_out = ctx.outputs.py_out
src = ctx.files.srcs[0]
args = ["-c++", "-python"]
args += ["-module", module_name]
args += ["-l" + f.path for f in ctx.files.swig_includes]
cc_include_dirs = set()
cc_includes = set()
for dep in ctx.attr.deps:
cc_include_dirs += [h.dirname for h in dep.cc.transitive_headers]
cc_includes += dep.cc.transitive_headers
args += ["-I" + x for x in cc_include_dirs]
args += ["-I" + ctx.label.workspace_root]
args += ["-o", cc_out.path]
args += ["-outdir", py_out.dirname]
args += [src.path]
outputs = [cc_out, py_out]
ctx.action(executable=ctx.executable.swig_binary,
arguments=args,
mnemonic="PythonSwig",
inputs=sorted(set([src]) + cc_includes + ctx.files.swig_includes +
ctx.attr.swig_deps.files),
outputs=outputs,
progress_message="SWIGing {input}".format(input=src.path))
return struct(files=set(outputs))
这里的ctx.executable.swig_binary是一个shell脚本,内容为:
# If possible, read swig path out of "swig_path" generated by configure
SWIG=swig
SWIG_PATH=tensorflow/tools/swig/swig_path
if [ -e $SWIG_PATH ]; then
SWIG=`cat $SWIG_PATH`
fi # If this line fails, rerun configure to set the path to swig correctly
"$SWIG" "$@"
可以看到起就是调用了swig命令。
“//tensorflow:tensorflow_py”:
其rule为:
py_library(
name = "tensorflow_py",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = ["//tensorflow/python"],
)
可以看到起主要依赖了我们上面生成的”//tensorflow/python”这个module。
剩余的几个其实和主框架关系不大,主要是生成一些model、文档啥的。
现在清楚了其构建链后,我们来看个简单的程序,其通过梯度下降算法求线性拟合的W和b。我们会从这个例子入手看下如何找到其使用的函数的具体实现的源码位置:
(python3.5)➜ tmp cat th.py
import tensorflow as tf
import numpy as np x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3 W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = W * x_data + b loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss) init = tf.initialize_all_variables() sess = tf.Session()
sess.run(init) for step in range(0, 201):
sess.run(train)
if step % 20 == 0:
print(step, sess.run(W), sess.run(b))
(python3.5)➜ tmp python th.py
0 [ 0.42190057] [ 0.17155224]
20 [ 0.1743494] [ 0.26045772]
40 [ 0.11817314] [ 0.29033473]
60 [ 0.10444205] [ 0.29763755]
80 [ 0.10108578] [ 0.29942256]
100 [ 0.10026541] [ 0.29985884]
120 [ 0.10006487] [ 0.2999655]
140 [ 0.10001585] [ 0.29999158]
160 [ 0.10000388] [ 0.29999796]
180 [ 0.10000096] [ 0.29999951]
200 [ 0.10000025] [ 0.29999989]
从我们上面的分析可以看到,import tensorflow as tf来自于tensorflow目录下的__init__.py文件,其内容为:
from tensorflow.python import *
再来看tf.Variable,在tensorflow.python的__init__.py中可以看到其导入了很多符号。但要定位到Variable还是比较困难,因为其很多直接是import *。所以一个快速定位的方法是直接grep这个class:
➜ python grep 'class Variable(' -R ./*
./ops/variables.py:class Variable(object):
对于tf.Session等也可以用同样的方法定位。我们来找个走SWIG包裹的,如果我们去看sess.run,我们会看到如下的代码:
return tf_session.TF_Run(session, options,
feed_dict, fetch_list, target_list,
run_metadata)
这里tf_session就是一个SWIG包裹的模块:
from tensorflow.python import pywrap_tensorflow as tf_session
pywrap_tensorflow在源码里是找不到的,因为这个得从SWIG生成后才有,我们可以从.i文件里找下TF_Run的声明,或者直接grep下这个函数:
➜ tensorflow grep 'TF_Run(' -R ./*
./core/client/tensor_c_api.cc:void TF_Run(TF_Session* s, const TF_Buffer* run_options,
这样就可以看其实现了:
void TF_Run(TF_Session* s, const TF_Buffer* run_options,
// Input tensors
const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
// Output tensors
const char** c_output_tensor_names, TF_Tensor** c_outputs,
int noutputs,
// Target nodes
const char** c_target_node_names, int ntargets,
TF_Buffer* run_metadata, TF_Status* status) {
TF_Run_Helper(s, nullptr, run_options, c_input_names, c_inputs, ninputs,
c_output_tensor_names, c_outputs, noutputs, c_target_node_names,
ntargets, run_metadata, status);
}
转 如何阅读TensorFlow源码的更多相关文章
- Tensorflow[源码安装时bazel行为解析]
0. 引言 通过源码方式安装,并进行一定程度的解读,有助于理解tensorflow源码,本文主要基于tensorflow v1.8源码,并借鉴于如何阅读TensorFlow源码. 首先,自然是需要去b ...
- TensorFlow源码框架 杂记
一.为什么我们需要使用线程池技术(ThreadPool) 线程:采用“即时创建,即时销毁”策略,即接受请求后,创建一个新的线程,执行任务,完毕后,线程退出: 线程池:应用软件启动后,立即创建一定数量的 ...
- tensorflow源码解析系列文章索引
文章索引 framework解析 resource allocator tensor op node kernel graph device function shape_inference 拾遗 c ...
- 如何阅读Java源码 阅读java的真实体会
刚才在论坛不经意间,看到有关源码阅读的帖子.回想自己前几年,阅读源码那种兴奋和成就感(1),不禁又有一种激动. 源码阅读,我觉得最核心有三点:技术基础+强烈的求知欲+耐心. 说到技术基础,我打个比 ...
- newsstand杂志阅读应用源码ipad版
一款newsstand iPad杂志阅读应用源码(newsstand在线下载/动态显示等)可以支持在线下载/动态显示等 ,也是一款newsstand iPad杂志阅读应用源码.运行之后,会在iPad ...
- 如何阅读Java源码
刚才在论坛不经意间,看到有关源码阅读的帖子.回想自己前几年,阅读源码那种兴奋和成就感(1),不禁又有一种激动.源码阅读,我觉得最核心有三点:技术基础+强烈的求知欲+耐心. 说到技术基础,我打个比方吧, ...
- 如何阅读mysql源码
在微博上问mysql高手,如何阅读mysql 源码大致给了下面的一些建议: step 1,知道代码的组织结构(官方文档http://t.cn/z8LoLgh: Step2: 尝试大致了解一条sql涉及 ...
- Ubuntu TensorFlow 源码 Android Demo的编译运行
Ubuntu TensorFlow 源码 Android Demo的编译运行 一. 安装 Android 的SDK和NDK SDK 配置 A:下载 国内下载地址选最新的: SDK: https://d ...
- 编译TensorFlow源码
编译TensorFlow源码 参考: https://www.tensorflow.org/install/install_sources https://github.com/tensorflo ...
随机推荐
- mysql 集群方案
试试基于Galera的MySQL高可用集群 mha mgr
- BZOJ3438小M的作物——最小割
题目描述 小M在MC里开辟了两块巨大的耕地A和B(你可以认为容量是无穷),现在,小P有n中作物的种子,每种作物的种子 有1个(就是可以种一棵作物)(用1...n编号),现在,第i种作物种植在A中种植可 ...
- 图灵机器人API接口
调用图灵API接口实现人机交互 流程一: 注册 图灵机器人官网: http://www.tuling123.com/ 第一步: 先注册, 然后创建机器人, 拿到一个32位的key 编码方式 UTF-8 ...
- Python小练习
1.计算x的n次方 2.计算x的阶乘 3.计算1x1 + 2x2 + 3x3 ...+ NxN之和 def fun(n): s=0 while n > 0: s = s + n*n n = n ...
- springMVC整理03--处理数据模型 & 试图解析器 & @ResponseBody & HttpEntity
1.处理模型数据 SpringMVC 中的模型数据是非常重要的,因为 MVC 中的控制(C)请求处理业务逻辑来生成数据模型(M),而视图(V)就是为了渲染数据模型的数据.当有一个查询的请求,控制器(C ...
- [洛谷P1273] 有线电视网
类型:树形背包 传送门:>Here< 题意:给出一棵树,根节点在转播足球赛,每个叶子节点是一个观众在收看.每个叶子结点到根节点的路径权值之和是该点转播的费用,每个叶子节点的观众都会付val ...
- Centos 7.3 安装Grafana 6.0
grafana简介 Grafana是一个完全开源的度量分析与可视化平台,可对来自各种各种数据源的数据进行查询.分析.可视化处理以及配置告警. Grafana支持的数据源: 官方:Graphite,In ...
- 求集合中选一个数与当前值进行位运算的max
求集合中选一个数与当前值进行位运算的max 这是一个听来的神仙东西. 先确定一下值域把,大概\(2^{16}\),再大点也可以,但是这里就只是写写,所以无所谓啦. 我们先看看如果暴力求怎么做,位运算需 ...
- HDU6333 Harvest of Apples (杭电多校4B)
这莫队太强啦 先推公式S(n,m)表示从C(n, 0) 到 C(n, m)的总和 1.S(n, m) = S(n, m-1) + C(n, m) 这个直接可以转移得到 2.S(n, m) = ...
- Linux的wget命令详解【转载】
Linux wget是一个下载文件的工具,它用在命令行下.对于Linux用户是必不可少的工具,尤其对于网络管理员,经常要下载一些软件或从远程服务器恢复备份到本地服务器.如果我们使用虚拟主机,处理这样的 ...