1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/compilability_check_util.h"
17
18 #include <algorithm>
19 #include <atomic>
20 #include <deque>
21 #include <iterator>
22 #include <limits>
23 #include <string>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <utility>
27
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "absl/strings/string_view.h"
34 #include "tensorflow/compiler/jit/defs.h"
35 #include "tensorflow/compiler/jit/device_util.h"
36 #include "tensorflow/compiler/jit/flags.h"
37 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
38 #include "tensorflow/compiler/jit/xla_activity.pb.h"
39 #include "tensorflow/compiler/jit/xla_activity_listener.h"
40 #include "tensorflow/compiler/jit/xla_cluster_util.h"
41 #include "tensorflow/compiler/tf2xla/const_analysis.h"
42 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
43 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
44 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
45 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
46 #include "tensorflow/compiler/xla/statusor.h"
47 #include "tensorflow/compiler/xla/union_find.h"
48 #include "tensorflow/compiler/xla/util.h"
49 #include "tensorflow/core/common_runtime/function.h"
50 #include "tensorflow/core/common_runtime/graph_constructor.h"
51 #include "tensorflow/core/framework/attr_value.pb.h"
52 #include "tensorflow/core/framework/bounds_check.h"
53 #include "tensorflow/core/framework/graph_def_util.h"
54 #include "tensorflow/core/framework/memory_types.h"
55 #include "tensorflow/core/framework/node_def.pb.h"
56 #include "tensorflow/core/framework/op_kernel.h"
57 #include "tensorflow/core/framework/types.h"
58 #include "tensorflow/core/graph/algorithm.h"
59 #include "tensorflow/core/graph/control_flow.h"
60 #include "tensorflow/core/lib/gtl/cleanup.h"
61 #include "tensorflow/core/lib/strings/stringprintf.h"
62 #include "tensorflow/core/public/version.h"
63 #include "tensorflow/core/util/dump_graph.h"
64
65 namespace tensorflow {
66
67 namespace {
68
HasResourceInput(const Node & node)69 bool HasResourceInput(const Node& node) {
70 return absl::c_count(node.input_types(), DT_RESOURCE) != 0;
71 }
72
LogNotCompilable(const Node & node,absl::string_view reason="")73 void LogNotCompilable(const Node& node, absl::string_view reason = "") {
74 VLOG(3) << "Found uncompilable node " << node.name() << " (op "
75 << node.type_string() << ")" << (reason.empty() ? "" : ": ")
76 << reason;
77 }
78
IsInOutsideCompilationCluster(const Node & n)79 bool IsInOutsideCompilationCluster(const Node& n) {
80 return n.attrs().Find(kXlaOutsideCompilationAttr) != nullptr;
81 }
82
MakeCallNodeFromAttribute(const Node & node,const std::string & attr_name,NodeDef * node_def)83 Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
84 NodeDef* node_def) {
85 const NameAttrList* name_attr;
86 TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &name_attr));
87 node_def->set_op(name_attr->name());
88 *(node_def->mutable_attr()) = name_attr->attr();
89 return OkStatus();
90 }
91
MakeCallNodesFromAttribute(const Node & node,absl::string_view attr_name,absl::string_view call_name)92 StatusOr<std::vector<NodeDef>> MakeCallNodesFromAttribute(
93 const Node& node, absl::string_view attr_name,
94 absl::string_view call_name) {
95 std::vector<NameAttrList> attr_lists;
96 TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &attr_lists));
97
98 std::vector<NodeDef> out;
99 out.reserve(attr_lists.size());
100 for (int i = 0; i < attr_lists.size(); i++) {
101 out.emplace_back();
102 NodeDef& inserted = out.back();
103 inserted.set_name(absl::StrCat(call_name, "_", i));
104 inserted.set_op(attr_lists[i].name());
105 *inserted.mutable_attr() = attr_lists[i].attr();
106 }
107 return out;
108 }
109
110 // Utility which searches for values in a sorted list by scanning over it once.
111 // No matter how many times ScanForValue is called, the list is scanned at most
112 // once. However, if a call to ScanForValue skips over a value, that value is
113 // not revisited in future calls to ScanForValue, so callers must take
114 // care to order their calls.
115 //
116 // Useful for merging multiple sorted lists in O(n) time.
117 class SinglePassSearch {
118 public:
119 // Creates a SinglePassSearch object that can be used to search in `values`.
120 // Does not take ownership of `values`. `values` must outlive this.
121 // `values` must be sorted.
SinglePassSearch(absl::Span<int const> values)122 explicit SinglePassSearch(absl::Span<int const> values)
123 : current_index_(0), values_(values) {}
124
125 // Scans forward in the vector looking for "value", updating the internal
126 // position in to the vector.
127 // Returns true iff the vector contains the given value at or after current
128 // position.
129 // Not thread-safe.
ScanForValue(int value)130 bool ScanForValue(int value) {
131 while (current_index_ < values_.size() &&
132 values_[current_index_] <= value) {
133 if (values_[current_index_] == value) {
134 current_index_++;
135 return true;
136 }
137 current_index_++;
138 }
139 return false;
140 }
141
142 private:
143 int current_index_;
144 const absl::Span<int const> values_;
145 };
146
147 } // anonymous namespace
148
149 RecursiveCompilabilityChecker::UncompilableNodesMap
FindUncompilableNodes(const Node & node,FunctionLibraryRuntime * lib_runtime,const std::vector<RecursiveCompilabilityChecker::StackFrame> * node_stack_trace) const150 RecursiveCompilabilityChecker::FindUncompilableNodes(
151 const Node& node, FunctionLibraryRuntime* lib_runtime,
152 const std::vector<RecursiveCompilabilityChecker::StackFrame>*
153 node_stack_trace) const {
154 std::vector<StackFrameView> stack_trace;
155 // If `node_stack_trace` is provided, that means `node` is inside
156 // a function body, and therefore, arg nodes and retval nodes are
157 // not considered uncompilable.
158 if (node_stack_trace != nullptr) {
159 for (const auto& frame : *node_stack_trace) {
160 stack_trace.emplace_back(
161 StackFrameView{frame.name, frame.function_name, frame.stack_trace});
162 }
163 }
164 stack_trace.emplace_back(
165 StackFrameView{node.name(), "", node.GetStackTrace()});
166
167 RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
168 IsCompilableNode(node, lib_runtime, &stack_trace,
169 /*encapsulating_function=*/nullptr, &uncompilable_nodes);
170 return uncompilable_nodes;
171 }
172
HasXLAKernel(const Node & node,string * uncompilable_reason) const173 bool RecursiveCompilabilityChecker::HasXLAKernel(
174 const Node& node, string* uncompilable_reason) const {
175 // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
176 // is really a kind of function call and will be handled by
177 // IsCompilableCall().
178 if (node.type_string() == "SymbolicGradient") {
179 *uncompilable_reason =
180 "SymbolicGradient should be handled by IsCompilableCall().";
181 return false;
182 }
183
184 if (node.type_string() == "Const") {
185 const AttrValue* attr = node.attrs().Find("dtype");
186 if (!op_filter_.allow_string_consts && attr != nullptr &&
187 attr->type() == DT_STRING) {
188 *uncompilable_reason =
189 "Const op with type DT_STRING is not supported by XLA.";
190 return false;
191 }
192 }
193
194 // XLA does not offer guaranteed aliasing between the input and output of the
195 // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave
196 // such nodes out of XLA clusters.
197 if (HasForwardedRefInput(node)) {
198 VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
199 *uncompilable_reason = "Identity with unsafe cast.";
200 return false;
201 }
202
203 Status s = FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr);
204 if (!s.ok()) {
205 *uncompilable_reason = s.error_message();
206 return false;
207 }
208 return true;
209 }
210
211 // Tests whether 'if_node' is compilable. Every operator in the then_branch and
212 // else_branch functions must be compilable for 'if_node' to be compilable.
IsCompilableIf(const Node & if_node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const213 bool RecursiveCompilabilityChecker::IsCompilableIf(
214 const Node& if_node, FunctionLibraryRuntime* lib_runtime,
215 std::vector<StackFrameView>* stack_trace,
216 NameAttrList* encapsulating_function,
217 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
218 const {
219 bool is_compilable = true;
220 is_compilable &= ExtractNodeDefAndCheckCompilability(
221 if_node, "then_branch", "if_then", encapsulating_function, lib_runtime,
222 stack_trace, uncompilable_nodes);
223 if (!uncompilable_nodes && !is_compilable) return is_compilable;
224
225 is_compilable &= ExtractNodeDefAndCheckCompilability(
226 if_node, "else_branch", "if_else", encapsulating_function, lib_runtime,
227 stack_trace, uncompilable_nodes);
228
229 return is_compilable;
230 }
231
IsCompilableCase(const Node & case_node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const232 bool RecursiveCompilabilityChecker::IsCompilableCase(
233 const Node& case_node, FunctionLibraryRuntime* lib_runtime,
234 std::vector<StackFrameView>* stack_trace,
235 NameAttrList* encapsulating_function,
236 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
237 const {
238 StatusOr<std::vector<NodeDef>> calls =
239 MakeCallNodesFromAttribute(case_node, "branches", "branch");
240 if (!calls.ok()) {
241 VLOG(2) << "Rejecting node " << case_node.name() << ": "
242 << "missing attribute 'branches'";
243 return false;
244 }
245
246 bool is_compilable = true;
247
248 for (const NodeDef& call : *calls) {
249 is_compilable &=
250 IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
251 uncompilable_nodes);
252 }
253 return is_compilable;
254 }
255
256 // Tests whether 'while_node' is a completely compilable loop.
257 // Every operator in the condition and body functions must be compilable for a
258 // while loop to be compilable.
IsCompilableWhile(const Node & while_node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const259 bool RecursiveCompilabilityChecker::IsCompilableWhile(
260 const Node& while_node, FunctionLibraryRuntime* lib_runtime,
261 std::vector<StackFrameView>* stack_trace,
262 NameAttrList* encapsulating_function,
263 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
264 const {
265 bool is_compilable = true;
266 is_compilable &= ExtractNodeDefAndCheckCompilability(
267 while_node, "cond", "while_cond", encapsulating_function, lib_runtime,
268 stack_trace, uncompilable_nodes);
269
270 if (!uncompilable_nodes && !is_compilable) return is_compilable;
271
272 is_compilable &= ExtractNodeDefAndCheckCompilability(
273 while_node, "body", "while_body", encapsulating_function, lib_runtime,
274 stack_trace, uncompilable_nodes);
275
276 return is_compilable;
277 }
278
ExtractNodeDefAndCheckCompilability(const Node & node,const std::string & attr_name,const std::string & call_name,NameAttrList * encapsulating_function,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const279 bool RecursiveCompilabilityChecker::ExtractNodeDefAndCheckCompilability(
280 const Node& node, const std::string& attr_name,
281 const std::string& call_name, NameAttrList* encapsulating_function,
282 FunctionLibraryRuntime* lib_runtime,
283 std::vector<StackFrameView>* stack_trace,
284 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
285 const {
286 NodeDef call;
287 call.set_name(call_name);
288 if (!MakeCallNodeFromAttribute(node, attr_name, &call).ok()) {
289 const auto uncompilable_reason = absl::StrCat(
290 "missing '", attr_name, "' attribute from node", node.name());
291 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
292 encapsulating_function, uncompilable_nodes);
293 VLOG(2) << "Rejecting node " << node.name() << ": " << uncompilable_reason
294 << ".";
295 return false;
296 }
297 if (!IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
298 uncompilable_nodes)) {
299 VLOG(2) << "Rejecting node " << node.name()
300 << ": can't compile : " << call.op();
301 return false;
302 }
303 return true;
304 }
305
306 // Tests whether 'call_def' is a call to a completely compilable function.
307 // Every operator in the function must be compilable for a function to be
308 // compilable.
IsCompilableCall(const NodeDef & call_def,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const309 bool RecursiveCompilabilityChecker::IsCompilableCall(
310 const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
311 std::vector<StackFrameView>* stack_trace,
312 NameAttrList* encapsulating_function,
313 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
314 const {
315 if (stack_trace->size() > kMaxRecursionDepth) {
316 std::string uncompilable_reason = "function depth limit exceeded";
317 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
318 encapsulating_function, uncompilable_nodes);
319 VLOG(2) << "Rejecting " << call_def.op() << ": " << uncompilable_reason
320 << ".";
321 return false;
322 }
323
324 FunctionLibraryRuntime::Handle handle;
325 Status s;
326 NameAttrList function;
327 s = NameAndAttrsFromFunctionCall(call_def, &function);
328 if (s.ok()) {
329 s = lib_runtime->Instantiate(function.name(), AttrSlice(&function.attr()),
330 &handle);
331 }
332 if (!s.ok()) {
333 std::string uncompilable_reason =
334 absl::StrCat("could not instantiate call: '", function.name(), "'");
335 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
336 encapsulating_function, uncompilable_nodes);
337 VLOG(2) << "Rejecting " << call_def.DebugString() << ": "
338 << uncompilable_reason << " : " << s;
339 return false;
340 }
341
342 auto release_handle_on_return = gtl::MakeCleanup(
343 [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
344 const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
345 bool is_compilable = true;
346 for (const Node* node : fbody->graph->op_nodes()) {
347 stack_trace->emplace_back(
348 StackFrameView{node->name(), function.name(), node->GetStackTrace()});
349 is_compilable &= IsCompilableNode(*node, lib_runtime, stack_trace,
350 &function, uncompilable_nodes);
351 stack_trace->pop_back();
352 if (!uncompilable_nodes && !is_compilable) return is_compilable;
353 }
354
355 return is_compilable;
356 }
357
OpIsInaccurate(const Node & node) const358 bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) const {
359 // b/127344411: SelfAdjointEigV2 and Svd precision issues.
360 return node.type_string() == "SelfAdjointEigV2" ||
361 node.type_string() == "Svd";
362 }
363
OpIsSlow(const Node & node) const364 bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) const {
365 // b/128001705: SelfAdjointEigV2 and Svd performance issues.
366 // b/135640736: MatrixInverse performance issues.
367 // b/111271662: MatrixSolve performance issues.
368 // https://github.com/tensorflow/tensorflow/pull/31012:
369 // ResizeNearestNeighbor, ResizeBilinear, and ResizeBilinearGrad sometimes
370 // create convolutions too large for CuDNN to handle.
371 return node.type_string() == "SelfAdjointEigV2" ||
372 node.type_string() == "Svd" || node.type_string() == "Qr" ||
373 node.type_string() == "MatrixInverse" ||
374 node.type_string() == "MatrixSolve" ||
375 node.type_string() == "ResizeBilinearGrad";
376 }
377
IsCompilableNode(const Node & node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const378 bool RecursiveCompilabilityChecker::IsCompilableNode(
379 const Node& node, FunctionLibraryRuntime* lib_runtime,
380 std::vector<StackFrameView>* stack_trace,
381 NameAttrList* encapsulating_function,
382 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
383 const {
384 auto stack_depth = stack_trace->size();
385
386 if (op_filter_.allow_outside_compiled && IsInOutsideCompilationCluster(node))
387 return true;
388
389 if (node.IsSource() || node.IsSink()) {
390 absl::string_view uncompilable_reason = "source or sink node";
391 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
392 encapsulating_function, uncompilable_nodes);
393 LogNotCompilable(node, uncompilable_reason);
394 return false;
395 }
396
397 // _Arg nodes in a top-level function represent feeds and _Retval nodes in a
398 // top-level function represent fetches.
399 if (stack_depth == 1 &&
400 (node.type_string() == "_Arg" || node.type_string() == "_Retval")) {
401 absl::string_view uncompilable_reason = "top level _Arg or _Retval";
402 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
403 encapsulating_function, uncompilable_nodes);
404 LogNotCompilable(node, uncompilable_reason);
405 return false;
406 }
407
408 if (node.attrs().Find("_scoped_allocator") ||
409 node.attrs().Find("_forward_from")) {
410 // TODO(b/128858118): XLA does not support _scoped_allocator and
411 // _forward_from.
412 absl::string_view uncompilable_reason =
413 "_scoped_allocator or _forward_from attribute";
414 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
415 encapsulating_function, uncompilable_nodes);
416 LogNotCompilable(node, uncompilable_reason);
417 return false;
418 }
419
420 string uncompilable_reason;
421 if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) {
422 if (!IsCompilableCall(node.def(), lib_runtime, stack_trace,
423 encapsulating_function, uncompilable_nodes)) {
424 LogNotCompilable(node, "unsupported function");
425 return false;
426 }
427 } else if (!HasXLAKernel(node, &uncompilable_reason)) {
428 MaybeMarkUncompilableNode(
429 absl::StrCat("unsupported op: ", uncompilable_reason), *stack_trace,
430 encapsulating_function, uncompilable_nodes);
431 LogNotCompilable(node, uncompilable_reason);
432 return false;
433 }
434
435 if (node.IsWhileNode() &&
436 !IsCompilableWhile(node, lib_runtime, stack_trace, encapsulating_function,
437 uncompilable_nodes)) {
438 LogNotCompilable(node, "unsupported while");
439 return false;
440 }
441
442 if (node.IsIfNode() &&
443 !IsCompilableIf(node, lib_runtime, stack_trace, encapsulating_function,
444 uncompilable_nodes)) {
445 LogNotCompilable(node, "unsupported if");
446 return false;
447 }
448
449 if (op_filter_.require_always_compilable && node.IsCaseNode() &&
450 !IsCompilableCase(node, lib_runtime, stack_trace, encapsulating_function,
451 uncompilable_nodes)) {
452 LogNotCompilable(node, "unsupported case");
453 return false;
454 }
455
456 if (!op_filter_.allow_stateful_rng_ops &&
457 IsStatefulRandomOp(node.type_string())) {
458 absl::string_view uncompilable_reason = "stateful random op";
459 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
460 encapsulating_function, uncompilable_nodes);
461 LogNotCompilable(node, uncompilable_reason);
462 return false;
463 }
464
465 if (!op_filter_.allow_control_trigger && node.IsControlTrigger()) {
466 absl::string_view uncompilable_reason = "not allowed control trigger";
467 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
468 encapsulating_function, uncompilable_nodes);
469 LogNotCompilable(node, uncompilable_reason);
470 return false;
471 }
472
473 if (!op_filter_.allow_eliding_assert_and_checknumerics_ops &&
474 IsAssertOrCheckNumerics(node.type_string())) {
475 absl::string_view uncompilable_reason = "Assert or CheckNumerics";
476 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
477 encapsulating_function, uncompilable_nodes);
478 LogNotCompilable(node, uncompilable_reason);
479 return false;
480 }
481
482 if (!op_filter_.allow_collective_reduce_v2 &&
483 node.type_string() == "CollectiveReduceV2") {
484 absl::string_view uncompilable_reason = "Collective op";
485 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
486 encapsulating_function, uncompilable_nodes);
487 LogNotCompilable(node, uncompilable_reason);
488 return false;
489 }
490
491 if (!op_filter_.allow_where_op && node.type_string() == "Where") {
492 absl::string_view uncompilable_reason = "Where op";
493 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
494 encapsulating_function, uncompilable_nodes);
495 LogNotCompilable(node, uncompilable_reason);
496 return false;
497 }
498
499 if (!op_filter_.allow_unique_op && node.type_string() == "Unique") {
500 absl::string_view uncompilable_reason = "Unique op";
501 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
502 encapsulating_function, uncompilable_nodes);
503 LogNotCompilable(node, uncompilable_reason);
504 return false;
505 }
506
507 if (!op_filter_.allow_ops_producing_or_consuming_variant &&
508 OpProducesOrConsumesVariant(node)) {
509 absl::string_view uncompilable_reason = "DT_VARIANT producer/consumer";
510 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
511 encapsulating_function, uncompilable_nodes);
512 LogNotCompilable(node, uncompilable_reason);
513 return false;
514 }
515
516 if (!op_filter_.allow_stack_ops && IsStackOp(node)) {
517 absl::string_view uncompilable_reason = "Stack op";
518 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
519 encapsulating_function, uncompilable_nodes);
520 LogNotCompilable(node, uncompilable_reason);
521 return false;
522 }
523
524 if (!op_filter_.allow_tensor_array_ops && IsTensorArrayOp(node)) {
525 absl::string_view uncompilable_reason = "TensorArray op";
526 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
527 encapsulating_function, uncompilable_nodes);
528 LogNotCompilable(node, uncompilable_reason);
529 return false;
530 }
531
532 if (!op_filter_.allow_resource_ops_in_called_functions && stack_depth > 1 &&
533 HasResourceInput(node)) {
534 absl::string_view uncompilable_reason =
535 "resource variable op in called function";
536 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
537 encapsulating_function, uncompilable_nodes);
538 LogNotCompilable(node, uncompilable_reason);
539 return false;
540 }
541
542 if (!op_filter_.allow_inaccurate_ops && OpIsInaccurate(node)) {
543 absl::string_view uncompilable_reason =
544 "operation with numerical accuracy issues";
545 BroadcastOptimizationRemark(XlaOptimizationRemark::INACCURATE_OPERATION,
546 node.DebugString())
547 .IgnoreError();
548 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
549 encapsulating_function, uncompilable_nodes);
550 LogNotCompilable(node, uncompilable_reason);
551 return false;
552 }
553
554 if (!op_filter_.allow_slow_ops && OpIsSlow(node)) {
555 absl::string_view uncompilable_reason = "slow operation";
556 BroadcastOptimizationRemark(XlaOptimizationRemark::SLOW_OPERATION,
557 node.DebugString())
558 .IgnoreError();
559 MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
560 encapsulating_function, uncompilable_nodes);
561 LogNotCompilable(node, uncompilable_reason);
562 return false;
563 }
564
565 return true;
566 }
567
CreateOperationFilter(const XlaOpRegistry::DeviceRegistration & registration)568 RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
569 const XlaOpRegistry::DeviceRegistration& registration) {
570 RecursiveCompilabilityChecker::OperationFilter op_filter;
571 op_filter.allow_resource_ops_in_called_functions =
572 registration.cluster_resource_variable_ops_unsafely;
573 op_filter.allow_stack_ops = registration.cluster_stack_ops;
574 op_filter.allow_tensor_array_ops = registration.cluster_tensor_array_ops;
575 op_filter.allow_stateful_rng_ops = registration.cluster_stateful_rng_ops;
576 op_filter.allow_control_trigger = registration.cluster_control_trigger;
577 op_filter.allow_eliding_assert_and_checknumerics_ops =
578 registration.elide_assert_and_checknumerics;
579 op_filter.allow_ops_producing_or_consuming_variant =
580 registration.cluster_variant_ops;
581 op_filter.allow_slow_ops = registration.cluster_slow_ops;
582 op_filter.allow_inaccurate_ops = registration.cluster_inaccurate_ops;
583 return op_filter;
584 }
585
MaybeMarkUncompilableNode(const absl::string_view reason,const std::vector<StackFrameView> & stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes)586 /*static*/ void RecursiveCompilabilityChecker::MaybeMarkUncompilableNode(
587 const absl::string_view reason,
588 const std::vector<StackFrameView>& stack_trace,
589 NameAttrList* encapsulating_function,
590 RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes) {
591 if (!uncompilable_nodes) return;
592
593 UncompilableNodeInfo node_info;
594 node_info.uncompilable_reason = std::string(reason);
595 absl::c_transform(stack_trace, std::back_inserter(node_info.stack_trace),
596 [](const StackFrameView& stack_element) {
597 return StackFrame{
598 std::string(stack_element.name),
599 std::string(stack_element.function_name),
600 stack_element.stack_trace};
601 });
602
603 node_info.name = std::string(stack_trace.back().name);
604 auto function =
605 encapsulating_function ? *encapsulating_function : NameAttrList();
606 auto function_identifier = function.ShortDebugString();
607
608 auto it = uncompilable_nodes->find(function_identifier);
609 if (it == uncompilable_nodes->end()) {
610 std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
611 uncompilable_node_info{std::move(node_info)};
612 uncompilable_nodes->emplace(
613 std::move(function_identifier),
614 std::make_pair(function, std::move(uncompilable_node_info)));
615 } else {
616 it->second.second.emplace_back(std::move(node_info));
617 }
618 }
619
620 // Returns `true` iff node has a given `attr` set to `true`. Returns `false`
621 // both for the missing attr, and the attr set to `false`.
HasBoolAttr(const NodeDef & node,const char * attr)622 static bool HasBoolAttr(const NodeDef& node, const char* attr) {
623 const auto& it = node.attr().find(attr);
624 return it != node.attr().end() && it->second.b();
625 }
626
CanCreateXlaKernel(const NodeDef & node_def)627 bool CanCreateXlaKernel(const NodeDef& node_def) {
628 return HasBoolAttr(node_def, kXlaMustCompileAttr);
629 }
630
GetBodyAndConstantsAndResources(FunctionLibraryRuntime * flr,const NameAttrList & function,const FunctionBody ** fbody,std::vector<int> * constant_arg_indices,std::vector<int> * resource_arg_indices)631 Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
632 const NameAttrList& function,
633 const FunctionBody** fbody,
634 std::vector<int>* constant_arg_indices,
635 std::vector<int>* resource_arg_indices) {
636 FunctionLibraryRuntime::Handle handle;
637 TF_RETURN_IF_ERROR(
638 flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
639 *fbody = flr->GetFunctionBody(handle);
640 CHECK(*fbody); // Can't be nullptr since we just instantiated it.
641 const DataTypeVector& arg_types = (*fbody)->arg_types;
642 std::vector<bool> const_args(arg_types.size());
643 // If we can't analyze the const args. Bail out.
644 TF_RETURN_IF_ERROR(
645 BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
646 /*compile_time_const_nodes=*/nullptr, flr));
647
648 for (size_t i = 0; i < const_args.size(); ++i) {
649 if (const_args[i]) {
650 constant_arg_indices->push_back(i);
651 }
652 }
653
654 // There can be hundreds of resource variables. Reserve the space for them.
655 // We don't reserve for constants above as they are usually few.
656 resource_arg_indices->reserve(arg_types.size());
657 for (size_t i = 0; i < arg_types.size(); ++i) {
658 if (arg_types[i] == DT_RESOURCE) {
659 resource_arg_indices->push_back(i);
660 }
661 }
662
663 return OkStatus();
664 }
665
GetInputMemoryTypes(const tensorflow::FunctionBody * fbody,absl::Span<int const> constant_arg_indices,absl::Span<int const> resource_arg_indices)666 tensorflow::MemoryTypeVector GetInputMemoryTypes(
667 const tensorflow::FunctionBody* fbody,
668 absl::Span<int const> constant_arg_indices,
669 absl::Span<int const> resource_arg_indices) {
670 // Set input and output memory types.
671 tensorflow::MemoryTypeVector input_memory_types(fbody->arg_types.size(),
672 tensorflow::DEVICE_MEMORY);
673 // These indices are used only for optimization purposes. They allow us
674 // to loop over constant_arg_indices and resource_arg_indices only once
675 // while iterating over all the function arguments checking if it is a
676 // resource or a constant.
677 // The reason we optimized this code is because functions can have a lot of
678 // captured arguments. For example, the backward pass of ResNet50 takes in all
679 // 214 variables and a similar number of activations.
680 SinglePassSearch constants_search(constant_arg_indices);
681 SinglePassSearch resources_search(resource_arg_indices);
682 for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
683 if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
684 // Compile-time constants and resource handles are expected to be in
685 // host memory.
686 input_memory_types[i] = tensorflow::HOST_MEMORY;
687 }
688 }
689 return input_memory_types;
690 }
691
GetOutputMemoryTypes(const tensorflow::FunctionBody * fbody)692 tensorflow::MemoryTypeVector GetOutputMemoryTypes(
693 const tensorflow::FunctionBody* fbody) {
694 tensorflow::MemoryTypeVector output_memory_types(fbody->ret_types.size(),
695 tensorflow::DEVICE_MEMORY);
696 for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
697 if (fbody->ret_types[i] == tensorflow::DT_RESOURCE) {
698 output_memory_types[i] = tensorflow::HOST_MEMORY;
699 }
700 }
701 return output_memory_types;
702 }
703
704 static auto const ops_triggering_xla_compilation =
705 new absl::flat_hash_set<std::string>{"XlaBroadcastHelper",
706 "XlaCallModule",
707 "XlaConv",
708 "XlaConvV2",
709 "XlaDequantize",
710 "XlaDot",
711 "XlaDotV2",
712 "XlaDynamicSlice",
713 "XlaDynamicUpdateSlice",
714 "XlaEinsum",
715 "XlaGather",
716 "XlaIf",
717 "XlaKeyValueSort",
718 "XlaPad",
719 "XlaRecv",
720 "XlaReduce",
721 "XlaReduceWindow",
722 "XlaReplicaId",
723 "XlaRngBitGenerator",
724 "XlaScatter",
725 "XlaSelectAndScatter",
726 "XlaSelfAdjointEig",
727 "XlaSend",
728 "XlaSharding",
729 "XlaSort",
730 "XlaSpmdFullToShardShape",
731 "XlaSpmdShardToFullShape",
732 "XlaSvd",
733 "XlaVariadicReduceV2",
734 "XlaVariadicSort",
735 "XlaWhile"};
736
NodeCanTriggerXlaCompilation(const NodeDef & node)737 static bool NodeCanTriggerXlaCompilation(const NodeDef& node) {
738 return node.attr().find(kXlaClusterIdAttr) != node.attr().end() ||
739 HasBoolAttr(node, kXlaMustCompileAttr) ||
740 HasBoolAttr(node, kXlaCompileAttr) ||
741 HasBoolAttr(node, kXlaScopeAttr) ||
742 HasBoolAttr(node, kXlaInternalScopeAttr) ||
743 ops_triggering_xla_compilation->count(node.op());
744 }
745
CanTriggerXlaCompilation(const GraphDef & graph)746 bool CanTriggerXlaCompilation(const GraphDef& graph) {
747 for (const FunctionDef& function : graph.library().function()) {
748 for (const NodeDef& node : function.node_def()) {
749 if (NodeCanTriggerXlaCompilation(node)) {
750 return true;
751 }
752 }
753 }
754
755 for (const NodeDef& node : graph.node()) {
756 if (NodeCanTriggerXlaCompilation(node)) {
757 return true;
758 }
759 }
760
761 return false;
762 }
763
764 } // namespace tensorflow
765