xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/python_headers.h>
2 
3 #include <c10/util/intrusive_ptr.h>
4 #include <c10/util/string_view.h>
5 #include <torch/csrc/distributed/c10d/FileStore.hpp>
6 #include <torch/csrc/distributed/c10d/Functional.hpp>
7 #include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
8 #include <torch/csrc/distributed/c10d/TCPStore.hpp>
9 #include <torch/csrc/distributed/c10d/Utils.hpp>
10 #include <torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp>
11 #include <torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp>
12 #include <torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp>
13 #include <vector>
14 #ifndef _WIN32
15 #include <torch/csrc/distributed/c10d/HashStore.hpp>
16 #endif
17 #include <torch/csrc/distributed/c10d/FakeProcessGroup.hpp>
18 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
19 #include <torch/csrc/distributed/c10d/PyProcessGroup.hpp>
20 
21 #ifdef USE_C10D_GLOO
22 #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
23 #include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp>
24 #endif
25 
26 #ifdef USE_C10D_NCCL
27 #include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
28 #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
29 #include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
30 #endif
31 
32 #ifdef USE_C10D_MPI
33 #include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp>
34 #endif
35 
36 #ifdef USE_C10D_UCC
37 #include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
38 #endif
39 
40 #include <fmt/format.h>
41 #include <pybind11/chrono.h>
42 #include <torch/csrc/distributed/c10d/DMAConnectivity.hpp>
43 #include <torch/csrc/distributed/c10d/PrefixStore.hpp>
44 #include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
45 
46 #include <torch/csrc/distributed/c10d/comm.hpp>
47 #include <torch/csrc/distributed/c10d/debug.h>
48 #include <torch/csrc/distributed/c10d/logger.hpp>
49 #include <torch/csrc/distributed/c10d/reducer.hpp>
50 
51 #include <torch/csrc/Exceptions.h>
52 #include <torch/csrc/distributed/c10d/python_comm_hook.h>
53 #include <torch/csrc/jit/python/pybind_utils.h>
54 #include <torch/csrc/utils/object_ptr.h>
55 #include <torch/csrc/utils/pybind.h>
56 
57 #include <torch/custom_class.h>
58 
59 namespace {
60 
61 #ifdef USE_C10D_NCCL
62 
acquire_gil()63 bool acquire_gil() {
64   // basically if this function can acquire the gil, it will return quickly.
65   // if not, it will hang forever.  The idea is to call this from a thread
66   // wrapped in a future, and then check the future after a timeout, to
67   // determine whether we're facing gil contention.
68   if (Py_IsInitialized()) {
69     pybind11::gil_scoped_acquire gil;
70     return true;
71   }
72 
73   // If we end up here, its probably still a "pass" from the perspective of
74   // checking whether python is stuck. but currently we don't check the return
75   // value of this function anyway, just check whether it returned quickly vs
76   // timing out.  Taking a long time is the main sign of trouble.  Fast return
77   // with true or with false is both OK from the perspective of debugging python
78   // hangs.
79   return false;
80 }
81 
registerGilChecker()82 bool registerGilChecker() {
83   c10d::get_gil_checker() = &acquire_gil;
84   return true;
85 }
86 
87 static bool registered = registerGilChecker();
88 #endif // USE_C10D_NCCL
89 
90 // Wrapper to ensure GIL is released before destructing ProcessGroupGloo
91 // TODO: move this somewhere more generally useful
92 template <typename T>
93 class IntrusivePtrNoGilDestructor {
94   c10::intrusive_ptr<T> impl_{};
95 
96  public:
97   IntrusivePtrNoGilDestructor() = default;
98   IntrusivePtrNoGilDestructor(const IntrusivePtrNoGilDestructor&) = default;
99   IntrusivePtrNoGilDestructor(IntrusivePtrNoGilDestructor&&) = default;
100   IntrusivePtrNoGilDestructor& operator=(const IntrusivePtrNoGilDestructor&) =
101       default;
102   IntrusivePtrNoGilDestructor& operator=(IntrusivePtrNoGilDestructor&&) =
103       default;
IntrusivePtrNoGilDestructor(c10::intrusive_ptr<T> impl)104   /* implicit */ IntrusivePtrNoGilDestructor(c10::intrusive_ptr<T> impl)
105       : impl_(std::move(impl)) {}
106   // This ctor is very important; see
107   // https://github.com/pybind/pybind11/issues/2957
IntrusivePtrNoGilDestructor(T * impl)108   explicit IntrusivePtrNoGilDestructor(T* impl)
109       : impl_(c10::intrusive_ptr<T>::unsafe_steal_from_new(impl)) {}
~IntrusivePtrNoGilDestructor()110   ~IntrusivePtrNoGilDestructor() {
111     if (impl_) {
112       if (PyGILState_Check()) {
113         pybind11::gil_scoped_release release;
114         impl_.reset();
115       } else {
116         impl_.reset();
117       }
118     }
119   }
operator *() const120   T& operator*() const noexcept {
121     return *impl_;
122   }
operator ->() const123   T* operator->() const noexcept {
124     return impl_.get();
125   }
get() const126   C10_NODISCARD T* get() const noexcept {
127     return impl_.get();
128   }
reset()129   void reset() noexcept {
130     impl_.reset();
131   }
operator bool() const132   operator bool() const noexcept {
133     return impl_;
134   }
135 };
136 
137 } // anonymous namespace
138 
139 PYBIND11_DECLARE_HOLDER_TYPE(T, IntrusivePtrNoGilDestructor<T>, true);
140 
141 namespace torch::distributed::c10d {
142 
143 namespace {
144 
toPyBytes(const std::vector<uint8_t> & data)145 py::bytes toPyBytes(const std::vector<uint8_t>& data) {
146   return py::bytes(reinterpret_cast<const char*>(data.data()), data.size());
147 }
148 
toPyBytes(const std::vector<std::vector<uint8_t>> & data)149 std::vector<py::bytes> toPyBytes(
150     const std::vector<std::vector<uint8_t>>& data) {
151   std::vector<py::bytes> out;
152   out.reserve(data.size());
153   for (const std::vector<uint8_t>& data_ : data) {
154     out.emplace_back(reinterpret_cast<const char*>(data_.data()), data_.size());
155   }
156   return out;
157 }
158 
toVec8(const std::string & data)159 std::vector<uint8_t> toVec8(const std::string& data) {
160   std::vector<uint8_t> out{data.begin(), data.end()};
161   return out;
162 }
163 
toVec8(const std::vector<std::string> & data)164 std::vector<std::vector<uint8_t>> toVec8(const std::vector<std::string>& data) {
165   std::vector<std::vector<uint8_t>> out;
166   out.reserve(data.size());
167   for (auto& data_ : data) {
168     out.emplace_back(toVec8(data_));
169   }
170   return out;
171 }
172 
173 template <typename T>
174 using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
175 
176 constexpr auto kDeprecationWarning =
177     "{} API is being deprecated, please ping "
178     "https://github.com/pytorch/pytorch/issues/46291 "
179     "if you see this warning";
180 template <typename T>
181 using intrusive_ptr_class_ = py::class_<T, c10::intrusive_ptr<T>>;
182 
183 template <typename T>
184 using intrusive_ptr_no_gil_destructor_class_ =
185     py::class_<T, IntrusivePtrNoGilDestructor<T>>;
186 
187 // PythonStore is a pybind11 trampoline class to allow a Python
188 // class to inherit from c10d.Store and implement its interface.
189 class PythonStore : public ::c10d::Store {
190  public:
191   using ::c10d::Store::Store;
192 
193   // Note: this function manually calls the Python-side overload
194   // for this function instead of using the PYBIND11_OVERLOAD_XYZ
195   // macros. This is done so that we can call the Python-side
196   // function with a std::string instead of a std::vector<uint8_t>.
set(const std::string & key,const std::vector<uint8_t> & value)197   void set(const std::string& key, const std::vector<uint8_t>& value) override {
198     pybind11::gil_scoped_acquire gil;
199     pybind11::function fn =
200         pybind11::get_overload(static_cast<const ::c10d::Store*>(this), "set");
201     TORCH_INTERNAL_ASSERT(fn, "Not implemented.");
202     // Call function with a py::bytes object for the value.
203     fn(key, toPyBytes(value));
204   }
205 
206   // Note: this function manually calls the Python-side overload
207   // for this function instead of using the PYBIND11_OVERLOAD_XYZ
208   // macros. This is done so that the Python-side function can
209   // return a py::bytes instead of a std::vector<uint8_t>.
get(const std::string & key)210   std::vector<uint8_t> get(const std::string& key) override {
211     pybind11::gil_scoped_acquire gil;
212     pybind11::function fn =
213         pybind11::get_overload(static_cast<const ::c10d::Store*>(this), "get");
214     TORCH_INTERNAL_ASSERT(fn, "Not implemented.");
215     // Cast return value from Python to py::bytes, then implicitly
216     // convert that to a std::string, so that we can construct a
217     // std::vector<uint8_t>. There is no API for directly accessing
218     // the contents of the py::bytes object.
219     std::string str = pybind11::cast<py::bytes>(fn(key));
220     return toVec8(str);
221   }
222 
223   // Note: this function manually calls the Python-side overload
224   // for this function instead of using the PYBIND11_OVERLOAD_XYZ
225   // macros. This is done so that the Python-side function can
226   // return a py::bytes instead of a std::vector<uint8_t>.
compareSet(const std::string & key,const std::vector<uint8_t> & expectedValue,const std::vector<uint8_t> & desiredValue)227   std::vector<uint8_t> compareSet(
228       const std::string& key,
229       const std::vector<uint8_t>& expectedValue,
230       const std::vector<uint8_t>& desiredValue) override {
231     pybind11::gil_scoped_acquire gil;
232     pybind11::function fn = pybind11::get_overload(
233         static_cast<const ::c10d::Store*>(this), "compare_set");
234     TORCH_INTERNAL_ASSERT(fn, "Not implemented.");
235     // Cast return value from Python to py::bytes, then implicitly
236     // convert that to a std::string, so that we can construct a
237     // std::vector<uint8_t>. There is no API for directly accessing
238     // the contents of the py::bytes object.
239     std::string str = pybind11::cast<py::bytes>(
240         fn(key, toPyBytes(expectedValue), toPyBytes(desiredValue)));
241     return toVec8(str);
242   }
243 
add(const std::string & key,int64_t value)244   int64_t add(const std::string& key, int64_t value) override {
245     PYBIND11_OVERLOAD_PURE(int64_t, ::c10d::Store, add, key, value);
246   }
247 
getNumKeys()248   int64_t getNumKeys() override {
249     PYBIND11_OVERLOAD_PURE(int64_t, ::c10d::Store, getNumKeys);
250   }
251 
deleteKey(const std::string & key)252   bool deleteKey(const std::string& key) override {
253     PYBIND11_OVERLOAD_PURE(bool, ::c10d::Store, deleteKey, key);
254   }
255 
check(const std::vector<std::string> & keys)256   bool check(const std::vector<std::string>& keys) override {
257     PYBIND11_OVERLOAD_PURE(bool, ::c10d::Store, check, keys);
258   }
259 
wait(const std::vector<std::string> & keys)260   void wait(const std::vector<std::string>& keys) override {
261     PYBIND11_OVERLOAD_PURE(void, ::c10d::Store, wait, keys);
262   }
263 
wait(const std::vector<std::string> & keys,const std::chrono::milliseconds & timeout)264   void wait(
265       const std::vector<std::string>& keys,
266       const std::chrono::milliseconds& timeout) override {
267     PYBIND11_OVERLOAD_PURE(void, ::c10d::Store, wait, keys, timeout);
268   }
269 
270   // Note: this function manually calls the Python-side overload
271   // for this function instead of using the PYBIND11_OVERLOAD_XYZ
272   // macros. This is done so that we can call the Python-side
273   // function with a std::string instead of a std::vector<uint8_t>.
append(const std::string & key,const std::vector<uint8_t> & value)274   void append(const std::string& key, const std::vector<uint8_t>& value)
275       override {
276     pybind11::gil_scoped_acquire gil;
277     pybind11::function fn = pybind11::get_overload(
278         static_cast<const ::c10d::Store*>(this), "append");
279     if (!fn) {
280       return Store::append(key, value);
281     }
282     // Call function with a py::bytes object for the value.
283     fn(key, toPyBytes(value));
284   }
285 
multiGet(const std::vector<std::string> & keys)286   std::vector<std::vector<uint8_t>> multiGet(
287       const std::vector<std::string>& keys) override {
288     pybind11::gil_scoped_acquire gil;
289     pybind11::function fn = pybind11::get_overload(
290         static_cast<const ::c10d::Store*>(this), "multi_get");
291     if (!fn) {
292       return Store::multiGet(keys);
293     }
294     std::vector<std::string> py_list =
295         pybind11::cast<std::vector<std::string>>(fn(keys));
296     std::vector<std::vector<uint8_t>> res;
297     res.reserve(py_list.size());
298 
299     for (auto& str : py_list) {
300       res.emplace_back(str.begin(), str.end());
301     }
302 
303     return res;
304   }
305 
multiSet(const std::vector<std::string> & keys,const std::vector<std::vector<uint8_t>> & values)306   void multiSet(
307       const std::vector<std::string>& keys,
308       const std::vector<std::vector<uint8_t>>& values) override {
309     pybind11::gil_scoped_acquire gil;
310     pybind11::function fn = pybind11::get_overload(
311         static_cast<const ::c10d::Store*>(this), "multi_set");
312     if (!fn) {
313       return Store::multiSet(keys, values);
314     }
315 
316     fn(keys, toPyBytes(values));
317   }
318 
hasExtendedApi() const319   bool hasExtendedApi() const override {
320     PYBIND11_OVERLOAD_NAME(
321         bool, ::c10d::Store, "has_extended_api", hasExtendedApi);
322   }
323 };
324 
325 class PythonRequest : public ::c10d::control_plane::Request {
326  public:
body() const327   const std::string& body() const override {
328     PYBIND11_OVERRIDE_PURE(
329         const std::string&, ::c10d::control_plane::Request, body);
330   }
331 
params() const332   const std::multimap<std::string, std::string>& params() const override {
333     using MultiMap = const std::multimap<std::string, std::string>&;
334     PYBIND11_OVERRIDE_PURE(MultiMap, ::c10d::control_plane::Request, params);
335   }
336 };
337 class PythonResponse : public ::c10d::control_plane::Response {
338  public:
setContent(std::string && content,const std::string & content_type)339   void setContent(std::string&& content, const std::string& content_type)
340       override {
341     PYBIND11_OVERRIDE_PURE_NAME(
342         void,
343         ::c10d::control_plane::Response,
344         "set_content",
345         setContent,
346         content,
347         content_type);
348   }
setStatus(int status)349   void setStatus(int status) override {
350     PYBIND11_OVERRIDE_PURE_NAME(
351         void, ::c10d::control_plane::Response, "set_status", setStatus, status);
352   }
353 };
354 
355 // Called from DDP's Python API to create a c10d Python comm hook object.
356 // The input state and callable comm_hook are Python objects. It later calls
357 // register_comm_hook function of the reducer input to register the hook.
_register_comm_hook(::c10d::Reducer & reducer,py::object state,py::object comm_hook)358 void _register_comm_hook(
359     ::c10d::Reducer& reducer,
360     py::object state,
361     py::object comm_hook) {
362   reducer.register_comm_hook(std::make_unique<::c10d::PythonCommHook>(
363       std::move(state), std::move(comm_hook)));
364 }
365 
366 // Called from DDP's Python API to create a c10d C++ comm hook.
367 // The input is an enum hook type. It later calls register_builtin_comm_hook
368 // function of the reducer input to set the hook type.
_register_builtin_comm_hook(::c10d::Reducer & reducer,::c10d::BuiltinCommHookType comm_hook_type)369 void _register_builtin_comm_hook(
370     ::c10d::Reducer& reducer,
371     ::c10d::BuiltinCommHookType comm_hook_type) {
372   reducer.register_builtin_comm_hook(comm_hook_type);
373 }
374 
375 // Customize the metaclass of ::c10d::ReduceOp for the backward compatibility.
376 // https://github.com/pytorch/pytorch/pull/84243 changed ::c10d::ReduceOp to
377 // struct from enum, sacrificing some of the Python built-in function supports
378 // such as `isinstance` (see https://github.com/pytorch/pytorch/issues/87191)
379 // and `copy` (see
380 // https://github.com/pytorch/pytorch/pull/87303#discussion_r1002879700). Below,
381 // we define a custom `isinstance` in CPython/pybind11
382 // (`reduceopmeta___instancecheck__`) and modify the default metaclass of
383 // pybind11 (`GetReduceOpMetaclass`) so that
384 // `isinstance(torch.distributed.ReduceOp.SUM, torch.distributed.ReduceOp)`
385 // returns :obj:`True` as if `ReduceOp` is enum.
386 // Ref:
387 //   - https://docs.python.org/3/extending/newtypes_tutorial.html
388 //   - https://docs.python.org/3/c-api/typeobj.html?highlight=tp_methods
389 //   - https://github.com/pybind/pybind11/issues/2696
reduceopmeta___instancecheck__(PyObject * self,PyObject * args)390 static PyObject* reduceopmeta___instancecheck__(
391     PyObject* self,
392     PyObject* args) {
393   if (Py_TYPE(self) == Py_TYPE(args)) {
394     Py_RETURN_TRUE;
395   }
396   if (c10::string_view(args->ob_type->tp_name).find("RedOpType") !=
397       c10::string_view::npos) {
398     Py_RETURN_TRUE;
399   }
400   Py_RETURN_FALSE;
401 }
402 // NOLINTNEXTLINE(*c-arrays)
403 static PyMethodDef reduceopmeta_methods[] = {
404     {"__instancecheck__",
405      (PyCFunction)reduceopmeta___instancecheck__,
406      METH_O,
407      "Custom `__instancecheck__` for ReduceOp"},
408     {nullptr, nullptr}};
GetReduceOpMetaclass()409 PyTypeObject* GetReduceOpMetaclass() {
410   static auto* metaclass = [] {
411     PyTypeObject* base_metaclass =
412         pybind11::detail::get_internals().default_metaclass;
413     // NOLINTNEXTLINE(*c-arrays)
414     PyType_Slot slots[] = {
415         {Py_tp_base, base_metaclass},
416         {Py_tp_methods, reduceopmeta_methods},
417         {0},
418     };
419     PyType_Spec spec = {};
420     spec.name = "torch._C._distributed_c10d._ReduceOpMeta";
421     // NOLINTNEXTLINE(*-narrowing-conversions)
422     spec.basicsize = base_metaclass->tp_basicsize;
423     spec.flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
424     spec.slots = slots;
425     PyTypeObject* metaclass = (PyTypeObject*)PyType_FromSpec(&spec);
426     if (!metaclass)
427       throw py::error_already_set();
428     return metaclass;
429   }();
430   return metaclass;
431 }
432 
c10d_init(PyObject * _unused,PyObject * noargs)433 PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {
434   C10_LOG_API_USAGE_ONCE("c10d.python.import");
435 
436   auto c10d_module = THPObjectPtr(PyImport_ImportModule("torch.distributed"));
437   if (!c10d_module) {
438     throw python_error();
439   }
440 
441   auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
442   if (!torch_C_module) {
443     throw python_error();
444   }
445 
446   auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
447   auto m =
448       torch_C_m.def_submodule("_distributed_c10d", "distributed c10d bindings");
449 
450   auto module = py::handle(m).cast<py::module>();
451 
452   module
453       .def(
454           "_register_comm_hook",
455           &_register_comm_hook,
456           py::arg("reducer"),
457           py::arg("state"),
458           py::arg("comm_hook"),
459           py::call_guard<py::gil_scoped_release>())
460       .def(
461           "_register_builtin_comm_hook",
462           &_register_builtin_comm_hook,
463           py::arg("reducer"),
464           py::arg("comm_hook_type"));
465 
466   shared_ptr_class_<::c10d::GradBucket>(
467       module,
468       "GradBucket",
469       R"(
470 This class mainly passes a flattened gradient tensor
471 (returned by :meth:`~torch.distributed.GradBucket.buffer`)
472 to DDP communication hook.
473 This tensor can be further decomposed into a list of per-parameter tensors within this bucket
474 (returned by :meth:`~torch.distributed.GradBucket.get_per_parameter_tensors`)
475 to apply layer-wise operations.
476 )")
477       .def(
478           "index",
479           &::c10d::GradBucket::getIndex,
480           py::call_guard<py::gil_scoped_release>(),
481           R"(
482 .. warning::
483     Since the buckets are rebuilt after the first iteration, should not rely on the indices at the beginning of training.
484 
485 Returns:
486     The index of a bucket that stores gradients of a few contiguous layers.
487     All the gradients are bucketized.
488 )")
489       .def(
490           "buffer",
491           &::c10d::GradBucket::getBuffer,
492           py::call_guard<py::gil_scoped_release>(),
493           R"(
494 Returns:
495     A flattened 1D ``torch.Tensor`` buffer,
496     which can be further decomposed into a list of per-parameter tensors within this bucket.
497 )")
498       .def(
499           "gradients",
500           &::c10d::GradBucket::getGradients,
501           py::call_guard<py::gil_scoped_release>(),
502           R"(
503 Returns:
504     A list of ``torch.Tensor``. Each tensor in the list corresponds to a gradient.
505 )")
506       .def(
507           "parameters",
508           &::c10d::GradBucket::getParameters,
509           py::call_guard<py::gil_scoped_release>(),
510           R"(
511 Returns:
512     A list of ``torch.Tensor``. Each tensor in the list corresponds to a model
513     parameter.
514 )")
515       .def(
516           "is_last",
517           &::c10d::GradBucket::isLast,
518           py::call_guard<py::gil_scoped_release>(),
519           R"(
520 Returns:
521     Whether this bucket is the last bucket to allreduce in an iteration.
522     This also means that this bucket corresponds to the first few layers in the forward pass.
523 )")
524       .def(
525           "set_buffer",
526           &::c10d::GradBucket::setBuffer,
527           py::arg("buffer"),
528           py::call_guard<py::gil_scoped_release>(),
529           R"(
530 Replaces the tensor in the bucket with the input tensor buffer.
531 )");
532 
533   py::enum_<::c10d::BuiltinCommHookType>(module, "BuiltinCommHookType", R"(
534 An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_COMPRESS``.)")
535       .value("ALLREDUCE", ::c10d::BuiltinCommHookType::ALLREDUCE)
536       .value("FP16_COMPRESS", ::c10d::BuiltinCommHookType::FP16_COMPRESS);
537 
538   shared_ptr_class_<::c10d::Reducer>(module, "Reducer")
539       .def(
540           py::init<
541               std::vector<at::Tensor>,
542               std::vector<std::vector<size_t>>,
543               std::vector<size_t>,
544               c10::intrusive_ptr<::c10d::ProcessGroup>,
545               std::vector<bool>,
546               int64_t,
547               bool,
548               bool,
549               std::unordered_map<size_t, std::string>,
550               int64_t>(),
551           py::arg("params"),
552           py::arg("bucket_indices"),
553           py::arg("per_bucket_size_limits"),
554           py::arg("process_group"),
555           py::arg("expect_sparse_gradients") = std::vector<bool>(),
556           py::arg("bucket_bytes_cap") = ::c10d::kDefaultBucketBytesCap,
557           py::arg("find_unused_parameters") = false,
558           py::arg("gradient_as_bucket_view") = false,
559           py::arg("param_to_name_mapping") =
560               std::unordered_map<size_t, std::string>(),
561           py::arg("first_bucket_bytes_cap") = ::c10d::kDefaultFirstBucketBytes,
562           py::call_guard<py::gil_scoped_release>())
563       .def(
564           "prepare_for_forward",
565           &::c10d::Reducer::prepare_for_forward,
566           py::call_guard<py::gil_scoped_release>())
567       .def(
568           "prepare_for_backward",
569           &::c10d::Reducer::prepare_for_backward,
570           py::call_guard<py::gil_scoped_release>())
571       .def(
572           "prepare_for_backward",
573           [](::c10d::Reducer& reducer, const at::Tensor& output) -> void {
574             reducer.prepare_for_backward({output});
575           },
576           py::call_guard<py::gil_scoped_release>())
577       .def("get_backward_stats", &::c10d::Reducer::get_backward_stats)
578       .def(
579           "_install_post_backward_futures",
580           [](::c10d::Reducer& reducer,
581              const std::vector<std::shared_ptr<jit::PythonFutureWrapper>>&
582                  futs) {
583             c10::List<c10::intrusive_ptr<c10::ivalue::Future>> futures(
584                 c10::FutureType::create(c10::TensorType::get()));
585             for (const auto& fut : futs) {
586               futures.push_back(fut->fut);
587             }
588             reducer.install_futures(futures);
589           },
590           py::call_guard<py::gil_scoped_release>())
591       .def(
592           "_rebuild_buckets",
593           &::c10d::Reducer::rebuild_buckets,
594           py::call_guard<py::gil_scoped_release>())
595       .def(
596           "_get_zeros_like_grad_buckets",
597           [](::c10d::Reducer& reducer) {
598             return reducer.get_grad_buckets(/* return_zero_tensors */ true);
599           },
600           py::call_guard<py::gil_scoped_release>())
601       .def(
602           "_set_optimizer_in_backward",
603           [](::c10d::Reducer& reducer) { reducer.set_optimizer_in_backward(); },
604           py::call_guard<py::gil_scoped_release>())
605       .def(
606           "_set_sparse_metadata",
607           &::c10d::Reducer::setSparseMetadata,
608           py::call_guard<py::gil_scoped_release>())
609       .def(
610           "_set_mixed_precision_param_dtype",
611           [](::c10d::Reducer& reducer, py::object data_type_obj) {
612             auto scalar_type =
613                 reinterpret_cast<THPDtype*>(data_type_obj.ptr())->scalar_type;
614             reducer.set_mixed_precision_param_dtype(scalar_type);
615           },
616           py::call_guard<py::gil_scoped_release>())
617       .def(
618           "_push_all_rebuilt_params",
619           &::c10d::Reducer::push_rebuilt_params_for_all_indices,
620           py::call_guard<py::gil_scoped_release>())
621       .def(
622           "_set_forward_pass_work_handle",
623           &::c10d::Reducer::set_forward_pass_work_handle,
624           py::call_guard<py::gil_scoped_release>())
625       .def(
626           "_get_local_used_map", &::c10d::Reducer::get_local_used_map_on_device)
627       .def(
628           "_set_ddp_runtime_logging_sample_rate",
629           &::c10d::Reducer::set_ddp_runtime_logging_sample_rate,
630           py::arg("sample_rate"),
631           py::call_guard<py::gil_scoped_release>())
632       .def(
633           "_set_static_graph",
634           &::c10d::Reducer::set_static_graph,
635           py::call_guard<py::gil_scoped_release>())
636       .def(
637           "_ddp_graph_static",
638           &::c10d::Reducer::ddp_graph_static,
639           py::call_guard<py::gil_scoped_release>())
640       .def(
641           "_delay_all_reduce",
642           &::c10d::Reducer::delay_all_reduce,
643           py::call_guard<py::gil_scoped_release>())
644       .def(
645           "_run_comm_hook",
646           [](::c10d::Reducer& reducer, ::c10d::GradBucket& bucket)
647               -> std::shared_ptr<jit::PythonFutureWrapper> {
648             c10::intrusive_ptr<c10::ivalue::Future> fut =
649                 reducer.run_comm_hook(bucket);
650             return std::make_shared<jit::PythonFutureWrapper>(fut);
651           },
652           py::call_guard<py::gil_scoped_release>())
653       .def(
654           "_run_allreduce_hook",
655           [](::c10d::Reducer& reducer, ::c10d::GradBucket& bucket)
656               -> std::shared_ptr<jit::PythonFutureWrapper> {
657             c10::intrusive_ptr<c10::ivalue::Future> fut =
658                 reducer.run_allreduce_hook(bucket);
659             return std::make_shared<jit::PythonFutureWrapper>(fut);
660           },
661           py::call_guard<py::gil_scoped_release>())
662       .def(
663           "_autograd_hook",
664           [](::c10d::Reducer& reducer, int index) -> void {
665             reducer.autograd_hook(index);
666           },
667           py::call_guard<py::gil_scoped_release>())
668       .def(
669           "set_logger",
670           [](::c10d::Reducer& reducer,
671              const std::shared_ptr<::c10d::Logger>& logger) {
672             std::weak_ptr<::c10d::Logger> logger_weakref = logger;
673             reducer.set_logger(logger_weakref);
674           })
675       .def(
676           "_remove_autograd_hooks",
677           [](::c10d::Reducer& reducer) { reducer.remove_autograd_hooks(); },
678           py::call_guard<py::gil_scoped_release>())
679       .def(
680           "_check_reducer_finalized",
681           [](::c10d::Reducer& reducer) { return reducer.check_finalized(); },
682           py::call_guard<py::gil_scoped_release>())
683       .def(
684           "_reset_state",
685           [](::c10d::Reducer& reducer) { return reducer.reset_state(); },
686           py::call_guard<py::gil_scoped_release>())
687       .def(
688           "_update_process_group",
689           [](::c10d::Reducer& reducer,
690              c10::intrusive_ptr<::c10d::ProcessGroup> new_process_group) {
691             return reducer.update_process_group(std::move(new_process_group));
692           },
693           py::call_guard<py::gil_scoped_release>());
694 
695   shared_ptr_class_<::c10d::Logger>(module, "Logger")
696       .def(
697           py::init<std::shared_ptr<::c10d::Reducer>>(),
698           py::arg("reducer"),
699           py::call_guard<py::gil_scoped_release>())
700       .def(
701           "set_construction_data_and_log",
702           &::c10d::Logger::set_construction_data_and_log,
703           py::arg("module_name"),
704           py::arg("device_ids"),
705           py::arg("output_device"),
706           py::arg("broadcast_buffers"),
707           py::arg("has_sync_bn"),
708           py::arg("static_graph"),
709           py::call_guard<py::gil_scoped_release>())
710       .def(
711           "set_runtime_stats_and_log",
712           &::c10d::Logger::set_runtime_stats_and_log,
713           py::call_guard<py::gil_scoped_release>())
714       .def(
715           "set_error_and_log",
716           [](::c10d::Logger& logger, const std::string& error) {
717             logger.set_error_and_log(error);
718           },
719           py::call_guard<py::gil_scoped_release>())
720       .def(
721           "_get_ddp_logging_data",
722           &::c10d::Logger::get_ddp_logging_data,
723           py::call_guard<py::gil_scoped_release>())
724       .def(
725           "_set_comm_hook_name",
726           &::c10d::Logger::set_comm_hook,
727           py::arg("comm_hook"),
728           py::call_guard<py::gil_scoped_release>())
729       .def(
730           "_set_uneven_input_join",
731           &::c10d::Logger::set_uneven_input_join,
732           py::call_guard<py::gil_scoped_release>())
733       .def(
734           "_set_static_graph",
735           &::c10d::Logger::set_static_graph,
736           py::call_guard<py::gil_scoped_release>());
737 
738   py::enum_<::c10d::DebugLevel>(module, "DebugLevel", R"(
739       An enum whose values correspond to different debug levels of the
740       torch.distributed package. Currently supporting OFF, INFO, and DETAIL,
741       which can be set via the TORCH_DISTRIBUTED_DEBUG environment variable
742       or via ``set_debug_level()`` function.
743   )")
744       .value("OFF", ::c10d::DebugLevel::Off)
745       .value("INFO", ::c10d::DebugLevel::Info)
746       .value("DETAIL", ::c10d::DebugLevel::Detail);
747 
748   module
749       .def(
750           "get_debug_level",
751           ::c10d::debug_level,
752           R"(Gets the debug level of the torch.distributed package.)")
753       .def(
754           "set_debug_level",
755           ::c10d::setDebugLevel,
756           R"(Sets the debug level of the torch.distributed package.)")
757       .def(
758           "set_debug_level_from_env",
759           ::c10d::setDebugLevelFromEnvironment,
760           R"(Sets the debug level of the torch.distributed package from the
761           ``TORCH_DISTRIBUTED_DEBUG`` environment variable.)");
762 
763   // TODO(crcrpar): Hardening `ReduceOp`.
764   //    While keeping most op types as enum value,
765   //    making `PREMUL_SUM` callable, i.e., allowing for
766   //    `ReduceOp.PREMUL_SUM(scale)` might be better as per @wanchaol.
767   // https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types
768   py::class_<::c10d::ReduceOp> reduce_op(
769       module, "ReduceOp", py::metaclass((PyObject*)GetReduceOpMetaclass()), R"(
770 An enum-like class for available reduction operations: ``SUM``, ``PRODUCT``,
771 ``MIN``, ``MAX``, ``BAND``, ``BOR``, ``BXOR``, and ``PREMUL_SUM``.
772 
773 ``BAND``, ``BOR``, and ``BXOR`` reductions are not available when
774 using the ``NCCL`` backend.
775 
776 ``AVG`` divides values by the world size before summing across ranks.
777 ``AVG`` is only available with the ``NCCL`` backend,
778 and only for NCCL versions 2.10 or later.
779 
780 ``PREMUL_SUM`` multiplies inputs by a given scalar locally before reduction.
781 ``PREMUL_SUM`` is only available with the ``NCCL`` backend,
782 and only available for NCCL versions 2.11 or later. Users are supposed to
783 use ``torch.distributed._make_nccl_premul_sum``.
784 
785 Additionally, ``MAX``, ``MIN`` and ``PRODUCT`` are not supported for complex tensors.
786 
787 The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``.
788 They are used in specifying strategies for reduction collectives, e.g.,
789 :func:`reduce`.
790 
791 This class does not support ``__members__`` property.)");
792 
793   reduce_op.def(py::init<::c10d::ReduceOp::RedOpType>())
794       .def_readwrite("op", &::c10d::ReduceOp::op_);
795   // The following are for some kind of backward compatibility.
796   // Since c10d::ReduceOp had been an `enum class`, users can do comparison and
797   // take hash of `::c10d::ReduceOp`. To avoid losing these functionality, here
798   // I define some member methods.
799   reduce_op
800       // todo(crcrpar): Support `RedOpType == ReduceOp`.
801       .def(
802           // This calls `operator==(const ReduceOp::RedOpType)`
803           "__eq__",
804           [](const ::c10d::ReduceOp& self,
805              const ::c10d::ReduceOp::RedOpType& other) {
806             return self == other;
807           })
808       .def(
809           // This calls `operator==(const ReduceOp)` for the future support of
810           // `PREMUL_SUM` comparison
811           "__eq__",
812           [](const ::c10d::ReduceOp& self, const ::c10d::ReduceOp& other) {
813             return self == other;
814           })
815       .def(
816           // With the above custom `__eq__`'s, I have to manually support the
817           // other types.
818           "__eq__",
819           // NOLINTNEXTLINE(performance-unnecessary-value-param)
820           [](const ::c10d::ReduceOp& self, py::object) { return false; })
821       .def(
822           "__hash__",
823           [](const ::c10d::ReduceOp& self) {
824             return static_cast<uint8_t>(self.op_);
825           })
826       .def(
827           "__copy__",
828           [](const ::c10d::ReduceOp& self) { return ::c10d::ReduceOp(self); })
829       .def(
830           "__deepcopy__",
831           [](const ::c10d::ReduceOp& self, const py::dict& memo) {
832             return ::c10d::ReduceOp(self);
833           })
834       .def(py::pickle(
835           [](const ::c10d::ReduceOp& r) {
836             // __getstate__
837             if (r.op_ != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
838               return py::make_tuple(r.op_, py::none());
839             }
840             TORCH_CHECK(r.supplement_.defined(), "Invalid PREMUL_SUM ReduceOp");
841             const auto* preMulSupplement =
842                 reinterpret_cast<::c10d::NCCLPreMulSumSupplement*>(
843                     r.supplement_.get());
844             if (!preMulSupplement->tensor_factor.defined()) {
845               return py::make_tuple(r.op_, preMulSupplement->double_factor);
846             } else {
847               return py::make_tuple(r.op_, preMulSupplement->tensor_factor);
848             }
849           },
850           [](const py::tuple& t) {
851             // __setstate__
852             TORCH_CHECK(t.size() == 2, "Invalid state");
853             const auto op =
854                 static_cast<::c10d::ReduceOp::RedOpType>(t[0].cast<uint8_t>());
855             if (op != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
856               return ::c10d::ReduceOp(op);
857             }
858             const auto preMulSupplement_factor = t[1];
859             if (py::isinstance<py::float_>(preMulSupplement_factor)) {
860               return ::c10d::makeNCCLPreMulSum(t[1].cast<double>());
861             } else {
862               return ::c10d::makeNCCLPreMulSum(t[1].cast<at::Tensor>());
863             }
864           }));
865 
866   py::enum_<::c10d::ReduceOp::RedOpType>(reduce_op, "RedOpType")
867       .value("SUM", ::c10d::ReduceOp::RedOpType::SUM)
868       .value("AVG", ::c10d::ReduceOp::RedOpType::AVG)
869       .value("PRODUCT", ::c10d::ReduceOp::RedOpType::PRODUCT)
870       .value("MIN", ::c10d::ReduceOp::RedOpType::MIN)
871       .value("MAX", ::c10d::ReduceOp::RedOpType::MAX)
872       .value("BAND", ::c10d::ReduceOp::RedOpType::BAND)
873       .value("BOR", ::c10d::ReduceOp::RedOpType::BOR)
874       .value("BXOR", ::c10d::ReduceOp::RedOpType::BXOR)
875       .value("PREMUL_SUM", ::c10d::ReduceOp::RedOpType::PREMUL_SUM)
876       .export_values();
877 
878   // note(crcrpar): This could be removed because users will not pass
879   // `RedOpType` to reduce collective ops Ref: [Implicit
880   // conversions](https://pybind11.readthedocs.io/en/stable/advanced/classes.html#implicit-conversions)
881   // Let us skip the explicit construction of `c10d::ReduceOp` from
882   // `c10d::ReduceOp::RedOpType` in Python.
883   py::implicitly_convertible<::c10d::ReduceOp::RedOpType, ::c10d::ReduceOp>();
884 
885   module
886       .def(
887           "_make_nccl_premul_sum",
888           &::c10d::makeNCCLPreMulSum<double>,
889           py::arg("factor").noconvert(),
890           py::return_value_policy::copy, // seems safest
891           py::call_guard<py::gil_scoped_release>())
892       .def(
893           "_make_nccl_premul_sum",
894           &::c10d::makeNCCLPreMulSum<at::Tensor>,
895           py::arg("factor").noconvert(),
896           py::return_value_policy::copy, // seems safest
897           py::call_guard<py::gil_scoped_release>());
898 
899   module.def(
900       "_set_thread_isolation_mode",
901       &::c10d::set_thread_isolation_mode,
902       py::arg("enable"));
903 
904   // Bindings for GroupRegistry.hpp
905   //
906   // Register a process group in the native registry. Process groups registered
907   // via `_register_process_group` can be resolved from both Python and C++.
908   module.def(
909       "_register_process_group",
910       [](const std::string& group_name,
911          c10::intrusive_ptr<::c10d::ProcessGroup> group) {
912         ::c10d::register_process_group(group_name, std::move(group));
913       },
914       py::arg("group_name"),
915       py::arg("group"));
916 
917   // Resolve a process group from the native registry
918   module.def(
919       "_resolve_process_group",
920       [](const std::string& group_name) {
921         return ::c10d::resolve_process_group(group_name);
922       },
923       py::arg("group_name"));
924 
925   module.def(
926       "_register_work",
927       [](const at::Tensor& tensor,
928          const c10::intrusive_ptr<::c10d::Work>& work) {
929         dynamic_cast<::c10d::PyProcessGroup::PyWork*>(work.get())
930             ->ref_py_object();
931         ::c10d::register_work(tensor, std::move(work));
932       },
933       py::arg("tensor"),
934       py::arg("work"));
935 
936   // Remove a group from the native registry
937   module.def(
938       "_unregister_process_group",
939       [](const std::string& group_name) {
940         return ::c10d::unregister_process_group(group_name);
941       },
942       py::arg("group_name"));
943 
944   // Remove all process groups from the native registry
945   module.def("_unregister_all_process_groups", []() {
946     return ::c10d::unregister_all_process_groups();
947   });
948 
949   py::class_<::c10d::BroadcastOptions>(module, "BroadcastOptions")
950       .def(py::init<>())
951       .def_readwrite("rootRank", &::c10d::BroadcastOptions::rootRank)
952       .def_readwrite("rootTensor", &::c10d::BroadcastOptions::rootTensor)
953       .def_readwrite("timeout", &::c10d::BroadcastOptions::timeout)
954       .def_readwrite("asyncOp", &::c10d::BroadcastOptions::asyncOp);
955 
956   py::class_<::c10d::AllreduceOptions>(module, "AllreduceOptions")
957       .def(py::init<>())
958       .def_readwrite("reduceOp", &::c10d::AllreduceOptions::reduceOp)
959       .def_readwrite("timeout", &::c10d::AllreduceOptions::timeout);
960 
961   py::class_<::c10d::AllreduceCoalescedOptions>(
962       module, "AllreduceCoalescedOptions")
963       .def(py::init<>())
964       .def_readwrite("reduceOp", &::c10d::AllreduceCoalescedOptions::reduceOp)
965       .def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout);
966 
967   py::class_<::c10d::ReduceOptions>(module, "ReduceOptions")
968       .def(py::init<>())
969       .def_readwrite("reduceOp", &::c10d::ReduceOptions::reduceOp)
970       .def_readwrite("rootRank", &::c10d::ReduceOptions::rootRank)
971       .def_readwrite("rootTensor", &::c10d::ReduceOptions::rootTensor)
972       .def_readwrite("timeout", &::c10d::ReduceOptions::timeout);
973 
974   py::class_<::c10d::AllgatherOptions>(module, "AllgatherOptions")
975       .def(py::init<>())
976       .def_readwrite("timeout", &::c10d::AllgatherOptions::timeout)
977       .def_readwrite("asyncOp", &::c10d::AllgatherOptions::asyncOp);
978 
979   py::class_<::c10d::GatherOptions>(module, "GatherOptions")
980       .def(py::init<>())
981       .def_readwrite("rootRank", &::c10d::GatherOptions::rootRank)
982       .def_readwrite("timeout", &::c10d::GatherOptions::timeout);
983 
984   py::class_<::c10d::ScatterOptions>(module, "ScatterOptions")
985       .def(py::init<>())
986       .def_readwrite("rootRank", &::c10d::ScatterOptions::rootRank)
987       .def_readwrite("timeout", &::c10d::ScatterOptions::timeout)
988       .def_readwrite("asyncOp", &::c10d::ScatterOptions::asyncOp);
989 
990   py::class_<::c10d::ReduceScatterOptions>(module, "ReduceScatterOptions")
991       .def(py::init<>())
992       .def_readwrite("reduceOp", &::c10d::ReduceScatterOptions::reduceOp)
993       .def_readwrite("timeout", &::c10d::ReduceScatterOptions::timeout)
994       .def_readwrite("asyncOp", &::c10d::ReduceScatterOptions::asyncOp);
995 
996   py::class_<::c10d::BarrierOptions>(module, "BarrierOptions")
997       .def(py::init<>())
998       .def_readwrite("device_ids", &::c10d::BarrierOptions::device_ids)
999       .def_readwrite("timeout", &::c10d::BarrierOptions::timeout)
1000       .def_readwrite("device", &::c10d::BarrierOptions::device);
1001 
1002   py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions")
1003       .def(py::init<>())
1004       .def_readwrite("timeout", &::c10d::AllToAllOptions::timeout);
1005 
1006   py::class_<::c10d::DistributedBackendOptions>(
1007       module, "_DistributedBackendOptions")
1008       .def(py::init<>())
1009       .def_readwrite("store", &::c10d::DistributedBackendOptions::store)
1010       .def_readwrite(
1011           "group_rank", &::c10d::DistributedBackendOptions::group_rank)
1012       .def_readwrite(
1013           "group_size", &::c10d::DistributedBackendOptions::group_size)
1014       .def_readwrite("timeout", &::c10d::DistributedBackendOptions::timeout)
1015       .def_readwrite("group_id", &::c10d::DistributedBackendOptions::group_id)
1016       .def_readwrite(
1017           "global_ranks_in_group",
1018           &::c10d::DistributedBackendOptions::global_ranks_in_group);
1019 
1020   py::class_<
1021       ::c10d::DMAConnectivity,
1022       c10::intrusive_ptr<::c10d::DMAConnectivity>>(module, "_DMAConnectivity")
1023       .def_readonly("device_type", &::c10d::DMAConnectivity::device_type)
1024       .def_readonly(
1025           "connection_type", &::c10d::DMAConnectivity::connection_type)
1026       .def_readonly("matrix", &::c10d::DMAConnectivity::matrix);
1027 
1028   module.def("_detect_dma_connectivity", ::c10d::detect_dma_connectivity);
1029 
1030   using SymmetricMemory = ::c10d::symmetric_memory::SymmetricMemory;
1031   py::class_<SymmetricMemory, c10::intrusive_ptr<SymmetricMemory>>(
1032       module, "_SymmetricMemory")
1033       .def_static("set_group_info", &::c10d::symmetric_memory::set_group_info)
1034       .def_static(
1035           "empty_strided_p2p",
1036           ::c10d::symmetric_memory::empty_strided_p2p,
1037           py::arg("size"),
1038           py::arg("stride"),
1039           py::arg("dtype"),
1040           py::arg("device"),
1041           py::arg("group_name"),
1042           py::arg("alloc_id") = py::none())
1043       .def_static("rendezvous", &::c10d::symmetric_memory::rendezvous)
1044       .def_static(
1045           "get_symmetric_memory",
1046           &::c10d::symmetric_memory::get_symmetric_memory)
1047       .def_static(
1048           "has_multicast_support",
1049           &::c10d::symmetric_memory::has_multicast_support)
1050       .def_property_readonly("rank", &SymmetricMemory::get_rank)
1051       .def_property_readonly("world_size", &SymmetricMemory::get_world_size)
1052       .def_property_readonly(
1053           "buffer_ptrs_dev",
1054           [](const c10::intrusive_ptr<SymmetricMemory>& symm_mem) {
1055             return reinterpret_cast<uintptr_t>(symm_mem->get_buffer_ptrs_dev());
1056           })
1057       .def_property_readonly(
1058           "signal_pad_ptrs_dev",
1059           [](const c10::intrusive_ptr<SymmetricMemory>& symm_mem) {
1060             return reinterpret_cast<uintptr_t>(
1061                 symm_mem->get_signal_pad_ptrs_dev());
1062           })
1063       .def_property_readonly("buffer_size", &SymmetricMemory::get_buffer_size)
1064       .def_property_readonly(
1065           "signal_pad_size", &SymmetricMemory::get_signal_pad_size)
1066       .def(
1067           "get_buffer",
1068           &SymmetricMemory::get_buffer,
1069           py::arg("rank"),
1070           py::arg("sizes"),
1071           py::arg("dtype"),
1072           py::arg("storage_offset") = 0)
1073       .def("barrier", &SymmetricMemory::barrier, py::arg("channel") = 0)
1074       .def(
1075           "put_signal",
1076           &SymmetricMemory::put_signal,
1077           py::arg("dst_rank"),
1078           py::arg("channel") = 0)
1079       .def(
1080           "wait_signal",
1081           &SymmetricMemory::wait_signal,
1082           py::arg("src_rank"),
1083           py::arg("channel") = 0);
1084 
1085   auto store =
1086       py::class_<::c10d::Store, c10::intrusive_ptr<::c10d::Store>, PythonStore>(
1087           module,
1088           "Store",
1089           R"(
1090 Base class for all store implementations, such as the 3 provided by PyTorch
1091 distributed: (:class:`~torch.distributed.TCPStore`, :class:`~torch.distributed.FileStore`,
1092 and :class:`~torch.distributed.HashStore`).
1093 )")
1094           // Default constructor.
1095           .def(py::init<>())
1096           // Convert from std::string to std::vector<uint8>.
1097           .def(
1098               "set",
1099               [](::c10d::Store& store,
1100                  const std::string& key,
1101                  const std::string& value) { store.set(key, toVec8(value)); },
1102               py::call_guard<py::gil_scoped_release>(),
1103               R"(
1104 Inserts the key-value pair into the store based on the supplied ``key`` and
1105 ``value``. If ``key`` already exists in the store, it will overwrite the old
1106 value with the new supplied ``value``.
1107 
1108 Arguments:
1109     key (str): The key to be added to the store.
1110     value (str): The value associated with ``key`` to be added to the store.
1111 
1112 Example::
1113     >>> import torch.distributed as dist
1114     >>> from datetime import timedelta
1115     >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1116     >>> store.set("first_key", "first_value")
1117     >>> # Should return "first_value"
1118     >>> store.get("first_key")
1119 )")
1120           .def(
1121               "compare_set",
1122               [](::c10d::Store& store,
1123                  const std::string& key,
1124                  const std::string& expected_value,
1125                  const std::string& desired_value) -> py::bytes {
1126                 auto value = [&]() {
1127                   py::gil_scoped_release guard;
1128                   return store.compareSet(
1129                       key, toVec8(expected_value), toVec8(desired_value));
1130                 }();
1131                 return toPyBytes(value);
1132               },
1133               R"(
1134 Inserts the key-value pair into the store based on the supplied ``key`` and
1135 performs comparison between ``expected_value`` and ``desired_value`` before inserting. ``desired_value``
1136 will only be set if ``expected_value`` for the ``key`` already exists in the store or if ``expected_value``
1137 is an empty string.
1138 
1139 Arguments:
1140     key (str): The key to be checked in the store.
1141     expected_value (str): The value associated with ``key`` to be checked before insertion.
1142     desired_value (str): The value associated with ``key`` to be added to the store.
1143 
1144 Example::
1145     >>> import torch.distributed as dist
1146     >>> from datetime import timedelta
1147     >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1148     >>> store.set("key", "first_value")
1149     >>> store.compare_set("key", "first_value", "second_value")
1150     >>> # Should return "second_value"
1151     >>> store.get("key")
1152 )")
1153           // Convert from std::vector<uint8_t> to py::bytes.
1154           // The returned value is not guaranteed to be valid UTF-8.
1155           .def(
1156               "get",
1157               [](::c10d::Store& store, const std::string& key) -> py::bytes {
1158                 auto value = [&]() {
1159                   py::gil_scoped_release guard;
1160                   return store.get(key);
1161                 }();
1162                 return toPyBytes(value);
1163               },
1164               R"(
1165 Retrieves the value associated with the given ``key`` in the store. If ``key`` is not
1166 present in the store, the function will wait for ``timeout``, which is defined
1167 when initializing the store, before throwing an exception.
1168 
1169 Arguments:
1170     key (str): The function will return the value associated with this key.
1171 
1172 Returns:
1173     Value associated with ``key`` if ``key`` is in the store.
1174 
1175 Example::
1176     >>> import torch.distributed as dist
1177     >>> from datetime import timedelta
1178     >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1179     >>> store.set("first_key", "first_value")
1180     >>> # Should return "first_value"
1181     >>> store.get("first_key")
1182 )")
1183           .def(
1184               "add",
1185               &::c10d::Store::add,
1186               py::call_guard<py::gil_scoped_release>(),
1187               R"(
1188 The first call to add for a given ``key`` creates a counter associated
1189 with ``key`` in the store, initialized to ``amount``. Subsequent calls to add
1190 with the same ``key`` increment the counter by the specified ``amount``.
1191 Calling :meth:`~torch.distributed.store.add` with a key that has already
1192 been set in the store by :meth:`~torch.distributed.store.set` will result
1193 in an exception.
1194 
1195 Arguments:
1196     key (str): The key in the store whose counter will be incremented.
1197     amount (int): The quantity by which the counter will be incremented.
1198 
1199 Example::
1200     >>> import torch.distributed as dist
1201     >>> from datetime import timedelta
1202     >>> # Using TCPStore as an example, other store types can also be used
1203     >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1204     >>> store.add("first_key", 1)
1205     >>> store.add("first_key", 6)
1206     >>> # Should return 7
1207     >>> store.get("first_key")
1208 )")
1209           .def(
1210               "check",
1211               &::c10d::Store::check,
1212               py::call_guard<py::gil_scoped_release>(),
1213               R"(
1214 The call to check whether a given list of ``keys`` have value stored in
1215 the store. This call immediately returns in normal cases but still suffers
1216 from some edge deadlock cases, e.g, calling check after TCPStore has been destroyed.
1217 Calling :meth:`~torch.distributed.store.check` with a list of keys that
1218 one wants to check whether stored in the store or not.
1219 
1220 Arguments:
1221     keys (lisr[str]): The keys to query whether stored in the store.
1222 
1223 Example::
1224     >>> import torch.distributed as dist
1225     >>> from datetime import timedelta
1226     >>> # Using TCPStore as an example, other store types can also be used
1227     >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1228     >>> store.add("first_key", 1)
1229     >>> # Should return 7
1230     >>> store.check(["first_key"])
1231 )")
1232           .def(
1233               "delete_key",
1234               &::c10d::Store::deleteKey,
1235               py::call_guard<py::gil_scoped_release>(),
1236               R"(
1237 Deletes the key-value pair associated with ``key`` from the store. Returns
1238 `true` if the key was successfully deleted, and `false` if it was not.
1239 
1240 .. warning::
1241     The ``delete_key`` API is only supported by the :class:`~torch.distributed.TCPStore` and :class:`~torch.distributed.HashStore`. Using this API
1242     with the :class:`~torch.distributed.FileStore` will result in an exception.
1243 
1244 Arguments:
1245     key (str): The key to be deleted from the store
1246 
1247 Returns:
1248     `True` if ``key`` was deleted, otherwise `False`.
1249 
1250 Example::
1251     >>> import torch.distributed as dist
1252     >>> from datetime import timedelta
1253     >>> # Using TCPStore as an example, HashStore can also be used
1254     >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1255     >>> store.set("first_key")
1256     >>> # This should return true
1257     >>> store.delete_key("first_key")
1258     >>> # This should return false
1259     >>> store.delete_key("bad_key")
1260 )")
1261           .def(
1262               "num_keys",
1263               &::c10d::Store::getNumKeys,
1264               py::call_guard<py::gil_scoped_release>(),
1265               R"(
1266 Returns the number of keys set in the store. Note that this number will typically
1267 be one greater than the number of keys added by :meth:`~torch.distributed.store.set`
1268 and :meth:`~torch.distributed.store.add` since one key is used to coordinate all
1269 the workers using the store.
1270 
1271 .. warning::
1272     When used with the :class:`~torch.distributed.TCPStore`, ``num_keys`` returns the number of keys written to the underlying file. If the store is destructed and another store is created with the same file, the original keys will be retained.
1273 
1274 Returns:
1275     The number of keys present in the store.
1276 
1277 Example::
1278     >>> import torch.distributed as dist
1279     >>> from datetime import timedelta
1280     >>> # Using TCPStore as an example, other store types can also be used
1281     >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1282     >>> store.set("first_key", "first_value")
1283     >>> # This should return 2
1284     >>> store.num_keys()
1285 )")
1286           .def(
1287               "set_timeout",
1288               &::c10d::Store::setTimeout,
1289               py::call_guard<py::gil_scoped_release>(),
1290               R"(
1291 Sets the store's default timeout. This timeout is used during initialization and in
1292 :meth:`~torch.distributed.store.wait` and :meth:`~torch.distributed.store.get`.
1293 
1294 Arguments:
1295     timeout (timedelta): timeout to be set in the store.
1296 
1297 Example::
1298     >>> import torch.distributed as dist
1299     >>> from datetime import timedelta
1300     >>> # Using TCPStore as an example, other store types can also be used
1301     >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1302     >>> store.set_timeout(timedelta(seconds=10))
1303     >>> # This will throw an exception after 10 seconds
1304     >>> store.wait(["bad_key"])
1305 )")
1306           .def(
1307               "wait",
1308               [](::c10d::Store& store, const std::vector<std::string>& keys) {
1309                 store.wait(keys);
1310               },
1311               py::call_guard<py::gil_scoped_release>(),
1312               R"(
1313 Waits for each key in ``keys`` to be added to the store. If not all keys are
1314 set before the ``timeout`` (set during store initialization), then ``wait``
1315 will throw an exception.
1316 
1317 Arguments:
1318     keys (list): List of keys on which to wait until they are set in the store.
1319 
1320 Example::
1321     >>> import torch.distributed as dist
1322     >>> from datetime import timedelta
1323     >>> # Using TCPStore as an example, other store types can also be used
1324     >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1325     >>> # This will throw an exception after 30 seconds
1326     >>> store.wait(["bad_key"])
1327 )")
1328           .def(
1329               "wait",
1330               [](::c10d::Store& store,
1331                  const std::vector<std::string>& keys,
1332                  const std::chrono::milliseconds& timeout) {
1333                 store.wait(keys, timeout);
1334               },
1335               py::call_guard<py::gil_scoped_release>(),
1336               R"(
1337 Waits for each key in ``keys`` to be added to the store, and throws an exception
1338 if the keys have not been set by the supplied ``timeout``.
1339 
1340 Arguments:
1341     keys (list): List of keys on which to wait until they are set in the store.
1342     timeout (timedelta): Time to wait for the keys to be added before throwing an exception.
1343 
1344 Example::
1345     >>> import torch.distributed as dist
1346     >>> from datetime import timedelta
1347     >>> # Using TCPStore as an example, other store types can also be used
1348     >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1349     >>> # This will throw an exception after 10 seconds
1350     >>> store.wait(["bad_key"], timedelta(seconds=10))
1351 )")
1352           .def_property_readonly(
1353               "timeout",
1354               &::c10d::Store::getTimeout,
1355               R"(Gets the timeout of the store.)")
1356           .def(
1357               "append",
1358               [](::c10d::Store& store,
1359                  const std::string& key,
1360                  const std::string& value) {
1361                 store.append(key, toVec8(value));
1362               },
1363               py::call_guard<py::gil_scoped_release>(),
1364               R"(
1365 Append the key-value pair into the store based on the supplied ``key`` and
1366 ``value``. If ``key`` does not exists in the store, it will be created.
1367 
1368 Arguments:
1369     key (str): The key to be appended to the store.
1370     value (str): The value associated with ``key`` to be added to the store.
1371 
1372 Example::
1373     >>> import torch.distributed as dist
1374     >>> from datetime import timedelta
1375     >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1376     >>> store.append("first_key", "po")
1377     >>> store.append("first_key", "tato")
1378     >>> # Should return "potato"
1379     >>> store.get("first_key")
1380 )")
1381           .def(
1382               "multi_get",
1383               [](::c10d::Store& store, const std::vector<std::string>& keys) {
1384                 auto values = [&]() {
1385                   py::gil_scoped_release guard;
1386                   return store.multiGet(keys);
1387                 }();
1388                 return toPyBytes(values);
1389               },
1390               R"(
1391 Retrieve all values in ``keys``. If any key in ``keys`` is not
1392 present in the store, the function will wait for ``timeout``
1393 
1394 Arguments:
1395     keys (List[str]): The keys to be retrieved from the store.
1396 
1397 Example::
1398     >>> import torch.distributed as dist
1399     >>> from datetime import timedelta
1400     >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1401     >>> store.set("first_key", "po")
1402     >>> store.set("second_key", "tato")
1403     >>> # Should return [b"po", b"tato"]
1404     >>> store.multi_get(["first_key", "second_key"])
1405 )")
1406           .def(
1407               "multi_set",
1408               [](::c10d::Store& store,
1409                  const std::vector<std::string>& keys,
1410                  const std::vector<std::string>& values) {
1411                 store.multiSet(keys, toVec8(values));
1412               },
1413               py::call_guard<py::gil_scoped_release>(),
1414               R"(
1415 Inserts a list key-value pair into the store based on the supplied ``keys`` and ``values``
1416 
1417 Arguments:
1418     keys (List[str]): The keys to insert.
1419     values (List[str]): The values to insert.
1420 
1421 Example::
1422     >>> import torch.distributed as dist
1423     >>> from datetime import timedelta
1424     >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
1425     >>> store.multi_set(["first_key", "second_key"], ["po", "tato"])
1426     >>> # Should return b"po"
1427     >>> store.get("first_key")
1428 )")
1429           .def(
1430               "has_extended_api",
1431               &::c10d::Store::hasExtendedApi,
1432               R"(Returns true if the store supports extended operations.)");
1433 
1434   intrusive_ptr_class_<::c10d::FileStore>(
1435       module,
1436       "FileStore",
1437       store,
1438       R"(
1439 A store implementation that uses a file to store the underlying key-value pairs.
1440 
1441 Arguments:
1442     file_name (str): path of the file in which to store the key-value pairs
1443     world_size (int, optional): The total number of processes using the store. Default is -1 (a negative value indicates a non-fixed number of store users).
1444 
1445 Example::
1446     >>> import torch.distributed as dist
1447     >>> store1 = dist.FileStore("/tmp/filestore", 2)
1448     >>> store2 = dist.FileStore("/tmp/filestore", 2)
1449     >>> # Use any of the store methods from either the client or server after initialization
1450     >>> store1.set("first_key", "first_value")
1451     >>> store2.get("first_key")
1452 
1453       )")
1454       .def(
1455           py::init<const std::string&, int>(),
1456           py::arg("file_name"),
1457           py::arg("world_size") = -1)
1458       .def_property_readonly(
1459           "path",
1460           &::c10d::FileStore::getPath,
1461           R"(Gets the path of the file used by FileStore to store key-value pairs.)");
1462 
1463 #ifndef _WIN32
1464   intrusive_ptr_class_<::c10d::HashStore>(
1465       module,
1466       "HashStore",
1467       store,
1468       R"(
1469 A thread-safe store implementation based on an underlying hashmap. This store can be used
1470 within the same process (for example, by other threads), but cannot be used across processes.
1471 
1472 Example::
1473     >>> import torch.distributed as dist
1474     >>> store = dist.HashStore()
1475     >>> # store can be used from other threads
1476     >>> # Use any of the store methods after initialization
1477     >>> store.set("first_key", "first_value")
1478       )")
1479       .def(py::init<>());
1480 #endif
1481 
1482   intrusive_ptr_class_<::c10d::TCPStore>(
1483       module,
1484       "TCPStore",
1485       store,
1486       R"(
1487 A TCP-based distributed key-value store implementation. The server store holds
1488 the data, while the client stores can connect to the server store over TCP and
1489 perform actions such as :meth:`~torch.distributed.store.set` to insert a key-value
1490 pair, :meth:`~torch.distributed.store.get` to retrieve a key-value pair, etc. There
1491 should always be one server store initialized because the client store(s) will wait for
1492 the server to establish a connection.
1493 
1494 Arguments:
1495     host_name (str): The hostname or IP Address the server store should run on.
1496     port (int): The port on which the server store should listen for incoming requests.
1497     world_size (int, optional): The total number of store users (number of clients + 1 for the server). Default is None (None indicates a non-fixed number of store users).
1498     is_master (bool, optional): True when initializing the server store and False for client stores. Default is False.
1499     timeout (timedelta, optional): Timeout used by the store during initialization and for methods such as :meth:`~torch.distributed.store.get` and :meth:`~torch.distributed.store.wait`. Default is timedelta(seconds=300)
1500     wait_for_workers (bool, optional): Whether to wait for all the workers to connect with the server store. This is only applicable when world_size is a fixed value. Default is True.
1501     multi_tenant (bool, optional): If True, all ``TCPStore`` instances in the current process with the same host/port will use the same underlying ``TCPServer``. Default is False.
1502     master_listen_fd (int, optional): If specified, the underlying ``TCPServer`` will listen on this file descriptor, which must be a socket already bound to ``port``. Useful to avoid port assignment races in some scenarios. Default is None (meaning the server creates a new socket and attempts to bind it to ``port``).
1503     use_libuv (bool, optional): If True, use libuv for ``TCPServer`` backend. Default is True.
1504 Example::
1505     >>> import torch.distributed as dist
1506     >>> from datetime import timedelta
1507     >>> # Run on process 1 (server)
1508     >>> server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30))
1509     >>> # Run on process 2 (client)
1510     >>> client_store = dist.TCPStore("127.0.0.1", 1234, 2, False)
1511     >>> # Use any of the store methods from either the client or server after initialization
1512     >>> server_store.set("first_key", "first_value")
1513     >>> client_store.get("first_key")
1514       )")
1515       .def(
1516           py::init([](const std::string& host,
1517                       uint16_t port,
1518                       std::optional<int> worldSize,
1519                       bool isServer,
1520                       std::chrono::milliseconds timeout,
1521                       bool waitWorkers,
1522                       bool multiTenant,
1523                       std::optional<int> masterListenFd,
1524                       bool useLibUV) {
1525             std::optional<std::size_t> numWorkers = std::nullopt;
1526             if (worldSize.has_value() && worldSize.value() > -1) {
1527               numWorkers = static_cast<std::size_t>(worldSize.value());
1528             }
1529 
1530             ::c10d::TCPStoreOptions opts{
1531                 port,
1532                 isServer,
1533                 numWorkers,
1534                 waitWorkers,
1535                 timeout,
1536                 multiTenant,
1537                 masterListenFd,
1538                 useLibUV};
1539 
1540             return c10::make_intrusive<::c10d::TCPStore>(host, opts);
1541           }),
1542           py::arg("host_name"),
1543           py::arg("port"),
1544           py::arg("world_size") = py::none(),
1545           // using noconvert() requires this argument to be True or False
1546           // prevents accidental implicit conversion to bool
1547           py::arg("is_master").noconvert() = false,
1548           py::arg("timeout") =
1549               std::chrono::milliseconds(::c10d::Store::kDefaultTimeout),
1550           py::arg("wait_for_workers") = true,
1551           py::arg("multi_tenant") = false,
1552           py::arg("master_listen_fd") = py::none(),
1553           py::arg("use_libuv") = true,
1554           py::call_guard<py::gil_scoped_release>())
1555       .def_property_readonly(
1556           "host",
1557           &::c10d::TCPStore::getHost,
1558           R"(Gets the hostname on which the store listens for requests.)")
1559       .def_property_readonly(
1560           "port",
1561           &::c10d::TCPStore::getPort,
1562           R"(Gets the port number on which the store listens for requests.)")
1563       .def_property_readonly(
1564           "libuvBackend",
1565           &::c10d::TCPStore::isLibUvBackend,
1566           R"(Returns True if it's using the libuv backend.)")
1567       .def(
1568           "__repr__",
1569           &::c10d::TCPStore::repr,
1570           R"(Returns a string representation of the TCPStore.)",
1571           py::call_guard<py::gil_scoped_release>());
1572 
1573   intrusive_ptr_class_<::c10d::PrefixStore>(
1574       module,
1575       "PrefixStore",
1576       store,
1577       R"(
1578 A wrapper around any of the 3 key-value stores (:class:`~torch.distributed.TCPStore`,
1579 :class:`~torch.distributed.FileStore`, and :class:`~torch.distributed.HashStore`)
1580 that adds a prefix to each key inserted to the store.
1581 
1582 Arguments:
1583     prefix (str): The prefix string that is prepended to each key before being inserted into the store.
1584     store (torch.distributed.store): A store object that forms the underlying key-value store.
1585       )")
1586       .def(py::init<const std::string&, c10::intrusive_ptr<::c10d::Store>>())
1587       .def_property_readonly(
1588           "underlying_store",
1589           &::c10d::PrefixStore::getUnderlyingStore,
1590           R"(Gets the underlying store object that PrefixStore wraps around.)")
1591       .def_property_readonly(
1592           "_underlying_non_prefix_store",
1593           &::c10d::PrefixStore::getUnderlyingNonPrefixStore,
1594           R"(Recursively to get the store before layers of wrapping with PrefixStore.)");
1595 
1596   using namespace std::chrono_literals;
1597 
1598   auto collectives =
1599       py::class_<
1600           ::c10d::ControlCollectives,
1601           c10::intrusive_ptr<::c10d::ControlCollectives>>(
1602           module,
1603           "_ControlCollectives",
1604           R"(
1605 Base class for all ControlCollectives implementations.
1606 )")
1607           .def(
1608               "barrier",
1609               &::c10d::ControlCollectives::barrier,
1610               py::arg("key"),
1611               py::arg("timeout") = 5min,
1612               py::arg("block") = true,
1613               py::call_guard<py::gil_scoped_release>(),
1614               R"(
1615 Blocks until all workers have entered this function.
1616 
1617 Arguments:
1618     key (str): The unique key used to identify this operation.
1619     timeout (duration): The timeout for this operation.
1620     block (bool): whether to block this working waiting on the results of the barrier.
1621 )")
1622           .def(
1623               "all_sum",
1624               &::c10d::ControlCollectives::allSum,
1625               py::arg("key"),
1626               py::arg("data"),
1627               py::arg("timeout") = 5min,
1628               py::call_guard<py::gil_scoped_release>(),
1629               R"(
1630 Computes a sum across all workers and returns the final value.
1631 
1632 Arguments:
1633     key (str): The unique key used to identify this operation.
1634     data (int): The data to sum.
1635     timeout (duration): The timeout for this operation.
1636 )")
1637           .def(
1638               "broadcast_send",
1639               [](::c10d::ControlCollectives& collectives,
1640                  const std::string& key,
1641                  const std::string& data,
1642                  std::chrono::milliseconds timeout = 5min) {
1643                 collectives.broadcastSend(key, toVec8(data), timeout);
1644               },
1645               py::arg("key"),
1646               py::arg("data"),
1647               py::arg("timeout") = 5min,
1648               py::call_guard<py::gil_scoped_release>(),
1649               R"(
1650 Sends data to all other workers. Must be only called from one worker.
1651 
1652 Arguments:
1653     key (str): The unique key used to identify this operation.
1654     data (str): The data to send.
1655     timeout (duration): The timeout for this operation.
1656 )")
1657           .def(
1658               "broadcast_recv",
1659               [](::c10d::ControlCollectives& collectives,
1660                  const std::string& key,
1661                  std::chrono::milliseconds timeout = 5min) {
1662                 auto out = [&]() {
1663                   py::gil_scoped_release guard;
1664                   return collectives.broadcastRecv(key, timeout);
1665                 }();
1666                 return toPyBytes(out);
1667               },
1668               py::arg("key"),
1669               py::arg("timeout") = 5min,
1670               R"(
1671 Receives data broadcasted from 1 worker.
1672 
1673 Arguments:
1674     key (str): The unique key used to identify this operation.
1675     timeout (duration): The timeout for this operation.
1676 )")
1677           .def(
1678               "gather_send",
1679               [](::c10d::ControlCollectives& collectives,
1680                  const std::string& key,
1681                  const std::string& data,
1682                  std::chrono::milliseconds timeout = 5min) {
1683                 collectives.gatherSend(key, toVec8(data), timeout);
1684               },
1685               py::arg("key"),
1686               py::arg("data"),
1687               py::arg("timeout") = 5min,
1688               py::call_guard<py::gil_scoped_release>(),
1689               R"(
1690 Sends data to one other worker.
1691 
1692 Arguments:
1693     key (str): The unique key used to identify this operation.
1694     data (str): The data to send.
1695     timeout (duration): The timeout for this operation.
1696 )")
1697           .def(
1698               "gather_recv",
1699               [](::c10d::ControlCollectives& collectives,
1700                  const std::string& key,
1701                  const std::string& data,
1702                  std::chrono::milliseconds timeout = 5min) {
1703                 auto out = [&]() {
1704                   py::gil_scoped_release guard;
1705                   return collectives.gatherRecv(key, toVec8(data), timeout);
1706                 }();
1707                 return toPyBytes(out);
1708               },
1709               py::arg("key"),
1710               py::arg("data"),
1711               py::arg("timeout") = 5min,
1712               R"(
1713 Receives data broadcasted from all workers. Must only be called by one worker.
1714 
1715 Arguments:
1716     key (str): The unique key used to identify this operation.
1717     timeout (duration): The timeout for this operation.
1718 )")
1719 
1720           .def(
1721               "scatter_send",
1722               [](::c10d::ControlCollectives& collectives,
1723                  const std::string& key,
1724                  const std::vector<std::string>& data,
1725                  std::chrono::milliseconds timeout = 5min) {
1726                 auto out = [&]() {
1727                   py::gil_scoped_release guard;
1728                   return collectives.scatterSend(key, toVec8(data), timeout);
1729                 }();
1730                 return toPyBytes(out);
1731               },
1732               py::arg("key"),
1733               py::arg("data"),
1734               py::arg("timeout") = 5min,
1735               R"(
1736 Sends rank specific data to all other workers.
1737 
1738 Arguments:
1739     key (str): The unique key used to identify this operation.
1740     data (str): The data to send.
1741     timeout (duration): The timeout for this operation.
1742 )")
1743           .def(
1744               "scatter_recv",
1745               [](::c10d::ControlCollectives& collectives,
1746                  const std::string& key,
1747                  std::chrono::milliseconds timeout = 5min) {
1748                 auto out = [&]() {
1749                   py::gil_scoped_release guard;
1750                   return collectives.scatterRecv(key, timeout);
1751                 }();
1752                 return toPyBytes(out);
1753               },
1754               py::arg("key"),
1755               py::arg("timeout") = 5min,
1756               R"(
1757 Receives rank specific data from one worker.
1758 
1759 Arguments:
1760     key (str): The unique key used to identify this operation.
1761     timeout (duration): The timeout for this operation.
1762 )")
1763 
1764           .def(
1765               "all_gather",
1766               [](::c10d::ControlCollectives& collectives,
1767                  const std::string& key,
1768                  const std::string& data,
1769                  std::chrono::milliseconds timeout = 5min) {
1770                 auto out = [&]() {
1771                   py::gil_scoped_release guard;
1772                   return collectives.allGather(key, toVec8(data), timeout);
1773                 }();
1774                 return toPyBytes(out);
1775               },
1776               py::arg("key"),
1777               py::arg("data"),
1778               py::arg("timeout") = 5min,
1779               R"(
1780 Sends data to all workers and receives data from all other workers.
1781 
1782 Arguments:
1783     key (str): The unique key used to identify this operation.
1784     data (str): The data to send.
1785     timeout (duration): The timeout for this operation.
1786 )");
1787 
1788   intrusive_ptr_class_<::c10d::StoreCollectives>(
1789       module,
1790       "_StoreCollectives",
1791       collectives,
1792       R"(
1793 An implementation of ControlCollectives that uses the provided store as the underlying
1794 communication mechanism.
1795       )")
1796       .def(
1797           py::init<c10::intrusive_ptr<::c10d::Store>, int, int>(),
1798           py::arg("store"),
1799           py::arg("rank"),
1800           py::arg("world_size"));
1801 
1802   auto processGroup =
1803       py::class_<
1804           ::c10d::ProcessGroup,
1805           c10::intrusive_ptr<::c10d::ProcessGroup>,
1806           ::c10d::PyProcessGroup>(module, "ProcessGroup")
1807           .def(py::init<int, int>())
1808           .def(
1809               py::init<
1810                   const c10::intrusive_ptr<::c10d::Store>&,
1811                   int,
1812                   int,
1813                   c10::intrusive_ptr<::c10d::ProcessGroup::Options>>(),
1814               py::call_guard<py::gil_scoped_release>())
1815           .def("rank", &::c10d::ProcessGroup::getRank)
1816           .def("size", &::c10d::ProcessGroup::getSize)
1817           .def("name", &::c10d::ProcessGroup::getBackendName)
1818           .def("_id", &::c10d::ProcessGroup::getID)
1819           .def(
1820               "_backend_id",
1821               &::c10d::ProcessGroup::getBackendID,
1822               py::arg("backend_type"))
1823           .def_property_readonly("options", &::c10d::ProcessGroup::getOptions)
1824           .def(
1825               "broadcast",
1826               &::c10d::ProcessGroup::broadcast,
1827               py::arg("tensors"),
1828               py::arg("opts") = ::c10d::BroadcastOptions(),
1829               py::call_guard<py::gil_scoped_release>())
1830           .def(
1831               "broadcast",
1832               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1833                  at::Tensor& x,
1834                  int rootRank) {
1835                 ::c10d::BroadcastOptions opts;
1836                 opts.rootRank = rootRank;
1837                 std::vector<at::Tensor> tensors = {x};
1838                 return self->broadcast(tensors, opts);
1839               },
1840               py::arg("tensor"),
1841               py::arg("root"),
1842               py::call_guard<py::gil_scoped_release>())
1843           .def(
1844               "allreduce",
1845               &::c10d::ProcessGroup::allreduce,
1846               py::arg("tensors"),
1847               py::arg("opts") = ::c10d::AllreduceOptions(),
1848               py::call_guard<py::gil_scoped_release>())
1849           .def(
1850               "allreduce",
1851               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1852                  std::vector<at::Tensor>& xs,
1853                  const ::c10d::ReduceOp& op) {
1854                 ::c10d::AllreduceOptions opts;
1855                 opts.reduceOp = op;
1856                 return self->allreduce(xs, opts);
1857               },
1858               py::arg("tensors"),
1859               py::arg("op") = ::c10d::ReduceOp::SUM,
1860               py::call_guard<py::gil_scoped_release>())
1861 
1862           .def(
1863               "allreduce",
1864               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1865                  at::Tensor& x,
1866                  const ::c10d::ReduceOp& op) {
1867                 ::c10d::AllreduceOptions opts;
1868                 opts.reduceOp = op;
1869                 std::vector<at::Tensor> xs = {x};
1870                 return self->allreduce(xs, opts);
1871               },
1872               py::arg("tensor"),
1873               py::arg("op") = ::c10d::ReduceOp::SUM,
1874               py::call_guard<py::gil_scoped_release>())
1875           .def(
1876               "allreduce_coalesced",
1877               &::c10d::ProcessGroup::allreduce_coalesced,
1878               py::arg("tensors"),
1879               py::arg("opts") = ::c10d::AllreduceCoalescedOptions(),
1880               py::call_guard<py::gil_scoped_release>())
1881 
1882           .def(
1883               "reduce",
1884               &::c10d::ProcessGroup::reduce,
1885               py::arg("tensors"),
1886               py::arg("opts") = ::c10d::ReduceOptions(),
1887               py::call_guard<py::gil_scoped_release>())
1888 
1889           .def(
1890               "reduce",
1891               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1892                  at::Tensor& x,
1893                  int rootRank,
1894                  const ::c10d::ReduceOp& op) {
1895                 ::c10d::ReduceOptions opts;
1896                 opts.reduceOp = op;
1897                 opts.rootRank = rootRank;
1898                 std::vector<at::Tensor> xs = {x};
1899                 return self->reduce(xs, opts);
1900               },
1901               py::arg("tensor"),
1902               py::arg("root"),
1903               py::arg("op") = ::c10d::ReduceOp::SUM,
1904               py::call_guard<py::gil_scoped_release>())
1905           .def(
1906               "allgather",
1907               &::c10d::ProcessGroup::allgather,
1908               py::arg("output_tensors"),
1909               py::arg("input_tensors"),
1910               py::arg("opts") = ::c10d::AllgatherOptions(),
1911               py::call_guard<py::gil_scoped_release>())
1912           .def(
1913               "allgather",
1914               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1915                  std::vector<at::Tensor>& output,
1916                  at::Tensor& input) {
1917                 std::vector<std::vector<at::Tensor>> outputs = {output};
1918                 std::vector<at::Tensor> inputs = {input};
1919                 return self->allgather(
1920                     outputs, inputs, ::c10d::AllgatherOptions());
1921               },
1922               py::arg("output_tensors"),
1923               py::arg("input_tensor"),
1924               py::call_guard<py::gil_scoped_release>())
1925           .def(
1926               "_allgather_base",
1927               &::c10d::ProcessGroup::_allgather_base,
1928               py::arg("output"),
1929               py::arg("input"),
1930               py::arg("opts") = ::c10d::AllgatherOptions(),
1931               py::call_guard<py::gil_scoped_release>())
1932           .def(
1933               "allgather_coalesced",
1934               &::c10d::ProcessGroup::allgather_coalesced,
1935               py::arg("output_lists"),
1936               py::arg("input_list"),
1937               py::arg("opts") = ::c10d::AllgatherOptions(),
1938               py::call_guard<py::gil_scoped_release>())
1939           .def(
1940               "allgather_into_tensor_coalesced",
1941               &::c10d::ProcessGroup::allgather_into_tensor_coalesced,
1942               py::arg("outputs"),
1943               py::arg("inputs"),
1944               py::arg("opts") = ::c10d::AllgatherOptions(),
1945               py::call_guard<py::gil_scoped_release>())
1946           .def(
1947               "gather",
1948               &::c10d::ProcessGroup::gather,
1949               py::arg("output_tensors"),
1950               py::arg("input_tensors"),
1951               py::arg("opts") = ::c10d::GatherOptions(),
1952               py::call_guard<py::gil_scoped_release>())
1953 
1954           .def(
1955               "gather",
1956               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1957                  std::vector<at::Tensor>& output,
1958                  at::Tensor& input,
1959                  int rootRank) {
1960                 ::c10d::GatherOptions opts;
1961                 opts.rootRank = rootRank;
1962                 std::vector<std::vector<at::Tensor>> outputs = {output};
1963                 std::vector<at::Tensor> inputs = {input};
1964                 return self->gather(outputs, inputs, opts);
1965               },
1966               py::arg("output_tensors"),
1967               py::arg("input_tensor"),
1968               py::arg("root"),
1969               py::call_guard<py::gil_scoped_release>())
1970           .def(
1971               "scatter",
1972               &::c10d::ProcessGroup::scatter,
1973               py::arg("output_tensors"),
1974               py::arg("input_tensors"),
1975               py::arg("opts") = ::c10d::ScatterOptions(),
1976               py::call_guard<py::gil_scoped_release>())
1977           .def(
1978               "scatter",
1979               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1980                  at::Tensor& output,
1981                  std::vector<at::Tensor>& input,
1982                  int rootRank) {
1983                 ::c10d::ScatterOptions opts;
1984                 opts.rootRank = rootRank;
1985                 std::vector<std::vector<at::Tensor>> inputs = {input};
1986                 std::vector<at::Tensor> outputs = {output};
1987                 return self->scatter(outputs, inputs, opts);
1988               },
1989               py::arg("output_tensor"),
1990               py::arg("input_tensors"),
1991               py::arg("root"),
1992               py::call_guard<py::gil_scoped_release>())
1993           .def(
1994               "reduce_scatter",
1995               &::c10d::ProcessGroup::reduce_scatter,
1996               py::arg("output_tensors"),
1997               py::arg("input_tensors"),
1998               py::arg("opts") = ::c10d::ReduceScatterOptions(),
1999               py::call_guard<py::gil_scoped_release>())
2000           .def(
2001               "reduce_scatter",
2002               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
2003                  at::Tensor& output,
2004                  std::vector<at::Tensor>& input,
2005                  const ::c10d::ReduceOp& op) {
2006                 std::vector<at::Tensor> outputs = {output};
2007                 std::vector<std::vector<at::Tensor>> inputs = {input};
2008                 ::c10d::ReduceScatterOptions opts;
2009                 opts.reduceOp = op;
2010                 return self->reduce_scatter(outputs, inputs, opts);
2011               },
2012               py::arg("output"),
2013               py::arg("input"),
2014               py::arg("op") = ::c10d::ReduceOp::SUM,
2015               py::call_guard<py::gil_scoped_release>())
2016           .def(
2017               "_reduce_scatter_base",
2018               &::c10d::ProcessGroup::_reduce_scatter_base,
2019               py::arg("outputTensor"),
2020               py::arg("inputTensor"),
2021               py::arg("opts") = ::c10d::ReduceScatterOptions(),
2022               py::call_guard<py::gil_scoped_release>())
2023           .def(
2024               "reduce_scatter_tensor_coalesced",
2025               &::c10d::ProcessGroup::reduce_scatter_tensor_coalesced,
2026               py::arg("outputs"),
2027               py::arg("inputs"),
2028               py::arg("opts") = ::c10d::ReduceScatterOptions(),
2029               py::call_guard<py::gil_scoped_release>())
2030           .def(
2031               "alltoall_base",
2032               &::c10d::ProcessGroup::alltoall_base,
2033               py::arg("output"),
2034               py::arg("input"),
2035               py::arg("output_split_sizes"),
2036               py::arg("input_split_sizes"),
2037               py::arg("opts") = ::c10d::AllToAllOptions(),
2038               py::call_guard<py::gil_scoped_release>())
2039           .def(
2040               "alltoall",
2041               &::c10d::ProcessGroup::alltoall,
2042               py::arg("output_tensors"),
2043               py::arg("input_tensors"),
2044               py::arg("opts") = ::c10d::AllToAllOptions(),
2045               py::call_guard<py::gil_scoped_release>())
2046           .def(
2047               "send",
2048               &::c10d::ProcessGroup::send,
2049               py::arg("tensors"),
2050               py::arg("dstRank"),
2051               py::arg("tag"),
2052               py::call_guard<py::gil_scoped_release>())
2053           .def(
2054               "recv",
2055               &::c10d::ProcessGroup::recv,
2056               py::arg("tensors"),
2057               py::arg("srcRank"),
2058               py::arg("tag"),
2059               py::call_guard<py::gil_scoped_release>())
2060           .def(
2061               "recv_anysource",
2062               &::c10d::ProcessGroup::recvAnysource,
2063               py::call_guard<py::gil_scoped_release>())
2064           .def(
2065               "barrier",
2066               &::c10d::ProcessGroup::barrier,
2067               py::arg("opts") = ::c10d::BarrierOptions(),
2068               py::call_guard<py::gil_scoped_release>())
2069           .def(
2070               "_set_sequence_number_for_group",
2071               &::c10d::ProcessGroup::setSequenceNumberForGroup,
2072               py::call_guard<py::gil_scoped_release>())
2073           .def(
2074               "_get_sequence_number_for_group",
2075               &::c10d::ProcessGroup::getSequenceNumberForGroup,
2076               py::call_guard<py::gil_scoped_release>())
2077           .def(
2078               "monitored_barrier",
2079               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
2080                  const std::chrono::milliseconds& timeout,
2081                  bool waitAllRanks) {
2082                 ::c10d::BarrierOptions opts;
2083                 opts.timeout = timeout;
2084                 return self->monitoredBarrier(opts, waitAllRanks);
2085               },
2086               py::arg("timeout") = ::c10d::kUnsetTimeout,
2087               py::arg("wait_all_ranks") = false,
2088               py::call_guard<py::gil_scoped_release>())
2089           .def_property_readonly(
2090               "_device_types", &::c10d::ProcessGroup::getDeviceTypes)
2091           .def(
2092               "_get_backend_name",
2093               &::c10d::ProcessGroup::getBackendName,
2094               py::call_guard<py::gil_scoped_release>())
2095           .def(
2096               "_start_coalescing",
2097               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
2098                  const c10::Device& device) {
2099                 self->startCoalescing(device.type());
2100               },
2101               py::arg("device_type"),
2102               py::call_guard<py::gil_scoped_release>())
2103           .def(
2104               "_end_coalescing",
2105               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
2106                  const c10::Device& device) {
2107                 return self->endCoalescing(device.type());
2108               },
2109               py::arg("device_type"),
2110               py::call_guard<py::gil_scoped_release>())
2111           .def(
2112               "_register_backend",
2113               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
2114                  const c10::Device& device,
2115                  const ::c10d::ProcessGroup::BackendType& backendType,
2116                  const std::optional<c10::intrusive_ptr<::c10d::Backend>>&
2117                      backend) {
2118                 self->setBackend(device.type(), backendType, backend);
2119               },
2120               py::arg("device"),
2121               py::arg("backend_type"),
2122               py::arg("backend") =
2123                   std::optional<c10::intrusive_ptr<::c10d::Backend>>(),
2124               py::call_guard<py::gil_scoped_release>())
2125           .def(
2126               "_get_backend",
2127               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
2128                  const c10::Device& device) {
2129                 return self->getBackend(device.type());
2130               },
2131               py::arg("device"),
2132               py::call_guard<py::gil_scoped_release>())
2133           .def(
2134               "_register_on_completion_hook",
2135               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
2136                  py::object hook) {
2137                 // We need to wrap a py::object hook with a wrapper to hold
2138                 // GIL before dereferencing the py::object.
2139                 // This needs to happen here instead of in ProcessGroup
2140                 // backend implementations and the latter cannot depend on
2141                 // python-related libs.
2142                 self->registerOnCompletionHook(
2143                     [hookWrapper = ::c10d::PythonOnCompletionHook(std::move(
2144                          hook))](std::shared_ptr<::c10d::WorkInfo> workInfo) {
2145                       hookWrapper(workInfo);
2146                     });
2147               },
2148               py::arg("hook"),
2149               // Intentionally holding GIL as we move hook py::object. This
2150               // should be OK as register a hook is cheap.
2151               py::call_guard<py::gil_scoped_acquire>(),
2152               R"(
2153 Register a hook function which is fired on every ``ProcessGroup::Work`` completion.
2154 The hook must have the following signature:
2155 
2156 >>> def hook(work_info: torch._C._distributed_c10d.WorkInfo) -> None:
2157 >>>     # custom code
2158 >>>     # work_info.op_type: type of collective of this work
2159 >>>     # work_info.seq: sequence number of collective of this work
2160 >>>     # work_info.time_started: system time when user code called this collective
2161 >>>     # work_info.time_finished: system time when the watchdog thread detected
2162 >>>     #     completion of this work. Note that, there can be delays between the
2163 >>>     #     actual completion time and the detection time.
2164 >>>     # work_info.active_duration: duration of this collective measured by CUDAEvents
2165 >>>     #     which can accurately represent the duration between when the collective
2166 >>>     #     is launched and when the collective completes.
2167 
2168 .. warning ::
2169     This only works for NCCL backend for now. All hooks are fired on the cpp watch dog
2170     thread. Firing the Python hook and acquiring GIL requires Python interpreter to be
2171     alive. Therefore, users need to make sure calling ``destroy_process_group(pg)`` on
2172     every active ProcessGroup ``pg`` before exiting.
2173 
2174 .. warning ::
2175     Note that ``Work`` object passed to the hook is a partially copied version without
2176     the output objects. So accessing the output tensors from ``Work`` will not work.
2177 
2178 
2179 Arguments:
2180     hook (Callable): hook function.
2181               )")
2182           .def(
2183               "_wait_for_pending_works",
2184               &::c10d::ProcessGroup::waitForPendingWorks,
2185               py::call_guard<py::gil_scoped_release>())
2186           .def(
2187               "_has_hooks",
2188               &::c10d::ProcessGroup::hasHooks,
2189               py::call_guard<py::gil_scoped_acquire>())
2190           .def(
2191               "_enable_collectives_timing",
2192               &::c10d::ProcessGroup::enableCollectivesTiming,
2193               py::call_guard<py::gil_scoped_acquire>(),
2194               "Enable timing of collectives by all backends. This might incur in additional overhead.")
2195           .def(
2196               "_set_group_name",
2197               &::c10d::ProcessGroup::setGroupName,
2198               py::call_guard<py::gil_scoped_acquire>(),
2199               "Sets the process group name. This is an internal C10D method, do not use.")
2200           .def_property_readonly(
2201               "group_name",
2202               &::c10d::ProcessGroup::getGroupName,
2203               "(Gets this process group name. It's cluster unique)")
2204           .def(
2205               "_set_group_desc",
2206               &::c10d::ProcessGroup::setGroupDesc,
2207               py::call_guard<py::gil_scoped_acquire>(),
2208               "Sets the process group description. This is an internal C10D method, do not use.")
2209           .def_property_readonly(
2210               "group_desc",
2211               &::c10d::ProcessGroup::getGroupDesc,
2212               "Gets this process group description")
2213           .def_property(
2214               "bound_device_id",
2215               &::c10d::ProcessGroup::getBoundDeviceId,
2216               &::c10d::ProcessGroup::setBoundDeviceId)
2217           .def("boxed", [](c10::intrusive_ptr<::c10d::ProcessGroup> self) {
2218             return torch::jit::toPyObject(c10::IValue(std::move(self)));
2219           })
2220           .def_static("unbox", [](py::object obj) {
2221               auto typePtr = torch::getCustomClass("__torch__.torch.classes.c10d.ProcessGroup");
2222               auto ivalue = torch::jit::toIValue(std::move(obj), typePtr);
2223               return ivalue.toCustomClass<::c10d::ProcessGroup>();
2224           });
2225 
2226   py::enum_<::c10d::ProcessGroup::BackendType>(processGroup, "BackendType")
2227       .value("UNDEFINED", ::c10d::ProcessGroup::BackendType::UNDEFINED)
2228       .value("GLOO", ::c10d::ProcessGroup::BackendType::GLOO)
2229       .value("NCCL", ::c10d::ProcessGroup::BackendType::NCCL)
2230       .value("UCC", ::c10d::ProcessGroup::BackendType::UCC)
2231       .value("MPI", ::c10d::ProcessGroup::BackendType::MPI)
2232       .value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM)
2233       .export_values();
2234 
2235   // base ProcessGroup::Options binding
2236   auto processGroupOptions =
2237       intrusive_ptr_class_<::c10d::ProcessGroup::Options>(
2238           processGroup,
2239           "Options",
2240           R"(
2241 Base class for all processes group options implementations, such as the nccl
2242 options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
2243 )")
2244           .def(
2245               py::init([](const std::string& backend,
2246                           const std::chrono::milliseconds& timeout) {
2247                 return c10::make_intrusive<::c10d::ProcessGroup::Options>(
2248                     backend, timeout);
2249               }),
2250               py::arg("backend"),
2251               py::arg("timeout") = kProcessGroupDefaultTimeout,
2252               py::call_guard<py::gil_scoped_release>())
2253           .def_readonly("backend", &::c10d::ProcessGroup::Options::backend)
2254           .def_readwrite("_timeout", &::c10d::ProcessGroup::Options::timeout);
2255 
2256   // TODO: The collection definitions handles direct instantiation of
2257   // ProcessGroup subclasses (e.g. dist.ProcessGroupGloo). This is not supported
2258   // and should be removed once all tests are transitioned
2259   auto backend =
2260       py::class_<::c10d::Backend, c10::intrusive_ptr<::c10d::Backend>>(
2261           module, "Backend")
2262           .def("rank", &::c10d::Backend::getRank)
2263           .def("size", &::c10d::Backend::getSize)
2264           .def("name", &::c10d::Backend::getBackendName)
2265           .def_property_readonly(
2266               "supports_splitting",
2267               &::c10d::Backend::supportsSplitting,
2268               "(test whether the backend supports splitting)")
2269           .def(
2270               "broadcast",
2271               &::c10d::Backend::broadcast,
2272               py::arg("tensors"),
2273               py::arg("opts") = ::c10d::BroadcastOptions(),
2274               py::call_guard<py::gil_scoped_release>())
2275           .def(
2276               "broadcast",
2277               [](const c10::intrusive_ptr<::c10d::Backend>& self,
2278                  at::Tensor& x,
2279                  int rootRank) {
2280                 ::c10d::BroadcastOptions opts;
2281                 opts.rootRank = rootRank;
2282                 std::vector<at::Tensor> xs = {x};
2283                 return self->broadcast(xs, opts);
2284               },
2285               py::arg("tensor"),
2286               py::arg("root"),
2287               py::call_guard<py::gil_scoped_release>())
2288           .def(
2289               "allreduce",
2290               &::c10d::Backend::allreduce,
2291               py::arg("tensors"),
2292               py::arg("opts") = ::c10d::AllreduceOptions(),
2293               py::call_guard<py::gil_scoped_release>())
2294           .def(
2295               "allreduce",
2296               [](const c10::intrusive_ptr<::c10d::Backend>& self,
2297                  std::vector<at::Tensor>& xs,
2298                  const ::c10d::ReduceOp& op) {
2299                 ::c10d::AllreduceOptions opts;
2300                 opts.reduceOp = op;
2301                 return self->allreduce(xs, opts);
2302               },
2303               py::arg("tensors"),
2304               py::arg("op") = ::c10d::ReduceOp::SUM,
2305               py::call_guard<py::gil_scoped_release>())
2306           .def(
2307               "allreduce",
2308               [](const c10::intrusive_ptr<::c10d::Backend>& self,
2309                  at::Tensor& x,
2310                  const ::c10d::ReduceOp& op) {
2311                 ::c10d::AllreduceOptions opts;
2312                 opts.reduceOp = op;
2313                 std::vector<at::Tensor> xs = {x};
2314                 return self->allreduce(xs, opts);
2315               },
2316               py::arg("tensor"),
2317               py::arg("op") = ::c10d::ReduceOp::SUM,
2318               py::call_guard<py::gil_scoped_release>())
2319           .def(
2320               "allreduce_coalesced",
2321               &::c10d::Backend::allreduce_coalesced,
2322               py::arg("tensors"),
2323               py::arg("opts") = ::c10d::AllreduceCoalescedOptions(),
2324               py::call_guard<py::gil_scoped_release>())
2325           .def(
2326               "reduce",
2327               &::c10d::Backend::reduce,
2328               py::arg("tensors"),
2329               py::arg("opts") = ::c10d::ReduceOptions(),
2330               py::call_guard<py::gil_scoped_release>())
2331           .def(
2332               "reduce",
2333               [](const c10::intrusive_ptr<::c10d::Backend>& self,
2334                  at::Tensor& x,
2335                  int rootRank,
2336                  const ::c10d::ReduceOp& op) {
2337                 ::c10d::ReduceOptions opts;
2338                 opts.reduceOp = op;
2339                 opts.rootRank = rootRank;
2340                 std::vector<at::Tensor> xs = {x};
2341                 return self->reduce(xs, opts);
2342               },
2343               py::arg("tensor"),
2344               py::arg("root"),
2345               py::arg("op") = ::c10d::ReduceOp::SUM,
2346               py::call_guard<py::gil_scoped_release>())
2347           .def(
2348               "allgather",
2349               &::c10d::Backend::allgather,
2350               py::arg("output_tensors"),
2351               py::arg("input_tensors"),
2352               py::arg("opts") = ::c10d::AllgatherOptions(),
2353               py::call_guard<py::gil_scoped_release>())
2354           .def(
2355               "_allgather_base",
2356               &::c10d::Backend::_allgather_base,
2357               py::arg("output"),
2358               py::arg("input"),
2359               py::arg("opts") = ::c10d::AllgatherOptions(),
2360               py::call_guard<py::gil_scoped_release>())
2361           .def(
2362               "allgather",
2363               [](const c10::intrusive_ptr<::c10d::Backend>& self,
2364                  std::vector<at::Tensor>& output,
2365                  at::Tensor& input) {
2366                 std::vector<std::vector<at::Tensor>> outputs = {output};
2367                 std::vector<at::Tensor> inputs = {input};
2368                 return self->allgather(
2369                     outputs, inputs, ::c10d::AllgatherOptions());
2370               },
2371               py::arg("output_tensors"),
2372               py::arg("input_tensor"),
2373               py::call_guard<py::gil_scoped_release>())
2374           .def(
2375               "allgather_coalesced",
2376               &::c10d::Backend::allgather_coalesced,
2377               py::arg("output_lists"),
2378               py::arg("input_list"),
2379               py::arg("opts") = ::c10d::AllgatherOptions(),
2380               py::call_guard<py::gil_scoped_release>())
2381           .def(
2382               "gather",
2383               &::c10d::Backend::gather,
2384               py::arg("output_tensors"),
2385               py::arg("input_tensors"),
2386               py::arg("opts") = ::c10d::GatherOptions(),
2387               py::call_guard<py::gil_scoped_release>())
2388           .def(
2389               "gather",
2390               [](const c10::intrusive_ptr<::c10d::Backend>& self,
2391                  std::vector<at::Tensor>& output,
2392                  at::Tensor& input,
2393                  int rootRank) {
2394                 ::c10d::GatherOptions opts;
2395                 opts.rootRank = rootRank;
2396                 std::vector<std::vector<at::Tensor>> outputs = {output};
2397                 std::vector<at::Tensor> inputs = {input};
2398                 return self->gather(outputs, inputs, opts);
2399               },
2400               py::arg("output_tensors"),
2401               py::arg("input_tensor"),
2402               py::arg("root"),
2403               py::call_guard<py::gil_scoped_release>())
2404           .def(
2405               "scatter",
2406               &::c10d::Backend::scatter,
2407               py::arg("output_tensors"),
2408               py::arg("input_tensors"),
2409               py::arg("opts") = ::c10d::ScatterOptions(),
2410               py::call_guard<py::gil_scoped_release>())
2411           .def(
2412               "scatter",
2413               [](const c10::intrusive_ptr<::c10d::Backend>& self,
2414                  at::Tensor& output,
2415                  std::vector<at::Tensor>& input,
2416                  int rootRank) {
2417                 ::c10d::ScatterOptions opts;
2418                 opts.rootRank = rootRank;
2419                 std::vector<std::vector<at::Tensor>> inputs = {input};
2420                 std::vector<at::Tensor> outputs = {output};
2421                 return self->scatter(outputs, inputs, opts);
2422               },
2423               py::arg("output_tensor"),
2424               py::arg("input_tensors"),
2425               py::arg("root"),
2426               py::call_guard<py::gil_scoped_release>())
2427           .def(
2428               "reduce_scatter",
2429               &::c10d::Backend::reduce_scatter,
2430               py::arg("output_tensors"),
2431               py::arg("input_tensors"),
2432               py::arg("opts") = ::c10d::ReduceScatterOptions(),
2433               py::call_guard<py::gil_scoped_release>())
2434           .def(
2435               "reduce_scatter",
2436               [](const c10::intrusive_ptr<::c10d::Backend>& self,
2437                  at::Tensor& output,
2438                  std::vector<at::Tensor>& input,
2439                  const ::c10d::ReduceOp& op) {
2440                 std::vector<at::Tensor> outputs = {output};
2441                 std::vector<std::vector<at::Tensor>> inputs = {input};
2442                 ::c10d::ReduceScatterOptions opts;
2443                 opts.reduceOp = op;
2444                 return self->reduce_scatter(outputs, inputs, opts);
2445               },
2446               py::arg("output_tensors"),
2447               py::arg("input_tensor"),
2448               py::arg("op") = ::c10d::ReduceOp::SUM,
2449               py::call_guard<py::gil_scoped_release>())
2450           .def(
2451               "_reduce_scatter_base",
2452               &::c10d::Backend::_reduce_scatter_base,
2453               py::arg("outputTensor"),
2454               py::arg("inputTensor"),
2455               py::arg("opts") = ::c10d::ReduceScatterOptions(),
2456               py::call_guard<py::gil_scoped_release>())
2457           .def(
2458               "alltoall_base",
2459               &::c10d::Backend::alltoall_base,
2460               py::arg("output_tensor"),
2461               py::arg("input_tensor"),
2462               py::arg("output_split_sizes"),
2463               py::arg("input_split_sizes"),
2464               py::arg("opts") = ::c10d::AllToAllOptions(),
2465               py::call_guard<py::gil_scoped_release>())
2466           .def(
2467               "alltoall_base",
2468               [](::c10d::Backend& self,
2469                  at::Tensor& output,
2470                  at::Tensor& input,
2471                  std::vector<int64_t> outputSplitSizes,
2472                  std::vector<int64_t> inputSplitSizes) {
2473                 return self.alltoall_base(
2474                     output,
2475                     input,
2476                     outputSplitSizes,
2477                     inputSplitSizes,
2478                     ::c10d::AllToAllOptions());
2479               },
2480               py::arg("output"),
2481               py::arg("input"),
2482               py::arg("output_split_sizes"),
2483               py::arg("input_split_sizes"),
2484               py::call_guard<py::gil_scoped_release>())
2485           .def(
2486               "alltoall",
2487               &::c10d::Backend::alltoall,
2488               py::arg("output_tensor"),
2489               py::arg("input_tensor"),
2490               py::arg("opts") = ::c10d::AllToAllOptions(),
2491               py::call_guard<py::gil_scoped_release>())
2492           .def(
2493               "send",
2494               &::c10d::Backend::send,
2495               py::arg("tensors"),
2496               py::arg("dstRank"),
2497               py::arg("tag"),
2498               py::call_guard<py::gil_scoped_release>())
2499           .def(
2500               "recv",
2501               &::c10d::Backend::recv,
2502               py::arg("tensors"),
2503               py::arg("srcRank"),
2504               py::arg("tag"),
2505               py::call_guard<py::gil_scoped_release>())
2506           .def(
2507               "recv_anysource",
2508               &::c10d::Backend::recvAnysource,
2509               py::call_guard<py::gil_scoped_release>())
2510           .def(
2511               "barrier",
2512               [](const c10::intrusive_ptr<::c10d::Backend>& self,
2513                  const ::c10d::BarrierOptions& opts) {
2514                 return self->barrier(opts);
2515               },
2516               py::arg("opts") = ::c10d::BarrierOptions(),
2517               py::call_guard<py::gil_scoped_release>())
2518           .def(
2519               "_set_sequence_number_for_group",
2520               &::c10d::Backend::setSequenceNumberForGroup,
2521               py::call_guard<py::gil_scoped_release>())
2522           .def(
2523               "_get_sequence_number_for_group",
2524               &::c10d::Backend::getSequenceNumberForGroup,
2525               py::call_guard<py::gil_scoped_release>())
2526           .def(
2527               "monitored_barrier",
2528               [](const c10::intrusive_ptr<::c10d::Backend>& self,
2529                  const std::chrono::milliseconds& timeout,
2530                  bool waitAllRanks) {
2531                 ::c10d::BarrierOptions opts;
2532                 opts.timeout = timeout;
2533                 return self->monitoredBarrier(opts, waitAllRanks);
2534               },
2535               py::arg("timeout") = ::c10d::kUnsetTimeout,
2536               py::arg("wait_all_ranks") = false,
2537               py::call_guard<py::gil_scoped_release>())
2538           .def(
2539               "eager_connect_single_device",
2540               &::c10d::Backend::eagerConnectSingleDevice,
2541               py::call_guard<py::gil_scoped_release>())
2542           .def(
2543               "_get_backend_name",
2544               &::c10d::Backend::getBackendName,
2545               py::call_guard<py::gil_scoped_release>())
2546           .def(
2547               "_start_coalescing",
2548               &::c10d::Backend::startCoalescing,
2549               py::call_guard<py::gil_scoped_release>())
2550           .def(
2551               "_end_coalescing",
2552               &::c10d::Backend::endCoalescing,
2553               py::call_guard<py::gil_scoped_release>());
2554 
2555 #ifdef USE_C10D_GLOO
2556   static const std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME";
2557 
2558   auto processGroupGloo =
2559       intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupGloo>(
2560           module, "ProcessGroupGloo", backend);
2561 
2562   // NOLINTNEXTLINE(bugprone-unused-raii)
2563   shared_ptr_class_<::gloo::transport::Device>(processGroupGloo, "Device");
2564 
2565   intrusive_ptr_class_<::c10d::ProcessGroupGloo::Options>(
2566       processGroupGloo, "_Options", processGroupOptions)
2567       .def(py::init<>())
2568       .def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices)
2569       .def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads);
2570 
2571   processGroupGloo
2572       .def_static(
2573           "create_device",
2574           [](const std::string& hostname, const std::string& interface)
2575               -> std::shared_ptr<::gloo::transport::Device> {
2576             if (!hostname.empty()) {
2577               return ::c10d::ProcessGroupGloo::createDeviceForHostname(
2578                   hostname);
2579             }
2580             if (!interface.empty()) {
2581               return ::c10d::ProcessGroupGloo::createDeviceForInterface(
2582                   interface);
2583             }
2584             throw std::invalid_argument(
2585                 "Specify either `hostname` or `interface` argument.");
2586           },
2587           py::arg("hostname") = "",
2588           py::arg("interface") = "")
2589       .def_static(
2590           "create_default_device",
2591           &::c10d::ProcessGroupGloo::createDefaultDevice);
2592 
2593   processGroupGloo
2594       .def(
2595           py::init<
2596               const c10::intrusive_ptr<::c10d::Store>&,
2597               int,
2598               int,
2599               c10::intrusive_ptr<::c10d::ProcessGroupGloo::Options>>(),
2600           py::call_guard<py::gil_scoped_release>())
2601       .def(
2602           py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
2603                       int rank,
2604                       int size,
2605                       std::chrono::milliseconds timeout) {
2606             auto options = ::c10d::ProcessGroupGloo::Options::create();
2607 
2608             // Use interfaces listed in "GLOO_SOCKET_IFNAME", if set.
2609             char* ifnameEnv = getenv(GLOO_SOCKET_IFNAME_ENV.c_str());
2610             if (ifnameEnv && strlen(ifnameEnv) > 1) {
2611               for (const auto& iface : ::c10d::split(',', ifnameEnv)) {
2612                 options->devices.push_back(
2613                     ::c10d::ProcessGroupGloo::createDeviceForInterface(iface));
2614               }
2615             } else {
2616               // If no hostname is specified, this function looks up
2617               // the machine's hostname and returns a device instance
2618               // associated with the address that the hostname resolves to.
2619               options->devices.push_back(
2620                   ::c10d::ProcessGroupGloo::createDefaultDevice());
2621             }
2622 
2623             options->timeout = timeout;
2624             // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
2625             options->threads = options->devices.size() * 2;
2626             return c10::make_intrusive<::c10d::ProcessGroupGloo>(
2627                 store, rank, size, options);
2628           }),
2629           py::arg("store"),
2630           py::arg("rank"),
2631           py::arg("size"),
2632           py::arg("timeout") = kProcessGroupDefaultTimeout,
2633           py::call_guard<py::gil_scoped_release>())
2634       .def(
2635           "_set_default_timeout",
2636           [](const c10::intrusive_ptr<::c10d::ProcessGroupGloo>& self,
2637              std::chrono::milliseconds timeout) {
2638             self->getOptions()->timeout = timeout;
2639           },
2640           py::arg("timeout"),
2641           py::call_guard<py::gil_scoped_release>())
2642       .def_property_readonly("options", &::c10d::ProcessGroupGloo::getOptions);
2643 
2644   // ProcessGroupWrapper is a wrapper pg that includes a helper gloo process
2645   // group. It can be used to validate collective calls across processes by
2646   // checking the op type and input tensor shapes.
2647   auto processGroupWrapper =
2648       intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupWrapper>(
2649           module, "_ProcessGroupWrapper", backend)
2650           .def(
2651               py::init(
2652                   [](const c10::intrusive_ptr<::c10d::Backend>& backend,
2653                      const c10::intrusive_ptr<::c10d::Backend>& gloo_backend) {
2654                     return c10::make_intrusive<::c10d::ProcessGroupWrapper>(
2655                         backend, gloo_backend);
2656                   }),
2657               py::arg("backend"),
2658               py::arg("gloo_backend"),
2659               py::call_guard<py::gil_scoped_release>())
2660           .def_property_readonly(
2661               "wrapped_pg", &::c10d::ProcessGroupWrapper::getWrappedPg);
2662 #endif
2663 
2664 #ifdef USE_C10D_NCCL
2665   auto processGroupNCCL =
2666       intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupNCCL>(
2667           module, "ProcessGroupNCCL", backend)
2668           .def(
2669               py::init<
2670                   const c10::intrusive_ptr<::c10d::Store>&,
2671                   int,
2672                   int,
2673                   c10::intrusive_ptr<::c10d::ProcessGroupNCCL::Options>>(),
2674               py::call_guard<py::gil_scoped_release>())
2675           .def(
2676               py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
2677                           int rank,
2678                           int size,
2679                           const std::chrono::milliseconds& timeout) {
2680                 auto options = ::c10d::ProcessGroupNCCL::Options::create();
2681                 options->is_high_priority_stream = false;
2682                 options->timeout = timeout;
2683                 return c10::make_intrusive<::c10d::ProcessGroupNCCL>(
2684                     store, rank, size, options);
2685               }),
2686               py::arg("store"),
2687               py::arg("rank"),
2688               py::arg("size"),
2689               py::arg("timeout") = ::c10d::kProcessGroupNCCLDefaultTimeout,
2690               py::call_guard<py::gil_scoped_release>())
2691           .def(
2692               "_shutdown",
2693               [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) {
2694                 return self->shutdown();
2695               },
2696               py::call_guard<py::gil_scoped_release>())
2697           .def("_group_start", &::c10d::ProcessGroupNCCL::groupStart)
2698           .def("_group_end", &::c10d::ProcessGroupNCCL::groupEnd)
2699           .def(
2700               "comm_split_count",
2701               &::c10d::ProcessGroupNCCL::getCommSplitCounter)
2702           .def(
2703               "_set_default_timeout",
2704               [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self,
2705                  std::chrono::milliseconds timeout) {
2706                 self->getOptions()->timeout = timeout;
2707               },
2708               py::arg("timeout"),
2709               py::call_guard<py::gil_scoped_release>())
2710           .def(
2711               "_add_ephemeral_timeout",
2712               [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self,
2713                  const std::chrono::milliseconds& timeout) {
2714                 self->addEphemeralTimeout(timeout);
2715               },
2716               py::arg("timeout"))
2717           .def(
2718               "_verify_work_timeout",
2719               [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self,
2720                  const c10::intrusive_ptr<::c10d::Work> work,
2721                  const std::chrono::milliseconds& timeout) {
2722                 return self->verifyWorkTimeoutForTest(work, timeout);
2723               },
2724               py::arg("work"),
2725               py::arg("timeout"))
2726           .def_property_readonly(
2727               "options", &::c10d::ProcessGroupNCCL::getOptions)
2728           .def_property_readonly("uid", &::c10d::ProcessGroupNCCL::getUid)
2729           .def_property(
2730               "bound_device_id",
2731               &::c10d::ProcessGroupNCCL::getBoundDeviceId,
2732               &::c10d::ProcessGroupNCCL::setBoundDeviceId)
2733           .def(
2734               "perform_nocolor_split",
2735               &::c10d::ProcessGroupNCCL::performNocolorSplit);
2736 
2737   module.def(
2738       "_get_intra_node_comm_usage_counter",
2739       &::c10d::intra_node_comm::getIntraNodeCommUsageCounter);
2740 
2741   using IntraNodeComm = ::c10d::intra_node_comm::IntraNodeComm;
2742   py::class_<IntraNodeComm, c10::intrusive_ptr<IntraNodeComm>>(
2743       module, "_IntraNodeComm")
2744       .def(
2745           py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
2746                       size_t rank,
2747                       size_t world_size,
2748                       std::optional<size_t> buffer_size) {
2749             auto comm = c10::make_intrusive<IntraNodeComm>(
2750                 store, rank, world_size, buffer_size);
2751             if (!comm->rendezvous()) {
2752               throw std::runtime_error("IntraNodeComm::rendezvous failed");
2753             }
2754             return comm;
2755           }),
2756           py::arg("store"),
2757           py::arg("rank"),
2758           py::arg("world_size"),
2759           py::arg("buffer_size") = std::nullopt)
2760       .def("barrier", &IntraNodeComm::barrier, py::arg("ranks") = py::none());
2761 
2762 #ifdef NCCL_HAS_COMM_CTA_CGA
2763   py::class_<ncclConfig_t>(
2764       processGroupNCCL,
2765       "NCCLConfig",
2766       R"(
2767 ncclConfig_t data type for configuring NCCL communicators.
2768 See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
2769 for details.
2770 )")
2771       .def(py::init<>())
2772       .def_readwrite("blocking", &ncclConfig_t::blocking)
2773       .def_readwrite("cga_cluster_size", &ncclConfig_t::cgaClusterSize)
2774       .def_readwrite("min_ctas", &ncclConfig_t::minCTAs)
2775       .def_readwrite("max_ctas", &ncclConfig_t::maxCTAs)
2776 #ifdef NCCL_HAS_COMM_SPLIT
2777       .def_readwrite("split_share", &ncclConfig_t::splitShare)
2778 #endif
2779       .def_property(
2780           "net_name",
2781           [](const ncclConfig_t& self) { return self.netName; },
2782           // Note: NCCL calls free on the netName pointer
2783           // when destroying the communicator. So memory
2784           // shouldn't leak because of allocation in strdup.
2785           [](ncclConfig_t& self, const char* tmp) {
2786             self.netName = strdup(tmp);
2787           });
2788 #endif
2789 
2790   intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>(
2791       processGroupNCCL,
2792       "Options",
2793       processGroupOptions,
2794       R"(
2795 ProcessGroup options for the NCCL backend
2796 
2797 Arguments:
2798     is_high_priority_stream (bool, optional): flag to enable/disable process
2799             group to pick up high priority cuda streams. It lets CUDA driver
2800             to prioritize NCCL kernels when there are compute kernels waiting.
2801             Default is False.
2802 
2803 Attributes:
2804     config (NCCLConfig): configures NCCL communicators (only avaiable for
2805             builds using NCCL 2.17+). This can be used to improve
2806             communication-computation overlap for NCCL kernels by tuning
2807             available parameters in the config. See
2808             https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
2809             for details.
2810 
2811 Example::
2812     >>> import torch.distributed as dist
2813     >>>
2814     >>> nccl_options = dist.ProcessGroupNCCL.Options(is_high_priority_stream=True)
2815     >>> # For builds using NCCL 2.17+, configure communicators
2816     >>> nccl_options.config.cga_cluster_size = 2
2817     >>> nccl_options.config.max_ctas = 4
2818     >>> nccl_options.config.min_ctas = 2
2819     >>> nccl_options.config.split_share = 1
2820     >>> # initialize a nccl process group with the options just created
2821     >>> dist.init_process_group("nccl", pg_options=nccl_options)
2822       )")
2823       .def(py::init<bool>(), py::arg("is_high_priority_stream") = false)
2824 #ifdef NCCL_HAS_COMM_CTA_CGA
2825       .def_readwrite("config", &::c10d::ProcessGroupNCCL::Options::config)
2826 #endif
2827       .def_readwrite(
2828           "is_high_priority_stream",
2829           &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream)
2830       .def_readwrite(
2831           "split_from", &::c10d::ProcessGroupNCCL::Options::split_from)
2832       .def_readwrite(
2833           "split_color", &::c10d::ProcessGroupNCCL::Options::split_color)
2834       .def_readwrite(
2835           "global_ranks_in_group",
2836           &::c10d::ProcessGroupNCCL::Options::global_ranks_in_group)
2837       .def_readwrite(
2838           "group_name", &::c10d::ProcessGroupNCCL::Options::group_name);
2839 #endif
2840 
2841 #ifdef USE_C10D_MPI
2842   auto processGroupMPI =
2843       intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupMPI>(
2844           module, "ProcessGroupMPI", backend);
2845 
2846   // Define static create function instead of a constructor, because
2847   // this function may return null. This happens if this process is not
2848   // part of a sub group that is to be created.
2849   processGroupMPI.def_static(
2850       "create",
2851       [](std::vector<int> ranks) {
2852         return ::c10d::ProcessGroupMPI::createProcessGroupMPI(std::move(ranks));
2853       },
2854       py::call_guard<py::gil_scoped_release>());
2855 #endif
2856 
2857 #ifdef USE_C10D_UCC
2858   auto processGroupUCC =
2859       intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupUCC>(
2860           module, "ProcessGroupUCC", backend)
2861           .def(
2862               py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
2863                           int rank,
2864                           int size,
2865                           const std::chrono::milliseconds& timeout) {
2866                 return c10::make_intrusive<::c10d::ProcessGroupUCC>(
2867                     store, rank, size, timeout);
2868               }),
2869               py::arg("store"),
2870               py::arg("rank"),
2871               py::arg("size"),
2872               py::arg("timeout") = kProcessGroupDefaultTimeout,
2873               py::call_guard<py::gil_scoped_release>());
2874 #endif
2875 
2876   py::enum_<::c10d::OpType>(module, "OpType")
2877       .value("BROADCAST", ::c10d::OpType::BROADCAST)
2878       .value("ALLREDUCE", ::c10d::OpType::ALLREDUCE)
2879       .value("ALLREDUCE_COALESCED", ::c10d::OpType::ALLREDUCE_COALESCED)
2880       .value("REDUCE", ::c10d::OpType::REDUCE)
2881       .value("ALLGATHER", ::c10d::OpType::ALLGATHER)
2882       .value("_ALLGATHER_BASE", ::c10d::OpType::_ALLGATHER_BASE)
2883       .value("ALLGATHER_COALESCED", ::c10d::OpType::ALLGATHER_COALESCED)
2884       .value("GATHER", ::c10d::OpType::GATHER)
2885       .value("SCATTER", ::c10d::OpType::SCATTER)
2886       .value("REDUCE_SCATTER", ::c10d::OpType::REDUCE_SCATTER)
2887       .value("ALLTOALL_BASE", ::c10d::OpType::ALLTOALL_BASE)
2888       .value("ALLTOALL", ::c10d::OpType::ALLTOALL)
2889       .value("SEND", ::c10d::OpType::SEND)
2890       .value("RECV", ::c10d::OpType::RECV)
2891       .value("RECVANYSOURCE", ::c10d::OpType::RECVANYSOURCE)
2892       .value("BARRIER", ::c10d::OpType::BARRIER)
2893       .value("_REDUCE_SCATTER_BASE", ::c10d::OpType::_REDUCE_SCATTER_BASE)
2894       .value("COALESCED", ::c10d::OpType::COALESCED)
2895       .value("_ALLREDUCE_SPARSE", ::c10d::OpType::_ALLREDUCE_SPARSE)
2896       .value("UNKNOWN", ::c10d::OpType::UNKNOWN);
2897 
2898   py::class_<::c10d::WorkInfo, std::shared_ptr<::c10d::WorkInfo>>(
2899       module, "WorkInfo")
2900       .def_readonly("op_type", &::c10d::WorkInfo::opType)
2901       .def_readonly("seq", &::c10d::WorkInfo::seq)
2902       .def_readonly("time_started", &::c10d::WorkInfo::timeStarted)
2903       .def_readonly("time_finished", &::c10d::WorkInfo::timeFinished)
2904       .def_readonly("active_duration", &::c10d::WorkInfo::activeDuration);
2905 
2906   py::class_<
2907       ::c10d::Work,
2908       c10::intrusive_ptr<::c10d::Work>,
2909       ::c10d::PyProcessGroup::PyWork>(module, "Work", R"(
2910 A `Work` object represents the handle to a pending asynchronous operation in
2911 PyTorch's distributed package. It is returned by non-blocking collective operations,
2912 such as `dist.all_reduce(tensor, async_op=True)`.
2913 )")
2914       .def(py::init<>())
2915       .def("is_completed", &::c10d::Work::isCompleted)
2916       .def(
2917           "is_success",
2918           [](::c10d::Work& work) -> bool {
2919             TORCH_WARN_ONCE(
2920                 fmt::format(kDeprecationWarning, "Work::is_success"));
2921             return work.isSuccess();
2922           })
2923       .def(
2924           "exception",
2925           [](::c10d::Work& work) -> std::exception_ptr {
2926             TORCH_WARN_ONCE(
2927                 fmt::format(kDeprecationWarning, "Work::exception"));
2928             return work.exception();
2929           })
2930       .def(
2931           "source_rank",
2932           [](::c10d::Work& work) -> int {
2933             TORCH_WARN_ONCE(
2934                 fmt::format(kDeprecationWarning, "Work::source_rank"));
2935             return work.sourceRank();
2936           })
2937       .def("_source_rank", &::c10d::Work::sourceRank)
2938       .def(
2939           "result",
2940           [](::c10d::Work& work) -> std::vector<at::Tensor> {
2941             // Deprecation reason:
2942             // Work.result() returns a vector of tensors. This signature is
2943             // problematic as some collectives may just return one tensor
2944             // (e.g all-reduce), while some others may return multiple
2945             // tensors (e.g. all-gather).
2946             // Deprecating work.result() would
2947             // also allow us to remove the `outputs_` field in the Work
2948             // class, avoiding an "artificial" reference to the tensors,
2949             // which could potentially hold up the tensors' memory.
2950             TORCH_WARN_ONCE(fmt::format(kDeprecationWarning, "Work::result"));
2951             return work.result();
2952           })
2953       .def(
2954           "synchronize",
2955           [](::c10d::Work& work) -> void {
2956             TORCH_WARN_ONCE(
2957                 fmt::format(kDeprecationWarning, "Work::synchronize"));
2958             work.synchronize();
2959           })
2960       .def(
2961           "wait",
2962           &::c10d::Work::wait,
2963           py::arg("timeout") = kNoTimeout,
2964           py::call_guard<py::gil_scoped_release>())
2965       .def(
2966           "get_future",
2967           [](::c10d::Work& work) -> std::shared_ptr<jit::PythonFutureWrapper> {
2968             return std::make_shared<jit::PythonFutureWrapper>(work.getFuture());
2969           },
2970           R"(
2971             Returns:
2972                 A ``torch.futures.Future`` object which is associated with the completion of
2973                 the ``Work``. As an example, a future object can be retrieved
2974                 by ``fut = process_group.allreduce(tensors).get_future()``.
2975 
2976             Example::
2977                 Below is an example of a simple allreduce DDP communication hook that uses
2978                 ``get_future` API to retrieve a Future associated with the completion of
2979                 ``allreduce``.
2980 
2981                 >>> def allreduce(process_group: dist.ProcessGroup, bucket: dist.GradBucket): -> torch.futures.Future
2982                 >>>     group_to_use = process_group if process_group is not None else torch.distributed.group.WORLD
2983                 >>>     tensor = bucket.buffer().div_(group_to_use.size())
2984                 >>>     return torch.distributed.all_reduce(tensor, group=group_to_use, async_op=True).get_future()
2985                 >>> ddp_model.register_comm_hook(state=None, hook=allreduce)
2986 
2987             .. warning ::
2988                 ``get_future`` API supports NCCL, and partially GLOO and MPI backends
2989                 (no support for peer-to-peer operations like send/recv) and will return a ``torch.futures.Future``.
2990 
2991                 In the example above, ``allreduce`` work will be done on GPU using NCCL backend,
2992                 ``fut.wait()`` will return after synchronizing the appropriate NCCL streams
2993                 with PyTorch's current device streams to ensure we can have asynchronous CUDA
2994                 execution and it does not wait for the entire operation to complete on GPU. Note that
2995                 ``CUDAFuture``  does not support ``TORCH_NCCL_BLOCKING_WAIT`` flag or NCCL's ``barrier()``.
2996                 In addition, if a callback function was added by ``fut.then()``, it will wait until
2997                 ``WorkNCCL``'s NCCL streams synchronize with ``ProcessGroupNCCL``'s dedicated callback
2998                 stream and invoke the callback inline after running the callback on the callback stream.
2999                 ``fut.then()`` will return another ``CUDAFuture`` that holds the return value of the
3000                 callback and a ``CUDAEvent`` that recorded the callback stream.
3001 
3002                     1. For CPU work, ``fut.done()`` returns true when work has been completed and value()
3003                        tensors are ready.
3004                     2. For GPU work, ``fut.done()`` returns true only whether the operation has been enqueued.
3005                     3. For mixed CPU-GPU work (e.g. sending GPU tensors with GLOO), ``fut.done()`` returns
3006                        true when tensors have arrived on respective nodes, but not yet necessarily synched on
3007                        respective GPUs (similarly to GPU work).
3008            )")
3009       .def(
3010           "_get_op_type",
3011           [](::c10d::Work& work) -> int {
3012             return static_cast<int>(work.retrieveOpType());
3013           })
3014       .def(
3015           "_get_duration",
3016           &::c10d::Work::getDuration,
3017           py::call_guard<py::gil_scoped_release>(),
3018           R"(
3019               Returns:
3020                   Duration of the corresponding collective communication.
3021 
3022               .. warning ::
3023                   This API only works for NCCL backend for now and must set
3024                   TORCH_NCCL_ENABLE_TIMING environment variable.
3025             )")
3026       .def(
3027           "boxed",
3028           [](c10::intrusive_ptr<::c10d::Work> self) {
3029             return torch::jit::toPyObject(c10::IValue(std::move(self)));
3030           })
3031       .def_static("unbox", [](py::object obj) {
3032         auto typePtr =
3033             torch::getCustomClass("__torch__.torch.classes.c10d.Work");
3034         auto ivalue = torch::jit::toIValue(std::move(obj), typePtr);
3035         return ivalue.toCustomClass<::c10d::Work>();
3036       });
3037 
3038   auto fakeProcessGroup =
3039       intrusive_ptr_no_gil_destructor_class_<::c10d::FakeProcessGroup>(
3040           module, "FakeProcessGroup", backend)
3041           .def(py::init([](int rank, int size) {
3042             return c10::make_intrusive<::c10d::FakeProcessGroup>(rank, size);
3043           }));
3044 
3045   py::class_<c10::DDPLoggingData>(module, "DDPLoggingData")
3046       .def(py::init<>())
3047       .def_readwrite("strs_map", &c10::DDPLoggingData::strs_map)
3048       .def_readwrite("ints_map", &c10::DDPLoggingData::ints_map);
3049 
3050   module.def(
3051       "_compute_bucket_assignment_by_size",
3052       [](const std::vector<at::Tensor>& tensors,
3053          const std::vector<size_t>& bucket_size_limits,
3054          const std::vector<bool>& expect_sparse_gradient,
3055          const std::vector<int64_t>& tensor_indices,
3056          const std::optional<std::shared_ptr<::c10d::Logger>>& logger) {
3057         if (logger.has_value()) {
3058           std::weak_ptr<::c10d::Logger> logger_weakref = logger.value();
3059           return ::c10d::compute_bucket_assignment_by_size(
3060               tensors,
3061               bucket_size_limits,
3062               expect_sparse_gradient,
3063               tensor_indices,
3064               {logger_weakref});
3065         } else {
3066           return ::c10d::compute_bucket_assignment_by_size(
3067               tensors,
3068               bucket_size_limits,
3069               expect_sparse_gradient,
3070               tensor_indices,
3071               {});
3072         }
3073       },
3074       py::arg("tensors"),
3075       py::arg("bucket_size"),
3076       py::arg("expect_sparse_gradient") = std::vector<bool>(),
3077       py::arg("tensor_indices") = std::vector<int64_t>(),
3078       py::arg("logger") = std::optional<std::shared_ptr<::c10d::Logger>>{},
3079       py::call_guard<py::gil_scoped_release>());
3080 
3081   module.def(
3082       "_verify_params_across_processes",
3083       [](const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group,
3084          const std::vector<at::Tensor>& params,
3085          const std::optional<std::shared_ptr<::c10d::Logger>>& logger) {
3086         if (logger.has_value()) {
3087           std::weak_ptr<::c10d::Logger> logger_weakref = logger.value();
3088           verify_params_across_processes(
3089               process_group, params, {logger_weakref});
3090         } else {
3091           verify_params_across_processes(process_group, params, {});
3092         }
3093       },
3094       py::arg("process_group"),
3095       py::arg("params"),
3096       py::arg("logger") = std::optional<std::shared_ptr<::c10d::Logger>>{},
3097       py::call_guard<py::gil_scoped_release>());
3098 
3099   module.def(
3100       "_broadcast_coalesced",
3101       // Define a lambda such that the pybind11 prototype can take a std::vector
3102       // for the tensor list argument, but still pass it to the underlying
3103       // function as a c10::ArrayRef.
3104       [](const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group,
3105          const std::vector<at::Tensor>& tensors,
3106          size_t buffer_size,
3107          int rank) {
3108         broadcast_coalesced(process_group, tensors, buffer_size, rank);
3109       },
3110       py::arg("process_group"),
3111       py::arg("tensors"),
3112       py::arg("buffer_size"),
3113       // The source of truth rank to broadcast the tensors from.
3114       py::arg("src") = 0,
3115       py::call_guard<py::gil_scoped_release>());
3116 
3117   module.def(
3118       "_test_python_store",
3119       // Define a function that takes a c10d store and runs a few tests.
3120       // This is used by the PythonStore tests, which we cannot test from the
3121       // Python side of the world. Calling Python functions on a Python object
3122       // completely bypasses pybind11. We need to test that the overloaded
3123       // functions call into Python and behave like we expect.
3124       [](c10::intrusive_ptr<::c10d::Store> store) {
3125         auto add = [&store](const std::string& key, int64_t value) {
3126           store->add(key, value);
3127         };
3128 
3129         auto set = [&store](const std::string& key, const std::string& value) {
3130           store->set(key, value);
3131         };
3132 
3133         auto get = [&store](const std::string& key) {
3134           auto value = store->get(key);
3135           return std::string(value.begin(), value.end());
3136         };
3137 
3138         add("key", 1);
3139         add("key", 2);
3140         add("key", 3);
3141         set("key0", "value0");
3142         add("key3", 1);
3143         set("key1", "value1");
3144         add("key3", 2);
3145         set("key2", "value2");
3146         add("key3", 3);
3147         add("key3", 4);
3148         add("key3", 3);
3149         add("key3", 2);
3150         if (get("key") != "6") {
3151           TORCH_CHECK(false, "assertion failed");
3152         }
3153         if (get("key0") != "value0") {
3154           TORCH_CHECK(false, "assertion failed");
3155         }
3156         if (get("key1") != "value1") {
3157           TORCH_CHECK(false, "assertion failed");
3158         }
3159         if (get("key2") != "value2") {
3160           TORCH_CHECK(false, "assertion failed");
3161         }
3162         if (get("key3") != "15") {
3163           TORCH_CHECK(false, "assertion failed");
3164         }
3165       },
3166       py::call_guard<py::gil_scoped_release>());
3167 
3168   module.attr("_DEFAULT_FIRST_BUCKET_BYTES") = ::c10d::kDefaultFirstBucketBytes;
3169   module.attr("_DEFAULT_PG_TIMEOUT") = py::cast(kProcessGroupDefaultTimeout);
3170 #ifdef USE_C10D_NCCL
3171   module.attr("_DEFAULT_PG_NCCL_TIMEOUT") =
3172       py::cast(::c10d::kProcessGroupNCCLDefaultTimeout);
3173 #endif
3174   module.attr("_DEFAULT_NO_TIMEOUT") = py::cast(kNoTimeout);
3175 
3176   module.def(
3177       "_set_global_rank",
3178       [](int64_t rank) { c10::SetGlobalRank(rank); },
3179       py::arg("rank"),
3180       R"(
3181         Arguments:
3182           rank(int): The rank of the default process group
3183         Informs the C++ runtime what the default process group (a strictly Python
3184         notion) is.  This mostly ensures that C++ log messages are prefixed with
3185         rank information.  This is not meant to be called manually; it is
3186         called by _update_default_pg.
3187       )");
3188 
3189   module.def(
3190       "_create_work_from_future",
3191       [](const std::shared_ptr<jit::PythonFutureWrapper>& future) {
3192         return ::c10d::Work::create_from_future(future->fut);
3193       },
3194       py::arg("future"),
3195       R"(
3196         Arguments:
3197             future(str): The future to wrap.
3198         Returns:
3199             A ``Work`` object which is associated with the completion of
3200             the ``torch.futures.Future``.
3201         This is the preferred way of constructing Work objects when writing a custom ProcessGroup
3202         in python.
3203         Example::
3204             >>> class SingleRankProcessGroup(torch.distributed.ProcessGroup):
3205             >>>     def broadcast(self, tensor_list, opts):
3206             >>>         fut = torch.futures.Future()
3207             >>>         fut.set_result(tensor_list)
3208             >>>         return torch._C._distributed_c10d._create_work_from_future(fut)
3209         .. warning ::
3210             This API is experimental and subject to change.
3211             The returned Work object has multiple limitations:
3212             - synchronize() does nothing. Use ``torch.futures.Future`` based synchronization.
3213             - wait() ignored timeout argument.
3214             - sourceRank() raises.
3215             - abort() raises.
3216             The provided Future object result must be a Tensor or a list of Tensors.
3217            )");
3218 
3219 #ifdef USE_C10D_NCCL
3220   module.def(
3221       "_hash_tensors",
3222       [](const std::vector<at::Tensor>& tensors) {
3223         return ::c10d::hashTensors(tensors);
3224       },
3225       py::arg("tensors"),
3226       R"(
3227         Arguments:
3228           tensors(List[torch.Tensor]): List of tensors we want to hash.
3229       )");
3230   module.def(
3231       "_dump_nccl_trace_json",
3232       [](std::optional<bool> includeCollectives,
3233          std::optional<bool> onlyActive) {
3234         return py::bytes(::c10d::dump_nccl_trace_json(
3235             includeCollectives.value_or(true), onlyActive.value_or(false)));
3236       },
3237       py::arg("includeCollectives") = std::optional<bool>(),
3238       py::arg("onlyActive") = std::optional<bool>(),
3239       R"(
3240       Arguments:
3241             includeCollectives(bool, optional): Whether to include collective work traces. Default is True.
3242             onlyActive (bool, optional): Whether to only include active collective work traces. Default is False.
3243       Returns:
3244             Stringified json work traces.
3245             Default settings return everything - i.e. contains NCCL comm dumps and collective traces.
3246       )");
3247   module.def(
3248       "_dump_nccl_trace",
3249       [](std::optional<bool> includeCollectives,
3250          std::optional<bool> includeStackTraces,
3251          std::optional<bool> onlyActive) {
3252         return py::bytes(::c10d::dump_nccl_trace(
3253             includeCollectives.value_or(true),
3254             includeStackTraces.value_or(true),
3255             onlyActive.value_or(false)));
3256       },
3257       py::arg("includeCollectives") = std::optional<bool>(),
3258       py::arg("includeStackTraces") = std::optional<bool>(),
3259       py::arg("onlyActive") = std::optional<bool>(),
3260       R"(
3261         Arguments:
3262             includeCollectives(bool, optional): Whether to include collective work traces. Default is True.
3263             includeStackTraces(bool, optional): Whether to include stacktraces in the collective work traces. Default is True.
3264             onlyActive (bool, optional): Whether to only include active collective work traces. Default is False.
3265         Returns:
3266             Stringified pickle work traces.
3267             Default settings return everything - i.e. contains NCCL comm dumps and collective traces.
3268       )");
3269 #endif
3270 
3271   intrusive_ptr_class_<::c10d::control_plane::WorkerServer>(
3272       module, "_WorkerServer", R"(
3273 )")
3274       .def(
3275           py::init([](const std::string& hostOrFile, int port) {
3276             return c10::make_intrusive<::c10d::control_plane::WorkerServer>(
3277                 hostOrFile, port);
3278           }),
3279           py::arg("host_or_file"),
3280           py::arg("port") = -1)
3281       .def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
3282 
3283   module.def(
3284       "_get_handler",
3285       [](const std::string& name) -> py::cpp_function {
3286         return py::cpp_function(
3287             ::c10d::control_plane::getHandler(name),
3288             py::arg("request"),
3289             py::arg("response"),
3290             py::call_guard<py::gil_scoped_release>());
3291       },
3292       py::arg("name"),
3293       R"(
3294       Returns the handler with the specified name.
3295     )");
3296 
3297   module.def(
3298       "_get_handler_names",
3299       &::c10d::control_plane::getHandlerNames,
3300       R"(
3301       Returns the names of all handlers.
3302     )",
3303       py::call_guard<py::gil_scoped_release>());
3304 
3305   py::class_<::c10d::control_plane::Request, PythonRequest>(
3306       module,
3307       "_Request",
3308       R"(
3309       See c10d::control_plane::Request for docs.
3310 )")
3311       // Default constructor.
3312       .def(py::init<>())
3313       .def("body", &::c10d::control_plane::Request::body)
3314       .def("params", &::c10d::control_plane::Request::params);
3315 
3316   py::class_<
3317       ::c10d::control_plane::Response,
3318       std::shared_ptr<::c10d::control_plane::Response>,
3319       PythonResponse>(
3320       module,
3321       "_Response",
3322       R"(
3323       See c10d::control_plane::Response for docs.
3324 )")
3325       // Default constructor.
3326       .def(py::init<>())
3327       .def(
3328           "set_content",
3329           &::c10d::control_plane::Response::setContent,
3330           py::arg("content"),
3331           py::arg("content_type"))
3332       .def(
3333           "set_status",
3334           &::c10d::control_plane::Response::setStatus,
3335           py::arg("status"));
3336 
3337   Py_RETURN_TRUE;
3338 }
3339 
3340 #undef PROCESS_GROUP_DEPRECATION_WARNING
3341 
3342 } // namespace
3343 
3344 // c10d methods on torch._C
3345 static PyMethodDef methods[] = { // NOLINT
3346     {"_c10d_init", c10d_init, METH_NOARGS, nullptr},
3347     {nullptr, nullptr, 0, nullptr}};
3348 
python_functions()3349 PyMethodDef* python_functions() {
3350   return methods;
3351 }
3352 
3353 } // namespace torch::distributed::c10d
3354