xref: /aosp_15_r20/external/pytorch/torchgen/api/types/signatures.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from dataclasses import dataclass
4from typing import Iterator, Sequence, TYPE_CHECKING
5
6from torchgen.api.types.types_base import Binding, CType, Expr
7
8
9if TYPE_CHECKING:
10    from torchgen.model import (
11        BackendIndex,
12        FunctionSchema,
13        NativeFunction,
14        NativeFunctionsGroup,
15        NativeFunctionsViewGroup,
16    )
17
18
19@dataclass(frozen=True)
20class CppSignature:
21    """
22    A CppSignature represents a single overload in the C++ API.  For
23    any given function schema, there may be multiple CppSignatures
24    corresponding to it, based on how we desugar to C++.  See also
25    CppSignatureGroup.
26    """
27
28    # The schema this signature is derived from
29    func: FunctionSchema
30
31    # Is this a C++ signature for a method, i.e. Tensor::my_op(...)?
32    method: bool
33
34    # Is this a faithful C++ signature (i.e. following the JIT schema) or a convenience API
35    # (i.e. with a potential TensorOptions argument and out arguments in the front)
36    faithful: bool
37
38    # Is this a symint C++ signature.  For BC reasons, functions that take
39    # SymInts still present as int64_t in C++, and the SymInt variant is
40    # offered at a different overload name
41    #
42    # NB: If a function RETURNS a SymInt, this is ALWAYS false
43    symint: bool
44
45    # The set of C++ arguments which should not have defaults applied to them
46    cpp_no_default_args: set[str]
47
48    # Is this a fallback C++ binding?  Fallback bindings are enabled by
49    # manual_cpp_binding: True and are alternate, non-public API that
50    # lets manual C++ binding implementors access the binding that would
51    # have been automatically generated
52    fallback_binding: bool = False
53
54    # Return the unpacked argument structure of this signature,
55    # discarding information about which arguments are semantically
56    # related to each other.
57    def arguments(self) -> Sequence[Binding]:
58        return cpp.arguments(
59            self.func.arguments,
60            faithful=self.faithful,
61            symint=self.symint,
62            method=self.method,
63            cpp_no_default_args=self.cpp_no_default_args,
64        )
65
66    def name(self, *, suppress_symint_suffix: bool = False) -> str:
67        n = cpp.name(
68            self.func,
69            faithful_name_for_out_overloads=self.faithful,
70            symint_overload=False if suppress_symint_suffix else self.symint,
71        )
72        if self.fallback_binding:
73            n = f"__dispatch_{n}"
74        return n
75
76    # Render the C++ declaration for this signature
77    def decl(
78        self,
79        *,
80        name: str | None = None,
81        prefix: str = "",
82        is_redispatching_fn: bool = False,
83        suppress_symint_suffix: bool = False,
84    ) -> str:
85        returns_type = cpp.returns_type(
86            self.func.returns, symint=self.symint
87        ).cpp_type()
88        cpp_args = [a.decl() for a in self.arguments()]
89        if is_redispatching_fn:
90            cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args
91        cpp_args_str = ", ".join(cpp_args)
92        if name is None:
93            name = prefix + self.name(suppress_symint_suffix=suppress_symint_suffix)
94        return f"{returns_type} {name}({cpp_args_str})"
95
96    # Render the C++ definition for this signature, not including
97    # the body (with curly braces)
98    def defn(
99        self,
100        *,
101        name: str | None = None,
102        prefix: str = "",
103        is_redispatching_fn: bool = False,
104    ) -> str:
105        returns_type = cpp.returns_type(
106            self.func.returns, symint=self.symint
107        ).cpp_type()
108        cpp_args = [a.defn() for a in self.arguments()]
109        if is_redispatching_fn:
110            cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args
111        cpp_args_str = ", ".join(cpp_args)
112        if name is None:
113            name = prefix + self.name()
114        return f"{returns_type} {name}({cpp_args_str})"
115
116    def ptr_type(self) -> str:
117        args_types_str = ", ".join(a.type for a in self.arguments())
118        return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_types_str})"
119
120    # Return the C++ function type, e.g., something like int(bool)
121    def type(self) -> str:
122        args_types_str = ", ".join(a.type for a in self.arguments())
123        return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} ({args_types_str})"
124
125
126# Represents group of all CppSignatures associated with a
127# FunctionSchema.  Right now, that's the regular, user-visible
128# signature, as well as a "faithful" signature which doesn't
129# have grouping.
130@dataclass(frozen=True)
131class CppSignatureGroup:
132    func: FunctionSchema
133    signature: CppSignature
134    faithful_signature: CppSignature | None
135    symint_signature: CppSignature | None
136    symint_faithful_signature: CppSignature | None
137
138    def most_faithful_signature(self) -> CppSignature:
139        if self.faithful_signature:
140            return self.faithful_signature
141        else:
142            return self.signature
143
144    def signatures(self, *, symint: bool = True) -> Iterator[CppSignature]:
145        yield self.signature
146        if self.faithful_signature:
147            yield self.faithful_signature
148        if symint:
149            if self.symint_signature:
150                yield self.symint_signature
151            if self.symint_faithful_signature:
152                yield self.symint_faithful_signature
153
154    @staticmethod
155    def from_native_function(
156        f: NativeFunction, *, method: bool, fallback_binding: bool = False
157    ) -> CppSignatureGroup:
158        func = f.func
159
160        def make_sig(*, faithful: bool, symint: bool) -> CppSignature:
161            return CppSignature(
162                func=func,
163                faithful=faithful,
164                symint=symint,
165                method=method,
166                fallback_binding=fallback_binding,
167                cpp_no_default_args=f.cpp_no_default_args,
168            )
169
170        def make_sigs(*, symint: bool) -> tuple[CppSignature, CppSignature | None]:
171            faithful_signature: CppSignature | None = None
172            if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
173                faithful_signature = make_sig(faithful=True, symint=symint)
174            signature = make_sig(faithful=False, symint=symint)
175            return signature, faithful_signature
176
177        signature, faithful_signature = make_sigs(symint=False)
178        symint_signature: CppSignature | None = None
179        symint_faithful_signature: CppSignature | None = None
180        if func.has_symint():
181            symint_signature, symint_faithful_signature = make_sigs(symint=True)
182
183        return CppSignatureGroup(
184            func=func,
185            signature=signature,
186            faithful_signature=faithful_signature,
187            symint_signature=symint_signature,
188            symint_faithful_signature=symint_faithful_signature,
189        )
190
191
192@dataclass(frozen=True)
193class DispatcherSignature:
194    # The schema this signature is derived from
195    func: FunctionSchema
196
197    # Allows you to prepend an arbitrary prefix to the signature name.
198    # This is useful for parts of the codegen that generate wrappers around kernels,
199    # and need to avoid naming collisions.
200    prefix: str = ""
201
202    symint: bool = True
203
204    def arguments(self) -> list[Binding]:
205        return dispatcher.arguments(self.func, symint=self.symint)
206
207    def name(self) -> str:
208        return self.prefix + dispatcher.name(self.func)
209
210    def decl(self, name: str | None = None) -> str:
211        args_str = ", ".join(a.decl() for a in self.arguments())
212        if name is None:
213            name = self.name()
214        return f"{self.returns_type().cpp_type()} {name}({args_str})"
215
216    def defn(
217        self, name: str | None = None, *, is_redispatching_fn: bool = False
218    ) -> str:
219        args = [a.defn() for a in self.arguments()]
220        if is_redispatching_fn:
221            args = ["c10::DispatchKeySet dispatchKeySet"] + args
222        args_str = ", ".join(args)
223        if name is None:
224            name = self.name()
225        return f"{self.returns_type().cpp_type()} {name}({args_str})"
226
227    def exprs(self) -> list[Expr]:
228        return [Expr(a.name, a.nctype) for a in self.arguments()]
229
230    def returns_type(self) -> CType:
231        return dispatcher.returns_type(self.func.returns, symint=self.symint)
232
233    def ptr_type(self) -> str:
234        dispatcher_args_types_str = ", ".join(a.type for a in self.arguments())
235        return f"{self.returns_type().cpp_type()} (*)({dispatcher_args_types_str})"
236
237    # Return the C++ function type, e.g., something like int(bool)
238    def type(self) -> str:
239        dispatcher_args_types_str = ", ".join(a.type for a in self.arguments())
240        return f"{self.returns_type().cpp_type()} ({dispatcher_args_types_str})"
241
242    @staticmethod
243    def from_schema(
244        func: FunctionSchema, *, prefix: str = "", symint: bool = True
245    ) -> DispatcherSignature:
246        return DispatcherSignature(func, prefix, symint)
247
248
249@dataclass(frozen=True)
250class NativeSignature:
251    # The schema this signature is derived from
252    func: FunctionSchema
253
254    symint: bool
255
256    prefix: str = ""
257
258    def name(self) -> str:
259        return self.prefix + native.name(self.func)
260
261    def decl(self, name: str | None = None) -> str:
262        args_str = ", ".join(a.decl() for a in self.arguments())
263        if name is None:
264            name = self.name()
265        return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
266
267    def defn(self, name: str | None = None) -> str:
268        args_str = ", ".join(a.defn() for a in self.arguments())
269        if name is None:
270            name = self.name()
271        return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
272
273    def ptr_type(self) -> str:
274        # don't include defaults in type signature!
275        args_str = ", ".join(a.defn() for a in self.arguments())
276        return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})"
277
278    def arguments(self) -> list[Binding]:
279        return native.arguments(self.func, symint=self.symint)
280
281    def returns_type(self) -> CType:
282        return native.returns_type(self.func.returns, symint=self.symint)
283
284    def dispatcher_exprs(self) -> list[Expr]:
285        return translate.translate(
286            self.arguments(), dispatcher.arguments(self.func), method=False
287        )
288
289
290@dataclass(frozen=True)
291class ViewInverseSignature:
292    g: NativeFunctionsViewGroup
293
294    def name(self) -> str:
295        return functionalization.reverse_name(self.g.view, include_namespace=False)
296
297    def decl(self) -> str:
298        return_type = functionalization.returns_type(self.g.view.func)
299        decls = [
300            a.decl()
301            for a in functionalization.inner_arguments(
302                self.g.view.func, is_reverse=True
303            )
304        ]
305        return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});"
306
307
308@dataclass(frozen=True)
309class FunctionalizationLambda:
310    g: NativeFunctionsViewGroup
311
312    # are we generating the forward lambda or the reverse lambda?
313    is_reverse: bool
314
315    def captures(self) -> list[Expr]:
316        # The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments
317        # We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
318        # and plumb it into the lambda.
319        outer_ctx = dispatcher.arguments(self.g.view.func) + [
320            functionalization.reapply_views_binding,
321            functionalization.inverse_return_mode_binding,
322        ]
323        capture_bindings = functionalization.capture_arguments(
324            self.g.view.func, is_reverse=self.is_reverse
325        )
326        # allow_expensive_conversions is set because we want to convert
327        # some reference types (IntArrayRef) to value types (vector<int64_t>).
328        capture_exprs = translate.translate(
329            outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True
330        )
331        return capture_exprs
332
333    def decl(self) -> str:
334        return_type = functionalization.returns_type(self.g.view.func)
335        capture_str = ", ".join(
336            f"{val.type.name} = {val.expr}" for val in self.captures()
337        )
338        decls = [
339            a.decl()
340            for a in functionalization.outer_arguments(is_reverse=self.is_reverse)
341        ]
342        return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}"
343
344    def inner_call(self, *, reapply_views: bool | None = None) -> str:
345        inner_call_name = functionalization.name(
346            self.g,
347            is_reverse=self.is_reverse,
348            include_namespace=True,
349            reapply_views=reapply_views,
350        )
351
352        arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse)
353        capture_ctx = functionalization.capture_arguments(
354            self.g.view.func, is_reverse=self.is_reverse
355        )
356        full_ctx = arg_ctx + capture_ctx
357
358        assert self.g.view_copy is not None
359        call_bindings = functionalization.inner_arguments(
360            self.g.view_copy.func, is_reverse=self.is_reverse
361        )
362        maybe_index = functionalization.inner_call_index(self.g.view_copy.func)
363        call_exprs = [
364            e.expr for e in translate.translate(full_ctx, call_bindings, method=False)
365        ]
366        if not self.is_reverse and maybe_index is not None:
367            return f'{inner_call_name}({", ".join(call_exprs)})[{maybe_index.name}];'
368        else:
369            return f'{inner_call_name}({", ".join(call_exprs)});'
370
371    @staticmethod
372    def from_func(
373        g: NativeFunctionsViewGroup, *, is_reverse: bool
374    ) -> FunctionalizationLambda:
375        return FunctionalizationLambda(g, is_reverse)
376
377
378@dataclass(frozen=True)
379class StructuredImplSignature:
380    g: NativeFunctionsGroup
381    name: str
382
383    def defn(self, name: str | None = None) -> str:
384        args_str = ", ".join(a.defn() for a in self.arguments())
385        return f"TORCH_IMPL_FUNC({self.name})({args_str})"
386
387    def arguments(self) -> list[Binding]:
388        return structured.impl_arguments(self.g)
389
390
391# Helper functions
392
393
394def kernel_signature(
395    f: NativeFunction, backend_index: BackendIndex, *, prefix: str = ""
396) -> NativeSignature | DispatcherSignature:
397    # Note [External Backends Follow Dispatcher API]
398    # Kernel signatures for in-tree backends follow the "native" API,
399    # while kernels for out-of-tree backends follow the dispatcher API.
400    # See the comments in `native.py` for details, but historically there have been
401    # some small differences in schema convention between them and the Dispatcher API.
402    # Any differences that require translating between the two will results in a runtime cost,
403    # so we'd like to keep the differences as small as possible.
404    # With external backends, we'd like to enforce that they write their kernels with schemas
405    # that match the Dispatcher API directly, if they can.
406    meta = backend_index.get_kernel(f)
407    symint = meta is not None and meta.supports_symint()
408    if symint:
409        assert (
410            f.func.has_symint()
411        ), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema"
412    if backend_index.external:
413        return DispatcherSignature.from_schema(f.func, prefix=prefix, symint=symint)
414    else:
415        return NativeSignature(f.func, prefix=prefix, symint=symint)
416
417
418# Functions only, no types
419from torchgen.api import (
420    cpp,
421    dispatcher,
422    functionalization,
423    native,
424    structured,
425    translate,
426)
427