1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_ 18 19 #include <memory> 20 #include <optional> 21 #include <set> 22 #include <string> 23 #include <vector> 24 25 #include "absl/container/flat_hash_map.h" 26 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" 27 #include "tensorflow/compiler/xla/service/hlo_computation.h" 28 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 29 #include "tensorflow/compiler/xla/service/hlo_module.h" 30 #include "tensorflow/compiler/xla/status.h" 31 #include "tensorflow/compiler/xla/statusor.h" 32 #include "tensorflow/core/lib/core/status.h" 33 34 namespace xla { 35 36 // Class for bookkeeping the information on the given modules, in particular on 37 // the interaction between computations. 38 // 39 // Companion instructions are one piece of information collected as we build the 40 // metadata. For example, for each While instruction, companion instructions 41 // refer to a set of While instructions in other computations that communicate 42 // with each other. 43 // In the example below with 3 modules, {While_0, While_2, While_5}, {While_1, 44 // While_4}, {While_3, While_6} are companion sets. 45 // 46 // <Module 0> <Module 1> <Module 2> 47 // While_0() { While_2() { While_5() { 48 // While_1() { Send(0) } While_3() { Send(1) } While_6() { Recv(1) } 49 // } While_4() { Recv(0) } 50 // } 51 // 52 // Each instruction can belong to at most one companion set: While_0 and While_5 53 // are in the same set even though they don't communicate with each other, 54 // because they both communicate with While_2. 55 // 56 // A send and the matching recv must both have the same level of nesting of 57 // companion instructions. 58 // 59 // Companion instructions are used to detect cycles in the graph and also for 60 // global scheduling. 61 class HloModuleGroupMetadata { 62 public: 63 // The kind of companion computation a given instruction can be within. 64 enum class ComputationKind { 65 kInvalid, 66 kWhileCondition, 67 kWhileBody, 68 kConditionalBranch, 69 kCallFunction, 70 }; 71 72 // Tracks the instruction mapped to a given computation, and the computation 73 // kind. 74 // For example, a body computation of a while instruction, will generate a 75 // TrackedInstruction with instruction being the while instruction, and 76 // kind being ComputationKind::kWhileBody. 77 class TrackedInstruction { 78 public: 79 TrackedInstruction() = default; 80 TrackedInstruction(HloInstruction* instruction, ComputationKind kind, 81 int index = -1) instruction_(instruction)82 : instruction_(instruction), kind_(kind), index_(index) {} 83 84 bool operator==(const TrackedInstruction& rhs) const { 85 return instruction_->opcode() == rhs.instruction_->opcode() && 86 kind_ == rhs.kind_ && index_ == rhs.index_; 87 } 88 bool operator!=(const TrackedInstruction& rhs) const { 89 return !operator==(rhs); 90 } 91 instruction()92 HloInstruction* instruction() const { return instruction_; } 93 94 std::string ToString() const; 95 96 private: 97 HloInstruction* instruction_ = nullptr; 98 ComputationKind kind_ = ComputationKind::kInvalid; 99 int index_ = -1; 100 }; 101 102 // Represents a channel and the instructions that form the channel. 103 struct Channel { 104 int64_t id = -1; 105 HloInstruction* send = nullptr; 106 HloInstruction* recv = nullptr; 107 HloInstruction* send_done = nullptr; 108 HloInstruction* recv_done = nullptr; 109 }; 110 HloModuleGroupMetadata(absl::Span<HloModule * const> modules)111 explicit HloModuleGroupMetadata(absl::Span<HloModule* const> modules) 112 : modules_(modules.begin(), modules.end()) {} 113 114 ~HloModuleGroupMetadata() = default; 115 116 // Build and return the metadata for the given modules. 117 static StatusOr<std::unique_ptr<HloModuleGroupMetadata>> Build( 118 absl::Span<HloModule* const> modules); 119 120 // Returns true if the instruction is one of the 4 channel instructions (Send, 121 // Recv, SendDone, RecvDone). 122 bool IsChannelInstruction(const HloInstruction* instruction) const; 123 124 // Returns true if the instruction is a companion instruction. See the class 125 // comment above on companion instructions. 126 bool IsCompanionInstruction(HloInstruction* hlo) const; 127 128 // Returns true if the instruction is either a channel instruction, a 129 // cross-module all-reduce instruction, or a companion instruction. 130 bool InstructionCommunicates(HloInstruction* hlo) const; 131 132 // Returns the Channel instance for the given channel id. 133 const Channel& GetChannel(int64_t channel_id) const; 134 135 // Returns if the given channel id exists in metadata. 136 bool HasChannel(int64_t channel_id) const; 137 138 // Returns the all-reduce instructions with the same channel_id. 139 const std::vector<HloInstruction*>& GetAllReduceGroup( 140 int64_t channel_id) const; 141 142 // Returns the computation that contains the peer channel instructions for 143 // the given instruction. 144 // 145 // Precondition: IsChannelInstruction(instruction) is true. 146 HloComputation* PeerComputation(const HloInstruction* instruction) const; 147 148 // Returns the path of the nested companion instructions, in terms of HLO 149 // instructions. The path goes from inner to outer companions. 150 // The returned path does not include the input hlo instruction, in case it 151 // is a companion instruction. 152 std::vector<TrackedInstruction> GetCompanionsPath( 153 const HloInstruction* hlo) const; 154 155 // Checks whether two companion paths (as returned by the GetCompanionsPath() 156 // API) are compatible. The two paths are compatible if the sequence of 157 // opcodes, and the companion kinds, of the two paths matches. 158 bool CheckCompanionPathsCompatibility( 159 const std::vector<TrackedInstruction>& path0, 160 const std::vector<TrackedInstruction>& path1) const; 161 162 // Returns the unique integer for each module. The returned id is the index of 163 // the module in the module vector. 164 int64_t GetModuleId(const HloModule* module) const; 165 166 // Retrieves the device an instruction is assigned to. Either from the 167 // sharding information, or from the ordinal of the module the instruction 168 // is in. 169 std::optional<int64_t> GetInstructionDevice( 170 const HloInstruction& instruction) const; 171 172 // Returns the number of modules for devices (excluding the host module). 173 int64_t GetDeviceModulesCount() const; 174 175 // Returns the companion set for the given instruction, including the 176 // instruction itself. 177 // 178 // Precondition: IsCompanionWhile(instruction) is true. Companions(const HloInstruction * instruction)179 const std::vector<HloInstruction*>& Companions( 180 const HloInstruction* instruction) const { 181 CHECK(companion_set_index_.contains(instruction)); 182 return companion_set(companion_set_index_.at(instruction)); 183 } 184 185 // Returns the companion set at the given index. companion_set(int64_t index)186 const std::vector<HloInstruction*>& companion_set(int64_t index) const { 187 CHECK_LT(index, companion_sets_.size()); 188 return *companion_sets_[index]; 189 } 190 191 // Returns the companion set index of the given instruction. companion_set_index(HloInstruction * instruction)192 int64_t companion_set_index(HloInstruction* instruction) const { 193 return companion_set_index_.at(instruction); 194 } 195 196 // Returns the list of all companion sets in the HLO module group. Each 197 // returned set contains at least one HloInstruction. 198 const std::vector<std::unique_ptr<std::vector<HloInstruction*>>>& companion_sets()199 companion_sets() const { 200 return companion_sets_; 201 } 202 203 // Returns all channels in the module group. channels()204 const std::vector<Channel>& channels() const { return channels_; } 205 206 // Returns the maximum channel id used in the module group. max_channel_id()207 int64_t max_channel_id() const { return max_channel_id_; } 208 alias_analysis(HloModule * module)209 HloAliasAnalysis* alias_analysis(HloModule* module) const { 210 return alias_analyses_.at(module).get(); 211 } 212 213 private: 214 Status Build(); 215 216 // Record all channel instructions, cross-module AllReduce instructions, and 217 // While/Conditional/Call instructions. 218 Status RecordInstructions(); 219 220 // Verifies the given HloModules are well-formed and follow the specification, 221 // in particular with respect to using channel instructions. 222 // 223 // * Each channel has all 4 instructions (Send, Recv, SendDone, RecvDone). 224 // * The shape of channel instructions match. 225 // * The nest level of channel instructions match. 226 // * Channel instructions are used in allowed computations, i.e., in the 227 // entry computation of the module or condition/body of While computations. 228 Status VerifyChannelInstructions(); 229 230 // Adds metadata that the given two instructions are companions. 231 Status AddCompanion(HloInstruction* instruction1, 232 HloInstruction* instruction2); 233 234 // Checks whether a communicating instruction is placed in a valid position 235 // within the graph. 236 Status CheckCommunicatingInstruction(HloInstruction* instruction) const; 237 238 // Performs a consistency check on the companion sets built for the input 239 // modules. Checks that each instruction in a companion set is in a different 240 // module/device. 241 Status VerifyCompanionSets() const; 242 243 // Retrieves a pointer to the stored TrackedInstruction associated with a 244 // tracked computation, or nullptr in case such computation is not tracked. GetTrackedInstruction(const HloComputation * computation)245 const TrackedInstruction* GetTrackedInstruction( 246 const HloComputation* computation) const { 247 auto it = tracked_instructions_.find(computation); 248 return it != tracked_instructions_.end() ? &it->second : nullptr; 249 } 250 251 // Dump all the collected module group statistics to the logs. 252 void DumpCollectedStats() const; 253 254 // List of all companion instructions sets in the module. 255 std::vector<std::unique_ptr<std::vector<HloInstruction*>>> companion_sets_; 256 257 // Map from each companion while instruction to the index into companion_set_. 258 absl::flat_hash_map<const HloInstruction*, int64_t> companion_set_index_; 259 260 // Map from computation to the instruction using it (a kWhile, kConditional). 261 absl::flat_hash_map<const HloComputation*, TrackedInstruction> 262 tracked_instructions_; 263 264 // Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of 265 // communicating instructions within the proper called computation(s). 266 absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>> 267 tracked_instructions_comms_; 268 269 // All channels in the module. 270 std::vector<Channel> channels_; 271 272 // Map from channel ids to the index in channels_. 273 absl::flat_hash_map<int64_t, int64_t> channel_id_map_; 274 275 // Map from all-reduce ids to the all reduce instructions. 276 absl::flat_hash_map<int64_t, std::vector<HloInstruction*>> all_reduce_map_; 277 278 // The maximum channel id used in the module group. 279 int64_t max_channel_id_ = -1; 280 281 // The modules that this metadata was built from. 282 const std::vector<HloModule*> modules_; 283 284 absl::flat_hash_map<HloModule*, std::unique_ptr<HloAliasAnalysis>> 285 alias_analyses_; 286 }; 287 288 } // namespace xla 289 290 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_ 291