1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <memory>
22 #include <numeric>
23 #include <optional>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/container/inlined_vector.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/comparison_util.h"
34 #include "tensorflow/compiler/xla/literal_util.h"
35 #include "tensorflow/compiler/xla/protobuf_util.h"
36 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
37 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
38 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
39 #include "tensorflow/compiler/xla/service/hlo_computation.h"
40 #include "tensorflow/compiler/xla/service/hlo_cse.h"
41 #include "tensorflow/compiler/xla/service/hlo_dce.h"
42 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
43 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
44 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
45 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
46 #include "tensorflow/compiler/xla/service/hlo_query.h"
47 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
48 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
49 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
50 #include "tensorflow/compiler/xla/service/shape_inference.h"
51 #include "tensorflow/compiler/xla/service/spmd/custom_call_handler.h"
52 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
53 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
54 #include "tensorflow/compiler/xla/shape_util.h"
55 #include "tensorflow/compiler/xla/util.h"
56 #include "tensorflow/compiler/xla/window_util.h"
57 #include "tensorflow/compiler/xla/xla_data.pb.h"
58 #include "tensorflow/core/platform/logging.h"
59
60 namespace xla {
61 namespace spmd {
62
63 namespace {
64 using hlo_sharding_util::GroupedSharding;
65 } // namespace
66
MakeReport()67 std::string SpmdLogger::MakeReport() {
68 std::string report;
69 absl::StrAppend(&report,
70 "\n\n***** SPMD memory during transformation *****\n");
71
72 std::sort(entries_.begin(), entries_.end(),
73 [](auto const& entry0, auto const& entry1) {
74 return entry0.first > entry1.first;
75 });
76 for (int64_t i = 0;
77 i < std::min<int64_t>(report_instruction_count_, entries_.size()); ++i) {
78 absl::StrAppend(
79 &report, "\n ",
80 tensorflow::strings::HumanReadableNumBytes(entries_[i].first), " : ",
81 entries_[i].second, "\n");
82 }
83
84 return report;
85 }
86
RegisterLogEntry(HloInstruction * hlo,const std::vector<HloInstruction * > & group)87 void SpmdLogger::RegisterLogEntry(HloInstruction* hlo,
88 const std::vector<HloInstruction*>& group) {
89 if (disabled_) {
90 return;
91 }
92 std::string report = hlo->ToString();
93 int64_t max_value = -1;
94 for (HloInstruction* inst : group) {
95 if (!inst->shape().IsArray()) {
96 continue;
97 }
98 max_value = std::max<int64_t>(max_value, ShapeSizeInBytes(inst->shape()));
99 absl::StrAppend(&report, " * ", inst->ToString(), "\n");
100 }
101 entries_.push_back(std::make_pair(max_value, report));
102 }
103
ReportBeforePartition(const HloModule & module,int64_t report_instruction_count)104 /* static */ std::string SpmdLogger::ReportBeforePartition(
105 const HloModule& module, int64_t report_instruction_count) {
106 std::string report;
107 absl::StrAppend(&report,
108 "\n\n***** SPMD memory usage before partition *****\n");
109 absl::StrAppend(&report, "\n ** Replicated instructions\n");
110 absl::StrAppend(&report, ReportMemoryUsage(
111 module,
112 [](const HloInstruction* hlo) {
113 return !hlo->has_sharding() ||
114 hlo->sharding().IsReplicated();
115 },
116 report_instruction_count));
117 absl::StrAppend(&report, "\n ** All instructions\n");
118 absl::StrAppend(&report,
119 ReportMemoryUsage(
120 module, [](const HloInstruction* hlo) { return true; },
121 report_instruction_count));
122 return report;
123 }
124
ReportAfterPartition(const HloModule & module,int64_t report_instruction_count)125 /* static */ std::string SpmdLogger::ReportAfterPartition(
126 const HloModule& module, int64_t report_instruction_count) {
127 std::string report;
128 absl::StrAppend(&report,
129 "\n\n***** SPMD memory usage after partition *****\n");
130 absl::StrAppend(&report,
131 ReportMemoryUsage(
132 module, [](const HloInstruction* hlo) { return true; },
133 report_instruction_count));
134 return report;
135 }
136
137 template <typename F>
ReportMemoryUsage(const HloModule & module,const F & filter,int64_t report_instruction_count)138 /* static */ std::string SpmdLogger::ReportMemoryUsage(
139 const HloModule& module, const F& filter,
140 int64_t report_instruction_count) {
141 std::string report;
142 std::vector<HloInstruction*> instructions;
143 instructions.reserve(module.instruction_count());
144
145 for (auto computation : module.computations()) {
146 if (computation->IsFusionComputation()) {
147 continue;
148 }
149 for (auto hlo : computation->instructions()) {
150 if (!hlo->shape().IsArray() ||
151 ShapeUtil::IsEffectiveScalar(hlo->shape())) {
152 continue;
153 }
154 if (filter(hlo)) {
155 instructions.push_back(hlo);
156 }
157 }
158 }
159
160 const auto add_report = [&](std::vector<HloInstruction*>* insts) {
161 std::sort(insts->begin(), insts->end(),
162 [](const HloInstruction* inst0, const HloInstruction* inst1) {
163 return ShapeSizeInBytes(inst0->shape()) >
164 ShapeSizeInBytes(inst1->shape());
165 });
166 for (int64_t i = 0;
167 i < std::min<int64_t>(report_instruction_count, insts->size()); ++i) {
168 absl::StrAppend(&report, " ",
169 tensorflow::strings::HumanReadableNumBytes(
170 ShapeSizeInBytes((*insts)[i]->shape())),
171 " : ", (*insts)[i]->ToString(), "\n");
172 }
173 };
174
175 add_report(&instructions);
176 return report;
177 }
178
179 namespace {
180
181 // Clears all sharding attributes from instructions in the module. This must be
182 // called only after all SPMD transformation is complete.
ClearShardingAttributes(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)183 Status ClearShardingAttributes(
184 HloModule* module,
185 const absl::flat_hash_set<absl::string_view>& execution_threads) {
186 for (HloComputation* computation : module->computations(execution_threads)) {
187 for (HloInstruction* hlo : computation->instructions()) {
188 // Keep sharding annotation on Infeed and entry parameters since they're
189 // used by HloReplicationAnalysis later (for ArCrsCombiner).
190 if (hlo->HasSideEffect()) {
191 continue;
192 }
193 if (hlo->opcode() == HloOpcode::kParameter &&
194 computation == module->entry_computation()) {
195 continue;
196 }
197 hlo->clear_sharding();
198 }
199 }
200 return OkStatus();
201 }
202
GetPartitionGroupsForReplication(const HloSharding & sharding,absl::Span<const int64_t> replication_dims)203 std::vector<std::vector<int64_t>> GetPartitionGroupsForReplication(
204 const HloSharding& sharding, absl::Span<const int64_t> replication_dims) {
205 int64_t group_size = 1;
206 for (int64_t i : replication_dims) {
207 group_size *= sharding.tile_assignment().dim(i);
208 }
209 std::vector<std::vector<int64_t>> partition_groups(
210 sharding.tile_assignment().num_elements() / group_size);
211 sharding.tile_assignment().Each(
212 [&](absl::Span<const int64_t> indices, int64_t partition) {
213 int64_t group_id = 0;
214 for (int64_t i = 0; i < indices.size(); ++i) {
215 if (!absl::c_linear_search(replication_dims, i)) {
216 group_id *= sharding.tile_assignment().dim(i);
217 group_id += indices[i];
218 }
219 }
220 partition_groups[group_id].push_back(partition);
221 });
222 return partition_groups;
223 }
224
225 // Returns a sharding that is replicated on all the dimensions where the given
226 // window is not unary.
GetShardingReplicatedOnWindowedDimension(const HloSharding & sharding,const Window & window)227 HloSharding GetShardingReplicatedOnWindowedDimension(
228 const HloSharding& sharding, const Window& window) {
229 std::vector<int64_t> dimensions_to_replicate;
230 for (int i = 0; i < window.dimensions_size(); ++i) {
231 const WindowDimension& wd = window.dimensions(i);
232 if (window_util::IsTrivialWindowDimension(wd)) {
233 continue;
234 }
235 dimensions_to_replicate.push_back(i);
236 }
237 return hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
238 sharding, dimensions_to_replicate);
239 }
240
241 } // namespace
242
AddInstruction(std::unique_ptr<HloInstruction> instruction)243 HloInstruction* SpmdBuilder::AddInstruction(
244 std::unique_ptr<HloInstruction> instruction) {
245 HloInstruction* hlo =
246 HloComputation::Builder::AddInstruction(std::move(instruction));
247 if (visiting_hlo_) {
248 hlo->set_metadata(visiting_hlo_->metadata());
249 instructions_[visiting_hlo_].push_back(hlo);
250 }
251 if (hlo->opcode() == HloOpcode::kBroadcast) {
252 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
253 if (!absl::c_linear_search(hlo->dimensions(), i)) {
254 broadcast_dims_[hlo].insert(i);
255 }
256 }
257 }
258 if (hlo->IsElementwise() && hlo->operand_count() > 0 &&
259 // Copy can have a tuple result.
260 hlo->shape().IsArray()) {
261 absl::flat_hash_set<int64_t> broadcast_dims;
262 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
263 broadcast_dims.insert(i);
264 }
265 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
266 auto it = broadcast_dims_.find(hlo->operand(i));
267 if (it == broadcast_dims_.end()) {
268 broadcast_dims.clear();
269 break;
270 }
271 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
272 if (!it->second.contains(i)) {
273 broadcast_dims.erase(i);
274 }
275 }
276 }
277 if (!broadcast_dims.empty()) {
278 broadcast_dims_[hlo] = std::move(broadcast_dims);
279 }
280 }
281 if (hlo->opcode() == HloOpcode::kTranspose) {
282 auto it = broadcast_dims_.find(hlo->operand(0));
283 if (it != broadcast_dims_.end()) {
284 absl::flat_hash_set<int64_t> xpose_broadcast_dims;
285 std::vector<int64_t> reverse_map(hlo->shape().rank());
286 for (int64_t i = 0; i < reverse_map.size(); ++i) {
287 reverse_map[hlo->dimensions(i)] = i;
288 }
289 for (int64_t dim : it->second) {
290 xpose_broadcast_dims.insert(reverse_map[dim]);
291 }
292 broadcast_dims_[hlo] = std::move(xpose_broadcast_dims);
293 }
294 }
295 if (hlo->opcode() == HloOpcode::kReshape &&
296 Product(hlo->shape().dimensions()) > 0) {
297 auto it = broadcast_dims_.find(hlo->operand(0));
298 if (it != broadcast_dims_.end()) {
299 absl::flat_hash_set<int64_t> reshape_broadcast_dims;
300 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
301 reshape_broadcast_dims.insert(i);
302 }
303 std::vector<int64_t> before_dim_size_stack;
304 std::vector<int64_t> after_dim_size_stack;
305 const int64_t operand0_rank = hlo->operand(0)->shape().rank();
306 const int64_t hlo_shape_rank = hlo->shape().rank();
307 before_dim_size_stack.reserve(operand0_rank);
308 after_dim_size_stack.reserve(hlo_shape_rank);
309 for (int64_t i = operand0_rank - 1; i >= 0; --i) {
310 before_dim_size_stack.push_back(hlo->operand(0)->shape().dimensions(i));
311 }
312 for (int64_t i = hlo_shape_rank - 1; i >= 0; --i) {
313 after_dim_size_stack.push_back(hlo->shape().dimensions(i));
314 }
315 while (!before_dim_size_stack.empty() && !after_dim_size_stack.empty()) {
316 int64_t before_size = before_dim_size_stack.back();
317 int64_t after_size = after_dim_size_stack.back();
318 int64_t current_before_dim =
319 hlo->operand(0)->shape().rank() - before_dim_size_stack.size();
320 int64_t current_after_dim =
321 hlo->shape().rank() - after_dim_size_stack.size();
322 before_dim_size_stack.pop_back();
323 after_dim_size_stack.pop_back();
324 if (!it->second.contains(current_before_dim)) {
325 reshape_broadcast_dims.erase(current_after_dim);
326 }
327 if (before_size == after_size) {
328 continue;
329 }
330 if (before_size % after_size == 0) {
331 // Split dim.
332 before_dim_size_stack.push_back(before_size / after_size);
333 } else if (after_size % before_size == 0) {
334 // Merge dim.
335 after_dim_size_stack.push_back(after_size / before_size);
336 } else {
337 // Other cases, mark all remaining dims as non-broadcast.
338 for (int64_t i = current_after_dim; i < hlo->shape().rank(); ++i) {
339 reshape_broadcast_dims.erase(i);
340 }
341 break;
342 }
343 }
344 if (!before_dim_size_stack.empty() || !after_dim_size_stack.empty()) {
345 reshape_broadcast_dims.clear();
346 }
347 if (!reshape_broadcast_dims.empty()) {
348 broadcast_dims_[hlo] = std::move(reshape_broadcast_dims);
349 }
350 }
351 }
352 if (hlo->opcode() == HloOpcode::kSlice ||
353 hlo->opcode() == HloOpcode::kDynamicSlice) {
354 auto it = broadcast_dims_.find(hlo->operand(0));
355 if (it != broadcast_dims_.end()) {
356 auto dims = it->second;
357 broadcast_dims_[hlo] = std::move(dims);
358 }
359 }
360 if (hlo->opcode() == HloOpcode::kPad) {
361 auto it = broadcast_dims_.find(hlo->operand(0));
362 if (it != broadcast_dims_.end()) {
363 absl::flat_hash_set<int64_t> pad_broadcast_dims;
364 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
365 const auto& dim = hlo->padding_config().dimensions(i);
366 if (dim.edge_padding_low() == 0 && dim.edge_padding_high() == 0 &&
367 dim.interior_padding() == 0 && it->second.contains(i)) {
368 pad_broadcast_dims.insert(i);
369 }
370 }
371 if (!pad_broadcast_dims.empty()) {
372 broadcast_dims_[hlo] = std::move(pad_broadcast_dims);
373 }
374 }
375 }
376 return hlo;
377 }
378
Reshard(const HloSharding & target,std::optional<Literal> pad_value)379 PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target,
380 std::optional<Literal> pad_value) {
381 if (sharding() == target) {
382 return *this;
383 }
384 auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
385 // Replace existing reshard cache for target if we are sharding with new
386 // padding value.
387 const bool replace_cache = pad_value.has_value();
388 const bool is_to_replicate =
389 hlo_->shape().IsArray() && target.NumTiles() < sharding().NumTiles();
390 const bool use_cache =
391 !is_to_replicate || state_.partitioner->options().cache_all_gather;
392 if (!replace_cache && use_cache) {
393 auto it = cache.find(target);
394 if (it != cache.end()) {
395 return it->second;
396 }
397 }
398
399 auto resharded = ReshardNoCache(target, std::move(pad_value));
400 // Update cache for resharded hlo.
401 {
402 auto& cache =
403 state_.reshard_cache->per_hlo_cache[resharded.hlo()].reshard_cache;
404 cache.insert_or_assign(sharding(), *this);
405 }
406 // Update cache for to-reshard hlo.
407 if (use_cache) {
408 // Get the cache again as it might be invalidated by the insertion above.
409 auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
410 auto [it, _] = cache.insert_or_assign(target, std::move(resharded));
411 return it->second;
412 }
413 return resharded;
414 }
415
ReshardNoCache(const HloSharding & target,std::optional<Literal> pad_value,bool allow_full_replication)416 PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target,
417 std::optional<Literal> pad_value,
418 bool allow_full_replication) {
419 VLOG(2) << "Resharding " << hlo_->ToString() << " from "
420 << hlo_->sharding().ToString() << " to " << target.ToString();
421 const Shape& shape = hlo_->shape();
422 if (shape.element_type() == TOKEN) {
423 return *this;
424 }
425 CHECK(shape.IsTuple() || !target.IsTuple());
426
427 // Tuple shape instructions may have non-tuple sharding, which means that the
428 // same sharding applies to all the leaves.
429 if (shape.IsTuple() && !target.IsTuple()) {
430 return Reshard(target.GetTupleSharding(shape).ValueOrDie());
431 }
432
433 // For a tuple shape, recursively apply Reshard to all the leaves and return
434 // a tuple instruction.
435 if (shape.IsTuple()) {
436 std::vector<HloInstruction*> elements;
437 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
438 auto subshape = ShapeUtil::GetTupleElementShape(shape, i);
439 auto element = state_.b->AddInstruction(
440 HloInstruction::CreateGetTupleElement(subshape, hlo(), i));
441 element->set_sharding(sharding().GetSubSharding(shape, {i}));
442 elements.push_back(
443 PartitionedHlo(
444 element, ShapeUtil::GetTupleElementShape(base_shape_, i), state_)
445 .Reshard(target.GetSubSharding(shape, {i}))
446 .hlo());
447 }
448 auto tuple =
449 state_.b->AddInstruction(HloInstruction::CreateTuple(elements));
450 tuple->set_sharding(target);
451 return PartitionedHlo(tuple, base_shape_, state_);
452 }
453
454 if (sharding() == target) {
455 return *this;
456 }
457
458 CHECK_EQ(target.IsManualSubgroup(), sharding().IsManualSubgroup());
459 if (sharding().IsManualSubgroup()) {
460 auto grouped = hlo_sharding_util::GetManualSubgroupSharding(sharding());
461 auto target_grouped = AlignGroupsWithIfCompatible(
462 hlo_sharding_util::GetManualSubgroupSharding(target), grouped);
463 CHECK(target_grouped.has_value())
464 << "Resharding target has incompatible sharding subgroups. From "
465 << sharding().ToString() << " to " << target.ToString();
466 HloSharding original_sharding = sharding();
467 hlo_->set_sharding(grouped.sharding);
468 HloInstruction* partitioned =
469 PartitionedHlo(hlo_, base_shape_,
470 CreatePerGroupPartitioningState(
471 state(), grouped.device_groups, state_.b))
472 .ReshardNoCache(target_grouped->sharding)
473 .hlo();
474 hlo_->set_sharding(original_sharding);
475 partitioned->set_sharding(target);
476 return PartitionedHlo(partitioned, base_shape_, state_);
477 }
478
479 if (CanReshardWithCollectivePermute(sharding(), target)) {
480 return ReshardWithCollectivePermute(target);
481 }
482
483 if (auto src_tgt_dims =
484 GetReshardAllToAllSourceTargetDims(sharding(), target)) {
485 return ReshardWithAllToAll(target, *src_tgt_dims);
486 }
487
488 if (!target.IsTileMaximal() && sharding().ReplicateOnLastTileDim()) {
489 auto try_reshard = ReshardFromPartialReplicateWithDynamicSlice(target);
490 if (try_reshard.has_value()) {
491 return try_reshard.value();
492 }
493 try_reshard = ReshardPartialReplicateWithAllToAll(target);
494 if (try_reshard.has_value()) {
495 return try_reshard.value();
496 }
497 }
498
499 if (!sharding().IsTileMaximal() && target.ReplicateOnLastTileDim()) {
500 auto try_reshard = ReshardToPartialReplicateWithAllGather(target);
501 if (try_reshard.has_value()) {
502 return try_reshard.value();
503 }
504 try_reshard = ReshardPartialReplicateWithAllToAll(target);
505 if (try_reshard.has_value()) {
506 return try_reshard.value();
507 }
508 }
509
510 // If not replicated yet, first replicate and then reshard to use one of the
511 // two implementations below.
512 if (!sharding().IsReplicated()) {
513 if (!target.IsReplicated()) {
514 auto reshard = TryComplexReshardHandling(target);
515 if (reshard.has_value()) {
516 return reshard.value();
517 }
518 if (!allow_full_replication) {
519 return *this;
520 }
521 LOG(ERROR)
522 << "[spmd] Involuntary full rematerialization. The compiled was "
523 "not able to go from sharding "
524 << sharding().ToString(/*include_metadata=*/true) << " to "
525 << target.ToString(/*include_metadata=*/true)
526 << " without doing a full rematerialization of the tensor. You "
527 "probably want to enrich the sharding annotations to prevent "
528 "this from happening.";
529 }
530 return Replicate().Reshard(target);
531 }
532
533 // 'Replicated' to 'SingleDevice'.
534 if (target.IsTileMaximal()) {
535 auto copy = state_.b->AddInstruction(
536 HloInstruction::CreateUnary(hlo_->shape(), HloOpcode::kCopy, hlo_));
537 copy->set_sharding(target);
538 return PartitionedHlo(copy, base_shape_, state_);
539 }
540
541 // 'Replicated' to partial replicated.
542 if (target.ReplicateOnLastTileDim()) {
543 std::vector<int64_t> group_dims(target.tile_assignment().num_dimensions() -
544 1);
545 std::iota(group_dims.begin(), group_dims.end(), 0);
546 auto target_grouped =
547 hlo_sharding_util::GroupShardingOnDims(target, group_dims);
548 auto partially_sharded = PerGroupSliceFromReplicated(
549 hlo_, state_.partition_id, target_grouped.device_groups, group_dims,
550 target_grouped.group_dim_sizes, state_.b);
551 partially_sharded->set_sharding(target);
552 return PartitionedHlo(partially_sharded, base_shape(), state_);
553 }
554
555 // 'Replicated' to 'Tiled'.
556 auto padded_hlo = PadBaseShapeBeforeUnevenTiledSharding(
557 hlo_, target, state_.b, std::move(pad_value));
558 auto shard_shape = MakePartitionedShape(shape, target);
559 auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
560 shard_shape, padded_hlo,
561 MakePartitionOffsets(shape, target, state_.partition_id, state_.b),
562 shard_shape.dimensions()));
563 slice->set_sharding(target);
564 return PartitionedHlo(slice, base_shape_, state_);
565 }
566
PadWithValue(HloInstruction * pad_value,absl::Span<const int64_t> left_padded_dims,absl::Span<const int64_t> skipped_dims) const567 PartitionedHlo PartitionedHlo::PadWithValue(
568 HloInstruction* pad_value, absl::Span<const int64_t> left_padded_dims,
569 absl::Span<const int64_t> skipped_dims) const {
570 HloInstruction* result =
571 PadWithValueHlo(pad_value, left_padded_dims, skipped_dims);
572 if (hlo_ != result) {
573 result->set_sharding(hlo_->sharding());
574 }
575 return PartitionedHlo(result, base_shape_, state_);
576 }
577
PadWithValueHlo(HloInstruction * pad_value,absl::Span<const int64_t> left_padded_dims,absl::Span<const int64_t> skipped_dims) const578 HloInstruction* PartitionedHlo::PadWithValueHlo(
579 HloInstruction* pad_value, absl::Span<const int64_t> left_padded_dims,
580 absl::Span<const int64_t> skipped_dims) const {
581 const HloSharding& sharding = hlo_->sharding();
582 const Shape& shape = hlo_->shape();
583 CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
584 if (sharding.IsReplicated() || EvenlyPartitions(base_shape_, sharding)) {
585 return hlo_;
586 }
587 CHECK(!sharding.IsTileMaximal());
588 auto index_shape = ShapeUtil::ChangeElementType(shape, S32);
589 auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED);
590 auto get_mask_for_dim = [&](int64_t dim, HloInstruction* start_index) {
591 // Comparison: iota + start_index < valid_size
592 auto iota =
593 state_.b->AddInstruction(HloInstruction::CreateIota(index_shape, dim));
594 auto broadcast_start_index = state_.b->AddInstruction(
595 HloInstruction::CreateBroadcast(index_shape, start_index, {}));
596 auto index_in_full_shape =
597 state_.b->AddInstruction(HloInstruction::CreateBinary(
598 index_shape, HloOpcode::kAdd, iota, broadcast_start_index));
599 ComparisonDirection direction = ComparisonDirection::kLt;
600 int64_t index_limit = base_shape_.dimensions(dim);
601 if (absl::c_linear_search(left_padded_dims, dim)) {
602 direction = ComparisonDirection::kGe;
603 index_limit =
604 index_shape.dimensions(dim) * sharding.tile_assignment().dim(dim) -
605 index_limit;
606 }
607 auto limit = state_.b->AddInstruction(HloInstruction::CreateConstant(
608 LiteralUtil::CreateR0<int32_t>(index_limit)));
609 auto broadcast_limit = state_.b->AddInstruction(
610 HloInstruction::CreateBroadcast(index_shape, limit, {}));
611 return state_.b->AddInstruction(HloInstruction::CreateCompare(
612 mask_shape, index_in_full_shape, broadcast_limit, direction));
613 };
614
615 HloInstruction* mask = nullptr;
616 auto offsets = MakePartitionOffsets(base_shape_, sharding,
617 state_.partition_id, state_.b);
618 for (int64_t i = 0; i < shape.rank(); ++i) {
619 if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0 ||
620 absl::c_linear_search(skipped_dims, i)) {
621 continue;
622 }
623 if (mask == nullptr) {
624 mask = get_mask_for_dim(i, offsets[i]);
625 } else {
626 mask = state_.b->AddInstruction(
627 HloInstruction::CreateBinary(mask->shape(), HloOpcode::kAnd, mask,
628 get_mask_for_dim(i, offsets[i])));
629 }
630 }
631
632 if (mask == nullptr) {
633 return hlo_;
634 }
635
636 auto broadcast_pad_value = state_.b->AddInstruction(
637 HloInstruction::CreateBroadcast(shape, pad_value, {}));
638 return state_.b->AddInstruction(HloInstruction::CreateTernary(
639 shape, HloOpcode::kSelect, mask, hlo_, broadcast_pad_value));
640 }
641
PadWithZero(absl::Span<const int64_t> left_padded_dims,absl::Span<const int64_t> skipped_dims) const642 PartitionedHlo PartitionedHlo::PadWithZero(
643 absl::Span<const int64_t> left_padded_dims,
644 absl::Span<const int64_t> skipped_dims) const {
645 auto zero = state_.b->AddInstruction(HloInstruction::CreateConstant(
646 LiteralUtil::Zero(hlo_->shape().element_type())));
647 return PadWithValue(zero, left_padded_dims, skipped_dims);
648 }
649
650 std::optional<PartitionedHlo::WindowedInputShardReturnValue>
ReshardAsWindowedInput(const Window & window,const HloSharding & target,HloInstruction * pad_value,bool mask_invalid_region)651 PartitionedHlo::ReshardAsWindowedInput(const Window& window,
652 const HloSharding& target,
653 HloInstruction* pad_value,
654 bool mask_invalid_region) {
655 auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].window_reshard_cache;
656 for (auto& entry : cache) {
657 if (std::get<0>(entry) == target &&
658 protobuf_util::ProtobufEquals(std::get<1>(entry), window)) {
659 return std::get<2>(entry);
660 }
661 }
662 auto update_cache = [&](WindowedInputShardReturnValue result) {
663 cache.emplace_back(target, window, std::move(result));
664 return std::get<2>(cache.back());
665 };
666 VLOG(2) << "ReshardAsWindowedInput()\n"
667 << "\twindow:" << window_util::ToString(window)
668 << "\ttarget sharding:" << target.ToString();
669
670 CHECK(!target.IsTileMaximal());
671 auto partition_ordinals =
672 MakeTiledPartitionOrdinals(target, state_.partition_id, state_.b);
673 auto shard_shape = base_shape_;
674
675 std::vector<MultiplyAddDivideOffsetCalculation> start_on_padded_calculations(
676 base_shape_.rank());
677 std::vector<MultiplyAddDivideOffsetCalculation> limit_on_padded_calculations(
678 base_shape_.rank());
679 std::vector<HloInstruction*> dynamic_slice_offset_on_output(
680 base_shape_.rank(), nullptr);
681
682 Window shard_window = window;
683 Shape padded_shape = base_shape_;
684 std::vector<HloInstruction*> offsets_on_padded_shape(base_shape_.rank());
685 std::vector<int64_t> per_shard_window_counts(base_shape_.rank());
686 std::vector<int64_t> explicit_left_padding(base_shape_.rank());
687 // Track if any shards can be skipped.
688 std::vector<int64_t> trimmed_target_sharding_tile_shape(base_shape_.rank());
689 // There can be at most 2 ranges of skipped shards on a dimension: 1) on the
690 // right side, 2) in the middle. The following vector tracks the middle range
691 // (i.e., <start, size>). The leftmost shard must not be skipped because
692 // outputs are left-aligned.
693 std::vector<std::pair<int64_t, int64_t>> trimmed_target_sharding_middle_range(
694 base_shape_.rank(), std::pair<int64_t, int64_t>(-1, -1));
695 bool trimmed_shards = false;
696 std::vector<int64_t> dims_needs_pre_masking;
697 Shape halo_exchange_base_shape = base_shape_;
698 // If all useful input data are in a single shard, we can skip in-shard data
699 // (e.g., those that belong to negative padding) via a local slice.
700 bool trimmed_in_shard = false;
701 std::vector<int64_t> pre_halo_exchange_slice_starts(base_shape_.rank(), 0);
702 std::vector<int64_t> pre_halo_exchange_slice_limits(
703 hlo_->shape().dimensions().begin(), hlo_->shape().dimensions().end());
704 std::vector<bool> can_leave_dimension_partitioned(base_shape_.rank(), false);
705 for (int64_t i = 0; i < base_shape_.rank(); ++i) {
706 can_leave_dimension_partitioned[i] =
707 window_util::IsTrivialWindowDimension(window.dimensions(i));
708 }
709 for (int64_t i = 0; i < base_shape_.rank(); ++i) {
710 // Do not pad non-partitioned dimensions.
711 int64_t shard_count = target.tile_assignment().dim(i);
712 trimmed_target_sharding_tile_shape[i] = shard_count;
713 if (shard_count == 1 || can_leave_dimension_partitioned[i]) {
714 offsets_on_padded_shape[i] = state_.b->AddInstruction(
715 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
716 shard_shape.set_dimensions(
717 i, CeilOfRatio(base_shape_.dimensions(i), shard_count));
718 continue;
719 }
720 const WindowDimension& wd = window.dimensions(i);
721 WindowDimension* swd = shard_window.mutable_dimensions(i);
722 const int64_t dilated_size = 1 + (wd.size() - 1) * wd.window_dilation();
723 const int64_t full_size =
724 1 + (base_shape_.dimensions(i) - 1) * wd.base_dilation() +
725 wd.padding_high() + wd.padding_low();
726 int64_t window_count = (full_size - dilated_size) / wd.stride() + 1;
727 per_shard_window_counts[i] = CeilOfRatio(window_count, shard_count);
728 // Find skippable shards on the right side. This could only happen when
729 // window_count < shard_count so that the right-most shard does not have any
730 // output.
731 int64_t input_shard_size = hlo_->shape().dimensions(i);
732 if (window_count < shard_count && wd.window_dilation() == 1 &&
733 wd.base_dilation() == 1) {
734 // Test if some shards do not have any useful input (all uneven padding or
735 // window negative padding).
736 int64_t useful_input_shards = CeilOfRatio(
737 base_shape_.dimensions(i) + wd.padding_high(), input_shard_size);
738 if (useful_input_shards < shard_count) {
739 shard_count = std::max<int64_t>(useful_input_shards, window_count);
740 trimmed_shards = true;
741 trimmed_target_sharding_tile_shape[i] = shard_count;
742 if (shard_count == 1) {
743 offsets_on_padded_shape[i] = state_.b->AddInstruction(
744 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
745 swd->set_padding_high(base_shape_.dimensions(i) + wd.padding_high() -
746 hlo_->shape().dimensions(i));
747 continue;
748 }
749 // Make sure the halo-exchange base shape is evenly sharded on the new
750 // shard count.
751 halo_exchange_base_shape.set_dimensions(i,
752 input_shard_size * shard_count);
753 if (input_shard_size * shard_count > base_shape_.dimensions(i) &&
754 wd.padding_high() > 0) {
755 // The new shape has paddings, make sure it's masked.
756 dims_needs_pre_masking.push_back(i);
757 } else if (wd.padding_high() < 0 &&
758 full_size - wd.padding_low() < input_shard_size) {
759 // If the useful input is smaller than a shard, we treat the shard
760 // size as the useful input size and slice later.
761 input_shard_size = full_size - wd.padding_low();
762 halo_exchange_base_shape.set_dimensions(
763 i, input_shard_size * shard_count);
764 pre_halo_exchange_slice_limits[i] = input_shard_size;
765 trimmed_in_shard = true;
766 }
767 }
768 }
769
770 // We use explicit padding for full dilations, then use padding_low and
771 // padding_high on the sharded op for the remaining. padding_low and
772 // padding_high are now given initial values, which will be later updated if
773 // dilation is not 1.
774 explicit_left_padding[i] = wd.padding_low() / wd.base_dilation();
775 swd->set_padding_low(wd.padding_low() % wd.base_dilation());
776 swd->set_padding_high(0);
777
778 // Find potential skippable range in the middle. This could happen when only
779 // a few shards have outputs (window_count < shard_count), but there is a
780 // large negative left padding such that the start shard that has useful
781 // input does not have any outputs.
782 if (window_count < shard_count && wd.window_dilation() == 1 &&
783 wd.base_dilation() == 1) {
784 int64_t middle_empty_shards =
785 (-explicit_left_padding[i]) / input_shard_size - window_count;
786 if (middle_empty_shards > 0) {
787 shard_count -= middle_empty_shards;
788 CHECK_GT(shard_count, 1);
789 trimmed_target_sharding_middle_range[i].first = window_count;
790 trimmed_target_sharding_middle_range[i].second = middle_empty_shards;
791 trimmed_shards = true;
792 trimmed_target_sharding_tile_shape[i] = shard_count;
793 // Reduce negative padding.
794 explicit_left_padding[i] += middle_empty_shards * input_shard_size;
795 halo_exchange_base_shape.set_dimensions(i,
796 input_shard_size * shard_count);
797 HloInstruction* ordinal = partition_ordinals[i];
798 HloInstruction* left_count = CreateR0WithType<int32_t>(
799 ordinal->shape().element_type(), window_count, state_.b);
800 HloInstruction* on_left =
801 state_.b->AddInstruction(HloInstruction::CreateCompare(
802 ShapeUtil::ChangeElementType(ordinal->shape(), PRED), ordinal,
803 left_count, ComparisonDirection::kLt));
804 HloInstruction* right_ordinal =
805 state_.b->AddInstruction(HloInstruction::CreateBinary(
806 ordinal->shape(), HloOpcode::kSubtract, ordinal, left_count));
807 partition_ordinals[i] =
808 state_.b->AddInstruction(HloInstruction::CreateTernary(
809 partition_ordinals[i]->shape(), HloOpcode::kSelect, on_left,
810 partition_ordinals[i], right_ordinal));
811 if (-explicit_left_padding[i] > input_shard_size * (shard_count - 1)) {
812 // If all useful data is on the last shard, we can skip extra negative
813 // left padding.
814 int64_t skip_amount =
815 -explicit_left_padding[i] - input_shard_size * (shard_count - 1);
816 input_shard_size -= skip_amount;
817 explicit_left_padding[i] += skip_amount * shard_count;
818 pre_halo_exchange_slice_starts[i] = skip_amount;
819 trimmed_in_shard = true;
820 // We may have enabled a new skipping opportunity on the right side
821 // within the only shard that has useful input, because we excluded
822 // negative left padding regions this time.
823 if (full_size < input_shard_size) {
824 skip_amount = input_shard_size - full_size;
825 pre_halo_exchange_slice_limits[i] -= skip_amount;
826 explicit_left_padding[i] += skip_amount * (shard_count - 1);
827 input_shard_size = full_size;
828 }
829 halo_exchange_base_shape.set_dimensions(
830 i, input_shard_size * shard_count);
831 }
832 }
833 }
834 if (full_size < dilated_size) {
835 VLOG(2) << "Failed to reshard window operand because the window size is "
836 "larger than padded base size";
837 return std::nullopt;
838 }
839 if (wd.stride() != 1 &&
840 (wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != 0) {
841 // TODO(yuanzx): Support this case.
842 VLOG(2) << "Failed to reshard window operand due to non-trivial dilation";
843 return std::nullopt;
844 }
845
846 // Calculation for the first element needed on the 'padded-but-not-dilated'
847 // shape. The start on the dilated shape could be a hole, so we add
848 // wd.base_dilation() - 1 to the constant term to skip the leading holes.
849 start_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
850 wd.stride() * per_shard_window_counts[i],
851 wd.base_dilation() - 1 - swd->padding_low(), wd.base_dilation());
852 int64_t dilated_shard_size =
853 wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
854 limit_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
855 wd.stride() * per_shard_window_counts[i],
856 dilated_shard_size + wd.base_dilation() - 1 - swd->padding_low(),
857 wd.base_dilation());
858
859 offsets_on_padded_shape[i] = start_on_padded_calculations[i].Calculate(
860 partition_ordinals[i], state_.b);
861
862 auto shard_size_function =
863 limit_on_padded_calculations[i] - start_on_padded_calculations[i];
864 int64_t max_shard_size = shard_size_function.MaxInRange(0, shard_count);
865 shard_shape.set_dimensions(i, max_shard_size);
866 padded_shape.set_dimensions(
867 i, limit_on_padded_calculations[i].Calculate(shard_count - 1));
868
869 // For base dilation, calculate the needed padding_low and padding_high, as
870 // well as the offset for the output if a dynamic slice is needed after the
871 // sharded op.
872 if (wd.base_dilation() != 1) {
873 // Returns the offset of a shard's first valid element in the dilated
874 // shard.
875 auto get_first_valid_element_offset_on_dilated_shard =
876 [&](int64_t shard_ordinal) {
877 return start_on_padded_calculations[i].Calculate(shard_ordinal) *
878 wd.base_dilation() +
879 swd->padding_low() -
880 wd.stride() * per_shard_window_counts[i] * shard_ordinal;
881 };
882 CHECK_EQ(get_first_valid_element_offset_on_dilated_shard(0),
883 swd->padding_low());
884
885 // Determine swd->padding_high.
886 for (int64_t shard_ordinal = 0; shard_ordinal < shard_count;
887 ++shard_ordinal) {
888 int64_t wanted_limit_on_dilated_shard =
889 wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
890 int64_t actual_limit_on_dilated_shard_without_pad_high =
891 get_first_valid_element_offset_on_dilated_shard(shard_ordinal) +
892 (max_shard_size - 1) * wd.base_dilation() + 1;
893 swd->set_padding_high(std::max<int64_t>(
894 swd->padding_high(),
895 wanted_limit_on_dilated_shard -
896 actual_limit_on_dilated_shard_without_pad_high));
897 }
898
899 // Determine swd->padding_low and output dynamic slice index.
900 if (wd.stride() == 1) {
901 int64_t max_pad_low =
902 get_first_valid_element_offset_on_dilated_shard(0);
903 bool all_same = true;
904 for (int64_t shard_ordinal = 1; shard_ordinal < shard_count;
905 ++shard_ordinal) {
906 int64_t start =
907 get_first_valid_element_offset_on_dilated_shard(shard_ordinal);
908 if (start != swd->padding_low()) {
909 all_same = false;
910 }
911 max_pad_low = std::max(max_pad_low, start);
912 }
913 if (!all_same) {
914 auto start_on_padded_input =
915 start_on_padded_calculations[i].Calculate(partition_ordinals[i],
916 state_.b);
917 // We will calculate
918 // max_pad_low - (first_window - required_first_window)
919 // which equals
920 // required_first_window - (first_window - max_pad_low)
921 auto first_window_minus_max_pad_low =
922 MultiplyAddDivideOffsetCalculation(
923 wd.base_dilation(), swd->padding_low() - max_pad_low, 1)
924 .Calculate(start_on_padded_input, state_.b);
925 auto required_first_window =
926 MultiplyAddDivideOffsetCalculation(per_shard_window_counts[i], 0,
927 1)
928 .Calculate(partition_ordinals[i], state_.b);
929 dynamic_slice_offset_on_output[i] =
930 state_.b->AddInstruction(HloInstruction::CreateBinary(
931 required_first_window->shape(), HloOpcode::kSubtract,
932 required_first_window, first_window_minus_max_pad_low));
933 }
934 swd->set_padding_low(max_pad_low);
935 } else {
936 if ((wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() !=
937 0) {
938 // General base dilation not yet implemented.
939 return std::nullopt;
940 }
941 // padding_low on all shards should equal the initially assigned
942 // swd->padding_low(), i.e., the padding_low() on the original window.
943 }
944 }
945 }
946
947 // Returns the output dynamic slice offset when needed, and std::nullopt
948 // otherwise.
949 auto get_dynamic_slice_offset_on_output_if_needed =
950 [&]() -> std::optional<std::vector<HloInstruction*>> {
951 if (absl::c_all_of(
952 dynamic_slice_offset_on_output,
953 [](HloInstruction* offset) { return offset == nullptr; })) {
954 return std::nullopt;
955 }
956 auto zero = state_.b->AddInstruction(
957 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
958 for (int64_t i = 0; i < dynamic_slice_offset_on_output.size(); ++i) {
959 if (dynamic_slice_offset_on_output[i] == nullptr) {
960 dynamic_slice_offset_on_output[i] = zero;
961 }
962 }
963 return dynamic_slice_offset_on_output;
964 };
965
966 auto handle_all_windowed_dimensions_are_replicated = [&]() {
967 PaddingConfig padding_config;
968 auto pad_hlo_shape = padded_shape;
969 for (int64_t i = 0; i < base_shape_.rank(); ++i) {
970 auto padding_config_dim = padding_config.add_dimensions();
971 padding_config_dim->set_interior_padding(0);
972 // Do not pad non-partitioned dimensions or partitioned dimensions that
973 // are already sharded in a way that where the windowed sharding matches
974 // the sharding we want.
975 if (target.tile_assignment().dim(i) == 1 ||
976 can_leave_dimension_partitioned[i]) {
977 padding_config_dim->set_edge_padding_low(0);
978 padding_config_dim->set_edge_padding_high(0);
979 pad_hlo_shape.set_dimensions(i, hlo_->shape().dimensions(i));
980 continue;
981 }
982 padding_config_dim->set_edge_padding_low(explicit_left_padding[i]);
983 padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) -
984 explicit_left_padding[i] -
985 base_shape_.dimensions(i));
986 }
987 auto padded_hlo =
988 ShapeUtil::Compatible(pad_hlo_shape, base_shape_)
989 ? hlo_
990 : state_.b->AddInstruction(HloInstruction::CreatePad(
991 pad_hlo_shape, hlo_, pad_value, padding_config));
992 auto sharded_input =
993 state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
994 shard_shape, padded_hlo, offsets_on_padded_shape,
995 shard_shape.dimensions()));
996 return update_cache(WindowedInputShardReturnValue{
997 sharded_input, shard_window,
998 get_dynamic_slice_offset_on_output_if_needed()});
999 };
1000
1001 auto sharding_with_non_windowed_dims_replicated =
1002 GetShardingReplicatedOnWindowedDimension(target, window);
1003 // If the currrent HLO is replicated or all windows dimensions are replicated,
1004 // pad then slice. If the target sharding and current sharding are not the
1005 // same then give the halo exchange system a chance to run as it can skip
1006 // generating a dynamic slice.
1007 if (sharding().IsReplicated() ||
1008 (target != sharding() &&
1009 sharding_with_non_windowed_dims_replicated == sharding())) {
1010 return handle_all_windowed_dimensions_are_replicated();
1011 }
1012 if (target != sharding() &&
1013 sharding_with_non_windowed_dims_replicated != sharding()) {
1014 return Reshard(target).ReshardAsWindowedInput(window, target, pad_value);
1015 }
1016 if (Product(trimmed_target_sharding_tile_shape) == 1) {
1017 // The trimmed sharding may have just one shard left. We can simply return
1018 // hlo_ in this case.
1019 return update_cache(WindowedInputShardReturnValue{
1020 hlo_, shard_window, get_dynamic_slice_offset_on_output_if_needed()});
1021 }
1022 if (target.ReplicateOnLastTileDim()) {
1023 trimmed_target_sharding_tile_shape.push_back(
1024 target.tile_assignment().dimensions().back());
1025 }
1026 std::optional<HloSharding> trimmed_target;
1027 const HloSharding* halo_exchange_target = ⌖
1028 if (trimmed_shards) {
1029 // Remove devices on the right side.
1030 Array<int64_t> trimmed_devices(trimmed_target_sharding_tile_shape);
1031 trimmed_devices.Each([&](absl::Span<const int64_t> indices, int64_t* d) {
1032 std::vector<int64_t> target_indices(indices.begin(), indices.end());
1033 for (int64_t i = 0; i < base_shape_.rank(); ++i) {
1034 const auto& range = trimmed_target_sharding_middle_range[i];
1035 if (range.first >= 0 && indices[i] >= range.first) {
1036 target_indices[i] += range.second;
1037 }
1038 }
1039 *d = target.tile_assignment()(target_indices);
1040 });
1041 trimmed_target = target.ReplicateOnLastTileDim()
1042 ? HloSharding::PartialTile(trimmed_devices)
1043 : HloSharding::Tile(trimmed_devices);
1044 halo_exchange_target = &*trimmed_target;
1045 }
1046
1047 // Halo exchange.
1048 HloInstruction* visiting_hlo = hlo_;
1049
1050 if (!dims_needs_pre_masking.empty()) {
1051 std::vector<int64_t> skipped_dims;
1052 for (int dim = 0; dim < base_shape_.rank(); ++dim) {
1053 if (!absl::c_linear_search(dims_needs_pre_masking, dim)) {
1054 skipped_dims.push_back(dim);
1055 }
1056 }
1057 visiting_hlo = PadWithValueHlo(pad_value, /*left_padded_dims=*/{},
1058 /*skipped_dims=*/skipped_dims);
1059 }
1060
1061 // If we skipped unused data within a shard, we need to slice the input shard.
1062 if (trimmed_in_shard) {
1063 std::vector<int64_t> slice_sizes(halo_exchange_base_shape.rank());
1064 for (int64_t i = 0; i < slice_sizes.size(); ++i) {
1065 slice_sizes[i] =
1066 pre_halo_exchange_slice_limits[i] - pre_halo_exchange_slice_starts[i];
1067 }
1068 visiting_hlo = state_.b->AddInstruction(HloInstruction::CreateSlice(
1069 ShapeUtil::MakeShape(halo_exchange_base_shape.element_type(),
1070 slice_sizes),
1071 visiting_hlo,
1072 /*start_indices=*/pre_halo_exchange_slice_starts,
1073 /*limit_indices=*/pre_halo_exchange_slice_limits,
1074 /*strides=*/
1075 std::vector<int64_t>(halo_exchange_base_shape.rank(), 1)));
1076 }
1077
1078 std::vector<OffsetCalculation> left_halo_size_functions(base_shape_.rank());
1079 std::vector<OffsetCalculation> right_halo_size_functions(base_shape_.rank());
1080 // TODO(yuanzx): We are concatenating on each sharded dimension one at time,
1081 // and in the second dimension (and beyond) we create halos by slicing the
1082 // concat in the previous dimension, which is not optimal. We should generate
1083 // halos only concating slices, instead of slicing concats.
1084 for (int dim = 0; dim < base_shape_.rank(); ++dim) {
1085 int64_t shard_count = halo_exchange_target->tile_assignment().dim(dim);
1086 if (shard_count == 1 || can_leave_dimension_partitioned[dim]) {
1087 continue;
1088 }
1089 int64_t input_shard_size =
1090 CeilOfRatio(halo_exchange_base_shape.dimensions(dim), shard_count);
1091
1092 // Left halo. The size of the halo is derived by subtracting the first read
1093 // element offset of the i'th partition from the limit of the (i-1)'th
1094 // partition.
1095 MultiplyAddDivideOffsetCalculation shard_limit_of_previous_on_padded(
1096 input_shard_size, explicit_left_padding[dim], 1);
1097 left_halo_size_functions[dim] =
1098 shard_limit_of_previous_on_padded - start_on_padded_calculations[dim];
1099
1100 // Right halo.
1101 MultiplyAddDivideOffsetCalculation shard_start_of_next_on_padded(
1102 input_shard_size, input_shard_size + explicit_left_padding[dim], 1);
1103 right_halo_size_functions[dim] =
1104 limit_on_padded_calculations[dim] - shard_start_of_next_on_padded;
1105
1106 auto resharded = ExchangeHaloAndGetValidData(
1107 visiting_hlo, halo_exchange_base_shape, left_halo_size_functions[dim],
1108 right_halo_size_functions[dim], explicit_left_padding[dim],
1109 padded_shape.dimensions(dim), shard_shape.dimensions(dim), dim,
1110 *halo_exchange_target, offsets_on_padded_shape[dim], pad_value,
1111 partition_ordinals[dim], state_.collective_ops_creator,
1112 state_.next_channel_id, state_.b, mask_invalid_region);
1113 if (!resharded) {
1114 VLOG(1) << "ReshardAsWindowedInput failed without replicate first: halo "
1115 "is beyond the neighbor.";
1116 // If we are already sharded in such a way that all windowed dimensions
1117 // are replicated then just handle it with pad + slice.
1118 if (sharding_with_non_windowed_dims_replicated == sharding()) {
1119 return handle_all_windowed_dimensions_are_replicated();
1120 }
1121 return Reshard(sharding_with_non_windowed_dims_replicated)
1122 .ReshardAsWindowedInput(window, target, pad_value);
1123 }
1124 visiting_hlo = *resharded;
1125 }
1126 return update_cache(WindowedInputShardReturnValue{
1127 visiting_hlo, shard_window,
1128 get_dynamic_slice_offset_on_output_if_needed()});
1129 }
1130
Replicate()1131 PartitionedHlo PartitionedHlo::Replicate() {
1132 auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
1133 if (state_.partitioner->options().cache_all_gather) {
1134 for (auto& entry : cache) {
1135 if (entry.first.IsReplicated()) {
1136 return entry.second;
1137 }
1138 }
1139 }
1140 // Do not use a reference as the HLO's sharding can be temporarily replaced.
1141 const HloSharding sharding = hlo_->sharding();
1142 const Shape& shape = hlo_->shape();
1143 CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
1144
1145 if (sharding.IsReplicated()) {
1146 return *this;
1147 }
1148 for (auto& entry : cache) {
1149 if (entry.first.IsReplicated()) {
1150 return entry.second;
1151 }
1152 }
1153 auto update_cache = [&](PartitionedHlo resharded) {
1154 state_.reshard_cache->per_hlo_cache[resharded.hlo()]
1155 .reshard_cache.insert_or_assign(sharding, *this);
1156 // Get the cache again as it might be invalidated by the insertion above.
1157 auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
1158 if (state_.partitioner->options().cache_all_gather) {
1159 auto [it, _] = cache.insert_or_assign(HloSharding::Replicate(),
1160 std::move(resharded));
1161 return it->second;
1162 }
1163 return resharded;
1164 };
1165 // 'Single Device' to 'Repliated'.
1166 if (sharding.IsTileMaximal()) {
1167 return update_cache(Broadcast());
1168 }
1169
1170 // 'Tiled' to 'Replicated'.
1171 std::vector<int64_t> all_dims(shape.rank());
1172 std::iota(all_dims.begin(), all_dims.end(), 0);
1173 HloInstruction* result = ReplicatePartial(all_dims);
1174 result->set_sharding(HloSharding::Replicate());
1175 return update_cache(PartitionedHlo(result, base_shape_, state_));
1176 }
1177
ReplicatePartial(absl::Span<const int64_t> dims)1178 HloInstruction* PartitionedHlo::ReplicatePartial(
1179 absl::Span<const int64_t> dims) {
1180 CHECK(!sharding().IsTileMaximal());
1181 const Shape& shard_shape = hlo()->shape();
1182 Shape target_shape = shard_shape;
1183 Shape padded_target_shape = shard_shape;
1184 std::vector<int64_t> broadcast_dims;
1185 std::vector<int64_t> dus_ar_dims;
1186 std::vector<int64_t> ag_dims;
1187 // Find dimensions that can be replicated with Broadcast() (shard size 1) and
1188 // others that need all-gather. dus_ar_dims is a generalization of
1189 // broadcast_dims where the full size is less than half of allgather size, and
1190 // we will use dus->allreduce on them.
1191 for (int64_t i : dims) {
1192 int64_t partitions = sharding().tile_assignment().dim(i);
1193 if (partitions == 1) {
1194 continue;
1195 }
1196 target_shape.set_dimensions(i, base_shape().dimensions(i));
1197 if (target_shape.dimensions(i) == shard_shape.dimensions(i)) {
1198 broadcast_dims.push_back(i);
1199 } else if (target_shape.dimensions(i) <= partitions / 2) {
1200 dus_ar_dims.push_back(i);
1201 } else {
1202 padded_target_shape.set_dimensions(
1203 i, shard_shape.dimensions(i) * partitions);
1204 ag_dims.push_back(i);
1205 }
1206 }
1207
1208 HloInstruction* broadcast = hlo_;
1209 if (!broadcast_dims.empty()) {
1210 std::vector<int64_t> other_dims;
1211 for (int64_t i = 0; i < sharding().tile_assignment().num_dimensions();
1212 ++i) {
1213 if (!absl::c_linear_search(broadcast_dims, i)) {
1214 other_dims.push_back(i);
1215 }
1216 }
1217 HloSharding original_sharding = sharding();
1218 auto grouped =
1219 hlo_sharding_util::GroupShardingOnDims(original_sharding, other_dims);
1220 std::vector<int64_t> dev_indices(
1221 grouped.sharding.tile_assignment().num_dimensions(), 0);
1222 hlo_->set_sharding(HloSharding::AssignDevice(
1223 grouped.sharding.tile_assignment()(dev_indices)));
1224 auto per_group_partitioner_state = CreatePerGroupPartitioningState(
1225 state(), grouped.device_groups, state().b);
1226 auto partial_replicate_hlo =
1227 PartitionedHlo(hlo_, shard_shape, per_group_partitioner_state)
1228 .Broadcast();
1229 hlo_->set_sharding(original_sharding);
1230 partial_replicate_hlo.hlo()->clear_sharding();
1231 broadcast = partial_replicate_hlo.hlo();
1232 }
1233
1234 if (ag_dims.empty() && dus_ar_dims.empty()) {
1235 return broadcast;
1236 }
1237
1238 HloInstruction* result = nullptr;
1239 if (state_.collective_ops_creator.create_cross_partition_all_gather) {
1240 result = state_.partitioner->AllGatherShards(
1241 state_.b, broadcast, sharding(), state_.next_channel_id, ag_dims,
1242 state_.collective_ops_creator);
1243 }
1244 // May also contain failed allgather dims.
1245 if (result == nullptr) {
1246 for (int64_t dim : ag_dims) {
1247 dus_ar_dims.push_back(dim);
1248 }
1249 result = broadcast;
1250 }
1251 if (!dus_ar_dims.empty()) {
1252 auto zero = state_.b->AddInstruction(HloInstruction::CreateConstant(
1253 LiteralUtil::Zero(shard_shape.element_type())));
1254 std::vector<int64_t> masking_dims;
1255 for (int64_t dim : dus_ar_dims) {
1256 int64_t partitions = sharding().tile_assignment().dim(dim);
1257 if (base_shape().dimensions(dim) < partitions) {
1258 // DUS will be out-of-bound and offset will be clamped, so we need to
1259 // mask this dim with 0.
1260 masking_dims.push_back(dim);
1261 // Adjust the padded dim size. This can be also failed allgather dims.
1262 padded_target_shape.set_dimensions(dim, base_shape().dimensions(dim));
1263 }
1264 }
1265 if (!masking_dims.empty()) {
1266 std::vector<int64_t> skipped_dims;
1267 for (int64_t i = 0; i < base_shape().rank(); ++i) {
1268 if (!absl::c_linear_search(masking_dims, i)) {
1269 skipped_dims.push_back(i);
1270 }
1271 }
1272 result->set_sharding(sharding());
1273 result = PartitionedHlo(result, padded_target_shape, state_)
1274 .PadWithValue(zero,
1275 /*left_padded_dims=*/{},
1276 /*skipped_dims=*/skipped_dims)
1277 .hlo();
1278 }
1279 auto zero_bcast = state_.b->AddInstruction(
1280 HloInstruction::CreateBroadcast(padded_target_shape, zero, {}));
1281 auto offsets = MakePartitionOffsets(
1282 padded_target_shape,
1283 hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept(
1284 sharding(), dus_ar_dims),
1285 state_.partition_id, state_.b, dus_ar_dims);
1286 auto dus =
1287 state_.b->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1288 padded_target_shape, zero_bcast, result, offsets));
1289 HloComputation* reduction =
1290 MakeBinaryAdd(shard_shape.element_type(), state_.module);
1291 result = state_.partitioner->AllReduceAlongShardingDims(
1292 state_.b, dus, sharding(), state_.next_channel_id, dus_ar_dims,
1293 state_.collective_ops_creator, reduction);
1294 }
1295 if (!ShapeUtil::Compatible(target_shape, padded_target_shape)) {
1296 std::vector<int64_t> start_indices(target_shape.rank(), 0);
1297 std::vector<int64_t> strides(target_shape.rank(), 1);
1298 result = state_.b->AddInstruction(
1299 HloInstruction::CreateSlice(target_shape, result, start_indices,
1300 target_shape.dimensions(), strides));
1301 }
1302 return result;
1303 }
1304
1305 std::optional<PartitionedHlo>
ReshardToPartialReplicateWithAllGather(const HloSharding & target)1306 PartitionedHlo::ReshardToPartialReplicateWithAllGather(
1307 const HloSharding& target) {
1308 if (!target.ReplicateOnLastTileDim()) {
1309 return std::nullopt;
1310 }
1311 // Tiled/partial replicate to partial replicate
1312 // Get the comptible sharding to target with resharding by all reduce.
1313 auto compatible_sharding =
1314 PartialReplicateReshardCompatibleSharding(target, sharding());
1315 if (!compatible_sharding.has_value()) {
1316 return std::nullopt;
1317 }
1318
1319 const auto& temp_sharding = compatible_sharding.value();
1320 auto partitioned_hlo = *this;
1321 // Use collective permute to adjust device assignment if needed.
1322 if (CanReshardWithCollectivePermute(sharding(), temp_sharding)) {
1323 partitioned_hlo =
1324 partitioned_hlo.ReshardWithCollectivePermute(temp_sharding);
1325 }
1326
1327 // Get replicate dims and replicate factor of each dimensions.
1328 int64_t rank = hlo_->shape().rank();
1329 std::vector<int64_t> replicate_dims;
1330 std::vector<int64_t> replicate_factors;
1331 for (int64_t dim = 0; dim < rank; dim++) {
1332 int64_t replicate_factor = temp_sharding.tile_assignment().dim(dim) /
1333 target.tile_assignment().dim(dim);
1334 if (replicate_factor > 1) {
1335 replicate_dims.emplace_back(dim);
1336 replicate_factors.emplace_back(replicate_factor);
1337 }
1338 }
1339
1340 // Do left halo exchange if all-reduce directly will remove useful data
1341 // from the source.
1342 auto halo_exchange = TileToPartialReplicateHaloExchange(
1343 partitioned_hlo.hlo_, base_shape_, temp_sharding, target, replicate_dims,
1344 partitioned_hlo.state().collective_ops_creator,
1345 partitioned_hlo.state().next_channel_id,
1346 partitioned_hlo.state().partition_id, partitioned_hlo.state().b);
1347 if (!halo_exchange.has_value()) {
1348 return std::nullopt;
1349 }
1350 auto halo_exchange_hlo = halo_exchange.value();
1351 // Grouped on replicate dimensions.
1352 auto sharding_grouped = hlo_sharding_util::GroupShardingOnDims(
1353 temp_sharding, replicate_dims, replicate_factors);
1354 auto per_group_partitioner_state = CreatePerGroupPartitioningState(
1355 partitioned_hlo.state(), sharding_grouped.device_groups,
1356 partitioned_hlo.state().b);
1357 auto base_shape = MakePartitionedShape(base_shape_, target);
1358 // It's possible that halo_exchange_hlo == hlo.hlo().
1359 // Record the sharding of hlo here, and reset it before return.
1360 auto original_sharding = partitioned_hlo.sharding();
1361 halo_exchange_hlo->set_sharding(sharding_grouped.sharding);
1362 auto partial_replicate_hlo = PartitionedHlo(halo_exchange_hlo, base_shape,
1363 per_group_partitioner_state);
1364 HloInstruction* result =
1365 partial_replicate_hlo.ReplicatePartial(replicate_dims);
1366 partitioned_hlo.hlo()->set_sharding(original_sharding);
1367 result->set_sharding(target);
1368 return PartitionedHlo(result, base_shape_, partitioned_hlo.state());
1369 }
1370
1371 std::optional<PartitionedHlo>
ReshardFromPartialReplicateWithDynamicSlice(const HloSharding & target)1372 PartitionedHlo::ReshardFromPartialReplicateWithDynamicSlice(
1373 const HloSharding& target) {
1374 if (!sharding().ReplicateOnLastTileDim()) {
1375 return std::nullopt;
1376 }
1377
1378 // Get the temp sharding target from partial replicate to target tile dims.
1379 // target_compatible_sharding has the same tile_assignment dimensions
1380 // as the target and can reshard to target by collective permute.
1381 // target_compatible_sharding could have different device assignment as
1382 // targe. sharding() can reshard to target_compatible_sharding by
1383 // dynamic slice.
1384 auto target_compatible_sharding =
1385 PartialReplicateReshardCompatibleSharding(sharding(), target);
1386 // Reshard to target_compatible_sharding by dynamic slice.
1387 if (!target_compatible_sharding.has_value()) {
1388 return std::nullopt;
1389 }
1390 std::vector<int64_t> expand_tile_dims;
1391 std::vector<int64_t> tiling_dim_factors;
1392 int64_t rank = hlo_->shape().rank();
1393 tiling_dim_factors.reserve(target.tile_assignment().num_dimensions());
1394 const auto& temp_target_sharding = target_compatible_sharding.value();
1395 for (int64_t dim = 0; dim < rank; dim++) {
1396 if (temp_target_sharding.tile_assignment().dim(dim) >
1397 sharding().tile_assignment().dim(dim)) {
1398 expand_tile_dims.push_back(dim);
1399 }
1400 tiling_dim_factors.emplace_back(
1401 temp_target_sharding.tile_assignment().dim(dim) /
1402 sharding().tile_assignment().dim(dim));
1403 }
1404
1405 // Add another dimension in tiling_dim_factors if target is partial replicate.
1406 if (target.ReplicateOnLastTileDim()) {
1407 tiling_dim_factors.emplace_back(
1408 target.tile_assignment().dimensions().back());
1409 }
1410
1411 // 2. Get the padded_hlo, do right halo exchange if needed.
1412 auto padded_hlo = PadFromPartialReplicateShape(
1413 hlo_, base_shape_, sharding(), temp_target_sharding, expand_tile_dims,
1414 state_.collective_ops_creator, state_.next_channel_id,
1415 state_.partition_id, state_.b);
1416 if (!padded_hlo.has_value()) {
1417 return std::nullopt;
1418 }
1419 // 3. Slice out the tile from replicate ones.
1420 auto shard_shape = MakePartitionedShape(base_shape_, temp_target_sharding);
1421 // Since we are just slicing, we can just use the differences between the new
1422 // and old offsets in the full shape as the dynamic-slice offsets.
1423 auto padded_base_shape = shard_shape;
1424 for (int64_t i = 0; i < padded_base_shape.rank(); ++i) {
1425 padded_base_shape.set_dimensions(
1426 i, padded_base_shape.dimensions(i) *
1427 temp_target_sharding.tile_assignment().dim(i));
1428 }
1429 auto offsets = MakePartitionOffsets(padded_base_shape, temp_target_sharding,
1430 state_.partition_id, state_.b);
1431 auto old_offsets = MakePartitionOffsets(padded_base_shape, sharding(),
1432 state_.partition_id, state_.b);
1433 for (int64_t i = 0; i < offsets.size(); ++i) {
1434 offsets[i] = state_.b->AddInstruction(HloInstruction::CreateBinary(
1435 offsets[i]->shape(), HloOpcode::kSubtract, offsets[i], old_offsets[i]));
1436 }
1437 auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
1438 shard_shape, padded_hlo.value(), offsets, shard_shape.dimensions()));
1439 slice->set_sharding(temp_target_sharding);
1440 auto result = PartitionedHlo(slice, base_shape_, state_);
1441 // If temp_target_sharding's device assignment is different from target,
1442 // use collective permute to reshard.
1443 if (CanReshardWithCollectivePermute(temp_target_sharding, target)) {
1444 return result.ReshardWithCollectivePermute(target);
1445 }
1446 // If device assignment in temp_target_sharding and target are the same,
1447 // return result directly.
1448 return result;
1449 }
1450
Broadcast() const1451 PartitionedHlo PartitionedHlo::Broadcast() const {
1452 const Shape& shape = hlo_->shape();
1453 const HloSharding& sharding = hlo_->sharding();
1454 CHECK(sharding.HasUniqueDevice());
1455 CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
1456
1457 auto src_core_id = state_.b->AddInstruction(HloInstruction::CreateConstant(
1458 LiteralUtil::CreateR0<uint32_t>(sharding.GetUniqueDevice())));
1459 Shape bcast_shape = ShapeUtil::ChangeElementType(shape, PRED);
1460 auto is_src_core = state_.b->AddInstruction(HloInstruction::CreateBroadcast(
1461 bcast_shape,
1462 state_.b->AddInstruction(HloInstruction::CreateCompare(
1463 ShapeUtil::MakeShape(PRED, {}), state_.partition_id, src_core_id,
1464 ComparisonDirection::kEq)),
1465 {}));
1466
1467 auto zero = state_.b->AddInstruction(
1468 HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
1469 auto zero_bcast = state_.b->AddInstruction(
1470 HloInstruction::CreateBroadcast(shape, zero, {}));
1471 auto operand = state_.b->AddInstruction(HloInstruction::CreateTernary(
1472 shape, HloOpcode::kSelect, is_src_core, hlo(), zero_bcast));
1473 HloComputation* reduction =
1474 MakeBinaryAdd(shape.element_type(), state_.module);
1475
1476 auto result = state_.collective_ops_creator.create_cross_partition_all_reduce(
1477 state_.b, operand, reduction, {}, NewChannel());
1478 result->set_sharding(HloSharding::Replicate());
1479 return PartitionedHlo(result, base_shape_, state_);
1480 }
1481
ReshardWithAllToAll(const HloSharding & target,absl::Span<const std::pair<int64_t,int64_t>> source_target_dims) const1482 PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
1483 const HloSharding& target,
1484 absl::Span<const std::pair<int64_t, int64_t>> source_target_dims) const {
1485 if (source_target_dims.empty()) {
1486 if (target == sharding()) {
1487 return *this;
1488 }
1489 // If the device order is different in the target, fix the order with
1490 // ReshardWithCollectivePermute.
1491 return ReshardWithCollectivePermute(target);
1492 }
1493
1494 // Swap one pair of dimensions.
1495 int64_t source_dim = source_target_dims[0].first;
1496 int64_t target_dim = source_target_dims[0].second;
1497 const int64_t group_size = sharding().tile_assignment().dim(source_dim) /
1498 sharding().tile_assignment().dim(target_dim);
1499
1500 auto temp_target_tile = sharding().tile_assignment();
1501 {
1502 std::vector<int64_t> reshape_tile_dims(temp_target_tile.num_dimensions() +
1503 2);
1504 int64_t i = 0;
1505 int64_t added_source_dim = -1;
1506 int64_t added_target_dim = -1;
1507 for (int64_t j = 0; j < temp_target_tile.num_dimensions(); ++j) {
1508 if (source_dim == j) {
1509 reshape_tile_dims[i] = temp_target_tile.dim(j) / group_size;
1510 reshape_tile_dims[++i] = group_size;
1511 added_source_dim = i;
1512 } else if (target_dim == j) {
1513 reshape_tile_dims[i] = temp_target_tile.dim(j);
1514 reshape_tile_dims[++i] = 1;
1515 added_target_dim = i;
1516 } else {
1517 reshape_tile_dims[i] = temp_target_tile.dim(j);
1518 }
1519 ++i;
1520 }
1521 temp_target_tile.Reshape(reshape_tile_dims);
1522 std::vector<int64_t> xpose_dims(temp_target_tile.num_dimensions());
1523 std::iota(xpose_dims.begin(), xpose_dims.end(), 0);
1524 xpose_dims[added_source_dim] = added_target_dim;
1525 xpose_dims[added_target_dim] = added_source_dim;
1526 temp_target_tile = hlo_sharding_util::TransposeSharding(
1527 HloSharding::Tile(temp_target_tile), xpose_dims)
1528 .tile_assignment();
1529 auto temp_target_tile_dims = sharding().tile_assignment().dimensions();
1530 temp_target_tile_dims[source_dim] =
1531 sharding().tile_assignment().dim(target_dim);
1532 temp_target_tile_dims[target_dim] =
1533 sharding().tile_assignment().dim(source_dim);
1534 temp_target_tile.Reshape(temp_target_tile_dims);
1535 }
1536 auto temp_target = target.ReplicateOnLastTileDim()
1537 ? HloSharding::PartialTile(temp_target_tile)
1538 : HloSharding::Tile(temp_target_tile);
1539 auto padded_shape = hlo_->shape();
1540 padded_shape.set_dimensions(
1541 target_dim, RoundUpTo(padded_shape.dimensions(target_dim),
1542 temp_target.tile_assignment().dim(target_dim)));
1543 auto padded_hlo = PadToShape(hlo_, padded_shape, state_.b);
1544
1545 // The order of ids in the group must follow the temp_target sharding.
1546 std::vector<std::vector<int64_t>> groups(
1547 temp_target.tile_assignment().num_elements() / group_size);
1548 temp_target.tile_assignment().Each(
1549 [&](absl::Span<const int64_t> indices, int64_t device) {
1550 int64_t group_id = 0;
1551 for (int64_t dim = 0; dim < indices.size(); ++dim) {
1552 if (dim == target_dim) {
1553 group_id *= temp_target.tile_assignment().dim(dim) / group_size;
1554 group_id += indices[dim] / group_size;
1555 } else {
1556 group_id *= temp_target.tile_assignment().dim(dim);
1557 group_id += indices[dim];
1558 }
1559 }
1560 groups[group_id].push_back(device);
1561 });
1562
1563 HloInstruction* result = nullptr;
1564
1565 // Split along the split dimension (target_dim) of the all-to-all
1566 // output.
1567 std::vector<int64_t> dimensions;
1568 const int64_t rank = base_shape_.rank();
1569 dimensions.reserve(rank + 1);
1570 for (int64_t i = 0; i < rank; ++i) {
1571 if (i == target_dim) {
1572 dimensions.push_back(group_size);
1573 dimensions.push_back(padded_hlo->shape().dimensions(i) / group_size);
1574 } else {
1575 dimensions.push_back(padded_hlo->shape().dimensions(i));
1576 }
1577 }
1578 VLOG(5) << "Target ata shape: "
1579 << ShapeUtil::MakeShape(base_shape_.element_type(), dimensions)
1580 .ToString();
1581 auto reshape = state_.b->AddInstruction(HloInstruction::CreateReshape(
1582 ShapeUtil::MakeShape(base_shape_.element_type(), dimensions),
1583 padded_hlo));
1584 // After the reshape, it is guaranteed to have at least 3 dimensions.
1585 auto all_to_all =
1586 state_.collective_ops_creator.create_cross_partition_all_to_all(
1587 state_.b, {reshape}, groups, (*state_.next_channel_id)++, target_dim);
1588
1589 // Reorder the split dimension of the reshape to be located in front of the
1590 // input partition dimension, so the two dimensions can be combined.
1591 int64_t new_source_dim =
1592 (target_dim < source_dim) ? source_dim + 1 : source_dim;
1593 std::vector<int64_t> permutation;
1594 for (int64_t i = 0; i < all_to_all->shape().rank(); ++i) {
1595 if (i == target_dim) {
1596 continue;
1597 }
1598 if (i == new_source_dim) {
1599 permutation.push_back(target_dim);
1600 }
1601 permutation.push_back(i);
1602 }
1603 auto transpose = state_.b->AddInstruction(HloInstruction::CreateTranspose(
1604 ShapeInference::InferTransposeShape(all_to_all->shape(), permutation)
1605 .ValueOrDie(),
1606 all_to_all, permutation));
1607
1608 // Combine the split dimension and the input partition dimension.
1609 auto new_shape = ShapeInference::InferAllToAllShape(
1610 padded_hlo->shape(), target_dim, source_dim, group_size)
1611 .ValueOrDie();
1612 result = state_.b->AddInstruction(
1613 HloInstruction::CreateReshape(new_shape, transpose));
1614
1615 const Shape result_shape = MakePartitionedShape(base_shape_, temp_target);
1616 if (result_shape != result->shape()) {
1617 result = state_.b->AddInstruction(HloInstruction::CreateSlice(
1618 result_shape, result, std::vector<int64_t>(result_shape.rank(), 0),
1619 result_shape.dimensions(),
1620 std::vector<int64_t>(result_shape.rank(), 1)));
1621 }
1622 result->set_sharding(temp_target);
1623 auto remaining_source_target_dims = source_target_dims;
1624 remaining_source_target_dims.remove_prefix(1);
1625 return PartitionedHlo(result, base_shape_, state_)
1626 .ReshardWithAllToAll(target, remaining_source_target_dims);
1627 }
1628
1629 namespace {
1630
1631 // Matching a pattern like [..,X,..,Y] -> [..,X*Y,..,1] or [..,X,..,Y] ->
1632 // [..,1,..,X*Y].
PatternMatchReshape(const Shape & shape,const HloSharding & source,const HloSharding & target)1633 std::optional<std::pair<HloSharding, int>> PatternMatchReshape(
1634 const Shape& shape, const HloSharding& source, const HloSharding& target) {
1635 if (!source.IsTiled() || !target.IsTiled()) {
1636 return std::nullopt;
1637 }
1638 if (source.TiledDataRank() != target.TiledDataRank()) {
1639 return std::nullopt;
1640 }
1641 if ((source.HasPartialReplication() ^ target.HasPartialReplication()) ||
1642 (source.HasPartialReplication() &&
1643 source.tile_assignment().dimensions()[source.TiledDataRank()] !=
1644 target.tile_assignment().dimensions()[target.TiledDataRank()])) {
1645 return std::nullopt;
1646 }
1647 for (int i = 0; i < target.TiledDataRank(); ++i) {
1648 if (source.tile_assignment().dim(i) > target.tile_assignment().dim(i) &&
1649 target.tile_assignment().dim(i) == 1 &&
1650 (source.tile_assignment().dim(i) % target.tile_assignment().dim(i)) ==
1651 0) {
1652 const int64_t dimension_size =
1653 source.tile_assignment().dim(i) / target.tile_assignment().dim(i);
1654 for (int j = i - 1; j >= 0; --j) {
1655 if (target.tile_assignment().dim(j) == 1) {
1656 continue;
1657 }
1658 if (target.tile_assignment().dim(j) !=
1659 dimension_size * source.tile_assignment().dim(j)) {
1660 continue;
1661 }
1662 // Do not consider if it requires additional padding.
1663 if (shape.dimensions(j) % dimension_size != 0) {
1664 continue;
1665 }
1666 auto reshaped_sharding = hlo_sharding_util::SplitShardingDimension(
1667 source, j, source.tile_assignment().dim(j));
1668 std::vector<int64_t> permutation(reshaped_sharding.TiledDataRank(), 0);
1669 absl::c_iota(permutation, 0);
1670 std::swap(permutation[i + 1], permutation[j]);
1671 reshaped_sharding = hlo_sharding_util::TransposeSharding(
1672 reshaped_sharding, permutation);
1673 return std::make_pair(reshaped_sharding, j);
1674 }
1675 for (int j = i + 1; j < target.TiledDataRank(); ++j) {
1676 if (target.tile_assignment().dim(j) == 1) {
1677 continue;
1678 }
1679 if (target.tile_assignment().dim(j) !=
1680 dimension_size * source.tile_assignment().dim(j)) {
1681 continue;
1682 }
1683 // Do not consider if it requires additional padding.
1684 if (shape.dimensions(j) % dimension_size != 0) {
1685 continue;
1686 }
1687
1688 auto reshaped_sharding = hlo_sharding_util::SplitShardingDimension(
1689 source, j, source.tile_assignment().dim(j));
1690 std::vector<int64_t> permutation(reshaped_sharding.TiledDataRank(), 0);
1691 absl::c_iota(permutation, 0);
1692 VLOG(5) << "Reshaped sharding before: " << reshaped_sharding.ToString();
1693 std::swap(permutation[i], permutation[j]);
1694 reshaped_sharding = hlo_sharding_util::TransposeSharding(
1695 reshaped_sharding, permutation);
1696 VLOG(5) << "Reshaped sharding: " << reshaped_sharding.ToString();
1697 return std::make_pair(reshaped_sharding, j);
1698 }
1699 }
1700 }
1701 return std::nullopt;
1702 }
1703 // Match patterns like [..,X,..,Z,..,Y,..] -> [..,Y,..,Z,..,X]
1704 // last_tile_dim_replicate, where X gets replicated and Y also changes
1705 // position. We try to perform the replication, so we can match some other
1706 // targets instead.
PatternMatchPartiallyReplicateDim(const HloSharding & source,const HloSharding & target)1707 std::optional<HloSharding> PatternMatchPartiallyReplicateDim(
1708 const HloSharding& source, const HloSharding& target) {
1709 if (!(!source.ReplicateOnLastTileDim() && target.ReplicateOnLastTileDim())) {
1710 return std::nullopt;
1711 }
1712 const int64_t target_replicated_dim = target.SubgroupReplicationDim();
1713 CHECK_NE(target_replicated_dim, -1) << "Expected replicated dim";
1714 for (int i = 0; i < source.tile_assignment().num_dimensions(); ++i) {
1715 if (source.tile_assignment().dim(i) !=
1716 target.tile_assignment().dim(target_replicated_dim)) {
1717 continue;
1718 }
1719 auto replicated_sharding =
1720 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(source, {i});
1721 return replicated_sharding;
1722 }
1723 return std::nullopt;
1724 }
1725
1726 // Helper to split a PartitionedHlo over a specific dimension.
SplitReshapeHelper(PartitionedHlo to_reshape,int64_t dim_to_split,int64_t dim_size,const HloSharding & target_sharding)1727 PartitionedHlo SplitReshapeHelper(PartitionedHlo to_reshape,
1728 int64_t dim_to_split, int64_t dim_size,
1729 const HloSharding& target_sharding) {
1730 Shape original_shape = to_reshape.hlo()->shape();
1731 std::vector<int64_t> shape_dim(original_shape.dimensions().begin(),
1732 original_shape.dimensions().end());
1733 shape_dim.insert(shape_dim.begin() + dim_to_split + 1, 1);
1734 std::vector<int64_t> base_shape_dim(
1735 to_reshape.base_shape().dimensions().begin(),
1736 to_reshape.base_shape().dimensions().end());
1737 base_shape_dim.insert(base_shape_dim.begin() + dim_to_split + 1, dim_size);
1738 base_shape_dim[dim_to_split] /= dim_size;
1739 Shape shape = ShapeUtil::MakeShape(original_shape.element_type(), shape_dim);
1740 HloInstruction* reshaped_instr = to_reshape.state().b->AddInstruction(
1741 HloInstruction::CreateReshape(shape, to_reshape.hlo()));
1742 reshaped_instr->set_sharding(target_sharding);
1743 return PartitionedHlo{
1744 reshaped_instr,
1745 ShapeUtil::MakeShape(to_reshape.base_shape().element_type(),
1746 base_shape_dim),
1747 to_reshape.state()};
1748 }
1749 // Merge a PartitionedHlo over a specific dimension.
MergeReshapeHelper(PartitionedHlo to_reshape,int64_t dim_to_merge,const HloSharding & target_sharding)1750 PartitionedHlo MergeReshapeHelper(PartitionedHlo to_reshape,
1751 int64_t dim_to_merge,
1752 const HloSharding& target_sharding) {
1753 Shape original_shape = to_reshape.hlo()->shape();
1754 std::vector<int64_t> shape_dim(original_shape.dimensions().begin(),
1755 original_shape.dimensions().end());
1756 shape_dim[dim_to_merge] *= shape_dim[dim_to_merge + 1];
1757 shape_dim.erase(shape_dim.begin() + dim_to_merge + 1);
1758 std::vector<int64_t> base_shape_dim(
1759 to_reshape.base_shape().dimensions().begin(),
1760 to_reshape.base_shape().dimensions().end());
1761 base_shape_dim[dim_to_merge] *= base_shape_dim[dim_to_merge + 1];
1762 base_shape_dim.erase(base_shape_dim.begin() + dim_to_merge + 1);
1763 Shape shape = ShapeUtil::MakeShape(original_shape.element_type(), shape_dim);
1764 HloInstruction* reshaped_instr = to_reshape.state().b->AddInstruction(
1765 HloInstruction::CreateReshape(shape, to_reshape.hlo()));
1766 reshaped_instr->set_sharding(target_sharding);
1767 return PartitionedHlo(
1768 reshaped_instr,
1769 ShapeUtil::MakeShape(original_shape.element_type(), base_shape_dim),
1770 to_reshape.state());
1771 }
1772
1773 } // namespace
1774
TryComplexReshardHandling(const HloSharding & target)1775 std::optional<PartitionedHlo> PartitionedHlo::TryComplexReshardHandling(
1776 const HloSharding& target) {
1777 VLOG(5) << "Trying to split complicated reshard: " << sharding().ToString()
1778 << " to " << target.ToString();
1779 const bool is_source_partially_replicated =
1780 sharding().ReplicateOnLastTileDim();
1781 const bool is_target_partially_replicated = target.ReplicateOnLastTileDim();
1782 if (auto reshape =
1783 PatternMatchReshape(this->hlo()->shape(), sharding(), target)) {
1784 VLOG(5) << "Matched \"pattern_match_reshape()\": "
1785 << reshape->first.ToString();
1786 VLOG(5) << "Original shape: " << hlo()->shape().ToString();
1787 auto before_sharding = hlo_sharding_util::SplitShardingDimension(
1788 sharding(), reshape->second,
1789 sharding().tile_assignment().dim(reshape->second));
1790 PartitionedHlo reshaped = SplitReshapeHelper(
1791 *this, reshape->second,
1792 sharding().tile_assignment().dim(reshape->second), before_sharding);
1793 VLOG(5) << "Reshaped shape: " << reshaped.hlo()->shape().ToString();
1794 auto reshard = reshaped.ReshardNoCache(reshape->first,
1795 /*pad_value=*/std::nullopt,
1796 /*allow_full_replication=*/false);
1797 if (reshard.sharding() != reshape->first) {
1798 return std::nullopt;
1799 }
1800 auto reshaped_sharding = hlo_sharding_util::MergeShardingDimension(
1801 reshard.sharding(), reshape->second);
1802 reshaped = MergeReshapeHelper(reshard, reshape->second, reshaped_sharding);
1803 if (reshaped.sharding() != target) {
1804 reshaped = reshaped.ReshardNoCache(target, /*pad_value=*/std::nullopt,
1805 /*allow_full_replication=*/false);
1806 if (reshaped.sharding() != target) {
1807 return std::nullopt;
1808 }
1809 }
1810 return reshaped;
1811 }
1812 if (auto intermediate_target =
1813 PatternMatchPartiallyReplicateDim(sharding(), target)) {
1814 VLOG(5) << "Matched \"pattern_match_partially_replicate_dim()\": "
1815 << intermediate_target->ToString();
1816 auto intermediate_reshard = Reshard(*intermediate_target);
1817 auto final_reshard = intermediate_reshard.ReshardNoCache(
1818 target, /*pad_value=*/std::nullopt, /*allow_full_replication=*/false);
1819 if (final_reshard.sharding() != target) {
1820 return std::nullopt;
1821 }
1822 return final_reshard;
1823 }
1824 if (is_source_partially_replicated && !is_target_partially_replicated) {
1825 const int64_t partial_repl_amount =
1826 sharding().tile_assignment().dimensions().back();
1827 int64_t first_different_dimension = -1;
1828 // Trying to match conditions like [..,X,..,Z,..,Y] last_tile_dim_replicate
1829 // to [..,Y,..,Z,..,X,..], where Y in the source is partially replicated,
1830 // but in the target it is not and some other dimension got moved or
1831 // modified. Try to remove the partial replication to simplify the step from
1832 // source to target sharding.
1833 for (int64_t i = 0; i < target.tile_assignment().num_dimensions(); ++i) {
1834 if (target.tile_assignment().dim(i) !=
1835 sharding().tile_assignment().dim(i) &&
1836 sharding().tile_assignment().dim(i) == 1 &&
1837 target.tile_assignment().dim(i) % partial_repl_amount == 0) {
1838 first_different_dimension = i;
1839 break;
1840 }
1841 }
1842 if (first_different_dimension == -1) {
1843 return std::nullopt;
1844 }
1845 VLOG(5) << "Matched partially replicated to non partially replicated: "
1846 << sharding().ToString();
1847 std::vector<int64_t> transpose_dims(
1848 sharding().tile_assignment().num_dimensions(), 0);
1849 std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
1850 std::swap(transpose_dims[first_different_dimension], transpose_dims.back());
1851 auto intermediate_sharding =
1852 hlo_sharding_util::TransposeSharding(sharding(), transpose_dims);
1853 auto intermediate_reshard = Reshard(intermediate_sharding);
1854 auto reshard = intermediate_reshard.ReshardNoCache(
1855 target, /*pad_value=*/std::nullopt, /*allow_full_replication=*/false);
1856 if (reshard.sharding() != target) {
1857 return std::nullopt;
1858 }
1859 return reshard;
1860 }
1861 return std::nullopt;
1862 }
1863
1864 std::optional<PartitionedHlo>
ReshardPartialReplicateWithAllToAll(const HloSharding & target)1865 PartitionedHlo::ReshardPartialReplicateWithAllToAll(const HloSharding& target) {
1866 bool source_is_partial_replicate = sharding().ReplicateOnLastTileDim();
1867 const auto& partial_replicate_sharding =
1868 source_is_partial_replicate ? sharding() : target;
1869 // If neither the source nor the target is partial replicate, return null.
1870 if (!partial_replicate_sharding.ReplicateOnLastTileDim()) {
1871 return std::nullopt;
1872 }
1873 const auto& tile_sharding = source_is_partial_replicate ? target : sharding();
1874 // If both source and target are partial replicate, should be supported in
1875 // Reshard with AllToAll already.
1876 if (tile_sharding.ReplicateOnLastTileDim() || tile_sharding.IsTileMaximal()) {
1877 return std::nullopt;
1878 }
1879
1880 // Only support resharding from sharding={devices=[2,3]0,1,2,3,4,5}
1881 // to sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}, where
1882 // the last tile dim will be replicate first before all-to-all.
1883 // Or resharding from
1884 // sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
1885 // to sharding={devices=[2,3]0,1,2,3,4,5}, where
1886 // the last tile dim will be sharded after all-to-all.
1887 const int num_replicas =
1888 partial_replicate_sharding.tile_assignment().dimensions().back();
1889 if (((tile_sharding.tile_assignment().num_dimensions() + 1) !=
1890 partial_replicate_sharding.tile_assignment().num_dimensions()) ||
1891 (partial_replicate_sharding.tile_assignment().dim(0) != 1)) {
1892 return std::nullopt;
1893 }
1894 int to_replicate_dim = -1;
1895 for (int i = tile_sharding.tile_assignment().num_dimensions() - 1; i >= 0;
1896 --i) {
1897 if (tile_sharding.tile_assignment().dim(i) > 1 &&
1898 (to_replicate_dim == -1)) {
1899 if (tile_sharding.tile_assignment().dim(i) != num_replicas) {
1900 return std::nullopt;
1901 }
1902 to_replicate_dim = i;
1903 }
1904
1905 if (tile_sharding.tile_assignment().dim(i) !=
1906 partial_replicate_sharding.tile_assignment().dim(i + 1)) {
1907 return std::nullopt;
1908 }
1909 }
1910
1911 if (to_replicate_dim == -1) {
1912 return std::nullopt;
1913 }
1914
1915 // Check if core assignments for source and the target are the same.
1916 auto reshape_tile_assignment = partial_replicate_sharding.tile_assignment();
1917 reshape_tile_assignment.Reshape(tile_sharding.tile_assignment().dimensions());
1918 if (reshape_tile_assignment != tile_sharding.tile_assignment()) {
1919 return std::nullopt;
1920 }
1921
1922 auto tmp_tile_assignment = tile_sharding.tile_assignment();
1923 auto tmp_tile_assignment_dimensions =
1924 tile_sharding.tile_assignment().dimensions();
1925 tmp_tile_assignment_dimensions[to_replicate_dim] = 1;
1926 tmp_tile_assignment_dimensions.push_back(num_replicas);
1927 tmp_tile_assignment.Reshape(tmp_tile_assignment_dimensions);
1928 auto tmp_partial_replicate_sharding =
1929 HloSharding::PartialTile(tmp_tile_assignment);
1930
1931 if (source_is_partial_replicate) {
1932 if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
1933 sharding(), tmp_partial_replicate_sharding)) {
1934 auto partitioned_hlo =
1935 ReshardWithAllToAll(tmp_partial_replicate_sharding, *src_tgt_dims);
1936 return partitioned_hlo.Reshard(target);
1937 }
1938 } else {
1939 auto partitioned_hlo = Reshard(tmp_partial_replicate_sharding);
1940
1941 if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
1942 partitioned_hlo.sharding(), target)) {
1943 return partitioned_hlo.ReshardWithAllToAll(target, *src_tgt_dims);
1944 }
1945 }
1946
1947 return std::nullopt;
1948 }
1949
ReshardWithCollectivePermute(const HloSharding & target) const1950 PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
1951 const HloSharding& target) const {
1952 CHECK(CanReshardWithCollectivePermute(sharding(), target))
1953 << sharding().ToString() << " to " << target.ToString();
1954 if (auto broadcast_dims = state_.b->BroadcastDimsForCreatedHlo(hlo())) {
1955 if (!(*broadcast_dims)->empty()) {
1956 // If hlo() has broadcast dims, check if data is already the same between
1957 // source/destination pairs.
1958 std::vector<int64_t> broadcast_dims_vector;
1959 for (int64_t i = 0; i < hlo()->shape().rank(); ++i) {
1960 if ((*broadcast_dims)->contains(i)) {
1961 broadcast_dims_vector.push_back(i);
1962 }
1963 }
1964 if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
1965 sharding(), broadcast_dims_vector) ==
1966 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
1967 target, broadcast_dims_vector)) {
1968 auto copy = state_.b->AddInstruction(HloInstruction::CreateUnary(
1969 hlo()->shape(), HloOpcode::kCopy, hlo()));
1970 copy->set_sharding(target);
1971 return PartitionedHlo(copy, base_shape_, state_);
1972 }
1973 }
1974 }
1975 std::vector<std::pair<int64_t, int64_t>> src_dst_pairs;
1976 sharding().tile_assignment().Each(
1977 [&](absl::Span<const int64_t> indices, int64_t src_device) {
1978 int64_t dst_device = target.tile_assignment()(indices);
1979 src_dst_pairs.emplace_back(src_device, dst_device);
1980 });
1981 auto cp =
1982 state_.collective_ops_creator.create_cross_partition_collective_permute(
1983 state_.b, hlo(), src_dst_pairs, (*state_.next_channel_id)++);
1984 cp->set_sharding(target);
1985 return PartitionedHlo(cp, base_shape_, state_);
1986 }
1987
SpmdPartitioningVisitor(HloComputation * computation,int64_t num_partitions,int64_t num_replicas,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t * next_channel_id,SpmdLogger * logger,SpmdPartitionerOptions options,SpmdPartitioner * partitioner)1988 SpmdPartitioningVisitor::SpmdPartitioningVisitor(
1989 HloComputation* computation, int64_t num_partitions, int64_t num_replicas,
1990 const SPMDCollectiveOpsCreator& collective_ops_creator,
1991 int64_t* next_channel_id, SpmdLogger* logger,
1992 SpmdPartitionerOptions options, SpmdPartitioner* partitioner)
1993 : changed_(false),
1994 module_(computation->parent()),
1995 num_partitions_(num_partitions),
1996 num_replicas_(num_replicas),
1997 collective_ops_creator_(collective_ops_creator),
1998 next_channel_id_(next_channel_id),
1999 b_(SpmdBuilder(computation->name() + "_spmd", /*hlo=*/nullptr)),
2000 partition_id_(collective_ops_creator_.create_partition_id(&b_)),
2001 logger_(logger),
2002 options_(std::move(options)),
2003 partitioner_(partitioner) {}
2004
2005 PartitionedHlo::PartitioningState
MakePartitioningState()2006 SpmdPartitioningVisitor::MakePartitioningState() {
2007 PartitionedHlo::PartitioningState state;
2008 state.b = &b_;
2009 state.module = module_;
2010 state.num_replicas = num_replicas_;
2011 state.next_channel_id = next_channel_id_;
2012 state.reshard_cache = &reshard_cache_;
2013 state.partitioner = partitioner_;
2014 if (!device_groups_.empty()) {
2015 // Use the original collective creator and partition_id to call
2016 // CreatePerGroupPartitioningState(). Current collective_ops_creator_ and
2017 // partition_id_ have been rewritten to be subgrouped.
2018 state.collective_ops_creator = *visiting_collective_ops_creator_;
2019 state.partition_id = *visiting_partition_id_;
2020 return CreatePerGroupPartitioningState(state, device_groups_, &b_);
2021 } else {
2022 state.collective_ops_creator = collective_ops_creator_;
2023 state.partition_id = partition_id_;
2024 }
2025 return state;
2026 }
2027
CreateReplicaGroups(std::vector<std::vector<int64_t>> & groups)2028 std::vector<ReplicaGroup> SpmdPartitioningVisitor::CreateReplicaGroups(
2029 std::vector<std::vector<int64_t>>& groups) {
2030 std::vector<ReplicaGroup> device_groups;
2031 device_groups.reserve(groups.size() * num_replicas_);
2032 for (int64_t i = 0; i < num_replicas_; ++i) {
2033 for (const auto& group : groups) {
2034 device_groups.emplace_back();
2035 for (int64_t id : group) {
2036 device_groups.back().add_replica_ids(i * num_partitions_ + id);
2037 }
2038 }
2039 }
2040 return device_groups;
2041 }
2042
DefaultAction(HloInstruction * hlo)2043 Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) {
2044 if (hlo->HasSideEffect() && !hlo->sharding().HasUniqueDevice()) {
2045 return Unimplemented("Side-effect ops cannot be replicated: %s",
2046 hlo->ToString());
2047 }
2048
2049 if (hlo->IsElementwise() && hlo->operand_count() > 0) {
2050 return HandleElementwise(hlo);
2051 }
2052
2053 if (!hlo->sharding().IsTileMaximal()) {
2054 VLOG(1) << "Not partitioned in SPMD mode (DefaultAction):"
2055 << hlo->ToString();
2056 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
2057 VLOG(1) << " operand " << i
2058 << " sharding:" << hlo->operand(i)->sharding().ToString();
2059 }
2060 }
2061
2062 HloSharding sharding = hlo->sharding().HasUniqueDevice()
2063 ? hlo->sharding()
2064 : HloSharding::Replicate();
2065 if (hlo->opcode() == HloOpcode::kSend || hlo->opcode() == HloOpcode::kRecv ||
2066 hlo->opcode() == HloOpcode::kRecvDone) {
2067 sharding = sharding.GetSubSharding(hlo->shape(), {0});
2068 }
2069
2070 // If the instruction cannot be partitioned, replicate the instruction unless
2071 // the instruction has side-effect.
2072 std::vector<HloInstruction*> new_operands;
2073 for (HloInstruction* operand : hlo->operands()) {
2074 new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
2075 }
2076 auto clone =
2077 b_.AddInstruction(hlo->CloneWithNewOperands(hlo->shape(), new_operands));
2078 clone->set_sharding(sharding);
2079 SetPartitionedHlo(hlo,
2080 PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
2081 .Reshard(hlo->sharding()));
2082 return OkStatus();
2083 }
2084
Preprocess(HloInstruction * hlo)2085 Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) {
2086 visiting_hlo_ = hlo;
2087 b_.set_visiting_hlo(hlo);
2088 // Temporarily replace manual sharding to one-device sharding so that the
2089 // partitioner will not change the HLOs.
2090 auto manual_to_onedevice = [&](HloOpcode opcode, const Shape& shape,
2091 const HloSharding& sharding) {
2092 // If a tuple's elements are all manual, then sharding.IsManual() == True,
2093 // so we test whether it is tuple first.
2094 if (sharding.IsTuple()) {
2095 std::vector<HloSharding> subshardings = sharding.tuple_elements();
2096 for (HloSharding& subsharding : subshardings) {
2097 // Delay manual sharding substitution for CustomCalls.
2098 if (subsharding.IsManual() && opcode != HloOpcode::kCustomCall) {
2099 subsharding = HloSharding::AssignDevice(0);
2100 }
2101 }
2102 return HloSharding::Tuple(shape, subshardings);
2103 }
2104 // Delay manual sharding substitution for CustomCalls.
2105 if (sharding.IsManual() && opcode != HloOpcode::kCustomCall) {
2106 return HloSharding::AssignDevice(0);
2107 }
2108 return sharding;
2109 };
2110
2111 if (hlo->opcode() != HloOpcode::kConditional &&
2112 hlo->opcode() != HloOpcode::kTuple &&
2113 hlo->opcode() != HloOpcode::kGetTupleElement &&
2114 hlo->opcode() != HloOpcode::kParameter &&
2115 hlo->opcode() != HloOpcode::kWhile && hlo->opcode() != HloOpcode::kRng &&
2116 hlo->opcode() != HloOpcode::kAllReduce) {
2117 const bool has_manual_sharding =
2118 hlo->sharding().IsManual() ||
2119 (hlo->sharding().IsTuple() &&
2120 absl::c_any_of(
2121 hlo->sharding().tuple_elements(),
2122 [](const HloSharding& sharding) { return sharding.IsManual(); }));
2123 if (has_manual_sharding && !hlo->IsCustomCall("SPMDFullToShardShape")) {
2124 visiting_hlo_sharding_ = hlo->sharding();
2125 hlo->set_sharding(manual_to_onedevice(hlo->opcode(), hlo->shape(),
2126 *visiting_hlo_sharding_));
2127
2128 visiting_hlo_operand_shardings_.reserve(hlo->operand_count());
2129 for (HloInstruction* operand : hlo->unique_operands()) {
2130 visiting_hlo_operand_shardings_.push_back(operand->sharding());
2131 operand->set_sharding(manual_to_onedevice(
2132 hlo->opcode(), operand->shape(), operand->sharding()));
2133 GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
2134 }
2135 } else {
2136 const bool has_manual_subgroup =
2137 hlo->sharding().IsManualSubgroup() ||
2138 (hlo->sharding().IsTuple() &&
2139 absl::c_any_of(hlo->sharding().tuple_elements(),
2140 [](const HloSharding& sharding) {
2141 return sharding.IsManualSubgroup();
2142 }));
2143 if (has_manual_subgroup && !hlo->IsCustomCall("SPMDFullToShardShape")) {
2144 auto get_grouped_sharding =
2145 [&](const HloSharding& sharding, const Shape& shape,
2146 const GroupedSharding* ref =
2147 nullptr) -> StatusOr<GroupedSharding> {
2148 if (!sharding.IsTuple()) {
2149 GroupedSharding grouped =
2150 hlo_sharding_util::GetManualSubgroupSharding(sharding);
2151 if (ref != nullptr) {
2152 auto aligned =
2153 AlignGroupsWithIfCompatible(std::move(grouped), *ref);
2154 TF_RET_CHECK(aligned.has_value())
2155 << "Incompatible manual sharding at " << hlo->ToString();
2156 return *aligned;
2157 }
2158 return grouped;
2159 }
2160 std::vector<HloSharding> elements;
2161 elements.reserve(sharding.tuple_elements().size());
2162 CHECK(!sharding.tuple_elements().empty());
2163 GroupedSharding grouped0 =
2164 hlo_sharding_util::GetManualSubgroupSharding(
2165 sharding.tuple_elements()[0]);
2166 if (ref != nullptr) {
2167 auto aligned =
2168 AlignGroupsWithIfCompatible(std::move(grouped0), *ref);
2169 TF_RET_CHECK(aligned.has_value())
2170 << "Incompatible manual sharding at " << hlo->ToString();
2171 grouped0 = std::move(*aligned);
2172 }
2173 elements.push_back(std::move(grouped0.sharding));
2174 for (int64_t i = 1; i < sharding.tuple_elements().size(); ++i) {
2175 auto grouped_i = AlignGroupsWithIfCompatible(
2176 hlo_sharding_util::GetManualSubgroupSharding(
2177 sharding.tuple_elements()[i]),
2178 grouped0);
2179 TF_RET_CHECK(grouped_i.has_value())
2180 << "Incompatible manual sharding between tuple elements: "
2181 << hlo->ToString();
2182 elements.push_back(std::move(grouped_i->sharding));
2183 }
2184 grouped0.sharding = HloSharding::Tuple(shape, elements);
2185 return grouped0;
2186 };
2187 TF_ASSIGN_OR_RETURN(
2188 auto group_sharding,
2189 get_grouped_sharding(hlo->sharding(), hlo->shape()));
2190 // Update sharding.
2191 visiting_hlo_sharding_ = hlo->sharding();
2192 hlo->set_sharding(group_sharding.sharding);
2193 // Update device_groups and num_partitions.
2194 // Set device_groups_, visiting_partition_id_ and
2195 // visiting_collective_ops_creator_ before MakePartitioningState() which
2196 // uses them.
2197 device_groups_ = group_sharding.device_groups;
2198 visiting_num_partitions_ = num_partitions_;
2199 num_partitions_ = num_partitions_ / group_sharding.device_groups.size();
2200 visiting_partition_id_ = partition_id_;
2201 visiting_collective_ops_creator_ = std::move(collective_ops_creator_);
2202 auto grouped_state = MakePartitioningState();
2203 collective_ops_creator_ =
2204 std::move(grouped_state.collective_ops_creator);
2205 partition_id_ = grouped_state.partition_id;
2206
2207 // Update sharding for the operands.
2208 visiting_hlo_operand_shardings_.reserve(hlo->operand_count());
2209 visiting_state_.reserve(hlo->operand_count());
2210 for (HloInstruction* operand : hlo->unique_operands()) {
2211 visiting_hlo_operand_shardings_.push_back(operand->sharding());
2212 auto old_state = GetPartitionedHlo(operand).state();
2213 visiting_state_.push_back(old_state);
2214 if (operand->shape().IsArray() && operand->IsConstant() &&
2215 operand->shape().rank() == 0 &&
2216 !operand->sharding().IsManualSubgroup()) {
2217 // We allowed scalar constants to be CSE'ed between manual/auto
2218 // subgraphs. It's possible that it doesn't have a manual subgroup.
2219 continue;
2220 }
2221 TF_ASSIGN_OR_RETURN(
2222 auto op_group_sharding,
2223 get_grouped_sharding(operand->sharding(), operand->shape(),
2224 &group_sharding));
2225 operand->set_sharding(op_group_sharding.sharding);
2226 GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
2227 auto group_state = CreatePerGroupPartitioningState(
2228 old_state, op_group_sharding.device_groups, &b_);
2229 GetPartitionedHlo(operand).set_state(group_state);
2230 }
2231 }
2232 }
2233 }
2234 return OkStatus();
2235 }
2236
Postprocess(HloInstruction * hlo)2237 Status SpmdPartitioningVisitor::Postprocess(HloInstruction* hlo) {
2238 logger_->RegisterLogEntry(hlo, b_.derived_instructions(hlo));
2239 visiting_hlo_ = nullptr;
2240 b_.set_visiting_hlo(nullptr);
2241 // Revert fake one-device shardings for manually partitioned ops.
2242 if (visiting_hlo_sharding_) {
2243 hlo->set_sharding(*visiting_hlo_sharding_);
2244 GetPartitionedHlo(hlo).hlo()->set_sharding(*visiting_hlo_sharding_);
2245 int64_t i = 0;
2246 for (HloInstruction* operand : hlo->unique_operands()) {
2247 operand->set_sharding(visiting_hlo_operand_shardings_[i++]);
2248 GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
2249 }
2250 visiting_hlo_sharding_.reset();
2251 visiting_hlo_operand_shardings_.clear();
2252 }
2253
2254 if (!device_groups_.empty()) {
2255 device_groups_.clear();
2256 num_partitions_ = *visiting_num_partitions_;
2257 visiting_num_partitions_.reset();
2258 collective_ops_creator_ = *visiting_collective_ops_creator_;
2259 visiting_collective_ops_creator_.reset();
2260 partition_id_ = *visiting_partition_id_;
2261 visiting_partition_id_.reset();
2262 GetPartitionedHlo(hlo).set_state(MakePartitioningState());
2263 }
2264
2265 if (!visiting_state_.empty()) {
2266 int64_t i = 0;
2267 for (const HloInstruction* operand : hlo->unique_operands()) {
2268 GetPartitionedHlo(operand).set_state(visiting_state_[i++]);
2269 }
2270 visiting_state_.clear();
2271 }
2272
2273 return OkStatus();
2274 }
2275
HandleElementwise(HloInstruction * hlo)2276 Status SpmdPartitioningVisitor::HandleElementwise(HloInstruction* hlo) {
2277 std::vector<HloInstruction*> new_operands;
2278 for (HloInstruction* operand : hlo->operands()) {
2279 new_operands.push_back(
2280 GetPartitionedHlo(operand).Reshard(hlo->sharding()).hlo());
2281 }
2282 SetPartitionedHlo(hlo, [&] {
2283 return b_.AddInstruction(hlo->CloneWithNewOperands(
2284 MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
2285 });
2286 return OkStatus();
2287 }
2288
HandleConcatenate(HloInstruction * hlo)2289 Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) {
2290 const HloSharding& sharding = hlo->sharding();
2291 if (sharding.IsTileMaximal()) {
2292 return DefaultAction(hlo);
2293 }
2294
2295 const Shape shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2296 const int64_t dimension = hlo->concatenate_dimension();
2297 if (sharding.tile_assignment().dim(dimension) == 1) {
2298 std::vector<HloInstruction*> new_operands;
2299 for (HloInstruction* operand : hlo->operands()) {
2300 new_operands.push_back(
2301 GetPartitionedHlo(operand).Reshard(sharding).hlo());
2302 }
2303 SetPartitionedHlo(hlo, [&] {
2304 return b_.AddInstruction(
2305 hlo->CloneWithNewOperands(shard_shape, new_operands));
2306 });
2307 return OkStatus();
2308 }
2309
2310 // If the concatenate dimension is along one of the partitioned dimensions,
2311 // allocate the full output shape, each partition updates its owned region,
2312 // all-reduce across partitions, and then slice its output region.
2313
2314 // temp_output_shape is the output shape where the concatenate dimension
2315 // is changed to the full (and padded to shard count) dimension size.
2316 auto temp_output_shape = MakePartitionedShape(hlo->shape(), sharding);
2317 auto last_operand_padded_shape =
2318 MakePartitionedShape(hlo->operands().back()->shape(), sharding);
2319 // If the last operand has more padding than the temp_output padding, needs to
2320 // add extra padding to avoid dynamic update slice out of bound.
2321 int last_operand_padding =
2322 last_operand_padded_shape.dimensions(dimension) *
2323 sharding.tile_assignment().dim(dimension) -
2324 hlo->operands().back()->shape().dimensions(dimension);
2325 int temp_output_padding = temp_output_shape.dimensions(dimension) *
2326 sharding.tile_assignment().dim(dimension) -
2327 hlo->shape().dimensions(dimension);
2328 int padding_for_last_operand =
2329 last_operand_padding < temp_output_padding
2330 ? 0
2331 : last_operand_padding - temp_output_padding;
2332 temp_output_shape.set_dimensions(
2333 dimension, temp_output_shape.dimensions(dimension) *
2334 sharding.tile_assignment().dim(dimension) +
2335 padding_for_last_operand);
2336 auto temp_output = CreateZero(temp_output_shape, &b_);
2337
2338 // Offset of each operand along the concatenate dimension.
2339 int64_t offset = 0;
2340 auto state = MakePartitioningState();
2341 for (HloInstruction* operand : hlo->operands()) {
2342 auto spmd_operand = GetPartitionedHlo(operand).Reshard(sharding).hlo();
2343 std::vector<HloInstruction*> start_indices(
2344 hlo->shape().rank(), b_.AddInstruction(HloInstruction::CreateConstant(
2345 LiteralUtil::Zero(S32))));
2346 start_indices[dimension] =
2347 MultiplyAddDivideOffsetCalculation(
2348 spmd_operand->shape().dimensions(dimension), offset, 1)
2349 .Calculate(MakeTiledPartitionOrdinals(sharding, state.partition_id,
2350 &b_)[dimension],
2351 &b_);
2352 temp_output = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2353 temp_output_shape, temp_output, spmd_operand, start_indices));
2354 offset += operand->shape().dimensions(dimension);
2355 }
2356 std::vector<int64_t> non_concat_dims;
2357 non_concat_dims.reserve(hlo->shape().rank() - 1);
2358 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2359 if (i != dimension) {
2360 non_concat_dims.push_back(i);
2361 }
2362 }
2363 auto grouped =
2364 hlo_sharding_util::GroupShardingOnDims(sharding, non_concat_dims);
2365 auto per_group_partitioner_state =
2366 CreatePerGroupPartitioningState(state, grouped.device_groups, &b_);
2367 auto all_reduce = per_group_partitioner_state.collective_ops_creator
2368 .create_cross_partition_all_reduce(
2369 &b_, temp_output,
2370 MakeBinaryAdd(hlo->shape().element_type(), module_),
2371 {}, NewChannel());
2372 SetPartitionedHlo(hlo, [&] {
2373 auto start_indices = MakeTiledPartitionOrdinals(
2374 grouped.sharding, per_group_partitioner_state.partition_id, &b_);
2375 start_indices[dimension] = MultiplyAddDivideOffsetCalculation(
2376 shard_shape.dimensions(dimension), 0, 1)
2377 .Calculate(start_indices[dimension], &b_);
2378 return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
2379 shard_shape, all_reduce, start_indices, shard_shape.dimensions()));
2380 });
2381
2382 return OkStatus();
2383 }
2384
HandleSlice(HloInstruction * hlo)2385 Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) {
2386 const HloSharding& sharding = hlo->sharding();
2387 if (sharding.IsTileMaximal()) {
2388 return DefaultAction(hlo);
2389 }
2390
2391 auto operand = GetPartitionedHlo(hlo->operand(0)).Reshard(sharding);
2392
2393 const int64_t rank = hlo->shape().rank();
2394 // Create a window config to represent the slice.
2395 Window window;
2396 for (int64_t i = 0; i < rank; ++i) {
2397 WindowDimension* dim = window.add_dimensions();
2398 dim->set_size(1);
2399 dim->set_stride(hlo->slice_strides(i));
2400 dim->set_window_dilation(1);
2401 dim->set_window_reversal(false);
2402 dim->set_padding_low(-hlo->slice_starts(i));
2403 dim->set_padding_high(hlo->slice_limits(i) -
2404 operand.base_shape().dimensions(i));
2405 dim->set_base_dilation(1);
2406 }
2407
2408 auto reshard_operand = operand.ReshardAsWindowedInput(
2409 window, sharding,
2410 CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
2411 /*mask_invalid_region=*/false);
2412 if (!reshard_operand.has_value()) {
2413 return DefaultAction(hlo);
2414 }
2415 TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
2416 const Shape& operand_shape = reshard_operand->sharded_input->shape();
2417
2418 std::vector<int64_t> start_indices(rank);
2419 std::vector<int64_t> limit_indices(rank);
2420 const std::vector<int64_t>& strides = hlo->slice_strides();
2421 bool need_slice = false;
2422 for (int64_t i = 0; i < rank; ++i) {
2423 auto dim = reshard_operand->shard_window.dimensions(i);
2424 start_indices[i] = -dim.padding_low();
2425 limit_indices[i] = operand_shape.dimensions(i) + dim.padding_high();
2426 if (start_indices[i] != 0 || strides[i] != 1 ||
2427 limit_indices[i] != operand_shape.dimensions(i)) {
2428 need_slice = true;
2429 }
2430 }
2431
2432 SetPartitionedHlo(hlo, [&] {
2433 if (need_slice) {
2434 auto shard_shape = MakePartitionedShape(hlo->shape(), sharding);
2435 return b_.AddInstruction(HloInstruction::CreateSlice(
2436 shard_shape, reshard_operand->sharded_input, start_indices,
2437 limit_indices, strides));
2438 }
2439 auto data = reshard_operand->sharded_input;
2440 // Create a copy so that it will not share the resharding cache.
2441 return b_.AddInstruction(
2442 HloInstruction::CreateUnary(data->shape(), HloOpcode::kCopy, data));
2443 });
2444
2445 return OkStatus();
2446 }
2447
HandleSort(HloInstruction * hlo)2448 Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) {
2449 HloSharding sharding = hlo->sharding();
2450 if (sharding.HasUniqueDevice()) {
2451 return DefaultAction(hlo);
2452 }
2453 // Special handling for sort in TopK when first operand partitioined at
2454 // sort dimension.
2455 auto k = GetKValueInTopKWhenPartitionSortDim(hlo);
2456 if (k.has_value()) {
2457 // When the first operand partitioned at sort dimension:
2458 // 1. Partition sort computation to different partitions;
2459 // 2. Slice TopK value and index from different partitions;
2460 // 3. Gather and replicate value and index from different partitions,
2461 // the shape of replicated value and index will be
2462 // [batch_size, ..., partition_count * k, ...];
2463 // 4. Final sort uses replicated value and index from different partitions
2464 // as input.
2465 // GetTupleElement and Slice after the non-partitoned sort won't change
2466 // at this point, as HandleGetTupleElement and HandleSlice will update them.
2467 HloSortInstruction* sort = DynCast<HloSortInstruction>(hlo);
2468 const int64_t sort_dim = sort->sort_dimension();
2469 auto input = hlo->operand(0);
2470 auto index = hlo->operand(1);
2471 const HloSharding& input_sharding = input->sharding();
2472 const int64_t partition_count =
2473 input_sharding.tile_assignment().dim(sort_dim);
2474 const int64_t input_size = input->shape().dimensions(sort_dim);
2475 const auto element_type = input->shape().element_type();
2476 const auto index_type = index->shape().element_type();
2477
2478 // Partition and pad input and index.
2479 // Pad input with minimal value.
2480 auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
2481 CreateFirstWithType(element_type, &b_));
2482 // Pad index with max value.
2483 auto partitioned_index =
2484 GetPartitionedHlo(index)
2485 .Reshard(input_sharding)
2486 .PadWithValue(CreateLastWithType(index_type, &b_));
2487
2488 // Each partition needs to do TopK separately, thus the base shape
2489 // becomes the padded shape.
2490 std::vector<int64_t> replicated_dimensions(
2491 input->shape().dimensions().begin(), input->shape().dimensions().end());
2492 replicated_dimensions[sort_dim] = RoundUpTo(input_size, partition_count);
2493 const Shape replicated_shape = ShapeUtil::MakeTupleShape(
2494 {ShapeUtil::MakeShape(element_type, replicated_dimensions),
2495 ShapeUtil::MakeShape(index_type, replicated_dimensions)});
2496
2497 // Partition original topk to different shards.
2498 auto topk_sharding =
2499 input_sharding.GetTupleSharding(replicated_shape).ValueOrDie();
2500 auto shard_shape = MakePartitionedShape(replicated_shape, topk_sharding);
2501 auto topk = b_.AddInstruction(hlo->CloneWithNewOperands(
2502 shard_shape, {partitioned_input.hlo(), partitioned_index.hlo()}));
2503
2504 // Get value from first sort.
2505 HloInstruction* value_gte =
2506 b_.AddInstruction(HloInstruction::CreateGetTupleElement(
2507 topk->shape().tuple_shapes(0), topk, 0));
2508 HloInstruction* index_gte =
2509 b_.AddInstruction(HloInstruction::CreateGetTupleElement(
2510 topk->shape().tuple_shapes(1), topk, 1));
2511
2512 // Slice top K value from the first partitioned sort.
2513 replicated_dimensions[sort_dim] = k.value() * partition_count;
2514 auto slice_input = SliceFirstK(value_gte, &b_, sort_dim, k.value());
2515 slice_input->set_sharding(input_sharding);
2516 PartitionedHlo partitioned_slice_input(
2517 slice_input, ShapeUtil::MakeShape(element_type, replicated_dimensions),
2518 MakePartitioningState());
2519 // Reshard value to be replicated.
2520 auto replicated_slice_input =
2521 partitioned_slice_input.Reshard(HloSharding::Replicate()).hlo();
2522
2523 // Slice top K index from the first parttioned sort.
2524 auto slice_index = SliceFirstK(index_gte, &b_, sort_dim, k.value());
2525 slice_index->set_sharding(input_sharding);
2526 PartitionedHlo partitioned_slice_index(
2527 slice_index, ShapeUtil::MakeShape(index_type, replicated_dimensions),
2528 MakePartitioningState());
2529 // Reshard value to be replicated.
2530 auto replicated_slice_index =
2531 partitioned_slice_index.Reshard(HloSharding::Replicate()).hlo();
2532
2533 // Creates replicated sort to do TopK, the input is value and index pairs
2534 // from all the partitions.
2535 const Shape final_topk_shape = ShapeUtil::MakeTupleShape(
2536 {ShapeUtil::MakeShape(element_type, replicated_dimensions),
2537 ShapeUtil::MakeShape(index_type, replicated_dimensions)});
2538 HloInstruction* final_sort = b_.AddInstruction(HloInstruction::CreateSort(
2539 final_topk_shape, sort_dim,
2540 {replicated_slice_input, replicated_slice_index}, sort->to_apply(),
2541 sort->is_stable()));
2542 final_sort->set_sharding(HloSharding::Replicate()
2543 .GetTupleSharding(final_sort->shape())
2544 .ValueOrDie());
2545 PartitionedHlo replicated_sort(final_sort, final_sort->shape(),
2546 MakePartitioningState());
2547 SetPartitionedHlo(hlo, replicated_sort.Reshard(hlo->sharding()));
2548
2549 return OkStatus();
2550 }
2551
2552 if (hlo->shape().IsTuple()) {
2553 // Check that all elements are sharded in the same way.
2554 if (hlo->shape().tuple_shapes_size() == 0) {
2555 return DefaultAction(hlo);
2556 }
2557 sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
2558 for (int64_t i = 1; i < hlo->operand_count(); ++i) {
2559 if (sharding != hlo->sharding().GetSubSharding(hlo->shape(), {i})) {
2560 return DefaultAction(hlo);
2561 }
2562 }
2563 }
2564 if (sharding.IsTileMaximal()) {
2565 return DefaultAction(hlo);
2566 }
2567 for (int64_t dim : hlo->dimensions()) {
2568 if (sharding.tile_assignment().dim(dim) > 1) {
2569 return DefaultAction(hlo);
2570 }
2571 }
2572 // Reshard operands to the same as the output.
2573 std::vector<HloInstruction*> new_operands;
2574 for (HloInstruction* operand : hlo->operands()) {
2575 new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
2576 }
2577 SetPartitionedHlo(hlo, [&] {
2578 return b_.AddInstruction(hlo->CloneWithNewOperands(
2579 MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
2580 });
2581 return OkStatus();
2582 }
2583
HandleTranspose(HloInstruction * hlo)2584 Status SpmdPartitioningVisitor::HandleTranspose(HloInstruction* hlo) {
2585 const HloSharding& sharding = hlo->sharding();
2586 if (sharding.IsTileMaximal()) {
2587 return DefaultAction(hlo);
2588 }
2589
2590 std::vector<int64_t> inverse_dimensions(hlo->shape().rank());
2591 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2592 inverse_dimensions[hlo->dimensions(i)] = i;
2593 }
2594 auto desired_operand_sharding =
2595 hlo_sharding_util::TransposeSharding(sharding, inverse_dimensions);
2596
2597 auto operand = GetPartitionedHlo(hlo->operand(0))
2598 .Reshard(desired_operand_sharding)
2599 .hlo();
2600 SetPartitionedHlo(hlo, [&] {
2601 return b_.AddInstruction(hlo->CloneWithNewOperands(
2602 MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand}));
2603 });
2604 return OkStatus();
2605 }
2606
HandleReshape(HloInstruction * hlo)2607 Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
2608 const HloSharding& sharding = hlo->sharding();
2609 if (sharding.IsTileMaximal()) {
2610 return DefaultAction(hlo);
2611 }
2612
2613 auto operand = GetPartitionedHlo(hlo->operand(0));
2614 // The output shape is the source and the operand shape is the target to get
2615 // the aligned sharding for the operand.
2616 std::optional<HloSharding> desired_operand_sharding =
2617 hlo_sharding_util::ReshapeSharding(hlo->shape(), hlo->operand(0)->shape(),
2618 hlo->sharding());
2619 // Use the desired operand sharding only if the number of tiles returned
2620 // matches the number of tiles in the output.
2621 if (desired_operand_sharding.has_value() &&
2622 hlo->sharding().NumTiles() == desired_operand_sharding->NumTiles()) {
2623 auto operand_hlo = operand.Reshard(*desired_operand_sharding).hlo();
2624 SetPartitionedHlo(hlo, [&] {
2625 return b_.AddInstruction(hlo->CloneWithNewOperands(
2626 MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand_hlo}));
2627 });
2628 return OkStatus();
2629 }
2630 std::optional<HloSharding> desired_output_sharding =
2631 hlo_sharding_util::ReshapeSharding(hlo->operand(0)->shape(), hlo->shape(),
2632 operand.sharding());
2633 if (desired_output_sharding.has_value()) {
2634 auto reshape = b_.AddInstruction(hlo->CloneWithNewOperands(
2635 MakePartitionedShape(hlo->shape(), *desired_output_sharding),
2636 {operand.hlo()}));
2637 reshape->set_sharding(*desired_output_sharding);
2638 SetPartitionedHlo(hlo, [&] {
2639 return PartitionedHlo(reshape, hlo->shape(), MakePartitioningState())
2640 .Reshard(sharding)
2641 .hlo();
2642 });
2643 return OkStatus();
2644 }
2645
2646 // Check if operand sharding and sharding are both tiled or partial replicate.
2647 // If both of them are partial replicate, check num_replications are the same.
2648 if (operand.sharding().ReplicateOnLastTileDim() !=
2649 sharding.ReplicateOnLastTileDim() ||
2650 (sharding.ReplicateOnLastTileDim() &&
2651 (operand.sharding().tile_assignment().dimensions().back() !=
2652 sharding.tile_assignment().dimensions().back()))) {
2653 return DefaultAction(hlo);
2654 }
2655
2656 // Try use halo exchange for certain split-dim/merge-dims cases.
2657 // ReshapeSharding failed in these cases probably due to uneven partitioning,
2658 // where halo exchange could help. Specifically we check the following
2659 // conditions to detect supported cases:
2660 // 1) Both input and output are partitioned on one dimension.
2661 // 2) The combined size of dimensions before the partitioned dimension are the
2662 // same on input and output. This means we don't need to consider the major
2663 // dimensions.
2664 // 3) Let A = the input size on the partitioned dimension, and
2665 // B = the output size on the partitioned dimension; then
2666 // either A % B == 0 (split dim) or B % A == 0 (merge dims).
2667 auto maybe_input_sharded_dim = UniqueTiledDim(operand.sharding());
2668 auto maybe_output_sharded_dim = UniqueTiledDim(sharding);
2669 if (!maybe_input_sharded_dim || !maybe_output_sharded_dim) {
2670 return DefaultAction(hlo);
2671 }
2672 int64_t input_sharded_dim = *maybe_input_sharded_dim;
2673 int64_t output_sharded_dim = *maybe_output_sharded_dim;
2674 // Check that the major dims before the sharded dim have the same total size
2675 // for input and output.
2676 int64_t input_major_dims_size = 1;
2677 for (int64_t i = 0; i < input_sharded_dim; ++i) {
2678 input_major_dims_size *= operand.base_shape().dimensions(i);
2679 }
2680 int64_t output_major_dims_size = 1;
2681 for (int64_t i = 0; i < output_sharded_dim; ++i) {
2682 output_major_dims_size *= hlo->shape().dimensions(i);
2683 }
2684 if (input_major_dims_size != output_major_dims_size) {
2685 return DefaultAction(hlo);
2686 }
2687 // Fix potential device ordering mismatch in tile assignment.
2688 Array<int64_t> new_input_tile_assignment = sharding.tile_assignment();
2689 new_input_tile_assignment.Reshape(
2690 operand.sharding().tile_assignment().dimensions());
2691 auto aligned_sharding =
2692 sharding.ReplicateOnLastTileDim()
2693 ? HloSharding::PartialTile(new_input_tile_assignment)
2694 : HloSharding::Tile(new_input_tile_assignment);
2695 operand = operand.Reshard(aligned_sharding);
2696 auto replication_count = sharding.ReplicateOnLastTileDim()
2697 ? sharding.tile_assignment().dimensions().back()
2698 : 1;
2699
2700 int64_t input_dim_size = operand.base_shape().dimensions(input_sharded_dim);
2701 int64_t output_dim_size = hlo->shape().dimensions(output_sharded_dim);
2702 auto input_shard_shape =
2703 MakePartitionedShape(operand.base_shape(), operand.sharding());
2704 auto output_shard_shape = MakePartitionedShape(hlo->shape(), sharding);
2705 if (input_dim_size % output_dim_size == 0) {
2706 // Split dim.
2707 int64_t split_factor = input_dim_size / output_dim_size;
2708 int64_t output_shard_size =
2709 output_shard_shape.dimensions(output_sharded_dim);
2710 // Use halo exchange to fix misaligned data.
2711 Window window;
2712 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2713 WindowDimension* dim = window.add_dimensions();
2714 dim->set_size(1);
2715 dim->set_stride(1);
2716 dim->set_window_dilation(1);
2717 dim->set_window_reversal(false);
2718 dim->set_base_dilation(1);
2719 dim->set_padding_low(0);
2720 if (i == input_sharded_dim) {
2721 dim->set_padding_high(output_shard_size * split_factor *
2722 num_partitions_ / replication_count -
2723 input_dim_size);
2724 } else {
2725 dim->set_padding_high(0);
2726 }
2727 }
2728
2729 auto reshard_operand = operand.ReshardAsWindowedInput(
2730 window, operand.sharding(),
2731 CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
2732 /*mask_invalid_region=*/false);
2733 if (!reshard_operand.has_value()) {
2734 return DefaultAction(hlo);
2735 }
2736 TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
2737 CHECK_EQ(
2738 reshard_operand->sharded_input->shape().dimensions(input_sharded_dim),
2739 output_shard_size * split_factor);
2740 SetPartitionedHlo(hlo, [&] {
2741 // Do a local reshape.
2742 return b_.AddInstruction(HloInstruction::CreateReshape(
2743 output_shard_shape, reshard_operand->sharded_input));
2744 });
2745 return OkStatus();
2746 } else if (output_dim_size % input_dim_size == 0) {
2747 // Merge dims.
2748 int64_t merge_factor = output_dim_size / input_dim_size;
2749 // First reshape locally. (The sharded dimension could include padded data.)
2750 auto tmp_shard_shape = output_shard_shape;
2751 tmp_shard_shape.set_dimensions(
2752 output_sharded_dim,
2753 input_shard_shape.dimensions(input_sharded_dim) * merge_factor);
2754 auto tmp_reshape = b_.AddInstruction(
2755 HloInstruction::CreateReshape(tmp_shard_shape, operand.hlo()));
2756 tmp_reshape->set_sharding(hlo->sharding());
2757 auto tmp_full_shape = tmp_shard_shape;
2758 tmp_full_shape.set_dimensions(
2759 output_sharded_dim, tmp_shard_shape.dimensions(output_sharded_dim) *
2760 num_partitions_ / replication_count);
2761 auto tmp_output =
2762 PartitionedHlo(tmp_reshape, tmp_full_shape, MakePartitioningState());
2763
2764 // Use halo exchange to fix misaligned data.
2765 Window window;
2766 for (int64_t i = 0; i < tmp_shard_shape.rank(); ++i) {
2767 WindowDimension* dim = window.add_dimensions();
2768 dim->set_size(1);
2769 dim->set_stride(1);
2770 dim->set_window_dilation(1);
2771 dim->set_window_reversal(false);
2772 dim->set_base_dilation(1);
2773 dim->set_padding_low(0);
2774 if (i == output_sharded_dim) {
2775 dim->set_padding_high(output_dim_size -
2776 tmp_shard_shape.dimensions(output_sharded_dim) *
2777 num_partitions_ / replication_count);
2778 } else {
2779 dim->set_padding_high(0);
2780 }
2781 }
2782
2783 auto reshard_output = tmp_output.ReshardAsWindowedInput(
2784 window, sharding,
2785 CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
2786 /*mask_invalid_region=*/false);
2787 if (!reshard_output.has_value()) {
2788 return DefaultAction(hlo);
2789 }
2790 TF_RET_CHECK(!reshard_output->dynamic_slice_index_on_output.has_value());
2791 CHECK_EQ(
2792 reshard_output->sharded_input->shape().dimensions(output_sharded_dim),
2793 output_shard_shape.dimensions(output_sharded_dim));
2794 SetPartitionedHlo(hlo, [&] { return reshard_output->sharded_input; });
2795 return OkStatus();
2796 }
2797 return DefaultAction(hlo);
2798 }
2799
HandleIota(HloInstruction * hlo)2800 Status SpmdPartitioningVisitor::HandleIota(HloInstruction* hlo) {
2801 const HloSharding& sharding = hlo->sharding();
2802 if (sharding.IsTileMaximal()) {
2803 return DefaultAction(hlo);
2804 }
2805
2806 SetPartitionedHlo(hlo, [&] {
2807 int64_t dimension = Cast<HloIotaInstruction>(hlo)->iota_dimension();
2808 auto iota = b_.AddInstruction(HloInstruction::CreateIota(
2809 MakePartitionedShape(hlo->shape(), sharding), dimension));
2810
2811 if (sharding.tile_assignment().dim(dimension) > 1) {
2812 auto partition_ordinals = MakeTiledPartitionOrdinals(
2813 sharding, MakePartitioningState().partition_id, &b_);
2814 auto multiplier = b_.AddInstruction(HloInstruction::CreateConstant(
2815 LiteralUtil::CreateR0<int32_t>(iota->shape().dimensions(dimension))));
2816 auto offset = b_.AddInstruction(HloInstruction::CreateBinary(
2817 ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply,
2818 partition_ordinals[dimension], multiplier));
2819 if (iota->shape().element_type() != S32) {
2820 offset = b_.AddInstruction(HloInstruction::CreateConvert(
2821 ShapeUtil::MakeShape(iota->shape().element_type(), {}), offset));
2822 }
2823 auto broadcast = b_.AddInstruction(
2824 HloInstruction::CreateBroadcast(iota->shape(), offset, {}));
2825 return b_.AddInstruction(HloInstruction::CreateBinary(
2826 iota->shape(), HloOpcode::kAdd, iota, broadcast));
2827 }
2828
2829 return iota;
2830 });
2831
2832 return OkStatus();
2833 }
2834
HandleSingleDevice(const HloInstruction * hlo)2835 Status SpmdPartitioningVisitor::HandleSingleDevice(const HloInstruction* hlo) {
2836 TF_RET_CHECK(hlo->sharding().HasUniqueDevice());
2837 int64_t device = hlo->sharding().GetUniqueDevice();
2838 const HloSharding sharding = HloSharding::AssignDevice(device);
2839
2840 std::vector<HloInstruction*> operands;
2841 std::vector<const Shape*> operand_shapes;
2842 const auto& old_operands = hlo->operands();
2843 const auto old_operands_size = old_operands.size();
2844 operands.reserve(old_operands_size);
2845 operand_shapes.reserve(old_operands_size);
2846 for (const HloInstruction* operand : old_operands) {
2847 operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
2848 operand_shapes.push_back(&operand->shape());
2849 }
2850 auto operand = b_.AddInstruction(HloInstruction::CreateTuple(operands));
2851 auto operand_shape = ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes);
2852
2853 auto on_device = b_.AddInstruction(
2854 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32_t>(device)));
2855 auto pred = b_.AddInstruction(HloInstruction::CreateCompare(
2856 ShapeUtil::MakeShape(PRED, {}), MakePartitioningState().partition_id,
2857 on_device, ComparisonDirection::kEq));
2858
2859 SpmdBuilder true_b("true_computation", visiting_hlo_);
2860 HloComputation* true_computation;
2861 {
2862 auto param = true_b.AddInstruction(HloInstruction::CreateParameter(
2863 /*parameter_number=*/0, operand_shape, "true_branch_param"));
2864 std::vector<HloInstruction*> new_operands;
2865 for (int64_t i = 0; i < operands.size(); ++i) {
2866 new_operands.push_back(true_b.AddInstruction(
2867 HloInstruction::CreateGetTupleElement(*operand_shapes[i], param, i)));
2868 }
2869 auto root = true_b.AddInstruction(
2870 hlo->CloneWithNewOperands(hlo->shape(), new_operands));
2871 true_computation = module_->AddEmbeddedComputation(true_b.Build(root));
2872 }
2873
2874 SpmdBuilder false_b("false_computation", visiting_hlo_);
2875 HloComputation* false_computation;
2876 {
2877 false_b.AddInstruction(HloInstruction::CreateParameter(
2878 /*parameter_number=*/0, operand_shape, "false_branch_param"));
2879 auto root = CreateZero(hlo->shape(), &false_b);
2880 false_computation = module_->AddEmbeddedComputation(false_b.Build(root));
2881 }
2882
2883 SetPartitionedHlo(hlo, [&]() {
2884 return b_.AddInstruction(HloInstruction::CreateConditional(
2885 hlo->shape(), pred, operand, true_computation, operand,
2886 false_computation));
2887 });
2888 return OkStatus();
2889 }
2890
HandleAllReduce(HloInstruction * hlo)2891 Status SpmdPartitioningVisitor::HandleAllReduce(HloInstruction* hlo) {
2892 if (hlo->IsCrossReplicaAllReduce() && hlo->operand_count() == 1) {
2893 return HandleElementwise(hlo);
2894 }
2895 if (hlo->channel_id()) {
2896 TF_RET_CHECK(hlo->operand_count() == 1)
2897 << "SPMD partitioner supports only single-operand allreduce in manual "
2898 "partitioning mode.";
2899 if (hlo->sharding().IsManual()) {
2900 return HandleElementwise(hlo);
2901 }
2902 TF_RET_CHECK(hlo->sharding().IsManualSubgroup())
2903 << "Cross-partition allreduce must be in (partial) manual partitioning "
2904 "mode.";
2905 auto* ar = Cast<HloAllReduceInstruction>(hlo);
2906 TF_RET_CHECK(ar->use_global_device_ids())
2907 << "Cross-partition allreduce in partial manual partitioning mode must "
2908 "use global device IDs.";
2909 absl::flat_hash_map<int64_t, int64_t> partition_to_group_id;
2910 hlo->sharding().tile_assignment().Each(
2911 [&](absl::Span<const int64_t> indices, int64_t partition) {
2912 int64_t group_id = 0;
2913 for (int64_t i = 0; i < indices.size(); ++i) {
2914 if (i == hlo->sharding().SubgroupManualDim()) {
2915 continue;
2916 }
2917 group_id *= hlo->sharding().tile_assignment().dim(i);
2918 group_id += indices[i];
2919 }
2920 partition_to_group_id[partition] = group_id;
2921 });
2922 for (const auto& group : ar->replica_groups()) {
2923 int64_t first_partition = group.replica_ids(0) % num_partitions_;
2924 for (int64_t device : group.replica_ids()) {
2925 int64_t partition = device % num_partitions_;
2926 if (partition_to_group_id[partition] !=
2927 partition_to_group_id[first_partition]) {
2928 return InvalidArgumentStrCat(
2929 "Manual all-reduce across devices that belong to different "
2930 "manual subgroups: ",
2931 ar->ToString());
2932 }
2933 }
2934 }
2935 return HandleElementwise(hlo);
2936 }
2937 return DefaultAction(hlo);
2938 }
2939
HandleBroadcast(HloInstruction * hlo)2940 Status SpmdPartitioningVisitor::HandleBroadcast(HloInstruction* hlo) {
2941 if (hlo->sharding().IsTileMaximal()) {
2942 return DefaultAction(hlo);
2943 }
2944
2945 auto& operand = GetPartitionedHlo(hlo->operand(0));
2946
2947 // Tiled output.
2948 std::vector<int64_t> new_dims;
2949 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2950 if (!absl::c_linear_search(hlo->dimensions(), i)) {
2951 new_dims.push_back(i);
2952 }
2953 }
2954 auto desired_input_sharding = hlo_sharding_util::RemoveShapeDimensions(
2955 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(hlo->sharding(),
2956 new_dims),
2957 new_dims);
2958 auto input = operand.Reshard(desired_input_sharding).hlo();
2959 auto output_shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2960 SetPartitionedHlo(hlo, [&] {
2961 return b_.AddInstruction(
2962 hlo->CloneWithNewOperands(output_shard_shape, {input}));
2963 });
2964 return OkStatus();
2965 }
2966
HandleConstant(HloInstruction * hlo)2967 Status SpmdPartitioningVisitor::HandleConstant(HloInstruction* hlo) {
2968 const Literal& literal = hlo->literal();
2969 if (literal.shape().IsTuple() ||
2970 (!hlo->sharding().IsTileMaximal() &&
2971 (!EvenlyPartitions(hlo->shape(), hlo->sharding()) ||
2972 !literal.IsAllFirst()))) {
2973 return DefaultAction(hlo);
2974 }
2975
2976 SetPartitionedHlo(hlo, [&]() {
2977 auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2978 std::vector<int64_t> start_indices(hlo->shape().rank(), 0);
2979 auto constant = b_.AddInstruction(HloInstruction::CreateConstant(
2980 literal.Slice(start_indices, shard_shape.dimensions())));
2981 *constant->mutable_shape() = shard_shape;
2982 return constant;
2983 });
2984 return OkStatus();
2985 }
2986
HandleDynamicSlice(HloInstruction * hlo)2987 Status SpmdPartitioningVisitor::HandleDynamicSlice(HloInstruction* hlo) {
2988 if (hlo->sharding().IsTileMaximal()) {
2989 return DefaultAction(hlo);
2990 }
2991 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2992 if (hlo->sharding().tile_assignment().dim(i) != 1 &&
2993 hlo->dynamic_slice_sizes()[i] !=
2994 hlo->operand(0)->shape().dimensions(i)) {
2995 // We currently do not partition the sliced dimensions.
2996 return DefaultAction(hlo);
2997 }
2998 }
2999 std::vector<HloInstruction*> new_indices(hlo->shape().rank());
3000 auto new_input =
3001 GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo();
3002 for (int64_t i = 0; i < new_indices.size(); ++i) {
3003 if (hlo->dynamic_slice_sizes()[i] ==
3004 hlo->operand(0)->shape().dimensions(i)) {
3005 // Trivial slice dim: index must be clampped to 0.
3006 new_indices[i] = CreateZero(hlo->operand(i + 1)->shape(), &b_);
3007 continue;
3008 }
3009 // Replicate the indices.;
3010 new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1))
3011 .Reshard(HloSharding::Replicate())
3012 .hlo();
3013 }
3014 SetPartitionedHlo(hlo, [&]() {
3015 auto partitioned_shape =
3016 MakePartitionedShape(hlo->shape(), hlo->sharding());
3017 return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
3018 partitioned_shape, new_input, new_indices,
3019 partitioned_shape.dimensions()));
3020 });
3021 return OkStatus();
3022 }
3023
HandleDynamicUpdateSlice(HloInstruction * hlo)3024 Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice(HloInstruction* hlo) {
3025 if (hlo->sharding().IsTileMaximal()) {
3026 return DefaultAction(hlo);
3027 }
3028
3029 std::vector<int64_t> partitioned_slice_dims;
3030 std::vector<int64_t> slice_dims;
3031 std::vector<int64_t> partitioned_non_slice_dims;
3032 std::vector<int64_t> partitioned_slice_offsets;
3033 bool any_non_constant_sliced_dim = false;
3034 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
3035 if (hlo->operand(1)->shape().dimensions(i) != hlo->shape().dimensions(i)) {
3036 slice_dims.push_back(i);
3037 int64_t slice_size = hlo->operand(1)->shape().dimensions(i);
3038 if (hlo->sharding().tile_assignment().dim(i) != 1) {
3039 if (!hlo->operand(i + 2)->IsConstant() && slice_size != 1) {
3040 any_non_constant_sliced_dim = true;
3041 continue;
3042 }
3043 partitioned_slice_dims.push_back(i);
3044 // Set partitioned_slice_offsets to -1 when slice_size is 1.
3045 if (slice_size == 1) {
3046 partitioned_slice_offsets.push_back(-1);
3047 } else {
3048 partitioned_slice_offsets.push_back(
3049 hlo->operand(i + 2)->literal().Get<int>({}));
3050 }
3051 }
3052 } else if (hlo->sharding().tile_assignment().dim(i) != 1) {
3053 partitioned_non_slice_dims.push_back(i);
3054 }
3055 }
3056 auto handle_with_replicate_slice_dims = [&]() {
3057 HloSharding replicated_sharding =
3058 hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept(
3059 hlo->operand(0)->sharding(), partitioned_non_slice_dims);
3060 auto base = GetPartitionedHlo(hlo->operand(0)).Reshard(replicated_sharding);
3061 auto operand =
3062 GetPartitionedHlo(hlo->operand(1)).Reshard(replicated_sharding);
3063 std::vector<HloInstruction*> new_indices(hlo->shape().rank());
3064 for (int64_t i = 0; i < new_indices.size(); ++i) {
3065 // Replicate the indices.
3066 new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
3067 .Reshard(HloSharding::Replicate())
3068 .hlo();
3069 }
3070 auto dus = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
3071 base.hlo()->shape(), base.hlo(), operand.hlo(), new_indices));
3072 dus->set_sharding(replicated_sharding);
3073 SetPartitionedHlo(hlo, PartitionedHlo(dus, base.base_shape(), base.state())
3074 .Reshard(hlo->sharding()));
3075 };
3076 if (any_non_constant_sliced_dim) {
3077 if (partitioned_non_slice_dims.empty()) {
3078 return DefaultAction(hlo);
3079 }
3080 handle_with_replicate_slice_dims();
3081 return OkStatus();
3082 }
3083
3084 // Handle when there is slice dim partitioned.
3085 if (!partitioned_slice_dims.empty()) {
3086 auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
3087 return b_.AddInstruction(std::move(to_add));
3088 };
3089 std::vector<HloInstruction*> new_indices(hlo->shape().rank());
3090 for (int64_t i = 0; i < new_indices.size(); ++i) {
3091 if (hlo->operand(1)->shape().dimensions(i) ==
3092 hlo->shape().dimensions(i)) {
3093 new_indices[i] = CreateZero(hlo->operand(i + 2)->shape(), &b_);
3094 continue;
3095 }
3096 // Replicate the indices.
3097 new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
3098 .Reshard(HloSharding::Replicate())
3099 .hlo();
3100 }
3101
3102 // Get partitioned input.
3103 const auto& dus_sharding = hlo->sharding();
3104 const auto& partitioned_input =
3105 GetPartitionedHlo(hlo->operand(0)).Reshard(dus_sharding).hlo();
3106
3107 // Get replicate update.
3108 auto update_sharding = HloSharding::Replicate();
3109 if (!partitioned_non_slice_dims.empty()) {
3110 // Do partial replicate for update if non slice dims are partitioned.
3111 update_sharding =
3112 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(dus_sharding,
3113 slice_dims);
3114 }
3115
3116 // TODO(wangtao): use collective permute for sharded update.
3117 HloInstruction* replicate_update =
3118 GetPartitionedHlo(hlo->operand(1)).Reshard(update_sharding).hlo();
3119
3120 const auto& update_shape = replicate_update->shape();
3121 const auto& partitioned_shape = partitioned_input->shape();
3122 auto partition_ordinals = MakeTiledPartitionOrdinals(
3123 hlo->sharding(), MakePartitioningState().partition_id, &b_);
3124 HloInstruction* all_dims_within_partition = add_hlo(
3125 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
3126
3127 for (int i = 0; i < partitioned_slice_dims.size(); ++i) {
3128 int dim = partitioned_slice_dims[i];
3129 // Calculate per partition size.
3130 const int64_t per_partition_size = partitioned_shape.dimensions(dim);
3131
3132 // Only update within a single partition is supported.
3133 // Will ignore this check when slice size is 1 where
3134 // partitioned_slice_offsets[i] is -1.
3135 if ((partitioned_slice_offsets[i] != -1) &&
3136 (partitioned_slice_offsets[i] / per_partition_size) !=
3137 ((partitioned_slice_offsets[i] + update_shape.dimensions(dim) -
3138 1) /
3139 per_partition_size)) {
3140 handle_with_replicate_slice_dims();
3141 return Status::OK();
3142 }
3143
3144 // within_partition = (offset >= partition_id * per_partition_size) &&
3145 // (offset < (partition_id + 1) * per_partition_size)
3146 const Shape& compare_shape =
3147 ShapeUtil::ChangeElementType(partition_id_->shape(), PRED);
3148 auto per_partition_size_hlo = add_hlo(HloInstruction::CreateConstant(
3149 LiteralUtil::CreateR0<int>(per_partition_size)));
3150 const Shape& offset_shape = per_partition_size_hlo->shape();
3151 auto partition_offset = add_hlo(HloInstruction::CreateBinary(
3152 offset_shape, HloOpcode::kMultiply, partition_ordinals[dim],
3153 per_partition_size_hlo));
3154 // offset >= partition_id * per_partition_size
3155 auto offset_ge = add_hlo(HloInstruction::CreateCompare(
3156 compare_shape, new_indices[dim], partition_offset,
3157 ComparisonDirection::kGe));
3158 // offset < (partition_id + 1) * per_partition_size
3159 auto offset_lt = add_hlo(HloInstruction::CreateCompare(
3160 compare_shape, new_indices[dim],
3161 add_hlo(HloInstruction::CreateBinary(
3162 offset_shape, HloOpcode::kMultiply,
3163 add_hlo(HloInstruction::CreateBinary(
3164 offset_shape, HloOpcode::kAdd, partition_ordinals[dim],
3165 add_hlo(HloInstruction::CreateConstant(
3166 LiteralUtil::CreateR0<int>(1))))),
3167 per_partition_size_hlo)),
3168 ComparisonDirection::kLt));
3169 auto update_within_partition = add_hlo(HloInstruction::CreateBinary(
3170 compare_shape, HloOpcode::kAnd, offset_ge, offset_lt));
3171
3172 all_dims_within_partition = add_hlo(HloInstruction::CreateBinary(
3173 compare_shape, HloOpcode::kAnd, all_dims_within_partition,
3174 update_within_partition));
3175
3176 // Calculate offset.
3177 // slice dim offset =
3178 // within_partition ?
3179 // offset - partition_id * per_partition_size : 0
3180 new_indices[dim] = add_hlo(HloInstruction::CreateTernary(
3181 new_indices[dim]->shape(), HloOpcode::kSelect,
3182 update_within_partition,
3183 add_hlo(HloInstruction::CreateBinary(
3184 new_indices[dim]->shape(), HloOpcode::kSubtract, new_indices[dim],
3185 partition_offset)),
3186 add_hlo(
3187 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)))));
3188 }
3189
3190 // Create dynamic update slice.
3191 auto dus = add_hlo(HloInstruction::CreateDynamicUpdateSlice(
3192 partitioned_shape, partitioned_input, replicate_update, new_indices));
3193 SetPartitionedHlo(hlo, [&]() {
3194 // Select if update is needed.
3195 return add_hlo(HloInstruction::CreateTernary(
3196 dus->shape(), HloOpcode::kSelect,
3197 add_hlo(HloInstruction::CreateBroadcast(
3198 ShapeUtil::ChangeElementType(dus->shape(), PRED),
3199 all_dims_within_partition, {})),
3200 dus, partitioned_input));
3201 });
3202 return OkStatus();
3203 }
3204
3205 // Partition non slice dims only.
3206 std::vector<HloInstruction*> new_indices(hlo->shape().rank());
3207 auto new_input =
3208 GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo();
3209 auto new_update =
3210 GetPartitionedHlo(hlo->operand(1)).Reshard(hlo->sharding()).hlo();
3211 for (int64_t i = 0; i < new_indices.size(); ++i) {
3212 if (hlo->operand(1)->shape().dimensions(i) == hlo->shape().dimensions(i)) {
3213 new_indices[i] = CreateZero(hlo->operand(i + 2)->shape(), &b_);
3214 continue;
3215 }
3216 // Replicate the indices.
3217 new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
3218 .Reshard(HloSharding::Replicate())
3219 .hlo();
3220 }
3221 SetPartitionedHlo(hlo, [&]() {
3222 auto partitioned_shape =
3223 MakePartitionedShape(hlo->shape(), hlo->sharding());
3224 return b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
3225 partitioned_shape, new_input, new_update, new_indices));
3226 });
3227 return OkStatus();
3228 }
3229
HandleGetTupleElement(HloInstruction * hlo)3230 Status SpmdPartitioningVisitor::HandleGetTupleElement(HloInstruction* hlo) {
3231 const auto& tuple = GetPartitionedHlo(hlo->operand(0));
3232 auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
3233 ShapeUtil::GetTupleElementShape(tuple.hlo()->shape(), hlo->tuple_index()),
3234 tuple.hlo(), hlo->tuple_index()));
3235 const auto source_sharding =
3236 tuple.sharding().GetSubSharding(tuple.base_shape(), {hlo->tuple_index()});
3237 gte->set_sharding(source_sharding);
3238 PartitionedHlo source_partitioned_gte(
3239 gte, tuple.base_shape().tuple_shapes(hlo->tuple_index()),
3240 MakePartitioningState());
3241 source_partitioned_gte = source_partitioned_gte.Reshard(hlo->sharding());
3242 SetPartitionedHlo(hlo, source_partitioned_gte);
3243 return OkStatus();
3244 }
3245
HandleInfeed(HloInstruction * hlo)3246 Status SpmdPartitioningVisitor::HandleInfeed(HloInstruction* hlo) {
3247 const Shape& shape = ShapeUtil::GetTupleElementShape(hlo->shape(), 0);
3248 auto token = GetPartitionedHlo(hlo->operand(0)).hlo();
3249 if (ShapeUtil::GetLeafCount(shape) == 0) {
3250 // TODO(b/155819021): HloSharding has issues with tuple-shaped sharding: it
3251 // requires one element for an empty tuple, but leaf-count number of
3252 // elements for non-empty tuple. So if it has a nested empty tuple, we
3253 // cannot invoke GetSubSharding() since it expects a sharding for the empty
3254 // tuple. This is a workaround for that case.
3255 SetPartitionedHlo(hlo, [&]() {
3256 return b_.AddInstruction(
3257 HloInstruction::CreateInfeed(shape, token, hlo->infeed_config()));
3258 });
3259 return OkStatus();
3260 }
3261 auto sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
3262 auto shard_shape = MakePartitionedShape(shape, sharding);
3263 if (EvenlyPartitions(shape, sharding)) {
3264 SetPartitionedHlo(hlo, [&]() {
3265 return b_.AddInstruction(HloInstruction::CreateInfeed(
3266 shard_shape, token, hlo->infeed_config()));
3267 });
3268 return OkStatus();
3269 }
3270
3271 if (hlo->sharding().HasUniqueDevice()) {
3272 return HandleSingleDevice(hlo);
3273 }
3274
3275 // Create a branch for each unique partitioned shape.
3276 std::vector<Shape> per_branch_partitioned_shapes;
3277 std::vector<int32_t> conditional_branch_indices(num_partitions_);
3278 for (int64_t i = 0; i < num_partitions_; ++i) {
3279 auto partitioned_shape =
3280 MakeNonPaddedShapeForGivenPartition(shape, sharding, i);
3281 int64_t matching_existing_index = 0;
3282 for (; matching_existing_index < per_branch_partitioned_shapes.size();
3283 ++matching_existing_index) {
3284 if (ShapeUtil::Compatible(
3285 partitioned_shape,
3286 per_branch_partitioned_shapes[matching_existing_index])) {
3287 break;
3288 }
3289 }
3290 if (matching_existing_index < per_branch_partitioned_shapes.size()) {
3291 conditional_branch_indices[i] = matching_existing_index;
3292 } else {
3293 conditional_branch_indices[i] = per_branch_partitioned_shapes.size();
3294 per_branch_partitioned_shapes.push_back(std::move(partitioned_shape));
3295 }
3296 }
3297
3298 HloInstruction* branch_index;
3299 auto state = MakePartitioningState();
3300 if (per_branch_partitioned_shapes.size() == num_partitions_) {
3301 // Use partition ID as the branch index if each partition has its own
3302 // branch.
3303 branch_index = state.partition_id;
3304 // PartitionId's output is U32 but conditional requires S32.
3305 if (branch_index->shape().element_type() != S32) {
3306 branch_index = b_.AddInstruction(HloInstruction::CreateConvert(
3307 ShapeUtil::ChangeElementType(branch_index->shape(), S32),
3308 branch_index));
3309 }
3310 } else {
3311 // Otherwise, use a constant table to look up the branch index.
3312 auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant(
3313 LiteralUtil::CreateR1<int32_t>(conditional_branch_indices)));
3314 branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
3315 ShapeUtil::MakeShape(S32, {1}), branch_index_table,
3316 {state.partition_id}, {1}));
3317 branch_index = b_.AddInstruction(HloInstruction::CreateReshape(
3318 ShapeUtil::MakeShape(S32, {}), branch_index));
3319 }
3320
3321 std::vector<HloComputation*> branches(per_branch_partitioned_shapes.size());
3322 for (int64_t i = 0; i < branches.size(); ++i) {
3323 SpmdBuilder branch_b(absl::StrCat("infeed_branch_", i), visiting_hlo_);
3324 auto param = branch_b.AddInstruction(HloInstruction::CreateParameter(
3325 /*parameter_number=*/0, token->shape(), "infeed_token_param"));
3326 auto infeed = branch_b.AddInstruction(HloInstruction::CreateInfeed(
3327 per_branch_partitioned_shapes[i], param, hlo->infeed_config()));
3328 if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) {
3329 std::function<HloInstruction*(const ShapeIndex&, HloInstruction*)>
3330 pad_infeed = [&](const ShapeIndex& index,
3331 HloInstruction* infeed_element) -> HloInstruction* {
3332 if (index == ShapeIndex({1})) {
3333 // Token.
3334 return infeed_element;
3335 }
3336 const Shape& element_shape =
3337 ShapeUtil::GetSubshape(infeed->shape(), index);
3338 if (element_shape.IsTuple() && element_shape.tuple_shapes_size() > 0) {
3339 std::vector<HloInstruction*> padded_elements(
3340 element_shape.tuple_shapes_size());
3341 for (int64_t i = 0; i < padded_elements.size(); ++i) {
3342 auto sub_index = index;
3343 sub_index.push_back(i);
3344 padded_elements[i] = pad_infeed(
3345 sub_index,
3346 branch_b.AddInstruction(HloInstruction::CreateGetTupleElement(
3347 ShapeUtil::GetSubshape(element_shape, {i}), infeed_element,
3348 i)));
3349 }
3350 return branch_b.AddInstruction(
3351 HloInstruction::CreateTuple(padded_elements));
3352 }
3353 const Shape& pad_shape = ShapeUtil::GetSubshape(
3354 shard_shape, ShapeIndexView(index).subspan(1));
3355 if (ShapeUtil::Compatible(element_shape, pad_shape)) {
3356 return infeed_element;
3357 }
3358 if (element_shape.IsArray()) {
3359 CHECK(pad_shape.IsArray());
3360 return PadToShape(infeed_element, pad_shape, &branch_b);
3361 }
3362 CHECK(element_shape.IsTuple());
3363 CHECK(element_shape.tuple_shapes().empty());
3364 return CreateZero(pad_shape, &branch_b);
3365 };
3366 pad_infeed({}, infeed);
3367 }
3368 branches[i] = module_->AddEmbeddedComputation(branch_b.Build());
3369 }
3370 SetPartitionedHlo(hlo, [&]() {
3371 return b_.AddInstruction(HloInstruction::CreateConditional(
3372 ShapeUtil::MakeTupleShape({shard_shape, token->shape()}), branch_index,
3373 branches, std::vector<HloInstruction*>(branches.size(), token)));
3374 });
3375 return OkStatus();
3376 }
3377
HandlePad(HloInstruction * hlo)3378 Status SpmdPartitioningVisitor::HandlePad(HloInstruction* hlo) {
3379 if (hlo->sharding().IsTileMaximal()) {
3380 return DefaultAction(hlo);
3381 }
3382 auto lhs = GetPartitionedHlo(hlo->operand(0));
3383 // Create a window config to represent the pad.
3384 Window window;
3385 bool needs_masking = false;
3386 const bool pad_value_is_zero =
3387 hlo->operand(1)->IsConstant() && hlo->operand(1)->literal().IsZero({});
3388 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
3389 const auto& pd = hlo->padding_config().dimensions(i);
3390 WindowDimension* dim = window.add_dimensions();
3391 dim->set_size(1);
3392 dim->set_stride(1);
3393 dim->set_window_dilation(1);
3394 dim->set_window_reversal(false);
3395 dim->set_padding_low(pd.edge_padding_low());
3396 dim->set_padding_high(pd.edge_padding_high());
3397 dim->set_base_dilation(pd.interior_padding() + 1);
3398 const int64_t shard_count = hlo->sharding().tile_assignment().dim(i);
3399 // Need masking only if there is non-zero padding value or the operand is
3400 // unevenly partitioned. Halo exchange fills 0 in collective permute result
3401 // for non-destination cores.
3402 needs_masking |=
3403 shard_count > 1 &&
3404 (pd.edge_padding_low() > 0 || pd.edge_padding_high() > 0 ||
3405 pd.interior_padding() > 0) &&
3406 (!pad_value_is_zero ||
3407 hlo->operand(0)->shape().dimensions(i) % shard_count != 0);
3408 }
3409
3410 auto replicated_rhs = GetPartitionedHlo(hlo->operand(1))
3411 .Reshard(HloSharding::Replicate())
3412 .hlo();
3413 auto reshard_operand =
3414 lhs.ReshardAsWindowedInput(window, hlo->sharding(), replicated_rhs,
3415 /*mask_invalid_region=*/needs_masking);
3416 if (!reshard_operand.has_value()) {
3417 return DefaultAction(hlo);
3418 }
3419 PaddingConfig sharded_padding_config;
3420 bool need_pad = false;
3421 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
3422 auto dim = sharded_padding_config.add_dimensions();
3423 const auto& wd = reshard_operand->shard_window.dimensions(i);
3424 dim->set_edge_padding_low(wd.padding_low());
3425 dim->set_edge_padding_high(wd.padding_high());
3426 dim->set_interior_padding(wd.base_dilation() - 1);
3427 if (wd.padding_low() != 0 || wd.padding_high() != 0 ||
3428 wd.base_dilation() != 1) {
3429 need_pad = true;
3430 }
3431 }
3432 auto sharded_pad = reshard_operand->sharded_input;
3433 if (need_pad) {
3434 TF_ASSIGN_OR_RETURN(auto sharded_pad_shape,
3435 ShapeInference::InferPadShape(sharded_pad->shape(),
3436 replicated_rhs->shape(),
3437 sharded_padding_config));
3438 sharded_pad = b_.AddInstruction(hlo->CreatePad(sharded_pad_shape,
3439 sharded_pad, replicated_rhs,
3440 sharded_padding_config));
3441 }
3442
3443 SetPartitionedHlo(hlo, [&]() {
3444 if (!reshard_operand->dynamic_slice_index_on_output) {
3445 return sharded_pad;
3446 }
3447 auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
3448 return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
3449 shard_shape, sharded_pad,
3450 *reshard_operand->dynamic_slice_index_on_output,
3451 shard_shape.dimensions()));
3452 });
3453 return OkStatus();
3454 }
3455
HandleParameter(HloInstruction * hlo)3456 Status SpmdPartitioningVisitor::HandleParameter(HloInstruction* hlo) {
3457 SetPartitionedHlo(hlo, [&]() {
3458 auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
3459 auto new_param = b_.AddInstruction(HloInstruction::CreateParameter(
3460 hlo->parameter_number(), shard_shape, "param"));
3461 if (hlo->parameter_replicated_at_leaf_buffers()) {
3462 new_param->set_parameter_replicated_at_leaf_buffers(
3463 *hlo->parameter_replicated_at_leaf_buffers());
3464 }
3465 return new_param;
3466 });
3467 return OkStatus();
3468 }
3469
HandleReduce(HloInstruction * hlo)3470 Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) {
3471 if (hlo->sharding().HasUniqueDevice()) {
3472 return DefaultAction(hlo);
3473 }
3474 int64_t input_count = 1;
3475 auto per_input_sharding = hlo->sharding();
3476 if (hlo->shape().IsTuple()) {
3477 input_count = hlo->shape().tuple_shapes_size();
3478 CHECK_GT(input_count, 0);
3479 per_input_sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
3480 }
3481
3482 std::vector<PartitionedHlo> inputs;
3483 std::vector<HloInstruction*> inits;
3484 std::vector<int64_t> preserved_dims;
3485 for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
3486 if (!absl::c_linear_search(hlo->dimensions(), i)) {
3487 preserved_dims.push_back(i);
3488 }
3489 }
3490
3491 for (int64_t operand_id = 0; operand_id < input_count; ++operand_id) {
3492 inits.push_back(GetPartitionedHlo(hlo->operand(operand_id + input_count))
3493 .Reshard(HloSharding::Replicate())
3494 .hlo());
3495 inputs.push_back(GetPartitionedHlo(hlo->operand(operand_id)));
3496 if (operand_id > 0) {
3497 // Make sure all operands are sharded in the same way.
3498 inputs.back() = inputs.back().Reshard(inputs[0].sharding());
3499 }
3500 if (!inputs[0].sharding().IsTileMaximal()) {
3501 inputs.back() =
3502 inputs.back().PadWithValue(inits[operand_id], /*left_padded_dims=*/{},
3503 /*skipped_dims=*/preserved_dims);
3504 }
3505 }
3506
3507 std::vector<const Shape*> new_operand_shapes(input_count * 2);
3508 for (int64_t i = 0; i < input_count; ++i) {
3509 new_operand_shapes[i] = &inputs[i].hlo()->shape();
3510 new_operand_shapes[i + input_count] = &inits[i]->shape();
3511 }
3512 // Create the shard shape of the reduce result.
3513 TF_ASSIGN_OR_RETURN(
3514 auto reduce_shape,
3515 ShapeInference::InferReduceShape(new_operand_shapes, hlo->dimensions(),
3516 hlo->to_apply()->ComputeProgramShape()));
3517
3518 std::vector<HloInstruction*> input_hlos(input_count);
3519 for (int64_t i = 0; i < input_count; ++i) {
3520 input_hlos[i] = inputs[i].hlo();
3521 }
3522 auto local_reduce = b_.AddInstruction(HloInstruction::CreateReduce(
3523 reduce_shape, input_hlos, inits, hlo->dimensions(), hlo->to_apply()));
3524
3525 SetPartitionedHlo(hlo, [&]() {
3526 HloInstruction* reduce = local_reduce;
3527 const bool reduce_sharded_dimension =
3528 !inputs[0].sharding().IsTileMaximal() &&
3529 absl::c_any_of(hlo->dimensions(), [&](int64_t i) {
3530 return inputs[0].sharding().tile_assignment().dim(i) > 1;
3531 });
3532 if (reduce_sharded_dimension) {
3533 if (inputs[0].sharding().ReplicateOnLastTileDim()) {
3534 preserved_dims.push_back(inputs[0].base_shape().rank());
3535 }
3536 if (local_reduce->shape().IsArray()) {
3537 reduce = partitioner_->AllReduceAlongShardingDims(
3538 &b_, local_reduce, inputs[0].sharding(), next_channel_id_,
3539 hlo->dimensions(), collective_ops_creator_, hlo->to_apply());
3540 } else {
3541 auto grouped = hlo_sharding_util::GroupShardingOnDims(
3542 inputs[0].sharding(), preserved_dims);
3543 auto grouped_state = CreatePerGroupPartitioningState(
3544 inputs[0].state(), grouped.device_groups, &b_);
3545 std::vector<HloInstruction*> all_gathered_partial_results(input_count);
3546 for (int64_t i = 0; i < input_count; ++i) {
3547 auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
3548 ShapeUtil::GetTupleElementShape(reduce_shape, i), local_reduce,
3549 i));
3550 auto expanded_shape = input_hlos[i]->shape();
3551 auto all_gather_shape = input_hlos[i]->shape();
3552 for (int64_t dim : hlo->dimensions()) {
3553 expanded_shape.set_dimensions(dim, 1);
3554 all_gather_shape.set_dimensions(
3555 dim, inputs[0].sharding().tile_assignment().dim(dim));
3556 }
3557 auto reshape = b_.AddInstruction(
3558 HloInstruction::CreateReshape(expanded_shape, gte));
3559 // Replicate per group.
3560 reshape->set_sharding(grouped.sharding);
3561 all_gathered_partial_results[i] =
3562 PartitionedHlo(reshape, all_gather_shape, grouped_state)
3563 .Replicate()
3564 .hlo();
3565 }
3566 reduce = b_.AddInstruction(HloInstruction::CreateReduce(
3567 reduce_shape, all_gathered_partial_results, inits,
3568 hlo->dimensions(), hlo->to_apply()));
3569 }
3570 }
3571 auto sharding = hlo_sharding_util::RemoveShapeDimensions(
3572 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
3573 inputs[0].sharding(), hlo->dimensions()),
3574 hlo->dimensions());
3575 if (local_reduce->shape().IsArray()) {
3576 reduce->set_sharding(sharding);
3577 } else {
3578 reduce->set_sharding(HloSharding::Tuple(
3579 reduce->shape(), std::vector<HloSharding>(input_count, sharding)));
3580 }
3581 return PartitionedHlo(reduce, hlo->shape(), MakePartitioningState())
3582 .Reshard(hlo->sharding())
3583 .hlo();
3584 });
3585 return OkStatus();
3586 }
3587
HandleReverse(HloInstruction * hlo)3588 Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) {
3589 auto reverse = Cast<HloReverseInstruction>(hlo);
3590 if (reverse->sharding().IsTileMaximal()) {
3591 return DefaultAction(hlo);
3592 }
3593 auto operand = GetPartitionedHlo(reverse->operand(0))
3594 .Reshard(hlo_sharding_util::ReverseSharding(
3595 reverse->sharding(), reverse->dimensions()));
3596 auto left_padded_operand =
3597 HaloExchangeToPadOnLeft(operand, reverse->dimensions());
3598 if (!left_padded_operand) {
3599 return DefaultAction(hlo);
3600 }
3601 SetPartitionedHlo(hlo, [&] {
3602 return b_.AddInstruction(hlo->CloneWithNewOperands(
3603 left_padded_operand->shape(), {left_padded_operand}));
3604 });
3605 return OkStatus();
3606 }
3607
HandleWhile(HloInstruction * hlo)3608 Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) {
3609 const HloSharding& sharding = hlo->sharding();
3610
3611 // Shardings for the body parameter, body root, and cond parameter must be
3612 // the same, and the condition root must be replicated so that all partitions
3613 // follow the same control flow.
3614 hlo->while_condition()->parameter_instruction(0)->set_sharding(sharding);
3615 hlo->while_body()->parameter_instruction(0)->set_sharding(sharding);
3616 const HloSharding& cond_root_sharding =
3617 hlo->while_condition()->root_instruction()->sharding();
3618 TF_RETURN_IF_ERROR(partitioner_
3619 ->PartitionComputation(hlo->while_condition(),
3620 cond_root_sharding.IsManual()
3621 ? cond_root_sharding
3622 : HloSharding::Replicate(),
3623 next_channel_id_, logger_)
3624 .status());
3625 TF_RETURN_IF_ERROR(partitioner_
3626 ->PartitionComputation(hlo->while_body(), sharding,
3627 next_channel_id_, logger_)
3628 .status());
3629 SetPartitionedHlo(hlo, [&] {
3630 return b_.AddInstruction(HloInstruction::CreateWhile(
3631 MakePartitionedShape(hlo->shape(), sharding), hlo->while_condition(),
3632 hlo->while_body(),
3633 GetPartitionedHlo(hlo->operand(0)).Reshard(sharding).hlo()));
3634 });
3635 return OkStatus();
3636 }
3637
HandleConditional(HloInstruction * hlo)3638 Status SpmdPartitioningVisitor::HandleConditional(HloInstruction* hlo) {
3639 std::vector<HloInstruction*> branch_args;
3640 for (int64_t i = 0; i < hlo->branch_count(); ++i) {
3641 HloComputation* computation = hlo->branch_computation(i);
3642
3643 // Shardings of the branch computation parameter and its argument must be
3644 // the same.
3645 computation->parameter_instruction(0)->set_sharding(
3646 hlo->operand(i + 1)->sharding());
3647 branch_args.push_back(GetPartitionedHlo(hlo->operand(i + 1)).hlo());
3648 }
3649
3650 // The root of the branch computations must follow the sharding of the
3651 // conditional instruction.
3652 for (int64_t i = 0; i < hlo->branch_count(); ++i) {
3653 HloComputation* computation = hlo->branch_computation(i);
3654 TF_RETURN_IF_ERROR(partitioner_
3655 ->PartitionComputation(computation, hlo->sharding(),
3656 next_channel_id_, logger_)
3657 .status());
3658 }
3659 SetPartitionedHlo(hlo, [&] {
3660 HloInstruction* cond = GetPartitionedHlo(hlo->operand(0)).hlo();
3661 if (!hlo->operand(0)->sharding().IsManual()) {
3662 // We replicate the predicate of the conditional (the first operand) so
3663 // that all partitions follow the same control flow.
3664 cond = GetPartitionedHlo(hlo->operand(0))
3665 .Reshard(HloSharding::Replicate())
3666 .hlo();
3667 }
3668 return b_.AddInstruction(HloInstruction::CreateConditional(
3669 MakePartitionedShape(hlo->shape(), hlo->sharding()), cond,
3670 hlo->called_computations(), branch_args));
3671 });
3672 return OkStatus();
3673 }
3674
HandleOptimizationBarrier(HloInstruction * hlo)3675 Status SpmdPartitioningVisitor::HandleOptimizationBarrier(HloInstruction* hlo) {
3676 return HandleElementwise(hlo);
3677 }
3678
HandleOutfeed(HloInstruction * hlo)3679 Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
3680 if (hlo->sharding().HasUniqueDevice()) {
3681 return HandleSingleDevice(hlo);
3682 }
3683
3684 const auto& sharding = hlo->sharding();
3685 const Shape& shape = hlo->operand(0)->shape();
3686 auto partitioned_operand =
3687 GetPartitionedHlo(hlo->operand(0)).Reshard(sharding);
3688 const auto& shard_shape = partitioned_operand.hlo()->shape();
3689 const auto& operand = partitioned_operand.hlo();
3690 auto token = GetPartitionedHlo(hlo->operand(1)).hlo();
3691
3692 if (EvenlyPartitions(shape, sharding)) {
3693 Shape outfeed_shape = operand->shape();
3694 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(hlo->outfeed_shape(),
3695 &outfeed_shape));
3696 SetPartitionedHlo(hlo, [&]() {
3697 return b_.AddInstruction(HloInstruction::CreateOutfeed(
3698 outfeed_shape, operand, token, hlo->outfeed_config()));
3699 });
3700 return OkStatus();
3701 }
3702
3703 // Create a branch for each unique partitioned shape.
3704 std::vector<Shape> per_branch_partitioned_shapes;
3705 std::vector<int32_t> conditional_branch_indices(num_partitions_);
3706 for (int64_t i = 0; i < num_partitions_; ++i) {
3707 auto partitioned_shape =
3708 MakeNonPaddedShapeForGivenPartition(shape, sharding, i);
3709 int64_t matching_existing_index = 0;
3710 for (; matching_existing_index < per_branch_partitioned_shapes.size();
3711 ++matching_existing_index) {
3712 if (ShapeUtil::Compatible(
3713 partitioned_shape,
3714 per_branch_partitioned_shapes[matching_existing_index])) {
3715 break;
3716 }
3717 }
3718 if (matching_existing_index < per_branch_partitioned_shapes.size()) {
3719 conditional_branch_indices[i] = matching_existing_index;
3720 } else {
3721 conditional_branch_indices[i] = per_branch_partitioned_shapes.size();
3722 per_branch_partitioned_shapes.push_back(std::move(partitioned_shape));
3723 }
3724 }
3725
3726 // Get branch index for this partition.
3727 HloInstruction* branch_index;
3728 auto state = MakePartitioningState();
3729 if (per_branch_partitioned_shapes.size() == num_partitions_) {
3730 // Use partition ID as the branch index if each partition has its own
3731 // branch.
3732 branch_index = state.partition_id;
3733 // PartitionId's output is U32 but conditional requires S32.
3734 if (branch_index->shape().element_type() != S32) {
3735 branch_index = b_.AddInstruction(HloInstruction::CreateConvert(
3736 ShapeUtil::ChangeElementType(branch_index->shape(), S32),
3737 branch_index));
3738 }
3739 } else {
3740 // Otherwise, use a constant table to look up the branch index.
3741 auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant(
3742 LiteralUtil::CreateR1<int32_t>(conditional_branch_indices)));
3743 branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
3744 ShapeUtil::MakeShape(S32, {1}), branch_index_table, {partition_id_},
3745 {1}));
3746 branch_index = b_.AddInstruction(HloInstruction::CreateReshape(
3747 ShapeUtil::MakeShape(S32, {}), branch_index));
3748 }
3749
3750 // Create conditional for the outfeed.
3751 std::vector<HloComputation*> branches(per_branch_partitioned_shapes.size());
3752 for (int64_t i = 0; i < branches.size(); ++i) {
3753 SpmdBuilder branch_b(absl::StrCat("outfeed_branch_", i), visiting_hlo_);
3754 // Create tuple param within the branch.
3755 auto param = branch_b.AddInstruction(HloInstruction::CreateParameter(
3756 /*parameter_number=*/0,
3757 ShapeUtil::MakeTupleShape({operand->shape(), token->shape()}),
3758 "outfeed_token_param"));
3759 auto outfeed_data = branch_b.AddInstruction(
3760 HloInstruction::CreateGetTupleElement(operand->shape(), param, 0));
3761 auto outfeed_token = branch_b.AddInstruction(
3762 HloInstruction::CreateGetTupleElement(token->shape(), param, 1));
3763 if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) {
3764 std::function<HloInstruction*(const ShapeIndex&, HloInstruction*)>
3765 slice_outfeed =
3766 [&](const ShapeIndex& index,
3767 HloInstruction* outfeed_operand) -> HloInstruction* {
3768 // Get outfeed element shape.
3769 const Shape& element_shape =
3770 ShapeUtil::GetSubshape(outfeed_data->shape(), index);
3771 // Recursively call slice_outfeed for tuple shapes.
3772 if (element_shape.IsTuple() && element_shape.tuple_shapes_size() > 0) {
3773 std::vector<HloInstruction*> slice_elements(
3774 element_shape.tuple_shapes_size());
3775 for (int64_t i = 0; i < slice_elements.size(); ++i) {
3776 auto sub_index = index;
3777 sub_index.push_back(i);
3778 slice_elements[i] = slice_outfeed(
3779 sub_index,
3780 branch_b.AddInstruction(HloInstruction::CreateGetTupleElement(
3781 ShapeUtil::GetSubshape(element_shape, {i}), outfeed_operand,
3782 i)));
3783 }
3784 return branch_b.AddInstruction(
3785 HloInstruction::CreateTuple(slice_elements));
3786 }
3787 // Get the slice shape.
3788 const Shape& slice_shape = ShapeUtil::GetSubshape(
3789 per_branch_partitioned_shapes[i], ShapeIndexView(index));
3790 if (ShapeUtil::Compatible(element_shape, slice_shape)) {
3791 return outfeed_operand;
3792 }
3793 // Slice out useful data.
3794 if (element_shape.IsArray()) {
3795 CHECK(slice_shape.IsArray());
3796 std::vector<int64_t> start_indices(slice_shape.rank(), 0);
3797 std::vector<int64_t> slice_strides(slice_shape.rank(), 1);
3798 return branch_b.AddInstruction(HloInstruction::CreateSlice(
3799 slice_shape, outfeed_operand, start_indices,
3800 slice_shape.dimensions(), slice_strides));
3801 }
3802 CHECK(element_shape.IsTuple());
3803 CHECK(element_shape.tuple_shapes().empty());
3804 return outfeed_operand;
3805 };
3806 outfeed_data = slice_outfeed({}, outfeed_data);
3807 }
3808 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3809 hlo->outfeed_shape(), &per_branch_partitioned_shapes[i]));
3810 branch_b.AddInstruction(HloInstruction::CreateOutfeed(
3811 per_branch_partitioned_shapes[i], outfeed_data, outfeed_token,
3812 hlo->outfeed_config()));
3813 branches[i] = module_->AddEmbeddedComputation(branch_b.Build());
3814 }
3815 SetPartitionedHlo(hlo, [&]() {
3816 return b_.AddInstruction(HloInstruction::CreateConditional(
3817 token->shape(), branch_index, branches,
3818 std::vector<HloInstruction*>(
3819 branches.size(),
3820 b_.AddInstruction(HloInstruction::CreateTuple({operand, token})))));
3821 });
3822 return OkStatus();
3823 }
3824
HandleRng(HloInstruction * hlo)3825 Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) {
3826 if (hlo->sharding().HasUniqueDevice()) {
3827 return HandleSingleDevice(hlo);
3828 }
3829 auto clone_from_original = [&](const HloSharding& shared_sharding) {
3830 std::vector<HloInstruction*> new_operands;
3831 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
3832 new_operands.push_back(
3833 GetPartitionedHlo(hlo->operand(i)).Reshard(shared_sharding).hlo());
3834 }
3835 auto clone = b_.AddInstruction(
3836 hlo->CloneWithNewOperands(hlo->shape(), new_operands));
3837 clone->set_sharding(shared_sharding);
3838 return clone;
3839 };
3840
3841 if (hlo->sharding().IsManual()) {
3842 SetPartitionedHlo(hlo,
3843 [&] { return clone_from_original(hlo->sharding()); });
3844 return OkStatus();
3845 }
3846
3847 if (hlo->sharding().IsReplicated()) {
3848 SetPartitionedHlo(hlo, [&] {
3849 // Run on a single device (0) and distribute the data to all other cores.
3850 auto clone = clone_from_original(HloSharding::AssignDevice(0));
3851 return PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
3852 .Reshard(HloSharding::Replicate())
3853 .hlo();
3854 });
3855 return OkStatus();
3856 }
3857
3858 TF_RET_CHECK(!hlo->sharding().IsTileMaximal());
3859 // Replicate the operands and run partitioned Rng on all devices.
3860 std::vector<HloInstruction*> new_operands;
3861 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
3862 new_operands.push_back(GetPartitionedHlo(hlo->operand(i))
3863 .Reshard(HloSharding::Replicate())
3864 .hlo());
3865 }
3866
3867 if (!hlo->sharding().ReplicateOnLastTileDim()) {
3868 SetPartitionedHlo(hlo, [&] {
3869 return b_.AddInstruction(HloInstruction::CreateRng(
3870 MakePartitionedShape(hlo->shape(), hlo->sharding()),
3871 hlo->random_distribution(), new_operands));
3872 });
3873 } else {
3874 std::vector<int64_t> group_dims(
3875 hlo->sharding().tile_assignment().num_dimensions() - 1);
3876 std::iota(group_dims.begin(), group_dims.end(), 0);
3877 auto sharding_grouped =
3878 hlo_sharding_util::GroupShardingOnDims(hlo->sharding(), group_dims);
3879 auto per_group_state = CreatePerGroupPartitioningState(
3880 MakePartitioningState(), sharding_grouped.device_groups, &b_);
3881 auto rng = b_.AddInstruction(HloInstruction::CreateRng(
3882 MakePartitionedShape(hlo->shape(), hlo->sharding()),
3883 hlo->random_distribution(), new_operands));
3884 rng->set_sharding(HloSharding::AssignDevice(0));
3885 SetPartitionedHlo(hlo, [&]() {
3886 return PartitionedHlo(rng, rng->shape(), per_group_state)
3887 .Replicate()
3888 .hlo();
3889 });
3890 }
3891 return OkStatus();
3892 }
3893
HandleReduceWindow(HloInstruction * hlo)3894 Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) {
3895 if (hlo->sharding().IsTileMaximal()) {
3896 return DefaultAction(hlo);
3897 }
3898 HloReduceWindowInstruction* reduce_window =
3899 Cast<HloReduceWindowInstruction>(hlo);
3900 absl::Span<HloInstruction* const> input_arrays = reduce_window->inputs();
3901 absl::Span<HloInstruction* const> init_values = reduce_window->init_values();
3902 int64_t input_idx = 0;
3903 absl::InlinedVector<PartitionedHlo::WindowedInputShardReturnValue, 2>
3904 sharded_results;
3905 absl::InlinedVector<const Shape*, 2> sharded_input_shapes,
3906 replicated_init_shapes;
3907 absl::InlinedVector<HloInstruction*, 2> sharded_inputs, replicated_inits;
3908 for (const HloInstruction* input_array : input_arrays) {
3909 PartitionedHlo& operand = GetPartitionedHlo(input_array);
3910 // Replicate init
3911 PartitionedHlo replicated_init = GetPartitionedHlo(init_values[input_idx++])
3912 .Reshard(HloSharding::Replicate());
3913 auto resharded_operand_and_window = operand.ReshardAsWindowedInput(
3914 hlo->window(), hlo->sharding(), replicated_init.hlo());
3915 if (!resharded_operand_and_window.has_value()) {
3916 return DefaultAction(hlo);
3917 }
3918 sharded_results.push_back(resharded_operand_and_window.value());
3919 sharded_inputs.push_back(resharded_operand_and_window->sharded_input);
3920 sharded_input_shapes.push_back(&sharded_inputs.back()->shape());
3921 replicated_inits.push_back(replicated_init.hlo());
3922 replicated_init_shapes.push_back(&replicated_inits.back()->shape());
3923 }
3924 TF_ASSIGN_OR_RETURN(Shape sharded_rw_shape,
3925 ShapeInference::InferReduceWindowShape(
3926 sharded_input_shapes, replicated_init_shapes,
3927 sharded_results[0].shard_window,
3928 hlo->to_apply()->ComputeProgramShape()));
3929 HloSharding result_sharding =
3930 (hlo->shape().IsTuple())
3931 ? hlo->sharding().GetTupleSharding(hlo->shape()).ValueOrDie()
3932 : hlo->sharding();
3933 Shape shard_shape = MakePartitionedShape(hlo->shape(), result_sharding);
3934 if (shard_shape.has_layout()) {
3935 *sharded_rw_shape.mutable_layout() = shard_shape.layout();
3936 }
3937 SetPartitionedHlo(hlo, [&]() {
3938 HloInstruction* sharded_rw =
3939 b_.AddInstruction(HloInstruction::CreateReduceWindow(
3940 sharded_rw_shape, sharded_inputs, replicated_inits,
3941 sharded_results[0].shard_window, hlo->to_apply()));
3942 if (!sharded_results[0].dynamic_slice_index_on_output.has_value()) {
3943 CHECK(ShapeUtil::Compatible(shard_shape, sharded_rw->shape()))
3944 << shard_shape << " vs " << sharded_rw->shape() << "\n";
3945 return sharded_rw;
3946 }
3947 return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
3948 shard_shape, sharded_rw,
3949 *sharded_results[0].dynamic_slice_index_on_output,
3950 shard_shape.dimensions()));
3951 });
3952 return OkStatus();
3953 }
3954
HandleSelectAndScatter(HloInstruction * hlo)3955 Status SpmdPartitioningVisitor::HandleSelectAndScatter(HloInstruction* hlo) {
3956 if (hlo->sharding().IsTileMaximal()) {
3957 return DefaultAction(hlo);
3958 }
3959 auto operand = GetPartitionedHlo(hlo->operand(0));
3960 auto source = GetPartitionedHlo(hlo->mutable_operand(1));
3961 if (hlo->sharding() != operand.sharding()) {
3962 operand = operand.Reshard(hlo->sharding());
3963 }
3964 if (hlo->sharding() != source.sharding()) {
3965 source = source.Reshard(hlo->sharding());
3966 }
3967
3968 // For F32 and BF16 types, we can use NaN padding to workaround the issue with
3969 // low/high padding, since comparison will return false with NaN input.
3970 if (hlo->shape().element_type() != F32 &&
3971 hlo->shape().element_type() != BF16) {
3972 return DefaultAction(hlo);
3973 }
3974
3975 auto select = hlo->called_computations()[0];
3976 auto select_root = select->root_instruction();
3977 if (select_root->opcode() != HloOpcode::kCompare ||
3978 select_root->operand(0)->opcode() != HloOpcode::kParameter ||
3979 select_root->operand(1)->opcode() != HloOpcode::kParameter ||
3980 select_root->operand(0)->parameter_number() +
3981 select_root->operand(1)->parameter_number() !=
3982 1) {
3983 return DefaultAction(hlo);
3984 }
3985
3986 float float_pad_value;
3987 if (select_root->comparison_direction() == ComparisonDirection::kGe ||
3988 select_root->comparison_direction() == ComparisonDirection::kGt) {
3989 if (select_root->operand(0)->parameter_number() == 0) {
3990 float_pad_value = -std::numeric_limits<float>::infinity();
3991 } else {
3992 float_pad_value = std::numeric_limits<float>::infinity();
3993 }
3994 } else if (select_root->comparison_direction() == ComparisonDirection::kLe ||
3995 select_root->comparison_direction() == ComparisonDirection::kLt) {
3996 if (select_root->operand(0)->parameter_number() == 0) {
3997 float_pad_value = std::numeric_limits<float>::infinity();
3998 } else {
3999 float_pad_value = -std::numeric_limits<float>::infinity();
4000 }
4001 } else {
4002 return DefaultAction(hlo);
4003 }
4004
4005 auto pad_value = b_.AddInstruction(HloInstruction::CreateConstant(
4006 hlo->shape().element_type() == BF16
4007 ? LiteralUtil::CreateR0<bfloat16>(
4008 static_cast<bfloat16>(float_pad_value))
4009 : LiteralUtil::CreateR0<float>(float_pad_value)));
4010
4011 // Replicate init
4012 auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(2))
4013 .Reshard(HloSharding::Replicate());
4014
4015 auto state = MakePartitioningState();
4016 auto partition_ordinals =
4017 MakeTiledPartitionOrdinals(hlo->sharding(), state.partition_id, &b_);
4018
4019 // The first window for each dimension that overlaps with the shard area.
4020 std::vector<MultiplyAddDivideOffsetCalculation> first_window(
4021 hlo->shape().rank());
4022 // The first window for each dimension that goes beyond with the shard area.
4023 std::vector<MultiplyAddDivideOffsetCalculation> limit_window(
4024 hlo->shape().rank());
4025 std::vector<OffsetCalculation> data_left_halo_sizes(hlo->shape().rank());
4026 std::vector<OffsetCalculation> data_right_halo_sizes(hlo->shape().rank());
4027 std::vector<OffsetCalculation> source_left_halo_sizes(hlo->shape().rank());
4028 std::vector<OffsetCalculation> source_right_halo_sizes(hlo->shape().rank());
4029 auto unpadded_data_shard_shape =
4030 MakePartitionedShape(hlo->shape(), hlo->sharding());
4031 auto unpadded_source_shard_shape =
4032 MakePartitionedShape(hlo->operand(1)->shape(), hlo->sharding());
4033 auto source_shard_hlo = source.hlo();
4034 auto data_shard_hlo = operand.hlo();
4035 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
4036 int64_t shard_count = hlo->sharding().tile_assignment().dim(i);
4037 if (shard_count == 1) {
4038 continue;
4039 }
4040 // If stride > window_size, there will be gaps between windows. These gaps
4041 // will also exist in the output, so we keep them during halo exchange.
4042 //
4043 // TODO(yuanzx): This could introduce overhead if partitions start at
4044 // different offsets in a gap.
4045 auto wd = hlo->window().dimensions(i);
4046 if (wd.stride() > wd.size()) {
4047 wd.set_size(wd.stride());
4048 }
4049 // shard_size * i < stride * k - pad_low + window_size =>
4050 // k > (shard_size * i + pad_low - window_size) / stride =>
4051 // first_k == (shard_size * i + pad_low - window_size + stride) / stride
4052 first_window[i] = MultiplyAddDivideOffsetCalculation(
4053 unpadded_data_shard_shape.dimensions(i),
4054 wd.padding_low() - wd.size() + wd.stride(), wd.stride());
4055 // shard_size * (i + 1) <= stride * k - pad_low =>
4056 // k >= (shard_size * i + shard_size + pad_low) / stride =>
4057 // limit_k == (shard_size * i + shard_size + pad_low + stride - 1) /
4058 // stride
4059 limit_window[i] = MultiplyAddDivideOffsetCalculation(
4060 unpadded_data_shard_shape.dimensions(i),
4061 unpadded_data_shard_shape.dimensions(i) + wd.padding_low() +
4062 wd.stride() - 1,
4063 wd.stride());
4064 source_left_halo_sizes[i] =
4065 MultiplyAddDivideOffsetCalculation(
4066 unpadded_source_shard_shape.dimensions(i), 0, 1) -
4067 first_window[i];
4068 source_right_halo_sizes[i] =
4069 limit_window[i] - MultiplyAddDivideOffsetCalculation(
4070 unpadded_source_shard_shape.dimensions(i),
4071 unpadded_source_shard_shape.dimensions(i), 1);
4072 data_left_halo_sizes[i] =
4073 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
4074 unpadded_data_shard_shape.dimensions(i), wd.padding_low(), 1)) -
4075 OffsetCalculation(
4076 HloOpcode::kMultiply, first_window[i],
4077 MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1));
4078 data_right_halo_sizes[i] =
4079 OffsetCalculation(
4080 HloOpcode::kMultiply, limit_window[i],
4081 MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1)) -
4082 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
4083 unpadded_data_shard_shape.dimensions(i),
4084 unpadded_data_shard_shape.dimensions(i) + wd.stride() +
4085 wd.padding_low() - wd.size(),
4086 1));
4087
4088 int64_t max_windows =
4089 (limit_window[i] - first_window[i]).MaxInRange(0, shard_count);
4090 auto first_window_hlo =
4091 first_window[i].Calculate(partition_ordinals[i], &b_);
4092 // Padding on the source is filled with the init value so they do not change
4093 // the data on overlapping windows.
4094 auto resharded_source = ExchangeHaloAndGetValidData(
4095 source_shard_hlo, source.base_shape(), source_left_halo_sizes[i],
4096 source_right_halo_sizes[i], 0,
4097 limit_window[i].Calculate(shard_count - 1), max_windows, i,
4098 hlo->sharding(), first_window_hlo, replicated_init.hlo(),
4099 partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_);
4100 if (!resharded_source) {
4101 return DefaultAction(hlo);
4102 }
4103 source_shard_hlo = *resharded_source;
4104
4105 auto offset_start_in_data =
4106 MultiplyAddDivideOffsetCalculation(wd.stride(), 0, 1)
4107 .Calculate(first_window_hlo, &b_);
4108 int64_t padded_data_size =
4109 (limit_window[i].Calculate(shard_count - 1) - 1) * wd.stride() +
4110 wd.size();
4111 int64_t data_shard_size = (max_windows - 1) * wd.stride() + wd.size();
4112 auto resharded_data = ExchangeHaloAndGetValidData(
4113 data_shard_hlo, operand.base_shape(), data_left_halo_sizes[i],
4114 data_right_halo_sizes[i], wd.padding_low(), padded_data_size,
4115 data_shard_size, i, hlo->sharding(), offset_start_in_data, pad_value,
4116 partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_);
4117 if (!resharded_data) {
4118 return DefaultAction(hlo);
4119 }
4120 data_shard_hlo = *resharded_data;
4121 }
4122
4123 Window window_on_shard = hlo->window();
4124 for (int64_t i = 0; i < window_on_shard.dimensions_size(); ++i) {
4125 int64_t shard_count = hlo->sharding().tile_assignment().dim(i);
4126 if (shard_count == 1) {
4127 continue;
4128 }
4129 auto reshard_wd = window_on_shard.mutable_dimensions(i);
4130 // The shards are already explicitly padded.
4131 reshard_wd->set_padding_low(0);
4132 reshard_wd->set_padding_high(0);
4133 }
4134
4135 auto sharded_select_and_scatter =
4136 b_.AddInstruction(HloInstruction::CreateSelectAndScatter(
4137 data_shard_hlo->shape(), data_shard_hlo, select, window_on_shard,
4138 source_shard_hlo, replicated_init.hlo(),
4139 hlo->called_computations()[1]));
4140 SetPartitionedHlo(hlo, [&]() {
4141 auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
4142 if (ShapeUtil::Compatible(sharded_select_and_scatter->shape(),
4143 shard_shape)) {
4144 return sharded_select_and_scatter;
4145 }
4146 auto zero = b_.AddInstruction(
4147 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
4148 std::vector<HloInstruction*> slice_offsets(shard_shape.rank(), zero);
4149 for (int64_t i = 0; i < window_on_shard.dimensions_size(); ++i) {
4150 if (hlo->sharding().tile_assignment().dim(i) == 1) {
4151 continue;
4152 }
4153 int64_t pad_low = hlo->window().dimensions(i).padding_low();
4154 auto left_halo_size =
4155 data_left_halo_sizes[i].Calculate(partition_ordinals[i], &b_);
4156 if (data_left_halo_sizes[i].Calculate(0) == pad_low) {
4157 slice_offsets[i] = left_halo_size;
4158 } else {
4159 auto is_shard0 = b_.AddInstruction(HloInstruction::CreateCompare(
4160 ShapeUtil::MakeShape(PRED, {}), zero, partition_ordinals[i],
4161 ComparisonDirection::kEq));
4162 auto pad_low_hlo = b_.AddInstruction(HloInstruction::CreateConstant(
4163 LiteralUtil::CreateR0<int32_t>(pad_low)));
4164 slice_offsets[i] = b_.AddInstruction(HloInstruction::CreateTernary(
4165 zero->shape(), HloOpcode::kSelect, is_shard0, pad_low_hlo,
4166 left_halo_size));
4167 }
4168 }
4169 return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
4170 shard_shape, sharded_select_and_scatter, slice_offsets,
4171 shard_shape.dimensions()));
4172 });
4173 return OkStatus();
4174 }
4175
HandleTuple(HloInstruction * hlo)4176 Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) {
4177 std::vector<HloInstruction*> new_operands;
4178 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
4179 new_operands.push_back(
4180 GetPartitionedHlo(hlo->operand(i))
4181 .Reshard(hlo->sharding().GetSubSharding(hlo->shape(), {i}))
4182 .hlo());
4183 }
4184 SetPartitionedHlo(hlo, [&]() {
4185 return b_.AddInstruction(HloInstruction::CreateTuple(new_operands));
4186 });
4187 return OkStatus();
4188 }
4189
DoPartition(HloComputation * computation,const HloSharding & root_sharding,const SpmdPartitionerOptions & options)4190 StatusOr<bool> SpmdPartitioningVisitor::DoPartition(
4191 HloComputation* computation, const HloSharding& root_sharding,
4192 const SpmdPartitionerOptions& options) {
4193 VLOG(2) << "Partitioning computation " << computation->name() << " for "
4194 << num_replicas_ << " replicas and " << num_partitions_
4195 << " partitions";
4196 TF_RETURN_IF_ERROR(computation->Accept(this));
4197
4198 HloModule* module = computation->parent();
4199 auto new_root =
4200 GetPartitionedHlo(computation->root_instruction()).Reshard(root_sharding);
4201 auto new_computation =
4202 module->AddEmbeddedComputation(b_.Build(new_root.hlo()));
4203 TF_RETURN_IF_ERROR(
4204 DoCodeMotionForWindowedDotGeneralLoops(new_computation, options));
4205
4206 // Replace the original computation with the new SPMD computation.
4207 absl::flat_hash_map<HloComputation*, HloComputation*> replacement;
4208 replacement[computation] = new_computation;
4209 module->ReplaceComputations(replacement);
4210 return changed_;
4211 }
4212
HandlePartitionId(HloInstruction * hlo)4213 Status SpmdPartitioningVisitor::HandlePartitionId(HloInstruction* hlo) {
4214 return Unimplemented(
4215 "PartitionId instruction is not supported for SPMD partitioning since "
4216 "the meaning is ambiguous -- whether the instruction is replicated or "
4217 "the data is replicated, and if the latter which data is replicated.");
4218 }
4219
GetDefaultCollectiveOpsCreator(int64_t num_partitions,int64_t num_replicas)4220 SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64_t num_partitions,
4221 int64_t num_replicas) {
4222 return {
4223 [](SpmdBuilder* b) {
4224 return b->AddInstruction(HloInstruction::CreatePartitionId());
4225 },
4226 [num_replicas, num_partitions](
4227 SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction,
4228 const std::vector<std::vector<int64_t>>& partition_subgroups,
4229 int64_t channel_id) {
4230 if (partition_subgroups.size() <= 1) {
4231 std::vector<ReplicaGroup> groups(num_replicas);
4232 // TODO(yuanzx): Unify subgroup definition with AllToAll.
4233 for (int64_t i = 0; i < num_replicas; ++i) {
4234 groups[i].add_replica_ids(i);
4235 }
4236 return b->AddInstruction(HloInstruction::CreateAllReduce(
4237 operand->shape(), {operand}, reduction, groups,
4238 /*constrain_layout=*/false, channel_id,
4239 /*use_global_device_ids=*/false));
4240 }
4241
4242 std::vector<ReplicaGroup> device_groups;
4243 device_groups.reserve(partition_subgroups.size() * num_replicas);
4244 for (int64_t i = 0; i < num_replicas; ++i) {
4245 for (const auto& pgroup : partition_subgroups) {
4246 device_groups.emplace_back();
4247 for (int64_t pid : pgroup) {
4248 device_groups.back().add_replica_ids(i * num_partitions + pid);
4249 }
4250 }
4251 }
4252 return b->AddInstruction(HloInstruction::CreateAllReduce(
4253 operand->shape(), {operand}, reduction, device_groups,
4254 /*constrain_layout=*/false, channel_id,
4255 /*use_global_device_ids=*/true));
4256 },
4257 [num_partitions](SpmdBuilder* b, HloInstruction* operand,
4258 std::vector<std::pair<int64_t, int64_t>>& src_dst_pairs,
4259 int64_t channel_id) {
4260 /* optimize trivial collective permute */
4261 if (src_dst_pairs.empty()) {
4262 // If the src/dst pairs are empty, then the collective permute just
4263 // initializes the output to zero.
4264 return CreateZero(operand->shape(), b);
4265 } else {
4266 // A collective-permute is a copy if all pairs are "identity" and
4267 // all partitions are listed.
4268 bool is_copy =
4269 src_dst_pairs.size() == num_partitions &&
4270 absl::c_all_of(src_dst_pairs,
4271 [](const std::pair<int64_t, int64_t>& pair) {
4272 return pair.first == pair.second;
4273 });
4274 if (is_copy) {
4275 return operand;
4276 } else {
4277 return b->AddInstruction(HloInstruction::CreateCollectivePermute(
4278 operand->shape(), operand, src_dst_pairs, channel_id));
4279 }
4280 }
4281 },
4282 [](SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
4283 const std::vector<std::vector<int64_t>>& partition_subgroups,
4284 int64_t channel_id, std::optional<int64_t> split_dimension) {
4285 std::vector<Shape> shapes(operands.size(), operands[0]->shape());
4286 const Shape output_shape = (shapes.size() == 1)
4287 ? shapes[0]
4288 : ShapeUtil::MakeTupleShape(shapes);
4289 std::vector<ReplicaGroup> groups(partition_subgroups.size());
4290 for (int64_t i = 0; i < groups.size(); ++i) {
4291 for (int64_t id : partition_subgroups[i]) {
4292 groups[i].add_replica_ids(id);
4293 }
4294 }
4295 return b->AddInstruction(HloInstruction::CreateAllToAll(
4296 output_shape, operands, groups,
4297 /*constrain_layout=*/false, channel_id, split_dimension));
4298 },
4299 [num_replicas, num_partitions](
4300 SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape,
4301 const std::vector<std::vector<int64_t>>& partition_subgroups,
4302 int64_t channel_id, int64_t all_gather_dimension) {
4303 std::vector<ReplicaGroup> device_groups;
4304 device_groups.reserve(partition_subgroups.size() * num_replicas);
4305 for (int64_t i = 0; i < num_replicas; ++i) {
4306 for (const auto& pgroup : partition_subgroups) {
4307 device_groups.emplace_back();
4308 for (int64_t pid : pgroup) {
4309 device_groups.back().add_replica_ids(i * num_partitions + pid);
4310 }
4311 }
4312 }
4313 return b->AddInstruction(HloInstruction::CreateAllGather(
4314 ag_shape, {operand}, all_gather_dimension, device_groups,
4315 /*constrain_layout=*/false, channel_id,
4316 /*use_global_device_ids=*/true));
4317 },
4318 };
4319 }
4320
SpmdPartitioner(int64_t num_partitions,int64_t num_replicas,SpmdPartitionerOptions options)4321 SpmdPartitioner::SpmdPartitioner(int64_t num_partitions, int64_t num_replicas,
4322 SpmdPartitionerOptions options)
4323 : SpmdPartitioner(
4324 num_partitions, num_replicas, std::move(options),
4325 GetDefaultCollectiveOpsCreator(num_partitions, num_replicas)) {}
4326
AllGatherShards(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64_t * next_channel_id,absl::Span<const int64_t> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator)4327 HloInstruction* SpmdPartitioner::AllGatherShards(
4328 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
4329 int64_t* next_channel_id, absl::Span<const int64_t> selected_dims,
4330 const SPMDCollectiveOpsCreator& collectives_creator) {
4331 return AllGatherShardsInternal(b, operand, sharding, next_channel_id,
4332 selected_dims, collectives_creator,
4333 /*per_dim_ag=*/true);
4334 }
4335
AllGatherShardsInternal(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64_t * next_channel_id,absl::Span<const int64_t> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator,bool per_dim_ag)4336 HloInstruction* SpmdPartitioner::AllGatherShardsInternal(
4337 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
4338 int64_t* next_channel_id, absl::Span<const int64_t> selected_dims,
4339 const SPMDCollectiveOpsCreator& collectives_creator, bool per_dim_ag) {
4340 if (selected_dims.empty()) {
4341 return operand;
4342 }
4343 CHECK(!sharding.IsTileMaximal());
4344 if (per_dim_ag || selected_dims.size() == 1) {
4345 HloInstruction* result = operand;
4346 Shape result_shape = operand->shape();
4347 for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
4348 if (sharding.tile_assignment().dim(*it) == 1) {
4349 continue;
4350 }
4351 auto partition_subgroups =
4352 GetPartitionGroupsForReplication(sharding, {*it});
4353 result_shape.set_dimensions(
4354 *it, result_shape.dimensions(*it) * partition_subgroups[0].size());
4355 result = collectives_creator.create_cross_partition_all_gather(
4356 b, result, result_shape, partition_subgroups, (*next_channel_id)++,
4357 /*all_gather_dimension=*/*it);
4358 }
4359 return result;
4360 }
4361 std::vector<int64_t> shape;
4362 shape.push_back(1);
4363 for (int64_t dim : operand->shape().dimensions()) {
4364 shape.push_back(dim);
4365 }
4366 // Add one leading dimension to gather all partitions.
4367 auto reshape = b->AddInstruction(HloInstruction::CreateReshape(
4368 ShapeUtil::MakeShape(operand->shape().element_type(), shape), operand));
4369 HloInstruction* result = reshape;
4370 auto partition_subgroups =
4371 GetPartitionGroupsForReplication(sharding, selected_dims);
4372 shape[0] *= partition_subgroups[0].size();
4373 result = collectives_creator.create_cross_partition_all_gather(
4374 b, result, ShapeUtil::MakeShape(operand->shape().element_type(), shape),
4375 partition_subgroups, (*next_channel_id)++,
4376 /*all_gather_dimension=*/0);
4377 // If n > 1 dimensions are partitioned, split the leading dimension to n.
4378 std::vector<int64_t> tiled_dims;
4379 for (int64_t i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
4380 if (sharding.tile_assignment().dim(i) > 1 &&
4381 absl::c_linear_search(selected_dims, i)) {
4382 tiled_dims.push_back(i);
4383 }
4384 }
4385 if (tiled_dims.size() > 1) {
4386 std::vector<int64_t> split_dim_shape;
4387 split_dim_shape.reserve(tiled_dims.size() + operand->shape().rank());
4388 for (int64_t i : tiled_dims) {
4389 split_dim_shape.push_back(sharding.tile_assignment().dim(i));
4390 }
4391 for (int64_t dim : operand->shape().dimensions()) {
4392 split_dim_shape.push_back(dim);
4393 }
4394 result = b->AddInstruction(HloInstruction::CreateReshape(
4395 ShapeUtil::MakeShape(operand->shape().element_type(), split_dim_shape),
4396 result));
4397 }
4398 // Transpose the gathered dimensions to next to their corresponding
4399 // partitioned dimensions.
4400 std::vector<int64_t> xpose_permutation(result->shape().rank());
4401 int64_t split_dims_added = 0;
4402 for (int64_t i = 0; i < xpose_permutation.size(); ++i) {
4403 if (sharding.tile_assignment().dim(i - split_dims_added) == 1 ||
4404 !absl::c_linear_search(selected_dims, i - split_dims_added)) {
4405 xpose_permutation[i] = i + tiled_dims.size() - split_dims_added;
4406 } else {
4407 xpose_permutation[i] = split_dims_added;
4408 xpose_permutation[i + 1] = i + tiled_dims.size() - split_dims_added;
4409 split_dims_added++;
4410 i++;
4411 }
4412 }
4413 result = b->AddInstruction(HloInstruction::CreateTranspose(
4414 ShapeInference::InferTransposeShape(result->shape(), xpose_permutation)
4415 .ValueOrDie(),
4416 result, xpose_permutation));
4417 // Reshape to the desired shape.
4418 auto ag_shape = operand->shape();
4419 for (int64_t i : tiled_dims) {
4420 ag_shape.set_dimensions(
4421 i, ag_shape.dimensions(i) * sharding.tile_assignment().dim(i));
4422 }
4423 result = b->AddInstruction(HloInstruction::CreateReshape(ag_shape, result));
4424 return result;
4425 }
4426
AllReduceAlongShardingDims(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64_t * next_channel_id,absl::Span<const int64_t> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator,HloComputation * reduction)4427 HloInstruction* SpmdPartitioner::AllReduceAlongShardingDims(
4428 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
4429 int64_t* next_channel_id, absl::Span<const int64_t> selected_dims,
4430 const SPMDCollectiveOpsCreator& collectives_creator,
4431 HloComputation* reduction) {
4432 return AllReduceAlongShardingDimsInternal(
4433 b, operand, sharding, next_channel_id, selected_dims, collectives_creator,
4434 reduction, /*per_dim_ar=*/true);
4435 }
4436
AllReduceAlongShardingDimsInternal(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64_t * next_channel_id,absl::Span<const int64_t> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator,HloComputation * reduction,bool per_dim_ar)4437 HloInstruction* SpmdPartitioner::AllReduceAlongShardingDimsInternal(
4438 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
4439 int64_t* next_channel_id, absl::Span<const int64_t> selected_dims,
4440 const SPMDCollectiveOpsCreator& collectives_creator,
4441 HloComputation* reduction, bool per_dim_ar) {
4442 if (!per_dim_ar) {
4443 auto partition_subgroups =
4444 GetPartitionGroupsForReplication(sharding, selected_dims);
4445 return collectives_creator.create_cross_partition_all_reduce(
4446 b, operand, reduction, partition_subgroups, (*next_channel_id)++);
4447 }
4448 auto result = operand;
4449 for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
4450 if (sharding.tile_assignment().dim(*it) == 1) {
4451 continue;
4452 }
4453 auto partition_subgroups =
4454 GetPartitionGroupsForReplication(sharding, {*it});
4455 result = collectives_creator.create_cross_partition_all_reduce(
4456 b, result, reduction, partition_subgroups, (*next_channel_id)++);
4457 }
4458 return result;
4459 }
4460
PartitionComputation(HloComputation * computation,const HloSharding & root_sharding,int64_t * next_channel_id,SpmdLogger * logger)4461 StatusOr<bool> SpmdPartitioner::PartitionComputation(
4462 HloComputation* computation, const HloSharding& root_sharding,
4463 int64_t* next_channel_id, SpmdLogger* logger) {
4464 auto visitor =
4465 CreateVisitor(computation, num_partitions_, num_replicas_,
4466 collective_ops_creator_, next_channel_id, logger, options_);
4467 return visitor->DoPartition(computation, root_sharding, options_);
4468 }
4469
CreateVisitor(HloComputation * computation,int64_t num_partitions,int64_t num_replicas,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t * next_channel_id,SpmdLogger * logger,SpmdPartitionerOptions options)4470 std::unique_ptr<SpmdPartitioningVisitor> SpmdPartitioner::CreateVisitor(
4471 HloComputation* computation, int64_t num_partitions, int64_t num_replicas,
4472 const SPMDCollectiveOpsCreator& collective_ops_creator,
4473 int64_t* next_channel_id, SpmdLogger* logger,
4474 SpmdPartitionerOptions options) {
4475 return std::make_unique<SpmdPartitioningVisitor>(
4476 computation, num_partitions, num_replicas, collective_ops_creator,
4477 next_channel_id, logger, std::move(options), this);
4478 }
4479
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)4480 StatusOr<bool> SpmdPartitioner::Run(
4481 HloModule* module,
4482 const absl::flat_hash_set<absl::string_view>& execution_threads) {
4483 TF_RETURN_IF_ERROR(PreprocessSharding(module, execution_threads));
4484 TF_RETURN_IF_ERROR(PreprocessHlos(module, execution_threads));
4485
4486 XLA_VLOG_LINES(1, SpmdLogger::ReportBeforePartition(
4487 *module, options_.report_instruction_count));
4488
4489 // Add the parameters' and output's shardings to the module.
4490 std::vector<HloSharding> entry_params_shardings;
4491 const auto num_parameters = module->entry_computation()->num_parameters();
4492 entry_params_shardings.reserve(num_parameters);
4493 for (int64_t i = 0; i < num_parameters; ++i) {
4494 auto param = module->entry_computation()->parameter_instruction(i);
4495 CHECK(param->has_sharding()) << "Missing sharding in entry parameter " << i;
4496 entry_params_shardings.push_back(param->sharding());
4497 }
4498 module->set_spmd_parameters_shardings(entry_params_shardings);
4499 auto entry_root = module->entry_computation()->root_instruction();
4500 CHECK(entry_root->has_sharding()) << "Missing sharding in entry root.";
4501 module->set_spmd_output_sharding(entry_root->sharding());
4502
4503 FlattenCallGraph flatten;
4504 TF_ASSIGN_OR_RETURN(auto changed, flatten.Run(module));
4505
4506 SpmdLogger logger(options_.report_instruction_count,
4507 /*disabled=*/!VLOG_IS_ON(1));
4508 auto program_shape = module->entry_computation()->ComputeProgramShape();
4509 int64_t next_channel_id = hlo_query::NextChannelId(*module);
4510 // Copy the root sharding since the partitioner visitor may temporarily change
4511 // the sharding to work around manual sharding.
4512 HloSharding root_sharding = entry_root->sharding();
4513 TF_ASSIGN_OR_RETURN(
4514 bool partition_changed,
4515 PartitionComputation(module->entry_computation(), root_sharding,
4516 &next_channel_id, &logger));
4517 changed |= partition_changed;
4518
4519 // For the entry computation, make sure that the root instruction and the
4520 // parameters preserve their signatures.
4521 auto new_program_shape = module->entry_computation()->ComputeProgramShape();
4522 if (!options_.allow_module_signature_change) {
4523 TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
4524 program_shape.result(), new_program_shape.result()))
4525 << "Result shape changed for the entry computation";
4526 TF_RET_CHECK(program_shape.parameters_size() ==
4527 new_program_shape.parameters_size())
4528 << "Parameter count changed for the entry computation";
4529 for (int64_t i = 0; i < program_shape.parameters_size(); ++i) {
4530 TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
4531 program_shape.parameters(i), new_program_shape.parameters(i)))
4532 << "Parameter shape changed for the entry computation";
4533 }
4534 } else {
4535 // Fix up some bad tiling in entry computation layout.
4536 auto update_shape = [this](Shape* subshape, const xla::ShapeIndex& index) {
4537 if (subshape->IsArray()) {
4538 UpdateLayout(subshape);
4539 }
4540 };
4541 const auto& old_entry_layout = module->entry_computation_layout();
4542 // Shapes can change but the layout should still remain the same.
4543 for (int64_t i = 0; i < new_program_shape.parameters_size(); ++i) {
4544 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
4545 old_entry_layout.parameter_shape(i),
4546 new_program_shape.mutable_parameters(i)));
4547 ShapeUtil::ForEachMutableSubshape(new_program_shape.mutable_parameters(i),
4548 update_shape);
4549 }
4550 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
4551 old_entry_layout.result_shape(), new_program_shape.mutable_result()));
4552 ShapeUtil::ForEachMutableSubshape(new_program_shape.mutable_result(),
4553 update_shape);
4554
4555 HloModuleConfig config = module->config();
4556 *config.mutable_entry_computation_layout() =
4557 ComputationLayout(new_program_shape, /*ignore_layouts=*/false);
4558 module->set_config(config);
4559 }
4560
4561 XLA_VLOG_LINES(1, SpmdLogger::ReportAfterPartition(
4562 *module, options_.report_instruction_count));
4563 XLA_VLOG_LINES(1, logger.MakeReport());
4564
4565 if (changed) {
4566 HloPassPipeline pass("spmd-cleanup");
4567 pass.AddPass<HloDCE>(/*remove_cross_partition_collective_ops=*/true);
4568 pass.AddPass<TupleSimplifier>();
4569 pass.AddPass<HloDCE>(/*remove_cross_partition_collective_ops=*/true);
4570 pass.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
4571 pass.AddPass<FlattenCallGraph>();
4572 TF_RETURN_IF_ERROR(pass.Run(module, execution_threads).status());
4573 }
4574
4575 TF_RETURN_IF_ERROR(ClearShardingAttributes(module, execution_threads));
4576 return changed;
4577 }
4578
PreprocessSharding(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)4579 Status SpmdPartitioner::PreprocessSharding(
4580 HloModule* module,
4581 const absl::flat_hash_set<absl::string_view>& execution_threads) {
4582 for (HloComputation* computation : module->computations(execution_threads)) {
4583 for (HloInstruction* hlo : computation->instructions()) {
4584 if (hlo->HasSideEffectNoRecurse() && hlo->opcode() != HloOpcode::kRng) {
4585 TF_RET_CHECK(hlo->has_sharding())
4586 << "Side-effect HLO must have sharding: " << hlo->ToString();
4587 TF_RET_CHECK(!HasReplicatedSharding(hlo->sharding()) ||
4588 CanSideEffectingHaveReplicatedSharding(hlo))
4589 << "side-effect HLO cannot have a replicated sharding: "
4590 << hlo->ToString();
4591 }
4592
4593 // For unassigned HLOs, annotate with replicated sharding.
4594 //
4595 // Among side-effecting ops, only Rng is allowed to omit the annotation.
4596 // In that case, we currently force it to run on core 0, since we don't
4597 // support partitioning or replicating the Rng op (the values depend on
4598 // the seed provided to each device).
4599 //
4600 // TODO(hyouklee): Should we also convert single-device shardings (without
4601 // side-effects) into replicated?
4602 if (!hlo->has_sharding()) {
4603 if (hlo->opcode() == HloOpcode::kRng) {
4604 hlo->set_sharding(HloSharding::AssignDevice(0));
4605 } else {
4606 hlo->set_sharding(
4607 HloSharding::Single(hlo->shape(), HloSharding::Replicate()));
4608 }
4609 } else if (!hlo->sharding().IsTileMaximal() &&
4610 !hlo->sharding().IsManual()) {
4611 std::vector<int64_t> available(num_partitions_);
4612 std::iota(available.begin(), available.end(), 0);
4613 TF_RET_CHECK(num_partitions_ == hlo_sharding_util::DevicesForSharding(
4614 hlo->sharding(), available)
4615 .size())
4616 << "num_partitions:" << num_partitions_ << "\n"
4617 << "SPMD partitioner only supports tile sharding that includes all "
4618 "partitions. If you didn't add this sharding annotation in the "
4619 "model, please file a bug to XLA team.\n"
4620 << hlo->ToString();
4621 }
4622 }
4623 }
4624
4625 // Entry computation's parameter and root sharding must be either all
4626 // replicated or all on a single device.
4627 if (!options_.allow_module_signature_change) {
4628 const HloComputation* entry = module->entry_computation();
4629 TF_RET_CHECK(entry->root_instruction()->has_sharding());
4630 const HloSharding& root_sharding = entry->root_instruction()->sharding();
4631 if (!root_sharding.UniqueDevice().has_value()) {
4632 if (root_sharding.IsTuple()) {
4633 TF_RET_CHECK(absl::c_all_of(root_sharding.tuple_elements(),
4634 [](const HloSharding& s) {
4635 return s.IsReplicated() || s.IsManual();
4636 }))
4637 << "Unsupported entry root sharding: " << root_sharding.ToString();
4638
4639 } else {
4640 TF_RET_CHECK(root_sharding.IsReplicated() || root_sharding.IsManual())
4641 << "Unsupported entry root sharding: " << root_sharding.ToString();
4642 }
4643 }
4644
4645 for (const HloInstruction* param : entry->parameter_instructions()) {
4646 TF_RET_CHECK(param->has_sharding());
4647 TF_RET_CHECK(param->sharding().IsReplicated() ||
4648 param->sharding().UniqueDevice().has_value())
4649 << "Unsupported entry parameter sharding:"
4650 << param->sharding().ToString();
4651 }
4652 }
4653
4654 return OkStatus();
4655 }
4656
PreprocessHlos(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)4657 Status SpmdPartitioner::PreprocessHlos(
4658 HloModule* module,
4659 const absl::flat_hash_set<absl::string_view>& execution_threads) {
4660 auto skip_copy_operands = [](HloInstruction* operand,
4661 bool check_single_use =
4662 true) -> HloInstruction* {
4663 while (operand->user_count() == 1 &&
4664 operand->opcode() == HloOpcode::kCopy) {
4665 operand = operand->mutable_operand(0);
4666 }
4667 if (check_single_use && operand->user_count() != 1) {
4668 return nullptr;
4669 }
4670 return operand;
4671 };
4672
4673 for (HloComputation* computation : module->computations(execution_threads)) {
4674 for (HloInstruction* hlo : computation->MakeInstructionPostOrder()) {
4675 if (hlo->sharding().IsTileMaximal() || hlo->sharding().IsManual()) {
4676 // No need to optimize for tile-maximal or manual sharding.
4677 continue;
4678 }
4679
4680 if (hlo->opcode() == HloOpcode::kSlice) {
4681 HloInstruction* operand = skip_copy_operands(hlo->mutable_operand(0));
4682 if (operand == nullptr || operand->sharding() != hlo->sharding()) {
4683 continue;
4684 }
4685
4686 // Merge pad->slice to avoid multiple halo exchanges.
4687 if (operand->opcode() == HloOpcode::kPad) {
4688 std::optional<PaddingConfig> merged_padding =
4689 operand->padding_config();
4690 bool may_have_multi_halo_exchanges = false;
4691 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
4692 const auto& dim = operand->padding_config().dimensions(i);
4693 if (dim.interior_padding() != 0 || hlo->slice_strides(i) != 1) {
4694 merged_padding = std::nullopt;
4695 break;
4696 }
4697 if (hlo->sharding().tile_assignment().dim(i) != 1 &&
4698 (dim.edge_padding_low() != 0 || dim.edge_padding_high() != 0) &&
4699 hlo->shape().dimensions(i) != operand->shape().dimensions(i)) {
4700 // There are padding, slicing, and sharding on this dim.
4701 may_have_multi_halo_exchanges = true;
4702 }
4703
4704 auto* merged_dim = merged_padding->mutable_dimensions(i);
4705 merged_dim->set_edge_padding_low(dim.edge_padding_low() -
4706 hlo->slice_starts(i));
4707 merged_dim->set_edge_padding_high(hlo->slice_limits(i) -
4708 operand->shape().dimensions(i));
4709 }
4710 if (merged_padding.has_value() && may_have_multi_halo_exchanges) {
4711 // Rewrite to a single Pad.
4712 HloInstruction* new_pad =
4713 computation->AddInstruction(HloInstruction::CreatePad(
4714 hlo->shape(), operand->mutable_operand(0),
4715 operand->mutable_operand(1), *merged_padding));
4716 new_pad->set_metadata(operand->metadata());
4717 new_pad->set_sharding(hlo->sharding());
4718 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_pad));
4719 TF_RETURN_IF_ERROR(
4720 computation->RemoveInstructionAndUnusedOperands(hlo));
4721 }
4722 }
4723 }
4724 if (hlo->opcode() == HloOpcode::kConcatenate) {
4725 const int64_t dim = hlo->concatenate_dimension();
4726 if (hlo->sharding().tile_assignment().dim(dim) == 1) {
4727 continue;
4728 }
4729 if (hlo->operand_count() == 2) {
4730 // Find a pattern of "rotate right on one dimension":
4731 // concat(slice(input), slice(input)).
4732 HloInstruction* lhs = skip_copy_operands(hlo->mutable_operand(0));
4733 HloInstruction* rhs = skip_copy_operands(hlo->mutable_operand(1));
4734 if (lhs == nullptr || rhs == nullptr) {
4735 continue;
4736 }
4737 const int64_t amount = FindRotateRightPattern(hlo, lhs, rhs);
4738 if (amount < 0) {
4739 continue;
4740 }
4741 HloInstruction* to_rotate = lhs->mutable_operand(0);
4742 HloInstruction* rotate = computation->AddInstruction(
4743 CreateCustomCallSPMDInternal_RotateRight(to_rotate, dim, amount));
4744 rotate->set_metadata(hlo->metadata());
4745 rotate->set_sharding(hlo->sharding());
4746 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(rotate));
4747 TF_RETURN_IF_ERROR(
4748 computation->RemoveInstructionAndUnusedOperands(hlo));
4749 } else if (hlo->operand_count() == 3) {
4750 // Find the pattern for "pad with wrap": concat(slice(x), x, slice(x))
4751 // All involved values with same sharding.
4752 HloInstruction* lhs = skip_copy_operands(hlo->mutable_operand(0));
4753 HloInstruction* mid = skip_copy_operands(hlo->mutable_operand(1),
4754 /*check_single_use=*/false);
4755 HloInstruction* rhs = skip_copy_operands(hlo->mutable_operand(2));
4756 std::optional<PadWithWrapPattern> pad_pattern =
4757 FindPadWithWrapPattern(hlo, lhs, mid, rhs);
4758 if (!pad_pattern) {
4759 continue;
4760 }
4761
4762 // Since the concat requires that the size of all operands along the
4763 // non-concat dimension is the same, it implies that the lhs/rhs slice
4764 // is slicing along the concat dims.
4765
4766 // Step 1: Pad the mid operand to the final size. The low padding is
4767 // the size of the lhs shape, and high padding is size of rhs shape.
4768 PaddingConfig padding_config =
4769 MakeNoPaddingConfig(hlo->shape().rank());
4770 auto* padding_config_dim = padding_config.mutable_dimensions(dim);
4771 const int64_t low_pad = lhs->shape().dimensions(dim);
4772 const int64_t high_pad = rhs->shape().dimensions(dim);
4773 padding_config_dim->set_edge_padding_low(low_pad);
4774 padding_config_dim->set_edge_padding_high(high_pad);
4775 HloInstruction* zero =
4776 computation->AddInstruction(HloInstruction::CreateConstant(
4777 LiteralUtil::Zero(hlo->shape().element_type())));
4778 zero->set_sharding(HloSharding::Replicate());
4779 HloInstruction* pad =
4780 computation->AddInstruction(HloInstruction::CreatePad(
4781 hlo->shape(), mid, zero, padding_config));
4782 pad->set_metadata(hlo->metadata());
4783 pad->set_sharding(hlo->sharding());
4784
4785 // Step 2: rotate the padded value so that the lhs slice aligns to the
4786 // low of the padded size.
4787 // padded_operand = low_pad | mid | high_pad.
4788 // slice_start in padded_operand = lhs->slice_start + low_pad.
4789 // Rotate left by (lhs->slice_start + low_pad)
4790 // i.e., rotate right = padded_size - (lhs_slice_start + low_pad).
4791 const int64_t padded_size = hlo->shape().dimensions(dim);
4792 const int rotate_lhs_amount =
4793 padded_size - (pad_pattern->lhs_slice_start + low_pad);
4794 HloInstruction* rotate_lhs = computation->AddInstruction(
4795 CreateCustomCallSPMDInternal_RotateRight(pad, dim,
4796 rotate_lhs_amount));
4797 rotate_lhs->set_metadata(hlo->metadata());
4798 rotate_lhs->set_sharding(hlo->sharding());
4799
4800 auto apply_modifiers =
4801 [&](HloInstruction* inst,
4802 const std::vector<const HloInstruction*>& modifiers) {
4803 // Apply the modifiers in the reverse order.
4804 for (auto it = modifiers.crbegin(), end = modifiers.crend();
4805 it != end; ++it) {
4806 const HloInstruction* modifier = *it;
4807 // New shape has same element type as the modifier, but dims
4808 // as inst.
4809 Shape new_shape = ShapeUtil::ChangeElementType(
4810 inst->shape(), modifier->shape().element_type());
4811 inst = computation->AddInstruction(
4812 modifier->CloneWithNewOperands(new_shape, {inst}));
4813 }
4814 return inst;
4815 };
4816 rotate_lhs = apply_modifiers(rotate_lhs, pad_pattern->lhs_modifiers);
4817
4818 // Step 3: rotate the padded value so that the rhs slice aligns to
4819 // high of the padded size.
4820 // padded_operand = low_pad | mid | high_pad.
4821 // slice_start in padded_operand = rhs->slice_start + low_pad.
4822 // slice_end in padded_operand = rhs->slice_start + low_pad +
4823 // high_pad; Rotate right by padded_size - (rhs->slice_start +
4824 // low_pad + high_pad)
4825 const int64_t rotate_rhs_amount =
4826 padded_size - (pad_pattern->rhs_slice_start + low_pad + high_pad);
4827 HloInstruction* rotate_rhs = computation->AddInstruction(
4828 CreateCustomCallSPMDInternal_RotateRight(pad, dim,
4829 rotate_rhs_amount));
4830 rotate_rhs->set_metadata(hlo->metadata());
4831 rotate_rhs->set_sharding(hlo->sharding());
4832 rotate_rhs = apply_modifiers(rotate_rhs, pad_pattern->rhs_modifiers);
4833
4834 // Now merge the 3 results using appropriate selects.
4835 const Shape iota_shape =
4836 ShapeUtil::ChangeElementType(hlo->shape(), U32);
4837 HloInstruction* iota = computation->AddInstruction(
4838 HloInstruction::CreateIota(iota_shape, dim));
4839 iota->set_metadata(hlo->metadata());
4840 iota->set_sharding(hlo->sharding());
4841
4842 struct SelectSpec {
4843 int64_t limit;
4844 HloInstruction* hlo;
4845 Comparison::Direction cmp;
4846 };
4847 const std::array<SelectSpec, 2> selects = {
4848 {// All elements < low_pad come from rotate_lhs.
4849 {low_pad, rotate_lhs, Comparison::Direction::kLt},
4850 // All elements >= padded_size - high_pad come from rotate_rhs
4851 {padded_size - high_pad, rotate_rhs,
4852 Comparison::Direction::kGe}}};
4853
4854 Shape pred_shape = ShapeUtil::ChangeElementType(hlo->shape(), PRED);
4855
4856 HloInstruction* merged = pad;
4857 for (const SelectSpec& select_spec : selects) {
4858 HloInstruction* limit =
4859 computation->AddInstruction(HloInstruction::CreateConstant(
4860 LiteralUtil::CreateR0<uint32_t>(select_spec.limit)));
4861 limit->set_sharding(HloSharding::Replicate());
4862 HloInstruction* limit_bcast = computation->AddInstruction(
4863 HloInstruction::CreateBroadcast(iota_shape, limit, {}));
4864 limit_bcast->set_metadata(hlo->metadata());
4865 limit_bcast->set_sharding(hlo->sharding());
4866 HloInstruction* compare =
4867 computation->AddInstruction(HloInstruction::CreateCompare(
4868 pred_shape, iota, limit_bcast, select_spec.cmp));
4869 compare->set_metadata(hlo->metadata());
4870 compare->set_sharding(hlo->sharding());
4871 merged = computation->AddInstruction(HloInstruction::CreateTernary(
4872 hlo->shape(), HloOpcode::kSelect, compare, select_spec.hlo,
4873 merged));
4874 merged->set_metadata(hlo->metadata());
4875 merged->set_sharding(hlo->sharding());
4876 }
4877
4878 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(merged));
4879 TF_RETURN_IF_ERROR(
4880 computation->RemoveInstructionAndUnusedOperands(hlo));
4881 }
4882 }
4883 }
4884 }
4885 return OkStatus();
4886 }
4887
4888 } // namespace spmd
4889 } // namespace xla
4890