xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/build_xla_ops_pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/build_xla_ops_pass.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/cc/framework/ops.h"
22 #include "tensorflow/cc/framework/scope_internal.h"
23 #include "tensorflow/cc/ops/array_ops.h"
24 #include "tensorflow/cc/ops/const_op.h"
25 #include "tensorflow/cc/ops/control_flow_ops.h"
26 #include "tensorflow/cc/ops/functional_ops.h"
27 #include "tensorflow/cc/ops/logging_ops.h"
28 #include "tensorflow/compiler/jit/defs.h"
29 #include "tensorflow/compiler/jit/device_util.h"
30 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
31 #include "tensorflow/compiler/jit/flags.h"
32 #include "tensorflow/compiler/jit/xla_cluster_util.h"
33 #include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
34 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
35 #include "tensorflow/compiler/xla/status_macros.h"
36 #include "tensorflow/core/common_runtime/function.h"
37 #include "tensorflow/core/common_runtime/graph_constructor.h"
38 #include "tensorflow/core/common_runtime/optimization_registry.h"
39 #include "tensorflow/core/framework/graph_def_util.h"
40 #include "tensorflow/core/framework/memory_types.h"
41 #include "tensorflow/core/framework/node_def_builder.h"
42 #include "tensorflow/core/framework/node_def_util.h"
43 #include "tensorflow/core/graph/algorithm.h"
44 #include "tensorflow/core/graph/graph.h"
45 #include "tensorflow/core/lib/core/status.h"
46 #include "tensorflow/core/lib/hash/hash.h"
47 #include "tensorflow/core/public/version.h"
48 #include "tensorflow/core/util/dump_graph.h"
49 
50 namespace tensorflow {
51 namespace {
52 struct DebuggingOpts {
53   // If true, insert Print nodes to print every output from an XLA cluster.
54   bool print_outputs;
55 
56   // If true, insert CheckNumerics nodes for every floating point typed input to
57   // an XLA cluster.
58   bool check_input_numerics;
59 
60   // If true, insert CheckNumerics nodes for every floating point typed output
61   // from an XLA cluster.
62   bool check_output_numerics;
63 };
64 
MoveOutgoingEdges(Graph * g,Node * old_node,Node * new_node)65 void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
66   std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
67                                      old_node->out_edges().end());
68   for (const Edge* edge : out_edges) {
69     // TODO(sanjoy): This does not update NodeDef inputs.  To be able to update
70     // NodeDef inputs we first need to fix encapsulate_subgraphs_pass to fix up
71     // the NodeDef inputs to the function call nodes.
72     g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input());
73     g->RemoveEdge(edge);
74   }
75 }
76 
77 // Returns a data value that is dead iff `control` is dead.
ControlToData(const Scope & scope,Node * control)78 Output ControlToData(const Scope& scope, Node* control) {
79   // The choice of data type here is important.
80   //
81   // We implement a "control merge", which is a control edge that is alive if
82   // either of two nodes (denoted as A and B below) are alive, in the following
83   // manner:
84   //
85   //   A --ctrl--> Const0 --data--> Merge --data--> Identity
86   //                                 ^                 |
87   //                                 |                ctrl
88   //   B --ctrl--> Const1 --data-----+                 |
89   //                                                   v
90   //                                                  ***
91   //
92   // where *** denotes the merged control output.
93   //
94   // We want everything starting from Const{0/1} to Identity to either wholly
95   // live on the host or wholly live on device so we need to pick a data type
96   // that is either consistently assigned to the device (e.g. float) or
97   // consistently assigned to the host (e.g. int32).  We should *not* pick a
98   // data type that partly placed on the host and partly on the device
99   // (e.g. bool constants are placed on the device but bool Identity is placed
100   // on the host).
101   Output data = ops::Const(scope.WithOpName("ctrl_as_data"),
102                            Tensor(DT_INT32, TensorShape({0})));
103   scope.graph()->AddControlEdge(control, data.node());
104   return Output(data.node());
105 }
106 
107 // Returns an operation that can be control-depended on that is dead iff `data`
108 // is dead.
DataToControl(const Scope & scope,Output data)109 Operation DataToControl(const Scope& scope, Output data) {
110   return Operation(
111       ops::Identity(scope.WithOpName("data_as_ctrl"), data).node());
112 }
113 
114 // Replaces each outgoing edge from `old_node` with a merge node that merges in
115 // the corresponding output from `new_node`.
MergeOutgoingDataEdges(const Scope & s,Node * old_node,Node * new_node,absl::string_view cluster_name,const DebuggingOpts & debugging_opts)116 void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node,
117                             absl::string_view cluster_name,
118                             const DebuggingOpts& debugging_opts) {
119   if (!s.status().ok()) {
120     return;
121   }
122 
123   std::vector<Output> merged_outputs(old_node->num_outputs(), Output(nullptr));
124 
125   std::vector<const Edge*> data_edges;
126   absl::c_copy_if(old_node->out_edges(), std::back_inserter(data_edges),
127                   [](const Edge* e) { return !e->IsControlEdge(); });
128 
129   for (const Edge* e : data_edges) {
130     int oidx = e->src_output();
131     Output merged_output = merged_outputs[oidx];
132     if (merged_output.node() == nullptr) {
133       Output new_output(new_node, oidx);
134       if (debugging_opts.print_outputs) {
135         string cpu_device = "/job:localhost/replica:0/task:0/device:CPU:0";
136         ops::Print print_op(s.WithOpName("print_", oidx)
137                                 .WithDevice(cpu_device)
138                                 .WithAssignedDevice(cpu_device),
139                             new_output, {new_output},
140                             ops::Print::Attrs{}
141                                 .Message(absl::StrCat("output ", oidx, " from ",
142                                                       old_node->name(), " is "))
143                                 .FirstN(1000)
144                                 .Summarize(-1));
145         new_output = print_op;
146       }
147 
148       if (debugging_opts.check_output_numerics &&
149           DataTypeIsFloating(new_output.type())) {
150         ops::CheckNumerics check_numerics_op(
151             s.WithOpName("check_output_", oidx)
152                 .WithDevice(new_node->requested_device())
153                 .WithAssignedDevice(new_node->assigned_device_name()),
154             new_output,
155             absl::StrCat("CheckNumerics failed for output ", oidx, "(",
156                          new_output.name(), ") from cluster ", cluster_name));
157         new_output = check_numerics_op;
158       }
159 
160       ops::_XlaMerge xla_merge_op(s.WithOpName("merge_oidx_", oidx),
161                                   Output(old_node, oidx), new_output);
162       merged_output = merged_outputs[oidx] = xla_merge_op.output;
163     }
164 
165     Node* dst = e->dst();
166     int dst_idx = e->dst_input();
167 
168     s.graph()->RemoveEdge(e);
169     s.graph()->AddEdge(merged_output.node(), merged_output.index(), dst,
170                        dst_idx);
171   }
172 }
173 
174 // Replaces each control successor of `old_node` to execute whenever either
175 // `old_node` or `new_node` is executed.
MergeOutgoingControlEdges(const Scope & s,Node * old_node,Node * new_node)176 void MergeOutgoingControlEdges(const Scope& s, Node* old_node, Node* new_node) {
177   if (!s.status().ok()) {
178     return;
179   }
180 
181   std::vector<const Edge*> ctrl_edges;
182   absl::c_copy_if(old_node->out_edges(), std::back_inserter(ctrl_edges),
183                   [](const Edge* e) { return e->IsControlEdge(); });
184 
185   if (ctrl_edges.empty()) {
186     return;
187   }
188 
189   if (ctrl_edges.size() == 1 && ctrl_edges.front()->dst()->IsSink()) {
190     // Avoid creating a Merge node if we can just add an edge to _SINK
191     // instead.
192     s.graph()->AddControlEdge(new_node, s.graph()->sink_node());
193     return;
194   }
195 
196   // We can't merge control edges directly so we instead first "convert" them to
197   // normal values that can be merged, merge the values and then "convert" the
198   // merged value back into control.
199   //
200   // NB! We need to copy out the outgoing control edges before constructing
201   // old_ctrl_as_data otherwise the control edge from old_node to the constant
202   // in ControlToData will be present in ctrl_edges.
203 
204   Output old_ctrl_as_data = ControlToData(s, old_node);
205   Output new_ctrl_as_data = ControlToData(s, new_node);
206 
207   ops::Merge ctrl_merge_as_data(s.WithOpName("ctrl_merge"),
208                                 {old_ctrl_as_data, new_ctrl_as_data});
209   Operation ctrl_merge = DataToControl(s, ctrl_merge_as_data.output);
210 
211   for (const Edge* e : ctrl_edges) {
212     s.graph()->AddControlEdge(ctrl_merge.node(), e->dst());
213     s.graph()->RemoveControlEdge(e);
214   }
215 }
216 
217 struct XlaClusterInfo {
218   std::vector<Output> constant_inputs;
219   std::vector<Output> non_constant_inputs;
220   std::vector<Output> resource_inputs;
221   NameAttrList function;
222 };
223 
IncomingEdgeAsOutput(const Edge * e)224 Output IncomingEdgeAsOutput(const Edge* e) {
225   return Output(e->src(), e->src_output());
226 }
227 
GetXlaClusterInfo(Node * n,XlaClusterInfo * result)228 Status GetXlaClusterInfo(Node* n, XlaClusterInfo* result) {
229   int num_constant_inputs, num_resource_inputs;
230   TF_RETURN_IF_ERROR(
231       GetNodeAttr(n->attrs(), kXlaNumConstantArgsAttr, &num_constant_inputs));
232   TF_RETURN_IF_ERROR(
233       GetNodeAttr(n->attrs(), kXlaNumResourceArgsAttr, &num_resource_inputs));
234 
235   if (num_constant_inputs < 0 || num_resource_inputs < 0 ||
236       num_constant_inputs + num_resource_inputs > n->num_inputs()) {
237     return errors::InvalidArgument(
238         "Invalid number of constant/resource arguments to XLA kernel.");
239   }
240 
241   int num_non_constant_inputs =
242       n->num_inputs() - num_constant_inputs - num_resource_inputs;
243 
244   std::vector<const Edge*> input_edges_vector;
245   TF_RETURN_IF_ERROR(n->input_edges(&input_edges_vector));
246   absl::Span<const Edge*> input_edges(input_edges_vector);
247 
248   absl::c_transform(input_edges.subspan(0, num_constant_inputs),
249                     std::back_inserter(result->constant_inputs),
250                     IncomingEdgeAsOutput);
251 
252   absl::c_transform(
253       input_edges.subspan(num_constant_inputs, num_non_constant_inputs),
254       std::back_inserter(result->non_constant_inputs), IncomingEdgeAsOutput);
255 
256   absl::c_transform(
257       input_edges.subspan(num_constant_inputs + num_non_constant_inputs,
258                           num_resource_inputs),
259       std::back_inserter(result->resource_inputs), IncomingEdgeAsOutput);
260 
261   result->function.set_name(n->type_string());
262   *result->function.mutable_attr() = n->def().attr();
263   return OkStatus();
264 }
265 
CopyIncomingControlEdges(Graph * g,Node * from,Node * to)266 Status CopyIncomingControlEdges(Graph* g, Node* from, Node* to) {
267   for (const Edge* e : from->in_edges()) {
268     if (e->IsControlEdge()) {
269       g->AddControlEdge(e->src(), to);
270     }
271   }
272 
273   return OkStatus();
274 }
275 
RemoveAllIncomingControlEdges(Graph * g,Node * n)276 void RemoveAllIncomingControlEdges(Graph* g, Node* n) {
277   std::vector<const Edge*> incoming_ctrl_edges;
278   absl::c_copy_if(n->in_edges(), std::back_inserter(incoming_ctrl_edges),
279                   [](const Edge* e) { return e->IsControlEdge(); });
280   for (const Edge* e : incoming_ctrl_edges) {
281     g->RemoveControlEdge(e);
282   }
283 }
284 
285 // Returns true (into `result`) if a node placed on `device` must be compiled.
DeviceRequiresCompilation(const jit::DeviceInfoCache & device_info_cache,jit::DeviceId device,bool * result)286 Status DeviceRequiresCompilation(const jit::DeviceInfoCache& device_info_cache,
287                                  jit::DeviceId device, bool* result) {
288   const XlaOpRegistry::DeviceRegistration* registration =
289       device_info_cache.GetCompilationDevice(device);
290   *result = registration->autoclustering_policy ==
291             XlaOpRegistry::AutoclusteringPolicy::kAlways;
292   return OkStatus();
293 }
294 
295 // Replaces `n` with a `PartitionedCall` op that calls the same function.
ReplaceFunctionCallWithPartitionedCall(const GraphOptimizationPassOptions & options,const FunctionLibraryDefinition & flib_def,Node * n,Graph * g,const NameAttrList & func,const Scope & root)296 StatusOr<Node*> ReplaceFunctionCallWithPartitionedCall(
297     const GraphOptimizationPassOptions& options,
298     const FunctionLibraryDefinition& flib_def, Node* n, Graph* g,
299     const NameAttrList& func, const Scope& root) {
300   string config_string = options.session_options->config.SerializeAsString();
301 
302   int input_count = absl::c_count_if(
303       n->in_edges(), [](const Edge* e) { return !e->IsControlEdge(); });
304 
305   std::vector<Output> args(input_count);
306   for (const Edge* e : n->in_edges()) {
307     if (!e->IsControlEdge()) {
308       args[e->dst_input()] = Output(e->src(), e->src_output());
309     }
310   }
311 
312   // In theory we can use PartitionedCall if the XLA cluster does not have any
313   // stateful operations.  However, for now we choose to be conservative since
314   // we don't have any evidence that choosing a stateless partitioned call helps
315   // for performance.
316   ops::StatefulPartitionedCall call(
317       root.WithOpName("stateful_partitioned_call"), args, n->output_types(),
318       func, ops::StatefulPartitionedCall::Attrs{}.ConfigProto(config_string));
319 
320   for (const Edge* e : n->in_edges()) {
321     if (e->IsControlEdge()) {
322       g->AddControlEdge(e->src(), call.operation.node());
323     }
324   }
325 
326   std::vector<const Edge*> edges_to_delete;
327 
328   for (const Edge* e : n->out_edges()) {
329     edges_to_delete.push_back(e);
330     if (e->IsControlEdge()) {
331       g->AddControlEdge(call.operation.node(), e->dst());
332     } else {
333       g->AddEdge(call.operation.node(), e->src_output(), e->dst(),
334                  e->dst_input());
335     }
336   }
337 
338   for (const Edge* e : edges_to_delete) {
339     g->RemoveEdge(e);
340   }
341 
342   g->RemoveNode(n);
343   return call.operation.node();
344 }
345 
InferDeviceForCluster(jit::DeviceInfoCache * device_info_cache,Node * n,const string & function_name,const FunctionLibraryDefinition & flib_def)346 StatusOr<jit::DeviceId> InferDeviceForCluster(
347     jit::DeviceInfoCache* device_info_cache, Node* n,
348     const string& function_name, const FunctionLibraryDefinition& flib_def) {
349   const FunctionDef* func_def = flib_def.Find(function_name);
350   TF_RET_CHECK(func_def) << "Could not find " << function_name;
351 
352   jit::DeviceSet device_set;
353 
354   for (const NodeDef& ndef : func_def->node_def()) {
355     VLOG(3) << ndef.DebugString();
356     if (!ndef.device().empty()) {
357       TF_ASSIGN_OR_RETURN(jit::DeviceId device_id,
358                           device_info_cache->GetIdFor(ndef.device()));
359       device_set.Insert(device_id);
360     }
361   }
362 
363   if (!n->assigned_device_name().empty()) {
364     // TODO(sanjoy): We need this because EncapsulateSubgraphsPass drops device
365     // assignment when constant folding.  We should fix EncapsulateSubgraphsPass
366     // instead.
367     TF_ASSIGN_OR_RETURN(jit::DeviceId device_id,
368                         device_info_cache->GetIdFor(n->assigned_device_name()));
369     device_set.Insert(device_id);
370   }
371 
372   TF_ASSIGN_OR_RETURN(jit::DeviceId result,
373                       PickDeviceForXla(*device_info_cache, device_set,
374                                        /*allow_mixing_unknown_and_cpu=*/true));
375   VLOG(2) << "For " << function_name << " PickDeviceForXla("
376           << device_info_cache->DebugString(device_set) << ") -> "
377           << device_info_cache->GetNameFor(result);
378   return result;
379 }
380 
GetXlaRunArgs(const Scope & s,const XlaClusterInfo & cluster_info,const DebuggingOpts & debugging_opts)381 std::vector<Output> GetXlaRunArgs(const Scope& s,
382                                   const XlaClusterInfo& cluster_info,
383                                   const DebuggingOpts& debugging_opts) {
384   std::vector<Output> xla_run_args;
385   xla_run_args.reserve(cluster_info.non_constant_inputs.size() +
386                        cluster_info.resource_inputs.size());
387   int input_idx = 0;
388   for (const Output& o : cluster_info.non_constant_inputs) {
389     if (debugging_opts.check_input_numerics && DataTypeIsFloating(o.type())) {
390       ops::CheckNumerics check_numerics_op(
391           s.WithOpName("check_input_", input_idx), o,
392           absl::StrCat("CheckNumerics failed for input ", input_idx, "(",
393                        o.name(), ") into ", cluster_info.function.name()));
394       xla_run_args.push_back(check_numerics_op);
395     } else {
396       xla_run_args.push_back(o);
397     }
398     input_idx++;
399   }
400   absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args));
401   return xla_run_args;
402 }
403 
GetOutputMemoryTypes(const Scope & root,Node * n)404 StatusOr<MemoryTypeVector> GetOutputMemoryTypes(const Scope& root, Node* n) {
405   MemoryTypeVector input_mtypes, output_mtypes;
406   DeviceType device_type("");
407   TF_RETURN_IF_ERROR(
408       DeviceNameToDeviceType(n->assigned_device_name(), &device_type));
409   TF_RETURN_IF_ERROR(MemoryTypesForNode(root.graph()->op_registry(),
410                                         device_type, n->def(), &input_mtypes,
411                                         &output_mtypes));
412   return output_mtypes;
413 }
414 
415 // Predicate INT32 typed inputs to `n` on the deadness of
416 // `predicate_as_control`.
417 //
418 // This is a performance optimization.  Since INT32 arguments to a
419 // PartitionedCall are placed on the host, a producer that produces them on the
420 // device will incur a D2H copy, even if the PartitionedCall is not executed
421 // (i.e. even if we choose to execute the XLA compiled computation via _XlaRun).
422 // To prevent this, we add control dependencies to make the int32 input edges
423 // into the PartitionedCall dead.  With this change the D2H copy only happens if
424 // the PartitionedCall is actually executed.
PredicateInt32Inputs(const Scope & root,Node * n,Operation predicate_as_control)425 Status PredicateInt32Inputs(const Scope& root, Node* n,
426                             Operation predicate_as_control) {
427   std::vector<Output> int32_inputs;
428   std::vector<int> int32_inputs_input_idxs;
429   for (const Edge* e : n->in_edges()) {
430     if (e->IsControlEdge()) {
431       continue;
432     }
433 
434     if (e->src()->output_type(e->src_output()) == DT_INT32) {
435       TF_ASSIGN_OR_RETURN(MemoryTypeVector source_output_mem_types,
436                           GetOutputMemoryTypes(root, e->src()));
437       if (source_output_mem_types[e->src_output()] == DEVICE_MEMORY) {
438         int32_inputs.push_back(Output(e->src(), e->src_output()));
439         int32_inputs_input_idxs.push_back(e->dst_input());
440       }
441     }
442   }
443 
444   if (int32_inputs.empty()) {
445     return OkStatus();
446   }
447 
448   // Create a single IdentityN that is dead if and only if
449   // `predicate_as_control` is dead.
450   //
451   // IdentityN is also special in that, unlike `Identity`, it does not place
452   // int32 inputs in host memory.  Placing int32 inputs in host memory would
453   // defeat the purpose of adding this indirection.
454   ops::IdentityN identity_n(root.WithOpName("int32_id_n"), int32_inputs);
455   root.graph()->AddControlEdge(predicate_as_control.node(),
456                                identity_n.operation.node());
457 
458   for (int i = 0, end = int32_inputs.size(); i < end; i++) {
459     TF_RETURN_IF_ERROR(root.graph()->UpdateEdge(identity_n[i].node(), i, n,
460                                                 int32_inputs_input_idxs[i]));
461   }
462 
463   return OkStatus();
464 }
465 
ReplaceNodeWithXlaCompileAndXlaRun(jit::DeviceInfoCache * device_info_cache,const GraphOptimizationPassOptions & options,const FunctionLibraryDefinition & flib_def,bool lazy_compilation_enabled,const DebuggingOpts & debugging_opts,Graph * g,Node * n)466 Status ReplaceNodeWithXlaCompileAndXlaRun(
467     jit::DeviceInfoCache* device_info_cache,
468     const GraphOptimizationPassOptions& options,
469     const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled,
470     const DebuggingOpts& debugging_opts, Graph* g, Node* n) {
471   XlaClusterInfo cluster_info;
472   TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
473 
474   TF_ASSIGN_OR_RETURN(
475       jit::DeviceId device,
476       InferDeviceForCluster(device_info_cache, n, cluster_info.function.name(),
477                             flib_def));
478 
479   bool requires_compilation;
480   TF_RETURN_IF_ERROR(DeviceRequiresCompilation(*device_info_cache, device,
481                                                &requires_compilation));
482   if (!lazy_compilation_enabled) {
483     requires_compilation = true;
484   }
485 
486   string device_name_str = string(device_info_cache->GetNameFor(device));
487 
488   Status status;
489   Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr)
490                    .NewSubScope(n->name())
491                    .WithDevice(n->requested_device())
492                    .WithAssignedDevice(device_name_str);
493 
494   ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"),
495                                /*constants=*/cluster_info.constant_inputs,
496                                /*args=*/cluster_info.non_constant_inputs,
497                                /*resources=*/cluster_info.resource_inputs,
498                                /*must_compile=*/requires_compilation,
499                                cluster_info.function);
500 
501   bool has_ref_attr;
502   TF_RETURN_IF_ERROR(
503       GetNodeAttr(n->attrs(), kXlaHasReferenceVarsAttr, &has_ref_attr));
504   xla_compile.operation.node()->AddAttr(kXlaHasReferenceVarsAttr, has_ref_attr);
505   TF_RETURN_IF_ERROR(
506       CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node()));
507 
508   std::vector<Output> xla_run_args =
509       GetXlaRunArgs(root, cluster_info, debugging_opts);
510 
511   if (requires_compilation) {
512     // "Strict" compilation:  every _XlaCompile invocation must compile the
513     // cluster.
514     ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
515                          xla_compile.key, n->output_types());
516 
517     MoveOutgoingEdges(g, /*old_node=*/n,
518                       /*new_node=*/xla_run.operation.node());
519     g->RemoveNode(n);
520   } else {
521     // "Lazy" compilation: an _XlaCompile invocation may decide not to compile
522     // the cluster based on profitability heuristics.
523 
524     // We generate the following graph:
525     //
526     //   (use_tf_call, use_xla_run) =
527     //       Switch(pred=xla_compile.compilation_successful,
528     //              value=xla_compile.key)
529     //
530     //   tf_call_outputs = cluster_N(..., ^use_tf_call)
531     //   xla_run_outputs = _XlaRun(..., key=use_xla_run)
532     //   outputs = Merge(tf_call_outputs, xla_run_outputs).
533     ops::Switch s(root.WithOpName("predicated_compilation_key"),
534                   xla_compile.key, xla_compile.compilation_successful);
535     Output predicated_compilation_key = s.output_true;
536     Output inverse_predicated_compilation_key = s.output_false;
537 
538     ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
539                          predicated_compilation_key, n->output_types());
540 
541     MergeOutgoingControlEdges(root, /*old_node=*/n,
542                               /*new_node=*/xla_run.operation.node());
543 
544     MergeOutgoingDataEdges(root, /*old_node=*/n,
545                            /*new_node=*/xla_run.operation.node(),
546                            cluster_info.function.name(), debugging_opts);
547 
548     TF_RETURN_IF_ERROR(root.status());
549 
550     // We already have a TensorFlow function call into the cluster -- the
551     // original node we set out to rewrite.  We just wire in the correct control
552     // deps and we're done.
553     RemoveAllIncomingControlEdges(g, n);
554     Operation inverse_predicate_as_control =
555         DataToControl(root, inverse_predicated_compilation_key);
556     g->AddControlEdge(inverse_predicate_as_control.node(), n);
557     n->ClearAttr(kXlaCompiledKernelAttr);
558 
559     TF_ASSIGN_OR_RETURN(Node* const pco, ReplaceFunctionCallWithPartitionedCall(
560                                              options, flib_def, n, g,
561                                              cluster_info.function, root));
562 
563     TF_RETURN_IF_ERROR(
564         PredicateInt32Inputs(root, pco, inverse_predicate_as_control));
565   }
566 
567   return OkStatus();
568 }
569 }  // namespace
570 
Run(const GraphOptimizationPassOptions & options)571 Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
572   Graph* graph = options.graph->get();
573 
574   // Copy out the nodes we want to rewrite to avoid modifying the graph while we
575   // iterate on graph->op_nodes().
576   std::vector<Node*> xla_compiled_kernels;
577   absl::c_copy_if(graph->op_nodes(), std::back_inserter(xla_compiled_kernels),
578                   [](const Node* n) {
579                     if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
580                       return false;
581                     }
582 
583                     // Only compile nodes that are marked for compilation by the
584                     // compilation-marking pass (via 'attr_name').
585                     return IsXlaCompiledKernel(*n);
586                   });
587 
588   bool lazy_compilation_enabled =
589       enable_lazy_compilation_
590           ? *enable_lazy_compilation_
591           : GetBuildXlaOpsPassFlags()->tf_xla_enable_lazy_compilation;
592 
593   jit::DeviceInfoCache device_info_cache;
594   const BuildXlaOpsPassFlags& flags = *GetBuildXlaOpsPassFlags();
595 
596   DebuggingOpts debugging_opts;
597   debugging_opts.print_outputs = flags.tf_xla_print_cluster_outputs;
598   debugging_opts.check_input_numerics =
599       flags.tf_xla_check_cluster_input_numerics;
600   debugging_opts.check_output_numerics =
601       flags.tf_xla_check_cluster_output_numerics;
602 
603   VLOG(1) << "print_outputs = " << debugging_opts.print_outputs;
604   VLOG(1) << "check_input_numerics = " << debugging_opts.check_input_numerics;
605   VLOG(1) << "check_output_numerics = " << debugging_opts.check_output_numerics;
606 
607   for (Node* n : xla_compiled_kernels) {
608     TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
609         &device_info_cache, options, *options.flib_def,
610         lazy_compilation_enabled, debugging_opts, graph, n));
611   }
612 
613   if (VLOG_IS_ON(1)) {
614     DumpGraphToFile("build_xla_ops", *graph, options.flib_def);
615   }
616 
617   return OkStatus();
618 }
619 }  // namespace tensorflow
620