1 #include <torch/csrc/python_headers.h>
2
3 #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
4 #include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h>
5 #include <torch/csrc/distributed/rpc/py_rref.h>
6 #include <torch/csrc/distributed/rpc/python_functions.h>
7 #include <torch/csrc/distributed/rpc/python_rpc_handler.h>
8 #include <torch/csrc/distributed/rpc/request_callback_impl.h>
9 #include <torch/csrc/distributed/rpc/rpc_agent.h>
10 #include <torch/csrc/distributed/rpc/rref_context.h>
11 #include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
12 #include <torch/csrc/distributed/rpc/torchscript_functions.h>
13 #include <torch/csrc/distributed/rpc/types.h>
14 #include <torch/csrc/jit/python/pybind_utils.h>
15 #include <torch/csrc/utils/object_ptr.h>
16 #include <torch/csrc/utils/pybind.h>
17 #include <torch/types.h>
18
19 #include <pybind11/chrono.h>
20 #include <pybind11/operators.h>
21
22 namespace torch::distributed::rpc {
23
24 namespace {
25
26 constexpr std::chrono::milliseconds kDeleteAllUsersTimeout(100000);
27
28 template <typename T>
29 using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
30
rpc_init(PyObject * _unused,PyObject * noargs)31 PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
32 auto rpc_module =
33 THPObjectPtr(PyImport_ImportModule("torch.distributed.rpc"));
34 if (!rpc_module) {
35 throw python_error();
36 }
37
38 auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
39 if (!torch_C_module) {
40 throw python_error();
41 }
42
43 auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
44 auto m =
45 torch_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings");
46
47 auto module = py::handle(m).cast<py::module>();
48
49 auto rpcBackendOptions =
50 shared_ptr_class_<RpcBackendOptions>(
51 module,
52 "RpcBackendOptions",
53 R"(An abstract structure encapsulating the options passed into the RPC
54 backend. An instance of this class can be passed in to
55 :meth:`~torch.distributed.rpc.init_rpc` in order to initialize RPC
56 with specific configurations, such as the RPC timeout and
57 ``init_method`` to be used. )")
58 .def(py::init<>())
59 .def(
60 py::init<float, std::string>(),
61 py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
62 py::arg("init_method") = kDefaultInitMethod)
63 .def_readwrite(
64 "rpc_timeout",
65 &RpcBackendOptions::rpcTimeoutSeconds,
66 R"(A float indicating the timeout to use for all
67 RPCs. If an RPC does not complete in this timeframe, it will
68 complete with an exception indicating that it has timed out.)")
69 .def_readwrite(
70 "init_method",
71 &RpcBackendOptions::initMethod,
72 R"(URL specifying how to initialize the process group.
73 Default is ``env://``)");
74
75 // The following C++ constants need to be cast so they can be used from
76 // python.
77 module.attr("_DEFAULT_RPC_TIMEOUT_SEC") = py::cast(kDefaultRpcTimeoutSeconds);
78 module.attr("_UNSET_RPC_TIMEOUT") = py::cast(kUnsetRpcTimeout);
79 module.attr("_DEFAULT_INIT_METHOD") = py::cast(kDefaultInitMethod);
80
81 auto workerInfo =
82 shared_ptr_class_<WorkerInfo>(
83 module,
84 "WorkerInfo",
85 R"(A structure that encapsulates information of a worker in the system.
86 Contains the name and ID of the worker. This class is not meant to
87 be constructed directly, rather, an instance can be retrieved
88 through :meth:`~torch.distributed.rpc.get_worker_info` and the
89 result can be passed in to functions such as
90 :meth:`~torch.distributed.rpc.rpc_sync`, :meth:`~torch.distributed.rpc.rpc_async`,
91 :meth:`~torch.distributed.rpc.remote` to avoid copying a string on
92 every invocation.)")
93 .def(
94 py::init<std::string, worker_id_t>(),
95 py::arg("name"),
96 py::arg("id"))
97 .def_readonly(
98 "name", &WorkerInfo::name_, R"(The name of the worker.)")
99 .def_readonly(
100 "id",
101 &WorkerInfo::id_,
102 R"(Globally unique id to identify the worker.)")
103 .def("__eq__", &WorkerInfo::operator==, py::is_operator())
104 // pybind11 suggests the syntax .def(hash(py::self)), with the
105 // unqualified "hash" function call. However the
106 // argument-dependent lookup for the function "hash" doesn't get
107 // triggered in this context because it conflicts with the struct
108 // c10::hash, so we need to use the qualified name
109 // py::detail::hash, which unfortunately is in a detail namespace.
110 .def(py::detail::hash(py::self)) // NOLINT
111 .def(
112 "__repr__",
113 [](const WorkerInfo& workerInfo) {
114 std::ostringstream os;
115 os << workerInfo;
116 return os.str();
117 })
118 .def(py::pickle(
119 /* __getstate__ */
120 [](const WorkerInfo& workerInfo) {
121 return py::make_tuple(workerInfo.name_, workerInfo.id_);
122 },
123 /* __setstate__ */
124 [](py::tuple t) {
125 TORCH_CHECK(t.size() == 2, "Invalid WorkerInfo state.");
126
127 WorkerInfo info(
128 t[0].cast<std::string>(), t[1].cast<worker_id_t>());
129 return info;
130 }));
131
132 auto rpcAgent =
133 shared_ptr_class_<RpcAgent>(module, "RpcAgent")
134 .def(
135 "join",
136 &RpcAgent::join,
137 py::call_guard<py::gil_scoped_release>(),
138 py::arg("shutdown") = false,
139 py::arg("timeout") = 0)
140 .def(
141 "sync", &RpcAgent::sync, py::call_guard<py::gil_scoped_release>())
142 .def(
143 "shutdown",
144 &RpcAgent::shutdown,
145 py::call_guard<py::gil_scoped_release>())
146 .def(
147 "get_worker_info",
148 (const WorkerInfo& (RpcAgent::*)(void) const) &
149 RpcAgent::getWorkerInfo,
150 py::call_guard<py::gil_scoped_release>())
151 .def(
152 "get_worker_info",
153 (const WorkerInfo& (RpcAgent::*)(const std::string&) const) &
154 RpcAgent::getWorkerInfo,
155 py::call_guard<py::gil_scoped_release>())
156 .def(
157 "get_worker_infos",
158 &RpcAgent::getWorkerInfos,
159 py::call_guard<py::gil_scoped_release>())
160 .def(
161 "_get_device_map",
162 &RpcAgent::getDeviceMap,
163 py::call_guard<py::gil_scoped_release>())
164 .def(
165 "get_debug_info",
166 &RpcAgent::getDebugInfo,
167 py::call_guard<py::gil_scoped_release>())
168 .def(
169 "get_metrics",
170 &RpcAgent::getMetrics,
171 py::call_guard<py::gil_scoped_release>());
172
173 auto pyRRef =
174 shared_ptr_class_<PyRRef>(module, "PyRRef", R"(
175 A class encapsulating a reference to a value of some type on a remote
176 worker. This handle will keep the referenced remote value alive on the
177 worker. A ``UserRRef`` will be deleted when 1) no references to it in
178 both the application code and in the local RRef context, or 2) the
179 application has called a graceful shutdown. Invoking methods on a
180 deleted RRef leads to undefined behaviors. RRef implementation only
181 offers best-effort error detection, and applications should not use
182 ``UserRRefs`` after ``rpc.shutdown()``.
183
184 .. warning::
185 RRefs can only be serialized and deserialized by the RPC module.
186 Serializing and deserializing RRefs without RPC (e.g., Python
187 pickle, torch :meth:`~torch.save` / :meth:`~torch.load`,
188 JIT :meth:`~torch.jit.save` / :meth:`~torch.jit.load`, etc.) will
189 lead to errors.
190
191 Args:
192 value (object): The value to be wrapped by this RRef.
193 type_hint (Type, optional): Python type that should be passed to
194 ``TorchScript`` compiler as type hint for ``value``.
195
196 Example::
197 Following examples skip RPC initialization and shutdown code
198 for simplicity. Refer to RPC docs for those details.
199
200 1. Create an RRef using rpc.remote
201
202 >>> import torch
203 >>> import torch.distributed.rpc as rpc
204 >>> rref = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
205 >>> # get a copy of value from the RRef
206 >>> x = rref.to_here()
207
208 2. Create an RRef from a local object
209
210 >>> import torch
211 >>> from torch.distributed.rpc import RRef
212 >>> x = torch.zeros(2, 2)
213 >>> rref = RRef(x)
214
215 3. Share an RRef with other workers
216
217 >>> # On both worker0 and worker1:
218 >>> def f(rref):
219 >>> return rref.to_here() + 1
220
221 >>> # On worker0:
222 >>> import torch
223 >>> import torch.distributed.rpc as rpc
224 >>> from torch.distributed.rpc import RRef
225 >>> rref = RRef(torch.zeros(2, 2))
226 >>> # the following RPC shares the rref with worker1, reference
227 >>> # count is automatically updated.
228 >>> rpc.rpc_sync("worker1", f, args=(rref,))
229 )")
230 .def(
231 py::init<const py::object&, const py::object&>(),
232 py::arg("value"),
233 py::arg("type_hint") = py::none())
234 .def(
235 // not releasing GIL here to avoid context switch on getters
236 "is_owner",
237 &PyRRef::isOwner,
238 R"(
239 Returns whether or not the current node is the owner of this
240 ``RRef``.
241 )")
242 .def(
243 "confirmed_by_owner",
244 &PyRRef::confirmedByOwner,
245 R"(
246 Returns whether this ``RRef`` has been confirmed by the owner.
247 ``OwnerRRef`` always returns true, while ``UserRRef`` only
248 returns true when the owner knowns about this ``UserRRef``.
249 )")
250 .def(
251 // not releasing GIL here to avoid context switch on getters
252 "owner",
253 &PyRRef::owner,
254 R"(
255 Returns worker information of the node that owns this ``RRef``.
256 )")
257 .def(
258 // not releasing GIL here to avoid context switch on getters
259 "owner_name",
260 &PyRRef::ownerName,
261 R"(
262 Returns worker name of the node that owns this ``RRef``.
263 )")
264 .def(
265 "to_here",
266 &PyRRef::toHere,
267 py::arg("timeout") = py::cast(kUnsetRpcTimeout),
268 py::call_guard<py::gil_scoped_release>(),
269 R"(
270 Blocking call that copies the value of the RRef from the owner
271 to the local node and returns it. If the current node is the
272 owner, returns a reference to the local value.
273
274 Args:
275 timeout (float, optional): Timeout for ``to_here``. If
276 the call does not complete within this timeframe, an
277 exception indicating so will be raised. If this
278 argument is not provided, the default RPC timeout
279 (60s) will be used.
280 )")
281 .def(
282 "local_value",
283 &PyRRef::localValue,
284 py::call_guard<py::gil_scoped_release>(),
285 R"(
286 If the current node is the owner, returns a reference to the
287 local value. Otherwise, throws an exception.
288 )")
289 .def(
290 "rpc_sync",
291 [](const PyRRef& self, float timeoutSeconds) {
292 return self.createRRefProxy(
293 RRefProxyType::RPC_SYNC, timeoutSeconds);
294 },
295 py::arg("timeout") = kUnsetRpcTimeout,
296 py::call_guard<py::gil_scoped_release>(),
297 R"(
298 Create a helper proxy to easily launch an ``rpc_sync`` using
299 the owner of the RRef as the destination to run functions on
300 the object referenced by this RRef. More specifically,
301 ``rref.rpc_sync().func_name(*args, **kwargs)`` is the same as
302 the following:
303
304 >>> def run(rref, func_name, args, kwargs):
305 >>> return getattr(rref.local_value(), func_name)(*args, **kwargs)
306 >>>
307 >>> rpc.rpc_sync(rref.owner(), run, args=(rref, func_name, args, kwargs))
308
309 Args:
310 timeout (float, optional): Timeout for ``rref.rpc_sync()``.
311 If the call does not complete within this timeframe, an
312 exception indicating so will be raised. If this argument
313 is not provided, the default RPC timeout will be used.
314
315 Example::
316 >>> from torch.distributed import rpc
317 >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
318 >>> rref.rpc_sync().size() # returns torch.Size([2, 2])
319 >>> rref.rpc_sync().view(1, 4) # returns tensor([[1., 1., 1., 1.]])
320 )")
321 .def(
322 "rpc_async",
323 [](const PyRRef& self, float timeoutSeconds) {
324 return self.createRRefProxy(
325 RRefProxyType::RPC_ASYNC, timeoutSeconds);
326 },
327 py::arg("timeout") = kUnsetRpcTimeout,
328 py::call_guard<py::gil_scoped_release>(),
329 R"(
330 Create a helper proxy to easily launch an ``rpc_async`` using
331 the owner of the RRef as the destination to run functions on
332 the object referenced by this RRef. More specifically,
333 ``rref.rpc_async().func_name(*args, **kwargs)`` is the same as
334 the following:
335
336 >>> def run(rref, func_name, args, kwargs):
337 >>> return getattr(rref.local_value(), func_name)(*args, **kwargs)
338 >>>
339 >>> rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs))
340
341 Args:
342 timeout (float, optional): Timeout for ``rref.rpc_async()``.
343 If the call does not complete within this timeframe, an
344 exception indicating so will be raised. If this argument
345 is not provided, the default RPC timeout will be used.
346
347 Example::
348 >>> from torch.distributed import rpc
349 >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
350 >>> rref.rpc_async().size().wait() # returns torch.Size([2, 2])
351 >>> rref.rpc_async().view(1, 4).wait() # returns tensor([[1., 1., 1., 1.]])
352 )")
353 .def(
354 "remote",
355 [](const PyRRef& self, float timeoutSeconds) {
356 return self.createRRefProxy(
357 RRefProxyType::REMOTE, timeoutSeconds);
358 },
359 py::arg("timeout") = kUnsetRpcTimeout,
360 py::call_guard<py::gil_scoped_release>(),
361 R"(
362 Create a helper proxy to easily launch a ``remote`` using
363 the owner of the RRef as the destination to run functions on
364 the object referenced by this RRef. More specifically,
365 ``rref.remote().func_name(*args, **kwargs)`` is the same as
366 the following:
367
368 >>> def run(rref, func_name, args, kwargs):
369 >>> return getattr(rref.local_value(), func_name)(*args, **kwargs)
370 >>>
371 >>> rpc.remote(rref.owner(), run, args=(rref, func_name, args, kwargs))
372
373 Args:
374 timeout (float, optional): Timeout for ``rref.remote()``. If
375 the creation of this :class:`~torch.distributed.rpc.RRef`
376 is not successfully completed within the timeout, then the
377 next time there is an attempt to use the RRef
378 (such as ``to_here``), a timeout will be raised. If not
379 provided, the default RPC timeout will be used. Please see
380 ``rpc.remote()`` for specific timeout semantics for
381 :class:`~torch.distributed.rpc.RRef`.
382
383 Example::
384 >>> from torch.distributed import rpc
385 >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
386 >>> rref.remote().size().to_here() # returns torch.Size([2, 2])
387 >>> rref.remote().view(1, 4).to_here() # returns tensor([[1., 1., 1., 1.]])
388 )")
389 .def(
390 py::pickle(
391 /* __getstate__ */
392 [](const PyRRef& /* unused */) {
393 TORCH_CHECK(
394 false,
395 "Can not pickle rref in python pickler, rref can only be "
396 "pickled when using RPC");
397 // Note that this return has no meaning since we always
398 // throw, it's only here to satisfy Pybind API's
399 // requirement.
400 return py::make_tuple();
401 },
402 /* __setstate__ */
403 [](py::tuple /* unused */) { // NOLINT
404 TORCH_CHECK(
405 false,
406 "Can not unpickle rref in python pickler, rref can only be "
407 "unpickled when using RPC");
408 // Note that this return has no meaning since we always
409 // throw, it's only here to satisfy PyBind's API
410 // requirement.
411 return PyRRef(
412 py::cast<py::none>(Py_None),
413 py::cast<py::none>(Py_None));
414 }),
415 py::call_guard<py::gil_scoped_release>())
416 .def(
417 "_serialize",
418 &PyRRef::pickle,
419 py::call_guard<py::gil_scoped_release>())
420 .def_static(
421 "_deserialize",
422 &PyRRef::unpickle,
423 py::call_guard<py::gil_scoped_release>())
424 .def(
425 "_get_type",
426 // Intentionally not releasing GIL, as most accesses just
427 // retrieve cached type py::object
428 &PyRRef::getRRefType,
429 py::arg("timeout") = kUnsetRpcTimeout,
430 py::arg("blocking") = true,
431 R"(
432 If ``blocking=True``, returns the type of the data object
433 referenced by this ``RRef``. On the owner, this is same as
434 ``type(rref.local_value())``. Otherwise, returns a future to
435 this result. On a user, this will trigger an RPC to fetch the
436 ``type`` object from the owner. After this function is run
437 once, the ``type`` object is cached by the ``RRef``, and
438 subsequent invocations no longer trigger RPC. Note that this is
439 true regardless of the ``blocking`` argument of subsequent
440 calls.
441
442 Args:
443 rref (torch.distributed.rpc.RRef): The RRef to get type of.
444 timeout (float, optional): Timeout, in seconds for
445 ``_get_type``. If the call does not complete within
446 this timeframe, an exception indicating so will be
447 raised. If this argument is not provided, the default
448 RPC timeout will be used.
449 blocking (bool, optional): Whether to synchronously wait on
450 the RPC triggered by the first call and return the
451 type. If ``False``, will return a future. Default is
452 ``True``.
453 )")
454 .def(
455 "_get_future",
456 [](const PyRRef& self) {
457 return std::make_shared<jit::PythonFutureWrapper>(
458 self.getFuture());
459 },
460 py::call_guard<py::gil_scoped_release>(),
461 R"(
462 Returns the future that corresponds to the creation of this RRef
463 on the remote node. This is for internal use cases such as profiling
464 only.
465 )")
466 .def(
467 "_get_profiling_future",
468 [](const PyRRef& self) {
469 return std::make_shared<jit::PythonFutureWrapper>(
470 self.getProfilingFuture());
471 },
472 py::call_guard<py::gil_scoped_acquire>(),
473 R"(
474 Returns future that completes when the profiling event corresponding
475 to the creation of this RRef on the remote node has been recorded.
476 )")
477 .def(
478 "_set_profiling_future",
479 [](PyRRef& self,
480 const std::shared_ptr<jit::PythonFutureWrapper>&
481 wrappedFuture) {
482 self.setProfilingFuture(wrappedFuture->fut);
483 },
484 py::call_guard<py::gil_scoped_acquire>(),
485 R"(
486 Set future that is completed when the profiling event corresponding
487 to the creation of this RRef on the remote node has been recorded.
488 )")
489 .def(
490 "backward",
491 [](PyRRef& self,
492 int64_t dist_autograd_ctx_id,
493 bool retain_graph) {
494 self.backward(dist_autograd_ctx_id, retain_graph);
495 },
496 py::arg("dist_autograd_ctx_id") = -1,
497 py::arg("retain_graph") = false,
498 py::call_guard<py::gil_scoped_release>(),
499 R"(
500 Runs the backward pass using the RRef as the root of the
501 backward pass. If ``dist_autograd_ctx_id`` is provided,
502 we perform a distributed backward pass using the provided
503 ctx_id starting from the owner of the RRef. In this case,
504 :meth:`~torch.distributed.autograd.get_gradients` should be
505 used to retrieve the gradients. If ``dist_autograd_ctx_id``
506 is ``None``, it is assumed that this is a local autograd graph
507 and we only perform a local backward pass. In the local case,
508 the node calling this API has to be the owner of the RRef.
509 The value of the RRef is expected to be a scalar Tensor.
510
511 Args:
512 dist_autograd_ctx_id (int, optional): The distributed
513 autograd context id for which we should retrieve the
514 gradients (default: -1).
515 retain_graph(bool, optional): If ``False``, the graph used to
516 compute the grad will be freed. Note that in nearly all
517 cases setting this option to ``True`` is not needed and
518 often can be worked around in a much more efficient way.
519 Usually, you need to set this to ``True`` to run backward
520 multiple times (default: False).
521
522 Example::
523 >>> import torch.distributed.autograd as dist_autograd
524 >>> with dist_autograd.context() as context_id:
525 >>> rref.backward(context_id)
526 )")
527 // not releasing GIL to avoid context switch
528 .def("__repr__", &PyRRef::str);
529
530 #ifdef USE_TENSORPIPE
531
532 // Base class: torch.distributed.rpc.RpcBackendOptions.
533 py::class_<TensorPipeRpcBackendOptions>(
534 module, "_TensorPipeRpcBackendOptionsBase", rpcBackendOptions)
535 .def(
536 py::init<
537 int,
538 std::optional<std::vector<std::string>>,
539 std::optional<std::vector<std::string>>,
540 float,
541 std::string,
542 std::unordered_map<std::string, DeviceMap>,
543 std::vector<c10::Device>>(),
544 py::arg("num_worker_threads") = kDefaultNumWorkerThreads,
545 py::arg("_transports") = std::optional<std::vector<std::string>>(),
546 py::arg("_channels") = std::optional<std::vector<std::string>>(),
547 py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
548 py::arg("init_method") = kDefaultInitMethod,
549 py::arg("device_maps") = std::unordered_map<std::string, DeviceMap>(),
550 py::arg("devices") = std::vector<c10::Device>())
551 .def_readwrite(
552 "num_worker_threads",
553 &TensorPipeRpcBackendOptions::numWorkerThreads,
554 R"(
555 The number of threads in the thread-pool used by
556 :class:`~torch.distributed.rpc.TensorPipeAgent` to execute
557 requests.
558 )")
559 .def_readwrite(
560 "device_maps",
561 &TensorPipeRpcBackendOptions::deviceMaps,
562 R"(The device map locations.)")
563 .def_readwrite(
564 "devices",
565 &TensorPipeRpcBackendOptions::devices,
566 R"(All devices used by the local agent.)")
567 .def("_set_device_map", &TensorPipeRpcBackendOptions::setDeviceMap);
568
569 module.attr("_DEFAULT_NUM_WORKER_THREADS") =
570 py::cast(kDefaultNumWorkerThreads);
571
572 shared_ptr_class_<TensorPipeAgent>(module, "TensorPipeAgent", rpcAgent)
573 .def(
574 py::init(
575 [](const c10::intrusive_ptr<::c10d::Store>& store,
576 std::string selfName,
577 worker_id_t selfId,
578 std::optional<int> worldSize,
579 TensorPipeRpcBackendOptions opts,
580 std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
581 std::vector<c10::Device> devices) {
582 return std::shared_ptr<TensorPipeAgent>(
583 new TensorPipeAgent(
584 store,
585 std::move(selfName),
586 selfId,
587 worldSize,
588 std::move(opts),
589 std::move(reverseDeviceMaps),
590 std::move(devices),
591 std::make_unique<RequestCallbackImpl>()),
592 impl::destroy_without_gil<TensorPipeAgent>);
593 }),
594 py::arg("store"),
595 py::arg("name"),
596 py::arg("rank"),
597 py::arg("world_size"),
598 py::arg("rpc_backend_options"),
599 py::arg("reverse_device_maps"),
600 py::arg("devices"))
601 .def(
602 "join",
603 &TensorPipeAgent::join,
604 py::call_guard<py::gil_scoped_release>(),
605 py::arg("shutdown") = false,
606 py::arg("timeout") = 0)
607 .def(
608 "shutdown",
609 &TensorPipeAgent::shutdown,
610 py::call_guard<py::gil_scoped_release>())
611 .def(
612 "get_worker_info",
613 (const WorkerInfo& (TensorPipeAgent::*)(void) const) &
614 RpcAgent::getWorkerInfo,
615 py::call_guard<py::gil_scoped_release>())
616 .def(
617 "get_worker_info",
618 (const WorkerInfo& (TensorPipeAgent::*)(const std::string&) const) &
619 TensorPipeAgent::getWorkerInfo,
620 py::call_guard<py::gil_scoped_release>())
621 .def(
622 "get_worker_info",
623 (const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id) const) &
624 TensorPipeAgent::getWorkerInfo,
625 py::call_guard<py::gil_scoped_release>())
626 .def(
627 "get_worker_infos",
628 (std::vector<WorkerInfo>(TensorPipeAgent::*)() const) &
629 TensorPipeAgent::getWorkerInfos,
630 py::call_guard<py::gil_scoped_release>())
631 .def(
632 "_get_device_map",
633 (DeviceMap(TensorPipeAgent::*)(const WorkerInfo& dst) const) &
634 TensorPipeAgent::getDeviceMap,
635 py::call_guard<py::gil_scoped_release>())
636 .def(
637 "_get_backend_options",
638 &TensorPipeAgent::getBackendOptions,
639 py::call_guard<py::gil_scoped_release>())
640 .def(
641 "_update_group_membership",
642 &TensorPipeAgent::updateGroupMembership,
643 py::call_guard<py::gil_scoped_release>())
644 .def_readonly("is_static_group", &TensorPipeAgent::isStaticGroup_)
645 .def_property_readonly("store", &TensorPipeAgent::getStore);
646
647 #endif // USE_TENSORPIPE
648
649 module.def("_is_current_rpc_agent_set", &RpcAgent::isCurrentRpcAgentSet);
650
651 module.def("_get_current_rpc_agent", &RpcAgent::getCurrentRpcAgent);
652
653 module.def(
654 "_set_and_start_rpc_agent",
655 [](const std::shared_ptr<RpcAgent>& rpcAgent) {
656 RpcAgent::setCurrentRpcAgent(rpcAgent);
657 // Initializing typeResolver inside RpcAgent constructor will make
658 // RpcAgent have python dependency. To avoid RpcAgent to have python
659 // dependency, setTypeResolver() here.
660 std::shared_ptr<TypeResolver> typeResolver =
661 std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
662 auto typePtr = PythonRpcHandler::getInstance().parseTypeFromStr(
663 qn.qualifiedName());
664 return c10::StrongTypePtr(
665 PythonRpcHandler::getInstance().jitCompilationUnit(),
666 std::move(typePtr));
667 });
668 rpcAgent->setTypeResolver(typeResolver);
669 rpcAgent->start();
670 },
671 py::call_guard<py::gil_scoped_release>());
672
673 module.def(
674 "_reset_current_rpc_agent",
675 []() { RpcAgent::setCurrentRpcAgent(nullptr); },
676 py::call_guard<py::gil_scoped_release>());
677
678 module.def(
679 "_delete_all_user_and_unforked_owner_rrefs",
680 [](std::chrono::milliseconds timeoutMillis) {
681 RRefContext::getInstance().delAllUsersAndUnforkedOwners(timeoutMillis);
682 },
683 py::arg("timeout") = kDeleteAllUsersTimeout,
684 py::call_guard<py::gil_scoped_release>());
685
686 module.def("_destroy_rref_context", [](bool ignoreRRefLeak) {
687 // NB: do not release GIL in the function. The destroyInstance() method
688 // returns a list of deleted OwnerRRefs that hold py::object instances.
689 // Clearing those OwnerRRefs are likely to trigger Python deref, which
690 // requires GIL.
691 RRefContext::getInstance().destroyInstance(ignoreRRefLeak).clear();
692 });
693
694 module.def("_rref_context_get_debug_info", []() {
695 return RRefContext::getInstance().getDebugInfo();
696 });
697
698 module.def(
699 "_cleanup_python_rpc_handler",
700 []() { PythonRpcHandler::getInstance().cleanup(); },
701 py::call_guard<py::gil_scoped_release>());
702
703 module.def(
704 "_invoke_rpc_builtin",
705 [](const WorkerInfo& dst,
706 const std::string& opName,
707 const float rpcTimeoutSeconds,
708 const py::args& args,
709 const py::kwargs& kwargs) {
710 return std::make_shared<jit::PythonFutureWrapper>(
711 pyRpcBuiltin(dst, opName, args, kwargs, rpcTimeoutSeconds));
712 },
713 py::call_guard<py::gil_scoped_acquire>());
714
715 module.def(
716 "_invoke_rpc_python_udf",
717 [](const WorkerInfo& dst,
718 std::string& pickledPythonUDF,
719 std::vector<torch::Tensor>& tensors,
720 const float rpcTimeoutSeconds,
721 const bool isAsyncExecution) {
722 return std::make_shared<jit::PythonFutureWrapper>(pyRpcPythonUdf(
723 dst,
724 pickledPythonUDF,
725 tensors,
726 rpcTimeoutSeconds,
727 isAsyncExecution));
728 },
729 py::call_guard<py::gil_scoped_release>());
730
731 module.def(
732 "_invoke_rpc_torchscript",
733 [](const std::string& dstWorkerName,
734 const std::string& qualifiedNameStr,
735 const py::tuple& argsTuple,
736 const py::dict& kwargsDict,
737 const float rpcTimeoutSeconds,
738 const bool isAsyncExecution) {
739 return std::make_shared<jit::PythonFutureWrapper>(pyRpcTorchscript(
740 dstWorkerName,
741 qualifiedNameStr,
742 argsTuple,
743 kwargsDict,
744 rpcTimeoutSeconds,
745 isAsyncExecution));
746 },
747 py::call_guard<py::gil_scoped_release>());
748
749 module.def(
750 "_invoke_remote_builtin",
751 &pyRemoteBuiltin,
752 py::call_guard<py::gil_scoped_acquire>());
753
754 module.def(
755 "_invoke_remote_python_udf",
756 &pyRemotePythonUdf,
757 py::call_guard<py::gil_scoped_release>());
758
759 module.def(
760 "_invoke_remote_torchscript",
761 &pyRemoteTorchscript,
762 py::call_guard<py::gil_scoped_release>());
763
764 module.def(
765 "get_rpc_timeout",
766 []() {
767 return RpcAgent::getCurrentRpcAgent()->getRpcTimeout().count() /
768 kSecToMsConversion;
769 },
770 R"(
771 Retrieve the default timeout for all RPCs that was set during RPC initialization.
772 The returned value will be in seconds.
773 Returns:
774 ``float`` indicating the RPC timeout in seconds.
775 )");
776
777 module.def(
778 "enable_gil_profiling",
779 [](bool flag) {
780 RpcAgent::getCurrentRpcAgent()->enableGILProfiling(flag);
781 },
782 R"(
783 Set whether GIL wait times should be enabled or not. This incurs a slight
784 overhead cost. Default is disabled for performance reasons.
785
786 Args:
787 flag (bool): True to set GIL profiling, False to disable.
788 )");
789
790 module.def(
791 "_set_rpc_timeout",
792 [](const float rpcTimeoutSeconds) {
793 auto rpcTimeout = std::chrono::milliseconds(
794 static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion));
795 RpcAgent::getCurrentRpcAgent()->setRpcTimeout(rpcTimeout);
796 },
797 R"(
798 Set the default timeout for all RPCs. The input unit is expected to be
799 in seconds. If an RPC is not completed within this time, an exception
800 indicating it has timed out will be raised. To control timeout for
801 specific RPCs, a timeout parameter can be passed into
802 :meth:`~torch.distributed.rpc.rpc_sync` and
803 :meth:`~torch.distributed.rpc.rpc_async`.
804
805 Args:
806 rpcTimeoutSeconds (float): Timeout value in seconds.
807 )");
808
809 module.def(
810 "_enable_server_process_global_profiler",
811 &profiler::processglobal::enableServer);
812 module.def(
813 "_disable_server_process_global_profiler",
814 &profiler::processglobal::disableServer);
815
816 module.def("_set_profiler_node_id", &at::RecordFunction::setDefaultNodeId);
817
818 py::class_<
819 RemoteProfilerManager,
820 std::unique_ptr<RemoteProfilerManager, py::nodelete>>(
821 module, "RemoteProfilerManager")
822 .def("set_current_profiling_key", [](const std::string& key) {
823 auto& inst = RemoteProfilerManager::getInstance();
824 inst.setCurrentKey(key);
825 });
826
827 module.def(
828 "_enable_jit_rref_pickle",
829 &enableJitRRefPickle,
830 R"(
831 Allows ``torch.jit.save`` to save a ``torch.jit.ScriptModule`` with
832 pickled RRefs out of RPC contexts.
833
834
835 .. warning::
836 This is dangerous. If the module contains RRefs, the pickled
837 result must be sent over RPC and get unpickled on the receiving side
838 to restore the module. Otherwise, there will be RRef leaks, which
839 can potentially lead to program hang. When using this API, it is
840 applications responsibility to make sure that the above assumption
841 always holds.
842 )");
843 module.def("_disable_jit_rref_pickle", &disableJitRRefPickle);
844
845 Py_RETURN_TRUE;
846 }
847
848 } // namespace
849
850 static PyMethodDef methods[] = { // NOLINT
851 {"_rpc_init", rpc_init, METH_NOARGS, nullptr},
852 {nullptr, nullptr, 0, nullptr}};
853
python_functions()854 PyMethodDef* python_functions() {
855 return methods;
856 }
857
858 } // namespace torch::distributed::rpc
859