xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
17 
18 #include <memory>
19 #include <string_view>
20 #include <utility>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/synchronization/notification.h"
25 #include "absl/time/time.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/service/global_device_id.h"
28 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
29 #include "tensorflow/compiler/xla/service/rendezvous.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/core/platform/env.h"
33 
34 namespace xla {
35 namespace gpu {
36 
IsGlobalNcclConfig()37 bool IsGlobalNcclConfig() {
38   static const bool global_nccl_config = std::getenv("NCCL_COMM_ID") != nullptr;
39   return global_nccl_config;
40 }
41 
IsNcclLaunchModeParallel()42 bool IsNcclLaunchModeParallel() {
43   static const bool is_launch_mode_parallel = []() {
44     const char* launch_mode = std::getenv("NCCL_LAUNCH_MODE");
45     return launch_mode && std::string_view(launch_mode) == "PARALLEL";
46   }();
47   return is_launch_mode_parallel;
48 }
49 
ToStatus(ncclResult_t s,const char * file,int64_t line,const char * expr)50 Status ToStatus(ncclResult_t s, const char* file, int64_t line,
51                 const char* expr) {
52   if (s == ncclSuccess) {
53     return OkStatus();
54   }
55   return tensorflow::errors::Internal(
56       absl::StrFormat("%s:%d: NCCL operation %s failed: %s", file, line, expr,
57                       ncclGetErrorString(s)));
58 }
59 
ToNcclReduction(ReductionKind kind)60 ncclRedOp_t ToNcclReduction(ReductionKind kind) {
61   switch (kind) {
62     case ReductionKind::SUM:
63       return ncclSum;
64     case ReductionKind::PRODUCT:
65       return ncclProd;
66     case ReductionKind::MIN:
67       return ncclMin;
68     case ReductionKind::MAX:
69       return ncclMax;
70   }
71 }
72 
73 namespace {
74 
ToNcclDataType(PrimitiveType element_type)75 StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType element_type) {
76   switch (element_type) {
77     case S8:
78       return ncclInt8;
79     case PRED:
80     case U8:
81       return ncclUint8;
82     case S32:
83       return ncclInt32;
84     case U32:
85       return ncclUint32;
86     case S64:
87       return ncclInt64;
88     case U64:
89       return ncclUint64;
90     case F16:
91       return ncclFloat16;
92     case F32:
93     case C64:
94       return ncclFloat32;
95     case F64:
96     case C128:
97       return ncclFloat64;
98 #if defined(__CUDA_BF16_TYPES_EXIST__)
99     case BF16:
100       return ncclBfloat16;
101 #endif
102     default:
103       return tensorflow::errors::InvalidArgument(absl::StrFormat(
104           "Unsupported data type: %s", PrimitiveType_Name(element_type)));
105   }
106 }
107 
ToNcclUniqueId(const std::string & id_str)108 StatusOr<ncclUniqueId> ToNcclUniqueId(const std::string& id_str) {
109   static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES,
110                 "NCCL_UNIQUE_ID_BYTES");
111 
112   TF_RET_CHECK(id_str.size() == NCCL_UNIQUE_ID_BYTES);
113   ncclUniqueId id;
114   absl::c_copy(id_str, id.internal);
115   return id;
116 }
117 
LocalNcclUniqueIdCallback(const NcclCliqueKey &)118 StatusOr<std::string> LocalNcclUniqueIdCallback(const NcclCliqueKey&) {
119   ncclUniqueId id;
120   XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&id));
121   return std::string(id.internal, NCCL_UNIQUE_ID_BYTES);
122 }
123 
124 struct NcclCliqueState {
125   ncclUniqueId unique_id;
126   int64_t run_id = -1;
127 
128   // mu guards ready, status, and communicators during initialization.
129   // Once 'ready' has been notified, the communicators may be accessed without
130   // synchronization.
131   absl::Mutex mu;
132   absl::Notification ready;
133   Status status;
134   absl::flat_hash_map<int, std::unique_ptr<NcclComm>> communicators;
135 };
136 
137 using NcclClique = Lockable<NcclCliqueState>;
138 
AcquireNcclClique(RunId run_id,OpId op_id,NcclCliqueKey clique_key,const NcclUniqueIdCallback & unique_id_callback,size_t num_local_participants)139 std::shared_ptr<StatusOr<NcclClique::Lock>> AcquireNcclClique(
140     RunId run_id, OpId op_id, NcclCliqueKey clique_key,
141     const NcclUniqueIdCallback& unique_id_callback,
142     size_t num_local_participants) {
143   static auto& cliques = *new ThreadSafeMap<NcclCliqueKey, NcclClique>;
144 
145   auto rendezvous_key = std::make_tuple(run_id, op_id, std::move(clique_key));
146 
147   int64_t terminate_timeout = xla::GetDebugOptionsFromFlags()
148                                   .xla_gpu_nccl_termination_timeout_seconds();
149 
150   return RendezvousSingle<StatusOr<NcclClique::Lock>>(
151       rendezvous_key, num_local_participants,
152       [&]() -> StatusOr<NcclClique::Lock> {
153         const NcclCliqueKey& clique_key = std::get<2>(rendezvous_key);
154         NcclClique::Lock clique = cliques[clique_key].Acquire();
155         if (clique->run_id < 0) {
156           TF_ASSIGN_OR_RETURN(std::string id, unique_id_callback(clique_key));
157           TF_ASSIGN_OR_RETURN(clique->unique_id, ToNcclUniqueId(id));
158         }
159         // If multiple executable are running simultaneously while using
160         // multiple hosts, it is possible that different executables could
161         // acquire the same clique on different hosts. We protect against this
162         // by checking that the run ID increases monotonically.
163         bool is_local = clique_key.devices().size() == num_local_participants;
164         TF_RET_CHECK(is_local || (run_id.ToInt() >= clique->run_id));
165         clique->run_id = run_id.ToInt();
166         return clique;
167       },
168       /*warn_stuck_timeout=*/absl::Seconds(10),
169       (terminate_timeout >= 0) ? absl::Seconds(terminate_timeout)
170                                : absl::InfiniteDuration());
171 }
172 
CheckNcclAsyncError(NcclComm & lockable_comm)173 void CheckNcclAsyncError(NcclComm& lockable_comm) {
174   ncclComm_t comm = *lockable_comm.Acquire();
175   if (comm == nullptr) return;
176 
177   Status status = [comm] {
178     ncclResult_t async_err;
179     XLA_CUDA_RETURN_IF_ERROR(ncclCommGetAsyncError(comm, &async_err));
180     if (async_err != ncclSuccess) {
181       LOG(ERROR) << "Aborting communicator: " << comm
182                  << " due to async NCCL error: "
183                  << ncclGetErrorString(async_err);
184       XLA_CUDA_RETURN_IF_ERROR(ncclCommAbort(comm));
185     }
186     return XLA_CUDA_STATUS(async_err);
187   }();
188 
189   if (!status.ok()) LOG(ERROR) << status.ToString();
190 }
191 
192 }  // namespace
193 
ToNcclDataTypeAndCountMultiplier(PrimitiveType element_type)194 StatusOr<std::pair<ncclDataType_t, int>> ToNcclDataTypeAndCountMultiplier(
195     PrimitiveType element_type) {
196   TF_ASSIGN_OR_RETURN(ncclDataType_t dtype, ToNcclDataType(element_type));
197   bool is_complex = primitive_util::IsComplexType(element_type);
198   return std::make_pair(dtype, is_complex ? 2 : 1);
199 }
200 
GetNumLocalParticipants(const std::vector<GlobalDeviceId> & participants,const std::vector<GlobalDeviceId> * local_devices)201 size_t GetNumLocalParticipants(
202     const std::vector<GlobalDeviceId>& participants,
203     const std::vector<GlobalDeviceId>* local_devices) {
204   if (local_devices == nullptr) return participants.size();
205 
206   return absl::c_count_if(participants, [&](const GlobalDeviceId& device_id) {
207     return absl::c_linear_search(*local_devices, device_id);
208   });
209 }
210 
GetNcclUniqueIdCallback(const NcclUniqueIdCallback * unique_id_callback,bool is_local)211 StatusOr<const NcclUniqueIdCallback*> GetNcclUniqueIdCallback(
212     const NcclUniqueIdCallback* unique_id_callback, bool is_local) {
213   if (unique_id_callback != nullptr) return unique_id_callback;
214 
215   TF_RET_CHECK(is_local || IsGlobalNcclConfig())
216       << "If non-local devices are taking part of a collective API on "
217          "GPU, the nccl_unique_id_callback must be provided by the client.";
218 
219   static NcclUniqueIdCallback local_callback(LocalNcclUniqueIdCallback);
220   return &local_callback;
221 }
222 
AcquireNcclComm(RunId run_id,OpId op_id,std::vector<GlobalDeviceId> participants,size_t num_local_participants,const NcclUniqueIdCallback & unique_id_callback,int rank)223 StatusOr<NcclComm::Lock> AcquireNcclComm(
224     RunId run_id, OpId op_id, std::vector<GlobalDeviceId> participants,
225     size_t num_local_participants,
226     const NcclUniqueIdCallback& unique_id_callback, int rank) {
227   // Ensure that this group of threads have exclusive access to the clique to
228   // prevent threads from different groups locking communicators in the clique.
229   NcclCliqueKey clique_key(std::move(participants));
230   std::shared_ptr<StatusOr<NcclClique::Lock>> clique = AcquireNcclClique(
231       run_id, op_id, clique_key, unique_id_callback, num_local_participants);
232 
233   if (!clique->ok()) return clique->status();
234   NcclCliqueState& state = ***clique;
235 
236   struct AllCommunicators {
237     absl::Mutex mu;
238     std::vector<NcclComm*> communicators ABSL_GUARDED_BY(mu);
239   };
240   static auto& all_communicators = *new AllCommunicators;
241 
242   // Launch a thread that periodically checks all NCCL communicators for
243   // asynchronous errors. If an asynchronous error is observed, the communicator
244   // is aborted and an error message logged.
245   static auto check_async_error_thread =
246       tensorflow::Env::Default()->StartThread(
247           tensorflow::ThreadOptions(), "nccl_async_error_thread", [&] {
248             while (true) {
249               absl::SleepFor(absl::Seconds(30));
250               absl::MutexLock lock(&all_communicators.mu);
251               for (NcclComm* comm : all_communicators.communicators) {
252                 CheckNcclAsyncError(*comm);
253               }
254             }
255           });
256   (void)check_async_error_thread;  // Silence unused variable warning.
257 
258   NcclComm::Lock comm;
259   if (state.ready.HasBeenNotified()) {
260     comm = state.communicators[rank]->Acquire();
261   } else {
262     auto comm_ptr = std::make_unique<NcclComm>();
263     comm = comm_ptr->Acquire();
264     int nranks = clique_key.devices().size();
265     const ncclUniqueId& id = state.unique_id;
266     Status status =
267         XLA_CUDA_STATUS(ncclCommInitRank(comm.get(), nranks, id, rank));
268 
269     // Add the communicator to the all_communicators list.
270     {
271       absl::MutexLock lock(&all_communicators.mu);
272       all_communicators.communicators.push_back(comm_ptr.get());
273     }
274 
275     absl::MutexLock lock(&state.mu);
276     state.status.Update(status);
277     state.communicators[rank] = std::move(comm_ptr);
278 
279     // Wait for all communicators to initialize before allowing any progress.
280     // Otherwise we may get deadlocks, because ncclCommInitRank may allocate,
281     // which may block on the completion of device activity on a peer device,
282     // which may depend on the completion of this collective if we do not have a
283     // barrier to prevent it.
284     auto all_initialized = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state.mu) {
285       return state.communicators.size() == num_local_participants;
286     };
287     state.mu.Await(absl::Condition(&all_initialized));
288     status = state.status;
289     if (!state.ready.HasBeenNotified()) {
290       state.ready.Notify();
291     }
292   }
293   if (!state.status.ok()) {
294     return state.status;
295   }
296   return comm;
297 }
298 }  // namespace gpu
299 }  // namespace xla
300