xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/compilability_check_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/compilability_check_util.h"
17 
18 #include <algorithm>
19 #include <atomic>
20 #include <deque>
21 #include <iterator>
22 #include <limits>
23 #include <string>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <utility>
27 
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "absl/strings/string_view.h"
34 #include "tensorflow/compiler/jit/defs.h"
35 #include "tensorflow/compiler/jit/device_util.h"
36 #include "tensorflow/compiler/jit/flags.h"
37 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
38 #include "tensorflow/compiler/jit/xla_activity.pb.h"
39 #include "tensorflow/compiler/jit/xla_activity_listener.h"
40 #include "tensorflow/compiler/jit/xla_cluster_util.h"
41 #include "tensorflow/compiler/tf2xla/const_analysis.h"
42 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
43 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
44 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
45 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
46 #include "tensorflow/compiler/xla/statusor.h"
47 #include "tensorflow/compiler/xla/union_find.h"
48 #include "tensorflow/compiler/xla/util.h"
49 #include "tensorflow/core/common_runtime/function.h"
50 #include "tensorflow/core/common_runtime/graph_constructor.h"
51 #include "tensorflow/core/framework/attr_value.pb.h"
52 #include "tensorflow/core/framework/bounds_check.h"
53 #include "tensorflow/core/framework/graph_def_util.h"
54 #include "tensorflow/core/framework/memory_types.h"
55 #include "tensorflow/core/framework/node_def.pb.h"
56 #include "tensorflow/core/framework/op_kernel.h"
57 #include "tensorflow/core/framework/types.h"
58 #include "tensorflow/core/graph/algorithm.h"
59 #include "tensorflow/core/graph/control_flow.h"
60 #include "tensorflow/core/lib/gtl/cleanup.h"
61 #include "tensorflow/core/lib/strings/stringprintf.h"
62 #include "tensorflow/core/public/version.h"
63 #include "tensorflow/core/util/dump_graph.h"
64 
65 namespace tensorflow {
66 
67 namespace {
68 
HasResourceInput(const Node & node)69 bool HasResourceInput(const Node& node) {
70   return absl::c_count(node.input_types(), DT_RESOURCE) != 0;
71 }
72 
LogNotCompilable(const Node & node,absl::string_view reason="")73 void LogNotCompilable(const Node& node, absl::string_view reason = "") {
74   VLOG(3) << "Found uncompilable node " << node.name() << " (op "
75           << node.type_string() << ")" << (reason.empty() ? "" : ": ")
76           << reason;
77 }
78 
IsInOutsideCompilationCluster(const Node & n)79 bool IsInOutsideCompilationCluster(const Node& n) {
80   return n.attrs().Find(kXlaOutsideCompilationAttr) != nullptr;
81 }
82 
MakeCallNodeFromAttribute(const Node & node,const std::string & attr_name,NodeDef * node_def)83 Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
84                                  NodeDef* node_def) {
85   const NameAttrList* name_attr;
86   TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &name_attr));
87   node_def->set_op(name_attr->name());
88   *(node_def->mutable_attr()) = name_attr->attr();
89   return OkStatus();
90 }
91 
MakeCallNodesFromAttribute(const Node & node,absl::string_view attr_name,absl::string_view call_name)92 StatusOr<std::vector<NodeDef>> MakeCallNodesFromAttribute(
93     const Node& node, absl::string_view attr_name,
94     absl::string_view call_name) {
95   std::vector<NameAttrList> attr_lists;
96   TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &attr_lists));
97 
98   std::vector<NodeDef> out;
99   out.reserve(attr_lists.size());
100   for (int i = 0; i < attr_lists.size(); i++) {
101     out.emplace_back();
102     NodeDef& inserted = out.back();
103     inserted.set_name(absl::StrCat(call_name, "_", i));
104     inserted.set_op(attr_lists[i].name());
105     *inserted.mutable_attr() = attr_lists[i].attr();
106   }
107   return out;
108 }
109 
110 // Utility which searches for values in a sorted list by scanning over it once.
111 // No matter how many times ScanForValue is called, the list is scanned at most
112 // once. However, if a call to ScanForValue skips over a value, that value is
113 // not revisited in future calls to ScanForValue, so callers must take
114 // care to order their calls.
115 //
116 // Useful for merging multiple sorted lists in O(n) time.
117 class SinglePassSearch {
118  public:
119   // Creates a SinglePassSearch object that can be used to search in `values`.
120   // Does not take ownership of `values`. `values` must outlive this.
121   // `values` must be sorted.
SinglePassSearch(absl::Span<int const> values)122   explicit SinglePassSearch(absl::Span<int const> values)
123       : current_index_(0), values_(values) {}
124 
125   // Scans forward in the vector looking for "value", updating the internal
126   // position in to the vector.
127   // Returns true iff the vector contains the given value at or after current
128   // position.
129   // Not thread-safe.
ScanForValue(int value)130   bool ScanForValue(int value) {
131     while (current_index_ < values_.size() &&
132            values_[current_index_] <= value) {
133       if (values_[current_index_] == value) {
134         current_index_++;
135         return true;
136       }
137       current_index_++;
138     }
139     return false;
140   }
141 
142  private:
143   int current_index_;
144   const absl::Span<int const> values_;
145 };
146 
147 }  // anonymous namespace
148 
149 RecursiveCompilabilityChecker::UncompilableNodesMap
FindUncompilableNodes(const Node & node,FunctionLibraryRuntime * lib_runtime,const std::vector<RecursiveCompilabilityChecker::StackFrame> * node_stack_trace) const150 RecursiveCompilabilityChecker::FindUncompilableNodes(
151     const Node& node, FunctionLibraryRuntime* lib_runtime,
152     const std::vector<RecursiveCompilabilityChecker::StackFrame>*
153         node_stack_trace) const {
154   std::vector<StackFrameView> stack_trace;
155   // If `node_stack_trace` is provided, that means `node` is inside
156   // a function body, and therefore, arg nodes and retval nodes are
157   // not considered uncompilable.
158   if (node_stack_trace != nullptr) {
159     for (const auto& frame : *node_stack_trace) {
160       stack_trace.emplace_back(
161           StackFrameView{frame.name, frame.function_name, frame.stack_trace});
162     }
163   }
164   stack_trace.emplace_back(
165       StackFrameView{node.name(), "", node.GetStackTrace()});
166 
167   RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes;
168   IsCompilableNode(node, lib_runtime, &stack_trace,
169                    /*encapsulating_function=*/nullptr, &uncompilable_nodes);
170   return uncompilable_nodes;
171 }
172 
HasXLAKernel(const Node & node,string * uncompilable_reason) const173 bool RecursiveCompilabilityChecker::HasXLAKernel(
174     const Node& node, string* uncompilable_reason) const {
175   // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
176   // is really a kind of function call and will be handled by
177   // IsCompilableCall().
178   if (node.type_string() == "SymbolicGradient") {
179     *uncompilable_reason =
180         "SymbolicGradient should be handled by IsCompilableCall().";
181     return false;
182   }
183 
184   if (node.type_string() == "Const") {
185     const AttrValue* attr = node.attrs().Find("dtype");
186     if (!op_filter_.allow_string_consts && attr != nullptr &&
187         attr->type() == DT_STRING) {
188       *uncompilable_reason =
189           "Const op with type DT_STRING is not supported by XLA.";
190       return false;
191     }
192   }
193 
194   // XLA does not offer guaranteed aliasing between the input and output of the
195   // XLA cluster so it can't implement the forward-tensor-ref semantic.  Leave
196   // such nodes out of XLA clusters.
197   if (HasForwardedRefInput(node)) {
198     VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
199     *uncompilable_reason = "Identity with unsafe cast.";
200     return false;
201   }
202 
203   Status s = FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr);
204   if (!s.ok()) {
205     *uncompilable_reason = s.error_message();
206     return false;
207   }
208   return true;
209 }
210 
211 // Tests whether 'if_node' is compilable. Every operator in the then_branch and
212 // else_branch functions must be compilable for 'if_node' to be compilable.
IsCompilableIf(const Node & if_node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const213 bool RecursiveCompilabilityChecker::IsCompilableIf(
214     const Node& if_node, FunctionLibraryRuntime* lib_runtime,
215     std::vector<StackFrameView>* stack_trace,
216     NameAttrList* encapsulating_function,
217     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
218     const {
219   bool is_compilable = true;
220   is_compilable &= ExtractNodeDefAndCheckCompilability(
221       if_node, "then_branch", "if_then", encapsulating_function, lib_runtime,
222       stack_trace, uncompilable_nodes);
223   if (!uncompilable_nodes && !is_compilable) return is_compilable;
224 
225   is_compilable &= ExtractNodeDefAndCheckCompilability(
226       if_node, "else_branch", "if_else", encapsulating_function, lib_runtime,
227       stack_trace, uncompilable_nodes);
228 
229   return is_compilable;
230 }
231 
IsCompilableCase(const Node & case_node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const232 bool RecursiveCompilabilityChecker::IsCompilableCase(
233     const Node& case_node, FunctionLibraryRuntime* lib_runtime,
234     std::vector<StackFrameView>* stack_trace,
235     NameAttrList* encapsulating_function,
236     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
237     const {
238   StatusOr<std::vector<NodeDef>> calls =
239       MakeCallNodesFromAttribute(case_node, "branches", "branch");
240   if (!calls.ok()) {
241     VLOG(2) << "Rejecting node " << case_node.name() << ": "
242             << "missing attribute 'branches'";
243     return false;
244   }
245 
246   bool is_compilable = true;
247 
248   for (const NodeDef& call : *calls) {
249     is_compilable &=
250         IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
251                          uncompilable_nodes);
252   }
253   return is_compilable;
254 }
255 
256 // Tests whether 'while_node' is a completely compilable loop.
257 // Every operator in the condition and body functions must be compilable for a
258 // while loop to be compilable.
IsCompilableWhile(const Node & while_node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const259 bool RecursiveCompilabilityChecker::IsCompilableWhile(
260     const Node& while_node, FunctionLibraryRuntime* lib_runtime,
261     std::vector<StackFrameView>* stack_trace,
262     NameAttrList* encapsulating_function,
263     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
264     const {
265   bool is_compilable = true;
266   is_compilable &= ExtractNodeDefAndCheckCompilability(
267       while_node, "cond", "while_cond", encapsulating_function, lib_runtime,
268       stack_trace, uncompilable_nodes);
269 
270   if (!uncompilable_nodes && !is_compilable) return is_compilable;
271 
272   is_compilable &= ExtractNodeDefAndCheckCompilability(
273       while_node, "body", "while_body", encapsulating_function, lib_runtime,
274       stack_trace, uncompilable_nodes);
275 
276   return is_compilable;
277 }
278 
ExtractNodeDefAndCheckCompilability(const Node & node,const std::string & attr_name,const std::string & call_name,NameAttrList * encapsulating_function,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const279 bool RecursiveCompilabilityChecker::ExtractNodeDefAndCheckCompilability(
280     const Node& node, const std::string& attr_name,
281     const std::string& call_name, NameAttrList* encapsulating_function,
282     FunctionLibraryRuntime* lib_runtime,
283     std::vector<StackFrameView>* stack_trace,
284     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
285     const {
286   NodeDef call;
287   call.set_name(call_name);
288   if (!MakeCallNodeFromAttribute(node, attr_name, &call).ok()) {
289     const auto uncompilable_reason = absl::StrCat(
290         "missing '", attr_name, "' attribute from node", node.name());
291     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
292                               encapsulating_function, uncompilable_nodes);
293     VLOG(2) << "Rejecting node " << node.name() << ": " << uncompilable_reason
294             << ".";
295     return false;
296   }
297   if (!IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
298                         uncompilable_nodes)) {
299     VLOG(2) << "Rejecting node " << node.name()
300             << ": can't compile : " << call.op();
301     return false;
302   }
303   return true;
304 }
305 
306 // Tests whether 'call_def' is a call to a completely compilable function.
307 // Every operator in the function must be compilable for a function to be
308 // compilable.
IsCompilableCall(const NodeDef & call_def,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const309 bool RecursiveCompilabilityChecker::IsCompilableCall(
310     const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
311     std::vector<StackFrameView>* stack_trace,
312     NameAttrList* encapsulating_function,
313     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
314     const {
315   if (stack_trace->size() > kMaxRecursionDepth) {
316     std::string uncompilable_reason = "function depth limit exceeded";
317     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
318                               encapsulating_function, uncompilable_nodes);
319     VLOG(2) << "Rejecting " << call_def.op() << ": " << uncompilable_reason
320             << ".";
321     return false;
322   }
323 
324   FunctionLibraryRuntime::Handle handle;
325   Status s;
326   NameAttrList function;
327   s = NameAndAttrsFromFunctionCall(call_def, &function);
328   if (s.ok()) {
329     s = lib_runtime->Instantiate(function.name(), AttrSlice(&function.attr()),
330                                  &handle);
331   }
332   if (!s.ok()) {
333     std::string uncompilable_reason =
334         absl::StrCat("could not instantiate call: '", function.name(), "'");
335     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
336                               encapsulating_function, uncompilable_nodes);
337     VLOG(2) << "Rejecting " << call_def.DebugString() << ": "
338             << uncompilable_reason << " : " << s;
339     return false;
340   }
341 
342   auto release_handle_on_return = gtl::MakeCleanup(
343       [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
344   const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
345   bool is_compilable = true;
346   for (const Node* node : fbody->graph->op_nodes()) {
347     stack_trace->emplace_back(
348         StackFrameView{node->name(), function.name(), node->GetStackTrace()});
349     is_compilable &= IsCompilableNode(*node, lib_runtime, stack_trace,
350                                       &function, uncompilable_nodes);
351     stack_trace->pop_back();
352     if (!uncompilable_nodes && !is_compilable) return is_compilable;
353   }
354 
355   return is_compilable;
356 }
357 
OpIsInaccurate(const Node & node) const358 bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) const {
359   // b/127344411: SelfAdjointEigV2 and Svd precision issues.
360   return node.type_string() == "SelfAdjointEigV2" ||
361          node.type_string() == "Svd";
362 }
363 
OpIsSlow(const Node & node) const364 bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) const {
365   // b/128001705: SelfAdjointEigV2 and Svd performance issues.
366   // b/135640736: MatrixInverse performance issues.
367   // b/111271662: MatrixSolve performance issues.
368   // https://github.com/tensorflow/tensorflow/pull/31012:
369   //    ResizeNearestNeighbor, ResizeBilinear, and ResizeBilinearGrad sometimes
370   //    create convolutions too large for CuDNN to handle.
371   return node.type_string() == "SelfAdjointEigV2" ||
372          node.type_string() == "Svd" || node.type_string() == "Qr" ||
373          node.type_string() == "MatrixInverse" ||
374          node.type_string() == "MatrixSolve" ||
375          node.type_string() == "ResizeBilinearGrad";
376 }
377 
IsCompilableNode(const Node & node,FunctionLibraryRuntime * lib_runtime,std::vector<StackFrameView> * stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes) const378 bool RecursiveCompilabilityChecker::IsCompilableNode(
379     const Node& node, FunctionLibraryRuntime* lib_runtime,
380     std::vector<StackFrameView>* stack_trace,
381     NameAttrList* encapsulating_function,
382     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
383     const {
384   auto stack_depth = stack_trace->size();
385 
386   if (op_filter_.allow_outside_compiled && IsInOutsideCompilationCluster(node))
387     return true;
388 
389   if (node.IsSource() || node.IsSink()) {
390     absl::string_view uncompilable_reason = "source or sink node";
391     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
392                               encapsulating_function, uncompilable_nodes);
393     LogNotCompilable(node, uncompilable_reason);
394     return false;
395   }
396 
397   // _Arg nodes in a top-level function represent feeds and _Retval nodes in a
398   // top-level function represent fetches.
399   if (stack_depth == 1 &&
400       (node.type_string() == "_Arg" || node.type_string() == "_Retval")) {
401     absl::string_view uncompilable_reason = "top level _Arg or _Retval";
402     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
403                               encapsulating_function, uncompilable_nodes);
404     LogNotCompilable(node, uncompilable_reason);
405     return false;
406   }
407 
408   if (node.attrs().Find("_scoped_allocator") ||
409       node.attrs().Find("_forward_from")) {
410     // TODO(b/128858118): XLA does not support _scoped_allocator and
411     // _forward_from.
412     absl::string_view uncompilable_reason =
413         "_scoped_allocator or _forward_from attribute";
414     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
415                               encapsulating_function, uncompilable_nodes);
416     LogNotCompilable(node, uncompilable_reason);
417     return false;
418   }
419 
420   string uncompilable_reason;
421   if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) {
422     if (!IsCompilableCall(node.def(), lib_runtime, stack_trace,
423                           encapsulating_function, uncompilable_nodes)) {
424       LogNotCompilable(node, "unsupported function");
425       return false;
426     }
427   } else if (!HasXLAKernel(node, &uncompilable_reason)) {
428     MaybeMarkUncompilableNode(
429         absl::StrCat("unsupported op: ", uncompilable_reason), *stack_trace,
430         encapsulating_function, uncompilable_nodes);
431     LogNotCompilable(node, uncompilable_reason);
432     return false;
433   }
434 
435   if (node.IsWhileNode() &&
436       !IsCompilableWhile(node, lib_runtime, stack_trace, encapsulating_function,
437                          uncompilable_nodes)) {
438     LogNotCompilable(node, "unsupported while");
439     return false;
440   }
441 
442   if (node.IsIfNode() &&
443       !IsCompilableIf(node, lib_runtime, stack_trace, encapsulating_function,
444                       uncompilable_nodes)) {
445     LogNotCompilable(node, "unsupported if");
446     return false;
447   }
448 
449   if (op_filter_.require_always_compilable && node.IsCaseNode() &&
450       !IsCompilableCase(node, lib_runtime, stack_trace, encapsulating_function,
451                         uncompilable_nodes)) {
452     LogNotCompilable(node, "unsupported case");
453     return false;
454   }
455 
456   if (!op_filter_.allow_stateful_rng_ops &&
457       IsStatefulRandomOp(node.type_string())) {
458     absl::string_view uncompilable_reason = "stateful random op";
459     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
460                               encapsulating_function, uncompilable_nodes);
461     LogNotCompilable(node, uncompilable_reason);
462     return false;
463   }
464 
465   if (!op_filter_.allow_control_trigger && node.IsControlTrigger()) {
466     absl::string_view uncompilable_reason = "not allowed control trigger";
467     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
468                               encapsulating_function, uncompilable_nodes);
469     LogNotCompilable(node, uncompilable_reason);
470     return false;
471   }
472 
473   if (!op_filter_.allow_eliding_assert_and_checknumerics_ops &&
474       IsAssertOrCheckNumerics(node.type_string())) {
475     absl::string_view uncompilable_reason = "Assert or CheckNumerics";
476     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
477                               encapsulating_function, uncompilable_nodes);
478     LogNotCompilable(node, uncompilable_reason);
479     return false;
480   }
481 
482   if (!op_filter_.allow_collective_reduce_v2 &&
483       node.type_string() == "CollectiveReduceV2") {
484     absl::string_view uncompilable_reason = "Collective op";
485     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
486                               encapsulating_function, uncompilable_nodes);
487     LogNotCompilable(node, uncompilable_reason);
488     return false;
489   }
490 
491   if (!op_filter_.allow_where_op && node.type_string() == "Where") {
492     absl::string_view uncompilable_reason = "Where op";
493     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
494                               encapsulating_function, uncompilable_nodes);
495     LogNotCompilable(node, uncompilable_reason);
496     return false;
497   }
498 
499   if (!op_filter_.allow_unique_op && node.type_string() == "Unique") {
500     absl::string_view uncompilable_reason = "Unique op";
501     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
502                               encapsulating_function, uncompilable_nodes);
503     LogNotCompilable(node, uncompilable_reason);
504     return false;
505   }
506 
507   if (!op_filter_.allow_ops_producing_or_consuming_variant &&
508       OpProducesOrConsumesVariant(node)) {
509     absl::string_view uncompilable_reason = "DT_VARIANT producer/consumer";
510     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
511                               encapsulating_function, uncompilable_nodes);
512     LogNotCompilable(node, uncompilable_reason);
513     return false;
514   }
515 
516   if (!op_filter_.allow_stack_ops && IsStackOp(node)) {
517     absl::string_view uncompilable_reason = "Stack op";
518     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
519                               encapsulating_function, uncompilable_nodes);
520     LogNotCompilable(node, uncompilable_reason);
521     return false;
522   }
523 
524   if (!op_filter_.allow_tensor_array_ops && IsTensorArrayOp(node)) {
525     absl::string_view uncompilable_reason = "TensorArray op";
526     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
527                               encapsulating_function, uncompilable_nodes);
528     LogNotCompilable(node, uncompilable_reason);
529     return false;
530   }
531 
532   if (!op_filter_.allow_resource_ops_in_called_functions && stack_depth > 1 &&
533       HasResourceInput(node)) {
534     absl::string_view uncompilable_reason =
535         "resource variable op in called function";
536     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
537                               encapsulating_function, uncompilable_nodes);
538     LogNotCompilable(node, uncompilable_reason);
539     return false;
540   }
541 
542   if (!op_filter_.allow_inaccurate_ops && OpIsInaccurate(node)) {
543     absl::string_view uncompilable_reason =
544         "operation with numerical accuracy issues";
545     BroadcastOptimizationRemark(XlaOptimizationRemark::INACCURATE_OPERATION,
546                                 node.DebugString())
547         .IgnoreError();
548     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
549                               encapsulating_function, uncompilable_nodes);
550     LogNotCompilable(node, uncompilable_reason);
551     return false;
552   }
553 
554   if (!op_filter_.allow_slow_ops && OpIsSlow(node)) {
555     absl::string_view uncompilable_reason = "slow operation";
556     BroadcastOptimizationRemark(XlaOptimizationRemark::SLOW_OPERATION,
557                                 node.DebugString())
558         .IgnoreError();
559     MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
560                               encapsulating_function, uncompilable_nodes);
561     LogNotCompilable(node, uncompilable_reason);
562     return false;
563   }
564 
565   return true;
566 }
567 
CreateOperationFilter(const XlaOpRegistry::DeviceRegistration & registration)568 RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
569     const XlaOpRegistry::DeviceRegistration& registration) {
570   RecursiveCompilabilityChecker::OperationFilter op_filter;
571   op_filter.allow_resource_ops_in_called_functions =
572       registration.cluster_resource_variable_ops_unsafely;
573   op_filter.allow_stack_ops = registration.cluster_stack_ops;
574   op_filter.allow_tensor_array_ops = registration.cluster_tensor_array_ops;
575   op_filter.allow_stateful_rng_ops = registration.cluster_stateful_rng_ops;
576   op_filter.allow_control_trigger = registration.cluster_control_trigger;
577   op_filter.allow_eliding_assert_and_checknumerics_ops =
578       registration.elide_assert_and_checknumerics;
579   op_filter.allow_ops_producing_or_consuming_variant =
580       registration.cluster_variant_ops;
581   op_filter.allow_slow_ops = registration.cluster_slow_ops;
582   op_filter.allow_inaccurate_ops = registration.cluster_inaccurate_ops;
583   return op_filter;
584 }
585 
MaybeMarkUncompilableNode(const absl::string_view reason,const std::vector<StackFrameView> & stack_trace,NameAttrList * encapsulating_function,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_nodes)586 /*static*/ void RecursiveCompilabilityChecker::MaybeMarkUncompilableNode(
587     const absl::string_view reason,
588     const std::vector<StackFrameView>& stack_trace,
589     NameAttrList* encapsulating_function,
590     RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes) {
591   if (!uncompilable_nodes) return;
592 
593   UncompilableNodeInfo node_info;
594   node_info.uncompilable_reason = std::string(reason);
595   absl::c_transform(stack_trace, std::back_inserter(node_info.stack_trace),
596                     [](const StackFrameView& stack_element) {
597                       return StackFrame{
598                           std::string(stack_element.name),
599                           std::string(stack_element.function_name),
600                           stack_element.stack_trace};
601                     });
602 
603   node_info.name = std::string(stack_trace.back().name);
604   auto function =
605       encapsulating_function ? *encapsulating_function : NameAttrList();
606   auto function_identifier = function.ShortDebugString();
607 
608   auto it = uncompilable_nodes->find(function_identifier);
609   if (it == uncompilable_nodes->end()) {
610     std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
611         uncompilable_node_info{std::move(node_info)};
612     uncompilable_nodes->emplace(
613         std::move(function_identifier),
614         std::make_pair(function, std::move(uncompilable_node_info)));
615   } else {
616     it->second.second.emplace_back(std::move(node_info));
617   }
618 }
619 
620 // Returns `true` iff node has a given `attr` set to `true`. Returns `false`
621 // both for the missing attr, and the attr set to `false`.
HasBoolAttr(const NodeDef & node,const char * attr)622 static bool HasBoolAttr(const NodeDef& node, const char* attr) {
623   const auto& it = node.attr().find(attr);
624   return it != node.attr().end() && it->second.b();
625 }
626 
CanCreateXlaKernel(const NodeDef & node_def)627 bool CanCreateXlaKernel(const NodeDef& node_def) {
628   return HasBoolAttr(node_def, kXlaMustCompileAttr);
629 }
630 
GetBodyAndConstantsAndResources(FunctionLibraryRuntime * flr,const NameAttrList & function,const FunctionBody ** fbody,std::vector<int> * constant_arg_indices,std::vector<int> * resource_arg_indices)631 Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
632                                        const NameAttrList& function,
633                                        const FunctionBody** fbody,
634                                        std::vector<int>* constant_arg_indices,
635                                        std::vector<int>* resource_arg_indices) {
636   FunctionLibraryRuntime::Handle handle;
637   TF_RETURN_IF_ERROR(
638       flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
639   *fbody = flr->GetFunctionBody(handle);
640   CHECK(*fbody);  // Can't be nullptr since we just instantiated it.
641   const DataTypeVector& arg_types = (*fbody)->arg_types;
642   std::vector<bool> const_args(arg_types.size());
643   // If we can't analyze the const args. Bail out.
644   TF_RETURN_IF_ERROR(
645       BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
646                              /*compile_time_const_nodes=*/nullptr, flr));
647 
648   for (size_t i = 0; i < const_args.size(); ++i) {
649     if (const_args[i]) {
650       constant_arg_indices->push_back(i);
651     }
652   }
653 
654   // There can be hundreds of resource variables. Reserve the space for them.
655   // We don't reserve for constants above as they are usually few.
656   resource_arg_indices->reserve(arg_types.size());
657   for (size_t i = 0; i < arg_types.size(); ++i) {
658     if (arg_types[i] == DT_RESOURCE) {
659       resource_arg_indices->push_back(i);
660     }
661   }
662 
663   return OkStatus();
664 }
665 
GetInputMemoryTypes(const tensorflow::FunctionBody * fbody,absl::Span<int const> constant_arg_indices,absl::Span<int const> resource_arg_indices)666 tensorflow::MemoryTypeVector GetInputMemoryTypes(
667     const tensorflow::FunctionBody* fbody,
668     absl::Span<int const> constant_arg_indices,
669     absl::Span<int const> resource_arg_indices) {
670   // Set input and output memory types.
671   tensorflow::MemoryTypeVector input_memory_types(fbody->arg_types.size(),
672                                                   tensorflow::DEVICE_MEMORY);
673   // These indices are used only for optimization purposes. They allow us
674   // to loop over constant_arg_indices and resource_arg_indices only once
675   // while iterating over all the function arguments checking if it is a
676   // resource or a constant.
677   // The reason we optimized this code is because functions can have a lot of
678   // captured arguments. For example, the backward pass of ResNet50 takes in all
679   // 214 variables and a similar number of activations.
680   SinglePassSearch constants_search(constant_arg_indices);
681   SinglePassSearch resources_search(resource_arg_indices);
682   for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
683     if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
684       // Compile-time constants and resource handles are expected to be in
685       // host memory.
686       input_memory_types[i] = tensorflow::HOST_MEMORY;
687     }
688   }
689   return input_memory_types;
690 }
691 
GetOutputMemoryTypes(const tensorflow::FunctionBody * fbody)692 tensorflow::MemoryTypeVector GetOutputMemoryTypes(
693     const tensorflow::FunctionBody* fbody) {
694   tensorflow::MemoryTypeVector output_memory_types(fbody->ret_types.size(),
695                                                    tensorflow::DEVICE_MEMORY);
696   for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
697     if (fbody->ret_types[i] == tensorflow::DT_RESOURCE) {
698       output_memory_types[i] = tensorflow::HOST_MEMORY;
699     }
700   }
701   return output_memory_types;
702 }
703 
704 static auto const ops_triggering_xla_compilation =
705     new absl::flat_hash_set<std::string>{"XlaBroadcastHelper",
706                                          "XlaCallModule",
707                                          "XlaConv",
708                                          "XlaConvV2",
709                                          "XlaDequantize",
710                                          "XlaDot",
711                                          "XlaDotV2",
712                                          "XlaDynamicSlice",
713                                          "XlaDynamicUpdateSlice",
714                                          "XlaEinsum",
715                                          "XlaGather",
716                                          "XlaIf",
717                                          "XlaKeyValueSort",
718                                          "XlaPad",
719                                          "XlaRecv",
720                                          "XlaReduce",
721                                          "XlaReduceWindow",
722                                          "XlaReplicaId",
723                                          "XlaRngBitGenerator",
724                                          "XlaScatter",
725                                          "XlaSelectAndScatter",
726                                          "XlaSelfAdjointEig",
727                                          "XlaSend",
728                                          "XlaSharding",
729                                          "XlaSort",
730                                          "XlaSpmdFullToShardShape",
731                                          "XlaSpmdShardToFullShape",
732                                          "XlaSvd",
733                                          "XlaVariadicReduceV2",
734                                          "XlaVariadicSort",
735                                          "XlaWhile"};
736 
NodeCanTriggerXlaCompilation(const NodeDef & node)737 static bool NodeCanTriggerXlaCompilation(const NodeDef& node) {
738   return node.attr().find(kXlaClusterIdAttr) != node.attr().end() ||
739          HasBoolAttr(node, kXlaMustCompileAttr) ||
740          HasBoolAttr(node, kXlaCompileAttr) ||
741          HasBoolAttr(node, kXlaScopeAttr) ||
742          HasBoolAttr(node, kXlaInternalScopeAttr) ||
743          ops_triggering_xla_compilation->count(node.op());
744 }
745 
CanTriggerXlaCompilation(const GraphDef & graph)746 bool CanTriggerXlaCompilation(const GraphDef& graph) {
747   for (const FunctionDef& function : graph.library().function()) {
748     for (const NodeDef& node : function.node_def()) {
749       if (NodeCanTriggerXlaCompilation(node)) {
750         return true;
751       }
752     }
753   }
754 
755   for (const NodeDef& node : graph.node()) {
756     if (NodeCanTriggerXlaCompilation(node)) {
757       return true;
758     }
759   }
760 
761   return false;
762 }
763 
764 }  // namespace tensorflow
765