xref: /aosp_15_r20/external/pytorch/torchgen/api/cpp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from typing import Sequence
4
5from torchgen import local
6from torchgen.api.types import (
7    ArgName,
8    ArrayCType,
9    ArrayRefCType,
10    BaseCType,
11    BaseTypeToCppMapping,
12    Binding,
13    boolT,
14    ConstRefCType,
15    CType,
16    dimnameListT,
17    intArrayRefT,
18    iTensorListRefT,
19    ListCType,
20    longT,
21    MutRefCType,
22    NamedCType,
23    OptionalCType,
24    optionalIntArrayRefT,
25    optionalSymIntArrayRefT,
26    scalarT,
27    SpecialArgName,
28    symIntArrayRefT,
29    SymIntT,
30    tensorListT,
31    tensorOptionsT,
32    tensorT,
33    TupleCType,
34    VectorCType,
35    voidT,
36)
37from torchgen.model import (
38    Argument,
39    Arguments,
40    BaseTy,
41    BaseType,
42    FunctionSchema,
43    ListType,
44    NativeFunction,
45    OptionalType,
46    Return,
47    SelfArgument,
48    TensorOptionsArguments,
49    Type,
50)
51from torchgen.utils import assert_never
52
53
54# This file describes the translation of JIT schema to the public C++
55# API, which is what people use when they call functions like at::add.
56#
57# Prominent characteristics of the C++ API:
58#
59#   - dtype, layout, device and pin_memory are collected into
60#     a single C++ type TensorOptions  (the native functions API
61#     also has this, but tensor options is really most relevant
62#     for the C++ API; it makes calling kwarg factory functions
63#     pleasant)
64#
65#   - defaulting lives here (in fact, the dispatcher is completely
66#     oblivious of defaults!)
67#
68# BTW: policy on name collisions: we try not to have types with
69# collisions, but functions are fair game to collide
70
71
72def name(
73    func: FunctionSchema,
74    *,
75    faithful_name_for_out_overloads: bool = False,
76    symint_overload: bool = False,
77) -> str:
78    name = str(func.name.name)
79    if symint_overload:
80        name += "_symint"
81    if func.is_out_fn():
82        if faithful_name_for_out_overloads:
83            name += "_outf"
84        else:
85            name += "_out"
86
87    return name
88
89
90# Translation of "value types" in JIT schema to C++ API type.  Value
91# types look the same no matter if they are argument types or return
92# types.  Returns None if the type in question is not a value type.
93def valuetype_type(
94    t: Type,
95    *,
96    binds: ArgName,
97    mutable: bool = True,
98    remove_non_owning_ref_types: bool = False,
99    symint: bool = False,
100) -> NamedCType | None:
101    if isinstance(t, BaseType):
102        if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
103            return None
104        elif str(t) == "SymInt":
105            if symint:
106                return NamedCType(binds, BaseCType(SymIntT))
107            else:
108                return NamedCType(binds, BaseCType(longT))
109        if remove_non_owning_ref_types:
110            if t.name == BaseTy.str:
111                raise AssertionError(
112                    "string ref->value conversion: not implemented yet"
113                )
114        # All other BaseType currently map directly to BaseCppTypes.
115        return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
116    elif isinstance(t, OptionalType):
117        elem = valuetype_type(t.elem, binds=binds, mutable=mutable, symint=symint)
118        if elem is None:
119            return None
120        return NamedCType(binds, OptionalCType(elem.type))
121    elif isinstance(t, ListType):
122        if str(t.elem) == "bool":
123            assert t.size is not None
124            return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size))
125        else:
126            return None
127    else:
128        raise AssertionError(f"unrecognized type {repr(t)}")
129
130
131# Translation of types occurring in JIT arguments to a C++ argument type.
132# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
133# For example, we'll return std::vector<int> instead of IntArrayRef.
134# See Note [translation from C++ reference to value types]
135def argumenttype_type(
136    t: Type,
137    *,
138    mutable: bool,
139    binds: ArgName,
140    remove_non_owning_ref_types: bool = False,
141    symint: bool = False,
142) -> NamedCType:
143    # If it's a value type, do the value type translation
144    r = valuetype_type(
145        t,
146        binds=binds,
147        mutable=mutable,
148        symint=symint,
149        remove_non_owning_ref_types=remove_non_owning_ref_types,
150    )
151    if r is not None:
152        return r
153
154    if isinstance(t, BaseType):
155        if t.name == BaseTy.Tensor:
156            if mutable and not local.use_const_ref_for_mutable_tensors():
157                return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
158            else:
159                return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
160        elif t.name == BaseTy.Scalar:
161            return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
162        else:
163            raise AssertionError(f"base type should have been value type {t}")
164    elif isinstance(t, OptionalType):
165        if str(t.elem) == "Tensor":
166            if mutable and not local.use_const_ref_for_mutable_tensors():
167                return NamedCType(
168                    binds, MutRefCType(BaseCType(tensorT))
169                )  # TODO: fix this discrepancy
170            else:
171                return NamedCType(
172                    binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
173                )
174        elif str(t.elem) == "Scalar":
175            return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
176        elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
177            return NamedCType(binds, BaseCType(optionalIntArrayRefT))
178        elif isinstance(t.elem, ListType) and str(t.elem.elem) == "SymInt":
179            if symint:
180                return NamedCType(binds, BaseCType(optionalSymIntArrayRefT))
181            else:
182                return NamedCType(binds, BaseCType(optionalIntArrayRefT))
183        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
184        return NamedCType(binds, OptionalCType(elem.type))
185    elif isinstance(t, ListType):
186        # TODO: remove these special cases, ArrayRef fallthrough works fine
187        if str(t.elem) == "int":
188            if remove_non_owning_ref_types:
189                return NamedCType(binds, VectorCType(BaseCType(longT)))
190            else:
191                return NamedCType(binds, BaseCType(intArrayRefT))
192        if str(t.elem) == "SymInt":
193            if remove_non_owning_ref_types:
194                if symint:
195                    return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
196                else:
197                    return NamedCType(binds, VectorCType(BaseCType(longT)))
198            else:
199                if symint:
200                    return NamedCType(binds, BaseCType(symIntArrayRefT))
201                else:
202                    return NamedCType(binds, BaseCType(intArrayRefT))
203        if str(t.elem) == "Tensor":
204            if local.use_ilistref_for_tensor_lists():
205                return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
206            else:
207                return NamedCType(binds, BaseCType(tensorListT))
208        elif str(t.elem) == "Scalar":
209            return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
210        elif str(t.elem) == "Dimname":
211            return NamedCType(binds, BaseCType(dimnameListT))
212        elif str(t.elem) == "Tensor?":
213            return NamedCType(
214                binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
215            )
216        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
217        return NamedCType(binds, ArrayRefCType(elem.type))
218    else:
219        raise AssertionError(f"unrecognized type {repr(t)}")
220
221
222# Translate a JIT argument into its C++ type
223def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType:
224    return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds)
225
226
227# Translation of a (non-multi) return type from JIT to C++
228# N.B: returntype_type returns a CType, not a NamedCType.
229# This is mostly because of the mismatch between return types and return names.
230# e.g. a function with a return type of 'void' has 0 return names,
231# and a function with a return type of 'std::tuple' has >1 return name.
232def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
233    # placeholder is ignored
234    # NB: symint is ALWAYS respected for return types.  So symint argument
235    # here is IGNORED
236    r = valuetype_type(t, binds="__placeholder__", mutable=mutable, symint=True)
237    if r is not None:
238        return r.type
239
240    if isinstance(t, BaseType):
241        if t.name == BaseTy.Tensor:
242            if mutable:
243                if local.use_const_ref_for_mutable_tensors():
244                    return ConstRefCType(BaseCType(tensorT))
245                else:
246                    return MutRefCType(BaseCType(tensorT))
247            else:
248                # Note [Tensor Copy Returns]
249                # Currently, we use "Argument.is_write" to determine
250                # whether or not Tensor return types should be copies or references.
251                # If that ever changes, take a look at other locations of this note!
252                return BaseCType(tensorT)
253        elif t.name == BaseTy.Scalar:
254            return BaseCType(scalarT)
255    elif isinstance(t, ListType):
256        assert (
257            not mutable
258        ), "Native functions should never return a mutable tensor list. They should return void."
259        elem = returntype_type(t.elem, mutable=False)
260        assert t.size is None, f"fixed size list returns not supported: {t}"
261        return VectorCType(elem)
262    elif isinstance(t, OptionalType):
263        elem = returntype_type(t.elem, mutable=mutable)
264        if str(t.elem) == "Tensor":
265            return OptionalCType(elem)
266
267    raise AssertionError(f"unrecognized return type {t}")
268
269
270# Translation of a single return to its C++ type
271def return_type(r: Return, *, symint: bool = False) -> CType:
272    return returntype_type(r.type, mutable=r.is_write, symint=symint)
273
274
275# Translation of a full (possibly multi) return from JIT to its C++ type
276def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
277    if len(rs) == 0:
278        return BaseCType(voidT)
279    elif len(rs) == 1:
280        return return_type(rs[0], symint=symint)
281    else:
282        return TupleCType([return_type(r, symint=symint) for r in rs])
283
284
285def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
286    returns: list[str] = []
287    for i, r in enumerate(f.func.returns):
288        # If we have an inplace function, the return argument is
289        # implicitly named self.
290        # TODO: Consider incorporating this into the data model
291        if f.func.name.name.inplace:
292            assert i == 0, "illegal inplace function with multiple returns"
293            name = "self"
294        # If we are out function, the name is the name of the
295        # corresponding output function (r.name will get recorded
296        # in field_name later.)
297        elif f.func.is_out_fn():
298            name = f.func.arguments.out[i].name
299        # If the return argument is explicitly named...
300        elif r.name:
301            name_conflict = any(
302                r.name == a.name for a in f.func.schema_order_arguments()
303            )
304            if name_conflict and not f.func.is_out_fn():
305                name = f"{r.name}_return"
306            else:
307                name = r.name
308        # If there is no explicit name and no fallback name was passed in, we just name the output result,
309        # unless it's a multi-return, in which case it's result0,
310        # result1, etc (zero-indexed)
311        else:
312            name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
313        returns.append(name)
314    return returns
315
316
317JIT_TO_CPP_DEFAULT = {
318    "False": "false",
319    "True": "true",
320    "None": "::std::nullopt",  # UGH this one is type directed
321    "Mean": "at::Reduction::Mean",
322    "[]": "{}",
323    "contiguous_format": "c10::MemoryFormat::Contiguous",
324    "long": "at::kLong",
325}
326
327
328# Convert a JIT default into C++ expression representing the default
329def default_expr(d: str, t: Type, *, symint: bool) -> str:
330    if d == "None" and str(t) == "Tensor?":
331        return "{}"
332    if isinstance(t, BaseType) and t.name is BaseTy.str:
333        # Schema allows single quotes but C++ needs double
334        if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
335            s = ""
336            i = 1
337            while i + 1 < len(d):
338                if d[i] != "\\":
339                    if d[i] == '"':
340                        s += '\\"'
341                    else:
342                        s += d[i]
343                    i += 1
344                else:
345                    if d[i + 1] == "'":
346                        s += "'"
347                    else:
348                        s += d[i : i + 2]
349                    i += 2
350
351            return f'"{s}"'
352
353    if isinstance(t, OptionalType):
354        if d == "None":
355            return "::std::nullopt"
356
357        return default_expr(d, t.elem, symint=symint)
358
359    if isinstance(t, ListType):
360        if d.startswith("[") and d.endswith("]"):
361            return "{" + d[1:-1] + "}"
362        elif symint and d.isdigit() and str(t.elem) == "SymInt":
363            return f"c10::SymInt({d})"
364        elif t.size is None:
365            # NOTE: Sized lists can have scalar defaults
366            raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
367
368    return JIT_TO_CPP_DEFAULT.get(d, d)
369
370
371# Convert an argument into its C++ API form
372
373
374def argument(
375    a: Argument | TensorOptionsArguments | SelfArgument,
376    *,
377    cpp_no_default_args: set[str],
378    method: bool,
379    faithful: bool,
380    symint: bool = False,
381    has_tensor_options: bool,
382) -> list[Binding]:
383    def sub_argument(
384        a: Argument | TensorOptionsArguments | SelfArgument,
385    ) -> list[Binding]:
386        return argument(
387            a,
388            cpp_no_default_args=cpp_no_default_args,
389            method=method,
390            faithful=faithful,
391            symint=symint,
392            has_tensor_options=has_tensor_options,
393        )
394
395    if isinstance(a, Argument):
396        binds: ArgName
397        if a.name == "memory_format" and has_tensor_options:
398            binds = SpecialArgName.possibly_redundant_memory_format
399        else:
400            binds = a.name
401        default: str | None = None
402        if a.name not in cpp_no_default_args and a.default is not None:
403            default = default_expr(a.default, a.type, symint=symint)
404        return [
405            Binding(
406                nctype=argument_type(a, binds=binds, symint=symint),
407                name=a.name,
408                default=default,
409                argument=a,
410            )
411        ]
412    elif isinstance(a, TensorOptionsArguments):
413        if faithful:
414            return (
415                sub_argument(a.dtype)
416                + sub_argument(a.layout)
417                + sub_argument(a.device)
418                + sub_argument(a.pin_memory)
419            )
420        else:
421            default = None
422            # Enforced by NativeFunction.__post_init__
423            assert "options" not in cpp_no_default_args
424            if all(x.default == "None" for x in a.all()):
425                default = "{}"
426            elif a.dtype.default == "long":
427                default = "at::kLong"  # TODO: this is wrong
428            return [
429                Binding(
430                    nctype=NamedCType("options", BaseCType(tensorOptionsT)),
431                    name="options",
432                    default=default,
433                    argument=a,
434                )
435            ]
436    elif isinstance(a, SelfArgument):
437        if method:
438            # Caller is responsible for installing implicit this in context!
439            return []
440        else:
441            return sub_argument(a.argument)
442    else:
443        assert_never(a)
444
445
446def arguments(
447    arguments: Arguments,
448    *,
449    faithful: bool,
450    symint: bool = False,
451    method: bool,
452    cpp_no_default_args: set[str],
453) -> list[Binding]:
454    args: list[Argument | TensorOptionsArguments | SelfArgument] = []
455    if faithful:
456        args.extend(arguments.non_out)
457        args.extend(arguments.out)
458    else:
459        args.extend(arguments.out)
460        args.extend(arguments.non_out)
461    return [
462        r.no_default() if faithful else r
463        for a in args
464        for r in argument(
465            a,
466            faithful=faithful,
467            symint=symint,
468            method=method,
469            has_tensor_options=arguments.tensor_options is not None,
470            cpp_no_default_args=cpp_no_default_args,
471        )
472    ]
473