xref: /aosp_15_r20/external/pytorch/torch/distributed/rpc/internal.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import copyreg
4import io
5import pickle
6import sys
7import threading
8import traceback
9from enum import Enum
10
11import torch
12import torch.distributed as dist
13from torch._C._distributed_rpc import _get_current_rpc_agent
14
15
16__all__ = ["RPCExecMode", "serialize", "deserialize", "PythonUDF", "RemoteException"]
17
18# Thread local tensor tables to store tensors while pickling torch.Tensor
19# objects
20_thread_local_tensor_tables = threading.local()
21_pickler = pickle.Pickler
22_unpickler = pickle.Unpickler
23
24
25class RPCExecMode(Enum):
26    SYNC = "sync"
27    ASYNC = "async"
28    ASYNC_JIT = "async_jit"
29    REMOTE = "remote"
30
31
32class _InternalRPCPickler:
33    r"""
34    This class provides serialize() and deserialize() interfaces to serialize
35    data to be "binary string + tensor table" format
36    So for RPC python UDF function and args, non tensor data will be serialized
37    into regular binary string, tensor data will be put into thread local tensor
38    tables, this serialization format is consistent with builtin operator and args
39    using JIT pickler. This format will make tensor handling in C++ much easier,
40    e.g. attach tensor to distributed autograd graph in C++
41    """
42
43    def __init__(self):
44        # Ignore type error because dispatch_table is defined in third-party package
45        self._dispatch_table = copyreg.dispatch_table.copy()  # type: ignore[attr-defined]
46        self._dispatch_table[torch.Tensor] = self._tensor_reducer
47        # Used for registering customized picklers.
48        self._class_reducer_dict = {}
49
50    def _register_reducer(self, obj_class, reducer):
51        # For the same class, only register the reducer once.
52        if obj_class not in self._class_reducer_dict:
53            self._class_reducer_dict[obj_class] = reducer
54
55    @classmethod
56    def _tensor_receiver(cls, tensor_index):
57        global _thread_local_tensor_tables
58        return _thread_local_tensor_tables.recv_tables[tensor_index]
59
60    def _tensor_reducer(self, tensor):
61        global _thread_local_tensor_tables
62        _thread_local_tensor_tables.send_tables.append(tensor)
63        tensor_index = len(_thread_local_tensor_tables.send_tables) - 1
64        return (_InternalRPCPickler._tensor_receiver, (tensor_index,))
65
66    @classmethod
67    def _py_rref_receiver(cls, rref_fork_data):
68        return dist.rpc.PyRRef._deserialize(rref_fork_data)
69
70    def _py_rref_reducer(self, py_rref):
71        rref_fork_data = py_rref._serialize()
72        return (_InternalRPCPickler._py_rref_receiver, (rref_fork_data,))
73
74    def _rref_reducer(self, rref):
75        return self._py_rref_reducer(rref)
76
77    @classmethod
78    def _script_module_receiver(cls, script_module_serialized):
79        """
80        Given a serialized representation of a ScriptModule created with torch.jit.save,
81        loads and returns the ScriptModule.
82        """
83        f = io.BytesIO(script_module_serialized)
84        m = torch.jit.load(f)
85        return m
86
87    def _script_module_reducer(self, script_module):
88        """
89        Serializes a ScriptModule.
90        """
91        f = io.BytesIO()
92        torch.jit.save(script_module, f)
93        return (_InternalRPCPickler._script_module_receiver, (f.getvalue(),))
94
95    def serialize(self, obj):
96        r"""
97        Serialize non tensor data into binary string, tensor data into
98        tensor table
99        """
100        f = io.BytesIO()
101        p = _pickler(f)
102        p.dispatch_table = self._dispatch_table
103
104        # rpc api could accept user picklers inheriting from _InternalRPCPickler to serialize rref,
105        # user picklers could have different initialization function from _InternalRPCPickler,
106        # but all the user picklers should call serialize() and use _rref_reducer to pickle rref
107        # in python. also, when _internal_rpc_pickler is imported to rpc/api.py, rpc.RRef is not
108        # compiled yet, it is not good place to access rpc.RRef inside _InternalRPCPickler constructor,
109        # so putting rref's dispatch table here
110        #
111        # The return value of a `rpc.remote(..)` call is type of `rpc.PyRRef`.
112        # The deserialized RRef object on an RPC receiver side is type of `rpc.PyRRef`.
113        # Ignore type error because dispatch_table is defined in third-party package
114        p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer  # type: ignore[index]
115        # An RRef created locally by RRef Python constructor is type of `rpc.RRef`.
116        # Ignore type error because dispatch_table is defined in third-party package
117        p.dispatch_table[dist.rpc.RRef] = self._rref_reducer  # type: ignore[index]
118
119        # Add dispatch pickling for ScriptModule or its subclass.
120        if isinstance(obj, torch.jit.ScriptModule):
121            # Ignore type error because dispatch_table is defined in third-party package
122            p.dispatch_table[obj.__class__] = self._script_module_reducer  # type: ignore[index]
123
124        # Install customized picklers.
125        for class_name in self._class_reducer_dict.keys():
126            p.dispatch_table[class_name] = self._class_reducer_dict[class_name]  # type: ignore[index]
127
128        # save _thread_local_tensor_tables.send_tables if it is in nested call
129        global _thread_local_tensor_tables
130        if hasattr(_thread_local_tensor_tables, "send_tables"):
131            old_send_tables = _thread_local_tensor_tables.send_tables
132        else:
133            old_send_tables = None
134        _thread_local_tensor_tables.send_tables = []
135
136        p.dump(obj)
137
138        # restore _thread_local_tensor_tables.send_tables if return
139        # from nested call, otherwise clean up the table
140        tensors = _thread_local_tensor_tables.send_tables
141        if old_send_tables is not None:
142            _thread_local_tensor_tables.send_tables = old_send_tables
143        else:
144            del _thread_local_tensor_tables.send_tables
145
146        return (f.getvalue(), tensors)
147
148    def deserialize(self, binary_data, tensor_table):
149        r"""
150        Deserialize binary string + tensor table to original obj
151        """
152        # save _thread_local_tensor_tables.recv_tables if it is in nested call
153        global _thread_local_tensor_tables
154        if hasattr(_thread_local_tensor_tables, "recv_tables"):
155            old_recv_tables = _thread_local_tensor_tables.recv_tables
156        else:
157            old_recv_tables = None
158        _thread_local_tensor_tables.recv_tables = tensor_table
159
160        try:
161            unpickler = _unpickler(io.BytesIO(binary_data))
162            ret = unpickler.load()
163        except AttributeError as e:
164            # Occurs when function is not found on module/class during
165            # unpickling.
166            except_str = (
167                str(e)
168                + """ Default RPC pickler does not serialize
169            function code. Ensure that UDFs are defined on both caller and
170            callee modules."""
171            )
172            ret = AttributeError(except_str)
173            # Ensure the stack trace gets preserved
174            ret.__cause__ = e
175
176        # restore _thread_local_tensor_tables.recv_tables if return
177        # from nested call, otherwise clean up the table
178        if old_recv_tables is not None:
179            _thread_local_tensor_tables.recv_tables = old_recv_tables
180        else:
181            del _thread_local_tensor_tables.recv_tables
182
183        return ret
184
185
186# Create _internal_rpc_pickler only once to initialize _dispatch_table only once
187_internal_rpc_pickler = _InternalRPCPickler()
188
189
190def serialize(obj):
191    return _internal_rpc_pickler.serialize(obj)
192
193
194def deserialize(binary_data, tensor_table):
195    return _internal_rpc_pickler.deserialize(binary_data, tensor_table)
196
197
198def _run_function(python_udf):
199    r"""
200    This function is exclusively called from C++.
201    See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``.
202
203    Runs a Python UDF and returns its return value.
204    Wraps any exception in ``RemoteException`` if the function raises.
205    """
206    try:
207        if isinstance(python_udf, AttributeError):
208            raise python_udf
209        result = python_udf.func(*python_udf.args, **python_udf.kwargs)
210    except Exception as e:
211        # except str = exception info + traceback string
212        except_str = (
213            f"On {_get_current_rpc_agent().get_worker_info()}:\n"
214            f"{repr(e)}\n{traceback.format_exc()}"
215        )
216        print(except_str, file=sys.stderr)
217        result = RemoteException(except_str, type(e))
218    return result
219
220
221def _handle_exception(result):
222    if isinstance(result, RemoteException):
223        exception_msg = result.msg.encode("utf-8").decode("unicode_escape")
224        # We wrap exception re-creation here in case some exception classes
225        # cannot be constructed directly from a string.
226        exc = None
227        try:
228            exc = result.exception_type(exception_msg)
229        except BaseException as e:
230            raise RuntimeError(  # noqa: B904
231                f"Failed to create original exception type. Error msg was {str(e)}"
232                f" Original exception on remote side was {exception_msg}"
233            ) from e
234
235        if exc is not None:
236            raise exc
237
238
239def _build_rpc_profiling_key(
240    exec_type, func_name, current_worker_name, dst_worker_name
241):
242    """
243    Builds the key that RPC calls are profiled with using the autograd profiler.
244    This will be the name of the corresponding Event recorded in the profiler.
245
246    Args:
247        exec_type (RPCExecMode): Type of RPC/RRef call
248        func_name (str): Name of function being profiled.
249        current_worker_name (str): Name of current worker.
250        dst_worker_name (str): Name of the destination worker.
251
252    Returns:
253        String representing profiling key
254    """
255    profile_key = (
256        f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})"
257    )
258    return profile_key
259
260
261def _start_record_function(exec_type, func_name, current_worker_name, dest_worker_name):
262    """
263    This function should be called from RPC/RRef functions to create a
264    RecordFunction object for profiling. This function also runs the before
265    callbacks that start the profiling, though the user is responsible for
266    running the appropriate callbacks when the function to be profiled finishes.
267
268    Args:
269        exec_type (RPCExecMode): Type of RPC/RRef call
270        func_name (str): Name of function being profiled.
271        current_worker_name (str): Name of current worker.
272        dest_worker_name (str): Name of the destination worker.
273
274    Returns:
275        An instance of `torch.autograd._RecordFunction`.
276    """
277    assert torch.autograd._profiler_enabled(), "Autograd profiler should be enabled."
278    profile_key = f"rpc_{exec_type.value}#{str(func_name)}({current_worker_name} -> {dest_worker_name})"
279    rf = torch.autograd._RecordFunction()  # type: ignore[attr-defined]
280    torch.autograd._run_before_callbacks(rf, profile_key)  # type: ignore[attr-defined]
281    return rf
282
283
284PythonUDF = collections.namedtuple("PythonUDF", ["func", "args", "kwargs"])
285RemoteException = collections.namedtuple("RemoteException", ["msg", "exception_type"])
286