xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/shape_inference.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Shape inference is used by the XLA service as the user builds up
17 // computation requests.
18 
19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
20 #define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
21 
22 #include <vector>
23 
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 
31 namespace xla {
32 
33 // For a given operation and input shapes, infers what the resulting shape is
34 // for the operation. With this functionality, the user does not need to specify
35 // the expected result type for computations that are built up via the API --
36 // the shape that results from an operation is inferred. Some methods have
37 // overloads for inferring shape at the HLO level.
38 //
39 // TODO(b/73352135): Shape inference does not issue very good error messages, in
40 // part because HloInstruction::ToString() is not available since shape
41 // inference runs before the HloInstruction object is created. We need a
42 // solution for this.
43 class ShapeInference {
44  public:
45   // Infers the shape produced by applying the given unary operation to the
46   // given input shape.
47   static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode,
48                                            const Shape& shape);
49   static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode,
50                                            const HloInstruction* operand);
51 
52   // Infers the shape produced by applying the given binary operation to the
53   // given input shapes.
54   static StatusOr<Shape> InferBinaryOpShape(
55       HloOpcode opcode, const Shape& lhs, const Shape& rhs,
56       absl::Span<const int64_t> broadcast_dimensions);
57   static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode,
58                                             const HloInstruction* lhs,
59                                             const HloInstruction* rhs);
60 
61   // Infers the shape produced by applying the given ternary operation to the
62   // given input shapes.
63   static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode, const Shape& lhs,
64                                              const Shape& rhs,
65                                              const Shape& ehs);
66   static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode,
67                                              const HloInstruction* lhs,
68                                              const HloInstruction* rhs,
69                                              const HloInstruction* ehs);
70 
71   // Infers the shape produced by applying the given variadic operation to the
72   // given input operand shapes.
73   static StatusOr<Shape> InferVariadicOpShape(
74       HloOpcode opcode, absl::Span<const Shape* const> operand_shapes);
75   static StatusOr<Shape> InferVariadicOpShape(
76       HloOpcode opcode, absl::Span<const HloInstruction* const> operands);
77 
78   // Infers the shape produced by applying the given mapping computation shape
79   // to the given operand shapes.
80   static StatusOr<Shape> InferMapShape(
81       absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
82       absl::Span<const int64_t> dimensions);
83 
84   // Infers the shape produced by InferBatchNormTraining with the given
85   // operands.
86   static StatusOr<Shape> InferBatchNormTrainingShape(const Shape& operand_shape,
87                                                      const Shape& scale_shape,
88                                                      const Shape& offset_shape,
89                                                      int64_t feature_index);
90 
91   // Infers the shape produced by InferBatchNormInference with the given
92   // operands.
93   static StatusOr<Shape> InferBatchNormInferenceShape(
94       const Shape& operand_shape, const Shape& scale_shape,
95       const Shape& offset_shape, const Shape& mean_shape,
96       const Shape& variance_shape, int64_t feature_index);
97 
98   // Infers the shape produced by InferBatchNormGrad with the given operands.
99   static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape,
100                                                  const Shape& scale_shape,
101                                                  const Shape& mean_shape,
102                                                  const Shape& var_shape,
103                                                  const Shape& output_grad_shape,
104                                                  int64_t feature_index);
105 
106   // Infers the shape produced by applying the given convolutional filter (rhs)
107   // to lhs in the way specified by the fields on window. An optional
108   // preferred_element_type can be specified to upcast the element type.
109   static StatusOr<Shape> InferConvolveShape(
110       const Shape& lhs, const Shape& rhs, int64_t feature_group_count,
111       int64_t batch_group_count, const Window& window,
112       const ConvolutionDimensionNumbers& dimension_numbers,
113       std::optional<PrimitiveType> preferred_element_type);
114 
115   // Infers the shape produced by the given FFT type on the given operand.
116   static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type,
117                                        absl::Span<const int64_t> fft_length);
118 
119   // Infers the shape produced by the given triangular solve operation.
120   static StatusOr<Shape> InferTriangularSolveShape(
121       const Shape& a, const Shape& b, const TriangularSolveOptions& options);
122 
123   // Infers the shape produced by the given triangular solve operation.
124   static StatusOr<Shape> InferCholeskyShape(const Shape& a);
125 
126   // Infers the shape produced by an all-gather with the given operand shape,
127   // concat dimension, and shard count.
128   static StatusOr<Shape> InferAllGatherShape(
129       absl::Span<const Shape* const> operand_shapes,
130       int64_t all_gather_dimension, int64_t shard_count);
131 
132   // Infers the shape produced by an all-gather-start with the given operand
133   // shape, concat dimension, and shard count.
134   static StatusOr<Shape> InferAllGatherStartShape(
135       absl::Span<const Shape* const> operand_shapes,
136       int64_t all_gather_dimension, int64_t shard_count);
137 
138   // Infers the shape produced by an all-gather-done given a certain
139   // all-gather-start shape.
140   static StatusOr<Shape> InferAllGatherDoneShape(
141       const Shape& all_gather_start_shape);
142 
143   // Infers the shape produced by a cross replica sum with the given operand
144   // shapes.
145   static StatusOr<Shape> InferAllReduceShape(
146       absl::Span<const Shape* const> operand_shapes);
147 
148   // Infers the shape produced by a reduce-scatter with the given operand
149   // shape, scatter dimension, and shard count.
150   static StatusOr<Shape> InferReduceScatterShape(
151       absl::Span<const Shape* const> operand_shapes, int64_t scatter_dimension,
152       int64_t shard_count);
153 
154   // Infers the shape produced by a cross replica sum start.
155   static StatusOr<Shape> InferAllReduceStartShape(
156       absl::Span<const Shape* const> operand_shapes);
157 
158   // Infers the shape produced by a cross replica sum done.
159   static StatusOr<Shape> InferAllReduceDoneShape(const Shape& operand_shape);
160 
161   // Infers final shape of an Alltoall operation that is created by the xla
162   // builder.
163   static StatusOr<Shape> InferAllToAllShape(const Shape& shape,
164                                             int64_t split_dimension,
165                                             int64_t concat_dimension,
166                                             int64_t split_count);
167 
168   // Infers the shape of an HLO all-to-all instruction.
169   static StatusOr<Shape> InferAllToAllTupleShape(
170       absl::Span<const Shape* const> operand_shapes);
171 
172   // Infers the shape of a collective permute operation.
173   static StatusOr<Shape> InferCollectivePermuteShape(
174       absl::Span<const Shape* const> operand_shapes);
175 
176   // Infers the shape of a collective permute start operation.
177   static StatusOr<Shape> InferCollectivePermuteStartShape(
178       absl::Span<const Shape* const> operand_shapes);
179 
180   // Infers the shape of a collective permute operation.
181   static StatusOr<Shape> InferCollectivePermuteDoneShape(
182       const Shape& operand_shape);
183 
184   // Infers the shape produced by applying the given reduction computation
185   // shape to the given input operand shape.
186   //
187   // If pass_index is true, the reduce function is invoked with the element
188   // index as the leading parameter, and the program shape should match
189   // accordingly (or an error will result).
190   static StatusOr<Shape> InferReduceShape(
191       absl::Span<const Shape* const> arg_shapes,
192       absl::Span<const int64_t> dimensions_to_reduce,
193       const ProgramShape& to_apply);
194 
195   // Infers the shape produced by applying the given computation to the operand
196   // shape with the given window and stride dimensions.
197   static StatusOr<Shape> InferReduceWindowShape(
198       const Shape& operand_shape, const Shape& init_value, const Window& window,
199       const ProgramShape& to_apply_shape);
200   static StatusOr<Shape> InferReduceWindowShape(const Shape& operand_shape,
201                                                 const Shape& init_value,
202                                                 const Window& window);
203   static StatusOr<Shape> InferReduceWindowShape(
204       absl::Span<const Shape* const> operands,
205       absl::Span<const Shape* const> init_values, const Window& window,
206       const ProgramShape& to_apply_shape);
207 
208   static StatusOr<Shape> InferReduceWindowShape(
209       absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values,
210       const Window& window);
211 
212   // Infers the shape produced by scattering the given source shape to the
213   // selected indices of each window on the operand shape.
214   static StatusOr<Shape> InferSelectAndScatterShape(
215       const Shape& operand_shape, const ProgramShape& select_shape,
216       const Window& window, const Shape& source_shape,
217       const Shape& init_value_shape, const ProgramShape& scatter_shape);
218 
219   // Infers the shape produced by a reverse operation that reverses the order
220   // of the elements in the given dimensions.
221   static StatusOr<Shape> InferReverseShape(
222       const Shape& operand_shape, absl::Span<const int64_t> dimensions);
223 
224   // Infers the shape produced by a slice operation spanning from the starts to
225   // the limits in the original shape's dimensions.
226   //
227   // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16]
228   static StatusOr<Shape> InferSliceShape(const Shape& arg,
229                                          absl::Span<const int64_t> starts,
230                                          absl::Span<const int64_t> limits,
231                                          absl::Span<const int64_t> strides);
232 
233   // Infers the shape produced by a dynamic slice operation of size specified
234   // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'.
235   static StatusOr<Shape> InferDynamicSliceShape(
236       const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
237       absl::Span<const int64_t> slice_sizes, bool allow_scalar_indices = true);
238 
239   // Infers the shape produced by a dynamic update slice operation based
240   // on the shape of operand and update.
241   static StatusOr<Shape> InferDynamicUpdateSliceShape(
242       const Shape& operand_shape, const Shape& update_shape,
243       absl::Span<const Shape> start_index_shapes,
244       bool allow_scalar_indices = true);
245 
246   // Infers the shape produced by doing a compile-time-constant indexing into
247   // the given input shape. This is essential for operations on tuples, because
248   // it is impossible to infer the type that comes out of the tuple indexing if
249   // it is not a compile time constant.
250   static StatusOr<Shape> InferGetTupleElementShape(const Shape& arg,
251                                                    int64_t index);
252 
253   // Infers the shape produced from a while node. condition and body are the
254   // shapes of computations for the condition and the body of a while node, and
255   // init is the shape of data initially passed in to the body as an argument.
256   // The shapes must match; condition: T -> PRED, body: T -> T, init: T
257   static StatusOr<Shape> InferWhileShape(const ProgramShape& condition,
258                                          const ProgramShape& body,
259                                          const Shape& init);
260 
261   // Infers the shape produced by a predicated or indexed conditional operation.
262   static StatusOr<Shape> InferConditionalShape(
263       const Shape& branch_index,
264       absl::Span<const ProgramShape> branch_computations,
265       absl::Span<const Shape> branch_operands);
266 
267   // Infers the shape produced by a broadcast operation.
268   static StatusOr<Shape> InferBroadcastShape(
269       const Shape& operand, absl::Span<const int64_t> broadcast_sizes);
270 
271   // Checks whether the given parameters can form a broadcast. Returns the same
272   // output_shape if it's legal.
273   static StatusOr<Shape> InferBroadcastShape(
274       const Shape& operand_shape, const Shape& output_shape,
275       absl::Span<const int64_t> broadcast_dimensions);
276 
277   // Infers the shape produced by a reshape operation from the element type of
278   // its operand and the new dimension sizes specified.
279   static StatusOr<Shape> InferReshapeShape(const Shape& operand,
280                                            absl::Span<const int64_t> dimensions,
281                                            absl::Span<const int64_t> new_sizes,
282                                            int64_t inferred_dimension);
283 
284   // Infers the shape produced by a dynamic reshape operation from the element
285   // type of its operand and the new dimension sizes specified. The result shape
286   // will have dynamic dimensions as specific in `dim_is_dynamic` and bound
287   // `new_size_bounds`.
288   static StatusOr<Shape> InferDynamicReshapeShape(
289       const Shape& operand, absl::Span<const Shape* const> dim_size_shapes,
290       absl::Span<const int64_t> new_size_bounds,
291       const std::vector<bool>& dims_are_dynamic);
292 
293   // Infers the shape produced by a transpose operation from the element type of
294   // its operand and its dimensions field.
295   static StatusOr<Shape> InferTransposeShape(
296       const Shape& operand, absl::Span<const int64_t> dimensions);
297 
298   // Helper that infers the shape produced by performing a concatenate operation
299   // with the given operand shapes.
300   static StatusOr<Shape> InferConcatOpShape(
301       absl::Span<const Shape* const> arg_shapes, int64_t dimension);
302 
303   // Helper that validates the given operand shape can be converted to the
304   // target output_shape via a convert instruction -- the requirement is that
305   // the shape is identical except for the element type.
306   static StatusOr<Shape> InferConvertShape(const Shape& operand_shape,
307                                            PrimitiveType new_element_type);
308 
309   // Helper that validates the given operand shape can be bitcast converted to
310   // the target output_shape via a bitcast convert instruction -- the
311   // requirement is that the shape is identical except for the element type and
312   // the element types have identical bit-widths.
313   static StatusOr<Shape> InferBitcastConvertShape(
314       const Shape& operand_shape, PrimitiveType new_element_type);
315 
316   // Helper that validates the input data type for a reduce-precision operation,
317   // and returns the result shape.
318   static StatusOr<Shape> InferReducePrecisionShape(const Shape& operand_shape,
319                                                    const int exponent_bits,
320                                                    const int mantissa_bits);
321 
322   // Helper that infers the shape produced by a pad operation based on the
323   // padding configuration.
324   static StatusOr<Shape> InferPadShape(const Shape& operand_shape,
325                                        const Shape& padding_value_shape,
326                                        const PaddingConfig& padding_config);
327 
328   // Helper that validates the given arg_shapes are compatible with the shape of
329   // the to_apply parameters, and returns the to_apply result shape.
330   static StatusOr<Shape> InferCallShape(
331       absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply);
332 
333   // Helper that infers the shape produced by performing a dot operation with
334   // the given LHS and RHS shapes. An optional preferred_element_type can be
335   // specified to upcast the element type.
336   static StatusOr<Shape> InferDotOpShape(
337       const Shape& lhs, const Shape& rhs,
338       const DotDimensionNumbers& dimension_numbers,
339       std::optional<PrimitiveType> preferred_element_type);
340 
341   // Helper that infers the shape of the tensor produced by a gather operation
342   // with the given input shape, gather indices shape and gather dimension
343   // numbers.
344   static StatusOr<Shape> InferGatherShape(
345       const Shape& input_shape, const Shape& start_indices_shape,
346       const GatherDimensionNumbers& gather_dim_numbers,
347       absl::Span<const int64_t> slice_sizes);
348 
349   // Helper that validates the given input shape, scatter indices shape, updates
350   // shape, and scatter dimension numbers that constitute a scatter operation,
351   // and returns the result shape of the scatter operation.
352   static StatusOr<Shape> InferScatterShape(
353       absl::Span<const Shape* const> arg_shapes,
354       const ProgramShape& to_apply_shape,
355       const ScatterDimensionNumbers& scatter_dim_numbers);
356 
357   // Helper that validates the given input shape to GetDimensionSize.
358   static StatusOr<Shape> InferGetDimensionSizeShape(const Shape& shape,
359                                                     int64_t dimension);
360 
361   // Helper that validates the given input shape to SetDimensionSize.
362   static StatusOr<Shape> InferSetDimensionSizeShape(const Shape& operand_shape,
363                                                     const Shape& val_shape,
364                                                     int64_t dimension);
365 
366   // Helper function for creating a Window proto from user-supplied data.
367   // Returns error if the user-supplied data was invalid.
368   static StatusOr<Window> InferWindowFromDimensions(
369       absl::Span<const int64_t> window_dimensions,
370       absl::Span<const int64_t> window_strides,
371       absl::Span<const std::pair<int64_t, int64_t>> padding,
372       absl::Span<const int64_t> lhs_dilation,
373       absl::Span<const int64_t> rhs_dilation,
374       std::optional<std::vector<bool>> window_reversal = std::nullopt);
375 
376  private:
377   // Helper that infers the shape produced by performing an element-wise binary
378   // operation with the given LHS and RHS shapes.
379   // Note: By "element-wise" we mean operations that look at a single element in
380   // the LHS and a single element in the RHS to produce a single output element,
381   // even in the presence of broadcasting of one of the operands over the other.
382   static StatusOr<Shape> InferElementwiseBinaryOpShape(
383       HloOpcode operation, const Shape& lhs, const Shape& rhs,
384       absl::Span<const int64_t> broadcast_dimensions);
385 
386   // Helper for inferring the shape of Clamp ops.
387   static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand,
388                                          const Shape& max);
389 
390   // Helper for inferring the shape of Select ops.
391   static StatusOr<Shape> InferSelectShape(const Shape& pred,
392                                           const Shape& on_true,
393                                           const Shape& on_false);
394 
395   // Helper for inferring shapes of binary operations which use degenerate
396   // dimension broadcasting (a dimension of size 1 in one operand is broadcast
397   // up to match the size of the dimension in the other operand).
398   static StatusOr<Shape> InferDegenerateDimensionBroadcastShape(
399       HloOpcode operation, const Shape& lhs, const Shape& rhs);
400 
401   // Helper for inferring shapes of binary operations using "InDim"
402   // broadcasting. This is the broadcasting used in the *InDim binary operations
403   // (for example ComputationBuilder::AddInDim). smaller_shape must be a
404   // lower-rank shape than larger_shape. Returns the shape that the
405   // smaller_shape is broadcast to.
406   static StatusOr<Shape> InferInDimBroadcastShape(
407       const Shape& smaller_shape, const Shape& larger_shape,
408       absl::Span<const int64_t> broadcast_dimensions);
409 
410   ShapeInference(const ShapeInference&) = delete;
411   ShapeInference& operator=(const ShapeInference&) = delete;
412 };
413 
414 }  // namespace xla
415 
416 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
417