xref: /aosp_15_r20/external/pytorch/torch/multiprocessing/reductions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport multiprocessing
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport threading
5*da0073e9SAndroid Build Coastguard Workerfrom multiprocessing.reduction import ForkingPickler
6*da0073e9SAndroid Build Coastguard Workerfrom multiprocessing.util import register_after_fork
7*da0073e9SAndroid Build Coastguard Workerfrom typing import Union
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch
10*da0073e9SAndroid Build Coastguard Workerfrom torch._namedtensor_internals import check_serializing_named_tensor
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workertry:
14*da0073e9SAndroid Build Coastguard Worker    # Early load resource_sharer to prevent a partially initialized instance
15*da0073e9SAndroid Build Coastguard Worker    # from being inherited in a forked child process. The reduce_storage method
16*da0073e9SAndroid Build Coastguard Worker    # requires this module indirectly through DupFd(). The built-in mp.Queue
17*da0073e9SAndroid Build Coastguard Worker    # class pickles arguments in a background thread which may overlap with the
18*da0073e9SAndroid Build Coastguard Worker    # fork.
19*da0073e9SAndroid Build Coastguard Worker    import multiprocessing.resource_sharer
20*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
21*da0073e9SAndroid Build Coastguard Worker    pass
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerclass StorageWeakRef:
25*da0073e9SAndroid Build Coastguard Worker    r"""A weak reference to a Storage.
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    The cdata member is a Python number containing the integer representation of
28*da0073e9SAndroid Build Coastguard Worker    the Storage pointer.
29*da0073e9SAndroid Build Coastguard Worker    """
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker    __slots__ = ["cdata", "_free_weak_ref"]
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker    def __init__(self, storage):
34*da0073e9SAndroid Build Coastguard Worker        self.cdata = storage._weak_ref()
35*da0073e9SAndroid Build Coastguard Worker        # Save a direct reference to _free_weak_ref because the `torch` module
36*da0073e9SAndroid Build Coastguard Worker        # might be cleared during Python shutdown before this module is cleared.
37*da0073e9SAndroid Build Coastguard Worker        self._free_weak_ref = torch.Storage._free_weak_ref  # type: ignore[attr-defined]
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker    @classmethod
40*da0073e9SAndroid Build Coastguard Worker    def from_weakref(cls, cdata):
41*da0073e9SAndroid Build Coastguard Worker        instance = cls.__new__(cls)
42*da0073e9SAndroid Build Coastguard Worker        instance.cdata = cdata
43*da0073e9SAndroid Build Coastguard Worker        instance._free_weak_ref = torch.Storage._free_weak_ref  # type: ignore[attr-defined]
44*da0073e9SAndroid Build Coastguard Worker        return instance
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    def expired(self):
47*da0073e9SAndroid Build Coastguard Worker        return torch.Storage._expired(self.cdata)  # type: ignore[attr-defined]
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    def __del__(self):
50*da0073e9SAndroid Build Coastguard Worker        self._free_weak_ref(self.cdata)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    def __hash__(self):
53*da0073e9SAndroid Build Coastguard Worker        return self.cdata
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    def __eq__(self, other):
56*da0073e9SAndroid Build Coastguard Worker        if id(self) == id(other):
57*da0073e9SAndroid Build Coastguard Worker            return True
58*da0073e9SAndroid Build Coastguard Worker        return self.cdata == other.cdata
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Workerclass SharedCache(dict):
62*da0073e9SAndroid Build Coastguard Worker    """Dictionary from multiprocessing handles to StorageWeakRef."""
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
65*da0073e9SAndroid Build Coastguard Worker        # free_dead_references() is called if the len exceeds the current
66*da0073e9SAndroid Build Coastguard Worker        # limit. The limit scales with the number of remaining live objects.
67*da0073e9SAndroid Build Coastguard Worker        self.limit = 128
68*da0073e9SAndroid Build Coastguard Worker        # `fork` inherits lock state, so in case we fork when the lock is held,
69*da0073e9SAndroid Build Coastguard Worker        # we register a function to reset the lock to a new object to avoid
70*da0073e9SAndroid Build Coastguard Worker        # possible deadlocks, following python multiprocessing library design.
71*da0073e9SAndroid Build Coastguard Worker        self._after_fork()
72*da0073e9SAndroid Build Coastguard Worker        register_after_fork(self, SharedCache._after_fork)
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    def _after_fork(self):
75*da0073e9SAndroid Build Coastguard Worker        self.lock = threading.Lock()
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    def get(self, key):
78*da0073e9SAndroid Build Coastguard Worker        with self.lock:
79*da0073e9SAndroid Build Coastguard Worker            return dict.get(self, key)
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker    def __setitem__(self, key, storage_ref):
82*da0073e9SAndroid Build Coastguard Worker        with self.lock:
83*da0073e9SAndroid Build Coastguard Worker            dict.__setitem__(self, key, storage_ref)
84*da0073e9SAndroid Build Coastguard Worker            if len(self) > self.limit:
85*da0073e9SAndroid Build Coastguard Worker                self.free_dead_references()
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker    def free_dead_references(self):
88*da0073e9SAndroid Build Coastguard Worker        live = 0
89*da0073e9SAndroid Build Coastguard Worker        for key, storage_ref in list(self.items()):
90*da0073e9SAndroid Build Coastguard Worker            if storage_ref.expired():
91*da0073e9SAndroid Build Coastguard Worker                del self[key]
92*da0073e9SAndroid Build Coastguard Worker            else:
93*da0073e9SAndroid Build Coastguard Worker                live += 1
94*da0073e9SAndroid Build Coastguard Worker        self.limit = max(128, live * 2)
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker# mapping from handles to StorageWeakRef objects
98*da0073e9SAndroid Build Coastguard Workershared_cache = SharedCache()
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Workerdef rebuild_event(device, handle):
102*da0073e9SAndroid Build Coastguard Worker    return torch.cuda.Event.from_ipc_handle(device, handle)
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Workerdef reduce_event(event):
106*da0073e9SAndroid Build Coastguard Worker    handle = event.ipc_handle()
107*da0073e9SAndroid Build Coastguard Worker    return (rebuild_event, (event.device, handle))
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Workerdef rebuild_tensor(cls, storage, metadata):
111*da0073e9SAndroid Build Coastguard Worker    storage_offset, size, stride, requires_grad = metadata
112*da0073e9SAndroid Build Coastguard Worker    t = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
113*da0073e9SAndroid Build Coastguard Worker    if cls == torch.nn.parameter.Parameter:
114*da0073e9SAndroid Build Coastguard Worker        # we have to pass requires_grad into constructor, rather than set it as an
115*da0073e9SAndroid Build Coastguard Worker        # attribute later, because it's an important check for Integer Tensors to
116*da0073e9SAndroid Build Coastguard Worker        # have requires_grad=False (or else they raise an error)
117*da0073e9SAndroid Build Coastguard Worker        t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
118*da0073e9SAndroid Build Coastguard Worker    else:
119*da0073e9SAndroid Build Coastguard Worker        t.requires_grad = requires_grad
120*da0073e9SAndroid Build Coastguard Worker    return t
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Workerdef rebuild_meta_tensor(
124*da0073e9SAndroid Build Coastguard Worker    tensor_cls,
125*da0073e9SAndroid Build Coastguard Worker    tensor_size,
126*da0073e9SAndroid Build Coastguard Worker    tensor_stride,
127*da0073e9SAndroid Build Coastguard Worker    tensor_offset,
128*da0073e9SAndroid Build Coastguard Worker    dtype,
129*da0073e9SAndroid Build Coastguard Worker    storage_size_bytes,
130*da0073e9SAndroid Build Coastguard Worker    requires_grad,
131*da0073e9SAndroid Build Coastguard Worker):
132*da0073e9SAndroid Build Coastguard Worker    untyped_storage = torch.UntypedStorage(storage_size_bytes, device="meta")
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker    typed_storage = torch.TypedStorage(
135*da0073e9SAndroid Build Coastguard Worker        wrap_storage=untyped_storage, dtype=dtype, _internal=True
136*da0073e9SAndroid Build Coastguard Worker    )
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    t = torch._utils._rebuild_tensor(
139*da0073e9SAndroid Build Coastguard Worker        typed_storage,
140*da0073e9SAndroid Build Coastguard Worker        tensor_offset,
141*da0073e9SAndroid Build Coastguard Worker        tensor_size,
142*da0073e9SAndroid Build Coastguard Worker        tensor_stride,
143*da0073e9SAndroid Build Coastguard Worker    )
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker    if tensor_cls == torch.nn.parameter.Parameter:
146*da0073e9SAndroid Build Coastguard Worker        # It is crucial for integer tensors to receive
147*da0073e9SAndroid Build Coastguard Worker        # the requires_grad=False as an argument in the constructor
148*da0073e9SAndroid Build Coastguard Worker        t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
149*da0073e9SAndroid Build Coastguard Worker    else:
150*da0073e9SAndroid Build Coastguard Worker        t.requires_grad = requires_grad
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker    return t
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Workerdef rebuild_cuda_tensor(
156*da0073e9SAndroid Build Coastguard Worker    tensor_cls,
157*da0073e9SAndroid Build Coastguard Worker    tensor_size,
158*da0073e9SAndroid Build Coastguard Worker    tensor_stride,
159*da0073e9SAndroid Build Coastguard Worker    tensor_offset,
160*da0073e9SAndroid Build Coastguard Worker    storage_cls,
161*da0073e9SAndroid Build Coastguard Worker    dtype,
162*da0073e9SAndroid Build Coastguard Worker    storage_device,
163*da0073e9SAndroid Build Coastguard Worker    storage_handle,
164*da0073e9SAndroid Build Coastguard Worker    storage_size_bytes,
165*da0073e9SAndroid Build Coastguard Worker    storage_offset_bytes,
166*da0073e9SAndroid Build Coastguard Worker    requires_grad,
167*da0073e9SAndroid Build Coastguard Worker    ref_counter_handle,
168*da0073e9SAndroid Build Coastguard Worker    ref_counter_offset,
169*da0073e9SAndroid Build Coastguard Worker    event_handle,
170*da0073e9SAndroid Build Coastguard Worker    event_sync_required,
171*da0073e9SAndroid Build Coastguard Worker):
172*da0073e9SAndroid Build Coastguard Worker    # If storage_handle is None, storage points to nullptr.
173*da0073e9SAndroid Build Coastguard Worker    if storage_handle is None or storage_size_bytes == 0:
174*da0073e9SAndroid Build Coastguard Worker        storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True)
175*da0073e9SAndroid Build Coastguard Worker    else:
176*da0073e9SAndroid Build Coastguard Worker        storage = storage_from_cache(
177*da0073e9SAndroid Build Coastguard Worker            storage_cls, (storage_handle, storage_offset_bytes)
178*da0073e9SAndroid Build Coastguard Worker        )
179*da0073e9SAndroid Build Coastguard Worker        if storage is None:
180*da0073e9SAndroid Build Coastguard Worker            torch.cuda._lazy_init()
181*da0073e9SAndroid Build Coastguard Worker            storage = storage_cls._new_shared_cuda(
182*da0073e9SAndroid Build Coastguard Worker                storage_device,
183*da0073e9SAndroid Build Coastguard Worker                storage_handle,
184*da0073e9SAndroid Build Coastguard Worker                storage_size_bytes,
185*da0073e9SAndroid Build Coastguard Worker                storage_offset_bytes,
186*da0073e9SAndroid Build Coastguard Worker                ref_counter_handle,
187*da0073e9SAndroid Build Coastguard Worker                ref_counter_offset,
188*da0073e9SAndroid Build Coastguard Worker                event_handle,
189*da0073e9SAndroid Build Coastguard Worker                event_sync_required,
190*da0073e9SAndroid Build Coastguard Worker            )
191*da0073e9SAndroid Build Coastguard Worker            shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(
192*da0073e9SAndroid Build Coastguard Worker                storage
193*da0073e9SAndroid Build Coastguard Worker            )
194*da0073e9SAndroid Build Coastguard Worker        else:
195*da0073e9SAndroid Build Coastguard Worker            # We already ref counting this Storage, but producer needs new ref-counters to be released.
196*da0073e9SAndroid Build Coastguard Worker            storage_cls._release_ipc_counter(
197*da0073e9SAndroid Build Coastguard Worker                ref_counter_handle, ref_counter_offset, device=storage_device
198*da0073e9SAndroid Build Coastguard Worker            )
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker    _storage = (
201*da0073e9SAndroid Build Coastguard Worker        storage
202*da0073e9SAndroid Build Coastguard Worker        if isinstance(storage, torch.UntypedStorage)
203*da0073e9SAndroid Build Coastguard Worker        else storage._untyped_storage
204*da0073e9SAndroid Build Coastguard Worker    )
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker    t = torch._utils._rebuild_tensor(
207*da0073e9SAndroid Build Coastguard Worker        torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True),
208*da0073e9SAndroid Build Coastguard Worker        tensor_offset,
209*da0073e9SAndroid Build Coastguard Worker        tensor_size,
210*da0073e9SAndroid Build Coastguard Worker        tensor_stride,
211*da0073e9SAndroid Build Coastguard Worker    )
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker    if tensor_cls == torch.nn.parameter.Parameter:
214*da0073e9SAndroid Build Coastguard Worker        # It is crucial for integer tensors to receive
215*da0073e9SAndroid Build Coastguard Worker        # the requires_grad=False as an argument in the constructor
216*da0073e9SAndroid Build Coastguard Worker        t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
217*da0073e9SAndroid Build Coastguard Worker    else:
218*da0073e9SAndroid Build Coastguard Worker        t.requires_grad = requires_grad
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker    return t
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Workerdef reduce_tensor(tensor):
224*da0073e9SAndroid Build Coastguard Worker    if tensor.requires_grad and not tensor.is_leaf:
225*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
226*da0073e9SAndroid Build Coastguard Worker            "Cowardly refusing to serialize non-leaf tensor which requires_grad, "
227*da0073e9SAndroid Build Coastguard Worker            "since autograd does not support crossing process boundaries.  "
228*da0073e9SAndroid Build Coastguard Worker            "If you just want to transfer the data, call detach() on the tensor "
229*da0073e9SAndroid Build Coastguard Worker            "before serializing (e.g., putting it on the queue)."
230*da0073e9SAndroid Build Coastguard Worker        )
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker    check_serializing_named_tensor(tensor)
233*da0073e9SAndroid Build Coastguard Worker    torch.utils.hooks.warn_if_has_hooks(tensor)
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker    # Note [CUDA IPC and the caching allocator]
236*da0073e9SAndroid Build Coastguard Worker    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
237*da0073e9SAndroid Build Coastguard Worker    # When you send a CUDA tensor over IPC, you might expect that you will
238*da0073e9SAndroid Build Coastguard Worker    # get out the same storage from the other end.  However, the CUDA caching
239*da0073e9SAndroid Build Coastguard Worker    # allocator makes it difficult to preserve this invariant.  Consider
240*da0073e9SAndroid Build Coastguard Worker    # the following situation: a tensor of size 0x100 points to offset 0x20 of
241*da0073e9SAndroid Build Coastguard Worker    # a storage at 0xA100 of size 0x100.  (For simplicity, all of these
242*da0073e9SAndroid Build Coastguard Worker    # sizes are given in bytes).  HOWEVER, with the caching allocator, this storage
243*da0073e9SAndroid Build Coastguard Worker    # might be part of a larger cudaMalloc allocation 0xA000 of size 0x4000.
244*da0073e9SAndroid Build Coastguard Worker    #
245*da0073e9SAndroid Build Coastguard Worker    # When we want to send this CUDA tensor over IPC, we must send the
246*da0073e9SAndroid Build Coastguard Worker    # *entire* cudaMalloc allocation, i.e., the 0xA000 region, not just
247*da0073e9SAndroid Build Coastguard Worker    # the storage 0xA100 (because that is what CUDA supports).  So, on the
248*da0073e9SAndroid Build Coastguard Worker    # other end, there simply isn't any way to say, "Wait, you gave me
249*da0073e9SAndroid Build Coastguard Worker    # a bigger region (0xA000) than the one I wanted (0xA100)".
250*da0073e9SAndroid Build Coastguard Worker    #
251*da0073e9SAndroid Build Coastguard Worker    # OK, so if you sent the cudaMalloc allocation, can you just wrap that up as
252*da0073e9SAndroid Build Coastguard Worker    # one storage itself? No, because this cudaMalloc allocation might contain
253*da0073e9SAndroid Build Coastguard Worker    # storages of mixed types: float, bytes, double... If you make the entire
254*da0073e9SAndroid Build Coastguard Worker    # allocation a single storage of a type A, we'll hit an error when constructing
255*da0073e9SAndroid Build Coastguard Worker    # a tensor of type B on the storage.
256*da0073e9SAndroid Build Coastguard Worker    #
257*da0073e9SAndroid Build Coastguard Worker    # cudaIpcMemHandle is an identifier to access the sender cudaMalloc allocation on the
258*da0073e9SAndroid Build Coastguard Worker    # receiver side. However, cudaIpcMemHandles from each device in a given process may
259*da0073e9SAndroid Build Coastguard Worker    # only be opened by one context per device per other process.
260*da0073e9SAndroid Build Coastguard Worker    # If we open and close a memory handle multiples times in a process, CUDA is allowed
261*da0073e9SAndroid Build Coastguard Worker    # to give it a different address; similarly, once we close the memory, we're not
262*da0073e9SAndroid Build Coastguard Worker    # allowed to access it(and the storage/tensor built on top of it), even if it is
263*da0073e9SAndroid Build Coastguard Worker    # still live in the original process. As we cannot make a cudaMalloc allocation
264*da0073e9SAndroid Build Coastguard Worker    # to a single storage in one go, this requires us to cache the device pointer for
265*da0073e9SAndroid Build Coastguard Worker    # each cudaIpcMemHandle on C++ side to reconstruct types of storages, while keep
266*da0073e9SAndroid Build Coastguard Worker    # the old ones alives.
267*da0073e9SAndroid Build Coastguard Worker    # See [https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html]
268*da0073e9SAndroid Build Coastguard Worker    #
269*da0073e9SAndroid Build Coastguard Worker    # This is fine, because all we need to do is to save our position in the allocation,
270*da0073e9SAndroid Build Coastguard Worker    # and reconstruct storage and tensor from it.
271*da0073e9SAndroid Build Coastguard Worker    # 0xA000 ->  -------CUDA Allocation------
272*da0073e9SAndroid Build Coastguard Worker    #           |                            |
273*da0073e9SAndroid Build Coastguard Worker    #           |                            |
274*da0073e9SAndroid Build Coastguard Worker    #           |                            |
275*da0073e9SAndroid Build Coastguard Worker    #           |                            |
276*da0073e9SAndroid Build Coastguard Worker    # 0xA100 ->  --------storage1 begin------
277*da0073e9SAndroid Build Coastguard Worker    #           |                            |
278*da0073e9SAndroid Build Coastguard Worker    # 0xA120 ->  --------tensor1 begin ------
279*da0073e9SAndroid Build Coastguard Worker    #           |                            |
280*da0073e9SAndroid Build Coastguard Worker    #           |                            |
281*da0073e9SAndroid Build Coastguard Worker    #           |                            |
282*da0073e9SAndroid Build Coastguard Worker    #           |                            |
283*da0073e9SAndroid Build Coastguard Worker    #           |                            |
284*da0073e9SAndroid Build Coastguard Worker    # 0xA160 ->  --------tensor1 end---------
285*da0073e9SAndroid Build Coastguard Worker    #           |                            |
286*da0073e9SAndroid Build Coastguard Worker    #           |                            |
287*da0073e9SAndroid Build Coastguard Worker    #           |                            |
288*da0073e9SAndroid Build Coastguard Worker    # 0xA200 ->  --------storage1 end--------
289*da0073e9SAndroid Build Coastguard Worker    #           |                            |
290*da0073e9SAndroid Build Coastguard Worker    # 0xE000 ->  --------CUDA allocation-----
291*da0073e9SAndroid Build Coastguard Worker    #
292*da0073e9SAndroid Build Coastguard Worker    # To send tensor1, the following info are required from sender to receiver for
293*da0073e9SAndroid Build Coastguard Worker    # storage recontruction.
294*da0073e9SAndroid Build Coastguard Worker    #   1. cudaIpcMemHandle of 0xA000(which can be mapped to a basePtr in receiver process).
295*da0073e9SAndroid Build Coastguard Worker    #      basePtr may not be exactly 0xA000 since it's a different process.
296*da0073e9SAndroid Build Coastguard Worker    #   2. offset(0xA100) of storage1 in the CUDA allocation.
297*da0073e9SAndroid Build Coastguard Worker    #   3. size of storage1(0x100).
298*da0073e9SAndroid Build Coastguard Worker    #
299*da0073e9SAndroid Build Coastguard Worker    # On receiver side:
300*da0073e9SAndroid Build Coastguard Worker    #   1. Get the devPtr of the MemHandle to access the memory, reconstruct a storage
301*da0073e9SAndroid Build Coastguard Worker    #      of the same type using (basePtr, offset, size).
302*da0073e9SAndroid Build Coastguard Worker    #   2. we can reconstruct the tensor on top of the reconstructed storage
303*da0073e9SAndroid Build Coastguard Worker    #   Tensor(size=0x040, offset=0x020, storage=Storage(data=basePtr+0xA100, size=0x0100))
304*da0073e9SAndroid Build Coastguard Worker    #
305*da0073e9SAndroid Build Coastguard Worker    # This strategy has a few implications:
306*da0073e9SAndroid Build Coastguard Worker    #
307*da0073e9SAndroid Build Coastguard Worker    # 1. When we serialize a CUDA tensor for IPC, we cannot do it all in one
308*da0073e9SAndroid Build Coastguard Worker    #    go (non-compositionally), and this requires to have a global map
309*da0073e9SAndroid Build Coastguard Worker    #    memHandle -> devPtr for each process.
310*da0073e9SAndroid Build Coastguard Worker    #
311*da0073e9SAndroid Build Coastguard Worker    # 2. We MUST NOT let the new IPC tensor be resizable.  Originally, a resize
312*da0073e9SAndroid Build Coastguard Worker    #    of the storage beyond 0x100 would merely have caused us to do a
313*da0073e9SAndroid Build Coastguard Worker    #    reallocation.  You don't really want to do this, but if you did,
314*da0073e9SAndroid Build Coastguard Worker    #    all that would happen is that you would lose IPC sharing.  But if
315*da0073e9SAndroid Build Coastguard Worker    #    you do this in the new world, we will happily let you write out of
316*da0073e9SAndroid Build Coastguard Worker    #    bounds of your "allocation", clobbering unrelated data in the cached
317*da0073e9SAndroid Build Coastguard Worker    #    allocator block.  BAD!
318*da0073e9SAndroid Build Coastguard Worker    #
319*da0073e9SAndroid Build Coastguard Worker    # By the way, in old versions of PyTorch, we supported this situation
320*da0073e9SAndroid Build Coastguard Worker    # natively using a "storage view", which permitted multiple storages to be
321*da0073e9SAndroid Build Coastguard Worker    # views on each other.  But this was the *only* use of storage views, so we
322*da0073e9SAndroid Build Coastguard Worker    # eliminated it so that we could just use tensor views to implement the same
323*da0073e9SAndroid Build Coastguard Worker    # thing.
324*da0073e9SAndroid Build Coastguard Worker    #
325*da0073e9SAndroid Build Coastguard Worker
326*da0073e9SAndroid Build Coastguard Worker    # TODO: Handle distinguishing between subclass and non-subclass versions of NT better
327*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/110543
328*da0073e9SAndroid Build Coastguard Worker    from torch.nested._internal.nested_tensor import NestedTensor
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker    if tensor.is_nested and not isinstance(tensor, NestedTensor):
331*da0073e9SAndroid Build Coastguard Worker        return reduce_nested_tensor(tensor)
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker    if tensor.layout in {
334*da0073e9SAndroid Build Coastguard Worker        torch.sparse_coo,
335*da0073e9SAndroid Build Coastguard Worker        torch.sparse_csr,
336*da0073e9SAndroid Build Coastguard Worker        torch.sparse_bsr,
337*da0073e9SAndroid Build Coastguard Worker        torch.sparse_csc,
338*da0073e9SAndroid Build Coastguard Worker        torch.sparse_bsc,
339*da0073e9SAndroid Build Coastguard Worker    }:
340*da0073e9SAndroid Build Coastguard Worker        return reduce_sparse_tensor(tensor)
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker    storage = tensor._typed_storage()
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker    if storage._untyped_storage.device.type == "cuda":
345*da0073e9SAndroid Build Coastguard Worker        (
346*da0073e9SAndroid Build Coastguard Worker            device,
347*da0073e9SAndroid Build Coastguard Worker            handle,
348*da0073e9SAndroid Build Coastguard Worker            storage_size_bytes,
349*da0073e9SAndroid Build Coastguard Worker            storage_offset_bytes,
350*da0073e9SAndroid Build Coastguard Worker            ref_counter_handle,
351*da0073e9SAndroid Build Coastguard Worker            ref_counter_offset,
352*da0073e9SAndroid Build Coastguard Worker            event_handle,
353*da0073e9SAndroid Build Coastguard Worker            event_sync_required,
354*da0073e9SAndroid Build Coastguard Worker        ) = storage._share_cuda_()
355*da0073e9SAndroid Build Coastguard Worker        tensor_offset = tensor.storage_offset()
356*da0073e9SAndroid Build Coastguard Worker        shared_cache[handle] = StorageWeakRef(storage)
357*da0073e9SAndroid Build Coastguard Worker        # _backward_hooks purposely omitted here, see
358*da0073e9SAndroid Build Coastguard Worker        # Note [Don't serialize hooks]
359*da0073e9SAndroid Build Coastguard Worker        return (
360*da0073e9SAndroid Build Coastguard Worker            rebuild_cuda_tensor,
361*da0073e9SAndroid Build Coastguard Worker            (
362*da0073e9SAndroid Build Coastguard Worker                type(tensor),
363*da0073e9SAndroid Build Coastguard Worker                tensor.size(),
364*da0073e9SAndroid Build Coastguard Worker                tensor.stride(),
365*da0073e9SAndroid Build Coastguard Worker                tensor_offset,  # tensor offset in its storage
366*da0073e9SAndroid Build Coastguard Worker                type(storage),
367*da0073e9SAndroid Build Coastguard Worker                tensor.dtype,
368*da0073e9SAndroid Build Coastguard Worker                device,
369*da0073e9SAndroid Build Coastguard Worker                handle,  # identifier which CUDA allocation is the storage in.
370*da0073e9SAndroid Build Coastguard Worker                storage_size_bytes,  # size(in bytes) of the storage
371*da0073e9SAndroid Build Coastguard Worker                storage_offset_bytes,  # offset(in bytes) of the storage in the CUDA allocation
372*da0073e9SAndroid Build Coastguard Worker                tensor.requires_grad,
373*da0073e9SAndroid Build Coastguard Worker                ref_counter_handle,
374*da0073e9SAndroid Build Coastguard Worker                ref_counter_offset,
375*da0073e9SAndroid Build Coastguard Worker                event_handle,
376*da0073e9SAndroid Build Coastguard Worker                event_sync_required,
377*da0073e9SAndroid Build Coastguard Worker            ),
378*da0073e9SAndroid Build Coastguard Worker        )
379*da0073e9SAndroid Build Coastguard Worker    elif storage._untyped_storage.device.type == "meta":
380*da0073e9SAndroid Build Coastguard Worker        return (
381*da0073e9SAndroid Build Coastguard Worker            rebuild_meta_tensor,
382*da0073e9SAndroid Build Coastguard Worker            (
383*da0073e9SAndroid Build Coastguard Worker                type(tensor),
384*da0073e9SAndroid Build Coastguard Worker                tensor.size(),
385*da0073e9SAndroid Build Coastguard Worker                tensor.stride(),
386*da0073e9SAndroid Build Coastguard Worker                tensor.storage_offset(),
387*da0073e9SAndroid Build Coastguard Worker                tensor.dtype,
388*da0073e9SAndroid Build Coastguard Worker                tensor.untyped_storage().size(),
389*da0073e9SAndroid Build Coastguard Worker                tensor.requires_grad,
390*da0073e9SAndroid Build Coastguard Worker            ),
391*da0073e9SAndroid Build Coastguard Worker        )
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker    # _backward_hooks purposely omitted here, see Note [Don't serialize hooks]
394*da0073e9SAndroid Build Coastguard Worker    metadata = (
395*da0073e9SAndroid Build Coastguard Worker        tensor.storage_offset(),
396*da0073e9SAndroid Build Coastguard Worker        tensor.size(),
397*da0073e9SAndroid Build Coastguard Worker        tensor.stride(),
398*da0073e9SAndroid Build Coastguard Worker        tensor.requires_grad,
399*da0073e9SAndroid Build Coastguard Worker    )
400*da0073e9SAndroid Build Coastguard Worker    return (rebuild_tensor, (type(tensor), storage, metadata))
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Workerdef rebuild_nested_tensor(
404*da0073e9SAndroid Build Coastguard Worker    rebuild_buffer_func,
405*da0073e9SAndroid Build Coastguard Worker    rebuild_buffer_args,
406*da0073e9SAndroid Build Coastguard Worker    rebuild_sizes_func,
407*da0073e9SAndroid Build Coastguard Worker    rebuild_sizes_args,
408*da0073e9SAndroid Build Coastguard Worker    rebuild_strides_func,
409*da0073e9SAndroid Build Coastguard Worker    rebuild_strides_args,
410*da0073e9SAndroid Build Coastguard Worker    rebuild_offsets_func,
411*da0073e9SAndroid Build Coastguard Worker    rebuild_offsets_args,
412*da0073e9SAndroid Build Coastguard Worker):
413*da0073e9SAndroid Build Coastguard Worker    buffer = rebuild_buffer_func(*rebuild_buffer_args)
414*da0073e9SAndroid Build Coastguard Worker    sizes = rebuild_sizes_func(*rebuild_sizes_args)
415*da0073e9SAndroid Build Coastguard Worker    strides = rebuild_strides_func(*rebuild_strides_args)
416*da0073e9SAndroid Build Coastguard Worker    offsets = rebuild_offsets_func(*rebuild_offsets_args)
417*da0073e9SAndroid Build Coastguard Worker    return torch._nested_view_from_buffer_copy(buffer, sizes, strides, offsets)
418*da0073e9SAndroid Build Coastguard Worker
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Workerdef reduce_nested_tensor(nt):
421*da0073e9SAndroid Build Coastguard Worker    rebuild_buffer_func, rebuild_buffer_args = reduce_tensor(nt.values())
422*da0073e9SAndroid Build Coastguard Worker    rebuild_sizes_func, rebuild_sizes_args = reduce_tensor(nt._nested_tensor_size())
423*da0073e9SAndroid Build Coastguard Worker    rebuild_strides_func, rebuild_strides_args = reduce_tensor(
424*da0073e9SAndroid Build Coastguard Worker        nt._nested_tensor_strides()
425*da0073e9SAndroid Build Coastguard Worker    )
426*da0073e9SAndroid Build Coastguard Worker    rebuild_offsets_func, rebuild_offsets_args = reduce_tensor(
427*da0073e9SAndroid Build Coastguard Worker        nt._nested_tensor_storage_offsets()
428*da0073e9SAndroid Build Coastguard Worker    )
429*da0073e9SAndroid Build Coastguard Worker
430*da0073e9SAndroid Build Coastguard Worker    return (
431*da0073e9SAndroid Build Coastguard Worker        rebuild_nested_tensor,
432*da0073e9SAndroid Build Coastguard Worker        (
433*da0073e9SAndroid Build Coastguard Worker            rebuild_buffer_func,
434*da0073e9SAndroid Build Coastguard Worker            rebuild_buffer_args,
435*da0073e9SAndroid Build Coastguard Worker            rebuild_sizes_func,
436*da0073e9SAndroid Build Coastguard Worker            rebuild_sizes_args,
437*da0073e9SAndroid Build Coastguard Worker            rebuild_strides_func,
438*da0073e9SAndroid Build Coastguard Worker            rebuild_strides_args,
439*da0073e9SAndroid Build Coastguard Worker            rebuild_offsets_func,
440*da0073e9SAndroid Build Coastguard Worker            rebuild_offsets_args,
441*da0073e9SAndroid Build Coastguard Worker        ),
442*da0073e9SAndroid Build Coastguard Worker    )
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Workerdef rebuild_sparse_coo_tensor(
446*da0073e9SAndroid Build Coastguard Worker    rebuild_indices_func,
447*da0073e9SAndroid Build Coastguard Worker    rebuild_indices_args,
448*da0073e9SAndroid Build Coastguard Worker    rebuild_values_func,
449*da0073e9SAndroid Build Coastguard Worker    rebuild_values_args,
450*da0073e9SAndroid Build Coastguard Worker    shape,
451*da0073e9SAndroid Build Coastguard Worker    is_coalesced,
452*da0073e9SAndroid Build Coastguard Worker):
453*da0073e9SAndroid Build Coastguard Worker    indices = rebuild_indices_func(*rebuild_indices_args)
454*da0073e9SAndroid Build Coastguard Worker    values = rebuild_values_func(*rebuild_values_args)
455*da0073e9SAndroid Build Coastguard Worker    return torch.sparse_coo_tensor(indices, values, shape, is_coalesced=is_coalesced)
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Workerdef rebuild_sparse_compressed_tensor(
459*da0073e9SAndroid Build Coastguard Worker    rebuild_compressed_indices_func,
460*da0073e9SAndroid Build Coastguard Worker    rebuild_compressed_indices_args,
461*da0073e9SAndroid Build Coastguard Worker    rebuild_plain_indices_func,
462*da0073e9SAndroid Build Coastguard Worker    rebuild_plain_indices_args,
463*da0073e9SAndroid Build Coastguard Worker    rebuild_values_func,
464*da0073e9SAndroid Build Coastguard Worker    rebuild_values_args,
465*da0073e9SAndroid Build Coastguard Worker    shape,
466*da0073e9SAndroid Build Coastguard Worker    layout,
467*da0073e9SAndroid Build Coastguard Worker):
468*da0073e9SAndroid Build Coastguard Worker    compressed_indices = rebuild_compressed_indices_func(
469*da0073e9SAndroid Build Coastguard Worker        *rebuild_compressed_indices_args
470*da0073e9SAndroid Build Coastguard Worker    )
471*da0073e9SAndroid Build Coastguard Worker    plain_indices = rebuild_plain_indices_func(*rebuild_plain_indices_args)
472*da0073e9SAndroid Build Coastguard Worker    values = rebuild_values_func(*rebuild_values_args)
473*da0073e9SAndroid Build Coastguard Worker    return torch.sparse_compressed_tensor(
474*da0073e9SAndroid Build Coastguard Worker        compressed_indices, plain_indices, values, shape, layout=layout
475*da0073e9SAndroid Build Coastguard Worker    )
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker
478*da0073e9SAndroid Build Coastguard Workerdef reduce_sparse_tensor(sparse):
479*da0073e9SAndroid Build Coastguard Worker    if sparse.layout is torch.sparse_coo:
480*da0073e9SAndroid Build Coastguard Worker        rebuild_indices_func, rebuild_indices_args = reduce_tensor(sparse._indices())
481*da0073e9SAndroid Build Coastguard Worker        rebuild_values_func, rebuild_values_args = reduce_tensor(sparse._values())
482*da0073e9SAndroid Build Coastguard Worker        return (
483*da0073e9SAndroid Build Coastguard Worker            rebuild_sparse_coo_tensor,
484*da0073e9SAndroid Build Coastguard Worker            (
485*da0073e9SAndroid Build Coastguard Worker                rebuild_indices_func,
486*da0073e9SAndroid Build Coastguard Worker                rebuild_indices_args,
487*da0073e9SAndroid Build Coastguard Worker                rebuild_values_func,
488*da0073e9SAndroid Build Coastguard Worker                rebuild_values_args,
489*da0073e9SAndroid Build Coastguard Worker                sparse.shape,
490*da0073e9SAndroid Build Coastguard Worker                sparse.is_coalesced(),
491*da0073e9SAndroid Build Coastguard Worker            ),
492*da0073e9SAndroid Build Coastguard Worker        )
493*da0073e9SAndroid Build Coastguard Worker    else:
494*da0073e9SAndroid Build Coastguard Worker        if sparse.layout in {torch.sparse_csr, torch.sparse_bsr}:
495*da0073e9SAndroid Build Coastguard Worker            compressed_indices = sparse.crow_indices()
496*da0073e9SAndroid Build Coastguard Worker            plain_indices = sparse.col_indices()
497*da0073e9SAndroid Build Coastguard Worker        elif sparse.layout in {torch.sparse_csc, torch.sparse_bsc}:
498*da0073e9SAndroid Build Coastguard Worker            compressed_indices = sparse.ccol_indices()
499*da0073e9SAndroid Build Coastguard Worker            plain_indices = sparse.row_indices()
500*da0073e9SAndroid Build Coastguard Worker        else:
501*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError(sparse.layout)
502*da0073e9SAndroid Build Coastguard Worker        (
503*da0073e9SAndroid Build Coastguard Worker            rebuild_compressed_indices_func,
504*da0073e9SAndroid Build Coastguard Worker            rebuild_compressed_indices_args,
505*da0073e9SAndroid Build Coastguard Worker        ) = reduce_tensor(compressed_indices)
506*da0073e9SAndroid Build Coastguard Worker        rebuild_plain_indices_func, rebuild_plain_indices_args = reduce_tensor(
507*da0073e9SAndroid Build Coastguard Worker            plain_indices
508*da0073e9SAndroid Build Coastguard Worker        )
509*da0073e9SAndroid Build Coastguard Worker        rebuild_values_func, rebuild_values_args = reduce_tensor(sparse.values())
510*da0073e9SAndroid Build Coastguard Worker        return (
511*da0073e9SAndroid Build Coastguard Worker            rebuild_sparse_compressed_tensor,
512*da0073e9SAndroid Build Coastguard Worker            (
513*da0073e9SAndroid Build Coastguard Worker                rebuild_compressed_indices_func,
514*da0073e9SAndroid Build Coastguard Worker                rebuild_compressed_indices_args,
515*da0073e9SAndroid Build Coastguard Worker                rebuild_plain_indices_func,
516*da0073e9SAndroid Build Coastguard Worker                rebuild_plain_indices_args,
517*da0073e9SAndroid Build Coastguard Worker                rebuild_values_func,
518*da0073e9SAndroid Build Coastguard Worker                rebuild_values_args,
519*da0073e9SAndroid Build Coastguard Worker                sparse.shape,
520*da0073e9SAndroid Build Coastguard Worker                sparse.layout,
521*da0073e9SAndroid Build Coastguard Worker            ),
522*da0073e9SAndroid Build Coastguard Worker        )
523*da0073e9SAndroid Build Coastguard Worker
524*da0073e9SAndroid Build Coastguard Worker
525*da0073e9SAndroid Build Coastguard Workerdef fd_id(fd):
526*da0073e9SAndroid Build Coastguard Worker    # Returns a tuple which uniquely identifies a file descriptor. In Mac OS,
527*da0073e9SAndroid Build Coastguard Worker    # this doesn't work with shared memory handles, which is why we don't
528*da0073e9SAndroid Build Coastguard Worker    # support the "file_descriptor" sharing method on that platform.
529*da0073e9SAndroid Build Coastguard Worker    stat = os.fstat(fd)
530*da0073e9SAndroid Build Coastguard Worker    return (stat.st_ino, stat.st_dev)
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Workerdef storage_from_cache(cls, key):
534*da0073e9SAndroid Build Coastguard Worker    storage_ref = shared_cache.get(key)
535*da0073e9SAndroid Build Coastguard Worker    if storage_ref is None:
536*da0073e9SAndroid Build Coastguard Worker        return None
537*da0073e9SAndroid Build Coastguard Worker    return torch.UntypedStorage._new_with_weak_ptr(storage_ref.cdata)
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Workerdef rebuild_storage_fd(cls, df, size):
541*da0073e9SAndroid Build Coastguard Worker    fd = df.detach()
542*da0073e9SAndroid Build Coastguard Worker    try:
543*da0073e9SAndroid Build Coastguard Worker        storage = storage_from_cache(cls, fd_id(fd))
544*da0073e9SAndroid Build Coastguard Worker        if storage is not None:
545*da0073e9SAndroid Build Coastguard Worker            return storage
546*da0073e9SAndroid Build Coastguard Worker        storage = cls._new_shared_fd_cpu(fd, size)
547*da0073e9SAndroid Build Coastguard Worker        shared_cache[fd_id(fd)] = StorageWeakRef(storage)
548*da0073e9SAndroid Build Coastguard Worker        return storage
549*da0073e9SAndroid Build Coastguard Worker    finally:
550*da0073e9SAndroid Build Coastguard Worker        os.close(fd)
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker
553*da0073e9SAndroid Build Coastguard Workerdef rebuild_storage_filename(cls, manager, handle, size, dtype=None):
554*da0073e9SAndroid Build Coastguard Worker    storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache(
555*da0073e9SAndroid Build Coastguard Worker        cls, handle
556*da0073e9SAndroid Build Coastguard Worker    )
557*da0073e9SAndroid Build Coastguard Worker    if storage is not None:
558*da0073e9SAndroid Build Coastguard Worker        return storage._shared_decref()
559*da0073e9SAndroid Build Coastguard Worker    if dtype is None:
560*da0073e9SAndroid Build Coastguard Worker        storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size)
561*da0073e9SAndroid Build Coastguard Worker    else:
562*da0073e9SAndroid Build Coastguard Worker        byte_size = size * torch._utils._element_size(dtype)
563*da0073e9SAndroid Build Coastguard Worker        untyped_storage: torch.UntypedStorage = (
564*da0073e9SAndroid Build Coastguard Worker            torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size)
565*da0073e9SAndroid Build Coastguard Worker        )
566*da0073e9SAndroid Build Coastguard Worker        storage = torch.TypedStorage(
567*da0073e9SAndroid Build Coastguard Worker            wrap_storage=untyped_storage, dtype=dtype, _internal=True
568*da0073e9SAndroid Build Coastguard Worker        )
569*da0073e9SAndroid Build Coastguard Worker    shared_cache[handle] = StorageWeakRef(storage)
570*da0073e9SAndroid Build Coastguard Worker    return storage._shared_decref()
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker
573*da0073e9SAndroid Build Coastguard Workerdef rebuild_storage_empty(cls):
574*da0073e9SAndroid Build Coastguard Worker    return cls()
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Workerdef rebuild_typed_storage(storage, dtype):
578*da0073e9SAndroid Build Coastguard Worker    return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True)
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker
581*da0073e9SAndroid Build Coastguard Worker# Use for torch.storage.TypedStorage
582*da0073e9SAndroid Build Coastguard Workerdef reduce_typed_storage(storage):
583*da0073e9SAndroid Build Coastguard Worker    return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype))
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Workerdef rebuild_typed_storage_child(storage, storage_type):
587*da0073e9SAndroid Build Coastguard Worker    return storage_type(wrap_storage=storage, _internal=True)
588*da0073e9SAndroid Build Coastguard Worker
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Worker# Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage
591*da0073e9SAndroid Build Coastguard Workerdef reduce_typed_storage_child(storage):
592*da0073e9SAndroid Build Coastguard Worker    return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage)))
593*da0073e9SAndroid Build Coastguard Worker
594*da0073e9SAndroid Build Coastguard Worker
595*da0073e9SAndroid Build Coastguard Workerdef reduce_storage(storage):
596*da0073e9SAndroid Build Coastguard Worker    from . import get_sharing_strategy
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Worker    if storage.is_cuda:
599*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
600*da0073e9SAndroid Build Coastguard Worker            "Cannot pickle CUDA storage; try pickling a CUDA tensor instead"
601*da0073e9SAndroid Build Coastguard Worker        )
602*da0073e9SAndroid Build Coastguard Worker    elif storage.device.type == "meta":
603*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
604*da0073e9SAndroid Build Coastguard Worker            "Cannot pickle meta storage; try pickling a meta tensor instead"
605*da0073e9SAndroid Build Coastguard Worker        )
606*da0073e9SAndroid Build Coastguard Worker    elif get_sharing_strategy() == "file_system":
607*da0073e9SAndroid Build Coastguard Worker        metadata = storage._share_filename_cpu_()
608*da0073e9SAndroid Build Coastguard Worker        cache_key = metadata[1]
609*da0073e9SAndroid Build Coastguard Worker        rebuild = rebuild_storage_filename
610*da0073e9SAndroid Build Coastguard Worker        if isinstance(storage, torch.TypedStorage):
611*da0073e9SAndroid Build Coastguard Worker            metadata += (storage.dtype,)
612*da0073e9SAndroid Build Coastguard Worker        storage._shared_incref()
613*da0073e9SAndroid Build Coastguard Worker    elif storage.size() == 0:
614*da0073e9SAndroid Build Coastguard Worker        # This is special cased because Empty tensors
615*da0073e9SAndroid Build Coastguard Worker        # (with size 0) cannot be mmapped.
616*da0073e9SAndroid Build Coastguard Worker        return (rebuild_storage_empty, (type(storage),))
617*da0073e9SAndroid Build Coastguard Worker    else:
618*da0073e9SAndroid Build Coastguard Worker        fd, size = storage._share_fd_cpu_()
619*da0073e9SAndroid Build Coastguard Worker        df = multiprocessing.reduction.DupFd(fd)
620*da0073e9SAndroid Build Coastguard Worker        cache_key = fd_id(fd)
621*da0073e9SAndroid Build Coastguard Worker        metadata = (df, size)
622*da0073e9SAndroid Build Coastguard Worker        rebuild = rebuild_storage_fd  # type: ignore[assignment]
623*da0073e9SAndroid Build Coastguard Worker
624*da0073e9SAndroid Build Coastguard Worker    shared_cache[cache_key] = StorageWeakRef(storage)
625*da0073e9SAndroid Build Coastguard Worker    return (rebuild, (type(storage),) + metadata)
626*da0073e9SAndroid Build Coastguard Worker
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Workerdef init_reductions():
629*da0073e9SAndroid Build Coastguard Worker    ForkingPickler.register(torch.cuda.Event, reduce_event)
630*da0073e9SAndroid Build Coastguard Worker
631*da0073e9SAndroid Build Coastguard Worker    for t in torch._storage_classes:
632*da0073e9SAndroid Build Coastguard Worker        if t.__name__ == "UntypedStorage":
633*da0073e9SAndroid Build Coastguard Worker            ForkingPickler.register(t, reduce_storage)
634*da0073e9SAndroid Build Coastguard Worker        else:
635*da0073e9SAndroid Build Coastguard Worker            ForkingPickler.register(t, reduce_typed_storage_child)
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker    ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage)
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Worker    for t in torch._tensor_classes:
640*da0073e9SAndroid Build Coastguard Worker        ForkingPickler.register(t, reduce_tensor)
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Worker    # TODO: Maybe this should be in tensor_classes? :)
643*da0073e9SAndroid Build Coastguard Worker    ForkingPickler.register(torch.Tensor, reduce_tensor)
644*da0073e9SAndroid Build Coastguard Worker
645*da0073e9SAndroid Build Coastguard Worker    from torch.nn.parameter import Parameter
646*da0073e9SAndroid Build Coastguard Worker
647*da0073e9SAndroid Build Coastguard Worker    ForkingPickler.register(Parameter, reduce_tensor)
648