xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_torch/shim_common.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/DeviceType.h>
2 #include <c10/core/GradMode.h>
3 #include <c10/core/Layout.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/util/Exception.h>
6 #include <torch/csrc/inductor/aoti_torch/c/shim.h>
7 #include <torch/csrc/inductor/aoti_torch/mkldnn_tensor.h>
8 #include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
9 #include <torch/csrc/inductor/aoti_torch/proxy_executor.h>
10 #include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
11 #include <torch/csrc/inductor/aoti_torch/utils.h>
12 #include <torch/csrc/inductor/inductor_ops.h>
13 #include <torch/csrc/jit/serialization/pickle.h>
14 #include <cstdint>
15 #include <cstdio>
16 #include <fstream>
17 #include <iostream>
18 
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/Functions.h>
21 #else
22 
23 #include <ATen/ops/_addmm_activation.h>
24 #include <ATen/ops/_embedding_bag.h>
25 #include <ATen/ops/_fft_c2c.h>
26 #include <ATen/ops/_scaled_dot_product_efficient_attention.h>
27 #include <ATen/ops/_scaled_dot_product_flash_attention.h>
28 #include <ATen/ops/_scaled_mm.h>
29 #include <ATen/ops/_wrapped_linear_prepack.h>
30 #include <ATen/ops/_wrapped_quantized_linear_prepacked.h>
31 #include <ATen/ops/addmm.h>
32 #include <ATen/ops/as_strided.h>
33 #include <ATen/ops/bmm.h>
34 #include <ATen/ops/convolution.h>
35 #include <ATen/ops/empty_strided.h>
36 #include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation.h>
37 #include <ATen/ops/fbgemm_pack_gemm_matrix_fp16.h>
38 #include <ATen/ops/from_blob.h>
39 #include <ATen/ops/index_put.h>
40 #include <ATen/ops/mm.h>
41 #include <ATen/ops/nonzero.h>
42 #include <ATen/ops/scalar_tensor.h>
43 #include <ATen/ops/scatter.h>
44 #include <ATen/ops/scatter_reduce.h>
45 #include <ATen/ops/view_as_real_ops.h>
46 #include <ATen/ops/view_ops.h>
47 
48 #endif
49 
50 #if __has_include("filesystem")
51 #include <filesystem>
52 namespace fs = std::filesystem;
53 #else
54 #include <experimental/filesystem>
55 namespace fs = std::experimental::filesystem;
56 #endif
57 
58 #ifndef _WIN32
59 #include <limits.h>
60 #include <sys/stat.h>
61 #include <sys/types.h>
62 #include <unistd.h>
63 #endif
64 
65 // HACK for failed builds in ARVR, where it cannot find these symbols within
66 // std::experimental::filesystem
67 namespace {
get_current_path()68 std::string get_current_path() {
69 #if __has_include("filesystem") && !defined(__linux__)
70   return fs::current_path().string();
71 #else
72   char currentPath[PATH_MAX];
73   if (getcwd(currentPath, sizeof(currentPath)) != nullptr) {
74     return std::string(currentPath);
75   } else {
76     throw std::runtime_error("Failed to get current path");
77   }
78 #endif
79 }
80 
file_exists(std::string & path)81 bool file_exists(std::string& path) {
82 #if __has_include("filesystem") && !defined(__linux__)
83   return fs::exists(path);
84 #else
85   struct stat rc;
86   return lstat(path.c_str(), &rc) == 0;
87 #endif
88 }
89 
create_directories(const std::string & path)90 bool create_directories(const std::string& path) {
91 #if __has_include("filesystem") && !defined(__linux__)
92   return fs::create_directories(path);
93 #else
94   if (mkdir(path.c_str(), 0777) == -1) {
95     throw std::runtime_error("Failed to create directory");
96   }
97   return true;
98 #endif
99 }
100 } // namespace
101 
102 using namespace torch::aot_inductor;
103 
104 namespace {
c10_device(int32_t device_type,int32_t device_index)105 static c10::Device c10_device(int32_t device_type, int32_t device_index) {
106   if (device_type == aoti_torch_device_type_cpu()) {
107     return c10::Device(static_cast<c10::DeviceType>(device_type));
108   } else {
109     return c10::Device(
110         static_cast<c10::DeviceType>(device_type),
111         static_cast<c10::DeviceIndex>(device_index));
112   }
113 }
114 } // namespace
115 
116 const int AOTI_TORCH_MAX_NUMEL_TO_PRINT = 64;
117 
aoti_torch_device_type_cpu()118 int32_t aoti_torch_device_type_cpu() {
119   return (int32_t)c10::DeviceType::CPU;
120 }
121 
aoti_torch_device_type_cuda()122 int32_t aoti_torch_device_type_cuda() {
123   return (int32_t)c10::DeviceType::CUDA;
124 }
125 
126 #define AOTI_TORCH_DTYPE_IMPL(dtype, stype) \
127   int32_t aoti_torch_dtype_##dtype() {      \
128     return (int32_t)c10::ScalarType::stype; \
129   }
130 
AOTI_TORCH_DTYPE_IMPL(float8_e5m2,Float8_e5m2)131 AOTI_TORCH_DTYPE_IMPL(float8_e5m2, Float8_e5m2)
132 AOTI_TORCH_DTYPE_IMPL(float8_e4m3fn, Float8_e4m3fn)
133 AOTI_TORCH_DTYPE_IMPL(float8_e5m2fnuz, Float8_e5m2fnuz)
134 AOTI_TORCH_DTYPE_IMPL(float8_e4m3fnuz, Float8_e4m3fnuz)
135 AOTI_TORCH_DTYPE_IMPL(bfloat16, BFloat16)
136 AOTI_TORCH_DTYPE_IMPL(float16, Half)
137 AOTI_TORCH_DTYPE_IMPL(float32, Float)
138 AOTI_TORCH_DTYPE_IMPL(float64, Double)
139 AOTI_TORCH_DTYPE_IMPL(uint8, Byte)
140 AOTI_TORCH_DTYPE_IMPL(uint16, UInt16)
141 AOTI_TORCH_DTYPE_IMPL(uint32, UInt32)
142 AOTI_TORCH_DTYPE_IMPL(uint64, UInt64)
143 AOTI_TORCH_DTYPE_IMPL(int8, Char)
144 AOTI_TORCH_DTYPE_IMPL(int16, Short)
145 AOTI_TORCH_DTYPE_IMPL(int32, Int)
146 AOTI_TORCH_DTYPE_IMPL(int64, Long)
147 AOTI_TORCH_DTYPE_IMPL(bool, Bool)
148 AOTI_TORCH_DTYPE_IMPL(complex32, ComplexHalf)
149 AOTI_TORCH_DTYPE_IMPL(complex64, ComplexFloat)
150 AOTI_TORCH_DTYPE_IMPL(complex128, ComplexDouble)
151 #undef AOTI_TORCH_DTYPE_IMPL
152 
153 int32_t aoti_torch_layout_strided() {
154   return (int32_t)at::kStrided;
155 }
156 
aoti_torch_layout__mkldnn()157 int32_t aoti_torch_layout__mkldnn() {
158   return (int32_t)at::kMkldnn;
159 }
160 
161 #define AOTI_TORCH_ITEM_IMPL(dtype, ctype)                     \
162   AOTITorchError aoti_torch_item_##dtype(                      \
163       AtenTensorHandle tensor, ctype* ret_value) {             \
164     AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({               \
165       at::Tensor* t = tensor_handle_to_tensor_pointer(tensor); \
166       *ret_value = t->item().to<ctype>();                      \
167     });                                                        \
168   }
169 
AOTI_TORCH_ITEM_IMPL(float16,c10::Half)170 AOTI_TORCH_ITEM_IMPL(float16, c10::Half)
171 AOTI_TORCH_ITEM_IMPL(float32, float)
172 AOTI_TORCH_ITEM_IMPL(float64, double)
173 AOTI_TORCH_ITEM_IMPL(uint8, uint8_t)
174 AOTI_TORCH_ITEM_IMPL(uint16, uint16_t)
175 AOTI_TORCH_ITEM_IMPL(uint32, uint32_t)
176 AOTI_TORCH_ITEM_IMPL(uint64, uint64_t)
177 AOTI_TORCH_ITEM_IMPL(int8, int8_t)
178 AOTI_TORCH_ITEM_IMPL(int16, int16_t)
179 AOTI_TORCH_ITEM_IMPL(int32, int32_t)
180 AOTI_TORCH_ITEM_IMPL(int64, int64_t)
181 AOTI_TORCH_ITEM_IMPL(bool, bool)
182 AOTI_TORCH_ITEM_IMPL(bfloat16, c10::BFloat16)
183 AOTI_TORCH_ITEM_IMPL(complex64, c10::complex<float>)
184 #undef AOTI_TORCH_ITEM_IMPL
185 
186 #define AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(dtype, ctype, ttype)                  \
187   AOTITorchError aoti_torch_scalar_to_tensor_##dtype(                          \
188       ctype value, AtenTensorHandle* ret_new_tensor) {                         \
189     AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({                               \
190       *ret_new_tensor =                                                        \
191           new_tensor_handle(at::scalar_tensor(value, c10::ScalarType::ttype)); \
192     });                                                                        \
193   }
194 
195 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(float32, float, Float)
196 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(float64, double, Double)
197 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(uint8, uint8_t, Byte)
198 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(uint16, uint16_t, UInt16)
199 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(uint32, uint32_t, UInt32)
200 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(uint64, uint64_t, UInt64)
201 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(int8, int8_t, Char)
202 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(int16, int16_t, Short)
203 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(int32, int32_t, Int)
204 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(int64, int64_t, Long)
205 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(bool, bool, Bool)
206 AOTI_TORCH_SCALAR_TO_TENSOR_IMPL(complex64, c10::complex<float>, ComplexFloat)
207 #undef AOTI_TORCH_SCALAR_TO_TENSOR_IMPL
208 
209 bool aoti_torch_grad_mode_is_enabled() {
210   return c10::GradMode::is_enabled();
211 }
212 
aoti_torch_grad_mode_set_enabled(bool enabled)213 void aoti_torch_grad_mode_set_enabled(bool enabled) {
214   return c10::GradMode::set_enabled(enabled);
215 }
216 
aoti_torch_delete_tensor_object(AtenTensorHandle tensor)217 AOTITorchError aoti_torch_delete_tensor_object(AtenTensorHandle tensor) {
218   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
219     at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
220     delete t;
221   });
222 }
223 
aoti_torch_get_data_ptr(AtenTensorHandle tensor,void ** ret_data_ptr)224 AOTITorchError aoti_torch_get_data_ptr(
225     AtenTensorHandle tensor,
226     void** ret_data_ptr) {
227   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
228     at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
229     if (t->is_mkldnn()) {
230       *ret_data_ptr = data_ptr_from_mkldnn(t);
231     } else {
232       *ret_data_ptr = t->data_ptr();
233     }
234   });
235 }
236 
aoti_torch_get_storage_size(AtenTensorHandle tensor,int64_t * ret_size)237 AOTITorchError aoti_torch_get_storage_size(
238     AtenTensorHandle tensor,
239     int64_t* ret_size) {
240   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
241     at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
242     *ret_size = t->storage().nbytes();
243   });
244 }
245 
aoti_torch_get_dim(AtenTensorHandle tensor,int64_t * ret_dim)246 AOTITorchError aoti_torch_get_dim(AtenTensorHandle tensor, int64_t* ret_dim) {
247   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
248     at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
249     *ret_dim = t->dim();
250   });
251 }
252 
aoti_torch_get_numel(AtenTensorHandle tensor,int64_t * ret_numel)253 AOTITorchError aoti_torch_get_numel(
254     AtenTensorHandle tensor,
255     int64_t* ret_numel) {
256   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
257     at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
258     *ret_numel = t->numel();
259   });
260 }
261 
aoti_torch_get_storage_numel(AtenTensorHandle tensor,int64_t * ret_numel)262 AOTITorchError aoti_torch_get_storage_numel(
263     AtenTensorHandle tensor,
264     int64_t* ret_numel) {
265   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
266     at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
267     TORCH_INTERNAL_ASSERT(t->has_storage());
268     auto dtype_size = t->dtype().itemsize();
269     size_t nbytes = t->storage().nbytes();
270     TORCH_INTERNAL_ASSERT(nbytes % dtype_size == 0);
271     auto numel = nbytes / dtype_size;
272     *ret_numel = numel;
273   });
274 }
275 
aoti_torch_get_sizes(AtenTensorHandle tensor,int64_t ** ret_sizes)276 AOTITorchError aoti_torch_get_sizes(
277     AtenTensorHandle tensor,
278     int64_t** ret_sizes) {
279   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
280     at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
281     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
282     *ret_sizes = const_cast<int64_t*>(t->sizes().data());
283   });
284 }
285 
aoti_torch_get_size(AtenTensorHandle tensor,int64_t d,int64_t * ret_size)286 AOTITorchError aoti_torch_get_size(
287     AtenTensorHandle tensor,
288     int64_t d,
289     int64_t* ret_size) {
290   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
291     at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
292     *ret_size = t->size(d);
293   });
294 }
295 
aoti_torch_get_strides(AtenTensorHandle tensor,int64_t ** ret_strides)296 AOTITorchError aoti_torch_get_strides(
297     AtenTensorHandle tensor,
298     int64_t** ret_strides) {
299   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
300     at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
301     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
302     *ret_strides = const_cast<int64_t*>(t->strides().data());
303   });
304 }
305 
aoti_torch_get_stride(AtenTensorHandle tensor,int64_t d,int64_t * ret_stride)306 AOTITorchError aoti_torch_get_stride(
307     AtenTensorHandle tensor,
308     int64_t d,
309     int64_t* ret_stride) {
310   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
311     at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
312     *ret_stride = t->stride(d);
313   });
314 }
315 
aoti_torch_get_dtype(AtenTensorHandle tensor,int32_t * ret_dtype)316 AOTITorchError aoti_torch_get_dtype(
317     AtenTensorHandle tensor,
318     int32_t* ret_dtype) {
319   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
320     at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
321     *ret_dtype = static_cast<int32_t>(t->scalar_type());
322   });
323 }
324 
aoti_torch_get_device_type(AtenTensorHandle tensor,int32_t * ret_device_type)325 AOTITorchError aoti_torch_get_device_type(
326     AtenTensorHandle tensor,
327     int32_t* ret_device_type) {
328   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
329     at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
330     *ret_device_type = static_cast<int32_t>(t->device().type());
331   });
332 }
333 
aoti_torch_get_device_index(AtenTensorHandle tensor,int32_t * ret_device_index)334 AOTITorchError aoti_torch_get_device_index(
335     AtenTensorHandle tensor,
336     int32_t* ret_device_index) {
337   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
338     at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
339     *ret_device_index = static_cast<int16_t>(t->device().index());
340   });
341 }
342 
aoti_torch_get_storage_offset(AtenTensorHandle tensor,int64_t * ret_storage_offset)343 AOTITorchError aoti_torch_get_storage_offset(
344     AtenTensorHandle tensor,
345     int64_t* ret_storage_offset) {
346   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
347     at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
348     *ret_storage_offset = t->storage_offset();
349   });
350 }
351 
aoti_torch__reinterpret_tensor(AtenTensorHandle self,int64_t ndim,const int64_t * sizes_ptr,const int64_t * strides_ptr,int64_t offset_increment,AtenTensorHandle * ret_new_tensor)352 AOTITorchError aoti_torch__reinterpret_tensor(
353     AtenTensorHandle self,
354     int64_t ndim,
355     const int64_t* sizes_ptr,
356     const int64_t* strides_ptr,
357     int64_t offset_increment,
358     AtenTensorHandle* ret_new_tensor) {
359   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
360     at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
361     c10::IntArrayRef sizes(sizes_ptr, ndim);
362     c10::IntArrayRef strides(strides_ptr, ndim);
363     *ret_new_tensor = new_tensor_handle(torch::inductor::_reinterpret_tensor(
364         *self_tensor, sizes, strides, offset_increment));
365   });
366 }
367 
368 // TODO: implement a more efficient version instead of calling into aten
aoti_torch_empty_strided(int64_t ndim,const int64_t * sizes_ptr,const int64_t * strides_ptr,int32_t dtype,int32_t device_type,int32_t device_index,AtenTensorHandle * ret_new_tensor)369 AOTITorchError aoti_torch_empty_strided(
370     int64_t ndim,
371     const int64_t* sizes_ptr,
372     const int64_t* strides_ptr,
373     int32_t dtype,
374     int32_t device_type,
375     int32_t device_index,
376     AtenTensorHandle* ret_new_tensor) {
377   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
378     c10::IntArrayRef sizes(sizes_ptr, ndim);
379     c10::IntArrayRef strides(strides_ptr, ndim);
380     if (c10::DeviceType(device_type) == c10::DeviceType::CPU) {
381       *ret_new_tensor = new_tensor_handle(at::detail::empty_strided_cpu(
382           sizes, strides, static_cast<c10::ScalarType>(dtype)));
383     } else {
384       c10::Device device = c10_device(device_type, device_index);
385       c10::TensorOptions options = c10::TensorOptions().device(device).dtype(
386           static_cast<c10::ScalarType>(dtype));
387       *ret_new_tensor =
388           new_tensor_handle(at::empty_strided(sizes, strides, options));
389     }
390   });
391 }
392 
aoti_torch_create_tensor_from_blob(void * data,int64_t ndim,const int64_t * sizes_ptr,const int64_t * strides_ptr,int64_t storage_offset,int32_t dtype,int32_t device_type,int32_t device_index,AtenTensorHandle * ret_new_tensor)393 AOTITorchError aoti_torch_create_tensor_from_blob(
394     void* data,
395     int64_t ndim,
396     const int64_t* sizes_ptr,
397     const int64_t* strides_ptr,
398     int64_t storage_offset,
399     int32_t dtype,
400     int32_t device_type,
401     int32_t device_index,
402     AtenTensorHandle* ret_new_tensor) {
403   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
404     c10::IntArrayRef sizes(sizes_ptr, ndim);
405     c10::IntArrayRef strides(strides_ptr, ndim);
406     c10::Device device = c10_device(device_type, device_index);
407     c10::TensorOptions options = c10::TensorOptions().device(device).dtype(
408         static_cast<c10::ScalarType>(dtype));
409     *ret_new_tensor = new_tensor_handle(
410         // data == nullptr can happen for a 0-size tensor
411         (data != nullptr) ? at::for_blob(data, sizes)
412                                 .strides(strides)
413                                 .storage_offset(storage_offset)
414                                 .options(options)
415                                 .make_tensor()
416                           : at::empty_strided(sizes, strides, options));
417   });
418 }
419 
aoti_torch_create_tensor_from_blob_v2(void * data,int64_t ndim,const int64_t * sizes_ptr,const int64_t * strides_ptr,int64_t storage_offset,int32_t dtype,int32_t device_type,int32_t device_index,AtenTensorHandle * ret_new_tensor,int32_t layout,const uint8_t * opaque_metadata,int64_t opaque_metadata_size)420 AOTITorchError aoti_torch_create_tensor_from_blob_v2(
421     void* data,
422     int64_t ndim,
423     const int64_t* sizes_ptr,
424     const int64_t* strides_ptr,
425     int64_t storage_offset,
426     int32_t dtype,
427     int32_t device_type,
428     int32_t device_index,
429     AtenTensorHandle* ret_new_tensor,
430     int32_t layout,
431     const uint8_t* opaque_metadata,
432     int64_t opaque_metadata_size) {
433   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
434     if (layout == static_cast<int32_t>(at::kMkldnn)) {
435       c10::IntArrayRef sizes(sizes_ptr, ndim);
436       c10::IntArrayRef strides(strides_ptr, ndim);
437       c10::Device device = c10_device(device_type, device_index);
438       // get a mkldnn tensor wrapped by a torch Tensor(OpaqueTensorImpl),
439       // which used by later mkldnn op.
440       *ret_new_tensor = new_tensor_handle(mkldnn_tensor_from_data_ptr(
441           data,
442           sizes,
443           static_cast<c10::ScalarType>(dtype),
444           device,
445           opaque_metadata,
446           opaque_metadata_size));
447     } else {
448       aoti_torch_create_tensor_from_blob(
449           data,
450           ndim,
451           sizes_ptr,
452           strides_ptr,
453           storage_offset,
454           dtype,
455           device_type,
456           device_index,
457           ret_new_tensor);
458     }
459   });
460 }
461 
aoti_torch__embedding_bag(AtenTensorHandle weight,AtenTensorHandle indices,AtenTensorHandle offsets,int32_t scale_grad_by_freq,int32_t mode,int32_t sparse,AtenTensorHandle per_sample_weights,int32_t include_last_offset,int32_t padding_idx,AtenTensorHandle * ret0,AtenTensorHandle * ret1,AtenTensorHandle * ret2,AtenTensorHandle * ret3)462 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag(
463     AtenTensorHandle weight,
464     AtenTensorHandle indices,
465     AtenTensorHandle offsets,
466     int32_t scale_grad_by_freq,
467     int32_t mode,
468     int32_t sparse,
469     AtenTensorHandle per_sample_weights, // optional argument
470     int32_t include_last_offset,
471     int32_t padding_idx,
472     AtenTensorHandle* ret0, // returns new reference
473     AtenTensorHandle* ret1, // returns new reference
474     AtenTensorHandle* ret2, // returns new reference
475     AtenTensorHandle* ret3 // returns new reference
476 ) {
477   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
478     auto [r0, r1, r2, r3] = at::_embedding_bag(
479         *tensor_handle_to_tensor_pointer(weight),
480         *tensor_handle_to_tensor_pointer(indices),
481         *tensor_handle_to_tensor_pointer(offsets),
482         scale_grad_by_freq,
483         mode,
484         sparse,
485         pointer_to_optional(
486             tensor_handle_to_tensor_pointer(per_sample_weights)),
487         include_last_offset,
488         padding_idx);
489 
490     *ret0 = new_tensor_handle(std::move(r0));
491     *ret1 = new_tensor_handle(std::move(r1));
492     *ret2 = new_tensor_handle(std::move(r2));
493     *ret3 = new_tensor_handle(std::move(r3));
494   });
495 }
496 
aoti_torch__fft_c2c(AtenTensorHandle self,const int64_t * dim_ptr,int64_t dim_size,int64_t normalization,int32_t forward,AtenTensorHandle * ret)497 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__fft_c2c(
498     AtenTensorHandle self,
499     const int64_t* dim_ptr,
500     int64_t dim_size,
501     int64_t normalization,
502     int32_t forward,
503     AtenTensorHandle* ret // returns new reference
504 ) {
505   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
506     auto dim = c10::IntArrayRef(dim_ptr, dim_size);
507     *ret = new_tensor_handle(at::_fft_c2c(
508         *tensor_handle_to_tensor_pointer(self), dim, normalization, forward));
509   });
510 }
511 
aoti_torch__scaled_dot_product_flash_attention_v2(AtenTensorHandle query,AtenTensorHandle key,AtenTensorHandle value,double dropout_p,int is_causal,int return_debug_mask,double * scale,AtenTensorHandle * ret0,AtenTensorHandle * ret1,AtenTensorHandle * ret2,AtenTensorHandle * ret3,int64_t * ret4,int64_t * ret5,AtenTensorHandle * ret6,AtenTensorHandle * ret7,AtenTensorHandle * ret8)512 AOTITorchError aoti_torch__scaled_dot_product_flash_attention_v2(
513     AtenTensorHandle query,
514     AtenTensorHandle key,
515     AtenTensorHandle value,
516     double dropout_p,
517     int is_causal,
518     int return_debug_mask,
519     double* scale, // optional argument
520     AtenTensorHandle* ret0, // returns new reference
521     AtenTensorHandle* ret1, // returns new reference
522     AtenTensorHandle* ret2, // returns new reference
523     AtenTensorHandle* ret3, // returns new reference
524     int64_t* ret4,
525     int64_t* ret5,
526     AtenTensorHandle* ret6, // returns new reference
527     AtenTensorHandle* ret7, // returns new reference
528     AtenTensorHandle* ret8 // returns new reference
529 ) {
530   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
531     at::Tensor* query_tensor = tensor_handle_to_tensor_pointer(query);
532     at::Tensor* key_tensor = tensor_handle_to_tensor_pointer(key);
533     at::Tensor* value_tensor = tensor_handle_to_tensor_pointer(value);
534     auto optional_scale = pointer_to_optional(scale);
535     auto [r0, r1, r2, r3, r4, r5, r6, r7, r8] =
536         at::_scaled_dot_product_flash_attention(
537             *query_tensor,
538             *key_tensor,
539             *value_tensor,
540             dropout_p,
541             is_causal,
542             return_debug_mask,
543             optional_scale);
544 
545     *ret0 = new_tensor_handle(std::move(r0));
546     *ret1 = new_tensor_handle(std::move(r1));
547     // ret2 and ret3 may be null
548     if (ret2) {
549       *ret2 = new_tensor_handle(std::move(r2));
550     }
551     if (ret3) {
552       *ret3 = new_tensor_handle(std::move(r3));
553     }
554     *ret4 = r4.expect_int();
555     *ret5 = r5.expect_int();
556     *ret6 = new_tensor_handle(std::move(r6));
557     *ret7 = new_tensor_handle(std::move(r7));
558     *ret8 = new_tensor_handle(std::move(r8));
559   });
560 }
561 
aoti_torch__scaled_dot_product_flash_attention(AtenTensorHandle query,AtenTensorHandle key,AtenTensorHandle value,double dropout_p,bool is_causal,bool return_debug_mask,double scale,AtenTensorHandle * ret0,AtenTensorHandle * ret1,AtenTensorHandle * ret2,AtenTensorHandle * ret3,int64_t * ret4,int64_t * ret5,AtenTensorHandle * ret6,AtenTensorHandle * ret7,AtenTensorHandle * ret8)562 AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
563     AtenTensorHandle query,
564     AtenTensorHandle key,
565     AtenTensorHandle value,
566     double dropout_p,
567     bool is_causal,
568     bool return_debug_mask,
569     double scale,
570     AtenTensorHandle* ret0, // returns new reference
571     AtenTensorHandle* ret1, // returns new reference
572     AtenTensorHandle* ret2, // returns new reference
573     AtenTensorHandle* ret3, // returns new reference
574     int64_t* ret4,
575     int64_t* ret5,
576     AtenTensorHandle* ret6, // returns new reference
577     AtenTensorHandle* ret7, // returns new reference
578     AtenTensorHandle* ret8 // returns new reference
579 ) {
580   return aoti_torch__scaled_dot_product_flash_attention_v2(
581       query,
582       key,
583       value,
584       dropout_p,
585       is_causal,
586       return_debug_mask,
587       &scale,
588       ret0,
589       ret1,
590       ret2,
591       ret3,
592       ret4,
593       ret5,
594       ret6,
595       ret7,
596       ret8);
597 }
598 
599 AOTI_TORCH_EXPORT AOTITorchError
aoti_torch__scaled_dot_product_efficient_attention(AtenTensorHandle query,AtenTensorHandle key,AtenTensorHandle value,AtenTensorHandle attn_bias,int compute_log_sumexp,double dropout_p,int is_causal,double * scale,AtenTensorHandle * ret0,AtenTensorHandle * ret1,AtenTensorHandle * ret2,AtenTensorHandle * ret3)600 aoti_torch__scaled_dot_product_efficient_attention(
601     AtenTensorHandle query,
602     AtenTensorHandle key,
603     AtenTensorHandle value,
604     AtenTensorHandle attn_bias, // optional argument
605     int compute_log_sumexp,
606     double dropout_p,
607     int is_causal,
608     double* scale, // optional argument
609     AtenTensorHandle* ret0, // returns new reference
610     AtenTensorHandle* ret1, // returns new reference
611     AtenTensorHandle* ret2, // returns new reference
612     AtenTensorHandle* ret3 // returns new reference
613 ) {
614   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
615     at::Tensor* query_tensor = tensor_handle_to_tensor_pointer(query);
616     at::Tensor* key_tensor = tensor_handle_to_tensor_pointer(key);
617     at::Tensor* value_tensor = tensor_handle_to_tensor_pointer(value);
618     auto optional_attn_bias =
619         pointer_to_optional(tensor_handle_to_tensor_pointer(attn_bias));
620     auto optional_scale = pointer_to_optional(scale);
621     auto [r0, r1, r2, r3] = at::_scaled_dot_product_efficient_attention(
622         *query_tensor,
623         *key_tensor,
624         *value_tensor,
625         optional_attn_bias,
626         compute_log_sumexp,
627         dropout_p,
628         is_causal,
629         optional_scale);
630     *ret0 = new_tensor_handle(std::move(r0));
631     *ret1 = new_tensor_handle(std::move(r1));
632     *ret2 = new_tensor_handle(std::move(r2));
633     *ret3 = new_tensor_handle(std::move(r3));
634   });
635 }
636 
aoti_torch_convolution(AtenTensorHandle input,AtenTensorHandle weight,AtenTensorHandle bias,const int64_t * stride_ptr,int64_t stride_size,const int64_t * padding_ptr,int64_t padding_size,const int64_t * dilation_ptr,int64_t dilation_size,int transposed,const int64_t * output_padding_ptr,int64_t output_padding_size,int64_t groups,AtenTensorHandle * out)637 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_convolution(
638     AtenTensorHandle input,
639     AtenTensorHandle weight,
640     AtenTensorHandle bias, // optional argument
641     const int64_t* stride_ptr,
642     int64_t stride_size,
643     const int64_t* padding_ptr,
644     int64_t padding_size,
645     const int64_t* dilation_ptr,
646     int64_t dilation_size,
647     int transposed,
648     const int64_t* output_padding_ptr,
649     int64_t output_padding_size,
650     int64_t groups,
651     AtenTensorHandle* out // returns new reference
652 ) {
653   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
654     at::Tensor* input_tensor = tensor_handle_to_tensor_pointer(input);
655     at::Tensor* weight_tensor = tensor_handle_to_tensor_pointer(weight);
656     at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias);
657     auto optional_bias = pointer_to_optional(bias_tensor);
658     c10::IntArrayRef stride(stride_ptr, stride_size);
659     c10::IntArrayRef padding(padding_ptr, padding_size);
660     c10::IntArrayRef dilation(dilation_ptr, dilation_size);
661     c10::IntArrayRef output_padding(output_padding_ptr, output_padding_size);
662 
663     *out = new_tensor_handle(at::convolution(
664         *input_tensor,
665         *weight_tensor,
666         optional_bias,
667         stride,
668         padding,
669         dilation,
670         static_cast<bool>(transposed),
671         output_padding,
672         groups));
673   });
674 }
675 
aoti_torch_new_uninitialized_tensor(AtenTensorHandle * ret)676 AOTITorchError aoti_torch_new_uninitialized_tensor(AtenTensorHandle* ret) {
677   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
678     at::Tensor* out_tensor = new at::Tensor();
679     *ret = tensor_pointer_to_tensor_handle(out_tensor);
680   });
681 }
682 
aoti_torch__scaled_mm(AtenTensorHandle self,AtenTensorHandle mat2,AtenTensorHandle bias,int32_t * out_dtype,AtenTensorHandle scale_a,AtenTensorHandle scale_b,AtenTensorHandle scale_result,int8_t use_fast_accum,AtenTensorHandle * ret0,AtenTensorHandle * ret1)683 AOTITorchError aoti_torch__scaled_mm(
684     AtenTensorHandle self,
685     AtenTensorHandle mat2,
686     AtenTensorHandle bias,
687     int32_t* out_dtype,
688     AtenTensorHandle scale_a,
689     AtenTensorHandle scale_b,
690     AtenTensorHandle scale_result,
691     int8_t use_fast_accum,
692     AtenTensorHandle* ret0,
693     AtenTensorHandle* ret1) {
694   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
695     at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
696     at::Tensor* mat2_tensor = tensor_handle_to_tensor_pointer(mat2);
697     at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias);
698     at::Tensor* scale_a_tensor = tensor_handle_to_tensor_pointer(scale_a);
699     at::Tensor* scale_b_tensor = tensor_handle_to_tensor_pointer(scale_b);
700     at::Tensor* scale_result_tensor =
701         tensor_handle_to_tensor_pointer(scale_result);
702     auto r0 = at::_scaled_mm(
703         *self_tensor,
704         *mat2_tensor,
705         *scale_a_tensor,
706         *scale_b_tensor,
707         pointer_to_optional(bias_tensor),
708         pointer_to_optional(scale_result_tensor),
709         pointer_to_optional<c10::ScalarType>(out_dtype),
710         use_fast_accum);
711     *ret0 = new_tensor_handle(std::move(r0));
712   });
713 }
714 
aoti_torch__scaled_mm_v2(AtenTensorHandle self,AtenTensorHandle mat2,AtenTensorHandle scale_a,AtenTensorHandle scale_b,AtenTensorHandle bias,AtenTensorHandle scale_result,int32_t * out_dtype,int8_t use_fast_accum,AtenTensorHandle * ret0)715 AOTITorchError aoti_torch__scaled_mm_v2(
716     AtenTensorHandle self,
717     AtenTensorHandle mat2,
718     AtenTensorHandle scale_a,
719     AtenTensorHandle scale_b,
720     AtenTensorHandle bias,
721     AtenTensorHandle scale_result,
722     int32_t* out_dtype,
723     int8_t use_fast_accum,
724     AtenTensorHandle* ret0) {
725   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
726     at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
727     at::Tensor* mat2_tensor = tensor_handle_to_tensor_pointer(mat2);
728     at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias);
729     at::Tensor* scale_a_tensor = tensor_handle_to_tensor_pointer(scale_a);
730     at::Tensor* scale_b_tensor = tensor_handle_to_tensor_pointer(scale_b);
731     at::Tensor* scale_result_tensor =
732         tensor_handle_to_tensor_pointer(scale_result);
733     auto r0 = at::_scaled_mm(
734         *self_tensor,
735         *mat2_tensor,
736         *scale_a_tensor,
737         *scale_b_tensor,
738         pointer_to_optional(bias_tensor),
739         pointer_to_optional(scale_result_tensor),
740         pointer_to_optional<c10::ScalarType>(out_dtype),
741         use_fast_accum);
742     *ret0 = new_tensor_handle(std::move(r0));
743   });
744 }
745 
746 // TODO: implement a more efficient version instead of calling into aten
aoti_torch_tensor_copy_(AtenTensorHandle src,AtenTensorHandle dst)747 AOTITorchError aoti_torch_tensor_copy_(
748     AtenTensorHandle src,
749     AtenTensorHandle dst) {
750   return aoti_torch_copy_(dst, src, /*non_blocking=*/0);
751 }
752 
aoti_torch_assign_tensors(AtenTensorHandle src,AtenTensorHandle dst)753 AOTITorchError aoti_torch_assign_tensors(
754     AtenTensorHandle src,
755     AtenTensorHandle dst) {
756   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
757     at::Tensor* src_tensor = tensor_handle_to_tensor_pointer(src);
758     at::Tensor* dst_tensor = tensor_handle_to_tensor_pointer(dst);
759     *dst_tensor = *src_tensor;
760   });
761 }
762 
aoti_torch_assign_tensors_out(AtenTensorHandle src,AtenTensorHandle * ret_dst)763 AOTITorchError aoti_torch_assign_tensors_out(
764     AtenTensorHandle src,
765     AtenTensorHandle* ret_dst) {
766   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
767     at::Tensor* src_tensor_ptr = tensor_handle_to_tensor_pointer(src);
768     at::Tensor dst_tensor = *src_tensor_ptr;
769     *ret_dst = new_tensor_handle(std::move(dst_tensor));
770   });
771 }
772 
aoti_torch_clone(AtenTensorHandle self,AtenTensorHandle * ret)773 AOTITorchError aoti_torch_clone(AtenTensorHandle self, AtenTensorHandle* ret) {
774   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
775     at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
776     *ret = new_tensor_handle(self_tensor->clone());
777   });
778 }
779 
780 // TODO: implement a more efficient version instead of calling into aten
aoti_torch_addmm_out(AtenTensorHandle out,AtenTensorHandle self,AtenTensorHandle mat1,AtenTensorHandle mat2,float beta,float alpha)781 AOTITorchError aoti_torch_addmm_out(
782     AtenTensorHandle out,
783     AtenTensorHandle self,
784     AtenTensorHandle mat1,
785     AtenTensorHandle mat2,
786     float beta,
787     float alpha) {
788   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
789     at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out);
790     at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
791     at::Tensor* mat1_tensor = tensor_handle_to_tensor_pointer(mat1);
792     at::Tensor* mat2_tensor = tensor_handle_to_tensor_pointer(mat2);
793     at::addmm_out(
794         *out_tensor, *self_tensor, *mat1_tensor, *mat2_tensor, beta, alpha);
795   });
796 }
797 
798 // TODO: implement a more efficient version instead of calling into aten
aoti_torch_bmm_out(AtenTensorHandle out,AtenTensorHandle self,AtenTensorHandle mat2)799 AOTITorchError aoti_torch_bmm_out(
800     AtenTensorHandle out,
801     AtenTensorHandle self,
802     AtenTensorHandle mat2) {
803   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
804     at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out);
805     at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
806     at::Tensor* mat2_tensor = tensor_handle_to_tensor_pointer(mat2);
807     at::bmm_out(*out_tensor, *self_tensor, *mat2_tensor);
808   });
809 }
810 
aoti_torch_copy_(AtenTensorHandle self,AtenTensorHandle src,int32_t non_blocking)811 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_copy_(
812     AtenTensorHandle self,
813     AtenTensorHandle src,
814     int32_t non_blocking) {
815   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
816     tensor_handle_to_tensor_pointer(self)->copy_(
817         *tensor_handle_to_tensor_pointer(src), non_blocking);
818   });
819 }
820 
821 // TODO: implement a more efficient version instead of calling into aten
aoti_torch_mm_out(AtenTensorHandle out,AtenTensorHandle self,AtenTensorHandle mat2)822 AOTITorchError aoti_torch_mm_out(
823     AtenTensorHandle out,
824     AtenTensorHandle self,
825     AtenTensorHandle mat2) {
826   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
827     at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out);
828     at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
829     at::Tensor* mat2_tensor = tensor_handle_to_tensor_pointer(mat2);
830     at::mm_out(*out_tensor, *self_tensor, *mat2_tensor);
831   });
832 }
833 
aoti_torch__mm_plus_mm_out(AtenTensorHandle out,AtenTensorHandle a,AtenTensorHandle b,AtenTensorHandle c,AtenTensorHandle d)834 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__mm_plus_mm_out(
835     AtenTensorHandle out,
836     AtenTensorHandle a,
837     AtenTensorHandle b,
838     AtenTensorHandle c,
839     AtenTensorHandle d) {
840   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
841     at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out);
842     at::Tensor* a_tensor = tensor_handle_to_tensor_pointer(a);
843     at::Tensor* b_tensor = tensor_handle_to_tensor_pointer(b);
844     at::Tensor* c_tensor = tensor_handle_to_tensor_pointer(c);
845     at::Tensor* d_tensor = tensor_handle_to_tensor_pointer(d);
846     torch::inductor::_mm_plus_mm_out(
847         *out_tensor, *a_tensor, *b_tensor, *c_tensor, *d_tensor);
848   });
849 }
850 
aoti_torch_cpu_wrapped_fbgemm_pack_gemm_matrix_fp16(AtenTensorHandle weight,AtenTensorHandle * out)851 AOTITorchError aoti_torch_cpu_wrapped_fbgemm_pack_gemm_matrix_fp16(
852     AtenTensorHandle weight,
853     AtenTensorHandle* out) {
854   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
855     at::Tensor* weight_tensor = tensor_handle_to_tensor_pointer(weight);
856 
857     *out = new_tensor_handle(at::fbgemm_pack_gemm_matrix_fp16(*weight_tensor));
858   });
859 }
860 
aoti_torch_cpu__wrapped_linear_prepack(AtenTensorHandle weight,AtenTensorHandle weight_scale,AtenTensorHandle weight_zero_point,AtenTensorHandle bias,AtenTensorHandle * out)861 AOTITorchError aoti_torch_cpu__wrapped_linear_prepack(
862     AtenTensorHandle weight,
863     AtenTensorHandle weight_scale,
864     AtenTensorHandle weight_zero_point,
865     AtenTensorHandle bias,
866     AtenTensorHandle* out) {
867   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
868     at::Tensor* weight_tensor = tensor_handle_to_tensor_pointer(weight);
869     at::Tensor* weight_scale_tensor =
870         tensor_handle_to_tensor_pointer(weight_scale);
871     at::Tensor* weight_zero_point_tensor =
872         tensor_handle_to_tensor_pointer(weight_zero_point);
873     at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias);
874 
875     *out = new_tensor_handle(at::_wrapped_linear_prepack(
876         *weight_tensor,
877         *weight_scale_tensor,
878         *weight_zero_point_tensor,
879         *bias_tensor));
880   });
881 }
882 
aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight(AtenTensorHandle input,AtenTensorHandle weight,AtenTensorHandle bias,int64_t out_channel,AtenTensorHandle * out)883 AOTITorchError aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight(
884     AtenTensorHandle input,
885     AtenTensorHandle weight,
886     AtenTensorHandle bias,
887     int64_t out_channel,
888     AtenTensorHandle* out) {
889   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
890     at::Tensor* input_tensor = tensor_handle_to_tensor_pointer(input);
891     at::Tensor* weight_tensor = tensor_handle_to_tensor_pointer(weight);
892     at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias);
893 
894     *out = new_tensor_handle(at::fbgemm_linear_fp16_weight_fp32_activation(
895         *input_tensor, *weight_tensor, *bias_tensor));
896   });
897 }
898 
aoti_torch_cpu__wrapped_quantized_linear_prepacked(AtenTensorHandle input,AtenTensorHandle input_scale,AtenTensorHandle input_zero_point,AtenTensorHandle weight,AtenTensorHandle out_scale,AtenTensorHandle out_zeropoint,int64_t out_channel,AtenTensorHandle * out)899 AOTITorchError aoti_torch_cpu__wrapped_quantized_linear_prepacked(
900     AtenTensorHandle input,
901     AtenTensorHandle input_scale,
902     AtenTensorHandle input_zero_point,
903     AtenTensorHandle weight,
904     AtenTensorHandle out_scale,
905     AtenTensorHandle out_zeropoint,
906     int64_t out_channel,
907     AtenTensorHandle* out) {
908   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
909     at::Tensor* input_tensor = tensor_handle_to_tensor_pointer(input);
910     at::Tensor* input_scale_tensor =
911         tensor_handle_to_tensor_pointer(input_scale);
912     at::Tensor* input_zero_point_tensor =
913         tensor_handle_to_tensor_pointer(input_zero_point);
914     at::Tensor* weight_tensor = tensor_handle_to_tensor_pointer(weight);
915     at::Tensor* out_scale_tensor = tensor_handle_to_tensor_pointer(out_scale);
916     at::Tensor* out_zeropoint_tensor =
917         tensor_handle_to_tensor_pointer(out_zeropoint);
918     *out = new_tensor_handle(at::_wrapped_quantized_linear_prepacked(
919         *input_tensor,
920         *input_scale_tensor,
921         *input_zero_point_tensor,
922         *weight_tensor,
923         *out_scale_tensor,
924         *out_zeropoint_tensor,
925         out_channel));
926   });
927 }
928 
aoti_torch_nonzero(AtenTensorHandle self,AtenTensorHandle * out)929 AOTITorchError aoti_torch_nonzero(
930     AtenTensorHandle self,
931     AtenTensorHandle* out) {
932   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
933     at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
934     *out = new_tensor_handle(at::nonzero(*self_tensor));
935   });
936 }
937 
aoti_torch_repeat_interleave_Tensor(AtenTensorHandle repeats,int64_t * output_size,AtenTensorHandle * out)938 AOTITorchError aoti_torch_repeat_interleave_Tensor(
939     AtenTensorHandle repeats,
940     int64_t* output_size,
941     AtenTensorHandle* out) {
942   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
943     at::Tensor* repeats_tensor = tensor_handle_to_tensor_pointer(repeats);
944     *out = new_tensor_handle(at::_ops::repeat_interleave_Tensor::call(
945         *repeats_tensor, pointer_to_optional<c10::SymInt>(output_size)));
946   });
947 }
948 
949 // Function to check existence of inf and NaN
aoti_torch_check_inf_and_nan(const char * tensor_name,AtenTensorHandle tensor)950 AOTITorchError aoti_torch_check_inf_and_nan(
951     const char* tensor_name,
952     AtenTensorHandle tensor) {
953   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
954     at::Tensor* check_tensor = tensor_handle_to_tensor_pointer(tensor);
955 
956     assert_inf_and_nan(tensor_name, *check_tensor);
957   });
958 }
959 
aoti_torch_scatter_out(AtenTensorHandle out,AtenTensorHandle self,int64_t dim,AtenTensorHandle index,AtenTensorHandle src)960 AOTITorchError aoti_torch_scatter_out(
961     AtenTensorHandle out,
962     AtenTensorHandle self,
963     int64_t dim,
964     AtenTensorHandle index,
965     AtenTensorHandle src) {
966   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
967     at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out);
968     at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
969     at::Tensor* index_tensor = tensor_handle_to_tensor_pointer(index);
970     at::Tensor* src_tensor = tensor_handle_to_tensor_pointer(src);
971     at::scatter_out(*out_tensor, *self_tensor, dim, *index_tensor, *src_tensor);
972   });
973 }
974 
aoti_torch_scatter_reduce_out(AtenTensorHandle out,AtenTensorHandle self,int64_t dim,AtenTensorHandle index,AtenTensorHandle src,const char * reduce,int32_t include_self)975 AOTITorchError aoti_torch_scatter_reduce_out(
976     AtenTensorHandle out,
977     AtenTensorHandle self,
978     int64_t dim,
979     AtenTensorHandle index,
980     AtenTensorHandle src,
981     const char* reduce,
982     int32_t include_self) {
983   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
984     at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out);
985     at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
986     at::Tensor* index_tensor = tensor_handle_to_tensor_pointer(index);
987     at::Tensor* src_tensor = tensor_handle_to_tensor_pointer(src);
988     at::scatter_reduce_out(
989         *out_tensor,
990         *self_tensor,
991         dim,
992         *index_tensor,
993         *src_tensor,
994         reduce,
995         (bool)include_self);
996   });
997 }
998 
aoti_torch_index_put_out(AtenTensorHandle out,AtenTensorHandle self,const AtenTensorHandle * indices,const uint32_t num_indices,const AtenTensorHandle values,bool accumulate)999 AOTITorchError aoti_torch_index_put_out(
1000     AtenTensorHandle out,
1001     AtenTensorHandle self,
1002     const AtenTensorHandle* indices,
1003     const uint32_t num_indices,
1004     const AtenTensorHandle values,
1005     bool accumulate) {
1006   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1007     c10::List<std::optional<at::Tensor>> indices_;
1008     indices_.reserve(num_indices);
1009     for (size_t i = 0; i < num_indices; i++) {
1010       indices_.emplace_back(
1011           pointer_to_optional(tensor_handle_to_tensor_pointer(indices[i])));
1012     }
1013     at::Tensor* out_tensor = tensor_handle_to_tensor_pointer(out);
1014     at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
1015     at::Tensor* values_tensor = tensor_handle_to_tensor_pointer(values);
1016     at::index_put_out(
1017         *out_tensor, *self_tensor, indices_, *values_tensor, accumulate);
1018   });
1019 }
1020 
aoti_torch_view_as_real(AtenTensorHandle self,AtenTensorHandle * ret)1021 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_as_real(
1022     AtenTensorHandle self,
1023     AtenTensorHandle* ret // returns new reference
1024 ) {
1025   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1026     *ret = new_tensor_handle(
1027         at::_ops::view_as_real::call(*tensor_handle_to_tensor_pointer(self)));
1028   });
1029 }
1030 
aoti_torch_view_dtype(AtenTensorHandle self,int32_t dtype,AtenTensorHandle * ret)1031 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_dtype(
1032     AtenTensorHandle self,
1033     int32_t dtype,
1034     AtenTensorHandle* ret // returns new reference
1035 ) {
1036   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1037     at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
1038     *ret = new_tensor_handle(at::_ops::view_dtype::call(
1039         *self_tensor, static_cast<c10::ScalarType>(dtype)));
1040   });
1041 }
1042 
aoti_torch_save_tensor_handle(AtenTensorHandle self,const char * tensor_name,const char * launch_prefix,const char * kernel_name)1043 AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle(
1044     AtenTensorHandle self,
1045     const char* tensor_name,
1046     const char* launch_prefix,
1047     const char* kernel_name) {
1048   at::Tensor* t = tensor_handle_to_tensor_pointer(self);
1049 #ifndef C10_MOBILE
1050   // Save tensor to tmp .pt file for tensors and can be torch.load'ed later
1051   std::string cwd = get_current_path();
1052   std::string tmp_folder = cwd + "/tmp/aoti_torch/";
1053   if (!file_exists(tmp_folder)) {
1054     std::cout
1055         << "aoti_torch_save_tensor_handle: Path does not exist, creating it..."
1056         << tmp_folder << std::endl;
1057 
1058     if (!create_directories(tmp_folder)) {
1059       std::cout << "aoti_torch_save_tensor_handle: Error creating directory: "
1060                 << tmp_folder << std::endl;
1061       return;
1062     }
1063   }
1064   std::string tensor_filepath_to_save = tmp_folder + launch_prefix + "_" +
1065       kernel_name + "_" + tensor_name + "_" + t->device().str() + ".pt";
1066 
1067   auto bytes = torch::jit::pickle_save(c10::IValue(*t));
1068   std::ofstream fout(tensor_filepath_to_save, std::ios::out | std::ios::binary);
1069   fout.write(bytes.data(), bytes.size());
1070   fout.close();
1071 
1072   std::cout << "aoti_torch_save_tensor_handle: Saved tensor to "
1073             << tensor_filepath_to_save << std::endl;
1074 #endif // !defined(C10_MOBILE)
1075 }
1076 
aoti_torch_print_tensor_handle(AtenTensorHandle self,const char * msg)1077 AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle(
1078     AtenTensorHandle self,
1079     const char* msg) {
1080   at::Tensor* t = tensor_handle_to_tensor_pointer(self);
1081 
1082   // Display message
1083   std::cout << "[";
1084   if (msg) {
1085     std::cout << "  " << msg;
1086   }
1087   std::cout << "  "
1088             << "]:" << std::endl;
1089 
1090   // Print exact tensor values for small size tensors
1091   const int64_t numel = t->numel();
1092   if (numel <= AOTI_TORCH_MAX_NUMEL_TO_PRINT) {
1093     std::cout << *t << "\n";
1094   }
1095 
1096   // Print summary stats of the tensor
1097   std::cout << "Number of elements: " << numel << std::endl;
1098   std::cout << "Dtype: " << t->dtype() << std::endl;
1099   if (numel > 0) {
1100     // torch/aten `mean()` function only supports float and complex dtypes
1101     // See:
1102     // https://github.com/pytorch/pytorch/blob/a0e062c6f1a03ec93e87413e42c4d0b336518131/aten/src/ATen/native/ReduceOps.cpp#L304-L309
1103     auto mean_value = [t](at::ScalarType dtype) {
1104       return t->to(dtype).mean().item();
1105     };
1106     bool is_complex_type =
1107         at::isComplexType(at::typeMetaToScalarType(t->dtype()));
1108     at::ScalarType float_dtype =
1109         is_complex_type ? at::kComplexFloat : at::kFloat;
1110     std::cout << "Mean value: " << mean_value(float_dtype) << std::endl;
1111     if (!is_complex_type) {
1112       // "min_all_cuda" function is not implemented for 'ComplexFloat' type.
1113       // (similar for max) Skip printing min/max value for complex type tensors
1114       // here If encountered complex dtypes (rare occasions), suggest to print
1115       // out the whole value of the tensor.
1116       std::cout << "Min value: " << t->min().item<float>() << std::endl;
1117       std::cout << "Max value: " << t->max().item<float>() << std::endl;
1118     }
1119   }
1120   std::cout << "Device: " << t->device() << std::endl;
1121   std::cout << "Size: " << t->sizes() << std::endl;
1122   std::cout << "Stride: " << t->strides() << std::endl;
1123   std::cout << "Layout: " << t->layout() << std::endl;
1124   std::cout << "Is contiguous: " << t->is_contiguous() << std::endl;
1125   std::cout << "Requires grad: " << t->requires_grad() << std::endl;
1126 
1127   std::cout << std::endl;
1128 }
1129 
1130 // ProxyExecutor
aoti_torch_proxy_executor_call_function(AOTIProxyExecutorHandle proxy_executor,int extern_node_index,int num_ints,int64_t * flatten_int_args,int num_tensors,AtenTensorHandle * flatten_tensor_args)1131 AOTITorchError aoti_torch_proxy_executor_call_function(
1132     AOTIProxyExecutorHandle proxy_executor,
1133     int extern_node_index,
1134     int num_ints,
1135     int64_t* flatten_int_args,
1136     int num_tensors,
1137     AtenTensorHandle* flatten_tensor_args) {
1138   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1139     ProxyExecutor* executor = reinterpret_cast<ProxyExecutor*>(proxy_executor);
1140     executor->call_function(
1141         extern_node_index,
1142         num_ints,
1143         flatten_int_args,
1144         num_tensors,
1145         flatten_tensor_args);
1146   });
1147 }
1148 
aoti_torch_check(bool cond,const char * func,const char * file,uint32_t line,const char * msg)1149 void aoti_torch_check(
1150     bool cond,
1151     const char* func,
1152     const char* file,
1153     uint32_t line,
1154     const char* msg) {
1155   if (C10_UNLIKELY_OR_CONST(!cond)) {
1156     ::c10::detail::torchCheckFail(func, file, line, msg);
1157   }
1158 }
1159 
aoti_torch__alloc_from_pool(AtenTensorHandle self,int64_t offset_bytes,int32_t dtype,int64_t ndim,const int64_t * sizes_ptr,const int64_t * strides_ptr,AtenTensorHandle * ret_new_tensor)1160 AOTITorchError aoti_torch__alloc_from_pool(
1161     AtenTensorHandle self,
1162     int64_t offset_bytes,
1163     int32_t dtype,
1164     int64_t ndim,
1165     const int64_t* sizes_ptr,
1166     const int64_t* strides_ptr,
1167     AtenTensorHandle* ret_new_tensor) {
1168   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1169     at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
1170     c10::IntArrayRef sizes(sizes_ptr, ndim);
1171     c10::IntArrayRef strides(strides_ptr, ndim);
1172     *ret_new_tensor = new_tensor_handle(torch::inductor::_alloc_from_pool(
1173         *self_tensor,
1174         offset_bytes,
1175         static_cast<c10::ScalarType>(dtype),
1176         sizes,
1177         strides));
1178   });
1179 }
1180