1# mypy: allow-untyped-defs 2from functools import partial 3 4import torch 5from torch.futures import Future 6 7from . import functions, rpc_async 8from .constants import UNSET_RPC_TIMEOUT 9 10 11def _local_invoke(rref, func_name, args, kwargs): 12 return getattr(rref.local_value(), func_name)(*args, **kwargs) 13 14 15@functions.async_execution 16def _local_invoke_async_execution(rref, func_name, args, kwargs): 17 return getattr(rref.local_value(), func_name)(*args, **kwargs) 18 19 20def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs): 21 def _rref_type_cont(rref_fut): 22 rref_type = rref_fut.value() 23 24 _invoke_func = _local_invoke 25 # Bypass ScriptModules when checking for async function attribute. 26 bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass( 27 rref_type, torch._C.ScriptModule 28 ) 29 if not bypass_type: 30 func = getattr(rref_type, func_name) 31 if hasattr(func, "_wrapped_async_rpc_function"): 32 _invoke_func = _local_invoke_async_execution 33 34 return rpc_api( 35 rref.owner(), 36 _invoke_func, 37 args=(rref, func_name, args, kwargs), 38 timeout=timeout, 39 ) 40 41 rref_fut = rref._get_type(timeout=timeout, blocking=False) 42 43 if rpc_api != rpc_async: 44 rref_fut.wait() 45 return _rref_type_cont(rref_fut) 46 else: 47 # A little explanation on this. 48 # rpc_async returns a Future pointing to the return value of `func_name`, it returns a `Future[T]` 49 # Calling _rref_type_cont from the `then` lambda causes Future wrapping. IOW, `then` returns a `Future[Future[T]]` 50 # To address that, we return a Future that is completed with the result of the async call. 51 result: Future = Future() 52 53 def _wrap_rref_type_cont(fut): 54 try: 55 _rref_type_cont(fut).then(_complete_op) 56 except BaseException as ex: 57 result.set_exception(ex) 58 59 def _complete_op(fut): 60 try: 61 result.set_result(fut.value()) 62 except BaseException as ex: 63 result.set_exception(ex) 64 65 rref_fut.then(_wrap_rref_type_cont) 66 return result 67 68 69# This class manages proxied RPC API calls for RRefs. It is entirely used from 70# C++ (see python_rpc_handler.cpp). 71class RRefProxy: 72 def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT): 73 self.rref = rref 74 self.rpc_api = rpc_api 75 self.rpc_timeout = timeout 76 77 def __getattr__(self, func_name): 78 return partial( 79 _invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout 80 ) 81