1 #include <torch/csrc/python_headers.h>
2
3 #include <c10/util/intrusive_ptr.h>
4 #include <c10/util/string_view.h>
5 #include <torch/csrc/distributed/c10d/FileStore.hpp>
6 #include <torch/csrc/distributed/c10d/Functional.hpp>
7 #include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
8 #include <torch/csrc/distributed/c10d/TCPStore.hpp>
9 #include <torch/csrc/distributed/c10d/Utils.hpp>
10 #include <torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp>
11 #include <torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp>
12 #include <torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp>
13 #include <vector>
14 #ifndef _WIN32
15 #include <torch/csrc/distributed/c10d/HashStore.hpp>
16 #endif
17 #include <torch/csrc/distributed/c10d/FakeProcessGroup.hpp>
18 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
19 #include <torch/csrc/distributed/c10d/PyProcessGroup.hpp>
20
21 #ifdef USE_C10D_GLOO
22 #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
23 #include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp>
24 #endif
25
26 #ifdef USE_C10D_NCCL
27 #include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
28 #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
29 #include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
30 #endif
31
32 #ifdef USE_C10D_MPI
33 #include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp>
34 #endif
35
36 #ifdef USE_C10D_UCC
37 #include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
38 #endif
39
40 #include <fmt/format.h>
41 #include <pybind11/chrono.h>
42 #include <torch/csrc/distributed/c10d/DMAConnectivity.hpp>
43 #include <torch/csrc/distributed/c10d/PrefixStore.hpp>
44 #include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
45
46 #include <torch/csrc/distributed/c10d/comm.hpp>
47 #include <torch/csrc/distributed/c10d/debug.h>
48 #include <torch/csrc/distributed/c10d/logger.hpp>
49 #include <torch/csrc/distributed/c10d/reducer.hpp>
50
51 #include <torch/csrc/Exceptions.h>
52 #include <torch/csrc/distributed/c10d/python_comm_hook.h>
53 #include <torch/csrc/jit/python/pybind_utils.h>
54 #include <torch/csrc/utils/object_ptr.h>
55 #include <torch/csrc/utils/pybind.h>
56
57 #include <torch/custom_class.h>
58
59 namespace {
60
61 #ifdef USE_C10D_NCCL
62
acquire_gil()63 bool acquire_gil() {
64 // basically if this function can acquire the gil, it will return quickly.
65 // if not, it will hang forever. The idea is to call this from a thread
66 // wrapped in a future, and then check the future after a timeout, to
67 // determine whether we're facing gil contention.
68 if (Py_IsInitialized()) {
69 pybind11::gil_scoped_acquire gil;
70 return true;
71 }
72
73 // If we end up here, its probably still a "pass" from the perspective of
74 // checking whether python is stuck. but currently we don't check the return
75 // value of this function anyway, just check whether it returned quickly vs
76 // timing out. Taking a long time is the main sign of trouble. Fast return
77 // with true or with false is both OK from the perspective of debugging python
78 // hangs.
79 return false;
80 }
81
registerGilChecker()82 bool registerGilChecker() {
83 c10d::get_gil_checker() = &acquire_gil;
84 return true;
85 }
86
87 static bool registered = registerGilChecker();
88 #endif // USE_C10D_NCCL
89
90 // Wrapper to ensure GIL is released before destructing ProcessGroupGloo
91 // TODO: move this somewhere more generally useful
92 template <typename T>
93 class IntrusivePtrNoGilDestructor {
94 c10::intrusive_ptr<T> impl_{};
95
96 public:
97 IntrusivePtrNoGilDestructor() = default;
98 IntrusivePtrNoGilDestructor(const IntrusivePtrNoGilDestructor&) = default;
99 IntrusivePtrNoGilDestructor(IntrusivePtrNoGilDestructor&&) = default;
100 IntrusivePtrNoGilDestructor& operator=(const IntrusivePtrNoGilDestructor&) =
101 default;
102 IntrusivePtrNoGilDestructor& operator=(IntrusivePtrNoGilDestructor&&) =
103 default;
IntrusivePtrNoGilDestructor(c10::intrusive_ptr<T> impl)104 /* implicit */ IntrusivePtrNoGilDestructor(c10::intrusive_ptr<T> impl)
105 : impl_(std::move(impl)) {}
106 // This ctor is very important; see
107 // https://github.com/pybind/pybind11/issues/2957
IntrusivePtrNoGilDestructor(T * impl)108 explicit IntrusivePtrNoGilDestructor(T* impl)
109 : impl_(c10::intrusive_ptr<T>::unsafe_steal_from_new(impl)) {}
~IntrusivePtrNoGilDestructor()110 ~IntrusivePtrNoGilDestructor() {
111 if (impl_) {
112 if (PyGILState_Check()) {
113 pybind11::gil_scoped_release release;
114 impl_.reset();
115 } else {
116 impl_.reset();
117 }
118 }
119 }
operator *() const120 T& operator*() const noexcept {
121 return *impl_;
122 }
operator ->() const123 T* operator->() const noexcept {
124 return impl_.get();
125 }
get() const126 C10_NODISCARD T* get() const noexcept {
127 return impl_.get();
128 }
reset()129 void reset() noexcept {
130 impl_.reset();
131 }
operator bool() const132 operator bool() const noexcept {
133 return impl_;
134 }
135 };
136
137 } // anonymous namespace
138
139 PYBIND11_DECLARE_HOLDER_TYPE(T, IntrusivePtrNoGilDestructor<T>, true);
140
141 namespace torch::distributed::c10d {
142
143 namespace {
144
toPyBytes(const std::vector<uint8_t> & data)145 py::bytes toPyBytes(const std::vector<uint8_t>& data) {
146 return py::bytes(reinterpret_cast<const char*>(data.data()), data.size());
147 }
148
toPyBytes(const std::vector<std::vector<uint8_t>> & data)149 std::vector<py::bytes> toPyBytes(
150 const std::vector<std::vector<uint8_t>>& data) {
151 std::vector<py::bytes> out;
152 out.reserve(data.size());
153 for (const std::vector<uint8_t>& data_ : data) {
154 out.emplace_back(reinterpret_cast<const char*>(data_.data()), data_.size());
155 }
156 return out;
157 }
158
toVec8(const std::string & data)159 std::vector<uint8_t> toVec8(const std::string& data) {
160 std::vector<uint8_t> out{data.begin(), data.end()};
161 return out;
162 }
163
toVec8(const std::vector<std::string> & data)164 std::vector<std::vector<uint8_t>> toVec8(const std::vector<std::string>& data) {
165 std::vector<std::vector<uint8_t>> out;
166 out.reserve(data.size());
167 for (auto& data_ : data) {
168 out.emplace_back(toVec8(data_));
169 }
170 return out;
171 }
172
173 template <typename T>
174 using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
175
176 constexpr auto kDeprecationWarning =
177 "{} API is being deprecated, please ping "
178 "https://github.com/pytorch/pytorch/issues/46291 "
179 "if you see this warning";
180 template <typename T>
181 using intrusive_ptr_class_ = py::class_<T, c10::intrusive_ptr<T>>;
182
183 template <typename T>
184 using intrusive_ptr_no_gil_destructor_class_ =
185 py::class_<T, IntrusivePtrNoGilDestructor<T>>;
186
187 // PythonStore is a pybind11 trampoline class to allow a Python
188 // class to inherit from c10d.Store and implement its interface.
189 class PythonStore : public ::c10d::Store {
190 public:
191 using ::c10d::Store::Store;
192
193 // Note: this function manually calls the Python-side overload
194 // for this function instead of using the PYBIND11_OVERLOAD_XYZ
195 // macros. This is done so that we can call the Python-side
196 // function with a std::string instead of a std::vector<uint8_t>.
set(const std::string & key,const std::vector<uint8_t> & value)197 void set(const std::string& key, const std::vector<uint8_t>& value) override {
198 pybind11::gil_scoped_acquire gil;
199 pybind11::function fn =
200 pybind11::get_overload(static_cast<const ::c10d::Store*>(this), "set");
201 TORCH_INTERNAL_ASSERT(fn, "Not implemented.");
202 // Call function with a py::bytes object for the value.
203 fn(key, toPyBytes(value));
204 }
205
206 // Note: this function manually calls the Python-side overload
207 // for this function instead of using the PYBIND11_OVERLOAD_XYZ
208 // macros. This is done so that the Python-side function can
209 // return a py::bytes instead of a std::vector<uint8_t>.
get(const std::string & key)210 std::vector<uint8_t> get(const std::string& key) override {
211 pybind11::gil_scoped_acquire gil;
212 pybind11::function fn =
213 pybind11::get_overload(static_cast<const ::c10d::Store*>(this), "get");
214 TORCH_INTERNAL_ASSERT(fn, "Not implemented.");
215 // Cast return value from Python to py::bytes, then implicitly
216 // convert that to a std::string, so that we can construct a
217 // std::vector<uint8_t>. There is no API for directly accessing
218 // the contents of the py::bytes object.
219 std::string str = pybind11::cast<py::bytes>(fn(key));
220 return toVec8(str);
221 }
222
223 // Note: this function manually calls the Python-side overload
224 // for this function instead of using the PYBIND11_OVERLOAD_XYZ
225 // macros. This is done so that the Python-side function can
226 // return a py::bytes instead of a std::vector<uint8_t>.
compareSet(const std::string & key,const std::vector<uint8_t> & expectedValue,const std::vector<uint8_t> & desiredValue)227 std::vector<uint8_t> compareSet(
228 const std::string& key,
229 const std::vector<uint8_t>& expectedValue,
230 const std::vector<uint8_t>& desiredValue) override {
231 pybind11::gil_scoped_acquire gil;
232 pybind11::function fn = pybind11::get_overload(
233 static_cast<const ::c10d::Store*>(this), "compare_set");
234 TORCH_INTERNAL_ASSERT(fn, "Not implemented.");
235 // Cast return value from Python to py::bytes, then implicitly
236 // convert that to a std::string, so that we can construct a
237 // std::vector<uint8_t>. There is no API for directly accessing
238 // the contents of the py::bytes object.
239 std::string str = pybind11::cast<py::bytes>(
240 fn(key, toPyBytes(expectedValue), toPyBytes(desiredValue)));
241 return toVec8(str);
242 }
243
add(const std::string & key,int64_t value)244 int64_t add(const std::string& key, int64_t value) override {
245 PYBIND11_OVERLOAD_PURE(int64_t, ::c10d::Store, add, key, value);
246 }
247
getNumKeys()248 int64_t getNumKeys() override {
249 PYBIND11_OVERLOAD_PURE(int64_t, ::c10d::Store, getNumKeys);
250 }
251
deleteKey(const std::string & key)252 bool deleteKey(const std::string& key) override {
253 PYBIND11_OVERLOAD_PURE(bool, ::c10d::Store, deleteKey, key);
254 }
255
check(const std::vector<std::string> & keys)256 bool check(const std::vector<std::string>& keys) override {
257 PYBIND11_OVERLOAD_PURE(bool, ::c10d::Store, check, keys);
258 }
259
wait(const std::vector<std::string> & keys)260 void wait(const std::vector<std::string>& keys) override {
261 PYBIND11_OVERLOAD_PURE(void, ::c10d::Store, wait, keys);
262 }
263
wait(const std::vector<std::string> & keys,const std::chrono::milliseconds & timeout)264 void wait(
265 const std::vector<std::string>& keys,
266 const std::chrono::milliseconds& timeout) override {
267 PYBIND11_OVERLOAD_PURE(void, ::c10d::Store, wait, keys, timeout);
268 }
269
270 // Note: this function manually calls the Python-side overload
271 // for this function instead of using the PYBIND11_OVERLOAD_XYZ
272 // macros. This is done so that we can call the Python-side
273 // function with a std::string instead of a std::vector<uint8_t>.
append(const std::string & key,const std::vector<uint8_t> & value)274 void append(const std::string& key, const std::vector<uint8_t>& value)
275 override {
276 pybind11::gil_scoped_acquire gil;
277 pybind11::function fn = pybind11::get_overload(
278 static_cast<const ::c10d::Store*>(this), "append");
279 if (!fn) {
280 return Store::append(key, value);
281 }
282 // Call function with a py::bytes object for the value.
283 fn(key, toPyBytes(value));
284 }
285
multiGet(const std::vector<std::string> & keys)286 std::vector<std::vector<uint8_t>> multiGet(
287 const std::vector<std::string>& keys) override {
288 pybind11::gil_scoped_acquire gil;
289 pybind11::function fn = pybind11::get_overload(
290 static_cast<const ::c10d::Store*>(this), "multi_get");
291 if (!fn) {
292 return Store::multiGet(keys);
293 }
294 std::vector<std::string> py_list =
295 pybind11::cast<std::vector<std::string>>(fn(keys));
296 std::vector<std::vector<uint8_t>> res;
297 res.reserve(py_list.size());
298
299 for (auto& str : py_list) {
300 res.emplace_back(str.begin(), str.end());
301 }
302
303 return res;
304 }
305
multiSet(const std::vector<std::string> & keys,const std::vector<std::vector<uint8_t>> & values)306 void multiSet(
307 const std::vector<std::string>& keys,
308 const std::vector<std::vector<uint8_t>>& values) override {
309 pybind11::gil_scoped_acquire gil;
310 pybind11::function fn = pybind11::get_overload(
311 static_cast<const ::c10d::Store*>(this), "multi_set");
312 if (!fn) {
313 return Store::multiSet(keys, values);
314 }
315
316 fn(keys, toPyBytes(values));
317 }
318
hasExtendedApi() const319 bool hasExtendedApi() const override {
320 PYBIND11_OVERLOAD_NAME(
321 bool, ::c10d::Store, "has_extended_api", hasExtendedApi);
322 }
323 };
324
325 class PythonRequest : public ::c10d::control_plane::Request {
326 public:
body() const327 const std::string& body() const override {
328 PYBIND11_OVERRIDE_PURE(
329 const std::string&, ::c10d::control_plane::Request, body);
330 }
331
params() const332 const std::multimap<std::string, std::string>& params() const override {
333 using MultiMap = const std::multimap<std::string, std::string>&;
334 PYBIND11_OVERRIDE_PURE(MultiMap, ::c10d::control_plane::Request, params);
335 }
336 };
337 class PythonResponse : public ::c10d::control_plane::Response {
338 public:
setContent(std::string && content,const std::string & content_type)339 void setContent(std::string&& content, const std::string& content_type)
340 override {
341 PYBIND11_OVERRIDE_PURE_NAME(
342 void,
343 ::c10d::control_plane::Response,
344 "set_content",
345 setContent,
346 content,
347 content_type);
348 }
setStatus(int status)349 void setStatus(int status) override {
350 PYBIND11_OVERRIDE_PURE_NAME(
351 void, ::c10d::control_plane::Response, "set_status", setStatus, status);
352 }
353 };
354
355 // Called from DDP's Python API to create a c10d Python comm hook object.
356 // The input state and callable comm_hook are Python objects. It later calls
357 // register_comm_hook function of the reducer input to register the hook.
_register_comm_hook(::c10d::Reducer & reducer,py::object state,py::object comm_hook)358 void _register_comm_hook(
359 ::c10d::Reducer& reducer,
360 py::object state,
361 py::object comm_hook) {
362 reducer.register_comm_hook(std::make_unique<::c10d::PythonCommHook>(
363 std::move(state), std::move(comm_hook)));
364 }
365
366 // Called from DDP's Python API to create a c10d C++ comm hook.
367 // The input is an enum hook type. It later calls register_builtin_comm_hook
368 // function of the reducer input to set the hook type.
_register_builtin_comm_hook(::c10d::Reducer & reducer,::c10d::BuiltinCommHookType comm_hook_type)369 void _register_builtin_comm_hook(
370 ::c10d::Reducer& reducer,
371 ::c10d::BuiltinCommHookType comm_hook_type) {
372 reducer.register_builtin_comm_hook(comm_hook_type);
373 }
374
375 // Customize the metaclass of ::c10d::ReduceOp for the backward compatibility.
376 // https://github.com/pytorch/pytorch/pull/84243 changed ::c10d::ReduceOp to
377 // struct from enum, sacrificing some of the Python built-in function supports
378 // such as `isinstance` (see https://github.com/pytorch/pytorch/issues/87191)
379 // and `copy` (see
380 // https://github.com/pytorch/pytorch/pull/87303#discussion_r1002879700). Below,
381 // we define a custom `isinstance` in CPython/pybind11
382 // (`reduceopmeta___instancecheck__`) and modify the default metaclass of
383 // pybind11 (`GetReduceOpMetaclass`) so that
384 // `isinstance(torch.distributed.ReduceOp.SUM, torch.distributed.ReduceOp)`
385 // returns :obj:`True` as if `ReduceOp` is enum.
386 // Ref:
387 // - https://docs.python.org/3/extending/newtypes_tutorial.html
388 // - https://docs.python.org/3/c-api/typeobj.html?highlight=tp_methods
389 // - https://github.com/pybind/pybind11/issues/2696
reduceopmeta___instancecheck__(PyObject * self,PyObject * args)390 static PyObject* reduceopmeta___instancecheck__(
391 PyObject* self,
392 PyObject* args) {
393 if (Py_TYPE(self) == Py_TYPE(args)) {
394 Py_RETURN_TRUE;
395 }
396 if (c10::string_view(args->ob_type->tp_name).find("RedOpType") !=
397 c10::string_view::npos) {
398 Py_RETURN_TRUE;
399 }
400 Py_RETURN_FALSE;
401 }
402 // NOLINTNEXTLINE(*c-arrays)
403 static PyMethodDef reduceopmeta_methods[] = {
404 {"__instancecheck__",
405 (PyCFunction)reduceopmeta___instancecheck__,
406 METH_O,
407 "Custom `__instancecheck__` for ReduceOp"},
408 {nullptr, nullptr}};
GetReduceOpMetaclass()409 PyTypeObject* GetReduceOpMetaclass() {
410 static auto* metaclass = [] {
411 PyTypeObject* base_metaclass =
412 pybind11::detail::get_internals().default_metaclass;
413 // NOLINTNEXTLINE(*c-arrays)
414 PyType_Slot slots[] = {
415 {Py_tp_base, base_metaclass},
416 {Py_tp_methods, reduceopmeta_methods},
417 {0},
418 };
419 PyType_Spec spec = {};
420 spec.name = "torch._C._distributed_c10d._ReduceOpMeta";
421 // NOLINTNEXTLINE(*-narrowing-conversions)
422 spec.basicsize = base_metaclass->tp_basicsize;
423 spec.flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
424 spec.slots = slots;
425 PyTypeObject* metaclass = (PyTypeObject*)PyType_FromSpec(&spec);
426 if (!metaclass)
427 throw py::error_already_set();
428 return metaclass;
429 }();
430 return metaclass;
431 }
432
c10d_init(PyObject * _unused,PyObject * noargs)433 PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {
434 C10_LOG_API_USAGE_ONCE("c10d.python.import");
435
436 auto c10d_module = THPObjectPtr(PyImport_ImportModule("torch.distributed"));
437 if (!c10d_module) {
438 throw python_error();
439 }
440
441 auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
442 if (!torch_C_module) {
443 throw python_error();
444 }
445
446 auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
447 auto m =
448 torch_C_m.def_submodule("_distributed_c10d", "distributed c10d bindings");
449
450 auto module = py::handle(m).cast<py::module>();
451
452 module
453 .def(
454 "_register_comm_hook",
455 &_register_comm_hook,
456 py::arg("reducer"),
457 py::arg("state"),
458 py::arg("comm_hook"),
459 py::call_guard<py::gil_scoped_release>())
460 .def(
461 "_register_builtin_comm_hook",
462 &_register_builtin_comm_hook,
463 py::arg("reducer"),
464 py::arg("comm_hook_type"));
465
466 shared_ptr_class_<::c10d::GradBucket>(
467 module,
468 "GradBucket",
469 R"(
470 This class mainly passes a flattened gradient tensor
471 (returned by :meth:`~torch.distributed.GradBucket.buffer`)
472 to DDP communication hook.
473 This tensor can be further decomposed into a list of per-parameter tensors within this bucket
474 (returned by :meth:`~torch.distributed.GradBucket.get_per_parameter_tensors`)
475 to apply layer-wise operations.
476 )")
477 .def(
478 "index",
479 &::c10d::GradBucket::getIndex,
480 py::call_guard<py::gil_scoped_release>(),
481 R"(
482 .. warning::
483 Since the buckets are rebuilt after the first iteration, should not rely on the indices at the beginning of training.
484
485 Returns:
486 The index of a bucket that stores gradients of a few contiguous layers.
487 All the gradients are bucketized.
488 )")
489 .def(
490 "buffer",
491 &::c10d::GradBucket::getBuffer,
492 py::call_guard<py::gil_scoped_release>(),
493 R"(
494 Returns:
495 A flattened 1D ``torch.Tensor`` buffer,
496 which can be further decomposed into a list of per-parameter tensors within this bucket.
497 )")
498 .def(
499 "gradients",
500 &::c10d::GradBucket::getGradients,
501 py::call_guard<py::gil_scoped_release>(),
502 R"(
503 Returns:
504 A list of ``torch.Tensor``. Each tensor in the list corresponds to a gradient.
505 )")
506 .def(
507 "parameters",
508 &::c10d::GradBucket::getParameters,
509 py::call_guard<py::gil_scoped_release>(),
510 R"(
511 Returns:
512 A list of ``torch.Tensor``. Each tensor in the list corresponds to a model
513 parameter.
514 )")
515 .def(
516 "is_last",
517 &::c10d::GradBucket::isLast,
518 py::call_guard<py::gil_scoped_release>(),
519 R"(
520 Returns:
521 Whether this bucket is the last bucket to allreduce in an iteration.
522 This also means that this bucket corresponds to the first few layers in the forward pass.
523 )")
524 .def(
525 "set_buffer",
526 &::c10d::GradBucket::setBuffer,
527 py::arg("buffer"),
528 py::call_guard<py::gil_scoped_release>(),
529 R"(
530 Replaces the tensor in the bucket with the input tensor buffer.
531 )");
532
533 py::enum_<::c10d::BuiltinCommHookType>(module, "BuiltinCommHookType", R"(
534 An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_COMPRESS``.)")
535 .value("ALLREDUCE", ::c10d::BuiltinCommHookType::ALLREDUCE)
536 .value("FP16_COMPRESS", ::c10d::BuiltinCommHookType::FP16_COMPRESS);
537
538 shared_ptr_class_<::c10d::Reducer>(module, "Reducer")
539 .def(
540 py::init<
541 std::vector<at::Tensor>,
542 std::vector<std::vector<size_t>>,
543 std::vector<size_t>,
544 c10::intrusive_ptr<::c10d::ProcessGroup>,
545 std::vector<bool>,
546 int64_t,
547 bool,
548 bool,
549 std::unordered_map<size_t, std::string>,
550 int64_t>(),
551 py::arg("params"),
552 py::arg("bucket_indices"),
553 py::arg("per_bucket_size_limits"),
554 py::arg("process_group"),
555 py::arg("expect_sparse_gradients") = std::vector<bool>(),
556 py::arg("bucket_bytes_cap") = ::c10d::kDefaultBucketBytesCap,
557 py::arg("find_unused_parameters") = false,
558 py::arg("gradient_as_bucket_view") = false,
559 py::arg("param_to_name_mapping") =
560 std::unordered_map<size_t, std::string>(),
561 py::arg("first_bucket_bytes_cap") = ::c10d::kDefaultFirstBucketBytes,
562 py::call_guard<py::gil_scoped_release>())
563 .def(
564 "prepare_for_forward",
565 &::c10d::Reducer::prepare_for_forward,
566 py::call_guard<py::gil_scoped_release>())
567 .def(
568 "prepare_for_backward",
569 &::c10d::Reducer::prepare_for_backward,
570 py::call_guard<py::gil_scoped_release>())
571 .def(
572 "prepare_for_backward",
573 [](::c10d::Reducer& reducer, const at::Tensor& output) -> void {
574 reducer.prepare_for_backward({output});
575 },
576 py::call_guard<py::gil_scoped_release>())
577 .def("get_backward_stats", &::c10d::Reducer::get_backward_stats)
578 .def(
579 "_install_post_backward_futures",
580 [](::c10d::Reducer& reducer,
581 const std::vector<std::shared_ptr<jit::PythonFutureWrapper>>&
582 futs) {
583 c10::List<c10::intrusive_ptr<c10::ivalue::Future>> futures(
584 c10::FutureType::create(c10::TensorType::get()));
585 for (const auto& fut : futs) {
586 futures.push_back(fut->fut);
587 }
588 reducer.install_futures(futures);
589 },
590 py::call_guard<py::gil_scoped_release>())
591 .def(
592 "_rebuild_buckets",
593 &::c10d::Reducer::rebuild_buckets,
594 py::call_guard<py::gil_scoped_release>())
595 .def(
596 "_get_zeros_like_grad_buckets",
597 [](::c10d::Reducer& reducer) {
598 return reducer.get_grad_buckets(/* return_zero_tensors */ true);
599 },
600 py::call_guard<py::gil_scoped_release>())
601 .def(
602 "_set_optimizer_in_backward",
603 [](::c10d::Reducer& reducer) { reducer.set_optimizer_in_backward(); },
604 py::call_guard<py::gil_scoped_release>())
605 .def(
606 "_set_sparse_metadata",
607 &::c10d::Reducer::setSparseMetadata,
608 py::call_guard<py::gil_scoped_release>())
609 .def(
610 "_set_mixed_precision_param_dtype",
611 [](::c10d::Reducer& reducer, py::object data_type_obj) {
612 auto scalar_type =
613 reinterpret_cast<THPDtype*>(data_type_obj.ptr())->scalar_type;
614 reducer.set_mixed_precision_param_dtype(scalar_type);
615 },
616 py::call_guard<py::gil_scoped_release>())
617 .def(
618 "_push_all_rebuilt_params",
619 &::c10d::Reducer::push_rebuilt_params_for_all_indices,
620 py::call_guard<py::gil_scoped_release>())
621 .def(
622 "_set_forward_pass_work_handle",
623 &::c10d::Reducer::set_forward_pass_work_handle,
624 py::call_guard<py::gil_scoped_release>())
625 .def(
626 "_get_local_used_map", &::c10d::Reducer::get_local_used_map_on_device)
627 .def(
628 "_set_ddp_runtime_logging_sample_rate",
629 &::c10d::Reducer::set_ddp_runtime_logging_sample_rate,
630 py::arg("sample_rate"),
631 py::call_guard<py::gil_scoped_release>())
632 .def(
633 "_set_static_graph",
634 &::c10d::Reducer::set_static_graph,
635 py::call_guard<py::gil_scoped_release>())
636 .def(
637 "_ddp_graph_static",
638 &::c10d::Reducer::ddp_graph_static,
639 py::call_guard<py::gil_scoped_release>())
640 .def(
641 "_delay_all_reduce",
642 &::c10d::Reducer::delay_all_reduce,
643 py::call_guard<py::gil_scoped_release>())
644 .def(
645 "_run_comm_hook",
646 [](::c10d::Reducer& reducer, ::c10d::GradBucket& bucket)
647 -> std::shared_ptr<jit::PythonFutureWrapper> {
648 c10::intrusive_ptr<c10::ivalue::Future> fut =
649 reducer.run_comm_hook(bucket);
650 return std::make_shared<jit::PythonFutureWrapper>(fut);
651 },
652 py::call_guard<py::gil_scoped_release>())
653 .def(
654 "_run_allreduce_hook",
655 [](::c10d::Reducer& reducer, ::c10d::GradBucket& bucket)
656 -> std::shared_ptr<jit::PythonFutureWrapper> {
657 c10::intrusive_ptr<c10::ivalue::Future> fut =
658 reducer.run_allreduce_hook(bucket);
659 return std::make_shared<jit::PythonFutureWrapper>(fut);
660 },
661 py::call_guard<py::gil_scoped_release>())
662 .def(
663 "_autograd_hook",
664 [](::c10d::Reducer& reducer, int index) -> void {
665 reducer.autograd_hook(index);
666 },
667 py::call_guard<py::gil_scoped_release>())
668 .def(
669 "set_logger",
670 [](::c10d::Reducer& reducer,
671 const std::shared_ptr<::c10d::Logger>& logger) {
672 std::weak_ptr<::c10d::Logger> logger_weakref = logger;
673 reducer.set_logger(logger_weakref);
674 })
675 .def(
676 "_remove_autograd_hooks",
677 [](::c10d::Reducer& reducer) { reducer.remove_autograd_hooks(); },
678 py::call_guard<py::gil_scoped_release>())
679 .def(
680 "_check_reducer_finalized",
681 [](::c10d::Reducer& reducer) { return reducer.check_finalized(); },
682 py::call_guard<py::gil_scoped_release>())
683 .def(
684 "_reset_state",
685 [](::c10d::Reducer& reducer) { return reducer.reset_state(); },
686 py::call_guard<py::gil_scoped_release>())
687 .def(
688 "_update_process_group",
689 [](::c10d::Reducer& reducer,
690 c10::intrusive_ptr<::c10d::ProcessGroup> new_process_group) {
691 return reducer.update_process_group(std::move(new_process_group));
692 },
693 py::call_guard<py::gil_scoped_release>());
694
695 shared_ptr_class_<::c10d::Logger>(module, "Logger")
696 .def(
697 py::init<std::shared_ptr<::c10d::Reducer>>(),
698 py::arg("reducer"),
699 py::call_guard<py::gil_scoped_release>())
700 .def(
701 "set_construction_data_and_log",
702 &::c10d::Logger::set_construction_data_and_log,
703 py::arg("module_name"),
704 py::arg("device_ids"),
705 py::arg("output_device"),
706 py::arg("broadcast_buffers"),
707 py::arg("has_sync_bn"),
708 py::arg("static_graph"),
709 py::call_guard<py::gil_scoped_release>())
710 .def(
711 "set_runtime_stats_and_log",
712 &::c10d::Logger::set_runtime_stats_and_log,
713 py::call_guard<py::gil_scoped_release>())
714 .def(
715 "set_error_and_log",
716 [](::c10d::Logger& logger, const std::string& error) {
717 logger.set_error_and_log(error);
718 },
719 py::call_guard<py::gil_scoped_release>())
720 .def(
721 "_get_ddp_logging_data",
722 &::c10d::Logger::get_ddp_logging_data,
723 py::call_guard<py::gil_scoped_release>())
724 .def(
725 "_set_comm_hook_name",
726 &::c10d::Logger::set_comm_hook,
727 py::arg("comm_hook"),
728 py::call_guard<py::gil_scoped_release>())
729 .def(
730 "_set_uneven_input_join",
731 &::c10d::Logger::set_uneven_input_join,
732 py::call_guard<py::gil_scoped_release>())
733 .def(
734 "_set_static_graph",
735 &::c10d::Logger::set_static_graph,
736 py::call_guard<py::gil_scoped_release>());
737
738 py::enum_<::c10d::DebugLevel>(module, "DebugLevel", R"(
739 An enum whose values correspond to different debug levels of the
740 torch.distributed package. Currently supporting OFF, INFO, and DETAIL,
741 which can be set via the TORCH_DISTRIBUTED_DEBUG environment variable
742 or via ``set_debug_level()`` function.
743 )")
744 .value("OFF", ::c10d::DebugLevel::Off)
745 .value("INFO", ::c10d::DebugLevel::Info)
746 .value("DETAIL", ::c10d::DebugLevel::Detail);
747
748 module
749 .def(
750 "get_debug_level",
751 ::c10d::debug_level,
752 R"(Gets the debug level of the torch.distributed package.)")
753 .def(
754 "set_debug_level",
755 ::c10d::setDebugLevel,
756 R"(Sets the debug level of the torch.distributed package.)")
757 .def(
758 "set_debug_level_from_env",
759 ::c10d::setDebugLevelFromEnvironment,
760 R"(Sets the debug level of the torch.distributed package from the
761 ``TORCH_DISTRIBUTED_DEBUG`` environment variable.)");
762
763 // TODO(crcrpar): Hardening `ReduceOp`.
764 // While keeping most op types as enum value,
765 // making `PREMUL_SUM` callable, i.e., allowing for
766 // `ReduceOp.PREMUL_SUM(scale)` might be better as per @wanchaol.
767 // https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types
768 py::class_<::c10d::ReduceOp> reduce_op(
769 module, "ReduceOp", py::metaclass((PyObject*)GetReduceOpMetaclass()), R"(
770 An enum-like class for available reduction operations: ``SUM``, ``PRODUCT``,
771 ``MIN``, ``MAX``, ``BAND``, ``BOR``, ``BXOR``, and ``PREMUL_SUM``.
772
773 ``BAND``, ``BOR``, and ``BXOR`` reductions are not available when
774 using the ``NCCL`` backend.
775
776 ``AVG`` divides values by the world size before summing across ranks.
777 ``AVG`` is only available with the ``NCCL`` backend,
778 and only for NCCL versions 2.10 or later.
779
780 ``PREMUL_SUM`` multiplies inputs by a given scalar locally before reduction.
781 ``PREMUL_SUM`` is only available with the ``NCCL`` backend,
782 and only available for NCCL versions 2.11 or later. Users are supposed to
783 use ``torch.distributed._make_nccl_premul_sum``.
784
785 Additionally, ``MAX``, ``MIN`` and ``PRODUCT`` are not supported for complex tensors.
786
787 The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``.
788 They are used in specifying strategies for reduction collectives, e.g.,
789 :func:`reduce`.
790
791 This class does not support ``__members__`` property.)");
792
793 reduce_op.def(py::init<::c10d::ReduceOp::RedOpType>())
794 .def_readwrite("op", &::c10d::ReduceOp::op_);
795 // The following are for some kind of backward compatibility.
796 // Since c10d::ReduceOp had been an `enum class`, users can do comparison and
797 // take hash of `::c10d::ReduceOp`. To avoid losing these functionality, here
798 // I define some member methods.
799 reduce_op
800 // todo(crcrpar): Support `RedOpType == ReduceOp`.
801 .def(
802 // This calls `operator==(const ReduceOp::RedOpType)`
803 "__eq__",
804 [](const ::c10d::ReduceOp& self,
805 const ::c10d::ReduceOp::RedOpType& other) {
806 return self == other;
807 })
808 .def(
809 // This calls `operator==(const ReduceOp)` for the future support of
810 // `PREMUL_SUM` comparison
811 "__eq__",
812 [](const ::c10d::ReduceOp& self, const ::c10d::ReduceOp& other) {
813 return self == other;
814 })
815 .def(
816 // With the above custom `__eq__`'s, I have to manually support the
817 // other types.
818 "__eq__",
819 // NOLINTNEXTLINE(performance-unnecessary-value-param)
820 [](const ::c10d::ReduceOp& self, py::object) { return false; })
821 .def(
822 "__hash__",
823 [](const ::c10d::ReduceOp& self) {
824 return static_cast<uint8_t>(self.op_);
825 })
826 .def(
827 "__copy__",
828 [](const ::c10d::ReduceOp& self) { return ::c10d::ReduceOp(self); })
829 .def(
830 "__deepcopy__",
831 [](const ::c10d::ReduceOp& self, const py::dict& memo) {
832 return ::c10d::ReduceOp(self);
833 })
834 .def(py::pickle(
835 [](const ::c10d::ReduceOp& r) {
836 // __getstate__
837 if (r.op_ != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
838 return py::make_tuple(r.op_, py::none());
839 }
840 TORCH_CHECK(r.supplement_.defined(), "Invalid PREMUL_SUM ReduceOp");
841 const auto* preMulSupplement =
842 reinterpret_cast<::c10d::NCCLPreMulSumSupplement*>(
843 r.supplement_.get());
844 if (!preMulSupplement->tensor_factor.defined()) {
845 return py::make_tuple(r.op_, preMulSupplement->double_factor);
846 } else {
847 return py::make_tuple(r.op_, preMulSupplement->tensor_factor);
848 }
849 },
850 [](const py::tuple& t) {
851 // __setstate__
852 TORCH_CHECK(t.size() == 2, "Invalid state");
853 const auto op =
854 static_cast<::c10d::ReduceOp::RedOpType>(t[0].cast<uint8_t>());
855 if (op != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
856 return ::c10d::ReduceOp(op);
857 }
858 const auto preMulSupplement_factor = t[1];
859 if (py::isinstance<py::float_>(preMulSupplement_factor)) {
860 return ::c10d::makeNCCLPreMulSum(t[1].cast<double>());
861 } else {
862 return ::c10d::makeNCCLPreMulSum(t[1].cast<at::Tensor>());
863 }
864 }));
865
866 py::enum_<::c10d::ReduceOp::RedOpType>(reduce_op, "RedOpType")
867 .value("SUM", ::c10d::ReduceOp::RedOpType::SUM)
868 .value("AVG", ::c10d::ReduceOp::RedOpType::AVG)
869 .value("PRODUCT", ::c10d::ReduceOp::RedOpType::PRODUCT)
870 .value("MIN", ::c10d::ReduceOp::RedOpType::MIN)
871 .value("MAX", ::c10d::ReduceOp::RedOpType::MAX)
872 .value("BAND", ::c10d::ReduceOp::RedOpType::BAND)
873 .value("BOR", ::c10d::ReduceOp::RedOpType::BOR)
874 .value("BXOR", ::c10d::ReduceOp::RedOpType::BXOR)
875 .value("PREMUL_SUM", ::c10d::ReduceOp::RedOpType::PREMUL_SUM)
876 .export_values();
877
878 // note(crcrpar): This could be removed because users will not pass
879 // `RedOpType` to reduce collective ops Ref: [Implicit
880 // conversions](https://pybind11.readthedocs.io/en/stable/advanced/classes.html#implicit-conversions)
881 // Let us skip the explicit construction of `c10d::ReduceOp` from
882 // `c10d::ReduceOp::RedOpType` in Python.
883 py::implicitly_convertible<::c10d::ReduceOp::RedOpType, ::c10d::ReduceOp>();
884
885 module
886 .def(
887 "_make_nccl_premul_sum",
888 &::c10d::makeNCCLPreMulSum<double>,
889 py::arg("factor").noconvert(),
890 py::return_value_policy::copy, // seems safest
891 py::call_guard<py::gil_scoped_release>())
892 .def(
893 "_make_nccl_premul_sum",
894 &::c10d::makeNCCLPreMulSum<at::Tensor>,
895 py::arg("factor").noconvert(),
896 py::return_value_policy::copy, // seems safest
897 py::call_guard<py::gil_scoped_release>());
898
899 module.def(
900 "_set_thread_isolation_mode",
901 &::c10d::set_thread_isolation_mode,
902 py::arg("enable"));
903
904 // Bindings for GroupRegistry.hpp
905 //
906 // Register a process group in the native registry. Process groups registered
907 // via `_register_process_group` can be resolved from both Python and C++.
908 module.def(
909 "_register_process_group",
910 [](const std::string& group_name,
911 c10::intrusive_ptr<::c10d::ProcessGroup> group) {
912 ::c10d::register_process_group(group_name, std::move(group));
913 },
914 py::arg("group_name"),
915 py::arg("group"));
916
917 // Resolve a process group from the native registry
918 module.def(
919 "_resolve_process_group",
920 [](const std::string& group_name) {
921 return ::c10d::resolve_process_group(group_name);
922 },
923 py::arg("group_name"));
924
925 module.def(
926 "_register_work",
927 [](const at::Tensor& tensor,
928 const c10::intrusive_ptr<::c10d::Work>& work) {
929 dynamic_cast<::c10d::PyProcessGroup::PyWork*>(work.get())
930 ->ref_py_object();
931 ::c10d::register_work(tensor, std::move(work));
932 },
933 py::arg("tensor"),
934 py::arg("work"));
935
936 // Remove a group from the native registry
937 module.def(
938 "_unregister_process_group",
939 [](const std::string& group_name) {
940 return ::c10d::unregister_process_group(group_name);
941 },
942 py::arg("group_name"));
943
944 // Remove all process groups from the native registry
945 module.def("_unregister_all_process_groups", []() {
946 return ::c10d::unregister_all_process_groups();
947 });
948
949 py::class_<::c10d::BroadcastOptions>(module, "BroadcastOptions")
950 .def(py::init<>())
951 .def_readwrite("rootRank", &::c10d::BroadcastOptions::rootRank)
952 .def_readwrite("rootTensor", &::c10d::BroadcastOptions::rootTensor)
953 .def_readwrite("timeout", &::c10d::BroadcastOptions::timeout)
954 .def_readwrite("asyncOp", &::c10d::BroadcastOptions::asyncOp);
955
956 py::class_<::c10d::AllreduceOptions>(module, "AllreduceOptions")
957 .def(py::init<>())
958 .def_readwrite("reduceOp", &::c10d::AllreduceOptions::reduceOp)
959 .def_readwrite("timeout", &::c10d::AllreduceOptions::timeout);
960
961 py::class_<::c10d::AllreduceCoalescedOptions>(
962 module, "AllreduceCoalescedOptions")
963 .def(py::init<>())
964 .def_readwrite("reduceOp", &::c10d::AllreduceCoalescedOptions::reduceOp)
965 .def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout);
966
967 py::class_<::c10d::ReduceOptions>(module, "ReduceOptions")
968 .def(py::init<>())
969 .def_readwrite("reduceOp", &::c10d::ReduceOptions::reduceOp)
970 .def_readwrite("rootRank", &::c10d::ReduceOptions::rootRank)
971 .def_readwrite("rootTensor", &::c10d::ReduceOptions::rootTensor)
972 .def_readwrite("timeout", &::c10d::ReduceOptions::timeout);
973
974 py::class_<::c10d::AllgatherOptions>(module, "AllgatherOptions")
975 .def(py::init<>())
976 .def_readwrite("timeout", &::c10d::AllgatherOptions::timeout)
977 .def_readwrite("asyncOp", &::c10d::AllgatherOptions::asyncOp);
978
979 py::class_<::c10d::GatherOptions>(module, "GatherOptions")
980 .def(py::init<>())
981 .def_readwrite("rootRank", &::c10d::GatherOptions::rootRank)
982 .def_readwrite("timeout", &::c10d::GatherOptions::timeout);
983
984 py::class_<::c10d::ScatterOptions>(module, "ScatterOptions")
985 .def(py::init<>())
986 .def_readwrite("rootRank", &::c10d::ScatterOptions::rootRank)
987 .def_readwrite("timeout", &::c10d::ScatterOptions::timeout)
988 .def_readwrite("asyncOp", &::c10d::ScatterOptions::asyncOp);
989
990 py::class_<::c10d::ReduceScatterOptions>(module, "ReduceScatterOptions")
991 .def(py::init<>())
992 .def_readwrite("reduceOp", &::c10d::ReduceScatterOptions::reduceOp)
993 .def_readwrite("timeout", &::c10d::ReduceScatterOptions::timeout)
994 .def_readwrite("asyncOp", &::c10d::ReduceScatterOptions::asyncOp);
995
996 py::class_<::c10d::BarrierOptions>(module, "BarrierOptions")
997 .def(py::init<>())
998 .def_readwrite("device_ids", &::c10d::BarrierOptions::device_ids)
999 .def_readwrite("timeout", &::c10d::BarrierOptions::timeout)
1000 .def_readwrite("device", &::c10d::BarrierOptions::device);
1001
1002 py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions")
1003 .def(py::init<>())
1004 .def_readwrite("timeout", &::c10d::AllToAllOptions::timeout);
1005
1006 py::class_<::c10d::DistributedBackendOptions>(
1007 module, "_DistributedBackendOptions")
1008 .def(py::init<>())
1009 .def_readwrite("store", &::c10d::DistributedBackendOptions::store)
1010 .def_readwrite(
1011 "group_rank", &::c10d::DistributedBackendOptions::group_rank)
1012 .def_readwrite(
1013 "group_size", &::c10d::DistributedBackendOptions::group_size)
1014 .def_readwrite("timeout", &::c10d::DistributedBackendOptions::timeout)
1015 .def_readwrite("group_id", &::c10d::DistributedBackendOptions::group_id)
1016 .def_readwrite(
1017 "global_ranks_in_group",
1018 &::c10d::DistributedBackendOptions::global_ranks_in_group);
1019
1020 py::class_<
1021 ::c10d::DMAConnectivity,
1022 c10::intrusive_ptr<::c10d::DMAConnectivity>>(module, "_DMAConnectivity")
1023 .def_readonly("device_type", &::c10d::DMAConnectivity::device_type)
1024 .def_readonly(
1025 "connection_type", &::c10d::DMAConnectivity::connection_type)
1026 .def_readonly("matrix", &::c10d::DMAConnectivity::matrix);
1027
1028 module.def("_detect_dma_connectivity", ::c10d::detect_dma_connectivity);
1029
1030 using SymmetricMemory = ::c10d::symmetric_memory::SymmetricMemory;
1031 py::class_<SymmetricMemory, c10::intrusive_ptr<SymmetricMemory>>(
1032 module, "_SymmetricMemory")
1033 .def_static("set_group_info", &::c10d::symmetric_memory::set_group_info)
1034 .def_static(
1035 "empty_strided_p2p",
1036 ::c10d::symmetric_memory::empty_strided_p2p,
1037 py::arg("size"),
1038 py::arg("stride"),
1039 py::arg("dtype"),
1040 py::arg("device"),
1041 py::arg("group_name"),
1042 py::arg("alloc_id") = py::none())
1043 .def_static("rendezvous", &::c10d::symmetric_memory::rendezvous)
1044 .def_static(
1045 "get_symmetric_memory",
1046 &::c10d::symmetric_memory::get_symmetric_memory)
1047 .def_static(
1048 "has_multicast_support",
1049 &::c10d::symmetric_memory::has_multicast_support)
1050 .def_property_readonly("rank", &SymmetricMemory::get_rank)
1051 .def_property_readonly("world_size", &SymmetricMemory::get_world_size)
1052 .def_property_readonly(
1053 "buffer_ptrs_dev",
1054 [](const c10::intrusive_ptr<SymmetricMemory>& symm_mem) {
1055 return reinterpret_cast<uintptr_t>(symm_mem->get_buffer_ptrs_dev());
1056 })
1057 .def_property_readonly(
1058 "signal_pad_ptrs_dev",
1059 [](const c10::intrusive_ptr<SymmetricMemory>& symm_mem) {
1060 return reinterpret_cast<uintptr_t>(
1061 symm_mem->get_signal_pad_ptrs_dev());
1062 })
1063 .def_property_readonly("buffer_size", &SymmetricMemory::get_buffer_size)
1064 .def_property_readonly(
1065 "signal_pad_size", &SymmetricMemory::get_signal_pad_size)
1066 .def(
1067 "get_buffer",
1068 &SymmetricMemory::get_buffer,
1069 py::arg("rank"),
1070 py::arg("sizes"),
1071 py::arg("dtype"),
1072 py::arg("storage_offset") = 0)
1073 .def("barrier", &SymmetricMemory::barrier, py::arg("channel") = 0)
1074 .def(
1075 "put_signal",
1076 &SymmetricMemory::put_signal,
1077 py::arg("dst_rank"),
1078 py::arg("channel") = 0)
1079 .def(
1080 "wait_signal",
1081 &SymmetricMemory::wait_signal,
1082 py::arg("src_rank"),
1083 py::arg("channel") = 0);
1084
1085 auto store =
1086 py::class_<::c10d::Store, c10::intrusive_ptr<::c10d::Store>, PythonStore>(
1087 module,
1088 "Store",
1089 R"(
1090 Base class for all store implementations, such as the 3 provided by PyTorch
1091 distributed: (:class:`~torch.distributed.TCPStore`, :class:`~torch.distributed.FileStore`,
1092 and :class:`~torch.distributed.HashStore`).
1093 )")
1094 // Default constructor.
1095 .def(py::init<>())
1096 // Convert from std::string to std::vector<uint8>.
1097 .def(
1098 "set",
1099 [](::c10d::Store& store,
1100 const std::string& key,
1101 const std::string& value) { store.set(key, toVec8(value)); },
1102 py::call_guard<py::gil_scoped_release>(),
1103 R"(
1104 Inserts the key-value pair into the store based on the supplied ``key`` and
1105 ``value``. If ``key`` already exists in the store, it will overwrite the old
1106 value with the new supplied ``value``.
1107
1108 Arguments:
1109 key (str): The key to be added to the store.
1110 value (str): The value associated with ``key`` to be added to the store.
1111
1112 Example::
1113 >>> import torch.distributed as dist
1114 >>> from datetime import timedelta
1115 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1116 >>> store.set("first_key", "first_value")
1117 >>> # Should return "first_value"
1118 >>> store.get("first_key")
1119 )")
1120 .def(
1121 "compare_set",
1122 [](::c10d::Store& store,
1123 const std::string& key,
1124 const std::string& expected_value,
1125 const std::string& desired_value) -> py::bytes {
1126 auto value = [&]() {
1127 py::gil_scoped_release guard;
1128 return store.compareSet(
1129 key, toVec8(expected_value), toVec8(desired_value));
1130 }();
1131 return toPyBytes(value);
1132 },
1133 R"(
1134 Inserts the key-value pair into the store based on the supplied ``key`` and
1135 performs comparison between ``expected_value`` and ``desired_value`` before inserting. ``desired_value``
1136 will only be set if ``expected_value`` for the ``key`` already exists in the store or if ``expected_value``
1137 is an empty string.
1138
1139 Arguments:
1140 key (str): The key to be checked in the store.
1141 expected_value (str): The value associated with ``key`` to be checked before insertion.
1142 desired_value (str): The value associated with ``key`` to be added to the store.
1143
1144 Example::
1145 >>> import torch.distributed as dist
1146 >>> from datetime import timedelta
1147 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1148 >>> store.set("key", "first_value")
1149 >>> store.compare_set("key", "first_value", "second_value")
1150 >>> # Should return "second_value"
1151 >>> store.get("key")
1152 )")
1153 // Convert from std::vector<uint8_t> to py::bytes.
1154 // The returned value is not guaranteed to be valid UTF-8.
1155 .def(
1156 "get",
1157 [](::c10d::Store& store, const std::string& key) -> py::bytes {
1158 auto value = [&]() {
1159 py::gil_scoped_release guard;
1160 return store.get(key);
1161 }();
1162 return toPyBytes(value);
1163 },
1164 R"(
1165 Retrieves the value associated with the given ``key`` in the store. If ``key`` is not
1166 present in the store, the function will wait for ``timeout``, which is defined
1167 when initializing the store, before throwing an exception.
1168
1169 Arguments:
1170 key (str): The function will return the value associated with this key.
1171
1172 Returns:
1173 Value associated with ``key`` if ``key`` is in the store.
1174
1175 Example::
1176 >>> import torch.distributed as dist
1177 >>> from datetime import timedelta
1178 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1179 >>> store.set("first_key", "first_value")
1180 >>> # Should return "first_value"
1181 >>> store.get("first_key")
1182 )")
1183 .def(
1184 "add",
1185 &::c10d::Store::add,
1186 py::call_guard<py::gil_scoped_release>(),
1187 R"(
1188 The first call to add for a given ``key`` creates a counter associated
1189 with ``key`` in the store, initialized to ``amount``. Subsequent calls to add
1190 with the same ``key`` increment the counter by the specified ``amount``.
1191 Calling :meth:`~torch.distributed.store.add` with a key that has already
1192 been set in the store by :meth:`~torch.distributed.store.set` will result
1193 in an exception.
1194
1195 Arguments:
1196 key (str): The key in the store whose counter will be incremented.
1197 amount (int): The quantity by which the counter will be incremented.
1198
1199 Example::
1200 >>> import torch.distributed as dist
1201 >>> from datetime import timedelta
1202 >>> # Using TCPStore as an example, other store types can also be used
1203 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1204 >>> store.add("first_key", 1)
1205 >>> store.add("first_key", 6)
1206 >>> # Should return 7
1207 >>> store.get("first_key")
1208 )")
1209 .def(
1210 "check",
1211 &::c10d::Store::check,
1212 py::call_guard<py::gil_scoped_release>(),
1213 R"(
1214 The call to check whether a given list of ``keys`` have value stored in
1215 the store. This call immediately returns in normal cases but still suffers
1216 from some edge deadlock cases, e.g, calling check after TCPStore has been destroyed.
1217 Calling :meth:`~torch.distributed.store.check` with a list of keys that
1218 one wants to check whether stored in the store or not.
1219
1220 Arguments:
1221 keys (lisr[str]): The keys to query whether stored in the store.
1222
1223 Example::
1224 >>> import torch.distributed as dist
1225 >>> from datetime import timedelta
1226 >>> # Using TCPStore as an example, other store types can also be used
1227 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1228 >>> store.add("first_key", 1)
1229 >>> # Should return 7
1230 >>> store.check(["first_key"])
1231 )")
1232 .def(
1233 "delete_key",
1234 &::c10d::Store::deleteKey,
1235 py::call_guard<py::gil_scoped_release>(),
1236 R"(
1237 Deletes the key-value pair associated with ``key`` from the store. Returns
1238 `true` if the key was successfully deleted, and `false` if it was not.
1239
1240 .. warning::
1241 The ``delete_key`` API is only supported by the :class:`~torch.distributed.TCPStore` and :class:`~torch.distributed.HashStore`. Using this API
1242 with the :class:`~torch.distributed.FileStore` will result in an exception.
1243
1244 Arguments:
1245 key (str): The key to be deleted from the store
1246
1247 Returns:
1248 `True` if ``key`` was deleted, otherwise `False`.
1249
1250 Example::
1251 >>> import torch.distributed as dist
1252 >>> from datetime import timedelta
1253 >>> # Using TCPStore as an example, HashStore can also be used
1254 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1255 >>> store.set("first_key")
1256 >>> # This should return true
1257 >>> store.delete_key("first_key")
1258 >>> # This should return false
1259 >>> store.delete_key("bad_key")
1260 )")
1261 .def(
1262 "num_keys",
1263 &::c10d::Store::getNumKeys,
1264 py::call_guard<py::gil_scoped_release>(),
1265 R"(
1266 Returns the number of keys set in the store. Note that this number will typically
1267 be one greater than the number of keys added by :meth:`~torch.distributed.store.set`
1268 and :meth:`~torch.distributed.store.add` since one key is used to coordinate all
1269 the workers using the store.
1270
1271 .. warning::
1272 When used with the :class:`~torch.distributed.TCPStore`, ``num_keys`` returns the number of keys written to the underlying file. If the store is destructed and another store is created with the same file, the original keys will be retained.
1273
1274 Returns:
1275 The number of keys present in the store.
1276
1277 Example::
1278 >>> import torch.distributed as dist
1279 >>> from datetime import timedelta
1280 >>> # Using TCPStore as an example, other store types can also be used
1281 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1282 >>> store.set("first_key", "first_value")
1283 >>> # This should return 2
1284 >>> store.num_keys()
1285 )")
1286 .def(
1287 "set_timeout",
1288 &::c10d::Store::setTimeout,
1289 py::call_guard<py::gil_scoped_release>(),
1290 R"(
1291 Sets the store's default timeout. This timeout is used during initialization and in
1292 :meth:`~torch.distributed.store.wait` and :meth:`~torch.distributed.store.get`.
1293
1294 Arguments:
1295 timeout (timedelta): timeout to be set in the store.
1296
1297 Example::
1298 >>> import torch.distributed as dist
1299 >>> from datetime import timedelta
1300 >>> # Using TCPStore as an example, other store types can also be used
1301 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1302 >>> store.set_timeout(timedelta(seconds=10))
1303 >>> # This will throw an exception after 10 seconds
1304 >>> store.wait(["bad_key"])
1305 )")
1306 .def(
1307 "wait",
1308 [](::c10d::Store& store, const std::vector<std::string>& keys) {
1309 store.wait(keys);
1310 },
1311 py::call_guard<py::gil_scoped_release>(),
1312 R"(
1313 Waits for each key in ``keys`` to be added to the store. If not all keys are
1314 set before the ``timeout`` (set during store initialization), then ``wait``
1315 will throw an exception.
1316
1317 Arguments:
1318 keys (list): List of keys on which to wait until they are set in the store.
1319
1320 Example::
1321 >>> import torch.distributed as dist
1322 >>> from datetime import timedelta
1323 >>> # Using TCPStore as an example, other store types can also be used
1324 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1325 >>> # This will throw an exception after 30 seconds
1326 >>> store.wait(["bad_key"])
1327 )")
1328 .def(
1329 "wait",
1330 [](::c10d::Store& store,
1331 const std::vector<std::string>& keys,
1332 const std::chrono::milliseconds& timeout) {
1333 store.wait(keys, timeout);
1334 },
1335 py::call_guard<py::gil_scoped_release>(),
1336 R"(
1337 Waits for each key in ``keys`` to be added to the store, and throws an exception
1338 if the keys have not been set by the supplied ``timeout``.
1339
1340 Arguments:
1341 keys (list): List of keys on which to wait until they are set in the store.
1342 timeout (timedelta): Time to wait for the keys to be added before throwing an exception.
1343
1344 Example::
1345 >>> import torch.distributed as dist
1346 >>> from datetime import timedelta
1347 >>> # Using TCPStore as an example, other store types can also be used
1348 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1349 >>> # This will throw an exception after 10 seconds
1350 >>> store.wait(["bad_key"], timedelta(seconds=10))
1351 )")
1352 .def_property_readonly(
1353 "timeout",
1354 &::c10d::Store::getTimeout,
1355 R"(Gets the timeout of the store.)")
1356 .def(
1357 "append",
1358 [](::c10d::Store& store,
1359 const std::string& key,
1360 const std::string& value) {
1361 store.append(key, toVec8(value));
1362 },
1363 py::call_guard<py::gil_scoped_release>(),
1364 R"(
1365 Append the key-value pair into the store based on the supplied ``key`` and
1366 ``value``. If ``key`` does not exists in the store, it will be created.
1367
1368 Arguments:
1369 key (str): The key to be appended to the store.
1370 value (str): The value associated with ``key`` to be added to the store.
1371
1372 Example::
1373 >>> import torch.distributed as dist
1374 >>> from datetime import timedelta
1375 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1376 >>> store.append("first_key", "po")
1377 >>> store.append("first_key", "tato")
1378 >>> # Should return "potato"
1379 >>> store.get("first_key")
1380 )")
1381 .def(
1382 "multi_get",
1383 [](::c10d::Store& store, const std::vector<std::string>& keys) {
1384 auto values = [&]() {
1385 py::gil_scoped_release guard;
1386 return store.multiGet(keys);
1387 }();
1388 return toPyBytes(values);
1389 },
1390 R"(
1391 Retrieve all values in ``keys``. If any key in ``keys`` is not
1392 present in the store, the function will wait for ``timeout``
1393
1394 Arguments:
1395 keys (List[str]): The keys to be retrieved from the store.
1396
1397 Example::
1398 >>> import torch.distributed as dist
1399 >>> from datetime import timedelta
1400 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1401 >>> store.set("first_key", "po")
1402 >>> store.set("second_key", "tato")
1403 >>> # Should return [b"po", b"tato"]
1404 >>> store.multi_get(["first_key", "second_key"])
1405 )")
1406 .def(
1407 "multi_set",
1408 [](::c10d::Store& store,
1409 const std::vector<std::string>& keys,
1410 const std::vector<std::string>& values) {
1411 store.multiSet(keys, toVec8(values));
1412 },
1413 py::call_guard<py::gil_scoped_release>(),
1414 R"(
1415 Inserts a list key-value pair into the store based on the supplied ``keys`` and ``values``
1416
1417 Arguments:
1418 keys (List[str]): The keys to insert.
1419 values (List[str]): The values to insert.
1420
1421 Example::
1422 >>> import torch.distributed as dist
1423 >>> from datetime import timedelta
1424 >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1425 >>> store.multi_set(["first_key", "second_key"], ["po", "tato"])
1426 >>> # Should return b"po"
1427 >>> store.get("first_key")
1428 )")
1429 .def(
1430 "has_extended_api",
1431 &::c10d::Store::hasExtendedApi,
1432 R"(Returns true if the store supports extended operations.)");
1433
1434 intrusive_ptr_class_<::c10d::FileStore>(
1435 module,
1436 "FileStore",
1437 store,
1438 R"(
1439 A store implementation that uses a file to store the underlying key-value pairs.
1440
1441 Arguments:
1442 file_name (str): path of the file in which to store the key-value pairs
1443 world_size (int, optional): The total number of processes using the store. Default is -1 (a negative value indicates a non-fixed number of store users).
1444
1445 Example::
1446 >>> import torch.distributed as dist
1447 >>> store1 = dist.FileStore("/tmp/filestore", 2)
1448 >>> store2 = dist.FileStore("/tmp/filestore", 2)
1449 >>> # Use any of the store methods from either the client or server after initialization
1450 >>> store1.set("first_key", "first_value")
1451 >>> store2.get("first_key")
1452
1453 )")
1454 .def(
1455 py::init<const std::string&, int>(),
1456 py::arg("file_name"),
1457 py::arg("world_size") = -1)
1458 .def_property_readonly(
1459 "path",
1460 &::c10d::FileStore::getPath,
1461 R"(Gets the path of the file used by FileStore to store key-value pairs.)");
1462
1463 #ifndef _WIN32
1464 intrusive_ptr_class_<::c10d::HashStore>(
1465 module,
1466 "HashStore",
1467 store,
1468 R"(
1469 A thread-safe store implementation based on an underlying hashmap. This store can be used
1470 within the same process (for example, by other threads), but cannot be used across processes.
1471
1472 Example::
1473 >>> import torch.distributed as dist
1474 >>> store = dist.HashStore()
1475 >>> # store can be used from other threads
1476 >>> # Use any of the store methods after initialization
1477 >>> store.set("first_key", "first_value")
1478 )")
1479 .def(py::init<>());
1480 #endif
1481
1482 intrusive_ptr_class_<::c10d::TCPStore>(
1483 module,
1484 "TCPStore",
1485 store,
1486 R"(
1487 A TCP-based distributed key-value store implementation. The server store holds
1488 the data, while the client stores can connect to the server store over TCP and
1489 perform actions such as :meth:`~torch.distributed.store.set` to insert a key-value
1490 pair, :meth:`~torch.distributed.store.get` to retrieve a key-value pair, etc. There
1491 should always be one server store initialized because the client store(s) will wait for
1492 the server to establish a connection.
1493
1494 Arguments:
1495 host_name (str): The hostname or IP Address the server store should run on.
1496 port (int): The port on which the server store should listen for incoming requests.
1497 world_size (int, optional): The total number of store users (number of clients + 1 for the server). Default is None (None indicates a non-fixed number of store users).
1498 is_master (bool, optional): True when initializing the server store and False for client stores. Default is False.
1499 timeout (timedelta, optional): Timeout used by the store during initialization and for methods such as :meth:`~torch.distributed.store.get` and :meth:`~torch.distributed.store.wait`. Default is timedelta(seconds=300)
1500 wait_for_workers (bool, optional): Whether to wait for all the workers to connect with the server store. This is only applicable when world_size is a fixed value. Default is True.
1501 multi_tenant (bool, optional): If True, all ``TCPStore`` instances in the current process with the same host/port will use the same underlying ``TCPServer``. Default is False.
1502 master_listen_fd (int, optional): If specified, the underlying ``TCPServer`` will listen on this file descriptor, which must be a socket already bound to ``port``. Useful to avoid port assignment races in some scenarios. Default is None (meaning the server creates a new socket and attempts to bind it to ``port``).
1503 use_libuv (bool, optional): If True, use libuv for ``TCPServer`` backend. Default is True.
1504 Example::
1505 >>> import torch.distributed as dist
1506 >>> from datetime import timedelta
1507 >>> # Run on process 1 (server)
1508 >>> server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30))
1509 >>> # Run on process 2 (client)
1510 >>> client_store = dist.TCPStore("127.0.0.1", 1234, 2, False)
1511 >>> # Use any of the store methods from either the client or server after initialization
1512 >>> server_store.set("first_key", "first_value")
1513 >>> client_store.get("first_key")
1514 )")
1515 .def(
1516 py::init([](const std::string& host,
1517 uint16_t port,
1518 std::optional<int> worldSize,
1519 bool isServer,
1520 std::chrono::milliseconds timeout,
1521 bool waitWorkers,
1522 bool multiTenant,
1523 std::optional<int> masterListenFd,
1524 bool useLibUV) {
1525 std::optional<std::size_t> numWorkers = std::nullopt;
1526 if (worldSize.has_value() && worldSize.value() > -1) {
1527 numWorkers = static_cast<std::size_t>(worldSize.value());
1528 }
1529
1530 ::c10d::TCPStoreOptions opts{
1531 port,
1532 isServer,
1533 numWorkers,
1534 waitWorkers,
1535 timeout,
1536 multiTenant,
1537 masterListenFd,
1538 useLibUV};
1539
1540 return c10::make_intrusive<::c10d::TCPStore>(host, opts);
1541 }),
1542 py::arg("host_name"),
1543 py::arg("port"),
1544 py::arg("world_size") = py::none(),
1545 // using noconvert() requires this argument to be True or False
1546 // prevents accidental implicit conversion to bool
1547 py::arg("is_master").noconvert() = false,
1548 py::arg("timeout") =
1549 std::chrono::milliseconds(::c10d::Store::kDefaultTimeout),
1550 py::arg("wait_for_workers") = true,
1551 py::arg("multi_tenant") = false,
1552 py::arg("master_listen_fd") = py::none(),
1553 py::arg("use_libuv") = true,
1554 py::call_guard<py::gil_scoped_release>())
1555 .def_property_readonly(
1556 "host",
1557 &::c10d::TCPStore::getHost,
1558 R"(Gets the hostname on which the store listens for requests.)")
1559 .def_property_readonly(
1560 "port",
1561 &::c10d::TCPStore::getPort,
1562 R"(Gets the port number on which the store listens for requests.)")
1563 .def_property_readonly(
1564 "libuvBackend",
1565 &::c10d::TCPStore::isLibUvBackend,
1566 R"(Returns True if it's using the libuv backend.)")
1567 .def(
1568 "__repr__",
1569 &::c10d::TCPStore::repr,
1570 R"(Returns a string representation of the TCPStore.)",
1571 py::call_guard<py::gil_scoped_release>());
1572
1573 intrusive_ptr_class_<::c10d::PrefixStore>(
1574 module,
1575 "PrefixStore",
1576 store,
1577 R"(
1578 A wrapper around any of the 3 key-value stores (:class:`~torch.distributed.TCPStore`,
1579 :class:`~torch.distributed.FileStore`, and :class:`~torch.distributed.HashStore`)
1580 that adds a prefix to each key inserted to the store.
1581
1582 Arguments:
1583 prefix (str): The prefix string that is prepended to each key before being inserted into the store.
1584 store (torch.distributed.store): A store object that forms the underlying key-value store.
1585 )")
1586 .def(py::init<const std::string&, c10::intrusive_ptr<::c10d::Store>>())
1587 .def_property_readonly(
1588 "underlying_store",
1589 &::c10d::PrefixStore::getUnderlyingStore,
1590 R"(Gets the underlying store object that PrefixStore wraps around.)")
1591 .def_property_readonly(
1592 "_underlying_non_prefix_store",
1593 &::c10d::PrefixStore::getUnderlyingNonPrefixStore,
1594 R"(Recursively to get the store before layers of wrapping with PrefixStore.)");
1595
1596 using namespace std::chrono_literals;
1597
1598 auto collectives =
1599 py::class_<
1600 ::c10d::ControlCollectives,
1601 c10::intrusive_ptr<::c10d::ControlCollectives>>(
1602 module,
1603 "_ControlCollectives",
1604 R"(
1605 Base class for all ControlCollectives implementations.
1606 )")
1607 .def(
1608 "barrier",
1609 &::c10d::ControlCollectives::barrier,
1610 py::arg("key"),
1611 py::arg("timeout") = 5min,
1612 py::arg("block") = true,
1613 py::call_guard<py::gil_scoped_release>(),
1614 R"(
1615 Blocks until all workers have entered this function.
1616
1617 Arguments:
1618 key (str): The unique key used to identify this operation.
1619 timeout (duration): The timeout for this operation.
1620 block (bool): whether to block this working waiting on the results of the barrier.
1621 )")
1622 .def(
1623 "all_sum",
1624 &::c10d::ControlCollectives::allSum,
1625 py::arg("key"),
1626 py::arg("data"),
1627 py::arg("timeout") = 5min,
1628 py::call_guard<py::gil_scoped_release>(),
1629 R"(
1630 Computes a sum across all workers and returns the final value.
1631
1632 Arguments:
1633 key (str): The unique key used to identify this operation.
1634 data (int): The data to sum.
1635 timeout (duration): The timeout for this operation.
1636 )")
1637 .def(
1638 "broadcast_send",
1639 [](::c10d::ControlCollectives& collectives,
1640 const std::string& key,
1641 const std::string& data,
1642 std::chrono::milliseconds timeout = 5min) {
1643 collectives.broadcastSend(key, toVec8(data), timeout);
1644 },
1645 py::arg("key"),
1646 py::arg("data"),
1647 py::arg("timeout") = 5min,
1648 py::call_guard<py::gil_scoped_release>(),
1649 R"(
1650 Sends data to all other workers. Must be only called from one worker.
1651
1652 Arguments:
1653 key (str): The unique key used to identify this operation.
1654 data (str): The data to send.
1655 timeout (duration): The timeout for this operation.
1656 )")
1657 .def(
1658 "broadcast_recv",
1659 [](::c10d::ControlCollectives& collectives,
1660 const std::string& key,
1661 std::chrono::milliseconds timeout = 5min) {
1662 auto out = [&]() {
1663 py::gil_scoped_release guard;
1664 return collectives.broadcastRecv(key, timeout);
1665 }();
1666 return toPyBytes(out);
1667 },
1668 py::arg("key"),
1669 py::arg("timeout") = 5min,
1670 R"(
1671 Receives data broadcasted from 1 worker.
1672
1673 Arguments:
1674 key (str): The unique key used to identify this operation.
1675 timeout (duration): The timeout for this operation.
1676 )")
1677 .def(
1678 "gather_send",
1679 [](::c10d::ControlCollectives& collectives,
1680 const std::string& key,
1681 const std::string& data,
1682 std::chrono::milliseconds timeout = 5min) {
1683 collectives.gatherSend(key, toVec8(data), timeout);
1684 },
1685 py::arg("key"),
1686 py::arg("data"),
1687 py::arg("timeout") = 5min,
1688 py::call_guard<py::gil_scoped_release>(),
1689 R"(
1690 Sends data to one other worker.
1691
1692 Arguments:
1693 key (str): The unique key used to identify this operation.
1694 data (str): The data to send.
1695 timeout (duration): The timeout for this operation.
1696 )")
1697 .def(
1698 "gather_recv",
1699 [](::c10d::ControlCollectives& collectives,
1700 const std::string& key,
1701 const std::string& data,
1702 std::chrono::milliseconds timeout = 5min) {
1703 auto out = [&]() {
1704 py::gil_scoped_release guard;
1705 return collectives.gatherRecv(key, toVec8(data), timeout);
1706 }();
1707 return toPyBytes(out);
1708 },
1709 py::arg("key"),
1710 py::arg("data"),
1711 py::arg("timeout") = 5min,
1712 R"(
1713 Receives data broadcasted from all workers. Must only be called by one worker.
1714
1715 Arguments:
1716 key (str): The unique key used to identify this operation.
1717 timeout (duration): The timeout for this operation.
1718 )")
1719
1720 .def(
1721 "scatter_send",
1722 [](::c10d::ControlCollectives& collectives,
1723 const std::string& key,
1724 const std::vector<std::string>& data,
1725 std::chrono::milliseconds timeout = 5min) {
1726 auto out = [&]() {
1727 py::gil_scoped_release guard;
1728 return collectives.scatterSend(key, toVec8(data), timeout);
1729 }();
1730 return toPyBytes(out);
1731 },
1732 py::arg("key"),
1733 py::arg("data"),
1734 py::arg("timeout") = 5min,
1735 R"(
1736 Sends rank specific data to all other workers.
1737
1738 Arguments:
1739 key (str): The unique key used to identify this operation.
1740 data (str): The data to send.
1741 timeout (duration): The timeout for this operation.
1742 )")
1743 .def(
1744 "scatter_recv",
1745 [](::c10d::ControlCollectives& collectives,
1746 const std::string& key,
1747 std::chrono::milliseconds timeout = 5min) {
1748 auto out = [&]() {
1749 py::gil_scoped_release guard;
1750 return collectives.scatterRecv(key, timeout);
1751 }();
1752 return toPyBytes(out);
1753 },
1754 py::arg("key"),
1755 py::arg("timeout") = 5min,
1756 R"(
1757 Receives rank specific data from one worker.
1758
1759 Arguments:
1760 key (str): The unique key used to identify this operation.
1761 timeout (duration): The timeout for this operation.
1762 )")
1763
1764 .def(
1765 "all_gather",
1766 [](::c10d::ControlCollectives& collectives,
1767 const std::string& key,
1768 const std::string& data,
1769 std::chrono::milliseconds timeout = 5min) {
1770 auto out = [&]() {
1771 py::gil_scoped_release guard;
1772 return collectives.allGather(key, toVec8(data), timeout);
1773 }();
1774 return toPyBytes(out);
1775 },
1776 py::arg("key"),
1777 py::arg("data"),
1778 py::arg("timeout") = 5min,
1779 R"(
1780 Sends data to all workers and receives data from all other workers.
1781
1782 Arguments:
1783 key (str): The unique key used to identify this operation.
1784 data (str): The data to send.
1785 timeout (duration): The timeout for this operation.
1786 )");
1787
1788 intrusive_ptr_class_<::c10d::StoreCollectives>(
1789 module,
1790 "_StoreCollectives",
1791 collectives,
1792 R"(
1793 An implementation of ControlCollectives that uses the provided store as the underlying
1794 communication mechanism.
1795 )")
1796 .def(
1797 py::init<c10::intrusive_ptr<::c10d::Store>, int, int>(),
1798 py::arg("store"),
1799 py::arg("rank"),
1800 py::arg("world_size"));
1801
1802 auto processGroup =
1803 py::class_<
1804 ::c10d::ProcessGroup,
1805 c10::intrusive_ptr<::c10d::ProcessGroup>,
1806 ::c10d::PyProcessGroup>(module, "ProcessGroup")
1807 .def(py::init<int, int>())
1808 .def(
1809 py::init<
1810 const c10::intrusive_ptr<::c10d::Store>&,
1811 int,
1812 int,
1813 c10::intrusive_ptr<::c10d::ProcessGroup::Options>>(),
1814 py::call_guard<py::gil_scoped_release>())
1815 .def("rank", &::c10d::ProcessGroup::getRank)
1816 .def("size", &::c10d::ProcessGroup::getSize)
1817 .def("name", &::c10d::ProcessGroup::getBackendName)
1818 .def("_id", &::c10d::ProcessGroup::getID)
1819 .def(
1820 "_backend_id",
1821 &::c10d::ProcessGroup::getBackendID,
1822 py::arg("backend_type"))
1823 .def_property_readonly("options", &::c10d::ProcessGroup::getOptions)
1824 .def(
1825 "broadcast",
1826 &::c10d::ProcessGroup::broadcast,
1827 py::arg("tensors"),
1828 py::arg("opts") = ::c10d::BroadcastOptions(),
1829 py::call_guard<py::gil_scoped_release>())
1830 .def(
1831 "broadcast",
1832 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1833 at::Tensor& x,
1834 int rootRank) {
1835 ::c10d::BroadcastOptions opts;
1836 opts.rootRank = rootRank;
1837 std::vector<at::Tensor> tensors = {x};
1838 return self->broadcast(tensors, opts);
1839 },
1840 py::arg("tensor"),
1841 py::arg("root"),
1842 py::call_guard<py::gil_scoped_release>())
1843 .def(
1844 "allreduce",
1845 &::c10d::ProcessGroup::allreduce,
1846 py::arg("tensors"),
1847 py::arg("opts") = ::c10d::AllreduceOptions(),
1848 py::call_guard<py::gil_scoped_release>())
1849 .def(
1850 "allreduce",
1851 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1852 std::vector<at::Tensor>& xs,
1853 const ::c10d::ReduceOp& op) {
1854 ::c10d::AllreduceOptions opts;
1855 opts.reduceOp = op;
1856 return self->allreduce(xs, opts);
1857 },
1858 py::arg("tensors"),
1859 py::arg("op") = ::c10d::ReduceOp::SUM,
1860 py::call_guard<py::gil_scoped_release>())
1861
1862 .def(
1863 "allreduce",
1864 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1865 at::Tensor& x,
1866 const ::c10d::ReduceOp& op) {
1867 ::c10d::AllreduceOptions opts;
1868 opts.reduceOp = op;
1869 std::vector<at::Tensor> xs = {x};
1870 return self->allreduce(xs, opts);
1871 },
1872 py::arg("tensor"),
1873 py::arg("op") = ::c10d::ReduceOp::SUM,
1874 py::call_guard<py::gil_scoped_release>())
1875 .def(
1876 "allreduce_coalesced",
1877 &::c10d::ProcessGroup::allreduce_coalesced,
1878 py::arg("tensors"),
1879 py::arg("opts") = ::c10d::AllreduceCoalescedOptions(),
1880 py::call_guard<py::gil_scoped_release>())
1881
1882 .def(
1883 "reduce",
1884 &::c10d::ProcessGroup::reduce,
1885 py::arg("tensors"),
1886 py::arg("opts") = ::c10d::ReduceOptions(),
1887 py::call_guard<py::gil_scoped_release>())
1888
1889 .def(
1890 "reduce",
1891 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1892 at::Tensor& x,
1893 int rootRank,
1894 const ::c10d::ReduceOp& op) {
1895 ::c10d::ReduceOptions opts;
1896 opts.reduceOp = op;
1897 opts.rootRank = rootRank;
1898 std::vector<at::Tensor> xs = {x};
1899 return self->reduce(xs, opts);
1900 },
1901 py::arg("tensor"),
1902 py::arg("root"),
1903 py::arg("op") = ::c10d::ReduceOp::SUM,
1904 py::call_guard<py::gil_scoped_release>())
1905 .def(
1906 "allgather",
1907 &::c10d::ProcessGroup::allgather,
1908 py::arg("output_tensors"),
1909 py::arg("input_tensors"),
1910 py::arg("opts") = ::c10d::AllgatherOptions(),
1911 py::call_guard<py::gil_scoped_release>())
1912 .def(
1913 "allgather",
1914 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1915 std::vector<at::Tensor>& output,
1916 at::Tensor& input) {
1917 std::vector<std::vector<at::Tensor>> outputs = {output};
1918 std::vector<at::Tensor> inputs = {input};
1919 return self->allgather(
1920 outputs, inputs, ::c10d::AllgatherOptions());
1921 },
1922 py::arg("output_tensors"),
1923 py::arg("input_tensor"),
1924 py::call_guard<py::gil_scoped_release>())
1925 .def(
1926 "_allgather_base",
1927 &::c10d::ProcessGroup::_allgather_base,
1928 py::arg("output"),
1929 py::arg("input"),
1930 py::arg("opts") = ::c10d::AllgatherOptions(),
1931 py::call_guard<py::gil_scoped_release>())
1932 .def(
1933 "allgather_coalesced",
1934 &::c10d::ProcessGroup::allgather_coalesced,
1935 py::arg("output_lists"),
1936 py::arg("input_list"),
1937 py::arg("opts") = ::c10d::AllgatherOptions(),
1938 py::call_guard<py::gil_scoped_release>())
1939 .def(
1940 "allgather_into_tensor_coalesced",
1941 &::c10d::ProcessGroup::allgather_into_tensor_coalesced,
1942 py::arg("outputs"),
1943 py::arg("inputs"),
1944 py::arg("opts") = ::c10d::AllgatherOptions(),
1945 py::call_guard<py::gil_scoped_release>())
1946 .def(
1947 "gather",
1948 &::c10d::ProcessGroup::gather,
1949 py::arg("output_tensors"),
1950 py::arg("input_tensors"),
1951 py::arg("opts") = ::c10d::GatherOptions(),
1952 py::call_guard<py::gil_scoped_release>())
1953
1954 .def(
1955 "gather",
1956 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1957 std::vector<at::Tensor>& output,
1958 at::Tensor& input,
1959 int rootRank) {
1960 ::c10d::GatherOptions opts;
1961 opts.rootRank = rootRank;
1962 std::vector<std::vector<at::Tensor>> outputs = {output};
1963 std::vector<at::Tensor> inputs = {input};
1964 return self->gather(outputs, inputs, opts);
1965 },
1966 py::arg("output_tensors"),
1967 py::arg("input_tensor"),
1968 py::arg("root"),
1969 py::call_guard<py::gil_scoped_release>())
1970 .def(
1971 "scatter",
1972 &::c10d::ProcessGroup::scatter,
1973 py::arg("output_tensors"),
1974 py::arg("input_tensors"),
1975 py::arg("opts") = ::c10d::ScatterOptions(),
1976 py::call_guard<py::gil_scoped_release>())
1977 .def(
1978 "scatter",
1979 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1980 at::Tensor& output,
1981 std::vector<at::Tensor>& input,
1982 int rootRank) {
1983 ::c10d::ScatterOptions opts;
1984 opts.rootRank = rootRank;
1985 std::vector<std::vector<at::Tensor>> inputs = {input};
1986 std::vector<at::Tensor> outputs = {output};
1987 return self->scatter(outputs, inputs, opts);
1988 },
1989 py::arg("output_tensor"),
1990 py::arg("input_tensors"),
1991 py::arg("root"),
1992 py::call_guard<py::gil_scoped_release>())
1993 .def(
1994 "reduce_scatter",
1995 &::c10d::ProcessGroup::reduce_scatter,
1996 py::arg("output_tensors"),
1997 py::arg("input_tensors"),
1998 py::arg("opts") = ::c10d::ReduceScatterOptions(),
1999 py::call_guard<py::gil_scoped_release>())
2000 .def(
2001 "reduce_scatter",
2002 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
2003 at::Tensor& output,
2004 std::vector<at::Tensor>& input,
2005 const ::c10d::ReduceOp& op) {
2006 std::vector<at::Tensor> outputs = {output};
2007 std::vector<std::vector<at::Tensor>> inputs = {input};
2008 ::c10d::ReduceScatterOptions opts;
2009 opts.reduceOp = op;
2010 return self->reduce_scatter(outputs, inputs, opts);
2011 },
2012 py::arg("output"),
2013 py::arg("input"),
2014 py::arg("op") = ::c10d::ReduceOp::SUM,
2015 py::call_guard<py::gil_scoped_release>())
2016 .def(
2017 "_reduce_scatter_base",
2018 &::c10d::ProcessGroup::_reduce_scatter_base,
2019 py::arg("outputTensor"),
2020 py::arg("inputTensor"),
2021 py::arg("opts") = ::c10d::ReduceScatterOptions(),
2022 py::call_guard<py::gil_scoped_release>())
2023 .def(
2024 "reduce_scatter_tensor_coalesced",
2025 &::c10d::ProcessGroup::reduce_scatter_tensor_coalesced,
2026 py::arg("outputs"),
2027 py::arg("inputs"),
2028 py::arg("opts") = ::c10d::ReduceScatterOptions(),
2029 py::call_guard<py::gil_scoped_release>())
2030 .def(
2031 "alltoall_base",
2032 &::c10d::ProcessGroup::alltoall_base,
2033 py::arg("output"),
2034 py::arg("input"),
2035 py::arg("output_split_sizes"),
2036 py::arg("input_split_sizes"),
2037 py::arg("opts") = ::c10d::AllToAllOptions(),
2038 py::call_guard<py::gil_scoped_release>())
2039 .def(
2040 "alltoall",
2041 &::c10d::ProcessGroup::alltoall,
2042 py::arg("output_tensors"),
2043 py::arg("input_tensors"),
2044 py::arg("opts") = ::c10d::AllToAllOptions(),
2045 py::call_guard<py::gil_scoped_release>())
2046 .def(
2047 "send",
2048 &::c10d::ProcessGroup::send,
2049 py::arg("tensors"),
2050 py::arg("dstRank"),
2051 py::arg("tag"),
2052 py::call_guard<py::gil_scoped_release>())
2053 .def(
2054 "recv",
2055 &::c10d::ProcessGroup::recv,
2056 py::arg("tensors"),
2057 py::arg("srcRank"),
2058 py::arg("tag"),
2059 py::call_guard<py::gil_scoped_release>())
2060 .def(
2061 "recv_anysource",
2062 &::c10d::ProcessGroup::recvAnysource,
2063 py::call_guard<py::gil_scoped_release>())
2064 .def(
2065 "barrier",
2066 &::c10d::ProcessGroup::barrier,
2067 py::arg("opts") = ::c10d::BarrierOptions(),
2068 py::call_guard<py::gil_scoped_release>())
2069 .def(
2070 "_set_sequence_number_for_group",
2071 &::c10d::ProcessGroup::setSequenceNumberForGroup,
2072 py::call_guard<py::gil_scoped_release>())
2073 .def(
2074 "_get_sequence_number_for_group",
2075 &::c10d::ProcessGroup::getSequenceNumberForGroup,
2076 py::call_guard<py::gil_scoped_release>())
2077 .def(
2078 "monitored_barrier",
2079 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
2080 const std::chrono::milliseconds& timeout,
2081 bool waitAllRanks) {
2082 ::c10d::BarrierOptions opts;
2083 opts.timeout = timeout;
2084 return self->monitoredBarrier(opts, waitAllRanks);
2085 },
2086 py::arg("timeout") = ::c10d::kUnsetTimeout,
2087 py::arg("wait_all_ranks") = false,
2088 py::call_guard<py::gil_scoped_release>())
2089 .def_property_readonly(
2090 "_device_types", &::c10d::ProcessGroup::getDeviceTypes)
2091 .def(
2092 "_get_backend_name",
2093 &::c10d::ProcessGroup::getBackendName,
2094 py::call_guard<py::gil_scoped_release>())
2095 .def(
2096 "_start_coalescing",
2097 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
2098 const c10::Device& device) {
2099 self->startCoalescing(device.type());
2100 },
2101 py::arg("device_type"),
2102 py::call_guard<py::gil_scoped_release>())
2103 .def(
2104 "_end_coalescing",
2105 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
2106 const c10::Device& device) {
2107 return self->endCoalescing(device.type());
2108 },
2109 py::arg("device_type"),
2110 py::call_guard<py::gil_scoped_release>())
2111 .def(
2112 "_register_backend",
2113 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
2114 const c10::Device& device,
2115 const ::c10d::ProcessGroup::BackendType& backendType,
2116 const std::optional<c10::intrusive_ptr<::c10d::Backend>>&
2117 backend) {
2118 self->setBackend(device.type(), backendType, backend);
2119 },
2120 py::arg("device"),
2121 py::arg("backend_type"),
2122 py::arg("backend") =
2123 std::optional<c10::intrusive_ptr<::c10d::Backend>>(),
2124 py::call_guard<py::gil_scoped_release>())
2125 .def(
2126 "_get_backend",
2127 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
2128 const c10::Device& device) {
2129 return self->getBackend(device.type());
2130 },
2131 py::arg("device"),
2132 py::call_guard<py::gil_scoped_release>())
2133 .def(
2134 "_register_on_completion_hook",
2135 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
2136 py::object hook) {
2137 // We need to wrap a py::object hook with a wrapper to hold
2138 // GIL before dereferencing the py::object.
2139 // This needs to happen here instead of in ProcessGroup
2140 // backend implementations and the latter cannot depend on
2141 // python-related libs.
2142 self->registerOnCompletionHook(
2143 [hookWrapper = ::c10d::PythonOnCompletionHook(std::move(
2144 hook))](std::shared_ptr<::c10d::WorkInfo> workInfo) {
2145 hookWrapper(workInfo);
2146 });
2147 },
2148 py::arg("hook"),
2149 // Intentionally holding GIL as we move hook py::object. This
2150 // should be OK as register a hook is cheap.
2151 py::call_guard<py::gil_scoped_acquire>(),
2152 R"(
2153 Register a hook function which is fired on every ``ProcessGroup::Work`` completion.
2154 The hook must have the following signature:
2155
2156 >>> def hook(work_info: torch._C._distributed_c10d.WorkInfo) -> None:
2157 >>> # custom code
2158 >>> # work_info.op_type: type of collective of this work
2159 >>> # work_info.seq: sequence number of collective of this work
2160 >>> # work_info.time_started: system time when user code called this collective
2161 >>> # work_info.time_finished: system time when the watchdog thread detected
2162 >>> # completion of this work. Note that, there can be delays between the
2163 >>> # actual completion time and the detection time.
2164 >>> # work_info.active_duration: duration of this collective measured by CUDAEvents
2165 >>> # which can accurately represent the duration between when the collective
2166 >>> # is launched and when the collective completes.
2167
2168 .. warning ::
2169 This only works for NCCL backend for now. All hooks are fired on the cpp watch dog
2170 thread. Firing the Python hook and acquiring GIL requires Python interpreter to be
2171 alive. Therefore, users need to make sure calling ``destroy_process_group(pg)`` on
2172 every active ProcessGroup ``pg`` before exiting.
2173
2174 .. warning ::
2175 Note that ``Work`` object passed to the hook is a partially copied version without
2176 the output objects. So accessing the output tensors from ``Work`` will not work.
2177
2178
2179 Arguments:
2180 hook (Callable): hook function.
2181 )")
2182 .def(
2183 "_wait_for_pending_works",
2184 &::c10d::ProcessGroup::waitForPendingWorks,
2185 py::call_guard<py::gil_scoped_release>())
2186 .def(
2187 "_has_hooks",
2188 &::c10d::ProcessGroup::hasHooks,
2189 py::call_guard<py::gil_scoped_acquire>())
2190 .def(
2191 "_enable_collectives_timing",
2192 &::c10d::ProcessGroup::enableCollectivesTiming,
2193 py::call_guard<py::gil_scoped_acquire>(),
2194 "Enable timing of collectives by all backends. This might incur in additional overhead.")
2195 .def(
2196 "_set_group_name",
2197 &::c10d::ProcessGroup::setGroupName,
2198 py::call_guard<py::gil_scoped_acquire>(),
2199 "Sets the process group name. This is an internal C10D method, do not use.")
2200 .def_property_readonly(
2201 "group_name",
2202 &::c10d::ProcessGroup::getGroupName,
2203 "(Gets this process group name. It's cluster unique)")
2204 .def(
2205 "_set_group_desc",
2206 &::c10d::ProcessGroup::setGroupDesc,
2207 py::call_guard<py::gil_scoped_acquire>(),
2208 "Sets the process group description. This is an internal C10D method, do not use.")
2209 .def_property_readonly(
2210 "group_desc",
2211 &::c10d::ProcessGroup::getGroupDesc,
2212 "Gets this process group description")
2213 .def_property(
2214 "bound_device_id",
2215 &::c10d::ProcessGroup::getBoundDeviceId,
2216 &::c10d::ProcessGroup::setBoundDeviceId)
2217 .def("boxed", [](c10::intrusive_ptr<::c10d::ProcessGroup> self) {
2218 return torch::jit::toPyObject(c10::IValue(std::move(self)));
2219 })
2220 .def_static("unbox", [](py::object obj) {
2221 auto typePtr = torch::getCustomClass("__torch__.torch.classes.c10d.ProcessGroup");
2222 auto ivalue = torch::jit::toIValue(std::move(obj), typePtr);
2223 return ivalue.toCustomClass<::c10d::ProcessGroup>();
2224 });
2225
2226 py::enum_<::c10d::ProcessGroup::BackendType>(processGroup, "BackendType")
2227 .value("UNDEFINED", ::c10d::ProcessGroup::BackendType::UNDEFINED)
2228 .value("GLOO", ::c10d::ProcessGroup::BackendType::GLOO)
2229 .value("NCCL", ::c10d::ProcessGroup::BackendType::NCCL)
2230 .value("UCC", ::c10d::ProcessGroup::BackendType::UCC)
2231 .value("MPI", ::c10d::ProcessGroup::BackendType::MPI)
2232 .value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM)
2233 .export_values();
2234
2235 // base ProcessGroup::Options binding
2236 auto processGroupOptions =
2237 intrusive_ptr_class_<::c10d::ProcessGroup::Options>(
2238 processGroup,
2239 "Options",
2240 R"(
2241 Base class for all processes group options implementations, such as the nccl
2242 options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
2243 )")
2244 .def(
2245 py::init([](const std::string& backend,
2246 const std::chrono::milliseconds& timeout) {
2247 return c10::make_intrusive<::c10d::ProcessGroup::Options>(
2248 backend, timeout);
2249 }),
2250 py::arg("backend"),
2251 py::arg("timeout") = kProcessGroupDefaultTimeout,
2252 py::call_guard<py::gil_scoped_release>())
2253 .def_readonly("backend", &::c10d::ProcessGroup::Options::backend)
2254 .def_readwrite("_timeout", &::c10d::ProcessGroup::Options::timeout);
2255
2256 // TODO: The collection definitions handles direct instantiation of
2257 // ProcessGroup subclasses (e.g. dist.ProcessGroupGloo). This is not supported
2258 // and should be removed once all tests are transitioned
2259 auto backend =
2260 py::class_<::c10d::Backend, c10::intrusive_ptr<::c10d::Backend>>(
2261 module, "Backend")
2262 .def("rank", &::c10d::Backend::getRank)
2263 .def("size", &::c10d::Backend::getSize)
2264 .def("name", &::c10d::Backend::getBackendName)
2265 .def_property_readonly(
2266 "supports_splitting",
2267 &::c10d::Backend::supportsSplitting,
2268 "(test whether the backend supports splitting)")
2269 .def(
2270 "broadcast",
2271 &::c10d::Backend::broadcast,
2272 py::arg("tensors"),
2273 py::arg("opts") = ::c10d::BroadcastOptions(),
2274 py::call_guard<py::gil_scoped_release>())
2275 .def(
2276 "broadcast",
2277 [](const c10::intrusive_ptr<::c10d::Backend>& self,
2278 at::Tensor& x,
2279 int rootRank) {
2280 ::c10d::BroadcastOptions opts;
2281 opts.rootRank = rootRank;
2282 std::vector<at::Tensor> xs = {x};
2283 return self->broadcast(xs, opts);
2284 },
2285 py::arg("tensor"),
2286 py::arg("root"),
2287 py::call_guard<py::gil_scoped_release>())
2288 .def(
2289 "allreduce",
2290 &::c10d::Backend::allreduce,
2291 py::arg("tensors"),
2292 py::arg("opts") = ::c10d::AllreduceOptions(),
2293 py::call_guard<py::gil_scoped_release>())
2294 .def(
2295 "allreduce",
2296 [](const c10::intrusive_ptr<::c10d::Backend>& self,
2297 std::vector<at::Tensor>& xs,
2298 const ::c10d::ReduceOp& op) {
2299 ::c10d::AllreduceOptions opts;
2300 opts.reduceOp = op;
2301 return self->allreduce(xs, opts);
2302 },
2303 py::arg("tensors"),
2304 py::arg("op") = ::c10d::ReduceOp::SUM,
2305 py::call_guard<py::gil_scoped_release>())
2306 .def(
2307 "allreduce",
2308 [](const c10::intrusive_ptr<::c10d::Backend>& self,
2309 at::Tensor& x,
2310 const ::c10d::ReduceOp& op) {
2311 ::c10d::AllreduceOptions opts;
2312 opts.reduceOp = op;
2313 std::vector<at::Tensor> xs = {x};
2314 return self->allreduce(xs, opts);
2315 },
2316 py::arg("tensor"),
2317 py::arg("op") = ::c10d::ReduceOp::SUM,
2318 py::call_guard<py::gil_scoped_release>())
2319 .def(
2320 "allreduce_coalesced",
2321 &::c10d::Backend::allreduce_coalesced,
2322 py::arg("tensors"),
2323 py::arg("opts") = ::c10d::AllreduceCoalescedOptions(),
2324 py::call_guard<py::gil_scoped_release>())
2325 .def(
2326 "reduce",
2327 &::c10d::Backend::reduce,
2328 py::arg("tensors"),
2329 py::arg("opts") = ::c10d::ReduceOptions(),
2330 py::call_guard<py::gil_scoped_release>())
2331 .def(
2332 "reduce",
2333 [](const c10::intrusive_ptr<::c10d::Backend>& self,
2334 at::Tensor& x,
2335 int rootRank,
2336 const ::c10d::ReduceOp& op) {
2337 ::c10d::ReduceOptions opts;
2338 opts.reduceOp = op;
2339 opts.rootRank = rootRank;
2340 std::vector<at::Tensor> xs = {x};
2341 return self->reduce(xs, opts);
2342 },
2343 py::arg("tensor"),
2344 py::arg("root"),
2345 py::arg("op") = ::c10d::ReduceOp::SUM,
2346 py::call_guard<py::gil_scoped_release>())
2347 .def(
2348 "allgather",
2349 &::c10d::Backend::allgather,
2350 py::arg("output_tensors"),
2351 py::arg("input_tensors"),
2352 py::arg("opts") = ::c10d::AllgatherOptions(),
2353 py::call_guard<py::gil_scoped_release>())
2354 .def(
2355 "_allgather_base",
2356 &::c10d::Backend::_allgather_base,
2357 py::arg("output"),
2358 py::arg("input"),
2359 py::arg("opts") = ::c10d::AllgatherOptions(),
2360 py::call_guard<py::gil_scoped_release>())
2361 .def(
2362 "allgather",
2363 [](const c10::intrusive_ptr<::c10d::Backend>& self,
2364 std::vector<at::Tensor>& output,
2365 at::Tensor& input) {
2366 std::vector<std::vector<at::Tensor>> outputs = {output};
2367 std::vector<at::Tensor> inputs = {input};
2368 return self->allgather(
2369 outputs, inputs, ::c10d::AllgatherOptions());
2370 },
2371 py::arg("output_tensors"),
2372 py::arg("input_tensor"),
2373 py::call_guard<py::gil_scoped_release>())
2374 .def(
2375 "allgather_coalesced",
2376 &::c10d::Backend::allgather_coalesced,
2377 py::arg("output_lists"),
2378 py::arg("input_list"),
2379 py::arg("opts") = ::c10d::AllgatherOptions(),
2380 py::call_guard<py::gil_scoped_release>())
2381 .def(
2382 "gather",
2383 &::c10d::Backend::gather,
2384 py::arg("output_tensors"),
2385 py::arg("input_tensors"),
2386 py::arg("opts") = ::c10d::GatherOptions(),
2387 py::call_guard<py::gil_scoped_release>())
2388 .def(
2389 "gather",
2390 [](const c10::intrusive_ptr<::c10d::Backend>& self,
2391 std::vector<at::Tensor>& output,
2392 at::Tensor& input,
2393 int rootRank) {
2394 ::c10d::GatherOptions opts;
2395 opts.rootRank = rootRank;
2396 std::vector<std::vector<at::Tensor>> outputs = {output};
2397 std::vector<at::Tensor> inputs = {input};
2398 return self->gather(outputs, inputs, opts);
2399 },
2400 py::arg("output_tensors"),
2401 py::arg("input_tensor"),
2402 py::arg("root"),
2403 py::call_guard<py::gil_scoped_release>())
2404 .def(
2405 "scatter",
2406 &::c10d::Backend::scatter,
2407 py::arg("output_tensors"),
2408 py::arg("input_tensors"),
2409 py::arg("opts") = ::c10d::ScatterOptions(),
2410 py::call_guard<py::gil_scoped_release>())
2411 .def(
2412 "scatter",
2413 [](const c10::intrusive_ptr<::c10d::Backend>& self,
2414 at::Tensor& output,
2415 std::vector<at::Tensor>& input,
2416 int rootRank) {
2417 ::c10d::ScatterOptions opts;
2418 opts.rootRank = rootRank;
2419 std::vector<std::vector<at::Tensor>> inputs = {input};
2420 std::vector<at::Tensor> outputs = {output};
2421 return self->scatter(outputs, inputs, opts);
2422 },
2423 py::arg("output_tensor"),
2424 py::arg("input_tensors"),
2425 py::arg("root"),
2426 py::call_guard<py::gil_scoped_release>())
2427 .def(
2428 "reduce_scatter",
2429 &::c10d::Backend::reduce_scatter,
2430 py::arg("output_tensors"),
2431 py::arg("input_tensors"),
2432 py::arg("opts") = ::c10d::ReduceScatterOptions(),
2433 py::call_guard<py::gil_scoped_release>())
2434 .def(
2435 "reduce_scatter",
2436 [](const c10::intrusive_ptr<::c10d::Backend>& self,
2437 at::Tensor& output,
2438 std::vector<at::Tensor>& input,
2439 const ::c10d::ReduceOp& op) {
2440 std::vector<at::Tensor> outputs = {output};
2441 std::vector<std::vector<at::Tensor>> inputs = {input};
2442 ::c10d::ReduceScatterOptions opts;
2443 opts.reduceOp = op;
2444 return self->reduce_scatter(outputs, inputs, opts);
2445 },
2446 py::arg("output_tensors"),
2447 py::arg("input_tensor"),
2448 py::arg("op") = ::c10d::ReduceOp::SUM,
2449 py::call_guard<py::gil_scoped_release>())
2450 .def(
2451 "_reduce_scatter_base",
2452 &::c10d::Backend::_reduce_scatter_base,
2453 py::arg("outputTensor"),
2454 py::arg("inputTensor"),
2455 py::arg("opts") = ::c10d::ReduceScatterOptions(),
2456 py::call_guard<py::gil_scoped_release>())
2457 .def(
2458 "alltoall_base",
2459 &::c10d::Backend::alltoall_base,
2460 py::arg("output_tensor"),
2461 py::arg("input_tensor"),
2462 py::arg("output_split_sizes"),
2463 py::arg("input_split_sizes"),
2464 py::arg("opts") = ::c10d::AllToAllOptions(),
2465 py::call_guard<py::gil_scoped_release>())
2466 .def(
2467 "alltoall_base",
2468 [](::c10d::Backend& self,
2469 at::Tensor& output,
2470 at::Tensor& input,
2471 std::vector<int64_t> outputSplitSizes,
2472 std::vector<int64_t> inputSplitSizes) {
2473 return self.alltoall_base(
2474 output,
2475 input,
2476 outputSplitSizes,
2477 inputSplitSizes,
2478 ::c10d::AllToAllOptions());
2479 },
2480 py::arg("output"),
2481 py::arg("input"),
2482 py::arg("output_split_sizes"),
2483 py::arg("input_split_sizes"),
2484 py::call_guard<py::gil_scoped_release>())
2485 .def(
2486 "alltoall",
2487 &::c10d::Backend::alltoall,
2488 py::arg("output_tensor"),
2489 py::arg("input_tensor"),
2490 py::arg("opts") = ::c10d::AllToAllOptions(),
2491 py::call_guard<py::gil_scoped_release>())
2492 .def(
2493 "send",
2494 &::c10d::Backend::send,
2495 py::arg("tensors"),
2496 py::arg("dstRank"),
2497 py::arg("tag"),
2498 py::call_guard<py::gil_scoped_release>())
2499 .def(
2500 "recv",
2501 &::c10d::Backend::recv,
2502 py::arg("tensors"),
2503 py::arg("srcRank"),
2504 py::arg("tag"),
2505 py::call_guard<py::gil_scoped_release>())
2506 .def(
2507 "recv_anysource",
2508 &::c10d::Backend::recvAnysource,
2509 py::call_guard<py::gil_scoped_release>())
2510 .def(
2511 "barrier",
2512 [](const c10::intrusive_ptr<::c10d::Backend>& self,
2513 const ::c10d::BarrierOptions& opts) {
2514 return self->barrier(opts);
2515 },
2516 py::arg("opts") = ::c10d::BarrierOptions(),
2517 py::call_guard<py::gil_scoped_release>())
2518 .def(
2519 "_set_sequence_number_for_group",
2520 &::c10d::Backend::setSequenceNumberForGroup,
2521 py::call_guard<py::gil_scoped_release>())
2522 .def(
2523 "_get_sequence_number_for_group",
2524 &::c10d::Backend::getSequenceNumberForGroup,
2525 py::call_guard<py::gil_scoped_release>())
2526 .def(
2527 "monitored_barrier",
2528 [](const c10::intrusive_ptr<::c10d::Backend>& self,
2529 const std::chrono::milliseconds& timeout,
2530 bool waitAllRanks) {
2531 ::c10d::BarrierOptions opts;
2532 opts.timeout = timeout;
2533 return self->monitoredBarrier(opts, waitAllRanks);
2534 },
2535 py::arg("timeout") = ::c10d::kUnsetTimeout,
2536 py::arg("wait_all_ranks") = false,
2537 py::call_guard<py::gil_scoped_release>())
2538 .def(
2539 "eager_connect_single_device",
2540 &::c10d::Backend::eagerConnectSingleDevice,
2541 py::call_guard<py::gil_scoped_release>())
2542 .def(
2543 "_get_backend_name",
2544 &::c10d::Backend::getBackendName,
2545 py::call_guard<py::gil_scoped_release>())
2546 .def(
2547 "_start_coalescing",
2548 &::c10d::Backend::startCoalescing,
2549 py::call_guard<py::gil_scoped_release>())
2550 .def(
2551 "_end_coalescing",
2552 &::c10d::Backend::endCoalescing,
2553 py::call_guard<py::gil_scoped_release>());
2554
2555 #ifdef USE_C10D_GLOO
2556 static const std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME";
2557
2558 auto processGroupGloo =
2559 intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupGloo>(
2560 module, "ProcessGroupGloo", backend);
2561
2562 // NOLINTNEXTLINE(bugprone-unused-raii)
2563 shared_ptr_class_<::gloo::transport::Device>(processGroupGloo, "Device");
2564
2565 intrusive_ptr_class_<::c10d::ProcessGroupGloo::Options>(
2566 processGroupGloo, "_Options", processGroupOptions)
2567 .def(py::init<>())
2568 .def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices)
2569 .def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads);
2570
2571 processGroupGloo
2572 .def_static(
2573 "create_device",
2574 [](const std::string& hostname, const std::string& interface)
2575 -> std::shared_ptr<::gloo::transport::Device> {
2576 if (!hostname.empty()) {
2577 return ::c10d::ProcessGroupGloo::createDeviceForHostname(
2578 hostname);
2579 }
2580 if (!interface.empty()) {
2581 return ::c10d::ProcessGroupGloo::createDeviceForInterface(
2582 interface);
2583 }
2584 throw std::invalid_argument(
2585 "Specify either `hostname` or `interface` argument.");
2586 },
2587 py::arg("hostname") = "",
2588 py::arg("interface") = "")
2589 .def_static(
2590 "create_default_device",
2591 &::c10d::ProcessGroupGloo::createDefaultDevice);
2592
2593 processGroupGloo
2594 .def(
2595 py::init<
2596 const c10::intrusive_ptr<::c10d::Store>&,
2597 int,
2598 int,
2599 c10::intrusive_ptr<::c10d::ProcessGroupGloo::Options>>(),
2600 py::call_guard<py::gil_scoped_release>())
2601 .def(
2602 py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
2603 int rank,
2604 int size,
2605 std::chrono::milliseconds timeout) {
2606 auto options = ::c10d::ProcessGroupGloo::Options::create();
2607
2608 // Use interfaces listed in "GLOO_SOCKET_IFNAME", if set.
2609 char* ifnameEnv = getenv(GLOO_SOCKET_IFNAME_ENV.c_str());
2610 if (ifnameEnv && strlen(ifnameEnv) > 1) {
2611 for (const auto& iface : ::c10d::split(',', ifnameEnv)) {
2612 options->devices.push_back(
2613 ::c10d::ProcessGroupGloo::createDeviceForInterface(iface));
2614 }
2615 } else {
2616 // If no hostname is specified, this function looks up
2617 // the machine's hostname and returns a device instance
2618 // associated with the address that the hostname resolves to.
2619 options->devices.push_back(
2620 ::c10d::ProcessGroupGloo::createDefaultDevice());
2621 }
2622
2623 options->timeout = timeout;
2624 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
2625 options->threads = options->devices.size() * 2;
2626 return c10::make_intrusive<::c10d::ProcessGroupGloo>(
2627 store, rank, size, options);
2628 }),
2629 py::arg("store"),
2630 py::arg("rank"),
2631 py::arg("size"),
2632 py::arg("timeout") = kProcessGroupDefaultTimeout,
2633 py::call_guard<py::gil_scoped_release>())
2634 .def(
2635 "_set_default_timeout",
2636 [](const c10::intrusive_ptr<::c10d::ProcessGroupGloo>& self,
2637 std::chrono::milliseconds timeout) {
2638 self->getOptions()->timeout = timeout;
2639 },
2640 py::arg("timeout"),
2641 py::call_guard<py::gil_scoped_release>())
2642 .def_property_readonly("options", &::c10d::ProcessGroupGloo::getOptions);
2643
2644 // ProcessGroupWrapper is a wrapper pg that includes a helper gloo process
2645 // group. It can be used to validate collective calls across processes by
2646 // checking the op type and input tensor shapes.
2647 auto processGroupWrapper =
2648 intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupWrapper>(
2649 module, "_ProcessGroupWrapper", backend)
2650 .def(
2651 py::init(
2652 [](const c10::intrusive_ptr<::c10d::Backend>& backend,
2653 const c10::intrusive_ptr<::c10d::Backend>& gloo_backend) {
2654 return c10::make_intrusive<::c10d::ProcessGroupWrapper>(
2655 backend, gloo_backend);
2656 }),
2657 py::arg("backend"),
2658 py::arg("gloo_backend"),
2659 py::call_guard<py::gil_scoped_release>())
2660 .def_property_readonly(
2661 "wrapped_pg", &::c10d::ProcessGroupWrapper::getWrappedPg);
2662 #endif
2663
2664 #ifdef USE_C10D_NCCL
2665 auto processGroupNCCL =
2666 intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupNCCL>(
2667 module, "ProcessGroupNCCL", backend)
2668 .def(
2669 py::init<
2670 const c10::intrusive_ptr<::c10d::Store>&,
2671 int,
2672 int,
2673 c10::intrusive_ptr<::c10d::ProcessGroupNCCL::Options>>(),
2674 py::call_guard<py::gil_scoped_release>())
2675 .def(
2676 py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
2677 int rank,
2678 int size,
2679 const std::chrono::milliseconds& timeout) {
2680 auto options = ::c10d::ProcessGroupNCCL::Options::create();
2681 options->is_high_priority_stream = false;
2682 options->timeout = timeout;
2683 return c10::make_intrusive<::c10d::ProcessGroupNCCL>(
2684 store, rank, size, options);
2685 }),
2686 py::arg("store"),
2687 py::arg("rank"),
2688 py::arg("size"),
2689 py::arg("timeout") = ::c10d::kProcessGroupNCCLDefaultTimeout,
2690 py::call_guard<py::gil_scoped_release>())
2691 .def(
2692 "_shutdown",
2693 [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) {
2694 return self->shutdown();
2695 },
2696 py::call_guard<py::gil_scoped_release>())
2697 .def("_group_start", &::c10d::ProcessGroupNCCL::groupStart)
2698 .def("_group_end", &::c10d::ProcessGroupNCCL::groupEnd)
2699 .def(
2700 "comm_split_count",
2701 &::c10d::ProcessGroupNCCL::getCommSplitCounter)
2702 .def(
2703 "_set_default_timeout",
2704 [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self,
2705 std::chrono::milliseconds timeout) {
2706 self->getOptions()->timeout = timeout;
2707 },
2708 py::arg("timeout"),
2709 py::call_guard<py::gil_scoped_release>())
2710 .def(
2711 "_add_ephemeral_timeout",
2712 [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self,
2713 const std::chrono::milliseconds& timeout) {
2714 self->addEphemeralTimeout(timeout);
2715 },
2716 py::arg("timeout"))
2717 .def(
2718 "_verify_work_timeout",
2719 [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self,
2720 const c10::intrusive_ptr<::c10d::Work> work,
2721 const std::chrono::milliseconds& timeout) {
2722 return self->verifyWorkTimeoutForTest(work, timeout);
2723 },
2724 py::arg("work"),
2725 py::arg("timeout"))
2726 .def_property_readonly(
2727 "options", &::c10d::ProcessGroupNCCL::getOptions)
2728 .def_property_readonly("uid", &::c10d::ProcessGroupNCCL::getUid)
2729 .def_property(
2730 "bound_device_id",
2731 &::c10d::ProcessGroupNCCL::getBoundDeviceId,
2732 &::c10d::ProcessGroupNCCL::setBoundDeviceId)
2733 .def(
2734 "perform_nocolor_split",
2735 &::c10d::ProcessGroupNCCL::performNocolorSplit);
2736
2737 module.def(
2738 "_get_intra_node_comm_usage_counter",
2739 &::c10d::intra_node_comm::getIntraNodeCommUsageCounter);
2740
2741 using IntraNodeComm = ::c10d::intra_node_comm::IntraNodeComm;
2742 py::class_<IntraNodeComm, c10::intrusive_ptr<IntraNodeComm>>(
2743 module, "_IntraNodeComm")
2744 .def(
2745 py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
2746 size_t rank,
2747 size_t world_size,
2748 std::optional<size_t> buffer_size) {
2749 auto comm = c10::make_intrusive<IntraNodeComm>(
2750 store, rank, world_size, buffer_size);
2751 if (!comm->rendezvous()) {
2752 throw std::runtime_error("IntraNodeComm::rendezvous failed");
2753 }
2754 return comm;
2755 }),
2756 py::arg("store"),
2757 py::arg("rank"),
2758 py::arg("world_size"),
2759 py::arg("buffer_size") = std::nullopt)
2760 .def("barrier", &IntraNodeComm::barrier, py::arg("ranks") = py::none());
2761
2762 #ifdef NCCL_HAS_COMM_CTA_CGA
2763 py::class_<ncclConfig_t>(
2764 processGroupNCCL,
2765 "NCCLConfig",
2766 R"(
2767 ncclConfig_t data type for configuring NCCL communicators.
2768 See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
2769 for details.
2770 )")
2771 .def(py::init<>())
2772 .def_readwrite("blocking", &ncclConfig_t::blocking)
2773 .def_readwrite("cga_cluster_size", &ncclConfig_t::cgaClusterSize)
2774 .def_readwrite("min_ctas", &ncclConfig_t::minCTAs)
2775 .def_readwrite("max_ctas", &ncclConfig_t::maxCTAs)
2776 #ifdef NCCL_HAS_COMM_SPLIT
2777 .def_readwrite("split_share", &ncclConfig_t::splitShare)
2778 #endif
2779 .def_property(
2780 "net_name",
2781 [](const ncclConfig_t& self) { return self.netName; },
2782 // Note: NCCL calls free on the netName pointer
2783 // when destroying the communicator. So memory
2784 // shouldn't leak because of allocation in strdup.
2785 [](ncclConfig_t& self, const char* tmp) {
2786 self.netName = strdup(tmp);
2787 });
2788 #endif
2789
2790 intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>(
2791 processGroupNCCL,
2792 "Options",
2793 processGroupOptions,
2794 R"(
2795 ProcessGroup options for the NCCL backend
2796
2797 Arguments:
2798 is_high_priority_stream (bool, optional): flag to enable/disable process
2799 group to pick up high priority cuda streams. It lets CUDA driver
2800 to prioritize NCCL kernels when there are compute kernels waiting.
2801 Default is False.
2802
2803 Attributes:
2804 config (NCCLConfig): configures NCCL communicators (only avaiable for
2805 builds using NCCL 2.17+). This can be used to improve
2806 communication-computation overlap for NCCL kernels by tuning
2807 available parameters in the config. See
2808 https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
2809 for details.
2810
2811 Example::
2812 >>> import torch.distributed as dist
2813 >>>
2814 >>> nccl_options = dist.ProcessGroupNCCL.Options(is_high_priority_stream=True)
2815 >>> # For builds using NCCL 2.17+, configure communicators
2816 >>> nccl_options.config.cga_cluster_size = 2
2817 >>> nccl_options.config.max_ctas = 4
2818 >>> nccl_options.config.min_ctas = 2
2819 >>> nccl_options.config.split_share = 1
2820 >>> # initialize a nccl process group with the options just created
2821 >>> dist.init_process_group("nccl", pg_options=nccl_options)
2822 )")
2823 .def(py::init<bool>(), py::arg("is_high_priority_stream") = false)
2824 #ifdef NCCL_HAS_COMM_CTA_CGA
2825 .def_readwrite("config", &::c10d::ProcessGroupNCCL::Options::config)
2826 #endif
2827 .def_readwrite(
2828 "is_high_priority_stream",
2829 &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream)
2830 .def_readwrite(
2831 "split_from", &::c10d::ProcessGroupNCCL::Options::split_from)
2832 .def_readwrite(
2833 "split_color", &::c10d::ProcessGroupNCCL::Options::split_color)
2834 .def_readwrite(
2835 "global_ranks_in_group",
2836 &::c10d::ProcessGroupNCCL::Options::global_ranks_in_group)
2837 .def_readwrite(
2838 "group_name", &::c10d::ProcessGroupNCCL::Options::group_name);
2839 #endif
2840
2841 #ifdef USE_C10D_MPI
2842 auto processGroupMPI =
2843 intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupMPI>(
2844 module, "ProcessGroupMPI", backend);
2845
2846 // Define static create function instead of a constructor, because
2847 // this function may return null. This happens if this process is not
2848 // part of a sub group that is to be created.
2849 processGroupMPI.def_static(
2850 "create",
2851 [](std::vector<int> ranks) {
2852 return ::c10d::ProcessGroupMPI::createProcessGroupMPI(std::move(ranks));
2853 },
2854 py::call_guard<py::gil_scoped_release>());
2855 #endif
2856
2857 #ifdef USE_C10D_UCC
2858 auto processGroupUCC =
2859 intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupUCC>(
2860 module, "ProcessGroupUCC", backend)
2861 .def(
2862 py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
2863 int rank,
2864 int size,
2865 const std::chrono::milliseconds& timeout) {
2866 return c10::make_intrusive<::c10d::ProcessGroupUCC>(
2867 store, rank, size, timeout);
2868 }),
2869 py::arg("store"),
2870 py::arg("rank"),
2871 py::arg("size"),
2872 py::arg("timeout") = kProcessGroupDefaultTimeout,
2873 py::call_guard<py::gil_scoped_release>());
2874 #endif
2875
2876 py::enum_<::c10d::OpType>(module, "OpType")
2877 .value("BROADCAST", ::c10d::OpType::BROADCAST)
2878 .value("ALLREDUCE", ::c10d::OpType::ALLREDUCE)
2879 .value("ALLREDUCE_COALESCED", ::c10d::OpType::ALLREDUCE_COALESCED)
2880 .value("REDUCE", ::c10d::OpType::REDUCE)
2881 .value("ALLGATHER", ::c10d::OpType::ALLGATHER)
2882 .value("_ALLGATHER_BASE", ::c10d::OpType::_ALLGATHER_BASE)
2883 .value("ALLGATHER_COALESCED", ::c10d::OpType::ALLGATHER_COALESCED)
2884 .value("GATHER", ::c10d::OpType::GATHER)
2885 .value("SCATTER", ::c10d::OpType::SCATTER)
2886 .value("REDUCE_SCATTER", ::c10d::OpType::REDUCE_SCATTER)
2887 .value("ALLTOALL_BASE", ::c10d::OpType::ALLTOALL_BASE)
2888 .value("ALLTOALL", ::c10d::OpType::ALLTOALL)
2889 .value("SEND", ::c10d::OpType::SEND)
2890 .value("RECV", ::c10d::OpType::RECV)
2891 .value("RECVANYSOURCE", ::c10d::OpType::RECVANYSOURCE)
2892 .value("BARRIER", ::c10d::OpType::BARRIER)
2893 .value("_REDUCE_SCATTER_BASE", ::c10d::OpType::_REDUCE_SCATTER_BASE)
2894 .value("COALESCED", ::c10d::OpType::COALESCED)
2895 .value("_ALLREDUCE_SPARSE", ::c10d::OpType::_ALLREDUCE_SPARSE)
2896 .value("UNKNOWN", ::c10d::OpType::UNKNOWN);
2897
2898 py::class_<::c10d::WorkInfo, std::shared_ptr<::c10d::WorkInfo>>(
2899 module, "WorkInfo")
2900 .def_readonly("op_type", &::c10d::WorkInfo::opType)
2901 .def_readonly("seq", &::c10d::WorkInfo::seq)
2902 .def_readonly("time_started", &::c10d::WorkInfo::timeStarted)
2903 .def_readonly("time_finished", &::c10d::WorkInfo::timeFinished)
2904 .def_readonly("active_duration", &::c10d::WorkInfo::activeDuration);
2905
2906 py::class_<
2907 ::c10d::Work,
2908 c10::intrusive_ptr<::c10d::Work>,
2909 ::c10d::PyProcessGroup::PyWork>(module, "Work", R"(
2910 A `Work` object represents the handle to a pending asynchronous operation in
2911 PyTorch's distributed package. It is returned by non-blocking collective operations,
2912 such as `dist.all_reduce(tensor, async_op=True)`.
2913 )")
2914 .def(py::init<>())
2915 .def("is_completed", &::c10d::Work::isCompleted)
2916 .def(
2917 "is_success",
2918 [](::c10d::Work& work) -> bool {
2919 TORCH_WARN_ONCE(
2920 fmt::format(kDeprecationWarning, "Work::is_success"));
2921 return work.isSuccess();
2922 })
2923 .def(
2924 "exception",
2925 [](::c10d::Work& work) -> std::exception_ptr {
2926 TORCH_WARN_ONCE(
2927 fmt::format(kDeprecationWarning, "Work::exception"));
2928 return work.exception();
2929 })
2930 .def(
2931 "source_rank",
2932 [](::c10d::Work& work) -> int {
2933 TORCH_WARN_ONCE(
2934 fmt::format(kDeprecationWarning, "Work::source_rank"));
2935 return work.sourceRank();
2936 })
2937 .def("_source_rank", &::c10d::Work::sourceRank)
2938 .def(
2939 "result",
2940 [](::c10d::Work& work) -> std::vector<at::Tensor> {
2941 // Deprecation reason:
2942 // Work.result() returns a vector of tensors. This signature is
2943 // problematic as some collectives may just return one tensor
2944 // (e.g all-reduce), while some others may return multiple
2945 // tensors (e.g. all-gather).
2946 // Deprecating work.result() would
2947 // also allow us to remove the `outputs_` field in the Work
2948 // class, avoiding an "artificial" reference to the tensors,
2949 // which could potentially hold up the tensors' memory.
2950 TORCH_WARN_ONCE(fmt::format(kDeprecationWarning, "Work::result"));
2951 return work.result();
2952 })
2953 .def(
2954 "synchronize",
2955 [](::c10d::Work& work) -> void {
2956 TORCH_WARN_ONCE(
2957 fmt::format(kDeprecationWarning, "Work::synchronize"));
2958 work.synchronize();
2959 })
2960 .def(
2961 "wait",
2962 &::c10d::Work::wait,
2963 py::arg("timeout") = kNoTimeout,
2964 py::call_guard<py::gil_scoped_release>())
2965 .def(
2966 "get_future",
2967 [](::c10d::Work& work) -> std::shared_ptr<jit::PythonFutureWrapper> {
2968 return std::make_shared<jit::PythonFutureWrapper>(work.getFuture());
2969 },
2970 R"(
2971 Returns:
2972 A ``torch.futures.Future`` object which is associated with the completion of
2973 the ``Work``. As an example, a future object can be retrieved
2974 by ``fut = process_group.allreduce(tensors).get_future()``.
2975
2976 Example::
2977 Below is an example of a simple allreduce DDP communication hook that uses
2978 ``get_future` API to retrieve a Future associated with the completion of
2979 ``allreduce``.
2980
2981 >>> def allreduce(process_group: dist.ProcessGroup, bucket: dist.GradBucket): -> torch.futures.Future
2982 >>> group_to_use = process_group if process_group is not None else torch.distributed.group.WORLD
2983 >>> tensor = bucket.buffer().div_(group_to_use.size())
2984 >>> return torch.distributed.all_reduce(tensor, group=group_to_use, async_op=True).get_future()
2985 >>> ddp_model.register_comm_hook(state=None, hook=allreduce)
2986
2987 .. warning ::
2988 ``get_future`` API supports NCCL, and partially GLOO and MPI backends
2989 (no support for peer-to-peer operations like send/recv) and will return a ``torch.futures.Future``.
2990
2991 In the example above, ``allreduce`` work will be done on GPU using NCCL backend,
2992 ``fut.wait()`` will return after synchronizing the appropriate NCCL streams
2993 with PyTorch's current device streams to ensure we can have asynchronous CUDA
2994 execution and it does not wait for the entire operation to complete on GPU. Note that
2995 ``CUDAFuture`` does not support ``TORCH_NCCL_BLOCKING_WAIT`` flag or NCCL's ``barrier()``.
2996 In addition, if a callback function was added by ``fut.then()``, it will wait until
2997 ``WorkNCCL``'s NCCL streams synchronize with ``ProcessGroupNCCL``'s dedicated callback
2998 stream and invoke the callback inline after running the callback on the callback stream.
2999 ``fut.then()`` will return another ``CUDAFuture`` that holds the return value of the
3000 callback and a ``CUDAEvent`` that recorded the callback stream.
3001
3002 1. For CPU work, ``fut.done()`` returns true when work has been completed and value()
3003 tensors are ready.
3004 2. For GPU work, ``fut.done()`` returns true only whether the operation has been enqueued.
3005 3. For mixed CPU-GPU work (e.g. sending GPU tensors with GLOO), ``fut.done()`` returns
3006 true when tensors have arrived on respective nodes, but not yet necessarily synched on
3007 respective GPUs (similarly to GPU work).
3008 )")
3009 .def(
3010 "_get_op_type",
3011 [](::c10d::Work& work) -> int {
3012 return static_cast<int>(work.retrieveOpType());
3013 })
3014 .def(
3015 "_get_duration",
3016 &::c10d::Work::getDuration,
3017 py::call_guard<py::gil_scoped_release>(),
3018 R"(
3019 Returns:
3020 Duration of the corresponding collective communication.
3021
3022 .. warning ::
3023 This API only works for NCCL backend for now and must set
3024 TORCH_NCCL_ENABLE_TIMING environment variable.
3025 )")
3026 .def(
3027 "boxed",
3028 [](c10::intrusive_ptr<::c10d::Work> self) {
3029 return torch::jit::toPyObject(c10::IValue(std::move(self)));
3030 })
3031 .def_static("unbox", [](py::object obj) {
3032 auto typePtr =
3033 torch::getCustomClass("__torch__.torch.classes.c10d.Work");
3034 auto ivalue = torch::jit::toIValue(std::move(obj), typePtr);
3035 return ivalue.toCustomClass<::c10d::Work>();
3036 });
3037
3038 auto fakeProcessGroup =
3039 intrusive_ptr_no_gil_destructor_class_<::c10d::FakeProcessGroup>(
3040 module, "FakeProcessGroup", backend)
3041 .def(py::init([](int rank, int size) {
3042 return c10::make_intrusive<::c10d::FakeProcessGroup>(rank, size);
3043 }));
3044
3045 py::class_<c10::DDPLoggingData>(module, "DDPLoggingData")
3046 .def(py::init<>())
3047 .def_readwrite("strs_map", &c10::DDPLoggingData::strs_map)
3048 .def_readwrite("ints_map", &c10::DDPLoggingData::ints_map);
3049
3050 module.def(
3051 "_compute_bucket_assignment_by_size",
3052 [](const std::vector<at::Tensor>& tensors,
3053 const std::vector<size_t>& bucket_size_limits,
3054 const std::vector<bool>& expect_sparse_gradient,
3055 const std::vector<int64_t>& tensor_indices,
3056 const std::optional<std::shared_ptr<::c10d::Logger>>& logger) {
3057 if (logger.has_value()) {
3058 std::weak_ptr<::c10d::Logger> logger_weakref = logger.value();
3059 return ::c10d::compute_bucket_assignment_by_size(
3060 tensors,
3061 bucket_size_limits,
3062 expect_sparse_gradient,
3063 tensor_indices,
3064 {logger_weakref});
3065 } else {
3066 return ::c10d::compute_bucket_assignment_by_size(
3067 tensors,
3068 bucket_size_limits,
3069 expect_sparse_gradient,
3070 tensor_indices,
3071 {});
3072 }
3073 },
3074 py::arg("tensors"),
3075 py::arg("bucket_size"),
3076 py::arg("expect_sparse_gradient") = std::vector<bool>(),
3077 py::arg("tensor_indices") = std::vector<int64_t>(),
3078 py::arg("logger") = std::optional<std::shared_ptr<::c10d::Logger>>{},
3079 py::call_guard<py::gil_scoped_release>());
3080
3081 module.def(
3082 "_verify_params_across_processes",
3083 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group,
3084 const std::vector<at::Tensor>& params,
3085 const std::optional<std::shared_ptr<::c10d::Logger>>& logger) {
3086 if (logger.has_value()) {
3087 std::weak_ptr<::c10d::Logger> logger_weakref = logger.value();
3088 verify_params_across_processes(
3089 process_group, params, {logger_weakref});
3090 } else {
3091 verify_params_across_processes(process_group, params, {});
3092 }
3093 },
3094 py::arg("process_group"),
3095 py::arg("params"),
3096 py::arg("logger") = std::optional<std::shared_ptr<::c10d::Logger>>{},
3097 py::call_guard<py::gil_scoped_release>());
3098
3099 module.def(
3100 "_broadcast_coalesced",
3101 // Define a lambda such that the pybind11 prototype can take a std::vector
3102 // for the tensor list argument, but still pass it to the underlying
3103 // function as a c10::ArrayRef.
3104 [](const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group,
3105 const std::vector<at::Tensor>& tensors,
3106 size_t buffer_size,
3107 int rank) {
3108 broadcast_coalesced(process_group, tensors, buffer_size, rank);
3109 },
3110 py::arg("process_group"),
3111 py::arg("tensors"),
3112 py::arg("buffer_size"),
3113 // The source of truth rank to broadcast the tensors from.
3114 py::arg("src") = 0,
3115 py::call_guard<py::gil_scoped_release>());
3116
3117 module.def(
3118 "_test_python_store",
3119 // Define a function that takes a c10d store and runs a few tests.
3120 // This is used by the PythonStore tests, which we cannot test from the
3121 // Python side of the world. Calling Python functions on a Python object
3122 // completely bypasses pybind11. We need to test that the overloaded
3123 // functions call into Python and behave like we expect.
3124 [](c10::intrusive_ptr<::c10d::Store> store) {
3125 auto add = [&store](const std::string& key, int64_t value) {
3126 store->add(key, value);
3127 };
3128
3129 auto set = [&store](const std::string& key, const std::string& value) {
3130 store->set(key, value);
3131 };
3132
3133 auto get = [&store](const std::string& key) {
3134 auto value = store->get(key);
3135 return std::string(value.begin(), value.end());
3136 };
3137
3138 add("key", 1);
3139 add("key", 2);
3140 add("key", 3);
3141 set("key0", "value0");
3142 add("key3", 1);
3143 set("key1", "value1");
3144 add("key3", 2);
3145 set("key2", "value2");
3146 add("key3", 3);
3147 add("key3", 4);
3148 add("key3", 3);
3149 add("key3", 2);
3150 if (get("key") != "6") {
3151 TORCH_CHECK(false, "assertion failed");
3152 }
3153 if (get("key0") != "value0") {
3154 TORCH_CHECK(false, "assertion failed");
3155 }
3156 if (get("key1") != "value1") {
3157 TORCH_CHECK(false, "assertion failed");
3158 }
3159 if (get("key2") != "value2") {
3160 TORCH_CHECK(false, "assertion failed");
3161 }
3162 if (get("key3") != "15") {
3163 TORCH_CHECK(false, "assertion failed");
3164 }
3165 },
3166 py::call_guard<py::gil_scoped_release>());
3167
3168 module.attr("_DEFAULT_FIRST_BUCKET_BYTES") = ::c10d::kDefaultFirstBucketBytes;
3169 module.attr("_DEFAULT_PG_TIMEOUT") = py::cast(kProcessGroupDefaultTimeout);
3170 #ifdef USE_C10D_NCCL
3171 module.attr("_DEFAULT_PG_NCCL_TIMEOUT") =
3172 py::cast(::c10d::kProcessGroupNCCLDefaultTimeout);
3173 #endif
3174 module.attr("_DEFAULT_NO_TIMEOUT") = py::cast(kNoTimeout);
3175
3176 module.def(
3177 "_set_global_rank",
3178 [](int64_t rank) { c10::SetGlobalRank(rank); },
3179 py::arg("rank"),
3180 R"(
3181 Arguments:
3182 rank(int): The rank of the default process group
3183 Informs the C++ runtime what the default process group (a strictly Python
3184 notion) is. This mostly ensures that C++ log messages are prefixed with
3185 rank information. This is not meant to be called manually; it is
3186 called by _update_default_pg.
3187 )");
3188
3189 module.def(
3190 "_create_work_from_future",
3191 [](const std::shared_ptr<jit::PythonFutureWrapper>& future) {
3192 return ::c10d::Work::create_from_future(future->fut);
3193 },
3194 py::arg("future"),
3195 R"(
3196 Arguments:
3197 future(str): The future to wrap.
3198 Returns:
3199 A ``Work`` object which is associated with the completion of
3200 the ``torch.futures.Future``.
3201 This is the preferred way of constructing Work objects when writing a custom ProcessGroup
3202 in python.
3203 Example::
3204 >>> class SingleRankProcessGroup(torch.distributed.ProcessGroup):
3205 >>> def broadcast(self, tensor_list, opts):
3206 >>> fut = torch.futures.Future()
3207 >>> fut.set_result(tensor_list)
3208 >>> return torch._C._distributed_c10d._create_work_from_future(fut)
3209 .. warning ::
3210 This API is experimental and subject to change.
3211 The returned Work object has multiple limitations:
3212 - synchronize() does nothing. Use ``torch.futures.Future`` based synchronization.
3213 - wait() ignored timeout argument.
3214 - sourceRank() raises.
3215 - abort() raises.
3216 The provided Future object result must be a Tensor or a list of Tensors.
3217 )");
3218
3219 #ifdef USE_C10D_NCCL
3220 module.def(
3221 "_hash_tensors",
3222 [](const std::vector<at::Tensor>& tensors) {
3223 return ::c10d::hashTensors(tensors);
3224 },
3225 py::arg("tensors"),
3226 R"(
3227 Arguments:
3228 tensors(List[torch.Tensor]): List of tensors we want to hash.
3229 )");
3230 module.def(
3231 "_dump_nccl_trace_json",
3232 [](std::optional<bool> includeCollectives,
3233 std::optional<bool> onlyActive) {
3234 return py::bytes(::c10d::dump_nccl_trace_json(
3235 includeCollectives.value_or(true), onlyActive.value_or(false)));
3236 },
3237 py::arg("includeCollectives") = std::optional<bool>(),
3238 py::arg("onlyActive") = std::optional<bool>(),
3239 R"(
3240 Arguments:
3241 includeCollectives(bool, optional): Whether to include collective work traces. Default is True.
3242 onlyActive (bool, optional): Whether to only include active collective work traces. Default is False.
3243 Returns:
3244 Stringified json work traces.
3245 Default settings return everything - i.e. contains NCCL comm dumps and collective traces.
3246 )");
3247 module.def(
3248 "_dump_nccl_trace",
3249 [](std::optional<bool> includeCollectives,
3250 std::optional<bool> includeStackTraces,
3251 std::optional<bool> onlyActive) {
3252 return py::bytes(::c10d::dump_nccl_trace(
3253 includeCollectives.value_or(true),
3254 includeStackTraces.value_or(true),
3255 onlyActive.value_or(false)));
3256 },
3257 py::arg("includeCollectives") = std::optional<bool>(),
3258 py::arg("includeStackTraces") = std::optional<bool>(),
3259 py::arg("onlyActive") = std::optional<bool>(),
3260 R"(
3261 Arguments:
3262 includeCollectives(bool, optional): Whether to include collective work traces. Default is True.
3263 includeStackTraces(bool, optional): Whether to include stacktraces in the collective work traces. Default is True.
3264 onlyActive (bool, optional): Whether to only include active collective work traces. Default is False.
3265 Returns:
3266 Stringified pickle work traces.
3267 Default settings return everything - i.e. contains NCCL comm dumps and collective traces.
3268 )");
3269 #endif
3270
3271 intrusive_ptr_class_<::c10d::control_plane::WorkerServer>(
3272 module, "_WorkerServer", R"(
3273 )")
3274 .def(
3275 py::init([](const std::string& hostOrFile, int port) {
3276 return c10::make_intrusive<::c10d::control_plane::WorkerServer>(
3277 hostOrFile, port);
3278 }),
3279 py::arg("host_or_file"),
3280 py::arg("port") = -1)
3281 .def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
3282
3283 module.def(
3284 "_get_handler",
3285 [](const std::string& name) -> py::cpp_function {
3286 return py::cpp_function(
3287 ::c10d::control_plane::getHandler(name),
3288 py::arg("request"),
3289 py::arg("response"),
3290 py::call_guard<py::gil_scoped_release>());
3291 },
3292 py::arg("name"),
3293 R"(
3294 Returns the handler with the specified name.
3295 )");
3296
3297 module.def(
3298 "_get_handler_names",
3299 &::c10d::control_plane::getHandlerNames,
3300 R"(
3301 Returns the names of all handlers.
3302 )",
3303 py::call_guard<py::gil_scoped_release>());
3304
3305 py::class_<::c10d::control_plane::Request, PythonRequest>(
3306 module,
3307 "_Request",
3308 R"(
3309 See c10d::control_plane::Request for docs.
3310 )")
3311 // Default constructor.
3312 .def(py::init<>())
3313 .def("body", &::c10d::control_plane::Request::body)
3314 .def("params", &::c10d::control_plane::Request::params);
3315
3316 py::class_<
3317 ::c10d::control_plane::Response,
3318 std::shared_ptr<::c10d::control_plane::Response>,
3319 PythonResponse>(
3320 module,
3321 "_Response",
3322 R"(
3323 See c10d::control_plane::Response for docs.
3324 )")
3325 // Default constructor.
3326 .def(py::init<>())
3327 .def(
3328 "set_content",
3329 &::c10d::control_plane::Response::setContent,
3330 py::arg("content"),
3331 py::arg("content_type"))
3332 .def(
3333 "set_status",
3334 &::c10d::control_plane::Response::setStatus,
3335 py::arg("status"));
3336
3337 Py_RETURN_TRUE;
3338 }
3339
3340 #undef PROCESS_GROUP_DEPRECATION_WARNING
3341
3342 } // namespace
3343
3344 // c10d methods on torch._C
3345 static PyMethodDef methods[] = { // NOLINT
3346 {"_c10d_init", c10d_init, METH_NOARGS, nullptr},
3347 {nullptr, nullptr, 0, nullptr}};
3348
python_functions()3349 PyMethodDef* python_functions() {
3350 return methods;
3351 }
3352
3353 } // namespace torch::distributed::c10d
3354