[源码解析] 机器学习参数服务器ps-lite(2) ----- 通信模块Van
[源码解析] 机器学习参数服务器ps-lite(2) ----- 通信模块Van
0x00 摘要
本文是参数服务器系列第二篇,介绍ps-lite的通信模块 Van。
本系列其他文章是:
[源码解析] 机器学习参数服务器ps-lite 之(1) ----- PostOffice
0x01 功能概述
邮局里有了地址簿,就需要有货车来负责拉送物件,Van 就是整个Parameter Server的通信模块,其特点如下。
- PostOffice 类在实例化的时候,会创建一个 Van 类的实例 作为成员变量。该 Van 实例与所属 PostOffice 实例生命周期相同(每个节点只有一个该对象);
- Van 负责具体的节点间通信。具体来说就是负责建立起节点之间的互相连接(例如Worker与Scheduler之间的连接),并且开启本地的receiving thread用来监听收到的message。
VAN 目前有两个实现:
- ZMQVan是基于zeromq的Van的实现,即用zmq库实现了连接的底层细节(zmq库是一个开源库,对socket进行了优良的封装,他使得Socket编程更加简单、简洁和性能更高)。
- IBVerbsVan 是字节跳动的实现,具体没有深入研究。
0x02 定义
2.1 UML图
首先给出 UML 图。
2.2 主要说明
下面我们只给出Van对象关键变量和成员函数说明。
其主要变量如下:
Node scheduler_ :Scheduler 节点参数,每一个node都会记录Scheduler 节点的信息;
Node my_node_ : 本节点参数。如果本节点是Scheduler,则 my_node_ 会指向上面的 scheduler_ ;
bool is_scheduler_ : 本节点是否是 scheduler;
std::unique_ptr< std::thread> receiver_thread_ :接收消息线程指针;
std::unique_ptr< std::thread> heartbeat_thread_ :发送心跳线程指针;
std::vector barrier_count_ :barrier 计数,用来记录登记节点数目,只有所有节点都登记之后,系统才到了 ready 状态,scheduler 才会给所有节点发送 ready 消息,系统才正式启动。
Resender *resender_ = nullptr :重新发送消息指针;
std::atomic timestamp_{0} :message 自增 id,原子变量;
std::unordered_map<std::string, int> connected_nodes_ : 记录了目前连接到哪些 nodes;
其主要函数功能如下:
start :建立通信初始化;
Receiving :接收消息线程的处理函数;
Heartbeat :发送心跳线程的处理函数;
ProcessAddNodeCommandAtScheduler :scheduler 的 AddNode 消息处理函数;
ProcessHearbeat:心跳包处理函数;
ProcessDataMsg :数据消息(push & pull)处理函数;
ProcessAddNodeCommand :worker 和 server 的 AddNode 消息处理函数;
ProcessBarrierCommand :Barrier 消息处理函数;
2.3 线程管理
PS Lite 定义的三种角色采用多线程机制工作,每个线程承担特定的职责,在所属的 Van 实例启动时被创建。
具体描述如下:
- Scheduler,Worker 和 Server 的 Van 实例里均持有一个接受数据的线程。
- Worker 和 Server 的 Van 实例里还持有一个间歇地向 Scheduler 发送心跳的线程。
- 如果定义了值不为 0 环境变量 PS_RESEND,那么 Scheduler、Worker 和 Server 还会启动一个监控线程。
2.4 类定义
详细代码(摘要)如下:
class Van {
public:
static Van *Create(const std::string &type);
virtual void Start(int customer_id);
int Send(const Message &msg);
virtual void Stop();
inline int GetTimestamp() { return timestamp_++; }
inline bool IsReady() { return ready_; }
protected:
//连结节点
virtual void Connect(const Node &node) = 0;
//绑定到自己节点之上
virtual int Bind(const Node &node, int max_retry) = 0;
//接收消息,用阻塞方式
virtual int RecvMsg(Message *msg) = 0;
//发送消息
virtual int SendMsg(const Message &msg) = 0;
/**
* \brief pack meta into a string
*/
void PackMeta(const Meta &meta, char **meta_buf, int *buf_size);
/**
* \brief pack meta into protobuf
*/
void PackMetaPB(const Meta &meta, PBMeta *pb);
/**
* \brief unpack meta from a string
*/
void UnpackMeta(const char *meta_buf, int buf_size, Meta *meta);
Node scheduler_;
Node my_node_;
bool is_scheduler_;
std::mutex start_mu_;
private:
/** thread function for receving */
void Receiving();
/** thread function for heartbeat */
void Heartbeat();
// node's address string (i.e. ip:port) -> node id
// this map is updated when ip:port is received for the first time
std::unordered_map<std::string, int> connected_nodes_;
// maps the id of node which is added later to the id of node
// which is with the same ip:port and added first
std::unordered_map<int, int> shared_node_mapping_;
/** whether it is ready for sending */
std::atomic<bool> ready_{false};
std::atomic<size_t> send_bytes_{0};
size_t recv_bytes_ = 0;
int num_servers_ = 0;
int num_workers_ = 0;
/** the thread for receiving messages */
std::unique_ptr<std::thread> receiver_thread_;
/** the thread for sending heartbeat */
std::unique_ptr<std::thread> heartbeat_thread_;
std::vector<int> barrier_count_;
/** msg resender */
Resender *resender_ = nullptr;
int drop_rate_ = 0;
std::atomic<int> timestamp_{0};
int init_stage = 0;
//以下是处理各种类型消息
void ProcessAddNodeCommandAtScheduler(Message *msg, Meta *nodes,
Meta *recovery_nodes);
void ProcessTerminateCommand();
void ProcessAddNodeCommand(Message *msg, Meta *nodes, Meta *recovery_nodes);
void ProcessBarrierCommand(Message *msg);
void ProcessHearbeat(Message *msg);
void ProcessDataMsg(Message *msg);
//更新本地NodeID
void UpdateLocalID(Message *msg, std::unordered_set<int> *deadnodes_set,
Meta *nodes, Meta *recovery_nodes);
const char *heartbeat_timeout_val =
Environment::Get()->find("PS_HEARTBEAT_TIMEOUT");
int heartbeat_timeout_ =
heartbeat_timeout_val ? atoi(heartbeat_timeout_val) : 0;
DISALLOW_COPY_AND_ASSIGN(Van);
};
0x03 初始化
Van对象的初始化函数作用就是依据本地节点类型的不同,做不同设置,从而启动端口,建立到scheduler的连结,启动接收消息线程,心跳线程等,这样就可以进行通信了。具体如下:
- 首先从环境变量中得到相关信息,比如scheduler 的 "ip,port"(这两个是预先设置的),本节点的角色(Worker/Server/Scheduler)等等,然后 初始化scheduler_这个成员变量;
- 如果本节点是 scheduler,则把 scheduler_ 赋值给 my_node_;
- 如果本节点不是 scheduler,则:
- 从系统中获取本节点的ip信息;
- 使用 GetAvailablePort 获取一个port;
- 使用 Bind 绑定一个端口;
- 调用 Connect 建立到 Scheduler 的连接(scheduler也连接到自己的那个预先设置的固定端口);
- 启动本地Node的接收消息线程
receiver_thread_
,执行Van::Receiving
; - 如果本节点不是 scheduler,给 Scheduler 发送一个 ADD_NODE 消息,这样可以将本地Node的信息告知Scheduler,即注册到 scheduler;
- 然后进入等待状态,等待Scheduler通知 Ready(scheduler 会等待所有节点都完成注册后,统一发送 ready); 注意,这里 scheduler 节点也会等,但是不影响 scheduler 节点 的 recevie 线程接受处理消息;
- Ready后启动心跳线程,建立到Scheduler的Heartbeat 连接;
关于7,8两点的进一步说明就是:
- 当worker和server节点绑定ip和port后,便向scheduler节点发送ADD_NODE message。
- 当 scheduler收到所有worker和server的ADD_NODE message后,则依次应答ADD_NODE message,
- 各个节点在此过程中通过原子变量ready_等待上述过程完成。
具体代码如下:
void Van::Start(int customer_id) {
// get scheduler info
start_mu_.lock();
if (init_stage == 0) {
// 初始化scheduler_这个成员变量
scheduler_.hostname = std::string(
CHECK_NOTNULL(Environment::Get()->find("DMLC_PS_ROOT_URI")));
scheduler_.port =
atoi(CHECK_NOTNULL(Environment::Get()->find("DMLC_PS_ROOT_PORT")));
scheduler_.role = Node::SCHEDULER;
scheduler_.id = kScheduler;
// 确认本节点是scheduler节点
is_scheduler_ = Postoffice::Get()->is_scheduler();
// get my node info
if (is_scheduler_) {
// 初始化本节点,因为是scheduler,所以直接赋值
my_node_ = scheduler_;
} else {
auto role = Postoffice::Get()->is_worker() ? Node::WORKER : Node::SERVER;
const char* nhost = Environment::Get()->find("DMLC_NODE_HOST");
std::string ip;
if (nhost) ip = std::string(nhost);
if (ip.empty()) {
const char* itf = Environment::Get()->find("DMLC_INTERFACE");
std::string interface;
if (itf) interface = std::string(itf);
if (interface.size()) {
GetIP(interface, &ip);
} else {
GetAvailableInterfaceAndIP(&interface, &ip);
}
}
int port = GetAvailablePort();
const char* pstr = Environment::Get()->find("PORT");
if (pstr) port = atoi(pstr);
my_node_.hostname = ip;
my_node_.role = role;
my_node_.port = port;
// cannot determine my id now, the scheduler will assign it later
// set it explicitly to make re-register within a same process possible
my_node_.id = Node::kEmpty;
my_node_.customer_id = customer_id;
}
// bind.
//绑定接口,把本节点绑定到ip:port这个socket上,理论来说这个函数就是初始化了receiver_
my_node_.port = Bind(my_node_, is_scheduler_ ? 0 : 40);
// connect to the scheduler
// 连接上scheduler_,由于本节点就是scheduler_,其实就是初始化senders_,由于发送的节点很多,所以这里是一个map<int,void*>
// 在这里就是senders_[1] = socket_1, socket_1中的body设置一点字符“ps1***”, 注意链接不是sendMsg
Connect(scheduler_);
// for debug use
if (Environment::Get()->find("PS_DROP_MSG")) {
drop_rate_ = atoi(Environment::Get()->find("PS_DROP_MSG"));
}
// start receiver
// 开启一个接收消息的线程,这里就是处理消息
receiver_thread_ =
std::unique_ptr<std::thread>(new std::thread(&Van::Receiving, this));
init_stage++;
}
start_mu_.unlock();
if (!is_scheduler_) {
// let the scheduler know myself
// worker和server节点会通过 ADD_NODE 消息把本地节点的信息告诉scheduler,比如角色,ip,port...
Message msg;
Node customer_specific_node = my_node_;
customer_specific_node.customer_id = customer_id;
msg.meta.recver = kScheduler;
msg.meta.control.cmd = Control::ADD_NODE;
msg.meta.control.node.push_back(customer_specific_node);
msg.meta.timestamp = timestamp_++;
Send(msg);
}
// wait until ready
// 等待 ready_ 从false变成true,当是scheduler的时候,必须要有等worker和server节点过来,不然一直都是阻塞在这,如果是 worker/server,则是等待 scheduler 发送系统allready消息。
while (!ready_.load()) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
start_mu_.lock();
if (init_stage == 1) {
// resender
if (Environment::Get()->find("PS_RESEND") &&
atoi(Environment::Get()->find("PS_RESEND")) != 0) {
int timeout = 1000;
if (Environment::Get()->find("PS_RESEND_TIMEOUT")) {
timeout = atoi(Environment::Get()->find("PS_RESEND_TIMEOUT"));
}
// 如果设置了超时重传,就初始化resender_这个变量
resender_ = new Resender(timeout, 10, this);
}
if (!is_scheduler_) {
// start heartbeat thread
// 初始化心跳线程
heartbeat_thread_ =
std::unique_ptr<std::thread>(new std::thread(&Van::Heartbeat, this));
}
init_stage++;
}
start_mu_.unlock();
}
0x04 接受消息
我们首先介绍后台线程是如何运行,然后会具体分析如何处理各种消息。
4.1 后台处理消息线程
ps-lite 启动了一个后台线程 receiver_thread_ 进行接受/处理消息。
// start receiver
receiver_thread_ =
std::unique_ptr<std::thread>(new std::thread(&Van::Receiving, this));
4.2 处理函数
receiver_thread_ 使用 Receiving 函数进行消息处理。
4.2.1 控制信息
除了传递参数的数据消息外,各个节点之间控制信息有:
- ADD_NODE:worker和server向shceduler进行节点注册;
- BARRIER:节点间的同步阻塞消息;
- HEARTBEAT:节点间的心跳信号;
- TERMINATE:节点退出信号;
- ACK:确认消息,ACK 类型只有启用了 Resender 类才会出现。
- EMPTY:push or pull;
因此在 Receiving 之中会调用 不同处理函数处理不同类型的消息:
- ProcessTerminateCommand :处理 TERMINATE;
- ProcessAddNodeCommand :处理 ADD_NODE;
- ProcessBarrierCommand :处理 BARRIER(在上文已经分析);
- ProcessHearbeat :处理 HEARTBEAT;
4.2.2 线程内全局变量
线程内有两个变量,因为其是在 while (true) 循环之外,所以属于线程内的全局变量,这点在阅读代码时候需要注意。
- nodes :只有 scheduler 在处理 ADD_NODE 时候会用到,存储目前 scheduler 内部拥有的所有 nodes;
- recovery_nodes :只有 scheduler 在处理 ADD_NODE 时候会用到,存储目前 scheduler 内部拥有的所有 recovery nodes(康复重启的节点);
4.2.3 具体实现
Receiving 逻辑如下:
- 调用 RecvMsg(派生类会实现)获取最新消息;
- 如果设定了采样,则进行 drop;
- 若设置了重传机制,则会检测此消息是否重复,利用 resender_->AddIncomming(msg) 来处理重复消息;
- 处理控制消息或者数据消息;
具体代码如下
void Van::Receiving() {
Meta nodes;
// 以下两个可以认为是全局变量
Meta recovery_nodes; // store recovery nodes 储存康复重启的节点
recovery_nodes.control.cmd = Control::ADD_NODE; // 康复重启节点的control.cmd 都设置为 ADD_NODE
while (true) {
Message msg;
int recv_bytes = RecvMsg(&msg); //利用receiver_ 变量拿到消息
// For debug, drop received message
if (ready_.load() && drop_rate_ > 0) {
unsigned seed = time(NULL) + my_node_.id;
if (rand_r(&seed) % 100 < drop_rate_) {
LOG(WARNING) << "Drop message " << msg.DebugString();
continue;
}
}
CHECK_NE(recv_bytes, -1);
recv_bytes_ += recv_bytes; //收到的字节数累加
if (Postoffice::Get()->verbose() >= 2) {
PS_VLOG(2) << msg.DebugString();
}
// duplicated message
if (resender_ && resender_->AddIncomming(msg)) continue; //重传确认机制
if (!msg.meta.control.empty()) { //如果是控制类型的消息
// control msg
auto& ctrl = msg.meta.control;
if (ctrl.cmd == Control::TERMINATE) {
ProcessTerminateCommand();
break;
} else if (ctrl.cmd == Control::ADD_NODE) {
ProcessAddNodeCommand(&msg, &nodes, &recovery_nodes); //当执行到这个位置的时候继续跳转
} else if (ctrl.cmd == Control::BARRIER) {
ProcessBarrierCommand(&msg);
} else if (ctrl.cmd == Control::HEARTBEAT) {
ProcessHearbeat(&msg); // 发回Heartbeat的ACK
} else {
LOG(WARNING) << "Drop unknown typed message " << msg.DebugString();
}
} else { //非控制类型的消息处理方式
ProcessDataMsg(&msg);
}
}
}
4.3 处理 ADD_NODE 消息
ADD_NODE 是 worker / server 用来向 scheduler 注册自身的控制消息。
4.3.1 注册逻辑
先回忆下注册基本思路。
- 当worker和server节点绑定ip和port后,便向scheduler节点发送ADD_NODE message。
- 当 scheduler收到所有worker和server的ADD_NODE message后则依次应答ADD_NODE message,注意,应答的也是 同类型ADD_NODE 消息。
- 各个节点(scheduler, worker, server)在此过程中通过原子变量ready_等待上述过程完成。
4.3.2 ProcessAddNodeCommand
ProcessAddNodeCommand 逻辑如下。
- 查出心跳包超时的id,转存到dead_set之中。
- 拿到收到消息里面的control信息。
- 调用 UpdateLocalID,在其中:
- 如果是新node,Scheduler记录这个新的node。
- 如果这个node是重启产生的,则将旧node的信息更新。
- 如果是 scheduler,则:
- 调用 ProcessAddNodeCommandAtScheduler 收到所有worker和server的ADD_NODE 的消息后进行节点id分配并应答,即 设定最新的所有node的rank并发送给所有Worker和Server。
- 如果不是 scheduler,说明 work & server 收到了 scheduler 回答的 ADD_NODE 消息,则:
- 如果自身是现有节点,则在 connected_nodes_ 之中不会找到这个新节点,则先有节点会调用 Connect 与新节点建立连接。
- 如果自身是新节点,则会连接所有现有节点(非同类型)。
- 在 connected_nodes_ 之中更新 全局节点信息,包括 global rank(本地Node的全局rank等信息是由receiver_thread_在这里获取);
- 最后设置 ready_ = true,即本节点也可以运行了,因为主线程会阻塞在其上。
具体代码如下:
void Van::ProcessAddNodeCommand(Message* msg, Meta* nodes,
Meta* recovery_nodes) {
auto dead_nodes = Postoffice::Get()->GetDeadNodes(heartbeat_timeout_);//查出心跳包超时的id
std::unordered_set<int> dead_set(dead_nodes.begin(), dead_nodes.end());//转存到dead_set之中
auto& ctrl = msg->meta.control; //拿到收到消息里面的control信息
UpdateLocalID(msg, &dead_set, nodes, recovery_nodes);
if (is_scheduler_) { // Scheduler 节点
ProcessAddNodeCommandAtScheduler(msg, nodes, recovery_nodes);
} else { // Worker & Server 节点
for (const auto& node : ctrl.node) {
std::string addr_str = node.hostname + ":" + std::to_string(node.port);
if (connected_nodes_.find(addr_str) == connected_nodes_.end()) {
// 现有节点会在自己连接之中查找这个新节点,发现现有连接中没有这个新节点
// 如果是新节点,则会连接现有节点(非同类型)
Connect(node); // 与新节点进行连接
connected_nodes_[addr_str] = node.id; // 加入已经连接的节点
}
if (!node.is_recovery && node.role == Node::SERVER) ++num_servers_;
if (!node.is_recovery && node.role == Node::WORKER) ++num_workers_;
}
ready_ = true;
}
}
4.3.3 UpdateLocalID
此函数作用是更新节点内部的node id 信息,也是分为两种情况,函数逻辑如下:
- 如果msg->meta.sender是Meta::kEmpty,即未设定,则处理此message的一定是Scheduler,会进入 if 分支。
- 如果目前 nodes 的control.node数目小于 "配置的server数目 + 配置的worker数目",则说明是系统启动阶段,将当前消息的node信息加入到 control.node 之中。
- 否则说明是系统运行阶段,应该是有些节点死掉重启后再次连接。那么,就从 nodes 的control.node 之中找到一个已经死掉的且节点role 与当前消息一致(同类型)的 node id,把这个 node id 赋给这个重启的节点。并且更新 nodes->control.node 和 recovery_nodes。
- 下面就是普通节点处理的逻辑:
- 即在 scheduler 传回来的所有节点信息中查找,目的是找到与自己的ip,port一致的节点。
- 如果找到,就更新本地节点信息(因为在本节点启动时候,并没有设置 node_id 这个信息,这个需要scheduler统一设置,从注释看,目的是为了使重新注册成为可能)。包括全局 rank 信息。
具体代码如下:
void Van::UpdateLocalID(Message* msg, std::unordered_set<int>* deadnodes_set,
Meta* nodes, Meta* recovery_nodes) {
auto& ctrl = msg->meta.control;
size_t num_nodes =
Postoffice::Get()->num_servers() + Postoffice::Get()->num_workers();
// assign an id
if (msg->meta.sender == Meta::kEmpty) { //如果sender未设定,则处理此message的一定是Scheduler
CHECK(is_scheduler_);
CHECK_EQ(ctrl.node.size(), 1); //msg中的control命令中的节点集合就是worker自己,所以就是1个节点
if (nodes->control.node.size() < num_nodes) { //没有到齐
nodes->control.node.push_back(ctrl.node[0]);
} else { //如果所有work和server到齐了,就进入else
// some node dies and restarts
CHECK(ready_.load());
for (size_t i = 0; i < nodes->control.node.size() - 1; ++i) {
const auto& node = nodes->control.node[i];
if (deadnodes_set->find(node.id) != deadnodes_set->end() &&
node.role == ctrl.node[0].role) {
auto& recovery_node = ctrl.node[0];
// assign previous node id
recovery_node.id = node.id;
recovery_node.is_recovery = true;
nodes->control.node[i] = recovery_node;
recovery_nodes->control.node.push_back(recovery_node);
break;
}
}
}
}
// update my id / 对普通的node,更新其rank,scheduler 节点不会起作用(因为找不到)。
// schedule发给此work节点的消息,如果发现本地的ip和port和消息中的某个一点重合,那么就把本地节点的ID(初始化时候没有ID,只是等于Empty)改为schedule发过来的 node id。
for (size_t i = 0; i < ctrl.node.size(); ++i) {
const auto& node = ctrl.node[i];
if (my_node_.hostname == node.hostname && my_node_.port == node.port) {
if (getenv("DMLC_RANK") == nullptr || my_node_.id == Meta::kEmpty) {
my_node_ = node;
std::string rank = std::to_string(Postoffice::IDtoRank(node.id));
#ifdef _MSC_VER
_putenv_s("DMLC_RANK", rank.c_str());
#else
setenv("DMLC_RANK", rank.c_str(), true);
#endif
}
}
}
}
4.3.4 ProcessAddNodeCommandAtScheduler
ProcessAddNodeCommandAtScheduler 是在 Scheduler 之内运行,是对控制类型消息的处理。
对于Scheduler节点来说,scheduler收到所有worker和server的ADD_NODE的消息后进行节点id分配并应答,即,需要设定 最新的所有node的 全局rank 并发送给所有Worker和Server。
- 当接受到所有 worker & server 的注册消息之后(
nodes->control.node.size() == num_nodes
):- 将节点按照 ip + port 组合排序。
- Scheduler 与所有注册的节点建立连接、更新心跳时间戳,给 scheduler所有连接的节点分配全局 rank。
- 向所有的worker和server发送ADD_NODE消息(携带scheduler之中的所有node信息)。
- 会把
ready_ = true
; 即 scheduler 是一个 ready 状态了,不管 worker 和 server 是否确认收到ADD_NODE消息。 - 而在接收端(worker & server)的,每一个本地Node的全局rank等信息是由接收端 receiver_thread_(其他函数)获取,就是得到了 scheduler 返回的这些 nodes 信息。
- 如果
!recovery_nodes->control.node.empty()
,这就表明是处理某些重启节点的注册行为:- 查出心跳包超时的id,转存到dead_set之中。
- 与重启节点建立连接(因为接收到了一个ADD_NODE),所以只与这个新重启节点建立连接即可(在代码中有
CHECK_EQ(recovery_nodes->control.node.size(), 1)
来确认重启节点为 1 个)。 - 更新重启节点的心跳。
- 因为新加入了重启节点,所以用一个发送达到两个目的:
- 向所有 recovery 的worker和server发送ADD_NODE消息(携带scheduler之中的目前所有node信息)。
- 向 alive 节点发送 recovery 节点信息。
- 这样,收到消息的节点会则分别与新节点相互建立连接;
具体代码如下:
void Van::ProcessAddNodeCommandAtScheduler(Message* msg, Meta* nodes,
Meta* recovery_nodes) {
recovery_nodes->control.cmd = Control::ADD_NODE;
time_t t = time(NULL);
size_t num_nodes =
Postoffice::Get()->num_servers() + Postoffice::Get()->num_workers();
// scheduler收到所有worker和server的ADD_NODE的消息后进行节点id分配并应答
if (nodes->control.node.size() == num_nodes) { // 节点收集完全
// sort the nodes according their ip and port, 根据IP和port给worker,server排个序
std::sort(nodes->control.node.begin(), nodes->control.node.end(),
[](const Node& a, const Node& b) {
return (a.hostname.compare(b.hostname) | (a.port < b.port)) > 0;
});
// assign node rank
for (auto& node : nodes->control.node) {
// 建立连接、更新心跳时间戳,给 scheduler所有连接的节点分配全局 rank。
std::string node_host_ip =
node.hostname + ":" + std::to_string(node.port);
if (connected_nodes_.find(node_host_ip) == connected_nodes_.end()) { //如果ip:port不存在van_中的话
CHECK_EQ(node.id, Node::kEmpty); //判断是不是初始化节点
int id = node.role == Node::SERVER
? Postoffice::ServerRankToID(num_servers_)
: Postoffice::WorkerRankToID(num_workers_); //如果是sever的话,就id产生一个id号,num_servers_初始化为0
node.id = id; //将这个新节点的id赋值为id
Connect(node); //连接这个新节点, 即建立一个socket, 然后senders_[id] = sender; 就是将目标id的socket存放起来后面使用
Postoffice::Get()->UpdateHeartbeat(node.id, t);//更新心跳包
connected_nodes_[node_host_ip] = id; //既然 worker, server 已经发message来了,scheduler要把这个节点作为已经链接的节点
} else {
int id = node.role == Node::SERVER
? Postoffice::ServerRankToID(num_servers_)
: Postoffice::WorkerRankToID(num_workers_);
shared_node_mapping_[id] = connected_nodes_[node_host_ip];
node.id = connected_nodes_[node_host_ip];
}
if (node.role == Node::SERVER) num_servers_++;//更新rank
if (node.role == Node::WORKER) num_workers_++;
}
nodes->control.node.push_back(my_node_); //把本节点放到里面
nodes->control.cmd = Control::ADD_NODE;
Message back;
back.meta = *nodes;
// 向所有的worker和server发送ADD_NODE消息
for (int r : Postoffice::Get()->GetNodeIDs(kWorkerGroup + kServerGroup)) {
int recver_id = r;
if (shared_node_mapping_.find(r) == shared_node_mapping_.end()) {
back.meta.recver = recver_id;
back.meta.timestamp = timestamp_++;
Send(back);
}
}
ready_ = true; //scheduler已经准备好了
} else if (!recovery_nodes->control.node.empty()) { // 节点没有收集完全
auto dead_nodes = Postoffice::Get()->GetDeadNodes(heartbeat_timeout_);//查出心跳包超时的id
std::unordered_set<int> dead_set(dead_nodes.begin(), dead_nodes.end());//转存到dead_set
// send back the recovery node
CHECK_EQ(recovery_nodes->control.node.size(), 1);
Connect(recovery_nodes->control.node[0]);
Postoffice::Get()->UpdateHeartbeat(recovery_nodes->control.node[0].id, t);
Message back;
for (int r : Postoffice::Get()->GetNodeIDs(kWorkerGroup + kServerGroup)) {
if (r != recovery_nodes->control.node[0].id &&
dead_set.find(r) != dead_set.end()) {
// do not try to send anything to dead node
continue;
}
// only send recovery_node to nodes already exist
// but send all nodes to the recovery_node
back.meta =
(r == recovery_nodes->control.node[0].id) ? *nodes : *recovery_nodes;
back.meta.recver = r;
back.meta.timestamp = timestamp_++;
Send(back);
}
}
}
此部分流程逻辑如下:
+
Scheduler | Worker
|
+ | +
| | |
| | |
v | |
Postoffice::Start +----> Van::Start | |
+ | |
| | |
| | |
v | |
Connect--do nothing | |
+ | v
| |
| | Postoffice::Start +-----> Van::Start
| | +
v | |
receiver_thread_ +---+ | |
+ | | v
| | | Connect--to scheduler
| | | +
| | | |
| | | |
| | | |
| | | v
| | | receiver_thread_ +----->+
| | | + |
| | | | |
| | | | |
| | | v |
| | <---------------------------------------+ Send |
| | | ADD_NODE + |
| v | | |
| | | |
| ProcessAddNodeCommand | | |
| + | | |
| | | | |
| | All nodes OK | | |
| | | | |
v | | | |
| set rank | | |
wait until ready | | | |
+ | | | |
| +----------------------------------------------------------------> |
| | | ADD_NODE response(nodes info) | |
| | | | ProcessAddNodeCommand
| | | v |
| | | |
| <--------------+ | wait until ready |
| ready_ = true | + |
| | | <---------------+
+-------------------+ v | |
| | +--------------------+ v
| | |
v | |
| v
Postoffice::Barrier |
| Postoffice::Barrier
+
手机如下,左侧是 Scheduler,右侧是 worker:
4.3.5 一个新加节点的序列
其互联过程可以分为3步:
第一步:worker/server节点初始化的时候,向schedular节点发送一个连接信息,假定自身是节点 2;
if (!is_scheduler_) {
// let the scheduler know myself
Message msg;
Node customer_specific_node = my_node_;
customer_specific_node.customer_id = customer_id;
msg.meta.recver = kScheduler;
msg.meta.control.cmd = Control::ADD_NODE;
msg.meta.control.node.push_back(customer_specific_node);
msg.meta.timestamp = timestamp_++;
Send(msg); //发送给schedular, 建立链接信息。
}
第二步:Scheduler 节点收到信息后,在 ProcessAddNodeCommandAtScheduler 之中,首先会和 节点 2 建立一个连接。会向所有已经和schedular建立连接的worker节点/server节点 广播此 "节点的加入信息“,并把 节点 2 请求连接的信息放入meta信息中。
// assign node rank
for (auto& node : nodes->control.node) {
std::string node_host_ip =
node.hostname + ":" + std::to_string(node.port);
if (connected_nodes_.find(node_host_ip) == connected_nodes_.end()) {
int id = node.role == Node::SERVER
? Postoffice::ServerRankToID(num_servers_)
: Postoffice::WorkerRankToID(num_workers_);
node.id = id;
Connect(node); // 连接这个新节点, 即建立一个socket, 然后senders_[id] = sender; 就是将目标id的socket存放起来后面使用
Postoffice::Get()->UpdateHeartbeat(node.id, t);
connected_nodes_[node_host_ip] = id;
} else {
int id = node.role == Node::SERVER
? Postoffice::ServerRankToID(num_servers_)
: Postoffice::WorkerRankToID(num_workers_);
shared_node_mapping_[id] = connected_nodes_[node_host_ip];
node.id = connected_nodes_[node_host_ip];
}
if (node.role == Node::SERVER) num_servers_++;
if (node.role == Node::WORKER) num_workers_++;
}
nodes->control.node.push_back(my_node_);
nodes->control.cmd = Control::ADD_NODE;
Message back;
back.meta = *nodes;
// 向所有已经和schedular建立连接的worker节点/server节点 广播此 "节点的加入信息“,并把 节点 2 请求连接的信息放入meta信息中。
for (int r : Postoffice::Get()->GetNodeIDs(kWorkerGroup + kServerGroup)) {
int recver_id = r;
if (shared_node_mapping_.find(r) == shared_node_mapping_.end()) {
back.meta.recver = recver_id;
back.meta.timestamp = timestamp_++;
Send(back);
}
}
第三步:现有worker/server节点收到这个命令后,在 ProcessAddNodeCommand 之中 会和 节点 2 形成连接。
for (const auto& node : ctrl.node) {
std::string addr_str = node.hostname + ":" + std::to_string(node.port);
if (connected_nodes_.find(addr_str) == connected_nodes_.end()) { // 现有连接中没有这个新节点
Connect(node); // 与新节点进行连接
connected_nodes_[addr_str] = node.id;
}
if (!node.is_recovery && node.role == Node::SERVER) ++num_servers_;
if (!node.is_recovery && node.role == Node::WORKER) ++num_workers_;
至此,整个过程就描述完了。每个新节点加入后,已经加入的节点都会通过schedular节点和这个新节点建立连接。
4.4 处理 HEARTBEAT 消息
我们接下来分析心跳机制。
4.4.1 心跳机制
为了记录网络的可达性,PS Lite 设计了心跳机制。具体而言:
- 每一个节点的 PostOffice 单例中维护了一个 MAP 结构,存储了心跳关联的节点的活跃信息。键为节点编号,值为上次收到其 HEARTBEAT 消息的时间戳。
- Worker/Server 只记录 Scheduler 的心跳,Scheduler 则记录所有节点的心跳。基于时间戳和心跳超时,可以输出所有的死亡节点。
- 每一个 Worker/Server 节点,会新建立一个心跳线程,每隔 PS_HEARTBEAT_INTERVAL 秒向 Scheduler 发送一条 HEARTBEAT 消息;
- Scheduler 节点收到后,响应一个 HEARTBEAT 消息。
- scheduler进行应答,通过当前时间与心跳包接收时间之差判断是否alive。
- Scheduler 会依据心跳节点的时间戳来判断死亡节点。如果新增的节点id在dead_node容器里,表示这个节点是重新恢复的;而新增节点通过schedular的中转与现有节点形成互联。
具体如下:
4.4.2 数据结构
std::unordered_map<int, time_t> heartbeats_ 就是存储了心跳关联的节点的活跃信息。键为节点编号,值为上次收到其 HEARTBEAT 消息的时间戳。
UpdateHeartbeat 会定期更新心跳。
void UpdateHeartbeat(int node_id, time_t t) {
std::lock_guard<std::mutex> lk(heartbeat_mu_);
heartbeats_[node_id] = t;
}
std::unordered_map<int, time_t> heartbeats_;
4.4.3 Worker / Server 发送心跳
在这两种节点中,启动了一个线程,每一个 Worker/Server 节点,每隔 PS_HEARTBEAT_INTERVAL 秒向 Scheduler 发送一条 HEARTBEAT 消息:
if (!is_scheduler_) {
// start heartbeat thread
heartbeat_thread_ =
std::unique_ptr<std::thread>(new std::thread(&Van::Heartbeat, this));
}
具体心跳函数是:
void Van::Heartbeat() {
const char* val = Environment::Get()->find("PS_HEARTBEAT_INTERVAL");
const int interval = val ? atoi(val) : kDefaultHeartbeatInterval;
while (interval > 0 && ready_.load()) {
std::this_thread::sleep_for(std::chrono::seconds(interval));
Message msg;
msg.meta.recver = kScheduler;
msg.meta.control.cmd = Control::HEARTBEAT;
msg.meta.control.node.push_back(my_node_);
msg.meta.timestamp = timestamp_++;
Send(msg);
}
}
4.4.4 Scheduler 节点处理心跳
Scheduler 节点收到后 HEARTBEAT 消息后,响应一个 HEARTBEAT 消息。UpdateHeartbeat 会定期更新心跳。
void Van::ProcessHearbeat(Message* msg) {
auto& ctrl = msg->meta.control;
time_t t = time(NULL);
for (auto& node : ctrl.node) {
Postoffice::Get()->UpdateHeartbeat(node.id, t);
if (is_scheduler_) {
Message heartbeat_ack;
heartbeat_ack.meta.recver = node.id;
heartbeat_ack.meta.control.cmd = Control::HEARTBEAT;
heartbeat_ack.meta.control.node.push_back(my_node_);
heartbeat_ack.meta.timestamp = timestamp_++;
// send back heartbeat
Send(heartbeat_ack);
}
}
}
4.4.5 死亡节点
Scheduler 在处理 ADD_NODE 消息时候,会看看是否已经有死亡节点,具体判通过当前时间戳与心跳包接收时间戳之差判断是否alive。
std::vector<int> Postoffice::GetDeadNodes(int t) {
std::vector<int> dead_nodes;
if (!van_->IsReady() || t == 0) return dead_nodes;
time_t curr_time = time(NULL);
const auto& nodes = is_scheduler_
? GetNodeIDs(kWorkerGroup + kServerGroup)
: GetNodeIDs(kScheduler);
{
std::lock_guard<std::mutex> lk(heartbeat_mu_);
for (int r : nodes) {
auto it = heartbeats_.find(r);
if ((it == heartbeats_.end() || it->second + t < curr_time)
&& start_time_ + t < curr_time) {
dead_nodes.push_back(r);
}
}
}
return dead_nodes;
}
逻辑如下:
+----------------------------------------------------+
| Scheduler |
| |
| |
| |
| heartbeats_ |
| |
| receiver_thread_+--------> ProcessHearbeat |
| ^ + ^ + |
| | | | | |
| | | | | |
| | | | | |
+----------------------------------------------------+
| | | |
| | | | RESPONSE
| | | +-------------------------------------+
| | | |
| | +-------------------------------+ |
| | | |
HEARTBEAT | | RESPONSE HEARTBEAT | |
| | | |
+-----------------------------------------+ +-----------------------------------------+
| Worker | | | | Server | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| heartbeats_ | | | | heartbeats_ | | |
| + | | | + | |
| heartbeat_thread_+----> Heartbeat | | | heartbeat_thread_+--> Heartbeat | |
| | | | | |
| v | | v |
| receiver_thread_ +---> ProcessHearbeat | | receiver_thread_ +--> ProcessHearbeat |
| | | |
| | | |
| | | |
+-----------------------------------------+ +-----------------------------------------+
4.5 处理 TERMINATE 消息
ProcessTerminateCommand 会处理结束消息,具体就是设定 ready_ 为 false。
这样就预示着 Van 状态不对,不可以继续处理。
void Van::ProcessTerminateCommand() {
PS_VLOG(1) << my_node().ShortDebugString() << " is stopped";
ready_ = false;
}
inline bool IsReady() { return ready_; }
4.6 处理 ACK 消息
4.6.1 Ack机制
在分布式系统中,通信也是不可靠的,丢包、延时都是必须考虑的场景。PS Lite 设计了 Resender类来提高通信的可靠性,它引入了 ACK 机制。即:
- 每一个节点,对于收到的非 ACK/TERMINATE 消息,必须响应一个 ACK 消息。
- 每一个节点,对于发送的每一个非 ACK/TERMINATE 消息,必须在本地缓存下来。存储的数据结构是一个 MAP,根据消息的内容生产唯一的键。
- 每一个节点,对于收到的 ACK 消息,必须根据反馈的键从本地缓存中移除对应的消息。
- 每一个节点运行一个监控线程,每隔 PS_RESEND_TIMEOUT 毫秒检查一下本地缓存。根据每个消息的发送时间戳和当前时间,找出超时的消息进行重发,并累加其重试次数。
4.6.2 Resender类
定义如下,其中 send_buff_ 就是发送缓存,用来存储发送了的消息列表。acked_ 就是已经确认的消息。
class Resender {
std::thread* monitor_;
std::unordered_set<uint64_t> acked_;
std::atomic<bool> exit_{false};
std::mutex mu_;
int timeout_;
int max_num_retry_;
Van* van_;
using Time = std::chrono::milliseconds;
// the buffer entry
struct Entry {
Message msg;
Time send;
int num_retry = 0;
};
std::unordered_map<uint64_t, Entry> send_buff_;
};
4.6.3 监控线程
监控线程以及函数如下如下,就是被唤醒时候,从send_buff_(本地缓存)找到每个消息的发送时间戳和当前时间,找出超时的消息进行重发,并累加其重试次数。 :
monitor_ = new std::thread(&Resender::Monitoring, this);
void Monitoring() {
while (!exit_) {
std::this_thread::sleep_for(Time(timeout_));
std::vector<Message> resend;
Time now = Now();
mu_.lock();
for (auto& it : send_buff_) {
if (it.second.send + Time(timeout_) * (1+it.second.num_retry) < now) {
resend.push_back(it.second.msg);
++it.second.num_retry;
CHECK_LT(it.second.num_retry, max_num_retry_);
}
}
mu_.unlock();
for (const auto& msg : resend) van_->Send(msg);
}
}
4.6.4 发送时缓存
当 Van 发送消息时候,如果配置了重传,就调用AddOutgoing函数把消息加入到发送缓存。
int Van::Send(const Message& msg) {
int send_bytes = SendMsg(msg);
CHECK_NE(send_bytes, -1);
send_bytes_ += send_bytes;
if (resender_) resender_->AddOutgoing(msg);
if (Postoffice::Get()->verbose() >= 2) {
PS_VLOG(2) << msg.DebugString();
}
return send_bytes;
}
下面函数就是加入到发送缓存。
/**
* \brief add an outgoining message
*
*/
void AddOutgoing(const Message& msg) {
if (msg.meta.control.cmd == Control::ACK) return;
CHECK_NE(msg.meta.timestamp, Meta::kEmpty) << msg.DebugString();
auto key = GetKey(msg);
std::lock_guard<std::mutex> lk(mu_);
// already buffered, which often due to call Send by the monitor thread
if (send_buff_.find(key) != send_buff_.end()) return;
auto& ent = send_buff_[key];
ent.msg = msg;
ent.send = Now();
ent.num_retry = 0;
}
4.6.5 清除缓存
下面函数有两个作用:
- 检查是否是重复消息,则已经收到的确认消息;
- 如果是确认消息,则从发送缓存中清除。
/**
* \brief add an incomming message
* \brief return true if msg has been added before or a ACK message
*/
bool AddIncomming(const Message& msg) {
// a message can be received by multiple times
if (msg.meta.control.cmd == Control::TERMINATE) {
return false;
} else if (msg.meta.control.cmd == Control::ACK) {
mu_.lock();
auto key = msg.meta.control.msg_sig;
auto it = send_buff_.find(key);
if (it != send_buff_.end()) send_buff_.erase(it);
mu_.unlock();
return true;
} else {
mu_.lock();
auto key = GetKey(msg);
auto it = acked_.find(key);
bool duplicated = it != acked_.end();
if (!duplicated) acked_.insert(key);
mu_.unlock();
// send back ack message (even if it is duplicated)
Message ack;
ack.meta.recver = msg.meta.sender;
ack.meta.sender = msg.meta.recver;
ack.meta.control.cmd = Control::ACK;
ack.meta.control.msg_sig = key;
van_->Send(ack);
// warning
if (duplicated) LOG(WARNING) << "Duplicated message: " << msg.DebugString();
return duplicated;
}
}
4.7 处理数据消息
ProcessDataMsg 用来处理 worker 发过来的数据消息(就是worker向server更新梯度),具体是取得对应的Customer后,调用 Customer 的方法进行处理,直接将msg
放入处理队列中。
我们会放在 Customer 之中进行介绍。
void Van::ProcessDataMsg(Message* msg) {
// data msg
int app_id = msg->meta.app_id;
int customer_id =
Postoffice::Get()->is_worker() ? msg->meta.customer_id : app_id;
auto* obj = Postoffice::Get()->GetCustomer(app_id, customer_id, 5);
obj->Accept(*msg); // 这里给 Customer 添加消息
}
0x05 ZMQVan
ZMQVan是基于zeromq的Van的实现,即为用zmq库实现了连接的底层细节(zmq库是一个开源库,对socket进行了优良的封装,他使得Socket编程更加简单、简洁和性能更高)。
5.1 定义
ZMQVan定义如下:
ZMQVan 继承于Van ,在这个类的基础上加了两个成员变量,分别是:
- unordered_map<int, void*> senders_ :senders_是一个集合,就是本节点发送 socket 的集合,即node id 与 socket 的映射。比如 8号节点要给9号节点发消息,那么只要找到(9,socket_9)这个组合就行了,然后调用 socket_9.send(message),
- void *receiver_ = nullptr :是 Bind 函数得到的 socket 连接,因为是接受端,所以只有一个 socket 就行。
具体如下:
class ZMQVan : public Van {
void *context_ = nullptr;
/**
* \brief node_id to the socket for sending data to this node
*/
std::unordered_map<int, void*> senders_;
std::mutex mu_;
void *receiver_ = nullptr;
};
5.2 Van 函数
Van类 有如下函数会调用到 ZMQVan 或者被 ZMQVan 调用。
5.2.1 发送消息
Send 函数就是调用 ZMQVan 的 SendMsg 函数进行发送消息,发送之后如果设定了ACK机制,则会调用 resender_->AddOutgoing。
int Van::Send(const Message& msg) {
int send_bytes = SendMsg(msg);
CHECK_NE(send_bytes, -1);
send_bytes_ += send_bytes;
if (resender_) resender_->AddOutgoing(msg);
if (Postoffice::Get()->verbose() >= 2) {
PS_VLOG(2) << msg.DebugString();
}
return send_bytes;
}
5.2.2 Meta 类
Meta封装了元数据,发送者,接受者,时间戳,请求还是响应等。
/**
* \brief meta info of a message
*/
struct Meta {
/** \brief the empty value */
static const int kEmpty;
/** \brief an int head */
int head;
/** \brief the unique id of the application of messsage is for*/
int app_id;
/** \brief customer id*/
int customer_id;
/** \brief the timestamp of this message */
int timestamp;
/** \brief the node id of the sender of this message */
int sender;
/** \brief the node id of the receiver of this message */
int recver;
/** \brief whether or not this is a request message*/
bool request;
/** \brief whether or not a push message */
bool push;
/** \brief whether or not a pull message */
bool pull;
/** \brief whether or not it's for SimpleApp */
bool simple_app;
/** \brief an string body */
std::string body;
/** \brief data type of message.data[i] */
std::vector<DataType> data_type;
/** \brief system control message */
Control control;
/** \brief the byte size */
int data_size = 0;
/** \brief message priority */
int priority = 0;
};
为了缓解通信压力,ps-lite 使用了Protobuf对 Meta 进行数据压缩。
5.2.3 压缩 Meta
就是按照 protobuf 来进行数据压缩。
void Van::PackMeta(const Meta& meta, char** meta_buf, int* buf_size) {
// convert into protobuf
PBMeta pb;
pb.set_head(meta.head);
if (meta.app_id != Meta::kEmpty) pb.set_app_id(meta.app_id);
if (meta.timestamp != Meta::kEmpty) pb.set_timestamp(meta.timestamp);
if (meta.body.size()) pb.set_body(meta.body);
pb.set_push(meta.push);
pb.set_pull(meta.pull);
pb.set_request(meta.request);
pb.set_simple_app(meta.simple_app);
pb.set_priority(meta.priority);
pb.set_customer_id(meta.customer_id);
for (auto d : meta.data_type) pb.add_data_type(d);
if (!meta.control.empty()) {
auto ctrl = pb.mutable_control();
ctrl->set_cmd(meta.control.cmd);
if (meta.control.cmd == Control::BARRIER) {
ctrl->set_barrier_group(meta.control.barrier_group);
} else if (meta.control.cmd == Control::ACK) {
ctrl->set_msg_sig(meta.control.msg_sig);
}
for (const auto& n : meta.control.node) {
auto p = ctrl->add_node();
p->set_id(n.id);
p->set_role(n.role);
p->set_port(n.port);
p->set_hostname(n.hostname);
p->set_is_recovery(n.is_recovery);
p->set_customer_id(n.customer_id);
}
}
// to string
*buf_size = pb.ByteSize();
*meta_buf = new char[*buf_size + 1];
CHECK(pb.SerializeToArray(*meta_buf, *buf_size))
<< "failed to serialize protobuf";
}
5.2.3 解压 UnpackMeta
按照protobuf 预先生成的 PBMeta 格式进行解压。
void Van::UnpackMeta(const char* meta_buf, int buf_size, Meta* meta) {
// to protobuf
PBMeta pb;
CHECK(pb.ParseFromArray(meta_buf, buf_size))
<< "failed to parse string into protobuf";
// to meta
meta->head = pb.head();
meta->app_id = pb.has_app_id() ? pb.app_id() : Meta::kEmpty;
meta->timestamp = pb.has_timestamp() ? pb.timestamp() : Meta::kEmpty;
meta->request = pb.request();
meta->push = pb.push();
meta->pull = pb.pull();
meta->simple_app = pb.simple_app();
meta->priority = pb.priority();
meta->body = pb.body();
meta->customer_id = pb.customer_id();
meta->data_type.resize(pb.data_type_size());
for (int i = 0; i < pb.data_type_size(); ++i) {
meta->data_type[i] = static_cast<DataType>(pb.data_type(i));
}
if (pb.has_control()) {
const auto& ctrl = pb.control();
meta->control.cmd = static_cast<Control::Command>(ctrl.cmd());
meta->control.barrier_group = ctrl.barrier_group();
meta->control.msg_sig = ctrl.msg_sig();
for (int i = 0; i < ctrl.node_size(); ++i) {
const auto& p = ctrl.node(i);
Node n;
n.role = static_cast<Node::Role>(p.role());
n.port = p.port();
n.hostname = p.hostname();
n.id = p.has_id() ? p.id() : Node::kEmpty;
n.is_recovery = p.is_recovery();
n.customer_id = p.customer_id();
meta->control.node.push_back(n);
}
} else {
meta->control.cmd = Control::EMPTY;
}
}
5.2.4 PackMetaPB
PackMetaPB 从注释看,是字节跳动提交的,主要用于 ibverbs_van.h,所以我们不做深入研究。
void Van::PackMetaPB(const Meta& meta, PBMeta* pb) {
pb->set_head(meta.head);
if (meta.app_id != Meta::kEmpty) pb->set_app_id(meta.app_id);
if (meta.timestamp != Meta::kEmpty) pb->set_timestamp(meta.timestamp);
if (meta.body.size()) pb->set_body(meta.body);
pb->set_push(meta.push);
pb->set_request(meta.request);
pb->set_simple_app(meta.simple_app);
pb->set_priority(meta.priority);
pb->set_customer_id(meta.customer_id);
for (auto d : meta.data_type) pb->add_data_type(d);
if (!meta.control.empty()) {
auto ctrl = pb->mutable_control();
ctrl->set_cmd(meta.control.cmd);
if (meta.control.cmd == Control::BARRIER) {
ctrl->set_barrier_group(meta.control.barrier_group);
} else if (meta.control.cmd == Control::ACK) {
ctrl->set_msg_sig(meta.control.msg_sig);
}
for (const auto& n : meta.control.node) {
auto p = ctrl->add_node();
p->set_id(n.id);
p->set_role(n.role);
p->set_port(n.port);
p->set_hostname(n.hostname);
p->set_is_recovery(n.is_recovery);
p->set_customer_id(n.customer_id);
}
}
pb->set_data_size(meta.data_size);
}
5.3 ZMQVan 派生函数
ZMQVan 有如下重要的派生函数。
5.3.1 Bind
Bind 逻辑如下:
- 使用 zmq_bind() 来把一个socket绑定在一个本地的网络节点(endpoint)上,然后开始接收发送到本节点上的消息。
- 节点地址信息是一个字符串,它包括一个协议 / 然后跟着一个address。
- Bind 函数会依据配置的变量 "DMLC_LOCAL" 来决定是启用 ipc 方式还是 tcp 方式,从而配置节点地址信息。
- 如果是 schedule节点调用,则不需要指定port,但是对于work和server需要自己查找一个本地可用端口。
- 在查找端口时候,会设置最大重试次数。
int Bind(const Node& node, int max_retry) override {
receiver_ = zmq_socket(context_, ZMQ_ROUTER);
int local = GetEnv("DMLC_LOCAL", 0);
std::string hostname = node.hostname.empty() ? "*" : node.hostname;
int use_kubernetes = GetEnv("DMLC_USE_KUBERNETES", 0);
if (use_kubernetes > 0 && node.role == Node::SCHEDULER) {
hostname = "0.0.0.0";
}
std::string addr = local ? "ipc:///tmp/" : "tcp://" + hostname + ":";
int port = node.port;
unsigned seed = static_cast<unsigned>(time(NULL) + port);
for (int i = 0; i < max_retry + 1; ++i) {
auto address = addr + std::to_string(port);
if (zmq_bind(receiver_, address.c_str()) == 0) break;
if (i == max_retry) {
port = -1;
} else {
port = 10000 + rand_r(&seed) % 40000;
}
}
return port;
}
5.3.2 Connect
主要就是初始化 Sender_,逻辑如下:
- 如果找到了对应socket就关闭socket。
- 如果发现是 worker 发给同类,或者 server 发给同类,并且不是自己发给自己(Scheduler 可以自己发给自己),则返回。
- 建立一个ZMQ套接字(socket),并且以一个不透明指针的形式把这新创建的socket赋值给 sender。
- 如果本身是scheduler,则配置socket,把自己的 id 绑定到 socket上。
- 将sender这个socket和目标地址连接。
- 将目标id的socket存放起来,即把 socket 加入到Sender_。
具体如下:
void Connect(const Node& node) override {
int id = node.id;
auto it = senders_.find(id);
if (it != senders_.end()) {
zmq_close(it->second); // 如果找到了对应socket就关闭socket
}
// worker doesn't need to connect to the other workers. same for server
if ((node.role == my_node_.role) && (node.id != my_node_.id)) {
return;
}
void *sender = zmq_socket(context_, ZMQ_DEALER); //建立一个socket
//如果本身是scheduler,则一开始就是知道自己的id = 1,所以这个if条件就是说把自己的id绑定到socket上
if (my_node_.id != Node::kEmpty) {
std::string my_id = "ps" + std::to_string(my_node_.id);
zmq_setsockopt(sender, ZMQ_IDENTITY, my_id.data(), my_id.size());
const char* watermark = Environment::Get()->find("DMLC_PS_WATER_MARK");
if (watermark) {
const int hwm = atoi(watermark);
zmq_setsockopt(sender, ZMQ_SNDHWM, &hwm, sizeof(hwm));
}
}
// connect
std::string addr = "tcp://" + node.hostname + ":" + std::to_string(node.port);
if (GetEnv("DMLC_LOCAL", 0)) {
addr = "ipc:///tmp/" + std::to_string(node.port);
}
if (zmq_connect(sender, addr.c_str()) != 0) { //将sender这个socket和目标地址连接
LOG(FATAL) << "connect to " + addr + " failed: " + zmq_strerror(errno);
}
senders_[id] = sender; //将目标id的socket存放起来后面使用
}
5.3.3 SendMsg
逻辑如下:
- 从保存的 sender_ 之中找到之前保留的socket;
- 压缩 meta;
- 发送 meta;
- 循环分段发送data;
int SendMsg(const Message& msg) override {
std::lock_guard<std::mutex> lk(mu_);
// find the socket
int id = msg.meta.recver;
CHECK_NE(id, Meta::kEmpty);
auto it = senders_.find(id);
if (it == senders_.end()) {
LOG(WARNING) << "there is no socket to node " << id;
return -1;
}
void *socket = it->second;
// send meta
int meta_size; char* meta_buf;
PackMeta(msg.meta, &meta_buf, &meta_size);
int tag = ZMQ_SNDMORE;
int n = msg.data.size();
if (n == 0) tag = 0;
zmq_msg_t meta_msg;
zmq_msg_init_data(&meta_msg, meta_buf, meta_size, FreeData, NULL);
while (true) {
if (zmq_msg_send(&meta_msg, socket, tag) == meta_size) break;
if (errno == EINTR) continue;
return -1;
}
// zmq_msg_close(&meta_msg);
int send_bytes = meta_size;
// send data
for (int i = 0; i < n; ++i) {
zmq_msg_t data_msg;
SArray<char>* data = new SArray<char>(msg.data[i]);
int data_size = data->size();
zmq_msg_init_data(&data_msg, data->data(), data->size(), FreeData, data);
if (i == n - 1) tag = 0;
while (true) {
if (zmq_msg_send(&data_msg, socket, tag) == data_size) break;
if (errno == EINTR) continue;
return -1;
}
// zmq_msg_close(&data_msg);
send_bytes += data_size;
}
return send_bytes;
}
5.3.4 RecvMsg
RecvMsg 就是在绑定的端口上接受消息。
接受消息时候,会判断是第几个消息,然后做不同的处理。
int RecvMsg(Message* msg) override {
msg->data.clear();
size_t recv_bytes = 0;
for (int i = 0; ; ++i) {
zmq_msg_t* zmsg = new zmq_msg_t;
CHECK(zmq_msg_init(zmsg) == 0) << zmq_strerror(errno);
while (true) {
if (zmq_msg_recv(zmsg, receiver_, 0) != -1) break;
if (errno == EINTR) {
std::cout << "interrupted";
continue;
}
return -1;
}
char* buf = CHECK_NOTNULL((char *)zmq_msg_data(zmsg));
size_t size = zmq_msg_size(zmsg);
recv_bytes += size;
if (i == 0) {
// identify
msg->meta.sender = GetNodeID(buf, size);
msg->meta.recver = my_node_.id;
CHECK(zmq_msg_more(zmsg));
zmq_msg_close(zmsg);
delete zmsg;
} else if (i == 1) {
// task
UnpackMeta(buf, size, &(msg->meta));
zmq_msg_close(zmsg);
bool more = zmq_msg_more(zmsg);
delete zmsg;
if (!more) break;
} else {
// zero-copy
SArray<char> data;
data.reset(buf, size, [zmsg, size](char* buf) {
zmq_msg_close(zmsg);
delete zmsg;
});
msg->data.push_back(data);
if (!zmq_msg_more(zmsg)) {
break;
}
}
}
return recv_bytes;
}
GetNodeID 函数是
/**
* return the node id given the received identity
* \return -1 if not find
*/
int GetNodeID(const char* buf, size_t size) {
if (size > 2 && buf[0] == 'p' && buf[1] == 's') {
int id = 0;
size_t i = 2;
for (; i < size; ++i) {
if (buf[i] >= '0' && buf[i] <= '9') {
id = id * 10 + buf[i] - '0';
} else {
break;
}
}
if (i == size) return id;
}
return Meta::kEmpty;
}
0x06 总结
我们最后进行一下总结:
邮局里有了地址簿,就需要有货车来负责拉送物件,Van 就是整个Parameter Server的通信模块,其特点如下。
- PostOffice 类在实例化的时候,会创建一个 Van 类的实例 作为成员变量。该 Van 实例与所属 PostOffice 实例生命周期相同(每个节点只有一个该对象);
- Van 负责具体的节点间通信。具体来说就是负责建立起节点之间的互相连接(例如Worker与Scheduler之间的连接),并且开启本地的receiving thread用来监听收到的message。
- Van对象的初始化函数作用就是依据本地节点类型的不同,做不同设置,从而启动端口,建立本地节点到scheduler的连结,启动接收消息线程,心跳线程等,这样就可以进行通信了。
- Parameter Server在后台线程 receiver_thread_ 进行接受/处理消息。除了传递参数的数据消息外,各个节点之间控制信息有:
- ADD_NODE:worker和server向shceduler进行节点注册;
- BARRIER:节点间的同步阻塞消息;
- HEARTBEAT:节点间的心跳信号;
- TERMINATE:节点退出信号;
- ACK:确认消息,ACK 类型只有启用了 Resender 类才会出现。
- EMPTY:push or pull;
0xEE 个人信息
★★★★★★关于生活和技术的思考★★★★★★
微信公众账号:罗西的思考
如果您想及时得到个人撰写文章的消息推送,或者想看看个人推荐的技术资料,敬请关注。
0xFF 参考
[源码解析] 机器学习参数服务器ps-lite(2) ----- 通信模块Van的更多相关文章
- [源码解析] 机器学习参数服务器ps-lite (1) ----- PostOffice
[源码解析] 机器学习参数服务器ps-lite 之(1) ----- PostOffice 目录 [源码解析] 机器学习参数服务器ps-lite 之(1) ----- PostOffice 0x00 ...
- [源码解析] 机器学习参数服务器ps-lite 之(3) ----- 代理人Customer
[源码解析] 机器学习参数服务器ps-lite 之(3) ----- 代理人Customer 目录 [源码解析] 机器学习参数服务器ps-lite 之(3) ----- 代理人Customer 0x0 ...
- [源码解析]机器学习参数服务器ps-lite(4) ----- 应用节点实现
[源码解析]机器学习参数服务器ps-lite(4) ----- 应用节点实现 目录 [源码解析]机器学习参数服务器ps-lite(4) ----- 应用节点实现 0x00 摘要 0x01 基础类 1. ...
- [源码解析] 机器学习参数服务器 Paracel (2)--------SSP控制协议实现
[源码解析] 机器学习参数服务器 Paracel (2)-----SSP实现 目录 [源码解析] 机器学习参数服务器 Paracel (2)-----SSP实现 0x00 摘要 0x01 背景知识 1 ...
- [源码解析] 机器学习参数服务器 Paracel (1)-----总体架构
[源码解析] 机器学习参数服务器 Paracel (1)-----总体架构 目录 [源码解析] 机器学习参数服务器 Paracel (1)-----总体架构 0x00 摘要 0x01使用 1.1 配置 ...
- [源码解析] 机器学习参数服务器Paracel (3)------数据处理
[源码解析] 机器学习参数服务器Paracel (3)------数据处理 目录 [源码解析] 机器学习参数服务器Paracel (3)------数据处理 0x00 摘要 0x01 切分需要 1.1 ...
- [源码解析] 深度学习流水线并行 PipeDream(5)--- 通信模块
[源码解析] 深度学习流水线并行 PipeDream(5)--- 通信模块 目录 [源码解析] 深度学习流水线并行 PipeDream(5)--- 通信模块 0x00 摘要 0x01 前言 0x02 ...
- springMVC源码解析--HandlerMethodArgumentResolverComposite参数解析器集合(二)
上一篇博客springMVC源码分析--HandlerMethodArgumentResolver参数解析器(一)中我们已经介绍了参数解析相关的东西,并且也提到了HandlerMethodArgume ...
- 【Hibernate实战】源码解析Hibernate参数绑定及PreparedStatement防SQL注入原理
本文采用mysql驱动是5.1.38版本. 本篇文章涉及内容比较多,单就Hibernate来讲就很大,再加上数据库驱动和数据库相关,非一篇文章或一篇专题就能说得完.本文从使用入手在[Spr ...
随机推荐
- Visual Studio 2019本地不能运行Azure Functions
最近一个项目,需要维护同事写得代码,主要是一堆基于 .net core 3.1 的 Azure Functions.想起2年前第一次接触 Azure Functions(那次是基于.net frame ...
- 详解apollo的设计与使用
简介 apollo 是一款由携程团队开发的配置中心,可以实现配置的集中管理.分环境管理.即时生效等等.在这篇博客中,我们可以了解到: 为什么使用配置中心 如何设计一个配置中心 apollo 是如何设计 ...
- Manacher(马拉车)————O(n)回文子串
Manacher 一.背景 1975年,Manacher发明了Manacher算法(中文名:马拉车算法),是一个可以在O(n)的复杂度中返回字符串s中最长回文子串长度的算法,十分巧妙. 让我们举个栗子 ...
- 复习Spring第一课--Spring的基本知识及使用
关于Spring: spring容器是Spring的核心,该容器负责管理spring中的java组件, ApplicationContext ctx = new ClassPathXmlApplic ...
- 我用段子讲.NET之依赖注入其一
<我用段子讲.NET之依赖注入其一> 1) 西城的某个人工湖畔,湖水清澈见底,湖畔柳树成荫.人工湖往北,坐落着两幢写字楼,水晶大厦靠近地铁站,由于为了与湖面天际线保持一致,楼层只有26层高 ...
- 用户RFM模型及应用
RMF含义 R(Recency)(用户粘性,越小越好):用户最近一次交易时间的间隔.R值越大,表示用户交易发生的日期越久,反之则表示用户交易发生的日期越近 F(Frequency)(用户忠诚度,越大越 ...
- JS replace 替换全部数据
(1)使用具有全局标志g的正则表达式 var str = "dogdogdog"; var str2 = str.replace(/dog/g,"cat");/ ...
- salesforce零基础学习(一百零五)Change Data Capture
本篇参考: https://developer.salesforce.com/docs/atlas.en-us.232.0.api_streaming.meta/api_streaming/using ...
- MySQL主从异常恢复
说明 MySQL主从出现不同步的情况时,或者要添加新的从库时,可以使用以下方法进行操作回复主从. 停止业务应用 停止所有连接到主从库上的应用,在恢复主从期间禁止任何增删改等操作,否则恢复失败 停止主从 ...
- Java:java获取Linux下的路径
指定Linux的路径 //Linux系统路径 StringBuilder sb = new StringBuilder(File.separator); String Url = sb.append( ...