xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/cpp_wrapper_cpu.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import math
4import os
5import sys
6from itertools import count
7from typing import Dict, List, Optional, Tuple
8
9import sympy
10from sympy import Expr
11
12import torch
13import torch._inductor.async_compile  # noqa: F401 required to warm up AsyncCompile pools
14import torch._ops
15from torch._inductor.codegen.debug_utils import IntermediateValueDebuggingLevel
16from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes
17
18from .. import config, ir
19from ..utils import _align, ALIGN_BYTES, cache_on_self, sympy_product
20from ..virtualized import V
21from .aoti_hipify_utils import maybe_hipify_code_wrapper
22from .common import IndentedBuffer
23from .cpp_utils import (
24    cexpr,
25    DEVICE_TO_ATEN,
26    DTYPE_TO_ATEN,
27    DTYPE_TO_CPP,
28    LAYOUT_TO_ATEN,
29)
30from .wrapper import EnterSubgraphLine, ExitSubgraphLine, WrapperCodeGen
31
32
33class CppWrapperCpu(WrapperCodeGen):
34    """
35    Generates cpp wrapper for running on CPU and calls cpp kernels
36    """
37
38    def __init__(self):
39        if not hasattr(self, "device"):
40            self.device = "cpu"
41        super().__init__()
42        self.declare = "auto "
43        self.declare_maybe_reference = "decltype(auto) "
44        self.ending = ";"
45        self.open_bracket = "{"
46        self.closed_bracket = "}"
47        self.comment = "//"
48        self.namespace = "at::"
49        self.none_str = "nullptr" if config.abi_compatible else "at::Tensor()"
50        self.extern_call_ops = set()
51        self.size = "sizes()"
52        self.stride = "strides()"
53        self.cuda = False
54        self.supports_intermediate_hooks = False
55        self.outputs_need_copy = set()
56        self.kernel_callsite_id = count()
57        self.var_array_id = (
58            count()
59        )  # for different types of local array variable declarations
60        self.declared_var_array_vars = set()
61        self.int_array_id = count()  # for int array local variable declarations
62        self.declared_int_array_vars = set()
63        self.tmp_tensor_id = count()  # for tmp tensor local variable declarations
64        self.arg_var_id = count()
65        self.used_cached_devices = set()
66        self.used_cached_dtypes = set()
67        self.used_cached_layouts = set()
68        self.cached_output_id = count()
69        self.scalar_to_tensor_id = count()
70        self.custom_op_wrapper_loaded = False
71        self.expr_printer = cexpr
72
73    def generate_kernel_call(
74        self,
75        kernel_name: str,
76        call_args,
77        grid=None,
78        device_index=None,
79        cuda=True,
80        triton=True,
81        arg_types=None,
82        raw_args=None,
83        grid_fn: str = "grid",
84        triton_meta=None,
85        autotune_configs=None,
86        grid_extra_kwargs="",
87    ):
88        """
89        Generates kernel call code.
90
91        cuda: Defines whether the backend is GPU. Otherwise the backend is CPU.
92
93        triton: Defines whether the GPU backend uses Triton for codegen.
94                Otherwise it uses the CUDA language for codegen.
95                Only valid when cuda == True.
96        """
97        if cuda:
98            return super().generate_kernel_call(
99                kernel_name,
100                call_args,
101                grid,
102                device_index,
103                cuda,
104                triton,
105                arg_types,
106                raw_args,
107                grid_fn,
108                triton_meta,
109                autotune_configs,
110                grid_extra_kwargs,
111            )
112        else:
113            if config.abi_compatible:
114                assert arg_types is not None and len(call_args) == len(
115                    arg_types
116                ), "Mismatch call_args and arg_types in generate_kernel_call"
117                new_args = []
118                for idx, arg in enumerate(call_args):
119                    if "*" in arg_types[idx]:
120                        var_name = f"var_{next(self.arg_var_id)}"
121                        self.writeline(
122                            f"auto* {var_name} = get_data_ptr_wrapper({arg});"
123                        )
124                        new_args.append(f"({arg_types[idx]})({var_name})")
125                    else:
126                        # arg is a scalar
127                        new_args.append(arg)
128                self.writeline(self.wrap_kernel_call(kernel_name, new_args))
129            else:
130                self.writeline(self.wrap_kernel_call(kernel_name, call_args))
131
132    def write_constant(self, name, hashed):
133        # include a hash so our code cache gives different constants different files
134        self.header.writeline(f"// {name} {hashed}")
135
136    def write_header(self):
137        if V.graph.is_const_graph:
138            # We do not write header for constant graph, it will be written by main module.
139            return
140
141        if V.graph.aot_mode:
142            for header_cpp_file in ("interface.cpp", "implementation.cpp"):
143                with open(
144                    os.path.join(
145                        os.path.dirname(__file__), "aoti_runtime", header_cpp_file
146                    )
147                ) as f:
148                    self.header.splice(f.read())
149        else:
150            self.header.splice(
151                """
152                import torch
153                from torch._inductor.codecache import CppWrapperCodeCache
154
155                cpp_wrapper_src = (
156                '''
157                """
158            )
159
160        if config.abi_compatible:
161            self.header.splice(
162                f"#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{self.device}.h>"
163            )
164            self.header.splice(
165                """
166                #include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
167                #include <torch/csrc/inductor/aoti_runtime/thread_local.h>
168                #include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
169                """
170            )
171            if V.graph.aot_mode:
172                self.header.splice(
173                    """
174                    #include <torch/csrc/inductor/aoti_runtime/model.h>
175                    """
176                )
177        else:
178            self.header.splice(
179                """
180                #include <ATen/ATen.h>
181                #include <ATen/core/dispatch/Dispatcher.h>
182                #include <ATen/native/BinaryOps.h>
183                #include <torch/csrc/inductor/aoti_runtime/utils.h>
184                #include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
185                #include <torch/csrc/inductor/aoti_torch/utils.h>
186                #include <torch/csrc/inductor/inductor_ops.h>
187                #include <torch/types.h>
188                #include <ATen/ops/bernoulli_native.h>
189
190                #define reinterpret_tensor torch::inductor::_reinterpret_tensor
191                #define alloc_from_pool torch::inductor::_alloc_from_pool
192                """
193            )
194        enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
195            "linux",
196            "win32",
197        ]
198        if config.profiler_mark_wrapper_call or enable_kernel_profile:
199            self.header.splice("#include <ATen/record_function.h>")
200
201        self.header.splice("typedef at::Half half;")
202        self.header.splice("typedef at::BFloat16 bfloat16;")
203        self.header.splice("#include <c10/util/generic_math.h>")
204
205        if not V.graph.aot_mode:
206            self.header.splice(
207                """
208                #include <pybind11/pybind11.h>
209
210                namespace py = pybind11;
211                using namespace torch::aot_inductor;
212
213                class RAIIPyObject {
214                public:
215                    RAIIPyObject() : obj_(nullptr) {}
216                    RAIIPyObject(PyObject* obj) : obj_(obj) {}
217                    ~RAIIPyObject() {
218                        Py_XDECREF(obj_);
219                    }
220                    RAIIPyObject& operator=(const RAIIPyObject& other) {
221                        if (this != &other) {
222                            Py_XDECREF(obj_);
223                            obj_ = other.obj_;
224                            Py_XINCREF(obj_);
225                        }
226                        return *this;
227                    }
228                    operator PyObject*() {
229                        return obj_;
230                    }
231                    PyObject* get() {
232                        return obj_;
233                    }
234                private:
235                    PyObject* obj_;
236                };
237                """
238            )
239
240        # Round up to the nearest multiple of ALIGN_BYTES
241        # ALIGN_BYTES must be a power of 2
242        self.header.splice(
243            f"""
244            [[maybe_unused]] static int64_t align(int64_t nbytes) {{
245              return (nbytes + {ALIGN_BYTES} - 1) & -{ALIGN_BYTES};
246            }}
247            """
248        )
249
250    @functools.lru_cache(None)  # noqa: B019
251    def include_extra_header(self, header: str):
252        # This is needed for cpp to python dtype conversion
253        self.header.splice(f"#include <{header}>")
254
255    def mark_output_type(self):
256        # mark output type to unwrap tensor back to python scalar
257        from ..ir import ShapeAsConstantBuffer
258
259        output_is_tensor = {}
260        for idx, x in enumerate(V.graph.graph_outputs):
261            if isinstance(x, ShapeAsConstantBuffer):
262                output_is_tensor[idx] = False
263            else:
264                output_is_tensor[idx] = True
265
266        self.output_is_tensor = output_is_tensor
267
268    def write_prefix(self):
269        if V.graph.is_const_graph:
270            # We do not write prefix for constant graph, it will be written by main module.
271            return
272
273        if V.graph.aot_mode:
274            self.prefix.writeline("namespace torch {")
275            self.prefix.writeline("namespace aot_inductor {")
276
277    def write_input_output_info(
278        self,
279        info_kind: str,
280        idx: int,
281        name: str,
282    ):
283        self.prefix.writeline(f"""{info_kind}[{idx}].name = "{name}";""")
284
285    @staticmethod
286    def get_input_cpp_type(input):
287        assert config.use_minimal_arrayref_interface
288
289        if isinstance(input, sympy.Expr):
290            from ..graph import may_get_constant_buffer_dtype
291
292            dtype = may_get_constant_buffer_dtype(input)
293            assert dtype is not None, f"Failed to get the dtype of sympy.Expr: {input}"
294            return DTYPE_TO_CPP[dtype]
295        return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>"
296
297    def generate_input_output_runtime_checks(self):
298        # In debug_compile mode, we generate checks to ensure the dtype/shape/stride of each
299        # real input/output tensor match ones provided at compile time via sample
300        # input/output.
301        def gen_check(handle_kind, idx, name, tensor):
302            self.prefix.writeline(f"auto {name} = {handle_kind}[{idx}];")
303            self.codegen_tensor_dtype_var_decl(self.prefix, name)
304            expected_dtype_name = DTYPE_TO_ATEN[tensor.dtype]
305            dtype_str = str(tensor.dtype).split(".")[-1]
306            self.prefix.splice(
307                f"""
308                    int32_t {name}_expected_dtype = aoti_torch_dtype_{dtype_str}();
309                    if ({name}_expected_dtype != {name}_dtype) {{
310                        std::stringstream ss;
311                        ss << "{handle_kind}[{idx}]: unmatched dtype, "
312                           << "expected: " << {name}_expected_dtype << "({expected_dtype_name}), "
313                           << "but got: " << {name}_dtype << "\\n";
314                        throw std::runtime_error(ss.str());
315                    }}
316                """
317            )
318            self.codegen_input_size_var_decl(self.prefix, name)
319            for dim_idx, d in enumerate(tensor.get_size()):
320                if isinstance(d, (int, sympy.Integer)):
321                    self.prefix.splice(
322                        f"""
323                            if ({d} != {name}_size[{dim_idx}]) {{
324                                std::stringstream ss;
325                                ss << "{handle_kind}[{idx}]: unmatched dim value at {dim_idx}, "
326                                   << "expected: {d}, " << "but got: " << {name}_size[{dim_idx}]
327                                   << "\\n";
328                                throw std::runtime_error(ss.str());
329                            }}
330                        """
331                    )
332                else:
333                    from torch.utils._sympy.value_ranges import bound_sympy
334
335                    sym_range = bound_sympy(d, V.graph.sizevars.shape_env.var_to_range)
336                    if not math.isinf(sym_range.lower):
337                        self.prefix.splice(
338                            f"""
339                                if ({name}_size[{dim_idx}] < {sym_range.lower}) {{
340                                    std::stringstream ss;
341                                    ss << "{handle_kind}[{idx}]: dim value is too small at {dim_idx}, "
342                                       << "expected it to be >= {sym_range.lower}, " << "but got: "
343                                       << {name}_size[{dim_idx}] << "\\n";
344                                    throw std::runtime_error(ss.str());
345                                }}
346                            """
347                        )
348                    if not math.isinf(sym_range.upper):
349                        self.prefix.splice(
350                            f"""
351                                if ({name}_size[{dim_idx}] > {sym_range.upper}) {{
352                                    std::stringstream ss;
353                                    ss << "{handle_kind}[{idx}]: dim value is too large at {dim_idx}, "
354                                       << "expected to be <= {sym_range.upper}, " << "but got: "
355                                       << {name}_size[{dim_idx}] << "\\n";
356                                    throw std::runtime_error(ss.str());
357                                }}
358                            """
359                        )
360
361            self.codegen_input_stride_var_decl(self.prefix, name)
362            for stride_idx, s in enumerate(tensor.get_stride()):
363                if not isinstance(s, (int, sympy.Integer)):
364                    continue
365                self.prefix.splice(
366                    f"""
367                        if ({s} != {name}_stride[{stride_idx}]) {{
368                            std::stringstream ss;
369                            ss << "{handle_kind}[{idx}]: unmatched stride value at {stride_idx}, "
370                               << "expected: {s}, " << "but got: " << {name}_stride[{stride_idx}]
371                               << "\\n";
372                            throw std::runtime_error(ss.str());
373                        }}
374                    """
375                )
376
377        # force noinline to avoid any potential compilation slowdown due to aggressive
378        # inline done by the host compiler
379        self.prefix.splice(
380            """
381            AOTI_NOINLINE static void __check_inputs_outputs(
382                AtenTensorHandle* input_handles,
383                AtenTensorHandle* output_handles) {
384            """
385        )
386        with self.prefix.indent():
387            for idx, (name, tensor) in enumerate(V.graph.graph_inputs.items()):
388                gen_check("input_handles", idx, name, tensor)
389        self.prefix.writeline("}")
390
391    def write_wrapper_decl(self):
392        inputs_len = len(V.graph.graph_inputs.keys())
393        if V.graph.aot_mode:
394            if config.use_minimal_arrayref_interface and not V.graph.is_const_graph:
395                input_cpp_types = ", ".join(
396                    f"{CppWrapperCpu.get_input_cpp_type(x)}"
397                    for x in V.graph.graph_inputs.values()
398                )
399                output_arrayref_types = ", ".join(
400                    f"ArrayRefTensor<{DTYPE_TO_CPP[x.get_dtype()]}>"
401                    for x in V.graph.graph_outputs
402                )
403
404                self.prefix.splice(
405                    f"""
406                    using AOTInductorModelInputs = std::tuple<{input_cpp_types}>;
407                    using AOTInductorModelOutputs = std::tuple<{output_arrayref_types}>;
408                    """
409                )
410
411            if V.graph.const_module:
412                self.header.splice(V.graph.const_module.wrapper_code.header)
413                self.prefix.splice(V.graph.const_code)
414
415            if V.graph.is_const_graph:
416                self.prefix.splice(
417                    """
418                    void AOTInductorModel::_const_run_impl(
419                        std::vector<AtenTensorHandle>& output_handles,
420                        DeviceStreamType stream,
421                        AOTIProxyExecutorHandle proxy_executor
422                    ) {
423                    """
424                )
425            else:
426                if not config.aot_inductor.use_runtime_constant_folding:
427                    # If we do not split the constant graph, we'll just create
428                    # an empty implementation when wrapping the main module.
429                    self.prefix.splice(
430                        """
431                        void AOTInductorModel::_const_run_impl(
432                            std::vector<AtenTensorHandle>& output_handles,
433                            DeviceStreamType stream,
434                            AOTIProxyExecutorHandle proxy_executor
435                        ) {}
436
437                        """
438                    )
439
440                run_impl_proto = """
441                    void AOTInductorModel::run_impl(
442                        AtenTensorHandle*
443                            input_handles, // array of input AtenTensorHandle; handles
444                                            // are stolen; the array itself is borrowed
445                        AtenTensorHandle*
446                            output_handles, // array for writing output AtenTensorHandle; handles
447                                            // will be stolen by the caller; the array itself is
448                                            // borrowed
449                        DeviceStreamType stream,
450                        AOTIProxyExecutorHandle proxy_executor
451                    ) {
452                    """
453                # Since we are removing non-abi-compatible mode, let's generate
454                # runtime checks only for abi_compatible mode to avoid extra branches.
455                if config.aot_inductor.debug_compile and config.abi_compatible:
456                    self.generate_input_output_runtime_checks()
457                    run_impl_proto += """
458                        __check_inputs_outputs(input_handles, output_handles);
459                    """
460                if config.use_minimal_arrayref_interface:
461                    self.prefix.splice(
462                        """
463                        template <>
464                        AOTInductorModelOutputs AOTInductorModel::run_impl_minimal_arrayref_interface<
465                          AOTInductorModelInputs, AOTInductorModelOutputs>(
466                            const AOTInductorModelInputs& inputs,
467                            DeviceStreamType stream,
468                            AOTIProxyExecutorHandle proxy_executor
469                        ) {
470                        """
471                    )
472                    self.suffix.splice(run_impl_proto)
473                    self.suffix.splice(
474                        """
475                            AOTInductorModelInputs inputs;
476                            convert_handles_to_inputs(input_handles, inputs);
477                            auto outputs = run_impl_minimal_arrayref_interface<AOTInductorModelInputs, AOTInductorModelOutputs>(
478                                inputs, stream, proxy_executor);
479                            // NOTE: outputs is full of ArrayRef to thread_local storage. If in the future we need this
480                            // interface to perform well for a DSO using the minimal arrayref interface, all we need
481                            // to do is provide ThreadLocalCachedTensor for each one!
482                            convert_outputs_to_handles(outputs, output_handles);
483                        }
484                    """
485                    )
486
487                    self.suffix.splice(
488                        """
489                        extern "C" AOTIRuntimeError AOTInductorModelRunMinimalArrayrefInterface(
490                            AOTInductorModelHandle model_handle,
491                            const AOTInductorModelInputs& inputs,
492                            AOTInductorModelOutputs& outputs) {
493                          auto model = reinterpret_cast<torch::aot_inductor::AOTInductorModel*>(model_handle);
494                          CONVERT_EXCEPTION_TO_ERROR_CODE({
495                              outputs = model->run_impl_minimal_arrayref_interface<AOTInductorModelInputs, AOTInductorModelOutputs>(
496                                  inputs,
497                                  (torch::aot_inductor::DeviceStreamType)nullptr,
498                                  nullptr);
499                          })
500                        }
501                    """
502                    )
503                else:
504                    self.prefix.splice(run_impl_proto)
505        else:
506            # cpp entry function for JIT with cpp wrapper
507            self.prefix.splice(
508                """
509                void inductor_entry_impl(
510                    AtenTensorHandle*
511                        input_handles, // array of input AtenTensorHandle; handles
512                                        // are stolen; the array itself is borrowed
513                    AtenTensorHandle*
514                        output_handles  // array for writing output AtenTensorHandle; handles
515                                        // will be stolen by the caller; the array itself is
516                                        // borrowed)
517                ) {
518                """
519            )
520        with self.prefix.indent():
521            # assign inputs and outputs in both cases so the later codegen can be simplified
522            if not config.use_minimal_arrayref_interface:
523                if not V.graph.is_const_graph:
524                    if V.graph.aot_mode:
525                        num_args = len(V.graph.graph_inputs)
526                    else:
527                        # Weights are promoted in the JIT mode
528                        num_args = len(V.graph.graph_inputs) + len(V.graph.constants)
529                        # release GIL to support multiple instances inference (in different threads of the same process)
530                        self.prefix.splice("py::gil_scoped_release release;")
531
532                    if config.abi_compatible:
533                        self.prefix.splice(
534                            f"""
535                                auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args});
536                            """
537                        )
538                    else:
539                        # This looks dumb, but can avoid creating two versions of code in the AOTInductor runtime.
540                        self.prefix.splice(
541                            f"""
542                                auto inputs = alloc_tensors_by_stealing_from_handles(input_handles, {num_args});
543                            """
544                        )
545
546            if inputs_len != 0:
547                for idx, input_key in enumerate(V.graph.graph_inputs.keys()):
548                    if config.use_minimal_arrayref_interface:
549                        self.prefix.writeline(
550                            f"auto {input_key} = std::get<{idx}>(inputs);"
551                        )
552                        continue
553                    # unwrap input tensor back to scalar
554                    if isinstance(V.graph.graph_inputs[input_key], sympy.Expr):
555                        from ..graph import may_get_constant_buffer_dtype
556
557                        dtype = may_get_constant_buffer_dtype(
558                            V.graph.graph_inputs[input_key]  # type: ignore[arg-type]
559                        )
560                        assert (
561                            dtype is not None
562                        ), "Fails to get the dtype of the sympy.Expr"
563                        cpp_dtype = DTYPE_TO_CPP[dtype]
564                        if config.abi_compatible:
565                            self.codegen_tensor_item(
566                                dtype, f"inputs[{idx}]", input_key, self.prefix
567                            )
568                        else:
569                            self.prefix.writeline(
570                                f"{cpp_dtype} {input_key} = inputs[{idx}].item<{cpp_dtype}>();"
571                            )
572                    else:
573                        self.prefix.writeline(
574                            f"auto {input_key} = std::move(inputs[{idx}]);"
575                        )
576
577            assert all(
578                isinstance(v, torch.Tensor) for v in list(V.graph.constants.values())
579            ), "Expect all constants to be Tensor"
580            for idx, constants_key in enumerate(V.graph.constants.keys()):
581                if V.graph.aot_mode:
582                    # Weights are stored in constants_ and owned by RAIIAtenTensorHandle there.
583                    # Don't call std::move here because it will cause constants_ to lose the ownership.
584                    if config.abi_compatible:
585                        self.prefix.writeline(
586                            f"""auto {constants_key} = constants_->at({idx});"""
587                        )
588                    else:
589                        self.prefix.writeline(
590                            f"auto {constants_key} = *tensor_handle_to_tensor_pointer("
591                            + f"""constants_->at({idx}));"""
592                        )
593                else:
594                    # Append constants as inputs to the graph
595                    constants_idx = inputs_len + idx
596                    if config.abi_compatible:
597                        self.prefix.writeline(
598                            f"auto {constants_key} = std::move(inputs[{constants_idx}]);"
599                        )
600                    else:
601                        self.prefix.writeline(
602                            f"auto {constants_key} = inputs[{constants_idx}];"
603                        )
604
605            self.codegen_inputs(self.prefix, V.graph.graph_inputs)
606
607            if V.graph.aot_mode:
608                if not V.graph.is_const_graph:
609                    if config.use_minimal_arrayref_interface:
610                        # TODO: input shape checking for regular tensor interface as well?
611                        self.codegen_input_numel_asserts()
612                    else:
613                        self.prefix.writeline("inputs.clear();")
614                self.prefix.writeline(
615                    "auto& kernels = static_cast<AOTInductorModelKernels&>(*this->kernels_.get());"
616                )
617
618    def codegen_input_numel_asserts(self):
619        for name, buf in V.graph.graph_inputs.items():
620            if isinstance(buf, sympy.Expr):
621                continue
622
623            # comparing strides for 0 size tensor is tricky. Ignore them for now.
624            if sympy_product(buf.get_size()) == 0:
625                continue
626            numel = buf.get_numel()
627            self.prefix.writeline(f"assert_numel({name}, {numel});")
628
629    def codegen_tensor_dtype_var_decl(self, code: IndentedBuffer, name):
630        if config.abi_compatible:
631            code.writeline(f"int32_t {name}_dtype;")
632            code.writeline(
633                "AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype"
634                f"({name}, &{name}_dtype));"
635            )
636        else:
637            # Note that we don't have a corresponding class method from
638            # the WrapperCodeGen since this method is used for asserting AOTI
639            # cpp wrapper code.
640            code.writeline(f"auto {name}_dtype = {name}.dtype();")
641
642    def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
643        if config.abi_compatible:
644            code.writeline(f"int64_t* {name}_size;")
645            code.writeline(
646                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes({name}, &{name}_size));"
647            )
648        else:
649            super().codegen_input_size_var_decl(code, name)
650
651    def codegen_input_stride_var_decl(self, code: IndentedBuffer, name):
652        if config.abi_compatible:
653            code.writeline(f"int64_t* {name}_stride;")
654            code.writeline(
655                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides({name}, &{name}_stride));"
656            )
657        else:
658            super().codegen_input_stride_var_decl(code, name)
659
660    def codegen_model_kernels(self):
661        self.prefix.writeline("namespace {")
662        self.prefix.writeline(
663            "class AOTInductorModelKernels : public AOTInductorModelKernelsBase {"
664        )
665        self.prefix.writeline("  public:")
666        declare_kernel = set(self.src_to_kernel.values())
667        declare_kernel.update(
668            entry[0] for entry in self.user_defined_kernel_cache.values()
669        )
670        if V.graph.const_module:
671            declare_kernel.update(
672                V.graph.const_module.wrapper_code.src_to_kernel.values()
673            )
674        for kernel in sorted(declare_kernel):
675            self.prefix.writeline(
676                maybe_hipify_code_wrapper(f"    CUfunction {kernel}{{nullptr}};")
677            )
678        self.prefix.writeline("};")
679        self.prefix.writeline("}  // namespace")
680
681    def codegen_model_constructor(self):
682        """
683        // Generated code example
684        AOTInductorModel::AOTInductorModel()
685            : AOTInductorModelBase(4, 1) {
686        inputs_info_[0].name = "input0";
687        inputs_info_[0].dtype = "torch.float16";
688        ...
689        constants_info_[0].name = "L__self___weight";
690        constants_info_[0].dtype = at::kFloat;
691        constants_info_[0].offset = 0;
692        constants_info_[0].data_size = 8192;
693        constants_info_[0].shape = {64, 32};
694        constants_info_[0].stride = {32, 1};
695        ...
696        outputs_info_[0].name = "output0";
697        outputs_info_[0].dtype = "torch.float16";
698        }
699        """
700
701        num_inputs = len(V.graph.graph_inputs)
702        num_outputs = len(V.graph.graph_outputs)
703        num_constants = len(V.graph.constants)
704        self.prefix.splice(
705            f"""
706            AOTInductorModel::AOTInductorModel(std::shared_ptr<ConstantMap> constants_map,
707                                               std::shared_ptr<std::vector<ConstantHandle>> constants_array,
708                                               const std::string& device_str,
709                                               std::optional<std::string> cubin_dir)
710                : AOTInductorModelBase({num_inputs}, {num_outputs}, {num_constants}, device_str, cubin_dir) {{
711            """
712        )
713
714        with self.prefix.indent():
715            for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()):
716                assert not isinstance(
717                    inp, sympy.Expr
718                ), f"input {name=} cannot be symbolic"
719                self.write_input_output_info("inputs_info_", idx, name)
720
721            all_cuda = all(
722                V.graph.get_original_value_of_constant(name).is_cuda
723                for name in V.graph.constants.keys()
724                if name not in V.graph.folded_constants
725            )
726            for idx, name in enumerate(V.graph.constants.keys()):
727                tensor = V.graph.get_original_value_of_constant(name)
728                assert isinstance(tensor, torch.Tensor)
729                self.prefix.writeline(f"""constants_info_[{idx}].name = "{name}";""")
730                self.prefix.writeline(
731                    f"constants_info_[{idx}].dtype = static_cast<int32_t>({self.codegen_dtype(tensor.dtype)});"
732                )
733                self.prefix.writeline(
734                    f"constants_info_[{idx}].offset = {tensor.storage_offset()};"
735                )
736
737                # If constants to serialize contain cpu tensors, we always align data_size it to 64.
738                # When loading the constants, the valid data will depends on the size
739                # not the data_size so there won't be correctness issue.
740                data_size = (
741                    torch.ops.mkldnn._nbytes(tensor)
742                    if tensor.is_mkldnn
743                    else tensor.untyped_storage().nbytes()
744                )
745                self.prefix.writeline(
746                    f"constants_info_[{idx}].data_size = {data_size if all_cuda else _align(data_size)};"
747                )
748
749                from_folded = "true" if name in V.graph.folded_constants else "false"
750                self.prefix.writeline(
751                    f"constants_info_[{idx}].from_folded = {from_folded};"
752                )
753
754                size_str = ", ".join([str(s) for s in tensor.size()])
755                self.prefix.writeline(f"constants_info_[{idx}].shape = {{{size_str}}};")
756
757                stride_str = ", ".join([str(s) for s in tensor.stride()])
758                self.prefix.writeline(
759                    f"constants_info_[{idx}].stride = {{{stride_str}}};"
760                )
761                self.prefix.writeline(
762                    f"constants_info_[{idx}].layout = static_cast<int32_t>({self.codegen_layout(tensor.layout)});"
763                )
764
765                if tensor.is_mkldnn:
766                    opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md(
767                        tensor
768                    )
769                    assert (
770                        opaque_metadata_tensor.dim() == 1
771                    ), "Expect opaque_metadata_tensor to be 1-D"
772
773                    opaque_metadata_list = opaque_metadata_tensor.tolist()
774                    opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list)
775                    self.prefix.writeline(
776                        f"constants_info_[{idx}].opaque_metadata = {opaque_metadata_str};"
777                    )
778                if name in V.graph.dynamo_flat_name_to_original_fqn:
779                    original_fqn = V.graph.dynamo_flat_name_to_original_fqn.get(
780                        name, name
781                    )
782                elif name in V.graph.allocated_constant_name:
783                    original_fqn = V.graph.allocated_constant_name[name]
784                else:
785                    raise AssertionError("original_fqn must be set for constant")
786                self.prefix.writeline(
787                    f"""constants_info_[{idx}].original_fqn = "{original_fqn}";"""
788                )
789            self.prefix.writeline("update_constants_map(std::move(constants_map));")
790            self.prefix.writeline("update_constants_array(std::move(constants_array));")
791
792            def escape_string(x):
793                return (
794                    x.replace("\\", "\\\\")
795                    .replace('"', '\\"')
796                    .replace("\n", "\\n")
797                    .replace("\t", "\\t")
798                )
799
800            self.prefix.writeline(
801                f'in_spec_ = "{escape_string(config.aot_inductor.serialized_in_spec)}";'
802            )
803            self.prefix.writeline(
804                f'out_spec_ = "{escape_string(config.aot_inductor.serialized_out_spec)}";'
805            )
806
807            for idx, output in enumerate(V.graph.graph_outputs):
808                assert not isinstance(
809                    output, sympy.Expr
810                ), f"output {name=} cannot be symbolic"
811                name = f"output{idx}"
812                self.write_input_output_info("outputs_info_", idx, name)
813
814            self.prefix.writeline(
815                "this->kernels_ = std::make_unique<AOTInductorModelKernels>();"
816            )
817
818        self.prefix.writeline("}")
819
820    def codegen_const_run_driver(self):
821        """
822        // Generated code example
823        std::unordered_map<std::string, AtenTensorHandle> AOTInductorModel::const_run_impl(
824            DeviceStreamType stream,
825            AOTIProxyExecutorHandle proxy_executor,
826            bool initialization
827        ) {
828            std::unordered_map<std::string, AtenTensorHandle> folded_constants_map;
829            std::vector<AtenTensorHandle> output_handles;
830            // build up output_handles over here.
831            _const_run_impl(output_handles, stream, proxy_executor);
832            // build up folded_constants_map
833            return folded_constants_map;
834        }
835        """
836
837        self.prefix.splice(
838            """
839            std::unordered_map<std::string, AtenTensorHandle> AOTInductorModel::const_run_impl(
840                DeviceStreamType stream,
841                AOTIProxyExecutorHandle proxy_executor,
842                bool initialization
843            ) {
844            """
845        )
846        if not config.aot_inductor.use_runtime_constant_folding:
847            self.prefix.splice(
848                """
849                    if (!initialization) {
850                        std::cerr << "[WARNING] Calling constant_folding in model, but compiled with config: "
851                                  << "aot_inductor.use_runtime_constant_folding=False\\n";
852                    }
853                    return {};
854                }
855                """
856            )
857            return
858
859        with self.prefix.indent():
860            # This is a mapping to the index of constant folding graph's output
861            const_index_mapping: List[Optional[Tuple[int, str]]] = [None] * len(
862                V.graph.const_output_index
863            )
864            for idx, (name, _) in enumerate(V.graph.constants.items()):
865                if name in V.graph.const_output_index:
866                    const_index_mapping[V.graph.const_output_index[name]] = (idx, name)  # type: ignore[call-overload]
867            assert (
868                None not in const_index_mapping
869            ), "Not all constant gets mapped for constant folding graph."
870
871            self.prefix.writeline(
872                f"""
873                std::unordered_map<std::string, AtenTensorHandle> folded_constants_map;
874                folded_constants_map.reserve({len(const_index_mapping)});
875                std::vector<AtenTensorHandle> output_handles({len(const_index_mapping)});
876                """
877            )
878
879            self.prefix.splice(
880                """
881                // The below assignment of output_handles to constants is not used directly.
882                // It's only used to memo the correspondence of handle and constants.
883                """
884            )
885
886            for output_idx, (const_idx, _) in enumerate(const_index_mapping):  # type: ignore[misc]
887                self.prefix.writeline(
888                    f"output_handles[{output_idx}] = constants_->at({const_idx});"
889                )
890
891            self.prefix.writeline(
892                "_const_run_impl(output_handles, stream, proxy_executor);"
893            )
894
895            for output_idx, (_, const_name) in enumerate(const_index_mapping):  # type: ignore[misc]
896                self.prefix.writeline(
897                    f'folded_constants_map["{const_name}"] = output_handles[{output_idx}];'
898                )
899            self.prefix.writeline("return folded_constants_map;")
900
901        self.prefix.writeline("}")
902
903    def generate(self, is_inference):
904        if V.graph.aot_mode and not V.graph.is_const_graph:
905            self.codegen_model_kernels()
906            self.codegen_model_constructor()
907            self.codegen_const_run_driver()
908        self.write_wrapper_decl()
909        return super().generate(is_inference)
910
911    def finalize_prefix(self):
912        cached_dtypes_buffer = IndentedBuffer()
913        if config.abi_compatible:
914            for dtype in self.used_cached_dtypes:
915                cached_dtypes_buffer.writeline(f"CACHE_TORCH_DTYPE({dtype});")
916            for device in self.used_cached_devices:
917                cached_dtypes_buffer.writeline(f"CACHE_TORCH_DEVICE({device});")
918            for layout in self.used_cached_layouts:
919                cached_dtypes_buffer.writeline(f"CACHE_TORCH_LAYOUT({layout});")
920        cached_dtypes_buffer.splice(self.prefix)
921        self.prefix = cached_dtypes_buffer
922
923    def define_kernel(
924        self, name: str, kernel: str, metadata: Optional[str] = None, cuda=False
925    ):
926        self.header.splice(f"\n{kernel}\n")
927
928    def codegen_scalar_to_tensor(self, output: str):
929        name = f"scalar_to_tensor_{next(self.scalar_to_tensor_id)}"
930        self.wrapper_call.writeline(
931            f"RAIIAtenTensorHandle {name} = scalar_to_tensor_handle({output});"
932        )
933        return name
934
935    def codegen_tensor_item(
936        self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None
937    ):
938        assert (
939            config.abi_compatible
940        ), "codegen_tensor_item is only used for the ABI-compatible mode"
941        dtype_str = str(dtype).split(".")[-1]
942        writer = indented_buffer or self
943
944        if dtype == torch.float16 or dtype == torch.bfloat16:
945            scalar_tmp = f"{scalar}_tmp"
946            writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar_tmp};")
947
948            # need convert_arrayref_tensor_to_tensor for ArrayRefTensors
949            tensor = f"convert_arrayref_tensor_to_tensor({tensor})"
950
951            writer.writeline(
952                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));"
953            )
954            writer.writeline(f"float {scalar} = float({scalar_tmp});")
955        else:
956            writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};")
957
958            # need convert_arrayref_tensor_to_tensor for ArrayRefTensors
959            tensor = f"convert_arrayref_tensor_to_tensor({tensor})"
960
961            writer.writeline(
962                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));"
963            )
964
965    @cache_on_self
966    def get_output_refs(self):
967        return [
968            f"torch::tensor({x.codegen_reference(self.wrapper_call)})"
969            if isinstance(x, ir.ShapeAsConstantBuffer) and not config.abi_compatible
970            else x.codegen_reference(self.wrapper_call)
971            for x in V.graph.graph_outputs
972        ]
973
974    def generate_return(self, output_refs: List[str]):
975        cst_names = V.graph.constants.keys()
976        arr_iface = (
977            not V.graph.is_const_graph and config.use_minimal_arrayref_interface
978        )  # For brevity.
979
980        def use_thread_local_cached_output_tensor(idx, output):
981            cached_output_name = f"cached_output_{next(self.cached_output_id)}"
982            cache_type = "Array" if arr_iface else "Tensor"
983            self.wrapper_call.writeline(
984                f"thread_local ThreadLocalCachedOutput{cache_type}<std::decay_t<decltype({output})>> "
985                f"{cached_output_name}({output});"
986            )
987            if arr_iface:
988                self.wrapper_call.writeline(
989                    f"{cached_output_name}.copy_data_from({output});"
990                )
991                output_entry = f"std::get<{idx}>(output_arrayref_tensors)"
992                element_type = f"std::decay_t<decltype({output_entry}.data()[0])>"
993                self.wrapper_call.writeline(
994                    f"{output_entry} = {cached_output_name}.arrayref_tensor<{element_type}>();"
995                )
996            else:
997                self.wrapper_call.writeline(
998                    f"{cached_output_name}.copy_data_from({output});"
999                )
1000                self.wrapper_call.writeline(
1001                    f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]));"
1002                )
1003                self.wrapper_call.writeline(
1004                    f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors({cached_output_name}.tensor(), "
1005                    f"output_handles[{idx}]));"
1006                )
1007
1008        if arr_iface:
1009            self.wrapper_call.writeline(
1010                "AOTInductorModelOutputs output_arrayref_tensors;"
1011            )
1012
1013        output2idx: Dict[str, int] = {}
1014        for idx, output in enumerate(output_refs):
1015            if output == self.none_str:
1016                continue
1017
1018            is_constant_buffer = output in cst_names
1019            output_buffer = V.graph.graph_outputs[idx]
1020            if isinstance(output_buffer, ir.BaseView):
1021                output_storage = output_buffer.unwrap_view()
1022                if isinstance(output_storage.data, ir.ConstantBuffer):
1023                    is_constant_buffer = True
1024
1025            if config.abi_compatible:
1026                if isinstance(output_buffer, ir.ShapeAsConstantBuffer):
1027                    # Need to wrap scalar into tensor as the main function returns a vector of tensors
1028                    output_tensor = self.codegen_scalar_to_tensor(output)
1029                    self.wrapper_call.writeline(
1030                        f"output_handles[{idx}] = {output_tensor}.release();"
1031                    )
1032                    continue
1033
1034                output_is_tensor_handle_expr = (
1035                    f"std::is_same_v<std::decay_t<decltype({output})>,"
1036                    "RAIIAtenTensorHandle> || "
1037                    f"std::is_same_v<std::decay_t<decltype({output})>,"
1038                    "AtenTensorHandle> || "
1039                    f"std::is_same_v<std::decay_t<decltype({output})>,"
1040                    "ConstantHandle>"
1041                )
1042                self.wrapper_call.writeline(
1043                    f"if constexpr ({output_is_tensor_handle_expr}) {{"
1044                )
1045                with self.wrapper_call.indent():
1046                    if arr_iface:
1047                        cached_output_name = (
1048                            f"cached_output_{next(self.cached_output_id)}"
1049                        )
1050                        output_value_type = f"std::decay_t<decltype(std::get<{idx}>(output_arrayref_tensors).data()[0])>"
1051                        self.wrapper_call.writeline(
1052                            f"thread_local RAIIAtenTensorHandle {cached_output_name};"
1053                        )
1054                        if is_constant_buffer:
1055                            # NOTE(return_constant): In some rare cases where we return
1056                            # a constant, we have to return a copy of this constant,
1057                            # because (1) constants are not owned by the Model instance
1058                            # (2) constants remain the same cross inference runs,
1059                            # assuming they are not updated at runtime Basically, we
1060                            # cannot release or transfer the ownership of any original
1061                            # constant to the user.
1062                            self.wrapper_call.writeline(
1063                                f"AtenTensorHandle {cached_output_name}_tmp;"
1064                            )
1065                            self.wrapper_call.writeline(
1066                                f"aoti_torch_clone({output}, &{cached_output_name}_tmp);"
1067                            )
1068                            self.wrapper_call.writeline(
1069                                f"{cached_output_name} = {cached_output_name}_tmp;"
1070                            )
1071                        else:
1072                            self.wrapper_call.writeline(
1073                                f"{cached_output_name} = {output}.release();"
1074                            )
1075                        self.wrapper_call.writeline(
1076                            f"convert_handle_to_arrayref_tensor({cached_output_name}, "
1077                            f"std::get<{idx}>(output_arrayref_tensors));"
1078                        )
1079                    else:
1080                        if is_constant_buffer:
1081                            # See NOTE(return_constant) above.
1082                            self.wrapper_call.writeline(
1083                                f"aoti_torch_clone({output}, &output_handles[{idx}]);"
1084                            )
1085                        else:
1086                            if output in output2idx:
1087                                src_idx = output2idx[output]
1088                                self.wrapper_call.writeline(
1089                                    f"output_handles[{idx}] = output_handles[{src_idx}];"
1090                                )
1091                            else:
1092                                self.wrapper_call.writeline(
1093                                    f"output_handles[{idx}] = {output}.release();"
1094                                )
1095                self.wrapper_call.writeline("} else {")
1096                with self.wrapper_call.indent():
1097                    use_thread_local_cached_output_tensor(idx, output)
1098                self.wrapper_call.writeline("}")
1099
1100            else:
1101                assert (
1102                    not arr_iface
1103                ), "minimal ArrayRef interface is only supported in ABI-compatible mode"
1104                if is_constant_buffer:
1105                    output_expr = f"{output}.clone()"
1106                    # See NOTE(return_constant) above.
1107                else:
1108                    output_expr = output
1109                self.wrapper_call.writeline(
1110                    f"output_handles[{idx}] = reinterpret_cast<AtenTensorHandle>("
1111                    + f"new at::Tensor({output_expr}));"
1112                )
1113
1114            if output not in output2idx:
1115                output2idx[output] = idx
1116        if arr_iface:
1117            self.wrapper_call.writeline("return output_arrayref_tensors;")
1118
1119    def generate_before_suffix(self, result):
1120        if not V.graph.is_const_graph:
1121            if V.graph.aot_mode:
1122                result.writeline("} // AOTInductorModel::run_impl")
1123            else:
1124                result.writeline("} // inductor_entry_impl")
1125
1126    def generate_end(self, result):
1127        if V.graph.aot_mode:
1128            if V.graph.is_const_graph:
1129                result.writeline("} // AOTInductorModel::_const_run_impl")
1130            else:
1131                result.writeline("} // namespace aot_inductor")
1132                result.writeline("} // namespace torch")
1133            return
1134
1135        # cpp entry function for JIT with cpp wrapper
1136        result.writeline("'''\n)")
1137        result.splice(
1138            f"""
1139            inductor_entry = CppWrapperCodeCache.load_pybinding(
1140                ["std::vector<AtenTensorHandle>"], cpp_wrapper_src, {self.cuda}, {len(V.graph.graph_outputs)})
1141            """
1142        )
1143
1144        wrapper_body = "input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]"
1145        if V.graph.constants:
1146            # Append constants to the input args for cpp wrapper.
1147            # Python wrapper directly gets the value inside the wrapper call
1148            # as a global variable passed when calling exec(code, mod.__dict__, mod.__dict__).
1149            # For cpp wrapper, we need to pass this python value to the inductor_entry_impl function explicitly.
1150            assert all(
1151                isinstance(v, torch.Tensor) for v in list(V.graph.constants.values())
1152            ), "Expect all constants to be Tensor"
1153            constants_str = f"[{', '.join(V.graph.constants.keys())}]"
1154            wrapper_body += f"""
1155                    constants_tensor = {constants_str}
1156                    input_tensors.extend(constants_tensor)
1157            """
1158        # Convert vector of at::Tensor to vector of AtenTensorHandle.
1159        # If we pass at::Tensor, the compilation will be too slow.
1160        wrapper_body += """
1161                    input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors)
1162        """
1163        # Release the inputs for memory reuse.
1164        wrapper_body += """
1165                    args.clear()
1166        """
1167
1168        # unwrap output tensor back to python scalar
1169        if all(x for x in self.output_is_tensor.values()):
1170            # If no ShapeAsConstantBuffer in the output, directly return the output as tensors
1171            outputs_str = "output_tensors"
1172        else:
1173            outputs = [
1174                f"output_tensors[{i}]"
1175                if self.output_is_tensor[i]
1176                else f"output_tensors[{i}].item()"
1177                for i in range(len(V.graph.graph_outputs))
1178            ]
1179            outputs_str = f"[{', '.join(outputs)}]"
1180        wrapper_body += f"""
1181                    output_handles = f(input_handles)
1182                    output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles)
1183                    return {outputs_str}
1184        """
1185
1186        # Wrap the func to support setting result._boxed_call = True
1187        result.splice(
1188            f"""
1189            def _wrap_func(f):
1190                def g(args):
1191                    {wrapper_body}
1192                return g
1193
1194            call = _wrap_func(inductor_entry)
1195            """
1196        )
1197
1198    def get_c_shim_func_name(self, kernel):
1199        if not config.abi_compatible or kernel.startswith("aoti_torch_"):
1200            return kernel
1201
1202        assert "::" in kernel, "Cpp kernel name: " + kernel + " does not contain '::'"
1203        kernel_tokens = kernel.split("::")
1204        kernel_suffix = kernel_tokens[-1]
1205        if kernel_suffix == "call":
1206            kernel_suffix = kernel_tokens[-2]
1207
1208        shim_fn = f"aoti_torch_{self.device}_{kernel_suffix}"
1209        return shim_fn
1210
1211    def generate_c_shim_extern_kernel_call(self, kernel, args):
1212        # In the abi_compatible mode, we call fallback aten ops through a C shim layer
1213        # Setting self.allow_stack_allocation to False because the exchange between
1214        # ArrayRefTensor and at::Tensor is still fragile.
1215        self.allow_stack_allocation = False
1216
1217        wrapped_args = []
1218
1219        args_to_print_or_save = None
1220        debug_printer_manager = V.graph.wrapper_code.debug_printer
1221        if (
1222            debug_printer_manager.debug_printer_level
1223            != IntermediateValueDebuggingLevel.OFF
1224        ):
1225            args_to_print_or_save = []
1226
1227        for x in args:
1228            pieces = x.split(", ")
1229            for piece in pieces:
1230                # We only really *need* convert_arrayref_tensor_to_tensor for
1231                # ArrayRefTensors. The code flowing into here uses `0` for nullptr,
1232                # which convert_arrayref_tensor_to_tensor would blindly coerce to int,
1233                # so just avoid wrapping integers.
1234                # Name matching is to find tensor is hacky, but fixing all the
1235                # ArrayRefTensor issues is not a priority for now.
1236                if isinstance(piece, str) and piece.startswith(
1237                    ("buf", "arg", "wrap_with_raii_handle_if_needed")
1238                ):
1239                    # TODO: The current way to find a 'tensor' type arg is hacky also as mentioned above
1240                    # Find a more reliable way to detect tensor kernel args for extern kernel calls
1241                    if (
1242                        debug_printer_manager.debug_printer_level
1243                        != IntermediateValueDebuggingLevel.OFF
1244                    ):
1245                        if piece.startswith(("buf", "arg")):
1246                            args_to_print_or_save.append(piece)
1247                    piece = f"convert_arrayref_tensor_to_tensor({piece})"
1248                wrapped_args.append(piece)
1249
1250        debug_printer_manager.set_printer_args(
1251            args_to_print_or_save, kernel, None, None
1252        )
1253        with debug_printer_manager:
1254            shim_fn = self.get_c_shim_func_name(kernel)
1255            self.writeline(
1256                f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(wrapped_args)}));"
1257            )
1258
1259    def generate_c_shim_extern_kernel_alloc(self, extern_kernel, args):
1260        # registered output buffer name
1261        name = extern_kernel.name
1262        output_handle_name = f"{name}_handle"
1263        self.writeline(f"AtenTensorHandle {output_handle_name};")
1264        output_arg = f"&{output_handle_name}"
1265        self.generate_c_shim_extern_kernel_call(
1266            extern_kernel.get_kernel_name(), args + [output_arg]
1267        )
1268        self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});")
1269
1270    def generate_extern_kernel_alloc(self, extern_kernel, args):
1271        if config.abi_compatible:
1272            self.generate_c_shim_extern_kernel_alloc(extern_kernel, args)
1273        else:
1274            super().generate_extern_kernel_alloc(extern_kernel, args)
1275
1276    def generate_c_shim_fallback_kernel(self, fallback_kernel, args):
1277        output_args = []
1278        output_raii_handles = []
1279        output_name_base = fallback_kernel.get_name()
1280        for idx, output in enumerate(fallback_kernel.outputs):
1281            if isinstance(output, ir.MultiOutput):
1282                # TODO: handle integer output (e.g., as in attention)
1283                name = f"{output.get_name()}"
1284                output_handle_name = f"{name}_handle"
1285                if output.indices:
1286                    assert (
1287                        output.indices[0][1] == idx
1288                    ), f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}"
1289                self.writeline(f"AtenTensorHandle {output_handle_name};")
1290                output_args.append(f"&{output_handle_name}")
1291                output_raii_handles.append(
1292                    f"RAIIAtenTensorHandle {name}({output_handle_name});"
1293                )
1294            elif isinstance(output, int):
1295                output_name = f"{output_name_base}_{idx}"
1296                self.writeline(f"int64_t {output_name} = {output};")
1297                output_args.append(f"&{output_name}")
1298            elif isinstance(output, sympy.Symbol):
1299                output_name = f"{output_name_base}_{idx}"
1300                self.writeline(f"auto {output_name} = {output};")
1301                output_args.append(f"&{output_name}")
1302            elif output is None:
1303                output_args.append("nullptr")
1304            else:
1305                raise NotImplementedError(f"unsupported type of {output=}")
1306        args = args + output_args
1307        self.generate_c_shim_extern_kernel_call(fallback_kernel.cpp_kernel_name, args)
1308        for raii_handle in output_raii_handles:
1309            self.writeline(raii_handle)
1310
1311    def generate_fallback_kernel(self, fallback_kernel, args):
1312        if config.abi_compatible:
1313            self.generate_c_shim_fallback_kernel(fallback_kernel, args)
1314        else:
1315            super().generate_fallback_kernel(fallback_kernel, args)
1316
1317    def generate_extern_kernel_out(
1318        self, kernel: str, out: str, out_view: Optional[str], args: List[str]
1319    ):
1320        if out_view:
1321            out_name = f"{out}_as_strided"
1322            self.writeline(f"auto {out_name} = {out_view};")
1323            args.insert(0, out_name)
1324        else:
1325            args.insert(0, out)
1326
1327        if config.abi_compatible:
1328            self.generate_c_shim_extern_kernel_call(kernel, args)
1329        else:
1330            # TODO: add debug printing info for non-abi compatible mode extern kernel call
1331            self.writeline(self.wrap_kernel_call(kernel, args))
1332
1333    def generate_scatter_fallback(
1334        self,
1335        output,
1336        inputs,
1337        cpp_kernel_name,
1338        python_kernel_name,
1339        src_is_tensor,
1340        reduce,
1341        kwargs,
1342    ):
1343        # No stack allocation when there is a fallback op
1344        self.allow_stack_allocation = False
1345
1346        if config.abi_compatible:
1347            # call the ABI shim function instead of the ATen one
1348            cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name)
1349            # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
1350            cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
1351            inputs_wrapped = [
1352                f"convert_arrayref_tensor_to_tensor({x})"
1353                if isinstance(x, str)
1354                else str(x)
1355                for x in inputs
1356            ]
1357            line = f"{cpp_kernel_name}(convert_arrayref_tensor_to_tensor({output}), {','.join(inputs_wrapped)}"
1358        else:
1359            line = f"{cpp_kernel_name}({','.join(map(str, inputs))}"
1360
1361        if python_kernel_name.startswith("aten.scatter_reduce"):
1362            line += f", {','.join(kwargs)}"
1363        else:
1364            if src_is_tensor:
1365                if reduce:
1366                    line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
1367            else:
1368                assert (
1369                    reduce is None
1370                ), "Expect reduce to be None for aten.scatter_ with scalar src"
1371        line += ");"
1372        self.writeline(line)
1373
1374    def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
1375        # No stack allocation when there is a fallback op
1376        self.allow_stack_allocation = False
1377
1378        # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version
1379        if config.abi_compatible:
1380            # See the comment in codegen_reinterpret_view about why having something like
1381            # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding
1382            # tensor prematurely deallocated, thus this std::vector().data() trick here.
1383            indices_str = (
1384                "std::vector<AtenTensorHandle>{"
1385                + (
1386                    ", ".join(
1387                        [f"convert_arrayref_tensor_to_tensor({ind})" for ind in indices]
1388                    )
1389                )
1390                + "}.data()"
1391            )
1392            args = [
1393                f"convert_arrayref_tensor_to_tensor({x})",
1394                indices_str,
1395                str(len(indices)),
1396                f"convert_arrayref_tensor_to_tensor({values})",
1397                accumulate,
1398            ]
1399            args.insert(
1400                0, f"convert_arrayref_tensor_to_tensor({x})"
1401            )  # set x as the output tensor, this fallback mutates x.
1402        else:
1403            indices_str = (
1404                f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}"
1405            )
1406            args = [x, indices_str, values, accumulate]
1407            args.insert(0, x)  # set x as the output tensor, this fallback mutates
1408
1409        self.writeline(self.wrap_kernel_call(kernel, args))
1410
1411    def add_benchmark_harness(self, output):
1412        if V.graph.aot_mode:
1413            return
1414        super().add_benchmark_harness(output)
1415
1416    def codegen_sizevar(self, x: Expr) -> str:
1417        return self.expr_printer(V.graph.sizevars.simplify(x))
1418
1419    def codegen_tuple_access(self, basename: str, name: str, index: str) -> str:
1420        if config.abi_compatible:
1421            # in the abi_compatible mode, outputs are returned via arguments
1422            return name
1423        else:
1424            return f"std::get<{index}>({basename})"
1425
1426    def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
1427        parts = list(map(self.codegen_sizevar, shape))
1428        if len(parts) == 0:
1429            return "{}"
1430        if len(parts) == 1:
1431            return f"{{{parts[0]}, }}"
1432        return f"{{{', '.join(parts)}}}"
1433
1434    def codegen_dynamic_scalar(self, node):
1435        (data,) = (t.codegen_reference() for t in node.inputs)
1436        if config.abi_compatible:
1437            self.codegen_tensor_item(
1438                node.inputs[0].get_dtype(), data, f"{node.sym}_raw"
1439            )
1440        else:
1441            convert_type = DTYPE_TO_ATEN[node.inputs[0].get_dtype()].replace(
1442                "at::k", "to"
1443            )
1444            self.writeline(f"auto {node.sym}_raw = {data}.item().{convert_type}();")
1445
1446        if len(node.keypath) == 0:
1447            self.writeline(f"auto {node.sym} = {node.sym}_raw;")
1448        elif len(node.keypath == 1) and isinstance(node.keypath[0], ConvertIntKey):
1449            self.writeline(f"int64_t {node.sym} = {node.sym}_raw ? 1 : 0;")
1450        elif len(node.keypath == 1) and isinstance(node.keypath[0], DivideByKey):
1451            # TODO: assert divisibility here
1452            self.writeline(
1453                f"int64_t {node.sym} = {node.sym}_raw / {node.keypath[0].divisor};"
1454            )
1455        else:
1456            raise AssertionError(f"unrecognized keypath {node.keypath}")
1457
1458        # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
1459        self.unbacked_symbol_decls.add(str(node.sym))
1460
1461    def can_stack_allocate_buffer(self, buffer):
1462        return (
1463            self.allow_stack_allocation
1464            and buffer.get_device().type == "cpu"
1465            and self.can_prove_buffer_has_static_shape(buffer)
1466            and ir.is_contiguous_strides_for_shape(
1467                buffer.get_stride(), buffer.get_size()
1468            )
1469        )
1470
1471    def make_buffer_free(self, buffer):
1472        return (
1473            ""
1474            if isinstance(buffer.get_layout(), ir.MultiOutputLayout)
1475            or (V.graph.aot_mode and buffer.get_name() in self.stack_allocated_buffers)
1476            or (
1477                config.use_minimal_arrayref_interface
1478                and V.graph.aot_mode
1479                and buffer.get_name() in V.graph.graph_inputs
1480            )
1481            else f"{buffer.get_name()}.reset();"
1482        )
1483
1484    def make_free_by_names(self, names_to_del: List[str]):
1485        return " ".join(f"{name}.reset();" for name in names_to_del)
1486
1487    def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str):
1488        if config.abi_compatible:
1489            return f"auto {new_name} = std::move({old_name});  // reuse"
1490        else:
1491            return super().codegen_exact_buffer_reuse(old_name, new_name, del_line)
1492
1493    def generate_profiler_mark_wrapper_call(self, stack):
1494        self.wrapper_call.writeline(
1495            'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef<c10::IValue>());'
1496        )
1497
1498    def write_triton_header_once(self):
1499        pass
1500
1501    def generate_start_graph(self):
1502        pass
1503
1504    def generate_end_graph(self):
1505        pass
1506
1507    def generate_inf_and_nan_checker(self, nodes):
1508        for buf in nodes.get_names():
1509            # TODO: Add buf name directly into check_inf_and_nan.
1510            self.writeline(
1511                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan({buf}));"
1512            )
1513
1514    def codegen_device(self, device):
1515        if config.abi_compatible:
1516            self.used_cached_devices.add(device.type)
1517            return f"cached_torch_device_type_{device.type}, {device.index if device.index else 0}"
1518        else:
1519            return (
1520                f"c10::Device({DEVICE_TO_ATEN[device.type]}, {device.index})"
1521                if device.index is not None
1522                else f"{DEVICE_TO_ATEN[device.type]}"
1523            )
1524
1525    def codegen_dtype(self, dtype):
1526        if config.abi_compatible:
1527            dtype_str = str(dtype).split(".")[-1]
1528            self.used_cached_dtypes.add(dtype_str)
1529            return f"cached_torch_dtype_{dtype_str}"
1530        else:
1531            return DTYPE_TO_ATEN[dtype]
1532
1533    def codegen_layout(self, layout):
1534        if config.abi_compatible:
1535            layout_str = str(layout).split(".")[-1]
1536            self.used_cached_layouts.add(layout_str)
1537            return f"cached_torch_layout_{layout_str}"
1538        else:
1539            return LAYOUT_TO_ATEN[layout]
1540
1541    @functools.lru_cache(None)  # noqa: B019
1542    def codegen_int_array_var(
1543        self,
1544        int_array: str,
1545        writer=None,
1546        known_statically=False,
1547        graph=None,  # for per-graph caching
1548    ):
1549        # This is used for size/stride declaration
1550        # Because the memory planning is done in two passes (see the implementation
1551        # of self.generate), the writeline behavior is different in the two passes.
1552        # As a result, the emitted int array declarations may appear in a later
1553        # position of the generated code, so the second pass codegen should not
1554        # reuse int array declarations generated in the first pass
1555        if writer is None:
1556            # The first pass codegen uses `self` as the writer
1557            writer = self
1558
1559        var = f"int_array_{next(self.int_array_id)}"
1560        ctype = "int64_t"
1561        if var not in self.declared_int_array_vars:
1562            self.declared_int_array_vars.add(var)
1563            if known_statically:
1564                writer.writeline(f"static constexpr {ctype} {var}[] = {int_array};")
1565            else:
1566                writer.writeline(f"const {ctype} {var}[] = {int_array};")
1567        return var
1568
1569    def make_buffer_allocation(self, buffer):
1570        return self.make_allocation(
1571            buffer.get_name(),
1572            buffer.get_device(),
1573            buffer.get_dtype(),
1574            buffer.get_size(),
1575            buffer.get_stride(),
1576            buffer if self.can_stack_allocate_buffer(buffer) else None,
1577        )
1578
1579    def make_allocation(
1580        self, name, device, dtype, shape, stride, buffer_if_can_stack_allocate=None
1581    ):
1582        orig_stride = stride
1583        device_str = self.codegen_device(device)
1584        dtype_code = self.codegen_dtype(dtype)
1585        size = self.codegen_shape_tuple(shape)
1586        stride = self.codegen_shape_tuple(orig_stride)
1587        if config.abi_compatible:
1588            size_array_var = self.codegen_int_array_var(
1589                size,
1590                self.wrapper_call,
1591                known_statically=self.is_statically_known_list_of_ints(shape),
1592                graph=self.get_codegened_graph(),
1593            )
1594            stride_array_var = self.codegen_int_array_var(
1595                stride,
1596                self.wrapper_call,
1597                known_statically=self.is_statically_known_list_of_ints(orig_stride),
1598                graph=self.get_codegened_graph(),
1599            )
1600            device_type, device_id = device_str.split(",")
1601            device_idx = "this->device_idx_" if V.graph.aot_mode else device_id
1602            if buffer_if_can_stack_allocate is not None:
1603                self.stack_allocated_buffers[name] = buffer_if_can_stack_allocate
1604                cpp_type = DTYPE_TO_CPP[dtype]
1605                numel = buffer_if_can_stack_allocate.get_numel()
1606                # Note: we don't zero storage because empty_strided doesn't zero either.
1607                self.wrapper_call.writeline(f"{cpp_type} {name}_storage[{numel}];")
1608                args = [
1609                    f"{name}_storage",
1610                    size_array_var,
1611                    stride_array_var,
1612                    device_type,
1613                    device_idx,
1614                ]
1615                return f"ArrayRefTensor<{cpp_type}> {name}({', '.join(args)});"
1616
1617            args = [
1618                str(len(shape)),
1619                size_array_var,
1620                stride_array_var,
1621                dtype_code,
1622                device_type,
1623                device_idx,
1624                f"&{name}_handle",
1625            ]
1626
1627            self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;")
1628            self.wrapper_call.writeline(
1629                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));"
1630            )
1631
1632            return f"RAIIAtenTensorHandle {name}({name}_handle);"
1633
1634        if V.graph.aot_mode and device_str.startswith("c10::Device("):
1635            tensor_device = f"{device_str.split(',')[0]}, this->device_idx_)"
1636        else:
1637            tensor_device = device_str
1638
1639        if device.type == "cpu":
1640            return f"at::Tensor {name} = at::detail::empty_strided_cpu({size}, {stride}, {dtype_code});"
1641        if device.type == "cuda":
1642            return (
1643                f"at::Tensor {name} = at::detail::empty_strided_cuda("
1644                f"{size}, {stride}, {dtype_code}, c10::DeviceType::CUDA);"
1645            )
1646        return (
1647            f"{self.declare}{name} = {self.namespace}empty_strided("
1648            f"{size}, {stride}, at::TensorOptions({tensor_device}).dtype({dtype_code})){self.ending}"
1649        )
1650
1651    def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str:
1652        if config.abi_compatible:
1653            size = self.codegen_shape_tuple(shape)
1654            stride = self.codegen_shape_tuple(stride)
1655            tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}"
1656            args = [
1657                name,
1658                self.expr_printer(offset),  # bytes not numel
1659                self.codegen_dtype(dtype),
1660                str(len(shape)),
1661                self.codegen_int_array_var(
1662                    size, self.wrapper_call, graph=self.get_codegened_graph()
1663                ),
1664                self.codegen_int_array_var(
1665                    stride, self.wrapper_call, graph=self.get_codegened_graph()
1666                ),
1667                f"&{tmp_name}",
1668            ]
1669            self.wrapper_call.writeline(f"AtenTensorHandle {tmp_name};")
1670            self.wrapper_call.writeline(
1671                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));"
1672            )
1673            return f"RAIIAtenTensorHandle({tmp_name})"
1674
1675        return "alloc_from_pool({})".format(
1676            ", ".join(
1677                [
1678                    name,
1679                    self.expr_printer(offset),  # bytes not numel
1680                    self.codegen_dtype(dtype),
1681                    self.codegen_shape_tuple(shape),
1682                    self.codegen_shape_tuple(stride),
1683                ]
1684            )
1685        )
1686
1687    def codegen_reinterpret_view(
1688        self, data, size_list, stride_list, offset, writer, dtype=None
1689    ) -> str:
1690        dim = str(len(size_list))
1691        original_offset = offset
1692        size = self.codegen_shape_tuple(size_list)
1693        stride = self.codegen_shape_tuple(stride_list)
1694        offset = self.codegen_sizevar(offset)
1695        call_strs = []
1696        if config.abi_compatible:
1697            final_tmp_name = None
1698            final_tmp_name_is_RAIIAtenTensorHandle = False
1699
1700            def create_reinterpret_call():
1701                tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}"
1702                args = [
1703                    f"{data.get_name()}",
1704                    dim,
1705                    self.codegen_int_array_var(
1706                        size,
1707                        writer,
1708                        known_statically=self.is_statically_known_list_of_ints(
1709                            size_list
1710                        ),
1711                        graph=self.get_codegened_graph(),
1712                    ),
1713                    self.codegen_int_array_var(
1714                        stride,
1715                        writer,
1716                        known_statically=self.is_statically_known_list_of_ints(
1717                            stride_list
1718                        ),
1719                        graph=self.get_codegened_graph(),
1720                    ),
1721                    offset,
1722                ]
1723                call_str = (
1724                    f"auto {tmp_name} = reinterpret_tensor_wrapper({', '.join(args)});"
1725                )
1726                return tmp_name, call_str
1727
1728            def create_dtypeview_call(reinterpret_call):
1729                tmp_AtenTensorHandle = (
1730                    f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}"
1731                )
1732                call_strs = [f"AtenTensorHandle {tmp_AtenTensorHandle};"]
1733                dtype_name = str(dtype).split(".")[-1]
1734                device_name = "cuda" if data.layout.device.type == "cuda" else "cpu"
1735                get_dtype_function = f"aoti_torch_dtype_{dtype_name}"
1736                dtypeview_function = f"aoti_torch_{device_name}_view_dtype"
1737                call_strs.append(
1738                    f"AOTI_TORCH_ERROR_CODE_CHECK({dtypeview_function}"
1739                    f"({reinterpret_call}, {get_dtype_function}(), &{tmp_AtenTensorHandle}));"
1740                )
1741                tmp_RAIIAtenTensorHandle = (
1742                    f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}_handle"
1743                )
1744                call_strs.append(
1745                    f"RAIIAtenTensorHandle {tmp_RAIIAtenTensorHandle}({tmp_AtenTensorHandle});"
1746                )
1747                return tmp_RAIIAtenTensorHandle, call_strs
1748
1749            if (
1750                size_list == data.layout.size
1751                and stride_list == data.layout.stride
1752                and original_offset == data.layout.offset
1753            ):
1754                # pure dtypeview
1755                if dtype is not None and dtype != data.dtype:
1756                    tmp_output_name, tmp_call_strs = create_dtypeview_call(
1757                        data.get_name()
1758                    )
1759                    call_strs.extend(tmp_call_strs)
1760                    final_tmp_name = tmp_output_name
1761                    final_tmp_name_is_RAIIAtenTensorHandle = True
1762                else:
1763                    return f"{data.get_name()}"
1764            else:
1765                # firstly create reinterpretview
1766                final_tmp_name, reinterpret_call = create_reinterpret_call()
1767                call_strs.append(reinterpret_call)
1768
1769                if dtype is not None and dtype != data.dtype:
1770                    # wrap it with dtypeview
1771                    final_tmp_name, tmp_call_strs = create_dtypeview_call(
1772                        reinterpret_call
1773                    )
1774                    call_strs.extend(tmp_call_strs)
1775            # Because the memory planning is done in two passes (see the implementation
1776            # of self.generate), the writeline behavior is different in the two passes.
1777            if writer is None:
1778                writer = self
1779            writer.writelines(call_strs)
1780            if (
1781                self.can_stack_allocate_buffer(data)
1782                and self.is_statically_known_list_of_ints(size_list)
1783                and self.is_statically_known_list_of_ints(stride_list)
1784                and ir.is_contiguous_strides_for_shape(stride_list, size_list)
1785            ):
1786                return final_tmp_name
1787
1788            # NB, the return handle here represents a temporary tensor, which will be automatically
1789            # released.
1790            # Here's a sample usage in the cpp wrapper code:
1791            # ```
1792            # aoti_torch_addmm_out(
1793            #     buf1,
1794            #     arg1_1,
1795            #     RAIIAtenTensorHandle(tmp_tensor_handle_0),
1796            #     buf0,
1797            #     1L,
1798            #     1L));
1799            # ```
1800            # RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out.
1801            # This could be problematic when it's used in a different pattern, for example:
1802            # ````
1803            # AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6};
1804            # aoti_torch_proxy_executor_call_function(..., tensor_args);
1805            # ````
1806            # RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter
1807            # kernel call.
1808            #
1809            # This is solved by updating the proxy_executor invocation to
1810            # ```
1811            # aoti_torch_proxy_executor_call_function(...,
1812            #     std::vector<AtenTensorHandle>{
1813            #         RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6
1814            #     }.data()
1815            # );
1816            # ```
1817            if not final_tmp_name_is_RAIIAtenTensorHandle:
1818                return f"wrap_with_raii_handle_if_needed({final_tmp_name})"
1819            else:
1820                return final_tmp_name
1821        else:
1822            args = [data.get_name(), size, stride, offset]
1823            return f"reinterpret_tensor({', '.join(args)})"
1824
1825    def codegen_device_copy(self, src, dst):
1826        if config.abi_compatible:
1827            # aoti_torch_tensor_copy_ takes AtenTensorHandle as input,
1828            # while stack-allocation results in ArrayRefTensor
1829            # so disable stack allocation here
1830            self.allow_stack_allocation = False
1831            self.writeline(
1832                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_tensor_copy_(expensive_copy_to_tensor_if_needed({src}), {dst}));"
1833            )
1834        else:
1835            self.writeline(f"{dst}.copy_({src});")
1836
1837    def codegen_multi_output(self, name, value):
1838        # in the abi_compatible mode, outputs are retrieved by passing
1839        # output pointers, so we skip its codegen here.
1840        if not config.abi_compatible:
1841            super().codegen_multi_output(name, value)
1842
1843    def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
1844        for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs):
1845            if config.abi_compatible:
1846                # in ABI-compatible mode, we copy the underlying at::Tensor of the conditional
1847                # input (outer_input) into another at::Tensor to be used as a subgraph input
1848                # (inner_input) in the nested scope. we can't std::move here, as the codegened
1849                # outer input may be an expression / rvalue (e.g., reinterpret_view(x)), so we
1850                # can't necessarily std::move it back to the origin (x).
1851                self.writeline(f"AtenTensorHandle {inner_input}_handle;")
1852                self.writeline(
1853                    f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({outer_input}, &{inner_input}_handle));"
1854                )
1855                self.writeline(
1856                    f"RAIIAtenTensorHandle {inner_input}({inner_input}_handle);"
1857                )
1858            else:
1859                self.writeline(
1860                    f"{self.declare}{inner_input} = {outer_input}{self.ending}"
1861                )
1862
1863    def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs):
1864        for inner_output, outer_output in zip(
1865            subgraph.graph.graph_outputs, outer_outputs
1866        ):
1867            src = inner_output.codegen_reference()
1868            if config.abi_compatible:
1869                # in ABI-compatible mode, we need to std::move subgraph output (inner_output)
1870                # to the conditional output (outer_output), as RAIIAtenTensorHandle's copy
1871                # constructor is deleted.
1872                src = f"std::move({src})"
1873                # in case the outer_output carried a value
1874                # before (e.g., in the while_loop codegen)
1875                self.writeline(f"{outer_output}.reset();")
1876            self.writeline(f"{outer_output} = {src}{self.ending}")
1877
1878    def codegen_conditional(self, conditional):
1879        name = conditional.get_name()
1880        outer_inputs = [f"{buf.codegen_reference()}" for buf in conditional.operands]
1881        if config.abi_compatible:
1882            outer_outputs = []
1883            for out in conditional.outputs:
1884                # in ABI-compatible mode, ir.MultiOutput is not codegened,
1885                # hence pre-declare output variables directly and separately
1886                self.writeline(f"RAIIAtenTensorHandle {out.get_name()};")
1887                outer_outputs.append(out.get_name())
1888
1889            if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer):
1890                # in ABI-compatible mode, we need to use the ABI shim function
1891                # to extract a C++ bool from the unrelying scalar bool Tensor
1892                predicate = f"{conditional.predicate.get_name()}_scalar"
1893                self.codegen_tensor_item(
1894                    torch.bool,
1895                    conditional.predicate.codegen_reference(),
1896                    predicate,
1897                )
1898            else:
1899                # the predicate is not a Tensor: SymBool or Python bool
1900                predicate = conditional.predicate.codegen_reference()
1901        else:
1902            # in non-ABI-compatible mode, we can codegen the conditional outputs
1903            # as array of at::Tensor instances, as the ir.MultiOutput is codegened
1904            outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
1905            self.writeline(f"at::Tensor {name}[{len(conditional.outputs)}];")
1906            predicate = f"{conditional.predicate.codegen_reference()}"
1907            if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer):
1908                # move the Tensor predicate to host
1909                predicate = f"{predicate}.item<bool>()"
1910
1911        self.writeline(f"if ({predicate}) {{")
1912        self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph))
1913        self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs)
1914        self.writeline(ExitSubgraphLine(self))
1915        self.writeline("} else {")
1916        self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph))
1917        self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs)
1918        self.writeline(ExitSubgraphLine(self))
1919        self.writeline("}")
1920
1921    def codegen_while_loop(self, while_loop):
1922        name = while_loop.get_name()
1923        outer_carried_inputs = [
1924            buf.codegen_reference() for buf in while_loop.carried_inputs
1925        ]
1926        outer_additional_inputs = [
1927            buf.codegen_reference() for buf in while_loop.additional_inputs
1928        ]
1929        cond_result_name = f"{name}_cond_result"
1930
1931        if config.abi_compatible:
1932            self.writeline(f"RAIIAtenTensorHandle {cond_result_name};")
1933
1934            cond_outer_inputs = []
1935            for inp, out in zip(outer_carried_inputs, while_loop.outputs):
1936                # in ABI-compatible mode, the carried inputs are codegened
1937                # as buffers outside the while loop and set to the initial
1938                # values. at the end of each while_loop iteration, they
1939                # will be assined the carried values.
1940                out_name = out.get_name()
1941                self.writeline(f"AtenTensorHandle {out_name}_handle;")
1942                self.writeline(
1943                    f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({inp}, &{out_name}_handle));"
1944                )
1945                self.writeline(f"RAIIAtenTensorHandle {out_name}({out_name}_handle);")
1946                cond_outer_inputs.append(out_name)
1947
1948            # additional inputs will be assinged within the while_loop
1949            # iteration directly from the corresponding outer graph buffers
1950            cond_outer_inputs.extend(outer_additional_inputs)
1951        else:
1952            self.writeline(f"at::Tensor {cond_result_name};")
1953            self.writeline(f"at::Tensor {name}[{len(outer_carried_inputs)}];")
1954            for i, inp in enumerate(outer_carried_inputs):
1955                # set the initial state before the loop
1956                self.writeline(f"{name}[{i}] = {inp};")
1957
1958            cond_outer_inputs = [
1959                *[f"{name}[{i}]" for i in range(len(outer_carried_inputs))],
1960                *outer_additional_inputs,
1961            ]
1962
1963        cond_outer_outputs = [cond_result_name]
1964        body_outer_inputs = list(cond_outer_inputs)
1965        body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)]
1966
1967        self.writeline("while (1) {")
1968        self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph))
1969        self.codegen_subgraph(
1970            while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
1971        )
1972
1973        if config.abi_compatible:
1974            cond_result = f"{cond_result_name}_scalar"
1975            self.codegen_tensor_item(torch.bool, cond_result_name, cond_result)
1976        else:
1977            cond_result = f"{cond_result_name}.item<bool>()"
1978        self.writeline(f"if (!{cond_result}) break;")
1979
1980        self.writeline(ExitSubgraphLine(self))
1981        self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
1982        self.codegen_subgraph(
1983            while_loop.body_subgraph, body_outer_inputs, body_outer_outputs
1984        )
1985        self.writeline(ExitSubgraphLine(self))
1986        self.writeline("}")
1987
1988    def generate_extern_kernel_args_decl_if_needed(
1989        self, op_overload, raw_args, output_args
1990    ):
1991        arg_types = [x.real_type for x in op_overload._schema.arguments]
1992        return_types = [x.type for x in op_overload._schema.returns]
1993
1994        new_tensor_args = []
1995        new_int_args = []
1996
1997        def fill_args(arg, arg_type):
1998            static_arg_types = (
1999                torch.FloatType,
2000                torch.BoolType,
2001                torch.StringType,
2002                torch.Type,
2003                torch.DeviceObjType,
2004            )
2005            inductor_tensor_buffers = (
2006                ir.Buffer,
2007                ir.ReinterpretView,
2008            )
2009
2010            if isinstance(arg_type, torch.TensorType):
2011                assert isinstance(arg, inductor_tensor_buffers), f"got {type(arg)}"
2012                new_tensor_args.append(f"{arg.codegen_reference()}")
2013            elif isinstance(arg_type, torch.IntType):
2014                # int
2015                new_int_args.append(str(arg))
2016            elif isinstance(arg_type, torch.SymIntType):
2017                # SymInt
2018                expr = arg.node.expr if isinstance(arg, torch.SymInt) else arg
2019                new_int_args.append(self.expr_printer(expr))
2020            elif isinstance(arg_type, torch.NumberType):
2021                # Scalar of type int
2022                assert isinstance(arg, (int, float, bool))
2023                # Only treat int Scalar as dynamic
2024                if isinstance(arg, int):
2025                    new_int_args.append(str(arg))
2026            elif isinstance(arg_type, torch.ListType):
2027                assert isinstance(arg, (list, tuple))
2028
2029                # List[Tensor]
2030                if isinstance(arg_type.getElementType(), torch.TensorType):
2031                    new_tensor_args.extend([f"{a.codegen_reference()}" for a in arg])
2032                # List[Optional[Tensor]]
2033                elif isinstance(
2034                    arg_type.getElementType(), torch.OptionalType
2035                ) and isinstance(
2036                    arg_type.getElementType().getElementType(), torch.TensorType
2037                ):
2038                    new_tensor_args.extend(
2039                        [f"{a.codegen_reference()}" for a in arg if a is not None]
2040                    )
2041                # List[int]
2042                elif isinstance(arg_type.getElementType(), torch.IntType):
2043                    new_int_args.extend([str(a) for a in arg])
2044                # List[SymInt]
2045                elif isinstance(arg_type.getElementType(), torch.SymIntType):
2046                    expressions = [
2047                        a.node.expr if isinstance(a, torch.SymInt) else a for a in arg
2048                    ]
2049                    new_int_args.extend(
2050                        [self.expr_printer(expr) for expr in expressions]
2051                    )
2052                # List[Scalar]
2053                elif isinstance(arg_type.getElementType(), torch.NumberType):
2054                    # Only treat int Scalar as dynamic
2055                    is_int_type = [isinstance(a, int) for a in arg]
2056                    if any(is_int_type):
2057                        assert all(
2058                            is_int_type
2059                        ), "AOTInductor only supports int scalars of the same type"
2060                        new_int_args.extend([str(a) for a in arg])
2061                else:
2062                    assert isinstance(
2063                        arg_type.getElementType(), static_arg_types  # type: ignore[arg-type]
2064                    ), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
2065            else:
2066                assert isinstance(
2067                    arg_type, static_arg_types  # type: ignore[arg-type]
2068                ), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
2069
2070        for arg, arg_type in zip(raw_args, arg_types):
2071            if arg is not None:
2072                if isinstance(arg_type, torch.OptionalType):
2073                    fill_args(arg, arg_type.getElementType())
2074                else:
2075                    fill_args(arg, arg_type)
2076
2077        def fill_output_arg(arg, return_type):
2078            if isinstance(return_type, torch.TensorType):
2079                self.writeline(f"AtenTensorHandle {arg}_handle;  // output buffer")
2080                self.writeline(
2081                    f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));"
2082                )
2083                self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);")
2084                new_tensor_args.append(f"{arg}")
2085            elif isinstance(return_type, torch.SymIntType):
2086                raise NotImplementedError("NYI support for return type: SymInt")
2087            elif isinstance(return_type, torch.ListType) and isinstance(
2088                return_type.getElementType(), torch.SymIntType
2089            ):
2090                raise NotImplementedError("NYI support for return type: List[SymInt]")
2091            else:
2092                raise AssertionError(f"Unsupported return type found: {return_type}")
2093
2094        # TODO: Only support tensor(s) returns for now, SymInt is not implemented yet
2095        for return_type in return_types:
2096            if isinstance(return_type, (torch.TensorType)):
2097                pass
2098            elif isinstance(return_type, torch.OptionalType):
2099                assert isinstance(return_type.getElementType(), torch.TensorType)
2100            elif isinstance(return_type, torch.ListType):
2101                assert isinstance(return_type.getElementType(), torch.TensorType)
2102            else:
2103                raise NotImplementedError(
2104                    f"return type {return_type} is not yet supported."
2105                )
2106
2107        for output_arg in output_args:
2108            assert output_arg is not None, "Optional return types are not yet supported"
2109            if isinstance(output_arg, (list, tuple)):
2110                for out in output_arg:
2111                    fill_output_arg(out, torch.TensorType.get())
2112            else:
2113                fill_output_arg(output_arg, torch.TensorType.get())
2114
2115        return new_tensor_args, new_int_args
2116
2117    def generate_extern_kernel_alloc_and_find_schema_if_needed(
2118        self,
2119        buf_name: str,
2120        python_kernel_name: str,
2121        cpp_kernel_name: str,
2122        codegen_args: List[str],
2123        cpp_op_schema: str,
2124        cpp_kernel_key: str,
2125        cpp_kernel_overload_name: str = "",
2126        op_overload: Optional[torch._ops.OpOverload] = None,
2127        raw_args=None,
2128        outputs=None,
2129    ):
2130        # No stack allocation when there is a fallback op
2131        self.allow_stack_allocation = False
2132
2133        def extract_output_name(out):
2134            if out is None:
2135                # Because out is not a MultiOutput, we assume the kernel returns a single output
2136                return [buf_name]
2137            elif isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)):
2138                return out.get_name()
2139            elif isinstance(out, (list, tuple)):
2140                return type(out)(extract_output_name(o) for o in out)
2141            else:
2142                raise AssertionError(f"Unexpected output: {type(out)}")
2143
2144        # output_args has the same pytree structure as outputs
2145        output_args = None
2146        if config.abi_compatible:
2147            output_args = extract_output_name(outputs)
2148            if isinstance(output_args, str):
2149                output_args = [output_args]
2150
2151        if V.graph.aot_mode and config.abi_compatible:
2152            assert op_overload is not None
2153            assert raw_args is not None
2154            assert outputs is not None
2155
2156            return self.generate_extern_kernel_alloc_and_find_schema_if_needed_with_proxy_executor(
2157                cpp_kernel_key,
2158                op_overload,
2159                raw_args,
2160                output_args,
2161            )
2162        else:
2163            return self.generate_extern_kernel_alloc_and_find_schema_if_needed_jit(
2164                buf_name,
2165                python_kernel_name,
2166                cpp_kernel_name,
2167                codegen_args,
2168                cpp_op_schema,
2169                cpp_kernel_key,
2170                cpp_kernel_overload_name,
2171                op_overload,
2172                raw_args,
2173                output_args,
2174            )
2175
2176    def generate_scoped_gil_acquire(self, declarations_before_scope, lines_in_scope):
2177        scoped_lines = IndentedBuffer()
2178        for declaration in declarations_before_scope:
2179            scoped_lines.writeline(declaration)
2180
2181        scoped_lines.writeline("{")
2182        with scoped_lines.indent():
2183            scoped_lines.writeline("py::gil_scoped_acquire acquire;")
2184            scoped_lines.writelines(lines_in_scope.split("\n"))
2185        scoped_lines.writelines("}")
2186        return scoped_lines._lines
2187
2188    def load_custom_op_wrapper(self):
2189        # TODO: need to support control flow
2190        if self.custom_op_wrapper_loaded:
2191            return
2192
2193        lines = """
2194RAIIPyObject codecache_module(PyImport_ImportModule("torch._inductor.codecache"));
2195if (codecache_module.get() == NULL) {
2196    throw std::runtime_error("Failed to load torch._inductor.codecache");
2197}
2198custom_op_wrapper = PyObject_GetAttrString(codecache_module, "custom_op_wrapper");
2199if (custom_op_wrapper.get() == NULL) {
2200    throw std::runtime_error("Failed to load torch._inductor.codecache.custom_op_wrapper");
2201}"""
2202
2203        declarations_before_scope = ["RAIIPyObject custom_op_wrapper;"]
2204        scope_gil_acquire = self.generate_scoped_gil_acquire(
2205            declarations_before_scope, lines
2206        )
2207        self.writelines(scope_gil_acquire)
2208
2209        self.custom_op_wrapper_loaded = True
2210
2211    def generate_py_arg(self, py_args_var, idx, raw_arg, arg_type):
2212        def generate_py_arg_inner(lines, raw_arg, arg_type):
2213            if raw_arg is None:
2214                # Py_None is a singleton, so we have to explicitly incref it here
2215                lines.append("Py_INCREF(Py_None);\n")
2216                return "Py_None"
2217            elif isinstance(arg_type, torch.TensorType):
2218                # Store AtenTensorHandle as void*
2219                base_handle = raw_arg.codegen_reference()
2220                (
2221                    tmp_raii_handle_var,
2222                    tmp_raii_handle_var_decl,
2223                ) = self.create_tmp_raii_handle_var(base_handle)
2224                if tmp_raii_handle_var:
2225                    lines.append(tmp_raii_handle_var_decl)
2226                    base_handle = tmp_raii_handle_var
2227                return f"PyCapsule_New(reinterpret_cast<void*>({base_handle}.get()), NULL, NULL)"
2228            elif isinstance(arg_type, torch.OptionalType):
2229                return generate_py_arg_inner(lines, raw_arg, arg_type.getElementType())
2230            elif isinstance(arg_type, torch.IntType):
2231                # int
2232                return f"PyLong_FromLongLong({raw_arg})"
2233            elif isinstance(arg_type, torch.SymIntType):
2234                # SymInt
2235                expr = (
2236                    raw_arg.node.expr if isinstance(raw_arg, torch.SymInt) else raw_arg
2237                )
2238                return f"PyLong_FromLongLong({self.expr_printer(expr)})"
2239            elif isinstance(arg_type, torch.FloatType):
2240                return f"PyFloat_FromDouble({raw_arg})"
2241            elif isinstance(arg_type, torch.BoolType):
2242                return f"PyBool_FromLong({1 if raw_arg else 0})"
2243            elif isinstance(arg_type, torch.StringType):
2244                return f'PyUnicode_FromString("{raw_arg}")'
2245            elif isinstance(arg_type, torch.NumberType):
2246                # Union[bool, int, float, complex]
2247                # torch/_prims_common/__init__.py
2248                if isinstance(raw_arg, int):
2249                    return f"PyLong_FromLongLong({raw_arg})"
2250                elif isinstance(raw_arg, float):
2251                    return f"PyFloat_FromDouble({raw_arg})"
2252                elif isinstance(raw_arg, bool):
2253                    return f"PyBool_FromLong({1 if raw_arg else 0})"
2254                elif isinstance(raw_arg, complex):
2255                    return f"PyComplex_FromDoubles({raw_arg.real, raw_arg.imag})"
2256                elif isinstance(raw_arg, torch.SymInt):
2257                    expr = raw_arg.node.expr
2258                    return f"PyLong_FromLongLong({self.expr_printer(expr)})"
2259                else:
2260                    raise NotImplementedError(
2261                        f"arg type {arg_type} with raw_arg {raw_arg}, {type(raw_arg)} is not yet supported by custom_op_wrapper"
2262                    )
2263            elif isinstance(raw_arg, torch.dtype):
2264                # dtype
2265                self.include_extra_header("torch/csrc/DynamicTypes.h")
2266                return f"Py_NewRef(torch::getTHPDtype(static_cast<c10::ScalarType>({self.codegen_dtype(raw_arg)})))"
2267            else:
2268                raise NotImplementedError(
2269                    f"arg type {arg_type} is not yet supported by custom_op_wrapper"
2270                )
2271
2272        lines = []
2273        if isinstance(arg_type, torch.ListType):
2274            assert isinstance(raw_arg, (list, tuple)), str(raw_arg) + " is not a list"
2275            lines.append(
2276                f"PyObject* {py_args_var}_{idx} = PyList_New({len(raw_arg)});\n"
2277            )
2278            for i, elem in enumerate(raw_arg):
2279                lines.append(
2280                    f"PyList_SetItem({py_args_var}_{idx}, {i}, {generate_py_arg_inner(lines, elem, arg_type.getElementType())});\n"
2281                )
2282            lines.append(
2283                f"PyTuple_SetItem({py_args_var}, {idx}, {py_args_var}_{idx});\n"
2284            )
2285        else:
2286            lines.append(
2287                f"PyTuple_SetItem({py_args_var}, {idx}, {generate_py_arg_inner(lines, raw_arg, arg_type)});\n"
2288            )
2289        return "".join(lines)
2290
2291    def generate_extern_kernel_alloc_and_find_schema_if_needed_jit(
2292        self,
2293        buf_name: str,
2294        python_kernel_name: str,
2295        cpp_kernel_name: str,
2296        codegen_args: List[str],
2297        cpp_op_schema: str,
2298        cpp_kernel_key: str,
2299        cpp_kernel_overload_name: str = "",
2300        op_overload: Optional[torch._ops.OpOverload] = None,
2301        raw_args=None,
2302        output_args: Optional[List[str]] = None,
2303    ):
2304        if not config.abi_compatible:
2305            # Will update this to use an OSS version ProxyExecutor
2306            if cpp_kernel_key not in self.extern_call_ops:
2307                self.writeline(
2308                    f"static auto op_{cpp_kernel_key} = c10::Dispatcher::singleton()"
2309                )
2310                self.writeline(
2311                    f'\t.findSchemaOrThrow("{cpp_kernel_name}", "{cpp_kernel_overload_name}")'
2312                )
2313                self.writeline(f"\t.typed<{cpp_op_schema}>();")
2314                self.extern_call_ops.add(cpp_kernel_key)
2315
2316            self.writeline(
2317                f"auto {buf_name} = op_{cpp_kernel_key}.call({', '.join(codegen_args)});"
2318            )
2319        else:
2320            # In the JIT mode, because of the ABI-compatible requirement, we can't directly call
2321            # c10::Dispatcher to find the custom op and call it. Instead, we go back to Python
2322            # to invoke this custom op.
2323            self.load_custom_op_wrapper()
2324
2325            assert output_args is not None, "output_args should not be None"
2326            num_args = len(raw_args)
2327            py_args_var = f"py_args_{next(self.arg_var_id)}"
2328            # First arg is always the python op name
2329            lines = f"""
2330RAIIPyObject {py_args_var}(PyTuple_New({num_args+1}));
2331if ({py_args_var}.get() == NULL) {{
2332    throw std::runtime_error("PyTuple_New {py_args_var} failed");
2333}}
2334PyTuple_SetItem({py_args_var}, 0, PyUnicode_FromString("{python_kernel_name}"));
2335"""
2336
2337            assert op_overload is not None, "op_overload should not be None"
2338
2339            for idx, (raw_arg, schema_arg) in enumerate(
2340                zip(raw_args, op_overload._schema.arguments)
2341            ):
2342                lines += self.generate_py_arg(
2343                    py_args_var, idx + 1, raw_arg, schema_arg.real_type
2344                )
2345
2346            lines += f"""
2347// Call the custom op in Python
2348RAIIPyObject py_{buf_name}(PyObject_CallObject(custom_op_wrapper, {py_args_var}));
2349if (py_{buf_name}.get() == NULL) {{
2350    throw std::runtime_error("PyObject_CallObject {python_kernel_name} failed");
2351}}"""
2352
2353            if len(output_args) == 1:
2354                # result is a single tensor
2355                lines += f"""
2356{output_args[0]} = reinterpret_cast<AtenTensorHandle>(PyCapsule_GetPointer(py_{buf_name}.get(), NULL));"""
2357            else:
2358                # result is a tuple of tensors
2359                for idx, output_arg in enumerate(output_args):
2360                    lines += f"""
2361{output_arg} =
2362    reinterpret_cast<AtenTensorHandle>(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_name}.get(), {idx}), NULL));"""
2363
2364            declarations_before_scope = [
2365                f"RAIIAtenTensorHandle {output_arg};"
2366                for idx, output_arg in enumerate(output_args)
2367            ]
2368            scope_gil_acquire = self.generate_scoped_gil_acquire(
2369                declarations_before_scope, lines
2370            )
2371            self.writelines(scope_gil_acquire)
2372
2373    def generate_extern_kernel_alloc_and_find_schema_if_needed_with_proxy_executor(
2374        self,
2375        cpp_kernel_key,
2376        op_overload,
2377        raw_args,  # contains both args and flatten kwargs
2378        output_args: Optional[List[str]] = None,
2379    ):
2380        (
2381            tensor_call_args,
2382            int_call_args,
2383        ) = self.generate_extern_kernel_args_decl_if_needed(
2384            op_overload, raw_args, output_args
2385        )
2386
2387        tensor_call_args_str = ", ".join(tensor_call_args)
2388        int_call_args_str = ", ".join(int_call_args)
2389
2390        extern_kernel_node_index = len(V.graph.extern_kernel_nodes) - 1
2391
2392        self.writeline(
2393            f"aoti_torch_proxy_executor_call_function(proxy_executor, "
2394            f"{extern_kernel_node_index}, "
2395            f"{len(int_call_args)}, "
2396            f"std::vector<int64_t>{{{int_call_args_str}}}.data(), "
2397            f"{len(tensor_call_args)}, "
2398            f"std::vector<AtenTensorHandle>{{{tensor_call_args_str}}}.data());"
2399        )
2400
2401        self.extern_call_ops.add(cpp_kernel_key)
2402
2403    def generate_reset_kernel_saved_flags(self):
2404        pass
2405
2406    def generate_save_uncompiled_kernels(self):
2407        pass
2408
2409    def c_type_for_prim_type(self, val, type_) -> str:
2410        assert (
2411            config.abi_compatible
2412        ), "c_type_for_prim_type is only used in ABI compatible mode"
2413        if isinstance(type_, torch.OptionalType):
2414            return f"{self.c_type_for_prim_type(val, type_.getElementType())}*"
2415        elif isinstance(type_, torch.TensorType):
2416            return "AtenTensorHandle"
2417        elif isinstance(type_, (torch.IntType, torch.SymIntType)):
2418            return "int64_t"
2419        elif isinstance(
2420            type_, (torch.BoolType, torch.SymBoolType, torch.EnumType)
2421        ) or repr(type_) in ("ScalarType", "Layout"):
2422            return "int32_t"
2423        elif isinstance(type_, torch.FloatType):
2424            return "double"
2425        elif isinstance(type_, torch.NumberType):
2426            if isinstance(val, bool):
2427                return "int32_t"
2428            elif isinstance(val, int):
2429                return "int64_t"
2430            elif isinstance(val, float):
2431                return "double"
2432            elif val is None:
2433                # This could happen when val is an optional value
2434                return "double"
2435            else:
2436                raise AssertionError(
2437                    f"Unexpected type in c_type_for_prim_type: {type_=}"
2438                )
2439        elif isinstance(type_, torch.StringType):
2440            return "const char*"
2441        else:
2442            raise AssertionError(f"Unexpected type in c_type_for_prim_type: {type_=}")
2443
2444    def val_to_arg_str_for_prim_type(self, val, type_) -> str:
2445        # TODO: not using type_ as the first step of refactoring. Will update this later.
2446        if isinstance(val, bool):
2447            if config.abi_compatible:
2448                return "1" if val else "0"
2449            else:
2450                return "true" if val else "false"
2451        elif isinstance(val, int):
2452            # uint64_t is long on Linux, but long long on MacOS and Windows
2453            return f"{val}LL" if sys.platform in ["darwin", "win32"] else f"{val}L"
2454        elif isinstance(val, str):
2455            return f'"{val}"'
2456        elif isinstance(
2457            val, (ir.Buffer, ir.ReinterpretView, ir.StorageBox, ir.TensorBox)
2458        ):
2459            return val.codegen_reference()
2460        elif isinstance(val, torch.device):
2461            return self.codegen_device(val)
2462        elif isinstance(val, torch.dtype):
2463            return self.codegen_dtype(val)
2464        elif isinstance(val, float) and val in [float("inf"), float("-inf")]:
2465            if val == float("inf"):
2466                return "std::numeric_limits<float>::infinity()"
2467            else:
2468                return "-std::numeric_limits<float>::infinity()"
2469        elif isinstance(val, (list, tuple)):
2470            # FIXME: This happens because type_ is not always properly set to torch.ListType
2471            return f"{{{', '.join(self.val_to_arg_str(x, None) for x in val)}}}"
2472        elif isinstance(val, SymTypes):
2473            return self.expr_printer(val.node.expr)
2474        elif isinstance(val, sympy.Expr):
2475            return self.expr_printer(val)
2476        else:
2477            return repr(val)
2478
2479    def val_to_arg_str(self, val, type_=None) -> str:
2480        if val is None:
2481            # None needs special care. It either represent nullopt or an empty tensor
2482            if config.abi_compatible:
2483                if type_ is None or isinstance(type_, torch.OptionalType):
2484                    if type_ is not None and isinstance(
2485                        type_.getElementType(),
2486                        (
2487                            torch.ListType,
2488                            torch.TupleType,
2489                            torch.DeviceObjType,
2490                        ),
2491                    ):
2492                        return "0, 0"
2493                    else:
2494                        return "0"  # nullptr is not available in C
2495                elif isinstance(type_, torch.TensorType):
2496                    # create an empty tensor, the equivalent of at::Tensor()
2497                    var_name = f"var_{next(self.arg_var_id)}"
2498                    self.writeline(f"AtenTensorHandle {var_name}_handle;")
2499                    self.writeline(
2500                        f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{var_name}_handle));"
2501                    )
2502                    self.writeline(
2503                        f"RAIIAtenTensorHandle {var_name}({var_name}_handle);"
2504                    )
2505                    return var_name
2506                else:
2507                    raise AssertionError("Can not map None to a known data type")
2508            else:
2509                return "std::nullopt"
2510
2511        if isinstance(type_, torch.OptionalType):
2512            element_type = type_.getElementType()
2513            if config.abi_compatible:
2514                if not isinstance(element_type, torch.TensorType):
2515                    var_name = f"var_{next(self.arg_var_id)}"
2516                    if isinstance(
2517                        element_type,
2518                        (torch.ListType, torch.TupleType, torch.DeviceObjType),
2519                    ):
2520                        # type_ is something like Optional[List] or Optional[Device]
2521                        arg_str = self.val_to_arg_str(val, element_type)
2522                        # For datatypes with auxiliary info, we need to hoist out the extra arguments.
2523                        # NOTE: This only works if there is one additional argument, though it can easily be generalized.
2524                        main_value, aux = arg_str.rsplit(", ")
2525                        self.writeline(f"auto {var_name} = {main_value};")
2526                        return f"&{var_name}, {aux}"
2527                    else:
2528                        self.writeline(
2529                            f"{self.c_type_for_prim_type(val, element_type)} {var_name} = {self.val_to_arg_str(val, element_type)};"
2530                        )
2531                        return f"&{var_name}"
2532                else:
2533                    # type_ is Optional[Tensor]
2534                    # Similar to other data type, use pointer to denote optional tensor arg in v2 C shim
2535                    base_handle = self.val_to_arg_str(val, element_type)
2536                    if config.use_minimal_arrayref_interface:
2537                        base_handle = (
2538                            f"convert_arrayref_tensor_to_tensor({base_handle})"
2539                        )
2540                    (
2541                        tmp_raii_handle_var,
2542                        tmp_raii_handle_var_decl,
2543                    ) = self.create_tmp_raii_handle_var(base_handle)
2544                    if tmp_raii_handle_var:
2545                        self.writeline(tmp_raii_handle_var_decl)
2546                        base_handle = tmp_raii_handle_var
2547                    var_name = f"var_{next(self.arg_var_id)}"
2548                    self.writeline(
2549                        f"AtenTensorHandle {var_name} = {base_handle}.get();"
2550                    )
2551                    return f"&{var_name}"
2552            else:
2553                return self.val_to_arg_str(val, element_type)
2554
2555        elif isinstance(type_, torch.ListType):
2556            assert isinstance(
2557                val, (list, tuple)
2558            ), f"{val} does not match with arg type {type_}"
2559            element_type = type_.getElementType()
2560            if config.abi_compatible:
2561                var_name = f"var_array_{next(self.var_array_id)}"
2562                if len(val) == 0:
2563                    # Zero-size array is not supported in the C or C++ standard, so
2564                    # we declare a null pointer for it.
2565                    self.writeline(
2566                        f"const {self.c_type_for_prim_type(None, element_type)}* {var_name} = nullptr;"
2567                    )
2568                else:
2569                    result = f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}"
2570                    self.writeline(
2571                        f"const {self.c_type_for_prim_type(val[0], element_type)} {var_name}[] = {result};"
2572                    )
2573                # Need to pass the array length because we can't use std::vector
2574                return f"{var_name}, {len(val)}"
2575            else:
2576                return f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}"
2577
2578        return self.val_to_arg_str_for_prim_type(val, type_)
2579
2580    def create_tmp_raii_handle_var(self, base_handle):
2581        if base_handle.startswith(
2582            (
2583                "convert_arrayref_tensor_to_tensor",
2584                "wrap_with_raii_handle_if_needed",
2585            )
2586        ):
2587            # wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to
2588            # explicitly store it. Otherwise, it will be destroyed before the fallback kernel call.
2589            tmp_var_name = f"var_{next(self.arg_var_id)}"
2590            return (
2591                tmp_var_name,
2592                f"RAIIAtenTensorHandle {tmp_var_name} = {base_handle};\n",
2593            )
2594        else:
2595            return "", ""
2596