xref: /aosp_15_r20/external/pytorch/torchgen/gen_aoti_c_shim.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import textwrap
4from dataclasses import dataclass
5from typing import Sequence
6
7from torchgen.api.types import DispatcherSignature
8from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
9from torchgen.context import method_with_native_function
10from torchgen.model import (
11    Argument,
12    BackendIndex,
13    BaseTy,
14    BaseType,
15    DispatchKey,
16    FunctionSchema,
17    ListType,
18    NativeFunction,
19    NativeFunctionsGroup,
20    OperatorName,
21    OptionalType,
22    Type,
23)
24from torchgen.utils import mapMaybe
25
26
27base_type_to_c_type = {
28    BaseTy.Tensor: "AtenTensorHandle",
29    BaseTy.bool: "int32_t",  # Use int to pass bool
30    BaseTy.int: "int64_t",
31    BaseTy.SymInt: "int64_t",  # Inductor-generated code won't see a SymInt
32    BaseTy.Scalar: "double",  # Use double to pass both integer and floating point
33    BaseTy.float: "double",  # TODO: how about other floating point types?
34    BaseTy.str: "const char*",
35    BaseTy.DeviceIndex: "int32_t",
36    BaseTy.Layout: "int32_t",  # Represent enum as int
37    BaseTy.MemoryFormat: "int32_t",  # Represent enum as int
38    BaseTy.ScalarType: "int32_t",  # Represent enum as int
39    BaseTy.Generator: "AtenGeneratorHandle",
40}
41
42base_type_to_aten_type = {
43    BaseTy.Tensor: "at::Tensor",
44    BaseTy.bool: "bool",
45    BaseTy.int: "int64_t",
46    BaseTy.SymInt: "c10::SymInt",
47    BaseTy.Scalar: "c10::Scalar",
48    BaseTy.float: "double",
49    BaseTy.str: "c10::string_view",
50    BaseTy.DeviceIndex: "c10::DeviceIndex",
51    BaseTy.Layout: "c10::Layout",
52    BaseTy.MemoryFormat: "c10::MemoryFormat",
53    BaseTy.ScalarType: "c10::ScalarType",
54    BaseTy.Generator: "at::Generator",
55}
56
57base_type_to_callsite_expr = {
58    BaseTy.Tensor: "*tensor_handle_to_tensor_pointer",
59    BaseTy.bool: "",
60    BaseTy.int: "",
61    BaseTy.SymInt: "",
62    BaseTy.Scalar: "",
63    BaseTy.float: "",
64    BaseTy.str: "",
65    BaseTy.DeviceIndex: "static_cast<c10::DeviceIndex>",
66    BaseTy.Layout: "static_cast<c10::Layout>",
67    BaseTy.MemoryFormat: "static_cast<c10::MemoryFormat>",
68    BaseTy.ScalarType: "static_cast<c10::ScalarType>",
69    BaseTy.Generator: "*generator_handle_to_generator_pointer",
70}
71
72
73# convert args to C types, names in declarations, and expressions in function bodies
74def convert_arg_type_and_name(typ: Type, name: str) -> tuple[list[str], list[str], list[str], list[str]]:  # type: ignore[return]
75    if isinstance(typ, BaseType):
76        if typ.name in base_type_to_c_type:
77            return (
78                [base_type_to_c_type[typ.name]],
79                [name],
80                [base_type_to_aten_type[typ.name]],
81                [
82                    f"{base_type_to_callsite_expr[typ.name]}({name})"
83                    if base_type_to_callsite_expr[typ.name]
84                    else name
85                ],
86            )
87        elif typ.name == BaseTy.Device:
88            return (
89                ["int32_t", "int32_t"],
90                [name, name + "_index_"],
91                ["c10::Device"],
92                [
93                    f"c10::Device(static_cast<c10::DeviceType>({name}), static_cast<c10::DeviceIndex>({name}_index_))"
94                ],
95            )
96        else:
97            # TODO: BaseTy.Dimname, etc.
98            raise NotImplementedError(f"TODO: add support for arg type {repr(typ)}")
99    elif isinstance(typ, OptionalType):
100        c_types, names, aten_types, callsite_exprs = convert_arg_type_and_name(
101            typ.elem, name
102        )
103        j = 0  # index for names
104        new_aten_types = []
105        new_callsite_exprs = []
106        for aten_type in aten_types:
107            # Use pointer to denote optional type
108            c_types[j] = c_types[j] + "*"
109            if aten_type.startswith("c10::ArrayRef<"):
110                # ArrayRef is passed as pointer + size, but no need to add "*" to the size argument
111                new_aten_types.append(f"::std::optional<{aten_type}>")
112                base_type = aten_type[len("c10::ArrayRef<") : -1]
113                new_callsite_exprs.append(
114                    f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j+1]})"
115                )
116                j += 2
117            elif aten_type == "c10::Device":
118                # Device is passed as device_type + device_index
119                new_aten_types.append("::std::optional<c10::Device>")
120                new_callsite_exprs.append(
121                    f"pointer_to_optional_device({names[j]}, {names[j+1]})"
122                )
123                j += 2
124            else:
125                new_aten_types.append(f"::std::optional<{aten_type}>")
126                new_callsite_exprs.append(
127                    f"pointer_to_optional<{aten_type}>({names[j]})"
128                )
129                j += 1
130
131        return (
132            c_types,
133            names,
134            new_aten_types,
135            new_callsite_exprs,
136        )
137    elif isinstance(typ, ListType):
138        # Need to explictly pass the list as pointer + length
139        c_types, names, aten_types, _ = convert_arg_type_and_name(typ.elem, name)
140        assert len(c_types) == 1, "ListType with unsupported element type " + repr(typ)
141
142        # The list content should never be modified
143        c_types[0] = f"const {c_types[0]}*"
144        c_types.append("int64_t")
145        name = names[0]
146        names.append(name + "_len_")
147
148        atype = aten_types[0]
149        callsite_exprs = []
150        if atype == "bool":
151            # no converter from std::vector<bool> to c10::ArrayRef<bool>
152            # construct std::array<bool, N> instead
153            assert typ.size is not None
154            callsite_exprs.append(f"pointer_to_list<{typ.size}>({name})")
155        elif atype == "::std::optional<at::Tensor>":
156            # convert from std::vector<::std::optional<at::Tensor>> to c10::List<::std::optional<at::Tensor>>
157            callsite_exprs.append(
158                f"c10::List<{atype}>(c10::ArrayRef<{atype}>(pointer_to_list<{atype}>({name}, {name}_len_)))"
159            )
160        else:
161            callsite_exprs.append(f"pointer_to_list<{atype}>({name}, {name}_len_)")
162
163        aten_types = [f"c10::ArrayRef<{t}>" for t in aten_types]
164        return (
165            c_types,
166            names,
167            aten_types,
168            callsite_exprs,
169        )
170
171
172def zip_type_and_name(types: list[str], names: list[str]) -> list[str]:
173    return [typ + " " + name for typ, name in zip(types, names)]
174
175
176# Generate argument declarations and callsite expressions
177def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[str]]:
178    types = []
179    new_names = []
180    callsite_exprs = []
181    for arg in flat_arguments:
182        new_types, names, _, new_callsite_exprs = convert_arg_type_and_name(
183            arg.type, arg.name
184        )
185        types.extend(new_types)
186        new_names.extend(names)
187        callsite_exprs.extend(new_callsite_exprs)
188    return zip_type_and_name(types, new_names), callsite_exprs
189
190
191# Return values are passed out as pointer arguments because all the C shim functions
192# are expected to return AOTITorchError.
193# Generate returns as declarations and callsite expressions
194def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]:
195    types = []
196    names = []
197    for idx, ret in enumerate(schema.returns):
198        names.append(f"ret{idx}")
199        if isinstance(ret.type, BaseType) and ret.type.name in base_type_to_c_type:
200            types.append(base_type_to_c_type[ret.type.name] + "*")
201        else:
202            raise NotImplementedError(
203                f"TODO: add support for return type {repr(ret.type)}"
204            )
205
206    def convert_return(typ: BaseType, val: str) -> str:
207        if typ.name == BaseTy.Tensor:
208            return f"new_tensor_handle(std::move({val}));"
209        elif typ.name == BaseTy.SymInt:
210            return f"{val}.expect_int()"
211        elif typ.name == BaseTy.Scalar:
212            return f"{val}.toDouble()"
213        else:
214            return val
215
216    ret_pointer_can_be_null = False
217    unambiguous_name = schema.name.unambiguous_name()
218    for name in [
219        "_scaled_dot_product_flash_attention",
220        "_scaled_dot_product_efficient_attention",
221        "_scaled_dot_product_cudnn_attention",
222        "convolution_backward",
223    ]:
224        if name in unambiguous_name:
225            ret_pointer_can_be_null = True
226            break
227
228    callsite_exprs: list[str] = []
229    for idx, ret in enumerate(schema.returns):
230        tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)"
231        assert isinstance(ret.type, BaseType)
232        rval = convert_return(ret.type, tmp)
233        if ret_pointer_can_be_null:
234            callsite_exprs.append(f"if ({names[idx]}) {{ *{names[idx]} = {rval}; }}")
235        else:
236            callsite_exprs.append(f"*{names[idx]} = {rval};")
237
238    return zip_type_and_name(types, names), callsite_exprs
239
240
241# gen.py generates header first and then src, so caching the result here to avoid duplicate work
242declaration_definition_cache: dict[tuple[str, str, str], tuple[str, str]] = {}
243
244
245def gen_declaration_and_definition(
246    schema: FunctionSchema, device: str, backend_call: str
247) -> tuple[str, str]:
248    func_name = schema.name.unambiguous_name()
249
250    global declaration_definition_cache
251    if (func_name, device, backend_call) in declaration_definition_cache:
252        return declaration_definition_cache[(func_name, device, backend_call)]
253
254    if schema.is_out_fn():
255        # out_variant has out arguments in the front, and it's ok to ignore return values
256        # because C shim functions only return AOTITorchError
257        args, callsite_exprs = gen_arguments(
258            [*schema.arguments.out, *schema.arguments.flat_non_out]
259        )
260        ret_assignments: list[str] = []
261    else:
262        args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
263        # ignore return values for inplace ops
264        ret_declarations, ret_assignments = (
265            ([], []) if schema.name.name.inplace else gen_returns(schema)
266        )
267        args.extend(ret_declarations)
268
269    declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
270
271    tmp_result = "auto tmp_result = " if ret_assignments else ""
272    ret_assignments_str = "\n" + "\n".join(ret_assignments) if ret_assignments else ""
273    definition = f"""
274{declaration} {{
275    AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{
276        {tmp_result}{backend_call}(
277{textwrap.indent(', '.join(callsite_exprs), "            ")}
278        );{textwrap.indent(ret_assignments_str, "        ")}
279    }});
280}}
281"""
282    declaration_definition_cache[(func_name, device, backend_call)] = (
283        declaration,
284        definition,
285    )
286    return declaration, definition
287
288
289def gen_static_dispatch_backend_call_signature(
290    sig: CppSignature | DispatcherSignature,
291    f: NativeFunction,
292) -> CppSignature:
293    sig = DispatcherSignature.from_schema(f.func)
294    cpp_sigs = CppSignatureGroup.from_native_function(
295        f, method=False, fallback_binding=False
296    )
297    if sig.symint and f.func.has_symint():
298        cpp_sig = cpp_sigs.symint_signature
299    else:
300        cpp_sig = cpp_sigs.signature
301    assert cpp_sig is not None
302    return cpp_sig
303
304
305def gen_static_dispatch_backend_call(
306    f: NativeFunction,
307    backend_index: BackendIndex,
308) -> str:
309    sig = DispatcherSignature.from_schema(f.func)
310    cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
311    return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
312
313
314def get_backend_index_for_aoti(
315    func: NativeFunction,
316    func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
317    dispatch_key: DispatchKey,
318    backend_indices: dict[DispatchKey, BackendIndex],
319) -> BackendIndex | None:
320    backend_index = None
321    if backend_indices[dispatch_key].has_kernel(func) or (
322        func.structured_delegate is not None
323        and func.structured_delegate in func_group_mapping
324        and backend_indices[dispatch_key].has_kernel(
325            func_group_mapping[func.structured_delegate]
326        )
327    ):
328        backend_index = backend_indices[dispatch_key]
329    elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(func):
330        # We need to create C shim wrappers for CompositeExplicitAutograd kernels
331        backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd]
332    elif backend_indices[DispatchKey.CompositeExplicitAutogradNonFunctional].has_kernel(
333        func
334    ):
335        # We need to create C shim wrappers for CompositeExplicitAutogradNonFunctional kernels
336        backend_index = backend_indices[
337            DispatchKey.CompositeExplicitAutogradNonFunctional
338        ]
339    elif backend_indices[DispatchKey.CompositeImplicitAutograd].has_kernel(func):
340        backend_index = backend_indices[DispatchKey.CompositeImplicitAutograd]
341
342    return backend_index
343
344
345def get_header_for_aoti(
346    func: NativeFunction,
347    func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
348    dispatch_key: DispatchKey,
349    backend_indices: dict[DispatchKey, BackendIndex],
350) -> str | None:
351    backend_index = get_backend_index_for_aoti(
352        func, func_group_mapping, dispatch_key, backend_indices
353    )
354    return (
355        None
356        if backend_index is None
357        else f"#include <ATen/ops/{func.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
358    )
359
360
361def get_fallback_op_name(func: NativeFunction) -> str:
362    return (
363        f"{func.namespace}.{func.func.name.name}.{func.func.name.overload_name}"
364        if func.func.name.overload_name
365        else f"{func.namespace}.{func.func.name.name}.default"
366    )
367
368
369def gen_c_shim(
370    func: NativeFunction,
371    func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
372    dispatch_key: DispatchKey,
373    backend_indices: dict[DispatchKey, BackendIndex],
374    header: bool,
375) -> str | None:
376    backend_index = get_backend_index_for_aoti(
377        func, func_group_mapping, dispatch_key, backend_indices
378    )
379    if backend_index is None:
380        return None
381
382    schema = func.func
383    device = dispatch_key.lower()
384    backend_call = gen_static_dispatch_backend_call(
385        func,
386        backend_index,
387    )
388
389    try:
390        if header:
391            declaration, _ = gen_declaration_and_definition(
392                schema, device, backend_call
393            )
394            return f"AOTI_TORCH_EXPORT {declaration};"
395        else:
396            _, definition = gen_declaration_and_definition(schema, device, backend_call)
397            return definition
398
399    except NotImplementedError:
400        return None
401
402
403@dataclass(frozen=True)
404class ShimGenerator:
405    func_group_mapping: dict[OperatorName, NativeFunctionsGroup]
406    dispatch_key: DispatchKey
407    backend_indices: dict[DispatchKey, BackendIndex]
408    header: bool  # True to generate .h and False to generate .cpp
409
410    @method_with_native_function
411    def __call__(
412        self,
413        func: NativeFunction,
414    ) -> str | None:
415        result = gen_c_shim(
416            func,
417            self.func_group_mapping,
418            self.dispatch_key,
419            self.backend_indices,
420            self.header,
421        )
422        return result
423
424
425def gen_aoti_c_shim(
426    native_functions: Sequence[NativeFunction],
427    func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
428    dispatch_key: DispatchKey,
429    backend_indices: dict[DispatchKey, BackendIndex],
430    header: bool,
431    includes: str = "",
432) -> str:
433    body = "\n".join(
434        list(
435            mapMaybe(
436                ShimGenerator(
437                    func_group_mapping, dispatch_key, backend_indices, header
438                ),
439                native_functions,
440            )
441        )
442    )
443    device = dispatch_key.lower()
444
445    warning = """
446// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
447// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details"""
448
449    if header:
450        return f"""
451{warning}
452
453#pragma once
454
455#include <torch/csrc/inductor/aoti_torch/c/shim.h>
456
457#ifdef __cplusplus
458extern "C" {{
459#endif
460
461{body}
462
463#ifdef __cplusplus
464}} // extern "C"
465#endif
466"""
467
468    else:
469        return f"""
470{warning}
471
472#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{device}.h>
473#include <torch/csrc/inductor/aoti_torch/utils.h>
474
475#ifndef AT_PER_OPERATOR_HEADERS
476#include <ATen/{str(dispatch_key)}Functions.h>
477#include <ATen/CompositeExplicitAutogradFunctions.h>
478#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
479#include <ATen/CompositeImplicitAutogradFunctions.h>
480#else
481{includes}
482#endif
483
484using namespace torch::aot_inductor;
485
486{body}"""
487