[源码解析] TensorFlow 分布式环境(7) --- Worker 动态逻辑

前文中,Master 在流程之中先后调用了 gRPC 给远端 worker 发送命令,即,GrpcRemoteWorker 类中的每一个函数都通过调用 IssueRequest() 发起一个异步的 gRPC 调用。GrpcRemoteWorker 一共发了两个请求:RegisterGraphAsync,RunGraphAsync,我们看看 GrpcWorkerService 如何处理。

本文依旧深度借鉴了两位大神:

本系列其他文章是:

[翻译] TensorFlow 分布式之论文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻译] TensorFlow 分布式之论文篇 "Implementation of Control Flow in TensorFlow"

[源码解析] TensorFlow 分布式环境(1) --- 总体架构

[源码解析] TensorFlow 分布式环境(2)---Master 静态逻辑

[源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑

[源码解析] TensorFlow 分布式环境(4) --- WorkerCache

[源码解析] TensorFlow 分布式环境(5) --- Session

1. 概述

1.1 温故

我们首先回顾一下目前为止各种概念之间的关系。

  • Client会构建完整的计算图(FullGraph),但是这个完整计算图无法并行执行,所以需要切分优化。
  • Master会对完整计算图进行处理,比如剪枝等操作,生成ClientGraph(可以执行的最小依赖子图)。然后根据Worker信息把ClientGraph继续切分成多个PartitionGraph。把这些PartitionGraph注册给每个Worker。
  • Worker接收到注册请求之后,会把收到的PartitionGraph根据本地计算设备集继续做切分成多个PartitionGraph,并且在每个设备上启动一个Executor来执行本设备收到的PartitionGraph。

1.2 知新

我们接下来看看Worker的流程概要。当流程来到某个特点 Worker 节点,如果 worker 节点收到了 RegisterGraphRequest,消息会携带 MasterSession 分配的 session_handle 和子图 graph_def(GraphDef形式)。GraphDef是TensorFlow把Client创建的计算图使用Protocol Buffer序列化之后的结果。GraphDef包括了计算图所有的元数据。它可以被ConvertGraphDefToGraph方法转换成Graph。Graph不但有计算图的元数据,还有其他运行时候所需要的信息。

Worker 把计算图按照本地设备集继续切分成多个 PartitionGraph,把PartitionGraph 分配给每个设备,然后在每个计算设备之上启动一个 Executor,等待后续执行命令。Executor类是TensorFlow之中会话执行器的抽象,其提供异步执行局部图的RunAsync虚方法及其同步封装版本Run方法。

当 Worker 节点收到 RunGraphAsync 之后,各个设备开始执行。WorkerSession 会调用 session->graph_mgr()->ExecuteAsync 执行,其又调用到 StartParallelExecutors,这里会启动一个 ExecutorBarrier。当某一个计算设备执行完所分配的 PartitionGraph 后,ExecutorBarrier 计数器将会增加 1,如果所有设备都完成 PartitionGraph 列表的执行,barrier.wait() 阻塞操作将退出。

我们接下来逐步分析一下上述流程。

2. 注册子图

当 worker 节点收到了 RegisterGraphRequest 之后,首先来到了 GrpcWorkerService,所以实际调用的是 "/tensorflow.WorkerService/RegisterGraph",对应代码如下,其实展开了就是 RegisterGraphHandler:

  1. #define HANDLE_CALL(method, may_block_on_compute_pool) \
  2. void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
  3. auto closure = [this, call]() { \
  4. Status s = worker_->method(&call->request, &call->response); \
  5. if (!s.ok()) { \
  6. VLOG(3) << "Bad response from " << #method << ": " << s; \
  7. } \
  8. call->SendResponse(ToGrpcStatus(s)); \
  9. }; \
  10. if ((may_block_on_compute_pool)) { \
  11. worker_->env()->env->SchedClosure(std::move(closure)); \
  12. } else { \
  13. worker_->env()->compute_pool->Schedule(std::move(closure)); \
  14. } \
  15. ENQUEUE_REQUEST(method, false); \
  16. }
  17. HANDLE_CALL(RegisterGraph, false);

2.1 GrpcWorker

RegisterGraph 实际调用的是 WorkerInterface 的方法,其内部会转到 RegisterGraphAsync 方法。

  1. Status WorkerInterface::RegisterGraph(const RegisterGraphRequest* request,
  2. RegisterGraphResponse* response) {
  3. return CallAndWait(&ME::RegisterGraphAsync, request, response);
  4. }

RegisterGraphAsync 最后来到 Worker 的实现,其首先依据 session_handle 查找到 WokerSession,然后调用 GraphMgr。

  1. GraphMgr* SessionMgr::graph_mgr() const { return graph_mgr_.get(); }

RegisterGraphAsync 具体如下:

  1. void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
  2. RegisterGraphResponse* response,
  3. StatusCallback done) {
  4. std::shared_ptr<WorkerSession> session;
  5. Status s;
  6. if (request->create_worker_session_called()) {
  7. s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
  8. &session);
  9. } else {
  10. session = env_->session_mgr->LegacySession();
  11. }
  12. if (s.ok()) {
  13. s = session->graph_mgr()->Register(
  14. request->session_handle(), request->graph_def(), session.get(),
  15. request->graph_options(), request->debug_options(),
  16. request->config_proto(), request->collective_graph_key(),
  17. session->cluster_flr(), response->mutable_graph_handle());
  18. }
  19. done(s);
  20. }

2.2 GraphMgr

GraphMgr 负责跟踪一组在 TensorFlow 工作者那里注册的计算图。每个注册的图都由 GraphMgr 生成的句柄 graph_handle 来识别,并返回给调用者。在成功注册后,调用者使用图句柄执行一个图。每个执行都通过调用者生成的全局唯一ID "step_id"与其他执行区分开来。只要使用的 "step_id"不同,多个执行可以同时独立使用同一个图,多个线程可以并发地调用 GraphMgr 方法。

2.2.1 定义

GraphMgr 具体定义如下:

  1. class GraphMgr {
  2. private:
  3. typedef GraphMgr ME;
  4. struct ExecutionUnit {
  5. std::unique_ptr<Graph> graph = nullptr;
  6. Device* device = nullptr; // not owned.
  7. Executor* root = nullptr; // not owned.
  8. FunctionLibraryRuntime* lib = nullptr; // not owned.
  9. // Build the cost model if this value is strictly positive.
  10. int64_t build_cost_model = 0;
  11. };
  12. struct Item : public core::RefCounted {
  13. ~Item() override;
  14. // Session handle.
  15. string session;
  16. // Graph handle.
  17. string handle;
  18. std::unique_ptr<FunctionLibraryDefinition> lib_def;
  19. // Owns the FunctionLibraryRuntime objects needed to execute functions, one
  20. // per device.
  21. std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr;
  22. // A graph is partitioned over multiple devices. Each partition
  23. // has a root executor which may call into the runtime library.
  24. std::vector<ExecutionUnit> units;
  25. // Used to deregister a cost model when cost model is required in graph
  26. // manager.
  27. GraphMgr* graph_mgr;
  28. int64_t collective_graph_key;
  29. };
  30. const WorkerEnv* worker_env_; // Not owned.
  31. const DeviceMgr* device_mgr_;
  32. CostModelManager cost_model_manager_;
  33. // Owned.
  34. mutex mu_;
  35. int64_t next_id_ TF_GUARDED_BY(mu_) = 0;
  36. // If true, blocks until device has finished all queued operations in a step.
  37. bool sync_on_finish_ = true;
  38. // Table mapping graph handles to registered graphs.
  39. //
  40. // TODO(zhifengc): If the client does not call Deregister, we'll
  41. // lose memory over time. We should implement a timeout-based
  42. // mechanism to gc these graphs.
  43. std::unordered_map<string, Item*> table_;
  44. TF_DISALLOW_COPY_AND_ASSIGN(GraphMgr);
  45. };

具体各个类之间关系和功能如下,注册图就是往GraphMgr的table_变量之中进行注册新Item,而执行图就是执行具体的Item。

2.2.2 注册图

注册图代码如下,其实就是转交给 InitItem,所以我们接下去看看 InitItem。

  1. Status GraphMgr::Register(
  2. const string& handle, const GraphDef& gdef, WorkerSession* session,
  3. const GraphOptions& graph_options, const DebugOptions& debug_options,
  4. const ConfigProto& config_proto, int64_t collective_graph_key,
  5. DistributedFunctionLibraryRuntime* cluster_flr, string* graph_handle) {
  6. Item* item = new Item;
  7. Status s = InitItem(handle, gdef, session, graph_options, debug_options,
  8. config_proto, collective_graph_key, cluster_flr, item);
  9. if (!s.ok()) {
  10. item->Unref();
  11. return s;
  12. }
  13. // Inserts one item into table_.
  14. {
  15. mutex_lock l(mu_);
  16. *graph_handle =
  17. strings::Printf("%016llx", static_cast<long long>(++next_id_));
  18. item->handle = *graph_handle;
  19. CHECK(table_.insert({*graph_handle, item}).second);
  20. }
  21. return Status::OK();
  22. }

InitItem 主要功能是:

  • 在给定 session 的一个图定义 "gdef" 之后,创建 executors。

  • 如果 "gdef"中的一个节点被 "session "中的其他图所共享,则相同的 op kernel 被重复使用。例如,通常一个params节点被一个会话中的多个图所共享。

  • 如果 "gdef"被分配给多个设备,可能会添加额外的节点(例如,发送/接收节点)。额外节点的名字是通过调用 "new_name(old_name) "生成的。

  • 如果成功的话,"executors"将被分配,每个设备填入一个执行器,调用者将拥有返回的 executors 的所有权。

  1. // Creates executors given a graph definition "gdef" of a "session".
  2. // If a node in "gdef" is shared by other graphs in "session", the
  3. // same op kernel is reused. E.g., typically a params node is shared
  4. // by multiple graphs in a session.
  5. //
  6. // If "gdef" is assigned to multiple devices, extra nodes (e.g.,
  7. // send/recv nodes) maybe added. The extra nodes' name are generated
  8. // by calling "new_name(old_name)".
  9. //
  10. // "executors" are filled with one executor per device if success and
  11. // the caller takes the ownership of returned executors.
  12. Status GraphMgr::InitItem(
  13. const string& handle, const GraphDef& gdef, WorkerSession* session,
  14. const GraphOptions& graph_options, const DebugOptions& debug_options,
  15. const ConfigProto& config_proto, int64_t collective_graph_key,
  16. DistributedFunctionLibraryRuntime* cluster_flr, Item* item) {
  17. item->session = handle;
  18. item->collective_graph_key = collective_graph_key;
  19. item->lib_def.reset(
  20. new FunctionLibraryDefinition(OpRegistry::Global(), gdef.library()));
  21. TF_RETURN_IF_ERROR(ValidateGraphDefForDevices(gdef));
  22. // We don't explicitly Validate the graph def because ConvertGraphDefToGraph
  23. // does that below.
  24. item->proc_flr.reset(new ProcessFunctionLibraryRuntime(
  25. device_mgr_, worker_env_->env, /*config=*/&config_proto,
  26. gdef.versions().producer(), item->lib_def.get(),
  27. graph_options.optimizer_options(), worker_env_->compute_pool, cluster_flr,
  28. /*session_metadata=*/nullptr,
  29. Rendezvous::Factory{
  30. [this, session](const int64_t step_id, const DeviceMgr*,
  31. Rendezvous** r) -> Status {
  32. auto* remote_r = this->worker_env_->rendezvous_mgr->Find(step_id);
  33. TF_RETURN_IF_ERROR(remote_r->Initialize(session));
  34. *r = remote_r;
  35. return Status::OK();
  36. },
  37. [this](const int64_t step_id) {
  38. this->worker_env_->rendezvous_mgr->Cleanup(step_id);
  39. return Status::OK();
  40. }}));
  41. // Constructs the graph out of "gdef".
  42. Graph graph(OpRegistry::Global());
  43. GraphConstructorOptions opts;
  44. opts.allow_internal_ops = true;
  45. opts.expect_device_spec = true;
  46. opts.validate_nodes = true;
  47. TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, gdef, &graph));
  48. // Splits "graph" into multiple subgraphs by device names.
  49. std::unordered_map<string, GraphDef> partitions;
  50. PartitionOptions popts;
  51. popts.node_to_loc = SplitByDevice; // 这里调用了
  52. popts.new_name = [this](const string& prefix) {
  53. mutex_lock l(mu_);
  54. return strings::StrCat(prefix, "_G", next_id_++);
  55. };
  56. popts.get_incarnation = [this](const string& name) -> int64 {
  57. Device* device = nullptr;
  58. Status s = device_mgr_->LookupDevice(name, &device);
  59. if (s.ok()) {
  60. return device->attributes().incarnation();
  61. } else {
  62. return PartitionOptions::kIllegalIncarnation;
  63. }
  64. };
  65. popts.flib_def = item->lib_def.get();
  66. popts.control_flow_added = true;
  67. popts.scheduling_for_recvs = graph_options.enable_recv_scheduling();
  68. TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions));
  69. if (popts.scheduling_for_recvs) {
  70. TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions));
  71. }
  72. std::unordered_map<string, std::unique_ptr<Graph>> partition_graphs;
  73. // 对每个分区进行图转换
  74. for (auto& partition : partitions) {
  75. std::unique_ptr<Graph> device_graph(new Graph(OpRegistry::Global()));
  76. GraphConstructorOptions device_opts;
  77. // There are internal operations (e.g., send/recv) that we now allow.
  78. device_opts.allow_internal_ops = true;
  79. device_opts.expect_device_spec = true;
  80. TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
  81. device_opts, std::move(partition.second), device_graph.get()));
  82. partition_graphs.emplace(partition.first, std::move(device_graph));
  83. }
  84. GraphOptimizationPassOptions optimization_options;
  85. optimization_options.flib_def = item->lib_def.get();
  86. optimization_options.partition_graphs = &partition_graphs;
  87. TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
  88. OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
  89. LocalExecutorParams params;
  90. item->units.reserve(partitions.size());
  91. item->graph_mgr = this;
  92. const auto& optimizer_opts = graph_options.optimizer_options();
  93. GraphOptimizer optimizer(optimizer_opts);
  94. for (auto& p : partition_graphs) {
  95. const string& device_name = p.first;
  96. std::unique_ptr<Graph>& subgraph = p.second;
  97. item->units.resize(item->units.size() + 1);
  98. ExecutionUnit* unit = &(item->units.back());
  99. // Find the device.
  100. Status s = device_mgr_->LookupDevice(device_name, &unit->device);
  101. if (!s.ok()) {
  102. // Remove the empty unit from the item as the item destructor wants all
  103. // units to have valid devices.
  104. item->units.pop_back();
  105. return s;
  106. }
  107. // 看看是否需要重写图
  108. // Give the device an opportunity to rewrite its subgraph.
  109. TF_RETURN_IF_ERROR(unit->device->MaybeRewriteGraph(&subgraph));
  110. // Top-level nodes in the graph uses the op segment to cache
  111. // kernels. Therefore, as long as the executor is alive, we need
  112. // to ensure the kernels cached for the session are alive.
  113. auto opseg = unit->device->op_segment();
  114. opseg->AddHold(handle);
  115. // Function library runtime.
  116. FunctionLibraryRuntime* lib = item->proc_flr->GetFLR(unit->device->name());
  117. // 建立 executor
  118. // Construct the root executor for the subgraph.
  119. params.device = unit->device;
  120. params.function_library = lib;
  121. params.create_kernel =
  122. [handle, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
  123. OpKernel** kernel) {
  124. // NOTE(mrry): We must not share function kernels (implemented
  125. // using `CallOp`) between subgraphs, because `CallOp::handle_`
  126. // is tied to a particular subgraph. Even if the function itself
  127. // is stateful, the `CallOp` that invokes it is not.
  128. if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
  129. return lib->CreateKernel(props, kernel);
  130. }
  131. auto create_fn = [lib, &props](OpKernel** kernel) {
  132. return lib->CreateKernel(props, kernel);
  133. };
  134. // Kernels created for subgraph nodes need to be cached. On
  135. // cache miss, create_fn() is invoked to create a kernel based
  136. // on the function library here + global op registry.
  137. return opseg->FindOrCreate(handle, props->node_def.name(), kernel,
  138. create_fn);
  139. };
  140. params.delete_kernel = [lib](OpKernel* kernel) {
  141. if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) {
  142. delete kernel;
  143. }
  144. };
  145. // 优化图
  146. optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph,
  147. GraphOptimizer::Options());
  148. TF_RETURN_IF_ERROR(
  149. EnsureMemoryTypes(DeviceType(unit->device->device_type()),
  150. unit->device->name(), subgraph.get()));
  151. unit->graph = std::move(subgraph);
  152. unit->build_cost_model = graph_options.build_cost_model();
  153. if (unit->build_cost_model > 0) {
  154. skip_cost_models_ = false;
  155. }
  156. TF_RETURN_IF_ERROR(NewLocalExecutor(params, *unit->graph, &unit->root));
  157. }
  158. return Status::OK();
  159. }

上面需要注意的一点是使用了 SplitByDevice 进行图的二次切分,这次是按照设备来切分。

  1. // NOTE: node->device_name() is not set by GraphConstructor. We
  2. // expects that NodeDef in GraphDef given to workers fully specifies
  3. // device names.
  4. static string SplitByDevice(const Node* node) {
  5. return node->assigned_device_name();
  6. }
  7. inline const std::string& Node::assigned_device_name() const {
  8. return graph_->get_assigned_device_name(*this);
  9. }

注册图的结果大致如下,就是使用Master传来的各种信息来生成一个Item,注册在GraphMgr之中,同时也为Item生成ExecutionUnit,其中graph_handle是根据handle生成的。

注册完子图之后,后续就可以运行子图。

3. 运行子图

Master 用 RunGraphRequest 来执行在 graph_handle下注册的所有子图。Master 会生成一个全局唯一的 step_id 来区分图计算的不同运行 step。子图之间可以使用 step_id 进行彼此通信(例如,发送/转发操作),以区分不同运行产生的张量。

RunGraphRequest 消息的 send 表示子图输入的张量,recv_key 指明子图输出的张量。RunGraphResponse 会返回 recv_key 对应的 Tensor 列表。

3.1 Service

首先来到了 GrpcWorkerService,调用到的是 "/tensorflow.WorkerService/RunGraph",对应的代码是:

  1. void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
  2. // 利用Schedule把计算任务放进线程池队列中
  3. Schedule([this, call]() {
  4. CallOptions* call_opts = new CallOptions;
  5. ProtoRunGraphRequest* wrapped_request =
  6. new ProtoRunGraphRequest(&call->request);
  7. NonOwnedProtoRunGraphResponse* wrapped_response =
  8. new NonOwnedProtoRunGraphResponse(&call->response);
  9. call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
  10. worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
  11. [call, call_opts, wrapped_request,
  12. wrapped_response](const Status& s) {
  13. call->ClearCancelCallback();
  14. delete call_opts;
  15. delete wrapped_request;
  16. delete wrapped_response;
  17. call->SendResponse(ToGrpcStatus(s));
  18. });
  19. });
  20. ENQUEUE_REQUEST(RunGraph, true);
  21. }

这里是把计算任务放进线程池队列中,具体业务逻辑在 Worker::RunGraphAsync 函数中。

  1. void Schedule(std::function<void()> f) {
  2. worker_->env()->compute_pool->Schedule(std::move(f));
  3. }

3.2 GrpcWorker

在 RunGraphAsync 之中,有两种执行方式,我们选择 DoRunGraph 来分析。

  1. void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
  2. MutableRunGraphResponseWrapper* response,
  3. StatusCallback done) {
  4. if (request->store_errors_in_response_body()) {
  5. done = [response, done](const Status& status) {
  6. response->set_status(status);
  7. done(Status::OK());
  8. };
  9. }
  10. if (request->is_partial()) {
  11. DoPartialRunGraph(opts, request, response, std::move(done)); // 有兴趣读者可以深入研究
  12. } else {
  13. DoRunGraph(opts, request, response, std::move(done)); // 分析这里
  14. }
  15. }

DoRunGraph 主要是调用了 session->graph_mgr()->ExecuteAsync 来执行计算图。

  1. void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
  2. MutableRunGraphResponseWrapper* response,
  3. StatusCallback done) {
  4. const int64_t step_id = request->step_id();
  5. Status s = recent_request_ids_.TrackUnique(request->request_id(),
  6. "RunGraph (Worker)", request);
  7. if (!s.ok()) {
  8. done(s);
  9. return;
  10. }
  11. std::shared_ptr<WorkerSession> session;
  12. if (request->create_worker_session_called()) {
  13. s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
  14. &session);
  15. } else {
  16. session = env_->session_mgr->LegacySession();
  17. }
  18. if (!s.ok()) {
  19. done(s);
  20. return;
  21. }
  22. GraphMgr::NamedTensors in;
  23. GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
  24. s = PrepareRunGraph(request, &in, out);
  25. if (!s.ok()) {
  26. delete out;
  27. done(s);
  28. return;
  29. }
  30. StepStatsCollector* collector = nullptr;
  31. if (request->exec_opts().report_tensor_allocations_upon_oom() ||
  32. request->exec_opts().record_timeline() ||
  33. request->exec_opts().record_costs()) {
  34. collector = new StepStatsCollector(response->mutable_step_stats());
  35. }
  36. DeviceProfilerSession* device_profiler_session = nullptr;
  37. if (collector && request->exec_opts().record_timeline()) {
  38. // If timeline was requested, assume we want hardware level tracing.
  39. device_profiler_session = DeviceProfilerSession::Create().release();
  40. }
  41. CancellationManager* cm = new CancellationManager;
  42. opts->SetCancelCallback([this, cm, step_id]() {
  43. cm->StartCancel();
  44. AbortStep(step_id);
  45. });
  46. CancellationToken token;
  47. token = cancellation_manager_.get_cancellation_token();
  48. bool already_cancelled = !cancellation_manager_.RegisterCallback(
  49. token, [cm]() { cm->StartCancel(); });
  50. if (already_cancelled) {
  51. opts->ClearCancelCallback();
  52. delete cm;
  53. delete collector;
  54. delete device_profiler_session;
  55. delete out;
  56. done(errors::Aborted("Call was aborted"));
  57. return;
  58. }
  59. session->graph_mgr()->ExecuteAsync(
  60. request->graph_handle(), step_id, session.get(), request->exec_opts(),
  61. collector, response, cm, in,
  62. [this, step_id, response, session, cm, out, token, collector,
  63. device_profiler_session, opts, done](const Status& status) {
  64. Status s = status;
  65. if (s.ok()) {
  66. // 接受张量
  67. s = session->graph_mgr()->RecvOutputs(step_id, out);
  68. }
  69. opts->ClearCancelCallback();
  70. cancellation_manager_.DeregisterCallback(token);
  71. delete cm;
  72. if (device_profiler_session) {
  73. device_profiler_session->CollectData(response->mutable_step_stats())
  74. .IgnoreError();
  75. }
  76. if (s.ok()) {
  77. for (const auto& p : *out) {
  78. const string& key = p.first;
  79. const Tensor& val = p.second;
  80. response->AddRecv(key, val);
  81. }
  82. }
  83. if (collector) collector->Finalize();
  84. delete collector;
  85. delete device_profiler_session;
  86. delete out;
  87. done(s);
  88. });
  89. }

3.3 GraphMgr

ExecuteAsync 调用了 StartParallelExecutors 完成并行计算,具体逻辑大致为:

  • 找到一个子图;
  • 计算子图 cost;
  • 生成一个 rendezvous,使用本 session 初始化 rendezvous,后续就是用这个 rendezvous 来通信,rendezvous 利用 session 进行通信;
  • 发送张量到 Rendezvous;
  • 调用 StartParallelExecutors 执行子计算图;
  1. void GraphMgr::ExecuteAsync(const string& handle, const int64_t step_id,
  2. WorkerSession* session, const ExecutorOpts& opts,
  3. StepStatsCollector* collector,
  4. MutableRunGraphResponseWrapper* response,
  5. CancellationManager* cancellation_manager,
  6. const NamedTensors& in, StatusCallback done) {
  7. const uint64 start_time_usecs = Env::Default()->NowMicros();
  8. profiler::TraceMeProducer activity(
  9. // To TraceMeConsumers in ExecutorState::Process/Finish or RunGraphDone.
  10. [step_id] {
  11. return profiler::TraceMeEncode(
  12. "RunGraph", {{"id", step_id}, {"_r", 1} /*root_event*/});
  13. },
  14. profiler::ContextType::kTfExecutor, step_id,
  15. profiler::TraceMeLevel::kInfo);
  16. // Lookup an item. Holds one ref while executing.
  17. // 找到一个子图
  18. Item* item = nullptr;
  19. {
  20. mutex_lock l(mu_);
  21. auto iter = table_.find(handle);
  22. if (iter != table_.end()) {
  23. item = iter->second;
  24. item->Ref();
  25. }
  26. }
  27. // 计算cost
  28. CostGraphDef* cost_graph = nullptr;
  29. if (response != nullptr) {
  30. cost_graph = response->mutable_cost_graph();
  31. if (opts.record_partition_graphs()) {
  32. for (const ExecutionUnit& unit : item->units) {
  33. GraphDef graph_def;
  34. unit.graph->ToGraphDef(&graph_def);
  35. response->AddPartitionGraph(graph_def);
  36. }
  37. }
  38. }
  39. // 生成一个rendezvous
  40. RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
  41. // 使用本session初始化rendezvous,后续就是用这个rendezvous来通信,rendezvous 利用session进行通信
  42. Status s = rendezvous->Initialize(session);
  43. CollectiveExecutor::Handle* ce_handle =
  44. item->collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey
  45. ? new CollectiveExecutor::Handle(
  46. worker_env_->collective_executor_mgr->FindOrCreate(step_id),
  47. true)
  48. : nullptr;
  49. // Sends values specified by the caller.
  50. // 发送张量到Rendezvous
  51. size_t input_size = 0;
  52. if (s.ok()) {
  53. std::vector<string> keys;
  54. std::vector<Tensor> tensors_to_send;
  55. keys.reserve(in.size());
  56. tensors_to_send.reserve(in.size());
  57. for (auto& p : in) {
  58. keys.push_back(p.first);
  59. tensors_to_send.push_back(p.second);
  60. input_size += p.second.AllocatedBytes();
  61. }
  62. // 发送张量
  63. s = SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
  64. }
  65. if (!s.ok()) {
  66. done(s);
  67. delete ce_handle;
  68. item->Unref();
  69. rendezvous->Unref();
  70. return;
  71. }
  72. // 执行子计算图
  73. StartParallelExecutors(
  74. handle, step_id, item, rendezvous, ce_handle, collector, cost_graph,
  75. cancellation_manager, session, start_time_usecs,
  76. [item, rendezvous, ce_handle, done, start_time_usecs, input_size,
  77. step_id](const Status& s) {
  78. profiler::TraceMeConsumer activity(
  79. // From TraceMeProducer in GraphMgr::ExecuteAsync.
  80. [step_id] {
  81. return profiler::TraceMeEncode("RunGraphDone", {{"id", step_id}});
  82. },
  83. profiler::ContextType::kTfExecutor, step_id,
  84. profiler::TraceMeLevel::kInfo);
  85. done(s);
  86. metrics::RecordGraphInputTensors(input_size);
  87. metrics::UpdateGraphExecTime(Env::Default()->NowMicros() -
  88. start_time_usecs);
  89. rendezvous->Unref();
  90. item->Unref();
  91. delete ce_handle;
  92. });
  93. }

具体大致如下,ExecuteAsync使用handle来查找Item,进而找到计算图。其中session用来通信和执行,step_id与通信相关,具体可以参见上面代码。

StartParallelExecutors 会启动一个 ExecutorBarrier。当某一个计算设备执行完所分配的 PartitionGraph 后,ExecutorBarrier 计数器将会增加 1,如果所有设备都完成 PartitionGraph 列表的执行,barrier.wait() 阻塞操作将退出。

  1. void GraphMgr::StartParallelExecutors(
  2. const string& handle, int64_t step_id, Item* item, Rendezvous* rendezvous,
  3. CollectiveExecutor::Handle* ce_handle, StepStatsCollector* collector,
  4. CostGraphDef* cost_graph, CancellationManager* cancellation_manager,
  5. WorkerSession* session, int64_t start_time_usecs, StatusCallback done) {
  6. const int num_units = item->units.size();
  7. ScopedStepContainer* step_container = new ScopedStepContainer(
  8. step_id,
  9. [this](const string& name) { device_mgr_->ClearContainers({name}); });
  10. ExecutorBarrier* barrier =
  11. new ExecutorBarrier(num_units, rendezvous,
  12. [this, item, collector, cost_graph, step_container,
  13. done](const Status& s) {
  14. BuildCostModel(item, collector, cost_graph);
  15. done(s);
  16. delete step_container;
  17. });
  18. Executor::Args args;
  19. args.step_id = step_id;
  20. args.rendezvous = rendezvous;
  21. args.collective_executor = ce_handle ? ce_handle->get() : nullptr;
  22. args.cancellation_manager = cancellation_manager;
  23. args.stats_collector = collector;
  24. args.step_container = step_container;
  25. args.sync_on_finish = sync_on_finish_;
  26. args.start_time_usecs = start_time_usecs;
  27. if (LogMemory::IsEnabled()) {
  28. LogMemory::RecordStep(args.step_id, handle);
  29. }
  30. thread::ThreadPool* pool = worker_env_->compute_pool;
  31. using std::placeholders::_1;
  32. // Line below is equivalent to this code, but does one less indirect call:
  33. // args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
  34. auto default_runner = std::bind(&thread::ThreadPool::Schedule, pool, _1);
  35. for (const auto& unit : item->units) {
  36. thread::ThreadPool* device_thread_pool =
  37. unit.device->tensorflow_device_thread_pool();
  38. if (!device_thread_pool) {
  39. args.runner = default_runner;
  40. } else {
  41. args.runner =
  42. std::bind(&thread::ThreadPool::Schedule, device_thread_pool, _1);
  43. }
  44. unit.root->RunAsync(args, barrier->Get());
  45. }
  46. }

3.4 小结

对于注册/运行子图,我们用一幅图来小结一下。

图 1 注册/运行子图

4. 总结

我们用一幅图来把整个分布式计算流程总结如下:

图 2 分布式计算流程

0xFF 参考

[源码解析] TensorFlow 分布式环境(7) --- Worker 动态逻辑的更多相关文章

  1. [源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑

    [源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑 目录 [源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑 1. 继承关系 1.1 角 ...

  2. [源码解析] TensorFlow 分布式环境(6) --- Master 动态逻辑

    [源码解析] TensorFlow 分布式环境(6) --- Master 动态逻辑 目录 [源码解析] TensorFlow 分布式环境(6) --- Master 动态逻辑 1. GrpcSess ...

  3. [源码解析] TensorFlow 分布式环境(2)---Master 静态逻辑

    [源码解析] TensorFlow 分布式环境(2)---Master 静态逻辑 目录 [源码解析] TensorFlow 分布式环境(2)---Master 静态逻辑 1. 总述 2. 接口 2.1 ...

  4. [源码解析] TensorFlow 分布式环境(8) --- 通信机制

    [源码解析] TensorFlow 分布式环境(8) --- 通信机制 目录 [源码解析] TensorFlow 分布式环境(8) --- 通信机制 1. 机制 1.1 消息标识符 1.1.1 定义 ...

  5. [源码解析] TensorFlow 分布式环境(4) --- WorkerCache

    [源码解析] TensorFlow 分布式环境(4) --- WorkerCache 目录 [源码解析] TensorFlow 分布式环境(4) --- WorkerCache 1. WorkerCa ...

  6. [源码解析] TensorFlow 分布式环境(5) --- Session

    [源码解析] TensorFlow 分布式环境(5) --- Session 目录 [源码解析] TensorFlow 分布式环境(5) --- Session 1. 概述 1.1 Session 分 ...

  7. [源码解析] TensorFlow 分布式环境(1) --- 总体架构

    [源码解析] TensorFlow 分布式环境(1) --- 总体架构 目录 [源码解析] TensorFlow 分布式环境(1) --- 总体架构 1. 总体架构 1.1 集群角度 1.1.1 概念 ...

  8. [源码解析] TensorFlow 分布式 DistributedStrategy 之基础篇

    [源码解析] TensorFlow 分布式 DistributedStrategy 之基础篇 目录 [源码解析] TensorFlow 分布式 DistributedStrategy 之基础篇 1. ...

  9. [源码解析] TensorFlow 分布式之 MirroredStrategy

    [源码解析] TensorFlow 分布式之 MirroredStrategy 目录 [源码解析] TensorFlow 分布式之 MirroredStrategy 1. 设计&思路 1.1 ...

随机推荐

  1. ts转js运行报错:“tsc : 无法加载文件

    一.在typescript.ts转换成.js运行时报错解决办法: 1.第一步:鼠标在vscode软件上右击打开属性–>兼容性–>以管理员的身份运行此程序,如下图: 2.第二步:打开vsco ...

  2. Pandas:从CSV中读取一个含有datetime类型的DataFrame、单项时间数据获取

    前言 有一个CSV文件test.csv,其中有一列是datetime类型,其他列是数值列,就像下边这样: 问题 1.读取该CSV文件,把datetime列转换为datetime类型,并将它设置为索引列 ...

  3. Dubbo服务如何优雅的校验参数

    一.背景 服务端在向外提供接口服务时,不管是对前端提供HTTP接口,还是面向内部其他服务端提供的RPC接口,常常会面对这样一个问题,就是如何优雅的解决各种接口参数校验问题? 早期大家在做面向前端提供的 ...

  4. 测评 | 矩池云上架 RTX 2080 Ti 八卡机开箱

    大家好,福利君今天给给大家带来的是一则消息.矩池云将上架了超微八卡GPU服务器,全新的机器组合,可靠的服务品质. 产品性能 在这里引用Lambda Labs基于FP32对多GPU扩展训练性能评测的数据 ...

  5. java几种数据的默认扩容机制

    当底层实现涉及到扩容时,容器或重新分配一段更大的连续内存(如果是离散分配则不需要重新分配,离散分配都是插入新元素时动态分配内存),要将容器原来的数据全部复制到新的内存上, 这无疑使效率大大降低.加载因 ...

  6. Seastar 教程(三)

    原文:https://github.com/scylladb/seastar/blob/master/doc/tutorial.md Fiber Seastar 延续通常很短,但经常相互链接,因此一个 ...

  7. 4月28日 python学习总结 线程与协程

    一. 异步与回调机制 问题: 1.任务的返回值不能得到及时的处理,必须等到所有任务都运行完毕才能统一进行处理 2.解析的过程是串行执行的,如果解析一次需要花费2s,解析9次则需要花费18s 解决一: ...

  8. C++ TCHAR* 与char* 互转

    C++ TCHAR* 与char* 互转 在MSDN中有这么一段: Note: The ANSI code pages can be different on different computers, ...

  9. 记-beego项目调用Jenkins API获取job信息

    type JenkinsController struct { beego.Controller } type Job struct { Name string `json:"name&qu ...

  10. idea在新窗口中打开

    IntelliJ IDEA 2018.1.4 x64版本同时打开多个窗口可以进行如下设置,找到file--Settings...,然后会弹出下面的窗口:然后注意红框里的勾选项,最后确定Apply,OK ...