xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/Exception.h>
2 #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
3 
4 #ifdef USE_C10D_GLOO
5 
6 #include <torch/csrc/distributed/c10d/GlooDeviceFactory.hpp>
7 #include <torch/csrc/distributed/c10d/PrefixStore.hpp>
8 #include <chrono>
9 #include <exception>
10 
11 #ifdef _WIN32
12 #include <gloo/common/win.h>
13 #include <winsock2.h>
14 #include <ws2tcpip.h>
15 #else
16 #include <netdb.h>
17 #include <sys/socket.h>
18 #include <unistd.h>
19 #endif
20 #include <sys/types.h>
21 
22 #include <type_traits>
23 #include <utility>
24 
25 #include <gloo/allgather.h>
26 #include <gloo/allgatherv.h>
27 #include <gloo/allreduce.h>
28 #include <gloo/alltoall.h>
29 #include <gloo/alltoallv.h>
30 #include <gloo/barrier.h>
31 #include <gloo/broadcast.h>
32 #include <gloo/gather.h>
33 #include <gloo/reduce.h>
34 #include <gloo/scatter.h>
35 
36 #include <ATen/ThreadLocalState.h>
37 #include <ATen/native/SparseTensorUtils.h>
38 
39 #include <c10/util/StringUtil.h>
40 #include <c10/util/intrusive_ptr.h>
41 #include <c10/util/irange.h>
42 #include <gloo/config.h>
43 #include <gloo/rendezvous/context.h>
44 #include <gloo/rendezvous/prefix_store.h>
45 
46 #ifdef _WIN32
47 #define GENERATE_ALL_TYPES(type, func, ...)      \
48   switch (type) {                                \
49     case ::at::ScalarType::Float:                \
50       func<float>(__VA_ARGS__);                  \
51       break;                                     \
52     case ::at::ScalarType::Double:               \
53       func<double>(__VA_ARGS__);                 \
54       break;                                     \
55     case ::at::ScalarType::Half:                 \
56       func<gloo::float16>(__VA_ARGS__);          \
57       break;                                     \
58     case ::at::ScalarType::BFloat16:             \
59       func<c10::BFloat16>(__VA_ARGS__);          \
60       break;                                     \
61     case ::at::ScalarType::Char:                 \
62       func<int8_t>(__VA_ARGS__);                 \
63       break;                                     \
64     case ::at::ScalarType::Byte:                 \
65     case ::at::ScalarType::Bool:                 \
66       func<uint8_t>(__VA_ARGS__);                \
67       break;                                     \
68     case ::at::ScalarType::Int:                  \
69       func<int32_t>(__VA_ARGS__);                \
70       break;                                     \
71     case ::at::ScalarType::Long:                 \
72       func<int64_t>(__VA_ARGS__);                \
73       break;                                     \
74     default:                                     \
75       TORCH_CHECK(false, "Invalid scalar type"); \
76   }
77 
78 #define HOST_NAME_MAX 256
79 #else
80 #define GENERATE_ALL_TYPES(type, func, args...)  \
81   switch (type) {                                \
82     case ::at::ScalarType::Float:                \
83       func<float>(args);                         \
84       break;                                     \
85     case ::at::ScalarType::Double:               \
86       func<double>(args);                        \
87       break;                                     \
88     case ::at::ScalarType::Half:                 \
89       func<gloo::float16>(args);                 \
90       break;                                     \
91     case ::at::ScalarType::BFloat16:             \
92       func<c10::BFloat16>(args);                 \
93       break;                                     \
94     case ::at::ScalarType::Char:                 \
95       func<int8_t>(args);                        \
96       break;                                     \
97     case ::at::ScalarType::Byte:                 \
98     case ::at::ScalarType::Bool:                 \
99       func<uint8_t>(args);                       \
100       break;                                     \
101     case ::at::ScalarType::Int:                  \
102       func<int32_t>(args);                       \
103       break;                                     \
104     case ::at::ScalarType::Long:                 \
105       func<int64_t>(args);                       \
106       break;                                     \
107     default:                                     \
108       TORCH_CHECK(false, "Invalid scalar type"); \
109   }
110 #endif
111 
112 namespace c10d {
113 
114 namespace {
115 
116 using steady_clock_time_point =
117     std::chrono::time_point<std::chrono::steady_clock>;
118 
getRemainingTime(steady_clock_time_point startTime,const std::chrono::milliseconds & timeout,bool waitAllRanks)119 std::chrono::milliseconds getRemainingTime(
120     steady_clock_time_point startTime,
121     const std::chrono::milliseconds& timeout,
122     bool waitAllRanks) {
123   if (waitAllRanks) {
124     // See Note in monitoredBarrier
125     return timeout;
126   }
127   auto elapsedTime = std::chrono::steady_clock::now() - startTime;
128   auto remainingMillis = timeout -
129       std::chrono::duration_cast<std::chrono::milliseconds>(elapsedTime);
130 
131   // If no more remaining time, return -1 to indicate to caller.
132   if (remainingMillis.count() <= 0) {
133     return std::chrono::milliseconds(-1);
134   }
135 
136   return remainingMillis;
137 }
138 
139 // Emit a LOG(ERROR) and throws using TORCH_CHECK with the given messages.
logAndThrow(const std::string & logMessage,const std::string & errorMessage)140 void logAndThrow(
141     const std::string& logMessage,
142     const std::string& errorMessage) {
143   LOG(ERROR) << logMessage;
144   TORCH_CHECK(false, errorMessage);
145 }
146 
147 // For monitoredBarrier, checks remaining time left to finish processing ranks
148 // and throws error if timeout.
checkRemainingTime(const std::chrono::milliseconds & monitoredBarrierTimeout,const std::chrono::milliseconds & remainingTime,const std::vector<int> & processedRanks,int currentRank)149 void checkRemainingTime(
150     const std::chrono::milliseconds& monitoredBarrierTimeout,
151     const std::chrono::milliseconds& remainingTime,
152     const std::vector<int>& processedRanks,
153     int currentRank) {
154   const std::string kNoRemainingTimeError = c10::str(
155       "Rank ",
156       currentRank,
157       " timed out in monitoredBarrier after ",
158       monitoredBarrierTimeout.count(),
159       " ms.");
160   if (remainingTime.count() < 0) {
161     std::string rankInfo;
162     if (!processedRanks.empty()) {
163       rankInfo = c10::str(
164           "Successfully processed ranks: ", c10::Join(", ", processedRanks));
165     } else {
166       rankInfo = "No ranks successfully processed in monitoredBarrier.";
167     }
168     auto error = c10::str(kNoRemainingTimeError, "\n", rankInfo);
169     logAndThrow(error, error);
170   }
171 }
172 
173 typedef void (*ReduceFunc)(void*, const void*, const void*, size_t);
174 
175 template <typename T, std::enable_if_t<!std::is_integral_v<T>, int> = 0>
toFunction(const ReduceOp & r)176 ReduceFunc toFunction(const ReduceOp& r) {
177   switch (r) {
178     case ReduceOp::SUM:
179       return ReduceFunc(&::gloo::sum<T>);
180     case ReduceOp::PRODUCT:
181       return ReduceFunc(&::gloo::product<T>);
182     case ReduceOp::MIN:
183       return ReduceFunc(&::gloo::min<T>);
184     case ReduceOp::MAX:
185       return ReduceFunc(&::gloo::max<T>);
186     case ReduceOp::BAND:
187       TORCH_CHECK(false, "Cannot use ReduceOp.BAND with non-integral dtype");
188       break;
189     case ReduceOp::BOR:
190       TORCH_CHECK(false, "Cannot use ReduceOp.BOR with non-integral dtype");
191       break;
192     case ReduceOp::BXOR:
193       TORCH_CHECK(false, "Cannot use ReduceOp.BXOR with non-integral dtype");
194       break;
195     case ReduceOp::AVG:
196       TORCH_CHECK(false, "Cannot use ReduceOp.AVG with Gloo");
197       break;
198     case ReduceOp::PREMUL_SUM:
199       TORCH_CHECK(false, "Cannot use ReduceOp.PREMUL_SUM with Gloo");
200       break;
201     case ReduceOp::UNUSED:
202     default:
203       break;
204   }
205 
206   TORCH_CHECK(false, "Unhandled ReduceOp");
207 }
208 
209 // Bitwise AND with SFINAE guard for integral types.
210 template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
band(void * c,const void * a,const void * b,size_t n)211 void band(void* c, const void* a, const void* b, size_t n) {
212   auto tc = static_cast<T*>(c);
213   auto ta = static_cast<const T*>(a);
214   auto tb = static_cast<const T*>(b);
215   for (const auto i : c10::irange(n)) {
216     tc[i] = ta[i] & tb[i];
217   }
218 }
219 
220 // Bitwise OR with SFINAE guard for integral types.
221 template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
bor(void * c,const void * a,const void * b,size_t n)222 void bor(void* c, const void* a, const void* b, size_t n) {
223   auto tc = static_cast<T*>(c);
224   auto ta = static_cast<const T*>(a);
225   auto tb = static_cast<const T*>(b);
226   for (const auto i : c10::irange(n)) {
227     tc[i] = ta[i] | tb[i];
228   }
229 }
230 
231 // Bitwise XOR with SFINAE guard for integral types.
232 template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
bxor(void * c,const void * a,const void * b,size_t n)233 void bxor(void* c, const void* a, const void* b, size_t n) {
234   auto tc = static_cast<T*>(c);
235   auto ta = static_cast<const T*>(a);
236   auto tb = static_cast<const T*>(b);
237   for (const auto i : c10::irange(n)) {
238     tc[i] = ta[i] ^ tb[i];
239   }
240 }
241 
242 template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
toFunction(const ReduceOp & r)243 ReduceFunc toFunction(const ReduceOp& r) {
244   switch (r) {
245     case ReduceOp::SUM:
246       return ReduceFunc(&::gloo::sum<T>);
247     case ReduceOp::PRODUCT:
248       return ReduceFunc(&::gloo::product<T>);
249     case ReduceOp::MIN:
250       return ReduceFunc(&::gloo::min<T>);
251     case ReduceOp::MAX:
252       return ReduceFunc(&::gloo::max<T>);
253     case ReduceOp::BAND:
254       return ReduceFunc(&band<T>);
255     case ReduceOp::BOR:
256       return ReduceFunc(&bor<T>);
257     case ReduceOp::BXOR:
258       return ReduceFunc(&bxor<T>);
259     case ReduceOp::AVG:
260       TORCH_CHECK(false, "Cannot use ReduceOp.AVG with Gloo");
261       break;
262     case ReduceOp::PREMUL_SUM:
263       TORCH_CHECK(false, "Cannot use ReduceOp.PREMUL_SUM with Gloo");
264       break;
265     case ReduceOp::UNUSED:
266     default:
267       break;
268   }
269 
270   TORCH_CHECK(false, "Unhandled ReduceOp");
271 }
272 
273 template <typename T, typename O>
setInputs(O & opts,std::vector<at::Tensor> & tensors)274 void setInputs(O& opts, std::vector<at::Tensor>& tensors) {
275   opts.setInputs(getDataPointers<T>(tensors), tensors[0].numel());
276 }
277 
278 template <typename T, typename O>
setInput(O & opts,at::Tensor & tensor)279 void setInput(O& opts, at::Tensor& tensor) {
280   opts.setInput(getDataPointer<T>(tensor), tensor.numel());
281 }
282 
283 template <typename T, typename O>
setInput(O & opts,at::Tensor & tensor,std::vector<size_t> & counts)284 void setInput(O& opts, at::Tensor& tensor, std::vector<size_t>& counts) {
285   opts.setInput(getDataPointer<T>(tensor), counts);
286 }
287 
288 template <typename T, typename O>
setInput(O & opts,at::Tensor & tensor,std::vector<int64_t> & counts)289 void setInput(O& opts, at::Tensor& tensor, std::vector<int64_t>& counts) {
290   opts.setInput(getDataPointer<T>(tensor), counts);
291 }
292 
293 template <typename T, typename O>
setOutputs(O & opts,std::vector<at::Tensor> & tensors)294 void setOutputs(O& opts, std::vector<at::Tensor>& tensors) {
295   opts.setOutputs(getDataPointers<T>(tensors), tensors[0].numel());
296 }
297 
298 template <typename T, typename O>
setOutput(O & opts,at::Tensor & tensor)299 void setOutput(O& opts, at::Tensor& tensor) {
300   opts.setOutput(getDataPointer<T>(tensor), tensor.numel());
301 }
302 
303 template <typename T, typename O>
setOutput(O & opts,at::Tensor & tensor,std::vector<size_t> & counts)304 void setOutput(O& opts, at::Tensor& tensor, std::vector<size_t>& counts) {
305   opts.setOutput(getDataPointer<T>(tensor), counts);
306 }
307 
308 template <typename T, typename O>
setOutput(O & opts,at::Tensor & tensor,std::vector<int64_t> & counts)309 void setOutput(O& opts, at::Tensor& tensor, std::vector<int64_t>& counts) {
310   opts.setOutput(getDataPointer<T>(tensor), counts);
311 }
312 
pinnedLike(at::Tensor & tensor)313 at::Tensor pinnedLike(at::Tensor& tensor) {
314   auto* allocator = at::detail::getCUDAHooks().getPinnedMemoryAllocator();
315   auto storage = c10::Storage(
316       c10::Storage::use_byte_size_t(),
317       static_cast<int64_t>(at::detail::computeStorageNbytes(
318           tensor.sizes(), tensor.strides(), tensor.dtype().itemsize())),
319       allocator,
320       /*resizable=*/false);
321   return at::empty({0}, tensor.options().device(at::kCPU))
322       .set_(storage, 0, tensor.sizes(), tensor.strides());
323 }
324 
325 // This function initializes a vector of CUDA streams, one for every
326 // tensor in the input tensor vector, and ensures that these streams are
327 // synchronized with the current default streams. This is needed so
328 // that new work on the new streams is serialized w.r.t. all operations
329 // on the tensors.
initializeStreamsEvents(const std::vector<at::Tensor> & tensors,std::vector<c10::Stream> & streams,std::vector<c10::Event> & events)330 void initializeStreamsEvents(
331     const std::vector<at::Tensor>& tensors,
332     std::vector<c10::Stream>& streams,
333     std::vector<c10::Event>& events) {
334   streams.reserve(tensors.size());
335   events.reserve(tensors.size());
336   for (const auto i : c10::irange(tensors.size())) {
337     c10::Device device = tensors[i].device();
338     c10::impl::VirtualGuardImpl impl(device.type());
339     // Record event on current stream
340     events.emplace_back(device.type());
341     events[i].record(impl.getStream(device));
342     // Get a non-default stream to execute asynchronous CUDA operations
343     // on for this device. This ensures that the default stream used
344     // by the caller is not occupied by c10d related operations.
345     streams.push_back(
346         impl.getStreamFromGlobalPool(device, /*isHighPriority=*/true));
347     // Ensure the new stream is synchronized with the current stream.
348     events[i].block(streams[i]);
349 
350     // `tensors` are created on a different stream. Hence, they must record
351     // new streams in this Work to prevent being freed before the Work finishes.
352     if (tensors[i].is_sparse()) {
353       if (tensors[i].is_coalesced()) {
354         impl.recordDataPtrOnStream(
355             tensors[i].indices().storage().data_ptr(), streams[i]);
356         impl.recordDataPtrOnStream(
357             tensors[i].values().storage().data_ptr(), streams[i]);
358       } else {
359         // We will need to coalesce first, which means new tensors will
360         // be allocated on the streams we just allocated, and there
361         // is no need to record them separately.
362       }
363     } else {
364       impl.recordDataPtrOnStream(tensors[i].storage().data_ptr(), streams[i]);
365     }
366   }
367 }
368 
369 // This function initializes a vector of CUDA streams, one per device,
370 // and ensures that these streams are synchronized with the current default
371 // streams. It is assumed that the tensors in the nested tensor vectors are
372 // on the same device.
initializeStreamsEvents(std::vector<std::vector<at::Tensor>> & tensors,std::vector<c10::Stream> & streams,std::vector<c10::Event> & events)373 void initializeStreamsEvents(
374     std::vector<std::vector<at::Tensor>>& tensors,
375     std::vector<c10::Stream>& streams,
376     std::vector<c10::Event>& events) {
377   // Ensure that the tensors in the nested tensor vectors are on the same
378   // device.
379   for (const auto& tensorgroup : tensors) {
380     const auto device_id = tensorgroup[0].device().index();
381     for (const auto& tensor : tensorgroup) {
382       if (tensor.device().index() != device_id) {
383         TORCH_CHECK(
384             false,
385             "tensors in the nested tensor vectors need to "
386             "be on the same device");
387       }
388     }
389   }
390 
391   streams.reserve(tensors.size());
392   events.reserve(tensors.size());
393   for (const auto i : c10::irange(tensors.size())) {
394     c10::Device device = tensors[i][0].device();
395     c10::impl::VirtualGuardImpl impl(device.type());
396     // Record event on current stream
397     events.emplace_back(device.type());
398     events[i].record(impl.getStream(device));
399     // Get a non-default stream to execute asynchronous CUDA operations
400     // on for this output. This ensures that the default stream used
401     // by the caller is not occupied by c10d related operations.
402     streams.push_back(
403         impl.getStreamFromGlobalPool(device, /*isHighPriority=*/true));
404     // Ensure the new stream is synchronized with the current stream.
405     events[i].block(streams[i]);
406 
407     for (at::Tensor& tensor : tensors[i]) {
408       // `tensors` are created on a different stream. Hence, they must record
409       // new streams in this Work to prevent being freed before the Work
410       // finishes.
411       impl.recordDataPtrOnStream(tensor.storage().data_ptr(), streams[i]);
412     }
413   }
414 }
415 
416 const auto kLoopbackAddress = "127.0.0.1";
417 
418 } // namespace
419 
420 // static
execute(const c10::intrusive_ptr<AsyncWork> & work)421 void ProcessGroupGloo::AsyncWork::execute(
422     const c10::intrusive_ptr<AsyncWork>& work) {
423   if (work->recordFunctionBeforeCallback_) {
424     work->recordFunctionBeforeCallback_();
425   }
426   try {
427     work->run();
428   } catch (...) {
429     work->finishWorkGlooError(std::current_exception());
430     return;
431   }
432 
433   // FIXME: We need to call it here since Future completion requires all
434   // the work to be synchronized to CUDA.
435   work->synchronize();
436   work->finishWorkGloo();
437 }
438 
result()439 std::vector<at::Tensor> ProcessGroupGloo::AsyncWork::result() {
440   TORCH_CHECK(
441       isCompleted(),
442       "Work needs to be completed before calling result(). "
443       "Should call wait() before result().");
444   TORCH_CHECK(
445       outputTensors_.size() <= 1,
446       "work result does not support list of lists, use .getFuture() and value()");
447   return outputTensors_.empty() ? std::vector<at::Tensor>()
448                                 : outputTensors_.at(0);
449 }
450 
451 c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupGloo::AsyncWork::
getFuture()452     getFuture() {
453   return future_;
454 }
455 
456 namespace {
createFutureAsOutput(const std::vector<std::vector<at::Tensor>> & outputTensors)457 c10::intrusive_ptr<c10::ivalue::Future> createFutureAsOutput(
458     const std::vector<std::vector<at::Tensor>>& outputTensors) {
459   if (outputTensors.size() > 1) {
460     return c10::make_intrusive<c10::ivalue::Future>(
461         c10::ListType::create(c10::ListType::create(c10::TensorType::get())));
462   }
463   return c10::make_intrusive<c10::ivalue::Future>(
464       c10::ListType::create(c10::TensorType::get()));
465 }
466 
returnFutureWithOutput(c10::intrusive_ptr<c10::ivalue::Future> & future,const std::vector<std::vector<at::Tensor>> & outputTensors)467 void returnFutureWithOutput(
468     c10::intrusive_ptr<c10::ivalue::Future>& future,
469     const std::vector<std::vector<at::Tensor>>& outputTensors) {
470   if (outputTensors.empty()) {
471     future->markCompleted(c10::IValue(std::vector<at::Tensor>()));
472     return;
473   }
474   if (outputTensors.size() > 1) {
475     future->markCompleted(c10::IValue(outputTensors));
476     return;
477   }
478   future->markCompleted(c10::IValue(outputTensors[0]));
479 }
480 } // namespace
481 
recordAsyncWorkProfilingInfo(const char * profilingTitle,const std::optional<std::vector<at::Tensor>> & inputTensors)482 inline void ProcessGroupGloo::AsyncWork::recordAsyncWorkProfilingInfo(
483     const char* profilingTitle,
484     const std::optional<std::vector<at::Tensor>>& inputTensors) {
485   auto recordingFunction =
486       std::make_shared<at::RecordFunction>(at::RecordScope::USER_SCOPE);
487   if (recordingFunction->isActive()) {
488     std::function<void()> before_handler =
489         [inputTensors, profilingTitle, recordingFunction]() {
490           // The work will be started and completed by different threads.
491           recordingFunction->_setAsync();
492           std::vector<c10::IValue> inputs;
493           if (inputTensors) {
494             inputs.reserve(inputTensors->size());
495             for (const auto& tensor : *inputTensors) {
496               inputs.emplace_back(tensor);
497             }
498           }
499           recordingFunction->before(
500               profilingTitle,
501               c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size()));
502         };
503     recordFunctionBeforeCallback_ =
504         at::wrapPropagateTLSState(std::move(before_handler));
505     std::function<void()> end_handler = [recordingFunction]() {
506       recordingFunction->end();
507     };
508     recordFunctionEndCallback_ = at::wrapPropagateTLSState(end_handler);
509   }
510 }
511 
AsyncWork(std::vector<std::vector<at::Tensor>> outputTensors,OpType opType,uint64_t seq,const char * profilingTitle,const std::optional<std::vector<at::Tensor>> & inputTensors)512 ProcessGroupGloo::AsyncWork::AsyncWork(
513     std::vector<std::vector<at::Tensor>> outputTensors,
514     OpType opType,
515     uint64_t seq,
516     const char* profilingTitle,
517     const std::optional<std::vector<at::Tensor>>& inputTensors)
518     // Profiler: Pass nullptr as profilingTitle to parent constructor to
519     // replace default profiler implementation with async version that reports
520     // correct timestamps for work that is asynchronously executed.
521     : Work(-1, opType, nullptr, inputTensors),
522       outputTensors_(std::move(outputTensors)),
523       future_(createFutureAsOutput(outputTensors_)),
524       seq_(seq) {
525   if (profilingTitle != nullptr) {
526     recordAsyncWorkProfilingInfo(profilingTitle, inputTensors);
527   }
528 }
529 
getSequencenumber() const530 uint64_t ProcessGroupGloo::AsyncWork::getSequencenumber() const {
531   return seq_;
532 }
533 
finishWorkGlooError(const std::exception_ptr & eptr)534 void ProcessGroupGloo::AsyncWork::finishWorkGlooError(
535     const std::exception_ptr& eptr) {
536   future_->setError(eptr);
537   finish(eptr);
538 }
539 
finishWorkGloo()540 void ProcessGroupGloo::AsyncWork::finishWorkGloo() {
541   returnFutureWithOutput(future_, outputTensors_);
542   finish();
543 }
544 
SendWork(at::Tensor & tensor,std::unique_ptr<::gloo::transport::UnboundBuffer> buffer,uint64_t seq)545 ProcessGroupGloo::SendWork::SendWork(
546     at::Tensor& tensor,
547     std::unique_ptr<::gloo::transport::UnboundBuffer> buffer,
548     uint64_t seq)
549     : Work(
550           -1,
551           OpType::SEND,
552           "gloo:send",
553           std::optional<std::vector<at::Tensor>>({tensor})),
554       tensor_(tensor),
555       buffer_(std::move(buffer)),
556       seq_(seq) {}
557 
getSequencenumber() const558 uint64_t ProcessGroupGloo::SendWork::getSequencenumber() const {
559   return seq_;
560 }
561 
wait(std::chrono::milliseconds timeout)562 bool ProcessGroupGloo::SendWork::wait(std::chrono::milliseconds timeout) {
563   bool sendCompleted = false;
564   std::exception_ptr exception{nullptr};
565   try {
566     if (timeout == kNoTimeout) {
567       sendCompleted = buffer_->waitSend();
568     } else {
569       sendCompleted = buffer_->waitSend(timeout);
570     }
571   } catch (...) {
572     exception = std::current_exception();
573   }
574 
575   // Completes the Work object and throws the exception.
576   finishAndThrow(exception);
577   return sendCompleted;
578 }
579 
abort()580 void ProcessGroupGloo::SendWork::abort() {
581   buffer_->abortWaitSend();
582 }
583 
RecvWork(at::Tensor & tensor,std::unique_ptr<::gloo::transport::UnboundBuffer> buffer,OpType opType,uint64_t seq,const char * profilingTitle)584 ProcessGroupGloo::RecvWork::RecvWork(
585     at::Tensor& tensor,
586     std::unique_ptr<::gloo::transport::UnboundBuffer> buffer,
587     OpType opType,
588     uint64_t seq,
589     const char* profilingTitle)
590     : Work(
591           -1,
592           opType,
593           profilingTitle,
594           std::optional<std::vector<at::Tensor>>({tensor})),
595       tensor_(tensor),
596       buffer_(std::move(buffer)),
597       srcRank_(-1),
598       seq_(seq) {}
599 
getSequencenumber() const600 uint64_t ProcessGroupGloo::RecvWork::getSequencenumber() const {
601   return seq_;
602 }
603 
sourceRank() const604 int ProcessGroupGloo::RecvWork::sourceRank() const {
605   std::lock_guard<std::mutex> lock(mutex_);
606   return srcRank_;
607 }
608 
wait(std::chrono::milliseconds timeout)609 bool ProcessGroupGloo::RecvWork::wait(std::chrono::milliseconds timeout) {
610   bool recvCompleted = false;
611   std::exception_ptr exception{nullptr};
612   try {
613     if (timeout == kNoTimeout) {
614       recvCompleted = buffer_->waitRecv(&srcRank_);
615     } else {
616       recvCompleted = buffer_->waitRecv(&srcRank_, timeout);
617     }
618   } catch (...) {
619     exception = std::current_exception();
620   }
621 
622   // Completes the Work object and throws the exception.
623   finishAndThrow(exception);
624   return recvCompleted;
625 }
626 
abort()627 void ProcessGroupGloo::RecvWork::abort() {
628   buffer_->abortWaitRecv();
629 }
630 
Options(std::chrono::milliseconds timeout)631 ProcessGroupGloo::Options::Options(std::chrono::milliseconds timeout)
632     : Backend::Options(GLOO_BACKEND_NAME, timeout), threads(2) {}
633 
634 namespace {
635 
socketInitialize()636 void socketInitialize() {
637 #ifdef _WIN32
638   ::gloo::init_winsock();
639 #endif
640 }
641 
642 // Gloo assumes that this machine's hostname can always be resolved
643 // to an address. If it doesn't it throws a runtime error saying
644 // that it can't be resolved. Instead of catching it, we choose
645 // to proactively check if an address can be resolved, so we can
646 // gracefully fall back to an alternative if it doesn't.
doesHostnameResolveToUsableAddress(const std::string & hostname)647 bool doesHostnameResolveToUsableAddress(const std::string& hostname) {
648   socketInitialize();
649   struct addrinfo hints {};
650   memset(&hints, 0, sizeof(hints));
651   hints.ai_family = AF_UNSPEC;
652   hints.ai_socktype = SOCK_STREAM;
653   struct addrinfo* result = nullptr;
654   auto rv = getaddrinfo(hostname.c_str(), nullptr, &hints, &result);
655   if (rv < 0) {
656     return false;
657   }
658   struct addrinfo* rp = nullptr;
659   for (rp = result; rp != nullptr; rp = rp->ai_next) {
660     auto fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
661     if (fd == -1) {
662       continue;
663     }
664     rv = bind(fd, rp->ai_addr, rp->ai_addrlen);
665 #ifdef _WIN32
666     closesocket(fd);
667 #else
668     close(fd);
669 #endif
670     if (rv == -1) {
671       continue;
672     }
673     break;
674   }
675   freeaddrinfo(result);
676   return rp != nullptr;
677 }
678 
679 } // namespace
680 
681 std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
createDeviceForInterface(const std::string & interface_name)682     createDeviceForInterface(const std::string& interface_name) {
683   return ::c10d::GlooDeviceFactory::makeDeviceForInterface(interface_name);
684 }
685 
686 std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
createDeviceForHostname(const std::string & hostname)687     createDeviceForHostname(const std::string& hostname) {
688   TORCH_CHECK(
689       doesHostnameResolveToUsableAddress(hostname),
690       "Cannot resolve ",
691       hostname,
692       " to a (local) address");
693   return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname);
694 }
695 
696 #if defined(__linux__) || defined(_WIN32)
697 std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
createDefaultDevice()698     createDefaultDevice() {
699   // Use the hostname to resolve the network address to
700   // use. Note: if the hostname does not resolve to an address (e.g.
701   // because of misconfigured /etc/hosts file), this will not work.
702   socketInitialize();
703   std::array<char, HOST_NAME_MAX> hostname{};
704   auto rv = gethostname(hostname.data(), HOST_NAME_MAX);
705   if (rv != 0) {
706     C10_THROW_ERROR(DistBackendError, std::strerror(errno));
707   }
708 
709   // Use this machine's hostname if it resolves to an address.
710   if (doesHostnameResolveToUsableAddress(hostname.data())) {
711     return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname.data());
712   }
713 
714   // Otherwise, use the loopback address.
715   TORCH_WARN_ONCE(
716       "Unable to resolve hostname to a (local) address. ",
717       "Using the loopback address as fallback. ",
718       "Manually set the network interface to bind to with GLOO_SOCKET_IFNAME.");
719   return createDeviceForHostname(kLoopbackAddress);
720 }
721 #endif
722 
723 #ifdef __APPLE__
724 std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
createDefaultDevice()725     createDefaultDevice() {
726   // Use the hostname to resolve the network address to
727   // use. Note: if the hostname does not resolve to an address (e.g.
728   // because of misconfigured /etc/hosts file), this will not work.
729   const auto hostNameMax = sysconf(_SC_HOST_NAME_MAX);
730   auto hostname = std::unique_ptr<char[]>(new char[hostNameMax]);
731   auto rv = gethostname(hostname.get(), hostNameMax);
732   if (rv != 0) {
733     C10_THROW_ERROR(DistBackendError, std::strerror(errno));
734   }
735 
736   // Use this machine's hostname if it resolves to an address.
737   if (doesHostnameResolveToUsableAddress(hostname.get())) {
738     return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname.get());
739   }
740 
741   // Otherwise, use the loopback address.
742   TORCH_WARN_ONCE(
743       "Unable to resolve hostname to a (local) address. ",
744       "Using the loopback address as fallback. ",
745       "Manually set the network interface to bind to with GLOO_SOCKET_IFNAME.");
746   return createDeviceForHostname(kLoopbackAddress);
747 }
748 #endif
749 
ProcessGroupGloo(const c10::intrusive_ptr<Store> & store,int rank,int size,c10::intrusive_ptr<Options> options)750 ProcessGroupGloo::ProcessGroupGloo(
751     const c10::intrusive_ptr<Store>& store,
752     int rank,
753     int size,
754     c10::intrusive_ptr<Options> options)
755     : Backend(rank, size),
756       store_(new GlooStore(store)),
757       options_(std::move(options)),
758       stop_(false),
759       collectiveCounter_(0) {
760   auto& devices = options_->devices;
761   if (devices.empty()) {
762     TORCH_CHECK(false, "No device(s) specified");
763   }
764 
765   // Create and connect a context for every device.
766   //
767   // Note that the same device can be specified multiple times, either
768   // the same object, or the same logical device as different objects.
769   // Either mode is fine and only has performance implications.
770   //
771   // Using the same object multiple times means all contexts share a
772   // single I/O thread. If you use different objects for the same
773   // logical device they will have independent I/O threads. The latter
774   // option is needed if you have a fast NIC that cannot be saturated
775   // by a single I/O thread.
776   //
777   contexts_.reserve(options_->devices.size());
778   for (const auto i : c10::irange(options_->devices.size())) {
779     auto context = std::make_shared<::gloo::rendezvous::Context>(rank_, size_);
780     auto store = ::gloo::rendezvous::PrefixStore(std::to_string(i), *store_);
781     context->setTimeout(options_->timeout);
782     try {
783       context->connectFullMesh(store, options_->devices[i]);
784     } catch (const std::runtime_error& e) {
785       auto err = e.what();
786       // TORCH_CHECK to print the cpp stacktrace.
787       auto msg = c10::str("Gloo connectFullMesh failed with ", err);
788       logAndThrow(msg, msg);
789     }
790     contexts_.push_back(std::move(context));
791   }
792 
793   // Every worker thread stores the AsyncWork object it's currently
794   // working on in the workInProgress_ vector. It must have size equal
795   // to the number of workers such that they can simply index into it
796   // using the worker index they are started with.
797   workInProgress_.resize(options_->threads);
798 
799   threads_.resize(options_->threads);
800   for (const auto i : c10::irange(threads_.size())) {
801     threads_[i] = std::thread(&ProcessGroupGloo::runLoop, this, i);
802   }
803 
804   init();
805 }
806 
~ProcessGroupGloo()807 ProcessGroupGloo::~ProcessGroupGloo() {
808   std::unique_lock<std::mutex> lock(workMutex_);
809   workConsumeCV_.wait(lock, [&] { return workQueue_.empty(); });
810 
811   // Queue is empty, signal stop
812   stop_ = true;
813 
814   // Release lock to allow threads to terminate
815   lock.unlock();
816 
817   workProduceCV_.notify_all();
818 
819   // Wait for worker threads to terminate
820   for (auto& thread : threads_) {
821     thread.join();
822   }
823 }
824 
nextTag()825 uint32_t ProcessGroupGloo::nextTag() {
826   return collectiveCounter_++;
827 }
828 
getContext(uint32_t tag)829 std::shared_ptr<::gloo::Context> ProcessGroupGloo::getContext(uint32_t tag) {
830   return contexts_[tag % contexts_.size()];
831 }
832 
runLoop(int workerIndex)833 void ProcessGroupGloo::runLoop(int workerIndex) {
834   std::unique_lock<std::mutex> lock(workMutex_);
835 
836   while (!stop_) {
837     if (workQueue_.empty()) {
838       workProduceCV_.wait(lock);
839       continue;
840     }
841 
842     auto work = std::move(workQueue_.front());
843     workQueue_.pop_front();
844     workInProgress_[workerIndex] = work;
845     lock.unlock();
846 
847     // Notify after releasing the lock so that the waiter
848     // does not immediately block.
849     workConsumeCV_.notify_one();
850 
851     AsyncWork::execute(work);
852     lock.lock();
853     workInProgress_[workerIndex].reset();
854   }
855 }
856 
enqueue(c10::intrusive_ptr<AsyncWork> work)857 void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
858   std::unique_lock<std::mutex> lock(workMutex_);
859   workQueue_.push_back(std::move(work));
860   lock.unlock();
861 
862   // Notify after releasing the lock so that the waiter
863   // does not immediately block.
864   workProduceCV_.notify_one();
865 }
866 
867 namespace {
868 
869 class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork {
870  public:
AsyncBroadcastWork(const std::shared_ptr<gloo::Context> & context,std::vector<at::Tensor> & inputs,int rootRank,int rootTensor,uint32_t tag,uint64_t seq)871   AsyncBroadcastWork(
872       const std::shared_ptr<gloo::Context>& context,
873       std::vector<at::Tensor>& inputs,
874       int rootRank,
875       int rootTensor,
876       uint32_t tag,
877       uint64_t seq)
878       : ProcessGroupGloo::AsyncWork(
879             {inputs},
880             OpType::BROADCAST,
881             seq,
882             "gloo:broadcast",
883             inputs),
884         context(context),
885         inputs(inputs),
886         rootRank(rootRank),
887         rootTensor(rootTensor),
888         tag(tag) {}
889 
890   std::shared_ptr<gloo::Context> context;
891   std::vector<at::Tensor> inputs{};
892   const int rootRank;
893   const int rootTensor;
894   const uint32_t tag;
895 
broadcast(at::Tensor & tensor)896   void broadcast(at::Tensor& tensor) {
897     const auto& scalarType = tensor.scalar_type();
898     gloo::BroadcastOptions opts(context);
899     opts.setRoot(rootRank);
900     opts.setTag(tag);
901     GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensor);
902     gloo::broadcast(opts);
903   }
904 
run()905   void run() override {
906     broadcast(inputs[rootTensor]);
907 
908     // Copy to non-root tensors
909     for (const auto i : c10::irange(inputs.size())) {
910       if (i == static_cast<size_t>(rootTensor)) {
911         continue;
912       }
913       inputs[i].copy_(inputs[rootTensor]);
914     }
915   }
916 };
917 
918 class AsyncBroadcastCUDAWork : public AsyncBroadcastWork {
919  public:
AsyncBroadcastCUDAWork(const std::shared_ptr<gloo::Context> & context,std::vector<at::Tensor> & inputs,int rootRank,int rootTensor,uint32_t tag,uint64_t seq)920   AsyncBroadcastCUDAWork(
921       const std::shared_ptr<gloo::Context>& context,
922       std::vector<at::Tensor>& inputs,
923       int rootRank,
924       int rootTensor,
925       uint32_t tag,
926       uint64_t seq)
927       : AsyncBroadcastWork(context, inputs, rootRank, rootTensor, tag, seq) {
928     initializeStreamsEvents(inputs, streams, events);
929 
930     // Create pinned host side tensors.
931     tmp = pinnedLike(inputs[rootTensor]);
932     c10::OptionalStreamGuard guard;
933     if (context->rank == rootRank) {
934       guard.reset_stream(streams[rootTensor]);
935       tmp.copy_(inputs[rootTensor], /* non_blocking */ true);
936     }
937   }
938 
run()939   void run() override {
940     // Synchronize with copy operation if applicable.
941     if (context->rank == rootRank) {
942       streams[rootTensor].synchronize();
943     }
944 
945     // Run broadcast on host side tensors.
946     broadcast(tmp);
947 
948     // Kick off copy back to the CUDA tensors.
949     c10::OptionalStreamGuard guard;
950     for (const auto i : c10::irange(inputs.size())) {
951       guard.reset_stream(streams[i]);
952       inputs[i].copy_(tmp, /* non_blocking */ true);
953       events[i].record(streams[i]);
954     }
955   }
956 
synchronize()957   void synchronize() override {
958     // Synchronize with the copy back to CUDA tensors.
959     for (const auto i : c10::irange(inputs.size())) {
960       c10::Device device = inputs[i].device();
961       events[i].block(
962           c10::impl::VirtualGuardImpl(device.type()).getStream(device));
963     }
964   }
965 
966   at::Tensor tmp;
967   std::vector<c10::Stream> streams{};
968   std::vector<c10::Event> events{};
969 };
970 
971 } // namespace
972 
broadcast(std::vector<at::Tensor> & inputs,const BroadcastOptions & opts)973 c10::intrusive_ptr<Work> ProcessGroupGloo::broadcast(
974     std::vector<at::Tensor>& inputs,
975     const BroadcastOptions& opts) {
976   static auto invalidArgument = [](const std::string& msg) {
977     TORCH_CHECK(false, "ProcessGroupGloo::broadcast: " + msg);
978   };
979 
980   assertRootRank(invalidArgument, opts.rootRank, size_);
981   assertRootTensor(
982       invalidArgument, opts.rootTensor, static_cast<int64_t>(inputs.size()));
983   assertDense(invalidArgument, inputs);
984   assertTypeAndSizesMatch(invalidArgument, inputs);
985 
986   const auto& device = inputs[0].device();
987   switch (device.type()) {
988     case at::kCPU:
989       break;
990     case at::kCUDA:
991       // If the user gave us a CUDA tensor then CUDA must be loaded.
992       TORCH_INTERNAL_ASSERT(at::hasCUDA());
993       break;
994     default:
995       invalidArgument(c10::str("unsupported device type ", device.type()));
996   }
997 
998   c10::intrusive_ptr<AsyncBroadcastWork> work;
999   auto tag = nextTag();
1000   auto context = getContext(tag);
1001   ++seq_;
1002   if (device.type() == at::kCPU) {
1003     work = c10::make_intrusive<AsyncBroadcastWork>(
1004         std::move(context), inputs, opts.rootRank, opts.rootTensor, tag, seq_);
1005   } else if (device.type() == at::kCUDA) {
1006     work = c10::make_intrusive<AsyncBroadcastCUDAWork>(
1007         std::move(context), inputs, opts.rootRank, opts.rootTensor, tag, seq_);
1008   } else {
1009     TORCH_CHECK(false, "Invalid backend");
1010   }
1011 
1012   enqueue(work);
1013   return work;
1014 }
1015 
1016 namespace {
1017 
1018 class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork {
1019  public:
AsyncAllreduceWork(const std::shared_ptr<gloo::Context> & context,std::vector<at::Tensor> & inputs,ReduceOp reduceOp,uint32_t tag,uint64_t seq)1020   AsyncAllreduceWork(
1021       const std::shared_ptr<gloo::Context>& context,
1022       std::vector<at::Tensor>& inputs,
1023       ReduceOp reduceOp,
1024       uint32_t tag,
1025       uint64_t seq)
1026       : ProcessGroupGloo::AsyncWork(
1027             {inputs},
1028             OpType::ALLREDUCE,
1029             seq,
1030             "gloo:all_reduce",
1031             inputs),
1032         context(context),
1033         inputs(inputs),
1034         reduceOp(std::move(reduceOp)),
1035         tag(tag) {}
1036 
1037   std::shared_ptr<gloo::Context> context;
1038   std::vector<at::Tensor> inputs{};
1039   const ReduceOp reduceOp;
1040   const uint32_t tag;
1041 
allreduce(std::vector<at::Tensor> & tensors)1042   void allreduce(std::vector<at::Tensor>& tensors) {
1043     const auto& scalarType = tensors[0].scalar_type();
1044     gloo::AllreduceOptions opts(context);
1045     opts.setReduceFunction(getFunction(scalarType, reduceOp));
1046     opts.setTag(tag);
1047     GENERATE_ALL_TYPES(scalarType, setOutputs, opts, tensors);
1048     gloo::allreduce(opts);
1049   }
1050 
run()1051   void run() override {
1052     allreduce(inputs);
1053   }
1054 
1055   template <typename T>
getFunction(gloo::AllreduceOptions::Func & fn,const ReduceOp op)1056   void getFunction(gloo::AllreduceOptions::Func& fn, const ReduceOp op) {
1057     fn = toFunction<T>(op);
1058   }
1059 
getFunction(const at::ScalarType & dtype,const ReduceOp & op)1060   gloo::AllreduceOptions::Func getFunction(
1061       const at::ScalarType& dtype,
1062       const ReduceOp& op) {
1063     gloo::AllreduceOptions::Func fn;
1064     GENERATE_ALL_TYPES(dtype, getFunction, fn, op);
1065     return fn;
1066   }
1067 };
1068 
1069 class AsyncAllreduceCoalescedWork : public AsyncAllreduceWork {
1070  public:
AsyncAllreduceCoalescedWork(const std::shared_ptr<gloo::Context> & context,std::vector<at::Tensor> & inputs,ReduceOp reduceOp,uint32_t tag,uint64_t seq)1071   AsyncAllreduceCoalescedWork(
1072       const std::shared_ptr<gloo::Context>& context,
1073       std::vector<at::Tensor>& inputs,
1074       ReduceOp reduceOp,
1075       uint32_t tag,
1076       uint64_t seq)
1077       : AsyncAllreduceWork(context, inputs, std::move(reduceOp), tag, seq) {}
1078 
run()1079   void run() override {
1080     allreduceCoalesced(inputs);
1081   }
1082 
1083  private:
allreduceCoalesced(std::vector<at::Tensor> & tensors)1084   void allreduceCoalesced(std::vector<at::Tensor>& tensors) {
1085     // reduce coalesced, flattened tensors.
1086     at::Tensor coalescedTensor = flattenDenseTensors(tensors);
1087     std::vector<at::Tensor> allreduceInput = {coalescedTensor};
1088     allreduce(allreduceInput);
1089 
1090     // separate and reshape tensors.
1091     size_t offset = 0;
1092     for (at::Tensor& tensor : tensors) {
1093       const int64_t tensorNumel = tensor.numel();
1094       const c10::IntArrayRef tensorShape = tensor.sizes();
1095       tensor.copy_(coalescedTensor.slice(0, offset, offset + tensorNumel)
1096                        .view(tensorShape));
1097       offset += tensorNumel;
1098     }
1099   }
1100 };
1101 
1102 class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
1103  public:
AsyncSparseAllreduceWork(const std::shared_ptr<gloo::Context> & context,std::vector<at::Tensor> & inputs,uint32_t tag,uint64_t seq)1104   AsyncSparseAllreduceWork(
1105       const std::shared_ptr<gloo::Context>& context,
1106       std::vector<at::Tensor>& inputs,
1107       uint32_t tag,
1108       uint64_t seq)
1109       : ProcessGroupGloo::AsyncWork(
1110             {inputs},
1111             OpType::_ALLREDUCE_SPARSE,
1112             seq,
1113             "gloo:sparse_all_reduce",
1114             inputs),
1115         context(context),
1116         inputs(inputs),
1117         tag(tag) {}
1118 
1119   std::shared_ptr<gloo::Context> context;
1120   std::vector<at::Tensor> inputs{};
1121   const uint32_t tag;
1122 
1123   // We share dimensionality about the sparse tensors before collecting
1124   // their contents. We assume here that the maximum number of sparse
1125   // and dense dimensions is 4. This is stored in a contiguous piece of
1126   // memory so that we can easily run allgather on it.
1127   //
1128   // The layout of this memory is as follows:
1129   //
1130   //   - [0:4]: sparse dims
1131   //   - [4:8]: dense dims
1132   //   -   [8]: nnz
1133   //
1134   class SparseTensorMetadata {
1135    public:
1136     static constexpr auto dim = 9;
1137 
1138     // Construct from an existing metadata tensor to facilitate structured
1139     // access to metadata from peers, after gathering it.
SparseTensorMetadata(at::Tensor metadata)1140     explicit SparseTensorMetadata(at::Tensor metadata)
1141         : metadata_(std::move(metadata)),
1142           data_(metadata_.mutable_data_ptr<int64_t>()) {
1143       AT_ASSERT(metadata_.scalar_type() == at::kLong);
1144       AT_ASSERT(metadata_.dim() == 1);
1145       AT_ASSERT(metadata_.size(0) == dim);
1146     }
1147 
1148     // Populate the metadata.
populate_from_sparse_tensor(const at::Tensor & tensor)1149     void populate_from_sparse_tensor(const at::Tensor& tensor) {
1150       const auto sparse_dim = tensor.sparse_dim();
1151       AT_ASSERT(sparse_dim <= 4);
1152       for (const auto i : c10::irange(4)) {
1153         if (i < sparse_dim) {
1154           data_[i] = tensor.size(i);
1155         }
1156       }
1157       const auto dense_dim = tensor.dense_dim();
1158       AT_ASSERT(dense_dim <= 4);
1159       for (const auto i : c10::irange(4)) {
1160         if (i < dense_dim) {
1161           data_[i + 4] = tensor.size(sparse_dim + i);
1162         }
1163       }
1164       data_[8] = tensor._nnz();
1165     }
1166 
sizes() const1167     std::vector<int64_t> sizes() const {
1168       std::vector<int64_t> sizes;
1169       // Sparse sizes
1170       for (const auto i : c10::irange(4)) {
1171         if (data_[i] <= 0) {
1172           break;
1173         }
1174         sizes.push_back(data_[i]);
1175       }
1176       // Dense sizes
1177       for (const auto i : c10::irange(4, 8)) {
1178         if (data_[i] <= 0) {
1179           break;
1180         }
1181         sizes.push_back(data_[i]);
1182       }
1183       return sizes;
1184     }
1185 
nnz() const1186     int64_t nnz() const {
1187       return data_[8];
1188     }
1189 
1190    protected:
1191     at::Tensor metadata_;
1192     int64_t* data_;
1193   };
1194 
1195   // Sparse allreduce is implemented with allgather on indices and values.
1196   // Every process then sums the resulting sparse tensors locally.
1197   // The nnz for sparse tensors may be different across processes, so first
1198   // we run allgather on the nnz, and then allgather with max(nnz).
allreduce(std::vector<at::Tensor> & tensors)1199   at::Tensor allreduce(std::vector<at::Tensor>& tensors) {
1200     // TODO: This is a massive hack!  There is some confusion about
1201     // Variable/Tensor inside the body of this function.  Turning off
1202     // grad smooths over the confusion for now.  This fixes
1203     // test/test_c10d_gloo.py ProcessGroupGlooTest.test_sparse_allreduce_basics
1204     //
1205     // The correct fix is to stop allocating tensors that are not variables,
1206     // but to conveniently do this c10d must depend on torch not ATen
1207     at::AutoDispatchBelowAutograd guard;
1208     auto input = tensors[0];
1209 
1210     // Perform local reduction if we have multiple inputs.
1211     for (const auto i : c10::irange(1, tensors.size())) {
1212       input += tensors[i];
1213     }
1214 
1215     // Need to coalesce before we can access indices and values.
1216     input = input.coalesce();
1217 
1218     // Gather metadata information from all ranks.
1219     auto metadata = allgather_metadata(input);
1220 
1221     // Sanity check dimensionality across ranks.
1222     {
1223       const auto expected = metadata[context->rank].sizes();
1224       for (const auto i : c10::irange(context->size)) {
1225         if (i == context->rank) {
1226           continue;
1227         }
1228         const auto actual = metadata[i].sizes();
1229         TORCH_CHECK(actual == expected, "Sparse dimensions do not match");
1230       }
1231     }
1232 
1233     // Gather all indices and all values.
1234     auto indices = allgather_indices(input, metadata);
1235     auto values = allgather_values(input, metadata);
1236 
1237     // Perform global reduction.
1238     AT_ASSERT(static_cast<int>(indices.size()) == context->size);
1239     AT_ASSERT(static_cast<int>(values.size()) == context->size);
1240     auto output = at::sparse_coo_tensor(
1241         indices[0], values[0], input.sizes(), input.options());
1242     for (const auto i : c10::irange(1, context->size)) {
1243       output += at::sparse_coo_tensor(
1244           indices[i], values[i], input.sizes(), input.options());
1245     }
1246 
1247     // Coalesce for good measure.
1248     return output.coalesce();
1249   }
1250 
run()1251   void run() override {
1252     auto output = allreduce(inputs);
1253 
1254     // This copy is needed when we run a multi-gpu version of reduce (multiple
1255     // inputs per rank).
1256     for (const auto i : c10::irange(inputs.size())) {
1257       inputs[i].copy_(output);
1258     }
1259   }
1260 
1261  private:
allgather_metadata(const at::Tensor & tensor)1262   std::vector<SparseTensorMetadata> allgather_metadata(
1263       const at::Tensor& tensor) {
1264     auto buffer =
1265         at::zeros({context->size, SparseTensorMetadata::dim}, at::kLong);
1266 
1267     // Prepare metadata vector (1 entry per rank)
1268     std::vector<SparseTensorMetadata> metadata;
1269     metadata.reserve(context->size);
1270     for (const auto i : c10::irange(context->size)) {
1271       metadata.emplace_back(buffer.select(0, i));
1272     }
1273 
1274     // Populate data for this rank
1275     metadata[context->rank].populate_from_sparse_tensor(tensor);
1276 
1277     // Allgather metadata
1278     gloo::AllgatherOptions opts(context);
1279     opts.setOutput(buffer.mutable_data_ptr<int64_t>(), buffer.numel());
1280     opts.setTag(tag);
1281     gloo::allgather(opts);
1282 
1283     return metadata;
1284   }
1285 
allgather_indices(const at::Tensor & tensor,const std::vector<SparseTensorMetadata> & metadata)1286   std::vector<at::Tensor> allgather_indices(
1287       const at::Tensor& tensor,
1288       const std::vector<SparseTensorMetadata>& metadata) {
1289     const auto sparseDim = tensor.sparse_dim();
1290 
1291     std::vector<size_t> counts(context->size);
1292     size_t totalSize = 0;
1293     for (const auto i : c10::irange(metadata.size())) {
1294       counts[i] = metadata[i].nnz() * sparseDim;
1295       totalSize += counts[i];
1296     }
1297 
1298     auto output = at::empty({static_cast<int64_t>(totalSize)}, at::kLong);
1299 
1300     // tensors copied from cuda may not be contiguous, get a contiguous
1301     // tensor before use its data_ptr
1302     auto input = tensor.indices().contiguous();
1303 
1304     // Allgatherv indices.
1305     gloo::AllgathervOptions opts(context);
1306     opts.setInput(
1307         // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
1308         const_cast<int64_t*>(input.const_data_ptr<int64_t>()),
1309         input.numel());
1310     opts.setOutput(output.mutable_data_ptr<int64_t>(), counts);
1311     opts.setTag(tag);
1312     gloo::allgatherv(opts);
1313 
1314     // Compile indices tensor per rank.
1315     std::vector<at::Tensor> indices;
1316     indices.reserve(metadata.size());
1317     int64_t offset = 0;
1318     for (const auto& i : metadata) {
1319       const auto nnz = i.nnz();
1320       const auto numel = sparseDim * nnz;
1321       indices.push_back(
1322           output.narrow(0, offset, numel).reshape({sparseDim, nnz}));
1323       offset += numel;
1324     }
1325 
1326     return indices;
1327   }
1328 
allgather_values(const at::Tensor & tensor,const std::vector<SparseTensorMetadata> & metadata)1329   std::vector<at::Tensor> allgather_values(
1330       const at::Tensor& tensor,
1331       const std::vector<SparseTensorMetadata>& metadata) {
1332     // There are nnz #dense_dim()-dimensional tensors per rank.
1333     const auto valueShape = tensor.sizes().slice(tensor.sparse_dim());
1334     int64_t denseNumel = 1;
1335     for (auto dim : valueShape) {
1336       denseNumel *= dim;
1337     }
1338 
1339     std::vector<size_t> counts(context->size);
1340     int64_t totalSize = 0;
1341     for (const auto i : c10::irange(metadata.size())) {
1342       counts[i] = metadata[i].nnz() * denseNumel;
1343       totalSize += static_cast<int64_t>(counts[i]);
1344     }
1345 
1346     auto output = at::empty({totalSize}, tensor.scalar_type());
1347 
1348     // Allgatherv indices.
1349     gloo::AllgathervOptions opts(context);
1350     // tensors copied from cuda may not be contiguous, get a contiguous
1351     // tensor before use its data_ptr
1352     at::Tensor valueTensor = tensor.values().contiguous();
1353     GENERATE_ALL_TYPES(valueTensor.scalar_type(), setInput, opts, valueTensor);
1354     GENERATE_ALL_TYPES(
1355         valueTensor.scalar_type(), setOutput, opts, output, counts);
1356     opts.setTag(tag);
1357     gloo::allgatherv(opts);
1358 
1359     // Compile values tensor per rank.
1360     std::vector<at::Tensor> values;
1361     values.reserve(metadata.size());
1362     int64_t offset = 0;
1363     for (const auto& i : metadata) {
1364       const auto nnz = i.nnz();
1365       const auto numel = denseNumel * nnz;
1366       auto tensorShape = std::vector<int64_t>({(int64_t)nnz});
1367       std::copy(
1368           valueShape.begin(),
1369           valueShape.end(),
1370           std::back_inserter(tensorShape));
1371       values.push_back(output.narrow(0, offset, numel).reshape(tensorShape));
1372       offset += numel;
1373     }
1374 
1375     return values;
1376   }
1377 };
1378 
1379 class AsyncAllreduceCUDAWork : public AsyncAllreduceWork {
1380  public:
AsyncAllreduceCUDAWork(const std::shared_ptr<gloo::Context> & context,std::vector<at::Tensor> & inputs,ReduceOp reduceOp,uint32_t tag,uint64_t seq)1381   AsyncAllreduceCUDAWork(
1382       const std::shared_ptr<gloo::Context>& context,
1383       std::vector<at::Tensor>& inputs,
1384       ReduceOp reduceOp,
1385       uint32_t tag,
1386       uint64_t seq)
1387       : AsyncAllreduceWork(context, inputs, std::move(reduceOp), tag, seq) {
1388     initializeStreamsEvents(inputs, streams, events);
1389 
1390     // Kick off copy from CUDA tensors to pinned CPU tensors.
1391     tmp.reserve(inputs.size());
1392     c10::OptionalStreamGuard guard;
1393     for (const auto i : c10::irange(inputs.size())) {
1394       guard.reset_stream(streams[i]);
1395       tmp.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true));
1396     }
1397   }
1398 
run()1399   void run() override {
1400     // Synchronize with copy operations.
1401     for (const auto i : c10::irange(inputs.size())) {
1402       streams[i].synchronize();
1403     }
1404 
1405     // Run allreduce on host side tensors.
1406     allreduce(tmp);
1407 
1408     c10::OptionalStreamGuard guard;
1409     for (const auto i : c10::irange(inputs.size())) {
1410       guard.reset_stream(streams[i]);
1411       inputs[i].copy_(tmp[i], /* non_blocking */ true);
1412       events[i].record(streams[i]);
1413     }
1414   }
1415 
synchronize()1416   void synchronize() override {
1417     // Synchronize with the copy back to CUDA tensors.
1418     for (const auto i : c10::irange(inputs.size())) {
1419       c10::Device device = inputs[i].device();
1420       events[i].block(
1421           c10::impl::VirtualGuardImpl(device.type()).getStream(device));
1422     }
1423   }
1424 
1425   std::vector<at::Tensor> tmp;
1426   std::vector<c10::Stream> streams{};
1427   std::vector<c10::Event> events{};
1428 };
1429 
1430 class AsyncSparseAllreduceCUDAWork : public AsyncSparseAllreduceWork {
1431  public:
AsyncSparseAllreduceCUDAWork(const std::shared_ptr<gloo::Context> & context,std::vector<at::Tensor> & inputs,uint32_t tag,uint64_t seq)1432   AsyncSparseAllreduceCUDAWork(
1433       const std::shared_ptr<gloo::Context>& context,
1434       std::vector<at::Tensor>& inputs,
1435       uint32_t tag,
1436       uint64_t seq)
1437       : AsyncSparseAllreduceWork(context, inputs, tag, seq) {
1438     initializeStreamsEvents(inputs, streams, events);
1439 
1440     // Kick off copy from CUDA tensors to CPU tensors.
1441     // Note that both coalescing the sparse tensor and copying it to CPU
1442     // memory must be performed asynchronously, or we block the caller.
1443     tmp.reserve(inputs.size());
1444     c10::OptionalStreamGuard guard;
1445     for (const auto i : c10::irange(inputs.size())) {
1446       guard.reset_stream(streams[i]);
1447       tmp.push_back(
1448           inputs[i].coalesce().to(at::DeviceType::CPU, /*non_blocking=*/true));
1449     }
1450   }
1451 
run()1452   void run() override {
1453     // Synchronize with copy operations.
1454     for (const auto i : c10::irange(inputs.size())) {
1455       streams[i].synchronize();
1456     }
1457 
1458     // Run allreduce on host side tensors.
1459     auto output = allreduce(tmp);
1460 
1461     // Kick off copy back to the CUDA tensors.
1462     c10::OptionalStreamGuard guard;
1463     for (const auto i : c10::irange(inputs.size())) {
1464       guard.reset_stream(streams[i]);
1465       inputs[i].copy_(output, /*non_blocking=*/true);
1466       events[i].record(streams[i]);
1467     }
1468   }
1469 
synchronize()1470   void synchronize() override {
1471     // Synchronize with the copy back to CUDA tensors.
1472     for (const auto i : c10::irange(inputs.size())) {
1473       c10::Device device = inputs[i].device();
1474       events[i].block(
1475           c10::impl::VirtualGuardImpl(device.type()).getStream(device));
1476     }
1477   }
1478 
1479   std::vector<at::Tensor> tmp{};
1480   std::vector<c10::Stream> streams{};
1481   std::vector<c10::Event> events{};
1482 };
1483 
1484 } // namespace
1485 
allreduce(std::vector<at::Tensor> & inputs,const AllreduceOptions & opts)1486 c10::intrusive_ptr<Work> ProcessGroupGloo::allreduce(
1487     std::vector<at::Tensor>& inputs,
1488     const AllreduceOptions& opts) {
1489   static auto invalidArgument = [](const std::string& msg) {
1490     TORCH_CHECK(false, "ProcessGroupGloo::allreduce: " + msg);
1491   };
1492 
1493   assertNonEmpty(invalidArgument, inputs);
1494   assertLayoutMatch(invalidArgument, inputs);
1495   assertTypeAndSizesMatch(invalidArgument, inputs);
1496 
1497   const auto& device = inputs[0].device();
1498   switch (device.type()) {
1499     case at::kCPU:
1500       break;
1501     case at::kCUDA:
1502       // If the user gave us a CUDA tensor then CUDA must be loaded.
1503       TORCH_INTERNAL_ASSERT(at::hasCUDA());
1504       break;
1505     default:
1506       invalidArgument(c10::str("unsupported device type ", device.type()));
1507   }
1508 
1509   const auto& layout = inputs[0].layout();
1510   if (layout == c10::kSparse && opts.reduceOp != ReduceOp::SUM) {
1511     invalidArgument(
1512         "unsupported reduction operation "
1513         "(allreduce of sparse tensors only works with ReduceOp.SUM)");
1514   }
1515 
1516   c10::intrusive_ptr<AsyncWork> work;
1517   auto tag = nextTag();
1518   auto context = getContext(tag);
1519   ++seq_;
1520   if (device.type() == at::kCPU) {
1521     if (layout == c10::kStrided) {
1522       work = c10::make_intrusive<AsyncAllreduceWork>(
1523           std::move(context), inputs, opts.reduceOp, tag, seq_);
1524     } else if (layout == c10::kSparse) {
1525       work = c10::make_intrusive<AsyncSparseAllreduceWork>(
1526           std::move(context), inputs, tag, seq_);
1527     } else {
1528       invalidArgument("unsupported layout");
1529     }
1530   } else if (device.type() == at::kCUDA) {
1531     if (layout == c10::kStrided) {
1532       work = c10::make_intrusive<AsyncAllreduceCUDAWork>(
1533           std::move(context), inputs, opts.reduceOp, tag, seq_);
1534     } else if (layout == c10::kSparse) {
1535       work = c10::make_intrusive<AsyncSparseAllreduceCUDAWork>(
1536           std::move(context), inputs, tag, seq_);
1537     } else {
1538       invalidArgument("unsupported layout");
1539     }
1540   } else {
1541     TORCH_CHECK(false, "Invalid backend");
1542   }
1543 
1544   enqueue(work);
1545   return work;
1546 }
1547 
allreduce_sparse(std::vector<at::Tensor> & inputs,const AllreduceOptions & opts)1548 c10::intrusive_ptr<Work> ProcessGroupGloo::allreduce_sparse(
1549     std::vector<at::Tensor>& inputs,
1550     const AllreduceOptions& opts) {
1551   // all reduce sparse calls into default allreduce which
1552   // implemented with all_gathering indices and values
1553   // we do ths we do not have a native cuda implementation
1554   return allreduce(inputs, opts);
1555 }
1556 
allreduce_coalesced(std::vector<at::Tensor> & tensors,const AllreduceCoalescedOptions & opts)1557 c10::intrusive_ptr<Work> ProcessGroupGloo::allreduce_coalesced(
1558     std::vector<at::Tensor>& tensors,
1559     const AllreduceCoalescedOptions& opts) {
1560   static auto invalidArgument = [](const std::string& msg) {
1561     TORCH_CHECK(false, "ProcessGroupGloo::allreduce_coalesced: " + msg);
1562   };
1563   assertNonEmpty(invalidArgument, tensors);
1564 
1565   // tensors will be flattened and concatenated (coalesced). This means that
1566   // input
1567   // tensors must have the same device, layout and type.
1568   assertLayoutMatch(invalidArgument, tensors);
1569   if (!std::all_of(tensors.begin(), tensors.end(), [&](at::Tensor& t) {
1570         return t.options().type_equal(tensors[0].options());
1571       })) {
1572     invalidArgument("tensors must all have the same type");
1573   }
1574   if (!std::all_of(tensors.begin(), tensors.end(), [&](at::Tensor& t) {
1575         return t.device() == tensors[0].device();
1576       })) {
1577     invalidArgument("tensors must all be on the same device");
1578   }
1579 
1580   const c10::Device& device = tensors[0].device();
1581   const c10::Layout& layout = tensors[0].layout();
1582 
1583   // invalid arguments are detected early here before any calls to nextTag()
1584   // which result in the collectiveCounter_ being incremented.
1585   switch (device.type()) {
1586     case c10::kCPU:
1587       break;
1588     default:
1589       invalidArgument(c10::str("unsupported device type ", device.type()));
1590   }
1591 
1592   switch (layout) {
1593     case c10::kStrided:
1594       break;
1595     default:
1596       invalidArgument("unsupported layout");
1597   }
1598 
1599   c10::intrusive_ptr<AsyncWork> work;
1600   const uint32_t tag = nextTag();
1601   std::shared_ptr<gloo::Context> context = getContext(tag);
1602   ++seq_;
1603   if (device.type() == c10::kCPU) {
1604     if (layout == c10::kStrided) {
1605       work = c10::make_intrusive<AsyncAllreduceCoalescedWork>(
1606           std::move(context), tensors, opts.reduceOp, tag, seq_);
1607     } else {
1608       invalidArgument("unsupported layout");
1609     }
1610   } else {
1611     TORCH_CHECK(false, "Invalid backend");
1612   }
1613   enqueue(work);
1614   return work;
1615 }
1616 
1617 namespace {
1618 
1619 class AsyncReduceWork : public ProcessGroupGloo::AsyncWork {
1620  public:
AsyncReduceWork(const std::shared_ptr<gloo::Context> & context,std::vector<at::Tensor> & inputs,int rootRank,int rootTensor,ReduceOp reduceOp,uint32_t tag,uint64_t seq)1621   AsyncReduceWork(
1622       const std::shared_ptr<gloo::Context>& context,
1623       std::vector<at::Tensor>& inputs,
1624       int rootRank,
1625       int rootTensor,
1626       ReduceOp reduceOp,
1627       uint32_t tag,
1628       uint64_t seq)
1629       : ProcessGroupGloo::AsyncWork(
1630             {inputs},
1631             OpType::REDUCE,
1632             seq,
1633             "gloo:reduce",
1634             inputs),
1635         context(context),
1636         inputs(inputs),
1637         rootRank(rootRank),
1638         rootTensor(rootTensor),
1639         reduceOp(std::move(reduceOp)),
1640         tag(tag) {}
1641 
1642   std::shared_ptr<gloo::Context> context;
1643   std::vector<at::Tensor> inputs{};
1644   const int rootRank;
1645   const int rootTensor;
1646   const ReduceOp reduceOp;
1647   const uint32_t tag;
1648 
reduce(std::vector<at::Tensor> & tensors)1649   void reduce(std::vector<at::Tensor>& tensors) {
1650     const auto& scalarType = tensors[0].scalar_type();
1651     gloo::ReduceOptions opts(context);
1652     opts.setRoot(rootRank);
1653     opts.setTag(tag);
1654     opts.setReduceFunction(getFunction(scalarType, reduceOp));
1655     GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensors[0]);
1656     gloo::reduce(opts);
1657   }
1658 
run()1659   void run() override {
1660     reduce(inputs);
1661   }
1662 
1663  protected:
1664   template <typename T>
getFunction(gloo::ReduceOptions::Func & fn,const ReduceOp op)1665   void getFunction(gloo::ReduceOptions::Func& fn, const ReduceOp op) {
1666     fn = toFunction<T>(op);
1667   }
1668 
getFunction(const at::ScalarType & dtype,const ReduceOp & op)1669   gloo::ReduceOptions::Func getFunction(
1670       const at::ScalarType& dtype,
1671       const ReduceOp& op) {
1672     gloo::ReduceOptions::Func fn;
1673     GENERATE_ALL_TYPES(dtype, getFunction, fn, op);
1674     return fn;
1675   }
1676 };
1677 
1678 class AsyncReduceCUDAWork : public AsyncReduceWork {
1679  public:
AsyncReduceCUDAWork(const std::shared_ptr<gloo::Context> & context,std::vector<at::Tensor> & inputs,int rootRank,int rootTensor,ReduceOp reduceOp,uint32_t tag,uint64_t seq)1680   AsyncReduceCUDAWork(
1681       const std::shared_ptr<gloo::Context>& context,
1682       std::vector<at::Tensor>& inputs,
1683       int rootRank,
1684       int rootTensor,
1685       ReduceOp reduceOp,
1686       uint32_t tag,
1687       uint64_t seq)
1688       : AsyncReduceWork(
1689             context,
1690             inputs,
1691             rootRank,
1692             rootTensor,
1693             std::move(reduceOp),
1694             tag,
1695             seq) {
1696     initializeStreamsEvents(inputs, streams, events);
1697 
1698     // Kick off copy from CUDA tensors to pinned CPU tensors.
1699     tmp.reserve(inputs.size());
1700     c10::OptionalStreamGuard guard;
1701     for (const auto i : c10::irange(inputs.size())) {
1702       guard.reset_stream(streams[i]);
1703       tmp.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true));
1704     }
1705   }
1706 
run()1707   void run() override {
1708     // Synchronize with copy operations.
1709     for (const auto i : c10::irange(inputs.size())) {
1710       streams[i].synchronize();
1711     }
1712 
1713     // Run reduce on host side tensors.
1714     reduce(tmp);
1715 
1716     // Kick off copy back to the CUDA tensors.
1717     c10::OptionalStreamGuard guard;
1718     for (const auto i : c10::irange(inputs.size())) {
1719       guard.reset_stream(streams[i]);
1720       inputs[i].copy_(tmp[i], /* non_blocking */ true);
1721       events[i].record(streams[i]);
1722     }
1723   }
1724 
synchronize()1725   void synchronize() override {
1726     // Synchronize with the copy back to CUDA tensors.
1727     for (const auto i : c10::irange(inputs.size())) {
1728       c10::Device device = inputs[i].device();
1729       events[i].block(
1730           c10::impl::VirtualGuardImpl(device.type()).getStream(device));
1731     }
1732   }
1733 
1734   std::vector<at::Tensor> tmp{};
1735   std::vector<c10::Stream> streams{};
1736   std::vector<c10::Event> events{};
1737 };
1738 
1739 } // namespace
1740 
reduce(std::vector<at::Tensor> & inputs,const ReduceOptions & opts)1741 c10::intrusive_ptr<Work> ProcessGroupGloo::reduce(
1742     std::vector<at::Tensor>& inputs,
1743     const ReduceOptions& opts) {
1744   static auto invalidArgument = [](const std::string& msg) {
1745     TORCH_CHECK(false, "ProcessGroupGloo::reduce: " + msg);
1746   };
1747 
1748   assertRootRank(invalidArgument, opts.rootRank, size_);
1749   assertRootTensor(
1750       invalidArgument, opts.rootTensor, static_cast<int64_t>(inputs.size()));
1751   assertSingleElement(invalidArgument, inputs);
1752   assertDense(invalidArgument, inputs);
1753 
1754   const auto& device = inputs[0].device();
1755   switch (device.type()) {
1756     case at::kCPU:
1757       break;
1758     case at::kCUDA:
1759       // If the user gave us a CUDA tensor then CUDA must be loaded.
1760       TORCH_INTERNAL_ASSERT(at::hasCUDA());
1761       break;
1762     default:
1763       invalidArgument(c10::str("unsupported device type ", device.type()));
1764   }
1765 
1766   c10::intrusive_ptr<AsyncReduceWork> work;
1767   auto tag = nextTag();
1768   auto context = getContext(tag);
1769   ++seq_;
1770   if (device.type() == at::kCPU) {
1771     work = c10::make_intrusive<AsyncReduceWork>(
1772         std::move(context),
1773         inputs,
1774         opts.rootRank,
1775         opts.rootTensor,
1776         opts.reduceOp,
1777         tag,
1778         seq_);
1779   } else if (device.type() == at::kCUDA) {
1780     work = c10::make_intrusive<AsyncReduceCUDAWork>(
1781         std::move(context),
1782         inputs,
1783         opts.rootRank,
1784         opts.rootTensor,
1785         opts.reduceOp,
1786         tag,
1787         seq_);
1788   } else {
1789     TORCH_CHECK(false, "Invalid backend");
1790   }
1791   enqueue(work);
1792   return work;
1793 }
1794 
1795 namespace {
1796 
1797 class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork {
1798  public:
AsyncAllgatherWork(const std::shared_ptr<gloo::Context> & context,std::vector<std::vector<at::Tensor>> & outputs,std::vector<at::Tensor> & inputs,uint32_t tag,uint64_t seq)1799   AsyncAllgatherWork(
1800       const std::shared_ptr<gloo::Context>& context,
1801       std::vector<std::vector<at::Tensor>>& outputs,
1802       std::vector<at::Tensor>& inputs,
1803       uint32_t tag,
1804       uint64_t seq)
1805       : ProcessGroupGloo::AsyncWork(
1806             outputs,
1807             OpType::ALLGATHER,
1808             seq,
1809             "gloo:all_gather",
1810             inputs),
1811         context(context),
1812         outputs(outputs),
1813         inputs(inputs),
1814         tag(tag) {}
1815 
1816   std::shared_ptr<gloo::Context> context;
1817   std::vector<std::vector<at::Tensor>> outputs{};
1818   std::vector<at::Tensor> inputs{};
1819   const uint32_t tag;
1820 
allgather(std::vector<std::vector<at::Tensor>> & outputs,std::vector<at::Tensor> & inputs)1821   void allgather(
1822       std::vector<std::vector<at::Tensor>>& outputs,
1823       std::vector<at::Tensor>& inputs) {
1824     const auto& scalarType = inputs[0].scalar_type();
1825     gloo::AllgatherOptions opts(context);
1826     opts.setTag(tag);
1827 
1828     // Use single flattened input tensor.
1829     at::Tensor flatInputTensor = flattenDenseTensors(inputs);
1830     GENERATE_ALL_TYPES(scalarType, setInput, opts, flatInputTensor);
1831 
1832     // Use single flat output tensor.
1833     // The first dimension corresponds to the index into outputs[N],
1834     // so copying into the actual output later is easy.
1835     at::Tensor flatOutputTensor = newLikeFlat(outputs[0]);
1836     GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor);
1837     gloo::allgather(opts);
1838 
1839     // Unflatten into output tensors.
1840     for (auto& outputgroup : outputs) {
1841       for (const auto j : c10::irange(outputgroup.size())) {
1842         outputgroup[j].copy_(flatOutputTensor[static_cast<int64_t>(j)]);
1843       }
1844     }
1845   }
1846 
run()1847   void run() override {
1848     allgather(outputs, inputs);
1849   }
1850 };
1851 
1852 // Note: current CUDA implementation holds the assumption that the
1853 // tensors in the nested output tensor vectors are on the same device.
1854 class AsyncAllgatherCUDAWork : public AsyncAllgatherWork {
1855  public:
AsyncAllgatherCUDAWork(const std::shared_ptr<gloo::Context> & context,std::vector<std::vector<at::Tensor>> & outputs,std::vector<at::Tensor> & inputs,uint32_t tag,uint64_t seq)1856   AsyncAllgatherCUDAWork(
1857       const std::shared_ptr<gloo::Context>& context,
1858       std::vector<std::vector<at::Tensor>>& outputs,
1859       std::vector<at::Tensor>& inputs,
1860       uint32_t tag,
1861       uint64_t seq)
1862       : AsyncAllgatherWork(context, outputs, inputs, tag, seq) {
1863     initializeStreamsEvents(inputs, inputStreams, inputEvents);
1864     initializeStreamsEvents(outputs, outputStreams, outputEvents);
1865 
1866     // Kick off copy from CUDA tensors to pinned CPU tensors.
1867     tmpInputs.reserve(inputs.size());
1868     c10::OptionalStreamGuard guard;
1869     for (const auto i : c10::irange(inputs.size())) {
1870       guard.reset_stream(inputStreams[i]);
1871       tmpInputs.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true));
1872     }
1873 
1874     tmpOutputs.resize(outputs.size());
1875     for (const auto i : c10::irange(outputs.size())) {
1876       tmpOutputs[i].reserve(outputs[i].size());
1877       for (const auto j : c10::irange(outputs[i].size())) {
1878         tmpOutputs[i].push_back(pinnedLike(outputs[i][j]));
1879       }
1880     }
1881   }
1882 
run()1883   void run() override {
1884     // Synchronize with copy operations.
1885     for (const auto i : c10::irange(inputs.size())) {
1886       inputStreams[i].synchronize();
1887     }
1888 
1889     for (const auto i : c10::irange(outputs.size())) {
1890       outputStreams[i].synchronize();
1891     }
1892 
1893     // Run allgather on host side tensors.
1894     allgather(tmpOutputs, tmpInputs);
1895 
1896     // Kick off copy back to the CUDA tensors.
1897     c10::OptionalStreamGuard guard;
1898     for (const auto i : c10::irange(outputs.size())) {
1899       guard.reset_stream(outputStreams[i]);
1900       for (const auto j : c10::irange(outputs[i].size())) {
1901         outputs[i][j].copy_(tmpOutputs[i][j], /* non_blocking */ true);
1902       }
1903       outputEvents[i].record(outputStreams[i]);
1904     }
1905   }
1906 
synchronize()1907   void synchronize() override {
1908     // Synchronize with the copy back to CUDA tensors.
1909     for (const auto i : c10::irange(outputs.size())) {
1910       c10::Device device = outputs[i][0].device();
1911       outputEvents[i].block(
1912           c10::impl::VirtualGuardImpl(device.type()).getStream(device));
1913     }
1914   }
1915 
1916   std::vector<at::Tensor> tmpInputs{};
1917   std::vector<c10::Stream> inputStreams{};
1918   std::vector<c10::Event> inputEvents{};
1919 
1920   std::vector<std::vector<at::Tensor>> tmpOutputs{};
1921   std::vector<c10::Stream> outputStreams{};
1922   std::vector<c10::Event> outputEvents{};
1923 };
1924 
1925 // A work that takes an lambda on construction and calls it on wait.
1926 // It is useful for add a continuation to another work, and/or
1927 // composing multiple works together.
1928 class LambdaWork : public Work {
1929  public:
LambdaWork(std::function<void (void)> fn)1930   LambdaWork(std::function<void(void)> fn) : fn_(std::move(fn)) {}
1931 
wait(std::chrono::milliseconds)1932   bool wait(std::chrono::milliseconds /* unused */) override {
1933     fn_();
1934     return true;
1935   }
1936 
1937  private:
1938   std::function<void(void)> fn_;
1939 };
1940 
1941 } // namespace
1942 
_reduce_scatter_base(at::Tensor & outputTensor,at::Tensor & inputTensor,const ReduceScatterOptions & opts)1943 c10::intrusive_ptr<Work> ProcessGroupGloo::_reduce_scatter_base(
1944     at::Tensor& outputTensor,
1945     at::Tensor& inputTensor,
1946     const ReduceScatterOptions& opts) {
1947   std::vector<at::Tensor> outputTensors = {outputTensor};
1948   std::vector<at::Tensor> inputTensors = {inputTensor};
1949   return reduce_scatter_tensor_coalesced(outputTensors, inputTensors, opts);
1950 }
1951 
reduce_scatter_tensor_coalesced(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,const ReduceScatterOptions & opts)1952 c10::intrusive_ptr<Work> ProcessGroupGloo::reduce_scatter_tensor_coalesced(
1953     std::vector<at::Tensor>& outputTensors,
1954     std::vector<at::Tensor>& inputTensors,
1955     const ReduceScatterOptions& opts) {
1956   if (outputTensors.size() != inputTensors.size()) {
1957     TORCH_CHECK(
1958         false, "requires input/output tensor lists to have the same length");
1959   }
1960   const auto rank = getRank();
1961   const auto worldSize = getSize();
1962   std::vector<at::Tensor> buffers;
1963   for (const auto i : c10::irange(inputTensors.size())) {
1964     auto inputShape = inputTensors[i].sizes().vec();
1965     auto outputShape = outputTensors[i].sizes().vec();
1966     TORCH_CHECK_EQ(outputTensors[i].dtype(), inputTensors[i].dtype());
1967     TORCH_CHECK_EQ(outputShape[0] * worldSize, inputShape[0]);
1968     for (size_t i = 1; i < outputShape.size(); ++i) {
1969       TORCH_CHECK_EQ(outputShape[i], inputShape[i]);
1970     }
1971     buffers.push_back(inputTensors[i].clone());
1972   }
1973   std::vector<c10::intrusive_ptr<Work>> works;
1974   for (const auto i : c10::irange(buffers.size())) {
1975     std::vector<at::Tensor> inp = {buffers[i]};
1976     AllreduceOptions arOpts;
1977     arOpts.reduceOp = opts.reduceOp;
1978     works.push_back(allreduce(inp));
1979   }
1980   return c10::make_intrusive<LambdaWork>(
1981       [rank, worldSize, buffers, outputTensors, works = std::move(works)]() {
1982         for (const auto i : c10::irange(outputTensors.size())) {
1983           works[i]->wait();
1984           outputTensors[i].copy_(buffers[i].chunk(worldSize)[rank]);
1985         }
1986       });
1987 }
1988 
_allgather_base(at::Tensor & output_tensor,at::Tensor & input_tensor,const AllgatherOptions & opts)1989 c10::intrusive_ptr<Work> ProcessGroupGloo::_allgather_base(
1990     at::Tensor& output_tensor,
1991     at::Tensor& input_tensor,
1992     const AllgatherOptions& opts) {
1993   auto tensor_list = at::chunk(output_tensor, this->getSize(), 0);
1994   std::vector<std::vector<at::Tensor>> outputs = {tensor_list};
1995   std::vector<at::Tensor> inputs = {input_tensor};
1996   return this->allgather(outputs, inputs, opts);
1997 }
1998 // Note: current CUDA implementation holds the assumption that the
1999 // tensors in the nested output tensor vectors are on the same device.
allgather(std::vector<std::vector<at::Tensor>> & outputs,std::vector<at::Tensor> & inputs,const AllgatherOptions & opts)2000 c10::intrusive_ptr<Work> ProcessGroupGloo::allgather(
2001     std::vector<std::vector<at::Tensor>>& outputs,
2002     std::vector<at::Tensor>& inputs,
2003     const AllgatherOptions& opts) {
2004   static auto invalidArgument = [](const std::string& msg) {
2005     TORCH_CHECK(false, "ProcessGroupGloo::allgather: " + msg);
2006   };
2007 
2008   if (inputs.empty()) {
2009     invalidArgument("requires non-empty input tensor list");
2010   }
2011 
2012   if (inputs.size() != outputs.size()) {
2013     invalidArgument(
2014         "requires input/output tensor lists to have the same length");
2015   }
2016 
2017   for (const auto i : c10::irange(outputs.size())) {
2018     const auto expected = inputs.size() * getSize();
2019     const auto actual = outputs[i].size();
2020     if (actual != expected) {
2021       invalidArgument(
2022           "invalid output tensor list at index " + std::to_string(i) +
2023           " (expected length " + std::to_string(expected) + ", got " +
2024           std::to_string(actual) + ")");
2025     }
2026   }
2027 
2028   assertDense(invalidArgument, inputs);
2029 
2030   // Expect all input/output tensors to have the same type and sizes
2031   const auto& options = inputs[0].options();
2032   const auto& sizes = inputs[0].sizes();
2033   assertTypeAndSizesMatch(invalidArgument, inputs, options, sizes);
2034   for (const auto& output : outputs) {
2035     assertTypeAndSizesMatch(invalidArgument, output, options, sizes);
2036   }
2037 
2038   const auto& device = inputs[0].device();
2039   switch (device.type()) {
2040     case at::kCPU:
2041       break;
2042     case at::kCUDA:
2043       // If the user gave us a CUDA tensor then CUDA must be loaded.
2044       TORCH_INTERNAL_ASSERT(at::hasCUDA());
2045       break;
2046     default:
2047       invalidArgument(c10::str("unsupported device type ", device.type()));
2048   }
2049 
2050   c10::intrusive_ptr<AsyncAllgatherWork> work;
2051   auto tag = nextTag();
2052   auto context = getContext(tag);
2053   ++seq_;
2054   if (device.type() == at::kCPU) {
2055     work = c10::make_intrusive<AsyncAllgatherWork>(
2056         std::move(context), outputs, inputs, tag, seq_);
2057   } else if (device.type() == at::kCUDA) {
2058     work = c10::make_intrusive<AsyncAllgatherCUDAWork>(
2059         std::move(context), outputs, inputs, tag, seq_);
2060   } else {
2061     TORCH_CHECK(false, "Invalid backend");
2062   }
2063   enqueue(work);
2064   return work;
2065 }
2066 
2067 namespace {
2068 
2069 class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork {
2070  public:
AsyncAllgatherCoalescedWork(const std::shared_ptr<gloo::Context> & context,std::vector<std::vector<at::Tensor>> & output_lists,std::vector<at::Tensor> & input_list,uint32_t tag,uint64_t seq)2071   AsyncAllgatherCoalescedWork(
2072       const std::shared_ptr<gloo::Context>& context,
2073       std::vector<std::vector<at::Tensor>>& output_lists,
2074       std::vector<at::Tensor>& input_list,
2075       uint32_t tag,
2076       uint64_t seq)
2077       : ProcessGroupGloo::AsyncWork(
2078             output_lists,
2079             OpType::ALLGATHER_COALESCED,
2080             seq,
2081             "gloo:all_gather",
2082             input_list),
2083         context(context),
2084         output_lists(output_lists),
2085         input_list(input_list),
2086         tag(tag) {}
2087 
2088   std::shared_ptr<gloo::Context> context;
2089   std::vector<std::vector<at::Tensor>> output_lists{};
2090   std::vector<at::Tensor> input_list{};
2091   const uint32_t tag;
2092 
allgather_coalesced()2093   void allgather_coalesced() {
2094     assert(!output_lists.empty());
2095     assert(!output_lists[0].empty());
2096     assert(!input_list.empty());
2097 
2098     const auto& scalarType = input_list[0].scalar_type();
2099     gloo::AllgatherOptions opts(context);
2100     opts.setTag(tag);
2101 
2102     // Use single flattened input tensor.
2103     at::Tensor flatInputTensor = flattenDenseTensors(input_list);
2104     GENERATE_ALL_TYPES(scalarType, setInput, opts, flatInputTensor);
2105 
2106     // Compute total number of elements we need to allocate for all tensors
2107     // requested.
2108     int64_t output_numel = 0;
2109     for (const auto& t : output_lists[0]) {
2110       output_numel += t.numel();
2111     }
2112     output_numel *= static_cast<int64_t>(output_lists.size());
2113     // Use single flat output tensor.
2114     at::Tensor flatOutputTensor =
2115         at::empty({output_numel}, output_lists[0][0].options());
2116     GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor);
2117     gloo::allgather(opts);
2118 
2119     int64_t current_element = 0;
2120     for (auto& output_list : output_lists) {
2121       for (auto& output_tensor : output_list) {
2122         output_tensor.copy_(
2123             flatOutputTensor.narrow(0, current_element, output_tensor.numel())
2124                 .reshape(output_tensor.sizes()),
2125             true);
2126         current_element += output_tensor.numel();
2127       }
2128     }
2129   }
2130 
run()2131   void run() override {
2132     allgather_coalesced();
2133   }
2134 };
2135 
2136 } // namespace
2137 
allgather_coalesced(std::vector<std::vector<at::Tensor>> & output_lists,std::vector<at::Tensor> & input_list,const AllgatherOptions &)2138 c10::intrusive_ptr<Work> ProcessGroupGloo::allgather_coalesced(
2139     std::vector<std::vector<at::Tensor>>& output_lists,
2140     std::vector<at::Tensor>& input_list,
2141     const AllgatherOptions& /* unused */) {
2142   static auto invalidArgument = [](const std::string& msg) {
2143     TORCH_CHECK(false, "ProcessGroupGloo::allgather_coalesced: " + msg);
2144   };
2145 
2146   if (input_list.empty()) {
2147     invalidArgument("requires non-empty input tensor list");
2148   }
2149 
2150   if (output_lists.size() != static_cast<size_t>(getSize())) {
2151     invalidArgument("output lists should be equal to world size");
2152   }
2153 
2154   assertSameDevice(invalidArgument, input_list);
2155 
2156   // Expect i'th tensor of each list from 'output_lists' match i'th tensor
2157   // from 'input_list' in type and size.
2158   for (const auto& output_list : output_lists) {
2159     if (output_list.size() != input_list.size()) {
2160       invalidArgument(
2161           "invalid output size: (expected length " +
2162           std::to_string(input_list.size()) + ", got " +
2163           std::to_string(output_list.size()) + ")");
2164     }
2165     for (const auto i : c10::irange(output_list.size())) {
2166       const auto expected = input_list[i].sizes();
2167       const auto actual = output_list[i].sizes();
2168       if (actual != expected) {
2169         invalidArgument(
2170             "invalid size of output tensor at index " + std::to_string(i) +
2171             " (expected length " + toString(expected) + ", got " +
2172             toString(actual) + ")");
2173       }
2174       if (!input_list[i].options().type_equal(output_list[i].options())) {
2175         invalidArgument(
2176             "invalid tensor type at index " + std::to_string(i) +
2177             " (expected " + input_list[i].toString() + ", got " +
2178             output_list[i].toString() + ")");
2179       }
2180     }
2181   }
2182 
2183   assertDense(invalidArgument, input_list);
2184 
2185   auto tag = nextTag();
2186   auto context = getContext(tag);
2187   ++seq_;
2188   auto work = c10::make_intrusive<AsyncAllgatherCoalescedWork>(
2189       std::move(context), output_lists, input_list, tag, seq_);
2190   enqueue(work);
2191   return work;
2192 }
2193 
allgather_into_tensor_coalesced(std::vector<at::Tensor> & outputs,std::vector<at::Tensor> & inputs,const AllgatherOptions & opts)2194 c10::intrusive_ptr<Work> ProcessGroupGloo::allgather_into_tensor_coalesced(
2195     std::vector<at::Tensor>& outputs,
2196     std::vector<at::Tensor>& inputs,
2197     const AllgatherOptions& opts) {
2198   TORCH_CHECK_EQ(outputs.size(), inputs.size());
2199   std::vector<std::vector<at::Tensor>> output_lists(getSize());
2200   for (auto& output : outputs) {
2201     auto chunks = output.chunk(getSize());
2202     for (const auto i : c10::irange(output_lists.size())) {
2203       output_lists[i].push_back(std::move(chunks[i]));
2204     }
2205   }
2206   return allgather_coalesced(output_lists, inputs, opts);
2207 }
2208 
2209 namespace {
2210 
2211 class AsyncGatherWork : public ProcessGroupGloo::AsyncWork {
2212  public:
AsyncGatherWork(const std::shared_ptr<gloo::Context> & context,std::vector<std::vector<at::Tensor>> & outputs,std::vector<at::Tensor> & inputs,int root,uint32_t tag,uint64_t seq)2213   AsyncGatherWork(
2214       const std::shared_ptr<gloo::Context>& context,
2215       std::vector<std::vector<at::Tensor>>& outputs,
2216       std::vector<at::Tensor>& inputs,
2217       int root,
2218       uint32_t tag,
2219       uint64_t seq)
2220       : ProcessGroupGloo::AsyncWork(
2221             outputs,
2222             OpType::GATHER,
2223             seq,
2224             "gloo:gather",
2225             inputs),
2226         context(context),
2227         outputs(outputs),
2228         inputs(inputs),
2229         root(root),
2230         tag(tag) {}
2231 
2232   std::shared_ptr<gloo::Context> context;
2233   std::vector<std::vector<at::Tensor>> outputs{};
2234   std::vector<at::Tensor> inputs{};
2235   const int root;
2236   const uint32_t tag;
2237 
gather(std::vector<std::vector<at::Tensor>> & outputs,std::vector<at::Tensor> & inputs)2238   void gather(
2239       std::vector<std::vector<at::Tensor>>& outputs,
2240       std::vector<at::Tensor>& inputs) {
2241     const auto scalarType = inputs[0].scalar_type();
2242     gloo::GatherOptions opts(context);
2243     opts.setRoot(root);
2244     opts.setTag(tag);
2245 
2246     // Set single temporary tensor on root process.
2247     // This is later scattered to the separate output tensors.
2248     at::Tensor flatOutputTensor;
2249     if (context->rank == root) {
2250       flatOutputTensor = newLikeFlat(outputs[0]);
2251       GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor);
2252     }
2253 
2254     // Set single input tensor on all processes.
2255     GENERATE_ALL_TYPES(scalarType, setInput, opts, inputs[0]);
2256     gloo::gather(opts);
2257 
2258     // Unflatten into output tensors on root process.
2259     if (context->rank == root) {
2260       for (const auto i : c10::irange(outputs[0].size())) {
2261         outputs[0][i].copy_(flatOutputTensor[static_cast<int64_t>(i)]);
2262       }
2263     }
2264   }
2265 
run()2266   void run() override {
2267     gather(outputs, inputs);
2268   }
2269 };
2270 
2271 // Note: current CUDA implementation holds the assumptions:
2272 //     - inputs.size() is 1
2273 //     - outputs.size() is 1
2274 //     - the size of the nested output tensors is world size, i.e.,
2275 //       outputs[0].size, is world size
2276 class AsyncGatherCUDAWork : public AsyncGatherWork {
2277  public:
AsyncGatherCUDAWork(const std::shared_ptr<gloo::Context> & context,std::vector<std::vector<at::Tensor>> & outputs,std::vector<at::Tensor> & inputs,int root,uint32_t tag,uint64_t seq)2278   AsyncGatherCUDAWork(
2279       const std::shared_ptr<gloo::Context>& context,
2280       std::vector<std::vector<at::Tensor>>& outputs,
2281       std::vector<at::Tensor>& inputs,
2282       int root,
2283       uint32_t tag,
2284       uint64_t seq)
2285       : AsyncGatherWork(context, outputs, inputs, root, tag, seq) {
2286     initializeStreamsEvents(inputs, inputStreams, inputEvents);
2287     initializeStreamsEvents(outputs, outputStreams, outputEvents);
2288 
2289     // Kick off copy from CUDA tensors to pinned CPU tensors.
2290     tmpInputs.reserve(inputs.size());
2291     c10::OptionalStreamGuard guard;
2292     for (const auto i : c10::irange(inputs.size())) {
2293       guard.reset_stream(inputStreams[i]);
2294       tmpInputs.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true));
2295     }
2296 
2297     tmpOutputs.resize(outputs.size());
2298     for (const auto i : c10::irange(outputs.size())) {
2299       tmpOutputs[i].reserve(outputs[i].size());
2300       for (const auto j : c10::irange(outputs[i].size())) {
2301         tmpOutputs[i].push_back(pinnedLike(outputs[i][j]));
2302       }
2303     }
2304   }
2305 
run()2306   void run() override {
2307     // Synchronize with copy operations.
2308     for (const auto i : c10::irange(inputs.size())) {
2309       inputStreams[i].synchronize();
2310     }
2311 
2312     for (const auto i : c10::irange(outputs.size())) {
2313       outputStreams[i].synchronize();
2314     }
2315 
2316     // Run gather on host side tensors.
2317     gather(tmpOutputs, tmpInputs);
2318 
2319     // Kick off copy back to the CUDA tensors.
2320     c10::OptionalStreamGuard guard;
2321     for (const auto i : c10::irange(outputs.size())) {
2322       guard.reset_stream(outputStreams[i]);
2323       for (const auto j : c10::irange(outputs[i].size())) {
2324         outputs[i][j].copy_(tmpOutputs[i][j], /* non_blocking */ true);
2325       }
2326       outputEvents[i].record(outputStreams[i]);
2327     }
2328   }
2329 
synchronize()2330   void synchronize() override {
2331     // Synchronize with the copy back to CUDA tensors.
2332     for (const auto i : c10::irange(outputs.size())) {
2333       c10::Device device = outputs[i][0].device();
2334       outputEvents[i].block(
2335           c10::impl::VirtualGuardImpl(device.type()).getStream(device));
2336     }
2337   }
2338 
2339   std::vector<at::Tensor> tmpInputs{};
2340   std::vector<c10::Stream> inputStreams{};
2341   std::vector<c10::Event> inputEvents{};
2342 
2343   std::vector<std::vector<at::Tensor>> tmpOutputs{};
2344   std::vector<c10::Stream> outputStreams{};
2345   std::vector<c10::Event> outputEvents{};
2346 };
2347 
2348 } // namespace
2349 
gather(std::vector<std::vector<at::Tensor>> & outputs,std::vector<at::Tensor> & inputs,const GatherOptions & opts)2350 c10::intrusive_ptr<Work> ProcessGroupGloo::gather(
2351     std::vector<std::vector<at::Tensor>>& outputs,
2352     std::vector<at::Tensor>& inputs,
2353     const GatherOptions& opts) {
2354   static auto invalidArgument = [](const std::string& msg) {
2355     TORCH_CHECK(false, "ProcessGroupGloo::gather: " + msg);
2356   };
2357 
2358   assertRootRank(invalidArgument, opts.rootRank, size_);
2359   assertSingleElementInput(invalidArgument, inputs);
2360   assertDense(invalidArgument, inputs);
2361 
2362   if (getRank() == opts.rootRank) {
2363     if (outputs.size() != 1) {
2364       std::stringstream ss;
2365       ss << "requires a single-element output list containing a list with "
2366          << getSize() << " tensors.";
2367       invalidArgument(ss.str());
2368     } else if (outputs[0].size() != static_cast<size_t>(getSize())) {
2369       std::stringstream ss;
2370       ss << "Incorrect output list size " << outputs[0].size()
2371          << ". Output list size should be " << getSize()
2372          << ", same as size of the process group.";
2373       invalidArgument(ss.str());
2374     }
2375 
2376     const auto& options = inputs[0].options();
2377     const auto& sizes = inputs[0].sizes();
2378     assertTypeAndSizesMatch(invalidArgument, outputs[0], options, sizes);
2379   } else {
2380     if (!outputs.empty()) {
2381       invalidArgument("requires empty output on non-root");
2382     }
2383   }
2384 
2385   const auto& device = inputs[0].device();
2386   switch (device.type()) {
2387     case at::kCPU:
2388       break;
2389     case at::kCUDA:
2390       // If the user gave us a CUDA tensor then CUDA must be loaded.
2391       TORCH_INTERNAL_ASSERT(at::hasCUDA());
2392       break;
2393     default:
2394       invalidArgument(c10::str("unsupported device type ", device.type()));
2395   }
2396 
2397   c10::intrusive_ptr<AsyncGatherWork> work;
2398   auto tag = nextTag();
2399   auto context = getContext(tag);
2400   ++seq_;
2401   if (device.type() == at::kCPU) {
2402     work = c10::make_intrusive<AsyncGatherWork>(
2403         std::move(context), outputs, inputs, opts.rootRank, tag, seq_);
2404   } else if (device.type() == at::kCUDA) {
2405     work = c10::make_intrusive<AsyncGatherCUDAWork>(
2406         std::move(context), outputs, inputs, opts.rootRank, tag, seq_);
2407   } else {
2408     TORCH_CHECK(false, "Invalid backend");
2409   }
2410   enqueue(work);
2411   return work;
2412 }
2413 
2414 namespace {
2415 
2416 class AsyncScatterWork : public ProcessGroupGloo::AsyncWork {
2417  public:
AsyncScatterWork(const std::shared_ptr<gloo::Context> & context,std::vector<at::Tensor> & outputs,std::vector<std::vector<at::Tensor>> & inputs,int root,uint32_t tag,uint64_t seq)2418   AsyncScatterWork(
2419       const std::shared_ptr<gloo::Context>& context,
2420       std::vector<at::Tensor>& outputs,
2421       std::vector<std::vector<at::Tensor>>& inputs,
2422       int root,
2423       uint32_t tag,
2424       uint64_t seq)
2425       : ProcessGroupGloo::AsyncWork(
2426             {outputs},
2427             OpType::SCATTER,
2428             seq,
2429             "gloo:scatter",
2430             !inputs.empty() ? std::optional<std::vector<at::Tensor>>(inputs[0])
2431                             : std::nullopt),
2432         context(context),
2433         outputs(outputs),
2434         inputs(inputs),
2435         root(root),
2436         tag(tag) {}
2437 
2438   std::shared_ptr<gloo::Context> context;
2439   std::vector<at::Tensor> outputs{};
2440   std::vector<std::vector<at::Tensor>> inputs{};
2441   const int root;
2442   const uint32_t tag;
2443 
scatter(std::vector<at::Tensor> & outputs,std::vector<std::vector<at::Tensor>> & inputs)2444   void scatter(
2445       std::vector<at::Tensor>& outputs,
2446       std::vector<std::vector<at::Tensor>>& inputs) {
2447     const auto scalarType = outputs[0].scalar_type();
2448     gloo::ScatterOptions opts(context);
2449     opts.setRoot(root);
2450     opts.setTag(tag);
2451 
2452     // Set list of input tensors on root process
2453     if (context->rank == root) {
2454       GENERATE_ALL_TYPES(scalarType, setInputs, opts, inputs[0]);
2455     }
2456 
2457     // Set single output tensor on all processes
2458     GENERATE_ALL_TYPES(scalarType, setOutput, opts, outputs[0]);
2459     gloo::scatter(opts);
2460   }
2461 
run()2462   void run() override {
2463     scatter(outputs, inputs);
2464   }
2465 };
2466 
2467 class AsyncScatterCUDAWork : public AsyncScatterWork {
2468  public:
AsyncScatterCUDAWork(const std::shared_ptr<gloo::Context> & context,std::vector<at::Tensor> & outputs,std::vector<std::vector<at::Tensor>> & inputs,int root,uint32_t tag,uint64_t seq)2469   AsyncScatterCUDAWork(
2470       const std::shared_ptr<gloo::Context>& context,
2471       std::vector<at::Tensor>& outputs,
2472       std::vector<std::vector<at::Tensor>>& inputs,
2473       int root,
2474       uint32_t tag,
2475       uint64_t seq)
2476       : AsyncScatterWork(context, outputs, inputs, root, tag, seq) {
2477     initializeStreamsEvents(inputs, inputStreams, inputEvents);
2478     initializeStreamsEvents(outputs, outputStreams, outputEvents);
2479 
2480     // Kick off copy from CUDA tensors to pinned CPU tensors.
2481     tmpInputs.resize(inputs.size());
2482     c10::OptionalStreamGuard guard;
2483     for (const auto i : c10::irange(inputs.size())) {
2484       guard.reset_stream(inputStreams[i]);
2485       tmpInputs[i].reserve(inputs[i].size());
2486       for (const auto j : c10::irange(inputs[i].size())) {
2487         tmpInputs[i].push_back(
2488             pinnedLike(inputs[i][j]).copy_(inputs[i][j], true));
2489       }
2490     }
2491 
2492     tmpOutputs.reserve(outputs.size());
2493     for (auto& output : outputs) {
2494       tmpOutputs.push_back(pinnedLike(output));
2495     }
2496   }
2497 
run()2498   void run() override {
2499     // Synchronize with copy operations.
2500     for (const auto i : c10::irange(inputs.size())) {
2501       inputStreams[i].synchronize();
2502     }
2503     for (const auto i : c10::irange(outputs.size())) {
2504       outputStreams[i].synchronize();
2505     }
2506 
2507     // Run scatter on host side tensors.
2508     scatter(tmpOutputs, tmpInputs);
2509 
2510     // Kick off copy back to the CUDA tensors.
2511     c10::OptionalStreamGuard guard;
2512     for (const auto i : c10::irange(outputs.size())) {
2513       guard.reset_stream(outputStreams[i]);
2514       outputs[i].copy_(tmpOutputs[i], /* non_blocking */ true);
2515       outputEvents[i].record(outputStreams[i]);
2516     }
2517   }
2518 
synchronize()2519   void synchronize() override {
2520     // Synchronize with the copy back to CUDA tensors.
2521     for (const auto i : c10::irange(outputs.size())) {
2522       c10::Device device = outputs[i].device();
2523       outputEvents[i].block(
2524           c10::impl::VirtualGuardImpl(device.type()).getStream(device));
2525     }
2526   }
2527 
2528   std::vector<at::Tensor> tmpOutputs{};
2529   std::vector<c10::Stream> outputStreams{};
2530   std::vector<c10::Event> outputEvents{};
2531 
2532   std::vector<std::vector<at::Tensor>> tmpInputs{};
2533   std::vector<c10::Stream> inputStreams{};
2534   std::vector<c10::Event> inputEvents{};
2535 };
2536 
2537 } // namespace
2538 
scatter(std::vector<at::Tensor> & outputs,std::vector<std::vector<at::Tensor>> & inputs,const ScatterOptions & opts)2539 c10::intrusive_ptr<Work> ProcessGroupGloo::scatter(
2540     std::vector<at::Tensor>& outputs,
2541     std::vector<std::vector<at::Tensor>>& inputs,
2542     const ScatterOptions& opts) {
2543   static auto invalidArgument = [](const std::string& msg) {
2544     TORCH_CHECK(false, "ProcessGroupGloo::scatter: " + msg);
2545   };
2546 
2547   assertRootRank(invalidArgument, opts.rootRank, size_);
2548   assertSingleElementOutput(invalidArgument, outputs);
2549   assertDense(invalidArgument, outputs);
2550 
2551   if (getRank() == opts.rootRank) {
2552     if (inputs.size() != 1) {
2553       std::stringstream ss;
2554       ss << "requires a single-element input list containing a list with "
2555          << getSize() << " tensors";
2556       invalidArgument(ss.str());
2557     } else if (inputs[0].size() != static_cast<size_t>(getSize())) {
2558       std::stringstream ss;
2559       ss << "Incorrect input list size " << inputs[0].size()
2560          << ". Input list size should be " << getSize()
2561          << ", same as size of the process group.";
2562       invalidArgument(ss.str());
2563     }
2564     const auto& options = outputs[0].options();
2565     const auto& sizes = outputs[0].sizes();
2566     assertTypeAndSizesMatch(invalidArgument, inputs[0], options, sizes);
2567   } else {
2568     if (!inputs.empty()) {
2569       invalidArgument("requires empty input on non-root");
2570     }
2571   }
2572 
2573   const auto& device = outputs[0].device();
2574   switch (device.type()) {
2575     case at::kCPU:
2576       break;
2577     case at::kCUDA:
2578       // If the user gave us a CUDA tensor then CUDA must be loaded.
2579       TORCH_INTERNAL_ASSERT(at::hasCUDA());
2580       break;
2581     default:
2582       invalidArgument(c10::str("unsupported device type ", device.type()));
2583   }
2584 
2585   c10::intrusive_ptr<AsyncScatterWork> work;
2586   auto tag = nextTag();
2587   auto context = getContext(tag);
2588   ++seq_;
2589   if (device.type() == at::kCPU) {
2590     work = c10::make_intrusive<AsyncScatterWork>(
2591         std::move(context), outputs, inputs, opts.rootRank, tag, seq_);
2592   } else if (device.type() == at::kCUDA) {
2593     work = c10::make_intrusive<AsyncScatterCUDAWork>(
2594         std::move(context), outputs, inputs, opts.rootRank, tag, seq_);
2595   } else {
2596     TORCH_CHECK(false, "Invalid backend");
2597   }
2598   enqueue(work);
2599   return work;
2600 }
2601 
reduce_scatter(std::vector<at::Tensor> & outputs,std::vector<std::vector<at::Tensor>> & inputs,const ReduceScatterOptions & opts)2602 c10::intrusive_ptr<Work> ProcessGroupGloo::reduce_scatter(
2603     std::vector<at::Tensor>& outputs,
2604     std::vector<std::vector<at::Tensor>>& inputs,
2605     const ReduceScatterOptions& opts) {
2606   TORCH_CHECK(false, "ProcessGroupGloo does not support reduce_scatter");
2607 }
2608 
2609 namespace {
2610 
2611 class AsyncAlltoallWork : public ProcessGroupGloo::AsyncWork {
2612  public:
AsyncAlltoallWork(const std::shared_ptr<gloo::Context> & context,at::Tensor & outputTensor,at::Tensor & inputTensor,std::vector<int64_t> & outputCounts,std::vector<int64_t> & inputCounts,uint32_t tag,uint64_t seq)2613   AsyncAlltoallWork(
2614       const std::shared_ptr<gloo::Context>& context,
2615       at::Tensor& outputTensor,
2616       at::Tensor& inputTensor,
2617       std::vector<int64_t>& outputCounts,
2618       std::vector<int64_t>& inputCounts,
2619       uint32_t tag,
2620       uint64_t seq)
2621       : ProcessGroupGloo::AsyncWork(
2622             {{outputTensor}},
2623             OpType::ALLTOALL,
2624             seq,
2625             "gloo:all_to_all",
2626             std::optional<std::vector<at::Tensor>>({inputTensor})),
2627         context(context),
2628         outputTensor(outputTensor),
2629         inputTensor(inputTensor),
2630         outputCounts(std::move(outputCounts)),
2631         inputCounts(std::move(inputCounts)),
2632         tag(tag) {}
2633 
2634   std::shared_ptr<gloo::Context> context;
2635   at::Tensor outputTensor;
2636   at::Tensor inputTensor;
2637   std::vector<int64_t> outputCounts{};
2638   std::vector<int64_t> inputCounts{};
2639   const uint32_t tag;
2640 
alltoall(at::Tensor & outputTensor,at::Tensor & inputTensor)2641   void alltoall(at::Tensor& outputTensor, at::Tensor& inputTensor) {
2642     const auto scalarType = outputTensor.scalar_type();
2643     if (outputCounts.empty() && inputCounts.empty()) {
2644       // Gloo alltoall
2645       gloo::AlltoallOptions opts(context);
2646       opts.setTag(tag);
2647       GENERATE_ALL_TYPES(scalarType, setInput, opts, inputTensor);
2648       GENERATE_ALL_TYPES(scalarType, setOutput, opts, outputTensor);
2649       gloo::alltoall(opts);
2650     } else {
2651       // Gloo alltoallv
2652       c10d::checkSplitSizes(inputCounts, inputTensor, context->size);
2653       c10d::checkSplitSizes(outputCounts, outputTensor, context->size);
2654       std::vector<int64_t> sendCounts(context->size);
2655       std::vector<int64_t> recvCounts(context->size);
2656       std::vector<int64_t> sendOffsets(context->size);
2657       std::vector<int64_t> recvOffsets(context->size);
2658       c10d::computeLengthsAndOffsets(
2659           inputCounts, inputTensor, &sendCounts, &sendOffsets);
2660       c10d::computeLengthsAndOffsets(
2661           outputCounts, outputTensor, &recvCounts, &recvOffsets);
2662       gloo::AlltoallvOptions opts(context);
2663       opts.setTag(tag);
2664       GENERATE_ALL_TYPES(scalarType, setInput, opts, inputTensor, sendCounts);
2665       GENERATE_ALL_TYPES(scalarType, setOutput, opts, outputTensor, recvCounts);
2666       gloo::alltoallv(opts);
2667     }
2668   }
2669 
run()2670   void run() override {
2671     alltoall(outputTensor, inputTensor);
2672   }
2673 };
2674 
2675 class AsyncAlltoallCUDAWork : public AsyncAlltoallWork {
2676  public:
AsyncAlltoallCUDAWork(const std::shared_ptr<gloo::Context> & context,at::Tensor & outputTensor,at::Tensor & inputTensor,std::vector<int64_t> & outputCounts,std::vector<int64_t> & inputCounts,uint32_t tag,uint64_t seq)2677   AsyncAlltoallCUDAWork(
2678       const std::shared_ptr<gloo::Context>& context,
2679       at::Tensor& outputTensor,
2680       at::Tensor& inputTensor,
2681       std::vector<int64_t>& outputCounts,
2682       std::vector<int64_t>& inputCounts,
2683       uint32_t tag,
2684       uint64_t seq)
2685       : AsyncAlltoallWork(
2686             context,
2687             outputTensor,
2688             inputTensor,
2689             outputCounts,
2690             inputCounts,
2691             tag,
2692             seq) {
2693     initializeStreamsEvents({inputTensor}, inputStreams, inputEvents);
2694     initializeStreamsEvents({outputTensor}, outputStreams, outputEvents);
2695 
2696     // Kick off copy from CUDA tensors to pinned CPU tensors.
2697     c10::OptionalStreamGuard guard;
2698     guard.reset_stream(inputStreams.front());
2699     cpuInput = pinnedLike(inputTensor).copy_(inputTensor, true);
2700 
2701     guard.reset_stream(outputStreams.front());
2702     cpuOutput = pinnedLike(outputTensor);
2703   }
2704 
run()2705   void run() override {
2706     // Synchronize with copy operations.
2707     inputStreams.front().synchronize();
2708     outputStreams.front().synchronize();
2709 
2710     // Run alltoall on host side tensors.
2711     alltoall(cpuOutput, cpuInput);
2712 
2713     // Kick off copy back to the CUDA tensors.
2714     c10::OptionalStreamGuard guard;
2715     guard.reset_stream(outputStreams.front());
2716     outputTensor.copy_(cpuOutput, /* non_blocking */ true);
2717     outputEvents.front().record(outputStreams.front());
2718   }
2719 
synchronize()2720   void synchronize() override {
2721     // Synchronize with the copy back to CUDA tensors.
2722     c10::Device device = outputTensor.device();
2723     outputEvents.front().block(
2724         c10::impl::VirtualGuardImpl(device.type()).getStream(device));
2725   }
2726 
2727   at::Tensor cpuOutput;
2728   std::vector<c10::Stream> outputStreams{};
2729   std::vector<c10::Event> outputEvents{};
2730 
2731   at::Tensor cpuInput;
2732   std::vector<c10::Stream> inputStreams{};
2733   std::vector<c10::Event> inputEvents{};
2734 };
2735 
2736 } // namespace
2737 
alltoall_base(at::Tensor & outputTensor,at::Tensor & inputTensor,std::vector<int64_t> & outputCounts,std::vector<int64_t> & inputCounts,const AllToAllOptions &)2738 c10::intrusive_ptr<Work> ProcessGroupGloo::alltoall_base(
2739     at::Tensor& outputTensor,
2740     at::Tensor& inputTensor,
2741     std::vector<int64_t>& outputCounts,
2742     std::vector<int64_t>& inputCounts,
2743     const AllToAllOptions& /* unused */) {
2744   static auto invalidArgument = [](const std::string& msg) {
2745     TORCH_CHECK(false, "ProcessGroupGloo::alltoall_base: " + msg);
2746   };
2747 
2748   TORCH_CHECK(
2749       outputTensor.device() == inputTensor.device(),
2750       "output tensor and input tensor must be on the same type of device");
2751   assertDense(invalidArgument, {outputTensor});
2752   assertDense(invalidArgument, {inputTensor});
2753 
2754   const auto& device = outputTensor.device();
2755   c10::intrusive_ptr<AsyncAlltoallWork> work;
2756   auto tag = nextTag();
2757   auto context = getContext(tag);
2758   ++seq_;
2759 
2760   if (device.type() == at::kCPU) {
2761     work = c10::make_intrusive<AsyncAlltoallWork>(
2762         std::move(context),
2763         outputTensor,
2764         inputTensor,
2765         outputCounts,
2766         inputCounts,
2767         tag,
2768         seq_);
2769   } else if (device.type() == at::kCUDA) {
2770     work = c10::make_intrusive<AsyncAlltoallCUDAWork>(
2771         std::move(context),
2772         outputTensor,
2773         inputTensor,
2774         outputCounts,
2775         inputCounts,
2776         tag,
2777         seq_);
2778   } else {
2779     invalidArgument(c10::str("unsupported device type ", device.type()));
2780   }
2781   enqueue(work);
2782   return work;
2783 }
2784 
checkSingleTensor(std::vector<at::Tensor> & tensors)2785 static at::Tensor& checkSingleTensor(std::vector<at::Tensor>& tensors) {
2786   if (tensors.size() != 1) {
2787     TORCH_CHECK(false, "ProcessGroupGloo::send takes a single tensor");
2788   }
2789   auto& tensor = tensors[0];
2790   if (!tensor.is_contiguous()) {
2791     TORCH_CHECK(false, "input tensor has to be contiguous");
2792   }
2793   if (tensor.is_sparse()) {
2794     TORCH_CHECK(false, "input tensor has to be dense");
2795   }
2796   return tensor;
2797 }
2798 
checkTag(int32_t tag)2799 static uint32_t checkTag(int32_t tag) {
2800   TORCH_CHECK(tag >= 0, "Tag must be nonnegative");
2801   return (uint32_t)tag;
2802 }
2803 
send(std::vector<at::Tensor> & tensors,int dstRank,int tag)2804 c10::intrusive_ptr<Work> ProcessGroupGloo::send(
2805     std::vector<at::Tensor>& tensors,
2806     int dstRank,
2807     int tag) {
2808   auto& tensor = checkSingleTensor(tensors);
2809   auto utag = checkTag(tag);
2810   auto ptr = tensor.const_data_ptr();
2811   auto size = tensor.numel() * tensor.element_size();
2812 
2813   // Construct unbound buffer.
2814   auto context = getContext(tag);
2815   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
2816   auto buf = context->createUnboundBuffer(const_cast<void*>(ptr), size);
2817   buf->send(dstRank, utag);
2818   ++seq_;
2819 
2820   // The work captures the tensor to prevent it being deallocated and
2821   // the unbound buffer to synchronize on completion of the send.
2822   return c10::make_intrusive<SendWork>(tensor, std::move(buf), seq_);
2823 }
2824 
recv(std::vector<at::Tensor> & tensors,int srcRank,int tag)2825 c10::intrusive_ptr<Work> ProcessGroupGloo::recv(
2826     std::vector<at::Tensor>& tensors,
2827     int srcRank,
2828     int tag) {
2829   auto& tensor = checkSingleTensor(tensors);
2830   auto utag = checkTag(tag);
2831   auto ptr = tensor.mutable_data_ptr();
2832   auto size = tensor.numel() * tensor.element_size();
2833 
2834   // Construct unbound buffer.
2835   auto context = getContext(tag);
2836   auto buf = context->createUnboundBuffer(ptr, size);
2837   buf->recv(srcRank, utag);
2838   ++seq_;
2839 
2840   // The work captures the tensor to prevent it being deallocated and
2841   // the unbound buffer to synchronize on completion of the recv.
2842   return c10::make_intrusive<RecvWork>(
2843       tensor, std::move(buf), OpType::RECV, seq_, "gloo:recv");
2844 }
2845 
recvAnysource(std::vector<at::Tensor> & tensors,int tag)2846 c10::intrusive_ptr<Work> ProcessGroupGloo::recvAnysource(
2847     std::vector<at::Tensor>& tensors,
2848     int tag) {
2849   auto& tensor = checkSingleTensor(tensors);
2850   auto utag = checkTag(tag);
2851   auto ptr = tensor.mutable_data_ptr();
2852   auto size = tensor.numel() * tensor.element_size();
2853 
2854   // Construct unbound buffer.
2855   auto context = getContext(tag);
2856   auto buf = context->createUnboundBuffer(ptr, size);
2857 
2858   // Build list of ranks that this operation can recv from. In these
2859   // bindings we don't differentiate between ranks and can receive
2860   // from any other process in the group.
2861   std::vector<int> srcRanks;
2862   srcRanks.resize(size_);
2863   for (const auto i : c10::irange(size_)) {
2864     srcRanks.push_back(i);
2865   }
2866 
2867   buf->recv(srcRanks, utag);
2868   ++seq_;
2869 
2870   // The work captures the tensor to prevent it being deallocated and
2871   // the unbound buffer to synchronize on completion of the recv.
2872   return c10::make_intrusive<RecvWork>(
2873       tensor,
2874       std::move(buf),
2875       OpType::RECVANYSOURCE,
2876       seq_,
2877       "gloo:recvAnySource");
2878 }
2879 
2880 namespace {
2881 
2882 class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork {
2883  public:
AsyncBarrierWork(const std::shared_ptr<gloo::Context> & context,std::vector<c10::weak_intrusive_ptr<AsyncWork>> priorWork,uint32_t tag,uint64_t seq)2884   AsyncBarrierWork(
2885       const std::shared_ptr<gloo::Context>& context,
2886       std::vector<c10::weak_intrusive_ptr<AsyncWork>> priorWork,
2887       uint32_t tag,
2888       uint64_t seq)
2889       : ProcessGroupGloo::AsyncWork(
2890             {},
2891             OpType::BARRIER,
2892             seq,
2893             "gloo:barrier",
2894             std::nullopt),
2895         context(context),
2896         priorWork(std::move(priorWork)),
2897         tag(tag) {}
2898 
2899   std::shared_ptr<gloo::Context> context;
2900   std::vector<c10::weak_intrusive_ptr<AsyncWork>> priorWork{};
2901   const uint32_t tag;
2902 
run()2903   void run() override {
2904     // Wait on prior work to complete
2905     for (auto& weakWork : priorWork) {
2906       auto work = weakWork.lock();
2907       if (work) {
2908         work->wait();
2909       }
2910     }
2911 
2912     gloo::BarrierOptions opts(context);
2913     opts.setTag(tag);
2914     gloo::barrier(opts);
2915   }
2916 };
2917 
2918 } // namespace
2919 
barrier(const BarrierOptions & opts)2920 c10::intrusive_ptr<Work> ProcessGroupGloo::barrier(const BarrierOptions& opts) {
2921   std::vector<c10::weak_intrusive_ptr<AsyncWork>> priorWork;
2922 
2923   // Snapshot all in progress and pending work as weak_ptr.
2924   // When executing a barrier, we need to ensure that all prior work
2925   // has completed before completing itself.
2926   {
2927     std::unique_lock<std::mutex> lock(workMutex_);
2928     priorWork.insert(
2929         priorWork.end(), workInProgress_.begin(), workInProgress_.end());
2930     priorWork.insert(priorWork.end(), workQueue_.begin(), workQueue_.end());
2931   }
2932 
2933   auto tag = nextTag();
2934   auto context = getContext(tag);
2935   ++seq_;
2936   auto work = c10::make_intrusive<AsyncBarrierWork>(
2937       std::move(context), std::move(priorWork), tag, seq_);
2938   enqueue(work);
2939   return work;
2940 }
2941 
monitoredBarrier(const BarrierOptions & opts,bool waitAllRanks)2942 void ProcessGroupGloo::monitoredBarrier(
2943     const BarrierOptions& opts,
2944     bool waitAllRanks) {
2945   C10_LOG_API_USAGE_ONCE("torch.distributed.monitored_barrier");
2946   // Use default timeout if no timeout was specified.
2947   auto monitoredBarrierTimeout =
2948       (opts.timeout == kUnsetTimeout) ? this->options_->timeout : opts.timeout;
2949   auto rank = this->getRank();
2950   auto t1 = nextTag();
2951   auto t2 = nextTag();
2952   std::vector<at::Tensor> commTensor = {at::tensor({rank})};
2953   // only enforce timeout on rank 0. This is so that other ranks aren't timed
2954   // out first, bringing down the job without reporting which rank timed out.
2955   if (rank != 0) {
2956     auto sendWork = send(commTensor, 0, static_cast<int>(t1));
2957     auto recvWork = recv(commTensor, 0, static_cast<int>(t2));
2958     try {
2959       sendWork->wait();
2960       recvWork->wait();
2961     } catch (const std::exception& e) {
2962       const std::string error = c10::str(
2963           "Rank ",
2964           rank,
2965           " successfully reached monitoredBarrier, but received errors while waiting",
2966           " for send/recv from rank 0. Please check rank 0 logs for faulty rank.");
2967       logAndThrow(
2968           error, c10::str(error, "\n Original exception: \n", e.what()));
2969     }
2970     return;
2971   }
2972   auto startTime = std::chrono::steady_clock::now();
2973   auto worldSize = this->getSize();
2974   // Mappings of rank to recvWork/sendWork respectively.
2975   std::map<int, c10::intrusive_ptr<Work>> recvWorkMap;
2976   std::map<int, c10::intrusive_ptr<Work>> sendWorkMap;
2977   // Kick off recvWork and wait to unblock sendWork->wait() from non-zero ranks.
2978   // Failed/hanging ranks will not ack this call, letting rank 0 know about the
2979   // failure.
2980   for (const auto dstRank : c10::irange(1, worldSize)) {
2981     recvWorkMap.emplace(
2982         dstRank, recv(commTensor, dstRank, static_cast<int>(t1)));
2983   }
2984 
2985   auto waitLoop = [&](const std::map<int, c10::intrusive_ptr<Work>>& works) {
2986     std::vector<int> processedRanks;
2987     for (auto& work : works) {
2988       bool rankResponded = false;
2989       try {
2990         // Note: if waitAllRanks=false, we recompute the time remaining in
2991         // barrier and use this recomputed time in wait(). However, if
2992         // waitAllRanks=true, we use the original timeout, since if we use
2993         // up the entire timeout waiting for response from rank n, then we
2994         // won't have any timeout left to query ranks beginning with n + 1.
2995         auto remainingTime =
2996             getRemainingTime(startTime, monitoredBarrierTimeout, waitAllRanks);
2997         if (!waitAllRanks) {
2998           checkRemainingTime(
2999               monitoredBarrierTimeout, remainingTime, processedRanks, rank);
3000         }
3001         work.second->wait(remainingTime);
3002         rankResponded = true;
3003       } catch (const std::exception& e) {
3004         const std::string error = c10::str(
3005             "[Rank 0]: Rank ",
3006             work.first,
3007             " failed to pass monitoredBarrier in ",
3008             monitoredBarrierTimeout.count(),
3009             " ms");
3010         if (waitAllRanks) {
3011           LOG(ERROR) << error;
3012         } else {
3013           logAndThrow(
3014               error, c10::str(error, "\n Original exception: \n", e.what()));
3015         }
3016       }
3017       if (rankResponded) {
3018         processedRanks.push_back(work.first);
3019       }
3020     }
3021     // If we are collecting all failed ranks, check if we need to throw if
3022     // some ranks have not responded.
3023     // Ensure all ranks from 1, ... WORLD_SIZE -1 have been successfully
3024     // processed.
3025     auto rankFailure =
3026         (processedRanks.size() != static_cast<size_t>(size_ - 1));
3027     if (waitAllRanks && rankFailure) {
3028       std::vector<int> failedRanks;
3029       for (const auto i : c10::irange(1, size_)) {
3030         if (std::find(processedRanks.begin(), processedRanks.end(), i) ==
3031             processedRanks.end()) {
3032           failedRanks.push_back(i);
3033         }
3034       }
3035 
3036       TORCH_INTERNAL_ASSERT(!failedRanks.empty());
3037       const std::string ranksStr = c10::Join(", ", failedRanks);
3038       const std::string error = c10::str(
3039           "[Rank 0]: Ranks ",
3040           ranksStr,
3041           " failed to pass monitoredBarrier in ",
3042           monitoredBarrierTimeout.count(),
3043           " ms");
3044       logAndThrow(error, error);
3045     }
3046   };
3047 
3048   waitLoop(recvWorkMap);
3049   // If we've reached here successfully, this means all ranks have acked in
3050   // monitoredBarrier. Unblock all ranks now by responding to their recv(). This
3051   // ensures that this is a true barrier in that all ranks  exit it successfully
3052   // or none of them do.
3053   for (const auto dstRank : c10::irange(1, worldSize)) {
3054     sendWorkMap.emplace(
3055         dstRank, send(commTensor, dstRank, static_cast<int>(t2)));
3056   }
3057 
3058   waitLoop(sendWorkMap);
3059 }
3060 
setSequenceNumberForGroup()3061 void ProcessGroupGloo::setSequenceNumberForGroup() {
3062 } // Gloo just starts sequence numbers at 0.
3063 
getSequenceNumberForGroup()3064 uint64_t ProcessGroupGloo::getSequenceNumberForGroup() {
3065   return seq_;
3066 }
3067 
enableCollectivesTiming()3068 void ProcessGroupGloo::enableCollectivesTiming() {
3069   // Nothing to do to enable timing
3070 }
3071 
3072 } // namespace c10d
3073 
3074 #endif // USE_C10D_GLOO
3075