1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ 18 19 #include <stack> 20 21 #include "absl/types/span.h" 22 #include "absl/types/variant.h" 23 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" 24 #include "tensorflow/compiler/tf2xla/layout_util.h" 25 #include "tensorflow/compiler/tf2xla/xla_argument.h" 26 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" 27 #include "tensorflow/compiler/tf2xla/xla_expression.h" 28 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 29 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 30 #include "tensorflow/compiler/xla/client/local_client.h" 31 #include "tensorflow/compiler/xla/client/xla_builder.h" 32 #include "tensorflow/compiler/xla/client/xla_computation.h" 33 #include "tensorflow/compiler/xla/status_macros.h" 34 #include "tensorflow/core/common_runtime/device.h" 35 #include "tensorflow/core/common_runtime/device_mgr.h" 36 #include "tensorflow/core/common_runtime/function.h" 37 #include "tensorflow/core/framework/function.h" 38 #include "tensorflow/core/lib/core/errors.h" 39 #include "tensorflow/core/platform/env.h" 40 #include "tensorflow/core/platform/mutex.h" 41 #include "tensorflow/core/platform/notification.h" 42 #include "tensorflow/core/platform/thread_annotations.h" 43 #include "tensorflow/core/protobuf/config.pb.h" 44 #include "tensorflow/core/public/version.h" 45 46 namespace tensorflow { 47 48 class XlaContext; 49 50 // The XlaCompiler class is responsible for compilation of a self-contained 51 // subgraph of a TensorFlow computation using the XLA linear algebra runtime. 52 // It does a symbolic execution of the graph starting from specific input 53 // shapes, using a JIT device to convert operators into XLA computations. 54 // 55 // XlaCompiler is typically invoked from an `XlaLaunch` operator once the 56 // shapes of all input parameters to the computation are known. This is 57 // because the symbolic execution requires known shapes for all operations. 58 // 59 // XlaCompiler compiles Tensorflow graphs that received inputs via _Arg nodes, 60 // and return outputs via _Retval nodes. 61 // 62 // The XlaCompiler requires one Argument struct for each _Arg index, that 63 // describes each argument. Arguments can be compile-time constants 64 // (kind kConstant), run-time parameters (kind kParameter), or resources 65 // (kind kResource). 66 // 67 // Only kParameter and initialized kResource arguments become runtime parameters 68 // to the generated XLA computation. 69 // 70 // The run-time outputs of the XLA computation are arranged in the following 71 // order: 72 // +------------------+-----------------------------------------+ 73 // | _Retval values | Updated values of kResource arguments | 74 // +------------------+-----------------------------------------+ 75 // _Retval values are ordered by _Retval index, whereas kResource values are 76 // ordered by the original _Arg position of the variable. 77 // 78 // If a shape representation function is provided as part of 79 // XlaCompiler::CompileOptions, kParameter arguments and return values to an 80 // entry computation will be reshaped in accordance to the shape function. 81 // Arguments and return values to a non-entry computation are not reshaped. 82 // Variable resource arguments are passed and returned in reshaped form, even 83 // for non-entry computations. This feature allows TensorFlow to keep on-device 84 // tensors with a different shape to their representation inside the XLA 85 // computation. 86 // 87 // In computation outputs, updated kResource values are placed the end. When 88 // emitting While loop bodies, we must ensure that the loop body has 89 // identical input and output signatures. By passing variable values 90 // at the end of the argument list and using the 91 // `return_updated_values_for_all_variables` option, we can ensure that the 92 // input and output values of resources appear at the same positions. 93 // 94 // Resources are passed as parameters or returned as resource updates in 95 // "packed" form. 96 // kStack resources are packed as (array, size of stack) XLA tuples. 97 // kTensorArray resources without gradients are packed as the array that 98 // backs the TensorArray. If gradients are present (`tensor_array_gradients`), 99 // the packed representation is a (array, gradient0, gradient1, ...) tuple, 100 // where gradient_k is the value of the k-th gradient in the 101 // `tensor_array_gradients` ordered set. 102 class XlaCompiler { 103 public: 104 using Argument = ::tensorflow::XlaArgument; 105 106 // Options pertaining to an individual call to CompileGraph() or 107 // CompileFunction(). 108 struct CompileOptions { 109 // If `use_tuple_arg` is true, a single tuple parameter will be used for all 110 // arguments; if false, each argument gets its own parameter. 111 bool use_tuple_arg = false; 112 113 // If 'return_updated_values_for_all_resources' is true, then updated 114 // values of all resource arguments will be included in the 115 // 'resource_updates' of the computation, even if the resource was not 116 // modified by the computation. Used when compiling loop bodies to ensure 117 // the input and output signatures match. 118 bool return_updated_values_for_all_resources = false; 119 120 // If 'always_return_tuple' is true, then the output of a computation will 121 // always be a tuple. Otherwise, a single-element output will not be wrapped 122 // in a tuple. 123 bool always_return_tuple = true; 124 125 // True when compiling the entry computation, false for subcomputations 126 // (while, call, etc.) 127 bool is_entry_computation = true; 128 129 // True when we should add XLA input & output to the graph/function. 130 bool add_token_input_output = false; 131 132 // Resource updates are converted into input / output of xla. The two 133 // buffers are aliased with other if this option is true. 134 bool alias_resource_update = false; 135 }; 136 137 using OutputDescription = ::tensorflow::XlaOutputDescription; 138 139 using ResourceUpdate = ::tensorflow::XlaResourceUpdate; 140 141 using CompilationResult = ::tensorflow::XlaCompilationResult; 142 143 struct Options { 144 // Name of the compilation device to use. It must be set by the caller. 145 // The default empty value is invalid. 146 DeviceType device_type = DeviceType(""); 147 148 // The device to use during compilation to execute instructions on, for 149 // example for auto-tuning. 150 // Valid values are defined by `xla::Backend::devices_ordinal_supported()`. 151 // -1 indicates the default device should be used. 152 int device_ordinal = -1; 153 154 xla::Client* client = nullptr; 155 156 // Function library in which to find function definitions. Must be non-null. 157 const FunctionLibraryDefinition* flib_def = nullptr; 158 159 // The graph def version to be compiled. 160 int graph_def_version = TF_GRAPH_DEF_VERSION; 161 162 // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall() 163 // for CPU. 164 bool allow_cpu_custom_calls = false; 165 166 // A ShapeDeterminationFns (i.e., a bundle of LayoutSelectionFn and 167 // ShapeRepresentationFn). Each bundle describes the XLA representation of 168 // arguments represented to XLA as the shape given by this shape function. 169 // Arguments are input activations or weights to an XLA entry computation. 170 // Variables are reshaped to this shape on write, and reshaped to their 171 // original shape on read. 172 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns; 173 174 // If not nullptr, populate_resource_manager is called with the 175 // compilation device's resource manager when the compilation 176 // device is created, and can be used to create metadata objects 177 // that can be accessed by XLA op kernels. 178 std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr; 179 180 // If not nullptr, this memory allocator can be used by the compiler for 181 // temporary allocations it might want to make during compilation. 182 // 183 // For example, the compiler may want to try out different algorithms and 184 // choose the fastest one, and it might run those algorithms over buffers 185 // created using this allocator. 186 // 187 // The compiler can function correctly without an explicit allocator given 188 // here, but on some devices (notably, GPUs), TensorFlow tends to eagerly 189 // allocate most or all available memory on the device, leaving none for the 190 // compiler to access, unless it can use TensorFlow's allocator. 191 // This must be a shared_ptr, as this is passed all the way down to the 192 // cluster compilation. This allows asynchronous compilation to hold a 193 // reference until the compilation is finished. 194 std::shared_ptr<se::DeviceMemoryAllocator> device_allocator; 195 196 // Alias input and output buffers for parameters that are passed-through XLA 197 // modules without being changed. 198 bool alias_passthrough_params = false; 199 200 // Enable detailed logging of compilation metadata. 201 bool detailed_logging = true; 202 }; 203 204 // Argument for compiling a single op. 205 struct SingleOpCompileArgument { 206 // Data type of the output tensors. This is used to create _Retval node. 207 std::vector<DataType> output_dtypes; 208 209 // The NodeDef representing the op. 210 NodeDef node_def; 211 212 // This is currently only used to obtain MLIR TPU bridge rollout state. 213 // Can be removed once full rollout is complete. 214 ConfigProto config_proto; 215 }; 216 217 explicit XlaCompiler(Options options); 218 219 ~XlaCompiler(); 220 221 // Helper function to populate an XlaCompiler::Argument from XlaResource. 222 static void PopulateArgumentFromResource(const XlaResource& resource, 223 Argument* arg); 224 225 Status CompileFunction(const CompileOptions& options, 226 const NameAttrList& fn_name_attrs, 227 absl::Span<const Argument> args, 228 CompilationResult* result); 229 230 // Compiles a tensorflow::Graph into an xla::XlaComputation. 231 // Similar to CompileFunction, but takes a Graph as input rather than a 232 // function. 233 Status CompileGraph( 234 const CompileOptions& options, string const& name, 235 std::unique_ptr<Graph> graph, absl::Span<const Argument> args, 236 CompilationResult* result); 237 238 // Returns the shape of the XLA parameter for an argument 'arg'. 239 // See the class comment for more details about the argument passing 240 // convention. 241 Status XLAShapeForArgument( 242 const Argument& arg, bool is_entry_computation, 243 const std::optional<xla::HloSharding>& arg_sharding, 244 xla::Shape* xla_shape) const; 245 246 // Retrieves the channel handle associated with `key`. Allocates 247 // a new channel handle if none exists. 248 // Channel handles can be used to communicate between different 249 // computations. Computations that communicate should be compiled with the 250 // same XlaCompiler. 251 Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); 252 253 // Retrieves the host-to-device channel handle associated with `key`. 254 // Allocates a new channel handle if none exists. 255 Status GetHostToDeviceChannelHandle(const string& key, 256 xla::ChannelHandle* channel); 257 258 // Retrieves the device-to-host channel handle associated with `key`. 259 // Allocates a new channel handle if none exists. 260 Status GetDeviceToHostChannelHandle(const string& key, 261 xla::ChannelHandle* channel); 262 263 // Sets the shapes and types for the device to host transfer associated with 264 // 'key'. 265 Status SetDeviceToHostMetadata(const string& key, 266 absl::Span<const DataType> types, 267 absl::Span<const TensorShape> shapes); 268 269 // Gets the shapes the device to host transfer associated with 'key'. 270 Status GetDeviceToHostShapes(const string& key, 271 std::vector<TensorShape>* shapes) const; 272 273 // Sets the shapes and types for the host to device transfer associated with 274 // 'key'. 275 Status SetHostToDeviceMetadata(const string& key, 276 absl::Span<const DataType> types, 277 absl::Span<const TensorShape> shapes); 278 279 // In order to avoid deadlocks from dependencies in host computations, it can 280 // be necessary to enforce a partial order on the execution of HostCompute 281 // Ops. In particular it may be necessary to constrain the SendToHost for one 282 // HostCompute to run before blocking on the RecvAtHost for another 283 // HostCompute. The compiler maintains a mapping from 'host_compute_name' to 284 // handle, where the handle is an 'output' of the HostCompute Op corresponding 285 // to 'host_compute_name'. Another HostCompute Op that needs to be sequenced 286 // later can add the handle as an 'input' to enforce the constraints. 287 // 'host_compute_name' can be any string the client wishes to use to identify 288 // a given HostCompute Op as long as the names are unique within the 289 // compilation. 290 Status GetHostComputeControlDependency(const string& host_compute_name, 291 xla::XlaOp* handle); 292 Status SetHostComputeControlDependency(const string& host_compute_name, 293 const xla::XlaOp& handle); 294 options()295 const Options& options() const { return options_; } client()296 xla::Client* client() const { return options_.client; } flib_runtime()297 FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } 298 299 void PushNodeTokenMapping(); 300 Status PopNodeTokenMapping(); 301 Status SetNodeToken(const string& node_name, const xla::XlaOp& op); 302 StatusOr<xla::XlaOp> GetNodeToken(const string& node_name); 303 304 // Sets the function body `fbody` to the one registered as `function`. 305 Status FindFunctionBody(const NameAttrList& function, 306 const FunctionBody** fbody, 307 const ConfigProto** config_proto = nullptr); 308 309 private: 310 // Returns the optimized graph object in this function body. 311 std::unique_ptr<Graph> GetGraph(const FunctionBody* fbody); 312 313 // Builds XLA computations for each of the arguments to the computation. 314 // `args` are the arguments to the computation. 315 Status BuildArguments(const Graph& graph, 316 const std::vector<XlaCompiler::Argument>& args, 317 bool use_tuple_arg, xla::XlaBuilder* builder, 318 XlaContext* context, 319 const std::map<int, xla::OpSharding>& arg_shardings, 320 std::vector<XlaExpression>* arg_expressions, 321 std::vector<int>* input_to_args, 322 std::vector<xla::Shape>* input_shapes, 323 bool is_entry_computation); 324 325 // Graph compiler needs to know how to get an optimized graph from a function 326 // body. 327 friend class GraphCompiler; 328 friend class XlaCompilerTest; 329 330 Options options_; 331 332 // Status set to non-OK in the constructor if initialization fails. 333 Status initialization_status_; 334 335 // Returns the next step sequence number. 336 int64_t NextStepId(); 337 338 // Internal sequence number for steps executed on the compilation device. 339 int64_t next_step_id_; 340 341 XlaCompilationDevice* device_; // Owned by device_mgr_ 342 StaticDeviceMgr device_mgr_; 343 344 // To avoid copying the client's function library, use a local function 345 // library and runtime for functions created as part of the functionalize 346 // control flow transformation. 347 std::unique_ptr<FunctionLibraryDefinition> local_flib_def_; 348 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; 349 std::unique_ptr<ProcessFunctionLibraryRuntime> local_pflr_; 350 351 FunctionLibraryRuntime* local_flib_runtime_; // owned by local_pflr_. 352 FunctionLibraryRuntime* flib_runtime_; // owned by pflr_. 353 354 struct SignatureHash { 355 uint64 operator()( 356 const std::pair<string, std::vector<Argument>>& signature) const; 357 }; 358 359 std::unordered_map<std::pair<string, std::vector<Argument>>, 360 CompilationResult, SignatureHash> 361 cache_; 362 363 std::unordered_map<string, xla::ChannelHandle> channels_; 364 365 std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_sends_; 366 std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_recvs_; 367 368 std::unordered_map<string, xla::XlaOp> host_compute_control_output_; 369 370 // This is used to store <node name, token output> mapping. Side-effecting 371 // ops call SetNodeToken() to record its token output, so later side-effecting 372 // ops can use GetNodeToken() to get it and use it as token input. 373 // 374 // It's a stack because we need a mapping like this for each level of nested 375 // CompileGraph() call. In CompileGraph(), we will push a new mapping to the 376 // stack, and pop the mapping before returning. 377 std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_; 378 379 TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); 380 }; 381 382 383 } // namespace tensorflow 384 385 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ 386