转 如何阅读TensorFlow源码
通过bazel学习之后,大概了解了TensorFlow的项目的源文件和描述文件。
下面是一篇不错的介绍,搬砖here。
在静下心来默默看了大半年机器学习的资料并做了些实践后,打算学习下现在热门的TensorFlow的实现,毕竟系统这块和自己关系较大。本文会简单的说明一下如何阅读TensorFlow的源码。最重要的是了解其构建工具bazel以及脚本语言调用c或cpp的包裹工具swig。这里假设大家对bazel及swig以及有所了解(不了解的可以google下)。要看代码首先要知道代码怎么构建,因此本文的一大部分会关注构建这块。
如果从源码构建TensorFlow会需要执行如下命令:
bazel build -c opt //tensorflow/tools/pip_package:build_pip_package
对应的BUILD文件的rule为:
sh_binary(
name = "build_pip_package",
srcs = ["build_pip_package.sh"],
data = [
"MANIFEST.in",
"README",
"setup.py",
"//tensorflow/core:framework_headers",
":other_headers",
":simple_console",
"//tensorflow:tensorflow_py",
"//tensorflow/examples/tutorials/mnist:package",
"//tensorflow/models/embedding:package",
"//tensorflow/models/image/cifar10:all_files",
"//tensorflow/models/image/mnist:convolutional",
"//tensorflow/models/rnn:package",
"//tensorflow/models/rnn/ptb:package",
"//tensorflow/models/rnn/translate:package",
"//tensorflow/tensorboard",
],
)
sh_binary在这里的主要作用是生成data的这些依赖。一个一个来看,一开始的三个文件MANIFEST.in、README、setup.py是直接存在的,因此不会有什么操作。
“//tensorflow/core:framework_headers”:
其对应的rule为:
filegroup(
name = "framework_headers",
srcs = [
"framework/allocator.h",
......
"util/device_name_utils.h",
],
)
这里filegroup的作用是给这一堆头文件一个别名,方便其他rule引用。
“:other_headers”:
rule为:
transitive_hdrs(
name = "other_headers",
deps = [
"//third_party/eigen3",
"//tensorflow/core:protos_all_cc",
],
)
transitive_hdrs的定义在:
load("//tensorflow:tensorflow.bzl", "transitive_hdrs")
实现为:
# Bazel rule for collecting the header files that a target depends on.
def _transitive_hdrs_impl(ctx):
outputs = set()
for dep in ctx.attr.deps:
outputs += dep.cc.transitive_headers
return struct(files=outputs) _transitive_hdrs = rule(attrs={
"deps": attr.label_list(allow_files=True,
providers=["cc"]),
},
implementation=_transitive_hdrs_impl,) def transitive_hdrs(name, deps=[], **kwargs):
_transitive_hdrs(name=name + "_gather",
deps=deps)
native.filegroup(name=name,
srcs=[":" + name + "_gather"])
其作用依旧是收集依赖需要的头文件。
“:simple_console”:
其rule为:
py_binary(
name = "simple_console",
srcs = ["simple_console.py"],
srcs_version = "PY2AND3",
deps = ["//tensorflow:tensorflow_py"],
py_library(
name = "tensorflow_py",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = ["//tensorflow/python"],
)
simple_console.py的代码的主要部分是:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function import code
import sys def main(_):
"""Run an interactive console."""
code.interact()
return 0 if __name__ == '__main__':
sys.exit(main(sys.argv))
可以看到起通过deps = [“//tensorflow/python”]构建了依赖包,然后生成了对应的执行文件。看下依赖的rule规则。
//tensorflow/python对应的rule为:
py_library(
name = "python",
srcs = [
"__init__.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__pkg__"],
deps = [
":client",
":client_testlib",
":framework",
":framework_test_lib",
":kernel_tests/gradient_checker",
":platform",
":platform_test",
":summary",
":training",
"//tensorflow/contrib:contrib_py",
],
)
py_library(
name = "training",
srcs = glob(
["training/**/*.py"],
exclude = ["**/*test*"],
),
srcs_version = "PY2AND3",
deps = [
":client",
":framework",
":lib",
":ops",
":protos_all_py",
":pywrap_tensorflow",
":training_ops",
],
)
这里其依赖的pywrap_tensorflow的rule为:
tf_py_wrap_cc(
name = "pywrap_tensorflow",
srcs = ["tensorflow.i"],
swig_includes = [
"client/device_lib.i",
"client/events_writer.i",
"client/server_lib.i",
"client/tf_session.i",
"framework/python_op_gen.i",
"lib/core/py_func.i",
"lib/core/status.i",
"lib/core/status_helper.i",
"lib/core/strings.i",
"lib/io/py_record_reader.i",
"lib/io/py_record_writer.i",
"platform/base.i",
"platform/numpy.i",
"util/port.i",
"util/py_checkpoint_reader.i",
],
deps = [
":py_func_lib",
":py_record_reader_lib",
":py_record_writer_lib",
":python_op_gen",
":tf_session_helper",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
"//util/python:python_headers",
],
)
tf_py_wrap_cc为其自己实现的一个rule,这里的.i就是SWIG的interface文件。来看下其实现:
def tf_py_wrap_cc(name, srcs, swig_includes=[], deps=[], copts=[], **kwargs):
module_name = name.split("/")[-1]
# Convert a rule name such as foo/bar/baz to foo/bar/_baz.so
# and use that as the name for the rule producing the .so file.
cc_library_name = "/".join(name.split("/")[:-1] + ["_" + module_name + ".so"])
extra_deps = []
_py_wrap_cc(name=name + "_py_wrap",
srcs=srcs,
swig_includes=swig_includes,
deps=deps + extra_deps,
module_name=module_name,
py_module_name=name)
native.cc_binary(
name=cc_library_name,
srcs=[module_name + ".cc"],
copts=(copts + ["-Wno-self-assign", "-Wno-write-strings"]
+ tf_extension_copts()),
linkopts=tf_extension_linkopts(),
linkstatic=1,
linkshared=1,
deps=deps + extra_deps)
native.py_library(name=name,
srcs=[":" + name + ".py"],
srcs_version="PY2AND3",
data=[":" + cc_library_name])
_py_wrap_cc = rule(attrs={
"srcs": attr.label_list(mandatory=True,
allow_files=True,),
"swig_includes": attr.label_list(cfg=DATA_CFG,
allow_files=True,),
"deps": attr.label_list(allow_files=True,
providers=["cc"],),
"swig_deps": attr.label(default=Label(
"//tensorflow:swig")), # swig_templates
"module_name": attr.string(mandatory=True),
"py_module_name": attr.string(mandatory=True),
"swig_binary": attr.label(default=Label("//tensorflow:swig"),
cfg=HOST_CFG,
executable=True,
allow_files=True,),
},
outputs={
"cc_out": "%{module_name}.cc",
"py_out": "%{py_module_name}.py",
},
implementation=_py_wrap_cc_impl,)
_py_wrap_cc_impl的实现为:
# Bazel rules for building swig files.
def _py_wrap_cc_impl(ctx):
srcs = ctx.files.srcs
if len(srcs) != 1:
fail("Exactly one SWIG source file label must be specified.", "srcs")
module_name = ctx.attr.module_name
cc_out = ctx.outputs.cc_out
py_out = ctx.outputs.py_out
src = ctx.files.srcs[0]
args = ["-c++", "-python"]
args += ["-module", module_name]
args += ["-l" + f.path for f in ctx.files.swig_includes]
cc_include_dirs = set()
cc_includes = set()
for dep in ctx.attr.deps:
cc_include_dirs += [h.dirname for h in dep.cc.transitive_headers]
cc_includes += dep.cc.transitive_headers
args += ["-I" + x for x in cc_include_dirs]
args += ["-I" + ctx.label.workspace_root]
args += ["-o", cc_out.path]
args += ["-outdir", py_out.dirname]
args += [src.path]
outputs = [cc_out, py_out]
ctx.action(executable=ctx.executable.swig_binary,
arguments=args,
mnemonic="PythonSwig",
inputs=sorted(set([src]) + cc_includes + ctx.files.swig_includes +
ctx.attr.swig_deps.files),
outputs=outputs,
progress_message="SWIGing {input}".format(input=src.path))
return struct(files=set(outputs))
这里的ctx.executable.swig_binary是一个shell脚本,内容为:
# If possible, read swig path out of "swig_path" generated by configure
SWIG=swig
SWIG_PATH=tensorflow/tools/swig/swig_path
if [ -e $SWIG_PATH ]; then
SWIG=`cat $SWIG_PATH`
fi # If this line fails, rerun configure to set the path to swig correctly
"$SWIG" "$@"
可以看到起就是调用了swig命令。
“//tensorflow:tensorflow_py”:
其rule为:
py_library(
name = "tensorflow_py",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = ["//tensorflow/python"],
)
可以看到起主要依赖了我们上面生成的”//tensorflow/python”这个module。
剩余的几个其实和主框架关系不大,主要是生成一些model、文档啥的。
现在清楚了其构建链后,我们来看个简单的程序,其通过梯度下降算法求线性拟合的W和b。我们会从这个例子入手看下如何找到其使用的函数的具体实现的源码位置:
(python3.5)➜ tmp cat th.py
import tensorflow as tf
import numpy as np x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3 W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = W * x_data + b loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss) init = tf.initialize_all_variables() sess = tf.Session()
sess.run(init) for step in range(0, 201):
sess.run(train)
if step % 20 == 0:
print(step, sess.run(W), sess.run(b))
(python3.5)➜ tmp python th.py
0 [ 0.42190057] [ 0.17155224]
20 [ 0.1743494] [ 0.26045772]
40 [ 0.11817314] [ 0.29033473]
60 [ 0.10444205] [ 0.29763755]
80 [ 0.10108578] [ 0.29942256]
100 [ 0.10026541] [ 0.29985884]
120 [ 0.10006487] [ 0.2999655]
140 [ 0.10001585] [ 0.29999158]
160 [ 0.10000388] [ 0.29999796]
180 [ 0.10000096] [ 0.29999951]
200 [ 0.10000025] [ 0.29999989]
从我们上面的分析可以看到,import tensorflow as tf来自于tensorflow目录下的__init__.py文件,其内容为:
from tensorflow.python import *
再来看tf.Variable,在tensorflow.python的__init__.py中可以看到其导入了很多符号。但要定位到Variable还是比较困难,因为其很多直接是import *。所以一个快速定位的方法是直接grep这个class:
➜ python grep 'class Variable(' -R ./*
./ops/variables.py:class Variable(object):
对于tf.Session等也可以用同样的方法定位。我们来找个走SWIG包裹的,如果我们去看sess.run,我们会看到如下的代码:
return tf_session.TF_Run(session, options,
feed_dict, fetch_list, target_list,
run_metadata)
这里tf_session就是一个SWIG包裹的模块:
from tensorflow.python import pywrap_tensorflow as tf_session
pywrap_tensorflow在源码里是找不到的,因为这个得从SWIG生成后才有,我们可以从.i文件里找下TF_Run的声明,或者直接grep下这个函数:
➜ tensorflow grep 'TF_Run(' -R ./*
./core/client/tensor_c_api.cc:void TF_Run(TF_Session* s, const TF_Buffer* run_options,
这样就可以看其实现了:
void TF_Run(TF_Session* s, const TF_Buffer* run_options,
// Input tensors
const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
// Output tensors
const char** c_output_tensor_names, TF_Tensor** c_outputs,
int noutputs,
// Target nodes
const char** c_target_node_names, int ntargets,
TF_Buffer* run_metadata, TF_Status* status) {
TF_Run_Helper(s, nullptr, run_options, c_input_names, c_inputs, ninputs,
c_output_tensor_names, c_outputs, noutputs, c_target_node_names,
ntargets, run_metadata, status);
}
转 如何阅读TensorFlow源码的更多相关文章
- Tensorflow[源码安装时bazel行为解析]
0. 引言 通过源码方式安装,并进行一定程度的解读,有助于理解tensorflow源码,本文主要基于tensorflow v1.8源码,并借鉴于如何阅读TensorFlow源码. 首先,自然是需要去b ...
- TensorFlow源码框架 杂记
一.为什么我们需要使用线程池技术(ThreadPool) 线程:采用“即时创建,即时销毁”策略,即接受请求后,创建一个新的线程,执行任务,完毕后,线程退出: 线程池:应用软件启动后,立即创建一定数量的 ...
- tensorflow源码解析系列文章索引
文章索引 framework解析 resource allocator tensor op node kernel graph device function shape_inference 拾遗 c ...
- 如何阅读Java源码 阅读java的真实体会
刚才在论坛不经意间,看到有关源码阅读的帖子.回想自己前几年,阅读源码那种兴奋和成就感(1),不禁又有一种激动. 源码阅读,我觉得最核心有三点:技术基础+强烈的求知欲+耐心. 说到技术基础,我打个比 ...
- newsstand杂志阅读应用源码ipad版
一款newsstand iPad杂志阅读应用源码(newsstand在线下载/动态显示等)可以支持在线下载/动态显示等 ,也是一款newsstand iPad杂志阅读应用源码.运行之后,会在iPad ...
- 如何阅读Java源码
刚才在论坛不经意间,看到有关源码阅读的帖子.回想自己前几年,阅读源码那种兴奋和成就感(1),不禁又有一种激动.源码阅读,我觉得最核心有三点:技术基础+强烈的求知欲+耐心. 说到技术基础,我打个比方吧, ...
- 如何阅读mysql源码
在微博上问mysql高手,如何阅读mysql 源码大致给了下面的一些建议: step 1,知道代码的组织结构(官方文档http://t.cn/z8LoLgh: Step2: 尝试大致了解一条sql涉及 ...
- Ubuntu TensorFlow 源码 Android Demo的编译运行
Ubuntu TensorFlow 源码 Android Demo的编译运行 一. 安装 Android 的SDK和NDK SDK 配置 A:下载 国内下载地址选最新的: SDK: https://d ...
- 编译TensorFlow源码
编译TensorFlow源码 参考: https://www.tensorflow.org/install/install_sources https://github.com/tensorflo ...
随机推荐
- [Codeforces266E]More Queries to Array...——线段树
题目链接: Codeforces266E 题目大意:给出一个序列$a$,要求完成$Q$次操作,操作分为两种:1.$l,r,x$,将$[l,r]$的数都变为$x$.2.$l,r,k$,求$\sum\li ...
- Codeforces1065G Fibonacci Suffix 【递推】【二分答案】
题目分析: 首先为了简便起见我们把前$15$的答案找出来,免得我们还要特判$200$以内之类的麻烦事. 然后我们从$16$开始递推.考虑猜测第i位是$0$还是$1$(这本质上是个二分).一开始先猜是$ ...
- 概念数据模型CDM基础
概念数据模型CDM 概念数据模型是设计数据库不可或缺的一步,是整个数据库设计的关键,CDM的主要作用如下: 1)能够真实地模拟真实世界,是需求分析人员和数据库设计人员沟通的桥梁.2)将系统需求分析得到 ...
- Linux block(1k) block(4k) 换算 gb
输入 df 显示1k blocks 大小 再输入 df -h 显示 gb换算大小 结论 block(1k) 计算公式为: block(1k) /1024/1000 = xx gb ...
- python学习日记(编码再回顾)
当想从一种编码方式转换为另一种编码方式时,执行的就是以上步骤. 在python3里面,默认编码方式是unicode,所以无需解码(decode),直接编码(encode)成你想要的编码方式就可以了. ...
- kubernetes 基础命令及操作
获取集群的基本信息kubectl cluster-infokubectl get nodeskubectl get namespaceskubectl get deployment --all-nam ...
- Educational Codeforces Round 51 (Rated for Div. 2) G. Distinctification(线段树合并 + 并查集)
题意 给出一个长度为 \(n\) 序列 , 每个位置有 \(a_i , b_i\) 两个参数 , \(b_i\) 互不相同 ,你可以进行任意次如下的两种操作 : 若存在 \(j \not = i\) ...
- 【JDK源码】将JDK源码导入IDEA中
新建工程 在IDEA中新建普通JAVA工程,步骤如下: 导入源码 首先可以通过如下方法找到工程目录. 在JDK安装目录下找到源码包src.zip,如下图 将src.zip包解压,并将src目录下的内容 ...
- 板载 SPI-FLASH 的烧写方法
@2018-12-15 [筹划] 通过烧录器(JTAG/SWD)即可方便的烧写板载外部 FLASH [参考] 如何更好地设计面向在板烧录的产品(一)SPI Flash篇 keil将程序装入外部FLAS ...
- NOIP2015斗地主(搜索+模拟+贪心)
%%%Luan 题面就不说了,和斗地主一样,给一组牌,求最少打几次. 注意一点,数据随机,这样我们瞎搞一搞就可以过,虽然直接贪心可以证明是错的. 枚举方法,每次搜索按照(三顺子>二顺子>普 ...