1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
17
18 #include <algorithm>
19 #include <iostream>
20 #include <iterator>
21 #include <map>
22 #include <memory>
23 #include <optional>
24 #include <set>
25 #include <string>
26 #include <utility>
27 #include <vector>
28
29 #include "absl/algorithm/container.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "tensorflow/compiler/xla/array.h"
32 #include "tensorflow/compiler/xla/literal_util.h"
33 #include "tensorflow/compiler/xla/protobuf_util.h"
34 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
35 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
36 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39
40 namespace xla {
41 namespace hlo_sharding_util {
42
IsShardingMoreSpecific(const HloSharding & lhs,const HloSharding & rhs)43 bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) {
44 CHECK_EQ(lhs.IsTuple(), rhs.IsTuple()) << lhs << " <> " << rhs;
45 if (lhs.IsTuple()) {
46 // For tuples we consider lhs to have a better sharding if none of the
47 // elements are worse and at least one element is better then in rhs
48 // sharding.
49 const auto& lhs_shardings = lhs.tuple_elements();
50 const auto& rhs_shardings = rhs.tuple_elements();
51 CHECK_EQ(lhs_shardings.size(), rhs_shardings.size());
52 bool is_better = false;
53 for (int64_t i = 0; i < lhs_shardings.size(); ++i) {
54 if (IsShardingMoreSpecific(rhs_shardings[i], lhs_shardings[i])) {
55 return false;
56 }
57 if (IsShardingMoreSpecific(lhs_shardings[i], rhs_shardings[i])) {
58 is_better = true;
59 }
60 }
61 return is_better;
62 }
63 if (lhs.IsManual() || rhs.IsManual()) {
64 return false;
65 }
66 if (!rhs.IsTileMaximal()) {
67 return lhs.NumTiles() > rhs.NumTiles();
68 } else if (!rhs.IsReplicated()) {
69 // If we are not replicated then only tiled (not tile maximal) shardings
70 // can improve us.
71 return !lhs.IsTileMaximal();
72 } else {
73 // If we are replicated then any non-replicated sharding can improve us.
74 return !lhs.IsReplicated();
75 }
76 }
77
MergeSharding(const HloSharding & old,HloSharding * to_merge,bool may_combine_partial_sharding)78 bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
79 bool may_combine_partial_sharding) {
80 if (old.IsTuple()) {
81 CHECK(to_merge->IsTuple());
82 bool changed = false;
83 for (int64_t i = 0; i < old.tuple_elements().size(); ++i) {
84 changed |=
85 MergeSharding(old.tuple_elements()[i], &to_merge->tuple_elements()[i],
86 may_combine_partial_sharding);
87 }
88 return changed;
89 }
90 if (!may_combine_partial_sharding || !old.HasPartialReplication() ||
91 !to_merge->HasPartialReplication() ||
92 old.tile_assignment().num_elements() !=
93 to_merge->tile_assignment().num_elements()) {
94 return IsShardingMoreSpecific(*to_merge, old);
95 }
96
97 if (MergeShardingIfCompatible(
98 old,
99 /*minimum_tiles=*/std::max(old.NumTiles(), to_merge->NumTiles()) + 1,
100 to_merge)) {
101 return true;
102 }
103 return IsShardingMoreSpecific(*to_merge, old);
104 }
105
MergeShardingIfCompatible(const HloSharding & to_merge,int64_t minimum_tiles,HloSharding * dst)106 bool MergeShardingIfCompatible(const HloSharding& to_merge,
107 int64_t minimum_tiles, HloSharding* dst) {
108 if (to_merge.IsTileMaximal()) {
109 return false;
110 }
111 if (dst->IsTileMaximal()) {
112 *dst = to_merge;
113 return true;
114 }
115 if (!dst->HasPartialReplication()) {
116 return false;
117 }
118 // Combine the tile dimension sizes from dst and to_merge.
119 int64_t num_devices = to_merge.tile_assignment().num_elements();
120 std::vector<int64_t> merged_tile_dims;
121 merged_tile_dims.reserve(dst->tile_assignment().num_dimensions());
122 for (int64_t i = 0; i < to_merge.TiledDataRank(); ++i) {
123 int64_t dst_dim = dst->tile_assignment().dim(i);
124 int64_t merge_dim = to_merge.tile_assignment().dim(i);
125 if (dst_dim == 1) {
126 merged_tile_dims.push_back(merge_dim);
127 } else if (merge_dim == 1) {
128 merged_tile_dims.push_back(dst_dim);
129 } else if (dst_dim == merge_dim) {
130 merged_tile_dims.push_back(dst_dim);
131 } else {
132 return false;
133 }
134 }
135 const int64_t num_tiles = Product(merged_tile_dims);
136 if (num_devices % num_tiles != 0 || num_tiles < minimum_tiles) {
137 return false;
138 }
139 int64_t to_merge_man_dim = to_merge.SubgroupManualDim();
140 int64_t dst_man_dim = dst->SubgroupManualDim();
141 if (to_merge_man_dim >= 0) {
142 if (dst_man_dim < 0) {
143 return false;
144 }
145 int64_t man_group_size = to_merge.tile_assignment().dim(to_merge_man_dim);
146 if (man_group_size != dst->tile_assignment().dim(dst_man_dim)) {
147 return false;
148 }
149 merged_tile_dims.push_back(man_group_size);
150 }
151 int64_t replication = num_devices / Product(merged_tile_dims);
152 merged_tile_dims.push_back(replication);
153 Array<int64_t> merged_tile(merged_tile_dims);
154 // Maps from replication group ID to sorted members.
155 absl::flat_hash_map<int64_t, std::set<int64_t>> merge_group_members;
156 absl::flat_hash_map<int64_t, std::set<int64_t>> dst_group_members;
157 auto get_group_index = [&](absl::Span<const int64_t> tile_indices,
158 const HloSharding& sharding, int64_t manual_dim) {
159 int64_t group_id = 0;
160 for (int64_t i = 0; i < to_merge.TiledDataRank(); ++i) {
161 group_id *= sharding.tile_assignment().dim(i);
162 group_id += tile_indices[i];
163 }
164 if (manual_dim >= 0) {
165 group_id *= sharding.tile_assignment().dim(manual_dim);
166 group_id += tile_indices[manual_dim];
167 }
168 return group_id;
169 };
170 to_merge.tile_assignment().Each([&](absl::Span<const int64_t> indices,
171 int64_t device) {
172 merge_group_members[get_group_index(indices, to_merge, to_merge_man_dim)]
173 .insert(device);
174 });
175 dst->tile_assignment().Each(
176 [&](absl::Span<const int64_t> indices, int64_t device) {
177 dst_group_members[get_group_index(indices, *dst, dst_man_dim)].insert(
178 device);
179 });
180 // Try to find the intersection of to_merge and dst replication groups, in
181 // order to determine the merged tile assignment.
182 Status compatible = merged_tile.EachStatus(
183 [&](absl::Span<const int64_t> indices, int64_t* device) {
184 std::vector<int64_t> to_merge_index(
185 to_merge.tile_assignment().num_dimensions());
186 std::vector<int64_t> dst_index(dst->tile_assignment().num_dimensions());
187 for (int64_t i = 0; i < to_merge.TiledDataRank(); ++i) {
188 if (to_merge.tile_assignment().dim(i) == 1) {
189 to_merge_index[i] = 0;
190 } else {
191 to_merge_index[i] = indices[i];
192 }
193 if (dst->tile_assignment().dim(i) == 1) {
194 dst_index[i] = 0;
195 } else {
196 dst_index[i] = indices[i];
197 }
198 }
199 if (to_merge_man_dim >= 0) {
200 to_merge_index[to_merge_man_dim] = indices[to_merge.TiledDataRank()];
201 dst_index[dst_man_dim] = indices[to_merge.TiledDataRank()];
202 }
203 if (to_merge.HasPartialReplication()) {
204 to_merge_index[to_merge.SubgroupReplicationDim()] = indices.back();
205 }
206 dst_index[dst->SubgroupReplicationDim()] = indices.back();
207 int64_t to_merge_group_id =
208 get_group_index(to_merge_index, to_merge, to_merge_man_dim);
209 int64_t dst_group_id = get_group_index(dst_index, *dst, dst_man_dim);
210 if (merge_group_members[to_merge_group_id].empty() ||
211 dst_group_members[dst_group_id].empty()) {
212 return InvalidArgument("Not compatible");
213 }
214
215 int64_t smallest_to_merge =
216 *merge_group_members[to_merge_group_id].begin();
217 int64_t smallest_dst = *dst_group_members[dst_group_id].begin();
218 if (smallest_to_merge < smallest_dst) {
219 if (merge_group_members[to_merge_group_id].count(smallest_dst) == 0) {
220 return InvalidArgument("Not compatible");
221 }
222 *device = smallest_dst;
223 } else {
224 if (dst_group_members[dst_group_id].count(smallest_to_merge) == 0) {
225 return InvalidArgument("Not compatible");
226 }
227 *device = smallest_to_merge;
228 }
229 merge_group_members[to_merge_group_id].erase(*device);
230 dst_group_members[dst_group_id].erase(*device);
231 return OkStatus();
232 });
233 if (!compatible.ok()) {
234 return false;
235 }
236 std::vector<OpMetadata> merged_metadata(std::move(dst->metadata()));
237 merged_metadata.reserve(merged_metadata.size() + to_merge.metadata().size());
238 const absl::flat_hash_set<OpMetadata, protobuf_util::ProtobufHashWrapper,
239 protobuf_util::ProtobufEqualsWrapper>
240 metadata_set(merged_metadata.begin(), merged_metadata.end());
241 absl::c_copy_if(to_merge.metadata(), std::back_inserter(merged_metadata),
242 [&metadata_set](const OpMetadata& data) {
243 return !ContainsKey(metadata_set, data);
244 });
245 std::vector<OpSharding::Type> subgroup_types;
246 if (to_merge_man_dim >= 0) {
247 subgroup_types.push_back(OpSharding::MANUAL);
248 }
249 subgroup_types.push_back(OpSharding::REPLICATED);
250 *dst = HloSharding::Subgroup(merged_tile, subgroup_types, merged_metadata);
251 return true;
252 }
253
SelectDominantDevice(const std::map<int64_t,int64_t> & device_map,int64_t * top_count)254 std::optional<int64_t> SelectDominantDevice(
255 const std::map<int64_t, int64_t>& device_map, int64_t* top_count) {
256 int64_t device = 0;
257 int64_t count = 0;
258 for (auto& it : device_map) {
259 if (it.second > count) {
260 count = it.second;
261 device = it.first;
262 }
263 }
264 if (top_count != nullptr) {
265 *top_count = count;
266 }
267 return count > 0 ? std::optional<int64_t>(device) : std::optional<int64_t>();
268 }
269
AssignComputationDevice(HloComputation * computation,int64_t device)270 void AssignComputationDevice(HloComputation* computation, int64_t device) {
271 VLOG(4) << "Assigning device " << device << " to " << computation->name()
272 << " computation";
273 for (HloInstruction* instruction : computation->instructions()) {
274 if (!instruction->has_sharding()) {
275 VLOG(4) << "Assigning device " << device << " to " << instruction->name();
276 instruction->set_device_sharding(device);
277 }
278 }
279 }
280
GetMostOccurringDevice(absl::Span<HloInstruction * const> instructions)281 std::optional<int64_t> GetMostOccurringDevice(
282 absl::Span<HloInstruction* const> instructions) {
283 std::map<int64_t, int64_t> device_map;
284 for (HloInstruction* instruction : instructions) {
285 if (instruction->has_sharding()) {
286 for (auto& it : instruction->sharding().UsedDevices(nullptr)) {
287 // The UsedDevices() API returns a map<device, occurrence_count>.
288 device_map[it.first] += it.second;
289 }
290 }
291 }
292 return SelectDominantDevice(device_map, nullptr);
293 }
294
GetDominantDevice(absl::Span<HloComputation * const> computations,double dominant_factor)295 std::optional<int64_t> GetDominantDevice(
296 absl::Span<HloComputation* const> computations, double dominant_factor) {
297 int64_t instruction_count = 0;
298 std::map<int64_t, int64_t> device_map;
299 for (HloComputation* computation : computations) {
300 for (HloInstruction* instruction : computation->instructions()) {
301 int64_t count = 1;
302 if (instruction->has_sharding()) {
303 for (auto& it : instruction->sharding().UsedDevices(&count)) {
304 // The UsedDevices() API returns a map<device, occurrence_count>.
305 device_map[it.first] += it.second;
306 }
307 }
308 instruction_count += count;
309 }
310 }
311 int64_t count;
312 std::optional<int64_t> device = SelectDominantDevice(device_map, &count);
313 std::optional<int64_t> dominant_device;
314 if (device) {
315 double factor =
316 static_cast<double>(count) / static_cast<double>(instruction_count);
317 if (factor >= dominant_factor) {
318 dominant_device = device;
319 }
320 }
321 return dominant_device;
322 }
323
TransposeSharding(const HloSharding & sharding,absl::Span<const int64_t> dimensions)324 HloSharding TransposeSharding(const HloSharding& sharding,
325 absl::Span<const int64_t> dimensions) {
326 if (sharding.IsTileMaximal()) {
327 return sharding;
328 }
329 DimensionVector perm_dimensions(dimensions.begin(), dimensions.end());
330 // Add subgroup dims if missing.
331 if (sharding.TiledDataRank() == dimensions.size()) {
332 for (int64_t i = sharding.TiledDataRank();
333 i < sharding.tile_assignment().num_dimensions(); ++i) {
334 perm_dimensions.push_back(i);
335 }
336 } else {
337 CHECK_EQ(sharding.tile_assignment().num_dimensions(), dimensions.size());
338 }
339 Array<int64_t> tile_assignment = sharding.tile_assignment();
340 tile_assignment.TransposeDimensions(perm_dimensions);
341 if (!sharding.ReplicateOnLastTileDim()) {
342 std::vector<OpSharding::Type> subgroup_types;
343 for (int64_t i = sharding.TiledDataRank(); i < perm_dimensions.size();
344 ++i) {
345 int64_t src_i = perm_dimensions[i] - sharding.TiledDataRank();
346 subgroup_types.push_back(sharding.subgroup_types()[src_i]);
347 }
348 return HloSharding::Subgroup(tile_assignment, subgroup_types,
349 sharding.metadata());
350 } else {
351 return HloSharding::PartialTile(tile_assignment, sharding.metadata());
352 }
353 }
354
ReshapeSharding(const Shape & source_shape,const Shape & target_shape,const HloSharding & sharding)355 std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
356 const Shape& target_shape,
357 const HloSharding& sharding) {
358 if (sharding.IsTileMaximal()) {
359 return sharding;
360 }
361
362 // In case of a tiled sharding the reshaped sharding will be a valid if the
363 // reshape is composed from the following operations:
364 // * Adding or removing dimensions with size 1.
365 // * Merging consecutive dimensions where only the most major is sharded.
366 // * Splitting a dimension to consecutive dimensions.
367 // * Any reshaping of unsharded dimensions.
368 // Note that merge and split can happen consecutively on the same dimension,
369 // e.g., f32[1024,256,1024] to f32[128,2048,1024] can be considered that 1024
370 // gets split into 128 and 8, but 8 then gets merged with 256. We use stacks
371 // to make supporting such cases easy.
372 const Shape tile_shape = sharding.TileShape(source_shape);
373 std::vector<int64_t> target_tile_assignment_dimensions;
374 std::vector<int64_t> source_dims_stack(source_shape.rank());
375 std::vector<int64_t> target_dims_stack(target_shape.rank());
376 std::vector<int64_t> sharding_tile_dims_stack(source_shape.rank());
377 int64_t added_to_partially_replicated = 1;
378 for (int64_t i = 0; i < source_shape.rank(); ++i) {
379 source_dims_stack[i] = source_shape.dimensions(source_shape.rank() - 1 - i);
380 sharding_tile_dims_stack[i] =
381 sharding.tile_assignment().dim(source_shape.rank() - 1 - i);
382 }
383 for (int64_t i = 0; i < target_shape.rank(); ++i) {
384 target_dims_stack[i] = target_shape.dimensions(target_shape.rank() - 1 - i);
385 }
386 while (!source_dims_stack.empty() || !target_dims_stack.empty()) {
387 if (target_dims_stack.empty()) {
388 if (Product(sharding_tile_dims_stack) != 1) {
389 return std::nullopt;
390 }
391 break;
392 }
393 int64_t s_size = 1;
394 int64_t t_size = 1;
395 int64_t s_partitions = 1;
396 if (!source_dims_stack.empty()) {
397 s_size = source_dims_stack.back();
398 source_dims_stack.pop_back();
399 s_partitions = sharding_tile_dims_stack.back();
400 sharding_tile_dims_stack.pop_back();
401 }
402 t_size = target_dims_stack.back();
403 target_dims_stack.pop_back();
404 if (s_partitions * Product(sharding_tile_dims_stack) == 1) {
405 // No more partitions left.
406 target_tile_assignment_dimensions.push_back(1);
407 continue;
408 }
409 if (s_size == t_size) {
410 // Same dimension.
411 target_tile_assignment_dimensions.push_back(s_partitions);
412 } else if (t_size == 1) {
413 // Trivial dimension added.
414 target_tile_assignment_dimensions.push_back(1);
415 source_dims_stack.push_back(s_size);
416 sharding_tile_dims_stack.push_back(s_partitions);
417 } else if (s_size == 1) {
418 // Trivial dimension removed.
419 if (s_partitions != 1) {
420 added_to_partially_replicated *= s_partitions;
421 }
422 target_dims_stack.push_back(t_size);
423 } else if (s_size > t_size) {
424 // Dimension split.
425 if (s_size % t_size != 0 || s_size % s_partitions != 0) {
426 return std::nullopt;
427 }
428 if (t_size % s_partitions == 0) {
429 target_tile_assignment_dimensions.push_back(s_partitions);
430 // We have part of the s_size unprocessed, so put it back to stack.
431 source_dims_stack.push_back(s_size / t_size);
432 sharding_tile_dims_stack.push_back(1);
433 } else if (s_partitions % t_size == 0) {
434 target_tile_assignment_dimensions.push_back(t_size);
435 // We have part of the s_size unprocessed, so put it back to stack.
436 source_dims_stack.push_back(s_size / t_size);
437 sharding_tile_dims_stack.push_back(s_partitions / t_size);
438 } else {
439 return std::nullopt;
440 }
441 } else {
442 // Dimension merge. Also merge the source dimension with the next, and
443 // process it next time.
444 if (s_size % s_partitions != 0) {
445 return std::nullopt;
446 }
447 CHECK(!source_dims_stack.empty());
448 if (sharding_tile_dims_stack.back() != 1 && s_size != s_partitions) {
449 // If the next dimension to combine is sharded, we require that the
450 // current dimension's shard size to be 1. Otherwise, the new shard
451 // would be non-contiguous.
452 return std::nullopt;
453 }
454 source_dims_stack.back() *= s_size;
455 sharding_tile_dims_stack.back() *= s_partitions;
456 target_dims_stack.push_back(t_size);
457 }
458 }
459 if (Product(target_tile_assignment_dimensions) == 1) {
460 return std::nullopt;
461 }
462 Array<int64_t> new_tile_assignment = sharding.tile_assignment();
463 for (int64_t i = sharding.TiledDataRank();
464 i < sharding.tile_assignment().num_dimensions(); ++i) {
465 target_tile_assignment_dimensions.push_back(
466 sharding.tile_assignment().dim(i));
467 }
468
469 auto subgroup_types = sharding.subgroup_types();
470 // If we added dimensions to the partially replicated dimension then add the
471 // additional dimension on the partially replicated tiling.
472 if (added_to_partially_replicated > 1) {
473 if (sharding.ReplicateOnLastTileDim()) {
474 target_tile_assignment_dimensions.back() *= added_to_partially_replicated;
475 } else {
476 target_tile_assignment_dimensions.push_back(
477 added_to_partially_replicated);
478 }
479 }
480 // If subgroup_types doesn't have already partially replicated as a sharding
481 // type then add it.
482 if ((sharding.ReplicateOnLastTileDim() ||
483 added_to_partially_replicated > 1) &&
484 (subgroup_types.empty() ||
485 subgroup_types.back() != OpSharding::REPLICATED)) {
486 subgroup_types.push_back(OpSharding::REPLICATED);
487 }
488 new_tile_assignment.Reshape(target_tile_assignment_dimensions);
489 return HloSharding::Subgroup(new_tile_assignment, subgroup_types,
490 sharding.metadata());
491 }
492
ReverseSharding(const HloSharding & sharding,absl::Span<const int64_t> dimensions)493 HloSharding ReverseSharding(const HloSharding& sharding,
494 absl::Span<const int64_t> dimensions) {
495 if (sharding.IsTileMaximal() || dimensions.empty()) {
496 return sharding;
497 }
498
499 Array<int64_t> new_tile_assignment(sharding.tile_assignment().dimensions());
500 new_tile_assignment.Each(
501 [&](absl::Span<const int64_t> indices, int64_t* device) {
502 std::vector<int64_t> original_indices(indices.begin(), indices.end());
503 for (int64_t d : dimensions) {
504 original_indices[d] =
505 new_tile_assignment.dim(d) - 1 - original_indices[d];
506 }
507 *device = sharding.tile_assignment()(original_indices);
508 });
509 return sharding.ReplicateOnLastTileDim()
510 ? HloSharding::PartialTile(new_tile_assignment,
511 sharding.metadata())
512 : HloSharding::Subgroup(new_tile_assignment,
513 sharding.subgroup_types(),
514 sharding.metadata());
515 }
516
ReshapeToTileDimension(const HloSharding & sharding,int64_t dim,absl::Span<const int64_t> dims)517 HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64_t dim,
518 absl::Span<const int64_t> dims) {
519 CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal());
520 CHECK_NE(absl::c_find(dims, dim), dims.end()) << "dim is not in dims";
521 // We optimize the tile assignment on the single dimension dim in a way to
522 // minimize communication among devices caused by the reshard:
523 // +---+---+ +---+---+ +-+-+-+-+
524 // | | | | 0 | | | | | |
525 // | 0 | 1 | +-------+ | | | | |
526 // | | | reshape on | 1 | reshape on | | | | |
527 // +---+---+ dim 0 => +-------+ dim 1 => |0|2|1|3|
528 // | | | | 2 | | | | | |
529 // | 2 | 3 | +-------+ | | | | |
530 // | | | | 3 | | | | | |
531 // +---+---+ +---+---+ +-+-+-+-+
532
533 std::vector<int64_t> tile_dims(sharding.tile_assignment().num_dimensions(),
534 1);
535 // Handle ignore dimensions.
536 std::vector<int64_t> ignore_sizes;
537 int64_t ignore_size = 1;
538 for (int64_t i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
539 if (absl::c_find(dims, i) == dims.end()) {
540 int64_t size = sharding.tile_assignment().dim(i);
541 ignore_sizes.push_back(size);
542 tile_dims[i] = size;
543 ignore_size *= size;
544 }
545 }
546
547 using Buckets = std::vector<std::vector<int64_t>>;
548 Array<Buckets> buckets(ignore_sizes,
549 Buckets(sharding.tile_assignment().dim(dim)));
550 sharding.tile_assignment().Each(
551 [&](absl::Span<const int64_t> index, int64_t device) {
552 std::vector<int64_t> ignore_index;
553 for (int64_t i = 0; i < index.size(); ++i) {
554 if (absl::c_find(dims, i) == dims.end()) {
555 ignore_index.push_back(index[i]);
556 }
557 }
558 buckets(ignore_index)[index[dim]].push_back(device);
559 });
560 std::vector<int64_t> devices;
561 buckets.Each([&](absl::Span<const int64_t> index, const Buckets& buckets) {
562 for (auto& bucket : buckets) {
563 devices.insert(devices.end(), bucket.begin(), bucket.end());
564 }
565 });
566 tile_dims[dim] = devices.size() / ignore_size;
567 Array<int64_t> tile_assignment(tile_dims);
568 tile_assignment.SetValues(devices);
569 return HloSharding::Tile(tile_assignment, sharding.metadata());
570 }
571
ContainsTileSharding(const HloModule & module)572 bool ContainsTileSharding(const HloModule& module) {
573 for (const HloComputation* computation : module.computations()) {
574 for (const HloInstruction* instruction : computation->instructions()) {
575 if (instruction->has_sharding() &&
576 !instruction->sharding().IsTileMaximal()) {
577 return true;
578 }
579 }
580 }
581 return false;
582 }
583
GatherOutputSharding(const HloSharding & index_sharding,const HloInstruction * hlo)584 HloSharding GatherOutputSharding(const HloSharding& index_sharding,
585 const HloInstruction* hlo) {
586 if (index_sharding.IsTileMaximal()) {
587 return index_sharding;
588 }
589
590 const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
591 std::vector<int64_t> output_tile_assignment_dims;
592 const int64_t rank = hlo->shape().rank(),
593 num_dimensions =
594 index_sharding.tile_assignment().num_dimensions();
595 output_tile_assignment_dims.reserve(rank + num_dimensions);
596 for (int64_t i = 0, index_dim = 0; i < rank; ++i) {
597 if (absl::c_binary_search(dnums.offset_dims(), i)) {
598 output_tile_assignment_dims.push_back(1);
599 } else {
600 const int64_t new_tile_dimension =
601 index_dim >= dnums.index_vector_dim() ? index_dim + 1 : index_dim;
602 output_tile_assignment_dims.push_back(
603 index_sharding.tile_assignment().dim(new_tile_dimension));
604 ++index_dim;
605 }
606 }
607
608 for (int64_t i = index_sharding.TiledDataRank(); i < num_dimensions; ++i) {
609 output_tile_assignment_dims.push_back(
610 index_sharding.tile_assignment().dim(i));
611 }
612
613 Array<int64_t> new_tile_assignment = index_sharding.tile_assignment();
614 if (new_tile_assignment.num_elements() !=
615 Product(output_tile_assignment_dims)) {
616 return HloSharding::Replicate(index_sharding.metadata());
617 }
618 new_tile_assignment.Reshape(output_tile_assignment_dims);
619 return index_sharding.ReplicateOnLastTileDim()
620 ? HloSharding::PartialTile(new_tile_assignment,
621 index_sharding.metadata())
622 : HloSharding::Subgroup(new_tile_assignment,
623 index_sharding.subgroup_types(),
624 index_sharding.metadata());
625 }
626
GatherIndexSharding(const HloSharding & output_sharding,const HloInstruction * hlo)627 HloSharding GatherIndexSharding(const HloSharding& output_sharding,
628 const HloInstruction* hlo) {
629 CHECK(hlo->opcode() == HloOpcode::kGather);
630 if (output_sharding.IsTileMaximal()) {
631 return output_sharding;
632 }
633
634 const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
635 std::vector<int64_t> index_tile_assignment_dims;
636 // Relevant output dims have shardings passed to the index.
637 std::vector<int64_t> relevant_output_dims;
638 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
639 if (!absl::c_binary_search(dnums.offset_dims(), i)) {
640 index_tile_assignment_dims.push_back(
641 output_sharding.tile_assignment().dim(i));
642 relevant_output_dims.push_back(i);
643 }
644 }
645 int64_t index_rank = hlo->operand(1)->shape().rank();
646
647 // Indices sharding on `index_vector_dim` is not supported yet.
648 if (index_rank > index_tile_assignment_dims.size()) {
649 index_tile_assignment_dims.insert(
650 index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1);
651 }
652
653 if (Product(index_tile_assignment_dims) == 1) {
654 return HloSharding::Replicate(output_sharding.metadata());
655 }
656 HloSharding relevant_output_sharding =
657 PartiallyReplicateTiledShardingOnAllDimsExcept(output_sharding,
658 relevant_output_dims);
659 if (relevant_output_sharding.IsTileMaximal()) {
660 return relevant_output_sharding;
661 }
662 for (int64_t i = relevant_output_sharding.TiledDataRank();
663 i < relevant_output_sharding.tile_assignment().num_dimensions(); ++i) {
664 index_tile_assignment_dims.push_back(
665 relevant_output_sharding.tile_assignment().dim(i));
666 }
667
668 Array<int64_t> new_tile_assignment =
669 relevant_output_sharding.tile_assignment();
670 new_tile_assignment.Reshape(index_tile_assignment_dims);
671 return relevant_output_sharding.ReplicateOnLastTileDim()
672 ? HloSharding::PartialTile(new_tile_assignment,
673 output_sharding.metadata())
674 : HloSharding::Subgroup(new_tile_assignment,
675 relevant_output_sharding.subgroup_types(),
676 output_sharding.metadata());
677 }
678
GatherEffectiveOutputSharding(const HloInstruction & hlo)679 HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) {
680 if (hlo.sharding().IsTileMaximal()) {
681 return hlo.sharding();
682 }
683
684 const GatherDimensionNumbers& dnums = hlo.gather_dimension_numbers();
685 std::vector<int64_t> tile_assignment_dims(hlo.shape().rank());
686 int64_t num_elements = 1;
687 for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
688 if (!absl::c_binary_search(dnums.offset_dims(), i)) {
689 tile_assignment_dims[i] = hlo.sharding().tile_assignment().dim(i);
690 num_elements *= hlo.sharding().tile_assignment().dim(i);
691 } else {
692 tile_assignment_dims[i] = 1;
693 }
694 }
695 if (num_elements == hlo.sharding().tile_assignment().num_elements()) {
696 // Output sharding is only on non offset dimensions. We use output sharding
697 // to shard this gather op directly.
698 return hlo.sharding();
699 }
700
701 if (num_elements == 1) {
702 // Output sharding is only on offset dimensions. We do not shard this gather
703 // op. Return a tile maximal sharding with the first device in output
704 // sharding tile assignment.
705 return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin(),
706 hlo.sharding().metadata());
707 }
708
709 // Output sharding is on both offset and non offset dimensions. We shard the
710 // gather op only on non offset dimensions.
711 // For example:
712 // - the gather op has sharding [2,2]{0,1,2,3},
713 // - first dimension is non offset dimension,
714 // - second dimension is offset dimension,
715 // Then the result sharding will be [2,1]{0,2}.
716 std::vector<int64_t> slice_starts(hlo.shape().rank(), 0LL),
717 slice_limits(hlo.shape().rank());
718 for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
719 if (!absl::c_binary_search(dnums.offset_dims(), i)) {
720 slice_limits[i] = hlo.sharding().tile_assignment().dim(i);
721 } else {
722 slice_limits[i] = 1;
723 }
724 }
725 Array<int64_t> tile_assignment =
726 hlo.sharding().tile_assignment().Slice(slice_starts, slice_limits);
727 return HloSharding::Tile(tile_assignment, hlo.sharding().metadata());
728 }
729
ScatterIndexSharding(const HloSharding & data_sharding,const HloScatterInstruction * scatter)730 HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
731 const HloScatterInstruction* scatter) {
732 if (data_sharding.IsTileMaximal()) {
733 return data_sharding;
734 }
735
736 const ScatterDimensionNumbers& dnums = scatter->scatter_dimension_numbers();
737 std::vector<int64_t> index_tile_assignment_dims;
738 std::vector<int64_t> relevant_data_dims;
739 for (int64_t i = 0; i < scatter->scatter_updates()[0]->shape().rank(); ++i) {
740 if (!absl::c_binary_search(dnums.update_window_dims(), i)) {
741 index_tile_assignment_dims.push_back(
742 data_sharding.tile_assignment().dim(i));
743 relevant_data_dims.push_back(i);
744 }
745 }
746 // Indices sharding on `index_vector_dim` is not supported yet.
747 if (index_tile_assignment_dims.size() <
748 scatter->scatter_indices()->shape().rank()) {
749 index_tile_assignment_dims.insert(
750 index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1);
751 }
752 HloSharding relevant_data_sharding =
753 PartiallyReplicateTiledShardingOnAllDimsExcept(data_sharding,
754 relevant_data_dims);
755 if (relevant_data_sharding.IsTileMaximal()) {
756 return relevant_data_sharding;
757 }
758 for (int64_t i = relevant_data_sharding.TiledDataRank();
759 i < relevant_data_sharding.tile_assignment().num_dimensions(); ++i) {
760 index_tile_assignment_dims.push_back(
761 relevant_data_sharding.tile_assignment().dim(i));
762 }
763 auto new_tile_assignment = relevant_data_sharding.tile_assignment();
764 new_tile_assignment.Reshape(index_tile_assignment_dims);
765 return relevant_data_sharding.ReplicateOnLastTileDim()
766 ? HloSharding::PartialTile(new_tile_assignment,
767 data_sharding.metadata())
768 : HloSharding::Subgroup(new_tile_assignment,
769 relevant_data_sharding.subgroup_types(),
770 data_sharding.metadata());
771 }
772
ScatterDataSharding(const HloSharding & index_sharding,const HloScatterInstruction * scatter)773 HloSharding ScatterDataSharding(const HloSharding& index_sharding,
774 const HloScatterInstruction* scatter) {
775 if (index_sharding.IsTileMaximal()) {
776 return index_sharding;
777 }
778
779 const ScatterDimensionNumbers& dnums = scatter->scatter_dimension_numbers();
780 std::vector<int64_t> data_tile_assignment_dims;
781 std::vector<int64_t> relevant_index_dims;
782 const int64_t rank = scatter->scatter_updates()[0]->shape().rank();
783 data_tile_assignment_dims.reserve(rank);
784 for (int64_t i = 0, index_dim = 0; i < rank; ++i) {
785 if (absl::c_binary_search(dnums.update_window_dims(), i)) {
786 data_tile_assignment_dims.push_back(1);
787 } else {
788 data_tile_assignment_dims.push_back(
789 index_sharding.tile_assignment().dim(index_dim));
790 relevant_index_dims.push_back(index_dim);
791 index_dim++;
792 }
793 }
794 auto relevant_index_sharding = PartiallyReplicateTiledShardingOnAllDimsExcept(
795 index_sharding, relevant_index_dims);
796 if (relevant_index_sharding.IsTileMaximal()) {
797 return relevant_index_sharding;
798 }
799 for (int64_t i = relevant_index_sharding.TiledDataRank();
800 i < relevant_index_sharding.tile_assignment().num_dimensions(); ++i) {
801 data_tile_assignment_dims.push_back(
802 relevant_index_sharding.tile_assignment().dim(i));
803 }
804 Array<int64_t> new_tile_assignment =
805 relevant_index_sharding.tile_assignment();
806 new_tile_assignment.Reshape(data_tile_assignment_dims);
807 return relevant_index_sharding.ReplicateOnLastTileDim()
808 ? HloSharding::PartialTile(new_tile_assignment,
809 index_sharding.metadata())
810 : HloSharding::Subgroup(new_tile_assignment,
811 relevant_index_sharding.subgroup_types(),
812 index_sharding.metadata());
813 }
814
ScatterEffectiveIndexSharding(const HloSharding & index_sharding,const HloScatterInstruction & scatter)815 HloSharding ScatterEffectiveIndexSharding(
816 const HloSharding& index_sharding, const HloScatterInstruction& scatter) {
817 if (index_sharding.IsTileMaximal()) {
818 return index_sharding;
819 }
820
821 // Only shard on first "number of scatter_window_dims" dimensions.
822 const ScatterDimensionNumbers& dnums = scatter.scatter_dimension_numbers();
823 int64_t num_elements = 1;
824 int64_t index_dim = 0;
825 for (int64_t i = 0; i < scatter.shape().rank(); ++i) {
826 if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
827 num_elements *= index_sharding.tile_assignment().dim(index_dim);
828 index_dim++;
829 }
830 }
831 if (num_elements == index_sharding.tile_assignment().num_elements()) {
832 // Index sharding is only on scatter_window_dims. We use this index sharding
833 // directly.
834 return index_sharding;
835 }
836
837 // Index sharding is only on update_window_dims. We do not shard this scatter
838 // op. Return a tile maximal sharding with the first device in index sharding
839 // tile assignment.
840 if (num_elements == 1) {
841 return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin(),
842 index_sharding.metadata());
843 }
844
845 const int64_t index_rank = scatter.scatter_indices()->shape().rank();
846 std::vector<int64_t> slice_starts(index_rank, 0LL), slice_limits(index_rank);
847 for (int64_t i = 0; i < index_rank; ++i) {
848 if (i < index_dim) {
849 slice_limits[i] = index_sharding.tile_assignment().dim(i);
850 } else {
851 slice_limits[i] = 1;
852 }
853 }
854 Array<int64_t> tile_assignment =
855 index_sharding.tile_assignment().Slice(slice_starts, slice_limits);
856 return HloSharding::Tile(tile_assignment, index_sharding.metadata());
857 }
858
ScatterEffectiveDataSharding(const HloSharding & data_sharding,const HloScatterInstruction & scatter)859 HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
860 const HloScatterInstruction& scatter) {
861 if (data_sharding.IsTileMaximal()) {
862 return data_sharding;
863 }
864
865 const ScatterDimensionNumbers& dnums = scatter.scatter_dimension_numbers();
866 const int64_t data_rank = scatter.scatter_updates()[0]->shape().rank();
867 std::vector<int64_t> tile_assignment_dims(data_rank, 1LL);
868 int64_t num_elements = 1;
869 for (int64_t i = 0; i < scatter.shape().rank(); ++i) {
870 if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
871 CHECK_LT(i, data_rank);
872 tile_assignment_dims[i] = data_sharding.tile_assignment().dim(i);
873 num_elements *= data_sharding.tile_assignment().dim(i);
874 }
875 }
876 if (num_elements == data_sharding.tile_assignment().num_elements()) {
877 // Data sharding is only on scatter_window_dims. We use this data sharding
878 // directly.
879 return data_sharding;
880 }
881
882 if (num_elements == 1) {
883 // Data sharding is only on update_window_dims. We do not shard this
884 // scatter op. Return a tile maximal sharding with the first device in
885 // data sharding tile assignment.
886 return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin(),
887 data_sharding.metadata());
888 }
889
890 // Data sharding is on both update_window_dims and scatter_window_dims. We
891 // shard the scatter op only on scatter_window_dims. For example:
892 // - the scatter data has sharding [2,2]{0,1,2,3},
893 // - first dimension is scatter_window_dims,
894 // - second dimension is update_window_dims,
895 // Then the result sharding will be [2,1]{0,2}.
896 std::vector<int64_t> slice_starts(data_rank, 0LL);
897 Array<int64_t> tile_assignment =
898 data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims);
899 return HloSharding::Tile(tile_assignment, data_sharding.metadata());
900 }
901
902 namespace {
903
904 // If partitioning in the operand only happens in dimensions in passthrough
905 // dimensions (offset dimensions in the gather output (or scatter update) that
906 // have the same size as the operand), returns the corresponding output (or
907 // update) sharding by passing through the input sharding.
PassthroughOperandToGatherOutputOrScatterUpdate(const Shape & operand_shape,const HloSharding & operand_sharding,const Shape & update_or_gather_shape,absl::Span<const int64_t> collapsed_or_inserted_dims,absl::Span<const int64_t> index_map,absl::Span<const int64_t> offset_or_window_dims,absl::Span<const int64_t> slice_size)908 std::optional<HloSharding> PassthroughOperandToGatherOutputOrScatterUpdate(
909 const Shape& operand_shape, const HloSharding& operand_sharding,
910 const Shape& update_or_gather_shape,
911 absl::Span<const int64_t> collapsed_or_inserted_dims,
912 absl::Span<const int64_t> index_map,
913 absl::Span<const int64_t> offset_or_window_dims,
914 absl::Span<const int64_t> slice_size) {
915 if (operand_sharding.IsTileMaximal()) {
916 return operand_sharding;
917 }
918 std::vector<int64_t> passthrough_tile(update_or_gather_shape.rank(), 1);
919 int64_t collapsed = 0;
920 for (int64_t i = 0; i < operand_shape.rank(); ++i) {
921 int64_t dim_partitions = operand_sharding.tile_assignment().dim(i);
922 if (absl::c_linear_search(collapsed_or_inserted_dims, i)) {
923 if (dim_partitions > 1) {
924 return std::nullopt;
925 }
926 collapsed++;
927 continue;
928 }
929 if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) {
930 return std::nullopt;
931 }
932 int64_t offset_dim = offset_or_window_dims[i - collapsed];
933 if (i - collapsed > 0 &&
934 offset_dim < offset_or_window_dims[i - collapsed - 1]) {
935 // Output offsets are transposed, we do not support this case.
936 return std::nullopt;
937 }
938 passthrough_tile[offset_dim] = dim_partitions;
939 }
940 for (int64_t i = operand_sharding.TiledDataRank();
941 i < operand_sharding.tile_assignment().num_dimensions(); ++i) {
942 passthrough_tile.push_back(operand_sharding.tile_assignment().dim(i));
943 }
944 Array<int64_t> tile_assignment = operand_sharding.tile_assignment();
945 tile_assignment.Reshape(passthrough_tile);
946 return operand_sharding.ReplicateOnLastTileDim()
947 ? HloSharding::PartialTile(tile_assignment,
948 operand_sharding.metadata())
949 : HloSharding::Subgroup(tile_assignment,
950 operand_sharding.subgroup_types(),
951 operand_sharding.metadata());
952 }
953
954 // Inverse of PassthroughOperandToGatherOutputOrScatterUpdate.
PassthroughGatherOutputOrScatterUpdateToOperand(const Shape & operand_shape,const HloSharding & update_or_gather_sharding,absl::Span<const int64_t> collapsed_or_inserted_dims,absl::Span<const int64_t> index_map,absl::Span<const int64_t> offset_or_window_dims,absl::Span<const int64_t> slice_size)955 std::optional<HloSharding> PassthroughGatherOutputOrScatterUpdateToOperand(
956 const Shape& operand_shape, const HloSharding& update_or_gather_sharding,
957 absl::Span<const int64_t> collapsed_or_inserted_dims,
958 absl::Span<const int64_t> index_map,
959 absl::Span<const int64_t> offset_or_window_dims,
960 absl::Span<const int64_t> slice_size) {
961 if (update_or_gather_sharding.IsTileMaximal()) {
962 return update_or_gather_sharding;
963 }
964 std::vector<int64_t> passthrough_tile(operand_shape.rank(), 1);
965 int64_t collapsed = 0;
966 // Relevant dims have shardings passed to the operand.
967 std::vector<int64_t> relevant_update_or_gather_dims;
968 for (int64_t i = 0; i < operand_shape.rank(); ++i) {
969 if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
970 absl::c_linear_search(index_map, i)) {
971 collapsed++;
972 continue;
973 }
974 int64_t offset_dim = offset_or_window_dims[i - collapsed];
975 int64_t dim_partitions =
976 update_or_gather_sharding.tile_assignment().dim(offset_dim);
977 if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) {
978 return std::nullopt;
979 }
980 if (i - collapsed > 0 &&
981 offset_dim < offset_or_window_dims[i - collapsed - 1]) {
982 // Output offsets are transposed, we do not support this case.
983 return std::nullopt;
984 }
985 relevant_update_or_gather_dims.push_back(offset_dim);
986 passthrough_tile[i] = dim_partitions;
987 }
988
989 HloSharding relevant_sharding =
990 PartiallyReplicateTiledShardingOnAllDimsExcept(
991 update_or_gather_sharding, relevant_update_or_gather_dims);
992 if (relevant_sharding.IsTileMaximal()) {
993 return relevant_sharding;
994 }
995 for (int64_t i = relevant_sharding.TiledDataRank();
996 i < relevant_sharding.tile_assignment().num_dimensions(); ++i) {
997 passthrough_tile.push_back(relevant_sharding.tile_assignment().dim(i));
998 }
999 Array<int64_t> tile_assignment = relevant_sharding.tile_assignment();
1000 tile_assignment.Reshape(passthrough_tile);
1001 return relevant_sharding.ReplicateOnLastTileDim()
1002 ? HloSharding::PartialTile(tile_assignment,
1003 update_or_gather_sharding.metadata())
1004 : HloSharding::Subgroup(tile_assignment,
1005 relevant_sharding.subgroup_types(),
1006 update_or_gather_sharding.metadata());
1007 }
1008
1009 // Collect data operand sharding for a gather with parallel dimensions from
1010 // the output.
GatherParallelDataOperandSharding(const HloSharding & output_sharding,const HloInstruction & gather,const GatherParallelDims & parallel_dims)1011 std::optional<HloSharding> GatherParallelDataOperandSharding(
1012 const HloSharding& output_sharding, const HloInstruction& gather,
1013 const GatherParallelDims& parallel_dims) {
1014 if (output_sharding.IsTileMaximal()) {
1015 return output_sharding;
1016 }
1017 auto output_parallel_dims = GatherParallelOutputDims(gather, parallel_dims);
1018 auto output_aligned_operand_parallel_dims =
1019 GatherOutputAlignedOperandParallelDims(gather, parallel_dims);
1020 const Shape gather_shape = gather.shape();
1021 CHECK_EQ(output_parallel_dims.size(),
1022 output_aligned_operand_parallel_dims.size());
1023 std::vector<int64_t> operand_tile_assignment(
1024 gather.operand(0)->shape().rank(), 1);
1025 std::vector<int64_t> relevant_output_dims;
1026 for (int i = 0, parallel_idx = 0; i < gather_shape.rank(); ++i) {
1027 if (parallel_idx >= output_parallel_dims.size() ||
1028 output_parallel_dims[parallel_idx] != i) {
1029 continue;
1030 }
1031 const int64_t operand_dim =
1032 output_aligned_operand_parallel_dims[parallel_idx++];
1033 operand_tile_assignment[operand_dim] =
1034 output_sharding.tile_assignment().dim(i);
1035 relevant_output_dims.push_back(i);
1036 }
1037 HloSharding relevant_output_sharding =
1038 PartiallyReplicateTiledShardingOnAllDimsExcept(output_sharding,
1039 relevant_output_dims);
1040 if (relevant_output_sharding.IsTileMaximal()) {
1041 return std::move(relevant_output_sharding);
1042 }
1043
1044 for (int64_t i = relevant_output_sharding.TiledDataRank();
1045 i < relevant_output_sharding.tile_assignment().num_dimensions(); ++i) {
1046 operand_tile_assignment.push_back(
1047 relevant_output_sharding.tile_assignment().dim(i));
1048 }
1049 Array<int64_t> tile_assignment = relevant_output_sharding.tile_assignment();
1050 tile_assignment.Reshape(operand_tile_assignment);
1051 return relevant_output_sharding.ReplicateOnLastTileDim()
1052 ? HloSharding::PartialTile(tile_assignment,
1053 output_sharding.metadata())
1054 : HloSharding::Subgroup(tile_assignment,
1055 relevant_output_sharding.subgroup_types(),
1056 output_sharding.metadata());
1057 }
1058
1059 } // namespace
1060
GatherOutputShardingFromDataOperand(const HloSharding & data_operand_sharding,const HloInstruction & hlo,absl::Span<const int64_t> slice_sizes,const Shape & output_shape,const Shape & operand_shape)1061 std::optional<HloSharding> GatherOutputShardingFromDataOperand(
1062 const HloSharding& data_operand_sharding, const HloInstruction& hlo,
1063 absl::Span<const int64_t> slice_sizes, const Shape& output_shape,
1064 const Shape& operand_shape) {
1065 const auto& dnums = hlo.gather_dimension_numbers();
1066 std::vector<int64_t> collapsed_slice_dims(
1067 dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end());
1068 std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
1069 dnums.start_index_map().end());
1070 std::vector<int64_t> offset_dims(dnums.offset_dims().begin(),
1071 dnums.offset_dims().end());
1072 return PassthroughOperandToGatherOutputOrScatterUpdate(
1073 operand_shape, data_operand_sharding, output_shape, collapsed_slice_dims,
1074 start_index_map, offset_dims, slice_sizes);
1075 }
1076
GatherDataOperandShardingFromOutput(const HloSharding & output_sharding,const HloInstruction & hlo)1077 std::optional<HloSharding> GatherDataOperandShardingFromOutput(
1078 const HloSharding& output_sharding, const HloInstruction& hlo) {
1079 const auto& dnums = hlo.gather_dimension_numbers();
1080 std::vector<int64_t> collapsed_slice_dims(
1081 dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end());
1082 std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
1083 dnums.start_index_map().end());
1084 std::vector<int64_t> offset_dims(dnums.offset_dims().begin(),
1085 dnums.offset_dims().end());
1086
1087 std::optional<HloSharding> parallel_sharding;
1088 auto parallel_dims = GetGatherBatchParallelDims(hlo);
1089 if (parallel_dims) {
1090 // Prioritize parallel sharding first as this is how it is in
1091 // spmd_partitioner.
1092 parallel_sharding =
1093 GatherParallelDataOperandSharding(output_sharding, hlo, *parallel_dims);
1094 }
1095 std::optional<HloSharding> passthrough_sharding =
1096 PassthroughGatherOutputOrScatterUpdateToOperand(
1097 hlo.operand(0)->shape(), output_sharding, collapsed_slice_dims,
1098 start_index_map, offset_dims, hlo.gather_slice_sizes());
1099 // Try to merge the two shardings or return the one that is present if only
1100 // one of the two is.
1101 if (!passthrough_sharding) {
1102 return parallel_sharding;
1103 }
1104 if (!parallel_sharding) {
1105 return passthrough_sharding;
1106 }
1107 if (MergeSharding(*parallel_sharding, &*passthrough_sharding,
1108 /*may_combine_partial_sharding=*/true)) {
1109 return passthrough_sharding;
1110 }
1111 if (MergeSharding(*passthrough_sharding, &*parallel_sharding,
1112 /*may_combine_partial_sharding=*/true)) {
1113 return parallel_sharding;
1114 }
1115 return parallel_sharding;
1116 }
1117
GetScatterSliceSize(const Shape & operand_shape,const Shape & update_shape,const ScatterDimensionNumbers & dnums)1118 std::vector<int64_t> GetScatterSliceSize(const Shape& operand_shape,
1119 const Shape& update_shape,
1120 const ScatterDimensionNumbers& dnums) {
1121 std::vector<int64_t> slice_size(operand_shape.rank(), 1);
1122 int64_t num_update_window_dims = 0;
1123 for (int64_t i = 0; i < operand_shape.rank(); ++i) {
1124 if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
1125 continue;
1126 }
1127 slice_size[i] = update_shape.dimensions(
1128 dnums.update_window_dims(num_update_window_dims++));
1129 }
1130 return slice_size;
1131 }
1132
ScatterOutputShardingFromUpdate(const HloSharding & update_sharding,const HloScatterInstruction & scatter)1133 std::optional<HloSharding> ScatterOutputShardingFromUpdate(
1134 const HloSharding& update_sharding, const HloScatterInstruction& scatter) {
1135 const auto& dnums = scatter.scatter_dimension_numbers();
1136 std::vector<int64_t> inserted_window_dims(
1137 dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end());
1138 std::vector<int64_t> scatter_dims_to_operand_dims(
1139 dnums.scatter_dims_to_operand_dims().begin(),
1140 dnums.scatter_dims_to_operand_dims().end());
1141 std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
1142 dnums.update_window_dims().end());
1143 std::vector<int64_t> slice_size =
1144 GetScatterSliceSize(scatter.scatter_operands()[0]->shape(),
1145 scatter.scatter_updates()[0]->shape(), dnums);
1146 return PassthroughGatherOutputOrScatterUpdateToOperand(
1147 scatter.shape(), update_sharding, inserted_window_dims,
1148 scatter_dims_to_operand_dims, update_window_dims, slice_size);
1149 }
1150
ScatterUpdateShardingFromOutput(const HloSharding & per_output_sharding,const HloScatterInstruction & scatter)1151 std::optional<HloSharding> ScatterUpdateShardingFromOutput(
1152 const HloSharding& per_output_sharding,
1153 const HloScatterInstruction& scatter) {
1154 const auto& dnums = scatter.scatter_dimension_numbers();
1155 std::vector<int64_t> inserted_window_dims(
1156 dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end());
1157 std::vector<int64_t> scatter_dims_to_operand_dims(
1158 dnums.scatter_dims_to_operand_dims().begin(),
1159 dnums.scatter_dims_to_operand_dims().end());
1160 std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
1161 dnums.update_window_dims().end());
1162 std::vector<int64_t> slice_size =
1163 GetScatterSliceSize(scatter.scatter_operands()[0]->shape(),
1164 scatter.scatter_updates()[0]->shape(), dnums);
1165 return PassthroughOperandToGatherOutputOrScatterUpdate(
1166 scatter.scatter_operands()[0]->shape(), per_output_sharding,
1167 scatter.scatter_updates()[0]->shape(), inserted_window_dims,
1168 scatter_dims_to_operand_dims, update_window_dims, slice_size);
1169 }
1170
1171 StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>>
IdentityValueAndHloOpcodeForScatterReduceComputation(const HloScatterInstruction & scatter)1172 IdentityValueAndHloOpcodeForScatterReduceComputation(
1173 const HloScatterInstruction& scatter) {
1174 auto computation = scatter.to_apply();
1175 // We only handle computations with 2 parameters and only 1 calculation.
1176 if (computation->instruction_count() != 3) {
1177 return Status(
1178 tensorflow::error::Code::INVALID_ARGUMENT,
1179 "Expected scatter reduce computation with 2 parameters and only 1 "
1180 "calculation");
1181 }
1182
1183 auto root_instruction = computation->root_instruction();
1184 if (root_instruction->opcode() == HloOpcode::kAdd ||
1185 root_instruction->opcode() == HloOpcode::kOr) {
1186 return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::Zero(
1187 scatter.shape().element_type())),
1188 root_instruction->opcode());
1189 } else if (root_instruction->opcode() == HloOpcode::kMultiply ||
1190 root_instruction->opcode() == HloOpcode::kAnd) {
1191 return std::make_pair(HloInstruction::CreateConstant(
1192 LiteralUtil::One(scatter.shape().element_type())),
1193 root_instruction->opcode());
1194 } else if (root_instruction->opcode() == HloOpcode::kMaximum) {
1195 return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MinValue(
1196 scatter.shape().element_type())),
1197 root_instruction->opcode());
1198 } else if (root_instruction->opcode() == HloOpcode::kMinimum) {
1199 return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MaxValue(
1200 scatter.shape().element_type())),
1201 root_instruction->opcode());
1202 }
1203
1204 return Status(tensorflow::error::Code::INVALID_ARGUMENT,
1205 "Expected scatter reduce computation which is "
1206 "add/or/multiply/add/min/max");
1207 }
1208
1209 namespace {
1210
DevicesForShardingInternal(const HloSharding & sharding,const absl::flat_hash_set<int64_t> & available_devices,absl::flat_hash_set<int64_t> * used)1211 void DevicesForShardingInternal(
1212 const HloSharding& sharding,
1213 const absl::flat_hash_set<int64_t>& available_devices,
1214 absl::flat_hash_set<int64_t>* used) {
1215 if (sharding.IsTuple()) {
1216 for (const auto& subsharding : sharding.tuple_elements()) {
1217 DevicesForShardingInternal(subsharding, available_devices, used);
1218 }
1219 return;
1220 }
1221
1222 if (sharding.IsReplicated()) {
1223 for (int64_t device : available_devices) {
1224 if (!HloSharding::IsReservedDevice(device)) {
1225 used->insert(device);
1226 }
1227 }
1228 return;
1229 }
1230
1231 DCHECK(std::all_of(
1232 sharding.tile_assignment().begin(), sharding.tile_assignment().end(),
1233 [&](int64_t device) { return available_devices.contains(device); }));
1234 sharding.tile_assignment().Each(
1235 [&](absl::Span<const int64_t> /*indices*/, int64_t device) {
1236 used->insert(device);
1237 });
1238 }
1239
1240 } // namespace
1241
DevicesForSharding(const HloSharding & sharding,absl::Span<const int64_t> available_devices)1242 std::vector<int64_t> DevicesForSharding(
1243 const HloSharding& sharding, absl::Span<const int64_t> available_devices) {
1244 absl::flat_hash_set<int64_t> available_set;
1245 for (int64_t device : available_devices) {
1246 available_set.insert(device);
1247 }
1248 absl::flat_hash_set<int64_t> used_set;
1249 DevicesForShardingInternal(sharding, available_set, &used_set);
1250 std::vector<int64_t> devices;
1251 for (int64_t device : available_devices) {
1252 if (used_set.contains(device)) {
1253 devices.push_back(device);
1254 }
1255 }
1256 return devices;
1257 }
1258
PartiallyReplicateTiledShardingOnDims(const HloSharding & sharding,absl::Span<const int64_t> dims_to_replicate)1259 HloSharding PartiallyReplicateTiledShardingOnDims(
1260 const HloSharding& sharding, absl::Span<const int64_t> dims_to_replicate) {
1261 if (sharding.IsTileMaximal()) {
1262 return sharding;
1263 }
1264 int64_t group_count = 1;
1265 std::vector<int64_t> valid_dims_to_replicate;
1266 for (int64_t dim : dims_to_replicate) {
1267 if (dim >= sharding.TiledDataRank()) {
1268 continue;
1269 }
1270 valid_dims_to_replicate.push_back(dim);
1271 group_count *= sharding.tile_assignment().dim(dim);
1272 }
1273 if (group_count == 1) {
1274 return sharding;
1275 }
1276 if (group_count == sharding.NumTiles() && sharding.subgroup_types().empty()) {
1277 return HloSharding::Replicate(sharding.metadata());
1278 }
1279 std::vector<int64_t> dim_permutation(sharding.TiledDataRank());
1280 absl::c_iota(dim_permutation, 0);
1281 absl::c_stable_sort(dim_permutation, [&](const int64_t a, const int64_t b) {
1282 return absl::c_linear_search(valid_dims_to_replicate, a) <
1283 absl::c_linear_search(valid_dims_to_replicate, b);
1284 });
1285 auto new_tile =
1286 TransposeSharding(sharding, dim_permutation).tile_assignment();
1287 std::vector<int64_t> new_tile_shape(
1288 sharding.tile_assignment().dimensions().begin(),
1289 sharding.tile_assignment().dimensions().end());
1290 for (int64_t dim : valid_dims_to_replicate) {
1291 new_tile_shape[dim] = 1;
1292 }
1293 if (sharding.ReplicateOnLastTileDim()) {
1294 new_tile_shape.back() *= group_count;
1295 new_tile.Reshape(new_tile_shape);
1296 return HloSharding::PartialTile(new_tile, sharding.metadata());
1297 } else {
1298 new_tile_shape.insert(new_tile_shape.begin() + sharding.TiledDataRank(),
1299 group_count);
1300 new_tile.Reshape(new_tile_shape);
1301 std::vector<OpSharding::Type> subgroup_types;
1302 subgroup_types.push_back(OpSharding::REPLICATED);
1303 for (OpSharding::Type type : sharding.subgroup_types()) {
1304 subgroup_types.push_back(type);
1305 }
1306 return HloSharding::Subgroup(new_tile, subgroup_types, sharding.metadata());
1307 }
1308 }
1309
PartiallyReplicateTiledShardingOnAllDimsExcept(const HloSharding & sharding,absl::Span<const int64_t> dims_to_keep)1310 HloSharding PartiallyReplicateTiledShardingOnAllDimsExcept(
1311 const HloSharding& sharding, absl::Span<const int64_t> dims_to_keep) {
1312 if (sharding.IsTileMaximal()) {
1313 return sharding;
1314 }
1315 std::vector<int64_t> dims_to_replicate(sharding.TiledDataRank());
1316 absl::c_iota(dims_to_replicate, 0);
1317
1318 dims_to_replicate.erase(
1319 std::remove_if(
1320 dims_to_replicate.begin(), dims_to_replicate.end(),
1321 [&](int64_t i) { return absl::c_linear_search(dims_to_keep, i); }),
1322 dims_to_replicate.end());
1323 return PartiallyReplicateTiledShardingOnDims(sharding, dims_to_replicate);
1324 }
1325
ReplicateAllDataDims(const HloSharding & sharding,int64_t data_rank)1326 HloSharding ReplicateAllDataDims(const HloSharding& sharding,
1327 int64_t data_rank) {
1328 if (sharding.IsManual()) {
1329 return sharding;
1330 }
1331 if (sharding.subgroup_types().empty()) {
1332 return HloSharding::Replicate(sharding.metadata());
1333 }
1334 HloSharding result =
1335 PartiallyReplicateTiledShardingOnAllDimsExcept(sharding, {});
1336 if (data_rank >= 0 && data_rank != result.TiledDataRank() &&
1337 !result.IsTileMaximal()) {
1338 std::vector<int64_t> new_tile_shape(data_rank, 1);
1339 for (int64_t i = result.TiledDataRank();
1340 i < result.tile_assignment().num_dimensions(); ++i) {
1341 new_tile_shape.push_back(result.tile_assignment().dim(i));
1342 }
1343 auto tile = result.tile_assignment();
1344 tile.Reshape(new_tile_shape);
1345 result = HloSharding::Subgroup(tile, result.subgroup_types());
1346 }
1347 return result;
1348 }
1349
RemoveShapeDimensions(const HloSharding & sharding,absl::Span<const int64_t> dims_to_remove)1350 HloSharding RemoveShapeDimensions(const HloSharding& sharding,
1351 absl::Span<const int64_t> dims_to_remove) {
1352 if (sharding.IsTileMaximal() || dims_to_remove.empty()) {
1353 return sharding;
1354 }
1355 std::vector<int64_t> new_tile_shape;
1356 new_tile_shape.reserve(sharding.tile_assignment().num_dimensions() -
1357 dims_to_remove.size());
1358 for (int64_t i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
1359 if (absl::c_linear_search(dims_to_remove, i)) {
1360 CHECK_EQ(sharding.tile_assignment().dim(i), 1);
1361 } else {
1362 new_tile_shape.push_back(sharding.tile_assignment().dim(i));
1363 }
1364 }
1365 auto new_tile = sharding.tile_assignment();
1366 new_tile.Reshape(new_tile_shape);
1367 return sharding.ReplicateOnLastTileDim()
1368 ? HloSharding::PartialTile(new_tile, sharding.metadata())
1369 : HloSharding::Subgroup(new_tile, sharding.subgroup_types(),
1370 sharding.metadata());
1371 }
1372
TransposeShardingWithCollapsedDims(const HloSharding & source,absl::Span<int64_t const> src_to_tgt,absl::Span<int64_t const> tgt_to_src)1373 std::optional<HloSharding> TransposeShardingWithCollapsedDims(
1374 const HloSharding& source, absl::Span<int64_t const> src_to_tgt,
1375 absl::Span<int64_t const> tgt_to_src) {
1376 if (source.IsTileMaximal()) {
1377 return source;
1378 }
1379 if (src_to_tgt.size() < source.tile_assignment().num_dimensions()) {
1380 // Add missing subgroup dims.
1381 std::vector<int64_t> new_src_to_tgt(src_to_tgt.begin(), src_to_tgt.end());
1382 std::vector<int64_t> new_tgt_to_src(tgt_to_src.begin(), tgt_to_src.end());
1383 for (int64_t i = 0;
1384 i < source.tile_assignment().num_dimensions() - src_to_tgt.size();
1385 ++i) {
1386 new_src_to_tgt.push_back(tgt_to_src.size() + i);
1387 new_tgt_to_src.push_back(src_to_tgt.size() + i);
1388 }
1389 return TransposeShardingWithCollapsedDims(source, new_src_to_tgt,
1390 new_tgt_to_src);
1391 }
1392 std::vector<int64_t> tgt_dims_skipping_new(tgt_to_src.size(), -1);
1393 int64_t skipped_tgt_dims = 0;
1394 int64_t src_non_subgroup_dims =
1395 src_to_tgt.size() - source.subgroup_types().size();
1396 int64_t tgt_non_subgroup_dims =
1397 tgt_to_src.size() - source.subgroup_types().size();
1398 for (int64_t i = 0; i < tgt_to_src.size(); ++i) {
1399 if (tgt_to_src[i] < 0) {
1400 CHECK_LT(i, tgt_non_subgroup_dims)
1401 << "Sharding transpose should not remove subgroup dims.";
1402 skipped_tgt_dims++;
1403 } else {
1404 tgt_dims_skipping_new[i] = i - skipped_tgt_dims;
1405 }
1406 }
1407 int64_t skipped_src_dims = absl::c_count(src_to_tgt, -1);
1408 std::vector<int64_t> perm(src_to_tgt.size());
1409 for (int64_t i = 0; i < src_non_subgroup_dims; ++i) {
1410 if (src_to_tgt[i] < 0) {
1411 if (source.tile_assignment().dim(i) > 1) {
1412 return std::nullopt;
1413 }
1414 perm[src_non_subgroup_dims - skipped_src_dims] = i;
1415 skipped_src_dims--;
1416 } else {
1417 perm[tgt_dims_skipping_new[src_to_tgt[i]]] = i;
1418 }
1419 }
1420 skipped_src_dims = absl::c_count(src_to_tgt, -1);
1421 for (int64_t i = src_non_subgroup_dims; i < src_to_tgt.size(); ++i) {
1422 CHECK_GE(src_to_tgt[i], tgt_non_subgroup_dims)
1423 << "Sharding transpose should not move subgroup dims before data dims.";
1424 perm[src_to_tgt[i] - skipped_tgt_dims + skipped_src_dims] = i;
1425 }
1426 auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm);
1427 auto reshape_tiles = tgt_sharding.tile_assignment();
1428 std::vector<int64_t> tgt_tiles(tgt_to_src.size(), 1);
1429 for (int64_t i = 0; i < tgt_tiles.size(); ++i) {
1430 if (tgt_to_src[i] >= 0) {
1431 int64_t dim = tgt_dims_skipping_new[i];
1432 if (i >= tgt_non_subgroup_dims) {
1433 dim += skipped_src_dims;
1434 }
1435 tgt_tiles[i] = reshape_tiles.dim(dim);
1436 }
1437 }
1438 reshape_tiles.Reshape(tgt_tiles);
1439 return source.ReplicateOnLastTileDim()
1440 ? HloSharding::PartialTile(reshape_tiles, source.metadata())
1441 : HloSharding::Subgroup(reshape_tiles, source.subgroup_types(),
1442 source.metadata());
1443 }
1444
GetDimensionForIota(const HloInstruction * maybe_iota)1445 std::optional<int64_t> GetDimensionForIota(const HloInstruction* maybe_iota) {
1446 if (auto* iota = DynCast<HloIotaInstruction>(maybe_iota)) {
1447 return iota->iota_dimension();
1448 }
1449
1450 if (maybe_iota->shape().element_type() != S32) {
1451 return std::nullopt;
1452 }
1453 if (maybe_iota->IsConstant()) {
1454 std::vector<bool> is_iota_dim(maybe_iota->shape().rank(), true);
1455 maybe_iota->literal().EachCell<int32_t>(
1456 [&](absl::Span<const int64_t> indices, int32_t val) {
1457 for (int64_t i = 0; i < indices.size(); ++i) {
1458 if (val != indices[i]) {
1459 is_iota_dim[i] = false;
1460 }
1461 }
1462 });
1463 for (int64_t i = 0; i < is_iota_dim.size(); ++i) {
1464 if (is_iota_dim[i] && maybe_iota->shape().dimensions(i) > 1) {
1465 return i;
1466 }
1467 }
1468 return std::nullopt;
1469 }
1470
1471 if (maybe_iota->opcode() == HloOpcode::kBroadcast) {
1472 auto operand_dim = GetDimensionForIota(maybe_iota->operand(0));
1473 if (operand_dim) {
1474 return maybe_iota->dimensions(*operand_dim);
1475 }
1476 return std::nullopt;
1477 }
1478 return std::nullopt;
1479 }
1480
GetGatherBatchParallelDims(const HloInstruction & hlo)1481 std::optional<GatherParallelDims> GetGatherBatchParallelDims(
1482 const HloInstruction& hlo) {
1483 const auto& dnums = hlo.gather_dimension_numbers();
1484 int64_t index_dim = dnums.index_vector_dim();
1485 // Try to identify if there's a dimension in the indices that is monotonically
1486 // increasing with a Iota across a certain dimension. This would mean that the
1487 // access in the relative dimension indexed by this index in the operand is
1488 // parallelizable and that we can shard the operand (and the index/output)
1489 // across such dimension.
1490 // For example the pattern:
1491 // %iota.1 = iota()
1492 // %indices = concatenate(..., %iota.1, ...)
1493 // ... = gather(..., %indices)
1494 // is common for tf.reverse_sequence and would match this case.
1495 absl::InlinedVector<const HloIotaInstruction*, 4> iotas;
1496 const HloInstruction* indices = hlo.operand(1);
1497 const int num_indices = dnums.start_index_map_size();
1498 std::vector<int64_t> index_parallel_in_dim(num_indices, -1);
1499 // Handle cases where we concatenate pieces of the indices one at a time.
1500 if (indices->opcode() == HloOpcode::kConcatenate &&
1501 indices->concatenate_dimension() == index_dim) {
1502 int concatenated_dims = 0;
1503 for (int i = 0; i < indices->operand_count(); ++i) {
1504 const HloInstruction* op = indices->operand(i);
1505 const int64_t num_indices_from_element =
1506 op->shape().dimensions_size() > index_dim
1507 ? op->shape().dimensions(index_dim)
1508 : 1;
1509 if (std::optional<int64_t> maybe_iota_dim = GetDimensionForIota(op)) {
1510 if (*maybe_iota_dim != index_dim) {
1511 for (int j = 0; j < num_indices_from_element; ++j) {
1512 index_parallel_in_dim[concatenated_dims + j] = *maybe_iota_dim;
1513 }
1514 }
1515 }
1516 concatenated_dims += num_indices_from_element;
1517 }
1518 } else if (std::optional<int64_t> maybe_iota_dim =
1519 GetDimensionForIota(indices)) {
1520 if (*maybe_iota_dim != index_dim) {
1521 // This is a case of a single iota with index_dim being out of bounds.
1522 const int64_t num_indices_from_element =
1523 indices->shape().dimensions_size() > index_dim
1524 ? indices->shape().dimensions(index_dim)
1525 : 1;
1526 index_parallel_in_dim.assign(num_indices_from_element, *maybe_iota_dim);
1527 }
1528 }
1529 absl::InlinedVector<int64_t, 1> indices_parallel_dims;
1530 absl::InlinedVector<int64_t, 1> operand_parallel_dims;
1531 // Map the parallelizable dimension from the iota to the dimensions of the
1532 // output and the operand. These dimensions are interconnected, but between
1533 // operands and index they could have different spots in the shape because the
1534 // position of the index dimension in the operand is determined by
1535 // start_index_map.
1536 for (int i = 0; i < index_parallel_in_dim.size(); ++i) {
1537 int index_parallel_dim = index_parallel_in_dim[i];
1538 if (index_parallel_dim == -1) {
1539 continue;
1540 }
1541 if (absl::c_linear_search(indices_parallel_dims, index_parallel_dim)) {
1542 return std::nullopt;
1543 }
1544 // Considered parallel only if the slice is of size 1 over the operand.
1545 if (hlo.gather_slice_sizes()[dnums.start_index_map(i)] == 1) {
1546 indices_parallel_dims.push_back(index_parallel_dim);
1547 operand_parallel_dims.push_back(dnums.start_index_map(i));
1548 } else {
1549 index_parallel_in_dim[i] = -1;
1550 }
1551 }
1552 absl::c_sort(indices_parallel_dims);
1553 if (!indices_parallel_dims.empty()) {
1554 return GatherParallelDims{indices_parallel_dims, operand_parallel_dims,
1555 index_parallel_in_dim};
1556 }
1557 return std::nullopt;
1558 }
1559
GatherParallelOutputDims(const HloInstruction & gather,const GatherParallelDims & parallel_dim)1560 absl::InlinedVector<int64_t, 1> GatherParallelOutputDims(
1561 const HloInstruction& gather, const GatherParallelDims& parallel_dim) {
1562 absl::InlinedVector<int64_t, 1> output_parallel_dims;
1563 auto indices_parallel_dims = parallel_dim.indices_parallel_dims;
1564 const Shape gather_shape = gather.shape();
1565 auto dnums = gather.gather_dimension_numbers();
1566 for (int i = 0, idx_dim = 0; i < gather_shape.dimensions_size(); ++i) {
1567 if (absl::c_linear_search(dnums.offset_dims(), i)) {
1568 continue;
1569 }
1570 const int index_dim =
1571 idx_dim < dnums.index_vector_dim() ? idx_dim : idx_dim + 1;
1572 if (absl::c_binary_search(indices_parallel_dims, index_dim)) {
1573 output_parallel_dims.push_back(i);
1574 }
1575 ++idx_dim;
1576 }
1577 return output_parallel_dims;
1578 }
1579
GatherOutputAlignedOperandParallelDims(const HloInstruction & gather,const GatherParallelDims & parallel_dims)1580 absl::InlinedVector<int64_t, 1> GatherOutputAlignedOperandParallelDims(
1581 const HloInstruction& gather, const GatherParallelDims& parallel_dims) {
1582 absl::InlinedVector<int64_t, 1> operand_parallel_dim_to_output(
1583 parallel_dims.operand_parallel_dims.size(), -1);
1584 auto dnums = gather.gather_dimension_numbers();
1585 CHECK_LE(parallel_dims.indices_parallel_dims.size(),
1586 parallel_dims.operand_parallel_dims.size());
1587 for (int i = 0; i < parallel_dims.index_parallel_in_dim.size(); ++i) {
1588 // This is the equivalent batch dimension of the indices that corresponds
1589 // to this index dimension.
1590 const int64_t index_parallel_dim = parallel_dims.index_parallel_in_dim[i];
1591 // If it's not an index that is parallel skip.
1592 if (index_parallel_dim == -1) {
1593 continue;
1594 }
1595 // This is small so just look linearly. Populate the operand parallel
1596 // dimensions based on the order of the index batch dims (which is the same
1597 // order as the output).
1598 for (int j = 0; j < parallel_dims.indices_parallel_dims.size(); ++j) {
1599 if (parallel_dims.indices_parallel_dims[j] == index_parallel_dim) {
1600 const int64_t operand_parallel_dim = dnums.start_index_map(i);
1601 if (operand_parallel_dim_to_output[j] == -1) {
1602 operand_parallel_dim_to_output[j] = operand_parallel_dim;
1603 }
1604 break;
1605 }
1606 }
1607 }
1608 return operand_parallel_dim_to_output;
1609 }
1610
ToString() const1611 std::string GroupedSharding::ToString() const {
1612 auto result = absl::StrCat("dims: ", absl::StrJoin(group_dims, ","),
1613 "\ndevice_groups:\n");
1614 absl::StrAppend(&result,
1615 "group dim sizes: ", absl::StrJoin(group_dim_sizes, ","));
1616 absl::StrAppend(&result, "data rank: ", data_rank);
1617 absl::StrAppend(&result, "subgroup manual: ", subgroup_manual);
1618 for (auto& device_group : device_groups) {
1619 absl::StrAppend(&result, "\t", absl::StrJoin(device_group, ","), "\n");
1620 }
1621 absl::StrAppend(&result, "inner sharding: ", sharding.ToString());
1622 return result;
1623 }
1624
GroupShardingOnDims(const HloSharding & sharding,absl::Span<const int64_t> group_dims,bool subgroup_manual)1625 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
1626 absl::Span<const int64_t> group_dims,
1627 bool subgroup_manual) {
1628 std::vector<int64_t> group_dim_shards(group_dims.size(), 1);
1629 return GroupShardingOnDims(sharding, group_dims, group_dim_shards,
1630 subgroup_manual);
1631 }
1632
GroupShardingOnDims(const HloSharding & sharding,absl::Span<const int64_t> group_dims,absl::Span<const int64_t> group_dim_shards,bool subgroup_manual)1633 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
1634 absl::Span<const int64_t> group_dims,
1635 absl::Span<const int64_t> group_dim_shards,
1636 bool subgroup_manual) {
1637 CHECK(!sharding.IsTileMaximal());
1638 std::vector<int64_t> grouped_tiling_dims =
1639 sharding.tile_assignment().dimensions();
1640 std::vector<int64_t> group_dim_sizes(group_dims.size());
1641 for (int64_t i = 0; i < group_dims.size(); ++i) {
1642 CHECK_EQ(grouped_tiling_dims[group_dims[i]] % group_dim_shards[i], 0);
1643 group_dim_sizes[i] =
1644 grouped_tiling_dims[group_dims[i]] / group_dim_shards[i];
1645 grouped_tiling_dims[group_dims[i]] = group_dim_shards[i];
1646 }
1647
1648 std::vector<std::vector<int64_t>> device_groups(Product(group_dim_sizes));
1649 sharding.tile_assignment().Each([&](absl::Span<const int64_t> indices,
1650 int64_t device) {
1651 int64_t group_id = 0;
1652 for (int64_t i = 0; i < group_dims.size(); ++i) {
1653 group_id *=
1654 sharding.tile_assignment().dim(group_dims[i]) / group_dim_shards[i];
1655 group_id += indices[group_dims[i]] / group_dim_shards[i];
1656 }
1657 device_groups[group_id].push_back(device);
1658 });
1659 auto grouped = GroupedSharding(
1660 std::move(device_groups),
1661 std::vector<int64_t>(group_dims.begin(), group_dims.end()),
1662 std::move(group_dim_sizes), sharding.tile_assignment().num_dimensions(),
1663 HloSharding::Replicate(), subgroup_manual);
1664 if (sharding.ReplicateOnLastTileDim()) {
1665 grouped.data_rank--;
1666 }
1667 if (sharding.IsManualSubgroup()) {
1668 grouped.data_rank -= sharding.subgroup_types().size();
1669 }
1670 if (Product(grouped_tiling_dims) == 1 ||
1671 (sharding.ReplicateOnLastTileDim() &&
1672 Product(grouped_tiling_dims) == grouped_tiling_dims.back())) {
1673 return grouped;
1674 }
1675 if (sharding.IsManualSubgroup()) {
1676 int64_t tile_dimensions = sharding.tile_assignment().num_dimensions();
1677 int64_t subgroup_size = sharding.subgroup_types().size();
1678 int64_t rank = tile_dimensions - subgroup_size;
1679 int num_dims_erase = 0;
1680 for (int i = 0; i < subgroup_size; i++) {
1681 if (sharding.subgroup_types()[i] == OpSharding::MANUAL) {
1682 grouped_tiling_dims.erase(grouped_tiling_dims.begin() + i + rank -
1683 num_dims_erase);
1684 num_dims_erase++;
1685 }
1686 }
1687 }
1688 if (sharding.ReplicateOnLastTileDim() && grouped_tiling_dims.back() == 1) {
1689 grouped_tiling_dims.pop_back();
1690 }
1691 Array<int64_t> grouped_tiling(grouped_tiling_dims);
1692 grouped_tiling.FillIota(0);
1693 grouped.sharding =
1694 sharding.ReplicateOnLastTileDim() &&
1695 grouped_tiling_dims.size() ==
1696 sharding.tile_assignment().num_dimensions()
1697 ? HloSharding::PartialTile(grouped_tiling, sharding.metadata())
1698 : HloSharding::Tile(grouped_tiling, sharding.metadata());
1699 return grouped;
1700 }
1701
GetManualSubgroupSharding(const HloSharding & sharding)1702 GroupedSharding GetManualSubgroupSharding(const HloSharding& sharding) {
1703 CHECK(sharding.IsManualSubgroup());
1704 int64_t tile_dimensions = sharding.tile_assignment().num_dimensions();
1705 int64_t subgroup_size = sharding.subgroup_types().size();
1706 int64_t rank = tile_dimensions - subgroup_size;
1707 std::vector<int64_t> group_dims;
1708 bool last_tile_dim_replicate = false;
1709
1710 for (int64_t i = 0; i < subgroup_size; i++) {
1711 if (sharding.subgroup_types()[i] == OpSharding::MANUAL) {
1712 group_dims.push_back(rank + i);
1713 } else if (sharding.subgroup_types()[i] == OpSharding::REPLICATED) {
1714 last_tile_dim_replicate = true;
1715 }
1716 }
1717
1718 GroupedSharding group_sharding =
1719 GroupShardingOnDims(sharding, group_dims, /*subgroup_manual=*/true);
1720
1721 if (last_tile_dim_replicate ||
1722 group_sharding.sharding.tile_assignment().num_dimensions() > rank) {
1723 group_sharding.sharding = HloSharding::PartialTile(
1724 group_sharding.sharding.tile_assignment(), sharding.metadata());
1725 }
1726 return group_sharding;
1727 }
1728
UngroupSharding(const GroupedSharding & grouped_sharding)1729 HloSharding UngroupSharding(const GroupedSharding& grouped_sharding) {
1730 std::vector<int64_t> tiling_dims;
1731 bool partial_sharding = false;
1732 std::vector<OpSharding::Type> subgroup_types;
1733 Array<int64_t> grouped_tiling = grouped_sharding.sharding.tile_assignment();
1734 if (grouped_sharding.sharding.IsTileMaximal()) {
1735 tiling_dims = std::vector<int64_t>(grouped_sharding.data_rank, 1);
1736 if (grouped_sharding.device_groups[0].size() != 1 ||
1737 absl::c_linear_search(grouped_sharding.group_dims,
1738 tiling_dims.size())) {
1739 // This is partial sharding.
1740 tiling_dims.push_back(grouped_sharding.device_groups[0].size());
1741 partial_sharding = true;
1742 }
1743 grouped_tiling = Array<int64_t>(tiling_dims);
1744 grouped_tiling.FillIota(0);
1745 }
1746
1747 // Handles subgroup manual first.
1748 if (grouped_sharding.subgroup_manual) {
1749 partial_sharding = grouped_sharding.sharding.ReplicateOnLastTileDim() ||
1750 grouped_sharding.sharding.IsReplicated();
1751 int64_t subgroup_dim_size = grouped_sharding.group_dims.size();
1752 if (partial_sharding) {
1753 subgroup_dim_size++;
1754 }
1755 subgroup_types = std::vector<OpSharding::Type>(subgroup_dim_size,
1756 OpSharding::REPLICATED);
1757 if (!grouped_sharding.sharding.IsTileMaximal()) {
1758 tiling_dims = grouped_sharding.sharding.tile_assignment().dimensions();
1759 }
1760 for (int i = 0; i < grouped_sharding.group_dims.size(); i++) {
1761 subgroup_types[grouped_sharding.group_dims[i] -
1762 grouped_sharding.data_rank] = OpSharding::MANUAL;
1763 tiling_dims.insert(tiling_dims.begin() + grouped_sharding.group_dims[i],
1764 1);
1765 }
1766 grouped_tiling.Reshape(tiling_dims);
1767 } else if (!grouped_sharding.sharding.IsTileMaximal()) {
1768 // Handles tile replicated.
1769 partial_sharding = grouped_sharding.sharding.ReplicateOnLastTileDim();
1770 tiling_dims = grouped_sharding.sharding.tile_assignment().dimensions();
1771 if (absl::c_linear_search(grouped_sharding.group_dims,
1772 tiling_dims.size())) {
1773 tiling_dims.push_back(1);
1774 grouped_tiling.Reshape(tiling_dims);
1775 partial_sharding = true;
1776 }
1777 }
1778
1779 // Update group dim sizes.
1780 for (int64_t i = 0; i < grouped_sharding.group_dims.size(); ++i) {
1781 int64_t dim = grouped_sharding.group_dims[i];
1782 tiling_dims[dim] *= grouped_sharding.group_dim_sizes[i];
1783 }
1784 Array<int64_t> tiling(tiling_dims);
1785 grouped_tiling.Each([&](absl::Span<const int64_t> indices, int64_t device) {
1786 std::vector<int64_t> ungrouped_inds(indices.begin(), indices.end());
1787 for (int64_t g = 0; g < grouped_sharding.device_groups.size(); ++g) {
1788 int64_t remaining_group_index = g;
1789 for (int64_t i = grouped_sharding.group_dims.size() - 1; i >= 0; --i) {
1790 int64_t dim = grouped_sharding.group_dims[i];
1791 int64_t groups_in_this_dim = grouped_sharding.group_dim_sizes[i];
1792 ungrouped_inds[dim] = (remaining_group_index % groups_in_this_dim) *
1793 grouped_tiling.dim(dim) +
1794 indices[dim];
1795 remaining_group_index /= groups_in_this_dim;
1796 }
1797 tiling(ungrouped_inds) = grouped_sharding.device_groups[g][device];
1798 }
1799 });
1800
1801 if (grouped_sharding.subgroup_manual) {
1802 return HloSharding::Subgroup(tiling, subgroup_types,
1803 grouped_sharding.sharding.metadata());
1804 }
1805 return partial_sharding ? HloSharding::PartialTile(tiling)
1806 : HloSharding::Tile(tiling);
1807 }
1808
DeviceGroupsAreMatch(GroupedSharding & lhs,GroupedSharding & rhs,bool ignore_group_order)1809 bool DeviceGroupsAreMatch(GroupedSharding& lhs, GroupedSharding& rhs,
1810 bool ignore_group_order) {
1811 if (lhs.device_groups.size() != rhs.device_groups.size()) {
1812 return false;
1813 }
1814
1815 bool matching_groups = true;
1816 absl::flat_hash_map<int64_t, int64_t> device_to_ref_group;
1817 for (int64_t g = 0; g < lhs.device_groups.size(); ++g) {
1818 for (int64_t device : lhs.device_groups[g]) {
1819 device_to_ref_group[device] = g;
1820 }
1821 }
1822 auto unique_ref_dev_group =
1823 [&](absl::Span<const int64_t> devices) -> int64_t {
1824 int64_t ref_g = -1;
1825 for (int64_t device : devices) {
1826 if (ref_g == -1) {
1827 ref_g = device_to_ref_group[device];
1828 } else if (ref_g != device_to_ref_group[device]) {
1829 return -1;
1830 }
1831 }
1832 return ref_g;
1833 };
1834 for (int64_t g = 0; g < rhs.device_groups.size(); ++g) {
1835 int64_t ref_g = unique_ref_dev_group(rhs.device_groups[g]);
1836 if (ref_g < 0 || (!ignore_group_order && g != ref_g)) {
1837 matching_groups = false;
1838 break;
1839 }
1840 }
1841
1842 return matching_groups;
1843 }
1844
SplitShardingDimension(const HloSharding & sharding,int64_t dimension,int64_t new_dim_size)1845 HloSharding SplitShardingDimension(const HloSharding& sharding,
1846 int64_t dimension, int64_t new_dim_size) {
1847 CHECK_GT(sharding.TiledDataRank(), dimension);
1848 CHECK_EQ(sharding.tile_assignment().dim(dimension) % new_dim_size, 0)
1849 << "dim size " << new_dim_size;
1850 auto new_tile_assignment = sharding.tile_assignment();
1851 std::vector<int64_t> dimensions = new_tile_assignment.dimensions();
1852 int64_t current_dimension = dimensions[dimension];
1853 dimensions.insert(dimensions.begin() + dimension + 1,
1854 current_dimension / new_dim_size);
1855 dimensions[dimension] = new_dim_size;
1856 new_tile_assignment.Reshape(dimensions);
1857 auto new_sharding = sharding.ReplicateOnLastTileDim()
1858 ? HloSharding::PartialTile(new_tile_assignment)
1859 : HloSharding::Subgroup(new_tile_assignment,
1860 sharding.subgroup_types());
1861 std::vector<int64_t> permutation(new_sharding.tile_assignment().dimensions());
1862 absl::c_iota(permutation, 0);
1863 std::swap(permutation[dimension], permutation[dimension + 1]);
1864 return TransposeSharding(new_sharding, permutation);
1865 }
1866
MergeShardingDimension(const HloSharding & sharding,int64_t dimension)1867 HloSharding MergeShardingDimension(const HloSharding& sharding,
1868 int64_t dimension) {
1869 CHECK_GT(sharding.TiledDataRank(), dimension);
1870 std::vector<int64_t> permutation(sharding.tile_assignment().dimensions());
1871 absl::c_iota(permutation, 0);
1872 std::swap(permutation[dimension], permutation[dimension + 1]);
1873 auto transposed_sharding = TransposeSharding(sharding, permutation);
1874 auto new_tile_assignment = transposed_sharding.tile_assignment();
1875 std::vector<int64_t> dimensions = new_tile_assignment.dimensions();
1876 dimensions[dimension] *= dimensions[dimension + 1];
1877 dimensions.erase(dimensions.begin() + dimension + 1);
1878 new_tile_assignment.Reshape(dimensions);
1879 return sharding.ReplicateOnLastTileDim()
1880 ? HloSharding::PartialTile(new_tile_assignment)
1881 : HloSharding::Subgroup(new_tile_assignment,
1882 sharding.subgroup_types());
1883 }
1884
1885 } // namespace hlo_sharding_util
1886 } // namespace xla
1887