xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/ar_crs_combiner.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/ar_crs_combiner.h"
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/call_graph.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_query.h"
31 #include "tensorflow/compiler/xla/service/hlo_replication_analysis.h"
32 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/types.h"
36 
37 namespace xla {
38 namespace {
39 
40 // In SPMD mode, if there's a cross-replica all-reduce that produces the same
41 // value for all partitions, replaces it with a global all-reduce and then
42 // divide by the number of partitions. Depending on the topology and the
43 // implementation of the all-reduce for the backend, this may give a better
44 // performance.
ReplaceReplicatedAllReduce(HloModule * module,int64_t partition_count)45 StatusOr<bool> ReplaceReplicatedAllReduce(HloModule* module,
46                                           int64_t partition_count) {
47   TF_ASSIGN_OR_RETURN(
48       auto replication_analysis,
49       HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true));
50 
51   bool changed = false;
52   int64_t next_channel = hlo_query::NextChannelId(*module);
53   for (auto computation : module->computations()) {
54     for (auto instruction : computation->instructions()) {
55       if (auto ar = DynCast<HloAllReduceInstruction>(instruction)) {
56         const Shape& shape = ar->shape();
57         if (ar->channel_id()) {
58           continue;
59         }
60         if (ar->replica_groups().size() > 1) {
61           continue;
62         }
63         if (shape.IsTuple() || shape.element_type() != F32) {
64           continue;
65         }
66         // We would need a cost model for the target, but in general we want to
67         // rewrite only if the replica count in the original op was large.
68         if (module->config().replica_count() < 8 * partition_count) {
69           continue;
70         }
71         if (replication_analysis->HloInstructionIsReplicatedAt(ar, {})) {
72           VLOG(2) << "Replaced replicated all-reduce:" << ar->ToString();
73           ar->set_channel_id(next_channel++);
74           auto divisor =
75               computation->AddInstruction(HloInstruction::CreateConstant(
76                   LiteralUtil::CreateR0<float>(partition_count)));
77           auto bcast = computation->AddInstruction(
78               HloInstruction::CreateBroadcast(shape, divisor, {}));
79           auto div = computation->AddInstruction(HloInstruction::CreateBinary(
80               ar->shape(), HloOpcode::kDivide, ar, bcast));
81           TF_RETURN_IF_ERROR(ar->ReplaceAllUsesWith(div));
82           changed = true;
83         }
84       }
85     }
86   }
87   return changed;
88 }
89 
90 // Returns true if the given instruction (must be a cross-partition all-reduce)
91 // has a ReplicaGroup config that can be combined with cross-replica all-reduce.
92 // We currently restrict to those groups where all partitions in each replica
93 // belong to the same group.
HasCombinableReplicaGroup(HloInstruction * hlo,int64_t num_partitions)94 bool HasCombinableReplicaGroup(HloInstruction* hlo, int64_t num_partitions) {
95   auto all_reduce = Cast<HloAllReduceInstruction>(hlo);
96   auto replica_groups = all_reduce->replica_groups();
97   const int64_t replica_count = hlo->GetModule()->config().replica_count();
98   CHECK(all_reduce->IsCrossModuleAllReduce());
99 
100   if (all_reduce->use_global_device_ids()) {
101     if (replica_groups.size() != replica_count) {
102       return false;
103     }
104     for (const auto& group : replica_groups) {
105       if (group.replica_ids_size() != num_partitions) {
106         return false;
107       }
108       absl::flat_hash_set<int64_t> partition_ids;
109       int64_t replica_id = group.replica_ids(0) / num_partitions;
110       for (int64_t i = 0; i < num_partitions; ++i) {
111         if (group.replica_ids(i) / num_partitions != replica_id) {
112           return false;
113         }
114         partition_ids.insert(group.replica_ids(i) % num_partitions);
115       }
116       if (partition_ids.size() != num_partitions) {
117         return false;
118       }
119     }
120     return true;
121   }
122 
123   return replica_groups.size() == replica_count;
124 }
125 
126 }  // namespace
127 
128 namespace m = match;
129 
130 // Checks if the argument instruction is an AllReduce, followed by a certain
131 // sequence of instructions and then a CRS. It must be possible to move
132 // the AR past each instruction in the sequence.
MatchesArCrsPattern(HloInstruction * instruction)133 std::optional<ArCrsCombiner::ArCrsPair> ArCrsCombiner::MatchesArCrsPattern(
134     HloInstruction* instruction) {
135   auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool {
136     if (instruction->user_count() != 1) {
137       return false;
138     }
139     switch (instruction->opcode()) {
140       case HloOpcode::kBitcast:
141       case HloOpcode::kTranspose:
142       case HloOpcode::kReshape:
143         return true;
144       case HloOpcode::kConvert:
145         // Can be moved across if both input and output is either float or
146         // integer (e.g. S32<->U32 or F32<->BF16)
147         return ShapeUtil::ElementIsFloating(instruction->shape()) ==
148                ShapeUtil::ElementIsFloating(instruction->operand(0)->shape());
149       case HloOpcode::kAdd:
150       case HloOpcode::kSubtract:
151       case HloOpcode::kMultiply:
152         // Only supported for floating point operands.
153         return ShapeUtil::ElementIsFloating(instruction->shape());
154       default:
155         return false;
156     }
157   };
158 
159   auto computation_is_addition = [](HloComputation* c) {
160     return c->instruction_count() == 3 &&
161            Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter()));
162   };
163 
164   // We only support combining cross-partition all-reduce where each replica
165   // belongs to its own group, since the later cross-replica all-reduce combines
166   // along the replica dimension.
167   if (instruction->IsCrossModuleAllReduce() &&
168       HasCombinableReplicaGroup(instruction, num_spatial_partitions_) &&
169       computation_is_addition(instruction->called_computations()[0]) &&
170       instruction->user_count() == 1) {
171     auto next = instruction->users()[0];
172     int64_t distance = 1;
173     while (!next->IsCrossReplicaAllReduce()) {
174       if (can_ar_move_past_instruction(next)) {
175         next = next->users()[0];
176       } else {
177         return std::nullopt;
178       }
179       ++distance;
180     }
181     if (!Cast<HloAllReduceInstruction>(next)->IsNoop() &&
182         computation_is_addition(next->called_computations()[0])) {
183       ArCrsPair pair(instruction, next, distance);
184       VLOG(2) << "ArCrsPair matching pattern: " << pair.ToString();
185       return pair;
186     }
187   }
188   return std::nullopt;
189 }
190 
WhileFromBodyParameter(HloInstruction * instruction)191 std::optional<HloInstruction*> ArCrsCombiner::WhileFromBodyParameter(
192     HloInstruction* instruction) {
193   CHECK_EQ(HloOpcode::kParameter, instruction->opcode());
194   HloComputation* computation = instruction->parent();
195   auto caller_instructions = call_graph_->GetComputationCallers(computation);
196   if (caller_instructions.size() == 1) {
197     auto caller_instruction = caller_instructions[0];
198     if (caller_instruction->opcode() == HloOpcode::kWhile) {
199       return caller_instruction;
200     }
201   }
202   return std::nullopt;
203 }
204 
ConditionalFromBodyParameter(HloInstruction * instruction)205 std::optional<HloInstruction*> ArCrsCombiner::ConditionalFromBodyParameter(
206     HloInstruction* instruction) {
207   CHECK_EQ(HloOpcode::kParameter, instruction->opcode());
208   HloComputation* computation = instruction->parent();
209   auto caller_instructions = call_graph_->GetComputationCallers(computation);
210   if (caller_instructions.size() == 1) {
211     auto caller_instruction = caller_instructions[0];
212     if (caller_instruction->opcode() == HloOpcode::kConditional) {
213       return caller_instruction;
214     }
215   }
216   return std::nullopt;
217 }
218 
GetAllTuples(HloInstruction * instruction,absl::flat_hash_set<HloInstruction * > * visited)219 std::optional<std::vector<HloInstruction*>> ArCrsCombiner::GetAllTuples(
220     HloInstruction* instruction,
221     absl::flat_hash_set<HloInstruction*>* visited) {
222   if (visited->find(instruction) != visited->end()) {
223     return std::vector<HloInstruction*>();
224   }
225   visited->insert(instruction);
226 
227   switch (instruction->opcode()) {
228     case HloOpcode::kTuple: {
229       return std::vector<HloInstruction*>({instruction});
230     }
231     case HloOpcode::kDomain: {
232       return GetAllTuples(instruction->operands()[0], visited);
233     }
234     case HloOpcode::kParameter: {
235       auto maybe_while = WhileFromBodyParameter(instruction);
236       if (maybe_while) {
237         auto while_instr = *maybe_while;
238         auto init_tuples = GetAllTuples(while_instr->while_init(), visited);
239         auto body_tuples = GetAllTuples(
240             while_instr->while_body()->root_instruction(), visited);
241         if (!init_tuples || !body_tuples) {
242           return std::nullopt;
243         }
244         auto result = *init_tuples;
245         result.insert(result.end(), body_tuples->begin(), body_tuples->end());
246         return result;
247       }
248       auto maybe_conditional = ConditionalFromBodyParameter(instruction);
249       if (maybe_conditional) {
250         auto cond_instr = *maybe_conditional;
251         std::vector<HloInstruction*> tuples;
252         for (int64_t i = 0; i < cond_instr->branch_computations().size(); ++i) {
253           if (cond_instr->branch_computation(i)->parameter_instruction(0) ==
254               instruction) {
255             // If the same computation is used for more than one branch of the
256             // conditional, we collect the arguments that flow to the
257             // computation from all branches.
258             auto branch_tuples =
259                 GetAllTuples(cond_instr->mutable_operand(i + 1), visited);
260             if (!branch_tuples) {
261               return std::nullopt;
262             }
263             tuples.insert(tuples.end(), branch_tuples->begin(),
264                           branch_tuples->end());
265           }
266         }
267         return tuples;
268       }
269       return std::nullopt;
270     }
271     case HloOpcode::kGetTupleElement: {
272       std::vector<HloInstruction*> result_tuples;
273       auto tuples = GetAllTuples(instruction->operands()[0], visited);
274       if (!tuples) {
275         return std::nullopt;
276       }
277       for (auto tuple : *tuples) {
278         auto tmp_tuples = GetAllTuples(
279             tuple->mutable_operand(instruction->tuple_index()), visited);
280         if (!tmp_tuples) {
281           return std::nullopt;
282         }
283         result_tuples.insert(result_tuples.end(), tmp_tuples->begin(),
284                              tmp_tuples->end());
285       }
286       return result_tuples;
287     }
288     case HloOpcode::kConditional: {
289       std::vector<HloInstruction*> result_tuples;
290       const auto& branch_computations = instruction->branch_computations();
291       result_tuples.reserve(branch_computations.size());
292       for (HloComputation* body : branch_computations) {
293         if (body->root_instruction()->opcode() != HloOpcode::kTuple) {
294           return std::nullopt;
295         }
296         result_tuples.push_back(body->root_instruction());
297       }
298       return result_tuples;
299     }
300     case HloOpcode::kWhile: {
301       auto init_tuples = GetAllTuples(instruction->while_init(), visited);
302       auto body_tuples =
303           GetAllTuples(instruction->while_body()->root_instruction(), visited);
304       if (!init_tuples || !body_tuples) {
305         return std::nullopt;
306       }
307       auto result = *init_tuples;
308       result.insert(result.end(), body_tuples->begin(), body_tuples->end());
309       return result;
310     }
311     default:
312       return std::nullopt;
313   }
314 }
315 
TupleElementsComputeSameValue(HloInstruction * tuple_shaped_instruction,int64_t i1,int64_t i2,absl::flat_hash_map<int64_t,int64_t> * visited_pairs)316 bool ArCrsCombiner::TupleElementsComputeSameValue(
317     HloInstruction* tuple_shaped_instruction, int64_t i1, int64_t i2,
318     absl::flat_hash_map<int64_t, int64_t>* visited_pairs) {
319   absl::flat_hash_set<HloInstruction*> visited;
320   auto tuples = GetAllTuples(tuple_shaped_instruction, &visited);
321   if (!tuples) {
322     return false;
323   }
324   for (auto tuple : *tuples) {
325     CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
326     if (!InstructionsComputeSameValue(tuple->mutable_operand(i1),
327                                       tuple->mutable_operand(i2),
328                                       visited_pairs)) {
329       return false;
330     }
331   }
332   return true;
333 }
334 
335 /* static */
TestInstructionsComputeSameValue(HloInstruction * i1,HloInstruction * i2)336 bool ArCrsCombiner::TestInstructionsComputeSameValue(HloInstruction* i1,
337                                                      HloInstruction* i2) {
338   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
339                          /*spmd_partition=*/false);
340   auto module = i1->parent()->parent();
341   CHECK_EQ(module, i2->parent()->parent());
342   combiner.call_graph_ = CallGraph::Build(module);
343   absl::flat_hash_map<int64_t, int64_t> visited_pairs;
344   return combiner.InstructionsComputeSameValue(i1, i2, &visited_pairs);
345 }
346 
InstructionsComputeSameValue(HloInstruction * i1,HloInstruction * i2,absl::flat_hash_map<int64_t,int64_t> * visited_pairs)347 bool ArCrsCombiner::InstructionsComputeSameValue(
348     HloInstruction* i1, HloInstruction* i2,
349     absl::flat_hash_map<int64_t, int64_t>* visited_pairs) {
350   if (i1 == i2) {
351     return true;
352   }
353   auto uid1 = i1->unique_id();
354   auto uid2 = i2->unique_id();
355   auto min_uid = std::min(uid1, uid2);
356   auto max_uid = std::max(uid1, uid2);
357   auto it = visited_pairs->find(min_uid);
358   if (it != visited_pairs->end() && max_uid == it->second) {
359     return true;
360   }
361   auto opcode1 = i1->opcode();
362   auto operands1 = i1->operands();
363   if (opcode1 != i2->opcode() || operands1.size() != i2->operands().size()) {
364     return false;
365   }
366   auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
367     return *a == *b;
368   };
369   // Two MPMD AllReduces are identical if they have the same channel_id. Their
370   // operands don't have to be identical.
371   auto eq_operands = [](const HloInstruction*, const HloInstruction*) {
372     return true;
373   };
374   if (i1->IsCrossModuleAllReduce()) {
375     return i1->Identical(*i2, eq_operands, eq_computations,
376                          /*layout_sensitive=*/false);
377   }
378   visited_pairs->emplace(min_uid, max_uid);
379   for (int i = 0; i < operands1.size(); ++i) {
380     auto operand1 = operands1[i];
381     auto operand2 = i2->operands()[i];
382     if (!InstructionsComputeSameValue(operand1, operand2, visited_pairs)) {
383       return false;
384     }
385   }
386   if (opcode1 == HloOpcode::kParameter) {
387     // In the general case, we don't try to prove equality of parameters.
388     // We only try in the context of get-tuple-element
389     // (see TupleElementsComputeSameValue).
390     return false;
391   }
392   if (opcode1 == HloOpcode::kGetTupleElement) {
393     return i1->tuple_index() == i2->tuple_index() ||
394            TupleElementsComputeSameValue(operands1[0], i1->tuple_index(),
395                                          i2->tuple_index(), visited_pairs);
396   }
397   // Don't check that the operands are identical, because Identical can
398   // return false for instructions that compute the same value but are not
399   // identical, which we don't want. We have checked the arguments with
400   // InstructionsComputeSameValue earlier.
401   auto eq_instructions = [](const HloInstruction* i1,
402                             const HloInstruction* i2) -> bool { return true; };
403   return i1->Identical(*i2, eq_instructions, eq_computations,
404                        /*layout_sensitive=*/false);
405 }
406 
GroupAllReducesById(HloModule * module)407 void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
408   // Say that two or more ARs lead to the same CRS: (AR1, CRS), (AR2, CRS),
409   // ... , (ARn, CRS).
410   // If as we traverse the HLO graph we start tracking the pair (AR2, CRS),
411   // and later find that AR1's distance from the CRS is longer, we discard
412   // AR2 and start tracking AR1. We put the discarded ids in this set, in order
413   // to skip processing of short paths when we encounter the other ARs that
414   // have the same id as AR2.
415   absl::flat_hash_set<int64_t> discarded_ar_ids;
416   for (HloComputation* computation : module->MakeNonfusionComputations()) {
417     for (HloInstruction* instruction : computation->instructions()) {
418       auto maybe_pair = MatchesArCrsPattern(instruction);
419       if (maybe_pair) {
420         auto pair = *maybe_pair;
421         int64_t ar_id = *(instruction->channel_id());
422         if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) {
423           continue;
424         }
425         auto it = crs_reserved_map_.find(pair.crs);
426         if (it != crs_reserved_map_.end()) {
427           auto prev_ar_id = it->second;
428           // Since there is another AR paired with CRS,
429           // all_reduce_map_[prev_ar_id] should exist, but
430           // all_reduce_map_[ar_id] shouldn't.
431           CHECK(all_reduce_map_.find(ar_id) == all_reduce_map_.end());
432           CHECK_NE(prev_ar_id, ar_id);
433           auto prev_pair = all_reduce_map_[prev_ar_id].back();
434           int64_t prev_distance = prev_pair.distance;
435           if (prev_distance < pair.distance) {
436             // The current AR's distance to CRS is longer than the previously
437             // tracked AR, so we discard the previous AR.
438             VLOG(2) << "Replacing ArCrsPair: " << prev_pair.ToString()
439                     << " with ArCrsPair: " << pair.ToString();
440             all_reduce_map_.erase(prev_ar_id);
441             discarded_ar_ids.insert(prev_ar_id);
442             all_reduce_map_[ar_id].push_back(pair);
443             crs_reserved_map_[pair.crs] = ar_id;
444           } else {
445             // Discard the current AR id because we are keeping the previously
446             // tracked AR.
447             discarded_ar_ids.insert(ar_id);
448           }
449         } else {
450           if (all_reduce_map_.find(ar_id) != all_reduce_map_.end()) {
451             int64_t prev_distance = all_reduce_map_[ar_id].back().distance;
452             CHECK_EQ(prev_distance, pair.distance)
453                 << "All ARs with the same AR ID must have the same distance "
454                    "from the corresponding CRSs. Found: "
455                 << prev_distance << " and " << pair.distance;
456           }
457           all_reduce_map_[ar_id].push_back(pair);
458           crs_reserved_map_[pair.crs] = ar_id;
459         }
460       }
461     }
462   }
463 }
464 
KeepProvablyEqualInstructionGroupsMPMD()465 Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsMPMD() {
466   for (auto it = all_reduce_map_.begin(); it != all_reduce_map_.end();) {
467     auto copy_it = it++;  // Advance `it` before invalidation from erase.
468     auto channel_id = copy_it->first;
469     VLOG(2)
470         << "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: "
471         << channel_id << "\n";
472     auto pairs_vec = copy_it->second;
473     TF_RET_CHECK(pairs_vec.size() == num_spatial_partitions_);
474     auto instr_0 = pairs_vec[0].ar;
475     for (int i = 1; i < pairs_vec.size(); ++i) {
476       auto instr_i = pairs_vec[i].ar;
477       auto next_0 = instr_0->users()[0];
478       auto next_i = instr_i->users()[0];
479       absl::flat_hash_map<int64_t, int64_t> visited_pairs;
480       while (true) {
481         if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) {
482           all_reduce_map_.erase(copy_it);
483           VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce "
484                      "channel id: "
485                   << channel_id << "\n";
486           break;
487         }
488         if (next_0->IsCrossReplicaAllReduce()) {
489           break;
490         }
491         next_0 = next_0->users()[0];
492         next_i = next_i->users()[0];
493       }
494     }
495   }
496   return OkStatus();
497 }
498 
KeepProvablyEqualInstructionGroupsSPMD(HloModule * module)499 Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsSPMD(
500     HloModule* module) {
501   // For SPMD mode, use HloReplicationAnalysis to figure out HLO value
502   // equivalence across partitions.
503   TF_ASSIGN_OR_RETURN(
504       auto replication_analysis,
505       HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true));
506 
507   for (auto it = all_reduce_map_.begin(); it != all_reduce_map_.end();) {
508     auto copy_it = it++;  // Advance `it` before invalidation from erase.
509     auto channel_id = copy_it->first;
510     VLOG(2)
511         << "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: "
512         << channel_id << "\n";
513     auto pairs_vec = copy_it->second;
514     TF_RET_CHECK(pairs_vec.size() == 1);
515     auto instr = pairs_vec[0].ar;
516     auto next = instr->users()[0];
517     while (true) {
518       // The patterns we detect in ArCrsCombiner::MatchesArCrsPattern()
519       // guarantee that the HLO produces an array.
520       TF_RET_CHECK(next->shape().IsArray());
521       if (!replication_analysis->HloInstructionIsReplicatedAt(next, {})) {
522         all_reduce_map_.erase(copy_it);
523         VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce "
524                    "channel id: "
525                 << channel_id << "\n";
526         break;
527       }
528       if (next->IsCrossReplicaAllReduce()) {
529         break;
530       }
531       next = next->users()[0];
532     }
533   }
534   return OkStatus();
535 }
536 
RewriteGraph()537 StatusOr<bool> ArCrsCombiner::RewriteGraph() {
538   if (all_reduce_map_.empty()) {
539     return false;
540   }
541   for (const auto& it : all_reduce_map_) {
542     auto pairs_vec = it.second;
543     for (auto pair : pairs_vec) {
544       auto all_reduce = pair.ar;
545       auto parent_computation = all_reduce->parent();
546       auto channel_id = all_reduce->channel_id();
547       auto prev = all_reduce->mutable_operand(0);
548       auto next = all_reduce->users()[0];
549       TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev));
550       TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce));
551       while (!next->IsCrossReplicaAllReduce()) {
552         switch (next->opcode()) {
553           case HloOpcode::kBitcast:
554           case HloOpcode::kTranspose:
555           case HloOpcode::kReshape:
556           case HloOpcode::kConvert:
557           case HloOpcode::kMultiply:
558             break;
559           case HloOpcode::kAdd:
560           case HloOpcode::kSubtract: {
561             auto other_operand = (next->operands()[0] == prev)
562                                      ? next->operands()[1]
563                                      : next->operands()[0];
564             // To move the AR past the addition/subtraction, we need to divide
565             // other_operand by the number of spatial partitions, except if
566             // other_operand is a cross-module AR, which can be eliminated.
567             if (other_operand->IsCrossModuleAllReduce() &&
568                 other_operand->user_count() == 1) {
569               TF_CHECK_OK(other_operand->ReplaceAllUsesWith(
570                   other_operand->mutable_operand(0)));
571             } else {
572               auto shape = other_operand->shape();
573               Literal lit(shape);
574               lit.PopulateWithValue<float>(num_spatial_partitions_);
575               auto divisor = parent_computation->AddInstruction(
576                   HloInstruction::CreateConstant(lit.Clone()));
577               auto division = parent_computation->AddInstruction(
578                   HloInstruction::CreateBinary(shape, HloOpcode::kDivide,
579                                                other_operand, divisor));
580               TF_CHECK_OK(other_operand->ReplaceUseWith(next, division));
581             }
582             break;
583           }
584           default:
585             LOG(FATAL) << "Unexpected instruction: " << next->ToShortString();
586         }
587         prev = next;
588         next = next->users()[0];
589       }
590       // The AllReduce and the CRS are combined to an all-core AllReduce.
591       //
592       // Note that we can just reuse the ReplicaGroup config of cross-replica
593       // all-reduce since we already checked that cross-partition all-reduce
594       // is always across all partitions (HasCombinableReplicaGroup). We need to
595       // combine ReplicaGroup configs using global ids here if we relax that
596       // restriction.
597       next->set_channel_id(channel_id);
598     }
599   }
600   return true;
601 }
602 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)603 StatusOr<bool> ArCrsCombiner::Run(
604     HloModule* module,
605     const absl::flat_hash_set<absl::string_view>& execution_threads) {
606   call_graph_ = CallGraph::Build(module);
607 
608   GroupAllReducesById(module);
609 
610   if (spmd_partition_) {
611     TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsSPMD(module));
612   } else {
613     TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsMPMD());
614   }
615 
616   TF_ASSIGN_OR_RETURN(auto changed, RewriteGraph());
617 
618   if (module->config().replica_count() > 1 && spmd_partition_) {
619     TF_ASSIGN_OR_RETURN(auto replaced, ReplaceReplicatedAllReduce(
620                                            module, num_spatial_partitions_));
621     changed |= replaced;
622   }
623 
624   return changed;
625 }
626 
627 }  // namespace xla
628