xref: /aosp_15_r20/external/pytorch/torchgen/model.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport dataclasses
4*da0073e9SAndroid Build Coastguard Workerimport itertools
5*da0073e9SAndroid Build Coastguard Workerimport re
6*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass
7*da0073e9SAndroid Build Coastguard Workerfrom enum import auto, Enum
8*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, Iterator, Sequence
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import assert_never, NamespaceHelper, OrderedSet
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
14*da0073e9SAndroid Build Coastguard Worker#
15*da0073e9SAndroid Build Coastguard Worker#                           DATA MODEL
16*da0073e9SAndroid Build Coastguard Worker#
17*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
18*da0073e9SAndroid Build Coastguard Worker#
19*da0073e9SAndroid Build Coastguard Worker# Some general principles for our data model.
20*da0073e9SAndroid Build Coastguard Worker#
21*da0073e9SAndroid Build Coastguard Worker# - Stop using C++ data types as the internal data representation
22*da0073e9SAndroid Build Coastguard Worker#   format.  Instead, the internal data structures are centered
23*da0073e9SAndroid Build Coastguard Worker#   around JIT schema representation.  This avoid a big problem
24*da0073e9SAndroid Build Coastguard Worker#   with the old codegen where we read in all the types from
25*da0073e9SAndroid Build Coastguard Worker#   native_functions.yaml and then immediately had to retranslate
26*da0073e9SAndroid Build Coastguard Worker#   them into C++ types.
27*da0073e9SAndroid Build Coastguard Worker#
28*da0073e9SAndroid Build Coastguard Worker# - More semantic data representation.  Instead of representing
29*da0073e9SAndroid Build Coastguard Worker#   everything as dicts and strings, we define dataclasses for
30*da0073e9SAndroid Build Coastguard Worker#   every interesting entity the code generation has to deal with.
31*da0073e9SAndroid Build Coastguard Worker#   These dataclasses have strong semantic invariants: for example,
32*da0073e9SAndroid Build Coastguard Worker#   we generally require them to roundtrip losslessly into the
33*da0073e9SAndroid Build Coastguard Worker#   form they were parsed from.  These structures are immutable
34*da0073e9SAndroid Build Coastguard Worker#   and you're expected to populate information once during
35*da0073e9SAndroid Build Coastguard Worker#   construction.
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker# Represent a source location; used for better error reporting
39*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
40*da0073e9SAndroid Build Coastguard Workerclass Location:
41*da0073e9SAndroid Build Coastguard Worker    file: str
42*da0073e9SAndroid Build Coastguard Worker    line: int
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
45*da0073e9SAndroid Build Coastguard Worker        return f"{self.file}:{self.line}"
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker# Valid values of the 'variants' field in native_functions.yaml
49*da0073e9SAndroid Build Coastguard Workerclass Variant(Enum):
50*da0073e9SAndroid Build Coastguard Worker    function = auto()
51*da0073e9SAndroid Build Coastguard Worker    method = auto()
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker# Default kernel namespace
55*da0073e9SAndroid Build Coastguard WorkerDEFAULT_KERNEL_NAMESPACE = "at::native"
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker# NOTE: Keep the list in sync with `DispatchKey` in c10/core/DispatchKey.h
58*da0073e9SAndroid Build Coastguard WorkerBACKEND_COMPONENTS = "CPU CUDA HIP XLA MTIA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split()
59*da0073e9SAndroid Build Coastguard WorkerFUNCTIONALITY_KEYS = [
60*da0073e9SAndroid Build Coastguard Worker    "",
61*da0073e9SAndroid Build Coastguard Worker    "Quantized",
62*da0073e9SAndroid Build Coastguard Worker    "Sparse",
63*da0073e9SAndroid Build Coastguard Worker    "SparseCsr",
64*da0073e9SAndroid Build Coastguard Worker    "NestedTensor",
65*da0073e9SAndroid Build Coastguard Worker    "Autograd",
66*da0073e9SAndroid Build Coastguard Worker]
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker# This list guards dispatches that can be used in derivatives.yaml
69*da0073e9SAndroid Build Coastguard Worker# For now we omit AutogradFunctionality and AutogradOther
70*da0073e9SAndroid Build Coastguard WorkerAUTOGRAD_KEYS = ["AutogradNestedTensor"] + [
71*da0073e9SAndroid Build Coastguard Worker    "Autograd" + component for component in BACKEND_COMPONENTS
72*da0073e9SAndroid Build Coastguard Worker]
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard WorkerFRAGMENT_NAMESPACES = {"quantized", "quantized_decomposed"}
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker# This doesn't have to be in sync with the header, it only needs to contain
78*da0073e9SAndroid Build Coastguard Worker# entries that we actually use in the codegen or want pyi entries for
79*da0073e9SAndroid Build Coastguard Workerclass DispatchKey(Enum):
80*da0073e9SAndroid Build Coastguard Worker    Undefined = 0
81*da0073e9SAndroid Build Coastguard Worker    CatchAll = Undefined
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker    FPGA = auto()
84*da0073e9SAndroid Build Coastguard Worker    MAIA = auto()
85*da0073e9SAndroid Build Coastguard Worker    Vulkan = auto()
86*da0073e9SAndroid Build Coastguard Worker    Metal = auto()
87*da0073e9SAndroid Build Coastguard Worker    MKLDNN = auto()
88*da0073e9SAndroid Build Coastguard Worker    OpenGL = auto()
89*da0073e9SAndroid Build Coastguard Worker    OpenCL = auto()
90*da0073e9SAndroid Build Coastguard Worker    IDEEP = auto()
91*da0073e9SAndroid Build Coastguard Worker    CustomRNGKeyId = auto()
92*da0073e9SAndroid Build Coastguard Worker    MkldnnCPU = auto()
93*da0073e9SAndroid Build Coastguard Worker    Sparse = auto()
94*da0073e9SAndroid Build Coastguard Worker    SparseCsr = auto()
95*da0073e9SAndroid Build Coastguard Worker    NestedTensor = auto()
96*da0073e9SAndroid Build Coastguard Worker    Dense = auto()
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker    PythonTLSSnapshot = auto()
99*da0073e9SAndroid Build Coastguard Worker    PreDispatch = auto()
100*da0073e9SAndroid Build Coastguard Worker    PythonDispatcher = auto()
101*da0073e9SAndroid Build Coastguard Worker    Python = auto()
102*da0073e9SAndroid Build Coastguard Worker    FuncTorchDynamicLayerBackMode = auto()
103*da0073e9SAndroid Build Coastguard Worker    ZeroTensor = auto()
104*da0073e9SAndroid Build Coastguard Worker    Conjugate = auto()
105*da0073e9SAndroid Build Coastguard Worker    Negative = auto()
106*da0073e9SAndroid Build Coastguard Worker    BackendSelect = auto()
107*da0073e9SAndroid Build Coastguard Worker    Named = auto()
108*da0073e9SAndroid Build Coastguard Worker    AutogradOther = auto()
109*da0073e9SAndroid Build Coastguard Worker    AutogradFunctionality = auto()
110*da0073e9SAndroid Build Coastguard Worker    AutogradNestedTensor = auto()
111*da0073e9SAndroid Build Coastguard Worker    Tracer = auto()
112*da0073e9SAndroid Build Coastguard Worker    Autocast = auto()
113*da0073e9SAndroid Build Coastguard Worker    AutocastCPU = auto()
114*da0073e9SAndroid Build Coastguard Worker    AutocastCUDA = auto()
115*da0073e9SAndroid Build Coastguard Worker    Batched = auto()
116*da0073e9SAndroid Build Coastguard Worker    VmapMode = auto()
117*da0073e9SAndroid Build Coastguard Worker    FuncTorchGradWrapper = auto()
118*da0073e9SAndroid Build Coastguard Worker    FuncTorchBatched = auto()
119*da0073e9SAndroid Build Coastguard Worker    BatchedNestedTensor = auto()
120*da0073e9SAndroid Build Coastguard Worker    FuncTorchVmapMode = auto()
121*da0073e9SAndroid Build Coastguard Worker    FuncTorchDynamicLayerFrontMode = auto()
122*da0073e9SAndroid Build Coastguard Worker    Functionalize = auto()
123*da0073e9SAndroid Build Coastguard Worker    TESTING_ONLY_GenericWrapper = auto()
124*da0073e9SAndroid Build Coastguard Worker    TESTING_ONLY_GenericMode = auto()
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker    ADInplaceOrView = auto()
127*da0073e9SAndroid Build Coastguard Worker    Autograd = auto()
128*da0073e9SAndroid Build Coastguard Worker    CompositeImplicitAutograd = auto()
129*da0073e9SAndroid Build Coastguard Worker    CompositeImplicitAutogradNestedTensor = auto()
130*da0073e9SAndroid Build Coastguard Worker    CompositeExplicitAutograd = auto()
131*da0073e9SAndroid Build Coastguard Worker    CompositeExplicitAutogradNonFunctional = auto()
132*da0073e9SAndroid Build Coastguard Worker    FuncTorchBatchedDecomposition = auto()
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker    # BEGIN autogenerated
135*da0073e9SAndroid Build Coastguard Worker    CPU = auto()
136*da0073e9SAndroid Build Coastguard Worker    CUDA = auto()
137*da0073e9SAndroid Build Coastguard Worker    HIP = auto()
138*da0073e9SAndroid Build Coastguard Worker    XLA = auto()
139*da0073e9SAndroid Build Coastguard Worker    MTIA = auto()
140*da0073e9SAndroid Build Coastguard Worker    MPS = auto()
141*da0073e9SAndroid Build Coastguard Worker    IPU = auto()
142*da0073e9SAndroid Build Coastguard Worker    XPU = auto()
143*da0073e9SAndroid Build Coastguard Worker    HPU = auto()
144*da0073e9SAndroid Build Coastguard Worker    VE = auto()
145*da0073e9SAndroid Build Coastguard Worker    Lazy = auto()
146*da0073e9SAndroid Build Coastguard Worker    Meta = auto()
147*da0073e9SAndroid Build Coastguard Worker    PrivateUse1 = auto()
148*da0073e9SAndroid Build Coastguard Worker    PrivateUse2 = auto()
149*da0073e9SAndroid Build Coastguard Worker    PrivateUse3 = auto()
150*da0073e9SAndroid Build Coastguard Worker    QuantizedCPU = auto()
151*da0073e9SAndroid Build Coastguard Worker    QuantizedCUDA = auto()
152*da0073e9SAndroid Build Coastguard Worker    QuantizedHIP = auto()
153*da0073e9SAndroid Build Coastguard Worker    QuantizedXLA = auto()
154*da0073e9SAndroid Build Coastguard Worker    QuantizedMTIA = auto()
155*da0073e9SAndroid Build Coastguard Worker    QuantizedMPS = auto()
156*da0073e9SAndroid Build Coastguard Worker    QuantizedIPU = auto()
157*da0073e9SAndroid Build Coastguard Worker    QuantizedXPU = auto()
158*da0073e9SAndroid Build Coastguard Worker    QuantizedHPU = auto()
159*da0073e9SAndroid Build Coastguard Worker    QuantizedVE = auto()
160*da0073e9SAndroid Build Coastguard Worker    QuantizedLazy = auto()
161*da0073e9SAndroid Build Coastguard Worker    QuantizedMeta = auto()
162*da0073e9SAndroid Build Coastguard Worker    QuantizedPrivateUse1 = auto()
163*da0073e9SAndroid Build Coastguard Worker    QuantizedPrivateUse2 = auto()
164*da0073e9SAndroid Build Coastguard Worker    QuantizedPrivateUse3 = auto()
165*da0073e9SAndroid Build Coastguard Worker    SparseCPU = auto()
166*da0073e9SAndroid Build Coastguard Worker    SparseCUDA = auto()
167*da0073e9SAndroid Build Coastguard Worker    SparseHIP = auto()
168*da0073e9SAndroid Build Coastguard Worker    SparseXLA = auto()
169*da0073e9SAndroid Build Coastguard Worker    SparseMTIA = auto()
170*da0073e9SAndroid Build Coastguard Worker    SparseMPS = auto()
171*da0073e9SAndroid Build Coastguard Worker    SparseIPU = auto()
172*da0073e9SAndroid Build Coastguard Worker    SparseXPU = auto()
173*da0073e9SAndroid Build Coastguard Worker    SparseHPU = auto()
174*da0073e9SAndroid Build Coastguard Worker    SparseVE = auto()
175*da0073e9SAndroid Build Coastguard Worker    SparseLazy = auto()
176*da0073e9SAndroid Build Coastguard Worker    SparseMeta = auto()
177*da0073e9SAndroid Build Coastguard Worker    SparsePrivateUse1 = auto()
178*da0073e9SAndroid Build Coastguard Worker    SparsePrivateUse2 = auto()
179*da0073e9SAndroid Build Coastguard Worker    SparsePrivateUse3 = auto()
180*da0073e9SAndroid Build Coastguard Worker    SparseCsrCPU = auto()
181*da0073e9SAndroid Build Coastguard Worker    SparseCsrCUDA = auto()
182*da0073e9SAndroid Build Coastguard Worker    SparseCsrHIP = auto()
183*da0073e9SAndroid Build Coastguard Worker    SparseCsrXLA = auto()
184*da0073e9SAndroid Build Coastguard Worker    SparseCsrMTIA = auto()
185*da0073e9SAndroid Build Coastguard Worker    SparseCsrMPS = auto()
186*da0073e9SAndroid Build Coastguard Worker    SparseCsrIPU = auto()
187*da0073e9SAndroid Build Coastguard Worker    SparseCsrXPU = auto()
188*da0073e9SAndroid Build Coastguard Worker    SparseCsrHPU = auto()
189*da0073e9SAndroid Build Coastguard Worker    SparseCsrVE = auto()
190*da0073e9SAndroid Build Coastguard Worker    SparseCsrLazy = auto()
191*da0073e9SAndroid Build Coastguard Worker    SparseCsrMeta = auto()
192*da0073e9SAndroid Build Coastguard Worker    SparseCsrPrivateUse1 = auto()
193*da0073e9SAndroid Build Coastguard Worker    SparseCsrPrivateUse2 = auto()
194*da0073e9SAndroid Build Coastguard Worker    SparseCsrPrivateUse3 = auto()
195*da0073e9SAndroid Build Coastguard Worker    NestedTensorCPU = auto()
196*da0073e9SAndroid Build Coastguard Worker    NestedTensorCUDA = auto()
197*da0073e9SAndroid Build Coastguard Worker    NestedTensorHIP = auto()
198*da0073e9SAndroid Build Coastguard Worker    NestedTensorXLA = auto()
199*da0073e9SAndroid Build Coastguard Worker    NestedTensorMTIA = auto()
200*da0073e9SAndroid Build Coastguard Worker    NestedTensorMPS = auto()
201*da0073e9SAndroid Build Coastguard Worker    NestedTensorIPU = auto()
202*da0073e9SAndroid Build Coastguard Worker    NestedTensorXPU = auto()
203*da0073e9SAndroid Build Coastguard Worker    NestedTensorHPU = auto()
204*da0073e9SAndroid Build Coastguard Worker    NestedTensorVE = auto()
205*da0073e9SAndroid Build Coastguard Worker    NestedTensorLazy = auto()
206*da0073e9SAndroid Build Coastguard Worker    NestedTensorMeta = auto()
207*da0073e9SAndroid Build Coastguard Worker    NestedTensorPrivateUse1 = auto()
208*da0073e9SAndroid Build Coastguard Worker    NestedTensorPrivateUse2 = auto()
209*da0073e9SAndroid Build Coastguard Worker    NestedTensorPrivateUse3 = auto()
210*da0073e9SAndroid Build Coastguard Worker    AutogradCPU = auto()
211*da0073e9SAndroid Build Coastguard Worker    AutogradCUDA = auto()
212*da0073e9SAndroid Build Coastguard Worker    AutogradHIP = auto()
213*da0073e9SAndroid Build Coastguard Worker    AutogradXLA = auto()
214*da0073e9SAndroid Build Coastguard Worker    AutogradMTIA = auto()
215*da0073e9SAndroid Build Coastguard Worker    AutogradMPS = auto()
216*da0073e9SAndroid Build Coastguard Worker    AutogradIPU = auto()
217*da0073e9SAndroid Build Coastguard Worker    AutogradXPU = auto()
218*da0073e9SAndroid Build Coastguard Worker    AutogradHPU = auto()
219*da0073e9SAndroid Build Coastguard Worker    AutogradVE = auto()
220*da0073e9SAndroid Build Coastguard Worker    AutogradLazy = auto()
221*da0073e9SAndroid Build Coastguard Worker    AutogradMeta = auto()
222*da0073e9SAndroid Build Coastguard Worker    AutogradPrivateUse1 = auto()
223*da0073e9SAndroid Build Coastguard Worker    AutogradPrivateUse2 = auto()
224*da0073e9SAndroid Build Coastguard Worker    AutogradPrivateUse3 = auto()
225*da0073e9SAndroid Build Coastguard Worker    # END autogenerated
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
228*da0073e9SAndroid Build Coastguard Worker        return self.name
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker    def lower(self) -> str:
231*da0073e9SAndroid Build Coastguard Worker        return str(self).lower()
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker    @staticmethod
234*da0073e9SAndroid Build Coastguard Worker    def parse(value: str) -> DispatchKey:
235*da0073e9SAndroid Build Coastguard Worker        for k, v in DispatchKey.__members__.items():
236*da0073e9SAndroid Build Coastguard Worker            if k == value:
237*da0073e9SAndroid Build Coastguard Worker                return v
238*da0073e9SAndroid Build Coastguard Worker        raise AssertionError(f"unknown dispatch key {value}")
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Workerclass _TorchDispatchModeKey(Enum):
242*da0073e9SAndroid Build Coastguard Worker    FAKE = auto()
243*da0073e9SAndroid Build Coastguard Worker    PROXY = auto()
244*da0073e9SAndroid Build Coastguard Worker    FUNCTIONAL = auto()
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Workerdef codegen_per_backend_entries() -> str:
248*da0073e9SAndroid Build Coastguard Worker    r = []
249*da0073e9SAndroid Build Coastguard Worker    for fk in FUNCTIONALITY_KEYS:
250*da0073e9SAndroid Build Coastguard Worker        for bc in BACKEND_COMPONENTS:
251*da0073e9SAndroid Build Coastguard Worker            r.append(f"    {fk}{bc} = auto()")
252*da0073e9SAndroid Build Coastguard Worker    return "\n".join(r)
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Workerfor fk in FUNCTIONALITY_KEYS:
256*da0073e9SAndroid Build Coastguard Worker    for bc in BACKEND_COMPONENTS:
257*da0073e9SAndroid Build Coastguard Worker        if not hasattr(DispatchKey, fk + bc):
258*da0073e9SAndroid Build Coastguard Worker            r = codegen_per_backend_entries()
259*da0073e9SAndroid Build Coastguard Worker            print(r)
260*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
261*da0073e9SAndroid Build Coastguard Worker                f"Missing {fk}{bc} from DispatchKey enum.  Here is the autogenerated list we expect to have:\n\n{r}"
262*da0073e9SAndroid Build Coastguard Worker            )
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard WorkerSTRUCTURED_DISPATCH_KEYS = {
266*da0073e9SAndroid Build Coastguard Worker    DispatchKey.MPS,
267*da0073e9SAndroid Build Coastguard Worker    DispatchKey.CUDA,
268*da0073e9SAndroid Build Coastguard Worker    DispatchKey.CPU,
269*da0073e9SAndroid Build Coastguard Worker    DispatchKey.XPU,
270*da0073e9SAndroid Build Coastguard Worker}
271*da0073e9SAndroid Build Coastguard WorkerUFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU}
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker# Set of supported dispatch keys
274*da0073e9SAndroid Build Coastguard Workerdispatch_keys = [
275*da0073e9SAndroid Build Coastguard Worker    DispatchKey.CPU,
276*da0073e9SAndroid Build Coastguard Worker    DispatchKey.SparseCPU,
277*da0073e9SAndroid Build Coastguard Worker    DispatchKey.SparseCsrCPU,
278*da0073e9SAndroid Build Coastguard Worker    DispatchKey.MkldnnCPU,
279*da0073e9SAndroid Build Coastguard Worker    DispatchKey.CUDA,
280*da0073e9SAndroid Build Coastguard Worker    DispatchKey.MPS,
281*da0073e9SAndroid Build Coastguard Worker    DispatchKey.XPU,
282*da0073e9SAndroid Build Coastguard Worker    DispatchKey.SparseCUDA,
283*da0073e9SAndroid Build Coastguard Worker    DispatchKey.SparseCsrCUDA,
284*da0073e9SAndroid Build Coastguard Worker    DispatchKey.QuantizedCPU,
285*da0073e9SAndroid Build Coastguard Worker    DispatchKey.QuantizedCUDA,
286*da0073e9SAndroid Build Coastguard Worker    DispatchKey.CompositeImplicitAutograd,
287*da0073e9SAndroid Build Coastguard Worker    DispatchKey.CompositeImplicitAutogradNestedTensor,
288*da0073e9SAndroid Build Coastguard Worker    DispatchKey.CompositeExplicitAutograd,
289*da0073e9SAndroid Build Coastguard Worker    DispatchKey.CompositeExplicitAutogradNonFunctional,
290*da0073e9SAndroid Build Coastguard Worker    DispatchKey.NestedTensorCPU,
291*da0073e9SAndroid Build Coastguard Worker    DispatchKey.NestedTensorCUDA,
292*da0073e9SAndroid Build Coastguard Worker    # Meta is a magic key: it is automatically generated for structured
293*da0073e9SAndroid Build Coastguard Worker    # kernels
294*da0073e9SAndroid Build Coastguard Worker    DispatchKey.Meta,
295*da0073e9SAndroid Build Coastguard Worker    DispatchKey.SparseMeta,
296*da0073e9SAndroid Build Coastguard Worker    DispatchKey.SparseCsrMeta,
297*da0073e9SAndroid Build Coastguard Worker    DispatchKey.QuantizedMeta,
298*da0073e9SAndroid Build Coastguard Worker    DispatchKey.NestedTensorMeta,
299*da0073e9SAndroid Build Coastguard Worker    DispatchKey.ZeroTensor,
300*da0073e9SAndroid Build Coastguard Worker]
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker# Dispatch keys that "support all backends".  These codegen slightly differently
304*da0073e9SAndroid Build Coastguard Worker# then backend specific keys.
305*da0073e9SAndroid Build Coastguard Workerdef is_generic_dispatch_key(dk: DispatchKey) -> bool:
306*da0073e9SAndroid Build Coastguard Worker    return dk in {
307*da0073e9SAndroid Build Coastguard Worker        DispatchKey.CompositeExplicitAutograd,
308*da0073e9SAndroid Build Coastguard Worker        DispatchKey.CompositeExplicitAutogradNonFunctional,
309*da0073e9SAndroid Build Coastguard Worker        DispatchKey.CompositeImplicitAutograd,
310*da0073e9SAndroid Build Coastguard Worker        DispatchKey.CompositeImplicitAutogradNestedTensor,
311*da0073e9SAndroid Build Coastguard Worker    }
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker# CUDA specific dispatch keys
315*da0073e9SAndroid Build Coastguard Workerdef is_cuda_dispatch_key(dk: DispatchKey) -> bool:
316*da0073e9SAndroid Build Coastguard Worker    return dk in {
317*da0073e9SAndroid Build Coastguard Worker        DispatchKey.CUDA,
318*da0073e9SAndroid Build Coastguard Worker        DispatchKey.QuantizedCUDA,
319*da0073e9SAndroid Build Coastguard Worker        DispatchKey.SparseCUDA,
320*da0073e9SAndroid Build Coastguard Worker        DispatchKey.SparseCsrCUDA,
321*da0073e9SAndroid Build Coastguard Worker        DispatchKey.NestedTensorCUDA,
322*da0073e9SAndroid Build Coastguard Worker        DispatchKey.AutogradCUDA,
323*da0073e9SAndroid Build Coastguard Worker    }
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker
326*da0073e9SAndroid Build Coastguard Worker# XPU specific dispatcy keys
327*da0073e9SAndroid Build Coastguard Workerdef is_xpu_dispatch_key(dk: DispatchKey) -> bool:
328*da0073e9SAndroid Build Coastguard Worker    return dk in {
329*da0073e9SAndroid Build Coastguard Worker        DispatchKey.XPU,
330*da0073e9SAndroid Build Coastguard Worker        DispatchKey.QuantizedXPU,
331*da0073e9SAndroid Build Coastguard Worker        DispatchKey.SparseXPU,
332*da0073e9SAndroid Build Coastguard Worker        DispatchKey.SparseCsrXPU,
333*da0073e9SAndroid Build Coastguard Worker        DispatchKey.NestedTensorXPU,
334*da0073e9SAndroid Build Coastguard Worker        DispatchKey.AutogradXPU,
335*da0073e9SAndroid Build Coastguard Worker    }
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker# Structured kernel generation is only supported for certain key types;
339*da0073e9SAndroid Build Coastguard Worker# otherwise use old-style
340*da0073e9SAndroid Build Coastguard Workerdef is_structured_dispatch_key(dk: DispatchKey) -> bool:
341*da0073e9SAndroid Build Coastguard Worker    return dk in STRUCTURED_DISPATCH_KEYS
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Workerdef is_ufunc_dispatch_key(dk: DispatchKey) -> bool:
345*da0073e9SAndroid Build Coastguard Worker    # For now, ufunc dispatch keys coincide with structured keys
346*da0073e9SAndroid Build Coastguard Worker    return dk in UFUNC_DISPATCH_KEYS
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker# This is oddly named ScalarType and not DType for symmetry with C++
350*da0073e9SAndroid Build Coastguard Workerclass ScalarType(Enum):
351*da0073e9SAndroid Build Coastguard Worker    Byte = auto()
352*da0073e9SAndroid Build Coastguard Worker    Char = auto()
353*da0073e9SAndroid Build Coastguard Worker    Short = auto()
354*da0073e9SAndroid Build Coastguard Worker    Int = auto()
355*da0073e9SAndroid Build Coastguard Worker    Long = auto()
356*da0073e9SAndroid Build Coastguard Worker    Half = auto()
357*da0073e9SAndroid Build Coastguard Worker    Float = auto()
358*da0073e9SAndroid Build Coastguard Worker    Double = auto()
359*da0073e9SAndroid Build Coastguard Worker    ComplexHalf = auto()
360*da0073e9SAndroid Build Coastguard Worker    ComplexFloat = auto()
361*da0073e9SAndroid Build Coastguard Worker    ComplexDouble = auto()
362*da0073e9SAndroid Build Coastguard Worker    Bool = auto()
363*da0073e9SAndroid Build Coastguard Worker    BFloat16 = auto()
364*da0073e9SAndroid Build Coastguard Worker    Float8_e5m2 = auto()
365*da0073e9SAndroid Build Coastguard Worker    Float8_e5m2fnuz = auto()
366*da0073e9SAndroid Build Coastguard Worker    Float8_e4m3fn = auto()
367*da0073e9SAndroid Build Coastguard Worker    Float8_e4m3fnuz = auto()
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
370*da0073e9SAndroid Build Coastguard Worker        return self.name
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker    @staticmethod
373*da0073e9SAndroid Build Coastguard Worker    def maybe_parse(value: str) -> ScalarType | None:
374*da0073e9SAndroid Build Coastguard Worker        for k, v in ScalarType.__members__.items():
375*da0073e9SAndroid Build Coastguard Worker            if k == value:
376*da0073e9SAndroid Build Coastguard Worker                return v
377*da0073e9SAndroid Build Coastguard Worker        return None
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker    @staticmethod
380*da0073e9SAndroid Build Coastguard Worker    def parse(value: str) -> ScalarType:
381*da0073e9SAndroid Build Coastguard Worker        mb_r = ScalarType.maybe_parse(value)
382*da0073e9SAndroid Build Coastguard Worker        assert mb_r is not None, f"unknown dtype {value}"
383*da0073e9SAndroid Build Coastguard Worker        return mb_r
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker    @staticmethod
386*da0073e9SAndroid Build Coastguard Worker    def parse_set(values: str) -> OrderedSet[ScalarType]:
387*da0073e9SAndroid Build Coastguard Worker        dtypes: OrderedSet[ScalarType] = OrderedSet()
388*da0073e9SAndroid Build Coastguard Worker        for value in values.split(", "):
389*da0073e9SAndroid Build Coastguard Worker            if value in DTYPE_CLASSES:
390*da0073e9SAndroid Build Coastguard Worker                dtypes.update(DTYPE_CLASSES[value])
391*da0073e9SAndroid Build Coastguard Worker            else:
392*da0073e9SAndroid Build Coastguard Worker                dtypes.add(ScalarType.parse(value))
393*da0073e9SAndroid Build Coastguard Worker        return dtypes
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard WorkerDTYPE_CLASSES: dict[str, OrderedSet[ScalarType]] = {}
397*da0073e9SAndroid Build Coastguard Worker# NB: Integral doesn't include boolean
398*da0073e9SAndroid Build Coastguard WorkerDTYPE_CLASSES["Integral"] = OrderedSet(
399*da0073e9SAndroid Build Coastguard Worker    [
400*da0073e9SAndroid Build Coastguard Worker        ScalarType.Byte,
401*da0073e9SAndroid Build Coastguard Worker        ScalarType.Char,
402*da0073e9SAndroid Build Coastguard Worker        ScalarType.Int,
403*da0073e9SAndroid Build Coastguard Worker        ScalarType.Long,
404*da0073e9SAndroid Build Coastguard Worker        ScalarType.Short,
405*da0073e9SAndroid Build Coastguard Worker    ]
406*da0073e9SAndroid Build Coastguard Worker)
407*da0073e9SAndroid Build Coastguard Worker# NB: Floating doesn't include low precision types
408*da0073e9SAndroid Build Coastguard WorkerDTYPE_CLASSES["Floating"] = OrderedSet([ScalarType.Float, ScalarType.Double])
409*da0073e9SAndroid Build Coastguard WorkerDTYPE_CLASSES["Complex"] = OrderedSet(
410*da0073e9SAndroid Build Coastguard Worker    [ScalarType.ComplexFloat, ScalarType.ComplexDouble]
411*da0073e9SAndroid Build Coastguard Worker)
412*da0073e9SAndroid Build Coastguard WorkerDTYPE_CLASSES["All"] = DTYPE_CLASSES["Integral"] | DTYPE_CLASSES["Floating"]
413*da0073e9SAndroid Build Coastguard WorkerDTYPE_CLASSES["AllAndComplex"] = DTYPE_CLASSES["All"] | DTYPE_CLASSES["Complex"]
414*da0073e9SAndroid Build Coastguard WorkerDTYPE_CLASSES["FloatingAndComplex"] = (
415*da0073e9SAndroid Build Coastguard Worker    DTYPE_CLASSES["Floating"] | DTYPE_CLASSES["Complex"]
416*da0073e9SAndroid Build Coastguard Worker)
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Worker
419*da0073e9SAndroid Build Coastguard Worker# Represents the valid entries for ufunc_inner_loop in native_functions.yaml.
420*da0073e9SAndroid Build Coastguard Worker# NB: if you add a new UfuncKey, you will teach torchgen.dest.ufunc how
421*da0073e9SAndroid Build Coastguard Worker# to process it.  Most logic will ignore keys they don't understand, so your
422*da0073e9SAndroid Build Coastguard Worker# new key will get silently ignored until you hook in logic to deal with it.
423*da0073e9SAndroid Build Coastguard Workerclass UfuncKey(Enum):
424*da0073e9SAndroid Build Coastguard Worker    # These are low level keys that represent exactly one particular
425*da0073e9SAndroid Build Coastguard Worker    # instantiation of the kernel produced by codegen
426*da0073e9SAndroid Build Coastguard Worker    CUDAFunctor = auto()
427*da0073e9SAndroid Build Coastguard Worker    CUDAFunctorOnOther = auto()
428*da0073e9SAndroid Build Coastguard Worker    CUDAFunctorOnSelf = auto()
429*da0073e9SAndroid Build Coastguard Worker
430*da0073e9SAndroid Build Coastguard Worker    CPUScalar = auto()
431*da0073e9SAndroid Build Coastguard Worker    CPUVector = auto()
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker    # These are the ones users will usually specify, and
434*da0073e9SAndroid Build Coastguard Worker    # implicitly "fill in" the low level keys
435*da0073e9SAndroid Build Coastguard Worker    ScalarOnly = auto()  # CUDA*, CPUScalar
436*da0073e9SAndroid Build Coastguard Worker    Generic = auto()  # CUDA*, CPU*
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
439*da0073e9SAndroid Build Coastguard Worker        return self.name
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker    @staticmethod
442*da0073e9SAndroid Build Coastguard Worker    def parse(value: str) -> UfuncKey:
443*da0073e9SAndroid Build Coastguard Worker        for k, v in UfuncKey.__members__.items():
444*da0073e9SAndroid Build Coastguard Worker            if k == value:
445*da0073e9SAndroid Build Coastguard Worker                return v
446*da0073e9SAndroid Build Coastguard Worker        raise AssertionError(f"unknown ufunc key {value}")
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Workerclass DeviceCheckType(Enum):
450*da0073e9SAndroid Build Coastguard Worker    NoCheck = 0
451*da0073e9SAndroid Build Coastguard Worker    ExactSame = 1
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Workerclass ViewSchemaKind(Enum):
455*da0073e9SAndroid Build Coastguard Worker    aliasing = auto()
456*da0073e9SAndroid Build Coastguard Worker    aliasing_inplace = auto()
457*da0073e9SAndroid Build Coastguard Worker    non_aliasing = auto()
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker
460*da0073e9SAndroid Build Coastguard Worker# The basic input to the code generation is native_functions.yaml.
461*da0073e9SAndroid Build Coastguard Worker# The name "native", BTW, comes from the distinction between native
462*da0073e9SAndroid Build Coastguard Worker# functions and legacy TH functions.  The legacy TH functions are gone,
463*da0073e9SAndroid Build Coastguard Worker# but the "native" descriptor has stuck.
464*da0073e9SAndroid Build Coastguard Worker#
465*da0073e9SAndroid Build Coastguard Worker# NativeFunction models a single entry in native_functions.yaml.  Its
466*da0073e9SAndroid Build Coastguard Worker# fields roughly correspond to what you would see in the YAML itself,
467*da0073e9SAndroid Build Coastguard Worker# but after canonicalization and parsing has occurred.
468*da0073e9SAndroid Build Coastguard Worker#
469*da0073e9SAndroid Build Coastguard Worker# You can see some of the overall design patterns for how we setup
470*da0073e9SAndroid Build Coastguard Worker# dataclasses in this class, but we will defer a complete discussion
471*da0073e9SAndroid Build Coastguard Worker# of this at FunctionSchema.
472*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
473*da0073e9SAndroid Build Coastguard Workerclass NativeFunction:
474*da0073e9SAndroid Build Coastguard Worker    # The namespace for this operator. For example, if we have "at::add"
475*da0073e9SAndroid Build Coastguard Worker    # then the namespace would be "at". This enables ops to be registered
476*da0073e9SAndroid Build Coastguard Worker    # through the same DSL with a custom namespace. If not specified, the
477*da0073e9SAndroid Build Coastguard Worker    # default namespace would be "at".
478*da0073e9SAndroid Build Coastguard Worker    namespace: str
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker    # The function schema of the operator in question.  This schema
481*da0073e9SAndroid Build Coastguard Worker    # has been parsed; see FunctionSchema for more about its structure.
482*da0073e9SAndroid Build Coastguard Worker    # (This type is quoted as we are forward referencing a type
483*da0073e9SAndroid Build Coastguard Worker    # defined later in the file.  I opted for this ordering of the
484*da0073e9SAndroid Build Coastguard Worker    # classes for expository clarity.)
485*da0073e9SAndroid Build Coastguard Worker    func: FunctionSchema
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker    # Whether or not to generate mutable tensor arguments like regular
488*da0073e9SAndroid Build Coastguard Worker    # ones
489*da0073e9SAndroid Build Coastguard Worker    use_const_ref_for_mutable_tensors: bool
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker    # Whether or not to omit automatic generation of a DeviceGuard
492*da0073e9SAndroid Build Coastguard Worker    device_guard: bool
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Worker    # How to emit automatic generation of device check
495*da0073e9SAndroid Build Coastguard Worker    device_check: DeviceCheckType
496*da0073e9SAndroid Build Coastguard Worker
497*da0073e9SAndroid Build Coastguard Worker    # What python module to put the function in
498*da0073e9SAndroid Build Coastguard Worker    python_module: str | None
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker    # TODO: figure out what this does
501*da0073e9SAndroid Build Coastguard Worker    category_override: str | None
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker    # If no variants are specified in native_functions.yaml, this is
504*da0073e9SAndroid Build Coastguard Worker    # assumed to be {'function'}.
505*da0073e9SAndroid Build Coastguard Worker    variants: set[Variant]
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker    # Whether or not we should skip generating registrations for
508*da0073e9SAndroid Build Coastguard Worker    # this kernel.  This is a bit of a double-edged sword, as manual
509*da0073e9SAndroid Build Coastguard Worker    # registrations don't participate in codegen-based selective build!
510*da0073e9SAndroid Build Coastguard Worker    manual_kernel_registration: bool
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker    # Whether or not to skip generating TensorMethod/Functions bindings
513*da0073e9SAndroid Build Coastguard Worker    # for this kernel.  Technically, this doesn't actually skip generating
514*da0073e9SAndroid Build Coastguard Worker    # the binding; instead, the binding gets generated to __dispatch_{funcname}
515*da0073e9SAndroid Build Coastguard Worker    # so you can make use of the normal binding if you need it.
516*da0073e9SAndroid Build Coastguard Worker    manual_cpp_binding: bool
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker    # The location in the YAML file were this native function entry was
519*da0073e9SAndroid Build Coastguard Worker    # defined.  This is for conveniently reporting error messages!
520*da0073e9SAndroid Build Coastguard Worker    loc: Location
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker    # A list of operators that are expected to be auto-generated for this NativeFunction.
523*da0073e9SAndroid Build Coastguard Worker    # Note: This list isn't actually directly used by the codegen to generate anything.
524*da0073e9SAndroid Build Coastguard Worker    # Instead, the codegen figures out what operators to generate purely based off of
525*da0073e9SAndroid Build Coastguard Worker    # function schema, and uses the autogen declarations to error check.
526*da0073e9SAndroid Build Coastguard Worker    # We expect every NativeFunction that gets auto-generated be explicitly called out
527*da0073e9SAndroid Build Coastguard Worker    # in native_functions.yaml
528*da0073e9SAndroid Build Coastguard Worker    autogen: list[OperatorName]
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker    # If non-empty, this kernel is subject to ufunc codegen.
531*da0073e9SAndroid Build Coastguard Worker    # Sorted by ufunc_key
532*da0073e9SAndroid Build Coastguard Worker    ufunc_inner_loop: dict[UfuncKey, UfuncInnerLoop]
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker    # Whether or not this out functions is a "structured kernel".  Structured
535*da0073e9SAndroid Build Coastguard Worker    # kernels are defined a little differently from normal kernels; in
536*da0073e9SAndroid Build Coastguard Worker    # particular, their shape checking logic is defined separately from
537*da0073e9SAndroid Build Coastguard Worker    # the kernel.  Only out functions can be structured; other functions
538*da0073e9SAndroid Build Coastguard Worker    # delegate to the out function using the structured_delegate keyword.
539*da0073e9SAndroid Build Coastguard Worker    # Every structured kernel must have at least an out and a functional
540*da0073e9SAndroid Build Coastguard Worker    # variant.
541*da0073e9SAndroid Build Coastguard Worker    structured: bool
542*da0073e9SAndroid Build Coastguard Worker
543*da0073e9SAndroid Build Coastguard Worker    # Whether or not this non-out function is a structured kernel, defined
544*da0073e9SAndroid Build Coastguard Worker    # in terms of the out kernel referenced by the string here.
545*da0073e9SAndroid Build Coastguard Worker    structured_delegate: OperatorName | None
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker    # Only valid for structured kernels.  Specifies alternative of what
548*da0073e9SAndroid Build Coastguard Worker    # to inherit from when defining the meta class for the structured
549*da0073e9SAndroid Build Coastguard Worker    # operator.  This will usually be TensorIteratorBase.  This also
550*da0073e9SAndroid Build Coastguard Worker    # changes the semantics of set_output to call the parent class.
551*da0073e9SAndroid Build Coastguard Worker    structured_inherits: str | None
552*da0073e9SAndroid Build Coastguard Worker
553*da0073e9SAndroid Build Coastguard Worker    # Structured kernels can declare elements as "precomputed". These elements
554*da0073e9SAndroid Build Coastguard Worker    # are returned by the meta function in one struct and passed to the impl
555*da0073e9SAndroid Build Coastguard Worker    # function in lieu of certain kernel arguments that these precomputed
556*da0073e9SAndroid Build Coastguard Worker    # elements supersede. Information about the names and types of these
557*da0073e9SAndroid Build Coastguard Worker    # precomputed elements and how they correspond to kernel arguments is stored
558*da0073e9SAndroid Build Coastguard Worker    # in this member, if applicable.
559*da0073e9SAndroid Build Coastguard Worker    precomputed: Precompute | None
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Worker    # Argument names whose default  should be excluded from the C++ interface.
562*da0073e9SAndroid Build Coastguard Worker    # Intended for resolving overload ambiguities between signatures.
563*da0073e9SAndroid Build Coastguard Worker    cpp_no_default_args: set[str]
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker    # Note [Abstract ATen methods]
566*da0073e9SAndroid Build Coastguard Worker    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
567*da0073e9SAndroid Build Coastguard Worker    # An abstract ATen method is one whose dispatch differs between
568*da0073e9SAndroid Build Coastguard Worker    # types.  These are implemented in derived types (with a
569*da0073e9SAndroid Build Coastguard Worker    # standard (throwing) definition in Type).  A concrete ATen
570*da0073e9SAndroid Build Coastguard Worker    # method is one which has the same dispatch for all types;
571*da0073e9SAndroid Build Coastguard Worker    # we just implement it in the base Type.  This is exposed
572*da0073e9SAndroid Build Coastguard Worker    # in Declarations.yaml via a field named 'abstract'.
573*da0073e9SAndroid Build Coastguard Worker    is_abstract: bool
574*da0073e9SAndroid Build Coastguard Worker
575*da0073e9SAndroid Build Coastguard Worker    # Whether or not the NativeFunction contains a backend-agnostic kernel
576*da0073e9SAndroid Build Coastguard Worker    has_composite_implicit_autograd_kernel: bool
577*da0073e9SAndroid Build Coastguard Worker    has_composite_implicit_autograd_nested_tensor_kernel: bool
578*da0073e9SAndroid Build Coastguard Worker    has_composite_explicit_autograd_kernel: bool
579*da0073e9SAndroid Build Coastguard Worker    has_composite_explicit_autograd_non_functional_kernel: bool
580*da0073e9SAndroid Build Coastguard Worker
581*da0073e9SAndroid Build Coastguard Worker    # Tags are used to describe semantic information about (groups of) operators,
582*da0073e9SAndroid Build Coastguard Worker    # That aren't easily inferrable directly from the operator's schema.
583*da0073e9SAndroid Build Coastguard Worker    tags: set[str]
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker    # NB: The benefit of defining a dataclass is that we automatically get
586*da0073e9SAndroid Build Coastguard Worker    # a constructor defined for all the fields we specify.  No need
587*da0073e9SAndroid Build Coastguard Worker    # to explicitly write it out.
588*da0073e9SAndroid Build Coastguard Worker
589*da0073e9SAndroid Build Coastguard Worker    # We parse both the NativeFunction + backend-specific information about it, which it stored in a corresponding BackendIndex.
590*da0073e9SAndroid Build Coastguard Worker    @staticmethod
591*da0073e9SAndroid Build Coastguard Worker    def from_yaml(
592*da0073e9SAndroid Build Coastguard Worker        ei: dict[str, object],
593*da0073e9SAndroid Build Coastguard Worker        loc: Location,
594*da0073e9SAndroid Build Coastguard Worker        valid_tags: set[str],
595*da0073e9SAndroid Build Coastguard Worker        ignore_keys: set[DispatchKey] | None = None,
596*da0073e9SAndroid Build Coastguard Worker    ) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
597*da0073e9SAndroid Build Coastguard Worker        """
598*da0073e9SAndroid Build Coastguard Worker        Parse a NativeFunction from a dictionary as directly parsed
599*da0073e9SAndroid Build Coastguard Worker        from native_functions.yaml
600*da0073e9SAndroid Build Coastguard Worker        """
601*da0073e9SAndroid Build Coastguard Worker        e = ei.copy()
602*da0073e9SAndroid Build Coastguard Worker
603*da0073e9SAndroid Build Coastguard Worker        funcs = e.pop("func")
604*da0073e9SAndroid Build Coastguard Worker        assert isinstance(funcs, str), f"not a str: {funcs}"
605*da0073e9SAndroid Build Coastguard Worker        # only support one level of namespace. E.g., aten::add
606*da0073e9SAndroid Build Coastguard Worker        namespace_helper = NamespaceHelper.from_namespaced_entity(
607*da0073e9SAndroid Build Coastguard Worker            namespaced_entity=funcs, max_level=1
608*da0073e9SAndroid Build Coastguard Worker        )
609*da0073e9SAndroid Build Coastguard Worker        namespace = namespace_helper.get_cpp_namespace(default="aten")
610*da0073e9SAndroid Build Coastguard Worker        func = FunctionSchema.parse(namespace_helper.entity_name)
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker        cpp_no_default_args_list = e.pop("cpp_no_default_args", [])
613*da0073e9SAndroid Build Coastguard Worker        assert isinstance(cpp_no_default_args_list, list)
614*da0073e9SAndroid Build Coastguard Worker        cpp_no_default_args = set(cpp_no_default_args_list)
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker        use_const_ref_for_mutable_tensors = e.pop(
617*da0073e9SAndroid Build Coastguard Worker            "use_const_ref_for_mutable_tensors", False
618*da0073e9SAndroid Build Coastguard Worker        )
619*da0073e9SAndroid Build Coastguard Worker        assert isinstance(use_const_ref_for_mutable_tensors, bool)
620*da0073e9SAndroid Build Coastguard Worker
621*da0073e9SAndroid Build Coastguard Worker        variants_s = e.pop("variants", "function")
622*da0073e9SAndroid Build Coastguard Worker        assert isinstance(variants_s, str)
623*da0073e9SAndroid Build Coastguard Worker        variants: set[Variant] = set()
624*da0073e9SAndroid Build Coastguard Worker        for v in variants_s.split(", "):
625*da0073e9SAndroid Build Coastguard Worker            if v == "function":
626*da0073e9SAndroid Build Coastguard Worker                variants.add(Variant.function)
627*da0073e9SAndroid Build Coastguard Worker            elif v == "method":
628*da0073e9SAndroid Build Coastguard Worker                variants.add(Variant.method)
629*da0073e9SAndroid Build Coastguard Worker            else:
630*da0073e9SAndroid Build Coastguard Worker                raise AssertionError(f"illegal variant {v}")
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker        manual_kernel_registration = e.pop("manual_kernel_registration", False)
633*da0073e9SAndroid Build Coastguard Worker        assert isinstance(
634*da0073e9SAndroid Build Coastguard Worker            manual_kernel_registration, bool
635*da0073e9SAndroid Build Coastguard Worker        ), f"not a bool: {manual_kernel_registration}"
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker        manual_cpp_binding = e.pop("manual_cpp_binding", False)
638*da0073e9SAndroid Build Coastguard Worker        assert isinstance(manual_cpp_binding, bool), f"not a bool: {manual_cpp_binding}"
639*da0073e9SAndroid Build Coastguard Worker
640*da0073e9SAndroid Build Coastguard Worker        device_guard = e.pop("device_guard", True)
641*da0073e9SAndroid Build Coastguard Worker        assert isinstance(device_guard, bool), f"not a bool: {device_guard}"
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker        device_check_s = e.pop("device_check", None)
644*da0073e9SAndroid Build Coastguard Worker        assert device_check_s is None or isinstance(
645*da0073e9SAndroid Build Coastguard Worker            device_check_s, str
646*da0073e9SAndroid Build Coastguard Worker        ), f"not a str: {device_check_s}"
647*da0073e9SAndroid Build Coastguard Worker        assert (
648*da0073e9SAndroid Build Coastguard Worker            device_check_s is None or device_check_s in DeviceCheckType.__members__
649*da0073e9SAndroid Build Coastguard Worker        ), f"illegal device_check: {device_check_s}"
650*da0073e9SAndroid Build Coastguard Worker        device_check: DeviceCheckType
651*da0073e9SAndroid Build Coastguard Worker        if device_check_s is None:
652*da0073e9SAndroid Build Coastguard Worker            device_check = DeviceCheckType.ExactSame
653*da0073e9SAndroid Build Coastguard Worker        else:
654*da0073e9SAndroid Build Coastguard Worker            device_check = DeviceCheckType[device_check_s]
655*da0073e9SAndroid Build Coastguard Worker
656*da0073e9SAndroid Build Coastguard Worker        structured = e.pop("structured", False)
657*da0073e9SAndroid Build Coastguard Worker        assert isinstance(structured, bool), f"not a bool: {structured}"
658*da0073e9SAndroid Build Coastguard Worker
659*da0073e9SAndroid Build Coastguard Worker        structured_delegate_s = e.pop("structured_delegate", None)
660*da0073e9SAndroid Build Coastguard Worker        assert structured_delegate_s is None or isinstance(
661*da0073e9SAndroid Build Coastguard Worker            structured_delegate_s, str
662*da0073e9SAndroid Build Coastguard Worker        ), f"not a str: {structured_delegate_s}"
663*da0073e9SAndroid Build Coastguard Worker        assert structured_delegate_s is None or "::" not in structured_delegate_s, (
664*da0073e9SAndroid Build Coastguard Worker            "namespace is not supported in structured delegate,"
665*da0073e9SAndroid Build Coastguard Worker            " using the same namespace as the native function"
666*da0073e9SAndroid Build Coastguard Worker        )
667*da0073e9SAndroid Build Coastguard Worker        structured_delegate: OperatorName | None = None
668*da0073e9SAndroid Build Coastguard Worker        if structured_delegate_s is not None:
669*da0073e9SAndroid Build Coastguard Worker            structured_delegate = OperatorName.parse(structured_delegate_s)
670*da0073e9SAndroid Build Coastguard Worker
671*da0073e9SAndroid Build Coastguard Worker        structured_inherits = e.pop("structured_inherits", None)
672*da0073e9SAndroid Build Coastguard Worker        assert structured_inherits is None or isinstance(
673*da0073e9SAndroid Build Coastguard Worker            structured_inherits, str
674*da0073e9SAndroid Build Coastguard Worker        ), f"not a str: {structured_inherits}"
675*da0073e9SAndroid Build Coastguard Worker        assert structured_inherits is None or "::" not in structured_inherits, (
676*da0073e9SAndroid Build Coastguard Worker            "namespace is not supported in structured inherits,"
677*da0073e9SAndroid Build Coastguard Worker            " using the same namespace as the native function"
678*da0073e9SAndroid Build Coastguard Worker        )
679*da0073e9SAndroid Build Coastguard Worker
680*da0073e9SAndroid Build Coastguard Worker        python_module = e.pop("python_module", None)
681*da0073e9SAndroid Build Coastguard Worker        assert python_module is None or isinstance(
682*da0073e9SAndroid Build Coastguard Worker            python_module, str
683*da0073e9SAndroid Build Coastguard Worker        ), f"not a str: {python_module}"
684*da0073e9SAndroid Build Coastguard Worker        assert (
685*da0073e9SAndroid Build Coastguard Worker            python_module is None or Variant.method not in variants
686*da0073e9SAndroid Build Coastguard Worker        ), "functions in modules cannot be methods"
687*da0073e9SAndroid Build Coastguard Worker
688*da0073e9SAndroid Build Coastguard Worker        category_override = e.pop("category_override", None)
689*da0073e9SAndroid Build Coastguard Worker        assert category_override is None or isinstance(
690*da0073e9SAndroid Build Coastguard Worker            category_override, str
691*da0073e9SAndroid Build Coastguard Worker        ), f"not a str: {category_override}"
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker        precomputed_dict = e.pop("precomputed", None)
694*da0073e9SAndroid Build Coastguard Worker        assert precomputed_dict is None or structured is True
695*da0073e9SAndroid Build Coastguard Worker        precomputed = Precompute.parse(precomputed_dict) if precomputed_dict else None
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Worker        tags_inp = e.pop("tags", [])
698*da0073e9SAndroid Build Coastguard Worker        if isinstance(tags_inp, str):
699*da0073e9SAndroid Build Coastguard Worker            tags_inp = [tags_inp]
700*da0073e9SAndroid Build Coastguard Worker        assert isinstance(tags_inp, list)
701*da0073e9SAndroid Build Coastguard Worker
702*da0073e9SAndroid Build Coastguard Worker        # All aten ops generated by torchgen receive the pt2_compliant tag.
703*da0073e9SAndroid Build Coastguard Worker        if namespace == "aten" and "pt2_compliant_tag" in valid_tags:
704*da0073e9SAndroid Build Coastguard Worker            tags_inp.append("pt2_compliant_tag")
705*da0073e9SAndroid Build Coastguard Worker
706*da0073e9SAndroid Build Coastguard Worker        tags: set[str] = set()
707*da0073e9SAndroid Build Coastguard Worker        for t in tags_inp:
708*da0073e9SAndroid Build Coastguard Worker            assert len(valid_tags) > 0
709*da0073e9SAndroid Build Coastguard Worker            # TODO: verify that the tag is valid and has an entry in tags.yaml
710*da0073e9SAndroid Build Coastguard Worker            if t in valid_tags:
711*da0073e9SAndroid Build Coastguard Worker                tags.add(t)
712*da0073e9SAndroid Build Coastguard Worker            else:
713*da0073e9SAndroid Build Coastguard Worker                raise AssertionError(f"illegal tag {t}")
714*da0073e9SAndroid Build Coastguard Worker
715*da0073e9SAndroid Build Coastguard Worker        from torchgen.api import cpp
716*da0073e9SAndroid Build Coastguard Worker
717*da0073e9SAndroid Build Coastguard Worker        raw_dispatch = e.pop("dispatch", None)
718*da0073e9SAndroid Build Coastguard Worker        assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
719*da0073e9SAndroid Build Coastguard Worker        dispatch: dict[DispatchKey, BackendMetadata] = {}
720*da0073e9SAndroid Build Coastguard Worker        num_dispatch_keys: int = 0
721*da0073e9SAndroid Build Coastguard Worker        if raw_dispatch is not None:
722*da0073e9SAndroid Build Coastguard Worker            assert not manual_kernel_registration, (
723*da0073e9SAndroid Build Coastguard Worker                "cannot specify both manual_kernel_registration and dispatch; with "
724*da0073e9SAndroid Build Coastguard Worker                "manual registration, dispatch has no effect!"
725*da0073e9SAndroid Build Coastguard Worker            )
726*da0073e9SAndroid Build Coastguard Worker            redundant_composite_implicit_autograd = False
727*da0073e9SAndroid Build Coastguard Worker            for ks, v in raw_dispatch.items():
728*da0073e9SAndroid Build Coastguard Worker                if ks == "__line__":
729*da0073e9SAndroid Build Coastguard Worker                    continue  # not worth tracking line numbers for dispatch entries
730*da0073e9SAndroid Build Coastguard Worker                assert isinstance(
731*da0073e9SAndroid Build Coastguard Worker                    ks, str
732*da0073e9SAndroid Build Coastguard Worker                ), f"illegal dispatch key '{ks}' in {raw_dispatch}"
733*da0073e9SAndroid Build Coastguard Worker                assert isinstance(
734*da0073e9SAndroid Build Coastguard Worker                    v, str
735*da0073e9SAndroid Build Coastguard Worker                ), f"illegal dispatch value '{v}' in {raw_dispatch}"
736*da0073e9SAndroid Build Coastguard Worker                for k in ks.split(","):
737*da0073e9SAndroid Build Coastguard Worker                    dispatch_key = DispatchKey.parse(k.strip())
738*da0073e9SAndroid Build Coastguard Worker                    num_dispatch_keys += 1
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker                    if ignore_keys and dispatch_key in ignore_keys:
741*da0073e9SAndroid Build Coastguard Worker                        continue
742*da0073e9SAndroid Build Coastguard Worker                    assert dispatch_key in dispatch_keys, (
743*da0073e9SAndroid Build Coastguard Worker                        f"Dispatch key {dispatch_key} of kernel {v} "
744*da0073e9SAndroid Build Coastguard Worker                        "is not a supported dispatch key."
745*da0073e9SAndroid Build Coastguard Worker                    )
746*da0073e9SAndroid Build Coastguard Worker                    # We only allow at most 3 levels of namespace for kernels.
747*da0073e9SAndroid Build Coastguard Worker                    # We will append "native" to a custom kernel namespace.
748*da0073e9SAndroid Build Coastguard Worker                    namespace_helper = NamespaceHelper.from_namespaced_entity(
749*da0073e9SAndroid Build Coastguard Worker                        v, max_level=3
750*da0073e9SAndroid Build Coastguard Worker                    )
751*da0073e9SAndroid Build Coastguard Worker                    kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
752*da0073e9SAndroid Build Coastguard Worker                    # Why is 'structured' included? External backends (e.g.
753*da0073e9SAndroid Build Coastguard Worker                    # XLA) opt into which ops are structured independently
754*da0073e9SAndroid Build Coastguard Worker                    # of which in-tree ops are structured
755*da0073e9SAndroid Build Coastguard Worker                    dispatch[dispatch_key] = BackendMetadata(
756*da0073e9SAndroid Build Coastguard Worker                        kernel=namespace_helper.entity_name,
757*da0073e9SAndroid Build Coastguard Worker                        structured=structured
758*da0073e9SAndroid Build Coastguard Worker                        and is_structured_dispatch_key(dispatch_key),
759*da0073e9SAndroid Build Coastguard Worker                        cpp_namespace=(kernel_namespace + "::native"),
760*da0073e9SAndroid Build Coastguard Worker                    )
761*da0073e9SAndroid Build Coastguard Worker                    if (
762*da0073e9SAndroid Build Coastguard Worker                        dispatch_key is DispatchKey.CompositeImplicitAutograd
763*da0073e9SAndroid Build Coastguard Worker                        and v == cpp.name(func)
764*da0073e9SAndroid Build Coastguard Worker                    ):
765*da0073e9SAndroid Build Coastguard Worker                        redundant_composite_implicit_autograd = True
766*da0073e9SAndroid Build Coastguard Worker
767*da0073e9SAndroid Build Coastguard Worker            # We count the number of dispatch keys which have not been ignored to prevent a dispatch table
768*da0073e9SAndroid Build Coastguard Worker            # in which all backend keys are ignored but necessarily kept, remaining compositeimplicit,
769*da0073e9SAndroid Build Coastguard Worker            # from being treated as redundant.
770*da0073e9SAndroid Build Coastguard Worker            assert not (
771*da0073e9SAndroid Build Coastguard Worker                num_dispatch_keys == 1 and redundant_composite_implicit_autograd
772*da0073e9SAndroid Build Coastguard Worker            ), (
773*da0073e9SAndroid Build Coastguard Worker                "unnecessary dispatch table for this function; just delete the dispatch "
774*da0073e9SAndroid Build Coastguard Worker                "key entirely"
775*da0073e9SAndroid Build Coastguard Worker            )
776*da0073e9SAndroid Build Coastguard Worker            # if a function is a structured delegate, deleting the dispatch
777*da0073e9SAndroid Build Coastguard Worker            # table is NOT semantics preserving
778*da0073e9SAndroid Build Coastguard Worker            assert (
779*da0073e9SAndroid Build Coastguard Worker                structured_delegate
780*da0073e9SAndroid Build Coastguard Worker                or dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
781*da0073e9SAndroid Build Coastguard Worker                or dispatch[DispatchKey.CompositeImplicitAutograd].supports_symint()
782*da0073e9SAndroid Build Coastguard Worker                or num_dispatch_keys != 1
783*da0073e9SAndroid Build Coastguard Worker            ), (
784*da0073e9SAndroid Build Coastguard Worker                f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} "
785*da0073e9SAndroid Build Coastguard Worker                f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}.  Rename your implementation to the expected "
786*da0073e9SAndroid Build Coastguard Worker                "name, then delete the dispatch table"
787*da0073e9SAndroid Build Coastguard Worker            )
788*da0073e9SAndroid Build Coastguard Worker        elif not structured and structured_delegate is None:
789*da0073e9SAndroid Build Coastguard Worker            name = str(func.name.name)
790*da0073e9SAndroid Build Coastguard Worker            assert not (
791*da0073e9SAndroid Build Coastguard Worker                name.startswith("new_")
792*da0073e9SAndroid Build Coastguard Worker                or name.endswith("_like")
793*da0073e9SAndroid Build Coastguard Worker                # TODO: maybe it's better to test the return
794*da0073e9SAndroid Build Coastguard Worker                or (
795*da0073e9SAndroid Build Coastguard Worker                    func.arguments.tensor_options
796*da0073e9SAndroid Build Coastguard Worker                    and not func.arguments.has_tensor_arg()
797*da0073e9SAndroid Build Coastguard Worker                )
798*da0073e9SAndroid Build Coastguard Worker            ), (
799*da0073e9SAndroid Build Coastguard Worker                f"expected {name} to have a CompositeExplicitAutograd "
800*da0073e9SAndroid Build Coastguard Worker                "dispatch entry, but there was no dispatch table.  Factory functions "
801*da0073e9SAndroid Build Coastguard Worker                "should not have implicit dispatch as they should not be decomposed "
802*da0073e9SAndroid Build Coastguard Worker                "for __torch_dispatch__"
803*da0073e9SAndroid Build Coastguard Worker            )
804*da0073e9SAndroid Build Coastguard Worker            dispatch[DispatchKey.CompositeImplicitAutograd] = BackendMetadata(
805*da0073e9SAndroid Build Coastguard Worker                cpp.name(func), structured=False, cpp_namespace=DEFAULT_KERNEL_NAMESPACE
806*da0073e9SAndroid Build Coastguard Worker            )
807*da0073e9SAndroid Build Coastguard Worker
808*da0073e9SAndroid Build Coastguard Worker        composites_in_dispatch = [
809*da0073e9SAndroid Build Coastguard Worker            d
810*da0073e9SAndroid Build Coastguard Worker            for d in dispatch
811*da0073e9SAndroid Build Coastguard Worker            if d == DispatchKey.CompositeExplicitAutograd
812*da0073e9SAndroid Build Coastguard Worker            or d == DispatchKey.CompositeExplicitAutogradNonFunctional
813*da0073e9SAndroid Build Coastguard Worker            or d == DispatchKey.CompositeImplicitAutograd
814*da0073e9SAndroid Build Coastguard Worker            or d == DispatchKey.CompositeImplicitAutogradNestedTensor
815*da0073e9SAndroid Build Coastguard Worker        ]
816*da0073e9SAndroid Build Coastguard Worker
817*da0073e9SAndroid Build Coastguard Worker        assert len(composites_in_dispatch) <= 1 or (
818*da0073e9SAndroid Build Coastguard Worker            len(composites_in_dispatch) == 2
819*da0073e9SAndroid Build Coastguard Worker            and (
820*da0073e9SAndroid Build Coastguard Worker                DispatchKey.CompositeExplicitAutogradNonFunctional
821*da0073e9SAndroid Build Coastguard Worker                not in composites_in_dispatch
822*da0073e9SAndroid Build Coastguard Worker            )
823*da0073e9SAndroid Build Coastguard Worker            and (
824*da0073e9SAndroid Build Coastguard Worker                DispatchKey.CompositeImplicitAutogradNestedTensor
825*da0073e9SAndroid Build Coastguard Worker                in composites_in_dispatch
826*da0073e9SAndroid Build Coastguard Worker            )
827*da0073e9SAndroid Build Coastguard Worker        ), (
828*da0073e9SAndroid Build Coastguard Worker            "cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, "
829*da0073e9SAndroid Build Coastguard Worker            "or CompositeImplicitAutograd on a single kernel; each "
830*da0073e9SAndroid Build Coastguard Worker            "strictly subsumes the other.  If you wanted to provide an explicit autograd "
831*da0073e9SAndroid Build Coastguard Worker            "implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only"
832*da0073e9SAndroid Build Coastguard Worker        )
833*da0073e9SAndroid Build Coastguard Worker
834*da0073e9SAndroid Build Coastguard Worker        autogen_str = e.pop("autogen", "")
835*da0073e9SAndroid Build Coastguard Worker        assert isinstance(autogen_str, str)
836*da0073e9SAndroid Build Coastguard Worker        autogen = (
837*da0073e9SAndroid Build Coastguard Worker            []
838*da0073e9SAndroid Build Coastguard Worker            if autogen_str == ""
839*da0073e9SAndroid Build Coastguard Worker            else [OperatorName.parse(x) for x in autogen_str.split(", ")]
840*da0073e9SAndroid Build Coastguard Worker        )
841*da0073e9SAndroid Build Coastguard Worker
842*da0073e9SAndroid Build Coastguard Worker        raw_ufunc_inner_loop = e.pop("ufunc_inner_loop", {})
843*da0073e9SAndroid Build Coastguard Worker        ufunc_inner_loop = {}
844*da0073e9SAndroid Build Coastguard Worker        if isinstance(raw_ufunc_inner_loop, str):
845*da0073e9SAndroid Build Coastguard Worker            ufunc_inner_loop[UfuncKey.Generic] = UfuncInnerLoop.parse(
846*da0073e9SAndroid Build Coastguard Worker                raw_ufunc_inner_loop, UfuncKey.Generic
847*da0073e9SAndroid Build Coastguard Worker            )
848*da0073e9SAndroid Build Coastguard Worker        elif isinstance(raw_ufunc_inner_loop, dict):
849*da0073e9SAndroid Build Coastguard Worker            for k, vo in raw_ufunc_inner_loop.items():
850*da0073e9SAndroid Build Coastguard Worker                if k == "__line__":
851*da0073e9SAndroid Build Coastguard Worker                    continue
852*da0073e9SAndroid Build Coastguard Worker                assert isinstance(k, str), f"ufunc_inner_loop key is not a str: {k}"
853*da0073e9SAndroid Build Coastguard Worker                assert isinstance(vo, str), f"ufunc_inner_loop value is not a str: {v}"
854*da0073e9SAndroid Build Coastguard Worker                ufunc_key = UfuncKey.parse(k)
855*da0073e9SAndroid Build Coastguard Worker                ufunc_inner_loop[ufunc_key] = UfuncInnerLoop.parse(vo, ufunc_key)
856*da0073e9SAndroid Build Coastguard Worker        else:
857*da0073e9SAndroid Build Coastguard Worker            raise AssertionError(
858*da0073e9SAndroid Build Coastguard Worker                f"ufunc_inner_loop not str or dict: {raw_ufunc_inner_loop}"
859*da0073e9SAndroid Build Coastguard Worker            )
860*da0073e9SAndroid Build Coastguard Worker        # Program the BackendIndex for the implicit dispatch entry from ufunc
861*da0073e9SAndroid Build Coastguard Worker        if ufunc_inner_loop:
862*da0073e9SAndroid Build Coastguard Worker            assert structured, "ufunc must be structured"
863*da0073e9SAndroid Build Coastguard Worker
864*da0073e9SAndroid Build Coastguard Worker            # Delay import ufunc here to avoid circular import issue
865*da0073e9SAndroid Build Coastguard Worker            # See: https://github.com/pytorch/pytorch/issues/81294
866*da0073e9SAndroid Build Coastguard Worker            import torchgen.api.ufunc as ufunc
867*da0073e9SAndroid Build Coastguard Worker
868*da0073e9SAndroid Build Coastguard Worker            for dispatch_key in UFUNC_DISPATCH_KEYS:
869*da0073e9SAndroid Build Coastguard Worker                assert (
870*da0073e9SAndroid Build Coastguard Worker                    dispatch_key not in dispatch
871*da0073e9SAndroid Build Coastguard Worker                ), f"ufunc should not have explicit dispatch entry for {dispatch_key}"
872*da0073e9SAndroid Build Coastguard Worker                dispatch[dispatch_key] = BackendMetadata(
873*da0073e9SAndroid Build Coastguard Worker                    kernel=ufunc.schema_kernel_name(func, dispatch_key),
874*da0073e9SAndroid Build Coastguard Worker                    structured=True,
875*da0073e9SAndroid Build Coastguard Worker                    cpp_namespace=DEFAULT_KERNEL_NAMESPACE,
876*da0073e9SAndroid Build Coastguard Worker                )
877*da0073e9SAndroid Build Coastguard Worker
878*da0073e9SAndroid Build Coastguard Worker        if structured_delegate:
879*da0073e9SAndroid Build Coastguard Worker            # Structured functions MUST have a dispatch table
880*da0073e9SAndroid Build Coastguard Worker            is_abstract = True
881*da0073e9SAndroid Build Coastguard Worker        else:
882*da0073e9SAndroid Build Coastguard Worker            is_abstract = (
883*da0073e9SAndroid Build Coastguard Worker                dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
884*da0073e9SAndroid Build Coastguard Worker                and dispatch.keys()
885*da0073e9SAndroid Build Coastguard Worker                != {DispatchKey.CompositeImplicitAutogradNestedTensor}
886*da0073e9SAndroid Build Coastguard Worker                and dispatch.keys()
887*da0073e9SAndroid Build Coastguard Worker                != {
888*da0073e9SAndroid Build Coastguard Worker                    DispatchKey.CompositeImplicitAutograd,
889*da0073e9SAndroid Build Coastguard Worker                    DispatchKey.CompositeImplicitAutogradNestedTensor,
890*da0073e9SAndroid Build Coastguard Worker                }
891*da0073e9SAndroid Build Coastguard Worker            )
892*da0073e9SAndroid Build Coastguard Worker
893*da0073e9SAndroid Build Coastguard Worker        has_composite_implicit_autograd_kernel = (
894*da0073e9SAndroid Build Coastguard Worker            DispatchKey.CompositeImplicitAutograd in dispatch
895*da0073e9SAndroid Build Coastguard Worker        )
896*da0073e9SAndroid Build Coastguard Worker        has_composite_implicit_autograd_nested_tensor_kernel = (
897*da0073e9SAndroid Build Coastguard Worker            DispatchKey.CompositeImplicitAutogradNestedTensor in dispatch
898*da0073e9SAndroid Build Coastguard Worker        )
899*da0073e9SAndroid Build Coastguard Worker        has_composite_explicit_autograd_kernel = (
900*da0073e9SAndroid Build Coastguard Worker            DispatchKey.CompositeExplicitAutograd in dispatch
901*da0073e9SAndroid Build Coastguard Worker        )
902*da0073e9SAndroid Build Coastguard Worker        has_composite_explicit_autograd_non_functional_kernel = (
903*da0073e9SAndroid Build Coastguard Worker            DispatchKey.CompositeExplicitAutogradNonFunctional in dispatch
904*da0073e9SAndroid Build Coastguard Worker        )
905*da0073e9SAndroid Build Coastguard Worker
906*da0073e9SAndroid Build Coastguard Worker        # We aren't going to store dispatch metadata inline in NativeFunctions;
907*da0073e9SAndroid Build Coastguard Worker        # instead it is separately indexed by backend (so other backends can
908*da0073e9SAndroid Build Coastguard Worker        # add more dispatch entries after the fact).  Reindex the individual
909*da0073e9SAndroid Build Coastguard Worker        # metadata by OperatorName!
910*da0073e9SAndroid Build Coastguard Worker        backend_metadata = {k: {func.name: v} for k, v in dispatch.items()}
911*da0073e9SAndroid Build Coastguard Worker
912*da0073e9SAndroid Build Coastguard Worker        # don't care if it exists or not; make it easier to use this function
913*da0073e9SAndroid Build Coastguard Worker        # with other yaml parsers that aren't setting __line__ in the dict
914*da0073e9SAndroid Build Coastguard Worker        e.pop("__line__", None)
915*da0073e9SAndroid Build Coastguard Worker        assert not e, f"leftover entries: {e}"
916*da0073e9SAndroid Build Coastguard Worker
917*da0073e9SAndroid Build Coastguard Worker        # Asserts that we can't do in post_init, because they rely on backend-specific info
918*da0073e9SAndroid Build Coastguard Worker        if structured_delegate is not None:
919*da0073e9SAndroid Build Coastguard Worker            for key in STRUCTURED_DISPATCH_KEYS:
920*da0073e9SAndroid Build Coastguard Worker                assert key not in dispatch, (
921*da0073e9SAndroid Build Coastguard Worker                    f"if structured_delegate, then must not have {key} in dispatch dictionary "
922*da0073e9SAndroid Build Coastguard Worker                    "(it is delegated!)"
923*da0073e9SAndroid Build Coastguard Worker                )
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker        return (
926*da0073e9SAndroid Build Coastguard Worker            NativeFunction(
927*da0073e9SAndroid Build Coastguard Worker                func=func,
928*da0073e9SAndroid Build Coastguard Worker                use_const_ref_for_mutable_tensors=use_const_ref_for_mutable_tensors,
929*da0073e9SAndroid Build Coastguard Worker                variants=variants,
930*da0073e9SAndroid Build Coastguard Worker                structured=structured,
931*da0073e9SAndroid Build Coastguard Worker                structured_delegate=structured_delegate,
932*da0073e9SAndroid Build Coastguard Worker                structured_inherits=structured_inherits,
933*da0073e9SAndroid Build Coastguard Worker                precomputed=precomputed,
934*da0073e9SAndroid Build Coastguard Worker                autogen=autogen,
935*da0073e9SAndroid Build Coastguard Worker                ufunc_inner_loop=ufunc_inner_loop,
936*da0073e9SAndroid Build Coastguard Worker                manual_kernel_registration=manual_kernel_registration,
937*da0073e9SAndroid Build Coastguard Worker                manual_cpp_binding=manual_cpp_binding,
938*da0073e9SAndroid Build Coastguard Worker                python_module=python_module,
939*da0073e9SAndroid Build Coastguard Worker                category_override=category_override,
940*da0073e9SAndroid Build Coastguard Worker                device_guard=device_guard,
941*da0073e9SAndroid Build Coastguard Worker                device_check=device_check,
942*da0073e9SAndroid Build Coastguard Worker                loc=loc,
943*da0073e9SAndroid Build Coastguard Worker                cpp_no_default_args=cpp_no_default_args,
944*da0073e9SAndroid Build Coastguard Worker                is_abstract=is_abstract,
945*da0073e9SAndroid Build Coastguard Worker                has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel,
946*da0073e9SAndroid Build Coastguard Worker                has_composite_implicit_autograd_nested_tensor_kernel=has_composite_implicit_autograd_nested_tensor_kernel,
947*da0073e9SAndroid Build Coastguard Worker                has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel,
948*da0073e9SAndroid Build Coastguard Worker                has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel,
949*da0073e9SAndroid Build Coastguard Worker                tags=tags,
950*da0073e9SAndroid Build Coastguard Worker                namespace=namespace,
951*da0073e9SAndroid Build Coastguard Worker            ),
952*da0073e9SAndroid Build Coastguard Worker            backend_metadata,
953*da0073e9SAndroid Build Coastguard Worker        )
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker    def validate_unstructured(self) -> None:
956*da0073e9SAndroid Build Coastguard Worker        # TODO: probably better to accumulate these errors and report them all
957*da0073e9SAndroid Build Coastguard Worker        # at once
958*da0073e9SAndroid Build Coastguard Worker        assert not self.structured, (
959*da0073e9SAndroid Build Coastguard Worker            "This function is structured, but there was "
960*da0073e9SAndroid Build Coastguard Worker            "no valid functional variant of it."
961*da0073e9SAndroid Build Coastguard Worker        )
962*da0073e9SAndroid Build Coastguard Worker        assert self.structured_delegate, (
963*da0073e9SAndroid Build Coastguard Worker            "This function delegates to another structured out function, "
964*da0073e9SAndroid Build Coastguard Worker            "but no valid function was found (the delegate may not exist, or it has the wrong type)"
965*da0073e9SAndroid Build Coastguard Worker        )
966*da0073e9SAndroid Build Coastguard Worker
967*da0073e9SAndroid Build Coastguard Worker    # __post_init__ functions in dataclasses can be used to do extra
968*da0073e9SAndroid Build Coastguard Worker    # validation after construction.
969*da0073e9SAndroid Build Coastguard Worker    #
970*da0073e9SAndroid Build Coastguard Worker    # Notice that we don't do any type validation here.  In fact, we
971*da0073e9SAndroid Build Coastguard Worker    # rely exclusively on mypy to check if you've done types correctly!
972*da0073e9SAndroid Build Coastguard Worker    # Validation is for nontrivial invariants that cannot be (conveniently)
973*da0073e9SAndroid Build Coastguard Worker    # encoded in the type system.
974*da0073e9SAndroid Build Coastguard Worker    def __post_init__(self) -> None:
975*da0073e9SAndroid Build Coastguard Worker        if self.func.arguments.out:
976*da0073e9SAndroid Build Coastguard Worker            assert self.variants == {Variant.function}, (
977*da0073e9SAndroid Build Coastguard Worker                "Native functions with out arguments MUST "
978*da0073e9SAndroid Build Coastguard Worker                "be declared with only function variant; e.g., variants: function; "
979*da0073e9SAndroid Build Coastguard Worker                "otherwise you will tickle a Python argument binding bug "
980*da0073e9SAndroid Build Coastguard Worker                "(which usually manifests itself as the result variable being undefined.)"
981*da0073e9SAndroid Build Coastguard Worker            )
982*da0073e9SAndroid Build Coastguard Worker        if self.structured:
983*da0073e9SAndroid Build Coastguard Worker            assert self.func.kind() == SchemaKind.out, (
984*da0073e9SAndroid Build Coastguard Worker                "Put structured field on the out= "
985*da0073e9SAndroid Build Coastguard Worker                "variant of a function; did you mean structured_delegate?"
986*da0073e9SAndroid Build Coastguard Worker            )
987*da0073e9SAndroid Build Coastguard Worker            assert (
988*da0073e9SAndroid Build Coastguard Worker                self.device_guard
989*da0073e9SAndroid Build Coastguard Worker            ), "device_guard: False is not respected by structured kernels"
990*da0073e9SAndroid Build Coastguard Worker        if self.structured_delegate:
991*da0073e9SAndroid Build Coastguard Worker            assert self.func.kind() != SchemaKind.out, (
992*da0073e9SAndroid Build Coastguard Worker                "structured_delegate field not allowed "
993*da0073e9SAndroid Build Coastguard Worker                "on out= functions; did you mean structured?"
994*da0073e9SAndroid Build Coastguard Worker            )
995*da0073e9SAndroid Build Coastguard Worker            assert (
996*da0073e9SAndroid Build Coastguard Worker                self.device_guard
997*da0073e9SAndroid Build Coastguard Worker            ), "device_guard: False is not respected by structured kernels"
998*da0073e9SAndroid Build Coastguard Worker        # Technically, with the asserts above, this assert is impossible to
999*da0073e9SAndroid Build Coastguard Worker        # happen
1000*da0073e9SAndroid Build Coastguard Worker        assert not (
1001*da0073e9SAndroid Build Coastguard Worker            self.structured and self.structured_delegate
1002*da0073e9SAndroid Build Coastguard Worker        ), "Cannot have both structured and structured_delegate on function"
1003*da0073e9SAndroid Build Coastguard Worker        defaulted_arguments = {
1004*da0073e9SAndroid Build Coastguard Worker            a.name for a in self.func.schema_order_arguments() if a.default is not None
1005*da0073e9SAndroid Build Coastguard Worker        }
1006*da0073e9SAndroid Build Coastguard Worker        invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments)
1007*da0073e9SAndroid Build Coastguard Worker        assert len(invalid_args) == 0, f"Invalid cpp_no_default_args: {invalid_args}"
1008*da0073e9SAndroid Build Coastguard Worker        if self.structured_inherits is not None:
1009*da0073e9SAndroid Build Coastguard Worker            assert (
1010*da0073e9SAndroid Build Coastguard Worker                self.structured
1011*da0073e9SAndroid Build Coastguard Worker            ), "structured_inherits must also imply structured: True"
1012*da0073e9SAndroid Build Coastguard Worker        if str(self.func.name).startswith("_foreach"):
1013*da0073e9SAndroid Build Coastguard Worker            assert self.device_check == DeviceCheckType.NoCheck, (
1014*da0073e9SAndroid Build Coastguard Worker                "foreach kernels fall back to slow path when tensor are on different devices, "
1015*da0073e9SAndroid Build Coastguard Worker                "device_check not allowed to be enabled"
1016*da0073e9SAndroid Build Coastguard Worker            )
1017*da0073e9SAndroid Build Coastguard Worker
1018*da0073e9SAndroid Build Coastguard Worker        # NB: if your function accidentally has rand/dropout/... in its name
1019*da0073e9SAndroid Build Coastguard Worker        # but is not actually random, feel free to amend this to special case
1020*da0073e9SAndroid Build Coastguard Worker        if (
1021*da0073e9SAndroid Build Coastguard Worker            "rand" in str(self.func.name)
1022*da0073e9SAndroid Build Coastguard Worker            or (
1023*da0073e9SAndroid Build Coastguard Worker                (
1024*da0073e9SAndroid Build Coastguard Worker                    "dropout" in str(self.func.name)
1025*da0073e9SAndroid Build Coastguard Worker                    or any(
1026*da0073e9SAndroid Build Coastguard Worker                        "dropout" in arg.name for arg in self.func.arguments.flat_all
1027*da0073e9SAndroid Build Coastguard Worker                    )
1028*da0073e9SAndroid Build Coastguard Worker                )
1029*da0073e9SAndroid Build Coastguard Worker                # Backwards of dropout is typically deterministic
1030*da0073e9SAndroid Build Coastguard Worker                and "backward" not in str(self.func.name)
1031*da0073e9SAndroid Build Coastguard Worker                and str(self.func.name.name) not in ["_cudnn_init_dropout_state"]
1032*da0073e9SAndroid Build Coastguard Worker            )
1033*da0073e9SAndroid Build Coastguard Worker            or self.func.arguments.has_generator_arg()
1034*da0073e9SAndroid Build Coastguard Worker        ):
1035*da0073e9SAndroid Build Coastguard Worker            assert "nondeterministic_seeded" in self.tags, str(self.func.name)
1036*da0073e9SAndroid Build Coastguard Worker
1037*da0073e9SAndroid Build Coastguard Worker    @property
1038*da0073e9SAndroid Build Coastguard Worker    def has_composite_kernel(self) -> bool:
1039*da0073e9SAndroid Build Coastguard Worker        return (
1040*da0073e9SAndroid Build Coastguard Worker            self.has_composite_implicit_autograd_kernel
1041*da0073e9SAndroid Build Coastguard Worker            or self.has_composite_explicit_autograd_kernel
1042*da0073e9SAndroid Build Coastguard Worker            or self.has_composite_explicit_autograd_non_functional_kernel
1043*da0073e9SAndroid Build Coastguard Worker        ) or (
1044*da0073e9SAndroid Build Coastguard Worker            self.has_composite_implicit_autograd_kernel
1045*da0073e9SAndroid Build Coastguard Worker            and self.has_composite_implicit_autograd_nested_tensor_kernel
1046*da0073e9SAndroid Build Coastguard Worker        )
1047*da0073e9SAndroid Build Coastguard Worker
1048*da0073e9SAndroid Build Coastguard Worker    @property
1049*da0073e9SAndroid Build Coastguard Worker    def is_view_op(self) -> bool:
1050*da0073e9SAndroid Build Coastguard Worker        rets = self.func.returns
1051*da0073e9SAndroid Build Coastguard Worker        is_non_mutating_view = len(rets) > 0 and any(
1052*da0073e9SAndroid Build Coastguard Worker            r.annotation is not None and not r.annotation.is_write for r in rets
1053*da0073e9SAndroid Build Coastguard Worker        )
1054*da0073e9SAndroid Build Coastguard Worker        # See Note [resize_ in Functionalization] for more dtails
1055*da0073e9SAndroid Build Coastguard Worker        is_inplace_view = (
1056*da0073e9SAndroid Build Coastguard Worker            "inplace_view" in self.tags
1057*da0073e9SAndroid Build Coastguard Worker            and str(self.func.name) != "resize_"
1058*da0073e9SAndroid Build Coastguard Worker            and str(self.func.name) != "resize_as_"
1059*da0073e9SAndroid Build Coastguard Worker        )
1060*da0073e9SAndroid Build Coastguard Worker        is_wildcard_view = any(
1061*da0073e9SAndroid Build Coastguard Worker            inp.annotation is not None and "*" in inp.annotation.alias_set_after
1062*da0073e9SAndroid Build Coastguard Worker            for inp in self.func.schema_order_arguments()
1063*da0073e9SAndroid Build Coastguard Worker        )
1064*da0073e9SAndroid Build Coastguard Worker        return is_non_mutating_view or is_inplace_view or is_wildcard_view
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker    @property
1067*da0073e9SAndroid Build Coastguard Worker    def view_schema_kind(self) -> ViewSchemaKind:
1068*da0073e9SAndroid Build Coastguard Worker        if self.is_view_op and self.func.name.name.inplace:
1069*da0073e9SAndroid Build Coastguard Worker            assert "inplace_view" in self.tags
1070*da0073e9SAndroid Build Coastguard Worker            return ViewSchemaKind.aliasing_inplace
1071*da0073e9SAndroid Build Coastguard Worker        if self.is_view_op:
1072*da0073e9SAndroid Build Coastguard Worker            return ViewSchemaKind.aliasing
1073*da0073e9SAndroid Build Coastguard Worker        else:
1074*da0073e9SAndroid Build Coastguard Worker            return ViewSchemaKind.non_aliasing
1075*da0073e9SAndroid Build Coastguard Worker
1076*da0073e9SAndroid Build Coastguard Worker    @property
1077*da0073e9SAndroid Build Coastguard Worker    def root_name(self) -> str:
1078*da0073e9SAndroid Build Coastguard Worker        return self.func.name.name.base
1079*da0073e9SAndroid Build Coastguard Worker
1080*da0073e9SAndroid Build Coastguard Worker    @property
1081*da0073e9SAndroid Build Coastguard Worker    def part_of_structured_group(self) -> bool:
1082*da0073e9SAndroid Build Coastguard Worker        return self.structured or self.structured_delegate is not None
1083*da0073e9SAndroid Build Coastguard Worker
1084*da0073e9SAndroid Build Coastguard Worker
1085*da0073e9SAndroid Build Coastguard Workerclass SchemaKind(Enum):
1086*da0073e9SAndroid Build Coastguard Worker    functional = auto()
1087*da0073e9SAndroid Build Coastguard Worker    inplace = auto()
1088*da0073e9SAndroid Build Coastguard Worker    out = auto()
1089*da0073e9SAndroid Build Coastguard Worker    mutable = auto()
1090*da0073e9SAndroid Build Coastguard Worker    scratch = auto()
1091*da0073e9SAndroid Build Coastguard Worker
1092*da0073e9SAndroid Build Coastguard Worker
1093*da0073e9SAndroid Build Coastguard Worker# A structured kernel is guaranteed to have a functional and out variant, and
1094*da0073e9SAndroid Build Coastguard Worker# optionally an inplace variant.
1095*da0073e9SAndroid Build Coastguard Worker#
1096*da0073e9SAndroid Build Coastguard Worker# NB: we create NativeFunctionsGroup *even if* the function is not
1097*da0073e9SAndroid Build Coastguard Worker# actually annotated structured.  Test the structured boolean to see if it
1098*da0073e9SAndroid Build Coastguard Worker# actually is structured or not.
1099*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
1100*da0073e9SAndroid Build Coastguard Workerclass NativeFunctionsGroup:
1101*da0073e9SAndroid Build Coastguard Worker    functional: NativeFunction
1102*da0073e9SAndroid Build Coastguard Worker    inplace: NativeFunction | None
1103*da0073e9SAndroid Build Coastguard Worker    mutable: NativeFunction | None
1104*da0073e9SAndroid Build Coastguard Worker    out: NativeFunction
1105*da0073e9SAndroid Build Coastguard Worker
1106*da0073e9SAndroid Build Coastguard Worker    @property
1107*da0073e9SAndroid Build Coastguard Worker    def structured(self) -> bool:
1108*da0073e9SAndroid Build Coastguard Worker        # Whether or not the operator has a meta() function. This information is backend-agnostic.
1109*da0073e9SAndroid Build Coastguard Worker        return self.out.structured
1110*da0073e9SAndroid Build Coastguard Worker
1111*da0073e9SAndroid Build Coastguard Worker    def __post_init__(self) -> None:
1112*da0073e9SAndroid Build Coastguard Worker        test_sig: FunctionSchema = self.functional.func.signature()
1113*da0073e9SAndroid Build Coastguard Worker        for f in self.functions():
1114*da0073e9SAndroid Build Coastguard Worker            if test_sig != f.func.signature():
1115*da0073e9SAndroid Build Coastguard Worker                raise AssertionError(
1116*da0073e9SAndroid Build Coastguard Worker                    "NativeFunctionsGroup constructed from two NativeFunctions "
1117*da0073e9SAndroid Build Coastguard Worker                    f"that don't have matching signatures: {test_sig} != {f.func.signature()}"
1118*da0073e9SAndroid Build Coastguard Worker                )
1119*da0073e9SAndroid Build Coastguard Worker
1120*da0073e9SAndroid Build Coastguard Worker            if self.structured != f.part_of_structured_group:
1121*da0073e9SAndroid Build Coastguard Worker                raise AssertionError(
1122*da0073e9SAndroid Build Coastguard Worker                    "NativeFunctionsGroup constructed from structured and unstructured "
1123*da0073e9SAndroid Build Coastguard Worker                    f"functions: {self.out.func.name} and {f.func.name}"
1124*da0073e9SAndroid Build Coastguard Worker                )
1125*da0073e9SAndroid Build Coastguard Worker        assert self.functional.func.kind() == SchemaKind.functional
1126*da0073e9SAndroid Build Coastguard Worker        assert self.out.func.kind() == SchemaKind.out
1127*da0073e9SAndroid Build Coastguard Worker        assert self.functional.namespace == self.out.namespace
1128*da0073e9SAndroid Build Coastguard Worker        if self.inplace is not None:
1129*da0073e9SAndroid Build Coastguard Worker            assert self.inplace.func.kind() == SchemaKind.inplace
1130*da0073e9SAndroid Build Coastguard Worker            assert self.inplace.namespace == self.functional.namespace
1131*da0073e9SAndroid Build Coastguard Worker
1132*da0073e9SAndroid Build Coastguard Worker        if self.mutable is not None:
1133*da0073e9SAndroid Build Coastguard Worker            assert self.mutable.func.kind() == SchemaKind.mutable
1134*da0073e9SAndroid Build Coastguard Worker            assert self.mutable.namespace == self.functional.namespace
1135*da0073e9SAndroid Build Coastguard Worker            # See Note [Overload Ambiguity With Functional Variants]
1136*da0073e9SAndroid Build Coastguard Worker            assert self.functional.func.name.name.functional_overload
1137*da0073e9SAndroid Build Coastguard Worker
1138*da0073e9SAndroid Build Coastguard Worker        if self.structured:
1139*da0073e9SAndroid Build Coastguard Worker            # For now, structured composite kernels are not supported (need some
1140*da0073e9SAndroid Build Coastguard Worker            # design work to figure out how to make the composite case work)
1141*da0073e9SAndroid Build Coastguard Worker            assert (
1142*da0073e9SAndroid Build Coastguard Worker                not self.out.has_composite_implicit_autograd_kernel
1143*da0073e9SAndroid Build Coastguard Worker                and not self.out.has_composite_implicit_autograd_nested_tensor_kernel
1144*da0073e9SAndroid Build Coastguard Worker            )
1145*da0073e9SAndroid Build Coastguard Worker
1146*da0073e9SAndroid Build Coastguard Worker            assert self.functional.structured_delegate == self.out.func.name, (
1147*da0073e9SAndroid Build Coastguard Worker                f"{self.functional.func.name} delegates to {self.functional.structured_delegate} "
1148*da0073e9SAndroid Build Coastguard Worker                f"but its actual delegate is {self.out.func.name}"
1149*da0073e9SAndroid Build Coastguard Worker            )
1150*da0073e9SAndroid Build Coastguard Worker            if self.inplace is not None:
1151*da0073e9SAndroid Build Coastguard Worker                assert self.inplace.structured_delegate == self.out.func.name
1152*da0073e9SAndroid Build Coastguard Worker
1153*da0073e9SAndroid Build Coastguard Worker        generated_fns = sorted(
1154*da0073e9SAndroid Build Coastguard Worker            [str(f.func.name) for f in self.functions() if "generated" in f.tags]
1155*da0073e9SAndroid Build Coastguard Worker        )
1156*da0073e9SAndroid Build Coastguard Worker        generated_fns_str = ", ".join(str(x) for x in generated_fns)
1157*da0073e9SAndroid Build Coastguard Worker        expected_generated_fns: set[str] = set()
1158*da0073e9SAndroid Build Coastguard Worker        for f in self.functions():
1159*da0073e9SAndroid Build Coastguard Worker            expected_generated_fns.update(str(op) for op in f.autogen)
1160*da0073e9SAndroid Build Coastguard Worker        expected_generated_fns_str = ", ".join(
1161*da0073e9SAndroid Build Coastguard Worker            str(x) for x in sorted(expected_generated_fns)
1162*da0073e9SAndroid Build Coastguard Worker        )
1163*da0073e9SAndroid Build Coastguard Worker        if len(expected_generated_fns) == 0 and len(generated_fns) > 0:
1164*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
1165*da0073e9SAndroid Build Coastguard Worker                f"The codegen expects to be able to generate '{generated_fns_str}'."
1166*da0073e9SAndroid Build Coastguard Worker                " In order to generate them however, we expect them to be called out explicitly in the yaml."
1167*da0073e9SAndroid Build Coastguard Worker                f" Please add an 'autogen: {generated_fns_str}' line to the entry for {str(f.func.name)}"
1168*da0073e9SAndroid Build Coastguard Worker            )
1169*da0073e9SAndroid Build Coastguard Worker        if expected_generated_fns_str != generated_fns_str:
1170*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
1171*da0073e9SAndroid Build Coastguard Worker                f"The codegen expects to be able to generate '{generated_fns_str}'."
1172*da0073e9SAndroid Build Coastguard Worker                f" To do so, it expects a line: 'autogen: {generated_fns_str}'."
1173*da0073e9SAndroid Build Coastguard Worker                f" Instead, it found 'autogen: {expected_generated_fns_str}'"
1174*da0073e9SAndroid Build Coastguard Worker            )
1175*da0073e9SAndroid Build Coastguard Worker
1176*da0073e9SAndroid Build Coastguard Worker    def signature(self) -> FunctionSchema:
1177*da0073e9SAndroid Build Coastguard Worker        return self.out.func.signature()
1178*da0073e9SAndroid Build Coastguard Worker
1179*da0073e9SAndroid Build Coastguard Worker    def functions(self) -> Iterator[NativeFunction]:
1180*da0073e9SAndroid Build Coastguard Worker        yield self.functional
1181*da0073e9SAndroid Build Coastguard Worker        yield self.out
1182*da0073e9SAndroid Build Coastguard Worker        if self.inplace is not None:
1183*da0073e9SAndroid Build Coastguard Worker            yield self.inplace
1184*da0073e9SAndroid Build Coastguard Worker        if self.mutable is not None:
1185*da0073e9SAndroid Build Coastguard Worker            yield self.mutable
1186*da0073e9SAndroid Build Coastguard Worker
1187*da0073e9SAndroid Build Coastguard Worker    @property
1188*da0073e9SAndroid Build Coastguard Worker    def root_name(self) -> str:
1189*da0073e9SAndroid Build Coastguard Worker        return self.functional.root_name
1190*da0073e9SAndroid Build Coastguard Worker
1191*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1192*da0073e9SAndroid Build Coastguard Worker    def from_dict(d: dict[SchemaKind, NativeFunction]) -> NativeFunctionsGroup | None:
1193*da0073e9SAndroid Build Coastguard Worker        assert d
1194*da0073e9SAndroid Build Coastguard Worker        if len(d) == 1:
1195*da0073e9SAndroid Build Coastguard Worker            return None
1196*da0073e9SAndroid Build Coastguard Worker        d = dict(d)  # non-destructive updates please
1197*da0073e9SAndroid Build Coastguard Worker        functional = d.pop(SchemaKind.functional, None)
1198*da0073e9SAndroid Build Coastguard Worker        inplace = d.pop(SchemaKind.inplace, None)
1199*da0073e9SAndroid Build Coastguard Worker        mutable = d.pop(SchemaKind.mutable, None)
1200*da0073e9SAndroid Build Coastguard Worker        out = d.pop(SchemaKind.out, None)
1201*da0073e9SAndroid Build Coastguard Worker        assert not d
1202*da0073e9SAndroid Build Coastguard Worker        assert functional is not None
1203*da0073e9SAndroid Build Coastguard Worker        # There are a few operators which only have functional/inplace variants;
1204*da0073e9SAndroid Build Coastguard Worker        # these don't count as structured for our purposes here
1205*da0073e9SAndroid Build Coastguard Worker        if out is None:
1206*da0073e9SAndroid Build Coastguard Worker            return None
1207*da0073e9SAndroid Build Coastguard Worker        # assuming all variants have the same namespace
1208*da0073e9SAndroid Build Coastguard Worker        return NativeFunctionsGroup(
1209*da0073e9SAndroid Build Coastguard Worker            functional=functional,
1210*da0073e9SAndroid Build Coastguard Worker            inplace=inplace,
1211*da0073e9SAndroid Build Coastguard Worker            mutable=mutable,
1212*da0073e9SAndroid Build Coastguard Worker            out=out,
1213*da0073e9SAndroid Build Coastguard Worker        )
1214*da0073e9SAndroid Build Coastguard Worker
1215*da0073e9SAndroid Build Coastguard Worker
1216*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
1217*da0073e9SAndroid Build Coastguard Workerclass BackendMetadata:
1218*da0073e9SAndroid Build Coastguard Worker    # The name of the backend kernel, for a given operator
1219*da0073e9SAndroid Build Coastguard Worker    # for in-tree backends. These names come directly from the 'dispatch" field
1220*da0073e9SAndroid Build Coastguard Worker    # in native_functions.yaml. The dispatch entry is optional; in that
1221*da0073e9SAndroid Build Coastguard Worker    # case, that is equivalent to having written:
1222*da0073e9SAndroid Build Coastguard Worker    #
1223*da0073e9SAndroid Build Coastguard Worker    #   dispatch:
1224*da0073e9SAndroid Build Coastguard Worker    #       CompositeImplicitAutograd: $operator_name
1225*da0073e9SAndroid Build Coastguard Worker    kernel: str
1226*da0073e9SAndroid Build Coastguard Worker    # Whether or not the operator has a structured kernel implemented, for this particular backend.
1227*da0073e9SAndroid Build Coastguard Worker    # For in-tree backends, they all have the same value for structured- this is listed
1228*da0073e9SAndroid Build Coastguard Worker    # in native_functions.yaml.
1229*da0073e9SAndroid Build Coastguard Worker    # However, external backends like XLA can indendently toggle which ops are structured.
1230*da0073e9SAndroid Build Coastguard Worker    structured: bool
1231*da0073e9SAndroid Build Coastguard Worker
1232*da0073e9SAndroid Build Coastguard Worker    # The namespace for kernels, default value: DEFAULT_KERNEL_NAMESPACE
1233*da0073e9SAndroid Build Coastguard Worker    cpp_namespace: str
1234*da0073e9SAndroid Build Coastguard Worker
1235*da0073e9SAndroid Build Coastguard Worker    def supports_symint(self) -> bool:
1236*da0073e9SAndroid Build Coastguard Worker        return "_symint" in self.kernel
1237*da0073e9SAndroid Build Coastguard Worker
1238*da0073e9SAndroid Build Coastguard Worker
1239*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
1240*da0073e9SAndroid Build Coastguard Workerclass UfuncInnerLoop:
1241*da0073e9SAndroid Build Coastguard Worker    name: str
1242*da0073e9SAndroid Build Coastguard Worker    supported_dtypes: OrderedSet[ScalarType]
1243*da0073e9SAndroid Build Coastguard Worker    # key is stored here because it affects the semantics of name,
1244*da0073e9SAndroid Build Coastguard Worker    # so its helpful to have them together for further processing
1245*da0073e9SAndroid Build Coastguard Worker    ufunc_key: UfuncKey
1246*da0073e9SAndroid Build Coastguard Worker
1247*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1248*da0073e9SAndroid Build Coastguard Worker    def parse(value: str, ufunc_key: UfuncKey) -> UfuncInnerLoop:
1249*da0073e9SAndroid Build Coastguard Worker        name, supported_dtypes_str = value.split(" ", 1)
1250*da0073e9SAndroid Build Coastguard Worker        assert supported_dtypes_str[0] == "("
1251*da0073e9SAndroid Build Coastguard Worker        assert supported_dtypes_str[-1] == ")"
1252*da0073e9SAndroid Build Coastguard Worker        supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
1253*da0073e9SAndroid Build Coastguard Worker        for k in supported_dtypes_str[1:-1].split(", "):
1254*da0073e9SAndroid Build Coastguard Worker            supported_dtypes |= ScalarType.parse_set(k)
1255*da0073e9SAndroid Build Coastguard Worker        return UfuncInnerLoop(
1256*da0073e9SAndroid Build Coastguard Worker            name=name, supported_dtypes=supported_dtypes, ufunc_key=ufunc_key
1257*da0073e9SAndroid Build Coastguard Worker        )
1258*da0073e9SAndroid Build Coastguard Worker
1259*da0073e9SAndroid Build Coastguard Worker
1260*da0073e9SAndroid Build Coastguard Worker# BackendIndex represents a backend.
1261*da0073e9SAndroid Build Coastguard Worker# The BackendIndex encodes per-operator information that is potentially different
1262*da0073e9SAndroid Build Coastguard Worker# for each backend. The most obvious example is the name of the kernel
1263*da0073e9SAndroid Build Coastguard Worker# (the 'dispatch' entry in native_functions.yaml).
1264*da0073e9SAndroid Build Coastguard Worker# However, there can be other examples of different backends having different information.
1265*da0073e9SAndroid Build Coastguard Worker# External backends can choose to opt their kernels to be structured independently from in-tree backends,
1266*da0073e9SAndroid Build Coastguard Worker# which means that this information isn't inherently tied to a NativeFunction- it's different per backend.
1267*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
1268*da0073e9SAndroid Build Coastguard Workerclass BackendIndex:
1269*da0073e9SAndroid Build Coastguard Worker    dispatch_key: DispatchKey
1270*da0073e9SAndroid Build Coastguard Worker    # Mainly important for structured kernels, this determines which variant in the operator group is used to implement the others.
1271*da0073e9SAndroid Build Coastguard Worker    # All in-tree ops use out kernels, while XLA uses functional kernels.
1272*da0073e9SAndroid Build Coastguard Worker    use_out_as_primary: bool
1273*da0073e9SAndroid Build Coastguard Worker    # Whether the backend requires a device guard, and device checks.
1274*da0073e9SAndroid Build Coastguard Worker    # For in-tree backends, this is currently just CUDA/HIP
1275*da0073e9SAndroid Build Coastguard Worker    # For out-of-tree backends, this is currently just Intel XPU
1276*da0073e9SAndroid Build Coastguard Worker    device_guard: bool
1277*da0073e9SAndroid Build Coastguard Worker    # Whether the backend is in-tree (CPU/CUDA) or out-of-tree (XLA)
1278*da0073e9SAndroid Build Coastguard Worker    external: bool
1279*da0073e9SAndroid Build Coastguard Worker    # Other backend-specific information that is on a per-operator basis
1280*da0073e9SAndroid Build Coastguard Worker    index: dict[OperatorName, BackendMetadata]
1281*da0073e9SAndroid Build Coastguard Worker
1282*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1283*da0073e9SAndroid Build Coastguard Worker    def grow_index(
1284*da0073e9SAndroid Build Coastguard Worker        parent_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
1285*da0073e9SAndroid Build Coastguard Worker        child_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
1286*da0073e9SAndroid Build Coastguard Worker    ) -> None:
1287*da0073e9SAndroid Build Coastguard Worker        for k, v in child_index.items():
1288*da0073e9SAndroid Build Coastguard Worker            for op_name, metadata in v.items():
1289*da0073e9SAndroid Build Coastguard Worker                assert (
1290*da0073e9SAndroid Build Coastguard Worker                    op_name not in parent_index[k]
1291*da0073e9SAndroid Build Coastguard Worker                ), f"duplicate operator {op_name} for dispatch key {k}"
1292*da0073e9SAndroid Build Coastguard Worker                parent_index[k][op_name] = metadata
1293*da0073e9SAndroid Build Coastguard Worker
1294*da0073e9SAndroid Build Coastguard Worker    def primary(self, g: NativeFunctionsGroup) -> NativeFunction:
1295*da0073e9SAndroid Build Coastguard Worker        if self.use_out_as_primary:
1296*da0073e9SAndroid Build Coastguard Worker            return g.out
1297*da0073e9SAndroid Build Coastguard Worker        else:
1298*da0073e9SAndroid Build Coastguard Worker            return g.functional
1299*da0073e9SAndroid Build Coastguard Worker
1300*da0073e9SAndroid Build Coastguard Worker    def has_kernel(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
1301*da0073e9SAndroid Build Coastguard Worker        m = self.get_kernel(g)
1302*da0073e9SAndroid Build Coastguard Worker        return m is not None
1303*da0073e9SAndroid Build Coastguard Worker
1304*da0073e9SAndroid Build Coastguard Worker    def get_kernel(
1305*da0073e9SAndroid Build Coastguard Worker        self, g: NativeFunction | NativeFunctionsGroup
1306*da0073e9SAndroid Build Coastguard Worker    ) -> BackendMetadata | None:
1307*da0073e9SAndroid Build Coastguard Worker        if isinstance(g, NativeFunction):
1308*da0073e9SAndroid Build Coastguard Worker            f = g
1309*da0073e9SAndroid Build Coastguard Worker        elif isinstance(g, NativeFunctionsGroup):
1310*da0073e9SAndroid Build Coastguard Worker            f = self.primary(g)
1311*da0073e9SAndroid Build Coastguard Worker        else:
1312*da0073e9SAndroid Build Coastguard Worker            assert_never(g)
1313*da0073e9SAndroid Build Coastguard Worker        if f.func.name not in self.index:
1314*da0073e9SAndroid Build Coastguard Worker            return None
1315*da0073e9SAndroid Build Coastguard Worker        return self.index[f.func.name]
1316*da0073e9SAndroid Build Coastguard Worker
1317*da0073e9SAndroid Build Coastguard Worker    def native_function_class_name(self) -> str | None:
1318*da0073e9SAndroid Build Coastguard Worker        if self.external:
1319*da0073e9SAndroid Build Coastguard Worker            return f"{str(self.dispatch_key)}NativeFunctions"
1320*da0073e9SAndroid Build Coastguard Worker        else:
1321*da0073e9SAndroid Build Coastguard Worker            # TODO: This discrepancy isn't required; we could also generated
1322*da0073e9SAndroid Build Coastguard Worker            # a class for in-tree kernels. It'll just require carefully
1323*da0073e9SAndroid Build Coastguard Worker            # updating every kernel definition + callsite of every in-tree aten kernel.
1324*da0073e9SAndroid Build Coastguard Worker            return None
1325*da0073e9SAndroid Build Coastguard Worker
1326*da0073e9SAndroid Build Coastguard Worker
1327*da0073e9SAndroid Build Coastguard Worker# The function schema is undoubtedly the most important data structure
1328*da0073e9SAndroid Build Coastguard Worker# in all of the codegen, as it defines the type signature for operators,
1329*da0073e9SAndroid Build Coastguard Worker# and most of the code generation we do is type directed (e.g., look at
1330*da0073e9SAndroid Build Coastguard Worker# the types, decide what to do.  Think about how we code generate
1331*da0073e9SAndroid Build Coastguard Worker# C++ function stubs!)
1332*da0073e9SAndroid Build Coastguard Worker#
1333*da0073e9SAndroid Build Coastguard Worker# We will also see in this class the general structure for how we model
1334*da0073e9SAndroid Build Coastguard Worker# data in this code generation.  A few notable properties to point out
1335*da0073e9SAndroid Build Coastguard Worker# ahead of time:
1336*da0073e9SAndroid Build Coastguard Worker#
1337*da0073e9SAndroid Build Coastguard Worker#   - These dataclasses are a *lossless* representation of the strings
1338*da0073e9SAndroid Build Coastguard Worker#     they are parsed from.  In fact, we assert that given the
1339*da0073e9SAndroid Build Coastguard Worker#     information stored in the dataclass, we can exactly reconstruct
1340*da0073e9SAndroid Build Coastguard Worker#     the string we parsed from (and assert this inside the parse
1341*da0073e9SAndroid Build Coastguard Worker#     definition).  There are a few reasons for this:
1342*da0073e9SAndroid Build Coastguard Worker#
1343*da0073e9SAndroid Build Coastguard Worker#       - If you find that it is difficult to reconstruct the string
1344*da0073e9SAndroid Build Coastguard Worker#         given a dataclass, that is a clue that you are data
1345*da0073e9SAndroid Build Coastguard Worker#         representation is wrong.
1346*da0073e9SAndroid Build Coastguard Worker#
1347*da0073e9SAndroid Build Coastguard Worker#       - It helps ensure that all relevant information is present
1348*da0073e9SAndroid Build Coastguard Worker#         in the dataclass, so that downstream users aren't tempted
1349*da0073e9SAndroid Build Coastguard Worker#         to reparse the original string to get some information
1350*da0073e9SAndroid Build Coastguard Worker#         that was omitted.
1351*da0073e9SAndroid Build Coastguard Worker#
1352*da0073e9SAndroid Build Coastguard Worker#       - It forces you to represent the data in-memory in the same way
1353*da0073e9SAndroid Build Coastguard Worker#         it is recorded textually, which makes the dataclasses easier
1354*da0073e9SAndroid Build Coastguard Worker#         to understand for someone who is familiar with the
1355*da0073e9SAndroid Build Coastguard Worker#         textual format.  (As a tradeoff, it means you have to model
1356*da0073e9SAndroid Build Coastguard Worker#         the syntax, even when it is inconvenient.  But maybe that means
1357*da0073e9SAndroid Build Coastguard Worker#         the syntax is bad!)  If you don't understand the internal
1358*da0073e9SAndroid Build Coastguard Worker#         representation, go look at the printing code to see how
1359*da0073e9SAndroid Build Coastguard Worker#         it maps onto the surface syntax!
1360*da0073e9SAndroid Build Coastguard Worker#
1361*da0073e9SAndroid Build Coastguard Worker#       - It makes it easy to test the parsing code, as parsing code
1362*da0073e9SAndroid Build Coastguard Worker#         that is inconsistent with the string code will fail early
1363*da0073e9SAndroid Build Coastguard Worker#         and loudly.  (As a tradeoff, it makes the parsing code a bit
1364*da0073e9SAndroid Build Coastguard Worker#         brittle (in particular, with trivial whitespace changes you
1365*da0073e9SAndroid Build Coastguard Worker#         are likely to trigger an assert error).
1366*da0073e9SAndroid Build Coastguard Worker#
1367*da0073e9SAndroid Build Coastguard Worker#     In general, try to make the __str__ code as simple as possible
1368*da0073e9SAndroid Build Coastguard Worker#     (even at the cost of more complex parsing logic.)  Additionally,
1369*da0073e9SAndroid Build Coastguard Worker#     try to minimize redundancy in data representation.  (Precomputed
1370*da0073e9SAndroid Build Coastguard Worker#     fields are OK though: they are defined as a simple function on
1371*da0073e9SAndroid Build Coastguard Worker#     the canonical representation in question.)
1372*da0073e9SAndroid Build Coastguard Worker#
1373*da0073e9SAndroid Build Coastguard Worker#   - These dataclasses are all frozen; once constructed their
1374*da0073e9SAndroid Build Coastguard Worker#     values never change.  This makes it easy to tell where any
1375*da0073e9SAndroid Build Coastguard Worker#     given data came from: just look to the constructor.  As a
1376*da0073e9SAndroid Build Coastguard Worker#     tradeoff, you can't easily "decorate" a schema with extra
1377*da0073e9SAndroid Build Coastguard Worker#     information from a post-facto analysis.  We impose this
1378*da0073e9SAndroid Build Coastguard Worker#     restriction to make these structures more understandable.
1379*da0073e9SAndroid Build Coastguard Worker#
1380*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
1381*da0073e9SAndroid Build Coastguard Workerclass FunctionSchema:
1382*da0073e9SAndroid Build Coastguard Worker    # The name of the operator this function schema describes.
1383*da0073e9SAndroid Build Coastguard Worker    name: OperatorName
1384*da0073e9SAndroid Build Coastguard Worker
1385*da0073e9SAndroid Build Coastguard Worker    arguments: Arguments
1386*da0073e9SAndroid Build Coastguard Worker
1387*da0073e9SAndroid Build Coastguard Worker    # TODO: Need to handle collisions with argument names at some point
1388*da0073e9SAndroid Build Coastguard Worker    returns: tuple[Return, ...]
1389*da0073e9SAndroid Build Coastguard Worker
1390*da0073e9SAndroid Build Coastguard Worker    @property
1391*da0073e9SAndroid Build Coastguard Worker    def is_mutable(self) -> bool:
1392*da0073e9SAndroid Build Coastguard Worker        def is_write(arg: Argument) -> bool:
1393*da0073e9SAndroid Build Coastguard Worker            if arg.annotation is None:
1394*da0073e9SAndroid Build Coastguard Worker                return False
1395*da0073e9SAndroid Build Coastguard Worker            return arg.annotation.is_write
1396*da0073e9SAndroid Build Coastguard Worker
1397*da0073e9SAndroid Build Coastguard Worker        # Corresponds to torch._C._FunctionSchema.is_mutable
1398*da0073e9SAndroid Build Coastguard Worker        # See aten/src/ATen/core/function_schema.h (keep these in sync)
1399*da0073e9SAndroid Build Coastguard Worker        return any(is_write(a) for a in self.arguments.flat_all)
1400*da0073e9SAndroid Build Coastguard Worker
1401*da0073e9SAndroid Build Coastguard Worker    def schema_order_arguments(self) -> Iterator[Argument]:
1402*da0073e9SAndroid Build Coastguard Worker        return itertools.chain(
1403*da0073e9SAndroid Build Coastguard Worker            self.arguments.flat_positional,
1404*da0073e9SAndroid Build Coastguard Worker            self.arguments.flat_kwarg_only,
1405*da0073e9SAndroid Build Coastguard Worker            self.arguments.out,
1406*da0073e9SAndroid Build Coastguard Worker        )
1407*da0073e9SAndroid Build Coastguard Worker
1408*da0073e9SAndroid Build Coastguard Worker    decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)")
1409*da0073e9SAndroid Build Coastguard Worker
1410*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1411*da0073e9SAndroid Build Coastguard Worker    def parse(func: str) -> FunctionSchema:
1412*da0073e9SAndroid Build Coastguard Worker        # We should probably get a proper parser here
1413*da0073e9SAndroid Build Coastguard Worker        decls = FunctionSchema.decl_re.findall(func)
1414*da0073e9SAndroid Build Coastguard Worker        assert len(decls) == 1, f"Invalid function schema: {func}"
1415*da0073e9SAndroid Build Coastguard Worker        ops, args, return_decl = decls[0]
1416*da0073e9SAndroid Build Coastguard Worker        name = OperatorName.parse(ops)
1417*da0073e9SAndroid Build Coastguard Worker        arguments = Arguments.parse(args)
1418*da0073e9SAndroid Build Coastguard Worker        returns = parse_returns(return_decl)
1419*da0073e9SAndroid Build Coastguard Worker        r = FunctionSchema(name=name, arguments=arguments, returns=returns)
1420*da0073e9SAndroid Build Coastguard Worker        assert str(r) == func, f"{str(r)} != {func}"
1421*da0073e9SAndroid Build Coastguard Worker        return r
1422*da0073e9SAndroid Build Coastguard Worker
1423*da0073e9SAndroid Build Coastguard Worker    def returns_are_aliased(self) -> bool:
1424*da0073e9SAndroid Build Coastguard Worker        # We assert earlier that schemas can't have a mix of aliased and non-aliased returns
1425*da0073e9SAndroid Build Coastguard Worker        return any(
1426*da0073e9SAndroid Build Coastguard Worker            r
1427*da0073e9SAndroid Build Coastguard Worker            for r in self.returns
1428*da0073e9SAndroid Build Coastguard Worker            if r.annotation is not None and r.annotation.is_write
1429*da0073e9SAndroid Build Coastguard Worker        )
1430*da0073e9SAndroid Build Coastguard Worker
1431*da0073e9SAndroid Build Coastguard Worker    def __post_init__(self) -> None:
1432*da0073e9SAndroid Build Coastguard Worker        for arg, ret in zip(self.arguments.out, self.returns):
1433*da0073e9SAndroid Build Coastguard Worker            assert arg.annotation == ret.annotation, (
1434*da0073e9SAndroid Build Coastguard Worker                "Out arguments must have matching return Tensor; furthermore, "
1435*da0073e9SAndroid Build Coastguard Worker                "the ith-argument needs to correspond to the ith return"
1436*da0073e9SAndroid Build Coastguard Worker            )
1437*da0073e9SAndroid Build Coastguard Worker        # We also enforce that if you have any mutable, positional args, then they are not returned.
1438*da0073e9SAndroid Build Coastguard Worker        # This makes it easier to group these functions properly with their functional/out= counterparts.
1439*da0073e9SAndroid Build Coastguard Worker        for a in self.arguments.post_self_positional_mutable:
1440*da0073e9SAndroid Build Coastguard Worker            assert not any(
1441*da0073e9SAndroid Build Coastguard Worker                a.annotation == r.annotation for r in self.returns
1442*da0073e9SAndroid Build Coastguard Worker            ), f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}"
1443*da0073e9SAndroid Build Coastguard Worker        # Invariant: we expect out arguments to appear as keyword arguments in the schema.
1444*da0073e9SAndroid Build Coastguard Worker        # This means that all mutable returns should be aliased to a keyword argument
1445*da0073e9SAndroid Build Coastguard Worker        # (except for "self", which we explicitly don't treat as an out argument because of its use in methods)
1446*da0073e9SAndroid Build Coastguard Worker        # See Note [is_out_fn]
1447*da0073e9SAndroid Build Coastguard Worker        out_and_self = list(self.arguments.out) + [
1448*da0073e9SAndroid Build Coastguard Worker            arg for arg in self.arguments.flat_positional if arg.name == "self"
1449*da0073e9SAndroid Build Coastguard Worker        ]
1450*da0073e9SAndroid Build Coastguard Worker        mutable_returns = [
1451*da0073e9SAndroid Build Coastguard Worker            ret
1452*da0073e9SAndroid Build Coastguard Worker            for ret in self.returns
1453*da0073e9SAndroid Build Coastguard Worker            if ret.annotation is not None and ret.annotation.is_write
1454*da0073e9SAndroid Build Coastguard Worker        ]
1455*da0073e9SAndroid Build Coastguard Worker        immutable_returns = [
1456*da0073e9SAndroid Build Coastguard Worker            ret
1457*da0073e9SAndroid Build Coastguard Worker            for ret in self.returns
1458*da0073e9SAndroid Build Coastguard Worker            if ret.annotation is None or not ret.annotation.is_write
1459*da0073e9SAndroid Build Coastguard Worker        ]
1460*da0073e9SAndroid Build Coastguard Worker        # Some assertions: We don't want any functions with a return type of "-> (Tensor(a!), Tensor)",
1461*da0073e9SAndroid Build Coastguard Worker        # because:
1462*da0073e9SAndroid Build Coastguard Worker        # (1) It's more annoying to handle properly
1463*da0073e9SAndroid Build Coastguard Worker        # (2) It's unnecessary - you can't method-chain on the first (mutated) output because it's part of a tuple.
1464*da0073e9SAndroid Build Coastguard Worker        # Instead, we expect the (a!) argument to not be returned.
1465*da0073e9SAndroid Build Coastguard Worker        assert (
1466*da0073e9SAndroid Build Coastguard Worker            len(mutable_returns) == 0 or len(immutable_returns) == 0
1467*da0073e9SAndroid Build Coastguard Worker        ), f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}"
1468*da0073e9SAndroid Build Coastguard Worker        for ret in mutable_returns:
1469*da0073e9SAndroid Build Coastguard Worker            assert any(ret.annotation == arg.annotation for arg in out_and_self), (
1470*da0073e9SAndroid Build Coastguard Worker                'All mutable returns must be aliased either to a keyword argument, or to "self". '
1471*da0073e9SAndroid Build Coastguard Worker                "Did you forget to mark an out argument as keyword-only?"
1472*da0073e9SAndroid Build Coastguard Worker            )
1473*da0073e9SAndroid Build Coastguard Worker        if self.arguments.out:
1474*da0073e9SAndroid Build Coastguard Worker            # out= ops that return their mutable inputs are only really useful for method chaining.
1475*da0073e9SAndroid Build Coastguard Worker            # And method chaining is only really useful if the thing you're returning is a plain Tensor.
1476*da0073e9SAndroid Build Coastguard Worker            # So ideally, we'd enforce that out= ops with a single plain mutable tensor should return the tensor,
1477*da0073e9SAndroid Build Coastguard Worker            # and all other types of out= op schemas should return void.
1478*da0073e9SAndroid Build Coastguard Worker            # There are a bunch of existing out= ops that return tuples of tensors though, so we're stuck with allowing that.
1479*da0073e9SAndroid Build Coastguard Worker            if any(a.type != BaseType(BaseTy.Tensor) for a in self.arguments.out):
1480*da0073e9SAndroid Build Coastguard Worker                assert (
1481*da0073e9SAndroid Build Coastguard Worker                    len(self.returns) == 0
1482*da0073e9SAndroid Build Coastguard Worker                ), "out= ops that accept tensor lists as out arguments "
1483*da0073e9SAndroid Build Coastguard Worker                "are expected to have no return type (since you can't do method chaining on them)"
1484*da0073e9SAndroid Build Coastguard Worker            else:
1485*da0073e9SAndroid Build Coastguard Worker                # mutable keyword arguments whose name has _scratch_ prefix are
1486*da0073e9SAndroid Build Coastguard Worker                # scratch tensors for memory planning and should not be returned
1487*da0073e9SAndroid Build Coastguard Worker                assert len(
1488*da0073e9SAndroid Build Coastguard Worker                    [
1489*da0073e9SAndroid Build Coastguard Worker                        arg
1490*da0073e9SAndroid Build Coastguard Worker                        for arg in self.arguments.out
1491*da0073e9SAndroid Build Coastguard Worker                        if not arg.name.startswith("_scratch_")
1492*da0073e9SAndroid Build Coastguard Worker                    ]
1493*da0073e9SAndroid Build Coastguard Worker                ) == len(
1494*da0073e9SAndroid Build Coastguard Worker                    self.returns
1495*da0073e9SAndroid Build Coastguard Worker                ), "Must return as many arguments as there are out arguments, or no return at all"
1496*da0073e9SAndroid Build Coastguard Worker
1497*da0073e9SAndroid Build Coastguard Worker        if self.name.name.inplace:
1498*da0073e9SAndroid Build Coastguard Worker            self_a = self.arguments.self_arg
1499*da0073e9SAndroid Build Coastguard Worker            assert (
1500*da0073e9SAndroid Build Coastguard Worker                self_a
1501*da0073e9SAndroid Build Coastguard Worker                and self_a.argument.annotation
1502*da0073e9SAndroid Build Coastguard Worker                and self_a.argument.annotation.is_write
1503*da0073e9SAndroid Build Coastguard Worker            )
1504*da0073e9SAndroid Build Coastguard Worker            if self_a.argument.type == BaseType(BaseTy.Tensor):
1505*da0073e9SAndroid Build Coastguard Worker                # All inplace ops with an ordinary `Tensor self` argument should return self,
1506*da0073e9SAndroid Build Coastguard Worker                # to allow for method chaining.
1507*da0073e9SAndroid Build Coastguard Worker                assert (
1508*da0073e9SAndroid Build Coastguard Worker                    len(self.returns) == 1
1509*da0073e9SAndroid Build Coastguard Worker                    and self.returns[0].annotation == self_a.argument.annotation
1510*da0073e9SAndroid Build Coastguard Worker                )
1511*da0073e9SAndroid Build Coastguard Worker            else:
1512*da0073e9SAndroid Build Coastguard Worker                # You can't method chain on non-tensor self arguments though (like a List[Tensor])
1513*da0073e9SAndroid Build Coastguard Worker                # so in all other cases we expect the return type to be none.
1514*da0073e9SAndroid Build Coastguard Worker                assert len(self.returns) == 0
1515*da0073e9SAndroid Build Coastguard Worker
1516*da0073e9SAndroid Build Coastguard Worker        if self.arguments.tensor_options is not None:
1517*da0073e9SAndroid Build Coastguard Worker            assert self.kind() == SchemaKind.functional, (
1518*da0073e9SAndroid Build Coastguard Worker                "Found an operator that is not functional or out variant, but has tensor options arguments."
1519*da0073e9SAndroid Build Coastguard Worker                "This is not allowed- tensor options arguments are only allowed for factory functions."
1520*da0073e9SAndroid Build Coastguard Worker                f"schema: {str(self)}"
1521*da0073e9SAndroid Build Coastguard Worker            )
1522*da0073e9SAndroid Build Coastguard Worker        if self.is_functional_fn():
1523*da0073e9SAndroid Build Coastguard Worker            assert self.kind() == SchemaKind.functional, (
1524*da0073e9SAndroid Build Coastguard Worker                "Found an operator that is not functional, but its overload contains the string 'functional'."
1525*da0073e9SAndroid Build Coastguard Worker                "This is a special keyword in the codegen, please use a different overload name."
1526*da0073e9SAndroid Build Coastguard Worker                f"schema: {str(self)}"
1527*da0073e9SAndroid Build Coastguard Worker            )
1528*da0073e9SAndroid Build Coastguard Worker
1529*da0073e9SAndroid Build Coastguard Worker    def is_functional_fn(self) -> bool:
1530*da0073e9SAndroid Build Coastguard Worker        return "functional" in self.name.overload_name
1531*da0073e9SAndroid Build Coastguard Worker
1532*da0073e9SAndroid Build Coastguard Worker    def is_out_fn(self) -> bool:
1533*da0073e9SAndroid Build Coastguard Worker        # Note [is_out_fn]
1534*da0073e9SAndroid Build Coastguard Worker        #
1535*da0073e9SAndroid Build Coastguard Worker        # out functions are the variants which take an explicit out= argument
1536*da0073e9SAndroid Build Coastguard Worker        # to populate into.  We need to know if a schema corresponds to an
1537*da0073e9SAndroid Build Coastguard Worker        # out function for several reasons:
1538*da0073e9SAndroid Build Coastguard Worker        #
1539*da0073e9SAndroid Build Coastguard Worker        #   - They codegen differently in C++ API
1540*da0073e9SAndroid Build Coastguard Worker        #       - codegen to at::add_out rather than at::add
1541*da0073e9SAndroid Build Coastguard Worker        #       - out argument is moved to front of C++ argument list
1542*da0073e9SAndroid Build Coastguard Worker        #
1543*da0073e9SAndroid Build Coastguard Worker        # out functions are DEFINED to be any function with a keyword-only
1544*da0073e9SAndroid Build Coastguard Worker        # argument that is mutable.  In principle, this could lead to a
1545*da0073e9SAndroid Build Coastguard Worker        # false positive if you define a function that mutates a
1546*da0073e9SAndroid Build Coastguard Worker        # kwarg only argument, but this isn't the "true" output of this
1547*da0073e9SAndroid Build Coastguard Worker        # function.  A more robust definition that would work in this
1548*da0073e9SAndroid Build Coastguard Worker        # case would also look at:
1549*da0073e9SAndroid Build Coastguard Worker        #
1550*da0073e9SAndroid Build Coastguard Worker        #   - The output types.  Out functions take in the arguments
1551*da0073e9SAndroid Build Coastguard Worker        #     they mutate and then return them again; this is sort
1552*da0073e9SAndroid Build Coastguard Worker        #     of "definitionally" what makes something an out function.
1553*da0073e9SAndroid Build Coastguard Worker        #     Historically, we DO check this for consistency.
1554*da0073e9SAndroid Build Coastguard Worker        #   - Correspondence with pure variant.  An out function
1555*da0073e9SAndroid Build Coastguard Worker        #     should have a signature equivalent to its pure variant,
1556*da0073e9SAndroid Build Coastguard Worker        #     but just with extra kwargs for the output elements.  This
1557*da0073e9SAndroid Build Coastguard Worker        #     is difficult to actually check for and historically
1558*da0073e9SAndroid Build Coastguard Worker        #     we only do this check in tools/
1559*da0073e9SAndroid Build Coastguard Worker        return bool(self.arguments.out)
1560*da0073e9SAndroid Build Coastguard Worker
1561*da0073e9SAndroid Build Coastguard Worker    def kind(self) -> SchemaKind:
1562*da0073e9SAndroid Build Coastguard Worker        """
1563*da0073e9SAndroid Build Coastguard Worker        What kind of schema is this?  A functional schema is one
1564*da0073e9SAndroid Build Coastguard Worker        that returns a newly allocated output; an inplace schema
1565*da0073e9SAndroid Build Coastguard Worker        modifies the self argument inplace; an out schema writes
1566*da0073e9SAndroid Build Coastguard Worker        the result into an explicitly provided out argument.
1567*da0073e9SAndroid Build Coastguard Worker        """
1568*da0073e9SAndroid Build Coastguard Worker        is_out = bool(self.arguments.out)
1569*da0073e9SAndroid Build Coastguard Worker        is_scratch = bool(
1570*da0073e9SAndroid Build Coastguard Worker            [arg for arg in self.arguments.out if arg.name.startswith("_scratch_")]
1571*da0073e9SAndroid Build Coastguard Worker        )
1572*da0073e9SAndroid Build Coastguard Worker        is_inplace = self.name.name.inplace
1573*da0073e9SAndroid Build Coastguard Worker        is_mutable = any(
1574*da0073e9SAndroid Build Coastguard Worker            a.annotation is not None and a.annotation.is_write
1575*da0073e9SAndroid Build Coastguard Worker            for a in self.arguments.post_self_positional
1576*da0073e9SAndroid Build Coastguard Worker        )
1577*da0073e9SAndroid Build Coastguard Worker        assert not (is_out and is_inplace)
1578*da0073e9SAndroid Build Coastguard Worker        # out= and inplace schemas can also have post_self_positional mutable args,
1579*da0073e9SAndroid Build Coastguard Worker        # but we give precedence to out= and inplace when deciding the schema kind.
1580*da0073e9SAndroid Build Coastguard Worker        # Tradeoff: we probably don't want to have to teach codegen that looks at inplace ops
1581*da0073e9SAndroid Build Coastguard Worker        # to also worry about mutable post_self_positional arguments,
1582*da0073e9SAndroid Build Coastguard Worker        # but it seems like a much bigger lift to classify them has having a new schema kind.
1583*da0073e9SAndroid Build Coastguard Worker        # The number of ops that fit in this strange category is small enough that
1584*da0073e9SAndroid Build Coastguard Worker        # we can probably manually write code for them instead of forcing the codegen to handle them.
1585*da0073e9SAndroid Build Coastguard Worker        if is_inplace:
1586*da0073e9SAndroid Build Coastguard Worker            return SchemaKind.inplace
1587*da0073e9SAndroid Build Coastguard Worker        elif is_scratch:
1588*da0073e9SAndroid Build Coastguard Worker            assert (
1589*da0073e9SAndroid Build Coastguard Worker                is_out
1590*da0073e9SAndroid Build Coastguard Worker            ), "invariant: all scratch operators are expected to be out= operators too"
1591*da0073e9SAndroid Build Coastguard Worker            return SchemaKind.scratch
1592*da0073e9SAndroid Build Coastguard Worker        elif is_out:
1593*da0073e9SAndroid Build Coastguard Worker            assert (
1594*da0073e9SAndroid Build Coastguard Worker                not is_scratch
1595*da0073e9SAndroid Build Coastguard Worker            ), "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!"
1596*da0073e9SAndroid Build Coastguard Worker            return SchemaKind.out
1597*da0073e9SAndroid Build Coastguard Worker        elif is_mutable:
1598*da0073e9SAndroid Build Coastguard Worker            return SchemaKind.mutable
1599*da0073e9SAndroid Build Coastguard Worker        else:
1600*da0073e9SAndroid Build Coastguard Worker            return SchemaKind.functional
1601*da0073e9SAndroid Build Coastguard Worker
1602*da0073e9SAndroid Build Coastguard Worker    # For every return:
1603*da0073e9SAndroid Build Coastguard Worker    # - If the return aliases an input, we return the input name
1604*da0073e9SAndroid Build Coastguard Worker    # - Otherwise, we return None.
1605*da0073e9SAndroid Build Coastguard Worker    # If return names were enforced to be consistent with aliasing information, then we wouldn't need this.
1606*da0073e9SAndroid Build Coastguard Worker    def aliased_return_names(self) -> list[str | None]:
1607*da0073e9SAndroid Build Coastguard Worker        outs: list[str | None] = []
1608*da0073e9SAndroid Build Coastguard Worker        for r in self.returns:
1609*da0073e9SAndroid Build Coastguard Worker            aliased_args = [
1610*da0073e9SAndroid Build Coastguard Worker                a
1611*da0073e9SAndroid Build Coastguard Worker                for a in self.arguments.flat_all
1612*da0073e9SAndroid Build Coastguard Worker                if a.annotation is not None and a.annotation == r.annotation
1613*da0073e9SAndroid Build Coastguard Worker            ]
1614*da0073e9SAndroid Build Coastguard Worker            if len(aliased_args) == 0:
1615*da0073e9SAndroid Build Coastguard Worker                outs.append(None)
1616*da0073e9SAndroid Build Coastguard Worker            elif len(aliased_args) == 1:
1617*da0073e9SAndroid Build Coastguard Worker                outs.append(aliased_args[0].name)
1618*da0073e9SAndroid Build Coastguard Worker            else:
1619*da0073e9SAndroid Build Coastguard Worker                aliased_names = ", ".join(a.name for a in aliased_args)
1620*da0073e9SAndroid Build Coastguard Worker                raise AssertionError(
1621*da0073e9SAndroid Build Coastguard Worker                    f"Found a return ({r.name})that aliases multiple inputs ({aliased_names})"
1622*da0073e9SAndroid Build Coastguard Worker                )
1623*da0073e9SAndroid Build Coastguard Worker        return outs
1624*da0073e9SAndroid Build Coastguard Worker
1625*da0073e9SAndroid Build Coastguard Worker    def signature(
1626*da0073e9SAndroid Build Coastguard Worker        self,
1627*da0073e9SAndroid Build Coastguard Worker        *,
1628*da0073e9SAndroid Build Coastguard Worker        strip_default: bool = False,
1629*da0073e9SAndroid Build Coastguard Worker        strip_view_copy_name: bool = False,
1630*da0073e9SAndroid Build Coastguard Worker        keep_return_names: bool = False,
1631*da0073e9SAndroid Build Coastguard Worker    ) -> FunctionSchema:
1632*da0073e9SAndroid Build Coastguard Worker        """
1633*da0073e9SAndroid Build Coastguard Worker                Certain schemas are 'related', in that they are simply
1634*da0073e9SAndroid Build Coastguard Worker                inplace/out/functional versions of the same function.  This method
1635*da0073e9SAndroid Build Coastguard Worker                factors these schemas into the "core" functional signature which
1636*da0073e9SAndroid Build Coastguard Worker                is equal across all versions.
1637*da0073e9SAndroid Build Coastguard Worker
1638*da0073e9SAndroid Build Coastguard Worker                Here is what normalization happens to the schema to convert
1639*da0073e9SAndroid Build Coastguard Worker                it to a signature:
1640*da0073e9SAndroid Build Coastguard Worker                - The overload name is stripped (name is retained, since
1641*da0073e9SAndroid Build Coastguard Worker                  it expresses semantic content about what the function does)
1642*da0073e9SAndroid Build Coastguard Worker                - Inplace is set False
1643*da0073e9SAndroid Build Coastguard Worker                - Out arguments are stripped
1644*da0073e9SAndroid Build Coastguard Worker                - Mutable post_self_positional args are converted to returns
1645*da0073e9SAndroid Build Coastguard Worker                - Mutability annotations are stripped  (this is sound
1646*da0073e9SAndroid Build Coastguard Worker                  because you cannot overload on mutability annotation)
1647*da0073e9SAndroid Build Coastguard Worker                - Return names are stripped since they are not overloadable and
1648*da0073e9SAndroid Build Coastguard Worker                  some variants have return names but some not
1649*da0073e9SAndroid Build Coastguard Worker                - TensorOptions are dropped
1650*da0073e9SAndroid Build Coastguard Worker                  because out= variants of factory functions don't include them
1651*da0073e9SAndroid Build Coastguard Worker                  (and we want to be able to pair up factory functions with their out variants)
1652*da0073e9SAndroid Build Coastguard Worker
1653*da0073e9SAndroid Build Coastguard Worker                Finally, we want to be able to pair up related "view" and their
1654*da0073e9SAndroid Build Coastguard Worker                corresponding "view_copy" operators. We do this by optionally
1655*da0073e9SAndroid Build Coastguard Worker                stripping the trailing "_copy" from the base name.
1656*da0073e9SAndroid Build Coastguard Worker
1657*da0073e9SAndroid Build Coastguard Worker                Example of a mutable op before and after:
1658*da0073e9SAndroid Build Coastguard Worker
1659*da0073e9SAndroid Build Coastguard Worker                f.func (Mutable operator):
1660*da0073e9SAndroid Build Coastguard Worker        _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask)  # noqa: B950
1661*da0073e9SAndroid Build Coastguard Worker
1662*da0073e9SAndroid Build Coastguard Worker                f.func (Corresponding functional operator):
1663*da0073e9SAndroid Build Coastguard Worker        _fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out)  # noqa: B950
1664*da0073e9SAndroid Build Coastguard Worker
1665*da0073e9SAndroid Build Coastguard Worker                f.func.signature() output:
1666*da0073e9SAndroid Build Coastguard Worker        _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)  # noqa: B950
1667*da0073e9SAndroid Build Coastguard Worker        """
1668*da0073e9SAndroid Build Coastguard Worker
1669*da0073e9SAndroid Build Coastguard Worker        def strip_ret_annotation(r: Return) -> Return:
1670*da0073e9SAndroid Build Coastguard Worker            return Return(
1671*da0073e9SAndroid Build Coastguard Worker                name=r.name if keep_return_names else None,
1672*da0073e9SAndroid Build Coastguard Worker                type=r.type,
1673*da0073e9SAndroid Build Coastguard Worker                annotation=None,
1674*da0073e9SAndroid Build Coastguard Worker            )
1675*da0073e9SAndroid Build Coastguard Worker
1676*da0073e9SAndroid Build Coastguard Worker        base_name = self.name.name.base
1677*da0073e9SAndroid Build Coastguard Worker        if strip_view_copy_name:
1678*da0073e9SAndroid Build Coastguard Worker            if base_name.endswith("_copy"):
1679*da0073e9SAndroid Build Coastguard Worker                base_name = base_name.replace("_copy", "")
1680*da0073e9SAndroid Build Coastguard Worker            elif base_name.endswith("_scatter"):
1681*da0073e9SAndroid Build Coastguard Worker                base_name = base_name.replace("scatter", "inverse")
1682*da0073e9SAndroid Build Coastguard Worker
1683*da0073e9SAndroid Build Coastguard Worker        # find mutable inputs that are not originally returned, and convert them to returns
1684*da0073e9SAndroid Build Coastguard Worker        returns_from_mutable_inputs = tuple(
1685*da0073e9SAndroid Build Coastguard Worker            # When we're grouping functions we strip the return names,
1686*da0073e9SAndroid Build Coastguard Worker            # but when we're generating the actual functional variants then we follow
1687*da0073e9SAndroid Build Coastguard Worker            # a convention for what to name the returns
1688*da0073e9SAndroid Build Coastguard Worker            Return(
1689*da0073e9SAndroid Build Coastguard Worker                name=f"{a.name}_out" if keep_return_names else None,
1690*da0073e9SAndroid Build Coastguard Worker                type=a.type,
1691*da0073e9SAndroid Build Coastguard Worker                annotation=None,
1692*da0073e9SAndroid Build Coastguard Worker            )
1693*da0073e9SAndroid Build Coastguard Worker            for a in itertools.chain(
1694*da0073e9SAndroid Build Coastguard Worker                # Order is important here (otherwise e.g. inplace with mutable args
1695*da0073e9SAndroid Build Coastguard Worker                # and out= with mutable args won't have the same signature)
1696*da0073e9SAndroid Build Coastguard Worker                [self.arguments.self_arg.argument]
1697*da0073e9SAndroid Build Coastguard Worker                if self.arguments.self_arg is not None
1698*da0073e9SAndroid Build Coastguard Worker                else [],
1699*da0073e9SAndroid Build Coastguard Worker                self.arguments.out,
1700*da0073e9SAndroid Build Coastguard Worker                self.arguments.post_self_positional,
1701*da0073e9SAndroid Build Coastguard Worker            )
1702*da0073e9SAndroid Build Coastguard Worker            if a.annotation is not None
1703*da0073e9SAndroid Build Coastguard Worker            and a.annotation.is_write
1704*da0073e9SAndroid Build Coastguard Worker            and not any(a.annotation == r.annotation for r in self.returns)
1705*da0073e9SAndroid Build Coastguard Worker        )
1706*da0073e9SAndroid Build Coastguard Worker        original_returns = tuple(map(strip_ret_annotation, self.returns))
1707*da0073e9SAndroid Build Coastguard Worker        # Ordering is important here. We expect the "mutable input" returns to come last.
1708*da0073e9SAndroid Build Coastguard Worker        returns = original_returns + returns_from_mutable_inputs
1709*da0073e9SAndroid Build Coastguard Worker
1710*da0073e9SAndroid Build Coastguard Worker        args_sig = self.arguments.signature(strip_default=strip_default)
1711*da0073e9SAndroid Build Coastguard Worker        # See Note [bernoulli.p schema]
1712*da0073e9SAndroid Build Coastguard Worker        if str(self.name) == "bernoulli.p":
1713*da0073e9SAndroid Build Coastguard Worker            args_sig = Arguments.parse(str(args_sig).replace("float p", "float p=0.5"))
1714*da0073e9SAndroid Build Coastguard Worker
1715*da0073e9SAndroid Build Coastguard Worker        return FunctionSchema(
1716*da0073e9SAndroid Build Coastguard Worker            name=OperatorName(
1717*da0073e9SAndroid Build Coastguard Worker                name=BaseOperatorName(
1718*da0073e9SAndroid Build Coastguard Worker                    base=base_name,
1719*da0073e9SAndroid Build Coastguard Worker                    inplace=False,
1720*da0073e9SAndroid Build Coastguard Worker                    dunder_method=self.name.name.dunder_method,
1721*da0073e9SAndroid Build Coastguard Worker                ),
1722*da0073e9SAndroid Build Coastguard Worker                overload_name="",  # stripped
1723*da0073e9SAndroid Build Coastguard Worker            ),
1724*da0073e9SAndroid Build Coastguard Worker            arguments=args_sig,
1725*da0073e9SAndroid Build Coastguard Worker            returns=returns,
1726*da0073e9SAndroid Build Coastguard Worker        )
1727*da0073e9SAndroid Build Coastguard Worker
1728*da0073e9SAndroid Build Coastguard Worker    def view_signature(self) -> FunctionSchema:
1729*da0073e9SAndroid Build Coastguard Worker        return self.signature(strip_view_copy_name=True)
1730*da0073e9SAndroid Build Coastguard Worker
1731*da0073e9SAndroid Build Coastguard Worker    def with_name(self, name: OperatorName) -> FunctionSchema:
1732*da0073e9SAndroid Build Coastguard Worker        return FunctionSchema(
1733*da0073e9SAndroid Build Coastguard Worker            name=name,
1734*da0073e9SAndroid Build Coastguard Worker            arguments=self.arguments,
1735*da0073e9SAndroid Build Coastguard Worker            returns=self.returns,
1736*da0073e9SAndroid Build Coastguard Worker        )
1737*da0073e9SAndroid Build Coastguard Worker
1738*da0073e9SAndroid Build Coastguard Worker    @property
1739*da0073e9SAndroid Build Coastguard Worker    def modifies_arguments(self) -> bool:
1740*da0073e9SAndroid Build Coastguard Worker        return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable]
1741*da0073e9SAndroid Build Coastguard Worker
1742*da0073e9SAndroid Build Coastguard Worker    def has_symint(self) -> bool:
1743*da0073e9SAndroid Build Coastguard Worker        return self.arguments.has_symint_arg()
1744*da0073e9SAndroid Build Coastguard Worker
1745*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
1746*da0073e9SAndroid Build Coastguard Worker        all_arguments_str = str(self.arguments)
1747*da0073e9SAndroid Build Coastguard Worker        if len(self.returns) == 1:
1748*da0073e9SAndroid Build Coastguard Worker            returns = str(self.returns[0])  # omit parentheses
1749*da0073e9SAndroid Build Coastguard Worker        else:
1750*da0073e9SAndroid Build Coastguard Worker            returns = "(" + ", ".join(map(str, self.returns)) + ")"
1751*da0073e9SAndroid Build Coastguard Worker        return f"{self.name}({all_arguments_str}) -> {returns}"
1752*da0073e9SAndroid Build Coastguard Worker
1753*da0073e9SAndroid Build Coastguard Worker
1754*da0073e9SAndroid Build Coastguard Worker# Here is the rest of the data model, described more briefly.
1755*da0073e9SAndroid Build Coastguard Worker
1756*da0073e9SAndroid Build Coastguard Worker
1757*da0073e9SAndroid Build Coastguard Worker# Simplified version for what actually shows up in built-ins.
1758*da0073e9SAndroid Build Coastguard Worker# Look at alias_info.h for expanded syntax.  If you need the structure,
1759*da0073e9SAndroid Build Coastguard Worker# you also need to make this structure recursive so it can be lined
1760*da0073e9SAndroid Build Coastguard Worker# up with the type components too.  For primitives this isn't really
1761*da0073e9SAndroid Build Coastguard Worker# necessary
1762*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
1763*da0073e9SAndroid Build Coastguard Workerclass Annotation:
1764*da0073e9SAndroid Build Coastguard Worker    # Typically only has one element.  Not actually a set so
1765*da0073e9SAndroid Build Coastguard Worker    # we can conveniently assume it is canonically ordered
1766*da0073e9SAndroid Build Coastguard Worker    alias_set: tuple[str, ...]
1767*da0073e9SAndroid Build Coastguard Worker    is_write: bool
1768*da0073e9SAndroid Build Coastguard Worker    alias_set_after: tuple[str, ...]
1769*da0073e9SAndroid Build Coastguard Worker
1770*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1771*da0073e9SAndroid Build Coastguard Worker    def parse(ann: str) -> Annotation:
1772*da0073e9SAndroid Build Coastguard Worker        # TODO: implement a proper parser if this gets more ugly
1773*da0073e9SAndroid Build Coastguard Worker        # Regex Explanation:
1774*da0073e9SAndroid Build Coastguard Worker        # Example: "a! -> a|b"
1775*da0073e9SAndroid Build Coastguard Worker        # Group #1: alias before optional '|', required. Matches the first
1776*da0073e9SAndroid Build Coastguard Worker        #   character 'a' in the example
1777*da0073e9SAndroid Build Coastguard Worker        # Group #2: optional alias set after optional '|', matches empty string
1778*da0073e9SAndroid Build Coastguard Worker        #   in the example
1779*da0073e9SAndroid Build Coastguard Worker        # Group #3: optional "is write" flag, matches '!' in the example.
1780*da0073e9SAndroid Build Coastguard Worker        # Group #4: optional section containing arrow, matches " -> a|b" in the
1781*da0073e9SAndroid Build Coastguard Worker        #   example.
1782*da0073e9SAndroid Build Coastguard Worker        # Group #5: optional alias after set, supports wildcard, matches "a|b"
1783*da0073e9SAndroid Build Coastguard Worker        #   in the example.
1784*da0073e9SAndroid Build Coastguard Worker        # Group #6: optional sub-section of alias after set, matches "|b" in the
1785*da0073e9SAndroid Build Coastguard Worker        #   example.
1786*da0073e9SAndroid Build Coastguard Worker        m = re.match(r"^([a-z])(\|[a-z])*(!?)( -> (\*|[a-z](\|[a-z])*))?$", ann)
1787*da0073e9SAndroid Build Coastguard Worker
1788*da0073e9SAndroid Build Coastguard Worker        assert m is not None, f"unrecognized alias annotation {ann}"
1789*da0073e9SAndroid Build Coastguard Worker        before_alias = m.group(1) + (m.group(2) if m.group(2) else "")
1790*da0073e9SAndroid Build Coastguard Worker        alias_set = tuple(before_alias.split("|"))
1791*da0073e9SAndroid Build Coastguard Worker        is_write = m.group(3) == "!"
1792*da0073e9SAndroid Build Coastguard Worker        assert not (
1793*da0073e9SAndroid Build Coastguard Worker            is_write and len(alias_set) > 1
1794*da0073e9SAndroid Build Coastguard Worker        ), f"alias set larger than 1 is not mutable, got {ann} instead."
1795*da0073e9SAndroid Build Coastguard Worker        after_set = tuple(m.group(5).split("|")) if m.group(5) else ()
1796*da0073e9SAndroid Build Coastguard Worker        assert not (
1797*da0073e9SAndroid Build Coastguard Worker            len(before_alias) > 1 and len(after_set) > 1
1798*da0073e9SAndroid Build Coastguard Worker        ), f"before alias set and after alias set cannot be larger than 1 at the same time, got {ann} instead."
1799*da0073e9SAndroid Build Coastguard Worker        r = Annotation(
1800*da0073e9SAndroid Build Coastguard Worker            alias_set=alias_set, is_write=is_write, alias_set_after=after_set
1801*da0073e9SAndroid Build Coastguard Worker        )
1802*da0073e9SAndroid Build Coastguard Worker        assert str(r) == ann, f"{r} != {ann}"
1803*da0073e9SAndroid Build Coastguard Worker        return r
1804*da0073e9SAndroid Build Coastguard Worker
1805*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
1806*da0073e9SAndroid Build Coastguard Worker        alias_set = "|".join(self.alias_set)
1807*da0073e9SAndroid Build Coastguard Worker        if self.is_write:
1808*da0073e9SAndroid Build Coastguard Worker            alias_set = f"{alias_set}!"
1809*da0073e9SAndroid Build Coastguard Worker        alias_set_after = "|".join(self.alias_set_after)
1810*da0073e9SAndroid Build Coastguard Worker        if alias_set_after:
1811*da0073e9SAndroid Build Coastguard Worker            alias_set = f'{alias_set}{" -> "}{alias_set_after}'
1812*da0073e9SAndroid Build Coastguard Worker        return alias_set
1813*da0073e9SAndroid Build Coastguard Worker
1814*da0073e9SAndroid Build Coastguard Worker
1815*da0073e9SAndroid Build Coastguard Worker# The base class for the type system.  This is also loosely modeled
1816*da0073e9SAndroid Build Coastguard Worker# off of jit_type.h, but we've simplified the hierarchy to focus
1817*da0073e9SAndroid Build Coastguard Worker# in on the aspects of the type system that matter for code generation
1818*da0073e9SAndroid Build Coastguard Worker# (for example, there's no SingleElementType subclass anymore).
1819*da0073e9SAndroid Build Coastguard Worker# You never actually construct a Type; usually it's going to be one
1820*da0073e9SAndroid Build Coastguard Worker# of the subclasses.  If Python had ADTs this would be one!
1821*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
1822*da0073e9SAndroid Build Coastguard Workerclass Type:
1823*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1824*da0073e9SAndroid Build Coastguard Worker    def parse(t: str) -> Type:
1825*da0073e9SAndroid Build Coastguard Worker        r = Type._parse(t)
1826*da0073e9SAndroid Build Coastguard Worker        assert str(r) == t, f"{r} != {t}"
1827*da0073e9SAndroid Build Coastguard Worker        return r
1828*da0073e9SAndroid Build Coastguard Worker
1829*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1830*da0073e9SAndroid Build Coastguard Worker    def _parse(t: str) -> Type:
1831*da0073e9SAndroid Build Coastguard Worker        m = re.match(r"^(.+)\?$", t)
1832*da0073e9SAndroid Build Coastguard Worker        if m is not None:
1833*da0073e9SAndroid Build Coastguard Worker            return OptionalType(Type.parse(m.group(1)))
1834*da0073e9SAndroid Build Coastguard Worker        m = re.match(r"^(.+)\[([0-9]+)?\]$", t)
1835*da0073e9SAndroid Build Coastguard Worker        if m is not None:
1836*da0073e9SAndroid Build Coastguard Worker            size = int(m.group(2)) if m.group(2) is not None else None
1837*da0073e9SAndroid Build Coastguard Worker            return ListType(elem=Type.parse(m.group(1)), size=size)
1838*da0073e9SAndroid Build Coastguard Worker
1839*da0073e9SAndroid Build Coastguard Worker        # '__torch__.torch.classes.' is the prefix for custom class
1840*da0073e9SAndroid Build Coastguard Worker        m = re.match(r"^__torch__\.torch\.classes\.([a-zA-Z0-9_.]+)$", t)
1841*da0073e9SAndroid Build Coastguard Worker        if m is not None:
1842*da0073e9SAndroid Build Coastguard Worker            return CustomClassType(m.group(1))
1843*da0073e9SAndroid Build Coastguard Worker        try:
1844*da0073e9SAndroid Build Coastguard Worker            return BaseType(BaseTy[t])
1845*da0073e9SAndroid Build Coastguard Worker        except KeyError as e:
1846*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f"unrecognized type {t}") from e
1847*da0073e9SAndroid Build Coastguard Worker
1848*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
1849*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
1850*da0073e9SAndroid Build Coastguard Worker
1851*da0073e9SAndroid Build Coastguard Worker    # WARNING: These concepts are not very well-defined.  For example,
1852*da0073e9SAndroid Build Coastguard Worker    # is "int?" nullable? How about "int?[]".  They are defined
1853*da0073e9SAndroid Build Coastguard Worker    # so we can conveniently generate legacy Declarations.yaml but
1854*da0073e9SAndroid Build Coastguard Worker    # really we should probably just remove these at some point
1855*da0073e9SAndroid Build Coastguard Worker
1856*da0073e9SAndroid Build Coastguard Worker    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1857*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
1858*da0073e9SAndroid Build Coastguard Worker
1859*da0073e9SAndroid Build Coastguard Worker    def is_tensor_like(self) -> bool:
1860*da0073e9SAndroid Build Coastguard Worker        return self.is_base_ty_like(BaseTy.Tensor)
1861*da0073e9SAndroid Build Coastguard Worker
1862*da0073e9SAndroid Build Coastguard Worker    def is_generator_like(self) -> bool:
1863*da0073e9SAndroid Build Coastguard Worker        return self.is_base_ty_like(BaseTy.Generator)
1864*da0073e9SAndroid Build Coastguard Worker
1865*da0073e9SAndroid Build Coastguard Worker    def is_symint_like(self) -> bool:
1866*da0073e9SAndroid Build Coastguard Worker        return self.is_base_ty_like(BaseTy.SymInt)
1867*da0073e9SAndroid Build Coastguard Worker
1868*da0073e9SAndroid Build Coastguard Worker    def is_nullable(self) -> bool:
1869*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
1870*da0073e9SAndroid Build Coastguard Worker
1871*da0073e9SAndroid Build Coastguard Worker    def is_list_like(self) -> ListType | None:
1872*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
1873*da0073e9SAndroid Build Coastguard Worker
1874*da0073e9SAndroid Build Coastguard Worker
1875*da0073e9SAndroid Build Coastguard Worker# Base types are simple, atomic types with no further structure
1876*da0073e9SAndroid Build Coastguard Workerclass BaseTy(Enum):
1877*da0073e9SAndroid Build Coastguard Worker    Generator = auto()
1878*da0073e9SAndroid Build Coastguard Worker    ScalarType = auto()
1879*da0073e9SAndroid Build Coastguard Worker    Tensor = auto()
1880*da0073e9SAndroid Build Coastguard Worker    int = auto()
1881*da0073e9SAndroid Build Coastguard Worker    Dimname = auto()
1882*da0073e9SAndroid Build Coastguard Worker    DimVector = auto()
1883*da0073e9SAndroid Build Coastguard Worker    float = auto()
1884*da0073e9SAndroid Build Coastguard Worker    str = auto()
1885*da0073e9SAndroid Build Coastguard Worker    bool = auto()
1886*da0073e9SAndroid Build Coastguard Worker    Layout = auto()
1887*da0073e9SAndroid Build Coastguard Worker    Device = auto()
1888*da0073e9SAndroid Build Coastguard Worker    DeviceIndex = auto()
1889*da0073e9SAndroid Build Coastguard Worker    Scalar = auto()
1890*da0073e9SAndroid Build Coastguard Worker    MemoryFormat = auto()
1891*da0073e9SAndroid Build Coastguard Worker    QScheme = auto()
1892*da0073e9SAndroid Build Coastguard Worker    Storage = auto()
1893*da0073e9SAndroid Build Coastguard Worker    Stream = auto()
1894*da0073e9SAndroid Build Coastguard Worker    SymInt = auto()
1895*da0073e9SAndroid Build Coastguard Worker    SymBool = auto()
1896*da0073e9SAndroid Build Coastguard Worker    ConstQuantizerPtr = auto()  # TODO: rename
1897*da0073e9SAndroid Build Coastguard Worker    GraphModule = auto()
1898*da0073e9SAndroid Build Coastguard Worker
1899*da0073e9SAndroid Build Coastguard Worker
1900*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
1901*da0073e9SAndroid Build Coastguard Workerclass BaseType(Type):
1902*da0073e9SAndroid Build Coastguard Worker    name: BaseTy
1903*da0073e9SAndroid Build Coastguard Worker
1904*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
1905*da0073e9SAndroid Build Coastguard Worker        return f"{self.name.name}"
1906*da0073e9SAndroid Build Coastguard Worker
1907*da0073e9SAndroid Build Coastguard Worker    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1908*da0073e9SAndroid Build Coastguard Worker        return self.name == base_ty
1909*da0073e9SAndroid Build Coastguard Worker
1910*da0073e9SAndroid Build Coastguard Worker    def is_nullable(self) -> bool:
1911*da0073e9SAndroid Build Coastguard Worker        return False
1912*da0073e9SAndroid Build Coastguard Worker
1913*da0073e9SAndroid Build Coastguard Worker    def is_list_like(self) -> ListType | None:
1914*da0073e9SAndroid Build Coastguard Worker        return None
1915*da0073e9SAndroid Build Coastguard Worker
1916*da0073e9SAndroid Build Coastguard Worker    def is_symint_like(self) -> bool:
1917*da0073e9SAndroid Build Coastguard Worker        return self.name == BaseTy.SymInt
1918*da0073e9SAndroid Build Coastguard Worker
1919*da0073e9SAndroid Build Coastguard Worker
1920*da0073e9SAndroid Build Coastguard Worker# Optional types may be specified, or may also be validly given None
1921*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
1922*da0073e9SAndroid Build Coastguard Workerclass OptionalType(Type):
1923*da0073e9SAndroid Build Coastguard Worker    elem: Type
1924*da0073e9SAndroid Build Coastguard Worker
1925*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
1926*da0073e9SAndroid Build Coastguard Worker        return f"{self.elem}?"
1927*da0073e9SAndroid Build Coastguard Worker
1928*da0073e9SAndroid Build Coastguard Worker    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1929*da0073e9SAndroid Build Coastguard Worker        return self.elem.is_base_ty_like(base_ty)
1930*da0073e9SAndroid Build Coastguard Worker
1931*da0073e9SAndroid Build Coastguard Worker    def is_symint_like(self) -> bool:
1932*da0073e9SAndroid Build Coastguard Worker        return self.elem.is_symint_like()
1933*da0073e9SAndroid Build Coastguard Worker
1934*da0073e9SAndroid Build Coastguard Worker    def is_nullable(self) -> bool:
1935*da0073e9SAndroid Build Coastguard Worker        return True
1936*da0073e9SAndroid Build Coastguard Worker
1937*da0073e9SAndroid Build Coastguard Worker    def is_list_like(self) -> ListType | None:
1938*da0073e9SAndroid Build Coastguard Worker        return self.elem.is_list_like()
1939*da0073e9SAndroid Build Coastguard Worker
1940*da0073e9SAndroid Build Coastguard Worker
1941*da0073e9SAndroid Build Coastguard Worker# A type representing a PyTorch custom class
1942*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
1943*da0073e9SAndroid Build Coastguard Workerclass CustomClassType(Type):
1944*da0073e9SAndroid Build Coastguard Worker    class_name: str
1945*da0073e9SAndroid Build Coastguard Worker
1946*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
1947*da0073e9SAndroid Build Coastguard Worker        """
1948*da0073e9SAndroid Build Coastguard Worker        Return the class name will prefix __torch__.torch.classes
1949*da0073e9SAndroid Build Coastguard Worker        """
1950*da0073e9SAndroid Build Coastguard Worker        return f"__torch__.torch.classes.{self.class_name}"
1951*da0073e9SAndroid Build Coastguard Worker
1952*da0073e9SAndroid Build Coastguard Worker    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1953*da0073e9SAndroid Build Coastguard Worker        return False
1954*da0073e9SAndroid Build Coastguard Worker
1955*da0073e9SAndroid Build Coastguard Worker    def is_symint_like(self) -> bool:
1956*da0073e9SAndroid Build Coastguard Worker        return False
1957*da0073e9SAndroid Build Coastguard Worker
1958*da0073e9SAndroid Build Coastguard Worker    def is_nullable(self) -> bool:
1959*da0073e9SAndroid Build Coastguard Worker        """
1960*da0073e9SAndroid Build Coastguard Worker        Assume a custom class is not nullable.
1961*da0073e9SAndroid Build Coastguard Worker        """
1962*da0073e9SAndroid Build Coastguard Worker        return False
1963*da0073e9SAndroid Build Coastguard Worker
1964*da0073e9SAndroid Build Coastguard Worker    def is_list_like(self) -> ListType | None:
1965*da0073e9SAndroid Build Coastguard Worker        return None
1966*da0073e9SAndroid Build Coastguard Worker
1967*da0073e9SAndroid Build Coastguard Worker
1968*da0073e9SAndroid Build Coastguard Worker# List types specify that we may have multiples of an element.  We
1969*da0073e9SAndroid Build Coastguard Worker# also support explicit sizes on list types, but these have
1970*da0073e9SAndroid Build Coastguard Worker# some nontrivial semantics!  (However, for C++ API purposes, explicit
1971*da0073e9SAndroid Build Coastguard Worker# sizes are mostly erased from the type system.)
1972*da0073e9SAndroid Build Coastguard Worker#
1973*da0073e9SAndroid Build Coastguard Worker# DANGER WILL ROBINSON: C++ elaboration depends on elem type; e.g.,
1974*da0073e9SAndroid Build Coastguard Worker# int[] elaborates differently than bool[3]!
1975*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
1976*da0073e9SAndroid Build Coastguard Workerclass ListType(Type):
1977*da0073e9SAndroid Build Coastguard Worker    elem: Type
1978*da0073e9SAndroid Build Coastguard Worker    size: int | None
1979*da0073e9SAndroid Build Coastguard Worker
1980*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
1981*da0073e9SAndroid Build Coastguard Worker        size = f"{self.size}" if self.size else ""
1982*da0073e9SAndroid Build Coastguard Worker        return f"{self.elem}[{size}]"
1983*da0073e9SAndroid Build Coastguard Worker
1984*da0073e9SAndroid Build Coastguard Worker    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1985*da0073e9SAndroid Build Coastguard Worker        return self.elem.is_base_ty_like(base_ty)
1986*da0073e9SAndroid Build Coastguard Worker
1987*da0073e9SAndroid Build Coastguard Worker    def is_symint_like(self) -> bool:
1988*da0073e9SAndroid Build Coastguard Worker        return self.elem.is_symint_like()
1989*da0073e9SAndroid Build Coastguard Worker
1990*da0073e9SAndroid Build Coastguard Worker    def is_nullable(self) -> bool:
1991*da0073e9SAndroid Build Coastguard Worker        return self.elem.is_nullable()
1992*da0073e9SAndroid Build Coastguard Worker
1993*da0073e9SAndroid Build Coastguard Worker    def is_list_like(self) -> ListType | None:
1994*da0073e9SAndroid Build Coastguard Worker        return self
1995*da0073e9SAndroid Build Coastguard Worker
1996*da0073e9SAndroid Build Coastguard Worker
1997*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
1998*da0073e9SAndroid Build Coastguard Workerclass Argument:
1999*da0073e9SAndroid Build Coastguard Worker    # NB: I didn't put kwarg_only as a boolean field here, unlike
2000*da0073e9SAndroid Build Coastguard Worker    # c10::Argument, so that printing works correctly
2001*da0073e9SAndroid Build Coastguard Worker
2002*da0073e9SAndroid Build Coastguard Worker    name: str
2003*da0073e9SAndroid Build Coastguard Worker    type: Type
2004*da0073e9SAndroid Build Coastguard Worker    default: str | None
2005*da0073e9SAndroid Build Coastguard Worker
2006*da0073e9SAndroid Build Coastguard Worker    # The semantics of the annotation field are a little strange.
2007*da0073e9SAndroid Build Coastguard Worker    #
2008*da0073e9SAndroid Build Coastguard Worker    # Alias annotations parametrize Tensors (since Tensors are the only things
2009*da0073e9SAndroid Build Coastguard Worker    # that can alias.)  This motivates why I write Tensor(a!)?  (and not, for
2010*da0073e9SAndroid Build Coastguard Worker    # example, Tensor?(a!)), because the (a!) describes aliasing on the tensor,
2011*da0073e9SAndroid Build Coastguard Worker    # which may be optional (i.e., the alias annotation should bind first to
2012*da0073e9SAndroid Build Coastguard Worker    # Tensor, before the optional postfix annotation).
2013*da0073e9SAndroid Build Coastguard Worker    #
2014*da0073e9SAndroid Build Coastguard Worker    # However, despite being a property of Tensor, we (and c10::Argument)
2015*da0073e9SAndroid Build Coastguard Worker    # store the annotation at the top level of the Argument, rather than
2016*da0073e9SAndroid Build Coastguard Worker    # inside the embedded Tensor type.  In the C++ version of this
2017*da0073e9SAndroid Build Coastguard Worker    # class, we then go through great lengths to mimic the type
2018*da0073e9SAndroid Build Coastguard Worker    # structure in the annotation structure so we can correlate
2019*da0073e9SAndroid Build Coastguard Worker    # annotations with types.
2020*da0073e9SAndroid Build Coastguard Worker    #
2021*da0073e9SAndroid Build Coastguard Worker    # Now, it turns out, in all applications in code generation, the
2022*da0073e9SAndroid Build Coastguard Worker    # structure of annotated types is very simple.  So we just hard
2023*da0073e9SAndroid Build Coastguard Worker    # code it here.  But if we ever do get anything more complex, this
2024*da0073e9SAndroid Build Coastguard Worker    # model will have to change!
2025*da0073e9SAndroid Build Coastguard Worker    annotation: Annotation | None
2026*da0073e9SAndroid Build Coastguard Worker
2027*da0073e9SAndroid Build Coastguard Worker    @property
2028*da0073e9SAndroid Build Coastguard Worker    def alias_info(self) -> Annotation | None:
2029*da0073e9SAndroid Build Coastguard Worker        return self.annotation
2030*da0073e9SAndroid Build Coastguard Worker
2031*da0073e9SAndroid Build Coastguard Worker    @staticmethod
2032*da0073e9SAndroid Build Coastguard Worker    def parse(arg: str) -> Argument:
2033*da0073e9SAndroid Build Coastguard Worker        name: str
2034*da0073e9SAndroid Build Coastguard Worker        default: str | None
2035*da0073e9SAndroid Build Coastguard Worker        assert " " in arg, f"illegal argument '{arg}'"
2036*da0073e9SAndroid Build Coastguard Worker        if "=" in arg:
2037*da0073e9SAndroid Build Coastguard Worker            assert arg.count("=") == 1, f"illegal argument with default value: '{arg}'"
2038*da0073e9SAndroid Build Coastguard Worker            type_and_annot_and_name, default = arg.split("=")
2039*da0073e9SAndroid Build Coastguard Worker            type_and_annot, name = type_and_annot_and_name.rsplit(" ", 1)
2040*da0073e9SAndroid Build Coastguard Worker            name_and_default = f"{name}={default}"
2041*da0073e9SAndroid Build Coastguard Worker        else:
2042*da0073e9SAndroid Build Coastguard Worker            type_and_annot, name_and_default = arg.rsplit(" ", 1)
2043*da0073e9SAndroid Build Coastguard Worker            name = name_and_default
2044*da0073e9SAndroid Build Coastguard Worker            default = None
2045*da0073e9SAndroid Build Coastguard Worker        # TODO: deduplicate annotation matching with Return
2046*da0073e9SAndroid Build Coastguard Worker        match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
2047*da0073e9SAndroid Build Coastguard Worker        annotation: Annotation | None
2048*da0073e9SAndroid Build Coastguard Worker        if match:
2049*da0073e9SAndroid Build Coastguard Worker            # If you update this, make sure the __str__ still works too
2050*da0073e9SAndroid Build Coastguard Worker            assert match.group(2) in [
2051*da0073e9SAndroid Build Coastguard Worker                "",
2052*da0073e9SAndroid Build Coastguard Worker                "?",
2053*da0073e9SAndroid Build Coastguard Worker                "[]",
2054*da0073e9SAndroid Build Coastguard Worker            ], "unrecognized alias analysis form with Tensor"
2055*da0073e9SAndroid Build Coastguard Worker            type_s = "Tensor" + match.group(2)
2056*da0073e9SAndroid Build Coastguard Worker            annotation = Annotation.parse(match.group(1))
2057*da0073e9SAndroid Build Coastguard Worker        else:
2058*da0073e9SAndroid Build Coastguard Worker            type_s = type_and_annot
2059*da0073e9SAndroid Build Coastguard Worker            annotation = None
2060*da0073e9SAndroid Build Coastguard Worker        type = Type.parse(type_s)
2061*da0073e9SAndroid Build Coastguard Worker        r = Argument(
2062*da0073e9SAndroid Build Coastguard Worker            name=name,
2063*da0073e9SAndroid Build Coastguard Worker            type=type,
2064*da0073e9SAndroid Build Coastguard Worker            default=default,
2065*da0073e9SAndroid Build Coastguard Worker            annotation=annotation,
2066*da0073e9SAndroid Build Coastguard Worker        )
2067*da0073e9SAndroid Build Coastguard Worker        assert str(r) == arg, f"{str(r)} != {arg}"
2068*da0073e9SAndroid Build Coastguard Worker        return r
2069*da0073e9SAndroid Build Coastguard Worker
2070*da0073e9SAndroid Build Coastguard Worker    @property
2071*da0073e9SAndroid Build Coastguard Worker    def is_write(self) -> bool:
2072*da0073e9SAndroid Build Coastguard Worker        return self.annotation is not None and self.annotation.is_write
2073*da0073e9SAndroid Build Coastguard Worker
2074*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
2075*da0073e9SAndroid Build Coastguard Worker        type = f"{self.type}"
2076*da0073e9SAndroid Build Coastguard Worker        if self.annotation:
2077*da0073e9SAndroid Build Coastguard Worker            assert type in ["Tensor", "Tensor?", "Tensor[]"]
2078*da0073e9SAndroid Build Coastguard Worker            type = type.replace("Tensor", f"Tensor({self.annotation})")
2079*da0073e9SAndroid Build Coastguard Worker        if self.name is None:
2080*da0073e9SAndroid Build Coastguard Worker            return type
2081*da0073e9SAndroid Build Coastguard Worker        else:
2082*da0073e9SAndroid Build Coastguard Worker            mb_default = ""
2083*da0073e9SAndroid Build Coastguard Worker            if self.default:
2084*da0073e9SAndroid Build Coastguard Worker                mb_default = f"={self.default}"
2085*da0073e9SAndroid Build Coastguard Worker            return f"{type} {self.name}{mb_default}"
2086*da0073e9SAndroid Build Coastguard Worker
2087*da0073e9SAndroid Build Coastguard Worker
2088*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
2089*da0073e9SAndroid Build Coastguard Workerclass Return:
2090*da0073e9SAndroid Build Coastguard Worker    name: str | None
2091*da0073e9SAndroid Build Coastguard Worker    type: Type
2092*da0073e9SAndroid Build Coastguard Worker    annotation: Annotation | None
2093*da0073e9SAndroid Build Coastguard Worker
2094*da0073e9SAndroid Build Coastguard Worker    @property
2095*da0073e9SAndroid Build Coastguard Worker    def alias_info(self) -> Annotation | None:
2096*da0073e9SAndroid Build Coastguard Worker        return self.annotation
2097*da0073e9SAndroid Build Coastguard Worker
2098*da0073e9SAndroid Build Coastguard Worker    @staticmethod
2099*da0073e9SAndroid Build Coastguard Worker    def parse(arg: str) -> Return:
2100*da0073e9SAndroid Build Coastguard Worker        name: str | None
2101*da0073e9SAndroid Build Coastguard Worker        if " " in arg:
2102*da0073e9SAndroid Build Coastguard Worker            type_and_annot, name = arg.rsplit(" ", 1)
2103*da0073e9SAndroid Build Coastguard Worker        else:
2104*da0073e9SAndroid Build Coastguard Worker            type_and_annot = arg
2105*da0073e9SAndroid Build Coastguard Worker            name = None
2106*da0073e9SAndroid Build Coastguard Worker        match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
2107*da0073e9SAndroid Build Coastguard Worker        annotation: Annotation | None
2108*da0073e9SAndroid Build Coastguard Worker        if match:
2109*da0073e9SAndroid Build Coastguard Worker            # If you update this, make sure the __str__ still works too
2110*da0073e9SAndroid Build Coastguard Worker            assert match.group(2) in [
2111*da0073e9SAndroid Build Coastguard Worker                "",
2112*da0073e9SAndroid Build Coastguard Worker                "?",
2113*da0073e9SAndroid Build Coastguard Worker                "[]",
2114*da0073e9SAndroid Build Coastguard Worker            ], "unrecognized alias analysis form with Tensor"
2115*da0073e9SAndroid Build Coastguard Worker            type_s = "Tensor" + match.group(2)
2116*da0073e9SAndroid Build Coastguard Worker            annotation = Annotation.parse(match.group(1))
2117*da0073e9SAndroid Build Coastguard Worker        else:
2118*da0073e9SAndroid Build Coastguard Worker            type_s = type_and_annot
2119*da0073e9SAndroid Build Coastguard Worker            annotation = None
2120*da0073e9SAndroid Build Coastguard Worker        type = Type.parse(type_s)
2121*da0073e9SAndroid Build Coastguard Worker        r = Return(
2122*da0073e9SAndroid Build Coastguard Worker            name=name,
2123*da0073e9SAndroid Build Coastguard Worker            type=type,
2124*da0073e9SAndroid Build Coastguard Worker            annotation=annotation,
2125*da0073e9SAndroid Build Coastguard Worker        )
2126*da0073e9SAndroid Build Coastguard Worker        assert str(r) == arg, f"{str(r)} != {arg}"
2127*da0073e9SAndroid Build Coastguard Worker        return r
2128*da0073e9SAndroid Build Coastguard Worker
2129*da0073e9SAndroid Build Coastguard Worker    @property
2130*da0073e9SAndroid Build Coastguard Worker    def is_write(self) -> bool:
2131*da0073e9SAndroid Build Coastguard Worker        return self.annotation is not None and self.annotation.is_write
2132*da0073e9SAndroid Build Coastguard Worker
2133*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
2134*da0073e9SAndroid Build Coastguard Worker        type = f"{self.type}"
2135*da0073e9SAndroid Build Coastguard Worker        if self.annotation:
2136*da0073e9SAndroid Build Coastguard Worker            assert type in ["Tensor", "Tensor?", "Tensor[]"]
2137*da0073e9SAndroid Build Coastguard Worker            type = type.replace("Tensor", f"Tensor({self.annotation})")
2138*da0073e9SAndroid Build Coastguard Worker        if self.name is None:
2139*da0073e9SAndroid Build Coastguard Worker            return type
2140*da0073e9SAndroid Build Coastguard Worker        else:
2141*da0073e9SAndroid Build Coastguard Worker            return f"{type} {self.name}"
2142*da0073e9SAndroid Build Coastguard Worker
2143*da0073e9SAndroid Build Coastguard Worker
2144*da0073e9SAndroid Build Coastguard Worker# Represents the self argument for functions that may be methods
2145*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
2146*da0073e9SAndroid Build Coastguard Workerclass SelfArgument:
2147*da0073e9SAndroid Build Coastguard Worker    argument: Argument
2148*da0073e9SAndroid Build Coastguard Worker
2149*da0073e9SAndroid Build Coastguard Worker
2150*da0073e9SAndroid Build Coastguard Worker# Bundle of arguments that represent a TensorOptions.  This is mostly
2151*da0073e9SAndroid Build Coastguard Worker# relevant for the public C++ API but we bake it into the core data
2152*da0073e9SAndroid Build Coastguard Worker# model because other APIs often have to interact with it
2153*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
2154*da0073e9SAndroid Build Coastguard Workerclass TensorOptionsArguments:
2155*da0073e9SAndroid Build Coastguard Worker    dtype: Argument
2156*da0073e9SAndroid Build Coastguard Worker    layout: Argument
2157*da0073e9SAndroid Build Coastguard Worker    device: Argument
2158*da0073e9SAndroid Build Coastguard Worker    pin_memory: Argument
2159*da0073e9SAndroid Build Coastguard Worker
2160*da0073e9SAndroid Build Coastguard Worker    def all(self) -> Sequence[Argument]:
2161*da0073e9SAndroid Build Coastguard Worker        return [self.dtype, self.layout, self.device, self.pin_memory]
2162*da0073e9SAndroid Build Coastguard Worker
2163*da0073e9SAndroid Build Coastguard Worker
2164*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
2165*da0073e9SAndroid Build Coastguard Workerclass Arguments:
2166*da0073e9SAndroid Build Coastguard Worker    # pre_self_positional is usually empty, but is notably non-empty
2167*da0073e9SAndroid Build Coastguard Worker    # for where.self, where the condition argument comes before the
2168*da0073e9SAndroid Build Coastguard Worker    # self argument
2169*da0073e9SAndroid Build Coastguard Worker    pre_self_positional: tuple[Argument, ...]
2170*da0073e9SAndroid Build Coastguard Worker    self_arg: SelfArgument | None
2171*da0073e9SAndroid Build Coastguard Worker    post_self_positional: tuple[Argument, ...]
2172*da0073e9SAndroid Build Coastguard Worker
2173*da0073e9SAndroid Build Coastguard Worker    pre_tensor_options_kwarg_only: tuple[Argument, ...]
2174*da0073e9SAndroid Build Coastguard Worker    tensor_options: TensorOptionsArguments | None
2175*da0073e9SAndroid Build Coastguard Worker    # post_tensor_options is typically memory format, which should be
2176*da0073e9SAndroid Build Coastguard Worker    # part of tensor options but isn't right now, and is usually
2177*da0073e9SAndroid Build Coastguard Worker    # placed after the tensor options arguments
2178*da0073e9SAndroid Build Coastguard Worker    post_tensor_options_kwarg_only: tuple[Argument, ...]
2179*da0073e9SAndroid Build Coastguard Worker
2180*da0073e9SAndroid Build Coastguard Worker    # Unlike in the previous codegen, we have factored out 'out' arguments
2181*da0073e9SAndroid Build Coastguard Worker    # in the canonical representation, removing them from kwarg
2182*da0073e9SAndroid Build Coastguard Worker    # arguments.  This choice is justified by numerous downstream
2183*da0073e9SAndroid Build Coastguard Worker    # transformations which treat out arguments specially; additionally,
2184*da0073e9SAndroid Build Coastguard Worker    # you can see that canonicity is not violated!
2185*da0073e9SAndroid Build Coastguard Worker    out: tuple[Argument, ...]  # these are also kwarg-only
2186*da0073e9SAndroid Build Coastguard Worker
2187*da0073e9SAndroid Build Coastguard Worker    @property
2188*da0073e9SAndroid Build Coastguard Worker    def flat_non_out(self) -> Sequence[Argument]:
2189*da0073e9SAndroid Build Coastguard Worker        ret: list[Argument] = []
2190*da0073e9SAndroid Build Coastguard Worker        ret.extend(self.flat_positional)
2191*da0073e9SAndroid Build Coastguard Worker        ret.extend(self.flat_kwarg_only)
2192*da0073e9SAndroid Build Coastguard Worker        return ret
2193*da0073e9SAndroid Build Coastguard Worker
2194*da0073e9SAndroid Build Coastguard Worker    @property
2195*da0073e9SAndroid Build Coastguard Worker    def flat_positional(self) -> Sequence[Argument]:
2196*da0073e9SAndroid Build Coastguard Worker        ret: list[Argument] = []
2197*da0073e9SAndroid Build Coastguard Worker        ret.extend(self.pre_self_positional)
2198*da0073e9SAndroid Build Coastguard Worker        if self.self_arg is not None:
2199*da0073e9SAndroid Build Coastguard Worker            ret.append(self.self_arg.argument)
2200*da0073e9SAndroid Build Coastguard Worker        ret.extend(self.post_self_positional)
2201*da0073e9SAndroid Build Coastguard Worker        return ret
2202*da0073e9SAndroid Build Coastguard Worker
2203*da0073e9SAndroid Build Coastguard Worker    @property
2204*da0073e9SAndroid Build Coastguard Worker    def post_self_positional_mutable(self) -> Sequence[Argument]:
2205*da0073e9SAndroid Build Coastguard Worker        return [a for a in self.post_self_positional if a.is_write]
2206*da0073e9SAndroid Build Coastguard Worker
2207*da0073e9SAndroid Build Coastguard Worker    # NB: doesn't contain out arguments
2208*da0073e9SAndroid Build Coastguard Worker    @property
2209*da0073e9SAndroid Build Coastguard Worker    def flat_kwarg_only(self) -> Sequence[Argument]:
2210*da0073e9SAndroid Build Coastguard Worker        ret: list[Argument] = []
2211*da0073e9SAndroid Build Coastguard Worker        ret.extend(self.pre_tensor_options_kwarg_only)
2212*da0073e9SAndroid Build Coastguard Worker        if self.tensor_options is not None:
2213*da0073e9SAndroid Build Coastguard Worker            ret.extend(self.tensor_options.all())
2214*da0073e9SAndroid Build Coastguard Worker        ret.extend(self.post_tensor_options_kwarg_only)
2215*da0073e9SAndroid Build Coastguard Worker        return ret
2216*da0073e9SAndroid Build Coastguard Worker
2217*da0073e9SAndroid Build Coastguard Worker    @property
2218*da0073e9SAndroid Build Coastguard Worker    def flat_all(self) -> Sequence[Argument]:
2219*da0073e9SAndroid Build Coastguard Worker        ret: list[Argument] = []
2220*da0073e9SAndroid Build Coastguard Worker        ret.extend(self.flat_positional)
2221*da0073e9SAndroid Build Coastguard Worker        ret.extend(self.flat_kwarg_only)
2222*da0073e9SAndroid Build Coastguard Worker        ret.extend(self.out)
2223*da0073e9SAndroid Build Coastguard Worker        return ret
2224*da0073e9SAndroid Build Coastguard Worker
2225*da0073e9SAndroid Build Coastguard Worker    @property
2226*da0073e9SAndroid Build Coastguard Worker    def non_out(
2227*da0073e9SAndroid Build Coastguard Worker        self,
2228*da0073e9SAndroid Build Coastguard Worker    ) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
2229*da0073e9SAndroid Build Coastguard Worker        ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
2230*da0073e9SAndroid Build Coastguard Worker        ret.extend(self.positional)
2231*da0073e9SAndroid Build Coastguard Worker        ret.extend(self.kwarg_only)
2232*da0073e9SAndroid Build Coastguard Worker        return ret
2233*da0073e9SAndroid Build Coastguard Worker
2234*da0073e9SAndroid Build Coastguard Worker    @property
2235*da0073e9SAndroid Build Coastguard Worker    def positional(self) -> Sequence[Argument | SelfArgument]:
2236*da0073e9SAndroid Build Coastguard Worker        ret: list[Argument | SelfArgument] = []
2237*da0073e9SAndroid Build Coastguard Worker        ret.extend(self.pre_self_positional)
2238*da0073e9SAndroid Build Coastguard Worker        if self.self_arg is not None:
2239*da0073e9SAndroid Build Coastguard Worker            ret.append(self.self_arg)
2240*da0073e9SAndroid Build Coastguard Worker        ret.extend(self.post_self_positional)
2241*da0073e9SAndroid Build Coastguard Worker        return ret
2242*da0073e9SAndroid Build Coastguard Worker
2243*da0073e9SAndroid Build Coastguard Worker    @property
2244*da0073e9SAndroid Build Coastguard Worker    def kwarg_only(self) -> Sequence[Argument | TensorOptionsArguments]:
2245*da0073e9SAndroid Build Coastguard Worker        ret: list[Argument | TensorOptionsArguments] = []
2246*da0073e9SAndroid Build Coastguard Worker        ret.extend(self.pre_tensor_options_kwarg_only)
2247*da0073e9SAndroid Build Coastguard Worker        if self.tensor_options is not None:
2248*da0073e9SAndroid Build Coastguard Worker            ret.append(self.tensor_options)
2249*da0073e9SAndroid Build Coastguard Worker        ret.extend(self.post_tensor_options_kwarg_only)
2250*da0073e9SAndroid Build Coastguard Worker        return ret
2251*da0073e9SAndroid Build Coastguard Worker
2252*da0073e9SAndroid Build Coastguard Worker    @property
2253*da0073e9SAndroid Build Coastguard Worker    def all(self) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
2254*da0073e9SAndroid Build Coastguard Worker        ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
2255*da0073e9SAndroid Build Coastguard Worker        ret.extend(self.positional)
2256*da0073e9SAndroid Build Coastguard Worker        ret.extend(self.kwarg_only)
2257*da0073e9SAndroid Build Coastguard Worker        ret.extend(self.out)
2258*da0073e9SAndroid Build Coastguard Worker        return ret
2259*da0073e9SAndroid Build Coastguard Worker
2260*da0073e9SAndroid Build Coastguard Worker    def mutable_arg_names(self) -> list[str]:
2261*da0073e9SAndroid Build Coastguard Worker        return [
2262*da0073e9SAndroid Build Coastguard Worker            a.name
2263*da0073e9SAndroid Build Coastguard Worker            for a in self.flat_all
2264*da0073e9SAndroid Build Coastguard Worker            if a.annotation is not None and a.annotation.is_write
2265*da0073e9SAndroid Build Coastguard Worker        ]
2266*da0073e9SAndroid Build Coastguard Worker
2267*da0073e9SAndroid Build Coastguard Worker    def has_tensor_arg(self) -> bool:
2268*da0073e9SAndroid Build Coastguard Worker        return any(a.type.is_tensor_like() for a in self.flat_non_out)
2269*da0073e9SAndroid Build Coastguard Worker
2270*da0073e9SAndroid Build Coastguard Worker    def has_symint_arg(self) -> bool:
2271*da0073e9SAndroid Build Coastguard Worker        return any(a.type.is_symint_like() for a in self.flat_non_out)
2272*da0073e9SAndroid Build Coastguard Worker
2273*da0073e9SAndroid Build Coastguard Worker    def has_generator_arg(self) -> bool:
2274*da0073e9SAndroid Build Coastguard Worker        return any(a.type.is_generator_like() for a in self.flat_non_out)
2275*da0073e9SAndroid Build Coastguard Worker
2276*da0073e9SAndroid Build Coastguard Worker    def signature(self, *, strip_default: bool = False) -> Arguments:
2277*da0073e9SAndroid Build Coastguard Worker        # dataclasses.replace could be used here, but it is less
2278*da0073e9SAndroid Build Coastguard Worker        # type safe so for now I've opted to type everything out
2279*da0073e9SAndroid Build Coastguard Worker        def strip_arg_annotation(a: Argument) -> Argument:
2280*da0073e9SAndroid Build Coastguard Worker            return Argument(
2281*da0073e9SAndroid Build Coastguard Worker                name=a.name,
2282*da0073e9SAndroid Build Coastguard Worker                type=a.type,
2283*da0073e9SAndroid Build Coastguard Worker                default=a.default if not strip_default else None,
2284*da0073e9SAndroid Build Coastguard Worker                annotation=None,
2285*da0073e9SAndroid Build Coastguard Worker            )
2286*da0073e9SAndroid Build Coastguard Worker
2287*da0073e9SAndroid Build Coastguard Worker        return Arguments(
2288*da0073e9SAndroid Build Coastguard Worker            pre_self_positional=tuple(
2289*da0073e9SAndroid Build Coastguard Worker                map(strip_arg_annotation, self.pre_self_positional)
2290*da0073e9SAndroid Build Coastguard Worker            ),
2291*da0073e9SAndroid Build Coastguard Worker            self_arg=SelfArgument(strip_arg_annotation(self.self_arg.argument))
2292*da0073e9SAndroid Build Coastguard Worker            if self.self_arg is not None
2293*da0073e9SAndroid Build Coastguard Worker            else None,
2294*da0073e9SAndroid Build Coastguard Worker            post_self_positional=tuple(
2295*da0073e9SAndroid Build Coastguard Worker                map(strip_arg_annotation, self.post_self_positional)
2296*da0073e9SAndroid Build Coastguard Worker            ),
2297*da0073e9SAndroid Build Coastguard Worker            # Since TensorOptions are dropped, the post_tensor_options_kwargs are
2298*da0073e9SAndroid Build Coastguard Worker            # converted to pre_tensor_options_kwargs
2299*da0073e9SAndroid Build Coastguard Worker            pre_tensor_options_kwarg_only=tuple(
2300*da0073e9SAndroid Build Coastguard Worker                map(strip_arg_annotation, self.pre_tensor_options_kwarg_only)
2301*da0073e9SAndroid Build Coastguard Worker            )
2302*da0073e9SAndroid Build Coastguard Worker            + tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)),
2303*da0073e9SAndroid Build Coastguard Worker            # TensorOptions are dropped in signature,
2304*da0073e9SAndroid Build Coastguard Worker            # so we can pair factory functions with their out= variants.
2305*da0073e9SAndroid Build Coastguard Worker            tensor_options=None,
2306*da0073e9SAndroid Build Coastguard Worker            post_tensor_options_kwarg_only=(),
2307*da0073e9SAndroid Build Coastguard Worker            # out arguments are dropped in signature
2308*da0073e9SAndroid Build Coastguard Worker            out=(),
2309*da0073e9SAndroid Build Coastguard Worker        )
2310*da0073e9SAndroid Build Coastguard Worker
2311*da0073e9SAndroid Build Coastguard Worker    def remove_self_annotation(self) -> Arguments:
2312*da0073e9SAndroid Build Coastguard Worker        assert self.self_arg is not None
2313*da0073e9SAndroid Build Coastguard Worker        return dataclasses.replace(
2314*da0073e9SAndroid Build Coastguard Worker            self,
2315*da0073e9SAndroid Build Coastguard Worker            self_arg=SelfArgument(
2316*da0073e9SAndroid Build Coastguard Worker                dataclasses.replace(self.self_arg.argument, annotation=None)
2317*da0073e9SAndroid Build Coastguard Worker            ),
2318*da0073e9SAndroid Build Coastguard Worker        )
2319*da0073e9SAndroid Build Coastguard Worker
2320*da0073e9SAndroid Build Coastguard Worker    def with_out_args(self, outs: list[Argument]) -> Arguments:
2321*da0073e9SAndroid Build Coastguard Worker        assert len(self.out) == 0
2322*da0073e9SAndroid Build Coastguard Worker        return dataclasses.replace(
2323*da0073e9SAndroid Build Coastguard Worker            self,
2324*da0073e9SAndroid Build Coastguard Worker            out=tuple(outs),
2325*da0073e9SAndroid Build Coastguard Worker        )
2326*da0073e9SAndroid Build Coastguard Worker
2327*da0073e9SAndroid Build Coastguard Worker    @staticmethod
2328*da0073e9SAndroid Build Coastguard Worker    def _preparse(args: str) -> tuple[list[Argument], list[Argument], list[Argument]]:
2329*da0073e9SAndroid Build Coastguard Worker        positional: list[Argument] = []
2330*da0073e9SAndroid Build Coastguard Worker        kwarg_only: list[Argument] = []
2331*da0073e9SAndroid Build Coastguard Worker        out: list[Argument] = []
2332*da0073e9SAndroid Build Coastguard Worker        arguments_acc = positional
2333*da0073e9SAndroid Build Coastguard Worker
2334*da0073e9SAndroid Build Coastguard Worker        # TODO: Use a real parser here; this will get bamboozled
2335*da0073e9SAndroid Build Coastguard Worker        # by signatures that contain things like std::array<bool, 2> (note the space)
2336*da0073e9SAndroid Build Coastguard Worker        for arg in args.split(", "):
2337*da0073e9SAndroid Build Coastguard Worker            if not arg:
2338*da0073e9SAndroid Build Coastguard Worker                continue
2339*da0073e9SAndroid Build Coastguard Worker            if arg == "*":
2340*da0073e9SAndroid Build Coastguard Worker                assert (
2341*da0073e9SAndroid Build Coastguard Worker                    arguments_acc is positional
2342*da0073e9SAndroid Build Coastguard Worker                ), "invalid syntax: kwarg-only specifier * can only occur once"
2343*da0073e9SAndroid Build Coastguard Worker                arguments_acc = kwarg_only
2344*da0073e9SAndroid Build Coastguard Worker                continue
2345*da0073e9SAndroid Build Coastguard Worker            parg = Argument.parse(arg)
2346*da0073e9SAndroid Build Coastguard Worker            # Currently, we rely directly on the invariant that there are NO
2347*da0073e9SAndroid Build Coastguard Worker            # kwarg-only mutating arguments.  If you want to relax this,
2348*da0073e9SAndroid Build Coastguard Worker            # we will need a more semantic way of matching that takes
2349*da0073e9SAndroid Build Coastguard Worker            # into account return arguments.  In that case, you will have
2350*da0073e9SAndroid Build Coastguard Worker            # to manage out computation a level up, in FunctionSchema.  See Note
2351*da0073e9SAndroid Build Coastguard Worker            # [is_out_fn]
2352*da0073e9SAndroid Build Coastguard Worker            if parg.annotation is not None and parg.annotation.is_write:
2353*da0073e9SAndroid Build Coastguard Worker                if arguments_acc is positional:
2354*da0073e9SAndroid Build Coastguard Worker                    pass  # do nothing
2355*da0073e9SAndroid Build Coastguard Worker                elif arguments_acc is kwarg_only:
2356*da0073e9SAndroid Build Coastguard Worker                    arguments_acc = out
2357*da0073e9SAndroid Build Coastguard Worker            else:
2358*da0073e9SAndroid Build Coastguard Worker                assert arguments_acc is not out
2359*da0073e9SAndroid Build Coastguard Worker            arguments_acc.append(parg)
2360*da0073e9SAndroid Build Coastguard Worker
2361*da0073e9SAndroid Build Coastguard Worker        return positional, kwarg_only, out
2362*da0073e9SAndroid Build Coastguard Worker
2363*da0073e9SAndroid Build Coastguard Worker    @staticmethod
2364*da0073e9SAndroid Build Coastguard Worker    def parse(args: str) -> Arguments:
2365*da0073e9SAndroid Build Coastguard Worker        """
2366*da0073e9SAndroid Build Coastguard Worker        Input: 'int x, int y, int z'
2367*da0073e9SAndroid Build Coastguard Worker        """
2368*da0073e9SAndroid Build Coastguard Worker
2369*da0073e9SAndroid Build Coastguard Worker        # We do this in two phases.  First we parse into three
2370*da0073e9SAndroid Build Coastguard Worker        # main categories: positional, kwarg_only, out.
2371*da0073e9SAndroid Build Coastguard Worker        # Then, we reparse positional and kwarg_only to separate
2372*da0073e9SAndroid Build Coastguard Worker        # out the self argument and tensor options arguments.
2373*da0073e9SAndroid Build Coastguard Worker
2374*da0073e9SAndroid Build Coastguard Worker        positional, kwarg_only, out = Arguments._preparse(args)
2375*da0073e9SAndroid Build Coastguard Worker
2376*da0073e9SAndroid Build Coastguard Worker        # Split self argument
2377*da0073e9SAndroid Build Coastguard Worker        self_ix = None
2378*da0073e9SAndroid Build Coastguard Worker        for i, a in enumerate(positional):
2379*da0073e9SAndroid Build Coastguard Worker            if a.name == "self":
2380*da0073e9SAndroid Build Coastguard Worker                self_ix = i
2381*da0073e9SAndroid Build Coastguard Worker                break
2382*da0073e9SAndroid Build Coastguard Worker        pre_self_positional: list[Argument]
2383*da0073e9SAndroid Build Coastguard Worker        self_arg: SelfArgument | None
2384*da0073e9SAndroid Build Coastguard Worker        post_self_positional: list[Argument]
2385*da0073e9SAndroid Build Coastguard Worker        if self_ix is not None:
2386*da0073e9SAndroid Build Coastguard Worker            pre_self_positional = positional[:self_ix]
2387*da0073e9SAndroid Build Coastguard Worker            self_arg = SelfArgument(positional[self_ix])
2388*da0073e9SAndroid Build Coastguard Worker            post_self_positional = positional[self_ix + 1 :]
2389*da0073e9SAndroid Build Coastguard Worker        else:
2390*da0073e9SAndroid Build Coastguard Worker            pre_self_positional = []
2391*da0073e9SAndroid Build Coastguard Worker            self_arg = None
2392*da0073e9SAndroid Build Coastguard Worker            post_self_positional = positional
2393*da0073e9SAndroid Build Coastguard Worker
2394*da0073e9SAndroid Build Coastguard Worker        # Group tensor options arguments
2395*da0073e9SAndroid Build Coastguard Worker        pre_tensor_options_kwarg_only: list[Argument] = []
2396*da0073e9SAndroid Build Coastguard Worker        tensor_options: TensorOptionsArguments | None = None
2397*da0073e9SAndroid Build Coastguard Worker        post_tensor_options_kwarg_only: list[Argument] = []
2398*da0073e9SAndroid Build Coastguard Worker        kwarg_only_acc = pre_tensor_options_kwarg_only
2399*da0073e9SAndroid Build Coastguard Worker
2400*da0073e9SAndroid Build Coastguard Worker        def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
2401*da0073e9SAndroid Build Coastguard Worker            return lambda a: a.name == name and a.type in [ty, OptionalType(ty)]
2402*da0073e9SAndroid Build Coastguard Worker
2403*da0073e9SAndroid Build Coastguard Worker        predicates = [  # order matters
2404*da0073e9SAndroid Build Coastguard Worker            pred("dtype", Type.parse("ScalarType")),
2405*da0073e9SAndroid Build Coastguard Worker            pred("layout", Type.parse("Layout")),
2406*da0073e9SAndroid Build Coastguard Worker            pred("device", Type.parse("Device")),
2407*da0073e9SAndroid Build Coastguard Worker            pred("pin_memory", Type.parse("bool")),
2408*da0073e9SAndroid Build Coastguard Worker        ]
2409*da0073e9SAndroid Build Coastguard Worker
2410*da0073e9SAndroid Build Coastguard Worker        i = 0
2411*da0073e9SAndroid Build Coastguard Worker        while i < len(kwarg_only):
2412*da0073e9SAndroid Build Coastguard Worker            # If there is enough space...
2413*da0073e9SAndroid Build Coastguard Worker            if i <= len(kwarg_only) - len(predicates):
2414*da0073e9SAndroid Build Coastguard Worker                # And the next len(predicates) arguments look like TensorOptions arguments
2415*da0073e9SAndroid Build Coastguard Worker                if all(
2416*da0073e9SAndroid Build Coastguard Worker                    p(a)
2417*da0073e9SAndroid Build Coastguard Worker                    for p, a in zip(predicates, kwarg_only[i : i + len(predicates)])
2418*da0073e9SAndroid Build Coastguard Worker                ):
2419*da0073e9SAndroid Build Coastguard Worker                    assert kwarg_only_acc is pre_tensor_options_kwarg_only
2420*da0073e9SAndroid Build Coastguard Worker                    # Group them together as one argument
2421*da0073e9SAndroid Build Coastguard Worker                    tensor_options = TensorOptionsArguments(
2422*da0073e9SAndroid Build Coastguard Worker                        dtype=kwarg_only[i],
2423*da0073e9SAndroid Build Coastguard Worker                        layout=kwarg_only[i + 1],
2424*da0073e9SAndroid Build Coastguard Worker                        device=kwarg_only[i + 2],
2425*da0073e9SAndroid Build Coastguard Worker                        pin_memory=kwarg_only[i + 3],
2426*da0073e9SAndroid Build Coastguard Worker                    )
2427*da0073e9SAndroid Build Coastguard Worker                    i += len(predicates)
2428*da0073e9SAndroid Build Coastguard Worker                    kwarg_only_acc = post_tensor_options_kwarg_only
2429*da0073e9SAndroid Build Coastguard Worker                    continue
2430*da0073e9SAndroid Build Coastguard Worker            kwarg_only_acc.append(kwarg_only[i])
2431*da0073e9SAndroid Build Coastguard Worker            i += 1
2432*da0073e9SAndroid Build Coastguard Worker
2433*da0073e9SAndroid Build Coastguard Worker        return Arguments(
2434*da0073e9SAndroid Build Coastguard Worker            pre_self_positional=tuple(pre_self_positional),
2435*da0073e9SAndroid Build Coastguard Worker            self_arg=self_arg,
2436*da0073e9SAndroid Build Coastguard Worker            post_self_positional=tuple(post_self_positional),
2437*da0073e9SAndroid Build Coastguard Worker            pre_tensor_options_kwarg_only=tuple(pre_tensor_options_kwarg_only),
2438*da0073e9SAndroid Build Coastguard Worker            tensor_options=tensor_options,
2439*da0073e9SAndroid Build Coastguard Worker            post_tensor_options_kwarg_only=tuple(post_tensor_options_kwarg_only),
2440*da0073e9SAndroid Build Coastguard Worker            out=tuple(out),
2441*da0073e9SAndroid Build Coastguard Worker        )
2442*da0073e9SAndroid Build Coastguard Worker
2443*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
2444*da0073e9SAndroid Build Coastguard Worker        all_arguments: list[str] = []
2445*da0073e9SAndroid Build Coastguard Worker        all_arguments.extend(map(str, self.flat_positional))
2446*da0073e9SAndroid Build Coastguard Worker        if self.flat_kwarg_only or self.out:
2447*da0073e9SAndroid Build Coastguard Worker            all_arguments.append("*")
2448*da0073e9SAndroid Build Coastguard Worker        all_arguments.extend(map(str, self.flat_kwarg_only))
2449*da0073e9SAndroid Build Coastguard Worker        all_arguments.extend(map(str, self.out))
2450*da0073e9SAndroid Build Coastguard Worker        return ", ".join(all_arguments)
2451*da0073e9SAndroid Build Coastguard Worker
2452*da0073e9SAndroid Build Coastguard Worker    def __post_init__(self) -> None:
2453*da0073e9SAndroid Build Coastguard Worker        # TODO: These invariants are weirdly asymmetric?
2454*da0073e9SAndroid Build Coastguard Worker        # TODO: Fancier types?
2455*da0073e9SAndroid Build Coastguard Worker        if self.self_arg is None:
2456*da0073e9SAndroid Build Coastguard Worker            assert not self.pre_self_positional
2457*da0073e9SAndroid Build Coastguard Worker        if self.tensor_options is None:
2458*da0073e9SAndroid Build Coastguard Worker            assert not self.post_tensor_options_kwarg_only
2459*da0073e9SAndroid Build Coastguard Worker
2460*da0073e9SAndroid Build Coastguard Worker        # We don't allow any of the following to have argument annotations,
2461*da0073e9SAndroid Build Coastguard Worker        # to keep things simple.
2462*da0073e9SAndroid Build Coastguard Worker        mutable_pre_self_positionals = [
2463*da0073e9SAndroid Build Coastguard Worker            a
2464*da0073e9SAndroid Build Coastguard Worker            for a in self.pre_self_positional
2465*da0073e9SAndroid Build Coastguard Worker            if a.annotation is not None and a.annotation.is_write
2466*da0073e9SAndroid Build Coastguard Worker        ]
2467*da0073e9SAndroid Build Coastguard Worker        assert (
2468*da0073e9SAndroid Build Coastguard Worker            len(mutable_pre_self_positionals) == 0
2469*da0073e9SAndroid Build Coastguard Worker        ), "mutable pre_self_positional arguments are not currently supported in the schema"
2470*da0073e9SAndroid Build Coastguard Worker
2471*da0073e9SAndroid Build Coastguard Worker
2472*da0073e9SAndroid Build Coastguard Worker# Names that validly are __iXXX__ indicating inplace operations.
2473*da0073e9SAndroid Build Coastguard Worker# Taken from https://www.python.org/dev/peps/pep-0203/#new-methods
2474*da0073e9SAndroid Build Coastguard Worker# NB: PyTorch hasn't actually implemented all of these
2475*da0073e9SAndroid Build Coastguard WorkerAUGMENTED_ASSIGNMENT_NAMES = [
2476*da0073e9SAndroid Build Coastguard Worker    "add",
2477*da0073e9SAndroid Build Coastguard Worker    "sub",
2478*da0073e9SAndroid Build Coastguard Worker    "mul",
2479*da0073e9SAndroid Build Coastguard Worker    "div",
2480*da0073e9SAndroid Build Coastguard Worker    "mod",
2481*da0073e9SAndroid Build Coastguard Worker    "pow",
2482*da0073e9SAndroid Build Coastguard Worker    "lshift",
2483*da0073e9SAndroid Build Coastguard Worker    "rshift",
2484*da0073e9SAndroid Build Coastguard Worker    "and",
2485*da0073e9SAndroid Build Coastguard Worker    "xor",
2486*da0073e9SAndroid Build Coastguard Worker    "or",
2487*da0073e9SAndroid Build Coastguard Worker]
2488*da0073e9SAndroid Build Coastguard Worker
2489*da0073e9SAndroid Build Coastguard Worker
2490*da0073e9SAndroid Build Coastguard Worker# A BaseOperatorName is what we think of the operator name, without
2491*da0073e9SAndroid Build Coastguard Worker# the overload name.  Unusually, we don't represent this as just a
2492*da0073e9SAndroid Build Coastguard Worker# string; instead, we directly represent a few important semantic
2493*da0073e9SAndroid Build Coastguard Worker# bits of information we derive from the string: namely whether
2494*da0073e9SAndroid Build Coastguard Worker# or not it's inplace (add_) and whether or not it's a double-underscore
2495*da0073e9SAndroid Build Coastguard Worker# method (__add__)
2496*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
2497*da0073e9SAndroid Build Coastguard Workerclass BaseOperatorName:
2498*da0073e9SAndroid Build Coastguard Worker    base: str
2499*da0073e9SAndroid Build Coastguard Worker    inplace: bool
2500*da0073e9SAndroid Build Coastguard Worker    dunder_method: bool
2501*da0073e9SAndroid Build Coastguard Worker    # Note [Overload Ambiguity With Functional Variants]
2502*da0073e9SAndroid Build Coastguard Worker    # A handful of operators have both a "mutable" and a "functional" variant.
2503*da0073e9SAndroid Build Coastguard Worker    # (native_batch_norm is a good example, although this isn't the case today).
2504*da0073e9SAndroid Build Coastguard Worker    # For those operators, the mutable and functional variant take in the same set of
2505*da0073e9SAndroid Build Coastguard Worker    # arguments, but have different alias annotations.
2506*da0073e9SAndroid Build Coastguard Worker    # this makes it ambiguous when you try to resolve an OverloadPacket into an overload,
2507*da0073e9SAndroid Build Coastguard Worker    # given a set of input arguments.
2508*da0073e9SAndroid Build Coastguard Worker    #
2509*da0073e9SAndroid Build Coastguard Worker    # So instead of making the "functional" variant in this case a real overload, e.g:
2510*da0073e9SAndroid Build Coastguard Worker    #   native_batch_norm (mutable variant)
2511*da0073e9SAndroid Build Coastguard Worker    #   native_batch_norm.functional (functional variant)
2512*da0073e9SAndroid Build Coastguard Worker    # we make it a new base operator,
2513*da0073e9SAndroid Build Coastguard Worker    #   native_batch_norm_functional (functional variant)
2514*da0073e9SAndroid Build Coastguard Worker    #
2515*da0073e9SAndroid Build Coastguard Worker    # In an ideal world, we would probably invert this so the operators were:
2516*da0073e9SAndroid Build Coastguard Worker    #   native_batch_norm.mutable (mutable variant)
2517*da0073e9SAndroid Build Coastguard Worker    #   native_batch_norm (functional variant)
2518*da0073e9SAndroid Build Coastguard Worker    #
2519*da0073e9SAndroid Build Coastguard Worker    # Doing that is BC-breaking though, so we're stuck with the above modeling.
2520*da0073e9SAndroid Build Coastguard Worker    functional_overload: bool = False
2521*da0073e9SAndroid Build Coastguard Worker
2522*da0073e9SAndroid Build Coastguard Worker    @staticmethod
2523*da0073e9SAndroid Build Coastguard Worker    def parse(op: str) -> BaseOperatorName:
2524*da0073e9SAndroid Build Coastguard Worker        assert op != ""
2525*da0073e9SAndroid Build Coastguard Worker        assert not op.endswith("_out"), (
2526*da0073e9SAndroid Build Coastguard Worker            "_out suffix is reserved and not permitted for operator names; "
2527*da0073e9SAndroid Build Coastguard Worker            "did you mean to specify an out overload name instead?"
2528*da0073e9SAndroid Build Coastguard Worker        )
2529*da0073e9SAndroid Build Coastguard Worker        m = re.match(r"^__([^_]+)__$", op)
2530*da0073e9SAndroid Build Coastguard Worker        if m is not None:
2531*da0073e9SAndroid Build Coastguard Worker            dunder_method = True
2532*da0073e9SAndroid Build Coastguard Worker            base = m.group(1)
2533*da0073e9SAndroid Build Coastguard Worker            if any(base == f"i{n}" for n in AUGMENTED_ASSIGNMENT_NAMES):
2534*da0073e9SAndroid Build Coastguard Worker                inplace = True
2535*da0073e9SAndroid Build Coastguard Worker                base = base[1:]
2536*da0073e9SAndroid Build Coastguard Worker            else:
2537*da0073e9SAndroid Build Coastguard Worker                inplace = False
2538*da0073e9SAndroid Build Coastguard Worker                # temporary, this is not intrinsically true but
2539*da0073e9SAndroid Build Coastguard Worker                # has been historically true for dunder methods
2540*da0073e9SAndroid Build Coastguard Worker                # we support  (but, if we ever got, say, __int__, this would
2541*da0073e9SAndroid Build Coastguard Worker                # be wrong!)
2542*da0073e9SAndroid Build Coastguard Worker                assert base[0] != "i"
2543*da0073e9SAndroid Build Coastguard Worker        else:
2544*da0073e9SAndroid Build Coastguard Worker            dunder_method = False
2545*da0073e9SAndroid Build Coastguard Worker            base = op
2546*da0073e9SAndroid Build Coastguard Worker            if base[-1] == "_":
2547*da0073e9SAndroid Build Coastguard Worker                inplace = True
2548*da0073e9SAndroid Build Coastguard Worker                base = base[:-1]
2549*da0073e9SAndroid Build Coastguard Worker            else:
2550*da0073e9SAndroid Build Coastguard Worker                inplace = False
2551*da0073e9SAndroid Build Coastguard Worker
2552*da0073e9SAndroid Build Coastguard Worker        # See Note [Overload Ambiguity With Functional Variants]
2553*da0073e9SAndroid Build Coastguard Worker        functional_suffix = "_functional"
2554*da0073e9SAndroid Build Coastguard Worker        if base.endswith(functional_suffix):
2555*da0073e9SAndroid Build Coastguard Worker            functional_overload = True
2556*da0073e9SAndroid Build Coastguard Worker            base = base[: -len(functional_suffix)]
2557*da0073e9SAndroid Build Coastguard Worker            # This seems complicated and unnecessary, so banning dunder methods
2558*da0073e9SAndroid Build Coastguard Worker            # for now on ops that have a functional + mutable variant (like native_batch_norm).
2559*da0073e9SAndroid Build Coastguard Worker            assert not dunder_method and not inplace
2560*da0073e9SAndroid Build Coastguard Worker        else:
2561*da0073e9SAndroid Build Coastguard Worker            functional_overload = False
2562*da0073e9SAndroid Build Coastguard Worker
2563*da0073e9SAndroid Build Coastguard Worker        r = BaseOperatorName(
2564*da0073e9SAndroid Build Coastguard Worker            base=base,
2565*da0073e9SAndroid Build Coastguard Worker            inplace=inplace,
2566*da0073e9SAndroid Build Coastguard Worker            dunder_method=dunder_method,
2567*da0073e9SAndroid Build Coastguard Worker            functional_overload=functional_overload,
2568*da0073e9SAndroid Build Coastguard Worker        )
2569*da0073e9SAndroid Build Coastguard Worker        assert str(r) == op, f"{str(r)} != {op}"
2570*da0073e9SAndroid Build Coastguard Worker        return r
2571*da0073e9SAndroid Build Coastguard Worker
2572*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
2573*da0073e9SAndroid Build Coastguard Worker        if self.dunder_method:
2574*da0073e9SAndroid Build Coastguard Worker            i = "i" if self.inplace else ""
2575*da0073e9SAndroid Build Coastguard Worker            return f"__{i}{self.base}__"
2576*da0073e9SAndroid Build Coastguard Worker        else:
2577*da0073e9SAndroid Build Coastguard Worker            i = (
2578*da0073e9SAndroid Build Coastguard Worker                "_"
2579*da0073e9SAndroid Build Coastguard Worker                if self.inplace
2580*da0073e9SAndroid Build Coastguard Worker                else "_functional"
2581*da0073e9SAndroid Build Coastguard Worker                if self.functional_overload
2582*da0073e9SAndroid Build Coastguard Worker                else ""
2583*da0073e9SAndroid Build Coastguard Worker            )
2584*da0073e9SAndroid Build Coastguard Worker            return f"{self.base}{i}"
2585*da0073e9SAndroid Build Coastguard Worker
2586*da0073e9SAndroid Build Coastguard Worker
2587*da0073e9SAndroid Build Coastguard Worker# Operator name is the base operator name along with the (typically not
2588*da0073e9SAndroid Build Coastguard Worker# user visible) overload string.
2589*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
2590*da0073e9SAndroid Build Coastguard Workerclass OperatorName:
2591*da0073e9SAndroid Build Coastguard Worker    name: BaseOperatorName
2592*da0073e9SAndroid Build Coastguard Worker    overload_name: str
2593*da0073e9SAndroid Build Coastguard Worker
2594*da0073e9SAndroid Build Coastguard Worker    @staticmethod
2595*da0073e9SAndroid Build Coastguard Worker    def parse(op_name: str) -> OperatorName:
2596*da0073e9SAndroid Build Coastguard Worker        if "." in op_name:
2597*da0073e9SAndroid Build Coastguard Worker            name, overload_name = op_name.split(".", 1)
2598*da0073e9SAndroid Build Coastguard Worker        else:
2599*da0073e9SAndroid Build Coastguard Worker            name = op_name
2600*da0073e9SAndroid Build Coastguard Worker            overload_name = ""
2601*da0073e9SAndroid Build Coastguard Worker        r = OperatorName(name=BaseOperatorName.parse(name), overload_name=overload_name)
2602*da0073e9SAndroid Build Coastguard Worker        assert str(r) == op_name, f"{str(r)} != {op_name}"
2603*da0073e9SAndroid Build Coastguard Worker        return r
2604*da0073e9SAndroid Build Coastguard Worker
2605*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
2606*da0073e9SAndroid Build Coastguard Worker        if self.overload_name:
2607*da0073e9SAndroid Build Coastguard Worker            return f"{self.name}.{self.overload_name}"
2608*da0073e9SAndroid Build Coastguard Worker        else:
2609*da0073e9SAndroid Build Coastguard Worker            return f"{self.name}"
2610*da0073e9SAndroid Build Coastguard Worker
2611*da0073e9SAndroid Build Coastguard Worker    # NB: This must be synchronized with the naming scheme in
2612*da0073e9SAndroid Build Coastguard Worker    # aten/src/ATen/templates/Operators.h
2613*da0073e9SAndroid Build Coastguard Worker    # Given a function schema "aten::op.overload(...)",
2614*da0073e9SAndroid Build Coastguard Worker    # If there is no overload name, this returns f"{op}"
2615*da0073e9SAndroid Build Coastguard Worker    # If there is an overload name, this returns f"{op}_{overload}"
2616*da0073e9SAndroid Build Coastguard Worker    def unambiguous_name(self) -> str:
2617*da0073e9SAndroid Build Coastguard Worker        if self.overload_name:
2618*da0073e9SAndroid Build Coastguard Worker            return f"{self.name}_{self.overload_name}"
2619*da0073e9SAndroid Build Coastguard Worker        else:
2620*da0073e9SAndroid Build Coastguard Worker            return f"{self.name}"
2621*da0073e9SAndroid Build Coastguard Worker
2622*da0073e9SAndroid Build Coastguard Worker    def remove_inplace(self) -> OperatorName:
2623*da0073e9SAndroid Build Coastguard Worker        return OperatorName(
2624*da0073e9SAndroid Build Coastguard Worker            name=BaseOperatorName(
2625*da0073e9SAndroid Build Coastguard Worker                base=self.name.base,
2626*da0073e9SAndroid Build Coastguard Worker                inplace=False,
2627*da0073e9SAndroid Build Coastguard Worker                dunder_method=self.name.dunder_method,
2628*da0073e9SAndroid Build Coastguard Worker            ),
2629*da0073e9SAndroid Build Coastguard Worker            overload_name=self.overload_name,
2630*da0073e9SAndroid Build Coastguard Worker        )
2631*da0073e9SAndroid Build Coastguard Worker
2632*da0073e9SAndroid Build Coastguard Worker    def with_overload(self, overload: str) -> OperatorName:
2633*da0073e9SAndroid Build Coastguard Worker        return OperatorName(
2634*da0073e9SAndroid Build Coastguard Worker            name=BaseOperatorName(
2635*da0073e9SAndroid Build Coastguard Worker                base=self.name.base,
2636*da0073e9SAndroid Build Coastguard Worker                inplace=False,
2637*da0073e9SAndroid Build Coastguard Worker                dunder_method=self.name.dunder_method,
2638*da0073e9SAndroid Build Coastguard Worker            ),
2639*da0073e9SAndroid Build Coastguard Worker            overload_name=overload,
2640*da0073e9SAndroid Build Coastguard Worker        )
2641*da0073e9SAndroid Build Coastguard Worker
2642*da0073e9SAndroid Build Coastguard Worker
2643*da0073e9SAndroid Build Coastguard Workerdef gets_generated_out_inplace_wrapper(
2644*da0073e9SAndroid Build Coastguard Worker    f: NativeFunction, g: NativeFunctionsGroup, b: BackendIndex
2645*da0073e9SAndroid Build Coastguard Worker) -> bool:
2646*da0073e9SAndroid Build Coastguard Worker    return (
2647*da0073e9SAndroid Build Coastguard Worker        f.func.kind() is not SchemaKind.functional
2648*da0073e9SAndroid Build Coastguard Worker        and not b.has_kernel(f)
2649*da0073e9SAndroid Build Coastguard Worker        and b.has_kernel(g.functional)
2650*da0073e9SAndroid Build Coastguard Worker    )
2651*da0073e9SAndroid Build Coastguard Worker
2652*da0073e9SAndroid Build Coastguard Worker
2653*da0073e9SAndroid Build Coastguard Worker# NativeFunction objects that are views (f.is_view_op returns True)
2654*da0073e9SAndroid Build Coastguard Worker# are added into a `NativeFunctionsViewGroup`, which we can use to
2655*da0073e9SAndroid Build Coastguard Worker# easily access the generated (optional) view_copy NativeFunction.
2656*da0073e9SAndroid Build Coastguard Worker# It's convenient to group them together, so we pair them up in NativeFunctionsViewGroup.
2657*da0073e9SAndroid Build Coastguard Worker# See Note [Codegen'd {view}_copy Operators]
2658*da0073e9SAndroid Build Coastguard Worker#
2659*da0073e9SAndroid Build Coastguard Worker# One property of this representation is that in order for a view-like op to be part of
2660*da0073e9SAndroid Build Coastguard Worker# a NativeFunctionsViewGroup, the "aliasing" version of that view op must exist.
2661*da0073e9SAndroid Build Coastguard Worker# There's one case where that doesn't happen: we have a non-aliasing `narrow_copy.out` op,
2662*da0073e9SAndroid Build Coastguard Worker# but don't have corresponding aliasing `narrow.out` op.
2663*da0073e9SAndroid Build Coastguard Worker# This means that `narrow_copy.out` won't appear as a NativeFunctionsViewGroup.
2664*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
2665*da0073e9SAndroid Build Coastguard Workerclass NativeFunctionsViewGroup:
2666*da0073e9SAndroid Build Coastguard Worker    view: NativeFunction
2667*da0073e9SAndroid Build Coastguard Worker    # Note: the {view}_copy operator is optional because we currently don't generate copy variants
2668*da0073e9SAndroid Build Coastguard Worker    # for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views
2669*da0073e9SAndroid Build Coastguard Worker    # (we already get them "for free" through decomposition)
2670*da0073e9SAndroid Build Coastguard Worker    view_copy: NativeFunction | None
2671*da0073e9SAndroid Build Coastguard Worker    # view_inplace ops are also optional, but every view_inplace op should have out-of-place variant.
2672*da0073e9SAndroid Build Coastguard Worker    view_inplace: NativeFunction | None
2673*da0073e9SAndroid Build Coastguard Worker
2674*da0073e9SAndroid Build Coastguard Worker    def __post_init__(self) -> None:
2675*da0073e9SAndroid Build Coastguard Worker        assert self.view.is_view_op
2676*da0073e9SAndroid Build Coastguard Worker        if self.view_copy is None:
2677*da0073e9SAndroid Build Coastguard Worker            assert not gets_generated_view_copy(self.view), (
2678*da0073e9SAndroid Build Coastguard Worker                f"{str(self.view.func.name)} appears to be a new operator that aliases its inputs."
2679*da0073e9SAndroid Build Coastguard Worker                " The codegen expects you to add a corresponding operator to native_functions.yaml:"
2680*da0073e9SAndroid Build Coastguard Worker                f" {get_view_copy_name(self.view)!s}."
2681*da0073e9SAndroid Build Coastguard Worker                " See Note [view_copy NativeFunctions] for details."
2682*da0073e9SAndroid Build Coastguard Worker            )
2683*da0073e9SAndroid Build Coastguard Worker        else:
2684*da0073e9SAndroid Build Coastguard Worker            assert self.view_copy.func.name.name.base.endswith(("_copy", "_scatter"))
2685*da0073e9SAndroid Build Coastguard Worker            assert self.view.func.signature() == self.view_copy.func.signature(
2686*da0073e9SAndroid Build Coastguard Worker                strip_view_copy_name=True,
2687*da0073e9SAndroid Build Coastguard Worker            )
2688*da0073e9SAndroid Build Coastguard Worker            assert "view_copy" in self.view_copy.tags, (
2689*da0073e9SAndroid Build Coastguard Worker                f"{str(self.view_copy.func.name), str(self.view.tags)} appears to be a view_copy operator. The codegen expects"
2690*da0073e9SAndroid Build Coastguard Worker                " view_copy operators to be annotated with the 'view_copy' tag in native_functions.yaml."
2691*da0073e9SAndroid Build Coastguard Worker                " See Note [view_copy NativeFunction] for details."
2692*da0073e9SAndroid Build Coastguard Worker            )
2693*da0073e9SAndroid Build Coastguard Worker        if self.view_inplace is not None:
2694*da0073e9SAndroid Build Coastguard Worker            assert self.view.func.signature() == self.view_inplace.func.signature()
2695*da0073e9SAndroid Build Coastguard Worker
2696*da0073e9SAndroid Build Coastguard Worker        if self.view.has_composite_implicit_autograd_kernel:
2697*da0073e9SAndroid Build Coastguard Worker            if self.view_inplace is not None:
2698*da0073e9SAndroid Build Coastguard Worker                assert self.view_inplace.has_composite_implicit_autograd_kernel, (
2699*da0073e9SAndroid Build Coastguard Worker                    f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
2700*da0073e9SAndroid Build Coastguard Worker                    " both have CompositeImplicitAutograd kernels, or both not have composite kernels."
2701*da0073e9SAndroid Build Coastguard Worker                )
2702*da0073e9SAndroid Build Coastguard Worker        if self.view.has_composite_implicit_autograd_nested_tensor_kernel:
2703*da0073e9SAndroid Build Coastguard Worker            if self.view_inplace is not None:
2704*da0073e9SAndroid Build Coastguard Worker                assert (
2705*da0073e9SAndroid Build Coastguard Worker                    self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel
2706*da0073e9SAndroid Build Coastguard Worker                ), (
2707*da0073e9SAndroid Build Coastguard Worker                    f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
2708*da0073e9SAndroid Build Coastguard Worker                    " both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels."
2709*da0073e9SAndroid Build Coastguard Worker                )
2710*da0073e9SAndroid Build Coastguard Worker
2711*da0073e9SAndroid Build Coastguard Worker    def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]:
2712*da0073e9SAndroid Build Coastguard Worker        yield self.view
2713*da0073e9SAndroid Build Coastguard Worker        if self.view_inplace is not None:
2714*da0073e9SAndroid Build Coastguard Worker            yield self.view_inplace
2715*da0073e9SAndroid Build Coastguard Worker        if self.view_copy is not None and include_copy:
2716*da0073e9SAndroid Build Coastguard Worker            yield self.view_copy
2717*da0073e9SAndroid Build Coastguard Worker
2718*da0073e9SAndroid Build Coastguard Worker    @property
2719*da0073e9SAndroid Build Coastguard Worker    def root_name(self) -> str:
2720*da0073e9SAndroid Build Coastguard Worker        return self.view.root_name
2721*da0073e9SAndroid Build Coastguard Worker
2722*da0073e9SAndroid Build Coastguard Worker    @property
2723*da0073e9SAndroid Build Coastguard Worker    def composite(self) -> bool:
2724*da0073e9SAndroid Build Coastguard Worker        # We currently assert that the "group" is consistent.
2725*da0073e9SAndroid Build Coastguard Worker        # If the view op is composite, then its view_inplace op is too.
2726*da0073e9SAndroid Build Coastguard Worker        return self.view.has_composite_implicit_autograd_kernel
2727*da0073e9SAndroid Build Coastguard Worker
2728*da0073e9SAndroid Build Coastguard Worker
2729*da0073e9SAndroid Build Coastguard Workerdef gets_generated_view_copy(f: NativeFunction) -> bool:
2730*da0073e9SAndroid Build Coastguard Worker    # Only aliasing (view) operators get a copy variant.
2731*da0073e9SAndroid Build Coastguard Worker    if not f.is_view_op:
2732*da0073e9SAndroid Build Coastguard Worker        return False
2733*da0073e9SAndroid Build Coastguard Worker    # We don't need to bother generating copy variants for CompositeImplicitAutograd ops,
2734*da0073e9SAndroid Build Coastguard Worker    # because we can let them decompose into base view ops.
2735*da0073e9SAndroid Build Coastguard Worker    if f.has_composite_implicit_autograd_kernel:
2736*da0073e9SAndroid Build Coastguard Worker        return False
2737*da0073e9SAndroid Build Coastguard Worker    # We also don't need to generate copy variants for inplace views.
2738*da0073e9SAndroid Build Coastguard Worker    if "inplace_view" in f.tags:
2739*da0073e9SAndroid Build Coastguard Worker        return False
2740*da0073e9SAndroid Build Coastguard Worker    # Assume ops ending in _inverse have manually-defined copy variants
2741*da0073e9SAndroid Build Coastguard Worker    # (e.g. slice_inverse() has the copy variant slice_scatter()).
2742*da0073e9SAndroid Build Coastguard Worker    # We -could- probably generate these as well, but the codegen will be
2743*da0073e9SAndroid Build Coastguard Worker    # slightly different, and hand-writing these few kernels keeps codegen
2744*da0073e9SAndroid Build Coastguard Worker    # complexity lower.
2745*da0073e9SAndroid Build Coastguard Worker    if f.func.name.name.base.endswith("_inverse"):
2746*da0073e9SAndroid Build Coastguard Worker        return False
2747*da0073e9SAndroid Build Coastguard Worker    return True
2748*da0073e9SAndroid Build Coastguard Worker
2749*da0073e9SAndroid Build Coastguard Worker
2750*da0073e9SAndroid Build Coastguard Worker# Given a NativeFunction that corresponds to a view op,
2751*da0073e9SAndroid Build Coastguard Worker# returns the OperatorName of the corresponding "copy" variant of the op.
2752*da0073e9SAndroid Build Coastguard Workerdef get_view_copy_name(f: NativeFunction) -> OperatorName:
2753*da0073e9SAndroid Build Coastguard Worker    # Right now, when asking for a view op's corresponding "view_copy" name
2754*da0073e9SAndroid Build Coastguard Worker    # we assert for sanity that the op is allowed to have a generated view_copy variant.
2755*da0073e9SAndroid Build Coastguard Worker    # (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op).
2756*da0073e9SAndroid Build Coastguard Worker    # However, narrow_copy() already exists as an op directly in native_functions.yaml.
2757*da0073e9SAndroid Build Coastguard Worker    # I'm hardcoding narrow_copy here for now to maintain the assert,
2758*da0073e9SAndroid Build Coastguard Worker    # But we could also just get rid of the assert.
2759*da0073e9SAndroid Build Coastguard Worker    list_of_ops_with_explicit_view_copy_operators = ["narrow"]
2760*da0073e9SAndroid Build Coastguard Worker    if str(f.func.name) not in list_of_ops_with_explicit_view_copy_operators:
2761*da0073e9SAndroid Build Coastguard Worker        assert gets_generated_view_copy(f)
2762*da0073e9SAndroid Build Coastguard Worker
2763*da0073e9SAndroid Build Coastguard Worker    base_name = f"{f.func.name.name.base}_copy"
2764*da0073e9SAndroid Build Coastguard Worker    view_copy_name = OperatorName(
2765*da0073e9SAndroid Build Coastguard Worker        name=BaseOperatorName(
2766*da0073e9SAndroid Build Coastguard Worker            base=base_name, inplace=False, dunder_method=f.func.name.name.dunder_method
2767*da0073e9SAndroid Build Coastguard Worker        ),
2768*da0073e9SAndroid Build Coastguard Worker        overload_name=f.func.name.overload_name,
2769*da0073e9SAndroid Build Coastguard Worker    )
2770*da0073e9SAndroid Build Coastguard Worker    return view_copy_name
2771*da0073e9SAndroid Build Coastguard Worker
2772*da0073e9SAndroid Build Coastguard Worker
2773*da0073e9SAndroid Build Coastguard Worker# Helper functions for parsing argument lists (both inputs and returns)
2774*da0073e9SAndroid Build Coastguard Worker
2775*da0073e9SAndroid Build Coastguard Worker
2776*da0073e9SAndroid Build Coastguard Workerdef parse_returns(return_decl: str) -> tuple[Return, ...]:
2777*da0073e9SAndroid Build Coastguard Worker    """
2778*da0073e9SAndroid Build Coastguard Worker    Input: '()'
2779*da0073e9SAndroid Build Coastguard Worker    Output: []
2780*da0073e9SAndroid Build Coastguard Worker    """
2781*da0073e9SAndroid Build Coastguard Worker    if return_decl == "()":
2782*da0073e9SAndroid Build Coastguard Worker        return ()
2783*da0073e9SAndroid Build Coastguard Worker    if return_decl[0] == "(" and return_decl[-1] == ")":
2784*da0073e9SAndroid Build Coastguard Worker        return_decl = return_decl[1:-1]
2785*da0073e9SAndroid Build Coastguard Worker    return tuple(Return.parse(arg) for arg in return_decl.split(", "))
2786*da0073e9SAndroid Build Coastguard Worker
2787*da0073e9SAndroid Build Coastguard Worker
2788*da0073e9SAndroid Build Coastguard Worker# A Precompute instance consists of a map from kernel argument name
2789*da0073e9SAndroid Build Coastguard Worker# to the list of Argument instances that should replace that
2790*da0073e9SAndroid Build Coastguard Worker# kernel argument in the impl function.
2791*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
2792*da0073e9SAndroid Build Coastguard Workerclass Precompute:
2793*da0073e9SAndroid Build Coastguard Worker    # A map from kernel argument name -> a list of precomputed
2794*da0073e9SAndroid Build Coastguard Worker    # elements that replaces/supersedes it.
2795*da0073e9SAndroid Build Coastguard Worker    replace: dict[str, list[Argument]]
2796*da0073e9SAndroid Build Coastguard Worker    # List of precomputed args added without replacement
2797*da0073e9SAndroid Build Coastguard Worker    add: list[Argument]
2798*da0073e9SAndroid Build Coastguard Worker
2799*da0073e9SAndroid Build Coastguard Worker    @staticmethod
2800*da0073e9SAndroid Build Coastguard Worker    def parse(src: object) -> Precompute:
2801*da0073e9SAndroid Build Coastguard Worker        assert isinstance(src, list)
2802*da0073e9SAndroid Build Coastguard Worker
2803*da0073e9SAndroid Build Coastguard Worker        # src is a list of strings of the format:
2804*da0073e9SAndroid Build Coastguard Worker        #   {kernel param name} -> {replacement decl}[, {replacement decl}, ...]
2805*da0073e9SAndroid Build Coastguard Worker        #   [{add decl}[, {add decl}, ...]]
2806*da0073e9SAndroid Build Coastguard Worker        # The last line is optional and contains the precomputed parameters that are
2807*da0073e9SAndroid Build Coastguard Worker        # added without replacement.
2808*da0073e9SAndroid Build Coastguard Worker        # The other lines are parsed to get the names of which precomputed elements
2809*da0073e9SAndroid Build Coastguard Worker        # should replace which kernel arguments.
2810*da0073e9SAndroid Build Coastguard Worker        add_args = []
2811*da0073e9SAndroid Build Coastguard Worker        if " -> " not in src[-1]:
2812*da0073e9SAndroid Build Coastguard Worker            add_list = src[-1].split(",")
2813*da0073e9SAndroid Build Coastguard Worker            add_args = [Argument.parse(name.strip()) for name in add_list]
2814*da0073e9SAndroid Build Coastguard Worker            src = src[:-1]
2815*da0073e9SAndroid Build Coastguard Worker
2816*da0073e9SAndroid Build Coastguard Worker        replace = {}
2817*da0073e9SAndroid Build Coastguard Worker        for raw_replace_item in src:
2818*da0073e9SAndroid Build Coastguard Worker            assert isinstance(raw_replace_item, str)
2819*da0073e9SAndroid Build Coastguard Worker            assert " -> " in raw_replace_item, (
2820*da0073e9SAndroid Build Coastguard Worker                "precomputed parameters without replacement"
2821*da0073e9SAndroid Build Coastguard Worker                " are allowed only in the last line"
2822*da0073e9SAndroid Build Coastguard Worker            )
2823*da0073e9SAndroid Build Coastguard Worker
2824*da0073e9SAndroid Build Coastguard Worker            arg, with_list_raw = raw_replace_item.split(" -> ")
2825*da0073e9SAndroid Build Coastguard Worker            assert (
2826*da0073e9SAndroid Build Coastguard Worker                " " not in arg
2827*da0073e9SAndroid Build Coastguard Worker            ), f"illegal kernel param name '{arg}' in precomputed parameters'"
2828*da0073e9SAndroid Build Coastguard Worker            with_list = with_list_raw.split(",")
2829*da0073e9SAndroid Build Coastguard Worker            with_list_args = [Argument.parse(name.strip()) for name in with_list]
2830*da0073e9SAndroid Build Coastguard Worker            replace[arg] = with_list_args
2831*da0073e9SAndroid Build Coastguard Worker
2832*da0073e9SAndroid Build Coastguard Worker        r = Precompute(replace=replace, add=add_args)
2833*da0073e9SAndroid Build Coastguard Worker        assert r.to_list() == src, "r.to_list() != src"
2834*da0073e9SAndroid Build Coastguard Worker        return r
2835*da0073e9SAndroid Build Coastguard Worker
2836*da0073e9SAndroid Build Coastguard Worker    def __post_init__(self) -> None:
2837*da0073e9SAndroid Build Coastguard Worker        # the template parameters are upper so if these are the
2838*da0073e9SAndroid Build Coastguard Worker        # same then it is ambiguous
2839*da0073e9SAndroid Build Coastguard Worker        for a in self.add:
2840*da0073e9SAndroid Build Coastguard Worker            assert a.name.upper() != a.name
2841*da0073e9SAndroid Build Coastguard Worker        for args in self.replace.values():
2842*da0073e9SAndroid Build Coastguard Worker            for a in args:
2843*da0073e9SAndroid Build Coastguard Worker                assert a.name.upper() != a.name
2844*da0073e9SAndroid Build Coastguard Worker
2845*da0073e9SAndroid Build Coastguard Worker    def to_list(self) -> list[str]:
2846*da0073e9SAndroid Build Coastguard Worker        replace_list = []
2847*da0073e9SAndroid Build Coastguard Worker        for kernel_param, replacement_params in self.replace.items():
2848*da0073e9SAndroid Build Coastguard Worker            replacements = ", ".join(str(param) for param in replacement_params)
2849*da0073e9SAndroid Build Coastguard Worker            replace_list.append(f"{kernel_param} -> {replacements}")
2850*da0073e9SAndroid Build Coastguard Worker
2851*da0073e9SAndroid Build Coastguard Worker        return replace_list
2852