1 #include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp>
2
3 #ifdef USE_C10D_GLOO
4
5 #include <c10/core/Allocator.h>
6 #include <c10/core/DeviceType.h>
7 #include <c10/core/ScalarType.h>
8 #include <c10/core/TensorOptions.h>
9 #include <c10/util/Exception.h>
10 #include <c10/util/intrusive_ptr.h>
11 #include <c10/util/irange.h>
12 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
13 #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
14 #include <optional>
15 #include <stdexcept>
16 #include <utility>
17
18 namespace c10d {
19
20 namespace {
21 // A container for information about a particular collective, including optype
22 // and input tensors (if applicable.)
23 struct CollectiveFingerPrint {
24 // Current collective's operation type.
25 OpType op_type_;
26 // Number of input tensors
27 std::size_t num_tensors_{};
28 // input tensor data types
29 std::vector<int8_t> tensor_dtypes_;
30 // input tensor device types
31 std::vector<int8_t> tensor_device_types_;
32 // input tensor sizes
33 std::vector<std::vector<int64_t>> tensor_sizes_;
34 uint64_t sequence_number_;
35
CollectiveFingerPrintc10d::__anonb471210f0111::CollectiveFingerPrint36 CollectiveFingerPrint(
37 OpType op_type,
38 const std::vector<at::Tensor>& input_tensors,
39 uint64_t sequence_number)
40 : op_type_(op_type),
41 num_tensors_(input_tensors.size()),
42 sequence_number_(sequence_number) {
43 tensor_dtypes_.reserve(num_tensors_);
44 tensor_device_types_.reserve(num_tensors_);
45 tensor_sizes_.reserve(num_tensors_);
46 for (const at::Tensor& t : input_tensors) {
47 tensor_dtypes_.push_back(static_cast<int8_t>(t.dtype().toScalarType()));
48 tensor_device_types_.push_back(static_cast<int8_t>(t.device().type()));
49 tensor_sizes_.push_back(t.sizes().vec());
50 }
51 }
52
53 // Constructor for the data received from deserialized fingerprint
CollectiveFingerPrintc10d::__anonb471210f0111::CollectiveFingerPrint54 CollectiveFingerPrint(
55 OpType op_type,
56 size_t num_tensors,
57 std::vector<int8_t> tensor_dtypes,
58 std::vector<int8_t> tensor_device_types,
59 std::vector<std::vector<int64_t>> tensor_sizes,
60 uint64_t sequence_number)
61 : op_type_(op_type),
62 num_tensors_(num_tensors),
63 tensor_dtypes_(std::move(tensor_dtypes)),
64 tensor_device_types_(std::move(tensor_device_types)),
65 tensor_sizes_(std::move(tensor_sizes)),
66 sequence_number_(sequence_number) {}
67
68 // Logs collective information in case of a failure.
69 friend std::ostream& operator<<(
70 std::ostream& output,
71 const CollectiveFingerPrint& collective_fingerprint);
72
73 // Executes and verifies the collective fingerprint.
verifyc10d::__anonb471210f0111::CollectiveFingerPrint74 void verify(c10::intrusive_ptr<Backend> backend) {
75 at::Tensor serialized_tensor = serialize_fingerprint();
76 std::vector<at::Tensor> inp{serialized_tensor};
77 // First verify tensor shapes. This is needed because if e.g. tensor dim
78 // does not match across processes, directly verifying tensors will result
79 // in a crash during allgather, but we'd actually like to report a
80 // description about the inconsistency. Since the input is just a 1D tensor
81 // the shape will be a single int k_i and we need to make sure k_i is
82 // consistent across the whole world.
83 std::vector<at::Tensor> sp = c10d::getTensorShapes(inp);
84 verify_tensors(sp, backend);
85 // Now verify consistency for the actual tensor.
86 verify_tensors(inp, backend);
87 }
88
89 // Takes a serialized fingerprint from
90 // CollectiveFingerPrint::serialize_fingerprint and deserializes it back to a
91 // CollectiveFingerPrint struct
deserialize_fingerprintc10d::__anonb471210f0111::CollectiveFingerPrint92 CollectiveFingerPrint deserialize_fingerprint(
93 const at::Tensor& serialized_tensor) {
94 auto dtypes = std::vector<int8_t>();
95 auto device_types = std::vector<int8_t>();
96 auto sizes = std::vector<std::vector<int64_t>>();
97 int index = 0;
98 int64_t seq = 0;
99 // 1. OpType
100 auto optype = OpType(serialized_tensor[index].item<int>());
101 index++;
102 int num_tensors = 0;
103 if (index < serialized_tensor.size(0)) {
104 seq = serialized_tensor[index].item<int64_t>();
105 index++;
106 // 2. Num tensors
107 num_tensors = serialized_tensor[index].item<int>();
108 index++;
109 dtypes.reserve(num_tensors);
110 device_types.reserve(num_tensors);
111 sizes.reserve(num_tensors);
112
113 // 3. Tensor dtypes
114 for (int i = 0; i < num_tensors; i++) {
115 dtypes.push_back(serialized_tensor[index].item<int8_t>());
116 index++;
117 }
118 // 4. Device types
119 for (int i = 0; i < num_tensors; i++) {
120 device_types.push_back(serialized_tensor[index].item<int8_t>());
121 index++;
122 }
123 // 5. Tensor shapes
124 for (int i = 0; i < num_tensors; i++) {
125 // 5a. Shape size
126 int size = serialized_tensor[index].item<int>();
127 index++;
128 // 5b. Shape
129 auto shapeVec = std::vector<int64_t>();
130 shapeVec.reserve(size);
131 for (int j = 0; j < size; j++) {
132 shapeVec.push_back(serialized_tensor[index].item<int64_t>());
133 index++;
134 }
135 sizes.push_back(shapeVec);
136 }
137 }
138 return CollectiveFingerPrint(
139 optype, num_tensors, dtypes, device_types, sizes, seq);
140 }
141
142 private:
verify_tensorsc10d::__anonb471210f0111::CollectiveFingerPrint143 void verify_tensors(
144 std::vector<at::Tensor>& tensors_to_verify,
145 c10::intrusive_ptr<Backend>& backend) {
146 // Create output tensor data structure to pass into allgather.
147 std::vector<std::vector<at::Tensor>> output_tensors;
148 // output tensors: [<tensor 0 outputs>, <tensor 1 outputs>, ..., <tensor n
149 // outputs>]
150 output_tensors.reserve(tensors_to_verify.size());
151 for (const auto& tensor_shape : tensors_to_verify) {
152 // Each rank has its own outputs shape, e.g.
153 // <tensor 0 outputs>: [<rank 0 tensor>, <rank 1 tensor>, ..., <rank n
154 // tensor>]
155 std::vector<at::Tensor> outputs;
156 outputs.reserve(backend->getSize());
157 for (const auto i : c10::irange(backend->getSize())) {
158 std::ignore = i; // Suppress unused variable warning
159 outputs.emplace_back(at::zeros_like(tensor_shape));
160 }
161 output_tensors.emplace_back(outputs);
162 }
163 // Allgather tensor shapes.
164 backend->allgather(output_tensors, tensors_to_verify)->wait();
165 // Verify equivalence
166 for (const auto i : c10::irange(output_tensors.size())) {
167 const std::vector<at::Tensor> gathered_tensors = output_tensors[i];
168 const at::Tensor reference_tensor = tensors_to_verify[i];
169 for (const auto rank : c10::irange(gathered_tensors.size())) {
170 const auto& rank_tensor = gathered_tensors[rank];
171 if (!rank_tensor.equal(reference_tensor)) {
172 CollectiveFingerPrint rank_fingerprint =
173 deserialize_fingerprint(rank_tensor);
174 std::stringstream ss;
175 ss << "Detected mismatch between collectives on ranks. Rank "
176 << backend->getRank() << " is running collective: " << *this
177 << ", but Rank " << rank
178 << " is running collective: " << rank_fingerprint << ".";
179 auto diff_result = compute_collective_diff(rank_fingerprint);
180 if (std::get<0>(diff_result)) {
181 ss << std::get<1>(diff_result);
182 }
183
184 TORCH_CHECK(false, ss.str());
185 }
186 }
187 }
188 }
189
get_size_strsc10d::__anonb471210f0111::CollectiveFingerPrint190 static std::vector<std::string> get_size_strs(
191 const CollectiveFingerPrint& collective_fingerprint) {
192 std::vector<std::string> size_strs;
193 if (!collective_fingerprint.tensor_sizes_.empty()) {
194 for (const auto& single_tensor_shape_num :
195 collective_fingerprint.tensor_sizes_[0]) {
196 size_strs.emplace_back(std::to_string(single_tensor_shape_num));
197 }
198 }
199 return size_strs;
200 }
201
get_dtype_strsc10d::__anonb471210f0111::CollectiveFingerPrint202 static std::vector<std::string> get_dtype_strs(
203 const CollectiveFingerPrint& collective_fingerprint) {
204 std::vector<std::string> dtype_strs;
205 dtype_strs.reserve(collective_fingerprint.tensor_dtypes_.size());
206 for (const auto& tensor_dtype : collective_fingerprint.tensor_dtypes_) {
207 dtype_strs.emplace_back(
208 c10::toString(static_cast<at::ScalarType>(tensor_dtype)));
209 }
210 return dtype_strs;
211 }
212
get_device_type_strsc10d::__anonb471210f0111::CollectiveFingerPrint213 static std::vector<std::string> get_device_type_strs(
214 const CollectiveFingerPrint& collective_fingerprint) {
215 std::vector<std::string> device_type_strs;
216 device_type_strs.reserve(
217 collective_fingerprint.tensor_device_types_.size());
218 for (const auto& tensor_device_type :
219 collective_fingerprint.tensor_device_types_) {
220 device_type_strs.emplace_back(
221 c10::toString(static_cast<at::DeviceType>(tensor_device_type)));
222 }
223 return device_type_strs;
224 }
225
compute_collective_diffc10d::__anonb471210f0111::CollectiveFingerPrint226 std::pair<bool, std::string> compute_collective_diff(
227 CollectiveFingerPrint& other) {
228 // Computes the difference between two collectives (seq num, tensor shapes,
229 // collective type, etc) for easier understanding of how mismatched
230 // collectives across ranks differ.
231 bool found_diff = false;
232 std::stringstream ss;
233 ss << "Collectives differ in the following aspects: ";
234 // Check seq_num
235 if (other.sequence_number_ != sequence_number_) {
236 found_diff = true;
237 ss << c10::str(
238 "\t Sequence number: ",
239 sequence_number_,
240 "vs ",
241 other.sequence_number_);
242 }
243 // Check op type
244 auto other_op = opTypeToString(other.op_type_);
245 auto this_op = opTypeToString(op_type_);
246 if (other_op != this_op) {
247 found_diff = true;
248 ss << c10::str(" Op type: ", this_op, "vs ", other_op);
249 }
250
251 auto check = [&ss, &found_diff](
252 const char* arg,
253 std::vector<std::string> other,
254 std::vector<std::string> curr) {
255 if (other.size() != curr.size()) {
256 found_diff = true;
257 ss << c10::str(" Tensor ", arg, ": ", curr, "vs ", other);
258 return;
259 }
260 for (size_t i = 0; i < other.size(); ++i) {
261 if (other[i] != curr[i]) {
262 found_diff = true;
263 ss << c10::str(" Tensor ", arg, ": ", curr, "vs ", other);
264 return;
265 }
266 }
267 };
268
269 // check tensor sizes
270 auto other_sizes = get_size_strs(other);
271 auto this_sizes = get_size_strs(*this);
272 check("Tensor shapes", other_sizes, this_sizes);
273
274 // check tensor dtypes
275 auto other_dtypes = get_dtype_strs(other);
276 auto this_dtypes = get_dtype_strs(*this);
277 check("Tensor dtypes", other_dtypes, this_dtypes);
278
279 // check tensor devices
280 auto other_devices = get_device_type_strs(other);
281 auto this_devices = get_device_type_strs(*this);
282
283 check("Tensor devices", other_devices, this_devices);
284 if (!found_diff) {
285 return std::make_pair(false, ss.str());
286 } else {
287 return std::make_pair(true, ss.str());
288 }
289 }
290
291 // Serializes the information (op type, input shapes, data types, device
292 // types) about the collective fingerprint into a tensor
serialize_fingerprintc10d::__anonb471210f0111::CollectiveFingerPrint293 at::Tensor serialize_fingerprint() {
294 auto data = std::make_unique<std::vector<int64_t>>();
295 // std::vector<int64_t> data;
296 // 1. OpType
297 data->push_back(static_cast<int64_t>(op_type_));
298 // sequence number
299 data->push_back(static_cast<int64_t>(sequence_number_));
300 // 2. Num tensors
301 data->push_back(static_cast<int64_t>(num_tensors_));
302 // 3. Tensor dtypes
303 for (const auto& type : tensor_dtypes_) {
304 data->push_back(type);
305 }
306 // 4. Device types
307 for (const auto& d : tensor_device_types_) {
308 data->push_back(d);
309 }
310 // 5. Shapes
311 for (const auto& sizes : tensor_sizes_) {
312 data->push_back(static_cast<int64_t>(sizes.size()));
313 for (const auto& s : sizes) {
314 data->push_back(s);
315 }
316 }
317 // Serialize data into tensor
318 int64_t data_size = static_cast<int64_t>(data->size());
319 // Need to release here and get the ptr due to C++ parameter evaluation
320 // order.
321 auto d = data.release();
322 at::Tensor serialized_tensor =
323 at::for_blob(d->data(), {data_size})
324 .context(
325 d,
326 [](void* ctx) {
327 delete static_cast<std::vector<int64_t>*>(ctx);
328 })
329 .options(at::TensorOptions().dtype(at::kLong))
330 .make_tensor();
331 return serialized_tensor;
332 }
333 };
334
operator <<(std::ostream & output,const CollectiveFingerPrint & collective_fingerprint)335 std::ostream& operator<<(
336 std::ostream& output,
337 const CollectiveFingerPrint& collective_fingerprint) {
338 std::string collectiveInfo;
339 auto op_type_str = opTypeToString(collective_fingerprint.op_type_);
340 if (collective_fingerprint.num_tensors_ != 0) {
341 // Convert dtype and device type info to string.
342 std::vector<std::string> dtype_strs =
343 CollectiveFingerPrint::get_dtype_strs(collective_fingerprint);
344 std::vector<std::string> device_type_strs =
345 CollectiveFingerPrint::get_device_type_strs(collective_fingerprint);
346 std::vector<std::string> size_strs =
347 CollectiveFingerPrint::get_size_strs(collective_fingerprint);
348
349 collectiveInfo = c10::str(
350 "CollectiveFingerPrint(",
351 "SequenceNumber=",
352 collective_fingerprint.sequence_number_,
353 ", OpType=",
354 op_type_str,
355 ", TensorShape=[",
356 c10::Join(", ", size_strs),
357 "], TensorDtypes=",
358 (dtype_strs),
359 ", TensorDeviceTypes=",
360 (device_type_strs),
361 ")");
362 } else {
363 collectiveInfo = c10::str(
364 "CollectiveFingerPrint(",
365 "SequenceNumber=",
366 collective_fingerprint.sequence_number_,
367 "OpType=",
368 op_type_str,
369 ")");
370 }
371 return output << collectiveInfo;
372 }
373
check_same_size(const std::vector<at::Tensor> & input_tensors)374 bool check_same_size(const std::vector<at::Tensor>& input_tensors) {
375 for (const auto& input_tensor : input_tensors) {
376 if (!input_tensors[0].is_same_size(input_tensor)) {
377 return false;
378 }
379 }
380 return true;
381 }
382
383 } // namespace
384
ProcessGroupWrapper(const c10::intrusive_ptr<Backend> & backend,c10::intrusive_ptr<Backend> glooBackend)385 ProcessGroupWrapper::ProcessGroupWrapper(
386 const c10::intrusive_ptr<Backend>& backend,
387 c10::intrusive_ptr<Backend> glooBackend)
388 : Backend(backend->getRank(), backend->getSize()),
389 backend_(backend),
390 glooBackend_(std::move(glooBackend)) {
391 // Set the sequence number for the underlying process group.
392 backend_->setSequenceNumberForGroup();
393 }
394
getBackendName() const395 const std::string ProcessGroupWrapper::getBackendName() const {
396 return backend_->getBackendName();
397 }
398
broadcast(std::vector<at::Tensor> & data,const BroadcastOptions & opts)399 c10::intrusive_ptr<Work> ProcessGroupWrapper::broadcast(
400 std::vector<at::Tensor>& data,
401 const BroadcastOptions& opts) {
402 runCollectiveChecks(OpType::BROADCAST, data);
403 return backend_->broadcast(data, opts);
404 }
405
allreduce(std::vector<at::Tensor> & data,const AllreduceOptions & opts)406 c10::intrusive_ptr<Work> ProcessGroupWrapper::allreduce(
407 std::vector<at::Tensor>& data,
408 const AllreduceOptions& opts) {
409 runCollectiveChecks(OpType::ALLREDUCE, data);
410 return backend_->allreduce(data, opts);
411 }
412
allreduce_coalesced(std::vector<at::Tensor> & tensors,const AllreduceCoalescedOptions & opts)413 c10::intrusive_ptr<Work> ProcessGroupWrapper::allreduce_coalesced(
414 std::vector<at::Tensor>& tensors,
415 const AllreduceCoalescedOptions& opts) {
416 // NOTE: We don't enforce shape checking for allreduce_coalesced because
417 // the implementation itself does not enforce it we have tests that use
418 // inconsistent shapes, see python implementation in distributed_c10d for
419 // details.
420 runCollectiveChecks(OpType::ALLREDUCE_COALESCED, {});
421 return backend_->allreduce_coalesced(tensors, opts);
422 }
423
reduce(std::vector<at::Tensor> & tensors,const ReduceOptions & opts)424 c10::intrusive_ptr<Work> ProcessGroupWrapper::reduce(
425 std::vector<at::Tensor>& tensors,
426 const ReduceOptions& opts) {
427 runCollectiveChecks(OpType::REDUCE, tensors);
428 return backend_->reduce(tensors, opts);
429 }
430
allgather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllgatherOptions & opts)431 c10::intrusive_ptr<Work> ProcessGroupWrapper::allgather(
432 std::vector<std::vector<at::Tensor>>& outputTensors,
433 std::vector<at::Tensor>& inputTensors,
434 const AllgatherOptions& opts) {
435 if (check_same_size(outputTensors.back())) {
436 runCollectiveChecks(OpType::ALLGATHER, inputTensors);
437 } else {
438 runCollectiveChecks(OpType::ALLGATHER, {});
439 }
440 return backend_->allgather(outputTensors, inputTensors, opts);
441 }
442
_allgather_base(at::Tensor & outputBuffer,at::Tensor & inputBuffer,const AllgatherOptions & opts)443 c10::intrusive_ptr<Work> ProcessGroupWrapper::_allgather_base(
444 at::Tensor& outputBuffer,
445 at::Tensor& inputBuffer,
446 const AllgatherOptions& opts) {
447 std::vector<at::Tensor> inputTensors({inputBuffer});
448 runCollectiveChecks(OpType::_ALLGATHER_BASE, inputTensors);
449 return backend_->_allgather_base(outputBuffer, inputBuffer, opts);
450 }
451
allgather_coalesced(std::vector<std::vector<at::Tensor>> & outputTensorLists,std::vector<at::Tensor> & inputTensors,const AllgatherOptions & opts)452 c10::intrusive_ptr<Work> ProcessGroupWrapper::allgather_coalesced(
453 std::vector<std::vector<at::Tensor>>& outputTensorLists,
454 std::vector<at::Tensor>& inputTensors,
455 const AllgatherOptions& opts) {
456 // NOTE: We don't enforce shape checking for allgather_coalesced because
457 // the implementation itself does not enforce it we have tests that use
458 // inconsistent shapes, see python implementation in distributed_c10d for
459 // details.
460 runCollectiveChecks(OpType::ALLGATHER_COALESCED, {});
461 return backend_->allgather_coalesced(outputTensorLists, inputTensors, opts);
462 }
463
gather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const GatherOptions & opts)464 c10::intrusive_ptr<Work> ProcessGroupWrapper::gather(
465 std::vector<std::vector<at::Tensor>>& outputTensors,
466 std::vector<at::Tensor>& inputTensors,
467 const GatherOptions& opts) {
468 runCollectiveChecks(OpType::GATHER, inputTensors);
469 return backend_->gather(outputTensors, inputTensors, opts);
470 }
471
scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ScatterOptions & opts)472 c10::intrusive_ptr<Work> ProcessGroupWrapper::scatter(
473 std::vector<at::Tensor>& outputTensors,
474 std::vector<std::vector<at::Tensor>>& inputTensors,
475 const ScatterOptions& opts) {
476 runCollectiveChecks(OpType::SCATTER, outputTensors);
477 return backend_->scatter(outputTensors, inputTensors, opts);
478 }
479
reduce_scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ReduceScatterOptions & opts)480 c10::intrusive_ptr<Work> ProcessGroupWrapper::reduce_scatter(
481 std::vector<at::Tensor>& outputTensors,
482 std::vector<std::vector<at::Tensor>>& inputTensors,
483 const ReduceScatterOptions& opts) {
484 if (check_same_size(inputTensors.back())) {
485 runCollectiveChecks(OpType::REDUCE_SCATTER, outputTensors);
486 } else {
487 runCollectiveChecks(OpType::REDUCE_SCATTER, {});
488 }
489 return backend_->reduce_scatter(outputTensors, inputTensors, opts);
490 }
491
alltoall_base(at::Tensor & outputTensor,at::Tensor & inputTensor,std::vector<int64_t> & outputSplitSizes,std::vector<int64_t> & inputSplitSizes,const AllToAllOptions & opts)492 c10::intrusive_ptr<Work> ProcessGroupWrapper::alltoall_base(
493 at::Tensor& outputTensor,
494 at::Tensor& inputTensor,
495 std::vector<int64_t>& outputSplitSizes,
496 std::vector<int64_t>& inputSplitSizes,
497 const AllToAllOptions& opts) {
498 // alltoall supports uneven split, so don't enforce shape checking.
499 runCollectiveChecks(OpType::ALLTOALL_BASE, {});
500 return backend_->alltoall_base(
501 outputTensor, inputTensor, outputSplitSizes, inputSplitSizes, opts);
502 }
503
alltoall(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllToAllOptions & opts)504 c10::intrusive_ptr<Work> ProcessGroupWrapper::alltoall(
505 std::vector<at::Tensor>& outputTensors,
506 std::vector<at::Tensor>& inputTensors,
507 const AllToAllOptions& opts) {
508 // alltoall supports uneven split, so don't enforce shape checking.
509 runCollectiveChecks(OpType::ALLTOALL, {});
510 return backend_->alltoall(outputTensors, inputTensors, opts);
511 }
512
monitoredBarrier(const BarrierOptions & opts,bool waitAllRanks)513 void ProcessGroupWrapper::monitoredBarrier(
514 const BarrierOptions& opts,
515 bool waitAllRanks) {
516 return backend_->monitoredBarrier(opts, waitAllRanks);
517 }
518
setSequenceNumberForGroup()519 void ProcessGroupWrapper::setSequenceNumberForGroup() {
520 // Set underlying pg's sequence number if it is not set.
521 if (backend_->getSequenceNumberForGroup() == 0) {
522 // Set the sequence number for the underlying process group.
523 backend_->setSequenceNumberForGroup();
524 }
525 }
526
getSequenceNumberForGroup()527 uint64_t ProcessGroupWrapper::getSequenceNumberForGroup() {
528 return backend_->getSequenceNumberForGroup();
529 }
530
send(std::vector<at::Tensor> & tensors,int dstRank,int tag)531 c10::intrusive_ptr<Work> ProcessGroupWrapper::send(
532 std::vector<at::Tensor>& tensors,
533 int dstRank,
534 int tag) {
535 return backend_->send(tensors, dstRank, tag);
536 }
537
recv(std::vector<at::Tensor> & tensors,int srcRank,int tag)538 c10::intrusive_ptr<Work> ProcessGroupWrapper::recv(
539 std::vector<at::Tensor>& tensors,
540 int srcRank,
541 int tag) {
542 return backend_->recv(tensors, srcRank, tag);
543 }
544
recvAnysource(std::vector<at::Tensor> & tensors,int tag)545 c10::intrusive_ptr<Work> ProcessGroupWrapper::recvAnysource(
546 std::vector<at::Tensor>& tensors,
547 int tag) {
548 return backend_->recvAnysource(tensors, tag);
549 }
550
barrier(const BarrierOptions & opts)551 c10::intrusive_ptr<Work> ProcessGroupWrapper::barrier(
552 const BarrierOptions& opts) {
553 runCollectiveChecks(OpType::BARRIER, {});
554 return backend_->barrier(opts);
555 }
556
_reduce_scatter_base(at::Tensor & outputBuffer,at::Tensor & inputBuffer,const ReduceScatterOptions & opts)557 c10::intrusive_ptr<Work> ProcessGroupWrapper::_reduce_scatter_base(
558 at::Tensor& outputBuffer,
559 at::Tensor& inputBuffer,
560 const ReduceScatterOptions& opts) {
561 runCollectiveChecks(
562 OpType::_REDUCE_SCATTER_BASE, {inputBuffer, outputBuffer});
563 return backend_->_reduce_scatter_base(outputBuffer, inputBuffer, opts);
564 }
565
startCoalescing()566 void ProcessGroupWrapper::startCoalescing() {
567 return backend_->startCoalescing();
568 }
569
endCoalescing()570 c10::intrusive_ptr<Work> ProcessGroupWrapper::endCoalescing() {
571 return backend_->endCoalescing();
572 }
573
getWrappedPg() const574 c10::intrusive_ptr<Backend> ProcessGroupWrapper::getWrappedPg() const {
575 return backend_;
576 }
577
runCollectiveChecks(OpType op_type,const std::vector<at::Tensor> & tensors)578 void ProcessGroupWrapper::runCollectiveChecks(
579 OpType op_type,
580 const std::vector<at::Tensor>& tensors) {
581 // first perform a monitored barrier to ensure all ranks can synchronize.
582 c10d::BarrierOptions options;
583 // TODO: we should use wrapped backend_'s timeout here, but C++ ProcessGroup
584 // API does not expose timeout.
585 auto seq = getSequenceNumberForGroup();
586 auto finger_print = CollectiveFingerPrint(op_type, tensors, seq);
587 LOG(INFO) << "[Rank " << getRank() << "] "
588 << "Running collective: " << finger_print;
589 try {
590 glooBackend_->monitoredBarrier(options, /* waitAllRanks */ true);
591 } catch (const std::runtime_error& e) {
592 // Attach collective info to the exception and re-raise.
593 std::stringstream ss;
594 ss << finger_print;
595 auto collective_info = ss.str();
596 auto err_msg = c10::str(
597 "ProcessGroupWrapper: Monitored Barrier encountered error running collective: ",
598 collective_info,
599 ". Error: \n",
600 e.what());
601 TORCH_CHECK(false, err_msg);
602 }
603 // Will throw if an ill-formed collective is detected.
604 finger_print.verify(glooBackend_);
605 }
606
607 } // namespace c10d
608
609 #endif // USE_C10D_GLOO
610