xref: /aosp_15_r20/external/pytorch/torch/multiprocessing/reductions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import multiprocessing
3import os
4import threading
5from multiprocessing.reduction import ForkingPickler
6from multiprocessing.util import register_after_fork
7from typing import Union
8
9import torch
10from torch._namedtensor_internals import check_serializing_named_tensor
11
12
13try:
14    # Early load resource_sharer to prevent a partially initialized instance
15    # from being inherited in a forked child process. The reduce_storage method
16    # requires this module indirectly through DupFd(). The built-in mp.Queue
17    # class pickles arguments in a background thread which may overlap with the
18    # fork.
19    import multiprocessing.resource_sharer
20except ImportError:
21    pass
22
23
24class StorageWeakRef:
25    r"""A weak reference to a Storage.
26
27    The cdata member is a Python number containing the integer representation of
28    the Storage pointer.
29    """
30
31    __slots__ = ["cdata", "_free_weak_ref"]
32
33    def __init__(self, storage):
34        self.cdata = storage._weak_ref()
35        # Save a direct reference to _free_weak_ref because the `torch` module
36        # might be cleared during Python shutdown before this module is cleared.
37        self._free_weak_ref = torch.Storage._free_weak_ref  # type: ignore[attr-defined]
38
39    @classmethod
40    def from_weakref(cls, cdata):
41        instance = cls.__new__(cls)
42        instance.cdata = cdata
43        instance._free_weak_ref = torch.Storage._free_weak_ref  # type: ignore[attr-defined]
44        return instance
45
46    def expired(self):
47        return torch.Storage._expired(self.cdata)  # type: ignore[attr-defined]
48
49    def __del__(self):
50        self._free_weak_ref(self.cdata)
51
52    def __hash__(self):
53        return self.cdata
54
55    def __eq__(self, other):
56        if id(self) == id(other):
57            return True
58        return self.cdata == other.cdata
59
60
61class SharedCache(dict):
62    """Dictionary from multiprocessing handles to StorageWeakRef."""
63
64    def __init__(self) -> None:
65        # free_dead_references() is called if the len exceeds the current
66        # limit. The limit scales with the number of remaining live objects.
67        self.limit = 128
68        # `fork` inherits lock state, so in case we fork when the lock is held,
69        # we register a function to reset the lock to a new object to avoid
70        # possible deadlocks, following python multiprocessing library design.
71        self._after_fork()
72        register_after_fork(self, SharedCache._after_fork)
73
74    def _after_fork(self):
75        self.lock = threading.Lock()
76
77    def get(self, key):
78        with self.lock:
79            return dict.get(self, key)
80
81    def __setitem__(self, key, storage_ref):
82        with self.lock:
83            dict.__setitem__(self, key, storage_ref)
84            if len(self) > self.limit:
85                self.free_dead_references()
86
87    def free_dead_references(self):
88        live = 0
89        for key, storage_ref in list(self.items()):
90            if storage_ref.expired():
91                del self[key]
92            else:
93                live += 1
94        self.limit = max(128, live * 2)
95
96
97# mapping from handles to StorageWeakRef objects
98shared_cache = SharedCache()
99
100
101def rebuild_event(device, handle):
102    return torch.cuda.Event.from_ipc_handle(device, handle)
103
104
105def reduce_event(event):
106    handle = event.ipc_handle()
107    return (rebuild_event, (event.device, handle))
108
109
110def rebuild_tensor(cls, storage, metadata):
111    storage_offset, size, stride, requires_grad = metadata
112    t = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
113    if cls == torch.nn.parameter.Parameter:
114        # we have to pass requires_grad into constructor, rather than set it as an
115        # attribute later, because it's an important check for Integer Tensors to
116        # have requires_grad=False (or else they raise an error)
117        t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
118    else:
119        t.requires_grad = requires_grad
120    return t
121
122
123def rebuild_meta_tensor(
124    tensor_cls,
125    tensor_size,
126    tensor_stride,
127    tensor_offset,
128    dtype,
129    storage_size_bytes,
130    requires_grad,
131):
132    untyped_storage = torch.UntypedStorage(storage_size_bytes, device="meta")
133
134    typed_storage = torch.TypedStorage(
135        wrap_storage=untyped_storage, dtype=dtype, _internal=True
136    )
137
138    t = torch._utils._rebuild_tensor(
139        typed_storage,
140        tensor_offset,
141        tensor_size,
142        tensor_stride,
143    )
144
145    if tensor_cls == torch.nn.parameter.Parameter:
146        # It is crucial for integer tensors to receive
147        # the requires_grad=False as an argument in the constructor
148        t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
149    else:
150        t.requires_grad = requires_grad
151
152    return t
153
154
155def rebuild_cuda_tensor(
156    tensor_cls,
157    tensor_size,
158    tensor_stride,
159    tensor_offset,
160    storage_cls,
161    dtype,
162    storage_device,
163    storage_handle,
164    storage_size_bytes,
165    storage_offset_bytes,
166    requires_grad,
167    ref_counter_handle,
168    ref_counter_offset,
169    event_handle,
170    event_sync_required,
171):
172    # If storage_handle is None, storage points to nullptr.
173    if storage_handle is None or storage_size_bytes == 0:
174        storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True)
175    else:
176        storage = storage_from_cache(
177            storage_cls, (storage_handle, storage_offset_bytes)
178        )
179        if storage is None:
180            torch.cuda._lazy_init()
181            storage = storage_cls._new_shared_cuda(
182                storage_device,
183                storage_handle,
184                storage_size_bytes,
185                storage_offset_bytes,
186                ref_counter_handle,
187                ref_counter_offset,
188                event_handle,
189                event_sync_required,
190            )
191            shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(
192                storage
193            )
194        else:
195            # We already ref counting this Storage, but producer needs new ref-counters to be released.
196            storage_cls._release_ipc_counter(
197                ref_counter_handle, ref_counter_offset, device=storage_device
198            )
199
200    _storage = (
201        storage
202        if isinstance(storage, torch.UntypedStorage)
203        else storage._untyped_storage
204    )
205
206    t = torch._utils._rebuild_tensor(
207        torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True),
208        tensor_offset,
209        tensor_size,
210        tensor_stride,
211    )
212
213    if tensor_cls == torch.nn.parameter.Parameter:
214        # It is crucial for integer tensors to receive
215        # the requires_grad=False as an argument in the constructor
216        t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
217    else:
218        t.requires_grad = requires_grad
219
220    return t
221
222
223def reduce_tensor(tensor):
224    if tensor.requires_grad and not tensor.is_leaf:
225        raise RuntimeError(
226            "Cowardly refusing to serialize non-leaf tensor which requires_grad, "
227            "since autograd does not support crossing process boundaries.  "
228            "If you just want to transfer the data, call detach() on the tensor "
229            "before serializing (e.g., putting it on the queue)."
230        )
231
232    check_serializing_named_tensor(tensor)
233    torch.utils.hooks.warn_if_has_hooks(tensor)
234
235    # Note [CUDA IPC and the caching allocator]
236    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
237    # When you send a CUDA tensor over IPC, you might expect that you will
238    # get out the same storage from the other end.  However, the CUDA caching
239    # allocator makes it difficult to preserve this invariant.  Consider
240    # the following situation: a tensor of size 0x100 points to offset 0x20 of
241    # a storage at 0xA100 of size 0x100.  (For simplicity, all of these
242    # sizes are given in bytes).  HOWEVER, with the caching allocator, this storage
243    # might be part of a larger cudaMalloc allocation 0xA000 of size 0x4000.
244    #
245    # When we want to send this CUDA tensor over IPC, we must send the
246    # *entire* cudaMalloc allocation, i.e., the 0xA000 region, not just
247    # the storage 0xA100 (because that is what CUDA supports).  So, on the
248    # other end, there simply isn't any way to say, "Wait, you gave me
249    # a bigger region (0xA000) than the one I wanted (0xA100)".
250    #
251    # OK, so if you sent the cudaMalloc allocation, can you just wrap that up as
252    # one storage itself? No, because this cudaMalloc allocation might contain
253    # storages of mixed types: float, bytes, double... If you make the entire
254    # allocation a single storage of a type A, we'll hit an error when constructing
255    # a tensor of type B on the storage.
256    #
257    # cudaIpcMemHandle is an identifier to access the sender cudaMalloc allocation on the
258    # receiver side. However, cudaIpcMemHandles from each device in a given process may
259    # only be opened by one context per device per other process.
260    # If we open and close a memory handle multiples times in a process, CUDA is allowed
261    # to give it a different address; similarly, once we close the memory, we're not
262    # allowed to access it(and the storage/tensor built on top of it), even if it is
263    # still live in the original process. As we cannot make a cudaMalloc allocation
264    # to a single storage in one go, this requires us to cache the device pointer for
265    # each cudaIpcMemHandle on C++ side to reconstruct types of storages, while keep
266    # the old ones alives.
267    # See [https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html]
268    #
269    # This is fine, because all we need to do is to save our position in the allocation,
270    # and reconstruct storage and tensor from it.
271    # 0xA000 ->  -------CUDA Allocation------
272    #           |                            |
273    #           |                            |
274    #           |                            |
275    #           |                            |
276    # 0xA100 ->  --------storage1 begin------
277    #           |                            |
278    # 0xA120 ->  --------tensor1 begin ------
279    #           |                            |
280    #           |                            |
281    #           |                            |
282    #           |                            |
283    #           |                            |
284    # 0xA160 ->  --------tensor1 end---------
285    #           |                            |
286    #           |                            |
287    #           |                            |
288    # 0xA200 ->  --------storage1 end--------
289    #           |                            |
290    # 0xE000 ->  --------CUDA allocation-----
291    #
292    # To send tensor1, the following info are required from sender to receiver for
293    # storage recontruction.
294    #   1. cudaIpcMemHandle of 0xA000(which can be mapped to a basePtr in receiver process).
295    #      basePtr may not be exactly 0xA000 since it's a different process.
296    #   2. offset(0xA100) of storage1 in the CUDA allocation.
297    #   3. size of storage1(0x100).
298    #
299    # On receiver side:
300    #   1. Get the devPtr of the MemHandle to access the memory, reconstruct a storage
301    #      of the same type using (basePtr, offset, size).
302    #   2. we can reconstruct the tensor on top of the reconstructed storage
303    #   Tensor(size=0x040, offset=0x020, storage=Storage(data=basePtr+0xA100, size=0x0100))
304    #
305    # This strategy has a few implications:
306    #
307    # 1. When we serialize a CUDA tensor for IPC, we cannot do it all in one
308    #    go (non-compositionally), and this requires to have a global map
309    #    memHandle -> devPtr for each process.
310    #
311    # 2. We MUST NOT let the new IPC tensor be resizable.  Originally, a resize
312    #    of the storage beyond 0x100 would merely have caused us to do a
313    #    reallocation.  You don't really want to do this, but if you did,
314    #    all that would happen is that you would lose IPC sharing.  But if
315    #    you do this in the new world, we will happily let you write out of
316    #    bounds of your "allocation", clobbering unrelated data in the cached
317    #    allocator block.  BAD!
318    #
319    # By the way, in old versions of PyTorch, we supported this situation
320    # natively using a "storage view", which permitted multiple storages to be
321    # views on each other.  But this was the *only* use of storage views, so we
322    # eliminated it so that we could just use tensor views to implement the same
323    # thing.
324    #
325
326    # TODO: Handle distinguishing between subclass and non-subclass versions of NT better
327    # https://github.com/pytorch/pytorch/issues/110543
328    from torch.nested._internal.nested_tensor import NestedTensor
329
330    if tensor.is_nested and not isinstance(tensor, NestedTensor):
331        return reduce_nested_tensor(tensor)
332
333    if tensor.layout in {
334        torch.sparse_coo,
335        torch.sparse_csr,
336        torch.sparse_bsr,
337        torch.sparse_csc,
338        torch.sparse_bsc,
339    }:
340        return reduce_sparse_tensor(tensor)
341
342    storage = tensor._typed_storage()
343
344    if storage._untyped_storage.device.type == "cuda":
345        (
346            device,
347            handle,
348            storage_size_bytes,
349            storage_offset_bytes,
350            ref_counter_handle,
351            ref_counter_offset,
352            event_handle,
353            event_sync_required,
354        ) = storage._share_cuda_()
355        tensor_offset = tensor.storage_offset()
356        shared_cache[handle] = StorageWeakRef(storage)
357        # _backward_hooks purposely omitted here, see
358        # Note [Don't serialize hooks]
359        return (
360            rebuild_cuda_tensor,
361            (
362                type(tensor),
363                tensor.size(),
364                tensor.stride(),
365                tensor_offset,  # tensor offset in its storage
366                type(storage),
367                tensor.dtype,
368                device,
369                handle,  # identifier which CUDA allocation is the storage in.
370                storage_size_bytes,  # size(in bytes) of the storage
371                storage_offset_bytes,  # offset(in bytes) of the storage in the CUDA allocation
372                tensor.requires_grad,
373                ref_counter_handle,
374                ref_counter_offset,
375                event_handle,
376                event_sync_required,
377            ),
378        )
379    elif storage._untyped_storage.device.type == "meta":
380        return (
381            rebuild_meta_tensor,
382            (
383                type(tensor),
384                tensor.size(),
385                tensor.stride(),
386                tensor.storage_offset(),
387                tensor.dtype,
388                tensor.untyped_storage().size(),
389                tensor.requires_grad,
390            ),
391        )
392
393    # _backward_hooks purposely omitted here, see Note [Don't serialize hooks]
394    metadata = (
395        tensor.storage_offset(),
396        tensor.size(),
397        tensor.stride(),
398        tensor.requires_grad,
399    )
400    return (rebuild_tensor, (type(tensor), storage, metadata))
401
402
403def rebuild_nested_tensor(
404    rebuild_buffer_func,
405    rebuild_buffer_args,
406    rebuild_sizes_func,
407    rebuild_sizes_args,
408    rebuild_strides_func,
409    rebuild_strides_args,
410    rebuild_offsets_func,
411    rebuild_offsets_args,
412):
413    buffer = rebuild_buffer_func(*rebuild_buffer_args)
414    sizes = rebuild_sizes_func(*rebuild_sizes_args)
415    strides = rebuild_strides_func(*rebuild_strides_args)
416    offsets = rebuild_offsets_func(*rebuild_offsets_args)
417    return torch._nested_view_from_buffer_copy(buffer, sizes, strides, offsets)
418
419
420def reduce_nested_tensor(nt):
421    rebuild_buffer_func, rebuild_buffer_args = reduce_tensor(nt.values())
422    rebuild_sizes_func, rebuild_sizes_args = reduce_tensor(nt._nested_tensor_size())
423    rebuild_strides_func, rebuild_strides_args = reduce_tensor(
424        nt._nested_tensor_strides()
425    )
426    rebuild_offsets_func, rebuild_offsets_args = reduce_tensor(
427        nt._nested_tensor_storage_offsets()
428    )
429
430    return (
431        rebuild_nested_tensor,
432        (
433            rebuild_buffer_func,
434            rebuild_buffer_args,
435            rebuild_sizes_func,
436            rebuild_sizes_args,
437            rebuild_strides_func,
438            rebuild_strides_args,
439            rebuild_offsets_func,
440            rebuild_offsets_args,
441        ),
442    )
443
444
445def rebuild_sparse_coo_tensor(
446    rebuild_indices_func,
447    rebuild_indices_args,
448    rebuild_values_func,
449    rebuild_values_args,
450    shape,
451    is_coalesced,
452):
453    indices = rebuild_indices_func(*rebuild_indices_args)
454    values = rebuild_values_func(*rebuild_values_args)
455    return torch.sparse_coo_tensor(indices, values, shape, is_coalesced=is_coalesced)
456
457
458def rebuild_sparse_compressed_tensor(
459    rebuild_compressed_indices_func,
460    rebuild_compressed_indices_args,
461    rebuild_plain_indices_func,
462    rebuild_plain_indices_args,
463    rebuild_values_func,
464    rebuild_values_args,
465    shape,
466    layout,
467):
468    compressed_indices = rebuild_compressed_indices_func(
469        *rebuild_compressed_indices_args
470    )
471    plain_indices = rebuild_plain_indices_func(*rebuild_plain_indices_args)
472    values = rebuild_values_func(*rebuild_values_args)
473    return torch.sparse_compressed_tensor(
474        compressed_indices, plain_indices, values, shape, layout=layout
475    )
476
477
478def reduce_sparse_tensor(sparse):
479    if sparse.layout is torch.sparse_coo:
480        rebuild_indices_func, rebuild_indices_args = reduce_tensor(sparse._indices())
481        rebuild_values_func, rebuild_values_args = reduce_tensor(sparse._values())
482        return (
483            rebuild_sparse_coo_tensor,
484            (
485                rebuild_indices_func,
486                rebuild_indices_args,
487                rebuild_values_func,
488                rebuild_values_args,
489                sparse.shape,
490                sparse.is_coalesced(),
491            ),
492        )
493    else:
494        if sparse.layout in {torch.sparse_csr, torch.sparse_bsr}:
495            compressed_indices = sparse.crow_indices()
496            plain_indices = sparse.col_indices()
497        elif sparse.layout in {torch.sparse_csc, torch.sparse_bsc}:
498            compressed_indices = sparse.ccol_indices()
499            plain_indices = sparse.row_indices()
500        else:
501            raise NotImplementedError(sparse.layout)
502        (
503            rebuild_compressed_indices_func,
504            rebuild_compressed_indices_args,
505        ) = reduce_tensor(compressed_indices)
506        rebuild_plain_indices_func, rebuild_plain_indices_args = reduce_tensor(
507            plain_indices
508        )
509        rebuild_values_func, rebuild_values_args = reduce_tensor(sparse.values())
510        return (
511            rebuild_sparse_compressed_tensor,
512            (
513                rebuild_compressed_indices_func,
514                rebuild_compressed_indices_args,
515                rebuild_plain_indices_func,
516                rebuild_plain_indices_args,
517                rebuild_values_func,
518                rebuild_values_args,
519                sparse.shape,
520                sparse.layout,
521            ),
522        )
523
524
525def fd_id(fd):
526    # Returns a tuple which uniquely identifies a file descriptor. In Mac OS,
527    # this doesn't work with shared memory handles, which is why we don't
528    # support the "file_descriptor" sharing method on that platform.
529    stat = os.fstat(fd)
530    return (stat.st_ino, stat.st_dev)
531
532
533def storage_from_cache(cls, key):
534    storage_ref = shared_cache.get(key)
535    if storage_ref is None:
536        return None
537    return torch.UntypedStorage._new_with_weak_ptr(storage_ref.cdata)
538
539
540def rebuild_storage_fd(cls, df, size):
541    fd = df.detach()
542    try:
543        storage = storage_from_cache(cls, fd_id(fd))
544        if storage is not None:
545            return storage
546        storage = cls._new_shared_fd_cpu(fd, size)
547        shared_cache[fd_id(fd)] = StorageWeakRef(storage)
548        return storage
549    finally:
550        os.close(fd)
551
552
553def rebuild_storage_filename(cls, manager, handle, size, dtype=None):
554    storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache(
555        cls, handle
556    )
557    if storage is not None:
558        return storage._shared_decref()
559    if dtype is None:
560        storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size)
561    else:
562        byte_size = size * torch._utils._element_size(dtype)
563        untyped_storage: torch.UntypedStorage = (
564            torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size)
565        )
566        storage = torch.TypedStorage(
567            wrap_storage=untyped_storage, dtype=dtype, _internal=True
568        )
569    shared_cache[handle] = StorageWeakRef(storage)
570    return storage._shared_decref()
571
572
573def rebuild_storage_empty(cls):
574    return cls()
575
576
577def rebuild_typed_storage(storage, dtype):
578    return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True)
579
580
581# Use for torch.storage.TypedStorage
582def reduce_typed_storage(storage):
583    return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype))
584
585
586def rebuild_typed_storage_child(storage, storage_type):
587    return storage_type(wrap_storage=storage, _internal=True)
588
589
590# Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage
591def reduce_typed_storage_child(storage):
592    return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage)))
593
594
595def reduce_storage(storage):
596    from . import get_sharing_strategy
597
598    if storage.is_cuda:
599        raise RuntimeError(
600            "Cannot pickle CUDA storage; try pickling a CUDA tensor instead"
601        )
602    elif storage.device.type == "meta":
603        raise RuntimeError(
604            "Cannot pickle meta storage; try pickling a meta tensor instead"
605        )
606    elif get_sharing_strategy() == "file_system":
607        metadata = storage._share_filename_cpu_()
608        cache_key = metadata[1]
609        rebuild = rebuild_storage_filename
610        if isinstance(storage, torch.TypedStorage):
611            metadata += (storage.dtype,)
612        storage._shared_incref()
613    elif storage.size() == 0:
614        # This is special cased because Empty tensors
615        # (with size 0) cannot be mmapped.
616        return (rebuild_storage_empty, (type(storage),))
617    else:
618        fd, size = storage._share_fd_cpu_()
619        df = multiprocessing.reduction.DupFd(fd)
620        cache_key = fd_id(fd)
621        metadata = (df, size)
622        rebuild = rebuild_storage_fd  # type: ignore[assignment]
623
624    shared_cache[cache_key] = StorageWeakRef(storage)
625    return (rebuild, (type(storage),) + metadata)
626
627
628def init_reductions():
629    ForkingPickler.register(torch.cuda.Event, reduce_event)
630
631    for t in torch._storage_classes:
632        if t.__name__ == "UntypedStorage":
633            ForkingPickler.register(t, reduce_storage)
634        else:
635            ForkingPickler.register(t, reduce_typed_storage_child)
636
637    ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage)
638
639    for t in torch._tensor_classes:
640        ForkingPickler.register(t, reduce_tensor)
641
642    # TODO: Maybe this should be in tensor_classes? :)
643    ForkingPickler.register(torch.Tensor, reduce_tensor)
644
645    from torch.nn.parameter import Parameter
646
647    ForkingPickler.register(Parameter, reduce_tensor)
648