1*da0073e9SAndroid Build Coastguard Workerfrom typing import Any 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerimport enum 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _from_dlpack 7*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _to_dlpack as to_dlpack 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerclass DLDeviceType(enum.IntEnum): 11*da0073e9SAndroid Build Coastguard Worker # Enums as in DLPack specification (aten/src/ATen/dlpack.h) 12*da0073e9SAndroid Build Coastguard Worker kDLCPU = 1, 13*da0073e9SAndroid Build Coastguard Worker kDLGPU = 2, 14*da0073e9SAndroid Build Coastguard Worker kDLCPUPinned = 3, 15*da0073e9SAndroid Build Coastguard Worker kDLOpenCL = 4, 16*da0073e9SAndroid Build Coastguard Worker kDLVulkan = 7, 17*da0073e9SAndroid Build Coastguard Worker kDLMetal = 8, 18*da0073e9SAndroid Build Coastguard Worker kDLVPI = 9, 19*da0073e9SAndroid Build Coastguard Worker kDLROCM = 10, 20*da0073e9SAndroid Build Coastguard Worker kDLExtDev = 12, 21*da0073e9SAndroid Build Coastguard Worker kDLOneAPI = 14, 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workertorch._C._add_docstr(to_dlpack, r"""to_dlpack(tensor) -> PyCapsule 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard WorkerReturns an opaque object (a "DLPack capsule") representing the tensor. 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker.. note:: 29*da0073e9SAndroid Build Coastguard Worker ``to_dlpack`` is a legacy DLPack interface. The capsule it returns 30*da0073e9SAndroid Build Coastguard Worker cannot be used for anything in Python other than use it as input to 31*da0073e9SAndroid Build Coastguard Worker ``from_dlpack``. The more idiomatic use of DLPack is to call 32*da0073e9SAndroid Build Coastguard Worker ``from_dlpack`` directly on the tensor object - this works when that 33*da0073e9SAndroid Build Coastguard Worker object has a ``__dlpack__`` method, which PyTorch and most other 34*da0073e9SAndroid Build Coastguard Worker libraries indeed have now. 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker.. warning:: 37*da0073e9SAndroid Build Coastguard Worker Only call ``from_dlpack`` once per capsule produced with ``to_dlpack``. 38*da0073e9SAndroid Build Coastguard Worker Behavior when a capsule is consumed multiple times is undefined. 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard WorkerArgs: 41*da0073e9SAndroid Build Coastguard Worker tensor: a tensor to be exported 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard WorkerThe DLPack capsule shares the tensor's memory. 44*da0073e9SAndroid Build Coastguard Worker""") 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker# TODO: add a typing.Protocol to be able to tell Mypy that only objects with 48*da0073e9SAndroid Build Coastguard Worker# __dlpack__ and __dlpack_device__ methods are accepted. 49*da0073e9SAndroid Build Coastguard Workerdef from_dlpack(ext_tensor: Any) -> 'torch.Tensor': 50*da0073e9SAndroid Build Coastguard Worker """from_dlpack(ext_tensor) -> Tensor 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker Converts a tensor from an external library into a ``torch.Tensor``. 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker The returned PyTorch tensor will share the memory with the input tensor 55*da0073e9SAndroid Build Coastguard Worker (which may have come from another library). Note that in-place operations 56*da0073e9SAndroid Build Coastguard Worker will therefore also affect the data of the input tensor. This may lead to 57*da0073e9SAndroid Build Coastguard Worker unexpected issues (e.g., other libraries may have read-only flags or 58*da0073e9SAndroid Build Coastguard Worker immutable data structures), so the user should only do this if they know 59*da0073e9SAndroid Build Coastguard Worker for sure that this is fine. 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker Args: 62*da0073e9SAndroid Build Coastguard Worker ext_tensor (object with ``__dlpack__`` attribute, or a DLPack capsule): 63*da0073e9SAndroid Build Coastguard Worker The tensor or DLPack capsule to convert. 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker If ``ext_tensor`` is a tensor (or ndarray) object, it must support 66*da0073e9SAndroid Build Coastguard Worker the ``__dlpack__`` protocol (i.e., have a ``ext_tensor.__dlpack__`` 67*da0073e9SAndroid Build Coastguard Worker method). Otherwise ``ext_tensor`` may be a DLPack capsule, which is 68*da0073e9SAndroid Build Coastguard Worker an opaque ``PyCapsule`` instance, typically produced by a 69*da0073e9SAndroid Build Coastguard Worker ``to_dlpack`` function or method. 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker Examples:: 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker >>> import torch.utils.dlpack 74*da0073e9SAndroid Build Coastguard Worker >>> t = torch.arange(4) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker # Convert a tensor directly (supported in PyTorch >= 1.10) 77*da0073e9SAndroid Build Coastguard Worker >>> t2 = torch.from_dlpack(t) 78*da0073e9SAndroid Build Coastguard Worker >>> t2[:2] = -1 # show that memory is shared 79*da0073e9SAndroid Build Coastguard Worker >>> t2 80*da0073e9SAndroid Build Coastguard Worker tensor([-1, -1, 2, 3]) 81*da0073e9SAndroid Build Coastguard Worker >>> t 82*da0073e9SAndroid Build Coastguard Worker tensor([-1, -1, 2, 3]) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker # The old-style DLPack usage, with an intermediate capsule object 85*da0073e9SAndroid Build Coastguard Worker >>> capsule = torch.utils.dlpack.to_dlpack(t) 86*da0073e9SAndroid Build Coastguard Worker >>> capsule 87*da0073e9SAndroid Build Coastguard Worker <capsule object "dltensor" at ...> 88*da0073e9SAndroid Build Coastguard Worker >>> t3 = torch.from_dlpack(capsule) 89*da0073e9SAndroid Build Coastguard Worker >>> t3 90*da0073e9SAndroid Build Coastguard Worker tensor([-1, -1, 2, 3]) 91*da0073e9SAndroid Build Coastguard Worker >>> t3[0] = -9 # now we're sharing memory between 3 tensors 92*da0073e9SAndroid Build Coastguard Worker >>> t3 93*da0073e9SAndroid Build Coastguard Worker tensor([-9, -1, 2, 3]) 94*da0073e9SAndroid Build Coastguard Worker >>> t2 95*da0073e9SAndroid Build Coastguard Worker tensor([-9, -1, 2, 3]) 96*da0073e9SAndroid Build Coastguard Worker >>> t 97*da0073e9SAndroid Build Coastguard Worker tensor([-9, -1, 2, 3]) 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker """ 100*da0073e9SAndroid Build Coastguard Worker if hasattr(ext_tensor, '__dlpack__'): 101*da0073e9SAndroid Build Coastguard Worker device = ext_tensor.__dlpack_device__() 102*da0073e9SAndroid Build Coastguard Worker # device is either CUDA or ROCm, we need to pass the current 103*da0073e9SAndroid Build Coastguard Worker # stream 104*da0073e9SAndroid Build Coastguard Worker if device[0] in (DLDeviceType.kDLGPU, DLDeviceType.kDLROCM): 105*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.current_stream(f'cuda:{device[1]}') 106*da0073e9SAndroid Build Coastguard Worker # cuda_stream is the pointer to the stream and it is a public 107*da0073e9SAndroid Build Coastguard Worker # attribute, but it is not documented 108*da0073e9SAndroid Build Coastguard Worker # The array API specify that the default legacy stream must be passed 109*da0073e9SAndroid Build Coastguard Worker # with a value of 1 for CUDA 110*da0073e9SAndroid Build Coastguard Worker # https://data-apis.org/array-api/latest/API_specification/array_object.html?dlpack-self-stream-none#dlpack-self-stream-none 111*da0073e9SAndroid Build Coastguard Worker is_cuda = device[0] == DLDeviceType.kDLGPU 112*da0073e9SAndroid Build Coastguard Worker # Since pytorch is not using PTDS by default, lets directly pass 113*da0073e9SAndroid Build Coastguard Worker # the legacy stream 114*da0073e9SAndroid Build Coastguard Worker stream_ptr = 1 if is_cuda and stream.cuda_stream == 0 else stream.cuda_stream 115*da0073e9SAndroid Build Coastguard Worker dlpack = ext_tensor.__dlpack__(stream=stream_ptr) 116*da0073e9SAndroid Build Coastguard Worker else: 117*da0073e9SAndroid Build Coastguard Worker dlpack = ext_tensor.__dlpack__() 118*da0073e9SAndroid Build Coastguard Worker else: 119*da0073e9SAndroid Build Coastguard Worker # Old versions just call the converter 120*da0073e9SAndroid Build Coastguard Worker dlpack = ext_tensor 121*da0073e9SAndroid Build Coastguard Worker return _from_dlpack(dlpack) 122