xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/spmd/convolution_handler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/spmd/convolution_handler.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
25 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
26 #include "tensorflow/compiler/xla/service/shape_inference.h"
27 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
28 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/compiler/xla/window_util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/numbers.h"
34 
35 namespace xla {
36 namespace spmd {
37 
38 namespace {
39 
40 // Partition convolution with batch group count.
PartitionConvolutionWithBatchGroupCount(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,int64_t num_partitions,SpmdBuilder * b)41 StatusOr<HloInstruction*> PartitionConvolutionWithBatchGroupCount(
42     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
43     const HloSharding& output_sharding,
44     const std::function<StatusOr<HloInstruction*>(
45         HloInstruction*, HloInstruction*, SpmdBuilder*,
46         const Window& conv_window)>& create_sharded_conv,
47     const Window& conv_window, HloInstruction* original_hlo,
48     int64_t num_partitions, SpmdBuilder* b) {
49   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
50   if (original_hlo->batch_group_count() == 1 ||
51       original_hlo->batch_group_count() % num_partitions != 0) {
52     return nullptr;
53   }
54 
55   const auto& dnums = original_hlo->convolution_dimension_numbers();
56   // Only supports batch_group_size equals input_batch_size case.
57   const int64_t input_batch_size =
58       lhs.base_shape().dimensions(dnums.input_batch_dimension());
59   const int64_t kernel_output_feature_size =
60       rhs.base_shape().dimensions(dnums.kernel_output_feature_dimension());
61   if (input_batch_size != kernel_output_feature_size ||
62       original_hlo->batch_group_count() != input_batch_size) {
63     return nullptr;
64   }
65 
66   // Map RHS indices to LHS indices.
67   std::vector<int64_t> rhs_to_lhs_indices(output_base_shape.rank());
68   rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
69       dnums.input_batch_dimension();
70   rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
71       dnums.input_feature_dimension();
72   for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
73     rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
74         dnums.input_spatial_dimensions(i);
75   }
76 
77   // Map LHS indices to RHS indices.
78   std::vector<int64_t> lhs_to_rhs_indices(output_base_shape.rank());
79   for (int64_t i = 0; i < rhs_to_lhs_indices.size(); ++i) {
80     lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
81   }
82 
83   // Map LHS indices to output indices.
84   std::vector<int64_t> lhs_to_output_indices(lhs.base_shape().rank(), -1);
85   lhs_to_output_indices[dnums.input_batch_dimension()] =
86       dnums.output_feature_dimension();
87   lhs_to_output_indices[dnums.input_feature_dimension()] =
88       dnums.output_batch_dimension();
89   for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
90     lhs_to_output_indices[dnums.input_spatial_dimensions(i)] =
91         dnums.output_spatial_dimensions(i);
92   }
93 
94   // Align LHS or RHS to other operand if input batch dim or kernel output
95   // feature dim is partitioned.
96   auto aligned_rhs_sharding =
97       hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices);
98   auto aligned_lhs_sharding =
99       hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
100 
101   bool lhs_batch_dim_is_partitioned =
102       (ShardCountAtDim(lhs.sharding(), dnums.input_batch_dimension()) ==
103        num_partitions);
104   bool rhs_output_feature_dim_is_partitioned =
105       (ShardCountAtDim(rhs.sharding(),
106                        dnums.kernel_output_feature_dimension()) ==
107        num_partitions);
108   if (!lhs_batch_dim_is_partitioned && !rhs_output_feature_dim_is_partitioned) {
109     return nullptr;
110   }
111   // Reshard LHS or RHS to partition at batch dimension or output feature
112   // dimension as the other operand.
113   if (lhs_batch_dim_is_partitioned) {
114     rhs = rhs.Reshard(aligned_rhs_sharding);
115   } else {
116     lhs = lhs.Reshard(aligned_lhs_sharding);
117   }
118   // Align output sharding after LHS and RHS sharding are consistent.
119   auto aligned_output_sharding = hlo_sharding_util::TransposeSharding(
120       lhs.sharding(), lhs_to_output_indices);
121 
122   // Create partitioned convolution.
123   TF_ASSIGN_OR_RETURN(
124       auto sharded_conv,
125       create_sharded_conv(lhs.hlo(), rhs.hlo(), b, conv_window));
126   sharded_conv->set_sharding(aligned_output_sharding);
127   return PartitionedHlo(sharded_conv, output_base_shape, lhs.state())
128       .Reshard(output_sharding)
129       .hlo();
130 }
131 
132 // Partition convolution with feature group count.
PartitionConvolutionWithFeatureGroupCount(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,int64_t num_partitions,SpmdBuilder * b)133 StatusOr<HloInstruction*> PartitionConvolutionWithFeatureGroupCount(
134     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
135     const HloSharding& output_sharding,
136     const std::function<StatusOr<HloInstruction*>(
137         HloInstruction*, HloInstruction*, SpmdBuilder*,
138         const Window& conv_window)>& create_sharded_conv,
139     const Window& conv_window, HloInstruction* original_hlo,
140     int64_t num_partitions, SpmdBuilder* b) {
141   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
142   if (original_hlo->feature_group_count() == 1 ||
143       original_hlo->feature_group_count() % num_partitions != 0) {
144     return nullptr;
145   }
146 
147   const auto& dnums = original_hlo->convolution_dimension_numbers();
148   const int64_t input_feature_size =
149       lhs.base_shape().dimensions(dnums.input_feature_dimension());
150   const int64_t kernel_output_feature_size =
151       rhs.base_shape().dimensions(dnums.kernel_output_feature_dimension());
152   if (kernel_output_feature_size % original_hlo->feature_group_count() != 0 ||
153       input_feature_size % original_hlo->feature_group_count() != 0) {
154     return nullptr;
155   }
156 
157   // Align RHS indices to LHS.
158   std::vector<int64_t> rhs_to_lhs_indices(output_base_shape.rank());
159   rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
160       dnums.input_feature_dimension();
161   rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
162       dnums.input_batch_dimension();
163   for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
164     rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
165         dnums.input_spatial_dimensions(i);
166   }
167 
168   // Align LHS indices to RHS.
169   std::vector<int64_t> lhs_to_rhs_indices(output_base_shape.rank());
170   for (int64_t i = 0; i < rhs_to_lhs_indices.size(); ++i) {
171     lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
172   }
173 
174   // Align LHS indices to output.
175   std::vector<int64_t> lhs_to_output_indices(output_base_shape.rank());
176   lhs_to_output_indices[dnums.input_feature_dimension()] =
177       dnums.output_feature_dimension();
178   lhs_to_output_indices[dnums.input_batch_dimension()] =
179       dnums.output_batch_dimension();
180   for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
181     lhs_to_output_indices[dnums.input_spatial_dimensions(i)] =
182         dnums.output_spatial_dimensions(i);
183   }
184 
185   // Align LHS or RHS if input_feature_dim or kernel_output_feature_dim is
186   // partitioned.
187   auto aligned_rhs_sharding =
188       hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices);
189   auto aligned_lhs_sharding =
190       hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
191 
192   bool lhs_feature_dim_is_partitioned =
193       (ShardCountAtDim(lhs.sharding(), dnums.input_feature_dimension()) ==
194        num_partitions);
195   bool rhs_output_feature_dim_is_partitioned =
196       (ShardCountAtDim(rhs.sharding(),
197                        dnums.kernel_output_feature_dimension()) ==
198        num_partitions);
199   if (!lhs_feature_dim_is_partitioned &&
200       !rhs_output_feature_dim_is_partitioned) {
201     return nullptr;
202   }
203   // Reshard LHS or RHS to partition at input feature dimension or output
204   // feature dimension as the other operand.
205   if (lhs_feature_dim_is_partitioned) {
206     rhs = rhs.Reshard(aligned_rhs_sharding);
207   } else {
208     lhs = lhs.Reshard(aligned_lhs_sharding);
209   }
210 
211   // Align output sharding after LHS and RHS sharding are consistent.
212   auto aligned_output_sharding = hlo_sharding_util::TransposeSharding(
213       lhs.sharding(), lhs_to_output_indices);
214 
215   TF_ASSIGN_OR_RETURN(
216       auto sharded_conv,
217       create_sharded_conv(lhs.hlo(), rhs.hlo(), b, conv_window));
218   sharded_conv->set_sharding(aligned_output_sharding);
219   return PartitionedHlo(sharded_conv, output_base_shape, lhs.state())
220       .Reshard(output_sharding)
221       .hlo();
222 }
223 
224 // Partition convolution when both LHS and RHS are partitioned at spatial
225 // dimensions. Halo exchange will happen on RHS only.
226 StatusOr<HloInstruction*>
PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,HloInstruction * partition_id,HloModule * module,SpmdBuilder * b)227 PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
228     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
229     const HloSharding& output_sharding,
230     const std::function<StatusOr<HloInstruction*>(
231         HloInstruction*, HloInstruction*, SpmdBuilder*,
232         const Window& conv_window)>& create_sharded_conv,
233     const Window& conv_window, HloInstruction* original_hlo,
234     HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
235   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
236   TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
237                !rhs.sharding().IsTileMaximal());
238 
239   const auto& dnums = original_hlo->convolution_dimension_numbers();
240   std::vector<int64_t> rhs_to_lhs_indices(output_base_shape.rank());
241   rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
242       dnums.input_batch_dimension();
243   rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
244       dnums.input_feature_dimension();
245   for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
246     rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
247         dnums.input_spatial_dimensions(i);
248   }
249   std::vector<int64_t> lhs_to_rhs_indices(output_base_shape.rank());
250   for (int64_t i = 0; i < rhs_to_lhs_indices.size(); ++i) {
251     lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
252   }
253   auto aligned_rhs_sharding =
254       hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices);
255   auto aligned_lhs_sharding =
256       hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
257 
258   auto unsupported_sharding = [&](const HloSharding& lhs_sharding,
259                                   const HloSharding& rhs_sharding) {
260     // We currently don't support partitioning input batch or output feature
261     // dimensions.
262     return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) !=
263                1 ||
264            rhs_sharding.tile_assignment().dim(
265                dnums.kernel_output_feature_dimension()) != 1;
266   };
267 
268   if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) {
269     if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) {
270       return nullptr;
271     }
272     lhs = lhs.Reshard(aligned_lhs_sharding).PadWithZero();
273     rhs = rhs.PadWithZero();
274   } else {
275     if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) {
276       return nullptr;
277     }
278     lhs = lhs.PadWithZero();
279     rhs = rhs.Reshard(aligned_rhs_sharding).PadWithZero();
280   }
281 
282   if (original_hlo->feature_group_count() > 1 &&
283       (lhs.sharding().tile_assignment().dim(dnums.input_feature_dimension()) >
284            1 ||
285        rhs.sharding().tile_assignment().dim(
286            dnums.kernel_output_feature_dimension()) > 1)) {
287     return nullptr;
288   }
289 
290   if (original_hlo->batch_group_count() > 1 &&
291       (lhs.sharding().tile_assignment().dim(dnums.input_batch_dimension()) >
292            1 ||
293        rhs.sharding().tile_assignment().dim(
294            dnums.kernel_output_feature_dimension()) > 1)) {
295     return nullptr;
296   }
297 
298   // Reshard RHS so that each shard computes the partial sum of the full
299   // shape result, and add AllReduce. See HandleConvolutionTiledLhsAndRhs()
300   // that reshards LHS.
301   //
302   // The size of halo on each dimension can be calculated from the
303   // projection onto the RHS that shard i needs to read. RHS and LHS below
304   // refers to the shard size of RHS and LHS, WC is the number of windows,
305   // and D is the window dilation.
306   //
307   // * offset(i): LHS * i + low_padding - (WC - 1) * stride
308   // * limit(i): LHS * (i + 1) + low_padding
309   //
310   // Since shard i has RHS of range [i * RHS * D, (i + 1) * RHS * D)
311   // * left-halo: i * RHS - offset(i)
312   //              = i * (RHS * D - LHS) + (WC - 1) * stride - low_padding
313   // * right-halo: limit(i) - (i + 1) * RHS
314   //              = (i + 1) * (LHS - RHS * D) + low_pading
315   const auto& collective_ops_creator = lhs.state().collective_ops_creator;
316   std::vector<int64_t> shard_counts(dnums.input_spatial_dimensions_size());
317   std::vector<int64_t> lhs_shard_sizes(dnums.input_spatial_dimensions_size());
318   std::vector<int64_t> rhs_shard_sizes(dnums.input_spatial_dimensions_size());
319 
320   for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
321     int64_t lhs_dimension = dnums.input_spatial_dimensions(i);
322     int64_t rhs_dimension = dnums.kernel_spatial_dimensions(i);
323     int64_t shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension);
324     const auto& wd = conv_window.dimensions(i);
325     if (wd.base_dilation() != 1 || wd.window_reversal()) {
326       return nullptr;
327     }
328 
329     int64_t lhs_shard_size =
330         CeilOfRatio(lhs.base_shape().dimensions(lhs_dimension), shard_count);
331     int64_t rhs_shard_size =
332         CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count);
333     shard_counts[i] = shard_count;
334     lhs_shard_sizes[i] = lhs_shard_size;
335     rhs_shard_sizes[i] = rhs_shard_size;
336   }
337 
338   std::vector<OffsetCalculation> left_halo_size_functions(
339       output_base_shape.rank());
340   std::vector<OffsetCalculation> right_halo_size_functions(
341       output_base_shape.rank());
342   Window new_window = conv_window;
343 
344   // Data structures needed for Pad and DynamicSlice on LHS if needed.
345   bool need_dynamic_slice_lhs = false;
346   auto partition_ordinals =
347       MakeTiledPartitionOrdinals(lhs.sharding(), partition_id, b);
348   std::vector<int64_t> zero_padding(output_base_shape.rank());
349   PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding);
350   auto zero_s32 =
351       b->AddInstruction(HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
352   std::vector<HloInstruction*> dynamic_slice_start_indices(
353       output_base_shape.rank(), zero_s32);
354   Shape dynamic_slice_shape = lhs.hlo()->shape();
355   Shape pad_shape = lhs.hlo()->shape();
356 
357   for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
358     int64_t lhs_dimension = dnums.input_spatial_dimensions(i);
359     int64_t rhs_dimension = dnums.kernel_spatial_dimensions(i);
360     int64_t lhs_shard_size = lhs_shard_sizes[i];
361     int64_t rhs_shard_size = rhs_shard_sizes[i];
362 
363     if (shard_counts[i] == 1) {
364       continue;
365     }
366 
367     // Calculate the left and right halo sizes as described in the comments
368     // above. It calculcates the halo sizes with dilation, so we apply
369     // CeilOfRatio({left,right}_halo_size, window_dilation).
370     const auto& wd = conv_window.dimensions(i);
371     int64_t padding_low = wd.padding_low();
372     int64_t padding_high = wd.padding_high();
373     int64_t base = lhs.base_shape().dimensions(lhs_dimension);
374     int64_t window_count = 1 + (padding_low + padding_high + base -
375                                 (1 + (wd.size() - 1) * wd.window_dilation())) /
376                                    wd.stride();
377     left_halo_size_functions[rhs_dimension] =
378         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
379             rhs_shard_size * wd.window_dilation() - lhs_shard_size,
380             (window_count - 1) * wd.stride() - padding_low +
381                 wd.window_dilation() - 1,
382             wd.window_dilation()));
383     right_halo_size_functions[rhs_dimension] =
384         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
385             lhs_shard_size - rhs_shard_size * wd.window_dilation(),
386             lhs_shard_size - rhs_shard_size * wd.window_dilation() +
387                 padding_low + wd.window_dilation() - 1,
388             wd.window_dilation()));
389 
390     // New RHS window size includes the maximum of both left and right
391     // halos.
392     int64_t halo_size =
393         left_halo_size_functions[rhs_dimension].MaxInRange(1, shard_counts[i]) +
394         right_halo_size_functions[rhs_dimension].MaxInRange(
395             0, shard_counts[i] - 1);
396     int64_t new_window_size =
397         rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size;
398 
399     // The amount of new low padding could be dynamic (e.g., window_dilation
400     // != 1), which requires pad (to the maximum) and dynamic slice on LHS.
401     //
402     // If we consider the first window, the offset of the dilated RHS that
403     // aligns with the first valid LHS element for shard i is 'padding_low +
404     // LHS * i'. When the left halo is added to RHS, the offset of the first
405     // RHS element is (RHS * i - left_halo) * window_dilation. The
406     // difference between the two values is the amount of padding_low we
407     // need on LHS.
408     auto new_padding_low_function =
409         OffsetCalculation(HloOpcode::kMultiply,
410                           left_halo_size_functions[rhs_dimension],
411                           OffsetCalculation(MultiplyAddDivideOffsetCalculation(
412                               0, wd.window_dilation(), 1))) -
413         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
414             rhs_shard_size * wd.window_dilation() - lhs_shard_size,
415             -padding_low, 1));
416 
417     int64_t new_padding_low_max =
418         new_padding_low_function.MaxInRange(0, shard_counts[i]);
419     int64_t new_padding_low = new_padding_low_max;
420     int64_t new_padding_high = window_count * wd.stride() +
421                                (new_window_size - 1) * wd.window_dilation() -
422                                new_padding_low - lhs_shard_size;
423 
424     // We do pad/dynamic-slice only when the padding is dynamic.
425     if (!new_padding_low_function.IsConstant()) {
426       need_dynamic_slice_lhs = true;
427       new_padding_low = 0;
428       pad_config.mutable_dimensions(lhs_dimension)
429           ->set_edge_padding_low(new_padding_low_max);
430       pad_config.mutable_dimensions(lhs_dimension)
431           ->set_edge_padding_high(new_padding_low_max);
432       pad_shape.set_dimensions(lhs_dimension,
433                                lhs_shard_size + 2 * new_padding_low_max);
434       dynamic_slice_start_indices[lhs_dimension] =
435           (OffsetCalculation(
436                MultiplyAddDivideOffsetCalculation(0, new_padding_low_max, 1)) -
437            new_padding_low_function)
438               .Calculate(partition_ordinals[lhs_dimension], b);
439       dynamic_slice_shape.set_dimensions(lhs_dimension,
440                                          lhs_shard_size + new_padding_low_max);
441     }
442 
443     // Since the convolution RHS operand size increased with halos, adjust
444     // the window config accordingly.
445     new_window.mutable_dimensions(i)->set_padding_low(new_padding_low);
446     new_window.mutable_dimensions(i)->set_padding_high(new_padding_high);
447     new_window.mutable_dimensions(i)->set_size(
448         rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size);
449   }
450 
451   HloInstruction* conv_lhs = lhs.hlo();
452   if (need_dynamic_slice_lhs) {
453     auto zero = b->AddInstruction(HloInstruction::CreateConstant(
454         LiteralUtil::Zero(lhs.hlo()->shape().element_type())));
455     auto pad = b->AddInstruction(
456         HloInstruction::CreatePad(pad_shape, lhs.hlo(), zero, pad_config));
457     conv_lhs = b->AddInstruction(HloInstruction::CreateDynamicSlice(
458         dynamic_slice_shape, pad, dynamic_slice_start_indices,
459         dynamic_slice_shape.dimensions()));
460   }
461 
462   // Exchange halo and concatenate.
463   HloInstruction* rhs_with_halo = rhs.hlo();
464   for (int i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
465     int64_t dim = dnums.kernel_spatial_dimensions(i);
466     int64_t explicit_left_padding_on_full_shape =
467         left_halo_size_functions[dim].Calculate(0);
468     int64_t shard_size_with_halo = new_window.dimensions(i).size();
469 
470     // offset_on_padded_shape and padded_full_shape_size are needed only if
471     // we want to mask out-of-range values in ExchangeHaloAndGetValidData().
472     // Since the default value for both the collective-permute is zero and
473     // also we call PadWithValue() on both operands at the beginning, we
474     // don't need to mask here.
475     //
476     // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls
477     // if it's always safe.
478     auto offset_on_padded_shape =
479         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
480             rhs_shard_sizes[i], explicit_left_padding_on_full_shape, 1)) -
481         left_halo_size_functions[dim];
482     int64_t padded_full_shape_size =
483         offset_on_padded_shape.Calculate(shard_counts[i] - 1) +
484         new_window.dimensions(i).size();
485     auto zero = b->AddInstruction(HloInstruction::CreateConstant(
486         LiteralUtil::Zero(rhs.hlo()->shape().element_type())));
487     auto concat = ExchangeHaloAndGetValidData(
488         rhs_with_halo, rhs.base_shape(), left_halo_size_functions[dim],
489         right_halo_size_functions[dim], explicit_left_padding_on_full_shape,
490         padded_full_shape_size, shard_size_with_halo, dim, rhs.sharding(),
491         offset_on_padded_shape.Calculate(partition_ordinals[dim], b), zero,
492         partition_ordinals[dim], collective_ops_creator,
493         lhs.state().next_channel_id, b,
494         /*mask_invalid_region=*/false);
495     if (!concat) {
496       return nullptr;
497     }
498     rhs_with_halo = *concat;
499   }
500 
501   TF_ASSIGN_OR_RETURN(
502       auto conv, create_sharded_conv(conv_lhs, rhs_with_halo, b, new_window));
503 
504   auto ar = collective_ops_creator.create_cross_partition_all_reduce(
505       b, conv, MakeBinaryAdd(original_hlo->shape().element_type(), module), {},
506       (*lhs.state().next_channel_id)++);
507   ar->set_sharding(HloSharding::Replicate());
508   return PartitionedHlo(ar, output_base_shape, lhs.state())
509       .Reshard(output_sharding)
510       .hlo();
511 }
512 
513 // Partition convolution when both LHS and RHS are partitioned at spatial
514 // dimensions. Halo exchange will happen on LHS only.
515 StatusOr<HloInstruction*>
PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,HloInstruction * partition_id,HloModule * module,SpmdBuilder * b)516 PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
517     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
518     const HloSharding& output_sharding,
519     const std::function<StatusOr<HloInstruction*>(
520         HloInstruction*, HloInstruction*, SpmdBuilder*,
521         const Window& conv_window)>& create_sharded_conv,
522     const Window& conv_window, HloInstruction* original_hlo,
523     HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
524   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
525   TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
526                !rhs.sharding().IsTileMaximal());
527 
528   const auto& dnums = original_hlo->convolution_dimension_numbers();
529 
530   // Check if the operand shardings are aligned. Also we currently don't
531   // support partitioning non-spatial dimensions.
532   std::vector<int64_t> rhs_to_lhs_indices(output_base_shape.rank());
533   rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
534       dnums.input_batch_dimension();
535   rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
536       dnums.input_feature_dimension();
537   for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
538     rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
539         dnums.input_spatial_dimensions(i);
540   }
541   std::vector<int64_t> lhs_to_rhs_indices(output_base_shape.rank());
542   for (int64_t i = 0; i < rhs_to_lhs_indices.size(); ++i) {
543     lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
544   }
545 
546   const Window& window = conv_window;
547   std::vector<int64_t> reversed_rhs_dims;
548   for (int64_t i = 0; i < window.dimensions_size(); ++i) {
549     if (window.dimensions(i).window_reversal()) {
550       reversed_rhs_dims.push_back(dnums.kernel_spatial_dimensions(i));
551     }
552   }
553   if (!reversed_rhs_dims.empty()) {
554     // Make the reversed dims left-padded to prepare for window reversal.
555     auto left_padded_rhs = HaloExchangeToPadOnLeft(rhs, reversed_rhs_dims);
556     if (left_padded_rhs == nullptr) {
557       return nullptr;
558     }
559     left_padded_rhs->set_sharding(rhs.sharding());
560     rhs = PartitionedHlo(left_padded_rhs, rhs.base_shape(), rhs.state());
561   }
562   // Consider window reversal when resharding RHS or LHS. Note: this will not
563   // reverse the data in the shard. We use window reversal to do that.
564   auto aligned_rhs_sharding = hlo_sharding_util::ReverseSharding(
565       hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices),
566       reversed_rhs_dims);
567   auto aligned_lhs_sharding = hlo_sharding_util::TransposeSharding(
568       hlo_sharding_util::ReverseSharding(rhs.sharding(), reversed_rhs_dims),
569       lhs_to_rhs_indices);
570 
571   auto unsupported_sharding = [&](const HloSharding& lhs_sharding,
572                                   const HloSharding& rhs_sharding) {
573     return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) !=
574                1 ||
575            rhs_sharding.tile_assignment().dim(
576                dnums.kernel_output_feature_dimension()) != 1;
577   };
578 
579   if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) {
580     if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) {
581       return nullptr;
582     }
583     lhs = lhs.Reshard(aligned_lhs_sharding).PadWithZero();
584     rhs = rhs.PadWithZero(reversed_rhs_dims);
585   } else {
586     if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) {
587       return nullptr;
588     }
589     lhs = lhs.PadWithZero();
590     rhs = rhs.Reshard(aligned_rhs_sharding).PadWithZero(reversed_rhs_dims);
591   }
592 
593   if (original_hlo->feature_group_count() > 1 &&
594       (lhs.sharding().tile_assignment().dim(dnums.input_feature_dimension()) >
595            1 ||
596        rhs.sharding().tile_assignment().dim(
597            dnums.kernel_output_feature_dimension()) > 1)) {
598     return nullptr;
599   }
600 
601   if (original_hlo->batch_group_count() > 1 &&
602       (lhs.sharding().tile_assignment().dim(dnums.input_batch_dimension()) >
603            1 ||
604        rhs.sharding().tile_assignment().dim(
605            dnums.kernel_output_feature_dimension()) > 1)) {
606     return nullptr;
607   }
608   // Reshard LHS by exchanging halo such that each shard computes the partial
609   // sum of the full shape result, and add AllReduce.
610   //
611   // The size of halo on each dimension can be calculated from the projection
612   // onto the LHS that each RHS shard i needs to read. RHS and LHS below refers
613   // to the shard size of RHS and LHS, WC is the number of windows, and D is the
614   // window dilation.
615   //
616   // * offset(i): RHS * D * i - low_padding
617   // * limit(i): {RHS * (i + 1) * D - (D - 1)} + (WC - 1) * stride - low_padding
618   //
619   // Since shard i has LHS of range [i * LHS, (i + 1) * LHS)
620   // * left-halo: i * LHS - offset(i)
621   //              = (LHS - RHS * D) * i + low_padding
622   // * right-halo: limit(i) - (i + 1) * LHS
623   //   = (RHS * D - LHS) * (i + 1) + (1 - D)  + (WC - 1) * stride - low_padding
624   //   = (RHS * D - LHS) * i + (RHS * D - LHS) + (1-D)
625   //     + (WC - 1) * stride - low_padding
626   std::vector<int64_t> shard_counts(dnums.input_spatial_dimensions_size());
627   std::vector<int64_t> lhs_shard_sizes(dnums.input_spatial_dimensions_size());
628   std::vector<int64_t> rhs_shard_sizes(dnums.input_spatial_dimensions_size());
629   for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
630     int64_t lhs_dimension = dnums.input_spatial_dimensions(i);
631     int64_t rhs_dimension = dnums.kernel_spatial_dimensions(i);
632     int64_t shard_count = lhs.sharding().tile_assignment().dim(lhs_dimension);
633     const auto& wd = window.dimensions(i);
634     if (wd.base_dilation() != 1) {
635       // TODO(wangtao): support parallel dim if it is replicate here.
636       return nullptr;
637     }
638 
639     int64_t lhs_shard_size =
640         CeilOfRatio(lhs.base_shape().dimensions(lhs_dimension), shard_count);
641     int64_t rhs_shard_size =
642         CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count);
643     shard_counts[i] = shard_count;
644     lhs_shard_sizes[i] = lhs_shard_size;
645     rhs_shard_sizes[i] = rhs_shard_size;
646   }
647 
648   std::vector<OffsetCalculation> left_halo_size_functions(
649       output_base_shape.rank());
650   std::vector<OffsetCalculation> right_halo_size_functions(
651       output_base_shape.rank());
652   Window new_window = window;
653 
654   auto partition_ordinals =
655       MakeTiledPartitionOrdinals(lhs.sharding(), partition_id, b);
656   HloInstruction* lhs_with_halo = lhs.hlo();
657   for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
658     int64_t lhs_dimension = dnums.input_spatial_dimensions(i);
659     int64_t lhs_shard_size = lhs_shard_sizes[i];
660     int64_t rhs_shard_size = rhs_shard_sizes[i];
661 
662     if (shard_counts[i] == 1) {
663       continue;
664     }
665 
666     // Calculate the left and right halo sizes as described in the comments
667     // above.
668     const auto& wd = window.dimensions(i);
669     int64_t padding_low = wd.padding_low();
670     int64_t padding_high = wd.padding_high();
671     int64_t base = lhs.base_shape().dimensions(lhs_dimension);
672     int64_t window_count = 1 + (padding_low + padding_high + base -
673                                 (1 + (wd.size() - 1) * wd.window_dilation())) /
674                                    wd.stride();
675     int64_t rhs_shard_size_dilated =
676         (rhs_shard_size - 1) * wd.window_dilation() + 1;
677 
678     left_halo_size_functions[lhs_dimension] =
679         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
680             lhs_shard_size - rhs_shard_size * wd.window_dilation(), padding_low,
681             1));
682     right_halo_size_functions[lhs_dimension] =
683         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
684             rhs_shard_size * wd.window_dilation() - lhs_shard_size,
685             rhs_shard_size * wd.window_dilation() - lhs_shard_size + 1 -
686                 wd.window_dilation() + wd.stride() * (window_count - 1) -
687                 padding_low,
688             1));
689 
690     // Exchange halo and concatenate.
691     int64_t dim = dnums.input_spatial_dimensions(i);
692     int64_t explicit_left_padding_on_full_shape = padding_low;
693     int64_t shard_size_with_halo =
694         wd.stride() * (window_count - 1) + rhs_shard_size_dilated;
695 
696     new_window.mutable_dimensions(i)->set_padding_low(0);
697     new_window.mutable_dimensions(i)->set_padding_high(0);
698     new_window.mutable_dimensions(i)->set_size(rhs_shard_size);
699 
700     // offset_on_padded_shape and padded_full_shape_size are needed only if
701     // we want to mask out-of-range values in ExchangeHaloAndGetValidData().
702     // Since the default value for both the collective-permute is zero and
703     // also we call PadWithValue() on both operands at the beginning, we
704     // don't need to mask here.
705     //
706     // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls
707     // if it's always safe.
708     auto offset_on_padded_shape =
709         OffsetCalculation(MultiplyAddDivideOffsetCalculation());
710     int64_t padded_full_shape_size = 0;
711 
712     auto zero = b->AddInstruction(HloInstruction::CreateConstant(
713         LiteralUtil::Zero(lhs.hlo()->shape().element_type())));
714     auto concat = ExchangeHaloAndGetValidData(
715         lhs_with_halo, lhs.base_shape(), left_halo_size_functions[dim],
716         right_halo_size_functions[dim], explicit_left_padding_on_full_shape,
717         padded_full_shape_size, shard_size_with_halo, dim, lhs.sharding(),
718         offset_on_padded_shape.Calculate(partition_ordinals[dim], b), zero,
719         partition_ordinals[dim], lhs.state().collective_ops_creator,
720         lhs.state().next_channel_id, b,
721         /*mask_invalid_region=*/false);
722     if (!concat) {
723       return nullptr;
724     }
725     lhs_with_halo = *concat;
726   }
727 
728   TF_ASSIGN_OR_RETURN(
729       auto conv, create_sharded_conv(lhs_with_halo, rhs.hlo(), b, new_window));
730   auto ar =
731       lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
732           b, conv, MakeBinaryAdd(output_base_shape.element_type(), module), {},
733           (*lhs.state().next_channel_id)++);
734   ar->set_sharding(HloSharding::Replicate());
735   return PartitionedHlo(ar, output_base_shape, lhs.state())
736       .Reshard(output_sharding)
737       .hlo();
738 }
739 
740 // Partition convolution when output is sharded. Will shard LHS with replicated
741 // RHS.
PartitionConvolutionTiledOutput(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,SpmdBuilder * b)742 StatusOr<HloInstruction*> PartitionConvolutionTiledOutput(
743     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
744     const HloSharding& output_sharding,
745     const std::function<StatusOr<HloInstruction*>(
746         HloInstruction*, HloInstruction*, SpmdBuilder*,
747         const Window& conv_window)>& create_sharded_conv,
748     const Window& conv_window, HloInstruction* original_hlo, SpmdBuilder* b) {
749   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
750   const auto& dnums = original_hlo->convolution_dimension_numbers();
751   TF_RET_CHECK(!output_sharding.IsTileMaximal());
752   // We don't currently support sharding on output feature dimension.
753   if (output_sharding.tile_assignment().dim(dnums.output_feature_dimension()) >
754       1) {
755     return nullptr;
756   }
757 
758   // Check if the operand and the output sharding are aligned.
759   std::vector<int64_t> input_to_output_indices(output_base_shape.rank());
760   input_to_output_indices[dnums.input_batch_dimension()] =
761       dnums.output_batch_dimension();
762   input_to_output_indices[dnums.input_feature_dimension()] =
763       dnums.output_feature_dimension();
764   for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
765     input_to_output_indices[dnums.input_spatial_dimensions(i)] =
766         dnums.output_spatial_dimensions(i);
767   }
768   auto target_operand_sharding = hlo_sharding_util::TransposeSharding(
769       output_sharding, input_to_output_indices);
770   lhs = lhs.Reshard(target_operand_sharding);
771 
772   // Replicate the RHS.
773   rhs = rhs.Reshard(HloSharding::Replicate());
774 
775   // Convolution window config does not include batch and feature dimensions,
776   // whereas ReshardAsWindowedInput() expects the same number of window
777   // dimensions as the rank of the operand. So add two more trivial
778   // dimensions.
779   std::vector<int64_t> ones(output_base_shape.rank(), 1);
780   auto operand_window = window_util::MakeWindow(ones);
781   for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
782     *operand_window.mutable_dimensions(dnums.input_spatial_dimensions(i)) =
783         conv_window.dimensions(i);
784   }
785 
786   auto zero = b->AddInstruction(HloInstruction::CreateConstant(
787       LiteralUtil::Zero(output_base_shape.element_type())));
788   auto resharded_operand_and_window =
789       lhs.ReshardAsWindowedInput(operand_window, target_operand_sharding, zero);
790   if (!resharded_operand_and_window.has_value()) {
791     return nullptr;
792   }
793   Window new_window;
794   for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
795     *new_window.add_dimensions() =
796         resharded_operand_and_window->shard_window.dimensions(
797             dnums.input_spatial_dimensions(i));
798   }
799 
800   TF_ASSIGN_OR_RETURN(
801       auto sharded_conv,
802       create_sharded_conv(resharded_operand_and_window->sharded_input,
803                           rhs.hlo(), b, new_window));
804 
805   auto shard_shape = MakePartitionedShape(output_base_shape, output_sharding);
806   if (!resharded_operand_and_window->dynamic_slice_index_on_output
807            .has_value()) {
808     CHECK(ShapeUtil::Compatible(shard_shape, sharded_conv->shape()));
809     return sharded_conv;
810   }
811   return b->AddInstruction(HloInstruction::CreateDynamicSlice(
812       shard_shape, sharded_conv,
813       *resharded_operand_and_window->dynamic_slice_index_on_output,
814       shard_shape.dimensions()));
815 }
816 
817 // Partition convolution with only one kind of dims partitioned.
PartitionConvolutionBaseCase(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,int64_t num_partitions,const SpmdPartitionerOptions & options,HloInstruction * partition_id,HloModule * module,SpmdBuilder * b)818 StatusOr<HloInstruction*> PartitionConvolutionBaseCase(
819     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
820     const HloSharding& output_sharding,
821     const std::function<StatusOr<HloInstruction*>(
822         HloInstruction*, HloInstruction*, SpmdBuilder*,
823         const Window& conv_window)>& create_sharded_conv,
824     const Window& conv_window, HloInstruction* original_hlo,
825     int64_t num_partitions, const SpmdPartitionerOptions& options,
826     HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
827   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
828 
829   // Case 1: Handle depthwise convolution with batch group count or
830   // feature group count.
831   if (original_hlo->batch_group_count() > 1) {
832     TF_ASSIGN_OR_RETURN(
833         auto parallel_partitioned_conv,
834         PartitionConvolutionWithBatchGroupCount(
835             lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
836             conv_window, original_hlo, num_partitions, b));
837     if (parallel_partitioned_conv) {
838       return parallel_partitioned_conv;
839     }
840   }
841 
842   if (original_hlo->feature_group_count() > 1) {
843     TF_ASSIGN_OR_RETURN(
844         auto parallel_partitioned_conv,
845         PartitionConvolutionWithFeatureGroupCount(
846             lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
847             conv_window, original_hlo, num_partitions, b));
848     if (parallel_partitioned_conv) {
849       return parallel_partitioned_conv;
850     }
851   }
852 
853   // Case 2: both RHS and LHS are tiled.
854   // Handling cases where both operands' shardings are aligned. We check that
855   // the LHS batch dimension is not partitioned because it is mapped to the
856   // output feature dimension in aligned_rhs_sharding, which are not the same
857   // dimension.
858   if (!lhs.sharding().IsTileMaximal() && !rhs.sharding().IsTileMaximal()) {
859     if (options.conv_halo_exchange_always_on_lhs) {
860       TF_ASSIGN_OR_RETURN(
861           auto partitioned_conv,
862           PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS(
863               lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
864               conv_window, original_hlo, partition_id, module, b));
865       if (partitioned_conv) {
866         return partitioned_conv;
867       }
868     } else {
869       TF_ASSIGN_OR_RETURN(
870           auto partitioned_conv,
871           PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS(
872               lhs, rhs, output_base_shape, output_sharding, create_sharded_conv,
873               conv_window, original_hlo, partition_id, module, b));
874 
875       if (partitioned_conv) {
876         return partitioned_conv;
877       }
878     }
879   }
880 
881   // Case 3: output is tiled.
882   if (!output_sharding.IsTileMaximal()) {
883     TF_ASSIGN_OR_RETURN(auto partitioned_conv,
884                         PartitionConvolutionTiledOutput(
885                             lhs, rhs, output_base_shape, output_sharding,
886                             create_sharded_conv, conv_window, original_hlo, b));
887 
888     if (partitioned_conv) {
889       return partitioned_conv;
890     }
891   }
892   return nullptr;
893 }
894 
CreateShardedConvConvolution(const HloInstruction & conv,const dot_as_convolution_util::DotConvolutionDimsInfo & dot_dnums,HloInstruction * sharded_lhs_hlo,HloInstruction * sharded_rhs_hlo,const Window & conv_window)895 StatusOr<std::unique_ptr<HloInstruction>> CreateShardedConvConvolution(
896     const HloInstruction& conv,
897     const dot_as_convolution_util::DotConvolutionDimsInfo& dot_dnums,
898     HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo,
899     const Window& conv_window) {
900   CHECK_EQ(conv.opcode(), HloOpcode::kConvolution);
901   const auto& conv_dnums = conv.convolution_dimension_numbers();
902   auto window = conv.window();
903   for (const auto& dim : dot_dnums.batch_dims) {
904     auto wd = window.mutable_dimensions(dim.spatial_dim);
905     wd->set_size(sharded_lhs_hlo->shape().dimensions(
906         conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
907     wd->set_stride(std::max<int64_t>(1, wd->size() - 1));
908     wd->set_base_dilation(wd->size());
909   }
910   for (const auto& dim : dot_dnums.contracting_dims) {
911     if (dim.spatial_dim < 0) {
912       continue;
913     }
914     auto wd = window.mutable_dimensions(dim.spatial_dim);
915     wd->set_size(sharded_lhs_hlo->shape().dimensions(
916         conv_dnums.input_spatial_dimensions(dim.spatial_dim)));
917   }
918   for (const auto& dim : dot_dnums.rhs_non_contracting_dims) {
919     if (dim.spatial_dim < 0) {
920       continue;
921     }
922     auto wd = window.mutable_dimensions(dim.spatial_dim);
923     wd->set_size(sharded_rhs_hlo->shape().dimensions(
924         conv_dnums.kernel_spatial_dimensions(dim.spatial_dim)));
925     wd->set_padding_high(wd->size() - 1);
926     wd->set_padding_low(wd->size() - 1);
927   }
928 
929   for (const auto& dim : dot_dnums.conv_spatial_dims) {
930     auto wd = window.mutable_dimensions(dim.spatial_dim);
931     const auto& new_window_dimension = conv_window.dimensions(dim.spatial_dim);
932     wd->set_size(new_window_dimension.size());
933     wd->set_padding_high(new_window_dimension.padding_high());
934     wd->set_padding_low(new_window_dimension.padding_low());
935   }
936 
937   int64_t feature_group_count = conv.feature_group_count();
938   if (feature_group_count > 1) {
939     feature_group_count = sharded_lhs_hlo->shape().dimensions(
940                               conv_dnums.input_feature_dimension()) /
941                           sharded_rhs_hlo->shape().dimensions(
942                               conv_dnums.kernel_input_feature_dimension());
943   }
944 
945   int64_t batch_group_count = conv.batch_group_count();
946   if (batch_group_count > 1) {
947     batch_group_count =
948         sharded_lhs_hlo->shape().dimensions(conv_dnums.input_batch_dimension());
949   }
950 
951   TF_ASSIGN_OR_RETURN(
952       Shape sharded_conv_shape,
953       ShapeInference::InferConvolveShape(
954           sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
955           feature_group_count, batch_group_count, window, conv_dnums,
956           /*preferred_element_type=*/conv.shape().element_type()));
957   *sharded_conv_shape.mutable_layout() = conv.shape().layout();
958   return HloInstruction::CreateConvolve(
959       sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo, feature_group_count,
960       batch_group_count, window, conv_dnums, conv.precision_config());
961 }
962 
963 }  // namespace
964 
965 // Partition convolution.
PartitionConvolution(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_conv,const Window & conv_window,HloInstruction * original_hlo,int64_t num_partitions,const SpmdPartitionerOptions & options,HloInstruction * partition_id,HloModule * module,SpmdBuilder * b)966 StatusOr<HloInstruction*> PartitionConvolution(
967     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
968     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
969     const std::function<StatusOr<HloInstruction*>(
970         HloInstruction*, HloInstruction*, SpmdBuilder*,
971         const Window& conv_window)>& create_sharded_conv,
972     const Window& conv_window, HloInstruction* original_hlo,
973     int64_t num_partitions, const SpmdPartitionerOptions& options,
974     HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) {
975   TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution);
976 
977   TF_ASSIGN_OR_RETURN(auto try_partitioned_conv,
978                       PartitionConvolutionBaseCase(
979                           lhs, rhs, output_base_shape, output_sharding,
980                           create_sharded_conv, conv_window, original_hlo,
981                           num_partitions, options, partition_id, module, b));
982   if (try_partitioned_conv) {
983     return try_partitioned_conv;
984   }
985 
986   return nullptr;
987 }
988 
HandleConvolution(HloInstruction * hlo)989 Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
990   if (hlo->sharding().HasUniqueDevice()) {
991     return DefaultAction(hlo);
992   }
993   auto dims_info = dot_as_convolution_util::ParseConvolutionDimsInfo(hlo);
994   spmd::DotConvDimsMapping mapping;
995   for (const auto& dims : dims_info.batch_dims) {
996     mapping.batch_dims.emplace_back();
997     mapping.batch_dims.back().lhs = dims.lhs;
998     mapping.batch_dims.back().rhs = dims.rhs;
999     mapping.batch_dims.back().output = dims.output;
1000     mapping.batch_dims.back().spatial = dims.spatial_dim;
1001   }
1002   for (const auto& dims : dims_info.contracting_dims) {
1003     mapping.contracting_dims.emplace_back();
1004     mapping.contracting_dims.back().lhs = dims.lhs;
1005     mapping.contracting_dims.back().rhs = dims.rhs;
1006     mapping.contracting_dims.back().output = dims.output;
1007     mapping.contracting_dims.back().spatial = dims.spatial_dim;
1008   }
1009   for (const auto& dims : dims_info.lhs_non_contracting_dims) {
1010     mapping.lhs_non_contracting_dims.emplace_back();
1011     mapping.lhs_non_contracting_dims.back().lhs = dims.lhs;
1012     mapping.lhs_non_contracting_dims.back().rhs = dims.rhs;
1013     mapping.lhs_non_contracting_dims.back().output = dims.output;
1014     mapping.lhs_non_contracting_dims.back().spatial = dims.spatial_dim;
1015   }
1016   for (const auto& dims : dims_info.rhs_non_contracting_dims) {
1017     mapping.rhs_non_contracting_dims.emplace_back();
1018     mapping.rhs_non_contracting_dims.back().lhs = dims.lhs;
1019     mapping.rhs_non_contracting_dims.back().rhs = dims.rhs;
1020     mapping.rhs_non_contracting_dims.back().output = dims.output;
1021     mapping.rhs_non_contracting_dims.back().spatial = dims.spatial_dim;
1022   }
1023   for (const auto& dims : dims_info.conv_spatial_dims) {
1024     mapping.conv_spatial_dims.emplace_back();
1025     mapping.conv_spatial_dims.back().lhs = dims.lhs;
1026     mapping.conv_spatial_dims.back().rhs = dims.rhs;
1027     mapping.conv_spatial_dims.back().output = dims.output;
1028     mapping.conv_spatial_dims.back().spatial = dims.spatial_dim;
1029   }
1030   auto create_sharded_conv =
1031       [&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo,
1032           spmd::SpmdBuilder* b,
1033           const Window& conv_window) -> StatusOr<HloInstruction*> {
1034     if (dims_info.conv_spatial_dims.empty() &&
1035         hlo->feature_group_count() == 1 && hlo->batch_group_count() == 1) {
1036       TF_ASSIGN_OR_RETURN(
1037           auto sharded_conv,
1038           dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
1039               *hlo, dims_info, lhs_hlo, rhs_hlo));
1040       return b->AddInstruction(std::move(sharded_conv));
1041     } else {
1042       TF_ASSIGN_OR_RETURN(auto sharded_conv,
1043                           CreateShardedConvConvolution(*hlo, dims_info, lhs_hlo,
1044                                                        rhs_hlo, conv_window));
1045       return b->AddInstruction(std::move(sharded_conv));
1046     }
1047   };
1048 
1049   return HandleDotHelper(hlo, mapping, create_sharded_conv);
1050 }
1051 
1052 }  // namespace spmd
1053 }  // namespace xla
1054