1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Shape inference is used by the XLA service as the user builds up 17 // computation requests. 18 19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_ 20 #define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_ 21 22 #include <vector> 23 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 26 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 27 #include "tensorflow/compiler/xla/statusor.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/compiler/xla/xla_data.pb.h" 30 31 namespace xla { 32 33 // For a given operation and input shapes, infers what the resulting shape is 34 // for the operation. With this functionality, the user does not need to specify 35 // the expected result type for computations that are built up via the API -- 36 // the shape that results from an operation is inferred. Some methods have 37 // overloads for inferring shape at the HLO level. 38 // 39 // TODO(b/73352135): Shape inference does not issue very good error messages, in 40 // part because HloInstruction::ToString() is not available since shape 41 // inference runs before the HloInstruction object is created. We need a 42 // solution for this. 43 class ShapeInference { 44 public: 45 // Infers the shape produced by applying the given unary operation to the 46 // given input shape. 47 static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode, 48 const Shape& shape); 49 static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode, 50 const HloInstruction* operand); 51 52 // Infers the shape produced by applying the given binary operation to the 53 // given input shapes. 54 static StatusOr<Shape> InferBinaryOpShape( 55 HloOpcode opcode, const Shape& lhs, const Shape& rhs, 56 absl::Span<const int64_t> broadcast_dimensions); 57 static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode, 58 const HloInstruction* lhs, 59 const HloInstruction* rhs); 60 61 // Infers the shape produced by applying the given ternary operation to the 62 // given input shapes. 63 static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode, const Shape& lhs, 64 const Shape& rhs, 65 const Shape& ehs); 66 static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode, 67 const HloInstruction* lhs, 68 const HloInstruction* rhs, 69 const HloInstruction* ehs); 70 71 // Infers the shape produced by applying the given variadic operation to the 72 // given input operand shapes. 73 static StatusOr<Shape> InferVariadicOpShape( 74 HloOpcode opcode, absl::Span<const Shape* const> operand_shapes); 75 static StatusOr<Shape> InferVariadicOpShape( 76 HloOpcode opcode, absl::Span<const HloInstruction* const> operands); 77 78 // Infers the shape produced by applying the given mapping computation shape 79 // to the given operand shapes. 80 static StatusOr<Shape> InferMapShape( 81 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply, 82 absl::Span<const int64_t> dimensions); 83 84 // Infers the shape produced by InferBatchNormTraining with the given 85 // operands. 86 static StatusOr<Shape> InferBatchNormTrainingShape(const Shape& operand_shape, 87 const Shape& scale_shape, 88 const Shape& offset_shape, 89 int64_t feature_index); 90 91 // Infers the shape produced by InferBatchNormInference with the given 92 // operands. 93 static StatusOr<Shape> InferBatchNormInferenceShape( 94 const Shape& operand_shape, const Shape& scale_shape, 95 const Shape& offset_shape, const Shape& mean_shape, 96 const Shape& variance_shape, int64_t feature_index); 97 98 // Infers the shape produced by InferBatchNormGrad with the given operands. 99 static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape, 100 const Shape& scale_shape, 101 const Shape& mean_shape, 102 const Shape& var_shape, 103 const Shape& output_grad_shape, 104 int64_t feature_index); 105 106 // Infers the shape produced by applying the given convolutional filter (rhs) 107 // to lhs in the way specified by the fields on window. An optional 108 // preferred_element_type can be specified to upcast the element type. 109 static StatusOr<Shape> InferConvolveShape( 110 const Shape& lhs, const Shape& rhs, int64_t feature_group_count, 111 int64_t batch_group_count, const Window& window, 112 const ConvolutionDimensionNumbers& dimension_numbers, 113 std::optional<PrimitiveType> preferred_element_type); 114 115 // Infers the shape produced by the given FFT type on the given operand. 116 static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type, 117 absl::Span<const int64_t> fft_length); 118 119 // Infers the shape produced by the given triangular solve operation. 120 static StatusOr<Shape> InferTriangularSolveShape( 121 const Shape& a, const Shape& b, const TriangularSolveOptions& options); 122 123 // Infers the shape produced by the given triangular solve operation. 124 static StatusOr<Shape> InferCholeskyShape(const Shape& a); 125 126 // Infers the shape produced by an all-gather with the given operand shape, 127 // concat dimension, and shard count. 128 static StatusOr<Shape> InferAllGatherShape( 129 absl::Span<const Shape* const> operand_shapes, 130 int64_t all_gather_dimension, int64_t shard_count); 131 132 // Infers the shape produced by an all-gather-start with the given operand 133 // shape, concat dimension, and shard count. 134 static StatusOr<Shape> InferAllGatherStartShape( 135 absl::Span<const Shape* const> operand_shapes, 136 int64_t all_gather_dimension, int64_t shard_count); 137 138 // Infers the shape produced by an all-gather-done given a certain 139 // all-gather-start shape. 140 static StatusOr<Shape> InferAllGatherDoneShape( 141 const Shape& all_gather_start_shape); 142 143 // Infers the shape produced by a cross replica sum with the given operand 144 // shapes. 145 static StatusOr<Shape> InferAllReduceShape( 146 absl::Span<const Shape* const> operand_shapes); 147 148 // Infers the shape produced by a reduce-scatter with the given operand 149 // shape, scatter dimension, and shard count. 150 static StatusOr<Shape> InferReduceScatterShape( 151 absl::Span<const Shape* const> operand_shapes, int64_t scatter_dimension, 152 int64_t shard_count); 153 154 // Infers the shape produced by a cross replica sum start. 155 static StatusOr<Shape> InferAllReduceStartShape( 156 absl::Span<const Shape* const> operand_shapes); 157 158 // Infers the shape produced by a cross replica sum done. 159 static StatusOr<Shape> InferAllReduceDoneShape(const Shape& operand_shape); 160 161 // Infers final shape of an Alltoall operation that is created by the xla 162 // builder. 163 static StatusOr<Shape> InferAllToAllShape(const Shape& shape, 164 int64_t split_dimension, 165 int64_t concat_dimension, 166 int64_t split_count); 167 168 // Infers the shape of an HLO all-to-all instruction. 169 static StatusOr<Shape> InferAllToAllTupleShape( 170 absl::Span<const Shape* const> operand_shapes); 171 172 // Infers the shape of a collective permute operation. 173 static StatusOr<Shape> InferCollectivePermuteShape( 174 absl::Span<const Shape* const> operand_shapes); 175 176 // Infers the shape of a collective permute start operation. 177 static StatusOr<Shape> InferCollectivePermuteStartShape( 178 absl::Span<const Shape* const> operand_shapes); 179 180 // Infers the shape of a collective permute operation. 181 static StatusOr<Shape> InferCollectivePermuteDoneShape( 182 const Shape& operand_shape); 183 184 // Infers the shape produced by applying the given reduction computation 185 // shape to the given input operand shape. 186 // 187 // If pass_index is true, the reduce function is invoked with the element 188 // index as the leading parameter, and the program shape should match 189 // accordingly (or an error will result). 190 static StatusOr<Shape> InferReduceShape( 191 absl::Span<const Shape* const> arg_shapes, 192 absl::Span<const int64_t> dimensions_to_reduce, 193 const ProgramShape& to_apply); 194 195 // Infers the shape produced by applying the given computation to the operand 196 // shape with the given window and stride dimensions. 197 static StatusOr<Shape> InferReduceWindowShape( 198 const Shape& operand_shape, const Shape& init_value, const Window& window, 199 const ProgramShape& to_apply_shape); 200 static StatusOr<Shape> InferReduceWindowShape(const Shape& operand_shape, 201 const Shape& init_value, 202 const Window& window); 203 static StatusOr<Shape> InferReduceWindowShape( 204 absl::Span<const Shape* const> operands, 205 absl::Span<const Shape* const> init_values, const Window& window, 206 const ProgramShape& to_apply_shape); 207 208 static StatusOr<Shape> InferReduceWindowShape( 209 absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values, 210 const Window& window); 211 212 // Infers the shape produced by scattering the given source shape to the 213 // selected indices of each window on the operand shape. 214 static StatusOr<Shape> InferSelectAndScatterShape( 215 const Shape& operand_shape, const ProgramShape& select_shape, 216 const Window& window, const Shape& source_shape, 217 const Shape& init_value_shape, const ProgramShape& scatter_shape); 218 219 // Infers the shape produced by a reverse operation that reverses the order 220 // of the elements in the given dimensions. 221 static StatusOr<Shape> InferReverseShape( 222 const Shape& operand_shape, absl::Span<const int64_t> dimensions); 223 224 // Infers the shape produced by a slice operation spanning from the starts to 225 // the limits in the original shape's dimensions. 226 // 227 // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16] 228 static StatusOr<Shape> InferSliceShape(const Shape& arg, 229 absl::Span<const int64_t> starts, 230 absl::Span<const int64_t> limits, 231 absl::Span<const int64_t> strides); 232 233 // Infers the shape produced by a dynamic slice operation of size specified 234 // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. 235 static StatusOr<Shape> InferDynamicSliceShape( 236 const Shape& operand_shape, absl::Span<const Shape> start_index_shapes, 237 absl::Span<const int64_t> slice_sizes, bool allow_scalar_indices = true); 238 239 // Infers the shape produced by a dynamic update slice operation based 240 // on the shape of operand and update. 241 static StatusOr<Shape> InferDynamicUpdateSliceShape( 242 const Shape& operand_shape, const Shape& update_shape, 243 absl::Span<const Shape> start_index_shapes, 244 bool allow_scalar_indices = true); 245 246 // Infers the shape produced by doing a compile-time-constant indexing into 247 // the given input shape. This is essential for operations on tuples, because 248 // it is impossible to infer the type that comes out of the tuple indexing if 249 // it is not a compile time constant. 250 static StatusOr<Shape> InferGetTupleElementShape(const Shape& arg, 251 int64_t index); 252 253 // Infers the shape produced from a while node. condition and body are the 254 // shapes of computations for the condition and the body of a while node, and 255 // init is the shape of data initially passed in to the body as an argument. 256 // The shapes must match; condition: T -> PRED, body: T -> T, init: T 257 static StatusOr<Shape> InferWhileShape(const ProgramShape& condition, 258 const ProgramShape& body, 259 const Shape& init); 260 261 // Infers the shape produced by a predicated or indexed conditional operation. 262 static StatusOr<Shape> InferConditionalShape( 263 const Shape& branch_index, 264 absl::Span<const ProgramShape> branch_computations, 265 absl::Span<const Shape> branch_operands); 266 267 // Infers the shape produced by a broadcast operation. 268 static StatusOr<Shape> InferBroadcastShape( 269 const Shape& operand, absl::Span<const int64_t> broadcast_sizes); 270 271 // Checks whether the given parameters can form a broadcast. Returns the same 272 // output_shape if it's legal. 273 static StatusOr<Shape> InferBroadcastShape( 274 const Shape& operand_shape, const Shape& output_shape, 275 absl::Span<const int64_t> broadcast_dimensions); 276 277 // Infers the shape produced by a reshape operation from the element type of 278 // its operand and the new dimension sizes specified. 279 static StatusOr<Shape> InferReshapeShape(const Shape& operand, 280 absl::Span<const int64_t> dimensions, 281 absl::Span<const int64_t> new_sizes, 282 int64_t inferred_dimension); 283 284 // Infers the shape produced by a dynamic reshape operation from the element 285 // type of its operand and the new dimension sizes specified. The result shape 286 // will have dynamic dimensions as specific in `dim_is_dynamic` and bound 287 // `new_size_bounds`. 288 static StatusOr<Shape> InferDynamicReshapeShape( 289 const Shape& operand, absl::Span<const Shape* const> dim_size_shapes, 290 absl::Span<const int64_t> new_size_bounds, 291 const std::vector<bool>& dims_are_dynamic); 292 293 // Infers the shape produced by a transpose operation from the element type of 294 // its operand and its dimensions field. 295 static StatusOr<Shape> InferTransposeShape( 296 const Shape& operand, absl::Span<const int64_t> dimensions); 297 298 // Helper that infers the shape produced by performing a concatenate operation 299 // with the given operand shapes. 300 static StatusOr<Shape> InferConcatOpShape( 301 absl::Span<const Shape* const> arg_shapes, int64_t dimension); 302 303 // Helper that validates the given operand shape can be converted to the 304 // target output_shape via a convert instruction -- the requirement is that 305 // the shape is identical except for the element type. 306 static StatusOr<Shape> InferConvertShape(const Shape& operand_shape, 307 PrimitiveType new_element_type); 308 309 // Helper that validates the given operand shape can be bitcast converted to 310 // the target output_shape via a bitcast convert instruction -- the 311 // requirement is that the shape is identical except for the element type and 312 // the element types have identical bit-widths. 313 static StatusOr<Shape> InferBitcastConvertShape( 314 const Shape& operand_shape, PrimitiveType new_element_type); 315 316 // Helper that validates the input data type for a reduce-precision operation, 317 // and returns the result shape. 318 static StatusOr<Shape> InferReducePrecisionShape(const Shape& operand_shape, 319 const int exponent_bits, 320 const int mantissa_bits); 321 322 // Helper that infers the shape produced by a pad operation based on the 323 // padding configuration. 324 static StatusOr<Shape> InferPadShape(const Shape& operand_shape, 325 const Shape& padding_value_shape, 326 const PaddingConfig& padding_config); 327 328 // Helper that validates the given arg_shapes are compatible with the shape of 329 // the to_apply parameters, and returns the to_apply result shape. 330 static StatusOr<Shape> InferCallShape( 331 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply); 332 333 // Helper that infers the shape produced by performing a dot operation with 334 // the given LHS and RHS shapes. An optional preferred_element_type can be 335 // specified to upcast the element type. 336 static StatusOr<Shape> InferDotOpShape( 337 const Shape& lhs, const Shape& rhs, 338 const DotDimensionNumbers& dimension_numbers, 339 std::optional<PrimitiveType> preferred_element_type); 340 341 // Helper that infers the shape of the tensor produced by a gather operation 342 // with the given input shape, gather indices shape and gather dimension 343 // numbers. 344 static StatusOr<Shape> InferGatherShape( 345 const Shape& input_shape, const Shape& start_indices_shape, 346 const GatherDimensionNumbers& gather_dim_numbers, 347 absl::Span<const int64_t> slice_sizes); 348 349 // Helper that validates the given input shape, scatter indices shape, updates 350 // shape, and scatter dimension numbers that constitute a scatter operation, 351 // and returns the result shape of the scatter operation. 352 static StatusOr<Shape> InferScatterShape( 353 absl::Span<const Shape* const> arg_shapes, 354 const ProgramShape& to_apply_shape, 355 const ScatterDimensionNumbers& scatter_dim_numbers); 356 357 // Helper that validates the given input shape to GetDimensionSize. 358 static StatusOr<Shape> InferGetDimensionSizeShape(const Shape& shape, 359 int64_t dimension); 360 361 // Helper that validates the given input shape to SetDimensionSize. 362 static StatusOr<Shape> InferSetDimensionSizeShape(const Shape& operand_shape, 363 const Shape& val_shape, 364 int64_t dimension); 365 366 // Helper function for creating a Window proto from user-supplied data. 367 // Returns error if the user-supplied data was invalid. 368 static StatusOr<Window> InferWindowFromDimensions( 369 absl::Span<const int64_t> window_dimensions, 370 absl::Span<const int64_t> window_strides, 371 absl::Span<const std::pair<int64_t, int64_t>> padding, 372 absl::Span<const int64_t> lhs_dilation, 373 absl::Span<const int64_t> rhs_dilation, 374 std::optional<std::vector<bool>> window_reversal = std::nullopt); 375 376 private: 377 // Helper that infers the shape produced by performing an element-wise binary 378 // operation with the given LHS and RHS shapes. 379 // Note: By "element-wise" we mean operations that look at a single element in 380 // the LHS and a single element in the RHS to produce a single output element, 381 // even in the presence of broadcasting of one of the operands over the other. 382 static StatusOr<Shape> InferElementwiseBinaryOpShape( 383 HloOpcode operation, const Shape& lhs, const Shape& rhs, 384 absl::Span<const int64_t> broadcast_dimensions); 385 386 // Helper for inferring the shape of Clamp ops. 387 static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand, 388 const Shape& max); 389 390 // Helper for inferring the shape of Select ops. 391 static StatusOr<Shape> InferSelectShape(const Shape& pred, 392 const Shape& on_true, 393 const Shape& on_false); 394 395 // Helper for inferring shapes of binary operations which use degenerate 396 // dimension broadcasting (a dimension of size 1 in one operand is broadcast 397 // up to match the size of the dimension in the other operand). 398 static StatusOr<Shape> InferDegenerateDimensionBroadcastShape( 399 HloOpcode operation, const Shape& lhs, const Shape& rhs); 400 401 // Helper for inferring shapes of binary operations using "InDim" 402 // broadcasting. This is the broadcasting used in the *InDim binary operations 403 // (for example ComputationBuilder::AddInDim). smaller_shape must be a 404 // lower-rank shape than larger_shape. Returns the shape that the 405 // smaller_shape is broadcast to. 406 static StatusOr<Shape> InferInDimBroadcastShape( 407 const Shape& smaller_shape, const Shape& larger_shape, 408 absl::Span<const int64_t> broadcast_dimensions); 409 410 ShapeInference(const ShapeInference&) = delete; 411 ShapeInference& operator=(const ShapeInference&) = delete; 412 }; 413 414 } // namespace xla 415 416 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_ 417