1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
17
18 #include <functional>
19 #include <memory>
20 #include <numeric>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/match.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/compiler/jit/flags.h"
30 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
31 #include "tensorflow/compiler/jit/shape_inference_helpers.h"
32 #include "tensorflow/compiler/jit/xla_cluster_util.h"
33 #include "tensorflow/compiler/tf2xla/const_analysis.h"
34 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
35 #include "tensorflow/compiler/xla/status_macros.h"
36 #include "tensorflow/core/common_runtime/device_factory.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/optimization_registry.h"
39 #include "tensorflow/core/common_runtime/shape_refiner.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/framework/graph_def_util.h"
42 #include "tensorflow/core/framework/graph_to_functiondef.h"
43 #include "tensorflow/core/framework/node_def_builder.h"
44 #include "tensorflow/core/framework/node_def_util.h"
45 #include "tensorflow/core/framework/tensor.pb.h"
46 #include "tensorflow/core/graph/algorithm.h"
47 #include "tensorflow/core/graph/control_flow.h"
48 #include "tensorflow/core/graph/graph.h"
49 #include "tensorflow/core/graph/graph_def_builder.h"
50 #include "tensorflow/core/graph/tensor_id.h"
51 #include "tensorflow/core/lib/gtl/map_util.h"
52 #include "tensorflow/core/lib/hash/hash.h"
53 #include "tensorflow/core/public/session_options.h"
54 #include "tensorflow/core/public/version.h"
55 #include "tensorflow/core/util/device_name_utils.h"
56 #include "tensorflow/core/util/dump_graph.h"
57
58 namespace tensorflow {
59
60 const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel";
61 const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs";
62 const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
63 const char* const kXlaHostTransferSequencerAttr =
64 "_xla_host_transfer_sequencer";
65 const char* const kXlaHasReferenceVarsAttr = "_XlaHasReferenceVars";
66
67 namespace {
68
AreAllParentsGuaranteedConst(const Node & n,const absl::flat_hash_set<const Node * > & runtime_const_nodes)69 bool AreAllParentsGuaranteedConst(
70 const Node& n,
71 const absl::flat_hash_set<const Node*>& runtime_const_nodes) {
72 if (n.type_string() == "GuaranteeConst") {
73 // If the current node is itself a cast-to-const, no need
74 // to look at the incoming edges.
75 return true;
76 }
77
78 bool all_parents_const = true;
79 bool atleast_one_non_control_edge = false;
80 for (const Edge* in : n.in_edges()) {
81 atleast_one_non_control_edge =
82 atleast_one_non_control_edge || !in->IsControlEdge();
83 if (!in->IsControlEdge() && runtime_const_nodes.count(in->src()) == 0) {
84 all_parents_const = false;
85 break;
86 }
87 }
88 return all_parents_const && atleast_one_non_control_edge;
89 }
90
MarkGuaranteedConstants(const Graph & graph,const std::vector<std::pair<const Node *,Node * >> & src_arg_pairs)91 void MarkGuaranteedConstants(
92 const Graph& graph,
93 const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) {
94 absl::flat_hash_set<const Node*> guaranteed_const_nodes;
95 std::vector<const Node*> srcs;
96 srcs.reserve(src_arg_pairs.size());
97 for (const auto& src_arg : src_arg_pairs) {
98 srcs.push_back(src_arg.first);
99 }
100 ReverseDFSFrom(
101 graph, srcs, /*enter=*/nullptr,
102 /*leave=*/[&guaranteed_const_nodes](const Node* n) {
103 // TODO(vinuraja): Doesn't work in the presence of loops.
104 if (AreAllParentsGuaranteedConst(*n, guaranteed_const_nodes)) {
105 guaranteed_const_nodes.insert(n);
106 }
107 });
108
109 for (auto& src_arg : src_arg_pairs) {
110 if (guaranteed_const_nodes.count(src_arg.first) != 0) {
111 VLOG(1) << "Guaranteed const found: " << src_arg.first->DebugString();
112 src_arg.second->AddAttr("_is_guaranteed_constant", true);
113 }
114 }
115 }
116
117 struct OutputInputTensorPairHasher {
operator ()tensorflow::__anond9260d780111::OutputInputTensorPairHasher118 uint64 operator()(std::pair<OutputTensor, InputTensor> const& s) const {
119 return Hash64Combine(OutputTensor::Hash()(s.first),
120 InputTensor::Hash()(s.second));
121 }
122 };
123
124 // TODO(phawkins) add a canonical copy of these operator names and refactor
125 // everything to use it.
126 static const char* const kArgOp = "_Arg";
127 static const char* const kRetValOp = "_Retval";
128
129 class Encapsulator {
130 public:
Encapsulator(string group_attribute,Graph const * graph_in)131 Encapsulator(string group_attribute, Graph const* graph_in)
132 : group_attribute_(std::move(group_attribute)), graph_in_(graph_in) {}
133
134 // Find subgraphs marked with 'group_attribute', and build a new
135 // subgraph, one for each value of 'group_attribute'.
136 Status SplitIntoSubgraphs(FunctionLibraryDefinition* library);
137
138 // Build a FunctionDef for each subgraph, and add it 'library'. The values of
139 // the 'group_attribute' annotations become the function names.
140 // If 'reuse_existing_functions' is set, use an existing function with the
141 // same name, if any.
142 // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
143 // function conversion.
144 Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn,
145 bool reuse_existing_functions,
146 FunctionLibraryDefinition* library);
147
148 // Write a copy of the input graph to 'graph_out', where the subgraphs are
149 // replaced with calls to the new functions.
150 Status BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library);
151
152 private:
153 // A subgraph of the input, all marked with a common 'group_attribute'
154 // value.
155 //
156 // In the following simple example, A, B, ..., E are nodes in the original
157 // graph. The group attributes g are each shown as either 0 or empty.
158 //
159 // A --> B --> C --> D --> E
160 // g: g:0 g:0 g:0 g:
161 //
162 // The example is rewritten to two graphs; one on the host and one to be
163 // compiled. The host graph is as follows.
164 //
165 // A --> Call --> E
166 //
167 // The compiled cluster is as follows.
168 //
169 // Arg --> B --> C --> D --> Retval
170 class Subgraph {
171 public:
172 // Creates a graph to build the subgraph in, if it doesn't already exist,
173 // using the same op registry and versions as graph_in.
174 Node* MakeNodeImage(const Graph* graph_in, Node* node);
175
176 // Returns the graph the subgraph is being built in.
177 Graph* GetGraph() const;
178
179 // Builds a FunctionDef, and adds it to 'library'. The value of the
180 // 'group_attribute' annotations becomes the function name. If
181 // 'reuse_existing_functions' is set, use an existing function with the same
182 // name, if any. If 'rewrite_subgraph_fn' is set, it is applied to the
183 // subgraph before function conversion.
184 Status BuildFunctionDef(const string& name_in,
185 const RewriteSubgraphFn& rewrite_subgraph_fn,
186 bool reuse_existing_functions,
187 FunctionLibraryDefinition* library);
188
189 // Adds the function call node to graph_out.
190 Status AddFunctionCallNode(
191 const std::unordered_map<const Node*, Node*>& node_images,
192 Graph* graph_out);
193
194 // Returns the Node that the inputs and outputs of the function should be
195 // wired up to.
196 Node* GetCallNode() const;
197
198 // Returns the index of the arg that the dst of edge should connect to.
199 int GetArgIndexForEdge(const Edge* edge) const;
200
201 // Returns the index of the result that the src of edge should connect to.
202 int GetResultIndexForEdge(const Edge* edge) const;
203
204 // Creates an _Arg node for the src node of edge, and add its index to
205 // args_by_src_, if none exists yet. Also adds its index to args_by_dst_,
206 // and adds the edge within the subgraph from the _Arg node to the image of
207 // the dst node.
208 Status RecordArg(const Edge* edge,
209 const std::unordered_map<const Node*, Node*>& node_images,
210 std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
211
212 // Records the src of the given edge as a control result of the graph.
213 // Used during graph to function conversion to tie control results to
214 // the function signature.
215 Status RecordControlResult(
216 const Edge* edge,
217 const std::unordered_map<const Node*, Node*>& node_images);
218
219 // Creates a _Retval node for the src node of edge, and add it to results_,
220 // if none exists yet. If a new _Retval node is created, also adds the edge
221 // within the subgraph from the src to the _Retval node.
222 Status RecordResult(
223 const Edge* edge,
224 const std::unordered_map<const Node*, Node*>& node_images);
225
226 // Creates the sequencer node if it doesn't exist, adding it to graph_out.
227 Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out);
228
229 // If there is a sequencer node, adds a control edge from the sequencer to
230 // the call node.
231 void ConnectSequencerToCallNode(Graph* graph_out);
232
233 Status ReplaceFunctionDef(FunctionLibraryDefinition* library);
234
235 private:
236 // The subgraph extracted from the input graph, suitable for being turned
237 // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are
238 // returned by _Retval nodes.
239 std::unique_ptr<Graph> graph_;
240
241 // Which device are these nodes on? Used to assign a device to the call
242 // node.
243 string device_;
244
245 // NodeDef for the function call node.
246 NodeDef call_node_def_;
247
248 // Name that is used for the call node. This may not be
249 // call_node_def_.name() if the client supplies a rewrite lambda.
250 string function_def_name_;
251
252 // Placeholder node simulating the host compute key in the output graph.
253 // Not owned.
254 Node* host_compute_key_placeholder_ = nullptr;
255
256 // Function call node in the output graph. Not owned.
257 Node* call_node_;
258
259 // Maps from source (producer node/slot) and destination
260 // (consumer node/slot) tensors in the input graph to _Arg numbers in
261 // the subgraph. The source map is one-to-one, whereas the dest map may be
262 // many-to-one.
263 std::unordered_map<OutputTensor, int, OutputTensor::Hash> args_by_src_;
264 std::unordered_map<InputTensor, int, InputTensor::Hash> args_by_dst_;
265
266 // The arguments to the subgraph, in order.
267 std::vector<Node*> args_;
268
269 // Map from source tensor in the input graph to result #.
270 std::unordered_map<OutputTensor, int, OutputTensor::Hash> results_;
271
272 // Set of node names that are the source of a control output of the
273 // subgraph. We store strings here so that we can tolerate nodes being
274 // removed from the graph.
275 absl::flat_hash_set<string> control_output_nodes_;
276
277 // NoOp node in the output graph that is sequenced after the call node.
278 Node* sequencer_ = nullptr;
279 };
280
281 // Returns the key attribute associated with a node in attr. Sets either
282 // result to the empty string if the respective attribute is not found.
283 Status GetFunctionNameAttr(Node const* node, string* attr) const;
284
285 // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to
286 // subgraphs for data edges that cross subgraph boundaries.
287 Status CopySubgraphEdges(
288 const std::unordered_map<const Node*, Node*>& node_images,
289 std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
290
291 // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes.
292 Status CopySubgraphNodes(std::unordered_map<const Node*, Node*>* node_images);
293
294 // Copies all nodes that aren't in a compiled subgraph to the output graph.
295 Status CopyNodesToOutputGraph(
296 Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images);
297
298 // Adds function call nodes for each compiled subgraph.
299 Status AddFunctionCallNodes(
300 const std::unordered_map<const Node*, Node*>& node_images,
301 Graph* graph_out);
302
303 // Finds the image of an edge source in the output graph. If the edge crosses
304 // a subgraph boundary it is the output of a call node, otherwise it is a node
305 // in the output graph.
306 Status FindOutputImageOfEdgeSrc(
307 const string& src_func_id, const string& dst_func_id,
308 const std::unordered_map<const Node*, Node*>& node_images,
309 const Node* original_src_node, Node** src_image);
310
311 // Finds an edge source slot in the output graph. If the edge crosses a
312 // subgraph boundary it is a slot on the output of a call node, otherwise it
313 // is a slot on a node in the output graph.
314 int FindOutputSlotOfEdgeSrc(const string& src_func_id,
315 const string& dst_func_id,
316 const Edge* edge);
317
318 // Finds the image of an edge destination in the output graph. If the edge
319 // crosses a subgraph boundary it is the input of a call node, otherwise it is
320 // a node in the output graph.
321 Status FindOutputImageOfEdgeDst(
322 const string& src_func_id, const string& dst_func_id,
323 const std::unordered_map<const Node*, Node*>& node_images,
324 const Node* original_dst_node, Node** dst_image);
325
326 // Finds an edge destination slot in the output graph. If the edge crosses a
327 // subgraph boundary it is a slot on the input of a call node, otherwise it is
328 // a slot on a node in the output graph.
329 int FindOutputSlotOfEdgeDst(const string& src_func_id,
330 const string& dst_func_id,
331 const Edge* edge);
332
333 // Copies a single edge to the output graph. The edge is either entirely
334 // within the output graph, or crosses into or out of a compiled subgraph.
335 Status CopyEdgeToOutputGraph(
336 const Edge* edge, const string& src_func_id, const string& dst_func_id,
337 const std::unordered_map<const Node*, Node*>& node_images,
338 Graph* graph_out,
339 std::unordered_set<std::pair<OutputTensor, InputTensor>,
340 OutputInputTensorPairHasher>* edges_added);
341
342 // Adds all edges to the output graph.
343 Status AddEdgesToOutputGraph(
344 const std::unordered_map<const Node*, Node*>& node_images,
345 Graph* graph_out);
346
347 // Makes a copy of graph containing only nodes that are ancestors of at least
348 // one node in send_from_host_nodes and store it in pruned_graph. On exit
349 // nodes_images contains a mapping from nodes in graph to nodes in
350 // pruned_graph. All functions in the copied graph are inlined.
351 Status MakePrunedGraphCopyAndInline(
352 const Graph& graph, const std::vector<Node*>& sink_nodes,
353 std::unique_ptr<Graph>* pruned_graph,
354 std::unordered_map<const Node*, Node*>* node_images,
355 FunctionLibraryDefinition* library);
356
357 const string group_attribute_;
358 const Graph* graph_in_;
359
360 std::unordered_map<string, Subgraph> subgraphs_;
361
362 TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator);
363 };
364
365 namespace {
366
367 // Return in 'sorted' a topological sort of clusters according to the
368 // dependencies encoded in ancestors. clusters is the list of all clusters
369 // including clusters that are not present in the ancestors map. has_successors
370 // is the set of clusters that are ancestors of some other cluster.
TopologicalClusterSort(const std::unordered_set<string> & clusters,const std::unordered_set<string> & has_successors,const std::unordered_map<string,std::unordered_set<string>> & ancestors,std::vector<string> * sorted)371 void TopologicalClusterSort(
372 const std::unordered_set<string>& clusters,
373 const std::unordered_set<string>& has_successors,
374 const std::unordered_map<string, std::unordered_set<string>>& ancestors,
375 std::vector<string>* sorted) {
376 // The nodes are placed in 'sorted' in topological order.
377 sorted->clear();
378 // We don't use the standard DFS because we are not operating on Node*
379 // objects.
380 struct Work {
381 string cluster;
382 bool leave;
383 };
384 std::set<string> visited;
385 std::vector<Work> stack;
386 // Seed the processing list with clusters that have no successors.
387 for (const auto& cluster : clusters) {
388 if (has_successors.find(cluster) == has_successors.end()) {
389 stack.push_back({cluster, false});
390 }
391 }
392 while (!stack.empty()) {
393 const Work item = stack.back();
394 stack.pop_back();
395 if (item.leave) {
396 sorted->push_back(item.cluster);
397 continue;
398 }
399
400 if (visited.find(item.cluster) != visited.end()) continue;
401 visited.insert(item.cluster);
402
403 stack.push_back({item.cluster, true});
404 const auto& iter = ancestors.find(item.cluster);
405 if (iter != ancestors.end()) {
406 for (const auto& ancestor : iter->second) {
407 stack.push_back({ancestor, false});
408 }
409 }
410 }
411 CHECK(sorted->size() == clusters.size());
412 }
413
414 } // namespace
415
GetCallNode() const416 Node* Encapsulator::Subgraph::GetCallNode() const { return call_node_; }
417
GetArgIndexForEdge(const Edge * edge) const418 int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const {
419 return args_by_dst_.at(InputTensor(edge->dst(), edge->dst_input()));
420 }
421
GetResultIndexForEdge(const Edge * edge) const422 int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const {
423 return results_.at(OutputTensor(edge->src(), edge->src_output()));
424 }
425
MakeNodeImage(const Graph * graph_in,Node * node)426 Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {
427 if (!graph_) {
428 graph_.reset(new Graph(graph_in->op_registry()));
429 graph_->set_versions(graph_in->versions());
430 }
431
432 // TODO(b/116981129): Enhance how the device for the encapsulated subgraph is
433 // determined. In case of hard placement, ensure all the encapsulated nodes
434 // have the same requested device, which in turn will be the requested device
435 // for the entire encapsulated subgraph. In case of soft placement, use a
436 // deterministic approach to fill in the requested device. Handle co-location
437 // constraints similarly if they exist.
438 if (device_.empty()) {
439 device_ = node->assigned_device_name().empty()
440 ? node->requested_device()
441 : node->assigned_device_name();
442 }
443
444 return graph_->CopyNode(node);
445 }
446
GetGraph() const447 Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); }
448
RecordArg(const Edge * edge,const std::unordered_map<const Node *,Node * > & node_images,std::vector<std::pair<const Node *,Node * >> * src_arg_pairs)449 Status Encapsulator::Subgraph::RecordArg(
450 const Edge* edge, const std::unordered_map<const Node*, Node*>& node_images,
451 std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
452 Node* src_node = edge->src();
453 int src_slot = edge->src_output();
454 std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
455 bool inserted;
456 std::tie(iter, inserted) = args_by_src_.emplace(
457 OutputTensor(src_node, src_slot), args_by_src_.size());
458 int arg_index = iter->second;
459 if (inserted) {
460 NodeDef arg_def;
461 NodeDefBuilder builder(
462 absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp,
463 NodeDebugInfo(src_node->def()));
464 DataType dtype = edge->dst()->input_type(edge->dst_input());
465 builder.Attr("T", dtype);
466 builder.Attr("index", arg_index);
467 Status s = builder.Finalize(&arg_def);
468 if (!s.ok()) return s;
469
470 TF_ASSIGN_OR_RETURN(Node * arg, graph_->AddNode(arg_def));
471 src_arg_pairs->push_back({src_node, arg});
472 args_.push_back(arg);
473 }
474 Node* dst_node = edge->dst();
475 Node* dst_image = node_images.at(dst_node);
476 int dst_slot = edge->dst_input();
477 args_by_dst_[InputTensor(dst_node, dst_slot)] = arg_index;
478 graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot);
479 return OkStatus();
480 }
481
RecordControlResult(const Edge * edge,const std::unordered_map<const Node *,Node * > & node_images)482 Status Encapsulator::Subgraph::RecordControlResult(
483 const Edge* edge,
484 const std::unordered_map<const Node*, Node*>& node_images) {
485 Node* src_node = edge->src();
486 Node* src_image = node_images.at(src_node);
487 control_output_nodes_.insert(src_image->name());
488 return OkStatus();
489 }
490
RecordResult(const Edge * edge,const std::unordered_map<const Node *,Node * > & node_images)491 Status Encapsulator::Subgraph::RecordResult(
492 const Edge* edge,
493 const std::unordered_map<const Node*, Node*>& node_images) {
494 Node* src_node = edge->src();
495 Node* src_image = node_images.at(src_node);
496 int src_slot = edge->src_output();
497 std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
498 bool inserted;
499 std::tie(iter, inserted) =
500 results_.emplace(OutputTensor(src_node, src_slot), results_.size());
501 int ret_index = iter->second;
502 if (inserted) {
503 NodeDef ret_def;
504 NodeDefBuilder builder(
505 absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp,
506 NodeDebugInfo(src_node->def()));
507 DataType dtype = src_node->output_type(src_slot);
508 builder.Attr("T", dtype);
509 builder.Attr("index", ret_index);
510 builder.Input(src_image->name(), src_slot, dtype);
511 Status s = builder.Finalize(&ret_def);
512 if (!s.ok()) return s;
513 TF_ASSIGN_OR_RETURN(Node * ret, graph_->AddNode(ret_def));
514 graph_->AddEdge(src_image, src_slot, ret, 0);
515 }
516 return OkStatus();
517 }
518
MakeSequencingNode(const string & subgraph_name,Graph * graph_out)519 Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name,
520 Graph* graph_out) {
521 if (sequencer_ == nullptr) {
522 NodeDef seq_def;
523 // TODO(shikharagarwal): What source node should we use for errors?
524 NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp");
525 builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name);
526 builder.Device(device_);
527 Status s = builder.Finalize(&seq_def);
528 if (!s.ok()) return s;
529
530 TF_ASSIGN_OR_RETURN(sequencer_, graph_out->AddNode(seq_def));
531 }
532 return OkStatus();
533 }
534
ConnectSequencerToCallNode(Graph * graph_out)535 void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) {
536 if (sequencer_ != nullptr) {
537 VLOG(2) << "ConnectSequencerToCallNode";
538 graph_out->AddControlEdge(sequencer_, call_node_,
539 /* allow_duplicates= */ true);
540 }
541 }
542
BuildFunctionDef(const string & name_in,const RewriteSubgraphFn & rewrite_subgraph_fn,bool reuse_existing_functions,FunctionLibraryDefinition * library)543 Status Encapsulator::Subgraph::BuildFunctionDef(
544 const string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
545 bool reuse_existing_functions, FunctionLibraryDefinition* library) {
546 // name_in is copied here because name may be modified below if
547 // rewrite_subgraph_fn is true.
548 string name = name_in;
549 call_node_def_.set_op(name);
550 call_node_def_.set_name(name);
551 call_node_def_.set_device(device_);
552
553 if (rewrite_subgraph_fn) {
554 std::vector<OutputTensor> arg_source_tensors(args_by_src_.size());
555 for (const auto& arg : args_by_src_) {
556 arg_source_tensors.at(arg.second) = arg.first;
557 }
558 // Initialize the input and output permutations to the identity.
559 std::vector<int> input_permutation(args_by_src_.size());
560 std::iota(input_permutation.begin(), input_permutation.end(), 0);
561 std::vector<int> output_permutation(results_.size());
562 std::iota(output_permutation.begin(), output_permutation.end(), 0);
563
564 TF_RETURN_IF_ERROR(
565 rewrite_subgraph_fn(arg_source_tensors, &graph_, &input_permutation,
566 &output_permutation, &call_node_def_));
567
568 // Apply the input/output permutations to the 'args_by_...' and 'results_'
569 // mappings, so when we build edges in BuildOutputGraph() we
570 // connect them to the right input/output positions.
571 if (input_permutation.size() != args_by_src_.size()) {
572 return errors::InvalidArgument("Input permutation has incorrect size.");
573 }
574 if (output_permutation.size() != results_.size()) {
575 return errors::InvalidArgument("Output permutation has incorrect size.");
576 }
577 for (auto& arg : args_by_src_) {
578 arg.second = input_permutation[arg.second];
579 }
580 for (auto& arg : args_by_dst_) {
581 arg.second = input_permutation[arg.second];
582 }
583 for (auto& result : results_) {
584 result.second = output_permutation[result.second];
585 }
586
587 name = call_node_def_.op();
588 }
589
590 function_def_name_ = name;
591
592 FunctionDef fdef;
593 auto lookup = [this](const Node* node) -> std::optional<string> {
594 if (control_output_nodes_.contains(node->name())) {
595 return absl::make_optional(node->name());
596 }
597 return std::nullopt;
598 };
599 // Verify that the graph has well-formed control flow structure.
600 std::vector<ControlFlowInfo> dummy;
601 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &dummy));
602 TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, lookup, &fdef));
603
604 if (VLOG_IS_ON(1)) {
605 VLOG(2) << "Build function def " << name;
606 DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name), *graph_,
607 library);
608 DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name), fdef);
609 }
610
611 const FunctionDef* original_fdef = library->Find(name);
612 if (!reuse_existing_functions || original_fdef == nullptr) {
613 TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
614 } else if (!FunctionDefsEqual(*original_fdef, fdef)) {
615 TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
616 }
617 return OkStatus();
618 }
619
ReplaceFunctionDef(FunctionLibraryDefinition * library)620 Status Encapsulator::Subgraph::ReplaceFunctionDef(
621 FunctionLibraryDefinition* library) {
622 const string& name = function_def_name_;
623
624 FunctionDef fdef;
625 TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
626
627 if (VLOG_IS_ON(1)) {
628 VLOG(2) << "Replace function def " << name;
629 DumpGraphToFile(absl::StrCat("replace_encapsulate_fdef_graph_", name),
630 *graph_, library);
631 DumpFunctionDefToFile(absl::StrCat("replace_encapsulate_fdef_", name),
632 fdef);
633 }
634
635 TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
636 return OkStatus();
637 }
638
AddFunctionCallNode(const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out)639 Status Encapsulator::Subgraph::AddFunctionCallNode(
640 const std::unordered_map<const Node*, Node*>& node_images,
641 Graph* graph_out) {
642 TF_ASSIGN_OR_RETURN(call_node_, graph_out->AddNode(call_node_def_));
643
644 // Copy the assigned device and the key_annotation over.
645 call_node_->set_assigned_device_name(device_);
646
647 return OkStatus();
648 }
649
GetFunctionNameAttr(Node const * node,string * attr) const650 Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const {
651 AttrSlice attrs = node->attrs();
652 attr->clear();
653 for (const auto& node_attr : attrs) {
654 if (node_attr.first == group_attribute_) {
655 TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string"));
656 *attr = node_attr.second.s();
657 break;
658 }
659 }
660 return OkStatus();
661 }
662
IsInSubgraph(const string & func_id)663 bool IsInSubgraph(const string& func_id) { return !func_id.empty(); }
664
CopySubgraphNodes(std::unordered_map<const Node *,Node * > * node_images)665 Status Encapsulator::CopySubgraphNodes(
666 std::unordered_map<const Node*, Node*>* node_images) {
667 for (Node* node : graph_in_->op_nodes()) {
668 string func_id;
669 TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id));
670 if (!IsInSubgraph(func_id)) continue;
671
672 Subgraph& subgraph = subgraphs_[func_id];
673 Node* image = subgraph.MakeNodeImage(graph_in_, node);
674 image->ClearAttr(group_attribute_);
675 (*node_images)[node] = image;
676 }
677 return OkStatus();
678 }
679
CopySubgraphEdges(const std::unordered_map<const Node *,Node * > & node_images,std::vector<std::pair<const Node *,Node * >> * src_arg_pairs)680 Status Encapsulator::CopySubgraphEdges(
681 const std::unordered_map<const Node*, Node*>& node_images,
682 std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
683 for (const Edge* edge : graph_in_->edges()) {
684 string src_func_id;
685 TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id));
686 string dst_func_id;
687 TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id));
688 Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr);
689 Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr);
690
691 // Copy edges that are local to a subgraph.
692 if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) &&
693 src_func_id == dst_func_id) {
694 Graph* g = subgraphs_[src_func_id].GetGraph();
695 if (edge->IsControlEdge()) {
696 g->AddControlEdge(src_image, dst_image,
697 /* allow_duplicates= */ true);
698 } else {
699 g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input());
700 }
701 continue;
702 }
703
704 // Record 'src' as an output of its subgraph, if applicable.
705 if (IsInSubgraph(src_func_id)) {
706 if (!edge->IsControlEdge()) {
707 DataType dtype = edge->src()->output_type(edge->src_output());
708 if (IsRefType(dtype)) {
709 return errors::InvalidArgument(
710 "Ref Tensors (e.g., Variables) are not supported as results: "
711 "tensor ",
712 edge->src()->name(), ":", edge->src_output());
713 }
714 }
715
716 Subgraph& src_subgraph = subgraphs_[src_func_id];
717 if (edge->IsControlEdge()) {
718 TF_RETURN_IF_ERROR(src_subgraph.RecordControlResult(edge, node_images));
719 } else {
720 TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images));
721 }
722 }
723
724 // Record 'dst' as an input of its subgraph, if applicable.
725 if (IsInSubgraph(dst_func_id)) {
726 // Look at the type of the destination not the source, since Ref output
727 // Tensors can be automatically cast to non-Ref Tensors at the
728 // destination.
729 if (!edge->IsControlEdge()) {
730 DataType dtype = edge->dst()->input_type(edge->dst_input());
731 if (IsRefType(dtype)) {
732 return errors::InvalidArgument(
733 "Ref Tensors (e.g., Variables) are not supported as args: "
734 "tensor ",
735 edge->src()->name(), ":", edge->src_output());
736 }
737 }
738
739 Subgraph& dst_subgraph = subgraphs_[dst_func_id];
740 // Ignore control edges entering the subgraph. We will lift them onto
741 // the enclosing call operators in BuildOutputGraph().
742 if (!edge->IsControlEdge()) {
743 TF_RETURN_IF_ERROR(
744 dst_subgraph.RecordArg(edge, node_images, src_arg_pairs));
745 }
746 }
747 }
748 return OkStatus();
749 }
750
SplitIntoSubgraphs(FunctionLibraryDefinition * library)751 Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) {
752 Status s;
753
754 // Map from input graph nodes to subgraph nodes.
755 std::unordered_map<const Node*, Node*> node_images;
756
757 // Each entry of src_arg_pairs is a pair whose first element is a node in the
758 // original graph that has an output edge in the subgraph, and whose second
759 // element is the arg node in the subgraph that it sends to. The vector will
760 // be filled in below in AddArgs.
761 std::vector<std::pair<const Node*, Node*>> src_arg_pairs;
762
763 TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images));
764 TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs));
765 MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
766
767 for (auto& entry : subgraphs_) {
768 Subgraph& subgraph = entry.second;
769 FixupSourceAndSinkEdges(subgraph.GetGraph());
770 }
771
772 if (VLOG_IS_ON(1)) {
773 // Dump subgraphs.
774 for (auto& entry : subgraphs_) {
775 DumpGraphToFile(
776 absl::StrCat("encapsulate_subgraphs_subgraph_", entry.first),
777 *entry.second.GetGraph(), library);
778 }
779 }
780
781 return s;
782 }
783
BuildFunctionDefs(const RewriteSubgraphFn & rewrite_subgraph_fn,bool reuse_existing_functions,FunctionLibraryDefinition * library)784 Status Encapsulator::BuildFunctionDefs(
785 const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
786 FunctionLibraryDefinition* library) {
787 for (auto& subgraph_entry : subgraphs_) {
788 string name = subgraph_entry.first;
789 Subgraph& subgraph = subgraph_entry.second;
790 TF_RETURN_IF_ERROR(subgraph.BuildFunctionDef(
791 name, rewrite_subgraph_fn, reuse_existing_functions, library));
792 }
793 return OkStatus();
794 }
795
CopyNodesToOutputGraph(Graph * graph_out,std::unordered_map<const Node *,Node * > * node_images)796 Status Encapsulator::CopyNodesToOutputGraph(
797 Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images) {
798 for (Node* node : graph_in_->op_nodes()) {
799 string func_id;
800 TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id));
801
802 // Don't copy nodes that are going to be encapsulated.
803 if (IsInSubgraph(func_id)) continue;
804
805 Node* image = graph_out->CopyNode(node);
806 (*node_images)[node] = image;
807 }
808 (*node_images)[graph_in_->source_node()] = graph_out->source_node();
809 (*node_images)[graph_in_->sink_node()] = graph_out->sink_node();
810 return OkStatus();
811 }
812
AddFunctionCallNodes(const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out)813 Status Encapsulator::AddFunctionCallNodes(
814 const std::unordered_map<const Node*, Node*>& node_images,
815 Graph* graph_out) {
816 for (auto& subgraph_entry : subgraphs_) {
817 TF_RETURN_IF_ERROR(
818 subgraph_entry.second.AddFunctionCallNode(node_images, graph_out));
819 }
820 return OkStatus();
821 }
822
FindOutputImageOfEdgeSrc(const string & src_func_id,const string & dst_func_id,const std::unordered_map<const Node *,Node * > & node_images,const Node * original_src_node,Node ** src_image)823 Status Encapsulator::FindOutputImageOfEdgeSrc(
824 const string& src_func_id, const string& dst_func_id,
825 const std::unordered_map<const Node*, Node*>& node_images,
826 const Node* original_src_node, Node** src_image) {
827 if (IsInSubgraph(src_func_id)) {
828 // The edge is from a subgraph to a regular node in the output graph so
829 // use the subgraph's call node output.
830 *src_image = subgraphs_.at(src_func_id).GetCallNode();
831 } else {
832 // The source of the edge is in the output graph so use the node image in
833 // the output graph.
834 *src_image = node_images.at(original_src_node);
835 }
836 return OkStatus();
837 }
838
FindOutputSlotOfEdgeSrc(const string & src_func_id,const string & dst_func_id,const Edge * edge)839 int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id,
840 const string& dst_func_id,
841 const Edge* edge) {
842 if (IsInSubgraph(src_func_id)) {
843 const Subgraph& src_subgraph = subgraphs_.at(src_func_id);
844 // 'src' is in a subgraph and 'dst' is a regular node in the output
845 // graph. Use the corresponding call output instead.
846 return src_subgraph.GetResultIndexForEdge(edge);
847 } else {
848 // The source of the edge is in the output graph so use the regular edge
849 // slot.
850 return edge->src_output();
851 }
852 }
853
FindOutputImageOfEdgeDst(const string & src_func_id,const string & dst_func_id,const std::unordered_map<const Node *,Node * > & node_images,const Node * original_dst_node,Node ** dst_image)854 Status Encapsulator::FindOutputImageOfEdgeDst(
855 const string& src_func_id, const string& dst_func_id,
856 const std::unordered_map<const Node*, Node*>& node_images,
857 const Node* original_dst_node, Node** dst_image) {
858 if (IsInSubgraph(dst_func_id)) {
859 // The edge is to a subgraph from a regular node in the output graph so
860 // use the subgraph's call node input.
861 *dst_image = subgraphs_.at(dst_func_id).GetCallNode();
862 } else {
863 // The destination of the edge is in the output graph so use the node image
864 // in the output graph.
865 *dst_image = node_images.at(original_dst_node);
866 }
867 return OkStatus();
868 }
869
FindOutputSlotOfEdgeDst(const string & src_func_id,const string & dst_func_id,const Edge * edge)870 int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id,
871 const string& dst_func_id,
872 const Edge* edge) {
873 if (IsInSubgraph(dst_func_id)) {
874 const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id);
875 // 'dst' is in a subgraph and 'src' is a regular node in the output
876 // graph. Use the corresponding call input instead.
877 return dst_subgraph.GetArgIndexForEdge(edge);
878 } else {
879 // The destination of the edge is in the output graph so use the regular
880 // edge slot.
881 return edge->dst_input();
882 }
883 }
884
CopyEdgeToOutputGraph(const Edge * edge,const string & src_func_id,const string & dst_func_id,const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out,std::unordered_set<std::pair<OutputTensor,InputTensor>,OutputInputTensorPairHasher> * edges_added)885 Status Encapsulator::CopyEdgeToOutputGraph(
886 const Edge* edge, const string& src_func_id, const string& dst_func_id,
887 const std::unordered_map<const Node*, Node*>& node_images, Graph* graph_out,
888 std::unordered_set<std::pair<OutputTensor, InputTensor>,
889 OutputInputTensorPairHasher>* edges_added) {
890 Node* src_image;
891 TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc(
892 src_func_id, dst_func_id, node_images, edge->src(), &src_image));
893 Node* dst_image;
894 TF_RETURN_IF_ERROR(FindOutputImageOfEdgeDst(
895 src_func_id, dst_func_id, node_images, edge->dst(), &dst_image));
896
897 // If this is a control edge then copy it and return. Lift control edges onto
898 // the enclosing call operator.
899 if (edge->IsControlEdge()) {
900 // Add the control edge, if we have not already added it, using the images
901 // determined above (potentially call operators or RecvAtHost/SendFromHost).
902 if (edges_added
903 ->emplace(OutputTensor(src_image, -1), InputTensor(dst_image, -1))
904 .second) {
905 graph_out->AddControlEdge(src_image, dst_image,
906 /* allow_duplicates= */ true);
907 }
908
909 return OkStatus();
910 }
911
912 int src_output = FindOutputSlotOfEdgeSrc(src_func_id, dst_func_id, edge);
913
914 int dst_input = FindOutputSlotOfEdgeDst(src_func_id, dst_func_id, edge);
915
916 // Add the edge, if we have not already added it.
917 if (edges_added
918 ->emplace(OutputTensor(src_image, src_output),
919 InputTensor(dst_image, dst_input))
920 .second) {
921 graph_out->AddEdge(src_image, src_output, dst_image, dst_input);
922 }
923 return OkStatus();
924 }
925
AddEdgesToOutputGraph(const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out)926 Status Encapsulator::AddEdgesToOutputGraph(
927 const std::unordered_map<const Node*, Node*>& node_images,
928 Graph* graph_out) {
929 // Set of edges already added to the output graph, represented as (src, dst)
930 // pairs. We use the set to deduplicate edges; multiple edges in the input
931 // graph may map to one edge in the output graph.
932 std::unordered_set<std::pair<OutputTensor, InputTensor>,
933 OutputInputTensorPairHasher>
934 edges_added;
935
936 for (const Edge* edge : graph_in_->edges()) {
937 string src_func_id;
938 TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id));
939 string dst_func_id;
940 TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id));
941
942 // Ignore edges that are strictly contained within one subgraph, unless
943 // we are constructing parallel check graphs.
944 if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) &&
945 src_func_id == dst_func_id) {
946 continue;
947 }
948
949 // We have an edge that crosses a cluster boundary or is entirely within the
950 // unclustered graph.
951 TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph(
952 edge, src_func_id, dst_func_id, node_images, graph_out, &edges_added));
953 }
954
955 for (auto& subgraph_entry : subgraphs_) {
956 Subgraph& subgraph = subgraph_entry.second;
957 subgraph.ConnectSequencerToCallNode(graph_out);
958 }
959
960 return OkStatus();
961 }
962
963 namespace {
964
965 // Adds a dummy Const node to graph_out. The "constant" has the type of
966 // data_type and the shape indicated in 'shape'. The dummy node is not a valid
967 // Const node because it does not have any value defined, but this doesn't
968 // matter because it will only be used subsequently for shape inference. (It
969 // would be possible to add a switch statement over data_type to create a value
970 // for the constant, but that would entail maintaining the logic as new types
971 // are added, and is not necessary.) If the node being replaced was within a
972 // control flow frame, adds appropriate Enter nodes so that the use of the Const
973 // is well-formed.
AddDummyShapedNode(const Node * src_node,int src_port,const std::vector<ControlFlowInfo> & control_flow_info,const TensorShapeProto & shape,Graph * graph_out)974 Node* AddDummyShapedNode(const Node* src_node, int src_port,
975 const std::vector<ControlFlowInfo>& control_flow_info,
976 const TensorShapeProto& shape, Graph* graph_out) {
977 DataType data_type = src_node->output_type(src_port);
978 TensorProto dummy_proto;
979 dummy_proto.set_dtype(data_type);
980 *dummy_proto.mutable_tensor_shape() = shape;
981 // Don't set any value field in the proto, since it is only going to be used
982 // for shape inference.
983
984 GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
985 NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
986 options.op_registry());
987 node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
988 Node* node = options.FinalizeBuilder(&node_builder);
989 // Add any Enter nodes required to bring the constant to the correct control
990 // flow frame.
991 while (!control_flow_info[src_node->id()].frame_name.empty()) {
992 NodeDebugInfo debug_info(*src_node);
993 NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter",
994 options.op_registry(), &debug_info);
995 enter_builder.Attr("frame_name",
996 control_flow_info[src_node->id()].frame_name);
997 enter_builder.Attr("is_constant", true);
998 enter_builder.Input(node, 0);
999 Node* enter_node = options.FinalizeBuilder(&enter_builder);
1000 // Adopt the new Enter node as the value in the current frame.
1001 node = enter_node;
1002 // Recurse to the parent frame to see if more Enter nodes need to be added.
1003 src_node = control_flow_info[src_node->id()].parent_frame;
1004 }
1005 return node;
1006 }
1007
1008 } // namespace
1009
MakePrunedGraphCopyAndInline(const Graph & graph,const std::vector<Node * > & sink_nodes,std::unique_ptr<Graph> * pruned_graph,std::unordered_map<const Node *,Node * > * node_images,FunctionLibraryDefinition * library)1010 Status Encapsulator::MakePrunedGraphCopyAndInline(
1011 const Graph& graph, const std::vector<Node*>& sink_nodes,
1012 std::unique_ptr<Graph>* pruned_graph,
1013 std::unordered_map<const Node*, Node*>* node_images,
1014 FunctionLibraryDefinition* library) {
1015 // First copy all ancestor nodes of sink_nodes into a new graph.
1016 pruned_graph->reset(new Graph(library));
1017 (*pruned_graph)->set_versions(graph.versions());
1018 ReverseDFSFrom(graph, sink_nodes,
1019 /*enter=*/nullptr,
1020 /*leave=*/[&](Node* n) {
1021 if (!n->IsSource()) {
1022 Node* copied = (*pruned_graph)->CopyNode(n);
1023 node_images->emplace(n, copied);
1024 }
1025 });
1026
1027 // Add all the edges between copied nodes.
1028 for (auto entry : *node_images) {
1029 const Node* orig = entry.first;
1030 Node* image = entry.second;
1031 for (const Edge* out_edge : orig->out_edges()) {
1032 auto iter = node_images->find(out_edge->dst());
1033 if (iter != node_images->end()) {
1034 // The source and destination are both in the copied graph.
1035 (*pruned_graph)
1036 ->AddEdge(image, out_edge->src_output(), iter->second,
1037 out_edge->dst_input());
1038 }
1039 }
1040 }
1041
1042 // Find all the function call nodes, and inline them.
1043 std::vector<Node*> function_nodes;
1044 for (auto node : (*pruned_graph)->nodes()) {
1045 const OpRegistrationData* op_reg_data;
1046 TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data));
1047 if (op_reg_data->is_function_op) {
1048 function_nodes.push_back(node);
1049 }
1050 }
1051 for (auto node : function_nodes) {
1052 VLOG(2) << "Inlining function " << node->name();
1053 const FunctionDef* fdef = library->Find(node->type_string());
1054 if (fdef == nullptr) {
1055 return errors::Internal("Failed to find function ", node->type_string(),
1056 " in function library.");
1057 }
1058 std::unique_ptr<FunctionBody> fbody;
1059 TF_RETURN_IF_ERROR(
1060 FunctionDefToBodyHelper(*fdef, node->attrs(), library, &fbody));
1061
1062 InlineFunctionBodyOptions inline_opts;
1063 TF_RETURN_IF_ERROR(InlineFunctionBody(*library, pruned_graph->get(), node,
1064 fbody.get(), inline_opts));
1065 }
1066
1067 return OkStatus();
1068 }
1069
BuildOutputGraph(Graph * graph_out,FunctionLibraryDefinition * library)1070 Status Encapsulator::BuildOutputGraph(Graph* graph_out,
1071 FunctionLibraryDefinition* library) {
1072 // Map from nodes in the input graph to nodes in the output graph.
1073 std::unordered_map<const Node*, Node*> node_images;
1074
1075 TF_RETURN_IF_ERROR(CopyNodesToOutputGraph(graph_out, &node_images));
1076 TF_RETURN_IF_ERROR(AddFunctionCallNodes(node_images, graph_out));
1077 TF_RETURN_IF_ERROR(AddEdgesToOutputGraph(node_images, graph_out));
1078
1079 return OkStatus();
1080 }
1081
1082 } // anonymous namespace
1083
EncapsulateSubgraphsInFunctions(string group_attribute,const Graph & graph_in,const RewriteSubgraphFn & rewrite_subgraph_fn,bool reuse_existing_functions,std::unique_ptr<Graph> * graph_out,FunctionLibraryDefinition * library)1084 Status EncapsulateSubgraphsInFunctions(
1085 string group_attribute, const Graph& graph_in,
1086 const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
1087 std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library) {
1088 Encapsulator encapsulator(std::move(group_attribute),
1089 &graph_in);
1090 TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs(library));
1091
1092 TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs(
1093 rewrite_subgraph_fn, reuse_existing_functions, library));
1094
1095 std::unique_ptr<Graph> out(new Graph(library));
1096 out->set_versions(graph_in.versions());
1097 TF_RETURN_IF_ERROR(encapsulator.BuildOutputGraph(out.get(), library));
1098
1099 *graph_out = std::move(out);
1100 return OkStatus();
1101 }
1102
1103 // Finds the types of the _Arg nodes, indexed by position.
GetArgTypes(const Graph & graph,DataTypeVector * types)1104 static Status GetArgTypes(const Graph& graph, DataTypeVector* types) {
1105 for (Node* n : graph.op_nodes()) {
1106 if (n->type_string() == kArgOp) {
1107 int index;
1108 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
1109 const int num_types = types->size();
1110 if (index < 0 || index >= num_types) {
1111 return errors::InvalidArgument("Invalid argument number");
1112 }
1113 (*types)[index] = n->output_type(0);
1114 }
1115 }
1116 return OkStatus();
1117 }
1118
1119 // Renumber the indices of _Arg nodes in a graph, according to
1120 // 'permutation' that maps old indices to new indices.
RenumberArguments(Graph * graph,const std::vector<int> & permutation)1121 static Status RenumberArguments(Graph* graph,
1122 const std::vector<int>& permutation) {
1123 for (Node* n : graph->op_nodes()) {
1124 if (n->type_string() == kArgOp) {
1125 int index;
1126 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
1127 const int permutation_size = permutation.size();
1128 if (index < 0 || index >= permutation_size) {
1129 return errors::InvalidArgument("Invalid argument number");
1130 }
1131 n->AddAttr("index", permutation[index]);
1132 }
1133 }
1134 return OkStatus();
1135 }
1136
Run(const GraphOptimizationPassOptions & options)1137 Status EncapsulateSubgraphsPass::Run(
1138 const GraphOptimizationPassOptions& options) {
1139 VLOG(1) << "EncapsulateSubgraphsPass::Run";
1140 if (VLOG_IS_ON(1)) {
1141 DumpGraphToFile("encapsulate_subgraphs_before", **options.graph,
1142 options.flib_def);
1143 }
1144
1145 // TODO(b/195757077): Remove this once there is a better way to disable
1146 // GraphOptimizationPasses that are not needed due to MLIR bridge.
1147 for (Node* n : (*options.graph)->nodes()) {
1148 // Skip the pass if we found TPUExecute or TPUExecuteAndUpdateVariables ops
1149 // in the graph, which indicates the graph is produced by TPU TF-XLA bridge
1150 // and doesn't require auto clustering.
1151 if (n->type_string() == "TPUExecute" ||
1152 n->type_string() == "TPUExecuteAndUpdateVariables") {
1153 return OkStatus();
1154 }
1155 }
1156
1157 std::unique_ptr<Graph> graph_out;
1158 FunctionLibraryDefinition* const library = options.flib_def;
1159
1160 // Constant folding below might need to run part of the function to compute
1161 // constants. Create an FunctionLibraryRuntime with a single CPU device
1162 // that can run the part of the function.
1163 // NOTE: If this turns out to be slow, we can cache the FLRs keyed by
1164 // `options`.
1165 SessionOptions session_options;
1166 auto* device_count = session_options.config.mutable_device_count();
1167 device_count->insert({"CPU", 1});
1168 std::vector<std::unique_ptr<Device>> devices;
1169
1170 DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
1171 if (!cpu_factory) {
1172 return errors::NotFound(
1173 "CPU Factory not registered. Can't run EncapsulateSubgraphsPass");
1174 }
1175 TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
1176 session_options, "/job:localhost/replica:0/task:0", &devices));
1177 if (devices.empty()) {
1178 return errors::NotFound(
1179 "Failed to create a CPU device for EncapsulateSubgraphsPass");
1180 }
1181
1182 std::unique_ptr<DeviceMgr> device_mgr =
1183 std::make_unique<StaticDeviceMgr>(std::move(devices));
1184 const auto* config = &options.session_options->config;
1185 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
1186 new ProcessFunctionLibraryRuntime(
1187 device_mgr.get(), options.session_options->env,
1188 /*config=*/config, TF_GRAPH_DEF_VERSION, library,
1189 config->graph_options().optimizer_options()));
1190 FunctionLibraryRuntime* flr =
1191 pflr->GetFLR("/job:localhost/replica:0/task:0/device:CPU:0");
1192 if (flr == nullptr) {
1193 return errors::Internal(
1194 "Failed to create and retrieve function library runtime to run "
1195 "constant folding");
1196 }
1197
1198 auto rewrite_subgraph =
1199 [flr](const std::vector<OutputTensor>& arg_source_tensors,
1200 std::unique_ptr<Graph>* subgraph,
1201 std::vector<int>* input_permutation,
1202 std::vector<int>* output_permutation, NodeDef* node) {
1203 // Optimize the subgraph.
1204 // Do not constant fold nodes that output DT_VARIANT type tensors.
1205 // XLA does not support Const nodes of Variant type since it needs
1206 // to know the original ops to be able to compile them to the relevant
1207 // XLA form.
1208 // TODO(srbs): This filter is a little conservative. E.g. a subgraph of
1209 // the form:
1210 // Const
1211 // |
1212 // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op
1213 // |
1214 // (Discard popped list)
1215 //
1216 // Would have been reduced to "Const -> Op" without this filter.
1217 // However since we are only allowed to specify the filter at the "Node"
1218 // level there is no good way to allow the above behavior. So we
1219 // disallow any sort of constant folding on Variant nodes for now.
1220 bool disable_constant_folding =
1221 GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding;
1222 auto cf_consider_fn = [disable_constant_folding](const Node* n) {
1223 if (disable_constant_folding) return false;
1224 for (const auto& output_arg : n->op_def().output_arg()) {
1225 if (output_arg.type() == DT_VARIANT) {
1226 return false;
1227 }
1228 }
1229 return true;
1230 };
1231 GraphOptimizer::Options graph_optimizer_options;
1232 graph_optimizer_options.cf_consider_fn = cf_consider_fn;
1233 OptimizeGraph(flr, subgraph, graph_optimizer_options);
1234
1235 const int num_args = input_permutation->size();
1236 std::vector<bool> const_args(num_args);
1237 TF_RETURN_IF_ERROR(
1238 BackwardsConstAnalysis(**subgraph, &const_args,
1239 /*compile_time_const_nodes=*/nullptr, flr));
1240
1241 DataTypeVector arg_types(num_args);
1242 TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
1243
1244 // Compute a permutation of the arguments such that the constant
1245 // arguments are first.
1246 const int num_consts =
1247 std::count(const_args.begin(), const_args.end(), true);
1248
1249 const int num_resources =
1250 std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE);
1251 const int num_nonconsts = num_args - num_resources - num_consts;
1252 if (num_nonconsts < 0) {
1253 return errors::Internal("num_nonconsts should be >= 0, was ",
1254 num_nonconsts);
1255 }
1256
1257 int const_pos = 0;
1258 int arg_pos = num_consts;
1259 int resource_pos = num_consts + num_nonconsts;
1260 for (int i = 0; i < num_args; ++i) {
1261 if (const_args[i]) {
1262 if (arg_types[i] == DT_RESOURCE) {
1263 return errors::Internal(
1264 "Resource arguments cannot be constant (argument ", i, ")");
1265 }
1266 (*input_permutation)[i] = const_pos;
1267 ++const_pos;
1268 } else if (arg_types[i] == DT_RESOURCE) {
1269 (*input_permutation)[i] = resource_pos;
1270 ++resource_pos;
1271 } else {
1272 (*input_permutation)[i] = arg_pos;
1273 ++arg_pos;
1274 }
1275 }
1276
1277 // Renumber argument nodes in the graph.
1278 TF_RETURN_IF_ERROR(
1279 RenumberArguments(subgraph->get(), *input_permutation));
1280
1281 // TODO(phawkins): add a forward is-constant analysis, similarly split
1282 // outputs into host-memory constants and device-memory non-constants.
1283
1284 AddNodeAttr(kXlaCompiledKernelAttr, true, node);
1285 AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
1286 AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node);
1287 return OkStatus();
1288 };
1289
1290 TF_RETURN_WITH_CONTEXT_IF_ERROR(
1291 EncapsulateSubgraphsInFunctions(
1292 kXlaClusterAttr, **options.graph, rewrite_subgraph,
1293 /*reuse_existing_functions=*/false, &graph_out, library),
1294 "EncapsulateSubgraphsPass failed");
1295 if (VLOG_IS_ON(1)) {
1296 DumpGraphToFile("encapsulate_subgraphs_after", *graph_out,
1297 options.flib_def);
1298 }
1299
1300 *options.graph = std::move(graph_out);
1301
1302 TF_ASSIGN_OR_RETURN(absl::flat_hash_set<Node*> ref_related_nodes,
1303 GetNodesRelatedToRefVariables(**options.graph, flr));
1304 for (Node* node : (*options.graph)->nodes()) {
1305 bool has_ref_vars = ref_related_nodes.contains(node);
1306 node->AddAttr(kXlaHasReferenceVarsAttr, has_ref_vars);
1307 VLOG(3) << "Has ref vars = " << has_ref_vars
1308 << ", node: " << node->def().DebugString();
1309 }
1310 return OkStatus();
1311 }
1312
IsXlaCompiledKernel(const Node & node)1313 bool IsXlaCompiledKernel(const Node& node) {
1314 bool is_compiled = false;
1315 bool has_compilation_attr =
1316 TryGetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled) &&
1317 is_compiled;
1318 return has_compilation_attr ? is_compiled : false;
1319 }
1320
1321 } // namespace tensorflow
1322