xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_replication_analysis.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_replication_analysis.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/strings/str_cat.h"
29 #include "tensorflow/compiler/xla/map_util.h"
30 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
34 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
35 #include "tensorflow/compiler/xla/shape_util.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 
39 namespace xla {
40 
41 // Determines whether an HLO instruction is replicated at index based on current
42 // knowledge in hlo_replication.
43 HloReplicationAnalysis::HloReplication
DetermineHloInstructionIsReplicated(const HloInstruction * hlo,const ShapeIndex & index,bool cross_partition_spmd,const absl::flat_hash_map<const HloInstruction *,ShapeTree<HloReplication>> & hlo_replication,bool support_partial_replication)44 HloReplicationAnalysis::DetermineHloInstructionIsReplicated(
45     const HloInstruction* hlo, const ShapeIndex& index,
46     bool cross_partition_spmd,
47     const absl::flat_hash_map<const HloInstruction*, ShapeTree<HloReplication>>&
48         hlo_replication,
49     bool support_partial_replication) {
50   const auto merge_operand_replication = [&hlo_replication](
51                                              const HloInstruction* inst) {
52     HloReplication replication = HloReplication::ReplicatedOnAllDevices();
53     for (auto operand : inst->operands()) {
54       auto operand_it = hlo_replication.find(operand);
55       if (operand_it == hlo_replication.end()) {
56         replication = replication.Merge(HloReplication::UniqueOnAllDevices());
57       } else {
58         replication = replication.Merge(operand_it->second.element({}));
59       }
60     }
61     return replication;
62   };
63 
64   if (hlo->opcode() == HloOpcode::kAllReduce ||
65       hlo->opcode() == HloOpcode::kAllGather) {
66     // All-reduce/all-gather returns same values across partitions/replicas as
67     // long as its operands are replicated.
68     HloReplication replication = merge_operand_replication(hlo);
69     if (replication.IsReplicatedOnAllDevices()) {
70       return replication;
71     }
72     if (!hlo->channel_id().has_value()) {
73       // This is cross-replica-only.
74       if (cross_partition_spmd) {
75         return replication;
76       }
77       if (hlo->replica_groups().empty() || hlo->replica_groups().size() == 1) {
78         return HloReplication::ReplicatedOnAllDevices();
79       }
80       if (support_partial_replication) {
81         std::vector<absl::Span<const int64_t>> device_sets;
82         for (const ReplicaGroup& replica_group : hlo->replica_groups()) {
83           device_sets.push_back(replica_group.replica_ids());
84         }
85         return HloReplication::PartiallyReplicated(device_sets);
86       } else {
87         return HloReplication::UniqueOnAllDevices();
88       }
89     } else {
90       bool global_id;
91       if (hlo->opcode() == HloOpcode::kAllReduce) {
92         global_id = Cast<HloAllReduceInstruction>(hlo)->use_global_device_ids();
93       } else {
94         global_id = Cast<HloAllGatherInstruction>(hlo)->use_global_device_ids();
95       }
96       if (global_id) {
97         bool replicated_across_partitions = true;
98         bool replicated_across_replicas = true;
99         const int64_t num_partitions =
100             hlo->GetModule()->config().num_partitions();
101         for (const auto& group : hlo->replica_groups()) {
102           absl::flat_hash_set<int64_t> visited_partitions;
103           absl::flat_hash_set<int64_t> visited_replicas;
104           for (int64_t id : group.replica_ids()) {
105             int64_t rid = id / num_partitions;
106             int64_t pid = id % num_partitions;
107             visited_partitions.insert(pid);
108             visited_replicas.insert(rid);
109           }
110           replicated_across_partitions &=
111               visited_partitions.size() == num_partitions;
112           replicated_across_replicas &=
113               visited_replicas.size() ==
114               hlo->GetModule()->config().replica_count();
115         }
116         if ((cross_partition_spmd && replicated_across_partitions) ||
117             (!cross_partition_spmd && replicated_across_replicas)) {
118           return HloReplication::ReplicatedOnAllDevices();
119         } else {
120           return HloReplication::UniqueOnAllDevices();
121         }
122       }
123       if (cross_partition_spmd) {
124         return HloReplication::ReplicatedOnAllDevices();
125       }
126       if (hlo->replica_groups().empty() || hlo->replica_groups().size() == 1) {
127         return HloReplication::ReplicatedOnAllDevices();
128       } else {
129         return HloReplication::UniqueOnAllDevices();
130       }
131     }
132   }
133   if (hlo->HasSideEffectNoRecurse()) {
134     return HloReplication::UniqueOnAllDevices();
135   }
136   if (hlo->opcode() == HloOpcode::kReplicaId) {
137     // ReplicaId returns the same value for all partitions in each replica.
138     return cross_partition_spmd ? HloReplication::ReplicatedOnAllDevices()
139                                 : HloReplication::UniqueOnAllDevices();
140   }
141   if (hlo->opcode() == HloOpcode::kPartitionId) {
142     // PartitionId returns the same value for all replicas in each partition.
143     return cross_partition_spmd ? HloReplication::UniqueOnAllDevices()
144                                 : HloReplication::ReplicatedOnAllDevices();
145   }
146   auto it = hlo_replication.find(hlo);
147   if (hlo->opcode() == HloOpcode::kParameter) {
148     // Parameters should have been processed.
149     CHECK(it != hlo_replication.end());
150     return it->second.element(index);
151   }
152   if (it != hlo_replication.end() &&
153       it->second.element(index).IsUniqueOnAllDevices()) {
154     // The HLO is already marked as non-replicated.
155     return it->second.element(index);
156   }
157 
158   if (hlo->opcode() == HloOpcode::kConstant) {
159     return HloReplication::ReplicatedOnAllDevices();
160   }
161 
162   if (hlo->opcode() == HloOpcode::kCustomCall &&
163       (hlo->custom_call_target() == "X64SplitLow" ||
164        hlo->custom_call_target() == "X64SplitHigh" ||
165        hlo->custom_call_target() == "X64Combine")) {
166     return merge_operand_replication(hlo);
167   }
168 
169   // Pattern-match and process cases where the HLO is partially replicated.
170   if (support_partial_replication) {
171     // Below is a very specific pattern to match the SPMD pipeline case.
172     if (hlo->opcode() == HloOpcode::kDynamicSlice) {
173       const HloInstruction* ds_buffer = hlo->operand(0);
174       if (hlo->dynamic_slice_sizes().size() == 1 &&
175           hlo->dynamic_slice_sizes()[0] == 1 &&
176           ds_buffer->opcode() == HloOpcode::kConstant &&
177           ds_buffer->shape().rank() == 1 &&
178           ds_buffer->shape().element_type() == PrimitiveType::S32 &&
179           ((cross_partition_spmd &&
180             hlo->operand(1)->opcode() == HloOpcode::kPartitionId) ||
181            (!cross_partition_spmd &&
182             hlo->operand(1)->opcode() == HloOpcode::kReplicaId))) {
183         const HloModule* hlo_module = hlo->GetModule();
184         int64_t num_devices = cross_partition_spmd
185                                   ? hlo_module->config().num_partitions()
186                                   : hlo_module->config().replica_count();
187         absl::flat_hash_map<int64_t, std::vector<int64_t>> value_to_device_set;
188         for (int64_t device_id = 0; device_id < num_devices; ++device_id) {
189           std::optional<int64_t> value =
190               ds_buffer->literal().GetIntegralAsS64({device_id});
191           value_to_device_set[*value].push_back(device_id);
192         }
193         std::vector<absl::Span<const int64_t>> device_sets;
194         for (const auto& value_and_device_set : value_to_device_set) {
195           device_sets.push_back(
196               absl::Span<const int64_t>(value_and_device_set.second));
197         }
198         return HloReplication::PartiallyReplicated(device_sets);
199       }
200     }
201   }
202 
203   if (hlo->IsElementwise() ||                             //
204       hlo->opcode() == HloOpcode::kConcatenate ||         //
205       hlo->opcode() == HloOpcode::kConvolution ||         //
206       hlo->opcode() == HloOpcode::kDot ||                 //
207       hlo->opcode() == HloOpcode::kReduce ||              //
208       hlo->opcode() == HloOpcode::kBroadcast ||           //
209       hlo->opcode() == HloOpcode::kTranspose ||           //
210       hlo->opcode() == HloOpcode::kReshape ||             //
211       hlo->opcode() == HloOpcode::kBitcast ||             //
212       hlo->opcode() == HloOpcode::kReverse ||             //
213       hlo->opcode() == HloOpcode::kGather ||              //
214       hlo->opcode() == HloOpcode::kScatter ||             //
215       hlo->opcode() == HloOpcode::kIota ||                //
216       hlo->opcode() == HloOpcode::kPad ||                 //
217       hlo->opcode() == HloOpcode::kSlice ||               //
218       hlo->opcode() == HloOpcode::kDynamicSlice ||        //
219       hlo->opcode() == HloOpcode::kDynamicUpdateSlice ||  //
220       hlo->opcode() == HloOpcode::kReduceWindow ||        //
221       hlo->opcode() == HloOpcode::kCopy) {
222     return merge_operand_replication(hlo);
223   }
224   return HloReplication::UniqueOnAllDevices();
225 }
226 
ComputeHloReplicationOnComputation(const HloComputation * computation,bool mark_everything_not_replicated)227 bool HloReplicationAnalysis::ComputeHloReplicationOnComputation(
228     const HloComputation* computation, bool mark_everything_not_replicated) {
229   bool changed = false;
230   for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
231     // Assigns the shape tree to dest if dest doesn't have one yet, or combines
232     // it with the existing one by and'ing them. Returns if anything is updated.
233     auto assign_or_combine_shapetree =
234         [&](ShapeTree<HloReplication>&& to_combine,
235             const HloInstruction* dest) {
236           auto it = hlo_replication_.find(dest);
237           if (it == hlo_replication_.end()) {
238             hlo_replication_[dest] = std::move(to_combine);
239             return true;
240           }
241           bool updated = false;
242           it->second.ForEachMutableElement(
243               [&](const ShapeIndex& index, HloReplication* element) {
244                 HloReplication new_replication =
245                     element->Merge(to_combine.element(index));
246                 if (!element->Equal(new_replication)) {
247                   *element = std::move(new_replication);
248                   updated = true;
249                 }
250               });
251           return updated;
252         };
253     // Assigns or combines source's shape tree to dest. Returns if anything is
254     // updated.
255     auto propagate_shapetree = [&](const HloInstruction* source,
256                                    const HloInstruction* dest) {
257       auto source_it = hlo_replication_.find(source);
258       if (source_it == hlo_replication_.end()) {
259         return false;
260       }
261       return assign_or_combine_shapetree(
262           ShapeTree<HloReplication>(source_it->second), dest);
263     };
264     // For the opcodes below that we do special handling, we don't need to
265     // explicitly check mark_everything_not_replicated because if it is set, the
266     // operands should already be marked as not replicated.
267     if (inst->opcode() == HloOpcode::kWhile) {
268       // Since while body's input and output alias each other, we need to run it
269       // multiple times until a fixed point is reached.
270       while (true) {
271         // First, propagate the input's and body root's shape trees to the
272         // parameters of the body and condition.
273         bool updated = propagate_shapetree(
274             inst->operand(0),
275             inst->while_condition()->parameter_instruction(0));
276         updated |= propagate_shapetree(
277             inst->while_body()->root_instruction(),
278             inst->while_condition()->parameter_instruction(0));
279         updated |= propagate_shapetree(
280             inst->operand(0), inst->while_body()->parameter_instruction(0));
281         updated |=
282             propagate_shapetree(inst->while_body()->root_instruction(),
283                                 inst->while_body()->parameter_instruction(0));
284         // Compute the condition.
285         updated |= ComputeHloReplicationOnComputation(
286             inst->while_condition(), mark_everything_not_replicated);
287         // Compute the body. If the condition is not replicated, the while body
288         // should be different across replicas.
289         if (!ContainsKey(loops_known_with_same_iterations_, inst) &&
290             !hlo_replication_[inst->while_condition()->root_instruction()]
291                  .element({})
292                  .IsReplicatedOnAllDevices()) {
293           updated |= ComputeHloReplicationOnComputation(
294               inst->while_body(), /*mark_everything_not_replicated=*/true);
295         } else {
296           updated |= ComputeHloReplicationOnComputation(
297               inst->while_body(), mark_everything_not_replicated);
298         }
299         if (!updated) {
300           break;
301         }
302         changed = true;
303       }
304       // Propagate the input's and body root's shape trees to the while HLO.
305       changed |= propagate_shapetree(inst->operand(0), inst);
306       changed |=
307           propagate_shapetree(inst->while_body()->root_instruction(), inst);
308     } else if (inst->opcode() == HloOpcode::kCall ||
309                inst->opcode() == HloOpcode::kFusion) {
310       auto called = inst->called_computations().front();
311       for (int64_t i = 0; i < inst->operand_count(); ++i) {
312         changed |= propagate_shapetree(inst->operand(i),
313                                        called->parameter_instruction(i));
314       }
315       changed |= ComputeHloReplicationOnComputation(
316           called, mark_everything_not_replicated);
317       changed |= propagate_shapetree(called->root_instruction(), inst);
318     } else if (inst->opcode() == HloOpcode::kConditional) {
319       // Propagate inputs' shape trees to the called computations' parameters.
320       for (int64_t i = 0; i < inst->called_computations().size(); ++i) {
321         changed |= propagate_shapetree(
322             inst->operand(i + 1),
323             inst->called_computations()[i]->parameter_instruction(0));
324       }
325       // If the condition is not replicated, the conditional result should be
326       // different across replicas.
327       if (!hlo_replication_[inst->operand(0)]
328                .element({})
329                .IsReplicatedOnAllDevices()) {
330         for (auto called : inst->called_computations()) {
331           changed |= ComputeHloReplicationOnComputation(
332               called,
333               /*mark_everything_not_replicated=*/true);
334         }
335         changed |= assign_or_combine_shapetree(
336             ShapeTree<HloReplication>(inst->shape(),
337                                       HloReplication::UniqueOnAllDevices()),
338             inst);
339       } else {
340         for (auto called : inst->called_computations()) {
341           changed |= ComputeHloReplicationOnComputation(
342               called, mark_everything_not_replicated);
343           changed |= propagate_shapetree(called->root_instruction(), inst);
344         }
345       }
346     } else if (inst->opcode() == HloOpcode::kTuple) {
347       ShapeTree<HloReplication> shape_tree(
348           inst->shape(), HloReplication::ReplicatedOnAllDevices());
349       for (int64_t i = 0; i < inst->operand_count(); ++i) {
350         shape_tree.CopySubtreeFrom(hlo_replication_[inst->operand(i)], {}, {i});
351       }
352       changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
353     } else if (inst->opcode() == HloOpcode::kGetTupleElement) {
354       ShapeTree<HloReplication> shape_tree(
355           inst->shape(), HloReplication::ReplicatedOnAllDevices());
356       shape_tree.CopySubtreeFrom(hlo_replication_[inst->operand(0)],
357                                  {inst->tuple_index()}, {});
358       changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
359     } else if (inst->opcode() == HloOpcode::kInfeed && cross_partition_spmd_) {
360       ShapeTree<HloReplication> shape_tree(
361           inst->shape(), HloReplication::UniqueOnAllDevices());
362       if (inst->has_sharding()) {
363         auto sharding = inst->sharding().GetAsShapeTree(inst->shape());
364         shape_tree.ForEachMutableElement(
365             [&sharding](const ShapeIndex& index, HloReplication* data) {
366               *data = sharding.element(index).IsReplicated()
367                           ? HloReplication::ReplicatedOnAllDevices()
368                           : HloReplication::UniqueOnAllDevices();
369             });
370       }
371       changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
372     } else {
373       if (mark_everything_not_replicated) {
374         changed |= assign_or_combine_shapetree(
375             ShapeTree<HloReplication>(inst->shape(),
376                                       HloReplication::UniqueOnAllDevices()),
377             inst);
378       } else {
379         ShapeTree<HloReplication> shape_tree(
380             inst->shape(), HloReplication::ReplicatedOnAllDevices());
381         ShapeUtil::ForEachSubshape(
382             inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
383               *shape_tree.mutable_element(index) =
384                   DetermineHloInstructionIsReplicated(
385                       inst, index, cross_partition_spmd_, hlo_replication_,
386                       support_partial_replication_);
387               return Status::OK();
388             });
389         changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
390       }
391     }
392   }
393   return changed;
394 }
395 
ComputeHloReplication()396 void HloReplicationAnalysis::ComputeHloReplication() {
397   // Add entry parameters to the above sets according to user annotation.
398   // Replicated modules read from `parameter_replicated_at_leaf_buffers` whereas
399   // SPMD partitioned modules read from HloSharding attributes.
400   auto entry = module_->entry_computation();
401   for (int i = 0; i < entry->num_parameters(); ++i) {
402     auto param = entry->parameter_instruction(i);
403     ShapeTree<HloReplication> shape_tree(param->shape(),
404                                          HloReplication::UniqueOnAllDevices());
405     const auto& replication = param->parameter_replicated_at_leaf_buffers();
406     int leaf_index = 0;
407     ShapeUtil::ForEachSubshape(
408         param->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
409           if (!ShapeUtil::IsLeafIndex(param->shape(), index)) {
410             return OkStatus();
411           }
412           if (cross_partition_spmd_ && param->has_sharding()) {
413             // In cross-partition spmd mode, set parameter replication status
414             // based on the parameter's sharding.
415             TF_ASSIGN_OR_RETURN(auto sharding_tree,
416                                 param->sharding().AsShapeTree(param->shape()));
417             *shape_tree.mutable_element(index) =
418                 sharding_tree.element(index).IsReplicated()
419                     ? HloReplication::ReplicatedOnAllDevices()
420                     : HloReplication::UniqueOnAllDevices();
421           }
422           if (replication) {
423             // If parameter replication status has been set explicitly, use that
424             // instead.
425             if (!cross_partition_spmd_ && replication->at(leaf_index)) {
426               // Setting parameter replication status for replicas in
427               // non cross-partition spmd mode.
428               *shape_tree.mutable_element(index) =
429                   HloReplication::ReplicatedOnAllDevices();
430             }
431             if (cross_partition_spmd_ && !replication->at(leaf_index)) {
432               // Setting paramemter replication status for partitions in
433               // cross-partition spmd mode.
434               *shape_tree.mutable_element(index) =
435                   HloReplication::UniqueOnAllDevices();
436             }
437             ++leaf_index;
438           }
439           return OkStatus();
440         });
441     hlo_replication_[param] = std::move(shape_tree);
442   }
443   ComputeHloReplicationOnComputation(entry,
444                                      /*mark_everything_not_replicated=*/false);
445 }
446 
HloInstructionIsReplicatedAt(const HloInstruction * inst,const ShapeIndex & index) const447 bool HloReplicationAnalysis::HloInstructionIsReplicatedAt(
448     const HloInstruction* inst, const ShapeIndex& index) const {
449   auto it = hlo_replication_.find(inst);
450   if (it == hlo_replication_.end()) {
451     return false;
452   }
453   return it->second.element(index).IsReplicatedOnAllDevices();
454 }
455 
HloInstructionIsReplicatedAt(const HloInstruction * inst,const ShapeIndex & index,absl::Span<const ReplicaGroup> replica_groups) const456 bool HloReplicationAnalysis::HloInstructionIsReplicatedAt(
457     const HloInstruction* inst, const ShapeIndex& index,
458     absl::Span<const ReplicaGroup> replica_groups) const {
459   auto it = hlo_replication_.find(inst);
460   if (it == hlo_replication_.end()) {
461     return false;
462   }
463   VLOG(5) << "HloInstructionIsReplicatedAt is called on " << inst->name()
464           << ", index: " << index.ToString()
465           << ", replication: " << it->second.element(index).ToString();
466   if (replica_groups.empty()) {
467     return it->second.element(index).IsReplicatedOnAllDevices();
468   }
469   if (it->second.element(index).IsReplicatedOnAllDevices()) {
470     return true;
471   }
472   if (it->second.element(index).IsUniqueOnAllDevices()) {
473     return false;
474   }
475   for (const ReplicaGroup& replica_group : replica_groups) {
476     if (!it->second.element(index).IsReplicatedWithinSubgroup(
477             replica_group.replica_ids())) {
478       return false;
479     }
480   }
481   return true;
482 }
483 
484 /* static */ StatusOr<std::unique_ptr<HloReplicationAnalysis>>
Run(const HloModule * module,bool cross_partition_spmd)485 HloReplicationAnalysis::Run(const HloModule* module,
486                             bool cross_partition_spmd) {
487   const absl::flat_hash_set<const HloInstruction*> empty;
488   return Run(module, cross_partition_spmd, &empty);
489 }
490 
491 /* static */ StatusOr<std::unique_ptr<HloReplicationAnalysis>>
Run(const HloModule * module,bool cross_partition_spmd,const absl::flat_hash_set<const HloInstruction * > * loops_known_with_same_iterations)492 HloReplicationAnalysis::Run(const HloModule* module, bool cross_partition_spmd,
493                             const absl::flat_hash_set<const HloInstruction*>*
494                                 loops_known_with_same_iterations) {
495   auto analysis = absl::WrapUnique(new HloReplicationAnalysis(
496       module, cross_partition_spmd, loops_known_with_same_iterations,
497       /*support_partial_replication=*/false));
498   analysis->ComputeHloReplication();
499   return analysis;
500 }
501 
502 /* static */ StatusOr<std::unique_ptr<HloReplicationAnalysis>>
RunWithPartialReplication(const HloModule * module,bool cross_partition_spmd)503 HloReplicationAnalysis::RunWithPartialReplication(const HloModule* module,
504                                                   bool cross_partition_spmd) {
505   const absl::flat_hash_set<const HloInstruction*> empty;
506   auto analysis = absl::WrapUnique(
507       new HloReplicationAnalysis(module, cross_partition_spmd, &empty,
508                                  /*support_partial_replication=*/true));
509   analysis->ComputeHloReplication();
510   return analysis;
511 }
512 
HloReplication()513 HloReplicationAnalysis::HloReplication::HloReplication()
514     : state_(State::kReplicatedOnAllDevices) {}
515 
HloReplication(HloReplicationAnalysis::HloReplication::State state,absl::Span<const int64_t> device_set_root)516 HloReplicationAnalysis::HloReplication::HloReplication(
517     HloReplicationAnalysis::HloReplication::State state,
518     absl::Span<const int64_t> device_set_root)
519     : state_(state),
520       device_set_root_(device_set_root.begin(), device_set_root.end()) {
521   CHECK(state == State::kPartiallyReplicated || device_set_root_.empty());
522 }
523 
524 HloReplicationAnalysis::HloReplication
ReplicatedOnAllDevices()525 HloReplicationAnalysis::HloReplication::ReplicatedOnAllDevices() {
526   return HloReplication(State::kReplicatedOnAllDevices, {});
527 }
528 
529 HloReplicationAnalysis::HloReplication
UniqueOnAllDevices()530 HloReplicationAnalysis::HloReplication::UniqueOnAllDevices() {
531   return HloReplication(State::kUniqueOnAllDevices, {});
532 }
533 
534 HloReplicationAnalysis::HloReplication
PartiallyReplicated(absl::Span<const absl::Span<const int64_t>> device_sets)535 HloReplicationAnalysis::HloReplication::PartiallyReplicated(
536     absl::Span<const absl::Span<const int64_t>> device_sets) {
537   int64_t max_device_id = 0;
538   for (const absl::Span<const int64_t>& device_set : device_sets) {
539     for (int64_t device_id : device_set) {
540       max_device_id = std::max(max_device_id, device_id);
541     }
542   }
543   std::vector<int64_t> device_set_root;
544   device_set_root.resize(max_device_id + 1);
545   for (const absl::Span<const int64_t>& device_set : device_sets) {
546     int64_t min_device_id = *absl::c_min_element(device_set);
547     for (int64_t device_id : device_set) {
548       device_set_root[device_id] = min_device_id;
549     }
550   }
551   return HloReplication(State::kPartiallyReplicated, device_set_root);
552 }
553 
554 HloReplicationAnalysis::HloReplication
Merge(const HloReplication & other) const555 HloReplicationAnalysis::HloReplication::Merge(
556     const HloReplication& other) const {
557   switch (state_) {
558     case State::kReplicatedOnAllDevices:
559       return other;
560     case State::kUniqueOnAllDevices:
561       return *this;
562     case State::kPartiallyReplicated: {
563       switch (other.state_) {
564         case State::kReplicatedOnAllDevices:
565           return *this;
566         case State::kUniqueOnAllDevices:
567           return other;
568         case State::kPartiallyReplicated: {
569           absl::flat_hash_map<int64_t, std::vector<int64_t>>
570               value_to_device_set;
571           size_t num_devices = device_set_root_.size();
572           for (int64_t device_id = 0; device_id < num_devices; ++device_id) {
573             int64_t new_value = device_set_root_[device_id] * num_devices +
574                                 other.device_set_root_[device_id];
575             value_to_device_set[new_value].push_back(device_id);
576           }
577           CHECK_LE(value_to_device_set.size(), num_devices);
578           if (value_to_device_set.size() == 1) {
579             return ReplicatedOnAllDevices();
580           } else if (value_to_device_set.size() < num_devices) {
581             std::vector<absl::Span<const int64_t>> device_sets;
582             for (const auto& value_and_device_set : value_to_device_set) {
583               device_sets.push_back(
584                   absl::Span<const int64_t>(value_and_device_set.second));
585             }
586             return PartiallyReplicated(device_sets);
587           } else {
588             return UniqueOnAllDevices();
589           }
590         }
591       }
592     }
593   }
594 }
595 
Equal(const HloReplication & other) const596 bool HloReplicationAnalysis::HloReplication::Equal(
597     const HloReplication& other) const {
598   if (state_ != other.state_) {
599     return false;
600   }
601   return absl::c_equal(device_set_root_, other.device_set_root_);
602 }
603 
IsReplicatedOnAllDevices() const604 bool HloReplicationAnalysis::HloReplication::IsReplicatedOnAllDevices() const {
605   return state_ == State::kReplicatedOnAllDevices;
606 }
607 
IsUniqueOnAllDevices() const608 bool HloReplicationAnalysis::HloReplication::IsUniqueOnAllDevices() const {
609   return state_ == State::kUniqueOnAllDevices;
610 }
611 
IsReplicatedWithinSubgroup(absl::Span<const int64_t> device_ids) const612 bool HloReplicationAnalysis::HloReplication::IsReplicatedWithinSubgroup(
613     absl::Span<const int64_t> device_ids) const {
614   if (device_ids.empty()) return true;
615   return absl::c_all_of(device_ids, [this, &device_ids](int device_id) {
616     return device_set_root_[device_id] == device_set_root_[device_ids.front()];
617   });
618 }
619 
ToString() const620 std::string HloReplicationAnalysis::HloReplication::ToString() const {
621   switch (state_) {
622     case State::kReplicatedOnAllDevices:
623       return "ReplicatedOnAllDevices";
624     case State::kUniqueOnAllDevices:
625       return "UniqueOnAllDevices";
626     case State::kPartiallyReplicated:
627       return absl::StrCat("PartiallyReplicated{",
628                           absl::StrJoin(device_set_root_, ","), "}");
629   }
630 }
631 
632 }  // namespace xla
633