1 #include <ATen/DLConvertor.h>
2 #include <ATen/Functions.h>
3
4 using namespace std;
5 namespace at {
6
getDLDataType(const Tensor & t)7 DLDataType getDLDataType(const Tensor& t) {
8 DLDataType dtype;
9 dtype.lanes = 1;
10 dtype.bits = t.element_size() * 8;
11 switch (t.scalar_type()) {
12 case ScalarType::UInt1:
13 case ScalarType::UInt2:
14 case ScalarType::UInt3:
15 case ScalarType::UInt4:
16 case ScalarType::UInt5:
17 case ScalarType::UInt6:
18 case ScalarType::UInt7:
19 case ScalarType::Byte:
20 case ScalarType::UInt16:
21 case ScalarType::UInt32:
22 case ScalarType::UInt64:
23 dtype.code = DLDataTypeCode::kDLUInt;
24 break;
25 case ScalarType::Char:
26 dtype.code = DLDataTypeCode::kDLInt;
27 break;
28 // NOLINTNEXTLINE(bugprone-branch-clone)
29 case ScalarType::Double:
30 dtype.code = DLDataTypeCode::kDLFloat;
31 break;
32 case ScalarType::Float:
33 dtype.code = DLDataTypeCode::kDLFloat;
34 break;
35 // NOLINTNEXTLINE(bugprone-branch-clone)
36 case ScalarType::Int:
37 dtype.code = DLDataTypeCode::kDLInt;
38 break;
39 case ScalarType::Long:
40 dtype.code = DLDataTypeCode::kDLInt;
41 break;
42 case ScalarType::Short:
43 dtype.code = DLDataTypeCode::kDLInt;
44 break;
45 case ScalarType::Half:
46 dtype.code = DLDataTypeCode::kDLFloat;
47 break;
48 case ScalarType::Bool:
49 dtype.code = DLDataTypeCode::kDLBool;
50 break;
51 case ScalarType::ComplexHalf:
52 dtype.code = DLDataTypeCode::kDLComplex;
53 break;
54 case ScalarType::ComplexFloat:
55 dtype.code = DLDataTypeCode::kDLComplex;
56 break;
57 case ScalarType::ComplexDouble:
58 dtype.code = DLDataTypeCode::kDLComplex;
59 break;
60 case ScalarType::BFloat16:
61 dtype.code = DLDataTypeCode::kDLBfloat;
62 break;
63 case ScalarType::Float8_e5m2:
64 case ScalarType::Float8_e5m2fnuz:
65 case ScalarType::Float8_e4m3fn:
66 case ScalarType::Float8_e4m3fnuz:
67 TORCH_CHECK(false, "float8 types are not supported by dlpack");
68 break;
69 case ScalarType::QInt8:
70 case ScalarType::QUInt8:
71 case ScalarType::QInt32:
72 case ScalarType::QUInt4x2:
73 case ScalarType::QUInt2x4:
74 TORCH_CHECK(false, "QUInt/QInt types are not supported by dlpack");
75 break;
76 case ScalarType::Bits1x8:
77 case ScalarType::Bits2x4:
78 case ScalarType::Bits4x2:
79 case ScalarType::Bits8:
80 case ScalarType::Bits16:
81 TORCH_CHECK(false, "Bit types are not supported by dlpack");
82 break;
83 case ScalarType::Undefined:
84 TORCH_CHECK(false, "Undefined is not a valid ScalarType");
85 case ScalarType::NumOptions:
86 TORCH_CHECK(false, "NumOptions is not a valid ScalarType");
87 }
88 return dtype;
89 }
90
getDLDevice(const Tensor & tensor,c10::DeviceIndex device_id)91 static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) {
92 DLDevice ctx;
93 ctx.device_id = static_cast<int32_t>(device_id);
94 switch (tensor.device().type()) {
95 case DeviceType::CPU:
96 ctx.device_type = DLDeviceType::kDLCPU;
97 break;
98 case DeviceType::CUDA:
99 #ifdef USE_ROCM
100 // ROCM, if enabled will look like cuda to PyTorch
101 // while everyone else should see HIP
102 ctx.device_type = DLDeviceType::kDLROCM;
103 #else
104 ctx.device_type = DLDeviceType::kDLCUDA;
105 #endif
106 break;
107 case DeviceType::OPENCL:
108 ctx.device_type = DLDeviceType::kDLOpenCL;
109 break;
110 case DeviceType::HIP:
111 ctx.device_type = DLDeviceType::kDLROCM;
112 break;
113 case DeviceType::XPU:
114 ctx.device_type = DLDeviceType::kDLOneAPI;
115 ctx.device_id =
116 at::detail::getXPUHooks().getGlobalIdxFromDevice(tensor.device());
117 break;
118 case DeviceType::MAIA:
119 ctx.device_type = DLDeviceType::kDLMAIA;
120 break;
121 default:
122 TORCH_CHECK(false, "Cannot pack tensors on " + tensor.device().str());
123 }
124 return ctx;
125 }
126
getATenDevice(const DLDevice & ctx,void * data)127 static Device getATenDevice(const DLDevice& ctx, void* data) {
128 switch (ctx.device_type) {
129 case DLDeviceType::kDLCPU:
130 return at::Device(DeviceType::CPU);
131 #ifndef USE_ROCM
132 // if we are compiled under HIP, we cannot do cuda
133 case DLDeviceType::kDLCUDA:
134 return at::Device(DeviceType::CUDA, static_cast<c10::DeviceIndex>(ctx.device_id));
135 #endif
136 case DLDeviceType::kDLOpenCL:
137 return at::Device(DeviceType::OPENCL, static_cast<c10::DeviceIndex>(ctx.device_id));
138 case DLDeviceType::kDLROCM:
139 #ifdef USE_ROCM
140 // this looks funny, we need to return CUDA here to masquerade
141 return at::Device(DeviceType::CUDA, static_cast<c10::DeviceIndex>(ctx.device_id));
142 #else
143 return at::Device(DeviceType::HIP, static_cast<c10::DeviceIndex>(ctx.device_id));
144 #endif
145 case DLDeviceType::kDLOneAPI:
146 return at::detail::getXPUHooks().getDeviceFromPtr(data);
147 case DLDeviceType::kDLMAIA:
148 return at::Device(DeviceType::MAIA, static_cast<c10::DeviceIndex>(ctx.device_id));
149 default:
150 TORCH_CHECK(
151 false, "Unsupported device_type: ", std::to_string(ctx.device_type));
152 }
153 }
154
toScalarType(const DLDataType & dtype)155 ScalarType toScalarType(const DLDataType& dtype) {
156 ScalarType stype = ScalarType::Undefined;
157 TORCH_CHECK(dtype.lanes == 1, "ATen does not support lanes != 1");
158 switch (dtype.code) {
159 case DLDataTypeCode::kDLUInt:
160 switch (dtype.bits) {
161 case 8:
162 stype = ScalarType::Byte;
163 break;
164 case 16:
165 stype = ScalarType::UInt16;
166 break;
167 case 32:
168 stype = ScalarType::UInt32;
169 break;
170 case 64:
171 stype = ScalarType::UInt64;
172 break;
173 default:
174 TORCH_CHECK(
175 false, "Unsupported kUInt bits ", std::to_string(dtype.bits));
176 }
177 break;
178 case DLDataTypeCode::kDLInt:
179 switch (dtype.bits) {
180 case 8:
181 stype = ScalarType::Char;
182 break;
183 case 16:
184 stype = ScalarType::Short;
185 break;
186 case 32:
187 stype = ScalarType::Int;
188 break;
189 case 64:
190 stype = ScalarType::Long;
191 break;
192 default:
193 TORCH_CHECK(
194 false, "Unsupported kInt bits ", std::to_string(dtype.bits));
195 }
196 break;
197 case DLDataTypeCode::kDLFloat:
198 switch (dtype.bits) {
199 case 16:
200 stype = ScalarType::Half;
201 break;
202 case 32:
203 stype = ScalarType::Float;
204 break;
205 case 64:
206 stype = ScalarType::Double;
207 break;
208 default:
209 TORCH_CHECK(
210 false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
211 }
212 break;
213 case DLDataTypeCode::kDLBfloat:
214 switch (dtype.bits) {
215 case 16:
216 stype = ScalarType::BFloat16;
217 break;
218 default:
219 TORCH_CHECK(
220 false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
221 }
222 break;
223 case DLDataTypeCode::kDLComplex:
224 switch (dtype.bits) {
225 case 32:
226 stype = ScalarType::ComplexHalf;
227 break;
228 case 64:
229 stype = ScalarType::ComplexFloat;
230 break;
231 case 128:
232 stype = ScalarType::ComplexDouble;
233 break;
234 default:
235 TORCH_CHECK(
236 false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
237 }
238 break;
239 case DLDataTypeCode::kDLBool:
240 switch (dtype.bits) {
241 case 8:
242 stype = ScalarType::Bool;
243 break;
244 default:
245 TORCH_CHECK(
246 false, "Unsupported kDLBool bits ", std::to_string(dtype.bits));
247 }
248 break;
249 default:
250 TORCH_CHECK(false, "Unsupported code ", std::to_string(dtype.code));
251 }
252 return stype;
253 }
254
255 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
256 struct ATenDLMTensor {
257 Tensor handle;
258 DLManagedTensor tensor;
259 };
260
deleter(DLManagedTensor * arg)261 static void deleter(DLManagedTensor* arg) {
262 delete static_cast<ATenDLMTensor*>(arg->manager_ctx);
263 }
264
265 // This function returns a shared_ptr to memory managed DLpack tensor
266 // constructed out of ATen tensor
toDLPack(const Tensor & src)267 DLManagedTensor* toDLPack(const Tensor& src) {
268 // create a new tensor with possibly normalized strides
269 // gh-83069
270 auto shape = src.sizes();
271 auto strides = src.strides().vec();
272 for (int i = 0; i < src.dim(); i++) {
273 if (shape[i] < 2) {
274 strides[i] = 1;
275 }
276 }
277
278 auto view = src.as_strided(shape, strides, src.storage_offset());
279 ATenDLMTensor* atDLMTensor(new ATenDLMTensor);
280 atDLMTensor->handle = view;
281 atDLMTensor->tensor.manager_ctx = atDLMTensor;
282 atDLMTensor->tensor.deleter = &deleter;
283 atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
284 c10::DeviceIndex device_id = 0;
285 if (src.is_cuda()) {
286 device_id = src.get_device();
287 }
288 atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id);
289 atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
290 atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
291 atDLMTensor->tensor.dl_tensor.shape = view.sizes().data();
292 atDLMTensor->tensor.dl_tensor.strides = view.strides().data();
293 atDLMTensor->tensor.dl_tensor.byte_offset = 0;
294 return &(atDLMTensor->tensor);
295 }
296
fromDLPack(DLManagedTensor * src)297 Tensor fromDLPack(DLManagedTensor* src) {
298 auto deleter = [src](void* self [[maybe_unused]]) {
299 if (src->deleter) {
300 src->deleter(src);
301 }
302 };
303 return fromDLPack(src, std::move(deleter));
304 }
305
fromDLPack(DLManagedTensor * src,std::function<void (void *)> deleter)306 Tensor fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter) {
307 Device device = getATenDevice(src->dl_tensor.device, src->dl_tensor.data);
308 ScalarType stype = toScalarType(src->dl_tensor.dtype);
309 if (!src->dl_tensor.strides) {
310 return at::from_blob(
311 src->dl_tensor.data,
312 IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
313 std::move(deleter),
314 at::device(device).dtype(stype),
315 {device});
316 }
317 return at::from_blob(
318 src->dl_tensor.data,
319 IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
320 IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim),
321 deleter,
322 at::device(device).dtype(stype),
323 {device});
324 }
325 } // namespace at
326