xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/dot_as_convolution_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
17 
18 #include <optional>
19 
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/service/shape_inference.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 
25 namespace xla {
26 namespace dot_as_convolution_util {
27 
SpatialIsBatch(int64_t lhs_spatial_size,const WindowDimension & spatial_wd)28 SpatialBatchRepresentation SpatialIsBatch(int64_t lhs_spatial_size,
29                                           const WindowDimension& spatial_wd) {
30   if (lhs_spatial_size == spatial_wd.size() &&
31       lhs_spatial_size == spatial_wd.base_dilation() &&
32       ((std::max<int64_t>(1, lhs_spatial_size - 1) == spatial_wd.stride() &&
33         spatial_wd.window_dilation() == 1) ||
34        (std::max<int64_t>(1, lhs_spatial_size - 1) ==
35             spatial_wd.window_dilation() &&
36         spatial_wd.stride() == 1)) &&
37       spatial_wd.padding_high() == 0 && spatial_wd.padding_low() == 0 &&
38       !spatial_wd.window_reversal()) {
39     return SpatialBatchRepresentation::kUnpaddedVersion;
40   } else if (lhs_spatial_size == spatial_wd.size() &&
41              spatial_wd.padding_high() == lhs_spatial_size - 1 &&
42              spatial_wd.padding_low() == lhs_spatial_size - 1 &&
43              spatial_wd.window_reversal() &&
44              spatial_wd.window_dilation() == 1 &&
45              spatial_wd.stride() == lhs_spatial_size &&
46              spatial_wd.base_dilation() == lhs_spatial_size - 1) {
47     return SpatialBatchRepresentation::kPaddedVersion;
48   }
49   return SpatialBatchRepresentation::kNone;
50 }
51 
SpatialIsLhsNonContracting(int64_t rhs_spatial_size,const WindowDimension & spatial_wd)52 bool SpatialIsLhsNonContracting(int64_t rhs_spatial_size,
53                                 const WindowDimension& spatial_wd) {
54   return spatial_wd.stride() == 1 && spatial_wd.window_dilation() == 1 &&
55          spatial_wd.base_dilation() == 1 && rhs_spatial_size == 1 &&
56          spatial_wd.size() == 1 && spatial_wd.padding_high() == 0 &&
57          spatial_wd.padding_low() == 0 && !spatial_wd.window_reversal();
58 }
59 
SpatialIsRhsNonContracting(int64_t lhs_spatial_size,int64_t rhs_spatial_size,const WindowDimension & spatial_wd)60 bool SpatialIsRhsNonContracting(int64_t lhs_spatial_size,
61                                 int64_t rhs_spatial_size,
62                                 const WindowDimension& spatial_wd) {
63   return spatial_wd.stride() == 1 && spatial_wd.window_dilation() == 1 &&
64          spatial_wd.base_dilation() == 1 && lhs_spatial_size == 1 &&
65          spatial_wd.size() == rhs_spatial_size &&
66          spatial_wd.padding_high() == rhs_spatial_size - 1 &&
67          spatial_wd.padding_low() == rhs_spatial_size - 1 &&
68          spatial_wd.window_reversal();
69 }
70 
SpatialIsContracting(int64_t lhs_spatial_size,int64_t rhs_spatial_size,const WindowDimension & spatial_wd)71 bool SpatialIsContracting(int64_t lhs_spatial_size, int64_t rhs_spatial_size,
72                           const WindowDimension& spatial_wd) {
73   return lhs_spatial_size == spatial_wd.size() &&
74          spatial_wd.base_dilation() == 1 && spatial_wd.window_dilation() == 1 &&
75          spatial_wd.padding_high() == 0 && spatial_wd.padding_low() == 0 &&
76          !spatial_wd.window_reversal();
77 }
78 
ParseConvolutionDimsInfo(const HloInstruction * conv)79 /* static */ DotConvolutionDimsInfo ParseConvolutionDimsInfo(
80     const HloInstruction* conv) {
81   CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
82   const auto& conv_dims = conv->convolution_dimension_numbers();
83   DotConvolutionDimsInfo dims;
84   dims.lhs_non_contracting_dims.push_back(
85       {conv_dims.input_batch_dimension(), -1,
86        conv_dims.output_batch_dimension(), -1});
87   dims.rhs_non_contracting_dims.push_back(
88       {-1, conv_dims.kernel_output_feature_dimension(),
89        conv_dims.output_feature_dimension(), -1});
90   dims.contracting_dims.push_back({conv_dims.input_feature_dimension(),
91                                    conv_dims.kernel_input_feature_dimension(),
92                                    -1, -1});
93 
94   for (int64_t i = 0; i < conv_dims.input_spatial_dimensions_size(); ++i) {
95     int64_t lhs = conv_dims.input_spatial_dimensions(i);
96     int64_t lhs_size = conv->operand(0)->shape().dimensions(lhs);
97     int64_t rhs = conv_dims.kernel_spatial_dimensions(i);
98     int64_t rhs_size = conv->operand(1)->shape().dimensions(rhs);
99     int64_t output = conv_dims.output_spatial_dimensions(i);
100     const auto& wd = conv->window().dimensions(i);
101     if (SpatialIsBatch(lhs_size, wd) != SpatialBatchRepresentation::kNone) {
102       dims.batch_dims.push_back({lhs, rhs, output, i});
103     } else if (lhs_size == wd.size() && wd.base_dilation() == 1 &&
104                wd.window_dilation() == 1 && wd.padding_high() == 0 &&
105                wd.padding_low() == 0 && !wd.window_reversal()) {
106       // A contracting dimension be represented as a spatial dimension with
107       // window size C (contracting dimension size). Stride can be any size
108       // since there is only one window.
109       dims.contracting_dims.push_back({lhs, rhs, output, i});
110     } else if (wd.stride() == 1 && wd.window_dilation() == 1 &&
111                wd.base_dilation() == 1) {
112       if (rhs_size == 1 && wd.size() == 1 && wd.padding_high() == 0 &&
113           wd.padding_low() == 0 && !wd.window_reversal()) {
114         // A LHS non-contracting dimension can be represented as a spatial
115         // dimension with window size 1.
116         dims.lhs_non_contracting_dims.push_back({lhs, rhs, output, i});
117       } else if (lhs_size == 1 && wd.size() == rhs_size &&
118                  wd.padding_high() == rhs_size - 1 &&
119                  wd.padding_low() == rhs_size - 1 && wd.window_reversal()) {
120         // A RHS non-contracting dimension can be represented as a spatial
121         // dimension with window size N (non-contracting dimension size), low
122         // padding N - 1,  high padding N - 1 and window reversal.
123         dims.rhs_non_contracting_dims.push_back({lhs, rhs, output, i});
124       } else {
125         dims.conv_spatial_dims.push_back({lhs, rhs, output, i});
126       }
127     } else {
128       dims.conv_spatial_dims.push_back({lhs, rhs, output, i});
129     }
130   }
131 
132   return dims;
133 }
134 
135 StatusOr<std::unique_ptr<HloInstruction>>
CreateShardedConvForDotGeneralConvolution(const HloInstruction & conv,const DotConvolutionDimsInfo & dot_dnums,HloInstruction * sharded_lhs_hlo,HloInstruction * sharded_rhs_hlo)136 CreateShardedConvForDotGeneralConvolution(
137     const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums,
138     HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo) {
139   CHECK_EQ(conv.opcode(), HloOpcode::kConvolution);
140   const auto& conv_dnums = conv.convolution_dimension_numbers();
141   auto window = conv.window();
142   for (const auto& dim : dot_dnums.batch_dims) {
143     auto wd = window.mutable_dimensions(dim.spatial_dim);
144     wd->set_size(sharded_lhs_hlo->shape().dimensions(
145         conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
146     wd->set_stride(std::max<int64_t>(1, wd->size() - 1));
147     wd->set_base_dilation(wd->size());
148   }
149   for (const auto& dim : dot_dnums.contracting_dims) {
150     if (dim.spatial_dim < 0) {
151       continue;
152     }
153     auto wd = window.mutable_dimensions(dim.spatial_dim);
154     wd->set_size(sharded_lhs_hlo->shape().dimensions(
155         conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
156   }
157   for (const auto& dim : dot_dnums.rhs_non_contracting_dims) {
158     if (dim.spatial_dim < 0) {
159       continue;
160     }
161     auto wd = window.mutable_dimensions(dim.spatial_dim);
162     wd->set_size(sharded_rhs_hlo->shape().dimensions(
163         conv_dnums.kernel_spatial_dimensions(dim.spatial_dim)));
164     wd->set_padding_high(wd->size() - 1);
165     wd->set_padding_low(wd->size() - 1);
166   }
167   TF_ASSIGN_OR_RETURN(
168       Shape sharded_conv_shape,
169       ShapeInference::InferConvolveShape(
170           sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
171           /*feature_group_count=*/conv.feature_group_count(),
172           /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums,
173           /*preferred_element_type=*/conv.shape().element_type()));
174   *sharded_conv_shape.mutable_layout() = conv.shape().layout();
175   return HloInstruction::CreateConvolve(
176       sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo,
177       /*feature_group_count=*/conv.feature_group_count(),
178       /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums,
179       conv.precision_config());
180 }
181 
ParseDotGeneralFromDot(const HloInstruction * dot)182 DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot) {
183   const auto& dot_dim_numbs = dot->dot_dimension_numbers();
184   dot_as_convolution_util::DotConvolutionDimsInfo dnums;
185   for (int64_t i = 0; i < dot_dim_numbs.lhs_batch_dimensions().size(); ++i) {
186     dnums.batch_dims.emplace_back();
187     dnums.batch_dims.back().lhs = dot_dim_numbs.lhs_batch_dimensions(i);
188     dnums.batch_dims.back().rhs = dot_dim_numbs.rhs_batch_dimensions(i);
189     dnums.batch_dims.back().output = i;
190     dnums.batch_dims.back().spatial_dim = -1;
191   }
192   for (int64_t i = 0; i < dot_dim_numbs.lhs_contracting_dimensions().size();
193        ++i) {
194     dnums.contracting_dims.emplace_back();
195     dnums.contracting_dims.back().lhs =
196         dot_dim_numbs.lhs_contracting_dimensions(i);
197     dnums.contracting_dims.back().rhs =
198         dot_dim_numbs.rhs_contracting_dimensions(i);
199     dnums.contracting_dims.back().output = -1;
200     dnums.contracting_dims.back().spatial_dim = -1;
201   }
202   for (int64_t i = 0; i < dot->operand(0)->shape().rank(); ++i) {
203     if (!absl::c_linear_search(dot_dim_numbs.lhs_batch_dimensions(), i) &&
204         !absl::c_linear_search(dot_dim_numbs.lhs_contracting_dimensions(), i)) {
205       dnums.lhs_non_contracting_dims.emplace_back();
206       dnums.lhs_non_contracting_dims.back().lhs = i;
207       dnums.lhs_non_contracting_dims.back().rhs = -1;
208       dnums.lhs_non_contracting_dims.back().output =
209           dot_dim_numbs.lhs_batch_dimensions_size() +
210           dnums.lhs_non_contracting_dims.size() - 1;
211       dnums.lhs_non_contracting_dims.back().spatial_dim = -1;
212     }
213   }
214   for (int64_t i = 0; i < dot->operand(1)->shape().rank(); ++i) {
215     if (!absl::c_linear_search(dot_dim_numbs.rhs_batch_dimensions(), i) &&
216         !absl::c_linear_search(dot_dim_numbs.rhs_contracting_dimensions(), i)) {
217       dnums.rhs_non_contracting_dims.emplace_back();
218       dnums.rhs_non_contracting_dims.back().lhs = -1;
219       dnums.rhs_non_contracting_dims.back().rhs = i;
220       dnums.rhs_non_contracting_dims.back().output =
221           dot_dim_numbs.lhs_batch_dimensions_size() +
222           dnums.lhs_non_contracting_dims.size() +
223           dnums.rhs_non_contracting_dims.size() - 1;
224       dnums.rhs_non_contracting_dims.back().spatial_dim = -1;
225     }
226   }
227   return dnums;
228 }
229 
230 }  // namespace dot_as_convolution_util
231 }  // namespace xla
232