0. 引言

通过源码方式安装,并进行一定程度的解读,有助于理解tensorflow源码,本文主要基于tensorflow v1.8源码,并借鉴于如何阅读TensorFlow源码.

首先,自然是需要去bazel官网了解下必备知识,如(1)什么是bazel; (2)bazel如何对cpp项目进行构建的; (3)bazel构建时候的函数大全。然后就是bazel官网的一些其他更细节部分了。下文中会给出超链接。

ps: 找了很久,基本可以确定bazel除了官网是没有如书籍等资料出现的,所以只有官网和别人博客这2个途径进行学习了解


  1. wget -m -c -x -np -k -E -p https://docs.bazel.build/versions/master/bazel-overview.html

1. 从源码编译tensorflow


图1.1 github上tensorflow v1.8源码目录

1.1 先配置

源代码树的根目录中包含了一个名为 configure 的 bash 脚本。此脚本会要求您确定所有相关 TensorFlow 依赖项的路径名,并指定其他构建配置选项,例如编译器标记。您必须先运行此脚本,然后才能创建 pip 软件包并安装 TensorFlow


  1. ./configure
  1. $ cd tensorflow # cd to the top-level directory created
  2. $ ./configure
  3. Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python2.7 # python解释器路径
  4. Found possible Python library paths:
  5. /usr/local/lib/python2.7/dist-packages
  6. /usr/lib/python2.7/dist-packages
  7. Please input the desired Python library path to use. Default is [/usr/lib/python2.7/dist-packages] # python 库路径
  8. Using python library path: /usr/local/lib/python2.7/dist-packages
  9. Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]: # 是否在编译期间启用优化
  10. Do you wish to use jemalloc as the malloc implementation? [Y/n] # 是否将 jemalloc 作为malloc的实现
  11. jemalloc enabled
  12. Do you wish to build TensorFlow with Google Cloud Platform support? [y/N] # 是否开启google云平台支持
  13. No Google Cloud Platform support will be enabled for TensorFlow
  14. Do you wish to build TensorFlow with Hadoop File System support? [y/N] # 是否开启hdfs的支持
  15. No Hadoop File System support will be enabled for TensorFlow
  16. Do you wish to build TensorFlow with the XLA just-in-time compiler (experimental)? [y/N] # 是否启用尚在实验性质的XLA jit编译
  17. No XLA support will be enabled for TensorFlow
  18. Do you wish to build TensorFlow with VERBS support? [y/N] # 是否开启VERBS支持
  19. No VERBS support will be enabled for TensorFlow
  20. Do you wish to build TensorFlow with OpenCL support? [y/N] # 是否开启OpenCL支持
  21. No OpenCL support will be enabled for TensorFlow
  22. Do you wish to build TensorFlow with CUDA support? [y/N] Y # 是否开启CUDA支持
  23. CUDA support will be enabled for TensorFlow
  24. Do you want to use clang as CUDA compiler? [y/N] # 是否将clang作为CUDA的编译器
  25. nvcc will be used as CUDA compiler
  26. Please specify the CUDA SDK version you want to use, e.g. 7.0. [Leave empty to default to CUDA 9.0]: 9.0 # 选择cuda版本
  27. Please specify the location where CUDA 9.0 toolkit is installed. Refer to README.md for more details. [Default is /usr/local/cuda]: # 告知cuda的安装路径
  28. Please specify which gcc should be used by nvcc as the host compiler. [Default is /usr/bin/gcc]: # 指定host侧的 编译器
  29. Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 7.0]: 7 # cuDNN版本
  30. Please specify the location where cuDNN 7 library is installed. Refer to README.md for more details. [Default is /usr/local/cuda]: # 告知cuDNN 的安装路径
  31. Please specify a list of comma-separated CUDA compute capabilities you want to build with. # 告知当前机器上GPU的计算力
  32. You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus.
  33. Please note that each additional compute capability significantly increases your build time and binary size.
  34. Do you wish to build TensorFlow with MPI support? [y/N] # 是否开启MPI支持
  35. MPI support will not be enabled for TensorFlow
  36. Configuration finished


  1. #!/usr/bin/env bash
  2. set -e
  3. set -o pipefail
  4. if [ -z "$PYTHON_BIN_PATH" ]; then
  5. PYTHON_BIN_PATH=$(which python || which python3 || true)
  6. fi
  7. # Set all env variables
  8. CONFIGURE_DIR=$(dirname "$0")
  9. "$PYTHON_BIN_PATH" "${CONFIGURE_DIR}/configure.py" "$@" # 这行表明该configure文件是通过调用 对应的configure.py来完成配置过程的
  10. echo "Configuration finished"


  1. set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc',
  2. 'with_jemalloc', True)
  3. set_build_var(environ_cp, 'TF_NEED_GCP', 'Google Cloud Platform',
  4. 'with_gcp_support', True, 'gcp')
  5. set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System',
  6. 'with_hdfs_support', True, 'hdfs')
  7. set_build_var(environ_cp, 'TF_NEED_AWS', 'Amazon AWS Platform',
  8. 'with_aws_support', True, 'aws')
  9. set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform',
  10. 'with_kafka_support', True, 'kafka')
  11. set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
  12. False, 'xla')
  13. set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support',
  14. False, 'gdr')
  15. set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support',
  16. False, 'verbs')



  1. import /mnt/d/tensorflow/tensorflow-master/.tf_configure.bazelrc


  1. build --action_env PYTHON_BIN_PATH="/home/shouhuxianjian/anaconda3/bin/python"
  2. build --action_env PYTHON_LIB_PATH="/home/shouhuxianjian/anaconda3/lib/python3.6/site-packages"
  3. build --python_path="/home/shouhuxianjian/anaconda3/bin/python"
  4. build --define with_jemalloc=true
  5. build:gcp --define with_gcp_support=true
  6. build:hdfs --define with_hdfs_support=true
  7. build:aws --define with_aws_support=true
  8. build:kafka --define with_kafka_support=true
  9. build:xla --define with_xla_support=true
  10. build:gdr --define with_gdr_support=true
  11. build:verbs --define with_verbs_support=true
  12. build --action_env TF_NEED_OPENCL_SYCL="0"
  13. build --action_env TF_NEED_CUDA="0"
  14. build --action_env TF_DOWNLOAD_CLANG="0"
  15. build --define grpc_no_ares=true
  16. build:opt --copt=-march=native
  17. build:opt --host_copt=-march=native
  18. build:opt --define with_default_optimizations=true
  19. build --strip=always

其中的build:hdfs等形式等效于build --config=hdfs ,见这里的--config


  1. build --define with_gcp_support=true
  2. build --define with_hdfs_support=true
  3. build --define with_aws_support=true
  4. build --define with_kafka_support=true


  1. build --define with_jemalloc=true


  1. # 文档在 tensorflow-master/third_party/hadoop/BUILD
  2. package(default_visibility = ["//visibility:public"])
  3. licenses(["notice"]) # Apache 2.0
  4. exports_files(["LICENSE.txt"])
  5. cc_library(
  6. name = "hdfs",
  7. hdrs = ["hdfs.h"],
  8. )


1.2 再bazel编译


  1. $ bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package


  1. $ bazel build --config=opt --config=cuda //tensorflow/tools/pip_package:build_pip_package

1.2.1 BUILD文件结构格式推荐

在解读tensorflow-master/tensorflow/tools/pip_package/BUILD的时候,需要温习bazel构建时候的函数大全,还有官方推荐的BUILD文件结构格式File structure. 如下形式:

  1. Package description (a comment)
  2. All load() statements
  3. The package() function.
  4. Calls to rules and macros

1.2.2 tensorflow/tools/pip_package/BUILD文件解读


  1. # Description:
  2. # Tools for building the TensorFlow pip package.
  3. # 原型:package(default_deprecation, default_testonly, default_visibility, features)
  4. # 此函数声明适用于包中每个后续规则的元数据。 它最多只能在一个包(BUILD文件)中使用一次。
  5. # 此函数应该出现文件顶部,在所有load()语句之后,任何规则之前的范围内,调用package()函数。
  6. # [package](https://docs.bazel.build/versions/master/be/functions.html#package)
  7. # private表示后续的规则默认情况下只能在当前包内可见 https://docs.bazel.build/versions/master/be/common-definitions.html#common-attributes
  8. package(default_visibility = ["//visibility:private"])
  9. # Bazel的扩展是以.bzl结尾的文件。 通过使用load语句从可以从bazel的扩展文件中导入对应符号到当前BUILD中使用。
  10. # [load](https://docs.bazel.build/versions/master/skylark/concepts.html)
  11. load(
  12. "//tensorflow:tensorflow.bzl",
  13. "if_not_windows",
  14. "if_windows",
  15. "transitive_hdrs",
  16. )
  17. load("//third_party/mkl:build_defs.bzl", "if_mkl")
  18. load("//tensorflow:tensorflow.bzl", "if_cuda")
  19. load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
  20. load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
  21. # This returns a list of headers of all public header libraries (e.g.,
  22. # framework, lib), and all of the transitive dependencies of those
  23. # public headers. Not all of the headers returned by the filegroup
  24. # are public (e.g., internal headers that are included by public
  25. # headers), but the internal headers need to be packaged in the
  26. # pip_package for the public headers to be properly included.
  27. #
  28. # Public headers are therefore defined by those that are both:
  29. #
  30. # 1) "publicly visible" as defined by bazel
  31. # 2) Have documentation.
  32. #
  33. # This matches the policy of "public" for our python API.
  34. transitive_hdrs(
  35. name = "included_headers",
  36. deps = [
  37. "//tensorflow/core:core_cpu",
  38. "//tensorflow/core:framework",
  39. "//tensorflow/core:lib",
  40. "//tensorflow/core:protos_all_cc",
  41. "//tensorflow/core:stream_executor",
  42. "//third_party/eigen3",
  43. ] + if_cuda([
  44. "@local_config_cuda//cuda:cuda_headers",
  45. ]),
  46. )
  47. py_binary(
  48. name = "simple_console",
  49. srcs = ["simple_console.py"],
  50. srcs_version = "PY2AND3",
  51. deps = ["//tensorflow:tensorflow_py"],
  52. )
  54. ":licenses",
  55. "MANIFEST.in",
  56. "README",
  57. "setup.py",
  58. ":included_headers",
  59. "//tensorflow:tensorflow_py",
  60. "//tensorflow/contrib/autograph:autograph",
  61. "//tensorflow/contrib/autograph/converters:converters",
  62. "//tensorflow/contrib/autograph/converters:test_lib",
  63. "//tensorflow/contrib/autograph/impl:impl",
  64. "//tensorflow/contrib/autograph/pyct:pyct",
  65. "//tensorflow/contrib/autograph/pyct/static_analysis:static_analysis",
  66. "//tensorflow/contrib/boosted_trees:boosted_trees_pip",
  67. "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
  68. "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test",
  69. "//tensorflow/contrib/data/python/ops:contrib_op_loader",
  70. "//tensorflow/contrib/eager/python/examples:examples_pip",
  71. "//tensorflow/contrib/eager/python:checkpointable_utils",
  72. "//tensorflow/contrib/eager/python:evaluator",
  73. "//tensorflow/contrib/gan:gan",
  74. "//tensorflow/contrib/graph_editor:graph_editor_pip",
  75. "//tensorflow/contrib/keras:keras",
  76. "//tensorflow/contrib/labeled_tensor:labeled_tensor_pip",
  77. "//tensorflow/contrib/nn:nn_py",
  78. "//tensorflow/contrib/predictor:predictor_pip",
  79. "//tensorflow/contrib/proto:proto_pip",
  80. "//tensorflow/contrib/receptive_field:receptive_field_pip",
  81. "//tensorflow/contrib/rpc:rpc_pip",
  82. "//tensorflow/contrib/session_bundle:session_bundle_pip",
  83. "//tensorflow/contrib/signal:signal_py",
  84. "//tensorflow/contrib/signal:test_util",
  85. "//tensorflow/contrib/slim:slim",
  86. "//tensorflow/contrib/slim/python/slim/data:data_pip",
  87. "//tensorflow/contrib/slim/python/slim/nets:nets_pip",
  88. "//tensorflow/contrib/specs:specs",
  89. "//tensorflow/contrib/summary:summary_test_util",
  90. "//tensorflow/contrib/tensor_forest:init_py",
  91. "//tensorflow/contrib/tensor_forest/hybrid:hybrid_pip",
  92. "//tensorflow/contrib/timeseries:timeseries_pip",
  93. "//tensorflow/contrib/tpu",
  94. "//tensorflow/examples/tutorials/mnist:package",
  95. "//tensorflow/python:distributed_framework_test_lib",
  96. "//tensorflow/python:meta_graph_testdata",
  97. "//tensorflow/python:spectral_ops_test_util",
  98. "//tensorflow/python:util_example_parser_configuration",
  99. "//tensorflow/python/debug:debug_pip",
  100. "//tensorflow/python/eager:eager_pip",
  101. "//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files",
  102. "//tensorflow/python/saved_model:saved_model",
  103. "//tensorflow/python/tools:tools_pip",
  104. "//tensorflow/python:test_ops",
  105. "//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
  106. ]
  107. # On Windows, python binary is a zip file of runfiles tree.
  108. # Add everything to its data dependency for generating a runfiles tree
  109. # for building the pip package on Windows.
  110. py_binary(
  111. name = "simple_console_for_windows",
  112. srcs = ["simple_console_for_windows.py"],
  113. data = COMMON_PIP_DEPS,
  114. srcs_version = "PY2AND3",
  115. deps = ["//tensorflow:tensorflow_py"],
  116. )
  117. filegroup(
  118. name = "licenses",
  119. data = [
  120. "//third_party/eigen3:LICENSE",
  121. "//third_party/fft2d:LICENSE",
  122. "//third_party/hadoop:LICENSE.txt",
  123. "@absl_py//absl/flags:LICENSE",
  124. "@arm_neon_2_x86_sse//:LICENSE",
  125. "@astor_archive//:LICENSE",
  126. "@aws//:LICENSE",
  127. "@boringssl//:LICENSE",
  128. "@com_google_absl//:LICENSE",
  129. "@com_googlesource_code_re2//:LICENSE",
  130. "@cub_archive//:LICENSE.TXT",
  131. "@curl//:COPYING",
  132. "@eigen_archive//:COPYING.MPL2",
  133. "@farmhash_archive//:COPYING",
  134. "@fft2d//:fft/readme.txt",
  135. "@flatbuffers//:LICENSE.txt",
  136. "@gast_archive//:PKG-INFO",
  137. "@gemmlowp//:LICENSE",
  138. "@gif_archive//:COPYING",
  139. "@grpc//:LICENSE",
  140. "@highwayhash//:LICENSE",
  141. "@jemalloc//:COPYING",
  142. "@jpeg//:LICENSE.md",
  143. "@kafka//:LICENSE",
  144. "@libxsmm_archive//:LICENSE",
  145. "@lmdb//:LICENSE",
  146. "@local_config_nccl//:LICENSE",
  147. "@local_config_sycl//sycl:LICENSE.text",
  148. "@grpc//third_party/nanopb:LICENSE.txt",
  149. "@grpc//third_party/address_sorting:LICENSE",
  150. "@nasm//:LICENSE",
  151. "@nsync//:LICENSE",
  152. "@pcre//:LICENCE",
  153. "@png_archive//:LICENSE",
  154. "@protobuf_archive//:LICENSE",
  155. "@six_archive//:LICENSE",
  156. "@snappy//:COPYING",
  157. "@swig//:LICENSE",
  158. "@termcolor_archive//:COPYING.txt",
  159. "@zlib_archive//:zlib.h",
  160. "@org_python_pypi_backports_weakref//:LICENSE",
  161. ] + if_mkl([
  162. "//third_party/mkl:LICENSE",
  163. ]) + tf_additional_license_deps(),
  164. )
  165. # 对应的shell二进制规则,其中涉及到了select(主要用来做平台依赖选择),在bazel的编译命令中,并未显式的指定build_pip_package的属性,所以这里采用了默认的条件
  166. # [select](https://docs.bazel.build/versions/master/skylark/lib/globals.html#select)
  167. # [select](https://docs.bazel.build/versions/master/be/functions.html#select)
  168. sh_binary(
  169. name = "build_pip_package",
  170. srcs = ["build_pip_package.sh"],
  171. data = select({
  172. "//tensorflow:windows": [":simple_console_for_windows"],
  173. "//tensorflow:windows_msvc": [":simple_console_for_windows"],
  174. "//conditions:default": COMMON_PIP_DEPS + [
  175. ":simple_console",
  176. "//tensorflow/contrib/lite/python:interpreter_test_data",
  177. "//tensorflow/contrib/lite/python:tf_lite_py_pip",
  178. "//tensorflow/contrib/lite/toco:toco",
  179. "//tensorflow/contrib/lite/toco/python:toco_wrapper",
  180. "//tensorflow/contrib/lite/toco/python:toco_from_protos",
  181. ],
  182. }) + if_mkl(["//third_party/mkl:intel_binary_blob"]) + if_tensorrt([
  183. "//tensorflow/contrib/tensorrt:init_py",
  184. ]),
  185. )
  186. # A genrule for generating a marker file for the pip package on Windows
  187. #
  188. # This only works on Windows, because :simple_console_for_windows is a
  189. # python zip file containing everything we need for building the pip package.
  190. # However, on other platforms, due to https://github.com/bazelbuild/bazel/issues/4223,
  191. # when C++ extensions change, this generule doesn't rebuild.
  192. genrule(
  193. name = "win_pip_package_marker",
  194. srcs = if_windows([
  195. ":build_pip_package",
  196. ":simple_console_for_windows",
  197. ]),
  198. outs = ["win_pip_package_marker_file"],
  199. cmd = select({
  200. "//conditions:default": "touch $@",
  201. "//tensorflow:windows": "md5sum $(locations :build_pip_package) $(locations :simple_console_for_windows) > $@",
  202. }),
  203. visibility = ["//visibility:public"],
  204. )

1.2.3 编译build_pip_package的过程


  2. ":simple_console",
  3. "//tensorflow/contrib/lite/python:interpreter_test_data",
  4. "//tensorflow/contrib/lite/python:tf_lite_py_pip",
  5. "//tensorflow/contrib/lite/toco:toco",
  6. "//tensorflow/contrib/lite/toco/python:toco_wrapper",
  7. "//tensorflow/contrib/lite/toco/python:toco_from_protos",
  8. ]

那么现在焦点就转移到COMMON_PIP_DEPS 部分了。该变量中,一开始的三个文件MANIFEST.in、README、setup.py是直接存在的,因此不会有什么操作。然后我们看下一行的

  1. :included_headers


  1. # This matches the policy of "public" for our python API.
  2. transitive_hdrs(
  3. name = "included_headers",
  4. deps = [
  5. "//tensorflow/core:core_cpu",
  6. "//tensorflow/core:framework",
  7. "//tensorflow/core:lib",
  8. "//tensorflow/core:protos_all_cc",
  9. "//tensorflow/core:stream_executor",
  10. "//third_party/eigen3",
  11. ] + if_cuda([
  12. "@local_config_cuda//cuda:cuda_headers",
  13. ]),
  14. )

而transitive_hdrs 并不是关键字类型的函数,是由上面的load导入的

  1. load(
  2. "//tensorflow:tensorflow.bzl",
  3. "if_not_windows",
  4. "if_windows",
  5. "transitive_hdrs",
  6. )


  1. # Bazel rule for collecting the header files that a target depends on.
  2. def _transitive_hdrs_impl(ctx):
  3. outputs = depset()
  4. for dep in ctx.attr.deps:
  5. outputs += dep.cc.transitive_headers
  6. return struct(files=outputs)
  7. # 这里调用了对应的rule函数
  8. # [rule](https://docs.bazel.build/versions/master/skylark/lib/globals.html#rule)
  9. _transitive_hdrs = rule(
  10. attrs = {
  11. "deps": attr.label_list(
  12. allow_files = True,
  13. providers = ["cc"],
  14. ),
  15. },
  16. implementation = _transitive_hdrs_impl,
  17. )
  18. # transitive_hdrs所在的位置,其通过内部的_transitive_hdrs规则,而该规则是通过_transitive_hdrs_impl 实现的
  19. def transitive_hdrs(name, deps=[], **kwargs):
  20. _transitive_hdrs(name=name + "_gather", deps=deps)
  21. native.filegroup(name=name, srcs=[":" + name + "_gather"])


  1. "//tensorflow:tensorflow_py",


  1. # 当前文件为tensorflow/BUILD的539-548行
  2. py_library(
  3. name = "tensorflow_py",
  4. srcs = ["__init__.py"],
  5. srcs_version = "PY2AND3",
  6. visibility = ["//visibility:public"],
  7. deps = [
  8. "//tensorflow/python",
  9. "//tensorflow/tools/api/generator:python_api",
  10. ],
  11. )


  1. # 当前文件为tensorflow/python/BUILD
  2. py_library(
  3. name = "python",
  4. srcs = ["__init__.py"],
  5. srcs_version = "PY2AND3",
  6. visibility = [
  7. "//tensorflow:__pkg__",
  8. "//tensorflow/compiler/aot/tests:__pkg__", # TODO(b/34059704): remove when fixed
  9. "//tensorflow/contrib/learn:__pkg__", # TODO(b/34059704): remove when fixed
  10. "//tensorflow/contrib/learn/python/learn/datasets:__pkg__", # TODO(b/34059704): remove when fixed
  11. "//tensorflow/contrib/lite/toco/python:__pkg__", # TODO(b/34059704): remove when fixed
  12. "//tensorflow/python/debug:__pkg__", # TODO(b/34059704): remove when fixed
  13. "//tensorflow/python/tools:__pkg__", # TODO(b/34059704): remove when fixed
  14. "//tensorflow/tools/api/generator:__pkg__",
  15. "//tensorflow/tools/quantization:__pkg__", # TODO(b/34059704): remove when fixed
  16. ],
  17. deps = [
  18. ":no_contrib",
  19. "//tensorflow/contrib:contrib_py",
  20. ],
  21. )

这里依赖于:no_contrib 这个target,那么我们关注下

  1. # 当前文件为tensorflow/python/BUILD
  2. py_library(
  3. name = "no_contrib",
  4. srcs = ["__init__.py"],
  5. srcs_version = "PY2AND3",
  6. visibility = [
  7. "//tensorflow:__pkg__",
  8. ],
  9. deps = [
  10. ":array_ops",
  11. ":bitwise_ops",
  12. ":boosted_trees_ops",
  13. ":check_ops",
  14. ":client",
  15. ":client_testlib",
  16. ":confusion_matrix",
  17. ":control_flow_ops",
  18. ":cudnn_rnn_ops_gen",
  19. ":errors",
  20. ":framework",
  21. ":framework_for_generated_wrappers",
  22. ":functional_ops",
  23. ":gradient_checker",
  24. ":graph_util",
  25. ":histogram_ops",
  26. ":image_ops",
  27. ":initializers_ns",
  28. ":io_ops",
  29. ":layers",
  30. ":lib",
  31. ":list_ops",
  32. ":manip_ops",
  33. ":math_ops",
  34. ":metrics",
  35. ":nn",
  36. ":ops",
  37. ":platform",
  38. ":pywrap_tensorflow",
  39. ":saver_test_utils",
  40. ":script_ops",
  41. ":session_ops",
  42. ":sets",
  43. ":sparse_ops",
  44. ":spectral_ops",
  45. ":spectral_ops_test_util",
  46. ":standard_ops",
  47. ":state_ops",
  48. ":string_ops",
  49. ":subscribe",
  50. ":summary",
  51. ":tensor_array_ops",
  52. ":test_ops", # TODO: Break testing code out into separate rule.
  53. ":tf_cluster",
  54. ":tf_item",
  55. ":tf_optimizer",
  56. ":training",
  57. ":util",
  58. ":weights_broadcast_ops",
  59. "//tensorflow/core:protos_all_py",
  60. "//tensorflow/python/data",
  61. "//tensorflow/python/estimator:estimator_py",
  62. "//tensorflow/python/feature_column:feature_column_py",
  63. "//tensorflow/python/keras",
  64. "//tensorflow/python/ops/distributions",
  65. "//tensorflow/python/ops/linalg",
  66. "//tensorflow/python/ops/losses",
  67. "//tensorflow/python/profiler",
  68. "//tensorflow/python/saved_model",
  69. "//third_party/py/numpy",
  70. ],
  71. )

我们也跟随.如何阅读TensorFlow源码去找pywrap_tensorflow这个部分,其中pywrap_tensorflow target依赖于pywrap_tensorflow_internal这个target的,而pywrap_tensorflow_internal就是通过swig从cc文件生成对应的python接口文件部分了

  1. # 当前文件为tensorflow/python/BUILD 3421行
  2. py_library(
  3. name = "pywrap_tensorflow",
  4. srcs = [
  5. "pywrap_tensorflow.py",
  6. ] + if_static(
  7. ["pywrap_dlopen_global_flags.py"],
  8. # Import will fail, indicating no global dlopen flags
  9. otherwise = [],
  10. ),
  11. srcs_version = "PY2AND3",
  12. deps = [":pywrap_tensorflow_internal"],
  13. )
  14. tf_py_wrap_cc(
  15. name = "pywrap_tensorflow_internal",
  16. srcs = ["tensorflow.i"],
  17. swig_includes = [
  18. "client/device_lib.i",
  19. "client/events_writer.i",
  20. "client/tf_session.i",
  21. "client/tf_sessionrun_wrapper.i",
  22. "framework/cpp_shape_inference.i",
  23. "framework/python_op_gen.i",
  24. "grappler/cluster.i",
  25. "grappler/cost_analyzer.i",
  26. "grappler/item.i",
  27. "grappler/model_analyzer.i",
  28. "grappler/tf_optimizer.i",
  29. "lib/core/bfloat16.i",
  30. "lib/core/py_exception_registry.i",
  31. "lib/core/py_func.i",
  32. "lib/core/strings.i",
  33. "lib/io/file_io.i",
  34. "lib/io/py_record_reader.i",
  35. "lib/io/py_record_writer.i",
  36. "platform/base.i",
  37. "platform/stacktrace_handler.i",
  38. "pywrap_tfe.i",
  39. "training/quantize_training.i",
  40. "training/server_lib.i",
  41. "util/kernel_registry.i",
  42. "util/port.i",
  43. "util/py_checkpoint_reader.i",
  44. "util/stat_summarizer.i",
  45. "util/tfprof.i",
  46. "util/transform_graph.i",
  47. "util/util.i",
  48. ],
  49. win_def_file = select({
  50. "//tensorflow:windows": ":pywrap_tensorflow_filtered_def_file",
  51. "//conditions:default": None,
  52. }),
  53. deps = [
  54. ":bfloat16_lib",
  55. ":cost_analyzer_lib",
  56. ":model_analyzer_lib",
  57. ":cpp_python_util",
  58. ":cpp_shape_inference",
  59. ":kernel_registry",
  60. ":numpy_lib",
  61. ":safe_ptr",
  62. ":py_exception_registry",
  63. ":py_func_lib",
  64. ":py_record_reader_lib",
  65. ":py_record_writer_lib",
  66. ":python_op_gen",
  67. ":tf_session_helper",
  68. "//tensorflow/c:c_api",
  69. "//tensorflow/c:checkpoint_reader",
  70. "//tensorflow/c:python_api",
  71. "//tensorflow/c:tf_status_helper",
  72. "//tensorflow/c/eager:c_api",
  73. "//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration",
  74. "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
  75. "//tensorflow/core/distributed_runtime/rpc:grpc_session",
  76. "//tensorflow/core/grappler:grappler_item",
  77. "//tensorflow/core/grappler:grappler_item_builder",
  78. "//tensorflow/core/grappler/clusters:cluster",
  79. "//tensorflow/core/grappler/clusters:single_machine",
  80. "//tensorflow/core/grappler/clusters:virtual_cluster",
  81. "//tensorflow/core/grappler/costs:graph_memory",
  82. "//tensorflow/core/grappler/optimizers:meta_optimizer",
  83. "//tensorflow/core:lib",
  84. "//tensorflow/core:reader_base",
  85. "//tensorflow/core/debug",
  86. "//tensorflow/core/distributed_runtime:server_lib",
  87. "//tensorflow/core/profiler/internal:print_model_analysis",
  88. "//tensorflow/tools/graph_transforms:transform_graph_lib",
  89. "//tensorflow/python/eager:pywrap_tfe_lib",
  90. "//tensorflow/python/eager:python_eager_op_gen",
  91. "//util/python:python_headers",
  92. ] + (tf_additional_lib_deps() +
  93. tf_additional_plugin_deps() +
  94. tf_additional_verbs_deps() +
  95. tf_additional_mpi_deps() +
  96. tf_additional_gdr_deps()),
  97. )


  1. load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")


  1. # 此文件为tensorflow/tensorflow.bzl 1404行
  2. def tf_py_wrap_cc(name,
  3. srcs,
  4. swig_includes=[],
  5. deps=[],
  6. copts=[],
  7. **kwargs):
  8. module_name = name.split("/")[-1]
  9. # Convert a rule name such as foo/bar/baz to foo/bar/_baz.so
  10. # and use that as the name for the rule producing the .so file.
  11. cc_library_name = "/".join(name.split("/")[:-1] + ["_" + module_name + ".so"])
  12. cc_library_pyd_name = "/".join(
  13. name.split("/")[:-1] + ["_" + module_name + ".pyd"])
  14. extra_deps = []
  15. _py_wrap_cc(
  16. name=name + "_py_wrap",
  17. srcs=srcs,
  18. swig_includes=swig_includes,
  19. deps=deps + extra_deps,
  20. toolchain_deps=["//tools/defaults:crosstool"],
  21. module_name=module_name,
  22. py_module_name=name)
  23. vscriptname=name+"_versionscript"
  24. _append_init_to_versionscript(
  25. name=vscriptname,
  26. module_name=module_name,
  27. is_version_script=select({
  28. "@local_config_cuda//cuda:darwin":False,
  29. "//conditions:default":True,
  30. }),
  31. template_file=select({
  32. "@local_config_cuda//cuda:darwin":clean_dep("//tensorflow:tf_exported_symbols.lds"),
  33. "//conditions:default":clean_dep("//tensorflow:tf_version_script.lds")
  34. })
  35. )
  36. extra_linkopts = select({
  37. "@local_config_cuda//cuda:darwin": [
  38. "-Wl,-exported_symbols_list",
  39. "%s.lds"%vscriptname,
  40. ],
  41. clean_dep("//tensorflow:windows"): [],
  42. clean_dep("//tensorflow:windows_msvc"): [],
  43. "//conditions:default": [
  44. "-Wl,--version-script",
  45. "%s.lds"%vscriptname,
  46. ]
  47. })
  48. extra_deps += select({
  49. "@local_config_cuda//cuda:darwin": [
  50. "%s.lds"%vscriptname,
  51. ],
  52. clean_dep("//tensorflow:windows"): [],
  53. clean_dep("//tensorflow:windows_msvc"): [],
  54. "//conditions:default": [
  55. "%s.lds"%vscriptname,
  56. ]
  57. })
  58. tf_cc_shared_object(
  59. name=cc_library_name,
  60. srcs=[module_name + ".cc"],
  61. copts=(copts + if_not_windows([
  62. "-Wno-self-assign", "-Wno-sign-compare", "-Wno-write-strings"
  63. ]) + tf_extension_copts()),
  64. linkopts=tf_extension_linkopts() + extra_linkopts,
  65. linkstatic=1,
  66. deps=deps + extra_deps,
  67. **kwargs)
  68. native.genrule(
  69. name="gen_" + cc_library_pyd_name,
  70. srcs=[":" + cc_library_name],
  71. outs=[cc_library_pyd_name],
  72. cmd="cp $< $@",)
  73. native.py_library(
  74. name=name,
  75. srcs=[":" + name + ".py"],
  76. srcs_version="PY2AND3",
  77. data=select({
  78. clean_dep("//tensorflow:windows"): [":" + cc_library_pyd_name],
  79. "//conditions:default": [":" + cc_library_name],
  80. }))



  • tf_cc_shared_object 负责生成 so文件;
  • 而native.py_library负责???


  1. # 此文件为tensorflow/tensorflow.bzl 1090行,下面的1122行就是_py_wrap_cc的位置
  2. # Bazel rules for building swig files.
  3. def _py_wrap_cc_impl(ctx):
  4. srcs = ctx.files.srcs
  5. if len(srcs) != 1:
  6. fail("Exactly one SWIG source file label must be specified.", "srcs")
  7. module_name = ctx.attr.module_name
  8. src = ctx.files.srcs[0]
  9. inputs = depset([src])
  10. inputs += ctx.files.swig_includes
  11. for dep in ctx.attr.deps:
  12. inputs += dep.cc.transitive_headers
  13. inputs += ctx.files._swiglib
  14. inputs += ctx.files.toolchain_deps
  15. swig_include_dirs = depset(_get_repository_roots(ctx, inputs))
  16. swig_include_dirs += sorted([f.dirname for f in ctx.files._swiglib])
  17. args = [
  18. "-c++", "-python", "-module", module_name, "-o", ctx.outputs.cc_out.path,
  19. "-outdir", ctx.outputs.py_out.dirname
  20. ]
  21. args += ["-l" + f.path for f in ctx.files.swig_includes]
  22. args += ["-I" + i for i in swig_include_dirs]
  23. args += [src.path]
  24. outputs = [ctx.outputs.cc_out, ctx.outputs.py_out]
  25. ctx.action(
  26. executable=ctx.executable._swig,
  27. arguments=args,
  28. inputs=list(inputs),
  29. outputs=outputs,
  30. mnemonic="PythonSwig",
  31. progress_message="SWIGing " + src.path)
  32. return struct(files=depset(outputs))
  33. _py_wrap_cc = rule(
  34. attrs = {
  35. "srcs": attr.label_list(
  36. mandatory = True,
  37. allow_files = True,
  38. ),
  39. "swig_includes": attr.label_list(
  40. cfg = "data",
  41. allow_files = True,
  42. ),
  43. "deps": attr.label_list(
  44. allow_files = True,
  45. providers = ["cc"],
  46. ),
  47. "toolchain_deps": attr.label_list(
  48. allow_files = True,
  49. ),
  50. "module_name": attr.string(mandatory = True),
  51. "py_module_name": attr.string(mandatory = True),
  52. "_swig": attr.label(
  53. default = Label("@swig//:swig"),
  54. executable = True,
  55. cfg = "host",
  56. ),
  57. "_swiglib": attr.label(
  58. default = Label("@swig//:templates"),
  59. allow_files = True,
  60. ),
  61. },
  62. outputs = {
  63. "cc_out": "%{module_name}.cc",
  64. "py_out": "%{py_module_name}.py",
  65. },
  66. implementation = _py_wrap_cc_impl,
  67. )

上述中ctx.executable._swig 是为执行部分,其对应的

  1. "_swig": attr.label(
  2. default = Label("@swig//:swig"),
  3. executable = True,
  4. cfg = "host",
  5. ),


  1. licenses(["restricted"]) # GPLv3
  2. exports_files(["LICENSE"])
  3. cc_binary(
  4. name = "swig",
  5. srcs = [
  6. "Source/CParse/cparse.h",
  7. "Source/CParse/cscanner.c",
  8. "Source/CParse/parser.c",
  9. "Source/CParse/parser.h",
  10. "Source/CParse/templ.c",
  11. "Source/CParse/util.c",
  12. "Source/DOH/base.c",
  13. "Source/DOH/doh.h",
  14. "Source/DOH/dohint.h",
  15. "Source/DOH/file.c",
  16. "Source/DOH/fio.c",
  17. "Source/DOH/hash.c",
  18. "Source/DOH/list.c",
  19. "Source/DOH/memory.c",
  20. "Source/DOH/string.c",
  21. "Source/DOH/void.c",
  22. "Source/Include/swigconfig.h",
  23. "Source/Include/swigwarn.h",
  24. "Source/Modules/allocate.cxx",
  25. "Source/Modules/browser.cxx",
  26. "Source/Modules/contract.cxx",
  27. "Source/Modules/directors.cxx",
  28. ......


  • 先生成swig可执行文件
  • 再通过对应i文件生成对应的wrap文件,并进行编译生成对应的so文件和py文件
  • 就可以正常导入了

1.2.4 如何从python端找到对应的c源码文件


  1. import tensorflow as tf
  2. import numpy as np
  3. x_data = np.random.rand(100).astype(np.float32)
  4. y_data = x_data * 0.1 + 0.3
  5. W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
  6. b = tf.Variable(tf.zeros([1]))
  7. y = W * x_data + b
  8. loss = tf.reduce_mean(tf.square(y - y_data))
  9. optimizer = tf.train.GradientDescentOptimizer(0.5)
  10. train = optimizer.minimize(loss)
  11. init = tf.initialize_all_variables()
  12. sess = tf.Session()
  13. sess.run(init)
  14. for step in range(0, 201):
  15. sess.run(train)
  16. if step % 20 == 0:
  17. print(step, sess.run(W), sess.run(b))


1.2.5 python和cpp函数名的对应



  1. string function_name;
  2. python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
  3. &function_name);


  1. void GenerateLowerCaseOpName(const string& str, string* result) {
  2. const char joiner = '_';
  3. const int last_index = str.size() - 1;
  4. for (int i = 0; i <= last_index; ++i) {
  5. const char c = str[i];
  6. // Emit a joiner only if a previous-lower-to-now-upper or a
  7. // now-upper-to-next-lower transition happens.
  8. if (isupper(c) && (i > 0)) {
  9. if (islower(str[i - 1]) || ((i < last_index) && islower(str[i + 1]))) {
  10. result->push_back(joiner);
  11. }
  12. }
  13. result->push_back(tolower(c));
  14. }
  15. }



