xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/compilability_check_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
17 #define TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
18 
19 #include <string>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/compiler/jit/defs.h"
25 #include "tensorflow/compiler/jit/device_util.h"
26 #include "tensorflow/compiler/jit/flags.h"
27 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
28 #include "tensorflow/compiler/tf2xla/const_analysis.h"
29 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
30 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
31 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/union_find.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/common_runtime/function.h"
36 #include "tensorflow/core/common_runtime/graph_constructor.h"
37 #include "tensorflow/core/framework/attr_value.pb.h"
38 #include "tensorflow/core/framework/bounds_check.h"
39 #include "tensorflow/core/framework/function.h"
40 #include "tensorflow/core/framework/graph_def_util.h"
41 #include "tensorflow/core/framework/memory_types.h"
42 #include "tensorflow/core/framework/node_def.pb.h"
43 #include "tensorflow/core/framework/op_kernel.h"
44 #include "tensorflow/core/framework/types.h"
45 #include "tensorflow/core/framework/types.pb.h"
46 #include "tensorflow/core/graph/algorithm.h"
47 #include "tensorflow/core/graph/control_flow.h"
48 #include "tensorflow/core/graph/graph.h"
49 #include "tensorflow/core/lib/gtl/cleanup.h"
50 #include "tensorflow/core/lib/strings/stringprintf.h"
51 #include "tensorflow/core/public/version.h"
52 #include "tensorflow/core/util/dump_graph.h"
53 
54 namespace tensorflow {
55 // Checks whether a TF node can be compiled or not.  "Recursive" as in for call
56 // and functional while nodes it recursively checks whether the callee functions
57 // can be compiled.
58 class RecursiveCompilabilityChecker {
59  public:
60   // Contains node name and function name. If the node is not inside a function
61   // body, function name is an empty string.
62   struct StackFrame {
63     std::string name;
64     std::string function_name;
65     std::shared_ptr<AbstractStackTrace> stack_trace;
66   };
67 
68   // Contains information about uncompilable node inside a function body.
69   struct UncompilableNodeInfo {
70     std::string name;
71     // A list representing a stacktrace from the highest level node in
72     // increasing call depth to immediate node that fails the
73     // compilability checker.
74     std::vector<StackFrame> stack_trace;
75     std::string uncompilable_reason;
76   };
77 
78   // Aggregates information about what kinds of ops are allowed.
79   struct OperationFilter {  // TODO(lzr): Add AllowEverything() helper.
80     // Whether resource variable ops are allowed are allowed in callees.  We do
81     // not allow resource variable ops in called functions (either as direct TF
82     // calls or as higher order control flow ops) because we do not yet model
83     // their memory effects in jit/resource_operation_safety_analysis.
84     bool allow_resource_ops_in_called_functions = false;
85 
86     // Whether Stack operations are allowed.  We avoid auto-clustering Stack
87     // operations in general because we do not support snapshotting them.
88     //
89     // TODO(b/112837194): This restriction can be lifted with some work.
90     bool allow_stack_ops = false;
91 
92     // Whether TensorArray operations are allowed.  We avoid auto-clustering
93     // TensorArray operations in general because we do not support snapshotting
94     // them.
95     //
96     // TODO(b/112837194): This restriction can be lifted with some work.
97     bool allow_tensor_array_ops = false;
98 
99     // Whether stateful RNG ops are allowed.  XLA's RNG does not have the same
100     // seeding behavior as TensorFlow's RNG (b/34749654).  So we avoid
101     // auto-clustering stateful RNG ops.
102     bool allow_stateful_rng_ops = false;
103 
104     // TODO(b/118970344): Whether ControlTrigger ops are allowed.  It is unsound
105     // to cluster ControlTrigger because of how we use deadness analysis.
106     bool allow_control_trigger = false;
107 
108     // Whether it is okay to "cluster" Assert and CheckNumerics by simply
109     // removing them (they're not removed during clustering, but their
110     // XlaOpKernel is a no-op kernel).  We avoid auto-clustering these ops so
111     // that the user is not surprised when XLA is implicitly enabled. If the
112     // user explicitly specifies to use XLA, it is fine to resort to a dummy
113     // implementation. Currently Assert and CheckNumerics ops have dummy XLA
114     // implementations.
115     bool allow_eliding_assert_and_checknumerics_ops = false;
116 
117     // Whether ops that produce or consume DT_VARIANT values are allowed.  We
118     // don't auto-cluster these ops because we don't yet support live-in or
119     // live-out DT_VARIANT values.
120     bool allow_ops_producing_or_consuming_variant = false;
121 
122     // Whether ops known to be slow on XLA-GPU should be considered compilable.
123     bool allow_slow_ops = false;
124 
125     // Whether ops known to have numerical accuracy issues should be considered
126     // compilable..
127     bool allow_inaccurate_ops = false;
128 
129     // Require the function to be always compilable, regardless whether some
130     // control flow branches might be dead for a given input.
131     bool require_always_compilable = false;
132 
133     // Whether string constants are compilable.
134     bool allow_string_consts = true;
135 
136     // Whether to allow the compilation of CollectiveReduceV2Op.
137     bool allow_collective_reduce_v2 = true;
138 
139     // Whether to allow the compilation of WhereOp.
140     bool allow_where_op = true;
141 
142     // Whether to allow the compilation of UniqueOp. Compilation of the UniqueOp
143     // generates output with bounded dynamic shape that may cause failures with
144     // auto clustering.
145     // TODO(b/209813421): Enable tf.unique during
146     // autoclustering once all failures are rfixed.
147     bool allow_unique_op = true;
148 
149     // Whether ops that are marked as outside compiled are always considered
150     // compilable.
151     // TODO(b/191502757):  Make this behavior true by default and remove this
152     // option once inference converter supports outside compilation.
153     bool allow_outside_compiled = false;
154   };
155 
RecursiveCompilabilityChecker(OperationFilter op_filter,DeviceType jit_device_type)156   RecursiveCompilabilityChecker(OperationFilter op_filter,
157                                 DeviceType jit_device_type)
158       : op_filter_(std::move(op_filter)),
159         jit_device_type_(std::move(jit_device_type)) {}
160 
161   using UncompilableNodesMap =
162       std::map<std::string,
163                std::pair<NameAttrList, std::vector<UncompilableNodeInfo>>>;
164 
165   // Returns a map where the key is the function identifier(short debug
166   // string) of the function encapsulating the uncompilable nodes, and the
167   // value is a pair of NameAttrList of the function and a vector of
168   // uncompilable node info. When uncompilable node is not inside any
169   // function call nodes, then key is a ShortDebugString() of an empty
170   // NameAttrList.
171   //
172   // Also, when `node` is inside a function body, users can set
173   // `node_stack_trace` to provide an additional context for `node`'s
174   // placement within the outer most graph.
175   UncompilableNodesMap FindUncompilableNodes(
176       const Node& node, FunctionLibraryRuntime* lib_runtime,
177       const std::vector<StackFrame>* node_stack_trace = nullptr) const;
178 
179   // Returns true if `node` can be compiled by XLA.
IsCompilableNode(const Node & node,FunctionLibraryRuntime * lib_runtime)180   bool IsCompilableNode(const Node& node,
181                         FunctionLibraryRuntime* lib_runtime) const {
182     std::vector<StackFrameView> stack_trace;
183     stack_trace.emplace_back(StackFrameView{node.name(), ""});
184     return IsCompilableNode(node, lib_runtime, &stack_trace);
185   }
186 
187   // Returns true if XLA supports this Op, but we don't want to cluster it (ie:
188   // due to performance or correctness concerns).
189   bool OpIsInaccurate(const Node& node) const;
190   bool OpIsSlow(const Node& node) const;
191 
192  private:
193   struct StackFrameView {
194     absl::string_view name;
195     absl::string_view function_name;
196     std::shared_ptr<AbstractStackTrace> stack_trace;
197   };
198 
199   bool IsCompilableNode(
200       const Node& node, FunctionLibraryRuntime* lib_runtime,
201       std::vector<StackFrameView>* stack_trace,
202       NameAttrList* encapsulating_function = nullptr,
203       UncompilableNodesMap* uncompilable_nodes = nullptr) const;
204   bool IsCompilableCall(
205       const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime,
206       std::vector<StackFrameView>* stack_trace,
207       NameAttrList* encapsulating_function = nullptr,
208       UncompilableNodesMap* uncompilable_nodes = nullptr) const;
209   bool IsCompilableIf(const Node& if_node, FunctionLibraryRuntime* lib_runtime,
210                       std::vector<StackFrameView>* stack_trace,
211                       NameAttrList* encapsulating_function,
212                       UncompilableNodesMap* uncompilable_nodes) const;
213   bool IsCompilableWhile(const Node& while_node,
214                          FunctionLibraryRuntime* lib_runtime,
215                          std::vector<StackFrameView>* stack_trace,
216                          NameAttrList* encapsulating_function,
217                          UncompilableNodesMap* uncompilable_nodes) const;
218 
219   // Tests whether 'case_node' is compilable. Every operator in all branches
220   // must be compilable.
221   bool IsCompilableCase(const Node& case_node,
222                         FunctionLibraryRuntime* lib_runtime,
223                         std::vector<StackFrameView>* stack_trace,
224                         NameAttrList* encapsulating_function,
225                         UncompilableNodesMap* uncompilable_nodes) const;
226 
227   // Returns compilability of node def retrieved from `node`'s attribute with
228   // name `attr_name`.
229   bool ExtractNodeDefAndCheckCompilability(
230       const Node& node, const std::string& attr_name,
231       const std::string& call_name, NameAttrList* encapsulating_function,
232       FunctionLibraryRuntime* lib_runtime,
233       std::vector<StackFrameView>* stack_trace,
234       UncompilableNodesMap* uncompilable_nodes) const;
235 
IsStackOp(const Node & node)236   bool IsStackOp(const Node& node) const {
237     const XlaResourceOpInfo* op_info =
238         GetResourceOpInfoForOp(node.type_string());
239     return op_info && op_info->resource_kind() == XlaResourceKind::kStack;
240   }
241 
IsTensorArrayOp(const Node & node)242   bool IsTensorArrayOp(const Node& node) const {
243     const XlaResourceOpInfo* op_info =
244         GetResourceOpInfoForOp(node.type_string());
245     return op_info && op_info->resource_kind() == XlaResourceKind::kTensorArray;
246   }
247 
IsAssertOrCheckNumerics(absl::string_view op_name)248   bool IsAssertOrCheckNumerics(absl::string_view op_name) const {
249     return op_name == "Assert" || op_name == "CheckNumerics";
250   }
251 
IsStatefulRandomOp(absl::string_view op_name)252   bool IsStatefulRandomOp(absl::string_view op_name) const {
253     return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
254            op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
255            op_name == "TruncatedNormal" || op_name == "Multinomial";
256   }
257 
OpProducesOrConsumesVariant(const Node & node)258   bool OpProducesOrConsumesVariant(const Node& node) const {
259     auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; };
260     return absl::c_any_of(node.input_types(), is_variant) ||
261            absl::c_any_of(node.output_types(), is_variant);
262   }
263 
264   bool HasXLAKernel(const Node& node,
265                     string* uncompilable_reason = nullptr) const;
266 
267   static void MaybeMarkUncompilableNode(
268       const absl::string_view reason,
269       const std::vector<StackFrameView>& stack_trace,
270       NameAttrList* encapsulating_function,
271       UncompilableNodesMap* uncompilable_nodes_map);
272 
273   // Make sure we don't recurse infinitely on recursive functions.
274   const size_t kMaxRecursionDepth = 50;
275 
276   const OperationFilter op_filter_;
277   const DeviceType jit_device_type_;
278 };
279 
280 RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
281     const XlaOpRegistry::DeviceRegistration& registration);
282 
283 // Given a FunctionLibraryRuntime and a `function`, returns this function's body
284 // in `fbody` as well as the indices of its constant and resource arguments.
285 // `fbody` is owned by `flr`.
286 // `constant_arg_indices` and `resource_arg_indices` should be empty vector.
287 // They are sorted in ascending order on this function's return.
288 Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
289                                        const NameAttrList& function,
290                                        const FunctionBody** fbody,
291                                        std::vector<int>* constant_arg_indices,
292                                        std::vector<int>* resource_arg_indices);
293 
294 // Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr
295 // set.
296 bool CanCreateXlaKernel(const NodeDef& node_def);
297 
298 // Returns memory types for the input.
299 // `constant_arg_indices` and `resource_arg_indices` are sorted arrays of
300 // indices corresponding to constant and resource arguments respectively.
301 //
302 // One might wonder, about the case where a compile-time constant argument
303 // (which must be in host memory) is also used as an input into an op,
304 // e.g. `Add`, that expects its inputs in device memory. Here is how it
305 // works now.
306 // First, what do we mean by "op expects an input in XYZ memory"?
307 // There are two types of "ops" here: the tf2xla kernel and the HLO
308 // computation it builds. The tf2xla kernel needs to retrieve the actual
309 // numeric value of the compile-time constant tensors, so it really expects
310 // them to be on in host memory. However, for other inputs, it refers to them
311 // using xla::ComputationDataHandle, which is just a symbolic handle that
312 // xla::ComputationBuilder assigns. How does this handle gets assigned for
313 // constant arguments? Even constant arguments get an _Arg node in the graph
314 // instantiated for Function compilation. The tf2xla kernel for constant _Arg
315 // nodes takes the constant value, converts it to XlaLiteral, and feeds it
316 // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
317 // constant XlaLiteral is included in the HLO graph, and subsequently, in
318 // the actual executable, which is copied to the device before being
319 // executed. Thus, when this executable runs, the constant is available in
320 // device memory.
321 tensorflow::MemoryTypeVector GetInputMemoryTypes(
322     const tensorflow::FunctionBody* fbody,
323     absl::Span<int const> constant_arg_indices,
324     absl::Span<int const> resource_arg_indices);
325 
326 // Returns output memory types.
327 //
328 // XlaLaunch kernel keeps all outputs (including constants, which it copies),
329 // in device memory except for resources.
330 tensorflow::MemoryTypeVector GetOutputMemoryTypes(
331     const tensorflow::FunctionBody* fbody);
332 
333 // Check whether graph can trigger XLA compilation.
334 bool CanTriggerXlaCompilation(const GraphDef& graph);
335 
336 }  // namespace tensorflow
337 
338 #endif  // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
339