xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/ar_crs_combiner.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/strings/string_view.h"
21 #include "tensorflow/compiler/xla/service/call_graph.h"
22 #include "tensorflow/compiler/xla/service/hlo_module.h"
23 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 
26 namespace xla {
27 
28 // When the HLO graph contains a cross-module AllReduce (N separate AllReduce
29 // ops that share the same channel_id for MPMD partitioning, or 1 AllReduce op
30 // for SPMD partitioning), followed by some simple linear operations, followed
31 // by a cross-replica AllReduce (also known as cross-replica sum, or CRS), we
32 // can combine the CMAR and the CRAR, to use an efficient AllReduce
33 // implementation that fully utilizes the interconnect bandwidth.
34 //
35 // Such sequences appear in spatially partitioned models (either MPMD or SPMD).
36 // This pass must run right after spatial partitioning, when the code is still
37 // in a single HLO module.
38 //
39 // The steps are:
40 // 1) Find CMARs followed by simple ops followed by CRARs.
41 // 2) Group CMARs by channel_id. They must all be rewritten. For SPMD
42 //    partitioning, there will only be a single CMAR for each channel_id.
43 // 3) Prove that the CMAR patterns in each core produce the same result.
44 // 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the
45 //    other operand by the number of spatial partitions.
46 // 5) Turn the CRAR into an all-core AllReduce.
47 //
48 // The pass also handles the case where multiple CMARs lead to the same CRAR,
49 // and eliminates all CMARs. This graph:
50 //
51 //        Y
52 //        |
53 //  X   CMAR_2   Z
54 //  |      \    /
55 // CMAR_1     +
56 //    \     /
57 //       +
58 //       |
59 //     CRAR
60 //
61 // gets rewritten to:
62 //
63 //           Z   num_partitions
64 //            \  /
65 //       Y    div
66 //        \   /
67 //    X     +
68 //     \   /
69 //       +
70 //       |
71 //  all-core AR
72 //
73 class ArCrsCombiner : public HloModulePass {
74  public:
ArCrsCombiner(int num_spatial_partitions,bool spmd_partition)75   ArCrsCombiner(int num_spatial_partitions, bool spmd_partition)
76       : num_spatial_partitions_(num_spatial_partitions),
77         spmd_partition_(spmd_partition) {}
name()78   absl::string_view name() const override { return "ar-crs-combiner"; }
79   using HloPassInterface::Run;
80   StatusOr<bool> Run(
81       HloModule* module,
82       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
83 
84   // Helper method to allow testing of InstructionsComputeSameValue.
85   static bool TestInstructionsComputeSameValue(HloInstruction* i1,
86                                                HloInstruction* i2);
87 
88  private:
89   // We used this struct because multiple ARs could be paired with the same CRS.
90   // In this case, we want to select the AR that is furthest from the CRS,
91   // because it makes it easier to eliminate all ARs during RewriteGraph.
92   struct ArCrsPair {
93     HloInstruction* ar;
94     HloInstruction* crs;
95     // The length of the path from AR to CRS in the HLO graph.
96     int64_t distance;
97 
ArCrsPairArCrsPair98     ArCrsPair(HloInstruction* all_reduce, HloInstruction* cross_replica_sum,
99               int64_t dist)
100         : ar(all_reduce), crs(cross_replica_sum), distance(dist) {}
101 
ToStringArCrsPair102     std::string ToString() {
103       std::vector<std::string> pieces;
104       pieces.push_back("(");
105       HloInstruction* instruction = ar;
106       while (instruction != crs) {
107         pieces.push_back(instruction->name());
108         pieces.push_back(",");
109         instruction = instruction->users()[0];
110       }
111       pieces.push_back(instruction->name());
112       pieces.push_back(")[id:");
113       pieces.push_back(std::to_string(*(ar->channel_id())));
114       pieces.push_back(",dist:");
115       pieces.push_back(std::to_string(distance));
116       pieces.push_back("]");
117       return absl::StrJoin(pieces, "");
118     }
119   };
120 
121   std::optional<ArCrsCombiner::ArCrsPair> MatchesArCrsPattern(
122       HloInstruction* instruction);
123 
124   // If the passed instruction is a while parameter, and the while body is only
125   // called by a single while instruction, return the while instruction.
126   std::optional<HloInstruction*> WhileFromBodyParameter(
127       HloInstruction* instruction);
128 
129   // If the passed instruction is a parameter in one of the branch computations,
130   // and the branch body is only called by a single instruction, return the
131   // conditional instruction.
132   std::optional<HloInstruction*> ConditionalFromBodyParameter(
133       HloInstruction* instruction);
134 
135   // Returns a vector of tuple instructions.
136   // If all instructions that flow to "instruction" are tuples, return them.
137   // Otherwise, return std::nullopt. Returns an empty vector if the instruction
138   // is already in the visited set.
139   std::optional<std::vector<HloInstruction*>> GetAllTuples(
140       HloInstruction* instruction,
141       absl::flat_hash_set<HloInstruction*>* visited);
142 
143   // Checks whether two different elements in the same tuple compute the same
144   // value.
145   bool TupleElementsComputeSameValue(
146       HloInstruction* tuple_shaped_instruction, int64_t i1, int64_t i2,
147       absl::flat_hash_map<int64_t, int64_t>* visited_pairs);
148 
149   // Returns whether the instructions i1 and i2 can be shown to evaluate to the
150   // same value. Handling WHILE requires recursion, which may cause us to visit
151   // the same instruction again. To avoid infinite loops, we pass a cache of
152   // visited instruction pairs.
153   bool InstructionsComputeSameValue(
154       HloInstruction* i1, HloInstruction* i2,
155       absl::flat_hash_map<int64_t, int64_t>* visited_pairs);
156 
157   // Populates all_reduce_map_.
158   void GroupAllReducesById(HloModule* module);
159 
160   // Looks at each AllReduce group in all_reduce_map_, and keeps only the
161   // groups for which it's safe to move the AllReduce later in the HLO graph.
162   Status KeepProvablyEqualInstructionGroupsMPMD();
163 
164   // Same as above, but runs on SPMD partitioned module instead of MPMD.
165   Status KeepProvablyEqualInstructionGroupsSPMD(HloModule* module);
166 
167   // Performs the graph rewrite that eliminates the early AllReduce and turns
168   // the later CRS into an AllReduce.
169   StatusOr<bool> RewriteGraph();
170 
171   int num_spatial_partitions_;
172 
173   // Run this combiner pass assuming the input module is an SPMD partitioned
174   // module (as opposed to MPMD partitioned).
175   //
176   // The main difference between the two w.r.t. this pass is that there would be
177   // N all-reduce ops for each channel in MPMD mode, whereas there is only 1
178   // for each channel in SPMD mode. Also we use HloReplicationAnalysis for HLO
179   // equivalence check in SPMD mode.
180   bool spmd_partition_;
181 
182   // Map from all-reduce ids to the AR/CRS pairs.
183   absl::flat_hash_map<int64_t, std::vector<ArCrsPair>> all_reduce_map_;
184 
185   // Map from a CRS instruction to the all-reduce ID of the AR paired with the
186   // CRS. Sometimes, several ARs in the code could be paired with the same CRS.
187   // We use this map to pick a single AR/CRS path to rewrite.
188   absl::flat_hash_map<HloInstruction*, int64_t> crs_reserved_map_;
189 
190   std::unique_ptr<CallGraph> call_graph_;
191 };
192 
193 }  // namespace xla
194 
195 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_
196