[源码解析] PyTorch 分布式之弹性训练(5)---Rendezvous 引擎

0x00 摘要

在前面的文章之中,我们已经学习了PyTorch 分布式的基本模块,介绍了官方的几个例子,我们接下来会介绍PyTorch的弹性训练,本文是第五篇,看看Rendezvous 的内部引擎,比如如何处理节点加入,节点离开,等待,心跳等等。

弹性训练系列文章如下:

[源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路

[源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程

[源码解析] PyTorch 分布式之弹性训练(3)---代理

[源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑

0x01 前言

1.1 总体系统

弹性训练可以理解为在 Rendezvous 基础之上的一个运行系统。

  • Agent 偏重具体节点上的逻辑

    • Agent 负责具体业务逻辑相关操作,比如启动进程执行用户程序,监控用户程序运行情况,如果有异常就通知 Rendezvous。
    • Agent 是一个 worker manager,负责启动/管理 workers 进程,组成一个 worker group,监控 workers 运行状态,捕获失效 workers,如果有故障/新加入worker,则重启 worker group。
    • Agent负责维护 WORLD_SIZE 以及 RANK 信息。用户不需要再手动提供,Agent会自动处理这些。
    • Agent 是具体节点上的后台进程,是独立个体。Agent自己无法实现整体上的弹性训练,所以需要一个机制来完成 worker 之间的相互发现,变更同步等等(WORLD_SIZE 和 RANK 这些信息其实也需要多个节点同步才能确定),这就是下面的 Rendezvous 概念。
  • Rendezvous 负责

    集群逻辑

    ,保证节点之间对于""有哪些节点参与训练"达成强一致共识。

    • 每一个 Agent 内部包括一个 Rendezvous handler,这些 handler 总体上构成了一个 Rendezvous 集群,从而构成了一个 Agent 集群。
    • Rendezvous 完成之后,会创建一个共享键值存储(shared key-value store),这个store实现了一个torch.distributed.Store API。此存储仅由已完成Rendezvous的成员共享,它旨在让Torch Distributed Elastic在初始化作业过程之中交换控制和数据信息。
    • Rendezvous 负责在每个agent之上维护当前 group 所有相关信息。每个 agent 之上有一个 rendezvous,它们会互相通信,总体维护一套信息,这些信息存储在上面提到的Store 之中。
    • Rendezvous 负责集群逻辑相关,比如新加入节点,移除节点,分配rank等等。

1.2 Rendezvous

目前为止,Rendezvous 信息如下,DynamicRendezvousHandler 属于动态逻辑,其中,_RendezvousStateHolder 是状态等元信息存储(静态结构),大家会发现图中还有一个 _RendezvousOpExecutor 没有介绍,这就是运行时引擎,所以我们本文看看 _RendezvousOpExecutor 如何处理。

+-----------------------------+      +------------------------------------------------+
| LocalElasticAgent | | WorkerSpec |
| | | |
| +------------------------+ | | rdzv_handler = {DynamicRendezvousHandler} -------+
| |WorkerGroup | | | | |
| | spec +--------------> | entry = worker_fn | |
| | workers | | | | |
| | store | | | role = {str} 'trainer' | |
| | group_rank | | | | |
| | group_world_size | | +------------------------------------------------+ |
| | | | |
| +------------------------+ | |
| | |
| rdzv_run_id | |
| store | +-----------------------------------------+ |
| | |DynamicRendezvousHandler | |
+-----------------------------+ | | |
| | |
| _settings: RendezvousSettings | <--+
| |
| _store: Store |
| |
| _state_holder: _RendezvousStateHolder |
| |
| _op_executor: _RendezvousOpExecutor |
| |
+-----------------------------------------+

1.3 解耦

_RendezvousOpExecutor 把功能分割解耦:

  • 业务逻辑被抽象成为一系列算子,比如 _RendevzousJoinOp
  • Rendezvous 内部维护了一套由业务函数组成的状态机,比如函数 _add_to_participants 用来添加参与者。
  • _RendezvousOpExecutor 引擎来执行各种算子,依据算子结果,得到一个 Action,再利用 Action 调用业务函数进行操作。

本文主要介绍C10d 后端对应的 Rendezvous 引擎。

0x02 引擎实现

2.1 基类

_RendezvousOpExecutor 是引擎的基类,只是定义了run这个虚函数。

class _RendezvousOpExecutor(ABC):
"""Executes rendezvous operations.""" @abstractmethod
def run(
self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
) -> None:
"""Executes a rendezvous operation. An operation is run inside a state machine and is expected to transition
the rendezvous from one state to another. Args:
state_handler:
A callable that is expected to return the next state transition
action based on the current state of the rendezvous.
deadline:
The time, in seconds, at which the operation will be considered
timed-out.
"""

这里用到了 _RendezvousContext,其作用是把 Rendezvous 的各种信息封装了起来,提供给操作引擎。这里就有了 _RendezvousState 和 RendezvousSettings 的使用。

class _RendezvousContext:
"""Holds the context of the rendezvous. Attributes:
node:
The node descriptor associated with the current rendezvous handler
instance.
state:
The current state of the rendezvous.
settings:
The rendezvous settings.
""" node: _NodeDesc
state: _RendezvousState
settings: RendezvousSettings def __init__(
self, node: _NodeDesc, state: _RendezvousState, settings: RendezvousSettings
) -> None:
self.node = node
self.state = state
self.settings = settings

2.2 分布式操作引擎

_DistributedRendezvousOpExecutor 拓展了 _RendezvousOpExecutor,是 ElasticTorch 的实际执行者。类似于 Looper,负责消息分发,调用业务,状态维护

2.2.1 定义

与其基类相比,_DistributedRendezvousOpExecutor 加入了比如节点信息,状态,配置这样的成员变量。

class _DistributedRendezvousOpExecutor(_RendezvousOpExecutor):
"""Executes rendezvous operations using a shared state. Args:
node:
The node descriptor associated with the current rendezvous handler
instance.
state_holder:
The ``RendezvousStateHolder`` to use to sync the rendezvous state
with other nodes.
settings:
The rendezvous settings.
""" _node: _NodeDesc
_state: _RendezvousState
_state_holder: _RendezvousStateHolder
_settings: RendezvousSettings def __init__(
self,
node: _NodeDesc,
state_holder: _RendezvousStateHolder,
settings: RendezvousSettings,
) -> None:
self._node = node
self._state_holder = state_holder
self._settings = settings

逻辑如下:

+---------------------------------------------------------------+
| _DistributedRendezvousOpExecutor |
| |
| +------------------------+ |
| _state +---> | _RendezvousState | |
| | | |
| | participants | |
| | wait_list | |
| | last_heartbeats | |
| | deadline | |
| +------------------------+ |
| |
| +-------------------------+ |
| _settings +--> | RendezvousSettings | |
| | | |
| +-------------------------+ |
| |
| +--------------------------------------+ |
| _state_holder +---> | _BackendRendezvousStateHolder | |
| | | |
| | _backend: RendezvousBackend | |
| | _state: _RendezvousState | |
| | _settings: RendezvousSettings | |
| | | |
| +--------------------------------------+ |
| +--------------------------------------+ |
| | _NodeDesc | |
| _node +-------> | fqdn: str | |
| | pid: int | |
| | local_id: int | |
| | | |
| +--------------------------------------+ |
+---------------------------------------------------------------+

2.2.2 调用

我们举出几个例子来看看如何调用引擎,可以看到都是先设置算子,然后调用引擎的run函数。

2.2.2.1 _RendezvousKeepAliveOp
def _keep_alive(self) -> None:
self._heartbeat_lock.acquire()
op = _RendezvousKeepAliveOp() # 设置算子
deadline = self._get_deadline(self._settings.timeout.heartbeat)
self._op_executor.run(op, deadline) # 调用
2.2.2.2 _RendezvousCloseOp
def _close(self) -> None:
op = _RendezvousCloseOp() # 设置算子
deadline = self._get_deadline(self._settings.timeout.close)
self._op_executor.run(op, deadline) # 调用
2.2.2.3 _RendezvousJoinOp
def next_rendezvous(self) -> Tuple[Store, int, int]:
"""See base class.""" self._stop_heartbeats() # Delay the execution for a small random amount of time if this is our
# first run. This will slightly skew the rendezvous attempts across the
# nodes and reduce the load on the backend.
if self._state_holder.state.round == 0:
_delay(seconds=(0, 0.3)) exit_op = _RendezvousExitOp() # 设置算子
join_op = _RendezvousJoinOp() # 设置算子 deadline = self._get_deadline(self._settings.timeout.join) self._op_executor.run(exit_op, deadline) # 这里会进行调用
self._op_executor.run(join_op, deadline) # 调用 self._start_heartbeats() rank, world_size = self._get_world()
store = self._get_store() return store, rank, world_size

2.2.3 功能

_DistributedRendezvousOpExecutor 之中,run 函数实现了基础逻辑,就是依据 action 类型进行各种操作。

2.2.3.1 主体循环

run 具体代码如下:

    def run(
self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
) -> None:
"""See base class."""
action = None while action != _Action.FINISH: # 循环,一直到获得一个FINISH action 为止
# Reads or writes the latest rendezvous state shared by all nodes in
# the rendezvous. Note that our local changes might get overridden
# by another node if that node synced its changes before us. # 这里很重要,在所有node之间做信息同步
has_set = self._state_holder.sync() # 因为最新状态在 rendezvous。 self._state = self._state_holder.state ctx = _RendezvousContext(self._node, self._state, self._settings) # Determine the next action to take based on the current state of
# the rendezvous.
action = state_handler(ctx, deadline) # 决定下一个操作,state_handler 就是算子 if action == _Action.FINISH:
continue if action == _Action.ERROR_CLOSED:
raise RendezvousClosedError() if action == _Action.ERROR_TIMEOUT:
raise RendezvousTimeoutError() if action == _Action.SYNC:
# Delay the execution by one second to avoid overloading the
# backend if we are asked to poll for state changes.
_delay(seconds=1)
else:
if action == _Action.KEEP_ALIVE:
self._keep_alive()
elif action == _Action.ADD_TO_PARTICIPANTS:
self._add_to_participants()
elif action == _Action.ADD_TO_WAIT_LIST:
self._add_to_wait_list()
elif action == _Action.REMOVE_FROM_PARTICIPANTS:
self._remove_from_participants()
elif action == _Action.REMOVE_FROM_WAIT_LIST:
self._remove_from_wait_list()
elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
self._mark_rendezvous_complete()
elif action == _Action.MARK_RENDEZVOUS_CLOSED:
self._mark_rendezvous_closed() # Attempt to sync our changes back to other nodes.
self._state_holder.mark_dirty()

具体如下图。

+-----------------------------------------+                          +---------------------------------------------------------------+
|DynamicRendezvousHandler | | _DistributedRendezvousOpExecutor |
| | | |
| | | +------------------------+ |
| _settings: RendezvousSettings | | _state +---> | _RendezvousState | |
| | | | | |
| | | | participants | |
| _store: Store | | | wait_list | |
| | | | last_heartbeats | |
| | | | deadline | |
| _state_holder: _RendezvousStateHolder | | +------------------------+ |
| | run(_RendezvousJoinOp()) | +-------------------------+ |
| | | _settings +--> | RendezvousSettings | |
| _op_executor +------------------------------------------------> | | | |
| | | +-------------------------+ |
| | | +--------------------------------------+ |
+-----------------------------------------+ | _state_holder +---> | _BackendRendezvousStateHolder | |
| | | |
| | _backend: RendezvousBackend | |
| | _state: _RendezvousState | |
| | _settings: RendezvousSettings | |
| | | |
| +--------------------------------------+ |
| +--------------------------------------+ |
| | _NodeDesc | |
| _node +-------> | fqdn: str | |
| | pid: int | |
| | local_id: int | |
| | | |
| +--------------------------------------+ |
+---------------------------------------------------------------+

手机如下:

2.2.3.2 同步

在 run 函数之中,需要注意的是:在执行各种算子操作之前,会调用 self._state_holder.sync() 在各个 worker 之间进行一个状态同步,达成共识 (consensus)

def sync(self) -> Optional[bool]:
"""See base class."""
state_bits: Optional[bytes] = None
token = None
has_set: Optional[bool] if self._dirty: # 如果本node状态变化了
has_set = False
state_bits = pickle.dumps(self._state)
# 把自己的状态设置到backend之中
set_response = self._backend.set_state(state_bits, self._token)
if set_response is not None:
state_bits, token, has_set = set_response
else: # 自己没变化,只能从后端获取
has_set = None
if self._cache_duration > 0:
# Avoid overloading the backend if we are asked to retrieve the
# state repeatedly. Try to serve the cached state.
if self._last_sync_time >= max(time.monotonic() - self._cache_duration, 0):
return None
get_response = self._backend.get_state() # 从backend获取其他节点最新状态
if get_response is not None:
state_bits, token = get_response if state_bits is not None:
try:
self._state = pickle.loads(state_bits) # 用后端状态更新本身的状态
except pickle.PickleError as exc:
raise RendezvousStateError(
"The rendezvous state is corrupt. See inner exception for details."
) from exc
else:
self._state = _RendezvousState() if has_set and self._dead_nodes and log.isEnabledFor(logging.DEBUG):
node_list = ", ".join(f"'{dead_node}'" for dead_node in self._dead_nodes)
msg = (
f"As part of the sync operation the node(s) {node_list} have been removed from the "
f"rendezvous '{self._settings.run_id}' since they had no heartbeat."
)
self._record(message=msg) self._token = token
self._dirty = False
self._last_sync_time = time.monotonic()
self._sanitize() return has_set
后端

torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py 之中是对应后端代码。

后端这里使用 store 作为一个集中式存储,是master。每个 node 是 client,会去master更新自己状态,并且获取其他node状态。这样所有node就会互通有无,达成共识。这里也会定期删除不更新元数据的clients。

get_state 就是简单的从 store 提取。

def get_state(self) -> Optional[Tuple[bytes, Token]]:
"""See base class."""
base64_state: bytes = self._call_store("get", self._key) return self._decode_state(base64_state)

set_state 会做一个compare set,其返回new state和是否更新了state。

def set_state(
self, state: bytes, token: Optional[Token] = None
) -> Optional[Tuple[bytes, Token, bool]]:
"""See base class."""
base64_state_str: str = b64encode(state).decode() if token:
# Shortcut if we know for sure that the token is not valid.
if not isinstance(token, bytes):
result = self.get_state()
if result is not None:
tmp = *result, False
# Python 3.6 does not support tuple unpacking in return
# statements.
return tmp
return None token = token.decode()
else:
token = self._NULL_SENTINEL base64_state: bytes = self._call_store("compare_set", self._key, token, base64_state_str) state_token_pair = self._decode_state(base64_state)
if state_token_pair is None:
return None new_state, new_token = state_token_pair # C10d Store's compare_set method does not offer an easy way to find out
# whether our write attempt was successful. As a brute-force solution we
# perform a bitwise comparison of our local state and the remote state.
return new_state, new_token, new_state == state
_sanitize

_sanitize 方法用来依据其他节点消息做处理,比如清理故障节点。即,如果上一次的心跳时间超过了一定阈值范围,则会把这些节点标记为dead_node,并且从 participant或者wait list中清除这些节点。

def _sanitize(self) -> None:
state = self._state expire_time = datetime.utcnow() - (
self._settings.keep_alive_interval * self._settings.keep_alive_max_attempt
) # Filter out the dead nodes.
self._dead_nodes = [
node
for node, last_heartbeat in state.last_heartbeats.items()
if last_heartbeat < expire_time
] participant_removed = False for dead_node in self._dead_nodes:
del state.last_heartbeats[dead_node] # 移除故障节点 try:
del state.participants[dead_node] # 移除故障节点 participant_removed = True
except KeyError:
pass try:
state.wait_list.remove(dead_node) # 移除故障节点
except KeyError:
pass if participant_removed:
# Common epilogue shared with the _remove_from_participants()
# function of _DistributedRendezvousOpExecutor.
_remove_participant_epilogue(state, self._settings)

介绍完毕如何运行引擎,我们接下来看看具体算子。

0x03 算子

_RendezvousOpExecutor 引擎的业务逻辑被分成两层:用户操作 和 内部业务逻辑。用户操作和内部业务机制之间被解耦。

  • 用户操作被分成各种算子,包括:心跳,Join,关闭,结束。比如Join 算子就是 _RendevzousJoinOp

  • 内部业务逻辑被分成各种业务函数,比如 _add_to_participants 方法从等待列表中移除节点,往 participants 加入这个节点。

  • 算子和内部业务逻辑并不是一一对应,需要一个类似状态机的机制来控制。

    • 比如,心跳操作算子的结果可能是:超时/keep alive/正常结束,所以应该根据这个结果调用不同的内部业务函数。这种对应关系逻辑就是通过 Action 来完成的
    • 各种算子联合起来,聚合成了一个状态机。
    • 算子内部就是生成各种 Action,决定了状态机的下一步操作。
  • 引擎内部就是根据 Action 来执行具体业务逻辑,或者可以说,是通过 Action 进行解耦。

具体如下,引擎从逻辑上可以分成三层:最上面是算子层,中间是 Action 层,下面是业务函数层。

+-----------------------------------------------------------------------------------------+
| |
| _RendezvousKeepAliveOp _RendezvousCloseOp _RendezvousExitOp _RendezvousJoinOp |
| |
+-------------+---------------------+--------------------+------------------+-------------+
| | | |
| | | |
| | | |
| | | |
v v v v +-----------------------------------------------------------------------------------------+
| |
| KEEP_ALIVE ADD_TO_PARTICIPANTS ADD_TO_WAIT_LIST REMOVE_FROM_WAIT_LIST ...... |
| |
+-------------+----------+----------+----------+---------+---------+---------+------------+
| | | | | | |
| | | | | | |
| | | | | | |
| | | | | | |
v v v v v v v +-----------------------------------------------------------------------------------------+
| |
| _add_to_participants _remove_from_participants _add_to_wait_list ...... |
| |
| |
+-----------------------------------------------------------------------------------------+

我们逐一解析。

3.1 操作

先来解析中间层 Action,看看有多少 Action。基于 rendezvous 的状态,引擎的actions具体如下。代码位于 torch/distributed/elastic/rendezvous/dynamic_rendezvous.py

class _Action(Enum):
"""Specifies the possible actions based on the state of the rendezvous.""" KEEP_ALIVE = 1
ADD_TO_PARTICIPANTS = 2
ADD_TO_WAIT_LIST = 3
REMOVE_FROM_PARTICIPANTS = 4
REMOVE_FROM_WAIT_LIST = 5
MARK_RENDEZVOUS_COMPLETE = 6
MARK_RENDEZVOUS_CLOSED = 7
SYNC = 8
ERROR_CLOSED = 9
ERROR_TIMEOUT = 10
FINISH = 11

3.2 算子

引擎之中实现了一些算子,基本上,一个操作对应一个算子,我们给出几个操作算子的例子,算子就是依据rendezvous的状态来设置操作类型

3.2.1 心跳

3.2.1.1 检查心跳

_RendezvousKeepAliveOp 的作用是:依据当前状态和时间来确定下一步Action。主要是定期检查本Node是否故障。

class _RendezvousKeepAliveOp:
"""Represents a rendezvous keep-alive update operation.""" def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
if _should_keep_alive(ctx):
if time.monotonic() > deadline:
return _Action.ERROR_TIMEOUT
return _Action.KEEP_ALIVE
return _Action.FINISH

_should_keep_alive 方法为:

def _should_keep_alive(ctx: _RendezvousContext) -> bool:
"""Determines whether a keep-alive heartbeat should be sent."""
try:
last_heartbeat = ctx.state.last_heartbeats[ctx.node]
except KeyError:
return False return last_heartbeat <= datetime.utcnow() - ctx.settings.keep_alive_interval
3.2.1.2 定期调用

这里要注意的是,因为做任何算子之前,都要调用 sync 操作,而 sync 会在 node 之间同步状态,因为心跳是定期的,所以同步状态也是定期的。

DynamicRendezvousHandler 之中会启动一个timer,定期调用_keep_alive_weak方法。

def _start_heartbeats(self) -> None:
self._keep_alive_timer = _PeriodicTimer(
self._settings.keep_alive_interval, self._keep_alive_weak, weakref.ref(self)
) self._keep_alive_timer.set_name(f"RendezvousKeepAliveTimer_{self._this_node.local_id}")
self._keep_alive_timer.start()

其次,_keep_alive_weak 会调用 self._keep_alive()

@staticmethod
def _keep_alive_weak(weak_self) -> None:
self = weak_self()
if self is not None:
self._keep_alive()

_keep_alive 会调用 _RendezvousKeepAliveOp。

def _keep_alive(self) -> None:
self._heartbeat_lock.acquire()
op = _RendezvousKeepAliveOp()
deadline = self._get_deadline(self._settings.timeout.heartbeat) try:
self._op_executor.run(op, deadline)
msg = (
f"The node '{self._this_node}' has sent a keep-alive heartbeat to the rendezvous "
f"'{self._settings.run_id}'."
)
self._record(message=msg)
log.debug(msg)
except RendezvousError as ex:
msg = (
f"The node '{self._this_node}' has failed to send a keep-alive heartbeat to the "
f"rendezvous '{self._settings.run_id}' due to an error of type {type(ex).__name__}."
)
self._record(message=msg, node_state=NodeState.FAILED)
finally:
self._heartbeat_lock.release()
3.2.1.2 设置心跳

另外,_DistributedRendezvousOpExecutor 有一个 _keep_alive 同名函数,是用来实现内部逻辑,我们后续会讲到。

3.2.2 关闭

_RendezvousCloseOp 会依据当前状态和时间来确定下一步Action。

class _RendezvousCloseOp:
"""Represents a rendezvous close operation.""" def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
if ctx.state.closed:
return _Action.FINISH
if time.monotonic() > deadline:
return _Action.ERROR_TIMEOUT
return _Action.MARK_RENDEZVOUS_CLOSED

3.2.3 结束

_RendezvousExitOp 依据当前状态和时间来确定下一步Action。如果本Node不在participants之中,不处理。否则返回一个从 participants 列表删除的下一步Action。如果超时则返回对应Action。

class _RendezvousExitOp:
"""Represents a rendezvous exit operation.""" def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
if ctx.node in ctx.state.participants:
if time.monotonic() > deadline:
return _Action.ERROR_TIMEOUT
return _Action.REMOVE_FROM_PARTICIPANTS
return _Action.FINISH

3.2.4 Join

_RendezvousJoinOp 这里依据系统状态不同,做不同处理,比如试图把本Node加入到participant,或者 waiting list,或者继续等待,具体可以参见代码注释。

  • 从上下文之中提取 _RendezvousState 状态,把结果存放在 state 之中。
  • 如果状态是closed,则说明此时rendezvous已经结束,则返回_Action.ERROR_CLOSED。
  • 看看是不是参与者,把结果存放在is_participant。
  • 如果状态已经结束,且本节点已经是参与者,则说明 rendezvous 可以结束,返回 _Action.FINISH。
  • 获取当前时间 now。
  • 如果 now > deadline,说明已经超时。
    • 如果还有时间做 rollback,说明本节点要返回之前的状态。

      • 如果本节点已经是参与者,说明此时总节点数目没有达到 min,虽然已经是参与者,但是需要从参与者列表移除,所以返回 _Action.REMOVE_FROM_PARTICIPANTS。
      • 如果本节点在等待列表之中,说明此时总节点数目没有达到 max,虽然在等待列表之中,但是需要从等待列表移除,所以返回_Action.REMOVE_FROM_WAIT_LIST。
    • 否则返回_Action.ERROR_TIMEOUT。
  • 否则没有超时,继续处理。
    • 如果state.complete 并且本节点不是参与者(如果节点是参与者,前面已经处理过了),说明rendezvous 已经结束,如果还没有达到最大节点数目,并且当前node不在等待列表之中,就需要添加到等待节点列表,等待下次监控周期到的时候,重新做rendezvous,就可以把等待列表中的节点加入到参与列表之中。所以返回_Action.ADD_TO_WAIT_LIST。
    • 如果本节点是参与者并且state不是complete状态(如果是complete状态,前面已经处理过了),如果已经达到了最小节点数 & 已经超时了,则说明rendezvous 已经结束,则返回_Action.MARK_RENDEZVOUS_COMPLETE。
    • 否则说明没结束,本节点也不是参与者,则直接加入到参与者列表,返回_Action.ADD_TO_PARTICIPANTS。
  • 如果需要保持心跳,就返回 _Action.KEEP_ALIVE。
  • 否则返回_Action.SYNC。
class _RendezvousJoinOp:
"""Represents a rendezvous join operation.""" def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
state = ctx.state # 从上下文之中提取 _RendezvousState 状态 # A closed rendezvous means that it no longer accepts new nodes.
if state.closed:
return _Action.ERROR_CLOSED # 如果已经结束,就返回 _Action.ERROR_CLOSED is_participant = ctx.node in state.participants # 看看是不是参与者 # If we are part of the rendezvous and it is already complete there is
# no further action to take.
if state.complete and is_participant: # 如果是参与者且状态是结束,就返回 _Action.FINISH
return _Action.FINISH now = time.monotonic()
if now > deadline: # 如果已经超时
rollback_period = 5 # 5 seconds # If we still have time to rollback (a short period on top of the
# operation deadline), try to remove ourself from the rendezvous.
# It is okay if we can't though as our keep-alive will eventually
# expire.
if now <= deadline + rollback_period: # 如果还有时间来 rollback
# If we are part of the rendezvous, it means we couldn't find
# enough participants to complete it on time.
if is_participant: # 此时尚未达到min,虽然已经是参与者,但是需要移除
return _Action.REMOVE_FROM_PARTICIPANTS # 需要从参与者列表移除
# If we are in the wait list, it means we couldn't wait till the
# next round of the rendezvous.
if ctx.node in state.wait_list: # 此时已经达到 max,虽然已经在等待列表之中,需要移除
return _Action.REMOVE_FROM_WAIT_LIST # 需要从等待列表移除
return _Action.ERROR_TIMEOUT # 返回超时 if state.complete: # 如果 rendezvous 已经结束
# If we are here, it means we are not part of the rendezvous. In
# case the rendezvous has capacity for additional participants add
# ourself to the wait list for the next round.
if len(state.participants) < ctx.settings.max_nodes: # 如果还没有达到最大节点数
if ctx.node not in state.wait_list: # 如果当前node不在等待列表之中
return _Action.ADD_TO_WAIT_LIST # 就加入到等待列表,发送一个等待action
elif is_participant: # 如果已经在参与者列表
# If the rendezvous has enough number of participants including us,
# check whether we have passed the rendezvous deadline. If yes,
# complete it.
if len(state.participants) >= ctx.settings.min_nodes: # 如果达到了最小节点数
if cast(datetime, state.deadline) < datetime.utcnow(): # 如果达到了超时
return _Action.MARK_RENDEZVOUS_COMPLETE # 标示 rendezvous 已经结束
else: # 否则就直接加入到参与者
# The rendezvous is not complete yet and we are not part of it. Try
# to join.
return _Action.ADD_TO_PARTICIPANTS if _should_keep_alive(ctx): # 如果需要保持心跳,就返回 _Action.KEEP_ALIVE
return _Action.KEEP_ALIVE # At this point either the rendezvous is not complete, but we are part
# of it, which means we have to wait for other participants to join; or
# the rendezvous is complete, but we are not part of it, which means we
# have to wait for the next round.
return _Action.SYNC # 否则返回同步状态 _Action.SYNC

具体逻辑如下:

                           state.closed
+--------------------------> _Action.ERROR_CLOSED
|
|
| complete & participant
+--------------------------> _Action.FINISH
|
|
| timeout & participant
+--------------------------> _Action.REMOVE_FROM_PARTICIPANTS
|
|
| timeout & wait
+--------------------------> _Action.REMOVE_FROM_WAIT_LIST
|
+-------------------+ |
| | | timeout
| _RendezvousJoinOp +------------------------------> _Action.ERROR_TIMEOUT
| | |
+-------------------+ | complete & < max & not wait
|
+--------------------------> _Action.ADD_TO_WAIT_LIST
|
| complete & participant & > min & deadline
|
+--------------------------> _Action.MARK_RENDEZVOUS_COMPLETE
|
| not complete & not participant
|
+--------------------------> _Action.ADD_TO_PARTICIPANTS
|
| _should_keep_alive
|
+--------------------------> _Action.KEEP_ALIVE
|
| else
|
+--------------------------> _Action.SYNC

以下是源码之中 ETCD 后端 Rendezvous 状态描述图,我们可以大致参考比对 c10d的状态。

可见,etcd 后端的Join可以分为4个阶段:

  • setup 阶段,会往固定目录写一个值,这是一个排他锁,如果写失败,说明目前正有一个 rendezvous 过程在进行中。
  • join(joinable) 阶段。如果写值成功,则进入join 阶段。如果在等待时间结束或者参与训练的节点达到了最大值,则进入 frozen 阶段。
  • frozen(confirm)阶段。需要所有节点都确认,进入最后的 final 阶段。
  • final 阶段。分配rank,RANK 0 的实例成为 master。

仿照上图,我们把 c10d 拓展如下。

      +
|
|
v
+-----+------+
| |
| closed +---------------> ERROR_CLOSED
| |
+-----+------+
|
|
v
+-----+------+ is_participant
| |
| complete +---------------> FINISH
| |
+-----+------+
| is_participant
|
v +----> REMOVE_FROM_PARTICIPANTS
+-----+-------+ now > deadline +-----------+ now < rollback +-----------+ |
| | | | | | |
| join +----------------> | timeout +---------------------->+ rollback +-----+
| | | | | | |
+-----+-------+ +----+------+ +-----------+ |
| | | in state.wait_list
| | now > rollback |
| now < deadline | +----> REMOVE_FROM_WAIT_LIST
| +----------> ERROR_TIMEOUT
|
| complete && not is_participant && < max && not in state.wait_list
|
+------------------------------------------------------------------> ADD_TO_WAIT_LIST
|
| not complete && is_participant && > min && > deadline
|
+------------------------------------------------------------------> MARK_RENDEZVOUS_COMPLETE
|
| not complete && not is_participant
|
+-----------------------------------------> ADD_TO_PARTICIPANTS
|
| _should_keep_alive
|
+---------------------------> KEEP_ALIVE
|
|
v
SYNC

手机如下:

0x04 业务操作

_DistributedRendezvousOpExecutor.run 的内部就是依据 action 选择不同的业务函数来执行。

            if action == _Action.KEEP_ALIVE:
self._keep_alive()
elif action == _Action.ADD_TO_PARTICIPANTS:
self._add_to_participants()
elif action == _Action.ADD_TO_WAIT_LIST:
self._add_to_wait_list()
elif action == _Action.REMOVE_FROM_PARTICIPANTS:
self._remove_from_participants()
elif action == _Action.REMOVE_FROM_WAIT_LIST:
self._remove_from_wait_list()
elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
self._mark_rendezvous_complete()
elif action == _Action.MARK_RENDEZVOUS_CLOSED:
self._mark_rendezvous_closed()

我们接下来就看看具体这些内部函数逻辑。

4.1 加入参与者

接受到 ADD_TO_PARTICIPANTS 之后,调用 _add_to_participants 从等待列表中移除节点,往 participants 加入这个节点。

    def _add_to_participants(self) -> None:

        state = self._state

        try:
state.wait_list.remove(self._node)
except KeyError:
pass # The ranks of the participants will be set once the rendezvous is
# complete.
state.participants[self._node] = 0 self._keep_alive() if len(state.participants) == self._settings.min_nodes:
state.deadline = datetime.utcnow() + self._settings.timeout.last_call if len(state.participants) == self._settings.max_nodes:
self._mark_rendezvous_complete()

4.2 移除参与者

接受到 REMOVE_FROM_PARTICIPANTS 之后,调用 _remove_from_participants 从 participants 和 last_heartbeats 中删除参与者。

    def _remove_from_participants(self) -> None:

        state = self._state
del state.participants[self._node]
del state.last_heartbeats[self._node] if state.complete:
# If we do not have any participants left, move to the next round.
if not state.participants:
state.complete = False
state.round += 1
else:
if len(state.participants) < self._settings.min_nodes:
state.deadline = None

4.3 加入等待序列

接受到 ADD_TO_WAIT_LIST 之后,调用 _add_to_wait_list 网 wait_list 中加入节点。

    def _add_to_wait_list(self) -> None:
self._state.wait_list.add(self._node)
self._keep_alive()

4.4 移除等待序列

接受到 REMOVE_FROM_WAIT_LIST 之后,调用 _remove_from_wait_list 从 wait_list 移除节点。

    def _remove_from_wait_list(self) -> None:
self._state.wait_list.remove(self._node)
del self._state.last_heartbeats[self._node]

4.5 设置结束

接受到 MARK_RENDEZVOUS_COMPLETE 之后,当 rendezvous 聚合操作结束之后,给每一个参与者设置 rank。

每个节点上都是按照同样算法排序,所以rank在每个节点上都是一样的。

    def _mark_rendezvous_complete(self) -> None:
state = self._state state.complete = True
state.deadline = None # Assign the ranks.
for rank, node in enumerate(sorted(state.participants)):
state.participants[node] = rank def _mark_rendezvous_closed(self) -> None:
self._state.closed = True

4.6 心跳

接收到 KEEP_ALIVE action之后,会调用到 _keep_alive 来维持心跳。另外,keep_alive 也会在 _add_to_participants等方法内被调用,会更新本地state之中的last heartbeats,下一次 sync 时候,会把 last_heartbeats 写入键值存储,这样其他Node就可以知道这个节点的状态了。而本地则会在 _sanitize 之中依据 last_heartbeats 做处理,我们之前提到过。

def _keep_alive(self) -> None:
msg = (
f"The node '{self._node}' updated its keep-alive heartbeat time for the rendezvous "
f"'{self._settings.run_id}'. Pending sync."
)
self._record(message=msg)
self._state.last_heartbeats[self._node] = datetime.utcnow()

_record 方法如下:

def _record(self, message: str, node_state: NodeState = NodeState.RUNNING) -> None:
construct_and_record_rdzv_event(
name=f"{self.__class__.__name__}.{get_method_name()}",
run_id=self._settings.run_id,
message=message,
node_state=node_state,
hostname=self._node.fqdn,
pid=self._node.pid,
local_id=self._node.local_id,
)

其就是调用如下代码记录log。

def record_rdzv_event(event: RdzvEvent) -> None:
_get_or_create_logger("dynamic_rendezvous").info(event.serialize()) def construct_and_record_rdzv_event(
run_id: str,
message: str,
node_state: NodeState,
name: str = "",
hostname: str = "",
pid: Optional[int] = None,
master_endpoint: str = "",
local_id: Optional[int] = None,
rank: Optional[int] = None,
) -> None:
# We don't want to perform an extra computation if not needed.
if isinstance(get_logging_handler("dynamic_rendezvous"), logging.NullHandler):
return # Set up parameters.
if not hostname:
hostname = socket.getfqdn()
if not pid:
pid = os.getpid() # Determines which file called this function.
callstack = inspect.stack()
filename = "no_file"
if len(callstack) > 1:
stack_depth_1 = callstack[1]
filename = os.path.basename(stack_depth_1.filename)
if not name:
name = stack_depth_1.function # Delete the callstack variable. If kept, this can mess with python's
# garbage collector as we are holding on to stack frame information in
# the inspect module.
del callstack # Set up error trace if this is an exception
if node_state == NodeState.FAILED:
error_trace = traceback.format_exc()
else:
error_trace = "" # Initialize event object
event = RdzvEvent(
name=f"{filename}:{name}",
run_id=run_id,
message=message,
hostname=hostname,
pid=pid,
node_state=node_state,
master_endpoint=master_endpoint,
rank=rank,
local_id=local_id,
error_trace=error_trace,
) # Finally, record the event.
record_rdzv_event(event)

至此,引擎部分也已经分析完毕,下一篇我们看看是否可以从整体角度再做一下全面梳理。

0xFF 参考

[源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路

[源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程

[源码解析] PyTorch 分布式之弹性训练(3)---代理

[源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑

[源码解析] PyTorch 分布式之弹性训练(5)---Rendezvous 引擎的更多相关文章

  1. [源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑

    [源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑 目录 [源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑 0x00 ...

  2. [源码解析] PyTorch 分布式之弹性训练(6)---监控/容错

    [源码解析] PyTorch 分布式之弹性训练(6)---监控/容错 目录 [源码解析] PyTorch 分布式之弹性训练(6)---监控/容错 0x00 摘要 0x01 总体逻辑 1.1 Node集 ...

  3. [源码解析] PyTorch 分布式之弹性训练(7)---节点变化

    [源码解析] PyTorch 分布式之弹性训练(7)---节点变化 目录 [源码解析] PyTorch 分布式之弹性训练(7)---节点变化 0x00 摘要 0x01 变化方式 1.1 Scale-d ...

  4. [源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路

    [源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路 目录 [源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路 0x00 摘要 0x01 痛点 0x02 难点 0 ...

  5. [源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程

    [源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程 目录 [源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程 0x00 摘要 0x01 ...

  6. [源码解析] PyTorch 分布式之弹性训练(3)---代理

    [源码解析] PyTorch 分布式之弹性训练(3)---代理 目录 [源码解析] PyTorch 分布式之弹性训练(3)---代理 0x00 摘要 0x01 总体背景 1.1 功能分离 1.2 Re ...

  7. [源码解析] PyTorch 分布式(1)------历史和概述

    [源码解析] PyTorch 分布式(1)------历史和概述 目录 [源码解析] PyTorch 分布式(1)------历史和概述 0x00 摘要 0x01 PyTorch分布式的历史 1.1 ...

  8. [源码解析] PyTorch分布式(5) ------ DistributedDataParallel 总述&如何使用

    [源码解析] PyTorch 分布式(5) ------ DistributedDataParallel 总述&如何使用 目录 [源码解析] PyTorch 分布式(5) ------ Dis ...

  9. [源码解析] PyTorch 分布式(18) --- 使用 RPC 的分布式管道并行

    [源码解析] PyTorch 分布式(18) --- 使用 RPC 的分布式管道并行 目录 [源码解析] PyTorch 分布式(18) --- 使用 RPC 的分布式管道并行 0x00 摘要 0x0 ...

随机推荐

  1. Go语言核心36讲(Go语言实战与应用二十一)--学习笔记

    43 | bufio包中的数据类型(下) 在上一篇文章中,我提到了bufio包中的数据类型主要有Reader.Scanner.Writer和ReadWriter.并着重讲到了bufio.Reader类 ...

  2. 非寻常方式学习ApacheTomcat架构及10.0.12源码编译

    概述 开启博客分享已近三个月,感谢所有花时间精力和小编一路学习和成长的伙伴们,有你们的支持,我们继续再接再厉 **本人博客网站 **IT小神 www.itxiaoshen.com 定义 Tomcat官 ...

  3. Shell 打印空行的行号

    目录 Shell 打印空行的行号 题解 Shell 打印空行的行号 写一个 bash脚本以输出一个文本文件 nowcoder.txt中空行的行号,可能连续,从1开始 示例: 假设 nowcoder.t ...

  4. day08 Nginx模块

    day08 Nginx模块 lnmp架构 l :Linux n :Nginx m :MySQL p :Python/PHP lnmp架构:是最简单的架构 Nginx中的模块(Python模块):前提是 ...

  5. Hive相关知识点

    ---恢复内容开始--- 转载:Hive 性能优化 介绍 首先,我们来看看Hadoop的计算框架特性,在此特性下会衍生哪些问题? 数据量大不是问题,数据倾斜是个问题. jobs数比较多的作业运行效率相 ...

  6. 【leetocode】55. Jump Game

    You are given an integer array nums. You are initially positioned at the array's first index, and ea ...

  7. 【原创】Altium生成Gerber时跳出The Film is too small for this PCB的解决办法

    在用altium Designer画板子的时候,要生成gerber文件的时候,会出错,出现这样的提示框:"The Film is too small for this PCB" 原 ...

  8. 【Linux】【Shell】【text】grep

    grep: Global search REgular expression and Print out the line. 作用:文本搜索工具,根据用户指定的"模式(过滤条件)" ...

  9. 测试JDBCUtils的重用性

    package cn.itcast.jdbc;import cn.itcast.util.JDBCUtils;import java.sql.*;import java.util.Properties ...

  10. Android工具-DDMS

    原创文章,如有转载,请注明出处:http://blog.csdn.net/yihui823/article/details/6686578 本文章的前提:已经安装了Eclipse和ADT.androi ...