xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/xla_helpers.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines helper routines for the XLA device.
17 
18 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
19 #define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
20 
21 #include "absl/types/optional.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/mlir/xla/layout_util.h"
24 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/executable_run_options.h"
27 #include "tensorflow/compiler/xla/service/computation_placer.h"
28 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/tensor.h"
31 
32 namespace tensorflow {
33 
34 using XlaLayoutPreference = mlir::XlaLayoutPreference;
35 
36 // Helper methods for building XLA computations.
37 class XlaHelpers {
38  public:
39   // Returns a handle representing the zero value of a scalar
40   // element of data_type.
41   static xla::XlaOp Zero(xla::XlaBuilder* b, DataType data_type);
42 
43   // Returns a handle representing the one value of a scalar
44   // element of data_type.
45   static xla::XlaOp One(xla::XlaBuilder* b, DataType data_type);
46 
47   // Returns a handle representing the given value of an integer scalar
48   // element of data_type.
49   // Note that unlike One and Zero, does not work on boolean types.
50   static xla::XlaOp IntegerLiteral(xla::XlaBuilder* b, DataType data_type,
51                                    int64_t value);
52 
53   // Returns a handle representing the given value of a floating-point scalar
54   // element of data_type.
55   static xla::XlaOp FloatLiteral(xla::XlaBuilder* b, DataType data_type,
56                                  double value);
57 
58   // Reshapes literal 'input' to have 'shape'. Both the original shape and
59   // 'shape' must contain the same number of elements.
60   static Status ReshapeLiteral(const xla::Literal& input,
61                                absl::Span<const int64_t> shape,
62                                xla::Literal* output);
63 
64   // Converts `indices` into a one-hot representation. `depth` is the size
65   // of the new axis to add. `axis` is the position at which to add the new
66   // axis. `indices_shape` is the shape of `indices`. `on_value` and
67   // `off_value` represent the values to use for the on and off positions,
68   // respectively.
69   static Status OneHot(xla::XlaBuilder* builder, int64_t depth, int axis,
70                        DataType index_type, const TensorShape& indices_shape,
71                        const xla::XlaOp& indices, const xla::XlaOp& on_value,
72                        const xla::XlaOp& off_value, xla::XlaOp* one_hot);
73 
74   // Certain DataTypes should use increased precision DataTypes when performing
75   // reductions.  This function remaps a given DataType to a higher precision
76   // DataType if needed.
77   static DataType SumAccumulationType(const DataType& dtype);
78 
79   // A helper for creating a ConvertElementType xla op given a DataType rather
80   // than the xla::PrimitiveType.
81   static xla::XlaOp ConvertElementType(const xla::XlaOp& operand,
82                                        const DataType new_element_type);
83 
84   typedef std::function<StatusOr<xla::Shape>(const TensorShape&, DataType, bool,
85                                              XlaLayoutPreference)>
86       ShapeRepresentationFn;
87 };
88 
89 // Creates an identity shape representation function.
90 XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn();
91 
92 struct XlaOutputDescription {
93   // Type and shape of the output. The shape is the unflattened shape.
94   // When `type` is DT_RESOURCE, `shape` is the shape of the resource
95   // variable's value.
96   DataType type;
97   TensorShape shape;
98 
99   // Constant output value, if known to be constant at JIT compilation time.
100   // 'Tensor' is in host memory.
101   bool is_constant = false;
102   Tensor constant_value;
103 
104   // When this output is a resource, i.e. `type == DT_RESOURCE`, this is
105   // the index of the input that contains the resource.
106   int input_index;
107 
108   // Whether this output is a TensorList.
109   bool is_tensor_list = false;
110 };
111 
112 // Describes a variable write side effect of the computation.
113 struct XlaResourceUpdate {
114   // Index of the input that contains the variable resource to write to.
115   int input_index;
116 
117   // Type and shape of the tensor to be written back.
118   // The `shape` field has the same meaning as the Argument::shape field.
119   DataType type;
120   TensorShape shape;
121 
122   // Was the value of the variable modified by the computation?
123   // (Always true, unless `return_updated_values_for_all_resources` is true.)
124   bool modified;
125 
126   // If the resource is a TensorArray, the set of gradients read or written.
127   std::set<string> tensor_array_gradients_accessed;
128 };
129 
130 struct XlaCompilationResult {
131   // Vector that maps from the parameters of the XLA computation to their
132   // original argument positions. To handle compile-time constant inputs, the
133   // parameters to the XLA computation may be a subset of the original
134   // arguments. The relative ordering of parameters are maintained.
135   std::vector<int> input_mapping;
136 
137   // Input shapes of the computation. If we are flattening inputs, these are
138   // the flattened shapes.
139   std::vector<xla::Shape> xla_input_shapes;
140 
141   // Output shape in XLA format. The output shape is always a tuple. If we
142   // are flattening outputs, these are the flattened shapes.
143   xla::Shape xla_output_shape;
144 
145   // TensorFlow shapes of outputs, together with the values of any
146   // constant arguments. Vector indexed by Tensorflow _Retval number,
147   // containing both constant and non-constant results.
148   std::vector<XlaOutputDescription> outputs;
149 
150   // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
151   // matching RecvAtHost/SendFromHost Ops in the outer graph.
152   tf2xla::HostComputeMetadata host_compute_metadata;
153 
154   // Resources whose values were updated by the computation, ordered
155   // by return value position (which is the same as the order the resources
156   // were passed as arguments). Resource updates follow the non-constant
157   // results in the outputs of XLA computation.
158   std::vector<XlaResourceUpdate> resource_updates;
159 
160   // The XLA computation built from the tensorflow subgraph.
161   std::shared_ptr<xla::XlaComputation> computation;
162 
163   // Meta-info about encountered collective ops.
164   struct CollectiveInfo {
165     int group_key;
166     int group_size;
167     int next_id;
168 
169     template <typename H>
AbslHashValueXlaCompilationResult::CollectiveInfo170     friend H AbslHashValue(H h, const CollectiveInfo& info) {
171       return H::combine(std::move(h), info.group_key, info.group_size,
172                         info.next_id);
173     }
174 
175     friend bool operator==(const CollectiveInfo& lhs,
176                            const CollectiveInfo& rhs) {
177       return lhs.group_key == rhs.group_key &&
178              lhs.group_size == rhs.group_size && lhs.next_id == rhs.next_id;
179     }
180   };
181 
182   // Information of the collectives encountered during the translation.
183   std::optional<CollectiveInfo> collective_info;
184 };
185 
186 // Resolves the device assignment based on CollectiveInfo.
187 // CollectiveInfo records collective ops in the cluster. Note that
188 // this relies on a rendezvous and blocks until all replicas are there.
189 //
190 // Takes several extra configuration objects by reference since
191 // xla::ExecutableRunOptions does not take ownership; these are configured and
192 // bundled into `run_options` if applicable.
193 Status ResolveDeviceAssignment(
194     OpKernelContext* ctx,
195     const XlaCompilationResult::CollectiveInfo& collective_info,
196     xla::ExecutableRunOptions& run_options,
197     xla::DeviceAssignment& device_assignment,
198     xla::gpu::GpuExecutableRunOptions& gpu_options);
199 
200 }  // end namespace tensorflow
201 
202 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
203