1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/build_xla_ops_pass.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/cc/framework/ops.h"
22 #include "tensorflow/cc/framework/scope_internal.h"
23 #include "tensorflow/cc/ops/array_ops.h"
24 #include "tensorflow/cc/ops/const_op.h"
25 #include "tensorflow/cc/ops/control_flow_ops.h"
26 #include "tensorflow/cc/ops/functional_ops.h"
27 #include "tensorflow/cc/ops/logging_ops.h"
28 #include "tensorflow/compiler/jit/defs.h"
29 #include "tensorflow/compiler/jit/device_util.h"
30 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
31 #include "tensorflow/compiler/jit/flags.h"
32 #include "tensorflow/compiler/jit/xla_cluster_util.h"
33 #include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
34 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
35 #include "tensorflow/compiler/xla/status_macros.h"
36 #include "tensorflow/core/common_runtime/function.h"
37 #include "tensorflow/core/common_runtime/graph_constructor.h"
38 #include "tensorflow/core/common_runtime/optimization_registry.h"
39 #include "tensorflow/core/framework/graph_def_util.h"
40 #include "tensorflow/core/framework/memory_types.h"
41 #include "tensorflow/core/framework/node_def_builder.h"
42 #include "tensorflow/core/framework/node_def_util.h"
43 #include "tensorflow/core/graph/algorithm.h"
44 #include "tensorflow/core/graph/graph.h"
45 #include "tensorflow/core/lib/core/status.h"
46 #include "tensorflow/core/lib/hash/hash.h"
47 #include "tensorflow/core/public/version.h"
48 #include "tensorflow/core/util/dump_graph.h"
49
50 namespace tensorflow {
51 namespace {
52 struct DebuggingOpts {
53 // If true, insert Print nodes to print every output from an XLA cluster.
54 bool print_outputs;
55
56 // If true, insert CheckNumerics nodes for every floating point typed input to
57 // an XLA cluster.
58 bool check_input_numerics;
59
60 // If true, insert CheckNumerics nodes for every floating point typed output
61 // from an XLA cluster.
62 bool check_output_numerics;
63 };
64
MoveOutgoingEdges(Graph * g,Node * old_node,Node * new_node)65 void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
66 std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
67 old_node->out_edges().end());
68 for (const Edge* edge : out_edges) {
69 // TODO(sanjoy): This does not update NodeDef inputs. To be able to update
70 // NodeDef inputs we first need to fix encapsulate_subgraphs_pass to fix up
71 // the NodeDef inputs to the function call nodes.
72 g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input());
73 g->RemoveEdge(edge);
74 }
75 }
76
77 // Returns a data value that is dead iff `control` is dead.
ControlToData(const Scope & scope,Node * control)78 Output ControlToData(const Scope& scope, Node* control) {
79 // The choice of data type here is important.
80 //
81 // We implement a "control merge", which is a control edge that is alive if
82 // either of two nodes (denoted as A and B below) are alive, in the following
83 // manner:
84 //
85 // A --ctrl--> Const0 --data--> Merge --data--> Identity
86 // ^ |
87 // | ctrl
88 // B --ctrl--> Const1 --data-----+ |
89 // v
90 // ***
91 //
92 // where *** denotes the merged control output.
93 //
94 // We want everything starting from Const{0/1} to Identity to either wholly
95 // live on the host or wholly live on device so we need to pick a data type
96 // that is either consistently assigned to the device (e.g. float) or
97 // consistently assigned to the host (e.g. int32). We should *not* pick a
98 // data type that partly placed on the host and partly on the device
99 // (e.g. bool constants are placed on the device but bool Identity is placed
100 // on the host).
101 Output data = ops::Const(scope.WithOpName("ctrl_as_data"),
102 Tensor(DT_INT32, TensorShape({0})));
103 scope.graph()->AddControlEdge(control, data.node());
104 return Output(data.node());
105 }
106
107 // Returns an operation that can be control-depended on that is dead iff `data`
108 // is dead.
DataToControl(const Scope & scope,Output data)109 Operation DataToControl(const Scope& scope, Output data) {
110 return Operation(
111 ops::Identity(scope.WithOpName("data_as_ctrl"), data).node());
112 }
113
114 // Replaces each outgoing edge from `old_node` with a merge node that merges in
115 // the corresponding output from `new_node`.
MergeOutgoingDataEdges(const Scope & s,Node * old_node,Node * new_node,absl::string_view cluster_name,const DebuggingOpts & debugging_opts)116 void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node,
117 absl::string_view cluster_name,
118 const DebuggingOpts& debugging_opts) {
119 if (!s.status().ok()) {
120 return;
121 }
122
123 std::vector<Output> merged_outputs(old_node->num_outputs(), Output(nullptr));
124
125 std::vector<const Edge*> data_edges;
126 absl::c_copy_if(old_node->out_edges(), std::back_inserter(data_edges),
127 [](const Edge* e) { return !e->IsControlEdge(); });
128
129 for (const Edge* e : data_edges) {
130 int oidx = e->src_output();
131 Output merged_output = merged_outputs[oidx];
132 if (merged_output.node() == nullptr) {
133 Output new_output(new_node, oidx);
134 if (debugging_opts.print_outputs) {
135 string cpu_device = "/job:localhost/replica:0/task:0/device:CPU:0";
136 ops::Print print_op(s.WithOpName("print_", oidx)
137 .WithDevice(cpu_device)
138 .WithAssignedDevice(cpu_device),
139 new_output, {new_output},
140 ops::Print::Attrs{}
141 .Message(absl::StrCat("output ", oidx, " from ",
142 old_node->name(), " is "))
143 .FirstN(1000)
144 .Summarize(-1));
145 new_output = print_op;
146 }
147
148 if (debugging_opts.check_output_numerics &&
149 DataTypeIsFloating(new_output.type())) {
150 ops::CheckNumerics check_numerics_op(
151 s.WithOpName("check_output_", oidx)
152 .WithDevice(new_node->requested_device())
153 .WithAssignedDevice(new_node->assigned_device_name()),
154 new_output,
155 absl::StrCat("CheckNumerics failed for output ", oidx, "(",
156 new_output.name(), ") from cluster ", cluster_name));
157 new_output = check_numerics_op;
158 }
159
160 ops::_XlaMerge xla_merge_op(s.WithOpName("merge_oidx_", oidx),
161 Output(old_node, oidx), new_output);
162 merged_output = merged_outputs[oidx] = xla_merge_op.output;
163 }
164
165 Node* dst = e->dst();
166 int dst_idx = e->dst_input();
167
168 s.graph()->RemoveEdge(e);
169 s.graph()->AddEdge(merged_output.node(), merged_output.index(), dst,
170 dst_idx);
171 }
172 }
173
174 // Replaces each control successor of `old_node` to execute whenever either
175 // `old_node` or `new_node` is executed.
MergeOutgoingControlEdges(const Scope & s,Node * old_node,Node * new_node)176 void MergeOutgoingControlEdges(const Scope& s, Node* old_node, Node* new_node) {
177 if (!s.status().ok()) {
178 return;
179 }
180
181 std::vector<const Edge*> ctrl_edges;
182 absl::c_copy_if(old_node->out_edges(), std::back_inserter(ctrl_edges),
183 [](const Edge* e) { return e->IsControlEdge(); });
184
185 if (ctrl_edges.empty()) {
186 return;
187 }
188
189 if (ctrl_edges.size() == 1 && ctrl_edges.front()->dst()->IsSink()) {
190 // Avoid creating a Merge node if we can just add an edge to _SINK
191 // instead.
192 s.graph()->AddControlEdge(new_node, s.graph()->sink_node());
193 return;
194 }
195
196 // We can't merge control edges directly so we instead first "convert" them to
197 // normal values that can be merged, merge the values and then "convert" the
198 // merged value back into control.
199 //
200 // NB! We need to copy out the outgoing control edges before constructing
201 // old_ctrl_as_data otherwise the control edge from old_node to the constant
202 // in ControlToData will be present in ctrl_edges.
203
204 Output old_ctrl_as_data = ControlToData(s, old_node);
205 Output new_ctrl_as_data = ControlToData(s, new_node);
206
207 ops::Merge ctrl_merge_as_data(s.WithOpName("ctrl_merge"),
208 {old_ctrl_as_data, new_ctrl_as_data});
209 Operation ctrl_merge = DataToControl(s, ctrl_merge_as_data.output);
210
211 for (const Edge* e : ctrl_edges) {
212 s.graph()->AddControlEdge(ctrl_merge.node(), e->dst());
213 s.graph()->RemoveControlEdge(e);
214 }
215 }
216
217 struct XlaClusterInfo {
218 std::vector<Output> constant_inputs;
219 std::vector<Output> non_constant_inputs;
220 std::vector<Output> resource_inputs;
221 NameAttrList function;
222 };
223
IncomingEdgeAsOutput(const Edge * e)224 Output IncomingEdgeAsOutput(const Edge* e) {
225 return Output(e->src(), e->src_output());
226 }
227
GetXlaClusterInfo(Node * n,XlaClusterInfo * result)228 Status GetXlaClusterInfo(Node* n, XlaClusterInfo* result) {
229 int num_constant_inputs, num_resource_inputs;
230 TF_RETURN_IF_ERROR(
231 GetNodeAttr(n->attrs(), kXlaNumConstantArgsAttr, &num_constant_inputs));
232 TF_RETURN_IF_ERROR(
233 GetNodeAttr(n->attrs(), kXlaNumResourceArgsAttr, &num_resource_inputs));
234
235 if (num_constant_inputs < 0 || num_resource_inputs < 0 ||
236 num_constant_inputs + num_resource_inputs > n->num_inputs()) {
237 return errors::InvalidArgument(
238 "Invalid number of constant/resource arguments to XLA kernel.");
239 }
240
241 int num_non_constant_inputs =
242 n->num_inputs() - num_constant_inputs - num_resource_inputs;
243
244 std::vector<const Edge*> input_edges_vector;
245 TF_RETURN_IF_ERROR(n->input_edges(&input_edges_vector));
246 absl::Span<const Edge*> input_edges(input_edges_vector);
247
248 absl::c_transform(input_edges.subspan(0, num_constant_inputs),
249 std::back_inserter(result->constant_inputs),
250 IncomingEdgeAsOutput);
251
252 absl::c_transform(
253 input_edges.subspan(num_constant_inputs, num_non_constant_inputs),
254 std::back_inserter(result->non_constant_inputs), IncomingEdgeAsOutput);
255
256 absl::c_transform(
257 input_edges.subspan(num_constant_inputs + num_non_constant_inputs,
258 num_resource_inputs),
259 std::back_inserter(result->resource_inputs), IncomingEdgeAsOutput);
260
261 result->function.set_name(n->type_string());
262 *result->function.mutable_attr() = n->def().attr();
263 return OkStatus();
264 }
265
CopyIncomingControlEdges(Graph * g,Node * from,Node * to)266 Status CopyIncomingControlEdges(Graph* g, Node* from, Node* to) {
267 for (const Edge* e : from->in_edges()) {
268 if (e->IsControlEdge()) {
269 g->AddControlEdge(e->src(), to);
270 }
271 }
272
273 return OkStatus();
274 }
275
RemoveAllIncomingControlEdges(Graph * g,Node * n)276 void RemoveAllIncomingControlEdges(Graph* g, Node* n) {
277 std::vector<const Edge*> incoming_ctrl_edges;
278 absl::c_copy_if(n->in_edges(), std::back_inserter(incoming_ctrl_edges),
279 [](const Edge* e) { return e->IsControlEdge(); });
280 for (const Edge* e : incoming_ctrl_edges) {
281 g->RemoveControlEdge(e);
282 }
283 }
284
285 // Returns true (into `result`) if a node placed on `device` must be compiled.
DeviceRequiresCompilation(const jit::DeviceInfoCache & device_info_cache,jit::DeviceId device,bool * result)286 Status DeviceRequiresCompilation(const jit::DeviceInfoCache& device_info_cache,
287 jit::DeviceId device, bool* result) {
288 const XlaOpRegistry::DeviceRegistration* registration =
289 device_info_cache.GetCompilationDevice(device);
290 *result = registration->autoclustering_policy ==
291 XlaOpRegistry::AutoclusteringPolicy::kAlways;
292 return OkStatus();
293 }
294
295 // Replaces `n` with a `PartitionedCall` op that calls the same function.
ReplaceFunctionCallWithPartitionedCall(const GraphOptimizationPassOptions & options,const FunctionLibraryDefinition & flib_def,Node * n,Graph * g,const NameAttrList & func,const Scope & root)296 StatusOr<Node*> ReplaceFunctionCallWithPartitionedCall(
297 const GraphOptimizationPassOptions& options,
298 const FunctionLibraryDefinition& flib_def, Node* n, Graph* g,
299 const NameAttrList& func, const Scope& root) {
300 string config_string = options.session_options->config.SerializeAsString();
301
302 int input_count = absl::c_count_if(
303 n->in_edges(), [](const Edge* e) { return !e->IsControlEdge(); });
304
305 std::vector<Output> args(input_count);
306 for (const Edge* e : n->in_edges()) {
307 if (!e->IsControlEdge()) {
308 args[e->dst_input()] = Output(e->src(), e->src_output());
309 }
310 }
311
312 // In theory we can use PartitionedCall if the XLA cluster does not have any
313 // stateful operations. However, for now we choose to be conservative since
314 // we don't have any evidence that choosing a stateless partitioned call helps
315 // for performance.
316 ops::StatefulPartitionedCall call(
317 root.WithOpName("stateful_partitioned_call"), args, n->output_types(),
318 func, ops::StatefulPartitionedCall::Attrs{}.ConfigProto(config_string));
319
320 for (const Edge* e : n->in_edges()) {
321 if (e->IsControlEdge()) {
322 g->AddControlEdge(e->src(), call.operation.node());
323 }
324 }
325
326 std::vector<const Edge*> edges_to_delete;
327
328 for (const Edge* e : n->out_edges()) {
329 edges_to_delete.push_back(e);
330 if (e->IsControlEdge()) {
331 g->AddControlEdge(call.operation.node(), e->dst());
332 } else {
333 g->AddEdge(call.operation.node(), e->src_output(), e->dst(),
334 e->dst_input());
335 }
336 }
337
338 for (const Edge* e : edges_to_delete) {
339 g->RemoveEdge(e);
340 }
341
342 g->RemoveNode(n);
343 return call.operation.node();
344 }
345
InferDeviceForCluster(jit::DeviceInfoCache * device_info_cache,Node * n,const string & function_name,const FunctionLibraryDefinition & flib_def)346 StatusOr<jit::DeviceId> InferDeviceForCluster(
347 jit::DeviceInfoCache* device_info_cache, Node* n,
348 const string& function_name, const FunctionLibraryDefinition& flib_def) {
349 const FunctionDef* func_def = flib_def.Find(function_name);
350 TF_RET_CHECK(func_def) << "Could not find " << function_name;
351
352 jit::DeviceSet device_set;
353
354 for (const NodeDef& ndef : func_def->node_def()) {
355 VLOG(3) << ndef.DebugString();
356 if (!ndef.device().empty()) {
357 TF_ASSIGN_OR_RETURN(jit::DeviceId device_id,
358 device_info_cache->GetIdFor(ndef.device()));
359 device_set.Insert(device_id);
360 }
361 }
362
363 if (!n->assigned_device_name().empty()) {
364 // TODO(sanjoy): We need this because EncapsulateSubgraphsPass drops device
365 // assignment when constant folding. We should fix EncapsulateSubgraphsPass
366 // instead.
367 TF_ASSIGN_OR_RETURN(jit::DeviceId device_id,
368 device_info_cache->GetIdFor(n->assigned_device_name()));
369 device_set.Insert(device_id);
370 }
371
372 TF_ASSIGN_OR_RETURN(jit::DeviceId result,
373 PickDeviceForXla(*device_info_cache, device_set,
374 /*allow_mixing_unknown_and_cpu=*/true));
375 VLOG(2) << "For " << function_name << " PickDeviceForXla("
376 << device_info_cache->DebugString(device_set) << ") -> "
377 << device_info_cache->GetNameFor(result);
378 return result;
379 }
380
GetXlaRunArgs(const Scope & s,const XlaClusterInfo & cluster_info,const DebuggingOpts & debugging_opts)381 std::vector<Output> GetXlaRunArgs(const Scope& s,
382 const XlaClusterInfo& cluster_info,
383 const DebuggingOpts& debugging_opts) {
384 std::vector<Output> xla_run_args;
385 xla_run_args.reserve(cluster_info.non_constant_inputs.size() +
386 cluster_info.resource_inputs.size());
387 int input_idx = 0;
388 for (const Output& o : cluster_info.non_constant_inputs) {
389 if (debugging_opts.check_input_numerics && DataTypeIsFloating(o.type())) {
390 ops::CheckNumerics check_numerics_op(
391 s.WithOpName("check_input_", input_idx), o,
392 absl::StrCat("CheckNumerics failed for input ", input_idx, "(",
393 o.name(), ") into ", cluster_info.function.name()));
394 xla_run_args.push_back(check_numerics_op);
395 } else {
396 xla_run_args.push_back(o);
397 }
398 input_idx++;
399 }
400 absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args));
401 return xla_run_args;
402 }
403
GetOutputMemoryTypes(const Scope & root,Node * n)404 StatusOr<MemoryTypeVector> GetOutputMemoryTypes(const Scope& root, Node* n) {
405 MemoryTypeVector input_mtypes, output_mtypes;
406 DeviceType device_type("");
407 TF_RETURN_IF_ERROR(
408 DeviceNameToDeviceType(n->assigned_device_name(), &device_type));
409 TF_RETURN_IF_ERROR(MemoryTypesForNode(root.graph()->op_registry(),
410 device_type, n->def(), &input_mtypes,
411 &output_mtypes));
412 return output_mtypes;
413 }
414
415 // Predicate INT32 typed inputs to `n` on the deadness of
416 // `predicate_as_control`.
417 //
418 // This is a performance optimization. Since INT32 arguments to a
419 // PartitionedCall are placed on the host, a producer that produces them on the
420 // device will incur a D2H copy, even if the PartitionedCall is not executed
421 // (i.e. even if we choose to execute the XLA compiled computation via _XlaRun).
422 // To prevent this, we add control dependencies to make the int32 input edges
423 // into the PartitionedCall dead. With this change the D2H copy only happens if
424 // the PartitionedCall is actually executed.
PredicateInt32Inputs(const Scope & root,Node * n,Operation predicate_as_control)425 Status PredicateInt32Inputs(const Scope& root, Node* n,
426 Operation predicate_as_control) {
427 std::vector<Output> int32_inputs;
428 std::vector<int> int32_inputs_input_idxs;
429 for (const Edge* e : n->in_edges()) {
430 if (e->IsControlEdge()) {
431 continue;
432 }
433
434 if (e->src()->output_type(e->src_output()) == DT_INT32) {
435 TF_ASSIGN_OR_RETURN(MemoryTypeVector source_output_mem_types,
436 GetOutputMemoryTypes(root, e->src()));
437 if (source_output_mem_types[e->src_output()] == DEVICE_MEMORY) {
438 int32_inputs.push_back(Output(e->src(), e->src_output()));
439 int32_inputs_input_idxs.push_back(e->dst_input());
440 }
441 }
442 }
443
444 if (int32_inputs.empty()) {
445 return OkStatus();
446 }
447
448 // Create a single IdentityN that is dead if and only if
449 // `predicate_as_control` is dead.
450 //
451 // IdentityN is also special in that, unlike `Identity`, it does not place
452 // int32 inputs in host memory. Placing int32 inputs in host memory would
453 // defeat the purpose of adding this indirection.
454 ops::IdentityN identity_n(root.WithOpName("int32_id_n"), int32_inputs);
455 root.graph()->AddControlEdge(predicate_as_control.node(),
456 identity_n.operation.node());
457
458 for (int i = 0, end = int32_inputs.size(); i < end; i++) {
459 TF_RETURN_IF_ERROR(root.graph()->UpdateEdge(identity_n[i].node(), i, n,
460 int32_inputs_input_idxs[i]));
461 }
462
463 return OkStatus();
464 }
465
ReplaceNodeWithXlaCompileAndXlaRun(jit::DeviceInfoCache * device_info_cache,const GraphOptimizationPassOptions & options,const FunctionLibraryDefinition & flib_def,bool lazy_compilation_enabled,const DebuggingOpts & debugging_opts,Graph * g,Node * n)466 Status ReplaceNodeWithXlaCompileAndXlaRun(
467 jit::DeviceInfoCache* device_info_cache,
468 const GraphOptimizationPassOptions& options,
469 const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled,
470 const DebuggingOpts& debugging_opts, Graph* g, Node* n) {
471 XlaClusterInfo cluster_info;
472 TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
473
474 TF_ASSIGN_OR_RETURN(
475 jit::DeviceId device,
476 InferDeviceForCluster(device_info_cache, n, cluster_info.function.name(),
477 flib_def));
478
479 bool requires_compilation;
480 TF_RETURN_IF_ERROR(DeviceRequiresCompilation(*device_info_cache, device,
481 &requires_compilation));
482 if (!lazy_compilation_enabled) {
483 requires_compilation = true;
484 }
485
486 string device_name_str = string(device_info_cache->GetNameFor(device));
487
488 Status status;
489 Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr)
490 .NewSubScope(n->name())
491 .WithDevice(n->requested_device())
492 .WithAssignedDevice(device_name_str);
493
494 ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"),
495 /*constants=*/cluster_info.constant_inputs,
496 /*args=*/cluster_info.non_constant_inputs,
497 /*resources=*/cluster_info.resource_inputs,
498 /*must_compile=*/requires_compilation,
499 cluster_info.function);
500
501 bool has_ref_attr;
502 TF_RETURN_IF_ERROR(
503 GetNodeAttr(n->attrs(), kXlaHasReferenceVarsAttr, &has_ref_attr));
504 xla_compile.operation.node()->AddAttr(kXlaHasReferenceVarsAttr, has_ref_attr);
505 TF_RETURN_IF_ERROR(
506 CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node()));
507
508 std::vector<Output> xla_run_args =
509 GetXlaRunArgs(root, cluster_info, debugging_opts);
510
511 if (requires_compilation) {
512 // "Strict" compilation: every _XlaCompile invocation must compile the
513 // cluster.
514 ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
515 xla_compile.key, n->output_types());
516
517 MoveOutgoingEdges(g, /*old_node=*/n,
518 /*new_node=*/xla_run.operation.node());
519 g->RemoveNode(n);
520 } else {
521 // "Lazy" compilation: an _XlaCompile invocation may decide not to compile
522 // the cluster based on profitability heuristics.
523
524 // We generate the following graph:
525 //
526 // (use_tf_call, use_xla_run) =
527 // Switch(pred=xla_compile.compilation_successful,
528 // value=xla_compile.key)
529 //
530 // tf_call_outputs = cluster_N(..., ^use_tf_call)
531 // xla_run_outputs = _XlaRun(..., key=use_xla_run)
532 // outputs = Merge(tf_call_outputs, xla_run_outputs).
533 ops::Switch s(root.WithOpName("predicated_compilation_key"),
534 xla_compile.key, xla_compile.compilation_successful);
535 Output predicated_compilation_key = s.output_true;
536 Output inverse_predicated_compilation_key = s.output_false;
537
538 ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
539 predicated_compilation_key, n->output_types());
540
541 MergeOutgoingControlEdges(root, /*old_node=*/n,
542 /*new_node=*/xla_run.operation.node());
543
544 MergeOutgoingDataEdges(root, /*old_node=*/n,
545 /*new_node=*/xla_run.operation.node(),
546 cluster_info.function.name(), debugging_opts);
547
548 TF_RETURN_IF_ERROR(root.status());
549
550 // We already have a TensorFlow function call into the cluster -- the
551 // original node we set out to rewrite. We just wire in the correct control
552 // deps and we're done.
553 RemoveAllIncomingControlEdges(g, n);
554 Operation inverse_predicate_as_control =
555 DataToControl(root, inverse_predicated_compilation_key);
556 g->AddControlEdge(inverse_predicate_as_control.node(), n);
557 n->ClearAttr(kXlaCompiledKernelAttr);
558
559 TF_ASSIGN_OR_RETURN(Node* const pco, ReplaceFunctionCallWithPartitionedCall(
560 options, flib_def, n, g,
561 cluster_info.function, root));
562
563 TF_RETURN_IF_ERROR(
564 PredicateInt32Inputs(root, pco, inverse_predicate_as_control));
565 }
566
567 return OkStatus();
568 }
569 } // namespace
570
Run(const GraphOptimizationPassOptions & options)571 Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
572 Graph* graph = options.graph->get();
573
574 // Copy out the nodes we want to rewrite to avoid modifying the graph while we
575 // iterate on graph->op_nodes().
576 std::vector<Node*> xla_compiled_kernels;
577 absl::c_copy_if(graph->op_nodes(), std::back_inserter(xla_compiled_kernels),
578 [](const Node* n) {
579 if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
580 return false;
581 }
582
583 // Only compile nodes that are marked for compilation by the
584 // compilation-marking pass (via 'attr_name').
585 return IsXlaCompiledKernel(*n);
586 });
587
588 bool lazy_compilation_enabled =
589 enable_lazy_compilation_
590 ? *enable_lazy_compilation_
591 : GetBuildXlaOpsPassFlags()->tf_xla_enable_lazy_compilation;
592
593 jit::DeviceInfoCache device_info_cache;
594 const BuildXlaOpsPassFlags& flags = *GetBuildXlaOpsPassFlags();
595
596 DebuggingOpts debugging_opts;
597 debugging_opts.print_outputs = flags.tf_xla_print_cluster_outputs;
598 debugging_opts.check_input_numerics =
599 flags.tf_xla_check_cluster_input_numerics;
600 debugging_opts.check_output_numerics =
601 flags.tf_xla_check_cluster_output_numerics;
602
603 VLOG(1) << "print_outputs = " << debugging_opts.print_outputs;
604 VLOG(1) << "check_input_numerics = " << debugging_opts.check_input_numerics;
605 VLOG(1) << "check_output_numerics = " << debugging_opts.check_output_numerics;
606
607 for (Node* n : xla_compiled_kernels) {
608 TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
609 &device_info_cache, options, *options.flib_def,
610 lazy_compilation_enabled, debugging_opts, graph, n));
611 }
612
613 if (VLOG_IS_ON(1)) {
614 DumpGraphToFile("build_xla_ops", *graph, options.flib_def);
615 }
616
617 return OkStatus();
618 }
619 } // namespace tensorflow
620