[源码解析] PyTorch 分布式 Autograd (3) ---- 上下文相关

0x00 摘要

我们已经知道 dist.autograd 如何发送和接受消息,本文再来看看如何其他支撑部分,就是如何把发送接受两个动作协调起来,如何确定每个发送/接受节点,如何确定每一个消息交互Session。

通过本文大家可以了解:AutogradMetadata 用来在不同节点间传递 autograd 元信息,DistAutogradContext 代表一个分布式autograd 相关信息,DistAutogradContainer 负责在一个worker之上存储 DistAutogradContext。

PyTorch分布式其他文章如下:

深度学习利器之自动微分(1)

深度学习利器之自动微分(2)

[源码解析]深度学习利器之自动微分(3) --- 示例解读

[源码解析]PyTorch如何实现前向传播(1) --- 基础类(上)

[源码解析]PyTorch如何实现前向传播(2) --- 基础类(下)

[源码解析] PyTorch如何实现前向传播(3) --- 具体实现

[源码解析] Pytorch 如何实现后向传播 (1)---- 调用引擎

[源码解析] Pytorch 如何实现后向传播 (2)---- 引擎静态结构

[源码解析] Pytorch 如何实现后向传播 (3)---- 引擎动态逻辑

[源码解析] PyTorch 如何实现后向传播 (4)---- 具体算法

[源码解析] PyTorch 分布式(1)------历史和概述

[源码解析] PyTorch 分布式(2) ----- DataParallel(上)

[源码解析] PyTorch 分布式(3) ----- DataParallel(下)

[源码解析] PyTorch 分布式(4)------分布式应用基础概念

[源码解析] PyTorch分布式(5) ------ DistributedDataParallel 总述&如何使用

[源码解析] PyTorch分布式(6) ---DistributedDataParallel -- 初始化&store

[源码解析] PyTorch 分布式(7) ----- DistributedDataParallel 之进程组

[源码解析] PyTorch 分布式(8) -------- DistributedDataParallel之论文篇

[源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化

[源码解析] PyTorch 分布式(10)------DistributedDataParallel 之 Reducer静态架构

[源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer和Join操作

[源码解析] PyTorch 分布式(12) ----- DistributedDataParallel 之 前向传播

[源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播

[源码解析] PyTorch 分布式 Autograd (1) ---- 设计

[源码解析] PyTorch 分布式 Autograd (2) ---- RPC基础

为了更好的说明,本文代码会依据具体情况来进行相应精简。

0x01 设计脉络

1.1 前文回顾

在前文之中当发送消息时候,我们在 sendMessageWithAutograd 通过 getMessageWithAutograd 来获得了 FORWARD_AUTOGRAD_REQ 类型的消息。

  1. c10::intrusive_ptr<JitFuture> sendMessageWithAutograd(
  2. RpcAgent& agent,
  3. const WorkerInfo& dst,
  4. torch::distributed::rpc::Message&& wrappedRpcMsg,
  5. bool forceGradRecording,
  6. const float rpcTimeoutSeconds,
  7. bool forceDisableProfiling) {
  8. auto msg = getMessageWithAutograd( // 这里会与上下文交互,构建了 FORWARD_AUTOGRAD_REQ
  9. dst.id_,
  10. std::move(wrappedRpcMsg),
  11. MessageType::FORWARD_AUTOGRAD_REQ,
  12. forceGradRecording,
  13. agent.getDeviceMap(dst));
  14. c10::intrusive_ptr<JitFuture> fut;
  15. if (!forceDisableProfiling && torch::autograd::profiler::profilerEnabled()) {
  16. auto profilerConfig = torch::autograd::profiler::getProfilerConfig();
  17. auto msgWithProfiling = getMessageWithProfiling(
  18. std::move(msg),
  19. rpc::MessageType::RUN_WITH_PROFILING_REQ, //构建消息
  20. std::move(profilerConfig));
  21. // 发送消息
  22. fut = agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds);
  23. } else {
  24. // 发送消息
  25. fut = agent.send(dst, std::move(msg), rpcTimeoutSeconds);
  26. }
  27. return fut;
  28. }

而 getMessageWithAutograd 会与上下文交互,其代码位于 torch/csrc/distributed/autograd/utils.cpp。

  1. Message getMessageWithAutograd(
  2. const rpc::worker_id_t dstId,
  3. torch::distributed::rpc::Message&& wrappedRpcMsg,
  4. MessageType msgType,
  5. bool forceGradRecording,
  6. const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
  7. // 获取到 DistAutogradContainer
  8. auto& autogradContainer = DistAutogradContainer::getInstance();
  9. // If there is no valid context and no tensor requires grads, send original
  10. // rpc message. otherwise, attach grad info and grad functions and send
  11. // rpcWithAutograd message.
  12. auto tensorsRequireGrad =
  13. torch::autograd::compute_requires_grad(wrappedRpcMsg.tensors());
  14. if (!autogradContainer.hasValidContext() ||
  15. (!forceGradRecording && !tensorsRequireGrad)) {
  16. return std::move(wrappedRpcMsg);
  17. }
  18. // Retrieve the appropriate context to modify.
  19. auto autogradContext = autogradContainer.currentContext(); // 获取到上下文,每个worker都有自己的上下文
  20. // Wrap the original rpc with autograd information.
  21. // newAutogradMessageId 会生成一个messageID
  22. AutogradMetadata autogradMetadata( // 构建了 AutogradMetadata
  23. autogradContext->contextId(), autogradContainer.newAutogradMessageId());
  24. auto rpcWithAutograd = std::make_unique<RpcWithAutograd>(
  25. RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_,
  26. msgType,
  27. autogradMetadata,
  28. std::move(wrappedRpcMsg),
  29. deviceMap);
  30. if (tensorsRequireGrad) {
  31. // Record autograd information for 'send'.
  32. addSendRpcBackward( // 这里把本地上下文,autograd 的元信息等一起打包
  33. autogradContext, autogradMetadata, rpcWithAutograd->tensors());
  34. }
  35. // Record the workerID
  36. autogradContext->addKnownWorkerId(dstId);
  37. return std::move(*rpcWithAutograd).toMessage(); // 最终构建了一个message
  38. }

因此,就引出了AutogradMetadata,DistAutogradContainer 和 DistAutogradContext 等一系列基础类,我们接下来就仔细分析一下。

1.2 总体思路

我们概括一下总体思路。

先看看问题:假如一套系统包括 a,b,c 三个节点,每个节点运行一个 worker,那么当运行一个传播操作,我们涉及到在这三个节点之间互相传播。因此我们需要一个机制,来在这三个节点之中唯一标示这个传播过程,在这个传播过程之中,也要在每一个节点之上把每一个send/recv都标示出来,这样才能让节点可以支持多个操作并行

再看看解决方案:

  • 使用上下文来唯一标示一个传播过程。DistAutogradContext 存储在一个worker之上的每一个分布式autograd的相关信息,其在分布式 autograd 之中封装前向和后向传播,累积梯度,这避免了多个worker在彼此的梯度上互相影响。每个自动微分过程被赋予一个唯一的 autograd_context_id,在容器中,这个微分过程的上下文(DistAutogradContext) 依据这个autograd_context_id 来唯一确认。
  • 使用autogradMessageId 来表示一对 send/recv autograd 函数。每send-recv对被分配一个全局唯一的autograd_message_id 以唯一地标识该send-recv对。这对于在向后传播期间查找远程节点上的相应函数很有用。
  • 最后,每个worker需要有一个地方来保持上下文和messageid,所以有了DistAutogradContainer这个类。每个worker拥有唯一一个单例DistAutogradContainer,其负责:
    • 对于每一个自动微分过程存储其分布式上下文。
    • 一旦这个自动微分过程结束,就清除其数据。

0x02 AutogradMetadata

2.1 定义

AutogradMetadata 这个类是用来在不同节点之间传递 autograd 的元信息,就是把上下文等信息封装了一下。即,发送方通知接收方自己的上下文信息,接收方会依据收到的这些上下文信息作相应处理。

我们提前剧透,接收方会使用 autogradContextId 和 autogradMessageId 分别作为 上下文 和 消息 的唯一标示。从注释之中可以知道。

  • autogradContextId 是全局唯一整数,用来表示一个唯一的分布式 autograd 传播过程(包括前向传播和后向传播)。一个传播过程会包括在反向传播链条上的多对send/recv autograd 函数。
  • autogradMessageId 是全局唯一整数,用来表示一对 send/recv autograd 函数。每send-recv对被分配一个全局唯一的autograd_message_id 以唯一地标识该send-recv对。这对于在向后传播期间查找远程节点上的相应函数很有用。
  1. // This structure represents autograd metadata that we need to pass across
  2. // different nodes when we call an RPC which needs autograd computation.
  3. struct TORCH_API AutogradMetadata {
  4. AutogradMetadata(int64_t autogradContextId, int64_t autogradMessageId);
  5. // autogradContextId_ is a globally unique integer that identifies a
  6. // particular distributed autograd pass.
  7. int64_t autogradContextId;
  8. // autogradMessageId_ is a globally unique integer that identifies a pair
  9. // of send/recv autograd functions.
  10. int64_t autogradMessageId;
  11. };

那么问题来了,autogradContextId 和 autogradMessageId 分别怎么做到全局(包括多个节点)唯一呢?

2.2 autogradMessageId

我们先概括一下:autogradMessageId 是由 rank 间接生成的,然后在内部进行递增,所以可以保证全局唯一。

我们从后往前推导。

  • 先看 newAutogradMessageId 是如何生成消息 id,原来是在 DistAutogradContainer 之中的成员变量 next_autograd_message_id_ 递增得到。
  1. int64_t DistAutogradContainer::newAutogradMessageId() {
  2. // Check for overflow into workerId_ section.
  3. TORCH_INTERNAL_ASSERT(next_autograd_message_id_ < max_id_);
  4. return next_autograd_message_id_++;
  5. }
  • 然后看如何初始化 next_autograd_message_id_?从 DistAutogradContainer 的 init 函数中可以知道,原来是依据 worker_id 来生成 next_autograd_message_id_。work_id 是 init 函数所得到的参数。
  1. DistAutogradContainer& DistAutogradContainer::init(int64_t worker_id) {
  2. std::lock_guard<std::mutex> guard(dist_container_init_lock_);
  3. auto& container = getInstanceInternal();
  4. container.worker_id_ = worker_id;
  5. container.next_context_id_ = static_cast<int64_t>(worker_id)
  6. << kAutoIncrementBits;
  7. container.next_autograd_message_id_ = static_cast<int64_t>(worker_id)
  8. << kAutoIncrementBits;
  9. container.max_id_ =
  10. (kAutoIncrementMask |
  11. (static_cast<int64_t>(worker_id) << kAutoIncrementBits));
  12. container.initialized_ = true;
  13. return container;
  14. }
  • 我们再推导,看看如何设置 worker id,找到了如下,看来需要看看 python 世界的 _init 方法。
  1. module.def(
  2. "_init",
  3. [](int64_t worker_id) { DistAutogradContainer::init(worker_id); },
  4. py::call_guard<py::gil_scoped_release>());

来到 python 世界,可以看到,使用了 rank 来作为参数,而 rank 是每个 worker 唯一的,这样就保证了 worker ID 唯一,从而 消息 id 唯一。

  1. def init_rpc(
  2. name,
  3. backend=None,
  4. rank=-1,
  5. world_size=None,
  6. rpc_backend_options=None,
  7. ):
  8. dist_autograd._init(rank) # rank是全局唯一

我们把这些逻辑关系总结下来:

  1. worker_id = rank;
  2. container.worker_id_ = worker_id;
  3. container.next_autograd_message_id_ = static_cast<int64_t>(worker_id) << kAutoIncrementBits

然后 next_autograd_message_id_ 内部递增。

  1. int64_t DistAutogradContainer::newAutogradMessageId() {
  2. // Check for overflow into workerId_ section.
  3. TORCH_INTERNAL_ASSERT(next_autograd_message_id_ < max_id_);
  4. return next_autograd_message_id_++;
  5. }

所以,AutogradMessageId 是全局唯一的。我们用图例来看看:

  1. +----------------------------------------------------------------------------------------+
  2. | worker |
  3. | +-------------------------------------+ |
  4. | | DistAutogradContainer | |
  5. | | | |
  6. | | | |
  7. | init() | | |
  8. | rank +--------------+----> worker_id_ | |
  9. | 1 | | | newAutogradMessageId() |
  10. | | +----> next_autograd_message_id_+------------------+ |
  11. | | | 2 | |
  12. | +-------------------------------------+ | |
  13. | | |
  14. | | |
  15. | | |
  16. | | |
  17. | +---------------------------------------------------------------+ |
  18. | | getMessageWithAutograd | | |
  19. | | | | |
  20. | | v | |
  21. | | | |
  22. | | AutogradMetadata autogradMetadata(contextId(), MessageId()) | |
  23. | | 4 3 | |
  24. | | | |
  25. | +---------------------------------------------------------------+ |
  26. | |
  27. +----------------------------------------------------------------------------------------+

为了看看 autogradContextId 为什么可以保证唯一,我们需要先分析 DistAutogradContainer 和 DistAutogradContext。

0x03 DistAutogradContainer

每个worker拥有唯一一个单例DistAutogradContainer,其负责:

  • 对于每一个自动微分过程存储其分布式上下文。
  • 一旦这个自动微分过程结束,就清除其数据。

每个自动微分过程被赋予一个唯一的 autograd_context_id。在每个容器中,这个微分过程的上下文(DistAutogradContext) 依据这个autograd_context_id 来唯一确认。autograd_context_id 是一个 64 bit 的全局唯一id,前 16 bis 是 worker_id,后 48 位是在每个worker内部自动递增id。所以可见,一个Container 之中,是有多个Context的。

此容器还负责维护全局唯一的消息id,用来关联发送/接收自动微分函数对。格式类似于autograd_context_id,是一个64位整数,前16位是工作者id,后48位是worker内部自动递增的。

因为消息 id 和 上下文 id 的前16 位是 worker_id,也就是 rank id,再加上后48位内部自增,所以可以保证 消息 id 和 上下文 id 全局唯一

3.1 定义

DistAutogradContainer 定义如下,其中:

  • worker_id_ : 本 worker 的 ID,其实就是本 worker 的 rank。
  • next_context_id_ :自增的上下文ID,用来给每个自动微分过程赋予一个唯一的autograd_context_id。在一个传播链条上,其实只有第一个节点的 DistAutogradContainer 用到了 next_context_id_ 来生成 Context,后续节点的 DistAutogradContainer 都是依据第一个 DistAutogradContainer 的 context id 信息来在本地生成对应 context id 的 Context。
  • next_autograd_message_id_ :维护全局唯一的消息id,用来关联 发送/接收 自动微分函数对。此变量是在本节点发送时候会使用到。
  1. // Singleton class per worker which is responsible for storing the distributed
  2. // autograd context for each autograd pass and also cleans up data for an
  3. // autograd pass once its done.
  4. //
  5. // Each autograd pass is assigned a unique autograd_context_id and all data for
  6. // that pass (DistAutogradContext) is stored in this container indexed by the
  7. // autograd_context_id. The autograd_context_id itself is a 64 bit globally
  8. // unique id. The first 16 bits is the worker_id and the next 48 bits is an
  9. // auto-incrementing id for each worker.
  10. //
  11. // This container is also responsible for maintaining a globally unique message
  12. // id, which is used to associate send/recv autograd function pairs. The format
  13. // is similar to the autograd_context_id where we have a 64 bit integer with
  14. // first 16 bits being the worker id and next 48 bits are auto-incrementing.
  15. class TORCH_API DistAutogradContainer {
  16. private:
  17. // Number of shards for the map storing autograd contexts. We'd like this
  18. // to be a power of 2 and we don't expect a value much higher than the
  19. // number of cores would provide much benefit.
  20. static constexpr uint32_t kNumDefaultShards = 128;
  21. // Use cache line size for alignment.
  22. static constexpr int kCacheLineSize = 64;
  23. // Structure holding one shard of the sharded autograd context map with its
  24. // associated lock. Align to cache line size to avoid contention between
  25. // adjacent entries.
  26. struct alignas(kCacheLineSize) ContextsShard {
  27. // Lock for this shard.
  28. mutable std::mutex lock;
  29. // Map storing autograd contexts for this shard.
  30. std::unordered_map<int64_t, ContextPtr> contexts; // 这里存储了上下文指针
  31. };
  32. // Auto incrementing context id used to identify unique autograd passes.
  33. // Initialized with the first 16 bits being the worker_id.
  34. std::atomic<int64_t> next_context_id_; // 新增上下文id
  35. // Unique id to identify a worker in the distributed setting.
  36. int16_t worker_id_;
  37. // Whether or not the container has been initialized appropriately.
  38. bool initialized_;
  39. // Sharded autograd context map.
  40. std::vector<ContextsShard> autograd_contexts_; // 存储上下文列表
  41. // Number of shards for the sharded autograd_contexts_ map.
  42. uint32_t num_shards_;
  43. // Autograd message id to identify unique send/recv autograd function pairs.
  44. std::atomic<int64_t> next_autograd_message_id_;
  45. // Maximum allowed value for autograd_context_id or autograd_message_id.
  46. int64_t max_id_;
  47. };

3.2 构建

Init 方法构建了 DistAutogradContainer,主要就是利用 worker_id 对本地成员变量进行相关赋值。

  1. DistAutogradContainer& DistAutogradContainer::init(int64_t worker_id) {
  2. std::lock_guard<std::mutex> guard(dist_container_init_lock_);
  3. TORCH_CHECK(
  4. worker_id >= 0 && worker_id <= kMaxWorkerId,
  5. "worker_id needs to be in the range [0, 65535]")
  6. auto& container = getInstanceInternal();
  7. TORCH_CHECK(
  8. !container.initialized_ || (worker_id == container.worker_id_),
  9. "Container is already initialized with worker_id: ",
  10. container.worker_id_,
  11. ", cannot initialize with different worker_id: ",
  12. worker_id);
  13. if (container.initialized_) {
  14. return container;
  15. }
  16. container.worker_id_ = worker_id;
  17. container.next_context_id_ = static_cast<int64_t>(worker_id)
  18. << kAutoIncrementBits;
  19. container.next_autograd_message_id_ = static_cast<int64_t>(worker_id)
  20. << kAutoIncrementBits;
  21. container.max_id_ =
  22. (kAutoIncrementMask |
  23. (static_cast<int64_t>(worker_id) << kAutoIncrementBits));
  24. container.initialized_ = true;
  25. return container;
  26. }

0x04 DistAutogradContext

DistAutogradContext 存储在一个worker之上的每一个分布式autograd的相关信息,其在分布式 autograd 之中封装前向和后向传播,累积梯度,这避免了多个worker在彼此的梯度上互相影响。

由前面可知道,contextId_ 是全局唯一。

4.1 定义

这里仅仅给出 DistAutogradContext 成员变量,忽略其成员函数。其中成员变量最主要的有三个:

  • contextId_ 是上下文 id。
  • sendAutogradFunctions_ 是一个 map 类型变量,会收集所有发送请求对应的反向传播算子 SendRpcBackward。
  • recvAutogradFunctions_ 是一个 map 类型变量,会收集所有接受送请求对应的反向传播算子 RecvRpcBackward。

关于 SendRpcBackward 和 RecvRpcBackward,我们后续会结合引擎进行分析。

  1. // DistAutogradContext which stores information for a single distributed
  2. // autograd pass on a worker.
  3. class TORCH_API DistAutogradContext {
  4. private:
  5. friend class BackwardPassCleanupGuard;
  6. friend class DistEngine;
  7. friend class RecvRpcBackward;
  8. friend class DistAccumulateGradCaptureHook;
  9. const int64_t contextId_;
  10. // Set containing known worker IDs, used in cleaning up autograd context.
  11. // Whenever a sendRpcBackward is attached to the autograd graph for this
  12. // context, the destination is added here.
  13. std::unordered_set<rpc::worker_id_t> knownWorkerIds_;
  14. // Map from autograd_message_id to appropriate 'send' autograd function.
  15. std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>>
  16. sendAutogradFunctions_;
  17. // Map from autograd_message_id to appropriate 'recv' autograd function.
  18. std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>>
  19. recvAutogradFunctions_;
  20. // Gradients accumulated in this context so far. The key is the variable on
  21. // which the gradient needs to be accumulated and the value is the gradient
  22. // that needs to be accumulated on that variable..
  23. c10::Dict<torch::Tensor, torch::Tensor> accumulatedGrads_;
  24. // See comments for recordGradEvent(c10::Device device);
  25. std::unordered_map<c10::Device, c10::Event> gradReadyEvents_;
  26. const c10::impl::VirtualGuardImpl impl_;
  27. // The autograd GraphTask for the backward pass on this node for this context.
  28. std::shared_ptr<torch::autograd::GraphTask> graphTask_;
  29. // List of futures for RPCs initiated by this node to propagate gradients to
  30. // other nodes. The distributed autograd engine on this node can return
  31. // successfully only if all these futures are done and are successful.
  32. std::vector<c10::intrusive_ptr<rpc::JitFuture>> outStandingRpcs_;
  33. // Lock to protect concurrent modification of the context.
  34. mutable std::mutex lock_;
  35. };

4.2 消息

上下文主要包括几种消息类型,比如:

  1. // Messages with autograd info
  2. FORWARD_AUTOGRAD_REQ = 0x0f | MessageTypeFlags::REQUEST_TYPE,
  3. FORWARD_AUTOGRAD_RESP = 0x10 | MessageTypeFlags::RESPONSE_TYPE,
  4. // Messages to propagate gradients on the backward pass.
  5. BACKWARD_AUTOGRAD_REQ = 0x11 | MessageTypeFlags::REQUEST_TYPE,
  6. BACKWARD_AUTOGRAD_RESP = 0x12 | MessageTypeFlags::RESPONSE_TYPE,

4.3 构建

我们首先看看如何构建上下文。

4.3.1 getOrCreateContext

getOrCreateContext 函数是用来得到上下文,如果已经有,就直接获取,如果没有,就新构建一个。这是一个被动调用,recv 端会用到这个

  1. ContextPtr DistAutogradContainer::getOrCreateContext(int64_t context_id) {
  2. auto& shard = getShard(context_id);
  3. std::lock_guard<std::mutex> guard(shard.lock);
  4. auto it = shard.contexts.find(context_id); // 根据这个context id来查找
  5. if (it != shard.contexts.end()) {
  6. return it->second; // 找到就返回
  7. }
  8. auto& context = // 如果没有,就构建一个 context
  9. shard.contexts
  10. .emplace(
  11. std::piecewise_construct,
  12. std::forward_as_tuple(context_id),
  13. std::forward_as_tuple(
  14. std::make_shared<DistAutogradContext>(context_id)))
  15. .first->second;
  16. return context;
  17. }

4.3.2 newContext

这里是主动调用,send 端会调用这个方法

4.3.2.1 Python

当分布式调用时候,python世界会生成一个context。

  1. with dist_autograd.context() as context_id:
  2. output = model(indices, offsets)
  3. loss = criterion(output, target)
  4. # Run distributed backward pass
  5. dist_autograd.backward(context_id, [loss])
  6. # Run distributed optimizer. Gradients propagated all the way to the parameter servers
  7. opt.step(context_id)

当生成时,__enter__ 会调用 _new_context() 在C++生成一个context。

  1. class context(object):
  2. '''
  3. Context object to wrap forward and backward passes when using
  4. distributed autograd. The ``context_id`` generated in the ``with``
  5. statement is required to uniquely identify a distributed backward pass
  6. on all workers. Each worker stores metadata associated with this
  7. ``context_id``, which is required to correctly execute a distributed
  8. autograd pass.
  9. Example::
  10. >>> import torch.distributed.autograd as dist_autograd
  11. >>> with dist_autograd.context() as context_id:
  12. >>> t1 = torch.rand((3, 3), requires_grad=True)
  13. >>> t2 = torch.rand((3, 3), requires_grad=True)
  14. >>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum()
  15. >>> dist_autograd.backward(context_id, [loss])
  16. '''
  17. def __enter__(self):
  18. self.autograd_context = _new_context() # 这里生成一个上下文
  19. return self.autograd_context._context_id()
  20. def __exit__(self, type, value, traceback):
  21. _release_context(self.autograd_context._context_id())

具体通过如下映射,我们可以看到 C++ 世界之中对应的方法,调用到了 DistAutogradContainer::getInstance().newContext()。

  1. module.def(
  2. "_new_context",
  3. []() -> const ContextPtr {
  4. return DistAutogradContainer::getInstance().newContext();
  5. },
  6. py::return_value_policy::reference);
4.3.2.2 C++

我们来到了C++世界。每一个线程都有一个autograd_context_id。

  1. constexpr int64_t kInvalidContextId = -1;
  2. // Each thread has a single autograd_context_id valid at any point in time.
  3. static thread_local int64_t current_context_id_ = kInvalidContextId;

newContext 就是生成了一个DistAutogradContext,其中通过 Container 的成员变量 next_context_id_ 的递增来指定下一个上下文的id。

  1. const ContextPtr DistAutogradContainer::newContext() {
  2. auto context_id = next_context_id_++; // 递增
  3. current_context_id_ = context_id; // 在这里设置了本地线程的 current_context_id_
  4. // Check for overflow into workerId_ section.
  5. TORCH_INTERNAL_ASSERT(context_id < max_id_);
  6. auto& shard = getShard(context_id);
  7. std::lock_guard<std::mutex> guard(shard.lock);
  8. auto& context =
  9. shard.contexts
  10. .emplace(
  11. std::piecewise_construct,
  12. std::forward_as_tuple(context_id),
  13. std::forward_as_tuple(
  14. std::make_shared<DistAutogradContext>(context_id)))
  15. .first->second;
  16. return context;
  17. }

4.4 如何共享上下文

具体使用中,在with语句中生成的context_id可以用作在所有 worker 之上唯一标识一个分布式后向传播(包括前向传播和后向传播)。每个worker存储与此 context_id关联的元数据,这是正确执行分布式自动加载过程所必需的。

因为需要在多个 worker 之中都存储这个 context_id关联的元数据,所以就需要一个 封装/发送/接受的机制来在 worker 之间传递这个元数据,封装机制就是我们前面提到的 AutogradMetadata。我们接下来看看如何发送/接受上下文元信息

4.4.1 发送方

当发送消息时候,getMessageWithAutograd 会使用 autogradContainer.currentContext() 获取当前上下文,进行发送。

  1. Message getMessageWithAutograd(
  2. const rpc::worker_id_t dstId,
  3. torch::distributed::rpc::Message&& wrappedRpcMsg,
  4. MessageType msgType,
  5. bool forceGradRecording,
  6. const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
  7. auto& autogradContainer = DistAutogradContainer::getInstance();
  8. // If there is no valid context and no tensor requires grads, send original
  9. // rpc message. otherwise, attach grad info and grad functions and send
  10. // rpcWithAutograd message.
  11. auto tensorsRequireGrad =
  12. torch::autograd::compute_requires_grad(wrappedRpcMsg.tensors());
  13. if (!autogradContainer.hasValidContext() ||
  14. (!forceGradRecording && !tensorsRequireGrad)) {
  15. return std::move(wrappedRpcMsg);
  16. }
  17. // Retrieve the appropriate context to modify.
  18. auto autogradContext = autogradContainer.currentContext(); // 获取当前上下文
  19. // Wrap the original rpc with autograd information.
  20. AutogradMetadata autogradMetadata( // 使用上下文id和消息id来构建元数据
  21. autogradContext->contextId(), autogradContainer.newAutogradMessageId());
  22. auto rpcWithAutograd = std::make_unique<RpcWithAutograd>(
  23. RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_,
  24. msgType,
  25. autogradMetadata,
  26. std::move(wrappedRpcMsg),
  27. deviceMap);
  28. if (tensorsRequireGrad) {
  29. // Record autograd information for 'send'.
  30. addSendRpcBackward(
  31. autogradContext, autogradMetadata, rpcWithAutograd->tensors());
  32. }
  33. // Record the workerID
  34. autogradContext->addKnownWorkerId(dstId);
  35. return std::move(*rpcWithAutograd).toMessage();
  36. }

我们之前的图现在可以拓展,加入了上下文ID。

  1. +----------------------------------------------------------------------------------------+
  2. | worker |
  3. | +------------------------------------------+ |
  4. | |DistAutogradContainer | |
  5. | init() | | |
  6. | rank +-------------+----> worker_id_ | |
  7. | | | | |
  8. | | +----> next_context_id_+-------------+ | |
  9. | | | | | |
  10. | | +----> next_autograd_message_id_ +----------------------+ |
  11. | | | | | |
  12. | | | | | |
  13. | +------------------------------------------+ | |
  14. | | | |
  15. | | | |
  16. | | | |
  17. | +------------------------------------------------------------------+ |
  18. | |getMessageWithAutograd | | | |
  19. | | | | | |
  20. | | v v | |
  21. | | | |
  22. | | AutogradMetadata autogradMetadata(contextId(), MessageId()) | |
  23. | | | |
  24. | | | |
  25. | +------------------------------------------------------------------+ |
  26. | |
  27. +----------------------------------------------------------------------------------------+

addSendRpcBackward 就被传入当前上下文之中,后续反向传播时候,会取出这个 addSendRpcBackward。

  1. void addSendRpcBackward(
  2. const ContextPtr& autogradContext,
  3. const AutogradMetadata& autogradMetadata,
  4. std::vector<torch::Tensor>& tensors) {
  5. // Attach autograd information only for tensors requiring grad.
  6. std::vector<torch::Tensor> tensors_with_grad;
  7. std::copy_if(
  8. tensors.begin(),
  9. tensors.end(),
  10. std::back_inserter(tensors_with_grad),
  11. [](const torch::Tensor& t) { return t.requires_grad(); });
  12. // Attach the appropriate autograd edges.
  13. auto grad_fn = std::make_shared<SendRpcBackward>();
  14. grad_fn->set_next_edges(
  15. torch::autograd::collect_next_edges(tensors_with_grad));
  16. // Add the appropriate input metadata for the grad_fn.
  17. for (const auto& tensor : tensors_with_grad) {
  18. grad_fn->add_input_metadata(tensor);
  19. }
  20. // Record the send autograd function in our current context.
  21. autogradContext->addSendFunction(grad_fn, autogradMetadata.autogradMessageId);
  22. }

4.4.2 接受方

在 addRecvRpcBackward 之中,会依据传递过来的 autogradMetadata.autogradContextId 来构建一个上下文。

  1. ContextPtr addRecvRpcBackward(
  2. const AutogradMetadata& autogradMetadata,
  3. std::vector<torch::Tensor>& tensors,
  4. rpc::worker_id_t fromWorkerId,
  5. const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
  6. // Initialize autograd context if necessary.
  7. auto& autogradContainer = DistAutogradContainer::getInstance();
  8. // 生成或者得到一个上下文,把发送方的 autogradContextId 传入,即利用 autogradContextId 作为key后续可以查找到这个上下文
  9. auto autogradContext =
  10. autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId);
  11. if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) {
  12. // Attach the tensors as inputs to the autograd function.
  13. auto grad_fn = std::make_shared<RecvRpcBackward>(
  14. autogradMetadata, autogradContext, fromWorkerId, deviceMap);
  15. for (auto& tensor : tensors) {
  16. if (tensor.requires_grad()) {
  17. torch::autograd::set_history(tensor, grad_fn);
  18. }
  19. }
  20. // Now update the autograd context with the necessary information.
  21. autogradContext->addRecvFunction(
  22. grad_fn, autogradMetadata.autogradMessageId);
  23. }
  24. return autogradContext;
  25. }

这样,发送方和接收方就共享了一个上下文,而且这个上下文的id是全局唯一的。

具体逻辑如下,上方是发送端,下方是接收端。

  • 发送端

    • 利用本地 context_id 构建了 AutogradMetadata,AutogradMetadata含有 ctx_id, msg_id。
    • 利用 AutogradMetadata 构建了 Message。
    • 利用 agent.send 发送了 Message。
  • 接收端:
    • 收到了 Message。
    • 从 Message 之中解析出 AutogradMetadata。
    • 从 AutogradMetadata 提取出 context_id。
    • 利用 context_id 构建了本地的 DistAutogradContext。
  • 发送方和接收方就共享了一个上下文(这个上下文的id是全局唯一的)。
  1. +----------------------------------------------------------------------------------+
  2. | sendMessageWithAutograd |
  3. | |
  4. | +----------------------------------------------------------------------------+ |
  5. | | addSendRpcBackward | |
  6. | | | |
  7. | | | |
  8. | | autogradMetadata = AutogradMetadata(context_id, message_id) | |
  9. | | + | |
  10. | | | | |
  11. | +----------------------------------------------------------------------------+ |
  12. | | |
  13. | v |
  14. | agent.send(message(autogradMetadata) |
  15. | + |
  16. | | |
  17. +----------------------------------------------------------------------------------+
  18. |
  19. |
  20. |
  21. | Sender
  22. +-----------------------------------------------------------------------------------+
  23. | Receiver
  24. | message
  25. v
  26. |
  27. +----------------------------------------------------------------------------------+
  28. | processForwardAutogradReq | |
  29. | | |
  30. | | message.autogradMetadata |
  31. | v |
  32. | +----------------------------------------------------------------------------+ |
  33. | | addSendRpcBackward | | |
  34. | | | | |
  35. | | +--------------------+ | |
  36. | | | | |
  37. | | v | |
  38. | | autogradContext = getOrCreateContext(autogradMetadata.autogradContextId) | |
  39. | | | |
  40. | | | |
  41. | +----------------------------------------------------------------------------+ |
  42. | |
  43. +----------------------------------------------------------------------------------+

0x05 前向传播交互过程

前面的分享过程还是简略,我们接下来把完整的发送/接受过程详细分析一下。

5.1 发送

这里对应设计中的如下文字:

在前向传播期间,我们在上下文中存储每个 autograd 传播的sendrecv函数。这确保我们在 autograd 图中保存对适当节点的引用以使其保持活动状态。除此之外,这也使得在后向传播期间很容易查找到对应的sendrecv函数。

5.1.1 发送逻辑

代码逻辑如下:

  • 生成一个 grad_fn,其类型是 SendRpcBackward。
  • 调用 collect_next_edges 和 set_next_edges 为 SendRpcBackward 添加后续边,这些函数我们在前面系列中有分析。
  • 调用 add_input_metadata 添加输入元数据。
  • 调用 addSendFunction 往上下文添加 grad_fn。
  1. void addSendRpcBackward(
  2. const ContextPtr& autogradContext,
  3. const AutogradMetadata& autogradMetadata,
  4. std::vector<torch::Tensor>& tensors) {
  5. // Attach autograd information only for tensors requiring grad.
  6. std::vector<torch::Tensor> tensors_with_grad;
  7. std::copy_if(
  8. tensors.begin(),
  9. tensors.end(),
  10. std::back_inserter(tensors_with_grad),
  11. [](const torch::Tensor& t) { return t.requires_grad(); });
  12. // Attach the appropriate autograd edges.
  13. auto grad_fn = std::make_shared<SendRpcBackward>();
  14. grad_fn->set_next_edges( // 这里会设置其输出边
  15. torch::autograd::collect_next_edges(tensors_with_grad));
  16. // Add the appropriate input metadata for the grad_fn.
  17. for (const auto& tensor : tensors_with_grad) {
  18. grad_fn->add_input_metadata(tensor);
  19. }
  20. // Record the send autograd function in our current context.
  21. autogradContext->addSendFunction(grad_fn, autogradMetadata.autogradMessageId);
  22. }

5.1.2 设置上下文

我们再回忆一下DistAutogradContext 定义,这里仅仅给出其部分成员变量。

  • contextId_ 是上下文 id。
  • sendAutogradFunctions_ 是一个 map 类型变量,会收集所有发送请求对应的反向传播算子 SendRpcBackward。
  • recvAutogradFunctions_ 是一个 map 类型变量,会收集所有接受送请求对应的反向传播算子 RecvRpcBackward。
  1. // DistAutogradContext which stores information for a single distributed
  2. // autograd pass on a worker.
  3. class TORCH_API DistAutogradContext {
  4. const int64_t contextId_;
  5. // Map from autograd_message_id to appropriate 'send' autograd function.
  6. std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>>
  7. sendAutogradFunctions_;
  8. // Map from autograd_message_id to appropriate 'recv' autograd function.
  9. std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>>
  10. recvAutogradFunctions_;
  11. };

addSendFunction 就是往 sendAutogradFunctions_ 之中添加SendRpcBackward,后续可以按照 message id 来得到这个 SendRpcBackward。

  1. void DistAutogradContext::addSendFunction(
  2. const std::shared_ptr<SendRpcBackward>& func,
  3. int64_t autograd_message_id) {
  4. std::lock_guard<std::mutex> guard(lock_);
  5. TORCH_INTERNAL_ASSERT(
  6. sendAutogradFunctions_.find(autograd_message_id) ==
  7. sendAutogradFunctions_.end());
  8. sendAutogradFunctions_.emplace(autograd_message_id, func);
  9. }

前面是从上下文构建的角度看,本次从上下文内容来看。

此时发送端逻辑如下:

  1. +--------------------------------------------------------------+ +-------------------+
  2. | worker | |SendRpcBackward |
  3. | +---------------------------------------------------------+ | | |
  4. | | DistAutogradContext | | | input_metadata_ |
  5. | | +-------------> | |
  6. | | contextId_ = context_id_1 | | | | next_edges_ |
  7. | | + | | | |
  8. | | sendAutogradFunctions_ = [msg_id_1, SendRpcBackward_1] | | +-------------------+
  9. | | | |
  10. | | | |
  11. | | recvAutogradFunctions_ | |
  12. | | | |
  13. | +---------------------------------------------------------+ |
  14. | |
  15. +--------------------------------------------------------------+
  16. sender
  17. +---------------------------------------------------------------------------------------+

5.2 接受

我们略过 agent 的发送内部处理,转而看看 FORWARD_AUTOGRAD_REQ 的业务流程。

5.2.1 接收消息 ---> 接收方

生成 TensorPipeAgent 时候,把 RequestCallbackImpl 配置为回调函数。这是 agent 的统一响应函数。

前面关于代理接收逻辑时候,我们也提到了,会进入以下函数,其中可以看到有对 processForwardAutogradReq 的处理逻辑。

  1. void RequestCallbackNoPython::processRpc(
  2. RpcCommandBase& rpc,
  3. const MessageType& messageType,
  4. const int64_t messageId,
  5. const c10::intrusive_ptr<JitFuture>& responseFuture,
  6. std::shared_ptr<LazyStreamContext> ctx) const {
  7. case MessageType::FORWARD_AUTOGRAD_REQ: {
  8. // 会来到这里
  9. processForwardAutogradReq(rpc, messageId, responseFuture, std::move(ctx));
  10. return;
  11. }
  12. case MessageType::BACKWARD_AUTOGRAD_REQ: {
  13. processBackwardAutogradReq(rpc, messageId, responseFuture);
  14. return;
  15. };
  16. }

5.2.2 处理消息

processForwardAutogradReq 负责具体处理消息,其处理逻辑如下:

  • 虽然是收到了前向传播请求,但因为此处是接收端,后续需要进行反向传播,所以对deviceMap进行转置。
  • 使用 addRecvRpcBackward 将 rpc 消息 加入上下文。
  • 可能会有nested命令的可能,所以需要再调用一次processRpc。
  • 设置最原始的消息为处理完毕,进行相关操作。
  1. void RequestCallbackNoPython::processForwardAutogradReq(
  2. RpcCommandBase& rpc,
  3. const int64_t messageId,
  4. const c10::intrusive_ptr<JitFuture>& responseFuture,
  5. std::shared_ptr<LazyStreamContext> ctx) const {
  6. auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);
  7. // Need to reverse the device map for the backward pass of distributed
  8. // autograd.
  9. std::unordered_map<c10::Device, c10::Device> reverseDeviceMap;
  10. // 对deviceMap进行转置
  11. for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
  12. reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
  13. }
  14. // Attach 'recv' autograd function.
  15. auto autogradContext = addRecvRpcBackward( // 调用了 addRecvRpcBackward 加入上下文
  16. rpcWithAutograd.autogradMetadata(),
  17. rpcWithAutograd.tensors(),
  18. rpcWithAutograd.fromWorkerId(),
  19. reverseDeviceMap);
  20. // For this recv thread on server side, before processRpc(),
  21. // set current_context_id_ to be context_id passed from client.
  22. // In this way, if there is nested rpc call in python rpc call, original
  23. // context_id from client can be passed in the chain calls.
  24. DistAutogradContextGuard ctxGuard(autogradContext->contextId());
  25. // Process the original RPC.
  26. auto wrappedMessageType = rpcWithAutograd.wrappedMessageType();
  27. // Make an overall future for the wrapped response.
  28. auto wrappedRpcResponseFuture =
  29. c10::make_intrusive<JitFuture>(at::AnyClassType::get());
  30. // Kick off processing for the nested RPC command.
  31. // wrappedRpcResponseFuture will be a Future<T> to the result.
  32. processRpc( // 可能会有nested命令的可能,所以需要再处理一次
  33. rpcWithAutograd.wrappedRpc(),
  34. wrappedMessageType,
  35. messageId,
  36. wrappedRpcResponseFuture,
  37. std::move(ctx));
  38. auto fromWorkerId = rpcWithAutograd.fromWorkerId();
  39. // The original future needs to be marked as completed when the wrapped
  40. // one completes, with the autograd context information wrapped.
  41. wrappedRpcResponseFuture->addCallback(
  42. [responseFuture,
  43. messageId,
  44. fromWorkerId,
  45. ctxId =
  46. autogradContext->contextId()](JitFuture& wrappedRpcResponseFuture) {
  47. // As this callback can be invoked by a different thread, we have to
  48. // make sure that the thread_local states in the previous thread is
  49. // correctly propagated.
  50. // NB: The execution of TorchScript functions can also run on a
  51. // different thread, which is addressed by
  52. // https://github.com/pytorch/pytorch/pull/36395
  53. // NB: when adding async UDF support, we should also propagate
  54. // thread_local states there.
  55. // TODO: Land on a general solution for RPC ThreadLocalState. See
  56. // https://github.com/pytorch/pytorch/issues/38510
  57. DistAutogradContextGuard cbCtxGuard(ctxId);
  58. if (wrappedRpcResponseFuture.hasError()) {
  59. // Propagate error to responseFuture if we had one.
  60. responseFuture->setError(wrappedRpcResponseFuture.exception_ptr());
  61. } else {
  62. auto msg = getMessageWithAutograd(
  63. fromWorkerId,
  64. std::move(
  65. *wrappedRpcResponseFuture.value().toCustomClass<Message>()),
  66. MessageType::FORWARD_AUTOGRAD_RESP);
  67. msg.setId(messageId);
  68. responseFuture->markCompleted(
  69. IValue(c10::make_intrusive<Message>(std::move(msg))));
  70. }
  71. });
  72. }

5.2.3 上下文交互

torch/csrc/distributed/autograd/utils.cpp 之中,addRecvRpcBackward 函数会对上下文进行处理。

这里对应设计中的:

在前向传播期间,我们在上下文中存储每个 autograd 传播的sendrecv函数。这确保我们在 autograd 图中保存对适当节点的引用以使其保持活动状态。除此之外,这也使得在向后传播期间很容易查找到对应的sendrecv函数。

其具体逻辑是:

  • 根据 rpc信息中的 autogradContextId 拿到本地的上下文。
  • 生成一个 RecvRpcBackward。
  • 用 rpc 信息中的张量来对 RecvRpcBackward 进行配置,包括torch::autograd::set_history(tensor, grad_fn)。
  • 调用 addRecvFunction 把 RecvRpcBackward 加入到上下文。
  1. ContextPtr addRecvRpcBackward(
  2. const AutogradMetadata& autogradMetadata,
  3. std::vector<torch::Tensor>& tensors,
  4. rpc::worker_id_t fromWorkerId,
  5. const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
  6. // Initialize autograd context if necessary.
  7. auto& autogradContainer = DistAutogradContainer::getInstance();
  8. auto autogradContext =
  9. autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId);
  10. if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) {
  11. // Attach the tensors as inputs to the autograd function.
  12. auto grad_fn = std::make_shared<RecvRpcBackward>(
  13. autogradMetadata, autogradContext, fromWorkerId, deviceMap);
  14. for (auto& tensor : tensors) {
  15. if (tensor.requires_grad()) {
  16. torch::autograd::set_history(tensor, grad_fn);
  17. }
  18. }
  19. // Now update the autograd context with the necessary information.
  20. autogradContext->addRecvFunction(
  21. grad_fn, autogradMetadata.autogradMessageId);
  22. }
  23. return autogradContext;
  24. }

addRecvFunction 的添加操作如下,就是看看 recvAutogradFunctions_之中是否已经存在这个 message id 对应的算子,如果没有就添加 。

  1. void DistAutogradContext::addRecvFunction(
  2. std::shared_ptr<RecvRpcBackward>& func,
  3. int64_t autograd_message_id) {
  4. TORCH_INTERNAL_ASSERT(func != nullptr);
  5. std::lock_guard<std::mutex> guard(lock_);
  6. TORCH_INTERNAL_ASSERT(
  7. recvAutogradFunctions_.find(autograd_message_id) ==
  8. recvAutogradFunctions_.end());
  9. recvAutogradFunctions_.emplace(autograd_message_id, func);
  10. }

至此,逻辑拓展如下,在发送端和接收端都有一个 DistAutogradContext,其 id 都是 context_id_1。

在 每个 DistAutogradContext 之内,均以 msg_id_1 作为key,一个是 SendRpcBackward,一个建立了 RecvRpcBackward。

这就对应了设计之中提到的:

每个自动微分过程被赋予一个唯一的 autograd_context_id,在容器中,这个微分过程的上下文(DistAutogradContext) 依据这个autograd_context_id 来唯一确认。autograd_context_id 是一个 64 bit 的全局唯一id,前 16 bis 是 worker_id,后 48 位是在每个worker内部自动递增id。所以可见,一个Container 之中,是有多个Context的。

此容器还负责维护全局唯一的消息id,用来关联发送/接收自动微分函数对。格式类似于autograd_context_id,是一个64位整数,前16位是工作者id,后48位是worker内部自动递增的。

  1. +----------------------------------------------------------------+
  2. | worker | +-------------------+
  3. | | |SendRpcBackward |
  4. | +---------------------------------------------------------+ | | |
  5. | | DistAutogradContext | | | input_metadata_ |
  6. | | +-------------> | |
  7. | | contextId_ = context_id_1 | | | | next_edges_ |
  8. | | + | | | |
  9. | | sendAutogradFunctions_ = [msg_id_1, SendRpcBackward_1] | | +-------------------+
  10. | | | |
  11. | | recvAutogradFunctions_ | |
  12. | | | |
  13. | +---------------------------------------------------------+ |
  14. | |
  15. | + |
  16. | | |
  17. +----------------------------------------------------------------+
  18. |
  19. |
  20. | Sender
  21. +-----------------------------------------------------------------------------------------+
  22. | Receiver
  23. |
  24. v
  25. +-----------------------------+----------------------------------+
  26. | worker |
  27. | | +-------------------+
  28. | +---------------------------------------------------------+ | |RecvRpcBackward |
  29. | | DistAutogradContext | | | |
  30. | | | | | |
  31. | | contextId_ = context_id_1 +-----------------> | input_metadata_ |
  32. | | | | | | |
  33. | | sendAutogradFunctions_ | | | | next_edges_ |
  34. | | + | | | |
  35. | | recvAutogradFunctions_ = [msg_id_1, RecvRpcBackward_1]| | +-------------------+
  36. | | | |
  37. | +---------------------------------------------------------+ |
  38. | |
  39. +----------------------------------------------------------------+

我们加入 Container,再拓展一下目前逻辑如下:

  • 每个worker 包括一个DistAutogradContainer。
  • 每个 DistAutogradContainer 包括若干个 DistAutogradContext,依据 context id 提取 DistAutogradContext。
  • 每个 DistAutogradContext 包括 sendAutogradFunctions_ 和 recvAutogradFunctions_,利用 msg id 来获取 SendRpcBackward 或者 RecvRpcBackward。

这样这个反向传播链条就构建了出来。

  1. +------------------------------------------------------------------------------------------------------------------------------------+
  2. | worker |
  3. | |
  4. | +---------------------------------------+ +---------------------------------------------------------+ +-------------------+ |
  5. | | DistAutogradContainer | | DistAutogradContext | |SendRpcBackward | |
  6. | | | | +----------> | | |
  7. | | worker_id_ | | contextId_ = ctx_id_1 | | | input_metadata_ | |
  8. | | | | + | | | |
  9. | | next_autograd_message_id_ +---------> | sendAutogradFunctions_ = [msg_id_1, SendRpcBackward_1] | | next_edges_ | |
  10. | | | | | | | | |
  11. | | next_context_id_ | | | recvAutogradFunctions_ | +-------------------+ |
  12. | | + | | | |
  13. | | autograd_contexts_[ctx_id_1 : ctx] | +---------------------------------------------------------+ |
  14. | | | |
  15. | +----------------------------+----------+ |
  16. | | |
  17. +------------------------------------------------------------------------------------------------------------------------------------+
  18. |
  19. |
  20. +-------------------------------------------------------------------------------------------------------------------------------------+
  21. |
  22. v
  23. +------------------------------+-----------------------------------------------------------------------------------------------------+
  24. | worker |
  25. | |
  26. | +---------------------------------------+ +---------------------------------------------------------+ +-------------------+ |
  27. | | DistAutogradContainer | | DistAutogradContext | |RecvRpcBackward | |
  28. | | | | +----------> | | |
  29. | | worker_id_ | | contextId_ = ctx_id_1 | | | input_metadata_ | |
  30. | | | | | | | | |
  31. | | next_autograd_message_id_ +---------> | sendAutogradFunctions_ | | | next_edges_ | |
  32. | | | | | + | | | |
  33. | | next_context_id_ | | | recvAutogradFunctions_ = [msg_id_1, RecvRpcBackward_1] | +-------------------+ |
  34. | | + | | | |
  35. | | autograd_contexts_[ctx_id_1 : ctx] | +---------------------------------------------------------+ |
  36. | | | |
  37. | +---------------------------------------+ |
  38. | |
  39. +------------------------------------------------------------------------------------------------------------------------------------+

手机如下:

至此,我们初步分析了上下文相关的类,下文我们把目前已经分析的内容结合起来,系统看看业务逻辑。

0xFF 参考

[源码解析] PyTorch 分布式 Autograd (3) ---- 上下文相关的更多相关文章

  1. [源码解析] PyTorch 分布式 Autograd (4) ---- 如何切入引擎

    [源码解析] PyTorch 分布式 Autograd (4) ---- 如何切入引擎 目录 [源码解析] PyTorch 分布式 Autograd (4) ---- 如何切入引擎 0x00 摘要 0 ...

  2. [源码解析] PyTorch 分布式 Autograd (5) ---- 引擎(上)

    [源码解析] PyTorch 分布式 Autograd (5) ---- 引擎(上) 目录 [源码解析] PyTorch 分布式 Autograd (5) ---- 引擎(上) 0x00 摘要 0x0 ...

  3. [源码解析] PyTorch 分布式 Autograd (6) ---- 引擎(下)

    [源码解析] PyTtorch 分布式 Autograd (6) ---- 引擎(下) 目录 [源码解析] PyTtorch 分布式 Autograd (6) ---- 引擎(下) 0x00 摘要 0 ...

  4. [源码解析] PyTorch 分布式 Autograd (1) ---- 设计

    [源码解析] PyTorch 分布式 Autograd (1) ---- 设计 目录 [源码解析] PyTorch 分布式 Autograd (1) ---- 设计 0x00 摘要 0x01 分布式R ...

  5. [源码解析] PyTorch 分布式 Autograd (2) ---- RPC基础

    [源码解析] PyTorch 分布式 Autograd (2) ---- RPC基础 目录 [源码解析] PyTorch 分布式 Autograd (2) ---- RPC基础 0x00 摘要 0x0 ...

  6. [源码解析] PyTorch 分布式(14) --使用 Distributed Autograd 和 Distributed Optimizer

    [源码解析] PyTorch 分布式(14) --使用 Distributed Autograd 和 Distributed Optimizer 目录 [源码解析] PyTorch 分布式(14) - ...

  7. [源码解析] PyTorch分布式优化器(1)----基石篇

    [源码解析] PyTorch分布式优化器(1)----基石篇 目录 [源码解析] PyTorch分布式优化器(1)----基石篇 0x00 摘要 0x01 从问题出发 1.1 示例 1.2 问题点 0 ...

  8. [源码解析] PyTorch分布式优化器(2)----数据并行优化器

    [源码解析] PyTorch分布式优化器(2)----数据并行优化器 目录 [源码解析] PyTorch分布式优化器(2)----数据并行优化器 0x00 摘要 0x01 前文回顾 0x02 DP 之 ...

  9. [源码解析] PyTorch分布式优化器(3)---- 模型并行

    [源码解析] PyTorch分布式优化器(3)---- 模型并行 目录 [源码解析] PyTorch分布式优化器(3)---- 模型并行 0x00 摘要 0x01 前文回顾 0x02 单机模型 2.1 ...

随机推荐

  1. javascript-原生-结构

    1.获取用户输入内容的方法 window.prompt("提示信息","默认值"); 获取用户输入内容(字符串类型),返回用户输入内容. 2.顺序结构:所有语句 ...

  2. Java:ArrayList类小记

    Java:ArrayList类小记 对 Java 中的 ArrayList类,做一个微不足道的小小小小记 概述 java.util.ArrayList 是大小可变的数组的实现,存储在内的数据称为元素. ...

  3. 欧姆龙PLC HostLink协议整理

    欧姆龙PLC HostLink协议整理 1.常用的存储器功能区 CIO: 输入继电器  272 点(17 CH) 0.00-16.15 输出继电器  272 点(17 CH) 100.00-116.1 ...

  4. 航胥:北航教务助手——Beta阶段发布声明

    下载地址在文章末尾! 这里是"航胥",一款更想要了解你的北航教务助手 Beta阶段,我们进化了! Beta阶段我们的新功能有: 课程评价功能 所有用户选过的课程都会在课程评价页面进 ...

  5. 技术博客——微信小程序的架构与原理

    技术博客--微信小程序的架构与原理 在两个月的微信小程序开发过程中,我曾走了不少弯路,也曾被很多现在看来十分可笑的问题所困扰.这些弯路与困扰,基本上都是由于当时对小程序的架构理解不够充分,对小程序的原 ...

  6. [no code][scrum meeting] Alpha 14

    项目 内容 会议时间 2020-04-22 会议主题 周中讨论会议 会议时长 45min 参会人员 全体成员 $( "#cnblogs_post_body" ).catalog() ...

  7. [no code][scrum meeting] Beta 1

    $( "#cnblogs_post_body" ).catalog() 会议纪要 会议在微信群进行:集体反思alpha阶段博客分数尤其是scrum博客分数低的问题,讨论beta阶段 ...

  8. Spring动态添加定时任务

    Spring动态添加定时任务 一.背景 二.需求和实现思路 1.能够动态的添加一个定时任务. 2.能够取消定时任务的执行. 3.动态的修改任务执行的时间. 4.获取定时任务执行的异常 三.代码实现 四 ...

  9. Noip模拟36 2021.8.11

    刚题的习惯还是改不了,怎么办??? T1 Dove打扑克 考场上打的动态开点线段树+并查集,考后发现自己像一个傻子,并查集就行.. 这几天恶补数据结构疯了 用树状数组维护后缀和,$siz_i$表示编号 ...

  10. 一张图彻底搞懂Spring循环依赖

    1 什么是循环依赖? 如下图所示: BeanA类依赖了BeanB类,同时BeanB类又依赖了BeanA类.这种依赖关系形成了一个闭环,我们把这种依赖关系就称之为循环依赖.同理,再如下图的情况: 上图中 ...