xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp>
2 
3 #ifdef USE_C10D_GLOO
4 
5 #include <c10/core/Allocator.h>
6 #include <c10/core/DeviceType.h>
7 #include <c10/core/ScalarType.h>
8 #include <c10/core/TensorOptions.h>
9 #include <c10/util/Exception.h>
10 #include <c10/util/intrusive_ptr.h>
11 #include <c10/util/irange.h>
12 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
13 #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
14 #include <optional>
15 #include <stdexcept>
16 #include <utility>
17 
18 namespace c10d {
19 
20 namespace {
21 // A container for information about a particular collective, including optype
22 // and input tensors (if applicable.)
23 struct CollectiveFingerPrint {
24   // Current collective's operation type.
25   OpType op_type_;
26   // Number of input tensors
27   std::size_t num_tensors_{};
28   // input tensor data types
29   std::vector<int8_t> tensor_dtypes_;
30   // input tensor device types
31   std::vector<int8_t> tensor_device_types_;
32   // input tensor sizes
33   std::vector<std::vector<int64_t>> tensor_sizes_;
34   uint64_t sequence_number_;
35 
CollectiveFingerPrintc10d::__anonb471210f0111::CollectiveFingerPrint36   CollectiveFingerPrint(
37       OpType op_type,
38       const std::vector<at::Tensor>& input_tensors,
39       uint64_t sequence_number)
40       : op_type_(op_type),
41         num_tensors_(input_tensors.size()),
42         sequence_number_(sequence_number) {
43     tensor_dtypes_.reserve(num_tensors_);
44     tensor_device_types_.reserve(num_tensors_);
45     tensor_sizes_.reserve(num_tensors_);
46     for (const at::Tensor& t : input_tensors) {
47       tensor_dtypes_.push_back(static_cast<int8_t>(t.dtype().toScalarType()));
48       tensor_device_types_.push_back(static_cast<int8_t>(t.device().type()));
49       tensor_sizes_.push_back(t.sizes().vec());
50     }
51   }
52 
53   // Constructor for the data received from deserialized fingerprint
CollectiveFingerPrintc10d::__anonb471210f0111::CollectiveFingerPrint54   CollectiveFingerPrint(
55       OpType op_type,
56       size_t num_tensors,
57       std::vector<int8_t> tensor_dtypes,
58       std::vector<int8_t> tensor_device_types,
59       std::vector<std::vector<int64_t>> tensor_sizes,
60       uint64_t sequence_number)
61       : op_type_(op_type),
62         num_tensors_(num_tensors),
63         tensor_dtypes_(std::move(tensor_dtypes)),
64         tensor_device_types_(std::move(tensor_device_types)),
65         tensor_sizes_(std::move(tensor_sizes)),
66         sequence_number_(sequence_number) {}
67 
68   // Logs collective information in case of a failure.
69   friend std::ostream& operator<<(
70       std::ostream& output,
71       const CollectiveFingerPrint& collective_fingerprint);
72 
73   // Executes and verifies the collective fingerprint.
verifyc10d::__anonb471210f0111::CollectiveFingerPrint74   void verify(c10::intrusive_ptr<Backend> backend) {
75     at::Tensor serialized_tensor = serialize_fingerprint();
76     std::vector<at::Tensor> inp{serialized_tensor};
77     // First verify tensor shapes. This is needed because if e.g. tensor dim
78     // does not match across processes, directly verifying tensors will result
79     // in a crash during allgather, but we'd actually like to report a
80     // description about the inconsistency. Since the input is just a 1D tensor
81     // the shape will be a single int k_i and we need to make sure k_i is
82     // consistent across the whole world.
83     std::vector<at::Tensor> sp = c10d::getTensorShapes(inp);
84     verify_tensors(sp, backend);
85     // Now verify consistency for the actual tensor.
86     verify_tensors(inp, backend);
87   }
88 
89   // Takes a serialized fingerprint from
90   // CollectiveFingerPrint::serialize_fingerprint and deserializes it back to a
91   // CollectiveFingerPrint struct
deserialize_fingerprintc10d::__anonb471210f0111::CollectiveFingerPrint92   CollectiveFingerPrint deserialize_fingerprint(
93       const at::Tensor& serialized_tensor) {
94     auto dtypes = std::vector<int8_t>();
95     auto device_types = std::vector<int8_t>();
96     auto sizes = std::vector<std::vector<int64_t>>();
97     int index = 0;
98     int64_t seq = 0;
99     // 1. OpType
100     auto optype = OpType(serialized_tensor[index].item<int>());
101     index++;
102     int num_tensors = 0;
103     if (index < serialized_tensor.size(0)) {
104       seq = serialized_tensor[index].item<int64_t>();
105       index++;
106       // 2. Num tensors
107       num_tensors = serialized_tensor[index].item<int>();
108       index++;
109       dtypes.reserve(num_tensors);
110       device_types.reserve(num_tensors);
111       sizes.reserve(num_tensors);
112 
113       // 3. Tensor dtypes
114       for (int i = 0; i < num_tensors; i++) {
115         dtypes.push_back(serialized_tensor[index].item<int8_t>());
116         index++;
117       }
118       // 4. Device types
119       for (int i = 0; i < num_tensors; i++) {
120         device_types.push_back(serialized_tensor[index].item<int8_t>());
121         index++;
122       }
123       // 5. Tensor shapes
124       for (int i = 0; i < num_tensors; i++) {
125         // 5a. Shape size
126         int size = serialized_tensor[index].item<int>();
127         index++;
128         // 5b. Shape
129         auto shapeVec = std::vector<int64_t>();
130         shapeVec.reserve(size);
131         for (int j = 0; j < size; j++) {
132           shapeVec.push_back(serialized_tensor[index].item<int64_t>());
133           index++;
134         }
135         sizes.push_back(shapeVec);
136       }
137     }
138     return CollectiveFingerPrint(
139         optype, num_tensors, dtypes, device_types, sizes, seq);
140   }
141 
142  private:
verify_tensorsc10d::__anonb471210f0111::CollectiveFingerPrint143   void verify_tensors(
144       std::vector<at::Tensor>& tensors_to_verify,
145       c10::intrusive_ptr<Backend>& backend) {
146     // Create output tensor data structure to pass into allgather.
147     std::vector<std::vector<at::Tensor>> output_tensors;
148     // output tensors: [<tensor 0 outputs>, <tensor 1 outputs>, ..., <tensor n
149     // outputs>]
150     output_tensors.reserve(tensors_to_verify.size());
151     for (const auto& tensor_shape : tensors_to_verify) {
152       // Each rank has its own outputs shape, e.g.
153       // <tensor 0 outputs>: [<rank 0 tensor>, <rank 1 tensor>, ..., <rank n
154       // tensor>]
155       std::vector<at::Tensor> outputs;
156       outputs.reserve(backend->getSize());
157       for (const auto i : c10::irange(backend->getSize())) {
158         std::ignore = i; // Suppress unused variable warning
159         outputs.emplace_back(at::zeros_like(tensor_shape));
160       }
161       output_tensors.emplace_back(outputs);
162     }
163     // Allgather tensor shapes.
164     backend->allgather(output_tensors, tensors_to_verify)->wait();
165     // Verify equivalence
166     for (const auto i : c10::irange(output_tensors.size())) {
167       const std::vector<at::Tensor> gathered_tensors = output_tensors[i];
168       const at::Tensor reference_tensor = tensors_to_verify[i];
169       for (const auto rank : c10::irange(gathered_tensors.size())) {
170         const auto& rank_tensor = gathered_tensors[rank];
171         if (!rank_tensor.equal(reference_tensor)) {
172           CollectiveFingerPrint rank_fingerprint =
173               deserialize_fingerprint(rank_tensor);
174           std::stringstream ss;
175           ss << "Detected mismatch between collectives on ranks. Rank "
176              << backend->getRank() << " is running collective: " << *this
177              << ", but Rank " << rank
178              << " is running collective: " << rank_fingerprint << ".";
179           auto diff_result = compute_collective_diff(rank_fingerprint);
180           if (std::get<0>(diff_result)) {
181             ss << std::get<1>(diff_result);
182           }
183 
184           TORCH_CHECK(false, ss.str());
185         }
186       }
187     }
188   }
189 
get_size_strsc10d::__anonb471210f0111::CollectiveFingerPrint190   static std::vector<std::string> get_size_strs(
191       const CollectiveFingerPrint& collective_fingerprint) {
192     std::vector<std::string> size_strs;
193     if (!collective_fingerprint.tensor_sizes_.empty()) {
194       for (const auto& single_tensor_shape_num :
195            collective_fingerprint.tensor_sizes_[0]) {
196         size_strs.emplace_back(std::to_string(single_tensor_shape_num));
197       }
198     }
199     return size_strs;
200   }
201 
get_dtype_strsc10d::__anonb471210f0111::CollectiveFingerPrint202   static std::vector<std::string> get_dtype_strs(
203       const CollectiveFingerPrint& collective_fingerprint) {
204     std::vector<std::string> dtype_strs;
205     dtype_strs.reserve(collective_fingerprint.tensor_dtypes_.size());
206     for (const auto& tensor_dtype : collective_fingerprint.tensor_dtypes_) {
207       dtype_strs.emplace_back(
208           c10::toString(static_cast<at::ScalarType>(tensor_dtype)));
209     }
210     return dtype_strs;
211   }
212 
get_device_type_strsc10d::__anonb471210f0111::CollectiveFingerPrint213   static std::vector<std::string> get_device_type_strs(
214       const CollectiveFingerPrint& collective_fingerprint) {
215     std::vector<std::string> device_type_strs;
216     device_type_strs.reserve(
217         collective_fingerprint.tensor_device_types_.size());
218     for (const auto& tensor_device_type :
219          collective_fingerprint.tensor_device_types_) {
220       device_type_strs.emplace_back(
221           c10::toString(static_cast<at::DeviceType>(tensor_device_type)));
222     }
223     return device_type_strs;
224   }
225 
compute_collective_diffc10d::__anonb471210f0111::CollectiveFingerPrint226   std::pair<bool, std::string> compute_collective_diff(
227       CollectiveFingerPrint& other) {
228     // Computes the difference between two collectives (seq num, tensor shapes,
229     // collective type, etc) for easier understanding of how mismatched
230     // collectives across ranks differ.
231     bool found_diff = false;
232     std::stringstream ss;
233     ss << "Collectives differ in the following aspects: ";
234     // Check seq_num
235     if (other.sequence_number_ != sequence_number_) {
236       found_diff = true;
237       ss << c10::str(
238           "\t Sequence number: ",
239           sequence_number_,
240           "vs ",
241           other.sequence_number_);
242     }
243     // Check op type
244     auto other_op = opTypeToString(other.op_type_);
245     auto this_op = opTypeToString(op_type_);
246     if (other_op != this_op) {
247       found_diff = true;
248       ss << c10::str("  Op type: ", this_op, "vs ", other_op);
249     }
250 
251     auto check = [&ss, &found_diff](
252                      const char* arg,
253                      std::vector<std::string> other,
254                      std::vector<std::string> curr) {
255       if (other.size() != curr.size()) {
256         found_diff = true;
257         ss << c10::str("  Tensor ", arg, ": ", curr, "vs ", other);
258         return;
259       }
260       for (size_t i = 0; i < other.size(); ++i) {
261         if (other[i] != curr[i]) {
262           found_diff = true;
263           ss << c10::str("  Tensor ", arg, ": ", curr, "vs ", other);
264           return;
265         }
266       }
267     };
268 
269     // check tensor sizes
270     auto other_sizes = get_size_strs(other);
271     auto this_sizes = get_size_strs(*this);
272     check("Tensor shapes", other_sizes, this_sizes);
273 
274     // check tensor dtypes
275     auto other_dtypes = get_dtype_strs(other);
276     auto this_dtypes = get_dtype_strs(*this);
277     check("Tensor dtypes", other_dtypes, this_dtypes);
278 
279     // check tensor devices
280     auto other_devices = get_device_type_strs(other);
281     auto this_devices = get_device_type_strs(*this);
282 
283     check("Tensor devices", other_devices, this_devices);
284     if (!found_diff) {
285       return std::make_pair(false, ss.str());
286     } else {
287       return std::make_pair(true, ss.str());
288     }
289   }
290 
291   // Serializes the information (op type, input shapes, data types, device
292   // types) about the collective fingerprint into a tensor
serialize_fingerprintc10d::__anonb471210f0111::CollectiveFingerPrint293   at::Tensor serialize_fingerprint() {
294     auto data = std::make_unique<std::vector<int64_t>>();
295     // std::vector<int64_t> data;
296     // 1. OpType
297     data->push_back(static_cast<int64_t>(op_type_));
298     // sequence number
299     data->push_back(static_cast<int64_t>(sequence_number_));
300     // 2. Num tensors
301     data->push_back(static_cast<int64_t>(num_tensors_));
302     // 3. Tensor dtypes
303     for (const auto& type : tensor_dtypes_) {
304       data->push_back(type);
305     }
306     // 4. Device types
307     for (const auto& d : tensor_device_types_) {
308       data->push_back(d);
309     }
310     // 5. Shapes
311     for (const auto& sizes : tensor_sizes_) {
312       data->push_back(static_cast<int64_t>(sizes.size()));
313       for (const auto& s : sizes) {
314         data->push_back(s);
315       }
316     }
317     // Serialize data into tensor
318     int64_t data_size = static_cast<int64_t>(data->size());
319     // Need to release here and get the ptr due to C++ parameter evaluation
320     // order.
321     auto d = data.release();
322     at::Tensor serialized_tensor =
323         at::for_blob(d->data(), {data_size})
324             .context(
325                 d,
326                 [](void* ctx) {
327                   delete static_cast<std::vector<int64_t>*>(ctx);
328                 })
329             .options(at::TensorOptions().dtype(at::kLong))
330             .make_tensor();
331     return serialized_tensor;
332   }
333 };
334 
operator <<(std::ostream & output,const CollectiveFingerPrint & collective_fingerprint)335 std::ostream& operator<<(
336     std::ostream& output,
337     const CollectiveFingerPrint& collective_fingerprint) {
338   std::string collectiveInfo;
339   auto op_type_str = opTypeToString(collective_fingerprint.op_type_);
340   if (collective_fingerprint.num_tensors_ != 0) {
341     // Convert dtype and device type info to string.
342     std::vector<std::string> dtype_strs =
343         CollectiveFingerPrint::get_dtype_strs(collective_fingerprint);
344     std::vector<std::string> device_type_strs =
345         CollectiveFingerPrint::get_device_type_strs(collective_fingerprint);
346     std::vector<std::string> size_strs =
347         CollectiveFingerPrint::get_size_strs(collective_fingerprint);
348 
349     collectiveInfo = c10::str(
350         "CollectiveFingerPrint(",
351         "SequenceNumber=",
352         collective_fingerprint.sequence_number_,
353         ", OpType=",
354         op_type_str,
355         ", TensorShape=[",
356         c10::Join(", ", size_strs),
357         "], TensorDtypes=",
358         (dtype_strs),
359         ", TensorDeviceTypes=",
360         (device_type_strs),
361         ")");
362   } else {
363     collectiveInfo = c10::str(
364         "CollectiveFingerPrint(",
365         "SequenceNumber=",
366         collective_fingerprint.sequence_number_,
367         "OpType=",
368         op_type_str,
369         ")");
370   }
371   return output << collectiveInfo;
372 }
373 
check_same_size(const std::vector<at::Tensor> & input_tensors)374 bool check_same_size(const std::vector<at::Tensor>& input_tensors) {
375   for (const auto& input_tensor : input_tensors) {
376     if (!input_tensors[0].is_same_size(input_tensor)) {
377       return false;
378     }
379   }
380   return true;
381 }
382 
383 } // namespace
384 
ProcessGroupWrapper(const c10::intrusive_ptr<Backend> & backend,c10::intrusive_ptr<Backend> glooBackend)385 ProcessGroupWrapper::ProcessGroupWrapper(
386     const c10::intrusive_ptr<Backend>& backend,
387     c10::intrusive_ptr<Backend> glooBackend)
388     : Backend(backend->getRank(), backend->getSize()),
389       backend_(backend),
390       glooBackend_(std::move(glooBackend)) {
391   // Set the sequence number for the underlying process group.
392   backend_->setSequenceNumberForGroup();
393 }
394 
getBackendName() const395 const std::string ProcessGroupWrapper::getBackendName() const {
396   return backend_->getBackendName();
397 }
398 
broadcast(std::vector<at::Tensor> & data,const BroadcastOptions & opts)399 c10::intrusive_ptr<Work> ProcessGroupWrapper::broadcast(
400     std::vector<at::Tensor>& data,
401     const BroadcastOptions& opts) {
402   runCollectiveChecks(OpType::BROADCAST, data);
403   return backend_->broadcast(data, opts);
404 }
405 
allreduce(std::vector<at::Tensor> & data,const AllreduceOptions & opts)406 c10::intrusive_ptr<Work> ProcessGroupWrapper::allreduce(
407     std::vector<at::Tensor>& data,
408     const AllreduceOptions& opts) {
409   runCollectiveChecks(OpType::ALLREDUCE, data);
410   return backend_->allreduce(data, opts);
411 }
412 
allreduce_coalesced(std::vector<at::Tensor> & tensors,const AllreduceCoalescedOptions & opts)413 c10::intrusive_ptr<Work> ProcessGroupWrapper::allreduce_coalesced(
414     std::vector<at::Tensor>& tensors,
415     const AllreduceCoalescedOptions& opts) {
416   // NOTE: We don't enforce shape checking for allreduce_coalesced because
417   // the implementation itself does not enforce it we have tests that use
418   // inconsistent shapes, see python implementation in distributed_c10d for
419   // details.
420   runCollectiveChecks(OpType::ALLREDUCE_COALESCED, {});
421   return backend_->allreduce_coalesced(tensors, opts);
422 }
423 
reduce(std::vector<at::Tensor> & tensors,const ReduceOptions & opts)424 c10::intrusive_ptr<Work> ProcessGroupWrapper::reduce(
425     std::vector<at::Tensor>& tensors,
426     const ReduceOptions& opts) {
427   runCollectiveChecks(OpType::REDUCE, tensors);
428   return backend_->reduce(tensors, opts);
429 }
430 
allgather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllgatherOptions & opts)431 c10::intrusive_ptr<Work> ProcessGroupWrapper::allgather(
432     std::vector<std::vector<at::Tensor>>& outputTensors,
433     std::vector<at::Tensor>& inputTensors,
434     const AllgatherOptions& opts) {
435   if (check_same_size(outputTensors.back())) {
436     runCollectiveChecks(OpType::ALLGATHER, inputTensors);
437   } else {
438     runCollectiveChecks(OpType::ALLGATHER, {});
439   }
440   return backend_->allgather(outputTensors, inputTensors, opts);
441 }
442 
_allgather_base(at::Tensor & outputBuffer,at::Tensor & inputBuffer,const AllgatherOptions & opts)443 c10::intrusive_ptr<Work> ProcessGroupWrapper::_allgather_base(
444     at::Tensor& outputBuffer,
445     at::Tensor& inputBuffer,
446     const AllgatherOptions& opts) {
447   std::vector<at::Tensor> inputTensors({inputBuffer});
448   runCollectiveChecks(OpType::_ALLGATHER_BASE, inputTensors);
449   return backend_->_allgather_base(outputBuffer, inputBuffer, opts);
450 }
451 
allgather_coalesced(std::vector<std::vector<at::Tensor>> & outputTensorLists,std::vector<at::Tensor> & inputTensors,const AllgatherOptions & opts)452 c10::intrusive_ptr<Work> ProcessGroupWrapper::allgather_coalesced(
453     std::vector<std::vector<at::Tensor>>& outputTensorLists,
454     std::vector<at::Tensor>& inputTensors,
455     const AllgatherOptions& opts) {
456   // NOTE: We don't enforce shape checking for allgather_coalesced because
457   // the implementation itself does not enforce it we have tests that use
458   // inconsistent shapes, see python implementation in distributed_c10d for
459   // details.
460   runCollectiveChecks(OpType::ALLGATHER_COALESCED, {});
461   return backend_->allgather_coalesced(outputTensorLists, inputTensors, opts);
462 }
463 
gather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const GatherOptions & opts)464 c10::intrusive_ptr<Work> ProcessGroupWrapper::gather(
465     std::vector<std::vector<at::Tensor>>& outputTensors,
466     std::vector<at::Tensor>& inputTensors,
467     const GatherOptions& opts) {
468   runCollectiveChecks(OpType::GATHER, inputTensors);
469   return backend_->gather(outputTensors, inputTensors, opts);
470 }
471 
scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ScatterOptions & opts)472 c10::intrusive_ptr<Work> ProcessGroupWrapper::scatter(
473     std::vector<at::Tensor>& outputTensors,
474     std::vector<std::vector<at::Tensor>>& inputTensors,
475     const ScatterOptions& opts) {
476   runCollectiveChecks(OpType::SCATTER, outputTensors);
477   return backend_->scatter(outputTensors, inputTensors, opts);
478 }
479 
reduce_scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ReduceScatterOptions & opts)480 c10::intrusive_ptr<Work> ProcessGroupWrapper::reduce_scatter(
481     std::vector<at::Tensor>& outputTensors,
482     std::vector<std::vector<at::Tensor>>& inputTensors,
483     const ReduceScatterOptions& opts) {
484   if (check_same_size(inputTensors.back())) {
485     runCollectiveChecks(OpType::REDUCE_SCATTER, outputTensors);
486   } else {
487     runCollectiveChecks(OpType::REDUCE_SCATTER, {});
488   }
489   return backend_->reduce_scatter(outputTensors, inputTensors, opts);
490 }
491 
alltoall_base(at::Tensor & outputTensor,at::Tensor & inputTensor,std::vector<int64_t> & outputSplitSizes,std::vector<int64_t> & inputSplitSizes,const AllToAllOptions & opts)492 c10::intrusive_ptr<Work> ProcessGroupWrapper::alltoall_base(
493     at::Tensor& outputTensor,
494     at::Tensor& inputTensor,
495     std::vector<int64_t>& outputSplitSizes,
496     std::vector<int64_t>& inputSplitSizes,
497     const AllToAllOptions& opts) {
498   // alltoall supports uneven split, so don't enforce shape checking.
499   runCollectiveChecks(OpType::ALLTOALL_BASE, {});
500   return backend_->alltoall_base(
501       outputTensor, inputTensor, outputSplitSizes, inputSplitSizes, opts);
502 }
503 
alltoall(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllToAllOptions & opts)504 c10::intrusive_ptr<Work> ProcessGroupWrapper::alltoall(
505     std::vector<at::Tensor>& outputTensors,
506     std::vector<at::Tensor>& inputTensors,
507     const AllToAllOptions& opts) {
508   // alltoall supports uneven split, so don't enforce shape checking.
509   runCollectiveChecks(OpType::ALLTOALL, {});
510   return backend_->alltoall(outputTensors, inputTensors, opts);
511 }
512 
monitoredBarrier(const BarrierOptions & opts,bool waitAllRanks)513 void ProcessGroupWrapper::monitoredBarrier(
514     const BarrierOptions& opts,
515     bool waitAllRanks) {
516   return backend_->monitoredBarrier(opts, waitAllRanks);
517 }
518 
setSequenceNumberForGroup()519 void ProcessGroupWrapper::setSequenceNumberForGroup() {
520   // Set underlying pg's sequence number if it is not set.
521   if (backend_->getSequenceNumberForGroup() == 0) {
522     // Set the sequence number for the underlying process group.
523     backend_->setSequenceNumberForGroup();
524   }
525 }
526 
getSequenceNumberForGroup()527 uint64_t ProcessGroupWrapper::getSequenceNumberForGroup() {
528   return backend_->getSequenceNumberForGroup();
529 }
530 
send(std::vector<at::Tensor> & tensors,int dstRank,int tag)531 c10::intrusive_ptr<Work> ProcessGroupWrapper::send(
532     std::vector<at::Tensor>& tensors,
533     int dstRank,
534     int tag) {
535   return backend_->send(tensors, dstRank, tag);
536 }
537 
recv(std::vector<at::Tensor> & tensors,int srcRank,int tag)538 c10::intrusive_ptr<Work> ProcessGroupWrapper::recv(
539     std::vector<at::Tensor>& tensors,
540     int srcRank,
541     int tag) {
542   return backend_->recv(tensors, srcRank, tag);
543 }
544 
recvAnysource(std::vector<at::Tensor> & tensors,int tag)545 c10::intrusive_ptr<Work> ProcessGroupWrapper::recvAnysource(
546     std::vector<at::Tensor>& tensors,
547     int tag) {
548   return backend_->recvAnysource(tensors, tag);
549 }
550 
barrier(const BarrierOptions & opts)551 c10::intrusive_ptr<Work> ProcessGroupWrapper::barrier(
552     const BarrierOptions& opts) {
553   runCollectiveChecks(OpType::BARRIER, {});
554   return backend_->barrier(opts);
555 }
556 
_reduce_scatter_base(at::Tensor & outputBuffer,at::Tensor & inputBuffer,const ReduceScatterOptions & opts)557 c10::intrusive_ptr<Work> ProcessGroupWrapper::_reduce_scatter_base(
558     at::Tensor& outputBuffer,
559     at::Tensor& inputBuffer,
560     const ReduceScatterOptions& opts) {
561   runCollectiveChecks(
562       OpType::_REDUCE_SCATTER_BASE, {inputBuffer, outputBuffer});
563   return backend_->_reduce_scatter_base(outputBuffer, inputBuffer, opts);
564 }
565 
startCoalescing()566 void ProcessGroupWrapper::startCoalescing() {
567   return backend_->startCoalescing();
568 }
569 
endCoalescing()570 c10::intrusive_ptr<Work> ProcessGroupWrapper::endCoalescing() {
571   return backend_->endCoalescing();
572 }
573 
getWrappedPg() const574 c10::intrusive_ptr<Backend> ProcessGroupWrapper::getWrappedPg() const {
575   return backend_;
576 }
577 
runCollectiveChecks(OpType op_type,const std::vector<at::Tensor> & tensors)578 void ProcessGroupWrapper::runCollectiveChecks(
579     OpType op_type,
580     const std::vector<at::Tensor>& tensors) {
581   // first perform a monitored barrier to ensure all ranks can synchronize.
582   c10d::BarrierOptions options;
583   // TODO: we should use wrapped backend_'s timeout here, but C++ ProcessGroup
584   // API does not expose timeout.
585   auto seq = getSequenceNumberForGroup();
586   auto finger_print = CollectiveFingerPrint(op_type, tensors, seq);
587   LOG(INFO) << "[Rank " << getRank() << "] "
588             << "Running collective: " << finger_print;
589   try {
590     glooBackend_->monitoredBarrier(options, /* waitAllRanks */ true);
591   } catch (const std::runtime_error& e) {
592     // Attach collective info to the exception and re-raise.
593     std::stringstream ss;
594     ss << finger_print;
595     auto collective_info = ss.str();
596     auto err_msg = c10::str(
597         "ProcessGroupWrapper: Monitored Barrier encountered error running collective: ",
598         collective_info,
599         ". Error: \n",
600         e.what());
601     TORCH_CHECK(false, err_msg);
602   }
603   // Will throw if an ill-formed collective is detected.
604   finger_print.verify(glooBackend_);
605 }
606 
607 } // namespace c10d
608 
609 #endif // USE_C10D_GLOO
610