1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/spmd/convolution_handler.h"
17
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
25 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
26 #include "tensorflow/compiler/xla/service/shape_inference.h"
27 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
28 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/compiler/xla/window_util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/numbers.h"
34
35 namespace xla {
36 namespace spmd {
37
38 namespace {
39
40 // Partition convolution with batch group count.
PartitionConvolutionWithBatchGroupCount(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,int64_t num_partitions,SpmdBuilder * b)41 StatusOr<HloInstruction*> PartitionConvolutionWithBatchGroupCount(
42 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
43 const HloSharding& output_sharding,
44 const std::function<StatusOr<HloInstruction*>(
45 HloInstruction*, HloInstruction*, SpmdBuilder*,
46 const Window& conv_window)>& create_sharded_conv,
47 const Window& conv_window, HloInstruction* original_hlo,
48 int64_t num_partitions, SpmdBuilder* b) {
49 TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
50 if (original_hlo->batch_group_count() == 1 ||
51 original_hlo->batch_group_count() % num_partitions != 0) {
52 return nullptr;
53 }
54
55 const auto& dnums = original_hlo->convolution_dimension_numbers();
56 // Only supports batch_group_size equals input_batch_size case.
57 const int64_t input_batch_size =
58 lhs.base_shape().dimensions(dnums.input_batch_dimension());
59 const int64_t kernel_output_feature_size =
60 rhs.base_shape().dimensions(dnums.kernel_output_feature_dimension());
61 if (input_batch_size != kernel_output_feature_size ||
62 original_hlo->batch_group_count() != input_batch_size) {
63 return nullptr;
64 }
65
66 // Map RHS indices to LHS indices.
67 std::vector<int64_t> rhs_to_lhs_indices(output_base_shape.rank());
68 rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
69 dnums.input_batch_dimension();
70 rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
71 dnums.input_feature_dimension();
72 for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
73 rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
74 dnums.input_spatial_dimensions(i);
75 }
76
77 // Map LHS indices to RHS indices.
78 std::vector<int64_t> lhs_to_rhs_indices(output_base_shape.rank());
79 for (int64_t i = 0; i < rhs_to_lhs_indices.size(); ++i) {
80 lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
81 }
82
83 // Map LHS indices to output indices.
84 std::vector<int64_t> lhs_to_output_indices(lhs.base_shape().rank(), -1);
85 lhs_to_output_indices[dnums.input_batch_dimension()] =
86 dnums.output_feature_dimension();
87 lhs_to_output_indices[dnums.input_feature_dimension()] =
88 dnums.output_batch_dimension();
89 for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
90 lhs_to_output_indices[dnums.input_spatial_dimensions(i)] =
91 dnums.output_spatial_dimensions(i);
92 }
93
94 // Align LHS or RHS to other operand if input batch dim or kernel output
95 // feature dim is partitioned.
96 auto aligned_rhs_sharding =
97 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices);
98 auto aligned_lhs_sharding =
99 hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
100
101 bool lhs_batch_dim_is_partitioned =
102 (ShardCountAtDim(lhs.sharding(), dnums.input_batch_dimension()) ==
103 num_partitions);
104 bool rhs_output_feature_dim_is_partitioned =
105 (ShardCountAtDim(rhs.sharding(),
106 dnums.kernel_output_feature_dimension()) ==
107 num_partitions);
108 if (!lhs_batch_dim_is_partitioned && !rhs_output_feature_dim_is_partitioned) {
109 return nullptr;
110 }
111 // Reshard LHS or RHS to partition at batch dimension or output feature
112 // dimension as the other operand.
113 if (lhs_batch_dim_is_partitioned) {
114 rhs = rhs.Reshard(aligned_rhs_sharding);
115 } else {
116 lhs = lhs.Reshard(aligned_lhs_sharding);
117 }
118 // Align output sharding after LHS and RHS sharding are consistent.
119 auto aligned_output_sharding = hlo_sharding_util::TransposeSharding(
120 lhs.sharding(), lhs_to_output_indices);
121
122 // Create partitioned convolution.
123 TF_ASSIGN_OR_RETURN(
124 auto sharded_conv,
125 create_sharded_conv(lhs.hlo(), rhs.hlo(), b, conv_window));
126 sharded_conv->set_sharding(aligned_output_sharding);
127 return PartitionedHlo(sharded_conv, output_base_shape, lhs.state())
128 .Reshard(output_sharding)
129 .hlo();
130 }
131
132 // Partition convolution with feature group count.
PartitionConvolutionWithFeatureGroupCount(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,int64_t num_partitions,SpmdBuilder * b)133 StatusOr<HloInstruction*> PartitionConvolutionWithFeatureGroupCount(
134 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
135 const HloSharding& output_sharding,
136 const std::function<StatusOr<HloInstruction*>(
137 HloInstruction*, HloInstruction*, SpmdBuilder*,
138 const Window& conv_window)>& create_sharded_conv,
139 const Window& conv_window, HloInstruction* original_hlo,
140 int64_t num_partitions, SpmdBuilder* b) {
141 TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
142 if (original_hlo->feature_group_count() == 1 ||
143 original_hlo->feature_group_count() % num_partitions != 0) {
144 return nullptr;
145 }
146
147 const auto& dnums = original_hlo->convolution_dimension_numbers();
148 const int64_t input_feature_size =
149 lhs.base_shape().dimensions(dnums.input_feature_dimension());
150 const int64_t kernel_output_feature_size =
151 rhs.base_shape().dimensions(dnums.kernel_output_feature_dimension());
152 if (kernel_output_feature_size % original_hlo->feature_group_count() != 0 ||
153 input_feature_size % original_hlo->feature_group_count() != 0) {
154 return nullptr;
155 }
156
157 // Align RHS indices to LHS.
158 std::vector<int64_t> rhs_to_lhs_indices(output_base_shape.rank());
159 rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
160 dnums.input_feature_dimension();
161 rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
162 dnums.input_batch_dimension();
163 for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
164 rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
165 dnums.input_spatial_dimensions(i);
166 }
167
168 // Align LHS indices to RHS.
169 std::vector<int64_t> lhs_to_rhs_indices(output_base_shape.rank());
170 for (int64_t i = 0; i < rhs_to_lhs_indices.size(); ++i) {
171 lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
172 }
173
174 // Align LHS indices to output.
175 std::vector<int64_t> lhs_to_output_indices(output_base_shape.rank());
176 lhs_to_output_indices[dnums.input_feature_dimension()] =
177 dnums.output_feature_dimension();
178 lhs_to_output_indices[dnums.input_batch_dimension()] =
179 dnums.output_batch_dimension();
180 for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
181 lhs_to_output_indices[dnums.input_spatial_dimensions(i)] =
182 dnums.output_spatial_dimensions(i);
183 }
184
185 // Align LHS or RHS if input_feature_dim or kernel_output_feature_dim is
186 // partitioned.
187 auto aligned_rhs_sharding =
188 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices);
189 auto aligned_lhs_sharding =
190 hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
191
192 bool lhs_feature_dim_is_partitioned =
193 (ShardCountAtDim(lhs.sharding(), dnums.input_feature_dimension()) ==
194 num_partitions);
195 bool rhs_output_feature_dim_is_partitioned =
196 (ShardCountAtDim(rhs.sharding(),
197 dnums.kernel_output_feature_dimension()) ==
198 num_partitions);
199 if (!lhs_feature_dim_is_partitioned &&
200 !rhs_output_feature_dim_is_partitioned) {
201 return nullptr;
202 }
203 // Reshard LHS or RHS to partition at input feature dimension or output
204 // feature dimension as the other operand.
205 if (lhs_feature_dim_is_partitioned) {
206 rhs = rhs.Reshard(aligned_rhs_sharding);
207 } else {
208 lhs = lhs.Reshard(aligned_lhs_sharding);
209 }
210
211 // Align output sharding after LHS and RHS sharding are consistent.
212 auto aligned_output_sharding = hlo_sharding_util::TransposeSharding(
213 lhs.sharding(), lhs_to_output_indices);
214
215 TF_ASSIGN_OR_RETURN(
216 auto sharded_conv,
217 create_sharded_conv(lhs.hlo(), rhs.hlo(), b, conv_window));
218 sharded_conv->set_sharding(aligned_output_sharding);
219 return PartitionedHlo(sharded_conv, output_base_shape, lhs.state())
220 .Reshard(output_sharding)
221 .hlo();
222 }
223
224 // Partition convolution when both LHS and RHS are partitioned at spatial
225 // dimensions. Halo exchange will happen on RHS only.
226 StatusOr<HloInstruction*>
PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,HloInstruction * partition_id,HloModule * module,SpmdBuilder * b)227 PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
228 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
229 const HloSharding& output_sharding,
230 const std::function<StatusOr<HloInstruction*>(
231 HloInstruction*, HloInstruction*, SpmdBuilder*,
232 const Window& conv_window)>& create_sharded_conv,
233 const Window& conv_window, HloInstruction* original_hlo,
234 HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
235 TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
236 TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
237 !rhs.sharding().IsTileMaximal());
238
239 const auto& dnums = original_hlo->convolution_dimension_numbers();
240 std::vector<int64_t> rhs_to_lhs_indices(output_base_shape.rank());
241 rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
242 dnums.input_batch_dimension();
243 rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
244 dnums.input_feature_dimension();
245 for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
246 rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
247 dnums.input_spatial_dimensions(i);
248 }
249 std::vector<int64_t> lhs_to_rhs_indices(output_base_shape.rank());
250 for (int64_t i = 0; i < rhs_to_lhs_indices.size(); ++i) {
251 lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
252 }
253 auto aligned_rhs_sharding =
254 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices);
255 auto aligned_lhs_sharding =
256 hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
257
258 auto unsupported_sharding = [&](const HloSharding& lhs_sharding,
259 const HloSharding& rhs_sharding) {
260 // We currently don't support partitioning input batch or output feature
261 // dimensions.
262 return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) !=
263 1 ||
264 rhs_sharding.tile_assignment().dim(
265 dnums.kernel_output_feature_dimension()) != 1;
266 };
267
268 if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) {
269 if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) {
270 return nullptr;
271 }
272 lhs = lhs.Reshard(aligned_lhs_sharding).PadWithZero();
273 rhs = rhs.PadWithZero();
274 } else {
275 if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) {
276 return nullptr;
277 }
278 lhs = lhs.PadWithZero();
279 rhs = rhs.Reshard(aligned_rhs_sharding).PadWithZero();
280 }
281
282 if (original_hlo->feature_group_count() > 1 &&
283 (lhs.sharding().tile_assignment().dim(dnums.input_feature_dimension()) >
284 1 ||
285 rhs.sharding().tile_assignment().dim(
286 dnums.kernel_output_feature_dimension()) > 1)) {
287 return nullptr;
288 }
289
290 if (original_hlo->batch_group_count() > 1 &&
291 (lhs.sharding().tile_assignment().dim(dnums.input_batch_dimension()) >
292 1 ||
293 rhs.sharding().tile_assignment().dim(
294 dnums.kernel_output_feature_dimension()) > 1)) {
295 return nullptr;
296 }
297
298 // Reshard RHS so that each shard computes the partial sum of the full
299 // shape result, and add AllReduce. See HandleConvolutionTiledLhsAndRhs()
300 // that reshards LHS.
301 //
302 // The size of halo on each dimension can be calculated from the
303 // projection onto the RHS that shard i needs to read. RHS and LHS below
304 // refers to the shard size of RHS and LHS, WC is the number of windows,
305 // and D is the window dilation.
306 //
307 // * offset(i): LHS * i + low_padding - (WC - 1) * stride
308 // * limit(i): LHS * (i + 1) + low_padding
309 //
310 // Since shard i has RHS of range [i * RHS * D, (i + 1) * RHS * D)
311 // * left-halo: i * RHS - offset(i)
312 // = i * (RHS * D - LHS) + (WC - 1) * stride - low_padding
313 // * right-halo: limit(i) - (i + 1) * RHS
314 // = (i + 1) * (LHS - RHS * D) + low_pading
315 const auto& collective_ops_creator = lhs.state().collective_ops_creator;
316 std::vector<int64_t> shard_counts(dnums.input_spatial_dimensions_size());
317 std::vector<int64_t> lhs_shard_sizes(dnums.input_spatial_dimensions_size());
318 std::vector<int64_t> rhs_shard_sizes(dnums.input_spatial_dimensions_size());
319
320 for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
321 int64_t lhs_dimension = dnums.input_spatial_dimensions(i);
322 int64_t rhs_dimension = dnums.kernel_spatial_dimensions(i);
323 int64_t shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension);
324 const auto& wd = conv_window.dimensions(i);
325 if (wd.base_dilation() != 1 || wd.window_reversal()) {
326 return nullptr;
327 }
328
329 int64_t lhs_shard_size =
330 CeilOfRatio(lhs.base_shape().dimensions(lhs_dimension), shard_count);
331 int64_t rhs_shard_size =
332 CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count);
333 shard_counts[i] = shard_count;
334 lhs_shard_sizes[i] = lhs_shard_size;
335 rhs_shard_sizes[i] = rhs_shard_size;
336 }
337
338 std::vector<OffsetCalculation> left_halo_size_functions(
339 output_base_shape.rank());
340 std::vector<OffsetCalculation> right_halo_size_functions(
341 output_base_shape.rank());
342 Window new_window = conv_window;
343
344 // Data structures needed for Pad and DynamicSlice on LHS if needed.
345 bool need_dynamic_slice_lhs = false;
346 auto partition_ordinals =
347 MakeTiledPartitionOrdinals(lhs.sharding(), partition_id, b);
348 std::vector<int64_t> zero_padding(output_base_shape.rank());
349 PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding);
350 auto zero_s32 =
351 b->AddInstruction(HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
352 std::vector<HloInstruction*> dynamic_slice_start_indices(
353 output_base_shape.rank(), zero_s32);
354 Shape dynamic_slice_shape = lhs.hlo()->shape();
355 Shape pad_shape = lhs.hlo()->shape();
356
357 for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
358 int64_t lhs_dimension = dnums.input_spatial_dimensions(i);
359 int64_t rhs_dimension = dnums.kernel_spatial_dimensions(i);
360 int64_t lhs_shard_size = lhs_shard_sizes[i];
361 int64_t rhs_shard_size = rhs_shard_sizes[i];
362
363 if (shard_counts[i] == 1) {
364 continue;
365 }
366
367 // Calculate the left and right halo sizes as described in the comments
368 // above. It calculcates the halo sizes with dilation, so we apply
369 // CeilOfRatio({left,right}_halo_size, window_dilation).
370 const auto& wd = conv_window.dimensions(i);
371 int64_t padding_low = wd.padding_low();
372 int64_t padding_high = wd.padding_high();
373 int64_t base = lhs.base_shape().dimensions(lhs_dimension);
374 int64_t window_count = 1 + (padding_low + padding_high + base -
375 (1 + (wd.size() - 1) * wd.window_dilation())) /
376 wd.stride();
377 left_halo_size_functions[rhs_dimension] =
378 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
379 rhs_shard_size * wd.window_dilation() - lhs_shard_size,
380 (window_count - 1) * wd.stride() - padding_low +
381 wd.window_dilation() - 1,
382 wd.window_dilation()));
383 right_halo_size_functions[rhs_dimension] =
384 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
385 lhs_shard_size - rhs_shard_size * wd.window_dilation(),
386 lhs_shard_size - rhs_shard_size * wd.window_dilation() +
387 padding_low + wd.window_dilation() - 1,
388 wd.window_dilation()));
389
390 // New RHS window size includes the maximum of both left and right
391 // halos.
392 int64_t halo_size =
393 left_halo_size_functions[rhs_dimension].MaxInRange(1, shard_counts[i]) +
394 right_halo_size_functions[rhs_dimension].MaxInRange(
395 0, shard_counts[i] - 1);
396 int64_t new_window_size =
397 rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size;
398
399 // The amount of new low padding could be dynamic (e.g., window_dilation
400 // != 1), which requires pad (to the maximum) and dynamic slice on LHS.
401 //
402 // If we consider the first window, the offset of the dilated RHS that
403 // aligns with the first valid LHS element for shard i is 'padding_low +
404 // LHS * i'. When the left halo is added to RHS, the offset of the first
405 // RHS element is (RHS * i - left_halo) * window_dilation. The
406 // difference between the two values is the amount of padding_low we
407 // need on LHS.
408 auto new_padding_low_function =
409 OffsetCalculation(HloOpcode::kMultiply,
410 left_halo_size_functions[rhs_dimension],
411 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
412 0, wd.window_dilation(), 1))) -
413 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
414 rhs_shard_size * wd.window_dilation() - lhs_shard_size,
415 -padding_low, 1));
416
417 int64_t new_padding_low_max =
418 new_padding_low_function.MaxInRange(0, shard_counts[i]);
419 int64_t new_padding_low = new_padding_low_max;
420 int64_t new_padding_high = window_count * wd.stride() +
421 (new_window_size - 1) * wd.window_dilation() -
422 new_padding_low - lhs_shard_size;
423
424 // We do pad/dynamic-slice only when the padding is dynamic.
425 if (!new_padding_low_function.IsConstant()) {
426 need_dynamic_slice_lhs = true;
427 new_padding_low = 0;
428 pad_config.mutable_dimensions(lhs_dimension)
429 ->set_edge_padding_low(new_padding_low_max);
430 pad_config.mutable_dimensions(lhs_dimension)
431 ->set_edge_padding_high(new_padding_low_max);
432 pad_shape.set_dimensions(lhs_dimension,
433 lhs_shard_size + 2 * new_padding_low_max);
434 dynamic_slice_start_indices[lhs_dimension] =
435 (OffsetCalculation(
436 MultiplyAddDivideOffsetCalculation(0, new_padding_low_max, 1)) -
437 new_padding_low_function)
438 .Calculate(partition_ordinals[lhs_dimension], b);
439 dynamic_slice_shape.set_dimensions(lhs_dimension,
440 lhs_shard_size + new_padding_low_max);
441 }
442
443 // Since the convolution RHS operand size increased with halos, adjust
444 // the window config accordingly.
445 new_window.mutable_dimensions(i)->set_padding_low(new_padding_low);
446 new_window.mutable_dimensions(i)->set_padding_high(new_padding_high);
447 new_window.mutable_dimensions(i)->set_size(
448 rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size);
449 }
450
451 HloInstruction* conv_lhs = lhs.hlo();
452 if (need_dynamic_slice_lhs) {
453 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
454 LiteralUtil::Zero(lhs.hlo()->shape().element_type())));
455 auto pad = b->AddInstruction(
456 HloInstruction::CreatePad(pad_shape, lhs.hlo(), zero, pad_config));
457 conv_lhs = b->AddInstruction(HloInstruction::CreateDynamicSlice(
458 dynamic_slice_shape, pad, dynamic_slice_start_indices,
459 dynamic_slice_shape.dimensions()));
460 }
461
462 // Exchange halo and concatenate.
463 HloInstruction* rhs_with_halo = rhs.hlo();
464 for (int i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
465 int64_t dim = dnums.kernel_spatial_dimensions(i);
466 int64_t explicit_left_padding_on_full_shape =
467 left_halo_size_functions[dim].Calculate(0);
468 int64_t shard_size_with_halo = new_window.dimensions(i).size();
469
470 // offset_on_padded_shape and padded_full_shape_size are needed only if
471 // we want to mask out-of-range values in ExchangeHaloAndGetValidData().
472 // Since the default value for both the collective-permute is zero and
473 // also we call PadWithValue() on both operands at the beginning, we
474 // don't need to mask here.
475 //
476 // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls
477 // if it's always safe.
478 auto offset_on_padded_shape =
479 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
480 rhs_shard_sizes[i], explicit_left_padding_on_full_shape, 1)) -
481 left_halo_size_functions[dim];
482 int64_t padded_full_shape_size =
483 offset_on_padded_shape.Calculate(shard_counts[i] - 1) +
484 new_window.dimensions(i).size();
485 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
486 LiteralUtil::Zero(rhs.hlo()->shape().element_type())));
487 auto concat = ExchangeHaloAndGetValidData(
488 rhs_with_halo, rhs.base_shape(), left_halo_size_functions[dim],
489 right_halo_size_functions[dim], explicit_left_padding_on_full_shape,
490 padded_full_shape_size, shard_size_with_halo, dim, rhs.sharding(),
491 offset_on_padded_shape.Calculate(partition_ordinals[dim], b), zero,
492 partition_ordinals[dim], collective_ops_creator,
493 lhs.state().next_channel_id, b,
494 /*mask_invalid_region=*/false);
495 if (!concat) {
496 return nullptr;
497 }
498 rhs_with_halo = *concat;
499 }
500
501 TF_ASSIGN_OR_RETURN(
502 auto conv, create_sharded_conv(conv_lhs, rhs_with_halo, b, new_window));
503
504 auto ar = collective_ops_creator.create_cross_partition_all_reduce(
505 b, conv, MakeBinaryAdd(original_hlo->shape().element_type(), module), {},
506 (*lhs.state().next_channel_id)++);
507 ar->set_sharding(HloSharding::Replicate());
508 return PartitionedHlo(ar, output_base_shape, lhs.state())
509 .Reshard(output_sharding)
510 .hlo();
511 }
512
513 // Partition convolution when both LHS and RHS are partitioned at spatial
514 // dimensions. Halo exchange will happen on LHS only.
515 StatusOr<HloInstruction*>
PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,HloInstruction * partition_id,HloModule * module,SpmdBuilder * b)516 PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
517 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
518 const HloSharding& output_sharding,
519 const std::function<StatusOr<HloInstruction*>(
520 HloInstruction*, HloInstruction*, SpmdBuilder*,
521 const Window& conv_window)>& create_sharded_conv,
522 const Window& conv_window, HloInstruction* original_hlo,
523 HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
524 TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
525 TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
526 !rhs.sharding().IsTileMaximal());
527
528 const auto& dnums = original_hlo->convolution_dimension_numbers();
529
530 // Check if the operand shardings are aligned. Also we currently don't
531 // support partitioning non-spatial dimensions.
532 std::vector<int64_t> rhs_to_lhs_indices(output_base_shape.rank());
533 rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
534 dnums.input_batch_dimension();
535 rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
536 dnums.input_feature_dimension();
537 for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
538 rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
539 dnums.input_spatial_dimensions(i);
540 }
541 std::vector<int64_t> lhs_to_rhs_indices(output_base_shape.rank());
542 for (int64_t i = 0; i < rhs_to_lhs_indices.size(); ++i) {
543 lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
544 }
545
546 const Window& window = conv_window;
547 std::vector<int64_t> reversed_rhs_dims;
548 for (int64_t i = 0; i < window.dimensions_size(); ++i) {
549 if (window.dimensions(i).window_reversal()) {
550 reversed_rhs_dims.push_back(dnums.kernel_spatial_dimensions(i));
551 }
552 }
553 if (!reversed_rhs_dims.empty()) {
554 // Make the reversed dims left-padded to prepare for window reversal.
555 auto left_padded_rhs = HaloExchangeToPadOnLeft(rhs, reversed_rhs_dims);
556 if (left_padded_rhs == nullptr) {
557 return nullptr;
558 }
559 left_padded_rhs->set_sharding(rhs.sharding());
560 rhs = PartitionedHlo(left_padded_rhs, rhs.base_shape(), rhs.state());
561 }
562 // Consider window reversal when resharding RHS or LHS. Note: this will not
563 // reverse the data in the shard. We use window reversal to do that.
564 auto aligned_rhs_sharding = hlo_sharding_util::ReverseSharding(
565 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices),
566 reversed_rhs_dims);
567 auto aligned_lhs_sharding = hlo_sharding_util::TransposeSharding(
568 hlo_sharding_util::ReverseSharding(rhs.sharding(), reversed_rhs_dims),
569 lhs_to_rhs_indices);
570
571 auto unsupported_sharding = [&](const HloSharding& lhs_sharding,
572 const HloSharding& rhs_sharding) {
573 return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) !=
574 1 ||
575 rhs_sharding.tile_assignment().dim(
576 dnums.kernel_output_feature_dimension()) != 1;
577 };
578
579 if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) {
580 if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) {
581 return nullptr;
582 }
583 lhs = lhs.Reshard(aligned_lhs_sharding).PadWithZero();
584 rhs = rhs.PadWithZero(reversed_rhs_dims);
585 } else {
586 if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) {
587 return nullptr;
588 }
589 lhs = lhs.PadWithZero();
590 rhs = rhs.Reshard(aligned_rhs_sharding).PadWithZero(reversed_rhs_dims);
591 }
592
593 if (original_hlo->feature_group_count() > 1 &&
594 (lhs.sharding().tile_assignment().dim(dnums.input_feature_dimension()) >
595 1 ||
596 rhs.sharding().tile_assignment().dim(
597 dnums.kernel_output_feature_dimension()) > 1)) {
598 return nullptr;
599 }
600
601 if (original_hlo->batch_group_count() > 1 &&
602 (lhs.sharding().tile_assignment().dim(dnums.input_batch_dimension()) >
603 1 ||
604 rhs.sharding().tile_assignment().dim(
605 dnums.kernel_output_feature_dimension()) > 1)) {
606 return nullptr;
607 }
608 // Reshard LHS by exchanging halo such that each shard computes the partial
609 // sum of the full shape result, and add AllReduce.
610 //
611 // The size of halo on each dimension can be calculated from the projection
612 // onto the LHS that each RHS shard i needs to read. RHS and LHS below refers
613 // to the shard size of RHS and LHS, WC is the number of windows, and D is the
614 // window dilation.
615 //
616 // * offset(i): RHS * D * i - low_padding
617 // * limit(i): {RHS * (i + 1) * D - (D - 1)} + (WC - 1) * stride - low_padding
618 //
619 // Since shard i has LHS of range [i * LHS, (i + 1) * LHS)
620 // * left-halo: i * LHS - offset(i)
621 // = (LHS - RHS * D) * i + low_padding
622 // * right-halo: limit(i) - (i + 1) * LHS
623 // = (RHS * D - LHS) * (i + 1) + (1 - D) + (WC - 1) * stride - low_padding
624 // = (RHS * D - LHS) * i + (RHS * D - LHS) + (1-D)
625 // + (WC - 1) * stride - low_padding
626 std::vector<int64_t> shard_counts(dnums.input_spatial_dimensions_size());
627 std::vector<int64_t> lhs_shard_sizes(dnums.input_spatial_dimensions_size());
628 std::vector<int64_t> rhs_shard_sizes(dnums.input_spatial_dimensions_size());
629 for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
630 int64_t lhs_dimension = dnums.input_spatial_dimensions(i);
631 int64_t rhs_dimension = dnums.kernel_spatial_dimensions(i);
632 int64_t shard_count = lhs.sharding().tile_assignment().dim(lhs_dimension);
633 const auto& wd = window.dimensions(i);
634 if (wd.base_dilation() != 1) {
635 // TODO(wangtao): support parallel dim if it is replicate here.
636 return nullptr;
637 }
638
639 int64_t lhs_shard_size =
640 CeilOfRatio(lhs.base_shape().dimensions(lhs_dimension), shard_count);
641 int64_t rhs_shard_size =
642 CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count);
643 shard_counts[i] = shard_count;
644 lhs_shard_sizes[i] = lhs_shard_size;
645 rhs_shard_sizes[i] = rhs_shard_size;
646 }
647
648 std::vector<OffsetCalculation> left_halo_size_functions(
649 output_base_shape.rank());
650 std::vector<OffsetCalculation> right_halo_size_functions(
651 output_base_shape.rank());
652 Window new_window = window;
653
654 auto partition_ordinals =
655 MakeTiledPartitionOrdinals(lhs.sharding(), partition_id, b);
656 HloInstruction* lhs_with_halo = lhs.hlo();
657 for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
658 int64_t lhs_dimension = dnums.input_spatial_dimensions(i);
659 int64_t lhs_shard_size = lhs_shard_sizes[i];
660 int64_t rhs_shard_size = rhs_shard_sizes[i];
661
662 if (shard_counts[i] == 1) {
663 continue;
664 }
665
666 // Calculate the left and right halo sizes as described in the comments
667 // above.
668 const auto& wd = window.dimensions(i);
669 int64_t padding_low = wd.padding_low();
670 int64_t padding_high = wd.padding_high();
671 int64_t base = lhs.base_shape().dimensions(lhs_dimension);
672 int64_t window_count = 1 + (padding_low + padding_high + base -
673 (1 + (wd.size() - 1) * wd.window_dilation())) /
674 wd.stride();
675 int64_t rhs_shard_size_dilated =
676 (rhs_shard_size - 1) * wd.window_dilation() + 1;
677
678 left_halo_size_functions[lhs_dimension] =
679 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
680 lhs_shard_size - rhs_shard_size * wd.window_dilation(), padding_low,
681 1));
682 right_halo_size_functions[lhs_dimension] =
683 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
684 rhs_shard_size * wd.window_dilation() - lhs_shard_size,
685 rhs_shard_size * wd.window_dilation() - lhs_shard_size + 1 -
686 wd.window_dilation() + wd.stride() * (window_count - 1) -
687 padding_low,
688 1));
689
690 // Exchange halo and concatenate.
691 int64_t dim = dnums.input_spatial_dimensions(i);
692 int64_t explicit_left_padding_on_full_shape = padding_low;
693 int64_t shard_size_with_halo =
694 wd.stride() * (window_count - 1) + rhs_shard_size_dilated;
695
696 new_window.mutable_dimensions(i)->set_padding_low(0);
697 new_window.mutable_dimensions(i)->set_padding_high(0);
698 new_window.mutable_dimensions(i)->set_size(rhs_shard_size);
699
700 // offset_on_padded_shape and padded_full_shape_size are needed only if
701 // we want to mask out-of-range values in ExchangeHaloAndGetValidData().
702 // Since the default value for both the collective-permute is zero and
703 // also we call PadWithValue() on both operands at the beginning, we
704 // don't need to mask here.
705 //
706 // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls
707 // if it's always safe.
708 auto offset_on_padded_shape =
709 OffsetCalculation(MultiplyAddDivideOffsetCalculation());
710 int64_t padded_full_shape_size = 0;
711
712 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
713 LiteralUtil::Zero(lhs.hlo()->shape().element_type())));
714 auto concat = ExchangeHaloAndGetValidData(
715 lhs_with_halo, lhs.base_shape(), left_halo_size_functions[dim],
716 right_halo_size_functions[dim], explicit_left_padding_on_full_shape,
717 padded_full_shape_size, shard_size_with_halo, dim, lhs.sharding(),
718 offset_on_padded_shape.Calculate(partition_ordinals[dim], b), zero,
719 partition_ordinals[dim], lhs.state().collective_ops_creator,
720 lhs.state().next_channel_id, b,
721 /*mask_invalid_region=*/false);
722 if (!concat) {
723 return nullptr;
724 }
725 lhs_with_halo = *concat;
726 }
727
728 TF_ASSIGN_OR_RETURN(
729 auto conv, create_sharded_conv(lhs_with_halo, rhs.hlo(), b, new_window));
730 auto ar =
731 lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
732 b, conv, MakeBinaryAdd(output_base_shape.element_type(), module), {},
733 (*lhs.state().next_channel_id)++);
734 ar->set_sharding(HloSharding::Replicate());
735 return PartitionedHlo(ar, output_base_shape, lhs.state())
736 .Reshard(output_sharding)
737 .hlo();
738 }
739
740 // Partition convolution when output is sharded. Will shard LHS with replicated
741 // RHS.
PartitionConvolutionTiledOutput(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,SpmdBuilder * b)742 StatusOr<HloInstruction*> PartitionConvolutionTiledOutput(
743 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
744 const HloSharding& output_sharding,
745 const std::function<StatusOr<HloInstruction*>(
746 HloInstruction*, HloInstruction*, SpmdBuilder*,
747 const Window& conv_window)>& create_sharded_conv,
748 const Window& conv_window, HloInstruction* original_hlo, SpmdBuilder* b) {
749 TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
750 const auto& dnums = original_hlo->convolution_dimension_numbers();
751 TF_RET_CHECK(!output_sharding.IsTileMaximal());
752 // We don't currently support sharding on output feature dimension.
753 if (output_sharding.tile_assignment().dim(dnums.output_feature_dimension()) >
754 1) {
755 return nullptr;
756 }
757
758 // Check if the operand and the output sharding are aligned.
759 std::vector<int64_t> input_to_output_indices(output_base_shape.rank());
760 input_to_output_indices[dnums.input_batch_dimension()] =
761 dnums.output_batch_dimension();
762 input_to_output_indices[dnums.input_feature_dimension()] =
763 dnums.output_feature_dimension();
764 for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
765 input_to_output_indices[dnums.input_spatial_dimensions(i)] =
766 dnums.output_spatial_dimensions(i);
767 }
768 auto target_operand_sharding = hlo_sharding_util::TransposeSharding(
769 output_sharding, input_to_output_indices);
770 lhs = lhs.Reshard(target_operand_sharding);
771
772 // Replicate the RHS.
773 rhs = rhs.Reshard(HloSharding::Replicate());
774
775 // Convolution window config does not include batch and feature dimensions,
776 // whereas ReshardAsWindowedInput() expects the same number of window
777 // dimensions as the rank of the operand. So add two more trivial
778 // dimensions.
779 std::vector<int64_t> ones(output_base_shape.rank(), 1);
780 auto operand_window = window_util::MakeWindow(ones);
781 for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
782 *operand_window.mutable_dimensions(dnums.input_spatial_dimensions(i)) =
783 conv_window.dimensions(i);
784 }
785
786 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
787 LiteralUtil::Zero(output_base_shape.element_type())));
788 auto resharded_operand_and_window =
789 lhs.ReshardAsWindowedInput(operand_window, target_operand_sharding, zero);
790 if (!resharded_operand_and_window.has_value()) {
791 return nullptr;
792 }
793 Window new_window;
794 for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
795 *new_window.add_dimensions() =
796 resharded_operand_and_window->shard_window.dimensions(
797 dnums.input_spatial_dimensions(i));
798 }
799
800 TF_ASSIGN_OR_RETURN(
801 auto sharded_conv,
802 create_sharded_conv(resharded_operand_and_window->sharded_input,
803 rhs.hlo(), b, new_window));
804
805 auto shard_shape = MakePartitionedShape(output_base_shape, output_sharding);
806 if (!resharded_operand_and_window->dynamic_slice_index_on_output
807 .has_value()) {
808 CHECK(ShapeUtil::Compatible(shard_shape, sharded_conv->shape()));
809 return sharded_conv;
810 }
811 return b->AddInstruction(HloInstruction::CreateDynamicSlice(
812 shard_shape, sharded_conv,
813 *resharded_operand_and_window->dynamic_slice_index_on_output,
814 shard_shape.dimensions()));
815 }
816
817 // Partition convolution with only one kind of dims partitioned.
PartitionConvolutionBaseCase(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,int64_t num_partitions,const SpmdPartitionerOptions & options,HloInstruction * partition_id,HloModule * module,SpmdBuilder * b)818 StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
819 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
820 const HloSharding& output_sharding,
821 const std::function<StatusOr<HloInstruction*>(
822 HloInstruction*, HloInstruction*, SpmdBuilder*,
823 const Window& conv_window)>& create_sharded_conv,
824 const Window& conv_window, HloInstruction* original_hlo,
825 int64_t num_partitions, const SpmdPartitionerOptions& options,
826 HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
827 TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
828
829 // Case 1: Handle depthwise convolution with batch group count or
830 // feature group count.
831 if (original_hlo->batch_group_count() > 1) {
832 TF_ASSIGN_OR_RETURN(
833 auto parallel_partitioned_conv,
834 PartitionConvolutionWithBatchGroupCount(
835 lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
836 conv_window, original_hlo, num_partitions, b));
837 if (parallel_partitioned_conv) {
838 return parallel_partitioned_conv;
839 }
840 }
841
842 if (original_hlo->feature_group_count() > 1) {
843 TF_ASSIGN_OR_RETURN(
844 auto parallel_partitioned_conv,
845 PartitionConvolutionWithFeatureGroupCount(
846 lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
847 conv_window, original_hlo, num_partitions, b));
848 if (parallel_partitioned_conv) {
849 return parallel_partitioned_conv;
850 }
851 }
852
853 // Case 2: both RHS and LHS are tiled.
854 // Handling cases where both operands' shardings are aligned. We check that
855 // the LHS batch dimension is not partitioned because it is mapped to the
856 // output feature dimension in aligned_rhs_sharding, which are not the same
857 // dimension.
858 if (!lhs.sharding().IsTileMaximal() && !rhs.sharding().IsTileMaximal()) {
859 if (options.conv_halo_exchange_always_on_lhs) {
860 TF_ASSIGN_OR_RETURN(
861 auto partitioned_conv,
862 PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
863 lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
864 conv_window, original_hlo, partition_id, module, b));
865 if (partitioned_conv) {
866 return partitioned_conv;
867 }
868 } else {
869 TF_ASSIGN_OR_RETURN(
870 auto partitioned_conv,
871 PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
872 lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
873 conv_window, original_hlo, partition_id, module, b));
874
875 if (partitioned_conv) {
876 return partitioned_conv;
877 }
878 }
879 }
880
881 // Case 3: output is tiled.
882 if (!output_sharding.IsTileMaximal()) {
883 TF_ASSIGN_OR_RETURN(auto partitioned_conv,
884 PartitionConvolutionTiledOutput(
885 lhs, rhs, output_base_shape, output_sharding,
886 create_sharded_conv, conv_window, original_hlo, b));
887
888 if (partitioned_conv) {
889 return partitioned_conv;
890 }
891 }
892 return nullptr;
893 }
894
CreateShardedConvConvolution(const HloInstruction & conv,const dot_as_convolution_util::DotConvolutionDimsInfo & dot_dnums,HloInstruction * sharded_lhs_hlo,HloInstruction * sharded_rhs_hlo,const Window & conv_window)895 StatusOr<std::unique_ptr<HloInstruction>> CreateShardedConvConvolution(
896 const HloInstruction& conv,
897 const dot_as_convolution_util::DotConvolutionDimsInfo& dot_dnums,
898 HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo,
899 const Window& conv_window) {
900 CHECK_EQ(conv.opcode(), HloOpcode::kConvolution);
901 const auto& conv_dnums = conv.convolution_dimension_numbers();
902 auto window = conv.window();
903 for (const auto& dim : dot_dnums.batch_dims) {
904 auto wd = window.mutable_dimensions(dim.spatial_dim);
905 wd->set_size(sharded_lhs_hlo->shape().dimensions(
906 conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
907 wd->set_stride(std::max<int64_t>(1, wd->size() - 1));
908 wd->set_base_dilation(wd->size());
909 }
910 for (const auto& dim : dot_dnums.contracting_dims) {
911 if (dim.spatial_dim < 0) {
912 continue;
913 }
914 auto wd = window.mutable_dimensions(dim.spatial_dim);
915 wd->set_size(sharded_lhs_hlo->shape().dimensions(
916 conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
917 }
918 for (const auto& dim : dot_dnums.rhs_non_contracting_dims) {
919 if (dim.spatial_dim < 0) {
920 continue;
921 }
922 auto wd = window.mutable_dimensions(dim.spatial_dim);
923 wd->set_size(sharded_rhs_hlo->shape().dimensions(
924 conv_dnums.kernel_spatial_dimensions(dim.spatial_dim)));
925 wd->set_padding_high(wd->size() - 1);
926 wd->set_padding_low(wd->size() - 1);
927 }
928
929 for (const auto& dim : dot_dnums.conv_spatial_dims) {
930 auto wd = window.mutable_dimensions(dim.spatial_dim);
931 const auto& new_window_dimension = conv_window.dimensions(dim.spatial_dim);
932 wd->set_size(new_window_dimension.size());
933 wd->set_padding_high(new_window_dimension.padding_high());
934 wd->set_padding_low(new_window_dimension.padding_low());
935 }
936
937 int64_t feature_group_count = conv.feature_group_count();
938 if (feature_group_count > 1) {
939 feature_group_count = sharded_lhs_hlo->shape().dimensions(
940 conv_dnums.input_feature_dimension()) /
941 sharded_rhs_hlo->shape().dimensions(
942 conv_dnums.kernel_input_feature_dimension());
943 }
944
945 int64_t batch_group_count = conv.batch_group_count();
946 if (batch_group_count > 1) {
947 batch_group_count =
948 sharded_lhs_hlo->shape().dimensions(conv_dnums.input_batch_dimension());
949 }
950
951 TF_ASSIGN_OR_RETURN(
952 Shape sharded_conv_shape,
953 ShapeInference::InferConvolveShape(
954 sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
955 feature_group_count, batch_group_count, window, conv_dnums,
956 /*preferred_element_type=*/conv.shape().element_type()));
957 *sharded_conv_shape.mutable_layout() = conv.shape().layout();
958 return HloInstruction::CreateConvolve(
959 sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo, feature_group_count,
960 batch_group_count, window, conv_dnums, conv.precision_config());
961 }
962
963 } // namespace
964
965 // Partition convolution.
PartitionConvolution(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,int64_t num_partitions,const SpmdPartitionerOptions & options,HloInstruction * partition_id,HloModule * module,SpmdBuilder * b)966 StatusOr<HloInstruction*> PartitionConvolution(
967 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
968 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
969 const std::function<StatusOr<HloInstruction*>(
970 HloInstruction*, HloInstruction*, SpmdBuilder*,
971 const Window& conv_window)>& create_sharded_conv,
972 const Window& conv_window, HloInstruction* original_hlo,
973 int64_t num_partitions, const SpmdPartitionerOptions& options,
974 HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
975 TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
976
977 TF_ASSIGN_OR_RETURN(auto try_partitioned_conv,
978 PartitionConvolutionBaseCase(
979 lhs, rhs, output_base_shape, output_sharding,
980 create_sharded_conv, conv_window, original_hlo,
981 num_partitions, options, partition_id, module, b));
982 if (try_partitioned_conv) {
983 return try_partitioned_conv;
984 }
985
986 return nullptr;
987 }
988
HandleConvolution(HloInstruction * hlo)989 Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
990 if (hlo->sharding().HasUniqueDevice()) {
991 return DefaultAction(hlo);
992 }
993 auto dims_info = dot_as_convolution_util::ParseConvolutionDimsInfo(hlo);
994 spmd::DotConvDimsMapping mapping;
995 for (const auto& dims : dims_info.batch_dims) {
996 mapping.batch_dims.emplace_back();
997 mapping.batch_dims.back().lhs = dims.lhs;
998 mapping.batch_dims.back().rhs = dims.rhs;
999 mapping.batch_dims.back().output = dims.output;
1000 mapping.batch_dims.back().spatial = dims.spatial_dim;
1001 }
1002 for (const auto& dims : dims_info.contracting_dims) {
1003 mapping.contracting_dims.emplace_back();
1004 mapping.contracting_dims.back().lhs = dims.lhs;
1005 mapping.contracting_dims.back().rhs = dims.rhs;
1006 mapping.contracting_dims.back().output = dims.output;
1007 mapping.contracting_dims.back().spatial = dims.spatial_dim;
1008 }
1009 for (const auto& dims : dims_info.lhs_non_contracting_dims) {
1010 mapping.lhs_non_contracting_dims.emplace_back();
1011 mapping.lhs_non_contracting_dims.back().lhs = dims.lhs;
1012 mapping.lhs_non_contracting_dims.back().rhs = dims.rhs;
1013 mapping.lhs_non_contracting_dims.back().output = dims.output;
1014 mapping.lhs_non_contracting_dims.back().spatial = dims.spatial_dim;
1015 }
1016 for (const auto& dims : dims_info.rhs_non_contracting_dims) {
1017 mapping.rhs_non_contracting_dims.emplace_back();
1018 mapping.rhs_non_contracting_dims.back().lhs = dims.lhs;
1019 mapping.rhs_non_contracting_dims.back().rhs = dims.rhs;
1020 mapping.rhs_non_contracting_dims.back().output = dims.output;
1021 mapping.rhs_non_contracting_dims.back().spatial = dims.spatial_dim;
1022 }
1023 for (const auto& dims : dims_info.conv_spatial_dims) {
1024 mapping.conv_spatial_dims.emplace_back();
1025 mapping.conv_spatial_dims.back().lhs = dims.lhs;
1026 mapping.conv_spatial_dims.back().rhs = dims.rhs;
1027 mapping.conv_spatial_dims.back().output = dims.output;
1028 mapping.conv_spatial_dims.back().spatial = dims.spatial_dim;
1029 }
1030 auto create_sharded_conv =
1031 [&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo,
1032 spmd::SpmdBuilder* b,
1033 const Window& conv_window) -> StatusOr<HloInstruction*> {
1034 if (dims_info.conv_spatial_dims.empty() &&
1035 hlo->feature_group_count() == 1 && hlo->batch_group_count() == 1) {
1036 TF_ASSIGN_OR_RETURN(
1037 auto sharded_conv,
1038 dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
1039 *hlo, dims_info, lhs_hlo, rhs_hlo));
1040 return b->AddInstruction(std::move(sharded_conv));
1041 } else {
1042 TF_ASSIGN_OR_RETURN(auto sharded_conv,
1043 CreateShardedConvConvolution(*hlo, dims_info, lhs_hlo,
1044 rhs_hlo, conv_window));
1045 return b->AddInstruction(std::move(sharded_conv));
1046 }
1047 };
1048
1049 return HandleDotHelper(hlo, mapping, create_sharded_conv);
1050 }
1051
1052 } // namespace spmd
1053 } // namespace xla
1054