1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
17
18 #include <chrono> // NOLINT (required by TF interfaces)
19 #include <cstdlib>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_format.h"
27 #include "absl/synchronization/mutex.h"
28 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
29 #include "tensorflow/compiler/xla/service/global_device_id.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
33
34 namespace xla {
35 namespace gpu {
36
37 // This file runs collective ops (i.e. ops that communicate between multiple
38 // GPUs) using NCCL.
39 //
40 // Here's a high-level overview of how running an op works.
41 //
42 // - Multiple threads call ExecuteOnStream.
43 // - All threads that "go together" (i.e. are participating in the "same"
44 // collective op) choose the same Rendezvous object from a global map.
45 // - Once all threads have arrived at the Rendezvous, we know exactly which
46 // GPUs are participating in the op, so we get or create a NcclClique
47 // containing those GPUs.
48 // - We perform the NCCL operation using the clique.
49
50 NcclCollectiveConfig::NcclCollectiveConfig() = default;
51 NcclCollectiveConfig::NcclCollectiveConfig(NcclCollectiveConfig&&) = default;
52 NcclCollectiveConfig::~NcclCollectiveConfig() = default;
53 NcclCollectiveConfig& NcclCollectiveConfig::operator=(NcclCollectiveConfig&&) =
54 default;
55
56 // Returns if the collective communication operation is degenerate because all
57 // the groups formed by the operation are singleton. A given op can be
58 // degenerate under several conditions, corresponding to the modes supported
59 // in GetParticipatingDevices().
60 // 1. no channel id, use_global_device_ids = false:
61 // degenerate if replica_groups are singleton, or groups empty and
62 // replica_count == 1.
63 // 2. channel_id is set, use_global_device_ids = false:
64 // degenerate if replica_groups are singleton and num_partitions == 1,
65 // or groups empty and num_replicas == 1 && num_partitions == 1.
66 // 3. channel_id is set, use_global_device_ids = true (flattened-ids):
67 // degenerate if replica_groups are singleton (groups cannot be empty).
68 // 4. no channel_id, no use_global_device_ids:
69 // identical to 1.
70 // 5. channel_id is set, no use_global_device_ids:
71 // degenerate if replica_groups are singleton or group emty and
72 // num_partitions == 1 (since replica groups contain partition ids).
73 //
IsDegenerate(int64_t replica_count,int64_t partition_count) const74 bool NcclCollectiveConfig::IsDegenerate(int64_t replica_count,
75 int64_t partition_count) const {
76 bool groups_empty = replica_groups.empty();
77
78 // check if all replica_groups are singleton. If not, then the operation is
79 // not degenerate.
80 bool all_groups_singleton =
81 !groups_empty &&
82 absl::c_all_of(replica_groups, [](const ReplicaGroup& group) {
83 return group.replica_ids_size() == 1;
84 });
85
86 switch (group_mode) {
87 case CollectiveOpGroupMode::kCrossReplica:
88 return all_groups_singleton || (groups_empty && replica_count == 1);
89 case CollectiveOpGroupMode::kCrossPartition:
90 return all_groups_singleton || (groups_empty && partition_count == 1);
91 case CollectiveOpGroupMode::kCrossReplicaAndPartition:
92 return (all_groups_singleton && partition_count == 1) ||
93 (groups_empty && replica_count == 1 && partition_count == 1);
94 case CollectiveOpGroupMode::kFlattenedID:
95 CHECK(!groups_empty)
96 << "replica groups cannot be empty if use_global_device_ids = true";
97 return all_groups_singleton;
98 default:
99 CHECK(0) << "Invalid collective op mode";
100 return false;
101 }
102 }
103
NcclIsEnabled()104 /* static */ bool NcclCollectiveThunk::NcclIsEnabled() {
105 #if XLA_ENABLE_XCCL
106 return true;
107 #else
108 return false;
109 #endif
110 }
111
112 #if XLA_ENABLE_XCCL
LockNcclComm(const NcclExecuteParams & params,const std::vector<ReplicaGroup> & replica_groups,CollectiveOpGroupMode group_mode,int64_t op_id)113 StatusOr<NcclComm::Lock> LockNcclComm(
114 const NcclExecuteParams& params,
115 const std::vector<ReplicaGroup>& replica_groups,
116 CollectiveOpGroupMode group_mode, int64_t op_id) {
117 TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
118 params.GetGlobalDeviceId());
119
120 TF_ASSIGN_OR_RETURN(
121 std::vector<GlobalDeviceId> participants,
122 GetParticipatingDevices(global_device_id, *params.device_assn,
123 replica_groups, group_mode));
124
125 if (IsGlobalNcclConfig() &&
126 (participants.size() != params.device_assn->replica_count())) {
127 return InvalidArgument(
128 "Partial replica groups are not allowed when using NCCL_COMM_ID "
129 "environment configuration.");
130 }
131
132 auto it = absl::c_find(participants, global_device_id);
133 TF_RET_CHECK(it != participants.end());
134 int rank = it - participants.begin();
135
136 size_t num_local_participants = GetNumLocalParticipants(
137 participants, /*local_devices=*/params.gpu_global_device_ids);
138
139 bool is_local = participants.size() == num_local_participants;
140 TF_ASSIGN_OR_RETURN(
141 const NcclUniqueIdCallback* unique_id_callback,
142 GetNcclUniqueIdCallback(params.nccl_unique_id_callback, is_local));
143
144 se::StreamExecutor* executor = params.stream->parent();
145 se::gpu::ScopedActivateExecutorContext scoped_context(executor);
146
147 return AcquireNcclComm(params.run_id, OpId(op_id), std::move(participants),
148 num_local_participants, *unique_id_callback, rank);
149 }
150 #endif // XLA_ENABLE_XCCL
151
ConvertToDeviceBuffers(const Thunk::ExecuteParams & params,const std::vector<NcclCollectiveThunk::Buffer> & buffers,const std::vector<PrimitiveType> & element_types)152 StatusOr<std::vector<DeviceBufferPair>> ConvertToDeviceBuffers(
153 const Thunk::ExecuteParams& params,
154 const std::vector<NcclCollectiveThunk::Buffer>& buffers,
155 const std::vector<PrimitiveType>& element_types) {
156 if (buffers.size() != element_types.size())
157 return FailedPrecondition("Mismatch in operand buffer counts.");
158
159 std::vector<DeviceBufferPair> device_buffers;
160 device_buffers.reserve(buffers.size());
161 for (int i = 0; i < buffers.size(); ++i) {
162 device_buffers.emplace_back(DeviceBufferPair{
163 element_types[i], buffers[i].element_count,
164
165 params.buffer_allocations->GetDeviceAddress(buffers[i].source_buffer),
166 params.buffer_allocations->GetDeviceAddress(
167 buffers[i].destination_buffer)});
168 }
169 return device_buffers;
170 }
171
ExecuteOnStream(const ExecuteParams & params)172 Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
173 #if XLA_ENABLE_XCCL
174 VLOG(1) << absl::StreamFormat("Starting %s.", Thunk::KindToString(kind()));
175 TF_ASSIGN_OR_RETURN(NcclComm::Lock comm,
176 LockNcclComm(params.nccl_params, config().replica_groups,
177 config().group_mode, config().op_id));
178
179 TF_RETURN_IF_ERROR(RunNcclCollective(params, *comm));
180
181 // Block host on the first call to ensure that all devices have allocated the
182 // required buffers for their communicators before allowing any device to
183 // continue enqueuing operations. Otherwise, the allocations can cause
184 // deadlock in the CUDA driver (b/215649390).
185 if (first_call_to_execute_) {
186 TF_RETURN_IF_ERROR(params.stream->BlockHostUntilDone());
187 first_call_to_execute_ = false;
188 }
189 return OkStatus();
190 #else // XLA_ENABLE_XCCL
191 return Unimplemented(
192 "NCCL support is not available: this binary was not built with a CUDA "
193 "compiler, which is necessary to build the NCCL source library.");
194 #endif // XLA_ENABLE_XCCL
195 }
196
GetDeviceString(const NcclExecuteParams & nccl_params)197 std::string NcclCollectiveThunk::GetDeviceString(
198 const NcclExecuteParams& nccl_params) {
199 int device_ordinal = nccl_params.stream->parent()->device_ordinal();
200 GlobalDeviceId global_device_id =
201 nccl_params.GetGlobalDeviceId().ValueOrDie();
202 DeviceAssignment::LogicalID logical_id =
203 nccl_params.device_assn->LogicalIdForDevice(global_device_id)
204 .ValueOrDie();
205 return absl::StrFormat("(r%d, p%d) : GlobalID %d, ord %d",
206 logical_id.replica_id, logical_id.computation_id,
207 global_device_id.value(), device_ordinal);
208 }
209
IsTypeSupportedByNccl(PrimitiveType element_type)210 bool IsTypeSupportedByNccl(PrimitiveType element_type) {
211 switch (element_type) {
212 case S8:
213 case PRED:
214 case U8:
215 case S32:
216 case U32:
217 case S64:
218 case U64:
219 case F16:
220 case F32:
221 case F64:
222 #if defined(__CUDA_BF16_TYPES_EXIST__)
223 case BF16:
224 #endif
225 case C64:
226 case C128:
227 return true;
228 default:
229 return false;
230 }
231 }
232
233 } // namespace gpu
234 } // namespace xla
235