xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <numeric>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/match.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/compiler/jit/flags.h"
30 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
31 #include "tensorflow/compiler/jit/shape_inference_helpers.h"
32 #include "tensorflow/compiler/jit/xla_cluster_util.h"
33 #include "tensorflow/compiler/tf2xla/const_analysis.h"
34 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
35 #include "tensorflow/compiler/xla/status_macros.h"
36 #include "tensorflow/core/common_runtime/device_factory.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/optimization_registry.h"
39 #include "tensorflow/core/common_runtime/shape_refiner.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/framework/graph_def_util.h"
42 #include "tensorflow/core/framework/graph_to_functiondef.h"
43 #include "tensorflow/core/framework/node_def_builder.h"
44 #include "tensorflow/core/framework/node_def_util.h"
45 #include "tensorflow/core/framework/tensor.pb.h"
46 #include "tensorflow/core/graph/algorithm.h"
47 #include "tensorflow/core/graph/control_flow.h"
48 #include "tensorflow/core/graph/graph.h"
49 #include "tensorflow/core/graph/graph_def_builder.h"
50 #include "tensorflow/core/graph/tensor_id.h"
51 #include "tensorflow/core/lib/gtl/map_util.h"
52 #include "tensorflow/core/lib/hash/hash.h"
53 #include "tensorflow/core/public/session_options.h"
54 #include "tensorflow/core/public/version.h"
55 #include "tensorflow/core/util/device_name_utils.h"
56 #include "tensorflow/core/util/dump_graph.h"
57 
58 namespace tensorflow {
59 
60 const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel";
61 const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs";
62 const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
63 const char* const kXlaHostTransferSequencerAttr =
64     "_xla_host_transfer_sequencer";
65 const char* const kXlaHasReferenceVarsAttr = "_XlaHasReferenceVars";
66 
67 namespace {
68 
AreAllParentsGuaranteedConst(const Node & n,const absl::flat_hash_set<const Node * > & runtime_const_nodes)69 bool AreAllParentsGuaranteedConst(
70     const Node& n,
71     const absl::flat_hash_set<const Node*>& runtime_const_nodes) {
72   if (n.type_string() == "GuaranteeConst") {
73     // If the current node is itself a cast-to-const, no need
74     // to look at the incoming edges.
75     return true;
76   }
77 
78   bool all_parents_const = true;
79   bool atleast_one_non_control_edge = false;
80   for (const Edge* in : n.in_edges()) {
81     atleast_one_non_control_edge =
82         atleast_one_non_control_edge || !in->IsControlEdge();
83     if (!in->IsControlEdge() && runtime_const_nodes.count(in->src()) == 0) {
84       all_parents_const = false;
85       break;
86     }
87   }
88   return all_parents_const && atleast_one_non_control_edge;
89 }
90 
MarkGuaranteedConstants(const Graph & graph,const std::vector<std::pair<const Node *,Node * >> & src_arg_pairs)91 void MarkGuaranteedConstants(
92     const Graph& graph,
93     const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) {
94   absl::flat_hash_set<const Node*> guaranteed_const_nodes;
95   std::vector<const Node*> srcs;
96   srcs.reserve(src_arg_pairs.size());
97   for (const auto& src_arg : src_arg_pairs) {
98     srcs.push_back(src_arg.first);
99   }
100   ReverseDFSFrom(
101       graph, srcs, /*enter=*/nullptr,
102       /*leave=*/[&guaranteed_const_nodes](const Node* n) {
103         // TODO(vinuraja): Doesn't work in the presence of loops.
104         if (AreAllParentsGuaranteedConst(*n, guaranteed_const_nodes)) {
105           guaranteed_const_nodes.insert(n);
106         }
107       });
108 
109   for (auto& src_arg : src_arg_pairs) {
110     if (guaranteed_const_nodes.count(src_arg.first) != 0) {
111       VLOG(1) << "Guaranteed const found: " << src_arg.first->DebugString();
112       src_arg.second->AddAttr("_is_guaranteed_constant", true);
113     }
114   }
115 }
116 
117 struct OutputInputTensorPairHasher {
operator ()tensorflow::__anond9260d780111::OutputInputTensorPairHasher118   uint64 operator()(std::pair<OutputTensor, InputTensor> const& s) const {
119     return Hash64Combine(OutputTensor::Hash()(s.first),
120                          InputTensor::Hash()(s.second));
121   }
122 };
123 
124 // TODO(phawkins) add a canonical copy of these operator names and refactor
125 // everything to use it.
126 static const char* const kArgOp = "_Arg";
127 static const char* const kRetValOp = "_Retval";
128 
129 class Encapsulator {
130  public:
Encapsulator(string group_attribute,Graph const * graph_in)131   Encapsulator(string group_attribute, Graph const* graph_in)
132       : group_attribute_(std::move(group_attribute)), graph_in_(graph_in) {}
133 
134   // Find subgraphs marked with 'group_attribute', and build a new
135   // subgraph, one for each value of 'group_attribute'.
136   Status SplitIntoSubgraphs(FunctionLibraryDefinition* library);
137 
138   // Build a FunctionDef for each subgraph, and add it 'library'. The values of
139   // the 'group_attribute' annotations become the function names.
140   // If 'reuse_existing_functions' is set, use an existing function with the
141   // same name, if any.
142   // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
143   // function conversion.
144   Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn,
145                            bool reuse_existing_functions,
146                            FunctionLibraryDefinition* library);
147 
148   // Write a copy of the input graph to 'graph_out', where the subgraphs are
149   // replaced with calls to the new functions.
150   Status BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library);
151 
152  private:
153   // A subgraph of the input, all marked with a common 'group_attribute'
154   // value.
155   //
156   // In the following simple example, A, B, ..., E are nodes in the original
157   // graph. The group attributes g are each shown as either 0 or empty.
158   //
159   //  A  -->  B  -->  C  -->  D  -->  E
160   //  g:      g:0     g:0     g:0     g:
161   //
162   // The example is rewritten to two graphs; one on the host and one to be
163   // compiled. The host graph is as follows.
164   //
165   //  A  -->  Call  -->  E
166   //
167   // The compiled cluster is as follows.
168   //
169   //  Arg  --> B  --> C  --> D --> Retval
170   class Subgraph {
171    public:
172     // Creates a graph to build the subgraph in, if it doesn't already exist,
173     // using the same op registry and versions as graph_in.
174     Node* MakeNodeImage(const Graph* graph_in, Node* node);
175 
176     // Returns the graph the subgraph is being built in.
177     Graph* GetGraph() const;
178 
179     // Builds a FunctionDef, and adds it to 'library'. The value of the
180     // 'group_attribute' annotations becomes the function name.  If
181     // 'reuse_existing_functions' is set, use an existing function with the same
182     // name, if any.  If 'rewrite_subgraph_fn' is set, it is applied to the
183     // subgraph before function conversion.
184     Status BuildFunctionDef(const string& name_in,
185                             const RewriteSubgraphFn& rewrite_subgraph_fn,
186                             bool reuse_existing_functions,
187                             FunctionLibraryDefinition* library);
188 
189     // Adds the function call node to graph_out.
190     Status AddFunctionCallNode(
191         const std::unordered_map<const Node*, Node*>& node_images,
192         Graph* graph_out);
193 
194     // Returns the Node that the inputs and outputs of the function should be
195     // wired up to.
196     Node* GetCallNode() const;
197 
198     // Returns the index of the arg that the dst of edge should connect to.
199     int GetArgIndexForEdge(const Edge* edge) const;
200 
201     // Returns the index of the result that the src of edge should connect to.
202     int GetResultIndexForEdge(const Edge* edge) const;
203 
204     // Creates an _Arg node for the src node of edge, and add its index to
205     // args_by_src_, if none exists yet. Also adds its index to args_by_dst_,
206     // and adds the edge within the subgraph from the _Arg node to the image of
207     // the dst node.
208     Status RecordArg(const Edge* edge,
209                      const std::unordered_map<const Node*, Node*>& node_images,
210                      std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
211 
212     // Records the src of the given edge as a control result of the graph.
213     // Used during graph to function conversion to tie control results to
214     // the function signature.
215     Status RecordControlResult(
216         const Edge* edge,
217         const std::unordered_map<const Node*, Node*>& node_images);
218 
219     // Creates a _Retval node for the src node of edge, and add it to results_,
220     // if none exists yet. If a new _Retval node is created, also adds the edge
221     // within the subgraph from the src to the _Retval node.
222     Status RecordResult(
223         const Edge* edge,
224         const std::unordered_map<const Node*, Node*>& node_images);
225 
226     // Creates the sequencer node if it doesn't exist, adding it to graph_out.
227     Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out);
228 
229     // If there is a sequencer node, adds a control edge from the sequencer to
230     // the call node.
231     void ConnectSequencerToCallNode(Graph* graph_out);
232 
233     Status ReplaceFunctionDef(FunctionLibraryDefinition* library);
234 
235    private:
236     // The subgraph extracted from the input graph, suitable for being turned
237     // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are
238     // returned by _Retval nodes.
239     std::unique_ptr<Graph> graph_;
240 
241     // Which device are these nodes on? Used to assign a device to the call
242     // node.
243     string device_;
244 
245     // NodeDef for the function call node.
246     NodeDef call_node_def_;
247 
248     // Name that is used for the call node. This may not be
249     // call_node_def_.name() if the client supplies a rewrite lambda.
250     string function_def_name_;
251 
252     // Placeholder node simulating the host compute key in the output graph.
253     // Not owned.
254     Node* host_compute_key_placeholder_ = nullptr;
255 
256     // Function call node in the output graph. Not owned.
257     Node* call_node_;
258 
259     // Maps from source (producer node/slot) and destination
260     // (consumer node/slot) tensors in the input graph to _Arg numbers in
261     // the subgraph. The source map is one-to-one, whereas the dest map may be
262     // many-to-one.
263     std::unordered_map<OutputTensor, int, OutputTensor::Hash> args_by_src_;
264     std::unordered_map<InputTensor, int, InputTensor::Hash> args_by_dst_;
265 
266     // The arguments to the subgraph, in order.
267     std::vector<Node*> args_;
268 
269     // Map from source tensor in the input graph to result #.
270     std::unordered_map<OutputTensor, int, OutputTensor::Hash> results_;
271 
272     // Set of node names that are the source of a control output of the
273     // subgraph. We store strings here so that we can tolerate nodes being
274     // removed from the graph.
275     absl::flat_hash_set<string> control_output_nodes_;
276 
277     // NoOp node in the output graph that is sequenced after the call node.
278     Node* sequencer_ = nullptr;
279   };
280 
281   // Returns the key attribute associated with a node in attr. Sets either
282   // result to the empty string if the respective attribute is not found.
283   Status GetFunctionNameAttr(Node const* node, string* attr) const;
284 
285   // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to
286   // subgraphs for data edges that cross subgraph boundaries.
287   Status CopySubgraphEdges(
288       const std::unordered_map<const Node*, Node*>& node_images,
289       std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
290 
291   // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes.
292   Status CopySubgraphNodes(std::unordered_map<const Node*, Node*>* node_images);
293 
294   // Copies all nodes that aren't in a compiled subgraph to the output graph.
295   Status CopyNodesToOutputGraph(
296       Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images);
297 
298   // Adds function call nodes for each compiled subgraph.
299   Status AddFunctionCallNodes(
300       const std::unordered_map<const Node*, Node*>& node_images,
301       Graph* graph_out);
302 
303   // Finds the image of an edge source in the output graph. If the edge crosses
304   // a subgraph boundary it is the output of a call node, otherwise it is a node
305   // in the output graph.
306   Status FindOutputImageOfEdgeSrc(
307       const string& src_func_id, const string& dst_func_id,
308       const std::unordered_map<const Node*, Node*>& node_images,
309       const Node* original_src_node, Node** src_image);
310 
311   // Finds an edge source slot in the output graph. If the edge crosses a
312   // subgraph boundary it is a slot on the output of a call node, otherwise it
313   // is a slot on a node in the output graph.
314   int FindOutputSlotOfEdgeSrc(const string& src_func_id,
315                               const string& dst_func_id,
316                               const Edge* edge);
317 
318   // Finds the image of an edge destination in the output graph. If the edge
319   // crosses a subgraph boundary it is the input of a call node, otherwise it is
320   // a node in the output graph.
321   Status FindOutputImageOfEdgeDst(
322       const string& src_func_id, const string& dst_func_id,
323       const std::unordered_map<const Node*, Node*>& node_images,
324       const Node* original_dst_node, Node** dst_image);
325 
326   // Finds an edge destination slot in the output graph. If the edge crosses a
327   // subgraph boundary it is a slot on the input of a call node, otherwise it is
328   // a slot on a node in the output graph.
329   int FindOutputSlotOfEdgeDst(const string& src_func_id,
330                               const string& dst_func_id,
331                               const Edge* edge);
332 
333   // Copies a single edge to the output graph. The edge is either entirely
334   // within the output graph, or crosses into or out of a compiled subgraph.
335   Status CopyEdgeToOutputGraph(
336       const Edge* edge, const string& src_func_id, const string& dst_func_id,
337       const std::unordered_map<const Node*, Node*>& node_images,
338       Graph* graph_out,
339       std::unordered_set<std::pair<OutputTensor, InputTensor>,
340                          OutputInputTensorPairHasher>* edges_added);
341 
342   // Adds all edges to the output graph.
343   Status AddEdgesToOutputGraph(
344       const std::unordered_map<const Node*, Node*>& node_images,
345       Graph* graph_out);
346 
347   // Makes a copy of graph containing only nodes that are ancestors of at least
348   // one node in send_from_host_nodes and store it in pruned_graph. On exit
349   // nodes_images contains a mapping from nodes in graph to nodes in
350   // pruned_graph. All functions in the copied graph are inlined.
351   Status MakePrunedGraphCopyAndInline(
352       const Graph& graph, const std::vector<Node*>& sink_nodes,
353       std::unique_ptr<Graph>* pruned_graph,
354       std::unordered_map<const Node*, Node*>* node_images,
355       FunctionLibraryDefinition* library);
356 
357   const string group_attribute_;
358   const Graph* graph_in_;
359 
360   std::unordered_map<string, Subgraph> subgraphs_;
361 
362   TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator);
363 };
364 
365 namespace {
366 
367 // Return in 'sorted' a topological sort of clusters according to the
368 // dependencies encoded in ancestors. clusters is the list of all clusters
369 // including clusters that are not present in the ancestors map. has_successors
370 // is the set of clusters that are ancestors of some other cluster.
TopologicalClusterSort(const std::unordered_set<string> & clusters,const std::unordered_set<string> & has_successors,const std::unordered_map<string,std::unordered_set<string>> & ancestors,std::vector<string> * sorted)371 void TopologicalClusterSort(
372     const std::unordered_set<string>& clusters,
373     const std::unordered_set<string>& has_successors,
374     const std::unordered_map<string, std::unordered_set<string>>& ancestors,
375     std::vector<string>* sorted) {
376   // The nodes are placed in 'sorted' in topological order.
377   sorted->clear();
378   // We don't use the standard DFS because we are not operating on Node*
379   // objects.
380   struct Work {
381     string cluster;
382     bool leave;
383   };
384   std::set<string> visited;
385   std::vector<Work> stack;
386   // Seed the processing list with clusters that have no successors.
387   for (const auto& cluster : clusters) {
388     if (has_successors.find(cluster) == has_successors.end()) {
389       stack.push_back({cluster, false});
390     }
391   }
392   while (!stack.empty()) {
393     const Work item = stack.back();
394     stack.pop_back();
395     if (item.leave) {
396       sorted->push_back(item.cluster);
397       continue;
398     }
399 
400     if (visited.find(item.cluster) != visited.end()) continue;
401     visited.insert(item.cluster);
402 
403     stack.push_back({item.cluster, true});
404     const auto& iter = ancestors.find(item.cluster);
405     if (iter != ancestors.end()) {
406       for (const auto& ancestor : iter->second) {
407         stack.push_back({ancestor, false});
408       }
409     }
410   }
411   CHECK(sorted->size() == clusters.size());
412 }
413 
414 }  // namespace
415 
GetCallNode() const416 Node* Encapsulator::Subgraph::GetCallNode() const { return call_node_; }
417 
GetArgIndexForEdge(const Edge * edge) const418 int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const {
419   return args_by_dst_.at(InputTensor(edge->dst(), edge->dst_input()));
420 }
421 
GetResultIndexForEdge(const Edge * edge) const422 int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const {
423   return results_.at(OutputTensor(edge->src(), edge->src_output()));
424 }
425 
MakeNodeImage(const Graph * graph_in,Node * node)426 Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {
427   if (!graph_) {
428     graph_.reset(new Graph(graph_in->op_registry()));
429     graph_->set_versions(graph_in->versions());
430   }
431 
432   // TODO(b/116981129): Enhance how the device for the encapsulated subgraph is
433   // determined. In case of hard placement, ensure all the encapsulated nodes
434   // have the same requested device, which in turn will be the requested device
435   // for the entire encapsulated subgraph. In case of soft placement, use a
436   // deterministic approach to fill in the requested device. Handle co-location
437   // constraints similarly if they exist.
438   if (device_.empty()) {
439     device_ = node->assigned_device_name().empty()
440                   ? node->requested_device()
441                   : node->assigned_device_name();
442   }
443 
444   return graph_->CopyNode(node);
445 }
446 
GetGraph() const447 Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); }
448 
RecordArg(const Edge * edge,const std::unordered_map<const Node *,Node * > & node_images,std::vector<std::pair<const Node *,Node * >> * src_arg_pairs)449 Status Encapsulator::Subgraph::RecordArg(
450     const Edge* edge, const std::unordered_map<const Node*, Node*>& node_images,
451     std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
452   Node* src_node = edge->src();
453   int src_slot = edge->src_output();
454   std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
455   bool inserted;
456   std::tie(iter, inserted) = args_by_src_.emplace(
457       OutputTensor(src_node, src_slot), args_by_src_.size());
458   int arg_index = iter->second;
459   if (inserted) {
460     NodeDef arg_def;
461     NodeDefBuilder builder(
462         absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp,
463         NodeDebugInfo(src_node->def()));
464     DataType dtype = edge->dst()->input_type(edge->dst_input());
465     builder.Attr("T", dtype);
466     builder.Attr("index", arg_index);
467     Status s = builder.Finalize(&arg_def);
468     if (!s.ok()) return s;
469 
470     TF_ASSIGN_OR_RETURN(Node * arg, graph_->AddNode(arg_def));
471     src_arg_pairs->push_back({src_node, arg});
472     args_.push_back(arg);
473   }
474   Node* dst_node = edge->dst();
475   Node* dst_image = node_images.at(dst_node);
476   int dst_slot = edge->dst_input();
477   args_by_dst_[InputTensor(dst_node, dst_slot)] = arg_index;
478   graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot);
479   return OkStatus();
480 }
481 
RecordControlResult(const Edge * edge,const std::unordered_map<const Node *,Node * > & node_images)482 Status Encapsulator::Subgraph::RecordControlResult(
483     const Edge* edge,
484     const std::unordered_map<const Node*, Node*>& node_images) {
485   Node* src_node = edge->src();
486   Node* src_image = node_images.at(src_node);
487   control_output_nodes_.insert(src_image->name());
488   return OkStatus();
489 }
490 
RecordResult(const Edge * edge,const std::unordered_map<const Node *,Node * > & node_images)491 Status Encapsulator::Subgraph::RecordResult(
492     const Edge* edge,
493     const std::unordered_map<const Node*, Node*>& node_images) {
494   Node* src_node = edge->src();
495   Node* src_image = node_images.at(src_node);
496   int src_slot = edge->src_output();
497   std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
498   bool inserted;
499   std::tie(iter, inserted) =
500       results_.emplace(OutputTensor(src_node, src_slot), results_.size());
501   int ret_index = iter->second;
502   if (inserted) {
503     NodeDef ret_def;
504     NodeDefBuilder builder(
505         absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp,
506         NodeDebugInfo(src_node->def()));
507     DataType dtype = src_node->output_type(src_slot);
508     builder.Attr("T", dtype);
509     builder.Attr("index", ret_index);
510     builder.Input(src_image->name(), src_slot, dtype);
511     Status s = builder.Finalize(&ret_def);
512     if (!s.ok()) return s;
513     TF_ASSIGN_OR_RETURN(Node * ret, graph_->AddNode(ret_def));
514     graph_->AddEdge(src_image, src_slot, ret, 0);
515   }
516   return OkStatus();
517 }
518 
MakeSequencingNode(const string & subgraph_name,Graph * graph_out)519 Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name,
520                                                   Graph* graph_out) {
521   if (sequencer_ == nullptr) {
522     NodeDef seq_def;
523     // TODO(shikharagarwal): What source node should we use for errors?
524     NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp");
525     builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name);
526     builder.Device(device_);
527     Status s = builder.Finalize(&seq_def);
528     if (!s.ok()) return s;
529 
530     TF_ASSIGN_OR_RETURN(sequencer_, graph_out->AddNode(seq_def));
531   }
532   return OkStatus();
533 }
534 
ConnectSequencerToCallNode(Graph * graph_out)535 void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) {
536   if (sequencer_ != nullptr) {
537     VLOG(2) << "ConnectSequencerToCallNode";
538     graph_out->AddControlEdge(sequencer_, call_node_,
539                               /* allow_duplicates= */ true);
540   }
541 }
542 
BuildFunctionDef(const string & name_in,const RewriteSubgraphFn & rewrite_subgraph_fn,bool reuse_existing_functions,FunctionLibraryDefinition * library)543 Status Encapsulator::Subgraph::BuildFunctionDef(
544     const string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
545     bool reuse_existing_functions, FunctionLibraryDefinition* library) {
546   // name_in is copied here because name may be modified below if
547   // rewrite_subgraph_fn is true.
548   string name = name_in;
549   call_node_def_.set_op(name);
550   call_node_def_.set_name(name);
551   call_node_def_.set_device(device_);
552 
553   if (rewrite_subgraph_fn) {
554     std::vector<OutputTensor> arg_source_tensors(args_by_src_.size());
555     for (const auto& arg : args_by_src_) {
556       arg_source_tensors.at(arg.second) = arg.first;
557     }
558     // Initialize the input and output permutations to the identity.
559     std::vector<int> input_permutation(args_by_src_.size());
560     std::iota(input_permutation.begin(), input_permutation.end(), 0);
561     std::vector<int> output_permutation(results_.size());
562     std::iota(output_permutation.begin(), output_permutation.end(), 0);
563 
564     TF_RETURN_IF_ERROR(
565         rewrite_subgraph_fn(arg_source_tensors, &graph_, &input_permutation,
566                             &output_permutation, &call_node_def_));
567 
568     // Apply the input/output permutations to the 'args_by_...' and 'results_'
569     // mappings, so when we build edges in BuildOutputGraph() we
570     // connect them to the right input/output positions.
571     if (input_permutation.size() != args_by_src_.size()) {
572       return errors::InvalidArgument("Input permutation has incorrect size.");
573     }
574     if (output_permutation.size() != results_.size()) {
575       return errors::InvalidArgument("Output permutation has incorrect size.");
576     }
577     for (auto& arg : args_by_src_) {
578       arg.second = input_permutation[arg.second];
579     }
580     for (auto& arg : args_by_dst_) {
581       arg.second = input_permutation[arg.second];
582     }
583     for (auto& result : results_) {
584       result.second = output_permutation[result.second];
585     }
586 
587     name = call_node_def_.op();
588   }
589 
590   function_def_name_ = name;
591 
592   FunctionDef fdef;
593   auto lookup = [this](const Node* node) -> std::optional<string> {
594     if (control_output_nodes_.contains(node->name())) {
595       return absl::make_optional(node->name());
596     }
597     return std::nullopt;
598   };
599   // Verify that the graph has well-formed control flow structure.
600   std::vector<ControlFlowInfo> dummy;
601   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &dummy));
602   TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, lookup, &fdef));
603 
604   if (VLOG_IS_ON(1)) {
605     VLOG(2) << "Build function def " << name;
606     DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name), *graph_,
607                     library);
608     DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name), fdef);
609   }
610 
611   const FunctionDef* original_fdef = library->Find(name);
612   if (!reuse_existing_functions || original_fdef == nullptr) {
613     TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
614   } else if (!FunctionDefsEqual(*original_fdef, fdef)) {
615     TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
616   }
617   return OkStatus();
618 }
619 
ReplaceFunctionDef(FunctionLibraryDefinition * library)620 Status Encapsulator::Subgraph::ReplaceFunctionDef(
621     FunctionLibraryDefinition* library) {
622   const string& name = function_def_name_;
623 
624   FunctionDef fdef;
625   TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
626 
627   if (VLOG_IS_ON(1)) {
628     VLOG(2) << "Replace function def " << name;
629     DumpGraphToFile(absl::StrCat("replace_encapsulate_fdef_graph_", name),
630                     *graph_, library);
631     DumpFunctionDefToFile(absl::StrCat("replace_encapsulate_fdef_", name),
632                           fdef);
633   }
634 
635   TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
636   return OkStatus();
637 }
638 
AddFunctionCallNode(const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out)639 Status Encapsulator::Subgraph::AddFunctionCallNode(
640     const std::unordered_map<const Node*, Node*>& node_images,
641     Graph* graph_out) {
642   TF_ASSIGN_OR_RETURN(call_node_, graph_out->AddNode(call_node_def_));
643 
644   // Copy the assigned device and the key_annotation over.
645   call_node_->set_assigned_device_name(device_);
646 
647   return OkStatus();
648 }
649 
GetFunctionNameAttr(Node const * node,string * attr) const650 Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const {
651   AttrSlice attrs = node->attrs();
652   attr->clear();
653   for (const auto& node_attr : attrs) {
654     if (node_attr.first == group_attribute_) {
655       TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string"));
656       *attr = node_attr.second.s();
657       break;
658     }
659   }
660   return OkStatus();
661 }
662 
IsInSubgraph(const string & func_id)663 bool IsInSubgraph(const string& func_id) { return !func_id.empty(); }
664 
CopySubgraphNodes(std::unordered_map<const Node *,Node * > * node_images)665 Status Encapsulator::CopySubgraphNodes(
666     std::unordered_map<const Node*, Node*>* node_images) {
667   for (Node* node : graph_in_->op_nodes()) {
668     string func_id;
669     TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id));
670     if (!IsInSubgraph(func_id)) continue;
671 
672     Subgraph& subgraph = subgraphs_[func_id];
673     Node* image = subgraph.MakeNodeImage(graph_in_, node);
674     image->ClearAttr(group_attribute_);
675     (*node_images)[node] = image;
676   }
677   return OkStatus();
678 }
679 
CopySubgraphEdges(const std::unordered_map<const Node *,Node * > & node_images,std::vector<std::pair<const Node *,Node * >> * src_arg_pairs)680 Status Encapsulator::CopySubgraphEdges(
681     const std::unordered_map<const Node*, Node*>& node_images,
682     std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
683   for (const Edge* edge : graph_in_->edges()) {
684     string src_func_id;
685     TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id));
686     string dst_func_id;
687     TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id));
688     Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr);
689     Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr);
690 
691     // Copy edges that are local to a subgraph.
692     if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) &&
693         src_func_id == dst_func_id) {
694       Graph* g = subgraphs_[src_func_id].GetGraph();
695       if (edge->IsControlEdge()) {
696         g->AddControlEdge(src_image, dst_image,
697                           /* allow_duplicates= */ true);
698       } else {
699         g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input());
700       }
701       continue;
702     }
703 
704     // Record 'src' as an output of its subgraph, if applicable.
705     if (IsInSubgraph(src_func_id)) {
706       if (!edge->IsControlEdge()) {
707         DataType dtype = edge->src()->output_type(edge->src_output());
708         if (IsRefType(dtype)) {
709           return errors::InvalidArgument(
710               "Ref Tensors (e.g., Variables) are not supported as results: "
711               "tensor ",
712               edge->src()->name(), ":", edge->src_output());
713         }
714       }
715 
716       Subgraph& src_subgraph = subgraphs_[src_func_id];
717       if (edge->IsControlEdge()) {
718         TF_RETURN_IF_ERROR(src_subgraph.RecordControlResult(edge, node_images));
719       } else {
720         TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images));
721       }
722     }
723 
724     // Record 'dst' as an input of its subgraph, if applicable.
725     if (IsInSubgraph(dst_func_id)) {
726       // Look at the type of the destination not the source, since Ref output
727       // Tensors can be automatically cast to non-Ref Tensors at the
728       // destination.
729       if (!edge->IsControlEdge()) {
730         DataType dtype = edge->dst()->input_type(edge->dst_input());
731         if (IsRefType(dtype)) {
732           return errors::InvalidArgument(
733               "Ref Tensors (e.g., Variables) are not supported as args: "
734               "tensor ",
735               edge->src()->name(), ":", edge->src_output());
736         }
737       }
738 
739       Subgraph& dst_subgraph = subgraphs_[dst_func_id];
740       // Ignore control edges entering the subgraph. We will lift them onto
741       // the enclosing call operators in BuildOutputGraph().
742       if (!edge->IsControlEdge()) {
743         TF_RETURN_IF_ERROR(
744             dst_subgraph.RecordArg(edge, node_images, src_arg_pairs));
745       }
746     }
747   }
748   return OkStatus();
749 }
750 
SplitIntoSubgraphs(FunctionLibraryDefinition * library)751 Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) {
752   Status s;
753 
754   // Map from input graph nodes to subgraph nodes.
755   std::unordered_map<const Node*, Node*> node_images;
756 
757   // Each entry of src_arg_pairs is a pair whose first element is a node in the
758   // original graph that has an output edge in the subgraph, and whose second
759   // element is the arg node in the subgraph that it sends to. The vector will
760   // be filled in below in AddArgs.
761   std::vector<std::pair<const Node*, Node*>> src_arg_pairs;
762 
763   TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images));
764   TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs));
765   MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
766 
767   for (auto& entry : subgraphs_) {
768     Subgraph& subgraph = entry.second;
769     FixupSourceAndSinkEdges(subgraph.GetGraph());
770   }
771 
772   if (VLOG_IS_ON(1)) {
773     // Dump subgraphs.
774     for (auto& entry : subgraphs_) {
775       DumpGraphToFile(
776           absl::StrCat("encapsulate_subgraphs_subgraph_", entry.first),
777           *entry.second.GetGraph(), library);
778     }
779   }
780 
781   return s;
782 }
783 
BuildFunctionDefs(const RewriteSubgraphFn & rewrite_subgraph_fn,bool reuse_existing_functions,FunctionLibraryDefinition * library)784 Status Encapsulator::BuildFunctionDefs(
785     const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
786     FunctionLibraryDefinition* library) {
787   for (auto& subgraph_entry : subgraphs_) {
788     string name = subgraph_entry.first;
789     Subgraph& subgraph = subgraph_entry.second;
790     TF_RETURN_IF_ERROR(subgraph.BuildFunctionDef(
791         name, rewrite_subgraph_fn, reuse_existing_functions, library));
792   }
793   return OkStatus();
794 }
795 
CopyNodesToOutputGraph(Graph * graph_out,std::unordered_map<const Node *,Node * > * node_images)796 Status Encapsulator::CopyNodesToOutputGraph(
797     Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images) {
798   for (Node* node : graph_in_->op_nodes()) {
799     string func_id;
800     TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id));
801 
802     // Don't copy nodes that are going to be encapsulated.
803     if (IsInSubgraph(func_id)) continue;
804 
805     Node* image = graph_out->CopyNode(node);
806     (*node_images)[node] = image;
807   }
808   (*node_images)[graph_in_->source_node()] = graph_out->source_node();
809   (*node_images)[graph_in_->sink_node()] = graph_out->sink_node();
810   return OkStatus();
811 }
812 
AddFunctionCallNodes(const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out)813 Status Encapsulator::AddFunctionCallNodes(
814     const std::unordered_map<const Node*, Node*>& node_images,
815     Graph* graph_out) {
816   for (auto& subgraph_entry : subgraphs_) {
817     TF_RETURN_IF_ERROR(
818         subgraph_entry.second.AddFunctionCallNode(node_images, graph_out));
819   }
820   return OkStatus();
821 }
822 
FindOutputImageOfEdgeSrc(const string & src_func_id,const string & dst_func_id,const std::unordered_map<const Node *,Node * > & node_images,const Node * original_src_node,Node ** src_image)823 Status Encapsulator::FindOutputImageOfEdgeSrc(
824     const string& src_func_id, const string& dst_func_id,
825     const std::unordered_map<const Node*, Node*>& node_images,
826     const Node* original_src_node, Node** src_image) {
827   if (IsInSubgraph(src_func_id)) {
828     // The edge is from a subgraph to a regular node in the output graph so
829     // use the subgraph's call node output.
830     *src_image = subgraphs_.at(src_func_id).GetCallNode();
831   } else {
832     // The source of the edge is in the output graph so use the node image in
833     // the output graph.
834     *src_image = node_images.at(original_src_node);
835   }
836   return OkStatus();
837 }
838 
FindOutputSlotOfEdgeSrc(const string & src_func_id,const string & dst_func_id,const Edge * edge)839 int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id,
840                                           const string& dst_func_id,
841                                           const Edge* edge) {
842   if (IsInSubgraph(src_func_id)) {
843     const Subgraph& src_subgraph = subgraphs_.at(src_func_id);
844     // 'src' is in a subgraph and 'dst' is a regular node in the output
845     // graph. Use the corresponding call output instead.
846     return src_subgraph.GetResultIndexForEdge(edge);
847   } else {
848     // The source of the edge is in the output graph so use the regular edge
849     // slot.
850     return edge->src_output();
851   }
852 }
853 
FindOutputImageOfEdgeDst(const string & src_func_id,const string & dst_func_id,const std::unordered_map<const Node *,Node * > & node_images,const Node * original_dst_node,Node ** dst_image)854 Status Encapsulator::FindOutputImageOfEdgeDst(
855     const string& src_func_id, const string& dst_func_id,
856     const std::unordered_map<const Node*, Node*>& node_images,
857     const Node* original_dst_node, Node** dst_image) {
858   if (IsInSubgraph(dst_func_id)) {
859     // The edge is to a subgraph from a regular node in the output graph so
860     // use the subgraph's call node input.
861     *dst_image = subgraphs_.at(dst_func_id).GetCallNode();
862   } else {
863     // The destination of the edge is in the output graph so use the node image
864     // in the output graph.
865     *dst_image = node_images.at(original_dst_node);
866   }
867   return OkStatus();
868 }
869 
FindOutputSlotOfEdgeDst(const string & src_func_id,const string & dst_func_id,const Edge * edge)870 int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id,
871                                           const string& dst_func_id,
872                                           const Edge* edge) {
873   if (IsInSubgraph(dst_func_id)) {
874     const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id);
875       // 'dst' is in a subgraph and 'src' is a regular node in the output
876       // graph. Use the corresponding call input instead.
877       return dst_subgraph.GetArgIndexForEdge(edge);
878   } else {
879     // The destination of the edge is in the output graph so use the regular
880     // edge slot.
881     return edge->dst_input();
882   }
883 }
884 
CopyEdgeToOutputGraph(const Edge * edge,const string & src_func_id,const string & dst_func_id,const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out,std::unordered_set<std::pair<OutputTensor,InputTensor>,OutputInputTensorPairHasher> * edges_added)885 Status Encapsulator::CopyEdgeToOutputGraph(
886     const Edge* edge, const string& src_func_id, const string& dst_func_id,
887     const std::unordered_map<const Node*, Node*>& node_images, Graph* graph_out,
888     std::unordered_set<std::pair<OutputTensor, InputTensor>,
889                        OutputInputTensorPairHasher>* edges_added) {
890   Node* src_image;
891   TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc(
892       src_func_id, dst_func_id, node_images, edge->src(), &src_image));
893   Node* dst_image;
894   TF_RETURN_IF_ERROR(FindOutputImageOfEdgeDst(
895       src_func_id, dst_func_id, node_images, edge->dst(), &dst_image));
896 
897   // If this is a control edge then copy it and return. Lift control edges onto
898   // the enclosing call operator.
899   if (edge->IsControlEdge()) {
900     // Add the control edge, if we have not already added it, using the images
901     // determined above (potentially call operators or RecvAtHost/SendFromHost).
902     if (edges_added
903             ->emplace(OutputTensor(src_image, -1), InputTensor(dst_image, -1))
904             .second) {
905       graph_out->AddControlEdge(src_image, dst_image,
906                                 /* allow_duplicates= */ true);
907     }
908 
909     return OkStatus();
910   }
911 
912   int src_output = FindOutputSlotOfEdgeSrc(src_func_id, dst_func_id, edge);
913 
914   int dst_input = FindOutputSlotOfEdgeDst(src_func_id, dst_func_id, edge);
915 
916   // Add the edge, if we have not already added it.
917   if (edges_added
918           ->emplace(OutputTensor(src_image, src_output),
919                     InputTensor(dst_image, dst_input))
920           .second) {
921     graph_out->AddEdge(src_image, src_output, dst_image, dst_input);
922   }
923   return OkStatus();
924 }
925 
AddEdgesToOutputGraph(const std::unordered_map<const Node *,Node * > & node_images,Graph * graph_out)926 Status Encapsulator::AddEdgesToOutputGraph(
927     const std::unordered_map<const Node*, Node*>& node_images,
928     Graph* graph_out) {
929   // Set of edges already added to the output graph, represented as (src, dst)
930   // pairs. We use the set to deduplicate edges; multiple edges in the input
931   // graph may map to one edge in the output graph.
932   std::unordered_set<std::pair<OutputTensor, InputTensor>,
933                      OutputInputTensorPairHasher>
934       edges_added;
935 
936   for (const Edge* edge : graph_in_->edges()) {
937     string src_func_id;
938     TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id));
939     string dst_func_id;
940     TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id));
941 
942     // Ignore edges that are strictly contained within one subgraph, unless
943     // we are constructing parallel check graphs.
944     if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) &&
945         src_func_id == dst_func_id) {
946       continue;
947     }
948 
949     // We have an edge that crosses a cluster boundary or is entirely within the
950     // unclustered graph.
951     TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph(
952         edge, src_func_id, dst_func_id, node_images, graph_out, &edges_added));
953   }
954 
955   for (auto& subgraph_entry : subgraphs_) {
956     Subgraph& subgraph = subgraph_entry.second;
957     subgraph.ConnectSequencerToCallNode(graph_out);
958   }
959 
960   return OkStatus();
961 }
962 
963 namespace {
964 
965 // Adds a dummy Const node to graph_out. The "constant" has the type of
966 // data_type and the shape indicated in 'shape'. The dummy node is not a valid
967 // Const node because it does not have any value defined, but this doesn't
968 // matter because it will only be used subsequently for shape inference. (It
969 // would be possible to add a switch statement over data_type to create a value
970 // for the constant, but that would entail maintaining the logic as new types
971 // are added, and is not necessary.) If the node being replaced was within a
972 // control flow frame, adds appropriate Enter nodes so that the use of the Const
973 // is well-formed.
AddDummyShapedNode(const Node * src_node,int src_port,const std::vector<ControlFlowInfo> & control_flow_info,const TensorShapeProto & shape,Graph * graph_out)974 Node* AddDummyShapedNode(const Node* src_node, int src_port,
975                          const std::vector<ControlFlowInfo>& control_flow_info,
976                          const TensorShapeProto& shape, Graph* graph_out) {
977   DataType data_type = src_node->output_type(src_port);
978   TensorProto dummy_proto;
979   dummy_proto.set_dtype(data_type);
980   *dummy_proto.mutable_tensor_shape() = shape;
981   // Don't set any value field in the proto, since it is only going to be used
982   // for shape inference.
983 
984   GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
985   NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
986                            options.op_registry());
987   node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
988   Node* node = options.FinalizeBuilder(&node_builder);
989   // Add any Enter nodes required to bring the constant to the correct control
990   // flow frame.
991   while (!control_flow_info[src_node->id()].frame_name.empty()) {
992     NodeDebugInfo debug_info(*src_node);
993     NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter",
994                               options.op_registry(), &debug_info);
995     enter_builder.Attr("frame_name",
996                        control_flow_info[src_node->id()].frame_name);
997     enter_builder.Attr("is_constant", true);
998     enter_builder.Input(node, 0);
999     Node* enter_node = options.FinalizeBuilder(&enter_builder);
1000     // Adopt the new Enter node as the value in the current frame.
1001     node = enter_node;
1002     // Recurse to the parent frame to see if more Enter nodes need to be added.
1003     src_node = control_flow_info[src_node->id()].parent_frame;
1004   }
1005   return node;
1006 }
1007 
1008 }  // namespace
1009 
MakePrunedGraphCopyAndInline(const Graph & graph,const std::vector<Node * > & sink_nodes,std::unique_ptr<Graph> * pruned_graph,std::unordered_map<const Node *,Node * > * node_images,FunctionLibraryDefinition * library)1010 Status Encapsulator::MakePrunedGraphCopyAndInline(
1011     const Graph& graph, const std::vector<Node*>& sink_nodes,
1012     std::unique_ptr<Graph>* pruned_graph,
1013     std::unordered_map<const Node*, Node*>* node_images,
1014     FunctionLibraryDefinition* library) {
1015   // First copy all ancestor nodes of sink_nodes into a new graph.
1016   pruned_graph->reset(new Graph(library));
1017   (*pruned_graph)->set_versions(graph.versions());
1018   ReverseDFSFrom(graph, sink_nodes,
1019                  /*enter=*/nullptr,
1020                  /*leave=*/[&](Node* n) {
1021                    if (!n->IsSource()) {
1022                      Node* copied = (*pruned_graph)->CopyNode(n);
1023                      node_images->emplace(n, copied);
1024                    }
1025                  });
1026 
1027   // Add all the edges between copied nodes.
1028   for (auto entry : *node_images) {
1029     const Node* orig = entry.first;
1030     Node* image = entry.second;
1031     for (const Edge* out_edge : orig->out_edges()) {
1032       auto iter = node_images->find(out_edge->dst());
1033       if (iter != node_images->end()) {
1034         // The source and destination are both in the copied graph.
1035         (*pruned_graph)
1036             ->AddEdge(image, out_edge->src_output(), iter->second,
1037                       out_edge->dst_input());
1038       }
1039     }
1040   }
1041 
1042   // Find all the function call nodes, and inline them.
1043   std::vector<Node*> function_nodes;
1044   for (auto node : (*pruned_graph)->nodes()) {
1045     const OpRegistrationData* op_reg_data;
1046     TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data));
1047     if (op_reg_data->is_function_op) {
1048       function_nodes.push_back(node);
1049     }
1050   }
1051   for (auto node : function_nodes) {
1052     VLOG(2) << "Inlining function " << node->name();
1053     const FunctionDef* fdef = library->Find(node->type_string());
1054     if (fdef == nullptr) {
1055       return errors::Internal("Failed to find function ", node->type_string(),
1056                               " in function library.");
1057     }
1058     std::unique_ptr<FunctionBody> fbody;
1059     TF_RETURN_IF_ERROR(
1060         FunctionDefToBodyHelper(*fdef, node->attrs(), library, &fbody));
1061 
1062     InlineFunctionBodyOptions inline_opts;
1063     TF_RETURN_IF_ERROR(InlineFunctionBody(*library, pruned_graph->get(), node,
1064                                           fbody.get(), inline_opts));
1065   }
1066 
1067   return OkStatus();
1068 }
1069 
BuildOutputGraph(Graph * graph_out,FunctionLibraryDefinition * library)1070 Status Encapsulator::BuildOutputGraph(Graph* graph_out,
1071                                       FunctionLibraryDefinition* library) {
1072   // Map from nodes in the input graph to nodes in the output graph.
1073   std::unordered_map<const Node*, Node*> node_images;
1074 
1075   TF_RETURN_IF_ERROR(CopyNodesToOutputGraph(graph_out, &node_images));
1076   TF_RETURN_IF_ERROR(AddFunctionCallNodes(node_images, graph_out));
1077   TF_RETURN_IF_ERROR(AddEdgesToOutputGraph(node_images, graph_out));
1078 
1079   return OkStatus();
1080 }
1081 
1082 }  // anonymous namespace
1083 
EncapsulateSubgraphsInFunctions(string group_attribute,const Graph & graph_in,const RewriteSubgraphFn & rewrite_subgraph_fn,bool reuse_existing_functions,std::unique_ptr<Graph> * graph_out,FunctionLibraryDefinition * library)1084 Status EncapsulateSubgraphsInFunctions(
1085     string group_attribute, const Graph& graph_in,
1086     const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
1087     std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library) {
1088   Encapsulator encapsulator(std::move(group_attribute),
1089                             &graph_in);
1090   TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs(library));
1091 
1092   TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs(
1093       rewrite_subgraph_fn, reuse_existing_functions, library));
1094 
1095   std::unique_ptr<Graph> out(new Graph(library));
1096   out->set_versions(graph_in.versions());
1097   TF_RETURN_IF_ERROR(encapsulator.BuildOutputGraph(out.get(), library));
1098 
1099   *graph_out = std::move(out);
1100   return OkStatus();
1101 }
1102 
1103 // Finds the types of the _Arg nodes, indexed by position.
GetArgTypes(const Graph & graph,DataTypeVector * types)1104 static Status GetArgTypes(const Graph& graph, DataTypeVector* types) {
1105   for (Node* n : graph.op_nodes()) {
1106     if (n->type_string() == kArgOp) {
1107       int index;
1108       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
1109       const int num_types = types->size();
1110       if (index < 0 || index >= num_types) {
1111         return errors::InvalidArgument("Invalid argument number");
1112       }
1113       (*types)[index] = n->output_type(0);
1114     }
1115   }
1116   return OkStatus();
1117 }
1118 
1119 // Renumber the indices of _Arg nodes in a graph, according to
1120 // 'permutation' that maps old indices to new indices.
RenumberArguments(Graph * graph,const std::vector<int> & permutation)1121 static Status RenumberArguments(Graph* graph,
1122                                 const std::vector<int>& permutation) {
1123   for (Node* n : graph->op_nodes()) {
1124     if (n->type_string() == kArgOp) {
1125       int index;
1126       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
1127       const int permutation_size = permutation.size();
1128       if (index < 0 || index >= permutation_size) {
1129         return errors::InvalidArgument("Invalid argument number");
1130       }
1131       n->AddAttr("index", permutation[index]);
1132     }
1133   }
1134   return OkStatus();
1135 }
1136 
Run(const GraphOptimizationPassOptions & options)1137 Status EncapsulateSubgraphsPass::Run(
1138     const GraphOptimizationPassOptions& options) {
1139   VLOG(1) << "EncapsulateSubgraphsPass::Run";
1140   if (VLOG_IS_ON(1)) {
1141     DumpGraphToFile("encapsulate_subgraphs_before", **options.graph,
1142                     options.flib_def);
1143   }
1144 
1145   // TODO(b/195757077): Remove this once there is a better way to disable
1146   // GraphOptimizationPasses that are not needed due to MLIR bridge.
1147   for (Node* n : (*options.graph)->nodes()) {
1148     // Skip the pass if we found TPUExecute or TPUExecuteAndUpdateVariables ops
1149     // in the graph, which indicates the graph is produced by TPU TF-XLA bridge
1150     // and doesn't require auto clustering.
1151     if (n->type_string() == "TPUExecute" ||
1152         n->type_string() == "TPUExecuteAndUpdateVariables") {
1153       return OkStatus();
1154     }
1155   }
1156 
1157   std::unique_ptr<Graph> graph_out;
1158   FunctionLibraryDefinition* const library = options.flib_def;
1159 
1160   // Constant folding below might need to run part of the function to compute
1161   // constants. Create an FunctionLibraryRuntime with a single CPU device
1162   // that can run the part of the function.
1163   // NOTE: If this turns out to be slow, we can cache the FLRs keyed by
1164   // `options`.
1165   SessionOptions session_options;
1166   auto* device_count = session_options.config.mutable_device_count();
1167   device_count->insert({"CPU", 1});
1168   std::vector<std::unique_ptr<Device>> devices;
1169 
1170   DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
1171   if (!cpu_factory) {
1172     return errors::NotFound(
1173         "CPU Factory not registered. Can't run EncapsulateSubgraphsPass");
1174   }
1175   TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
1176       session_options, "/job:localhost/replica:0/task:0", &devices));
1177   if (devices.empty()) {
1178     return errors::NotFound(
1179         "Failed to create a CPU device for EncapsulateSubgraphsPass");
1180   }
1181 
1182   std::unique_ptr<DeviceMgr> device_mgr =
1183       std::make_unique<StaticDeviceMgr>(std::move(devices));
1184   const auto* config = &options.session_options->config;
1185   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
1186       new ProcessFunctionLibraryRuntime(
1187           device_mgr.get(), options.session_options->env,
1188           /*config=*/config, TF_GRAPH_DEF_VERSION, library,
1189           config->graph_options().optimizer_options()));
1190   FunctionLibraryRuntime* flr =
1191       pflr->GetFLR("/job:localhost/replica:0/task:0/device:CPU:0");
1192   if (flr == nullptr) {
1193     return errors::Internal(
1194         "Failed to create and retrieve function library runtime to run "
1195         "constant folding");
1196   }
1197 
1198   auto rewrite_subgraph =
1199       [flr](const std::vector<OutputTensor>& arg_source_tensors,
1200             std::unique_ptr<Graph>* subgraph,
1201             std::vector<int>* input_permutation,
1202             std::vector<int>* output_permutation, NodeDef* node) {
1203         // Optimize the subgraph.
1204         // Do not constant fold nodes that output DT_VARIANT type tensors.
1205         // XLA does not support Const nodes of Variant type since it needs
1206         // to know the original ops to be able to compile them to the relevant
1207         // XLA form.
1208         // TODO(srbs): This filter is a little conservative. E.g. a subgraph of
1209         // the form:
1210         //                          Const
1211         //                            |
1212         // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op
1213         //                                                  |
1214         //                                        (Discard popped list)
1215         //
1216         // Would have been reduced to "Const -> Op" without this filter.
1217         // However since we are only allowed to specify the filter at the "Node"
1218         // level there is no good way to allow the above behavior. So we
1219         // disallow any sort of constant folding on Variant nodes for now.
1220         bool disable_constant_folding =
1221             GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding;
1222         auto cf_consider_fn = [disable_constant_folding](const Node* n) {
1223           if (disable_constant_folding) return false;
1224           for (const auto& output_arg : n->op_def().output_arg()) {
1225             if (output_arg.type() == DT_VARIANT) {
1226               return false;
1227             }
1228           }
1229           return true;
1230         };
1231         GraphOptimizer::Options graph_optimizer_options;
1232         graph_optimizer_options.cf_consider_fn = cf_consider_fn;
1233         OptimizeGraph(flr, subgraph, graph_optimizer_options);
1234 
1235         const int num_args = input_permutation->size();
1236         std::vector<bool> const_args(num_args);
1237         TF_RETURN_IF_ERROR(
1238             BackwardsConstAnalysis(**subgraph, &const_args,
1239                                    /*compile_time_const_nodes=*/nullptr, flr));
1240 
1241         DataTypeVector arg_types(num_args);
1242         TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
1243 
1244         // Compute a permutation of the arguments such that the constant
1245         // arguments are first.
1246         const int num_consts =
1247             std::count(const_args.begin(), const_args.end(), true);
1248 
1249         const int num_resources =
1250             std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE);
1251         const int num_nonconsts = num_args - num_resources - num_consts;
1252         if (num_nonconsts < 0) {
1253           return errors::Internal("num_nonconsts should be >= 0, was ",
1254                                   num_nonconsts);
1255         }
1256 
1257         int const_pos = 0;
1258         int arg_pos = num_consts;
1259         int resource_pos = num_consts + num_nonconsts;
1260         for (int i = 0; i < num_args; ++i) {
1261           if (const_args[i]) {
1262             if (arg_types[i] == DT_RESOURCE) {
1263               return errors::Internal(
1264                   "Resource arguments cannot be constant (argument ", i, ")");
1265             }
1266             (*input_permutation)[i] = const_pos;
1267             ++const_pos;
1268           } else if (arg_types[i] == DT_RESOURCE) {
1269             (*input_permutation)[i] = resource_pos;
1270             ++resource_pos;
1271           } else {
1272             (*input_permutation)[i] = arg_pos;
1273             ++arg_pos;
1274           }
1275         }
1276 
1277         // Renumber argument nodes in the graph.
1278         TF_RETURN_IF_ERROR(
1279             RenumberArguments(subgraph->get(), *input_permutation));
1280 
1281         // TODO(phawkins): add a forward is-constant analysis, similarly split
1282         // outputs into host-memory constants and device-memory non-constants.
1283 
1284         AddNodeAttr(kXlaCompiledKernelAttr, true, node);
1285         AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
1286         AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node);
1287         return OkStatus();
1288       };
1289 
1290   TF_RETURN_WITH_CONTEXT_IF_ERROR(
1291       EncapsulateSubgraphsInFunctions(
1292           kXlaClusterAttr, **options.graph, rewrite_subgraph,
1293           /*reuse_existing_functions=*/false, &graph_out, library),
1294       "EncapsulateSubgraphsPass failed");
1295   if (VLOG_IS_ON(1)) {
1296     DumpGraphToFile("encapsulate_subgraphs_after", *graph_out,
1297                     options.flib_def);
1298   }
1299 
1300   *options.graph = std::move(graph_out);
1301 
1302   TF_ASSIGN_OR_RETURN(absl::flat_hash_set<Node*> ref_related_nodes,
1303                       GetNodesRelatedToRefVariables(**options.graph, flr));
1304   for (Node* node : (*options.graph)->nodes()) {
1305     bool has_ref_vars = ref_related_nodes.contains(node);
1306     node->AddAttr(kXlaHasReferenceVarsAttr, has_ref_vars);
1307     VLOG(3) << "Has ref vars = " << has_ref_vars
1308             << ", node: " << node->def().DebugString();
1309   }
1310   return OkStatus();
1311 }
1312 
IsXlaCompiledKernel(const Node & node)1313 bool IsXlaCompiledKernel(const Node& node) {
1314   bool is_compiled = false;
1315   bool has_compilation_attr =
1316       TryGetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled) &&
1317       is_compiled;
1318   return has_compilation_attr ? is_compiled : false;
1319 }
1320 
1321 }  // namespace tensorflow
1322