xref: /aosp_15_r20/external/pytorch/aten/src/ATen/DLConvertor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <ATen/Tensor.h>
5 #include <ATen/dlpack.h>
6 
7 // this convertor will:
8 // 1) take a Tensor object and wrap it in the DLPack tensor
9 // 2) take a dlpack tensor and convert it to the ATen Tensor
10 
11 namespace at {
12 
13 TORCH_API ScalarType toScalarType(const DLDataType& dtype);
14 TORCH_API DLManagedTensor* toDLPack(const Tensor& src);
15 TORCH_API Tensor fromDLPack(DLManagedTensor* src);
16 C10_DEPRECATED_MESSAGE("Please migrate to a non-const variant")
fromDLPack(const DLManagedTensor * src)17 inline Tensor fromDLPack(const DLManagedTensor* src) {
18   return fromDLPack(const_cast<DLManagedTensor*>(src));
19 }
20 TORCH_API Tensor
21 fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter);
22 TORCH_API DLDataType getDLDataType(const Tensor& t);
23 TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id);
24 
25 } // namespace at
26