1 #include <c10/core/DeviceType.h>
2 #include <c10/core/GradMode.h>
3 #include <c10/core/Layout.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/util/Exception.h>
6 #include <torch/csrc/inductor/aoti_torch/c/shim.h>
7 #include <torch/csrc/inductor/aoti_torch/mkldnn_tensor.h>
8 #include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
9 #include <torch/csrc/inductor/aoti_torch/proxy_executor.h>
10 #include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
11 #include <torch/csrc/inductor/aoti_torch/utils.h>
12 #include <torch/csrc/inductor/inductor_ops.h>
13 #include <torch/csrc/jit/serialization/pickle.h>
14 #include <cstdint>
15 #include <cstdio>
16 #include <fstream>
17 #include <iostream>
18
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/Functions.h>
21 #else
22
23 #include <ATen/ops/_addmm_activation.h>
24 #include <ATen/ops/_embedding_bag.h>
25 #include <ATen/ops/_fft_c2c.h>
26 #include <ATen/ops/_scaled_dot_product_efficient_attention.h>
27 #include <ATen/ops/_scaled_dot_product_flash_attention.h>
28 #include <ATen/ops/_scaled_mm.h>
29 #include <ATen/ops/_wrapped_linear_prepack.h>
30 #include <ATen/ops/_wrapped_quantized_linear_prepacked.h>
31 #include <ATen/ops/addmm.h>
32 #include <ATen/ops/as_strided.h>
33 #include <ATen/ops/bmm.h>
34 #include <ATen/ops/convolution.h>
35 #include <ATen/ops/empty_strided.h>
36 #include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation.h>
37 #include <ATen/ops/fbgemm_pack_gemm_matrix_fp16.h>
38 #include <ATen/ops/from_blob.h>
39 #include <ATen/ops/index_put.h>
40 #include <ATen/ops/mm.h>
41 #include <ATen/ops/nonzero.h>
42 #include <ATen/ops/scalar_tensor.h>
43 #include <ATen/ops/scatter.h>
44 #include <ATen/ops/scatter_reduce.h>
45 #include <ATen/ops/view_as_real_ops.h>
46 #include <ATen/ops/view_ops.h>
47
48 #endif
49
50 #if __has_include("filesystem")
51 #include <filesystem>
52 namespace fs = std::filesystem;
53 #else
54 #include <experimental/filesystem>
55 namespace fs = std::experimental::filesystem;
56 #endif
57
58 #ifndef _WIN32
59 #include <limits.h>
60 #include <sys/stat.h>
61 #include <sys/types.h>
62 #include <unistd.h>
63 #endif
64
65 // HACK for failed builds in ARVR, where it cannot find these symbols within
66 // std::experimental::filesystem
67 namespace {
get_current_path()68 std::string get_current_path() {
69 #if __has_include("filesystem") && !defined(__linux__)
70 return fs::current_path().string();
71 #else
72 char currentPath[PATH_MAX];
73 if (getcwd(currentPath, sizeof(currentPath)) != nullptr) {
74 return std::string(currentPath);
75 } else {
76 throw std::runtime_error("Failed to get current path");
77 }
78 #endif
79 }
80
file_exists(std::string & path)81 bool file_exists(std::string& path) {
82 #if __has_include("filesystem") && !defined(__linux__)
83 return fs::exists(path);
84 #else
85 struct stat rc;
86 return lstat(path.c_str(), &rc) == 0;
87 #endif
88 }
89
create_directories(const std::string & path)90 bool create_directories(const std::string& path) {
91 #if __has_include("filesystem") && !defined(__linux__)
92 return fs::create_directories(path);
93 #else
94 if (mkdir(path.c_str(), 0777) == -1) {
95 throw std::runtime_error("Failed to create directory");
96 }
97 return true;
98 #endif
99 }
100 } // namespace
101
102 using namespace torch::aot_inductor;
103
104 namespace {
c10_device(int32_t device_type,int32_t device_index)105 static c10::Device c10_device(int32_t device_type, int32_t device_index) {
106 if (device_type == aoti_torch_device_type_cpu()) {
107 return c10::Device(static_cast<c10::DeviceType>(device_type));
108 } else {
109 return c10::Device(
110 static_cast<c10::DeviceType>(device_type),
111 static_cast<c10::DeviceIndex>(device_index));
112 }
113 }
114 } // namespace
115
116 const int AOTI_TORCH_MAX_NUMEL_TO_PRINT = 64;
117
aoti_torch_device_type_cpu()118 int32_t aoti_torch_device_type_cpu() {
119 return (int32_t)c10::DeviceType::CPU;
120 }
121
aoti_torch_device_type_cuda()122 int32_t aoti_torch_device_type_cuda() {
123 return (int32_t)c10::DeviceType::CUDA;
124 }
125
126 #define AOTI_TORCH_DTYPE_IMPL(dtype, stype) \
127 int32_t aoti_torch_dtype_##dtype() { \
128 return (int32_t)c10::ScalarType::stype; \
129 }
130
AOTI_TORCH_DTYPE_IMPL(float8_e5m2,Float8_e5m2)131 AOTI_TORCH_DTYPE_IMPL(float8_e5m2, Float8_e5m2)
132 AOTI_TORCH_DTYPE_IMPL(float8_e4m3fn, Float8_e4m3fn)
133 AOTI_TORCH_DTYPE_IMPL(float8_e5m2fnuz, Float8_e5m2fnuz)
134 AOTI_TORCH_DTYPE_IMPL(float8_e4m3fnuz, Float8_e4m3fnuz)
135 AOTI_TORCH_DTYPE_IMPL(bfloat16, BFloat16)
136 AOTI_TORCH_DTYPE_IMPL(float16, Half)
137 AOTI_TORCH_DTYPE_IMPL(float32, Float)
138 AOTI_TORCH_DTYPE_IMPL(float64, Double)
139 AOTI_TORCH_DTYPE_IMPL(uint8, Byte)
140 AOTI_TORCH_DTYPE_IMPL(uint16, UInt16)
141 AOTI_TORCH_DTYPE_IMPL(uint32, UInt32)
142 AOTI_TORCH_DTYPE_IMPL(uint64, UInt64)
143 AOTI_TORCH_DTYPE_IMPL(int8, Char)
144 AOTI_TORCH_DTYPE_IMPL(int16, Short)
145 AOTI_TORCH_DTYPE_IMPL(int32, Int)
146 AOTI_TORCH_DTYPE_IMPL(int64, Long)
147 AOTI_TORCH_DTYPE_IMPL(bool, Bool)
148 AOTI_TORCH_DTYPE_IMPL(complex32, ComplexHalf)
149 AOTI_TORCH_DTYPE_IMPL(complex64, ComplexFloat)
150 AOTI_TORCH_DTYPE_IMPL(complex128, ComplexDouble)
151 #undef AOTI_TORCH_DTYPE_IMPL
152
153 int32_t aoti_torch_layout_strided() {
154 return (int32_t)at::kStrided;
155 }
156
aoti_torch_layout__mkldnn()157 int32_t aoti_torch_layout__mkldnn() {
158 return (int32_t)at::kMkldnn;
159 }
160
161 #define AOTI_TORCH_ITEM_IMPL(dtype, ctype) \
162 AOTITorchError aoti_torch_item_##dtype( \
163 AtenTensorHandle tensor, ctype* ret_value) { \
164 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ \
165 at::Tensor* t = tensor_handle_to_tensor_pointer(tensor); \
166 *ret_value = t->item().to<ctype>(); \
167 }); \
168 }
169
AOTI_TORCH_ITEM_IMPL(float16,c10::Half)170 AOTI_TORCH_ITEM_IMPL(float16, c10::Half)
171 AOTI_TORCH_ITEM_IMPL(float32, float)
172 AOTI_TORCH_ITEM_IMPL(float64, double)
173 AOTI_TORCH_ITEM_IMPL(uint8, uint8_t)
174 AOTI_TORCH_ITEM_IMPL(uint16, uint16_t)
175 AOTI_TORCH_ITEM_IMPL(uint32, uint32_t)
176 AOTI_TORCH_ITEM_IMPL(uint64, uint64_t)
177 AOTI_TORCH_ITEM_IMPL(int8, int8_t)
178 AOTI_TORCH_ITEM_IMPL(int16, int16_t)
179 AOTI_TORCH_ITEM_IMPL(int32, int32_t)
180 AOTI_TORCH_ITEM_IMPL(int64, int64_t)
181 AOTI_TORCH_ITEM_IMPL(bool, bool)
182 AOTI_TORCH_ITEM_IMPL(bfloat16, c10::BFloat16)
183 AOTI_TORCH_ITEM_IMPL(complex64, c10::complex<float>)
184 #undef AOTI_TORCH_ITEM_IMPL
185
186 #define AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(dtype, ctype, ttype) \
187 AOTITorchError aoti_torch_scalar_to_tensor_##dtype( \
188 ctype value, AtenTensorHandle* ret_new_tensor) { \
189 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ \
190 *ret_new_tensor = \
191 new_tensor_handle(at::scalar_tensor(value, c10::ScalarType::ttype)); \
192 }); \
193 }
194
195 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(float32, float, Float)
196 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(float64, double, Double)
197 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(uint8, uint8_t, Byte)
198 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(uint16, uint16_t, UInt16)
199 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(uint32, uint32_t, UInt32)
200 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(uint64, uint64_t, UInt64)
201 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(int8, int8_t, Char)
202 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(int16, int16_t, Short)
203 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(int32, int32_t, Int)
204 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(int64, int64_t, Long)
205 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(bool, bool, Bool)
206 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(complex64, c10::complex<float>, ComplexFloat)
207 #undef AOTI_TORCH_SCALAR_TO_TENSOR_IMPL
208
209 bool aoti_torch_grad_mode_is_enabled() {
210 return c10::GradMode::is_enabled();
211 }
212
aoti_torch_grad_mode_set_enabled(bool enabled)213 void aoti_torch_grad_mode_set_enabled(bool enabled) {
214 return c10::GradMode::set_enabled(enabled);
215 }
216
aoti_torch_delete_tensor_object(AtenTensorHandle tensor)217 AOTITorchError aoti_torch_delete_tensor_object(AtenTensorHandle tensor) {
218 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
219 at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
220 delete t;
221 });
222 }
223
aoti_torch_get_data_ptr(AtenTensorHandle tensor,void ** ret_data_ptr)224 AOTITorchError aoti_torch_get_data_ptr(
225 AtenTensorHandle tensor,
226 void** ret_data_ptr) {
227 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
228 at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
229 if (t->is_mkldnn()) {
230 *ret_data_ptr = data_ptr_from_mkldnn(t);
231 } else {
232 *ret_data_ptr = t->data_ptr();
233 }
234 });
235 }
236
aoti_torch_get_storage_size(AtenTensorHandle tensor,int64_t * ret_size)237 AOTITorchError aoti_torch_get_storage_size(
238 AtenTensorHandle tensor,
239 int64_t* ret_size) {
240 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
241 at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
242 *ret_size = t->storage().nbytes();
243 });
244 }
245
aoti_torch_get_dim(AtenTensorHandle tensor,int64_t * ret_dim)246 AOTITorchError aoti_torch_get_dim(AtenTensorHandle tensor, int64_t* ret_dim) {
247 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
248 at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
249 *ret_dim = t->dim();
250 });
251 }
252
aoti_torch_get_numel(AtenTensorHandle tensor,int64_t * ret_numel)253 AOTITorchError aoti_torch_get_numel(
254 AtenTensorHandle tensor,
255 int64_t* ret_numel) {
256 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
257 at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
258 *ret_numel = t->numel();
259 });
260 }
261
aoti_torch_get_storage_numel(AtenTensorHandle tensor,int64_t * ret_numel)262 AOTITorchError aoti_torch_get_storage_numel(
263 AtenTensorHandle tensor,
264 int64_t* ret_numel) {
265 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
266 at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
267 TORCH_INTERNAL_ASSERT(t->has_storage());
268 auto dtype_size = t->dtype().itemsize();
269 size_t nbytes = t->storage().nbytes();
270 TORCH_INTERNAL_ASSERT(nbytes % dtype_size == 0);
271 auto numel = nbytes / dtype_size;
272 *ret_numel = numel;
273 });
274 }
275
aoti_torch_get_sizes(AtenTensorHandle tensor,int64_t ** ret_sizes)276 AOTITorchError aoti_torch_get_sizes(
277 AtenTensorHandle tensor,
278 int64_t** ret_sizes) {
279 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
280 at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
281 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
282 *ret_sizes = const_cast<int64_t*>(t->sizes().data());
283 });
284 }
285
aoti_torch_get_size(AtenTensorHandle tensor,int64_t d,int64_t * ret_size)286 AOTITorchError aoti_torch_get_size(
287 AtenTensorHandle tensor,
288 int64_t d,
289 int64_t* ret_size) {
290 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
291 at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
292 *ret_size = t->size(d);
293 });
294 }
295
aoti_torch_get_strides(AtenTensorHandle tensor,int64_t ** ret_strides)296 AOTITorchError aoti_torch_get_strides(
297 AtenTensorHandle tensor,
298 int64_t** ret_strides) {
299 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
300 at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
301 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
302 *ret_strides = const_cast<int64_t*>(t->strides().data());
303 });
304 }
305
aoti_torch_get_stride(AtenTensorHandle tensor,int64_t d,int64_t * ret_stride)306 AOTITorchError aoti_torch_get_stride(
307 AtenTensorHandle tensor,
308 int64_t d,
309 int64_t* ret_stride) {
310 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
311 at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
312 *ret_stride = t->stride(d);
313 });
314 }
315
aoti_torch_get_dtype(AtenTensorHandle tensor,int32_t * ret_dtype)316 AOTITorchError aoti_torch_get_dtype(
317 AtenTensorHandle tensor,
318 int32_t* ret_dtype) {
319 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
320 at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
321 *ret_dtype = static_cast<int32_t>(t->scalar_type());
322 });
323 }
324
aoti_torch_get_device_type(AtenTensorHandle tensor,int32_t * ret_device_type)325 AOTITorchError aoti_torch_get_device_type(
326 AtenTensorHandle tensor,
327 int32_t* ret_device_type) {
328 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
329 at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
330 *ret_device_type = static_cast<int32_t>(t->device().type());
331 });
332 }
333
aoti_torch_get_device_index(AtenTensorHandle tensor,int32_t * ret_device_index)334 AOTITorchError aoti_torch_get_device_index(
335 AtenTensorHandle tensor,
336 int32_t* ret_device_index) {
337 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
338 at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
339 *ret_device_index = static_cast<int16_t>(t->device().index());
340 });
341 }
342
aoti_torch_get_storage_offset(AtenTensorHandle tensor,int64_t * ret_storage_offset)343 AOTITorchError aoti_torch_get_storage_offset(
344 AtenTensorHandle tensor,
345 int64_t* ret_storage_offset) {
346 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
347 at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
348 *ret_storage_offset = t->storage_offset();
349 });
350 }
351
aoti_torch__reinterpret_tensor(AtenTensorHandle self,int64_t ndim,const int64_t * sizes_ptr,const int64_t * strides_ptr,int64_t offset_increment,AtenTensorHandle * ret_new_tensor)352 AOTITorchError aoti_torch__reinterpret_tensor(
353 AtenTensorHandle self,
354 int64_t ndim,
355 const int64_t* sizes_ptr,
356 const int64_t* strides_ptr,
357 int64_t offset_increment,
358 AtenTensorHandle* ret_new_tensor) {
359 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
360 at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
361 c10::IntArrayRef sizes(sizes_ptr, ndim);
362 c10::IntArrayRef strides(strides_ptr, ndim);
363 *ret_new_tensor = new_tensor_handle(torch::inductor::_reinterpret_tensor(
364 *self_tensor, sizes, strides, offset_increment));
365 });
366 }
367
368 // TODO: implement a more efficient version instead of calling into aten
aoti_torch_empty_strided(int64_t ndim,const int64_t * sizes_ptr,const int64_t * strides_ptr,int32_t dtype,int32_t device_type,int32_t device_index,AtenTensorHandle * ret_new_tensor)369 AOTITorchError aoti_torch_empty_strided(
370 int64_t ndim,
371 const int64_t* sizes_ptr,
372 const int64_t* strides_ptr,
373 int32_t dtype,
374 int32_t device_type,
375 int32_t device_index,
376 AtenTensorHandle* ret_new_tensor) {
377 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
378 c10::IntArrayRef sizes(sizes_ptr, ndim);
379 c10::IntArrayRef strides(strides_ptr, ndim);
380 if (c10::DeviceType(device_type) == c10::DeviceType::CPU) {
381 *ret_new_tensor = new_tensor_handle(at::detail::empty_strided_cpu(
382 sizes, strides, static_cast<c10::ScalarType>(dtype)));
383 } else {
384 c10::Device device = c10_device(device_type, device_index);
385 c10::TensorOptions options = c10::TensorOptions().device(device).dtype(
386 static_cast<c10::ScalarType>(dtype));
387 *ret_new_tensor =
388 new_tensor_handle(at::empty_strided(sizes, strides, options));
389 }
390 });
391 }
392
aoti_torch_create_tensor_from_blob(void * data,int64_t ndim,const int64_t * sizes_ptr,const int64_t * strides_ptr,int64_t storage_offset,int32_t dtype,int32_t device_type,int32_t device_index,AtenTensorHandle * ret_new_tensor)393 AOTITorchError aoti_torch_create_tensor_from_blob(
394 void* data,
395 int64_t ndim,
396 const int64_t* sizes_ptr,
397 const int64_t* strides_ptr,
398 int64_t storage_offset,
399 int32_t dtype,
400 int32_t device_type,
401 int32_t device_index,
402 AtenTensorHandle* ret_new_tensor) {
403 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
404 c10::IntArrayRef sizes(sizes_ptr, ndim);
405 c10::IntArrayRef strides(strides_ptr, ndim);
406 c10::Device device = c10_device(device_type, device_index);
407 c10::TensorOptions options = c10::TensorOptions().device(device).dtype(
408 static_cast<c10::ScalarType>(dtype));
409 *ret_new_tensor = new_tensor_handle(
410 // data == nullptr can happen for a 0-size tensor
411 (data != nullptr) ? at::for_blob(data, sizes)
412 .strides(strides)
413 .storage_offset(storage_offset)
414 .options(options)
415 .make_tensor()
416 : at::empty_strided(sizes, strides, options));
417 });
418 }
419
aoti_torch_create_tensor_from_blob_v2(void * data,int64_t ndim,const int64_t * sizes_ptr,const int64_t * strides_ptr,int64_t storage_offset,int32_t dtype,int32_t device_type,int32_t device_index,AtenTensorHandle * ret_new_tensor,int32_t layout,const uint8_t * opaque_metadata,int64_t opaque_metadata_size)420 AOTITorchError aoti_torch_create_tensor_from_blob_v2(
421 void* data,
422 int64_t ndim,
423 const int64_t* sizes_ptr,
424 const int64_t* strides_ptr,
425 int64_t storage_offset,
426 int32_t dtype,
427 int32_t device_type,
428 int32_t device_index,
429 AtenTensorHandle* ret_new_tensor,
430 int32_t layout,
431 const uint8_t* opaque_metadata,
432 int64_t opaque_metadata_size) {
433 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
434 if (layout == static_cast<int32_t>(at::kMkldnn)) {
435 c10::IntArrayRef sizes(sizes_ptr, ndim);
436 c10::IntArrayRef strides(strides_ptr, ndim);
437 c10::Device device = c10_device(device_type, device_index);
438 // get a mkldnn tensor wrapped by a torch Tensor(OpaqueTensorImpl),
439 // which used by later mkldnn op.
440 *ret_new_tensor = new_tensor_handle(mkldnn_tensor_from_data_ptr(
441 data,
442 sizes,
443 static_cast<c10::ScalarType>(dtype),
444 device,
445 opaque_metadata,
446 opaque_metadata_size));
447 } else {
448 aoti_torch_create_tensor_from_blob(
449 data,
450 ndim,
451 sizes_ptr,
452 strides_ptr,
453 storage_offset,
454 dtype,
455 device_type,
456 device_index,
457 ret_new_tensor);
458 }
459 });
460 }
461
aoti_torch__embedding_bag(AtenTensorHandle weight,AtenTensorHandle indices,AtenTensorHandle offsets,int32_t scale_grad_by_freq,int32_t mode,int32_t sparse,AtenTensorHandle per_sample_weights,int32_t include_last_offset,int32_t padding_idx,AtenTensorHandle * ret0,AtenTensorHandle * ret1,AtenTensorHandle * ret2,AtenTensorHandle * ret3)462 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag(
463 AtenTensorHandle weight,
464 AtenTensorHandle indices,
465 AtenTensorHandle offsets,
466 int32_t scale_grad_by_freq,
467 int32_t mode,
468 int32_t sparse,
469 AtenTensorHandle per_sample_weights, // optional argument
470 int32_t include_last_offset,
471 int32_t padding_idx,
472 AtenTensorHandle* ret0, // returns new reference
473 AtenTensorHandle* ret1, // returns new reference
474 AtenTensorHandle* ret2, // returns new reference
475 AtenTensorHandle* ret3 // returns new reference
476 ) {
477 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
478 auto [r0, r1, r2, r3] = at::_embedding_bag(
479 *tensor_handle_to_tensor_pointer(weight),
480 *tensor_handle_to_tensor_pointer(indices),
481 *tensor_handle_to_tensor_pointer(offsets),
482 scale_grad_by_freq,
483 mode,
484 sparse,
485 pointer_to_optional(
486 tensor_handle_to_tensor_pointer(per_sample_weights)),
487 include_last_offset,
488 padding_idx);
489
490 *ret0 = new_tensor_handle(std::move(r0));
491 *ret1 = new_tensor_handle(std::move(r1));
492 *ret2 = new_tensor_handle(std::move(r2));
493 *ret3 = new_tensor_handle(std::move(r3));
494 });
495 }
496
aoti_torch__fft_c2c(AtenTensorHandle self,const int64_t * dim_ptr,int64_t dim_size,int64_t normalization,int32_t forward,AtenTensorHandle * ret)497 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__fft_c2c(
498 AtenTensorHandle self,
499 const int64_t* dim_ptr,
500 int64_t dim_size,
501 int64_t normalization,
502 int32_t forward,
503 AtenTensorHandle* ret // returns new reference
504 ) {
505 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
506 auto dim = c10::IntArrayRef(dim_ptr, dim_size);
507 *ret = new_tensor_handle(at::_fft_c2c(
508 *tensor_handle_to_tensor_pointer(self), dim, normalization, forward));
509 });
510 }
511
aoti_torch__scaled_dot_product_flash_attention_v2(AtenTensorHandle query,AtenTensorHandle key,AtenTensorHandle value,double dropout_p,int is_causal,int return_debug_mask,double * scale,AtenTensorHandle * ret0,AtenTensorHandle * ret1,AtenTensorHandle * ret2,AtenTensorHandle * ret3,int64_t * ret4,int64_t * ret5,AtenTensorHandle * ret6,AtenTensorHandle * ret7,AtenTensorHandle * ret8)512 AOTITorchError aoti_torch__scaled_dot_product_flash_attention_v2(
513 AtenTensorHandle query,
514 AtenTensorHandle key,
515 AtenTensorHandle value,
516 double dropout_p,
517 int is_causal,
518 int return_debug_mask,
519 double* scale, // optional argument
520 AtenTensorHandle* ret0, // returns new reference
521 AtenTensorHandle* ret1, // returns new reference
522 AtenTensorHandle* ret2, // returns new reference
523 AtenTensorHandle* ret3, // returns new reference
524 int64_t* ret4,
525 int64_t* ret5,
526 AtenTensorHandle* ret6, // returns new reference
527 AtenTensorHandle* ret7, // returns new reference
528 AtenTensorHandle* ret8 // returns new reference
529 ) {
530 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
531 at::Tensor* query_tensor = tensor_handle_to_tensor_pointer(query);
532 at::Tensor* key_tensor = tensor_handle_to_tensor_pointer(key);
533 at::Tensor* value_tensor = tensor_handle_to_tensor_pointer(value);
534 auto optional_scale = pointer_to_optional(scale);
535 auto [r0, r1, r2, r3, r4, r5, r6, r7, r8] =
536 at::_scaled_dot_product_flash_attention(
537 *query_tensor,
538 *key_tensor,
539 *value_tensor,
540 dropout_p,
541 is_causal,
542 return_debug_mask,
543 optional_scale);
544
545 *ret0 = new_tensor_handle(std::move(r0));
546 *ret1 = new_tensor_handle(std::move(r1));
547 // ret2 and ret3 may be null
548 if (ret2) {
549 *ret2 = new_tensor_handle(std::move(r2));
550 }
551 if (ret3) {
552 *ret3 = new_tensor_handle(std::move(r3));
553 }
554 *ret4 = r4.expect_int();
555 *ret5 = r5.expect_int();
556 *ret6 = new_tensor_handle(std::move(r6));
557 *ret7 = new_tensor_handle(std::move(r7));
558 *ret8 = new_tensor_handle(std::move(r8));
559 });
560 }
561
aoti_torch__scaled_dot_product_flash_attention(AtenTensorHandle query,AtenTensorHandle key,AtenTensorHandle value,double dropout_p,bool is_causal,bool return_debug_mask,double scale,AtenTensorHandle * ret0,AtenTensorHandle * ret1,AtenTensorHandle * ret2,AtenTensorHandle * ret3,int64_t * ret4,int64_t * ret5,AtenTensorHandle * ret6,AtenTensorHandle * ret7,AtenTensorHandle * ret8)562 AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
563 AtenTensorHandle query,
564 AtenTensorHandle key,
565 AtenTensorHandle value,
566 double dropout_p,
567 bool is_causal,
568 bool return_debug_mask,
569 double scale,
570 AtenTensorHandle* ret0, // returns new reference
571 AtenTensorHandle* ret1, // returns new reference
572 AtenTensorHandle* ret2, // returns new reference
573 AtenTensorHandle* ret3, // returns new reference
574 int64_t* ret4,
575 int64_t* ret5,
576 AtenTensorHandle* ret6, // returns new reference
577 AtenTensorHandle* ret7, // returns new reference
578 AtenTensorHandle* ret8 // returns new reference
579 ) {
580 return aoti_torch__scaled_dot_product_flash_attention_v2(
581 query,
582 key,
583 value,
584 dropout_p,
585 is_causal,
586 return_debug_mask,
587 &scale,
588 ret0,
589 ret1,
590 ret2,
591 ret3,
592 ret4,
593 ret5,
594 ret6,
595 ret7,
596 ret8);
597 }
598
599 AOTI_TORCH_EXPORT AOTITorchError
aoti_torch__scaled_dot_product_efficient_attention(AtenTensorHandle query,AtenTensorHandle key,AtenTensorHandle value,AtenTensorHandle attn_bias,int compute_log_sumexp,double dropout_p,int is_causal,double * scale,AtenTensorHandle * ret0,AtenTensorHandle * ret1,AtenTensorHandle * ret2,AtenTensorHandle * ret3)600 aoti_torch__scaled_dot_product_efficient_attention(
601 AtenTensorHandle query,
602 AtenTensorHandle key,
603 AtenTensorHandle value,
604 AtenTensorHandle attn_bias, // optional argument
605 int compute_log_sumexp,
606 double dropout_p,
607 int is_causal,
608 double* scale, // optional argument
609 AtenTensorHandle* ret0, // returns new reference
610 AtenTensorHandle* ret1, // returns new reference
611 AtenTensorHandle* ret2, // returns new reference
612 AtenTensorHandle* ret3 // returns new reference
613 ) {
614 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
615 at::Tensor* query_tensor = tensor_handle_to_tensor_pointer(query);
616 at::Tensor* key_tensor = tensor_handle_to_tensor_pointer(key);
617 at::Tensor* value_tensor = tensor_handle_to_tensor_pointer(value);
618 auto optional_attn_bias =
619 pointer_to_optional(tensor_handle_to_tensor_pointer(attn_bias));
620 auto optional_scale = pointer_to_optional(scale);
621 auto [r0, r1, r2, r3] = at::_scaled_dot_product_efficient_attention(
622 *query_tensor,
623 *key_tensor,
624 *value_tensor,
625 optional_attn_bias,
626 compute_log_sumexp,
627 dropout_p,
628 is_causal,
629 optional_scale);
630 *ret0 = new_tensor_handle(std::move(r0));
631 *ret1 = new_tensor_handle(std::move(r1));
632 *ret2 = new_tensor_handle(std::move(r2));
633 *ret3 = new_tensor_handle(std::move(r3));
634 });
635 }
636
aoti_torch_convolution(AtenTensorHandle input,AtenTensorHandle weight,AtenTensorHandle bias,const int64_t * stride_ptr,int64_t stride_size,const int64_t * padding_ptr,int64_t padding_size,const int64_t * dilation_ptr,int64_t dilation_size,int transposed,const int64_t * output_padding_ptr,int64_t output_padding_size,int64_t groups,AtenTensorHandle * out)637 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_convolution(
638 AtenTensorHandle input,
639 AtenTensorHandle weight,
640 AtenTensorHandle bias, // optional argument
641 const int64_t* stride_ptr,
642 int64_t stride_size,
643 const int64_t* padding_ptr,
644 int64_t padding_size,
645 const int64_t* dilation_ptr,
646 int64_t dilation_size,
647 int transposed,
648 const int64_t* output_padding_ptr,
649 int64_t output_padding_size,
650 int64_t groups,
651 AtenTensorHandle* out // returns new reference
652 ) {
653 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
654 at::Tensor* input_tensor = tensor_handle_to_tensor_pointer(input);
655 at::Tensor* weight_tensor = tensor_handle_to_tensor_pointer(weight);
656 at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias);
657 auto optional_bias = pointer_to_optional(bias_tensor);
658 c10::IntArrayRef stride(stride_ptr, stride_size);
659 c10::IntArrayRef padding(padding_ptr, padding_size);
660 c10::IntArrayRef dilation(dilation_ptr, dilation_size);
661 c10::IntArrayRef output_padding(output_padding_ptr, output_padding_size);
662
663 *out = new_tensor_handle(at::convolution(
664 *input_tensor,
665 *weight_tensor,
666 optional_bias,
667 stride,
668 padding,
669 dilation,
670 static_cast<bool>(transposed),
671 output_padding,
672 groups));
673 });
674 }
675
aoti_torch_new_uninitialized_tensor(AtenTensorHandle * ret)676 AOTITorchError aoti_torch_new_uninitialized_tensor(AtenTensorHandle* ret) {
677 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
678 at::Tensor* out_tensor = new at::Tensor();
679 *ret = tensor_pointer_to_tensor_handle(out_tensor);
680 });
681 }
682
aoti_torch__scaled_mm(AtenTensorHandle self,AtenTensorHandle mat2,AtenTensorHandle bias,int32_t * out_dtype,AtenTensorHandle scale_a,AtenTensorHandle scale_b,AtenTensorHandle scale_result,int8_t use_fast_accum,AtenTensorHandle * ret0,AtenTensorHandle * ret1)683 AOTITorchError aoti_torch__scaled_mm(
684 AtenTensorHandle self,
685 AtenTensorHandle mat2,
686 AtenTensorHandle bias,
687 int32_t* out_dtype,
688 AtenTensorHandle scale_a,
689 AtenTensorHandle scale_b,
690 AtenTensorHandle scale_result,
691 int8_t use_fast_accum,
692 AtenTensorHandle* ret0,
693 AtenTensorHandle* ret1) {
694 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
695 at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
696 at::Tensor* mat2_tensor = tensor_handle_to_tensor_pointer(mat2);
697 at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias);
698 at::Tensor* scale_a_tensor = tensor_handle_to_tensor_pointer(scale_a);
699 at::Tensor* scale_b_tensor = tensor_handle_to_tensor_pointer(scale_b);
700 at::Tensor* scale_result_tensor =
701 tensor_handle_to_tensor_pointer(scale_result);
702 auto r0 = at::_scaled_mm(
703 *self_tensor,
704 *mat2_tensor,
705 *scale_a_tensor,
706 *scale_b_tensor,
707 pointer_to_optional(bias_tensor),
708 pointer_to_optional(scale_result_tensor),
709 pointer_to_optional<c10::ScalarType>(out_dtype),
710 use_fast_accum);
711 *ret0 = new_tensor_handle(std::move(r0));
712 });
713 }
714
aoti_torch__scaled_mm_v2(AtenTensorHandle self,AtenTensorHandle mat2,AtenTensorHandle scale_a,AtenTensorHandle scale_b,AtenTensorHandle bias,AtenTensorHandle scale_result,int32_t * out_dtype,int8_t use_fast_accum,AtenTensorHandle * ret0)715 AOTITorchError aoti_torch__scaled_mm_v2(
716 AtenTensorHandle self,
717 AtenTensorHandle mat2,
718 AtenTensorHandle scale_a,
719 AtenTensorHandle scale_b,
720 AtenTensorHandle bias,
721 AtenTensorHandle scale_result,
722 int32_t* out_dtype,
723 int8_t use_fast_accum,
724 AtenTensorHandle* ret0) {
725 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
726 at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
727 at::Tensor* mat2_tensor = tensor_handle_to_tensor_pointer(mat2);
728 at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias);
729 at::Tensor* scale_a_tensor = tensor_handle_to_tensor_pointer(scale_a);
730 at::Tensor* scale_b_tensor = tensor_handle_to_tensor_pointer(scale_b);
731 at::Tensor* scale_result_tensor =
732 tensor_handle_to_tensor_pointer(scale_result);
733 auto r0 = at::_scaled_mm(
734 *self_tensor,
735 *mat2_tensor,
736 *scale_a_tensor,
737 *scale_b_tensor,
738 pointer_to_optional(bias_tensor),
739 pointer_to_optional(scale_result_tensor),
740 pointer_to_optional<c10::ScalarType>(out_dtype),
741 use_fast_accum);
742 *ret0 = new_tensor_handle(std::move(r0));
743 });
744 }
745
746 // TODO: implement a more efficient version instead of calling into aten
aoti_torch_tensor_copy_(AtenTensorHandle src,AtenTensorHandle dst)747 AOTITorchError aoti_torch_tensor_copy_(
748 AtenTensorHandle src,
749 AtenTensorHandle dst) {
750 return aoti_torch_copy_(dst, src, /*non_blocking=*/0);
751 }
752
aoti_torch_assign_tensors(AtenTensorHandle src,AtenTensorHandle dst)753 AOTITorchError aoti_torch_assign_tensors(
754 AtenTensorHandle src,
755 AtenTensorHandle dst) {
756 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
757 at::Tensor* src_tensor = tensor_handle_to_tensor_pointer(src);
758 at::Tensor* dst_tensor = tensor_handle_to_tensor_pointer(dst);
759 *dst_tensor = *src_tensor;
760 });
761 }
762
aoti_torch_assign_tensors_out(AtenTensorHandle src,AtenTensorHandle * ret_dst)763 AOTITorchError aoti_torch_assign_tensors_out(
764 AtenTensorHandle src,
765 AtenTensorHandle* ret_dst) {
766 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
767 at::Tensor* src_tensor_ptr = tensor_handle_to_tensor_pointer(src);
768 at::Tensor dst_tensor = *src_tensor_ptr;
769 *ret_dst = new_tensor_handle(std::move(dst_tensor));
770 });
771 }
772
aoti_torch_clone(AtenTensorHandle self,AtenTensorHandle * ret)773 AOTITorchError aoti_torch_clone(AtenTensorHandle self, AtenTensorHandle* ret) {
774 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
775 at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
776 *ret = new_tensor_handle(self_tensor->clone());
777 });
778 }
779
780 // TODO: implement a more efficient version instead of calling into aten
aoti_torch_addmm_out(AtenTensorHandle out,AtenTensorHandle self,AtenTensorHandle mat1,AtenTensorHandle mat2,float beta,float alpha)781 AOTITorchError aoti_torch_addmm_out(
782 AtenTensorHandle out,
783 AtenTensorHandle self,
784 AtenTensorHandle mat1,
785 AtenTensorHandle mat2,
786 float beta,
787 float alpha) {
788 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
789 at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out);
790 at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
791 at::Tensor* mat1_tensor = tensor_handle_to_tensor_pointer(mat1);
792 at::Tensor* mat2_tensor = tensor_handle_to_tensor_pointer(mat2);
793 at::addmm_out(
794 *out_tensor, *self_tensor, *mat1_tensor, *mat2_tensor, beta, alpha);
795 });
796 }
797
798 // TODO: implement a more efficient version instead of calling into aten
aoti_torch_bmm_out(AtenTensorHandle out,AtenTensorHandle self,AtenTensorHandle mat2)799 AOTITorchError aoti_torch_bmm_out(
800 AtenTensorHandle out,
801 AtenTensorHandle self,
802 AtenTensorHandle mat2) {
803 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
804 at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out);
805 at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
806 at::Tensor* mat2_tensor = tensor_handle_to_tensor_pointer(mat2);
807 at::bmm_out(*out_tensor, *self_tensor, *mat2_tensor);
808 });
809 }
810
aoti_torch_copy_(AtenTensorHandle self,AtenTensorHandle src,int32_t non_blocking)811 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_copy_(
812 AtenTensorHandle self,
813 AtenTensorHandle src,
814 int32_t non_blocking) {
815 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
816 tensor_handle_to_tensor_pointer(self)->copy_(
817 *tensor_handle_to_tensor_pointer(src), non_blocking);
818 });
819 }
820
821 // TODO: implement a more efficient version instead of calling into aten
aoti_torch_mm_out(AtenTensorHandle out,AtenTensorHandle self,AtenTensorHandle mat2)822 AOTITorchError aoti_torch_mm_out(
823 AtenTensorHandle out,
824 AtenTensorHandle self,
825 AtenTensorHandle mat2) {
826 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
827 at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out);
828 at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
829 at::Tensor* mat2_tensor = tensor_handle_to_tensor_pointer(mat2);
830 at::mm_out(*out_tensor, *self_tensor, *mat2_tensor);
831 });
832 }
833
aoti_torch__mm_plus_mm_out(AtenTensorHandle out,AtenTensorHandle a,AtenTensorHandle b,AtenTensorHandle c,AtenTensorHandle d)834 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__mm_plus_mm_out(
835 AtenTensorHandle out,
836 AtenTensorHandle a,
837 AtenTensorHandle b,
838 AtenTensorHandle c,
839 AtenTensorHandle d) {
840 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
841 at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out);
842 at::Tensor* a_tensor = tensor_handle_to_tensor_pointer(a);
843 at::Tensor* b_tensor = tensor_handle_to_tensor_pointer(b);
844 at::Tensor* c_tensor = tensor_handle_to_tensor_pointer(c);
845 at::Tensor* d_tensor = tensor_handle_to_tensor_pointer(d);
846 torch::inductor::_mm_plus_mm_out(
847 *out_tensor, *a_tensor, *b_tensor, *c_tensor, *d_tensor);
848 });
849 }
850
aoti_torch_cpu_wrapped_fbgemm_pack_gemm_matrix_fp16(AtenTensorHandle weight,AtenTensorHandle * out)851 AOTITorchError aoti_torch_cpu_wrapped_fbgemm_pack_gemm_matrix_fp16(
852 AtenTensorHandle weight,
853 AtenTensorHandle* out) {
854 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
855 at::Tensor* weight_tensor = tensor_handle_to_tensor_pointer(weight);
856
857 *out = new_tensor_handle(at::fbgemm_pack_gemm_matrix_fp16(*weight_tensor));
858 });
859 }
860
aoti_torch_cpu__wrapped_linear_prepack(AtenTensorHandle weight,AtenTensorHandle weight_scale,AtenTensorHandle weight_zero_point,AtenTensorHandle bias,AtenTensorHandle * out)861 AOTITorchError aoti_torch_cpu__wrapped_linear_prepack(
862 AtenTensorHandle weight,
863 AtenTensorHandle weight_scale,
864 AtenTensorHandle weight_zero_point,
865 AtenTensorHandle bias,
866 AtenTensorHandle* out) {
867 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
868 at::Tensor* weight_tensor = tensor_handle_to_tensor_pointer(weight);
869 at::Tensor* weight_scale_tensor =
870 tensor_handle_to_tensor_pointer(weight_scale);
871 at::Tensor* weight_zero_point_tensor =
872 tensor_handle_to_tensor_pointer(weight_zero_point);
873 at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias);
874
875 *out = new_tensor_handle(at::_wrapped_linear_prepack(
876 *weight_tensor,
877 *weight_scale_tensor,
878 *weight_zero_point_tensor,
879 *bias_tensor));
880 });
881 }
882
aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight(AtenTensorHandle input,AtenTensorHandle weight,AtenTensorHandle bias,int64_t out_channel,AtenTensorHandle * out)883 AOTITorchError aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight(
884 AtenTensorHandle input,
885 AtenTensorHandle weight,
886 AtenTensorHandle bias,
887 int64_t out_channel,
888 AtenTensorHandle* out) {
889 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
890 at::Tensor* input_tensor = tensor_handle_to_tensor_pointer(input);
891 at::Tensor* weight_tensor = tensor_handle_to_tensor_pointer(weight);
892 at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias);
893
894 *out = new_tensor_handle(at::fbgemm_linear_fp16_weight_fp32_activation(
895 *input_tensor, *weight_tensor, *bias_tensor));
896 });
897 }
898
aoti_torch_cpu__wrapped_quantized_linear_prepacked(AtenTensorHandle input,AtenTensorHandle input_scale,AtenTensorHandle input_zero_point,AtenTensorHandle weight,AtenTensorHandle out_scale,AtenTensorHandle out_zeropoint,int64_t out_channel,AtenTensorHandle * out)899 AOTITorchError aoti_torch_cpu__wrapped_quantized_linear_prepacked(
900 AtenTensorHandle input,
901 AtenTensorHandle input_scale,
902 AtenTensorHandle input_zero_point,
903 AtenTensorHandle weight,
904 AtenTensorHandle out_scale,
905 AtenTensorHandle out_zeropoint,
906 int64_t out_channel,
907 AtenTensorHandle* out) {
908 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
909 at::Tensor* input_tensor = tensor_handle_to_tensor_pointer(input);
910 at::Tensor* input_scale_tensor =
911 tensor_handle_to_tensor_pointer(input_scale);
912 at::Tensor* input_zero_point_tensor =
913 tensor_handle_to_tensor_pointer(input_zero_point);
914 at::Tensor* weight_tensor = tensor_handle_to_tensor_pointer(weight);
915 at::Tensor* out_scale_tensor = tensor_handle_to_tensor_pointer(out_scale);
916 at::Tensor* out_zeropoint_tensor =
917 tensor_handle_to_tensor_pointer(out_zeropoint);
918 *out = new_tensor_handle(at::_wrapped_quantized_linear_prepacked(
919 *input_tensor,
920 *input_scale_tensor,
921 *input_zero_point_tensor,
922 *weight_tensor,
923 *out_scale_tensor,
924 *out_zeropoint_tensor,
925 out_channel));
926 });
927 }
928
aoti_torch_nonzero(AtenTensorHandle self,AtenTensorHandle * out)929 AOTITorchError aoti_torch_nonzero(
930 AtenTensorHandle self,
931 AtenTensorHandle* out) {
932 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
933 at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
934 *out = new_tensor_handle(at::nonzero(*self_tensor));
935 });
936 }
937
aoti_torch_repeat_interleave_Tensor(AtenTensorHandle repeats,int64_t * output_size,AtenTensorHandle * out)938 AOTITorchError aoti_torch_repeat_interleave_Tensor(
939 AtenTensorHandle repeats,
940 int64_t* output_size,
941 AtenTensorHandle* out) {
942 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
943 at::Tensor* repeats_tensor = tensor_handle_to_tensor_pointer(repeats);
944 *out = new_tensor_handle(at::_ops::repeat_interleave_Tensor::call(
945 *repeats_tensor, pointer_to_optional<c10::SymInt>(output_size)));
946 });
947 }
948
949 // Function to check existence of inf and NaN
aoti_torch_check_inf_and_nan(const char * tensor_name,AtenTensorHandle tensor)950 AOTITorchError aoti_torch_check_inf_and_nan(
951 const char* tensor_name,
952 AtenTensorHandle tensor) {
953 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
954 at::Tensor* check_tensor = tensor_handle_to_tensor_pointer(tensor);
955
956 assert_inf_and_nan(tensor_name, *check_tensor);
957 });
958 }
959
aoti_torch_scatter_out(AtenTensorHandle out,AtenTensorHandle self,int64_t dim,AtenTensorHandle index,AtenTensorHandle src)960 AOTITorchError aoti_torch_scatter_out(
961 AtenTensorHandle out,
962 AtenTensorHandle self,
963 int64_t dim,
964 AtenTensorHandle index,
965 AtenTensorHandle src) {
966 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
967 at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out);
968 at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
969 at::Tensor* index_tensor = tensor_handle_to_tensor_pointer(index);
970 at::Tensor* src_tensor = tensor_handle_to_tensor_pointer(src);
971 at::scatter_out(*out_tensor, *self_tensor, dim, *index_tensor, *src_tensor);
972 });
973 }
974
aoti_torch_scatter_reduce_out(AtenTensorHandle out,AtenTensorHandle self,int64_t dim,AtenTensorHandle index,AtenTensorHandle src,const char * reduce,int32_t include_self)975 AOTITorchError aoti_torch_scatter_reduce_out(
976 AtenTensorHandle out,
977 AtenTensorHandle self,
978 int64_t dim,
979 AtenTensorHandle index,
980 AtenTensorHandle src,
981 const char* reduce,
982 int32_t include_self) {
983 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
984 at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out);
985 at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
986 at::Tensor* index_tensor = tensor_handle_to_tensor_pointer(index);
987 at::Tensor* src_tensor = tensor_handle_to_tensor_pointer(src);
988 at::scatter_reduce_out(
989 *out_tensor,
990 *self_tensor,
991 dim,
992 *index_tensor,
993 *src_tensor,
994 reduce,
995 (bool)include_self);
996 });
997 }
998
aoti_torch_index_put_out(AtenTensorHandle out,AtenTensorHandle self,const AtenTensorHandle * indices,const uint32_t num_indices,const AtenTensorHandle values,bool accumulate)999 AOTITorchError aoti_torch_index_put_out(
1000 AtenTensorHandle out,
1001 AtenTensorHandle self,
1002 const AtenTensorHandle* indices,
1003 const uint32_t num_indices,
1004 const AtenTensorHandle values,
1005 bool accumulate) {
1006 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1007 c10::List<std::optional<at::Tensor>> indices_;
1008 indices_.reserve(num_indices);
1009 for (size_t i = 0; i < num_indices; i++) {
1010 indices_.emplace_back(
1011 pointer_to_optional(tensor_handle_to_tensor_pointer(indices[i])));
1012 }
1013 at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out);
1014 at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
1015 at::Tensor* values_tensor = tensor_handle_to_tensor_pointer(values);
1016 at::index_put_out(
1017 *out_tensor, *self_tensor, indices_, *values_tensor, accumulate);
1018 });
1019 }
1020
aoti_torch_view_as_real(AtenTensorHandle self,AtenTensorHandle * ret)1021 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_as_real(
1022 AtenTensorHandle self,
1023 AtenTensorHandle* ret // returns new reference
1024 ) {
1025 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1026 *ret = new_tensor_handle(
1027 at::_ops::view_as_real::call(*tensor_handle_to_tensor_pointer(self)));
1028 });
1029 }
1030
aoti_torch_view_dtype(AtenTensorHandle self,int32_t dtype,AtenTensorHandle * ret)1031 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_dtype(
1032 AtenTensorHandle self,
1033 int32_t dtype,
1034 AtenTensorHandle* ret // returns new reference
1035 ) {
1036 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1037 at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
1038 *ret = new_tensor_handle(at::_ops::view_dtype::call(
1039 *self_tensor, static_cast<c10::ScalarType>(dtype)));
1040 });
1041 }
1042
aoti_torch_save_tensor_handle(AtenTensorHandle self,const char * tensor_name,const char * launch_prefix,const char * kernel_name)1043 AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle(
1044 AtenTensorHandle self,
1045 const char* tensor_name,
1046 const char* launch_prefix,
1047 const char* kernel_name) {
1048 at::Tensor* t = tensor_handle_to_tensor_pointer(self);
1049 #ifndef C10_MOBILE
1050 // Save tensor to tmp .pt file for tensors and can be torch.load'ed later
1051 std::string cwd = get_current_path();
1052 std::string tmp_folder = cwd + "/tmp/aoti_torch/";
1053 if (!file_exists(tmp_folder)) {
1054 std::cout
1055 << "aoti_torch_save_tensor_handle: Path does not exist, creating it..."
1056 << tmp_folder << std::endl;
1057
1058 if (!create_directories(tmp_folder)) {
1059 std::cout << "aoti_torch_save_tensor_handle: Error creating directory: "
1060 << tmp_folder << std::endl;
1061 return;
1062 }
1063 }
1064 std::string tensor_filepath_to_save = tmp_folder + launch_prefix + "_" +
1065 kernel_name + "_" + tensor_name + "_" + t->device().str() + ".pt";
1066
1067 auto bytes = torch::jit::pickle_save(c10::IValue(*t));
1068 std::ofstream fout(tensor_filepath_to_save, std::ios::out | std::ios::binary);
1069 fout.write(bytes.data(), bytes.size());
1070 fout.close();
1071
1072 std::cout << "aoti_torch_save_tensor_handle: Saved tensor to "
1073 << tensor_filepath_to_save << std::endl;
1074 #endif // !defined(C10_MOBILE)
1075 }
1076
aoti_torch_print_tensor_handle(AtenTensorHandle self,const char * msg)1077 AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle(
1078 AtenTensorHandle self,
1079 const char* msg) {
1080 at::Tensor* t = tensor_handle_to_tensor_pointer(self);
1081
1082 // Display message
1083 std::cout << "[";
1084 if (msg) {
1085 std::cout << " " << msg;
1086 }
1087 std::cout << " "
1088 << "]:" << std::endl;
1089
1090 // Print exact tensor values for small size tensors
1091 const int64_t numel = t->numel();
1092 if (numel <= AOTI_TORCH_MAX_NUMEL_TO_PRINT) {
1093 std::cout << *t << "\n";
1094 }
1095
1096 // Print summary stats of the tensor
1097 std::cout << "Number of elements: " << numel << std::endl;
1098 std::cout << "Dtype: " << t->dtype() << std::endl;
1099 if (numel > 0) {
1100 // torch/aten `mean()` function only supports float and complex dtypes
1101 // See:
1102 // https://github.com/pytorch/pytorch/blob/a0e062c6f1a03ec93e87413e42c4d0b336518131/aten/src/ATen/native/ReduceOps.cpp#L304-L309
1103 auto mean_value = [t](at::ScalarType dtype) {
1104 return t->to(dtype).mean().item();
1105 };
1106 bool is_complex_type =
1107 at::isComplexType(at::typeMetaToScalarType(t->dtype()));
1108 at::ScalarType float_dtype =
1109 is_complex_type ? at::kComplexFloat : at::kFloat;
1110 std::cout << "Mean value: " << mean_value(float_dtype) << std::endl;
1111 if (!is_complex_type) {
1112 // "min_all_cuda" function is not implemented for 'ComplexFloat' type.
1113 // (similar for max) Skip printing min/max value for complex type tensors
1114 // here If encountered complex dtypes (rare occasions), suggest to print
1115 // out the whole value of the tensor.
1116 std::cout << "Min value: " << t->min().item<float>() << std::endl;
1117 std::cout << "Max value: " << t->max().item<float>() << std::endl;
1118 }
1119 }
1120 std::cout << "Device: " << t->device() << std::endl;
1121 std::cout << "Size: " << t->sizes() << std::endl;
1122 std::cout << "Stride: " << t->strides() << std::endl;
1123 std::cout << "Layout: " << t->layout() << std::endl;
1124 std::cout << "Is contiguous: " << t->is_contiguous() << std::endl;
1125 std::cout << "Requires grad: " << t->requires_grad() << std::endl;
1126
1127 std::cout << std::endl;
1128 }
1129
1130 // ProxyExecutor
aoti_torch_proxy_executor_call_function(AOTIProxyExecutorHandle proxy_executor,int extern_node_index,int num_ints,int64_t * flatten_int_args,int num_tensors,AtenTensorHandle * flatten_tensor_args)1131 AOTITorchError aoti_torch_proxy_executor_call_function(
1132 AOTIProxyExecutorHandle proxy_executor,
1133 int extern_node_index,
1134 int num_ints,
1135 int64_t* flatten_int_args,
1136 int num_tensors,
1137 AtenTensorHandle* flatten_tensor_args) {
1138 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1139 ProxyExecutor* executor = reinterpret_cast<ProxyExecutor*>(proxy_executor);
1140 executor->call_function(
1141 extern_node_index,
1142 num_ints,
1143 flatten_int_args,
1144 num_tensors,
1145 flatten_tensor_args);
1146 });
1147 }
1148
aoti_torch_check(bool cond,const char * func,const char * file,uint32_t line,const char * msg)1149 void aoti_torch_check(
1150 bool cond,
1151 const char* func,
1152 const char* file,
1153 uint32_t line,
1154 const char* msg) {
1155 if (C10_UNLIKELY_OR_CONST(!cond)) {
1156 ::c10::detail::torchCheckFail(func, file, line, msg);
1157 }
1158 }
1159
aoti_torch__alloc_from_pool(AtenTensorHandle self,int64_t offset_bytes,int32_t dtype,int64_t ndim,const int64_t * sizes_ptr,const int64_t * strides_ptr,AtenTensorHandle * ret_new_tensor)1160 AOTITorchError aoti_torch__alloc_from_pool(
1161 AtenTensorHandle self,
1162 int64_t offset_bytes,
1163 int32_t dtype,
1164 int64_t ndim,
1165 const int64_t* sizes_ptr,
1166 const int64_t* strides_ptr,
1167 AtenTensorHandle* ret_new_tensor) {
1168 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1169 at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
1170 c10::IntArrayRef sizes(sizes_ptr, ndim);
1171 c10::IntArrayRef strides(strides_ptr, ndim);
1172 *ret_new_tensor = new_tensor_handle(torch::inductor::_alloc_from_pool(
1173 *self_tensor,
1174 offset_bytes,
1175 static_cast<c10::ScalarType>(dtype),
1176 sizes,
1177 strides));
1178 });
1179 }
1180