xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/collective_ops_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/compiler/xla/executable_run_options.h"
25 #include "tensorflow/compiler/xla/service/computation_placer.h"
26 #include "tensorflow/compiler/xla/service/global_device_id.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/core/lib/core/blocking_counter.h"
32 
33 namespace xla {
34 
35 enum class ReductionKind { SUM, PRODUCT, MIN, MAX };
36 
37 // Attempts to match instruction to one of the possible cases for ReductionKind.
38 std::optional<ReductionKind> MatchReductionInstruction(
39     const HloInstruction* hlo);
40 
41 // Attempts to match computation to one of the possible cases in ReductionKind.
42 std::optional<ReductionKind> MatchReductionComputation(
43     const HloComputation* computation);
44 
45 // Figures out which IDs are participating in the collective subgroup.
46 // An empty `groups` indicates that all [0, total_participant_count) IDs
47 // are participating. Note that for CollectiveOpGroupMode::kFlattenedID,
48 // groups cannot be empty, so `total_participant_count` is an optional.
49 StatusOr<std::vector<int>> GetParticipatingIDs(
50     int current_id, std::optional<int> total_participant_count,
51     absl::Span<const ReplicaGroup> groups);
52 
53 // There are broadly 4 modes that collective communication ops use to describe
54 // which sets of devices are participating with a given device in the operation.
55 // These modes are determined by the values of channel_id (optional) and
56 // use_global_device_ids (optional). The modes are as follows:
57 //
58 // kCrossReplica:
59 //    implied by: no channel id, use_global_device_ids = false, or
60 //                no channel_id, no use_global_device_ids:
61 //    replica_groups contain replica_id, group contains all replicas for the
62 //    current partition
63 //
64 // kCrossPartition:
65 //    implied by: channel_id is set, no use_global_device_ids:
66 //    replica_groups contain partition_id, group contains all partitions for the
67 //    current replica.
68 //
69 // kCrossReplicaAndPartition:
70 //    implied by: channel_id is set, use_global_device_ids = false:
71 //    replica_groups contain replica_id, group contains all replicas for all
72 //    partitions (as opposed to just current partition).
73 //
74 // kFlattenedID:
75 //    implied by: channel_id is set, use_global_device_ids = true:
76 //    replica_groups contain flattened-ids, group contains devices that are
77 //    listed in the flattened-id list.
78 //
79 // Rest of the combinations are invalid.
80 //
81 // Since the actual value of channel_id does not matter, we use a bool argument
82 // `has_channel_id`, and optional<bool> for use_global_device_ids.
83 // Note that use_global_device_ids true requires channel_id to be set as well.
84 // Additionally, if use_global_device_ids = true, replica groups cannot be
85 // empty (verified in the HLO verifier).
86 enum class CollectiveOpGroupMode {
87   kCrossReplica,
88   kCrossPartition,
89   kCrossReplicaAndPartition,
90   kFlattenedID,
91 };
92 
93 absl::string_view CollectiveOpGroupModeToString(
94     CollectiveOpGroupMode group_mode);
95 
96 // Returns the group formation mode implied by (a) whether the operation has
97 // channel_id and (b) if it has use_global_device_ids and if yes, its value.
98 StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(
99     bool has_channel_id, std::optional<bool> use_global_device_ids);
100 
101 // Figures out subgroups of participating devices from given replica_groups and
102 // group_mode.
103 //
104 // Returns list of participants, where each participant is a list of
105 // GlobalDeviceIds.
106 //
107 // For example:
108 //   device_assignment={{33, 34}, {44, 45}, {55, 56}}  3 replicas 2 partitions
109 //   group_mode=CollectiveOpGroupMode::kCrossReplica
110 //   replica_groups={{0}, {1, 2}}
111 //
112 //   This functions returns {{33, 34}, {44, 45, 55, 56}}
113 //   There are 2 subgroups of participating devices {33, 34}, {44, 45, 55, 56}.
114 StatusOr<std::vector<std::vector<GlobalDeviceId>>>
115 GetParticipatingDevicesGroups(const DeviceAssignment& device_assignment,
116                               absl::Span<const ReplicaGroup> replica_groups,
117                               CollectiveOpGroupMode group_mode);
118 
119 // Figures out which devices are participating in the collective subgroup.
120 StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
121     GlobalDeviceId device_id, const DeviceAssignment& device_assignment,
122     absl::Span<const ReplicaGroup> replica_groups,
123     CollectiveOpGroupMode group_mode);
124 
125 // Returns true if the two replica group are orthogonal.
126 bool ReplicaGroupsOrthogonal(absl::Span<const ReplicaGroup> first,
127                              absl::Span<const ReplicaGroup> second);
128 
129 // A custom call target that can be used to create a nop that can legally
130 // replace a collective op.
131 constexpr char kNopCustomCallTarget[] = "AllocateBuffer";
132 
133 // Returns true if instruction is a collective op or a collective fusion.
134 bool IsCollective(const HloInstruction* instruction);
135 
136 // Key that identifies a particular Rendezvous object in our global hashtable.
137 // This determines which calls to ExecuteOnStream communicate with each other.
138 // The rules are as follows.
139 //
140 // * Only ops with the same RunId can communicate with each other. (This is the
141 //   whole purpose of RunId).
142 //
143 // * Only ops with the same set of participating replicas can communicate with
144 //   each other.  This is how we separate out different replica groups (e.g. a
145 //   single AllReduce HLO might do two reductions, between say GPUs {0,2} and
146 //   {1,3}).
147 //
148 // * Only ops with the same opcode can communicate with each other.  At the
149 //   moment we only support kAllReduce, so we don't check for this explicitly.
150 //
151 // * For cross-module all-reduces (i.e. instr->channel_id().has_value()),
152 //   only ops with the same value for channel_id() can communicate with each
153 //   other.
154 //
155 // * For cross-replica (i.e. same-module) all-reduces (i.e.
156 //   !channel_id().has_value()), only ops from the same module (as
157 //   identified by its unique_id()) can communicate with each other.
158 //
159 struct RendezvousKey {
160   enum CollectiveOpKind {
161     kCrossModule,
162     kCrossReplica,
163   };
164 
RendezvousKeyRendezvousKey165   explicit RendezvousKey(const RunId& run_id,
166                          std::vector<GlobalDeviceId> global_devices,
167                          int num_local_participants,
168                          CollectiveOpKind collective_op_kind, int64_t op_id)
169       : run_id(run_id),
170         global_devices(std::move(global_devices)),
171         num_local_participants(num_local_participants),
172         collective_op_kind(collective_op_kind),
173         op_id(op_id) {}
174 
175   template <typename H>
AbslHashValueRendezvousKey176   friend H AbslHashValue(H h, const RendezvousKey& k) {
177     return H::combine(std::move(h), k.run_id, k.global_devices,
178                       k.num_local_participants, k.collective_op_kind, k.op_id);
179   }
180   friend bool operator==(const RendezvousKey& a, const RendezvousKey& b) {
181     return a.run_id == b.run_id && a.global_devices == b.global_devices &&
182            a.num_local_participants == b.num_local_participants &&
183            a.collective_op_kind == b.collective_op_kind &&  //
184            a.op_id == b.op_id;
185   }
186   friend bool operator!=(const RendezvousKey& a, const RendezvousKey& b) {
187     return !(a == b);
188   }
189 
CollectiveOpKindStringRendezvousKey190   absl::string_view CollectiveOpKindString() const {
191     switch (collective_op_kind) {
192       case kCrossModule:
193         return "cross_module";
194       case kCrossReplica:
195         return "cross_replica";
196     }
197   }
198 
ToStringRendezvousKey199   std::string ToString() const {
200     return absl::StrFormat(
201         "RendezvousKey{run_id=%s, global_devices=[%s], "
202         "num_local_participants=%d, collective_op_kind=%s, op_id=%d}",
203         run_id.ToString(), GlobalDeviceIdsToString(global_devices),
204         num_local_participants, CollectiveOpKindString(), op_id);
205   }
206 
207   RunId run_id;
208   std::vector<GlobalDeviceId> global_devices;
209   int num_local_participants;
210   CollectiveOpKind collective_op_kind;
211   int64_t op_id;
212 };
213 
214 template <typename DescFn>
WaitAndLogIfStuck(tensorflow::BlockingCounter * counter,const DescFn & desc_fn)215 void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
216                        const DescFn& desc_fn) {
217   VLOG(3) << "Begin: " << desc_fn();
218   const std::chrono::milliseconds timeout(5000);
219   bool ok = counter->WaitFor(timeout);
220   if (ok) {
221     VLOG(3) << "Finished: " << desc_fn();
222     return;
223   }
224   LOG(ERROR) << "This thread has been waiting for " << timeout.count()
225              << "ms for and may be stuck: " << desc_fn();
226   counter->Wait();
227   LOG(ERROR) << "Thread is unstuck!  Warning above was a false-positive.  "
228                 "Perhaps the timeout is too short: "
229              << desc_fn();
230 }
231 
232 // Participant data for each rendezvous.
233 struct ParticipantData {
ParticipantDataParticipantData234   explicit ParticipantData(const RendezvousKey& rendezvous_key)
235       : rendezvous_key(rendezvous_key) {}
236 
~ParticipantDataParticipantData237   virtual ~ParticipantData() {}
238 
239   RendezvousKey rendezvous_key;
240 
241   virtual std::string ToString() const = 0;
242 };
243 
244 // Encapsulates parameters to Rendezvous::SubmitParticipant.
245 struct AllReduceParticipantData : ParticipantData {
AllReduceParticipantDataAllReduceParticipantData246   AllReduceParticipantData(const RendezvousKey& rendezvous_key_p,
247                            int64_t device_ordinal_p, se::Stream* stream_p)
248       : ParticipantData(rendezvous_key_p),
249         device_ordinal(device_ordinal_p),
250         stream(stream_p) {}
251 
252   // TODO(b/125951860): We should vet that we're buffer allocating such that
253   // source_buffer == destination_buffer if that avoids a NCCL copy (will depend
254   // on how well the NCCL in-place implementation performs vs the out-of-place
255   // implementation).
256   struct Buffer {
257     int64_t element_count;
258     se::DeviceMemoryBase source_data;
259     se::DeviceMemoryBase destination_data;
260     PrimitiveType primitive_type;
261   };
262   int64_t device_ordinal;
263   se::Stream* stream;
264   std::vector<Buffer> buffers;
265 
266   ReductionKind reduction_kind;
267 
268   // For each local all-reduce participant a (global ID, local device ordinal)
269   // pair for the participant. Participants are in no particular order.
270   std::vector<std::pair<GlobalDeviceId, int64_t>> local_devices;
271 
ToStringAllReduceParticipantData272   std::string ToString() const override {
273     std::vector<std::string> buffer_strs;
274     for (const Buffer& buffer : buffers) {
275       buffer_strs.push_back(
276           absl::StrFormat("{element_count=%d}", buffer.element_count));
277     }
278     return absl::StrFormat(
279         "AllReduceParticipantData{buffers=[%s], rendezvous_key=%s, "
280         "device_ordinal=%d, stream=%p}",
281         absl::StrJoin(buffer_strs, ","), rendezvous_key.ToString(),
282         device_ordinal, stream);
283   }
284 };
285 
286 // The set of threads that want to do a collective op together all pick the same
287 // Rendezvous object out of the global cache and call SubmitParticipant.
288 //
289 // The Rendezvous instance handles waiting for all threads to join, ensuring
290 // that a clique exists for the desired set of GPUs, etc.
291 //
292 // Rendezvous objects can only be used once.
293 //
294 // I: Participant data.
295 // O: Participant output.
296 template <typename I, typename O,
297           typename =
298               std::enable_if_t<std::is_base_of<ParticipantData, I>::value>>
299 class Rendezvous {
300  public:
~Rendezvous()301   virtual ~Rendezvous() {}
Rendezvous(const RendezvousKey & k)302   explicit Rendezvous(const RendezvousKey& k) : key_(k) {}
303 
304   // Submit a participant to the rendezvous. We get the rendezvous from
305   // `rendezvous_getter`, which we can then use to drop the existing reference.
SubmitParticipant(std::function<std::shared_ptr<Rendezvous<I,O>> ()> rendezvous_getter,I participant)306   static StatusOr<O> SubmitParticipant(
307       std::function<std::shared_ptr<Rendezvous<I, O>>()> rendezvous_getter,
308       I participant) {
309     std::shared_ptr<Rendezvous<I, O>> rendezvous = rendezvous_getter();
310     TF_ASSIGN_OR_RETURN(auto p, rendezvous->SubmitParticipant(participant));
311 
312     // Drop our reference to the Rendezvous and wait for all other threads to do
313     // the same.  If we didn't do this, one of the threads could run past this
314     // point, reenter ExecuteOnStream for another all-reduce, and attempt to
315     // reuse the Rendezvous!
316     //
317     // An alternative way of accomplishing this goal would be to implement
318     // RefcountingHashMap::erase() and call it during SubmitParticipant.  But
319     // erase() is deceptively complex to implement correctly.
320     std::shared_ptr<tensorflow::BlockingCounter> blocking_counter = p.second;
321     rendezvous.reset();
322     blocking_counter->DecrementCount();
323     xla::WaitAndLogIfStuck(blocking_counter.get(), [&] {
324       return absl::StrFormat(
325           "participant waiting for all threads to drop their reference to the "
326           "rendezvous: %p",
327           rendezvous.get());
328     });
329     return std::move(p.first);
330   }
331 
332  protected:
333   // Returns domain-specific output O and whether this replica is primary.
334   virtual StatusOr<O> RunCollectiveOp(const I& participant) = 0;
335 
336   // Initialize the rendezvous by the first ("primary") thread which reaches the
337   // barrier. Returns whether this thread is primary.
InitializationBarrier()338   bool InitializationBarrier() {
339     absl::MutexLock lock(&mu_);
340     if (!initialized_) {
341       initialized_ = true;
342       return true;
343     }
344     return false;
345   }
346 
347   absl::Mutex mu_;
348 
349   bool initialized_ ABSL_GUARDED_BY(mu_) = false;
350 
351   std::vector<I> participants_ ABSL_GUARDED_BY(mu_);
352 
353  private:
354   // Runs the all-reduce on the given thread.  If successful, returns
355   //  - a handle to the clique that was used, so that the caller may keep the
356   //    clique alive if it chooses.
357   //  - a BlockingCounter initialized to the number of participants, so that
358   //    the caller can coordinate with the participants one last time if it
359   //    chooses.  This is useful for coordinating destruction of the Rendezvous.
360   StatusOr<std::pair<O, std::shared_ptr<tensorflow::BlockingCounter>>>
SubmitParticipant(const I & participant)361   SubmitParticipant(const I& participant) {
362     {
363       absl::MutexLock lock(&mu_);
364       CHECK(!initialized_);
365 
366       // Spot check for consistent replica counts among submitting threads.
367       if (!participants_.empty() &&
368           participants_.back().rendezvous_key != participant.rendezvous_key) {
369         return InvalidArgument(
370             "Mismatch among all-reduce participants.  Expected same "
371             "replica-count, element-count, and rendezvous-key but were %s and "
372             "%s",
373             participants_.back().ToString(), participant.ToString());
374       }
375       participants_.push_back(participant);
376     }
377 
378     // Wait for all participants to arrive.
379     all_participants_present_.DecrementCount();
380     WaitAndLogIfStuck(&all_participants_present_, [&] {
381       return absl::StrFormat(
382           "participant %s waiting for all participants to arrive at rendezvous "
383           "%s",
384           participant.ToString(), key_.ToString());
385     });
386 
387     TF_ASSIGN_OR_RETURN(O output, RunCollectiveOp(participant));
388     return std::make_pair(std::move(output), returned_blocking_counter_);
389   }
390 
391   const RendezvousKey key_;
392 
393   tensorflow::BlockingCounter all_participants_present_{
394       key_.num_local_participants};
395 
396   // tensorflow::BlockingCounter returned by SubmitParticipant.
397   std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_{
398       std::make_shared<tensorflow::BlockingCounter>(
399           key_.num_local_participants)};
400 };
401 
402 }  // end namespace xla
403 
404 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
405