xref: /aosp_15_r20/external/tensorflow/tensorflow/core/nccl/collective_communicator.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/nccl/collective_communicator.h"
17 
18 #include "tensorflow/core/framework/cancellation.h"
19 
20 #if TENSORFLOW_USE_NCCL && (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/nccl/nccl_manager.h"
24 #include "tensorflow/core/platform/tracing.h"
25 #include "tensorflow/core/profiler/lib/traceme.h"
26 
27 namespace tensorflow {
28 
29 class NcclCommunicator : public NcclCommunicatorInterface {
30  public:
GenerateCommunicatorKey()31   string GenerateCommunicatorKey() override {
32     return nccl_manager_.GenerateCommunicatorKey();
33   }
34 
35   void Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
36                StatusCallback done) override;
37 
38   void StartAbort(const Status& s) override;
39 
40  private:
41   NcclManager nccl_manager_;
42 };
43 
44 namespace {
ReductionOp(const string & merge_op,ncclRedOp_t * reduction_op)45 Status ReductionOp(const string& merge_op, ncclRedOp_t* reduction_op) {
46   if (merge_op == "Add") {
47     *reduction_op = ncclSum;
48     return OkStatus();
49   } else if (merge_op == "Mul") {
50     *reduction_op = ncclProd;
51     return OkStatus();
52   } else if (merge_op == "Maximum") {
53     *reduction_op = ncclMax;
54     return OkStatus();
55   } else if (merge_op == "Minimum") {
56     *reduction_op = ncclMin;
57     return OkStatus();
58   } else {
59     return errors::Internal(
60         "Expected merge_op to be in [Add, Mul, Maximum, Minimum], found ",
61         merge_op);
62   }
63 }
64 
NcclCollectiveKey(const string & exec_key,int step_id)65 string NcclCollectiveKey(const string& exec_key, int step_id) {
66   return strings::StrCat(exec_key, ":", step_id);
67 }
68 }  // namespace
69 
MaybeCreateNcclCommunicator(const ConfigProto & config)70 std::unique_ptr<NcclCommunicatorInterface> MaybeCreateNcclCommunicator(
71     const ConfigProto& config) {
72   // Skip creating a NcclCommunicator if there are 0 GPUs configured.
73   const auto& device_count = config.device_count();
74   auto item = device_count.find("GPU");
75   if (item != device_count.end() && item->second == 0) {
76     return nullptr;
77   }
78   return absl::make_unique<NcclCommunicator>();
79 }
80 
Enqueue(std::shared_ptr<CollectiveContext> col_ctx,StatusCallback done)81 void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
82                                StatusCallback done) {
83   const CollectiveParams* col_params = col_ctx->col_params.get();
84   const int num_global_devices = col_params->group.group_size;
85   const int num_local_devices = col_params->group.num_devices_per_task.at(
86       col_params->group.members[col_params->default_rank].task);
87   const string nccl_collective_key =
88       NcclCollectiveKey(col_ctx->exec_key, col_ctx->step_id);
89   auto* compute_stream = col_ctx->op_ctx->op_device_context()->stream();
90   auto* gpu_info =
91       col_ctx->op_ctx->device()->tensorflow_accelerator_device_info();
92   auto participant = absl::make_unique<NcclManager::Participant>(
93       compute_stream->parent(), compute_stream, gpu_info, col_ctx->input,
94       col_ctx->output, col_ctx->col_params->default_rank,
95       /*done_callback=*/nullptr);
96   CancellationManager* cancel_mgr = col_ctx->op_ctx->cancellation_manager();
97   if (cancel_mgr == nullptr) {
98     participant->done_callback = std::move(done);
99   } else {
100     CancellationToken cancel_token = cancel_mgr->get_cancellation_token();
101     bool already_cancelled =
102         !cancel_mgr->RegisterCallback(cancel_token, [this]() {
103           nccl_manager_.StartAbort(errors::Cancelled("op cancelled"));
104           nccl_manager_.Reset();
105         });
106     if (already_cancelled) {
107       done(errors::Cancelled("op cancelled"));
108       return;
109     }
110     participant->done_callback = [cancel_mgr, cancel_token,
111                                   done = std::move(done)](const Status& s) {
112       // Do not block on deregistration since this can be invoked by
113       // NcclManager::StartAbort() in the cancellation callback.
114       cancel_mgr->TryDeregisterCallback(cancel_token);
115       done(s);
116     };
117   }
118   NcclManager::Context context(
119       nccl_collective_key, num_local_devices, num_global_devices,
120       col_params->group.runtime_details.communicator_key,
121       col_params->source_rank);
122   VLOG(1) << "NcclCommunicator::Enqueue type " << col_params->instance.type
123           << " num_tasks " << col_params->group.num_tasks << " current task "
124           << col_params->group.members[col_params->default_rank].task
125           << " num local devices " << num_local_devices
126           << " num global devices " << num_global_devices << " device "
127           << col_ctx->device_name << " instance "
128           << col_params->instance.instance_key;
129   // `AddTo*` performs consistency checks for the NCCL call and enqueues the
130   // `Participant` struct locally.  When all local participants with this
131   // `nccl_collective_key` have called `AddToAllReduce` and
132   // `SignalMultiNodeReady`, all devices at this worker are ready to process
133   // this NCCL op.
134   //
135   // The `NcclManager` uses a dedicated CUDA stream for NCCL kernels.  At this
136   // point, it synchronizes the NCCL stream with the compute stream, and then
137   // enqueues the NCCL kernel on the NCCL stream.
138   switch (col_params->instance.type) {
139     case REDUCTION_COLLECTIVE: {
140       ncclRedOp_t reduction_op;
141       Status s =
142           ReductionOp(col_params->merge_op->type_string(), &reduction_op);
143       if (!s.ok()) {
144         participant->done_callback(s);
145         return;
146       }
147       nccl_manager_.AddToAllReduce(std::move(participant), context,
148                                    reduction_op);
149       break;
150     }
151     case GATHER_COLLECTIVE: {
152       nccl_manager_.AddToAllGather(std::move(participant), context);
153       break;
154     }
155     case BROADCAST_COLLECTIVE: {
156       if (col_params->is_source) {
157         nccl_manager_.AddBroadcastSend(std::move(participant), context);
158       } else {
159         nccl_manager_.AddBroadcastRecv(std::move(participant), context);
160       }
161       break;
162     }
163     default: {
164       participant->done_callback(errors::Internal("Unexpected CollectiveType ",
165                                                   col_params->instance.type));
166       return;
167     }
168   }
169   // NOTE(ayushd): We need to synchronize NCCL launches across nodes to prevent
170   // deadlocks.  In the current implementation, we define a deterministic
171   // sequential launch order between potentially concurrent collective instances
172   // by introducing control information during static graph analysis in
173   // graph/collective_order.cc.  This can be either in the form of explicit
174   // control edges or via `wait_for` attribute on the collective op.
175   //
176   // The other end of the design spectrum would have a distinguished node
177   // dynamically signal the next collective to launch to all other participants.
178   // This has higher degree of runtime coordination, but it may be able to
179   // achieve better performance if the (arbitrary) static execution order
180   // assigned in the first approach turns out to not be good from a scheduling
181   // perspective.  e.g. consider a graph in which c1, c2, and c3 are three
182   // concurrent collective instances, and the static ordering assigns c1 -> c2
183   // -> c3.  In practice, it could turn out that c3 is always ready to execute
184   // before c1 or c2.
185   {
186     // `WaitForDependencies` may block if the collective instances on which this
187     // op depends have not yet launched.  When this function returns, this op is
188     // ready to go.
189     profiler::TraceMe activity("WaitForDependencies",
190                                profiler::TraceMeLevel::kInfo);
191     col_ctx->col_exec->WaitForDependencies(*col_params);
192     nccl_manager_.SignalMultiNodeReady(nccl_collective_key);
193   }
194   {
195     // When all devices at this worker have called `SignalMultiNodeReady`, the
196     // `NcclManager` will enqueue the NCCL kernel on the NCCL stream.  Thus the
197     // implementation of `UnblockDependencies` keeps track of the number of
198     // devices that have launched.
199     profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
200     col_ctx->col_exec->UnblockDependencies(*col_params);
201   }
202 }
203 
StartAbort(const Status & s)204 void NcclCommunicator::StartAbort(const Status& s) {
205   nccl_manager_.StartAbort(s);
206 }
207 
208 }  // namespace tensorflow
209 
210 #else
211 namespace tensorflow {
MaybeCreateNcclCommunicator(const ConfigProto & config)212 std::unique_ptr<NcclCommunicatorInterface> MaybeCreateNcclCommunicator(
213     const ConfigProto& config) {
214   return nullptr;
215 }
216 }  // namespace tensorflow
217 #endif  // TENSORFLOW_USE_NCCL && (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
218