xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_sharding_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
17 
18 #include <algorithm>
19 #include <iostream>
20 #include <iterator>
21 #include <map>
22 #include <memory>
23 #include <optional>
24 #include <set>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 #include "absl/algorithm/container.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "tensorflow/compiler/xla/array.h"
32 #include "tensorflow/compiler/xla/literal_util.h"
33 #include "tensorflow/compiler/xla/protobuf_util.h"
34 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
35 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
36 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 
40 namespace xla {
41 namespace hlo_sharding_util {
42 
IsShardingMoreSpecific(const HloSharding & lhs,const HloSharding & rhs)43 bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) {
44   CHECK_EQ(lhs.IsTuple(), rhs.IsTuple()) << lhs << " <> " << rhs;
45   if (lhs.IsTuple()) {
46     // For tuples we consider lhs to have a better sharding if none of the
47     // elements are worse and at least one element is better then in rhs
48     // sharding.
49     const auto& lhs_shardings = lhs.tuple_elements();
50     const auto& rhs_shardings = rhs.tuple_elements();
51     CHECK_EQ(lhs_shardings.size(), rhs_shardings.size());
52     bool is_better = false;
53     for (int64_t i = 0; i < lhs_shardings.size(); ++i) {
54       if (IsShardingMoreSpecific(rhs_shardings[i], lhs_shardings[i])) {
55         return false;
56       }
57       if (IsShardingMoreSpecific(lhs_shardings[i], rhs_shardings[i])) {
58         is_better = true;
59       }
60     }
61     return is_better;
62   }
63   if (lhs.IsManual() || rhs.IsManual()) {
64     return false;
65   }
66   if (!rhs.IsTileMaximal()) {
67     return lhs.NumTiles() > rhs.NumTiles();
68   } else if (!rhs.IsReplicated()) {
69     // If we are not replicated then only tiled (not tile maximal) shardings
70     // can improve us.
71     return !lhs.IsTileMaximal();
72   } else {
73     // If we are replicated then any non-replicated sharding can improve us.
74     return !lhs.IsReplicated();
75   }
76 }
77 
MergeSharding(const HloSharding & old,HloSharding * to_merge,bool may_combine_partial_sharding)78 bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
79                    bool may_combine_partial_sharding) {
80   if (old.IsTuple()) {
81     CHECK(to_merge->IsTuple());
82     bool changed = false;
83     for (int64_t i = 0; i < old.tuple_elements().size(); ++i) {
84       changed |=
85           MergeSharding(old.tuple_elements()[i], &to_merge->tuple_elements()[i],
86                         may_combine_partial_sharding);
87     }
88     return changed;
89   }
90   if (!may_combine_partial_sharding || !old.HasPartialReplication() ||
91       !to_merge->HasPartialReplication() ||
92       old.tile_assignment().num_elements() !=
93           to_merge->tile_assignment().num_elements()) {
94     return IsShardingMoreSpecific(*to_merge, old);
95   }
96 
97   if (MergeShardingIfCompatible(
98           old,
99           /*minimum_tiles=*/std::max(old.NumTiles(), to_merge->NumTiles()) + 1,
100           to_merge)) {
101     return true;
102   }
103   return IsShardingMoreSpecific(*to_merge, old);
104 }
105 
MergeShardingIfCompatible(const HloSharding & to_merge,int64_t minimum_tiles,HloSharding * dst)106 bool MergeShardingIfCompatible(const HloSharding& to_merge,
107                                int64_t minimum_tiles, HloSharding* dst) {
108   if (to_merge.IsTileMaximal()) {
109     return false;
110   }
111   if (dst->IsTileMaximal()) {
112     *dst = to_merge;
113     return true;
114   }
115   if (!dst->HasPartialReplication()) {
116     return false;
117   }
118   // Combine the tile dimension sizes from dst and to_merge.
119   int64_t num_devices = to_merge.tile_assignment().num_elements();
120   std::vector<int64_t> merged_tile_dims;
121   merged_tile_dims.reserve(dst->tile_assignment().num_dimensions());
122   for (int64_t i = 0; i < to_merge.TiledDataRank(); ++i) {
123     int64_t dst_dim = dst->tile_assignment().dim(i);
124     int64_t merge_dim = to_merge.tile_assignment().dim(i);
125     if (dst_dim == 1) {
126       merged_tile_dims.push_back(merge_dim);
127     } else if (merge_dim == 1) {
128       merged_tile_dims.push_back(dst_dim);
129     } else if (dst_dim == merge_dim) {
130       merged_tile_dims.push_back(dst_dim);
131     } else {
132       return false;
133     }
134   }
135   const int64_t num_tiles = Product(merged_tile_dims);
136   if (num_devices % num_tiles != 0 || num_tiles < minimum_tiles) {
137     return false;
138   }
139   int64_t to_merge_man_dim = to_merge.SubgroupManualDim();
140   int64_t dst_man_dim = dst->SubgroupManualDim();
141   if (to_merge_man_dim >= 0) {
142     if (dst_man_dim < 0) {
143       return false;
144     }
145     int64_t man_group_size = to_merge.tile_assignment().dim(to_merge_man_dim);
146     if (man_group_size != dst->tile_assignment().dim(dst_man_dim)) {
147       return false;
148     }
149     merged_tile_dims.push_back(man_group_size);
150   }
151   int64_t replication = num_devices / Product(merged_tile_dims);
152   merged_tile_dims.push_back(replication);
153   Array<int64_t> merged_tile(merged_tile_dims);
154   // Maps from replication group ID to sorted members.
155   absl::flat_hash_map<int64_t, std::set<int64_t>> merge_group_members;
156   absl::flat_hash_map<int64_t, std::set<int64_t>> dst_group_members;
157   auto get_group_index = [&](absl::Span<const int64_t> tile_indices,
158                              const HloSharding& sharding, int64_t manual_dim) {
159     int64_t group_id = 0;
160     for (int64_t i = 0; i < to_merge.TiledDataRank(); ++i) {
161       group_id *= sharding.tile_assignment().dim(i);
162       group_id += tile_indices[i];
163     }
164     if (manual_dim >= 0) {
165       group_id *= sharding.tile_assignment().dim(manual_dim);
166       group_id += tile_indices[manual_dim];
167     }
168     return group_id;
169   };
170   to_merge.tile_assignment().Each([&](absl::Span<const int64_t> indices,
171                                       int64_t device) {
172     merge_group_members[get_group_index(indices, to_merge, to_merge_man_dim)]
173         .insert(device);
174   });
175   dst->tile_assignment().Each(
176       [&](absl::Span<const int64_t> indices, int64_t device) {
177         dst_group_members[get_group_index(indices, *dst, dst_man_dim)].insert(
178             device);
179       });
180   // Try to find the intersection of to_merge and dst replication groups, in
181   // order to determine the merged tile assignment.
182   Status compatible = merged_tile.EachStatus(
183       [&](absl::Span<const int64_t> indices, int64_t* device) {
184         std::vector<int64_t> to_merge_index(
185             to_merge.tile_assignment().num_dimensions());
186         std::vector<int64_t> dst_index(dst->tile_assignment().num_dimensions());
187         for (int64_t i = 0; i < to_merge.TiledDataRank(); ++i) {
188           if (to_merge.tile_assignment().dim(i) == 1) {
189             to_merge_index[i] = 0;
190           } else {
191             to_merge_index[i] = indices[i];
192           }
193           if (dst->tile_assignment().dim(i) == 1) {
194             dst_index[i] = 0;
195           } else {
196             dst_index[i] = indices[i];
197           }
198         }
199         if (to_merge_man_dim >= 0) {
200           to_merge_index[to_merge_man_dim] = indices[to_merge.TiledDataRank()];
201           dst_index[dst_man_dim] = indices[to_merge.TiledDataRank()];
202         }
203         if (to_merge.HasPartialReplication()) {
204           to_merge_index[to_merge.SubgroupReplicationDim()] = indices.back();
205         }
206         dst_index[dst->SubgroupReplicationDim()] = indices.back();
207         int64_t to_merge_group_id =
208             get_group_index(to_merge_index, to_merge, to_merge_man_dim);
209         int64_t dst_group_id = get_group_index(dst_index, *dst, dst_man_dim);
210         if (merge_group_members[to_merge_group_id].empty() ||
211             dst_group_members[dst_group_id].empty()) {
212           return InvalidArgument("Not compatible");
213         }
214 
215         int64_t smallest_to_merge =
216             *merge_group_members[to_merge_group_id].begin();
217         int64_t smallest_dst = *dst_group_members[dst_group_id].begin();
218         if (smallest_to_merge < smallest_dst) {
219           if (merge_group_members[to_merge_group_id].count(smallest_dst) == 0) {
220             return InvalidArgument("Not compatible");
221           }
222           *device = smallest_dst;
223         } else {
224           if (dst_group_members[dst_group_id].count(smallest_to_merge) == 0) {
225             return InvalidArgument("Not compatible");
226           }
227           *device = smallest_to_merge;
228         }
229         merge_group_members[to_merge_group_id].erase(*device);
230         dst_group_members[dst_group_id].erase(*device);
231         return OkStatus();
232       });
233   if (!compatible.ok()) {
234     return false;
235   }
236   std::vector<OpMetadata> merged_metadata(std::move(dst->metadata()));
237   merged_metadata.reserve(merged_metadata.size() + to_merge.metadata().size());
238   const absl::flat_hash_set<OpMetadata, protobuf_util::ProtobufHashWrapper,
239                             protobuf_util::ProtobufEqualsWrapper>
240       metadata_set(merged_metadata.begin(), merged_metadata.end());
241   absl::c_copy_if(to_merge.metadata(), std::back_inserter(merged_metadata),
242                   [&metadata_set](const OpMetadata& data) {
243                     return !ContainsKey(metadata_set, data);
244                   });
245   std::vector<OpSharding::Type> subgroup_types;
246   if (to_merge_man_dim >= 0) {
247     subgroup_types.push_back(OpSharding::MANUAL);
248   }
249   subgroup_types.push_back(OpSharding::REPLICATED);
250   *dst = HloSharding::Subgroup(merged_tile, subgroup_types, merged_metadata);
251   return true;
252 }
253 
SelectDominantDevice(const std::map<int64_t,int64_t> & device_map,int64_t * top_count)254 std::optional<int64_t> SelectDominantDevice(
255     const std::map<int64_t, int64_t>& device_map, int64_t* top_count) {
256   int64_t device = 0;
257   int64_t count = 0;
258   for (auto& it : device_map) {
259     if (it.second > count) {
260       count = it.second;
261       device = it.first;
262     }
263   }
264   if (top_count != nullptr) {
265     *top_count = count;
266   }
267   return count > 0 ? std::optional<int64_t>(device) : std::optional<int64_t>();
268 }
269 
AssignComputationDevice(HloComputation * computation,int64_t device)270 void AssignComputationDevice(HloComputation* computation, int64_t device) {
271   VLOG(4) << "Assigning device " << device << " to " << computation->name()
272           << " computation";
273   for (HloInstruction* instruction : computation->instructions()) {
274     if (!instruction->has_sharding()) {
275       VLOG(4) << "Assigning device " << device << " to " << instruction->name();
276       instruction->set_device_sharding(device);
277     }
278   }
279 }
280 
GetMostOccurringDevice(absl::Span<HloInstruction * const> instructions)281 std::optional<int64_t> GetMostOccurringDevice(
282     absl::Span<HloInstruction* const> instructions) {
283   std::map<int64_t, int64_t> device_map;
284   for (HloInstruction* instruction : instructions) {
285     if (instruction->has_sharding()) {
286       for (auto& it : instruction->sharding().UsedDevices(nullptr)) {
287         // The UsedDevices() API returns a map<device, occurrence_count>.
288         device_map[it.first] += it.second;
289       }
290     }
291   }
292   return SelectDominantDevice(device_map, nullptr);
293 }
294 
GetDominantDevice(absl::Span<HloComputation * const> computations,double dominant_factor)295 std::optional<int64_t> GetDominantDevice(
296     absl::Span<HloComputation* const> computations, double dominant_factor) {
297   int64_t instruction_count = 0;
298   std::map<int64_t, int64_t> device_map;
299   for (HloComputation* computation : computations) {
300     for (HloInstruction* instruction : computation->instructions()) {
301       int64_t count = 1;
302       if (instruction->has_sharding()) {
303         for (auto& it : instruction->sharding().UsedDevices(&count)) {
304           // The UsedDevices() API returns a map<device, occurrence_count>.
305           device_map[it.first] += it.second;
306         }
307       }
308       instruction_count += count;
309     }
310   }
311   int64_t count;
312   std::optional<int64_t> device = SelectDominantDevice(device_map, &count);
313   std::optional<int64_t> dominant_device;
314   if (device) {
315     double factor =
316         static_cast<double>(count) / static_cast<double>(instruction_count);
317     if (factor >= dominant_factor) {
318       dominant_device = device;
319     }
320   }
321   return dominant_device;
322 }
323 
TransposeSharding(const HloSharding & sharding,absl::Span<const int64_t> dimensions)324 HloSharding TransposeSharding(const HloSharding& sharding,
325                               absl::Span<const int64_t> dimensions) {
326   if (sharding.IsTileMaximal()) {
327     return sharding;
328   }
329   DimensionVector perm_dimensions(dimensions.begin(), dimensions.end());
330   // Add subgroup dims if missing.
331   if (sharding.TiledDataRank() == dimensions.size()) {
332     for (int64_t i = sharding.TiledDataRank();
333          i < sharding.tile_assignment().num_dimensions(); ++i) {
334       perm_dimensions.push_back(i);
335     }
336   } else {
337     CHECK_EQ(sharding.tile_assignment().num_dimensions(), dimensions.size());
338   }
339   Array<int64_t> tile_assignment = sharding.tile_assignment();
340   tile_assignment.TransposeDimensions(perm_dimensions);
341   if (!sharding.ReplicateOnLastTileDim()) {
342     std::vector<OpSharding::Type> subgroup_types;
343     for (int64_t i = sharding.TiledDataRank(); i < perm_dimensions.size();
344          ++i) {
345       int64_t src_i = perm_dimensions[i] - sharding.TiledDataRank();
346       subgroup_types.push_back(sharding.subgroup_types()[src_i]);
347     }
348     return HloSharding::Subgroup(tile_assignment, subgroup_types,
349                                  sharding.metadata());
350   } else {
351     return HloSharding::PartialTile(tile_assignment, sharding.metadata());
352   }
353 }
354 
ReshapeSharding(const Shape & source_shape,const Shape & target_shape,const HloSharding & sharding)355 std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
356                                            const Shape& target_shape,
357                                            const HloSharding& sharding) {
358   if (sharding.IsTileMaximal()) {
359     return sharding;
360   }
361 
362   // In case of a tiled sharding the reshaped sharding will be a valid if the
363   // reshape is composed from the following operations:
364   // * Adding or removing dimensions with size 1.
365   // * Merging consecutive dimensions where only the most major is sharded.
366   // * Splitting a dimension to consecutive dimensions.
367   // * Any reshaping of unsharded dimensions.
368   // Note that merge and split can happen consecutively on the same dimension,
369   // e.g., f32[1024,256,1024] to f32[128,2048,1024] can be considered that 1024
370   // gets split into 128 and 8, but 8 then gets merged with 256. We use stacks
371   // to make supporting such cases easy.
372   const Shape tile_shape = sharding.TileShape(source_shape);
373   std::vector<int64_t> target_tile_assignment_dimensions;
374   std::vector<int64_t> source_dims_stack(source_shape.rank());
375   std::vector<int64_t> target_dims_stack(target_shape.rank());
376   std::vector<int64_t> sharding_tile_dims_stack(source_shape.rank());
377   int64_t added_to_partially_replicated = 1;
378   for (int64_t i = 0; i < source_shape.rank(); ++i) {
379     source_dims_stack[i] = source_shape.dimensions(source_shape.rank() - 1 - i);
380     sharding_tile_dims_stack[i] =
381         sharding.tile_assignment().dim(source_shape.rank() - 1 - i);
382   }
383   for (int64_t i = 0; i < target_shape.rank(); ++i) {
384     target_dims_stack[i] = target_shape.dimensions(target_shape.rank() - 1 - i);
385   }
386   while (!source_dims_stack.empty() || !target_dims_stack.empty()) {
387     if (target_dims_stack.empty()) {
388       if (Product(sharding_tile_dims_stack) != 1) {
389         return std::nullopt;
390       }
391       break;
392     }
393     int64_t s_size = 1;
394     int64_t t_size = 1;
395     int64_t s_partitions = 1;
396     if (!source_dims_stack.empty()) {
397       s_size = source_dims_stack.back();
398       source_dims_stack.pop_back();
399       s_partitions = sharding_tile_dims_stack.back();
400       sharding_tile_dims_stack.pop_back();
401     }
402     t_size = target_dims_stack.back();
403     target_dims_stack.pop_back();
404     if (s_partitions * Product(sharding_tile_dims_stack) == 1) {
405       // No more partitions left.
406       target_tile_assignment_dimensions.push_back(1);
407       continue;
408     }
409     if (s_size == t_size) {
410       // Same dimension.
411       target_tile_assignment_dimensions.push_back(s_partitions);
412     } else if (t_size == 1) {
413       // Trivial dimension added.
414       target_tile_assignment_dimensions.push_back(1);
415       source_dims_stack.push_back(s_size);
416       sharding_tile_dims_stack.push_back(s_partitions);
417     } else if (s_size == 1) {
418       // Trivial dimension removed.
419       if (s_partitions != 1) {
420         added_to_partially_replicated *= s_partitions;
421       }
422       target_dims_stack.push_back(t_size);
423     } else if (s_size > t_size) {
424       // Dimension split.
425       if (s_size % t_size != 0 || s_size % s_partitions != 0) {
426         return std::nullopt;
427       }
428       if (t_size % s_partitions == 0) {
429         target_tile_assignment_dimensions.push_back(s_partitions);
430         // We have part of the s_size unprocessed, so put it back to stack.
431         source_dims_stack.push_back(s_size / t_size);
432         sharding_tile_dims_stack.push_back(1);
433       } else if (s_partitions % t_size == 0) {
434         target_tile_assignment_dimensions.push_back(t_size);
435         // We have part of the s_size unprocessed, so put it back to stack.
436         source_dims_stack.push_back(s_size / t_size);
437         sharding_tile_dims_stack.push_back(s_partitions / t_size);
438       } else {
439         return std::nullopt;
440       }
441     } else {
442       // Dimension merge. Also merge the source dimension with the next, and
443       // process it next time.
444       if (s_size % s_partitions != 0) {
445         return std::nullopt;
446       }
447       CHECK(!source_dims_stack.empty());
448       if (sharding_tile_dims_stack.back() != 1 && s_size != s_partitions) {
449         // If the next dimension to combine is sharded, we require that the
450         // current dimension's shard size to be 1. Otherwise, the new shard
451         // would be non-contiguous.
452         return std::nullopt;
453       }
454       source_dims_stack.back() *= s_size;
455       sharding_tile_dims_stack.back() *= s_partitions;
456       target_dims_stack.push_back(t_size);
457     }
458   }
459   if (Product(target_tile_assignment_dimensions) == 1) {
460     return std::nullopt;
461   }
462   Array<int64_t> new_tile_assignment = sharding.tile_assignment();
463   for (int64_t i = sharding.TiledDataRank();
464        i < sharding.tile_assignment().num_dimensions(); ++i) {
465     target_tile_assignment_dimensions.push_back(
466         sharding.tile_assignment().dim(i));
467   }
468 
469   auto subgroup_types = sharding.subgroup_types();
470   // If we added dimensions to the partially replicated dimension then add the
471   // additional dimension on the partially replicated tiling.
472   if (added_to_partially_replicated > 1) {
473     if (sharding.ReplicateOnLastTileDim()) {
474       target_tile_assignment_dimensions.back() *= added_to_partially_replicated;
475     } else {
476       target_tile_assignment_dimensions.push_back(
477           added_to_partially_replicated);
478     }
479   }
480   // If subgroup_types doesn't have already partially replicated as a sharding
481   // type then add it.
482   if ((sharding.ReplicateOnLastTileDim() ||
483        added_to_partially_replicated > 1) &&
484       (subgroup_types.empty() ||
485        subgroup_types.back() != OpSharding::REPLICATED)) {
486     subgroup_types.push_back(OpSharding::REPLICATED);
487   }
488   new_tile_assignment.Reshape(target_tile_assignment_dimensions);
489   return HloSharding::Subgroup(new_tile_assignment, subgroup_types,
490                                sharding.metadata());
491 }
492 
ReverseSharding(const HloSharding & sharding,absl::Span<const int64_t> dimensions)493 HloSharding ReverseSharding(const HloSharding& sharding,
494                             absl::Span<const int64_t> dimensions) {
495   if (sharding.IsTileMaximal() || dimensions.empty()) {
496     return sharding;
497   }
498 
499   Array<int64_t> new_tile_assignment(sharding.tile_assignment().dimensions());
500   new_tile_assignment.Each(
501       [&](absl::Span<const int64_t> indices, int64_t* device) {
502         std::vector<int64_t> original_indices(indices.begin(), indices.end());
503         for (int64_t d : dimensions) {
504           original_indices[d] =
505               new_tile_assignment.dim(d) - 1 - original_indices[d];
506         }
507         *device = sharding.tile_assignment()(original_indices);
508       });
509   return sharding.ReplicateOnLastTileDim()
510              ? HloSharding::PartialTile(new_tile_assignment,
511                                         sharding.metadata())
512              : HloSharding::Subgroup(new_tile_assignment,
513                                      sharding.subgroup_types(),
514                                      sharding.metadata());
515 }
516 
ReshapeToTileDimension(const HloSharding & sharding,int64_t dim,absl::Span<const int64_t> dims)517 HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64_t dim,
518                                    absl::Span<const int64_t> dims) {
519   CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal());
520   CHECK_NE(absl::c_find(dims, dim), dims.end()) << "dim is not in dims";
521   // We optimize the tile assignment on the single dimension dim in a way to
522   // minimize communication among devices caused by the reshard:
523   // +---+---+               +---+---+              +-+-+-+-+
524   // |   |   |               |   0   |              | | | | |
525   // | 0 | 1 |               +-------+              | | | | |
526   // |   |   |  reshape on   |   1   |  reshape on  | | | | |
527   // +---+---+   dim 0  =>   +-------+   dim 1  =>  |0|2|1|3|
528   // |   |   |               |   2   |              | | | | |
529   // | 2 | 3 |               +-------+              | | | | |
530   // |   |   |               |   3   |              | | | | |
531   // +---+---+               +---+---+              +-+-+-+-+
532 
533   std::vector<int64_t> tile_dims(sharding.tile_assignment().num_dimensions(),
534                                  1);
535   // Handle ignore dimensions.
536   std::vector<int64_t> ignore_sizes;
537   int64_t ignore_size = 1;
538   for (int64_t i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
539     if (absl::c_find(dims, i) == dims.end()) {
540       int64_t size = sharding.tile_assignment().dim(i);
541       ignore_sizes.push_back(size);
542       tile_dims[i] = size;
543       ignore_size *= size;
544     }
545   }
546 
547   using Buckets = std::vector<std::vector<int64_t>>;
548   Array<Buckets> buckets(ignore_sizes,
549                          Buckets(sharding.tile_assignment().dim(dim)));
550   sharding.tile_assignment().Each(
551       [&](absl::Span<const int64_t> index, int64_t device) {
552         std::vector<int64_t> ignore_index;
553         for (int64_t i = 0; i < index.size(); ++i) {
554           if (absl::c_find(dims, i) == dims.end()) {
555             ignore_index.push_back(index[i]);
556           }
557         }
558         buckets(ignore_index)[index[dim]].push_back(device);
559       });
560   std::vector<int64_t> devices;
561   buckets.Each([&](absl::Span<const int64_t> index, const Buckets& buckets) {
562     for (auto& bucket : buckets) {
563       devices.insert(devices.end(), bucket.begin(), bucket.end());
564     }
565   });
566   tile_dims[dim] = devices.size() / ignore_size;
567   Array<int64_t> tile_assignment(tile_dims);
568   tile_assignment.SetValues(devices);
569   return HloSharding::Tile(tile_assignment, sharding.metadata());
570 }
571 
ContainsTileSharding(const HloModule & module)572 bool ContainsTileSharding(const HloModule& module) {
573   for (const HloComputation* computation : module.computations()) {
574     for (const HloInstruction* instruction : computation->instructions()) {
575       if (instruction->has_sharding() &&
576           !instruction->sharding().IsTileMaximal()) {
577         return true;
578       }
579     }
580   }
581   return false;
582 }
583 
GatherOutputSharding(const HloSharding & index_sharding,const HloInstruction * hlo)584 HloSharding GatherOutputSharding(const HloSharding& index_sharding,
585                                  const HloInstruction* hlo) {
586   if (index_sharding.IsTileMaximal()) {
587     return index_sharding;
588   }
589 
590   const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
591   std::vector<int64_t> output_tile_assignment_dims;
592   const int64_t rank = hlo->shape().rank(),
593                 num_dimensions =
594                     index_sharding.tile_assignment().num_dimensions();
595   output_tile_assignment_dims.reserve(rank + num_dimensions);
596   for (int64_t i = 0, index_dim = 0; i < rank; ++i) {
597     if (absl::c_binary_search(dnums.offset_dims(), i)) {
598       output_tile_assignment_dims.push_back(1);
599     } else {
600       const int64_t new_tile_dimension =
601           index_dim >= dnums.index_vector_dim() ? index_dim + 1 : index_dim;
602       output_tile_assignment_dims.push_back(
603           index_sharding.tile_assignment().dim(new_tile_dimension));
604       ++index_dim;
605     }
606   }
607 
608   for (int64_t i = index_sharding.TiledDataRank(); i < num_dimensions; ++i) {
609     output_tile_assignment_dims.push_back(
610         index_sharding.tile_assignment().dim(i));
611   }
612 
613   Array<int64_t> new_tile_assignment = index_sharding.tile_assignment();
614   if (new_tile_assignment.num_elements() !=
615       Product(output_tile_assignment_dims)) {
616     return HloSharding::Replicate(index_sharding.metadata());
617   }
618   new_tile_assignment.Reshape(output_tile_assignment_dims);
619   return index_sharding.ReplicateOnLastTileDim()
620              ? HloSharding::PartialTile(new_tile_assignment,
621                                         index_sharding.metadata())
622              : HloSharding::Subgroup(new_tile_assignment,
623                                      index_sharding.subgroup_types(),
624                                      index_sharding.metadata());
625 }
626 
GatherIndexSharding(const HloSharding & output_sharding,const HloInstruction * hlo)627 HloSharding GatherIndexSharding(const HloSharding& output_sharding,
628                                 const HloInstruction* hlo) {
629   CHECK(hlo->opcode() == HloOpcode::kGather);
630   if (output_sharding.IsTileMaximal()) {
631     return output_sharding;
632   }
633 
634   const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
635   std::vector<int64_t> index_tile_assignment_dims;
636   // Relevant output dims have shardings passed to the index.
637   std::vector<int64_t> relevant_output_dims;
638   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
639     if (!absl::c_binary_search(dnums.offset_dims(), i)) {
640       index_tile_assignment_dims.push_back(
641           output_sharding.tile_assignment().dim(i));
642       relevant_output_dims.push_back(i);
643     }
644   }
645   int64_t index_rank = hlo->operand(1)->shape().rank();
646 
647   // Indices sharding on `index_vector_dim` is not supported yet.
648   if (index_rank > index_tile_assignment_dims.size()) {
649     index_tile_assignment_dims.insert(
650         index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1);
651   }
652 
653   if (Product(index_tile_assignment_dims) == 1) {
654     return HloSharding::Replicate(output_sharding.metadata());
655   }
656   HloSharding relevant_output_sharding =
657       PartiallyReplicateTiledShardingOnAllDimsExcept(output_sharding,
658                                                      relevant_output_dims);
659   if (relevant_output_sharding.IsTileMaximal()) {
660     return relevant_output_sharding;
661   }
662   for (int64_t i = relevant_output_sharding.TiledDataRank();
663        i < relevant_output_sharding.tile_assignment().num_dimensions(); ++i) {
664     index_tile_assignment_dims.push_back(
665         relevant_output_sharding.tile_assignment().dim(i));
666   }
667 
668   Array<int64_t> new_tile_assignment =
669       relevant_output_sharding.tile_assignment();
670   new_tile_assignment.Reshape(index_tile_assignment_dims);
671   return relevant_output_sharding.ReplicateOnLastTileDim()
672              ? HloSharding::PartialTile(new_tile_assignment,
673                                         output_sharding.metadata())
674              : HloSharding::Subgroup(new_tile_assignment,
675                                      relevant_output_sharding.subgroup_types(),
676                                      output_sharding.metadata());
677 }
678 
GatherEffectiveOutputSharding(const HloInstruction & hlo)679 HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) {
680   if (hlo.sharding().IsTileMaximal()) {
681     return hlo.sharding();
682   }
683 
684   const GatherDimensionNumbers& dnums = hlo.gather_dimension_numbers();
685   std::vector<int64_t> tile_assignment_dims(hlo.shape().rank());
686   int64_t num_elements = 1;
687   for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
688     if (!absl::c_binary_search(dnums.offset_dims(), i)) {
689       tile_assignment_dims[i] = hlo.sharding().tile_assignment().dim(i);
690       num_elements *= hlo.sharding().tile_assignment().dim(i);
691     } else {
692       tile_assignment_dims[i] = 1;
693     }
694   }
695   if (num_elements == hlo.sharding().tile_assignment().num_elements()) {
696     // Output sharding is only on non offset dimensions. We use output sharding
697     // to shard this gather op directly.
698     return hlo.sharding();
699   }
700 
701   if (num_elements == 1) {
702     // Output sharding is only on offset dimensions. We do not shard this gather
703     // op. Return a tile maximal sharding with the first device in output
704     // sharding tile assignment.
705     return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin(),
706                                      hlo.sharding().metadata());
707   }
708 
709   // Output sharding is on both offset and non offset dimensions. We shard the
710   // gather op only on non offset dimensions.
711   // For example:
712   // - the gather op has sharding [2,2]{0,1,2,3},
713   // - first dimension is non offset dimension,
714   // - second dimension is offset dimension,
715   // Then the result sharding will be [2,1]{0,2}.
716   std::vector<int64_t> slice_starts(hlo.shape().rank(), 0LL),
717       slice_limits(hlo.shape().rank());
718   for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
719     if (!absl::c_binary_search(dnums.offset_dims(), i)) {
720       slice_limits[i] = hlo.sharding().tile_assignment().dim(i);
721     } else {
722       slice_limits[i] = 1;
723     }
724   }
725   Array<int64_t> tile_assignment =
726       hlo.sharding().tile_assignment().Slice(slice_starts, slice_limits);
727   return HloSharding::Tile(tile_assignment, hlo.sharding().metadata());
728 }
729 
ScatterIndexSharding(const HloSharding & data_sharding,const HloScatterInstruction * scatter)730 HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
731                                  const HloScatterInstruction* scatter) {
732   if (data_sharding.IsTileMaximal()) {
733     return data_sharding;
734   }
735 
736   const ScatterDimensionNumbers& dnums = scatter->scatter_dimension_numbers();
737   std::vector<int64_t> index_tile_assignment_dims;
738   std::vector<int64_t> relevant_data_dims;
739   for (int64_t i = 0; i < scatter->scatter_updates()[0]->shape().rank(); ++i) {
740     if (!absl::c_binary_search(dnums.update_window_dims(), i)) {
741       index_tile_assignment_dims.push_back(
742           data_sharding.tile_assignment().dim(i));
743       relevant_data_dims.push_back(i);
744     }
745   }
746   // Indices sharding on `index_vector_dim` is not supported yet.
747   if (index_tile_assignment_dims.size() <
748       scatter->scatter_indices()->shape().rank()) {
749     index_tile_assignment_dims.insert(
750         index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1);
751   }
752   HloSharding relevant_data_sharding =
753       PartiallyReplicateTiledShardingOnAllDimsExcept(data_sharding,
754                                                      relevant_data_dims);
755   if (relevant_data_sharding.IsTileMaximal()) {
756     return relevant_data_sharding;
757   }
758   for (int64_t i = relevant_data_sharding.TiledDataRank();
759        i < relevant_data_sharding.tile_assignment().num_dimensions(); ++i) {
760     index_tile_assignment_dims.push_back(
761         relevant_data_sharding.tile_assignment().dim(i));
762   }
763   auto new_tile_assignment = relevant_data_sharding.tile_assignment();
764   new_tile_assignment.Reshape(index_tile_assignment_dims);
765   return relevant_data_sharding.ReplicateOnLastTileDim()
766              ? HloSharding::PartialTile(new_tile_assignment,
767                                         data_sharding.metadata())
768              : HloSharding::Subgroup(new_tile_assignment,
769                                      relevant_data_sharding.subgroup_types(),
770                                      data_sharding.metadata());
771 }
772 
ScatterDataSharding(const HloSharding & index_sharding,const HloScatterInstruction * scatter)773 HloSharding ScatterDataSharding(const HloSharding& index_sharding,
774                                 const HloScatterInstruction* scatter) {
775   if (index_sharding.IsTileMaximal()) {
776     return index_sharding;
777   }
778 
779   const ScatterDimensionNumbers& dnums = scatter->scatter_dimension_numbers();
780   std::vector<int64_t> data_tile_assignment_dims;
781   std::vector<int64_t> relevant_index_dims;
782   const int64_t rank = scatter->scatter_updates()[0]->shape().rank();
783   data_tile_assignment_dims.reserve(rank);
784   for (int64_t i = 0, index_dim = 0; i < rank; ++i) {
785     if (absl::c_binary_search(dnums.update_window_dims(), i)) {
786       data_tile_assignment_dims.push_back(1);
787     } else {
788       data_tile_assignment_dims.push_back(
789           index_sharding.tile_assignment().dim(index_dim));
790       relevant_index_dims.push_back(index_dim);
791       index_dim++;
792     }
793   }
794   auto relevant_index_sharding = PartiallyReplicateTiledShardingOnAllDimsExcept(
795       index_sharding, relevant_index_dims);
796   if (relevant_index_sharding.IsTileMaximal()) {
797     return relevant_index_sharding;
798   }
799   for (int64_t i = relevant_index_sharding.TiledDataRank();
800        i < relevant_index_sharding.tile_assignment().num_dimensions(); ++i) {
801     data_tile_assignment_dims.push_back(
802         relevant_index_sharding.tile_assignment().dim(i));
803   }
804   Array<int64_t> new_tile_assignment =
805       relevant_index_sharding.tile_assignment();
806   new_tile_assignment.Reshape(data_tile_assignment_dims);
807   return relevant_index_sharding.ReplicateOnLastTileDim()
808              ? HloSharding::PartialTile(new_tile_assignment,
809                                         index_sharding.metadata())
810              : HloSharding::Subgroup(new_tile_assignment,
811                                      relevant_index_sharding.subgroup_types(),
812                                      index_sharding.metadata());
813 }
814 
ScatterEffectiveIndexSharding(const HloSharding & index_sharding,const HloScatterInstruction & scatter)815 HloSharding ScatterEffectiveIndexSharding(
816     const HloSharding& index_sharding, const HloScatterInstruction& scatter) {
817   if (index_sharding.IsTileMaximal()) {
818     return index_sharding;
819   }
820 
821   // Only shard on first "number of scatter_window_dims" dimensions.
822   const ScatterDimensionNumbers& dnums = scatter.scatter_dimension_numbers();
823   int64_t num_elements = 1;
824   int64_t index_dim = 0;
825   for (int64_t i = 0; i < scatter.shape().rank(); ++i) {
826     if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
827       num_elements *= index_sharding.tile_assignment().dim(index_dim);
828       index_dim++;
829     }
830   }
831   if (num_elements == index_sharding.tile_assignment().num_elements()) {
832     // Index sharding is only on scatter_window_dims. We use this index sharding
833     // directly.
834     return index_sharding;
835   }
836 
837   // Index sharding is only on update_window_dims. We do not shard this scatter
838   // op. Return a tile maximal sharding with the first device in index sharding
839   // tile assignment.
840   if (num_elements == 1) {
841     return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin(),
842                                      index_sharding.metadata());
843   }
844 
845   const int64_t index_rank = scatter.scatter_indices()->shape().rank();
846   std::vector<int64_t> slice_starts(index_rank, 0LL), slice_limits(index_rank);
847   for (int64_t i = 0; i < index_rank; ++i) {
848     if (i < index_dim) {
849       slice_limits[i] = index_sharding.tile_assignment().dim(i);
850     } else {
851       slice_limits[i] = 1;
852     }
853   }
854   Array<int64_t> tile_assignment =
855       index_sharding.tile_assignment().Slice(slice_starts, slice_limits);
856   return HloSharding::Tile(tile_assignment, index_sharding.metadata());
857 }
858 
ScatterEffectiveDataSharding(const HloSharding & data_sharding,const HloScatterInstruction & scatter)859 HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
860                                          const HloScatterInstruction& scatter) {
861   if (data_sharding.IsTileMaximal()) {
862     return data_sharding;
863   }
864 
865   const ScatterDimensionNumbers& dnums = scatter.scatter_dimension_numbers();
866   const int64_t data_rank = scatter.scatter_updates()[0]->shape().rank();
867   std::vector<int64_t> tile_assignment_dims(data_rank, 1LL);
868   int64_t num_elements = 1;
869   for (int64_t i = 0; i < scatter.shape().rank(); ++i) {
870     if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
871       CHECK_LT(i, data_rank);
872       tile_assignment_dims[i] = data_sharding.tile_assignment().dim(i);
873       num_elements *= data_sharding.tile_assignment().dim(i);
874     }
875   }
876   if (num_elements == data_sharding.tile_assignment().num_elements()) {
877     // Data sharding is only on scatter_window_dims. We use this data sharding
878     // directly.
879     return data_sharding;
880   }
881 
882   if (num_elements == 1) {
883     // Data sharding is only on update_window_dims. We do not shard this
884     // scatter op. Return a tile maximal sharding with the first device in
885     // data sharding tile assignment.
886     return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin(),
887                                      data_sharding.metadata());
888   }
889 
890   // Data sharding is on both update_window_dims and scatter_window_dims. We
891   // shard the scatter op only on scatter_window_dims. For example:
892   // - the scatter data has sharding [2,2]{0,1,2,3},
893   // - first dimension is scatter_window_dims,
894   // - second dimension is update_window_dims,
895   // Then the result sharding will be [2,1]{0,2}.
896   std::vector<int64_t> slice_starts(data_rank, 0LL);
897   Array<int64_t> tile_assignment =
898       data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims);
899   return HloSharding::Tile(tile_assignment, data_sharding.metadata());
900 }
901 
902 namespace {
903 
904 // If partitioning in the operand only happens in dimensions in passthrough
905 // dimensions (offset dimensions in the gather output (or scatter update) that
906 // have the same size as the operand), returns the corresponding output (or
907 // update) sharding by passing through the input sharding.
PassthroughOperandToGatherOutputOrScatterUpdate(const Shape & operand_shape,const HloSharding & operand_sharding,const Shape & update_or_gather_shape,absl::Span<const int64_t> collapsed_or_inserted_dims,absl::Span<const int64_t> index_map,absl::Span<const int64_t> offset_or_window_dims,absl::Span<const int64_t> slice_size)908 std::optional<HloSharding> PassthroughOperandToGatherOutputOrScatterUpdate(
909     const Shape& operand_shape, const HloSharding& operand_sharding,
910     const Shape& update_or_gather_shape,
911     absl::Span<const int64_t> collapsed_or_inserted_dims,
912     absl::Span<const int64_t> index_map,
913     absl::Span<const int64_t> offset_or_window_dims,
914     absl::Span<const int64_t> slice_size) {
915   if (operand_sharding.IsTileMaximal()) {
916     return operand_sharding;
917   }
918   std::vector<int64_t> passthrough_tile(update_or_gather_shape.rank(), 1);
919   int64_t collapsed = 0;
920   for (int64_t i = 0; i < operand_shape.rank(); ++i) {
921     int64_t dim_partitions = operand_sharding.tile_assignment().dim(i);
922     if (absl::c_linear_search(collapsed_or_inserted_dims, i)) {
923       if (dim_partitions > 1) {
924         return std::nullopt;
925       }
926       collapsed++;
927       continue;
928     }
929     if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) {
930       return std::nullopt;
931     }
932     int64_t offset_dim = offset_or_window_dims[i - collapsed];
933     if (i - collapsed > 0 &&
934         offset_dim < offset_or_window_dims[i - collapsed - 1]) {
935       // Output offsets are transposed, we do not support this case.
936       return std::nullopt;
937     }
938     passthrough_tile[offset_dim] = dim_partitions;
939   }
940   for (int64_t i = operand_sharding.TiledDataRank();
941        i < operand_sharding.tile_assignment().num_dimensions(); ++i) {
942     passthrough_tile.push_back(operand_sharding.tile_assignment().dim(i));
943   }
944   Array<int64_t> tile_assignment = operand_sharding.tile_assignment();
945   tile_assignment.Reshape(passthrough_tile);
946   return operand_sharding.ReplicateOnLastTileDim()
947              ? HloSharding::PartialTile(tile_assignment,
948                                         operand_sharding.metadata())
949              : HloSharding::Subgroup(tile_assignment,
950                                      operand_sharding.subgroup_types(),
951                                      operand_sharding.metadata());
952 }
953 
954 // Inverse of PassthroughOperandToGatherOutputOrScatterUpdate.
PassthroughGatherOutputOrScatterUpdateToOperand(const Shape & operand_shape,const HloSharding & update_or_gather_sharding,absl::Span<const int64_t> collapsed_or_inserted_dims,absl::Span<const int64_t> index_map,absl::Span<const int64_t> offset_or_window_dims,absl::Span<const int64_t> slice_size)955 std::optional<HloSharding> PassthroughGatherOutputOrScatterUpdateToOperand(
956     const Shape& operand_shape, const HloSharding& update_or_gather_sharding,
957     absl::Span<const int64_t> collapsed_or_inserted_dims,
958     absl::Span<const int64_t> index_map,
959     absl::Span<const int64_t> offset_or_window_dims,
960     absl::Span<const int64_t> slice_size) {
961   if (update_or_gather_sharding.IsTileMaximal()) {
962     return update_or_gather_sharding;
963   }
964   std::vector<int64_t> passthrough_tile(operand_shape.rank(), 1);
965   int64_t collapsed = 0;
966   // Relevant dims have shardings passed to the operand.
967   std::vector<int64_t> relevant_update_or_gather_dims;
968   for (int64_t i = 0; i < operand_shape.rank(); ++i) {
969     if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
970         absl::c_linear_search(index_map, i)) {
971       collapsed++;
972       continue;
973     }
974     int64_t offset_dim = offset_or_window_dims[i - collapsed];
975     int64_t dim_partitions =
976         update_or_gather_sharding.tile_assignment().dim(offset_dim);
977     if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) {
978       return std::nullopt;
979     }
980     if (i - collapsed > 0 &&
981         offset_dim < offset_or_window_dims[i - collapsed - 1]) {
982       // Output offsets are transposed, we do not support this case.
983       return std::nullopt;
984     }
985     relevant_update_or_gather_dims.push_back(offset_dim);
986     passthrough_tile[i] = dim_partitions;
987   }
988 
989   HloSharding relevant_sharding =
990       PartiallyReplicateTiledShardingOnAllDimsExcept(
991           update_or_gather_sharding, relevant_update_or_gather_dims);
992   if (relevant_sharding.IsTileMaximal()) {
993     return relevant_sharding;
994   }
995   for (int64_t i = relevant_sharding.TiledDataRank();
996        i < relevant_sharding.tile_assignment().num_dimensions(); ++i) {
997     passthrough_tile.push_back(relevant_sharding.tile_assignment().dim(i));
998   }
999   Array<int64_t> tile_assignment = relevant_sharding.tile_assignment();
1000   tile_assignment.Reshape(passthrough_tile);
1001   return relevant_sharding.ReplicateOnLastTileDim()
1002              ? HloSharding::PartialTile(tile_assignment,
1003                                         update_or_gather_sharding.metadata())
1004              : HloSharding::Subgroup(tile_assignment,
1005                                      relevant_sharding.subgroup_types(),
1006                                      update_or_gather_sharding.metadata());
1007 }
1008 
1009 // Collect data operand sharding for a gather with parallel dimensions from
1010 // the output.
GatherParallelDataOperandSharding(const HloSharding & output_sharding,const HloInstruction & gather,const GatherParallelDims & parallel_dims)1011 std::optional<HloSharding> GatherParallelDataOperandSharding(
1012     const HloSharding& output_sharding, const HloInstruction& gather,
1013     const GatherParallelDims& parallel_dims) {
1014   if (output_sharding.IsTileMaximal()) {
1015     return output_sharding;
1016   }
1017   auto output_parallel_dims = GatherParallelOutputDims(gather, parallel_dims);
1018   auto output_aligned_operand_parallel_dims =
1019       GatherOutputAlignedOperandParallelDims(gather, parallel_dims);
1020   const Shape gather_shape = gather.shape();
1021   CHECK_EQ(output_parallel_dims.size(),
1022            output_aligned_operand_parallel_dims.size());
1023   std::vector<int64_t> operand_tile_assignment(
1024       gather.operand(0)->shape().rank(), 1);
1025   std::vector<int64_t> relevant_output_dims;
1026   for (int i = 0, parallel_idx = 0; i < gather_shape.rank(); ++i) {
1027     if (parallel_idx >= output_parallel_dims.size() ||
1028         output_parallel_dims[parallel_idx] != i) {
1029       continue;
1030     }
1031     const int64_t operand_dim =
1032         output_aligned_operand_parallel_dims[parallel_idx++];
1033     operand_tile_assignment[operand_dim] =
1034         output_sharding.tile_assignment().dim(i);
1035     relevant_output_dims.push_back(i);
1036   }
1037   HloSharding relevant_output_sharding =
1038       PartiallyReplicateTiledShardingOnAllDimsExcept(output_sharding,
1039                                                      relevant_output_dims);
1040   if (relevant_output_sharding.IsTileMaximal()) {
1041     return std::move(relevant_output_sharding);
1042   }
1043 
1044   for (int64_t i = relevant_output_sharding.TiledDataRank();
1045        i < relevant_output_sharding.tile_assignment().num_dimensions(); ++i) {
1046     operand_tile_assignment.push_back(
1047         relevant_output_sharding.tile_assignment().dim(i));
1048   }
1049   Array<int64_t> tile_assignment = relevant_output_sharding.tile_assignment();
1050   tile_assignment.Reshape(operand_tile_assignment);
1051   return relevant_output_sharding.ReplicateOnLastTileDim()
1052              ? HloSharding::PartialTile(tile_assignment,
1053                                         output_sharding.metadata())
1054              : HloSharding::Subgroup(tile_assignment,
1055                                      relevant_output_sharding.subgroup_types(),
1056                                      output_sharding.metadata());
1057 }
1058 
1059 }  // namespace
1060 
GatherOutputShardingFromDataOperand(const HloSharding & data_operand_sharding,const HloInstruction & hlo,absl::Span<const int64_t> slice_sizes,const Shape & output_shape,const Shape & operand_shape)1061 std::optional<HloSharding> GatherOutputShardingFromDataOperand(
1062     const HloSharding& data_operand_sharding, const HloInstruction& hlo,
1063     absl::Span<const int64_t> slice_sizes, const Shape& output_shape,
1064     const Shape& operand_shape) {
1065   const auto& dnums = hlo.gather_dimension_numbers();
1066   std::vector<int64_t> collapsed_slice_dims(
1067       dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end());
1068   std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
1069                                        dnums.start_index_map().end());
1070   std::vector<int64_t> offset_dims(dnums.offset_dims().begin(),
1071                                    dnums.offset_dims().end());
1072   return PassthroughOperandToGatherOutputOrScatterUpdate(
1073       operand_shape, data_operand_sharding, output_shape, collapsed_slice_dims,
1074       start_index_map, offset_dims, slice_sizes);
1075 }
1076 
GatherDataOperandShardingFromOutput(const HloSharding & output_sharding,const HloInstruction & hlo)1077 std::optional<HloSharding> GatherDataOperandShardingFromOutput(
1078     const HloSharding& output_sharding, const HloInstruction& hlo) {
1079   const auto& dnums = hlo.gather_dimension_numbers();
1080   std::vector<int64_t> collapsed_slice_dims(
1081       dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end());
1082   std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
1083                                        dnums.start_index_map().end());
1084   std::vector<int64_t> offset_dims(dnums.offset_dims().begin(),
1085                                    dnums.offset_dims().end());
1086 
1087   std::optional<HloSharding> parallel_sharding;
1088   auto parallel_dims = GetGatherBatchParallelDims(hlo);
1089   if (parallel_dims) {
1090     // Prioritize parallel sharding first as this is how it is in
1091     // spmd_partitioner.
1092     parallel_sharding =
1093         GatherParallelDataOperandSharding(output_sharding, hlo, *parallel_dims);
1094   }
1095   std::optional<HloSharding> passthrough_sharding =
1096       PassthroughGatherOutputOrScatterUpdateToOperand(
1097           hlo.operand(0)->shape(), output_sharding, collapsed_slice_dims,
1098           start_index_map, offset_dims, hlo.gather_slice_sizes());
1099   // Try to merge the two shardings or return the one that is present if only
1100   // one of the two is.
1101   if (!passthrough_sharding) {
1102     return parallel_sharding;
1103   }
1104   if (!parallel_sharding) {
1105     return passthrough_sharding;
1106   }
1107   if (MergeSharding(*parallel_sharding, &*passthrough_sharding,
1108                     /*may_combine_partial_sharding=*/true)) {
1109     return passthrough_sharding;
1110   }
1111   if (MergeSharding(*passthrough_sharding, &*parallel_sharding,
1112                     /*may_combine_partial_sharding=*/true)) {
1113     return parallel_sharding;
1114   }
1115   return parallel_sharding;
1116 }
1117 
GetScatterSliceSize(const Shape & operand_shape,const Shape & update_shape,const ScatterDimensionNumbers & dnums)1118 std::vector<int64_t> GetScatterSliceSize(const Shape& operand_shape,
1119                                          const Shape& update_shape,
1120                                          const ScatterDimensionNumbers& dnums) {
1121   std::vector<int64_t> slice_size(operand_shape.rank(), 1);
1122   int64_t num_update_window_dims = 0;
1123   for (int64_t i = 0; i < operand_shape.rank(); ++i) {
1124     if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
1125       continue;
1126     }
1127     slice_size[i] = update_shape.dimensions(
1128         dnums.update_window_dims(num_update_window_dims++));
1129   }
1130   return slice_size;
1131 }
1132 
ScatterOutputShardingFromUpdate(const HloSharding & update_sharding,const HloScatterInstruction & scatter)1133 std::optional<HloSharding> ScatterOutputShardingFromUpdate(
1134     const HloSharding& update_sharding, const HloScatterInstruction& scatter) {
1135   const auto& dnums = scatter.scatter_dimension_numbers();
1136   std::vector<int64_t> inserted_window_dims(
1137       dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end());
1138   std::vector<int64_t> scatter_dims_to_operand_dims(
1139       dnums.scatter_dims_to_operand_dims().begin(),
1140       dnums.scatter_dims_to_operand_dims().end());
1141   std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
1142                                           dnums.update_window_dims().end());
1143   std::vector<int64_t> slice_size =
1144       GetScatterSliceSize(scatter.scatter_operands()[0]->shape(),
1145                           scatter.scatter_updates()[0]->shape(), dnums);
1146   return PassthroughGatherOutputOrScatterUpdateToOperand(
1147       scatter.shape(), update_sharding, inserted_window_dims,
1148       scatter_dims_to_operand_dims, update_window_dims, slice_size);
1149 }
1150 
ScatterUpdateShardingFromOutput(const HloSharding & per_output_sharding,const HloScatterInstruction & scatter)1151 std::optional<HloSharding> ScatterUpdateShardingFromOutput(
1152     const HloSharding& per_output_sharding,
1153     const HloScatterInstruction& scatter) {
1154   const auto& dnums = scatter.scatter_dimension_numbers();
1155   std::vector<int64_t> inserted_window_dims(
1156       dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end());
1157   std::vector<int64_t> scatter_dims_to_operand_dims(
1158       dnums.scatter_dims_to_operand_dims().begin(),
1159       dnums.scatter_dims_to_operand_dims().end());
1160   std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
1161                                           dnums.update_window_dims().end());
1162   std::vector<int64_t> slice_size =
1163       GetScatterSliceSize(scatter.scatter_operands()[0]->shape(),
1164                           scatter.scatter_updates()[0]->shape(), dnums);
1165   return PassthroughOperandToGatherOutputOrScatterUpdate(
1166       scatter.scatter_operands()[0]->shape(), per_output_sharding,
1167       scatter.scatter_updates()[0]->shape(), inserted_window_dims,
1168       scatter_dims_to_operand_dims, update_window_dims, slice_size);
1169 }
1170 
1171 StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>>
IdentityValueAndHloOpcodeForScatterReduceComputation(const HloScatterInstruction & scatter)1172 IdentityValueAndHloOpcodeForScatterReduceComputation(
1173     const HloScatterInstruction& scatter) {
1174   auto computation = scatter.to_apply();
1175   // We only handle computations with 2 parameters and only 1 calculation.
1176   if (computation->instruction_count() != 3) {
1177     return Status(
1178         tensorflow::error::Code::INVALID_ARGUMENT,
1179         "Expected scatter reduce computation with 2 parameters and only 1 "
1180         "calculation");
1181   }
1182 
1183   auto root_instruction = computation->root_instruction();
1184   if (root_instruction->opcode() == HloOpcode::kAdd ||
1185       root_instruction->opcode() == HloOpcode::kOr) {
1186     return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::Zero(
1187                               scatter.shape().element_type())),
1188                           root_instruction->opcode());
1189   } else if (root_instruction->opcode() == HloOpcode::kMultiply ||
1190              root_instruction->opcode() == HloOpcode::kAnd) {
1191     return std::make_pair(HloInstruction::CreateConstant(
1192                               LiteralUtil::One(scatter.shape().element_type())),
1193                           root_instruction->opcode());
1194   } else if (root_instruction->opcode() == HloOpcode::kMaximum) {
1195     return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MinValue(
1196                               scatter.shape().element_type())),
1197                           root_instruction->opcode());
1198   } else if (root_instruction->opcode() == HloOpcode::kMinimum) {
1199     return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MaxValue(
1200                               scatter.shape().element_type())),
1201                           root_instruction->opcode());
1202   }
1203 
1204   return Status(tensorflow::error::Code::INVALID_ARGUMENT,
1205                 "Expected scatter reduce computation which is "
1206                 "add/or/multiply/add/min/max");
1207 }
1208 
1209 namespace {
1210 
DevicesForShardingInternal(const HloSharding & sharding,const absl::flat_hash_set<int64_t> & available_devices,absl::flat_hash_set<int64_t> * used)1211 void DevicesForShardingInternal(
1212     const HloSharding& sharding,
1213     const absl::flat_hash_set<int64_t>& available_devices,
1214     absl::flat_hash_set<int64_t>* used) {
1215   if (sharding.IsTuple()) {
1216     for (const auto& subsharding : sharding.tuple_elements()) {
1217       DevicesForShardingInternal(subsharding, available_devices, used);
1218     }
1219     return;
1220   }
1221 
1222   if (sharding.IsReplicated()) {
1223     for (int64_t device : available_devices) {
1224       if (!HloSharding::IsReservedDevice(device)) {
1225         used->insert(device);
1226       }
1227     }
1228     return;
1229   }
1230 
1231   DCHECK(std::all_of(
1232       sharding.tile_assignment().begin(), sharding.tile_assignment().end(),
1233       [&](int64_t device) { return available_devices.contains(device); }));
1234   sharding.tile_assignment().Each(
1235       [&](absl::Span<const int64_t> /*indices*/, int64_t device) {
1236         used->insert(device);
1237       });
1238 }
1239 
1240 }  // namespace
1241 
DevicesForSharding(const HloSharding & sharding,absl::Span<const int64_t> available_devices)1242 std::vector<int64_t> DevicesForSharding(
1243     const HloSharding& sharding, absl::Span<const int64_t> available_devices) {
1244   absl::flat_hash_set<int64_t> available_set;
1245   for (int64_t device : available_devices) {
1246     available_set.insert(device);
1247   }
1248   absl::flat_hash_set<int64_t> used_set;
1249   DevicesForShardingInternal(sharding, available_set, &used_set);
1250   std::vector<int64_t> devices;
1251   for (int64_t device : available_devices) {
1252     if (used_set.contains(device)) {
1253       devices.push_back(device);
1254     }
1255   }
1256   return devices;
1257 }
1258 
PartiallyReplicateTiledShardingOnDims(const HloSharding & sharding,absl::Span<const int64_t> dims_to_replicate)1259 HloSharding PartiallyReplicateTiledShardingOnDims(
1260     const HloSharding& sharding, absl::Span<const int64_t> dims_to_replicate) {
1261   if (sharding.IsTileMaximal()) {
1262     return sharding;
1263   }
1264   int64_t group_count = 1;
1265   std::vector<int64_t> valid_dims_to_replicate;
1266   for (int64_t dim : dims_to_replicate) {
1267     if (dim >= sharding.TiledDataRank()) {
1268       continue;
1269     }
1270     valid_dims_to_replicate.push_back(dim);
1271     group_count *= sharding.tile_assignment().dim(dim);
1272   }
1273   if (group_count == 1) {
1274     return sharding;
1275   }
1276   if (group_count == sharding.NumTiles() && sharding.subgroup_types().empty()) {
1277     return HloSharding::Replicate(sharding.metadata());
1278   }
1279   std::vector<int64_t> dim_permutation(sharding.TiledDataRank());
1280   absl::c_iota(dim_permutation, 0);
1281   absl::c_stable_sort(dim_permutation, [&](const int64_t a, const int64_t b) {
1282     return absl::c_linear_search(valid_dims_to_replicate, a) <
1283            absl::c_linear_search(valid_dims_to_replicate, b);
1284   });
1285   auto new_tile =
1286       TransposeSharding(sharding, dim_permutation).tile_assignment();
1287   std::vector<int64_t> new_tile_shape(
1288       sharding.tile_assignment().dimensions().begin(),
1289       sharding.tile_assignment().dimensions().end());
1290   for (int64_t dim : valid_dims_to_replicate) {
1291     new_tile_shape[dim] = 1;
1292   }
1293   if (sharding.ReplicateOnLastTileDim()) {
1294     new_tile_shape.back() *= group_count;
1295     new_tile.Reshape(new_tile_shape);
1296     return HloSharding::PartialTile(new_tile, sharding.metadata());
1297   } else {
1298     new_tile_shape.insert(new_tile_shape.begin() + sharding.TiledDataRank(),
1299                           group_count);
1300     new_tile.Reshape(new_tile_shape);
1301     std::vector<OpSharding::Type> subgroup_types;
1302     subgroup_types.push_back(OpSharding::REPLICATED);
1303     for (OpSharding::Type type : sharding.subgroup_types()) {
1304       subgroup_types.push_back(type);
1305     }
1306     return HloSharding::Subgroup(new_tile, subgroup_types, sharding.metadata());
1307   }
1308 }
1309 
PartiallyReplicateTiledShardingOnAllDimsExcept(const HloSharding & sharding,absl::Span<const int64_t> dims_to_keep)1310 HloSharding PartiallyReplicateTiledShardingOnAllDimsExcept(
1311     const HloSharding& sharding, absl::Span<const int64_t> dims_to_keep) {
1312   if (sharding.IsTileMaximal()) {
1313     return sharding;
1314   }
1315   std::vector<int64_t> dims_to_replicate(sharding.TiledDataRank());
1316   absl::c_iota(dims_to_replicate, 0);
1317 
1318   dims_to_replicate.erase(
1319       std::remove_if(
1320           dims_to_replicate.begin(), dims_to_replicate.end(),
1321           [&](int64_t i) { return absl::c_linear_search(dims_to_keep, i); }),
1322       dims_to_replicate.end());
1323   return PartiallyReplicateTiledShardingOnDims(sharding, dims_to_replicate);
1324 }
1325 
ReplicateAllDataDims(const HloSharding & sharding,int64_t data_rank)1326 HloSharding ReplicateAllDataDims(const HloSharding& sharding,
1327                                  int64_t data_rank) {
1328   if (sharding.IsManual()) {
1329     return sharding;
1330   }
1331   if (sharding.subgroup_types().empty()) {
1332     return HloSharding::Replicate(sharding.metadata());
1333   }
1334   HloSharding result =
1335       PartiallyReplicateTiledShardingOnAllDimsExcept(sharding, {});
1336   if (data_rank >= 0 && data_rank != result.TiledDataRank() &&
1337       !result.IsTileMaximal()) {
1338     std::vector<int64_t> new_tile_shape(data_rank, 1);
1339     for (int64_t i = result.TiledDataRank();
1340          i < result.tile_assignment().num_dimensions(); ++i) {
1341       new_tile_shape.push_back(result.tile_assignment().dim(i));
1342     }
1343     auto tile = result.tile_assignment();
1344     tile.Reshape(new_tile_shape);
1345     result = HloSharding::Subgroup(tile, result.subgroup_types());
1346   }
1347   return result;
1348 }
1349 
RemoveShapeDimensions(const HloSharding & sharding,absl::Span<const int64_t> dims_to_remove)1350 HloSharding RemoveShapeDimensions(const HloSharding& sharding,
1351                                   absl::Span<const int64_t> dims_to_remove) {
1352   if (sharding.IsTileMaximal() || dims_to_remove.empty()) {
1353     return sharding;
1354   }
1355   std::vector<int64_t> new_tile_shape;
1356   new_tile_shape.reserve(sharding.tile_assignment().num_dimensions() -
1357                          dims_to_remove.size());
1358   for (int64_t i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
1359     if (absl::c_linear_search(dims_to_remove, i)) {
1360       CHECK_EQ(sharding.tile_assignment().dim(i), 1);
1361     } else {
1362       new_tile_shape.push_back(sharding.tile_assignment().dim(i));
1363     }
1364   }
1365   auto new_tile = sharding.tile_assignment();
1366   new_tile.Reshape(new_tile_shape);
1367   return sharding.ReplicateOnLastTileDim()
1368              ? HloSharding::PartialTile(new_tile, sharding.metadata())
1369              : HloSharding::Subgroup(new_tile, sharding.subgroup_types(),
1370                                      sharding.metadata());
1371 }
1372 
TransposeShardingWithCollapsedDims(const HloSharding & source,absl::Span<int64_t const> src_to_tgt,absl::Span<int64_t const> tgt_to_src)1373 std::optional<HloSharding> TransposeShardingWithCollapsedDims(
1374     const HloSharding& source, absl::Span<int64_t const> src_to_tgt,
1375     absl::Span<int64_t const> tgt_to_src) {
1376   if (source.IsTileMaximal()) {
1377     return source;
1378   }
1379   if (src_to_tgt.size() < source.tile_assignment().num_dimensions()) {
1380     // Add missing subgroup dims.
1381     std::vector<int64_t> new_src_to_tgt(src_to_tgt.begin(), src_to_tgt.end());
1382     std::vector<int64_t> new_tgt_to_src(tgt_to_src.begin(), tgt_to_src.end());
1383     for (int64_t i = 0;
1384          i < source.tile_assignment().num_dimensions() - src_to_tgt.size();
1385          ++i) {
1386       new_src_to_tgt.push_back(tgt_to_src.size() + i);
1387       new_tgt_to_src.push_back(src_to_tgt.size() + i);
1388     }
1389     return TransposeShardingWithCollapsedDims(source, new_src_to_tgt,
1390                                               new_tgt_to_src);
1391   }
1392   std::vector<int64_t> tgt_dims_skipping_new(tgt_to_src.size(), -1);
1393   int64_t skipped_tgt_dims = 0;
1394   int64_t src_non_subgroup_dims =
1395       src_to_tgt.size() - source.subgroup_types().size();
1396   int64_t tgt_non_subgroup_dims =
1397       tgt_to_src.size() - source.subgroup_types().size();
1398   for (int64_t i = 0; i < tgt_to_src.size(); ++i) {
1399     if (tgt_to_src[i] < 0) {
1400       CHECK_LT(i, tgt_non_subgroup_dims)
1401           << "Sharding transpose should not remove subgroup dims.";
1402       skipped_tgt_dims++;
1403     } else {
1404       tgt_dims_skipping_new[i] = i - skipped_tgt_dims;
1405     }
1406   }
1407   int64_t skipped_src_dims = absl::c_count(src_to_tgt, -1);
1408   std::vector<int64_t> perm(src_to_tgt.size());
1409   for (int64_t i = 0; i < src_non_subgroup_dims; ++i) {
1410     if (src_to_tgt[i] < 0) {
1411       if (source.tile_assignment().dim(i) > 1) {
1412         return std::nullopt;
1413       }
1414       perm[src_non_subgroup_dims - skipped_src_dims] = i;
1415       skipped_src_dims--;
1416     } else {
1417       perm[tgt_dims_skipping_new[src_to_tgt[i]]] = i;
1418     }
1419   }
1420   skipped_src_dims = absl::c_count(src_to_tgt, -1);
1421   for (int64_t i = src_non_subgroup_dims; i < src_to_tgt.size(); ++i) {
1422     CHECK_GE(src_to_tgt[i], tgt_non_subgroup_dims)
1423         << "Sharding transpose should not move subgroup dims before data dims.";
1424     perm[src_to_tgt[i] - skipped_tgt_dims + skipped_src_dims] = i;
1425   }
1426   auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm);
1427   auto reshape_tiles = tgt_sharding.tile_assignment();
1428   std::vector<int64_t> tgt_tiles(tgt_to_src.size(), 1);
1429   for (int64_t i = 0; i < tgt_tiles.size(); ++i) {
1430     if (tgt_to_src[i] >= 0) {
1431       int64_t dim = tgt_dims_skipping_new[i];
1432       if (i >= tgt_non_subgroup_dims) {
1433         dim += skipped_src_dims;
1434       }
1435       tgt_tiles[i] = reshape_tiles.dim(dim);
1436     }
1437   }
1438   reshape_tiles.Reshape(tgt_tiles);
1439   return source.ReplicateOnLastTileDim()
1440              ? HloSharding::PartialTile(reshape_tiles, source.metadata())
1441              : HloSharding::Subgroup(reshape_tiles, source.subgroup_types(),
1442                                      source.metadata());
1443 }
1444 
GetDimensionForIota(const HloInstruction * maybe_iota)1445 std::optional<int64_t> GetDimensionForIota(const HloInstruction* maybe_iota) {
1446   if (auto* iota = DynCast<HloIotaInstruction>(maybe_iota)) {
1447     return iota->iota_dimension();
1448   }
1449 
1450   if (maybe_iota->shape().element_type() != S32) {
1451     return std::nullopt;
1452   }
1453   if (maybe_iota->IsConstant()) {
1454     std::vector<bool> is_iota_dim(maybe_iota->shape().rank(), true);
1455     maybe_iota->literal().EachCell<int32_t>(
1456         [&](absl::Span<const int64_t> indices, int32_t val) {
1457           for (int64_t i = 0; i < indices.size(); ++i) {
1458             if (val != indices[i]) {
1459               is_iota_dim[i] = false;
1460             }
1461           }
1462         });
1463     for (int64_t i = 0; i < is_iota_dim.size(); ++i) {
1464       if (is_iota_dim[i] && maybe_iota->shape().dimensions(i) > 1) {
1465         return i;
1466       }
1467     }
1468     return std::nullopt;
1469   }
1470 
1471   if (maybe_iota->opcode() == HloOpcode::kBroadcast) {
1472     auto operand_dim = GetDimensionForIota(maybe_iota->operand(0));
1473     if (operand_dim) {
1474       return maybe_iota->dimensions(*operand_dim);
1475     }
1476     return std::nullopt;
1477   }
1478   return std::nullopt;
1479 }
1480 
GetGatherBatchParallelDims(const HloInstruction & hlo)1481 std::optional<GatherParallelDims> GetGatherBatchParallelDims(
1482     const HloInstruction& hlo) {
1483   const auto& dnums = hlo.gather_dimension_numbers();
1484   int64_t index_dim = dnums.index_vector_dim();
1485   // Try to identify if there's a dimension in the indices that is monotonically
1486   // increasing with a Iota across a certain dimension. This would mean that the
1487   // access in the relative dimension indexed by this index in the operand is
1488   // parallelizable and that we can shard the operand (and the index/output)
1489   // across such dimension.
1490   // For example the pattern:
1491   //   %iota.1 = iota()
1492   //   %indices = concatenate(..., %iota.1, ...)
1493   //   ... = gather(..., %indices)
1494   // is common for tf.reverse_sequence and would match this case.
1495   absl::InlinedVector<const HloIotaInstruction*, 4> iotas;
1496   const HloInstruction* indices = hlo.operand(1);
1497   const int num_indices = dnums.start_index_map_size();
1498   std::vector<int64_t> index_parallel_in_dim(num_indices, -1);
1499   // Handle cases where we concatenate pieces of the indices one at a time.
1500   if (indices->opcode() == HloOpcode::kConcatenate &&
1501       indices->concatenate_dimension() == index_dim) {
1502     int concatenated_dims = 0;
1503     for (int i = 0; i < indices->operand_count(); ++i) {
1504       const HloInstruction* op = indices->operand(i);
1505       const int64_t num_indices_from_element =
1506           op->shape().dimensions_size() > index_dim
1507               ? op->shape().dimensions(index_dim)
1508               : 1;
1509       if (std::optional<int64_t> maybe_iota_dim = GetDimensionForIota(op)) {
1510         if (*maybe_iota_dim != index_dim) {
1511           for (int j = 0; j < num_indices_from_element; ++j) {
1512             index_parallel_in_dim[concatenated_dims + j] = *maybe_iota_dim;
1513           }
1514         }
1515       }
1516       concatenated_dims += num_indices_from_element;
1517     }
1518   } else if (std::optional<int64_t> maybe_iota_dim =
1519                  GetDimensionForIota(indices)) {
1520     if (*maybe_iota_dim != index_dim) {
1521       // This is a case of a single iota with index_dim being out of bounds.
1522       const int64_t num_indices_from_element =
1523           indices->shape().dimensions_size() > index_dim
1524               ? indices->shape().dimensions(index_dim)
1525               : 1;
1526       index_parallel_in_dim.assign(num_indices_from_element, *maybe_iota_dim);
1527     }
1528   }
1529   absl::InlinedVector<int64_t, 1> indices_parallel_dims;
1530   absl::InlinedVector<int64_t, 1> operand_parallel_dims;
1531   // Map the parallelizable dimension from the iota to the dimensions of the
1532   // output and the operand. These dimensions are interconnected, but between
1533   // operands and index they could have different spots in the shape because the
1534   // position of the index dimension in the operand is determined by
1535   // start_index_map.
1536   for (int i = 0; i < index_parallel_in_dim.size(); ++i) {
1537     int index_parallel_dim = index_parallel_in_dim[i];
1538     if (index_parallel_dim == -1) {
1539       continue;
1540     }
1541     if (absl::c_linear_search(indices_parallel_dims, index_parallel_dim)) {
1542       return std::nullopt;
1543     }
1544     // Considered parallel only if the slice is of size 1 over the operand.
1545     if (hlo.gather_slice_sizes()[dnums.start_index_map(i)] == 1) {
1546       indices_parallel_dims.push_back(index_parallel_dim);
1547       operand_parallel_dims.push_back(dnums.start_index_map(i));
1548     } else {
1549       index_parallel_in_dim[i] = -1;
1550     }
1551   }
1552   absl::c_sort(indices_parallel_dims);
1553   if (!indices_parallel_dims.empty()) {
1554     return GatherParallelDims{indices_parallel_dims, operand_parallel_dims,
1555                               index_parallel_in_dim};
1556   }
1557   return std::nullopt;
1558 }
1559 
GatherParallelOutputDims(const HloInstruction & gather,const GatherParallelDims & parallel_dim)1560 absl::InlinedVector<int64_t, 1> GatherParallelOutputDims(
1561     const HloInstruction& gather, const GatherParallelDims& parallel_dim) {
1562   absl::InlinedVector<int64_t, 1> output_parallel_dims;
1563   auto indices_parallel_dims = parallel_dim.indices_parallel_dims;
1564   const Shape gather_shape = gather.shape();
1565   auto dnums = gather.gather_dimension_numbers();
1566   for (int i = 0, idx_dim = 0; i < gather_shape.dimensions_size(); ++i) {
1567     if (absl::c_linear_search(dnums.offset_dims(), i)) {
1568       continue;
1569     }
1570     const int index_dim =
1571         idx_dim < dnums.index_vector_dim() ? idx_dim : idx_dim + 1;
1572     if (absl::c_binary_search(indices_parallel_dims, index_dim)) {
1573       output_parallel_dims.push_back(i);
1574     }
1575     ++idx_dim;
1576   }
1577   return output_parallel_dims;
1578 }
1579 
GatherOutputAlignedOperandParallelDims(const HloInstruction & gather,const GatherParallelDims & parallel_dims)1580 absl::InlinedVector<int64_t, 1> GatherOutputAlignedOperandParallelDims(
1581     const HloInstruction& gather, const GatherParallelDims& parallel_dims) {
1582   absl::InlinedVector<int64_t, 1> operand_parallel_dim_to_output(
1583       parallel_dims.operand_parallel_dims.size(), -1);
1584   auto dnums = gather.gather_dimension_numbers();
1585   CHECK_LE(parallel_dims.indices_parallel_dims.size(),
1586            parallel_dims.operand_parallel_dims.size());
1587   for (int i = 0; i < parallel_dims.index_parallel_in_dim.size(); ++i) {
1588     // This is the equivalent batch dimension of the indices that corresponds
1589     // to this index dimension.
1590     const int64_t index_parallel_dim = parallel_dims.index_parallel_in_dim[i];
1591     // If it's not an index that is parallel skip.
1592     if (index_parallel_dim == -1) {
1593       continue;
1594     }
1595     // This is small so just look linearly. Populate the operand parallel
1596     // dimensions based on the order of the index batch dims (which is the same
1597     // order as the output).
1598     for (int j = 0; j < parallel_dims.indices_parallel_dims.size(); ++j) {
1599       if (parallel_dims.indices_parallel_dims[j] == index_parallel_dim) {
1600         const int64_t operand_parallel_dim = dnums.start_index_map(i);
1601         if (operand_parallel_dim_to_output[j] == -1) {
1602           operand_parallel_dim_to_output[j] = operand_parallel_dim;
1603         }
1604         break;
1605       }
1606     }
1607   }
1608   return operand_parallel_dim_to_output;
1609 }
1610 
ToString() const1611 std::string GroupedSharding::ToString() const {
1612   auto result = absl::StrCat("dims: ", absl::StrJoin(group_dims, ","),
1613                              "\ndevice_groups:\n");
1614   absl::StrAppend(&result,
1615                   "group dim sizes: ", absl::StrJoin(group_dim_sizes, ","));
1616   absl::StrAppend(&result, "data rank: ", data_rank);
1617   absl::StrAppend(&result, "subgroup manual: ", subgroup_manual);
1618   for (auto& device_group : device_groups) {
1619     absl::StrAppend(&result, "\t", absl::StrJoin(device_group, ","), "\n");
1620   }
1621   absl::StrAppend(&result, "inner sharding: ", sharding.ToString());
1622   return result;
1623 }
1624 
GroupShardingOnDims(const HloSharding & sharding,absl::Span<const int64_t> group_dims,bool subgroup_manual)1625 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
1626                                     absl::Span<const int64_t> group_dims,
1627                                     bool subgroup_manual) {
1628   std::vector<int64_t> group_dim_shards(group_dims.size(), 1);
1629   return GroupShardingOnDims(sharding, group_dims, group_dim_shards,
1630                              subgroup_manual);
1631 }
1632 
GroupShardingOnDims(const HloSharding & sharding,absl::Span<const int64_t> group_dims,absl::Span<const int64_t> group_dim_shards,bool subgroup_manual)1633 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
1634                                     absl::Span<const int64_t> group_dims,
1635                                     absl::Span<const int64_t> group_dim_shards,
1636                                     bool subgroup_manual) {
1637   CHECK(!sharding.IsTileMaximal());
1638   std::vector<int64_t> grouped_tiling_dims =
1639       sharding.tile_assignment().dimensions();
1640   std::vector<int64_t> group_dim_sizes(group_dims.size());
1641   for (int64_t i = 0; i < group_dims.size(); ++i) {
1642     CHECK_EQ(grouped_tiling_dims[group_dims[i]] % group_dim_shards[i], 0);
1643     group_dim_sizes[i] =
1644         grouped_tiling_dims[group_dims[i]] / group_dim_shards[i];
1645     grouped_tiling_dims[group_dims[i]] = group_dim_shards[i];
1646   }
1647 
1648   std::vector<std::vector<int64_t>> device_groups(Product(group_dim_sizes));
1649   sharding.tile_assignment().Each([&](absl::Span<const int64_t> indices,
1650                                       int64_t device) {
1651     int64_t group_id = 0;
1652     for (int64_t i = 0; i < group_dims.size(); ++i) {
1653       group_id *=
1654           sharding.tile_assignment().dim(group_dims[i]) / group_dim_shards[i];
1655       group_id += indices[group_dims[i]] / group_dim_shards[i];
1656     }
1657     device_groups[group_id].push_back(device);
1658   });
1659   auto grouped = GroupedSharding(
1660       std::move(device_groups),
1661       std::vector<int64_t>(group_dims.begin(), group_dims.end()),
1662       std::move(group_dim_sizes), sharding.tile_assignment().num_dimensions(),
1663       HloSharding::Replicate(), subgroup_manual);
1664   if (sharding.ReplicateOnLastTileDim()) {
1665     grouped.data_rank--;
1666   }
1667   if (sharding.IsManualSubgroup()) {
1668     grouped.data_rank -= sharding.subgroup_types().size();
1669   }
1670   if (Product(grouped_tiling_dims) == 1 ||
1671       (sharding.ReplicateOnLastTileDim() &&
1672        Product(grouped_tiling_dims) == grouped_tiling_dims.back())) {
1673     return grouped;
1674   }
1675   if (sharding.IsManualSubgroup()) {
1676     int64_t tile_dimensions = sharding.tile_assignment().num_dimensions();
1677     int64_t subgroup_size = sharding.subgroup_types().size();
1678     int64_t rank = tile_dimensions - subgroup_size;
1679     int num_dims_erase = 0;
1680     for (int i = 0; i < subgroup_size; i++) {
1681       if (sharding.subgroup_types()[i] == OpSharding::MANUAL) {
1682         grouped_tiling_dims.erase(grouped_tiling_dims.begin() + i + rank -
1683                                   num_dims_erase);
1684         num_dims_erase++;
1685       }
1686     }
1687   }
1688   if (sharding.ReplicateOnLastTileDim() && grouped_tiling_dims.back() == 1) {
1689     grouped_tiling_dims.pop_back();
1690   }
1691   Array<int64_t> grouped_tiling(grouped_tiling_dims);
1692   grouped_tiling.FillIota(0);
1693   grouped.sharding =
1694       sharding.ReplicateOnLastTileDim() &&
1695               grouped_tiling_dims.size() ==
1696                   sharding.tile_assignment().num_dimensions()
1697           ? HloSharding::PartialTile(grouped_tiling, sharding.metadata())
1698           : HloSharding::Tile(grouped_tiling, sharding.metadata());
1699   return grouped;
1700 }
1701 
GetManualSubgroupSharding(const HloSharding & sharding)1702 GroupedSharding GetManualSubgroupSharding(const HloSharding& sharding) {
1703   CHECK(sharding.IsManualSubgroup());
1704   int64_t tile_dimensions = sharding.tile_assignment().num_dimensions();
1705   int64_t subgroup_size = sharding.subgroup_types().size();
1706   int64_t rank = tile_dimensions - subgroup_size;
1707   std::vector<int64_t> group_dims;
1708   bool last_tile_dim_replicate = false;
1709 
1710   for (int64_t i = 0; i < subgroup_size; i++) {
1711     if (sharding.subgroup_types()[i] == OpSharding::MANUAL) {
1712       group_dims.push_back(rank + i);
1713     } else if (sharding.subgroup_types()[i] == OpSharding::REPLICATED) {
1714       last_tile_dim_replicate = true;
1715     }
1716   }
1717 
1718   GroupedSharding group_sharding =
1719       GroupShardingOnDims(sharding, group_dims, /*subgroup_manual=*/true);
1720 
1721   if (last_tile_dim_replicate ||
1722       group_sharding.sharding.tile_assignment().num_dimensions() > rank) {
1723     group_sharding.sharding = HloSharding::PartialTile(
1724         group_sharding.sharding.tile_assignment(), sharding.metadata());
1725   }
1726   return group_sharding;
1727 }
1728 
UngroupSharding(const GroupedSharding & grouped_sharding)1729 HloSharding UngroupSharding(const GroupedSharding& grouped_sharding) {
1730   std::vector<int64_t> tiling_dims;
1731   bool partial_sharding = false;
1732   std::vector<OpSharding::Type> subgroup_types;
1733   Array<int64_t> grouped_tiling = grouped_sharding.sharding.tile_assignment();
1734   if (grouped_sharding.sharding.IsTileMaximal()) {
1735     tiling_dims = std::vector<int64_t>(grouped_sharding.data_rank, 1);
1736     if (grouped_sharding.device_groups[0].size() != 1 ||
1737         absl::c_linear_search(grouped_sharding.group_dims,
1738                               tiling_dims.size())) {
1739       // This is partial sharding.
1740       tiling_dims.push_back(grouped_sharding.device_groups[0].size());
1741       partial_sharding = true;
1742     }
1743     grouped_tiling = Array<int64_t>(tiling_dims);
1744     grouped_tiling.FillIota(0);
1745   }
1746 
1747   // Handles subgroup manual first.
1748   if (grouped_sharding.subgroup_manual) {
1749     partial_sharding = grouped_sharding.sharding.ReplicateOnLastTileDim() ||
1750                        grouped_sharding.sharding.IsReplicated();
1751     int64_t subgroup_dim_size = grouped_sharding.group_dims.size();
1752     if (partial_sharding) {
1753       subgroup_dim_size++;
1754     }
1755     subgroup_types = std::vector<OpSharding::Type>(subgroup_dim_size,
1756                                                    OpSharding::REPLICATED);
1757     if (!grouped_sharding.sharding.IsTileMaximal()) {
1758       tiling_dims = grouped_sharding.sharding.tile_assignment().dimensions();
1759     }
1760     for (int i = 0; i < grouped_sharding.group_dims.size(); i++) {
1761       subgroup_types[grouped_sharding.group_dims[i] -
1762                      grouped_sharding.data_rank] = OpSharding::MANUAL;
1763       tiling_dims.insert(tiling_dims.begin() + grouped_sharding.group_dims[i],
1764                          1);
1765     }
1766     grouped_tiling.Reshape(tiling_dims);
1767   } else if (!grouped_sharding.sharding.IsTileMaximal()) {
1768     // Handles tile replicated.
1769     partial_sharding = grouped_sharding.sharding.ReplicateOnLastTileDim();
1770     tiling_dims = grouped_sharding.sharding.tile_assignment().dimensions();
1771     if (absl::c_linear_search(grouped_sharding.group_dims,
1772                               tiling_dims.size())) {
1773       tiling_dims.push_back(1);
1774       grouped_tiling.Reshape(tiling_dims);
1775       partial_sharding = true;
1776     }
1777   }
1778 
1779   // Update group dim sizes.
1780   for (int64_t i = 0; i < grouped_sharding.group_dims.size(); ++i) {
1781     int64_t dim = grouped_sharding.group_dims[i];
1782     tiling_dims[dim] *= grouped_sharding.group_dim_sizes[i];
1783   }
1784   Array<int64_t> tiling(tiling_dims);
1785   grouped_tiling.Each([&](absl::Span<const int64_t> indices, int64_t device) {
1786     std::vector<int64_t> ungrouped_inds(indices.begin(), indices.end());
1787     for (int64_t g = 0; g < grouped_sharding.device_groups.size(); ++g) {
1788       int64_t remaining_group_index = g;
1789       for (int64_t i = grouped_sharding.group_dims.size() - 1; i >= 0; --i) {
1790         int64_t dim = grouped_sharding.group_dims[i];
1791         int64_t groups_in_this_dim = grouped_sharding.group_dim_sizes[i];
1792         ungrouped_inds[dim] = (remaining_group_index % groups_in_this_dim) *
1793                                   grouped_tiling.dim(dim) +
1794                               indices[dim];
1795         remaining_group_index /= groups_in_this_dim;
1796       }
1797       tiling(ungrouped_inds) = grouped_sharding.device_groups[g][device];
1798     }
1799   });
1800 
1801   if (grouped_sharding.subgroup_manual) {
1802     return HloSharding::Subgroup(tiling, subgroup_types,
1803                                  grouped_sharding.sharding.metadata());
1804   }
1805   return partial_sharding ? HloSharding::PartialTile(tiling)
1806                           : HloSharding::Tile(tiling);
1807 }
1808 
DeviceGroupsAreMatch(GroupedSharding & lhs,GroupedSharding & rhs,bool ignore_group_order)1809 bool DeviceGroupsAreMatch(GroupedSharding& lhs, GroupedSharding& rhs,
1810                           bool ignore_group_order) {
1811   if (lhs.device_groups.size() != rhs.device_groups.size()) {
1812     return false;
1813   }
1814 
1815   bool matching_groups = true;
1816   absl::flat_hash_map<int64_t, int64_t> device_to_ref_group;
1817   for (int64_t g = 0; g < lhs.device_groups.size(); ++g) {
1818     for (int64_t device : lhs.device_groups[g]) {
1819       device_to_ref_group[device] = g;
1820     }
1821   }
1822   auto unique_ref_dev_group =
1823       [&](absl::Span<const int64_t> devices) -> int64_t {
1824     int64_t ref_g = -1;
1825     for (int64_t device : devices) {
1826       if (ref_g == -1) {
1827         ref_g = device_to_ref_group[device];
1828       } else if (ref_g != device_to_ref_group[device]) {
1829         return -1;
1830       }
1831     }
1832     return ref_g;
1833   };
1834   for (int64_t g = 0; g < rhs.device_groups.size(); ++g) {
1835     int64_t ref_g = unique_ref_dev_group(rhs.device_groups[g]);
1836     if (ref_g < 0 || (!ignore_group_order && g != ref_g)) {
1837       matching_groups = false;
1838       break;
1839     }
1840   }
1841 
1842   return matching_groups;
1843 }
1844 
SplitShardingDimension(const HloSharding & sharding,int64_t dimension,int64_t new_dim_size)1845 HloSharding SplitShardingDimension(const HloSharding& sharding,
1846                                    int64_t dimension, int64_t new_dim_size) {
1847   CHECK_GT(sharding.TiledDataRank(), dimension);
1848   CHECK_EQ(sharding.tile_assignment().dim(dimension) % new_dim_size, 0)
1849       << "dim size " << new_dim_size;
1850   auto new_tile_assignment = sharding.tile_assignment();
1851   std::vector<int64_t> dimensions = new_tile_assignment.dimensions();
1852   int64_t current_dimension = dimensions[dimension];
1853   dimensions.insert(dimensions.begin() + dimension + 1,
1854                     current_dimension / new_dim_size);
1855   dimensions[dimension] = new_dim_size;
1856   new_tile_assignment.Reshape(dimensions);
1857   auto new_sharding = sharding.ReplicateOnLastTileDim()
1858                           ? HloSharding::PartialTile(new_tile_assignment)
1859                           : HloSharding::Subgroup(new_tile_assignment,
1860                                                   sharding.subgroup_types());
1861   std::vector<int64_t> permutation(new_sharding.tile_assignment().dimensions());
1862   absl::c_iota(permutation, 0);
1863   std::swap(permutation[dimension], permutation[dimension + 1]);
1864   return TransposeSharding(new_sharding, permutation);
1865 }
1866 
MergeShardingDimension(const HloSharding & sharding,int64_t dimension)1867 HloSharding MergeShardingDimension(const HloSharding& sharding,
1868                                    int64_t dimension) {
1869   CHECK_GT(sharding.TiledDataRank(), dimension);
1870   std::vector<int64_t> permutation(sharding.tile_assignment().dimensions());
1871   absl::c_iota(permutation, 0);
1872   std::swap(permutation[dimension], permutation[dimension + 1]);
1873   auto transposed_sharding = TransposeSharding(sharding, permutation);
1874   auto new_tile_assignment = transposed_sharding.tile_assignment();
1875   std::vector<int64_t> dimensions = new_tile_assignment.dimensions();
1876   dimensions[dimension] *= dimensions[dimension + 1];
1877   dimensions.erase(dimensions.begin() + dimension + 1);
1878   new_tile_assignment.Reshape(dimensions);
1879   return sharding.ReplicateOnLastTileDim()
1880              ? HloSharding::PartialTile(new_tile_assignment)
1881              : HloSharding::Subgroup(new_tile_assignment,
1882                                      sharding.subgroup_types());
1883 }
1884 
1885 }  // namespace hlo_sharding_util
1886 }  // namespace xla
1887