TensorFlow的图切割模块——Graph Partitioner
Graph Partition切割流程

第一步——分析构建Control Flow相关信息
- GraphInfo g_info;
- if (!opts.control_flow_added) {
- // Add the "code" for distributed execution of control flow. Code is
- // added only for the frames that are placed on multiple devices. The
- // new graph is an equivalent transformation of the original graph and
- // has the property that it can be subsequently partitioned arbitrarily
- // (down to the level of individual device) for distributed execution.
- status = AddControlFlow(opts, g, &g_info);
- if (!status.ok()) return status;
- }
第二步——构建Op的Input和Output Memory类型信息
- // MemoryType is used to describe whether input or output Tensors of
- // an OpKernel should reside in "Host memory" (e.g., CPU memory) or
- // "Device" Memory (CPU memory for CPU devices, GPU memory for GPU
- // devices).
- enum MemoryType {
- };
- #define REGISTER_GPU_KERNEL(type) \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tshape"), \
- ReshapeOp); \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int64>("Tshape"), \
- ReshapeOp);
上面的宏显示,虽然Reshape Op确实在GPU上有注册的实现版本,但是它依然要使用HostMemory。另外,某些Tensor的类型也决定了其是否可以被放置到Device Memory上,一般情况下float类型的数据对于计算设备是非常友好的,而String类型就不是这样,所以在types.cc文件中规定了一些强制被放在HostMemory的数据类型,如下代码所示。
- bool DataTypeAlwaysOnHost(DataType dt) {
- // Includes DT_STRING and DT_RESOURCE.
- switch (dt) {
- case DT_STRING:
- return true;
- default:
- return false;
- }
- }

- // Check whether there is already a send/recv pair transferring
- // the same tensor/control from the src to dst partition.
- const bool on_host = IsDstInputOnHost(edge, g_info);
- DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host};
- auto iter = dup_recv.find(key);
- if (iter != dup_recv.end()) {
- // We found one. Reuse the data/control transferred already.
- const string& recv_node_name = iter->second.recv->name();
- if (edge->IsControlEdge()) {
- AddInput(dst_def, recv_node_name, Graph::kControlSlot);
- } else {
- AddInput(dst_def, recv_node_name, );
- }
- ref_control_inputs.push_back(recv_node_name);
- // We want the start_time for the recv to be the smallest of the start
- // times of it's consumers. So we update this whenever we use a recv,
- // and write it out to the attribute at the end of the subroutine
- if (iter->second.start_time > recv_start_time) {
- iter->second.start_time = recv_start_time;
- }
- continue;
- }
- const FunctionLibraryDefinition* flib_def = opts.flib_def;
- if (flib_def == nullptr) {
- flib_def = &g->flib_def();
- }
- // Set versions, function library and send/recv incarnation.
- for (auto& it : *partitions) {
- GraphDef* gdef = &it.second;
- *gdef->mutable_versions() = g->versions();
- // Prune unreachable functions from `flib_def` before adding them to `gdef`.
- *gdef->mutable_library() = flib_def->ReachableDefinitions(*gdef).ToProto();
- // Traverse the graph to fill every send/recv op's incarnation
- // information.
- SetIncarnation(opts, gdef);
- }
- // Need to split edge by placing matching send/recv nodes on
- // the src/dst sides of the edge.
- NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from,
- send_start_time, &status);
- if (!status.ok()) return status;
- NodeDef* real_recv = nullptr;
- NodeDef* recv =
- AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status);
- if (!status.ok()) return status;
因为同一个Device上的Send和Recv节点在执行过程中实际上Memory Copy,而Recv的kernel又是异步的,所以需要有一种机制保证保证Recv一定要在Send之后执行,因此需要在Send和Recv之间插入一个Control Edge,从图的依赖上保证它们的执行顺序。
这个过程的关键是在插入Send和Recv节点之后,需要插入额外的Control Edge,代码如下。
- // Fix up the control flow edge.
- // NOTE(yuanbyu): 'real_recv' must be the real recv node.
- if (src_graph == dst_graph) {
- // For same device send/recv, add a control edge from send to recv.
- // This prevents the asynchronous recv kernel from being scheduled
- // before the data is available.
- AddInput(real_recv, send->name(), Graph::kControlSlot);
- }


- NodeDefBuilder::NodeOut send_from;
- if (edge->IsControlEdge()) {
- // Insert a dummy const node that will generate a tiny
- // data element to be sent from send to recv.
- VLOG() << "Send/Recv control: " << src->assigned_device_name() << "["
- << src->name() << "] -> " << dst->assigned_device_name() << "["
- << dst->name() << "]";
- NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status);
- if (!status.ok()) return status;
- // Set the start time for this dummy node.
- if (opts.scheduling_for_recvs) {
- AddNodeAttr("_start_time", send_start_time, dummy);
- }
- AddInput(dummy, src->name(), Graph::kControlSlot);
- send_from.Reset(dummy->name(), , DT_FLOAT);
- } else {
- send_from.Reset(src->name(), edge->src_output(), EdgeType(edge));
- }
- // Add the cast node (from cast_dtype to dtype) or an Identity node.
- if (dtype != cast_dtype) {
- const string cast_op = (host_memory) ? "_HostCast" : "Cast";
- NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op);
- cast_builder.Attr("DstT", dtype);
- cast_builder.Device(dst->assigned_device_name())
- .Input(recv->name(), , cast_dtype);
- NodeDef* cast = gdef->add_node();
- *status = cast_builder.Finalize(cast);
- if (!status->ok()) return nullptr;
- return cast;
- } else if (edge->IsControlEdge()) {
- // An Identity is only needed for control edges.
- NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity");
- id_builder.Device(dst->assigned_device_name())
- .Input(recv->name(), , cast_dtype);
- NodeDef* id = gdef->add_node();
- *status = id_builder.Finalize(id);
- if (!status->ok()) return nullptr;
- return id;
- } else {
- return recv;
- }
- graph_options = tf.GraphOptions(enable_bfloat16_sendrecv=True)
- session_config = tf.ConfigProto(gpu_options=gpu_options)
