xref: /aosp_15_r20/external/pytorch/torch/distributed/rpc/functions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3
4
5def async_execution(fn):
6    r"""
7    A decorator for a function indicating that the return value of the function
8    is guaranteed to be a :class:`~torch.futures.Future` object and this
9    function can run asynchronously on the RPC callee. More specifically, the
10    callee extracts the :class:`~torch.futures.Future` returned by the wrapped
11    function and installs subsequent processing steps as a callback to that
12    :class:`~torch.futures.Future`. The installed callback will read the value
13    from the :class:`~torch.futures.Future` when completed and send the
14    value back as the RPC response. That also means the returned
15    :class:`~torch.futures.Future` only exists on the callee side and is never
16    sent through RPC. This decorator is useful when the wrapped function's
17    (``fn``) execution needs to pause and resume due to, e.g., containing
18    :meth:`~torch.distributed.rpc.rpc_async` or waiting for other signals.
19
20    .. note:: To enable asynchronous execution, applications must pass the
21        function object returned by this decorator to RPC APIs. If RPC detected
22        attributes installed by this decorator, it knows that this function
23        returns a ``Future`` object and will handle that accordingly.
24        However, this does not mean this decorator has to be outmost one when
25        defining a function. For example, when combined with ``@staticmethod``
26        or ``@classmethod``, ``@rpc.functions.async_execution`` needs to be the
27        inner decorator to allow the target function be recognized as a static
28        or class function. This target function can still execute asynchronously
29        because, when accessed, the static or class method preserves attributes
30        installed by ``@rpc.functions.async_execution``.
31
32
33    Example::
34        The returned :class:`~torch.futures.Future` object can come from
35        :meth:`~torch.distributed.rpc.rpc_async`,
36        :meth:`~torch.futures.Future.then`, or :class:`~torch.futures.Future`
37        constructor. The example below shows directly using the
38        :class:`~torch.futures.Future` returned by
39        :meth:`~torch.futures.Future.then`.
40
41        >>> from torch.distributed import rpc
42        >>>
43        >>> # omitting setup and shutdown RPC
44        >>>
45        >>> # On all workers
46        >>> @rpc.functions.async_execution
47        >>> def async_add_chained(to, x, y, z):
48        >>>     # This function runs on "worker1" and returns immediately when
49        >>>     # the callback is installed through the `then(cb)` API. In the
50        >>>     # mean time, the `rpc_async` to "worker2" can run concurrently.
51        >>>     # When the return value of that `rpc_async` arrives at
52        >>>     # "worker1", "worker1" will run the lambda function accordingly
53        >>>     # and set the value for the previously returned `Future`, which
54        >>>     # will then trigger RPC to send the result back to "worker0".
55        >>>     return rpc.rpc_async(to, torch.add, args=(x, y)).then(
56        >>>         lambda fut: fut.wait() + z
57        >>>     )
58        >>>
59        >>> # On worker0
60        >>> # xdoctest: +SKIP
61        >>> ret = rpc.rpc_sync(
62        >>>     "worker1",
63        >>>     async_add_chained,
64        >>>     args=("worker2", torch.ones(2), 1, 1)
65        >>> )
66        >>> print(ret)  # prints tensor([3., 3.])
67
68        When combined with TorchScript decorators, this decorator must be the
69        outmost one.
70
71        >>> from torch import Tensor
72        >>> from torch.futures import Future
73        >>> from torch.distributed import rpc
74        >>>
75        >>> # omitting setup and shutdown RPC
76        >>>
77        >>> # On all workers
78        >>> @torch.jit.script
79        >>> def script_add(x: Tensor, y: Tensor) -> Tensor:
80        >>>     return x + y
81        >>>
82        >>> @rpc.functions.async_execution
83        >>> @torch.jit.script
84        >>> def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]:
85        >>>     return rpc.rpc_async(to, script_add, (x, y))
86        >>>
87        >>> # On worker0
88        >>> ret = rpc.rpc_sync(
89        >>>     "worker1",
90        >>>     async_add,
91        >>>     args=("worker2", torch.ones(2), 1)
92        >>> )
93        >>> print(ret)  # prints tensor([2., 2.])
94
95        When combined with static or class method, this decorator must be the
96        inner one.
97
98        >>> from torch.distributed import rpc
99        >>>
100        >>> # omitting setup and shutdown RPC
101        >>>
102        >>> # On all workers
103        >>> class AsyncExecutionClass:
104        >>>
105        >>>     @staticmethod
106        >>>     @rpc.functions.async_execution
107        >>>     def static_async_add(to, x, y, z):
108        >>>         return rpc.rpc_async(to, torch.add, args=(x, y)).then(
109        >>>             lambda fut: fut.wait() + z
110        >>>         )
111        >>>
112        >>>     @classmethod
113        >>>     @rpc.functions.async_execution
114        >>>     def class_async_add(cls, to, x, y, z):
115        >>>         ret_fut = torch.futures.Future()
116        >>>         rpc.rpc_async(to, torch.add, args=(x, y)).then(
117        >>>             lambda fut: ret_fut.set_result(fut.wait() + z)
118        >>>         )
119        >>>         return ret_fut
120        >>>
121        >>>     @rpc.functions.async_execution
122        >>>     def bound_async_add(self, to, x, y, z):
123        >>>         return rpc.rpc_async(to, torch.add, args=(x, y)).then(
124        >>>             lambda fut: fut.wait() + z
125        >>>         )
126        >>>
127        >>> # On worker0
128        >>> ret = rpc.rpc_sync(
129        >>>     "worker1",
130        >>>     AsyncExecutionClass.static_async_add,
131        >>>     args=("worker2", torch.ones(2), 1, 2)
132        >>> )
133        >>> print(ret)  # prints tensor([4., 4.])
134        >>>
135        >>> ret = rpc.rpc_sync(
136        >>>     "worker1",
137        >>>     AsyncExecutionClass.class_async_add,
138        >>>     args=("worker2", torch.ones(2), 1, 2)
139        >>> )
140        >>> print(ret)  # prints tensor([4., 4.])
141
142        This decorator also works with RRef helpers, i.e., .
143        :meth:`torch.distributed.rpc.RRef.rpc_sync`,
144        :meth:`torch.distributed.rpc.RRef.rpc_async`, and
145        :meth:`torch.distributed.rpc.RRef.remote`.
146
147        >>> from torch.distributed import rpc
148        >>>
149        >>> # reuse the AsyncExecutionClass class above
150        >>> rref = rpc.remote("worker1", AsyncExecutionClass)
151        >>> ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2)
152        >>> print(ret)  # prints tensor([4., 4.])
153        >>>
154        >>> rref = rpc.remote("worker1", AsyncExecutionClass)
155        >>> ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait()
156        >>> print(ret)  # prints tensor([4., 4.])
157        >>>
158        >>> rref = rpc.remote("worker1", AsyncExecutionClass)
159        >>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here()
160        >>> print(ret)  # prints tensor([4., 4.])
161    """
162
163    @functools.wraps(fn)
164    def wrapper(*args, **kwargs):
165        return fn(*args, **kwargs)
166
167    # Can't declare and use attributes of function objects (mypy#2087)
168    wrapper._wrapped_async_rpc_function = fn  # type: ignore[attr-defined]
169    return wrapper
170