xref: /aosp_15_r20/external/pytorch/aten/src/ATen/DLConvertor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/DLConvertor.h>
2 #include <ATen/Functions.h>
3 
4 using namespace std;
5 namespace at {
6 
getDLDataType(const Tensor & t)7 DLDataType getDLDataType(const Tensor& t) {
8   DLDataType dtype;
9   dtype.lanes = 1;
10   dtype.bits = t.element_size() * 8;
11   switch (t.scalar_type()) {
12     case ScalarType::UInt1:
13     case ScalarType::UInt2:
14     case ScalarType::UInt3:
15     case ScalarType::UInt4:
16     case ScalarType::UInt5:
17     case ScalarType::UInt6:
18     case ScalarType::UInt7:
19     case ScalarType::Byte:
20     case ScalarType::UInt16:
21     case ScalarType::UInt32:
22     case ScalarType::UInt64:
23       dtype.code = DLDataTypeCode::kDLUInt;
24       break;
25     case ScalarType::Char:
26       dtype.code = DLDataTypeCode::kDLInt;
27       break;
28     // NOLINTNEXTLINE(bugprone-branch-clone)
29     case ScalarType::Double:
30       dtype.code = DLDataTypeCode::kDLFloat;
31       break;
32     case ScalarType::Float:
33       dtype.code = DLDataTypeCode::kDLFloat;
34       break;
35     // NOLINTNEXTLINE(bugprone-branch-clone)
36     case ScalarType::Int:
37       dtype.code = DLDataTypeCode::kDLInt;
38       break;
39     case ScalarType::Long:
40       dtype.code = DLDataTypeCode::kDLInt;
41       break;
42     case ScalarType::Short:
43       dtype.code = DLDataTypeCode::kDLInt;
44       break;
45     case ScalarType::Half:
46       dtype.code = DLDataTypeCode::kDLFloat;
47       break;
48     case ScalarType::Bool:
49       dtype.code = DLDataTypeCode::kDLBool;
50       break;
51     case ScalarType::ComplexHalf:
52       dtype.code = DLDataTypeCode::kDLComplex;
53       break;
54     case ScalarType::ComplexFloat:
55       dtype.code = DLDataTypeCode::kDLComplex;
56       break;
57     case ScalarType::ComplexDouble:
58       dtype.code = DLDataTypeCode::kDLComplex;
59       break;
60     case ScalarType::BFloat16:
61       dtype.code = DLDataTypeCode::kDLBfloat;
62       break;
63     case ScalarType::Float8_e5m2:
64     case ScalarType::Float8_e5m2fnuz:
65     case ScalarType::Float8_e4m3fn:
66     case ScalarType::Float8_e4m3fnuz:
67       TORCH_CHECK(false, "float8 types are not supported by dlpack");
68       break;
69     case ScalarType::QInt8:
70     case ScalarType::QUInt8:
71     case ScalarType::QInt32:
72     case ScalarType::QUInt4x2:
73     case ScalarType::QUInt2x4:
74       TORCH_CHECK(false, "QUInt/QInt types are not supported by dlpack");
75       break;
76     case ScalarType::Bits1x8:
77     case ScalarType::Bits2x4:
78     case ScalarType::Bits4x2:
79     case ScalarType::Bits8:
80     case ScalarType::Bits16:
81       TORCH_CHECK(false, "Bit types are not supported by dlpack");
82       break;
83     case ScalarType::Undefined:
84       TORCH_CHECK(false, "Undefined is not a valid ScalarType");
85     case ScalarType::NumOptions:
86       TORCH_CHECK(false, "NumOptions is not a valid ScalarType");
87   }
88   return dtype;
89 }
90 
getDLDevice(const Tensor & tensor,c10::DeviceIndex device_id)91 static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) {
92   DLDevice ctx;
93   ctx.device_id = static_cast<int32_t>(device_id);
94   switch (tensor.device().type()) {
95     case DeviceType::CPU:
96       ctx.device_type = DLDeviceType::kDLCPU;
97       break;
98     case DeviceType::CUDA:
99 #ifdef USE_ROCM
100       // ROCM, if enabled will look like cuda to PyTorch
101       // while everyone else should see HIP
102       ctx.device_type = DLDeviceType::kDLROCM;
103 #else
104       ctx.device_type = DLDeviceType::kDLCUDA;
105 #endif
106       break;
107     case DeviceType::OPENCL:
108       ctx.device_type = DLDeviceType::kDLOpenCL;
109       break;
110     case DeviceType::HIP:
111       ctx.device_type = DLDeviceType::kDLROCM;
112       break;
113     case DeviceType::XPU:
114       ctx.device_type = DLDeviceType::kDLOneAPI;
115       ctx.device_id =
116           at::detail::getXPUHooks().getGlobalIdxFromDevice(tensor.device());
117       break;
118     case DeviceType::MAIA:
119       ctx.device_type = DLDeviceType::kDLMAIA;
120       break;
121     default:
122       TORCH_CHECK(false, "Cannot pack tensors on " + tensor.device().str());
123   }
124   return ctx;
125 }
126 
getATenDevice(const DLDevice & ctx,void * data)127 static Device getATenDevice(const DLDevice& ctx, void* data) {
128   switch (ctx.device_type) {
129     case DLDeviceType::kDLCPU:
130       return at::Device(DeviceType::CPU);
131 #ifndef USE_ROCM
132     // if we are compiled under HIP, we cannot do cuda
133     case DLDeviceType::kDLCUDA:
134       return at::Device(DeviceType::CUDA, static_cast<c10::DeviceIndex>(ctx.device_id));
135 #endif
136     case DLDeviceType::kDLOpenCL:
137       return at::Device(DeviceType::OPENCL, static_cast<c10::DeviceIndex>(ctx.device_id));
138     case DLDeviceType::kDLROCM:
139 #ifdef USE_ROCM
140       // this looks funny, we need to return CUDA here to masquerade
141       return at::Device(DeviceType::CUDA, static_cast<c10::DeviceIndex>(ctx.device_id));
142 #else
143       return at::Device(DeviceType::HIP, static_cast<c10::DeviceIndex>(ctx.device_id));
144 #endif
145     case DLDeviceType::kDLOneAPI:
146       return at::detail::getXPUHooks().getDeviceFromPtr(data);
147     case DLDeviceType::kDLMAIA:
148       return at::Device(DeviceType::MAIA, static_cast<c10::DeviceIndex>(ctx.device_id));
149     default:
150       TORCH_CHECK(
151           false, "Unsupported device_type: ", std::to_string(ctx.device_type));
152   }
153 }
154 
toScalarType(const DLDataType & dtype)155 ScalarType toScalarType(const DLDataType& dtype) {
156   ScalarType stype = ScalarType::Undefined;
157   TORCH_CHECK(dtype.lanes == 1, "ATen does not support lanes != 1");
158   switch (dtype.code) {
159     case DLDataTypeCode::kDLUInt:
160       switch (dtype.bits) {
161         case 8:
162           stype = ScalarType::Byte;
163           break;
164         case 16:
165           stype = ScalarType::UInt16;
166           break;
167         case 32:
168           stype = ScalarType::UInt32;
169           break;
170         case 64:
171           stype = ScalarType::UInt64;
172           break;
173         default:
174           TORCH_CHECK(
175               false, "Unsupported kUInt bits ", std::to_string(dtype.bits));
176       }
177       break;
178     case DLDataTypeCode::kDLInt:
179       switch (dtype.bits) {
180         case 8:
181           stype = ScalarType::Char;
182           break;
183         case 16:
184           stype = ScalarType::Short;
185           break;
186         case 32:
187           stype = ScalarType::Int;
188           break;
189         case 64:
190           stype = ScalarType::Long;
191           break;
192         default:
193           TORCH_CHECK(
194               false, "Unsupported kInt bits ", std::to_string(dtype.bits));
195       }
196       break;
197     case DLDataTypeCode::kDLFloat:
198       switch (dtype.bits) {
199         case 16:
200           stype = ScalarType::Half;
201           break;
202         case 32:
203           stype = ScalarType::Float;
204           break;
205         case 64:
206           stype = ScalarType::Double;
207           break;
208         default:
209           TORCH_CHECK(
210               false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
211       }
212       break;
213     case DLDataTypeCode::kDLBfloat:
214       switch (dtype.bits) {
215         case 16:
216           stype = ScalarType::BFloat16;
217           break;
218         default:
219           TORCH_CHECK(
220               false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
221       }
222       break;
223     case DLDataTypeCode::kDLComplex:
224       switch (dtype.bits) {
225         case 32:
226           stype = ScalarType::ComplexHalf;
227           break;
228         case 64:
229           stype = ScalarType::ComplexFloat;
230           break;
231         case 128:
232           stype = ScalarType::ComplexDouble;
233           break;
234         default:
235           TORCH_CHECK(
236               false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
237       }
238       break;
239     case DLDataTypeCode::kDLBool:
240       switch (dtype.bits) {
241         case 8:
242           stype = ScalarType::Bool;
243           break;
244         default:
245           TORCH_CHECK(
246               false, "Unsupported kDLBool bits ", std::to_string(dtype.bits));
247       }
248       break;
249     default:
250       TORCH_CHECK(false, "Unsupported code ", std::to_string(dtype.code));
251   }
252   return stype;
253 }
254 
255 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
256 struct ATenDLMTensor {
257   Tensor handle;
258   DLManagedTensor tensor;
259 };
260 
deleter(DLManagedTensor * arg)261 static void deleter(DLManagedTensor* arg) {
262   delete static_cast<ATenDLMTensor*>(arg->manager_ctx);
263 }
264 
265 // This function returns a shared_ptr to memory managed DLpack tensor
266 // constructed out of ATen tensor
toDLPack(const Tensor & src)267 DLManagedTensor* toDLPack(const Tensor& src) {
268   // create a new tensor with possibly normalized strides
269   // gh-83069
270   auto shape = src.sizes();
271   auto strides = src.strides().vec();
272   for (int i = 0; i < src.dim(); i++) {
273     if (shape[i] < 2) {
274       strides[i] = 1;
275     }
276   }
277 
278   auto view = src.as_strided(shape, strides, src.storage_offset());
279   ATenDLMTensor* atDLMTensor(new ATenDLMTensor);
280   atDLMTensor->handle = view;
281   atDLMTensor->tensor.manager_ctx = atDLMTensor;
282   atDLMTensor->tensor.deleter = &deleter;
283   atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
284   c10::DeviceIndex device_id = 0;
285   if (src.is_cuda()) {
286     device_id = src.get_device();
287   }
288   atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id);
289   atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
290   atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
291   atDLMTensor->tensor.dl_tensor.shape = view.sizes().data();
292   atDLMTensor->tensor.dl_tensor.strides = view.strides().data();
293   atDLMTensor->tensor.dl_tensor.byte_offset = 0;
294   return &(atDLMTensor->tensor);
295 }
296 
fromDLPack(DLManagedTensor * src)297 Tensor fromDLPack(DLManagedTensor* src) {
298   auto deleter = [src](void* self [[maybe_unused]]) {
299     if (src->deleter) {
300       src->deleter(src);
301     }
302   };
303   return fromDLPack(src, std::move(deleter));
304 }
305 
fromDLPack(DLManagedTensor * src,std::function<void (void *)> deleter)306 Tensor fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter) {
307   Device device = getATenDevice(src->dl_tensor.device, src->dl_tensor.data);
308   ScalarType stype = toScalarType(src->dl_tensor.dtype);
309   if (!src->dl_tensor.strides) {
310     return at::from_blob(
311         src->dl_tensor.data,
312         IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
313         std::move(deleter),
314         at::device(device).dtype(stype),
315         {device});
316   }
317   return at::from_blob(
318       src->dl_tensor.data,
319       IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
320       IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim),
321       deleter,
322       at::device(device).dtype(stype),
323       {device});
324 }
325 } // namespace at
326