1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
17
18 #include <optional>
19
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/service/shape_inference.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24
25 namespace xla {
26 namespace dot_as_convolution_util {
27
SpatialIsBatch(int64_t lhs_spatial_size,const WindowDimension & spatial_wd)28 SpatialBatchRepresentation SpatialIsBatch(int64_t lhs_spatial_size,
29 const WindowDimension& spatial_wd) {
30 if (lhs_spatial_size == spatial_wd.size() &&
31 lhs_spatial_size == spatial_wd.base_dilation() &&
32 ((std::max<int64_t>(1, lhs_spatial_size - 1) == spatial_wd.stride() &&
33 spatial_wd.window_dilation() == 1) ||
34 (std::max<int64_t>(1, lhs_spatial_size - 1) ==
35 spatial_wd.window_dilation() &&
36 spatial_wd.stride() == 1)) &&
37 spatial_wd.padding_high() == 0 && spatial_wd.padding_low() == 0 &&
38 !spatial_wd.window_reversal()) {
39 return SpatialBatchRepresentation::kUnpaddedVersion;
40 } else if (lhs_spatial_size == spatial_wd.size() &&
41 spatial_wd.padding_high() == lhs_spatial_size - 1 &&
42 spatial_wd.padding_low() == lhs_spatial_size - 1 &&
43 spatial_wd.window_reversal() &&
44 spatial_wd.window_dilation() == 1 &&
45 spatial_wd.stride() == lhs_spatial_size &&
46 spatial_wd.base_dilation() == lhs_spatial_size - 1) {
47 return SpatialBatchRepresentation::kPaddedVersion;
48 }
49 return SpatialBatchRepresentation::kNone;
50 }
51
SpatialIsLhsNonContracting(int64_t rhs_spatial_size,const WindowDimension & spatial_wd)52 bool SpatialIsLhsNonContracting(int64_t rhs_spatial_size,
53 const WindowDimension& spatial_wd) {
54 return spatial_wd.stride() == 1 && spatial_wd.window_dilation() == 1 &&
55 spatial_wd.base_dilation() == 1 && rhs_spatial_size == 1 &&
56 spatial_wd.size() == 1 && spatial_wd.padding_high() == 0 &&
57 spatial_wd.padding_low() == 0 && !spatial_wd.window_reversal();
58 }
59
SpatialIsRhsNonContracting(int64_t lhs_spatial_size,int64_t rhs_spatial_size,const WindowDimension & spatial_wd)60 bool SpatialIsRhsNonContracting(int64_t lhs_spatial_size,
61 int64_t rhs_spatial_size,
62 const WindowDimension& spatial_wd) {
63 return spatial_wd.stride() == 1 && spatial_wd.window_dilation() == 1 &&
64 spatial_wd.base_dilation() == 1 && lhs_spatial_size == 1 &&
65 spatial_wd.size() == rhs_spatial_size &&
66 spatial_wd.padding_high() == rhs_spatial_size - 1 &&
67 spatial_wd.padding_low() == rhs_spatial_size - 1 &&
68 spatial_wd.window_reversal();
69 }
70
SpatialIsContracting(int64_t lhs_spatial_size,int64_t rhs_spatial_size,const WindowDimension & spatial_wd)71 bool SpatialIsContracting(int64_t lhs_spatial_size, int64_t rhs_spatial_size,
72 const WindowDimension& spatial_wd) {
73 return lhs_spatial_size == spatial_wd.size() &&
74 spatial_wd.base_dilation() == 1 && spatial_wd.window_dilation() == 1 &&
75 spatial_wd.padding_high() == 0 && spatial_wd.padding_low() == 0 &&
76 !spatial_wd.window_reversal();
77 }
78
ParseConvolutionDimsInfo(const HloInstruction * conv)79 /* static */ DotConvolutionDimsInfo ParseConvolutionDimsInfo(
80 const HloInstruction* conv) {
81 CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
82 const auto& conv_dims = conv->convolution_dimension_numbers();
83 DotConvolutionDimsInfo dims;
84 dims.lhs_non_contracting_dims.push_back(
85 {conv_dims.input_batch_dimension(), -1,
86 conv_dims.output_batch_dimension(), -1});
87 dims.rhs_non_contracting_dims.push_back(
88 {-1, conv_dims.kernel_output_feature_dimension(),
89 conv_dims.output_feature_dimension(), -1});
90 dims.contracting_dims.push_back({conv_dims.input_feature_dimension(),
91 conv_dims.kernel_input_feature_dimension(),
92 -1, -1});
93
94 for (int64_t i = 0; i < conv_dims.input_spatial_dimensions_size(); ++i) {
95 int64_t lhs = conv_dims.input_spatial_dimensions(i);
96 int64_t lhs_size = conv->operand(0)->shape().dimensions(lhs);
97 int64_t rhs = conv_dims.kernel_spatial_dimensions(i);
98 int64_t rhs_size = conv->operand(1)->shape().dimensions(rhs);
99 int64_t output = conv_dims.output_spatial_dimensions(i);
100 const auto& wd = conv->window().dimensions(i);
101 if (SpatialIsBatch(lhs_size, wd) != SpatialBatchRepresentation::kNone) {
102 dims.batch_dims.push_back({lhs, rhs, output, i});
103 } else if (lhs_size == wd.size() && wd.base_dilation() == 1 &&
104 wd.window_dilation() == 1 && wd.padding_high() == 0 &&
105 wd.padding_low() == 0 && !wd.window_reversal()) {
106 // A contracting dimension be represented as a spatial dimension with
107 // window size C (contracting dimension size). Stride can be any size
108 // since there is only one window.
109 dims.contracting_dims.push_back({lhs, rhs, output, i});
110 } else if (wd.stride() == 1 && wd.window_dilation() == 1 &&
111 wd.base_dilation() == 1) {
112 if (rhs_size == 1 && wd.size() == 1 && wd.padding_high() == 0 &&
113 wd.padding_low() == 0 && !wd.window_reversal()) {
114 // A LHS non-contracting dimension can be represented as a spatial
115 // dimension with window size 1.
116 dims.lhs_non_contracting_dims.push_back({lhs, rhs, output, i});
117 } else if (lhs_size == 1 && wd.size() == rhs_size &&
118 wd.padding_high() == rhs_size - 1 &&
119 wd.padding_low() == rhs_size - 1 && wd.window_reversal()) {
120 // A RHS non-contracting dimension can be represented as a spatial
121 // dimension with window size N (non-contracting dimension size), low
122 // padding N - 1, high padding N - 1 and window reversal.
123 dims.rhs_non_contracting_dims.push_back({lhs, rhs, output, i});
124 } else {
125 dims.conv_spatial_dims.push_back({lhs, rhs, output, i});
126 }
127 } else {
128 dims.conv_spatial_dims.push_back({lhs, rhs, output, i});
129 }
130 }
131
132 return dims;
133 }
134
135 StatusOr<std::unique_ptr<HloInstruction>>
CreateShardedConvForDotGeneralConvolution(const HloInstruction & conv,const DotConvolutionDimsInfo & dot_dnums,HloInstruction * sharded_lhs_hlo,HloInstruction * sharded_rhs_hlo)136 CreateShardedConvForDotGeneralConvolution(
137 const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums,
138 HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo) {
139 CHECK_EQ(conv.opcode(), HloOpcode::kConvolution);
140 const auto& conv_dnums = conv.convolution_dimension_numbers();
141 auto window = conv.window();
142 for (const auto& dim : dot_dnums.batch_dims) {
143 auto wd = window.mutable_dimensions(dim.spatial_dim);
144 wd->set_size(sharded_lhs_hlo->shape().dimensions(
145 conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
146 wd->set_stride(std::max<int64_t>(1, wd->size() - 1));
147 wd->set_base_dilation(wd->size());
148 }
149 for (const auto& dim : dot_dnums.contracting_dims) {
150 if (dim.spatial_dim < 0) {
151 continue;
152 }
153 auto wd = window.mutable_dimensions(dim.spatial_dim);
154 wd->set_size(sharded_lhs_hlo->shape().dimensions(
155 conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
156 }
157 for (const auto& dim : dot_dnums.rhs_non_contracting_dims) {
158 if (dim.spatial_dim < 0) {
159 continue;
160 }
161 auto wd = window.mutable_dimensions(dim.spatial_dim);
162 wd->set_size(sharded_rhs_hlo->shape().dimensions(
163 conv_dnums.kernel_spatial_dimensions(dim.spatial_dim)));
164 wd->set_padding_high(wd->size() - 1);
165 wd->set_padding_low(wd->size() - 1);
166 }
167 TF_ASSIGN_OR_RETURN(
168 Shape sharded_conv_shape,
169 ShapeInference::InferConvolveShape(
170 sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
171 /*feature_group_count=*/conv.feature_group_count(),
172 /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums,
173 /*preferred_element_type=*/conv.shape().element_type()));
174 *sharded_conv_shape.mutable_layout() = conv.shape().layout();
175 return HloInstruction::CreateConvolve(
176 sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo,
177 /*feature_group_count=*/conv.feature_group_count(),
178 /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums,
179 conv.precision_config());
180 }
181
ParseDotGeneralFromDot(const HloInstruction * dot)182 DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot) {
183 const auto& dot_dim_numbs = dot->dot_dimension_numbers();
184 dot_as_convolution_util::DotConvolutionDimsInfo dnums;
185 for (int64_t i = 0; i < dot_dim_numbs.lhs_batch_dimensions().size(); ++i) {
186 dnums.batch_dims.emplace_back();
187 dnums.batch_dims.back().lhs = dot_dim_numbs.lhs_batch_dimensions(i);
188 dnums.batch_dims.back().rhs = dot_dim_numbs.rhs_batch_dimensions(i);
189 dnums.batch_dims.back().output = i;
190 dnums.batch_dims.back().spatial_dim = -1;
191 }
192 for (int64_t i = 0; i < dot_dim_numbs.lhs_contracting_dimensions().size();
193 ++i) {
194 dnums.contracting_dims.emplace_back();
195 dnums.contracting_dims.back().lhs =
196 dot_dim_numbs.lhs_contracting_dimensions(i);
197 dnums.contracting_dims.back().rhs =
198 dot_dim_numbs.rhs_contracting_dimensions(i);
199 dnums.contracting_dims.back().output = -1;
200 dnums.contracting_dims.back().spatial_dim = -1;
201 }
202 for (int64_t i = 0; i < dot->operand(0)->shape().rank(); ++i) {
203 if (!absl::c_linear_search(dot_dim_numbs.lhs_batch_dimensions(), i) &&
204 !absl::c_linear_search(dot_dim_numbs.lhs_contracting_dimensions(), i)) {
205 dnums.lhs_non_contracting_dims.emplace_back();
206 dnums.lhs_non_contracting_dims.back().lhs = i;
207 dnums.lhs_non_contracting_dims.back().rhs = -1;
208 dnums.lhs_non_contracting_dims.back().output =
209 dot_dim_numbs.lhs_batch_dimensions_size() +
210 dnums.lhs_non_contracting_dims.size() - 1;
211 dnums.lhs_non_contracting_dims.back().spatial_dim = -1;
212 }
213 }
214 for (int64_t i = 0; i < dot->operand(1)->shape().rank(); ++i) {
215 if (!absl::c_linear_search(dot_dim_numbs.rhs_batch_dimensions(), i) &&
216 !absl::c_linear_search(dot_dim_numbs.rhs_contracting_dimensions(), i)) {
217 dnums.rhs_non_contracting_dims.emplace_back();
218 dnums.rhs_non_contracting_dims.back().lhs = -1;
219 dnums.rhs_non_contracting_dims.back().rhs = i;
220 dnums.rhs_non_contracting_dims.back().output =
221 dot_dim_numbs.lhs_batch_dimensions_size() +
222 dnums.lhs_non_contracting_dims.size() +
223 dnums.rhs_non_contracting_dims.size() - 1;
224 dnums.rhs_non_contracting_dims.back().spatial_dim = -1;
225 }
226 }
227 return dnums;
228 }
229
230 } // namespace dot_as_convolution_util
231 } // namespace xla
232