xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_sharding.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <map>
21 #include <numeric>
22 #include <optional>
23 #include <ostream>
24 #include <set>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 #include "absl/algorithm/container.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/xla/overflow_util.h"
35 #include "tensorflow/compiler/xla/service/hlo_op_metadata.h"
36 #include "tensorflow/compiler/xla/status_macros.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 
40 namespace xla {
41 
42 using absl::StrCat;
43 using absl::StrJoin;
44 
AssignDevice(int64_t device_id,absl::Span<const OpMetadata> metadata)45 HloSharding HloSharding::AssignDevice(int64_t device_id,
46                                       absl::Span<const OpMetadata> metadata) {
47   return HloSharding(device_id, metadata);
48 }
49 
Tile1D(const Shape & input_shape,int64_t num_tiles,absl::Span<const OpMetadata> metadata)50 HloSharding HloSharding::Tile1D(const Shape& input_shape, int64_t num_tiles,
51                                 absl::Span<const OpMetadata> metadata) {
52   CHECK_EQ(1, input_shape.rank());
53   CHECK_GT(num_tiles, 1);
54   std::vector<int64_t> dimensions(1, num_tiles);
55   Array<int64_t> assignment(dimensions);
56   std::iota(assignment.begin(), assignment.end(), 0);
57   return HloSharding(assignment, /*replicate_on_last_tile_dim=*/false,
58                      metadata);
59 }
60 
PartialTile(const Array<int64_t> & group_tile_assignment,absl::Span<const absl::Span<const int64_t>> replication_groups,absl::Span<const OpMetadata> metadata)61 HloSharding HloSharding::PartialTile(
62     const Array<int64_t>& group_tile_assignment,
63     absl::Span<const absl::Span<const int64_t>> replication_groups,
64     absl::Span<const OpMetadata> metadata) {
65   CHECK_EQ(group_tile_assignment.num_elements(), replication_groups.size());
66   if (replication_groups.size() == 1) {
67     return Replicate(metadata);
68   }
69   auto new_tile_dims = group_tile_assignment.dimensions();
70   new_tile_dims.push_back(replication_groups[0].size());
71   auto new_tile_assignment = Array<int64_t>(new_tile_dims);
72   new_tile_assignment.Each(
73       [&](absl::Span<const int64_t> indices, int64_t* device) {
74         std::vector<int64_t> group_index(indices.begin(), indices.end());
75         group_index.pop_back();
76         int64_t group = group_tile_assignment(group_index);
77         *device = replication_groups[group][indices.back()];
78       });
79   return PartialTile(new_tile_assignment, metadata);
80 }
81 
PartialTile(const Array<int64_t> & tile_assignment_last_dim_replicate,absl::Span<const OpMetadata> metadata)82 HloSharding HloSharding::PartialTile(
83     const Array<int64_t>& tile_assignment_last_dim_replicate,
84     absl::Span<const OpMetadata> metadata) {
85   if (tile_assignment_last_dim_replicate.num_dimensions() == 1 ||
86       tile_assignment_last_dim_replicate.dimensions().back() ==
87           tile_assignment_last_dim_replicate.num_elements()) {
88     return Replicate(metadata);
89   }
90   if (tile_assignment_last_dim_replicate.dimensions().back() == 1) {
91     auto new_tile_dims = tile_assignment_last_dim_replicate.dimensions();
92     new_tile_dims.pop_back();
93     auto fully_tiled = tile_assignment_last_dim_replicate;
94     fully_tiled.Reshape(new_tile_dims);
95     return HloSharding(fully_tiled, /*replicate_on_last_tile_dim=*/false,
96                        metadata);
97   }
98   std::vector<std::set<int64_t>> sorted_groups(
99       tile_assignment_last_dim_replicate.num_elements() /
100       tile_assignment_last_dim_replicate.dimensions().back());
101   auto get_group_id = [&](absl::Span<const int64_t> indices) {
102     int64_t group_id = 0;
103     for (int64_t i = 0; i < indices.size() - 1; ++i) {
104       group_id *= tile_assignment_last_dim_replicate.dim(i);
105       group_id += indices[i];
106     }
107     return group_id;
108   };
109   tile_assignment_last_dim_replicate.Each(
110       [&](absl::Span<const int64_t> indices, const int64_t device) {
111         sorted_groups[get_group_id(indices)].insert(device);
112       });
113   Array<int64_t> sorted_tile(tile_assignment_last_dim_replicate.dimensions());
114   sorted_tile.Each([&](absl::Span<const int64_t> indices, int64_t* device) {
115     const int64_t group_id = get_group_id(indices);
116     auto begin = sorted_groups[group_id].begin();
117     *device = *begin;
118     sorted_groups[group_id].erase(begin);
119   });
120   return HloSharding(sorted_tile, /*replicate_on_last_tile_dim=*/true,
121                      metadata);
122 }
123 
Subgroup(const Array<int64_t> & tile_assignment,absl::Span<const OpSharding::Type> subgroup_types,absl::Span<const OpMetadata> metadata)124 HloSharding HloSharding::Subgroup(
125     const Array<int64_t>& tile_assignment,
126     absl::Span<const OpSharding::Type> subgroup_types,
127     absl::Span<const OpMetadata> metadata) {
128   if (subgroup_types.empty()) {
129     return HloSharding(tile_assignment, /*replicate_on_last_tile_dim=*/false,
130                        metadata);
131   }
132   // If there is only one type of subgrouping and there is no tiling on data
133   // dimensions, it can be canonicalized to a simple manual/replicated sharding.
134   if (absl::c_all_of(
135           subgroup_types,
136           [&](const OpSharding::Type t) { return t == subgroup_types[0]; }) &&
137       Product(absl::Span<const int64_t>(tile_assignment.dimensions())
138                   .subspan(0, tile_assignment.num_dimensions() -
139                                   subgroup_types.size())) == 1) {
140     if (subgroup_types[0] == OpSharding::MANUAL) {
141       return Manual(metadata);
142     }
143     if (subgroup_types[0] == OpSharding::REPLICATED) {
144       return Replicate(metadata);
145     }
146   }
147   // Normalize the subgroups to simplify two cases:
148   //   - Remove trivial dims of size 1.
149   //   - Merge dims of the same type.
150   //   - Sort types.
151   int64_t data_dims = tile_assignment.num_dimensions() - subgroup_types.size();
152   std::vector<int64_t> perm(data_dims);
153   std::iota(perm.begin(), perm.end(), 0);
154   // Make sure the replicate dims are at the end so that we can leverage
155   // PartialTile() to sort the elements.
156   struct CmpTypeRepliateLast {
157     bool operator()(OpSharding::Type a, OpSharding::Type b) const {
158       if (a == b) {
159         return false;
160       }
161       if (a == OpSharding::REPLICATED) {
162         return false;
163       }
164       if (b == OpSharding::REPLICATED) {
165         return true;
166       }
167       return a < b;
168     }
169   };
170   std::map<OpSharding::Type, std::vector<int64_t>, CmpTypeRepliateLast>
171       type_to_dims;
172   bool needs_merging = false;
173   for (int64_t i = 0; i < subgroup_types.size(); ++i) {
174     if (tile_assignment.dim(i + data_dims) == 1) {
175       needs_merging = true;
176       continue;
177     }
178     auto& dims = type_to_dims[subgroup_types[i]];
179     needs_merging |= !dims.empty();
180     dims.push_back(i + data_dims);
181   }
182   needs_merging |= type_to_dims.size() > 1;
183   auto create_sharding = [](const Array<int64_t> tiles,
184                             absl::Span<const OpSharding::Type> types,
185                             absl::Span<const OpMetadata> metadata) {
186     if (types.size() == 1 && types.back() == OpSharding::REPLICATED) {
187       // Normalize to partial tile.
188       return PartialTile(tiles, metadata);
189     }
190     if (types.size() == 1 && types.back() == OpSharding::MANUAL &&
191         tiles.num_elements() == tiles.dimensions().back()) {
192       // Normalize to manual.
193       return Manual(metadata);
194     }
195     if (!types.empty() && types.back() == OpSharding::REPLICATED) {
196       // If the last type is REPLICATED, we first create a partially replicated
197       // sharding without other subgroups so that the elements are sorted. Then
198       // we fix the subgroup types.
199       HloSharding sharding = PartialTile(tiles, metadata);
200       sharding.replicate_on_last_tile_dim_ = false;
201       for (const OpSharding::Type type : types) {
202         sharding.subgroup_types_.push_back(type);
203       }
204       return sharding;
205     }
206     return HloSharding(tiles, types, metadata);
207   };
208   if (needs_merging) {
209     auto data_tile_shape =
210         absl::Span<const int64_t>(tile_assignment.dimensions())
211             .subspan(0, data_dims);
212     std::vector<int64_t> merged_shape(data_tile_shape.begin(),
213                                       data_tile_shape.end());
214     std::vector<int64_t> transposed_shape = merged_shape;
215     std::vector<OpSharding::Type> merged_types;
216     for (const auto& type_dims : type_to_dims) {
217       int64_t dim_size = 1;
218       for (int64_t dim : type_dims.second) {
219         perm.push_back(dim);
220         dim_size *= tile_assignment.dim(dim);
221         transposed_shape.push_back(tile_assignment.dim(dim));
222       }
223       merged_shape.push_back(dim_size);
224       merged_types.push_back(type_dims.first);
225     }
226     Array<int64_t> new_tiles(transposed_shape);
227     new_tiles.Each([&](absl::Span<const int64_t> indices, int64_t* value) {
228       std::vector<int64_t> src_indices(tile_assignment.num_dimensions(), 0);
229       for (int64_t i = 0; i < indices.size(); ++i) {
230         src_indices[perm[i]] = indices[i];
231       }
232       *value = tile_assignment(src_indices);
233     });
234     new_tiles.Reshape(merged_shape);
235     return create_sharding(new_tiles, merged_types, metadata);
236   }
237   return create_sharding(tile_assignment, subgroup_types, metadata);
238 }
239 
Tuple(const ShapeTree<HloSharding> & sub_shardings)240 HloSharding HloSharding::Tuple(const ShapeTree<HloSharding>& sub_shardings) {
241   std::vector<HloSharding> flattened_list;
242   flattened_list.reserve(sub_shardings.leaf_count());
243   for (const auto& index_to_sharding : sub_shardings.leaves()) {
244     flattened_list.push_back(index_to_sharding.second);
245   }
246   if (flattened_list.empty()) {
247     // Empty tuple sharding ends up having no leaves, but we want to allow
248     // empty tuple HLO instruction results to have sharding, so we fetch the
249     // root ({}) sharding value from the ShapeTree.
250     // A ShapeTree created with ShapeTree<HloSharding>(shape, init) will have
251     // init as value at its root.
252     flattened_list.push_back(sub_shardings.element(ShapeIndex({})));
253   }
254   return HloSharding(flattened_list);
255 }
256 
Tuple(const Shape & tuple_shape,absl::Span<const HloSharding> shardings)257 HloSharding HloSharding::Tuple(const Shape& tuple_shape,
258                                absl::Span<const HloSharding> shardings) {
259   CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape);
260   for (auto& sharding : shardings) {
261     CHECK(!sharding.IsTuple())
262         << sharding.ToString() << ShapeUtil::HumanString(tuple_shape);
263   }
264   std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end());
265   CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape))
266       << "Flat list has " << flattened_list.size() << ", required "
267       << RequiredLeaves(tuple_shape);
268   return HloSharding(flattened_list);
269 }
270 
SingleTuple(const Shape & tuple_shape,const HloSharding & sharding)271 HloSharding HloSharding::SingleTuple(const Shape& tuple_shape,
272                                      const HloSharding& sharding) {
273   CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape);
274   CHECK(!sharding.IsTuple()) << sharding.ToString();
275   int64_t leaf_count = RequiredLeaves(tuple_shape);
276   std::vector<HloSharding> flattened_list;
277   flattened_list.resize(leaf_count, sharding);
278   return HloSharding(flattened_list);
279 }
280 
Single(const Shape & shape,const HloSharding & sharding)281 HloSharding HloSharding::Single(const Shape& shape,
282                                 const HloSharding& sharding) {
283   return shape.IsTuple() ? SingleTuple(shape, sharding) : sharding;
284 }
285 
ToString(bool include_metadata) const286 std::string HloSharding::ToString(bool include_metadata) const {
287   if (IsTuple()) {
288     CHECK(metadata_.empty());
289     std::string result = "{";
290     for (int i = 0; i < tuple_elements_.size(); ++i) {
291       const HloSharding& element = tuple_elements_[i];
292       if (i != 0) {
293         absl::StrAppend(&result, ", ");
294         if (i % 5 == 0) {
295           absl::StrAppend(&result, "/*index=", i, "*/");
296         }
297       }
298       absl::StrAppend(&result, element.ToString(include_metadata));
299     }
300     absl::StrAppend(&result, "}");
301     return result;
302   }
303 
304   std::string metadata;
305   if (include_metadata) {
306     if (metadata_.size() == 1) {
307       metadata =
308           StrCat(" metadata={", OpMetadataToString(metadata_.front()), "}");
309     } else if (metadata_.size() > 1) {
310       std::vector<std::string> metadata_strings;
311       metadata_strings.reserve(metadata_.size());
312       for (const auto& single_metadata : metadata_) {
313         metadata_strings.push_back(
314             StrCat("{", OpMetadataToString(single_metadata), "}"));
315       }
316       metadata = StrCat(" metadata={", StrJoin(metadata_strings, ", "), "}");
317     }
318   }
319 
320   std::string last_tile_dims;
321   if (!subgroup_types_.empty()) {
322     auto op_sharding_type_to_string = [](OpSharding::Type type) {
323       switch (type) {
324         case OpSharding::MANUAL:
325           return "manual";
326         case OpSharding::MAXIMAL:
327           return "maximul";
328         case OpSharding::REPLICATED:
329           return "replicated";
330         default:
331           return "error_type.";
332       }
333     };
334     std::vector<std::string> sharding_type_strings;
335     sharding_type_strings.reserve(subgroup_types_.size());
336     for (const auto& single_sharding_type : subgroup_types_) {
337       sharding_type_strings.push_back(
338           op_sharding_type_to_string(single_sharding_type));
339     }
340     last_tile_dims =
341         StrCat(" last_tile_dims={", StrJoin(sharding_type_strings, ", "), "}");
342   }
343 
344   if (replicated_) {
345     return StrCat("{replicated", metadata, "}");
346   }
347 
348   if (manual_) {
349     return StrCat("{manual", metadata, "}");
350   }
351   if (maximal_) {
352     return StrCat(
353         "{maximal device=", static_cast<int64_t>(*tile_assignment_.begin()),
354         metadata, "}");
355   }
356   return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), "]",
357                 StrJoin(tile_assignment_, ","),
358                 replicate_on_last_tile_dim_ ? " last_tile_dim_replicate" : "",
359                 last_tile_dims, metadata, "}");
360 }
361 
UsesDevice(int64_t device) const362 bool HloSharding::UsesDevice(int64_t device) const {
363   if (IsTuple()) {
364     return absl::c_any_of(tuple_elements_, [&](const HloSharding& s) {
365       return s.UsesDevice(device);
366     });
367   }
368   const auto& devices = tile_assignment_;
369   return replicated_ || manual_ || absl::c_linear_search(devices, device);
370 }
371 
UsedDevices(int64_t * count) const372 std::map<int64_t, int64_t> HloSharding::UsedDevices(int64_t* count) const {
373   int64_t element_count = 1;
374   std::map<int64_t, int64_t> device_map;
375   if (IsTuple()) {
376     for (auto& tuple_element_sharding : tuple_elements()) {
377       auto unique_device = tuple_element_sharding.UniqueDevice();
378       if (unique_device) {
379         device_map[*unique_device] += 1;
380       }
381     }
382     element_count = tuple_elements().size();
383   } else {
384     auto unique_device = UniqueDevice();
385     if (unique_device) {
386       device_map[*unique_device] += 1;
387     }
388   }
389   if (count != nullptr) {
390     *count = element_count;
391   }
392   return device_map;
393 }
394 
TileIndexForDevice(int64_t device) const395 std::vector<int64_t> HloSharding::TileIndexForDevice(int64_t device) const {
396   CHECK(!maximal_);
397   CHECK(!manual_);
398   CHECK(!IsTuple());
399   std::vector<int64_t> ret_index;
400   tile_assignment_.Each([&](absl::Span<const int64_t> index, int64_t d) {
401     if (d == device) {
402       ret_index = {index.begin(), index.end()};
403     }
404   });
405   CHECK(!ret_index.empty());
406   ret_index.resize(TiledDataRank());
407   return ret_index;
408 }
409 
DeviceForTileIndex(absl::Span<const int64_t> index) const410 int64_t HloSharding::DeviceForTileIndex(absl::Span<const int64_t> index) const {
411   CHECK(!replicated_);
412   CHECK(!manual_);
413   CHECK(!IsTuple());
414   if (maximal_) {
415     return *tile_assignment_.begin();
416   }
417   if (index.size() == TiledDataRank() &&
418       index.size() < tile_assignment_.num_dimensions()) {
419     std::vector<int64_t> first_subgroup_index(index.begin(), index.end());
420     for (int64_t i = 0; i < tile_assignment_.num_dimensions() - index.size();
421          ++i) {
422       first_subgroup_index.push_back(0);
423     }
424     return tile_assignment_(first_subgroup_index);
425   }
426   return tile_assignment_(index);
427 }
428 
TileOffsetForDevice(const Shape & shape,int64_t device) const429 std::vector<int64_t> HloSharding::TileOffsetForDevice(const Shape& shape,
430                                                       int64_t device) const {
431   CHECK(!IsTuple());
432   CHECK(!manual_);
433 
434   if (maximal_) {
435     return std::vector<int64_t>(shape.dimensions_size(), 0);
436   }
437   CHECK_EQ(shape.dimensions_size(), TiledDataRank());
438   std::vector<int64_t> index = TileIndexForDevice(device);
439   for (int64_t i = 0; i < index.size(); ++i) {
440     const int64_t shape_dim = shape.dimensions(i);
441     index[i] = std::min(
442         index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim);
443   }
444   return index;
445 }
446 
TileLimitForDevice(const Shape & shape,int64_t device) const447 std::vector<int64_t> HloSharding::TileLimitForDevice(const Shape& shape,
448                                                      int64_t device) const {
449   CHECK(!IsTuple());
450   CHECK(!manual_);
451 
452   if (maximal_) {
453     return std::vector<int64_t>(shape.dimensions().begin(),
454                                 shape.dimensions().end());
455   }
456 
457   CHECK_EQ(shape.dimensions_size(), TiledDataRank());
458   std::vector<int64_t> index = TileIndexForDevice(device);
459   for (int64_t i = 0; i < index.size(); ++i) {
460     const int64_t shape_dim = shape.dimensions(i);
461     index[i] = std::min(
462         (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)),
463         shape_dim);
464   }
465   return index;
466 }
467 
RequiredLeaves(const Shape & shape)468 int64_t HloSharding::RequiredLeaves(const Shape& shape) {
469   // Empty tuples (with arbitrary nesting) have no leaf nodes as far as
470   // ShapeUtil and ShapeTree are concerned, but they do have a single
471   // tuple_elements_ entry since we want to allow empty tuple results to
472   // have sharding.
473   const int64_t leaf_count = ShapeUtil::GetLeafCount(shape);
474   return (leaf_count == 0) ? 1 : leaf_count;
475 }
476 
CheckLeafCount(const Shape & shape) const477 Status HloSharding::CheckLeafCount(const Shape& shape) const {
478   int64_t leaf_count = ShapeUtil::GetLeafCount(shape);
479   if (leaf_count == 0 && tuple_elements_.size() == 1) {
480     // Allow (but don't require) empty tuples to have a single sharding
481     return OkStatus();
482   }
483   TF_RET_CHECK(leaf_count == tuple_elements_.size())
484       << "Shape " << ShapeUtil::HumanString(shape) << " has " << leaf_count
485       << " leaf nodes while this sharding has " << tuple_elements_.size();
486   return OkStatus();
487 }
488 
AsShapeTree(const Shape & shape) const489 StatusOr<ShapeTree<HloSharding>> HloSharding::AsShapeTree(
490     const Shape& shape) const {
491   if (IsTuple()) {
492     ShapeTree<HloSharding> result(shape, HloSharding::Replicate());
493     TF_RETURN_IF_ERROR(CheckLeafCount(shape));
494     auto it = tuple_elements_.begin();
495     for (auto& index_to_sharding : result.leaves()) {
496       index_to_sharding.second = *it++;
497     }
498     if (ShapeUtil::IsEmptyTuple(shape)) {
499       // Empty tuples have no leaves, but we want to assign them a sharding
500       // anyway, so we use the root element sharding.
501       *result.mutable_element(ShapeIndex({})) = *it;
502     }
503     return std::move(result);
504   } else {
505     return ShapeTree<HloSharding>(shape, *this);
506   }
507 }
508 
GetTupleSharding(const Shape & shape) const509 StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
510   if (IsTuple()) {
511     TF_RETURN_IF_ERROR(CheckLeafCount(shape));
512     return *this;
513   }
514   return Tuple(ShapeTree<HloSharding>(shape, *this));
515 }
516 
UniqueDevice() const517 std::optional<int64_t> HloSharding::UniqueDevice() const {
518   if (IsTuple()) {
519     if (tuple_elements_.empty()) {
520       return std::nullopt;
521     }
522     std::optional<int64_t> unique_device;
523     for (auto& tuple_sharding : tuple_elements_) {
524       auto device = tuple_sharding.UniqueDevice();
525       if (!device || (unique_device && *device != *unique_device)) {
526         return std::nullopt;
527       }
528       unique_device = device;
529     }
530     return unique_device;
531   }
532   if (!replicated_ && maximal_) {
533     return static_cast<int64_t>(*tile_assignment_.begin());
534   }
535   return std::nullopt;
536 }
537 
GetUniqueDevice() const538 int64_t HloSharding::GetUniqueDevice() const {
539   auto device = UniqueDevice();
540   CHECK(device) << "Sharding does not have a unique device: " << *this;
541   return *device;
542 }
543 
ValidateTuple(const Shape & shape,int64_t num_devices) const544 Status HloSharding::ValidateTuple(const Shape& shape,
545                                   int64_t num_devices) const {
546   if (!shape.IsTuple()) {
547     return tensorflow::errors::InvalidArgument(
548         StrCat("Sharding is tuple-shaped but validation shape is not."));
549   }
550   TF_RETURN_IF_ERROR(CheckLeafCount(shape));
551   if (ShapeUtil::GetLeafCount(shape) == 0 && tuple_elements_.empty()) {
552     // Empty tuples are allowed to not have sharding
553     return OkStatus();
554   }
555 
556   // Now we've validated the number of tuple elements, it's safe to request a
557   // shape tree.
558   ShapeTree<HloSharding> shape_tree = GetAsShapeTree(shape);
559   for (const auto& index_to_sharding : shape_tree.leaves()) {
560     Status status = index_to_sharding.second.ValidateNonTuple(
561         ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices);
562     if (!status.ok()) {
563       tensorflow::errors::AppendToMessage(
564           &status, StrCat("Note: While validating sharding tuple element ",
565                           index_to_sharding.first.ToString(), " which is ",
566                           index_to_sharding.second.ToString()));
567       return status;
568     }
569   }
570   return OkStatus();
571 }
572 
Validate(const Shape & shape,int64_t num_devices) const573 Status HloSharding::Validate(const Shape& shape, int64_t num_devices) const {
574   if (shape.IsToken()) {
575     return OkStatus();
576   }
577   Status status = IsTuple() ? ValidateTuple(shape, num_devices)
578                             : ValidateNonTuple(shape, num_devices);
579   if (!status.ok()) {
580     tensorflow::errors::AppendToMessage(
581         &status, StrCat("Note: While validating sharding ", ToString(),
582                         " against shape ", ShapeUtil::HumanString(shape)));
583   }
584   return status;
585 }
586 
ValidateNonTuple(const Shape & shape,int64_t num_devices) const587 Status HloSharding::ValidateNonTuple(const Shape& shape,
588                                      int64_t num_devices) const {
589   if (shape.IsTuple()) {
590     return tensorflow::errors::InvalidArgument(
591         StrCat("Validation shape is a tuple but sharding is not."));
592   }
593   if (replicated_) {
594     return OkStatus();
595   }
596 
597   // All tile assignments must be less than the number of available cores and
598   // unique.
599   Status status = OkStatus();
600   absl::flat_hash_set<int64_t> seen_cores;
601   tile_assignment_.Each([&](absl::Span<const int64_t> indices, int32_t core) {
602     // Don't overwrite a bad status, so we report the first error.
603     if (status.ok()) {
604       if (core >= num_devices) {
605         status = tensorflow::errors::InvalidArgument(
606             StrCat("core ", core, " > ", num_devices, " in tile assignment"));
607       } else if (seen_cores.contains(core)) {
608         status = tensorflow::errors::InvalidArgument(
609             StrCat("core ", core, " is not unique in tile assignment"));
610       }
611       seen_cores.insert(core);
612     }
613   });
614   if (!status.ok()) {
615     return status;
616   }
617 
618   if (IsTileMaximal() || IsManual()) {
619     return OkStatus();
620   }
621 
622   // The tile assignment tensor must have the same rank as the input, or input
623   // rank + 1 for replicate_on_last_tile_dim_.
624   if (shape.rank() + (replicate_on_last_tile_dim_ ? 1 : 0) +
625           subgroup_types_.size() !=
626       tile_assignment_.num_dimensions()) {
627     return tensorflow::errors::InvalidArgument(
628         "Number of tile assignment dimensions is different to the input rank. "
629         "sharding=",
630         ToString(), ", input_shape=", ShapeUtil::HumanString(shape));
631   }
632 
633   // The correct constructor has to be used to create tile maximal shardings.
634   if (tile_assignment_.num_elements() == 1) {
635     return tensorflow::errors::InvalidArgument(
636         "Tile assignment only contains a single device. If a replicated "
637         "sharding was intended, use HloSharding::Replicated(). If a device "
638         "placement was intended, use HloSharding::AssignDevice()");
639   }
640   return OkStatus();
641 }
642 
FromProto(const OpSharding & proto)643 /*static*/ StatusOr<HloSharding> HloSharding::FromProto(
644     const OpSharding& proto) {
645   std::vector<OpMetadata> metadata(proto.metadata().begin(),
646                                    proto.metadata().end());
647   std::vector<int> subgroup_types_int(proto.last_tile_dims().begin(),
648                                       proto.last_tile_dims().end());
649   std::vector<OpSharding::Type> subgroup_types;
650   absl::c_transform(
651       subgroup_types_int, std::back_inserter(subgroup_types),
652       [](const int type) { return static_cast<OpSharding::Type>(type); });
653   if (proto.type() == OpSharding::TUPLE) {
654     TF_RET_CHECK(metadata.empty())
655         << "Tuple sharding is expected to have no metadata.";
656     std::vector<HloSharding> tuple_shardings;
657     tuple_shardings.reserve(proto.tuple_shardings().size());
658     for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) {
659       TF_ASSIGN_OR_RETURN(HloSharding sharding,
660                           HloSharding::FromProto(tuple_sharding_proto));
661       tuple_shardings.push_back(sharding);
662     }
663     return HloSharding(tuple_shardings);
664   } else if (proto.type() == OpSharding::REPLICATED) {
665     return Replicate(metadata);
666   } else if (proto.type() == OpSharding::MANUAL) {
667     return Manual(metadata);
668   } else if (proto.tile_assignment_devices().size() == 1) {
669     return HloSharding(proto.tile_assignment_devices(0), metadata);
670   }
671 
672   TF_RET_CHECK(proto.type() != OpSharding::MAXIMAL)
673       << "Maximal sharding is expected to have single device assignment, but "
674       << proto.tile_assignment_devices().size() << " has provided.";
675 
676   TF_RET_CHECK(proto.tile_assignment_devices().size() > 1);
677   TF_RET_CHECK(!proto.tile_assignment_dimensions().empty());
678 
679   // RE: the product of tile assignment tensor dimensions must be
680   // equal to tile_assignment_devices.size().
681   int64_t product_of_dimensions = 1;
682   for (auto dimension : proto.tile_assignment_dimensions()) {
683     TF_RET_CHECK(dimension > 0);
684     product_of_dimensions =
685         MultiplyWithoutOverflow(product_of_dimensions, dimension);
686     TF_RET_CHECK(product_of_dimensions > 0);
687   }
688   TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size());
689 
690   // Some versions of gcc cannot infer the TileAssignment constructor from a
691   // braced initializer-list, so create one manually.
692   std::vector<int64_t> devices(proto.tile_assignment_devices().begin(),
693                                proto.tile_assignment_devices().end());
694   Array<int64_t> tile_assignment(
695       std::vector<int64_t>(proto.tile_assignment_dimensions().begin(),
696                            proto.tile_assignment_dimensions().end()));
697   std::copy(proto.tile_assignment_devices().begin(),
698             proto.tile_assignment_devices().end(), tile_assignment.begin());
699   if (!subgroup_types.empty()) {
700     TF_RET_CHECK(!proto.replicate_on_last_tile_dim());
701     return Subgroup(tile_assignment, subgroup_types, metadata);
702   }
703   return proto.replicate_on_last_tile_dim()
704              ? PartialTile(tile_assignment, metadata)
705              : HloSharding(tile_assignment,
706                            /*replicate_on_last_tile_dim=*/false, metadata);
707 }
708 
ToProto() const709 OpSharding HloSharding::ToProto() const {
710   OpSharding result;
711 
712   if (IsTuple()) {
713     CHECK(metadata_.empty());
714     for (const HloSharding& element : tuple_elements_) {
715       *result.add_tuple_shardings() = element.ToProto();
716     }
717     result.set_type(OpSharding::TUPLE);
718     return result;
719   }
720 
721   result.mutable_metadata()->Reserve(metadata_.size());
722   for (const auto& metadata : metadata_) {
723     *result.add_metadata() = metadata;
724   }
725 
726   for (int64_t dim : tile_assignment_.dimensions()) {
727     result.add_tile_assignment_dimensions(dim);
728   }
729   for (auto device : tile_assignment_) {
730     result.add_tile_assignment_devices(device);
731   }
732   if (IsReplicated()) {
733     result.set_type(OpSharding::REPLICATED);
734     result.clear_tile_assignment_dimensions();
735   } else if (IsTileMaximal()) {
736     result.set_type(OpSharding::MAXIMAL);
737   } else if (IsManual()) {
738     result.set_type(OpSharding::MANUAL);
739     result.clear_tile_assignment_dimensions();
740   } else {
741     result.set_type(OpSharding::OTHER);
742     result.set_replicate_on_last_tile_dim(ReplicateOnLastTileDim());
743     for (auto type : subgroup_types_) {
744       result.add_last_tile_dims(type);
745     }
746   }
747   return result;
748 }
749 
TileShape(const Shape & shape) const750 Shape HloSharding::TileShape(const Shape& shape) const {
751   if (IsTileMaximal() || IsManual()) {
752     return shape;
753   }
754   Shape result_shape = shape;
755   for (int64_t i = 0; i < TiledDataRank(); ++i) {
756     result_shape.set_dimensions(
757         i, CeilOfRatio<int64_t>(shape.dimensions(i), tile_assignment_.dim(i)));
758   }
759   return result_shape;
760 }
761 
TileShape(const Shape & shape,int64_t device) const762 Shape HloSharding::TileShape(const Shape& shape, int64_t device) const {
763   if (IsTileMaximal() || IsManual()) {
764     return shape;
765   }
766 
767   std::vector<int64_t> index = TileIndexForDevice(device);
768   Shape result_shape = shape;
769   for (int64_t i = 0; i < index.size(); ++i) {
770     const int64_t shape_dim = shape.dimensions(i);
771     int64_t offset = std::min(
772         index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim);
773     int64_t limit = std::min(
774         (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)),
775         shape_dim);
776     result_shape.set_dimensions(i, limit - offset);
777   }
778   return result_shape;
779 }
780 
NumTiles() const781 int64_t HloSharding::NumTiles() const {
782   if (IsTileMaximal()) {
783     return 1;
784   }
785   CHECK(!IsManual());
786   return Product(absl::Span<const int64_t>(tile_assignment_.dimensions())
787                      .subspan(0, TiledDataRank()));
788 }
789 
NumTiles(absl::Span<const int64_t> dims) const790 int64_t HloSharding::NumTiles(absl::Span<const int64_t> dims) const {
791   if (IsTileMaximal()) {
792     return 1;
793   }
794   CHECK(!IsManual());
795   CHECK(!ReplicateOnLastTileDim() ||
796         !absl::c_linear_search(dims, tile_assignment().num_dimensions() - 1));
797   int64_t num_tiles = 1;
798   for (auto d : dims) {
799     CHECK(d < tile_assignment().num_dimensions());
800     num_tiles *= tile_assignment().dim(d);
801   }
802   return num_tiles;
803 }
804 
GetSubSharding(const Shape & shape,const ShapeIndex & index) const805 HloSharding HloSharding::GetSubSharding(const Shape& shape,
806                                         const ShapeIndex& index) const {
807   CHECK(IsTuple());
808   int64_t sharding_index = 0;
809   const Shape* sub_shape = &shape;
810   for (int64_t idx : index) {
811     for (int64_t i = 0; i < idx; ++i) {
812       sharding_index +=
813           ShapeUtil::GetLeafCount(ShapeUtil::GetSubshape(*sub_shape, {i}));
814     }
815     sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx});
816   }
817   if (sub_shape->IsTuple()) {
818     auto begin_it = tuple_elements_.begin() + sharding_index;
819     std::vector<HloSharding> sub_shardings(
820         begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape));
821     return HloSharding::Tuple(*sub_shape, sub_shardings);
822   } else {
823     return tuple_elements_[sharding_index];
824   }
825 }
826 
ExtractSingleSharding() const827 std::optional<HloSharding> HloSharding::ExtractSingleSharding() const {
828   if (!IsTuple()) {
829     return *this;
830   }
831   if (tuple_elements_.empty()) {
832     return std::nullopt;
833   }
834   for (int64_t i = 1; i < tuple_elements_.size(); ++i) {
835     if (tuple_elements_[0] != tuple_elements_[i]) {
836       return std::nullopt;
837     }
838   }
839   return tuple_elements_.front();
840 }
841 
WithMetadata(absl::Span<const OpMetadata> metadata,bool overwrite) const842 HloSharding HloSharding::WithMetadata(absl::Span<const OpMetadata> metadata,
843                                       bool overwrite) const {
844   auto assign_metadata = [&](HloSharding& sharding) {
845     if (sharding.metadata_.empty() || overwrite) {
846       sharding.metadata_.assign(metadata.begin(), metadata.end());
847     }
848   };
849 
850   HloSharding sharding = *this;
851   if (sharding.IsTuple()) {
852     for (HloSharding& sub_sharding : sharding.tuple_elements()) {
853       assign_metadata(sub_sharding);
854     }
855   } else {
856     assign_metadata(sharding);
857   }
858   return sharding;
859 }
860 
WithoutMetadata() const861 HloSharding HloSharding::WithoutMetadata() const {
862   HloSharding sharding = *this;
863   sharding.metadata_.clear();
864   for (HloSharding& sub_sharding : sharding.tuple_elements()) {
865     sub_sharding.metadata_.clear();
866   }
867   return sharding;
868 }
869 
operator <<(std::ostream & out,const HloSharding & sharding)870 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) {
871   out << sharding.ToString();
872   return out;
873 }
874 
875 }  // namespace xla
876