1# mypy: allow-untyped-defs 2import functools 3 4 5def async_execution(fn): 6 r""" 7 A decorator for a function indicating that the return value of the function 8 is guaranteed to be a :class:`~torch.futures.Future` object and this 9 function can run asynchronously on the RPC callee. More specifically, the 10 callee extracts the :class:`~torch.futures.Future` returned by the wrapped 11 function and installs subsequent processing steps as a callback to that 12 :class:`~torch.futures.Future`. The installed callback will read the value 13 from the :class:`~torch.futures.Future` when completed and send the 14 value back as the RPC response. That also means the returned 15 :class:`~torch.futures.Future` only exists on the callee side and is never 16 sent through RPC. This decorator is useful when the wrapped function's 17 (``fn``) execution needs to pause and resume due to, e.g., containing 18 :meth:`~torch.distributed.rpc.rpc_async` or waiting for other signals. 19 20 .. note:: To enable asynchronous execution, applications must pass the 21 function object returned by this decorator to RPC APIs. If RPC detected 22 attributes installed by this decorator, it knows that this function 23 returns a ``Future`` object and will handle that accordingly. 24 However, this does not mean this decorator has to be outmost one when 25 defining a function. For example, when combined with ``@staticmethod`` 26 or ``@classmethod``, ``@rpc.functions.async_execution`` needs to be the 27 inner decorator to allow the target function be recognized as a static 28 or class function. This target function can still execute asynchronously 29 because, when accessed, the static or class method preserves attributes 30 installed by ``@rpc.functions.async_execution``. 31 32 33 Example:: 34 The returned :class:`~torch.futures.Future` object can come from 35 :meth:`~torch.distributed.rpc.rpc_async`, 36 :meth:`~torch.futures.Future.then`, or :class:`~torch.futures.Future` 37 constructor. The example below shows directly using the 38 :class:`~torch.futures.Future` returned by 39 :meth:`~torch.futures.Future.then`. 40 41 >>> from torch.distributed import rpc 42 >>> 43 >>> # omitting setup and shutdown RPC 44 >>> 45 >>> # On all workers 46 >>> @rpc.functions.async_execution 47 >>> def async_add_chained(to, x, y, z): 48 >>> # This function runs on "worker1" and returns immediately when 49 >>> # the callback is installed through the `then(cb)` API. In the 50 >>> # mean time, the `rpc_async` to "worker2" can run concurrently. 51 >>> # When the return value of that `rpc_async` arrives at 52 >>> # "worker1", "worker1" will run the lambda function accordingly 53 >>> # and set the value for the previously returned `Future`, which 54 >>> # will then trigger RPC to send the result back to "worker0". 55 >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then( 56 >>> lambda fut: fut.wait() + z 57 >>> ) 58 >>> 59 >>> # On worker0 60 >>> # xdoctest: +SKIP 61 >>> ret = rpc.rpc_sync( 62 >>> "worker1", 63 >>> async_add_chained, 64 >>> args=("worker2", torch.ones(2), 1, 1) 65 >>> ) 66 >>> print(ret) # prints tensor([3., 3.]) 67 68 When combined with TorchScript decorators, this decorator must be the 69 outmost one. 70 71 >>> from torch import Tensor 72 >>> from torch.futures import Future 73 >>> from torch.distributed import rpc 74 >>> 75 >>> # omitting setup and shutdown RPC 76 >>> 77 >>> # On all workers 78 >>> @torch.jit.script 79 >>> def script_add(x: Tensor, y: Tensor) -> Tensor: 80 >>> return x + y 81 >>> 82 >>> @rpc.functions.async_execution 83 >>> @torch.jit.script 84 >>> def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]: 85 >>> return rpc.rpc_async(to, script_add, (x, y)) 86 >>> 87 >>> # On worker0 88 >>> ret = rpc.rpc_sync( 89 >>> "worker1", 90 >>> async_add, 91 >>> args=("worker2", torch.ones(2), 1) 92 >>> ) 93 >>> print(ret) # prints tensor([2., 2.]) 94 95 When combined with static or class method, this decorator must be the 96 inner one. 97 98 >>> from torch.distributed import rpc 99 >>> 100 >>> # omitting setup and shutdown RPC 101 >>> 102 >>> # On all workers 103 >>> class AsyncExecutionClass: 104 >>> 105 >>> @staticmethod 106 >>> @rpc.functions.async_execution 107 >>> def static_async_add(to, x, y, z): 108 >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then( 109 >>> lambda fut: fut.wait() + z 110 >>> ) 111 >>> 112 >>> @classmethod 113 >>> @rpc.functions.async_execution 114 >>> def class_async_add(cls, to, x, y, z): 115 >>> ret_fut = torch.futures.Future() 116 >>> rpc.rpc_async(to, torch.add, args=(x, y)).then( 117 >>> lambda fut: ret_fut.set_result(fut.wait() + z) 118 >>> ) 119 >>> return ret_fut 120 >>> 121 >>> @rpc.functions.async_execution 122 >>> def bound_async_add(self, to, x, y, z): 123 >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then( 124 >>> lambda fut: fut.wait() + z 125 >>> ) 126 >>> 127 >>> # On worker0 128 >>> ret = rpc.rpc_sync( 129 >>> "worker1", 130 >>> AsyncExecutionClass.static_async_add, 131 >>> args=("worker2", torch.ones(2), 1, 2) 132 >>> ) 133 >>> print(ret) # prints tensor([4., 4.]) 134 >>> 135 >>> ret = rpc.rpc_sync( 136 >>> "worker1", 137 >>> AsyncExecutionClass.class_async_add, 138 >>> args=("worker2", torch.ones(2), 1, 2) 139 >>> ) 140 >>> print(ret) # prints tensor([4., 4.]) 141 142 This decorator also works with RRef helpers, i.e., . 143 :meth:`torch.distributed.rpc.RRef.rpc_sync`, 144 :meth:`torch.distributed.rpc.RRef.rpc_async`, and 145 :meth:`torch.distributed.rpc.RRef.remote`. 146 147 >>> from torch.distributed import rpc 148 >>> 149 >>> # reuse the AsyncExecutionClass class above 150 >>> rref = rpc.remote("worker1", AsyncExecutionClass) 151 >>> ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2) 152 >>> print(ret) # prints tensor([4., 4.]) 153 >>> 154 >>> rref = rpc.remote("worker1", AsyncExecutionClass) 155 >>> ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait() 156 >>> print(ret) # prints tensor([4., 4.]) 157 >>> 158 >>> rref = rpc.remote("worker1", AsyncExecutionClass) 159 >>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here() 160 >>> print(ret) # prints tensor([4., 4.]) 161 """ 162 163 @functools.wraps(fn) 164 def wrapper(*args, **kwargs): 165 return fn(*args, **kwargs) 166 167 # Can't declare and use attributes of function objects (mypy#2087) 168 wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined] 169 return wrapper 170