xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_
18 
19 #include <memory>
20 
21 #include "absl/synchronization/mutex.h"
22 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
23 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
24 #include "tensorflow/compiler/xla/status.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 
28 // Common place for all collective thunks to include nccl/rccl headers.
29 #if TENSORFLOW_USE_ROCM
30 #include "rocm/include/rccl/rccl.h"
31 #else
32 #include "third_party/nccl/nccl.h"
33 #endif
34 
35 namespace xla {
36 namespace gpu {
37 
38 ncclRedOp_t ToNcclReduction(ReductionKind kind);
39 StatusOr<std::pair<ncclDataType_t, int>> ToNcclDataTypeAndCountMultiplier(
40     PrimitiveType element_type);
41 
42 bool IsGlobalNcclConfig();
43 bool IsNcclLaunchModeParallel();
44 
45 Status ToStatus(ncclResult_t s, const char* file, int64_t line,
46                 const char* expr);
47 
48 // Macros to return or warn on CUDA/NCCL errors.  (The same macro works for both
49 // NCCL and CUDA errors.)
50 //
51 // It's tempting to say these macros belong in an XLA header somewhere, but in
52 // practice we don't do much direct-to-CUDA-API stuff outside of this file.
53 #define XLA_CUDA_STATUS(expr) \
54   xla::gpu::ToStatus(expr, __FILE__, __LINE__, #expr)
55 
56 #define XLA_CUDA_RETURN_IF_ERROR(expr) \
57   do {                                 \
58     Status s = XLA_CUDA_STATUS(expr);  \
59     if (!s.ok()) {                     \
60       return s;                        \
61     }                                  \
62   } while (0)
63 
64 #define XLA_CUDA_WARN_IF_ERROR(expr)  \
65   do {                                \
66     Status s = XLA_CUDA_STATUS(expr); \
67     if (!s.ok()) {                    \
68       LOG(ERROR) << s.ToString();     \
69     }                                 \
70   } while (0)
71 
72 size_t GetNumLocalParticipants(
73     const std::vector<GlobalDeviceId>& participants,
74     const std::vector<GlobalDeviceId>* local_devices);  // may be null
75 
76 StatusOr<const NcclUniqueIdCallback*> GetNcclUniqueIdCallback(
77     const NcclUniqueIdCallback* unique_id_callback,  // may be null
78     bool is_local);
79 
80 // Represents a type that requires mutually exclusive access.
81 template <typename T>
82 class Lockable {
83  public:
84   // RAII type that will release the exclusive lock when it is destroyed.
85   using Lock = std::unique_ptr<T, std::function<void(T*)>>;
86 
87   Lockable() = default;
Lockable(T value)88   explicit Lockable(T value) : value_(std::move(value)) {}
89   Lockable(const Lockable&) = delete;
90   Lockable(Lockable&&) = delete;
91   Lockable& operator=(const Lockable&) = delete;
92   Lockable& operator=(Lockable&&) = delete;
93 
Acquire()94   Lock Acquire() {
95     absl::MutexLock lock(&mutex_);
96     mutex_.Await(absl::Condition(&is_unlocked_));
97     is_unlocked_ = false;
98 
99     return {&value_, [this](T*) {
100               absl::MutexLock lock(&mutex_);
101               CHECK(!is_unlocked_);
102               is_unlocked_ = true;
103             }};
104   }
105 
106  private:
107   T value_;
108   absl::Mutex mutex_;
109   bool is_unlocked_ ABSL_GUARDED_BY(mutex_) = true;
110 };
111 
112 TF_LIB_GTL_DEFINE_INT_TYPE(OpId, int64_t);
113 
114 struct NcclComm : public Lockable<ncclComm_t> {
NcclCommNcclComm115   NcclComm() : Lockable(nullptr) {}
116 };
117 
118 StatusOr<NcclComm::Lock> AcquireNcclComm(
119     RunId run_id, OpId op_id, std::vector<GlobalDeviceId> participants,
120     size_t num_local_participants,
121     const NcclUniqueIdCallback& unique_id_callback, int rank);
122 
123 }  // namespace gpu
124 }  // namespace xla
125 
126 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_
127