xref: /aosp_15_r20/external/tensorflow/tensorflow/core/nccl/nccl_manager.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_NCCL_NCCL_MANAGER_H_
16 #define TENSORFLOW_CORE_NCCL_NCCL_MANAGER_H_
17 
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 
20 #include <vector>
21 
22 // TODO(rmlarsen): Get rid of this workaround. "gpu_assert" is defined when
23 // setting EIGEN_USE_THREADS. But when defining EIGEN_USE_THREADS here,
24 // incAtomic and other CUDA specific symbols are no longer recognized.
25 #ifndef gpu_assert
26 #define gpu_assert(x)
27 #endif
28 
29 #include "absl/container/flat_hash_map.h"
30 #if GOOGLE_CUDA
31 #include "third_party/nccl/nccl.h"
32 #elif TENSORFLOW_USE_ROCM
33 #include "rocm/include/rccl/rccl.h"
34 #include "tensorflow/core/common_runtime/gpu_device_context.h"
35 #endif
36 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
37 #include "tensorflow/core/framework/device_base.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/platform/mutex.h"
40 #include "tensorflow/core/platform/stream_executor.h"
41 
42 namespace tensorflow {
43 
44 // NCCL manager is used to make the asynchronous communicator calls and to
45 // manage the per-device streams used for communication.
46 //
47 // See nccl_ops.cc for example usage, including description of memory
48 // management and stream synchronization.
49 class NcclManager {
50  public:
51   typedef std::function<void(Status)> DoneCallback;
52   NcclManager();
53   ~NcclManager();
54 
55   static NcclManager* instance();
56 
57 #if TENSORFLOW_USE_ROCM
58   static int instance_count;
59 #endif
60 
61   // Calls `ncclGetUniqueId` and returns the id as a string.  The returned value
62   // may be shared with other participants on different nodes and passed in to
63   // multi-node collective invocations.
64   string GenerateCommunicatorKey();
65 
66   // A participant in a Collective.
67   struct Participant {
ParticipantParticipant68     Participant(se::StreamExecutor* executor, se::Stream* tensor_stream,
69                 const DeviceBase::AcceleratorDeviceInfo* info,
70                 const Tensor* input, Tensor* output, int global_rank,
71                 DoneCallback done_callback)
72         : executor(executor),
73           tensor_stream(tensor_stream),
74           event_mgr(info->event_mgr),
75           gpu_device_id(info->gpu_id),
76 #if TENSORFLOW_USE_ROCM
77           context(static_cast<GPUDeviceContext*>(info->default_context)),
78 #endif
79           input(input),
80           output(output),
81           global_rank(global_rank),
82           done_callback(std::move(done_callback)),
83           root(false) {
84       DCHECK(executor != nullptr);
85       DCHECK(event_mgr != nullptr);
86       DCHECK(tensor_stream != nullptr);
87     }
88 
89     // StreamExecutor for the device. Expected to be live for process lifetime.
90     se::StreamExecutor* const executor = nullptr;
91 
92     // `tensor_stream` is the stream that should be waited on to ensure
93     // `input`'s data is available on the GPU for the communication stream to
94     // access. It is also the stream that will use the produced data;
95     // `done_callback` is not called until the next kernel launched on `stream`
96     // would see the data. Owned by the caller, who must keep it live until
97     // `done_callback` is called.
98     se::Stream* const tensor_stream;
99 
100     // EventMgr which polls on executor.
101     // Owned by the caller, who must keep it live until `done_callback` is
102     // called.
103     EventMgr* const event_mgr;
104 
105     const int gpu_device_id;
106 
107 #if TENSORFLOW_USE_ROCM
108     GPUDeviceContext* const context;
109 #endif
110 
111     // Owned by the caller, who must keep it live until `done_callback` is
112     // called. Is NULL for participants that only receive data.
113     const Tensor* input;
114 
115     // Owned by the caller, who must keep it live until `done_callback` is
116     // called. Is NULL for participants that only send data.
117     Tensor* output;
118 
119     // Rank across all devices and all nodes.
120     // `global_rank` is not required for single-node collectives.
121     const int global_rank;
122 
123     // The callback which is called at the completion of the NCCL operation.
124     // When called, `output` has been set to the result of the operation. (note:
125     // the stream may not yet have been synced)
126     DoneCallback done_callback;
127 
128     // True if this is the root of the collective, e.g. source of broadcast.
129     bool root;
130   };
131 
132   // Data that provides context for the collective operation, including the
133   // operation key, number of participants, and communicator key.
134   struct Context {
ContextContext135     Context(const string& collective_key, int num_local_devices,
136             int num_global_devices, const string& communicator_key,
137             int source_rank)
138         : collective_key(collective_key),
139           num_local_devices(num_local_devices),
140           num_global_devices(num_global_devices),
141           communicator_key(communicator_key),
142           source_rank(source_rank) {}
143 
144     // Unique key for this collective instance
145     const string& collective_key;
146 
147     // Devices local to this node
148     int num_local_devices;
149 
150     // Devices across all nodes
151     int num_global_devices;
152 
153     // In order to use NCCL across nodes, the callee first has to generate a
154     // `communicator_key` via `GenerateCommunicatorKey()` function and share
155     // this with all the other nodes.  Each node should pass in this
156     // `communicator_key` to the `NcclManager` functions.
157     // `communicator_key` is not required for single-node collectives and can be
158     // empty.
159     const string& communicator_key;
160 
161     // Rank of broadcast source.
162     int source_rank;
163   };
164 
165   // Adds one participant to an all-reduce.
166   void AddToAllReduce(std::unique_ptr<Participant> participant,
167                       const Context& context, ncclRedOp_t reduction_op);
168 
169   // Adds one participant to an all-gather.
170   void AddToAllGather(std::unique_ptr<Participant> participant,
171                       const Context& context);
172 
173   // AddBroadcastSend and AddBroadcastRecv combine to send data from one sender
174   // to all receivers.
175   void AddBroadcastSend(std::unique_ptr<Participant> participant,
176                         const Context& context);
177   void AddBroadcastRecv(std::unique_ptr<Participant> participant,
178                         const Context& context);
179 
180   // AddReduceSend and AddReduceRecv combine to send data from all senders
181   // to one receiver.
182   void AddReduceSend(std::unique_ptr<Participant> participant,
183                      const Context& context, ncclRedOp_t reduction_op);
184   void AddReduceRecv(std::unique_ptr<Participant> participant,
185                      const Context& context, ncclRedOp_t reduction_op);
186 
187   // Signals that the `Collective` corresponding to `key` is ready to launch
188   // across all nodes participating in this multi-node collective operation.
189   //
190   // This should only be called for multi-node collectives; single-node
191   // collectives are implicitly ready when all participants have called Add*
192   // function.
193   void SignalMultiNodeReady(const string& collective_key);
194 
195   // Aborts all collectives. After abortion, no further collectives can be
196   // launched with this NcclManager.
197   void StartAbort(const Status& s);
198 
199   // Resets a previously aborted NcclManager, making it available for future
200   // collectives.
201   void Reset();
202 
203  private:
204   enum CollectiveType {
205     kAllReduce = 1,
206     kBroadcast = 2,
207     kReduce = 3,
208     kAllGather = 4,
209   };
210   struct Collective;
211   struct Communicator;
212   struct CommunicatorMember;
213   struct NcclStream;
214 
215   // Gets the `Communicator` object that will be used to enqueue NCCL kernels
216   // for `collective`, and returns it via `communicator`.
217   //
218   // This may involve creating CUDA streams and NCCL initialization.  If a NCCL
219   // or CUDA error occurs in the process, this returns an INTERNAL error with
220   // the corresponding NCCL/CUDA error string.
221   Status GetCommunicator(Collective* collective, Communicator** communicator);
222 
223   // Adds a participant device to the local `Collective` instance corresponding
224   // to `collective_key`.  Launches the `Collective` if it is ready, which it
225   // checks by calling `CheckReady()`.  Also performs consistency and sanity
226   // checks before launching.
227   void AddParticipant(std::unique_ptr<Participant> participant,
228                       const Context& context, CollectiveType collective_type,
229                       ncclRedOp_t reduction_op);
230 
231   // If `collective` is ready to run, removes it from the `collectives_` map and
232   // returns true.  Otherwise returns false.
233   // Assumes `collective_key` corresponds to `collective`.
234   //
235   // A collective is ready to run when all local participants have called Add*
236   // function, and the collective is signalled globally ready via
237   // `SetMultiNodeReady`.
238   bool CheckReady(const string& collective_key, Collective* collective)
239       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
240 
241   // Run <collective>.  This calls takes ownership of <collective>.
242   void RunCollective(Collective* collective);
243   void LoopKernelLaunches(NcclStream* stream);
244 
245   mutex mu_;
246 
247   // Maps key to collectives currently being assembled or run.
248   absl::flat_hash_map<string, Collective*> collectives_ TF_GUARDED_BY(mu_);
249 
250   // Maps a device to the communication streams that make up its collective.
251   // This is used to share the stream across different communicators that
252   // include the same device.
253   absl::flat_hash_map<se::StreamExecutor*, std::vector<NcclStream*>>
254       device_to_comm_streams_ TF_GUARDED_BY(mu_);
255 
256   std::vector<std::unique_ptr<Communicator>> communicators_ TF_GUARDED_BY(mu_);
257 
258   Status status_ TF_GUARDED_BY(mu_);
259 
260   TF_DISALLOW_COPY_AND_ASSIGN(NcclManager);
261 };
262 
263 }  // namespace tensorflow
264 
265 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
266 
267 #endif  // TENSORFLOW_CORE_NCCL_NCCL_MANAGER_H_
268