xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/ring_alg.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/ring_alg.h"
16 
17 #include <stdlib.h>
18 
19 #include <atomic>
20 #include <functional>
21 #include <utility>
22 
23 #include "tensorflow/core/common_runtime/collective_rma_local.h"
24 #include "tensorflow/core/common_runtime/collective_util.h"
25 #include "tensorflow/core/common_runtime/copy_tensor.h"
26 #include "tensorflow/core/common_runtime/device.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/dma_helper.h"
29 #include "tensorflow/core/common_runtime/process_util.h"
30 #include "tensorflow/core/framework/allocator.h"
31 #include "tensorflow/core/framework/device_base.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/core/notification.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/lib/strings/str_util.h"
39 #include "tensorflow/core/lib/strings/strcat.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/types.h"
42 
43 // Set true for greater intelligibility of debug mode log messages.
44 #define READABLE_KEYS false
45 // A ring algorithm exchanges chunks of tensor between devices.  The chunk size
46 // depends on the number of subdivisions specified in the algorithm.  If the
47 // user does not specify the number of subdivisions we may infer the number
48 // dynamically so that the resulting chunk size does not exceed
49 // kMaxChunkSizeBytes, empirically set at 4 MiB.
50 constexpr size_t kMaxChunkSizeBytes = (4 * 1024 * 1024);
51 // kMaxSubdivsPerDeviceDefault is used to give an upper bound on the number of
52 // subdivisions dynamically generated when user does not provide the parameter
53 // through the collectives API. A reasonable value would be a small
54 // multiple of the number of NICs adjacent to each device.
55 constexpr int kMaxSubdivsPerDeviceDefault = 2;
56 
57 namespace tensorflow {
58 namespace {
59 // Each CollectiveOp implementation is free to define its own
60 // BufRendezvous key format.  This function produces the key used by
61 // RingAlg instances.  Note that the exec_key will differentiate between
62 // different instances consequently we don't need to further differentiate
63 // between subclasses of RingAlg.
RingAlgBufKey(const string & name,const string & exec_key,int pass,int section,int source_rank)64 string RingAlgBufKey(const string& name, const string& exec_key, int pass,
65                      int section, int source_rank) {
66   if (READABLE_KEYS) {
67     return strings::StrCat(name, "(", exec_key, "):pass(", pass, "):section(",
68                            section, "):srcrank(", source_rank, ")");
69   } else {
70     // TODO(b/78352018): Try out some kind of denser encoding, e.g. 128 bit
71     // hash.
72     return strings::StrCat(exec_key, ":", pass, ":", section, ":", source_rank);
73   }
74 }
75 
76 }  // namespace
77 
Enqueue(RingField * rf)78 void RingAlg::PCQueue::Enqueue(RingField* rf) {
79   mutex_lock l(pcq_mu_);
80   deque_.push_back(rf);
81   if (waiter_count_ > 0) {
82     cv_.notify_one();
83   }
84 }
85 
Dequeue()86 RingAlg::RingField* RingAlg::PCQueue::Dequeue() {
87   mutex_lock l(pcq_mu_);
88   if (deque_.empty()) {
89     ++waiter_count_;
90     while (deque_.empty()) {
91       cv_.wait(l);
92     }
93     --waiter_count_;
94   }
95   RingField* rf = deque_.front();
96   deque_.pop_front();
97   return rf;
98 }
99 
RingAlg(CollectiveType type,const string & name)100 RingAlg::RingAlg(CollectiveType type, const string& name)
101     : type_(type),
102       name_(name),
103       col_ctx_(nullptr),
104       col_params_(nullptr),
105       done_(nullptr),
106       group_size_(-1),
107       num_subdivs_(-1) {}
108 
109 namespace {
GenerateSubdivsInCollectiveParams(CollectiveParams * col_params)110 Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) {
111   // This function generates subdivision_offsets. Expect it to be empty when
112   // called.
113   DCHECK(col_params->instance.impl_details.subdiv_offsets.empty());
114 
115   if (col_params->instance.impl_details.max_subdivs_per_device == -1) {
116     col_params->instance.impl_details.subdiv_offsets = {0};
117     VLOG(2) << "Limiting to 1 subdivision as max_subdivs_per_device == -1";
118     return OkStatus();
119   }
120 
121   if (col_params->instance.shape.num_elements() == 0) {
122     return errors::Internal("shape in CollectiveParams should be non-empty");
123   }
124   const int kAvgDevPerTask =
125       col_params->group.group_size / col_params->group.num_tasks;
126   const int max_subdivs_per_device =
127       (col_params->instance.impl_details.max_subdivs_per_device > 0)
128           ? col_params->instance.impl_details.max_subdivs_per_device
129           : kMaxSubdivsPerDeviceDefault;
130   const int kMaxNumSubdivs = max_subdivs_per_device * kAvgDevPerTask;
131   if (kMaxNumSubdivs <= 0) {
132     return errors::Internal("Unexpected kMaxNumSubdivs ", kMaxNumSubdivs,
133                             " in ",
134                             col_params->instance.impl_details.collective_name);
135   }
136   // NOTE(ayushd): If no subdiv_offsets have been specified, dynamically add
137   // as many offsets as needed so that the size of tensor chunks <=
138   // kMaxChunkSizeBytes.  Empirically, chunks that are too small or too large
139   // lead to worse performance.
140   int num_subdivs = 0;
141   const size_t tensor_size = col_params->instance.shape.num_elements() *
142                              DataTypeSize(col_params->instance.data_type);
143   size_t chunk_size;
144   do {
145     ++num_subdivs;
146     int num_chunks = col_params->group.group_size * num_subdivs;
147     chunk_size = tensor_size / num_chunks;
148     VLOG(2) << "num_subdivs " << num_subdivs << " num_chunks " << num_chunks
149             << " chunk_size " << chunk_size;
150   } while (chunk_size > kMaxChunkSizeBytes && num_subdivs < kMaxNumSubdivs);
151   if (num_subdivs <= 0) {
152     return errors::Internal("Unexpected num_subdivs ", num_subdivs, " in ",
153                             col_params->instance.impl_details.collective_name);
154   }
155 
156   int subdiv_stride = kAvgDevPerTask / num_subdivs;
157   if (subdiv_stride == 0) subdiv_stride = 1;
158   col_params->instance.impl_details.subdiv_offsets.reserve(num_subdivs);
159   for (int sdi = 0; sdi < num_subdivs; ++sdi) {
160     int subdiv_offset = subdiv_stride * sdi;
161     if (sdi % 2 == 1) subdiv_offset *= -1;
162     col_params->instance.impl_details.subdiv_offsets.push_back(subdiv_offset);
163   }
164 
165   if (VLOG_IS_ON(2)) {
166     string subdiv_buf;
167     for (const int subdiv_offset :
168          col_params->instance.impl_details.subdiv_offsets) {
169       strings::StrAppend(&subdiv_buf, " ", subdiv_offset);
170     }
171     VLOG(2) << "Dynamically generated " << num_subdivs
172             << " subdiv_offsets:" << subdiv_buf << " tensor_size "
173             << tensor_size << " chunk_size " << chunk_size;
174   }
175 
176   return OkStatus();
177 }
178 }  // namespace
179 
InitializeCollectiveParams(CollectiveParams * col_params)180 Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) {
181   const string& device_name =
182       col_params->group.members[col_params->default_rank].device.name();
183   // Each subdiv permutation is a ring formed by rotating each
184   // single-task subsequence of devices by an offset.  This makes most
185   // sense when each task has the same number of devices but we can't
186   // depend on that being the case so we'll compute something that
187   // works in any case.
188 
189   // Start by counting the devices in each task.
190   // Precondition: device_names must be sorted so that all devices in
191   // the same task are adjacent.
192   std::vector<int> dev_per_task;
193   const string* prior_task_name = &col_params->group.members[0].task;
194   int dev_count = 1;
195   for (int di = 1; di < col_params->group.group_size; ++di) {
196     if (col_params->group.members[di].task != *prior_task_name) {
197       dev_per_task.push_back(dev_count);
198       dev_count = 1;
199       prior_task_name = &col_params->group.members[di].task;
200     } else {
201       ++dev_count;
202     }
203   }
204   dev_per_task.push_back(dev_count);
205   DCHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
206 
207   if (col_params->instance.impl_details.subdiv_offsets.empty()) {
208     TF_RETURN_IF_ERROR(GenerateSubdivsInCollectiveParams(col_params));
209   }
210 
211   // Generate a ring permutation for requested offset.
212   VLOG(2) << "Setting up perms for col_params " << col_params
213           << " subdiv_permutations "
214           << &col_params->instance.impl_details.subdiv_permutations;
215   col_params->instance.impl_details.subdiv_permutations.resize(
216       col_params->instance.impl_details.subdiv_offsets.size());
217   col_params->subdiv_rank.resize(
218       col_params->instance.impl_details.subdiv_offsets.size(), -1);
219   for (int sdi = 0;
220        sdi < col_params->instance.impl_details.subdiv_offsets.size(); ++sdi) {
221     std::vector<int>& perm =
222         col_params->instance.impl_details.subdiv_permutations[sdi];
223     DCHECK_EQ(perm.size(), 0);
224     int offset = col_params->instance.impl_details.subdiv_offsets[sdi];
225     // A negative subdivision offset is interpreted as follows:
226     //  1. Reverse the local device ordering.
227     //  2. Begin the subdivision at abs(offset) in the reversed ordering.
228     bool reverse = false;
229     if (offset < 0) {
230       offset = abs(offset);
231       reverse = true;
232     }
233     int prior_dev_count = 0;  // sum over prior worker device counts
234     for (int ti = 0; ti < col_params->group.num_tasks; ++ti) {
235       for (int di = 0; di < dev_per_task[ti]; ++di) {
236         int di_offset = (di + offset) % dev_per_task[ti];
237         int offset_di =
238             reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset;
239         // Device index in global subdivision permutation.
240         int permuted_di = prior_dev_count + offset_di;
241         int rank = static_cast<int>(perm.size());
242         perm.push_back(permuted_di);
243         if (col_params->group.members[permuted_di].device.name() ==
244             device_name) {
245           DCHECK_EQ(permuted_di, col_params->default_rank);
246           col_params->subdiv_rank[sdi] = rank;
247         }
248       }
249       prior_dev_count += dev_per_task[ti];
250     }
251     DCHECK_EQ(col_params->group.group_size, perm.size());
252   }
253 
254   VLOG(2) << collective_util::SubdivPermDebugString(*col_params);
255   return OkStatus();
256 }
257 
InitializeCollectiveContext(std::shared_ptr<CollectiveContext> col_ctx)258 Status RingAlg::InitializeCollectiveContext(
259     std::shared_ptr<CollectiveContext> col_ctx) {
260   DCHECK(col_ctx->dev_mgr);
261   col_ctx_ = col_ctx;
262   col_params_ = col_ctx->col_params.get();
263   return collective_util::InitializeDeviceAndLocality(
264       col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
265       &col_ctx->device_locality);
266 }
267 
TensorDebugString(const Tensor & tensor)268 string RingAlg::TensorDebugString(const Tensor& tensor) {
269   const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info =
270       col_ctx_->op_ctx->device()->tensorflow_accelerator_device_info();
271   if (accelerator_device_info) {
272     Tensor cpu_tensor(tensor.dtype(), tensor.shape());
273     Status st =
274         accelerator_device_info->default_context->CopyDeviceTensorToCPUSync(
275             &tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor);
276     DCHECK(st.ok());
277     return cpu_tensor.SummarizeValue(64);
278   } else {
279     return tensor.SummarizeValue(64);
280   }
281 }
282 
StartAbort(const Status & s)283 void RingAlg::StartAbort(const Status& s) {
284   // In abort mode we stop issuing additional ProvideBuf
285   // and ConsumeBuf calls, but we need to wait for all of the
286   // outstanding callbacks to be invoked before quitting.
287   bool abort_started = false;
288   {
289     mutex_lock l(status_mu_);
290     if (status_.ok()) {
291       LOG(ERROR) << "Aborting Ring" << name_ << " with " << s;
292       abort_started = true;
293       status_.Update(s);
294     }
295   }
296   // If this is the initial entry to abort mode and it's not a cancellation,
297   // then invoke StartAbort on the CollectiveExecutor that invoked us.  That
298   // should start cancellation on all of the outstanding CollectiveRemoteAccess
299   // actions. If it's cancellation all pending send/recv should be cancelled as
300   // well and there's then no need to abort.
301   if (abort_started) {
302     if (col_ctx_->op_ctx->cancellation_manager() == nullptr ||
303         (!col_ctx_->op_ctx->cancellation_manager()->IsCancelled() &&
304          !col_ctx_->op_ctx->cancellation_manager()->IsCancelling())) {
305       col_ctx_->col_exec->StartAbort(s);
306     }
307   }
308 }
309 
Finish(bool ok)310 void RingAlg::Finish(bool ok) {
311   if (ok) {
312     // Recover the output from the adaptor.
313     ca_->ConsumeFinalValue(col_ctx_->output);
314   }
315   Status s;
316   {
317     mutex_lock l(status_mu_);
318     s = status_;
319   }
320   rfv_.clear();  // Give up Refs on output tensor.
321   done_(s);
322 }
323 
324 // At the beginning of the algorithm initialize a RingField struct for
325 // every independent field of the tensor.
InitRingField(RingField * rf,int chunk_idx,int subdiv_idx,int field_idx)326 void RingAlg::InitRingField(RingField* rf, int chunk_idx, int subdiv_idx,
327                             int field_idx) {
328   // Note on field indexing: There are group_size_ devices in the
329   // instance, implying the same number of chunks per tensor, where a
330   // chunk is the unit of data transferred in a time step.  However, if
331   // a device can simultaneously send data by 2 or more independent
332   // channels we can speed up the transfer by subdividing chunks and
333   // processing multiple subdivisions at once.  So the actual number
334   // of RingFields is group_size_ * num_subdivs_.
335   DCHECK_EQ(field_idx, (chunk_idx * num_subdivs_) + subdiv_idx);
336   rf->chunk_idx = chunk_idx;
337   rf->subdiv_idx = subdiv_idx;
338   rf->sc_idx = field_idx;
339   rf->rank = col_params_->subdiv_rank[subdiv_idx];
340   rf->second_pass = false;
341   rf->action = RF_INIT;
342   // Recv from the device with preceding rank within the subdivision.
343   int recv_from_rank = (rf->rank + (group_size_ - 1)) % group_size_;
344   int send_to_rank = (rf->rank + 1) % group_size_;
345   rf->recv_dev_idx = col_params_->instance.impl_details
346                          .subdiv_permutations[subdiv_idx][recv_from_rank];
347   int send_dev_idx = col_params_->instance.impl_details
348                          .subdiv_permutations[subdiv_idx][send_to_rank];
349   rf->recv_is_remote = !col_params_->group.members[rf->recv_dev_idx].is_local;
350   rf->send_is_remote = !col_params_->group.members[send_dev_idx].is_local;
351   if (ca_->ChunkBytes(rf->sc_idx) > 0) {
352     // In pass 0 we skip Recv when rank = chunk_idx
353     rf->do_recv = (rf->chunk_idx != rf->rank);
354     // In pass 0 we skip Send when rank = chunk_idx-1
355     rf->do_send =
356         (rf->rank != ((rf->chunk_idx + (group_size_ - 1)) % group_size_));
357   }
358   rf->is_final =
359       (rf->rank == ((rf->chunk_idx + (group_size_ - 1)) % group_size_));
360   if (rf->do_send || rf->do_recv) {
361     rf->chunk = ca_->ChunkAlias(rf->sc_idx);
362   }
363   VLOG(2) << this << " InitRingField " << rf->DebugString() << " chunk "
364           << ca_->TBounds(rf->chunk);
365 }
366 
367 // When a RingField transitions from first to second recompute the
368 // do_send and do_recv values.
AdvanceToSecondPass(RingField * rf)369 void RingAlg::AdvanceToSecondPass(RingField* rf) {
370   VLOG(3) << "IncrRingField old value " << rf->DebugString();
371   DCHECK(!rf->second_pass);
372   rf->second_pass = true;
373   rf->action = RF_INIT;
374   if (ca_->ChunkBytes(rf->sc_idx) > 0) {
375     // In pass 1 the send/no-send boundary moves down 1 place.
376     rf->do_recv =
377         (rf->rank != ((rf->chunk_idx + (group_size_ - 1)) % group_size_));
378     rf->do_send =
379         (rf->rank != ((rf->chunk_idx + (group_size_ - 2)) % group_size_));
380   }
381   rf->is_final =
382       (rf->rank == ((rf->chunk_idx + (group_size_ - 2)) % group_size_));
383   VLOG(3) << "IncrRingField new value " << rf->DebugString();
384 }
385 
DebugString() const386 string RingAlg::RingField::DebugString() const {
387   string rv = strings::StrCat("RingField rank=", rank, " chunk_idx=", chunk_idx,
388                               " subdiv=", subdiv_idx, " sc_idx=", sc_idx,
389                               " action=", action);
390   strings::StrAppend(&rv, " pass=", second_pass);
391   strings::StrAppend(&rv, " do_send=", do_send, " do_recv=", do_recv,
392                      " is_final=", is_final, " recv_is_remote=", recv_is_remote,
393                      " recv_dev_idx=", recv_dev_idx, " sc_idx=", sc_idx);
394   return rv;
395 }
396 
DispatchSend(RingField * rf,const StatusCallback & done)397 void RingAlg::DispatchSend(RingField* rf, const StatusCallback& done) {
398   DCHECK(rf->do_send);
399   string send_buf_key = RingAlgBufKey(name_, col_ctx_->exec_key,
400                                       rf->second_pass, rf->sc_idx, rf->rank);
401   VLOG(3) << "DispatchSend rank=" << col_params_->default_rank << " send key "
402           << send_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " sc_idx "
403           << rf->sc_idx;
404   int send_to_rank = (rf->rank + 1) % group_size_;
405   int send_to_dev_idx = col_params_->instance.impl_details
406                             .subdiv_permutations[rf->subdiv_idx][send_to_rank];
407   col_ctx_->col_exec->remote_access()->PostToPeer(
408       col_params_->group.members[send_to_dev_idx].device.name(),
409       col_params_->group.members[send_to_dev_idx].task, send_buf_key,
410       col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
411       col_ctx_->op_ctx->output_alloc_attr(0), &rf->chunk,
412       col_ctx_->device_locality, col_ctx_->op_ctx->cancellation_manager(),
413       done);
414 }
415 
DispatchRecv(RingField * rf,const StatusCallback & done)416 void RingAlg::DispatchRecv(RingField* rf, const StatusCallback& done) {
417   DCHECK(rf->do_recv);
418   string recv_buf_key =
419       RingAlgBufKey(name_, col_ctx_->exec_key, rf->second_pass, rf->sc_idx,
420                     (rf->rank + (group_size_ - 1)) % group_size_);
421   VLOG(3) << "DispatchRecv rank=" << col_params_->default_rank << " recv key "
422           << recv_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " into "
423           << ((col_params_->merge_op != nullptr) ? "tmp_chunk" : "chunk");
424   Tensor* dst_tensor = (!rf->second_pass && (col_params_->merge_op != nullptr))
425                            ? &rf->tmp_chunk
426                            : &rf->chunk;
427   col_ctx_->col_exec->remote_access()->RecvFromPeer(
428       col_params_->group.members[rf->recv_dev_idx].device.name(),
429       col_params_->group.members[rf->recv_dev_idx].task,
430       col_params_->group.members[rf->recv_dev_idx].is_local, recv_buf_key,
431       col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
432       col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
433       col_ctx_->device_locality, rf->subdiv_idx,
434       col_ctx_->op_ctx->cancellation_manager(), done);
435 }
436 
FieldState()437 string RingAlg::FieldState() {
438   string s = strings::StrCat(
439       "Ring", name_, " ", strings::Hex(reinterpret_cast<uint64>(this)),
440       " exec ", col_ctx_->exec_key, " step_id=", col_ctx_->step_id,
441       " state of all ", rfv_.size(), " fields:");
442   for (int i = 0; i < rfv_.size(); ++i) {
443     s.append("\n");
444     s.append(rfv_[i].DebugString());
445   }
446   return s;
447 }
448 
449 }  // namespace tensorflow
450