1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ 17 #define TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ 18 19 #include <string> 20 21 #include "absl/algorithm/container.h" 22 #include "absl/strings/string_view.h" 23 #include "absl/types/optional.h" 24 #include "tensorflow/compiler/jit/defs.h" 25 #include "tensorflow/compiler/jit/device_util.h" 26 #include "tensorflow/compiler/jit/flags.h" 27 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" 28 #include "tensorflow/compiler/tf2xla/const_analysis.h" 29 #include "tensorflow/compiler/tf2xla/resource_operation_table.h" 30 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 31 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 #include "tensorflow/compiler/xla/union_find.h" 34 #include "tensorflow/compiler/xla/util.h" 35 #include "tensorflow/core/common_runtime/function.h" 36 #include "tensorflow/core/common_runtime/graph_constructor.h" 37 #include "tensorflow/core/framework/attr_value.pb.h" 38 #include "tensorflow/core/framework/bounds_check.h" 39 #include "tensorflow/core/framework/function.h" 40 #include "tensorflow/core/framework/graph_def_util.h" 41 #include "tensorflow/core/framework/memory_types.h" 42 #include "tensorflow/core/framework/node_def.pb.h" 43 #include "tensorflow/core/framework/op_kernel.h" 44 #include "tensorflow/core/framework/types.h" 45 #include "tensorflow/core/framework/types.pb.h" 46 #include "tensorflow/core/graph/algorithm.h" 47 #include "tensorflow/core/graph/control_flow.h" 48 #include "tensorflow/core/graph/graph.h" 49 #include "tensorflow/core/lib/gtl/cleanup.h" 50 #include "tensorflow/core/lib/strings/stringprintf.h" 51 #include "tensorflow/core/public/version.h" 52 #include "tensorflow/core/util/dump_graph.h" 53 54 namespace tensorflow { 55 // Checks whether a TF node can be compiled or not. "Recursive" as in for call 56 // and functional while nodes it recursively checks whether the callee functions 57 // can be compiled. 58 class RecursiveCompilabilityChecker { 59 public: 60 // Contains node name and function name. If the node is not inside a function 61 // body, function name is an empty string. 62 struct StackFrame { 63 std::string name; 64 std::string function_name; 65 std::shared_ptr<AbstractStackTrace> stack_trace; 66 }; 67 68 // Contains information about uncompilable node inside a function body. 69 struct UncompilableNodeInfo { 70 std::string name; 71 // A list representing a stacktrace from the highest level node in 72 // increasing call depth to immediate node that fails the 73 // compilability checker. 74 std::vector<StackFrame> stack_trace; 75 std::string uncompilable_reason; 76 }; 77 78 // Aggregates information about what kinds of ops are allowed. 79 struct OperationFilter { // TODO(lzr): Add AllowEverything() helper. 80 // Whether resource variable ops are allowed are allowed in callees. We do 81 // not allow resource variable ops in called functions (either as direct TF 82 // calls or as higher order control flow ops) because we do not yet model 83 // their memory effects in jit/resource_operation_safety_analysis. 84 bool allow_resource_ops_in_called_functions = false; 85 86 // Whether Stack operations are allowed. We avoid auto-clustering Stack 87 // operations in general because we do not support snapshotting them. 88 // 89 // TODO(b/112837194): This restriction can be lifted with some work. 90 bool allow_stack_ops = false; 91 92 // Whether TensorArray operations are allowed. We avoid auto-clustering 93 // TensorArray operations in general because we do not support snapshotting 94 // them. 95 // 96 // TODO(b/112837194): This restriction can be lifted with some work. 97 bool allow_tensor_array_ops = false; 98 99 // Whether stateful RNG ops are allowed. XLA's RNG does not have the same 100 // seeding behavior as TensorFlow's RNG (b/34749654). So we avoid 101 // auto-clustering stateful RNG ops. 102 bool allow_stateful_rng_ops = false; 103 104 // TODO(b/118970344): Whether ControlTrigger ops are allowed. It is unsound 105 // to cluster ControlTrigger because of how we use deadness analysis. 106 bool allow_control_trigger = false; 107 108 // Whether it is okay to "cluster" Assert and CheckNumerics by simply 109 // removing them (they're not removed during clustering, but their 110 // XlaOpKernel is a no-op kernel). We avoid auto-clustering these ops so 111 // that the user is not surprised when XLA is implicitly enabled. If the 112 // user explicitly specifies to use XLA, it is fine to resort to a dummy 113 // implementation. Currently Assert and CheckNumerics ops have dummy XLA 114 // implementations. 115 bool allow_eliding_assert_and_checknumerics_ops = false; 116 117 // Whether ops that produce or consume DT_VARIANT values are allowed. We 118 // don't auto-cluster these ops because we don't yet support live-in or 119 // live-out DT_VARIANT values. 120 bool allow_ops_producing_or_consuming_variant = false; 121 122 // Whether ops known to be slow on XLA-GPU should be considered compilable. 123 bool allow_slow_ops = false; 124 125 // Whether ops known to have numerical accuracy issues should be considered 126 // compilable.. 127 bool allow_inaccurate_ops = false; 128 129 // Require the function to be always compilable, regardless whether some 130 // control flow branches might be dead for a given input. 131 bool require_always_compilable = false; 132 133 // Whether string constants are compilable. 134 bool allow_string_consts = true; 135 136 // Whether to allow the compilation of CollectiveReduceV2Op. 137 bool allow_collective_reduce_v2 = true; 138 139 // Whether to allow the compilation of WhereOp. 140 bool allow_where_op = true; 141 142 // Whether to allow the compilation of UniqueOp. Compilation of the UniqueOp 143 // generates output with bounded dynamic shape that may cause failures with 144 // auto clustering. 145 // TODO(b/209813421): Enable tf.unique during 146 // autoclustering once all failures are rfixed. 147 bool allow_unique_op = true; 148 149 // Whether ops that are marked as outside compiled are always considered 150 // compilable. 151 // TODO(b/191502757): Make this behavior true by default and remove this 152 // option once inference converter supports outside compilation. 153 bool allow_outside_compiled = false; 154 }; 155 RecursiveCompilabilityChecker(OperationFilter op_filter,DeviceType jit_device_type)156 RecursiveCompilabilityChecker(OperationFilter op_filter, 157 DeviceType jit_device_type) 158 : op_filter_(std::move(op_filter)), 159 jit_device_type_(std::move(jit_device_type)) {} 160 161 using UncompilableNodesMap = 162 std::map<std::string, 163 std::pair<NameAttrList, std::vector<UncompilableNodeInfo>>>; 164 165 // Returns a map where the key is the function identifier(short debug 166 // string) of the function encapsulating the uncompilable nodes, and the 167 // value is a pair of NameAttrList of the function and a vector of 168 // uncompilable node info. When uncompilable node is not inside any 169 // function call nodes, then key is a ShortDebugString() of an empty 170 // NameAttrList. 171 // 172 // Also, when `node` is inside a function body, users can set 173 // `node_stack_trace` to provide an additional context for `node`'s 174 // placement within the outer most graph. 175 UncompilableNodesMap FindUncompilableNodes( 176 const Node& node, FunctionLibraryRuntime* lib_runtime, 177 const std::vector<StackFrame>* node_stack_trace = nullptr) const; 178 179 // Returns true if `node` can be compiled by XLA. IsCompilableNode(const Node & node,FunctionLibraryRuntime * lib_runtime)180 bool IsCompilableNode(const Node& node, 181 FunctionLibraryRuntime* lib_runtime) const { 182 std::vector<StackFrameView> stack_trace; 183 stack_trace.emplace_back(StackFrameView{node.name(), ""}); 184 return IsCompilableNode(node, lib_runtime, &stack_trace); 185 } 186 187 // Returns true if XLA supports this Op, but we don't want to cluster it (ie: 188 // due to performance or correctness concerns). 189 bool OpIsInaccurate(const Node& node) const; 190 bool OpIsSlow(const Node& node) const; 191 192 private: 193 struct StackFrameView { 194 absl::string_view name; 195 absl::string_view function_name; 196 std::shared_ptr<AbstractStackTrace> stack_trace; 197 }; 198 199 bool IsCompilableNode( 200 const Node& node, FunctionLibraryRuntime* lib_runtime, 201 std::vector<StackFrameView>* stack_trace, 202 NameAttrList* encapsulating_function = nullptr, 203 UncompilableNodesMap* uncompilable_nodes = nullptr) const; 204 bool IsCompilableCall( 205 const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime, 206 std::vector<StackFrameView>* stack_trace, 207 NameAttrList* encapsulating_function = nullptr, 208 UncompilableNodesMap* uncompilable_nodes = nullptr) const; 209 bool IsCompilableIf(const Node& if_node, FunctionLibraryRuntime* lib_runtime, 210 std::vector<StackFrameView>* stack_trace, 211 NameAttrList* encapsulating_function, 212 UncompilableNodesMap* uncompilable_nodes) const; 213 bool IsCompilableWhile(const Node& while_node, 214 FunctionLibraryRuntime* lib_runtime, 215 std::vector<StackFrameView>* stack_trace, 216 NameAttrList* encapsulating_function, 217 UncompilableNodesMap* uncompilable_nodes) const; 218 219 // Tests whether 'case_node' is compilable. Every operator in all branches 220 // must be compilable. 221 bool IsCompilableCase(const Node& case_node, 222 FunctionLibraryRuntime* lib_runtime, 223 std::vector<StackFrameView>* stack_trace, 224 NameAttrList* encapsulating_function, 225 UncompilableNodesMap* uncompilable_nodes) const; 226 227 // Returns compilability of node def retrieved from `node`'s attribute with 228 // name `attr_name`. 229 bool ExtractNodeDefAndCheckCompilability( 230 const Node& node, const std::string& attr_name, 231 const std::string& call_name, NameAttrList* encapsulating_function, 232 FunctionLibraryRuntime* lib_runtime, 233 std::vector<StackFrameView>* stack_trace, 234 UncompilableNodesMap* uncompilable_nodes) const; 235 IsStackOp(const Node & node)236 bool IsStackOp(const Node& node) const { 237 const XlaResourceOpInfo* op_info = 238 GetResourceOpInfoForOp(node.type_string()); 239 return op_info && op_info->resource_kind() == XlaResourceKind::kStack; 240 } 241 IsTensorArrayOp(const Node & node)242 bool IsTensorArrayOp(const Node& node) const { 243 const XlaResourceOpInfo* op_info = 244 GetResourceOpInfoForOp(node.type_string()); 245 return op_info && op_info->resource_kind() == XlaResourceKind::kTensorArray; 246 } 247 IsAssertOrCheckNumerics(absl::string_view op_name)248 bool IsAssertOrCheckNumerics(absl::string_view op_name) const { 249 return op_name == "Assert" || op_name == "CheckNumerics"; 250 } 251 IsStatefulRandomOp(absl::string_view op_name)252 bool IsStatefulRandomOp(absl::string_view op_name) const { 253 return op_name == "RandomUniform" || op_name == "RandomShuffle" || 254 op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" || 255 op_name == "TruncatedNormal" || op_name == "Multinomial"; 256 } 257 OpProducesOrConsumesVariant(const Node & node)258 bool OpProducesOrConsumesVariant(const Node& node) const { 259 auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; }; 260 return absl::c_any_of(node.input_types(), is_variant) || 261 absl::c_any_of(node.output_types(), is_variant); 262 } 263 264 bool HasXLAKernel(const Node& node, 265 string* uncompilable_reason = nullptr) const; 266 267 static void MaybeMarkUncompilableNode( 268 const absl::string_view reason, 269 const std::vector<StackFrameView>& stack_trace, 270 NameAttrList* encapsulating_function, 271 UncompilableNodesMap* uncompilable_nodes_map); 272 273 // Make sure we don't recurse infinitely on recursive functions. 274 const size_t kMaxRecursionDepth = 50; 275 276 const OperationFilter op_filter_; 277 const DeviceType jit_device_type_; 278 }; 279 280 RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter( 281 const XlaOpRegistry::DeviceRegistration& registration); 282 283 // Given a FunctionLibraryRuntime and a `function`, returns this function's body 284 // in `fbody` as well as the indices of its constant and resource arguments. 285 // `fbody` is owned by `flr`. 286 // `constant_arg_indices` and `resource_arg_indices` should be empty vector. 287 // They are sorted in ascending order on this function's return. 288 Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, 289 const NameAttrList& function, 290 const FunctionBody** fbody, 291 std::vector<int>* constant_arg_indices, 292 std::vector<int>* resource_arg_indices); 293 294 // Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr 295 // set. 296 bool CanCreateXlaKernel(const NodeDef& node_def); 297 298 // Returns memory types for the input. 299 // `constant_arg_indices` and `resource_arg_indices` are sorted arrays of 300 // indices corresponding to constant and resource arguments respectively. 301 // 302 // One might wonder, about the case where a compile-time constant argument 303 // (which must be in host memory) is also used as an input into an op, 304 // e.g. `Add`, that expects its inputs in device memory. Here is how it 305 // works now. 306 // First, what do we mean by "op expects an input in XYZ memory"? 307 // There are two types of "ops" here: the tf2xla kernel and the HLO 308 // computation it builds. The tf2xla kernel needs to retrieve the actual 309 // numeric value of the compile-time constant tensors, so it really expects 310 // them to be on in host memory. However, for other inputs, it refers to them 311 // using xla::ComputationDataHandle, which is just a symbolic handle that 312 // xla::ComputationBuilder assigns. How does this handle gets assigned for 313 // constant arguments? Even constant arguments get an _Arg node in the graph 314 // instantiated for Function compilation. The tf2xla kernel for constant _Arg 315 // nodes takes the constant value, converts it to XlaLiteral, and feeds it 316 // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This 317 // constant XlaLiteral is included in the HLO graph, and subsequently, in 318 // the actual executable, which is copied to the device before being 319 // executed. Thus, when this executable runs, the constant is available in 320 // device memory. 321 tensorflow::MemoryTypeVector GetInputMemoryTypes( 322 const tensorflow::FunctionBody* fbody, 323 absl::Span<int const> constant_arg_indices, 324 absl::Span<int const> resource_arg_indices); 325 326 // Returns output memory types. 327 // 328 // XlaLaunch kernel keeps all outputs (including constants, which it copies), 329 // in device memory except for resources. 330 tensorflow::MemoryTypeVector GetOutputMemoryTypes( 331 const tensorflow::FunctionBody* fbody); 332 333 // Check whether graph can trigger XLA compilation. 334 bool CanTriggerXlaCompilation(const GraphDef& graph); 335 336 } // namespace tensorflow 337 338 #endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ 339