xref: /aosp_15_r20/external/tensorflow/tensorflow/core/graph/graph_partition.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/graph_partition.h"
17 
18 #include <deque>
19 #include <queue>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/memory_types.h"
28 #include "tensorflow/core/framework/node_def_builder.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/graph/control_flow.h"
34 #include "tensorflow/core/graph/costmodel.h"
35 #include "tensorflow/core/graph/graph_def_builder.h"
36 #include "tensorflow/core/graph/node_builder.h"
37 #include "tensorflow/core/graph/tensor_id.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/hash/hash.h"
40 #include "tensorflow/core/lib/strings/str_util.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43 #include "tensorflow/core/util/dump_graph.h"
44 
45 namespace tensorflow {
46 
47 namespace {
48 
IsMerge(const NodeDef & node_def)49 inline bool IsMerge(const NodeDef& node_def) {
50   return node_def.op() == "Merge" || node_def.op() == "RefMerge" ||
51          node_def.op() == "_XlaMerge";
52 }
53 
IsNextIteration(const NodeDef & node_def)54 inline bool IsNextIteration(const NodeDef& node_def) {
55   return node_def.op() == "NextIteration" ||
56          node_def.op() == "RefNextIteration";
57 }
58 
59 struct DupRecvKey {
60   int src_node_id;           // Edge's src node id
61   int src_output_slot;       // Edge's src node output slot
62   GraphDef* dst_graph;       // Edge's dst node is in this subgraph
63   bool recv_output_on_host;  // The output of recv is on host
64 
65   template <typename H>
AbslHashValue(H h,const DupRecvKey & c)66   friend H AbslHashValue(H h, const DupRecvKey& c) {
67     return H::combine(std::move(h), c.src_node_id, c.src_output_slot,
68                       reinterpret_cast<std::uintptr_t>(c.dst_graph),
69                       c.recv_output_on_host);
70   }
71 
operator ==(const DupRecvKey & x,const DupRecvKey & y)72   friend bool operator==(const DupRecvKey& x, const DupRecvKey& y) {
73     return (x.src_node_id == y.src_node_id) &&
74            (x.src_output_slot == y.src_output_slot) &&
75            (x.dst_graph == y.dst_graph) &&
76            (x.recv_output_on_host == y.recv_output_on_host);
77   }
78 };
79 
80 // struct used to store the recvs, so that start times can be properly updated
81 struct RecvInfo {
82   NodeDef* recv;
83   NodeDef* real_recv;
84   int64_t start_time;
85 };
86 
87 typedef absl::flat_hash_map<DupRecvKey, RecvInfo> DupRecvTable;
88 
89 // A map used to store memory types for the inputs/outputs of every node.
90 // The key is a pair of ints consisting of a node id and input/output index.
91 // TODO(power): migrate back to std::pair when absl::Hash is fixed for MSVC.
92 struct NodePort {
93   int node_id;
94   int index;
95 
operator ==(const NodePort & x,const NodePort & y)96   friend bool operator==(const NodePort& x, const NodePort& y) {
97     return x.node_id == y.node_id && x.index == y.index;
98   }
99 
100   template <typename H>
AbslHashValue(H h,const NodePort & c)101   friend H AbslHashValue(H h, const NodePort& c) {
102     return H::combine(std::move(h), c.node_id, c.index);
103   }
104 };
105 
106 typedef absl::flat_hash_map<NodePort, MemoryType> MemoryTypeMap;
107 
108 // We collect the following information about the graph before performing
109 // graph partitioning.
110 struct GraphInfo {
111   std::vector<DeviceType> device_types;
112   MemoryTypeMap input_types;
113   MemoryTypeMap output_types;
114   std::vector<ControlFlowInfo> cf_info;
115 };
116 
EdgeType(const Edge * e)117 DataType EdgeType(const Edge* e) {
118   if (e->IsControlEdge()) {
119     return DT_FLOAT;
120   } else {
121     return e->dst()->input_type(e->dst_input());
122   }
123 }
124 
125 // Return true iff we need to add the same device send/recv for 'edge'.
NeedSameDeviceSendRecv(const Edge * edge,const GraphInfo & info)126 bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) {
127   if (edge->IsControlEdge()) {
128     return false;
129   }
130 
131   const Node* src = edge->src();
132   const Node* dst = edge->dst();
133   if (src->assigned_device_name() == dst->assigned_device_name()) {
134     int src_port = edge->src_output();
135     int dst_port = edge->dst_input();
136     if (info.device_types[src->id()] != DEVICE_CPU) {
137       auto src_it = info.output_types.find({src->id(), src_port});
138       DCHECK(src_it != info.output_types.end());
139       auto dst_it = info.input_types.find({dst->id(), dst_port});
140       DCHECK(dst_it != info.input_types.end());
141       return src_it->second != dst_it->second;
142     }
143   }
144   return false;
145 }
146 
147 // Return true iff (dst, dst_input) is specified on host memory.
IsDstInputOnHost(const Edge * edge,const GraphInfo & info)148 bool IsDstInputOnHost(const Edge* edge, const GraphInfo& info) {
149   const Node* dst = edge->dst();
150   int dst_port = edge->dst_input();
151   if (info.device_types[dst->id()] != DEVICE_CPU) {
152     if (edge->IsControlEdge()) return false;
153     auto dst_it = info.input_types.find({dst->id(), dst_port});
154     DCHECK(dst_it != info.input_types.end());
155     return dst_it->second == HOST_MEMORY;
156   }
157   return true;
158 }
159 
160 // Add an input to dst that comes from the "src_slot" output of the
161 // node named by "src_name".
AddInput(NodeDef * dst,StringPiece src_name,int src_slot)162 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
163   if (src_slot == Graph::kControlSlot) {
164     dst->add_input(strings::StrCat("^", src_name));
165   } else if (src_slot == 0) {
166     dst->add_input(src_name.data(), src_name.size());
167   } else {
168     dst->add_input(strings::StrCat(src_name, ":", src_slot));
169   }
170 }
171 
172 // Add a control edge from each input to each recv.
AddReadControl(const std::vector<NodeDef * > & recvs,const std::vector<string> & inputs)173 void AddReadControl(const std::vector<NodeDef*>& recvs,
174                     const std::vector<string>& inputs) {
175   for (NodeDef* recv : recvs) {
176     for (const string& input : inputs) {
177       recv->add_input(strings::StrCat("^", input));
178     }
179   }
180 }
181 
SetSendRecvAttrs(const PartitionOptions & opts,const Edge * edge,const string & tensor_name_attr,NodeDefBuilder * builder)182 void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge,
183                       const string& tensor_name_attr, NodeDefBuilder* builder) {
184   builder->Attr("tensor_name", tensor_name_attr);
185   builder->Attr("send_device", edge->src()->assigned_device_name());
186   builder->Attr("send_device_incarnation",
187                 static_cast<int64_t>(
188                     opts.get_incarnation(edge->src()->assigned_device_name())));
189   builder->Attr("recv_device", edge->dst()->assigned_device_name());
190   builder->Attr("client_terminated", false);
191   builder->Attr("_src", edge->src()->name());
192   builder->Attr("_dst", edge->dst()->name());
193 }
194 
AddSend(const PartitionOptions & opts,const GraphInfo & g_info,GraphDef * gdef,const Edge * edge,NodeDefBuilder::NodeOut send_from,int64_t start_time,const string & tensor_name_attr,Status * status)195 NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
196                  GraphDef* gdef, const Edge* edge,
197                  NodeDefBuilder::NodeOut send_from, int64_t start_time,
198                  const string& tensor_name_attr, Status* status) {
199   const DataType dtype = send_from.data_type;
200   const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype;
201   const Node* src = edge->src();
202   const int src_port = edge->src_output();
203 
204   // host_memory = true iff we need to use HostSend/HostCast.
205   bool host_memory = false;
206   if (!edge->IsControlEdge()) {
207     auto src_it = g_info.output_types.find({src->id(), src_port});
208     DCHECK(src_it != g_info.output_types.end());
209     host_memory = (src_it->second == HOST_MEMORY);
210   }
211 
212   // Add a cast node that casts dtype to cast_dtype.
213   // NOTE(yuanbyu): Only cast for cross-device send/recv.
214   if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) {
215     const string cast_op = (host_memory) ? "_HostCast" : "Cast";
216     NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
217                                 NodeDebugInfo(*src));
218     cast_builder.Device(src->assigned_device_name()).Input(send_from);
219     if (opts.scheduling_for_recvs) {
220       cast_builder.Attr("_start_time", start_time);
221     }
222     cast_builder.Attr("DstT", cast_dtype);
223 
224     if (cast_dtype == DT_BFLOAT16) {
225       // the below attribute specifies that the cast to bfloat16 should use
226       // truncation. This is needed to retain legacy behavior when we change
227       // the default bfloat16 casts to use rounding instead of truncation
228       cast_builder.Attr("Truncate", true);
229     }
230 
231     NodeDef* cast = gdef->add_node();
232     *status = cast_builder.Finalize(cast, /*consume=*/true);
233     if (!status->ok()) return nullptr;
234 
235     // Connect the Send op to the cast.
236     send_from.Reset(cast->name(), 0, cast_dtype);
237   }
238 
239   // Add the send node.
240   const string send_op = (host_memory) ? "_HostSend" : "_Send";
241   NodeDefBuilder send_builder(opts.new_name(src->name()), send_op,
242                               NodeDebugInfo(*src));
243   SetSendRecvAttrs(opts, edge, tensor_name_attr, &send_builder);
244   send_builder.Device(src->assigned_device_name()).Input(send_from);
245   if (opts.scheduling_for_recvs) {
246     send_builder.Attr("_start_time", start_time);
247   }
248   NodeDef* send = gdef->add_node();
249   *status = send_builder.Finalize(send, /*consume=*/true);
250   return send;
251 }
252 
AddRecv(const PartitionOptions & opts,const GraphInfo & g_info,GraphDef * gdef,const Edge * edge,NodeDef ** real_recv,const string & tensor_name_attr,Status * status)253 NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info,
254                  GraphDef* gdef, const Edge* edge, NodeDef** real_recv,
255                  const string& tensor_name_attr, Status* status) {
256   const DataType dtype = EdgeType(edge);
257   const Node* src = edge->src();
258   const Node* dst = edge->dst();
259   const int dst_port = edge->dst_input();
260   DataType cast_dtype = dtype;
261 
262   // NOTE(yuanbyu): Only cast for cross-device send/recv.
263   if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) {
264     cast_dtype = opts.should_cast(edge);
265   }
266 
267   // host_memory = true iff we need to use HostRecv/HostCast.
268   // Also log the introduction of the send-recv pair, for performance debugging.
269   bool host_memory = false;
270   if (!edge->IsControlEdge()) {
271     auto dst_it = g_info.input_types.find({dst->id(), dst_port});
272     DCHECK(dst_it != g_info.input_types.end());
273     host_memory = (dst_it->second == HOST_MEMORY);
274     bool src_host_memory = false;
275     if (VLOG_IS_ON(1)) {
276       const int src_port = edge->src_output();
277       auto src_it = g_info.output_types.find({src->id(), src_port});
278       DCHECK(src_it != g_info.output_types.end());
279       src_host_memory = (src_it->second == HOST_MEMORY);
280     }
281     VLOG(1) << "Receiving data"
282             << " from " << src->name() << " (" << src->type_string() << ")"
283             << " on " << src->assigned_device_name() << " in "
284             << (src_host_memory ? "host memory" : "device memory") << " for "
285             << dst->name() << " (" << dst->type_string() << ")"
286             << " on " << dst->assigned_device_name() << " in "
287             << (host_memory ? "host memory" : "device memory");
288   } else {
289     // Log control-edge transfers too, but don't mention memory space since it's
290     // irrelevant.
291     VLOG(1) << "Receiving control"
292             << " from " << src->name() << " (" << src->type_string() << ")"
293             << " on " << src->assigned_device_name() << " for " << dst->name()
294             << " (" << dst->type_string() << ")"
295             << " on " << dst->assigned_device_name();
296   }
297 
298   // Add the recv node.
299   const string recv_op = (host_memory) ? "_HostRecv" : "_Recv";
300   NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op,
301                               NodeDebugInfo(*src));
302   SetSendRecvAttrs(opts, edge, tensor_name_attr, &recv_builder);
303   recv_builder.Device(dst->assigned_device_name())
304       .Attr("tensor_type", cast_dtype);
305   NodeDef* recv = gdef->add_node();
306   *status = recv_builder.Finalize(recv, /*consume=*/true);
307   if (!status->ok()) return nullptr;
308   *real_recv = recv;
309 
310   // Add the cast node (from cast_dtype to dtype) or an Identity node.
311   if (dtype != cast_dtype) {
312     const string cast_op = (host_memory) ? "_HostCast" : "Cast";
313     NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
314                                 NodeDebugInfo(*src));
315     cast_builder.Attr("DstT", dtype);
316     cast_builder.Device(dst->assigned_device_name())
317         .Input(recv->name(), 0, cast_dtype);
318     NodeDef* cast = gdef->add_node();
319     *status = cast_builder.Finalize(cast, /*consume=*/true);
320     if (!status->ok()) return nullptr;
321     return cast;
322   } else if (edge->IsControlEdge()) {
323     // An Identity is only needed for control edges.
324     NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity",
325                               NodeDebugInfo(*src));
326     id_builder.Device(dst->assigned_device_name())
327         .Input(recv->name(), 0, cast_dtype);
328     NodeDef* id = gdef->add_node();
329     *status = id_builder.Finalize(id, /*consume=*/true);
330     if (!status->ok()) return nullptr;
331     return id;
332   } else {
333     return recv;
334   }
335 }
336 
AddDummyConst(const PartitionOptions & opts,GraphDef * gdef,const Edge * edge,Status * status)337 NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef,
338                        const Edge* edge, Status* status) {
339   const Node* src = edge->src();
340   Tensor tensor(DT_FLOAT, TensorShape({0}));
341   NodeDef* result = gdef->add_node();
342   *status = NodeDefBuilder(opts.new_name(src->name()), "Const")
343                 .Device(src->assigned_device_name())
344                 .Attr("dtype", DT_FLOAT)
345                 .Attr("value", tensor)
346                 .Finalize(result, /*consume=*/true);
347   return result;
348 }
349 
350 // A dummy node for scheduling.
AddControlTrigger(const PartitionOptions & opts,GraphDef * gdef,const string & assigned_device_name,int64_t epoch,int64_t starttime,Status * status)351 NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
352                            const string& assigned_device_name, int64_t epoch,
353                            int64_t starttime, Status* status) {
354   NodeDef* result = gdef->add_node();
355   *status = NodeDefBuilder(opts.new_name(strings::StrCat("synch_", epoch)),
356                            "ControlTrigger")
357                 .Device(assigned_device_name)
358                 .Attr("_start_time", starttime)
359                 .Finalize(result, /*consume=*/true);
360   return result;
361 }
362 
363 // Optimize colocation for control flow nodes. For cond, we want the
364 // switch nodes to colocate with its data input. This is particularly
365 // needed for conditional reading of a remote variable. It may also
366 // reduce the number of devices involved in a loop.
367 // TODO(yuanbyu): In this case, we don't respect the requested device in
368 // the GraphDef for these nodes. Ideally, the placer would enforce the
369 // colocation to render this unnecessary.
OptimizeControlFlowColocation(Graph * graph)370 void OptimizeControlFlowColocation(Graph* graph) {
371   auto visit = [](Node* node) {
372     if (IsSwitch(node)) {
373       for (const Edge* in_edge : node->in_edges()) {
374         if (in_edge->dst_input() == 0) {
375           // Colocate with the data input.
376           node->set_assigned_device_name(
377               in_edge->src()->assigned_device_name());
378           return;
379         }
380       }
381     } else if (IsExit(node)) {
382       for (const Edge* in_edge : node->in_edges()) {
383         if (!in_edge->IsControlEdge()) {
384           // Colocate with upstream node.
385           node->set_assigned_device_name(
386               in_edge->src()->assigned_device_name());
387           return;
388         }
389       }
390     } else {
391       if ((IsEnter(node) && !IsRefType(node->input_type(0))) ||
392           IsNextIteration(node)) {
393         const Edge* data_edge = nullptr;
394         for (const Edge* out_edge : node->out_edges()) {
395           if (!out_edge->IsControlEdge()) {
396             data_edge = out_edge;
397             break;
398           }
399         }
400         // Colocate with the first downstream data node.
401         if (data_edge) {
402           node->set_assigned_device_name(
403               data_edge->dst()->assigned_device_name());
404         }
405       }
406     }
407   };
408   DFS(*graph, visit, {});
409 }
410 
ControlLoopName(const string & name)411 string ControlLoopName(const string& name) {
412   return strings::StrCat("_cloop", name);
413 }
414 
IsControlLoop(const Node * node)415 bool IsControlLoop(const Node* node) {
416   const string& name = node->name();
417   return absl::StartsWith(name, "_cloop");
418 }
419 
420 // An enter node for control flow.
AddControlEnter(Graph * g,const string & node_name,const string & device_name,const string & frame_name,const int parallel_iterations,Status * status)421 Node* AddControlEnter(Graph* g, const string& node_name,
422                       const string& device_name, const string& frame_name,
423                       const int parallel_iterations, Status* status) {
424   NodeBuilder node_builder(node_name, "Enter", g->op_registry());
425   node_builder.Input({"dummy", 0, DT_FLOAT});
426   node_builder.Attr("frame_name", frame_name);
427   node_builder.Attr("parallel_iterations", parallel_iterations);
428   Node* res_node;
429   *status = node_builder.Finalize(g, &res_node, /*consume=*/true);
430   if (!status->ok()) return nullptr;
431   res_node->set_assigned_device_name(device_name);
432   return res_node;
433 }
434 
435 // A merge node for control flow.
AddControlMerge(const string & in_name1,const string & in_name2,Graph * g,const string & node_name,const string & device_name,Status * status)436 Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g,
437                       const string& node_name, const string& device_name,
438                       Status* status) {
439   NodeBuilder node_builder(node_name, "Merge", g->op_registry());
440   node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}});
441   Node* res_node;
442   *status = node_builder.Finalize(g, &res_node, /*consume=*/true);
443   if (!status->ok()) return nullptr;
444   res_node->set_assigned_device_name(device_name);
445   return res_node;
446 }
447 
448 // A switch node for control flow.
AddControlSwitch(NodeBuilder::NodeOut input1,NodeBuilder::NodeOut input2,const string & device_name,const GraphDefBuilder::Options & bopts)449 Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2,
450                        const string& device_name,
451                        const GraphDefBuilder::Options& bopts) {
452   Node* res_node =
453       ops::BinaryOp("Switch", std::move(input1), std::move(input2), bopts);
454   if (bopts.HaveError()) return nullptr;
455   res_node->set_assigned_device_name(device_name);
456   return res_node;
457 }
458 
459 // A next_iteration node for control flow.
AddControlNext(NodeBuilder::NodeOut input,const string & device_name,const GraphDefBuilder::Options & bopts)460 Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name,
461                      const GraphDefBuilder::Options& bopts) {
462   Node* res_node = ops::UnaryOp("NextIteration", std::move(input), bopts);
463   if (bopts.HaveError()) return nullptr;
464   res_node->set_assigned_device_name(device_name);
465   return res_node;
466 }
467 
EmptyConst(const GraphDefBuilder::Options & options)468 Node* EmptyConst(const GraphDefBuilder::Options& options) {
469   if (options.HaveError()) return nullptr;
470   NodeBuilder node_builder(options.GetNameForOp("Const"), "Const",
471                            options.op_registry());
472   const DataType dt = DataTypeToEnum<float>::v();
473   TensorProto proto;
474   proto.set_dtype(dt);
475   TensorShape empty_shape({0});
476   empty_shape.AsProto(proto.mutable_tensor_shape());
477   node_builder.Attr("dtype", dt).Attr("value", proto);
478   return options.FinalizeBuilder(&node_builder);
479 }
480 
481 // A dummy const node for control flow.
AddControlConst(const string & device_name,const GraphDefBuilder::Options & bopts)482 Node* AddControlConst(const string& device_name,
483                       const GraphDefBuilder::Options& bopts) {
484   Node* res_node = EmptyConst(bopts);
485   if (bopts.HaveError()) return nullptr;
486   res_node->set_assigned_device_name(device_name);
487   return res_node;
488 }
489 
490 // A synthetic loop, made up of dummy nodes. It performs control-flow actions
491 // on behalf of a leader on a different device.
492 struct ControlLoop {
493   Node* enter = nullptr;
494   Node* merge = nullptr;
495   Node* switch_node = nullptr;
496 };
497 
498 // Add the control flow info of a new node added during partitioning.
499 // The new node has the same control flow info as src.
AddControlFlowInfo(const Node * node,const Node * src,std::vector<ControlFlowInfo> * cf_info)500 void AddControlFlowInfo(const Node* node, const Node* src,
501                         std::vector<ControlFlowInfo>* cf_info) {
502   int id = node->id();
503   if (static_cast<size_t>(id) >= cf_info->size()) {
504     cf_info->resize(id + 1);
505   }
506   const ControlFlowInfo& src_info = (*cf_info)[src->id()];
507   ControlFlowInfo* info = &(*cf_info)[id];
508   info->frame = src_info.frame;
509   info->parent_frame = src_info.parent_frame;
510   info->frame_name = src_info.frame_name;
511 }
512 
513 // Constructs a control loop. Returns a struct containing the newly created
514 // enter, merge, and switch nodes. The enter and merge nodes are used in the
515 // recursive construction of control loops for nested frames (loops). The
516 // switch node will be connected to the LoopCond node. The merge node will
517 // be connected to all the recvs of the same frame by control edges when
518 // the actual partitioning happens.
AddControlLoop(const PartitionOptions & opts,Graph * g,const Node * src,const Edge * edge,Node * loop_cond,std::vector<ControlFlowInfo> * cf_info,ControlLoop * loop)519 Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src,
520                       const Edge* edge, Node* loop_cond,
521                       std::vector<ControlFlowInfo>* cf_info,
522                       ControlLoop* loop) {
523   Status status;
524   GraphDefBuilder::Options bopts(g, &status);
525   const ControlFlowInfo& src_info = (*cf_info)[src->id()];
526   const string& device_name = edge->dst()->assigned_device_name();
527   const string& frame_name = src_info.frame_name;
528   int parallel_iterations;
529   status = GetNodeAttr(src_info.frame->attrs(), "parallel_iterations",
530                        &parallel_iterations);
531   if (!status.ok()) return status;
532 
533   // The names of the nodes to be added.
534   const string& enter_name =
535       ControlLoopName(opts.new_name(edge->dst()->name()));
536   const string& merge_name =
537       ControlLoopName(opts.new_name(edge->dst()->name()));
538   const string& switch_name =
539       ControlLoopName(opts.new_name(edge->dst()->name()));
540   const string& next_name = ControlLoopName(opts.new_name(edge->dst()->name()));
541 
542   // Add the nodes to the graph g.
543   Node* enter = AddControlEnter(g, enter_name, device_name, frame_name,
544                                 parallel_iterations, &status);
545   if (!status.ok()) return status;
546   Node* merge = AddControlMerge(enter_name, next_name, g, merge_name,
547                                 device_name, &status);
548   if (!status.ok()) return status;
549   Node* switch_node = AddControlSwitch(merge, loop_cond, device_name,
550                                        bopts.WithName(switch_name));
551   if (!status.ok()) return status;
552   Node* next =
553       AddControlNext({switch_node, 1}, device_name, bopts.WithName(next_name));
554   if (!status.ok()) return status;
555 
556   // Add control flow info for these new nodes:
557   AddControlFlowInfo(enter, src, cf_info);
558   AddControlFlowInfo(merge, src, cf_info);
559   AddControlFlowInfo(switch_node, src, cf_info);
560   AddControlFlowInfo(next, src, cf_info);
561 
562   // Add input edges for the newly created merge node:
563   g->AddEdge(enter, 0, merge, 0);
564   g->AddEdge(next, 0, merge, 1);
565 
566   loop->enter = enter;
567   loop->merge = merge;
568   loop->switch_node = switch_node;
569   return OkStatus();
570 }
571 
572 // Build memory and device type info for every node in the graph.
573 // TODO(yuanbyu): It might be simpler if we convert MemoryType to
574 // DeviceType for the inputs/outputs of each node.
BuildMemoryDeviceInfo(const Graph & g,GraphInfo * info)575 Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) {
576   MemoryTypeVector input_memory_types;
577   MemoryTypeVector output_memory_types;
578 
579   info->device_types.resize(g.num_node_ids(), DEVICE_CPU);
580   for (const Node* node : g.op_nodes()) {
581     DeviceNameUtils::ParsedName parsed;
582     if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(),
583                                         &parsed)) {
584       return errors::Internal("Malformed assigned device '",
585                               node->assigned_device_name(), "'");
586     }
587 
588     TF_RETURN_IF_ERROR(MemoryTypesForNode(
589         g.op_registry(), DeviceType(parsed.type), node->def(),
590         &input_memory_types, &output_memory_types));
591 
592     int node_id = node->id();
593     info->device_types[node_id] = DeviceType(parsed.type);
594     for (int i = 0; i < input_memory_types.size(); ++i) {
595       info->input_types[{node_id, i}] = input_memory_types[i];
596     }
597     for (int i = 0; i < output_memory_types.size(); ++i) {
598       info->output_types[{node_id, i}] = output_memory_types[i];
599     }
600   }
601   return OkStatus();
602 }
603 
InputFrame(const Node * node,const std::vector<ControlFlowInfo> & cf_info)604 const Node* InputFrame(const Node* node,
605                        const std::vector<ControlFlowInfo>& cf_info) {
606   // An input is in the same frame as the node except for Enter nodes.
607   // The input of Enter is in the parent frame of the Enter node.
608   if (!node->IsEnter()) {
609     return node;
610   }
611   return cf_info[node->id()].parent_frame;
612 }
613 
OutputFrame(const Node * node,const std::vector<ControlFlowInfo> & cf_info)614 const Node* OutputFrame(const Node* node,
615                         const std::vector<ControlFlowInfo>& cf_info) {
616   // An output is in the same frame as the node except for Exit nodes.
617   // The output of Exit is in the parent frame of the Exit node.
618   if (!node->IsExit()) {
619     return node;
620   }
621   return cf_info[node->id()].parent_frame;
622 }
623 
624 // Each participating device needs to decide a) if there is a next iteration,
625 // and b) if the loop terminates. We take the approach to encode this control
626 // flow logic in the dataflow graph. There are at least two possible encodings.
627 // In a completely decentralized encoding, the participants communicate peer
628 // to peer. The other encoding uses a frame leader (the participant who owns
629 // the pivot termination predicate) to broadcast the termination condition to
630 // all the participants. For now we take the latter because it is simpler.
631 //
632 // TODO(yuanbyu): The correctness of this construction is rather subtle. I got
633 // it wrong many times so it would be nice to write a proof to be sure.
AddControlFlow(const PartitionOptions & opts,Graph * g,GraphInfo * g_info)634 Status AddControlFlow(const PartitionOptions& opts, Graph* g,
635                       GraphInfo* g_info) {
636   Status status;
637   GraphDefBuilder::Options bopts(g, &status);
638   std::vector<ControlFlowInfo>& cf_info = g_info->cf_info;
639 
640   // Build the control flow info for every node.
641   status = BuildControlFlowInfo(g, &cf_info);
642   if (!status.ok()) return status;
643 
644   OptimizeControlFlowColocation(g);
645 
646   // The map from frames to their LoopCond nodes.
647   std::unordered_map<string, Node*> frame_cond_map;
648   int num_node_ids = g->num_node_ids();
649   for (int i = 0; i < num_node_ids; ++i) {
650     Node* node = g->FindNodeId(i);
651     if (node == nullptr) continue;
652 
653     if (IsLoopCond(node)) {
654       const string& frame_name = cf_info[node->id()].frame_name;
655       DCHECK(!frame_name.empty());
656       frame_cond_map[frame_name] = node;
657     }
658   }
659 
660   // Add all control loops for cross-device frames.
661   // A control loop is added only when there is a cross-device edge in a
662   // non-root frame. Nothing is added if there is no loops. We also don't
663   // add anything for a frame that is completely local to a device. For
664   // nested loops, we stack the control loops together by connecting
665   // the merge of the outer loop to the enter of the inner loop.
666   //
667   // A map from <frame_name, device_name> to ControlLoop.
668   std::unordered_map<string, ControlLoop> control_loops;
669   int num_edge_ids = g->num_edge_ids();
670   for (int i = 0; i < num_edge_ids; ++i) {
671     const Edge* edge = g->FindEdgeId(i);
672     if (edge == nullptr) continue;
673 
674     const Node* src = edge->src();
675     const Node* dst = edge->dst();
676     // Skip Sink/Source nodes.
677     if (!src->IsOp() || !dst->IsOp()) continue;
678 
679     const string& src_device = src->assigned_device_name();
680     const string& dst_device = dst->assigned_device_name();
681     // Skip local edges.
682     if (src_device == dst_device) continue;
683 
684     const Node* src_frame = OutputFrame(src, cf_info);
685     const Node* dst_frame = InputFrame(dst, cf_info);
686     const string& src_frame_name = cf_info[src_frame->id()].frame_name;
687     const string& dst_frame_name = cf_info[dst_frame->id()].frame_name;
688     // Skip if src and dst are not in the same frame.
689     if (src_frame_name.empty() || src_frame_name != dst_frame_name) {
690       continue;
691     }
692 
693     // Add the control loop. Start by adding the control loop for the
694     // current frame if needed, and recursively adding the control loop
695     // for its outer frame when nested.
696     ControlLoop child_loop;
697     while (true) {
698       const string& curr_frame_name = cf_info[src_frame->id()].frame_name;
699       if (curr_frame_name.empty()) {
700         // We have reached the root frame.
701         if (child_loop.merge != nullptr) {
702           const string& node_name = opts.new_name(edge->dst()->name());
703           const string& device_name = edge->dst()->assigned_device_name();
704           Node* const_node =
705               AddControlConst(device_name, bopts.WithName(node_name));
706           if (!status.ok()) return status;
707           AddControlFlowInfo(const_node, src_frame, &cf_info);
708           g->AddEdge(const_node, 0, child_loop.enter, 0);
709         }
710         break;
711       }
712 
713       const string& cl_key = strings::StrCat(curr_frame_name, "$$", dst_device);
714       auto it = control_loops.find(cl_key);
715       if (it != control_loops.end()) {
716         if (child_loop.enter != nullptr) {
717           g->AddEdge(it->second.merge, 0, child_loop.enter, 0);
718         }
719         break;
720       }
721 
722       // Get the frame's LoopCond.
723       auto cond_it = frame_cond_map.find(curr_frame_name);
724       if (cond_it == frame_cond_map.end()) {
725         return errors::InvalidArgument(
726             "A cross-device loop must have a pivot predicate: ",
727             curr_frame_name);
728       }
729       Node* loop_cond = cond_it->second;
730 
731       // Add the control loop.
732       ControlLoop curr_loop;
733       status = AddControlLoop(opts, g, src_frame, edge, loop_cond, &cf_info,
734                               &curr_loop);
735       if (!status.ok()) return status;
736       control_loops[cl_key] = curr_loop;
737 
738       if (child_loop.enter != nullptr) {
739         // Connect the merge of the outer loop to the enter of the inner.
740         g->AddEdge(curr_loop.merge, 0, child_loop.enter, 0);
741       }
742       src_frame = cf_info[src_frame->id()].parent_frame;
743       child_loop = curr_loop;
744     }
745   }
746 
747   // For a cross-device edge, on the dst device, add a control edge
748   // from the merge node of the control loop to dst. If a send/recv is
749   // introduced for this edge in future partitioning, we delete this
750   // control edge and add a new control edge from the merge to the recv.
751   num_edge_ids = g->num_edge_ids();
752   for (int i = 0; i < num_edge_ids; ++i) {
753     const Edge* edge = g->FindEdgeId(i);
754     if (edge == nullptr) continue;
755 
756     const Node* src = edge->src();
757     Node* dst = edge->dst();
758     // Skip Sink/Source nodes.
759     if (!src->IsOp() || !dst->IsOp()) continue;
760 
761     const string& src_device = src->assigned_device_name();
762     const string& dst_device = dst->assigned_device_name();
763     if (src_device != dst_device) {
764       const Node* src_frame = OutputFrame(src, cf_info);
765       const Node* dst_frame = InputFrame(dst, cf_info);
766       const string& src_frame_name = cf_info[src_frame->id()].frame_name;
767       const string& dst_frame_name = cf_info[dst_frame->id()].frame_name;
768       if (!src_frame_name.empty() && src_frame_name == dst_frame_name) {
769         const string& cl_key =
770             strings::StrCat(dst_frame_name, "$$", dst_device);
771         ControlLoop loop = control_loops[cl_key];
772         DCHECK(loop.enter != nullptr);
773         // Note that we'll create multiple duplicate edges if dst has multiple
774         // cross-device inputs. This is expected by the logic in Partition(), so
775         // it can add control edges to the recv nodes once they're created.
776         g->AddControlEdge(loop.merge, dst, /*allow_duplicates=*/true);
777       }
778     }
779   }
780   return OkStatus();
781 }
782 
783 struct PriorityTopoSortNode {
PriorityTopoSortNodetensorflow::__anondc86d8e30111::PriorityTopoSortNode784   PriorityTopoSortNode(const NodeDef* n, int64_t st)
785       : node(n), start_time(st) {}
786 
787   const NodeDef* node;
788   int64_t start_time;
789 };
790 
791 struct PriorityTopoSortNodeGreater {
operator ()tensorflow::__anondc86d8e30111::PriorityTopoSortNodeGreater792   bool operator()(const PriorityTopoSortNode& left,
793                   const PriorityTopoSortNode& right) {
794     return left.start_time > right.start_time;
795   }
796 };
797 
798 }  // namespace
799 
800 // Returns in <nodes> the nodes that should participate in epoch-based recv
801 // scheduling, along with their times; <nodes> is ordered by increasing
802 // start_time. Returns in <node_to_start_time_out> the timing for all nodes,
803 // even those not in <nodes>.
804 //
805 // Comparing to sorting on the node's start time only, this also processes the
806 // nodes in dependency order, and updates start times to ensure a node's
807 // start_time > the start time for all dependencies.
808 //
809 // Note that graph_partition_test.cc accesses this function for testing, even
810 // though it's not declared in the header.
TopologicalSortNodesWithTimePriority(const GraphDef * gdef,std::vector<std::pair<const NodeDef *,int64_t>> * nodes,std::unordered_map<const NodeDef *,int64_t> * node_to_start_time_out)811 Status TopologicalSortNodesWithTimePriority(
812     const GraphDef* gdef,
813     std::vector<std::pair<const NodeDef*, int64_t>>* nodes,
814     std::unordered_map<const NodeDef*, int64_t>* node_to_start_time_out) {
815   // Queue of nodes to process; lowest start time is returned first.
816   std::priority_queue<PriorityTopoSortNode, std::vector<PriorityTopoSortNode>,
817                       PriorityTopoSortNodeGreater>
818       q;
819   std::unordered_map<const NodeDef*, int64_t> node_to_start_time;
820   auto enqueue = [&q, &node_to_start_time](const NodeDef* node) {
821     const int64_t start_time = node_to_start_time[node];
822     q.emplace(node, start_time);
823   };
824 
825   // Build initial structures, initial contents of queue.
826   std::unordered_map<string, std::vector<const NodeDef*>> node_to_output_nodes;
827   std::unordered_map<const NodeDef*, int> inputs_needed;
828   for (int n = 0; n < gdef->node_size(); ++n) {
829     const NodeDef* ndef = &gdef->node(n);
830     for (int i = 0; i < ndef->input_size(); ++i) {
831       node_to_output_nodes[string(ParseTensorName(ndef->input(i)).first)]
832           .push_back(ndef);
833     }
834     int64_t start_time;
835     TF_RETURN_IF_ERROR(GetNodeAttr(*ndef, "_start_time", &start_time));
836     node_to_start_time[ndef] = start_time;
837     inputs_needed[ndef] = ndef->input_size();
838     if (ndef->input_size() == 0) {
839       enqueue(ndef);
840     }
841   }
842 
843   // Determine which merge nodes are parts of loops; these
844   // need to happen in the traversal after all non-NextIteration inputs
845   // are run.
846   for (int n = 0; n < gdef->node_size(); ++n) {
847     const NodeDef* ndef = &gdef->node(n);
848     if (IsNextIteration(*ndef)) {
849       for (const NodeDef* n : node_to_output_nodes[ndef->name()]) {
850         if (IsMerge(*n)) {
851           // n is a merge that is part of a loop structure.
852           // It doesn't need to wait for this NextIteration loop
853           // when doing the traversal.
854           --inputs_needed[n];
855         }
856       }
857     }
858   }
859 
860   // Traverse.
861   std::vector<std::pair<const NodeDef*, int64_t>> start_times;
862   start_times.reserve(gdef->node_size());
863   while (!q.empty()) {
864     PriorityTopoSortNode cur = q.top();
865     q.pop();
866 
867     start_times.emplace_back(cur.node, cur.start_time);
868 
869     for (const NodeDef* n : node_to_output_nodes[cur.node->name()]) {
870       auto& output_start_time = node_to_start_time[n];
871       if (output_start_time <= cur.start_time) {
872         output_start_time = cur.start_time + 1;
873       }
874       if (--inputs_needed[n] == 0) {
875         enqueue(n);
876       }
877     }
878   }
879 
880   // Done.
881   nodes->swap(start_times);
882   node_to_start_time_out->swap(node_to_start_time);
883   return OkStatus();
884 }
885 
AddControlEdges(const PartitionOptions & opts,std::unordered_map<string,GraphDef> * partitions)886 Status AddControlEdges(const PartitionOptions& opts,
887                        std::unordered_map<string, GraphDef>* partitions) {
888   Status status;
889   // TODO(yuanbyu): Very naive for now. To be improved.
890   const int num_epochs = 100;
891   const int prefetch = 6;
892 
893   for (auto& part : *partitions) {
894     GraphDef* gdef = &part.second;
895     std::vector<std::pair<const NodeDef*, int64_t>> start_times;
896     std::unordered_map<const NodeDef*, int64_t> node_to_start_time;
897     status = TopologicalSortNodesWithTimePriority(gdef, &start_times,
898                                                   &node_to_start_time);
899     if (!status.ok()) {
900       return status;
901     }
902 
903     // Add a dummy node for every epoch, and add a control edge from the
904     // "last" node in the preceding epoch to the dummy node.
905     string device_name = gdef->node(0).device();
906     int64_t makespan = start_times.back().second;
907     int64_t resolution = (makespan / num_epochs) + 1;
908 
909     int i = 0;
910     int j = 0;
911     std::vector<NodeDef*> dummys;
912     while (i < num_epochs && static_cast<size_t>(j) < start_times.size()) {
913       if (i * resolution > start_times[j].second) {
914         j++;
915       } else {
916         NodeDef* dummy = AddControlTrigger(opts, gdef, device_name, i,
917                                            i * resolution, &status);
918         if (!status.ok()) {
919           return status;
920         }
921         dummys.push_back(dummy);
922         if (j > 0) {
923           string src_name = start_times[j - 1].first->name();
924           AddInput(dummy, src_name, Graph::kControlSlot);
925         }
926         i++;
927       }
928     }
929 
930     // Finally, add the control edges to recvs.
931     for (int n = 0; n < gdef->node_size(); ++n) {
932       NodeDef* ndef = gdef->mutable_node(n);
933       if (ndef->op() == "_Recv") {
934         const int64_t start_time = node_to_start_time[ndef];
935         const int recv_epoch = start_time / resolution;
936         if (recv_epoch >= prefetch) {
937           NodeDef* dummy = dummys[recv_epoch - prefetch];
938           AddInput(ndef, dummy->name(), Graph::kControlSlot);
939         }
940       }
941     }
942   }
943   return OkStatus();
944 }
945 
946 // If 'ndef' is a Send or Recv, fills its attr send_device_incarnation
947 // if possible.
SetIncarnation(const PartitionOptions & opts,NodeDef * ndef)948 void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) {
949   StringPiece op(ndef->op());
950   if (op != "_Send" && op != "_Recv") {
951     // Not related to send/recv.
952     return;
953   }
954   const string& send_device = GetNodeAttrString(*ndef, "send_device");
955   if (send_device.empty()) {
956     // No known send_device. The runtime will detect it later.
957     return;
958   }
959   int64_t incarnation = PartitionOptions::kIllegalIncarnation;
960   if (!TryGetNodeAttr(*ndef, "send_device_incarnation", &incarnation) ||
961       (incarnation == PartitionOptions::kIllegalIncarnation)) {
962     incarnation = opts.get_incarnation(send_device);
963     SetAttrValue(incarnation,
964                  &((*ndef->mutable_attr())["send_device_incarnation"]));
965   }
966 }
967 
968 // Sets attribute send_device_incarnation of all Send/Recv nodes in
969 // 'gdef', if possible.
SetIncarnation(const PartitionOptions & opts,GraphDef * gdef)970 void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) {
971   for (NodeDef& ndef : *gdef->mutable_node()) {
972     SetIncarnation(opts, &ndef);
973   }
974   for (FunctionDef& fdef : *gdef->mutable_library()->mutable_function()) {
975     for (NodeDef& ndef : *fdef.mutable_node_def()) {
976       SetIncarnation(opts, &ndef);
977     }
978   }
979 }
980 
Partition(const PartitionOptions & opts,Graph * g,std::unordered_map<string,GraphDef> * partitions)981 Status Partition(const PartitionOptions& opts, Graph* g,
982                  std::unordered_map<string, GraphDef>* partitions) {
983   Status status;
984   partitions->clear();
985 
986   GraphInfo g_info;
987   if (!opts.control_flow_added) {
988     // Add the "code" for distributed execution of control flow. Code is
989     // added only for the frames that are placed on multiple devices. The
990     // new graph is an equivalent transformation of the original graph and
991     // has the property that it can be subsequently partitioned arbitrarily
992     // (down to the level of individual device) for distributed execution.
993     status = AddControlFlow(opts, g, &g_info);
994     if (!status.ok()) return status;
995   }
996 
997   // At this point, all the graph mutations have been done. Build memory
998   // and device type info for every node and edge in the graph.
999   status = BuildMemoryDeviceInfo(*g, &g_info);
1000   if (!status.ok()) return status;
1001 
1002   string dstp;
1003   std::vector<const Edge*> inputs;
1004   DupRecvTable dup_recv(3);
1005   // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref
1006   // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref
1007   // edge to dst. We will add a control edge for every pair in
1008   // (ref_recvs x ref_control_inputs).
1009   std::vector<NodeDef*> ref_recvs;
1010   std::vector<string> ref_control_inputs;
1011 
1012   int32_t num_data = 0;
1013   int32_t num_control = 0;
1014   for (const Node* dst : g->op_nodes()) {
1015     dstp = opts.node_to_loc(dst);
1016     GraphDef* dst_graph = &(*partitions)[dstp];
1017     NodeDef* dst_def = dst_graph->add_node();
1018     *dst_def = dst->def();
1019     MergeDebugInfo(NodeDebugInfo(dst->def()), dst_def);
1020     dst_def->set_device(dst->assigned_device_name());
1021     dst_def->clear_input();  // Inputs are filled below
1022     if (opts.need_to_record_start_times) {
1023       int64_t start_time;
1024       status = GetNodeAttr(*dst_def, "_start_time", &start_time);
1025       if (errors::IsNotFound(status)) {
1026         start_time = opts.start_times[dst->id()].value();
1027         AddNodeAttr("_start_time", start_time, dst_def);
1028       } else if (!status.ok()) {
1029         return status;
1030       }
1031     }
1032 
1033     // Arrange the incoming edges to dst so that input[i] holds the
1034     // input flowing into slot numbered i. Trailing entries in input[]
1035     // hold control edges.
1036     inputs.clear();
1037     inputs.resize(dst->num_inputs(), nullptr);
1038     ref_recvs.clear();
1039     ref_control_inputs.clear();
1040     const Edge* control_flow_edge = nullptr;
1041     int32_t num_control_flow_edges = 0;
1042     int32_t num_input_edges = 0;
1043     for (const Edge* edge : dst->in_edges()) {
1044       if (edge->IsControlEdge()) {
1045         if (IsMerge(edge->src()) && IsControlLoop(edge->src())) {
1046           // This is one of the control edges added for control flow. There
1047           // can be multiple such edges as the dest node may have multiple
1048           // remote inputs. We keep track of the number of such edges.
1049           control_flow_edge = edge;
1050           ++num_control_flow_edges;
1051         } else {
1052           inputs.push_back(edge);
1053         }
1054       } else {
1055         DCHECK(inputs[edge->dst_input()] == nullptr);
1056         inputs[edge->dst_input()] = edge;
1057         ++num_input_edges;
1058       }
1059     }
1060 
1061     if (num_input_edges != dst->num_inputs()) {
1062       return errors::InvalidArgument("Incomplete graph, missing ",
1063                                      (dst->num_inputs() - num_input_edges),
1064                                      " inputs for ", dst->name());
1065     }
1066 
1067     // Process in order so that all data edges are added as inputs to
1068     // dst in Edge::dst_input() order.
1069     for (const Edge* edge : inputs) {
1070       const Node* src = edge->src();
1071       if (!src->IsOp()) continue;  // Skip Sink/Source nodes.
1072 
1073       GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)];
1074       if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) {
1075         // Same partition and compatible memory types:
1076         AddInput(dst_def, src->name(), edge->src_output());
1077         if (edge->IsControlEdge() ||
1078             !IsRefType(src->output_type(edge->src_output()))) {
1079           ref_control_inputs.push_back(src->name());
1080         }
1081         continue;
1082       }
1083 
1084       int64_t send_start_time = 0;
1085       int64_t recv_start_time = 0;
1086       if (opts.scheduling_for_recvs) {
1087         status = GetNodeAttr(src->attrs(), "_start_time", &send_start_time);
1088         if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
1089           send_start_time = opts.start_times[src->id()].value();
1090         } else if (!status.ok()) {
1091           return status;
1092         }
1093 
1094         status = GetNodeAttr(dst->attrs(), "_start_time", &recv_start_time);
1095         if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
1096           recv_start_time = opts.start_times[dst->id()].value();
1097         } else if (!status.ok()) {
1098           return status;
1099         }
1100       }
1101 
1102       // Check whether there is already a send/recv pair transferring
1103       // the same tensor/control from the src to dst partition.
1104       const bool on_host = IsDstInputOnHost(edge, g_info);
1105       DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host};
1106       auto iter = dup_recv.find(key);
1107       if (iter != dup_recv.end()) {
1108         // We found one. Reuse the data/control transferred already.
1109         const string& recv_node_name = iter->second.recv->name();
1110         if (edge->IsControlEdge()) {
1111           AddInput(dst_def, recv_node_name, Graph::kControlSlot);
1112         } else {
1113           AddInput(dst_def, recv_node_name, 0);
1114         }
1115         ref_control_inputs.push_back(recv_node_name);
1116 
1117         // We want the start_time for the recv to be the smallest of the start
1118         // times of it's consumers. So we update this whenever we use a recv,
1119         // and write it out to the attribute at the end of the subroutine
1120         if (iter->second.start_time > recv_start_time) {
1121           iter->second.start_time = recv_start_time;
1122         }
1123         continue;
1124       }
1125 
1126       NodeDefBuilder::NodeOut send_from;
1127       if (edge->IsControlEdge()) {
1128         // Insert a dummy const node that will generate a tiny
1129         // data element to be sent from send to recv.
1130         VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "["
1131                 << src->name() << "] -> " << dst->assigned_device_name() << "["
1132                 << dst->name() << "]";
1133         NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status);
1134         if (!status.ok()) return status;
1135         // Set the start time for this dummy node.
1136         if (opts.scheduling_for_recvs) {
1137           AddNodeAttr("_start_time", send_start_time, dummy);
1138         }
1139         AddInput(dummy, src->name(), Graph::kControlSlot);
1140         send_from.Reset(dummy->name(), 0, DT_FLOAT);
1141       } else {
1142         send_from.Reset(src->name(), edge->src_output(), EdgeType(edge));
1143       }
1144 
1145       string tensor_name_attr;
1146       if (opts.get_tensor_name_attr) {
1147         tensor_name_attr = opts.get_tensor_name_attr(edge);
1148       } else {
1149         tensor_name_attr =
1150             strings::StrCat("edge_", edge->id(), "_", edge->src()->name());
1151       }
1152 
1153       // Need to split edge by placing matching send/recv nodes on
1154       // the src/dst sides of the edge.
1155       NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from,
1156                               send_start_time, tensor_name_attr, &status);
1157       if (!status.ok()) return status;
1158 
1159       NodeDef* real_recv = nullptr;
1160       NodeDef* recv = AddRecv(opts, g_info, dst_graph, edge, &real_recv,
1161                               tensor_name_attr, &status);
1162       if (!status.ok()) return status;
1163 
1164       // Fix up the control flow edge.
1165       // NOTE(yuanbyu): 'real_recv' must be the real recv node.
1166       if (src_graph == dst_graph) {
1167         // For same device send/recv, add a control edge from send to recv.
1168         // This prevents the asynchronous recv kernel from being scheduled
1169         // before the data is available.
1170         AddInput(real_recv, send->name(), Graph::kControlSlot);
1171       } else if (control_flow_edge != nullptr) {
1172         // Redirect control edge to the real recv since this is not the same
1173         // device send/recv.
1174         --num_control_flow_edges;
1175         AddInput(real_recv, control_flow_edge->src()->name(),
1176                  Graph::kControlSlot);
1177       }
1178 
1179       if (!edge->IsControlEdge() &&
1180           IsRefType(src->output_type(edge->src_output()))) {
1181         AddNodeAttr("_start_time", recv_start_time, recv);
1182         if (real_recv != recv) {
1183           AddNodeAttr("_start_time", recv_start_time, real_recv);
1184         }
1185         // If src is of ref type and the edge is not a control edge, dst has
1186         // read semantics and therefore we must control the recv.
1187         ref_recvs.push_back(real_recv);
1188       } else {
1189         // Memorize the send/recv pair, only if this is not a "ref" edge.
1190         // NOTE(yuanbyu): Collapsing ref edges requires extreme care so
1191         // for now we don't do it.
1192         dup_recv[key] = {recv, real_recv, recv_start_time};
1193         ref_control_inputs.push_back(recv->name());
1194       }
1195 
1196       if (edge->IsControlEdge()) {
1197         ++num_control;
1198         AddInput(dst_def, recv->name(), Graph::kControlSlot);
1199       } else {
1200         ++num_data;
1201         AddInput(dst_def, recv->name(), 0);
1202       }
1203     }
1204 
1205     // Add control edges from 'ref_control_inputs' to 'ref_recvs'.
1206     // NOTE(yuanbyu): Adding these control edges should not introduce
1207     // deadlocks. 'dst' has implicit "read" nodes that, when we split
1208     // across devices, are made explicit; Retargeting the dependencies
1209     // to 'dst' to those nodes would not introduce cycles if there isn't
1210     // one before the transformation.
1211     // NOTE(yuanbyu): This may impact performance because it defers the
1212     // execution of recvs until all the other inputs become available.
1213     AddReadControl(ref_recvs, ref_control_inputs);
1214 
1215     // Add back the control edges for control flow that are not used.
1216     if (control_flow_edge != nullptr) {
1217       for (int i = 0; i < num_control_flow_edges; ++i) {
1218         AddInput(dst_def, control_flow_edge->src()->name(),
1219                  Graph::kControlSlot);
1220       }
1221     }
1222   }
1223 
1224   const FunctionLibraryDefinition* flib_def = opts.flib_def;
1225   if (flib_def == nullptr) {
1226     flib_def = &g->flib_def();
1227   }
1228 
1229   // Set versions, function library and send/recv incarnation.
1230   for (auto& it : *partitions) {
1231     GraphDef* gdef = &it.second;
1232     *gdef->mutable_versions() = g->versions();
1233     // Prune unreachable functions from `flib_def` before adding them to `gdef`.
1234     *gdef->mutable_library() = flib_def->ReachableDefinitions(*gdef).ToProto();
1235 
1236     // Traverse the graph to fill every send/recv op's incarnation
1237     // information.
1238     SetIncarnation(opts, gdef);
1239   }
1240 
1241   // Set the start times for recvs at the very end.
1242   if (opts.scheduling_for_recvs) {
1243     for (auto& it : dup_recv) {
1244       AddNodeAttr("_start_time", it.second.start_time, it.second.recv);
1245       if (it.second.real_recv != it.second.recv) {
1246         AddNodeAttr("_start_time", it.second.start_time, it.second.real_recv);
1247       }
1248     }
1249   }
1250 
1251   VLOG(1) << "Added send/recv: controls=" << num_control
1252           << ", data=" << num_data;
1253   if (VLOG_IS_ON(2)) {
1254     for (auto& it : *partitions) {
1255       GraphDef* gdef = &it.second;
1256       DumpGraphDefToFile(strings::StrCat("partition_", it.first, "_",
1257                                          reinterpret_cast<uintptr_t>(gdef)),
1258                          *gdef);
1259     }
1260   }
1261   return OkStatus();
1262 }
1263 
1264 }  // namespace tensorflow
1265