xref: /aosp_15_r20/external/pytorch/torch/distributed/rpc/rref_proxy.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from functools import partial
3
4import torch
5from torch.futures import Future
6
7from . import functions, rpc_async
8from .constants import UNSET_RPC_TIMEOUT
9
10
11def _local_invoke(rref, func_name, args, kwargs):
12    return getattr(rref.local_value(), func_name)(*args, **kwargs)
13
14
15@functions.async_execution
16def _local_invoke_async_execution(rref, func_name, args, kwargs):
17    return getattr(rref.local_value(), func_name)(*args, **kwargs)
18
19
20def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs):
21    def _rref_type_cont(rref_fut):
22        rref_type = rref_fut.value()
23
24        _invoke_func = _local_invoke
25        # Bypass ScriptModules when checking for async function attribute.
26        bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass(
27            rref_type, torch._C.ScriptModule
28        )
29        if not bypass_type:
30            func = getattr(rref_type, func_name)
31            if hasattr(func, "_wrapped_async_rpc_function"):
32                _invoke_func = _local_invoke_async_execution
33
34        return rpc_api(
35            rref.owner(),
36            _invoke_func,
37            args=(rref, func_name, args, kwargs),
38            timeout=timeout,
39        )
40
41    rref_fut = rref._get_type(timeout=timeout, blocking=False)
42
43    if rpc_api != rpc_async:
44        rref_fut.wait()
45        return _rref_type_cont(rref_fut)
46    else:
47        # A little explanation on this.
48        # rpc_async returns a Future pointing to the return value of `func_name`, it returns a `Future[T]`
49        # Calling _rref_type_cont from the `then` lambda causes Future wrapping. IOW, `then` returns a `Future[Future[T]]`
50        # To address that, we return a Future that is completed with the result of the async call.
51        result: Future = Future()
52
53        def _wrap_rref_type_cont(fut):
54            try:
55                _rref_type_cont(fut).then(_complete_op)
56            except BaseException as ex:
57                result.set_exception(ex)
58
59        def _complete_op(fut):
60            try:
61                result.set_result(fut.value())
62            except BaseException as ex:
63                result.set_exception(ex)
64
65        rref_fut.then(_wrap_rref_type_cont)
66        return result
67
68
69# This class manages proxied RPC API calls for RRefs. It is entirely used from
70# C++ (see python_rpc_handler.cpp).
71class RRefProxy:
72    def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT):
73        self.rref = rref
74        self.rpc_api = rpc_api
75        self.rpc_timeout = timeout
76
77    def __getattr__(self, func_name):
78        return partial(
79            _invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout
80        )
81