xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_module_group_metadata.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_
18 
19 #include <memory>
20 #include <optional>
21 #include <set>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/status.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/core/lib/core/status.h"
33 
34 namespace xla {
35 
36 // Class for bookkeeping the information on the given modules, in particular on
37 // the interaction between computations.
38 //
39 // Companion instructions are one piece of information collected as we build the
40 // metadata. For example, for each While instruction, companion instructions
41 // refer to a set of While instructions in other computations that communicate
42 // with each other.
43 // In the example below with 3 modules, {While_0, While_2, While_5}, {While_1,
44 // While_4}, {While_3, While_6} are companion sets.
45 //
46 // <Module 0>               <Module 1>                 <Module 2>
47 // While_0() {              While_2() {                While_5() {
48 //   While_1() { Send(0) }    While_3() { Send(1) }      While_6() { Recv(1) }
49 // }                          While_4() { Recv(0) }
50 //                          }
51 //
52 // Each instruction can belong to at most one companion set: While_0 and While_5
53 // are in the same set even though they don't communicate with each other,
54 // because they both communicate with While_2.
55 //
56 // A send and the matching recv must both have the same level of nesting of
57 // companion instructions.
58 //
59 // Companion instructions are used to detect cycles in the graph and also for
60 // global scheduling.
61 class HloModuleGroupMetadata {
62  public:
63   // The kind of companion computation a given instruction can be within.
64   enum class ComputationKind {
65     kInvalid,
66     kWhileCondition,
67     kWhileBody,
68     kConditionalBranch,
69     kCallFunction,
70   };
71 
72   // Tracks the instruction mapped to a given computation, and the computation
73   // kind.
74   // For example, a body computation of a while instruction, will generate a
75   // TrackedInstruction with instruction being the while instruction, and
76   // kind being ComputationKind::kWhileBody.
77   class TrackedInstruction {
78    public:
79     TrackedInstruction() = default;
80     TrackedInstruction(HloInstruction* instruction, ComputationKind kind,
81                        int index = -1)
instruction_(instruction)82         : instruction_(instruction), kind_(kind), index_(index) {}
83 
84     bool operator==(const TrackedInstruction& rhs) const {
85       return instruction_->opcode() == rhs.instruction_->opcode() &&
86              kind_ == rhs.kind_ && index_ == rhs.index_;
87     }
88     bool operator!=(const TrackedInstruction& rhs) const {
89       return !operator==(rhs);
90     }
91 
instruction()92     HloInstruction* instruction() const { return instruction_; }
93 
94     std::string ToString() const;
95 
96    private:
97     HloInstruction* instruction_ = nullptr;
98     ComputationKind kind_ = ComputationKind::kInvalid;
99     int index_ = -1;
100   };
101 
102   // Represents a channel and the instructions that form the channel.
103   struct Channel {
104     int64_t id = -1;
105     HloInstruction* send = nullptr;
106     HloInstruction* recv = nullptr;
107     HloInstruction* send_done = nullptr;
108     HloInstruction* recv_done = nullptr;
109   };
110 
HloModuleGroupMetadata(absl::Span<HloModule * const> modules)111   explicit HloModuleGroupMetadata(absl::Span<HloModule* const> modules)
112       : modules_(modules.begin(), modules.end()) {}
113 
114   ~HloModuleGroupMetadata() = default;
115 
116   // Build and return the metadata for the given modules.
117   static StatusOr<std::unique_ptr<HloModuleGroupMetadata>> Build(
118       absl::Span<HloModule* const> modules);
119 
120   // Returns true if the instruction is one of the 4 channel instructions (Send,
121   // Recv, SendDone, RecvDone).
122   bool IsChannelInstruction(const HloInstruction* instruction) const;
123 
124   // Returns true if the instruction is a companion instruction. See the class
125   // comment above on companion instructions.
126   bool IsCompanionInstruction(HloInstruction* hlo) const;
127 
128   // Returns true if the instruction is either a channel instruction, a
129   // cross-module all-reduce instruction, or a companion instruction.
130   bool InstructionCommunicates(HloInstruction* hlo) const;
131 
132   // Returns the Channel instance for the given channel id.
133   const Channel& GetChannel(int64_t channel_id) const;
134 
135   // Returns if the given channel id exists in metadata.
136   bool HasChannel(int64_t channel_id) const;
137 
138   // Returns the all-reduce instructions with the same channel_id.
139   const std::vector<HloInstruction*>& GetAllReduceGroup(
140       int64_t channel_id) const;
141 
142   // Returns the computation that contains the peer channel instructions for
143   // the given instruction.
144   //
145   // Precondition: IsChannelInstruction(instruction) is true.
146   HloComputation* PeerComputation(const HloInstruction* instruction) const;
147 
148   // Returns the path of the nested companion instructions, in terms of HLO
149   // instructions. The path goes from inner to outer companions.
150   // The returned path does not include the input hlo instruction, in case it
151   // is a companion instruction.
152   std::vector<TrackedInstruction> GetCompanionsPath(
153       const HloInstruction* hlo) const;
154 
155   // Checks whether two companion paths (as returned by the GetCompanionsPath()
156   // API) are compatible. The two paths are compatible if the sequence of
157   // opcodes, and the companion kinds, of the two paths matches.
158   bool CheckCompanionPathsCompatibility(
159       const std::vector<TrackedInstruction>& path0,
160       const std::vector<TrackedInstruction>& path1) const;
161 
162   // Returns the unique integer for each module. The returned id is the index of
163   // the module in the module vector.
164   int64_t GetModuleId(const HloModule* module) const;
165 
166   // Retrieves the device an instruction is assigned to. Either from the
167   // sharding information, or from the ordinal of the module the instruction
168   // is in.
169   std::optional<int64_t> GetInstructionDevice(
170       const HloInstruction& instruction) const;
171 
172   // Returns the number of modules for devices (excluding the host module).
173   int64_t GetDeviceModulesCount() const;
174 
175   // Returns the companion set for the given instruction, including the
176   // instruction itself.
177   //
178   // Precondition: IsCompanionWhile(instruction) is true.
Companions(const HloInstruction * instruction)179   const std::vector<HloInstruction*>& Companions(
180       const HloInstruction* instruction) const {
181     CHECK(companion_set_index_.contains(instruction));
182     return companion_set(companion_set_index_.at(instruction));
183   }
184 
185   // Returns the companion set at the given index.
companion_set(int64_t index)186   const std::vector<HloInstruction*>& companion_set(int64_t index) const {
187     CHECK_LT(index, companion_sets_.size());
188     return *companion_sets_[index];
189   }
190 
191   // Returns the companion set index of the given instruction.
companion_set_index(HloInstruction * instruction)192   int64_t companion_set_index(HloInstruction* instruction) const {
193     return companion_set_index_.at(instruction);
194   }
195 
196   // Returns the list of all companion sets in the HLO module group. Each
197   // returned set contains at least one HloInstruction.
198   const std::vector<std::unique_ptr<std::vector<HloInstruction*>>>&
companion_sets()199   companion_sets() const {
200     return companion_sets_;
201   }
202 
203   // Returns all channels in the module group.
channels()204   const std::vector<Channel>& channels() const { return channels_; }
205 
206   // Returns the maximum channel id used in the module group.
max_channel_id()207   int64_t max_channel_id() const { return max_channel_id_; }
208 
alias_analysis(HloModule * module)209   HloAliasAnalysis* alias_analysis(HloModule* module) const {
210     return alias_analyses_.at(module).get();
211   }
212 
213  private:
214   Status Build();
215 
216   // Record all channel instructions, cross-module AllReduce instructions, and
217   // While/Conditional/Call instructions.
218   Status RecordInstructions();
219 
220   // Verifies the given HloModules are well-formed and follow the specification,
221   // in particular with respect to using channel instructions.
222   //
223   // * Each channel has all 4 instructions (Send, Recv, SendDone, RecvDone).
224   // * The shape of channel instructions match.
225   // * The nest level of channel instructions match.
226   // * Channel instructions are used in allowed computations, i.e., in the
227   //   entry computation of the module or condition/body of While computations.
228   Status VerifyChannelInstructions();
229 
230   // Adds metadata that the given two instructions are companions.
231   Status AddCompanion(HloInstruction* instruction1,
232                       HloInstruction* instruction2);
233 
234   // Checks whether a communicating instruction is placed in a valid position
235   // within the graph.
236   Status CheckCommunicatingInstruction(HloInstruction* instruction) const;
237 
238   // Performs a consistency check on the companion sets built for the input
239   // modules. Checks that each instruction in a companion set is in a different
240   // module/device.
241   Status VerifyCompanionSets() const;
242 
243   // Retrieves a pointer to the stored TrackedInstruction associated with a
244   // tracked computation, or nullptr in case such computation is not tracked.
GetTrackedInstruction(const HloComputation * computation)245   const TrackedInstruction* GetTrackedInstruction(
246       const HloComputation* computation) const {
247     auto it = tracked_instructions_.find(computation);
248     return it != tracked_instructions_.end() ? &it->second : nullptr;
249   }
250 
251   // Dump all the collected module group statistics to the logs.
252   void DumpCollectedStats() const;
253 
254   // List of all companion instructions sets in the module.
255   std::vector<std::unique_ptr<std::vector<HloInstruction*>>> companion_sets_;
256 
257   // Map from each companion while instruction to the index into companion_set_.
258   absl::flat_hash_map<const HloInstruction*, int64_t> companion_set_index_;
259 
260   // Map from computation to the instruction using it (a kWhile, kConditional).
261   absl::flat_hash_map<const HloComputation*, TrackedInstruction>
262       tracked_instructions_;
263 
264   // Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of
265   // communicating instructions within the proper called computation(s).
266   absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>>
267       tracked_instructions_comms_;
268 
269   // All channels in the module.
270   std::vector<Channel> channels_;
271 
272   // Map from channel ids to the index in channels_.
273   absl::flat_hash_map<int64_t, int64_t> channel_id_map_;
274 
275   // Map from all-reduce ids to the all reduce instructions.
276   absl::flat_hash_map<int64_t, std::vector<HloInstruction*>> all_reduce_map_;
277 
278   // The maximum channel id used in the module group.
279   int64_t max_channel_id_ = -1;
280 
281   // The modules that this metadata was built from.
282   const std::vector<HloModule*> modules_;
283 
284   absl::flat_hash_map<HloModule*, std::unique_ptr<HloAliasAnalysis>>
285       alias_analyses_;
286 };
287 
288 }  // namespace xla
289 
290 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_
291