xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/dot_as_convolution_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_
18 
19 #include <memory>
20 #include <optional>
21 #include <vector>
22 
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 
25 namespace xla {
26 namespace dot_as_convolution_util {
27 
28 // Type of Batch representation for a convolution that has a spatial dimension
29 // that is effectively a batch dimension. We currently have two
30 // representations that we detect as "batch equivalent" and this enum allows
31 // differentiating between the two.
32 enum class SpatialBatchRepresentation {
33   kNone,
34   kUnpaddedVersion,
35   kPaddedVersion,
36 };
37 
38 // Describes the dimensions of a convolution that can be interpreted as a dot
39 // or a normal convolution.
40 struct DotConvolutionDimsInfo {
41   // The dimension numbers for the operands and output corresponding to a
42   // logical dimension (e.g., batch, contracting, non-contracting). If an
43   // operand or the output doesn't have the logical dimension, it is set to
44   // -1.
45   struct DimNums {
46     int64_t lhs;
47     int64_t rhs;
48     int64_t output;
49     // The corresponding spatial dimension in the convolution's config. Set to
50     // -1 if it's not mapped to a spatial dimension.
51     int64_t spatial_dim;
52   };
53   std::vector<DimNums> batch_dims;
54   std::vector<DimNums> contracting_dims;
55   std::vector<DimNums> lhs_non_contracting_dims;
56   std::vector<DimNums> rhs_non_contracting_dims;
57   std::vector<DimNums> conv_spatial_dims;
58 };
59 
60 // Parses a convolution and returns a DotGeneralAsConvolutionDimsInfo. If it can
61 // be interpreted as a dot, there is no conv_spatial_dims.
62 DotConvolutionDimsInfo ParseConvolutionDimsInfo(const HloInstruction* conv);
63 
64 // Creates sharded convolution instruction that can be interpreted as a dot.
65 // This is a utility for per-op partitioners.
66 //  - 'conv' is the original convolution instruction.
67 //  - 'dot_dnums' is the result of ParseDotConvolutionDimsInfo() for 'conv'.
68 //  - 'sharded_lhs_hlo' and 'sharded_rhs_hlo' are sharded inputs for the result
69 //    convolution instruction.
70 StatusOr<std::unique_ptr<HloInstruction>>
71 CreateShardedConvForDotGeneralConvolution(
72     const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums,
73     HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo);
74 
75 // Check if a spatial dim is parallel batch dimension.
76 // A parallel batch dimension in DotGeneral is represented as a spatial
77 // dimension with window size B (batch dimension size), stride B - 1, and base
78 // dilation B or an alternative representation of window size B, stride B,
79 // padding low/high B - 1, base dilation B - 1 and window reversal
80 SpatialBatchRepresentation SpatialIsBatch(int64_t lhs_spatial_size,
81                                           const WindowDimension& spatial_wd);
82 // Returns if the spatial dimension represented by 'spatial_wd' is an LHS non
83 // contracting dimension.
84 bool SpatialIsLhsNonContracting(int64_t rhs_spatial_size,
85                                 const WindowDimension& spatial_wd);
86 // Returns if the spatial dimension represented by 'spatial_wd' is an RHS non
87 // contracting dimension.
88 bool SpatialIsRhsNonContracting(int64_t lhs_spatial_size,
89                                 int64_t rhs_spatial_size,
90                                 const WindowDimension& spatial_wd);
91 // Returns if the spatial dimension represented by 'spatial_wd' endsup being
92 // equivalent to a contracting dimension.
93 bool SpatialIsContracting(int64_t lhs_spatial_size, int64_t rhs_spatial_size,
94                           const WindowDimension& spatial_wd);
95 // Returns a DotConvolutionDimsInfo from a kDot instruction, where all
96 // the spatial_dim values are set to -1.
97 DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot);
98 
99 }  // namespace dot_as_convolution_util
100 }  // namespace xla
101 
102 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_
103