1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_ 18 19 #include <memory> 20 21 #include "absl/synchronization/mutex.h" 22 #include "tensorflow/compiler/xla/service/collective_ops_utils.h" 23 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" 24 #include "tensorflow/compiler/xla/status.h" 25 #include "tensorflow/compiler/xla/statusor.h" 26 #include "tensorflow/compiler/xla/xla_data.pb.h" 27 28 // Common place for all collective thunks to include nccl/rccl headers. 29 #if TENSORFLOW_USE_ROCM 30 #include "rocm/include/rccl/rccl.h" 31 #else 32 #include "third_party/nccl/nccl.h" 33 #endif 34 35 namespace xla { 36 namespace gpu { 37 38 ncclRedOp_t ToNcclReduction(ReductionKind kind); 39 StatusOr<std::pair<ncclDataType_t, int>> ToNcclDataTypeAndCountMultiplier( 40 PrimitiveType element_type); 41 42 bool IsGlobalNcclConfig(); 43 bool IsNcclLaunchModeParallel(); 44 45 Status ToStatus(ncclResult_t s, const char* file, int64_t line, 46 const char* expr); 47 48 // Macros to return or warn on CUDA/NCCL errors. (The same macro works for both 49 // NCCL and CUDA errors.) 50 // 51 // It's tempting to say these macros belong in an XLA header somewhere, but in 52 // practice we don't do much direct-to-CUDA-API stuff outside of this file. 53 #define XLA_CUDA_STATUS(expr) \ 54 xla::gpu::ToStatus(expr, __FILE__, __LINE__, #expr) 55 56 #define XLA_CUDA_RETURN_IF_ERROR(expr) \ 57 do { \ 58 Status s = XLA_CUDA_STATUS(expr); \ 59 if (!s.ok()) { \ 60 return s; \ 61 } \ 62 } while (0) 63 64 #define XLA_CUDA_WARN_IF_ERROR(expr) \ 65 do { \ 66 Status s = XLA_CUDA_STATUS(expr); \ 67 if (!s.ok()) { \ 68 LOG(ERROR) << s.ToString(); \ 69 } \ 70 } while (0) 71 72 size_t GetNumLocalParticipants( 73 const std::vector<GlobalDeviceId>& participants, 74 const std::vector<GlobalDeviceId>* local_devices); // may be null 75 76 StatusOr<const NcclUniqueIdCallback*> GetNcclUniqueIdCallback( 77 const NcclUniqueIdCallback* unique_id_callback, // may be null 78 bool is_local); 79 80 // Represents a type that requires mutually exclusive access. 81 template <typename T> 82 class Lockable { 83 public: 84 // RAII type that will release the exclusive lock when it is destroyed. 85 using Lock = std::unique_ptr<T, std::function<void(T*)>>; 86 87 Lockable() = default; Lockable(T value)88 explicit Lockable(T value) : value_(std::move(value)) {} 89 Lockable(const Lockable&) = delete; 90 Lockable(Lockable&&) = delete; 91 Lockable& operator=(const Lockable&) = delete; 92 Lockable& operator=(Lockable&&) = delete; 93 Acquire()94 Lock Acquire() { 95 absl::MutexLock lock(&mutex_); 96 mutex_.Await(absl::Condition(&is_unlocked_)); 97 is_unlocked_ = false; 98 99 return {&value_, [this](T*) { 100 absl::MutexLock lock(&mutex_); 101 CHECK(!is_unlocked_); 102 is_unlocked_ = true; 103 }}; 104 } 105 106 private: 107 T value_; 108 absl::Mutex mutex_; 109 bool is_unlocked_ ABSL_GUARDED_BY(mutex_) = true; 110 }; 111 112 TF_LIB_GTL_DEFINE_INT_TYPE(OpId, int64_t); 113 114 struct NcclComm : public Lockable<ncclComm_t> { NcclCommNcclComm115 NcclComm() : Lockable(nullptr) {} 116 }; 117 118 StatusOr<NcclComm::Lock> AcquireNcclComm( 119 RunId run_id, OpId op_id, std::vector<GlobalDeviceId> participants, 120 size_t num_local_participants, 121 const NcclUniqueIdCallback& unique_id_callback, int rank); 122 123 } // namespace gpu 124 } // namespace xla 125 126 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_UTILS_H_ 127