xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/xla_compiler.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
18 
19 #include <stack>
20 
21 #include "absl/types/span.h"
22 #include "absl/types/variant.h"
23 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
24 #include "tensorflow/compiler/tf2xla/layout_util.h"
25 #include "tensorflow/compiler/tf2xla/xla_argument.h"
26 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
27 #include "tensorflow/compiler/tf2xla/xla_expression.h"
28 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
29 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
30 #include "tensorflow/compiler/xla/client/local_client.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/compiler/xla/client/xla_computation.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/core/common_runtime/device.h"
35 #include "tensorflow/core/common_runtime/device_mgr.h"
36 #include "tensorflow/core/common_runtime/function.h"
37 #include "tensorflow/core/framework/function.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/mutex.h"
41 #include "tensorflow/core/platform/notification.h"
42 #include "tensorflow/core/platform/thread_annotations.h"
43 #include "tensorflow/core/protobuf/config.pb.h"
44 #include "tensorflow/core/public/version.h"
45 
46 namespace tensorflow {
47 
48 class XlaContext;
49 
50 // The XlaCompiler class is responsible for compilation of a self-contained
51 // subgraph of a TensorFlow computation using the XLA linear algebra runtime.
52 // It does a symbolic execution of the graph starting from specific input
53 // shapes, using a JIT device to convert operators into XLA computations.
54 //
55 // XlaCompiler is typically invoked from an `XlaLaunch` operator once the
56 // shapes of all input parameters to the computation are known. This is
57 // because the symbolic execution requires known shapes for all operations.
58 //
59 // XlaCompiler compiles Tensorflow graphs that received inputs via _Arg nodes,
60 // and return outputs via _Retval nodes.
61 //
62 // The XlaCompiler requires one Argument struct for each _Arg index, that
63 // describes each argument. Arguments can be compile-time constants
64 // (kind kConstant), run-time parameters (kind kParameter), or resources
65 // (kind kResource).
66 //
67 // Only kParameter and initialized kResource arguments become runtime parameters
68 // to the generated XLA computation.
69 //
70 // The run-time outputs of the XLA computation are arranged in the following
71 // order:
72 //   +------------------+-----------------------------------------+
73 //   |  _Retval values  |  Updated values of kResource arguments  |
74 //   +------------------+-----------------------------------------+
75 // _Retval values are ordered by _Retval index, whereas kResource values are
76 // ordered by the original _Arg position of the variable.
77 //
78 // If a shape representation function is provided as part of
79 // XlaCompiler::CompileOptions, kParameter arguments and return values to an
80 // entry computation will be reshaped in accordance to the shape function.
81 // Arguments and return values to a non-entry computation are not reshaped.
82 // Variable resource arguments are passed and returned in reshaped form, even
83 // for non-entry computations. This feature allows TensorFlow to keep on-device
84 // tensors with a different shape to their representation inside the XLA
85 // computation.
86 //
87 // In computation outputs, updated kResource values are placed the end. When
88 // emitting While loop bodies, we must ensure that the loop body has
89 // identical input and output signatures. By passing variable values
90 // at the end of the argument list and using the
91 // `return_updated_values_for_all_variables` option, we can ensure that the
92 // input and output values of resources appear at the same positions.
93 //
94 // Resources are passed as parameters or returned as resource updates in
95 // "packed" form.
96 // kStack resources are packed as (array, size of stack) XLA tuples.
97 // kTensorArray resources without gradients are packed as the array that
98 // backs the TensorArray. If gradients are present (`tensor_array_gradients`),
99 // the packed representation is a (array, gradient0, gradient1, ...) tuple,
100 // where gradient_k is the value of the k-th gradient in the
101 // `tensor_array_gradients` ordered set.
102 class XlaCompiler {
103  public:
104   using Argument = ::tensorflow::XlaArgument;
105 
106   // Options pertaining to an individual call to CompileGraph() or
107   // CompileFunction().
108   struct CompileOptions {
109     // If `use_tuple_arg` is true, a single tuple parameter will be used for all
110     // arguments; if false, each argument gets its own parameter.
111     bool use_tuple_arg = false;
112 
113     // If 'return_updated_values_for_all_resources' is true, then updated
114     // values of all resource arguments will be included in the
115     // 'resource_updates' of the computation, even if the resource was not
116     // modified by the computation. Used when compiling loop bodies to ensure
117     // the input and output signatures match.
118     bool return_updated_values_for_all_resources = false;
119 
120     // If 'always_return_tuple' is true, then the output of a computation will
121     // always be a tuple. Otherwise, a single-element output will not be wrapped
122     // in a tuple.
123     bool always_return_tuple = true;
124 
125     // True when compiling the entry computation, false for subcomputations
126     // (while, call, etc.)
127     bool is_entry_computation = true;
128 
129     // True when we should add XLA input & output to the graph/function.
130     bool add_token_input_output = false;
131 
132     // Resource updates are converted into input / output of xla. The two
133     // buffers are aliased with other if this option is true.
134     bool alias_resource_update = false;
135   };
136 
137   using OutputDescription = ::tensorflow::XlaOutputDescription;
138 
139   using ResourceUpdate = ::tensorflow::XlaResourceUpdate;
140 
141   using CompilationResult = ::tensorflow::XlaCompilationResult;
142 
143   struct Options {
144     // Name of the compilation device to use. It must be set by the caller.
145     // The default empty value is invalid.
146     DeviceType device_type = DeviceType("");
147 
148     // The device to use during compilation to execute instructions on, for
149     // example for auto-tuning.
150     // Valid values are defined by `xla::Backend::devices_ordinal_supported()`.
151     // -1 indicates the default device should be used.
152     int device_ordinal = -1;
153 
154     xla::Client* client = nullptr;
155 
156     // Function library in which to find function definitions. Must be non-null.
157     const FunctionLibraryDefinition* flib_def = nullptr;
158 
159     // The graph def version to be compiled.
160     int graph_def_version = TF_GRAPH_DEF_VERSION;
161 
162     // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall()
163     // for CPU.
164     bool allow_cpu_custom_calls = false;
165 
166     // A ShapeDeterminationFns (i.e., a bundle of LayoutSelectionFn and
167     // ShapeRepresentationFn). Each bundle describes the XLA representation of
168     // arguments represented to XLA as the shape given by this shape function.
169     // Arguments are input activations or weights to an XLA entry computation.
170     // Variables are reshaped to this shape on write, and reshaped to their
171     // original shape on read.
172     XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns;
173 
174     // If not nullptr, populate_resource_manager is called with the
175     // compilation device's resource manager when the compilation
176     // device is created, and can be used to create metadata objects
177     // that can be accessed by XLA op kernels.
178     std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr;
179 
180     // If not nullptr, this memory allocator can be used by the compiler for
181     // temporary allocations it might want to make during compilation.
182     //
183     // For example, the compiler may want to try out different algorithms and
184     // choose the fastest one, and it might run those algorithms over buffers
185     // created using this allocator.
186     //
187     // The compiler can function correctly without an explicit allocator given
188     // here, but on some devices (notably, GPUs), TensorFlow tends to eagerly
189     // allocate most or all available memory on the device, leaving none for the
190     // compiler to access, unless it can use TensorFlow's allocator.
191     // This must be a shared_ptr, as this is passed all the way down to the
192     // cluster compilation. This allows asynchronous compilation to hold a
193     // reference until the compilation is finished.
194     std::shared_ptr<se::DeviceMemoryAllocator> device_allocator;
195 
196     // Alias input and output buffers for parameters that are passed-through XLA
197     // modules without being changed.
198     bool alias_passthrough_params = false;
199 
200     // Enable detailed logging of compilation metadata.
201     bool detailed_logging = true;
202   };
203 
204   // Argument for compiling a single op.
205   struct SingleOpCompileArgument {
206     // Data type of the output tensors. This is used to create _Retval node.
207     std::vector<DataType> output_dtypes;
208 
209     // The NodeDef representing the op.
210     NodeDef node_def;
211 
212     // This is currently only used to obtain MLIR TPU bridge rollout state.
213     // Can be removed once full rollout is complete.
214     ConfigProto config_proto;
215   };
216 
217   explicit XlaCompiler(Options options);
218 
219   ~XlaCompiler();
220 
221   // Helper function to populate an XlaCompiler::Argument from XlaResource.
222   static void PopulateArgumentFromResource(const XlaResource& resource,
223                                            Argument* arg);
224 
225   Status CompileFunction(const CompileOptions& options,
226                          const NameAttrList& fn_name_attrs,
227                          absl::Span<const Argument> args,
228                          CompilationResult* result);
229 
230   // Compiles a tensorflow::Graph into an xla::XlaComputation.
231   // Similar to CompileFunction, but takes a Graph as input rather than a
232   // function.
233   Status CompileGraph(
234       const CompileOptions& options, string const& name,
235       std::unique_ptr<Graph> graph, absl::Span<const Argument> args,
236       CompilationResult* result);
237 
238   // Returns the shape of the XLA parameter for an argument 'arg'.
239   // See the class comment for more details about the argument passing
240   // convention.
241   Status XLAShapeForArgument(
242       const Argument& arg, bool is_entry_computation,
243       const std::optional<xla::HloSharding>& arg_sharding,
244       xla::Shape* xla_shape) const;
245 
246   // Retrieves the channel handle associated with `key`. Allocates
247   // a new channel handle if none exists.
248   // Channel handles can be used to communicate between different
249   // computations. Computations that communicate should be compiled with the
250   // same XlaCompiler.
251   Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
252 
253   // Retrieves the host-to-device channel handle associated with `key`.
254   // Allocates a new channel handle if none exists.
255   Status GetHostToDeviceChannelHandle(const string& key,
256                                       xla::ChannelHandle* channel);
257 
258   // Retrieves the device-to-host channel handle associated with `key`.
259   // Allocates a new channel handle if none exists.
260   Status GetDeviceToHostChannelHandle(const string& key,
261                                       xla::ChannelHandle* channel);
262 
263   // Sets the shapes and types for the device to host transfer associated with
264   // 'key'.
265   Status SetDeviceToHostMetadata(const string& key,
266                                  absl::Span<const DataType> types,
267                                  absl::Span<const TensorShape> shapes);
268 
269   // Gets the shapes the device to host transfer associated with 'key'.
270   Status GetDeviceToHostShapes(const string& key,
271                                std::vector<TensorShape>* shapes) const;
272 
273   // Sets the shapes and types for the host to device transfer associated with
274   // 'key'.
275   Status SetHostToDeviceMetadata(const string& key,
276                                  absl::Span<const DataType> types,
277                                  absl::Span<const TensorShape> shapes);
278 
279   // In order to avoid deadlocks from dependencies in host computations, it can
280   // be necessary to enforce a partial order on the execution of HostCompute
281   // Ops. In particular it may be necessary to constrain the SendToHost for one
282   // HostCompute to run before blocking on the RecvAtHost for another
283   // HostCompute. The compiler maintains a mapping from 'host_compute_name' to
284   // handle, where the handle is an 'output' of the HostCompute Op corresponding
285   // to 'host_compute_name'. Another HostCompute Op that needs to be sequenced
286   // later can add the handle as an 'input' to enforce the constraints.
287   // 'host_compute_name' can be any string the client wishes to use to identify
288   // a given HostCompute Op as long as the names are unique within the
289   // compilation.
290   Status GetHostComputeControlDependency(const string& host_compute_name,
291                                          xla::XlaOp* handle);
292   Status SetHostComputeControlDependency(const string& host_compute_name,
293                                          const xla::XlaOp& handle);
294 
options()295   const Options& options() const { return options_; }
client()296   xla::Client* client() const { return options_.client; }
flib_runtime()297   FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
298 
299   void PushNodeTokenMapping();
300   Status PopNodeTokenMapping();
301   Status SetNodeToken(const string& node_name, const xla::XlaOp& op);
302   StatusOr<xla::XlaOp> GetNodeToken(const string& node_name);
303 
304   // Sets the function body `fbody` to the one registered as `function`.
305   Status FindFunctionBody(const NameAttrList& function,
306                           const FunctionBody** fbody,
307                           const ConfigProto** config_proto = nullptr);
308 
309  private:
310   // Returns the optimized graph object in this function body.
311   std::unique_ptr<Graph> GetGraph(const FunctionBody* fbody);
312 
313   // Builds XLA computations for each of the arguments to the computation.
314   // `args` are the arguments to the computation.
315   Status BuildArguments(const Graph& graph,
316                         const std::vector<XlaCompiler::Argument>& args,
317                         bool use_tuple_arg, xla::XlaBuilder* builder,
318                         XlaContext* context,
319                         const std::map<int, xla::OpSharding>& arg_shardings,
320                         std::vector<XlaExpression>* arg_expressions,
321                         std::vector<int>* input_to_args,
322                         std::vector<xla::Shape>* input_shapes,
323                         bool is_entry_computation);
324 
325   // Graph compiler needs to know how to get an optimized graph from a function
326   // body.
327   friend class GraphCompiler;
328   friend class XlaCompilerTest;
329 
330   Options options_;
331 
332   // Status set to non-OK in the constructor if initialization fails.
333   Status initialization_status_;
334 
335   // Returns the next step sequence number.
336   int64_t NextStepId();
337 
338   // Internal sequence number for steps executed on the compilation device.
339   int64_t next_step_id_;
340 
341   XlaCompilationDevice* device_;  // Owned by device_mgr_
342   StaticDeviceMgr device_mgr_;
343 
344   // To avoid copying the client's function library, use a local function
345   // library and runtime for functions created as part of the functionalize
346   // control flow transformation.
347   std::unique_ptr<FunctionLibraryDefinition> local_flib_def_;
348   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
349   std::unique_ptr<ProcessFunctionLibraryRuntime> local_pflr_;
350 
351   FunctionLibraryRuntime* local_flib_runtime_;  // owned by local_pflr_.
352   FunctionLibraryRuntime* flib_runtime_;        // owned by pflr_.
353 
354   struct SignatureHash {
355     uint64 operator()(
356         const std::pair<string, std::vector<Argument>>& signature) const;
357   };
358 
359   std::unordered_map<std::pair<string, std::vector<Argument>>,
360                      CompilationResult, SignatureHash>
361       cache_;
362 
363   std::unordered_map<string, xla::ChannelHandle> channels_;
364 
365   std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_sends_;
366   std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_recvs_;
367 
368   std::unordered_map<string, xla::XlaOp> host_compute_control_output_;
369 
370   // This is used to store <node name, token output> mapping. Side-effecting
371   // ops call SetNodeToken() to record its token output, so later side-effecting
372   // ops can use GetNodeToken() to get it and use it as token input.
373   //
374   // It's a stack because we need a mapping like this for each level of nested
375   // CompileGraph() call. In CompileGraph(), we will push a new mapping to the
376   // stack, and pop the mapping before returning.
377   std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_;
378 
379   TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
380 };
381 
382 
383 }  // namespace tensorflow
384 
385 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
386