1# mypy: allow-untyped-defs 2import functools 3import math 4import os 5import sys 6from itertools import count 7from typing import Dict, List, Optional, Tuple 8 9import sympy 10from sympy import Expr 11 12import torch 13import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools 14import torch._ops 15from torch._inductor.codegen.debug_utils import IntermediateValueDebuggingLevel 16from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes 17 18from .. import config, ir 19from ..utils import _align, ALIGN_BYTES, cache_on_self, sympy_product 20from ..virtualized import V 21from .aoti_hipify_utils import maybe_hipify_code_wrapper 22from .common import IndentedBuffer 23from .cpp_utils import ( 24 cexpr, 25 DEVICE_TO_ATEN, 26 DTYPE_TO_ATEN, 27 DTYPE_TO_CPP, 28 LAYOUT_TO_ATEN, 29) 30from .wrapper import EnterSubgraphLine, ExitSubgraphLine, WrapperCodeGen 31 32 33class CppWrapperCpu(WrapperCodeGen): 34 """ 35 Generates cpp wrapper for running on CPU and calls cpp kernels 36 """ 37 38 def __init__(self): 39 if not hasattr(self, "device"): 40 self.device = "cpu" 41 super().__init__() 42 self.declare = "auto " 43 self.declare_maybe_reference = "decltype(auto) " 44 self.ending = ";" 45 self.open_bracket = "{" 46 self.closed_bracket = "}" 47 self.comment = "//" 48 self.namespace = "at::" 49 self.none_str = "nullptr" if config.abi_compatible else "at::Tensor()" 50 self.extern_call_ops = set() 51 self.size = "sizes()" 52 self.stride = "strides()" 53 self.cuda = False 54 self.supports_intermediate_hooks = False 55 self.outputs_need_copy = set() 56 self.kernel_callsite_id = count() 57 self.var_array_id = ( 58 count() 59 ) # for different types of local array variable declarations 60 self.declared_var_array_vars = set() 61 self.int_array_id = count() # for int array local variable declarations 62 self.declared_int_array_vars = set() 63 self.tmp_tensor_id = count() # for tmp tensor local variable declarations 64 self.arg_var_id = count() 65 self.used_cached_devices = set() 66 self.used_cached_dtypes = set() 67 self.used_cached_layouts = set() 68 self.cached_output_id = count() 69 self.scalar_to_tensor_id = count() 70 self.custom_op_wrapper_loaded = False 71 self.expr_printer = cexpr 72 73 def generate_kernel_call( 74 self, 75 kernel_name: str, 76 call_args, 77 grid=None, 78 device_index=None, 79 cuda=True, 80 triton=True, 81 arg_types=None, 82 raw_args=None, 83 grid_fn: str = "grid", 84 triton_meta=None, 85 autotune_configs=None, 86 grid_extra_kwargs="", 87 ): 88 """ 89 Generates kernel call code. 90 91 cuda: Defines whether the backend is GPU. Otherwise the backend is CPU. 92 93 triton: Defines whether the GPU backend uses Triton for codegen. 94 Otherwise it uses the CUDA language for codegen. 95 Only valid when cuda == True. 96 """ 97 if cuda: 98 return super().generate_kernel_call( 99 kernel_name, 100 call_args, 101 grid, 102 device_index, 103 cuda, 104 triton, 105 arg_types, 106 raw_args, 107 grid_fn, 108 triton_meta, 109 autotune_configs, 110 grid_extra_kwargs, 111 ) 112 else: 113 if config.abi_compatible: 114 assert arg_types is not None and len(call_args) == len( 115 arg_types 116 ), "Mismatch call_args and arg_types in generate_kernel_call" 117 new_args = [] 118 for idx, arg in enumerate(call_args): 119 if "*" in arg_types[idx]: 120 var_name = f"var_{next(self.arg_var_id)}" 121 self.writeline( 122 f"auto* {var_name} = get_data_ptr_wrapper({arg});" 123 ) 124 new_args.append(f"({arg_types[idx]})({var_name})") 125 else: 126 # arg is a scalar 127 new_args.append(arg) 128 self.writeline(self.wrap_kernel_call(kernel_name, new_args)) 129 else: 130 self.writeline(self.wrap_kernel_call(kernel_name, call_args)) 131 132 def write_constant(self, name, hashed): 133 # include a hash so our code cache gives different constants different files 134 self.header.writeline(f"// {name} {hashed}") 135 136 def write_header(self): 137 if V.graph.is_const_graph: 138 # We do not write header for constant graph, it will be written by main module. 139 return 140 141 if V.graph.aot_mode: 142 for header_cpp_file in ("interface.cpp", "implementation.cpp"): 143 with open( 144 os.path.join( 145 os.path.dirname(__file__), "aoti_runtime", header_cpp_file 146 ) 147 ) as f: 148 self.header.splice(f.read()) 149 else: 150 self.header.splice( 151 """ 152 import torch 153 from torch._inductor.codecache import CppWrapperCodeCache 154 155 cpp_wrapper_src = ( 156 ''' 157 """ 158 ) 159 160 if config.abi_compatible: 161 self.header.splice( 162 f"#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{self.device}.h>" 163 ) 164 self.header.splice( 165 """ 166 #include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h> 167 #include <torch/csrc/inductor/aoti_runtime/thread_local.h> 168 #include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h> 169 """ 170 ) 171 if V.graph.aot_mode: 172 self.header.splice( 173 """ 174 #include <torch/csrc/inductor/aoti_runtime/model.h> 175 """ 176 ) 177 else: 178 self.header.splice( 179 """ 180 #include <ATen/ATen.h> 181 #include <ATen/core/dispatch/Dispatcher.h> 182 #include <ATen/native/BinaryOps.h> 183 #include <torch/csrc/inductor/aoti_runtime/utils.h> 184 #include <torch/csrc/inductor/aoti_torch/tensor_converter.h> 185 #include <torch/csrc/inductor/aoti_torch/utils.h> 186 #include <torch/csrc/inductor/inductor_ops.h> 187 #include <torch/types.h> 188 #include <ATen/ops/bernoulli_native.h> 189 190 #define reinterpret_tensor torch::inductor::_reinterpret_tensor 191 #define alloc_from_pool torch::inductor::_alloc_from_pool 192 """ 193 ) 194 enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ 195 "linux", 196 "win32", 197 ] 198 if config.profiler_mark_wrapper_call or enable_kernel_profile: 199 self.header.splice("#include <ATen/record_function.h>") 200 201 self.header.splice("typedef at::Half half;") 202 self.header.splice("typedef at::BFloat16 bfloat16;") 203 self.header.splice("#include <c10/util/generic_math.h>") 204 205 if not V.graph.aot_mode: 206 self.header.splice( 207 """ 208 #include <pybind11/pybind11.h> 209 210 namespace py = pybind11; 211 using namespace torch::aot_inductor; 212 213 class RAIIPyObject { 214 public: 215 RAIIPyObject() : obj_(nullptr) {} 216 RAIIPyObject(PyObject* obj) : obj_(obj) {} 217 ~RAIIPyObject() { 218 Py_XDECREF(obj_); 219 } 220 RAIIPyObject& operator=(const RAIIPyObject& other) { 221 if (this != &other) { 222 Py_XDECREF(obj_); 223 obj_ = other.obj_; 224 Py_XINCREF(obj_); 225 } 226 return *this; 227 } 228 operator PyObject*() { 229 return obj_; 230 } 231 PyObject* get() { 232 return obj_; 233 } 234 private: 235 PyObject* obj_; 236 }; 237 """ 238 ) 239 240 # Round up to the nearest multiple of ALIGN_BYTES 241 # ALIGN_BYTES must be a power of 2 242 self.header.splice( 243 f""" 244 [[maybe_unused]] static int64_t align(int64_t nbytes) {{ 245 return (nbytes + {ALIGN_BYTES} - 1) & -{ALIGN_BYTES}; 246 }} 247 """ 248 ) 249 250 @functools.lru_cache(None) # noqa: B019 251 def include_extra_header(self, header: str): 252 # This is needed for cpp to python dtype conversion 253 self.header.splice(f"#include <{header}>") 254 255 def mark_output_type(self): 256 # mark output type to unwrap tensor back to python scalar 257 from ..ir import ShapeAsConstantBuffer 258 259 output_is_tensor = {} 260 for idx, x in enumerate(V.graph.graph_outputs): 261 if isinstance(x, ShapeAsConstantBuffer): 262 output_is_tensor[idx] = False 263 else: 264 output_is_tensor[idx] = True 265 266 self.output_is_tensor = output_is_tensor 267 268 def write_prefix(self): 269 if V.graph.is_const_graph: 270 # We do not write prefix for constant graph, it will be written by main module. 271 return 272 273 if V.graph.aot_mode: 274 self.prefix.writeline("namespace torch {") 275 self.prefix.writeline("namespace aot_inductor {") 276 277 def write_input_output_info( 278 self, 279 info_kind: str, 280 idx: int, 281 name: str, 282 ): 283 self.prefix.writeline(f"""{info_kind}[{idx}].name = "{name}";""") 284 285 @staticmethod 286 def get_input_cpp_type(input): 287 assert config.use_minimal_arrayref_interface 288 289 if isinstance(input, sympy.Expr): 290 from ..graph import may_get_constant_buffer_dtype 291 292 dtype = may_get_constant_buffer_dtype(input) 293 assert dtype is not None, f"Failed to get the dtype of sympy.Expr: {input}" 294 return DTYPE_TO_CPP[dtype] 295 return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>" 296 297 def generate_input_output_runtime_checks(self): 298 # In debug_compile mode, we generate checks to ensure the dtype/shape/stride of each 299 # real input/output tensor match ones provided at compile time via sample 300 # input/output. 301 def gen_check(handle_kind, idx, name, tensor): 302 self.prefix.writeline(f"auto {name} = {handle_kind}[{idx}];") 303 self.codegen_tensor_dtype_var_decl(self.prefix, name) 304 expected_dtype_name = DTYPE_TO_ATEN[tensor.dtype] 305 dtype_str = str(tensor.dtype).split(".")[-1] 306 self.prefix.splice( 307 f""" 308 int32_t {name}_expected_dtype = aoti_torch_dtype_{dtype_str}(); 309 if ({name}_expected_dtype != {name}_dtype) {{ 310 std::stringstream ss; 311 ss << "{handle_kind}[{idx}]: unmatched dtype, " 312 << "expected: " << {name}_expected_dtype << "({expected_dtype_name}), " 313 << "but got: " << {name}_dtype << "\\n"; 314 throw std::runtime_error(ss.str()); 315 }} 316 """ 317 ) 318 self.codegen_input_size_var_decl(self.prefix, name) 319 for dim_idx, d in enumerate(tensor.get_size()): 320 if isinstance(d, (int, sympy.Integer)): 321 self.prefix.splice( 322 f""" 323 if ({d} != {name}_size[{dim_idx}]) {{ 324 std::stringstream ss; 325 ss << "{handle_kind}[{idx}]: unmatched dim value at {dim_idx}, " 326 << "expected: {d}, " << "but got: " << {name}_size[{dim_idx}] 327 << "\\n"; 328 throw std::runtime_error(ss.str()); 329 }} 330 """ 331 ) 332 else: 333 from torch.utils._sympy.value_ranges import bound_sympy 334 335 sym_range = bound_sympy(d, V.graph.sizevars.shape_env.var_to_range) 336 if not math.isinf(sym_range.lower): 337 self.prefix.splice( 338 f""" 339 if ({name}_size[{dim_idx}] < {sym_range.lower}) {{ 340 std::stringstream ss; 341 ss << "{handle_kind}[{idx}]: dim value is too small at {dim_idx}, " 342 << "expected it to be >= {sym_range.lower}, " << "but got: " 343 << {name}_size[{dim_idx}] << "\\n"; 344 throw std::runtime_error(ss.str()); 345 }} 346 """ 347 ) 348 if not math.isinf(sym_range.upper): 349 self.prefix.splice( 350 f""" 351 if ({name}_size[{dim_idx}] > {sym_range.upper}) {{ 352 std::stringstream ss; 353 ss << "{handle_kind}[{idx}]: dim value is too large at {dim_idx}, " 354 << "expected to be <= {sym_range.upper}, " << "but got: " 355 << {name}_size[{dim_idx}] << "\\n"; 356 throw std::runtime_error(ss.str()); 357 }} 358 """ 359 ) 360 361 self.codegen_input_stride_var_decl(self.prefix, name) 362 for stride_idx, s in enumerate(tensor.get_stride()): 363 if not isinstance(s, (int, sympy.Integer)): 364 continue 365 self.prefix.splice( 366 f""" 367 if ({s} != {name}_stride[{stride_idx}]) {{ 368 std::stringstream ss; 369 ss << "{handle_kind}[{idx}]: unmatched stride value at {stride_idx}, " 370 << "expected: {s}, " << "but got: " << {name}_stride[{stride_idx}] 371 << "\\n"; 372 throw std::runtime_error(ss.str()); 373 }} 374 """ 375 ) 376 377 # force noinline to avoid any potential compilation slowdown due to aggressive 378 # inline done by the host compiler 379 self.prefix.splice( 380 """ 381 AOTI_NOINLINE static void __check_inputs_outputs( 382 AtenTensorHandle* input_handles, 383 AtenTensorHandle* output_handles) { 384 """ 385 ) 386 with self.prefix.indent(): 387 for idx, (name, tensor) in enumerate(V.graph.graph_inputs.items()): 388 gen_check("input_handles", idx, name, tensor) 389 self.prefix.writeline("}") 390 391 def write_wrapper_decl(self): 392 inputs_len = len(V.graph.graph_inputs.keys()) 393 if V.graph.aot_mode: 394 if config.use_minimal_arrayref_interface and not V.graph.is_const_graph: 395 input_cpp_types = ", ".join( 396 f"{CppWrapperCpu.get_input_cpp_type(x)}" 397 for x in V.graph.graph_inputs.values() 398 ) 399 output_arrayref_types = ", ".join( 400 f"ArrayRefTensor<{DTYPE_TO_CPP[x.get_dtype()]}>" 401 for x in V.graph.graph_outputs 402 ) 403 404 self.prefix.splice( 405 f""" 406 using AOTInductorModelInputs = std::tuple<{input_cpp_types}>; 407 using AOTInductorModelOutputs = std::tuple<{output_arrayref_types}>; 408 """ 409 ) 410 411 if V.graph.const_module: 412 self.header.splice(V.graph.const_module.wrapper_code.header) 413 self.prefix.splice(V.graph.const_code) 414 415 if V.graph.is_const_graph: 416 self.prefix.splice( 417 """ 418 void AOTInductorModel::_const_run_impl( 419 std::vector<AtenTensorHandle>& output_handles, 420 DeviceStreamType stream, 421 AOTIProxyExecutorHandle proxy_executor 422 ) { 423 """ 424 ) 425 else: 426 if not config.aot_inductor.use_runtime_constant_folding: 427 # If we do not split the constant graph, we'll just create 428 # an empty implementation when wrapping the main module. 429 self.prefix.splice( 430 """ 431 void AOTInductorModel::_const_run_impl( 432 std::vector<AtenTensorHandle>& output_handles, 433 DeviceStreamType stream, 434 AOTIProxyExecutorHandle proxy_executor 435 ) {} 436 437 """ 438 ) 439 440 run_impl_proto = """ 441 void AOTInductorModel::run_impl( 442 AtenTensorHandle* 443 input_handles, // array of input AtenTensorHandle; handles 444 // are stolen; the array itself is borrowed 445 AtenTensorHandle* 446 output_handles, // array for writing output AtenTensorHandle; handles 447 // will be stolen by the caller; the array itself is 448 // borrowed 449 DeviceStreamType stream, 450 AOTIProxyExecutorHandle proxy_executor 451 ) { 452 """ 453 # Since we are removing non-abi-compatible mode, let's generate 454 # runtime checks only for abi_compatible mode to avoid extra branches. 455 if config.aot_inductor.debug_compile and config.abi_compatible: 456 self.generate_input_output_runtime_checks() 457 run_impl_proto += """ 458 __check_inputs_outputs(input_handles, output_handles); 459 """ 460 if config.use_minimal_arrayref_interface: 461 self.prefix.splice( 462 """ 463 template <> 464 AOTInductorModelOutputs AOTInductorModel::run_impl_minimal_arrayref_interface< 465 AOTInductorModelInputs, AOTInductorModelOutputs>( 466 const AOTInductorModelInputs& inputs, 467 DeviceStreamType stream, 468 AOTIProxyExecutorHandle proxy_executor 469 ) { 470 """ 471 ) 472 self.suffix.splice(run_impl_proto) 473 self.suffix.splice( 474 """ 475 AOTInductorModelInputs inputs; 476 convert_handles_to_inputs(input_handles, inputs); 477 auto outputs = run_impl_minimal_arrayref_interface<AOTInductorModelInputs, AOTInductorModelOutputs>( 478 inputs, stream, proxy_executor); 479 // NOTE: outputs is full of ArrayRef to thread_local storage. If in the future we need this 480 // interface to perform well for a DSO using the minimal arrayref interface, all we need 481 // to do is provide ThreadLocalCachedTensor for each one! 482 convert_outputs_to_handles(outputs, output_handles); 483 } 484 """ 485 ) 486 487 self.suffix.splice( 488 """ 489 extern "C" AOTIRuntimeError AOTInductorModelRunMinimalArrayrefInterface( 490 AOTInductorModelHandle model_handle, 491 const AOTInductorModelInputs& inputs, 492 AOTInductorModelOutputs& outputs) { 493 auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle); 494 CONVERT_EXCEPTION_TO_ERROR_CODE({ 495 outputs = model->run_impl_minimal_arrayref_interface<AOTInductorModelInputs, AOTInductorModelOutputs>( 496 inputs, 497 (torch::aot_inductor::DeviceStreamType)nullptr, 498 nullptr); 499 }) 500 } 501 """ 502 ) 503 else: 504 self.prefix.splice(run_impl_proto) 505 else: 506 # cpp entry function for JIT with cpp wrapper 507 self.prefix.splice( 508 """ 509 void inductor_entry_impl( 510 AtenTensorHandle* 511 input_handles, // array of input AtenTensorHandle; handles 512 // are stolen; the array itself is borrowed 513 AtenTensorHandle* 514 output_handles // array for writing output AtenTensorHandle; handles 515 // will be stolen by the caller; the array itself is 516 // borrowed) 517 ) { 518 """ 519 ) 520 with self.prefix.indent(): 521 # assign inputs and outputs in both cases so the later codegen can be simplified 522 if not config.use_minimal_arrayref_interface: 523 if not V.graph.is_const_graph: 524 if V.graph.aot_mode: 525 num_args = len(V.graph.graph_inputs) 526 else: 527 # Weights are promoted in the JIT mode 528 num_args = len(V.graph.graph_inputs) + len(V.graph.constants) 529 # release GIL to support multiple instances inference (in different threads of the same process) 530 self.prefix.splice("py::gil_scoped_release release;") 531 532 if config.abi_compatible: 533 self.prefix.splice( 534 f""" 535 auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args}); 536 """ 537 ) 538 else: 539 # This looks dumb, but can avoid creating two versions of code in the AOTInductor runtime. 540 self.prefix.splice( 541 f""" 542 auto inputs = alloc_tensors_by_stealing_from_handles(input_handles, {num_args}); 543 """ 544 ) 545 546 if inputs_len != 0: 547 for idx, input_key in enumerate(V.graph.graph_inputs.keys()): 548 if config.use_minimal_arrayref_interface: 549 self.prefix.writeline( 550 f"auto {input_key} = std::get<{idx}>(inputs);" 551 ) 552 continue 553 # unwrap input tensor back to scalar 554 if isinstance(V.graph.graph_inputs[input_key], sympy.Expr): 555 from ..graph import may_get_constant_buffer_dtype 556 557 dtype = may_get_constant_buffer_dtype( 558 V.graph.graph_inputs[input_key] # type: ignore[arg-type] 559 ) 560 assert ( 561 dtype is not None 562 ), "Fails to get the dtype of the sympy.Expr" 563 cpp_dtype = DTYPE_TO_CPP[dtype] 564 if config.abi_compatible: 565 self.codegen_tensor_item( 566 dtype, f"inputs[{idx}]", input_key, self.prefix 567 ) 568 else: 569 self.prefix.writeline( 570 f"{cpp_dtype} {input_key} = inputs[{idx}].item<{cpp_dtype}>();" 571 ) 572 else: 573 self.prefix.writeline( 574 f"auto {input_key} = std::move(inputs[{idx}]);" 575 ) 576 577 assert all( 578 isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) 579 ), "Expect all constants to be Tensor" 580 for idx, constants_key in enumerate(V.graph.constants.keys()): 581 if V.graph.aot_mode: 582 # Weights are stored in constants_ and owned by RAIIAtenTensorHandle there. 583 # Don't call std::move here because it will cause constants_ to lose the ownership. 584 if config.abi_compatible: 585 self.prefix.writeline( 586 f"""auto {constants_key} = constants_->at({idx});""" 587 ) 588 else: 589 self.prefix.writeline( 590 f"auto {constants_key} = *tensor_handle_to_tensor_pointer(" 591 + f"""constants_->at({idx}));""" 592 ) 593 else: 594 # Append constants as inputs to the graph 595 constants_idx = inputs_len + idx 596 if config.abi_compatible: 597 self.prefix.writeline( 598 f"auto {constants_key} = std::move(inputs[{constants_idx}]);" 599 ) 600 else: 601 self.prefix.writeline( 602 f"auto {constants_key} = inputs[{constants_idx}];" 603 ) 604 605 self.codegen_inputs(self.prefix, V.graph.graph_inputs) 606 607 if V.graph.aot_mode: 608 if not V.graph.is_const_graph: 609 if config.use_minimal_arrayref_interface: 610 # TODO: input shape checking for regular tensor interface as well? 611 self.codegen_input_numel_asserts() 612 else: 613 self.prefix.writeline("inputs.clear();") 614 self.prefix.writeline( 615 "auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());" 616 ) 617 618 def codegen_input_numel_asserts(self): 619 for name, buf in V.graph.graph_inputs.items(): 620 if isinstance(buf, sympy.Expr): 621 continue 622 623 # comparing strides for 0 size tensor is tricky. Ignore them for now. 624 if sympy_product(buf.get_size()) == 0: 625 continue 626 numel = buf.get_numel() 627 self.prefix.writeline(f"assert_numel({name}, {numel});") 628 629 def codegen_tensor_dtype_var_decl(self, code: IndentedBuffer, name): 630 if config.abi_compatible: 631 code.writeline(f"int32_t {name}_dtype;") 632 code.writeline( 633 "AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype" 634 f"({name}, &{name}_dtype));" 635 ) 636 else: 637 # Note that we don't have a corresponding class method from 638 # the WrapperCodeGen since this method is used for asserting AOTI 639 # cpp wrapper code. 640 code.writeline(f"auto {name}_dtype = {name}.dtype();") 641 642 def codegen_input_size_var_decl(self, code: IndentedBuffer, name): 643 if config.abi_compatible: 644 code.writeline(f"int64_t* {name}_size;") 645 code.writeline( 646 f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes({name}, &{name}_size));" 647 ) 648 else: 649 super().codegen_input_size_var_decl(code, name) 650 651 def codegen_input_stride_var_decl(self, code: IndentedBuffer, name): 652 if config.abi_compatible: 653 code.writeline(f"int64_t* {name}_stride;") 654 code.writeline( 655 f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides({name}, &{name}_stride));" 656 ) 657 else: 658 super().codegen_input_stride_var_decl(code, name) 659 660 def codegen_model_kernels(self): 661 self.prefix.writeline("namespace {") 662 self.prefix.writeline( 663 "class AOTInductorModelKernels : public AOTInductorModelKernelsBase {" 664 ) 665 self.prefix.writeline(" public:") 666 declare_kernel = set(self.src_to_kernel.values()) 667 declare_kernel.update( 668 entry[0] for entry in self.user_defined_kernel_cache.values() 669 ) 670 if V.graph.const_module: 671 declare_kernel.update( 672 V.graph.const_module.wrapper_code.src_to_kernel.values() 673 ) 674 for kernel in sorted(declare_kernel): 675 self.prefix.writeline( 676 maybe_hipify_code_wrapper(f" CUfunction {kernel}{{nullptr}};") 677 ) 678 self.prefix.writeline("};") 679 self.prefix.writeline("} // namespace") 680 681 def codegen_model_constructor(self): 682 """ 683 // Generated code example 684 AOTInductorModel::AOTInductorModel() 685 : AOTInductorModelBase(4, 1) { 686 inputs_info_[0].name = "input0"; 687 inputs_info_[0].dtype = "torch.float16"; 688 ... 689 constants_info_[0].name = "L__self___weight"; 690 constants_info_[0].dtype = at::kFloat; 691 constants_info_[0].offset = 0; 692 constants_info_[0].data_size = 8192; 693 constants_info_[0].shape = {64, 32}; 694 constants_info_[0].stride = {32, 1}; 695 ... 696 outputs_info_[0].name = "output0"; 697 outputs_info_[0].dtype = "torch.float16"; 698 } 699 """ 700 701 num_inputs = len(V.graph.graph_inputs) 702 num_outputs = len(V.graph.graph_outputs) 703 num_constants = len(V.graph.constants) 704 self.prefix.splice( 705 f""" 706 AOTInductorModel::AOTInductorModel(std::shared_ptr<ConstantMap> constants_map, 707 std::shared_ptr<std::vector<ConstantHandle>> constants_array, 708 const std::string& device_str, 709 std::optional<std::string> cubin_dir) 710 : AOTInductorModelBase({num_inputs}, {num_outputs}, {num_constants}, device_str, cubin_dir) {{ 711 """ 712 ) 713 714 with self.prefix.indent(): 715 for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()): 716 assert not isinstance( 717 inp, sympy.Expr 718 ), f"input {name=} cannot be symbolic" 719 self.write_input_output_info("inputs_info_", idx, name) 720 721 all_cuda = all( 722 V.graph.get_original_value_of_constant(name).is_cuda 723 for name in V.graph.constants.keys() 724 if name not in V.graph.folded_constants 725 ) 726 for idx, name in enumerate(V.graph.constants.keys()): 727 tensor = V.graph.get_original_value_of_constant(name) 728 assert isinstance(tensor, torch.Tensor) 729 self.prefix.writeline(f"""constants_info_[{idx}].name = "{name}";""") 730 self.prefix.writeline( 731 f"constants_info_[{idx}].dtype = static_cast<int32_t>({self.codegen_dtype(tensor.dtype)});" 732 ) 733 self.prefix.writeline( 734 f"constants_info_[{idx}].offset = {tensor.storage_offset()};" 735 ) 736 737 # If constants to serialize contain cpu tensors, we always align data_size it to 64. 738 # When loading the constants, the valid data will depends on the size 739 # not the data_size so there won't be correctness issue. 740 data_size = ( 741 torch.ops.mkldnn._nbytes(tensor) 742 if tensor.is_mkldnn 743 else tensor.untyped_storage().nbytes() 744 ) 745 self.prefix.writeline( 746 f"constants_info_[{idx}].data_size = {data_size if all_cuda else _align(data_size)};" 747 ) 748 749 from_folded = "true" if name in V.graph.folded_constants else "false" 750 self.prefix.writeline( 751 f"constants_info_[{idx}].from_folded = {from_folded};" 752 ) 753 754 size_str = ", ".join([str(s) for s in tensor.size()]) 755 self.prefix.writeline(f"constants_info_[{idx}].shape = {{{size_str}}};") 756 757 stride_str = ", ".join([str(s) for s in tensor.stride()]) 758 self.prefix.writeline( 759 f"constants_info_[{idx}].stride = {{{stride_str}}};" 760 ) 761 self.prefix.writeline( 762 f"constants_info_[{idx}].layout = static_cast<int32_t>({self.codegen_layout(tensor.layout)});" 763 ) 764 765 if tensor.is_mkldnn: 766 opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md( 767 tensor 768 ) 769 assert ( 770 opaque_metadata_tensor.dim() == 1 771 ), "Expect opaque_metadata_tensor to be 1-D" 772 773 opaque_metadata_list = opaque_metadata_tensor.tolist() 774 opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list) 775 self.prefix.writeline( 776 f"constants_info_[{idx}].opaque_metadata = {opaque_metadata_str};" 777 ) 778 if name in V.graph.dynamo_flat_name_to_original_fqn: 779 original_fqn = V.graph.dynamo_flat_name_to_original_fqn.get( 780 name, name 781 ) 782 elif name in V.graph.allocated_constant_name: 783 original_fqn = V.graph.allocated_constant_name[name] 784 else: 785 raise AssertionError("original_fqn must be set for constant") 786 self.prefix.writeline( 787 f"""constants_info_[{idx}].original_fqn = "{original_fqn}";""" 788 ) 789 self.prefix.writeline("update_constants_map(std::move(constants_map));") 790 self.prefix.writeline("update_constants_array(std::move(constants_array));") 791 792 def escape_string(x): 793 return ( 794 x.replace("\\", "\\\\") 795 .replace('"', '\\"') 796 .replace("\n", "\\n") 797 .replace("\t", "\\t") 798 ) 799 800 self.prefix.writeline( 801 f'in_spec_ = "{escape_string(config.aot_inductor.serialized_in_spec)}";' 802 ) 803 self.prefix.writeline( 804 f'out_spec_ = "{escape_string(config.aot_inductor.serialized_out_spec)}";' 805 ) 806 807 for idx, output in enumerate(V.graph.graph_outputs): 808 assert not isinstance( 809 output, sympy.Expr 810 ), f"output {name=} cannot be symbolic" 811 name = f"output{idx}" 812 self.write_input_output_info("outputs_info_", idx, name) 813 814 self.prefix.writeline( 815 "this->kernels_ = std::make_unique<AOTInductorModelKernels>();" 816 ) 817 818 self.prefix.writeline("}") 819 820 def codegen_const_run_driver(self): 821 """ 822 // Generated code example 823 std::unordered_map<std::string, AtenTensorHandle> AOTInductorModel::const_run_impl( 824 DeviceStreamType stream, 825 AOTIProxyExecutorHandle proxy_executor, 826 bool initialization 827 ) { 828 std::unordered_map<std::string, AtenTensorHandle> folded_constants_map; 829 std::vector<AtenTensorHandle> output_handles; 830 // build up output_handles over here. 831 _const_run_impl(output_handles, stream, proxy_executor); 832 // build up folded_constants_map 833 return folded_constants_map; 834 } 835 """ 836 837 self.prefix.splice( 838 """ 839 std::unordered_map<std::string, AtenTensorHandle> AOTInductorModel::const_run_impl( 840 DeviceStreamType stream, 841 AOTIProxyExecutorHandle proxy_executor, 842 bool initialization 843 ) { 844 """ 845 ) 846 if not config.aot_inductor.use_runtime_constant_folding: 847 self.prefix.splice( 848 """ 849 if (!initialization) { 850 std::cerr << "[WARNING] Calling constant_folding in model, but compiled with config: " 851 << "aot_inductor.use_runtime_constant_folding=False\\n"; 852 } 853 return {}; 854 } 855 """ 856 ) 857 return 858 859 with self.prefix.indent(): 860 # This is a mapping to the index of constant folding graph's output 861 const_index_mapping: List[Optional[Tuple[int, str]]] = [None] * len( 862 V.graph.const_output_index 863 ) 864 for idx, (name, _) in enumerate(V.graph.constants.items()): 865 if name in V.graph.const_output_index: 866 const_index_mapping[V.graph.const_output_index[name]] = (idx, name) # type: ignore[call-overload] 867 assert ( 868 None not in const_index_mapping 869 ), "Not all constant gets mapped for constant folding graph." 870 871 self.prefix.writeline( 872 f""" 873 std::unordered_map<std::string, AtenTensorHandle> folded_constants_map; 874 folded_constants_map.reserve({len(const_index_mapping)}); 875 std::vector<AtenTensorHandle> output_handles({len(const_index_mapping)}); 876 """ 877 ) 878 879 self.prefix.splice( 880 """ 881 // The below assignment of output_handles to constants is not used directly. 882 // It's only used to memo the correspondence of handle and constants. 883 """ 884 ) 885 886 for output_idx, (const_idx, _) in enumerate(const_index_mapping): # type: ignore[misc] 887 self.prefix.writeline( 888 f"output_handles[{output_idx}] = constants_->at({const_idx});" 889 ) 890 891 self.prefix.writeline( 892 "_const_run_impl(output_handles, stream, proxy_executor);" 893 ) 894 895 for output_idx, (_, const_name) in enumerate(const_index_mapping): # type: ignore[misc] 896 self.prefix.writeline( 897 f'folded_constants_map["{const_name}"] = output_handles[{output_idx}];' 898 ) 899 self.prefix.writeline("return folded_constants_map;") 900 901 self.prefix.writeline("}") 902 903 def generate(self, is_inference): 904 if V.graph.aot_mode and not V.graph.is_const_graph: 905 self.codegen_model_kernels() 906 self.codegen_model_constructor() 907 self.codegen_const_run_driver() 908 self.write_wrapper_decl() 909 return super().generate(is_inference) 910 911 def finalize_prefix(self): 912 cached_dtypes_buffer = IndentedBuffer() 913 if config.abi_compatible: 914 for dtype in self.used_cached_dtypes: 915 cached_dtypes_buffer.writeline(f"CACHE_TORCH_DTYPE({dtype});") 916 for device in self.used_cached_devices: 917 cached_dtypes_buffer.writeline(f"CACHE_TORCH_DEVICE({device});") 918 for layout in self.used_cached_layouts: 919 cached_dtypes_buffer.writeline(f"CACHE_TORCH_LAYOUT({layout});") 920 cached_dtypes_buffer.splice(self.prefix) 921 self.prefix = cached_dtypes_buffer 922 923 def define_kernel( 924 self, name: str, kernel: str, metadata: Optional[str] = None, cuda=False 925 ): 926 self.header.splice(f"\n{kernel}\n") 927 928 def codegen_scalar_to_tensor(self, output: str): 929 name = f"scalar_to_tensor_{next(self.scalar_to_tensor_id)}" 930 self.wrapper_call.writeline( 931 f"RAIIAtenTensorHandle {name} = scalar_to_tensor_handle({output});" 932 ) 933 return name 934 935 def codegen_tensor_item( 936 self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None 937 ): 938 assert ( 939 config.abi_compatible 940 ), "codegen_tensor_item is only used for the ABI-compatible mode" 941 dtype_str = str(dtype).split(".")[-1] 942 writer = indented_buffer or self 943 944 if dtype == torch.float16 or dtype == torch.bfloat16: 945 scalar_tmp = f"{scalar}_tmp" 946 writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar_tmp};") 947 948 # need convert_arrayref_tensor_to_tensor for ArrayRefTensors 949 tensor = f"convert_arrayref_tensor_to_tensor({tensor})" 950 951 writer.writeline( 952 f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));" 953 ) 954 writer.writeline(f"float {scalar} = float({scalar_tmp});") 955 else: 956 writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};") 957 958 # need convert_arrayref_tensor_to_tensor for ArrayRefTensors 959 tensor = f"convert_arrayref_tensor_to_tensor({tensor})" 960 961 writer.writeline( 962 f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));" 963 ) 964 965 @cache_on_self 966 def get_output_refs(self): 967 return [ 968 f"torch::tensor({x.codegen_reference(self.wrapper_call)})" 969 if isinstance(x, ir.ShapeAsConstantBuffer) and not config.abi_compatible 970 else x.codegen_reference(self.wrapper_call) 971 for x in V.graph.graph_outputs 972 ] 973 974 def generate_return(self, output_refs: List[str]): 975 cst_names = V.graph.constants.keys() 976 arr_iface = ( 977 not V.graph.is_const_graph and config.use_minimal_arrayref_interface 978 ) # For brevity. 979 980 def use_thread_local_cached_output_tensor(idx, output): 981 cached_output_name = f"cached_output_{next(self.cached_output_id)}" 982 cache_type = "Array" if arr_iface else "Tensor" 983 self.wrapper_call.writeline( 984 f"thread_local ThreadLocalCachedOutput{cache_type}<std::decay_t<decltype({output})>> " 985 f"{cached_output_name}({output});" 986 ) 987 if arr_iface: 988 self.wrapper_call.writeline( 989 f"{cached_output_name}.copy_data_from({output});" 990 ) 991 output_entry = f"std::get<{idx}>(output_arrayref_tensors)" 992 element_type = f"std::decay_t<decltype({output_entry}.data()[0])>" 993 self.wrapper_call.writeline( 994 f"{output_entry} = {cached_output_name}.arrayref_tensor<{element_type}>();" 995 ) 996 else: 997 self.wrapper_call.writeline( 998 f"{cached_output_name}.copy_data_from({output});" 999 ) 1000 self.wrapper_call.writeline( 1001 f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]));" 1002 ) 1003 self.wrapper_call.writeline( 1004 f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors({cached_output_name}.tensor(), " 1005 f"output_handles[{idx}]));" 1006 ) 1007 1008 if arr_iface: 1009 self.wrapper_call.writeline( 1010 "AOTInductorModelOutputs output_arrayref_tensors;" 1011 ) 1012 1013 output2idx: Dict[str, int] = {} 1014 for idx, output in enumerate(output_refs): 1015 if output == self.none_str: 1016 continue 1017 1018 is_constant_buffer = output in cst_names 1019 output_buffer = V.graph.graph_outputs[idx] 1020 if isinstance(output_buffer, ir.BaseView): 1021 output_storage = output_buffer.unwrap_view() 1022 if isinstance(output_storage.data, ir.ConstantBuffer): 1023 is_constant_buffer = True 1024 1025 if config.abi_compatible: 1026 if isinstance(output_buffer, ir.ShapeAsConstantBuffer): 1027 # Need to wrap scalar into tensor as the main function returns a vector of tensors 1028 output_tensor = self.codegen_scalar_to_tensor(output) 1029 self.wrapper_call.writeline( 1030 f"output_handles[{idx}] = {output_tensor}.release();" 1031 ) 1032 continue 1033 1034 output_is_tensor_handle_expr = ( 1035 f"std::is_same_v<std::decay_t<decltype({output})>," 1036 "RAIIAtenTensorHandle> || " 1037 f"std::is_same_v<std::decay_t<decltype({output})>," 1038 "AtenTensorHandle> || " 1039 f"std::is_same_v<std::decay_t<decltype({output})>," 1040 "ConstantHandle>" 1041 ) 1042 self.wrapper_call.writeline( 1043 f"if constexpr ({output_is_tensor_handle_expr}) {{" 1044 ) 1045 with self.wrapper_call.indent(): 1046 if arr_iface: 1047 cached_output_name = ( 1048 f"cached_output_{next(self.cached_output_id)}" 1049 ) 1050 output_value_type = f"std::decay_t<decltype(std::get<{idx}>(output_arrayref_tensors).data()[0])>" 1051 self.wrapper_call.writeline( 1052 f"thread_local RAIIAtenTensorHandle {cached_output_name};" 1053 ) 1054 if is_constant_buffer: 1055 # NOTE(return_constant): In some rare cases where we return 1056 # a constant, we have to return a copy of this constant, 1057 # because (1) constants are not owned by the Model instance 1058 # (2) constants remain the same cross inference runs, 1059 # assuming they are not updated at runtime Basically, we 1060 # cannot release or transfer the ownership of any original 1061 # constant to the user. 1062 self.wrapper_call.writeline( 1063 f"AtenTensorHandle {cached_output_name}_tmp;" 1064 ) 1065 self.wrapper_call.writeline( 1066 f"aoti_torch_clone({output}, &{cached_output_name}_tmp);" 1067 ) 1068 self.wrapper_call.writeline( 1069 f"{cached_output_name} = {cached_output_name}_tmp;" 1070 ) 1071 else: 1072 self.wrapper_call.writeline( 1073 f"{cached_output_name} = {output}.release();" 1074 ) 1075 self.wrapper_call.writeline( 1076 f"convert_handle_to_arrayref_tensor({cached_output_name}, " 1077 f"std::get<{idx}>(output_arrayref_tensors));" 1078 ) 1079 else: 1080 if is_constant_buffer: 1081 # See NOTE(return_constant) above. 1082 self.wrapper_call.writeline( 1083 f"aoti_torch_clone({output}, &output_handles[{idx}]);" 1084 ) 1085 else: 1086 if output in output2idx: 1087 src_idx = output2idx[output] 1088 self.wrapper_call.writeline( 1089 f"output_handles[{idx}] = output_handles[{src_idx}];" 1090 ) 1091 else: 1092 self.wrapper_call.writeline( 1093 f"output_handles[{idx}] = {output}.release();" 1094 ) 1095 self.wrapper_call.writeline("} else {") 1096 with self.wrapper_call.indent(): 1097 use_thread_local_cached_output_tensor(idx, output) 1098 self.wrapper_call.writeline("}") 1099 1100 else: 1101 assert ( 1102 not arr_iface 1103 ), "minimal ArrayRef interface is only supported in ABI-compatible mode" 1104 if is_constant_buffer: 1105 output_expr = f"{output}.clone()" 1106 # See NOTE(return_constant) above. 1107 else: 1108 output_expr = output 1109 self.wrapper_call.writeline( 1110 f"output_handles[{idx}] = reinterpret_cast<AtenTensorHandle>(" 1111 + f"new at::Tensor({output_expr}));" 1112 ) 1113 1114 if output not in output2idx: 1115 output2idx[output] = idx 1116 if arr_iface: 1117 self.wrapper_call.writeline("return output_arrayref_tensors;") 1118 1119 def generate_before_suffix(self, result): 1120 if not V.graph.is_const_graph: 1121 if V.graph.aot_mode: 1122 result.writeline("} // AOTInductorModel::run_impl") 1123 else: 1124 result.writeline("} // inductor_entry_impl") 1125 1126 def generate_end(self, result): 1127 if V.graph.aot_mode: 1128 if V.graph.is_const_graph: 1129 result.writeline("} // AOTInductorModel::_const_run_impl") 1130 else: 1131 result.writeline("} // namespace aot_inductor") 1132 result.writeline("} // namespace torch") 1133 return 1134 1135 # cpp entry function for JIT with cpp wrapper 1136 result.writeline("'''\n)") 1137 result.splice( 1138 f""" 1139 inductor_entry = CppWrapperCodeCache.load_pybinding( 1140 ["std::vector<AtenTensorHandle>"], cpp_wrapper_src, {self.cuda}, {len(V.graph.graph_outputs)}) 1141 """ 1142 ) 1143 1144 wrapper_body = "input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]" 1145 if V.graph.constants: 1146 # Append constants to the input args for cpp wrapper. 1147 # Python wrapper directly gets the value inside the wrapper call 1148 # as a global variable passed when calling exec(code, mod.__dict__, mod.__dict__). 1149 # For cpp wrapper, we need to pass this python value to the inductor_entry_impl function explicitly. 1150 assert all( 1151 isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) 1152 ), "Expect all constants to be Tensor" 1153 constants_str = f"[{', '.join(V.graph.constants.keys())}]" 1154 wrapper_body += f""" 1155 constants_tensor = {constants_str} 1156 input_tensors.extend(constants_tensor) 1157 """ 1158 # Convert vector of at::Tensor to vector of AtenTensorHandle. 1159 # If we pass at::Tensor, the compilation will be too slow. 1160 wrapper_body += """ 1161 input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors) 1162 """ 1163 # Release the inputs for memory reuse. 1164 wrapper_body += """ 1165 args.clear() 1166 """ 1167 1168 # unwrap output tensor back to python scalar 1169 if all(x for x in self.output_is_tensor.values()): 1170 # If no ShapeAsConstantBuffer in the output, directly return the output as tensors 1171 outputs_str = "output_tensors" 1172 else: 1173 outputs = [ 1174 f"output_tensors[{i}]" 1175 if self.output_is_tensor[i] 1176 else f"output_tensors[{i}].item()" 1177 for i in range(len(V.graph.graph_outputs)) 1178 ] 1179 outputs_str = f"[{', '.join(outputs)}]" 1180 wrapper_body += f""" 1181 output_handles = f(input_handles) 1182 output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles) 1183 return {outputs_str} 1184 """ 1185 1186 # Wrap the func to support setting result._boxed_call = True 1187 result.splice( 1188 f""" 1189 def _wrap_func(f): 1190 def g(args): 1191 {wrapper_body} 1192 return g 1193 1194 call = _wrap_func(inductor_entry) 1195 """ 1196 ) 1197 1198 def get_c_shim_func_name(self, kernel): 1199 if not config.abi_compatible or kernel.startswith("aoti_torch_"): 1200 return kernel 1201 1202 assert "::" in kernel, "Cpp kernel name: " + kernel + " does not contain '::'" 1203 kernel_tokens = kernel.split("::") 1204 kernel_suffix = kernel_tokens[-1] 1205 if kernel_suffix == "call": 1206 kernel_suffix = kernel_tokens[-2] 1207 1208 shim_fn = f"aoti_torch_{self.device}_{kernel_suffix}" 1209 return shim_fn 1210 1211 def generate_c_shim_extern_kernel_call(self, kernel, args): 1212 # In the abi_compatible mode, we call fallback aten ops through a C shim layer 1213 # Setting self.allow_stack_allocation to False because the exchange between 1214 # ArrayRefTensor and at::Tensor is still fragile. 1215 self.allow_stack_allocation = False 1216 1217 wrapped_args = [] 1218 1219 args_to_print_or_save = None 1220 debug_printer_manager = V.graph.wrapper_code.debug_printer 1221 if ( 1222 debug_printer_manager.debug_printer_level 1223 != IntermediateValueDebuggingLevel.OFF 1224 ): 1225 args_to_print_or_save = [] 1226 1227 for x in args: 1228 pieces = x.split(", ") 1229 for piece in pieces: 1230 # We only really *need* convert_arrayref_tensor_to_tensor for 1231 # ArrayRefTensors. The code flowing into here uses `0` for nullptr, 1232 # which convert_arrayref_tensor_to_tensor would blindly coerce to int, 1233 # so just avoid wrapping integers. 1234 # Name matching is to find tensor is hacky, but fixing all the 1235 # ArrayRefTensor issues is not a priority for now. 1236 if isinstance(piece, str) and piece.startswith( 1237 ("buf", "arg", "wrap_with_raii_handle_if_needed") 1238 ): 1239 # TODO: The current way to find a 'tensor' type arg is hacky also as mentioned above 1240 # Find a more reliable way to detect tensor kernel args for extern kernel calls 1241 if ( 1242 debug_printer_manager.debug_printer_level 1243 != IntermediateValueDebuggingLevel.OFF 1244 ): 1245 if piece.startswith(("buf", "arg")): 1246 args_to_print_or_save.append(piece) 1247 piece = f"convert_arrayref_tensor_to_tensor({piece})" 1248 wrapped_args.append(piece) 1249 1250 debug_printer_manager.set_printer_args( 1251 args_to_print_or_save, kernel, None, None 1252 ) 1253 with debug_printer_manager: 1254 shim_fn = self.get_c_shim_func_name(kernel) 1255 self.writeline( 1256 f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(wrapped_args)}));" 1257 ) 1258 1259 def generate_c_shim_extern_kernel_alloc(self, extern_kernel, args): 1260 # registered output buffer name 1261 name = extern_kernel.name 1262 output_handle_name = f"{name}_handle" 1263 self.writeline(f"AtenTensorHandle {output_handle_name};") 1264 output_arg = f"&{output_handle_name}" 1265 self.generate_c_shim_extern_kernel_call( 1266 extern_kernel.get_kernel_name(), args + [output_arg] 1267 ) 1268 self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});") 1269 1270 def generate_extern_kernel_alloc(self, extern_kernel, args): 1271 if config.abi_compatible: 1272 self.generate_c_shim_extern_kernel_alloc(extern_kernel, args) 1273 else: 1274 super().generate_extern_kernel_alloc(extern_kernel, args) 1275 1276 def generate_c_shim_fallback_kernel(self, fallback_kernel, args): 1277 output_args = [] 1278 output_raii_handles = [] 1279 output_name_base = fallback_kernel.get_name() 1280 for idx, output in enumerate(fallback_kernel.outputs): 1281 if isinstance(output, ir.MultiOutput): 1282 # TODO: handle integer output (e.g., as in attention) 1283 name = f"{output.get_name()}" 1284 output_handle_name = f"{name}_handle" 1285 if output.indices: 1286 assert ( 1287 output.indices[0][1] == idx 1288 ), f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}" 1289 self.writeline(f"AtenTensorHandle {output_handle_name};") 1290 output_args.append(f"&{output_handle_name}") 1291 output_raii_handles.append( 1292 f"RAIIAtenTensorHandle {name}({output_handle_name});" 1293 ) 1294 elif isinstance(output, int): 1295 output_name = f"{output_name_base}_{idx}" 1296 self.writeline(f"int64_t {output_name} = {output};") 1297 output_args.append(f"&{output_name}") 1298 elif isinstance(output, sympy.Symbol): 1299 output_name = f"{output_name_base}_{idx}" 1300 self.writeline(f"auto {output_name} = {output};") 1301 output_args.append(f"&{output_name}") 1302 elif output is None: 1303 output_args.append("nullptr") 1304 else: 1305 raise NotImplementedError(f"unsupported type of {output=}") 1306 args = args + output_args 1307 self.generate_c_shim_extern_kernel_call(fallback_kernel.cpp_kernel_name, args) 1308 for raii_handle in output_raii_handles: 1309 self.writeline(raii_handle) 1310 1311 def generate_fallback_kernel(self, fallback_kernel, args): 1312 if config.abi_compatible: 1313 self.generate_c_shim_fallback_kernel(fallback_kernel, args) 1314 else: 1315 super().generate_fallback_kernel(fallback_kernel, args) 1316 1317 def generate_extern_kernel_out( 1318 self, kernel: str, out: str, out_view: Optional[str], args: List[str] 1319 ): 1320 if out_view: 1321 out_name = f"{out}_as_strided" 1322 self.writeline(f"auto {out_name} = {out_view};") 1323 args.insert(0, out_name) 1324 else: 1325 args.insert(0, out) 1326 1327 if config.abi_compatible: 1328 self.generate_c_shim_extern_kernel_call(kernel, args) 1329 else: 1330 # TODO: add debug printing info for non-abi compatible mode extern kernel call 1331 self.writeline(self.wrap_kernel_call(kernel, args)) 1332 1333 def generate_scatter_fallback( 1334 self, 1335 output, 1336 inputs, 1337 cpp_kernel_name, 1338 python_kernel_name, 1339 src_is_tensor, 1340 reduce, 1341 kwargs, 1342 ): 1343 # No stack allocation when there is a fallback op 1344 self.allow_stack_allocation = False 1345 1346 if config.abi_compatible: 1347 # call the ABI shim function instead of the ATen one 1348 cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name) 1349 # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py 1350 cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out" 1351 inputs_wrapped = [ 1352 f"convert_arrayref_tensor_to_tensor({x})" 1353 if isinstance(x, str) 1354 else str(x) 1355 for x in inputs 1356 ] 1357 line = f"{cpp_kernel_name}(convert_arrayref_tensor_to_tensor({output}), {','.join(inputs_wrapped)}" 1358 else: 1359 line = f"{cpp_kernel_name}({','.join(map(str, inputs))}" 1360 1361 if python_kernel_name.startswith("aten.scatter_reduce"): 1362 line += f", {','.join(kwargs)}" 1363 else: 1364 if src_is_tensor: 1365 if reduce: 1366 line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}" 1367 else: 1368 assert ( 1369 reduce is None 1370 ), "Expect reduce to be None for aten.scatter_ with scalar src" 1371 line += ");" 1372 self.writeline(line) 1373 1374 def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): 1375 # No stack allocation when there is a fallback op 1376 self.allow_stack_allocation = False 1377 1378 # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version 1379 if config.abi_compatible: 1380 # See the comment in codegen_reinterpret_view about why having something like 1381 # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding 1382 # tensor prematurely deallocated, thus this std::vector().data() trick here. 1383 indices_str = ( 1384 "std::vector<AtenTensorHandle>{" 1385 + ( 1386 ", ".join( 1387 [f"convert_arrayref_tensor_to_tensor({ind})" for ind in indices] 1388 ) 1389 ) 1390 + "}.data()" 1391 ) 1392 args = [ 1393 f"convert_arrayref_tensor_to_tensor({x})", 1394 indices_str, 1395 str(len(indices)), 1396 f"convert_arrayref_tensor_to_tensor({values})", 1397 accumulate, 1398 ] 1399 args.insert( 1400 0, f"convert_arrayref_tensor_to_tensor({x})" 1401 ) # set x as the output tensor, this fallback mutates x. 1402 else: 1403 indices_str = ( 1404 f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}" 1405 ) 1406 args = [x, indices_str, values, accumulate] 1407 args.insert(0, x) # set x as the output tensor, this fallback mutates 1408 1409 self.writeline(self.wrap_kernel_call(kernel, args)) 1410 1411 def add_benchmark_harness(self, output): 1412 if V.graph.aot_mode: 1413 return 1414 super().add_benchmark_harness(output) 1415 1416 def codegen_sizevar(self, x: Expr) -> str: 1417 return self.expr_printer(V.graph.sizevars.simplify(x)) 1418 1419 def codegen_tuple_access(self, basename: str, name: str, index: str) -> str: 1420 if config.abi_compatible: 1421 # in the abi_compatible mode, outputs are returned via arguments 1422 return name 1423 else: 1424 return f"std::get<{index}>({basename})" 1425 1426 def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: 1427 parts = list(map(self.codegen_sizevar, shape)) 1428 if len(parts) == 0: 1429 return "{}" 1430 if len(parts) == 1: 1431 return f"{{{parts[0]}, }}" 1432 return f"{{{', '.join(parts)}}}" 1433 1434 def codegen_dynamic_scalar(self, node): 1435 (data,) = (t.codegen_reference() for t in node.inputs) 1436 if config.abi_compatible: 1437 self.codegen_tensor_item( 1438 node.inputs[0].get_dtype(), data, f"{node.sym}_raw" 1439 ) 1440 else: 1441 convert_type = DTYPE_TO_ATEN[node.inputs[0].get_dtype()].replace( 1442 "at::k", "to" 1443 ) 1444 self.writeline(f"auto {node.sym}_raw = {data}.item().{convert_type}();") 1445 1446 if len(node.keypath) == 0: 1447 self.writeline(f"auto {node.sym} = {node.sym}_raw;") 1448 elif len(node.keypath == 1) and isinstance(node.keypath[0], ConvertIntKey): 1449 self.writeline(f"int64_t {node.sym} = {node.sym}_raw ? 1 : 0;") 1450 elif len(node.keypath == 1) and isinstance(node.keypath[0], DivideByKey): 1451 # TODO: assert divisibility here 1452 self.writeline( 1453 f"int64_t {node.sym} = {node.sym}_raw / {node.keypath[0].divisor};" 1454 ) 1455 else: 1456 raise AssertionError(f"unrecognized keypath {node.keypath}") 1457 1458 # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again 1459 self.unbacked_symbol_decls.add(str(node.sym)) 1460 1461 def can_stack_allocate_buffer(self, buffer): 1462 return ( 1463 self.allow_stack_allocation 1464 and buffer.get_device().type == "cpu" 1465 and self.can_prove_buffer_has_static_shape(buffer) 1466 and ir.is_contiguous_strides_for_shape( 1467 buffer.get_stride(), buffer.get_size() 1468 ) 1469 ) 1470 1471 def make_buffer_free(self, buffer): 1472 return ( 1473 "" 1474 if isinstance(buffer.get_layout(), ir.MultiOutputLayout) 1475 or (V.graph.aot_mode and buffer.get_name() in self.stack_allocated_buffers) 1476 or ( 1477 config.use_minimal_arrayref_interface 1478 and V.graph.aot_mode 1479 and buffer.get_name() in V.graph.graph_inputs 1480 ) 1481 else f"{buffer.get_name()}.reset();" 1482 ) 1483 1484 def make_free_by_names(self, names_to_del: List[str]): 1485 return " ".join(f"{name}.reset();" for name in names_to_del) 1486 1487 def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str): 1488 if config.abi_compatible: 1489 return f"auto {new_name} = std::move({old_name}); // reuse" 1490 else: 1491 return super().codegen_exact_buffer_reuse(old_name, new_name, del_line) 1492 1493 def generate_profiler_mark_wrapper_call(self, stack): 1494 self.wrapper_call.writeline( 1495 'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef<c10::IValue>());' 1496 ) 1497 1498 def write_triton_header_once(self): 1499 pass 1500 1501 def generate_start_graph(self): 1502 pass 1503 1504 def generate_end_graph(self): 1505 pass 1506 1507 def generate_inf_and_nan_checker(self, nodes): 1508 for buf in nodes.get_names(): 1509 # TODO: Add buf name directly into check_inf_and_nan. 1510 self.writeline( 1511 f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan({buf}));" 1512 ) 1513 1514 def codegen_device(self, device): 1515 if config.abi_compatible: 1516 self.used_cached_devices.add(device.type) 1517 return f"cached_torch_device_type_{device.type}, {device.index if device.index else 0}" 1518 else: 1519 return ( 1520 f"c10::Device({DEVICE_TO_ATEN[device.type]}, {device.index})" 1521 if device.index is not None 1522 else f"{DEVICE_TO_ATEN[device.type]}" 1523 ) 1524 1525 def codegen_dtype(self, dtype): 1526 if config.abi_compatible: 1527 dtype_str = str(dtype).split(".")[-1] 1528 self.used_cached_dtypes.add(dtype_str) 1529 return f"cached_torch_dtype_{dtype_str}" 1530 else: 1531 return DTYPE_TO_ATEN[dtype] 1532 1533 def codegen_layout(self, layout): 1534 if config.abi_compatible: 1535 layout_str = str(layout).split(".")[-1] 1536 self.used_cached_layouts.add(layout_str) 1537 return f"cached_torch_layout_{layout_str}" 1538 else: 1539 return LAYOUT_TO_ATEN[layout] 1540 1541 @functools.lru_cache(None) # noqa: B019 1542 def codegen_int_array_var( 1543 self, 1544 int_array: str, 1545 writer=None, 1546 known_statically=False, 1547 graph=None, # for per-graph caching 1548 ): 1549 # This is used for size/stride declaration 1550 # Because the memory planning is done in two passes (see the implementation 1551 # of self.generate), the writeline behavior is different in the two passes. 1552 # As a result, the emitted int array declarations may appear in a later 1553 # position of the generated code, so the second pass codegen should not 1554 # reuse int array declarations generated in the first pass 1555 if writer is None: 1556 # The first pass codegen uses `self` as the writer 1557 writer = self 1558 1559 var = f"int_array_{next(self.int_array_id)}" 1560 ctype = "int64_t" 1561 if var not in self.declared_int_array_vars: 1562 self.declared_int_array_vars.add(var) 1563 if known_statically: 1564 writer.writeline(f"static constexpr {ctype} {var}[] = {int_array};") 1565 else: 1566 writer.writeline(f"const {ctype} {var}[] = {int_array};") 1567 return var 1568 1569 def make_buffer_allocation(self, buffer): 1570 return self.make_allocation( 1571 buffer.get_name(), 1572 buffer.get_device(), 1573 buffer.get_dtype(), 1574 buffer.get_size(), 1575 buffer.get_stride(), 1576 buffer if self.can_stack_allocate_buffer(buffer) else None, 1577 ) 1578 1579 def make_allocation( 1580 self, name, device, dtype, shape, stride, buffer_if_can_stack_allocate=None 1581 ): 1582 orig_stride = stride 1583 device_str = self.codegen_device(device) 1584 dtype_code = self.codegen_dtype(dtype) 1585 size = self.codegen_shape_tuple(shape) 1586 stride = self.codegen_shape_tuple(orig_stride) 1587 if config.abi_compatible: 1588 size_array_var = self.codegen_int_array_var( 1589 size, 1590 self.wrapper_call, 1591 known_statically=self.is_statically_known_list_of_ints(shape), 1592 graph=self.get_codegened_graph(), 1593 ) 1594 stride_array_var = self.codegen_int_array_var( 1595 stride, 1596 self.wrapper_call, 1597 known_statically=self.is_statically_known_list_of_ints(orig_stride), 1598 graph=self.get_codegened_graph(), 1599 ) 1600 device_type, device_id = device_str.split(",") 1601 device_idx = "this->device_idx_" if V.graph.aot_mode else device_id 1602 if buffer_if_can_stack_allocate is not None: 1603 self.stack_allocated_buffers[name] = buffer_if_can_stack_allocate 1604 cpp_type = DTYPE_TO_CPP[dtype] 1605 numel = buffer_if_can_stack_allocate.get_numel() 1606 # Note: we don't zero storage because empty_strided doesn't zero either. 1607 self.wrapper_call.writeline(f"{cpp_type} {name}_storage[{numel}];") 1608 args = [ 1609 f"{name}_storage", 1610 size_array_var, 1611 stride_array_var, 1612 device_type, 1613 device_idx, 1614 ] 1615 return f"ArrayRefTensor<{cpp_type}> {name}({', '.join(args)});" 1616 1617 args = [ 1618 str(len(shape)), 1619 size_array_var, 1620 stride_array_var, 1621 dtype_code, 1622 device_type, 1623 device_idx, 1624 f"&{name}_handle", 1625 ] 1626 1627 self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;") 1628 self.wrapper_call.writeline( 1629 f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));" 1630 ) 1631 1632 return f"RAIIAtenTensorHandle {name}({name}_handle);" 1633 1634 if V.graph.aot_mode and device_str.startswith("c10::Device("): 1635 tensor_device = f"{device_str.split(',')[0]}, this->device_idx_)" 1636 else: 1637 tensor_device = device_str 1638 1639 if device.type == "cpu": 1640 return f"at::Tensor {name} = at::detail::empty_strided_cpu({size}, {stride}, {dtype_code});" 1641 if device.type == "cuda": 1642 return ( 1643 f"at::Tensor {name} = at::detail::empty_strided_cuda(" 1644 f"{size}, {stride}, {dtype_code}, c10::DeviceType::CUDA);" 1645 ) 1646 return ( 1647 f"{self.declare}{name} = {self.namespace}empty_strided(" 1648 f"{size}, {stride}, at::TensorOptions({tensor_device}).dtype({dtype_code})){self.ending}" 1649 ) 1650 1651 def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: 1652 if config.abi_compatible: 1653 size = self.codegen_shape_tuple(shape) 1654 stride = self.codegen_shape_tuple(stride) 1655 tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" 1656 args = [ 1657 name, 1658 self.expr_printer(offset), # bytes not numel 1659 self.codegen_dtype(dtype), 1660 str(len(shape)), 1661 self.codegen_int_array_var( 1662 size, self.wrapper_call, graph=self.get_codegened_graph() 1663 ), 1664 self.codegen_int_array_var( 1665 stride, self.wrapper_call, graph=self.get_codegened_graph() 1666 ), 1667 f"&{tmp_name}", 1668 ] 1669 self.wrapper_call.writeline(f"AtenTensorHandle {tmp_name};") 1670 self.wrapper_call.writeline( 1671 f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));" 1672 ) 1673 return f"RAIIAtenTensorHandle({tmp_name})" 1674 1675 return "alloc_from_pool({})".format( 1676 ", ".join( 1677 [ 1678 name, 1679 self.expr_printer(offset), # bytes not numel 1680 self.codegen_dtype(dtype), 1681 self.codegen_shape_tuple(shape), 1682 self.codegen_shape_tuple(stride), 1683 ] 1684 ) 1685 ) 1686 1687 def codegen_reinterpret_view( 1688 self, data, size_list, stride_list, offset, writer, dtype=None 1689 ) -> str: 1690 dim = str(len(size_list)) 1691 original_offset = offset 1692 size = self.codegen_shape_tuple(size_list) 1693 stride = self.codegen_shape_tuple(stride_list) 1694 offset = self.codegen_sizevar(offset) 1695 call_strs = [] 1696 if config.abi_compatible: 1697 final_tmp_name = None 1698 final_tmp_name_is_RAIIAtenTensorHandle = False 1699 1700 def create_reinterpret_call(): 1701 tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" 1702 args = [ 1703 f"{data.get_name()}", 1704 dim, 1705 self.codegen_int_array_var( 1706 size, 1707 writer, 1708 known_statically=self.is_statically_known_list_of_ints( 1709 size_list 1710 ), 1711 graph=self.get_codegened_graph(), 1712 ), 1713 self.codegen_int_array_var( 1714 stride, 1715 writer, 1716 known_statically=self.is_statically_known_list_of_ints( 1717 stride_list 1718 ), 1719 graph=self.get_codegened_graph(), 1720 ), 1721 offset, 1722 ] 1723 call_str = ( 1724 f"auto {tmp_name} = reinterpret_tensor_wrapper({', '.join(args)});" 1725 ) 1726 return tmp_name, call_str 1727 1728 def create_dtypeview_call(reinterpret_call): 1729 tmp_AtenTensorHandle = ( 1730 f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}" 1731 ) 1732 call_strs = [f"AtenTensorHandle {tmp_AtenTensorHandle};"] 1733 dtype_name = str(dtype).split(".")[-1] 1734 device_name = "cuda" if data.layout.device.type == "cuda" else "cpu" 1735 get_dtype_function = f"aoti_torch_dtype_{dtype_name}" 1736 dtypeview_function = f"aoti_torch_{device_name}_view_dtype" 1737 call_strs.append( 1738 f"AOTI_TORCH_ERROR_CODE_CHECK({dtypeview_function}" 1739 f"({reinterpret_call}, {get_dtype_function}(), &{tmp_AtenTensorHandle}));" 1740 ) 1741 tmp_RAIIAtenTensorHandle = ( 1742 f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}_handle" 1743 ) 1744 call_strs.append( 1745 f"RAIIAtenTensorHandle {tmp_RAIIAtenTensorHandle}({tmp_AtenTensorHandle});" 1746 ) 1747 return tmp_RAIIAtenTensorHandle, call_strs 1748 1749 if ( 1750 size_list == data.layout.size 1751 and stride_list == data.layout.stride 1752 and original_offset == data.layout.offset 1753 ): 1754 # pure dtypeview 1755 if dtype is not None and dtype != data.dtype: 1756 tmp_output_name, tmp_call_strs = create_dtypeview_call( 1757 data.get_name() 1758 ) 1759 call_strs.extend(tmp_call_strs) 1760 final_tmp_name = tmp_output_name 1761 final_tmp_name_is_RAIIAtenTensorHandle = True 1762 else: 1763 return f"{data.get_name()}" 1764 else: 1765 # firstly create reinterpretview 1766 final_tmp_name, reinterpret_call = create_reinterpret_call() 1767 call_strs.append(reinterpret_call) 1768 1769 if dtype is not None and dtype != data.dtype: 1770 # wrap it with dtypeview 1771 final_tmp_name, tmp_call_strs = create_dtypeview_call( 1772 reinterpret_call 1773 ) 1774 call_strs.extend(tmp_call_strs) 1775 # Because the memory planning is done in two passes (see the implementation 1776 # of self.generate), the writeline behavior is different in the two passes. 1777 if writer is None: 1778 writer = self 1779 writer.writelines(call_strs) 1780 if ( 1781 self.can_stack_allocate_buffer(data) 1782 and self.is_statically_known_list_of_ints(size_list) 1783 and self.is_statically_known_list_of_ints(stride_list) 1784 and ir.is_contiguous_strides_for_shape(stride_list, size_list) 1785 ): 1786 return final_tmp_name 1787 1788 # NB, the return handle here represents a temporary tensor, which will be automatically 1789 # released. 1790 # Here's a sample usage in the cpp wrapper code: 1791 # ``` 1792 # aoti_torch_addmm_out( 1793 # buf1, 1794 # arg1_1, 1795 # RAIIAtenTensorHandle(tmp_tensor_handle_0), 1796 # buf0, 1797 # 1L, 1798 # 1L)); 1799 # ``` 1800 # RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out. 1801 # This could be problematic when it's used in a different pattern, for example: 1802 # ```` 1803 # AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6}; 1804 # aoti_torch_proxy_executor_call_function(..., tensor_args); 1805 # ```` 1806 # RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter 1807 # kernel call. 1808 # 1809 # This is solved by updating the proxy_executor invocation to 1810 # ``` 1811 # aoti_torch_proxy_executor_call_function(..., 1812 # std::vector<AtenTensorHandle>{ 1813 # RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6 1814 # }.data() 1815 # ); 1816 # ``` 1817 if not final_tmp_name_is_RAIIAtenTensorHandle: 1818 return f"wrap_with_raii_handle_if_needed({final_tmp_name})" 1819 else: 1820 return final_tmp_name 1821 else: 1822 args = [data.get_name(), size, stride, offset] 1823 return f"reinterpret_tensor({', '.join(args)})" 1824 1825 def codegen_device_copy(self, src, dst): 1826 if config.abi_compatible: 1827 # aoti_torch_tensor_copy_ takes AtenTensorHandle as input, 1828 # while stack-allocation results in ArrayRefTensor 1829 # so disable stack allocation here 1830 self.allow_stack_allocation = False 1831 self.writeline( 1832 f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_tensor_copy_(expensive_copy_to_tensor_if_needed({src}), {dst}));" 1833 ) 1834 else: 1835 self.writeline(f"{dst}.copy_({src});") 1836 1837 def codegen_multi_output(self, name, value): 1838 # in the abi_compatible mode, outputs are retrieved by passing 1839 # output pointers, so we skip its codegen here. 1840 if not config.abi_compatible: 1841 super().codegen_multi_output(name, value) 1842 1843 def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs): 1844 for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs): 1845 if config.abi_compatible: 1846 # in ABI-compatible mode, we copy the underlying at::Tensor of the conditional 1847 # input (outer_input) into another at::Tensor to be used as a subgraph input 1848 # (inner_input) in the nested scope. we can't std::move here, as the codegened 1849 # outer input may be an expression / rvalue (e.g., reinterpret_view(x)), so we 1850 # can't necessarily std::move it back to the origin (x). 1851 self.writeline(f"AtenTensorHandle {inner_input}_handle;") 1852 self.writeline( 1853 f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({outer_input}, &{inner_input}_handle));" 1854 ) 1855 self.writeline( 1856 f"RAIIAtenTensorHandle {inner_input}({inner_input}_handle);" 1857 ) 1858 else: 1859 self.writeline( 1860 f"{self.declare}{inner_input} = {outer_input}{self.ending}" 1861 ) 1862 1863 def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs): 1864 for inner_output, outer_output in zip( 1865 subgraph.graph.graph_outputs, outer_outputs 1866 ): 1867 src = inner_output.codegen_reference() 1868 if config.abi_compatible: 1869 # in ABI-compatible mode, we need to std::move subgraph output (inner_output) 1870 # to the conditional output (outer_output), as RAIIAtenTensorHandle's copy 1871 # constructor is deleted. 1872 src = f"std::move({src})" 1873 # in case the outer_output carried a value 1874 # before (e.g., in the while_loop codegen) 1875 self.writeline(f"{outer_output}.reset();") 1876 self.writeline(f"{outer_output} = {src}{self.ending}") 1877 1878 def codegen_conditional(self, conditional): 1879 name = conditional.get_name() 1880 outer_inputs = [f"{buf.codegen_reference()}" for buf in conditional.operands] 1881 if config.abi_compatible: 1882 outer_outputs = [] 1883 for out in conditional.outputs: 1884 # in ABI-compatible mode, ir.MultiOutput is not codegened, 1885 # hence pre-declare output variables directly and separately 1886 self.writeline(f"RAIIAtenTensorHandle {out.get_name()};") 1887 outer_outputs.append(out.get_name()) 1888 1889 if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer): 1890 # in ABI-compatible mode, we need to use the ABI shim function 1891 # to extract a C++ bool from the unrelying scalar bool Tensor 1892 predicate = f"{conditional.predicate.get_name()}_scalar" 1893 self.codegen_tensor_item( 1894 torch.bool, 1895 conditional.predicate.codegen_reference(), 1896 predicate, 1897 ) 1898 else: 1899 # the predicate is not a Tensor: SymBool or Python bool 1900 predicate = conditional.predicate.codegen_reference() 1901 else: 1902 # in non-ABI-compatible mode, we can codegen the conditional outputs 1903 # as array of at::Tensor instances, as the ir.MultiOutput is codegened 1904 outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))] 1905 self.writeline(f"at::Tensor {name}[{len(conditional.outputs)}];") 1906 predicate = f"{conditional.predicate.codegen_reference()}" 1907 if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer): 1908 # move the Tensor predicate to host 1909 predicate = f"{predicate}.item<bool>()" 1910 1911 self.writeline(f"if ({predicate}) {{") 1912 self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph)) 1913 self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs) 1914 self.writeline(ExitSubgraphLine(self)) 1915 self.writeline("} else {") 1916 self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph)) 1917 self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs) 1918 self.writeline(ExitSubgraphLine(self)) 1919 self.writeline("}") 1920 1921 def codegen_while_loop(self, while_loop): 1922 name = while_loop.get_name() 1923 outer_carried_inputs = [ 1924 buf.codegen_reference() for buf in while_loop.carried_inputs 1925 ] 1926 outer_additional_inputs = [ 1927 buf.codegen_reference() for buf in while_loop.additional_inputs 1928 ] 1929 cond_result_name = f"{name}_cond_result" 1930 1931 if config.abi_compatible: 1932 self.writeline(f"RAIIAtenTensorHandle {cond_result_name};") 1933 1934 cond_outer_inputs = [] 1935 for inp, out in zip(outer_carried_inputs, while_loop.outputs): 1936 # in ABI-compatible mode, the carried inputs are codegened 1937 # as buffers outside the while loop and set to the initial 1938 # values. at the end of each while_loop iteration, they 1939 # will be assined the carried values. 1940 out_name = out.get_name() 1941 self.writeline(f"AtenTensorHandle {out_name}_handle;") 1942 self.writeline( 1943 f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({inp}, &{out_name}_handle));" 1944 ) 1945 self.writeline(f"RAIIAtenTensorHandle {out_name}({out_name}_handle);") 1946 cond_outer_inputs.append(out_name) 1947 1948 # additional inputs will be assinged within the while_loop 1949 # iteration directly from the corresponding outer graph buffers 1950 cond_outer_inputs.extend(outer_additional_inputs) 1951 else: 1952 self.writeline(f"at::Tensor {cond_result_name};") 1953 self.writeline(f"at::Tensor {name}[{len(outer_carried_inputs)}];") 1954 for i, inp in enumerate(outer_carried_inputs): 1955 # set the initial state before the loop 1956 self.writeline(f"{name}[{i}] = {inp};") 1957 1958 cond_outer_inputs = [ 1959 *[f"{name}[{i}]" for i in range(len(outer_carried_inputs))], 1960 *outer_additional_inputs, 1961 ] 1962 1963 cond_outer_outputs = [cond_result_name] 1964 body_outer_inputs = list(cond_outer_inputs) 1965 body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)] 1966 1967 self.writeline("while (1) {") 1968 self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph)) 1969 self.codegen_subgraph( 1970 while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs 1971 ) 1972 1973 if config.abi_compatible: 1974 cond_result = f"{cond_result_name}_scalar" 1975 self.codegen_tensor_item(torch.bool, cond_result_name, cond_result) 1976 else: 1977 cond_result = f"{cond_result_name}.item<bool>()" 1978 self.writeline(f"if (!{cond_result}) break;") 1979 1980 self.writeline(ExitSubgraphLine(self)) 1981 self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) 1982 self.codegen_subgraph( 1983 while_loop.body_subgraph, body_outer_inputs, body_outer_outputs 1984 ) 1985 self.writeline(ExitSubgraphLine(self)) 1986 self.writeline("}") 1987 1988 def generate_extern_kernel_args_decl_if_needed( 1989 self, op_overload, raw_args, output_args 1990 ): 1991 arg_types = [x.real_type for x in op_overload._schema.arguments] 1992 return_types = [x.type for x in op_overload._schema.returns] 1993 1994 new_tensor_args = [] 1995 new_int_args = [] 1996 1997 def fill_args(arg, arg_type): 1998 static_arg_types = ( 1999 torch.FloatType, 2000 torch.BoolType, 2001 torch.StringType, 2002 torch.Type, 2003 torch.DeviceObjType, 2004 ) 2005 inductor_tensor_buffers = ( 2006 ir.Buffer, 2007 ir.ReinterpretView, 2008 ) 2009 2010 if isinstance(arg_type, torch.TensorType): 2011 assert isinstance(arg, inductor_tensor_buffers), f"got {type(arg)}" 2012 new_tensor_args.append(f"{arg.codegen_reference()}") 2013 elif isinstance(arg_type, torch.IntType): 2014 # int 2015 new_int_args.append(str(arg)) 2016 elif isinstance(arg_type, torch.SymIntType): 2017 # SymInt 2018 expr = arg.node.expr if isinstance(arg, torch.SymInt) else arg 2019 new_int_args.append(self.expr_printer(expr)) 2020 elif isinstance(arg_type, torch.NumberType): 2021 # Scalar of type int 2022 assert isinstance(arg, (int, float, bool)) 2023 # Only treat int Scalar as dynamic 2024 if isinstance(arg, int): 2025 new_int_args.append(str(arg)) 2026 elif isinstance(arg_type, torch.ListType): 2027 assert isinstance(arg, (list, tuple)) 2028 2029 # List[Tensor] 2030 if isinstance(arg_type.getElementType(), torch.TensorType): 2031 new_tensor_args.extend([f"{a.codegen_reference()}" for a in arg]) 2032 # List[Optional[Tensor]] 2033 elif isinstance( 2034 arg_type.getElementType(), torch.OptionalType 2035 ) and isinstance( 2036 arg_type.getElementType().getElementType(), torch.TensorType 2037 ): 2038 new_tensor_args.extend( 2039 [f"{a.codegen_reference()}" for a in arg if a is not None] 2040 ) 2041 # List[int] 2042 elif isinstance(arg_type.getElementType(), torch.IntType): 2043 new_int_args.extend([str(a) for a in arg]) 2044 # List[SymInt] 2045 elif isinstance(arg_type.getElementType(), torch.SymIntType): 2046 expressions = [ 2047 a.node.expr if isinstance(a, torch.SymInt) else a for a in arg 2048 ] 2049 new_int_args.extend( 2050 [self.expr_printer(expr) for expr in expressions] 2051 ) 2052 # List[Scalar] 2053 elif isinstance(arg_type.getElementType(), torch.NumberType): 2054 # Only treat int Scalar as dynamic 2055 is_int_type = [isinstance(a, int) for a in arg] 2056 if any(is_int_type): 2057 assert all( 2058 is_int_type 2059 ), "AOTInductor only supports int scalars of the same type" 2060 new_int_args.extend([str(a) for a in arg]) 2061 else: 2062 assert isinstance( 2063 arg_type.getElementType(), static_arg_types # type: ignore[arg-type] 2064 ), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" 2065 else: 2066 assert isinstance( 2067 arg_type, static_arg_types # type: ignore[arg-type] 2068 ), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" 2069 2070 for arg, arg_type in zip(raw_args, arg_types): 2071 if arg is not None: 2072 if isinstance(arg_type, torch.OptionalType): 2073 fill_args(arg, arg_type.getElementType()) 2074 else: 2075 fill_args(arg, arg_type) 2076 2077 def fill_output_arg(arg, return_type): 2078 if isinstance(return_type, torch.TensorType): 2079 self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer") 2080 self.writeline( 2081 f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));" 2082 ) 2083 self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);") 2084 new_tensor_args.append(f"{arg}") 2085 elif isinstance(return_type, torch.SymIntType): 2086 raise NotImplementedError("NYI support for return type: SymInt") 2087 elif isinstance(return_type, torch.ListType) and isinstance( 2088 return_type.getElementType(), torch.SymIntType 2089 ): 2090 raise NotImplementedError("NYI support for return type: List[SymInt]") 2091 else: 2092 raise AssertionError(f"Unsupported return type found: {return_type}") 2093 2094 # TODO: Only support tensor(s) returns for now, SymInt is not implemented yet 2095 for return_type in return_types: 2096 if isinstance(return_type, (torch.TensorType)): 2097 pass 2098 elif isinstance(return_type, torch.OptionalType): 2099 assert isinstance(return_type.getElementType(), torch.TensorType) 2100 elif isinstance(return_type, torch.ListType): 2101 assert isinstance(return_type.getElementType(), torch.TensorType) 2102 else: 2103 raise NotImplementedError( 2104 f"return type {return_type} is not yet supported." 2105 ) 2106 2107 for output_arg in output_args: 2108 assert output_arg is not None, "Optional return types are not yet supported" 2109 if isinstance(output_arg, (list, tuple)): 2110 for out in output_arg: 2111 fill_output_arg(out, torch.TensorType.get()) 2112 else: 2113 fill_output_arg(output_arg, torch.TensorType.get()) 2114 2115 return new_tensor_args, new_int_args 2116 2117 def generate_extern_kernel_alloc_and_find_schema_if_needed( 2118 self, 2119 buf_name: str, 2120 python_kernel_name: str, 2121 cpp_kernel_name: str, 2122 codegen_args: List[str], 2123 cpp_op_schema: str, 2124 cpp_kernel_key: str, 2125 cpp_kernel_overload_name: str = "", 2126 op_overload: Optional[torch._ops.OpOverload] = None, 2127 raw_args=None, 2128 outputs=None, 2129 ): 2130 # No stack allocation when there is a fallback op 2131 self.allow_stack_allocation = False 2132 2133 def extract_output_name(out): 2134 if out is None: 2135 # Because out is not a MultiOutput, we assume the kernel returns a single output 2136 return [buf_name] 2137 elif isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)): 2138 return out.get_name() 2139 elif isinstance(out, (list, tuple)): 2140 return type(out)(extract_output_name(o) for o in out) 2141 else: 2142 raise AssertionError(f"Unexpected output: {type(out)}") 2143 2144 # output_args has the same pytree structure as outputs 2145 output_args = None 2146 if config.abi_compatible: 2147 output_args = extract_output_name(outputs) 2148 if isinstance(output_args, str): 2149 output_args = [output_args] 2150 2151 if V.graph.aot_mode and config.abi_compatible: 2152 assert op_overload is not None 2153 assert raw_args is not None 2154 assert outputs is not None 2155 2156 return self.generate_extern_kernel_alloc_and_find_schema_if_needed_with_proxy_executor( 2157 cpp_kernel_key, 2158 op_overload, 2159 raw_args, 2160 output_args, 2161 ) 2162 else: 2163 return self.generate_extern_kernel_alloc_and_find_schema_if_needed_jit( 2164 buf_name, 2165 python_kernel_name, 2166 cpp_kernel_name, 2167 codegen_args, 2168 cpp_op_schema, 2169 cpp_kernel_key, 2170 cpp_kernel_overload_name, 2171 op_overload, 2172 raw_args, 2173 output_args, 2174 ) 2175 2176 def generate_scoped_gil_acquire(self, declarations_before_scope, lines_in_scope): 2177 scoped_lines = IndentedBuffer() 2178 for declaration in declarations_before_scope: 2179 scoped_lines.writeline(declaration) 2180 2181 scoped_lines.writeline("{") 2182 with scoped_lines.indent(): 2183 scoped_lines.writeline("py::gil_scoped_acquire acquire;") 2184 scoped_lines.writelines(lines_in_scope.split("\n")) 2185 scoped_lines.writelines("}") 2186 return scoped_lines._lines 2187 2188 def load_custom_op_wrapper(self): 2189 # TODO: need to support control flow 2190 if self.custom_op_wrapper_loaded: 2191 return 2192 2193 lines = """ 2194RAIIPyObject codecache_module(PyImport_ImportModule("torch._inductor.codecache")); 2195if (codecache_module.get() == NULL) { 2196 throw std::runtime_error("Failed to load torch._inductor.codecache"); 2197} 2198custom_op_wrapper = PyObject_GetAttrString(codecache_module, "custom_op_wrapper"); 2199if (custom_op_wrapper.get() == NULL) { 2200 throw std::runtime_error("Failed to load torch._inductor.codecache.custom_op_wrapper"); 2201}""" 2202 2203 declarations_before_scope = ["RAIIPyObject custom_op_wrapper;"] 2204 scope_gil_acquire = self.generate_scoped_gil_acquire( 2205 declarations_before_scope, lines 2206 ) 2207 self.writelines(scope_gil_acquire) 2208 2209 self.custom_op_wrapper_loaded = True 2210 2211 def generate_py_arg(self, py_args_var, idx, raw_arg, arg_type): 2212 def generate_py_arg_inner(lines, raw_arg, arg_type): 2213 if raw_arg is None: 2214 # Py_None is a singleton, so we have to explicitly incref it here 2215 lines.append("Py_INCREF(Py_None);\n") 2216 return "Py_None" 2217 elif isinstance(arg_type, torch.TensorType): 2218 # Store AtenTensorHandle as void* 2219 base_handle = raw_arg.codegen_reference() 2220 ( 2221 tmp_raii_handle_var, 2222 tmp_raii_handle_var_decl, 2223 ) = self.create_tmp_raii_handle_var(base_handle) 2224 if tmp_raii_handle_var: 2225 lines.append(tmp_raii_handle_var_decl) 2226 base_handle = tmp_raii_handle_var 2227 return f"PyCapsule_New(reinterpret_cast<void*>({base_handle}.get()), NULL, NULL)" 2228 elif isinstance(arg_type, torch.OptionalType): 2229 return generate_py_arg_inner(lines, raw_arg, arg_type.getElementType()) 2230 elif isinstance(arg_type, torch.IntType): 2231 # int 2232 return f"PyLong_FromLongLong({raw_arg})" 2233 elif isinstance(arg_type, torch.SymIntType): 2234 # SymInt 2235 expr = ( 2236 raw_arg.node.expr if isinstance(raw_arg, torch.SymInt) else raw_arg 2237 ) 2238 return f"PyLong_FromLongLong({self.expr_printer(expr)})" 2239 elif isinstance(arg_type, torch.FloatType): 2240 return f"PyFloat_FromDouble({raw_arg})" 2241 elif isinstance(arg_type, torch.BoolType): 2242 return f"PyBool_FromLong({1 if raw_arg else 0})" 2243 elif isinstance(arg_type, torch.StringType): 2244 return f'PyUnicode_FromString("{raw_arg}")' 2245 elif isinstance(arg_type, torch.NumberType): 2246 # Union[bool, int, float, complex] 2247 # torch/_prims_common/__init__.py 2248 if isinstance(raw_arg, int): 2249 return f"PyLong_FromLongLong({raw_arg})" 2250 elif isinstance(raw_arg, float): 2251 return f"PyFloat_FromDouble({raw_arg})" 2252 elif isinstance(raw_arg, bool): 2253 return f"PyBool_FromLong({1 if raw_arg else 0})" 2254 elif isinstance(raw_arg, complex): 2255 return f"PyComplex_FromDoubles({raw_arg.real, raw_arg.imag})" 2256 elif isinstance(raw_arg, torch.SymInt): 2257 expr = raw_arg.node.expr 2258 return f"PyLong_FromLongLong({self.expr_printer(expr)})" 2259 else: 2260 raise NotImplementedError( 2261 f"arg type {arg_type} with raw_arg {raw_arg}, {type(raw_arg)} is not yet supported by custom_op_wrapper" 2262 ) 2263 elif isinstance(raw_arg, torch.dtype): 2264 # dtype 2265 self.include_extra_header("torch/csrc/DynamicTypes.h") 2266 return f"Py_NewRef(torch::getTHPDtype(static_cast<c10::ScalarType>({self.codegen_dtype(raw_arg)})))" 2267 else: 2268 raise NotImplementedError( 2269 f"arg type {arg_type} is not yet supported by custom_op_wrapper" 2270 ) 2271 2272 lines = [] 2273 if isinstance(arg_type, torch.ListType): 2274 assert isinstance(raw_arg, (list, tuple)), str(raw_arg) + " is not a list" 2275 lines.append( 2276 f"PyObject* {py_args_var}_{idx} = PyList_New({len(raw_arg)});\n" 2277 ) 2278 for i, elem in enumerate(raw_arg): 2279 lines.append( 2280 f"PyList_SetItem({py_args_var}_{idx}, {i}, {generate_py_arg_inner(lines, elem, arg_type.getElementType())});\n" 2281 ) 2282 lines.append( 2283 f"PyTuple_SetItem({py_args_var}, {idx}, {py_args_var}_{idx});\n" 2284 ) 2285 else: 2286 lines.append( 2287 f"PyTuple_SetItem({py_args_var}, {idx}, {generate_py_arg_inner(lines, raw_arg, arg_type)});\n" 2288 ) 2289 return "".join(lines) 2290 2291 def generate_extern_kernel_alloc_and_find_schema_if_needed_jit( 2292 self, 2293 buf_name: str, 2294 python_kernel_name: str, 2295 cpp_kernel_name: str, 2296 codegen_args: List[str], 2297 cpp_op_schema: str, 2298 cpp_kernel_key: str, 2299 cpp_kernel_overload_name: str = "", 2300 op_overload: Optional[torch._ops.OpOverload] = None, 2301 raw_args=None, 2302 output_args: Optional[List[str]] = None, 2303 ): 2304 if not config.abi_compatible: 2305 # Will update this to use an OSS version ProxyExecutor 2306 if cpp_kernel_key not in self.extern_call_ops: 2307 self.writeline( 2308 f"static auto op_{cpp_kernel_key} = c10::Dispatcher::singleton()" 2309 ) 2310 self.writeline( 2311 f'\t.findSchemaOrThrow("{cpp_kernel_name}", "{cpp_kernel_overload_name}")' 2312 ) 2313 self.writeline(f"\t.typed<{cpp_op_schema}>();") 2314 self.extern_call_ops.add(cpp_kernel_key) 2315 2316 self.writeline( 2317 f"auto {buf_name} = op_{cpp_kernel_key}.call({', '.join(codegen_args)});" 2318 ) 2319 else: 2320 # In the JIT mode, because of the ABI-compatible requirement, we can't directly call 2321 # c10::Dispatcher to find the custom op and call it. Instead, we go back to Python 2322 # to invoke this custom op. 2323 self.load_custom_op_wrapper() 2324 2325 assert output_args is not None, "output_args should not be None" 2326 num_args = len(raw_args) 2327 py_args_var = f"py_args_{next(self.arg_var_id)}" 2328 # First arg is always the python op name 2329 lines = f""" 2330RAIIPyObject {py_args_var}(PyTuple_New({num_args+1})); 2331if ({py_args_var}.get() == NULL) {{ 2332 throw std::runtime_error("PyTuple_New {py_args_var} failed"); 2333}} 2334PyTuple_SetItem({py_args_var}, 0, PyUnicode_FromString("{python_kernel_name}")); 2335""" 2336 2337 assert op_overload is not None, "op_overload should not be None" 2338 2339 for idx, (raw_arg, schema_arg) in enumerate( 2340 zip(raw_args, op_overload._schema.arguments) 2341 ): 2342 lines += self.generate_py_arg( 2343 py_args_var, idx + 1, raw_arg, schema_arg.real_type 2344 ) 2345 2346 lines += f""" 2347// Call the custom op in Python 2348RAIIPyObject py_{buf_name}(PyObject_CallObject(custom_op_wrapper, {py_args_var})); 2349if (py_{buf_name}.get() == NULL) {{ 2350 throw std::runtime_error("PyObject_CallObject {python_kernel_name} failed"); 2351}}""" 2352 2353 if len(output_args) == 1: 2354 # result is a single tensor 2355 lines += f""" 2356{output_args[0]} = reinterpret_cast<AtenTensorHandle>(PyCapsule_GetPointer(py_{buf_name}.get(), NULL));""" 2357 else: 2358 # result is a tuple of tensors 2359 for idx, output_arg in enumerate(output_args): 2360 lines += f""" 2361{output_arg} = 2362 reinterpret_cast<AtenTensorHandle>(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_name}.get(), {idx}), NULL));""" 2363 2364 declarations_before_scope = [ 2365 f"RAIIAtenTensorHandle {output_arg};" 2366 for idx, output_arg in enumerate(output_args) 2367 ] 2368 scope_gil_acquire = self.generate_scoped_gil_acquire( 2369 declarations_before_scope, lines 2370 ) 2371 self.writelines(scope_gil_acquire) 2372 2373 def generate_extern_kernel_alloc_and_find_schema_if_needed_with_proxy_executor( 2374 self, 2375 cpp_kernel_key, 2376 op_overload, 2377 raw_args, # contains both args and flatten kwargs 2378 output_args: Optional[List[str]] = None, 2379 ): 2380 ( 2381 tensor_call_args, 2382 int_call_args, 2383 ) = self.generate_extern_kernel_args_decl_if_needed( 2384 op_overload, raw_args, output_args 2385 ) 2386 2387 tensor_call_args_str = ", ".join(tensor_call_args) 2388 int_call_args_str = ", ".join(int_call_args) 2389 2390 extern_kernel_node_index = len(V.graph.extern_kernel_nodes) - 1 2391 2392 self.writeline( 2393 f"aoti_torch_proxy_executor_call_function(proxy_executor, " 2394 f"{extern_kernel_node_index}, " 2395 f"{len(int_call_args)}, " 2396 f"std::vector<int64_t>{{{int_call_args_str}}}.data(), " 2397 f"{len(tensor_call_args)}, " 2398 f"std::vector<AtenTensorHandle>{{{tensor_call_args_str}}}.data());" 2399 ) 2400 2401 self.extern_call_ops.add(cpp_kernel_key) 2402 2403 def generate_reset_kernel_saved_flags(self): 2404 pass 2405 2406 def generate_save_uncompiled_kernels(self): 2407 pass 2408 2409 def c_type_for_prim_type(self, val, type_) -> str: 2410 assert ( 2411 config.abi_compatible 2412 ), "c_type_for_prim_type is only used in ABI compatible mode" 2413 if isinstance(type_, torch.OptionalType): 2414 return f"{self.c_type_for_prim_type(val, type_.getElementType())}*" 2415 elif isinstance(type_, torch.TensorType): 2416 return "AtenTensorHandle" 2417 elif isinstance(type_, (torch.IntType, torch.SymIntType)): 2418 return "int64_t" 2419 elif isinstance( 2420 type_, (torch.BoolType, torch.SymBoolType, torch.EnumType) 2421 ) or repr(type_) in ("ScalarType", "Layout"): 2422 return "int32_t" 2423 elif isinstance(type_, torch.FloatType): 2424 return "double" 2425 elif isinstance(type_, torch.NumberType): 2426 if isinstance(val, bool): 2427 return "int32_t" 2428 elif isinstance(val, int): 2429 return "int64_t" 2430 elif isinstance(val, float): 2431 return "double" 2432 elif val is None: 2433 # This could happen when val is an optional value 2434 return "double" 2435 else: 2436 raise AssertionError( 2437 f"Unexpected type in c_type_for_prim_type: {type_=}" 2438 ) 2439 elif isinstance(type_, torch.StringType): 2440 return "const char*" 2441 else: 2442 raise AssertionError(f"Unexpected type in c_type_for_prim_type: {type_=}") 2443 2444 def val_to_arg_str_for_prim_type(self, val, type_) -> str: 2445 # TODO: not using type_ as the first step of refactoring. Will update this later. 2446 if isinstance(val, bool): 2447 if config.abi_compatible: 2448 return "1" if val else "0" 2449 else: 2450 return "true" if val else "false" 2451 elif isinstance(val, int): 2452 # uint64_t is long on Linux, but long long on MacOS and Windows 2453 return f"{val}LL" if sys.platform in ["darwin", "win32"] else f"{val}L" 2454 elif isinstance(val, str): 2455 return f'"{val}"' 2456 elif isinstance( 2457 val, (ir.Buffer, ir.ReinterpretView, ir.StorageBox, ir.TensorBox) 2458 ): 2459 return val.codegen_reference() 2460 elif isinstance(val, torch.device): 2461 return self.codegen_device(val) 2462 elif isinstance(val, torch.dtype): 2463 return self.codegen_dtype(val) 2464 elif isinstance(val, float) and val in [float("inf"), float("-inf")]: 2465 if val == float("inf"): 2466 return "std::numeric_limits<float>::infinity()" 2467 else: 2468 return "-std::numeric_limits<float>::infinity()" 2469 elif isinstance(val, (list, tuple)): 2470 # FIXME: This happens because type_ is not always properly set to torch.ListType 2471 return f"{{{', '.join(self.val_to_arg_str(x, None) for x in val)}}}" 2472 elif isinstance(val, SymTypes): 2473 return self.expr_printer(val.node.expr) 2474 elif isinstance(val, sympy.Expr): 2475 return self.expr_printer(val) 2476 else: 2477 return repr(val) 2478 2479 def val_to_arg_str(self, val, type_=None) -> str: 2480 if val is None: 2481 # None needs special care. It either represent nullopt or an empty tensor 2482 if config.abi_compatible: 2483 if type_ is None or isinstance(type_, torch.OptionalType): 2484 if type_ is not None and isinstance( 2485 type_.getElementType(), 2486 ( 2487 torch.ListType, 2488 torch.TupleType, 2489 torch.DeviceObjType, 2490 ), 2491 ): 2492 return "0, 0" 2493 else: 2494 return "0" # nullptr is not available in C 2495 elif isinstance(type_, torch.TensorType): 2496 # create an empty tensor, the equivalent of at::Tensor() 2497 var_name = f"var_{next(self.arg_var_id)}" 2498 self.writeline(f"AtenTensorHandle {var_name}_handle;") 2499 self.writeline( 2500 f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{var_name}_handle));" 2501 ) 2502 self.writeline( 2503 f"RAIIAtenTensorHandle {var_name}({var_name}_handle);" 2504 ) 2505 return var_name 2506 else: 2507 raise AssertionError("Can not map None to a known data type") 2508 else: 2509 return "std::nullopt" 2510 2511 if isinstance(type_, torch.OptionalType): 2512 element_type = type_.getElementType() 2513 if config.abi_compatible: 2514 if not isinstance(element_type, torch.TensorType): 2515 var_name = f"var_{next(self.arg_var_id)}" 2516 if isinstance( 2517 element_type, 2518 (torch.ListType, torch.TupleType, torch.DeviceObjType), 2519 ): 2520 # type_ is something like Optional[List] or Optional[Device] 2521 arg_str = self.val_to_arg_str(val, element_type) 2522 # For datatypes with auxiliary info, we need to hoist out the extra arguments. 2523 # NOTE: This only works if there is one additional argument, though it can easily be generalized. 2524 main_value, aux = arg_str.rsplit(", ") 2525 self.writeline(f"auto {var_name} = {main_value};") 2526 return f"&{var_name}, {aux}" 2527 else: 2528 self.writeline( 2529 f"{self.c_type_for_prim_type(val, element_type)} {var_name} = {self.val_to_arg_str(val, element_type)};" 2530 ) 2531 return f"&{var_name}" 2532 else: 2533 # type_ is Optional[Tensor] 2534 # Similar to other data type, use pointer to denote optional tensor arg in v2 C shim 2535 base_handle = self.val_to_arg_str(val, element_type) 2536 if config.use_minimal_arrayref_interface: 2537 base_handle = ( 2538 f"convert_arrayref_tensor_to_tensor({base_handle})" 2539 ) 2540 ( 2541 tmp_raii_handle_var, 2542 tmp_raii_handle_var_decl, 2543 ) = self.create_tmp_raii_handle_var(base_handle) 2544 if tmp_raii_handle_var: 2545 self.writeline(tmp_raii_handle_var_decl) 2546 base_handle = tmp_raii_handle_var 2547 var_name = f"var_{next(self.arg_var_id)}" 2548 self.writeline( 2549 f"AtenTensorHandle {var_name} = {base_handle}.get();" 2550 ) 2551 return f"&{var_name}" 2552 else: 2553 return self.val_to_arg_str(val, element_type) 2554 2555 elif isinstance(type_, torch.ListType): 2556 assert isinstance( 2557 val, (list, tuple) 2558 ), f"{val} does not match with arg type {type_}" 2559 element_type = type_.getElementType() 2560 if config.abi_compatible: 2561 var_name = f"var_array_{next(self.var_array_id)}" 2562 if len(val) == 0: 2563 # Zero-size array is not supported in the C or C++ standard, so 2564 # we declare a null pointer for it. 2565 self.writeline( 2566 f"const {self.c_type_for_prim_type(None, element_type)}* {var_name} = nullptr;" 2567 ) 2568 else: 2569 result = f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}" 2570 self.writeline( 2571 f"const {self.c_type_for_prim_type(val[0], element_type)} {var_name}[] = {result};" 2572 ) 2573 # Need to pass the array length because we can't use std::vector 2574 return f"{var_name}, {len(val)}" 2575 else: 2576 return f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}" 2577 2578 return self.val_to_arg_str_for_prim_type(val, type_) 2579 2580 def create_tmp_raii_handle_var(self, base_handle): 2581 if base_handle.startswith( 2582 ( 2583 "convert_arrayref_tensor_to_tensor", 2584 "wrap_with_raii_handle_if_needed", 2585 ) 2586 ): 2587 # wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to 2588 # explicitly store it. Otherwise, it will be destroyed before the fallback kernel call. 2589 tmp_var_name = f"var_{next(self.arg_var_id)}" 2590 return ( 2591 tmp_var_name, 2592 f"RAIIAtenTensorHandle {tmp_var_name} = {base_handle};\n", 2593 ) 2594 else: 2595 return "", "" 2596