xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <optional>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/inlined_vector.h"
25 #include "absl/strings/str_join.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/service/hlo_module.h"
32 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
33 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
34 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
35 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
36 #include "tensorflow/compiler/xla/service/shape_inference.h"
37 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/util.h"
40 #include "tensorflow/compiler/xla/window_util.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 
43 namespace xla {
44 namespace spmd {
45 
46 namespace {
47 using hlo_sharding_util::GroupedSharding;
48 }  // namespace
49 
HasReplicatedSharding(const HloSharding & sharding)50 bool HasReplicatedSharding(const HloSharding& sharding) {
51   if (sharding.IsTuple()) {
52     return absl::c_any_of(sharding.tuple_elements(), HasReplicatedSharding);
53   }
54   return sharding.IsReplicated();
55 }
56 
MakeBinaryAdd(PrimitiveType type,HloModule * module)57 HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) {
58   HloComputation::Builder sum_b("add");
59   auto x = sum_b.AddInstruction(HloInstruction::CreateParameter(
60       /*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x"));
61   auto y = sum_b.AddInstruction(HloInstruction::CreateParameter(
62       /*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y"));
63   if (type == PRED) {
64     sum_b.AddInstruction(HloInstruction::CreateBinary(
65         ShapeUtil::MakeShape(type, {}), HloOpcode::kOr, x, y));
66   } else {
67     sum_b.AddInstruction(HloInstruction::CreateBinary(
68         ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y));
69   }
70   HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build());
71   return reduction;
72 }
73 
EvenlyPartitions(const Shape & shape,const HloSharding & sharding)74 bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) {
75   if (sharding.IsTuple()) {
76     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
77       if (!EvenlyPartitions(ShapeUtil::GetTupleElementShape(shape, i),
78                             sharding.GetSubSharding(shape, {i}))) {
79         return false;
80       }
81     }
82   }
83 
84   if (sharding.IsTileMaximal()) {
85     return sharding.IsReplicated();
86   }
87   for (int64_t i = 0; i < shape.dimensions_size(); ++i) {
88     if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) {
89       return false;
90     }
91   }
92   return true;
93 }
94 
MakePartitionedShape(const Shape & shape,const HloSharding & sharding)95 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) {
96   if (sharding.IsTuple()) {
97     std::vector<Shape> subshapes;
98     const int64_t shape_n = ShapeUtil::TupleElementCount(shape);
99     subshapes.reserve(shape_n);
100     for (int64_t i = 0; i < shape_n; ++i) {
101       subshapes.push_back(
102           MakePartitionedShape(ShapeUtil::GetTupleElementShape(shape, i),
103                                sharding.GetSubSharding(shape, {i})));
104     }
105     return ShapeUtil::MakeTupleShape(subshapes);
106   }
107   return sharding.TileShape(shape);
108 }
109 
ShapeSizeInBytes(const Shape & shape)110 int64_t ShapeSizeInBytes(const Shape& shape) {
111   return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) *
112          ShapeUtil::ElementsIn(shape);
113 }
114 
MakeNonPaddedShapeForGivenPartition(const Shape & shape,const HloSharding & sharding,int64_t partition_id)115 Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
116                                           const HloSharding& sharding,
117                                           int64_t partition_id) {
118   if (sharding.IsTuple()) {
119     std::vector<Shape> subshapes;
120     const int64_t shape_n = ShapeUtil::TupleElementCount(shape);
121     subshapes.reserve(shape_n);
122     for (int64_t i = 0; i < shape_n; ++i) {
123       subshapes.push_back(MakeNonPaddedShapeForGivenPartition(
124           ShapeUtil::GetTupleElementShape(shape, i),
125           sharding.GetSubSharding(shape, {i}), partition_id));
126     }
127     return ShapeUtil::MakeTupleShape(subshapes);
128   }
129 
130   if (sharding.IsReplicated()) {
131     return shape;
132   }
133   if (sharding.IsTileMaximal()) {
134     if (partition_id == *sharding.UniqueDevice()) {
135       return shape;
136     }
137     return ShapeUtil::MakeTupleShape({});
138   }
139 
140   auto partition_shape = shape;
141   std::vector<int64_t> tile_offset =
142       sharding.TileOffsetForDevice(shape, partition_id);
143   std::vector<int64_t> tile_limit =
144       sharding.TileLimitForDevice(shape, partition_id);
145   for (int64_t i = 0; i < tile_offset.size(); ++i) {
146     if (sharding.UsesDevice(partition_id)) {
147       partition_shape.set_dimensions(i, tile_limit[i] - tile_offset[i]);
148     } else {
149       partition_shape.set_dimensions(i, 0);
150     }
151   }
152   return partition_shape;
153 }
154 
MakePartitionOffsets(const Shape & shape,const HloSharding & sharding,HloInstruction * partition_id,SpmdBuilder * b,absl::Span<const int64_t> dims)155 std::vector<HloInstruction*> MakePartitionOffsets(
156     const Shape& shape, const HloSharding& sharding,
157     HloInstruction* partition_id, SpmdBuilder* b,
158     absl::Span<const int64_t> dims) {
159   CHECK(!shape.IsTuple());
160 
161   std::vector<std::vector<int32_t>> offset_arrays(shape.rank());
162   for (int64_t i = 0; i < shape.rank(); ++i) {
163     offset_arrays[i].resize(sharding.tile_assignment().num_elements());
164   }
165   auto shard_shape = MakePartitionedShape(shape, sharding);
166   sharding.tile_assignment().Each(
167       [&](absl::Span<const int64_t> indices, int64_t device) {
168         for (int64_t i = 0; i < shape.rank(); ++i) {
169           offset_arrays[i][device] = indices[i] * shard_shape.dimensions(i);
170         }
171       });
172   std::vector<HloInstruction*> offsets;
173   for (int64_t i = 0; i < shape.rank(); ++i) {
174     if (sharding.tile_assignment().dim(i) == 1 ||
175         (!dims.empty() && !absl::c_linear_search(dims, i))) {
176       offsets.push_back(b->AddInstruction(
177           HloInstruction::CreateConstant(LiteralUtil::Zero(S32))));
178     } else {
179       auto offset_table = b->AddInstruction(HloInstruction::CreateConstant(
180           LiteralUtil::CreateR1<int32_t>(offset_arrays[i])));
181       auto index = b->AddInstruction(HloInstruction::CreateDynamicSlice(
182           ShapeUtil::MakeShape(S32, {1}), offset_table, {partition_id}, {1}));
183       offsets.push_back(b->AddInstruction(
184           HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), index)));
185     }
186   }
187   return offsets;
188 }
189 
MakeTiledPartitionOrdinals(const HloSharding & sharding,HloInstruction * partition_id,SpmdBuilder * b)190 std::vector<HloInstruction*> MakeTiledPartitionOrdinals(
191     const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) {
192   CHECK(!sharding.IsTileMaximal());
193   auto dimensions = sharding.tile_assignment().dimensions();
194   if (sharding.ReplicateOnLastTileDim()) {
195     dimensions.pop_back();
196   }
197   auto table_shape = ShapeUtil::MakeShape(S32, dimensions);
198   return MakePartitionOffsets(table_shape, sharding, partition_id, b);
199 }
200 
GetPaddedShapeForUnevenPartitioning(const Shape & base_shape,const HloSharding & sharding)201 Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape,
202                                           const HloSharding& sharding) {
203   if (sharding.IsTileMaximal()) {
204     return base_shape;
205   }
206   if (EvenlyPartitions(base_shape, sharding)) {
207     return base_shape;
208   }
209   auto shard_shape = MakePartitionedShape(base_shape, sharding);
210   Shape padded_base_shape = base_shape;
211   for (int64_t i = 0; i < padded_base_shape.rank(); ++i) {
212     padded_base_shape.set_dimensions(
213         i, shard_shape.dimensions(i) * sharding.tile_assignment().dim(i));
214   }
215   return padded_base_shape;
216 }
217 
PartialReplicateReshardCompatibleSharding(const HloSharding & partial_sharding,const HloSharding & target_sharding)218 std::optional<HloSharding> PartialReplicateReshardCompatibleSharding(
219     const HloSharding& partial_sharding, const HloSharding& target_sharding) {
220   if (!partial_sharding.ReplicateOnLastTileDim()) {
221     return std::nullopt;
222   }
223   int64_t rank = partial_sharding.tile_assignment().num_dimensions() - 1;
224   int64_t target_rank = target_sharding.tile_assignment().num_dimensions() -
225                         (target_sharding.ReplicateOnLastTileDim() ? 1 : 0);
226   if (target_rank != rank) {
227     return std::nullopt;
228   }
229 
230   absl::flat_hash_map<int64_t, int64_t> device_to_replication_group;
231   partial_sharding.tile_assignment().Each(
232       [&](absl::Span<const int64_t> indices, int64_t device) {
233         int64_t gid = 0;
234         for (int64_t i = 0; i < rank; ++i) {
235           gid *= partial_sharding.tile_assignment().dim(i);
236           gid += indices[i];
237         }
238         device_to_replication_group[device] = gid;
239       });
240 
241   // A dimension is expanded when target_tile_size > partial_tile_size and
242   // target_tile_size % partial_tile_size == 0.
243   // expand_tile_dims_positions is the index of the expand_dim.
244   std::vector<int64_t> expand_tile_dims_indices(rank, -1);
245   // expand_tile_size = target_tile_size / partial_tile_size.
246   std::vector<int64_t> expand_tile_sizes;
247   int num_expand_dims = 0;
248   for (int64_t dim = 0; dim < rank; dim++) {
249     int64_t partial_tile_size = partial_sharding.tile_assignment().dim(dim);
250     int64_t target_tile_size = target_sharding.tile_assignment().dim(dim);
251     if (target_tile_size % partial_tile_size != 0 ||
252         target_tile_size < partial_tile_size) {
253       return std::nullopt;
254     }
255 
256     if (target_tile_size > partial_tile_size) {
257       expand_tile_dims_indices[dim] = num_expand_dims++;
258       expand_tile_sizes.emplace_back(target_tile_size / partial_tile_size);
259     }
260   }
261 
262   // Reshape the partial replicate tile_dimensions.
263   int64_t num_target_replication = 1;
264   if (target_sharding.ReplicateOnLastTileDim()) {
265     num_target_replication =
266         target_sharding.tile_assignment().dimensions().back();
267   }
268   auto reshape_dimensions = partial_sharding.tile_assignment().dimensions();
269   int64_t num_replication = reshape_dimensions.back();
270   if (num_replication / num_target_replication != Product(expand_tile_sizes) ||
271       num_replication % num_target_replication != 0) {
272     return std::nullopt;
273   }
274 
275   reshape_dimensions.pop_back();
276   reshape_dimensions.insert(reshape_dimensions.end(), expand_tile_sizes.begin(),
277                             expand_tile_sizes.end());
278 
279   if (target_sharding.ReplicateOnLastTileDim()) {
280     reshape_dimensions.push_back(num_target_replication);
281   }
282 
283   auto reshape_tile_assignment = partial_sharding.tile_assignment();
284   reshape_tile_assignment.Reshape(reshape_dimensions);
285 
286   // Transpose.
287   std::vector<int64_t> perm;
288   perm.reserve(rank + expand_tile_sizes.size());
289   for (int64_t dim = 0; dim < rank; dim++) {
290     perm.emplace_back(dim);
291     if (expand_tile_dims_indices[dim] > -1) {
292       perm.emplace_back(expand_tile_dims_indices[dim] + rank);
293     }
294   }
295   auto transpose_sharding = hlo_sharding_util::TransposeSharding(
296       target_sharding.ReplicateOnLastTileDim()
297           ? HloSharding::PartialTile(reshape_tile_assignment)
298           : HloSharding::Tile(reshape_tile_assignment),
299       perm);
300 
301   // Reshape to target shape
302   auto transpose_tile_assignment = transpose_sharding.tile_assignment();
303   transpose_tile_assignment.Reshape(
304       target_sharding.tile_assignment().dimensions());
305 
306   bool groups_matching = true;
307   target_sharding.tile_assignment().Each(
308       [&](absl::Span<const int64_t> indices, int64_t device) {
309         if (device_to_replication_group[device] !=
310             device_to_replication_group[transpose_tile_assignment(indices)]) {
311           groups_matching = false;
312         }
313       });
314 
315   if (groups_matching) {
316     return target_sharding;
317   }
318   return target_sharding.ReplicateOnLastTileDim()
319              ? HloSharding::PartialTile(transpose_tile_assignment)
320              : HloSharding::Tile(transpose_tile_assignment);
321 }
322 
TileToPartialReplicateHaloExchange(HloInstruction * hlo,const Shape & base_shape,const HloSharding & src_sharding,const HloSharding & dst_sharding,const std::vector<int64_t> & replicate_dims,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t * next_channel_id,HloInstruction * partition_id,SpmdBuilder * b)323 std::optional<HloInstruction*> TileToPartialReplicateHaloExchange(
324     HloInstruction* hlo, const Shape& base_shape,
325     const HloSharding& src_sharding, const HloSharding& dst_sharding,
326     const std::vector<int64_t>& replicate_dims,
327     const SPMDCollectiveOpsCreator& collective_ops_creator,
328     int64_t* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) {
329   // Source is tile sharding.
330   auto padded_src_shape =
331       GetPaddedShapeForUnevenPartitioning(base_shape, src_sharding);
332   // Target is partial replicate.
333   auto padded_dst_shape =
334       GetPaddedShapeForUnevenPartitioning(base_shape, dst_sharding);
335   if (ShapeUtil::Compatible(padded_dst_shape, hlo->shape())) {
336     return hlo;
337   }
338 
339   auto partition_ordinals =
340       MakeTiledPartitionOrdinals(dst_sharding, partition_id, b);
341 
342   auto result = hlo;
343   auto hlo_shape = hlo->shape();
344   for (auto dim : replicate_dims) {
345     int64_t dst_shard_count = dst_sharding.tile_assignment().dim(dim);
346     int64_t src_per_shard_size =
347         padded_src_shape.dimensions(dim) / dst_shard_count;
348     // Calculate per shard size using the sharding to compare if dst_sharding
349     // needs more padding at the end.
350     int64_t dst_per_shard_size =
351         padded_dst_shape.dimensions(dim) / dst_shard_count;
352 
353     // If src per shard doesn't have redundant data.
354     if (src_per_shard_size <= dst_per_shard_size || dst_shard_count == 1) {
355       continue;
356     }
357 
358     // If src_per_shard * replicate_factor > dst_per_shard , need to
359     // re-distribute the data between each shard using collective permute. For
360     // example, if dimension size is 6 and shard 4 ways in the src but needs to
361     // shard 2 ways in the dst. 4 way sharding has 2 element in each shard,
362     // while 2 way sharding has 3 elements, the last element in the first shard
363     // will be sliced out. re-distribution is needed.
364     //
365     // 1. Calculate left_halo size.
366     // left-halo size is
367     //   (src_per_shard_size - dst_per_shard_size) * i / replicate_factor
368     int64_t replicate_factor = src_sharding.tile_assignment().dim(dim) /
369                                dst_sharding.tile_assignment().dim(dim);
370     OffsetCalculation left_halo_size_function =
371         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
372             src_per_shard_size - dst_per_shard_size, 0, replicate_factor));
373 
374     // 2. Calculate right_halo size.
375     // right-halo size is 0
376     OffsetCalculation right_halo_size_function =
377         OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1));
378 
379     auto concat = result;
380     // 3. Halo exchange.
381     auto halo_exchange_result = ExchangeHalo(
382         result, left_halo_size_function, right_halo_size_function, dim,
383         src_sharding, collective_ops_creator, next_channel_id, b);
384 
385     if (halo_exchange_result.has_value()) {
386       concat = halo_exchange_result.value();
387     } else {
388       return std::nullopt;
389     }
390 
391     // 4. Slice the valid result.
392     // Slice offset is
393     // (dst_shard_count - i - 1) *
394     // (src_per_shard_size - dst_per_shard_size)
395     // i is the index in dst_sharindg.
396     auto zero_s32 = b->AddInstruction(
397         HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
398     OffsetCalculation start_offset_on_padded_concat_calculation =
399         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
400             dst_per_shard_size - src_per_shard_size,
401             (src_per_shard_size - dst_per_shard_size) * (dst_shard_count - 1),
402             1));
403     auto slice_shape = concat->shape();
404     slice_shape.set_dimensions(dim,
405                                padded_src_shape.dimensions(dim) /
406                                    src_sharding.tile_assignment().dim(dim));
407     std::vector<HloInstruction*> slice_offsets(concat->shape().rank(),
408                                                zero_s32);
409     slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
410         partition_ordinals[dim], b);
411     result = b->AddInstruction(HloInstruction::CreateDynamicSlice(
412         slice_shape, concat, slice_offsets, slice_shape.dimensions()));
413   }
414   return result;
415 }
416 
PadFromPartialReplicateShape(HloInstruction * hlo,const Shape & base_shape,const HloSharding & src_sharding,const HloSharding & dst_sharding,const std::vector<int64_t> & expand_tile_dims,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t * next_channel_id,HloInstruction * partition_id,SpmdBuilder * b)417 std::optional<HloInstruction*> PadFromPartialReplicateShape(
418     HloInstruction* hlo, const Shape& base_shape,
419     const HloSharding& src_sharding, const HloSharding& dst_sharding,
420     const std::vector<int64_t>& expand_tile_dims,
421     const SPMDCollectiveOpsCreator& collective_ops_creator,
422     int64_t* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) {
423   auto padded_src_shape =
424       GetPaddedShapeForUnevenPartitioning(base_shape, src_sharding);
425   auto padded_dst_shape =
426       GetPaddedShapeForUnevenPartitioning(base_shape, dst_sharding);
427   if (ShapeUtil::Compatible(padded_dst_shape, hlo->shape())) {
428     return hlo;
429   }
430 
431   auto partition_ordinals =
432       MakeTiledPartitionOrdinals(src_sharding, partition_id, b);
433 
434   HloInstruction* result = hlo;
435   auto zero = b->AddInstruction(HloInstruction::CreateConstant(
436       LiteralUtil::Zero(hlo->shape().element_type())));
437   std::vector<int64_t> expand_dims_without_halo_exchange;
438   // Pad the dimensions needs halo exchange and record the padded dims that
439   // won't need halo exchange.
440   for (auto dim : expand_tile_dims) {
441     int64_t src_shard_count = src_sharding.tile_assignment().dim(dim);
442     int64_t src_per_shard_size =
443         padded_src_shape.dimensions(dim) / src_shard_count;
444     // Calculate per shard size using the sharding to compare if dst_sharding
445     // needs more padding at the end.
446     int64_t dst_per_shard_size =
447         padded_dst_shape.dimensions(dim) / src_shard_count;
448 
449     // If dst_sharding doesn't need more padding at the end.
450     if (src_per_shard_size >= dst_per_shard_size) {
451       continue;
452     }
453     // If src sharding at this dimension is not partitoned, simply pad to
454     // the desired shape.
455     if (src_shard_count == 1) {
456       expand_dims_without_halo_exchange.emplace_back(dim);
457       continue;
458     }
459 
460     // If dst_padding needs more padding at the end, need to re-distribute the
461     // data between each shard using collective permute.
462     // For example, if dimension size is 6 and shard 2 ways in the src but
463     // needs to shard 4 ways in the dst. 4 ways needs padding 2 0s at the end
464     // and has 2 elements at each shard, while 2 way sharding has 3 elements
465     // in each shard, re-distribution is needed.
466     //
467     // 1. Calculate left_halo size.
468     // left-halo size is 0
469     OffsetCalculation left_halo_size_function =
470         OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1));
471 
472     // 2. Calculate right_halo size.
473     // right-halo size is D * (i + 1) - S * (i + 1) = (D - S) * i + (D - S)
474     OffsetCalculation right_halo_size_function =
475         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
476             dst_per_shard_size - src_per_shard_size,
477             dst_per_shard_size - src_per_shard_size, 1));
478 
479     auto concat = result;
480     // 3. Halo exchange.
481     auto halo_exchange_result = ExchangeHalo(
482         result, left_halo_size_function, right_halo_size_function, dim,
483         src_sharding, collective_ops_creator, next_channel_id, b);
484 
485     if (halo_exchange_result.has_value()) {
486       concat = halo_exchange_result.value();
487     } else {
488       return std::nullopt;
489     }
490 
491     // 4. Pad.
492     std::vector<int64_t> zero_padding(concat->shape().rank());
493     PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding);
494     pad_config.mutable_dimensions(dim)->set_edge_padding_low(0);
495     int64_t max_right_halo_size =
496         right_halo_size_function.MaxInRange(0, src_shard_count - 1);
497     pad_config.mutable_dimensions(dim)->set_edge_padding_high(
498         std::max(int64_t{0}, padded_dst_shape.dimensions(dim) -
499                                  padded_src_shape.dimensions(dim) -
500                                  max_right_halo_size));
501     auto padded_concat_shape = ShapeInference::InferPadShape(
502                                    concat->shape(), zero->shape(), pad_config)
503                                    .ValueOrDie();
504     concat = b->AddInstruction(HloInstruction::CreatePad(
505         padded_concat_shape, concat, zero, pad_config));
506 
507     // 5. Slice the valid result.
508     // Slice offset is (D-S) * i
509     auto zero_s32 = b->AddInstruction(
510         HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
511     OffsetCalculation start_offset_on_padded_concat_calculation =
512         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
513             dst_per_shard_size - src_per_shard_size, 0, 1));
514     auto slice_shape = concat->shape();
515     slice_shape.set_dimensions(dim, dst_per_shard_size);
516     std::vector<HloInstruction*> slice_offsets(concat->shape().rank(),
517                                                zero_s32);
518     slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
519         partition_ordinals[dim], b);
520     result = b->AddInstruction(HloInstruction::CreateDynamicSlice(
521         slice_shape, concat, slice_offsets, slice_shape.dimensions()));
522   }
523 
524   // Pad other dimensions that won't need halo exchange with a single pad.
525   if (!expand_dims_without_halo_exchange.empty()) {
526     std::vector<int64_t> zero_padding(result->shape().rank());
527     PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding);
528 
529     auto padded_shape = result->shape();
530     for (auto dim : expand_dims_without_halo_exchange) {
531       pad_config.mutable_dimensions(dim)->set_edge_padding_low(0);
532       pad_config.mutable_dimensions(dim)->set_edge_padding_high(
533           padded_dst_shape.dimensions(dim) - padded_src_shape.dimensions(dim));
534       padded_shape.set_dimensions(dim, result->shape().dimensions(dim) +
535                                            padded_dst_shape.dimensions(dim) -
536                                            padded_src_shape.dimensions(dim));
537     }
538     result = b->AddInstruction(
539         HloInstruction::CreatePad(padded_shape, result, zero, pad_config));
540   }
541 
542   return result;
543 }
544 
UniqueTiledDim(const HloSharding & sharding)545 std::optional<int64_t> UniqueTiledDim(const HloSharding& sharding) {
546   if (sharding.IsTileMaximal()) {
547     return std::nullopt;
548   }
549   int64_t dim = -1;
550   int64_t rank = sharding.ReplicateOnLastTileDim()
551                      ? sharding.tile_assignment().num_dimensions() - 1
552                      : sharding.tile_assignment().num_dimensions();
553   for (int64_t i = 0; i < rank; ++i) {
554     if (sharding.tile_assignment().dim(i) > 1) {
555       if (dim != -1) {
556         return std::nullopt;
557       }
558       dim = i;
559     }
560   }
561   CHECK_NE(dim, -1);
562   return dim;
563 }
564 
MultiplyAddDivideOffsetCalculation(int64_t multiplier,int64_t offset,int64_t divisor)565 MultiplyAddDivideOffsetCalculation::MultiplyAddDivideOffsetCalculation(
566     int64_t multiplier, int64_t offset, int64_t divisor)
567     : multiplier_(multiplier), offset_(offset), divisor_(divisor) {
568   CHECK_GT(divisor_, 0);
569   Simplify();
570 }
571 
operator -(const MultiplyAddDivideOffsetCalculation & other) const572 OffsetCalculation MultiplyAddDivideOffsetCalculation::operator-(
573     const MultiplyAddDivideOffsetCalculation& other) const {
574   if (divisor_ == 1 && other.divisor_ == 1) {
575     return OffsetCalculation(MultiplyAddDivideOffsetCalculation(
576         multiplier_ - other.multiplier_, offset_ - other.offset_, 1));
577   }
578   return OffsetCalculation(HloOpcode::kSubtract, *this, other);
579 }
580 
Simplify()581 void MultiplyAddDivideOffsetCalculation::Simplify() {
582   // We could simplify the calculation when multiplier is a multiple of
583   // divisor_. However, when offset_ is not a multiple of divisor_, we must
584   // make sure that offset_ and multiplier_ are both non-negative or both
585   // non-positive. E.g., (3 * i  - 1) / 3 is not equivalent to i or i - 1.
586   if (divisor_ != 1 && multiplier_ % divisor_ == 0 &&
587       (offset_ % divisor_ == 0 || offset_ * multiplier_ > 0)) {
588     multiplier_ /= divisor_;
589     offset_ /= divisor_;
590     divisor_ = 1;
591   }
592 }
593 
Calculate(int64_t shard_ordinal) const594 int64_t MultiplyAddDivideOffsetCalculation::Calculate(
595     int64_t shard_ordinal) const {
596   return (shard_ordinal * multiplier_ + offset_) / divisor_;
597 }
598 
Calculate(HloInstruction * shard_ordinal,SpmdBuilder * b) const599 HloInstruction* MultiplyAddDivideOffsetCalculation::Calculate(
600     HloInstruction* shard_ordinal, SpmdBuilder* b) const {
601   auto scalar_shape = ShapeUtil::MakeShape(S32, {});
602   if (multiplier_ == 0) {
603     return b->AddInstruction(HloInstruction::CreateConstant(
604         LiteralUtil::CreateR0<int32_t>(offset_ / divisor_)));
605   }
606   HloInstruction* result = shard_ordinal;
607   if (multiplier_ != 1) {
608     result = b->AddInstruction(HloInstruction::CreateBinary(
609         scalar_shape, HloOpcode::kMultiply, shard_ordinal,
610         b->AddInstruction(HloInstruction::CreateConstant(
611             LiteralUtil::CreateR0<int32_t>(multiplier_)))));
612   }
613   if (offset_ != 0) {
614     auto offset = b->AddInstruction(HloInstruction::CreateConstant(
615         LiteralUtil::CreateR0<int32_t>(offset_)));
616     result = b->AddInstruction(HloInstruction::CreateBinary(
617         scalar_shape, HloOpcode::kAdd, result, offset));
618   }
619   if (divisor_ != 1) {
620     auto divisor = b->AddInstruction(HloInstruction::CreateConstant(
621         LiteralUtil::CreateR0<int32_t>(divisor_)));
622     result = b->AddInstruction(HloInstruction::CreateBinary(
623         scalar_shape, HloOpcode::kDivide, result, divisor));
624   }
625   return result;
626 }
627 
MaxInRange(int64_t start_ordinal,int64_t limit_ordinal) const628 int64_t MultiplyAddDivideOffsetCalculation::MaxInRange(
629     int64_t start_ordinal, int64_t limit_ordinal) const {
630   int64_t max = Calculate(start_ordinal);
631   for (int64_t i = start_ordinal + 1; i < limit_ordinal; ++i) {
632     max = std::max(max, Calculate(i));
633   }
634   return max;
635 }
636 
operator =(const OffsetCalculation & other)637 OffsetCalculation& OffsetCalculation::operator=(
638     const OffsetCalculation& other) {
639   opcode_ = other.opcode_;
640   copy_from_ = other.copy_from_;
641   if (opcode_ != HloOpcode::kCopy) {
642     lhs_ = std::make_unique<OffsetCalculation>(*other.lhs_);
643     rhs_ = std::make_unique<OffsetCalculation>(*other.rhs_);
644   }
645   return *this;
646 }
647 
IsConstant() const648 bool OffsetCalculation::IsConstant() const {
649   if (opcode_ == HloOpcode::kCopy) {
650     return copy_from_.IsConstant();
651   }
652   if (opcode_ == HloOpcode::kSubtract && *lhs_ == *rhs_) {
653     return true;
654   }
655   return lhs_->IsConstant() && rhs_->IsConstant();
656 }
657 
operator -(const OffsetCalculation & other) const658 OffsetCalculation OffsetCalculation::operator-(
659     const OffsetCalculation& other) const {
660   if (opcode_ == HloOpcode::kCopy && other.opcode_ == HloOpcode::kCopy) {
661     return copy_from_ - other.copy_from_;
662   }
663   return OffsetCalculation(HloOpcode::kSubtract, *this, other);
664 }
665 
operator ==(const OffsetCalculation & other) const666 bool OffsetCalculation::operator==(const OffsetCalculation& other) const {
667   if (opcode_ != other.opcode_) {
668     return false;
669   }
670   if (opcode_ == HloOpcode::kCopy) {
671     return copy_from_ == other.copy_from_;
672   }
673   return *lhs_ == *other.lhs_ && *rhs_ == *other.rhs_;
674 }
675 
Calculate(int64_t shard_ordinal) const676 int64_t OffsetCalculation::Calculate(int64_t shard_ordinal) const {
677   switch (opcode_) {
678     case HloOpcode::kCopy:
679       return copy_from_.Calculate(shard_ordinal);
680     case HloOpcode::kSubtract:
681       return lhs_->Calculate(shard_ordinal) - rhs_->Calculate(shard_ordinal);
682     case HloOpcode::kMultiply:
683       return lhs_->Calculate(shard_ordinal) * rhs_->Calculate(shard_ordinal);
684     default:
685       LOG(FATAL) << "Should not happen";
686   }
687 }
688 
Calculate(HloInstruction * shard_ordinal,SpmdBuilder * b) const689 HloInstruction* OffsetCalculation::Calculate(HloInstruction* shard_ordinal,
690                                              SpmdBuilder* b) const {
691   if (opcode_ == HloOpcode::kCopy) {
692     return copy_from_.Calculate(shard_ordinal, b);
693   }
694   auto lhs = lhs_->Calculate(shard_ordinal, b);
695   auto rhs = rhs_->Calculate(shard_ordinal, b);
696   return b->AddInstruction(
697       HloInstruction::CreateBinary(lhs->shape(), opcode_, lhs, rhs));
698 }
699 
MaxInRange(int64_t start_ordinal,int64_t limit_ordinal) const700 int64_t OffsetCalculation::MaxInRange(int64_t start_ordinal,
701                                       int64_t limit_ordinal) const {
702   if (IsConstant()) {
703     return Calculate(start_ordinal);
704   }
705   if (opcode_ == HloOpcode::kCopy) {
706     return std::max(Calculate(start_ordinal), Calculate(limit_ordinal - 1));
707   }
708   int64_t max = Calculate(start_ordinal);
709   for (int64_t i = start_ordinal + 1; i < limit_ordinal; ++i) {
710     max = std::max(max, Calculate(i));
711   }
712   return max;
713 }
714 
ExchangeHalo(HloInstruction * hlo,const OffsetCalculation & left_halo_size_function,const OffsetCalculation & right_halo_size_function,int64_t dim,const HloSharding & target,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t * next_channel_id,SpmdBuilder * b)715 std::optional<HloInstruction*> ExchangeHalo(
716     HloInstruction* hlo, const OffsetCalculation& left_halo_size_function,
717     const OffsetCalculation& right_halo_size_function, int64_t dim,
718     const HloSharding& target,
719     const SPMDCollectiveOpsCreator& collective_ops_creator,
720     int64_t* next_channel_id, SpmdBuilder* b) {
721   int64_t input_shard_size = hlo->shape().dimensions(dim);
722   int64_t shard_count = target.tile_assignment().dim(dim);
723 
724   std::vector<HloInstruction*> concat_pieces;
725 
726   int64_t max_left_halo_size =
727       left_halo_size_function.MaxInRange(1, shard_count);
728   int64_t max_right_halo_size =
729       right_halo_size_function.MaxInRange(0, shard_count - 1);
730   if (max_left_halo_size + max_right_halo_size + input_shard_size >=
731           input_shard_size * shard_count &&
732       (max_left_halo_size > input_shard_size ||
733        max_right_halo_size > input_shard_size)) {
734     return std::nullopt;
735   }
736   // Since max halo sizes could be negative, we only need to include data within
737   // certain bounds. Useful region is [left_bound, right_bound).
738   const int64_t left_bound =
739       -left_halo_size_function.MaxInRange(0, shard_count);
740   const int64_t right_bound =
741       input_shard_size + right_halo_size_function.MaxInRange(0, shard_count);
742   if (left_bound >= right_bound) {
743     return std::nullopt;
744   }
745   // Left halo.
746   for (int64_t i = CeilOfRatio(max_left_halo_size, input_shard_size) - 1;
747        i >= 0 && (-i - 1) * input_shard_size < right_bound; --i) {
748     std::vector<std::pair<int64_t, int64_t>> source_target_pairs;
749     target.tile_assignment().Each(
750         [&](absl::Span<const int64_t> indices, int64_t device) {
751           if (indices[dim] > i) {
752             std::vector<int64_t> source_indices(indices.begin(), indices.end());
753             source_indices[dim] -= i + 1;
754             source_target_pairs.emplace_back(
755                 target.tile_assignment()(source_indices), device);
756           }
757         });
758     int64_t halo_size_including_skips =
759         std::min(max_left_halo_size - input_shard_size * i, input_shard_size);
760     int64_t halo_right_skips =
761         std::max<int64_t>(-i * input_shard_size - right_bound, 0);
762     int64_t halo_size = halo_size_including_skips - halo_right_skips;
763     auto halo_shape = hlo->shape();
764     auto source_halo_slice = hlo;
765     if (halo_size != hlo->shape().dimensions(dim)) {
766       halo_shape.set_dimensions(dim, halo_size);
767       std::vector<int64_t> halo_start_indices(halo_shape.rank(), 0);
768       halo_start_indices[dim] =
769           hlo->shape().dimensions(dim) - halo_size_including_skips;
770       std::vector<int64_t> halo_limit_indices(hlo->shape().dimensions().begin(),
771                                               hlo->shape().dimensions().end());
772       halo_limit_indices[dim] -= halo_right_skips;
773       std::vector<int64_t> halo_slice_strides(halo_shape.rank(), 1);
774       source_halo_slice = b->AddInstruction(
775           HloInstruction::CreateSlice(halo_shape, hlo, halo_start_indices,
776                                       halo_limit_indices, halo_slice_strides));
777     }
778     auto left_halo =
779         collective_ops_creator.create_cross_partition_collective_permute(
780             b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
781     concat_pieces.push_back(left_halo);
782   }
783 
784   if (left_bound < input_shard_size && right_bound > 0) {
785     int64_t self_start = std::max<int64_t>(0, left_bound);
786     int64_t self_limit = std::min<int64_t>(input_shard_size, right_bound);
787     if (self_start == 0 && self_limit == input_shard_size) {
788       concat_pieces.push_back(hlo);
789     } else {
790       auto self_shape = hlo->shape();
791       self_shape.set_dimensions(dim, self_limit - self_start);
792       std::vector<int64_t> start_indices(self_shape.rank(), 0);
793       start_indices[dim] = self_start;
794       std::vector<int64_t> limit_indices(hlo->shape().dimensions().begin(),
795                                          hlo->shape().dimensions().end());
796       limit_indices[dim] = self_limit;
797       std::vector<int64_t> slice_strides(self_shape.rank(), 1);
798       concat_pieces.push_back(b->AddInstruction(HloInstruction::CreateSlice(
799           self_shape, hlo, start_indices, limit_indices, slice_strides)));
800     }
801   }
802 
803   int64_t skipped_right_halos =
804       std::min<int64_t>(std::max<int64_t>(left_bound - input_shard_size, 0),
805                         std::max<int64_t>(max_right_halo_size, 0)) /
806       input_shard_size;
807   // Right halo.
808   for (int64_t i = skipped_right_halos;
809        i < CeilOfRatio(max_right_halo_size, input_shard_size); ++i) {
810     std::vector<std::pair<int64_t, int64_t>> source_target_pairs;
811     target.tile_assignment().Each(
812         [&](absl::Span<const int64_t> indices, int64_t device) {
813           if (indices[dim] > i) {
814             std::vector<int64_t> target_indices(indices.begin(), indices.end());
815             target_indices[dim] -= i + 1;
816             source_target_pairs.emplace_back(
817                 device, target.tile_assignment()(target_indices));
818           }
819         });
820     int64_t halo_size_including_skips =
821         std::min(max_right_halo_size - input_shard_size * i, input_shard_size);
822     int64_t halo_left_skips =
823         std::max<int64_t>(left_bound - (i + 1) * input_shard_size, 0);
824     int64_t halo_size = halo_size_including_skips - halo_left_skips;
825     auto halo_shape = hlo->shape();
826     HloInstruction* source_halo_slice = hlo;
827     if (halo_size != halo_shape.dimensions(dim)) {
828       halo_shape.set_dimensions(dim, halo_size);
829       std::vector<int64_t> halo_start_indices(halo_shape.rank(), 0);
830       halo_start_indices[dim] = halo_left_skips;
831       std::vector<int64_t> halo_limit_indices(halo_shape.dimensions().begin(),
832                                               halo_shape.dimensions().end());
833       halo_limit_indices[dim] += halo_left_skips;
834       std::vector<int64_t> halo_slice_strides(halo_shape.rank(), 1);
835       source_halo_slice = b->AddInstruction(
836           HloInstruction::CreateSlice(halo_shape, hlo, halo_start_indices,
837                                       halo_limit_indices, halo_slice_strides));
838     }
839     auto right_halo =
840         collective_ops_creator.create_cross_partition_collective_permute(
841             b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
842     concat_pieces.push_back(right_halo);
843   }
844 
845   auto concat = concat_pieces[0];
846   // Concat with halos/padding.
847   if (concat_pieces.size() > 1) {
848     auto concat_shape = hlo->shape();
849     int64_t concat_dim_size = 0;
850     for (auto piece : concat_pieces) {
851       concat_dim_size += piece->shape().dimensions(dim);
852     }
853     concat_shape.set_dimensions(dim, concat_dim_size);
854     concat = b->AddInstruction(
855         HloInstruction::CreateConcatenate(concat_shape, concat_pieces, dim));
856   }
857 
858   return concat;
859 }
860 
ExchangeHalo(HloInstruction * hlo,std::vector<OffsetCalculation> left_halo_size_functions,std::vector<OffsetCalculation> right_halo_size_functions,const HloSharding & target,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t * next_channel_id,SpmdBuilder * b)861 std::optional<HloInstruction*> ExchangeHalo(
862     HloInstruction* hlo,
863     std::vector<OffsetCalculation> left_halo_size_functions,
864     std::vector<OffsetCalculation> right_halo_size_functions,
865     const HloSharding& target,
866     const SPMDCollectiveOpsCreator& collective_ops_creator,
867     int64_t* next_channel_id, SpmdBuilder* b) {
868   CHECK(left_halo_size_functions.size() == hlo->shape().rank());
869   CHECK(right_halo_size_functions.size() == hlo->shape().rank());
870 
871   HloInstruction* visiting_hlo = hlo;
872   for (int dim = 0; dim < hlo->shape().rank(); ++dim) {
873     auto concat = ExchangeHalo(visiting_hlo, left_halo_size_functions[dim],
874                                right_halo_size_functions[dim], dim, target,
875                                collective_ops_creator, next_channel_id, b);
876     if (!concat) {
877       return std::nullopt;
878     }
879     visiting_hlo = *concat;
880   }
881   return visiting_hlo;
882 }
883 
ExchangeHaloAndGetValidData(HloInstruction * hlo,const Shape & base_shape,const OffsetCalculation & left_halo_size_function,const OffsetCalculation & right_halo_size_function,int64_t explicit_left_padding_on_full_shape,int64_t padded_full_shape_size,int64_t shard_size_with_halo,int64_t dim,const HloSharding & target,HloInstruction * offset_on_padded_shape,HloInstruction * pad_value,HloInstruction * partition_ordinal,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t * next_channel_id,SpmdBuilder * b,bool mask_invalid_region)884 std::optional<HloInstruction*> ExchangeHaloAndGetValidData(
885     HloInstruction* hlo, const Shape& base_shape,
886     const OffsetCalculation& left_halo_size_function,
887     const OffsetCalculation& right_halo_size_function,
888     int64_t explicit_left_padding_on_full_shape, int64_t padded_full_shape_size,
889     int64_t shard_size_with_halo, int64_t dim, const HloSharding& target,
890     HloInstruction* offset_on_padded_shape, HloInstruction* pad_value,
891     HloInstruction* partition_ordinal,
892     const SPMDCollectiveOpsCreator& collective_ops_creator,
893     int64_t* next_channel_id, SpmdBuilder* b, bool mask_invalid_region) {
894   auto halo_exchange_result =
895       ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function, dim,
896                    target, collective_ops_creator, next_channel_id, b);
897   if (!halo_exchange_result) {
898     return std::nullopt;
899   }
900   auto concat = *halo_exchange_result;
901   int64_t shard_count = target.tile_assignment().dim(dim);
902   int64_t max_left_halo_size =
903       left_halo_size_function.MaxInRange(1, shard_count);
904 
905   // Now we determine if we need extra padding after the concat.
906   //
907   // The max of halo size or the first shard's explicit left padding.
908   int64_t max_left_halo_or_padding_size =
909       std::max(max_left_halo_size, explicit_left_padding_on_full_shape);
910   // The calculation that returns the dynamic slice index for a shard on the
911   // padded concat, which is the difference between
912   // max_left_halo_or_padding_size and its left halo size.
913   auto start_offset_on_padded_concat_calculation =
914       OffsetCalculation(MultiplyAddDivideOffsetCalculation(
915           0, max_left_halo_or_padding_size, 1)) -
916       left_halo_size_function;
917 
918   // See if we need to pad the concat before dynamic slice.
919   int64_t extra_left_padding =
920       std::max(int64_t{0}, max_left_halo_or_padding_size -
921                                std::max(int64_t{0}, max_left_halo_size));
922   int64_t extra_right_padding =
923       start_offset_on_padded_concat_calculation.MaxInRange(0, shard_count) +
924       shard_size_with_halo - concat->shape().dimensions(dim) -
925       extra_left_padding;
926   extra_right_padding = std::max(int64_t{0}, extra_right_padding);
927   if (extra_left_padding > 0 || extra_right_padding > 0) {
928     PaddingConfig padding_config;
929     auto padded_concat_shape = concat->shape();
930     for (int64_t i = 0; i < base_shape.rank(); ++i) {
931       auto padding_config_dim = padding_config.add_dimensions();
932       padding_config_dim->set_interior_padding(0);
933       padding_config_dim->set_edge_padding_low(0);
934       padding_config_dim->set_edge_padding_high(0);
935       if (i != dim) {
936         continue;
937       }
938       padding_config_dim->set_edge_padding_low(extra_left_padding);
939       padding_config_dim->set_edge_padding_high(extra_right_padding);
940       padded_concat_shape.set_dimensions(dim, concat->shape().dimensions(dim) +
941                                                   extra_left_padding +
942                                                   extra_right_padding);
943     }
944     concat = b->AddInstruction(HloInstruction::CreatePad(
945         padded_concat_shape, concat, pad_value, padding_config));
946   }
947 
948   auto valid_slice = concat;
949   if (shard_size_with_halo != concat->shape().dimensions(dim)) {
950     // Concat is bigger than the shard shape, so we need a dynamic slice.
951     CHECK_LT(shard_size_with_halo, concat->shape().dimensions(dim));
952     auto slice_shape = concat->shape();
953     slice_shape.set_dimensions(dim, shard_size_with_halo);
954 
955     if (left_halo_size_function.IsConstant() &&
956         left_halo_size_function.Calculate(0) ==
957             explicit_left_padding_on_full_shape) {
958       std::vector<int64_t> start_indices(slice_shape.rank(), 0);
959       std::vector<int64_t> strides(slice_shape.rank(), 1);
960       valid_slice = b->AddInstruction(
961           HloInstruction::CreateSlice(slice_shape, concat, start_indices,
962                                       slice_shape.dimensions(), strides));
963     } else {
964       auto zero = b->AddInstruction(
965           HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
966       std::vector<HloInstruction*> slice_offsets(base_shape.rank(), zero);
967       slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
968           partition_ordinal, b);
969       valid_slice = b->AddInstruction(HloInstruction::CreateDynamicSlice(
970           slice_shape, concat, slice_offsets, slice_shape.dimensions()));
971     }
972   }
973 
974   if (!mask_invalid_region) {
975     return valid_slice;
976   }
977 
978   int64_t total_right_padding = padded_full_shape_size -
979                                 base_shape.dimensions(dim) -
980                                 explicit_left_padding_on_full_shape;
981   // Mask off garbage data due to uneven partition or low/high padding.
982   if (explicit_left_padding_on_full_shape > 0 || total_right_padding > 0) {
983     auto index_shape = ShapeUtil::ChangeElementType(valid_slice->shape(), S32);
984     auto iota = b->AddInstruction(HloInstruction::CreateIota(index_shape, dim));
985     auto broadcast_start_index_in_padded_shape =
986         b->AddInstruction(HloInstruction::CreateBroadcast(
987             index_shape, offset_on_padded_shape, {}));
988     auto index_in_padded_shape = b->AddInstruction(
989         HloInstruction::CreateBinary(index_shape, HloOpcode::kAdd, iota,
990                                      broadcast_start_index_in_padded_shape));
991     auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED);
992     std::vector<HloInstruction*> predicates;
993     if (explicit_left_padding_on_full_shape > 0) {
994       auto valid_index_start =
995           b->AddInstruction(HloInstruction::CreateBroadcast(
996               index_shape,
997               b->AddInstruction(
998                   HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
999                       explicit_left_padding_on_full_shape))),
1000               {}));
1001       predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare(
1002           mask_shape, index_in_padded_shape, valid_index_start,
1003           ComparisonDirection::kGe)));
1004     }
1005     if (total_right_padding > 0) {
1006       auto valid_index_limit =
1007           b->AddInstruction(HloInstruction::CreateBroadcast(
1008               index_shape,
1009               b->AddInstruction(
1010                   HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
1011                       base_shape.dimensions(dim) +
1012                       explicit_left_padding_on_full_shape))),
1013               {}));
1014       predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare(
1015           mask_shape, index_in_padded_shape, valid_index_limit,
1016           ComparisonDirection::kLt)));
1017     }
1018     CHECK(!predicates.empty());
1019     auto is_valid =
1020         predicates.size() == 2
1021             ? b->AddInstruction(HloInstruction::CreateBinary(
1022                   mask_shape, HloOpcode::kAnd, predicates[0], predicates[1]))
1023             : predicates[0];
1024     auto masking_value = b->AddInstruction(
1025         HloInstruction::CreateBroadcast(valid_slice->shape(), pad_value, {}));
1026     valid_slice = b->AddInstruction(
1027         HloInstruction::CreateTernary(valid_slice->shape(), HloOpcode::kSelect,
1028                                       is_valid, valid_slice, masking_value));
1029   }
1030   return valid_slice;
1031 }
1032 
HaloExchangeToPadOnLeft(PartitionedHlo & original,absl::Span<const int64_t> dims)1033 HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original,
1034                                         absl::Span<const int64_t> dims) {
1035   if (original.sharding().IsTileMaximal()) {
1036     return original.hlo();
1037   }
1038   // Create a window config to halo exchange for unevenly partitioned reverse
1039   // dimensions.
1040   Window window;
1041   for (int64_t i = 0; i < original.base_shape().rank(); ++i) {
1042     WindowDimension* dim = window.add_dimensions();
1043     dim->set_size(1);
1044     dim->set_stride(1);
1045     dim->set_window_dilation(1);
1046     dim->set_window_reversal(false);
1047     int64_t low_padding = 0;
1048     if (absl::c_linear_search(dims, i)) {
1049       low_padding = RoundUpTo(original.base_shape().dimensions(i),
1050                               original.sharding().tile_assignment().dim(i)) -
1051                     original.base_shape().dimensions(i);
1052     }
1053     dim->set_padding_low(low_padding);
1054     dim->set_padding_high(0);
1055     dim->set_base_dilation(1);
1056   }
1057 
1058   auto reshard_window = original.ReshardAsWindowedInput(
1059       window, original.sharding(),
1060       CreateZero(ShapeUtil::MakeShape(original.base_shape().element_type(), {}),
1061                  original.state().b),
1062       /*mask_invalid_region=*/false);
1063   if (!reshard_window.has_value()) {
1064     return nullptr;
1065   }
1066   CHECK(!reshard_window->dynamic_slice_index_on_output.has_value());
1067   return reshard_window->sharded_input;
1068 }
1069 
IsNanSafeGt(HloComputation * comp)1070 bool IsNanSafeGt(HloComputation* comp) {
1071   namespace m = match;
1072   auto match_bitcast_f32 = [](int64_t parameter_number) {
1073     auto param = m::Parameter(parameter_number)
1074                      .WithShape(m::Shape().WithElementType(F32));
1075     auto param_s32 =
1076         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
1077     auto param_u32 =
1078         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
1079     return m::Select(
1080         m::Lt(param_s32, m::ConstantScalar(0)),
1081         m::BitcastConvert(
1082             m::Subtract(m::ConstantScalar(std::numeric_limits<int32_t>::max()),
1083                         param_u32))
1084             .WithShape(m::Shape().WithElementType(S32)),
1085         param_s32);
1086   };
1087   auto match_bitcast_bf16 = [](int64_t parameter_number) {
1088     auto param = m::Convert(m::Parameter(parameter_number)
1089                                 .WithShape(m::Shape().WithElementType(BF16)))
1090                      .WithShape(m::Shape().WithElementType(F32));
1091     auto param_s32 =
1092         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
1093     auto param_u32 =
1094         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
1095     return m::Select(
1096         m::Lt(param_s32, m::ConstantScalar(0)),
1097         m::BitcastConvert(
1098             m::Subtract(m::ConstantScalar(std::numeric_limits<int32_t>::max()),
1099                         param_u32))
1100             .WithShape(m::Shape().WithElementType(S32)),
1101         param_s32);
1102   };
1103   // If root instruction is kSelect and compares indices if values are equal.
1104   if (comp->root_instruction()->opcode() == HloOpcode::kSelect) {
1105     return Match(comp->root_instruction()->operand(2),
1106                  m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
1107            Match(comp->root_instruction()->operand(2),
1108                  m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1)));
1109   }
1110   return Match(comp->root_instruction(),
1111                m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
1112          Match(comp->root_instruction(),
1113                m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1)));
1114 }
1115 
GetKValueInTopKWhenPartitionSortDim(HloInstruction * hlo)1116 std::optional<int64_t> GetKValueInTopKWhenPartitionSortDim(
1117     HloInstruction* hlo) {
1118   HloSortInstruction* sort = DynCast<HloSortInstruction>(hlo);
1119   if (sort == nullptr || sort->operand_count() != 2) {
1120     return std::nullopt;
1121   }
1122   if (!IsNanSafeGt(sort->to_apply())) {
1123     return std::nullopt;
1124   }
1125   HloInstruction* data = sort->mutable_operand(0);
1126   HloIotaInstruction* iota =
1127       DynCast<HloIotaInstruction>(sort->mutable_operand(1));
1128   const PrimitiveType element_type = data->shape().element_type();
1129   if (iota == nullptr || iota->shape().element_type() != S32 ||
1130       iota->opcode() != HloOpcode::kIota ||
1131       iota->iota_dimension() != sort->sort_dimension()) {
1132     return std::nullopt;
1133   }
1134 
1135   const int64_t sort_dim = sort->sort_dimension();
1136 
1137   if (element_type != F32 && element_type != BF16 && element_type != S32 &&
1138       element_type != U32) {
1139     return std::nullopt;
1140   }
1141 
1142   bool supported = true;
1143   std::optional<int64_t> k;
1144   for (HloInstruction* gte : sort->users()) {
1145     if (gte->opcode() != HloOpcode::kGetTupleElement) {
1146       supported = false;
1147       break;
1148     }
1149 
1150     const HloInstruction* slice = gte->users()[0];
1151     if (slice->opcode() != HloOpcode::kSlice) {
1152       // Non-slice user means we are not doing a TopK
1153       supported = false;
1154       break;
1155     }
1156     if (absl::c_any_of(slice->slice_starts(), [](int x) { return x != 0; }) ||
1157         absl::c_any_of(slice->slice_strides(), [](int x) { return x != 1; })) {
1158       // Strided slice or slicing at the beginning isn't supported.
1159       supported = false;
1160       break;
1161     }
1162     for (int64_t dim = 0; dim < data->shape().dimensions_size(); dim++) {
1163       if (dim == sort_dim) {
1164         continue;
1165       }
1166       if (slice->slice_limits(dim) !=
1167           slice->operand(0)->shape().dimensions(dim)) {
1168         // Slicing along the other dimension isn't supported.
1169         supported = false;
1170         break;
1171       }
1172     }
1173     if (!k.has_value()) {
1174       k = slice->slice_limits(sort_dim);
1175     } else if (k != slice->slice_limits(sort_dim)) {
1176       // Different k for the different operands isn't supported.
1177       supported = false;
1178       break;
1179     }
1180   }
1181   if (k == std::nullopt || !supported) {
1182     return std::nullopt;
1183   }
1184 
1185   // Only support when sort dim is sharded.
1186   if (!data->has_sharding()) {
1187     return std::nullopt;
1188   }
1189   const HloSharding& sharding = sort->operand(0)->sharding();
1190 
1191   if (sharding.IsTileMaximal()) {
1192     return std::nullopt;
1193   }
1194 
1195   // Check if partitioned at sort dimension.
1196   for (int64_t dim = 0; dim < sort->shape().tuple_shapes(0).dimensions_size();
1197        ++dim) {
1198     if (sharding.tile_assignment().dim(dim) > 1) {
1199       if (dim != sort_dim) {
1200         return std::nullopt;
1201       }
1202     }
1203   }
1204 
1205   // Checks if partition size is smaller than k.
1206   const int64_t shard_count = sharding.tile_assignment().dim(sort_dim);
1207 
1208   if (shard_count <= 1) {
1209     return std::nullopt;
1210   }
1211 
1212   const int64_t input_size = hlo->operand(0)->shape().dimensions(sort_dim);
1213   const int64_t per_partition_size = CeilOfRatio(input_size, shard_count);
1214 
1215   if (k.value() >= per_partition_size) {
1216     return std::nullopt;
1217   }
1218 
1219   return k;
1220 }
1221 
1222 // Slice first k elements from sort_dim.
SliceFirstK(HloInstruction * hlo,SpmdBuilder * builder,int64_t slice_dim,int64_t k)1223 HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
1224                             int64_t slice_dim, int64_t k) {
1225   const Shape& hlo_shape = hlo->shape();
1226   auto hlo_dims = hlo_shape.dimensions();
1227   std::vector<int64_t> start_indices(hlo_shape.dimensions_size(), 0);
1228   std::vector<int64_t> limit_indices(hlo_dims.begin(), hlo_dims.end());
1229   std::vector<int64_t> strides(hlo_shape.dimensions_size(), 1);
1230   limit_indices[slice_dim] = k;
1231   auto output_shape = hlo_shape;
1232   output_shape.set_dimensions(slice_dim, k);
1233   return builder->AddInstruction(HloInstruction::CreateSlice(
1234       output_shape, hlo, start_indices, limit_indices, strides));
1235 }
1236 
1237 // Check if a dimension is sharded.
ShardCountAtDim(const HloSharding & sharding,int64_t dim)1238 int64_t ShardCountAtDim(const HloSharding& sharding, int64_t dim) {
1239   if (sharding.IsTileMaximal()) {
1240     return 1;
1241   }
1242   return sharding.tile_assignment().dim(dim);
1243 }
1244 
1245 std::optional<std::vector<std::pair<int64_t, int64_t>>>
GetReshardAllToAllSourceTargetDims(const HloSharding & source,const HloSharding & target)1246 GetReshardAllToAllSourceTargetDims(const HloSharding& source,
1247                                    const HloSharding& target) {
1248   if (source.IsTileMaximal() || target.IsTileMaximal() ||
1249       source.tile_assignment().num_dimensions() !=
1250           target.tile_assignment().num_dimensions() ||
1251       source.NumTiles() != target.NumTiles()) {
1252     return std::nullopt;
1253   }
1254   // Record partition count to index for indices that have different partition
1255   // counts on source and target.
1256   std::map<int64_t, std::vector<int64_t>> source_size_to_dim;
1257   std::map<int64_t, std::vector<int64_t>> target_size_to_dim;
1258   for (int64_t i = 0; i < source.tile_assignment().num_dimensions(); ++i) {
1259     if (source.tile_assignment().dim(i) == target.tile_assignment().dim(i)) {
1260       continue;
1261     }
1262     source_size_to_dim[source.tile_assignment().dim(i)].push_back(i);
1263     target_size_to_dim[target.tile_assignment().dim(i)].push_back(i);
1264   }
1265   // In order to shard via AllToAll, source_size_to_dim and target_size_to_dim
1266   // must have the same distribution.
1267   if (source_size_to_dim.empty() ||
1268       source_size_to_dim.size() != target_size_to_dim.size()) {
1269     return std::nullopt;
1270   }
1271   for (const auto& entry : source_size_to_dim) {
1272     auto target_it = target_size_to_dim.find(entry.first);
1273     if (target_it == target_size_to_dim.end() ||
1274         target_it->second.size() != entry.second.size()) {
1275       return std::nullopt;
1276     }
1277   }
1278   std::vector<std::pair<int64_t, int64_t>> result;
1279   auto remove_entry = [](int64_t size, int64_t dim,
1280                          std::map<int64_t, std::vector<int64_t>>& size_to_dim) {
1281     size_to_dim[size].erase(
1282         std::remove_if(size_to_dim[size].begin(), size_to_dim[size].end(),
1283                        [dim](int64_t a) { return a == dim; }),
1284         size_to_dim[size].end());
1285     if (size_to_dim[size].empty()) {
1286       size_to_dim.erase(size);
1287     }
1288   };
1289   // Find one pair of dimensions to swap at a time.
1290   while (!source_size_to_dim.empty()) {
1291     int64_t source_size = source_size_to_dim.begin()->first;
1292     int64_t i = source_size_to_dim.begin()->second.back();
1293     int64_t target_i_size = target.tile_assignment().dim(i);
1294     if (target_i_size == source_size) {
1295       remove_entry(source_size, i, source_size_to_dim);
1296       remove_entry(source_size, i, target_size_to_dim);
1297       continue;
1298     }
1299     auto j_it = source_size_to_dim[target_i_size].begin();
1300     int64_t j = *j_it;
1301     if (source_size == 1) {
1302       // If possible, find a j where the target partition count is not one, so
1303       // that when we swap, the resulting size-1 dimension will still be useful
1304       // to other dimensions.
1305       while (target.tile_assignment().dim(j) == 1) {
1306         if (++j_it == source_size_to_dim[target_i_size].end()) {
1307           break;
1308         }
1309         j = *j_it;
1310       }
1311     } else if (target_i_size % source_size == 0) {
1312       // If possible, find a j where the target partition count is source_size,
1313       // so that we can do a single swap.
1314       while (target.tile_assignment().dim(j) != source_size) {
1315         if (++j_it == source_size_to_dim[target_i_size].end()) {
1316           break;
1317         }
1318         j = *j_it;
1319       }
1320     } else {
1321       return std::nullopt;
1322     }
1323     result.emplace_back(j, i);
1324     remove_entry(target_i_size, i, target_size_to_dim);
1325     source_size_to_dim.begin()->second.back() = j;
1326     remove_entry(target_i_size, j, source_size_to_dim);
1327   }
1328   return result;
1329 }
1330 
CanReshardWithCollectivePermute(const HloSharding & source,const HloSharding & target)1331 bool CanReshardWithCollectivePermute(const HloSharding& source,
1332                                      const HloSharding& target) {
1333   return !source.IsTileMaximal() && !target.IsTileMaximal() &&
1334          source.tile_assignment().dimensions() ==
1335              target.tile_assignment().dimensions() &&
1336          source.ReplicateOnLastTileDim() == target.ReplicateOnLastTileDim() &&
1337          source.tile_assignment() != target.tile_assignment();
1338 }
1339 
AlignGroupsWithInternal(GroupedSharding grouped_sharding,const GroupedSharding & reference,bool requires_compatibility,bool ignore_group_order)1340 std::optional<GroupedSharding> AlignGroupsWithInternal(
1341     GroupedSharding grouped_sharding, const GroupedSharding& reference,
1342     bool requires_compatibility, bool ignore_group_order) {
1343   // Returns src -> dst index mapping.
1344   auto get_permutation = [](absl::Span<const int64_t> src,
1345                             absl::Span<const int64_t> dst) {
1346     CHECK_EQ(src.size(), dst.size());
1347     absl::flat_hash_map<int64_t, int64_t> dst_reverse_map;
1348     for (int64_t i = 0; i < dst.size(); ++i) {
1349       dst_reverse_map[dst[i]] = i;
1350     }
1351     std::vector<int64_t> permutation(src.size());
1352     for (int64_t i = 0; i < src.size(); ++i) {
1353       auto it = dst_reverse_map.find(src[i]);
1354       CHECK(it != dst_reverse_map.end());
1355       permutation[i] = it->second;
1356     }
1357     return permutation;
1358   };
1359   CHECK_EQ(grouped_sharding.device_groups.size(),
1360            reference.device_groups.size());
1361   absl::flat_hash_map<int64_t, int64_t> device_to_ref_group;
1362   for (int64_t g = 0; g < reference.device_groups.size(); ++g) {
1363     for (int64_t device : reference.device_groups[g]) {
1364       device_to_ref_group[device] = g;
1365     }
1366   }
1367   auto unique_ref_dev_group =
1368       [&](absl::Span<const int64_t> devices) -> int64_t {
1369     int64_t ref_g = -1;
1370     for (int64_t device : devices) {
1371       if (ref_g == -1) {
1372         ref_g = device_to_ref_group[device];
1373       } else if (ref_g != device_to_ref_group[device]) {
1374         return -1;
1375       }
1376     }
1377     return ref_g;
1378   };
1379   bool matching_groups = true;
1380   std::vector<int64_t> original_src_to_ref_permutation;
1381   for (int64_t g = 0; g < grouped_sharding.device_groups.size(); ++g) {
1382     int64_t ref_g = unique_ref_dev_group(grouped_sharding.device_groups[g]);
1383     if (ref_g < 0 || (!ignore_group_order && g != ref_g)) {
1384       if (requires_compatibility) {
1385         return std::nullopt;
1386       }
1387       matching_groups = false;
1388       break;
1389     }
1390     if (g == 0) {
1391       original_src_to_ref_permutation = get_permutation(
1392           grouped_sharding.device_groups[g], reference.device_groups[ref_g]);
1393     } else if (requires_compatibility) {
1394       if (original_src_to_ref_permutation !=
1395           get_permutation(grouped_sharding.device_groups[g],
1396                           reference.device_groups[ref_g])) {
1397         return std::nullopt;
1398       }
1399     }
1400   }
1401   if (matching_groups && !grouped_sharding.sharding.IsTileMaximal()) {
1402     auto tiles = grouped_sharding.sharding.tile_assignment();
1403     tiles.Each([&](absl::Span<const int64_t> indices, int64_t* device) {
1404       *device = original_src_to_ref_permutation[*device];
1405     });
1406     grouped_sharding.sharding =
1407         grouped_sharding.sharding.ReplicateOnLastTileDim()
1408             ? HloSharding::PartialTile(tiles)
1409             : HloSharding::Tile(tiles);
1410   }
1411   grouped_sharding.device_groups = std::move(reference.device_groups);
1412   return grouped_sharding;
1413 }
1414 
AlignGroupsWith(GroupedSharding grouped_sharding,const GroupedSharding & reference,bool ignore_group_order)1415 GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding,
1416                                 const GroupedSharding& reference,
1417                                 bool ignore_group_order) {
1418   return *AlignGroupsWithInternal(std::move(grouped_sharding), reference,
1419                                   /*requires_compatibility=*/false,
1420                                   ignore_group_order);
1421 }
1422 
AlignGroupsWithIfCompatible(GroupedSharding grouped_sharding,const GroupedSharding & reference)1423 std::optional<GroupedSharding> AlignGroupsWithIfCompatible(
1424     GroupedSharding grouped_sharding, const GroupedSharding& reference) {
1425   return AlignGroupsWithInternal(std::move(grouped_sharding), reference,
1426                                  /*requires_compatibility=*/true,
1427                                  /*ignore_group_order=*/false);
1428 }
1429 
AlignShardingOnDims(const HloSharding & sharding,absl::Span<const int64_t> sharding_dims,const HloSharding & reference,absl::Span<const int64_t> reference_dims)1430 HloSharding AlignShardingOnDims(const HloSharding& sharding,
1431                                 absl::Span<const int64_t> sharding_dims,
1432                                 const HloSharding& reference,
1433                                 absl::Span<const int64_t> reference_dims) {
1434   auto sharding_grouped =
1435       hlo_sharding_util::GroupShardingOnDims(sharding, sharding_dims);
1436   auto reference_grouped =
1437       hlo_sharding_util::GroupShardingOnDims(reference, reference_dims);
1438   return hlo_sharding_util::UngroupSharding(
1439       AlignGroupsWith(sharding_grouped, reference_grouped));
1440 }
1441 
GetPerGroupBaseShape(const GroupedSharding & grouped_sharding,const Shape & original_base_shape)1442 Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding,
1443                            const Shape& original_base_shape) {
1444   auto result = original_base_shape;
1445   for (int64_t i = 0; i < grouped_sharding.group_dims.size(); ++i) {
1446     int64_t dim = grouped_sharding.group_dims[i];
1447     if (dim >= original_base_shape.rank()) {
1448       continue;
1449     }
1450     int64_t groups = grouped_sharding.group_dim_sizes[i];
1451     result.set_dimensions(dim, CeilOfRatio(result.dimensions(dim), groups));
1452   }
1453   return result;
1454 }
1455 
1456 namespace {
1457 
GetInGroupPartitionId(HloInstruction * partition_id,const std::vector<std::vector<int64_t>> & device_groups,SpmdBuilder * b)1458 HloInstruction* GetInGroupPartitionId(
1459     HloInstruction* partition_id,
1460     const std::vector<std::vector<int64_t>>& device_groups, SpmdBuilder* b) {
1461   int64_t total_devices = device_groups.size() * device_groups[0].size();
1462   std::vector<uint32_t> in_group_ids(total_devices);
1463   for (uint32_t i = 0; i < device_groups.size(); ++i) {
1464     for (uint32_t j = 0; j < device_groups[i].size(); ++j) {
1465       in_group_ids[device_groups[i][j]] = j;
1466     }
1467   }
1468   auto id_table = b->AddInstruction(HloInstruction::CreateConstant(
1469       LiteralUtil::CreateR1<uint32_t>(in_group_ids)));
1470   return b->AddInstruction(HloInstruction::CreateReshape(
1471       ShapeUtil::MakeScalarShape(U32),
1472       b->AddInstruction(HloInstruction::CreateDynamicSlice(
1473           ShapeUtil::MakeShape(U32, {1}), id_table, {partition_id}, {1}))));
1474 }
1475 
GetPerGroupCollectiveOpsCreator(const SPMDCollectiveOpsCreator & creator,const std::vector<std::vector<int64_t>> & device_groups)1476 SPMDCollectiveOpsCreator GetPerGroupCollectiveOpsCreator(
1477     const SPMDCollectiveOpsCreator& creator,
1478     const std::vector<std::vector<int64_t>>& device_groups) {
1479   SPMDCollectiveOpsCreator result;
1480   result.create_partition_id = [creator, device_groups](SpmdBuilder* b) {
1481     return GetInGroupPartitionId(creator.create_partition_id(b), device_groups,
1482                                  b);
1483   };
1484   auto expand_partition_groups =
1485       [device_groups](
1486           const std::vector<std::vector<int64_t>>& partition_subgroups) {
1487         if (partition_subgroups.empty()) {
1488           return device_groups;
1489         }
1490         std::vector<std::vector<int64_t>> result(partition_subgroups.size() *
1491                                                  device_groups.size());
1492         for (int64_t g = 0; g < device_groups.size(); ++g) {
1493           for (int64_t i = 0; i < partition_subgroups.size(); ++i) {
1494             result[g * partition_subgroups.size() + i].resize(
1495                 partition_subgroups[i].size());
1496             for (int64_t j = 0; j < partition_subgroups[i].size(); ++j) {
1497               result[g * partition_subgroups.size() + i][j] =
1498                   device_groups[g][partition_subgroups[i][j]];
1499             }
1500           }
1501         }
1502         return result;
1503       };
1504   result.create_cross_partition_all_reduce =
1505       [creator, expand_partition_groups](
1506           SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction,
1507           const std::vector<std::vector<int64_t>>& partition_subgroups,
1508           int64_t channel_id) {
1509         return creator.create_cross_partition_all_reduce(
1510             b, operand, reduction, expand_partition_groups(partition_subgroups),
1511             channel_id);
1512       };
1513   result.create_cross_partition_collective_permute =
1514       [creator, device_groups](
1515           SpmdBuilder* b, HloInstruction* operand,
1516           std::vector<std::pair<int64_t, int64_t>>& src_dst_pairs,
1517           int64_t next_channel_id) {
1518         std::vector<std::pair<int64_t, int64_t>> expanded_pairs(
1519             src_dst_pairs.size() * device_groups.size());
1520         for (int64_t g = 0; g < device_groups.size(); ++g) {
1521           for (int64_t i = 0; i < src_dst_pairs.size(); ++i) {
1522             expanded_pairs[g * src_dst_pairs.size() + i] =
1523                 std::pair<int64_t, int64_t>{
1524                     device_groups[g][src_dst_pairs[i].first],
1525                     device_groups[g][src_dst_pairs[i].second]};
1526           }
1527         }
1528         return creator.create_cross_partition_collective_permute(
1529             b, operand, expanded_pairs, next_channel_id);
1530       };
1531   result.create_cross_partition_all_to_all =
1532       [creator, expand_partition_groups](
1533           SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
1534           const std::vector<std::vector<int64_t>>& partition_subgroups,
1535           int64_t channel_id, std::optional<int64_t> split_dimension) {
1536         return creator.create_cross_partition_all_to_all(
1537             b, operands, expand_partition_groups(partition_subgroups),
1538             channel_id, split_dimension);
1539       };
1540   if (creator.create_cross_partition_all_gather) {
1541     result.create_cross_partition_all_gather =
1542         [creator, expand_partition_groups](
1543             SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape,
1544             const std::vector<std::vector<int64_t>>& partition_subgroups,
1545             int64_t channel_id, int64_t all_gather_dimension) {
1546           return creator.create_cross_partition_all_gather(
1547               b, operand, ag_shape,
1548               expand_partition_groups(partition_subgroups), channel_id,
1549               all_gather_dimension);
1550         };
1551   }
1552   return result;
1553 }
1554 
1555 }  // namespace
1556 
CreatePerGroupPartitioningState(const PartitionedHlo::PartitioningState & state,const std::vector<std::vector<int64_t>> & device_groups,SpmdBuilder * b)1557 PartitionedHlo::PartitioningState CreatePerGroupPartitioningState(
1558     const PartitionedHlo::PartitioningState& state,
1559     const std::vector<std::vector<int64_t>>& device_groups, SpmdBuilder* b) {
1560   auto result = state;
1561   result.collective_ops_creator = GetPerGroupCollectiveOpsCreator(
1562       state.collective_ops_creator, device_groups);
1563   result.partition_id =
1564       GetInGroupPartitionId(state.partition_id, device_groups, b);
1565   // Create a string key for the groups.
1566   std::vector<std::string> per_group_strings(device_groups.size());
1567   for (int64_t i = 0; i < per_group_strings.size(); ++i) {
1568     per_group_strings[i] = absl::StrJoin(device_groups[i], ",");
1569   }
1570   auto& grouped_cache =
1571       state.reshard_cache->groupd_caches[absl::StrJoin(per_group_strings, ";")];
1572   if (!grouped_cache) {
1573     grouped_cache = std::make_unique<PartitionedHlo::ReshardCache>();
1574   }
1575   result.reshard_cache = grouped_cache.get();
1576   return result;
1577 }
1578 
PerGroupSliceFromReplicated(HloInstruction * replicated,HloInstruction * partition_id,const std::vector<std::vector<int64_t>> & device_groups,absl::Span<const int64_t> group_dims,absl::Span<const int64_t> group_dim_sizes,SpmdBuilder * b)1579 HloInstruction* PerGroupSliceFromReplicated(
1580     HloInstruction* replicated, HloInstruction* partition_id,
1581     const std::vector<std::vector<int64_t>>& device_groups,
1582     absl::Span<const int64_t> group_dims,
1583     absl::Span<const int64_t> group_dim_sizes, SpmdBuilder* b) {
1584   std::vector<uint32_t> group_ids(device_groups.size() *
1585                                   device_groups[0].size());
1586   for (int64_t g = 0; g < device_groups.size(); ++g) {
1587     for (int64_t device : device_groups[g]) {
1588       group_ids[device] = g;
1589     }
1590   }
1591   auto group_id_table = b->AddInstruction(HloInstruction::CreateConstant(
1592       LiteralUtil::CreateR1<uint32_t>(group_ids)));
1593   auto group_id = b->AddInstruction(HloInstruction::CreateReshape(
1594       ShapeUtil::MakeScalarShape(U32),
1595       b->AddInstruction(HloInstruction::CreateDynamicSlice(
1596           ShapeUtil::MakeShape(U32, {1}), group_id_table, {partition_id},
1597           {1}))));
1598   std::vector<int64_t> group_level_tile_dims(replicated->shape().rank(), 1);
1599   for (int64_t i = 0; i < group_dims.size(); ++i) {
1600     group_level_tile_dims[group_dims[i]] = group_dim_sizes[i];
1601   }
1602   Array<int64_t> group_level_tile(group_level_tile_dims);
1603   group_level_tile.Each([&](absl::Span<const int64_t> indices, int64_t* group) {
1604     *group = 0;
1605     for (int64_t dim : group_dims) {
1606       *group *= group_level_tile.dim(dim);
1607       *group += indices[dim];
1608     }
1609   });
1610   auto group_level_sharding = HloSharding::Tile(group_level_tile);
1611   auto padded_hlo = PadBaseShapeBeforeUnevenTiledSharding(
1612       replicated, group_level_sharding, b);
1613   auto shard_shape =
1614       MakePartitionedShape(replicated->shape(), group_level_sharding);
1615   return b->AddInstruction(HloInstruction::CreateDynamicSlice(
1616       shard_shape, padded_hlo,
1617       MakePartitionOffsets(replicated->shape(), group_level_sharding, group_id,
1618                            b),
1619       shard_shape.dimensions()));
1620 }
1621 
FindMatchingPartitionedDimsForGrouping(const HloSharding & sharding,const std::vector<std::vector<int64_t>> & device_groups)1622 std::optional<std::vector<int64_t>> FindMatchingPartitionedDimsForGrouping(
1623     const HloSharding& sharding,
1624     const std::vector<std::vector<int64_t>>& device_groups) {
1625   if (sharding.IsTileMaximal() || device_groups.size() < 2) {
1626     return std::nullopt;
1627   }
1628   std::vector<int64_t> dims;
1629   if (device_groups[0].size() < 2) {
1630     // Trivial case: single member groups
1631     for (int64_t i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
1632       if (sharding.tile_assignment().dim(i) > 1) {
1633         dims.push_back(i);
1634       }
1635     }
1636     return dims;
1637   }
1638   int64_t rank = sharding.tile_assignment().num_dimensions();
1639   absl::flat_hash_map<int64_t, std::vector<int64_t>> device_to_index;
1640   sharding.tile_assignment().Each(
1641       [&](absl::Span<const int64_t> index, int64_t device) {
1642         device_to_index[device] =
1643             std::vector<int64_t>(index.begin(), index.begin() + rank);
1644       });
1645   int64_t group_count = 1;
1646   for (int64_t i = 0; i < rank; ++i) {
1647     if (device_to_index[device_groups[0][0]][i] ==
1648         device_to_index[device_groups[0][1]][i]) {
1649       dims.push_back(i);
1650       group_count *= sharding.tile_assignment().dim(i);
1651     }
1652   }
1653   if (group_count != device_groups.size()) {
1654     return std::nullopt;
1655   }
1656   for (const auto& group : device_groups) {
1657     for (int64_t i = 1; i < group.size(); ++i) {
1658       if (absl::c_any_of(dims, [&](const int64_t dim) {
1659             return device_to_index[group[i]][dim] !=
1660                    device_to_index[group[0]][dim];
1661           })) {
1662         return std::nullopt;
1663       }
1664     }
1665   }
1666   return dims;
1667 }
1668 
CreateMatchingShardingOnDims(const Shape & target_shape,const HloSharding & source_sharding,absl::Span<const int64_t> target_dims,absl::Span<const int64_t> source_dims)1669 HloSharding CreateMatchingShardingOnDims(
1670     const Shape& target_shape, const HloSharding& source_sharding,
1671     absl::Span<const int64_t> target_dims,
1672     absl::Span<const int64_t> source_dims) {
1673   CHECK(target_dims.size() == source_dims.size())
1674       << "Expected 1:1 match between parallel dimensions";
1675   if (source_sharding.IsReplicated()) {
1676     return HloSharding::Replicate();
1677   }
1678   absl::InlinedVector<int64_t, 4> tile_dims(target_shape.dimensions_size(), 1);
1679   int num_tiles = 1;
1680   for (int i = 0, end = target_dims.size(); i < end; ++i) {
1681     num_tiles *= source_sharding.tile_assignment().dim(source_dims[i]);
1682     tile_dims[target_dims[i]] =
1683         source_sharding.tile_assignment().dim(source_dims[i]);
1684   }
1685   // If there is some partition across non-parallel dimensions in the
1686   // other operand then partially replicate for the new
1687   bool to_be_partially_replicated = false;
1688   if (num_tiles != source_sharding.tile_assignment().num_elements()) {
1689     CHECK_EQ(source_sharding.tile_assignment().num_elements() % num_tiles, 0);
1690     to_be_partially_replicated = true;
1691     tile_dims.push_back(source_sharding.tile_assignment().num_elements() /
1692                         num_tiles);
1693   }
1694   auto tgt_tile_assignment = source_sharding.tile_assignment();
1695   tgt_tile_assignment.Reshape(tile_dims);
1696   if (to_be_partially_replicated) {
1697     return AlignShardingOnDims(HloSharding::PartialTile(tgt_tile_assignment),
1698                                target_dims, source_sharding, source_dims);
1699   } else {
1700     return AlignShardingOnDims(HloSharding::Tile(tgt_tile_assignment),
1701                                target_dims, source_sharding, source_dims);
1702   }
1703 }
1704 
1705 std::optional<GatherParallelDimSharding>
GatherOperandsShardedAcrossParallelDims(const HloInstruction & operand,const HloInstruction & indices,const hlo_sharding_util::GatherParallelDims & parallel_dims)1706 GatherOperandsShardedAcrossParallelDims(
1707     const HloInstruction& operand, const HloInstruction& indices,
1708     const hlo_sharding_util::GatherParallelDims& parallel_dims) {
1709   auto& indices_parallel_dims = parallel_dims.indices_parallel_dims;
1710   auto& operand_parallel_dims = parallel_dims.operand_parallel_dims;
1711   if (indices_parallel_dims.size() != operand_parallel_dims.size()) {
1712     return std::nullopt;
1713   }
1714   auto new_index_shard = indices.sharding();
1715   auto new_operand_shard = operand.sharding();
1716   int idx_parallel_tiles_num = new_index_shard.NumTiles(indices_parallel_dims);
1717   int op_parallel_tiles_num = new_operand_shard.NumTiles(operand_parallel_dims);
1718   if (idx_parallel_tiles_num == 1 && op_parallel_tiles_num == 1) {
1719     return std::nullopt;
1720   }
1721   absl::InlinedVector<int64_t, 1> indices_parallel_dims_ordered_as_operand;
1722   for (int idx : parallel_dims.index_parallel_in_dim) {
1723     if (idx != -1) {
1724       indices_parallel_dims_ordered_as_operand.push_back(idx);
1725     }
1726   }
1727   if (new_index_shard.IsReplicated()) {
1728     return GatherParallelDimSharding{
1729         CreateMatchingShardingOnDims(indices.shape(), new_operand_shard,
1730                                      indices_parallel_dims_ordered_as_operand,
1731                                      operand_parallel_dims),
1732         new_operand_shard};
1733   }
1734   if (new_operand_shard.IsReplicated()) {
1735     return GatherParallelDimSharding{
1736         new_index_shard,
1737         CreateMatchingShardingOnDims(operand.shape(), new_index_shard,
1738                                      operand_parallel_dims,
1739                                      indices_parallel_dims_ordered_as_operand)};
1740   }
1741 
1742   // Parallel dimension distribution needs to be the same, so try to steal
1743   // sharding from partial replication to compensate.
1744   if (idx_parallel_tiles_num != op_parallel_tiles_num) {
1745     auto to_adjust_dims = operand_parallel_dims;
1746     auto target_dims = indices_parallel_dims_ordered_as_operand;
1747     HloSharding* target = &new_index_shard;
1748     HloSharding* to_adjust = &new_operand_shard;
1749     if (idx_parallel_tiles_num < op_parallel_tiles_num) {
1750       std::swap(to_adjust_dims, target_dims);
1751       std::swap(to_adjust, target);
1752     }
1753     if (!to_adjust->ReplicateOnLastTileDim()) {
1754       return std::nullopt;
1755     }
1756     auto new_tile_assignment_dims = to_adjust->tile_assignment().dimensions();
1757     for (int i = 0; i < to_adjust_dims.size(); ++i) {
1758       int64_t target_dim = target->tile_assignment().dim(target_dims[i]);
1759       int64_t to_adjust_dim =
1760           to_adjust->tile_assignment().dim(to_adjust_dims[i]);
1761       if (target_dim < to_adjust_dim) {
1762         return std::nullopt;
1763       }
1764       if (target_dim == to_adjust_dim) {
1765         continue;
1766       }
1767       int64_t ratio = target_dim / to_adjust_dim;
1768       if (target_dim % to_adjust_dim != 0 ||
1769           new_tile_assignment_dims.back() % ratio != 0) {
1770         return std::nullopt;
1771       }
1772       new_tile_assignment_dims[to_adjust_dims[i]] *= ratio;
1773       new_tile_assignment_dims.back() /= ratio;
1774     }
1775     CHECK_GE(new_tile_assignment_dims.back(), 1);
1776     bool to_partially_replicate = true;
1777     if (new_tile_assignment_dims.back() == 1) {
1778       new_tile_assignment_dims.pop_back();
1779       to_partially_replicate = false;
1780     }
1781     auto new_tile_assignment = to_adjust->tile_assignment();
1782     new_tile_assignment.Reshape(new_tile_assignment_dims);
1783     if (to_partially_replicate) {
1784       *to_adjust =
1785           AlignShardingOnDims(HloSharding::PartialTile(new_tile_assignment),
1786                               to_adjust_dims, *target, target_dims);
1787     } else {
1788       *to_adjust = AlignShardingOnDims(HloSharding::Tile(new_tile_assignment),
1789                                        to_adjust_dims, *target, target_dims);
1790     }
1791   }
1792   // Make sure that the parallel dimensions are aligned.
1793   auto operand_shard_tile_dims =
1794       new_operand_shard.tile_assignment().dimensions();
1795   for (int i = 0; i < indices_parallel_dims_ordered_as_operand.size(); ++i) {
1796     operand_shard_tile_dims[operand_parallel_dims[i]] =
1797         new_index_shard.tile_assignment().dim(
1798             indices_parallel_dims_ordered_as_operand[i]);
1799   }
1800   auto operand_shard_tiles = new_operand_shard.tile_assignment();
1801   operand_shard_tiles.Reshape(operand_shard_tile_dims);
1802   new_operand_shard =
1803       AlignShardingOnDims(new_operand_shard.ReplicateOnLastTileDim()
1804                               ? HloSharding::PartialTile(operand_shard_tiles)
1805                               : HloSharding::Tile(operand_shard_tiles),
1806                           operand_parallel_dims, new_index_shard,
1807                           indices_parallel_dims_ordered_as_operand);
1808   return GatherParallelDimSharding{new_index_shard, new_operand_shard};
1809 }
1810 
FindRotateRightPattern(const HloInstruction * concat,const HloInstruction * lhs,const HloInstruction * rhs)1811 int64_t FindRotateRightPattern(const HloInstruction* concat,
1812                                const HloInstruction* lhs,
1813                                const HloInstruction* rhs) {
1814   if (lhs->opcode() != HloOpcode::kSlice ||
1815       rhs->opcode() != HloOpcode::kSlice ||
1816       lhs->operand(0) != rhs->operand(0)) {
1817     return -1;
1818   }
1819   const HloInstruction* to_rotate = lhs->operand(0);
1820   if (!ShapeUtil::Compatible(to_rotate->shape(), concat->shape()) ||
1821       concat->sharding() != to_rotate->sharding()) {
1822     return -1;
1823   }
1824   const int64_t dim = concat->concatenate_dimension();
1825   if (lhs->slice_strides(dim) != 1 || rhs->slice_strides(dim) != 1 ||
1826       lhs->slice_starts(dim) != rhs->slice_limits(dim)) {
1827     return -1;
1828   }
1829   return lhs->shape().dimensions(dim);
1830 }
1831 
FindPadWithWrapPattern(const HloInstruction * concat,const HloInstruction * lhs,const HloInstruction * mid,const HloInstruction * rhs)1832 std::optional<PadWithWrapPattern> FindPadWithWrapPattern(
1833     const HloInstruction* concat, const HloInstruction* lhs,
1834     const HloInstruction* mid, const HloInstruction* rhs) {
1835   if (!lhs || !mid || !rhs) {
1836     return std::nullopt;
1837   }
1838 
1839   // Skip elementwise unary operations applied to inst, returning
1840   // a list of applied operations that were skipped.
1841   auto skip_elementwise_ops = [&](const HloInstruction* inst) {
1842     std::vector<const HloInstruction*> modifiers;
1843     while (inst->IsElementwise() && inst->operand_count() == 1 &&
1844            inst->user_count() == 1) {
1845       if (inst->opcode() != HloOpcode::kCopy) {
1846         modifiers.push_back(inst);
1847       }
1848       inst = inst->operand(0);
1849     }
1850     return std::make_pair(modifiers, inst);
1851   };
1852 
1853   PadWithWrapPattern pad_pattern;
1854   auto skip_result = skip_elementwise_ops(lhs);
1855   pad_pattern.lhs_modifiers = std::move(skip_result.first);
1856   lhs = skip_result.second;
1857 
1858   skip_result = skip_elementwise_ops(rhs);
1859   pad_pattern.rhs_modifiers = std::move(skip_result.first);
1860   rhs = skip_result.second;
1861 
1862   const int64_t dim = concat->concatenate_dimension();
1863   if (lhs->opcode() != HloOpcode::kSlice ||
1864       rhs->opcode() != HloOpcode::kSlice || lhs->operand(0) != mid ||
1865       rhs->operand(0) != mid || lhs->slice_strides(dim) != 1 ||
1866       rhs->slice_strides(dim) != 1 || lhs->sharding() != mid->sharding() ||
1867       rhs->sharding() != mid->sharding() ||
1868       lhs->sharding() != concat->sharding()) {
1869     return std::nullopt;
1870   }
1871   pad_pattern.lhs_slice_start = lhs->slice_starts(dim);
1872   pad_pattern.rhs_slice_start = rhs->slice_starts(dim);
1873   return pad_pattern;
1874 }
1875 
1876 }  // namespace spmd
1877 }  // namespace xla
1878