1# mypy: allow-untyped-defs 2import collections 3import copyreg 4import io 5import pickle 6import sys 7import threading 8import traceback 9from enum import Enum 10 11import torch 12import torch.distributed as dist 13from torch._C._distributed_rpc import _get_current_rpc_agent 14 15 16__all__ = ["RPCExecMode", "serialize", "deserialize", "PythonUDF", "RemoteException"] 17 18# Thread local tensor tables to store tensors while pickling torch.Tensor 19# objects 20_thread_local_tensor_tables = threading.local() 21_pickler = pickle.Pickler 22_unpickler = pickle.Unpickler 23 24 25class RPCExecMode(Enum): 26 SYNC = "sync" 27 ASYNC = "async" 28 ASYNC_JIT = "async_jit" 29 REMOTE = "remote" 30 31 32class _InternalRPCPickler: 33 r""" 34 This class provides serialize() and deserialize() interfaces to serialize 35 data to be "binary string + tensor table" format 36 So for RPC python UDF function and args, non tensor data will be serialized 37 into regular binary string, tensor data will be put into thread local tensor 38 tables, this serialization format is consistent with builtin operator and args 39 using JIT pickler. This format will make tensor handling in C++ much easier, 40 e.g. attach tensor to distributed autograd graph in C++ 41 """ 42 43 def __init__(self): 44 # Ignore type error because dispatch_table is defined in third-party package 45 self._dispatch_table = copyreg.dispatch_table.copy() # type: ignore[attr-defined] 46 self._dispatch_table[torch.Tensor] = self._tensor_reducer 47 # Used for registering customized picklers. 48 self._class_reducer_dict = {} 49 50 def _register_reducer(self, obj_class, reducer): 51 # For the same class, only register the reducer once. 52 if obj_class not in self._class_reducer_dict: 53 self._class_reducer_dict[obj_class] = reducer 54 55 @classmethod 56 def _tensor_receiver(cls, tensor_index): 57 global _thread_local_tensor_tables 58 return _thread_local_tensor_tables.recv_tables[tensor_index] 59 60 def _tensor_reducer(self, tensor): 61 global _thread_local_tensor_tables 62 _thread_local_tensor_tables.send_tables.append(tensor) 63 tensor_index = len(_thread_local_tensor_tables.send_tables) - 1 64 return (_InternalRPCPickler._tensor_receiver, (tensor_index,)) 65 66 @classmethod 67 def _py_rref_receiver(cls, rref_fork_data): 68 return dist.rpc.PyRRef._deserialize(rref_fork_data) 69 70 def _py_rref_reducer(self, py_rref): 71 rref_fork_data = py_rref._serialize() 72 return (_InternalRPCPickler._py_rref_receiver, (rref_fork_data,)) 73 74 def _rref_reducer(self, rref): 75 return self._py_rref_reducer(rref) 76 77 @classmethod 78 def _script_module_receiver(cls, script_module_serialized): 79 """ 80 Given a serialized representation of a ScriptModule created with torch.jit.save, 81 loads and returns the ScriptModule. 82 """ 83 f = io.BytesIO(script_module_serialized) 84 m = torch.jit.load(f) 85 return m 86 87 def _script_module_reducer(self, script_module): 88 """ 89 Serializes a ScriptModule. 90 """ 91 f = io.BytesIO() 92 torch.jit.save(script_module, f) 93 return (_InternalRPCPickler._script_module_receiver, (f.getvalue(),)) 94 95 def serialize(self, obj): 96 r""" 97 Serialize non tensor data into binary string, tensor data into 98 tensor table 99 """ 100 f = io.BytesIO() 101 p = _pickler(f) 102 p.dispatch_table = self._dispatch_table 103 104 # rpc api could accept user picklers inheriting from _InternalRPCPickler to serialize rref, 105 # user picklers could have different initialization function from _InternalRPCPickler, 106 # but all the user picklers should call serialize() and use _rref_reducer to pickle rref 107 # in python. also, when _internal_rpc_pickler is imported to rpc/api.py, rpc.RRef is not 108 # compiled yet, it is not good place to access rpc.RRef inside _InternalRPCPickler constructor, 109 # so putting rref's dispatch table here 110 # 111 # The return value of a `rpc.remote(..)` call is type of `rpc.PyRRef`. 112 # The deserialized RRef object on an RPC receiver side is type of `rpc.PyRRef`. 113 # Ignore type error because dispatch_table is defined in third-party package 114 p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer # type: ignore[index] 115 # An RRef created locally by RRef Python constructor is type of `rpc.RRef`. 116 # Ignore type error because dispatch_table is defined in third-party package 117 p.dispatch_table[dist.rpc.RRef] = self._rref_reducer # type: ignore[index] 118 119 # Add dispatch pickling for ScriptModule or its subclass. 120 if isinstance(obj, torch.jit.ScriptModule): 121 # Ignore type error because dispatch_table is defined in third-party package 122 p.dispatch_table[obj.__class__] = self._script_module_reducer # type: ignore[index] 123 124 # Install customized picklers. 125 for class_name in self._class_reducer_dict.keys(): 126 p.dispatch_table[class_name] = self._class_reducer_dict[class_name] # type: ignore[index] 127 128 # save _thread_local_tensor_tables.send_tables if it is in nested call 129 global _thread_local_tensor_tables 130 if hasattr(_thread_local_tensor_tables, "send_tables"): 131 old_send_tables = _thread_local_tensor_tables.send_tables 132 else: 133 old_send_tables = None 134 _thread_local_tensor_tables.send_tables = [] 135 136 p.dump(obj) 137 138 # restore _thread_local_tensor_tables.send_tables if return 139 # from nested call, otherwise clean up the table 140 tensors = _thread_local_tensor_tables.send_tables 141 if old_send_tables is not None: 142 _thread_local_tensor_tables.send_tables = old_send_tables 143 else: 144 del _thread_local_tensor_tables.send_tables 145 146 return (f.getvalue(), tensors) 147 148 def deserialize(self, binary_data, tensor_table): 149 r""" 150 Deserialize binary string + tensor table to original obj 151 """ 152 # save _thread_local_tensor_tables.recv_tables if it is in nested call 153 global _thread_local_tensor_tables 154 if hasattr(_thread_local_tensor_tables, "recv_tables"): 155 old_recv_tables = _thread_local_tensor_tables.recv_tables 156 else: 157 old_recv_tables = None 158 _thread_local_tensor_tables.recv_tables = tensor_table 159 160 try: 161 unpickler = _unpickler(io.BytesIO(binary_data)) 162 ret = unpickler.load() 163 except AttributeError as e: 164 # Occurs when function is not found on module/class during 165 # unpickling. 166 except_str = ( 167 str(e) 168 + """ Default RPC pickler does not serialize 169 function code. Ensure that UDFs are defined on both caller and 170 callee modules.""" 171 ) 172 ret = AttributeError(except_str) 173 # Ensure the stack trace gets preserved 174 ret.__cause__ = e 175 176 # restore _thread_local_tensor_tables.recv_tables if return 177 # from nested call, otherwise clean up the table 178 if old_recv_tables is not None: 179 _thread_local_tensor_tables.recv_tables = old_recv_tables 180 else: 181 del _thread_local_tensor_tables.recv_tables 182 183 return ret 184 185 186# Create _internal_rpc_pickler only once to initialize _dispatch_table only once 187_internal_rpc_pickler = _InternalRPCPickler() 188 189 190def serialize(obj): 191 return _internal_rpc_pickler.serialize(obj) 192 193 194def deserialize(binary_data, tensor_table): 195 return _internal_rpc_pickler.deserialize(binary_data, tensor_table) 196 197 198def _run_function(python_udf): 199 r""" 200 This function is exclusively called from C++. 201 See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``. 202 203 Runs a Python UDF and returns its return value. 204 Wraps any exception in ``RemoteException`` if the function raises. 205 """ 206 try: 207 if isinstance(python_udf, AttributeError): 208 raise python_udf 209 result = python_udf.func(*python_udf.args, **python_udf.kwargs) 210 except Exception as e: 211 # except str = exception info + traceback string 212 except_str = ( 213 f"On {_get_current_rpc_agent().get_worker_info()}:\n" 214 f"{repr(e)}\n{traceback.format_exc()}" 215 ) 216 print(except_str, file=sys.stderr) 217 result = RemoteException(except_str, type(e)) 218 return result 219 220 221def _handle_exception(result): 222 if isinstance(result, RemoteException): 223 exception_msg = result.msg.encode("utf-8").decode("unicode_escape") 224 # We wrap exception re-creation here in case some exception classes 225 # cannot be constructed directly from a string. 226 exc = None 227 try: 228 exc = result.exception_type(exception_msg) 229 except BaseException as e: 230 raise RuntimeError( # noqa: B904 231 f"Failed to create original exception type. Error msg was {str(e)}" 232 f" Original exception on remote side was {exception_msg}" 233 ) from e 234 235 if exc is not None: 236 raise exc 237 238 239def _build_rpc_profiling_key( 240 exec_type, func_name, current_worker_name, dst_worker_name 241): 242 """ 243 Builds the key that RPC calls are profiled with using the autograd profiler. 244 This will be the name of the corresponding Event recorded in the profiler. 245 246 Args: 247 exec_type (RPCExecMode): Type of RPC/RRef call 248 func_name (str): Name of function being profiled. 249 current_worker_name (str): Name of current worker. 250 dst_worker_name (str): Name of the destination worker. 251 252 Returns: 253 String representing profiling key 254 """ 255 profile_key = ( 256 f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})" 257 ) 258 return profile_key 259 260 261def _start_record_function(exec_type, func_name, current_worker_name, dest_worker_name): 262 """ 263 This function should be called from RPC/RRef functions to create a 264 RecordFunction object for profiling. This function also runs the before 265 callbacks that start the profiling, though the user is responsible for 266 running the appropriate callbacks when the function to be profiled finishes. 267 268 Args: 269 exec_type (RPCExecMode): Type of RPC/RRef call 270 func_name (str): Name of function being profiled. 271 current_worker_name (str): Name of current worker. 272 dest_worker_name (str): Name of the destination worker. 273 274 Returns: 275 An instance of `torch.autograd._RecordFunction`. 276 """ 277 assert torch.autograd._profiler_enabled(), "Autograd profiler should be enabled." 278 profile_key = f"rpc_{exec_type.value}#{str(func_name)}({current_worker_name} -> {dest_worker_name})" 279 rf = torch.autograd._RecordFunction() # type: ignore[attr-defined] 280 torch.autograd._run_before_callbacks(rf, profile_key) # type: ignore[attr-defined] 281 return rf 282 283 284PythonUDF = collections.namedtuple("PythonUDF", ["func", "args", "kwargs"]) 285RemoteException = collections.namedtuple("RemoteException", ["msg", "exception_type"]) 286