1from __future__ import annotations 2 3import dataclasses 4import itertools 5import re 6from dataclasses import dataclass 7from enum import auto, Enum 8from typing import Callable, Iterator, Sequence 9 10from torchgen.utils import assert_never, NamespaceHelper, OrderedSet 11 12 13# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 14# 15# DATA MODEL 16# 17# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 18# 19# Some general principles for our data model. 20# 21# - Stop using C++ data types as the internal data representation 22# format. Instead, the internal data structures are centered 23# around JIT schema representation. This avoid a big problem 24# with the old codegen where we read in all the types from 25# native_functions.yaml and then immediately had to retranslate 26# them into C++ types. 27# 28# - More semantic data representation. Instead of representing 29# everything as dicts and strings, we define dataclasses for 30# every interesting entity the code generation has to deal with. 31# These dataclasses have strong semantic invariants: for example, 32# we generally require them to roundtrip losslessly into the 33# form they were parsed from. These structures are immutable 34# and you're expected to populate information once during 35# construction. 36 37 38# Represent a source location; used for better error reporting 39@dataclass(frozen=True) 40class Location: 41 file: str 42 line: int 43 44 def __str__(self) -> str: 45 return f"{self.file}:{self.line}" 46 47 48# Valid values of the 'variants' field in native_functions.yaml 49class Variant(Enum): 50 function = auto() 51 method = auto() 52 53 54# Default kernel namespace 55DEFAULT_KERNEL_NAMESPACE = "at::native" 56 57# NOTE: Keep the list in sync with `DispatchKey` in c10/core/DispatchKey.h 58BACKEND_COMPONENTS = "CPU CUDA HIP XLA MTIA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split() 59FUNCTIONALITY_KEYS = [ 60 "", 61 "Quantized", 62 "Sparse", 63 "SparseCsr", 64 "NestedTensor", 65 "Autograd", 66] 67 68# This list guards dispatches that can be used in derivatives.yaml 69# For now we omit AutogradFunctionality and AutogradOther 70AUTOGRAD_KEYS = ["AutogradNestedTensor"] + [ 71 "Autograd" + component for component in BACKEND_COMPONENTS 72] 73 74FRAGMENT_NAMESPACES = {"quantized", "quantized_decomposed"} 75 76 77# This doesn't have to be in sync with the header, it only needs to contain 78# entries that we actually use in the codegen or want pyi entries for 79class DispatchKey(Enum): 80 Undefined = 0 81 CatchAll = Undefined 82 83 FPGA = auto() 84 MAIA = auto() 85 Vulkan = auto() 86 Metal = auto() 87 MKLDNN = auto() 88 OpenGL = auto() 89 OpenCL = auto() 90 IDEEP = auto() 91 CustomRNGKeyId = auto() 92 MkldnnCPU = auto() 93 Sparse = auto() 94 SparseCsr = auto() 95 NestedTensor = auto() 96 Dense = auto() 97 98 PythonTLSSnapshot = auto() 99 PreDispatch = auto() 100 PythonDispatcher = auto() 101 Python = auto() 102 FuncTorchDynamicLayerBackMode = auto() 103 ZeroTensor = auto() 104 Conjugate = auto() 105 Negative = auto() 106 BackendSelect = auto() 107 Named = auto() 108 AutogradOther = auto() 109 AutogradFunctionality = auto() 110 AutogradNestedTensor = auto() 111 Tracer = auto() 112 Autocast = auto() 113 AutocastCPU = auto() 114 AutocastCUDA = auto() 115 Batched = auto() 116 VmapMode = auto() 117 FuncTorchGradWrapper = auto() 118 FuncTorchBatched = auto() 119 BatchedNestedTensor = auto() 120 FuncTorchVmapMode = auto() 121 FuncTorchDynamicLayerFrontMode = auto() 122 Functionalize = auto() 123 TESTING_ONLY_GenericWrapper = auto() 124 TESTING_ONLY_GenericMode = auto() 125 126 ADInplaceOrView = auto() 127 Autograd = auto() 128 CompositeImplicitAutograd = auto() 129 CompositeImplicitAutogradNestedTensor = auto() 130 CompositeExplicitAutograd = auto() 131 CompositeExplicitAutogradNonFunctional = auto() 132 FuncTorchBatchedDecomposition = auto() 133 134 # BEGIN autogenerated 135 CPU = auto() 136 CUDA = auto() 137 HIP = auto() 138 XLA = auto() 139 MTIA = auto() 140 MPS = auto() 141 IPU = auto() 142 XPU = auto() 143 HPU = auto() 144 VE = auto() 145 Lazy = auto() 146 Meta = auto() 147 PrivateUse1 = auto() 148 PrivateUse2 = auto() 149 PrivateUse3 = auto() 150 QuantizedCPU = auto() 151 QuantizedCUDA = auto() 152 QuantizedHIP = auto() 153 QuantizedXLA = auto() 154 QuantizedMTIA = auto() 155 QuantizedMPS = auto() 156 QuantizedIPU = auto() 157 QuantizedXPU = auto() 158 QuantizedHPU = auto() 159 QuantizedVE = auto() 160 QuantizedLazy = auto() 161 QuantizedMeta = auto() 162 QuantizedPrivateUse1 = auto() 163 QuantizedPrivateUse2 = auto() 164 QuantizedPrivateUse3 = auto() 165 SparseCPU = auto() 166 SparseCUDA = auto() 167 SparseHIP = auto() 168 SparseXLA = auto() 169 SparseMTIA = auto() 170 SparseMPS = auto() 171 SparseIPU = auto() 172 SparseXPU = auto() 173 SparseHPU = auto() 174 SparseVE = auto() 175 SparseLazy = auto() 176 SparseMeta = auto() 177 SparsePrivateUse1 = auto() 178 SparsePrivateUse2 = auto() 179 SparsePrivateUse3 = auto() 180 SparseCsrCPU = auto() 181 SparseCsrCUDA = auto() 182 SparseCsrHIP = auto() 183 SparseCsrXLA = auto() 184 SparseCsrMTIA = auto() 185 SparseCsrMPS = auto() 186 SparseCsrIPU = auto() 187 SparseCsrXPU = auto() 188 SparseCsrHPU = auto() 189 SparseCsrVE = auto() 190 SparseCsrLazy = auto() 191 SparseCsrMeta = auto() 192 SparseCsrPrivateUse1 = auto() 193 SparseCsrPrivateUse2 = auto() 194 SparseCsrPrivateUse3 = auto() 195 NestedTensorCPU = auto() 196 NestedTensorCUDA = auto() 197 NestedTensorHIP = auto() 198 NestedTensorXLA = auto() 199 NestedTensorMTIA = auto() 200 NestedTensorMPS = auto() 201 NestedTensorIPU = auto() 202 NestedTensorXPU = auto() 203 NestedTensorHPU = auto() 204 NestedTensorVE = auto() 205 NestedTensorLazy = auto() 206 NestedTensorMeta = auto() 207 NestedTensorPrivateUse1 = auto() 208 NestedTensorPrivateUse2 = auto() 209 NestedTensorPrivateUse3 = auto() 210 AutogradCPU = auto() 211 AutogradCUDA = auto() 212 AutogradHIP = auto() 213 AutogradXLA = auto() 214 AutogradMTIA = auto() 215 AutogradMPS = auto() 216 AutogradIPU = auto() 217 AutogradXPU = auto() 218 AutogradHPU = auto() 219 AutogradVE = auto() 220 AutogradLazy = auto() 221 AutogradMeta = auto() 222 AutogradPrivateUse1 = auto() 223 AutogradPrivateUse2 = auto() 224 AutogradPrivateUse3 = auto() 225 # END autogenerated 226 227 def __str__(self) -> str: 228 return self.name 229 230 def lower(self) -> str: 231 return str(self).lower() 232 233 @staticmethod 234 def parse(value: str) -> DispatchKey: 235 for k, v in DispatchKey.__members__.items(): 236 if k == value: 237 return v 238 raise AssertionError(f"unknown dispatch key {value}") 239 240 241class _TorchDispatchModeKey(Enum): 242 FAKE = auto() 243 PROXY = auto() 244 FUNCTIONAL = auto() 245 246 247def codegen_per_backend_entries() -> str: 248 r = [] 249 for fk in FUNCTIONALITY_KEYS: 250 for bc in BACKEND_COMPONENTS: 251 r.append(f" {fk}{bc} = auto()") 252 return "\n".join(r) 253 254 255for fk in FUNCTIONALITY_KEYS: 256 for bc in BACKEND_COMPONENTS: 257 if not hasattr(DispatchKey, fk + bc): 258 r = codegen_per_backend_entries() 259 print(r) 260 raise RuntimeError( 261 f"Missing {fk}{bc} from DispatchKey enum. Here is the autogenerated list we expect to have:\n\n{r}" 262 ) 263 264 265STRUCTURED_DISPATCH_KEYS = { 266 DispatchKey.MPS, 267 DispatchKey.CUDA, 268 DispatchKey.CPU, 269 DispatchKey.XPU, 270} 271UFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU} 272 273# Set of supported dispatch keys 274dispatch_keys = [ 275 DispatchKey.CPU, 276 DispatchKey.SparseCPU, 277 DispatchKey.SparseCsrCPU, 278 DispatchKey.MkldnnCPU, 279 DispatchKey.CUDA, 280 DispatchKey.MPS, 281 DispatchKey.XPU, 282 DispatchKey.SparseCUDA, 283 DispatchKey.SparseCsrCUDA, 284 DispatchKey.QuantizedCPU, 285 DispatchKey.QuantizedCUDA, 286 DispatchKey.CompositeImplicitAutograd, 287 DispatchKey.CompositeImplicitAutogradNestedTensor, 288 DispatchKey.CompositeExplicitAutograd, 289 DispatchKey.CompositeExplicitAutogradNonFunctional, 290 DispatchKey.NestedTensorCPU, 291 DispatchKey.NestedTensorCUDA, 292 # Meta is a magic key: it is automatically generated for structured 293 # kernels 294 DispatchKey.Meta, 295 DispatchKey.SparseMeta, 296 DispatchKey.SparseCsrMeta, 297 DispatchKey.QuantizedMeta, 298 DispatchKey.NestedTensorMeta, 299 DispatchKey.ZeroTensor, 300] 301 302 303# Dispatch keys that "support all backends". These codegen slightly differently 304# then backend specific keys. 305def is_generic_dispatch_key(dk: DispatchKey) -> bool: 306 return dk in { 307 DispatchKey.CompositeExplicitAutograd, 308 DispatchKey.CompositeExplicitAutogradNonFunctional, 309 DispatchKey.CompositeImplicitAutograd, 310 DispatchKey.CompositeImplicitAutogradNestedTensor, 311 } 312 313 314# CUDA specific dispatch keys 315def is_cuda_dispatch_key(dk: DispatchKey) -> bool: 316 return dk in { 317 DispatchKey.CUDA, 318 DispatchKey.QuantizedCUDA, 319 DispatchKey.SparseCUDA, 320 DispatchKey.SparseCsrCUDA, 321 DispatchKey.NestedTensorCUDA, 322 DispatchKey.AutogradCUDA, 323 } 324 325 326# XPU specific dispatcy keys 327def is_xpu_dispatch_key(dk: DispatchKey) -> bool: 328 return dk in { 329 DispatchKey.XPU, 330 DispatchKey.QuantizedXPU, 331 DispatchKey.SparseXPU, 332 DispatchKey.SparseCsrXPU, 333 DispatchKey.NestedTensorXPU, 334 DispatchKey.AutogradXPU, 335 } 336 337 338# Structured kernel generation is only supported for certain key types; 339# otherwise use old-style 340def is_structured_dispatch_key(dk: DispatchKey) -> bool: 341 return dk in STRUCTURED_DISPATCH_KEYS 342 343 344def is_ufunc_dispatch_key(dk: DispatchKey) -> bool: 345 # For now, ufunc dispatch keys coincide with structured keys 346 return dk in UFUNC_DISPATCH_KEYS 347 348 349# This is oddly named ScalarType and not DType for symmetry with C++ 350class ScalarType(Enum): 351 Byte = auto() 352 Char = auto() 353 Short = auto() 354 Int = auto() 355 Long = auto() 356 Half = auto() 357 Float = auto() 358 Double = auto() 359 ComplexHalf = auto() 360 ComplexFloat = auto() 361 ComplexDouble = auto() 362 Bool = auto() 363 BFloat16 = auto() 364 Float8_e5m2 = auto() 365 Float8_e5m2fnuz = auto() 366 Float8_e4m3fn = auto() 367 Float8_e4m3fnuz = auto() 368 369 def __str__(self) -> str: 370 return self.name 371 372 @staticmethod 373 def maybe_parse(value: str) -> ScalarType | None: 374 for k, v in ScalarType.__members__.items(): 375 if k == value: 376 return v 377 return None 378 379 @staticmethod 380 def parse(value: str) -> ScalarType: 381 mb_r = ScalarType.maybe_parse(value) 382 assert mb_r is not None, f"unknown dtype {value}" 383 return mb_r 384 385 @staticmethod 386 def parse_set(values: str) -> OrderedSet[ScalarType]: 387 dtypes: OrderedSet[ScalarType] = OrderedSet() 388 for value in values.split(", "): 389 if value in DTYPE_CLASSES: 390 dtypes.update(DTYPE_CLASSES[value]) 391 else: 392 dtypes.add(ScalarType.parse(value)) 393 return dtypes 394 395 396DTYPE_CLASSES: dict[str, OrderedSet[ScalarType]] = {} 397# NB: Integral doesn't include boolean 398DTYPE_CLASSES["Integral"] = OrderedSet( 399 [ 400 ScalarType.Byte, 401 ScalarType.Char, 402 ScalarType.Int, 403 ScalarType.Long, 404 ScalarType.Short, 405 ] 406) 407# NB: Floating doesn't include low precision types 408DTYPE_CLASSES["Floating"] = OrderedSet([ScalarType.Float, ScalarType.Double]) 409DTYPE_CLASSES["Complex"] = OrderedSet( 410 [ScalarType.ComplexFloat, ScalarType.ComplexDouble] 411) 412DTYPE_CLASSES["All"] = DTYPE_CLASSES["Integral"] | DTYPE_CLASSES["Floating"] 413DTYPE_CLASSES["AllAndComplex"] = DTYPE_CLASSES["All"] | DTYPE_CLASSES["Complex"] 414DTYPE_CLASSES["FloatingAndComplex"] = ( 415 DTYPE_CLASSES["Floating"] | DTYPE_CLASSES["Complex"] 416) 417 418 419# Represents the valid entries for ufunc_inner_loop in native_functions.yaml. 420# NB: if you add a new UfuncKey, you will teach torchgen.dest.ufunc how 421# to process it. Most logic will ignore keys they don't understand, so your 422# new key will get silently ignored until you hook in logic to deal with it. 423class UfuncKey(Enum): 424 # These are low level keys that represent exactly one particular 425 # instantiation of the kernel produced by codegen 426 CUDAFunctor = auto() 427 CUDAFunctorOnOther = auto() 428 CUDAFunctorOnSelf = auto() 429 430 CPUScalar = auto() 431 CPUVector = auto() 432 433 # These are the ones users will usually specify, and 434 # implicitly "fill in" the low level keys 435 ScalarOnly = auto() # CUDA*, CPUScalar 436 Generic = auto() # CUDA*, CPU* 437 438 def __str__(self) -> str: 439 return self.name 440 441 @staticmethod 442 def parse(value: str) -> UfuncKey: 443 for k, v in UfuncKey.__members__.items(): 444 if k == value: 445 return v 446 raise AssertionError(f"unknown ufunc key {value}") 447 448 449class DeviceCheckType(Enum): 450 NoCheck = 0 451 ExactSame = 1 452 453 454class ViewSchemaKind(Enum): 455 aliasing = auto() 456 aliasing_inplace = auto() 457 non_aliasing = auto() 458 459 460# The basic input to the code generation is native_functions.yaml. 461# The name "native", BTW, comes from the distinction between native 462# functions and legacy TH functions. The legacy TH functions are gone, 463# but the "native" descriptor has stuck. 464# 465# NativeFunction models a single entry in native_functions.yaml. Its 466# fields roughly correspond to what you would see in the YAML itself, 467# but after canonicalization and parsing has occurred. 468# 469# You can see some of the overall design patterns for how we setup 470# dataclasses in this class, but we will defer a complete discussion 471# of this at FunctionSchema. 472@dataclass(frozen=True) 473class NativeFunction: 474 # The namespace for this operator. For example, if we have "at::add" 475 # then the namespace would be "at". This enables ops to be registered 476 # through the same DSL with a custom namespace. If not specified, the 477 # default namespace would be "at". 478 namespace: str 479 480 # The function schema of the operator in question. This schema 481 # has been parsed; see FunctionSchema for more about its structure. 482 # (This type is quoted as we are forward referencing a type 483 # defined later in the file. I opted for this ordering of the 484 # classes for expository clarity.) 485 func: FunctionSchema 486 487 # Whether or not to generate mutable tensor arguments like regular 488 # ones 489 use_const_ref_for_mutable_tensors: bool 490 491 # Whether or not to omit automatic generation of a DeviceGuard 492 device_guard: bool 493 494 # How to emit automatic generation of device check 495 device_check: DeviceCheckType 496 497 # What python module to put the function in 498 python_module: str | None 499 500 # TODO: figure out what this does 501 category_override: str | None 502 503 # If no variants are specified in native_functions.yaml, this is 504 # assumed to be {'function'}. 505 variants: set[Variant] 506 507 # Whether or not we should skip generating registrations for 508 # this kernel. This is a bit of a double-edged sword, as manual 509 # registrations don't participate in codegen-based selective build! 510 manual_kernel_registration: bool 511 512 # Whether or not to skip generating TensorMethod/Functions bindings 513 # for this kernel. Technically, this doesn't actually skip generating 514 # the binding; instead, the binding gets generated to __dispatch_{funcname} 515 # so you can make use of the normal binding if you need it. 516 manual_cpp_binding: bool 517 518 # The location in the YAML file were this native function entry was 519 # defined. This is for conveniently reporting error messages! 520 loc: Location 521 522 # A list of operators that are expected to be auto-generated for this NativeFunction. 523 # Note: This list isn't actually directly used by the codegen to generate anything. 524 # Instead, the codegen figures out what operators to generate purely based off of 525 # function schema, and uses the autogen declarations to error check. 526 # We expect every NativeFunction that gets auto-generated be explicitly called out 527 # in native_functions.yaml 528 autogen: list[OperatorName] 529 530 # If non-empty, this kernel is subject to ufunc codegen. 531 # Sorted by ufunc_key 532 ufunc_inner_loop: dict[UfuncKey, UfuncInnerLoop] 533 534 # Whether or not this out functions is a "structured kernel". Structured 535 # kernels are defined a little differently from normal kernels; in 536 # particular, their shape checking logic is defined separately from 537 # the kernel. Only out functions can be structured; other functions 538 # delegate to the out function using the structured_delegate keyword. 539 # Every structured kernel must have at least an out and a functional 540 # variant. 541 structured: bool 542 543 # Whether or not this non-out function is a structured kernel, defined 544 # in terms of the out kernel referenced by the string here. 545 structured_delegate: OperatorName | None 546 547 # Only valid for structured kernels. Specifies alternative of what 548 # to inherit from when defining the meta class for the structured 549 # operator. This will usually be TensorIteratorBase. This also 550 # changes the semantics of set_output to call the parent class. 551 structured_inherits: str | None 552 553 # Structured kernels can declare elements as "precomputed". These elements 554 # are returned by the meta function in one struct and passed to the impl 555 # function in lieu of certain kernel arguments that these precomputed 556 # elements supersede. Information about the names and types of these 557 # precomputed elements and how they correspond to kernel arguments is stored 558 # in this member, if applicable. 559 precomputed: Precompute | None 560 561 # Argument names whose default should be excluded from the C++ interface. 562 # Intended for resolving overload ambiguities between signatures. 563 cpp_no_default_args: set[str] 564 565 # Note [Abstract ATen methods] 566 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 567 # An abstract ATen method is one whose dispatch differs between 568 # types. These are implemented in derived types (with a 569 # standard (throwing) definition in Type). A concrete ATen 570 # method is one which has the same dispatch for all types; 571 # we just implement it in the base Type. This is exposed 572 # in Declarations.yaml via a field named 'abstract'. 573 is_abstract: bool 574 575 # Whether or not the NativeFunction contains a backend-agnostic kernel 576 has_composite_implicit_autograd_kernel: bool 577 has_composite_implicit_autograd_nested_tensor_kernel: bool 578 has_composite_explicit_autograd_kernel: bool 579 has_composite_explicit_autograd_non_functional_kernel: bool 580 581 # Tags are used to describe semantic information about (groups of) operators, 582 # That aren't easily inferrable directly from the operator's schema. 583 tags: set[str] 584 585 # NB: The benefit of defining a dataclass is that we automatically get 586 # a constructor defined for all the fields we specify. No need 587 # to explicitly write it out. 588 589 # We parse both the NativeFunction + backend-specific information about it, which it stored in a corresponding BackendIndex. 590 @staticmethod 591 def from_yaml( 592 ei: dict[str, object], 593 loc: Location, 594 valid_tags: set[str], 595 ignore_keys: set[DispatchKey] | None = None, 596 ) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]: 597 """ 598 Parse a NativeFunction from a dictionary as directly parsed 599 from native_functions.yaml 600 """ 601 e = ei.copy() 602 603 funcs = e.pop("func") 604 assert isinstance(funcs, str), f"not a str: {funcs}" 605 # only support one level of namespace. E.g., aten::add 606 namespace_helper = NamespaceHelper.from_namespaced_entity( 607 namespaced_entity=funcs, max_level=1 608 ) 609 namespace = namespace_helper.get_cpp_namespace(default="aten") 610 func = FunctionSchema.parse(namespace_helper.entity_name) 611 612 cpp_no_default_args_list = e.pop("cpp_no_default_args", []) 613 assert isinstance(cpp_no_default_args_list, list) 614 cpp_no_default_args = set(cpp_no_default_args_list) 615 616 use_const_ref_for_mutable_tensors = e.pop( 617 "use_const_ref_for_mutable_tensors", False 618 ) 619 assert isinstance(use_const_ref_for_mutable_tensors, bool) 620 621 variants_s = e.pop("variants", "function") 622 assert isinstance(variants_s, str) 623 variants: set[Variant] = set() 624 for v in variants_s.split(", "): 625 if v == "function": 626 variants.add(Variant.function) 627 elif v == "method": 628 variants.add(Variant.method) 629 else: 630 raise AssertionError(f"illegal variant {v}") 631 632 manual_kernel_registration = e.pop("manual_kernel_registration", False) 633 assert isinstance( 634 manual_kernel_registration, bool 635 ), f"not a bool: {manual_kernel_registration}" 636 637 manual_cpp_binding = e.pop("manual_cpp_binding", False) 638 assert isinstance(manual_cpp_binding, bool), f"not a bool: {manual_cpp_binding}" 639 640 device_guard = e.pop("device_guard", True) 641 assert isinstance(device_guard, bool), f"not a bool: {device_guard}" 642 643 device_check_s = e.pop("device_check", None) 644 assert device_check_s is None or isinstance( 645 device_check_s, str 646 ), f"not a str: {device_check_s}" 647 assert ( 648 device_check_s is None or device_check_s in DeviceCheckType.__members__ 649 ), f"illegal device_check: {device_check_s}" 650 device_check: DeviceCheckType 651 if device_check_s is None: 652 device_check = DeviceCheckType.ExactSame 653 else: 654 device_check = DeviceCheckType[device_check_s] 655 656 structured = e.pop("structured", False) 657 assert isinstance(structured, bool), f"not a bool: {structured}" 658 659 structured_delegate_s = e.pop("structured_delegate", None) 660 assert structured_delegate_s is None or isinstance( 661 structured_delegate_s, str 662 ), f"not a str: {structured_delegate_s}" 663 assert structured_delegate_s is None or "::" not in structured_delegate_s, ( 664 "namespace is not supported in structured delegate," 665 " using the same namespace as the native function" 666 ) 667 structured_delegate: OperatorName | None = None 668 if structured_delegate_s is not None: 669 structured_delegate = OperatorName.parse(structured_delegate_s) 670 671 structured_inherits = e.pop("structured_inherits", None) 672 assert structured_inherits is None or isinstance( 673 structured_inherits, str 674 ), f"not a str: {structured_inherits}" 675 assert structured_inherits is None or "::" not in structured_inherits, ( 676 "namespace is not supported in structured inherits," 677 " using the same namespace as the native function" 678 ) 679 680 python_module = e.pop("python_module", None) 681 assert python_module is None or isinstance( 682 python_module, str 683 ), f"not a str: {python_module}" 684 assert ( 685 python_module is None or Variant.method not in variants 686 ), "functions in modules cannot be methods" 687 688 category_override = e.pop("category_override", None) 689 assert category_override is None or isinstance( 690 category_override, str 691 ), f"not a str: {category_override}" 692 693 precomputed_dict = e.pop("precomputed", None) 694 assert precomputed_dict is None or structured is True 695 precomputed = Precompute.parse(precomputed_dict) if precomputed_dict else None 696 697 tags_inp = e.pop("tags", []) 698 if isinstance(tags_inp, str): 699 tags_inp = [tags_inp] 700 assert isinstance(tags_inp, list) 701 702 # All aten ops generated by torchgen receive the pt2_compliant tag. 703 if namespace == "aten" and "pt2_compliant_tag" in valid_tags: 704 tags_inp.append("pt2_compliant_tag") 705 706 tags: set[str] = set() 707 for t in tags_inp: 708 assert len(valid_tags) > 0 709 # TODO: verify that the tag is valid and has an entry in tags.yaml 710 if t in valid_tags: 711 tags.add(t) 712 else: 713 raise AssertionError(f"illegal tag {t}") 714 715 from torchgen.api import cpp 716 717 raw_dispatch = e.pop("dispatch", None) 718 assert raw_dispatch is None or isinstance(raw_dispatch, dict), e 719 dispatch: dict[DispatchKey, BackendMetadata] = {} 720 num_dispatch_keys: int = 0 721 if raw_dispatch is not None: 722 assert not manual_kernel_registration, ( 723 "cannot specify both manual_kernel_registration and dispatch; with " 724 "manual registration, dispatch has no effect!" 725 ) 726 redundant_composite_implicit_autograd = False 727 for ks, v in raw_dispatch.items(): 728 if ks == "__line__": 729 continue # not worth tracking line numbers for dispatch entries 730 assert isinstance( 731 ks, str 732 ), f"illegal dispatch key '{ks}' in {raw_dispatch}" 733 assert isinstance( 734 v, str 735 ), f"illegal dispatch value '{v}' in {raw_dispatch}" 736 for k in ks.split(","): 737 dispatch_key = DispatchKey.parse(k.strip()) 738 num_dispatch_keys += 1 739 740 if ignore_keys and dispatch_key in ignore_keys: 741 continue 742 assert dispatch_key in dispatch_keys, ( 743 f"Dispatch key {dispatch_key} of kernel {v} " 744 "is not a supported dispatch key." 745 ) 746 # We only allow at most 3 levels of namespace for kernels. 747 # We will append "native" to a custom kernel namespace. 748 namespace_helper = NamespaceHelper.from_namespaced_entity( 749 v, max_level=3 750 ) 751 kernel_namespace = namespace_helper.get_cpp_namespace(default="at") 752 # Why is 'structured' included? External backends (e.g. 753 # XLA) opt into which ops are structured independently 754 # of which in-tree ops are structured 755 dispatch[dispatch_key] = BackendMetadata( 756 kernel=namespace_helper.entity_name, 757 structured=structured 758 and is_structured_dispatch_key(dispatch_key), 759 cpp_namespace=(kernel_namespace + "::native"), 760 ) 761 if ( 762 dispatch_key is DispatchKey.CompositeImplicitAutograd 763 and v == cpp.name(func) 764 ): 765 redundant_composite_implicit_autograd = True 766 767 # We count the number of dispatch keys which have not been ignored to prevent a dispatch table 768 # in which all backend keys are ignored but necessarily kept, remaining compositeimplicit, 769 # from being treated as redundant. 770 assert not ( 771 num_dispatch_keys == 1 and redundant_composite_implicit_autograd 772 ), ( 773 "unnecessary dispatch table for this function; just delete the dispatch " 774 "key entirely" 775 ) 776 # if a function is a structured delegate, deleting the dispatch 777 # table is NOT semantics preserving 778 assert ( 779 structured_delegate 780 or dispatch.keys() != {DispatchKey.CompositeImplicitAutograd} 781 or dispatch[DispatchKey.CompositeImplicitAutograd].supports_symint() 782 or num_dispatch_keys != 1 783 ), ( 784 f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} " 785 f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}. Rename your implementation to the expected " 786 "name, then delete the dispatch table" 787 ) 788 elif not structured and structured_delegate is None: 789 name = str(func.name.name) 790 assert not ( 791 name.startswith("new_") 792 or name.endswith("_like") 793 # TODO: maybe it's better to test the return 794 or ( 795 func.arguments.tensor_options 796 and not func.arguments.has_tensor_arg() 797 ) 798 ), ( 799 f"expected {name} to have a CompositeExplicitAutograd " 800 "dispatch entry, but there was no dispatch table. Factory functions " 801 "should not have implicit dispatch as they should not be decomposed " 802 "for __torch_dispatch__" 803 ) 804 dispatch[DispatchKey.CompositeImplicitAutograd] = BackendMetadata( 805 cpp.name(func), structured=False, cpp_namespace=DEFAULT_KERNEL_NAMESPACE 806 ) 807 808 composites_in_dispatch = [ 809 d 810 for d in dispatch 811 if d == DispatchKey.CompositeExplicitAutograd 812 or d == DispatchKey.CompositeExplicitAutogradNonFunctional 813 or d == DispatchKey.CompositeImplicitAutograd 814 or d == DispatchKey.CompositeImplicitAutogradNestedTensor 815 ] 816 817 assert len(composites_in_dispatch) <= 1 or ( 818 len(composites_in_dispatch) == 2 819 and ( 820 DispatchKey.CompositeExplicitAutogradNonFunctional 821 not in composites_in_dispatch 822 ) 823 and ( 824 DispatchKey.CompositeImplicitAutogradNestedTensor 825 in composites_in_dispatch 826 ) 827 ), ( 828 "cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, " 829 "or CompositeImplicitAutograd on a single kernel; each " 830 "strictly subsumes the other. If you wanted to provide an explicit autograd " 831 "implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only" 832 ) 833 834 autogen_str = e.pop("autogen", "") 835 assert isinstance(autogen_str, str) 836 autogen = ( 837 [] 838 if autogen_str == "" 839 else [OperatorName.parse(x) for x in autogen_str.split(", ")] 840 ) 841 842 raw_ufunc_inner_loop = e.pop("ufunc_inner_loop", {}) 843 ufunc_inner_loop = {} 844 if isinstance(raw_ufunc_inner_loop, str): 845 ufunc_inner_loop[UfuncKey.Generic] = UfuncInnerLoop.parse( 846 raw_ufunc_inner_loop, UfuncKey.Generic 847 ) 848 elif isinstance(raw_ufunc_inner_loop, dict): 849 for k, vo in raw_ufunc_inner_loop.items(): 850 if k == "__line__": 851 continue 852 assert isinstance(k, str), f"ufunc_inner_loop key is not a str: {k}" 853 assert isinstance(vo, str), f"ufunc_inner_loop value is not a str: {v}" 854 ufunc_key = UfuncKey.parse(k) 855 ufunc_inner_loop[ufunc_key] = UfuncInnerLoop.parse(vo, ufunc_key) 856 else: 857 raise AssertionError( 858 f"ufunc_inner_loop not str or dict: {raw_ufunc_inner_loop}" 859 ) 860 # Program the BackendIndex for the implicit dispatch entry from ufunc 861 if ufunc_inner_loop: 862 assert structured, "ufunc must be structured" 863 864 # Delay import ufunc here to avoid circular import issue 865 # See: https://github.com/pytorch/pytorch/issues/81294 866 import torchgen.api.ufunc as ufunc 867 868 for dispatch_key in UFUNC_DISPATCH_KEYS: 869 assert ( 870 dispatch_key not in dispatch 871 ), f"ufunc should not have explicit dispatch entry for {dispatch_key}" 872 dispatch[dispatch_key] = BackendMetadata( 873 kernel=ufunc.schema_kernel_name(func, dispatch_key), 874 structured=True, 875 cpp_namespace=DEFAULT_KERNEL_NAMESPACE, 876 ) 877 878 if structured_delegate: 879 # Structured functions MUST have a dispatch table 880 is_abstract = True 881 else: 882 is_abstract = ( 883 dispatch.keys() != {DispatchKey.CompositeImplicitAutograd} 884 and dispatch.keys() 885 != {DispatchKey.CompositeImplicitAutogradNestedTensor} 886 and dispatch.keys() 887 != { 888 DispatchKey.CompositeImplicitAutograd, 889 DispatchKey.CompositeImplicitAutogradNestedTensor, 890 } 891 ) 892 893 has_composite_implicit_autograd_kernel = ( 894 DispatchKey.CompositeImplicitAutograd in dispatch 895 ) 896 has_composite_implicit_autograd_nested_tensor_kernel = ( 897 DispatchKey.CompositeImplicitAutogradNestedTensor in dispatch 898 ) 899 has_composite_explicit_autograd_kernel = ( 900 DispatchKey.CompositeExplicitAutograd in dispatch 901 ) 902 has_composite_explicit_autograd_non_functional_kernel = ( 903 DispatchKey.CompositeExplicitAutogradNonFunctional in dispatch 904 ) 905 906 # We aren't going to store dispatch metadata inline in NativeFunctions; 907 # instead it is separately indexed by backend (so other backends can 908 # add more dispatch entries after the fact). Reindex the individual 909 # metadata by OperatorName! 910 backend_metadata = {k: {func.name: v} for k, v in dispatch.items()} 911 912 # don't care if it exists or not; make it easier to use this function 913 # with other yaml parsers that aren't setting __line__ in the dict 914 e.pop("__line__", None) 915 assert not e, f"leftover entries: {e}" 916 917 # Asserts that we can't do in post_init, because they rely on backend-specific info 918 if structured_delegate is not None: 919 for key in STRUCTURED_DISPATCH_KEYS: 920 assert key not in dispatch, ( 921 f"if structured_delegate, then must not have {key} in dispatch dictionary " 922 "(it is delegated!)" 923 ) 924 925 return ( 926 NativeFunction( 927 func=func, 928 use_const_ref_for_mutable_tensors=use_const_ref_for_mutable_tensors, 929 variants=variants, 930 structured=structured, 931 structured_delegate=structured_delegate, 932 structured_inherits=structured_inherits, 933 precomputed=precomputed, 934 autogen=autogen, 935 ufunc_inner_loop=ufunc_inner_loop, 936 manual_kernel_registration=manual_kernel_registration, 937 manual_cpp_binding=manual_cpp_binding, 938 python_module=python_module, 939 category_override=category_override, 940 device_guard=device_guard, 941 device_check=device_check, 942 loc=loc, 943 cpp_no_default_args=cpp_no_default_args, 944 is_abstract=is_abstract, 945 has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel, 946 has_composite_implicit_autograd_nested_tensor_kernel=has_composite_implicit_autograd_nested_tensor_kernel, 947 has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel, 948 has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel, 949 tags=tags, 950 namespace=namespace, 951 ), 952 backend_metadata, 953 ) 954 955 def validate_unstructured(self) -> None: 956 # TODO: probably better to accumulate these errors and report them all 957 # at once 958 assert not self.structured, ( 959 "This function is structured, but there was " 960 "no valid functional variant of it." 961 ) 962 assert self.structured_delegate, ( 963 "This function delegates to another structured out function, " 964 "but no valid function was found (the delegate may not exist, or it has the wrong type)" 965 ) 966 967 # __post_init__ functions in dataclasses can be used to do extra 968 # validation after construction. 969 # 970 # Notice that we don't do any type validation here. In fact, we 971 # rely exclusively on mypy to check if you've done types correctly! 972 # Validation is for nontrivial invariants that cannot be (conveniently) 973 # encoded in the type system. 974 def __post_init__(self) -> None: 975 if self.func.arguments.out: 976 assert self.variants == {Variant.function}, ( 977 "Native functions with out arguments MUST " 978 "be declared with only function variant; e.g., variants: function; " 979 "otherwise you will tickle a Python argument binding bug " 980 "(which usually manifests itself as the result variable being undefined.)" 981 ) 982 if self.structured: 983 assert self.func.kind() == SchemaKind.out, ( 984 "Put structured field on the out= " 985 "variant of a function; did you mean structured_delegate?" 986 ) 987 assert ( 988 self.device_guard 989 ), "device_guard: False is not respected by structured kernels" 990 if self.structured_delegate: 991 assert self.func.kind() != SchemaKind.out, ( 992 "structured_delegate field not allowed " 993 "on out= functions; did you mean structured?" 994 ) 995 assert ( 996 self.device_guard 997 ), "device_guard: False is not respected by structured kernels" 998 # Technically, with the asserts above, this assert is impossible to 999 # happen 1000 assert not ( 1001 self.structured and self.structured_delegate 1002 ), "Cannot have both structured and structured_delegate on function" 1003 defaulted_arguments = { 1004 a.name for a in self.func.schema_order_arguments() if a.default is not None 1005 } 1006 invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments) 1007 assert len(invalid_args) == 0, f"Invalid cpp_no_default_args: {invalid_args}" 1008 if self.structured_inherits is not None: 1009 assert ( 1010 self.structured 1011 ), "structured_inherits must also imply structured: True" 1012 if str(self.func.name).startswith("_foreach"): 1013 assert self.device_check == DeviceCheckType.NoCheck, ( 1014 "foreach kernels fall back to slow path when tensor are on different devices, " 1015 "device_check not allowed to be enabled" 1016 ) 1017 1018 # NB: if your function accidentally has rand/dropout/... in its name 1019 # but is not actually random, feel free to amend this to special case 1020 if ( 1021 "rand" in str(self.func.name) 1022 or ( 1023 ( 1024 "dropout" in str(self.func.name) 1025 or any( 1026 "dropout" in arg.name for arg in self.func.arguments.flat_all 1027 ) 1028 ) 1029 # Backwards of dropout is typically deterministic 1030 and "backward" not in str(self.func.name) 1031 and str(self.func.name.name) not in ["_cudnn_init_dropout_state"] 1032 ) 1033 or self.func.arguments.has_generator_arg() 1034 ): 1035 assert "nondeterministic_seeded" in self.tags, str(self.func.name) 1036 1037 @property 1038 def has_composite_kernel(self) -> bool: 1039 return ( 1040 self.has_composite_implicit_autograd_kernel 1041 or self.has_composite_explicit_autograd_kernel 1042 or self.has_composite_explicit_autograd_non_functional_kernel 1043 ) or ( 1044 self.has_composite_implicit_autograd_kernel 1045 and self.has_composite_implicit_autograd_nested_tensor_kernel 1046 ) 1047 1048 @property 1049 def is_view_op(self) -> bool: 1050 rets = self.func.returns 1051 is_non_mutating_view = len(rets) > 0 and any( 1052 r.annotation is not None and not r.annotation.is_write for r in rets 1053 ) 1054 # See Note [resize_ in Functionalization] for more dtails 1055 is_inplace_view = ( 1056 "inplace_view" in self.tags 1057 and str(self.func.name) != "resize_" 1058 and str(self.func.name) != "resize_as_" 1059 ) 1060 is_wildcard_view = any( 1061 inp.annotation is not None and "*" in inp.annotation.alias_set_after 1062 for inp in self.func.schema_order_arguments() 1063 ) 1064 return is_non_mutating_view or is_inplace_view or is_wildcard_view 1065 1066 @property 1067 def view_schema_kind(self) -> ViewSchemaKind: 1068 if self.is_view_op and self.func.name.name.inplace: 1069 assert "inplace_view" in self.tags 1070 return ViewSchemaKind.aliasing_inplace 1071 if self.is_view_op: 1072 return ViewSchemaKind.aliasing 1073 else: 1074 return ViewSchemaKind.non_aliasing 1075 1076 @property 1077 def root_name(self) -> str: 1078 return self.func.name.name.base 1079 1080 @property 1081 def part_of_structured_group(self) -> bool: 1082 return self.structured or self.structured_delegate is not None 1083 1084 1085class SchemaKind(Enum): 1086 functional = auto() 1087 inplace = auto() 1088 out = auto() 1089 mutable = auto() 1090 scratch = auto() 1091 1092 1093# A structured kernel is guaranteed to have a functional and out variant, and 1094# optionally an inplace variant. 1095# 1096# NB: we create NativeFunctionsGroup *even if* the function is not 1097# actually annotated structured. Test the structured boolean to see if it 1098# actually is structured or not. 1099@dataclass(frozen=True) 1100class NativeFunctionsGroup: 1101 functional: NativeFunction 1102 inplace: NativeFunction | None 1103 mutable: NativeFunction | None 1104 out: NativeFunction 1105 1106 @property 1107 def structured(self) -> bool: 1108 # Whether or not the operator has a meta() function. This information is backend-agnostic. 1109 return self.out.structured 1110 1111 def __post_init__(self) -> None: 1112 test_sig: FunctionSchema = self.functional.func.signature() 1113 for f in self.functions(): 1114 if test_sig != f.func.signature(): 1115 raise AssertionError( 1116 "NativeFunctionsGroup constructed from two NativeFunctions " 1117 f"that don't have matching signatures: {test_sig} != {f.func.signature()}" 1118 ) 1119 1120 if self.structured != f.part_of_structured_group: 1121 raise AssertionError( 1122 "NativeFunctionsGroup constructed from structured and unstructured " 1123 f"functions: {self.out.func.name} and {f.func.name}" 1124 ) 1125 assert self.functional.func.kind() == SchemaKind.functional 1126 assert self.out.func.kind() == SchemaKind.out 1127 assert self.functional.namespace == self.out.namespace 1128 if self.inplace is not None: 1129 assert self.inplace.func.kind() == SchemaKind.inplace 1130 assert self.inplace.namespace == self.functional.namespace 1131 1132 if self.mutable is not None: 1133 assert self.mutable.func.kind() == SchemaKind.mutable 1134 assert self.mutable.namespace == self.functional.namespace 1135 # See Note [Overload Ambiguity With Functional Variants] 1136 assert self.functional.func.name.name.functional_overload 1137 1138 if self.structured: 1139 # For now, structured composite kernels are not supported (need some 1140 # design work to figure out how to make the composite case work) 1141 assert ( 1142 not self.out.has_composite_implicit_autograd_kernel 1143 and not self.out.has_composite_implicit_autograd_nested_tensor_kernel 1144 ) 1145 1146 assert self.functional.structured_delegate == self.out.func.name, ( 1147 f"{self.functional.func.name} delegates to {self.functional.structured_delegate} " 1148 f"but its actual delegate is {self.out.func.name}" 1149 ) 1150 if self.inplace is not None: 1151 assert self.inplace.structured_delegate == self.out.func.name 1152 1153 generated_fns = sorted( 1154 [str(f.func.name) for f in self.functions() if "generated" in f.tags] 1155 ) 1156 generated_fns_str = ", ".join(str(x) for x in generated_fns) 1157 expected_generated_fns: set[str] = set() 1158 for f in self.functions(): 1159 expected_generated_fns.update(str(op) for op in f.autogen) 1160 expected_generated_fns_str = ", ".join( 1161 str(x) for x in sorted(expected_generated_fns) 1162 ) 1163 if len(expected_generated_fns) == 0 and len(generated_fns) > 0: 1164 raise RuntimeError( 1165 f"The codegen expects to be able to generate '{generated_fns_str}'." 1166 " In order to generate them however, we expect them to be called out explicitly in the yaml." 1167 f" Please add an 'autogen: {generated_fns_str}' line to the entry for {str(f.func.name)}" 1168 ) 1169 if expected_generated_fns_str != generated_fns_str: 1170 raise RuntimeError( 1171 f"The codegen expects to be able to generate '{generated_fns_str}'." 1172 f" To do so, it expects a line: 'autogen: {generated_fns_str}'." 1173 f" Instead, it found 'autogen: {expected_generated_fns_str}'" 1174 ) 1175 1176 def signature(self) -> FunctionSchema: 1177 return self.out.func.signature() 1178 1179 def functions(self) -> Iterator[NativeFunction]: 1180 yield self.functional 1181 yield self.out 1182 if self.inplace is not None: 1183 yield self.inplace 1184 if self.mutable is not None: 1185 yield self.mutable 1186 1187 @property 1188 def root_name(self) -> str: 1189 return self.functional.root_name 1190 1191 @staticmethod 1192 def from_dict(d: dict[SchemaKind, NativeFunction]) -> NativeFunctionsGroup | None: 1193 assert d 1194 if len(d) == 1: 1195 return None 1196 d = dict(d) # non-destructive updates please 1197 functional = d.pop(SchemaKind.functional, None) 1198 inplace = d.pop(SchemaKind.inplace, None) 1199 mutable = d.pop(SchemaKind.mutable, None) 1200 out = d.pop(SchemaKind.out, None) 1201 assert not d 1202 assert functional is not None 1203 # There are a few operators which only have functional/inplace variants; 1204 # these don't count as structured for our purposes here 1205 if out is None: 1206 return None 1207 # assuming all variants have the same namespace 1208 return NativeFunctionsGroup( 1209 functional=functional, 1210 inplace=inplace, 1211 mutable=mutable, 1212 out=out, 1213 ) 1214 1215 1216@dataclass(frozen=True) 1217class BackendMetadata: 1218 # The name of the backend kernel, for a given operator 1219 # for in-tree backends. These names come directly from the 'dispatch" field 1220 # in native_functions.yaml. The dispatch entry is optional; in that 1221 # case, that is equivalent to having written: 1222 # 1223 # dispatch: 1224 # CompositeImplicitAutograd: $operator_name 1225 kernel: str 1226 # Whether or not the operator has a structured kernel implemented, for this particular backend. 1227 # For in-tree backends, they all have the same value for structured- this is listed 1228 # in native_functions.yaml. 1229 # However, external backends like XLA can indendently toggle which ops are structured. 1230 structured: bool 1231 1232 # The namespace for kernels, default value: DEFAULT_KERNEL_NAMESPACE 1233 cpp_namespace: str 1234 1235 def supports_symint(self) -> bool: 1236 return "_symint" in self.kernel 1237 1238 1239@dataclass(frozen=True) 1240class UfuncInnerLoop: 1241 name: str 1242 supported_dtypes: OrderedSet[ScalarType] 1243 # key is stored here because it affects the semantics of name, 1244 # so its helpful to have them together for further processing 1245 ufunc_key: UfuncKey 1246 1247 @staticmethod 1248 def parse(value: str, ufunc_key: UfuncKey) -> UfuncInnerLoop: 1249 name, supported_dtypes_str = value.split(" ", 1) 1250 assert supported_dtypes_str[0] == "(" 1251 assert supported_dtypes_str[-1] == ")" 1252 supported_dtypes: OrderedSet[ScalarType] = OrderedSet() 1253 for k in supported_dtypes_str[1:-1].split(", "): 1254 supported_dtypes |= ScalarType.parse_set(k) 1255 return UfuncInnerLoop( 1256 name=name, supported_dtypes=supported_dtypes, ufunc_key=ufunc_key 1257 ) 1258 1259 1260# BackendIndex represents a backend. 1261# The BackendIndex encodes per-operator information that is potentially different 1262# for each backend. The most obvious example is the name of the kernel 1263# (the 'dispatch' entry in native_functions.yaml). 1264# However, there can be other examples of different backends having different information. 1265# External backends can choose to opt their kernels to be structured independently from in-tree backends, 1266# which means that this information isn't inherently tied to a NativeFunction- it's different per backend. 1267@dataclass(frozen=True) 1268class BackendIndex: 1269 dispatch_key: DispatchKey 1270 # Mainly important for structured kernels, this determines which variant in the operator group is used to implement the others. 1271 # All in-tree ops use out kernels, while XLA uses functional kernels. 1272 use_out_as_primary: bool 1273 # Whether the backend requires a device guard, and device checks. 1274 # For in-tree backends, this is currently just CUDA/HIP 1275 # For out-of-tree backends, this is currently just Intel XPU 1276 device_guard: bool 1277 # Whether the backend is in-tree (CPU/CUDA) or out-of-tree (XLA) 1278 external: bool 1279 # Other backend-specific information that is on a per-operator basis 1280 index: dict[OperatorName, BackendMetadata] 1281 1282 @staticmethod 1283 def grow_index( 1284 parent_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]], 1285 child_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]], 1286 ) -> None: 1287 for k, v in child_index.items(): 1288 for op_name, metadata in v.items(): 1289 assert ( 1290 op_name not in parent_index[k] 1291 ), f"duplicate operator {op_name} for dispatch key {k}" 1292 parent_index[k][op_name] = metadata 1293 1294 def primary(self, g: NativeFunctionsGroup) -> NativeFunction: 1295 if self.use_out_as_primary: 1296 return g.out 1297 else: 1298 return g.functional 1299 1300 def has_kernel(self, g: NativeFunction | NativeFunctionsGroup) -> bool: 1301 m = self.get_kernel(g) 1302 return m is not None 1303 1304 def get_kernel( 1305 self, g: NativeFunction | NativeFunctionsGroup 1306 ) -> BackendMetadata | None: 1307 if isinstance(g, NativeFunction): 1308 f = g 1309 elif isinstance(g, NativeFunctionsGroup): 1310 f = self.primary(g) 1311 else: 1312 assert_never(g) 1313 if f.func.name not in self.index: 1314 return None 1315 return self.index[f.func.name] 1316 1317 def native_function_class_name(self) -> str | None: 1318 if self.external: 1319 return f"{str(self.dispatch_key)}NativeFunctions" 1320 else: 1321 # TODO: This discrepancy isn't required; we could also generated 1322 # a class for in-tree kernels. It'll just require carefully 1323 # updating every kernel definition + callsite of every in-tree aten kernel. 1324 return None 1325 1326 1327# The function schema is undoubtedly the most important data structure 1328# in all of the codegen, as it defines the type signature for operators, 1329# and most of the code generation we do is type directed (e.g., look at 1330# the types, decide what to do. Think about how we code generate 1331# C++ function stubs!) 1332# 1333# We will also see in this class the general structure for how we model 1334# data in this code generation. A few notable properties to point out 1335# ahead of time: 1336# 1337# - These dataclasses are a *lossless* representation of the strings 1338# they are parsed from. In fact, we assert that given the 1339# information stored in the dataclass, we can exactly reconstruct 1340# the string we parsed from (and assert this inside the parse 1341# definition). There are a few reasons for this: 1342# 1343# - If you find that it is difficult to reconstruct the string 1344# given a dataclass, that is a clue that you are data 1345# representation is wrong. 1346# 1347# - It helps ensure that all relevant information is present 1348# in the dataclass, so that downstream users aren't tempted 1349# to reparse the original string to get some information 1350# that was omitted. 1351# 1352# - It forces you to represent the data in-memory in the same way 1353# it is recorded textually, which makes the dataclasses easier 1354# to understand for someone who is familiar with the 1355# textual format. (As a tradeoff, it means you have to model 1356# the syntax, even when it is inconvenient. But maybe that means 1357# the syntax is bad!) If you don't understand the internal 1358# representation, go look at the printing code to see how 1359# it maps onto the surface syntax! 1360# 1361# - It makes it easy to test the parsing code, as parsing code 1362# that is inconsistent with the string code will fail early 1363# and loudly. (As a tradeoff, it makes the parsing code a bit 1364# brittle (in particular, with trivial whitespace changes you 1365# are likely to trigger an assert error). 1366# 1367# In general, try to make the __str__ code as simple as possible 1368# (even at the cost of more complex parsing logic.) Additionally, 1369# try to minimize redundancy in data representation. (Precomputed 1370# fields are OK though: they are defined as a simple function on 1371# the canonical representation in question.) 1372# 1373# - These dataclasses are all frozen; once constructed their 1374# values never change. This makes it easy to tell where any 1375# given data came from: just look to the constructor. As a 1376# tradeoff, you can't easily "decorate" a schema with extra 1377# information from a post-facto analysis. We impose this 1378# restriction to make these structures more understandable. 1379# 1380@dataclass(frozen=True) 1381class FunctionSchema: 1382 # The name of the operator this function schema describes. 1383 name: OperatorName 1384 1385 arguments: Arguments 1386 1387 # TODO: Need to handle collisions with argument names at some point 1388 returns: tuple[Return, ...] 1389 1390 @property 1391 def is_mutable(self) -> bool: 1392 def is_write(arg: Argument) -> bool: 1393 if arg.annotation is None: 1394 return False 1395 return arg.annotation.is_write 1396 1397 # Corresponds to torch._C._FunctionSchema.is_mutable 1398 # See aten/src/ATen/core/function_schema.h (keep these in sync) 1399 return any(is_write(a) for a in self.arguments.flat_all) 1400 1401 def schema_order_arguments(self) -> Iterator[Argument]: 1402 return itertools.chain( 1403 self.arguments.flat_positional, 1404 self.arguments.flat_kwarg_only, 1405 self.arguments.out, 1406 ) 1407 1408 decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)") 1409 1410 @staticmethod 1411 def parse(func: str) -> FunctionSchema: 1412 # We should probably get a proper parser here 1413 decls = FunctionSchema.decl_re.findall(func) 1414 assert len(decls) == 1, f"Invalid function schema: {func}" 1415 ops, args, return_decl = decls[0] 1416 name = OperatorName.parse(ops) 1417 arguments = Arguments.parse(args) 1418 returns = parse_returns(return_decl) 1419 r = FunctionSchema(name=name, arguments=arguments, returns=returns) 1420 assert str(r) == func, f"{str(r)} != {func}" 1421 return r 1422 1423 def returns_are_aliased(self) -> bool: 1424 # We assert earlier that schemas can't have a mix of aliased and non-aliased returns 1425 return any( 1426 r 1427 for r in self.returns 1428 if r.annotation is not None and r.annotation.is_write 1429 ) 1430 1431 def __post_init__(self) -> None: 1432 for arg, ret in zip(self.arguments.out, self.returns): 1433 assert arg.annotation == ret.annotation, ( 1434 "Out arguments must have matching return Tensor; furthermore, " 1435 "the ith-argument needs to correspond to the ith return" 1436 ) 1437 # We also enforce that if you have any mutable, positional args, then they are not returned. 1438 # This makes it easier to group these functions properly with their functional/out= counterparts. 1439 for a in self.arguments.post_self_positional_mutable: 1440 assert not any( 1441 a.annotation == r.annotation for r in self.returns 1442 ), f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}" 1443 # Invariant: we expect out arguments to appear as keyword arguments in the schema. 1444 # This means that all mutable returns should be aliased to a keyword argument 1445 # (except for "self", which we explicitly don't treat as an out argument because of its use in methods) 1446 # See Note [is_out_fn] 1447 out_and_self = list(self.arguments.out) + [ 1448 arg for arg in self.arguments.flat_positional if arg.name == "self" 1449 ] 1450 mutable_returns = [ 1451 ret 1452 for ret in self.returns 1453 if ret.annotation is not None and ret.annotation.is_write 1454 ] 1455 immutable_returns = [ 1456 ret 1457 for ret in self.returns 1458 if ret.annotation is None or not ret.annotation.is_write 1459 ] 1460 # Some assertions: We don't want any functions with a return type of "-> (Tensor(a!), Tensor)", 1461 # because: 1462 # (1) It's more annoying to handle properly 1463 # (2) It's unnecessary - you can't method-chain on the first (mutated) output because it's part of a tuple. 1464 # Instead, we expect the (a!) argument to not be returned. 1465 assert ( 1466 len(mutable_returns) == 0 or len(immutable_returns) == 0 1467 ), f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}" 1468 for ret in mutable_returns: 1469 assert any(ret.annotation == arg.annotation for arg in out_and_self), ( 1470 'All mutable returns must be aliased either to a keyword argument, or to "self". ' 1471 "Did you forget to mark an out argument as keyword-only?" 1472 ) 1473 if self.arguments.out: 1474 # out= ops that return their mutable inputs are only really useful for method chaining. 1475 # And method chaining is only really useful if the thing you're returning is a plain Tensor. 1476 # So ideally, we'd enforce that out= ops with a single plain mutable tensor should return the tensor, 1477 # and all other types of out= op schemas should return void. 1478 # There are a bunch of existing out= ops that return tuples of tensors though, so we're stuck with allowing that. 1479 if any(a.type != BaseType(BaseTy.Tensor) for a in self.arguments.out): 1480 assert ( 1481 len(self.returns) == 0 1482 ), "out= ops that accept tensor lists as out arguments " 1483 "are expected to have no return type (since you can't do method chaining on them)" 1484 else: 1485 # mutable keyword arguments whose name has _scratch_ prefix are 1486 # scratch tensors for memory planning and should not be returned 1487 assert len( 1488 [ 1489 arg 1490 for arg in self.arguments.out 1491 if not arg.name.startswith("_scratch_") 1492 ] 1493 ) == len( 1494 self.returns 1495 ), "Must return as many arguments as there are out arguments, or no return at all" 1496 1497 if self.name.name.inplace: 1498 self_a = self.arguments.self_arg 1499 assert ( 1500 self_a 1501 and self_a.argument.annotation 1502 and self_a.argument.annotation.is_write 1503 ) 1504 if self_a.argument.type == BaseType(BaseTy.Tensor): 1505 # All inplace ops with an ordinary `Tensor self` argument should return self, 1506 # to allow for method chaining. 1507 assert ( 1508 len(self.returns) == 1 1509 and self.returns[0].annotation == self_a.argument.annotation 1510 ) 1511 else: 1512 # You can't method chain on non-tensor self arguments though (like a List[Tensor]) 1513 # so in all other cases we expect the return type to be none. 1514 assert len(self.returns) == 0 1515 1516 if self.arguments.tensor_options is not None: 1517 assert self.kind() == SchemaKind.functional, ( 1518 "Found an operator that is not functional or out variant, but has tensor options arguments." 1519 "This is not allowed- tensor options arguments are only allowed for factory functions." 1520 f"schema: {str(self)}" 1521 ) 1522 if self.is_functional_fn(): 1523 assert self.kind() == SchemaKind.functional, ( 1524 "Found an operator that is not functional, but its overload contains the string 'functional'." 1525 "This is a special keyword in the codegen, please use a different overload name." 1526 f"schema: {str(self)}" 1527 ) 1528 1529 def is_functional_fn(self) -> bool: 1530 return "functional" in self.name.overload_name 1531 1532 def is_out_fn(self) -> bool: 1533 # Note [is_out_fn] 1534 # 1535 # out functions are the variants which take an explicit out= argument 1536 # to populate into. We need to know if a schema corresponds to an 1537 # out function for several reasons: 1538 # 1539 # - They codegen differently in C++ API 1540 # - codegen to at::add_out rather than at::add 1541 # - out argument is moved to front of C++ argument list 1542 # 1543 # out functions are DEFINED to be any function with a keyword-only 1544 # argument that is mutable. In principle, this could lead to a 1545 # false positive if you define a function that mutates a 1546 # kwarg only argument, but this isn't the "true" output of this 1547 # function. A more robust definition that would work in this 1548 # case would also look at: 1549 # 1550 # - The output types. Out functions take in the arguments 1551 # they mutate and then return them again; this is sort 1552 # of "definitionally" what makes something an out function. 1553 # Historically, we DO check this for consistency. 1554 # - Correspondence with pure variant. An out function 1555 # should have a signature equivalent to its pure variant, 1556 # but just with extra kwargs for the output elements. This 1557 # is difficult to actually check for and historically 1558 # we only do this check in tools/ 1559 return bool(self.arguments.out) 1560 1561 def kind(self) -> SchemaKind: 1562 """ 1563 What kind of schema is this? A functional schema is one 1564 that returns a newly allocated output; an inplace schema 1565 modifies the self argument inplace; an out schema writes 1566 the result into an explicitly provided out argument. 1567 """ 1568 is_out = bool(self.arguments.out) 1569 is_scratch = bool( 1570 [arg for arg in self.arguments.out if arg.name.startswith("_scratch_")] 1571 ) 1572 is_inplace = self.name.name.inplace 1573 is_mutable = any( 1574 a.annotation is not None and a.annotation.is_write 1575 for a in self.arguments.post_self_positional 1576 ) 1577 assert not (is_out and is_inplace) 1578 # out= and inplace schemas can also have post_self_positional mutable args, 1579 # but we give precedence to out= and inplace when deciding the schema kind. 1580 # Tradeoff: we probably don't want to have to teach codegen that looks at inplace ops 1581 # to also worry about mutable post_self_positional arguments, 1582 # but it seems like a much bigger lift to classify them has having a new schema kind. 1583 # The number of ops that fit in this strange category is small enough that 1584 # we can probably manually write code for them instead of forcing the codegen to handle them. 1585 if is_inplace: 1586 return SchemaKind.inplace 1587 elif is_scratch: 1588 assert ( 1589 is_out 1590 ), "invariant: all scratch operators are expected to be out= operators too" 1591 return SchemaKind.scratch 1592 elif is_out: 1593 assert ( 1594 not is_scratch 1595 ), "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!" 1596 return SchemaKind.out 1597 elif is_mutable: 1598 return SchemaKind.mutable 1599 else: 1600 return SchemaKind.functional 1601 1602 # For every return: 1603 # - If the return aliases an input, we return the input name 1604 # - Otherwise, we return None. 1605 # If return names were enforced to be consistent with aliasing information, then we wouldn't need this. 1606 def aliased_return_names(self) -> list[str | None]: 1607 outs: list[str | None] = [] 1608 for r in self.returns: 1609 aliased_args = [ 1610 a 1611 for a in self.arguments.flat_all 1612 if a.annotation is not None and a.annotation == r.annotation 1613 ] 1614 if len(aliased_args) == 0: 1615 outs.append(None) 1616 elif len(aliased_args) == 1: 1617 outs.append(aliased_args[0].name) 1618 else: 1619 aliased_names = ", ".join(a.name for a in aliased_args) 1620 raise AssertionError( 1621 f"Found a return ({r.name})that aliases multiple inputs ({aliased_names})" 1622 ) 1623 return outs 1624 1625 def signature( 1626 self, 1627 *, 1628 strip_default: bool = False, 1629 strip_view_copy_name: bool = False, 1630 keep_return_names: bool = False, 1631 ) -> FunctionSchema: 1632 """ 1633 Certain schemas are 'related', in that they are simply 1634 inplace/out/functional versions of the same function. This method 1635 factors these schemas into the "core" functional signature which 1636 is equal across all versions. 1637 1638 Here is what normalization happens to the schema to convert 1639 it to a signature: 1640 - The overload name is stripped (name is retained, since 1641 it expresses semantic content about what the function does) 1642 - Inplace is set False 1643 - Out arguments are stripped 1644 - Mutable post_self_positional args are converted to returns 1645 - Mutability annotations are stripped (this is sound 1646 because you cannot overload on mutability annotation) 1647 - Return names are stripped since they are not overloadable and 1648 some variants have return names but some not 1649 - TensorOptions are dropped 1650 because out= variants of factory functions don't include them 1651 (and we want to be able to pair up factory functions with their out variants) 1652 1653 Finally, we want to be able to pair up related "view" and their 1654 corresponding "view_copy" operators. We do this by optionally 1655 stripping the trailing "_copy" from the base name. 1656 1657 Example of a mutable op before and after: 1658 1659 f.func (Mutable operator): 1660 _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950 1661 1662 f.func (Corresponding functional operator): 1663 _fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) # noqa: B950 1664 1665 f.func.signature() output: 1666 _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) # noqa: B950 1667 """ 1668 1669 def strip_ret_annotation(r: Return) -> Return: 1670 return Return( 1671 name=r.name if keep_return_names else None, 1672 type=r.type, 1673 annotation=None, 1674 ) 1675 1676 base_name = self.name.name.base 1677 if strip_view_copy_name: 1678 if base_name.endswith("_copy"): 1679 base_name = base_name.replace("_copy", "") 1680 elif base_name.endswith("_scatter"): 1681 base_name = base_name.replace("scatter", "inverse") 1682 1683 # find mutable inputs that are not originally returned, and convert them to returns 1684 returns_from_mutable_inputs = tuple( 1685 # When we're grouping functions we strip the return names, 1686 # but when we're generating the actual functional variants then we follow 1687 # a convention for what to name the returns 1688 Return( 1689 name=f"{a.name}_out" if keep_return_names else None, 1690 type=a.type, 1691 annotation=None, 1692 ) 1693 for a in itertools.chain( 1694 # Order is important here (otherwise e.g. inplace with mutable args 1695 # and out= with mutable args won't have the same signature) 1696 [self.arguments.self_arg.argument] 1697 if self.arguments.self_arg is not None 1698 else [], 1699 self.arguments.out, 1700 self.arguments.post_self_positional, 1701 ) 1702 if a.annotation is not None 1703 and a.annotation.is_write 1704 and not any(a.annotation == r.annotation for r in self.returns) 1705 ) 1706 original_returns = tuple(map(strip_ret_annotation, self.returns)) 1707 # Ordering is important here. We expect the "mutable input" returns to come last. 1708 returns = original_returns + returns_from_mutable_inputs 1709 1710 args_sig = self.arguments.signature(strip_default=strip_default) 1711 # See Note [bernoulli.p schema] 1712 if str(self.name) == "bernoulli.p": 1713 args_sig = Arguments.parse(str(args_sig).replace("float p", "float p=0.5")) 1714 1715 return FunctionSchema( 1716 name=OperatorName( 1717 name=BaseOperatorName( 1718 base=base_name, 1719 inplace=False, 1720 dunder_method=self.name.name.dunder_method, 1721 ), 1722 overload_name="", # stripped 1723 ), 1724 arguments=args_sig, 1725 returns=returns, 1726 ) 1727 1728 def view_signature(self) -> FunctionSchema: 1729 return self.signature(strip_view_copy_name=True) 1730 1731 def with_name(self, name: OperatorName) -> FunctionSchema: 1732 return FunctionSchema( 1733 name=name, 1734 arguments=self.arguments, 1735 returns=self.returns, 1736 ) 1737 1738 @property 1739 def modifies_arguments(self) -> bool: 1740 return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable] 1741 1742 def has_symint(self) -> bool: 1743 return self.arguments.has_symint_arg() 1744 1745 def __str__(self) -> str: 1746 all_arguments_str = str(self.arguments) 1747 if len(self.returns) == 1: 1748 returns = str(self.returns[0]) # omit parentheses 1749 else: 1750 returns = "(" + ", ".join(map(str, self.returns)) + ")" 1751 return f"{self.name}({all_arguments_str}) -> {returns}" 1752 1753 1754# Here is the rest of the data model, described more briefly. 1755 1756 1757# Simplified version for what actually shows up in built-ins. 1758# Look at alias_info.h for expanded syntax. If you need the structure, 1759# you also need to make this structure recursive so it can be lined 1760# up with the type components too. For primitives this isn't really 1761# necessary 1762@dataclass(frozen=True) 1763class Annotation: 1764 # Typically only has one element. Not actually a set so 1765 # we can conveniently assume it is canonically ordered 1766 alias_set: tuple[str, ...] 1767 is_write: bool 1768 alias_set_after: tuple[str, ...] 1769 1770 @staticmethod 1771 def parse(ann: str) -> Annotation: 1772 # TODO: implement a proper parser if this gets more ugly 1773 # Regex Explanation: 1774 # Example: "a! -> a|b" 1775 # Group #1: alias before optional '|', required. Matches the first 1776 # character 'a' in the example 1777 # Group #2: optional alias set after optional '|', matches empty string 1778 # in the example 1779 # Group #3: optional "is write" flag, matches '!' in the example. 1780 # Group #4: optional section containing arrow, matches " -> a|b" in the 1781 # example. 1782 # Group #5: optional alias after set, supports wildcard, matches "a|b" 1783 # in the example. 1784 # Group #6: optional sub-section of alias after set, matches "|b" in the 1785 # example. 1786 m = re.match(r"^([a-z])(\|[a-z])*(!?)( -> (\*|[a-z](\|[a-z])*))?$", ann) 1787 1788 assert m is not None, f"unrecognized alias annotation {ann}" 1789 before_alias = m.group(1) + (m.group(2) if m.group(2) else "") 1790 alias_set = tuple(before_alias.split("|")) 1791 is_write = m.group(3) == "!" 1792 assert not ( 1793 is_write and len(alias_set) > 1 1794 ), f"alias set larger than 1 is not mutable, got {ann} instead." 1795 after_set = tuple(m.group(5).split("|")) if m.group(5) else () 1796 assert not ( 1797 len(before_alias) > 1 and len(after_set) > 1 1798 ), f"before alias set and after alias set cannot be larger than 1 at the same time, got {ann} instead." 1799 r = Annotation( 1800 alias_set=alias_set, is_write=is_write, alias_set_after=after_set 1801 ) 1802 assert str(r) == ann, f"{r} != {ann}" 1803 return r 1804 1805 def __str__(self) -> str: 1806 alias_set = "|".join(self.alias_set) 1807 if self.is_write: 1808 alias_set = f"{alias_set}!" 1809 alias_set_after = "|".join(self.alias_set_after) 1810 if alias_set_after: 1811 alias_set = f'{alias_set}{" -> "}{alias_set_after}' 1812 return alias_set 1813 1814 1815# The base class for the type system. This is also loosely modeled 1816# off of jit_type.h, but we've simplified the hierarchy to focus 1817# in on the aspects of the type system that matter for code generation 1818# (for example, there's no SingleElementType subclass anymore). 1819# You never actually construct a Type; usually it's going to be one 1820# of the subclasses. If Python had ADTs this would be one! 1821@dataclass(frozen=True) 1822class Type: 1823 @staticmethod 1824 def parse(t: str) -> Type: 1825 r = Type._parse(t) 1826 assert str(r) == t, f"{r} != {t}" 1827 return r 1828 1829 @staticmethod 1830 def _parse(t: str) -> Type: 1831 m = re.match(r"^(.+)\?$", t) 1832 if m is not None: 1833 return OptionalType(Type.parse(m.group(1))) 1834 m = re.match(r"^(.+)\[([0-9]+)?\]$", t) 1835 if m is not None: 1836 size = int(m.group(2)) if m.group(2) is not None else None 1837 return ListType(elem=Type.parse(m.group(1)), size=size) 1838 1839 # '__torch__.torch.classes.' is the prefix for custom class 1840 m = re.match(r"^__torch__\.torch\.classes\.([a-zA-Z0-9_.]+)$", t) 1841 if m is not None: 1842 return CustomClassType(m.group(1)) 1843 try: 1844 return BaseType(BaseTy[t]) 1845 except KeyError as e: 1846 raise RuntimeError(f"unrecognized type {t}") from e 1847 1848 def __str__(self) -> str: 1849 raise NotImplementedError 1850 1851 # WARNING: These concepts are not very well-defined. For example, 1852 # is "int?" nullable? How about "int?[]". They are defined 1853 # so we can conveniently generate legacy Declarations.yaml but 1854 # really we should probably just remove these at some point 1855 1856 def is_base_ty_like(self, base_ty: BaseTy) -> bool: 1857 raise NotImplementedError 1858 1859 def is_tensor_like(self) -> bool: 1860 return self.is_base_ty_like(BaseTy.Tensor) 1861 1862 def is_generator_like(self) -> bool: 1863 return self.is_base_ty_like(BaseTy.Generator) 1864 1865 def is_symint_like(self) -> bool: 1866 return self.is_base_ty_like(BaseTy.SymInt) 1867 1868 def is_nullable(self) -> bool: 1869 raise NotImplementedError 1870 1871 def is_list_like(self) -> ListType | None: 1872 raise NotImplementedError 1873 1874 1875# Base types are simple, atomic types with no further structure 1876class BaseTy(Enum): 1877 Generator = auto() 1878 ScalarType = auto() 1879 Tensor = auto() 1880 int = auto() 1881 Dimname = auto() 1882 DimVector = auto() 1883 float = auto() 1884 str = auto() 1885 bool = auto() 1886 Layout = auto() 1887 Device = auto() 1888 DeviceIndex = auto() 1889 Scalar = auto() 1890 MemoryFormat = auto() 1891 QScheme = auto() 1892 Storage = auto() 1893 Stream = auto() 1894 SymInt = auto() 1895 SymBool = auto() 1896 ConstQuantizerPtr = auto() # TODO: rename 1897 GraphModule = auto() 1898 1899 1900@dataclass(frozen=True) 1901class BaseType(Type): 1902 name: BaseTy 1903 1904 def __str__(self) -> str: 1905 return f"{self.name.name}" 1906 1907 def is_base_ty_like(self, base_ty: BaseTy) -> bool: 1908 return self.name == base_ty 1909 1910 def is_nullable(self) -> bool: 1911 return False 1912 1913 def is_list_like(self) -> ListType | None: 1914 return None 1915 1916 def is_symint_like(self) -> bool: 1917 return self.name == BaseTy.SymInt 1918 1919 1920# Optional types may be specified, or may also be validly given None 1921@dataclass(frozen=True) 1922class OptionalType(Type): 1923 elem: Type 1924 1925 def __str__(self) -> str: 1926 return f"{self.elem}?" 1927 1928 def is_base_ty_like(self, base_ty: BaseTy) -> bool: 1929 return self.elem.is_base_ty_like(base_ty) 1930 1931 def is_symint_like(self) -> bool: 1932 return self.elem.is_symint_like() 1933 1934 def is_nullable(self) -> bool: 1935 return True 1936 1937 def is_list_like(self) -> ListType | None: 1938 return self.elem.is_list_like() 1939 1940 1941# A type representing a PyTorch custom class 1942@dataclass(frozen=True) 1943class CustomClassType(Type): 1944 class_name: str 1945 1946 def __str__(self) -> str: 1947 """ 1948 Return the class name will prefix __torch__.torch.classes 1949 """ 1950 return f"__torch__.torch.classes.{self.class_name}" 1951 1952 def is_base_ty_like(self, base_ty: BaseTy) -> bool: 1953 return False 1954 1955 def is_symint_like(self) -> bool: 1956 return False 1957 1958 def is_nullable(self) -> bool: 1959 """ 1960 Assume a custom class is not nullable. 1961 """ 1962 return False 1963 1964 def is_list_like(self) -> ListType | None: 1965 return None 1966 1967 1968# List types specify that we may have multiples of an element. We 1969# also support explicit sizes on list types, but these have 1970# some nontrivial semantics! (However, for C++ API purposes, explicit 1971# sizes are mostly erased from the type system.) 1972# 1973# DANGER WILL ROBINSON: C++ elaboration depends on elem type; e.g., 1974# int[] elaborates differently than bool[3]! 1975@dataclass(frozen=True) 1976class ListType(Type): 1977 elem: Type 1978 size: int | None 1979 1980 def __str__(self) -> str: 1981 size = f"{self.size}" if self.size else "" 1982 return f"{self.elem}[{size}]" 1983 1984 def is_base_ty_like(self, base_ty: BaseTy) -> bool: 1985 return self.elem.is_base_ty_like(base_ty) 1986 1987 def is_symint_like(self) -> bool: 1988 return self.elem.is_symint_like() 1989 1990 def is_nullable(self) -> bool: 1991 return self.elem.is_nullable() 1992 1993 def is_list_like(self) -> ListType | None: 1994 return self 1995 1996 1997@dataclass(frozen=True) 1998class Argument: 1999 # NB: I didn't put kwarg_only as a boolean field here, unlike 2000 # c10::Argument, so that printing works correctly 2001 2002 name: str 2003 type: Type 2004 default: str | None 2005 2006 # The semantics of the annotation field are a little strange. 2007 # 2008 # Alias annotations parametrize Tensors (since Tensors are the only things 2009 # that can alias.) This motivates why I write Tensor(a!)? (and not, for 2010 # example, Tensor?(a!)), because the (a!) describes aliasing on the tensor, 2011 # which may be optional (i.e., the alias annotation should bind first to 2012 # Tensor, before the optional postfix annotation). 2013 # 2014 # However, despite being a property of Tensor, we (and c10::Argument) 2015 # store the annotation at the top level of the Argument, rather than 2016 # inside the embedded Tensor type. In the C++ version of this 2017 # class, we then go through great lengths to mimic the type 2018 # structure in the annotation structure so we can correlate 2019 # annotations with types. 2020 # 2021 # Now, it turns out, in all applications in code generation, the 2022 # structure of annotated types is very simple. So we just hard 2023 # code it here. But if we ever do get anything more complex, this 2024 # model will have to change! 2025 annotation: Annotation | None 2026 2027 @property 2028 def alias_info(self) -> Annotation | None: 2029 return self.annotation 2030 2031 @staticmethod 2032 def parse(arg: str) -> Argument: 2033 name: str 2034 default: str | None 2035 assert " " in arg, f"illegal argument '{arg}'" 2036 if "=" in arg: 2037 assert arg.count("=") == 1, f"illegal argument with default value: '{arg}'" 2038 type_and_annot_and_name, default = arg.split("=") 2039 type_and_annot, name = type_and_annot_and_name.rsplit(" ", 1) 2040 name_and_default = f"{name}={default}" 2041 else: 2042 type_and_annot, name_and_default = arg.rsplit(" ", 1) 2043 name = name_and_default 2044 default = None 2045 # TODO: deduplicate annotation matching with Return 2046 match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot) 2047 annotation: Annotation | None 2048 if match: 2049 # If you update this, make sure the __str__ still works too 2050 assert match.group(2) in [ 2051 "", 2052 "?", 2053 "[]", 2054 ], "unrecognized alias analysis form with Tensor" 2055 type_s = "Tensor" + match.group(2) 2056 annotation = Annotation.parse(match.group(1)) 2057 else: 2058 type_s = type_and_annot 2059 annotation = None 2060 type = Type.parse(type_s) 2061 r = Argument( 2062 name=name, 2063 type=type, 2064 default=default, 2065 annotation=annotation, 2066 ) 2067 assert str(r) == arg, f"{str(r)} != {arg}" 2068 return r 2069 2070 @property 2071 def is_write(self) -> bool: 2072 return self.annotation is not None and self.annotation.is_write 2073 2074 def __str__(self) -> str: 2075 type = f"{self.type}" 2076 if self.annotation: 2077 assert type in ["Tensor", "Tensor?", "Tensor[]"] 2078 type = type.replace("Tensor", f"Tensor({self.annotation})") 2079 if self.name is None: 2080 return type 2081 else: 2082 mb_default = "" 2083 if self.default: 2084 mb_default = f"={self.default}" 2085 return f"{type} {self.name}{mb_default}" 2086 2087 2088@dataclass(frozen=True) 2089class Return: 2090 name: str | None 2091 type: Type 2092 annotation: Annotation | None 2093 2094 @property 2095 def alias_info(self) -> Annotation | None: 2096 return self.annotation 2097 2098 @staticmethod 2099 def parse(arg: str) -> Return: 2100 name: str | None 2101 if " " in arg: 2102 type_and_annot, name = arg.rsplit(" ", 1) 2103 else: 2104 type_and_annot = arg 2105 name = None 2106 match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot) 2107 annotation: Annotation | None 2108 if match: 2109 # If you update this, make sure the __str__ still works too 2110 assert match.group(2) in [ 2111 "", 2112 "?", 2113 "[]", 2114 ], "unrecognized alias analysis form with Tensor" 2115 type_s = "Tensor" + match.group(2) 2116 annotation = Annotation.parse(match.group(1)) 2117 else: 2118 type_s = type_and_annot 2119 annotation = None 2120 type = Type.parse(type_s) 2121 r = Return( 2122 name=name, 2123 type=type, 2124 annotation=annotation, 2125 ) 2126 assert str(r) == arg, f"{str(r)} != {arg}" 2127 return r 2128 2129 @property 2130 def is_write(self) -> bool: 2131 return self.annotation is not None and self.annotation.is_write 2132 2133 def __str__(self) -> str: 2134 type = f"{self.type}" 2135 if self.annotation: 2136 assert type in ["Tensor", "Tensor?", "Tensor[]"] 2137 type = type.replace("Tensor", f"Tensor({self.annotation})") 2138 if self.name is None: 2139 return type 2140 else: 2141 return f"{type} {self.name}" 2142 2143 2144# Represents the self argument for functions that may be methods 2145@dataclass(frozen=True) 2146class SelfArgument: 2147 argument: Argument 2148 2149 2150# Bundle of arguments that represent a TensorOptions. This is mostly 2151# relevant for the public C++ API but we bake it into the core data 2152# model because other APIs often have to interact with it 2153@dataclass(frozen=True) 2154class TensorOptionsArguments: 2155 dtype: Argument 2156 layout: Argument 2157 device: Argument 2158 pin_memory: Argument 2159 2160 def all(self) -> Sequence[Argument]: 2161 return [self.dtype, self.layout, self.device, self.pin_memory] 2162 2163 2164@dataclass(frozen=True) 2165class Arguments: 2166 # pre_self_positional is usually empty, but is notably non-empty 2167 # for where.self, where the condition argument comes before the 2168 # self argument 2169 pre_self_positional: tuple[Argument, ...] 2170 self_arg: SelfArgument | None 2171 post_self_positional: tuple[Argument, ...] 2172 2173 pre_tensor_options_kwarg_only: tuple[Argument, ...] 2174 tensor_options: TensorOptionsArguments | None 2175 # post_tensor_options is typically memory format, which should be 2176 # part of tensor options but isn't right now, and is usually 2177 # placed after the tensor options arguments 2178 post_tensor_options_kwarg_only: tuple[Argument, ...] 2179 2180 # Unlike in the previous codegen, we have factored out 'out' arguments 2181 # in the canonical representation, removing them from kwarg 2182 # arguments. This choice is justified by numerous downstream 2183 # transformations which treat out arguments specially; additionally, 2184 # you can see that canonicity is not violated! 2185 out: tuple[Argument, ...] # these are also kwarg-only 2186 2187 @property 2188 def flat_non_out(self) -> Sequence[Argument]: 2189 ret: list[Argument] = [] 2190 ret.extend(self.flat_positional) 2191 ret.extend(self.flat_kwarg_only) 2192 return ret 2193 2194 @property 2195 def flat_positional(self) -> Sequence[Argument]: 2196 ret: list[Argument] = [] 2197 ret.extend(self.pre_self_positional) 2198 if self.self_arg is not None: 2199 ret.append(self.self_arg.argument) 2200 ret.extend(self.post_self_positional) 2201 return ret 2202 2203 @property 2204 def post_self_positional_mutable(self) -> Sequence[Argument]: 2205 return [a for a in self.post_self_positional if a.is_write] 2206 2207 # NB: doesn't contain out arguments 2208 @property 2209 def flat_kwarg_only(self) -> Sequence[Argument]: 2210 ret: list[Argument] = [] 2211 ret.extend(self.pre_tensor_options_kwarg_only) 2212 if self.tensor_options is not None: 2213 ret.extend(self.tensor_options.all()) 2214 ret.extend(self.post_tensor_options_kwarg_only) 2215 return ret 2216 2217 @property 2218 def flat_all(self) -> Sequence[Argument]: 2219 ret: list[Argument] = [] 2220 ret.extend(self.flat_positional) 2221 ret.extend(self.flat_kwarg_only) 2222 ret.extend(self.out) 2223 return ret 2224 2225 @property 2226 def non_out( 2227 self, 2228 ) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]: 2229 ret: list[Argument | SelfArgument | TensorOptionsArguments] = [] 2230 ret.extend(self.positional) 2231 ret.extend(self.kwarg_only) 2232 return ret 2233 2234 @property 2235 def positional(self) -> Sequence[Argument | SelfArgument]: 2236 ret: list[Argument | SelfArgument] = [] 2237 ret.extend(self.pre_self_positional) 2238 if self.self_arg is not None: 2239 ret.append(self.self_arg) 2240 ret.extend(self.post_self_positional) 2241 return ret 2242 2243 @property 2244 def kwarg_only(self) -> Sequence[Argument | TensorOptionsArguments]: 2245 ret: list[Argument | TensorOptionsArguments] = [] 2246 ret.extend(self.pre_tensor_options_kwarg_only) 2247 if self.tensor_options is not None: 2248 ret.append(self.tensor_options) 2249 ret.extend(self.post_tensor_options_kwarg_only) 2250 return ret 2251 2252 @property 2253 def all(self) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]: 2254 ret: list[Argument | SelfArgument | TensorOptionsArguments] = [] 2255 ret.extend(self.positional) 2256 ret.extend(self.kwarg_only) 2257 ret.extend(self.out) 2258 return ret 2259 2260 def mutable_arg_names(self) -> list[str]: 2261 return [ 2262 a.name 2263 for a in self.flat_all 2264 if a.annotation is not None and a.annotation.is_write 2265 ] 2266 2267 def has_tensor_arg(self) -> bool: 2268 return any(a.type.is_tensor_like() for a in self.flat_non_out) 2269 2270 def has_symint_arg(self) -> bool: 2271 return any(a.type.is_symint_like() for a in self.flat_non_out) 2272 2273 def has_generator_arg(self) -> bool: 2274 return any(a.type.is_generator_like() for a in self.flat_non_out) 2275 2276 def signature(self, *, strip_default: bool = False) -> Arguments: 2277 # dataclasses.replace could be used here, but it is less 2278 # type safe so for now I've opted to type everything out 2279 def strip_arg_annotation(a: Argument) -> Argument: 2280 return Argument( 2281 name=a.name, 2282 type=a.type, 2283 default=a.default if not strip_default else None, 2284 annotation=None, 2285 ) 2286 2287 return Arguments( 2288 pre_self_positional=tuple( 2289 map(strip_arg_annotation, self.pre_self_positional) 2290 ), 2291 self_arg=SelfArgument(strip_arg_annotation(self.self_arg.argument)) 2292 if self.self_arg is not None 2293 else None, 2294 post_self_positional=tuple( 2295 map(strip_arg_annotation, self.post_self_positional) 2296 ), 2297 # Since TensorOptions are dropped, the post_tensor_options_kwargs are 2298 # converted to pre_tensor_options_kwargs 2299 pre_tensor_options_kwarg_only=tuple( 2300 map(strip_arg_annotation, self.pre_tensor_options_kwarg_only) 2301 ) 2302 + tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)), 2303 # TensorOptions are dropped in signature, 2304 # so we can pair factory functions with their out= variants. 2305 tensor_options=None, 2306 post_tensor_options_kwarg_only=(), 2307 # out arguments are dropped in signature 2308 out=(), 2309 ) 2310 2311 def remove_self_annotation(self) -> Arguments: 2312 assert self.self_arg is not None 2313 return dataclasses.replace( 2314 self, 2315 self_arg=SelfArgument( 2316 dataclasses.replace(self.self_arg.argument, annotation=None) 2317 ), 2318 ) 2319 2320 def with_out_args(self, outs: list[Argument]) -> Arguments: 2321 assert len(self.out) == 0 2322 return dataclasses.replace( 2323 self, 2324 out=tuple(outs), 2325 ) 2326 2327 @staticmethod 2328 def _preparse(args: str) -> tuple[list[Argument], list[Argument], list[Argument]]: 2329 positional: list[Argument] = [] 2330 kwarg_only: list[Argument] = [] 2331 out: list[Argument] = [] 2332 arguments_acc = positional 2333 2334 # TODO: Use a real parser here; this will get bamboozled 2335 # by signatures that contain things like std::array<bool, 2> (note the space) 2336 for arg in args.split(", "): 2337 if not arg: 2338 continue 2339 if arg == "*": 2340 assert ( 2341 arguments_acc is positional 2342 ), "invalid syntax: kwarg-only specifier * can only occur once" 2343 arguments_acc = kwarg_only 2344 continue 2345 parg = Argument.parse(arg) 2346 # Currently, we rely directly on the invariant that there are NO 2347 # kwarg-only mutating arguments. If you want to relax this, 2348 # we will need a more semantic way of matching that takes 2349 # into account return arguments. In that case, you will have 2350 # to manage out computation a level up, in FunctionSchema. See Note 2351 # [is_out_fn] 2352 if parg.annotation is not None and parg.annotation.is_write: 2353 if arguments_acc is positional: 2354 pass # do nothing 2355 elif arguments_acc is kwarg_only: 2356 arguments_acc = out 2357 else: 2358 assert arguments_acc is not out 2359 arguments_acc.append(parg) 2360 2361 return positional, kwarg_only, out 2362 2363 @staticmethod 2364 def parse(args: str) -> Arguments: 2365 """ 2366 Input: 'int x, int y, int z' 2367 """ 2368 2369 # We do this in two phases. First we parse into three 2370 # main categories: positional, kwarg_only, out. 2371 # Then, we reparse positional and kwarg_only to separate 2372 # out the self argument and tensor options arguments. 2373 2374 positional, kwarg_only, out = Arguments._preparse(args) 2375 2376 # Split self argument 2377 self_ix = None 2378 for i, a in enumerate(positional): 2379 if a.name == "self": 2380 self_ix = i 2381 break 2382 pre_self_positional: list[Argument] 2383 self_arg: SelfArgument | None 2384 post_self_positional: list[Argument] 2385 if self_ix is not None: 2386 pre_self_positional = positional[:self_ix] 2387 self_arg = SelfArgument(positional[self_ix]) 2388 post_self_positional = positional[self_ix + 1 :] 2389 else: 2390 pre_self_positional = [] 2391 self_arg = None 2392 post_self_positional = positional 2393 2394 # Group tensor options arguments 2395 pre_tensor_options_kwarg_only: list[Argument] = [] 2396 tensor_options: TensorOptionsArguments | None = None 2397 post_tensor_options_kwarg_only: list[Argument] = [] 2398 kwarg_only_acc = pre_tensor_options_kwarg_only 2399 2400 def pred(name: str, ty: Type) -> Callable[[Argument], bool]: 2401 return lambda a: a.name == name and a.type in [ty, OptionalType(ty)] 2402 2403 predicates = [ # order matters 2404 pred("dtype", Type.parse("ScalarType")), 2405 pred("layout", Type.parse("Layout")), 2406 pred("device", Type.parse("Device")), 2407 pred("pin_memory", Type.parse("bool")), 2408 ] 2409 2410 i = 0 2411 while i < len(kwarg_only): 2412 # If there is enough space... 2413 if i <= len(kwarg_only) - len(predicates): 2414 # And the next len(predicates) arguments look like TensorOptions arguments 2415 if all( 2416 p(a) 2417 for p, a in zip(predicates, kwarg_only[i : i + len(predicates)]) 2418 ): 2419 assert kwarg_only_acc is pre_tensor_options_kwarg_only 2420 # Group them together as one argument 2421 tensor_options = TensorOptionsArguments( 2422 dtype=kwarg_only[i], 2423 layout=kwarg_only[i + 1], 2424 device=kwarg_only[i + 2], 2425 pin_memory=kwarg_only[i + 3], 2426 ) 2427 i += len(predicates) 2428 kwarg_only_acc = post_tensor_options_kwarg_only 2429 continue 2430 kwarg_only_acc.append(kwarg_only[i]) 2431 i += 1 2432 2433 return Arguments( 2434 pre_self_positional=tuple(pre_self_positional), 2435 self_arg=self_arg, 2436 post_self_positional=tuple(post_self_positional), 2437 pre_tensor_options_kwarg_only=tuple(pre_tensor_options_kwarg_only), 2438 tensor_options=tensor_options, 2439 post_tensor_options_kwarg_only=tuple(post_tensor_options_kwarg_only), 2440 out=tuple(out), 2441 ) 2442 2443 def __str__(self) -> str: 2444 all_arguments: list[str] = [] 2445 all_arguments.extend(map(str, self.flat_positional)) 2446 if self.flat_kwarg_only or self.out: 2447 all_arguments.append("*") 2448 all_arguments.extend(map(str, self.flat_kwarg_only)) 2449 all_arguments.extend(map(str, self.out)) 2450 return ", ".join(all_arguments) 2451 2452 def __post_init__(self) -> None: 2453 # TODO: These invariants are weirdly asymmetric? 2454 # TODO: Fancier types? 2455 if self.self_arg is None: 2456 assert not self.pre_self_positional 2457 if self.tensor_options is None: 2458 assert not self.post_tensor_options_kwarg_only 2459 2460 # We don't allow any of the following to have argument annotations, 2461 # to keep things simple. 2462 mutable_pre_self_positionals = [ 2463 a 2464 for a in self.pre_self_positional 2465 if a.annotation is not None and a.annotation.is_write 2466 ] 2467 assert ( 2468 len(mutable_pre_self_positionals) == 0 2469 ), "mutable pre_self_positional arguments are not currently supported in the schema" 2470 2471 2472# Names that validly are __iXXX__ indicating inplace operations. 2473# Taken from https://www.python.org/dev/peps/pep-0203/#new-methods 2474# NB: PyTorch hasn't actually implemented all of these 2475AUGMENTED_ASSIGNMENT_NAMES = [ 2476 "add", 2477 "sub", 2478 "mul", 2479 "div", 2480 "mod", 2481 "pow", 2482 "lshift", 2483 "rshift", 2484 "and", 2485 "xor", 2486 "or", 2487] 2488 2489 2490# A BaseOperatorName is what we think of the operator name, without 2491# the overload name. Unusually, we don't represent this as just a 2492# string; instead, we directly represent a few important semantic 2493# bits of information we derive from the string: namely whether 2494# or not it's inplace (add_) and whether or not it's a double-underscore 2495# method (__add__) 2496@dataclass(frozen=True) 2497class BaseOperatorName: 2498 base: str 2499 inplace: bool 2500 dunder_method: bool 2501 # Note [Overload Ambiguity With Functional Variants] 2502 # A handful of operators have both a "mutable" and a "functional" variant. 2503 # (native_batch_norm is a good example, although this isn't the case today). 2504 # For those operators, the mutable and functional variant take in the same set of 2505 # arguments, but have different alias annotations. 2506 # this makes it ambiguous when you try to resolve an OverloadPacket into an overload, 2507 # given a set of input arguments. 2508 # 2509 # So instead of making the "functional" variant in this case a real overload, e.g: 2510 # native_batch_norm (mutable variant) 2511 # native_batch_norm.functional (functional variant) 2512 # we make it a new base operator, 2513 # native_batch_norm_functional (functional variant) 2514 # 2515 # In an ideal world, we would probably invert this so the operators were: 2516 # native_batch_norm.mutable (mutable variant) 2517 # native_batch_norm (functional variant) 2518 # 2519 # Doing that is BC-breaking though, so we're stuck with the above modeling. 2520 functional_overload: bool = False 2521 2522 @staticmethod 2523 def parse(op: str) -> BaseOperatorName: 2524 assert op != "" 2525 assert not op.endswith("_out"), ( 2526 "_out suffix is reserved and not permitted for operator names; " 2527 "did you mean to specify an out overload name instead?" 2528 ) 2529 m = re.match(r"^__([^_]+)__$", op) 2530 if m is not None: 2531 dunder_method = True 2532 base = m.group(1) 2533 if any(base == f"i{n}" for n in AUGMENTED_ASSIGNMENT_NAMES): 2534 inplace = True 2535 base = base[1:] 2536 else: 2537 inplace = False 2538 # temporary, this is not intrinsically true but 2539 # has been historically true for dunder methods 2540 # we support (but, if we ever got, say, __int__, this would 2541 # be wrong!) 2542 assert base[0] != "i" 2543 else: 2544 dunder_method = False 2545 base = op 2546 if base[-1] == "_": 2547 inplace = True 2548 base = base[:-1] 2549 else: 2550 inplace = False 2551 2552 # See Note [Overload Ambiguity With Functional Variants] 2553 functional_suffix = "_functional" 2554 if base.endswith(functional_suffix): 2555 functional_overload = True 2556 base = base[: -len(functional_suffix)] 2557 # This seems complicated and unnecessary, so banning dunder methods 2558 # for now on ops that have a functional + mutable variant (like native_batch_norm). 2559 assert not dunder_method and not inplace 2560 else: 2561 functional_overload = False 2562 2563 r = BaseOperatorName( 2564 base=base, 2565 inplace=inplace, 2566 dunder_method=dunder_method, 2567 functional_overload=functional_overload, 2568 ) 2569 assert str(r) == op, f"{str(r)} != {op}" 2570 return r 2571 2572 def __str__(self) -> str: 2573 if self.dunder_method: 2574 i = "i" if self.inplace else "" 2575 return f"__{i}{self.base}__" 2576 else: 2577 i = ( 2578 "_" 2579 if self.inplace 2580 else "_functional" 2581 if self.functional_overload 2582 else "" 2583 ) 2584 return f"{self.base}{i}" 2585 2586 2587# Operator name is the base operator name along with the (typically not 2588# user visible) overload string. 2589@dataclass(frozen=True) 2590class OperatorName: 2591 name: BaseOperatorName 2592 overload_name: str 2593 2594 @staticmethod 2595 def parse(op_name: str) -> OperatorName: 2596 if "." in op_name: 2597 name, overload_name = op_name.split(".", 1) 2598 else: 2599 name = op_name 2600 overload_name = "" 2601 r = OperatorName(name=BaseOperatorName.parse(name), overload_name=overload_name) 2602 assert str(r) == op_name, f"{str(r)} != {op_name}" 2603 return r 2604 2605 def __str__(self) -> str: 2606 if self.overload_name: 2607 return f"{self.name}.{self.overload_name}" 2608 else: 2609 return f"{self.name}" 2610 2611 # NB: This must be synchronized with the naming scheme in 2612 # aten/src/ATen/templates/Operators.h 2613 # Given a function schema "aten::op.overload(...)", 2614 # If there is no overload name, this returns f"{op}" 2615 # If there is an overload name, this returns f"{op}_{overload}" 2616 def unambiguous_name(self) -> str: 2617 if self.overload_name: 2618 return f"{self.name}_{self.overload_name}" 2619 else: 2620 return f"{self.name}" 2621 2622 def remove_inplace(self) -> OperatorName: 2623 return OperatorName( 2624 name=BaseOperatorName( 2625 base=self.name.base, 2626 inplace=False, 2627 dunder_method=self.name.dunder_method, 2628 ), 2629 overload_name=self.overload_name, 2630 ) 2631 2632 def with_overload(self, overload: str) -> OperatorName: 2633 return OperatorName( 2634 name=BaseOperatorName( 2635 base=self.name.base, 2636 inplace=False, 2637 dunder_method=self.name.dunder_method, 2638 ), 2639 overload_name=overload, 2640 ) 2641 2642 2643def gets_generated_out_inplace_wrapper( 2644 f: NativeFunction, g: NativeFunctionsGroup, b: BackendIndex 2645) -> bool: 2646 return ( 2647 f.func.kind() is not SchemaKind.functional 2648 and not b.has_kernel(f) 2649 and b.has_kernel(g.functional) 2650 ) 2651 2652 2653# NativeFunction objects that are views (f.is_view_op returns True) 2654# are added into a `NativeFunctionsViewGroup`, which we can use to 2655# easily access the generated (optional) view_copy NativeFunction. 2656# It's convenient to group them together, so we pair them up in NativeFunctionsViewGroup. 2657# See Note [Codegen'd {view}_copy Operators] 2658# 2659# One property of this representation is that in order for a view-like op to be part of 2660# a NativeFunctionsViewGroup, the "aliasing" version of that view op must exist. 2661# There's one case where that doesn't happen: we have a non-aliasing `narrow_copy.out` op, 2662# but don't have corresponding aliasing `narrow.out` op. 2663# This means that `narrow_copy.out` won't appear as a NativeFunctionsViewGroup. 2664@dataclass(frozen=True) 2665class NativeFunctionsViewGroup: 2666 view: NativeFunction 2667 # Note: the {view}_copy operator is optional because we currently don't generate copy variants 2668 # for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views 2669 # (we already get them "for free" through decomposition) 2670 view_copy: NativeFunction | None 2671 # view_inplace ops are also optional, but every view_inplace op should have out-of-place variant. 2672 view_inplace: NativeFunction | None 2673 2674 def __post_init__(self) -> None: 2675 assert self.view.is_view_op 2676 if self.view_copy is None: 2677 assert not gets_generated_view_copy(self.view), ( 2678 f"{str(self.view.func.name)} appears to be a new operator that aliases its inputs." 2679 " The codegen expects you to add a corresponding operator to native_functions.yaml:" 2680 f" {get_view_copy_name(self.view)!s}." 2681 " See Note [view_copy NativeFunctions] for details." 2682 ) 2683 else: 2684 assert self.view_copy.func.name.name.base.endswith(("_copy", "_scatter")) 2685 assert self.view.func.signature() == self.view_copy.func.signature( 2686 strip_view_copy_name=True, 2687 ) 2688 assert "view_copy" in self.view_copy.tags, ( 2689 f"{str(self.view_copy.func.name), str(self.view.tags)} appears to be a view_copy operator. The codegen expects" 2690 " view_copy operators to be annotated with the 'view_copy' tag in native_functions.yaml." 2691 " See Note [view_copy NativeFunction] for details." 2692 ) 2693 if self.view_inplace is not None: 2694 assert self.view.func.signature() == self.view_inplace.func.signature() 2695 2696 if self.view.has_composite_implicit_autograd_kernel: 2697 if self.view_inplace is not None: 2698 assert self.view_inplace.has_composite_implicit_autograd_kernel, ( 2699 f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either" 2700 " both have CompositeImplicitAutograd kernels, or both not have composite kernels." 2701 ) 2702 if self.view.has_composite_implicit_autograd_nested_tensor_kernel: 2703 if self.view_inplace is not None: 2704 assert ( 2705 self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel 2706 ), ( 2707 f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either" 2708 " both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels." 2709 ) 2710 2711 def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]: 2712 yield self.view 2713 if self.view_inplace is not None: 2714 yield self.view_inplace 2715 if self.view_copy is not None and include_copy: 2716 yield self.view_copy 2717 2718 @property 2719 def root_name(self) -> str: 2720 return self.view.root_name 2721 2722 @property 2723 def composite(self) -> bool: 2724 # We currently assert that the "group" is consistent. 2725 # If the view op is composite, then its view_inplace op is too. 2726 return self.view.has_composite_implicit_autograd_kernel 2727 2728 2729def gets_generated_view_copy(f: NativeFunction) -> bool: 2730 # Only aliasing (view) operators get a copy variant. 2731 if not f.is_view_op: 2732 return False 2733 # We don't need to bother generating copy variants for CompositeImplicitAutograd ops, 2734 # because we can let them decompose into base view ops. 2735 if f.has_composite_implicit_autograd_kernel: 2736 return False 2737 # We also don't need to generate copy variants for inplace views. 2738 if "inplace_view" in f.tags: 2739 return False 2740 # Assume ops ending in _inverse have manually-defined copy variants 2741 # (e.g. slice_inverse() has the copy variant slice_scatter()). 2742 # We -could- probably generate these as well, but the codegen will be 2743 # slightly different, and hand-writing these few kernels keeps codegen 2744 # complexity lower. 2745 if f.func.name.name.base.endswith("_inverse"): 2746 return False 2747 return True 2748 2749 2750# Given a NativeFunction that corresponds to a view op, 2751# returns the OperatorName of the corresponding "copy" variant of the op. 2752def get_view_copy_name(f: NativeFunction) -> OperatorName: 2753 # Right now, when asking for a view op's corresponding "view_copy" name 2754 # we assert for sanity that the op is allowed to have a generated view_copy variant. 2755 # (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op). 2756 # However, narrow_copy() already exists as an op directly in native_functions.yaml. 2757 # I'm hardcoding narrow_copy here for now to maintain the assert, 2758 # But we could also just get rid of the assert. 2759 list_of_ops_with_explicit_view_copy_operators = ["narrow"] 2760 if str(f.func.name) not in list_of_ops_with_explicit_view_copy_operators: 2761 assert gets_generated_view_copy(f) 2762 2763 base_name = f"{f.func.name.name.base}_copy" 2764 view_copy_name = OperatorName( 2765 name=BaseOperatorName( 2766 base=base_name, inplace=False, dunder_method=f.func.name.name.dunder_method 2767 ), 2768 overload_name=f.func.name.overload_name, 2769 ) 2770 return view_copy_name 2771 2772 2773# Helper functions for parsing argument lists (both inputs and returns) 2774 2775 2776def parse_returns(return_decl: str) -> tuple[Return, ...]: 2777 """ 2778 Input: '()' 2779 Output: [] 2780 """ 2781 if return_decl == "()": 2782 return () 2783 if return_decl[0] == "(" and return_decl[-1] == ")": 2784 return_decl = return_decl[1:-1] 2785 return tuple(Return.parse(arg) for arg in return_decl.split(", ")) 2786 2787 2788# A Precompute instance consists of a map from kernel argument name 2789# to the list of Argument instances that should replace that 2790# kernel argument in the impl function. 2791@dataclass(frozen=True) 2792class Precompute: 2793 # A map from kernel argument name -> a list of precomputed 2794 # elements that replaces/supersedes it. 2795 replace: dict[str, list[Argument]] 2796 # List of precomputed args added without replacement 2797 add: list[Argument] 2798 2799 @staticmethod 2800 def parse(src: object) -> Precompute: 2801 assert isinstance(src, list) 2802 2803 # src is a list of strings of the format: 2804 # {kernel param name} -> {replacement decl}[, {replacement decl}, ...] 2805 # [{add decl}[, {add decl}, ...]] 2806 # The last line is optional and contains the precomputed parameters that are 2807 # added without replacement. 2808 # The other lines are parsed to get the names of which precomputed elements 2809 # should replace which kernel arguments. 2810 add_args = [] 2811 if " -> " not in src[-1]: 2812 add_list = src[-1].split(",") 2813 add_args = [Argument.parse(name.strip()) for name in add_list] 2814 src = src[:-1] 2815 2816 replace = {} 2817 for raw_replace_item in src: 2818 assert isinstance(raw_replace_item, str) 2819 assert " -> " in raw_replace_item, ( 2820 "precomputed parameters without replacement" 2821 " are allowed only in the last line" 2822 ) 2823 2824 arg, with_list_raw = raw_replace_item.split(" -> ") 2825 assert ( 2826 " " not in arg 2827 ), f"illegal kernel param name '{arg}' in precomputed parameters'" 2828 with_list = with_list_raw.split(",") 2829 with_list_args = [Argument.parse(name.strip()) for name in with_list] 2830 replace[arg] = with_list_args 2831 2832 r = Precompute(replace=replace, add=add_args) 2833 assert r.to_list() == src, "r.to_list() != src" 2834 return r 2835 2836 def __post_init__(self) -> None: 2837 # the template parameters are upper so if these are the 2838 # same then it is ambiguous 2839 for a in self.add: 2840 assert a.name.upper() != a.name 2841 for args in self.replace.values(): 2842 for a in args: 2843 assert a.name.upper() != a.name 2844 2845 def to_list(self) -> list[str]: 2846 replace_list = [] 2847 for kernel_param, replacement_params in self.replace.items(): 2848 replacements = ", ".join(str(param) for param in replacement_params) 2849 replace_list.append(f"{kernel_param} -> {replacements}") 2850 2851 return replace_list 2852