1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_replication_analysis.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/strings/str_cat.h"
29 #include "tensorflow/compiler/xla/map_util.h"
30 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
34 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
35 #include "tensorflow/compiler/xla/shape_util.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38
39 namespace xla {
40
41 // Determines whether an HLO instruction is replicated at index based on current
42 // knowledge in hlo_replication.
43 HloReplicationAnalysis::HloReplication
DetermineHloInstructionIsReplicated(const HloInstruction * hlo,const ShapeIndex & index,bool cross_partition_spmd,const absl::flat_hash_map<const HloInstruction *,ShapeTree<HloReplication>> & hlo_replication,bool support_partial_replication)44 HloReplicationAnalysis::DetermineHloInstructionIsReplicated(
45 const HloInstruction* hlo, const ShapeIndex& index,
46 bool cross_partition_spmd,
47 const absl::flat_hash_map<const HloInstruction*, ShapeTree<HloReplication>>&
48 hlo_replication,
49 bool support_partial_replication) {
50 const auto merge_operand_replication = [&hlo_replication](
51 const HloInstruction* inst) {
52 HloReplication replication = HloReplication::ReplicatedOnAllDevices();
53 for (auto operand : inst->operands()) {
54 auto operand_it = hlo_replication.find(operand);
55 if (operand_it == hlo_replication.end()) {
56 replication = replication.Merge(HloReplication::UniqueOnAllDevices());
57 } else {
58 replication = replication.Merge(operand_it->second.element({}));
59 }
60 }
61 return replication;
62 };
63
64 if (hlo->opcode() == HloOpcode::kAllReduce ||
65 hlo->opcode() == HloOpcode::kAllGather) {
66 // All-reduce/all-gather returns same values across partitions/replicas as
67 // long as its operands are replicated.
68 HloReplication replication = merge_operand_replication(hlo);
69 if (replication.IsReplicatedOnAllDevices()) {
70 return replication;
71 }
72 if (!hlo->channel_id().has_value()) {
73 // This is cross-replica-only.
74 if (cross_partition_spmd) {
75 return replication;
76 }
77 if (hlo->replica_groups().empty() || hlo->replica_groups().size() == 1) {
78 return HloReplication::ReplicatedOnAllDevices();
79 }
80 if (support_partial_replication) {
81 std::vector<absl::Span<const int64_t>> device_sets;
82 for (const ReplicaGroup& replica_group : hlo->replica_groups()) {
83 device_sets.push_back(replica_group.replica_ids());
84 }
85 return HloReplication::PartiallyReplicated(device_sets);
86 } else {
87 return HloReplication::UniqueOnAllDevices();
88 }
89 } else {
90 bool global_id;
91 if (hlo->opcode() == HloOpcode::kAllReduce) {
92 global_id = Cast<HloAllReduceInstruction>(hlo)->use_global_device_ids();
93 } else {
94 global_id = Cast<HloAllGatherInstruction>(hlo)->use_global_device_ids();
95 }
96 if (global_id) {
97 bool replicated_across_partitions = true;
98 bool replicated_across_replicas = true;
99 const int64_t num_partitions =
100 hlo->GetModule()->config().num_partitions();
101 for (const auto& group : hlo->replica_groups()) {
102 absl::flat_hash_set<int64_t> visited_partitions;
103 absl::flat_hash_set<int64_t> visited_replicas;
104 for (int64_t id : group.replica_ids()) {
105 int64_t rid = id / num_partitions;
106 int64_t pid = id % num_partitions;
107 visited_partitions.insert(pid);
108 visited_replicas.insert(rid);
109 }
110 replicated_across_partitions &=
111 visited_partitions.size() == num_partitions;
112 replicated_across_replicas &=
113 visited_replicas.size() ==
114 hlo->GetModule()->config().replica_count();
115 }
116 if ((cross_partition_spmd && replicated_across_partitions) ||
117 (!cross_partition_spmd && replicated_across_replicas)) {
118 return HloReplication::ReplicatedOnAllDevices();
119 } else {
120 return HloReplication::UniqueOnAllDevices();
121 }
122 }
123 if (cross_partition_spmd) {
124 return HloReplication::ReplicatedOnAllDevices();
125 }
126 if (hlo->replica_groups().empty() || hlo->replica_groups().size() == 1) {
127 return HloReplication::ReplicatedOnAllDevices();
128 } else {
129 return HloReplication::UniqueOnAllDevices();
130 }
131 }
132 }
133 if (hlo->HasSideEffectNoRecurse()) {
134 return HloReplication::UniqueOnAllDevices();
135 }
136 if (hlo->opcode() == HloOpcode::kReplicaId) {
137 // ReplicaId returns the same value for all partitions in each replica.
138 return cross_partition_spmd ? HloReplication::ReplicatedOnAllDevices()
139 : HloReplication::UniqueOnAllDevices();
140 }
141 if (hlo->opcode() == HloOpcode::kPartitionId) {
142 // PartitionId returns the same value for all replicas in each partition.
143 return cross_partition_spmd ? HloReplication::UniqueOnAllDevices()
144 : HloReplication::ReplicatedOnAllDevices();
145 }
146 auto it = hlo_replication.find(hlo);
147 if (hlo->opcode() == HloOpcode::kParameter) {
148 // Parameters should have been processed.
149 CHECK(it != hlo_replication.end());
150 return it->second.element(index);
151 }
152 if (it != hlo_replication.end() &&
153 it->second.element(index).IsUniqueOnAllDevices()) {
154 // The HLO is already marked as non-replicated.
155 return it->second.element(index);
156 }
157
158 if (hlo->opcode() == HloOpcode::kConstant) {
159 return HloReplication::ReplicatedOnAllDevices();
160 }
161
162 if (hlo->opcode() == HloOpcode::kCustomCall &&
163 (hlo->custom_call_target() == "X64SplitLow" ||
164 hlo->custom_call_target() == "X64SplitHigh" ||
165 hlo->custom_call_target() == "X64Combine")) {
166 return merge_operand_replication(hlo);
167 }
168
169 // Pattern-match and process cases where the HLO is partially replicated.
170 if (support_partial_replication) {
171 // Below is a very specific pattern to match the SPMD pipeline case.
172 if (hlo->opcode() == HloOpcode::kDynamicSlice) {
173 const HloInstruction* ds_buffer = hlo->operand(0);
174 if (hlo->dynamic_slice_sizes().size() == 1 &&
175 hlo->dynamic_slice_sizes()[0] == 1 &&
176 ds_buffer->opcode() == HloOpcode::kConstant &&
177 ds_buffer->shape().rank() == 1 &&
178 ds_buffer->shape().element_type() == PrimitiveType::S32 &&
179 ((cross_partition_spmd &&
180 hlo->operand(1)->opcode() == HloOpcode::kPartitionId) ||
181 (!cross_partition_spmd &&
182 hlo->operand(1)->opcode() == HloOpcode::kReplicaId))) {
183 const HloModule* hlo_module = hlo->GetModule();
184 int64_t num_devices = cross_partition_spmd
185 ? hlo_module->config().num_partitions()
186 : hlo_module->config().replica_count();
187 absl::flat_hash_map<int64_t, std::vector<int64_t>> value_to_device_set;
188 for (int64_t device_id = 0; device_id < num_devices; ++device_id) {
189 std::optional<int64_t> value =
190 ds_buffer->literal().GetIntegralAsS64({device_id});
191 value_to_device_set[*value].push_back(device_id);
192 }
193 std::vector<absl::Span<const int64_t>> device_sets;
194 for (const auto& value_and_device_set : value_to_device_set) {
195 device_sets.push_back(
196 absl::Span<const int64_t>(value_and_device_set.second));
197 }
198 return HloReplication::PartiallyReplicated(device_sets);
199 }
200 }
201 }
202
203 if (hlo->IsElementwise() || //
204 hlo->opcode() == HloOpcode::kConcatenate || //
205 hlo->opcode() == HloOpcode::kConvolution || //
206 hlo->opcode() == HloOpcode::kDot || //
207 hlo->opcode() == HloOpcode::kReduce || //
208 hlo->opcode() == HloOpcode::kBroadcast || //
209 hlo->opcode() == HloOpcode::kTranspose || //
210 hlo->opcode() == HloOpcode::kReshape || //
211 hlo->opcode() == HloOpcode::kBitcast || //
212 hlo->opcode() == HloOpcode::kReverse || //
213 hlo->opcode() == HloOpcode::kGather || //
214 hlo->opcode() == HloOpcode::kScatter || //
215 hlo->opcode() == HloOpcode::kIota || //
216 hlo->opcode() == HloOpcode::kPad || //
217 hlo->opcode() == HloOpcode::kSlice || //
218 hlo->opcode() == HloOpcode::kDynamicSlice || //
219 hlo->opcode() == HloOpcode::kDynamicUpdateSlice || //
220 hlo->opcode() == HloOpcode::kReduceWindow || //
221 hlo->opcode() == HloOpcode::kCopy) {
222 return merge_operand_replication(hlo);
223 }
224 return HloReplication::UniqueOnAllDevices();
225 }
226
ComputeHloReplicationOnComputation(const HloComputation * computation,bool mark_everything_not_replicated)227 bool HloReplicationAnalysis::ComputeHloReplicationOnComputation(
228 const HloComputation* computation, bool mark_everything_not_replicated) {
229 bool changed = false;
230 for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
231 // Assigns the shape tree to dest if dest doesn't have one yet, or combines
232 // it with the existing one by and'ing them. Returns if anything is updated.
233 auto assign_or_combine_shapetree =
234 [&](ShapeTree<HloReplication>&& to_combine,
235 const HloInstruction* dest) {
236 auto it = hlo_replication_.find(dest);
237 if (it == hlo_replication_.end()) {
238 hlo_replication_[dest] = std::move(to_combine);
239 return true;
240 }
241 bool updated = false;
242 it->second.ForEachMutableElement(
243 [&](const ShapeIndex& index, HloReplication* element) {
244 HloReplication new_replication =
245 element->Merge(to_combine.element(index));
246 if (!element->Equal(new_replication)) {
247 *element = std::move(new_replication);
248 updated = true;
249 }
250 });
251 return updated;
252 };
253 // Assigns or combines source's shape tree to dest. Returns if anything is
254 // updated.
255 auto propagate_shapetree = [&](const HloInstruction* source,
256 const HloInstruction* dest) {
257 auto source_it = hlo_replication_.find(source);
258 if (source_it == hlo_replication_.end()) {
259 return false;
260 }
261 return assign_or_combine_shapetree(
262 ShapeTree<HloReplication>(source_it->second), dest);
263 };
264 // For the opcodes below that we do special handling, we don't need to
265 // explicitly check mark_everything_not_replicated because if it is set, the
266 // operands should already be marked as not replicated.
267 if (inst->opcode() == HloOpcode::kWhile) {
268 // Since while body's input and output alias each other, we need to run it
269 // multiple times until a fixed point is reached.
270 while (true) {
271 // First, propagate the input's and body root's shape trees to the
272 // parameters of the body and condition.
273 bool updated = propagate_shapetree(
274 inst->operand(0),
275 inst->while_condition()->parameter_instruction(0));
276 updated |= propagate_shapetree(
277 inst->while_body()->root_instruction(),
278 inst->while_condition()->parameter_instruction(0));
279 updated |= propagate_shapetree(
280 inst->operand(0), inst->while_body()->parameter_instruction(0));
281 updated |=
282 propagate_shapetree(inst->while_body()->root_instruction(),
283 inst->while_body()->parameter_instruction(0));
284 // Compute the condition.
285 updated |= ComputeHloReplicationOnComputation(
286 inst->while_condition(), mark_everything_not_replicated);
287 // Compute the body. If the condition is not replicated, the while body
288 // should be different across replicas.
289 if (!ContainsKey(loops_known_with_same_iterations_, inst) &&
290 !hlo_replication_[inst->while_condition()->root_instruction()]
291 .element({})
292 .IsReplicatedOnAllDevices()) {
293 updated |= ComputeHloReplicationOnComputation(
294 inst->while_body(), /*mark_everything_not_replicated=*/true);
295 } else {
296 updated |= ComputeHloReplicationOnComputation(
297 inst->while_body(), mark_everything_not_replicated);
298 }
299 if (!updated) {
300 break;
301 }
302 changed = true;
303 }
304 // Propagate the input's and body root's shape trees to the while HLO.
305 changed |= propagate_shapetree(inst->operand(0), inst);
306 changed |=
307 propagate_shapetree(inst->while_body()->root_instruction(), inst);
308 } else if (inst->opcode() == HloOpcode::kCall ||
309 inst->opcode() == HloOpcode::kFusion) {
310 auto called = inst->called_computations().front();
311 for (int64_t i = 0; i < inst->operand_count(); ++i) {
312 changed |= propagate_shapetree(inst->operand(i),
313 called->parameter_instruction(i));
314 }
315 changed |= ComputeHloReplicationOnComputation(
316 called, mark_everything_not_replicated);
317 changed |= propagate_shapetree(called->root_instruction(), inst);
318 } else if (inst->opcode() == HloOpcode::kConditional) {
319 // Propagate inputs' shape trees to the called computations' parameters.
320 for (int64_t i = 0; i < inst->called_computations().size(); ++i) {
321 changed |= propagate_shapetree(
322 inst->operand(i + 1),
323 inst->called_computations()[i]->parameter_instruction(0));
324 }
325 // If the condition is not replicated, the conditional result should be
326 // different across replicas.
327 if (!hlo_replication_[inst->operand(0)]
328 .element({})
329 .IsReplicatedOnAllDevices()) {
330 for (auto called : inst->called_computations()) {
331 changed |= ComputeHloReplicationOnComputation(
332 called,
333 /*mark_everything_not_replicated=*/true);
334 }
335 changed |= assign_or_combine_shapetree(
336 ShapeTree<HloReplication>(inst->shape(),
337 HloReplication::UniqueOnAllDevices()),
338 inst);
339 } else {
340 for (auto called : inst->called_computations()) {
341 changed |= ComputeHloReplicationOnComputation(
342 called, mark_everything_not_replicated);
343 changed |= propagate_shapetree(called->root_instruction(), inst);
344 }
345 }
346 } else if (inst->opcode() == HloOpcode::kTuple) {
347 ShapeTree<HloReplication> shape_tree(
348 inst->shape(), HloReplication::ReplicatedOnAllDevices());
349 for (int64_t i = 0; i < inst->operand_count(); ++i) {
350 shape_tree.CopySubtreeFrom(hlo_replication_[inst->operand(i)], {}, {i});
351 }
352 changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
353 } else if (inst->opcode() == HloOpcode::kGetTupleElement) {
354 ShapeTree<HloReplication> shape_tree(
355 inst->shape(), HloReplication::ReplicatedOnAllDevices());
356 shape_tree.CopySubtreeFrom(hlo_replication_[inst->operand(0)],
357 {inst->tuple_index()}, {});
358 changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
359 } else if (inst->opcode() == HloOpcode::kInfeed && cross_partition_spmd_) {
360 ShapeTree<HloReplication> shape_tree(
361 inst->shape(), HloReplication::UniqueOnAllDevices());
362 if (inst->has_sharding()) {
363 auto sharding = inst->sharding().GetAsShapeTree(inst->shape());
364 shape_tree.ForEachMutableElement(
365 [&sharding](const ShapeIndex& index, HloReplication* data) {
366 *data = sharding.element(index).IsReplicated()
367 ? HloReplication::ReplicatedOnAllDevices()
368 : HloReplication::UniqueOnAllDevices();
369 });
370 }
371 changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
372 } else {
373 if (mark_everything_not_replicated) {
374 changed |= assign_or_combine_shapetree(
375 ShapeTree<HloReplication>(inst->shape(),
376 HloReplication::UniqueOnAllDevices()),
377 inst);
378 } else {
379 ShapeTree<HloReplication> shape_tree(
380 inst->shape(), HloReplication::ReplicatedOnAllDevices());
381 ShapeUtil::ForEachSubshape(
382 inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
383 *shape_tree.mutable_element(index) =
384 DetermineHloInstructionIsReplicated(
385 inst, index, cross_partition_spmd_, hlo_replication_,
386 support_partial_replication_);
387 return Status::OK();
388 });
389 changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
390 }
391 }
392 }
393 return changed;
394 }
395
ComputeHloReplication()396 void HloReplicationAnalysis::ComputeHloReplication() {
397 // Add entry parameters to the above sets according to user annotation.
398 // Replicated modules read from `parameter_replicated_at_leaf_buffers` whereas
399 // SPMD partitioned modules read from HloSharding attributes.
400 auto entry = module_->entry_computation();
401 for (int i = 0; i < entry->num_parameters(); ++i) {
402 auto param = entry->parameter_instruction(i);
403 ShapeTree<HloReplication> shape_tree(param->shape(),
404 HloReplication::UniqueOnAllDevices());
405 const auto& replication = param->parameter_replicated_at_leaf_buffers();
406 int leaf_index = 0;
407 ShapeUtil::ForEachSubshape(
408 param->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
409 if (!ShapeUtil::IsLeafIndex(param->shape(), index)) {
410 return OkStatus();
411 }
412 if (cross_partition_spmd_ && param->has_sharding()) {
413 // In cross-partition spmd mode, set parameter replication status
414 // based on the parameter's sharding.
415 TF_ASSIGN_OR_RETURN(auto sharding_tree,
416 param->sharding().AsShapeTree(param->shape()));
417 *shape_tree.mutable_element(index) =
418 sharding_tree.element(index).IsReplicated()
419 ? HloReplication::ReplicatedOnAllDevices()
420 : HloReplication::UniqueOnAllDevices();
421 }
422 if (replication) {
423 // If parameter replication status has been set explicitly, use that
424 // instead.
425 if (!cross_partition_spmd_ && replication->at(leaf_index)) {
426 // Setting parameter replication status for replicas in
427 // non cross-partition spmd mode.
428 *shape_tree.mutable_element(index) =
429 HloReplication::ReplicatedOnAllDevices();
430 }
431 if (cross_partition_spmd_ && !replication->at(leaf_index)) {
432 // Setting paramemter replication status for partitions in
433 // cross-partition spmd mode.
434 *shape_tree.mutable_element(index) =
435 HloReplication::UniqueOnAllDevices();
436 }
437 ++leaf_index;
438 }
439 return OkStatus();
440 });
441 hlo_replication_[param] = std::move(shape_tree);
442 }
443 ComputeHloReplicationOnComputation(entry,
444 /*mark_everything_not_replicated=*/false);
445 }
446
HloInstructionIsReplicatedAt(const HloInstruction * inst,const ShapeIndex & index) const447 bool HloReplicationAnalysis::HloInstructionIsReplicatedAt(
448 const HloInstruction* inst, const ShapeIndex& index) const {
449 auto it = hlo_replication_.find(inst);
450 if (it == hlo_replication_.end()) {
451 return false;
452 }
453 return it->second.element(index).IsReplicatedOnAllDevices();
454 }
455
HloInstructionIsReplicatedAt(const HloInstruction * inst,const ShapeIndex & index,absl::Span<const ReplicaGroup> replica_groups) const456 bool HloReplicationAnalysis::HloInstructionIsReplicatedAt(
457 const HloInstruction* inst, const ShapeIndex& index,
458 absl::Span<const ReplicaGroup> replica_groups) const {
459 auto it = hlo_replication_.find(inst);
460 if (it == hlo_replication_.end()) {
461 return false;
462 }
463 VLOG(5) << "HloInstructionIsReplicatedAt is called on " << inst->name()
464 << ", index: " << index.ToString()
465 << ", replication: " << it->second.element(index).ToString();
466 if (replica_groups.empty()) {
467 return it->second.element(index).IsReplicatedOnAllDevices();
468 }
469 if (it->second.element(index).IsReplicatedOnAllDevices()) {
470 return true;
471 }
472 if (it->second.element(index).IsUniqueOnAllDevices()) {
473 return false;
474 }
475 for (const ReplicaGroup& replica_group : replica_groups) {
476 if (!it->second.element(index).IsReplicatedWithinSubgroup(
477 replica_group.replica_ids())) {
478 return false;
479 }
480 }
481 return true;
482 }
483
484 /* static */ StatusOr<std::unique_ptr<HloReplicationAnalysis>>
Run(const HloModule * module,bool cross_partition_spmd)485 HloReplicationAnalysis::Run(const HloModule* module,
486 bool cross_partition_spmd) {
487 const absl::flat_hash_set<const HloInstruction*> empty;
488 return Run(module, cross_partition_spmd, &empty);
489 }
490
491 /* static */ StatusOr<std::unique_ptr<HloReplicationAnalysis>>
Run(const HloModule * module,bool cross_partition_spmd,const absl::flat_hash_set<const HloInstruction * > * loops_known_with_same_iterations)492 HloReplicationAnalysis::Run(const HloModule* module, bool cross_partition_spmd,
493 const absl::flat_hash_set<const HloInstruction*>*
494 loops_known_with_same_iterations) {
495 auto analysis = absl::WrapUnique(new HloReplicationAnalysis(
496 module, cross_partition_spmd, loops_known_with_same_iterations,
497 /*support_partial_replication=*/false));
498 analysis->ComputeHloReplication();
499 return analysis;
500 }
501
502 /* static */ StatusOr<std::unique_ptr<HloReplicationAnalysis>>
RunWithPartialReplication(const HloModule * module,bool cross_partition_spmd)503 HloReplicationAnalysis::RunWithPartialReplication(const HloModule* module,
504 bool cross_partition_spmd) {
505 const absl::flat_hash_set<const HloInstruction*> empty;
506 auto analysis = absl::WrapUnique(
507 new HloReplicationAnalysis(module, cross_partition_spmd, &empty,
508 /*support_partial_replication=*/true));
509 analysis->ComputeHloReplication();
510 return analysis;
511 }
512
HloReplication()513 HloReplicationAnalysis::HloReplication::HloReplication()
514 : state_(State::kReplicatedOnAllDevices) {}
515
HloReplication(HloReplicationAnalysis::HloReplication::State state,absl::Span<const int64_t> device_set_root)516 HloReplicationAnalysis::HloReplication::HloReplication(
517 HloReplicationAnalysis::HloReplication::State state,
518 absl::Span<const int64_t> device_set_root)
519 : state_(state),
520 device_set_root_(device_set_root.begin(), device_set_root.end()) {
521 CHECK(state == State::kPartiallyReplicated || device_set_root_.empty());
522 }
523
524 HloReplicationAnalysis::HloReplication
ReplicatedOnAllDevices()525 HloReplicationAnalysis::HloReplication::ReplicatedOnAllDevices() {
526 return HloReplication(State::kReplicatedOnAllDevices, {});
527 }
528
529 HloReplicationAnalysis::HloReplication
UniqueOnAllDevices()530 HloReplicationAnalysis::HloReplication::UniqueOnAllDevices() {
531 return HloReplication(State::kUniqueOnAllDevices, {});
532 }
533
534 HloReplicationAnalysis::HloReplication
PartiallyReplicated(absl::Span<const absl::Span<const int64_t>> device_sets)535 HloReplicationAnalysis::HloReplication::PartiallyReplicated(
536 absl::Span<const absl::Span<const int64_t>> device_sets) {
537 int64_t max_device_id = 0;
538 for (const absl::Span<const int64_t>& device_set : device_sets) {
539 for (int64_t device_id : device_set) {
540 max_device_id = std::max(max_device_id, device_id);
541 }
542 }
543 std::vector<int64_t> device_set_root;
544 device_set_root.resize(max_device_id + 1);
545 for (const absl::Span<const int64_t>& device_set : device_sets) {
546 int64_t min_device_id = *absl::c_min_element(device_set);
547 for (int64_t device_id : device_set) {
548 device_set_root[device_id] = min_device_id;
549 }
550 }
551 return HloReplication(State::kPartiallyReplicated, device_set_root);
552 }
553
554 HloReplicationAnalysis::HloReplication
Merge(const HloReplication & other) const555 HloReplicationAnalysis::HloReplication::Merge(
556 const HloReplication& other) const {
557 switch (state_) {
558 case State::kReplicatedOnAllDevices:
559 return other;
560 case State::kUniqueOnAllDevices:
561 return *this;
562 case State::kPartiallyReplicated: {
563 switch (other.state_) {
564 case State::kReplicatedOnAllDevices:
565 return *this;
566 case State::kUniqueOnAllDevices:
567 return other;
568 case State::kPartiallyReplicated: {
569 absl::flat_hash_map<int64_t, std::vector<int64_t>>
570 value_to_device_set;
571 size_t num_devices = device_set_root_.size();
572 for (int64_t device_id = 0; device_id < num_devices; ++device_id) {
573 int64_t new_value = device_set_root_[device_id] * num_devices +
574 other.device_set_root_[device_id];
575 value_to_device_set[new_value].push_back(device_id);
576 }
577 CHECK_LE(value_to_device_set.size(), num_devices);
578 if (value_to_device_set.size() == 1) {
579 return ReplicatedOnAllDevices();
580 } else if (value_to_device_set.size() < num_devices) {
581 std::vector<absl::Span<const int64_t>> device_sets;
582 for (const auto& value_and_device_set : value_to_device_set) {
583 device_sets.push_back(
584 absl::Span<const int64_t>(value_and_device_set.second));
585 }
586 return PartiallyReplicated(device_sets);
587 } else {
588 return UniqueOnAllDevices();
589 }
590 }
591 }
592 }
593 }
594 }
595
Equal(const HloReplication & other) const596 bool HloReplicationAnalysis::HloReplication::Equal(
597 const HloReplication& other) const {
598 if (state_ != other.state_) {
599 return false;
600 }
601 return absl::c_equal(device_set_root_, other.device_set_root_);
602 }
603
IsReplicatedOnAllDevices() const604 bool HloReplicationAnalysis::HloReplication::IsReplicatedOnAllDevices() const {
605 return state_ == State::kReplicatedOnAllDevices;
606 }
607
IsUniqueOnAllDevices() const608 bool HloReplicationAnalysis::HloReplication::IsUniqueOnAllDevices() const {
609 return state_ == State::kUniqueOnAllDevices;
610 }
611
IsReplicatedWithinSubgroup(absl::Span<const int64_t> device_ids) const612 bool HloReplicationAnalysis::HloReplication::IsReplicatedWithinSubgroup(
613 absl::Span<const int64_t> device_ids) const {
614 if (device_ids.empty()) return true;
615 return absl::c_all_of(device_ids, [this, &device_ids](int device_id) {
616 return device_set_root_[device_id] == device_set_root_[device_ids.front()];
617 });
618 }
619
ToString() const620 std::string HloReplicationAnalysis::HloReplication::ToString() const {
621 switch (state_) {
622 case State::kReplicatedOnAllDevices:
623 return "ReplicatedOnAllDevices";
624 case State::kUniqueOnAllDevices:
625 return "UniqueOnAllDevices";
626 case State::kPartiallyReplicated:
627 return absl::StrCat("PartiallyReplicated{",
628 absl::StrJoin(device_set_root_, ","), "}");
629 }
630 }
631
632 } // namespace xla
633