xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
17 
18 #include <chrono>  // NOLINT (required by TF interfaces)
19 #include <cstdlib>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_format.h"
27 #include "absl/synchronization/mutex.h"
28 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
29 #include "tensorflow/compiler/xla/service/global_device_id.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
33 
34 namespace xla {
35 namespace gpu {
36 
37 // This file runs collective ops (i.e. ops that communicate between multiple
38 // GPUs) using NCCL.
39 //
40 // Here's a high-level overview of how running an op works.
41 //
42 //  - Multiple threads call ExecuteOnStream.
43 //  - All threads that "go together" (i.e. are participating in the "same"
44 //    collective op) choose the same Rendezvous object from a global map.
45 //  - Once all threads have arrived at the Rendezvous, we know exactly which
46 //    GPUs are participating in the op, so we get or create a NcclClique
47 //    containing those GPUs.
48 //  - We perform the NCCL operation using the clique.
49 
50 NcclCollectiveConfig::NcclCollectiveConfig() = default;
51 NcclCollectiveConfig::NcclCollectiveConfig(NcclCollectiveConfig&&) = default;
52 NcclCollectiveConfig::~NcclCollectiveConfig() = default;
53 NcclCollectiveConfig& NcclCollectiveConfig::operator=(NcclCollectiveConfig&&) =
54     default;
55 
56 // Returns if the collective communication operation is degenerate because all
57 // the groups formed by the operation are singleton. A given op can be
58 // degenerate under several conditions, corresponding to the modes supported
59 // in GetParticipatingDevices().
60 //   1. no channel id, use_global_device_ids = false:
61 //         degenerate if replica_groups are singleton, or groups empty and
62 //         replica_count == 1.
63 //   2. channel_id is set, use_global_device_ids = false:
64 //         degenerate if replica_groups are singleton and num_partitions == 1,
65 //         or groups empty and num_replicas == 1 && num_partitions == 1.
66 //   3. channel_id is set, use_global_device_ids = true (flattened-ids):
67 //         degenerate if replica_groups are singleton (groups cannot be empty).
68 //   4. no channel_id, no use_global_device_ids:
69 //         identical to 1.
70 //   5. channel_id is set, no use_global_device_ids:
71 //         degenerate if replica_groups are singleton or group emty and
72 //         num_partitions == 1 (since replica groups contain partition ids).
73 //
IsDegenerate(int64_t replica_count,int64_t partition_count) const74 bool NcclCollectiveConfig::IsDegenerate(int64_t replica_count,
75                                         int64_t partition_count) const {
76   bool groups_empty = replica_groups.empty();
77 
78   // check if all replica_groups are singleton. If not, then the operation is
79   // not degenerate.
80   bool all_groups_singleton =
81       !groups_empty &&
82       absl::c_all_of(replica_groups, [](const ReplicaGroup& group) {
83         return group.replica_ids_size() == 1;
84       });
85 
86   switch (group_mode) {
87     case CollectiveOpGroupMode::kCrossReplica:
88       return all_groups_singleton || (groups_empty && replica_count == 1);
89     case CollectiveOpGroupMode::kCrossPartition:
90       return all_groups_singleton || (groups_empty && partition_count == 1);
91     case CollectiveOpGroupMode::kCrossReplicaAndPartition:
92       return (all_groups_singleton && partition_count == 1) ||
93              (groups_empty && replica_count == 1 && partition_count == 1);
94     case CollectiveOpGroupMode::kFlattenedID:
95       CHECK(!groups_empty)
96           << "replica groups cannot be empty if use_global_device_ids = true";
97       return all_groups_singleton;
98     default:
99       CHECK(0) << "Invalid collective op mode";
100       return false;
101   }
102 }
103 
NcclIsEnabled()104 /* static */ bool NcclCollectiveThunk::NcclIsEnabled() {
105 #if XLA_ENABLE_XCCL
106   return true;
107 #else
108   return false;
109 #endif
110 }
111 
112 #if XLA_ENABLE_XCCL
LockNcclComm(const NcclExecuteParams & params,const std::vector<ReplicaGroup> & replica_groups,CollectiveOpGroupMode group_mode,int64_t op_id)113 StatusOr<NcclComm::Lock> LockNcclComm(
114     const NcclExecuteParams& params,
115     const std::vector<ReplicaGroup>& replica_groups,
116     CollectiveOpGroupMode group_mode, int64_t op_id) {
117   TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
118                       params.GetGlobalDeviceId());
119 
120   TF_ASSIGN_OR_RETURN(
121       std::vector<GlobalDeviceId> participants,
122       GetParticipatingDevices(global_device_id, *params.device_assn,
123                               replica_groups, group_mode));
124 
125   if (IsGlobalNcclConfig() &&
126       (participants.size() != params.device_assn->replica_count())) {
127     return InvalidArgument(
128         "Partial replica groups are not allowed when using NCCL_COMM_ID "
129         "environment configuration.");
130   }
131 
132   auto it = absl::c_find(participants, global_device_id);
133   TF_RET_CHECK(it != participants.end());
134   int rank = it - participants.begin();
135 
136   size_t num_local_participants = GetNumLocalParticipants(
137       participants, /*local_devices=*/params.gpu_global_device_ids);
138 
139   bool is_local = participants.size() == num_local_participants;
140   TF_ASSIGN_OR_RETURN(
141       const NcclUniqueIdCallback* unique_id_callback,
142       GetNcclUniqueIdCallback(params.nccl_unique_id_callback, is_local));
143 
144   se::StreamExecutor* executor = params.stream->parent();
145   se::gpu::ScopedActivateExecutorContext scoped_context(executor);
146 
147   return AcquireNcclComm(params.run_id, OpId(op_id), std::move(participants),
148                          num_local_participants, *unique_id_callback, rank);
149 }
150 #endif  // XLA_ENABLE_XCCL
151 
ConvertToDeviceBuffers(const Thunk::ExecuteParams & params,const std::vector<NcclCollectiveThunk::Buffer> & buffers,const std::vector<PrimitiveType> & element_types)152 StatusOr<std::vector<DeviceBufferPair>> ConvertToDeviceBuffers(
153     const Thunk::ExecuteParams& params,
154     const std::vector<NcclCollectiveThunk::Buffer>& buffers,
155     const std::vector<PrimitiveType>& element_types) {
156   if (buffers.size() != element_types.size())
157     return FailedPrecondition("Mismatch in operand buffer counts.");
158 
159   std::vector<DeviceBufferPair> device_buffers;
160   device_buffers.reserve(buffers.size());
161   for (int i = 0; i < buffers.size(); ++i) {
162     device_buffers.emplace_back(DeviceBufferPair{
163         element_types[i], buffers[i].element_count,
164 
165         params.buffer_allocations->GetDeviceAddress(buffers[i].source_buffer),
166         params.buffer_allocations->GetDeviceAddress(
167             buffers[i].destination_buffer)});
168   }
169   return device_buffers;
170 }
171 
ExecuteOnStream(const ExecuteParams & params)172 Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
173 #if XLA_ENABLE_XCCL
174   VLOG(1) << absl::StreamFormat("Starting %s.", Thunk::KindToString(kind()));
175   TF_ASSIGN_OR_RETURN(NcclComm::Lock comm,
176                       LockNcclComm(params.nccl_params, config().replica_groups,
177                                    config().group_mode, config().op_id));
178 
179   TF_RETURN_IF_ERROR(RunNcclCollective(params, *comm));
180 
181   // Block host on the first call to ensure that all devices have allocated the
182   // required buffers for their communicators before allowing any device to
183   // continue enqueuing operations. Otherwise, the allocations can cause
184   // deadlock in the CUDA driver (b/215649390).
185   if (first_call_to_execute_) {
186     TF_RETURN_IF_ERROR(params.stream->BlockHostUntilDone());
187     first_call_to_execute_ = false;
188   }
189   return OkStatus();
190 #else   // XLA_ENABLE_XCCL
191   return Unimplemented(
192       "NCCL support is not available: this binary was not built with a CUDA "
193       "compiler, which is necessary to build the NCCL source library.");
194 #endif  // XLA_ENABLE_XCCL
195 }
196 
GetDeviceString(const NcclExecuteParams & nccl_params)197 std::string NcclCollectiveThunk::GetDeviceString(
198     const NcclExecuteParams& nccl_params) {
199   int device_ordinal = nccl_params.stream->parent()->device_ordinal();
200   GlobalDeviceId global_device_id =
201       nccl_params.GetGlobalDeviceId().ValueOrDie();
202   DeviceAssignment::LogicalID logical_id =
203       nccl_params.device_assn->LogicalIdForDevice(global_device_id)
204           .ValueOrDie();
205   return absl::StrFormat("(r%d, p%d) : GlobalID %d, ord %d",
206                          logical_id.replica_id, logical_id.computation_id,
207                          global_device_id.value(), device_ordinal);
208 }
209 
IsTypeSupportedByNccl(PrimitiveType element_type)210 bool IsTypeSupportedByNccl(PrimitiveType element_type) {
211   switch (element_type) {
212     case S8:
213     case PRED:
214     case U8:
215     case S32:
216     case U32:
217     case S64:
218     case U64:
219     case F16:
220     case F32:
221     case F64:
222 #if defined(__CUDA_BF16_TYPES_EXIST__)
223     case BF16:
224 #endif
225     case C64:
226     case C128:
227       return true;
228     default:
229       return false;
230   }
231 }
232 
233 }  // namespace gpu
234 }  // namespace xla
235