xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <memory>
22 #include <numeric>
23 #include <optional>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/container/inlined_vector.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/comparison_util.h"
34 #include "tensorflow/compiler/xla/literal_util.h"
35 #include "tensorflow/compiler/xla/protobuf_util.h"
36 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
37 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
38 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
39 #include "tensorflow/compiler/xla/service/hlo_computation.h"
40 #include "tensorflow/compiler/xla/service/hlo_cse.h"
41 #include "tensorflow/compiler/xla/service/hlo_dce.h"
42 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
43 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
44 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
45 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
46 #include "tensorflow/compiler/xla/service/hlo_query.h"
47 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
48 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
49 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
50 #include "tensorflow/compiler/xla/service/shape_inference.h"
51 #include "tensorflow/compiler/xla/service/spmd/custom_call_handler.h"
52 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
53 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
54 #include "tensorflow/compiler/xla/shape_util.h"
55 #include "tensorflow/compiler/xla/util.h"
56 #include "tensorflow/compiler/xla/window_util.h"
57 #include "tensorflow/compiler/xla/xla_data.pb.h"
58 #include "tensorflow/core/platform/logging.h"
59 
60 namespace xla {
61 namespace spmd {
62 
63 namespace {
64 using hlo_sharding_util::GroupedSharding;
65 }  // namespace
66 
MakeReport()67 std::string SpmdLogger::MakeReport() {
68   std::string report;
69   absl::StrAppend(&report,
70                   "\n\n***** SPMD memory during transformation *****\n");
71 
72   std::sort(entries_.begin(), entries_.end(),
73             [](auto const& entry0, auto const& entry1) {
74               return entry0.first > entry1.first;
75             });
76   for (int64_t i = 0;
77        i < std::min<int64_t>(report_instruction_count_, entries_.size()); ++i) {
78     absl::StrAppend(
79         &report, "\n  ",
80         tensorflow::strings::HumanReadableNumBytes(entries_[i].first), " : ",
81         entries_[i].second, "\n");
82   }
83 
84   return report;
85 }
86 
RegisterLogEntry(HloInstruction * hlo,const std::vector<HloInstruction * > & group)87 void SpmdLogger::RegisterLogEntry(HloInstruction* hlo,
88                                   const std::vector<HloInstruction*>& group) {
89   if (disabled_) {
90     return;
91   }
92   std::string report = hlo->ToString();
93   int64_t max_value = -1;
94   for (HloInstruction* inst : group) {
95     if (!inst->shape().IsArray()) {
96       continue;
97     }
98     max_value = std::max<int64_t>(max_value, ShapeSizeInBytes(inst->shape()));
99     absl::StrAppend(&report, "     * ", inst->ToString(), "\n");
100   }
101   entries_.push_back(std::make_pair(max_value, report));
102 }
103 
ReportBeforePartition(const HloModule & module,int64_t report_instruction_count)104 /* static */ std::string SpmdLogger::ReportBeforePartition(
105     const HloModule& module, int64_t report_instruction_count) {
106   std::string report;
107   absl::StrAppend(&report,
108                   "\n\n***** SPMD memory usage before partition *****\n");
109   absl::StrAppend(&report, "\n  ** Replicated instructions\n");
110   absl::StrAppend(&report, ReportMemoryUsage(
111                                module,
112                                [](const HloInstruction* hlo) {
113                                  return !hlo->has_sharding() ||
114                                         hlo->sharding().IsReplicated();
115                                },
116                                report_instruction_count));
117   absl::StrAppend(&report, "\n  ** All instructions\n");
118   absl::StrAppend(&report,
119                   ReportMemoryUsage(
120                       module, [](const HloInstruction* hlo) { return true; },
121                       report_instruction_count));
122   return report;
123 }
124 
ReportAfterPartition(const HloModule & module,int64_t report_instruction_count)125 /* static */ std::string SpmdLogger::ReportAfterPartition(
126     const HloModule& module, int64_t report_instruction_count) {
127   std::string report;
128   absl::StrAppend(&report,
129                   "\n\n***** SPMD memory usage after partition *****\n");
130   absl::StrAppend(&report,
131                   ReportMemoryUsage(
132                       module, [](const HloInstruction* hlo) { return true; },
133                       report_instruction_count));
134   return report;
135 }
136 
137 template <typename F>
ReportMemoryUsage(const HloModule & module,const F & filter,int64_t report_instruction_count)138 /* static */ std::string SpmdLogger::ReportMemoryUsage(
139     const HloModule& module, const F& filter,
140     int64_t report_instruction_count) {
141   std::string report;
142   std::vector<HloInstruction*> instructions;
143   instructions.reserve(module.instruction_count());
144 
145   for (auto computation : module.computations()) {
146     if (computation->IsFusionComputation()) {
147       continue;
148     }
149     for (auto hlo : computation->instructions()) {
150       if (!hlo->shape().IsArray() ||
151           ShapeUtil::IsEffectiveScalar(hlo->shape())) {
152         continue;
153       }
154       if (filter(hlo)) {
155         instructions.push_back(hlo);
156       }
157     }
158   }
159 
160   const auto add_report = [&](std::vector<HloInstruction*>* insts) {
161     std::sort(insts->begin(), insts->end(),
162               [](const HloInstruction* inst0, const HloInstruction* inst1) {
163                 return ShapeSizeInBytes(inst0->shape()) >
164                        ShapeSizeInBytes(inst1->shape());
165               });
166     for (int64_t i = 0;
167          i < std::min<int64_t>(report_instruction_count, insts->size()); ++i) {
168       absl::StrAppend(&report, "  ",
169                       tensorflow::strings::HumanReadableNumBytes(
170                           ShapeSizeInBytes((*insts)[i]->shape())),
171                       " : ", (*insts)[i]->ToString(), "\n");
172     }
173   };
174 
175   add_report(&instructions);
176   return report;
177 }
178 
179 namespace {
180 
181 // Clears all sharding attributes from instructions in the module. This must be
182 // called only after all SPMD transformation is complete.
ClearShardingAttributes(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)183 Status ClearShardingAttributes(
184     HloModule* module,
185     const absl::flat_hash_set<absl::string_view>& execution_threads) {
186   for (HloComputation* computation : module->computations(execution_threads)) {
187     for (HloInstruction* hlo : computation->instructions()) {
188       // Keep sharding annotation on Infeed and entry parameters since they're
189       // used by HloReplicationAnalysis later (for ArCrsCombiner).
190       if (hlo->HasSideEffect()) {
191         continue;
192       }
193       if (hlo->opcode() == HloOpcode::kParameter &&
194           computation == module->entry_computation()) {
195         continue;
196       }
197       hlo->clear_sharding();
198     }
199   }
200   return OkStatus();
201 }
202 
GetPartitionGroupsForReplication(const HloSharding & sharding,absl::Span<const int64_t> replication_dims)203 std::vector<std::vector<int64_t>> GetPartitionGroupsForReplication(
204     const HloSharding& sharding, absl::Span<const int64_t> replication_dims) {
205   int64_t group_size = 1;
206   for (int64_t i : replication_dims) {
207     group_size *= sharding.tile_assignment().dim(i);
208   }
209   std::vector<std::vector<int64_t>> partition_groups(
210       sharding.tile_assignment().num_elements() / group_size);
211   sharding.tile_assignment().Each(
212       [&](absl::Span<const int64_t> indices, int64_t partition) {
213         int64_t group_id = 0;
214         for (int64_t i = 0; i < indices.size(); ++i) {
215           if (!absl::c_linear_search(replication_dims, i)) {
216             group_id *= sharding.tile_assignment().dim(i);
217             group_id += indices[i];
218           }
219         }
220         partition_groups[group_id].push_back(partition);
221       });
222   return partition_groups;
223 }
224 
225 // Returns a sharding that is replicated on all the dimensions where the given
226 // window is not unary.
GetShardingReplicatedOnWindowedDimension(const HloSharding & sharding,const Window & window)227 HloSharding GetShardingReplicatedOnWindowedDimension(
228     const HloSharding& sharding, const Window& window) {
229   std::vector<int64_t> dimensions_to_replicate;
230   for (int i = 0; i < window.dimensions_size(); ++i) {
231     const WindowDimension& wd = window.dimensions(i);
232     if (window_util::IsTrivialWindowDimension(wd)) {
233       continue;
234     }
235     dimensions_to_replicate.push_back(i);
236   }
237   return hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
238       sharding, dimensions_to_replicate);
239 }
240 
241 }  // namespace
242 
AddInstruction(std::unique_ptr<HloInstruction> instruction)243 HloInstruction* SpmdBuilder::AddInstruction(
244     std::unique_ptr<HloInstruction> instruction) {
245   HloInstruction* hlo =
246       HloComputation::Builder::AddInstruction(std::move(instruction));
247   if (visiting_hlo_) {
248     hlo->set_metadata(visiting_hlo_->metadata());
249     instructions_[visiting_hlo_].push_back(hlo);
250   }
251   if (hlo->opcode() == HloOpcode::kBroadcast) {
252     for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
253       if (!absl::c_linear_search(hlo->dimensions(), i)) {
254         broadcast_dims_[hlo].insert(i);
255       }
256     }
257   }
258   if (hlo->IsElementwise() && hlo->operand_count() > 0 &&
259       // Copy can have a tuple result.
260       hlo->shape().IsArray()) {
261     absl::flat_hash_set<int64_t> broadcast_dims;
262     for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
263       broadcast_dims.insert(i);
264     }
265     for (int64_t i = 0; i < hlo->operand_count(); ++i) {
266       auto it = broadcast_dims_.find(hlo->operand(i));
267       if (it == broadcast_dims_.end()) {
268         broadcast_dims.clear();
269         break;
270       }
271       for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
272         if (!it->second.contains(i)) {
273           broadcast_dims.erase(i);
274         }
275       }
276     }
277     if (!broadcast_dims.empty()) {
278       broadcast_dims_[hlo] = std::move(broadcast_dims);
279     }
280   }
281   if (hlo->opcode() == HloOpcode::kTranspose) {
282     auto it = broadcast_dims_.find(hlo->operand(0));
283     if (it != broadcast_dims_.end()) {
284       absl::flat_hash_set<int64_t> xpose_broadcast_dims;
285       std::vector<int64_t> reverse_map(hlo->shape().rank());
286       for (int64_t i = 0; i < reverse_map.size(); ++i) {
287         reverse_map[hlo->dimensions(i)] = i;
288       }
289       for (int64_t dim : it->second) {
290         xpose_broadcast_dims.insert(reverse_map[dim]);
291       }
292       broadcast_dims_[hlo] = std::move(xpose_broadcast_dims);
293     }
294   }
295   if (hlo->opcode() == HloOpcode::kReshape &&
296       Product(hlo->shape().dimensions()) > 0) {
297     auto it = broadcast_dims_.find(hlo->operand(0));
298     if (it != broadcast_dims_.end()) {
299       absl::flat_hash_set<int64_t> reshape_broadcast_dims;
300       for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
301         reshape_broadcast_dims.insert(i);
302       }
303       std::vector<int64_t> before_dim_size_stack;
304       std::vector<int64_t> after_dim_size_stack;
305       const int64_t operand0_rank = hlo->operand(0)->shape().rank();
306       const int64_t hlo_shape_rank = hlo->shape().rank();
307       before_dim_size_stack.reserve(operand0_rank);
308       after_dim_size_stack.reserve(hlo_shape_rank);
309       for (int64_t i = operand0_rank - 1; i >= 0; --i) {
310         before_dim_size_stack.push_back(hlo->operand(0)->shape().dimensions(i));
311       }
312       for (int64_t i = hlo_shape_rank - 1; i >= 0; --i) {
313         after_dim_size_stack.push_back(hlo->shape().dimensions(i));
314       }
315       while (!before_dim_size_stack.empty() && !after_dim_size_stack.empty()) {
316         int64_t before_size = before_dim_size_stack.back();
317         int64_t after_size = after_dim_size_stack.back();
318         int64_t current_before_dim =
319             hlo->operand(0)->shape().rank() - before_dim_size_stack.size();
320         int64_t current_after_dim =
321             hlo->shape().rank() - after_dim_size_stack.size();
322         before_dim_size_stack.pop_back();
323         after_dim_size_stack.pop_back();
324         if (!it->second.contains(current_before_dim)) {
325           reshape_broadcast_dims.erase(current_after_dim);
326         }
327         if (before_size == after_size) {
328           continue;
329         }
330         if (before_size % after_size == 0) {
331           // Split dim.
332           before_dim_size_stack.push_back(before_size / after_size);
333         } else if (after_size % before_size == 0) {
334           // Merge dim.
335           after_dim_size_stack.push_back(after_size / before_size);
336         } else {
337           // Other cases, mark all remaining dims as non-broadcast.
338           for (int64_t i = current_after_dim; i < hlo->shape().rank(); ++i) {
339             reshape_broadcast_dims.erase(i);
340           }
341           break;
342         }
343       }
344       if (!before_dim_size_stack.empty() || !after_dim_size_stack.empty()) {
345         reshape_broadcast_dims.clear();
346       }
347       if (!reshape_broadcast_dims.empty()) {
348         broadcast_dims_[hlo] = std::move(reshape_broadcast_dims);
349       }
350     }
351   }
352   if (hlo->opcode() == HloOpcode::kSlice ||
353       hlo->opcode() == HloOpcode::kDynamicSlice) {
354     auto it = broadcast_dims_.find(hlo->operand(0));
355     if (it != broadcast_dims_.end()) {
356       auto dims = it->second;
357       broadcast_dims_[hlo] = std::move(dims);
358     }
359   }
360   if (hlo->opcode() == HloOpcode::kPad) {
361     auto it = broadcast_dims_.find(hlo->operand(0));
362     if (it != broadcast_dims_.end()) {
363       absl::flat_hash_set<int64_t> pad_broadcast_dims;
364       for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
365         const auto& dim = hlo->padding_config().dimensions(i);
366         if (dim.edge_padding_low() == 0 && dim.edge_padding_high() == 0 &&
367             dim.interior_padding() == 0 && it->second.contains(i)) {
368           pad_broadcast_dims.insert(i);
369         }
370       }
371       if (!pad_broadcast_dims.empty()) {
372         broadcast_dims_[hlo] = std::move(pad_broadcast_dims);
373       }
374     }
375   }
376   return hlo;
377 }
378 
Reshard(const HloSharding & target,std::optional<Literal> pad_value)379 PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target,
380                                        std::optional<Literal> pad_value) {
381   if (sharding() == target) {
382     return *this;
383   }
384   auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
385   // Replace existing reshard cache for target if we are sharding with new
386   // padding value.
387   const bool replace_cache = pad_value.has_value();
388   const bool is_to_replicate =
389       hlo_->shape().IsArray() && target.NumTiles() < sharding().NumTiles();
390   const bool use_cache =
391       !is_to_replicate || state_.partitioner->options().cache_all_gather;
392   if (!replace_cache && use_cache) {
393     auto it = cache.find(target);
394     if (it != cache.end()) {
395       return it->second;
396     }
397   }
398 
399   auto resharded = ReshardNoCache(target, std::move(pad_value));
400   // Update cache for resharded hlo.
401   {
402     auto& cache =
403         state_.reshard_cache->per_hlo_cache[resharded.hlo()].reshard_cache;
404     cache.insert_or_assign(sharding(), *this);
405   }
406   // Update cache for to-reshard hlo.
407   if (use_cache) {
408     // Get the cache again as it might be invalidated by the insertion above.
409     auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
410     auto [it, _] = cache.insert_or_assign(target, std::move(resharded));
411     return it->second;
412   }
413   return resharded;
414 }
415 
ReshardNoCache(const HloSharding & target,std::optional<Literal> pad_value,bool allow_full_replication)416 PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target,
417                                               std::optional<Literal> pad_value,
418                                               bool allow_full_replication) {
419   VLOG(2) << "Resharding " << hlo_->ToString() << " from "
420           << hlo_->sharding().ToString() << " to " << target.ToString();
421   const Shape& shape = hlo_->shape();
422   if (shape.element_type() == TOKEN) {
423     return *this;
424   }
425   CHECK(shape.IsTuple() || !target.IsTuple());
426 
427   // Tuple shape instructions may have non-tuple sharding, which means that the
428   // same sharding applies to all the leaves.
429   if (shape.IsTuple() && !target.IsTuple()) {
430     return Reshard(target.GetTupleSharding(shape).ValueOrDie());
431   }
432 
433   // For a tuple shape, recursively apply Reshard to all the leaves and return
434   // a tuple instruction.
435   if (shape.IsTuple()) {
436     std::vector<HloInstruction*> elements;
437     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
438       auto subshape = ShapeUtil::GetTupleElementShape(shape, i);
439       auto element = state_.b->AddInstruction(
440           HloInstruction::CreateGetTupleElement(subshape, hlo(), i));
441       element->set_sharding(sharding().GetSubSharding(shape, {i}));
442       elements.push_back(
443           PartitionedHlo(
444               element, ShapeUtil::GetTupleElementShape(base_shape_, i), state_)
445               .Reshard(target.GetSubSharding(shape, {i}))
446               .hlo());
447     }
448     auto tuple =
449         state_.b->AddInstruction(HloInstruction::CreateTuple(elements));
450     tuple->set_sharding(target);
451     return PartitionedHlo(tuple, base_shape_, state_);
452   }
453 
454   if (sharding() == target) {
455     return *this;
456   }
457 
458   CHECK_EQ(target.IsManualSubgroup(), sharding().IsManualSubgroup());
459   if (sharding().IsManualSubgroup()) {
460     auto grouped = hlo_sharding_util::GetManualSubgroupSharding(sharding());
461     auto target_grouped = AlignGroupsWithIfCompatible(
462         hlo_sharding_util::GetManualSubgroupSharding(target), grouped);
463     CHECK(target_grouped.has_value())
464         << "Resharding target has incompatible sharding subgroups. From "
465         << sharding().ToString() << " to " << target.ToString();
466     HloSharding original_sharding = sharding();
467     hlo_->set_sharding(grouped.sharding);
468     HloInstruction* partitioned =
469         PartitionedHlo(hlo_, base_shape_,
470                        CreatePerGroupPartitioningState(
471                            state(), grouped.device_groups, state_.b))
472             .ReshardNoCache(target_grouped->sharding)
473             .hlo();
474     hlo_->set_sharding(original_sharding);
475     partitioned->set_sharding(target);
476     return PartitionedHlo(partitioned, base_shape_, state_);
477   }
478 
479   if (CanReshardWithCollectivePermute(sharding(), target)) {
480     return ReshardWithCollectivePermute(target);
481   }
482 
483   if (auto src_tgt_dims =
484           GetReshardAllToAllSourceTargetDims(sharding(), target)) {
485     return ReshardWithAllToAll(target, *src_tgt_dims);
486   }
487 
488   if (!target.IsTileMaximal() && sharding().ReplicateOnLastTileDim()) {
489     auto try_reshard = ReshardFromPartialReplicateWithDynamicSlice(target);
490     if (try_reshard.has_value()) {
491       return try_reshard.value();
492     }
493     try_reshard = ReshardPartialReplicateWithAllToAll(target);
494     if (try_reshard.has_value()) {
495       return try_reshard.value();
496     }
497   }
498 
499   if (!sharding().IsTileMaximal() && target.ReplicateOnLastTileDim()) {
500     auto try_reshard = ReshardToPartialReplicateWithAllGather(target);
501     if (try_reshard.has_value()) {
502       return try_reshard.value();
503     }
504     try_reshard = ReshardPartialReplicateWithAllToAll(target);
505     if (try_reshard.has_value()) {
506       return try_reshard.value();
507     }
508   }
509 
510   // If not replicated yet, first replicate and then reshard to use one of the
511   // two implementations below.
512   if (!sharding().IsReplicated()) {
513     if (!target.IsReplicated()) {
514       auto reshard = TryComplexReshardHandling(target);
515       if (reshard.has_value()) {
516         return reshard.value();
517       }
518       if (!allow_full_replication) {
519         return *this;
520       }
521       LOG(ERROR)
522           << "[spmd] Involuntary full rematerialization. The compiled was "
523              "not able to go from sharding "
524           << sharding().ToString(/*include_metadata=*/true) << " to "
525           << target.ToString(/*include_metadata=*/true)
526           << " without doing a full rematerialization of the tensor. You "
527              "probably want to enrich the sharding annotations to prevent "
528              "this from happening.";
529     }
530     return Replicate().Reshard(target);
531   }
532 
533   // 'Replicated' to 'SingleDevice'.
534   if (target.IsTileMaximal()) {
535     auto copy = state_.b->AddInstruction(
536         HloInstruction::CreateUnary(hlo_->shape(), HloOpcode::kCopy, hlo_));
537     copy->set_sharding(target);
538     return PartitionedHlo(copy, base_shape_, state_);
539   }
540 
541   // 'Replicated' to partial replicated.
542   if (target.ReplicateOnLastTileDim()) {
543     std::vector<int64_t> group_dims(target.tile_assignment().num_dimensions() -
544                                     1);
545     std::iota(group_dims.begin(), group_dims.end(), 0);
546     auto target_grouped =
547         hlo_sharding_util::GroupShardingOnDims(target, group_dims);
548     auto partially_sharded = PerGroupSliceFromReplicated(
549         hlo_, state_.partition_id, target_grouped.device_groups, group_dims,
550         target_grouped.group_dim_sizes, state_.b);
551     partially_sharded->set_sharding(target);
552     return PartitionedHlo(partially_sharded, base_shape(), state_);
553   }
554 
555   // 'Replicated' to 'Tiled'.
556   auto padded_hlo = PadBaseShapeBeforeUnevenTiledSharding(
557       hlo_, target, state_.b, std::move(pad_value));
558   auto shard_shape = MakePartitionedShape(shape, target);
559   auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
560       shard_shape, padded_hlo,
561       MakePartitionOffsets(shape, target, state_.partition_id, state_.b),
562       shard_shape.dimensions()));
563   slice->set_sharding(target);
564   return PartitionedHlo(slice, base_shape_, state_);
565 }
566 
PadWithValue(HloInstruction * pad_value,absl::Span<const int64_t> left_padded_dims,absl::Span<const int64_t> skipped_dims) const567 PartitionedHlo PartitionedHlo::PadWithValue(
568     HloInstruction* pad_value, absl::Span<const int64_t> left_padded_dims,
569     absl::Span<const int64_t> skipped_dims) const {
570   HloInstruction* result =
571       PadWithValueHlo(pad_value, left_padded_dims, skipped_dims);
572   if (hlo_ != result) {
573     result->set_sharding(hlo_->sharding());
574   }
575   return PartitionedHlo(result, base_shape_, state_);
576 }
577 
PadWithValueHlo(HloInstruction * pad_value,absl::Span<const int64_t> left_padded_dims,absl::Span<const int64_t> skipped_dims) const578 HloInstruction* PartitionedHlo::PadWithValueHlo(
579     HloInstruction* pad_value, absl::Span<const int64_t> left_padded_dims,
580     absl::Span<const int64_t> skipped_dims) const {
581   const HloSharding& sharding = hlo_->sharding();
582   const Shape& shape = hlo_->shape();
583   CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
584   if (sharding.IsReplicated() || EvenlyPartitions(base_shape_, sharding)) {
585     return hlo_;
586   }
587   CHECK(!sharding.IsTileMaximal());
588   auto index_shape = ShapeUtil::ChangeElementType(shape, S32);
589   auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED);
590   auto get_mask_for_dim = [&](int64_t dim, HloInstruction* start_index) {
591     // Comparison: iota + start_index < valid_size
592     auto iota =
593         state_.b->AddInstruction(HloInstruction::CreateIota(index_shape, dim));
594     auto broadcast_start_index = state_.b->AddInstruction(
595         HloInstruction::CreateBroadcast(index_shape, start_index, {}));
596     auto index_in_full_shape =
597         state_.b->AddInstruction(HloInstruction::CreateBinary(
598             index_shape, HloOpcode::kAdd, iota, broadcast_start_index));
599     ComparisonDirection direction = ComparisonDirection::kLt;
600     int64_t index_limit = base_shape_.dimensions(dim);
601     if (absl::c_linear_search(left_padded_dims, dim)) {
602       direction = ComparisonDirection::kGe;
603       index_limit =
604           index_shape.dimensions(dim) * sharding.tile_assignment().dim(dim) -
605           index_limit;
606     }
607     auto limit = state_.b->AddInstruction(HloInstruction::CreateConstant(
608         LiteralUtil::CreateR0<int32_t>(index_limit)));
609     auto broadcast_limit = state_.b->AddInstruction(
610         HloInstruction::CreateBroadcast(index_shape, limit, {}));
611     return state_.b->AddInstruction(HloInstruction::CreateCompare(
612         mask_shape, index_in_full_shape, broadcast_limit, direction));
613   };
614 
615   HloInstruction* mask = nullptr;
616   auto offsets = MakePartitionOffsets(base_shape_, sharding,
617                                       state_.partition_id, state_.b);
618   for (int64_t i = 0; i < shape.rank(); ++i) {
619     if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0 ||
620         absl::c_linear_search(skipped_dims, i)) {
621       continue;
622     }
623     if (mask == nullptr) {
624       mask = get_mask_for_dim(i, offsets[i]);
625     } else {
626       mask = state_.b->AddInstruction(
627           HloInstruction::CreateBinary(mask->shape(), HloOpcode::kAnd, mask,
628                                        get_mask_for_dim(i, offsets[i])));
629     }
630   }
631 
632   if (mask == nullptr) {
633     return hlo_;
634   }
635 
636   auto broadcast_pad_value = state_.b->AddInstruction(
637       HloInstruction::CreateBroadcast(shape, pad_value, {}));
638   return state_.b->AddInstruction(HloInstruction::CreateTernary(
639       shape, HloOpcode::kSelect, mask, hlo_, broadcast_pad_value));
640 }
641 
PadWithZero(absl::Span<const int64_t> left_padded_dims,absl::Span<const int64_t> skipped_dims) const642 PartitionedHlo PartitionedHlo::PadWithZero(
643     absl::Span<const int64_t> left_padded_dims,
644     absl::Span<const int64_t> skipped_dims) const {
645   auto zero = state_.b->AddInstruction(HloInstruction::CreateConstant(
646       LiteralUtil::Zero(hlo_->shape().element_type())));
647   return PadWithValue(zero, left_padded_dims, skipped_dims);
648 }
649 
650 std::optional<PartitionedHlo::WindowedInputShardReturnValue>
ReshardAsWindowedInput(const Window & window,const HloSharding & target,HloInstruction * pad_value,bool mask_invalid_region)651 PartitionedHlo::ReshardAsWindowedInput(const Window& window,
652                                        const HloSharding& target,
653                                        HloInstruction* pad_value,
654                                        bool mask_invalid_region) {
655   auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].window_reshard_cache;
656   for (auto& entry : cache) {
657     if (std::get<0>(entry) == target &&
658         protobuf_util::ProtobufEquals(std::get<1>(entry), window)) {
659       return std::get<2>(entry);
660     }
661   }
662   auto update_cache = [&](WindowedInputShardReturnValue result) {
663     cache.emplace_back(target, window, std::move(result));
664     return std::get<2>(cache.back());
665   };
666   VLOG(2) << "ReshardAsWindowedInput()\n"
667           << "\twindow:" << window_util::ToString(window)
668           << "\ttarget sharding:" << target.ToString();
669 
670   CHECK(!target.IsTileMaximal());
671   auto partition_ordinals =
672       MakeTiledPartitionOrdinals(target, state_.partition_id, state_.b);
673   auto shard_shape = base_shape_;
674 
675   std::vector<MultiplyAddDivideOffsetCalculation> start_on_padded_calculations(
676       base_shape_.rank());
677   std::vector<MultiplyAddDivideOffsetCalculation> limit_on_padded_calculations(
678       base_shape_.rank());
679   std::vector<HloInstruction*> dynamic_slice_offset_on_output(
680       base_shape_.rank(), nullptr);
681 
682   Window shard_window = window;
683   Shape padded_shape = base_shape_;
684   std::vector<HloInstruction*> offsets_on_padded_shape(base_shape_.rank());
685   std::vector<int64_t> per_shard_window_counts(base_shape_.rank());
686   std::vector<int64_t> explicit_left_padding(base_shape_.rank());
687   // Track if any shards can be skipped.
688   std::vector<int64_t> trimmed_target_sharding_tile_shape(base_shape_.rank());
689   // There can be at most 2 ranges of skipped shards on a dimension: 1) on the
690   // right side, 2) in the middle. The following vector tracks the middle range
691   // (i.e., <start, size>). The leftmost shard must not be skipped because
692   // outputs are left-aligned.
693   std::vector<std::pair<int64_t, int64_t>> trimmed_target_sharding_middle_range(
694       base_shape_.rank(), std::pair<int64_t, int64_t>(-1, -1));
695   bool trimmed_shards = false;
696   std::vector<int64_t> dims_needs_pre_masking;
697   Shape halo_exchange_base_shape = base_shape_;
698   // If all useful input data are in a single shard, we can skip in-shard data
699   // (e.g., those that belong to negative padding) via a local slice.
700   bool trimmed_in_shard = false;
701   std::vector<int64_t> pre_halo_exchange_slice_starts(base_shape_.rank(), 0);
702   std::vector<int64_t> pre_halo_exchange_slice_limits(
703       hlo_->shape().dimensions().begin(), hlo_->shape().dimensions().end());
704   std::vector<bool> can_leave_dimension_partitioned(base_shape_.rank(), false);
705   for (int64_t i = 0; i < base_shape_.rank(); ++i) {
706     can_leave_dimension_partitioned[i] =
707         window_util::IsTrivialWindowDimension(window.dimensions(i));
708   }
709   for (int64_t i = 0; i < base_shape_.rank(); ++i) {
710     // Do not pad non-partitioned dimensions.
711     int64_t shard_count = target.tile_assignment().dim(i);
712     trimmed_target_sharding_tile_shape[i] = shard_count;
713     if (shard_count == 1 || can_leave_dimension_partitioned[i]) {
714       offsets_on_padded_shape[i] = state_.b->AddInstruction(
715           HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
716       shard_shape.set_dimensions(
717           i, CeilOfRatio(base_shape_.dimensions(i), shard_count));
718       continue;
719     }
720     const WindowDimension& wd = window.dimensions(i);
721     WindowDimension* swd = shard_window.mutable_dimensions(i);
722     const int64_t dilated_size = 1 + (wd.size() - 1) * wd.window_dilation();
723     const int64_t full_size =
724         1 + (base_shape_.dimensions(i) - 1) * wd.base_dilation() +
725         wd.padding_high() + wd.padding_low();
726     int64_t window_count = (full_size - dilated_size) / wd.stride() + 1;
727     per_shard_window_counts[i] = CeilOfRatio(window_count, shard_count);
728     // Find skippable shards on the right side. This could only happen when
729     // window_count < shard_count so that the right-most shard does not have any
730     // output.
731     int64_t input_shard_size = hlo_->shape().dimensions(i);
732     if (window_count < shard_count && wd.window_dilation() == 1 &&
733         wd.base_dilation() == 1) {
734       // Test if some shards do not have any useful input (all uneven padding or
735       // window negative padding).
736       int64_t useful_input_shards = CeilOfRatio(
737           base_shape_.dimensions(i) + wd.padding_high(), input_shard_size);
738       if (useful_input_shards < shard_count) {
739         shard_count = std::max<int64_t>(useful_input_shards, window_count);
740         trimmed_shards = true;
741         trimmed_target_sharding_tile_shape[i] = shard_count;
742         if (shard_count == 1) {
743           offsets_on_padded_shape[i] = state_.b->AddInstruction(
744               HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
745           swd->set_padding_high(base_shape_.dimensions(i) + wd.padding_high() -
746                                 hlo_->shape().dimensions(i));
747           continue;
748         }
749         // Make sure the halo-exchange base shape is evenly sharded on the new
750         // shard count.
751         halo_exchange_base_shape.set_dimensions(i,
752                                                 input_shard_size * shard_count);
753         if (input_shard_size * shard_count > base_shape_.dimensions(i) &&
754             wd.padding_high() > 0) {
755           // The new shape has paddings, make sure it's masked.
756           dims_needs_pre_masking.push_back(i);
757         } else if (wd.padding_high() < 0 &&
758                    full_size - wd.padding_low() < input_shard_size) {
759           // If the useful input is smaller than a shard, we treat the shard
760           // size as the useful input size and slice later.
761           input_shard_size = full_size - wd.padding_low();
762           halo_exchange_base_shape.set_dimensions(
763               i, input_shard_size * shard_count);
764           pre_halo_exchange_slice_limits[i] = input_shard_size;
765           trimmed_in_shard = true;
766         }
767       }
768     }
769 
770     // We use explicit padding for full dilations, then use padding_low and
771     // padding_high on the sharded op for the remaining. padding_low and
772     // padding_high are now given initial values, which will be later updated if
773     // dilation is not 1.
774     explicit_left_padding[i] = wd.padding_low() / wd.base_dilation();
775     swd->set_padding_low(wd.padding_low() % wd.base_dilation());
776     swd->set_padding_high(0);
777 
778     // Find potential skippable range in the middle. This could happen when only
779     // a few shards have outputs (window_count < shard_count), but there is a
780     // large negative left padding such that the start shard that has useful
781     // input does not have any outputs.
782     if (window_count < shard_count && wd.window_dilation() == 1 &&
783         wd.base_dilation() == 1) {
784       int64_t middle_empty_shards =
785           (-explicit_left_padding[i]) / input_shard_size - window_count;
786       if (middle_empty_shards > 0) {
787         shard_count -= middle_empty_shards;
788         CHECK_GT(shard_count, 1);
789         trimmed_target_sharding_middle_range[i].first = window_count;
790         trimmed_target_sharding_middle_range[i].second = middle_empty_shards;
791         trimmed_shards = true;
792         trimmed_target_sharding_tile_shape[i] = shard_count;
793         // Reduce negative padding.
794         explicit_left_padding[i] += middle_empty_shards * input_shard_size;
795         halo_exchange_base_shape.set_dimensions(i,
796                                                 input_shard_size * shard_count);
797         HloInstruction* ordinal = partition_ordinals[i];
798         HloInstruction* left_count = CreateR0WithType<int32_t>(
799             ordinal->shape().element_type(), window_count, state_.b);
800         HloInstruction* on_left =
801             state_.b->AddInstruction(HloInstruction::CreateCompare(
802                 ShapeUtil::ChangeElementType(ordinal->shape(), PRED), ordinal,
803                 left_count, ComparisonDirection::kLt));
804         HloInstruction* right_ordinal =
805             state_.b->AddInstruction(HloInstruction::CreateBinary(
806                 ordinal->shape(), HloOpcode::kSubtract, ordinal, left_count));
807         partition_ordinals[i] =
808             state_.b->AddInstruction(HloInstruction::CreateTernary(
809                 partition_ordinals[i]->shape(), HloOpcode::kSelect, on_left,
810                 partition_ordinals[i], right_ordinal));
811         if (-explicit_left_padding[i] > input_shard_size * (shard_count - 1)) {
812           // If all useful data is on the last shard, we can skip extra negative
813           // left padding.
814           int64_t skip_amount =
815               -explicit_left_padding[i] - input_shard_size * (shard_count - 1);
816           input_shard_size -= skip_amount;
817           explicit_left_padding[i] += skip_amount * shard_count;
818           pre_halo_exchange_slice_starts[i] = skip_amount;
819           trimmed_in_shard = true;
820           // We may have enabled a new skipping opportunity on the right side
821           // within the only shard that has useful input, because we excluded
822           // negative left padding regions this time.
823           if (full_size < input_shard_size) {
824             skip_amount = input_shard_size - full_size;
825             pre_halo_exchange_slice_limits[i] -= skip_amount;
826             explicit_left_padding[i] += skip_amount * (shard_count - 1);
827             input_shard_size = full_size;
828           }
829           halo_exchange_base_shape.set_dimensions(
830               i, input_shard_size * shard_count);
831         }
832       }
833     }
834     if (full_size < dilated_size) {
835       VLOG(2) << "Failed to reshard window operand because the window size is "
836                  "larger than padded base size";
837       return std::nullopt;
838     }
839     if (wd.stride() != 1 &&
840         (wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != 0) {
841       // TODO(yuanzx): Support this case.
842       VLOG(2) << "Failed to reshard window operand due to non-trivial dilation";
843       return std::nullopt;
844     }
845 
846     // Calculation for the first element needed on the 'padded-but-not-dilated'
847     // shape. The start on the dilated shape could be a hole, so we add
848     // wd.base_dilation() - 1 to the constant term to skip the leading holes.
849     start_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
850         wd.stride() * per_shard_window_counts[i],
851         wd.base_dilation() - 1 - swd->padding_low(), wd.base_dilation());
852     int64_t dilated_shard_size =
853         wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
854     limit_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
855         wd.stride() * per_shard_window_counts[i],
856         dilated_shard_size + wd.base_dilation() - 1 - swd->padding_low(),
857         wd.base_dilation());
858 
859     offsets_on_padded_shape[i] = start_on_padded_calculations[i].Calculate(
860         partition_ordinals[i], state_.b);
861 
862     auto shard_size_function =
863         limit_on_padded_calculations[i] - start_on_padded_calculations[i];
864     int64_t max_shard_size = shard_size_function.MaxInRange(0, shard_count);
865     shard_shape.set_dimensions(i, max_shard_size);
866     padded_shape.set_dimensions(
867         i, limit_on_padded_calculations[i].Calculate(shard_count - 1));
868 
869     // For base dilation, calculate the needed padding_low and padding_high, as
870     // well as the offset for the output if a dynamic slice is needed after the
871     // sharded op.
872     if (wd.base_dilation() != 1) {
873       // Returns the offset of a shard's first valid element in the dilated
874       // shard.
875       auto get_first_valid_element_offset_on_dilated_shard =
876           [&](int64_t shard_ordinal) {
877             return start_on_padded_calculations[i].Calculate(shard_ordinal) *
878                        wd.base_dilation() +
879                    swd->padding_low() -
880                    wd.stride() * per_shard_window_counts[i] * shard_ordinal;
881           };
882       CHECK_EQ(get_first_valid_element_offset_on_dilated_shard(0),
883                swd->padding_low());
884 
885       // Determine swd->padding_high.
886       for (int64_t shard_ordinal = 0; shard_ordinal < shard_count;
887            ++shard_ordinal) {
888         int64_t wanted_limit_on_dilated_shard =
889             wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
890         int64_t actual_limit_on_dilated_shard_without_pad_high =
891             get_first_valid_element_offset_on_dilated_shard(shard_ordinal) +
892             (max_shard_size - 1) * wd.base_dilation() + 1;
893         swd->set_padding_high(std::max<int64_t>(
894             swd->padding_high(),
895             wanted_limit_on_dilated_shard -
896                 actual_limit_on_dilated_shard_without_pad_high));
897       }
898 
899       // Determine swd->padding_low and output dynamic slice index.
900       if (wd.stride() == 1) {
901         int64_t max_pad_low =
902             get_first_valid_element_offset_on_dilated_shard(0);
903         bool all_same = true;
904         for (int64_t shard_ordinal = 1; shard_ordinal < shard_count;
905              ++shard_ordinal) {
906           int64_t start =
907               get_first_valid_element_offset_on_dilated_shard(shard_ordinal);
908           if (start != swd->padding_low()) {
909             all_same = false;
910           }
911           max_pad_low = std::max(max_pad_low, start);
912         }
913         if (!all_same) {
914           auto start_on_padded_input =
915               start_on_padded_calculations[i].Calculate(partition_ordinals[i],
916                                                         state_.b);
917           // We will calculate
918           //   max_pad_low - (first_window - required_first_window)
919           // which equals
920           //   required_first_window - (first_window - max_pad_low)
921           auto first_window_minus_max_pad_low =
922               MultiplyAddDivideOffsetCalculation(
923                   wd.base_dilation(), swd->padding_low() - max_pad_low, 1)
924                   .Calculate(start_on_padded_input, state_.b);
925           auto required_first_window =
926               MultiplyAddDivideOffsetCalculation(per_shard_window_counts[i], 0,
927                                                  1)
928                   .Calculate(partition_ordinals[i], state_.b);
929           dynamic_slice_offset_on_output[i] =
930               state_.b->AddInstruction(HloInstruction::CreateBinary(
931                   required_first_window->shape(), HloOpcode::kSubtract,
932                   required_first_window, first_window_minus_max_pad_low));
933         }
934         swd->set_padding_low(max_pad_low);
935       } else {
936         if ((wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() !=
937             0) {
938           // General base dilation not yet implemented.
939           return std::nullopt;
940         }
941         // padding_low on all shards should equal the initially assigned
942         // swd->padding_low(), i.e., the padding_low() on the original window.
943       }
944     }
945   }
946 
947   // Returns the output dynamic slice offset when needed, and std::nullopt
948   // otherwise.
949   auto get_dynamic_slice_offset_on_output_if_needed =
950       [&]() -> std::optional<std::vector<HloInstruction*>> {
951     if (absl::c_all_of(
952             dynamic_slice_offset_on_output,
953             [](HloInstruction* offset) { return offset == nullptr; })) {
954       return std::nullopt;
955     }
956     auto zero = state_.b->AddInstruction(
957         HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
958     for (int64_t i = 0; i < dynamic_slice_offset_on_output.size(); ++i) {
959       if (dynamic_slice_offset_on_output[i] == nullptr) {
960         dynamic_slice_offset_on_output[i] = zero;
961       }
962     }
963     return dynamic_slice_offset_on_output;
964   };
965 
966   auto handle_all_windowed_dimensions_are_replicated = [&]() {
967     PaddingConfig padding_config;
968     auto pad_hlo_shape = padded_shape;
969     for (int64_t i = 0; i < base_shape_.rank(); ++i) {
970       auto padding_config_dim = padding_config.add_dimensions();
971       padding_config_dim->set_interior_padding(0);
972       // Do not pad non-partitioned dimensions or partitioned dimensions that
973       // are already sharded in a way that where the windowed sharding matches
974       // the sharding we want.
975       if (target.tile_assignment().dim(i) == 1 ||
976           can_leave_dimension_partitioned[i]) {
977         padding_config_dim->set_edge_padding_low(0);
978         padding_config_dim->set_edge_padding_high(0);
979         pad_hlo_shape.set_dimensions(i, hlo_->shape().dimensions(i));
980         continue;
981       }
982       padding_config_dim->set_edge_padding_low(explicit_left_padding[i]);
983       padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) -
984                                                 explicit_left_padding[i] -
985                                                 base_shape_.dimensions(i));
986     }
987     auto padded_hlo =
988         ShapeUtil::Compatible(pad_hlo_shape, base_shape_)
989             ? hlo_
990             : state_.b->AddInstruction(HloInstruction::CreatePad(
991                   pad_hlo_shape, hlo_, pad_value, padding_config));
992     auto sharded_input =
993         state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
994             shard_shape, padded_hlo, offsets_on_padded_shape,
995             shard_shape.dimensions()));
996     return update_cache(WindowedInputShardReturnValue{
997         sharded_input, shard_window,
998         get_dynamic_slice_offset_on_output_if_needed()});
999   };
1000 
1001   auto sharding_with_non_windowed_dims_replicated =
1002       GetShardingReplicatedOnWindowedDimension(target, window);
1003   // If the currrent HLO is replicated or all windows dimensions are replicated,
1004   // pad then slice. If the target sharding and current sharding are not the
1005   // same then give the halo exchange system a chance to run as it can skip
1006   // generating a dynamic slice.
1007   if (sharding().IsReplicated() ||
1008       (target != sharding() &&
1009        sharding_with_non_windowed_dims_replicated == sharding())) {
1010     return handle_all_windowed_dimensions_are_replicated();
1011   }
1012   if (target != sharding() &&
1013       sharding_with_non_windowed_dims_replicated != sharding()) {
1014     return Reshard(target).ReshardAsWindowedInput(window, target, pad_value);
1015   }
1016   if (Product(trimmed_target_sharding_tile_shape) == 1) {
1017     // The trimmed sharding may have just one shard left. We can simply return
1018     // hlo_ in this case.
1019     return update_cache(WindowedInputShardReturnValue{
1020         hlo_, shard_window, get_dynamic_slice_offset_on_output_if_needed()});
1021   }
1022   if (target.ReplicateOnLastTileDim()) {
1023     trimmed_target_sharding_tile_shape.push_back(
1024         target.tile_assignment().dimensions().back());
1025   }
1026   std::optional<HloSharding> trimmed_target;
1027   const HloSharding* halo_exchange_target = &target;
1028   if (trimmed_shards) {
1029     // Remove devices on the right side.
1030     Array<int64_t> trimmed_devices(trimmed_target_sharding_tile_shape);
1031     trimmed_devices.Each([&](absl::Span<const int64_t> indices, int64_t* d) {
1032       std::vector<int64_t> target_indices(indices.begin(), indices.end());
1033       for (int64_t i = 0; i < base_shape_.rank(); ++i) {
1034         const auto& range = trimmed_target_sharding_middle_range[i];
1035         if (range.first >= 0 && indices[i] >= range.first) {
1036           target_indices[i] += range.second;
1037         }
1038       }
1039       *d = target.tile_assignment()(target_indices);
1040     });
1041     trimmed_target = target.ReplicateOnLastTileDim()
1042                          ? HloSharding::PartialTile(trimmed_devices)
1043                          : HloSharding::Tile(trimmed_devices);
1044     halo_exchange_target = &*trimmed_target;
1045   }
1046 
1047   // Halo exchange.
1048   HloInstruction* visiting_hlo = hlo_;
1049 
1050   if (!dims_needs_pre_masking.empty()) {
1051     std::vector<int64_t> skipped_dims;
1052     for (int dim = 0; dim < base_shape_.rank(); ++dim) {
1053       if (!absl::c_linear_search(dims_needs_pre_masking, dim)) {
1054         skipped_dims.push_back(dim);
1055       }
1056     }
1057     visiting_hlo = PadWithValueHlo(pad_value, /*left_padded_dims=*/{},
1058                                    /*skipped_dims=*/skipped_dims);
1059   }
1060 
1061   // If we skipped unused data within a shard, we need to slice the input shard.
1062   if (trimmed_in_shard) {
1063     std::vector<int64_t> slice_sizes(halo_exchange_base_shape.rank());
1064     for (int64_t i = 0; i < slice_sizes.size(); ++i) {
1065       slice_sizes[i] =
1066           pre_halo_exchange_slice_limits[i] - pre_halo_exchange_slice_starts[i];
1067     }
1068     visiting_hlo = state_.b->AddInstruction(HloInstruction::CreateSlice(
1069         ShapeUtil::MakeShape(halo_exchange_base_shape.element_type(),
1070                              slice_sizes),
1071         visiting_hlo,
1072         /*start_indices=*/pre_halo_exchange_slice_starts,
1073         /*limit_indices=*/pre_halo_exchange_slice_limits,
1074         /*strides=*/
1075         std::vector<int64_t>(halo_exchange_base_shape.rank(), 1)));
1076   }
1077 
1078   std::vector<OffsetCalculation> left_halo_size_functions(base_shape_.rank());
1079   std::vector<OffsetCalculation> right_halo_size_functions(base_shape_.rank());
1080   // TODO(yuanzx): We are concatenating on each sharded dimension one at time,
1081   // and in the second dimension (and beyond) we create halos by slicing the
1082   // concat in the previous dimension, which is not optimal. We should generate
1083   // halos only concating slices, instead of slicing concats.
1084   for (int dim = 0; dim < base_shape_.rank(); ++dim) {
1085     int64_t shard_count = halo_exchange_target->tile_assignment().dim(dim);
1086     if (shard_count == 1 || can_leave_dimension_partitioned[dim]) {
1087       continue;
1088     }
1089     int64_t input_shard_size =
1090         CeilOfRatio(halo_exchange_base_shape.dimensions(dim), shard_count);
1091 
1092     // Left halo. The size of the halo is derived by subtracting the first read
1093     // element offset of the i'th partition from the limit of the (i-1)'th
1094     // partition.
1095     MultiplyAddDivideOffsetCalculation shard_limit_of_previous_on_padded(
1096         input_shard_size, explicit_left_padding[dim], 1);
1097     left_halo_size_functions[dim] =
1098         shard_limit_of_previous_on_padded - start_on_padded_calculations[dim];
1099 
1100     // Right halo.
1101     MultiplyAddDivideOffsetCalculation shard_start_of_next_on_padded(
1102         input_shard_size, input_shard_size + explicit_left_padding[dim], 1);
1103     right_halo_size_functions[dim] =
1104         limit_on_padded_calculations[dim] - shard_start_of_next_on_padded;
1105 
1106     auto resharded = ExchangeHaloAndGetValidData(
1107         visiting_hlo, halo_exchange_base_shape, left_halo_size_functions[dim],
1108         right_halo_size_functions[dim], explicit_left_padding[dim],
1109         padded_shape.dimensions(dim), shard_shape.dimensions(dim), dim,
1110         *halo_exchange_target, offsets_on_padded_shape[dim], pad_value,
1111         partition_ordinals[dim], state_.collective_ops_creator,
1112         state_.next_channel_id, state_.b, mask_invalid_region);
1113     if (!resharded) {
1114       VLOG(1) << "ReshardAsWindowedInput failed without replicate first: halo "
1115                  "is beyond the neighbor.";
1116       // If we are already sharded in such a way that all windowed dimensions
1117       // are replicated then just handle it with pad + slice.
1118       if (sharding_with_non_windowed_dims_replicated == sharding()) {
1119         return handle_all_windowed_dimensions_are_replicated();
1120       }
1121       return Reshard(sharding_with_non_windowed_dims_replicated)
1122           .ReshardAsWindowedInput(window, target, pad_value);
1123     }
1124     visiting_hlo = *resharded;
1125   }
1126   return update_cache(WindowedInputShardReturnValue{
1127       visiting_hlo, shard_window,
1128       get_dynamic_slice_offset_on_output_if_needed()});
1129 }
1130 
Replicate()1131 PartitionedHlo PartitionedHlo::Replicate() {
1132   auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
1133   if (state_.partitioner->options().cache_all_gather) {
1134     for (auto& entry : cache) {
1135       if (entry.first.IsReplicated()) {
1136         return entry.second;
1137       }
1138     }
1139   }
1140   // Do not use a reference as the HLO's sharding can be temporarily replaced.
1141   const HloSharding sharding = hlo_->sharding();
1142   const Shape& shape = hlo_->shape();
1143   CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
1144 
1145   if (sharding.IsReplicated()) {
1146     return *this;
1147   }
1148   for (auto& entry : cache) {
1149     if (entry.first.IsReplicated()) {
1150       return entry.second;
1151     }
1152   }
1153   auto update_cache = [&](PartitionedHlo resharded) {
1154     state_.reshard_cache->per_hlo_cache[resharded.hlo()]
1155         .reshard_cache.insert_or_assign(sharding, *this);
1156     // Get the cache again as it might be invalidated by the insertion above.
1157     auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
1158     if (state_.partitioner->options().cache_all_gather) {
1159       auto [it, _] = cache.insert_or_assign(HloSharding::Replicate(),
1160                                             std::move(resharded));
1161       return it->second;
1162     }
1163     return resharded;
1164   };
1165   // 'Single Device' to 'Repliated'.
1166   if (sharding.IsTileMaximal()) {
1167     return update_cache(Broadcast());
1168   }
1169 
1170   // 'Tiled' to 'Replicated'.
1171   std::vector<int64_t> all_dims(shape.rank());
1172   std::iota(all_dims.begin(), all_dims.end(), 0);
1173   HloInstruction* result = ReplicatePartial(all_dims);
1174   result->set_sharding(HloSharding::Replicate());
1175   return update_cache(PartitionedHlo(result, base_shape_, state_));
1176 }
1177 
ReplicatePartial(absl::Span<const int64_t> dims)1178 HloInstruction* PartitionedHlo::ReplicatePartial(
1179     absl::Span<const int64_t> dims) {
1180   CHECK(!sharding().IsTileMaximal());
1181   const Shape& shard_shape = hlo()->shape();
1182   Shape target_shape = shard_shape;
1183   Shape padded_target_shape = shard_shape;
1184   std::vector<int64_t> broadcast_dims;
1185   std::vector<int64_t> dus_ar_dims;
1186   std::vector<int64_t> ag_dims;
1187   // Find dimensions that can be replicated with Broadcast() (shard size 1) and
1188   // others that need all-gather. dus_ar_dims is a generalization of
1189   // broadcast_dims where the full size is less than half of allgather size, and
1190   // we will use dus->allreduce on them.
1191   for (int64_t i : dims) {
1192     int64_t partitions = sharding().tile_assignment().dim(i);
1193     if (partitions == 1) {
1194       continue;
1195     }
1196     target_shape.set_dimensions(i, base_shape().dimensions(i));
1197     if (target_shape.dimensions(i) == shard_shape.dimensions(i)) {
1198       broadcast_dims.push_back(i);
1199     } else if (target_shape.dimensions(i) <= partitions / 2) {
1200       dus_ar_dims.push_back(i);
1201     } else {
1202       padded_target_shape.set_dimensions(
1203           i, shard_shape.dimensions(i) * partitions);
1204       ag_dims.push_back(i);
1205     }
1206   }
1207 
1208   HloInstruction* broadcast = hlo_;
1209   if (!broadcast_dims.empty()) {
1210     std::vector<int64_t> other_dims;
1211     for (int64_t i = 0; i < sharding().tile_assignment().num_dimensions();
1212          ++i) {
1213       if (!absl::c_linear_search(broadcast_dims, i)) {
1214         other_dims.push_back(i);
1215       }
1216     }
1217     HloSharding original_sharding = sharding();
1218     auto grouped =
1219         hlo_sharding_util::GroupShardingOnDims(original_sharding, other_dims);
1220     std::vector<int64_t> dev_indices(
1221         grouped.sharding.tile_assignment().num_dimensions(), 0);
1222     hlo_->set_sharding(HloSharding::AssignDevice(
1223         grouped.sharding.tile_assignment()(dev_indices)));
1224     auto per_group_partitioner_state = CreatePerGroupPartitioningState(
1225         state(), grouped.device_groups, state().b);
1226     auto partial_replicate_hlo =
1227         PartitionedHlo(hlo_, shard_shape, per_group_partitioner_state)
1228             .Broadcast();
1229     hlo_->set_sharding(original_sharding);
1230     partial_replicate_hlo.hlo()->clear_sharding();
1231     broadcast = partial_replicate_hlo.hlo();
1232   }
1233 
1234   if (ag_dims.empty() && dus_ar_dims.empty()) {
1235     return broadcast;
1236   }
1237 
1238   HloInstruction* result = nullptr;
1239   if (state_.collective_ops_creator.create_cross_partition_all_gather) {
1240     result = state_.partitioner->AllGatherShards(
1241         state_.b, broadcast, sharding(), state_.next_channel_id, ag_dims,
1242         state_.collective_ops_creator);
1243   }
1244   // May also contain failed allgather dims.
1245   if (result == nullptr) {
1246     for (int64_t dim : ag_dims) {
1247       dus_ar_dims.push_back(dim);
1248     }
1249     result = broadcast;
1250   }
1251   if (!dus_ar_dims.empty()) {
1252     auto zero = state_.b->AddInstruction(HloInstruction::CreateConstant(
1253         LiteralUtil::Zero(shard_shape.element_type())));
1254     std::vector<int64_t> masking_dims;
1255     for (int64_t dim : dus_ar_dims) {
1256       int64_t partitions = sharding().tile_assignment().dim(dim);
1257       if (base_shape().dimensions(dim) < partitions) {
1258         // DUS will be out-of-bound and offset will be clamped, so we need to
1259         // mask this dim with 0.
1260         masking_dims.push_back(dim);
1261         // Adjust the padded dim size. This can be also failed allgather dims.
1262         padded_target_shape.set_dimensions(dim, base_shape().dimensions(dim));
1263       }
1264     }
1265     if (!masking_dims.empty()) {
1266       std::vector<int64_t> skipped_dims;
1267       for (int64_t i = 0; i < base_shape().rank(); ++i) {
1268         if (!absl::c_linear_search(masking_dims, i)) {
1269           skipped_dims.push_back(i);
1270         }
1271       }
1272       result->set_sharding(sharding());
1273       result = PartitionedHlo(result, padded_target_shape, state_)
1274                    .PadWithValue(zero,
1275                                  /*left_padded_dims=*/{},
1276                                  /*skipped_dims=*/skipped_dims)
1277                    .hlo();
1278     }
1279     auto zero_bcast = state_.b->AddInstruction(
1280         HloInstruction::CreateBroadcast(padded_target_shape, zero, {}));
1281     auto offsets = MakePartitionOffsets(
1282         padded_target_shape,
1283         hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept(
1284             sharding(), dus_ar_dims),
1285         state_.partition_id, state_.b, dus_ar_dims);
1286     auto dus =
1287         state_.b->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1288             padded_target_shape, zero_bcast, result, offsets));
1289     HloComputation* reduction =
1290         MakeBinaryAdd(shard_shape.element_type(), state_.module);
1291     result = state_.partitioner->AllReduceAlongShardingDims(
1292         state_.b, dus, sharding(), state_.next_channel_id, dus_ar_dims,
1293         state_.collective_ops_creator, reduction);
1294   }
1295   if (!ShapeUtil::Compatible(target_shape, padded_target_shape)) {
1296     std::vector<int64_t> start_indices(target_shape.rank(), 0);
1297     std::vector<int64_t> strides(target_shape.rank(), 1);
1298     result = state_.b->AddInstruction(
1299         HloInstruction::CreateSlice(target_shape, result, start_indices,
1300                                     target_shape.dimensions(), strides));
1301   }
1302   return result;
1303 }
1304 
1305 std::optional<PartitionedHlo>
ReshardToPartialReplicateWithAllGather(const HloSharding & target)1306 PartitionedHlo::ReshardToPartialReplicateWithAllGather(
1307     const HloSharding& target) {
1308   if (!target.ReplicateOnLastTileDim()) {
1309     return std::nullopt;
1310   }
1311   // Tiled/partial replicate to partial replicate
1312   // Get the comptible sharding to target with resharding by all reduce.
1313   auto compatible_sharding =
1314       PartialReplicateReshardCompatibleSharding(target, sharding());
1315   if (!compatible_sharding.has_value()) {
1316     return std::nullopt;
1317   }
1318 
1319   const auto& temp_sharding = compatible_sharding.value();
1320   auto partitioned_hlo = *this;
1321   // Use collective permute to adjust device assignment if needed.
1322   if (CanReshardWithCollectivePermute(sharding(), temp_sharding)) {
1323     partitioned_hlo =
1324         partitioned_hlo.ReshardWithCollectivePermute(temp_sharding);
1325   }
1326 
1327   // Get replicate dims and replicate factor of each dimensions.
1328   int64_t rank = hlo_->shape().rank();
1329   std::vector<int64_t> replicate_dims;
1330   std::vector<int64_t> replicate_factors;
1331   for (int64_t dim = 0; dim < rank; dim++) {
1332     int64_t replicate_factor = temp_sharding.tile_assignment().dim(dim) /
1333                                target.tile_assignment().dim(dim);
1334     if (replicate_factor > 1) {
1335       replicate_dims.emplace_back(dim);
1336       replicate_factors.emplace_back(replicate_factor);
1337     }
1338   }
1339 
1340   // Do left halo exchange if all-reduce directly will remove useful data
1341   // from the source.
1342   auto halo_exchange = TileToPartialReplicateHaloExchange(
1343       partitioned_hlo.hlo_, base_shape_, temp_sharding, target, replicate_dims,
1344       partitioned_hlo.state().collective_ops_creator,
1345       partitioned_hlo.state().next_channel_id,
1346       partitioned_hlo.state().partition_id, partitioned_hlo.state().b);
1347   if (!halo_exchange.has_value()) {
1348     return std::nullopt;
1349   }
1350   auto halo_exchange_hlo = halo_exchange.value();
1351   // Grouped on replicate dimensions.
1352   auto sharding_grouped = hlo_sharding_util::GroupShardingOnDims(
1353       temp_sharding, replicate_dims, replicate_factors);
1354   auto per_group_partitioner_state = CreatePerGroupPartitioningState(
1355       partitioned_hlo.state(), sharding_grouped.device_groups,
1356       partitioned_hlo.state().b);
1357   auto base_shape = MakePartitionedShape(base_shape_, target);
1358   // It's possible that halo_exchange_hlo == hlo.hlo().
1359   // Record the sharding of hlo here, and reset it before return.
1360   auto original_sharding = partitioned_hlo.sharding();
1361   halo_exchange_hlo->set_sharding(sharding_grouped.sharding);
1362   auto partial_replicate_hlo = PartitionedHlo(halo_exchange_hlo, base_shape,
1363                                               per_group_partitioner_state);
1364   HloInstruction* result =
1365       partial_replicate_hlo.ReplicatePartial(replicate_dims);
1366   partitioned_hlo.hlo()->set_sharding(original_sharding);
1367   result->set_sharding(target);
1368   return PartitionedHlo(result, base_shape_, partitioned_hlo.state());
1369 }
1370 
1371 std::optional<PartitionedHlo>
ReshardFromPartialReplicateWithDynamicSlice(const HloSharding & target)1372 PartitionedHlo::ReshardFromPartialReplicateWithDynamicSlice(
1373     const HloSharding& target) {
1374   if (!sharding().ReplicateOnLastTileDim()) {
1375     return std::nullopt;
1376   }
1377 
1378   // Get the temp sharding target from partial replicate to target tile dims.
1379   // target_compatible_sharding has the same tile_assignment dimensions
1380   // as the target and can reshard to target by collective permute.
1381   // target_compatible_sharding could have different device assignment as
1382   // targe. sharding() can reshard to target_compatible_sharding by
1383   // dynamic slice.
1384   auto target_compatible_sharding =
1385       PartialReplicateReshardCompatibleSharding(sharding(), target);
1386   // Reshard to target_compatible_sharding by dynamic slice.
1387   if (!target_compatible_sharding.has_value()) {
1388     return std::nullopt;
1389   }
1390   std::vector<int64_t> expand_tile_dims;
1391   std::vector<int64_t> tiling_dim_factors;
1392   int64_t rank = hlo_->shape().rank();
1393   tiling_dim_factors.reserve(target.tile_assignment().num_dimensions());
1394   const auto& temp_target_sharding = target_compatible_sharding.value();
1395   for (int64_t dim = 0; dim < rank; dim++) {
1396     if (temp_target_sharding.tile_assignment().dim(dim) >
1397         sharding().tile_assignment().dim(dim)) {
1398       expand_tile_dims.push_back(dim);
1399     }
1400     tiling_dim_factors.emplace_back(
1401         temp_target_sharding.tile_assignment().dim(dim) /
1402         sharding().tile_assignment().dim(dim));
1403   }
1404 
1405   // Add another dimension in tiling_dim_factors if target is partial replicate.
1406   if (target.ReplicateOnLastTileDim()) {
1407     tiling_dim_factors.emplace_back(
1408         target.tile_assignment().dimensions().back());
1409   }
1410 
1411   // 2. Get the padded_hlo, do right halo exchange if needed.
1412   auto padded_hlo = PadFromPartialReplicateShape(
1413       hlo_, base_shape_, sharding(), temp_target_sharding, expand_tile_dims,
1414       state_.collective_ops_creator, state_.next_channel_id,
1415       state_.partition_id, state_.b);
1416   if (!padded_hlo.has_value()) {
1417     return std::nullopt;
1418   }
1419   // 3. Slice out the tile from replicate ones.
1420   auto shard_shape = MakePartitionedShape(base_shape_, temp_target_sharding);
1421   // Since we are just slicing, we can just use the differences between the new
1422   // and old offsets in the full shape as the dynamic-slice offsets.
1423   auto padded_base_shape = shard_shape;
1424   for (int64_t i = 0; i < padded_base_shape.rank(); ++i) {
1425     padded_base_shape.set_dimensions(
1426         i, padded_base_shape.dimensions(i) *
1427                temp_target_sharding.tile_assignment().dim(i));
1428   }
1429   auto offsets = MakePartitionOffsets(padded_base_shape, temp_target_sharding,
1430                                       state_.partition_id, state_.b);
1431   auto old_offsets = MakePartitionOffsets(padded_base_shape, sharding(),
1432                                           state_.partition_id, state_.b);
1433   for (int64_t i = 0; i < offsets.size(); ++i) {
1434     offsets[i] = state_.b->AddInstruction(HloInstruction::CreateBinary(
1435         offsets[i]->shape(), HloOpcode::kSubtract, offsets[i], old_offsets[i]));
1436   }
1437   auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
1438       shard_shape, padded_hlo.value(), offsets, shard_shape.dimensions()));
1439   slice->set_sharding(temp_target_sharding);
1440   auto result = PartitionedHlo(slice, base_shape_, state_);
1441   // If temp_target_sharding's device assignment is different from target,
1442   // use collective permute to reshard.
1443   if (CanReshardWithCollectivePermute(temp_target_sharding, target)) {
1444     return result.ReshardWithCollectivePermute(target);
1445   }
1446   // If device assignment in temp_target_sharding and target are the same,
1447   // return result directly.
1448   return result;
1449 }
1450 
Broadcast() const1451 PartitionedHlo PartitionedHlo::Broadcast() const {
1452   const Shape& shape = hlo_->shape();
1453   const HloSharding& sharding = hlo_->sharding();
1454   CHECK(sharding.HasUniqueDevice());
1455   CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
1456 
1457   auto src_core_id = state_.b->AddInstruction(HloInstruction::CreateConstant(
1458       LiteralUtil::CreateR0<uint32_t>(sharding.GetUniqueDevice())));
1459   Shape bcast_shape = ShapeUtil::ChangeElementType(shape, PRED);
1460   auto is_src_core = state_.b->AddInstruction(HloInstruction::CreateBroadcast(
1461       bcast_shape,
1462       state_.b->AddInstruction(HloInstruction::CreateCompare(
1463           ShapeUtil::MakeShape(PRED, {}), state_.partition_id, src_core_id,
1464           ComparisonDirection::kEq)),
1465       {}));
1466 
1467   auto zero = state_.b->AddInstruction(
1468       HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
1469   auto zero_bcast = state_.b->AddInstruction(
1470       HloInstruction::CreateBroadcast(shape, zero, {}));
1471   auto operand = state_.b->AddInstruction(HloInstruction::CreateTernary(
1472       shape, HloOpcode::kSelect, is_src_core, hlo(), zero_bcast));
1473   HloComputation* reduction =
1474       MakeBinaryAdd(shape.element_type(), state_.module);
1475 
1476   auto result = state_.collective_ops_creator.create_cross_partition_all_reduce(
1477       state_.b, operand, reduction, {}, NewChannel());
1478   result->set_sharding(HloSharding::Replicate());
1479   return PartitionedHlo(result, base_shape_, state_);
1480 }
1481 
ReshardWithAllToAll(const HloSharding & target,absl::Span<const std::pair<int64_t,int64_t>> source_target_dims) const1482 PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
1483     const HloSharding& target,
1484     absl::Span<const std::pair<int64_t, int64_t>> source_target_dims) const {
1485   if (source_target_dims.empty()) {
1486     if (target == sharding()) {
1487       return *this;
1488     }
1489     // If the device order is different in the target, fix the order with
1490     // ReshardWithCollectivePermute.
1491     return ReshardWithCollectivePermute(target);
1492   }
1493 
1494   // Swap one pair of dimensions.
1495   int64_t source_dim = source_target_dims[0].first;
1496   int64_t target_dim = source_target_dims[0].second;
1497   const int64_t group_size = sharding().tile_assignment().dim(source_dim) /
1498                              sharding().tile_assignment().dim(target_dim);
1499 
1500   auto temp_target_tile = sharding().tile_assignment();
1501   {
1502     std::vector<int64_t> reshape_tile_dims(temp_target_tile.num_dimensions() +
1503                                            2);
1504     int64_t i = 0;
1505     int64_t added_source_dim = -1;
1506     int64_t added_target_dim = -1;
1507     for (int64_t j = 0; j < temp_target_tile.num_dimensions(); ++j) {
1508       if (source_dim == j) {
1509         reshape_tile_dims[i] = temp_target_tile.dim(j) / group_size;
1510         reshape_tile_dims[++i] = group_size;
1511         added_source_dim = i;
1512       } else if (target_dim == j) {
1513         reshape_tile_dims[i] = temp_target_tile.dim(j);
1514         reshape_tile_dims[++i] = 1;
1515         added_target_dim = i;
1516       } else {
1517         reshape_tile_dims[i] = temp_target_tile.dim(j);
1518       }
1519       ++i;
1520     }
1521     temp_target_tile.Reshape(reshape_tile_dims);
1522     std::vector<int64_t> xpose_dims(temp_target_tile.num_dimensions());
1523     std::iota(xpose_dims.begin(), xpose_dims.end(), 0);
1524     xpose_dims[added_source_dim] = added_target_dim;
1525     xpose_dims[added_target_dim] = added_source_dim;
1526     temp_target_tile = hlo_sharding_util::TransposeSharding(
1527                            HloSharding::Tile(temp_target_tile), xpose_dims)
1528                            .tile_assignment();
1529     auto temp_target_tile_dims = sharding().tile_assignment().dimensions();
1530     temp_target_tile_dims[source_dim] =
1531         sharding().tile_assignment().dim(target_dim);
1532     temp_target_tile_dims[target_dim] =
1533         sharding().tile_assignment().dim(source_dim);
1534     temp_target_tile.Reshape(temp_target_tile_dims);
1535   }
1536   auto temp_target = target.ReplicateOnLastTileDim()
1537                          ? HloSharding::PartialTile(temp_target_tile)
1538                          : HloSharding::Tile(temp_target_tile);
1539   auto padded_shape = hlo_->shape();
1540   padded_shape.set_dimensions(
1541       target_dim, RoundUpTo(padded_shape.dimensions(target_dim),
1542                             temp_target.tile_assignment().dim(target_dim)));
1543   auto padded_hlo = PadToShape(hlo_, padded_shape, state_.b);
1544 
1545   // The order of ids in the group must follow the temp_target sharding.
1546   std::vector<std::vector<int64_t>> groups(
1547       temp_target.tile_assignment().num_elements() / group_size);
1548   temp_target.tile_assignment().Each(
1549       [&](absl::Span<const int64_t> indices, int64_t device) {
1550         int64_t group_id = 0;
1551         for (int64_t dim = 0; dim < indices.size(); ++dim) {
1552           if (dim == target_dim) {
1553             group_id *= temp_target.tile_assignment().dim(dim) / group_size;
1554             group_id += indices[dim] / group_size;
1555           } else {
1556             group_id *= temp_target.tile_assignment().dim(dim);
1557             group_id += indices[dim];
1558           }
1559         }
1560         groups[group_id].push_back(device);
1561       });
1562 
1563   HloInstruction* result = nullptr;
1564 
1565   // Split along the split dimension (target_dim) of the all-to-all
1566   // output.
1567   std::vector<int64_t> dimensions;
1568   const int64_t rank = base_shape_.rank();
1569   dimensions.reserve(rank + 1);
1570   for (int64_t i = 0; i < rank; ++i) {
1571     if (i == target_dim) {
1572       dimensions.push_back(group_size);
1573       dimensions.push_back(padded_hlo->shape().dimensions(i) / group_size);
1574     } else {
1575       dimensions.push_back(padded_hlo->shape().dimensions(i));
1576     }
1577   }
1578   VLOG(5) << "Target ata shape: "
1579           << ShapeUtil::MakeShape(base_shape_.element_type(), dimensions)
1580                  .ToString();
1581   auto reshape = state_.b->AddInstruction(HloInstruction::CreateReshape(
1582       ShapeUtil::MakeShape(base_shape_.element_type(), dimensions),
1583       padded_hlo));
1584   // After the reshape, it is guaranteed to have at least 3 dimensions.
1585   auto all_to_all =
1586       state_.collective_ops_creator.create_cross_partition_all_to_all(
1587           state_.b, {reshape}, groups, (*state_.next_channel_id)++, target_dim);
1588 
1589   // Reorder the split dimension of the reshape to be located in front of the
1590   // input partition dimension, so the two dimensions can be combined.
1591   int64_t new_source_dim =
1592       (target_dim < source_dim) ? source_dim + 1 : source_dim;
1593   std::vector<int64_t> permutation;
1594   for (int64_t i = 0; i < all_to_all->shape().rank(); ++i) {
1595     if (i == target_dim) {
1596       continue;
1597     }
1598     if (i == new_source_dim) {
1599       permutation.push_back(target_dim);
1600     }
1601     permutation.push_back(i);
1602   }
1603   auto transpose = state_.b->AddInstruction(HloInstruction::CreateTranspose(
1604       ShapeInference::InferTransposeShape(all_to_all->shape(), permutation)
1605           .ValueOrDie(),
1606       all_to_all, permutation));
1607 
1608   // Combine the split dimension and the input partition dimension.
1609   auto new_shape = ShapeInference::InferAllToAllShape(
1610                        padded_hlo->shape(), target_dim, source_dim, group_size)
1611                        .ValueOrDie();
1612   result = state_.b->AddInstruction(
1613       HloInstruction::CreateReshape(new_shape, transpose));
1614 
1615   const Shape result_shape = MakePartitionedShape(base_shape_, temp_target);
1616   if (result_shape != result->shape()) {
1617     result = state_.b->AddInstruction(HloInstruction::CreateSlice(
1618         result_shape, result, std::vector<int64_t>(result_shape.rank(), 0),
1619         result_shape.dimensions(),
1620         std::vector<int64_t>(result_shape.rank(), 1)));
1621   }
1622   result->set_sharding(temp_target);
1623   auto remaining_source_target_dims = source_target_dims;
1624   remaining_source_target_dims.remove_prefix(1);
1625   return PartitionedHlo(result, base_shape_, state_)
1626       .ReshardWithAllToAll(target, remaining_source_target_dims);
1627 }
1628 
1629 namespace {
1630 
1631 // Matching a pattern like [..,X,..,Y] -> [..,X*Y,..,1] or [..,X,..,Y] ->
1632 // [..,1,..,X*Y].
PatternMatchReshape(const Shape & shape,const HloSharding & source,const HloSharding & target)1633 std::optional<std::pair<HloSharding, int>> PatternMatchReshape(
1634     const Shape& shape, const HloSharding& source, const HloSharding& target) {
1635   if (!source.IsTiled() || !target.IsTiled()) {
1636     return std::nullopt;
1637   }
1638   if (source.TiledDataRank() != target.TiledDataRank()) {
1639     return std::nullopt;
1640   }
1641   if ((source.HasPartialReplication() ^ target.HasPartialReplication()) ||
1642       (source.HasPartialReplication() &&
1643        source.tile_assignment().dimensions()[source.TiledDataRank()] !=
1644            target.tile_assignment().dimensions()[target.TiledDataRank()])) {
1645     return std::nullopt;
1646   }
1647   for (int i = 0; i < target.TiledDataRank(); ++i) {
1648     if (source.tile_assignment().dim(i) > target.tile_assignment().dim(i) &&
1649         target.tile_assignment().dim(i) == 1 &&
1650         (source.tile_assignment().dim(i) % target.tile_assignment().dim(i)) ==
1651             0) {
1652       const int64_t dimension_size =
1653           source.tile_assignment().dim(i) / target.tile_assignment().dim(i);
1654       for (int j = i - 1; j >= 0; --j) {
1655         if (target.tile_assignment().dim(j) == 1) {
1656           continue;
1657         }
1658         if (target.tile_assignment().dim(j) !=
1659             dimension_size * source.tile_assignment().dim(j)) {
1660           continue;
1661         }
1662         // Do not consider if it requires additional padding.
1663         if (shape.dimensions(j) % dimension_size != 0) {
1664           continue;
1665         }
1666         auto reshaped_sharding = hlo_sharding_util::SplitShardingDimension(
1667             source, j, source.tile_assignment().dim(j));
1668         std::vector<int64_t> permutation(reshaped_sharding.TiledDataRank(), 0);
1669         absl::c_iota(permutation, 0);
1670         std::swap(permutation[i + 1], permutation[j]);
1671         reshaped_sharding = hlo_sharding_util::TransposeSharding(
1672             reshaped_sharding, permutation);
1673         return std::make_pair(reshaped_sharding, j);
1674       }
1675       for (int j = i + 1; j < target.TiledDataRank(); ++j) {
1676         if (target.tile_assignment().dim(j) == 1) {
1677           continue;
1678         }
1679         if (target.tile_assignment().dim(j) !=
1680             dimension_size * source.tile_assignment().dim(j)) {
1681           continue;
1682         }
1683         // Do not consider if it requires additional padding.
1684         if (shape.dimensions(j) % dimension_size != 0) {
1685           continue;
1686         }
1687 
1688         auto reshaped_sharding = hlo_sharding_util::SplitShardingDimension(
1689             source, j, source.tile_assignment().dim(j));
1690         std::vector<int64_t> permutation(reshaped_sharding.TiledDataRank(), 0);
1691         absl::c_iota(permutation, 0);
1692         VLOG(5) << "Reshaped sharding before: " << reshaped_sharding.ToString();
1693         std::swap(permutation[i], permutation[j]);
1694         reshaped_sharding = hlo_sharding_util::TransposeSharding(
1695             reshaped_sharding, permutation);
1696         VLOG(5) << "Reshaped sharding: " << reshaped_sharding.ToString();
1697         return std::make_pair(reshaped_sharding, j);
1698       }
1699     }
1700   }
1701   return std::nullopt;
1702 }
1703 // Match patterns like [..,X,..,Z,..,Y,..] -> [..,Y,..,Z,..,X]
1704 // last_tile_dim_replicate, where X gets replicated and Y also changes
1705 // position. We try to perform the replication, so we can match some other
1706 // targets instead.
PatternMatchPartiallyReplicateDim(const HloSharding & source,const HloSharding & target)1707 std::optional<HloSharding> PatternMatchPartiallyReplicateDim(
1708     const HloSharding& source, const HloSharding& target) {
1709   if (!(!source.ReplicateOnLastTileDim() && target.ReplicateOnLastTileDim())) {
1710     return std::nullopt;
1711   }
1712   const int64_t target_replicated_dim = target.SubgroupReplicationDim();
1713   CHECK_NE(target_replicated_dim, -1) << "Expected replicated dim";
1714   for (int i = 0; i < source.tile_assignment().num_dimensions(); ++i) {
1715     if (source.tile_assignment().dim(i) !=
1716         target.tile_assignment().dim(target_replicated_dim)) {
1717       continue;
1718     }
1719     auto replicated_sharding =
1720         hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(source, {i});
1721     return replicated_sharding;
1722   }
1723   return std::nullopt;
1724 }
1725 
1726 // Helper to split a PartitionedHlo over a specific dimension.
SplitReshapeHelper(PartitionedHlo to_reshape,int64_t dim_to_split,int64_t dim_size,const HloSharding & target_sharding)1727 PartitionedHlo SplitReshapeHelper(PartitionedHlo to_reshape,
1728                                   int64_t dim_to_split, int64_t dim_size,
1729                                   const HloSharding& target_sharding) {
1730   Shape original_shape = to_reshape.hlo()->shape();
1731   std::vector<int64_t> shape_dim(original_shape.dimensions().begin(),
1732                                  original_shape.dimensions().end());
1733   shape_dim.insert(shape_dim.begin() + dim_to_split + 1, 1);
1734   std::vector<int64_t> base_shape_dim(
1735       to_reshape.base_shape().dimensions().begin(),
1736       to_reshape.base_shape().dimensions().end());
1737   base_shape_dim.insert(base_shape_dim.begin() + dim_to_split + 1, dim_size);
1738   base_shape_dim[dim_to_split] /= dim_size;
1739   Shape shape = ShapeUtil::MakeShape(original_shape.element_type(), shape_dim);
1740   HloInstruction* reshaped_instr = to_reshape.state().b->AddInstruction(
1741       HloInstruction::CreateReshape(shape, to_reshape.hlo()));
1742   reshaped_instr->set_sharding(target_sharding);
1743   return PartitionedHlo{
1744       reshaped_instr,
1745       ShapeUtil::MakeShape(to_reshape.base_shape().element_type(),
1746                            base_shape_dim),
1747       to_reshape.state()};
1748 }
1749 // Merge a PartitionedHlo over a specific dimension.
MergeReshapeHelper(PartitionedHlo to_reshape,int64_t dim_to_merge,const HloSharding & target_sharding)1750 PartitionedHlo MergeReshapeHelper(PartitionedHlo to_reshape,
1751                                   int64_t dim_to_merge,
1752                                   const HloSharding& target_sharding) {
1753   Shape original_shape = to_reshape.hlo()->shape();
1754   std::vector<int64_t> shape_dim(original_shape.dimensions().begin(),
1755                                  original_shape.dimensions().end());
1756   shape_dim[dim_to_merge] *= shape_dim[dim_to_merge + 1];
1757   shape_dim.erase(shape_dim.begin() + dim_to_merge + 1);
1758   std::vector<int64_t> base_shape_dim(
1759       to_reshape.base_shape().dimensions().begin(),
1760       to_reshape.base_shape().dimensions().end());
1761   base_shape_dim[dim_to_merge] *= base_shape_dim[dim_to_merge + 1];
1762   base_shape_dim.erase(base_shape_dim.begin() + dim_to_merge + 1);
1763   Shape shape = ShapeUtil::MakeShape(original_shape.element_type(), shape_dim);
1764   HloInstruction* reshaped_instr = to_reshape.state().b->AddInstruction(
1765       HloInstruction::CreateReshape(shape, to_reshape.hlo()));
1766   reshaped_instr->set_sharding(target_sharding);
1767   return PartitionedHlo(
1768       reshaped_instr,
1769       ShapeUtil::MakeShape(original_shape.element_type(), base_shape_dim),
1770       to_reshape.state());
1771 }
1772 
1773 }  // namespace
1774 
TryComplexReshardHandling(const HloSharding & target)1775 std::optional<PartitionedHlo> PartitionedHlo::TryComplexReshardHandling(
1776     const HloSharding& target) {
1777   VLOG(5) << "Trying to split complicated reshard: " << sharding().ToString()
1778           << " to " << target.ToString();
1779   const bool is_source_partially_replicated =
1780       sharding().ReplicateOnLastTileDim();
1781   const bool is_target_partially_replicated = target.ReplicateOnLastTileDim();
1782   if (auto reshape =
1783           PatternMatchReshape(this->hlo()->shape(), sharding(), target)) {
1784     VLOG(5) << "Matched \"pattern_match_reshape()\": "
1785             << reshape->first.ToString();
1786     VLOG(5) << "Original shape: " << hlo()->shape().ToString();
1787     auto before_sharding = hlo_sharding_util::SplitShardingDimension(
1788         sharding(), reshape->second,
1789         sharding().tile_assignment().dim(reshape->second));
1790     PartitionedHlo reshaped = SplitReshapeHelper(
1791         *this, reshape->second,
1792         sharding().tile_assignment().dim(reshape->second), before_sharding);
1793     VLOG(5) << "Reshaped shape: " << reshaped.hlo()->shape().ToString();
1794     auto reshard = reshaped.ReshardNoCache(reshape->first,
1795                                            /*pad_value=*/std::nullopt,
1796                                            /*allow_full_replication=*/false);
1797     if (reshard.sharding() != reshape->first) {
1798       return std::nullopt;
1799     }
1800     auto reshaped_sharding = hlo_sharding_util::MergeShardingDimension(
1801         reshard.sharding(), reshape->second);
1802     reshaped = MergeReshapeHelper(reshard, reshape->second, reshaped_sharding);
1803     if (reshaped.sharding() != target) {
1804       reshaped = reshaped.ReshardNoCache(target, /*pad_value=*/std::nullopt,
1805                                          /*allow_full_replication=*/false);
1806       if (reshaped.sharding() != target) {
1807         return std::nullopt;
1808       }
1809     }
1810     return reshaped;
1811   }
1812   if (auto intermediate_target =
1813           PatternMatchPartiallyReplicateDim(sharding(), target)) {
1814     VLOG(5) << "Matched \"pattern_match_partially_replicate_dim()\": "
1815             << intermediate_target->ToString();
1816     auto intermediate_reshard = Reshard(*intermediate_target);
1817     auto final_reshard = intermediate_reshard.ReshardNoCache(
1818         target, /*pad_value=*/std::nullopt, /*allow_full_replication=*/false);
1819     if (final_reshard.sharding() != target) {
1820       return std::nullopt;
1821     }
1822     return final_reshard;
1823   }
1824   if (is_source_partially_replicated && !is_target_partially_replicated) {
1825     const int64_t partial_repl_amount =
1826         sharding().tile_assignment().dimensions().back();
1827     int64_t first_different_dimension = -1;
1828     // Trying to match conditions like [..,X,..,Z,..,Y] last_tile_dim_replicate
1829     // to [..,Y,..,Z,..,X,..], where Y in the source is partially replicated,
1830     // but in the target it is not and some other dimension got moved or
1831     // modified. Try to remove the partial replication to simplify the step from
1832     // source to target sharding.
1833     for (int64_t i = 0; i < target.tile_assignment().num_dimensions(); ++i) {
1834       if (target.tile_assignment().dim(i) !=
1835               sharding().tile_assignment().dim(i) &&
1836           sharding().tile_assignment().dim(i) == 1 &&
1837           target.tile_assignment().dim(i) % partial_repl_amount == 0) {
1838         first_different_dimension = i;
1839         break;
1840       }
1841     }
1842     if (first_different_dimension == -1) {
1843       return std::nullopt;
1844     }
1845     VLOG(5) << "Matched partially replicated to non partially replicated: "
1846             << sharding().ToString();
1847     std::vector<int64_t> transpose_dims(
1848         sharding().tile_assignment().num_dimensions(), 0);
1849     std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
1850     std::swap(transpose_dims[first_different_dimension], transpose_dims.back());
1851     auto intermediate_sharding =
1852         hlo_sharding_util::TransposeSharding(sharding(), transpose_dims);
1853     auto intermediate_reshard = Reshard(intermediate_sharding);
1854     auto reshard = intermediate_reshard.ReshardNoCache(
1855         target, /*pad_value=*/std::nullopt, /*allow_full_replication=*/false);
1856     if (reshard.sharding() != target) {
1857       return std::nullopt;
1858     }
1859     return reshard;
1860   }
1861   return std::nullopt;
1862 }
1863 
1864 std::optional<PartitionedHlo>
ReshardPartialReplicateWithAllToAll(const HloSharding & target)1865 PartitionedHlo::ReshardPartialReplicateWithAllToAll(const HloSharding& target) {
1866   bool source_is_partial_replicate = sharding().ReplicateOnLastTileDim();
1867   const auto& partial_replicate_sharding =
1868       source_is_partial_replicate ? sharding() : target;
1869   // If neither the source nor the target is partial replicate, return null.
1870   if (!partial_replicate_sharding.ReplicateOnLastTileDim()) {
1871     return std::nullopt;
1872   }
1873   const auto& tile_sharding = source_is_partial_replicate ? target : sharding();
1874   // If both source and target are partial replicate, should be supported in
1875   // Reshard with AllToAll already.
1876   if (tile_sharding.ReplicateOnLastTileDim() || tile_sharding.IsTileMaximal()) {
1877     return std::nullopt;
1878   }
1879 
1880   // Only support resharding from sharding={devices=[2,3]0,1,2,3,4,5}
1881   // to sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}, where
1882   // the last tile dim will be replicate first before all-to-all.
1883   // Or resharding from
1884   // sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
1885   // to sharding={devices=[2,3]0,1,2,3,4,5}, where
1886   // the last tile dim will be sharded after all-to-all.
1887   const int num_replicas =
1888       partial_replicate_sharding.tile_assignment().dimensions().back();
1889   if (((tile_sharding.tile_assignment().num_dimensions() + 1) !=
1890        partial_replicate_sharding.tile_assignment().num_dimensions()) ||
1891       (partial_replicate_sharding.tile_assignment().dim(0) != 1)) {
1892     return std::nullopt;
1893   }
1894   int to_replicate_dim = -1;
1895   for (int i = tile_sharding.tile_assignment().num_dimensions() - 1; i >= 0;
1896        --i) {
1897     if (tile_sharding.tile_assignment().dim(i) > 1 &&
1898         (to_replicate_dim == -1)) {
1899       if (tile_sharding.tile_assignment().dim(i) != num_replicas) {
1900         return std::nullopt;
1901       }
1902       to_replicate_dim = i;
1903     }
1904 
1905     if (tile_sharding.tile_assignment().dim(i) !=
1906         partial_replicate_sharding.tile_assignment().dim(i + 1)) {
1907       return std::nullopt;
1908     }
1909   }
1910 
1911   if (to_replicate_dim == -1) {
1912     return std::nullopt;
1913   }
1914 
1915   // Check if core assignments for source and the target are the same.
1916   auto reshape_tile_assignment = partial_replicate_sharding.tile_assignment();
1917   reshape_tile_assignment.Reshape(tile_sharding.tile_assignment().dimensions());
1918   if (reshape_tile_assignment != tile_sharding.tile_assignment()) {
1919     return std::nullopt;
1920   }
1921 
1922   auto tmp_tile_assignment = tile_sharding.tile_assignment();
1923   auto tmp_tile_assignment_dimensions =
1924       tile_sharding.tile_assignment().dimensions();
1925   tmp_tile_assignment_dimensions[to_replicate_dim] = 1;
1926   tmp_tile_assignment_dimensions.push_back(num_replicas);
1927   tmp_tile_assignment.Reshape(tmp_tile_assignment_dimensions);
1928   auto tmp_partial_replicate_sharding =
1929       HloSharding::PartialTile(tmp_tile_assignment);
1930 
1931   if (source_is_partial_replicate) {
1932     if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
1933             sharding(), tmp_partial_replicate_sharding)) {
1934       auto partitioned_hlo =
1935           ReshardWithAllToAll(tmp_partial_replicate_sharding, *src_tgt_dims);
1936       return partitioned_hlo.Reshard(target);
1937     }
1938   } else {
1939     auto partitioned_hlo = Reshard(tmp_partial_replicate_sharding);
1940 
1941     if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
1942             partitioned_hlo.sharding(), target)) {
1943       return partitioned_hlo.ReshardWithAllToAll(target, *src_tgt_dims);
1944     }
1945   }
1946 
1947   return std::nullopt;
1948 }
1949 
ReshardWithCollectivePermute(const HloSharding & target) const1950 PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
1951     const HloSharding& target) const {
1952   CHECK(CanReshardWithCollectivePermute(sharding(), target))
1953       << sharding().ToString() << " to " << target.ToString();
1954   if (auto broadcast_dims = state_.b->BroadcastDimsForCreatedHlo(hlo())) {
1955     if (!(*broadcast_dims)->empty()) {
1956       // If hlo() has broadcast dims, check if data is already the same between
1957       // source/destination pairs.
1958       std::vector<int64_t> broadcast_dims_vector;
1959       for (int64_t i = 0; i < hlo()->shape().rank(); ++i) {
1960         if ((*broadcast_dims)->contains(i)) {
1961           broadcast_dims_vector.push_back(i);
1962         }
1963       }
1964       if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
1965               sharding(), broadcast_dims_vector) ==
1966           hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
1967               target, broadcast_dims_vector)) {
1968         auto copy = state_.b->AddInstruction(HloInstruction::CreateUnary(
1969             hlo()->shape(), HloOpcode::kCopy, hlo()));
1970         copy->set_sharding(target);
1971         return PartitionedHlo(copy, base_shape_, state_);
1972       }
1973     }
1974   }
1975   std::vector<std::pair<int64_t, int64_t>> src_dst_pairs;
1976   sharding().tile_assignment().Each(
1977       [&](absl::Span<const int64_t> indices, int64_t src_device) {
1978         int64_t dst_device = target.tile_assignment()(indices);
1979         src_dst_pairs.emplace_back(src_device, dst_device);
1980       });
1981   auto cp =
1982       state_.collective_ops_creator.create_cross_partition_collective_permute(
1983           state_.b, hlo(), src_dst_pairs, (*state_.next_channel_id)++);
1984   cp->set_sharding(target);
1985   return PartitionedHlo(cp, base_shape_, state_);
1986 }
1987 
SpmdPartitioningVisitor(HloComputation * computation,int64_t num_partitions,int64_t num_replicas,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t * next_channel_id,SpmdLogger * logger,SpmdPartitionerOptions options,SpmdPartitioner * partitioner)1988 SpmdPartitioningVisitor::SpmdPartitioningVisitor(
1989     HloComputation* computation, int64_t num_partitions, int64_t num_replicas,
1990     const SPMDCollectiveOpsCreator& collective_ops_creator,
1991     int64_t* next_channel_id, SpmdLogger* logger,
1992     SpmdPartitionerOptions options, SpmdPartitioner* partitioner)
1993     : changed_(false),
1994       module_(computation->parent()),
1995       num_partitions_(num_partitions),
1996       num_replicas_(num_replicas),
1997       collective_ops_creator_(collective_ops_creator),
1998       next_channel_id_(next_channel_id),
1999       b_(SpmdBuilder(computation->name() + "_spmd", /*hlo=*/nullptr)),
2000       partition_id_(collective_ops_creator_.create_partition_id(&b_)),
2001       logger_(logger),
2002       options_(std::move(options)),
2003       partitioner_(partitioner) {}
2004 
2005 PartitionedHlo::PartitioningState
MakePartitioningState()2006 SpmdPartitioningVisitor::MakePartitioningState() {
2007   PartitionedHlo::PartitioningState state;
2008   state.b = &b_;
2009   state.module = module_;
2010   state.num_replicas = num_replicas_;
2011   state.next_channel_id = next_channel_id_;
2012   state.reshard_cache = &reshard_cache_;
2013   state.partitioner = partitioner_;
2014   if (!device_groups_.empty()) {
2015     // Use the original collective creator and partition_id to call
2016     // CreatePerGroupPartitioningState(). Current collective_ops_creator_ and
2017     // partition_id_ have been rewritten to be subgrouped.
2018     state.collective_ops_creator = *visiting_collective_ops_creator_;
2019     state.partition_id = *visiting_partition_id_;
2020     return CreatePerGroupPartitioningState(state, device_groups_, &b_);
2021   } else {
2022     state.collective_ops_creator = collective_ops_creator_;
2023     state.partition_id = partition_id_;
2024   }
2025   return state;
2026 }
2027 
CreateReplicaGroups(std::vector<std::vector<int64_t>> & groups)2028 std::vector<ReplicaGroup> SpmdPartitioningVisitor::CreateReplicaGroups(
2029     std::vector<std::vector<int64_t>>& groups) {
2030   std::vector<ReplicaGroup> device_groups;
2031   device_groups.reserve(groups.size() * num_replicas_);
2032   for (int64_t i = 0; i < num_replicas_; ++i) {
2033     for (const auto& group : groups) {
2034       device_groups.emplace_back();
2035       for (int64_t id : group) {
2036         device_groups.back().add_replica_ids(i * num_partitions_ + id);
2037       }
2038     }
2039   }
2040   return device_groups;
2041 }
2042 
DefaultAction(HloInstruction * hlo)2043 Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) {
2044   if (hlo->HasSideEffect() && !hlo->sharding().HasUniqueDevice()) {
2045     return Unimplemented("Side-effect ops cannot be replicated: %s",
2046                          hlo->ToString());
2047   }
2048 
2049   if (hlo->IsElementwise() && hlo->operand_count() > 0) {
2050     return HandleElementwise(hlo);
2051   }
2052 
2053   if (!hlo->sharding().IsTileMaximal()) {
2054     VLOG(1) << "Not partitioned in SPMD mode (DefaultAction):"
2055             << hlo->ToString();
2056     for (int64_t i = 0; i < hlo->operand_count(); ++i) {
2057       VLOG(1) << "  operand " << i
2058               << " sharding:" << hlo->operand(i)->sharding().ToString();
2059     }
2060   }
2061 
2062   HloSharding sharding = hlo->sharding().HasUniqueDevice()
2063                              ? hlo->sharding()
2064                              : HloSharding::Replicate();
2065   if (hlo->opcode() == HloOpcode::kSend || hlo->opcode() == HloOpcode::kRecv ||
2066       hlo->opcode() == HloOpcode::kRecvDone) {
2067     sharding = sharding.GetSubSharding(hlo->shape(), {0});
2068   }
2069 
2070   // If the instruction cannot be partitioned, replicate the instruction unless
2071   // the instruction has side-effect.
2072   std::vector<HloInstruction*> new_operands;
2073   for (HloInstruction* operand : hlo->operands()) {
2074     new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
2075   }
2076   auto clone =
2077       b_.AddInstruction(hlo->CloneWithNewOperands(hlo->shape(), new_operands));
2078   clone->set_sharding(sharding);
2079   SetPartitionedHlo(hlo,
2080                     PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
2081                         .Reshard(hlo->sharding()));
2082   return OkStatus();
2083 }
2084 
Preprocess(HloInstruction * hlo)2085 Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) {
2086   visiting_hlo_ = hlo;
2087   b_.set_visiting_hlo(hlo);
2088   // Temporarily replace manual sharding to one-device sharding so that the
2089   // partitioner will not change the HLOs.
2090   auto manual_to_onedevice = [&](HloOpcode opcode, const Shape& shape,
2091                                  const HloSharding& sharding) {
2092     // If a tuple's elements are all manual, then sharding.IsManual() == True,
2093     // so we test whether it is tuple first.
2094     if (sharding.IsTuple()) {
2095       std::vector<HloSharding> subshardings = sharding.tuple_elements();
2096       for (HloSharding& subsharding : subshardings) {
2097         // Delay manual sharding substitution for CustomCalls.
2098         if (subsharding.IsManual() && opcode != HloOpcode::kCustomCall) {
2099           subsharding = HloSharding::AssignDevice(0);
2100         }
2101       }
2102       return HloSharding::Tuple(shape, subshardings);
2103     }
2104     // Delay manual sharding substitution for CustomCalls.
2105     if (sharding.IsManual() && opcode != HloOpcode::kCustomCall) {
2106       return HloSharding::AssignDevice(0);
2107     }
2108     return sharding;
2109   };
2110 
2111   if (hlo->opcode() != HloOpcode::kConditional &&
2112       hlo->opcode() != HloOpcode::kTuple &&
2113       hlo->opcode() != HloOpcode::kGetTupleElement &&
2114       hlo->opcode() != HloOpcode::kParameter &&
2115       hlo->opcode() != HloOpcode::kWhile && hlo->opcode() != HloOpcode::kRng &&
2116       hlo->opcode() != HloOpcode::kAllReduce) {
2117     const bool has_manual_sharding =
2118         hlo->sharding().IsManual() ||
2119         (hlo->sharding().IsTuple() &&
2120          absl::c_any_of(
2121              hlo->sharding().tuple_elements(),
2122              [](const HloSharding& sharding) { return sharding.IsManual(); }));
2123     if (has_manual_sharding && !hlo->IsCustomCall("SPMDFullToShardShape")) {
2124       visiting_hlo_sharding_ = hlo->sharding();
2125       hlo->set_sharding(manual_to_onedevice(hlo->opcode(), hlo->shape(),
2126                                             *visiting_hlo_sharding_));
2127 
2128       visiting_hlo_operand_shardings_.reserve(hlo->operand_count());
2129       for (HloInstruction* operand : hlo->unique_operands()) {
2130         visiting_hlo_operand_shardings_.push_back(operand->sharding());
2131         operand->set_sharding(manual_to_onedevice(
2132             hlo->opcode(), operand->shape(), operand->sharding()));
2133         GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
2134       }
2135     } else {
2136       const bool has_manual_subgroup =
2137           hlo->sharding().IsManualSubgroup() ||
2138           (hlo->sharding().IsTuple() &&
2139            absl::c_any_of(hlo->sharding().tuple_elements(),
2140                           [](const HloSharding& sharding) {
2141                             return sharding.IsManualSubgroup();
2142                           }));
2143       if (has_manual_subgroup && !hlo->IsCustomCall("SPMDFullToShardShape")) {
2144         auto get_grouped_sharding =
2145             [&](const HloSharding& sharding, const Shape& shape,
2146                 const GroupedSharding* ref =
2147                     nullptr) -> StatusOr<GroupedSharding> {
2148           if (!sharding.IsTuple()) {
2149             GroupedSharding grouped =
2150                 hlo_sharding_util::GetManualSubgroupSharding(sharding);
2151             if (ref != nullptr) {
2152               auto aligned =
2153                   AlignGroupsWithIfCompatible(std::move(grouped), *ref);
2154               TF_RET_CHECK(aligned.has_value())
2155                   << "Incompatible manual sharding at " << hlo->ToString();
2156               return *aligned;
2157             }
2158             return grouped;
2159           }
2160           std::vector<HloSharding> elements;
2161           elements.reserve(sharding.tuple_elements().size());
2162           CHECK(!sharding.tuple_elements().empty());
2163           GroupedSharding grouped0 =
2164               hlo_sharding_util::GetManualSubgroupSharding(
2165                   sharding.tuple_elements()[0]);
2166           if (ref != nullptr) {
2167             auto aligned =
2168                 AlignGroupsWithIfCompatible(std::move(grouped0), *ref);
2169             TF_RET_CHECK(aligned.has_value())
2170                 << "Incompatible manual sharding at " << hlo->ToString();
2171             grouped0 = std::move(*aligned);
2172           }
2173           elements.push_back(std::move(grouped0.sharding));
2174           for (int64_t i = 1; i < sharding.tuple_elements().size(); ++i) {
2175             auto grouped_i = AlignGroupsWithIfCompatible(
2176                 hlo_sharding_util::GetManualSubgroupSharding(
2177                     sharding.tuple_elements()[i]),
2178                 grouped0);
2179             TF_RET_CHECK(grouped_i.has_value())
2180                 << "Incompatible manual sharding between tuple elements: "
2181                 << hlo->ToString();
2182             elements.push_back(std::move(grouped_i->sharding));
2183           }
2184           grouped0.sharding = HloSharding::Tuple(shape, elements);
2185           return grouped0;
2186         };
2187         TF_ASSIGN_OR_RETURN(
2188             auto group_sharding,
2189             get_grouped_sharding(hlo->sharding(), hlo->shape()));
2190         // Update sharding.
2191         visiting_hlo_sharding_ = hlo->sharding();
2192         hlo->set_sharding(group_sharding.sharding);
2193         // Update device_groups and num_partitions.
2194         // Set device_groups_, visiting_partition_id_ and
2195         // visiting_collective_ops_creator_ before MakePartitioningState() which
2196         // uses them.
2197         device_groups_ = group_sharding.device_groups;
2198         visiting_num_partitions_ = num_partitions_;
2199         num_partitions_ = num_partitions_ / group_sharding.device_groups.size();
2200         visiting_partition_id_ = partition_id_;
2201         visiting_collective_ops_creator_ = std::move(collective_ops_creator_);
2202         auto grouped_state = MakePartitioningState();
2203         collective_ops_creator_ =
2204             std::move(grouped_state.collective_ops_creator);
2205         partition_id_ = grouped_state.partition_id;
2206 
2207         // Update sharding for the operands.
2208         visiting_hlo_operand_shardings_.reserve(hlo->operand_count());
2209         visiting_state_.reserve(hlo->operand_count());
2210         for (HloInstruction* operand : hlo->unique_operands()) {
2211           visiting_hlo_operand_shardings_.push_back(operand->sharding());
2212           auto old_state = GetPartitionedHlo(operand).state();
2213           visiting_state_.push_back(old_state);
2214           if (operand->shape().IsArray() && operand->IsConstant() &&
2215               operand->shape().rank() == 0 &&
2216               !operand->sharding().IsManualSubgroup()) {
2217             // We allowed scalar constants to be CSE'ed between manual/auto
2218             // subgraphs. It's possible that it doesn't have a manual subgroup.
2219             continue;
2220           }
2221           TF_ASSIGN_OR_RETURN(
2222               auto op_group_sharding,
2223               get_grouped_sharding(operand->sharding(), operand->shape(),
2224                                    &group_sharding));
2225           operand->set_sharding(op_group_sharding.sharding);
2226           GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
2227           auto group_state = CreatePerGroupPartitioningState(
2228               old_state, op_group_sharding.device_groups, &b_);
2229           GetPartitionedHlo(operand).set_state(group_state);
2230         }
2231       }
2232     }
2233   }
2234   return OkStatus();
2235 }
2236 
Postprocess(HloInstruction * hlo)2237 Status SpmdPartitioningVisitor::Postprocess(HloInstruction* hlo) {
2238   logger_->RegisterLogEntry(hlo, b_.derived_instructions(hlo));
2239   visiting_hlo_ = nullptr;
2240   b_.set_visiting_hlo(nullptr);
2241   // Revert fake one-device shardings for manually partitioned ops.
2242   if (visiting_hlo_sharding_) {
2243     hlo->set_sharding(*visiting_hlo_sharding_);
2244     GetPartitionedHlo(hlo).hlo()->set_sharding(*visiting_hlo_sharding_);
2245     int64_t i = 0;
2246     for (HloInstruction* operand : hlo->unique_operands()) {
2247       operand->set_sharding(visiting_hlo_operand_shardings_[i++]);
2248       GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
2249     }
2250     visiting_hlo_sharding_.reset();
2251     visiting_hlo_operand_shardings_.clear();
2252   }
2253 
2254   if (!device_groups_.empty()) {
2255     device_groups_.clear();
2256     num_partitions_ = *visiting_num_partitions_;
2257     visiting_num_partitions_.reset();
2258     collective_ops_creator_ = *visiting_collective_ops_creator_;
2259     visiting_collective_ops_creator_.reset();
2260     partition_id_ = *visiting_partition_id_;
2261     visiting_partition_id_.reset();
2262     GetPartitionedHlo(hlo).set_state(MakePartitioningState());
2263   }
2264 
2265   if (!visiting_state_.empty()) {
2266     int64_t i = 0;
2267     for (const HloInstruction* operand : hlo->unique_operands()) {
2268       GetPartitionedHlo(operand).set_state(visiting_state_[i++]);
2269     }
2270     visiting_state_.clear();
2271   }
2272 
2273   return OkStatus();
2274 }
2275 
HandleElementwise(HloInstruction * hlo)2276 Status SpmdPartitioningVisitor::HandleElementwise(HloInstruction* hlo) {
2277   std::vector<HloInstruction*> new_operands;
2278   for (HloInstruction* operand : hlo->operands()) {
2279     new_operands.push_back(
2280         GetPartitionedHlo(operand).Reshard(hlo->sharding()).hlo());
2281   }
2282   SetPartitionedHlo(hlo, [&] {
2283     return b_.AddInstruction(hlo->CloneWithNewOperands(
2284         MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
2285   });
2286   return OkStatus();
2287 }
2288 
HandleConcatenate(HloInstruction * hlo)2289 Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) {
2290   const HloSharding& sharding = hlo->sharding();
2291   if (sharding.IsTileMaximal()) {
2292     return DefaultAction(hlo);
2293   }
2294 
2295   const Shape shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2296   const int64_t dimension = hlo->concatenate_dimension();
2297   if (sharding.tile_assignment().dim(dimension) == 1) {
2298     std::vector<HloInstruction*> new_operands;
2299     for (HloInstruction* operand : hlo->operands()) {
2300       new_operands.push_back(
2301           GetPartitionedHlo(operand).Reshard(sharding).hlo());
2302     }
2303     SetPartitionedHlo(hlo, [&] {
2304       return b_.AddInstruction(
2305           hlo->CloneWithNewOperands(shard_shape, new_operands));
2306     });
2307     return OkStatus();
2308   }
2309 
2310   // If the concatenate dimension is along one of the partitioned dimensions,
2311   // allocate the full output shape, each partition updates its owned region,
2312   // all-reduce across partitions, and then slice its output region.
2313 
2314   // temp_output_shape is the output shape where the concatenate dimension
2315   // is changed to the full (and padded to shard count) dimension size.
2316   auto temp_output_shape = MakePartitionedShape(hlo->shape(), sharding);
2317   auto last_operand_padded_shape =
2318       MakePartitionedShape(hlo->operands().back()->shape(), sharding);
2319   // If the last operand has more padding than the temp_output padding, needs to
2320   // add extra padding to avoid dynamic update slice out of bound.
2321   int last_operand_padding =
2322       last_operand_padded_shape.dimensions(dimension) *
2323           sharding.tile_assignment().dim(dimension) -
2324       hlo->operands().back()->shape().dimensions(dimension);
2325   int temp_output_padding = temp_output_shape.dimensions(dimension) *
2326                                 sharding.tile_assignment().dim(dimension) -
2327                             hlo->shape().dimensions(dimension);
2328   int padding_for_last_operand =
2329       last_operand_padding < temp_output_padding
2330           ? 0
2331           : last_operand_padding - temp_output_padding;
2332   temp_output_shape.set_dimensions(
2333       dimension, temp_output_shape.dimensions(dimension) *
2334                          sharding.tile_assignment().dim(dimension) +
2335                      padding_for_last_operand);
2336   auto temp_output = CreateZero(temp_output_shape, &b_);
2337 
2338   // Offset of each operand along the concatenate dimension.
2339   int64_t offset = 0;
2340   auto state = MakePartitioningState();
2341   for (HloInstruction* operand : hlo->operands()) {
2342     auto spmd_operand = GetPartitionedHlo(operand).Reshard(sharding).hlo();
2343     std::vector<HloInstruction*> start_indices(
2344         hlo->shape().rank(), b_.AddInstruction(HloInstruction::CreateConstant(
2345                                  LiteralUtil::Zero(S32))));
2346     start_indices[dimension] =
2347         MultiplyAddDivideOffsetCalculation(
2348             spmd_operand->shape().dimensions(dimension), offset, 1)
2349             .Calculate(MakeTiledPartitionOrdinals(sharding, state.partition_id,
2350                                                   &b_)[dimension],
2351                        &b_);
2352     temp_output = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2353         temp_output_shape, temp_output, spmd_operand, start_indices));
2354     offset += operand->shape().dimensions(dimension);
2355   }
2356   std::vector<int64_t> non_concat_dims;
2357   non_concat_dims.reserve(hlo->shape().rank() - 1);
2358   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2359     if (i != dimension) {
2360       non_concat_dims.push_back(i);
2361     }
2362   }
2363   auto grouped =
2364       hlo_sharding_util::GroupShardingOnDims(sharding, non_concat_dims);
2365   auto per_group_partitioner_state =
2366       CreatePerGroupPartitioningState(state, grouped.device_groups, &b_);
2367   auto all_reduce = per_group_partitioner_state.collective_ops_creator
2368                         .create_cross_partition_all_reduce(
2369                             &b_, temp_output,
2370                             MakeBinaryAdd(hlo->shape().element_type(), module_),
2371                             {}, NewChannel());
2372   SetPartitionedHlo(hlo, [&] {
2373     auto start_indices = MakeTiledPartitionOrdinals(
2374         grouped.sharding, per_group_partitioner_state.partition_id, &b_);
2375     start_indices[dimension] = MultiplyAddDivideOffsetCalculation(
2376                                    shard_shape.dimensions(dimension), 0, 1)
2377                                    .Calculate(start_indices[dimension], &b_);
2378     return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
2379         shard_shape, all_reduce, start_indices, shard_shape.dimensions()));
2380   });
2381 
2382   return OkStatus();
2383 }
2384 
HandleSlice(HloInstruction * hlo)2385 Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) {
2386   const HloSharding& sharding = hlo->sharding();
2387   if (sharding.IsTileMaximal()) {
2388     return DefaultAction(hlo);
2389   }
2390 
2391   auto operand = GetPartitionedHlo(hlo->operand(0)).Reshard(sharding);
2392 
2393   const int64_t rank = hlo->shape().rank();
2394   // Create a window config to represent the slice.
2395   Window window;
2396   for (int64_t i = 0; i < rank; ++i) {
2397     WindowDimension* dim = window.add_dimensions();
2398     dim->set_size(1);
2399     dim->set_stride(hlo->slice_strides(i));
2400     dim->set_window_dilation(1);
2401     dim->set_window_reversal(false);
2402     dim->set_padding_low(-hlo->slice_starts(i));
2403     dim->set_padding_high(hlo->slice_limits(i) -
2404                           operand.base_shape().dimensions(i));
2405     dim->set_base_dilation(1);
2406   }
2407 
2408   auto reshard_operand = operand.ReshardAsWindowedInput(
2409       window, sharding,
2410       CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
2411       /*mask_invalid_region=*/false);
2412   if (!reshard_operand.has_value()) {
2413     return DefaultAction(hlo);
2414   }
2415   TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
2416   const Shape& operand_shape = reshard_operand->sharded_input->shape();
2417 
2418   std::vector<int64_t> start_indices(rank);
2419   std::vector<int64_t> limit_indices(rank);
2420   const std::vector<int64_t>& strides = hlo->slice_strides();
2421   bool need_slice = false;
2422   for (int64_t i = 0; i < rank; ++i) {
2423     auto dim = reshard_operand->shard_window.dimensions(i);
2424     start_indices[i] = -dim.padding_low();
2425     limit_indices[i] = operand_shape.dimensions(i) + dim.padding_high();
2426     if (start_indices[i] != 0 || strides[i] != 1 ||
2427         limit_indices[i] != operand_shape.dimensions(i)) {
2428       need_slice = true;
2429     }
2430   }
2431 
2432   SetPartitionedHlo(hlo, [&] {
2433     if (need_slice) {
2434       auto shard_shape = MakePartitionedShape(hlo->shape(), sharding);
2435       return b_.AddInstruction(HloInstruction::CreateSlice(
2436           shard_shape, reshard_operand->sharded_input, start_indices,
2437           limit_indices, strides));
2438     }
2439     auto data = reshard_operand->sharded_input;
2440     // Create a copy so that it will not share the resharding cache.
2441     return b_.AddInstruction(
2442         HloInstruction::CreateUnary(data->shape(), HloOpcode::kCopy, data));
2443   });
2444 
2445   return OkStatus();
2446 }
2447 
HandleSort(HloInstruction * hlo)2448 Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) {
2449   HloSharding sharding = hlo->sharding();
2450   if (sharding.HasUniqueDevice()) {
2451     return DefaultAction(hlo);
2452   }
2453   // Special handling for sort in TopK when first operand partitioined at
2454   // sort dimension.
2455   auto k = GetKValueInTopKWhenPartitionSortDim(hlo);
2456   if (k.has_value()) {
2457     // When the first operand partitioned at sort dimension:
2458     //   1. Partition sort computation to different partitions;
2459     //   2. Slice TopK value and index from different partitions;
2460     //   3. Gather and replicate value and index from different partitions,
2461     //      the shape of replicated value and index will be
2462     //      [batch_size, ..., partition_count * k, ...];
2463     //   4. Final sort uses replicated value and index from different partitions
2464     //      as input.
2465     // GetTupleElement and Slice after the non-partitoned sort won't change
2466     // at this point, as HandleGetTupleElement and HandleSlice will update them.
2467     HloSortInstruction* sort = DynCast<HloSortInstruction>(hlo);
2468     const int64_t sort_dim = sort->sort_dimension();
2469     auto input = hlo->operand(0);
2470     auto index = hlo->operand(1);
2471     const HloSharding& input_sharding = input->sharding();
2472     const int64_t partition_count =
2473         input_sharding.tile_assignment().dim(sort_dim);
2474     const int64_t input_size = input->shape().dimensions(sort_dim);
2475     const auto element_type = input->shape().element_type();
2476     const auto index_type = index->shape().element_type();
2477 
2478     // Partition and pad input and index.
2479     // Pad input with minimal value.
2480     auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
2481         CreateFirstWithType(element_type, &b_));
2482     // Pad index with max value.
2483     auto partitioned_index =
2484         GetPartitionedHlo(index)
2485             .Reshard(input_sharding)
2486             .PadWithValue(CreateLastWithType(index_type, &b_));
2487 
2488     // Each partition needs to do TopK separately, thus the base shape
2489     // becomes the padded shape.
2490     std::vector<int64_t> replicated_dimensions(
2491         input->shape().dimensions().begin(), input->shape().dimensions().end());
2492     replicated_dimensions[sort_dim] = RoundUpTo(input_size, partition_count);
2493     const Shape replicated_shape = ShapeUtil::MakeTupleShape(
2494         {ShapeUtil::MakeShape(element_type, replicated_dimensions),
2495          ShapeUtil::MakeShape(index_type, replicated_dimensions)});
2496 
2497     // Partition original topk to different shards.
2498     auto topk_sharding =
2499         input_sharding.GetTupleSharding(replicated_shape).ValueOrDie();
2500     auto shard_shape = MakePartitionedShape(replicated_shape, topk_sharding);
2501     auto topk = b_.AddInstruction(hlo->CloneWithNewOperands(
2502         shard_shape, {partitioned_input.hlo(), partitioned_index.hlo()}));
2503 
2504     // Get value from first sort.
2505     HloInstruction* value_gte =
2506         b_.AddInstruction(HloInstruction::CreateGetTupleElement(
2507             topk->shape().tuple_shapes(0), topk, 0));
2508     HloInstruction* index_gte =
2509         b_.AddInstruction(HloInstruction::CreateGetTupleElement(
2510             topk->shape().tuple_shapes(1), topk, 1));
2511 
2512     // Slice top K value from the first partitioned sort.
2513     replicated_dimensions[sort_dim] = k.value() * partition_count;
2514     auto slice_input = SliceFirstK(value_gte, &b_, sort_dim, k.value());
2515     slice_input->set_sharding(input_sharding);
2516     PartitionedHlo partitioned_slice_input(
2517         slice_input, ShapeUtil::MakeShape(element_type, replicated_dimensions),
2518         MakePartitioningState());
2519     // Reshard value to be replicated.
2520     auto replicated_slice_input =
2521         partitioned_slice_input.Reshard(HloSharding::Replicate()).hlo();
2522 
2523     // Slice top K index from the first parttioned sort.
2524     auto slice_index = SliceFirstK(index_gte, &b_, sort_dim, k.value());
2525     slice_index->set_sharding(input_sharding);
2526     PartitionedHlo partitioned_slice_index(
2527         slice_index, ShapeUtil::MakeShape(index_type, replicated_dimensions),
2528         MakePartitioningState());
2529     // Reshard value to be replicated.
2530     auto replicated_slice_index =
2531         partitioned_slice_index.Reshard(HloSharding::Replicate()).hlo();
2532 
2533     // Creates replicated sort to do TopK, the input is value and index pairs
2534     // from all the partitions.
2535     const Shape final_topk_shape = ShapeUtil::MakeTupleShape(
2536         {ShapeUtil::MakeShape(element_type, replicated_dimensions),
2537          ShapeUtil::MakeShape(index_type, replicated_dimensions)});
2538     HloInstruction* final_sort = b_.AddInstruction(HloInstruction::CreateSort(
2539         final_topk_shape, sort_dim,
2540         {replicated_slice_input, replicated_slice_index}, sort->to_apply(),
2541         sort->is_stable()));
2542     final_sort->set_sharding(HloSharding::Replicate()
2543                                  .GetTupleSharding(final_sort->shape())
2544                                  .ValueOrDie());
2545     PartitionedHlo replicated_sort(final_sort, final_sort->shape(),
2546                                    MakePartitioningState());
2547     SetPartitionedHlo(hlo, replicated_sort.Reshard(hlo->sharding()));
2548 
2549     return OkStatus();
2550   }
2551 
2552   if (hlo->shape().IsTuple()) {
2553     // Check that all elements are sharded in the same way.
2554     if (hlo->shape().tuple_shapes_size() == 0) {
2555       return DefaultAction(hlo);
2556     }
2557     sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
2558     for (int64_t i = 1; i < hlo->operand_count(); ++i) {
2559       if (sharding != hlo->sharding().GetSubSharding(hlo->shape(), {i})) {
2560         return DefaultAction(hlo);
2561       }
2562     }
2563   }
2564   if (sharding.IsTileMaximal()) {
2565     return DefaultAction(hlo);
2566   }
2567   for (int64_t dim : hlo->dimensions()) {
2568     if (sharding.tile_assignment().dim(dim) > 1) {
2569       return DefaultAction(hlo);
2570     }
2571   }
2572   // Reshard operands to the same as the output.
2573   std::vector<HloInstruction*> new_operands;
2574   for (HloInstruction* operand : hlo->operands()) {
2575     new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
2576   }
2577   SetPartitionedHlo(hlo, [&] {
2578     return b_.AddInstruction(hlo->CloneWithNewOperands(
2579         MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
2580   });
2581   return OkStatus();
2582 }
2583 
HandleTranspose(HloInstruction * hlo)2584 Status SpmdPartitioningVisitor::HandleTranspose(HloInstruction* hlo) {
2585   const HloSharding& sharding = hlo->sharding();
2586   if (sharding.IsTileMaximal()) {
2587     return DefaultAction(hlo);
2588   }
2589 
2590   std::vector<int64_t> inverse_dimensions(hlo->shape().rank());
2591   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2592     inverse_dimensions[hlo->dimensions(i)] = i;
2593   }
2594   auto desired_operand_sharding =
2595       hlo_sharding_util::TransposeSharding(sharding, inverse_dimensions);
2596 
2597   auto operand = GetPartitionedHlo(hlo->operand(0))
2598                      .Reshard(desired_operand_sharding)
2599                      .hlo();
2600   SetPartitionedHlo(hlo, [&] {
2601     return b_.AddInstruction(hlo->CloneWithNewOperands(
2602         MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand}));
2603   });
2604   return OkStatus();
2605 }
2606 
HandleReshape(HloInstruction * hlo)2607 Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
2608   const HloSharding& sharding = hlo->sharding();
2609   if (sharding.IsTileMaximal()) {
2610     return DefaultAction(hlo);
2611   }
2612 
2613   auto operand = GetPartitionedHlo(hlo->operand(0));
2614   // The output shape is the source and the operand shape is the target to get
2615   // the aligned sharding for the operand.
2616   std::optional<HloSharding> desired_operand_sharding =
2617       hlo_sharding_util::ReshapeSharding(hlo->shape(), hlo->operand(0)->shape(),
2618                                          hlo->sharding());
2619   // Use the desired operand sharding only if the number of tiles returned
2620   // matches the number of tiles in the output.
2621   if (desired_operand_sharding.has_value() &&
2622       hlo->sharding().NumTiles() == desired_operand_sharding->NumTiles()) {
2623     auto operand_hlo = operand.Reshard(*desired_operand_sharding).hlo();
2624     SetPartitionedHlo(hlo, [&] {
2625       return b_.AddInstruction(hlo->CloneWithNewOperands(
2626           MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand_hlo}));
2627     });
2628     return OkStatus();
2629   }
2630   std::optional<HloSharding> desired_output_sharding =
2631       hlo_sharding_util::ReshapeSharding(hlo->operand(0)->shape(), hlo->shape(),
2632                                          operand.sharding());
2633   if (desired_output_sharding.has_value()) {
2634     auto reshape = b_.AddInstruction(hlo->CloneWithNewOperands(
2635         MakePartitionedShape(hlo->shape(), *desired_output_sharding),
2636         {operand.hlo()}));
2637     reshape->set_sharding(*desired_output_sharding);
2638     SetPartitionedHlo(hlo, [&] {
2639       return PartitionedHlo(reshape, hlo->shape(), MakePartitioningState())
2640           .Reshard(sharding)
2641           .hlo();
2642     });
2643     return OkStatus();
2644   }
2645 
2646   // Check if operand sharding and sharding are both tiled or partial replicate.
2647   // If both of them are partial replicate, check num_replications are the same.
2648   if (operand.sharding().ReplicateOnLastTileDim() !=
2649           sharding.ReplicateOnLastTileDim() ||
2650       (sharding.ReplicateOnLastTileDim() &&
2651        (operand.sharding().tile_assignment().dimensions().back() !=
2652         sharding.tile_assignment().dimensions().back()))) {
2653     return DefaultAction(hlo);
2654   }
2655 
2656   // Try use halo exchange for certain split-dim/merge-dims cases.
2657   // ReshapeSharding failed in these cases probably due to uneven partitioning,
2658   // where halo exchange could help. Specifically we check the following
2659   // conditions to detect supported cases:
2660   // 1) Both input and output are partitioned on one dimension.
2661   // 2) The combined size of dimensions before the partitioned dimension are the
2662   // same on input and output. This means we don't need to consider the major
2663   // dimensions.
2664   // 3) Let A = the input size on the partitioned dimension, and
2665   //        B = the output size on the partitioned dimension; then
2666   //    either A % B == 0 (split dim) or B % A == 0 (merge dims).
2667   auto maybe_input_sharded_dim = UniqueTiledDim(operand.sharding());
2668   auto maybe_output_sharded_dim = UniqueTiledDim(sharding);
2669   if (!maybe_input_sharded_dim || !maybe_output_sharded_dim) {
2670     return DefaultAction(hlo);
2671   }
2672   int64_t input_sharded_dim = *maybe_input_sharded_dim;
2673   int64_t output_sharded_dim = *maybe_output_sharded_dim;
2674   // Check that the major dims before the sharded dim have the same total size
2675   // for input and output.
2676   int64_t input_major_dims_size = 1;
2677   for (int64_t i = 0; i < input_sharded_dim; ++i) {
2678     input_major_dims_size *= operand.base_shape().dimensions(i);
2679   }
2680   int64_t output_major_dims_size = 1;
2681   for (int64_t i = 0; i < output_sharded_dim; ++i) {
2682     output_major_dims_size *= hlo->shape().dimensions(i);
2683   }
2684   if (input_major_dims_size != output_major_dims_size) {
2685     return DefaultAction(hlo);
2686   }
2687   // Fix potential device ordering mismatch in tile assignment.
2688   Array<int64_t> new_input_tile_assignment = sharding.tile_assignment();
2689   new_input_tile_assignment.Reshape(
2690       operand.sharding().tile_assignment().dimensions());
2691   auto aligned_sharding =
2692       sharding.ReplicateOnLastTileDim()
2693           ? HloSharding::PartialTile(new_input_tile_assignment)
2694           : HloSharding::Tile(new_input_tile_assignment);
2695   operand = operand.Reshard(aligned_sharding);
2696   auto replication_count = sharding.ReplicateOnLastTileDim()
2697                                ? sharding.tile_assignment().dimensions().back()
2698                                : 1;
2699 
2700   int64_t input_dim_size = operand.base_shape().dimensions(input_sharded_dim);
2701   int64_t output_dim_size = hlo->shape().dimensions(output_sharded_dim);
2702   auto input_shard_shape =
2703       MakePartitionedShape(operand.base_shape(), operand.sharding());
2704   auto output_shard_shape = MakePartitionedShape(hlo->shape(), sharding);
2705   if (input_dim_size % output_dim_size == 0) {
2706     // Split dim.
2707     int64_t split_factor = input_dim_size / output_dim_size;
2708     int64_t output_shard_size =
2709         output_shard_shape.dimensions(output_sharded_dim);
2710     // Use halo exchange to fix misaligned data.
2711     Window window;
2712     for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2713       WindowDimension* dim = window.add_dimensions();
2714       dim->set_size(1);
2715       dim->set_stride(1);
2716       dim->set_window_dilation(1);
2717       dim->set_window_reversal(false);
2718       dim->set_base_dilation(1);
2719       dim->set_padding_low(0);
2720       if (i == input_sharded_dim) {
2721         dim->set_padding_high(output_shard_size * split_factor *
2722                                   num_partitions_ / replication_count -
2723                               input_dim_size);
2724       } else {
2725         dim->set_padding_high(0);
2726       }
2727     }
2728 
2729     auto reshard_operand = operand.ReshardAsWindowedInput(
2730         window, operand.sharding(),
2731         CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
2732         /*mask_invalid_region=*/false);
2733     if (!reshard_operand.has_value()) {
2734       return DefaultAction(hlo);
2735     }
2736     TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
2737     CHECK_EQ(
2738         reshard_operand->sharded_input->shape().dimensions(input_sharded_dim),
2739         output_shard_size * split_factor);
2740     SetPartitionedHlo(hlo, [&] {
2741       // Do a local reshape.
2742       return b_.AddInstruction(HloInstruction::CreateReshape(
2743           output_shard_shape, reshard_operand->sharded_input));
2744     });
2745     return OkStatus();
2746   } else if (output_dim_size % input_dim_size == 0) {
2747     // Merge dims.
2748     int64_t merge_factor = output_dim_size / input_dim_size;
2749     // First reshape locally. (The sharded dimension could include padded data.)
2750     auto tmp_shard_shape = output_shard_shape;
2751     tmp_shard_shape.set_dimensions(
2752         output_sharded_dim,
2753         input_shard_shape.dimensions(input_sharded_dim) * merge_factor);
2754     auto tmp_reshape = b_.AddInstruction(
2755         HloInstruction::CreateReshape(tmp_shard_shape, operand.hlo()));
2756     tmp_reshape->set_sharding(hlo->sharding());
2757     auto tmp_full_shape = tmp_shard_shape;
2758     tmp_full_shape.set_dimensions(
2759         output_sharded_dim, tmp_shard_shape.dimensions(output_sharded_dim) *
2760                                 num_partitions_ / replication_count);
2761     auto tmp_output =
2762         PartitionedHlo(tmp_reshape, tmp_full_shape, MakePartitioningState());
2763 
2764     // Use halo exchange to fix misaligned data.
2765     Window window;
2766     for (int64_t i = 0; i < tmp_shard_shape.rank(); ++i) {
2767       WindowDimension* dim = window.add_dimensions();
2768       dim->set_size(1);
2769       dim->set_stride(1);
2770       dim->set_window_dilation(1);
2771       dim->set_window_reversal(false);
2772       dim->set_base_dilation(1);
2773       dim->set_padding_low(0);
2774       if (i == output_sharded_dim) {
2775         dim->set_padding_high(output_dim_size -
2776                               tmp_shard_shape.dimensions(output_sharded_dim) *
2777                                   num_partitions_ / replication_count);
2778       } else {
2779         dim->set_padding_high(0);
2780       }
2781     }
2782 
2783     auto reshard_output = tmp_output.ReshardAsWindowedInput(
2784         window, sharding,
2785         CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
2786         /*mask_invalid_region=*/false);
2787     if (!reshard_output.has_value()) {
2788       return DefaultAction(hlo);
2789     }
2790     TF_RET_CHECK(!reshard_output->dynamic_slice_index_on_output.has_value());
2791     CHECK_EQ(
2792         reshard_output->sharded_input->shape().dimensions(output_sharded_dim),
2793         output_shard_shape.dimensions(output_sharded_dim));
2794     SetPartitionedHlo(hlo, [&] { return reshard_output->sharded_input; });
2795     return OkStatus();
2796   }
2797   return DefaultAction(hlo);
2798 }
2799 
HandleIota(HloInstruction * hlo)2800 Status SpmdPartitioningVisitor::HandleIota(HloInstruction* hlo) {
2801   const HloSharding& sharding = hlo->sharding();
2802   if (sharding.IsTileMaximal()) {
2803     return DefaultAction(hlo);
2804   }
2805 
2806   SetPartitionedHlo(hlo, [&] {
2807     int64_t dimension = Cast<HloIotaInstruction>(hlo)->iota_dimension();
2808     auto iota = b_.AddInstruction(HloInstruction::CreateIota(
2809         MakePartitionedShape(hlo->shape(), sharding), dimension));
2810 
2811     if (sharding.tile_assignment().dim(dimension) > 1) {
2812       auto partition_ordinals = MakeTiledPartitionOrdinals(
2813           sharding, MakePartitioningState().partition_id, &b_);
2814       auto multiplier = b_.AddInstruction(HloInstruction::CreateConstant(
2815           LiteralUtil::CreateR0<int32_t>(iota->shape().dimensions(dimension))));
2816       auto offset = b_.AddInstruction(HloInstruction::CreateBinary(
2817           ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply,
2818           partition_ordinals[dimension], multiplier));
2819       if (iota->shape().element_type() != S32) {
2820         offset = b_.AddInstruction(HloInstruction::CreateConvert(
2821             ShapeUtil::MakeShape(iota->shape().element_type(), {}), offset));
2822       }
2823       auto broadcast = b_.AddInstruction(
2824           HloInstruction::CreateBroadcast(iota->shape(), offset, {}));
2825       return b_.AddInstruction(HloInstruction::CreateBinary(
2826           iota->shape(), HloOpcode::kAdd, iota, broadcast));
2827     }
2828 
2829     return iota;
2830   });
2831 
2832   return OkStatus();
2833 }
2834 
HandleSingleDevice(const HloInstruction * hlo)2835 Status SpmdPartitioningVisitor::HandleSingleDevice(const HloInstruction* hlo) {
2836   TF_RET_CHECK(hlo->sharding().HasUniqueDevice());
2837   int64_t device = hlo->sharding().GetUniqueDevice();
2838   const HloSharding sharding = HloSharding::AssignDevice(device);
2839 
2840   std::vector<HloInstruction*> operands;
2841   std::vector<const Shape*> operand_shapes;
2842   const auto& old_operands = hlo->operands();
2843   const auto old_operands_size = old_operands.size();
2844   operands.reserve(old_operands_size);
2845   operand_shapes.reserve(old_operands_size);
2846   for (const HloInstruction* operand : old_operands) {
2847     operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
2848     operand_shapes.push_back(&operand->shape());
2849   }
2850   auto operand = b_.AddInstruction(HloInstruction::CreateTuple(operands));
2851   auto operand_shape = ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes);
2852 
2853   auto on_device = b_.AddInstruction(
2854       HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32_t>(device)));
2855   auto pred = b_.AddInstruction(HloInstruction::CreateCompare(
2856       ShapeUtil::MakeShape(PRED, {}), MakePartitioningState().partition_id,
2857       on_device, ComparisonDirection::kEq));
2858 
2859   SpmdBuilder true_b("true_computation", visiting_hlo_);
2860   HloComputation* true_computation;
2861   {
2862     auto param = true_b.AddInstruction(HloInstruction::CreateParameter(
2863         /*parameter_number=*/0, operand_shape, "true_branch_param"));
2864     std::vector<HloInstruction*> new_operands;
2865     for (int64_t i = 0; i < operands.size(); ++i) {
2866       new_operands.push_back(true_b.AddInstruction(
2867           HloInstruction::CreateGetTupleElement(*operand_shapes[i], param, i)));
2868     }
2869     auto root = true_b.AddInstruction(
2870         hlo->CloneWithNewOperands(hlo->shape(), new_operands));
2871     true_computation = module_->AddEmbeddedComputation(true_b.Build(root));
2872   }
2873 
2874   SpmdBuilder false_b("false_computation", visiting_hlo_);
2875   HloComputation* false_computation;
2876   {
2877     false_b.AddInstruction(HloInstruction::CreateParameter(
2878         /*parameter_number=*/0, operand_shape, "false_branch_param"));
2879     auto root = CreateZero(hlo->shape(), &false_b);
2880     false_computation = module_->AddEmbeddedComputation(false_b.Build(root));
2881   }
2882 
2883   SetPartitionedHlo(hlo, [&]() {
2884     return b_.AddInstruction(HloInstruction::CreateConditional(
2885         hlo->shape(), pred, operand, true_computation, operand,
2886         false_computation));
2887   });
2888   return OkStatus();
2889 }
2890 
HandleAllReduce(HloInstruction * hlo)2891 Status SpmdPartitioningVisitor::HandleAllReduce(HloInstruction* hlo) {
2892   if (hlo->IsCrossReplicaAllReduce() && hlo->operand_count() == 1) {
2893     return HandleElementwise(hlo);
2894   }
2895   if (hlo->channel_id()) {
2896     TF_RET_CHECK(hlo->operand_count() == 1)
2897         << "SPMD partitioner supports only single-operand allreduce in manual "
2898            "partitioning mode.";
2899     if (hlo->sharding().IsManual()) {
2900       return HandleElementwise(hlo);
2901     }
2902     TF_RET_CHECK(hlo->sharding().IsManualSubgroup())
2903         << "Cross-partition allreduce must be in (partial) manual partitioning "
2904            "mode.";
2905     auto* ar = Cast<HloAllReduceInstruction>(hlo);
2906     TF_RET_CHECK(ar->use_global_device_ids())
2907         << "Cross-partition allreduce in partial manual partitioning mode must "
2908            "use global device IDs.";
2909     absl::flat_hash_map<int64_t, int64_t> partition_to_group_id;
2910     hlo->sharding().tile_assignment().Each(
2911         [&](absl::Span<const int64_t> indices, int64_t partition) {
2912           int64_t group_id = 0;
2913           for (int64_t i = 0; i < indices.size(); ++i) {
2914             if (i == hlo->sharding().SubgroupManualDim()) {
2915               continue;
2916             }
2917             group_id *= hlo->sharding().tile_assignment().dim(i);
2918             group_id += indices[i];
2919           }
2920           partition_to_group_id[partition] = group_id;
2921         });
2922     for (const auto& group : ar->replica_groups()) {
2923       int64_t first_partition = group.replica_ids(0) % num_partitions_;
2924       for (int64_t device : group.replica_ids()) {
2925         int64_t partition = device % num_partitions_;
2926         if (partition_to_group_id[partition] !=
2927             partition_to_group_id[first_partition]) {
2928           return InvalidArgumentStrCat(
2929               "Manual all-reduce across devices that belong to different "
2930               "manual subgroups: ",
2931               ar->ToString());
2932         }
2933       }
2934     }
2935     return HandleElementwise(hlo);
2936   }
2937   return DefaultAction(hlo);
2938 }
2939 
HandleBroadcast(HloInstruction * hlo)2940 Status SpmdPartitioningVisitor::HandleBroadcast(HloInstruction* hlo) {
2941   if (hlo->sharding().IsTileMaximal()) {
2942     return DefaultAction(hlo);
2943   }
2944 
2945   auto& operand = GetPartitionedHlo(hlo->operand(0));
2946 
2947   // Tiled output.
2948   std::vector<int64_t> new_dims;
2949   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2950     if (!absl::c_linear_search(hlo->dimensions(), i)) {
2951       new_dims.push_back(i);
2952     }
2953   }
2954   auto desired_input_sharding = hlo_sharding_util::RemoveShapeDimensions(
2955       hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(hlo->sharding(),
2956                                                                new_dims),
2957       new_dims);
2958   auto input = operand.Reshard(desired_input_sharding).hlo();
2959   auto output_shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2960   SetPartitionedHlo(hlo, [&] {
2961     return b_.AddInstruction(
2962         hlo->CloneWithNewOperands(output_shard_shape, {input}));
2963   });
2964   return OkStatus();
2965 }
2966 
HandleConstant(HloInstruction * hlo)2967 Status SpmdPartitioningVisitor::HandleConstant(HloInstruction* hlo) {
2968   const Literal& literal = hlo->literal();
2969   if (literal.shape().IsTuple() ||
2970       (!hlo->sharding().IsTileMaximal() &&
2971        (!EvenlyPartitions(hlo->shape(), hlo->sharding()) ||
2972         !literal.IsAllFirst()))) {
2973     return DefaultAction(hlo);
2974   }
2975 
2976   SetPartitionedHlo(hlo, [&]() {
2977     auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2978     std::vector<int64_t> start_indices(hlo->shape().rank(), 0);
2979     auto constant = b_.AddInstruction(HloInstruction::CreateConstant(
2980         literal.Slice(start_indices, shard_shape.dimensions())));
2981     *constant->mutable_shape() = shard_shape;
2982     return constant;
2983   });
2984   return OkStatus();
2985 }
2986 
HandleDynamicSlice(HloInstruction * hlo)2987 Status SpmdPartitioningVisitor::HandleDynamicSlice(HloInstruction* hlo) {
2988   if (hlo->sharding().IsTileMaximal()) {
2989     return DefaultAction(hlo);
2990   }
2991   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2992     if (hlo->sharding().tile_assignment().dim(i) != 1 &&
2993         hlo->dynamic_slice_sizes()[i] !=
2994             hlo->operand(0)->shape().dimensions(i)) {
2995       // We currently do not partition the sliced dimensions.
2996       return DefaultAction(hlo);
2997     }
2998   }
2999   std::vector<HloInstruction*> new_indices(hlo->shape().rank());
3000   auto new_input =
3001       GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo();
3002   for (int64_t i = 0; i < new_indices.size(); ++i) {
3003     if (hlo->dynamic_slice_sizes()[i] ==
3004         hlo->operand(0)->shape().dimensions(i)) {
3005       // Trivial slice dim: index must be clampped to 0.
3006       new_indices[i] = CreateZero(hlo->operand(i + 1)->shape(), &b_);
3007       continue;
3008     }
3009     // Replicate the indices.;
3010     new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1))
3011                          .Reshard(HloSharding::Replicate())
3012                          .hlo();
3013   }
3014   SetPartitionedHlo(hlo, [&]() {
3015     auto partitioned_shape =
3016         MakePartitionedShape(hlo->shape(), hlo->sharding());
3017     return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
3018         partitioned_shape, new_input, new_indices,
3019         partitioned_shape.dimensions()));
3020   });
3021   return OkStatus();
3022 }
3023 
HandleDynamicUpdateSlice(HloInstruction * hlo)3024 Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice(HloInstruction* hlo) {
3025   if (hlo->sharding().IsTileMaximal()) {
3026     return DefaultAction(hlo);
3027   }
3028 
3029   std::vector<int64_t> partitioned_slice_dims;
3030   std::vector<int64_t> slice_dims;
3031   std::vector<int64_t> partitioned_non_slice_dims;
3032   std::vector<int64_t> partitioned_slice_offsets;
3033   bool any_non_constant_sliced_dim = false;
3034   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
3035     if (hlo->operand(1)->shape().dimensions(i) != hlo->shape().dimensions(i)) {
3036       slice_dims.push_back(i);
3037       int64_t slice_size = hlo->operand(1)->shape().dimensions(i);
3038       if (hlo->sharding().tile_assignment().dim(i) != 1) {
3039         if (!hlo->operand(i + 2)->IsConstant() && slice_size != 1) {
3040           any_non_constant_sliced_dim = true;
3041           continue;
3042         }
3043         partitioned_slice_dims.push_back(i);
3044         // Set partitioned_slice_offsets to -1 when slice_size is 1.
3045         if (slice_size == 1) {
3046           partitioned_slice_offsets.push_back(-1);
3047         } else {
3048           partitioned_slice_offsets.push_back(
3049               hlo->operand(i + 2)->literal().Get<int>({}));
3050         }
3051       }
3052     } else if (hlo->sharding().tile_assignment().dim(i) != 1) {
3053       partitioned_non_slice_dims.push_back(i);
3054     }
3055   }
3056   auto handle_with_replicate_slice_dims = [&]() {
3057     HloSharding replicated_sharding =
3058         hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept(
3059             hlo->operand(0)->sharding(), partitioned_non_slice_dims);
3060     auto base = GetPartitionedHlo(hlo->operand(0)).Reshard(replicated_sharding);
3061     auto operand =
3062         GetPartitionedHlo(hlo->operand(1)).Reshard(replicated_sharding);
3063     std::vector<HloInstruction*> new_indices(hlo->shape().rank());
3064     for (int64_t i = 0; i < new_indices.size(); ++i) {
3065       // Replicate the indices.
3066       new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
3067                            .Reshard(HloSharding::Replicate())
3068                            .hlo();
3069     }
3070     auto dus = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
3071         base.hlo()->shape(), base.hlo(), operand.hlo(), new_indices));
3072     dus->set_sharding(replicated_sharding);
3073     SetPartitionedHlo(hlo, PartitionedHlo(dus, base.base_shape(), base.state())
3074                                .Reshard(hlo->sharding()));
3075   };
3076   if (any_non_constant_sliced_dim) {
3077     if (partitioned_non_slice_dims.empty()) {
3078       return DefaultAction(hlo);
3079     }
3080     handle_with_replicate_slice_dims();
3081     return OkStatus();
3082   }
3083 
3084   // Handle when there is slice dim partitioned.
3085   if (!partitioned_slice_dims.empty()) {
3086     auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
3087       return b_.AddInstruction(std::move(to_add));
3088     };
3089     std::vector<HloInstruction*> new_indices(hlo->shape().rank());
3090     for (int64_t i = 0; i < new_indices.size(); ++i) {
3091       if (hlo->operand(1)->shape().dimensions(i) ==
3092           hlo->shape().dimensions(i)) {
3093         new_indices[i] = CreateZero(hlo->operand(i + 2)->shape(), &b_);
3094         continue;
3095       }
3096       // Replicate the indices.
3097       new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
3098                            .Reshard(HloSharding::Replicate())
3099                            .hlo();
3100     }
3101 
3102     // Get partitioned input.
3103     const auto& dus_sharding = hlo->sharding();
3104     const auto& partitioned_input =
3105         GetPartitionedHlo(hlo->operand(0)).Reshard(dus_sharding).hlo();
3106 
3107     // Get replicate update.
3108     auto update_sharding = HloSharding::Replicate();
3109     if (!partitioned_non_slice_dims.empty()) {
3110       // Do partial replicate for update if non slice dims are partitioned.
3111       update_sharding =
3112           hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(dus_sharding,
3113                                                                    slice_dims);
3114     }
3115 
3116     // TODO(wangtao): use collective permute for sharded update.
3117     HloInstruction* replicate_update =
3118         GetPartitionedHlo(hlo->operand(1)).Reshard(update_sharding).hlo();
3119 
3120     const auto& update_shape = replicate_update->shape();
3121     const auto& partitioned_shape = partitioned_input->shape();
3122     auto partition_ordinals = MakeTiledPartitionOrdinals(
3123         hlo->sharding(), MakePartitioningState().partition_id, &b_);
3124     HloInstruction* all_dims_within_partition = add_hlo(
3125         HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
3126 
3127     for (int i = 0; i < partitioned_slice_dims.size(); ++i) {
3128       int dim = partitioned_slice_dims[i];
3129       // Calculate per partition size.
3130       const int64_t per_partition_size = partitioned_shape.dimensions(dim);
3131 
3132       // Only update within a single partition is supported.
3133       // Will ignore this check when slice size is 1 where
3134       // partitioned_slice_offsets[i] is -1.
3135       if ((partitioned_slice_offsets[i] != -1) &&
3136           (partitioned_slice_offsets[i] / per_partition_size) !=
3137               ((partitioned_slice_offsets[i] + update_shape.dimensions(dim) -
3138                 1) /
3139                per_partition_size)) {
3140         handle_with_replicate_slice_dims();
3141         return Status::OK();
3142       }
3143 
3144       // within_partition = (offset >= partition_id * per_partition_size) &&
3145       //                    (offset < (partition_id + 1) * per_partition_size)
3146       const Shape& compare_shape =
3147           ShapeUtil::ChangeElementType(partition_id_->shape(), PRED);
3148       auto per_partition_size_hlo = add_hlo(HloInstruction::CreateConstant(
3149           LiteralUtil::CreateR0<int>(per_partition_size)));
3150       const Shape& offset_shape = per_partition_size_hlo->shape();
3151       auto partition_offset = add_hlo(HloInstruction::CreateBinary(
3152           offset_shape, HloOpcode::kMultiply, partition_ordinals[dim],
3153           per_partition_size_hlo));
3154       // offset >= partition_id * per_partition_size
3155       auto offset_ge = add_hlo(HloInstruction::CreateCompare(
3156           compare_shape, new_indices[dim], partition_offset,
3157           ComparisonDirection::kGe));
3158       // offset < (partition_id + 1) * per_partition_size
3159       auto offset_lt = add_hlo(HloInstruction::CreateCompare(
3160           compare_shape, new_indices[dim],
3161           add_hlo(HloInstruction::CreateBinary(
3162               offset_shape, HloOpcode::kMultiply,
3163               add_hlo(HloInstruction::CreateBinary(
3164                   offset_shape, HloOpcode::kAdd, partition_ordinals[dim],
3165                   add_hlo(HloInstruction::CreateConstant(
3166                       LiteralUtil::CreateR0<int>(1))))),
3167               per_partition_size_hlo)),
3168           ComparisonDirection::kLt));
3169       auto update_within_partition = add_hlo(HloInstruction::CreateBinary(
3170           compare_shape, HloOpcode::kAnd, offset_ge, offset_lt));
3171 
3172       all_dims_within_partition = add_hlo(HloInstruction::CreateBinary(
3173           compare_shape, HloOpcode::kAnd, all_dims_within_partition,
3174           update_within_partition));
3175 
3176       // Calculate offset.
3177       // slice dim offset =
3178       //  within_partition ?
3179       //  offset - partition_id * per_partition_size : 0
3180       new_indices[dim] = add_hlo(HloInstruction::CreateTernary(
3181           new_indices[dim]->shape(), HloOpcode::kSelect,
3182           update_within_partition,
3183           add_hlo(HloInstruction::CreateBinary(
3184               new_indices[dim]->shape(), HloOpcode::kSubtract, new_indices[dim],
3185               partition_offset)),
3186           add_hlo(
3187               HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)))));
3188     }
3189 
3190     // Create dynamic update slice.
3191     auto dus = add_hlo(HloInstruction::CreateDynamicUpdateSlice(
3192         partitioned_shape, partitioned_input, replicate_update, new_indices));
3193     SetPartitionedHlo(hlo, [&]() {
3194       // Select if update is needed.
3195       return add_hlo(HloInstruction::CreateTernary(
3196           dus->shape(), HloOpcode::kSelect,
3197           add_hlo(HloInstruction::CreateBroadcast(
3198               ShapeUtil::ChangeElementType(dus->shape(), PRED),
3199               all_dims_within_partition, {})),
3200           dus, partitioned_input));
3201     });
3202     return OkStatus();
3203   }
3204 
3205   // Partition non slice dims only.
3206   std::vector<HloInstruction*> new_indices(hlo->shape().rank());
3207   auto new_input =
3208       GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo();
3209   auto new_update =
3210       GetPartitionedHlo(hlo->operand(1)).Reshard(hlo->sharding()).hlo();
3211   for (int64_t i = 0; i < new_indices.size(); ++i) {
3212     if (hlo->operand(1)->shape().dimensions(i) == hlo->shape().dimensions(i)) {
3213       new_indices[i] = CreateZero(hlo->operand(i + 2)->shape(), &b_);
3214       continue;
3215     }
3216     // Replicate the indices.
3217     new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
3218                          .Reshard(HloSharding::Replicate())
3219                          .hlo();
3220   }
3221   SetPartitionedHlo(hlo, [&]() {
3222     auto partitioned_shape =
3223         MakePartitionedShape(hlo->shape(), hlo->sharding());
3224     return b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
3225         partitioned_shape, new_input, new_update, new_indices));
3226   });
3227   return OkStatus();
3228 }
3229 
HandleGetTupleElement(HloInstruction * hlo)3230 Status SpmdPartitioningVisitor::HandleGetTupleElement(HloInstruction* hlo) {
3231   const auto& tuple = GetPartitionedHlo(hlo->operand(0));
3232   auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
3233       ShapeUtil::GetTupleElementShape(tuple.hlo()->shape(), hlo->tuple_index()),
3234       tuple.hlo(), hlo->tuple_index()));
3235   const auto source_sharding =
3236       tuple.sharding().GetSubSharding(tuple.base_shape(), {hlo->tuple_index()});
3237   gte->set_sharding(source_sharding);
3238   PartitionedHlo source_partitioned_gte(
3239       gte, tuple.base_shape().tuple_shapes(hlo->tuple_index()),
3240       MakePartitioningState());
3241   source_partitioned_gte = source_partitioned_gte.Reshard(hlo->sharding());
3242   SetPartitionedHlo(hlo, source_partitioned_gte);
3243   return OkStatus();
3244 }
3245 
HandleInfeed(HloInstruction * hlo)3246 Status SpmdPartitioningVisitor::HandleInfeed(HloInstruction* hlo) {
3247   const Shape& shape = ShapeUtil::GetTupleElementShape(hlo->shape(), 0);
3248   auto token = GetPartitionedHlo(hlo->operand(0)).hlo();
3249   if (ShapeUtil::GetLeafCount(shape) == 0) {
3250     // TODO(b/155819021): HloSharding has issues with tuple-shaped sharding: it
3251     // requires one element for an empty tuple, but leaf-count number of
3252     // elements for non-empty tuple. So if it has a nested empty tuple, we
3253     // cannot invoke GetSubSharding() since it expects a sharding for the empty
3254     // tuple. This is a workaround for that case.
3255     SetPartitionedHlo(hlo, [&]() {
3256       return b_.AddInstruction(
3257           HloInstruction::CreateInfeed(shape, token, hlo->infeed_config()));
3258     });
3259     return OkStatus();
3260   }
3261   auto sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
3262   auto shard_shape = MakePartitionedShape(shape, sharding);
3263   if (EvenlyPartitions(shape, sharding)) {
3264     SetPartitionedHlo(hlo, [&]() {
3265       return b_.AddInstruction(HloInstruction::CreateInfeed(
3266           shard_shape, token, hlo->infeed_config()));
3267     });
3268     return OkStatus();
3269   }
3270 
3271   if (hlo->sharding().HasUniqueDevice()) {
3272     return HandleSingleDevice(hlo);
3273   }
3274 
3275   // Create a branch for each unique partitioned shape.
3276   std::vector<Shape> per_branch_partitioned_shapes;
3277   std::vector<int32_t> conditional_branch_indices(num_partitions_);
3278   for (int64_t i = 0; i < num_partitions_; ++i) {
3279     auto partitioned_shape =
3280         MakeNonPaddedShapeForGivenPartition(shape, sharding, i);
3281     int64_t matching_existing_index = 0;
3282     for (; matching_existing_index < per_branch_partitioned_shapes.size();
3283          ++matching_existing_index) {
3284       if (ShapeUtil::Compatible(
3285               partitioned_shape,
3286               per_branch_partitioned_shapes[matching_existing_index])) {
3287         break;
3288       }
3289     }
3290     if (matching_existing_index < per_branch_partitioned_shapes.size()) {
3291       conditional_branch_indices[i] = matching_existing_index;
3292     } else {
3293       conditional_branch_indices[i] = per_branch_partitioned_shapes.size();
3294       per_branch_partitioned_shapes.push_back(std::move(partitioned_shape));
3295     }
3296   }
3297 
3298   HloInstruction* branch_index;
3299   auto state = MakePartitioningState();
3300   if (per_branch_partitioned_shapes.size() == num_partitions_) {
3301     // Use partition ID as the branch index if each partition has its own
3302     // branch.
3303     branch_index = state.partition_id;
3304     // PartitionId's output is U32 but conditional requires S32.
3305     if (branch_index->shape().element_type() != S32) {
3306       branch_index = b_.AddInstruction(HloInstruction::CreateConvert(
3307           ShapeUtil::ChangeElementType(branch_index->shape(), S32),
3308           branch_index));
3309     }
3310   } else {
3311     // Otherwise, use a constant table to look up the branch index.
3312     auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant(
3313         LiteralUtil::CreateR1<int32_t>(conditional_branch_indices)));
3314     branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
3315         ShapeUtil::MakeShape(S32, {1}), branch_index_table,
3316         {state.partition_id}, {1}));
3317     branch_index = b_.AddInstruction(HloInstruction::CreateReshape(
3318         ShapeUtil::MakeShape(S32, {}), branch_index));
3319   }
3320 
3321   std::vector<HloComputation*> branches(per_branch_partitioned_shapes.size());
3322   for (int64_t i = 0; i < branches.size(); ++i) {
3323     SpmdBuilder branch_b(absl::StrCat("infeed_branch_", i), visiting_hlo_);
3324     auto param = branch_b.AddInstruction(HloInstruction::CreateParameter(
3325         /*parameter_number=*/0, token->shape(), "infeed_token_param"));
3326     auto infeed = branch_b.AddInstruction(HloInstruction::CreateInfeed(
3327         per_branch_partitioned_shapes[i], param, hlo->infeed_config()));
3328     if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) {
3329       std::function<HloInstruction*(const ShapeIndex&, HloInstruction*)>
3330           pad_infeed = [&](const ShapeIndex& index,
3331                            HloInstruction* infeed_element) -> HloInstruction* {
3332         if (index == ShapeIndex({1})) {
3333           // Token.
3334           return infeed_element;
3335         }
3336         const Shape& element_shape =
3337             ShapeUtil::GetSubshape(infeed->shape(), index);
3338         if (element_shape.IsTuple() && element_shape.tuple_shapes_size() > 0) {
3339           std::vector<HloInstruction*> padded_elements(
3340               element_shape.tuple_shapes_size());
3341           for (int64_t i = 0; i < padded_elements.size(); ++i) {
3342             auto sub_index = index;
3343             sub_index.push_back(i);
3344             padded_elements[i] = pad_infeed(
3345                 sub_index,
3346                 branch_b.AddInstruction(HloInstruction::CreateGetTupleElement(
3347                     ShapeUtil::GetSubshape(element_shape, {i}), infeed_element,
3348                     i)));
3349           }
3350           return branch_b.AddInstruction(
3351               HloInstruction::CreateTuple(padded_elements));
3352         }
3353         const Shape& pad_shape = ShapeUtil::GetSubshape(
3354             shard_shape, ShapeIndexView(index).subspan(1));
3355         if (ShapeUtil::Compatible(element_shape, pad_shape)) {
3356           return infeed_element;
3357         }
3358         if (element_shape.IsArray()) {
3359           CHECK(pad_shape.IsArray());
3360           return PadToShape(infeed_element, pad_shape, &branch_b);
3361         }
3362         CHECK(element_shape.IsTuple());
3363         CHECK(element_shape.tuple_shapes().empty());
3364         return CreateZero(pad_shape, &branch_b);
3365       };
3366       pad_infeed({}, infeed);
3367     }
3368     branches[i] = module_->AddEmbeddedComputation(branch_b.Build());
3369   }
3370   SetPartitionedHlo(hlo, [&]() {
3371     return b_.AddInstruction(HloInstruction::CreateConditional(
3372         ShapeUtil::MakeTupleShape({shard_shape, token->shape()}), branch_index,
3373         branches, std::vector<HloInstruction*>(branches.size(), token)));
3374   });
3375   return OkStatus();
3376 }
3377 
HandlePad(HloInstruction * hlo)3378 Status SpmdPartitioningVisitor::HandlePad(HloInstruction* hlo) {
3379   if (hlo->sharding().IsTileMaximal()) {
3380     return DefaultAction(hlo);
3381   }
3382   auto lhs = GetPartitionedHlo(hlo->operand(0));
3383   // Create a window config to represent the pad.
3384   Window window;
3385   bool needs_masking = false;
3386   const bool pad_value_is_zero =
3387       hlo->operand(1)->IsConstant() && hlo->operand(1)->literal().IsZero({});
3388   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
3389     const auto& pd = hlo->padding_config().dimensions(i);
3390     WindowDimension* dim = window.add_dimensions();
3391     dim->set_size(1);
3392     dim->set_stride(1);
3393     dim->set_window_dilation(1);
3394     dim->set_window_reversal(false);
3395     dim->set_padding_low(pd.edge_padding_low());
3396     dim->set_padding_high(pd.edge_padding_high());
3397     dim->set_base_dilation(pd.interior_padding() + 1);
3398     const int64_t shard_count = hlo->sharding().tile_assignment().dim(i);
3399     // Need masking only if there is non-zero padding value or the operand is
3400     // unevenly partitioned. Halo exchange fills 0 in collective permute result
3401     // for non-destination cores.
3402     needs_masking |=
3403         shard_count > 1 &&
3404         (pd.edge_padding_low() > 0 || pd.edge_padding_high() > 0 ||
3405          pd.interior_padding() > 0) &&
3406         (!pad_value_is_zero ||
3407          hlo->operand(0)->shape().dimensions(i) % shard_count != 0);
3408   }
3409 
3410   auto replicated_rhs = GetPartitionedHlo(hlo->operand(1))
3411                             .Reshard(HloSharding::Replicate())
3412                             .hlo();
3413   auto reshard_operand =
3414       lhs.ReshardAsWindowedInput(window, hlo->sharding(), replicated_rhs,
3415                                  /*mask_invalid_region=*/needs_masking);
3416   if (!reshard_operand.has_value()) {
3417     return DefaultAction(hlo);
3418   }
3419   PaddingConfig sharded_padding_config;
3420   bool need_pad = false;
3421   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
3422     auto dim = sharded_padding_config.add_dimensions();
3423     const auto& wd = reshard_operand->shard_window.dimensions(i);
3424     dim->set_edge_padding_low(wd.padding_low());
3425     dim->set_edge_padding_high(wd.padding_high());
3426     dim->set_interior_padding(wd.base_dilation() - 1);
3427     if (wd.padding_low() != 0 || wd.padding_high() != 0 ||
3428         wd.base_dilation() != 1) {
3429       need_pad = true;
3430     }
3431   }
3432   auto sharded_pad = reshard_operand->sharded_input;
3433   if (need_pad) {
3434     TF_ASSIGN_OR_RETURN(auto sharded_pad_shape,
3435                         ShapeInference::InferPadShape(sharded_pad->shape(),
3436                                                       replicated_rhs->shape(),
3437                                                       sharded_padding_config));
3438     sharded_pad = b_.AddInstruction(hlo->CreatePad(sharded_pad_shape,
3439                                                    sharded_pad, replicated_rhs,
3440                                                    sharded_padding_config));
3441   }
3442 
3443   SetPartitionedHlo(hlo, [&]() {
3444     if (!reshard_operand->dynamic_slice_index_on_output) {
3445       return sharded_pad;
3446     }
3447     auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
3448     return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
3449         shard_shape, sharded_pad,
3450         *reshard_operand->dynamic_slice_index_on_output,
3451         shard_shape.dimensions()));
3452   });
3453   return OkStatus();
3454 }
3455 
HandleParameter(HloInstruction * hlo)3456 Status SpmdPartitioningVisitor::HandleParameter(HloInstruction* hlo) {
3457   SetPartitionedHlo(hlo, [&]() {
3458     auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
3459     auto new_param = b_.AddInstruction(HloInstruction::CreateParameter(
3460         hlo->parameter_number(), shard_shape, "param"));
3461     if (hlo->parameter_replicated_at_leaf_buffers()) {
3462       new_param->set_parameter_replicated_at_leaf_buffers(
3463           *hlo->parameter_replicated_at_leaf_buffers());
3464     }
3465     return new_param;
3466   });
3467   return OkStatus();
3468 }
3469 
HandleReduce(HloInstruction * hlo)3470 Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) {
3471   if (hlo->sharding().HasUniqueDevice()) {
3472     return DefaultAction(hlo);
3473   }
3474   int64_t input_count = 1;
3475   auto per_input_sharding = hlo->sharding();
3476   if (hlo->shape().IsTuple()) {
3477     input_count = hlo->shape().tuple_shapes_size();
3478     CHECK_GT(input_count, 0);
3479     per_input_sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
3480   }
3481 
3482   std::vector<PartitionedHlo> inputs;
3483   std::vector<HloInstruction*> inits;
3484   std::vector<int64_t> preserved_dims;
3485   for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
3486     if (!absl::c_linear_search(hlo->dimensions(), i)) {
3487       preserved_dims.push_back(i);
3488     }
3489   }
3490 
3491   for (int64_t operand_id = 0; operand_id < input_count; ++operand_id) {
3492     inits.push_back(GetPartitionedHlo(hlo->operand(operand_id + input_count))
3493                         .Reshard(HloSharding::Replicate())
3494                         .hlo());
3495     inputs.push_back(GetPartitionedHlo(hlo->operand(operand_id)));
3496     if (operand_id > 0) {
3497       // Make sure all operands are sharded in the same way.
3498       inputs.back() = inputs.back().Reshard(inputs[0].sharding());
3499     }
3500     if (!inputs[0].sharding().IsTileMaximal()) {
3501       inputs.back() =
3502           inputs.back().PadWithValue(inits[operand_id], /*left_padded_dims=*/{},
3503                                      /*skipped_dims=*/preserved_dims);
3504     }
3505   }
3506 
3507   std::vector<const Shape*> new_operand_shapes(input_count * 2);
3508   for (int64_t i = 0; i < input_count; ++i) {
3509     new_operand_shapes[i] = &inputs[i].hlo()->shape();
3510     new_operand_shapes[i + input_count] = &inits[i]->shape();
3511   }
3512   // Create the shard shape of the reduce result.
3513   TF_ASSIGN_OR_RETURN(
3514       auto reduce_shape,
3515       ShapeInference::InferReduceShape(new_operand_shapes, hlo->dimensions(),
3516                                        hlo->to_apply()->ComputeProgramShape()));
3517 
3518   std::vector<HloInstruction*> input_hlos(input_count);
3519   for (int64_t i = 0; i < input_count; ++i) {
3520     input_hlos[i] = inputs[i].hlo();
3521   }
3522   auto local_reduce = b_.AddInstruction(HloInstruction::CreateReduce(
3523       reduce_shape, input_hlos, inits, hlo->dimensions(), hlo->to_apply()));
3524 
3525   SetPartitionedHlo(hlo, [&]() {
3526     HloInstruction* reduce = local_reduce;
3527     const bool reduce_sharded_dimension =
3528         !inputs[0].sharding().IsTileMaximal() &&
3529         absl::c_any_of(hlo->dimensions(), [&](int64_t i) {
3530           return inputs[0].sharding().tile_assignment().dim(i) > 1;
3531         });
3532     if (reduce_sharded_dimension) {
3533       if (inputs[0].sharding().ReplicateOnLastTileDim()) {
3534         preserved_dims.push_back(inputs[0].base_shape().rank());
3535       }
3536       if (local_reduce->shape().IsArray()) {
3537         reduce = partitioner_->AllReduceAlongShardingDims(
3538             &b_, local_reduce, inputs[0].sharding(), next_channel_id_,
3539             hlo->dimensions(), collective_ops_creator_, hlo->to_apply());
3540       } else {
3541         auto grouped = hlo_sharding_util::GroupShardingOnDims(
3542             inputs[0].sharding(), preserved_dims);
3543         auto grouped_state = CreatePerGroupPartitioningState(
3544             inputs[0].state(), grouped.device_groups, &b_);
3545         std::vector<HloInstruction*> all_gathered_partial_results(input_count);
3546         for (int64_t i = 0; i < input_count; ++i) {
3547           auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
3548               ShapeUtil::GetTupleElementShape(reduce_shape, i), local_reduce,
3549               i));
3550           auto expanded_shape = input_hlos[i]->shape();
3551           auto all_gather_shape = input_hlos[i]->shape();
3552           for (int64_t dim : hlo->dimensions()) {
3553             expanded_shape.set_dimensions(dim, 1);
3554             all_gather_shape.set_dimensions(
3555                 dim, inputs[0].sharding().tile_assignment().dim(dim));
3556           }
3557           auto reshape = b_.AddInstruction(
3558               HloInstruction::CreateReshape(expanded_shape, gte));
3559           // Replicate per group.
3560           reshape->set_sharding(grouped.sharding);
3561           all_gathered_partial_results[i] =
3562               PartitionedHlo(reshape, all_gather_shape, grouped_state)
3563                   .Replicate()
3564                   .hlo();
3565         }
3566         reduce = b_.AddInstruction(HloInstruction::CreateReduce(
3567             reduce_shape, all_gathered_partial_results, inits,
3568             hlo->dimensions(), hlo->to_apply()));
3569       }
3570     }
3571     auto sharding = hlo_sharding_util::RemoveShapeDimensions(
3572         hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
3573             inputs[0].sharding(), hlo->dimensions()),
3574         hlo->dimensions());
3575     if (local_reduce->shape().IsArray()) {
3576       reduce->set_sharding(sharding);
3577     } else {
3578       reduce->set_sharding(HloSharding::Tuple(
3579           reduce->shape(), std::vector<HloSharding>(input_count, sharding)));
3580     }
3581     return PartitionedHlo(reduce, hlo->shape(), MakePartitioningState())
3582         .Reshard(hlo->sharding())
3583         .hlo();
3584   });
3585   return OkStatus();
3586 }
3587 
HandleReverse(HloInstruction * hlo)3588 Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) {
3589   auto reverse = Cast<HloReverseInstruction>(hlo);
3590   if (reverse->sharding().IsTileMaximal()) {
3591     return DefaultAction(hlo);
3592   }
3593   auto operand = GetPartitionedHlo(reverse->operand(0))
3594                      .Reshard(hlo_sharding_util::ReverseSharding(
3595                          reverse->sharding(), reverse->dimensions()));
3596   auto left_padded_operand =
3597       HaloExchangeToPadOnLeft(operand, reverse->dimensions());
3598   if (!left_padded_operand) {
3599     return DefaultAction(hlo);
3600   }
3601   SetPartitionedHlo(hlo, [&] {
3602     return b_.AddInstruction(hlo->CloneWithNewOperands(
3603         left_padded_operand->shape(), {left_padded_operand}));
3604   });
3605   return OkStatus();
3606 }
3607 
HandleWhile(HloInstruction * hlo)3608 Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) {
3609   const HloSharding& sharding = hlo->sharding();
3610 
3611   // Shardings for the body parameter, body root, and cond parameter must be
3612   // the same, and the condition root must be replicated so that all partitions
3613   // follow the same control flow.
3614   hlo->while_condition()->parameter_instruction(0)->set_sharding(sharding);
3615   hlo->while_body()->parameter_instruction(0)->set_sharding(sharding);
3616   const HloSharding& cond_root_sharding =
3617       hlo->while_condition()->root_instruction()->sharding();
3618   TF_RETURN_IF_ERROR(partitioner_
3619                          ->PartitionComputation(hlo->while_condition(),
3620                                                 cond_root_sharding.IsManual()
3621                                                     ? cond_root_sharding
3622                                                     : HloSharding::Replicate(),
3623                                                 next_channel_id_, logger_)
3624                          .status());
3625   TF_RETURN_IF_ERROR(partitioner_
3626                          ->PartitionComputation(hlo->while_body(), sharding,
3627                                                 next_channel_id_, logger_)
3628                          .status());
3629   SetPartitionedHlo(hlo, [&] {
3630     return b_.AddInstruction(HloInstruction::CreateWhile(
3631         MakePartitionedShape(hlo->shape(), sharding), hlo->while_condition(),
3632         hlo->while_body(),
3633         GetPartitionedHlo(hlo->operand(0)).Reshard(sharding).hlo()));
3634   });
3635   return OkStatus();
3636 }
3637 
HandleConditional(HloInstruction * hlo)3638 Status SpmdPartitioningVisitor::HandleConditional(HloInstruction* hlo) {
3639   std::vector<HloInstruction*> branch_args;
3640   for (int64_t i = 0; i < hlo->branch_count(); ++i) {
3641     HloComputation* computation = hlo->branch_computation(i);
3642 
3643     // Shardings of the branch computation parameter and its argument must be
3644     // the same.
3645     computation->parameter_instruction(0)->set_sharding(
3646         hlo->operand(i + 1)->sharding());
3647     branch_args.push_back(GetPartitionedHlo(hlo->operand(i + 1)).hlo());
3648   }
3649 
3650   // The root of the branch computations must follow the sharding of the
3651   // conditional instruction.
3652   for (int64_t i = 0; i < hlo->branch_count(); ++i) {
3653     HloComputation* computation = hlo->branch_computation(i);
3654     TF_RETURN_IF_ERROR(partitioner_
3655                            ->PartitionComputation(computation, hlo->sharding(),
3656                                                   next_channel_id_, logger_)
3657                            .status());
3658   }
3659   SetPartitionedHlo(hlo, [&] {
3660     HloInstruction* cond = GetPartitionedHlo(hlo->operand(0)).hlo();
3661     if (!hlo->operand(0)->sharding().IsManual()) {
3662       // We replicate the predicate of the conditional (the first operand) so
3663       // that all partitions follow the same control flow.
3664       cond = GetPartitionedHlo(hlo->operand(0))
3665                  .Reshard(HloSharding::Replicate())
3666                  .hlo();
3667     }
3668     return b_.AddInstruction(HloInstruction::CreateConditional(
3669         MakePartitionedShape(hlo->shape(), hlo->sharding()), cond,
3670         hlo->called_computations(), branch_args));
3671   });
3672   return OkStatus();
3673 }
3674 
HandleOptimizationBarrier(HloInstruction * hlo)3675 Status SpmdPartitioningVisitor::HandleOptimizationBarrier(HloInstruction* hlo) {
3676   return HandleElementwise(hlo);
3677 }
3678 
HandleOutfeed(HloInstruction * hlo)3679 Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
3680   if (hlo->sharding().HasUniqueDevice()) {
3681     return HandleSingleDevice(hlo);
3682   }
3683 
3684   const auto& sharding = hlo->sharding();
3685   const Shape& shape = hlo->operand(0)->shape();
3686   auto partitioned_operand =
3687       GetPartitionedHlo(hlo->operand(0)).Reshard(sharding);
3688   const auto& shard_shape = partitioned_operand.hlo()->shape();
3689   const auto& operand = partitioned_operand.hlo();
3690   auto token = GetPartitionedHlo(hlo->operand(1)).hlo();
3691 
3692   if (EvenlyPartitions(shape, sharding)) {
3693     Shape outfeed_shape = operand->shape();
3694     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(hlo->outfeed_shape(),
3695                                                            &outfeed_shape));
3696     SetPartitionedHlo(hlo, [&]() {
3697       return b_.AddInstruction(HloInstruction::CreateOutfeed(
3698           outfeed_shape, operand, token, hlo->outfeed_config()));
3699     });
3700     return OkStatus();
3701   }
3702 
3703   // Create a branch for each unique partitioned shape.
3704   std::vector<Shape> per_branch_partitioned_shapes;
3705   std::vector<int32_t> conditional_branch_indices(num_partitions_);
3706   for (int64_t i = 0; i < num_partitions_; ++i) {
3707     auto partitioned_shape =
3708         MakeNonPaddedShapeForGivenPartition(shape, sharding, i);
3709     int64_t matching_existing_index = 0;
3710     for (; matching_existing_index < per_branch_partitioned_shapes.size();
3711          ++matching_existing_index) {
3712       if (ShapeUtil::Compatible(
3713               partitioned_shape,
3714               per_branch_partitioned_shapes[matching_existing_index])) {
3715         break;
3716       }
3717     }
3718     if (matching_existing_index < per_branch_partitioned_shapes.size()) {
3719       conditional_branch_indices[i] = matching_existing_index;
3720     } else {
3721       conditional_branch_indices[i] = per_branch_partitioned_shapes.size();
3722       per_branch_partitioned_shapes.push_back(std::move(partitioned_shape));
3723     }
3724   }
3725 
3726   // Get branch index for this partition.
3727   HloInstruction* branch_index;
3728   auto state = MakePartitioningState();
3729   if (per_branch_partitioned_shapes.size() == num_partitions_) {
3730     // Use partition ID as the branch index if each partition has its own
3731     // branch.
3732     branch_index = state.partition_id;
3733     // PartitionId's output is U32 but conditional requires S32.
3734     if (branch_index->shape().element_type() != S32) {
3735       branch_index = b_.AddInstruction(HloInstruction::CreateConvert(
3736           ShapeUtil::ChangeElementType(branch_index->shape(), S32),
3737           branch_index));
3738     }
3739   } else {
3740     // Otherwise, use a constant table to look up the branch index.
3741     auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant(
3742         LiteralUtil::CreateR1<int32_t>(conditional_branch_indices)));
3743     branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
3744         ShapeUtil::MakeShape(S32, {1}), branch_index_table, {partition_id_},
3745         {1}));
3746     branch_index = b_.AddInstruction(HloInstruction::CreateReshape(
3747         ShapeUtil::MakeShape(S32, {}), branch_index));
3748   }
3749 
3750   // Create conditional for the outfeed.
3751   std::vector<HloComputation*> branches(per_branch_partitioned_shapes.size());
3752   for (int64_t i = 0; i < branches.size(); ++i) {
3753     SpmdBuilder branch_b(absl::StrCat("outfeed_branch_", i), visiting_hlo_);
3754     // Create tuple param within the branch.
3755     auto param = branch_b.AddInstruction(HloInstruction::CreateParameter(
3756         /*parameter_number=*/0,
3757         ShapeUtil::MakeTupleShape({operand->shape(), token->shape()}),
3758         "outfeed_token_param"));
3759     auto outfeed_data = branch_b.AddInstruction(
3760         HloInstruction::CreateGetTupleElement(operand->shape(), param, 0));
3761     auto outfeed_token = branch_b.AddInstruction(
3762         HloInstruction::CreateGetTupleElement(token->shape(), param, 1));
3763     if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) {
3764       std::function<HloInstruction*(const ShapeIndex&, HloInstruction*)>
3765           slice_outfeed =
3766               [&](const ShapeIndex& index,
3767                   HloInstruction* outfeed_operand) -> HloInstruction* {
3768         // Get outfeed element shape.
3769         const Shape& element_shape =
3770             ShapeUtil::GetSubshape(outfeed_data->shape(), index);
3771         // Recursively call slice_outfeed for tuple shapes.
3772         if (element_shape.IsTuple() && element_shape.tuple_shapes_size() > 0) {
3773           std::vector<HloInstruction*> slice_elements(
3774               element_shape.tuple_shapes_size());
3775           for (int64_t i = 0; i < slice_elements.size(); ++i) {
3776             auto sub_index = index;
3777             sub_index.push_back(i);
3778             slice_elements[i] = slice_outfeed(
3779                 sub_index,
3780                 branch_b.AddInstruction(HloInstruction::CreateGetTupleElement(
3781                     ShapeUtil::GetSubshape(element_shape, {i}), outfeed_operand,
3782                     i)));
3783           }
3784           return branch_b.AddInstruction(
3785               HloInstruction::CreateTuple(slice_elements));
3786         }
3787         // Get the slice shape.
3788         const Shape& slice_shape = ShapeUtil::GetSubshape(
3789             per_branch_partitioned_shapes[i], ShapeIndexView(index));
3790         if (ShapeUtil::Compatible(element_shape, slice_shape)) {
3791           return outfeed_operand;
3792         }
3793         // Slice out useful data.
3794         if (element_shape.IsArray()) {
3795           CHECK(slice_shape.IsArray());
3796           std::vector<int64_t> start_indices(slice_shape.rank(), 0);
3797           std::vector<int64_t> slice_strides(slice_shape.rank(), 1);
3798           return branch_b.AddInstruction(HloInstruction::CreateSlice(
3799               slice_shape, outfeed_operand, start_indices,
3800               slice_shape.dimensions(), slice_strides));
3801         }
3802         CHECK(element_shape.IsTuple());
3803         CHECK(element_shape.tuple_shapes().empty());
3804         return outfeed_operand;
3805       };
3806       outfeed_data = slice_outfeed({}, outfeed_data);
3807     }
3808     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3809         hlo->outfeed_shape(), &per_branch_partitioned_shapes[i]));
3810     branch_b.AddInstruction(HloInstruction::CreateOutfeed(
3811         per_branch_partitioned_shapes[i], outfeed_data, outfeed_token,
3812         hlo->outfeed_config()));
3813     branches[i] = module_->AddEmbeddedComputation(branch_b.Build());
3814   }
3815   SetPartitionedHlo(hlo, [&]() {
3816     return b_.AddInstruction(HloInstruction::CreateConditional(
3817         token->shape(), branch_index, branches,
3818         std::vector<HloInstruction*>(
3819             branches.size(),
3820             b_.AddInstruction(HloInstruction::CreateTuple({operand, token})))));
3821   });
3822   return OkStatus();
3823 }
3824 
HandleRng(HloInstruction * hlo)3825 Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) {
3826   if (hlo->sharding().HasUniqueDevice()) {
3827     return HandleSingleDevice(hlo);
3828   }
3829   auto clone_from_original = [&](const HloSharding& shared_sharding) {
3830     std::vector<HloInstruction*> new_operands;
3831     for (int64_t i = 0; i < hlo->operand_count(); ++i) {
3832       new_operands.push_back(
3833           GetPartitionedHlo(hlo->operand(i)).Reshard(shared_sharding).hlo());
3834     }
3835     auto clone = b_.AddInstruction(
3836         hlo->CloneWithNewOperands(hlo->shape(), new_operands));
3837     clone->set_sharding(shared_sharding);
3838     return clone;
3839   };
3840 
3841   if (hlo->sharding().IsManual()) {
3842     SetPartitionedHlo(hlo,
3843                       [&] { return clone_from_original(hlo->sharding()); });
3844     return OkStatus();
3845   }
3846 
3847   if (hlo->sharding().IsReplicated()) {
3848     SetPartitionedHlo(hlo, [&] {
3849       // Run on a single device (0) and distribute the data to all other cores.
3850       auto clone = clone_from_original(HloSharding::AssignDevice(0));
3851       return PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
3852           .Reshard(HloSharding::Replicate())
3853           .hlo();
3854     });
3855     return OkStatus();
3856   }
3857 
3858   TF_RET_CHECK(!hlo->sharding().IsTileMaximal());
3859   // Replicate the operands and run partitioned Rng on all devices.
3860   std::vector<HloInstruction*> new_operands;
3861   for (int64_t i = 0; i < hlo->operand_count(); ++i) {
3862     new_operands.push_back(GetPartitionedHlo(hlo->operand(i))
3863                                .Reshard(HloSharding::Replicate())
3864                                .hlo());
3865   }
3866 
3867   if (!hlo->sharding().ReplicateOnLastTileDim()) {
3868     SetPartitionedHlo(hlo, [&] {
3869       return b_.AddInstruction(HloInstruction::CreateRng(
3870           MakePartitionedShape(hlo->shape(), hlo->sharding()),
3871           hlo->random_distribution(), new_operands));
3872     });
3873   } else {
3874     std::vector<int64_t> group_dims(
3875         hlo->sharding().tile_assignment().num_dimensions() - 1);
3876     std::iota(group_dims.begin(), group_dims.end(), 0);
3877     auto sharding_grouped =
3878         hlo_sharding_util::GroupShardingOnDims(hlo->sharding(), group_dims);
3879     auto per_group_state = CreatePerGroupPartitioningState(
3880         MakePartitioningState(), sharding_grouped.device_groups, &b_);
3881     auto rng = b_.AddInstruction(HloInstruction::CreateRng(
3882         MakePartitionedShape(hlo->shape(), hlo->sharding()),
3883         hlo->random_distribution(), new_operands));
3884     rng->set_sharding(HloSharding::AssignDevice(0));
3885     SetPartitionedHlo(hlo, [&]() {
3886       return PartitionedHlo(rng, rng->shape(), per_group_state)
3887           .Replicate()
3888           .hlo();
3889     });
3890   }
3891   return OkStatus();
3892 }
3893 
HandleReduceWindow(HloInstruction * hlo)3894 Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) {
3895   if (hlo->sharding().IsTileMaximal()) {
3896     return DefaultAction(hlo);
3897   }
3898   HloReduceWindowInstruction* reduce_window =
3899       Cast<HloReduceWindowInstruction>(hlo);
3900   absl::Span<HloInstruction* const> input_arrays = reduce_window->inputs();
3901   absl::Span<HloInstruction* const> init_values = reduce_window->init_values();
3902   int64_t input_idx = 0;
3903   absl::InlinedVector<PartitionedHlo::WindowedInputShardReturnValue, 2>
3904       sharded_results;
3905   absl::InlinedVector<const Shape*, 2> sharded_input_shapes,
3906       replicated_init_shapes;
3907   absl::InlinedVector<HloInstruction*, 2> sharded_inputs, replicated_inits;
3908   for (const HloInstruction* input_array : input_arrays) {
3909     PartitionedHlo& operand = GetPartitionedHlo(input_array);
3910     // Replicate init
3911     PartitionedHlo replicated_init = GetPartitionedHlo(init_values[input_idx++])
3912                                          .Reshard(HloSharding::Replicate());
3913     auto resharded_operand_and_window = operand.ReshardAsWindowedInput(
3914         hlo->window(), hlo->sharding(), replicated_init.hlo());
3915     if (!resharded_operand_and_window.has_value()) {
3916       return DefaultAction(hlo);
3917     }
3918     sharded_results.push_back(resharded_operand_and_window.value());
3919     sharded_inputs.push_back(resharded_operand_and_window->sharded_input);
3920     sharded_input_shapes.push_back(&sharded_inputs.back()->shape());
3921     replicated_inits.push_back(replicated_init.hlo());
3922     replicated_init_shapes.push_back(&replicated_inits.back()->shape());
3923   }
3924   TF_ASSIGN_OR_RETURN(Shape sharded_rw_shape,
3925                       ShapeInference::InferReduceWindowShape(
3926                           sharded_input_shapes, replicated_init_shapes,
3927                           sharded_results[0].shard_window,
3928                           hlo->to_apply()->ComputeProgramShape()));
3929   HloSharding result_sharding =
3930       (hlo->shape().IsTuple())
3931           ? hlo->sharding().GetTupleSharding(hlo->shape()).ValueOrDie()
3932           : hlo->sharding();
3933   Shape shard_shape = MakePartitionedShape(hlo->shape(), result_sharding);
3934   if (shard_shape.has_layout()) {
3935     *sharded_rw_shape.mutable_layout() = shard_shape.layout();
3936   }
3937   SetPartitionedHlo(hlo, [&]() {
3938     HloInstruction* sharded_rw =
3939         b_.AddInstruction(HloInstruction::CreateReduceWindow(
3940             sharded_rw_shape, sharded_inputs, replicated_inits,
3941             sharded_results[0].shard_window, hlo->to_apply()));
3942     if (!sharded_results[0].dynamic_slice_index_on_output.has_value()) {
3943       CHECK(ShapeUtil::Compatible(shard_shape, sharded_rw->shape()))
3944           << shard_shape << " vs " << sharded_rw->shape() << "\n";
3945       return sharded_rw;
3946     }
3947     return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
3948         shard_shape, sharded_rw,
3949         *sharded_results[0].dynamic_slice_index_on_output,
3950         shard_shape.dimensions()));
3951   });
3952   return OkStatus();
3953 }
3954 
HandleSelectAndScatter(HloInstruction * hlo)3955 Status SpmdPartitioningVisitor::HandleSelectAndScatter(HloInstruction* hlo) {
3956   if (hlo->sharding().IsTileMaximal()) {
3957     return DefaultAction(hlo);
3958   }
3959   auto operand = GetPartitionedHlo(hlo->operand(0));
3960   auto source = GetPartitionedHlo(hlo->mutable_operand(1));
3961   if (hlo->sharding() != operand.sharding()) {
3962     operand = operand.Reshard(hlo->sharding());
3963   }
3964   if (hlo->sharding() != source.sharding()) {
3965     source = source.Reshard(hlo->sharding());
3966   }
3967 
3968   // For F32 and BF16 types, we can use NaN padding to workaround the issue with
3969   // low/high padding, since comparison will return false with NaN input.
3970   if (hlo->shape().element_type() != F32 &&
3971       hlo->shape().element_type() != BF16) {
3972     return DefaultAction(hlo);
3973   }
3974 
3975   auto select = hlo->called_computations()[0];
3976   auto select_root = select->root_instruction();
3977   if (select_root->opcode() != HloOpcode::kCompare ||
3978       select_root->operand(0)->opcode() != HloOpcode::kParameter ||
3979       select_root->operand(1)->opcode() != HloOpcode::kParameter ||
3980       select_root->operand(0)->parameter_number() +
3981               select_root->operand(1)->parameter_number() !=
3982           1) {
3983     return DefaultAction(hlo);
3984   }
3985 
3986   float float_pad_value;
3987   if (select_root->comparison_direction() == ComparisonDirection::kGe ||
3988       select_root->comparison_direction() == ComparisonDirection::kGt) {
3989     if (select_root->operand(0)->parameter_number() == 0) {
3990       float_pad_value = -std::numeric_limits<float>::infinity();
3991     } else {
3992       float_pad_value = std::numeric_limits<float>::infinity();
3993     }
3994   } else if (select_root->comparison_direction() == ComparisonDirection::kLe ||
3995              select_root->comparison_direction() == ComparisonDirection::kLt) {
3996     if (select_root->operand(0)->parameter_number() == 0) {
3997       float_pad_value = std::numeric_limits<float>::infinity();
3998     } else {
3999       float_pad_value = -std::numeric_limits<float>::infinity();
4000     }
4001   } else {
4002     return DefaultAction(hlo);
4003   }
4004 
4005   auto pad_value = b_.AddInstruction(HloInstruction::CreateConstant(
4006       hlo->shape().element_type() == BF16
4007           ? LiteralUtil::CreateR0<bfloat16>(
4008                 static_cast<bfloat16>(float_pad_value))
4009           : LiteralUtil::CreateR0<float>(float_pad_value)));
4010 
4011   // Replicate init
4012   auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(2))
4013                              .Reshard(HloSharding::Replicate());
4014 
4015   auto state = MakePartitioningState();
4016   auto partition_ordinals =
4017       MakeTiledPartitionOrdinals(hlo->sharding(), state.partition_id, &b_);
4018 
4019   // The first window for each dimension that overlaps with the shard area.
4020   std::vector<MultiplyAddDivideOffsetCalculation> first_window(
4021       hlo->shape().rank());
4022   // The first window for each dimension that goes beyond with the shard area.
4023   std::vector<MultiplyAddDivideOffsetCalculation> limit_window(
4024       hlo->shape().rank());
4025   std::vector<OffsetCalculation> data_left_halo_sizes(hlo->shape().rank());
4026   std::vector<OffsetCalculation> data_right_halo_sizes(hlo->shape().rank());
4027   std::vector<OffsetCalculation> source_left_halo_sizes(hlo->shape().rank());
4028   std::vector<OffsetCalculation> source_right_halo_sizes(hlo->shape().rank());
4029   auto unpadded_data_shard_shape =
4030       MakePartitionedShape(hlo->shape(), hlo->sharding());
4031   auto unpadded_source_shard_shape =
4032       MakePartitionedShape(hlo->operand(1)->shape(), hlo->sharding());
4033   auto source_shard_hlo = source.hlo();
4034   auto data_shard_hlo = operand.hlo();
4035   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
4036     int64_t shard_count = hlo->sharding().tile_assignment().dim(i);
4037     if (shard_count == 1) {
4038       continue;
4039     }
4040     // If stride > window_size, there will be gaps between windows. These gaps
4041     // will also exist in the output, so we keep them during halo exchange.
4042     //
4043     // TODO(yuanzx): This could introduce overhead if partitions start at
4044     // different offsets in a gap.
4045     auto wd = hlo->window().dimensions(i);
4046     if (wd.stride() > wd.size()) {
4047       wd.set_size(wd.stride());
4048     }
4049     // shard_size * i < stride * k - pad_low + window_size  =>
4050     //   k > (shard_size * i + pad_low - window_size) / stride  =>
4051     //   first_k == (shard_size * i + pad_low - window_size + stride) / stride
4052     first_window[i] = MultiplyAddDivideOffsetCalculation(
4053         unpadded_data_shard_shape.dimensions(i),
4054         wd.padding_low() - wd.size() + wd.stride(), wd.stride());
4055     // shard_size * (i + 1) <= stride * k - pad_low  =>
4056     //   k >= (shard_size * i + shard_size + pad_low) / stride  =>
4057     //   limit_k == (shard_size * i + shard_size + pad_low + stride - 1) /
4058     //     stride
4059     limit_window[i] = MultiplyAddDivideOffsetCalculation(
4060         unpadded_data_shard_shape.dimensions(i),
4061         unpadded_data_shard_shape.dimensions(i) + wd.padding_low() +
4062             wd.stride() - 1,
4063         wd.stride());
4064     source_left_halo_sizes[i] =
4065         MultiplyAddDivideOffsetCalculation(
4066             unpadded_source_shard_shape.dimensions(i), 0, 1) -
4067         first_window[i];
4068     source_right_halo_sizes[i] =
4069         limit_window[i] - MultiplyAddDivideOffsetCalculation(
4070                               unpadded_source_shard_shape.dimensions(i),
4071                               unpadded_source_shard_shape.dimensions(i), 1);
4072     data_left_halo_sizes[i] =
4073         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
4074             unpadded_data_shard_shape.dimensions(i), wd.padding_low(), 1)) -
4075         OffsetCalculation(
4076             HloOpcode::kMultiply, first_window[i],
4077             MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1));
4078     data_right_halo_sizes[i] =
4079         OffsetCalculation(
4080             HloOpcode::kMultiply, limit_window[i],
4081             MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1)) -
4082         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
4083             unpadded_data_shard_shape.dimensions(i),
4084             unpadded_data_shard_shape.dimensions(i) + wd.stride() +
4085                 wd.padding_low() - wd.size(),
4086             1));
4087 
4088     int64_t max_windows =
4089         (limit_window[i] - first_window[i]).MaxInRange(0, shard_count);
4090     auto first_window_hlo =
4091         first_window[i].Calculate(partition_ordinals[i], &b_);
4092     // Padding on the source is filled with the init value so they do not change
4093     // the data on overlapping windows.
4094     auto resharded_source = ExchangeHaloAndGetValidData(
4095         source_shard_hlo, source.base_shape(), source_left_halo_sizes[i],
4096         source_right_halo_sizes[i], 0,
4097         limit_window[i].Calculate(shard_count - 1), max_windows, i,
4098         hlo->sharding(), first_window_hlo, replicated_init.hlo(),
4099         partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_);
4100     if (!resharded_source) {
4101       return DefaultAction(hlo);
4102     }
4103     source_shard_hlo = *resharded_source;
4104 
4105     auto offset_start_in_data =
4106         MultiplyAddDivideOffsetCalculation(wd.stride(), 0, 1)
4107             .Calculate(first_window_hlo, &b_);
4108     int64_t padded_data_size =
4109         (limit_window[i].Calculate(shard_count - 1) - 1) * wd.stride() +
4110         wd.size();
4111     int64_t data_shard_size = (max_windows - 1) * wd.stride() + wd.size();
4112     auto resharded_data = ExchangeHaloAndGetValidData(
4113         data_shard_hlo, operand.base_shape(), data_left_halo_sizes[i],
4114         data_right_halo_sizes[i], wd.padding_low(), padded_data_size,
4115         data_shard_size, i, hlo->sharding(), offset_start_in_data, pad_value,
4116         partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_);
4117     if (!resharded_data) {
4118       return DefaultAction(hlo);
4119     }
4120     data_shard_hlo = *resharded_data;
4121   }
4122 
4123   Window window_on_shard = hlo->window();
4124   for (int64_t i = 0; i < window_on_shard.dimensions_size(); ++i) {
4125     int64_t shard_count = hlo->sharding().tile_assignment().dim(i);
4126     if (shard_count == 1) {
4127       continue;
4128     }
4129     auto reshard_wd = window_on_shard.mutable_dimensions(i);
4130     // The shards are already explicitly padded.
4131     reshard_wd->set_padding_low(0);
4132     reshard_wd->set_padding_high(0);
4133   }
4134 
4135   auto sharded_select_and_scatter =
4136       b_.AddInstruction(HloInstruction::CreateSelectAndScatter(
4137           data_shard_hlo->shape(), data_shard_hlo, select, window_on_shard,
4138           source_shard_hlo, replicated_init.hlo(),
4139           hlo->called_computations()[1]));
4140   SetPartitionedHlo(hlo, [&]() {
4141     auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
4142     if (ShapeUtil::Compatible(sharded_select_and_scatter->shape(),
4143                               shard_shape)) {
4144       return sharded_select_and_scatter;
4145     }
4146     auto zero = b_.AddInstruction(
4147         HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
4148     std::vector<HloInstruction*> slice_offsets(shard_shape.rank(), zero);
4149     for (int64_t i = 0; i < window_on_shard.dimensions_size(); ++i) {
4150       if (hlo->sharding().tile_assignment().dim(i) == 1) {
4151         continue;
4152       }
4153       int64_t pad_low = hlo->window().dimensions(i).padding_low();
4154       auto left_halo_size =
4155           data_left_halo_sizes[i].Calculate(partition_ordinals[i], &b_);
4156       if (data_left_halo_sizes[i].Calculate(0) == pad_low) {
4157         slice_offsets[i] = left_halo_size;
4158       } else {
4159         auto is_shard0 = b_.AddInstruction(HloInstruction::CreateCompare(
4160             ShapeUtil::MakeShape(PRED, {}), zero, partition_ordinals[i],
4161             ComparisonDirection::kEq));
4162         auto pad_low_hlo = b_.AddInstruction(HloInstruction::CreateConstant(
4163             LiteralUtil::CreateR0<int32_t>(pad_low)));
4164         slice_offsets[i] = b_.AddInstruction(HloInstruction::CreateTernary(
4165             zero->shape(), HloOpcode::kSelect, is_shard0, pad_low_hlo,
4166             left_halo_size));
4167       }
4168     }
4169     return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
4170         shard_shape, sharded_select_and_scatter, slice_offsets,
4171         shard_shape.dimensions()));
4172   });
4173   return OkStatus();
4174 }
4175 
HandleTuple(HloInstruction * hlo)4176 Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) {
4177   std::vector<HloInstruction*> new_operands;
4178   for (int64_t i = 0; i < hlo->operand_count(); ++i) {
4179     new_operands.push_back(
4180         GetPartitionedHlo(hlo->operand(i))
4181             .Reshard(hlo->sharding().GetSubSharding(hlo->shape(), {i}))
4182             .hlo());
4183   }
4184   SetPartitionedHlo(hlo, [&]() {
4185     return b_.AddInstruction(HloInstruction::CreateTuple(new_operands));
4186   });
4187   return OkStatus();
4188 }
4189 
DoPartition(HloComputation * computation,const HloSharding & root_sharding,const SpmdPartitionerOptions & options)4190 StatusOr<bool> SpmdPartitioningVisitor::DoPartition(
4191     HloComputation* computation, const HloSharding& root_sharding,
4192     const SpmdPartitionerOptions& options) {
4193   VLOG(2) << "Partitioning computation " << computation->name() << " for "
4194           << num_replicas_ << " replicas and " << num_partitions_
4195           << " partitions";
4196   TF_RETURN_IF_ERROR(computation->Accept(this));
4197 
4198   HloModule* module = computation->parent();
4199   auto new_root =
4200       GetPartitionedHlo(computation->root_instruction()).Reshard(root_sharding);
4201   auto new_computation =
4202       module->AddEmbeddedComputation(b_.Build(new_root.hlo()));
4203   TF_RETURN_IF_ERROR(
4204       DoCodeMotionForWindowedDotGeneralLoops(new_computation, options));
4205 
4206   // Replace the original computation with the new SPMD computation.
4207   absl::flat_hash_map<HloComputation*, HloComputation*> replacement;
4208   replacement[computation] = new_computation;
4209   module->ReplaceComputations(replacement);
4210   return changed_;
4211 }
4212 
HandlePartitionId(HloInstruction * hlo)4213 Status SpmdPartitioningVisitor::HandlePartitionId(HloInstruction* hlo) {
4214   return Unimplemented(
4215       "PartitionId instruction is not supported for SPMD partitioning since "
4216       "the meaning is ambiguous -- whether the instruction is replicated or "
4217       "the data is replicated, and if the latter which data is replicated.");
4218 }
4219 
GetDefaultCollectiveOpsCreator(int64_t num_partitions,int64_t num_replicas)4220 SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64_t num_partitions,
4221                                                         int64_t num_replicas) {
4222   return {
4223       [](SpmdBuilder* b) {
4224         return b->AddInstruction(HloInstruction::CreatePartitionId());
4225       },
4226       [num_replicas, num_partitions](
4227           SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction,
4228           const std::vector<std::vector<int64_t>>& partition_subgroups,
4229           int64_t channel_id) {
4230         if (partition_subgroups.size() <= 1) {
4231           std::vector<ReplicaGroup> groups(num_replicas);
4232           // TODO(yuanzx): Unify subgroup definition with AllToAll.
4233           for (int64_t i = 0; i < num_replicas; ++i) {
4234             groups[i].add_replica_ids(i);
4235           }
4236           return b->AddInstruction(HloInstruction::CreateAllReduce(
4237               operand->shape(), {operand}, reduction, groups,
4238               /*constrain_layout=*/false, channel_id,
4239               /*use_global_device_ids=*/false));
4240         }
4241 
4242         std::vector<ReplicaGroup> device_groups;
4243         device_groups.reserve(partition_subgroups.size() * num_replicas);
4244         for (int64_t i = 0; i < num_replicas; ++i) {
4245           for (const auto& pgroup : partition_subgroups) {
4246             device_groups.emplace_back();
4247             for (int64_t pid : pgroup) {
4248               device_groups.back().add_replica_ids(i * num_partitions + pid);
4249             }
4250           }
4251         }
4252         return b->AddInstruction(HloInstruction::CreateAllReduce(
4253             operand->shape(), {operand}, reduction, device_groups,
4254             /*constrain_layout=*/false, channel_id,
4255             /*use_global_device_ids=*/true));
4256       },
4257       [num_partitions](SpmdBuilder* b, HloInstruction* operand,
4258                        std::vector<std::pair<int64_t, int64_t>>& src_dst_pairs,
4259                        int64_t channel_id) {
4260         /* optimize trivial collective permute */
4261         if (src_dst_pairs.empty()) {
4262           // If the src/dst pairs are empty, then the collective permute just
4263           // initializes the output to zero.
4264           return CreateZero(operand->shape(), b);
4265         } else {
4266           // A collective-permute is a copy if all pairs are "identity" and
4267           // all partitions are listed.
4268           bool is_copy =
4269               src_dst_pairs.size() == num_partitions &&
4270               absl::c_all_of(src_dst_pairs,
4271                              [](const std::pair<int64_t, int64_t>& pair) {
4272                                return pair.first == pair.second;
4273                              });
4274           if (is_copy) {
4275             return operand;
4276           } else {
4277             return b->AddInstruction(HloInstruction::CreateCollectivePermute(
4278                 operand->shape(), operand, src_dst_pairs, channel_id));
4279           }
4280         }
4281       },
4282       [](SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
4283          const std::vector<std::vector<int64_t>>& partition_subgroups,
4284          int64_t channel_id, std::optional<int64_t> split_dimension) {
4285         std::vector<Shape> shapes(operands.size(), operands[0]->shape());
4286         const Shape output_shape = (shapes.size() == 1)
4287                                        ? shapes[0]
4288                                        : ShapeUtil::MakeTupleShape(shapes);
4289         std::vector<ReplicaGroup> groups(partition_subgroups.size());
4290         for (int64_t i = 0; i < groups.size(); ++i) {
4291           for (int64_t id : partition_subgroups[i]) {
4292             groups[i].add_replica_ids(id);
4293           }
4294         }
4295         return b->AddInstruction(HloInstruction::CreateAllToAll(
4296             output_shape, operands, groups,
4297             /*constrain_layout=*/false, channel_id, split_dimension));
4298       },
4299       [num_replicas, num_partitions](
4300           SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape,
4301           const std::vector<std::vector<int64_t>>& partition_subgroups,
4302           int64_t channel_id, int64_t all_gather_dimension) {
4303         std::vector<ReplicaGroup> device_groups;
4304         device_groups.reserve(partition_subgroups.size() * num_replicas);
4305         for (int64_t i = 0; i < num_replicas; ++i) {
4306           for (const auto& pgroup : partition_subgroups) {
4307             device_groups.emplace_back();
4308             for (int64_t pid : pgroup) {
4309               device_groups.back().add_replica_ids(i * num_partitions + pid);
4310             }
4311           }
4312         }
4313         return b->AddInstruction(HloInstruction::CreateAllGather(
4314             ag_shape, {operand}, all_gather_dimension, device_groups,
4315             /*constrain_layout=*/false, channel_id,
4316             /*use_global_device_ids=*/true));
4317       },
4318   };
4319 }
4320 
SpmdPartitioner(int64_t num_partitions,int64_t num_replicas,SpmdPartitionerOptions options)4321 SpmdPartitioner::SpmdPartitioner(int64_t num_partitions, int64_t num_replicas,
4322                                  SpmdPartitionerOptions options)
4323     : SpmdPartitioner(
4324           num_partitions, num_replicas, std::move(options),
4325           GetDefaultCollectiveOpsCreator(num_partitions, num_replicas)) {}
4326 
AllGatherShards(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64_t * next_channel_id,absl::Span<const int64_t> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator)4327 HloInstruction* SpmdPartitioner::AllGatherShards(
4328     SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
4329     int64_t* next_channel_id, absl::Span<const int64_t> selected_dims,
4330     const SPMDCollectiveOpsCreator& collectives_creator) {
4331   return AllGatherShardsInternal(b, operand, sharding, next_channel_id,
4332                                  selected_dims, collectives_creator,
4333                                  /*per_dim_ag=*/true);
4334 }
4335 
AllGatherShardsInternal(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64_t * next_channel_id,absl::Span<const int64_t> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator,bool per_dim_ag)4336 HloInstruction* SpmdPartitioner::AllGatherShardsInternal(
4337     SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
4338     int64_t* next_channel_id, absl::Span<const int64_t> selected_dims,
4339     const SPMDCollectiveOpsCreator& collectives_creator, bool per_dim_ag) {
4340   if (selected_dims.empty()) {
4341     return operand;
4342   }
4343   CHECK(!sharding.IsTileMaximal());
4344   if (per_dim_ag || selected_dims.size() == 1) {
4345     HloInstruction* result = operand;
4346     Shape result_shape = operand->shape();
4347     for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
4348       if (sharding.tile_assignment().dim(*it) == 1) {
4349         continue;
4350       }
4351       auto partition_subgroups =
4352           GetPartitionGroupsForReplication(sharding, {*it});
4353       result_shape.set_dimensions(
4354           *it, result_shape.dimensions(*it) * partition_subgroups[0].size());
4355       result = collectives_creator.create_cross_partition_all_gather(
4356           b, result, result_shape, partition_subgroups, (*next_channel_id)++,
4357           /*all_gather_dimension=*/*it);
4358     }
4359     return result;
4360   }
4361   std::vector<int64_t> shape;
4362   shape.push_back(1);
4363   for (int64_t dim : operand->shape().dimensions()) {
4364     shape.push_back(dim);
4365   }
4366   // Add one leading dimension to gather all partitions.
4367   auto reshape = b->AddInstruction(HloInstruction::CreateReshape(
4368       ShapeUtil::MakeShape(operand->shape().element_type(), shape), operand));
4369   HloInstruction* result = reshape;
4370   auto partition_subgroups =
4371       GetPartitionGroupsForReplication(sharding, selected_dims);
4372   shape[0] *= partition_subgroups[0].size();
4373   result = collectives_creator.create_cross_partition_all_gather(
4374       b, result, ShapeUtil::MakeShape(operand->shape().element_type(), shape),
4375       partition_subgroups, (*next_channel_id)++,
4376       /*all_gather_dimension=*/0);
4377   // If n > 1 dimensions are partitioned, split the leading dimension to n.
4378   std::vector<int64_t> tiled_dims;
4379   for (int64_t i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
4380     if (sharding.tile_assignment().dim(i) > 1 &&
4381         absl::c_linear_search(selected_dims, i)) {
4382       tiled_dims.push_back(i);
4383     }
4384   }
4385   if (tiled_dims.size() > 1) {
4386     std::vector<int64_t> split_dim_shape;
4387     split_dim_shape.reserve(tiled_dims.size() + operand->shape().rank());
4388     for (int64_t i : tiled_dims) {
4389       split_dim_shape.push_back(sharding.tile_assignment().dim(i));
4390     }
4391     for (int64_t dim : operand->shape().dimensions()) {
4392       split_dim_shape.push_back(dim);
4393     }
4394     result = b->AddInstruction(HloInstruction::CreateReshape(
4395         ShapeUtil::MakeShape(operand->shape().element_type(), split_dim_shape),
4396         result));
4397   }
4398   // Transpose the gathered dimensions to next to their corresponding
4399   // partitioned dimensions.
4400   std::vector<int64_t> xpose_permutation(result->shape().rank());
4401   int64_t split_dims_added = 0;
4402   for (int64_t i = 0; i < xpose_permutation.size(); ++i) {
4403     if (sharding.tile_assignment().dim(i - split_dims_added) == 1 ||
4404         !absl::c_linear_search(selected_dims, i - split_dims_added)) {
4405       xpose_permutation[i] = i + tiled_dims.size() - split_dims_added;
4406     } else {
4407       xpose_permutation[i] = split_dims_added;
4408       xpose_permutation[i + 1] = i + tiled_dims.size() - split_dims_added;
4409       split_dims_added++;
4410       i++;
4411     }
4412   }
4413   result = b->AddInstruction(HloInstruction::CreateTranspose(
4414       ShapeInference::InferTransposeShape(result->shape(), xpose_permutation)
4415           .ValueOrDie(),
4416       result, xpose_permutation));
4417   // Reshape to the desired shape.
4418   auto ag_shape = operand->shape();
4419   for (int64_t i : tiled_dims) {
4420     ag_shape.set_dimensions(
4421         i, ag_shape.dimensions(i) * sharding.tile_assignment().dim(i));
4422   }
4423   result = b->AddInstruction(HloInstruction::CreateReshape(ag_shape, result));
4424   return result;
4425 }
4426 
AllReduceAlongShardingDims(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64_t * next_channel_id,absl::Span<const int64_t> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator,HloComputation * reduction)4427 HloInstruction* SpmdPartitioner::AllReduceAlongShardingDims(
4428     SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
4429     int64_t* next_channel_id, absl::Span<const int64_t> selected_dims,
4430     const SPMDCollectiveOpsCreator& collectives_creator,
4431     HloComputation* reduction) {
4432   return AllReduceAlongShardingDimsInternal(
4433       b, operand, sharding, next_channel_id, selected_dims, collectives_creator,
4434       reduction, /*per_dim_ar=*/true);
4435 }
4436 
AllReduceAlongShardingDimsInternal(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64_t * next_channel_id,absl::Span<const int64_t> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator,HloComputation * reduction,bool per_dim_ar)4437 HloInstruction* SpmdPartitioner::AllReduceAlongShardingDimsInternal(
4438     SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
4439     int64_t* next_channel_id, absl::Span<const int64_t> selected_dims,
4440     const SPMDCollectiveOpsCreator& collectives_creator,
4441     HloComputation* reduction, bool per_dim_ar) {
4442   if (!per_dim_ar) {
4443     auto partition_subgroups =
4444         GetPartitionGroupsForReplication(sharding, selected_dims);
4445     return collectives_creator.create_cross_partition_all_reduce(
4446         b, operand, reduction, partition_subgroups, (*next_channel_id)++);
4447   }
4448   auto result = operand;
4449   for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
4450     if (sharding.tile_assignment().dim(*it) == 1) {
4451       continue;
4452     }
4453     auto partition_subgroups =
4454         GetPartitionGroupsForReplication(sharding, {*it});
4455     result = collectives_creator.create_cross_partition_all_reduce(
4456         b, result, reduction, partition_subgroups, (*next_channel_id)++);
4457   }
4458   return result;
4459 }
4460 
PartitionComputation(HloComputation * computation,const HloSharding & root_sharding,int64_t * next_channel_id,SpmdLogger * logger)4461 StatusOr<bool> SpmdPartitioner::PartitionComputation(
4462     HloComputation* computation, const HloSharding& root_sharding,
4463     int64_t* next_channel_id, SpmdLogger* logger) {
4464   auto visitor =
4465       CreateVisitor(computation, num_partitions_, num_replicas_,
4466                     collective_ops_creator_, next_channel_id, logger, options_);
4467   return visitor->DoPartition(computation, root_sharding, options_);
4468 }
4469 
CreateVisitor(HloComputation * computation,int64_t num_partitions,int64_t num_replicas,const SPMDCollectiveOpsCreator & collective_ops_creator,int64_t * next_channel_id,SpmdLogger * logger,SpmdPartitionerOptions options)4470 std::unique_ptr<SpmdPartitioningVisitor> SpmdPartitioner::CreateVisitor(
4471     HloComputation* computation, int64_t num_partitions, int64_t num_replicas,
4472     const SPMDCollectiveOpsCreator& collective_ops_creator,
4473     int64_t* next_channel_id, SpmdLogger* logger,
4474     SpmdPartitionerOptions options) {
4475   return std::make_unique<SpmdPartitioningVisitor>(
4476       computation, num_partitions, num_replicas, collective_ops_creator,
4477       next_channel_id, logger, std::move(options), this);
4478 }
4479 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)4480 StatusOr<bool> SpmdPartitioner::Run(
4481     HloModule* module,
4482     const absl::flat_hash_set<absl::string_view>& execution_threads) {
4483   TF_RETURN_IF_ERROR(PreprocessSharding(module, execution_threads));
4484   TF_RETURN_IF_ERROR(PreprocessHlos(module, execution_threads));
4485 
4486   XLA_VLOG_LINES(1, SpmdLogger::ReportBeforePartition(
4487                         *module, options_.report_instruction_count));
4488 
4489   // Add the parameters' and output's shardings to the module.
4490   std::vector<HloSharding> entry_params_shardings;
4491   const auto num_parameters = module->entry_computation()->num_parameters();
4492   entry_params_shardings.reserve(num_parameters);
4493   for (int64_t i = 0; i < num_parameters; ++i) {
4494     auto param = module->entry_computation()->parameter_instruction(i);
4495     CHECK(param->has_sharding()) << "Missing sharding in entry parameter " << i;
4496     entry_params_shardings.push_back(param->sharding());
4497   }
4498   module->set_spmd_parameters_shardings(entry_params_shardings);
4499   auto entry_root = module->entry_computation()->root_instruction();
4500   CHECK(entry_root->has_sharding()) << "Missing sharding in entry root.";
4501   module->set_spmd_output_sharding(entry_root->sharding());
4502 
4503   FlattenCallGraph flatten;
4504   TF_ASSIGN_OR_RETURN(auto changed, flatten.Run(module));
4505 
4506   SpmdLogger logger(options_.report_instruction_count,
4507                     /*disabled=*/!VLOG_IS_ON(1));
4508   auto program_shape = module->entry_computation()->ComputeProgramShape();
4509   int64_t next_channel_id = hlo_query::NextChannelId(*module);
4510   // Copy the root sharding since the partitioner visitor may temporarily change
4511   // the sharding to work around manual sharding.
4512   HloSharding root_sharding = entry_root->sharding();
4513   TF_ASSIGN_OR_RETURN(
4514       bool partition_changed,
4515       PartitionComputation(module->entry_computation(), root_sharding,
4516                            &next_channel_id, &logger));
4517   changed |= partition_changed;
4518 
4519   // For the entry computation, make sure that the root instruction and the
4520   // parameters preserve their signatures.
4521   auto new_program_shape = module->entry_computation()->ComputeProgramShape();
4522   if (!options_.allow_module_signature_change) {
4523     TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
4524         program_shape.result(), new_program_shape.result()))
4525         << "Result shape changed for the entry computation";
4526     TF_RET_CHECK(program_shape.parameters_size() ==
4527                  new_program_shape.parameters_size())
4528         << "Parameter count changed for the entry computation";
4529     for (int64_t i = 0; i < program_shape.parameters_size(); ++i) {
4530       TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
4531           program_shape.parameters(i), new_program_shape.parameters(i)))
4532           << "Parameter shape changed for the entry computation";
4533     }
4534   } else {
4535     // Fix up some bad tiling in entry computation layout.
4536     auto update_shape = [this](Shape* subshape, const xla::ShapeIndex& index) {
4537       if (subshape->IsArray()) {
4538         UpdateLayout(subshape);
4539       }
4540     };
4541     const auto& old_entry_layout = module->entry_computation_layout();
4542     // Shapes can change but the layout should still remain the same.
4543     for (int64_t i = 0; i < new_program_shape.parameters_size(); ++i) {
4544       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
4545           old_entry_layout.parameter_shape(i),
4546           new_program_shape.mutable_parameters(i)));
4547       ShapeUtil::ForEachMutableSubshape(new_program_shape.mutable_parameters(i),
4548                                         update_shape);
4549     }
4550     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
4551         old_entry_layout.result_shape(), new_program_shape.mutable_result()));
4552     ShapeUtil::ForEachMutableSubshape(new_program_shape.mutable_result(),
4553                                       update_shape);
4554 
4555     HloModuleConfig config = module->config();
4556     *config.mutable_entry_computation_layout() =
4557         ComputationLayout(new_program_shape, /*ignore_layouts=*/false);
4558     module->set_config(config);
4559   }
4560 
4561   XLA_VLOG_LINES(1, SpmdLogger::ReportAfterPartition(
4562                         *module, options_.report_instruction_count));
4563   XLA_VLOG_LINES(1, logger.MakeReport());
4564 
4565   if (changed) {
4566     HloPassPipeline pass("spmd-cleanup");
4567     pass.AddPass<HloDCE>(/*remove_cross_partition_collective_ops=*/true);
4568     pass.AddPass<TupleSimplifier>();
4569     pass.AddPass<HloDCE>(/*remove_cross_partition_collective_ops=*/true);
4570     pass.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
4571     pass.AddPass<FlattenCallGraph>();
4572     TF_RETURN_IF_ERROR(pass.Run(module, execution_threads).status());
4573   }
4574 
4575   TF_RETURN_IF_ERROR(ClearShardingAttributes(module, execution_threads));
4576   return changed;
4577 }
4578 
PreprocessSharding(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)4579 Status SpmdPartitioner::PreprocessSharding(
4580     HloModule* module,
4581     const absl::flat_hash_set<absl::string_view>& execution_threads) {
4582   for (HloComputation* computation : module->computations(execution_threads)) {
4583     for (HloInstruction* hlo : computation->instructions()) {
4584       if (hlo->HasSideEffectNoRecurse() && hlo->opcode() != HloOpcode::kRng) {
4585         TF_RET_CHECK(hlo->has_sharding())
4586             << "Side-effect HLO must have sharding: " << hlo->ToString();
4587         TF_RET_CHECK(!HasReplicatedSharding(hlo->sharding()) ||
4588                      CanSideEffectingHaveReplicatedSharding(hlo))
4589             << "side-effect HLO cannot have a replicated sharding: "
4590             << hlo->ToString();
4591       }
4592 
4593       // For unassigned HLOs, annotate with replicated sharding.
4594       //
4595       // Among side-effecting ops, only Rng is allowed to omit the annotation.
4596       // In that case, we currently force it to run on core 0, since we don't
4597       // support partitioning or replicating the Rng op (the values depend on
4598       // the seed provided to each device).
4599       //
4600       // TODO(hyouklee): Should we also convert single-device shardings (without
4601       // side-effects) into replicated?
4602       if (!hlo->has_sharding()) {
4603         if (hlo->opcode() == HloOpcode::kRng) {
4604           hlo->set_sharding(HloSharding::AssignDevice(0));
4605         } else {
4606           hlo->set_sharding(
4607               HloSharding::Single(hlo->shape(), HloSharding::Replicate()));
4608         }
4609       } else if (!hlo->sharding().IsTileMaximal() &&
4610                  !hlo->sharding().IsManual()) {
4611         std::vector<int64_t> available(num_partitions_);
4612         std::iota(available.begin(), available.end(), 0);
4613         TF_RET_CHECK(num_partitions_ == hlo_sharding_util::DevicesForSharding(
4614                                             hlo->sharding(), available)
4615                                             .size())
4616             << "num_partitions:" << num_partitions_ << "\n"
4617             << "SPMD partitioner only supports tile sharding that includes all "
4618                "partitions. If you didn't add this sharding annotation in the "
4619                "model, please file a bug to XLA team.\n"
4620             << hlo->ToString();
4621       }
4622     }
4623   }
4624 
4625   // Entry computation's parameter and root sharding must be either all
4626   // replicated or all on a single device.
4627   if (!options_.allow_module_signature_change) {
4628     const HloComputation* entry = module->entry_computation();
4629     TF_RET_CHECK(entry->root_instruction()->has_sharding());
4630     const HloSharding& root_sharding = entry->root_instruction()->sharding();
4631     if (!root_sharding.UniqueDevice().has_value()) {
4632       if (root_sharding.IsTuple()) {
4633         TF_RET_CHECK(absl::c_all_of(root_sharding.tuple_elements(),
4634                                     [](const HloSharding& s) {
4635                                       return s.IsReplicated() || s.IsManual();
4636                                     }))
4637             << "Unsupported entry root sharding: " << root_sharding.ToString();
4638 
4639       } else {
4640         TF_RET_CHECK(root_sharding.IsReplicated() || root_sharding.IsManual())
4641             << "Unsupported entry root sharding: " << root_sharding.ToString();
4642       }
4643     }
4644 
4645     for (const HloInstruction* param : entry->parameter_instructions()) {
4646       TF_RET_CHECK(param->has_sharding());
4647       TF_RET_CHECK(param->sharding().IsReplicated() ||
4648                    param->sharding().UniqueDevice().has_value())
4649           << "Unsupported entry parameter sharding:"
4650           << param->sharding().ToString();
4651     }
4652   }
4653 
4654   return OkStatus();
4655 }
4656 
PreprocessHlos(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)4657 Status SpmdPartitioner::PreprocessHlos(
4658     HloModule* module,
4659     const absl::flat_hash_set<absl::string_view>& execution_threads) {
4660   auto skip_copy_operands = [](HloInstruction* operand,
4661                                bool check_single_use =
4662                                    true) -> HloInstruction* {
4663     while (operand->user_count() == 1 &&
4664            operand->opcode() == HloOpcode::kCopy) {
4665       operand = operand->mutable_operand(0);
4666     }
4667     if (check_single_use && operand->user_count() != 1) {
4668       return nullptr;
4669     }
4670     return operand;
4671   };
4672 
4673   for (HloComputation* computation : module->computations(execution_threads)) {
4674     for (HloInstruction* hlo : computation->MakeInstructionPostOrder()) {
4675       if (hlo->sharding().IsTileMaximal() || hlo->sharding().IsManual()) {
4676         // No need to optimize for tile-maximal or manual sharding.
4677         continue;
4678       }
4679 
4680       if (hlo->opcode() == HloOpcode::kSlice) {
4681         HloInstruction* operand = skip_copy_operands(hlo->mutable_operand(0));
4682         if (operand == nullptr || operand->sharding() != hlo->sharding()) {
4683           continue;
4684         }
4685 
4686         // Merge pad->slice to avoid multiple halo exchanges.
4687         if (operand->opcode() == HloOpcode::kPad) {
4688           std::optional<PaddingConfig> merged_padding =
4689               operand->padding_config();
4690           bool may_have_multi_halo_exchanges = false;
4691           for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
4692             const auto& dim = operand->padding_config().dimensions(i);
4693             if (dim.interior_padding() != 0 || hlo->slice_strides(i) != 1) {
4694               merged_padding = std::nullopt;
4695               break;
4696             }
4697             if (hlo->sharding().tile_assignment().dim(i) != 1 &&
4698                 (dim.edge_padding_low() != 0 || dim.edge_padding_high() != 0) &&
4699                 hlo->shape().dimensions(i) != operand->shape().dimensions(i)) {
4700               // There are padding, slicing, and sharding on this dim.
4701               may_have_multi_halo_exchanges = true;
4702             }
4703 
4704             auto* merged_dim = merged_padding->mutable_dimensions(i);
4705             merged_dim->set_edge_padding_low(dim.edge_padding_low() -
4706                                              hlo->slice_starts(i));
4707             merged_dim->set_edge_padding_high(hlo->slice_limits(i) -
4708                                               operand->shape().dimensions(i));
4709           }
4710           if (merged_padding.has_value() && may_have_multi_halo_exchanges) {
4711             // Rewrite to a single Pad.
4712             HloInstruction* new_pad =
4713                 computation->AddInstruction(HloInstruction::CreatePad(
4714                     hlo->shape(), operand->mutable_operand(0),
4715                     operand->mutable_operand(1), *merged_padding));
4716             new_pad->set_metadata(operand->metadata());
4717             new_pad->set_sharding(hlo->sharding());
4718             TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_pad));
4719             TF_RETURN_IF_ERROR(
4720                 computation->RemoveInstructionAndUnusedOperands(hlo));
4721           }
4722         }
4723       }
4724       if (hlo->opcode() == HloOpcode::kConcatenate) {
4725         const int64_t dim = hlo->concatenate_dimension();
4726         if (hlo->sharding().tile_assignment().dim(dim) == 1) {
4727           continue;
4728         }
4729         if (hlo->operand_count() == 2) {
4730           // Find a pattern of "rotate right on one dimension":
4731           // concat(slice(input), slice(input)).
4732           HloInstruction* lhs = skip_copy_operands(hlo->mutable_operand(0));
4733           HloInstruction* rhs = skip_copy_operands(hlo->mutable_operand(1));
4734           if (lhs == nullptr || rhs == nullptr) {
4735             continue;
4736           }
4737           const int64_t amount = FindRotateRightPattern(hlo, lhs, rhs);
4738           if (amount < 0) {
4739             continue;
4740           }
4741           HloInstruction* to_rotate = lhs->mutable_operand(0);
4742           HloInstruction* rotate = computation->AddInstruction(
4743               CreateCustomCallSPMDInternal_RotateRight(to_rotate, dim, amount));
4744           rotate->set_metadata(hlo->metadata());
4745           rotate->set_sharding(hlo->sharding());
4746           TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(rotate));
4747           TF_RETURN_IF_ERROR(
4748               computation->RemoveInstructionAndUnusedOperands(hlo));
4749         } else if (hlo->operand_count() == 3) {
4750           // Find the pattern for "pad with wrap": concat(slice(x), x, slice(x))
4751           // All involved values with same sharding.
4752           HloInstruction* lhs = skip_copy_operands(hlo->mutable_operand(0));
4753           HloInstruction* mid = skip_copy_operands(hlo->mutable_operand(1),
4754                                                    /*check_single_use=*/false);
4755           HloInstruction* rhs = skip_copy_operands(hlo->mutable_operand(2));
4756           std::optional<PadWithWrapPattern> pad_pattern =
4757               FindPadWithWrapPattern(hlo, lhs, mid, rhs);
4758           if (!pad_pattern) {
4759             continue;
4760           }
4761 
4762           // Since the concat requires that the size of all operands along the
4763           // non-concat dimension is the same, it implies that the lhs/rhs slice
4764           // is slicing along the concat dims.
4765 
4766           // Step 1: Pad the mid operand to the final size. The low padding is
4767           // the size of the lhs shape, and high padding is size of rhs shape.
4768           PaddingConfig padding_config =
4769               MakeNoPaddingConfig(hlo->shape().rank());
4770           auto* padding_config_dim = padding_config.mutable_dimensions(dim);
4771           const int64_t low_pad = lhs->shape().dimensions(dim);
4772           const int64_t high_pad = rhs->shape().dimensions(dim);
4773           padding_config_dim->set_edge_padding_low(low_pad);
4774           padding_config_dim->set_edge_padding_high(high_pad);
4775           HloInstruction* zero =
4776               computation->AddInstruction(HloInstruction::CreateConstant(
4777                   LiteralUtil::Zero(hlo->shape().element_type())));
4778           zero->set_sharding(HloSharding::Replicate());
4779           HloInstruction* pad =
4780               computation->AddInstruction(HloInstruction::CreatePad(
4781                   hlo->shape(), mid, zero, padding_config));
4782           pad->set_metadata(hlo->metadata());
4783           pad->set_sharding(hlo->sharding());
4784 
4785           // Step 2: rotate the padded value so that the lhs slice aligns to the
4786           // low of the padded size.
4787           //  padded_operand = low_pad | mid | high_pad.
4788           //  slice_start in padded_operand = lhs->slice_start + low_pad.
4789           //  Rotate left by (lhs->slice_start + low_pad)
4790           //  i.e., rotate right = padded_size - (lhs_slice_start + low_pad).
4791           const int64_t padded_size = hlo->shape().dimensions(dim);
4792           const int rotate_lhs_amount =
4793               padded_size - (pad_pattern->lhs_slice_start + low_pad);
4794           HloInstruction* rotate_lhs = computation->AddInstruction(
4795               CreateCustomCallSPMDInternal_RotateRight(pad, dim,
4796                                                        rotate_lhs_amount));
4797           rotate_lhs->set_metadata(hlo->metadata());
4798           rotate_lhs->set_sharding(hlo->sharding());
4799 
4800           auto apply_modifiers =
4801               [&](HloInstruction* inst,
4802                   const std::vector<const HloInstruction*>& modifiers) {
4803                 // Apply the modifiers in the reverse order.
4804                 for (auto it = modifiers.crbegin(), end = modifiers.crend();
4805                      it != end; ++it) {
4806                   const HloInstruction* modifier = *it;
4807                   // New shape has same element type as the modifier, but dims
4808                   // as inst.
4809                   Shape new_shape = ShapeUtil::ChangeElementType(
4810                       inst->shape(), modifier->shape().element_type());
4811                   inst = computation->AddInstruction(
4812                       modifier->CloneWithNewOperands(new_shape, {inst}));
4813                 }
4814                 return inst;
4815               };
4816           rotate_lhs = apply_modifiers(rotate_lhs, pad_pattern->lhs_modifiers);
4817 
4818           // Step 3: rotate the padded value so that the rhs slice aligns to
4819           // high of the padded size.
4820           //  padded_operand = low_pad | mid | high_pad.
4821           //  slice_start in padded_operand = rhs->slice_start + low_pad.
4822           //  slice_end in padded_operand = rhs->slice_start + low_pad +
4823           //  high_pad; Rotate right by padded_size - (rhs->slice_start +
4824           //  low_pad + high_pad)
4825           const int64_t rotate_rhs_amount =
4826               padded_size - (pad_pattern->rhs_slice_start + low_pad + high_pad);
4827           HloInstruction* rotate_rhs = computation->AddInstruction(
4828               CreateCustomCallSPMDInternal_RotateRight(pad, dim,
4829                                                        rotate_rhs_amount));
4830           rotate_rhs->set_metadata(hlo->metadata());
4831           rotate_rhs->set_sharding(hlo->sharding());
4832           rotate_rhs = apply_modifiers(rotate_rhs, pad_pattern->rhs_modifiers);
4833 
4834           // Now merge the 3 results using appropriate selects.
4835           const Shape iota_shape =
4836               ShapeUtil::ChangeElementType(hlo->shape(), U32);
4837           HloInstruction* iota = computation->AddInstruction(
4838               HloInstruction::CreateIota(iota_shape, dim));
4839           iota->set_metadata(hlo->metadata());
4840           iota->set_sharding(hlo->sharding());
4841 
4842           struct SelectSpec {
4843             int64_t limit;
4844             HloInstruction* hlo;
4845             Comparison::Direction cmp;
4846           };
4847           const std::array<SelectSpec, 2> selects = {
4848               {// All elements < low_pad come from rotate_lhs.
4849                {low_pad, rotate_lhs, Comparison::Direction::kLt},
4850                // All elements >= padded_size - high_pad come from rotate_rhs
4851                {padded_size - high_pad, rotate_rhs,
4852                 Comparison::Direction::kGe}}};
4853 
4854           Shape pred_shape = ShapeUtil::ChangeElementType(hlo->shape(), PRED);
4855 
4856           HloInstruction* merged = pad;
4857           for (const SelectSpec& select_spec : selects) {
4858             HloInstruction* limit =
4859                 computation->AddInstruction(HloInstruction::CreateConstant(
4860                     LiteralUtil::CreateR0<uint32_t>(select_spec.limit)));
4861             limit->set_sharding(HloSharding::Replicate());
4862             HloInstruction* limit_bcast = computation->AddInstruction(
4863                 HloInstruction::CreateBroadcast(iota_shape, limit, {}));
4864             limit_bcast->set_metadata(hlo->metadata());
4865             limit_bcast->set_sharding(hlo->sharding());
4866             HloInstruction* compare =
4867                 computation->AddInstruction(HloInstruction::CreateCompare(
4868                     pred_shape, iota, limit_bcast, select_spec.cmp));
4869             compare->set_metadata(hlo->metadata());
4870             compare->set_sharding(hlo->sharding());
4871             merged = computation->AddInstruction(HloInstruction::CreateTernary(
4872                 hlo->shape(), HloOpcode::kSelect, compare, select_spec.hlo,
4873                 merged));
4874             merged->set_metadata(hlo->metadata());
4875             merged->set_sharding(hlo->sharding());
4876           }
4877 
4878           TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(merged));
4879           TF_RETURN_IF_ERROR(
4880               computation->RemoveInstructionAndUnusedOperands(hlo));
4881         }
4882       }
4883     }
4884   }
4885   return OkStatus();
4886 }
4887 
4888 }  // namespace spmd
4889 }  // namespace xla
4890