xref: /aosp_15_r20/external/pytorch/torch/multiprocessing/queue.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport io
3*da0073e9SAndroid Build Coastguard Workerimport multiprocessing.queues
4*da0073e9SAndroid Build Coastguard Workerimport pickle
5*da0073e9SAndroid Build Coastguard Workerfrom multiprocessing.reduction import ForkingPickler
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerclass ConnectionWrapper:
9*da0073e9SAndroid Build Coastguard Worker    """Proxy class for _multiprocessing.Connection which uses ForkingPickler for object serialization."""
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker    def __init__(self, conn):
12*da0073e9SAndroid Build Coastguard Worker        self.conn = conn
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker    def send(self, obj):
15*da0073e9SAndroid Build Coastguard Worker        buf = io.BytesIO()
16*da0073e9SAndroid Build Coastguard Worker        ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj)
17*da0073e9SAndroid Build Coastguard Worker        self.send_bytes(buf.getvalue())
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker    def recv(self):
20*da0073e9SAndroid Build Coastguard Worker        buf = self.recv_bytes()
21*da0073e9SAndroid Build Coastguard Worker        return pickle.loads(buf)
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker    def __getattr__(self, name):
24*da0073e9SAndroid Build Coastguard Worker        if "conn" in self.__dict__:
25*da0073e9SAndroid Build Coastguard Worker            return getattr(self.conn, name)
26*da0073e9SAndroid Build Coastguard Worker        raise AttributeError(f"'{type(self).__name__}' object has no attribute 'conn'")
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerclass Queue(multiprocessing.queues.Queue):
30*da0073e9SAndroid Build Coastguard Worker    def __init__(self, *args, **kwargs):
31*da0073e9SAndroid Build Coastguard Worker        super().__init__(*args, **kwargs)
32*da0073e9SAndroid Build Coastguard Worker        self._reader: ConnectionWrapper = ConnectionWrapper(self._reader)
33*da0073e9SAndroid Build Coastguard Worker        self._writer: ConnectionWrapper = ConnectionWrapper(self._writer)
34*da0073e9SAndroid Build Coastguard Worker        self._send = self._writer.send
35*da0073e9SAndroid Build Coastguard Worker        self._recv = self._reader.recv
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Workerclass SimpleQueue(multiprocessing.queues.SimpleQueue):
39*da0073e9SAndroid Build Coastguard Worker    def _make_methods(self):
40*da0073e9SAndroid Build Coastguard Worker        if not isinstance(self._reader, ConnectionWrapper):
41*da0073e9SAndroid Build Coastguard Worker            self._reader: ConnectionWrapper = ConnectionWrapper(self._reader)
42*da0073e9SAndroid Build Coastguard Worker            self._writer: ConnectionWrapper = ConnectionWrapper(self._writer)
43*da0073e9SAndroid Build Coastguard Worker        super()._make_methods()  # type: ignore[misc]
44