1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/ar_crs_combiner.h"
17
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/call_graph.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_query.h"
31 #include "tensorflow/compiler/xla/service/hlo_replication_analysis.h"
32 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/types.h"
36
37 namespace xla {
38 namespace {
39
40 // In SPMD mode, if there's a cross-replica all-reduce that produces the same
41 // value for all partitions, replaces it with a global all-reduce and then
42 // divide by the number of partitions. Depending on the topology and the
43 // implementation of the all-reduce for the backend, this may give a better
44 // performance.
ReplaceReplicatedAllReduce(HloModule * module,int64_t partition_count)45 StatusOr<bool> ReplaceReplicatedAllReduce(HloModule* module,
46 int64_t partition_count) {
47 TF_ASSIGN_OR_RETURN(
48 auto replication_analysis,
49 HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true));
50
51 bool changed = false;
52 int64_t next_channel = hlo_query::NextChannelId(*module);
53 for (auto computation : module->computations()) {
54 for (auto instruction : computation->instructions()) {
55 if (auto ar = DynCast<HloAllReduceInstruction>(instruction)) {
56 const Shape& shape = ar->shape();
57 if (ar->channel_id()) {
58 continue;
59 }
60 if (ar->replica_groups().size() > 1) {
61 continue;
62 }
63 if (shape.IsTuple() || shape.element_type() != F32) {
64 continue;
65 }
66 // We would need a cost model for the target, but in general we want to
67 // rewrite only if the replica count in the original op was large.
68 if (module->config().replica_count() < 8 * partition_count) {
69 continue;
70 }
71 if (replication_analysis->HloInstructionIsReplicatedAt(ar, {})) {
72 VLOG(2) << "Replaced replicated all-reduce:" << ar->ToString();
73 ar->set_channel_id(next_channel++);
74 auto divisor =
75 computation->AddInstruction(HloInstruction::CreateConstant(
76 LiteralUtil::CreateR0<float>(partition_count)));
77 auto bcast = computation->AddInstruction(
78 HloInstruction::CreateBroadcast(shape, divisor, {}));
79 auto div = computation->AddInstruction(HloInstruction::CreateBinary(
80 ar->shape(), HloOpcode::kDivide, ar, bcast));
81 TF_RETURN_IF_ERROR(ar->ReplaceAllUsesWith(div));
82 changed = true;
83 }
84 }
85 }
86 }
87 return changed;
88 }
89
90 // Returns true if the given instruction (must be a cross-partition all-reduce)
91 // has a ReplicaGroup config that can be combined with cross-replica all-reduce.
92 // We currently restrict to those groups where all partitions in each replica
93 // belong to the same group.
HasCombinableReplicaGroup(HloInstruction * hlo,int64_t num_partitions)94 bool HasCombinableReplicaGroup(HloInstruction* hlo, int64_t num_partitions) {
95 auto all_reduce = Cast<HloAllReduceInstruction>(hlo);
96 auto replica_groups = all_reduce->replica_groups();
97 const int64_t replica_count = hlo->GetModule()->config().replica_count();
98 CHECK(all_reduce->IsCrossModuleAllReduce());
99
100 if (all_reduce->use_global_device_ids()) {
101 if (replica_groups.size() != replica_count) {
102 return false;
103 }
104 for (const auto& group : replica_groups) {
105 if (group.replica_ids_size() != num_partitions) {
106 return false;
107 }
108 absl::flat_hash_set<int64_t> partition_ids;
109 int64_t replica_id = group.replica_ids(0) / num_partitions;
110 for (int64_t i = 0; i < num_partitions; ++i) {
111 if (group.replica_ids(i) / num_partitions != replica_id) {
112 return false;
113 }
114 partition_ids.insert(group.replica_ids(i) % num_partitions);
115 }
116 if (partition_ids.size() != num_partitions) {
117 return false;
118 }
119 }
120 return true;
121 }
122
123 return replica_groups.size() == replica_count;
124 }
125
126 } // namespace
127
128 namespace m = match;
129
130 // Checks if the argument instruction is an AllReduce, followed by a certain
131 // sequence of instructions and then a CRS. It must be possible to move
132 // the AR past each instruction in the sequence.
MatchesArCrsPattern(HloInstruction * instruction)133 std::optional<ArCrsCombiner::ArCrsPair> ArCrsCombiner::MatchesArCrsPattern(
134 HloInstruction* instruction) {
135 auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool {
136 if (instruction->user_count() != 1) {
137 return false;
138 }
139 switch (instruction->opcode()) {
140 case HloOpcode::kBitcast:
141 case HloOpcode::kTranspose:
142 case HloOpcode::kReshape:
143 return true;
144 case HloOpcode::kConvert:
145 // Can be moved across if both input and output is either float or
146 // integer (e.g. S32<->U32 or F32<->BF16)
147 return ShapeUtil::ElementIsFloating(instruction->shape()) ==
148 ShapeUtil::ElementIsFloating(instruction->operand(0)->shape());
149 case HloOpcode::kAdd:
150 case HloOpcode::kSubtract:
151 case HloOpcode::kMultiply:
152 // Only supported for floating point operands.
153 return ShapeUtil::ElementIsFloating(instruction->shape());
154 default:
155 return false;
156 }
157 };
158
159 auto computation_is_addition = [](HloComputation* c) {
160 return c->instruction_count() == 3 &&
161 Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter()));
162 };
163
164 // We only support combining cross-partition all-reduce where each replica
165 // belongs to its own group, since the later cross-replica all-reduce combines
166 // along the replica dimension.
167 if (instruction->IsCrossModuleAllReduce() &&
168 HasCombinableReplicaGroup(instruction, num_spatial_partitions_) &&
169 computation_is_addition(instruction->called_computations()[0]) &&
170 instruction->user_count() == 1) {
171 auto next = instruction->users()[0];
172 int64_t distance = 1;
173 while (!next->IsCrossReplicaAllReduce()) {
174 if (can_ar_move_past_instruction(next)) {
175 next = next->users()[0];
176 } else {
177 return std::nullopt;
178 }
179 ++distance;
180 }
181 if (!Cast<HloAllReduceInstruction>(next)->IsNoop() &&
182 computation_is_addition(next->called_computations()[0])) {
183 ArCrsPair pair(instruction, next, distance);
184 VLOG(2) << "ArCrsPair matching pattern: " << pair.ToString();
185 return pair;
186 }
187 }
188 return std::nullopt;
189 }
190
WhileFromBodyParameter(HloInstruction * instruction)191 std::optional<HloInstruction*> ArCrsCombiner::WhileFromBodyParameter(
192 HloInstruction* instruction) {
193 CHECK_EQ(HloOpcode::kParameter, instruction->opcode());
194 HloComputation* computation = instruction->parent();
195 auto caller_instructions = call_graph_->GetComputationCallers(computation);
196 if (caller_instructions.size() == 1) {
197 auto caller_instruction = caller_instructions[0];
198 if (caller_instruction->opcode() == HloOpcode::kWhile) {
199 return caller_instruction;
200 }
201 }
202 return std::nullopt;
203 }
204
ConditionalFromBodyParameter(HloInstruction * instruction)205 std::optional<HloInstruction*> ArCrsCombiner::ConditionalFromBodyParameter(
206 HloInstruction* instruction) {
207 CHECK_EQ(HloOpcode::kParameter, instruction->opcode());
208 HloComputation* computation = instruction->parent();
209 auto caller_instructions = call_graph_->GetComputationCallers(computation);
210 if (caller_instructions.size() == 1) {
211 auto caller_instruction = caller_instructions[0];
212 if (caller_instruction->opcode() == HloOpcode::kConditional) {
213 return caller_instruction;
214 }
215 }
216 return std::nullopt;
217 }
218
GetAllTuples(HloInstruction * instruction,absl::flat_hash_set<HloInstruction * > * visited)219 std::optional<std::vector<HloInstruction*>> ArCrsCombiner::GetAllTuples(
220 HloInstruction* instruction,
221 absl::flat_hash_set<HloInstruction*>* visited) {
222 if (visited->find(instruction) != visited->end()) {
223 return std::vector<HloInstruction*>();
224 }
225 visited->insert(instruction);
226
227 switch (instruction->opcode()) {
228 case HloOpcode::kTuple: {
229 return std::vector<HloInstruction*>({instruction});
230 }
231 case HloOpcode::kDomain: {
232 return GetAllTuples(instruction->operands()[0], visited);
233 }
234 case HloOpcode::kParameter: {
235 auto maybe_while = WhileFromBodyParameter(instruction);
236 if (maybe_while) {
237 auto while_instr = *maybe_while;
238 auto init_tuples = GetAllTuples(while_instr->while_init(), visited);
239 auto body_tuples = GetAllTuples(
240 while_instr->while_body()->root_instruction(), visited);
241 if (!init_tuples || !body_tuples) {
242 return std::nullopt;
243 }
244 auto result = *init_tuples;
245 result.insert(result.end(), body_tuples->begin(), body_tuples->end());
246 return result;
247 }
248 auto maybe_conditional = ConditionalFromBodyParameter(instruction);
249 if (maybe_conditional) {
250 auto cond_instr = *maybe_conditional;
251 std::vector<HloInstruction*> tuples;
252 for (int64_t i = 0; i < cond_instr->branch_computations().size(); ++i) {
253 if (cond_instr->branch_computation(i)->parameter_instruction(0) ==
254 instruction) {
255 // If the same computation is used for more than one branch of the
256 // conditional, we collect the arguments that flow to the
257 // computation from all branches.
258 auto branch_tuples =
259 GetAllTuples(cond_instr->mutable_operand(i + 1), visited);
260 if (!branch_tuples) {
261 return std::nullopt;
262 }
263 tuples.insert(tuples.end(), branch_tuples->begin(),
264 branch_tuples->end());
265 }
266 }
267 return tuples;
268 }
269 return std::nullopt;
270 }
271 case HloOpcode::kGetTupleElement: {
272 std::vector<HloInstruction*> result_tuples;
273 auto tuples = GetAllTuples(instruction->operands()[0], visited);
274 if (!tuples) {
275 return std::nullopt;
276 }
277 for (auto tuple : *tuples) {
278 auto tmp_tuples = GetAllTuples(
279 tuple->mutable_operand(instruction->tuple_index()), visited);
280 if (!tmp_tuples) {
281 return std::nullopt;
282 }
283 result_tuples.insert(result_tuples.end(), tmp_tuples->begin(),
284 tmp_tuples->end());
285 }
286 return result_tuples;
287 }
288 case HloOpcode::kConditional: {
289 std::vector<HloInstruction*> result_tuples;
290 const auto& branch_computations = instruction->branch_computations();
291 result_tuples.reserve(branch_computations.size());
292 for (HloComputation* body : branch_computations) {
293 if (body->root_instruction()->opcode() != HloOpcode::kTuple) {
294 return std::nullopt;
295 }
296 result_tuples.push_back(body->root_instruction());
297 }
298 return result_tuples;
299 }
300 case HloOpcode::kWhile: {
301 auto init_tuples = GetAllTuples(instruction->while_init(), visited);
302 auto body_tuples =
303 GetAllTuples(instruction->while_body()->root_instruction(), visited);
304 if (!init_tuples || !body_tuples) {
305 return std::nullopt;
306 }
307 auto result = *init_tuples;
308 result.insert(result.end(), body_tuples->begin(), body_tuples->end());
309 return result;
310 }
311 default:
312 return std::nullopt;
313 }
314 }
315
TupleElementsComputeSameValue(HloInstruction * tuple_shaped_instruction,int64_t i1,int64_t i2,absl::flat_hash_map<int64_t,int64_t> * visited_pairs)316 bool ArCrsCombiner::TupleElementsComputeSameValue(
317 HloInstruction* tuple_shaped_instruction, int64_t i1, int64_t i2,
318 absl::flat_hash_map<int64_t, int64_t>* visited_pairs) {
319 absl::flat_hash_set<HloInstruction*> visited;
320 auto tuples = GetAllTuples(tuple_shaped_instruction, &visited);
321 if (!tuples) {
322 return false;
323 }
324 for (auto tuple : *tuples) {
325 CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
326 if (!InstructionsComputeSameValue(tuple->mutable_operand(i1),
327 tuple->mutable_operand(i2),
328 visited_pairs)) {
329 return false;
330 }
331 }
332 return true;
333 }
334
335 /* static */
TestInstructionsComputeSameValue(HloInstruction * i1,HloInstruction * i2)336 bool ArCrsCombiner::TestInstructionsComputeSameValue(HloInstruction* i1,
337 HloInstruction* i2) {
338 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
339 /*spmd_partition=*/false);
340 auto module = i1->parent()->parent();
341 CHECK_EQ(module, i2->parent()->parent());
342 combiner.call_graph_ = CallGraph::Build(module);
343 absl::flat_hash_map<int64_t, int64_t> visited_pairs;
344 return combiner.InstructionsComputeSameValue(i1, i2, &visited_pairs);
345 }
346
InstructionsComputeSameValue(HloInstruction * i1,HloInstruction * i2,absl::flat_hash_map<int64_t,int64_t> * visited_pairs)347 bool ArCrsCombiner::InstructionsComputeSameValue(
348 HloInstruction* i1, HloInstruction* i2,
349 absl::flat_hash_map<int64_t, int64_t>* visited_pairs) {
350 if (i1 == i2) {
351 return true;
352 }
353 auto uid1 = i1->unique_id();
354 auto uid2 = i2->unique_id();
355 auto min_uid = std::min(uid1, uid2);
356 auto max_uid = std::max(uid1, uid2);
357 auto it = visited_pairs->find(min_uid);
358 if (it != visited_pairs->end() && max_uid == it->second) {
359 return true;
360 }
361 auto opcode1 = i1->opcode();
362 auto operands1 = i1->operands();
363 if (opcode1 != i2->opcode() || operands1.size() != i2->operands().size()) {
364 return false;
365 }
366 auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
367 return *a == *b;
368 };
369 // Two MPMD AllReduces are identical if they have the same channel_id. Their
370 // operands don't have to be identical.
371 auto eq_operands = [](const HloInstruction*, const HloInstruction*) {
372 return true;
373 };
374 if (i1->IsCrossModuleAllReduce()) {
375 return i1->Identical(*i2, eq_operands, eq_computations,
376 /*layout_sensitive=*/false);
377 }
378 visited_pairs->emplace(min_uid, max_uid);
379 for (int i = 0; i < operands1.size(); ++i) {
380 auto operand1 = operands1[i];
381 auto operand2 = i2->operands()[i];
382 if (!InstructionsComputeSameValue(operand1, operand2, visited_pairs)) {
383 return false;
384 }
385 }
386 if (opcode1 == HloOpcode::kParameter) {
387 // In the general case, we don't try to prove equality of parameters.
388 // We only try in the context of get-tuple-element
389 // (see TupleElementsComputeSameValue).
390 return false;
391 }
392 if (opcode1 == HloOpcode::kGetTupleElement) {
393 return i1->tuple_index() == i2->tuple_index() ||
394 TupleElementsComputeSameValue(operands1[0], i1->tuple_index(),
395 i2->tuple_index(), visited_pairs);
396 }
397 // Don't check that the operands are identical, because Identical can
398 // return false for instructions that compute the same value but are not
399 // identical, which we don't want. We have checked the arguments with
400 // InstructionsComputeSameValue earlier.
401 auto eq_instructions = [](const HloInstruction* i1,
402 const HloInstruction* i2) -> bool { return true; };
403 return i1->Identical(*i2, eq_instructions, eq_computations,
404 /*layout_sensitive=*/false);
405 }
406
GroupAllReducesById(HloModule * module)407 void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
408 // Say that two or more ARs lead to the same CRS: (AR1, CRS), (AR2, CRS),
409 // ... , (ARn, CRS).
410 // If as we traverse the HLO graph we start tracking the pair (AR2, CRS),
411 // and later find that AR1's distance from the CRS is longer, we discard
412 // AR2 and start tracking AR1. We put the discarded ids in this set, in order
413 // to skip processing of short paths when we encounter the other ARs that
414 // have the same id as AR2.
415 absl::flat_hash_set<int64_t> discarded_ar_ids;
416 for (HloComputation* computation : module->MakeNonfusionComputations()) {
417 for (HloInstruction* instruction : computation->instructions()) {
418 auto maybe_pair = MatchesArCrsPattern(instruction);
419 if (maybe_pair) {
420 auto pair = *maybe_pair;
421 int64_t ar_id = *(instruction->channel_id());
422 if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) {
423 continue;
424 }
425 auto it = crs_reserved_map_.find(pair.crs);
426 if (it != crs_reserved_map_.end()) {
427 auto prev_ar_id = it->second;
428 // Since there is another AR paired with CRS,
429 // all_reduce_map_[prev_ar_id] should exist, but
430 // all_reduce_map_[ar_id] shouldn't.
431 CHECK(all_reduce_map_.find(ar_id) == all_reduce_map_.end());
432 CHECK_NE(prev_ar_id, ar_id);
433 auto prev_pair = all_reduce_map_[prev_ar_id].back();
434 int64_t prev_distance = prev_pair.distance;
435 if (prev_distance < pair.distance) {
436 // The current AR's distance to CRS is longer than the previously
437 // tracked AR, so we discard the previous AR.
438 VLOG(2) << "Replacing ArCrsPair: " << prev_pair.ToString()
439 << " with ArCrsPair: " << pair.ToString();
440 all_reduce_map_.erase(prev_ar_id);
441 discarded_ar_ids.insert(prev_ar_id);
442 all_reduce_map_[ar_id].push_back(pair);
443 crs_reserved_map_[pair.crs] = ar_id;
444 } else {
445 // Discard the current AR id because we are keeping the previously
446 // tracked AR.
447 discarded_ar_ids.insert(ar_id);
448 }
449 } else {
450 if (all_reduce_map_.find(ar_id) != all_reduce_map_.end()) {
451 int64_t prev_distance = all_reduce_map_[ar_id].back().distance;
452 CHECK_EQ(prev_distance, pair.distance)
453 << "All ARs with the same AR ID must have the same distance "
454 "from the corresponding CRSs. Found: "
455 << prev_distance << " and " << pair.distance;
456 }
457 all_reduce_map_[ar_id].push_back(pair);
458 crs_reserved_map_[pair.crs] = ar_id;
459 }
460 }
461 }
462 }
463 }
464
KeepProvablyEqualInstructionGroupsMPMD()465 Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsMPMD() {
466 for (auto it = all_reduce_map_.begin(); it != all_reduce_map_.end();) {
467 auto copy_it = it++; // Advance `it` before invalidation from erase.
468 auto channel_id = copy_it->first;
469 VLOG(2)
470 << "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: "
471 << channel_id << "\n";
472 auto pairs_vec = copy_it->second;
473 TF_RET_CHECK(pairs_vec.size() == num_spatial_partitions_);
474 auto instr_0 = pairs_vec[0].ar;
475 for (int i = 1; i < pairs_vec.size(); ++i) {
476 auto instr_i = pairs_vec[i].ar;
477 auto next_0 = instr_0->users()[0];
478 auto next_i = instr_i->users()[0];
479 absl::flat_hash_map<int64_t, int64_t> visited_pairs;
480 while (true) {
481 if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) {
482 all_reduce_map_.erase(copy_it);
483 VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce "
484 "channel id: "
485 << channel_id << "\n";
486 break;
487 }
488 if (next_0->IsCrossReplicaAllReduce()) {
489 break;
490 }
491 next_0 = next_0->users()[0];
492 next_i = next_i->users()[0];
493 }
494 }
495 }
496 return OkStatus();
497 }
498
KeepProvablyEqualInstructionGroupsSPMD(HloModule * module)499 Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsSPMD(
500 HloModule* module) {
501 // For SPMD mode, use HloReplicationAnalysis to figure out HLO value
502 // equivalence across partitions.
503 TF_ASSIGN_OR_RETURN(
504 auto replication_analysis,
505 HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true));
506
507 for (auto it = all_reduce_map_.begin(); it != all_reduce_map_.end();) {
508 auto copy_it = it++; // Advance `it` before invalidation from erase.
509 auto channel_id = copy_it->first;
510 VLOG(2)
511 << "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: "
512 << channel_id << "\n";
513 auto pairs_vec = copy_it->second;
514 TF_RET_CHECK(pairs_vec.size() == 1);
515 auto instr = pairs_vec[0].ar;
516 auto next = instr->users()[0];
517 while (true) {
518 // The patterns we detect in ArCrsCombiner::MatchesArCrsPattern()
519 // guarantee that the HLO produces an array.
520 TF_RET_CHECK(next->shape().IsArray());
521 if (!replication_analysis->HloInstructionIsReplicatedAt(next, {})) {
522 all_reduce_map_.erase(copy_it);
523 VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce "
524 "channel id: "
525 << channel_id << "\n";
526 break;
527 }
528 if (next->IsCrossReplicaAllReduce()) {
529 break;
530 }
531 next = next->users()[0];
532 }
533 }
534 return OkStatus();
535 }
536
RewriteGraph()537 StatusOr<bool> ArCrsCombiner::RewriteGraph() {
538 if (all_reduce_map_.empty()) {
539 return false;
540 }
541 for (const auto& it : all_reduce_map_) {
542 auto pairs_vec = it.second;
543 for (auto pair : pairs_vec) {
544 auto all_reduce = pair.ar;
545 auto parent_computation = all_reduce->parent();
546 auto channel_id = all_reduce->channel_id();
547 auto prev = all_reduce->mutable_operand(0);
548 auto next = all_reduce->users()[0];
549 TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev));
550 TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce));
551 while (!next->IsCrossReplicaAllReduce()) {
552 switch (next->opcode()) {
553 case HloOpcode::kBitcast:
554 case HloOpcode::kTranspose:
555 case HloOpcode::kReshape:
556 case HloOpcode::kConvert:
557 case HloOpcode::kMultiply:
558 break;
559 case HloOpcode::kAdd:
560 case HloOpcode::kSubtract: {
561 auto other_operand = (next->operands()[0] == prev)
562 ? next->operands()[1]
563 : next->operands()[0];
564 // To move the AR past the addition/subtraction, we need to divide
565 // other_operand by the number of spatial partitions, except if
566 // other_operand is a cross-module AR, which can be eliminated.
567 if (other_operand->IsCrossModuleAllReduce() &&
568 other_operand->user_count() == 1) {
569 TF_CHECK_OK(other_operand->ReplaceAllUsesWith(
570 other_operand->mutable_operand(0)));
571 } else {
572 auto shape = other_operand->shape();
573 Literal lit(shape);
574 lit.PopulateWithValue<float>(num_spatial_partitions_);
575 auto divisor = parent_computation->AddInstruction(
576 HloInstruction::CreateConstant(lit.Clone()));
577 auto division = parent_computation->AddInstruction(
578 HloInstruction::CreateBinary(shape, HloOpcode::kDivide,
579 other_operand, divisor));
580 TF_CHECK_OK(other_operand->ReplaceUseWith(next, division));
581 }
582 break;
583 }
584 default:
585 LOG(FATAL) << "Unexpected instruction: " << next->ToShortString();
586 }
587 prev = next;
588 next = next->users()[0];
589 }
590 // The AllReduce and the CRS are combined to an all-core AllReduce.
591 //
592 // Note that we can just reuse the ReplicaGroup config of cross-replica
593 // all-reduce since we already checked that cross-partition all-reduce
594 // is always across all partitions (HasCombinableReplicaGroup). We need to
595 // combine ReplicaGroup configs using global ids here if we relax that
596 // restriction.
597 next->set_channel_id(channel_id);
598 }
599 }
600 return true;
601 }
602
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)603 StatusOr<bool> ArCrsCombiner::Run(
604 HloModule* module,
605 const absl::flat_hash_set<absl::string_view>& execution_threads) {
606 call_graph_ = CallGraph::Build(module);
607
608 GroupAllReducesById(module);
609
610 if (spmd_partition_) {
611 TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsSPMD(module));
612 } else {
613 TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsMPMD());
614 }
615
616 TF_ASSIGN_OR_RETURN(auto changed, RewriteGraph());
617
618 if (module->config().replica_count() > 1 && spmd_partition_) {
619 TF_ASSIGN_OR_RETURN(auto replaced, ReplaceReplicatedAllReduce(
620 module, num_spatial_partitions_));
621 changed |= replaced;
622 }
623
624 return changed;
625 }
626
627 } // namespace xla
628