1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
18
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "tensorflow/compiler/xla/executable_run_options.h"
25 #include "tensorflow/compiler/xla/service/computation_placer.h"
26 #include "tensorflow/compiler/xla/service/global_device_id.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/core/lib/core/blocking_counter.h"
32
33 namespace xla {
34
35 enum class ReductionKind { SUM, PRODUCT, MIN, MAX };
36
37 // Attempts to match instruction to one of the possible cases for ReductionKind.
38 std::optional<ReductionKind> MatchReductionInstruction(
39 const HloInstruction* hlo);
40
41 // Attempts to match computation to one of the possible cases in ReductionKind.
42 std::optional<ReductionKind> MatchReductionComputation(
43 const HloComputation* computation);
44
45 // Figures out which IDs are participating in the collective subgroup.
46 // An empty `groups` indicates that all [0, total_participant_count) IDs
47 // are participating. Note that for CollectiveOpGroupMode::kFlattenedID,
48 // groups cannot be empty, so `total_participant_count` is an optional.
49 StatusOr<std::vector<int>> GetParticipatingIDs(
50 int current_id, std::optional<int> total_participant_count,
51 absl::Span<const ReplicaGroup> groups);
52
53 // There are broadly 4 modes that collective communication ops use to describe
54 // which sets of devices are participating with a given device in the operation.
55 // These modes are determined by the values of channel_id (optional) and
56 // use_global_device_ids (optional). The modes are as follows:
57 //
58 // kCrossReplica:
59 // implied by: no channel id, use_global_device_ids = false, or
60 // no channel_id, no use_global_device_ids:
61 // replica_groups contain replica_id, group contains all replicas for the
62 // current partition
63 //
64 // kCrossPartition:
65 // implied by: channel_id is set, no use_global_device_ids:
66 // replica_groups contain partition_id, group contains all partitions for the
67 // current replica.
68 //
69 // kCrossReplicaAndPartition:
70 // implied by: channel_id is set, use_global_device_ids = false:
71 // replica_groups contain replica_id, group contains all replicas for all
72 // partitions (as opposed to just current partition).
73 //
74 // kFlattenedID:
75 // implied by: channel_id is set, use_global_device_ids = true:
76 // replica_groups contain flattened-ids, group contains devices that are
77 // listed in the flattened-id list.
78 //
79 // Rest of the combinations are invalid.
80 //
81 // Since the actual value of channel_id does not matter, we use a bool argument
82 // `has_channel_id`, and optional<bool> for use_global_device_ids.
83 // Note that use_global_device_ids true requires channel_id to be set as well.
84 // Additionally, if use_global_device_ids = true, replica groups cannot be
85 // empty (verified in the HLO verifier).
86 enum class CollectiveOpGroupMode {
87 kCrossReplica,
88 kCrossPartition,
89 kCrossReplicaAndPartition,
90 kFlattenedID,
91 };
92
93 absl::string_view CollectiveOpGroupModeToString(
94 CollectiveOpGroupMode group_mode);
95
96 // Returns the group formation mode implied by (a) whether the operation has
97 // channel_id and (b) if it has use_global_device_ids and if yes, its value.
98 StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(
99 bool has_channel_id, std::optional<bool> use_global_device_ids);
100
101 // Figures out subgroups of participating devices from given replica_groups and
102 // group_mode.
103 //
104 // Returns list of participants, where each participant is a list of
105 // GlobalDeviceIds.
106 //
107 // For example:
108 // device_assignment={{33, 34}, {44, 45}, {55, 56}} 3 replicas 2 partitions
109 // group_mode=CollectiveOpGroupMode::kCrossReplica
110 // replica_groups={{0}, {1, 2}}
111 //
112 // This functions returns {{33, 34}, {44, 45, 55, 56}}
113 // There are 2 subgroups of participating devices {33, 34}, {44, 45, 55, 56}.
114 StatusOr<std::vector<std::vector<GlobalDeviceId>>>
115 GetParticipatingDevicesGroups(const DeviceAssignment& device_assignment,
116 absl::Span<const ReplicaGroup> replica_groups,
117 CollectiveOpGroupMode group_mode);
118
119 // Figures out which devices are participating in the collective subgroup.
120 StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
121 GlobalDeviceId device_id, const DeviceAssignment& device_assignment,
122 absl::Span<const ReplicaGroup> replica_groups,
123 CollectiveOpGroupMode group_mode);
124
125 // Returns true if the two replica group are orthogonal.
126 bool ReplicaGroupsOrthogonal(absl::Span<const ReplicaGroup> first,
127 absl::Span<const ReplicaGroup> second);
128
129 // A custom call target that can be used to create a nop that can legally
130 // replace a collective op.
131 constexpr char kNopCustomCallTarget[] = "AllocateBuffer";
132
133 // Returns true if instruction is a collective op or a collective fusion.
134 bool IsCollective(const HloInstruction* instruction);
135
136 // Key that identifies a particular Rendezvous object in our global hashtable.
137 // This determines which calls to ExecuteOnStream communicate with each other.
138 // The rules are as follows.
139 //
140 // * Only ops with the same RunId can communicate with each other. (This is the
141 // whole purpose of RunId).
142 //
143 // * Only ops with the same set of participating replicas can communicate with
144 // each other. This is how we separate out different replica groups (e.g. a
145 // single AllReduce HLO might do two reductions, between say GPUs {0,2} and
146 // {1,3}).
147 //
148 // * Only ops with the same opcode can communicate with each other. At the
149 // moment we only support kAllReduce, so we don't check for this explicitly.
150 //
151 // * For cross-module all-reduces (i.e. instr->channel_id().has_value()),
152 // only ops with the same value for channel_id() can communicate with each
153 // other.
154 //
155 // * For cross-replica (i.e. same-module) all-reduces (i.e.
156 // !channel_id().has_value()), only ops from the same module (as
157 // identified by its unique_id()) can communicate with each other.
158 //
159 struct RendezvousKey {
160 enum CollectiveOpKind {
161 kCrossModule,
162 kCrossReplica,
163 };
164
RendezvousKeyRendezvousKey165 explicit RendezvousKey(const RunId& run_id,
166 std::vector<GlobalDeviceId> global_devices,
167 int num_local_participants,
168 CollectiveOpKind collective_op_kind, int64_t op_id)
169 : run_id(run_id),
170 global_devices(std::move(global_devices)),
171 num_local_participants(num_local_participants),
172 collective_op_kind(collective_op_kind),
173 op_id(op_id) {}
174
175 template <typename H>
AbslHashValueRendezvousKey176 friend H AbslHashValue(H h, const RendezvousKey& k) {
177 return H::combine(std::move(h), k.run_id, k.global_devices,
178 k.num_local_participants, k.collective_op_kind, k.op_id);
179 }
180 friend bool operator==(const RendezvousKey& a, const RendezvousKey& b) {
181 return a.run_id == b.run_id && a.global_devices == b.global_devices &&
182 a.num_local_participants == b.num_local_participants &&
183 a.collective_op_kind == b.collective_op_kind && //
184 a.op_id == b.op_id;
185 }
186 friend bool operator!=(const RendezvousKey& a, const RendezvousKey& b) {
187 return !(a == b);
188 }
189
CollectiveOpKindStringRendezvousKey190 absl::string_view CollectiveOpKindString() const {
191 switch (collective_op_kind) {
192 case kCrossModule:
193 return "cross_module";
194 case kCrossReplica:
195 return "cross_replica";
196 }
197 }
198
ToStringRendezvousKey199 std::string ToString() const {
200 return absl::StrFormat(
201 "RendezvousKey{run_id=%s, global_devices=[%s], "
202 "num_local_participants=%d, collective_op_kind=%s, op_id=%d}",
203 run_id.ToString(), GlobalDeviceIdsToString(global_devices),
204 num_local_participants, CollectiveOpKindString(), op_id);
205 }
206
207 RunId run_id;
208 std::vector<GlobalDeviceId> global_devices;
209 int num_local_participants;
210 CollectiveOpKind collective_op_kind;
211 int64_t op_id;
212 };
213
214 template <typename DescFn>
WaitAndLogIfStuck(tensorflow::BlockingCounter * counter,const DescFn & desc_fn)215 void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
216 const DescFn& desc_fn) {
217 VLOG(3) << "Begin: " << desc_fn();
218 const std::chrono::milliseconds timeout(5000);
219 bool ok = counter->WaitFor(timeout);
220 if (ok) {
221 VLOG(3) << "Finished: " << desc_fn();
222 return;
223 }
224 LOG(ERROR) << "This thread has been waiting for " << timeout.count()
225 << "ms for and may be stuck: " << desc_fn();
226 counter->Wait();
227 LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. "
228 "Perhaps the timeout is too short: "
229 << desc_fn();
230 }
231
232 // Participant data for each rendezvous.
233 struct ParticipantData {
ParticipantDataParticipantData234 explicit ParticipantData(const RendezvousKey& rendezvous_key)
235 : rendezvous_key(rendezvous_key) {}
236
~ParticipantDataParticipantData237 virtual ~ParticipantData() {}
238
239 RendezvousKey rendezvous_key;
240
241 virtual std::string ToString() const = 0;
242 };
243
244 // Encapsulates parameters to Rendezvous::SubmitParticipant.
245 struct AllReduceParticipantData : ParticipantData {
AllReduceParticipantDataAllReduceParticipantData246 AllReduceParticipantData(const RendezvousKey& rendezvous_key_p,
247 int64_t device_ordinal_p, se::Stream* stream_p)
248 : ParticipantData(rendezvous_key_p),
249 device_ordinal(device_ordinal_p),
250 stream(stream_p) {}
251
252 // TODO(b/125951860): We should vet that we're buffer allocating such that
253 // source_buffer == destination_buffer if that avoids a NCCL copy (will depend
254 // on how well the NCCL in-place implementation performs vs the out-of-place
255 // implementation).
256 struct Buffer {
257 int64_t element_count;
258 se::DeviceMemoryBase source_data;
259 se::DeviceMemoryBase destination_data;
260 PrimitiveType primitive_type;
261 };
262 int64_t device_ordinal;
263 se::Stream* stream;
264 std::vector<Buffer> buffers;
265
266 ReductionKind reduction_kind;
267
268 // For each local all-reduce participant a (global ID, local device ordinal)
269 // pair for the participant. Participants are in no particular order.
270 std::vector<std::pair<GlobalDeviceId, int64_t>> local_devices;
271
ToStringAllReduceParticipantData272 std::string ToString() const override {
273 std::vector<std::string> buffer_strs;
274 for (const Buffer& buffer : buffers) {
275 buffer_strs.push_back(
276 absl::StrFormat("{element_count=%d}", buffer.element_count));
277 }
278 return absl::StrFormat(
279 "AllReduceParticipantData{buffers=[%s], rendezvous_key=%s, "
280 "device_ordinal=%d, stream=%p}",
281 absl::StrJoin(buffer_strs, ","), rendezvous_key.ToString(),
282 device_ordinal, stream);
283 }
284 };
285
286 // The set of threads that want to do a collective op together all pick the same
287 // Rendezvous object out of the global cache and call SubmitParticipant.
288 //
289 // The Rendezvous instance handles waiting for all threads to join, ensuring
290 // that a clique exists for the desired set of GPUs, etc.
291 //
292 // Rendezvous objects can only be used once.
293 //
294 // I: Participant data.
295 // O: Participant output.
296 template <typename I, typename O,
297 typename =
298 std::enable_if_t<std::is_base_of<ParticipantData, I>::value>>
299 class Rendezvous {
300 public:
~Rendezvous()301 virtual ~Rendezvous() {}
Rendezvous(const RendezvousKey & k)302 explicit Rendezvous(const RendezvousKey& k) : key_(k) {}
303
304 // Submit a participant to the rendezvous. We get the rendezvous from
305 // `rendezvous_getter`, which we can then use to drop the existing reference.
SubmitParticipant(std::function<std::shared_ptr<Rendezvous<I,O>> ()> rendezvous_getter,I participant)306 static StatusOr<O> SubmitParticipant(
307 std::function<std::shared_ptr<Rendezvous<I, O>>()> rendezvous_getter,
308 I participant) {
309 std::shared_ptr<Rendezvous<I, O>> rendezvous = rendezvous_getter();
310 TF_ASSIGN_OR_RETURN(auto p, rendezvous->SubmitParticipant(participant));
311
312 // Drop our reference to the Rendezvous and wait for all other threads to do
313 // the same. If we didn't do this, one of the threads could run past this
314 // point, reenter ExecuteOnStream for another all-reduce, and attempt to
315 // reuse the Rendezvous!
316 //
317 // An alternative way of accomplishing this goal would be to implement
318 // RefcountingHashMap::erase() and call it during SubmitParticipant. But
319 // erase() is deceptively complex to implement correctly.
320 std::shared_ptr<tensorflow::BlockingCounter> blocking_counter = p.second;
321 rendezvous.reset();
322 blocking_counter->DecrementCount();
323 xla::WaitAndLogIfStuck(blocking_counter.get(), [&] {
324 return absl::StrFormat(
325 "participant waiting for all threads to drop their reference to the "
326 "rendezvous: %p",
327 rendezvous.get());
328 });
329 return std::move(p.first);
330 }
331
332 protected:
333 // Returns domain-specific output O and whether this replica is primary.
334 virtual StatusOr<O> RunCollectiveOp(const I& participant) = 0;
335
336 // Initialize the rendezvous by the first ("primary") thread which reaches the
337 // barrier. Returns whether this thread is primary.
InitializationBarrier()338 bool InitializationBarrier() {
339 absl::MutexLock lock(&mu_);
340 if (!initialized_) {
341 initialized_ = true;
342 return true;
343 }
344 return false;
345 }
346
347 absl::Mutex mu_;
348
349 bool initialized_ ABSL_GUARDED_BY(mu_) = false;
350
351 std::vector<I> participants_ ABSL_GUARDED_BY(mu_);
352
353 private:
354 // Runs the all-reduce on the given thread. If successful, returns
355 // - a handle to the clique that was used, so that the caller may keep the
356 // clique alive if it chooses.
357 // - a BlockingCounter initialized to the number of participants, so that
358 // the caller can coordinate with the participants one last time if it
359 // chooses. This is useful for coordinating destruction of the Rendezvous.
360 StatusOr<std::pair<O, std::shared_ptr<tensorflow::BlockingCounter>>>
SubmitParticipant(const I & participant)361 SubmitParticipant(const I& participant) {
362 {
363 absl::MutexLock lock(&mu_);
364 CHECK(!initialized_);
365
366 // Spot check for consistent replica counts among submitting threads.
367 if (!participants_.empty() &&
368 participants_.back().rendezvous_key != participant.rendezvous_key) {
369 return InvalidArgument(
370 "Mismatch among all-reduce participants. Expected same "
371 "replica-count, element-count, and rendezvous-key but were %s and "
372 "%s",
373 participants_.back().ToString(), participant.ToString());
374 }
375 participants_.push_back(participant);
376 }
377
378 // Wait for all participants to arrive.
379 all_participants_present_.DecrementCount();
380 WaitAndLogIfStuck(&all_participants_present_, [&] {
381 return absl::StrFormat(
382 "participant %s waiting for all participants to arrive at rendezvous "
383 "%s",
384 participant.ToString(), key_.ToString());
385 });
386
387 TF_ASSIGN_OR_RETURN(O output, RunCollectiveOp(participant));
388 return std::make_pair(std::move(output), returned_blocking_counter_);
389 }
390
391 const RendezvousKey key_;
392
393 tensorflow::BlockingCounter all_participants_present_{
394 key_.num_local_participants};
395
396 // tensorflow::BlockingCounter returned by SubmitParticipant.
397 std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_{
398 std::make_shared<tensorflow::BlockingCounter>(
399 key_.num_local_participants)};
400 };
401
402 } // end namespace xla
403
404 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
405