1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
17
18 #include <memory>
19 #include <string_view>
20 #include <utility>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/synchronization/notification.h"
25 #include "absl/time/time.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/service/global_device_id.h"
28 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
29 #include "tensorflow/compiler/xla/service/rendezvous.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/core/platform/env.h"
33
34 namespace xla {
35 namespace gpu {
36
IsGlobalNcclConfig()37 bool IsGlobalNcclConfig() {
38 static const bool global_nccl_config = std::getenv("NCCL_COMM_ID") != nullptr;
39 return global_nccl_config;
40 }
41
IsNcclLaunchModeParallel()42 bool IsNcclLaunchModeParallel() {
43 static const bool is_launch_mode_parallel = []() {
44 const char* launch_mode = std::getenv("NCCL_LAUNCH_MODE");
45 return launch_mode && std::string_view(launch_mode) == "PARALLEL";
46 }();
47 return is_launch_mode_parallel;
48 }
49
ToStatus(ncclResult_t s,const char * file,int64_t line,const char * expr)50 Status ToStatus(ncclResult_t s, const char* file, int64_t line,
51 const char* expr) {
52 if (s == ncclSuccess) {
53 return OkStatus();
54 }
55 return tensorflow::errors::Internal(
56 absl::StrFormat("%s:%d: NCCL operation %s failed: %s", file, line, expr,
57 ncclGetErrorString(s)));
58 }
59
ToNcclReduction(ReductionKind kind)60 ncclRedOp_t ToNcclReduction(ReductionKind kind) {
61 switch (kind) {
62 case ReductionKind::SUM:
63 return ncclSum;
64 case ReductionKind::PRODUCT:
65 return ncclProd;
66 case ReductionKind::MIN:
67 return ncclMin;
68 case ReductionKind::MAX:
69 return ncclMax;
70 }
71 }
72
73 namespace {
74
ToNcclDataType(PrimitiveType element_type)75 StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType element_type) {
76 switch (element_type) {
77 case S8:
78 return ncclInt8;
79 case PRED:
80 case U8:
81 return ncclUint8;
82 case S32:
83 return ncclInt32;
84 case U32:
85 return ncclUint32;
86 case S64:
87 return ncclInt64;
88 case U64:
89 return ncclUint64;
90 case F16:
91 return ncclFloat16;
92 case F32:
93 case C64:
94 return ncclFloat32;
95 case F64:
96 case C128:
97 return ncclFloat64;
98 #if defined(__CUDA_BF16_TYPES_EXIST__)
99 case BF16:
100 return ncclBfloat16;
101 #endif
102 default:
103 return tensorflow::errors::InvalidArgument(absl::StrFormat(
104 "Unsupported data type: %s", PrimitiveType_Name(element_type)));
105 }
106 }
107
ToNcclUniqueId(const std::string & id_str)108 StatusOr<ncclUniqueId> ToNcclUniqueId(const std::string& id_str) {
109 static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES,
110 "NCCL_UNIQUE_ID_BYTES");
111
112 TF_RET_CHECK(id_str.size() == NCCL_UNIQUE_ID_BYTES);
113 ncclUniqueId id;
114 absl::c_copy(id_str, id.internal);
115 return id;
116 }
117
LocalNcclUniqueIdCallback(const NcclCliqueKey &)118 StatusOr<std::string> LocalNcclUniqueIdCallback(const NcclCliqueKey&) {
119 ncclUniqueId id;
120 XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&id));
121 return std::string(id.internal, NCCL_UNIQUE_ID_BYTES);
122 }
123
124 struct NcclCliqueState {
125 ncclUniqueId unique_id;
126 int64_t run_id = -1;
127
128 // mu guards ready, status, and communicators during initialization.
129 // Once 'ready' has been notified, the communicators may be accessed without
130 // synchronization.
131 absl::Mutex mu;
132 absl::Notification ready;
133 Status status;
134 absl::flat_hash_map<int, std::unique_ptr<NcclComm>> communicators;
135 };
136
137 using NcclClique = Lockable<NcclCliqueState>;
138
AcquireNcclClique(RunId run_id,OpId op_id,NcclCliqueKey clique_key,const NcclUniqueIdCallback & unique_id_callback,size_t num_local_participants)139 std::shared_ptr<StatusOr<NcclClique::Lock>> AcquireNcclClique(
140 RunId run_id, OpId op_id, NcclCliqueKey clique_key,
141 const NcclUniqueIdCallback& unique_id_callback,
142 size_t num_local_participants) {
143 static auto& cliques = *new ThreadSafeMap<NcclCliqueKey, NcclClique>;
144
145 auto rendezvous_key = std::make_tuple(run_id, op_id, std::move(clique_key));
146
147 int64_t terminate_timeout = xla::GetDebugOptionsFromFlags()
148 .xla_gpu_nccl_termination_timeout_seconds();
149
150 return RendezvousSingle<StatusOr<NcclClique::Lock>>(
151 rendezvous_key, num_local_participants,
152 [&]() -> StatusOr<NcclClique::Lock> {
153 const NcclCliqueKey& clique_key = std::get<2>(rendezvous_key);
154 NcclClique::Lock clique = cliques[clique_key].Acquire();
155 if (clique->run_id < 0) {
156 TF_ASSIGN_OR_RETURN(std::string id, unique_id_callback(clique_key));
157 TF_ASSIGN_OR_RETURN(clique->unique_id, ToNcclUniqueId(id));
158 }
159 // If multiple executable are running simultaneously while using
160 // multiple hosts, it is possible that different executables could
161 // acquire the same clique on different hosts. We protect against this
162 // by checking that the run ID increases monotonically.
163 bool is_local = clique_key.devices().size() == num_local_participants;
164 TF_RET_CHECK(is_local || (run_id.ToInt() >= clique->run_id));
165 clique->run_id = run_id.ToInt();
166 return clique;
167 },
168 /*warn_stuck_timeout=*/absl::Seconds(10),
169 (terminate_timeout >= 0) ? absl::Seconds(terminate_timeout)
170 : absl::InfiniteDuration());
171 }
172
CheckNcclAsyncError(NcclComm & lockable_comm)173 void CheckNcclAsyncError(NcclComm& lockable_comm) {
174 ncclComm_t comm = *lockable_comm.Acquire();
175 if (comm == nullptr) return;
176
177 Status status = [comm] {
178 ncclResult_t async_err;
179 XLA_CUDA_RETURN_IF_ERROR(ncclCommGetAsyncError(comm, &async_err));
180 if (async_err != ncclSuccess) {
181 LOG(ERROR) << "Aborting communicator: " << comm
182 << " due to async NCCL error: "
183 << ncclGetErrorString(async_err);
184 XLA_CUDA_RETURN_IF_ERROR(ncclCommAbort(comm));
185 }
186 return XLA_CUDA_STATUS(async_err);
187 }();
188
189 if (!status.ok()) LOG(ERROR) << status.ToString();
190 }
191
192 } // namespace
193
ToNcclDataTypeAndCountMultiplier(PrimitiveType element_type)194 StatusOr<std::pair<ncclDataType_t, int>> ToNcclDataTypeAndCountMultiplier(
195 PrimitiveType element_type) {
196 TF_ASSIGN_OR_RETURN(ncclDataType_t dtype, ToNcclDataType(element_type));
197 bool is_complex = primitive_util::IsComplexType(element_type);
198 return std::make_pair(dtype, is_complex ? 2 : 1);
199 }
200
GetNumLocalParticipants(const std::vector<GlobalDeviceId> & participants,const std::vector<GlobalDeviceId> * local_devices)201 size_t GetNumLocalParticipants(
202 const std::vector<GlobalDeviceId>& participants,
203 const std::vector<GlobalDeviceId>* local_devices) {
204 if (local_devices == nullptr) return participants.size();
205
206 return absl::c_count_if(participants, [&](const GlobalDeviceId& device_id) {
207 return absl::c_linear_search(*local_devices, device_id);
208 });
209 }
210
GetNcclUniqueIdCallback(const NcclUniqueIdCallback * unique_id_callback,bool is_local)211 StatusOr<const NcclUniqueIdCallback*> GetNcclUniqueIdCallback(
212 const NcclUniqueIdCallback* unique_id_callback, bool is_local) {
213 if (unique_id_callback != nullptr) return unique_id_callback;
214
215 TF_RET_CHECK(is_local || IsGlobalNcclConfig())
216 << "If non-local devices are taking part of a collective API on "
217 "GPU, the nccl_unique_id_callback must be provided by the client.";
218
219 static NcclUniqueIdCallback local_callback(LocalNcclUniqueIdCallback);
220 return &local_callback;
221 }
222
AcquireNcclComm(RunId run_id,OpId op_id,std::vector<GlobalDeviceId> participants,size_t num_local_participants,const NcclUniqueIdCallback & unique_id_callback,int rank)223 StatusOr<NcclComm::Lock> AcquireNcclComm(
224 RunId run_id, OpId op_id, std::vector<GlobalDeviceId> participants,
225 size_t num_local_participants,
226 const NcclUniqueIdCallback& unique_id_callback, int rank) {
227 // Ensure that this group of threads have exclusive access to the clique to
228 // prevent threads from different groups locking communicators in the clique.
229 NcclCliqueKey clique_key(std::move(participants));
230 std::shared_ptr<StatusOr<NcclClique::Lock>> clique = AcquireNcclClique(
231 run_id, op_id, clique_key, unique_id_callback, num_local_participants);
232
233 if (!clique->ok()) return clique->status();
234 NcclCliqueState& state = ***clique;
235
236 struct AllCommunicators {
237 absl::Mutex mu;
238 std::vector<NcclComm*> communicators ABSL_GUARDED_BY(mu);
239 };
240 static auto& all_communicators = *new AllCommunicators;
241
242 // Launch a thread that periodically checks all NCCL communicators for
243 // asynchronous errors. If an asynchronous error is observed, the communicator
244 // is aborted and an error message logged.
245 static auto check_async_error_thread =
246 tensorflow::Env::Default()->StartThread(
247 tensorflow::ThreadOptions(), "nccl_async_error_thread", [&] {
248 while (true) {
249 absl::SleepFor(absl::Seconds(30));
250 absl::MutexLock lock(&all_communicators.mu);
251 for (NcclComm* comm : all_communicators.communicators) {
252 CheckNcclAsyncError(*comm);
253 }
254 }
255 });
256 (void)check_async_error_thread; // Silence unused variable warning.
257
258 NcclComm::Lock comm;
259 if (state.ready.HasBeenNotified()) {
260 comm = state.communicators[rank]->Acquire();
261 } else {
262 auto comm_ptr = std::make_unique<NcclComm>();
263 comm = comm_ptr->Acquire();
264 int nranks = clique_key.devices().size();
265 const ncclUniqueId& id = state.unique_id;
266 Status status =
267 XLA_CUDA_STATUS(ncclCommInitRank(comm.get(), nranks, id, rank));
268
269 // Add the communicator to the all_communicators list.
270 {
271 absl::MutexLock lock(&all_communicators.mu);
272 all_communicators.communicators.push_back(comm_ptr.get());
273 }
274
275 absl::MutexLock lock(&state.mu);
276 state.status.Update(status);
277 state.communicators[rank] = std::move(comm_ptr);
278
279 // Wait for all communicators to initialize before allowing any progress.
280 // Otherwise we may get deadlocks, because ncclCommInitRank may allocate,
281 // which may block on the completion of device activity on a peer device,
282 // which may depend on the completion of this collective if we do not have a
283 // barrier to prevent it.
284 auto all_initialized = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state.mu) {
285 return state.communicators.size() == num_local_participants;
286 };
287 state.mu.Await(absl::Condition(&all_initialized));
288 status = state.status;
289 if (!state.ready.HasBeenNotified()) {
290 state.ready.Notify();
291 }
292 }
293 if (!state.status.ok()) {
294 return state.status;
295 }
296 return comm;
297 }
298 } // namespace gpu
299 } // namespace xla
300