1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <optional>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/inlined_vector.h"
25 #include "absl/strings/str_join.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/service/hlo_module.h"
32 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
33 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
34 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
35 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
36 #include "tensorflow/compiler/xla/service/shape_inference.h"
37 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/util.h"
40 #include "tensorflow/compiler/xla/window_util.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42
43 namespace xla {
44 namespace spmd {
45
46 namespace {
47 using hlo_sharding_util::GroupedSharding;
48 } // namespace
49
HasReplicatedSharding(const HloSharding & sharding)50 bool HasReplicatedSharding(const HloSharding& sharding) {
51 if (sharding.IsTuple()) {
52 return absl::c_any_of(sharding.tuple_elements(), HasReplicatedSharding);
53 }
54 return sharding.IsReplicated();
55 }
56
MakeBinaryAdd(PrimitiveType type,HloModule * module)57 HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) {
58 HloComputation::Builder sum_b("add");
59 auto x = sum_b.AddInstruction(HloInstruction::CreateParameter(
60 /*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x"));
61 auto y = sum_b.AddInstruction(HloInstruction::CreateParameter(
62 /*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y"));
63 if (type == PRED) {
64 sum_b.AddInstruction(HloInstruction::CreateBinary(
65 ShapeUtil::MakeShape(type, {}), HloOpcode::kOr, x, y));
66 } else {
67 sum_b.AddInstruction(HloInstruction::CreateBinary(
68 ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y));
69 }
70 HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build());
71 return reduction;
72 }
73
EvenlyPartitions(const Shape & shape,const HloSharding & sharding)74 bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) {
75 if (sharding.IsTuple()) {
76 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
77 if (!EvenlyPartitions(ShapeUtil::GetTupleElementShape(shape, i),
78 sharding.GetSubSharding(shape, {i}))) {
79 return false;
80 }
81 }
82 }
83
84 if (sharding.IsTileMaximal()) {
85 return sharding.IsReplicated();
86 }
87 for (int64_t i = 0; i < shape.dimensions_size(); ++i) {
88 if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) {
89 return false;
90 }
91 }
92 return true;
93 }
94
MakePartitionedShape(const Shape & shape,const HloSharding & sharding)95 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) {
96 if (sharding.IsTuple()) {
97 std::vector<Shape> subshapes;
98 const int64_t shape_n = ShapeUtil::TupleElementCount(shape);
99 subshapes.reserve(shape_n);
100 for (int64_t i = 0; i < shape_n; ++i) {
101 subshapes.push_back(
102 MakePartitionedShape(ShapeUtil::GetTupleElementShape(shape, i),
103 sharding.GetSubSharding(shape, {i})));
104 }
105 return ShapeUtil::MakeTupleShape(subshapes);
106 }
107 return sharding.TileShape(shape);
108 }
109
ShapeSizeInBytes(const Shape & shape)110 int64_t ShapeSizeInBytes(const Shape& shape) {
111 return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) *
112 ShapeUtil::ElementsIn(shape);
113 }
114
MakeNonPaddedShapeForGivenPartition(const Shape & shape,const HloSharding & sharding,int64_t partition_id)115 Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
116 const HloSharding& sharding,
117 int64_t partition_id) {
118 if (sharding.IsTuple()) {
119 std::vector<Shape> subshapes;
120 const int64_t shape_n = ShapeUtil::TupleElementCount(shape);
121 subshapes.reserve(shape_n);
122 for (int64_t i = 0; i < shape_n; ++i) {
123 subshapes.push_back(MakeNonPaddedShapeForGivenPartition(
124 ShapeUtil::GetTupleElementShape(shape, i),
125 sharding.GetSubSharding(shape, {i}), partition_id));
126 }
127 return ShapeUtil::MakeTupleShape(subshapes);
128 }
129
130 if (sharding.IsReplicated()) {
131 return shape;
132 }
133 if (sharding.IsTileMaximal()) {
134 if (partition_id == *sharding.UniqueDevice()) {
135 return shape;
136 }
137 return ShapeUtil::MakeTupleShape({});
138 }
139
140 auto partition_shape = shape;
141 std::vector<int64_t> tile_offset =
142 sharding.TileOffsetForDevice(shape, partition_id);
143 std::vector<int64_t> tile_limit =
144 sharding.TileLimitForDevice(shape, partition_id);
145 for (int64_t i = 0; i < tile_offset.size(); ++i) {
146 if (sharding.UsesDevice(partition_id)) {
147 partition_shape.set_dimensions(i, tile_limit[i] - tile_offset[i]);
148 } else {
149 partition_shape.set_dimensions(i, 0);
150 }
151 }
152 return partition_shape;
153 }
154
MakePartitionOffsets(const Shape & shape,const HloSharding & sharding,HloInstruction * partition_id,SpmdBuilder * b,absl::Span<const int64_t> dims)155 std::vector<HloInstruction*> MakePartitionOffsets(
156 const Shape& shape, const HloSharding& sharding,
157 HloInstruction* partition_id, SpmdBuilder* b,
158 absl::Span<const int64_t> dims) {
159 CHECK(!shape.IsTuple());
160
161 std::vector<std::vector<int32_t>> offset_arrays(shape.rank());
162 for (int64_t i = 0; i < shape.rank(); ++i) {
163 offset_arrays[i].resize(sharding.tile_assignment().num_elements());
164 }
165 auto shard_shape = MakePartitionedShape(shape, sharding);
166 sharding.tile_assignment().Each(
167 [&](absl::Span<const int64_t> indices, int64_t device) {
168 for (int64_t i = 0; i < shape.rank(); ++i) {
169 offset_arrays[i][device] = indices[i] * shard_shape.dimensions(i);
170 }
171 });
172 std::vector<HloInstruction*> offsets;
173 for (int64_t i = 0; i < shape.rank(); ++i) {
174 if (sharding.tile_assignment().dim(i) == 1 ||
175 (!dims.empty() && !absl::c_linear_search(dims, i))) {
176 offsets.push_back(b->AddInstruction(
177 HloInstruction::CreateConstant(LiteralUtil::Zero(S32))));
178 } else {
179 auto offset_table = b->AddInstruction(HloInstruction::CreateConstant(
180 LiteralUtil::CreateR1<int32_t>(offset_arrays[i])));
181 auto index = b->AddInstruction(HloInstruction::CreateDynamicSlice(
182 ShapeUtil::MakeShape(S32, {1}), offset_table, {partition_id}, {1}));
183 offsets.push_back(b->AddInstruction(
184 HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), index)));
185 }
186 }
187 return offsets;
188 }
189
MakeTiledPartitionOrdinals(const HloSharding & sharding,HloInstruction * partition_id,SpmdBuilder * b)190 std::vector<HloInstruction*> MakeTiledPartitionOrdinals(
191 const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) {
192 CHECK(!sharding.IsTileMaximal());
193 auto dimensions = sharding.tile_assignment().dimensions();
194 if (sharding.ReplicateOnLastTileDim()) {
195 dimensions.pop_back();
196 }
197 auto table_shape = ShapeUtil::MakeShape(S32, dimensions);
198 return MakePartitionOffsets(table_shape, sharding, partition_id, b);
199 }
200
GetPaddedShapeForUnevenPartitioning(const Shape & base_shape,const HloSharding & sharding)201 Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape,
202 const HloSharding& sharding) {
203 if (sharding.IsTileMaximal()) {
204 return base_shape;
205 }
206 if (EvenlyPartitions(base_shape, sharding)) {
207 return base_shape;
208 }
209 auto shard_shape = MakePartitionedShape(base_shape, sharding);
210 Shape padded_base_shape = base_shape;
211 for (int64_t i = 0; i < padded_base_shape.rank(); ++i) {
212 padded_base_shape.set_dimensions(
213 i, shard_shape.dimensions(i) * sharding.tile_assignment().dim(i));
214 }
215 return padded_base_shape;
216 }
217
PartialReplicateReshardCompatibleSharding(const HloSharding & partial_sharding,const HloSharding & target_sharding)218 std::optional<HloSharding> PartialReplicateReshardCompatibleSharding(
219 const HloSharding& partial_sharding, const HloSharding& target_sharding) {
220 if (!partial_sharding.ReplicateOnLastTileDim()) {
221 return std::nullopt;
222 }
223 int64_t rank = partial_sharding.tile_assignment().num_dimensions() - 1;
224 int64_t target_rank = target_sharding.tile_assignment().num_dimensions() -
225 (target_sharding.ReplicateOnLastTileDim() ? 1 : 0);
226 if (target_rank != rank) {
227 return std::nullopt;
228 }
229
230 absl::flat_hash_map<int64_t, int64_t> device_to_replication_group;
231 partial_sharding.tile_assignment().Each(
232 [&](absl::Span<const int64_t> indices, int64_t device) {
233 int64_t gid = 0;
234 for (int64_t i = 0; i < rank; ++i) {
235 gid *= partial_sharding.tile_assignment().dim(i);
236 gid += indices[i];
237 }
238 device_to_replication_group[device] = gid;
239 });
240
241 // A dimension is expanded when target_tile_size > partial_tile_size and
242 // target_tile_size % partial_tile_size == 0.
243 // expand_tile_dims_positions is the index of the expand_dim.
244 std::vector<int64_t> expand_tile_dims_indices(rank, -1);
245 // expand_tile_size = target_tile_size / partial_tile_size.
246 std::vector<int64_t> expand_tile_sizes;
247 int num_expand_dims = 0;
248 for (int64_t dim = 0; dim < rank; dim++) {
249 int64_t partial_tile_size = partial_sharding.tile_assignment().dim(dim);
250 int64_t target_tile_size = target_sharding.tile_assignment().dim(dim);
251 if (target_tile_size % partial_tile_size != 0 ||
252 target_tile_size < partial_tile_size) {
253 return std::nullopt;
254 }
255
256 if (target_tile_size > partial_tile_size) {
257 expand_tile_dims_indices[dim] = num_expand_dims++;
258 expand_tile_sizes.emplace_back(target_tile_size / partial_tile_size);
259 }
260 }
261
262 // Reshape the partial replicate tile_dimensions.
263 int64_t num_target_replication = 1;
264 if (target_sharding.ReplicateOnLastTileDim()) {
265 num_target_replication =
266 target_sharding.tile_assignment().dimensions().back();
267 }
268 auto reshape_dimensions = partial_sharding.tile_assignment().dimensions();
269 int64_t num_replication = reshape_dimensions.back();
270 if (num_replication / num_target_replication != Product(expand_tile_sizes) ||
271 num_replication % num_target_replication != 0) {
272 return std::nullopt;
273 }
274
275 reshape_dimensions.pop_back();
276 reshape_dimensions.insert(reshape_dimensions.end(), expand_tile_sizes.begin(),
277 expand_tile_sizes.end());
278
279 if (target_sharding.ReplicateOnLastTileDim()) {
280 reshape_dimensions.push_back(num_target_replication);
281 }
282
283 auto reshape_tile_assignment = partial_sharding.tile_assignment();
284 reshape_tile_assignment.Reshape(reshape_dimensions);
285
286 // Transpose.
287 std::vector<int64_t> perm;
288 perm.reserve(rank + expand_tile_sizes.size());
289 for (int64_t dim = 0; dim < rank; dim++) {
290 perm.emplace_back(dim);
291 if (expand_tile_dims_indices[dim] > -1) {
292 perm.emplace_back(expand_tile_dims_indices[dim] + rank);
293 }
294 }
295 auto transpose_sharding = hlo_sharding_util::TransposeSharding(
296 target_sharding.ReplicateOnLastTileDim()
297 ? HloSharding::PartialTile(reshape_tile_assignment)
298 : HloSharding::Tile(reshape_tile_assignment),
299 perm);
300
301 // Reshape to target shape
302 auto transpose_tile_assignment = transpose_sharding.tile_assignment();
303 transpose_tile_assignment.Reshape(
304 target_sharding.tile_assignment().dimensions());
305
306 bool groups_matching = true;
307 target_sharding.tile_assignment().Each(
308 [&](absl::Span<const int64_t> indices, int64_t device) {
309 if (device_to_replication_group[device] !=
310 device_to_replication_group[transpose_tile_assignment(indices)]) {
311 groups_matching = false;
312 }
313 });
314
315 if (groups_matching) {
316 return target_sharding;
317 }
318 return target_sharding.ReplicateOnLastTileDim()
319 ? HloSharding::PartialTile(transpose_tile_assignment)
320 : HloSharding::Tile(transpose_tile_assignment);
321 }
322
TileToPartialReplicateHaloExchange(HloInstruction * hlo,const Shape & base_shape,const HloSharding & src_sharding,const HloSharding & dst_sharding,const std::vector<int64_t> & replicate_dims,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t * next_channel_id,HloInstruction * partition_id,SpmdBuilder * b)323 std::optional<HloInstruction*> TileToPartialReplicateHaloExchange(
324 HloInstruction* hlo, const Shape& base_shape,
325 const HloSharding& src_sharding, const HloSharding& dst_sharding,
326 const std::vector<int64_t>& replicate_dims,
327 const SPMDCollectiveOpsCreator& collective_ops_creator,
328 int64_t* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) {
329 // Source is tile sharding.
330 auto padded_src_shape =
331 GetPaddedShapeForUnevenPartitioning(base_shape, src_sharding);
332 // Target is partial replicate.
333 auto padded_dst_shape =
334 GetPaddedShapeForUnevenPartitioning(base_shape, dst_sharding);
335 if (ShapeUtil::Compatible(padded_dst_shape, hlo->shape())) {
336 return hlo;
337 }
338
339 auto partition_ordinals =
340 MakeTiledPartitionOrdinals(dst_sharding, partition_id, b);
341
342 auto result = hlo;
343 auto hlo_shape = hlo->shape();
344 for (auto dim : replicate_dims) {
345 int64_t dst_shard_count = dst_sharding.tile_assignment().dim(dim);
346 int64_t src_per_shard_size =
347 padded_src_shape.dimensions(dim) / dst_shard_count;
348 // Calculate per shard size using the sharding to compare if dst_sharding
349 // needs more padding at the end.
350 int64_t dst_per_shard_size =
351 padded_dst_shape.dimensions(dim) / dst_shard_count;
352
353 // If src per shard doesn't have redundant data.
354 if (src_per_shard_size <= dst_per_shard_size || dst_shard_count == 1) {
355 continue;
356 }
357
358 // If src_per_shard * replicate_factor > dst_per_shard , need to
359 // re-distribute the data between each shard using collective permute. For
360 // example, if dimension size is 6 and shard 4 ways in the src but needs to
361 // shard 2 ways in the dst. 4 way sharding has 2 element in each shard,
362 // while 2 way sharding has 3 elements, the last element in the first shard
363 // will be sliced out. re-distribution is needed.
364 //
365 // 1. Calculate left_halo size.
366 // left-halo size is
367 // (src_per_shard_size - dst_per_shard_size) * i / replicate_factor
368 int64_t replicate_factor = src_sharding.tile_assignment().dim(dim) /
369 dst_sharding.tile_assignment().dim(dim);
370 OffsetCalculation left_halo_size_function =
371 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
372 src_per_shard_size - dst_per_shard_size, 0, replicate_factor));
373
374 // 2. Calculate right_halo size.
375 // right-halo size is 0
376 OffsetCalculation right_halo_size_function =
377 OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1));
378
379 auto concat = result;
380 // 3. Halo exchange.
381 auto halo_exchange_result = ExchangeHalo(
382 result, left_halo_size_function, right_halo_size_function, dim,
383 src_sharding, collective_ops_creator, next_channel_id, b);
384
385 if (halo_exchange_result.has_value()) {
386 concat = halo_exchange_result.value();
387 } else {
388 return std::nullopt;
389 }
390
391 // 4. Slice the valid result.
392 // Slice offset is
393 // (dst_shard_count - i - 1) *
394 // (src_per_shard_size - dst_per_shard_size)
395 // i is the index in dst_sharindg.
396 auto zero_s32 = b->AddInstruction(
397 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
398 OffsetCalculation start_offset_on_padded_concat_calculation =
399 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
400 dst_per_shard_size - src_per_shard_size,
401 (src_per_shard_size - dst_per_shard_size) * (dst_shard_count - 1),
402 1));
403 auto slice_shape = concat->shape();
404 slice_shape.set_dimensions(dim,
405 padded_src_shape.dimensions(dim) /
406 src_sharding.tile_assignment().dim(dim));
407 std::vector<HloInstruction*> slice_offsets(concat->shape().rank(),
408 zero_s32);
409 slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
410 partition_ordinals[dim], b);
411 result = b->AddInstruction(HloInstruction::CreateDynamicSlice(
412 slice_shape, concat, slice_offsets, slice_shape.dimensions()));
413 }
414 return result;
415 }
416
PadFromPartialReplicateShape(HloInstruction * hlo,const Shape & base_shape,const HloSharding & src_sharding,const HloSharding & dst_sharding,const std::vector<int64_t> & expand_tile_dims,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t * next_channel_id,HloInstruction * partition_id,SpmdBuilder * b)417 std::optional<HloInstruction*> PadFromPartialReplicateShape(
418 HloInstruction* hlo, const Shape& base_shape,
419 const HloSharding& src_sharding, const HloSharding& dst_sharding,
420 const std::vector<int64_t>& expand_tile_dims,
421 const SPMDCollectiveOpsCreator& collective_ops_creator,
422 int64_t* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) {
423 auto padded_src_shape =
424 GetPaddedShapeForUnevenPartitioning(base_shape, src_sharding);
425 auto padded_dst_shape =
426 GetPaddedShapeForUnevenPartitioning(base_shape, dst_sharding);
427 if (ShapeUtil::Compatible(padded_dst_shape, hlo->shape())) {
428 return hlo;
429 }
430
431 auto partition_ordinals =
432 MakeTiledPartitionOrdinals(src_sharding, partition_id, b);
433
434 HloInstruction* result = hlo;
435 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
436 LiteralUtil::Zero(hlo->shape().element_type())));
437 std::vector<int64_t> expand_dims_without_halo_exchange;
438 // Pad the dimensions needs halo exchange and record the padded dims that
439 // won't need halo exchange.
440 for (auto dim : expand_tile_dims) {
441 int64_t src_shard_count = src_sharding.tile_assignment().dim(dim);
442 int64_t src_per_shard_size =
443 padded_src_shape.dimensions(dim) / src_shard_count;
444 // Calculate per shard size using the sharding to compare if dst_sharding
445 // needs more padding at the end.
446 int64_t dst_per_shard_size =
447 padded_dst_shape.dimensions(dim) / src_shard_count;
448
449 // If dst_sharding doesn't need more padding at the end.
450 if (src_per_shard_size >= dst_per_shard_size) {
451 continue;
452 }
453 // If src sharding at this dimension is not partitoned, simply pad to
454 // the desired shape.
455 if (src_shard_count == 1) {
456 expand_dims_without_halo_exchange.emplace_back(dim);
457 continue;
458 }
459
460 // If dst_padding needs more padding at the end, need to re-distribute the
461 // data between each shard using collective permute.
462 // For example, if dimension size is 6 and shard 2 ways in the src but
463 // needs to shard 4 ways in the dst. 4 ways needs padding 2 0s at the end
464 // and has 2 elements at each shard, while 2 way sharding has 3 elements
465 // in each shard, re-distribution is needed.
466 //
467 // 1. Calculate left_halo size.
468 // left-halo size is 0
469 OffsetCalculation left_halo_size_function =
470 OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1));
471
472 // 2. Calculate right_halo size.
473 // right-halo size is D * (i + 1) - S * (i + 1) = (D - S) * i + (D - S)
474 OffsetCalculation right_halo_size_function =
475 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
476 dst_per_shard_size - src_per_shard_size,
477 dst_per_shard_size - src_per_shard_size, 1));
478
479 auto concat = result;
480 // 3. Halo exchange.
481 auto halo_exchange_result = ExchangeHalo(
482 result, left_halo_size_function, right_halo_size_function, dim,
483 src_sharding, collective_ops_creator, next_channel_id, b);
484
485 if (halo_exchange_result.has_value()) {
486 concat = halo_exchange_result.value();
487 } else {
488 return std::nullopt;
489 }
490
491 // 4. Pad.
492 std::vector<int64_t> zero_padding(concat->shape().rank());
493 PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding);
494 pad_config.mutable_dimensions(dim)->set_edge_padding_low(0);
495 int64_t max_right_halo_size =
496 right_halo_size_function.MaxInRange(0, src_shard_count - 1);
497 pad_config.mutable_dimensions(dim)->set_edge_padding_high(
498 std::max(int64_t{0}, padded_dst_shape.dimensions(dim) -
499 padded_src_shape.dimensions(dim) -
500 max_right_halo_size));
501 auto padded_concat_shape = ShapeInference::InferPadShape(
502 concat->shape(), zero->shape(), pad_config)
503 .ValueOrDie();
504 concat = b->AddInstruction(HloInstruction::CreatePad(
505 padded_concat_shape, concat, zero, pad_config));
506
507 // 5. Slice the valid result.
508 // Slice offset is (D-S) * i
509 auto zero_s32 = b->AddInstruction(
510 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
511 OffsetCalculation start_offset_on_padded_concat_calculation =
512 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
513 dst_per_shard_size - src_per_shard_size, 0, 1));
514 auto slice_shape = concat->shape();
515 slice_shape.set_dimensions(dim, dst_per_shard_size);
516 std::vector<HloInstruction*> slice_offsets(concat->shape().rank(),
517 zero_s32);
518 slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
519 partition_ordinals[dim], b);
520 result = b->AddInstruction(HloInstruction::CreateDynamicSlice(
521 slice_shape, concat, slice_offsets, slice_shape.dimensions()));
522 }
523
524 // Pad other dimensions that won't need halo exchange with a single pad.
525 if (!expand_dims_without_halo_exchange.empty()) {
526 std::vector<int64_t> zero_padding(result->shape().rank());
527 PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding);
528
529 auto padded_shape = result->shape();
530 for (auto dim : expand_dims_without_halo_exchange) {
531 pad_config.mutable_dimensions(dim)->set_edge_padding_low(0);
532 pad_config.mutable_dimensions(dim)->set_edge_padding_high(
533 padded_dst_shape.dimensions(dim) - padded_src_shape.dimensions(dim));
534 padded_shape.set_dimensions(dim, result->shape().dimensions(dim) +
535 padded_dst_shape.dimensions(dim) -
536 padded_src_shape.dimensions(dim));
537 }
538 result = b->AddInstruction(
539 HloInstruction::CreatePad(padded_shape, result, zero, pad_config));
540 }
541
542 return result;
543 }
544
UniqueTiledDim(const HloSharding & sharding)545 std::optional<int64_t> UniqueTiledDim(const HloSharding& sharding) {
546 if (sharding.IsTileMaximal()) {
547 return std::nullopt;
548 }
549 int64_t dim = -1;
550 int64_t rank = sharding.ReplicateOnLastTileDim()
551 ? sharding.tile_assignment().num_dimensions() - 1
552 : sharding.tile_assignment().num_dimensions();
553 for (int64_t i = 0; i < rank; ++i) {
554 if (sharding.tile_assignment().dim(i) > 1) {
555 if (dim != -1) {
556 return std::nullopt;
557 }
558 dim = i;
559 }
560 }
561 CHECK_NE(dim, -1);
562 return dim;
563 }
564
MultiplyAddDivideOffsetCalculation(int64_t multiplier,int64_t offset,int64_t divisor)565 MultiplyAddDivideOffsetCalculation::MultiplyAddDivideOffsetCalculation(
566 int64_t multiplier, int64_t offset, int64_t divisor)
567 : multiplier_(multiplier), offset_(offset), divisor_(divisor) {
568 CHECK_GT(divisor_, 0);
569 Simplify();
570 }
571
operator -(const MultiplyAddDivideOffsetCalculation & other) const572 OffsetCalculation MultiplyAddDivideOffsetCalculation::operator-(
573 const MultiplyAddDivideOffsetCalculation& other) const {
574 if (divisor_ == 1 && other.divisor_ == 1) {
575 return OffsetCalculation(MultiplyAddDivideOffsetCalculation(
576 multiplier_ - other.multiplier_, offset_ - other.offset_, 1));
577 }
578 return OffsetCalculation(HloOpcode::kSubtract, *this, other);
579 }
580
Simplify()581 void MultiplyAddDivideOffsetCalculation::Simplify() {
582 // We could simplify the calculation when multiplier is a multiple of
583 // divisor_. However, when offset_ is not a multiple of divisor_, we must
584 // make sure that offset_ and multiplier_ are both non-negative or both
585 // non-positive. E.g., (3 * i - 1) / 3 is not equivalent to i or i - 1.
586 if (divisor_ != 1 && multiplier_ % divisor_ == 0 &&
587 (offset_ % divisor_ == 0 || offset_ * multiplier_ > 0)) {
588 multiplier_ /= divisor_;
589 offset_ /= divisor_;
590 divisor_ = 1;
591 }
592 }
593
Calculate(int64_t shard_ordinal) const594 int64_t MultiplyAddDivideOffsetCalculation::Calculate(
595 int64_t shard_ordinal) const {
596 return (shard_ordinal * multiplier_ + offset_) / divisor_;
597 }
598
Calculate(HloInstruction * shard_ordinal,SpmdBuilder * b) const599 HloInstruction* MultiplyAddDivideOffsetCalculation::Calculate(
600 HloInstruction* shard_ordinal, SpmdBuilder* b) const {
601 auto scalar_shape = ShapeUtil::MakeShape(S32, {});
602 if (multiplier_ == 0) {
603 return b->AddInstruction(HloInstruction::CreateConstant(
604 LiteralUtil::CreateR0<int32_t>(offset_ / divisor_)));
605 }
606 HloInstruction* result = shard_ordinal;
607 if (multiplier_ != 1) {
608 result = b->AddInstruction(HloInstruction::CreateBinary(
609 scalar_shape, HloOpcode::kMultiply, shard_ordinal,
610 b->AddInstruction(HloInstruction::CreateConstant(
611 LiteralUtil::CreateR0<int32_t>(multiplier_)))));
612 }
613 if (offset_ != 0) {
614 auto offset = b->AddInstruction(HloInstruction::CreateConstant(
615 LiteralUtil::CreateR0<int32_t>(offset_)));
616 result = b->AddInstruction(HloInstruction::CreateBinary(
617 scalar_shape, HloOpcode::kAdd, result, offset));
618 }
619 if (divisor_ != 1) {
620 auto divisor = b->AddInstruction(HloInstruction::CreateConstant(
621 LiteralUtil::CreateR0<int32_t>(divisor_)));
622 result = b->AddInstruction(HloInstruction::CreateBinary(
623 scalar_shape, HloOpcode::kDivide, result, divisor));
624 }
625 return result;
626 }
627
MaxInRange(int64_t start_ordinal,int64_t limit_ordinal) const628 int64_t MultiplyAddDivideOffsetCalculation::MaxInRange(
629 int64_t start_ordinal, int64_t limit_ordinal) const {
630 int64_t max = Calculate(start_ordinal);
631 for (int64_t i = start_ordinal + 1; i < limit_ordinal; ++i) {
632 max = std::max(max, Calculate(i));
633 }
634 return max;
635 }
636
operator =(const OffsetCalculation & other)637 OffsetCalculation& OffsetCalculation::operator=(
638 const OffsetCalculation& other) {
639 opcode_ = other.opcode_;
640 copy_from_ = other.copy_from_;
641 if (opcode_ != HloOpcode::kCopy) {
642 lhs_ = std::make_unique<OffsetCalculation>(*other.lhs_);
643 rhs_ = std::make_unique<OffsetCalculation>(*other.rhs_);
644 }
645 return *this;
646 }
647
IsConstant() const648 bool OffsetCalculation::IsConstant() const {
649 if (opcode_ == HloOpcode::kCopy) {
650 return copy_from_.IsConstant();
651 }
652 if (opcode_ == HloOpcode::kSubtract && *lhs_ == *rhs_) {
653 return true;
654 }
655 return lhs_->IsConstant() && rhs_->IsConstant();
656 }
657
operator -(const OffsetCalculation & other) const658 OffsetCalculation OffsetCalculation::operator-(
659 const OffsetCalculation& other) const {
660 if (opcode_ == HloOpcode::kCopy && other.opcode_ == HloOpcode::kCopy) {
661 return copy_from_ - other.copy_from_;
662 }
663 return OffsetCalculation(HloOpcode::kSubtract, *this, other);
664 }
665
operator ==(const OffsetCalculation & other) const666 bool OffsetCalculation::operator==(const OffsetCalculation& other) const {
667 if (opcode_ != other.opcode_) {
668 return false;
669 }
670 if (opcode_ == HloOpcode::kCopy) {
671 return copy_from_ == other.copy_from_;
672 }
673 return *lhs_ == *other.lhs_ && *rhs_ == *other.rhs_;
674 }
675
Calculate(int64_t shard_ordinal) const676 int64_t OffsetCalculation::Calculate(int64_t shard_ordinal) const {
677 switch (opcode_) {
678 case HloOpcode::kCopy:
679 return copy_from_.Calculate(shard_ordinal);
680 case HloOpcode::kSubtract:
681 return lhs_->Calculate(shard_ordinal) - rhs_->Calculate(shard_ordinal);
682 case HloOpcode::kMultiply:
683 return lhs_->Calculate(shard_ordinal) * rhs_->Calculate(shard_ordinal);
684 default:
685 LOG(FATAL) << "Should not happen";
686 }
687 }
688
Calculate(HloInstruction * shard_ordinal,SpmdBuilder * b) const689 HloInstruction* OffsetCalculation::Calculate(HloInstruction* shard_ordinal,
690 SpmdBuilder* b) const {
691 if (opcode_ == HloOpcode::kCopy) {
692 return copy_from_.Calculate(shard_ordinal, b);
693 }
694 auto lhs = lhs_->Calculate(shard_ordinal, b);
695 auto rhs = rhs_->Calculate(shard_ordinal, b);
696 return b->AddInstruction(
697 HloInstruction::CreateBinary(lhs->shape(), opcode_, lhs, rhs));
698 }
699
MaxInRange(int64_t start_ordinal,int64_t limit_ordinal) const700 int64_t OffsetCalculation::MaxInRange(int64_t start_ordinal,
701 int64_t limit_ordinal) const {
702 if (IsConstant()) {
703 return Calculate(start_ordinal);
704 }
705 if (opcode_ == HloOpcode::kCopy) {
706 return std::max(Calculate(start_ordinal), Calculate(limit_ordinal - 1));
707 }
708 int64_t max = Calculate(start_ordinal);
709 for (int64_t i = start_ordinal + 1; i < limit_ordinal; ++i) {
710 max = std::max(max, Calculate(i));
711 }
712 return max;
713 }
714
ExchangeHalo(HloInstruction * hlo,const OffsetCalculation & left_halo_size_function,const OffsetCalculation & right_halo_size_function,int64_t dim,const HloSharding & target,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t * next_channel_id,SpmdBuilder * b)715 std::optional<HloInstruction*> ExchangeHalo(
716 HloInstruction* hlo, const OffsetCalculation& left_halo_size_function,
717 const OffsetCalculation& right_halo_size_function, int64_t dim,
718 const HloSharding& target,
719 const SPMDCollectiveOpsCreator& collective_ops_creator,
720 int64_t* next_channel_id, SpmdBuilder* b) {
721 int64_t input_shard_size = hlo->shape().dimensions(dim);
722 int64_t shard_count = target.tile_assignment().dim(dim);
723
724 std::vector<HloInstruction*> concat_pieces;
725
726 int64_t max_left_halo_size =
727 left_halo_size_function.MaxInRange(1, shard_count);
728 int64_t max_right_halo_size =
729 right_halo_size_function.MaxInRange(0, shard_count - 1);
730 if (max_left_halo_size + max_right_halo_size + input_shard_size >=
731 input_shard_size * shard_count &&
732 (max_left_halo_size > input_shard_size ||
733 max_right_halo_size > input_shard_size)) {
734 return std::nullopt;
735 }
736 // Since max halo sizes could be negative, we only need to include data within
737 // certain bounds. Useful region is [left_bound, right_bound).
738 const int64_t left_bound =
739 -left_halo_size_function.MaxInRange(0, shard_count);
740 const int64_t right_bound =
741 input_shard_size + right_halo_size_function.MaxInRange(0, shard_count);
742 if (left_bound >= right_bound) {
743 return std::nullopt;
744 }
745 // Left halo.
746 for (int64_t i = CeilOfRatio(max_left_halo_size, input_shard_size) - 1;
747 i >= 0 && (-i - 1) * input_shard_size < right_bound; --i) {
748 std::vector<std::pair<int64_t, int64_t>> source_target_pairs;
749 target.tile_assignment().Each(
750 [&](absl::Span<const int64_t> indices, int64_t device) {
751 if (indices[dim] > i) {
752 std::vector<int64_t> source_indices(indices.begin(), indices.end());
753 source_indices[dim] -= i + 1;
754 source_target_pairs.emplace_back(
755 target.tile_assignment()(source_indices), device);
756 }
757 });
758 int64_t halo_size_including_skips =
759 std::min(max_left_halo_size - input_shard_size * i, input_shard_size);
760 int64_t halo_right_skips =
761 std::max<int64_t>(-i * input_shard_size - right_bound, 0);
762 int64_t halo_size = halo_size_including_skips - halo_right_skips;
763 auto halo_shape = hlo->shape();
764 auto source_halo_slice = hlo;
765 if (halo_size != hlo->shape().dimensions(dim)) {
766 halo_shape.set_dimensions(dim, halo_size);
767 std::vector<int64_t> halo_start_indices(halo_shape.rank(), 0);
768 halo_start_indices[dim] =
769 hlo->shape().dimensions(dim) - halo_size_including_skips;
770 std::vector<int64_t> halo_limit_indices(hlo->shape().dimensions().begin(),
771 hlo->shape().dimensions().end());
772 halo_limit_indices[dim] -= halo_right_skips;
773 std::vector<int64_t> halo_slice_strides(halo_shape.rank(), 1);
774 source_halo_slice = b->AddInstruction(
775 HloInstruction::CreateSlice(halo_shape, hlo, halo_start_indices,
776 halo_limit_indices, halo_slice_strides));
777 }
778 auto left_halo =
779 collective_ops_creator.create_cross_partition_collective_permute(
780 b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
781 concat_pieces.push_back(left_halo);
782 }
783
784 if (left_bound < input_shard_size && right_bound > 0) {
785 int64_t self_start = std::max<int64_t>(0, left_bound);
786 int64_t self_limit = std::min<int64_t>(input_shard_size, right_bound);
787 if (self_start == 0 && self_limit == input_shard_size) {
788 concat_pieces.push_back(hlo);
789 } else {
790 auto self_shape = hlo->shape();
791 self_shape.set_dimensions(dim, self_limit - self_start);
792 std::vector<int64_t> start_indices(self_shape.rank(), 0);
793 start_indices[dim] = self_start;
794 std::vector<int64_t> limit_indices(hlo->shape().dimensions().begin(),
795 hlo->shape().dimensions().end());
796 limit_indices[dim] = self_limit;
797 std::vector<int64_t> slice_strides(self_shape.rank(), 1);
798 concat_pieces.push_back(b->AddInstruction(HloInstruction::CreateSlice(
799 self_shape, hlo, start_indices, limit_indices, slice_strides)));
800 }
801 }
802
803 int64_t skipped_right_halos =
804 std::min<int64_t>(std::max<int64_t>(left_bound - input_shard_size, 0),
805 std::max<int64_t>(max_right_halo_size, 0)) /
806 input_shard_size;
807 // Right halo.
808 for (int64_t i = skipped_right_halos;
809 i < CeilOfRatio(max_right_halo_size, input_shard_size); ++i) {
810 std::vector<std::pair<int64_t, int64_t>> source_target_pairs;
811 target.tile_assignment().Each(
812 [&](absl::Span<const int64_t> indices, int64_t device) {
813 if (indices[dim] > i) {
814 std::vector<int64_t> target_indices(indices.begin(), indices.end());
815 target_indices[dim] -= i + 1;
816 source_target_pairs.emplace_back(
817 device, target.tile_assignment()(target_indices));
818 }
819 });
820 int64_t halo_size_including_skips =
821 std::min(max_right_halo_size - input_shard_size * i, input_shard_size);
822 int64_t halo_left_skips =
823 std::max<int64_t>(left_bound - (i + 1) * input_shard_size, 0);
824 int64_t halo_size = halo_size_including_skips - halo_left_skips;
825 auto halo_shape = hlo->shape();
826 HloInstruction* source_halo_slice = hlo;
827 if (halo_size != halo_shape.dimensions(dim)) {
828 halo_shape.set_dimensions(dim, halo_size);
829 std::vector<int64_t> halo_start_indices(halo_shape.rank(), 0);
830 halo_start_indices[dim] = halo_left_skips;
831 std::vector<int64_t> halo_limit_indices(halo_shape.dimensions().begin(),
832 halo_shape.dimensions().end());
833 halo_limit_indices[dim] += halo_left_skips;
834 std::vector<int64_t> halo_slice_strides(halo_shape.rank(), 1);
835 source_halo_slice = b->AddInstruction(
836 HloInstruction::CreateSlice(halo_shape, hlo, halo_start_indices,
837 halo_limit_indices, halo_slice_strides));
838 }
839 auto right_halo =
840 collective_ops_creator.create_cross_partition_collective_permute(
841 b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
842 concat_pieces.push_back(right_halo);
843 }
844
845 auto concat = concat_pieces[0];
846 // Concat with halos/padding.
847 if (concat_pieces.size() > 1) {
848 auto concat_shape = hlo->shape();
849 int64_t concat_dim_size = 0;
850 for (auto piece : concat_pieces) {
851 concat_dim_size += piece->shape().dimensions(dim);
852 }
853 concat_shape.set_dimensions(dim, concat_dim_size);
854 concat = b->AddInstruction(
855 HloInstruction::CreateConcatenate(concat_shape, concat_pieces, dim));
856 }
857
858 return concat;
859 }
860
ExchangeHalo(HloInstruction * hlo,std::vector<OffsetCalculation> left_halo_size_functions,std::vector<OffsetCalculation> right_halo_size_functions,const HloSharding & target,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t * next_channel_id,SpmdBuilder * b)861 std::optional<HloInstruction*> ExchangeHalo(
862 HloInstruction* hlo,
863 std::vector<OffsetCalculation> left_halo_size_functions,
864 std::vector<OffsetCalculation> right_halo_size_functions,
865 const HloSharding& target,
866 const SPMDCollectiveOpsCreator& collective_ops_creator,
867 int64_t* next_channel_id, SpmdBuilder* b) {
868 CHECK(left_halo_size_functions.size() == hlo->shape().rank());
869 CHECK(right_halo_size_functions.size() == hlo->shape().rank());
870
871 HloInstruction* visiting_hlo = hlo;
872 for (int dim = 0; dim < hlo->shape().rank(); ++dim) {
873 auto concat = ExchangeHalo(visiting_hlo, left_halo_size_functions[dim],
874 right_halo_size_functions[dim], dim, target,
875 collective_ops_creator, next_channel_id, b);
876 if (!concat) {
877 return std::nullopt;
878 }
879 visiting_hlo = *concat;
880 }
881 return visiting_hlo;
882 }
883
ExchangeHaloAndGetValidData(HloInstruction * hlo,const Shape & base_shape,const OffsetCalculation & left_halo_size_function,const OffsetCalculation & right_halo_size_function,int64_t explicit_left_padding_on_full_shape,int64_t padded_full_shape_size,int64_t shard_size_with_halo,int64_t dim,const HloSharding & target,HloInstruction * offset_on_padded_shape,HloInstruction * pad_value,HloInstruction * partition_ordinal,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t * next_channel_id,SpmdBuilder * b,bool mask_invalid_region)884 std::optional<HloInstruction*> ExchangeHaloAndGetValidData(
885 HloInstruction* hlo, const Shape& base_shape,
886 const OffsetCalculation& left_halo_size_function,
887 const OffsetCalculation& right_halo_size_function,
888 int64_t explicit_left_padding_on_full_shape, int64_t padded_full_shape_size,
889 int64_t shard_size_with_halo, int64_t dim, const HloSharding& target,
890 HloInstruction* offset_on_padded_shape, HloInstruction* pad_value,
891 HloInstruction* partition_ordinal,
892 const SPMDCollectiveOpsCreator& collective_ops_creator,
893 int64_t* next_channel_id, SpmdBuilder* b, bool mask_invalid_region) {
894 auto halo_exchange_result =
895 ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function, dim,
896 target, collective_ops_creator, next_channel_id, b);
897 if (!halo_exchange_result) {
898 return std::nullopt;
899 }
900 auto concat = *halo_exchange_result;
901 int64_t shard_count = target.tile_assignment().dim(dim);
902 int64_t max_left_halo_size =
903 left_halo_size_function.MaxInRange(1, shard_count);
904
905 // Now we determine if we need extra padding after the concat.
906 //
907 // The max of halo size or the first shard's explicit left padding.
908 int64_t max_left_halo_or_padding_size =
909 std::max(max_left_halo_size, explicit_left_padding_on_full_shape);
910 // The calculation that returns the dynamic slice index for a shard on the
911 // padded concat, which is the difference between
912 // max_left_halo_or_padding_size and its left halo size.
913 auto start_offset_on_padded_concat_calculation =
914 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
915 0, max_left_halo_or_padding_size, 1)) -
916 left_halo_size_function;
917
918 // See if we need to pad the concat before dynamic slice.
919 int64_t extra_left_padding =
920 std::max(int64_t{0}, max_left_halo_or_padding_size -
921 std::max(int64_t{0}, max_left_halo_size));
922 int64_t extra_right_padding =
923 start_offset_on_padded_concat_calculation.MaxInRange(0, shard_count) +
924 shard_size_with_halo - concat->shape().dimensions(dim) -
925 extra_left_padding;
926 extra_right_padding = std::max(int64_t{0}, extra_right_padding);
927 if (extra_left_padding > 0 || extra_right_padding > 0) {
928 PaddingConfig padding_config;
929 auto padded_concat_shape = concat->shape();
930 for (int64_t i = 0; i < base_shape.rank(); ++i) {
931 auto padding_config_dim = padding_config.add_dimensions();
932 padding_config_dim->set_interior_padding(0);
933 padding_config_dim->set_edge_padding_low(0);
934 padding_config_dim->set_edge_padding_high(0);
935 if (i != dim) {
936 continue;
937 }
938 padding_config_dim->set_edge_padding_low(extra_left_padding);
939 padding_config_dim->set_edge_padding_high(extra_right_padding);
940 padded_concat_shape.set_dimensions(dim, concat->shape().dimensions(dim) +
941 extra_left_padding +
942 extra_right_padding);
943 }
944 concat = b->AddInstruction(HloInstruction::CreatePad(
945 padded_concat_shape, concat, pad_value, padding_config));
946 }
947
948 auto valid_slice = concat;
949 if (shard_size_with_halo != concat->shape().dimensions(dim)) {
950 // Concat is bigger than the shard shape, so we need a dynamic slice.
951 CHECK_LT(shard_size_with_halo, concat->shape().dimensions(dim));
952 auto slice_shape = concat->shape();
953 slice_shape.set_dimensions(dim, shard_size_with_halo);
954
955 if (left_halo_size_function.IsConstant() &&
956 left_halo_size_function.Calculate(0) ==
957 explicit_left_padding_on_full_shape) {
958 std::vector<int64_t> start_indices(slice_shape.rank(), 0);
959 std::vector<int64_t> strides(slice_shape.rank(), 1);
960 valid_slice = b->AddInstruction(
961 HloInstruction::CreateSlice(slice_shape, concat, start_indices,
962 slice_shape.dimensions(), strides));
963 } else {
964 auto zero = b->AddInstruction(
965 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
966 std::vector<HloInstruction*> slice_offsets(base_shape.rank(), zero);
967 slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
968 partition_ordinal, b);
969 valid_slice = b->AddInstruction(HloInstruction::CreateDynamicSlice(
970 slice_shape, concat, slice_offsets, slice_shape.dimensions()));
971 }
972 }
973
974 if (!mask_invalid_region) {
975 return valid_slice;
976 }
977
978 int64_t total_right_padding = padded_full_shape_size -
979 base_shape.dimensions(dim) -
980 explicit_left_padding_on_full_shape;
981 // Mask off garbage data due to uneven partition or low/high padding.
982 if (explicit_left_padding_on_full_shape > 0 || total_right_padding > 0) {
983 auto index_shape = ShapeUtil::ChangeElementType(valid_slice->shape(), S32);
984 auto iota = b->AddInstruction(HloInstruction::CreateIota(index_shape, dim));
985 auto broadcast_start_index_in_padded_shape =
986 b->AddInstruction(HloInstruction::CreateBroadcast(
987 index_shape, offset_on_padded_shape, {}));
988 auto index_in_padded_shape = b->AddInstruction(
989 HloInstruction::CreateBinary(index_shape, HloOpcode::kAdd, iota,
990 broadcast_start_index_in_padded_shape));
991 auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED);
992 std::vector<HloInstruction*> predicates;
993 if (explicit_left_padding_on_full_shape > 0) {
994 auto valid_index_start =
995 b->AddInstruction(HloInstruction::CreateBroadcast(
996 index_shape,
997 b->AddInstruction(
998 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
999 explicit_left_padding_on_full_shape))),
1000 {}));
1001 predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare(
1002 mask_shape, index_in_padded_shape, valid_index_start,
1003 ComparisonDirection::kGe)));
1004 }
1005 if (total_right_padding > 0) {
1006 auto valid_index_limit =
1007 b->AddInstruction(HloInstruction::CreateBroadcast(
1008 index_shape,
1009 b->AddInstruction(
1010 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
1011 base_shape.dimensions(dim) +
1012 explicit_left_padding_on_full_shape))),
1013 {}));
1014 predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare(
1015 mask_shape, index_in_padded_shape, valid_index_limit,
1016 ComparisonDirection::kLt)));
1017 }
1018 CHECK(!predicates.empty());
1019 auto is_valid =
1020 predicates.size() == 2
1021 ? b->AddInstruction(HloInstruction::CreateBinary(
1022 mask_shape, HloOpcode::kAnd, predicates[0], predicates[1]))
1023 : predicates[0];
1024 auto masking_value = b->AddInstruction(
1025 HloInstruction::CreateBroadcast(valid_slice->shape(), pad_value, {}));
1026 valid_slice = b->AddInstruction(
1027 HloInstruction::CreateTernary(valid_slice->shape(), HloOpcode::kSelect,
1028 is_valid, valid_slice, masking_value));
1029 }
1030 return valid_slice;
1031 }
1032
HaloExchangeToPadOnLeft(PartitionedHlo & original,absl::Span<const int64_t> dims)1033 HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original,
1034 absl::Span<const int64_t> dims) {
1035 if (original.sharding().IsTileMaximal()) {
1036 return original.hlo();
1037 }
1038 // Create a window config to halo exchange for unevenly partitioned reverse
1039 // dimensions.
1040 Window window;
1041 for (int64_t i = 0; i < original.base_shape().rank(); ++i) {
1042 WindowDimension* dim = window.add_dimensions();
1043 dim->set_size(1);
1044 dim->set_stride(1);
1045 dim->set_window_dilation(1);
1046 dim->set_window_reversal(false);
1047 int64_t low_padding = 0;
1048 if (absl::c_linear_search(dims, i)) {
1049 low_padding = RoundUpTo(original.base_shape().dimensions(i),
1050 original.sharding().tile_assignment().dim(i)) -
1051 original.base_shape().dimensions(i);
1052 }
1053 dim->set_padding_low(low_padding);
1054 dim->set_padding_high(0);
1055 dim->set_base_dilation(1);
1056 }
1057
1058 auto reshard_window = original.ReshardAsWindowedInput(
1059 window, original.sharding(),
1060 CreateZero(ShapeUtil::MakeShape(original.base_shape().element_type(), {}),
1061 original.state().b),
1062 /*mask_invalid_region=*/false);
1063 if (!reshard_window.has_value()) {
1064 return nullptr;
1065 }
1066 CHECK(!reshard_window->dynamic_slice_index_on_output.has_value());
1067 return reshard_window->sharded_input;
1068 }
1069
IsNanSafeGt(HloComputation * comp)1070 bool IsNanSafeGt(HloComputation* comp) {
1071 namespace m = match;
1072 auto match_bitcast_f32 = [](int64_t parameter_number) {
1073 auto param = m::Parameter(parameter_number)
1074 .WithShape(m::Shape().WithElementType(F32));
1075 auto param_s32 =
1076 m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
1077 auto param_u32 =
1078 m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
1079 return m::Select(
1080 m::Lt(param_s32, m::ConstantScalar(0)),
1081 m::BitcastConvert(
1082 m::Subtract(m::ConstantScalar(std::numeric_limits<int32_t>::max()),
1083 param_u32))
1084 .WithShape(m::Shape().WithElementType(S32)),
1085 param_s32);
1086 };
1087 auto match_bitcast_bf16 = [](int64_t parameter_number) {
1088 auto param = m::Convert(m::Parameter(parameter_number)
1089 .WithShape(m::Shape().WithElementType(BF16)))
1090 .WithShape(m::Shape().WithElementType(F32));
1091 auto param_s32 =
1092 m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
1093 auto param_u32 =
1094 m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
1095 return m::Select(
1096 m::Lt(param_s32, m::ConstantScalar(0)),
1097 m::BitcastConvert(
1098 m::Subtract(m::ConstantScalar(std::numeric_limits<int32_t>::max()),
1099 param_u32))
1100 .WithShape(m::Shape().WithElementType(S32)),
1101 param_s32);
1102 };
1103 // If root instruction is kSelect and compares indices if values are equal.
1104 if (comp->root_instruction()->opcode() == HloOpcode::kSelect) {
1105 return Match(comp->root_instruction()->operand(2),
1106 m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
1107 Match(comp->root_instruction()->operand(2),
1108 m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1)));
1109 }
1110 return Match(comp->root_instruction(),
1111 m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
1112 Match(comp->root_instruction(),
1113 m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1)));
1114 }
1115
GetKValueInTopKWhenPartitionSortDim(HloInstruction * hlo)1116 std::optional<int64_t> GetKValueInTopKWhenPartitionSortDim(
1117 HloInstruction* hlo) {
1118 HloSortInstruction* sort = DynCast<HloSortInstruction>(hlo);
1119 if (sort == nullptr || sort->operand_count() != 2) {
1120 return std::nullopt;
1121 }
1122 if (!IsNanSafeGt(sort->to_apply())) {
1123 return std::nullopt;
1124 }
1125 HloInstruction* data = sort->mutable_operand(0);
1126 HloIotaInstruction* iota =
1127 DynCast<HloIotaInstruction>(sort->mutable_operand(1));
1128 const PrimitiveType element_type = data->shape().element_type();
1129 if (iota == nullptr || iota->shape().element_type() != S32 ||
1130 iota->opcode() != HloOpcode::kIota ||
1131 iota->iota_dimension() != sort->sort_dimension()) {
1132 return std::nullopt;
1133 }
1134
1135 const int64_t sort_dim = sort->sort_dimension();
1136
1137 if (element_type != F32 && element_type != BF16 && element_type != S32 &&
1138 element_type != U32) {
1139 return std::nullopt;
1140 }
1141
1142 bool supported = true;
1143 std::optional<int64_t> k;
1144 for (HloInstruction* gte : sort->users()) {
1145 if (gte->opcode() != HloOpcode::kGetTupleElement) {
1146 supported = false;
1147 break;
1148 }
1149
1150 const HloInstruction* slice = gte->users()[0];
1151 if (slice->opcode() != HloOpcode::kSlice) {
1152 // Non-slice user means we are not doing a TopK
1153 supported = false;
1154 break;
1155 }
1156 if (absl::c_any_of(slice->slice_starts(), [](int x) { return x != 0; }) ||
1157 absl::c_any_of(slice->slice_strides(), [](int x) { return x != 1; })) {
1158 // Strided slice or slicing at the beginning isn't supported.
1159 supported = false;
1160 break;
1161 }
1162 for (int64_t dim = 0; dim < data->shape().dimensions_size(); dim++) {
1163 if (dim == sort_dim) {
1164 continue;
1165 }
1166 if (slice->slice_limits(dim) !=
1167 slice->operand(0)->shape().dimensions(dim)) {
1168 // Slicing along the other dimension isn't supported.
1169 supported = false;
1170 break;
1171 }
1172 }
1173 if (!k.has_value()) {
1174 k = slice->slice_limits(sort_dim);
1175 } else if (k != slice->slice_limits(sort_dim)) {
1176 // Different k for the different operands isn't supported.
1177 supported = false;
1178 break;
1179 }
1180 }
1181 if (k == std::nullopt || !supported) {
1182 return std::nullopt;
1183 }
1184
1185 // Only support when sort dim is sharded.
1186 if (!data->has_sharding()) {
1187 return std::nullopt;
1188 }
1189 const HloSharding& sharding = sort->operand(0)->sharding();
1190
1191 if (sharding.IsTileMaximal()) {
1192 return std::nullopt;
1193 }
1194
1195 // Check if partitioned at sort dimension.
1196 for (int64_t dim = 0; dim < sort->shape().tuple_shapes(0).dimensions_size();
1197 ++dim) {
1198 if (sharding.tile_assignment().dim(dim) > 1) {
1199 if (dim != sort_dim) {
1200 return std::nullopt;
1201 }
1202 }
1203 }
1204
1205 // Checks if partition size is smaller than k.
1206 const int64_t shard_count = sharding.tile_assignment().dim(sort_dim);
1207
1208 if (shard_count <= 1) {
1209 return std::nullopt;
1210 }
1211
1212 const int64_t input_size = hlo->operand(0)->shape().dimensions(sort_dim);
1213 const int64_t per_partition_size = CeilOfRatio(input_size, shard_count);
1214
1215 if (k.value() >= per_partition_size) {
1216 return std::nullopt;
1217 }
1218
1219 return k;
1220 }
1221
1222 // Slice first k elements from sort_dim.
SliceFirstK(HloInstruction * hlo,SpmdBuilder * builder,int64_t slice_dim,int64_t k)1223 HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
1224 int64_t slice_dim, int64_t k) {
1225 const Shape& hlo_shape = hlo->shape();
1226 auto hlo_dims = hlo_shape.dimensions();
1227 std::vector<int64_t> start_indices(hlo_shape.dimensions_size(), 0);
1228 std::vector<int64_t> limit_indices(hlo_dims.begin(), hlo_dims.end());
1229 std::vector<int64_t> strides(hlo_shape.dimensions_size(), 1);
1230 limit_indices[slice_dim] = k;
1231 auto output_shape = hlo_shape;
1232 output_shape.set_dimensions(slice_dim, k);
1233 return builder->AddInstruction(HloInstruction::CreateSlice(
1234 output_shape, hlo, start_indices, limit_indices, strides));
1235 }
1236
1237 // Check if a dimension is sharded.
ShardCountAtDim(const HloSharding & sharding,int64_t dim)1238 int64_t ShardCountAtDim(const HloSharding& sharding, int64_t dim) {
1239 if (sharding.IsTileMaximal()) {
1240 return 1;
1241 }
1242 return sharding.tile_assignment().dim(dim);
1243 }
1244
1245 std::optional<std::vector<std::pair<int64_t, int64_t>>>
GetReshardAllToAllSourceTargetDims(const HloSharding & source,const HloSharding & target)1246 GetReshardAllToAllSourceTargetDims(const HloSharding& source,
1247 const HloSharding& target) {
1248 if (source.IsTileMaximal() || target.IsTileMaximal() ||
1249 source.tile_assignment().num_dimensions() !=
1250 target.tile_assignment().num_dimensions() ||
1251 source.NumTiles() != target.NumTiles()) {
1252 return std::nullopt;
1253 }
1254 // Record partition count to index for indices that have different partition
1255 // counts on source and target.
1256 std::map<int64_t, std::vector<int64_t>> source_size_to_dim;
1257 std::map<int64_t, std::vector<int64_t>> target_size_to_dim;
1258 for (int64_t i = 0; i < source.tile_assignment().num_dimensions(); ++i) {
1259 if (source.tile_assignment().dim(i) == target.tile_assignment().dim(i)) {
1260 continue;
1261 }
1262 source_size_to_dim[source.tile_assignment().dim(i)].push_back(i);
1263 target_size_to_dim[target.tile_assignment().dim(i)].push_back(i);
1264 }
1265 // In order to shard via AllToAll, source_size_to_dim and target_size_to_dim
1266 // must have the same distribution.
1267 if (source_size_to_dim.empty() ||
1268 source_size_to_dim.size() != target_size_to_dim.size()) {
1269 return std::nullopt;
1270 }
1271 for (const auto& entry : source_size_to_dim) {
1272 auto target_it = target_size_to_dim.find(entry.first);
1273 if (target_it == target_size_to_dim.end() ||
1274 target_it->second.size() != entry.second.size()) {
1275 return std::nullopt;
1276 }
1277 }
1278 std::vector<std::pair<int64_t, int64_t>> result;
1279 auto remove_entry = [](int64_t size, int64_t dim,
1280 std::map<int64_t, std::vector<int64_t>>& size_to_dim) {
1281 size_to_dim[size].erase(
1282 std::remove_if(size_to_dim[size].begin(), size_to_dim[size].end(),
1283 [dim](int64_t a) { return a == dim; }),
1284 size_to_dim[size].end());
1285 if (size_to_dim[size].empty()) {
1286 size_to_dim.erase(size);
1287 }
1288 };
1289 // Find one pair of dimensions to swap at a time.
1290 while (!source_size_to_dim.empty()) {
1291 int64_t source_size = source_size_to_dim.begin()->first;
1292 int64_t i = source_size_to_dim.begin()->second.back();
1293 int64_t target_i_size = target.tile_assignment().dim(i);
1294 if (target_i_size == source_size) {
1295 remove_entry(source_size, i, source_size_to_dim);
1296 remove_entry(source_size, i, target_size_to_dim);
1297 continue;
1298 }
1299 auto j_it = source_size_to_dim[target_i_size].begin();
1300 int64_t j = *j_it;
1301 if (source_size == 1) {
1302 // If possible, find a j where the target partition count is not one, so
1303 // that when we swap, the resulting size-1 dimension will still be useful
1304 // to other dimensions.
1305 while (target.tile_assignment().dim(j) == 1) {
1306 if (++j_it == source_size_to_dim[target_i_size].end()) {
1307 break;
1308 }
1309 j = *j_it;
1310 }
1311 } else if (target_i_size % source_size == 0) {
1312 // If possible, find a j where the target partition count is source_size,
1313 // so that we can do a single swap.
1314 while (target.tile_assignment().dim(j) != source_size) {
1315 if (++j_it == source_size_to_dim[target_i_size].end()) {
1316 break;
1317 }
1318 j = *j_it;
1319 }
1320 } else {
1321 return std::nullopt;
1322 }
1323 result.emplace_back(j, i);
1324 remove_entry(target_i_size, i, target_size_to_dim);
1325 source_size_to_dim.begin()->second.back() = j;
1326 remove_entry(target_i_size, j, source_size_to_dim);
1327 }
1328 return result;
1329 }
1330
CanReshardWithCollectivePermute(const HloSharding & source,const HloSharding & target)1331 bool CanReshardWithCollectivePermute(const HloSharding& source,
1332 const HloSharding& target) {
1333 return !source.IsTileMaximal() && !target.IsTileMaximal() &&
1334 source.tile_assignment().dimensions() ==
1335 target.tile_assignment().dimensions() &&
1336 source.ReplicateOnLastTileDim() == target.ReplicateOnLastTileDim() &&
1337 source.tile_assignment() != target.tile_assignment();
1338 }
1339
AlignGroupsWithInternal(GroupedSharding grouped_sharding,const GroupedSharding & reference,bool requires_compatibility,bool ignore_group_order)1340 std::optional<GroupedSharding> AlignGroupsWithInternal(
1341 GroupedSharding grouped_sharding, const GroupedSharding& reference,
1342 bool requires_compatibility, bool ignore_group_order) {
1343 // Returns src -> dst index mapping.
1344 auto get_permutation = [](absl::Span<const int64_t> src,
1345 absl::Span<const int64_t> dst) {
1346 CHECK_EQ(src.size(), dst.size());
1347 absl::flat_hash_map<int64_t, int64_t> dst_reverse_map;
1348 for (int64_t i = 0; i < dst.size(); ++i) {
1349 dst_reverse_map[dst[i]] = i;
1350 }
1351 std::vector<int64_t> permutation(src.size());
1352 for (int64_t i = 0; i < src.size(); ++i) {
1353 auto it = dst_reverse_map.find(src[i]);
1354 CHECK(it != dst_reverse_map.end());
1355 permutation[i] = it->second;
1356 }
1357 return permutation;
1358 };
1359 CHECK_EQ(grouped_sharding.device_groups.size(),
1360 reference.device_groups.size());
1361 absl::flat_hash_map<int64_t, int64_t> device_to_ref_group;
1362 for (int64_t g = 0; g < reference.device_groups.size(); ++g) {
1363 for (int64_t device : reference.device_groups[g]) {
1364 device_to_ref_group[device] = g;
1365 }
1366 }
1367 auto unique_ref_dev_group =
1368 [&](absl::Span<const int64_t> devices) -> int64_t {
1369 int64_t ref_g = -1;
1370 for (int64_t device : devices) {
1371 if (ref_g == -1) {
1372 ref_g = device_to_ref_group[device];
1373 } else if (ref_g != device_to_ref_group[device]) {
1374 return -1;
1375 }
1376 }
1377 return ref_g;
1378 };
1379 bool matching_groups = true;
1380 std::vector<int64_t> original_src_to_ref_permutation;
1381 for (int64_t g = 0; g < grouped_sharding.device_groups.size(); ++g) {
1382 int64_t ref_g = unique_ref_dev_group(grouped_sharding.device_groups[g]);
1383 if (ref_g < 0 || (!ignore_group_order && g != ref_g)) {
1384 if (requires_compatibility) {
1385 return std::nullopt;
1386 }
1387 matching_groups = false;
1388 break;
1389 }
1390 if (g == 0) {
1391 original_src_to_ref_permutation = get_permutation(
1392 grouped_sharding.device_groups[g], reference.device_groups[ref_g]);
1393 } else if (requires_compatibility) {
1394 if (original_src_to_ref_permutation !=
1395 get_permutation(grouped_sharding.device_groups[g],
1396 reference.device_groups[ref_g])) {
1397 return std::nullopt;
1398 }
1399 }
1400 }
1401 if (matching_groups && !grouped_sharding.sharding.IsTileMaximal()) {
1402 auto tiles = grouped_sharding.sharding.tile_assignment();
1403 tiles.Each([&](absl::Span<const int64_t> indices, int64_t* device) {
1404 *device = original_src_to_ref_permutation[*device];
1405 });
1406 grouped_sharding.sharding =
1407 grouped_sharding.sharding.ReplicateOnLastTileDim()
1408 ? HloSharding::PartialTile(tiles)
1409 : HloSharding::Tile(tiles);
1410 }
1411 grouped_sharding.device_groups = std::move(reference.device_groups);
1412 return grouped_sharding;
1413 }
1414
AlignGroupsWith(GroupedSharding grouped_sharding,const GroupedSharding & reference,bool ignore_group_order)1415 GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding,
1416 const GroupedSharding& reference,
1417 bool ignore_group_order) {
1418 return *AlignGroupsWithInternal(std::move(grouped_sharding), reference,
1419 /*requires_compatibility=*/false,
1420 ignore_group_order);
1421 }
1422
AlignGroupsWithIfCompatible(GroupedSharding grouped_sharding,const GroupedSharding & reference)1423 std::optional<GroupedSharding> AlignGroupsWithIfCompatible(
1424 GroupedSharding grouped_sharding, const GroupedSharding& reference) {
1425 return AlignGroupsWithInternal(std::move(grouped_sharding), reference,
1426 /*requires_compatibility=*/true,
1427 /*ignore_group_order=*/false);
1428 }
1429
AlignShardingOnDims(const HloSharding & sharding,absl::Span<const int64_t> sharding_dims,const HloSharding & reference,absl::Span<const int64_t> reference_dims)1430 HloSharding AlignShardingOnDims(const HloSharding& sharding,
1431 absl::Span<const int64_t> sharding_dims,
1432 const HloSharding& reference,
1433 absl::Span<const int64_t> reference_dims) {
1434 auto sharding_grouped =
1435 hlo_sharding_util::GroupShardingOnDims(sharding, sharding_dims);
1436 auto reference_grouped =
1437 hlo_sharding_util::GroupShardingOnDims(reference, reference_dims);
1438 return hlo_sharding_util::UngroupSharding(
1439 AlignGroupsWith(sharding_grouped, reference_grouped));
1440 }
1441
GetPerGroupBaseShape(const GroupedSharding & grouped_sharding,const Shape & original_base_shape)1442 Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding,
1443 const Shape& original_base_shape) {
1444 auto result = original_base_shape;
1445 for (int64_t i = 0; i < grouped_sharding.group_dims.size(); ++i) {
1446 int64_t dim = grouped_sharding.group_dims[i];
1447 if (dim >= original_base_shape.rank()) {
1448 continue;
1449 }
1450 int64_t groups = grouped_sharding.group_dim_sizes[i];
1451 result.set_dimensions(dim, CeilOfRatio(result.dimensions(dim), groups));
1452 }
1453 return result;
1454 }
1455
1456 namespace {
1457
GetInGroupPartitionId(HloInstruction * partition_id,const std::vector<std::vector<int64_t>> & device_groups,SpmdBuilder * b)1458 HloInstruction* GetInGroupPartitionId(
1459 HloInstruction* partition_id,
1460 const std::vector<std::vector<int64_t>>& device_groups, SpmdBuilder* b) {
1461 int64_t total_devices = device_groups.size() * device_groups[0].size();
1462 std::vector<uint32_t> in_group_ids(total_devices);
1463 for (uint32_t i = 0; i < device_groups.size(); ++i) {
1464 for (uint32_t j = 0; j < device_groups[i].size(); ++j) {
1465 in_group_ids[device_groups[i][j]] = j;
1466 }
1467 }
1468 auto id_table = b->AddInstruction(HloInstruction::CreateConstant(
1469 LiteralUtil::CreateR1<uint32_t>(in_group_ids)));
1470 return b->AddInstruction(HloInstruction::CreateReshape(
1471 ShapeUtil::MakeScalarShape(U32),
1472 b->AddInstruction(HloInstruction::CreateDynamicSlice(
1473 ShapeUtil::MakeShape(U32, {1}), id_table, {partition_id}, {1}))));
1474 }
1475
GetPerGroupCollectiveOpsCreator(const SPMDCollectiveOpsCreator & creator,const std::vector<std::vector<int64_t>> & device_groups)1476 SPMDCollectiveOpsCreator GetPerGroupCollectiveOpsCreator(
1477 const SPMDCollectiveOpsCreator& creator,
1478 const std::vector<std::vector<int64_t>>& device_groups) {
1479 SPMDCollectiveOpsCreator result;
1480 result.create_partition_id = [creator, device_groups](SpmdBuilder* b) {
1481 return GetInGroupPartitionId(creator.create_partition_id(b), device_groups,
1482 b);
1483 };
1484 auto expand_partition_groups =
1485 [device_groups](
1486 const std::vector<std::vector<int64_t>>& partition_subgroups) {
1487 if (partition_subgroups.empty()) {
1488 return device_groups;
1489 }
1490 std::vector<std::vector<int64_t>> result(partition_subgroups.size() *
1491 device_groups.size());
1492 for (int64_t g = 0; g < device_groups.size(); ++g) {
1493 for (int64_t i = 0; i < partition_subgroups.size(); ++i) {
1494 result[g * partition_subgroups.size() + i].resize(
1495 partition_subgroups[i].size());
1496 for (int64_t j = 0; j < partition_subgroups[i].size(); ++j) {
1497 result[g * partition_subgroups.size() + i][j] =
1498 device_groups[g][partition_subgroups[i][j]];
1499 }
1500 }
1501 }
1502 return result;
1503 };
1504 result.create_cross_partition_all_reduce =
1505 [creator, expand_partition_groups](
1506 SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction,
1507 const std::vector<std::vector<int64_t>>& partition_subgroups,
1508 int64_t channel_id) {
1509 return creator.create_cross_partition_all_reduce(
1510 b, operand, reduction, expand_partition_groups(partition_subgroups),
1511 channel_id);
1512 };
1513 result.create_cross_partition_collective_permute =
1514 [creator, device_groups](
1515 SpmdBuilder* b, HloInstruction* operand,
1516 std::vector<std::pair<int64_t, int64_t>>& src_dst_pairs,
1517 int64_t next_channel_id) {
1518 std::vector<std::pair<int64_t, int64_t>> expanded_pairs(
1519 src_dst_pairs.size() * device_groups.size());
1520 for (int64_t g = 0; g < device_groups.size(); ++g) {
1521 for (int64_t i = 0; i < src_dst_pairs.size(); ++i) {
1522 expanded_pairs[g * src_dst_pairs.size() + i] =
1523 std::pair<int64_t, int64_t>{
1524 device_groups[g][src_dst_pairs[i].first],
1525 device_groups[g][src_dst_pairs[i].second]};
1526 }
1527 }
1528 return creator.create_cross_partition_collective_permute(
1529 b, operand, expanded_pairs, next_channel_id);
1530 };
1531 result.create_cross_partition_all_to_all =
1532 [creator, expand_partition_groups](
1533 SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
1534 const std::vector<std::vector<int64_t>>& partition_subgroups,
1535 int64_t channel_id, std::optional<int64_t> split_dimension) {
1536 return creator.create_cross_partition_all_to_all(
1537 b, operands, expand_partition_groups(partition_subgroups),
1538 channel_id, split_dimension);
1539 };
1540 if (creator.create_cross_partition_all_gather) {
1541 result.create_cross_partition_all_gather =
1542 [creator, expand_partition_groups](
1543 SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape,
1544 const std::vector<std::vector<int64_t>>& partition_subgroups,
1545 int64_t channel_id, int64_t all_gather_dimension) {
1546 return creator.create_cross_partition_all_gather(
1547 b, operand, ag_shape,
1548 expand_partition_groups(partition_subgroups), channel_id,
1549 all_gather_dimension);
1550 };
1551 }
1552 return result;
1553 }
1554
1555 } // namespace
1556
CreatePerGroupPartitioningState(const PartitionedHlo::PartitioningState & state,const std::vector<std::vector<int64_t>> & device_groups,SpmdBuilder * b)1557 PartitionedHlo::PartitioningState CreatePerGroupPartitioningState(
1558 const PartitionedHlo::PartitioningState& state,
1559 const std::vector<std::vector<int64_t>>& device_groups, SpmdBuilder* b) {
1560 auto result = state;
1561 result.collective_ops_creator = GetPerGroupCollectiveOpsCreator(
1562 state.collective_ops_creator, device_groups);
1563 result.partition_id =
1564 GetInGroupPartitionId(state.partition_id, device_groups, b);
1565 // Create a string key for the groups.
1566 std::vector<std::string> per_group_strings(device_groups.size());
1567 for (int64_t i = 0; i < per_group_strings.size(); ++i) {
1568 per_group_strings[i] = absl::StrJoin(device_groups[i], ",");
1569 }
1570 auto& grouped_cache =
1571 state.reshard_cache->groupd_caches[absl::StrJoin(per_group_strings, ";")];
1572 if (!grouped_cache) {
1573 grouped_cache = std::make_unique<PartitionedHlo::ReshardCache>();
1574 }
1575 result.reshard_cache = grouped_cache.get();
1576 return result;
1577 }
1578
PerGroupSliceFromReplicated(HloInstruction * replicated,HloInstruction * partition_id,const std::vector<std::vector<int64_t>> & device_groups,absl::Span<const int64_t> group_dims,absl::Span<const int64_t> group_dim_sizes,SpmdBuilder * b)1579 HloInstruction* PerGroupSliceFromReplicated(
1580 HloInstruction* replicated, HloInstruction* partition_id,
1581 const std::vector<std::vector<int64_t>>& device_groups,
1582 absl::Span<const int64_t> group_dims,
1583 absl::Span<const int64_t> group_dim_sizes, SpmdBuilder* b) {
1584 std::vector<uint32_t> group_ids(device_groups.size() *
1585 device_groups[0].size());
1586 for (int64_t g = 0; g < device_groups.size(); ++g) {
1587 for (int64_t device : device_groups[g]) {
1588 group_ids[device] = g;
1589 }
1590 }
1591 auto group_id_table = b->AddInstruction(HloInstruction::CreateConstant(
1592 LiteralUtil::CreateR1<uint32_t>(group_ids)));
1593 auto group_id = b->AddInstruction(HloInstruction::CreateReshape(
1594 ShapeUtil::MakeScalarShape(U32),
1595 b->AddInstruction(HloInstruction::CreateDynamicSlice(
1596 ShapeUtil::MakeShape(U32, {1}), group_id_table, {partition_id},
1597 {1}))));
1598 std::vector<int64_t> group_level_tile_dims(replicated->shape().rank(), 1);
1599 for (int64_t i = 0; i < group_dims.size(); ++i) {
1600 group_level_tile_dims[group_dims[i]] = group_dim_sizes[i];
1601 }
1602 Array<int64_t> group_level_tile(group_level_tile_dims);
1603 group_level_tile.Each([&](absl::Span<const int64_t> indices, int64_t* group) {
1604 *group = 0;
1605 for (int64_t dim : group_dims) {
1606 *group *= group_level_tile.dim(dim);
1607 *group += indices[dim];
1608 }
1609 });
1610 auto group_level_sharding = HloSharding::Tile(group_level_tile);
1611 auto padded_hlo = PadBaseShapeBeforeUnevenTiledSharding(
1612 replicated, group_level_sharding, b);
1613 auto shard_shape =
1614 MakePartitionedShape(replicated->shape(), group_level_sharding);
1615 return b->AddInstruction(HloInstruction::CreateDynamicSlice(
1616 shard_shape, padded_hlo,
1617 MakePartitionOffsets(replicated->shape(), group_level_sharding, group_id,
1618 b),
1619 shard_shape.dimensions()));
1620 }
1621
FindMatchingPartitionedDimsForGrouping(const HloSharding & sharding,const std::vector<std::vector<int64_t>> & device_groups)1622 std::optional<std::vector<int64_t>> FindMatchingPartitionedDimsForGrouping(
1623 const HloSharding& sharding,
1624 const std::vector<std::vector<int64_t>>& device_groups) {
1625 if (sharding.IsTileMaximal() || device_groups.size() < 2) {
1626 return std::nullopt;
1627 }
1628 std::vector<int64_t> dims;
1629 if (device_groups[0].size() < 2) {
1630 // Trivial case: single member groups
1631 for (int64_t i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
1632 if (sharding.tile_assignment().dim(i) > 1) {
1633 dims.push_back(i);
1634 }
1635 }
1636 return dims;
1637 }
1638 int64_t rank = sharding.tile_assignment().num_dimensions();
1639 absl::flat_hash_map<int64_t, std::vector<int64_t>> device_to_index;
1640 sharding.tile_assignment().Each(
1641 [&](absl::Span<const int64_t> index, int64_t device) {
1642 device_to_index[device] =
1643 std::vector<int64_t>(index.begin(), index.begin() + rank);
1644 });
1645 int64_t group_count = 1;
1646 for (int64_t i = 0; i < rank; ++i) {
1647 if (device_to_index[device_groups[0][0]][i] ==
1648 device_to_index[device_groups[0][1]][i]) {
1649 dims.push_back(i);
1650 group_count *= sharding.tile_assignment().dim(i);
1651 }
1652 }
1653 if (group_count != device_groups.size()) {
1654 return std::nullopt;
1655 }
1656 for (const auto& group : device_groups) {
1657 for (int64_t i = 1; i < group.size(); ++i) {
1658 if (absl::c_any_of(dims, [&](const int64_t dim) {
1659 return device_to_index[group[i]][dim] !=
1660 device_to_index[group[0]][dim];
1661 })) {
1662 return std::nullopt;
1663 }
1664 }
1665 }
1666 return dims;
1667 }
1668
CreateMatchingShardingOnDims(const Shape & target_shape,const HloSharding & source_sharding,absl::Span<const int64_t> target_dims,absl::Span<const int64_t> source_dims)1669 HloSharding CreateMatchingShardingOnDims(
1670 const Shape& target_shape, const HloSharding& source_sharding,
1671 absl::Span<const int64_t> target_dims,
1672 absl::Span<const int64_t> source_dims) {
1673 CHECK(target_dims.size() == source_dims.size())
1674 << "Expected 1:1 match between parallel dimensions";
1675 if (source_sharding.IsReplicated()) {
1676 return HloSharding::Replicate();
1677 }
1678 absl::InlinedVector<int64_t, 4> tile_dims(target_shape.dimensions_size(), 1);
1679 int num_tiles = 1;
1680 for (int i = 0, end = target_dims.size(); i < end; ++i) {
1681 num_tiles *= source_sharding.tile_assignment().dim(source_dims[i]);
1682 tile_dims[target_dims[i]] =
1683 source_sharding.tile_assignment().dim(source_dims[i]);
1684 }
1685 // If there is some partition across non-parallel dimensions in the
1686 // other operand then partially replicate for the new
1687 bool to_be_partially_replicated = false;
1688 if (num_tiles != source_sharding.tile_assignment().num_elements()) {
1689 CHECK_EQ(source_sharding.tile_assignment().num_elements() % num_tiles, 0);
1690 to_be_partially_replicated = true;
1691 tile_dims.push_back(source_sharding.tile_assignment().num_elements() /
1692 num_tiles);
1693 }
1694 auto tgt_tile_assignment = source_sharding.tile_assignment();
1695 tgt_tile_assignment.Reshape(tile_dims);
1696 if (to_be_partially_replicated) {
1697 return AlignShardingOnDims(HloSharding::PartialTile(tgt_tile_assignment),
1698 target_dims, source_sharding, source_dims);
1699 } else {
1700 return AlignShardingOnDims(HloSharding::Tile(tgt_tile_assignment),
1701 target_dims, source_sharding, source_dims);
1702 }
1703 }
1704
1705 std::optional<GatherParallelDimSharding>
GatherOperandsShardedAcrossParallelDims(const HloInstruction & operand,const HloInstruction & indices,const hlo_sharding_util::GatherParallelDims & parallel_dims)1706 GatherOperandsShardedAcrossParallelDims(
1707 const HloInstruction& operand, const HloInstruction& indices,
1708 const hlo_sharding_util::GatherParallelDims& parallel_dims) {
1709 auto& indices_parallel_dims = parallel_dims.indices_parallel_dims;
1710 auto& operand_parallel_dims = parallel_dims.operand_parallel_dims;
1711 if (indices_parallel_dims.size() != operand_parallel_dims.size()) {
1712 return std::nullopt;
1713 }
1714 auto new_index_shard = indices.sharding();
1715 auto new_operand_shard = operand.sharding();
1716 int idx_parallel_tiles_num = new_index_shard.NumTiles(indices_parallel_dims);
1717 int op_parallel_tiles_num = new_operand_shard.NumTiles(operand_parallel_dims);
1718 if (idx_parallel_tiles_num == 1 && op_parallel_tiles_num == 1) {
1719 return std::nullopt;
1720 }
1721 absl::InlinedVector<int64_t, 1> indices_parallel_dims_ordered_as_operand;
1722 for (int idx : parallel_dims.index_parallel_in_dim) {
1723 if (idx != -1) {
1724 indices_parallel_dims_ordered_as_operand.push_back(idx);
1725 }
1726 }
1727 if (new_index_shard.IsReplicated()) {
1728 return GatherParallelDimSharding{
1729 CreateMatchingShardingOnDims(indices.shape(), new_operand_shard,
1730 indices_parallel_dims_ordered_as_operand,
1731 operand_parallel_dims),
1732 new_operand_shard};
1733 }
1734 if (new_operand_shard.IsReplicated()) {
1735 return GatherParallelDimSharding{
1736 new_index_shard,
1737 CreateMatchingShardingOnDims(operand.shape(), new_index_shard,
1738 operand_parallel_dims,
1739 indices_parallel_dims_ordered_as_operand)};
1740 }
1741
1742 // Parallel dimension distribution needs to be the same, so try to steal
1743 // sharding from partial replication to compensate.
1744 if (idx_parallel_tiles_num != op_parallel_tiles_num) {
1745 auto to_adjust_dims = operand_parallel_dims;
1746 auto target_dims = indices_parallel_dims_ordered_as_operand;
1747 HloSharding* target = &new_index_shard;
1748 HloSharding* to_adjust = &new_operand_shard;
1749 if (idx_parallel_tiles_num < op_parallel_tiles_num) {
1750 std::swap(to_adjust_dims, target_dims);
1751 std::swap(to_adjust, target);
1752 }
1753 if (!to_adjust->ReplicateOnLastTileDim()) {
1754 return std::nullopt;
1755 }
1756 auto new_tile_assignment_dims = to_adjust->tile_assignment().dimensions();
1757 for (int i = 0; i < to_adjust_dims.size(); ++i) {
1758 int64_t target_dim = target->tile_assignment().dim(target_dims[i]);
1759 int64_t to_adjust_dim =
1760 to_adjust->tile_assignment().dim(to_adjust_dims[i]);
1761 if (target_dim < to_adjust_dim) {
1762 return std::nullopt;
1763 }
1764 if (target_dim == to_adjust_dim) {
1765 continue;
1766 }
1767 int64_t ratio = target_dim / to_adjust_dim;
1768 if (target_dim % to_adjust_dim != 0 ||
1769 new_tile_assignment_dims.back() % ratio != 0) {
1770 return std::nullopt;
1771 }
1772 new_tile_assignment_dims[to_adjust_dims[i]] *= ratio;
1773 new_tile_assignment_dims.back() /= ratio;
1774 }
1775 CHECK_GE(new_tile_assignment_dims.back(), 1);
1776 bool to_partially_replicate = true;
1777 if (new_tile_assignment_dims.back() == 1) {
1778 new_tile_assignment_dims.pop_back();
1779 to_partially_replicate = false;
1780 }
1781 auto new_tile_assignment = to_adjust->tile_assignment();
1782 new_tile_assignment.Reshape(new_tile_assignment_dims);
1783 if (to_partially_replicate) {
1784 *to_adjust =
1785 AlignShardingOnDims(HloSharding::PartialTile(new_tile_assignment),
1786 to_adjust_dims, *target, target_dims);
1787 } else {
1788 *to_adjust = AlignShardingOnDims(HloSharding::Tile(new_tile_assignment),
1789 to_adjust_dims, *target, target_dims);
1790 }
1791 }
1792 // Make sure that the parallel dimensions are aligned.
1793 auto operand_shard_tile_dims =
1794 new_operand_shard.tile_assignment().dimensions();
1795 for (int i = 0; i < indices_parallel_dims_ordered_as_operand.size(); ++i) {
1796 operand_shard_tile_dims[operand_parallel_dims[i]] =
1797 new_index_shard.tile_assignment().dim(
1798 indices_parallel_dims_ordered_as_operand[i]);
1799 }
1800 auto operand_shard_tiles = new_operand_shard.tile_assignment();
1801 operand_shard_tiles.Reshape(operand_shard_tile_dims);
1802 new_operand_shard =
1803 AlignShardingOnDims(new_operand_shard.ReplicateOnLastTileDim()
1804 ? HloSharding::PartialTile(operand_shard_tiles)
1805 : HloSharding::Tile(operand_shard_tiles),
1806 operand_parallel_dims, new_index_shard,
1807 indices_parallel_dims_ordered_as_operand);
1808 return GatherParallelDimSharding{new_index_shard, new_operand_shard};
1809 }
1810
FindRotateRightPattern(const HloInstruction * concat,const HloInstruction * lhs,const HloInstruction * rhs)1811 int64_t FindRotateRightPattern(const HloInstruction* concat,
1812 const HloInstruction* lhs,
1813 const HloInstruction* rhs) {
1814 if (lhs->opcode() != HloOpcode::kSlice ||
1815 rhs->opcode() != HloOpcode::kSlice ||
1816 lhs->operand(0) != rhs->operand(0)) {
1817 return -1;
1818 }
1819 const HloInstruction* to_rotate = lhs->operand(0);
1820 if (!ShapeUtil::Compatible(to_rotate->shape(), concat->shape()) ||
1821 concat->sharding() != to_rotate->sharding()) {
1822 return -1;
1823 }
1824 const int64_t dim = concat->concatenate_dimension();
1825 if (lhs->slice_strides(dim) != 1 || rhs->slice_strides(dim) != 1 ||
1826 lhs->slice_starts(dim) != rhs->slice_limits(dim)) {
1827 return -1;
1828 }
1829 return lhs->shape().dimensions(dim);
1830 }
1831
FindPadWithWrapPattern(const HloInstruction * concat,const HloInstruction * lhs,const HloInstruction * mid,const HloInstruction * rhs)1832 std::optional<PadWithWrapPattern> FindPadWithWrapPattern(
1833 const HloInstruction* concat, const HloInstruction* lhs,
1834 const HloInstruction* mid, const HloInstruction* rhs) {
1835 if (!lhs || !mid || !rhs) {
1836 return std::nullopt;
1837 }
1838
1839 // Skip elementwise unary operations applied to inst, returning
1840 // a list of applied operations that were skipped.
1841 auto skip_elementwise_ops = [&](const HloInstruction* inst) {
1842 std::vector<const HloInstruction*> modifiers;
1843 while (inst->IsElementwise() && inst->operand_count() == 1 &&
1844 inst->user_count() == 1) {
1845 if (inst->opcode() != HloOpcode::kCopy) {
1846 modifiers.push_back(inst);
1847 }
1848 inst = inst->operand(0);
1849 }
1850 return std::make_pair(modifiers, inst);
1851 };
1852
1853 PadWithWrapPattern pad_pattern;
1854 auto skip_result = skip_elementwise_ops(lhs);
1855 pad_pattern.lhs_modifiers = std::move(skip_result.first);
1856 lhs = skip_result.second;
1857
1858 skip_result = skip_elementwise_ops(rhs);
1859 pad_pattern.rhs_modifiers = std::move(skip_result.first);
1860 rhs = skip_result.second;
1861
1862 const int64_t dim = concat->concatenate_dimension();
1863 if (lhs->opcode() != HloOpcode::kSlice ||
1864 rhs->opcode() != HloOpcode::kSlice || lhs->operand(0) != mid ||
1865 rhs->operand(0) != mid || lhs->slice_strides(dim) != 1 ||
1866 rhs->slice_strides(dim) != 1 || lhs->sharding() != mid->sharding() ||
1867 rhs->sharding() != mid->sharding() ||
1868 lhs->sharding() != concat->sharding()) {
1869 return std::nullopt;
1870 }
1871 pad_pattern.lhs_slice_start = lhs->slice_starts(dim);
1872 pad_pattern.rhs_slice_start = rhs->slice_starts(dim);
1873 return pad_pattern;
1874 }
1875
1876 } // namespace spmd
1877 } // namespace xla
1878