xref: /aosp_15_r20/external/pytorch/torchgen/model.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import dataclasses
4import itertools
5import re
6from dataclasses import dataclass
7from enum import auto, Enum
8from typing import Callable, Iterator, Sequence
9
10from torchgen.utils import assert_never, NamespaceHelper, OrderedSet
11
12
13# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
14#
15#                           DATA MODEL
16#
17# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
18#
19# Some general principles for our data model.
20#
21# - Stop using C++ data types as the internal data representation
22#   format.  Instead, the internal data structures are centered
23#   around JIT schema representation.  This avoid a big problem
24#   with the old codegen where we read in all the types from
25#   native_functions.yaml and then immediately had to retranslate
26#   them into C++ types.
27#
28# - More semantic data representation.  Instead of representing
29#   everything as dicts and strings, we define dataclasses for
30#   every interesting entity the code generation has to deal with.
31#   These dataclasses have strong semantic invariants: for example,
32#   we generally require them to roundtrip losslessly into the
33#   form they were parsed from.  These structures are immutable
34#   and you're expected to populate information once during
35#   construction.
36
37
38# Represent a source location; used for better error reporting
39@dataclass(frozen=True)
40class Location:
41    file: str
42    line: int
43
44    def __str__(self) -> str:
45        return f"{self.file}:{self.line}"
46
47
48# Valid values of the 'variants' field in native_functions.yaml
49class Variant(Enum):
50    function = auto()
51    method = auto()
52
53
54# Default kernel namespace
55DEFAULT_KERNEL_NAMESPACE = "at::native"
56
57# NOTE: Keep the list in sync with `DispatchKey` in c10/core/DispatchKey.h
58BACKEND_COMPONENTS = "CPU CUDA HIP XLA MTIA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split()
59FUNCTIONALITY_KEYS = [
60    "",
61    "Quantized",
62    "Sparse",
63    "SparseCsr",
64    "NestedTensor",
65    "Autograd",
66]
67
68# This list guards dispatches that can be used in derivatives.yaml
69# For now we omit AutogradFunctionality and AutogradOther
70AUTOGRAD_KEYS = ["AutogradNestedTensor"] + [
71    "Autograd" + component for component in BACKEND_COMPONENTS
72]
73
74FRAGMENT_NAMESPACES = {"quantized", "quantized_decomposed"}
75
76
77# This doesn't have to be in sync with the header, it only needs to contain
78# entries that we actually use in the codegen or want pyi entries for
79class DispatchKey(Enum):
80    Undefined = 0
81    CatchAll = Undefined
82
83    FPGA = auto()
84    MAIA = auto()
85    Vulkan = auto()
86    Metal = auto()
87    MKLDNN = auto()
88    OpenGL = auto()
89    OpenCL = auto()
90    IDEEP = auto()
91    CustomRNGKeyId = auto()
92    MkldnnCPU = auto()
93    Sparse = auto()
94    SparseCsr = auto()
95    NestedTensor = auto()
96    Dense = auto()
97
98    PythonTLSSnapshot = auto()
99    PreDispatch = auto()
100    PythonDispatcher = auto()
101    Python = auto()
102    FuncTorchDynamicLayerBackMode = auto()
103    ZeroTensor = auto()
104    Conjugate = auto()
105    Negative = auto()
106    BackendSelect = auto()
107    Named = auto()
108    AutogradOther = auto()
109    AutogradFunctionality = auto()
110    AutogradNestedTensor = auto()
111    Tracer = auto()
112    Autocast = auto()
113    AutocastCPU = auto()
114    AutocastCUDA = auto()
115    Batched = auto()
116    VmapMode = auto()
117    FuncTorchGradWrapper = auto()
118    FuncTorchBatched = auto()
119    BatchedNestedTensor = auto()
120    FuncTorchVmapMode = auto()
121    FuncTorchDynamicLayerFrontMode = auto()
122    Functionalize = auto()
123    TESTING_ONLY_GenericWrapper = auto()
124    TESTING_ONLY_GenericMode = auto()
125
126    ADInplaceOrView = auto()
127    Autograd = auto()
128    CompositeImplicitAutograd = auto()
129    CompositeImplicitAutogradNestedTensor = auto()
130    CompositeExplicitAutograd = auto()
131    CompositeExplicitAutogradNonFunctional = auto()
132    FuncTorchBatchedDecomposition = auto()
133
134    # BEGIN autogenerated
135    CPU = auto()
136    CUDA = auto()
137    HIP = auto()
138    XLA = auto()
139    MTIA = auto()
140    MPS = auto()
141    IPU = auto()
142    XPU = auto()
143    HPU = auto()
144    VE = auto()
145    Lazy = auto()
146    Meta = auto()
147    PrivateUse1 = auto()
148    PrivateUse2 = auto()
149    PrivateUse3 = auto()
150    QuantizedCPU = auto()
151    QuantizedCUDA = auto()
152    QuantizedHIP = auto()
153    QuantizedXLA = auto()
154    QuantizedMTIA = auto()
155    QuantizedMPS = auto()
156    QuantizedIPU = auto()
157    QuantizedXPU = auto()
158    QuantizedHPU = auto()
159    QuantizedVE = auto()
160    QuantizedLazy = auto()
161    QuantizedMeta = auto()
162    QuantizedPrivateUse1 = auto()
163    QuantizedPrivateUse2 = auto()
164    QuantizedPrivateUse3 = auto()
165    SparseCPU = auto()
166    SparseCUDA = auto()
167    SparseHIP = auto()
168    SparseXLA = auto()
169    SparseMTIA = auto()
170    SparseMPS = auto()
171    SparseIPU = auto()
172    SparseXPU = auto()
173    SparseHPU = auto()
174    SparseVE = auto()
175    SparseLazy = auto()
176    SparseMeta = auto()
177    SparsePrivateUse1 = auto()
178    SparsePrivateUse2 = auto()
179    SparsePrivateUse3 = auto()
180    SparseCsrCPU = auto()
181    SparseCsrCUDA = auto()
182    SparseCsrHIP = auto()
183    SparseCsrXLA = auto()
184    SparseCsrMTIA = auto()
185    SparseCsrMPS = auto()
186    SparseCsrIPU = auto()
187    SparseCsrXPU = auto()
188    SparseCsrHPU = auto()
189    SparseCsrVE = auto()
190    SparseCsrLazy = auto()
191    SparseCsrMeta = auto()
192    SparseCsrPrivateUse1 = auto()
193    SparseCsrPrivateUse2 = auto()
194    SparseCsrPrivateUse3 = auto()
195    NestedTensorCPU = auto()
196    NestedTensorCUDA = auto()
197    NestedTensorHIP = auto()
198    NestedTensorXLA = auto()
199    NestedTensorMTIA = auto()
200    NestedTensorMPS = auto()
201    NestedTensorIPU = auto()
202    NestedTensorXPU = auto()
203    NestedTensorHPU = auto()
204    NestedTensorVE = auto()
205    NestedTensorLazy = auto()
206    NestedTensorMeta = auto()
207    NestedTensorPrivateUse1 = auto()
208    NestedTensorPrivateUse2 = auto()
209    NestedTensorPrivateUse3 = auto()
210    AutogradCPU = auto()
211    AutogradCUDA = auto()
212    AutogradHIP = auto()
213    AutogradXLA = auto()
214    AutogradMTIA = auto()
215    AutogradMPS = auto()
216    AutogradIPU = auto()
217    AutogradXPU = auto()
218    AutogradHPU = auto()
219    AutogradVE = auto()
220    AutogradLazy = auto()
221    AutogradMeta = auto()
222    AutogradPrivateUse1 = auto()
223    AutogradPrivateUse2 = auto()
224    AutogradPrivateUse3 = auto()
225    # END autogenerated
226
227    def __str__(self) -> str:
228        return self.name
229
230    def lower(self) -> str:
231        return str(self).lower()
232
233    @staticmethod
234    def parse(value: str) -> DispatchKey:
235        for k, v in DispatchKey.__members__.items():
236            if k == value:
237                return v
238        raise AssertionError(f"unknown dispatch key {value}")
239
240
241class _TorchDispatchModeKey(Enum):
242    FAKE = auto()
243    PROXY = auto()
244    FUNCTIONAL = auto()
245
246
247def codegen_per_backend_entries() -> str:
248    r = []
249    for fk in FUNCTIONALITY_KEYS:
250        for bc in BACKEND_COMPONENTS:
251            r.append(f"    {fk}{bc} = auto()")
252    return "\n".join(r)
253
254
255for fk in FUNCTIONALITY_KEYS:
256    for bc in BACKEND_COMPONENTS:
257        if not hasattr(DispatchKey, fk + bc):
258            r = codegen_per_backend_entries()
259            print(r)
260            raise RuntimeError(
261                f"Missing {fk}{bc} from DispatchKey enum.  Here is the autogenerated list we expect to have:\n\n{r}"
262            )
263
264
265STRUCTURED_DISPATCH_KEYS = {
266    DispatchKey.MPS,
267    DispatchKey.CUDA,
268    DispatchKey.CPU,
269    DispatchKey.XPU,
270}
271UFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU}
272
273# Set of supported dispatch keys
274dispatch_keys = [
275    DispatchKey.CPU,
276    DispatchKey.SparseCPU,
277    DispatchKey.SparseCsrCPU,
278    DispatchKey.MkldnnCPU,
279    DispatchKey.CUDA,
280    DispatchKey.MPS,
281    DispatchKey.XPU,
282    DispatchKey.SparseCUDA,
283    DispatchKey.SparseCsrCUDA,
284    DispatchKey.QuantizedCPU,
285    DispatchKey.QuantizedCUDA,
286    DispatchKey.CompositeImplicitAutograd,
287    DispatchKey.CompositeImplicitAutogradNestedTensor,
288    DispatchKey.CompositeExplicitAutograd,
289    DispatchKey.CompositeExplicitAutogradNonFunctional,
290    DispatchKey.NestedTensorCPU,
291    DispatchKey.NestedTensorCUDA,
292    # Meta is a magic key: it is automatically generated for structured
293    # kernels
294    DispatchKey.Meta,
295    DispatchKey.SparseMeta,
296    DispatchKey.SparseCsrMeta,
297    DispatchKey.QuantizedMeta,
298    DispatchKey.NestedTensorMeta,
299    DispatchKey.ZeroTensor,
300]
301
302
303# Dispatch keys that "support all backends".  These codegen slightly differently
304# then backend specific keys.
305def is_generic_dispatch_key(dk: DispatchKey) -> bool:
306    return dk in {
307        DispatchKey.CompositeExplicitAutograd,
308        DispatchKey.CompositeExplicitAutogradNonFunctional,
309        DispatchKey.CompositeImplicitAutograd,
310        DispatchKey.CompositeImplicitAutogradNestedTensor,
311    }
312
313
314# CUDA specific dispatch keys
315def is_cuda_dispatch_key(dk: DispatchKey) -> bool:
316    return dk in {
317        DispatchKey.CUDA,
318        DispatchKey.QuantizedCUDA,
319        DispatchKey.SparseCUDA,
320        DispatchKey.SparseCsrCUDA,
321        DispatchKey.NestedTensorCUDA,
322        DispatchKey.AutogradCUDA,
323    }
324
325
326# XPU specific dispatcy keys
327def is_xpu_dispatch_key(dk: DispatchKey) -> bool:
328    return dk in {
329        DispatchKey.XPU,
330        DispatchKey.QuantizedXPU,
331        DispatchKey.SparseXPU,
332        DispatchKey.SparseCsrXPU,
333        DispatchKey.NestedTensorXPU,
334        DispatchKey.AutogradXPU,
335    }
336
337
338# Structured kernel generation is only supported for certain key types;
339# otherwise use old-style
340def is_structured_dispatch_key(dk: DispatchKey) -> bool:
341    return dk in STRUCTURED_DISPATCH_KEYS
342
343
344def is_ufunc_dispatch_key(dk: DispatchKey) -> bool:
345    # For now, ufunc dispatch keys coincide with structured keys
346    return dk in UFUNC_DISPATCH_KEYS
347
348
349# This is oddly named ScalarType and not DType for symmetry with C++
350class ScalarType(Enum):
351    Byte = auto()
352    Char = auto()
353    Short = auto()
354    Int = auto()
355    Long = auto()
356    Half = auto()
357    Float = auto()
358    Double = auto()
359    ComplexHalf = auto()
360    ComplexFloat = auto()
361    ComplexDouble = auto()
362    Bool = auto()
363    BFloat16 = auto()
364    Float8_e5m2 = auto()
365    Float8_e5m2fnuz = auto()
366    Float8_e4m3fn = auto()
367    Float8_e4m3fnuz = auto()
368
369    def __str__(self) -> str:
370        return self.name
371
372    @staticmethod
373    def maybe_parse(value: str) -> ScalarType | None:
374        for k, v in ScalarType.__members__.items():
375            if k == value:
376                return v
377        return None
378
379    @staticmethod
380    def parse(value: str) -> ScalarType:
381        mb_r = ScalarType.maybe_parse(value)
382        assert mb_r is not None, f"unknown dtype {value}"
383        return mb_r
384
385    @staticmethod
386    def parse_set(values: str) -> OrderedSet[ScalarType]:
387        dtypes: OrderedSet[ScalarType] = OrderedSet()
388        for value in values.split(", "):
389            if value in DTYPE_CLASSES:
390                dtypes.update(DTYPE_CLASSES[value])
391            else:
392                dtypes.add(ScalarType.parse(value))
393        return dtypes
394
395
396DTYPE_CLASSES: dict[str, OrderedSet[ScalarType]] = {}
397# NB: Integral doesn't include boolean
398DTYPE_CLASSES["Integral"] = OrderedSet(
399    [
400        ScalarType.Byte,
401        ScalarType.Char,
402        ScalarType.Int,
403        ScalarType.Long,
404        ScalarType.Short,
405    ]
406)
407# NB: Floating doesn't include low precision types
408DTYPE_CLASSES["Floating"] = OrderedSet([ScalarType.Float, ScalarType.Double])
409DTYPE_CLASSES["Complex"] = OrderedSet(
410    [ScalarType.ComplexFloat, ScalarType.ComplexDouble]
411)
412DTYPE_CLASSES["All"] = DTYPE_CLASSES["Integral"] | DTYPE_CLASSES["Floating"]
413DTYPE_CLASSES["AllAndComplex"] = DTYPE_CLASSES["All"] | DTYPE_CLASSES["Complex"]
414DTYPE_CLASSES["FloatingAndComplex"] = (
415    DTYPE_CLASSES["Floating"] | DTYPE_CLASSES["Complex"]
416)
417
418
419# Represents the valid entries for ufunc_inner_loop in native_functions.yaml.
420# NB: if you add a new UfuncKey, you will teach torchgen.dest.ufunc how
421# to process it.  Most logic will ignore keys they don't understand, so your
422# new key will get silently ignored until you hook in logic to deal with it.
423class UfuncKey(Enum):
424    # These are low level keys that represent exactly one particular
425    # instantiation of the kernel produced by codegen
426    CUDAFunctor = auto()
427    CUDAFunctorOnOther = auto()
428    CUDAFunctorOnSelf = auto()
429
430    CPUScalar = auto()
431    CPUVector = auto()
432
433    # These are the ones users will usually specify, and
434    # implicitly "fill in" the low level keys
435    ScalarOnly = auto()  # CUDA*, CPUScalar
436    Generic = auto()  # CUDA*, CPU*
437
438    def __str__(self) -> str:
439        return self.name
440
441    @staticmethod
442    def parse(value: str) -> UfuncKey:
443        for k, v in UfuncKey.__members__.items():
444            if k == value:
445                return v
446        raise AssertionError(f"unknown ufunc key {value}")
447
448
449class DeviceCheckType(Enum):
450    NoCheck = 0
451    ExactSame = 1
452
453
454class ViewSchemaKind(Enum):
455    aliasing = auto()
456    aliasing_inplace = auto()
457    non_aliasing = auto()
458
459
460# The basic input to the code generation is native_functions.yaml.
461# The name "native", BTW, comes from the distinction between native
462# functions and legacy TH functions.  The legacy TH functions are gone,
463# but the "native" descriptor has stuck.
464#
465# NativeFunction models a single entry in native_functions.yaml.  Its
466# fields roughly correspond to what you would see in the YAML itself,
467# but after canonicalization and parsing has occurred.
468#
469# You can see some of the overall design patterns for how we setup
470# dataclasses in this class, but we will defer a complete discussion
471# of this at FunctionSchema.
472@dataclass(frozen=True)
473class NativeFunction:
474    # The namespace for this operator. For example, if we have "at::add"
475    # then the namespace would be "at". This enables ops to be registered
476    # through the same DSL with a custom namespace. If not specified, the
477    # default namespace would be "at".
478    namespace: str
479
480    # The function schema of the operator in question.  This schema
481    # has been parsed; see FunctionSchema for more about its structure.
482    # (This type is quoted as we are forward referencing a type
483    # defined later in the file.  I opted for this ordering of the
484    # classes for expository clarity.)
485    func: FunctionSchema
486
487    # Whether or not to generate mutable tensor arguments like regular
488    # ones
489    use_const_ref_for_mutable_tensors: bool
490
491    # Whether or not to omit automatic generation of a DeviceGuard
492    device_guard: bool
493
494    # How to emit automatic generation of device check
495    device_check: DeviceCheckType
496
497    # What python module to put the function in
498    python_module: str | None
499
500    # TODO: figure out what this does
501    category_override: str | None
502
503    # If no variants are specified in native_functions.yaml, this is
504    # assumed to be {'function'}.
505    variants: set[Variant]
506
507    # Whether or not we should skip generating registrations for
508    # this kernel.  This is a bit of a double-edged sword, as manual
509    # registrations don't participate in codegen-based selective build!
510    manual_kernel_registration: bool
511
512    # Whether or not to skip generating TensorMethod/Functions bindings
513    # for this kernel.  Technically, this doesn't actually skip generating
514    # the binding; instead, the binding gets generated to __dispatch_{funcname}
515    # so you can make use of the normal binding if you need it.
516    manual_cpp_binding: bool
517
518    # The location in the YAML file were this native function entry was
519    # defined.  This is for conveniently reporting error messages!
520    loc: Location
521
522    # A list of operators that are expected to be auto-generated for this NativeFunction.
523    # Note: This list isn't actually directly used by the codegen to generate anything.
524    # Instead, the codegen figures out what operators to generate purely based off of
525    # function schema, and uses the autogen declarations to error check.
526    # We expect every NativeFunction that gets auto-generated be explicitly called out
527    # in native_functions.yaml
528    autogen: list[OperatorName]
529
530    # If non-empty, this kernel is subject to ufunc codegen.
531    # Sorted by ufunc_key
532    ufunc_inner_loop: dict[UfuncKey, UfuncInnerLoop]
533
534    # Whether or not this out functions is a "structured kernel".  Structured
535    # kernels are defined a little differently from normal kernels; in
536    # particular, their shape checking logic is defined separately from
537    # the kernel.  Only out functions can be structured; other functions
538    # delegate to the out function using the structured_delegate keyword.
539    # Every structured kernel must have at least an out and a functional
540    # variant.
541    structured: bool
542
543    # Whether or not this non-out function is a structured kernel, defined
544    # in terms of the out kernel referenced by the string here.
545    structured_delegate: OperatorName | None
546
547    # Only valid for structured kernels.  Specifies alternative of what
548    # to inherit from when defining the meta class for the structured
549    # operator.  This will usually be TensorIteratorBase.  This also
550    # changes the semantics of set_output to call the parent class.
551    structured_inherits: str | None
552
553    # Structured kernels can declare elements as "precomputed". These elements
554    # are returned by the meta function in one struct and passed to the impl
555    # function in lieu of certain kernel arguments that these precomputed
556    # elements supersede. Information about the names and types of these
557    # precomputed elements and how they correspond to kernel arguments is stored
558    # in this member, if applicable.
559    precomputed: Precompute | None
560
561    # Argument names whose default  should be excluded from the C++ interface.
562    # Intended for resolving overload ambiguities between signatures.
563    cpp_no_default_args: set[str]
564
565    # Note [Abstract ATen methods]
566    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
567    # An abstract ATen method is one whose dispatch differs between
568    # types.  These are implemented in derived types (with a
569    # standard (throwing) definition in Type).  A concrete ATen
570    # method is one which has the same dispatch for all types;
571    # we just implement it in the base Type.  This is exposed
572    # in Declarations.yaml via a field named 'abstract'.
573    is_abstract: bool
574
575    # Whether or not the NativeFunction contains a backend-agnostic kernel
576    has_composite_implicit_autograd_kernel: bool
577    has_composite_implicit_autograd_nested_tensor_kernel: bool
578    has_composite_explicit_autograd_kernel: bool
579    has_composite_explicit_autograd_non_functional_kernel: bool
580
581    # Tags are used to describe semantic information about (groups of) operators,
582    # That aren't easily inferrable directly from the operator's schema.
583    tags: set[str]
584
585    # NB: The benefit of defining a dataclass is that we automatically get
586    # a constructor defined for all the fields we specify.  No need
587    # to explicitly write it out.
588
589    # We parse both the NativeFunction + backend-specific information about it, which it stored in a corresponding BackendIndex.
590    @staticmethod
591    def from_yaml(
592        ei: dict[str, object],
593        loc: Location,
594        valid_tags: set[str],
595        ignore_keys: set[DispatchKey] | None = None,
596    ) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
597        """
598        Parse a NativeFunction from a dictionary as directly parsed
599        from native_functions.yaml
600        """
601        e = ei.copy()
602
603        funcs = e.pop("func")
604        assert isinstance(funcs, str), f"not a str: {funcs}"
605        # only support one level of namespace. E.g., aten::add
606        namespace_helper = NamespaceHelper.from_namespaced_entity(
607            namespaced_entity=funcs, max_level=1
608        )
609        namespace = namespace_helper.get_cpp_namespace(default="aten")
610        func = FunctionSchema.parse(namespace_helper.entity_name)
611
612        cpp_no_default_args_list = e.pop("cpp_no_default_args", [])
613        assert isinstance(cpp_no_default_args_list, list)
614        cpp_no_default_args = set(cpp_no_default_args_list)
615
616        use_const_ref_for_mutable_tensors = e.pop(
617            "use_const_ref_for_mutable_tensors", False
618        )
619        assert isinstance(use_const_ref_for_mutable_tensors, bool)
620
621        variants_s = e.pop("variants", "function")
622        assert isinstance(variants_s, str)
623        variants: set[Variant] = set()
624        for v in variants_s.split(", "):
625            if v == "function":
626                variants.add(Variant.function)
627            elif v == "method":
628                variants.add(Variant.method)
629            else:
630                raise AssertionError(f"illegal variant {v}")
631
632        manual_kernel_registration = e.pop("manual_kernel_registration", False)
633        assert isinstance(
634            manual_kernel_registration, bool
635        ), f"not a bool: {manual_kernel_registration}"
636
637        manual_cpp_binding = e.pop("manual_cpp_binding", False)
638        assert isinstance(manual_cpp_binding, bool), f"not a bool: {manual_cpp_binding}"
639
640        device_guard = e.pop("device_guard", True)
641        assert isinstance(device_guard, bool), f"not a bool: {device_guard}"
642
643        device_check_s = e.pop("device_check", None)
644        assert device_check_s is None or isinstance(
645            device_check_s, str
646        ), f"not a str: {device_check_s}"
647        assert (
648            device_check_s is None or device_check_s in DeviceCheckType.__members__
649        ), f"illegal device_check: {device_check_s}"
650        device_check: DeviceCheckType
651        if device_check_s is None:
652            device_check = DeviceCheckType.ExactSame
653        else:
654            device_check = DeviceCheckType[device_check_s]
655
656        structured = e.pop("structured", False)
657        assert isinstance(structured, bool), f"not a bool: {structured}"
658
659        structured_delegate_s = e.pop("structured_delegate", None)
660        assert structured_delegate_s is None or isinstance(
661            structured_delegate_s, str
662        ), f"not a str: {structured_delegate_s}"
663        assert structured_delegate_s is None or "::" not in structured_delegate_s, (
664            "namespace is not supported in structured delegate,"
665            " using the same namespace as the native function"
666        )
667        structured_delegate: OperatorName | None = None
668        if structured_delegate_s is not None:
669            structured_delegate = OperatorName.parse(structured_delegate_s)
670
671        structured_inherits = e.pop("structured_inherits", None)
672        assert structured_inherits is None or isinstance(
673            structured_inherits, str
674        ), f"not a str: {structured_inherits}"
675        assert structured_inherits is None or "::" not in structured_inherits, (
676            "namespace is not supported in structured inherits,"
677            " using the same namespace as the native function"
678        )
679
680        python_module = e.pop("python_module", None)
681        assert python_module is None or isinstance(
682            python_module, str
683        ), f"not a str: {python_module}"
684        assert (
685            python_module is None or Variant.method not in variants
686        ), "functions in modules cannot be methods"
687
688        category_override = e.pop("category_override", None)
689        assert category_override is None or isinstance(
690            category_override, str
691        ), f"not a str: {category_override}"
692
693        precomputed_dict = e.pop("precomputed", None)
694        assert precomputed_dict is None or structured is True
695        precomputed = Precompute.parse(precomputed_dict) if precomputed_dict else None
696
697        tags_inp = e.pop("tags", [])
698        if isinstance(tags_inp, str):
699            tags_inp = [tags_inp]
700        assert isinstance(tags_inp, list)
701
702        # All aten ops generated by torchgen receive the pt2_compliant tag.
703        if namespace == "aten" and "pt2_compliant_tag" in valid_tags:
704            tags_inp.append("pt2_compliant_tag")
705
706        tags: set[str] = set()
707        for t in tags_inp:
708            assert len(valid_tags) > 0
709            # TODO: verify that the tag is valid and has an entry in tags.yaml
710            if t in valid_tags:
711                tags.add(t)
712            else:
713                raise AssertionError(f"illegal tag {t}")
714
715        from torchgen.api import cpp
716
717        raw_dispatch = e.pop("dispatch", None)
718        assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
719        dispatch: dict[DispatchKey, BackendMetadata] = {}
720        num_dispatch_keys: int = 0
721        if raw_dispatch is not None:
722            assert not manual_kernel_registration, (
723                "cannot specify both manual_kernel_registration and dispatch; with "
724                "manual registration, dispatch has no effect!"
725            )
726            redundant_composite_implicit_autograd = False
727            for ks, v in raw_dispatch.items():
728                if ks == "__line__":
729                    continue  # not worth tracking line numbers for dispatch entries
730                assert isinstance(
731                    ks, str
732                ), f"illegal dispatch key '{ks}' in {raw_dispatch}"
733                assert isinstance(
734                    v, str
735                ), f"illegal dispatch value '{v}' in {raw_dispatch}"
736                for k in ks.split(","):
737                    dispatch_key = DispatchKey.parse(k.strip())
738                    num_dispatch_keys += 1
739
740                    if ignore_keys and dispatch_key in ignore_keys:
741                        continue
742                    assert dispatch_key in dispatch_keys, (
743                        f"Dispatch key {dispatch_key} of kernel {v} "
744                        "is not a supported dispatch key."
745                    )
746                    # We only allow at most 3 levels of namespace for kernels.
747                    # We will append "native" to a custom kernel namespace.
748                    namespace_helper = NamespaceHelper.from_namespaced_entity(
749                        v, max_level=3
750                    )
751                    kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
752                    # Why is 'structured' included? External backends (e.g.
753                    # XLA) opt into which ops are structured independently
754                    # of which in-tree ops are structured
755                    dispatch[dispatch_key] = BackendMetadata(
756                        kernel=namespace_helper.entity_name,
757                        structured=structured
758                        and is_structured_dispatch_key(dispatch_key),
759                        cpp_namespace=(kernel_namespace + "::native"),
760                    )
761                    if (
762                        dispatch_key is DispatchKey.CompositeImplicitAutograd
763                        and v == cpp.name(func)
764                    ):
765                        redundant_composite_implicit_autograd = True
766
767            # We count the number of dispatch keys which have not been ignored to prevent a dispatch table
768            # in which all backend keys are ignored but necessarily kept, remaining compositeimplicit,
769            # from being treated as redundant.
770            assert not (
771                num_dispatch_keys == 1 and redundant_composite_implicit_autograd
772            ), (
773                "unnecessary dispatch table for this function; just delete the dispatch "
774                "key entirely"
775            )
776            # if a function is a structured delegate, deleting the dispatch
777            # table is NOT semantics preserving
778            assert (
779                structured_delegate
780                or dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
781                or dispatch[DispatchKey.CompositeImplicitAutograd].supports_symint()
782                or num_dispatch_keys != 1
783            ), (
784                f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} "
785                f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}.  Rename your implementation to the expected "
786                "name, then delete the dispatch table"
787            )
788        elif not structured and structured_delegate is None:
789            name = str(func.name.name)
790            assert not (
791                name.startswith("new_")
792                or name.endswith("_like")
793                # TODO: maybe it's better to test the return
794                or (
795                    func.arguments.tensor_options
796                    and not func.arguments.has_tensor_arg()
797                )
798            ), (
799                f"expected {name} to have a CompositeExplicitAutograd "
800                "dispatch entry, but there was no dispatch table.  Factory functions "
801                "should not have implicit dispatch as they should not be decomposed "
802                "for __torch_dispatch__"
803            )
804            dispatch[DispatchKey.CompositeImplicitAutograd] = BackendMetadata(
805                cpp.name(func), structured=False, cpp_namespace=DEFAULT_KERNEL_NAMESPACE
806            )
807
808        composites_in_dispatch = [
809            d
810            for d in dispatch
811            if d == DispatchKey.CompositeExplicitAutograd
812            or d == DispatchKey.CompositeExplicitAutogradNonFunctional
813            or d == DispatchKey.CompositeImplicitAutograd
814            or d == DispatchKey.CompositeImplicitAutogradNestedTensor
815        ]
816
817        assert len(composites_in_dispatch) <= 1 or (
818            len(composites_in_dispatch) == 2
819            and (
820                DispatchKey.CompositeExplicitAutogradNonFunctional
821                not in composites_in_dispatch
822            )
823            and (
824                DispatchKey.CompositeImplicitAutogradNestedTensor
825                in composites_in_dispatch
826            )
827        ), (
828            "cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, "
829            "or CompositeImplicitAutograd on a single kernel; each "
830            "strictly subsumes the other.  If you wanted to provide an explicit autograd "
831            "implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only"
832        )
833
834        autogen_str = e.pop("autogen", "")
835        assert isinstance(autogen_str, str)
836        autogen = (
837            []
838            if autogen_str == ""
839            else [OperatorName.parse(x) for x in autogen_str.split(", ")]
840        )
841
842        raw_ufunc_inner_loop = e.pop("ufunc_inner_loop", {})
843        ufunc_inner_loop = {}
844        if isinstance(raw_ufunc_inner_loop, str):
845            ufunc_inner_loop[UfuncKey.Generic] = UfuncInnerLoop.parse(
846                raw_ufunc_inner_loop, UfuncKey.Generic
847            )
848        elif isinstance(raw_ufunc_inner_loop, dict):
849            for k, vo in raw_ufunc_inner_loop.items():
850                if k == "__line__":
851                    continue
852                assert isinstance(k, str), f"ufunc_inner_loop key is not a str: {k}"
853                assert isinstance(vo, str), f"ufunc_inner_loop value is not a str: {v}"
854                ufunc_key = UfuncKey.parse(k)
855                ufunc_inner_loop[ufunc_key] = UfuncInnerLoop.parse(vo, ufunc_key)
856        else:
857            raise AssertionError(
858                f"ufunc_inner_loop not str or dict: {raw_ufunc_inner_loop}"
859            )
860        # Program the BackendIndex for the implicit dispatch entry from ufunc
861        if ufunc_inner_loop:
862            assert structured, "ufunc must be structured"
863
864            # Delay import ufunc here to avoid circular import issue
865            # See: https://github.com/pytorch/pytorch/issues/81294
866            import torchgen.api.ufunc as ufunc
867
868            for dispatch_key in UFUNC_DISPATCH_KEYS:
869                assert (
870                    dispatch_key not in dispatch
871                ), f"ufunc should not have explicit dispatch entry for {dispatch_key}"
872                dispatch[dispatch_key] = BackendMetadata(
873                    kernel=ufunc.schema_kernel_name(func, dispatch_key),
874                    structured=True,
875                    cpp_namespace=DEFAULT_KERNEL_NAMESPACE,
876                )
877
878        if structured_delegate:
879            # Structured functions MUST have a dispatch table
880            is_abstract = True
881        else:
882            is_abstract = (
883                dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
884                and dispatch.keys()
885                != {DispatchKey.CompositeImplicitAutogradNestedTensor}
886                and dispatch.keys()
887                != {
888                    DispatchKey.CompositeImplicitAutograd,
889                    DispatchKey.CompositeImplicitAutogradNestedTensor,
890                }
891            )
892
893        has_composite_implicit_autograd_kernel = (
894            DispatchKey.CompositeImplicitAutograd in dispatch
895        )
896        has_composite_implicit_autograd_nested_tensor_kernel = (
897            DispatchKey.CompositeImplicitAutogradNestedTensor in dispatch
898        )
899        has_composite_explicit_autograd_kernel = (
900            DispatchKey.CompositeExplicitAutograd in dispatch
901        )
902        has_composite_explicit_autograd_non_functional_kernel = (
903            DispatchKey.CompositeExplicitAutogradNonFunctional in dispatch
904        )
905
906        # We aren't going to store dispatch metadata inline in NativeFunctions;
907        # instead it is separately indexed by backend (so other backends can
908        # add more dispatch entries after the fact).  Reindex the individual
909        # metadata by OperatorName!
910        backend_metadata = {k: {func.name: v} for k, v in dispatch.items()}
911
912        # don't care if it exists or not; make it easier to use this function
913        # with other yaml parsers that aren't setting __line__ in the dict
914        e.pop("__line__", None)
915        assert not e, f"leftover entries: {e}"
916
917        # Asserts that we can't do in post_init, because they rely on backend-specific info
918        if structured_delegate is not None:
919            for key in STRUCTURED_DISPATCH_KEYS:
920                assert key not in dispatch, (
921                    f"if structured_delegate, then must not have {key} in dispatch dictionary "
922                    "(it is delegated!)"
923                )
924
925        return (
926            NativeFunction(
927                func=func,
928                use_const_ref_for_mutable_tensors=use_const_ref_for_mutable_tensors,
929                variants=variants,
930                structured=structured,
931                structured_delegate=structured_delegate,
932                structured_inherits=structured_inherits,
933                precomputed=precomputed,
934                autogen=autogen,
935                ufunc_inner_loop=ufunc_inner_loop,
936                manual_kernel_registration=manual_kernel_registration,
937                manual_cpp_binding=manual_cpp_binding,
938                python_module=python_module,
939                category_override=category_override,
940                device_guard=device_guard,
941                device_check=device_check,
942                loc=loc,
943                cpp_no_default_args=cpp_no_default_args,
944                is_abstract=is_abstract,
945                has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel,
946                has_composite_implicit_autograd_nested_tensor_kernel=has_composite_implicit_autograd_nested_tensor_kernel,
947                has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel,
948                has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel,
949                tags=tags,
950                namespace=namespace,
951            ),
952            backend_metadata,
953        )
954
955    def validate_unstructured(self) -> None:
956        # TODO: probably better to accumulate these errors and report them all
957        # at once
958        assert not self.structured, (
959            "This function is structured, but there was "
960            "no valid functional variant of it."
961        )
962        assert self.structured_delegate, (
963            "This function delegates to another structured out function, "
964            "but no valid function was found (the delegate may not exist, or it has the wrong type)"
965        )
966
967    # __post_init__ functions in dataclasses can be used to do extra
968    # validation after construction.
969    #
970    # Notice that we don't do any type validation here.  In fact, we
971    # rely exclusively on mypy to check if you've done types correctly!
972    # Validation is for nontrivial invariants that cannot be (conveniently)
973    # encoded in the type system.
974    def __post_init__(self) -> None:
975        if self.func.arguments.out:
976            assert self.variants == {Variant.function}, (
977                "Native functions with out arguments MUST "
978                "be declared with only function variant; e.g., variants: function; "
979                "otherwise you will tickle a Python argument binding bug "
980                "(which usually manifests itself as the result variable being undefined.)"
981            )
982        if self.structured:
983            assert self.func.kind() == SchemaKind.out, (
984                "Put structured field on the out= "
985                "variant of a function; did you mean structured_delegate?"
986            )
987            assert (
988                self.device_guard
989            ), "device_guard: False is not respected by structured kernels"
990        if self.structured_delegate:
991            assert self.func.kind() != SchemaKind.out, (
992                "structured_delegate field not allowed "
993                "on out= functions; did you mean structured?"
994            )
995            assert (
996                self.device_guard
997            ), "device_guard: False is not respected by structured kernels"
998        # Technically, with the asserts above, this assert is impossible to
999        # happen
1000        assert not (
1001            self.structured and self.structured_delegate
1002        ), "Cannot have both structured and structured_delegate on function"
1003        defaulted_arguments = {
1004            a.name for a in self.func.schema_order_arguments() if a.default is not None
1005        }
1006        invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments)
1007        assert len(invalid_args) == 0, f"Invalid cpp_no_default_args: {invalid_args}"
1008        if self.structured_inherits is not None:
1009            assert (
1010                self.structured
1011            ), "structured_inherits must also imply structured: True"
1012        if str(self.func.name).startswith("_foreach"):
1013            assert self.device_check == DeviceCheckType.NoCheck, (
1014                "foreach kernels fall back to slow path when tensor are on different devices, "
1015                "device_check not allowed to be enabled"
1016            )
1017
1018        # NB: if your function accidentally has rand/dropout/... in its name
1019        # but is not actually random, feel free to amend this to special case
1020        if (
1021            "rand" in str(self.func.name)
1022            or (
1023                (
1024                    "dropout" in str(self.func.name)
1025                    or any(
1026                        "dropout" in arg.name for arg in self.func.arguments.flat_all
1027                    )
1028                )
1029                # Backwards of dropout is typically deterministic
1030                and "backward" not in str(self.func.name)
1031                and str(self.func.name.name) not in ["_cudnn_init_dropout_state"]
1032            )
1033            or self.func.arguments.has_generator_arg()
1034        ):
1035            assert "nondeterministic_seeded" in self.tags, str(self.func.name)
1036
1037    @property
1038    def has_composite_kernel(self) -> bool:
1039        return (
1040            self.has_composite_implicit_autograd_kernel
1041            or self.has_composite_explicit_autograd_kernel
1042            or self.has_composite_explicit_autograd_non_functional_kernel
1043        ) or (
1044            self.has_composite_implicit_autograd_kernel
1045            and self.has_composite_implicit_autograd_nested_tensor_kernel
1046        )
1047
1048    @property
1049    def is_view_op(self) -> bool:
1050        rets = self.func.returns
1051        is_non_mutating_view = len(rets) > 0 and any(
1052            r.annotation is not None and not r.annotation.is_write for r in rets
1053        )
1054        # See Note [resize_ in Functionalization] for more dtails
1055        is_inplace_view = (
1056            "inplace_view" in self.tags
1057            and str(self.func.name) != "resize_"
1058            and str(self.func.name) != "resize_as_"
1059        )
1060        is_wildcard_view = any(
1061            inp.annotation is not None and "*" in inp.annotation.alias_set_after
1062            for inp in self.func.schema_order_arguments()
1063        )
1064        return is_non_mutating_view or is_inplace_view or is_wildcard_view
1065
1066    @property
1067    def view_schema_kind(self) -> ViewSchemaKind:
1068        if self.is_view_op and self.func.name.name.inplace:
1069            assert "inplace_view" in self.tags
1070            return ViewSchemaKind.aliasing_inplace
1071        if self.is_view_op:
1072            return ViewSchemaKind.aliasing
1073        else:
1074            return ViewSchemaKind.non_aliasing
1075
1076    @property
1077    def root_name(self) -> str:
1078        return self.func.name.name.base
1079
1080    @property
1081    def part_of_structured_group(self) -> bool:
1082        return self.structured or self.structured_delegate is not None
1083
1084
1085class SchemaKind(Enum):
1086    functional = auto()
1087    inplace = auto()
1088    out = auto()
1089    mutable = auto()
1090    scratch = auto()
1091
1092
1093# A structured kernel is guaranteed to have a functional and out variant, and
1094# optionally an inplace variant.
1095#
1096# NB: we create NativeFunctionsGroup *even if* the function is not
1097# actually annotated structured.  Test the structured boolean to see if it
1098# actually is structured or not.
1099@dataclass(frozen=True)
1100class NativeFunctionsGroup:
1101    functional: NativeFunction
1102    inplace: NativeFunction | None
1103    mutable: NativeFunction | None
1104    out: NativeFunction
1105
1106    @property
1107    def structured(self) -> bool:
1108        # Whether or not the operator has a meta() function. This information is backend-agnostic.
1109        return self.out.structured
1110
1111    def __post_init__(self) -> None:
1112        test_sig: FunctionSchema = self.functional.func.signature()
1113        for f in self.functions():
1114            if test_sig != f.func.signature():
1115                raise AssertionError(
1116                    "NativeFunctionsGroup constructed from two NativeFunctions "
1117                    f"that don't have matching signatures: {test_sig} != {f.func.signature()}"
1118                )
1119
1120            if self.structured != f.part_of_structured_group:
1121                raise AssertionError(
1122                    "NativeFunctionsGroup constructed from structured and unstructured "
1123                    f"functions: {self.out.func.name} and {f.func.name}"
1124                )
1125        assert self.functional.func.kind() == SchemaKind.functional
1126        assert self.out.func.kind() == SchemaKind.out
1127        assert self.functional.namespace == self.out.namespace
1128        if self.inplace is not None:
1129            assert self.inplace.func.kind() == SchemaKind.inplace
1130            assert self.inplace.namespace == self.functional.namespace
1131
1132        if self.mutable is not None:
1133            assert self.mutable.func.kind() == SchemaKind.mutable
1134            assert self.mutable.namespace == self.functional.namespace
1135            # See Note [Overload Ambiguity With Functional Variants]
1136            assert self.functional.func.name.name.functional_overload
1137
1138        if self.structured:
1139            # For now, structured composite kernels are not supported (need some
1140            # design work to figure out how to make the composite case work)
1141            assert (
1142                not self.out.has_composite_implicit_autograd_kernel
1143                and not self.out.has_composite_implicit_autograd_nested_tensor_kernel
1144            )
1145
1146            assert self.functional.structured_delegate == self.out.func.name, (
1147                f"{self.functional.func.name} delegates to {self.functional.structured_delegate} "
1148                f"but its actual delegate is {self.out.func.name}"
1149            )
1150            if self.inplace is not None:
1151                assert self.inplace.structured_delegate == self.out.func.name
1152
1153        generated_fns = sorted(
1154            [str(f.func.name) for f in self.functions() if "generated" in f.tags]
1155        )
1156        generated_fns_str = ", ".join(str(x) for x in generated_fns)
1157        expected_generated_fns: set[str] = set()
1158        for f in self.functions():
1159            expected_generated_fns.update(str(op) for op in f.autogen)
1160        expected_generated_fns_str = ", ".join(
1161            str(x) for x in sorted(expected_generated_fns)
1162        )
1163        if len(expected_generated_fns) == 0 and len(generated_fns) > 0:
1164            raise RuntimeError(
1165                f"The codegen expects to be able to generate '{generated_fns_str}'."
1166                " In order to generate them however, we expect them to be called out explicitly in the yaml."
1167                f" Please add an 'autogen: {generated_fns_str}' line to the entry for {str(f.func.name)}"
1168            )
1169        if expected_generated_fns_str != generated_fns_str:
1170            raise RuntimeError(
1171                f"The codegen expects to be able to generate '{generated_fns_str}'."
1172                f" To do so, it expects a line: 'autogen: {generated_fns_str}'."
1173                f" Instead, it found 'autogen: {expected_generated_fns_str}'"
1174            )
1175
1176    def signature(self) -> FunctionSchema:
1177        return self.out.func.signature()
1178
1179    def functions(self) -> Iterator[NativeFunction]:
1180        yield self.functional
1181        yield self.out
1182        if self.inplace is not None:
1183            yield self.inplace
1184        if self.mutable is not None:
1185            yield self.mutable
1186
1187    @property
1188    def root_name(self) -> str:
1189        return self.functional.root_name
1190
1191    @staticmethod
1192    def from_dict(d: dict[SchemaKind, NativeFunction]) -> NativeFunctionsGroup | None:
1193        assert d
1194        if len(d) == 1:
1195            return None
1196        d = dict(d)  # non-destructive updates please
1197        functional = d.pop(SchemaKind.functional, None)
1198        inplace = d.pop(SchemaKind.inplace, None)
1199        mutable = d.pop(SchemaKind.mutable, None)
1200        out = d.pop(SchemaKind.out, None)
1201        assert not d
1202        assert functional is not None
1203        # There are a few operators which only have functional/inplace variants;
1204        # these don't count as structured for our purposes here
1205        if out is None:
1206            return None
1207        # assuming all variants have the same namespace
1208        return NativeFunctionsGroup(
1209            functional=functional,
1210            inplace=inplace,
1211            mutable=mutable,
1212            out=out,
1213        )
1214
1215
1216@dataclass(frozen=True)
1217class BackendMetadata:
1218    # The name of the backend kernel, for a given operator
1219    # for in-tree backends. These names come directly from the 'dispatch" field
1220    # in native_functions.yaml. The dispatch entry is optional; in that
1221    # case, that is equivalent to having written:
1222    #
1223    #   dispatch:
1224    #       CompositeImplicitAutograd: $operator_name
1225    kernel: str
1226    # Whether or not the operator has a structured kernel implemented, for this particular backend.
1227    # For in-tree backends, they all have the same value for structured- this is listed
1228    # in native_functions.yaml.
1229    # However, external backends like XLA can indendently toggle which ops are structured.
1230    structured: bool
1231
1232    # The namespace for kernels, default value: DEFAULT_KERNEL_NAMESPACE
1233    cpp_namespace: str
1234
1235    def supports_symint(self) -> bool:
1236        return "_symint" in self.kernel
1237
1238
1239@dataclass(frozen=True)
1240class UfuncInnerLoop:
1241    name: str
1242    supported_dtypes: OrderedSet[ScalarType]
1243    # key is stored here because it affects the semantics of name,
1244    # so its helpful to have them together for further processing
1245    ufunc_key: UfuncKey
1246
1247    @staticmethod
1248    def parse(value: str, ufunc_key: UfuncKey) -> UfuncInnerLoop:
1249        name, supported_dtypes_str = value.split(" ", 1)
1250        assert supported_dtypes_str[0] == "("
1251        assert supported_dtypes_str[-1] == ")"
1252        supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
1253        for k in supported_dtypes_str[1:-1].split(", "):
1254            supported_dtypes |= ScalarType.parse_set(k)
1255        return UfuncInnerLoop(
1256            name=name, supported_dtypes=supported_dtypes, ufunc_key=ufunc_key
1257        )
1258
1259
1260# BackendIndex represents a backend.
1261# The BackendIndex encodes per-operator information that is potentially different
1262# for each backend. The most obvious example is the name of the kernel
1263# (the 'dispatch' entry in native_functions.yaml).
1264# However, there can be other examples of different backends having different information.
1265# External backends can choose to opt their kernels to be structured independently from in-tree backends,
1266# which means that this information isn't inherently tied to a NativeFunction- it's different per backend.
1267@dataclass(frozen=True)
1268class BackendIndex:
1269    dispatch_key: DispatchKey
1270    # Mainly important for structured kernels, this determines which variant in the operator group is used to implement the others.
1271    # All in-tree ops use out kernels, while XLA uses functional kernels.
1272    use_out_as_primary: bool
1273    # Whether the backend requires a device guard, and device checks.
1274    # For in-tree backends, this is currently just CUDA/HIP
1275    # For out-of-tree backends, this is currently just Intel XPU
1276    device_guard: bool
1277    # Whether the backend is in-tree (CPU/CUDA) or out-of-tree (XLA)
1278    external: bool
1279    # Other backend-specific information that is on a per-operator basis
1280    index: dict[OperatorName, BackendMetadata]
1281
1282    @staticmethod
1283    def grow_index(
1284        parent_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
1285        child_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
1286    ) -> None:
1287        for k, v in child_index.items():
1288            for op_name, metadata in v.items():
1289                assert (
1290                    op_name not in parent_index[k]
1291                ), f"duplicate operator {op_name} for dispatch key {k}"
1292                parent_index[k][op_name] = metadata
1293
1294    def primary(self, g: NativeFunctionsGroup) -> NativeFunction:
1295        if self.use_out_as_primary:
1296            return g.out
1297        else:
1298            return g.functional
1299
1300    def has_kernel(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
1301        m = self.get_kernel(g)
1302        return m is not None
1303
1304    def get_kernel(
1305        self, g: NativeFunction | NativeFunctionsGroup
1306    ) -> BackendMetadata | None:
1307        if isinstance(g, NativeFunction):
1308            f = g
1309        elif isinstance(g, NativeFunctionsGroup):
1310            f = self.primary(g)
1311        else:
1312            assert_never(g)
1313        if f.func.name not in self.index:
1314            return None
1315        return self.index[f.func.name]
1316
1317    def native_function_class_name(self) -> str | None:
1318        if self.external:
1319            return f"{str(self.dispatch_key)}NativeFunctions"
1320        else:
1321            # TODO: This discrepancy isn't required; we could also generated
1322            # a class for in-tree kernels. It'll just require carefully
1323            # updating every kernel definition + callsite of every in-tree aten kernel.
1324            return None
1325
1326
1327# The function schema is undoubtedly the most important data structure
1328# in all of the codegen, as it defines the type signature for operators,
1329# and most of the code generation we do is type directed (e.g., look at
1330# the types, decide what to do.  Think about how we code generate
1331# C++ function stubs!)
1332#
1333# We will also see in this class the general structure for how we model
1334# data in this code generation.  A few notable properties to point out
1335# ahead of time:
1336#
1337#   - These dataclasses are a *lossless* representation of the strings
1338#     they are parsed from.  In fact, we assert that given the
1339#     information stored in the dataclass, we can exactly reconstruct
1340#     the string we parsed from (and assert this inside the parse
1341#     definition).  There are a few reasons for this:
1342#
1343#       - If you find that it is difficult to reconstruct the string
1344#         given a dataclass, that is a clue that you are data
1345#         representation is wrong.
1346#
1347#       - It helps ensure that all relevant information is present
1348#         in the dataclass, so that downstream users aren't tempted
1349#         to reparse the original string to get some information
1350#         that was omitted.
1351#
1352#       - It forces you to represent the data in-memory in the same way
1353#         it is recorded textually, which makes the dataclasses easier
1354#         to understand for someone who is familiar with the
1355#         textual format.  (As a tradeoff, it means you have to model
1356#         the syntax, even when it is inconvenient.  But maybe that means
1357#         the syntax is bad!)  If you don't understand the internal
1358#         representation, go look at the printing code to see how
1359#         it maps onto the surface syntax!
1360#
1361#       - It makes it easy to test the parsing code, as parsing code
1362#         that is inconsistent with the string code will fail early
1363#         and loudly.  (As a tradeoff, it makes the parsing code a bit
1364#         brittle (in particular, with trivial whitespace changes you
1365#         are likely to trigger an assert error).
1366#
1367#     In general, try to make the __str__ code as simple as possible
1368#     (even at the cost of more complex parsing logic.)  Additionally,
1369#     try to minimize redundancy in data representation.  (Precomputed
1370#     fields are OK though: they are defined as a simple function on
1371#     the canonical representation in question.)
1372#
1373#   - These dataclasses are all frozen; once constructed their
1374#     values never change.  This makes it easy to tell where any
1375#     given data came from: just look to the constructor.  As a
1376#     tradeoff, you can't easily "decorate" a schema with extra
1377#     information from a post-facto analysis.  We impose this
1378#     restriction to make these structures more understandable.
1379#
1380@dataclass(frozen=True)
1381class FunctionSchema:
1382    # The name of the operator this function schema describes.
1383    name: OperatorName
1384
1385    arguments: Arguments
1386
1387    # TODO: Need to handle collisions with argument names at some point
1388    returns: tuple[Return, ...]
1389
1390    @property
1391    def is_mutable(self) -> bool:
1392        def is_write(arg: Argument) -> bool:
1393            if arg.annotation is None:
1394                return False
1395            return arg.annotation.is_write
1396
1397        # Corresponds to torch._C._FunctionSchema.is_mutable
1398        # See aten/src/ATen/core/function_schema.h (keep these in sync)
1399        return any(is_write(a) for a in self.arguments.flat_all)
1400
1401    def schema_order_arguments(self) -> Iterator[Argument]:
1402        return itertools.chain(
1403            self.arguments.flat_positional,
1404            self.arguments.flat_kwarg_only,
1405            self.arguments.out,
1406        )
1407
1408    decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)")
1409
1410    @staticmethod
1411    def parse(func: str) -> FunctionSchema:
1412        # We should probably get a proper parser here
1413        decls = FunctionSchema.decl_re.findall(func)
1414        assert len(decls) == 1, f"Invalid function schema: {func}"
1415        ops, args, return_decl = decls[0]
1416        name = OperatorName.parse(ops)
1417        arguments = Arguments.parse(args)
1418        returns = parse_returns(return_decl)
1419        r = FunctionSchema(name=name, arguments=arguments, returns=returns)
1420        assert str(r) == func, f"{str(r)} != {func}"
1421        return r
1422
1423    def returns_are_aliased(self) -> bool:
1424        # We assert earlier that schemas can't have a mix of aliased and non-aliased returns
1425        return any(
1426            r
1427            for r in self.returns
1428            if r.annotation is not None and r.annotation.is_write
1429        )
1430
1431    def __post_init__(self) -> None:
1432        for arg, ret in zip(self.arguments.out, self.returns):
1433            assert arg.annotation == ret.annotation, (
1434                "Out arguments must have matching return Tensor; furthermore, "
1435                "the ith-argument needs to correspond to the ith return"
1436            )
1437        # We also enforce that if you have any mutable, positional args, then they are not returned.
1438        # This makes it easier to group these functions properly with their functional/out= counterparts.
1439        for a in self.arguments.post_self_positional_mutable:
1440            assert not any(
1441                a.annotation == r.annotation for r in self.returns
1442            ), f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}"
1443        # Invariant: we expect out arguments to appear as keyword arguments in the schema.
1444        # This means that all mutable returns should be aliased to a keyword argument
1445        # (except for "self", which we explicitly don't treat as an out argument because of its use in methods)
1446        # See Note [is_out_fn]
1447        out_and_self = list(self.arguments.out) + [
1448            arg for arg in self.arguments.flat_positional if arg.name == "self"
1449        ]
1450        mutable_returns = [
1451            ret
1452            for ret in self.returns
1453            if ret.annotation is not None and ret.annotation.is_write
1454        ]
1455        immutable_returns = [
1456            ret
1457            for ret in self.returns
1458            if ret.annotation is None or not ret.annotation.is_write
1459        ]
1460        # Some assertions: We don't want any functions with a return type of "-> (Tensor(a!), Tensor)",
1461        # because:
1462        # (1) It's more annoying to handle properly
1463        # (2) It's unnecessary - you can't method-chain on the first (mutated) output because it's part of a tuple.
1464        # Instead, we expect the (a!) argument to not be returned.
1465        assert (
1466            len(mutable_returns) == 0 or len(immutable_returns) == 0
1467        ), f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}"
1468        for ret in mutable_returns:
1469            assert any(ret.annotation == arg.annotation for arg in out_and_self), (
1470                'All mutable returns must be aliased either to a keyword argument, or to "self". '
1471                "Did you forget to mark an out argument as keyword-only?"
1472            )
1473        if self.arguments.out:
1474            # out= ops that return their mutable inputs are only really useful for method chaining.
1475            # And method chaining is only really useful if the thing you're returning is a plain Tensor.
1476            # So ideally, we'd enforce that out= ops with a single plain mutable tensor should return the tensor,
1477            # and all other types of out= op schemas should return void.
1478            # There are a bunch of existing out= ops that return tuples of tensors though, so we're stuck with allowing that.
1479            if any(a.type != BaseType(BaseTy.Tensor) for a in self.arguments.out):
1480                assert (
1481                    len(self.returns) == 0
1482                ), "out= ops that accept tensor lists as out arguments "
1483                "are expected to have no return type (since you can't do method chaining on them)"
1484            else:
1485                # mutable keyword arguments whose name has _scratch_ prefix are
1486                # scratch tensors for memory planning and should not be returned
1487                assert len(
1488                    [
1489                        arg
1490                        for arg in self.arguments.out
1491                        if not arg.name.startswith("_scratch_")
1492                    ]
1493                ) == len(
1494                    self.returns
1495                ), "Must return as many arguments as there are out arguments, or no return at all"
1496
1497        if self.name.name.inplace:
1498            self_a = self.arguments.self_arg
1499            assert (
1500                self_a
1501                and self_a.argument.annotation
1502                and self_a.argument.annotation.is_write
1503            )
1504            if self_a.argument.type == BaseType(BaseTy.Tensor):
1505                # All inplace ops with an ordinary `Tensor self` argument should return self,
1506                # to allow for method chaining.
1507                assert (
1508                    len(self.returns) == 1
1509                    and self.returns[0].annotation == self_a.argument.annotation
1510                )
1511            else:
1512                # You can't method chain on non-tensor self arguments though (like a List[Tensor])
1513                # so in all other cases we expect the return type to be none.
1514                assert len(self.returns) == 0
1515
1516        if self.arguments.tensor_options is not None:
1517            assert self.kind() == SchemaKind.functional, (
1518                "Found an operator that is not functional or out variant, but has tensor options arguments."
1519                "This is not allowed- tensor options arguments are only allowed for factory functions."
1520                f"schema: {str(self)}"
1521            )
1522        if self.is_functional_fn():
1523            assert self.kind() == SchemaKind.functional, (
1524                "Found an operator that is not functional, but its overload contains the string 'functional'."
1525                "This is a special keyword in the codegen, please use a different overload name."
1526                f"schema: {str(self)}"
1527            )
1528
1529    def is_functional_fn(self) -> bool:
1530        return "functional" in self.name.overload_name
1531
1532    def is_out_fn(self) -> bool:
1533        # Note [is_out_fn]
1534        #
1535        # out functions are the variants which take an explicit out= argument
1536        # to populate into.  We need to know if a schema corresponds to an
1537        # out function for several reasons:
1538        #
1539        #   - They codegen differently in C++ API
1540        #       - codegen to at::add_out rather than at::add
1541        #       - out argument is moved to front of C++ argument list
1542        #
1543        # out functions are DEFINED to be any function with a keyword-only
1544        # argument that is mutable.  In principle, this could lead to a
1545        # false positive if you define a function that mutates a
1546        # kwarg only argument, but this isn't the "true" output of this
1547        # function.  A more robust definition that would work in this
1548        # case would also look at:
1549        #
1550        #   - The output types.  Out functions take in the arguments
1551        #     they mutate and then return them again; this is sort
1552        #     of "definitionally" what makes something an out function.
1553        #     Historically, we DO check this for consistency.
1554        #   - Correspondence with pure variant.  An out function
1555        #     should have a signature equivalent to its pure variant,
1556        #     but just with extra kwargs for the output elements.  This
1557        #     is difficult to actually check for and historically
1558        #     we only do this check in tools/
1559        return bool(self.arguments.out)
1560
1561    def kind(self) -> SchemaKind:
1562        """
1563        What kind of schema is this?  A functional schema is one
1564        that returns a newly allocated output; an inplace schema
1565        modifies the self argument inplace; an out schema writes
1566        the result into an explicitly provided out argument.
1567        """
1568        is_out = bool(self.arguments.out)
1569        is_scratch = bool(
1570            [arg for arg in self.arguments.out if arg.name.startswith("_scratch_")]
1571        )
1572        is_inplace = self.name.name.inplace
1573        is_mutable = any(
1574            a.annotation is not None and a.annotation.is_write
1575            for a in self.arguments.post_self_positional
1576        )
1577        assert not (is_out and is_inplace)
1578        # out= and inplace schemas can also have post_self_positional mutable args,
1579        # but we give precedence to out= and inplace when deciding the schema kind.
1580        # Tradeoff: we probably don't want to have to teach codegen that looks at inplace ops
1581        # to also worry about mutable post_self_positional arguments,
1582        # but it seems like a much bigger lift to classify them has having a new schema kind.
1583        # The number of ops that fit in this strange category is small enough that
1584        # we can probably manually write code for them instead of forcing the codegen to handle them.
1585        if is_inplace:
1586            return SchemaKind.inplace
1587        elif is_scratch:
1588            assert (
1589                is_out
1590            ), "invariant: all scratch operators are expected to be out= operators too"
1591            return SchemaKind.scratch
1592        elif is_out:
1593            assert (
1594                not is_scratch
1595            ), "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!"
1596            return SchemaKind.out
1597        elif is_mutable:
1598            return SchemaKind.mutable
1599        else:
1600            return SchemaKind.functional
1601
1602    # For every return:
1603    # - If the return aliases an input, we return the input name
1604    # - Otherwise, we return None.
1605    # If return names were enforced to be consistent with aliasing information, then we wouldn't need this.
1606    def aliased_return_names(self) -> list[str | None]:
1607        outs: list[str | None] = []
1608        for r in self.returns:
1609            aliased_args = [
1610                a
1611                for a in self.arguments.flat_all
1612                if a.annotation is not None and a.annotation == r.annotation
1613            ]
1614            if len(aliased_args) == 0:
1615                outs.append(None)
1616            elif len(aliased_args) == 1:
1617                outs.append(aliased_args[0].name)
1618            else:
1619                aliased_names = ", ".join(a.name for a in aliased_args)
1620                raise AssertionError(
1621                    f"Found a return ({r.name})that aliases multiple inputs ({aliased_names})"
1622                )
1623        return outs
1624
1625    def signature(
1626        self,
1627        *,
1628        strip_default: bool = False,
1629        strip_view_copy_name: bool = False,
1630        keep_return_names: bool = False,
1631    ) -> FunctionSchema:
1632        """
1633                Certain schemas are 'related', in that they are simply
1634                inplace/out/functional versions of the same function.  This method
1635                factors these schemas into the "core" functional signature which
1636                is equal across all versions.
1637
1638                Here is what normalization happens to the schema to convert
1639                it to a signature:
1640                - The overload name is stripped (name is retained, since
1641                  it expresses semantic content about what the function does)
1642                - Inplace is set False
1643                - Out arguments are stripped
1644                - Mutable post_self_positional args are converted to returns
1645                - Mutability annotations are stripped  (this is sound
1646                  because you cannot overload on mutability annotation)
1647                - Return names are stripped since they are not overloadable and
1648                  some variants have return names but some not
1649                - TensorOptions are dropped
1650                  because out= variants of factory functions don't include them
1651                  (and we want to be able to pair up factory functions with their out variants)
1652
1653                Finally, we want to be able to pair up related "view" and their
1654                corresponding "view_copy" operators. We do this by optionally
1655                stripping the trailing "_copy" from the base name.
1656
1657                Example of a mutable op before and after:
1658
1659                f.func (Mutable operator):
1660        _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask)  # noqa: B950
1661
1662                f.func (Corresponding functional operator):
1663        _fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out)  # noqa: B950
1664
1665                f.func.signature() output:
1666        _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)  # noqa: B950
1667        """
1668
1669        def strip_ret_annotation(r: Return) -> Return:
1670            return Return(
1671                name=r.name if keep_return_names else None,
1672                type=r.type,
1673                annotation=None,
1674            )
1675
1676        base_name = self.name.name.base
1677        if strip_view_copy_name:
1678            if base_name.endswith("_copy"):
1679                base_name = base_name.replace("_copy", "")
1680            elif base_name.endswith("_scatter"):
1681                base_name = base_name.replace("scatter", "inverse")
1682
1683        # find mutable inputs that are not originally returned, and convert them to returns
1684        returns_from_mutable_inputs = tuple(
1685            # When we're grouping functions we strip the return names,
1686            # but when we're generating the actual functional variants then we follow
1687            # a convention for what to name the returns
1688            Return(
1689                name=f"{a.name}_out" if keep_return_names else None,
1690                type=a.type,
1691                annotation=None,
1692            )
1693            for a in itertools.chain(
1694                # Order is important here (otherwise e.g. inplace with mutable args
1695                # and out= with mutable args won't have the same signature)
1696                [self.arguments.self_arg.argument]
1697                if self.arguments.self_arg is not None
1698                else [],
1699                self.arguments.out,
1700                self.arguments.post_self_positional,
1701            )
1702            if a.annotation is not None
1703            and a.annotation.is_write
1704            and not any(a.annotation == r.annotation for r in self.returns)
1705        )
1706        original_returns = tuple(map(strip_ret_annotation, self.returns))
1707        # Ordering is important here. We expect the "mutable input" returns to come last.
1708        returns = original_returns + returns_from_mutable_inputs
1709
1710        args_sig = self.arguments.signature(strip_default=strip_default)
1711        # See Note [bernoulli.p schema]
1712        if str(self.name) == "bernoulli.p":
1713            args_sig = Arguments.parse(str(args_sig).replace("float p", "float p=0.5"))
1714
1715        return FunctionSchema(
1716            name=OperatorName(
1717                name=BaseOperatorName(
1718                    base=base_name,
1719                    inplace=False,
1720                    dunder_method=self.name.name.dunder_method,
1721                ),
1722                overload_name="",  # stripped
1723            ),
1724            arguments=args_sig,
1725            returns=returns,
1726        )
1727
1728    def view_signature(self) -> FunctionSchema:
1729        return self.signature(strip_view_copy_name=True)
1730
1731    def with_name(self, name: OperatorName) -> FunctionSchema:
1732        return FunctionSchema(
1733            name=name,
1734            arguments=self.arguments,
1735            returns=self.returns,
1736        )
1737
1738    @property
1739    def modifies_arguments(self) -> bool:
1740        return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable]
1741
1742    def has_symint(self) -> bool:
1743        return self.arguments.has_symint_arg()
1744
1745    def __str__(self) -> str:
1746        all_arguments_str = str(self.arguments)
1747        if len(self.returns) == 1:
1748            returns = str(self.returns[0])  # omit parentheses
1749        else:
1750            returns = "(" + ", ".join(map(str, self.returns)) + ")"
1751        return f"{self.name}({all_arguments_str}) -> {returns}"
1752
1753
1754# Here is the rest of the data model, described more briefly.
1755
1756
1757# Simplified version for what actually shows up in built-ins.
1758# Look at alias_info.h for expanded syntax.  If you need the structure,
1759# you also need to make this structure recursive so it can be lined
1760# up with the type components too.  For primitives this isn't really
1761# necessary
1762@dataclass(frozen=True)
1763class Annotation:
1764    # Typically only has one element.  Not actually a set so
1765    # we can conveniently assume it is canonically ordered
1766    alias_set: tuple[str, ...]
1767    is_write: bool
1768    alias_set_after: tuple[str, ...]
1769
1770    @staticmethod
1771    def parse(ann: str) -> Annotation:
1772        # TODO: implement a proper parser if this gets more ugly
1773        # Regex Explanation:
1774        # Example: "a! -> a|b"
1775        # Group #1: alias before optional '|', required. Matches the first
1776        #   character 'a' in the example
1777        # Group #2: optional alias set after optional '|', matches empty string
1778        #   in the example
1779        # Group #3: optional "is write" flag, matches '!' in the example.
1780        # Group #4: optional section containing arrow, matches " -> a|b" in the
1781        #   example.
1782        # Group #5: optional alias after set, supports wildcard, matches "a|b"
1783        #   in the example.
1784        # Group #6: optional sub-section of alias after set, matches "|b" in the
1785        #   example.
1786        m = re.match(r"^([a-z])(\|[a-z])*(!?)( -> (\*|[a-z](\|[a-z])*))?$", ann)
1787
1788        assert m is not None, f"unrecognized alias annotation {ann}"
1789        before_alias = m.group(1) + (m.group(2) if m.group(2) else "")
1790        alias_set = tuple(before_alias.split("|"))
1791        is_write = m.group(3) == "!"
1792        assert not (
1793            is_write and len(alias_set) > 1
1794        ), f"alias set larger than 1 is not mutable, got {ann} instead."
1795        after_set = tuple(m.group(5).split("|")) if m.group(5) else ()
1796        assert not (
1797            len(before_alias) > 1 and len(after_set) > 1
1798        ), f"before alias set and after alias set cannot be larger than 1 at the same time, got {ann} instead."
1799        r = Annotation(
1800            alias_set=alias_set, is_write=is_write, alias_set_after=after_set
1801        )
1802        assert str(r) == ann, f"{r} != {ann}"
1803        return r
1804
1805    def __str__(self) -> str:
1806        alias_set = "|".join(self.alias_set)
1807        if self.is_write:
1808            alias_set = f"{alias_set}!"
1809        alias_set_after = "|".join(self.alias_set_after)
1810        if alias_set_after:
1811            alias_set = f'{alias_set}{" -> "}{alias_set_after}'
1812        return alias_set
1813
1814
1815# The base class for the type system.  This is also loosely modeled
1816# off of jit_type.h, but we've simplified the hierarchy to focus
1817# in on the aspects of the type system that matter for code generation
1818# (for example, there's no SingleElementType subclass anymore).
1819# You never actually construct a Type; usually it's going to be one
1820# of the subclasses.  If Python had ADTs this would be one!
1821@dataclass(frozen=True)
1822class Type:
1823    @staticmethod
1824    def parse(t: str) -> Type:
1825        r = Type._parse(t)
1826        assert str(r) == t, f"{r} != {t}"
1827        return r
1828
1829    @staticmethod
1830    def _parse(t: str) -> Type:
1831        m = re.match(r"^(.+)\?$", t)
1832        if m is not None:
1833            return OptionalType(Type.parse(m.group(1)))
1834        m = re.match(r"^(.+)\[([0-9]+)?\]$", t)
1835        if m is not None:
1836            size = int(m.group(2)) if m.group(2) is not None else None
1837            return ListType(elem=Type.parse(m.group(1)), size=size)
1838
1839        # '__torch__.torch.classes.' is the prefix for custom class
1840        m = re.match(r"^__torch__\.torch\.classes\.([a-zA-Z0-9_.]+)$", t)
1841        if m is not None:
1842            return CustomClassType(m.group(1))
1843        try:
1844            return BaseType(BaseTy[t])
1845        except KeyError as e:
1846            raise RuntimeError(f"unrecognized type {t}") from e
1847
1848    def __str__(self) -> str:
1849        raise NotImplementedError
1850
1851    # WARNING: These concepts are not very well-defined.  For example,
1852    # is "int?" nullable? How about "int?[]".  They are defined
1853    # so we can conveniently generate legacy Declarations.yaml but
1854    # really we should probably just remove these at some point
1855
1856    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1857        raise NotImplementedError
1858
1859    def is_tensor_like(self) -> bool:
1860        return self.is_base_ty_like(BaseTy.Tensor)
1861
1862    def is_generator_like(self) -> bool:
1863        return self.is_base_ty_like(BaseTy.Generator)
1864
1865    def is_symint_like(self) -> bool:
1866        return self.is_base_ty_like(BaseTy.SymInt)
1867
1868    def is_nullable(self) -> bool:
1869        raise NotImplementedError
1870
1871    def is_list_like(self) -> ListType | None:
1872        raise NotImplementedError
1873
1874
1875# Base types are simple, atomic types with no further structure
1876class BaseTy(Enum):
1877    Generator = auto()
1878    ScalarType = auto()
1879    Tensor = auto()
1880    int = auto()
1881    Dimname = auto()
1882    DimVector = auto()
1883    float = auto()
1884    str = auto()
1885    bool = auto()
1886    Layout = auto()
1887    Device = auto()
1888    DeviceIndex = auto()
1889    Scalar = auto()
1890    MemoryFormat = auto()
1891    QScheme = auto()
1892    Storage = auto()
1893    Stream = auto()
1894    SymInt = auto()
1895    SymBool = auto()
1896    ConstQuantizerPtr = auto()  # TODO: rename
1897    GraphModule = auto()
1898
1899
1900@dataclass(frozen=True)
1901class BaseType(Type):
1902    name: BaseTy
1903
1904    def __str__(self) -> str:
1905        return f"{self.name.name}"
1906
1907    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1908        return self.name == base_ty
1909
1910    def is_nullable(self) -> bool:
1911        return False
1912
1913    def is_list_like(self) -> ListType | None:
1914        return None
1915
1916    def is_symint_like(self) -> bool:
1917        return self.name == BaseTy.SymInt
1918
1919
1920# Optional types may be specified, or may also be validly given None
1921@dataclass(frozen=True)
1922class OptionalType(Type):
1923    elem: Type
1924
1925    def __str__(self) -> str:
1926        return f"{self.elem}?"
1927
1928    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1929        return self.elem.is_base_ty_like(base_ty)
1930
1931    def is_symint_like(self) -> bool:
1932        return self.elem.is_symint_like()
1933
1934    def is_nullable(self) -> bool:
1935        return True
1936
1937    def is_list_like(self) -> ListType | None:
1938        return self.elem.is_list_like()
1939
1940
1941# A type representing a PyTorch custom class
1942@dataclass(frozen=True)
1943class CustomClassType(Type):
1944    class_name: str
1945
1946    def __str__(self) -> str:
1947        """
1948        Return the class name will prefix __torch__.torch.classes
1949        """
1950        return f"__torch__.torch.classes.{self.class_name}"
1951
1952    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1953        return False
1954
1955    def is_symint_like(self) -> bool:
1956        return False
1957
1958    def is_nullable(self) -> bool:
1959        """
1960        Assume a custom class is not nullable.
1961        """
1962        return False
1963
1964    def is_list_like(self) -> ListType | None:
1965        return None
1966
1967
1968# List types specify that we may have multiples of an element.  We
1969# also support explicit sizes on list types, but these have
1970# some nontrivial semantics!  (However, for C++ API purposes, explicit
1971# sizes are mostly erased from the type system.)
1972#
1973# DANGER WILL ROBINSON: C++ elaboration depends on elem type; e.g.,
1974# int[] elaborates differently than bool[3]!
1975@dataclass(frozen=True)
1976class ListType(Type):
1977    elem: Type
1978    size: int | None
1979
1980    def __str__(self) -> str:
1981        size = f"{self.size}" if self.size else ""
1982        return f"{self.elem}[{size}]"
1983
1984    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1985        return self.elem.is_base_ty_like(base_ty)
1986
1987    def is_symint_like(self) -> bool:
1988        return self.elem.is_symint_like()
1989
1990    def is_nullable(self) -> bool:
1991        return self.elem.is_nullable()
1992
1993    def is_list_like(self) -> ListType | None:
1994        return self
1995
1996
1997@dataclass(frozen=True)
1998class Argument:
1999    # NB: I didn't put kwarg_only as a boolean field here, unlike
2000    # c10::Argument, so that printing works correctly
2001
2002    name: str
2003    type: Type
2004    default: str | None
2005
2006    # The semantics of the annotation field are a little strange.
2007    #
2008    # Alias annotations parametrize Tensors (since Tensors are the only things
2009    # that can alias.)  This motivates why I write Tensor(a!)?  (and not, for
2010    # example, Tensor?(a!)), because the (a!) describes aliasing on the tensor,
2011    # which may be optional (i.e., the alias annotation should bind first to
2012    # Tensor, before the optional postfix annotation).
2013    #
2014    # However, despite being a property of Tensor, we (and c10::Argument)
2015    # store the annotation at the top level of the Argument, rather than
2016    # inside the embedded Tensor type.  In the C++ version of this
2017    # class, we then go through great lengths to mimic the type
2018    # structure in the annotation structure so we can correlate
2019    # annotations with types.
2020    #
2021    # Now, it turns out, in all applications in code generation, the
2022    # structure of annotated types is very simple.  So we just hard
2023    # code it here.  But if we ever do get anything more complex, this
2024    # model will have to change!
2025    annotation: Annotation | None
2026
2027    @property
2028    def alias_info(self) -> Annotation | None:
2029        return self.annotation
2030
2031    @staticmethod
2032    def parse(arg: str) -> Argument:
2033        name: str
2034        default: str | None
2035        assert " " in arg, f"illegal argument '{arg}'"
2036        if "=" in arg:
2037            assert arg.count("=") == 1, f"illegal argument with default value: '{arg}'"
2038            type_and_annot_and_name, default = arg.split("=")
2039            type_and_annot, name = type_and_annot_and_name.rsplit(" ", 1)
2040            name_and_default = f"{name}={default}"
2041        else:
2042            type_and_annot, name_and_default = arg.rsplit(" ", 1)
2043            name = name_and_default
2044            default = None
2045        # TODO: deduplicate annotation matching with Return
2046        match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
2047        annotation: Annotation | None
2048        if match:
2049            # If you update this, make sure the __str__ still works too
2050            assert match.group(2) in [
2051                "",
2052                "?",
2053                "[]",
2054            ], "unrecognized alias analysis form with Tensor"
2055            type_s = "Tensor" + match.group(2)
2056            annotation = Annotation.parse(match.group(1))
2057        else:
2058            type_s = type_and_annot
2059            annotation = None
2060        type = Type.parse(type_s)
2061        r = Argument(
2062            name=name,
2063            type=type,
2064            default=default,
2065            annotation=annotation,
2066        )
2067        assert str(r) == arg, f"{str(r)} != {arg}"
2068        return r
2069
2070    @property
2071    def is_write(self) -> bool:
2072        return self.annotation is not None and self.annotation.is_write
2073
2074    def __str__(self) -> str:
2075        type = f"{self.type}"
2076        if self.annotation:
2077            assert type in ["Tensor", "Tensor?", "Tensor[]"]
2078            type = type.replace("Tensor", f"Tensor({self.annotation})")
2079        if self.name is None:
2080            return type
2081        else:
2082            mb_default = ""
2083            if self.default:
2084                mb_default = f"={self.default}"
2085            return f"{type} {self.name}{mb_default}"
2086
2087
2088@dataclass(frozen=True)
2089class Return:
2090    name: str | None
2091    type: Type
2092    annotation: Annotation | None
2093
2094    @property
2095    def alias_info(self) -> Annotation | None:
2096        return self.annotation
2097
2098    @staticmethod
2099    def parse(arg: str) -> Return:
2100        name: str | None
2101        if " " in arg:
2102            type_and_annot, name = arg.rsplit(" ", 1)
2103        else:
2104            type_and_annot = arg
2105            name = None
2106        match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
2107        annotation: Annotation | None
2108        if match:
2109            # If you update this, make sure the __str__ still works too
2110            assert match.group(2) in [
2111                "",
2112                "?",
2113                "[]",
2114            ], "unrecognized alias analysis form with Tensor"
2115            type_s = "Tensor" + match.group(2)
2116            annotation = Annotation.parse(match.group(1))
2117        else:
2118            type_s = type_and_annot
2119            annotation = None
2120        type = Type.parse(type_s)
2121        r = Return(
2122            name=name,
2123            type=type,
2124            annotation=annotation,
2125        )
2126        assert str(r) == arg, f"{str(r)} != {arg}"
2127        return r
2128
2129    @property
2130    def is_write(self) -> bool:
2131        return self.annotation is not None and self.annotation.is_write
2132
2133    def __str__(self) -> str:
2134        type = f"{self.type}"
2135        if self.annotation:
2136            assert type in ["Tensor", "Tensor?", "Tensor[]"]
2137            type = type.replace("Tensor", f"Tensor({self.annotation})")
2138        if self.name is None:
2139            return type
2140        else:
2141            return f"{type} {self.name}"
2142
2143
2144# Represents the self argument for functions that may be methods
2145@dataclass(frozen=True)
2146class SelfArgument:
2147    argument: Argument
2148
2149
2150# Bundle of arguments that represent a TensorOptions.  This is mostly
2151# relevant for the public C++ API but we bake it into the core data
2152# model because other APIs often have to interact with it
2153@dataclass(frozen=True)
2154class TensorOptionsArguments:
2155    dtype: Argument
2156    layout: Argument
2157    device: Argument
2158    pin_memory: Argument
2159
2160    def all(self) -> Sequence[Argument]:
2161        return [self.dtype, self.layout, self.device, self.pin_memory]
2162
2163
2164@dataclass(frozen=True)
2165class Arguments:
2166    # pre_self_positional is usually empty, but is notably non-empty
2167    # for where.self, where the condition argument comes before the
2168    # self argument
2169    pre_self_positional: tuple[Argument, ...]
2170    self_arg: SelfArgument | None
2171    post_self_positional: tuple[Argument, ...]
2172
2173    pre_tensor_options_kwarg_only: tuple[Argument, ...]
2174    tensor_options: TensorOptionsArguments | None
2175    # post_tensor_options is typically memory format, which should be
2176    # part of tensor options but isn't right now, and is usually
2177    # placed after the tensor options arguments
2178    post_tensor_options_kwarg_only: tuple[Argument, ...]
2179
2180    # Unlike in the previous codegen, we have factored out 'out' arguments
2181    # in the canonical representation, removing them from kwarg
2182    # arguments.  This choice is justified by numerous downstream
2183    # transformations which treat out arguments specially; additionally,
2184    # you can see that canonicity is not violated!
2185    out: tuple[Argument, ...]  # these are also kwarg-only
2186
2187    @property
2188    def flat_non_out(self) -> Sequence[Argument]:
2189        ret: list[Argument] = []
2190        ret.extend(self.flat_positional)
2191        ret.extend(self.flat_kwarg_only)
2192        return ret
2193
2194    @property
2195    def flat_positional(self) -> Sequence[Argument]:
2196        ret: list[Argument] = []
2197        ret.extend(self.pre_self_positional)
2198        if self.self_arg is not None:
2199            ret.append(self.self_arg.argument)
2200        ret.extend(self.post_self_positional)
2201        return ret
2202
2203    @property
2204    def post_self_positional_mutable(self) -> Sequence[Argument]:
2205        return [a for a in self.post_self_positional if a.is_write]
2206
2207    # NB: doesn't contain out arguments
2208    @property
2209    def flat_kwarg_only(self) -> Sequence[Argument]:
2210        ret: list[Argument] = []
2211        ret.extend(self.pre_tensor_options_kwarg_only)
2212        if self.tensor_options is not None:
2213            ret.extend(self.tensor_options.all())
2214        ret.extend(self.post_tensor_options_kwarg_only)
2215        return ret
2216
2217    @property
2218    def flat_all(self) -> Sequence[Argument]:
2219        ret: list[Argument] = []
2220        ret.extend(self.flat_positional)
2221        ret.extend(self.flat_kwarg_only)
2222        ret.extend(self.out)
2223        return ret
2224
2225    @property
2226    def non_out(
2227        self,
2228    ) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
2229        ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
2230        ret.extend(self.positional)
2231        ret.extend(self.kwarg_only)
2232        return ret
2233
2234    @property
2235    def positional(self) -> Sequence[Argument | SelfArgument]:
2236        ret: list[Argument | SelfArgument] = []
2237        ret.extend(self.pre_self_positional)
2238        if self.self_arg is not None:
2239            ret.append(self.self_arg)
2240        ret.extend(self.post_self_positional)
2241        return ret
2242
2243    @property
2244    def kwarg_only(self) -> Sequence[Argument | TensorOptionsArguments]:
2245        ret: list[Argument | TensorOptionsArguments] = []
2246        ret.extend(self.pre_tensor_options_kwarg_only)
2247        if self.tensor_options is not None:
2248            ret.append(self.tensor_options)
2249        ret.extend(self.post_tensor_options_kwarg_only)
2250        return ret
2251
2252    @property
2253    def all(self) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
2254        ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
2255        ret.extend(self.positional)
2256        ret.extend(self.kwarg_only)
2257        ret.extend(self.out)
2258        return ret
2259
2260    def mutable_arg_names(self) -> list[str]:
2261        return [
2262            a.name
2263            for a in self.flat_all
2264            if a.annotation is not None and a.annotation.is_write
2265        ]
2266
2267    def has_tensor_arg(self) -> bool:
2268        return any(a.type.is_tensor_like() for a in self.flat_non_out)
2269
2270    def has_symint_arg(self) -> bool:
2271        return any(a.type.is_symint_like() for a in self.flat_non_out)
2272
2273    def has_generator_arg(self) -> bool:
2274        return any(a.type.is_generator_like() for a in self.flat_non_out)
2275
2276    def signature(self, *, strip_default: bool = False) -> Arguments:
2277        # dataclasses.replace could be used here, but it is less
2278        # type safe so for now I've opted to type everything out
2279        def strip_arg_annotation(a: Argument) -> Argument:
2280            return Argument(
2281                name=a.name,
2282                type=a.type,
2283                default=a.default if not strip_default else None,
2284                annotation=None,
2285            )
2286
2287        return Arguments(
2288            pre_self_positional=tuple(
2289                map(strip_arg_annotation, self.pre_self_positional)
2290            ),
2291            self_arg=SelfArgument(strip_arg_annotation(self.self_arg.argument))
2292            if self.self_arg is not None
2293            else None,
2294            post_self_positional=tuple(
2295                map(strip_arg_annotation, self.post_self_positional)
2296            ),
2297            # Since TensorOptions are dropped, the post_tensor_options_kwargs are
2298            # converted to pre_tensor_options_kwargs
2299            pre_tensor_options_kwarg_only=tuple(
2300                map(strip_arg_annotation, self.pre_tensor_options_kwarg_only)
2301            )
2302            + tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)),
2303            # TensorOptions are dropped in signature,
2304            # so we can pair factory functions with their out= variants.
2305            tensor_options=None,
2306            post_tensor_options_kwarg_only=(),
2307            # out arguments are dropped in signature
2308            out=(),
2309        )
2310
2311    def remove_self_annotation(self) -> Arguments:
2312        assert self.self_arg is not None
2313        return dataclasses.replace(
2314            self,
2315            self_arg=SelfArgument(
2316                dataclasses.replace(self.self_arg.argument, annotation=None)
2317            ),
2318        )
2319
2320    def with_out_args(self, outs: list[Argument]) -> Arguments:
2321        assert len(self.out) == 0
2322        return dataclasses.replace(
2323            self,
2324            out=tuple(outs),
2325        )
2326
2327    @staticmethod
2328    def _preparse(args: str) -> tuple[list[Argument], list[Argument], list[Argument]]:
2329        positional: list[Argument] = []
2330        kwarg_only: list[Argument] = []
2331        out: list[Argument] = []
2332        arguments_acc = positional
2333
2334        # TODO: Use a real parser here; this will get bamboozled
2335        # by signatures that contain things like std::array<bool, 2> (note the space)
2336        for arg in args.split(", "):
2337            if not arg:
2338                continue
2339            if arg == "*":
2340                assert (
2341                    arguments_acc is positional
2342                ), "invalid syntax: kwarg-only specifier * can only occur once"
2343                arguments_acc = kwarg_only
2344                continue
2345            parg = Argument.parse(arg)
2346            # Currently, we rely directly on the invariant that there are NO
2347            # kwarg-only mutating arguments.  If you want to relax this,
2348            # we will need a more semantic way of matching that takes
2349            # into account return arguments.  In that case, you will have
2350            # to manage out computation a level up, in FunctionSchema.  See Note
2351            # [is_out_fn]
2352            if parg.annotation is not None and parg.annotation.is_write:
2353                if arguments_acc is positional:
2354                    pass  # do nothing
2355                elif arguments_acc is kwarg_only:
2356                    arguments_acc = out
2357            else:
2358                assert arguments_acc is not out
2359            arguments_acc.append(parg)
2360
2361        return positional, kwarg_only, out
2362
2363    @staticmethod
2364    def parse(args: str) -> Arguments:
2365        """
2366        Input: 'int x, int y, int z'
2367        """
2368
2369        # We do this in two phases.  First we parse into three
2370        # main categories: positional, kwarg_only, out.
2371        # Then, we reparse positional and kwarg_only to separate
2372        # out the self argument and tensor options arguments.
2373
2374        positional, kwarg_only, out = Arguments._preparse(args)
2375
2376        # Split self argument
2377        self_ix = None
2378        for i, a in enumerate(positional):
2379            if a.name == "self":
2380                self_ix = i
2381                break
2382        pre_self_positional: list[Argument]
2383        self_arg: SelfArgument | None
2384        post_self_positional: list[Argument]
2385        if self_ix is not None:
2386            pre_self_positional = positional[:self_ix]
2387            self_arg = SelfArgument(positional[self_ix])
2388            post_self_positional = positional[self_ix + 1 :]
2389        else:
2390            pre_self_positional = []
2391            self_arg = None
2392            post_self_positional = positional
2393
2394        # Group tensor options arguments
2395        pre_tensor_options_kwarg_only: list[Argument] = []
2396        tensor_options: TensorOptionsArguments | None = None
2397        post_tensor_options_kwarg_only: list[Argument] = []
2398        kwarg_only_acc = pre_tensor_options_kwarg_only
2399
2400        def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
2401            return lambda a: a.name == name and a.type in [ty, OptionalType(ty)]
2402
2403        predicates = [  # order matters
2404            pred("dtype", Type.parse("ScalarType")),
2405            pred("layout", Type.parse("Layout")),
2406            pred("device", Type.parse("Device")),
2407            pred("pin_memory", Type.parse("bool")),
2408        ]
2409
2410        i = 0
2411        while i < len(kwarg_only):
2412            # If there is enough space...
2413            if i <= len(kwarg_only) - len(predicates):
2414                # And the next len(predicates) arguments look like TensorOptions arguments
2415                if all(
2416                    p(a)
2417                    for p, a in zip(predicates, kwarg_only[i : i + len(predicates)])
2418                ):
2419                    assert kwarg_only_acc is pre_tensor_options_kwarg_only
2420                    # Group them together as one argument
2421                    tensor_options = TensorOptionsArguments(
2422                        dtype=kwarg_only[i],
2423                        layout=kwarg_only[i + 1],
2424                        device=kwarg_only[i + 2],
2425                        pin_memory=kwarg_only[i + 3],
2426                    )
2427                    i += len(predicates)
2428                    kwarg_only_acc = post_tensor_options_kwarg_only
2429                    continue
2430            kwarg_only_acc.append(kwarg_only[i])
2431            i += 1
2432
2433        return Arguments(
2434            pre_self_positional=tuple(pre_self_positional),
2435            self_arg=self_arg,
2436            post_self_positional=tuple(post_self_positional),
2437            pre_tensor_options_kwarg_only=tuple(pre_tensor_options_kwarg_only),
2438            tensor_options=tensor_options,
2439            post_tensor_options_kwarg_only=tuple(post_tensor_options_kwarg_only),
2440            out=tuple(out),
2441        )
2442
2443    def __str__(self) -> str:
2444        all_arguments: list[str] = []
2445        all_arguments.extend(map(str, self.flat_positional))
2446        if self.flat_kwarg_only or self.out:
2447            all_arguments.append("*")
2448        all_arguments.extend(map(str, self.flat_kwarg_only))
2449        all_arguments.extend(map(str, self.out))
2450        return ", ".join(all_arguments)
2451
2452    def __post_init__(self) -> None:
2453        # TODO: These invariants are weirdly asymmetric?
2454        # TODO: Fancier types?
2455        if self.self_arg is None:
2456            assert not self.pre_self_positional
2457        if self.tensor_options is None:
2458            assert not self.post_tensor_options_kwarg_only
2459
2460        # We don't allow any of the following to have argument annotations,
2461        # to keep things simple.
2462        mutable_pre_self_positionals = [
2463            a
2464            for a in self.pre_self_positional
2465            if a.annotation is not None and a.annotation.is_write
2466        ]
2467        assert (
2468            len(mutable_pre_self_positionals) == 0
2469        ), "mutable pre_self_positional arguments are not currently supported in the schema"
2470
2471
2472# Names that validly are __iXXX__ indicating inplace operations.
2473# Taken from https://www.python.org/dev/peps/pep-0203/#new-methods
2474# NB: PyTorch hasn't actually implemented all of these
2475AUGMENTED_ASSIGNMENT_NAMES = [
2476    "add",
2477    "sub",
2478    "mul",
2479    "div",
2480    "mod",
2481    "pow",
2482    "lshift",
2483    "rshift",
2484    "and",
2485    "xor",
2486    "or",
2487]
2488
2489
2490# A BaseOperatorName is what we think of the operator name, without
2491# the overload name.  Unusually, we don't represent this as just a
2492# string; instead, we directly represent a few important semantic
2493# bits of information we derive from the string: namely whether
2494# or not it's inplace (add_) and whether or not it's a double-underscore
2495# method (__add__)
2496@dataclass(frozen=True)
2497class BaseOperatorName:
2498    base: str
2499    inplace: bool
2500    dunder_method: bool
2501    # Note [Overload Ambiguity With Functional Variants]
2502    # A handful of operators have both a "mutable" and a "functional" variant.
2503    # (native_batch_norm is a good example, although this isn't the case today).
2504    # For those operators, the mutable and functional variant take in the same set of
2505    # arguments, but have different alias annotations.
2506    # this makes it ambiguous when you try to resolve an OverloadPacket into an overload,
2507    # given a set of input arguments.
2508    #
2509    # So instead of making the "functional" variant in this case a real overload, e.g:
2510    #   native_batch_norm (mutable variant)
2511    #   native_batch_norm.functional (functional variant)
2512    # we make it a new base operator,
2513    #   native_batch_norm_functional (functional variant)
2514    #
2515    # In an ideal world, we would probably invert this so the operators were:
2516    #   native_batch_norm.mutable (mutable variant)
2517    #   native_batch_norm (functional variant)
2518    #
2519    # Doing that is BC-breaking though, so we're stuck with the above modeling.
2520    functional_overload: bool = False
2521
2522    @staticmethod
2523    def parse(op: str) -> BaseOperatorName:
2524        assert op != ""
2525        assert not op.endswith("_out"), (
2526            "_out suffix is reserved and not permitted for operator names; "
2527            "did you mean to specify an out overload name instead?"
2528        )
2529        m = re.match(r"^__([^_]+)__$", op)
2530        if m is not None:
2531            dunder_method = True
2532            base = m.group(1)
2533            if any(base == f"i{n}" for n in AUGMENTED_ASSIGNMENT_NAMES):
2534                inplace = True
2535                base = base[1:]
2536            else:
2537                inplace = False
2538                # temporary, this is not intrinsically true but
2539                # has been historically true for dunder methods
2540                # we support  (but, if we ever got, say, __int__, this would
2541                # be wrong!)
2542                assert base[0] != "i"
2543        else:
2544            dunder_method = False
2545            base = op
2546            if base[-1] == "_":
2547                inplace = True
2548                base = base[:-1]
2549            else:
2550                inplace = False
2551
2552        # See Note [Overload Ambiguity With Functional Variants]
2553        functional_suffix = "_functional"
2554        if base.endswith(functional_suffix):
2555            functional_overload = True
2556            base = base[: -len(functional_suffix)]
2557            # This seems complicated and unnecessary, so banning dunder methods
2558            # for now on ops that have a functional + mutable variant (like native_batch_norm).
2559            assert not dunder_method and not inplace
2560        else:
2561            functional_overload = False
2562
2563        r = BaseOperatorName(
2564            base=base,
2565            inplace=inplace,
2566            dunder_method=dunder_method,
2567            functional_overload=functional_overload,
2568        )
2569        assert str(r) == op, f"{str(r)} != {op}"
2570        return r
2571
2572    def __str__(self) -> str:
2573        if self.dunder_method:
2574            i = "i" if self.inplace else ""
2575            return f"__{i}{self.base}__"
2576        else:
2577            i = (
2578                "_"
2579                if self.inplace
2580                else "_functional"
2581                if self.functional_overload
2582                else ""
2583            )
2584            return f"{self.base}{i}"
2585
2586
2587# Operator name is the base operator name along with the (typically not
2588# user visible) overload string.
2589@dataclass(frozen=True)
2590class OperatorName:
2591    name: BaseOperatorName
2592    overload_name: str
2593
2594    @staticmethod
2595    def parse(op_name: str) -> OperatorName:
2596        if "." in op_name:
2597            name, overload_name = op_name.split(".", 1)
2598        else:
2599            name = op_name
2600            overload_name = ""
2601        r = OperatorName(name=BaseOperatorName.parse(name), overload_name=overload_name)
2602        assert str(r) == op_name, f"{str(r)} != {op_name}"
2603        return r
2604
2605    def __str__(self) -> str:
2606        if self.overload_name:
2607            return f"{self.name}.{self.overload_name}"
2608        else:
2609            return f"{self.name}"
2610
2611    # NB: This must be synchronized with the naming scheme in
2612    # aten/src/ATen/templates/Operators.h
2613    # Given a function schema "aten::op.overload(...)",
2614    # If there is no overload name, this returns f"{op}"
2615    # If there is an overload name, this returns f"{op}_{overload}"
2616    def unambiguous_name(self) -> str:
2617        if self.overload_name:
2618            return f"{self.name}_{self.overload_name}"
2619        else:
2620            return f"{self.name}"
2621
2622    def remove_inplace(self) -> OperatorName:
2623        return OperatorName(
2624            name=BaseOperatorName(
2625                base=self.name.base,
2626                inplace=False,
2627                dunder_method=self.name.dunder_method,
2628            ),
2629            overload_name=self.overload_name,
2630        )
2631
2632    def with_overload(self, overload: str) -> OperatorName:
2633        return OperatorName(
2634            name=BaseOperatorName(
2635                base=self.name.base,
2636                inplace=False,
2637                dunder_method=self.name.dunder_method,
2638            ),
2639            overload_name=overload,
2640        )
2641
2642
2643def gets_generated_out_inplace_wrapper(
2644    f: NativeFunction, g: NativeFunctionsGroup, b: BackendIndex
2645) -> bool:
2646    return (
2647        f.func.kind() is not SchemaKind.functional
2648        and not b.has_kernel(f)
2649        and b.has_kernel(g.functional)
2650    )
2651
2652
2653# NativeFunction objects that are views (f.is_view_op returns True)
2654# are added into a `NativeFunctionsViewGroup`, which we can use to
2655# easily access the generated (optional) view_copy NativeFunction.
2656# It's convenient to group them together, so we pair them up in NativeFunctionsViewGroup.
2657# See Note [Codegen'd {view}_copy Operators]
2658#
2659# One property of this representation is that in order for a view-like op to be part of
2660# a NativeFunctionsViewGroup, the "aliasing" version of that view op must exist.
2661# There's one case where that doesn't happen: we have a non-aliasing `narrow_copy.out` op,
2662# but don't have corresponding aliasing `narrow.out` op.
2663# This means that `narrow_copy.out` won't appear as a NativeFunctionsViewGroup.
2664@dataclass(frozen=True)
2665class NativeFunctionsViewGroup:
2666    view: NativeFunction
2667    # Note: the {view}_copy operator is optional because we currently don't generate copy variants
2668    # for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views
2669    # (we already get them "for free" through decomposition)
2670    view_copy: NativeFunction | None
2671    # view_inplace ops are also optional, but every view_inplace op should have out-of-place variant.
2672    view_inplace: NativeFunction | None
2673
2674    def __post_init__(self) -> None:
2675        assert self.view.is_view_op
2676        if self.view_copy is None:
2677            assert not gets_generated_view_copy(self.view), (
2678                f"{str(self.view.func.name)} appears to be a new operator that aliases its inputs."
2679                " The codegen expects you to add a corresponding operator to native_functions.yaml:"
2680                f" {get_view_copy_name(self.view)!s}."
2681                " See Note [view_copy NativeFunctions] for details."
2682            )
2683        else:
2684            assert self.view_copy.func.name.name.base.endswith(("_copy", "_scatter"))
2685            assert self.view.func.signature() == self.view_copy.func.signature(
2686                strip_view_copy_name=True,
2687            )
2688            assert "view_copy" in self.view_copy.tags, (
2689                f"{str(self.view_copy.func.name), str(self.view.tags)} appears to be a view_copy operator. The codegen expects"
2690                " view_copy operators to be annotated with the 'view_copy' tag in native_functions.yaml."
2691                " See Note [view_copy NativeFunction] for details."
2692            )
2693        if self.view_inplace is not None:
2694            assert self.view.func.signature() == self.view_inplace.func.signature()
2695
2696        if self.view.has_composite_implicit_autograd_kernel:
2697            if self.view_inplace is not None:
2698                assert self.view_inplace.has_composite_implicit_autograd_kernel, (
2699                    f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
2700                    " both have CompositeImplicitAutograd kernels, or both not have composite kernels."
2701                )
2702        if self.view.has_composite_implicit_autograd_nested_tensor_kernel:
2703            if self.view_inplace is not None:
2704                assert (
2705                    self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel
2706                ), (
2707                    f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
2708                    " both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels."
2709                )
2710
2711    def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]:
2712        yield self.view
2713        if self.view_inplace is not None:
2714            yield self.view_inplace
2715        if self.view_copy is not None and include_copy:
2716            yield self.view_copy
2717
2718    @property
2719    def root_name(self) -> str:
2720        return self.view.root_name
2721
2722    @property
2723    def composite(self) -> bool:
2724        # We currently assert that the "group" is consistent.
2725        # If the view op is composite, then its view_inplace op is too.
2726        return self.view.has_composite_implicit_autograd_kernel
2727
2728
2729def gets_generated_view_copy(f: NativeFunction) -> bool:
2730    # Only aliasing (view) operators get a copy variant.
2731    if not f.is_view_op:
2732        return False
2733    # We don't need to bother generating copy variants for CompositeImplicitAutograd ops,
2734    # because we can let them decompose into base view ops.
2735    if f.has_composite_implicit_autograd_kernel:
2736        return False
2737    # We also don't need to generate copy variants for inplace views.
2738    if "inplace_view" in f.tags:
2739        return False
2740    # Assume ops ending in _inverse have manually-defined copy variants
2741    # (e.g. slice_inverse() has the copy variant slice_scatter()).
2742    # We -could- probably generate these as well, but the codegen will be
2743    # slightly different, and hand-writing these few kernels keeps codegen
2744    # complexity lower.
2745    if f.func.name.name.base.endswith("_inverse"):
2746        return False
2747    return True
2748
2749
2750# Given a NativeFunction that corresponds to a view op,
2751# returns the OperatorName of the corresponding "copy" variant of the op.
2752def get_view_copy_name(f: NativeFunction) -> OperatorName:
2753    # Right now, when asking for a view op's corresponding "view_copy" name
2754    # we assert for sanity that the op is allowed to have a generated view_copy variant.
2755    # (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op).
2756    # However, narrow_copy() already exists as an op directly in native_functions.yaml.
2757    # I'm hardcoding narrow_copy here for now to maintain the assert,
2758    # But we could also just get rid of the assert.
2759    list_of_ops_with_explicit_view_copy_operators = ["narrow"]
2760    if str(f.func.name) not in list_of_ops_with_explicit_view_copy_operators:
2761        assert gets_generated_view_copy(f)
2762
2763    base_name = f"{f.func.name.name.base}_copy"
2764    view_copy_name = OperatorName(
2765        name=BaseOperatorName(
2766            base=base_name, inplace=False, dunder_method=f.func.name.name.dunder_method
2767        ),
2768        overload_name=f.func.name.overload_name,
2769    )
2770    return view_copy_name
2771
2772
2773# Helper functions for parsing argument lists (both inputs and returns)
2774
2775
2776def parse_returns(return_decl: str) -> tuple[Return, ...]:
2777    """
2778    Input: '()'
2779    Output: []
2780    """
2781    if return_decl == "()":
2782        return ()
2783    if return_decl[0] == "(" and return_decl[-1] == ")":
2784        return_decl = return_decl[1:-1]
2785    return tuple(Return.parse(arg) for arg in return_decl.split(", "))
2786
2787
2788# A Precompute instance consists of a map from kernel argument name
2789# to the list of Argument instances that should replace that
2790# kernel argument in the impl function.
2791@dataclass(frozen=True)
2792class Precompute:
2793    # A map from kernel argument name -> a list of precomputed
2794    # elements that replaces/supersedes it.
2795    replace: dict[str, list[Argument]]
2796    # List of precomputed args added without replacement
2797    add: list[Argument]
2798
2799    @staticmethod
2800    def parse(src: object) -> Precompute:
2801        assert isinstance(src, list)
2802
2803        # src is a list of strings of the format:
2804        #   {kernel param name} -> {replacement decl}[, {replacement decl}, ...]
2805        #   [{add decl}[, {add decl}, ...]]
2806        # The last line is optional and contains the precomputed parameters that are
2807        # added without replacement.
2808        # The other lines are parsed to get the names of which precomputed elements
2809        # should replace which kernel arguments.
2810        add_args = []
2811        if " -> " not in src[-1]:
2812            add_list = src[-1].split(",")
2813            add_args = [Argument.parse(name.strip()) for name in add_list]
2814            src = src[:-1]
2815
2816        replace = {}
2817        for raw_replace_item in src:
2818            assert isinstance(raw_replace_item, str)
2819            assert " -> " in raw_replace_item, (
2820                "precomputed parameters without replacement"
2821                " are allowed only in the last line"
2822            )
2823
2824            arg, with_list_raw = raw_replace_item.split(" -> ")
2825            assert (
2826                " " not in arg
2827            ), f"illegal kernel param name '{arg}' in precomputed parameters'"
2828            with_list = with_list_raw.split(",")
2829            with_list_args = [Argument.parse(name.strip()) for name in with_list]
2830            replace[arg] = with_list_args
2831
2832        r = Precompute(replace=replace, add=add_args)
2833        assert r.to_list() == src, "r.to_list() != src"
2834        return r
2835
2836    def __post_init__(self) -> None:
2837        # the template parameters are upper so if these are the
2838        # same then it is ambiguous
2839        for a in self.add:
2840            assert a.name.upper() != a.name
2841        for args in self.replace.values():
2842            for a in args:
2843                assert a.name.upper() != a.name
2844
2845    def to_list(self) -> list[str]:
2846        replace_list = []
2847        for kernel_param, replacement_params in self.replace.items():
2848            replacements = ", ".join(str(param) for param in replacement_params)
2849            replace_list.append(f"{kernel_param} -> {replacements}")
2850
2851        return replace_list
2852