xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_torch/c/shim.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifndef AOTI_TORCH_SHIM
2 #define AOTI_TORCH_SHIM
3 
4 #include <stddef.h>
5 #include <stdint.h>
6 
7 // This header defines a stable C API for certain ATen functionality in
8 // libtorch. The AOTInductor compiled model.so will only refer to this header
9 // instead of other headers from aten/c10, which means it will NOT be able to
10 // directly use any data structures or call functions from libtorch.
11 //
12 // What problems are we trying to solve here?  Direct use of aten/c10 APIs
13 // means use of C++ APIs on a library that doesn't have any ABI compatibility
14 // guarantees.  However, we want model.so to remain usable across updates
15 // to the PyTorch C++ libraries, which requires a stable ABI.  By introducing
16 // a C shim layer, we can minimize the surface that will cause breakage. The
17 // corresponding software stack can be illustrated as follows:
18 //
19 // |--------------------------------|
20 // |     inference service code     |
21 // |--------------------------------|
22 // |           model.so             |
23 // |--------------|-----------------|
24 // |           <c shim>             |
25 // |          libtorch.so           |
26 // |--------------------------------|
27 //
28 // The general guidelines for the C API:
29 //
30 //  - No exceptions, return an explicit error code to be checked at call site
31 //  - Only pointers (AtenTensorHandle counts), integers and floats in headers
32 //
33 // If you want to make changes to this header, you MUST MAINTAIN ABI
34 // compatibility.  Typically, this means you will have to add a _v2 version
35 // of a function that you, e.g., want to add a new function parameter to, and
36 // maintain the old and new versions of the APIs until all old model.so
37 // go out of use.
38 
39 #ifdef __GNUC__
40 #define AOTI_TORCH_EXPORT __attribute__((__visibility__("default")))
41 #else // !__GNUC__
42 #ifdef _WIN32
43 // PyTorch2 doesn't currently work on Windows. Exporting these APIs can lead
44 // to symbol clashes at link time if libtorch is included in a DLL and binary
45 // that depends on the DLL. As a short term fix, we don't export the symbols.
46 // In the long term, this will need to be addressed when Windows is supported.
47 // #define AOTI_TORCH_EXPORT __declspec(dllexport)
48 #define AOTI_TORCH_EXPORT
49 #else // !_WIN32
50 #define AOTI_TORCH_EXPORT
51 #endif // _WIN32
52 #endif // __GNUC__
53 
54 #include <c10/util/BFloat16.h>
55 #include <c10/util/Half.h>
56 #include <c10/util/complex.h>
57 
58 #ifdef __cplusplus
59 extern "C" {
60 #endif
61 
62 // AtenTensorHandle represents an abstract notion of Tensor that can be passed
63 // between model.so and libtorch.so.  The contents of the structure itself
64 // are private; model.so is not allowed to access any fields directly, it must
65 // go through functions defined in this ABI.  Under the hood, this is
66 // represented as at::Tensor*, but we reserve the right to change this (and in
67 // fact, we probably should change it to at::TensorImpl* at least).
68 //
69 // An AtenTensorHandle can be owning (please check the API reference for exact
70 // ownership/borrow semantics).  If you have an owning AtenTensorHandle
71 // in model.so, you are obligated to aoti_torch_delete_tensor_object when you
72 // are done.  You can use the helper C++ class RAIIAtenTensorHandle
73 // (see aot_runtime/model.h) to ensure the deallocator is called in RAII style
74 // (note that RAIIAtenTensorHandle is private to model.so, and never crosses
75 // the ABI boundary.)
76 struct AtenTensorOpaque;
77 using AtenTensorHandle = AtenTensorOpaque*;
78 
79 struct AtenGeneratorOpaque;
80 using AtenGeneratorHandle = AtenGeneratorOpaque*;
81 
82 struct AOTIProxyExecutorOpaque;
83 using AOTIProxyExecutorHandle = AOTIProxyExecutorOpaque*;
84 
85 using AOTITorchError = int32_t;
86 #define AOTI_TORCH_SUCCESS 0
87 #define AOTI_TORCH_FAILURE 1
88 
89 // Getter functions for retrieving various constants from the runtime, that
90 // can subsequently be passed to other aoti_* functions.  By hiding these
91 // behind functions, the precise value of device/dtype is NOT part of the
92 // ABI contract.  (In practice, aten/c10 is pretty good about not renumbering
93 // these, so we probably could later switch to having these in the ABI, if
94 // desired for perf reasons.)
95 AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cpu();
96 AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cuda();
97 
98 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2();
99 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fn();
100 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2fnuz();
101 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fnuz();
102 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bfloat16();
103 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float16();
104 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float32();
105 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float64();
106 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint8();
107 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint16();
108 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint32();
109 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint64();
110 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int8();
111 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int16();
112 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int32();
113 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int64();
114 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bool();
115 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex32();
116 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex64();
117 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex128();
118 
119 AOTI_TORCH_EXPORT int32_t aoti_torch_layout_strided();
120 AOTI_TORCH_EXPORT int32_t aoti_torch_layout__mkldnn();
121 
122 // Functions for converting a single-element tensor to a scalar value
123 AOTI_TORCH_EXPORT AOTITorchError
124 aoti_torch_item_float16(AtenTensorHandle tensor, c10::Half* ret_value);
125 AOTI_TORCH_EXPORT AOTITorchError
126 aoti_torch_item_float32(AtenTensorHandle tensor, float* ret_value);
127 AOTI_TORCH_EXPORT AOTITorchError
128 aoti_torch_item_float64(AtenTensorHandle tensor, double* ret_value);
129 AOTI_TORCH_EXPORT AOTITorchError
130 aoti_torch_item_uint8(AtenTensorHandle tensor, uint8_t* ret_value);
131 AOTI_TORCH_EXPORT AOTITorchError
132 aoti_torch_item_uint16(AtenTensorHandle tensor, uint16_t* ret_value);
133 AOTI_TORCH_EXPORT AOTITorchError
134 aoti_torch_item_uint32(AtenTensorHandle tensor, uint32_t* ret_value);
135 AOTI_TORCH_EXPORT AOTITorchError
136 aoti_torch_item_uint64(AtenTensorHandle tensor, uint64_t* ret_value);
137 AOTI_TORCH_EXPORT AOTITorchError
138 aoti_torch_item_int8(AtenTensorHandle tensor, int8_t* ret_value);
139 AOTI_TORCH_EXPORT AOTITorchError
140 aoti_torch_item_int16(AtenTensorHandle tensor, int16_t* ret_value);
141 AOTI_TORCH_EXPORT AOTITorchError
142 aoti_torch_item_int32(AtenTensorHandle tensor, int32_t* ret_value);
143 AOTI_TORCH_EXPORT AOTITorchError
144 aoti_torch_item_int64(AtenTensorHandle tensor, int64_t* ret_value);
145 AOTI_TORCH_EXPORT AOTITorchError
146 aoti_torch_item_bool(AtenTensorHandle tensor, bool* ret_value);
147 AOTI_TORCH_EXPORT AOTITorchError
148 aoti_torch_item_bfloat16(AtenTensorHandle tensor, c10::BFloat16* ret_value);
149 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_complex64(
150     AtenTensorHandle tensor,
151     c10::complex<float>* ret_value);
152 
153 // Functions for wrapping a scalar value to a single-element tensor
154 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_float32(
155     float value,
156     AtenTensorHandle* ret_new_tensor);
157 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_float64(
158     double value,
159     AtenTensorHandle* ret_new_tensor);
160 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint8(
161     uint8_t value,
162     AtenTensorHandle* ret_new_tensor);
163 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint16(
164     uint16_t value,
165     AtenTensorHandle* ret_new_tensor);
166 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint32(
167     uint32_t value,
168     AtenTensorHandle* ret_new_tensor);
169 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint64(
170     uint64_t value,
171     AtenTensorHandle* ret_new_tensor);
172 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int8(
173     int8_t value,
174     AtenTensorHandle* ret_new_tensor);
175 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int16(
176     int16_t value,
177     AtenTensorHandle* ret_new_tensor);
178 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int32(
179     int32_t value,
180     AtenTensorHandle* ret_new_tensor);
181 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int64(
182     int64_t value,
183     AtenTensorHandle* ret_new_tensor);
184 AOTI_TORCH_EXPORT AOTITorchError
185 aoti_torch_scalar_to_tensor_bool(bool value, AtenTensorHandle* ret_new_tensor);
186 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_complex64(
187     c10::complex<float> value,
188     AtenTensorHandle* ret_new_tensor);
189 
190 AOTI_TORCH_EXPORT bool aoti_torch_grad_mode_is_enabled();
191 AOTI_TORCH_EXPORT void aoti_torch_grad_mode_set_enabled(bool enabled);
192 
193 // Free the tensor object
194 AOTI_TORCH_EXPORT AOTITorchError
195 aoti_torch_delete_tensor_object(AtenTensorHandle tensor);
196 
197 // Get a pointer to the underlying storage data
198 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_data_ptr(
199     AtenTensorHandle tensor,
200     void** ret_data_ptr // returns borrowed reference
201 );
202 
203 // Get the nbytes of the underlying storage
204 AOTI_TORCH_EXPORT AOTITorchError
205 aoti_torch_get_storage_size(AtenTensorHandle tensor, int64_t* ret_size);
206 
207 AOTI_TORCH_EXPORT AOTITorchError
208 aoti_torch_get_dim(AtenTensorHandle tensor, int64_t* ret_dim);
209 
210 AOTI_TORCH_EXPORT AOTITorchError
211 aoti_torch_get_numel(AtenTensorHandle tensor, int64_t* ret_numel);
212 
213 AOTI_TORCH_EXPORT AOTITorchError
214 aoti_torch_get_storage_numel(AtenTensorHandle tensor, int64_t* ret_numel);
215 
216 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_sizes(
217     AtenTensorHandle tensor,
218     int64_t** ret_sizes // returns borrowed reference
219 );
220 
221 AOTI_TORCH_EXPORT AOTITorchError
222 aoti_torch_get_size(AtenTensorHandle tensor, int64_t d, int64_t* ret_size);
223 
224 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_strides(
225     AtenTensorHandle tensor,
226     int64_t** ret_strides // returns borrowed reference
227 );
228 
229 AOTI_TORCH_EXPORT AOTITorchError
230 aoti_torch_get_stride(AtenTensorHandle tensor, int64_t d, int64_t* ret_stride);
231 
232 AOTI_TORCH_EXPORT AOTITorchError
233 aoti_torch_get_dtype(AtenTensorHandle tensor, int32_t* ret_dtype);
234 
235 AOTI_TORCH_EXPORT AOTITorchError
236 aoti_torch_get_device_type(AtenTensorHandle tensor, int32_t* ret_device_type);
237 
238 AOTI_TORCH_EXPORT AOTITorchError
239 aoti_torch_get_device_index(AtenTensorHandle tensor, int32_t* ret_device_index);
240 
241 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset(
242     AtenTensorHandle tensor,
243     int64_t* ret_storage_offset);
244 
245 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__alloc_from_pool(
246     AtenTensorHandle self,
247     int64_t offset_bytes,
248     int32_t dtype,
249     int64_t ndim,
250     const int64_t* sizes_ptr,
251     const int64_t* strides_ptr,
252     AtenTensorHandle* ret_new_tensor);
253 
254 // This function will create a new tensor object and its pointer is returned
255 // through *out. The caller is responsible for wrapping the tensor pointer
256 // with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object
257 // when going out of scope.
258 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__reinterpret_tensor(
259     AtenTensorHandle self,
260     int64_t ndim,
261     const int64_t* sizes_ptr,
262     const int64_t* strides_ptr,
263     int64_t storage_offset,
264     AtenTensorHandle* ret_new_tensor // returns new reference
265 );
266 
267 // This function will create a new tensor object and its pointer is returned
268 // through *out. The caller is responsible for wrapping the tensor pointer
269 // with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object
270 // when going out of scope.
271 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_empty_strided(
272     int64_t ndim,
273     const int64_t* sizes_ptr,
274     const int64_t* strides_ptr,
275     int32_t dtype,
276     int32_t device_type,
277     int32_t device_index,
278     AtenTensorHandle* ret_new_tensor // returns new reference
279 );
280 
281 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
282     void* data,
283     int64_t ndim,
284     const int64_t* sizes_ptr,
285     const int64_t* strides_ptr,
286     int64_t storage_offset,
287     int32_t dtype,
288     int32_t device_type,
289     int32_t device_index,
290     AtenTensorHandle* ret // returns new reference
291 );
292 
293 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2(
294     void* data,
295     int64_t ndim,
296     const int64_t* sizes_ptr,
297     const int64_t* strides_ptr,
298     int64_t storage_offset,
299     int32_t dtype,
300     int32_t device_type,
301     int32_t device_index,
302     AtenTensorHandle* ret, // returns new reference
303     int32_t layout,
304     const uint8_t* opaque_metadata,
305     int64_t opaque_metadata_size);
306 
307 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag(
308     AtenTensorHandle weight,
309     AtenTensorHandle indices,
310     AtenTensorHandle offsets,
311     int32_t scale_grad_by_freq,
312     int32_t mode,
313     int32_t sparse,
314     AtenTensorHandle per_sample_weights, // optional argument
315     int32_t include_last_offset,
316     int32_t padding_idx,
317     AtenTensorHandle* ret0, // returns new reference
318     AtenTensorHandle* ret1, // returns new reference
319     AtenTensorHandle* ret2, // returns new reference
320     AtenTensorHandle* ret3 // returns new reference
321 );
322 
323 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__fft_c2c(
324     AtenTensorHandle self,
325     const int64_t* dim_ptr,
326     int64_t dim_size,
327     int64_t normalization,
328     int32_t forward,
329     AtenTensorHandle* ret // returns new reference
330 );
331 
332 // This version is deprecated. We will remove it later
333 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
334     AtenTensorHandle query,
335     AtenTensorHandle key,
336     AtenTensorHandle value,
337     double dropout_p,
338     bool is_causal,
339     bool return_debug_mask,
340     double scale,
341     AtenTensorHandle* ret0, // returns new reference
342     AtenTensorHandle* ret1, // returns new reference
343     AtenTensorHandle* ret2, // returns new reference
344     AtenTensorHandle* ret3, // returns new reference
345     int64_t* ret4,
346     int64_t* ret5,
347     AtenTensorHandle* ret6, // returns new reference
348     AtenTensorHandle* ret7, // returns new reference
349     AtenTensorHandle* ret8 // returns new reference
350 );
351 
352 AOTI_TORCH_EXPORT AOTITorchError
353 aoti_torch__scaled_dot_product_flash_attention_v2(
354     AtenTensorHandle query,
355     AtenTensorHandle key,
356     AtenTensorHandle value,
357     double dropout_p,
358     int is_causal,
359     int return_debug_mask,
360     double* scale, // optional argument
361     AtenTensorHandle* ret0, // returns new reference
362     AtenTensorHandle* ret1, // returns new reference
363     AtenTensorHandle* ret2, // returns new reference
364     AtenTensorHandle* ret3, // returns new reference
365     int64_t* ret4,
366     int64_t* ret5,
367     AtenTensorHandle* ret6, // returns new reference
368     AtenTensorHandle* ret7, // returns new reference
369     AtenTensorHandle* ret8 // returns new reference
370 );
371 
372 AOTI_TORCH_EXPORT AOTITorchError
373 aoti_torch__scaled_dot_product_efficient_attention(
374     AtenTensorHandle query,
375     AtenTensorHandle key,
376     AtenTensorHandle value,
377     AtenTensorHandle attn_bias, // optional argument
378     int compute_log_sumexp,
379     double dropout_p,
380     int is_causal,
381     double* scale, // optional argument
382     AtenTensorHandle* ret0, // returns new reference
383     AtenTensorHandle* ret1, // returns new reference
384     AtenTensorHandle* ret2, // returns new reference
385     AtenTensorHandle* ret3 // returns new reference
386 );
387 
388 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm(
389     AtenTensorHandle self,
390     AtenTensorHandle mat2,
391     AtenTensorHandle bias,
392     int32_t* out_dtype,
393     AtenTensorHandle scale_a,
394     AtenTensorHandle scale_b,
395     AtenTensorHandle scale_result,
396     int8_t use_fast_accum,
397     AtenTensorHandle* ret0,
398     AtenTensorHandle* ret1);
399 
400 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm_v2(
401     AtenTensorHandle self,
402     AtenTensorHandle mat2,
403     AtenTensorHandle scale_a,
404     AtenTensorHandle scale_b,
405     AtenTensorHandle bias,
406     AtenTensorHandle scale_result,
407     int32_t* out_dtype,
408     int8_t use_fast_accum,
409     AtenTensorHandle* ret0);
410 
411 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_convolution(
412     AtenTensorHandle input,
413     AtenTensorHandle weight,
414     AtenTensorHandle bias, // optional argument
415     const int64_t* stride_ptr,
416     int64_t stride_size,
417     const int64_t* padding_ptr,
418     int64_t padding_size,
419     const int64_t* dilation_ptr,
420     int64_t dilation_size,
421     int transposed,
422     const int64_t* output_padding_ptr,
423     int64_t output_padding_size,
424     int64_t groups,
425     AtenTensorHandle* ret // returns new reference
426 );
427 
428 // This function will create a new uninitialized tensor object
429 // and its pointer is returned through *ret.
430 AOTI_TORCH_EXPORT AOTITorchError
431 aoti_torch_new_uninitialized_tensor(AtenTensorHandle* ret);
432 
433 // WARNING: This will be deprecated. Use aoti_torch_copy_ instead.
434 AOTI_TORCH_EXPORT AOTITorchError
435 aoti_torch_tensor_copy_(AtenTensorHandle src, AtenTensorHandle dst);
436 
437 // Make the tensor referred to by dst an alias for the tensor referred
438 // to by src. The two tensors must still be deleted with
439 // aoti_torch_delete_tensor separately (or not) as before the call.
440 AOTI_TORCH_EXPORT AOTITorchError
441 aoti_torch_assign_tensors(AtenTensorHandle src, AtenTensorHandle dst);
442 
443 // Make a shallow copy of the tensor referred to by src and assign
444 // it to the handle in the ret_dst. This is similar to the above
445 // aoti_torch_assign_tensors function, but creates and sets the
446 // ret_dst from within.
447 AOTI_TORCH_EXPORT AOTITorchError
448 aoti_torch_assign_tensors_out(AtenTensorHandle src, AtenTensorHandle* ret_dst);
449 
450 // This function will create a new tensor object and its pointer is returned
451 // through *ret. The caller is responsible for wrapping the tensor pointer
452 // with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object
453 // when going out of scope.
454 AOTI_TORCH_EXPORT AOTITorchError
455 aoti_torch_clone(AtenTensorHandle self, AtenTensorHandle* ret);
456 
457 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_addmm_out(
458     AtenTensorHandle out,
459     AtenTensorHandle self,
460     AtenTensorHandle mat1,
461     AtenTensorHandle mat2,
462     float beta,
463     float alpha);
464 
465 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_bmm_out(
466     AtenTensorHandle out,
467     AtenTensorHandle self,
468     AtenTensorHandle mat2);
469 
470 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_copy_(
471     AtenTensorHandle self,
472     AtenTensorHandle src,
473     int32_t non_blocking);
474 
475 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mm_out(
476     AtenTensorHandle out,
477     AtenTensorHandle self,
478     AtenTensorHandle mat2);
479 
480 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__mm_plus_mm_out(
481     AtenTensorHandle out,
482     AtenTensorHandle a,
483     AtenTensorHandle b,
484     AtenTensorHandle c,
485     AtenTensorHandle d);
486 
487 // This will soon be deprecated after ao_quantization is complete.
488 // Please refrain from using this or increasing callsites.
489 AOTI_TORCH_EXPORT AOTITorchError
490 aoti_torch_cpu_wrapped_fbgemm_pack_gemm_matrix_fp16(
491     AtenTensorHandle weight,
492     AtenTensorHandle* out);
493 
494 // This will soon be deprecated after ao_quantization is complete.
495 // Please refrain from using this or increasing callsites.
496 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__wrapped_linear_prepack(
497     AtenTensorHandle weight,
498     AtenTensorHandle weight_scale,
499     AtenTensorHandle weight_zero_point,
500     AtenTensorHandle bias,
501     AtenTensorHandle* out);
502 
503 // This will soon be deprecated after ao_quantization is complete.
504 // Please refrain from using this or increasing callsites.
505 AOTI_TORCH_EXPORT AOTITorchError
506 aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight(
507     AtenTensorHandle input,
508     AtenTensorHandle weight,
509     AtenTensorHandle bias,
510     int64_t out_channel,
511     AtenTensorHandle* out);
512 
513 // This will soon be deprecated after ao_quantization is complete.
514 // Please refrain from using this or increasing callsites.
515 AOTI_TORCH_EXPORT AOTITorchError
516 aoti_torch_cpu__wrapped_quantized_linear_prepacked(
517     AtenTensorHandle input,
518     AtenTensorHandle input_scale,
519     AtenTensorHandle input_zero_point,
520     AtenTensorHandle weight,
521     AtenTensorHandle out_scale,
522     AtenTensorHandle out_zeropoint,
523     int64_t out_channel,
524     AtenTensorHandle* out);
525 
526 AOTI_TORCH_EXPORT AOTITorchError
527 aoti_torch_nonzero(AtenTensorHandle self, AtenTensorHandle* out);
528 
529 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor(
530     AtenTensorHandle repeats,
531     int64_t* output_size,
532     AtenTensorHandle* out);
533 
534 AOTI_TORCH_EXPORT AOTITorchError
535 aoti_torch_check_inf_and_nan(const char* tensor_name, AtenTensorHandle tensor);
536 
537 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scatter_out(
538     AtenTensorHandle out,
539     AtenTensorHandle self,
540     int64_t dim,
541     AtenTensorHandle index,
542     AtenTensorHandle src);
543 
544 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scatter_reduce_out(
545     AtenTensorHandle out,
546     AtenTensorHandle self,
547     int64_t dim,
548     AtenTensorHandle index,
549     AtenTensorHandle src,
550     const char* reduce,
551     int32_t include_self);
552 
553 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_index_put_out(
554     AtenTensorHandle out,
555     AtenTensorHandle self,
556     const AtenTensorHandle* indices,
557     const uint32_t num_indices,
558     const AtenTensorHandle values,
559     bool accumulate);
560 
561 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_as_real(
562     AtenTensorHandle self,
563     AtenTensorHandle* ret // returns new reference
564 );
565 
566 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_dtype(
567     AtenTensorHandle self,
568     int32_t dtype,
569     AtenTensorHandle* ret // returns new reference
570 );
571 
572 AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle(
573     AtenTensorHandle self,
574     const char* msg);
575 
576 // When AOTI debug printer option is enabled, this function will be invoked to
577 // torch pickle save the intermediate tensor for debugging purpose.
578 AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle(
579     AtenTensorHandle self,
580     const char* tensor_name,
581     const char* launch_prefix,
582     const char* kernel_name);
583 
584 #ifdef USE_CUDA
585 
586 struct CUDAGuardOpaque;
587 using CUDAGuardHandle = CUDAGuardOpaque*;
588 
589 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_cuda_guard(
590     int32_t device_index,
591     CUDAGuardHandle* ret_guard // returns new reference
592 );
593 
594 AOTI_TORCH_EXPORT AOTITorchError
595 aoti_torch_delete_cuda_guard(CUDAGuardHandle guard);
596 
597 AOTI_TORCH_EXPORT AOTITorchError
598 aoti_torch_cuda_guard_set_index(CUDAGuardHandle guard, int32_t device_index);
599 
600 struct CUDAStreamGuardOpaque;
601 using CUDAStreamGuardHandle = CUDAStreamGuardOpaque*;
602 
603 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_cuda_stream_guard(
604     void* stream,
605     int32_t device_index,
606     CUDAStreamGuardHandle* ret_guard // returns new reference
607 );
608 
609 AOTI_TORCH_EXPORT AOTITorchError
610 aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard);
611 
612 AOTI_TORCH_EXPORT AOTITorchError
613 aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream);
614 
615 #endif
616 
617 // See `ProxyExecutor Design Note` in ir.py for more details
618 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_proxy_executor_call_function(
619     AOTIProxyExecutorHandle proxy_executor,
620     int extern_node_index,
621     int num_ints,
622     int64_t* flatten_int_args,
623     int num_tensors,
624     AtenTensorHandle* flatten_tensor_args);
625 
626 AOTI_TORCH_EXPORT void aoti_torch_check(
627     bool cond,
628     const char* func,
629     const char* file,
630     uint32_t line,
631     const char* msg);
632 
633 #ifdef STRIP_ERROR_MESSAGES
634 #define AOTI_TORCH_CHECK(cond, ...)              \
635   if (!(cond)) {                                 \
636     aoti_torch_check(                            \
637         false,                                   \
638         __func__,                                \
639         __FILE__,                                \
640         static_cast<uint32_t>(__LINE__),         \
641         TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \
642   }
643 #else
644 #define AOTI_TORCH_CHECK(cond, ...)                \
645   if (!(cond)) {                                   \
646     aoti_torch_check(                              \
647         false,                                     \
648         __func__,                                  \
649         __FILE__,                                  \
650         static_cast<uint32_t>(__LINE__),           \
651         TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); \
652   }
653 #endif
654 
655 #ifdef __cplusplus
656 } // extern "C"
657 
658 template <typename T>
659 int32_t aoti_torch_dtype() = delete;
660 
661 #define DEFINE_DTYPE_SPECIALIZATION(ctype, typename) \
662   template <>                                        \
663   inline int32_t aoti_torch_dtype<ctype>() {         \
664     return aoti_torch_dtype_##typename();            \
665   }
666 
667 namespace c10 {
668 struct BFloat16;
669 struct Half;
670 } // namespace c10
671 
672 DEFINE_DTYPE_SPECIALIZATION(c10::BFloat16, bfloat16)
673 DEFINE_DTYPE_SPECIALIZATION(c10::Half, float16)
674 DEFINE_DTYPE_SPECIALIZATION(c10::complex<float>, complex64)
675 DEFINE_DTYPE_SPECIALIZATION(float, float32)
676 DEFINE_DTYPE_SPECIALIZATION(double, float64)
677 DEFINE_DTYPE_SPECIALIZATION(uint8_t, uint8)
678 DEFINE_DTYPE_SPECIALIZATION(int8_t, int8)
679 DEFINE_DTYPE_SPECIALIZATION(int16_t, int16)
680 DEFINE_DTYPE_SPECIALIZATION(int32_t, int32)
681 DEFINE_DTYPE_SPECIALIZATION(int64_t, int64)
682 DEFINE_DTYPE_SPECIALIZATION(bool, bool)
683 
684 #endif
685 
686 #endif // AOTI_TORCH_SHIM
687