本文主要基于MXNet1.6.0版本进行分析。

上一篇文章中,我们分析了MXNet中KVStore的进程内通信机制。在这篇文章中,我们主要分析KVStore如何进行多节点分布式通信。

在KVStore的实现中,KVStoreDistKVStoreDistServer分别对应参数服务器中的worker节点与server节点。KVStoreDist继承自KVStoreLocal,通过封装PS-Lite中的KVWorker实现了PushPull等接口,从而向server发送各类请求;而KVStoreDistServer则封装了PS-Lite中的KVServer,用来处理并响应worker发来的各类请求。

worker端执行逻辑

worker创建

KVStoreDist的构造函数为每个worker节点创建一个ps::KVWorker<char>类型的对象。如果当前worker节点不是一个recovery的节点,那么就阻塞到所有的worker和server启动。

  1. explicit KVStoreDist(bool use_device_comm)
  2. : KVStoreLocal(use_device_comm), ps_worker_(nullptr), server_(nullptr) {
  3. if (IsWorkerNode()) {
  4. int new_customer_id = GetNewCustomerId();
  5. ps_worker_ = new ps::KVWorker<char>(0, new_customer_id);
  6. ps::StartAsync(new_customer_id, "mxnet\0");
  7. if (!ps::Postoffice::Get()->is_recovery()) {
  8. ps::Postoffice::Get()->Barrier(
  9. new_customer_id,
  10. ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
  11. }
  12. }
  13. bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 1000 * 1000);
  14. log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false);
  15. }

worker的初始化过程

在初始化时,每个worker首先检查key的唯一性,随后调用comm_->Init为每个key初始化进行本地通信的资源。本地初始化完成后,worker0把自己本地的权重发送给所有的server。worker0在其push操作完成后,会将数据写入到comm_buf_compr_buf_这两个缓冲区中。

  1. void InitImpl(const std::vector<int>& keys,
  2. const std::vector<NDArray>& values) override {
  3. CheckUnique(keys);
  4. for (size_t i = 0; i < keys.size(); ++i) {
  5. comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype());
  6. }
  7. if (get_rank() == 0 && this->ps_worker_->get_customer()->customer_id() == 0) {
  8. Push_(keys, values, 0, false);
  9. // wait until the push is finished
  10. for (const int key : keys) {
  11. comm_buf_[key].WaitToWrite();
  12. compr_buf_[key].WaitToWrite();
  13. }
  14. } else {
  15. // do nothing
  16. }
  17. if (!ps::Postoffice::Get()->is_recovery()) {
  18. Barrier();
  19. }
  20. }

worker发送控制消息

worker端通过SendCommandToServers函数向server端发送控制消息。例如,在KVStoreDist的析构函数中有如下代码,用来从worker0节点向所有server节点发送一个终止的命令。

  1. if (get_rank() == 0 && ps_worker_->get_customer()->customer_id() == 0) {
  2. // stop the executor at servers
  3. SendCommandToServers(static_cast<int>(CommandType::kStopServer), "");
  4. }

worker发送数据消息

worker会调用Push_函数向server发送数据请求,它的核心逻辑如下所示(省略部分代码)。与之前提到的本地通信类似,在向server节点发送数据之前,会先调用GroupPairsPush把具有相同key的value汇总到一个vector中。对于每个key,先在本地进行一次Reduce操作聚合所有设备上的梯度,并将结果存放到comm_buf中。随后,通过EncodeDefaultKey把key和value编码成PS-Lite支持的数据结构,再调用PushDefault把对应的数据发送出去。

  1. void KVStoreDist::Push_(const std::vector<int>& keys,
  2. const std::vector<NDArray>& values,
  3. int priority,
  4. bool do_merge) {
  5. std::vector<int> uniq_keys;
  6. std::vector<std::vector<NDArray>> grouped_val;
  7. GroupKVPairsPush(keys, values, &uniq_keys, &grouped_val, false);
  8. for (size_t i = 0; i < uniq_keys.size(); ++i) {
  9. int key = uniq_keys[i];
  10. const auto& vals = grouped_vals[i];
  11. NDArray merged = do_merge ? comm_->Reduce(key, vals, priority) : vals[0];
  12. auto &comm_buf = comm_buf_[key];
  13. if (merged.ctx().dev_mask() == cpu::kDevMask) {
  14. // Start of a push doesn't guarantee that the previous pushes are completed.
  15. // This shouldn't affect training of networks though because training involves
  16. // a sequence of push, pull, then push. This imposes ordering that the
  17. // second push happens after the first pull, and the pull happens after first push.
  18. comm_buf = merged; // avoid memory copy
  19. } else {
  20. if (comm_buf.is_none()) {
  21. comm_buf = NDArray(merged.shape(), pinned_ctx_, true, merged.dtype());
  22. }
  23. CopyFromTo(merged, &comm_buf);
  24. }
  25. const int dtype = merged.dtype();
  26. const int num_bytes = mshadow::mshadow_sizeof(dtype);
  27. PSKV& pskv = EncodeDefaultKey(key, comm_buf.shape().Size(), num_bytes);
  28. PushDefault(key, comm_buf, pskv, priority);
  29. }
  30. }

PushDefault会调用ps_worker_->ZPush来完成梯度的发送,梯度发送以及发送之前的一些准备操作都被封装到一个lambda表达式中,这个lambda表达式随后被压入到MXNet后端的依赖引擎中等待执行。

  1. void PushDefault(int key, const NDArray &send_buf, const PSKV& pskv, int priority) {
  2. auto push_to_servers =
  3. [this, key, pskv, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
  4. const int dtype = send_buf.dtype();
  5. // convert to ps keys
  6. const size_t size = send_buf.shape().Size() * mshadow::mshadow_sizeof(dtype);
  7. char* data = static_cast<char *>(send_buf.data().dptr_);
  8. // do push. false means no delete
  9. ps::SArray<char> vals(data, size, false);
  10. int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype);
  11. CHECK_NOTNULL(ps_worker_)->ZPush(
  12. pskv.keys, vals, pskv.lens,
  13. cmd, [cb]() { cb(); });
  14. };
  15. Engine::Get()->PushAsync(
  16. push_to_servers,
  17. pinned_ctx_,
  18. {send_buf.var()},
  19. {},
  20. FnProperty::kNormal,
  21. priority,
  22. "KVStoreDistDefaultPush");
  23. }

Pull操作的过程如下所示。在准备工作完成后,调用ps_server_->ZPull完成权重的拉取,最后在本地执行Broadcast把从server端拉回的权重广播到所有设备上。

  1. void PullImpl(const std::vector<int>& keys,
  2. const std::vector<NDArray*>& values,
  3. int priority, bool ignore_sparse) override {
  4. CHECK(ignore_sparse) << "dist kvstore pull doesn't support ignore_sparse=False";
  5. std::vector<int> uniq_keys;
  6. std::vector<std::vector<NDArray*> > grouped_vals;
  7. GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals, true);
  8. for (size_t i = 0; i < uniq_keys.size(); ++i) {
  9. int key = uniq_keys[i];
  10. // use the same array for merging to guarantee that pull always happens
  11. // after the previous push on this key
  12. auto& recv_buf = comm_buf_[key];
  13. const auto storage_type = grouped_vals[i][0]->storage_type();
  14. CHECK_EQ(storage_type, kDefaultStorage)
  15. << "Expected stype of value to be kDefaultStorage";
  16. if (recv_buf.is_none()) {
  17. // it may happen for the first time a no-rank-0 worker pull the weight.
  18. recv_buf = NDArray(grouped_vals[i][0]->shape(), pinned_ctx_,
  19. true, grouped_vals[i][0]->dtype());
  20. }
  21. auto pull_from_servers = [this, key, recv_buf](
  22. RunContext rctx, Engine::CallbackOnComplete cb) {
  23. // convert to ps keys
  24. size_t size = recv_buf.shape().Size();
  25. const int dtype = recv_buf.dtype();
  26. const int num_bytes = mshadow::mshadow_sizeof(dtype);
  27. PSKV& pskv = EncodeDefaultKey(key, size, num_bytes) :
  28. char* data = static_cast<char*> (recv_buf.data().dptr_);
  29. // false means not to delete data when SArray is deleted
  30. auto vals = new ps::SArray<char>(data, size * num_bytes, false);
  31. // issue pull
  32. RequestType mode = RequestType::kDefaultPushPull;
  33. const int cmd = GetCommandType(mode, dtype);
  34. CHECK_NOTNULL(ps_worker_)->ZPull(
  35. pskv.keys, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); });
  36. };
  37. CHECK_NOTNULL(Engine::Get())->PushAsync(
  38. pull_from_servers,
  39. pinned_ctx_,
  40. {},
  41. {recv_buf.var()},
  42. FnProperty::kNormal,
  43. priority,
  44. "KVStoreDistDefaultStoragePull");
  45. comm_->Broadcast(key, recv_buf, grouped_vals[i], priority);
  46. }
  47. }

server端执行逻辑

server的创建以及启动

首先在KVStoreDistServer的构造函数中为ps_server_绑定处理命令请求的CommandHandle以及处理数据请求的DataHandleEx。注意到在绑定CommandHandle时,ps_server_被向上转型成ps::SimpleApp*类型。这是因为ps::SimpleApp中实现的set_request_handle只能接收包含两个形参的函数对象,而ps::KVServer继承了ps::SimpleApp并且重载了set_request_handle,使之可以接收包含三个形参的函数对象。这样一来,就完成了对控制请求和数据请求的分开处理。

  1. KVStoreDistServer() {
  2. using namespace std::placeholders;
  3. ps_server_ = new ps::KVServer<char>(0);
  4. static_cast<ps::SimpleApp*>(ps_server_)->set_request_handle(
  5. std::bind(&KVStoreDistServer::CommandHandle, this, _1, _2));
  6. ps_server_->set_request_handle(
  7. std::bind(&KVStoreDistServer::DataHandleEx, this, _1, _2, _3));
  8. sync_mode_ = false;
  9. gradient_compression_ = std::make_shared<GradientCompression>();
  10. log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false);
  11. }

处理控制请求

server接收到worker0发来的命令后,会根据命令的类型,执行不同的操作。例如,当worker发来StopServer的命令后,server就会被停止。相应的命令执行完毕后,server会发送一个响应给worker0。注意这里负责发送响应的不是ps::KVWorker<char>类型的对象,而是ps::SimpleApp类型的对象。

  1. void CommandHandle(const ps::SimpleData& recved, ps::SimpleApp* app) {
  2. CommandType recved_type = static_cast<CommandType>(recved.head);
  3. switch (recved_type) {
  4. case CommandType::kStopServer:
  5. exec_.Stop();
  6. break;
  7. case CommandType::kSyncMode:
  8. sync_mode_ = true;
  9. break;
  10. case CommandType::kSetGradientCompression:
  11. gradient_compression_->DecodeParams(recved.body);
  12. break;
  13. case CommandType::kSetProfilerParams:
  14. // last char is the type of profiler command
  15. ProcessServerProfilerCommands(static_cast<KVStoreServerProfilerCommand>
  16. (recved.body.back() - '0'),
  17. recved.body);
  18. break;
  19. case CommandType::kSetMultiPrecision:
  20. // uses value 1 for message id from frontend
  21. if (!multi_precision_) {
  22. multi_precision_ = true;
  23. CreateMultiPrecisionCopies();
  24. }
  25. break;
  26. case CommandType::kController:
  27. // this uses value 0 for message id from frontend
  28. // let the main thread to execute ctrl, which is necessary for python
  29. exec_.Exec([this, recved]() {
  30. CHECK(controller_);
  31. controller_(recved.head, recved.body);
  32. });
  33. break;
  34. }
  35. app->Response(recved);
  36. }

处理数据请求

前面提到,DataHandleEx被注册为处理数据请求的函数,它会根据数据请求类型去调用不同的处理函数。默认情况下会调用DataHandleDefalut,该函数会对worker发来的push和pull请求分开处理。当worker节点push梯度到server时,如果某个key是第一次被push,那么server会为相应的key申请内存空间;否则会根据sync_mode_的值分别进行处理。在sync_mode_ == true(即同步训练模式)的情况下,所有worker上的梯度会被聚合到update_buf_[key].merged中;而在异步训练模式下,server把从某个worker接收的梯度放在update_buf_[key].temp_array中。随后,worker发来的push请求信息会被记录到update_buf_[key].request中。待上面的工作完成后,会调用ApplyUpdates函数去更新key对应的模型参数。当worker节点向server节点发送pull请求时,server会直接调用DefaultStorageResponse把server节点最新的模型参数发送给worker。

  1. void DataHandleDefault(const DataHandleType type, const ps::KVMeta& req_meta,
  2. const ps::KVPairs<char>& req_data, ps::KVServer<char>* server) {
  3. int key = DecodeKey(req_data.keys[0]);
  4. auto& stored = store_[key];
  5. if (req_meta.push) { // push operation
  6. size_t ds[] = {(size_t) req_data.lens[0] / mshadow::mshadow_sizeof(type.dtype)};
  7. mxnet::TShape dshape(ds, ds + 1);
  8. TBlob recv_blob;
  9. MSHADOW_REAL_TYPE_SWITCH(type.dtype, DType, {
  10. recv_blob = TBlob(reinterpret_cast<DType*>(req_data.vals.data()), dshape, cpu::kDevMask);
  11. })
  12. NDArray recved = NDArray(recv_blob, 0);
  13. if (stored.is_none()) { // the first push request
  14. // initialization
  15. stored = NDArray(dshape, Context(), false, type.dtype);
  16. CopyFromTo(recved, &stored, 0);
  17. server->Response(req_meta);
  18. stored.WaitToRead();
  19. } else {
  20. auto& updates = update_buf_[key];
  21. if (sync_mode_ && updates.merged.is_none() {
  22. updates.merged = NDArray(dshape, Context(), false, type.dtype);
  23. }
  24. if (updates.request.empty()) { // the first
  25. if (sync_mode_) {
  26. CopyFromTo(recvd, updates.merged);
  27. } else { // async training
  28. updates.temp_array = recved;
  29. }
  30. } else {
  31. updates.merged += recved;
  32. }
  33. updates.request.push_back(req_meta);
  34. ApplyUpdates(type, key, req_data, &updates, server);
  35. } else { // pull operation
  36. DefaultStorageResponse(type, key, req_meta, req_data, server);
  37. }
  38. }

函数ApplyUpdates实现了模型权重更新的核心逻辑。如果是异步训练模式,或者当前的update_buf中的push请求数量等于worker的数量(意味着server收到了所有worker上的梯度),那么就会执行参数的更新过程;否则就不进行更新,直接调用server->Response给worker发一个不带任何数据的响应消息,表示收到了相关的数据。如果server端设置了更新器updater_,那么就会在server端执行更新操作;否则,server只对梯度进行聚合。如下代码的7~16行描述了这一过程,更新或聚合的结果会被存放到store_[key]中。由于update_buf_[key].request中保存的请求既有可能是push,也有可能是pushpull(唯独不可能是pull,因为我们只在req_meta.push==true时才把req_meta加入到update_buf_[key].request中),因此我们还要额外处理pushpull这类请求。对于update_buf_[key].request中的每个请求,如果该请求req.pull==true,那么就调用DefaultStorageResponse把模型权重传输给worker。在更新过程完成后,update_buf_[key].request就会被清空,以等待下一次更新。

  1. inline void ApplyUpdates(const DataHandleType type, const int key,
  2. const ps::KVPairs<char>& req_data, UpdateBuf *update_buf,
  3. ps::KVServer<char>* server) {
  4. if (!sync_mode_ || update_buf->request.size() == (size_t) ps::NumWorkers()) {
  5. // let the main thread to execute updater_, which is necessary for python
  6. auto& stored = store_[key];
  7. auto& update = sync_mode_ ? update_buf->merged : update_buf->temp_array;
  8. if (updater_) { // update_on_kvstore == True
  9. exec_.Exec([this, key, &update, &stored](){
  10. CHECK(updater_);
  11. updater_(key, update, &stored);
  12. });
  13. } else { // update_on_kvstore == False, only support for sync mode
  14. CHECK(sync_mode_) << "Updater needs to be set for async mode";
  15. // if no updater, just copy
  16. CopyFromTo(update_buf->merged, &stored);
  17. }
  18. /**
  19. * Request can be for either push or pushpull
  20. * If pull flag is set, respond immediately with the updated values
  21. * Otherwise, only send the notification
  22. */
  23. bool has_pull = false;
  24. for (const auto& req : update_buf->request) {
  25. has_pull = has_pull || req.pull;
  26. }
  27. if (has_pull) {
  28. // if there is a pull request, perform WaitToRead() once before DefaultStorageResponse
  29. stored.WaitToRead();
  30. for (const auto& req : update_buf->request) {
  31. if (req.pull) {
  32. DefaultStorageResponse(type, key, req, req_data, server);
  33. }
  34. }
  35. update_buf->request.clear();
  36. } else {
  37. // otherwise, send response directly
  38. for (const auto& req : update_buf->request) {
  39. server->Response(req);
  40. }
  41. update_buf->request.clear();
  42. stored.WaitToRead();
  43. }
  44. } else { // donot perform update operation
  45. update_buf->merged.WaitToRead();
  46. }
  47. }

DefaultStorageResponse会根据传入的req_metareq_data这两个参数针对worker的push请求构建出对应的带数据的响应消息。响应是一个ps::KVPairs<char>类型的对象,其中的数据部分拷贝自store_[key]。响应对象构建完成后,同样会调用server->Response将消息发回对应的worker。

  1. void DefaultStorageResponse(const DataHandleType type,
  2. const int key,
  3. const ps::KVMeta& req_meta,
  4. const ps::KVPairs<char> &req_data,
  5. ps::KVServer<char>* server) {
  6. ps::KVPairs<char> response;
  7. const NDArray& stored = store_[key];
  8. CHECK(!stored.is_none()) << "init " << key << " first";
  9. auto len = stored.shape().Size() * mshadow::mshadow_sizeof(stored.dtype());
  10. response.keys = req_data.keys;
  11. response.lens = {len};
  12. // TODO(mli) try to remove this CopyFrom
  13. response.vals.CopyFrom(static_cast<const char*>(stored.data().dptr_), len);
  14. server->Response(req_meta, response);
  15. }

MXNet源码分析 | KVStore进程间通信的更多相关文章

  1. MXNet源码分析 | KVStore进程内通信

    本文主要基于MXNet1.6.0版本进行分析. MXNet的KVStore模块下有几个比较重要的类.KVStore是一个抽象类,提供了一些通用的API,例如Init.Push和Pull等.因为KVSo ...

  2. MXNet源码分析 | Gluon接口分布式训练流程

    本文主要基于MXNet1.6.0版本,对Gluon接口的分布式训练过程进行简要分析. 众所周知,KVStore负责MXNet分布式训练过程中参数的同步,那么它究竟是如何应用在训练中的呢?下面我们将从G ...

  3. 源码分析——从AIDL的使用开始理解Binder进程间通信的流程

    源码分析——从AIDL的使用开始理解Binder进程间通信的流程 Binder通信是Android系统架构的基础.本文尝试从AIDL的使用开始理解系统的Binder通信. 0x00 一个AIDL的例子 ...

  4. Android源码分析-消息队列和Looper

    转载请注明出处:http://blog.csdn.net/singwhatiwanna/article/details/17361775 前言 上周对Android中的事件派发机制进行了分析,这次博主 ...

  5. android消息处理源码分析

    一.简介消息处理机制主要涉及到这几个类:1.Looper2.MessageQueue3.Message4.Handler 二.源码分析 Looper.class的关键源码: //保存Looper对象, ...

  6. wifidog源码分析 - wifidog原理 tiger

    转:http://www.cnblogs.com/tolimit/p/4223644.html wifidog源码分析 - wifidog原理 wifidog是一个用于配合认证服务器实现无线网页认证功 ...

  7. 6. SOFAJRaft源码分析— 透过RheaKV看线性一致性读

    开篇 其实这篇文章我本来想在讲完选举的时候就开始讲线性一致性读的,但是感觉直接讲没头没尾的看起来比比较困难,所以就有了RheaKV的系列,这是RheaKV,终于可以讲一下SOFAJRaft的线性一致性 ...

  8. python基础-11 socket,IO多路复用,select伪造多线程,select读写分离。socketserver源码分析

    Socket socket通常也称作"套接字",用于描述IP地址和端口,是一个通信链的句柄,应用程序通常通过"套接字"向网络发出请求或者应答网络请求. sock ...

  9. 🏆【Alibaba微服务技术系列】「Dubbo3.0技术专题」回顾Dubbo2.x的技术原理和功能实现及源码分析(温故而知新)

    RPC服务 什么叫RPC? RPC[Remote Procedure Call]是指远程过程调用,是一种进程间通信方式,他是一种技术的思想,而不是规范.它允许程序调用另一个地址空间(通常是共享网络的另 ...

随机推荐

  1. 解决excel两表之间数据关联关系,知道这几招就够了

    用过SAP的凭证批量录入模板(Excel文件)的都知道,一个凭证由[抬头]和多个[行项目]组成,这是一个关于excel两表信息关联的典型场景. 这里头蕴藏着一个麻烦:当我们需要一次性录入多个凭证时,如 ...

  2. Ubuntu16桌面版编译OpenCV4的java库和so库

    欢迎访问我的GitHub https://github.com/zq2599/blog_demos 内容:所有原创文章分类汇总及配套源码,涉及Java.Docker.Kubernetes.DevOPS ...

  3. Api自动生成

    如果经常对接api, 可以自己写一个自动化生成代码,提高效率 只抛出一个思路,暂不提供源码 使用json+字符串处理+生成文件 发送一个请求,返回字符串转换为 Newtonsoft.Json.Linq ...

  4. 今日学习——蓝桥杯 2019年 C语言 B组

    1.手淦(亲身体验,,,没啥大用,最终还是代码) 2.代码(下面是我看其他博主代码答案能看的懂的....具体的可以直接去下面的网址看) https://blog.csdn.net/qq_4452491 ...

  5. jsp文本框输入限制问题

    1.jsp文本窗口实现控制输入格式 <input onkeyup = "value=value.replace(/[\W]/g,'')" onbeforepaste=&quo ...

  6. vs python2.7 bug

    微软vs里面小细节的bug真他妈的多

  7. 推荐召回--基于用户的协同过滤UserCF

    目录 1. 前言 2. 原理 3. 数据及相似度计算 4. 根据相似度计算结果 5. 相关问题 5.1 如何提炼用户日志数据? 5.2 用户相似度计算很耗时,有什么好的方法? 5.3 有哪些改进措施? ...

  8. Django 优化杂谈

    Django 优化杂谈 Apr 21 2017 总结下最近看过的一些文章,然后想到的一些优化点,整理一下. 数据库连接池 http://mt.dbanotes.net/arch/instagram.h ...

  9. 字的研究(3)fontTools-TrueType轮廓坐标的获取以及基于TrueType的Glyph实例的构建

    前言 本文主要介绍如果使用Python第三方库fontTools提取OpenType字体文件中的TrueType轮廓坐标以及如何构建基于TrueType的Glyph实例 TrueType轮廓坐标的获取 ...

  10. 阅读笔记——长文本匹配《Matching Article Pairs with Graphical Decomposition and Convolutions》

    论文题目:Matching Article Pairs with Graphical Decomposition and Convolutions 发表情况:ACL2019 腾讯PCG小组 模型简介 ...