1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "absl/strings/string_view.h" 21 #include "tensorflow/compiler/xla/service/call_graph.h" 22 #include "tensorflow/compiler/xla/service/hlo_module.h" 23 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 24 #include "tensorflow/compiler/xla/statusor.h" 25 26 namespace xla { 27 28 // When the HLO graph contains a cross-module AllReduce (N separate AllReduce 29 // ops that share the same channel_id for MPMD partitioning, or 1 AllReduce op 30 // for SPMD partitioning), followed by some simple linear operations, followed 31 // by a cross-replica AllReduce (also known as cross-replica sum, or CRS), we 32 // can combine the CMAR and the CRAR, to use an efficient AllReduce 33 // implementation that fully utilizes the interconnect bandwidth. 34 // 35 // Such sequences appear in spatially partitioned models (either MPMD or SPMD). 36 // This pass must run right after spatial partitioning, when the code is still 37 // in a single HLO module. 38 // 39 // The steps are: 40 // 1) Find CMARs followed by simple ops followed by CRARs. 41 // 2) Group CMARs by channel_id. They must all be rewritten. For SPMD 42 // partitioning, there will only be a single CMAR for each channel_id. 43 // 3) Prove that the CMAR patterns in each core produce the same result. 44 // 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the 45 // other operand by the number of spatial partitions. 46 // 5) Turn the CRAR into an all-core AllReduce. 47 // 48 // The pass also handles the case where multiple CMARs lead to the same CRAR, 49 // and eliminates all CMARs. This graph: 50 // 51 // Y 52 // | 53 // X CMAR_2 Z 54 // | \ / 55 // CMAR_1 + 56 // \ / 57 // + 58 // | 59 // CRAR 60 // 61 // gets rewritten to: 62 // 63 // Z num_partitions 64 // \ / 65 // Y div 66 // \ / 67 // X + 68 // \ / 69 // + 70 // | 71 // all-core AR 72 // 73 class ArCrsCombiner : public HloModulePass { 74 public: ArCrsCombiner(int num_spatial_partitions,bool spmd_partition)75 ArCrsCombiner(int num_spatial_partitions, bool spmd_partition) 76 : num_spatial_partitions_(num_spatial_partitions), 77 spmd_partition_(spmd_partition) {} name()78 absl::string_view name() const override { return "ar-crs-combiner"; } 79 using HloPassInterface::Run; 80 StatusOr<bool> Run( 81 HloModule* module, 82 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 83 84 // Helper method to allow testing of InstructionsComputeSameValue. 85 static bool TestInstructionsComputeSameValue(HloInstruction* i1, 86 HloInstruction* i2); 87 88 private: 89 // We used this struct because multiple ARs could be paired with the same CRS. 90 // In this case, we want to select the AR that is furthest from the CRS, 91 // because it makes it easier to eliminate all ARs during RewriteGraph. 92 struct ArCrsPair { 93 HloInstruction* ar; 94 HloInstruction* crs; 95 // The length of the path from AR to CRS in the HLO graph. 96 int64_t distance; 97 ArCrsPairArCrsPair98 ArCrsPair(HloInstruction* all_reduce, HloInstruction* cross_replica_sum, 99 int64_t dist) 100 : ar(all_reduce), crs(cross_replica_sum), distance(dist) {} 101 ToStringArCrsPair102 std::string ToString() { 103 std::vector<std::string> pieces; 104 pieces.push_back("("); 105 HloInstruction* instruction = ar; 106 while (instruction != crs) { 107 pieces.push_back(instruction->name()); 108 pieces.push_back(","); 109 instruction = instruction->users()[0]; 110 } 111 pieces.push_back(instruction->name()); 112 pieces.push_back(")[id:"); 113 pieces.push_back(std::to_string(*(ar->channel_id()))); 114 pieces.push_back(",dist:"); 115 pieces.push_back(std::to_string(distance)); 116 pieces.push_back("]"); 117 return absl::StrJoin(pieces, ""); 118 } 119 }; 120 121 std::optional<ArCrsCombiner::ArCrsPair> MatchesArCrsPattern( 122 HloInstruction* instruction); 123 124 // If the passed instruction is a while parameter, and the while body is only 125 // called by a single while instruction, return the while instruction. 126 std::optional<HloInstruction*> WhileFromBodyParameter( 127 HloInstruction* instruction); 128 129 // If the passed instruction is a parameter in one of the branch computations, 130 // and the branch body is only called by a single instruction, return the 131 // conditional instruction. 132 std::optional<HloInstruction*> ConditionalFromBodyParameter( 133 HloInstruction* instruction); 134 135 // Returns a vector of tuple instructions. 136 // If all instructions that flow to "instruction" are tuples, return them. 137 // Otherwise, return std::nullopt. Returns an empty vector if the instruction 138 // is already in the visited set. 139 std::optional<std::vector<HloInstruction*>> GetAllTuples( 140 HloInstruction* instruction, 141 absl::flat_hash_set<HloInstruction*>* visited); 142 143 // Checks whether two different elements in the same tuple compute the same 144 // value. 145 bool TupleElementsComputeSameValue( 146 HloInstruction* tuple_shaped_instruction, int64_t i1, int64_t i2, 147 absl::flat_hash_map<int64_t, int64_t>* visited_pairs); 148 149 // Returns whether the instructions i1 and i2 can be shown to evaluate to the 150 // same value. Handling WHILE requires recursion, which may cause us to visit 151 // the same instruction again. To avoid infinite loops, we pass a cache of 152 // visited instruction pairs. 153 bool InstructionsComputeSameValue( 154 HloInstruction* i1, HloInstruction* i2, 155 absl::flat_hash_map<int64_t, int64_t>* visited_pairs); 156 157 // Populates all_reduce_map_. 158 void GroupAllReducesById(HloModule* module); 159 160 // Looks at each AllReduce group in all_reduce_map_, and keeps only the 161 // groups for which it's safe to move the AllReduce later in the HLO graph. 162 Status KeepProvablyEqualInstructionGroupsMPMD(); 163 164 // Same as above, but runs on SPMD partitioned module instead of MPMD. 165 Status KeepProvablyEqualInstructionGroupsSPMD(HloModule* module); 166 167 // Performs the graph rewrite that eliminates the early AllReduce and turns 168 // the later CRS into an AllReduce. 169 StatusOr<bool> RewriteGraph(); 170 171 int num_spatial_partitions_; 172 173 // Run this combiner pass assuming the input module is an SPMD partitioned 174 // module (as opposed to MPMD partitioned). 175 // 176 // The main difference between the two w.r.t. this pass is that there would be 177 // N all-reduce ops for each channel in MPMD mode, whereas there is only 1 178 // for each channel in SPMD mode. Also we use HloReplicationAnalysis for HLO 179 // equivalence check in SPMD mode. 180 bool spmd_partition_; 181 182 // Map from all-reduce ids to the AR/CRS pairs. 183 absl::flat_hash_map<int64_t, std::vector<ArCrsPair>> all_reduce_map_; 184 185 // Map from a CRS instruction to the all-reduce ID of the AR paired with the 186 // CRS. Sometimes, several ARs in the code could be paired with the same CRS. 187 // We use this map to pick a single AR/CRS path to rewrite. 188 absl::flat_hash_map<HloInstruction*, int64_t> crs_reserved_map_; 189 190 std::unique_ptr<CallGraph> call_graph_; 191 }; 192 193 } // namespace xla 194 195 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_ 196