xref: /aosp_15_r20/external/pytorch/torch/utils/dlpack.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom typing import Any
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Workerimport enum
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _from_dlpack
7*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _to_dlpack as to_dlpack
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerclass DLDeviceType(enum.IntEnum):
11*da0073e9SAndroid Build Coastguard Worker    # Enums as in DLPack specification (aten/src/ATen/dlpack.h)
12*da0073e9SAndroid Build Coastguard Worker    kDLCPU = 1,
13*da0073e9SAndroid Build Coastguard Worker    kDLGPU = 2,
14*da0073e9SAndroid Build Coastguard Worker    kDLCPUPinned = 3,
15*da0073e9SAndroid Build Coastguard Worker    kDLOpenCL = 4,
16*da0073e9SAndroid Build Coastguard Worker    kDLVulkan = 7,
17*da0073e9SAndroid Build Coastguard Worker    kDLMetal = 8,
18*da0073e9SAndroid Build Coastguard Worker    kDLVPI = 9,
19*da0073e9SAndroid Build Coastguard Worker    kDLROCM = 10,
20*da0073e9SAndroid Build Coastguard Worker    kDLExtDev = 12,
21*da0073e9SAndroid Build Coastguard Worker    kDLOneAPI = 14,
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workertorch._C._add_docstr(to_dlpack, r"""to_dlpack(tensor) -> PyCapsule
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard WorkerReturns an opaque object (a "DLPack capsule") representing the tensor.
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker.. note::
29*da0073e9SAndroid Build Coastguard Worker  ``to_dlpack`` is a legacy DLPack interface. The capsule it returns
30*da0073e9SAndroid Build Coastguard Worker  cannot be used for anything in Python other than use it as input to
31*da0073e9SAndroid Build Coastguard Worker  ``from_dlpack``. The more idiomatic use of DLPack is to call
32*da0073e9SAndroid Build Coastguard Worker  ``from_dlpack`` directly on the tensor object - this works when that
33*da0073e9SAndroid Build Coastguard Worker  object has a ``__dlpack__`` method, which PyTorch and most other
34*da0073e9SAndroid Build Coastguard Worker  libraries indeed have now.
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker.. warning::
37*da0073e9SAndroid Build Coastguard Worker  Only call ``from_dlpack`` once per capsule produced with ``to_dlpack``.
38*da0073e9SAndroid Build Coastguard Worker  Behavior when a capsule is consumed multiple times is undefined.
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard WorkerArgs:
41*da0073e9SAndroid Build Coastguard Worker    tensor: a tensor to be exported
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard WorkerThe DLPack capsule shares the tensor's memory.
44*da0073e9SAndroid Build Coastguard Worker""")
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker# TODO: add a typing.Protocol to be able to tell Mypy that only objects with
48*da0073e9SAndroid Build Coastguard Worker# __dlpack__ and __dlpack_device__ methods are accepted.
49*da0073e9SAndroid Build Coastguard Workerdef from_dlpack(ext_tensor: Any) -> 'torch.Tensor':
50*da0073e9SAndroid Build Coastguard Worker    """from_dlpack(ext_tensor) -> Tensor
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    Converts a tensor from an external library into a ``torch.Tensor``.
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker    The returned PyTorch tensor will share the memory with the input tensor
55*da0073e9SAndroid Build Coastguard Worker    (which may have come from another library). Note that in-place operations
56*da0073e9SAndroid Build Coastguard Worker    will therefore also affect the data of the input tensor. This may lead to
57*da0073e9SAndroid Build Coastguard Worker    unexpected issues (e.g., other libraries may have read-only flags or
58*da0073e9SAndroid Build Coastguard Worker    immutable data structures), so the user should only do this if they know
59*da0073e9SAndroid Build Coastguard Worker    for sure that this is fine.
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    Args:
62*da0073e9SAndroid Build Coastguard Worker        ext_tensor (object with ``__dlpack__`` attribute, or a DLPack capsule):
63*da0073e9SAndroid Build Coastguard Worker            The tensor or DLPack capsule to convert.
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker            If ``ext_tensor`` is a tensor (or ndarray) object, it must support
66*da0073e9SAndroid Build Coastguard Worker            the ``__dlpack__`` protocol (i.e., have a ``ext_tensor.__dlpack__``
67*da0073e9SAndroid Build Coastguard Worker            method). Otherwise ``ext_tensor`` may be a DLPack capsule, which is
68*da0073e9SAndroid Build Coastguard Worker            an opaque ``PyCapsule`` instance, typically produced by a
69*da0073e9SAndroid Build Coastguard Worker            ``to_dlpack`` function or method.
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker    Examples::
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker        >>> import torch.utils.dlpack
74*da0073e9SAndroid Build Coastguard Worker        >>> t = torch.arange(4)
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker        # Convert a tensor directly (supported in PyTorch >= 1.10)
77*da0073e9SAndroid Build Coastguard Worker        >>> t2 = torch.from_dlpack(t)
78*da0073e9SAndroid Build Coastguard Worker        >>> t2[:2] = -1  # show that memory is shared
79*da0073e9SAndroid Build Coastguard Worker        >>> t2
80*da0073e9SAndroid Build Coastguard Worker        tensor([-1, -1,  2,  3])
81*da0073e9SAndroid Build Coastguard Worker        >>> t
82*da0073e9SAndroid Build Coastguard Worker        tensor([-1, -1,  2,  3])
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker        # The old-style DLPack usage, with an intermediate capsule object
85*da0073e9SAndroid Build Coastguard Worker        >>> capsule = torch.utils.dlpack.to_dlpack(t)
86*da0073e9SAndroid Build Coastguard Worker        >>> capsule
87*da0073e9SAndroid Build Coastguard Worker        <capsule object "dltensor" at ...>
88*da0073e9SAndroid Build Coastguard Worker        >>> t3 = torch.from_dlpack(capsule)
89*da0073e9SAndroid Build Coastguard Worker        >>> t3
90*da0073e9SAndroid Build Coastguard Worker        tensor([-1, -1,  2,  3])
91*da0073e9SAndroid Build Coastguard Worker        >>> t3[0] = -9  # now we're sharing memory between 3 tensors
92*da0073e9SAndroid Build Coastguard Worker        >>> t3
93*da0073e9SAndroid Build Coastguard Worker        tensor([-9, -1,  2,  3])
94*da0073e9SAndroid Build Coastguard Worker        >>> t2
95*da0073e9SAndroid Build Coastguard Worker        tensor([-9, -1,  2,  3])
96*da0073e9SAndroid Build Coastguard Worker        >>> t
97*da0073e9SAndroid Build Coastguard Worker        tensor([-9, -1,  2,  3])
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker    """
100*da0073e9SAndroid Build Coastguard Worker    if hasattr(ext_tensor, '__dlpack__'):
101*da0073e9SAndroid Build Coastguard Worker        device = ext_tensor.__dlpack_device__()
102*da0073e9SAndroid Build Coastguard Worker        # device is either CUDA or ROCm, we need to pass the current
103*da0073e9SAndroid Build Coastguard Worker        # stream
104*da0073e9SAndroid Build Coastguard Worker        if device[0] in (DLDeviceType.kDLGPU, DLDeviceType.kDLROCM):
105*da0073e9SAndroid Build Coastguard Worker            stream = torch.cuda.current_stream(f'cuda:{device[1]}')
106*da0073e9SAndroid Build Coastguard Worker            # cuda_stream is the pointer to the stream and it is a public
107*da0073e9SAndroid Build Coastguard Worker            # attribute, but it is not documented
108*da0073e9SAndroid Build Coastguard Worker            # The array API specify that the default legacy stream must be passed
109*da0073e9SAndroid Build Coastguard Worker            # with a value of 1 for CUDA
110*da0073e9SAndroid Build Coastguard Worker            # https://data-apis.org/array-api/latest/API_specification/array_object.html?dlpack-self-stream-none#dlpack-self-stream-none
111*da0073e9SAndroid Build Coastguard Worker            is_cuda = device[0] == DLDeviceType.kDLGPU
112*da0073e9SAndroid Build Coastguard Worker            # Since pytorch is not using PTDS by default, lets directly pass
113*da0073e9SAndroid Build Coastguard Worker            # the legacy stream
114*da0073e9SAndroid Build Coastguard Worker            stream_ptr = 1 if is_cuda and stream.cuda_stream == 0 else stream.cuda_stream
115*da0073e9SAndroid Build Coastguard Worker            dlpack = ext_tensor.__dlpack__(stream=stream_ptr)
116*da0073e9SAndroid Build Coastguard Worker        else:
117*da0073e9SAndroid Build Coastguard Worker            dlpack = ext_tensor.__dlpack__()
118*da0073e9SAndroid Build Coastguard Worker    else:
119*da0073e9SAndroid Build Coastguard Worker        # Old versions just call the converter
120*da0073e9SAndroid Build Coastguard Worker        dlpack = ext_tensor
121*da0073e9SAndroid Build Coastguard Worker    return _from_dlpack(dlpack)
122