xref: /aosp_15_r20/external/pytorch/torchgen/api/types/types_base.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2Where should I add a new type? `types_base.py` vs `types.py`
3
4This file defines data model classes for torchgen typing system, as well as some base types such as int32_t.
5
6`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types.
7
8The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't
9contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused
10if we want to generate code for another C++ library.
11
12Add new types to `types.py` if these types are ATen/c10 related.
13Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
14"""
15
16from __future__ import annotations
17
18from abc import ABC, abstractmethod
19from dataclasses import dataclass
20from enum import auto, Enum
21from typing import TYPE_CHECKING, Union
22
23
24if TYPE_CHECKING:
25    from torchgen.model import Argument, SelfArgument, TensorOptionsArguments
26
27
28# An ArgName is just the str name of the argument in schema;
29# but in some special circumstances, we may add a little extra
30# context.  The Enum SpecialArgName covers all of these cases;
31# grep for their construction sites to see when they can occur.
32
33
34class SpecialArgName(Enum):
35    possibly_redundant_memory_format = auto()
36
37
38ArgName = Union[str, SpecialArgName]
39
40
41# This class shouldn't be created directly; instead, use/create one of the singletons below.
42@dataclass(frozen=True)
43class BaseCppType:
44    ns: str | None
45    name: str
46
47    def __str__(self) -> str:
48        if self.ns is None or self.ns == "":
49            return self.name
50        return f"{self.ns}::{self.name}"
51
52
53# The set of all non-templated, valid, fully-qualified names of C++ types that are used in the codegen.
54# Templated types get their own dataclass, mainly to make namespace parsing easier.
55byteT = BaseCppType("", "uint8_t")
56charT = BaseCppType("", "int8_t")
57shortT = BaseCppType("", "int16_t")
58# It would be more symmetric for this to be called intT, but it easy to mix
59# this up with JIT int (which is int64_t in C++), so we intentionally don't
60# define intT to make it obvious when you've stuffed it up
61int32T = BaseCppType("", "int32_t")
62longT = BaseCppType("", "int64_t")
63doubleT = BaseCppType("", "double")
64floatT = BaseCppType("", "float")
65boolT = BaseCppType("", "bool")
66voidT = BaseCppType("", "void")
67
68
69class CType(ABC):
70    @abstractmethod
71    def cpp_type(self, *, strip_ref: bool = False) -> str:
72        raise NotImplementedError
73
74    @abstractmethod
75    def cpp_type_registration_declarations(self) -> str:
76        raise NotImplementedError
77
78    @abstractmethod
79    def remove_const_ref(self) -> CType:
80        return self
81
82
83@dataclass(frozen=True)
84class BaseCType(CType):
85    type: BaseCppType
86
87    def cpp_type(self, *, strip_ref: bool = False) -> str:
88        return str(self.type)
89
90    # For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
91    # TODO: Kill this when we eventually remove it!
92    def cpp_type_registration_declarations(self) -> str:
93        return str(self.type).replace("at::", "")
94
95    def remove_const_ref(self) -> CType:
96        return self
97
98
99@dataclass(frozen=True)
100class ConstRefCType(CType):
101    elem: CType
102
103    def cpp_type(self, *, strip_ref: bool = False) -> str:
104        if strip_ref:
105            return self.elem.cpp_type(strip_ref=strip_ref)
106        return f"const {self.elem.cpp_type()} &"
107
108    def cpp_type_registration_declarations(self) -> str:
109        return f"const {self.elem.cpp_type_registration_declarations()} &"
110
111    def remove_const_ref(self) -> CType:
112        return self.elem.remove_const_ref()
113
114
115@dataclass(frozen=True)
116class VectorCType(CType):
117    elem: CType
118
119    def cpp_type(self, *, strip_ref: bool = False) -> str:
120        # Do not pass `strip_ref` recursively.
121        return f"::std::vector<{self.elem.cpp_type()}>"
122
123    def cpp_type_registration_declarations(self) -> str:
124        return f"::std::vector<{self.elem.cpp_type_registration_declarations()}>"
125
126    def remove_const_ref(self) -> CType:
127        return VectorCType(self.elem.remove_const_ref())
128
129
130@dataclass(frozen=True)
131class ArrayCType(CType):
132    elem: CType
133    size: int
134
135    def cpp_type(self, *, strip_ref: bool = False) -> str:
136        # Do not pass `strip_ref` recursively.
137        return f"::std::array<{self.elem.cpp_type()},{self.size}>"
138
139    def cpp_type_registration_declarations(self) -> str:
140        return f"::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>"
141
142    def remove_const_ref(self) -> CType:
143        return ArrayCType(self.elem.remove_const_ref(), self.size)
144
145
146@dataclass(frozen=True)
147class TupleCType(CType):
148    elems: list[CType]
149
150    def cpp_type(self, *, strip_ref: bool = False) -> str:
151        # Do not pass `strip_ref` recursively.
152        return f'::std::tuple<{",".join([e.cpp_type() for e in self.elems])}>'
153
154    def cpp_type_registration_declarations(self) -> str:
155        return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>'
156
157    def remove_const_ref(self) -> CType:
158        return TupleCType([e.remove_const_ref() for e in self.elems])
159
160
161@dataclass(frozen=True)
162class MutRefCType(CType):
163    elem: CType
164
165    def cpp_type(self, *, strip_ref: bool = False) -> str:
166        if strip_ref:
167            return self.elem.cpp_type(strip_ref=strip_ref)
168        return f"{self.elem.cpp_type()} &"
169
170    def cpp_type_registration_declarations(self) -> str:
171        return f"{self.elem.cpp_type_registration_declarations()} &"
172
173    def remove_const_ref(self) -> CType:
174        return self.elem.remove_const_ref()
175
176
177# A NamedCType is short for Named C++ semantic type.  A NamedCType represents a C++ type, plus
178# semantic information about what it represents.  For example, consider the
179# argument "bool pin_memory"; its normal C++ type is "bool", but its C++
180# semantic type also keeps track that this represents a "pin_memory"; you can't
181# just use a random other boolean in a context where you need a "pin_memory"!
182#
183
184
185@dataclass(frozen=True)
186class NamedCType:
187    name: ArgName
188    type: CType
189
190    def cpp_type(self, *, strip_ref: bool = False) -> str:
191        return self.type.cpp_type(strip_ref=strip_ref)
192
193    # For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
194    # TODO: Kill this when we eventually remove it!
195    def cpp_type_registration_declarations(self) -> str:
196        return self.type.cpp_type_registration_declarations()
197
198    def remove_const_ref(self) -> NamedCType:
199        return NamedCType(self.name, self.type.remove_const_ref())
200
201    def with_name(self, name: str) -> NamedCType:
202        return NamedCType(name, self.type)
203
204
205# A binding represents any C++ binding site for a formal parameter.
206# We don't distinguish between binding sites for different APIs;
207# instead, all of the important distinctions are encoded in CType,
208# which you can use to figure out if a given Binding is appropriate
209# for use in another context.  (See torchgen.api.translate)
210
211
212@dataclass(frozen=True)
213class Binding:
214    name: str
215    nctype: NamedCType
216    argument: Argument | TensorOptionsArguments | SelfArgument
217    # TODO: maybe don't represent default here
218    default: str | None = None
219
220    def rename(self, name: str) -> Binding:
221        return Binding(
222            name=name,
223            nctype=self.nctype,
224            argument=self.argument,
225            default=self.default,
226        )
227
228    @property
229    def type(self) -> str:
230        return self.nctype.cpp_type()
231
232    def no_default(self) -> Binding:
233        return Binding(
234            name=self.name,
235            nctype=self.nctype,
236            default=None,
237            argument=self.argument,
238        )
239
240    def decl(self, *, func_ptr_cast: bool = False) -> str:
241        mb_default = ""
242        if self.default is not None:
243            mb_default = f"={self.default}"
244
245        # casting only needs to know the type
246        if func_ptr_cast:
247            return f"{self.type}"
248        else:
249            return f"{self.type} {self.name}{mb_default}"
250
251    # For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml
252    # TODO: Kill this when we eventually remove it!
253    def decl_registration_declarations(self) -> str:
254        type_s = self.nctype.cpp_type_registration_declarations()
255        mb_default = ""
256        if self.default is not None:
257            mb_default = f"={self.default}"
258        return f"{type_s} {self.name}{mb_default}"
259
260    def defn(self) -> str:
261        return f"{self.type} {self.name}"
262
263    def with_name(self, name: str) -> Binding:
264        return Binding(
265            name=name, nctype=self.nctype, argument=self.argument, default=self.default
266        )
267
268
269# An Expr is a C++ expression.  It has a C++ string representing its syntax,
270# as well as a CType saying what it provides.
271
272
273@dataclass(frozen=True)
274class Expr:
275    expr: str
276    type: NamedCType
277