1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <map>
21 #include <numeric>
22 #include <optional>
23 #include <ostream>
24 #include <set>
25 #include <string>
26 #include <utility>
27 #include <vector>
28
29 #include "absl/algorithm/container.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/xla/overflow_util.h"
35 #include "tensorflow/compiler/xla/service/hlo_op_metadata.h"
36 #include "tensorflow/compiler/xla/status_macros.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39
40 namespace xla {
41
42 using absl::StrCat;
43 using absl::StrJoin;
44
AssignDevice(int64_t device_id,absl::Span<const OpMetadata> metadata)45 HloSharding HloSharding::AssignDevice(int64_t device_id,
46 absl::Span<const OpMetadata> metadata) {
47 return HloSharding(device_id, metadata);
48 }
49
Tile1D(const Shape & input_shape,int64_t num_tiles,absl::Span<const OpMetadata> metadata)50 HloSharding HloSharding::Tile1D(const Shape& input_shape, int64_t num_tiles,
51 absl::Span<const OpMetadata> metadata) {
52 CHECK_EQ(1, input_shape.rank());
53 CHECK_GT(num_tiles, 1);
54 std::vector<int64_t> dimensions(1, num_tiles);
55 Array<int64_t> assignment(dimensions);
56 std::iota(assignment.begin(), assignment.end(), 0);
57 return HloSharding(assignment, /*replicate_on_last_tile_dim=*/false,
58 metadata);
59 }
60
PartialTile(const Array<int64_t> & group_tile_assignment,absl::Span<const absl::Span<const int64_t>> replication_groups,absl::Span<const OpMetadata> metadata)61 HloSharding HloSharding::PartialTile(
62 const Array<int64_t>& group_tile_assignment,
63 absl::Span<const absl::Span<const int64_t>> replication_groups,
64 absl::Span<const OpMetadata> metadata) {
65 CHECK_EQ(group_tile_assignment.num_elements(), replication_groups.size());
66 if (replication_groups.size() == 1) {
67 return Replicate(metadata);
68 }
69 auto new_tile_dims = group_tile_assignment.dimensions();
70 new_tile_dims.push_back(replication_groups[0].size());
71 auto new_tile_assignment = Array<int64_t>(new_tile_dims);
72 new_tile_assignment.Each(
73 [&](absl::Span<const int64_t> indices, int64_t* device) {
74 std::vector<int64_t> group_index(indices.begin(), indices.end());
75 group_index.pop_back();
76 int64_t group = group_tile_assignment(group_index);
77 *device = replication_groups[group][indices.back()];
78 });
79 return PartialTile(new_tile_assignment, metadata);
80 }
81
PartialTile(const Array<int64_t> & tile_assignment_last_dim_replicate,absl::Span<const OpMetadata> metadata)82 HloSharding HloSharding::PartialTile(
83 const Array<int64_t>& tile_assignment_last_dim_replicate,
84 absl::Span<const OpMetadata> metadata) {
85 if (tile_assignment_last_dim_replicate.num_dimensions() == 1 ||
86 tile_assignment_last_dim_replicate.dimensions().back() ==
87 tile_assignment_last_dim_replicate.num_elements()) {
88 return Replicate(metadata);
89 }
90 if (tile_assignment_last_dim_replicate.dimensions().back() == 1) {
91 auto new_tile_dims = tile_assignment_last_dim_replicate.dimensions();
92 new_tile_dims.pop_back();
93 auto fully_tiled = tile_assignment_last_dim_replicate;
94 fully_tiled.Reshape(new_tile_dims);
95 return HloSharding(fully_tiled, /*replicate_on_last_tile_dim=*/false,
96 metadata);
97 }
98 std::vector<std::set<int64_t>> sorted_groups(
99 tile_assignment_last_dim_replicate.num_elements() /
100 tile_assignment_last_dim_replicate.dimensions().back());
101 auto get_group_id = [&](absl::Span<const int64_t> indices) {
102 int64_t group_id = 0;
103 for (int64_t i = 0; i < indices.size() - 1; ++i) {
104 group_id *= tile_assignment_last_dim_replicate.dim(i);
105 group_id += indices[i];
106 }
107 return group_id;
108 };
109 tile_assignment_last_dim_replicate.Each(
110 [&](absl::Span<const int64_t> indices, const int64_t device) {
111 sorted_groups[get_group_id(indices)].insert(device);
112 });
113 Array<int64_t> sorted_tile(tile_assignment_last_dim_replicate.dimensions());
114 sorted_tile.Each([&](absl::Span<const int64_t> indices, int64_t* device) {
115 const int64_t group_id = get_group_id(indices);
116 auto begin = sorted_groups[group_id].begin();
117 *device = *begin;
118 sorted_groups[group_id].erase(begin);
119 });
120 return HloSharding(sorted_tile, /*replicate_on_last_tile_dim=*/true,
121 metadata);
122 }
123
Subgroup(const Array<int64_t> & tile_assignment,absl::Span<const OpSharding::Type> subgroup_types,absl::Span<const OpMetadata> metadata)124 HloSharding HloSharding::Subgroup(
125 const Array<int64_t>& tile_assignment,
126 absl::Span<const OpSharding::Type> subgroup_types,
127 absl::Span<const OpMetadata> metadata) {
128 if (subgroup_types.empty()) {
129 return HloSharding(tile_assignment, /*replicate_on_last_tile_dim=*/false,
130 metadata);
131 }
132 // If there is only one type of subgrouping and there is no tiling on data
133 // dimensions, it can be canonicalized to a simple manual/replicated sharding.
134 if (absl::c_all_of(
135 subgroup_types,
136 [&](const OpSharding::Type t) { return t == subgroup_types[0]; }) &&
137 Product(absl::Span<const int64_t>(tile_assignment.dimensions())
138 .subspan(0, tile_assignment.num_dimensions() -
139 subgroup_types.size())) == 1) {
140 if (subgroup_types[0] == OpSharding::MANUAL) {
141 return Manual(metadata);
142 }
143 if (subgroup_types[0] == OpSharding::REPLICATED) {
144 return Replicate(metadata);
145 }
146 }
147 // Normalize the subgroups to simplify two cases:
148 // - Remove trivial dims of size 1.
149 // - Merge dims of the same type.
150 // - Sort types.
151 int64_t data_dims = tile_assignment.num_dimensions() - subgroup_types.size();
152 std::vector<int64_t> perm(data_dims);
153 std::iota(perm.begin(), perm.end(), 0);
154 // Make sure the replicate dims are at the end so that we can leverage
155 // PartialTile() to sort the elements.
156 struct CmpTypeRepliateLast {
157 bool operator()(OpSharding::Type a, OpSharding::Type b) const {
158 if (a == b) {
159 return false;
160 }
161 if (a == OpSharding::REPLICATED) {
162 return false;
163 }
164 if (b == OpSharding::REPLICATED) {
165 return true;
166 }
167 return a < b;
168 }
169 };
170 std::map<OpSharding::Type, std::vector<int64_t>, CmpTypeRepliateLast>
171 type_to_dims;
172 bool needs_merging = false;
173 for (int64_t i = 0; i < subgroup_types.size(); ++i) {
174 if (tile_assignment.dim(i + data_dims) == 1) {
175 needs_merging = true;
176 continue;
177 }
178 auto& dims = type_to_dims[subgroup_types[i]];
179 needs_merging |= !dims.empty();
180 dims.push_back(i + data_dims);
181 }
182 needs_merging |= type_to_dims.size() > 1;
183 auto create_sharding = [](const Array<int64_t> tiles,
184 absl::Span<const OpSharding::Type> types,
185 absl::Span<const OpMetadata> metadata) {
186 if (types.size() == 1 && types.back() == OpSharding::REPLICATED) {
187 // Normalize to partial tile.
188 return PartialTile(tiles, metadata);
189 }
190 if (types.size() == 1 && types.back() == OpSharding::MANUAL &&
191 tiles.num_elements() == tiles.dimensions().back()) {
192 // Normalize to manual.
193 return Manual(metadata);
194 }
195 if (!types.empty() && types.back() == OpSharding::REPLICATED) {
196 // If the last type is REPLICATED, we first create a partially replicated
197 // sharding without other subgroups so that the elements are sorted. Then
198 // we fix the subgroup types.
199 HloSharding sharding = PartialTile(tiles, metadata);
200 sharding.replicate_on_last_tile_dim_ = false;
201 for (const OpSharding::Type type : types) {
202 sharding.subgroup_types_.push_back(type);
203 }
204 return sharding;
205 }
206 return HloSharding(tiles, types, metadata);
207 };
208 if (needs_merging) {
209 auto data_tile_shape =
210 absl::Span<const int64_t>(tile_assignment.dimensions())
211 .subspan(0, data_dims);
212 std::vector<int64_t> merged_shape(data_tile_shape.begin(),
213 data_tile_shape.end());
214 std::vector<int64_t> transposed_shape = merged_shape;
215 std::vector<OpSharding::Type> merged_types;
216 for (const auto& type_dims : type_to_dims) {
217 int64_t dim_size = 1;
218 for (int64_t dim : type_dims.second) {
219 perm.push_back(dim);
220 dim_size *= tile_assignment.dim(dim);
221 transposed_shape.push_back(tile_assignment.dim(dim));
222 }
223 merged_shape.push_back(dim_size);
224 merged_types.push_back(type_dims.first);
225 }
226 Array<int64_t> new_tiles(transposed_shape);
227 new_tiles.Each([&](absl::Span<const int64_t> indices, int64_t* value) {
228 std::vector<int64_t> src_indices(tile_assignment.num_dimensions(), 0);
229 for (int64_t i = 0; i < indices.size(); ++i) {
230 src_indices[perm[i]] = indices[i];
231 }
232 *value = tile_assignment(src_indices);
233 });
234 new_tiles.Reshape(merged_shape);
235 return create_sharding(new_tiles, merged_types, metadata);
236 }
237 return create_sharding(tile_assignment, subgroup_types, metadata);
238 }
239
Tuple(const ShapeTree<HloSharding> & sub_shardings)240 HloSharding HloSharding::Tuple(const ShapeTree<HloSharding>& sub_shardings) {
241 std::vector<HloSharding> flattened_list;
242 flattened_list.reserve(sub_shardings.leaf_count());
243 for (const auto& index_to_sharding : sub_shardings.leaves()) {
244 flattened_list.push_back(index_to_sharding.second);
245 }
246 if (flattened_list.empty()) {
247 // Empty tuple sharding ends up having no leaves, but we want to allow
248 // empty tuple HLO instruction results to have sharding, so we fetch the
249 // root ({}) sharding value from the ShapeTree.
250 // A ShapeTree created with ShapeTree<HloSharding>(shape, init) will have
251 // init as value at its root.
252 flattened_list.push_back(sub_shardings.element(ShapeIndex({})));
253 }
254 return HloSharding(flattened_list);
255 }
256
Tuple(const Shape & tuple_shape,absl::Span<const HloSharding> shardings)257 HloSharding HloSharding::Tuple(const Shape& tuple_shape,
258 absl::Span<const HloSharding> shardings) {
259 CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape);
260 for (auto& sharding : shardings) {
261 CHECK(!sharding.IsTuple())
262 << sharding.ToString() << ShapeUtil::HumanString(tuple_shape);
263 }
264 std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end());
265 CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape))
266 << "Flat list has " << flattened_list.size() << ", required "
267 << RequiredLeaves(tuple_shape);
268 return HloSharding(flattened_list);
269 }
270
SingleTuple(const Shape & tuple_shape,const HloSharding & sharding)271 HloSharding HloSharding::SingleTuple(const Shape& tuple_shape,
272 const HloSharding& sharding) {
273 CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape);
274 CHECK(!sharding.IsTuple()) << sharding.ToString();
275 int64_t leaf_count = RequiredLeaves(tuple_shape);
276 std::vector<HloSharding> flattened_list;
277 flattened_list.resize(leaf_count, sharding);
278 return HloSharding(flattened_list);
279 }
280
Single(const Shape & shape,const HloSharding & sharding)281 HloSharding HloSharding::Single(const Shape& shape,
282 const HloSharding& sharding) {
283 return shape.IsTuple() ? SingleTuple(shape, sharding) : sharding;
284 }
285
ToString(bool include_metadata) const286 std::string HloSharding::ToString(bool include_metadata) const {
287 if (IsTuple()) {
288 CHECK(metadata_.empty());
289 std::string result = "{";
290 for (int i = 0; i < tuple_elements_.size(); ++i) {
291 const HloSharding& element = tuple_elements_[i];
292 if (i != 0) {
293 absl::StrAppend(&result, ", ");
294 if (i % 5 == 0) {
295 absl::StrAppend(&result, "/*index=", i, "*/");
296 }
297 }
298 absl::StrAppend(&result, element.ToString(include_metadata));
299 }
300 absl::StrAppend(&result, "}");
301 return result;
302 }
303
304 std::string metadata;
305 if (include_metadata) {
306 if (metadata_.size() == 1) {
307 metadata =
308 StrCat(" metadata={", OpMetadataToString(metadata_.front()), "}");
309 } else if (metadata_.size() > 1) {
310 std::vector<std::string> metadata_strings;
311 metadata_strings.reserve(metadata_.size());
312 for (const auto& single_metadata : metadata_) {
313 metadata_strings.push_back(
314 StrCat("{", OpMetadataToString(single_metadata), "}"));
315 }
316 metadata = StrCat(" metadata={", StrJoin(metadata_strings, ", "), "}");
317 }
318 }
319
320 std::string last_tile_dims;
321 if (!subgroup_types_.empty()) {
322 auto op_sharding_type_to_string = [](OpSharding::Type type) {
323 switch (type) {
324 case OpSharding::MANUAL:
325 return "manual";
326 case OpSharding::MAXIMAL:
327 return "maximul";
328 case OpSharding::REPLICATED:
329 return "replicated";
330 default:
331 return "error_type.";
332 }
333 };
334 std::vector<std::string> sharding_type_strings;
335 sharding_type_strings.reserve(subgroup_types_.size());
336 for (const auto& single_sharding_type : subgroup_types_) {
337 sharding_type_strings.push_back(
338 op_sharding_type_to_string(single_sharding_type));
339 }
340 last_tile_dims =
341 StrCat(" last_tile_dims={", StrJoin(sharding_type_strings, ", "), "}");
342 }
343
344 if (replicated_) {
345 return StrCat("{replicated", metadata, "}");
346 }
347
348 if (manual_) {
349 return StrCat("{manual", metadata, "}");
350 }
351 if (maximal_) {
352 return StrCat(
353 "{maximal device=", static_cast<int64_t>(*tile_assignment_.begin()),
354 metadata, "}");
355 }
356 return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), "]",
357 StrJoin(tile_assignment_, ","),
358 replicate_on_last_tile_dim_ ? " last_tile_dim_replicate" : "",
359 last_tile_dims, metadata, "}");
360 }
361
UsesDevice(int64_t device) const362 bool HloSharding::UsesDevice(int64_t device) const {
363 if (IsTuple()) {
364 return absl::c_any_of(tuple_elements_, [&](const HloSharding& s) {
365 return s.UsesDevice(device);
366 });
367 }
368 const auto& devices = tile_assignment_;
369 return replicated_ || manual_ || absl::c_linear_search(devices, device);
370 }
371
UsedDevices(int64_t * count) const372 std::map<int64_t, int64_t> HloSharding::UsedDevices(int64_t* count) const {
373 int64_t element_count = 1;
374 std::map<int64_t, int64_t> device_map;
375 if (IsTuple()) {
376 for (auto& tuple_element_sharding : tuple_elements()) {
377 auto unique_device = tuple_element_sharding.UniqueDevice();
378 if (unique_device) {
379 device_map[*unique_device] += 1;
380 }
381 }
382 element_count = tuple_elements().size();
383 } else {
384 auto unique_device = UniqueDevice();
385 if (unique_device) {
386 device_map[*unique_device] += 1;
387 }
388 }
389 if (count != nullptr) {
390 *count = element_count;
391 }
392 return device_map;
393 }
394
TileIndexForDevice(int64_t device) const395 std::vector<int64_t> HloSharding::TileIndexForDevice(int64_t device) const {
396 CHECK(!maximal_);
397 CHECK(!manual_);
398 CHECK(!IsTuple());
399 std::vector<int64_t> ret_index;
400 tile_assignment_.Each([&](absl::Span<const int64_t> index, int64_t d) {
401 if (d == device) {
402 ret_index = {index.begin(), index.end()};
403 }
404 });
405 CHECK(!ret_index.empty());
406 ret_index.resize(TiledDataRank());
407 return ret_index;
408 }
409
DeviceForTileIndex(absl::Span<const int64_t> index) const410 int64_t HloSharding::DeviceForTileIndex(absl::Span<const int64_t> index) const {
411 CHECK(!replicated_);
412 CHECK(!manual_);
413 CHECK(!IsTuple());
414 if (maximal_) {
415 return *tile_assignment_.begin();
416 }
417 if (index.size() == TiledDataRank() &&
418 index.size() < tile_assignment_.num_dimensions()) {
419 std::vector<int64_t> first_subgroup_index(index.begin(), index.end());
420 for (int64_t i = 0; i < tile_assignment_.num_dimensions() - index.size();
421 ++i) {
422 first_subgroup_index.push_back(0);
423 }
424 return tile_assignment_(first_subgroup_index);
425 }
426 return tile_assignment_(index);
427 }
428
TileOffsetForDevice(const Shape & shape,int64_t device) const429 std::vector<int64_t> HloSharding::TileOffsetForDevice(const Shape& shape,
430 int64_t device) const {
431 CHECK(!IsTuple());
432 CHECK(!manual_);
433
434 if (maximal_) {
435 return std::vector<int64_t>(shape.dimensions_size(), 0);
436 }
437 CHECK_EQ(shape.dimensions_size(), TiledDataRank());
438 std::vector<int64_t> index = TileIndexForDevice(device);
439 for (int64_t i = 0; i < index.size(); ++i) {
440 const int64_t shape_dim = shape.dimensions(i);
441 index[i] = std::min(
442 index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim);
443 }
444 return index;
445 }
446
TileLimitForDevice(const Shape & shape,int64_t device) const447 std::vector<int64_t> HloSharding::TileLimitForDevice(const Shape& shape,
448 int64_t device) const {
449 CHECK(!IsTuple());
450 CHECK(!manual_);
451
452 if (maximal_) {
453 return std::vector<int64_t>(shape.dimensions().begin(),
454 shape.dimensions().end());
455 }
456
457 CHECK_EQ(shape.dimensions_size(), TiledDataRank());
458 std::vector<int64_t> index = TileIndexForDevice(device);
459 for (int64_t i = 0; i < index.size(); ++i) {
460 const int64_t shape_dim = shape.dimensions(i);
461 index[i] = std::min(
462 (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)),
463 shape_dim);
464 }
465 return index;
466 }
467
RequiredLeaves(const Shape & shape)468 int64_t HloSharding::RequiredLeaves(const Shape& shape) {
469 // Empty tuples (with arbitrary nesting) have no leaf nodes as far as
470 // ShapeUtil and ShapeTree are concerned, but they do have a single
471 // tuple_elements_ entry since we want to allow empty tuple results to
472 // have sharding.
473 const int64_t leaf_count = ShapeUtil::GetLeafCount(shape);
474 return (leaf_count == 0) ? 1 : leaf_count;
475 }
476
CheckLeafCount(const Shape & shape) const477 Status HloSharding::CheckLeafCount(const Shape& shape) const {
478 int64_t leaf_count = ShapeUtil::GetLeafCount(shape);
479 if (leaf_count == 0 && tuple_elements_.size() == 1) {
480 // Allow (but don't require) empty tuples to have a single sharding
481 return OkStatus();
482 }
483 TF_RET_CHECK(leaf_count == tuple_elements_.size())
484 << "Shape " << ShapeUtil::HumanString(shape) << " has " << leaf_count
485 << " leaf nodes while this sharding has " << tuple_elements_.size();
486 return OkStatus();
487 }
488
AsShapeTree(const Shape & shape) const489 StatusOr<ShapeTree<HloSharding>> HloSharding::AsShapeTree(
490 const Shape& shape) const {
491 if (IsTuple()) {
492 ShapeTree<HloSharding> result(shape, HloSharding::Replicate());
493 TF_RETURN_IF_ERROR(CheckLeafCount(shape));
494 auto it = tuple_elements_.begin();
495 for (auto& index_to_sharding : result.leaves()) {
496 index_to_sharding.second = *it++;
497 }
498 if (ShapeUtil::IsEmptyTuple(shape)) {
499 // Empty tuples have no leaves, but we want to assign them a sharding
500 // anyway, so we use the root element sharding.
501 *result.mutable_element(ShapeIndex({})) = *it;
502 }
503 return std::move(result);
504 } else {
505 return ShapeTree<HloSharding>(shape, *this);
506 }
507 }
508
GetTupleSharding(const Shape & shape) const509 StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
510 if (IsTuple()) {
511 TF_RETURN_IF_ERROR(CheckLeafCount(shape));
512 return *this;
513 }
514 return Tuple(ShapeTree<HloSharding>(shape, *this));
515 }
516
UniqueDevice() const517 std::optional<int64_t> HloSharding::UniqueDevice() const {
518 if (IsTuple()) {
519 if (tuple_elements_.empty()) {
520 return std::nullopt;
521 }
522 std::optional<int64_t> unique_device;
523 for (auto& tuple_sharding : tuple_elements_) {
524 auto device = tuple_sharding.UniqueDevice();
525 if (!device || (unique_device && *device != *unique_device)) {
526 return std::nullopt;
527 }
528 unique_device = device;
529 }
530 return unique_device;
531 }
532 if (!replicated_ && maximal_) {
533 return static_cast<int64_t>(*tile_assignment_.begin());
534 }
535 return std::nullopt;
536 }
537
GetUniqueDevice() const538 int64_t HloSharding::GetUniqueDevice() const {
539 auto device = UniqueDevice();
540 CHECK(device) << "Sharding does not have a unique device: " << *this;
541 return *device;
542 }
543
ValidateTuple(const Shape & shape,int64_t num_devices) const544 Status HloSharding::ValidateTuple(const Shape& shape,
545 int64_t num_devices) const {
546 if (!shape.IsTuple()) {
547 return tensorflow::errors::InvalidArgument(
548 StrCat("Sharding is tuple-shaped but validation shape is not."));
549 }
550 TF_RETURN_IF_ERROR(CheckLeafCount(shape));
551 if (ShapeUtil::GetLeafCount(shape) == 0 && tuple_elements_.empty()) {
552 // Empty tuples are allowed to not have sharding
553 return OkStatus();
554 }
555
556 // Now we've validated the number of tuple elements, it's safe to request a
557 // shape tree.
558 ShapeTree<HloSharding> shape_tree = GetAsShapeTree(shape);
559 for (const auto& index_to_sharding : shape_tree.leaves()) {
560 Status status = index_to_sharding.second.ValidateNonTuple(
561 ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices);
562 if (!status.ok()) {
563 tensorflow::errors::AppendToMessage(
564 &status, StrCat("Note: While validating sharding tuple element ",
565 index_to_sharding.first.ToString(), " which is ",
566 index_to_sharding.second.ToString()));
567 return status;
568 }
569 }
570 return OkStatus();
571 }
572
Validate(const Shape & shape,int64_t num_devices) const573 Status HloSharding::Validate(const Shape& shape, int64_t num_devices) const {
574 if (shape.IsToken()) {
575 return OkStatus();
576 }
577 Status status = IsTuple() ? ValidateTuple(shape, num_devices)
578 : ValidateNonTuple(shape, num_devices);
579 if (!status.ok()) {
580 tensorflow::errors::AppendToMessage(
581 &status, StrCat("Note: While validating sharding ", ToString(),
582 " against shape ", ShapeUtil::HumanString(shape)));
583 }
584 return status;
585 }
586
ValidateNonTuple(const Shape & shape,int64_t num_devices) const587 Status HloSharding::ValidateNonTuple(const Shape& shape,
588 int64_t num_devices) const {
589 if (shape.IsTuple()) {
590 return tensorflow::errors::InvalidArgument(
591 StrCat("Validation shape is a tuple but sharding is not."));
592 }
593 if (replicated_) {
594 return OkStatus();
595 }
596
597 // All tile assignments must be less than the number of available cores and
598 // unique.
599 Status status = OkStatus();
600 absl::flat_hash_set<int64_t> seen_cores;
601 tile_assignment_.Each([&](absl::Span<const int64_t> indices, int32_t core) {
602 // Don't overwrite a bad status, so we report the first error.
603 if (status.ok()) {
604 if (core >= num_devices) {
605 status = tensorflow::errors::InvalidArgument(
606 StrCat("core ", core, " > ", num_devices, " in tile assignment"));
607 } else if (seen_cores.contains(core)) {
608 status = tensorflow::errors::InvalidArgument(
609 StrCat("core ", core, " is not unique in tile assignment"));
610 }
611 seen_cores.insert(core);
612 }
613 });
614 if (!status.ok()) {
615 return status;
616 }
617
618 if (IsTileMaximal() || IsManual()) {
619 return OkStatus();
620 }
621
622 // The tile assignment tensor must have the same rank as the input, or input
623 // rank + 1 for replicate_on_last_tile_dim_.
624 if (shape.rank() + (replicate_on_last_tile_dim_ ? 1 : 0) +
625 subgroup_types_.size() !=
626 tile_assignment_.num_dimensions()) {
627 return tensorflow::errors::InvalidArgument(
628 "Number of tile assignment dimensions is different to the input rank. "
629 "sharding=",
630 ToString(), ", input_shape=", ShapeUtil::HumanString(shape));
631 }
632
633 // The correct constructor has to be used to create tile maximal shardings.
634 if (tile_assignment_.num_elements() == 1) {
635 return tensorflow::errors::InvalidArgument(
636 "Tile assignment only contains a single device. If a replicated "
637 "sharding was intended, use HloSharding::Replicated(). If a device "
638 "placement was intended, use HloSharding::AssignDevice()");
639 }
640 return OkStatus();
641 }
642
FromProto(const OpSharding & proto)643 /*static*/ StatusOr<HloSharding> HloSharding::FromProto(
644 const OpSharding& proto) {
645 std::vector<OpMetadata> metadata(proto.metadata().begin(),
646 proto.metadata().end());
647 std::vector<int> subgroup_types_int(proto.last_tile_dims().begin(),
648 proto.last_tile_dims().end());
649 std::vector<OpSharding::Type> subgroup_types;
650 absl::c_transform(
651 subgroup_types_int, std::back_inserter(subgroup_types),
652 [](const int type) { return static_cast<OpSharding::Type>(type); });
653 if (proto.type() == OpSharding::TUPLE) {
654 TF_RET_CHECK(metadata.empty())
655 << "Tuple sharding is expected to have no metadata.";
656 std::vector<HloSharding> tuple_shardings;
657 tuple_shardings.reserve(proto.tuple_shardings().size());
658 for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) {
659 TF_ASSIGN_OR_RETURN(HloSharding sharding,
660 HloSharding::FromProto(tuple_sharding_proto));
661 tuple_shardings.push_back(sharding);
662 }
663 return HloSharding(tuple_shardings);
664 } else if (proto.type() == OpSharding::REPLICATED) {
665 return Replicate(metadata);
666 } else if (proto.type() == OpSharding::MANUAL) {
667 return Manual(metadata);
668 } else if (proto.tile_assignment_devices().size() == 1) {
669 return HloSharding(proto.tile_assignment_devices(0), metadata);
670 }
671
672 TF_RET_CHECK(proto.type() != OpSharding::MAXIMAL)
673 << "Maximal sharding is expected to have single device assignment, but "
674 << proto.tile_assignment_devices().size() << " has provided.";
675
676 TF_RET_CHECK(proto.tile_assignment_devices().size() > 1);
677 TF_RET_CHECK(!proto.tile_assignment_dimensions().empty());
678
679 // RE: the product of tile assignment tensor dimensions must be
680 // equal to tile_assignment_devices.size().
681 int64_t product_of_dimensions = 1;
682 for (auto dimension : proto.tile_assignment_dimensions()) {
683 TF_RET_CHECK(dimension > 0);
684 product_of_dimensions =
685 MultiplyWithoutOverflow(product_of_dimensions, dimension);
686 TF_RET_CHECK(product_of_dimensions > 0);
687 }
688 TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size());
689
690 // Some versions of gcc cannot infer the TileAssignment constructor from a
691 // braced initializer-list, so create one manually.
692 std::vector<int64_t> devices(proto.tile_assignment_devices().begin(),
693 proto.tile_assignment_devices().end());
694 Array<int64_t> tile_assignment(
695 std::vector<int64_t>(proto.tile_assignment_dimensions().begin(),
696 proto.tile_assignment_dimensions().end()));
697 std::copy(proto.tile_assignment_devices().begin(),
698 proto.tile_assignment_devices().end(), tile_assignment.begin());
699 if (!subgroup_types.empty()) {
700 TF_RET_CHECK(!proto.replicate_on_last_tile_dim());
701 return Subgroup(tile_assignment, subgroup_types, metadata);
702 }
703 return proto.replicate_on_last_tile_dim()
704 ? PartialTile(tile_assignment, metadata)
705 : HloSharding(tile_assignment,
706 /*replicate_on_last_tile_dim=*/false, metadata);
707 }
708
ToProto() const709 OpSharding HloSharding::ToProto() const {
710 OpSharding result;
711
712 if (IsTuple()) {
713 CHECK(metadata_.empty());
714 for (const HloSharding& element : tuple_elements_) {
715 *result.add_tuple_shardings() = element.ToProto();
716 }
717 result.set_type(OpSharding::TUPLE);
718 return result;
719 }
720
721 result.mutable_metadata()->Reserve(metadata_.size());
722 for (const auto& metadata : metadata_) {
723 *result.add_metadata() = metadata;
724 }
725
726 for (int64_t dim : tile_assignment_.dimensions()) {
727 result.add_tile_assignment_dimensions(dim);
728 }
729 for (auto device : tile_assignment_) {
730 result.add_tile_assignment_devices(device);
731 }
732 if (IsReplicated()) {
733 result.set_type(OpSharding::REPLICATED);
734 result.clear_tile_assignment_dimensions();
735 } else if (IsTileMaximal()) {
736 result.set_type(OpSharding::MAXIMAL);
737 } else if (IsManual()) {
738 result.set_type(OpSharding::MANUAL);
739 result.clear_tile_assignment_dimensions();
740 } else {
741 result.set_type(OpSharding::OTHER);
742 result.set_replicate_on_last_tile_dim(ReplicateOnLastTileDim());
743 for (auto type : subgroup_types_) {
744 result.add_last_tile_dims(type);
745 }
746 }
747 return result;
748 }
749
TileShape(const Shape & shape) const750 Shape HloSharding::TileShape(const Shape& shape) const {
751 if (IsTileMaximal() || IsManual()) {
752 return shape;
753 }
754 Shape result_shape = shape;
755 for (int64_t i = 0; i < TiledDataRank(); ++i) {
756 result_shape.set_dimensions(
757 i, CeilOfRatio<int64_t>(shape.dimensions(i), tile_assignment_.dim(i)));
758 }
759 return result_shape;
760 }
761
TileShape(const Shape & shape,int64_t device) const762 Shape HloSharding::TileShape(const Shape& shape, int64_t device) const {
763 if (IsTileMaximal() || IsManual()) {
764 return shape;
765 }
766
767 std::vector<int64_t> index = TileIndexForDevice(device);
768 Shape result_shape = shape;
769 for (int64_t i = 0; i < index.size(); ++i) {
770 const int64_t shape_dim = shape.dimensions(i);
771 int64_t offset = std::min(
772 index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim);
773 int64_t limit = std::min(
774 (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)),
775 shape_dim);
776 result_shape.set_dimensions(i, limit - offset);
777 }
778 return result_shape;
779 }
780
NumTiles() const781 int64_t HloSharding::NumTiles() const {
782 if (IsTileMaximal()) {
783 return 1;
784 }
785 CHECK(!IsManual());
786 return Product(absl::Span<const int64_t>(tile_assignment_.dimensions())
787 .subspan(0, TiledDataRank()));
788 }
789
NumTiles(absl::Span<const int64_t> dims) const790 int64_t HloSharding::NumTiles(absl::Span<const int64_t> dims) const {
791 if (IsTileMaximal()) {
792 return 1;
793 }
794 CHECK(!IsManual());
795 CHECK(!ReplicateOnLastTileDim() ||
796 !absl::c_linear_search(dims, tile_assignment().num_dimensions() - 1));
797 int64_t num_tiles = 1;
798 for (auto d : dims) {
799 CHECK(d < tile_assignment().num_dimensions());
800 num_tiles *= tile_assignment().dim(d);
801 }
802 return num_tiles;
803 }
804
GetSubSharding(const Shape & shape,const ShapeIndex & index) const805 HloSharding HloSharding::GetSubSharding(const Shape& shape,
806 const ShapeIndex& index) const {
807 CHECK(IsTuple());
808 int64_t sharding_index = 0;
809 const Shape* sub_shape = &shape;
810 for (int64_t idx : index) {
811 for (int64_t i = 0; i < idx; ++i) {
812 sharding_index +=
813 ShapeUtil::GetLeafCount(ShapeUtil::GetSubshape(*sub_shape, {i}));
814 }
815 sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx});
816 }
817 if (sub_shape->IsTuple()) {
818 auto begin_it = tuple_elements_.begin() + sharding_index;
819 std::vector<HloSharding> sub_shardings(
820 begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape));
821 return HloSharding::Tuple(*sub_shape, sub_shardings);
822 } else {
823 return tuple_elements_[sharding_index];
824 }
825 }
826
ExtractSingleSharding() const827 std::optional<HloSharding> HloSharding::ExtractSingleSharding() const {
828 if (!IsTuple()) {
829 return *this;
830 }
831 if (tuple_elements_.empty()) {
832 return std::nullopt;
833 }
834 for (int64_t i = 1; i < tuple_elements_.size(); ++i) {
835 if (tuple_elements_[0] != tuple_elements_[i]) {
836 return std::nullopt;
837 }
838 }
839 return tuple_elements_.front();
840 }
841
WithMetadata(absl::Span<const OpMetadata> metadata,bool overwrite) const842 HloSharding HloSharding::WithMetadata(absl::Span<const OpMetadata> metadata,
843 bool overwrite) const {
844 auto assign_metadata = [&](HloSharding& sharding) {
845 if (sharding.metadata_.empty() || overwrite) {
846 sharding.metadata_.assign(metadata.begin(), metadata.end());
847 }
848 };
849
850 HloSharding sharding = *this;
851 if (sharding.IsTuple()) {
852 for (HloSharding& sub_sharding : sharding.tuple_elements()) {
853 assign_metadata(sub_sharding);
854 }
855 } else {
856 assign_metadata(sharding);
857 }
858 return sharding;
859 }
860
WithoutMetadata() const861 HloSharding HloSharding::WithoutMetadata() const {
862 HloSharding sharding = *this;
863 sharding.metadata_.clear();
864 for (HloSharding& sub_sharding : sharding.tuple_elements()) {
865 sub_sharding.metadata_.clear();
866 }
867 return sharding;
868 }
869
operator <<(std::ostream & out,const HloSharding & sharding)870 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) {
871 out << sharding.ToString();
872 return out;
873 }
874
875 } // namespace xla
876