1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_NCCL_NCCL_MANAGER_H_ 16 #define TENSORFLOW_CORE_NCCL_NCCL_MANAGER_H_ 17 18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 19 20 #include <vector> 21 22 // TODO(rmlarsen): Get rid of this workaround. "gpu_assert" is defined when 23 // setting EIGEN_USE_THREADS. But when defining EIGEN_USE_THREADS here, 24 // incAtomic and other CUDA specific symbols are no longer recognized. 25 #ifndef gpu_assert 26 #define gpu_assert(x) 27 #endif 28 29 #include "absl/container/flat_hash_map.h" 30 #if GOOGLE_CUDA 31 #include "third_party/nccl/nccl.h" 32 #elif TENSORFLOW_USE_ROCM 33 #include "rocm/include/rccl/rccl.h" 34 #include "tensorflow/core/common_runtime/gpu_device_context.h" 35 #endif 36 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" 37 #include "tensorflow/core/framework/device_base.h" 38 #include "tensorflow/core/framework/tensor.h" 39 #include "tensorflow/core/platform/mutex.h" 40 #include "tensorflow/core/platform/stream_executor.h" 41 42 namespace tensorflow { 43 44 // NCCL manager is used to make the asynchronous communicator calls and to 45 // manage the per-device streams used for communication. 46 // 47 // See nccl_ops.cc for example usage, including description of memory 48 // management and stream synchronization. 49 class NcclManager { 50 public: 51 typedef std::function<void(Status)> DoneCallback; 52 NcclManager(); 53 ~NcclManager(); 54 55 static NcclManager* instance(); 56 57 #if TENSORFLOW_USE_ROCM 58 static int instance_count; 59 #endif 60 61 // Calls `ncclGetUniqueId` and returns the id as a string. The returned value 62 // may be shared with other participants on different nodes and passed in to 63 // multi-node collective invocations. 64 string GenerateCommunicatorKey(); 65 66 // A participant in a Collective. 67 struct Participant { ParticipantParticipant68 Participant(se::StreamExecutor* executor, se::Stream* tensor_stream, 69 const DeviceBase::AcceleratorDeviceInfo* info, 70 const Tensor* input, Tensor* output, int global_rank, 71 DoneCallback done_callback) 72 : executor(executor), 73 tensor_stream(tensor_stream), 74 event_mgr(info->event_mgr), 75 gpu_device_id(info->gpu_id), 76 #if TENSORFLOW_USE_ROCM 77 context(static_cast<GPUDeviceContext*>(info->default_context)), 78 #endif 79 input(input), 80 output(output), 81 global_rank(global_rank), 82 done_callback(std::move(done_callback)), 83 root(false) { 84 DCHECK(executor != nullptr); 85 DCHECK(event_mgr != nullptr); 86 DCHECK(tensor_stream != nullptr); 87 } 88 89 // StreamExecutor for the device. Expected to be live for process lifetime. 90 se::StreamExecutor* const executor = nullptr; 91 92 // `tensor_stream` is the stream that should be waited on to ensure 93 // `input`'s data is available on the GPU for the communication stream to 94 // access. It is also the stream that will use the produced data; 95 // `done_callback` is not called until the next kernel launched on `stream` 96 // would see the data. Owned by the caller, who must keep it live until 97 // `done_callback` is called. 98 se::Stream* const tensor_stream; 99 100 // EventMgr which polls on executor. 101 // Owned by the caller, who must keep it live until `done_callback` is 102 // called. 103 EventMgr* const event_mgr; 104 105 const int gpu_device_id; 106 107 #if TENSORFLOW_USE_ROCM 108 GPUDeviceContext* const context; 109 #endif 110 111 // Owned by the caller, who must keep it live until `done_callback` is 112 // called. Is NULL for participants that only receive data. 113 const Tensor* input; 114 115 // Owned by the caller, who must keep it live until `done_callback` is 116 // called. Is NULL for participants that only send data. 117 Tensor* output; 118 119 // Rank across all devices and all nodes. 120 // `global_rank` is not required for single-node collectives. 121 const int global_rank; 122 123 // The callback which is called at the completion of the NCCL operation. 124 // When called, `output` has been set to the result of the operation. (note: 125 // the stream may not yet have been synced) 126 DoneCallback done_callback; 127 128 // True if this is the root of the collective, e.g. source of broadcast. 129 bool root; 130 }; 131 132 // Data that provides context for the collective operation, including the 133 // operation key, number of participants, and communicator key. 134 struct Context { ContextContext135 Context(const string& collective_key, int num_local_devices, 136 int num_global_devices, const string& communicator_key, 137 int source_rank) 138 : collective_key(collective_key), 139 num_local_devices(num_local_devices), 140 num_global_devices(num_global_devices), 141 communicator_key(communicator_key), 142 source_rank(source_rank) {} 143 144 // Unique key for this collective instance 145 const string& collective_key; 146 147 // Devices local to this node 148 int num_local_devices; 149 150 // Devices across all nodes 151 int num_global_devices; 152 153 // In order to use NCCL across nodes, the callee first has to generate a 154 // `communicator_key` via `GenerateCommunicatorKey()` function and share 155 // this with all the other nodes. Each node should pass in this 156 // `communicator_key` to the `NcclManager` functions. 157 // `communicator_key` is not required for single-node collectives and can be 158 // empty. 159 const string& communicator_key; 160 161 // Rank of broadcast source. 162 int source_rank; 163 }; 164 165 // Adds one participant to an all-reduce. 166 void AddToAllReduce(std::unique_ptr<Participant> participant, 167 const Context& context, ncclRedOp_t reduction_op); 168 169 // Adds one participant to an all-gather. 170 void AddToAllGather(std::unique_ptr<Participant> participant, 171 const Context& context); 172 173 // AddBroadcastSend and AddBroadcastRecv combine to send data from one sender 174 // to all receivers. 175 void AddBroadcastSend(std::unique_ptr<Participant> participant, 176 const Context& context); 177 void AddBroadcastRecv(std::unique_ptr<Participant> participant, 178 const Context& context); 179 180 // AddReduceSend and AddReduceRecv combine to send data from all senders 181 // to one receiver. 182 void AddReduceSend(std::unique_ptr<Participant> participant, 183 const Context& context, ncclRedOp_t reduction_op); 184 void AddReduceRecv(std::unique_ptr<Participant> participant, 185 const Context& context, ncclRedOp_t reduction_op); 186 187 // Signals that the `Collective` corresponding to `key` is ready to launch 188 // across all nodes participating in this multi-node collective operation. 189 // 190 // This should only be called for multi-node collectives; single-node 191 // collectives are implicitly ready when all participants have called Add* 192 // function. 193 void SignalMultiNodeReady(const string& collective_key); 194 195 // Aborts all collectives. After abortion, no further collectives can be 196 // launched with this NcclManager. 197 void StartAbort(const Status& s); 198 199 // Resets a previously aborted NcclManager, making it available for future 200 // collectives. 201 void Reset(); 202 203 private: 204 enum CollectiveType { 205 kAllReduce = 1, 206 kBroadcast = 2, 207 kReduce = 3, 208 kAllGather = 4, 209 }; 210 struct Collective; 211 struct Communicator; 212 struct CommunicatorMember; 213 struct NcclStream; 214 215 // Gets the `Communicator` object that will be used to enqueue NCCL kernels 216 // for `collective`, and returns it via `communicator`. 217 // 218 // This may involve creating CUDA streams and NCCL initialization. If a NCCL 219 // or CUDA error occurs in the process, this returns an INTERNAL error with 220 // the corresponding NCCL/CUDA error string. 221 Status GetCommunicator(Collective* collective, Communicator** communicator); 222 223 // Adds a participant device to the local `Collective` instance corresponding 224 // to `collective_key`. Launches the `Collective` if it is ready, which it 225 // checks by calling `CheckReady()`. Also performs consistency and sanity 226 // checks before launching. 227 void AddParticipant(std::unique_ptr<Participant> participant, 228 const Context& context, CollectiveType collective_type, 229 ncclRedOp_t reduction_op); 230 231 // If `collective` is ready to run, removes it from the `collectives_` map and 232 // returns true. Otherwise returns false. 233 // Assumes `collective_key` corresponds to `collective`. 234 // 235 // A collective is ready to run when all local participants have called Add* 236 // function, and the collective is signalled globally ready via 237 // `SetMultiNodeReady`. 238 bool CheckReady(const string& collective_key, Collective* collective) 239 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 240 241 // Run <collective>. This calls takes ownership of <collective>. 242 void RunCollective(Collective* collective); 243 void LoopKernelLaunches(NcclStream* stream); 244 245 mutex mu_; 246 247 // Maps key to collectives currently being assembled or run. 248 absl::flat_hash_map<string, Collective*> collectives_ TF_GUARDED_BY(mu_); 249 250 // Maps a device to the communication streams that make up its collective. 251 // This is used to share the stream across different communicators that 252 // include the same device. 253 absl::flat_hash_map<se::StreamExecutor*, std::vector<NcclStream*>> 254 device_to_comm_streams_ TF_GUARDED_BY(mu_); 255 256 std::vector<std::unique_ptr<Communicator>> communicators_ TF_GUARDED_BY(mu_); 257 258 Status status_ TF_GUARDED_BY(mu_); 259 260 TF_DISALLOW_COPY_AND_ASSIGN(NcclManager); 261 }; 262 263 } // namespace tensorflow 264 265 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 266 267 #endif // TENSORFLOW_CORE_NCCL_NCCL_MANAGER_H_ 268