1 #ifndef AOTI_TORCH_SHIM 2 #define AOTI_TORCH_SHIM 3 4 #include <stddef.h> 5 #include <stdint.h> 6 7 // This header defines a stable C API for certain ATen functionality in 8 // libtorch. The AOTInductor compiled model.so will only refer to this header 9 // instead of other headers from aten/c10, which means it will NOT be able to 10 // directly use any data structures or call functions from libtorch. 11 // 12 // What problems are we trying to solve here? Direct use of aten/c10 APIs 13 // means use of C++ APIs on a library that doesn't have any ABI compatibility 14 // guarantees. However, we want model.so to remain usable across updates 15 // to the PyTorch C++ libraries, which requires a stable ABI. By introducing 16 // a C shim layer, we can minimize the surface that will cause breakage. The 17 // corresponding software stack can be illustrated as follows: 18 // 19 // |--------------------------------| 20 // | inference service code | 21 // |--------------------------------| 22 // | model.so | 23 // |--------------|-----------------| 24 // | <c shim> | 25 // | libtorch.so | 26 // |--------------------------------| 27 // 28 // The general guidelines for the C API: 29 // 30 // - No exceptions, return an explicit error code to be checked at call site 31 // - Only pointers (AtenTensorHandle counts), integers and floats in headers 32 // 33 // If you want to make changes to this header, you MUST MAINTAIN ABI 34 // compatibility. Typically, this means you will have to add a _v2 version 35 // of a function that you, e.g., want to add a new function parameter to, and 36 // maintain the old and new versions of the APIs until all old model.so 37 // go out of use. 38 39 #ifdef __GNUC__ 40 #define AOTI_TORCH_EXPORT __attribute__((__visibility__("default"))) 41 #else // !__GNUC__ 42 #ifdef _WIN32 43 // PyTorch2 doesn't currently work on Windows. Exporting these APIs can lead 44 // to symbol clashes at link time if libtorch is included in a DLL and binary 45 // that depends on the DLL. As a short term fix, we don't export the symbols. 46 // In the long term, this will need to be addressed when Windows is supported. 47 // #define AOTI_TORCH_EXPORT __declspec(dllexport) 48 #define AOTI_TORCH_EXPORT 49 #else // !_WIN32 50 #define AOTI_TORCH_EXPORT 51 #endif // _WIN32 52 #endif // __GNUC__ 53 54 #include <c10/util/BFloat16.h> 55 #include <c10/util/Half.h> 56 #include <c10/util/complex.h> 57 58 #ifdef __cplusplus 59 extern "C" { 60 #endif 61 62 // AtenTensorHandle represents an abstract notion of Tensor that can be passed 63 // between model.so and libtorch.so. The contents of the structure itself 64 // are private; model.so is not allowed to access any fields directly, it must 65 // go through functions defined in this ABI. Under the hood, this is 66 // represented as at::Tensor*, but we reserve the right to change this (and in 67 // fact, we probably should change it to at::TensorImpl* at least). 68 // 69 // An AtenTensorHandle can be owning (please check the API reference for exact 70 // ownership/borrow semantics). If you have an owning AtenTensorHandle 71 // in model.so, you are obligated to aoti_torch_delete_tensor_object when you 72 // are done. You can use the helper C++ class RAIIAtenTensorHandle 73 // (see aot_runtime/model.h) to ensure the deallocator is called in RAII style 74 // (note that RAIIAtenTensorHandle is private to model.so, and never crosses 75 // the ABI boundary.) 76 struct AtenTensorOpaque; 77 using AtenTensorHandle = AtenTensorOpaque*; 78 79 struct AtenGeneratorOpaque; 80 using AtenGeneratorHandle = AtenGeneratorOpaque*; 81 82 struct AOTIProxyExecutorOpaque; 83 using AOTIProxyExecutorHandle = AOTIProxyExecutorOpaque*; 84 85 using AOTITorchError = int32_t; 86 #define AOTI_TORCH_SUCCESS 0 87 #define AOTI_TORCH_FAILURE 1 88 89 // Getter functions for retrieving various constants from the runtime, that 90 // can subsequently be passed to other aoti_* functions. By hiding these 91 // behind functions, the precise value of device/dtype is NOT part of the 92 // ABI contract. (In practice, aten/c10 is pretty good about not renumbering 93 // these, so we probably could later switch to having these in the ABI, if 94 // desired for perf reasons.) 95 AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cpu(); 96 AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cuda(); 97 98 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2(); 99 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fn(); 100 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2fnuz(); 101 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fnuz(); 102 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bfloat16(); 103 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float16(); 104 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float32(); 105 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float64(); 106 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint8(); 107 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint16(); 108 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint32(); 109 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint64(); 110 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int8(); 111 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int16(); 112 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int32(); 113 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int64(); 114 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bool(); 115 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex32(); 116 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex64(); 117 AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex128(); 118 119 AOTI_TORCH_EXPORT int32_t aoti_torch_layout_strided(); 120 AOTI_TORCH_EXPORT int32_t aoti_torch_layout__mkldnn(); 121 122 // Functions for converting a single-element tensor to a scalar value 123 AOTI_TORCH_EXPORT AOTITorchError 124 aoti_torch_item_float16(AtenTensorHandle tensor, c10::Half* ret_value); 125 AOTI_TORCH_EXPORT AOTITorchError 126 aoti_torch_item_float32(AtenTensorHandle tensor, float* ret_value); 127 AOTI_TORCH_EXPORT AOTITorchError 128 aoti_torch_item_float64(AtenTensorHandle tensor, double* ret_value); 129 AOTI_TORCH_EXPORT AOTITorchError 130 aoti_torch_item_uint8(AtenTensorHandle tensor, uint8_t* ret_value); 131 AOTI_TORCH_EXPORT AOTITorchError 132 aoti_torch_item_uint16(AtenTensorHandle tensor, uint16_t* ret_value); 133 AOTI_TORCH_EXPORT AOTITorchError 134 aoti_torch_item_uint32(AtenTensorHandle tensor, uint32_t* ret_value); 135 AOTI_TORCH_EXPORT AOTITorchError 136 aoti_torch_item_uint64(AtenTensorHandle tensor, uint64_t* ret_value); 137 AOTI_TORCH_EXPORT AOTITorchError 138 aoti_torch_item_int8(AtenTensorHandle tensor, int8_t* ret_value); 139 AOTI_TORCH_EXPORT AOTITorchError 140 aoti_torch_item_int16(AtenTensorHandle tensor, int16_t* ret_value); 141 AOTI_TORCH_EXPORT AOTITorchError 142 aoti_torch_item_int32(AtenTensorHandle tensor, int32_t* ret_value); 143 AOTI_TORCH_EXPORT AOTITorchError 144 aoti_torch_item_int64(AtenTensorHandle tensor, int64_t* ret_value); 145 AOTI_TORCH_EXPORT AOTITorchError 146 aoti_torch_item_bool(AtenTensorHandle tensor, bool* ret_value); 147 AOTI_TORCH_EXPORT AOTITorchError 148 aoti_torch_item_bfloat16(AtenTensorHandle tensor, c10::BFloat16* ret_value); 149 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_complex64( 150 AtenTensorHandle tensor, 151 c10::complex<float>* ret_value); 152 153 // Functions for wrapping a scalar value to a single-element tensor 154 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_float32( 155 float value, 156 AtenTensorHandle* ret_new_tensor); 157 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_float64( 158 double value, 159 AtenTensorHandle* ret_new_tensor); 160 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint8( 161 uint8_t value, 162 AtenTensorHandle* ret_new_tensor); 163 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint16( 164 uint16_t value, 165 AtenTensorHandle* ret_new_tensor); 166 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint32( 167 uint32_t value, 168 AtenTensorHandle* ret_new_tensor); 169 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint64( 170 uint64_t value, 171 AtenTensorHandle* ret_new_tensor); 172 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int8( 173 int8_t value, 174 AtenTensorHandle* ret_new_tensor); 175 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int16( 176 int16_t value, 177 AtenTensorHandle* ret_new_tensor); 178 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int32( 179 int32_t value, 180 AtenTensorHandle* ret_new_tensor); 181 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int64( 182 int64_t value, 183 AtenTensorHandle* ret_new_tensor); 184 AOTI_TORCH_EXPORT AOTITorchError 185 aoti_torch_scalar_to_tensor_bool(bool value, AtenTensorHandle* ret_new_tensor); 186 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_complex64( 187 c10::complex<float> value, 188 AtenTensorHandle* ret_new_tensor); 189 190 AOTI_TORCH_EXPORT bool aoti_torch_grad_mode_is_enabled(); 191 AOTI_TORCH_EXPORT void aoti_torch_grad_mode_set_enabled(bool enabled); 192 193 // Free the tensor object 194 AOTI_TORCH_EXPORT AOTITorchError 195 aoti_torch_delete_tensor_object(AtenTensorHandle tensor); 196 197 // Get a pointer to the underlying storage data 198 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_data_ptr( 199 AtenTensorHandle tensor, 200 void** ret_data_ptr // returns borrowed reference 201 ); 202 203 // Get the nbytes of the underlying storage 204 AOTI_TORCH_EXPORT AOTITorchError 205 aoti_torch_get_storage_size(AtenTensorHandle tensor, int64_t* ret_size); 206 207 AOTI_TORCH_EXPORT AOTITorchError 208 aoti_torch_get_dim(AtenTensorHandle tensor, int64_t* ret_dim); 209 210 AOTI_TORCH_EXPORT AOTITorchError 211 aoti_torch_get_numel(AtenTensorHandle tensor, int64_t* ret_numel); 212 213 AOTI_TORCH_EXPORT AOTITorchError 214 aoti_torch_get_storage_numel(AtenTensorHandle tensor, int64_t* ret_numel); 215 216 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_sizes( 217 AtenTensorHandle tensor, 218 int64_t** ret_sizes // returns borrowed reference 219 ); 220 221 AOTI_TORCH_EXPORT AOTITorchError 222 aoti_torch_get_size(AtenTensorHandle tensor, int64_t d, int64_t* ret_size); 223 224 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_strides( 225 AtenTensorHandle tensor, 226 int64_t** ret_strides // returns borrowed reference 227 ); 228 229 AOTI_TORCH_EXPORT AOTITorchError 230 aoti_torch_get_stride(AtenTensorHandle tensor, int64_t d, int64_t* ret_stride); 231 232 AOTI_TORCH_EXPORT AOTITorchError 233 aoti_torch_get_dtype(AtenTensorHandle tensor, int32_t* ret_dtype); 234 235 AOTI_TORCH_EXPORT AOTITorchError 236 aoti_torch_get_device_type(AtenTensorHandle tensor, int32_t* ret_device_type); 237 238 AOTI_TORCH_EXPORT AOTITorchError 239 aoti_torch_get_device_index(AtenTensorHandle tensor, int32_t* ret_device_index); 240 241 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset( 242 AtenTensorHandle tensor, 243 int64_t* ret_storage_offset); 244 245 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__alloc_from_pool( 246 AtenTensorHandle self, 247 int64_t offset_bytes, 248 int32_t dtype, 249 int64_t ndim, 250 const int64_t* sizes_ptr, 251 const int64_t* strides_ptr, 252 AtenTensorHandle* ret_new_tensor); 253 254 // This function will create a new tensor object and its pointer is returned 255 // through *out. The caller is responsible for wrapping the tensor pointer 256 // with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object 257 // when going out of scope. 258 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__reinterpret_tensor( 259 AtenTensorHandle self, 260 int64_t ndim, 261 const int64_t* sizes_ptr, 262 const int64_t* strides_ptr, 263 int64_t storage_offset, 264 AtenTensorHandle* ret_new_tensor // returns new reference 265 ); 266 267 // This function will create a new tensor object and its pointer is returned 268 // through *out. The caller is responsible for wrapping the tensor pointer 269 // with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object 270 // when going out of scope. 271 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_empty_strided( 272 int64_t ndim, 273 const int64_t* sizes_ptr, 274 const int64_t* strides_ptr, 275 int32_t dtype, 276 int32_t device_type, 277 int32_t device_index, 278 AtenTensorHandle* ret_new_tensor // returns new reference 279 ); 280 281 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob( 282 void* data, 283 int64_t ndim, 284 const int64_t* sizes_ptr, 285 const int64_t* strides_ptr, 286 int64_t storage_offset, 287 int32_t dtype, 288 int32_t device_type, 289 int32_t device_index, 290 AtenTensorHandle* ret // returns new reference 291 ); 292 293 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2( 294 void* data, 295 int64_t ndim, 296 const int64_t* sizes_ptr, 297 const int64_t* strides_ptr, 298 int64_t storage_offset, 299 int32_t dtype, 300 int32_t device_type, 301 int32_t device_index, 302 AtenTensorHandle* ret, // returns new reference 303 int32_t layout, 304 const uint8_t* opaque_metadata, 305 int64_t opaque_metadata_size); 306 307 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag( 308 AtenTensorHandle weight, 309 AtenTensorHandle indices, 310 AtenTensorHandle offsets, 311 int32_t scale_grad_by_freq, 312 int32_t mode, 313 int32_t sparse, 314 AtenTensorHandle per_sample_weights, // optional argument 315 int32_t include_last_offset, 316 int32_t padding_idx, 317 AtenTensorHandle* ret0, // returns new reference 318 AtenTensorHandle* ret1, // returns new reference 319 AtenTensorHandle* ret2, // returns new reference 320 AtenTensorHandle* ret3 // returns new reference 321 ); 322 323 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__fft_c2c( 324 AtenTensorHandle self, 325 const int64_t* dim_ptr, 326 int64_t dim_size, 327 int64_t normalization, 328 int32_t forward, 329 AtenTensorHandle* ret // returns new reference 330 ); 331 332 // This version is deprecated. We will remove it later 333 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention( 334 AtenTensorHandle query, 335 AtenTensorHandle key, 336 AtenTensorHandle value, 337 double dropout_p, 338 bool is_causal, 339 bool return_debug_mask, 340 double scale, 341 AtenTensorHandle* ret0, // returns new reference 342 AtenTensorHandle* ret1, // returns new reference 343 AtenTensorHandle* ret2, // returns new reference 344 AtenTensorHandle* ret3, // returns new reference 345 int64_t* ret4, 346 int64_t* ret5, 347 AtenTensorHandle* ret6, // returns new reference 348 AtenTensorHandle* ret7, // returns new reference 349 AtenTensorHandle* ret8 // returns new reference 350 ); 351 352 AOTI_TORCH_EXPORT AOTITorchError 353 aoti_torch__scaled_dot_product_flash_attention_v2( 354 AtenTensorHandle query, 355 AtenTensorHandle key, 356 AtenTensorHandle value, 357 double dropout_p, 358 int is_causal, 359 int return_debug_mask, 360 double* scale, // optional argument 361 AtenTensorHandle* ret0, // returns new reference 362 AtenTensorHandle* ret1, // returns new reference 363 AtenTensorHandle* ret2, // returns new reference 364 AtenTensorHandle* ret3, // returns new reference 365 int64_t* ret4, 366 int64_t* ret5, 367 AtenTensorHandle* ret6, // returns new reference 368 AtenTensorHandle* ret7, // returns new reference 369 AtenTensorHandle* ret8 // returns new reference 370 ); 371 372 AOTI_TORCH_EXPORT AOTITorchError 373 aoti_torch__scaled_dot_product_efficient_attention( 374 AtenTensorHandle query, 375 AtenTensorHandle key, 376 AtenTensorHandle value, 377 AtenTensorHandle attn_bias, // optional argument 378 int compute_log_sumexp, 379 double dropout_p, 380 int is_causal, 381 double* scale, // optional argument 382 AtenTensorHandle* ret0, // returns new reference 383 AtenTensorHandle* ret1, // returns new reference 384 AtenTensorHandle* ret2, // returns new reference 385 AtenTensorHandle* ret3 // returns new reference 386 ); 387 388 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm( 389 AtenTensorHandle self, 390 AtenTensorHandle mat2, 391 AtenTensorHandle bias, 392 int32_t* out_dtype, 393 AtenTensorHandle scale_a, 394 AtenTensorHandle scale_b, 395 AtenTensorHandle scale_result, 396 int8_t use_fast_accum, 397 AtenTensorHandle* ret0, 398 AtenTensorHandle* ret1); 399 400 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm_v2( 401 AtenTensorHandle self, 402 AtenTensorHandle mat2, 403 AtenTensorHandle scale_a, 404 AtenTensorHandle scale_b, 405 AtenTensorHandle bias, 406 AtenTensorHandle scale_result, 407 int32_t* out_dtype, 408 int8_t use_fast_accum, 409 AtenTensorHandle* ret0); 410 411 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_convolution( 412 AtenTensorHandle input, 413 AtenTensorHandle weight, 414 AtenTensorHandle bias, // optional argument 415 const int64_t* stride_ptr, 416 int64_t stride_size, 417 const int64_t* padding_ptr, 418 int64_t padding_size, 419 const int64_t* dilation_ptr, 420 int64_t dilation_size, 421 int transposed, 422 const int64_t* output_padding_ptr, 423 int64_t output_padding_size, 424 int64_t groups, 425 AtenTensorHandle* ret // returns new reference 426 ); 427 428 // This function will create a new uninitialized tensor object 429 // and its pointer is returned through *ret. 430 AOTI_TORCH_EXPORT AOTITorchError 431 aoti_torch_new_uninitialized_tensor(AtenTensorHandle* ret); 432 433 // WARNING: This will be deprecated. Use aoti_torch_copy_ instead. 434 AOTI_TORCH_EXPORT AOTITorchError 435 aoti_torch_tensor_copy_(AtenTensorHandle src, AtenTensorHandle dst); 436 437 // Make the tensor referred to by dst an alias for the tensor referred 438 // to by src. The two tensors must still be deleted with 439 // aoti_torch_delete_tensor separately (or not) as before the call. 440 AOTI_TORCH_EXPORT AOTITorchError 441 aoti_torch_assign_tensors(AtenTensorHandle src, AtenTensorHandle dst); 442 443 // Make a shallow copy of the tensor referred to by src and assign 444 // it to the handle in the ret_dst. This is similar to the above 445 // aoti_torch_assign_tensors function, but creates and sets the 446 // ret_dst from within. 447 AOTI_TORCH_EXPORT AOTITorchError 448 aoti_torch_assign_tensors_out(AtenTensorHandle src, AtenTensorHandle* ret_dst); 449 450 // This function will create a new tensor object and its pointer is returned 451 // through *ret. The caller is responsible for wrapping the tensor pointer 452 // with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object 453 // when going out of scope. 454 AOTI_TORCH_EXPORT AOTITorchError 455 aoti_torch_clone(AtenTensorHandle self, AtenTensorHandle* ret); 456 457 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_addmm_out( 458 AtenTensorHandle out, 459 AtenTensorHandle self, 460 AtenTensorHandle mat1, 461 AtenTensorHandle mat2, 462 float beta, 463 float alpha); 464 465 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_bmm_out( 466 AtenTensorHandle out, 467 AtenTensorHandle self, 468 AtenTensorHandle mat2); 469 470 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_copy_( 471 AtenTensorHandle self, 472 AtenTensorHandle src, 473 int32_t non_blocking); 474 475 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mm_out( 476 AtenTensorHandle out, 477 AtenTensorHandle self, 478 AtenTensorHandle mat2); 479 480 AOTI_TORCH_EXPORT AOTITorchError aoti_torch__mm_plus_mm_out( 481 AtenTensorHandle out, 482 AtenTensorHandle a, 483 AtenTensorHandle b, 484 AtenTensorHandle c, 485 AtenTensorHandle d); 486 487 // This will soon be deprecated after ao_quantization is complete. 488 // Please refrain from using this or increasing callsites. 489 AOTI_TORCH_EXPORT AOTITorchError 490 aoti_torch_cpu_wrapped_fbgemm_pack_gemm_matrix_fp16( 491 AtenTensorHandle weight, 492 AtenTensorHandle* out); 493 494 // This will soon be deprecated after ao_quantization is complete. 495 // Please refrain from using this or increasing callsites. 496 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__wrapped_linear_prepack( 497 AtenTensorHandle weight, 498 AtenTensorHandle weight_scale, 499 AtenTensorHandle weight_zero_point, 500 AtenTensorHandle bias, 501 AtenTensorHandle* out); 502 503 // This will soon be deprecated after ao_quantization is complete. 504 // Please refrain from using this or increasing callsites. 505 AOTI_TORCH_EXPORT AOTITorchError 506 aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight( 507 AtenTensorHandle input, 508 AtenTensorHandle weight, 509 AtenTensorHandle bias, 510 int64_t out_channel, 511 AtenTensorHandle* out); 512 513 // This will soon be deprecated after ao_quantization is complete. 514 // Please refrain from using this or increasing callsites. 515 AOTI_TORCH_EXPORT AOTITorchError 516 aoti_torch_cpu__wrapped_quantized_linear_prepacked( 517 AtenTensorHandle input, 518 AtenTensorHandle input_scale, 519 AtenTensorHandle input_zero_point, 520 AtenTensorHandle weight, 521 AtenTensorHandle out_scale, 522 AtenTensorHandle out_zeropoint, 523 int64_t out_channel, 524 AtenTensorHandle* out); 525 526 AOTI_TORCH_EXPORT AOTITorchError 527 aoti_torch_nonzero(AtenTensorHandle self, AtenTensorHandle* out); 528 529 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor( 530 AtenTensorHandle repeats, 531 int64_t* output_size, 532 AtenTensorHandle* out); 533 534 AOTI_TORCH_EXPORT AOTITorchError 535 aoti_torch_check_inf_and_nan(const char* tensor_name, AtenTensorHandle tensor); 536 537 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scatter_out( 538 AtenTensorHandle out, 539 AtenTensorHandle self, 540 int64_t dim, 541 AtenTensorHandle index, 542 AtenTensorHandle src); 543 544 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scatter_reduce_out( 545 AtenTensorHandle out, 546 AtenTensorHandle self, 547 int64_t dim, 548 AtenTensorHandle index, 549 AtenTensorHandle src, 550 const char* reduce, 551 int32_t include_self); 552 553 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_index_put_out( 554 AtenTensorHandle out, 555 AtenTensorHandle self, 556 const AtenTensorHandle* indices, 557 const uint32_t num_indices, 558 const AtenTensorHandle values, 559 bool accumulate); 560 561 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_as_real( 562 AtenTensorHandle self, 563 AtenTensorHandle* ret // returns new reference 564 ); 565 566 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_dtype( 567 AtenTensorHandle self, 568 int32_t dtype, 569 AtenTensorHandle* ret // returns new reference 570 ); 571 572 AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle( 573 AtenTensorHandle self, 574 const char* msg); 575 576 // When AOTI debug printer option is enabled, this function will be invoked to 577 // torch pickle save the intermediate tensor for debugging purpose. 578 AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle( 579 AtenTensorHandle self, 580 const char* tensor_name, 581 const char* launch_prefix, 582 const char* kernel_name); 583 584 #ifdef USE_CUDA 585 586 struct CUDAGuardOpaque; 587 using CUDAGuardHandle = CUDAGuardOpaque*; 588 589 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_cuda_guard( 590 int32_t device_index, 591 CUDAGuardHandle* ret_guard // returns new reference 592 ); 593 594 AOTI_TORCH_EXPORT AOTITorchError 595 aoti_torch_delete_cuda_guard(CUDAGuardHandle guard); 596 597 AOTI_TORCH_EXPORT AOTITorchError 598 aoti_torch_cuda_guard_set_index(CUDAGuardHandle guard, int32_t device_index); 599 600 struct CUDAStreamGuardOpaque; 601 using CUDAStreamGuardHandle = CUDAStreamGuardOpaque*; 602 603 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_cuda_stream_guard( 604 void* stream, 605 int32_t device_index, 606 CUDAStreamGuardHandle* ret_guard // returns new reference 607 ); 608 609 AOTI_TORCH_EXPORT AOTITorchError 610 aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard); 611 612 AOTI_TORCH_EXPORT AOTITorchError 613 aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream); 614 615 #endif 616 617 // See `ProxyExecutor Design Note` in ir.py for more details 618 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_proxy_executor_call_function( 619 AOTIProxyExecutorHandle proxy_executor, 620 int extern_node_index, 621 int num_ints, 622 int64_t* flatten_int_args, 623 int num_tensors, 624 AtenTensorHandle* flatten_tensor_args); 625 626 AOTI_TORCH_EXPORT void aoti_torch_check( 627 bool cond, 628 const char* func, 629 const char* file, 630 uint32_t line, 631 const char* msg); 632 633 #ifdef STRIP_ERROR_MESSAGES 634 #define AOTI_TORCH_CHECK(cond, ...) \ 635 if (!(cond)) { \ 636 aoti_torch_check( \ 637 false, \ 638 __func__, \ 639 __FILE__, \ 640 static_cast<uint32_t>(__LINE__), \ 641 TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \ 642 } 643 #else 644 #define AOTI_TORCH_CHECK(cond, ...) \ 645 if (!(cond)) { \ 646 aoti_torch_check( \ 647 false, \ 648 __func__, \ 649 __FILE__, \ 650 static_cast<uint32_t>(__LINE__), \ 651 TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); \ 652 } 653 #endif 654 655 #ifdef __cplusplus 656 } // extern "C" 657 658 template <typename T> 659 int32_t aoti_torch_dtype() = delete; 660 661 #define DEFINE_DTYPE_SPECIALIZATION(ctype, typename) \ 662 template <> \ 663 inline int32_t aoti_torch_dtype<ctype>() { \ 664 return aoti_torch_dtype_##typename(); \ 665 } 666 667 namespace c10 { 668 struct BFloat16; 669 struct Half; 670 } // namespace c10 671 672 DEFINE_DTYPE_SPECIALIZATION(c10::BFloat16, bfloat16) 673 DEFINE_DTYPE_SPECIALIZATION(c10::Half, float16) 674 DEFINE_DTYPE_SPECIALIZATION(c10::complex<float>, complex64) 675 DEFINE_DTYPE_SPECIALIZATION(float, float32) 676 DEFINE_DTYPE_SPECIALIZATION(double, float64) 677 DEFINE_DTYPE_SPECIALIZATION(uint8_t, uint8) 678 DEFINE_DTYPE_SPECIALIZATION(int8_t, int8) 679 DEFINE_DTYPE_SPECIALIZATION(int16_t, int16) 680 DEFINE_DTYPE_SPECIALIZATION(int32_t, int32) 681 DEFINE_DTYPE_SPECIALIZATION(int64_t, int64) 682 DEFINE_DTYPE_SPECIALIZATION(bool, bool) 683 684 #endif 685 686 #endif // AOTI_TORCH_SHIM 687