1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport dataclasses 4*da0073e9SAndroid Build Coastguard Workerimport itertools 5*da0073e9SAndroid Build Coastguard Workerimport re 6*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass 7*da0073e9SAndroid Build Coastguard Workerfrom enum import auto, Enum 8*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, Iterator, Sequence 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import assert_never, NamespaceHelper, OrderedSet 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 14*da0073e9SAndroid Build Coastguard Worker# 15*da0073e9SAndroid Build Coastguard Worker# DATA MODEL 16*da0073e9SAndroid Build Coastguard Worker# 17*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 18*da0073e9SAndroid Build Coastguard Worker# 19*da0073e9SAndroid Build Coastguard Worker# Some general principles for our data model. 20*da0073e9SAndroid Build Coastguard Worker# 21*da0073e9SAndroid Build Coastguard Worker# - Stop using C++ data types as the internal data representation 22*da0073e9SAndroid Build Coastguard Worker# format. Instead, the internal data structures are centered 23*da0073e9SAndroid Build Coastguard Worker# around JIT schema representation. This avoid a big problem 24*da0073e9SAndroid Build Coastguard Worker# with the old codegen where we read in all the types from 25*da0073e9SAndroid Build Coastguard Worker# native_functions.yaml and then immediately had to retranslate 26*da0073e9SAndroid Build Coastguard Worker# them into C++ types. 27*da0073e9SAndroid Build Coastguard Worker# 28*da0073e9SAndroid Build Coastguard Worker# - More semantic data representation. Instead of representing 29*da0073e9SAndroid Build Coastguard Worker# everything as dicts and strings, we define dataclasses for 30*da0073e9SAndroid Build Coastguard Worker# every interesting entity the code generation has to deal with. 31*da0073e9SAndroid Build Coastguard Worker# These dataclasses have strong semantic invariants: for example, 32*da0073e9SAndroid Build Coastguard Worker# we generally require them to roundtrip losslessly into the 33*da0073e9SAndroid Build Coastguard Worker# form they were parsed from. These structures are immutable 34*da0073e9SAndroid Build Coastguard Worker# and you're expected to populate information once during 35*da0073e9SAndroid Build Coastguard Worker# construction. 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker# Represent a source location; used for better error reporting 39*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 40*da0073e9SAndroid Build Coastguard Workerclass Location: 41*da0073e9SAndroid Build Coastguard Worker file: str 42*da0073e9SAndroid Build Coastguard Worker line: int 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker def __str__(self) -> str: 45*da0073e9SAndroid Build Coastguard Worker return f"{self.file}:{self.line}" 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker# Valid values of the 'variants' field in native_functions.yaml 49*da0073e9SAndroid Build Coastguard Workerclass Variant(Enum): 50*da0073e9SAndroid Build Coastguard Worker function = auto() 51*da0073e9SAndroid Build Coastguard Worker method = auto() 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker# Default kernel namespace 55*da0073e9SAndroid Build Coastguard WorkerDEFAULT_KERNEL_NAMESPACE = "at::native" 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker# NOTE: Keep the list in sync with `DispatchKey` in c10/core/DispatchKey.h 58*da0073e9SAndroid Build Coastguard WorkerBACKEND_COMPONENTS = "CPU CUDA HIP XLA MTIA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split() 59*da0073e9SAndroid Build Coastguard WorkerFUNCTIONALITY_KEYS = [ 60*da0073e9SAndroid Build Coastguard Worker "", 61*da0073e9SAndroid Build Coastguard Worker "Quantized", 62*da0073e9SAndroid Build Coastguard Worker "Sparse", 63*da0073e9SAndroid Build Coastguard Worker "SparseCsr", 64*da0073e9SAndroid Build Coastguard Worker "NestedTensor", 65*da0073e9SAndroid Build Coastguard Worker "Autograd", 66*da0073e9SAndroid Build Coastguard Worker] 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker# This list guards dispatches that can be used in derivatives.yaml 69*da0073e9SAndroid Build Coastguard Worker# For now we omit AutogradFunctionality and AutogradOther 70*da0073e9SAndroid Build Coastguard WorkerAUTOGRAD_KEYS = ["AutogradNestedTensor"] + [ 71*da0073e9SAndroid Build Coastguard Worker "Autograd" + component for component in BACKEND_COMPONENTS 72*da0073e9SAndroid Build Coastguard Worker] 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard WorkerFRAGMENT_NAMESPACES = {"quantized", "quantized_decomposed"} 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker# This doesn't have to be in sync with the header, it only needs to contain 78*da0073e9SAndroid Build Coastguard Worker# entries that we actually use in the codegen or want pyi entries for 79*da0073e9SAndroid Build Coastguard Workerclass DispatchKey(Enum): 80*da0073e9SAndroid Build Coastguard Worker Undefined = 0 81*da0073e9SAndroid Build Coastguard Worker CatchAll = Undefined 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker FPGA = auto() 84*da0073e9SAndroid Build Coastguard Worker MAIA = auto() 85*da0073e9SAndroid Build Coastguard Worker Vulkan = auto() 86*da0073e9SAndroid Build Coastguard Worker Metal = auto() 87*da0073e9SAndroid Build Coastguard Worker MKLDNN = auto() 88*da0073e9SAndroid Build Coastguard Worker OpenGL = auto() 89*da0073e9SAndroid Build Coastguard Worker OpenCL = auto() 90*da0073e9SAndroid Build Coastguard Worker IDEEP = auto() 91*da0073e9SAndroid Build Coastguard Worker CustomRNGKeyId = auto() 92*da0073e9SAndroid Build Coastguard Worker MkldnnCPU = auto() 93*da0073e9SAndroid Build Coastguard Worker Sparse = auto() 94*da0073e9SAndroid Build Coastguard Worker SparseCsr = auto() 95*da0073e9SAndroid Build Coastguard Worker NestedTensor = auto() 96*da0073e9SAndroid Build Coastguard Worker Dense = auto() 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker PythonTLSSnapshot = auto() 99*da0073e9SAndroid Build Coastguard Worker PreDispatch = auto() 100*da0073e9SAndroid Build Coastguard Worker PythonDispatcher = auto() 101*da0073e9SAndroid Build Coastguard Worker Python = auto() 102*da0073e9SAndroid Build Coastguard Worker FuncTorchDynamicLayerBackMode = auto() 103*da0073e9SAndroid Build Coastguard Worker ZeroTensor = auto() 104*da0073e9SAndroid Build Coastguard Worker Conjugate = auto() 105*da0073e9SAndroid Build Coastguard Worker Negative = auto() 106*da0073e9SAndroid Build Coastguard Worker BackendSelect = auto() 107*da0073e9SAndroid Build Coastguard Worker Named = auto() 108*da0073e9SAndroid Build Coastguard Worker AutogradOther = auto() 109*da0073e9SAndroid Build Coastguard Worker AutogradFunctionality = auto() 110*da0073e9SAndroid Build Coastguard Worker AutogradNestedTensor = auto() 111*da0073e9SAndroid Build Coastguard Worker Tracer = auto() 112*da0073e9SAndroid Build Coastguard Worker Autocast = auto() 113*da0073e9SAndroid Build Coastguard Worker AutocastCPU = auto() 114*da0073e9SAndroid Build Coastguard Worker AutocastCUDA = auto() 115*da0073e9SAndroid Build Coastguard Worker Batched = auto() 116*da0073e9SAndroid Build Coastguard Worker VmapMode = auto() 117*da0073e9SAndroid Build Coastguard Worker FuncTorchGradWrapper = auto() 118*da0073e9SAndroid Build Coastguard Worker FuncTorchBatched = auto() 119*da0073e9SAndroid Build Coastguard Worker BatchedNestedTensor = auto() 120*da0073e9SAndroid Build Coastguard Worker FuncTorchVmapMode = auto() 121*da0073e9SAndroid Build Coastguard Worker FuncTorchDynamicLayerFrontMode = auto() 122*da0073e9SAndroid Build Coastguard Worker Functionalize = auto() 123*da0073e9SAndroid Build Coastguard Worker TESTING_ONLY_GenericWrapper = auto() 124*da0073e9SAndroid Build Coastguard Worker TESTING_ONLY_GenericMode = auto() 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker ADInplaceOrView = auto() 127*da0073e9SAndroid Build Coastguard Worker Autograd = auto() 128*da0073e9SAndroid Build Coastguard Worker CompositeImplicitAutograd = auto() 129*da0073e9SAndroid Build Coastguard Worker CompositeImplicitAutogradNestedTensor = auto() 130*da0073e9SAndroid Build Coastguard Worker CompositeExplicitAutograd = auto() 131*da0073e9SAndroid Build Coastguard Worker CompositeExplicitAutogradNonFunctional = auto() 132*da0073e9SAndroid Build Coastguard Worker FuncTorchBatchedDecomposition = auto() 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker # BEGIN autogenerated 135*da0073e9SAndroid Build Coastguard Worker CPU = auto() 136*da0073e9SAndroid Build Coastguard Worker CUDA = auto() 137*da0073e9SAndroid Build Coastguard Worker HIP = auto() 138*da0073e9SAndroid Build Coastguard Worker XLA = auto() 139*da0073e9SAndroid Build Coastguard Worker MTIA = auto() 140*da0073e9SAndroid Build Coastguard Worker MPS = auto() 141*da0073e9SAndroid Build Coastguard Worker IPU = auto() 142*da0073e9SAndroid Build Coastguard Worker XPU = auto() 143*da0073e9SAndroid Build Coastguard Worker HPU = auto() 144*da0073e9SAndroid Build Coastguard Worker VE = auto() 145*da0073e9SAndroid Build Coastguard Worker Lazy = auto() 146*da0073e9SAndroid Build Coastguard Worker Meta = auto() 147*da0073e9SAndroid Build Coastguard Worker PrivateUse1 = auto() 148*da0073e9SAndroid Build Coastguard Worker PrivateUse2 = auto() 149*da0073e9SAndroid Build Coastguard Worker PrivateUse3 = auto() 150*da0073e9SAndroid Build Coastguard Worker QuantizedCPU = auto() 151*da0073e9SAndroid Build Coastguard Worker QuantizedCUDA = auto() 152*da0073e9SAndroid Build Coastguard Worker QuantizedHIP = auto() 153*da0073e9SAndroid Build Coastguard Worker QuantizedXLA = auto() 154*da0073e9SAndroid Build Coastguard Worker QuantizedMTIA = auto() 155*da0073e9SAndroid Build Coastguard Worker QuantizedMPS = auto() 156*da0073e9SAndroid Build Coastguard Worker QuantizedIPU = auto() 157*da0073e9SAndroid Build Coastguard Worker QuantizedXPU = auto() 158*da0073e9SAndroid Build Coastguard Worker QuantizedHPU = auto() 159*da0073e9SAndroid Build Coastguard Worker QuantizedVE = auto() 160*da0073e9SAndroid Build Coastguard Worker QuantizedLazy = auto() 161*da0073e9SAndroid Build Coastguard Worker QuantizedMeta = auto() 162*da0073e9SAndroid Build Coastguard Worker QuantizedPrivateUse1 = auto() 163*da0073e9SAndroid Build Coastguard Worker QuantizedPrivateUse2 = auto() 164*da0073e9SAndroid Build Coastguard Worker QuantizedPrivateUse3 = auto() 165*da0073e9SAndroid Build Coastguard Worker SparseCPU = auto() 166*da0073e9SAndroid Build Coastguard Worker SparseCUDA = auto() 167*da0073e9SAndroid Build Coastguard Worker SparseHIP = auto() 168*da0073e9SAndroid Build Coastguard Worker SparseXLA = auto() 169*da0073e9SAndroid Build Coastguard Worker SparseMTIA = auto() 170*da0073e9SAndroid Build Coastguard Worker SparseMPS = auto() 171*da0073e9SAndroid Build Coastguard Worker SparseIPU = auto() 172*da0073e9SAndroid Build Coastguard Worker SparseXPU = auto() 173*da0073e9SAndroid Build Coastguard Worker SparseHPU = auto() 174*da0073e9SAndroid Build Coastguard Worker SparseVE = auto() 175*da0073e9SAndroid Build Coastguard Worker SparseLazy = auto() 176*da0073e9SAndroid Build Coastguard Worker SparseMeta = auto() 177*da0073e9SAndroid Build Coastguard Worker SparsePrivateUse1 = auto() 178*da0073e9SAndroid Build Coastguard Worker SparsePrivateUse2 = auto() 179*da0073e9SAndroid Build Coastguard Worker SparsePrivateUse3 = auto() 180*da0073e9SAndroid Build Coastguard Worker SparseCsrCPU = auto() 181*da0073e9SAndroid Build Coastguard Worker SparseCsrCUDA = auto() 182*da0073e9SAndroid Build Coastguard Worker SparseCsrHIP = auto() 183*da0073e9SAndroid Build Coastguard Worker SparseCsrXLA = auto() 184*da0073e9SAndroid Build Coastguard Worker SparseCsrMTIA = auto() 185*da0073e9SAndroid Build Coastguard Worker SparseCsrMPS = auto() 186*da0073e9SAndroid Build Coastguard Worker SparseCsrIPU = auto() 187*da0073e9SAndroid Build Coastguard Worker SparseCsrXPU = auto() 188*da0073e9SAndroid Build Coastguard Worker SparseCsrHPU = auto() 189*da0073e9SAndroid Build Coastguard Worker SparseCsrVE = auto() 190*da0073e9SAndroid Build Coastguard Worker SparseCsrLazy = auto() 191*da0073e9SAndroid Build Coastguard Worker SparseCsrMeta = auto() 192*da0073e9SAndroid Build Coastguard Worker SparseCsrPrivateUse1 = auto() 193*da0073e9SAndroid Build Coastguard Worker SparseCsrPrivateUse2 = auto() 194*da0073e9SAndroid Build Coastguard Worker SparseCsrPrivateUse3 = auto() 195*da0073e9SAndroid Build Coastguard Worker NestedTensorCPU = auto() 196*da0073e9SAndroid Build Coastguard Worker NestedTensorCUDA = auto() 197*da0073e9SAndroid Build Coastguard Worker NestedTensorHIP = auto() 198*da0073e9SAndroid Build Coastguard Worker NestedTensorXLA = auto() 199*da0073e9SAndroid Build Coastguard Worker NestedTensorMTIA = auto() 200*da0073e9SAndroid Build Coastguard Worker NestedTensorMPS = auto() 201*da0073e9SAndroid Build Coastguard Worker NestedTensorIPU = auto() 202*da0073e9SAndroid Build Coastguard Worker NestedTensorXPU = auto() 203*da0073e9SAndroid Build Coastguard Worker NestedTensorHPU = auto() 204*da0073e9SAndroid Build Coastguard Worker NestedTensorVE = auto() 205*da0073e9SAndroid Build Coastguard Worker NestedTensorLazy = auto() 206*da0073e9SAndroid Build Coastguard Worker NestedTensorMeta = auto() 207*da0073e9SAndroid Build Coastguard Worker NestedTensorPrivateUse1 = auto() 208*da0073e9SAndroid Build Coastguard Worker NestedTensorPrivateUse2 = auto() 209*da0073e9SAndroid Build Coastguard Worker NestedTensorPrivateUse3 = auto() 210*da0073e9SAndroid Build Coastguard Worker AutogradCPU = auto() 211*da0073e9SAndroid Build Coastguard Worker AutogradCUDA = auto() 212*da0073e9SAndroid Build Coastguard Worker AutogradHIP = auto() 213*da0073e9SAndroid Build Coastguard Worker AutogradXLA = auto() 214*da0073e9SAndroid Build Coastguard Worker AutogradMTIA = auto() 215*da0073e9SAndroid Build Coastguard Worker AutogradMPS = auto() 216*da0073e9SAndroid Build Coastguard Worker AutogradIPU = auto() 217*da0073e9SAndroid Build Coastguard Worker AutogradXPU = auto() 218*da0073e9SAndroid Build Coastguard Worker AutogradHPU = auto() 219*da0073e9SAndroid Build Coastguard Worker AutogradVE = auto() 220*da0073e9SAndroid Build Coastguard Worker AutogradLazy = auto() 221*da0073e9SAndroid Build Coastguard Worker AutogradMeta = auto() 222*da0073e9SAndroid Build Coastguard Worker AutogradPrivateUse1 = auto() 223*da0073e9SAndroid Build Coastguard Worker AutogradPrivateUse2 = auto() 224*da0073e9SAndroid Build Coastguard Worker AutogradPrivateUse3 = auto() 225*da0073e9SAndroid Build Coastguard Worker # END autogenerated 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker def __str__(self) -> str: 228*da0073e9SAndroid Build Coastguard Worker return self.name 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker def lower(self) -> str: 231*da0073e9SAndroid Build Coastguard Worker return str(self).lower() 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker @staticmethod 234*da0073e9SAndroid Build Coastguard Worker def parse(value: str) -> DispatchKey: 235*da0073e9SAndroid Build Coastguard Worker for k, v in DispatchKey.__members__.items(): 236*da0073e9SAndroid Build Coastguard Worker if k == value: 237*da0073e9SAndroid Build Coastguard Worker return v 238*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"unknown dispatch key {value}") 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Workerclass _TorchDispatchModeKey(Enum): 242*da0073e9SAndroid Build Coastguard Worker FAKE = auto() 243*da0073e9SAndroid Build Coastguard Worker PROXY = auto() 244*da0073e9SAndroid Build Coastguard Worker FUNCTIONAL = auto() 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Workerdef codegen_per_backend_entries() -> str: 248*da0073e9SAndroid Build Coastguard Worker r = [] 249*da0073e9SAndroid Build Coastguard Worker for fk in FUNCTIONALITY_KEYS: 250*da0073e9SAndroid Build Coastguard Worker for bc in BACKEND_COMPONENTS: 251*da0073e9SAndroid Build Coastguard Worker r.append(f" {fk}{bc} = auto()") 252*da0073e9SAndroid Build Coastguard Worker return "\n".join(r) 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Workerfor fk in FUNCTIONALITY_KEYS: 256*da0073e9SAndroid Build Coastguard Worker for bc in BACKEND_COMPONENTS: 257*da0073e9SAndroid Build Coastguard Worker if not hasattr(DispatchKey, fk + bc): 258*da0073e9SAndroid Build Coastguard Worker r = codegen_per_backend_entries() 259*da0073e9SAndroid Build Coastguard Worker print(r) 260*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 261*da0073e9SAndroid Build Coastguard Worker f"Missing {fk}{bc} from DispatchKey enum. Here is the autogenerated list we expect to have:\n\n{r}" 262*da0073e9SAndroid Build Coastguard Worker ) 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard WorkerSTRUCTURED_DISPATCH_KEYS = { 266*da0073e9SAndroid Build Coastguard Worker DispatchKey.MPS, 267*da0073e9SAndroid Build Coastguard Worker DispatchKey.CUDA, 268*da0073e9SAndroid Build Coastguard Worker DispatchKey.CPU, 269*da0073e9SAndroid Build Coastguard Worker DispatchKey.XPU, 270*da0073e9SAndroid Build Coastguard Worker} 271*da0073e9SAndroid Build Coastguard WorkerUFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU} 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker# Set of supported dispatch keys 274*da0073e9SAndroid Build Coastguard Workerdispatch_keys = [ 275*da0073e9SAndroid Build Coastguard Worker DispatchKey.CPU, 276*da0073e9SAndroid Build Coastguard Worker DispatchKey.SparseCPU, 277*da0073e9SAndroid Build Coastguard Worker DispatchKey.SparseCsrCPU, 278*da0073e9SAndroid Build Coastguard Worker DispatchKey.MkldnnCPU, 279*da0073e9SAndroid Build Coastguard Worker DispatchKey.CUDA, 280*da0073e9SAndroid Build Coastguard Worker DispatchKey.MPS, 281*da0073e9SAndroid Build Coastguard Worker DispatchKey.XPU, 282*da0073e9SAndroid Build Coastguard Worker DispatchKey.SparseCUDA, 283*da0073e9SAndroid Build Coastguard Worker DispatchKey.SparseCsrCUDA, 284*da0073e9SAndroid Build Coastguard Worker DispatchKey.QuantizedCPU, 285*da0073e9SAndroid Build Coastguard Worker DispatchKey.QuantizedCUDA, 286*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeImplicitAutograd, 287*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeImplicitAutogradNestedTensor, 288*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeExplicitAutograd, 289*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeExplicitAutogradNonFunctional, 290*da0073e9SAndroid Build Coastguard Worker DispatchKey.NestedTensorCPU, 291*da0073e9SAndroid Build Coastguard Worker DispatchKey.NestedTensorCUDA, 292*da0073e9SAndroid Build Coastguard Worker # Meta is a magic key: it is automatically generated for structured 293*da0073e9SAndroid Build Coastguard Worker # kernels 294*da0073e9SAndroid Build Coastguard Worker DispatchKey.Meta, 295*da0073e9SAndroid Build Coastguard Worker DispatchKey.SparseMeta, 296*da0073e9SAndroid Build Coastguard Worker DispatchKey.SparseCsrMeta, 297*da0073e9SAndroid Build Coastguard Worker DispatchKey.QuantizedMeta, 298*da0073e9SAndroid Build Coastguard Worker DispatchKey.NestedTensorMeta, 299*da0073e9SAndroid Build Coastguard Worker DispatchKey.ZeroTensor, 300*da0073e9SAndroid Build Coastguard Worker] 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker# Dispatch keys that "support all backends". These codegen slightly differently 304*da0073e9SAndroid Build Coastguard Worker# then backend specific keys. 305*da0073e9SAndroid Build Coastguard Workerdef is_generic_dispatch_key(dk: DispatchKey) -> bool: 306*da0073e9SAndroid Build Coastguard Worker return dk in { 307*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeExplicitAutograd, 308*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeExplicitAutogradNonFunctional, 309*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeImplicitAutograd, 310*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeImplicitAutogradNestedTensor, 311*da0073e9SAndroid Build Coastguard Worker } 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker# CUDA specific dispatch keys 315*da0073e9SAndroid Build Coastguard Workerdef is_cuda_dispatch_key(dk: DispatchKey) -> bool: 316*da0073e9SAndroid Build Coastguard Worker return dk in { 317*da0073e9SAndroid Build Coastguard Worker DispatchKey.CUDA, 318*da0073e9SAndroid Build Coastguard Worker DispatchKey.QuantizedCUDA, 319*da0073e9SAndroid Build Coastguard Worker DispatchKey.SparseCUDA, 320*da0073e9SAndroid Build Coastguard Worker DispatchKey.SparseCsrCUDA, 321*da0073e9SAndroid Build Coastguard Worker DispatchKey.NestedTensorCUDA, 322*da0073e9SAndroid Build Coastguard Worker DispatchKey.AutogradCUDA, 323*da0073e9SAndroid Build Coastguard Worker } 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker 326*da0073e9SAndroid Build Coastguard Worker# XPU specific dispatcy keys 327*da0073e9SAndroid Build Coastguard Workerdef is_xpu_dispatch_key(dk: DispatchKey) -> bool: 328*da0073e9SAndroid Build Coastguard Worker return dk in { 329*da0073e9SAndroid Build Coastguard Worker DispatchKey.XPU, 330*da0073e9SAndroid Build Coastguard Worker DispatchKey.QuantizedXPU, 331*da0073e9SAndroid Build Coastguard Worker DispatchKey.SparseXPU, 332*da0073e9SAndroid Build Coastguard Worker DispatchKey.SparseCsrXPU, 333*da0073e9SAndroid Build Coastguard Worker DispatchKey.NestedTensorXPU, 334*da0073e9SAndroid Build Coastguard Worker DispatchKey.AutogradXPU, 335*da0073e9SAndroid Build Coastguard Worker } 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker# Structured kernel generation is only supported for certain key types; 339*da0073e9SAndroid Build Coastguard Worker# otherwise use old-style 340*da0073e9SAndroid Build Coastguard Workerdef is_structured_dispatch_key(dk: DispatchKey) -> bool: 341*da0073e9SAndroid Build Coastguard Worker return dk in STRUCTURED_DISPATCH_KEYS 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Workerdef is_ufunc_dispatch_key(dk: DispatchKey) -> bool: 345*da0073e9SAndroid Build Coastguard Worker # For now, ufunc dispatch keys coincide with structured keys 346*da0073e9SAndroid Build Coastguard Worker return dk in UFUNC_DISPATCH_KEYS 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker# This is oddly named ScalarType and not DType for symmetry with C++ 350*da0073e9SAndroid Build Coastguard Workerclass ScalarType(Enum): 351*da0073e9SAndroid Build Coastguard Worker Byte = auto() 352*da0073e9SAndroid Build Coastguard Worker Char = auto() 353*da0073e9SAndroid Build Coastguard Worker Short = auto() 354*da0073e9SAndroid Build Coastguard Worker Int = auto() 355*da0073e9SAndroid Build Coastguard Worker Long = auto() 356*da0073e9SAndroid Build Coastguard Worker Half = auto() 357*da0073e9SAndroid Build Coastguard Worker Float = auto() 358*da0073e9SAndroid Build Coastguard Worker Double = auto() 359*da0073e9SAndroid Build Coastguard Worker ComplexHalf = auto() 360*da0073e9SAndroid Build Coastguard Worker ComplexFloat = auto() 361*da0073e9SAndroid Build Coastguard Worker ComplexDouble = auto() 362*da0073e9SAndroid Build Coastguard Worker Bool = auto() 363*da0073e9SAndroid Build Coastguard Worker BFloat16 = auto() 364*da0073e9SAndroid Build Coastguard Worker Float8_e5m2 = auto() 365*da0073e9SAndroid Build Coastguard Worker Float8_e5m2fnuz = auto() 366*da0073e9SAndroid Build Coastguard Worker Float8_e4m3fn = auto() 367*da0073e9SAndroid Build Coastguard Worker Float8_e4m3fnuz = auto() 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker def __str__(self) -> str: 370*da0073e9SAndroid Build Coastguard Worker return self.name 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker @staticmethod 373*da0073e9SAndroid Build Coastguard Worker def maybe_parse(value: str) -> ScalarType | None: 374*da0073e9SAndroid Build Coastguard Worker for k, v in ScalarType.__members__.items(): 375*da0073e9SAndroid Build Coastguard Worker if k == value: 376*da0073e9SAndroid Build Coastguard Worker return v 377*da0073e9SAndroid Build Coastguard Worker return None 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker @staticmethod 380*da0073e9SAndroid Build Coastguard Worker def parse(value: str) -> ScalarType: 381*da0073e9SAndroid Build Coastguard Worker mb_r = ScalarType.maybe_parse(value) 382*da0073e9SAndroid Build Coastguard Worker assert mb_r is not None, f"unknown dtype {value}" 383*da0073e9SAndroid Build Coastguard Worker return mb_r 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker @staticmethod 386*da0073e9SAndroid Build Coastguard Worker def parse_set(values: str) -> OrderedSet[ScalarType]: 387*da0073e9SAndroid Build Coastguard Worker dtypes: OrderedSet[ScalarType] = OrderedSet() 388*da0073e9SAndroid Build Coastguard Worker for value in values.split(", "): 389*da0073e9SAndroid Build Coastguard Worker if value in DTYPE_CLASSES: 390*da0073e9SAndroid Build Coastguard Worker dtypes.update(DTYPE_CLASSES[value]) 391*da0073e9SAndroid Build Coastguard Worker else: 392*da0073e9SAndroid Build Coastguard Worker dtypes.add(ScalarType.parse(value)) 393*da0073e9SAndroid Build Coastguard Worker return dtypes 394*da0073e9SAndroid Build Coastguard Worker 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard WorkerDTYPE_CLASSES: dict[str, OrderedSet[ScalarType]] = {} 397*da0073e9SAndroid Build Coastguard Worker# NB: Integral doesn't include boolean 398*da0073e9SAndroid Build Coastguard WorkerDTYPE_CLASSES["Integral"] = OrderedSet( 399*da0073e9SAndroid Build Coastguard Worker [ 400*da0073e9SAndroid Build Coastguard Worker ScalarType.Byte, 401*da0073e9SAndroid Build Coastguard Worker ScalarType.Char, 402*da0073e9SAndroid Build Coastguard Worker ScalarType.Int, 403*da0073e9SAndroid Build Coastguard Worker ScalarType.Long, 404*da0073e9SAndroid Build Coastguard Worker ScalarType.Short, 405*da0073e9SAndroid Build Coastguard Worker ] 406*da0073e9SAndroid Build Coastguard Worker) 407*da0073e9SAndroid Build Coastguard Worker# NB: Floating doesn't include low precision types 408*da0073e9SAndroid Build Coastguard WorkerDTYPE_CLASSES["Floating"] = OrderedSet([ScalarType.Float, ScalarType.Double]) 409*da0073e9SAndroid Build Coastguard WorkerDTYPE_CLASSES["Complex"] = OrderedSet( 410*da0073e9SAndroid Build Coastguard Worker [ScalarType.ComplexFloat, ScalarType.ComplexDouble] 411*da0073e9SAndroid Build Coastguard Worker) 412*da0073e9SAndroid Build Coastguard WorkerDTYPE_CLASSES["All"] = DTYPE_CLASSES["Integral"] | DTYPE_CLASSES["Floating"] 413*da0073e9SAndroid Build Coastguard WorkerDTYPE_CLASSES["AllAndComplex"] = DTYPE_CLASSES["All"] | DTYPE_CLASSES["Complex"] 414*da0073e9SAndroid Build Coastguard WorkerDTYPE_CLASSES["FloatingAndComplex"] = ( 415*da0073e9SAndroid Build Coastguard Worker DTYPE_CLASSES["Floating"] | DTYPE_CLASSES["Complex"] 416*da0073e9SAndroid Build Coastguard Worker) 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker 419*da0073e9SAndroid Build Coastguard Worker# Represents the valid entries for ufunc_inner_loop in native_functions.yaml. 420*da0073e9SAndroid Build Coastguard Worker# NB: if you add a new UfuncKey, you will teach torchgen.dest.ufunc how 421*da0073e9SAndroid Build Coastguard Worker# to process it. Most logic will ignore keys they don't understand, so your 422*da0073e9SAndroid Build Coastguard Worker# new key will get silently ignored until you hook in logic to deal with it. 423*da0073e9SAndroid Build Coastguard Workerclass UfuncKey(Enum): 424*da0073e9SAndroid Build Coastguard Worker # These are low level keys that represent exactly one particular 425*da0073e9SAndroid Build Coastguard Worker # instantiation of the kernel produced by codegen 426*da0073e9SAndroid Build Coastguard Worker CUDAFunctor = auto() 427*da0073e9SAndroid Build Coastguard Worker CUDAFunctorOnOther = auto() 428*da0073e9SAndroid Build Coastguard Worker CUDAFunctorOnSelf = auto() 429*da0073e9SAndroid Build Coastguard Worker 430*da0073e9SAndroid Build Coastguard Worker CPUScalar = auto() 431*da0073e9SAndroid Build Coastguard Worker CPUVector = auto() 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker # These are the ones users will usually specify, and 434*da0073e9SAndroid Build Coastguard Worker # implicitly "fill in" the low level keys 435*da0073e9SAndroid Build Coastguard Worker ScalarOnly = auto() # CUDA*, CPUScalar 436*da0073e9SAndroid Build Coastguard Worker Generic = auto() # CUDA*, CPU* 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker def __str__(self) -> str: 439*da0073e9SAndroid Build Coastguard Worker return self.name 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker @staticmethod 442*da0073e9SAndroid Build Coastguard Worker def parse(value: str) -> UfuncKey: 443*da0073e9SAndroid Build Coastguard Worker for k, v in UfuncKey.__members__.items(): 444*da0073e9SAndroid Build Coastguard Worker if k == value: 445*da0073e9SAndroid Build Coastguard Worker return v 446*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"unknown ufunc key {value}") 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker 449*da0073e9SAndroid Build Coastguard Workerclass DeviceCheckType(Enum): 450*da0073e9SAndroid Build Coastguard Worker NoCheck = 0 451*da0073e9SAndroid Build Coastguard Worker ExactSame = 1 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Workerclass ViewSchemaKind(Enum): 455*da0073e9SAndroid Build Coastguard Worker aliasing = auto() 456*da0073e9SAndroid Build Coastguard Worker aliasing_inplace = auto() 457*da0073e9SAndroid Build Coastguard Worker non_aliasing = auto() 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker 460*da0073e9SAndroid Build Coastguard Worker# The basic input to the code generation is native_functions.yaml. 461*da0073e9SAndroid Build Coastguard Worker# The name "native", BTW, comes from the distinction between native 462*da0073e9SAndroid Build Coastguard Worker# functions and legacy TH functions. The legacy TH functions are gone, 463*da0073e9SAndroid Build Coastguard Worker# but the "native" descriptor has stuck. 464*da0073e9SAndroid Build Coastguard Worker# 465*da0073e9SAndroid Build Coastguard Worker# NativeFunction models a single entry in native_functions.yaml. Its 466*da0073e9SAndroid Build Coastguard Worker# fields roughly correspond to what you would see in the YAML itself, 467*da0073e9SAndroid Build Coastguard Worker# but after canonicalization and parsing has occurred. 468*da0073e9SAndroid Build Coastguard Worker# 469*da0073e9SAndroid Build Coastguard Worker# You can see some of the overall design patterns for how we setup 470*da0073e9SAndroid Build Coastguard Worker# dataclasses in this class, but we will defer a complete discussion 471*da0073e9SAndroid Build Coastguard Worker# of this at FunctionSchema. 472*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 473*da0073e9SAndroid Build Coastguard Workerclass NativeFunction: 474*da0073e9SAndroid Build Coastguard Worker # The namespace for this operator. For example, if we have "at::add" 475*da0073e9SAndroid Build Coastguard Worker # then the namespace would be "at". This enables ops to be registered 476*da0073e9SAndroid Build Coastguard Worker # through the same DSL with a custom namespace. If not specified, the 477*da0073e9SAndroid Build Coastguard Worker # default namespace would be "at". 478*da0073e9SAndroid Build Coastguard Worker namespace: str 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker # The function schema of the operator in question. This schema 481*da0073e9SAndroid Build Coastguard Worker # has been parsed; see FunctionSchema for more about its structure. 482*da0073e9SAndroid Build Coastguard Worker # (This type is quoted as we are forward referencing a type 483*da0073e9SAndroid Build Coastguard Worker # defined later in the file. I opted for this ordering of the 484*da0073e9SAndroid Build Coastguard Worker # classes for expository clarity.) 485*da0073e9SAndroid Build Coastguard Worker func: FunctionSchema 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker # Whether or not to generate mutable tensor arguments like regular 488*da0073e9SAndroid Build Coastguard Worker # ones 489*da0073e9SAndroid Build Coastguard Worker use_const_ref_for_mutable_tensors: bool 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker # Whether or not to omit automatic generation of a DeviceGuard 492*da0073e9SAndroid Build Coastguard Worker device_guard: bool 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Worker # How to emit automatic generation of device check 495*da0073e9SAndroid Build Coastguard Worker device_check: DeviceCheckType 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Worker # What python module to put the function in 498*da0073e9SAndroid Build Coastguard Worker python_module: str | None 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker # TODO: figure out what this does 501*da0073e9SAndroid Build Coastguard Worker category_override: str | None 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker # If no variants are specified in native_functions.yaml, this is 504*da0073e9SAndroid Build Coastguard Worker # assumed to be {'function'}. 505*da0073e9SAndroid Build Coastguard Worker variants: set[Variant] 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker # Whether or not we should skip generating registrations for 508*da0073e9SAndroid Build Coastguard Worker # this kernel. This is a bit of a double-edged sword, as manual 509*da0073e9SAndroid Build Coastguard Worker # registrations don't participate in codegen-based selective build! 510*da0073e9SAndroid Build Coastguard Worker manual_kernel_registration: bool 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker # Whether or not to skip generating TensorMethod/Functions bindings 513*da0073e9SAndroid Build Coastguard Worker # for this kernel. Technically, this doesn't actually skip generating 514*da0073e9SAndroid Build Coastguard Worker # the binding; instead, the binding gets generated to __dispatch_{funcname} 515*da0073e9SAndroid Build Coastguard Worker # so you can make use of the normal binding if you need it. 516*da0073e9SAndroid Build Coastguard Worker manual_cpp_binding: bool 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker # The location in the YAML file were this native function entry was 519*da0073e9SAndroid Build Coastguard Worker # defined. This is for conveniently reporting error messages! 520*da0073e9SAndroid Build Coastguard Worker loc: Location 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker # A list of operators that are expected to be auto-generated for this NativeFunction. 523*da0073e9SAndroid Build Coastguard Worker # Note: This list isn't actually directly used by the codegen to generate anything. 524*da0073e9SAndroid Build Coastguard Worker # Instead, the codegen figures out what operators to generate purely based off of 525*da0073e9SAndroid Build Coastguard Worker # function schema, and uses the autogen declarations to error check. 526*da0073e9SAndroid Build Coastguard Worker # We expect every NativeFunction that gets auto-generated be explicitly called out 527*da0073e9SAndroid Build Coastguard Worker # in native_functions.yaml 528*da0073e9SAndroid Build Coastguard Worker autogen: list[OperatorName] 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker # If non-empty, this kernel is subject to ufunc codegen. 531*da0073e9SAndroid Build Coastguard Worker # Sorted by ufunc_key 532*da0073e9SAndroid Build Coastguard Worker ufunc_inner_loop: dict[UfuncKey, UfuncInnerLoop] 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard Worker # Whether or not this out functions is a "structured kernel". Structured 535*da0073e9SAndroid Build Coastguard Worker # kernels are defined a little differently from normal kernels; in 536*da0073e9SAndroid Build Coastguard Worker # particular, their shape checking logic is defined separately from 537*da0073e9SAndroid Build Coastguard Worker # the kernel. Only out functions can be structured; other functions 538*da0073e9SAndroid Build Coastguard Worker # delegate to the out function using the structured_delegate keyword. 539*da0073e9SAndroid Build Coastguard Worker # Every structured kernel must have at least an out and a functional 540*da0073e9SAndroid Build Coastguard Worker # variant. 541*da0073e9SAndroid Build Coastguard Worker structured: bool 542*da0073e9SAndroid Build Coastguard Worker 543*da0073e9SAndroid Build Coastguard Worker # Whether or not this non-out function is a structured kernel, defined 544*da0073e9SAndroid Build Coastguard Worker # in terms of the out kernel referenced by the string here. 545*da0073e9SAndroid Build Coastguard Worker structured_delegate: OperatorName | None 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker # Only valid for structured kernels. Specifies alternative of what 548*da0073e9SAndroid Build Coastguard Worker # to inherit from when defining the meta class for the structured 549*da0073e9SAndroid Build Coastguard Worker # operator. This will usually be TensorIteratorBase. This also 550*da0073e9SAndroid Build Coastguard Worker # changes the semantics of set_output to call the parent class. 551*da0073e9SAndroid Build Coastguard Worker structured_inherits: str | None 552*da0073e9SAndroid Build Coastguard Worker 553*da0073e9SAndroid Build Coastguard Worker # Structured kernels can declare elements as "precomputed". These elements 554*da0073e9SAndroid Build Coastguard Worker # are returned by the meta function in one struct and passed to the impl 555*da0073e9SAndroid Build Coastguard Worker # function in lieu of certain kernel arguments that these precomputed 556*da0073e9SAndroid Build Coastguard Worker # elements supersede. Information about the names and types of these 557*da0073e9SAndroid Build Coastguard Worker # precomputed elements and how they correspond to kernel arguments is stored 558*da0073e9SAndroid Build Coastguard Worker # in this member, if applicable. 559*da0073e9SAndroid Build Coastguard Worker precomputed: Precompute | None 560*da0073e9SAndroid Build Coastguard Worker 561*da0073e9SAndroid Build Coastguard Worker # Argument names whose default should be excluded from the C++ interface. 562*da0073e9SAndroid Build Coastguard Worker # Intended for resolving overload ambiguities between signatures. 563*da0073e9SAndroid Build Coastguard Worker cpp_no_default_args: set[str] 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker # Note [Abstract ATen methods] 566*da0073e9SAndroid Build Coastguard Worker # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 567*da0073e9SAndroid Build Coastguard Worker # An abstract ATen method is one whose dispatch differs between 568*da0073e9SAndroid Build Coastguard Worker # types. These are implemented in derived types (with a 569*da0073e9SAndroid Build Coastguard Worker # standard (throwing) definition in Type). A concrete ATen 570*da0073e9SAndroid Build Coastguard Worker # method is one which has the same dispatch for all types; 571*da0073e9SAndroid Build Coastguard Worker # we just implement it in the base Type. This is exposed 572*da0073e9SAndroid Build Coastguard Worker # in Declarations.yaml via a field named 'abstract'. 573*da0073e9SAndroid Build Coastguard Worker is_abstract: bool 574*da0073e9SAndroid Build Coastguard Worker 575*da0073e9SAndroid Build Coastguard Worker # Whether or not the NativeFunction contains a backend-agnostic kernel 576*da0073e9SAndroid Build Coastguard Worker has_composite_implicit_autograd_kernel: bool 577*da0073e9SAndroid Build Coastguard Worker has_composite_implicit_autograd_nested_tensor_kernel: bool 578*da0073e9SAndroid Build Coastguard Worker has_composite_explicit_autograd_kernel: bool 579*da0073e9SAndroid Build Coastguard Worker has_composite_explicit_autograd_non_functional_kernel: bool 580*da0073e9SAndroid Build Coastguard Worker 581*da0073e9SAndroid Build Coastguard Worker # Tags are used to describe semantic information about (groups of) operators, 582*da0073e9SAndroid Build Coastguard Worker # That aren't easily inferrable directly from the operator's schema. 583*da0073e9SAndroid Build Coastguard Worker tags: set[str] 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker # NB: The benefit of defining a dataclass is that we automatically get 586*da0073e9SAndroid Build Coastguard Worker # a constructor defined for all the fields we specify. No need 587*da0073e9SAndroid Build Coastguard Worker # to explicitly write it out. 588*da0073e9SAndroid Build Coastguard Worker 589*da0073e9SAndroid Build Coastguard Worker # We parse both the NativeFunction + backend-specific information about it, which it stored in a corresponding BackendIndex. 590*da0073e9SAndroid Build Coastguard Worker @staticmethod 591*da0073e9SAndroid Build Coastguard Worker def from_yaml( 592*da0073e9SAndroid Build Coastguard Worker ei: dict[str, object], 593*da0073e9SAndroid Build Coastguard Worker loc: Location, 594*da0073e9SAndroid Build Coastguard Worker valid_tags: set[str], 595*da0073e9SAndroid Build Coastguard Worker ignore_keys: set[DispatchKey] | None = None, 596*da0073e9SAndroid Build Coastguard Worker ) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]: 597*da0073e9SAndroid Build Coastguard Worker """ 598*da0073e9SAndroid Build Coastguard Worker Parse a NativeFunction from a dictionary as directly parsed 599*da0073e9SAndroid Build Coastguard Worker from native_functions.yaml 600*da0073e9SAndroid Build Coastguard Worker """ 601*da0073e9SAndroid Build Coastguard Worker e = ei.copy() 602*da0073e9SAndroid Build Coastguard Worker 603*da0073e9SAndroid Build Coastguard Worker funcs = e.pop("func") 604*da0073e9SAndroid Build Coastguard Worker assert isinstance(funcs, str), f"not a str: {funcs}" 605*da0073e9SAndroid Build Coastguard Worker # only support one level of namespace. E.g., aten::add 606*da0073e9SAndroid Build Coastguard Worker namespace_helper = NamespaceHelper.from_namespaced_entity( 607*da0073e9SAndroid Build Coastguard Worker namespaced_entity=funcs, max_level=1 608*da0073e9SAndroid Build Coastguard Worker ) 609*da0073e9SAndroid Build Coastguard Worker namespace = namespace_helper.get_cpp_namespace(default="aten") 610*da0073e9SAndroid Build Coastguard Worker func = FunctionSchema.parse(namespace_helper.entity_name) 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker cpp_no_default_args_list = e.pop("cpp_no_default_args", []) 613*da0073e9SAndroid Build Coastguard Worker assert isinstance(cpp_no_default_args_list, list) 614*da0073e9SAndroid Build Coastguard Worker cpp_no_default_args = set(cpp_no_default_args_list) 615*da0073e9SAndroid Build Coastguard Worker 616*da0073e9SAndroid Build Coastguard Worker use_const_ref_for_mutable_tensors = e.pop( 617*da0073e9SAndroid Build Coastguard Worker "use_const_ref_for_mutable_tensors", False 618*da0073e9SAndroid Build Coastguard Worker ) 619*da0073e9SAndroid Build Coastguard Worker assert isinstance(use_const_ref_for_mutable_tensors, bool) 620*da0073e9SAndroid Build Coastguard Worker 621*da0073e9SAndroid Build Coastguard Worker variants_s = e.pop("variants", "function") 622*da0073e9SAndroid Build Coastguard Worker assert isinstance(variants_s, str) 623*da0073e9SAndroid Build Coastguard Worker variants: set[Variant] = set() 624*da0073e9SAndroid Build Coastguard Worker for v in variants_s.split(", "): 625*da0073e9SAndroid Build Coastguard Worker if v == "function": 626*da0073e9SAndroid Build Coastguard Worker variants.add(Variant.function) 627*da0073e9SAndroid Build Coastguard Worker elif v == "method": 628*da0073e9SAndroid Build Coastguard Worker variants.add(Variant.method) 629*da0073e9SAndroid Build Coastguard Worker else: 630*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"illegal variant {v}") 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker manual_kernel_registration = e.pop("manual_kernel_registration", False) 633*da0073e9SAndroid Build Coastguard Worker assert isinstance( 634*da0073e9SAndroid Build Coastguard Worker manual_kernel_registration, bool 635*da0073e9SAndroid Build Coastguard Worker ), f"not a bool: {manual_kernel_registration}" 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker manual_cpp_binding = e.pop("manual_cpp_binding", False) 638*da0073e9SAndroid Build Coastguard Worker assert isinstance(manual_cpp_binding, bool), f"not a bool: {manual_cpp_binding}" 639*da0073e9SAndroid Build Coastguard Worker 640*da0073e9SAndroid Build Coastguard Worker device_guard = e.pop("device_guard", True) 641*da0073e9SAndroid Build Coastguard Worker assert isinstance(device_guard, bool), f"not a bool: {device_guard}" 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker device_check_s = e.pop("device_check", None) 644*da0073e9SAndroid Build Coastguard Worker assert device_check_s is None or isinstance( 645*da0073e9SAndroid Build Coastguard Worker device_check_s, str 646*da0073e9SAndroid Build Coastguard Worker ), f"not a str: {device_check_s}" 647*da0073e9SAndroid Build Coastguard Worker assert ( 648*da0073e9SAndroid Build Coastguard Worker device_check_s is None or device_check_s in DeviceCheckType.__members__ 649*da0073e9SAndroid Build Coastguard Worker ), f"illegal device_check: {device_check_s}" 650*da0073e9SAndroid Build Coastguard Worker device_check: DeviceCheckType 651*da0073e9SAndroid Build Coastguard Worker if device_check_s is None: 652*da0073e9SAndroid Build Coastguard Worker device_check = DeviceCheckType.ExactSame 653*da0073e9SAndroid Build Coastguard Worker else: 654*da0073e9SAndroid Build Coastguard Worker device_check = DeviceCheckType[device_check_s] 655*da0073e9SAndroid Build Coastguard Worker 656*da0073e9SAndroid Build Coastguard Worker structured = e.pop("structured", False) 657*da0073e9SAndroid Build Coastguard Worker assert isinstance(structured, bool), f"not a bool: {structured}" 658*da0073e9SAndroid Build Coastguard Worker 659*da0073e9SAndroid Build Coastguard Worker structured_delegate_s = e.pop("structured_delegate", None) 660*da0073e9SAndroid Build Coastguard Worker assert structured_delegate_s is None or isinstance( 661*da0073e9SAndroid Build Coastguard Worker structured_delegate_s, str 662*da0073e9SAndroid Build Coastguard Worker ), f"not a str: {structured_delegate_s}" 663*da0073e9SAndroid Build Coastguard Worker assert structured_delegate_s is None or "::" not in structured_delegate_s, ( 664*da0073e9SAndroid Build Coastguard Worker "namespace is not supported in structured delegate," 665*da0073e9SAndroid Build Coastguard Worker " using the same namespace as the native function" 666*da0073e9SAndroid Build Coastguard Worker ) 667*da0073e9SAndroid Build Coastguard Worker structured_delegate: OperatorName | None = None 668*da0073e9SAndroid Build Coastguard Worker if structured_delegate_s is not None: 669*da0073e9SAndroid Build Coastguard Worker structured_delegate = OperatorName.parse(structured_delegate_s) 670*da0073e9SAndroid Build Coastguard Worker 671*da0073e9SAndroid Build Coastguard Worker structured_inherits = e.pop("structured_inherits", None) 672*da0073e9SAndroid Build Coastguard Worker assert structured_inherits is None or isinstance( 673*da0073e9SAndroid Build Coastguard Worker structured_inherits, str 674*da0073e9SAndroid Build Coastguard Worker ), f"not a str: {structured_inherits}" 675*da0073e9SAndroid Build Coastguard Worker assert structured_inherits is None or "::" not in structured_inherits, ( 676*da0073e9SAndroid Build Coastguard Worker "namespace is not supported in structured inherits," 677*da0073e9SAndroid Build Coastguard Worker " using the same namespace as the native function" 678*da0073e9SAndroid Build Coastguard Worker ) 679*da0073e9SAndroid Build Coastguard Worker 680*da0073e9SAndroid Build Coastguard Worker python_module = e.pop("python_module", None) 681*da0073e9SAndroid Build Coastguard Worker assert python_module is None or isinstance( 682*da0073e9SAndroid Build Coastguard Worker python_module, str 683*da0073e9SAndroid Build Coastguard Worker ), f"not a str: {python_module}" 684*da0073e9SAndroid Build Coastguard Worker assert ( 685*da0073e9SAndroid Build Coastguard Worker python_module is None or Variant.method not in variants 686*da0073e9SAndroid Build Coastguard Worker ), "functions in modules cannot be methods" 687*da0073e9SAndroid Build Coastguard Worker 688*da0073e9SAndroid Build Coastguard Worker category_override = e.pop("category_override", None) 689*da0073e9SAndroid Build Coastguard Worker assert category_override is None or isinstance( 690*da0073e9SAndroid Build Coastguard Worker category_override, str 691*da0073e9SAndroid Build Coastguard Worker ), f"not a str: {category_override}" 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker precomputed_dict = e.pop("precomputed", None) 694*da0073e9SAndroid Build Coastguard Worker assert precomputed_dict is None or structured is True 695*da0073e9SAndroid Build Coastguard Worker precomputed = Precompute.parse(precomputed_dict) if precomputed_dict else None 696*da0073e9SAndroid Build Coastguard Worker 697*da0073e9SAndroid Build Coastguard Worker tags_inp = e.pop("tags", []) 698*da0073e9SAndroid Build Coastguard Worker if isinstance(tags_inp, str): 699*da0073e9SAndroid Build Coastguard Worker tags_inp = [tags_inp] 700*da0073e9SAndroid Build Coastguard Worker assert isinstance(tags_inp, list) 701*da0073e9SAndroid Build Coastguard Worker 702*da0073e9SAndroid Build Coastguard Worker # All aten ops generated by torchgen receive the pt2_compliant tag. 703*da0073e9SAndroid Build Coastguard Worker if namespace == "aten" and "pt2_compliant_tag" in valid_tags: 704*da0073e9SAndroid Build Coastguard Worker tags_inp.append("pt2_compliant_tag") 705*da0073e9SAndroid Build Coastguard Worker 706*da0073e9SAndroid Build Coastguard Worker tags: set[str] = set() 707*da0073e9SAndroid Build Coastguard Worker for t in tags_inp: 708*da0073e9SAndroid Build Coastguard Worker assert len(valid_tags) > 0 709*da0073e9SAndroid Build Coastguard Worker # TODO: verify that the tag is valid and has an entry in tags.yaml 710*da0073e9SAndroid Build Coastguard Worker if t in valid_tags: 711*da0073e9SAndroid Build Coastguard Worker tags.add(t) 712*da0073e9SAndroid Build Coastguard Worker else: 713*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"illegal tag {t}") 714*da0073e9SAndroid Build Coastguard Worker 715*da0073e9SAndroid Build Coastguard Worker from torchgen.api import cpp 716*da0073e9SAndroid Build Coastguard Worker 717*da0073e9SAndroid Build Coastguard Worker raw_dispatch = e.pop("dispatch", None) 718*da0073e9SAndroid Build Coastguard Worker assert raw_dispatch is None or isinstance(raw_dispatch, dict), e 719*da0073e9SAndroid Build Coastguard Worker dispatch: dict[DispatchKey, BackendMetadata] = {} 720*da0073e9SAndroid Build Coastguard Worker num_dispatch_keys: int = 0 721*da0073e9SAndroid Build Coastguard Worker if raw_dispatch is not None: 722*da0073e9SAndroid Build Coastguard Worker assert not manual_kernel_registration, ( 723*da0073e9SAndroid Build Coastguard Worker "cannot specify both manual_kernel_registration and dispatch; with " 724*da0073e9SAndroid Build Coastguard Worker "manual registration, dispatch has no effect!" 725*da0073e9SAndroid Build Coastguard Worker ) 726*da0073e9SAndroid Build Coastguard Worker redundant_composite_implicit_autograd = False 727*da0073e9SAndroid Build Coastguard Worker for ks, v in raw_dispatch.items(): 728*da0073e9SAndroid Build Coastguard Worker if ks == "__line__": 729*da0073e9SAndroid Build Coastguard Worker continue # not worth tracking line numbers for dispatch entries 730*da0073e9SAndroid Build Coastguard Worker assert isinstance( 731*da0073e9SAndroid Build Coastguard Worker ks, str 732*da0073e9SAndroid Build Coastguard Worker ), f"illegal dispatch key '{ks}' in {raw_dispatch}" 733*da0073e9SAndroid Build Coastguard Worker assert isinstance( 734*da0073e9SAndroid Build Coastguard Worker v, str 735*da0073e9SAndroid Build Coastguard Worker ), f"illegal dispatch value '{v}' in {raw_dispatch}" 736*da0073e9SAndroid Build Coastguard Worker for k in ks.split(","): 737*da0073e9SAndroid Build Coastguard Worker dispatch_key = DispatchKey.parse(k.strip()) 738*da0073e9SAndroid Build Coastguard Worker num_dispatch_keys += 1 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker if ignore_keys and dispatch_key in ignore_keys: 741*da0073e9SAndroid Build Coastguard Worker continue 742*da0073e9SAndroid Build Coastguard Worker assert dispatch_key in dispatch_keys, ( 743*da0073e9SAndroid Build Coastguard Worker f"Dispatch key {dispatch_key} of kernel {v} " 744*da0073e9SAndroid Build Coastguard Worker "is not a supported dispatch key." 745*da0073e9SAndroid Build Coastguard Worker ) 746*da0073e9SAndroid Build Coastguard Worker # We only allow at most 3 levels of namespace for kernels. 747*da0073e9SAndroid Build Coastguard Worker # We will append "native" to a custom kernel namespace. 748*da0073e9SAndroid Build Coastguard Worker namespace_helper = NamespaceHelper.from_namespaced_entity( 749*da0073e9SAndroid Build Coastguard Worker v, max_level=3 750*da0073e9SAndroid Build Coastguard Worker ) 751*da0073e9SAndroid Build Coastguard Worker kernel_namespace = namespace_helper.get_cpp_namespace(default="at") 752*da0073e9SAndroid Build Coastguard Worker # Why is 'structured' included? External backends (e.g. 753*da0073e9SAndroid Build Coastguard Worker # XLA) opt into which ops are structured independently 754*da0073e9SAndroid Build Coastguard Worker # of which in-tree ops are structured 755*da0073e9SAndroid Build Coastguard Worker dispatch[dispatch_key] = BackendMetadata( 756*da0073e9SAndroid Build Coastguard Worker kernel=namespace_helper.entity_name, 757*da0073e9SAndroid Build Coastguard Worker structured=structured 758*da0073e9SAndroid Build Coastguard Worker and is_structured_dispatch_key(dispatch_key), 759*da0073e9SAndroid Build Coastguard Worker cpp_namespace=(kernel_namespace + "::native"), 760*da0073e9SAndroid Build Coastguard Worker ) 761*da0073e9SAndroid Build Coastguard Worker if ( 762*da0073e9SAndroid Build Coastguard Worker dispatch_key is DispatchKey.CompositeImplicitAutograd 763*da0073e9SAndroid Build Coastguard Worker and v == cpp.name(func) 764*da0073e9SAndroid Build Coastguard Worker ): 765*da0073e9SAndroid Build Coastguard Worker redundant_composite_implicit_autograd = True 766*da0073e9SAndroid Build Coastguard Worker 767*da0073e9SAndroid Build Coastguard Worker # We count the number of dispatch keys which have not been ignored to prevent a dispatch table 768*da0073e9SAndroid Build Coastguard Worker # in which all backend keys are ignored but necessarily kept, remaining compositeimplicit, 769*da0073e9SAndroid Build Coastguard Worker # from being treated as redundant. 770*da0073e9SAndroid Build Coastguard Worker assert not ( 771*da0073e9SAndroid Build Coastguard Worker num_dispatch_keys == 1 and redundant_composite_implicit_autograd 772*da0073e9SAndroid Build Coastguard Worker ), ( 773*da0073e9SAndroid Build Coastguard Worker "unnecessary dispatch table for this function; just delete the dispatch " 774*da0073e9SAndroid Build Coastguard Worker "key entirely" 775*da0073e9SAndroid Build Coastguard Worker ) 776*da0073e9SAndroid Build Coastguard Worker # if a function is a structured delegate, deleting the dispatch 777*da0073e9SAndroid Build Coastguard Worker # table is NOT semantics preserving 778*da0073e9SAndroid Build Coastguard Worker assert ( 779*da0073e9SAndroid Build Coastguard Worker structured_delegate 780*da0073e9SAndroid Build Coastguard Worker or dispatch.keys() != {DispatchKey.CompositeImplicitAutograd} 781*da0073e9SAndroid Build Coastguard Worker or dispatch[DispatchKey.CompositeImplicitAutograd].supports_symint() 782*da0073e9SAndroid Build Coastguard Worker or num_dispatch_keys != 1 783*da0073e9SAndroid Build Coastguard Worker ), ( 784*da0073e9SAndroid Build Coastguard Worker f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} " 785*da0073e9SAndroid Build Coastguard Worker f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}. Rename your implementation to the expected " 786*da0073e9SAndroid Build Coastguard Worker "name, then delete the dispatch table" 787*da0073e9SAndroid Build Coastguard Worker ) 788*da0073e9SAndroid Build Coastguard Worker elif not structured and structured_delegate is None: 789*da0073e9SAndroid Build Coastguard Worker name = str(func.name.name) 790*da0073e9SAndroid Build Coastguard Worker assert not ( 791*da0073e9SAndroid Build Coastguard Worker name.startswith("new_") 792*da0073e9SAndroid Build Coastguard Worker or name.endswith("_like") 793*da0073e9SAndroid Build Coastguard Worker # TODO: maybe it's better to test the return 794*da0073e9SAndroid Build Coastguard Worker or ( 795*da0073e9SAndroid Build Coastguard Worker func.arguments.tensor_options 796*da0073e9SAndroid Build Coastguard Worker and not func.arguments.has_tensor_arg() 797*da0073e9SAndroid Build Coastguard Worker ) 798*da0073e9SAndroid Build Coastguard Worker ), ( 799*da0073e9SAndroid Build Coastguard Worker f"expected {name} to have a CompositeExplicitAutograd " 800*da0073e9SAndroid Build Coastguard Worker "dispatch entry, but there was no dispatch table. Factory functions " 801*da0073e9SAndroid Build Coastguard Worker "should not have implicit dispatch as they should not be decomposed " 802*da0073e9SAndroid Build Coastguard Worker "for __torch_dispatch__" 803*da0073e9SAndroid Build Coastguard Worker ) 804*da0073e9SAndroid Build Coastguard Worker dispatch[DispatchKey.CompositeImplicitAutograd] = BackendMetadata( 805*da0073e9SAndroid Build Coastguard Worker cpp.name(func), structured=False, cpp_namespace=DEFAULT_KERNEL_NAMESPACE 806*da0073e9SAndroid Build Coastguard Worker ) 807*da0073e9SAndroid Build Coastguard Worker 808*da0073e9SAndroid Build Coastguard Worker composites_in_dispatch = [ 809*da0073e9SAndroid Build Coastguard Worker d 810*da0073e9SAndroid Build Coastguard Worker for d in dispatch 811*da0073e9SAndroid Build Coastguard Worker if d == DispatchKey.CompositeExplicitAutograd 812*da0073e9SAndroid Build Coastguard Worker or d == DispatchKey.CompositeExplicitAutogradNonFunctional 813*da0073e9SAndroid Build Coastguard Worker or d == DispatchKey.CompositeImplicitAutograd 814*da0073e9SAndroid Build Coastguard Worker or d == DispatchKey.CompositeImplicitAutogradNestedTensor 815*da0073e9SAndroid Build Coastguard Worker ] 816*da0073e9SAndroid Build Coastguard Worker 817*da0073e9SAndroid Build Coastguard Worker assert len(composites_in_dispatch) <= 1 or ( 818*da0073e9SAndroid Build Coastguard Worker len(composites_in_dispatch) == 2 819*da0073e9SAndroid Build Coastguard Worker and ( 820*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeExplicitAutogradNonFunctional 821*da0073e9SAndroid Build Coastguard Worker not in composites_in_dispatch 822*da0073e9SAndroid Build Coastguard Worker ) 823*da0073e9SAndroid Build Coastguard Worker and ( 824*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeImplicitAutogradNestedTensor 825*da0073e9SAndroid Build Coastguard Worker in composites_in_dispatch 826*da0073e9SAndroid Build Coastguard Worker ) 827*da0073e9SAndroid Build Coastguard Worker ), ( 828*da0073e9SAndroid Build Coastguard Worker "cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, " 829*da0073e9SAndroid Build Coastguard Worker "or CompositeImplicitAutograd on a single kernel; each " 830*da0073e9SAndroid Build Coastguard Worker "strictly subsumes the other. If you wanted to provide an explicit autograd " 831*da0073e9SAndroid Build Coastguard Worker "implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only" 832*da0073e9SAndroid Build Coastguard Worker ) 833*da0073e9SAndroid Build Coastguard Worker 834*da0073e9SAndroid Build Coastguard Worker autogen_str = e.pop("autogen", "") 835*da0073e9SAndroid Build Coastguard Worker assert isinstance(autogen_str, str) 836*da0073e9SAndroid Build Coastguard Worker autogen = ( 837*da0073e9SAndroid Build Coastguard Worker [] 838*da0073e9SAndroid Build Coastguard Worker if autogen_str == "" 839*da0073e9SAndroid Build Coastguard Worker else [OperatorName.parse(x) for x in autogen_str.split(", ")] 840*da0073e9SAndroid Build Coastguard Worker ) 841*da0073e9SAndroid Build Coastguard Worker 842*da0073e9SAndroid Build Coastguard Worker raw_ufunc_inner_loop = e.pop("ufunc_inner_loop", {}) 843*da0073e9SAndroid Build Coastguard Worker ufunc_inner_loop = {} 844*da0073e9SAndroid Build Coastguard Worker if isinstance(raw_ufunc_inner_loop, str): 845*da0073e9SAndroid Build Coastguard Worker ufunc_inner_loop[UfuncKey.Generic] = UfuncInnerLoop.parse( 846*da0073e9SAndroid Build Coastguard Worker raw_ufunc_inner_loop, UfuncKey.Generic 847*da0073e9SAndroid Build Coastguard Worker ) 848*da0073e9SAndroid Build Coastguard Worker elif isinstance(raw_ufunc_inner_loop, dict): 849*da0073e9SAndroid Build Coastguard Worker for k, vo in raw_ufunc_inner_loop.items(): 850*da0073e9SAndroid Build Coastguard Worker if k == "__line__": 851*da0073e9SAndroid Build Coastguard Worker continue 852*da0073e9SAndroid Build Coastguard Worker assert isinstance(k, str), f"ufunc_inner_loop key is not a str: {k}" 853*da0073e9SAndroid Build Coastguard Worker assert isinstance(vo, str), f"ufunc_inner_loop value is not a str: {v}" 854*da0073e9SAndroid Build Coastguard Worker ufunc_key = UfuncKey.parse(k) 855*da0073e9SAndroid Build Coastguard Worker ufunc_inner_loop[ufunc_key] = UfuncInnerLoop.parse(vo, ufunc_key) 856*da0073e9SAndroid Build Coastguard Worker else: 857*da0073e9SAndroid Build Coastguard Worker raise AssertionError( 858*da0073e9SAndroid Build Coastguard Worker f"ufunc_inner_loop not str or dict: {raw_ufunc_inner_loop}" 859*da0073e9SAndroid Build Coastguard Worker ) 860*da0073e9SAndroid Build Coastguard Worker # Program the BackendIndex for the implicit dispatch entry from ufunc 861*da0073e9SAndroid Build Coastguard Worker if ufunc_inner_loop: 862*da0073e9SAndroid Build Coastguard Worker assert structured, "ufunc must be structured" 863*da0073e9SAndroid Build Coastguard Worker 864*da0073e9SAndroid Build Coastguard Worker # Delay import ufunc here to avoid circular import issue 865*da0073e9SAndroid Build Coastguard Worker # See: https://github.com/pytorch/pytorch/issues/81294 866*da0073e9SAndroid Build Coastguard Worker import torchgen.api.ufunc as ufunc 867*da0073e9SAndroid Build Coastguard Worker 868*da0073e9SAndroid Build Coastguard Worker for dispatch_key in UFUNC_DISPATCH_KEYS: 869*da0073e9SAndroid Build Coastguard Worker assert ( 870*da0073e9SAndroid Build Coastguard Worker dispatch_key not in dispatch 871*da0073e9SAndroid Build Coastguard Worker ), f"ufunc should not have explicit dispatch entry for {dispatch_key}" 872*da0073e9SAndroid Build Coastguard Worker dispatch[dispatch_key] = BackendMetadata( 873*da0073e9SAndroid Build Coastguard Worker kernel=ufunc.schema_kernel_name(func, dispatch_key), 874*da0073e9SAndroid Build Coastguard Worker structured=True, 875*da0073e9SAndroid Build Coastguard Worker cpp_namespace=DEFAULT_KERNEL_NAMESPACE, 876*da0073e9SAndroid Build Coastguard Worker ) 877*da0073e9SAndroid Build Coastguard Worker 878*da0073e9SAndroid Build Coastguard Worker if structured_delegate: 879*da0073e9SAndroid Build Coastguard Worker # Structured functions MUST have a dispatch table 880*da0073e9SAndroid Build Coastguard Worker is_abstract = True 881*da0073e9SAndroid Build Coastguard Worker else: 882*da0073e9SAndroid Build Coastguard Worker is_abstract = ( 883*da0073e9SAndroid Build Coastguard Worker dispatch.keys() != {DispatchKey.CompositeImplicitAutograd} 884*da0073e9SAndroid Build Coastguard Worker and dispatch.keys() 885*da0073e9SAndroid Build Coastguard Worker != {DispatchKey.CompositeImplicitAutogradNestedTensor} 886*da0073e9SAndroid Build Coastguard Worker and dispatch.keys() 887*da0073e9SAndroid Build Coastguard Worker != { 888*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeImplicitAutograd, 889*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeImplicitAutogradNestedTensor, 890*da0073e9SAndroid Build Coastguard Worker } 891*da0073e9SAndroid Build Coastguard Worker ) 892*da0073e9SAndroid Build Coastguard Worker 893*da0073e9SAndroid Build Coastguard Worker has_composite_implicit_autograd_kernel = ( 894*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeImplicitAutograd in dispatch 895*da0073e9SAndroid Build Coastguard Worker ) 896*da0073e9SAndroid Build Coastguard Worker has_composite_implicit_autograd_nested_tensor_kernel = ( 897*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeImplicitAutogradNestedTensor in dispatch 898*da0073e9SAndroid Build Coastguard Worker ) 899*da0073e9SAndroid Build Coastguard Worker has_composite_explicit_autograd_kernel = ( 900*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeExplicitAutograd in dispatch 901*da0073e9SAndroid Build Coastguard Worker ) 902*da0073e9SAndroid Build Coastguard Worker has_composite_explicit_autograd_non_functional_kernel = ( 903*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeExplicitAutogradNonFunctional in dispatch 904*da0073e9SAndroid Build Coastguard Worker ) 905*da0073e9SAndroid Build Coastguard Worker 906*da0073e9SAndroid Build Coastguard Worker # We aren't going to store dispatch metadata inline in NativeFunctions; 907*da0073e9SAndroid Build Coastguard Worker # instead it is separately indexed by backend (so other backends can 908*da0073e9SAndroid Build Coastguard Worker # add more dispatch entries after the fact). Reindex the individual 909*da0073e9SAndroid Build Coastguard Worker # metadata by OperatorName! 910*da0073e9SAndroid Build Coastguard Worker backend_metadata = {k: {func.name: v} for k, v in dispatch.items()} 911*da0073e9SAndroid Build Coastguard Worker 912*da0073e9SAndroid Build Coastguard Worker # don't care if it exists or not; make it easier to use this function 913*da0073e9SAndroid Build Coastguard Worker # with other yaml parsers that aren't setting __line__ in the dict 914*da0073e9SAndroid Build Coastguard Worker e.pop("__line__", None) 915*da0073e9SAndroid Build Coastguard Worker assert not e, f"leftover entries: {e}" 916*da0073e9SAndroid Build Coastguard Worker 917*da0073e9SAndroid Build Coastguard Worker # Asserts that we can't do in post_init, because they rely on backend-specific info 918*da0073e9SAndroid Build Coastguard Worker if structured_delegate is not None: 919*da0073e9SAndroid Build Coastguard Worker for key in STRUCTURED_DISPATCH_KEYS: 920*da0073e9SAndroid Build Coastguard Worker assert key not in dispatch, ( 921*da0073e9SAndroid Build Coastguard Worker f"if structured_delegate, then must not have {key} in dispatch dictionary " 922*da0073e9SAndroid Build Coastguard Worker "(it is delegated!)" 923*da0073e9SAndroid Build Coastguard Worker ) 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker return ( 926*da0073e9SAndroid Build Coastguard Worker NativeFunction( 927*da0073e9SAndroid Build Coastguard Worker func=func, 928*da0073e9SAndroid Build Coastguard Worker use_const_ref_for_mutable_tensors=use_const_ref_for_mutable_tensors, 929*da0073e9SAndroid Build Coastguard Worker variants=variants, 930*da0073e9SAndroid Build Coastguard Worker structured=structured, 931*da0073e9SAndroid Build Coastguard Worker structured_delegate=structured_delegate, 932*da0073e9SAndroid Build Coastguard Worker structured_inherits=structured_inherits, 933*da0073e9SAndroid Build Coastguard Worker precomputed=precomputed, 934*da0073e9SAndroid Build Coastguard Worker autogen=autogen, 935*da0073e9SAndroid Build Coastguard Worker ufunc_inner_loop=ufunc_inner_loop, 936*da0073e9SAndroid Build Coastguard Worker manual_kernel_registration=manual_kernel_registration, 937*da0073e9SAndroid Build Coastguard Worker manual_cpp_binding=manual_cpp_binding, 938*da0073e9SAndroid Build Coastguard Worker python_module=python_module, 939*da0073e9SAndroid Build Coastguard Worker category_override=category_override, 940*da0073e9SAndroid Build Coastguard Worker device_guard=device_guard, 941*da0073e9SAndroid Build Coastguard Worker device_check=device_check, 942*da0073e9SAndroid Build Coastguard Worker loc=loc, 943*da0073e9SAndroid Build Coastguard Worker cpp_no_default_args=cpp_no_default_args, 944*da0073e9SAndroid Build Coastguard Worker is_abstract=is_abstract, 945*da0073e9SAndroid Build Coastguard Worker has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel, 946*da0073e9SAndroid Build Coastguard Worker has_composite_implicit_autograd_nested_tensor_kernel=has_composite_implicit_autograd_nested_tensor_kernel, 947*da0073e9SAndroid Build Coastguard Worker has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel, 948*da0073e9SAndroid Build Coastguard Worker has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel, 949*da0073e9SAndroid Build Coastguard Worker tags=tags, 950*da0073e9SAndroid Build Coastguard Worker namespace=namespace, 951*da0073e9SAndroid Build Coastguard Worker ), 952*da0073e9SAndroid Build Coastguard Worker backend_metadata, 953*da0073e9SAndroid Build Coastguard Worker ) 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker def validate_unstructured(self) -> None: 956*da0073e9SAndroid Build Coastguard Worker # TODO: probably better to accumulate these errors and report them all 957*da0073e9SAndroid Build Coastguard Worker # at once 958*da0073e9SAndroid Build Coastguard Worker assert not self.structured, ( 959*da0073e9SAndroid Build Coastguard Worker "This function is structured, but there was " 960*da0073e9SAndroid Build Coastguard Worker "no valid functional variant of it." 961*da0073e9SAndroid Build Coastguard Worker ) 962*da0073e9SAndroid Build Coastguard Worker assert self.structured_delegate, ( 963*da0073e9SAndroid Build Coastguard Worker "This function delegates to another structured out function, " 964*da0073e9SAndroid Build Coastguard Worker "but no valid function was found (the delegate may not exist, or it has the wrong type)" 965*da0073e9SAndroid Build Coastguard Worker ) 966*da0073e9SAndroid Build Coastguard Worker 967*da0073e9SAndroid Build Coastguard Worker # __post_init__ functions in dataclasses can be used to do extra 968*da0073e9SAndroid Build Coastguard Worker # validation after construction. 969*da0073e9SAndroid Build Coastguard Worker # 970*da0073e9SAndroid Build Coastguard Worker # Notice that we don't do any type validation here. In fact, we 971*da0073e9SAndroid Build Coastguard Worker # rely exclusively on mypy to check if you've done types correctly! 972*da0073e9SAndroid Build Coastguard Worker # Validation is for nontrivial invariants that cannot be (conveniently) 973*da0073e9SAndroid Build Coastguard Worker # encoded in the type system. 974*da0073e9SAndroid Build Coastguard Worker def __post_init__(self) -> None: 975*da0073e9SAndroid Build Coastguard Worker if self.func.arguments.out: 976*da0073e9SAndroid Build Coastguard Worker assert self.variants == {Variant.function}, ( 977*da0073e9SAndroid Build Coastguard Worker "Native functions with out arguments MUST " 978*da0073e9SAndroid Build Coastguard Worker "be declared with only function variant; e.g., variants: function; " 979*da0073e9SAndroid Build Coastguard Worker "otherwise you will tickle a Python argument binding bug " 980*da0073e9SAndroid Build Coastguard Worker "(which usually manifests itself as the result variable being undefined.)" 981*da0073e9SAndroid Build Coastguard Worker ) 982*da0073e9SAndroid Build Coastguard Worker if self.structured: 983*da0073e9SAndroid Build Coastguard Worker assert self.func.kind() == SchemaKind.out, ( 984*da0073e9SAndroid Build Coastguard Worker "Put structured field on the out= " 985*da0073e9SAndroid Build Coastguard Worker "variant of a function; did you mean structured_delegate?" 986*da0073e9SAndroid Build Coastguard Worker ) 987*da0073e9SAndroid Build Coastguard Worker assert ( 988*da0073e9SAndroid Build Coastguard Worker self.device_guard 989*da0073e9SAndroid Build Coastguard Worker ), "device_guard: False is not respected by structured kernels" 990*da0073e9SAndroid Build Coastguard Worker if self.structured_delegate: 991*da0073e9SAndroid Build Coastguard Worker assert self.func.kind() != SchemaKind.out, ( 992*da0073e9SAndroid Build Coastguard Worker "structured_delegate field not allowed " 993*da0073e9SAndroid Build Coastguard Worker "on out= functions; did you mean structured?" 994*da0073e9SAndroid Build Coastguard Worker ) 995*da0073e9SAndroid Build Coastguard Worker assert ( 996*da0073e9SAndroid Build Coastguard Worker self.device_guard 997*da0073e9SAndroid Build Coastguard Worker ), "device_guard: False is not respected by structured kernels" 998*da0073e9SAndroid Build Coastguard Worker # Technically, with the asserts above, this assert is impossible to 999*da0073e9SAndroid Build Coastguard Worker # happen 1000*da0073e9SAndroid Build Coastguard Worker assert not ( 1001*da0073e9SAndroid Build Coastguard Worker self.structured and self.structured_delegate 1002*da0073e9SAndroid Build Coastguard Worker ), "Cannot have both structured and structured_delegate on function" 1003*da0073e9SAndroid Build Coastguard Worker defaulted_arguments = { 1004*da0073e9SAndroid Build Coastguard Worker a.name for a in self.func.schema_order_arguments() if a.default is not None 1005*da0073e9SAndroid Build Coastguard Worker } 1006*da0073e9SAndroid Build Coastguard Worker invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments) 1007*da0073e9SAndroid Build Coastguard Worker assert len(invalid_args) == 0, f"Invalid cpp_no_default_args: {invalid_args}" 1008*da0073e9SAndroid Build Coastguard Worker if self.structured_inherits is not None: 1009*da0073e9SAndroid Build Coastguard Worker assert ( 1010*da0073e9SAndroid Build Coastguard Worker self.structured 1011*da0073e9SAndroid Build Coastguard Worker ), "structured_inherits must also imply structured: True" 1012*da0073e9SAndroid Build Coastguard Worker if str(self.func.name).startswith("_foreach"): 1013*da0073e9SAndroid Build Coastguard Worker assert self.device_check == DeviceCheckType.NoCheck, ( 1014*da0073e9SAndroid Build Coastguard Worker "foreach kernels fall back to slow path when tensor are on different devices, " 1015*da0073e9SAndroid Build Coastguard Worker "device_check not allowed to be enabled" 1016*da0073e9SAndroid Build Coastguard Worker ) 1017*da0073e9SAndroid Build Coastguard Worker 1018*da0073e9SAndroid Build Coastguard Worker # NB: if your function accidentally has rand/dropout/... in its name 1019*da0073e9SAndroid Build Coastguard Worker # but is not actually random, feel free to amend this to special case 1020*da0073e9SAndroid Build Coastguard Worker if ( 1021*da0073e9SAndroid Build Coastguard Worker "rand" in str(self.func.name) 1022*da0073e9SAndroid Build Coastguard Worker or ( 1023*da0073e9SAndroid Build Coastguard Worker ( 1024*da0073e9SAndroid Build Coastguard Worker "dropout" in str(self.func.name) 1025*da0073e9SAndroid Build Coastguard Worker or any( 1026*da0073e9SAndroid Build Coastguard Worker "dropout" in arg.name for arg in self.func.arguments.flat_all 1027*da0073e9SAndroid Build Coastguard Worker ) 1028*da0073e9SAndroid Build Coastguard Worker ) 1029*da0073e9SAndroid Build Coastguard Worker # Backwards of dropout is typically deterministic 1030*da0073e9SAndroid Build Coastguard Worker and "backward" not in str(self.func.name) 1031*da0073e9SAndroid Build Coastguard Worker and str(self.func.name.name) not in ["_cudnn_init_dropout_state"] 1032*da0073e9SAndroid Build Coastguard Worker ) 1033*da0073e9SAndroid Build Coastguard Worker or self.func.arguments.has_generator_arg() 1034*da0073e9SAndroid Build Coastguard Worker ): 1035*da0073e9SAndroid Build Coastguard Worker assert "nondeterministic_seeded" in self.tags, str(self.func.name) 1036*da0073e9SAndroid Build Coastguard Worker 1037*da0073e9SAndroid Build Coastguard Worker @property 1038*da0073e9SAndroid Build Coastguard Worker def has_composite_kernel(self) -> bool: 1039*da0073e9SAndroid Build Coastguard Worker return ( 1040*da0073e9SAndroid Build Coastguard Worker self.has_composite_implicit_autograd_kernel 1041*da0073e9SAndroid Build Coastguard Worker or self.has_composite_explicit_autograd_kernel 1042*da0073e9SAndroid Build Coastguard Worker or self.has_composite_explicit_autograd_non_functional_kernel 1043*da0073e9SAndroid Build Coastguard Worker ) or ( 1044*da0073e9SAndroid Build Coastguard Worker self.has_composite_implicit_autograd_kernel 1045*da0073e9SAndroid Build Coastguard Worker and self.has_composite_implicit_autograd_nested_tensor_kernel 1046*da0073e9SAndroid Build Coastguard Worker ) 1047*da0073e9SAndroid Build Coastguard Worker 1048*da0073e9SAndroid Build Coastguard Worker @property 1049*da0073e9SAndroid Build Coastguard Worker def is_view_op(self) -> bool: 1050*da0073e9SAndroid Build Coastguard Worker rets = self.func.returns 1051*da0073e9SAndroid Build Coastguard Worker is_non_mutating_view = len(rets) > 0 and any( 1052*da0073e9SAndroid Build Coastguard Worker r.annotation is not None and not r.annotation.is_write for r in rets 1053*da0073e9SAndroid Build Coastguard Worker ) 1054*da0073e9SAndroid Build Coastguard Worker # See Note [resize_ in Functionalization] for more dtails 1055*da0073e9SAndroid Build Coastguard Worker is_inplace_view = ( 1056*da0073e9SAndroid Build Coastguard Worker "inplace_view" in self.tags 1057*da0073e9SAndroid Build Coastguard Worker and str(self.func.name) != "resize_" 1058*da0073e9SAndroid Build Coastguard Worker and str(self.func.name) != "resize_as_" 1059*da0073e9SAndroid Build Coastguard Worker ) 1060*da0073e9SAndroid Build Coastguard Worker is_wildcard_view = any( 1061*da0073e9SAndroid Build Coastguard Worker inp.annotation is not None and "*" in inp.annotation.alias_set_after 1062*da0073e9SAndroid Build Coastguard Worker for inp in self.func.schema_order_arguments() 1063*da0073e9SAndroid Build Coastguard Worker ) 1064*da0073e9SAndroid Build Coastguard Worker return is_non_mutating_view or is_inplace_view or is_wildcard_view 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Worker @property 1067*da0073e9SAndroid Build Coastguard Worker def view_schema_kind(self) -> ViewSchemaKind: 1068*da0073e9SAndroid Build Coastguard Worker if self.is_view_op and self.func.name.name.inplace: 1069*da0073e9SAndroid Build Coastguard Worker assert "inplace_view" in self.tags 1070*da0073e9SAndroid Build Coastguard Worker return ViewSchemaKind.aliasing_inplace 1071*da0073e9SAndroid Build Coastguard Worker if self.is_view_op: 1072*da0073e9SAndroid Build Coastguard Worker return ViewSchemaKind.aliasing 1073*da0073e9SAndroid Build Coastguard Worker else: 1074*da0073e9SAndroid Build Coastguard Worker return ViewSchemaKind.non_aliasing 1075*da0073e9SAndroid Build Coastguard Worker 1076*da0073e9SAndroid Build Coastguard Worker @property 1077*da0073e9SAndroid Build Coastguard Worker def root_name(self) -> str: 1078*da0073e9SAndroid Build Coastguard Worker return self.func.name.name.base 1079*da0073e9SAndroid Build Coastguard Worker 1080*da0073e9SAndroid Build Coastguard Worker @property 1081*da0073e9SAndroid Build Coastguard Worker def part_of_structured_group(self) -> bool: 1082*da0073e9SAndroid Build Coastguard Worker return self.structured or self.structured_delegate is not None 1083*da0073e9SAndroid Build Coastguard Worker 1084*da0073e9SAndroid Build Coastguard Worker 1085*da0073e9SAndroid Build Coastguard Workerclass SchemaKind(Enum): 1086*da0073e9SAndroid Build Coastguard Worker functional = auto() 1087*da0073e9SAndroid Build Coastguard Worker inplace = auto() 1088*da0073e9SAndroid Build Coastguard Worker out = auto() 1089*da0073e9SAndroid Build Coastguard Worker mutable = auto() 1090*da0073e9SAndroid Build Coastguard Worker scratch = auto() 1091*da0073e9SAndroid Build Coastguard Worker 1092*da0073e9SAndroid Build Coastguard Worker 1093*da0073e9SAndroid Build Coastguard Worker# A structured kernel is guaranteed to have a functional and out variant, and 1094*da0073e9SAndroid Build Coastguard Worker# optionally an inplace variant. 1095*da0073e9SAndroid Build Coastguard Worker# 1096*da0073e9SAndroid Build Coastguard Worker# NB: we create NativeFunctionsGroup *even if* the function is not 1097*da0073e9SAndroid Build Coastguard Worker# actually annotated structured. Test the structured boolean to see if it 1098*da0073e9SAndroid Build Coastguard Worker# actually is structured or not. 1099*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 1100*da0073e9SAndroid Build Coastguard Workerclass NativeFunctionsGroup: 1101*da0073e9SAndroid Build Coastguard Worker functional: NativeFunction 1102*da0073e9SAndroid Build Coastguard Worker inplace: NativeFunction | None 1103*da0073e9SAndroid Build Coastguard Worker mutable: NativeFunction | None 1104*da0073e9SAndroid Build Coastguard Worker out: NativeFunction 1105*da0073e9SAndroid Build Coastguard Worker 1106*da0073e9SAndroid Build Coastguard Worker @property 1107*da0073e9SAndroid Build Coastguard Worker def structured(self) -> bool: 1108*da0073e9SAndroid Build Coastguard Worker # Whether or not the operator has a meta() function. This information is backend-agnostic. 1109*da0073e9SAndroid Build Coastguard Worker return self.out.structured 1110*da0073e9SAndroid Build Coastguard Worker 1111*da0073e9SAndroid Build Coastguard Worker def __post_init__(self) -> None: 1112*da0073e9SAndroid Build Coastguard Worker test_sig: FunctionSchema = self.functional.func.signature() 1113*da0073e9SAndroid Build Coastguard Worker for f in self.functions(): 1114*da0073e9SAndroid Build Coastguard Worker if test_sig != f.func.signature(): 1115*da0073e9SAndroid Build Coastguard Worker raise AssertionError( 1116*da0073e9SAndroid Build Coastguard Worker "NativeFunctionsGroup constructed from two NativeFunctions " 1117*da0073e9SAndroid Build Coastguard Worker f"that don't have matching signatures: {test_sig} != {f.func.signature()}" 1118*da0073e9SAndroid Build Coastguard Worker ) 1119*da0073e9SAndroid Build Coastguard Worker 1120*da0073e9SAndroid Build Coastguard Worker if self.structured != f.part_of_structured_group: 1121*da0073e9SAndroid Build Coastguard Worker raise AssertionError( 1122*da0073e9SAndroid Build Coastguard Worker "NativeFunctionsGroup constructed from structured and unstructured " 1123*da0073e9SAndroid Build Coastguard Worker f"functions: {self.out.func.name} and {f.func.name}" 1124*da0073e9SAndroid Build Coastguard Worker ) 1125*da0073e9SAndroid Build Coastguard Worker assert self.functional.func.kind() == SchemaKind.functional 1126*da0073e9SAndroid Build Coastguard Worker assert self.out.func.kind() == SchemaKind.out 1127*da0073e9SAndroid Build Coastguard Worker assert self.functional.namespace == self.out.namespace 1128*da0073e9SAndroid Build Coastguard Worker if self.inplace is not None: 1129*da0073e9SAndroid Build Coastguard Worker assert self.inplace.func.kind() == SchemaKind.inplace 1130*da0073e9SAndroid Build Coastguard Worker assert self.inplace.namespace == self.functional.namespace 1131*da0073e9SAndroid Build Coastguard Worker 1132*da0073e9SAndroid Build Coastguard Worker if self.mutable is not None: 1133*da0073e9SAndroid Build Coastguard Worker assert self.mutable.func.kind() == SchemaKind.mutable 1134*da0073e9SAndroid Build Coastguard Worker assert self.mutable.namespace == self.functional.namespace 1135*da0073e9SAndroid Build Coastguard Worker # See Note [Overload Ambiguity With Functional Variants] 1136*da0073e9SAndroid Build Coastguard Worker assert self.functional.func.name.name.functional_overload 1137*da0073e9SAndroid Build Coastguard Worker 1138*da0073e9SAndroid Build Coastguard Worker if self.structured: 1139*da0073e9SAndroid Build Coastguard Worker # For now, structured composite kernels are not supported (need some 1140*da0073e9SAndroid Build Coastguard Worker # design work to figure out how to make the composite case work) 1141*da0073e9SAndroid Build Coastguard Worker assert ( 1142*da0073e9SAndroid Build Coastguard Worker not self.out.has_composite_implicit_autograd_kernel 1143*da0073e9SAndroid Build Coastguard Worker and not self.out.has_composite_implicit_autograd_nested_tensor_kernel 1144*da0073e9SAndroid Build Coastguard Worker ) 1145*da0073e9SAndroid Build Coastguard Worker 1146*da0073e9SAndroid Build Coastguard Worker assert self.functional.structured_delegate == self.out.func.name, ( 1147*da0073e9SAndroid Build Coastguard Worker f"{self.functional.func.name} delegates to {self.functional.structured_delegate} " 1148*da0073e9SAndroid Build Coastguard Worker f"but its actual delegate is {self.out.func.name}" 1149*da0073e9SAndroid Build Coastguard Worker ) 1150*da0073e9SAndroid Build Coastguard Worker if self.inplace is not None: 1151*da0073e9SAndroid Build Coastguard Worker assert self.inplace.structured_delegate == self.out.func.name 1152*da0073e9SAndroid Build Coastguard Worker 1153*da0073e9SAndroid Build Coastguard Worker generated_fns = sorted( 1154*da0073e9SAndroid Build Coastguard Worker [str(f.func.name) for f in self.functions() if "generated" in f.tags] 1155*da0073e9SAndroid Build Coastguard Worker ) 1156*da0073e9SAndroid Build Coastguard Worker generated_fns_str = ", ".join(str(x) for x in generated_fns) 1157*da0073e9SAndroid Build Coastguard Worker expected_generated_fns: set[str] = set() 1158*da0073e9SAndroid Build Coastguard Worker for f in self.functions(): 1159*da0073e9SAndroid Build Coastguard Worker expected_generated_fns.update(str(op) for op in f.autogen) 1160*da0073e9SAndroid Build Coastguard Worker expected_generated_fns_str = ", ".join( 1161*da0073e9SAndroid Build Coastguard Worker str(x) for x in sorted(expected_generated_fns) 1162*da0073e9SAndroid Build Coastguard Worker ) 1163*da0073e9SAndroid Build Coastguard Worker if len(expected_generated_fns) == 0 and len(generated_fns) > 0: 1164*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1165*da0073e9SAndroid Build Coastguard Worker f"The codegen expects to be able to generate '{generated_fns_str}'." 1166*da0073e9SAndroid Build Coastguard Worker " In order to generate them however, we expect them to be called out explicitly in the yaml." 1167*da0073e9SAndroid Build Coastguard Worker f" Please add an 'autogen: {generated_fns_str}' line to the entry for {str(f.func.name)}" 1168*da0073e9SAndroid Build Coastguard Worker ) 1169*da0073e9SAndroid Build Coastguard Worker if expected_generated_fns_str != generated_fns_str: 1170*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1171*da0073e9SAndroid Build Coastguard Worker f"The codegen expects to be able to generate '{generated_fns_str}'." 1172*da0073e9SAndroid Build Coastguard Worker f" To do so, it expects a line: 'autogen: {generated_fns_str}'." 1173*da0073e9SAndroid Build Coastguard Worker f" Instead, it found 'autogen: {expected_generated_fns_str}'" 1174*da0073e9SAndroid Build Coastguard Worker ) 1175*da0073e9SAndroid Build Coastguard Worker 1176*da0073e9SAndroid Build Coastguard Worker def signature(self) -> FunctionSchema: 1177*da0073e9SAndroid Build Coastguard Worker return self.out.func.signature() 1178*da0073e9SAndroid Build Coastguard Worker 1179*da0073e9SAndroid Build Coastguard Worker def functions(self) -> Iterator[NativeFunction]: 1180*da0073e9SAndroid Build Coastguard Worker yield self.functional 1181*da0073e9SAndroid Build Coastguard Worker yield self.out 1182*da0073e9SAndroid Build Coastguard Worker if self.inplace is not None: 1183*da0073e9SAndroid Build Coastguard Worker yield self.inplace 1184*da0073e9SAndroid Build Coastguard Worker if self.mutable is not None: 1185*da0073e9SAndroid Build Coastguard Worker yield self.mutable 1186*da0073e9SAndroid Build Coastguard Worker 1187*da0073e9SAndroid Build Coastguard Worker @property 1188*da0073e9SAndroid Build Coastguard Worker def root_name(self) -> str: 1189*da0073e9SAndroid Build Coastguard Worker return self.functional.root_name 1190*da0073e9SAndroid Build Coastguard Worker 1191*da0073e9SAndroid Build Coastguard Worker @staticmethod 1192*da0073e9SAndroid Build Coastguard Worker def from_dict(d: dict[SchemaKind, NativeFunction]) -> NativeFunctionsGroup | None: 1193*da0073e9SAndroid Build Coastguard Worker assert d 1194*da0073e9SAndroid Build Coastguard Worker if len(d) == 1: 1195*da0073e9SAndroid Build Coastguard Worker return None 1196*da0073e9SAndroid Build Coastguard Worker d = dict(d) # non-destructive updates please 1197*da0073e9SAndroid Build Coastguard Worker functional = d.pop(SchemaKind.functional, None) 1198*da0073e9SAndroid Build Coastguard Worker inplace = d.pop(SchemaKind.inplace, None) 1199*da0073e9SAndroid Build Coastguard Worker mutable = d.pop(SchemaKind.mutable, None) 1200*da0073e9SAndroid Build Coastguard Worker out = d.pop(SchemaKind.out, None) 1201*da0073e9SAndroid Build Coastguard Worker assert not d 1202*da0073e9SAndroid Build Coastguard Worker assert functional is not None 1203*da0073e9SAndroid Build Coastguard Worker # There are a few operators which only have functional/inplace variants; 1204*da0073e9SAndroid Build Coastguard Worker # these don't count as structured for our purposes here 1205*da0073e9SAndroid Build Coastguard Worker if out is None: 1206*da0073e9SAndroid Build Coastguard Worker return None 1207*da0073e9SAndroid Build Coastguard Worker # assuming all variants have the same namespace 1208*da0073e9SAndroid Build Coastguard Worker return NativeFunctionsGroup( 1209*da0073e9SAndroid Build Coastguard Worker functional=functional, 1210*da0073e9SAndroid Build Coastguard Worker inplace=inplace, 1211*da0073e9SAndroid Build Coastguard Worker mutable=mutable, 1212*da0073e9SAndroid Build Coastguard Worker out=out, 1213*da0073e9SAndroid Build Coastguard Worker ) 1214*da0073e9SAndroid Build Coastguard Worker 1215*da0073e9SAndroid Build Coastguard Worker 1216*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 1217*da0073e9SAndroid Build Coastguard Workerclass BackendMetadata: 1218*da0073e9SAndroid Build Coastguard Worker # The name of the backend kernel, for a given operator 1219*da0073e9SAndroid Build Coastguard Worker # for in-tree backends. These names come directly from the 'dispatch" field 1220*da0073e9SAndroid Build Coastguard Worker # in native_functions.yaml. The dispatch entry is optional; in that 1221*da0073e9SAndroid Build Coastguard Worker # case, that is equivalent to having written: 1222*da0073e9SAndroid Build Coastguard Worker # 1223*da0073e9SAndroid Build Coastguard Worker # dispatch: 1224*da0073e9SAndroid Build Coastguard Worker # CompositeImplicitAutograd: $operator_name 1225*da0073e9SAndroid Build Coastguard Worker kernel: str 1226*da0073e9SAndroid Build Coastguard Worker # Whether or not the operator has a structured kernel implemented, for this particular backend. 1227*da0073e9SAndroid Build Coastguard Worker # For in-tree backends, they all have the same value for structured- this is listed 1228*da0073e9SAndroid Build Coastguard Worker # in native_functions.yaml. 1229*da0073e9SAndroid Build Coastguard Worker # However, external backends like XLA can indendently toggle which ops are structured. 1230*da0073e9SAndroid Build Coastguard Worker structured: bool 1231*da0073e9SAndroid Build Coastguard Worker 1232*da0073e9SAndroid Build Coastguard Worker # The namespace for kernels, default value: DEFAULT_KERNEL_NAMESPACE 1233*da0073e9SAndroid Build Coastguard Worker cpp_namespace: str 1234*da0073e9SAndroid Build Coastguard Worker 1235*da0073e9SAndroid Build Coastguard Worker def supports_symint(self) -> bool: 1236*da0073e9SAndroid Build Coastguard Worker return "_symint" in self.kernel 1237*da0073e9SAndroid Build Coastguard Worker 1238*da0073e9SAndroid Build Coastguard Worker 1239*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 1240*da0073e9SAndroid Build Coastguard Workerclass UfuncInnerLoop: 1241*da0073e9SAndroid Build Coastguard Worker name: str 1242*da0073e9SAndroid Build Coastguard Worker supported_dtypes: OrderedSet[ScalarType] 1243*da0073e9SAndroid Build Coastguard Worker # key is stored here because it affects the semantics of name, 1244*da0073e9SAndroid Build Coastguard Worker # so its helpful to have them together for further processing 1245*da0073e9SAndroid Build Coastguard Worker ufunc_key: UfuncKey 1246*da0073e9SAndroid Build Coastguard Worker 1247*da0073e9SAndroid Build Coastguard Worker @staticmethod 1248*da0073e9SAndroid Build Coastguard Worker def parse(value: str, ufunc_key: UfuncKey) -> UfuncInnerLoop: 1249*da0073e9SAndroid Build Coastguard Worker name, supported_dtypes_str = value.split(" ", 1) 1250*da0073e9SAndroid Build Coastguard Worker assert supported_dtypes_str[0] == "(" 1251*da0073e9SAndroid Build Coastguard Worker assert supported_dtypes_str[-1] == ")" 1252*da0073e9SAndroid Build Coastguard Worker supported_dtypes: OrderedSet[ScalarType] = OrderedSet() 1253*da0073e9SAndroid Build Coastguard Worker for k in supported_dtypes_str[1:-1].split(", "): 1254*da0073e9SAndroid Build Coastguard Worker supported_dtypes |= ScalarType.parse_set(k) 1255*da0073e9SAndroid Build Coastguard Worker return UfuncInnerLoop( 1256*da0073e9SAndroid Build Coastguard Worker name=name, supported_dtypes=supported_dtypes, ufunc_key=ufunc_key 1257*da0073e9SAndroid Build Coastguard Worker ) 1258*da0073e9SAndroid Build Coastguard Worker 1259*da0073e9SAndroid Build Coastguard Worker 1260*da0073e9SAndroid Build Coastguard Worker# BackendIndex represents a backend. 1261*da0073e9SAndroid Build Coastguard Worker# The BackendIndex encodes per-operator information that is potentially different 1262*da0073e9SAndroid Build Coastguard Worker# for each backend. The most obvious example is the name of the kernel 1263*da0073e9SAndroid Build Coastguard Worker# (the 'dispatch' entry in native_functions.yaml). 1264*da0073e9SAndroid Build Coastguard Worker# However, there can be other examples of different backends having different information. 1265*da0073e9SAndroid Build Coastguard Worker# External backends can choose to opt their kernels to be structured independently from in-tree backends, 1266*da0073e9SAndroid Build Coastguard Worker# which means that this information isn't inherently tied to a NativeFunction- it's different per backend. 1267*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 1268*da0073e9SAndroid Build Coastguard Workerclass BackendIndex: 1269*da0073e9SAndroid Build Coastguard Worker dispatch_key: DispatchKey 1270*da0073e9SAndroid Build Coastguard Worker # Mainly important for structured kernels, this determines which variant in the operator group is used to implement the others. 1271*da0073e9SAndroid Build Coastguard Worker # All in-tree ops use out kernels, while XLA uses functional kernels. 1272*da0073e9SAndroid Build Coastguard Worker use_out_as_primary: bool 1273*da0073e9SAndroid Build Coastguard Worker # Whether the backend requires a device guard, and device checks. 1274*da0073e9SAndroid Build Coastguard Worker # For in-tree backends, this is currently just CUDA/HIP 1275*da0073e9SAndroid Build Coastguard Worker # For out-of-tree backends, this is currently just Intel XPU 1276*da0073e9SAndroid Build Coastguard Worker device_guard: bool 1277*da0073e9SAndroid Build Coastguard Worker # Whether the backend is in-tree (CPU/CUDA) or out-of-tree (XLA) 1278*da0073e9SAndroid Build Coastguard Worker external: bool 1279*da0073e9SAndroid Build Coastguard Worker # Other backend-specific information that is on a per-operator basis 1280*da0073e9SAndroid Build Coastguard Worker index: dict[OperatorName, BackendMetadata] 1281*da0073e9SAndroid Build Coastguard Worker 1282*da0073e9SAndroid Build Coastguard Worker @staticmethod 1283*da0073e9SAndroid Build Coastguard Worker def grow_index( 1284*da0073e9SAndroid Build Coastguard Worker parent_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]], 1285*da0073e9SAndroid Build Coastguard Worker child_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]], 1286*da0073e9SAndroid Build Coastguard Worker ) -> None: 1287*da0073e9SAndroid Build Coastguard Worker for k, v in child_index.items(): 1288*da0073e9SAndroid Build Coastguard Worker for op_name, metadata in v.items(): 1289*da0073e9SAndroid Build Coastguard Worker assert ( 1290*da0073e9SAndroid Build Coastguard Worker op_name not in parent_index[k] 1291*da0073e9SAndroid Build Coastguard Worker ), f"duplicate operator {op_name} for dispatch key {k}" 1292*da0073e9SAndroid Build Coastguard Worker parent_index[k][op_name] = metadata 1293*da0073e9SAndroid Build Coastguard Worker 1294*da0073e9SAndroid Build Coastguard Worker def primary(self, g: NativeFunctionsGroup) -> NativeFunction: 1295*da0073e9SAndroid Build Coastguard Worker if self.use_out_as_primary: 1296*da0073e9SAndroid Build Coastguard Worker return g.out 1297*da0073e9SAndroid Build Coastguard Worker else: 1298*da0073e9SAndroid Build Coastguard Worker return g.functional 1299*da0073e9SAndroid Build Coastguard Worker 1300*da0073e9SAndroid Build Coastguard Worker def has_kernel(self, g: NativeFunction | NativeFunctionsGroup) -> bool: 1301*da0073e9SAndroid Build Coastguard Worker m = self.get_kernel(g) 1302*da0073e9SAndroid Build Coastguard Worker return m is not None 1303*da0073e9SAndroid Build Coastguard Worker 1304*da0073e9SAndroid Build Coastguard Worker def get_kernel( 1305*da0073e9SAndroid Build Coastguard Worker self, g: NativeFunction | NativeFunctionsGroup 1306*da0073e9SAndroid Build Coastguard Worker ) -> BackendMetadata | None: 1307*da0073e9SAndroid Build Coastguard Worker if isinstance(g, NativeFunction): 1308*da0073e9SAndroid Build Coastguard Worker f = g 1309*da0073e9SAndroid Build Coastguard Worker elif isinstance(g, NativeFunctionsGroup): 1310*da0073e9SAndroid Build Coastguard Worker f = self.primary(g) 1311*da0073e9SAndroid Build Coastguard Worker else: 1312*da0073e9SAndroid Build Coastguard Worker assert_never(g) 1313*da0073e9SAndroid Build Coastguard Worker if f.func.name not in self.index: 1314*da0073e9SAndroid Build Coastguard Worker return None 1315*da0073e9SAndroid Build Coastguard Worker return self.index[f.func.name] 1316*da0073e9SAndroid Build Coastguard Worker 1317*da0073e9SAndroid Build Coastguard Worker def native_function_class_name(self) -> str | None: 1318*da0073e9SAndroid Build Coastguard Worker if self.external: 1319*da0073e9SAndroid Build Coastguard Worker return f"{str(self.dispatch_key)}NativeFunctions" 1320*da0073e9SAndroid Build Coastguard Worker else: 1321*da0073e9SAndroid Build Coastguard Worker # TODO: This discrepancy isn't required; we could also generated 1322*da0073e9SAndroid Build Coastguard Worker # a class for in-tree kernels. It'll just require carefully 1323*da0073e9SAndroid Build Coastguard Worker # updating every kernel definition + callsite of every in-tree aten kernel. 1324*da0073e9SAndroid Build Coastguard Worker return None 1325*da0073e9SAndroid Build Coastguard Worker 1326*da0073e9SAndroid Build Coastguard Worker 1327*da0073e9SAndroid Build Coastguard Worker# The function schema is undoubtedly the most important data structure 1328*da0073e9SAndroid Build Coastguard Worker# in all of the codegen, as it defines the type signature for operators, 1329*da0073e9SAndroid Build Coastguard Worker# and most of the code generation we do is type directed (e.g., look at 1330*da0073e9SAndroid Build Coastguard Worker# the types, decide what to do. Think about how we code generate 1331*da0073e9SAndroid Build Coastguard Worker# C++ function stubs!) 1332*da0073e9SAndroid Build Coastguard Worker# 1333*da0073e9SAndroid Build Coastguard Worker# We will also see in this class the general structure for how we model 1334*da0073e9SAndroid Build Coastguard Worker# data in this code generation. A few notable properties to point out 1335*da0073e9SAndroid Build Coastguard Worker# ahead of time: 1336*da0073e9SAndroid Build Coastguard Worker# 1337*da0073e9SAndroid Build Coastguard Worker# - These dataclasses are a *lossless* representation of the strings 1338*da0073e9SAndroid Build Coastguard Worker# they are parsed from. In fact, we assert that given the 1339*da0073e9SAndroid Build Coastguard Worker# information stored in the dataclass, we can exactly reconstruct 1340*da0073e9SAndroid Build Coastguard Worker# the string we parsed from (and assert this inside the parse 1341*da0073e9SAndroid Build Coastguard Worker# definition). There are a few reasons for this: 1342*da0073e9SAndroid Build Coastguard Worker# 1343*da0073e9SAndroid Build Coastguard Worker# - If you find that it is difficult to reconstruct the string 1344*da0073e9SAndroid Build Coastguard Worker# given a dataclass, that is a clue that you are data 1345*da0073e9SAndroid Build Coastguard Worker# representation is wrong. 1346*da0073e9SAndroid Build Coastguard Worker# 1347*da0073e9SAndroid Build Coastguard Worker# - It helps ensure that all relevant information is present 1348*da0073e9SAndroid Build Coastguard Worker# in the dataclass, so that downstream users aren't tempted 1349*da0073e9SAndroid Build Coastguard Worker# to reparse the original string to get some information 1350*da0073e9SAndroid Build Coastguard Worker# that was omitted. 1351*da0073e9SAndroid Build Coastguard Worker# 1352*da0073e9SAndroid Build Coastguard Worker# - It forces you to represent the data in-memory in the same way 1353*da0073e9SAndroid Build Coastguard Worker# it is recorded textually, which makes the dataclasses easier 1354*da0073e9SAndroid Build Coastguard Worker# to understand for someone who is familiar with the 1355*da0073e9SAndroid Build Coastguard Worker# textual format. (As a tradeoff, it means you have to model 1356*da0073e9SAndroid Build Coastguard Worker# the syntax, even when it is inconvenient. But maybe that means 1357*da0073e9SAndroid Build Coastguard Worker# the syntax is bad!) If you don't understand the internal 1358*da0073e9SAndroid Build Coastguard Worker# representation, go look at the printing code to see how 1359*da0073e9SAndroid Build Coastguard Worker# it maps onto the surface syntax! 1360*da0073e9SAndroid Build Coastguard Worker# 1361*da0073e9SAndroid Build Coastguard Worker# - It makes it easy to test the parsing code, as parsing code 1362*da0073e9SAndroid Build Coastguard Worker# that is inconsistent with the string code will fail early 1363*da0073e9SAndroid Build Coastguard Worker# and loudly. (As a tradeoff, it makes the parsing code a bit 1364*da0073e9SAndroid Build Coastguard Worker# brittle (in particular, with trivial whitespace changes you 1365*da0073e9SAndroid Build Coastguard Worker# are likely to trigger an assert error). 1366*da0073e9SAndroid Build Coastguard Worker# 1367*da0073e9SAndroid Build Coastguard Worker# In general, try to make the __str__ code as simple as possible 1368*da0073e9SAndroid Build Coastguard Worker# (even at the cost of more complex parsing logic.) Additionally, 1369*da0073e9SAndroid Build Coastguard Worker# try to minimize redundancy in data representation. (Precomputed 1370*da0073e9SAndroid Build Coastguard Worker# fields are OK though: they are defined as a simple function on 1371*da0073e9SAndroid Build Coastguard Worker# the canonical representation in question.) 1372*da0073e9SAndroid Build Coastguard Worker# 1373*da0073e9SAndroid Build Coastguard Worker# - These dataclasses are all frozen; once constructed their 1374*da0073e9SAndroid Build Coastguard Worker# values never change. This makes it easy to tell where any 1375*da0073e9SAndroid Build Coastguard Worker# given data came from: just look to the constructor. As a 1376*da0073e9SAndroid Build Coastguard Worker# tradeoff, you can't easily "decorate" a schema with extra 1377*da0073e9SAndroid Build Coastguard Worker# information from a post-facto analysis. We impose this 1378*da0073e9SAndroid Build Coastguard Worker# restriction to make these structures more understandable. 1379*da0073e9SAndroid Build Coastguard Worker# 1380*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 1381*da0073e9SAndroid Build Coastguard Workerclass FunctionSchema: 1382*da0073e9SAndroid Build Coastguard Worker # The name of the operator this function schema describes. 1383*da0073e9SAndroid Build Coastguard Worker name: OperatorName 1384*da0073e9SAndroid Build Coastguard Worker 1385*da0073e9SAndroid Build Coastguard Worker arguments: Arguments 1386*da0073e9SAndroid Build Coastguard Worker 1387*da0073e9SAndroid Build Coastguard Worker # TODO: Need to handle collisions with argument names at some point 1388*da0073e9SAndroid Build Coastguard Worker returns: tuple[Return, ...] 1389*da0073e9SAndroid Build Coastguard Worker 1390*da0073e9SAndroid Build Coastguard Worker @property 1391*da0073e9SAndroid Build Coastguard Worker def is_mutable(self) -> bool: 1392*da0073e9SAndroid Build Coastguard Worker def is_write(arg: Argument) -> bool: 1393*da0073e9SAndroid Build Coastguard Worker if arg.annotation is None: 1394*da0073e9SAndroid Build Coastguard Worker return False 1395*da0073e9SAndroid Build Coastguard Worker return arg.annotation.is_write 1396*da0073e9SAndroid Build Coastguard Worker 1397*da0073e9SAndroid Build Coastguard Worker # Corresponds to torch._C._FunctionSchema.is_mutable 1398*da0073e9SAndroid Build Coastguard Worker # See aten/src/ATen/core/function_schema.h (keep these in sync) 1399*da0073e9SAndroid Build Coastguard Worker return any(is_write(a) for a in self.arguments.flat_all) 1400*da0073e9SAndroid Build Coastguard Worker 1401*da0073e9SAndroid Build Coastguard Worker def schema_order_arguments(self) -> Iterator[Argument]: 1402*da0073e9SAndroid Build Coastguard Worker return itertools.chain( 1403*da0073e9SAndroid Build Coastguard Worker self.arguments.flat_positional, 1404*da0073e9SAndroid Build Coastguard Worker self.arguments.flat_kwarg_only, 1405*da0073e9SAndroid Build Coastguard Worker self.arguments.out, 1406*da0073e9SAndroid Build Coastguard Worker ) 1407*da0073e9SAndroid Build Coastguard Worker 1408*da0073e9SAndroid Build Coastguard Worker decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)") 1409*da0073e9SAndroid Build Coastguard Worker 1410*da0073e9SAndroid Build Coastguard Worker @staticmethod 1411*da0073e9SAndroid Build Coastguard Worker def parse(func: str) -> FunctionSchema: 1412*da0073e9SAndroid Build Coastguard Worker # We should probably get a proper parser here 1413*da0073e9SAndroid Build Coastguard Worker decls = FunctionSchema.decl_re.findall(func) 1414*da0073e9SAndroid Build Coastguard Worker assert len(decls) == 1, f"Invalid function schema: {func}" 1415*da0073e9SAndroid Build Coastguard Worker ops, args, return_decl = decls[0] 1416*da0073e9SAndroid Build Coastguard Worker name = OperatorName.parse(ops) 1417*da0073e9SAndroid Build Coastguard Worker arguments = Arguments.parse(args) 1418*da0073e9SAndroid Build Coastguard Worker returns = parse_returns(return_decl) 1419*da0073e9SAndroid Build Coastguard Worker r = FunctionSchema(name=name, arguments=arguments, returns=returns) 1420*da0073e9SAndroid Build Coastguard Worker assert str(r) == func, f"{str(r)} != {func}" 1421*da0073e9SAndroid Build Coastguard Worker return r 1422*da0073e9SAndroid Build Coastguard Worker 1423*da0073e9SAndroid Build Coastguard Worker def returns_are_aliased(self) -> bool: 1424*da0073e9SAndroid Build Coastguard Worker # We assert earlier that schemas can't have a mix of aliased and non-aliased returns 1425*da0073e9SAndroid Build Coastguard Worker return any( 1426*da0073e9SAndroid Build Coastguard Worker r 1427*da0073e9SAndroid Build Coastguard Worker for r in self.returns 1428*da0073e9SAndroid Build Coastguard Worker if r.annotation is not None and r.annotation.is_write 1429*da0073e9SAndroid Build Coastguard Worker ) 1430*da0073e9SAndroid Build Coastguard Worker 1431*da0073e9SAndroid Build Coastguard Worker def __post_init__(self) -> None: 1432*da0073e9SAndroid Build Coastguard Worker for arg, ret in zip(self.arguments.out, self.returns): 1433*da0073e9SAndroid Build Coastguard Worker assert arg.annotation == ret.annotation, ( 1434*da0073e9SAndroid Build Coastguard Worker "Out arguments must have matching return Tensor; furthermore, " 1435*da0073e9SAndroid Build Coastguard Worker "the ith-argument needs to correspond to the ith return" 1436*da0073e9SAndroid Build Coastguard Worker ) 1437*da0073e9SAndroid Build Coastguard Worker # We also enforce that if you have any mutable, positional args, then they are not returned. 1438*da0073e9SAndroid Build Coastguard Worker # This makes it easier to group these functions properly with their functional/out= counterparts. 1439*da0073e9SAndroid Build Coastguard Worker for a in self.arguments.post_self_positional_mutable: 1440*da0073e9SAndroid Build Coastguard Worker assert not any( 1441*da0073e9SAndroid Build Coastguard Worker a.annotation == r.annotation for r in self.returns 1442*da0073e9SAndroid Build Coastguard Worker ), f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}" 1443*da0073e9SAndroid Build Coastguard Worker # Invariant: we expect out arguments to appear as keyword arguments in the schema. 1444*da0073e9SAndroid Build Coastguard Worker # This means that all mutable returns should be aliased to a keyword argument 1445*da0073e9SAndroid Build Coastguard Worker # (except for "self", which we explicitly don't treat as an out argument because of its use in methods) 1446*da0073e9SAndroid Build Coastguard Worker # See Note [is_out_fn] 1447*da0073e9SAndroid Build Coastguard Worker out_and_self = list(self.arguments.out) + [ 1448*da0073e9SAndroid Build Coastguard Worker arg for arg in self.arguments.flat_positional if arg.name == "self" 1449*da0073e9SAndroid Build Coastguard Worker ] 1450*da0073e9SAndroid Build Coastguard Worker mutable_returns = [ 1451*da0073e9SAndroid Build Coastguard Worker ret 1452*da0073e9SAndroid Build Coastguard Worker for ret in self.returns 1453*da0073e9SAndroid Build Coastguard Worker if ret.annotation is not None and ret.annotation.is_write 1454*da0073e9SAndroid Build Coastguard Worker ] 1455*da0073e9SAndroid Build Coastguard Worker immutable_returns = [ 1456*da0073e9SAndroid Build Coastguard Worker ret 1457*da0073e9SAndroid Build Coastguard Worker for ret in self.returns 1458*da0073e9SAndroid Build Coastguard Worker if ret.annotation is None or not ret.annotation.is_write 1459*da0073e9SAndroid Build Coastguard Worker ] 1460*da0073e9SAndroid Build Coastguard Worker # Some assertions: We don't want any functions with a return type of "-> (Tensor(a!), Tensor)", 1461*da0073e9SAndroid Build Coastguard Worker # because: 1462*da0073e9SAndroid Build Coastguard Worker # (1) It's more annoying to handle properly 1463*da0073e9SAndroid Build Coastguard Worker # (2) It's unnecessary - you can't method-chain on the first (mutated) output because it's part of a tuple. 1464*da0073e9SAndroid Build Coastguard Worker # Instead, we expect the (a!) argument to not be returned. 1465*da0073e9SAndroid Build Coastguard Worker assert ( 1466*da0073e9SAndroid Build Coastguard Worker len(mutable_returns) == 0 or len(immutable_returns) == 0 1467*da0073e9SAndroid Build Coastguard Worker ), f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}" 1468*da0073e9SAndroid Build Coastguard Worker for ret in mutable_returns: 1469*da0073e9SAndroid Build Coastguard Worker assert any(ret.annotation == arg.annotation for arg in out_and_self), ( 1470*da0073e9SAndroid Build Coastguard Worker 'All mutable returns must be aliased either to a keyword argument, or to "self". ' 1471*da0073e9SAndroid Build Coastguard Worker "Did you forget to mark an out argument as keyword-only?" 1472*da0073e9SAndroid Build Coastguard Worker ) 1473*da0073e9SAndroid Build Coastguard Worker if self.arguments.out: 1474*da0073e9SAndroid Build Coastguard Worker # out= ops that return their mutable inputs are only really useful for method chaining. 1475*da0073e9SAndroid Build Coastguard Worker # And method chaining is only really useful if the thing you're returning is a plain Tensor. 1476*da0073e9SAndroid Build Coastguard Worker # So ideally, we'd enforce that out= ops with a single plain mutable tensor should return the tensor, 1477*da0073e9SAndroid Build Coastguard Worker # and all other types of out= op schemas should return void. 1478*da0073e9SAndroid Build Coastguard Worker # There are a bunch of existing out= ops that return tuples of tensors though, so we're stuck with allowing that. 1479*da0073e9SAndroid Build Coastguard Worker if any(a.type != BaseType(BaseTy.Tensor) for a in self.arguments.out): 1480*da0073e9SAndroid Build Coastguard Worker assert ( 1481*da0073e9SAndroid Build Coastguard Worker len(self.returns) == 0 1482*da0073e9SAndroid Build Coastguard Worker ), "out= ops that accept tensor lists as out arguments " 1483*da0073e9SAndroid Build Coastguard Worker "are expected to have no return type (since you can't do method chaining on them)" 1484*da0073e9SAndroid Build Coastguard Worker else: 1485*da0073e9SAndroid Build Coastguard Worker # mutable keyword arguments whose name has _scratch_ prefix are 1486*da0073e9SAndroid Build Coastguard Worker # scratch tensors for memory planning and should not be returned 1487*da0073e9SAndroid Build Coastguard Worker assert len( 1488*da0073e9SAndroid Build Coastguard Worker [ 1489*da0073e9SAndroid Build Coastguard Worker arg 1490*da0073e9SAndroid Build Coastguard Worker for arg in self.arguments.out 1491*da0073e9SAndroid Build Coastguard Worker if not arg.name.startswith("_scratch_") 1492*da0073e9SAndroid Build Coastguard Worker ] 1493*da0073e9SAndroid Build Coastguard Worker ) == len( 1494*da0073e9SAndroid Build Coastguard Worker self.returns 1495*da0073e9SAndroid Build Coastguard Worker ), "Must return as many arguments as there are out arguments, or no return at all" 1496*da0073e9SAndroid Build Coastguard Worker 1497*da0073e9SAndroid Build Coastguard Worker if self.name.name.inplace: 1498*da0073e9SAndroid Build Coastguard Worker self_a = self.arguments.self_arg 1499*da0073e9SAndroid Build Coastguard Worker assert ( 1500*da0073e9SAndroid Build Coastguard Worker self_a 1501*da0073e9SAndroid Build Coastguard Worker and self_a.argument.annotation 1502*da0073e9SAndroid Build Coastguard Worker and self_a.argument.annotation.is_write 1503*da0073e9SAndroid Build Coastguard Worker ) 1504*da0073e9SAndroid Build Coastguard Worker if self_a.argument.type == BaseType(BaseTy.Tensor): 1505*da0073e9SAndroid Build Coastguard Worker # All inplace ops with an ordinary `Tensor self` argument should return self, 1506*da0073e9SAndroid Build Coastguard Worker # to allow for method chaining. 1507*da0073e9SAndroid Build Coastguard Worker assert ( 1508*da0073e9SAndroid Build Coastguard Worker len(self.returns) == 1 1509*da0073e9SAndroid Build Coastguard Worker and self.returns[0].annotation == self_a.argument.annotation 1510*da0073e9SAndroid Build Coastguard Worker ) 1511*da0073e9SAndroid Build Coastguard Worker else: 1512*da0073e9SAndroid Build Coastguard Worker # You can't method chain on non-tensor self arguments though (like a List[Tensor]) 1513*da0073e9SAndroid Build Coastguard Worker # so in all other cases we expect the return type to be none. 1514*da0073e9SAndroid Build Coastguard Worker assert len(self.returns) == 0 1515*da0073e9SAndroid Build Coastguard Worker 1516*da0073e9SAndroid Build Coastguard Worker if self.arguments.tensor_options is not None: 1517*da0073e9SAndroid Build Coastguard Worker assert self.kind() == SchemaKind.functional, ( 1518*da0073e9SAndroid Build Coastguard Worker "Found an operator that is not functional or out variant, but has tensor options arguments." 1519*da0073e9SAndroid Build Coastguard Worker "This is not allowed- tensor options arguments are only allowed for factory functions." 1520*da0073e9SAndroid Build Coastguard Worker f"schema: {str(self)}" 1521*da0073e9SAndroid Build Coastguard Worker ) 1522*da0073e9SAndroid Build Coastguard Worker if self.is_functional_fn(): 1523*da0073e9SAndroid Build Coastguard Worker assert self.kind() == SchemaKind.functional, ( 1524*da0073e9SAndroid Build Coastguard Worker "Found an operator that is not functional, but its overload contains the string 'functional'." 1525*da0073e9SAndroid Build Coastguard Worker "This is a special keyword in the codegen, please use a different overload name." 1526*da0073e9SAndroid Build Coastguard Worker f"schema: {str(self)}" 1527*da0073e9SAndroid Build Coastguard Worker ) 1528*da0073e9SAndroid Build Coastguard Worker 1529*da0073e9SAndroid Build Coastguard Worker def is_functional_fn(self) -> bool: 1530*da0073e9SAndroid Build Coastguard Worker return "functional" in self.name.overload_name 1531*da0073e9SAndroid Build Coastguard Worker 1532*da0073e9SAndroid Build Coastguard Worker def is_out_fn(self) -> bool: 1533*da0073e9SAndroid Build Coastguard Worker # Note [is_out_fn] 1534*da0073e9SAndroid Build Coastguard Worker # 1535*da0073e9SAndroid Build Coastguard Worker # out functions are the variants which take an explicit out= argument 1536*da0073e9SAndroid Build Coastguard Worker # to populate into. We need to know if a schema corresponds to an 1537*da0073e9SAndroid Build Coastguard Worker # out function for several reasons: 1538*da0073e9SAndroid Build Coastguard Worker # 1539*da0073e9SAndroid Build Coastguard Worker # - They codegen differently in C++ API 1540*da0073e9SAndroid Build Coastguard Worker # - codegen to at::add_out rather than at::add 1541*da0073e9SAndroid Build Coastguard Worker # - out argument is moved to front of C++ argument list 1542*da0073e9SAndroid Build Coastguard Worker # 1543*da0073e9SAndroid Build Coastguard Worker # out functions are DEFINED to be any function with a keyword-only 1544*da0073e9SAndroid Build Coastguard Worker # argument that is mutable. In principle, this could lead to a 1545*da0073e9SAndroid Build Coastguard Worker # false positive if you define a function that mutates a 1546*da0073e9SAndroid Build Coastguard Worker # kwarg only argument, but this isn't the "true" output of this 1547*da0073e9SAndroid Build Coastguard Worker # function. A more robust definition that would work in this 1548*da0073e9SAndroid Build Coastguard Worker # case would also look at: 1549*da0073e9SAndroid Build Coastguard Worker # 1550*da0073e9SAndroid Build Coastguard Worker # - The output types. Out functions take in the arguments 1551*da0073e9SAndroid Build Coastguard Worker # they mutate and then return them again; this is sort 1552*da0073e9SAndroid Build Coastguard Worker # of "definitionally" what makes something an out function. 1553*da0073e9SAndroid Build Coastguard Worker # Historically, we DO check this for consistency. 1554*da0073e9SAndroid Build Coastguard Worker # - Correspondence with pure variant. An out function 1555*da0073e9SAndroid Build Coastguard Worker # should have a signature equivalent to its pure variant, 1556*da0073e9SAndroid Build Coastguard Worker # but just with extra kwargs for the output elements. This 1557*da0073e9SAndroid Build Coastguard Worker # is difficult to actually check for and historically 1558*da0073e9SAndroid Build Coastguard Worker # we only do this check in tools/ 1559*da0073e9SAndroid Build Coastguard Worker return bool(self.arguments.out) 1560*da0073e9SAndroid Build Coastguard Worker 1561*da0073e9SAndroid Build Coastguard Worker def kind(self) -> SchemaKind: 1562*da0073e9SAndroid Build Coastguard Worker """ 1563*da0073e9SAndroid Build Coastguard Worker What kind of schema is this? A functional schema is one 1564*da0073e9SAndroid Build Coastguard Worker that returns a newly allocated output; an inplace schema 1565*da0073e9SAndroid Build Coastguard Worker modifies the self argument inplace; an out schema writes 1566*da0073e9SAndroid Build Coastguard Worker the result into an explicitly provided out argument. 1567*da0073e9SAndroid Build Coastguard Worker """ 1568*da0073e9SAndroid Build Coastguard Worker is_out = bool(self.arguments.out) 1569*da0073e9SAndroid Build Coastguard Worker is_scratch = bool( 1570*da0073e9SAndroid Build Coastguard Worker [arg for arg in self.arguments.out if arg.name.startswith("_scratch_")] 1571*da0073e9SAndroid Build Coastguard Worker ) 1572*da0073e9SAndroid Build Coastguard Worker is_inplace = self.name.name.inplace 1573*da0073e9SAndroid Build Coastguard Worker is_mutable = any( 1574*da0073e9SAndroid Build Coastguard Worker a.annotation is not None and a.annotation.is_write 1575*da0073e9SAndroid Build Coastguard Worker for a in self.arguments.post_self_positional 1576*da0073e9SAndroid Build Coastguard Worker ) 1577*da0073e9SAndroid Build Coastguard Worker assert not (is_out and is_inplace) 1578*da0073e9SAndroid Build Coastguard Worker # out= and inplace schemas can also have post_self_positional mutable args, 1579*da0073e9SAndroid Build Coastguard Worker # but we give precedence to out= and inplace when deciding the schema kind. 1580*da0073e9SAndroid Build Coastguard Worker # Tradeoff: we probably don't want to have to teach codegen that looks at inplace ops 1581*da0073e9SAndroid Build Coastguard Worker # to also worry about mutable post_self_positional arguments, 1582*da0073e9SAndroid Build Coastguard Worker # but it seems like a much bigger lift to classify them has having a new schema kind. 1583*da0073e9SAndroid Build Coastguard Worker # The number of ops that fit in this strange category is small enough that 1584*da0073e9SAndroid Build Coastguard Worker # we can probably manually write code for them instead of forcing the codegen to handle them. 1585*da0073e9SAndroid Build Coastguard Worker if is_inplace: 1586*da0073e9SAndroid Build Coastguard Worker return SchemaKind.inplace 1587*da0073e9SAndroid Build Coastguard Worker elif is_scratch: 1588*da0073e9SAndroid Build Coastguard Worker assert ( 1589*da0073e9SAndroid Build Coastguard Worker is_out 1590*da0073e9SAndroid Build Coastguard Worker ), "invariant: all scratch operators are expected to be out= operators too" 1591*da0073e9SAndroid Build Coastguard Worker return SchemaKind.scratch 1592*da0073e9SAndroid Build Coastguard Worker elif is_out: 1593*da0073e9SAndroid Build Coastguard Worker assert ( 1594*da0073e9SAndroid Build Coastguard Worker not is_scratch 1595*da0073e9SAndroid Build Coastguard Worker ), "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!" 1596*da0073e9SAndroid Build Coastguard Worker return SchemaKind.out 1597*da0073e9SAndroid Build Coastguard Worker elif is_mutable: 1598*da0073e9SAndroid Build Coastguard Worker return SchemaKind.mutable 1599*da0073e9SAndroid Build Coastguard Worker else: 1600*da0073e9SAndroid Build Coastguard Worker return SchemaKind.functional 1601*da0073e9SAndroid Build Coastguard Worker 1602*da0073e9SAndroid Build Coastguard Worker # For every return: 1603*da0073e9SAndroid Build Coastguard Worker # - If the return aliases an input, we return the input name 1604*da0073e9SAndroid Build Coastguard Worker # - Otherwise, we return None. 1605*da0073e9SAndroid Build Coastguard Worker # If return names were enforced to be consistent with aliasing information, then we wouldn't need this. 1606*da0073e9SAndroid Build Coastguard Worker def aliased_return_names(self) -> list[str | None]: 1607*da0073e9SAndroid Build Coastguard Worker outs: list[str | None] = [] 1608*da0073e9SAndroid Build Coastguard Worker for r in self.returns: 1609*da0073e9SAndroid Build Coastguard Worker aliased_args = [ 1610*da0073e9SAndroid Build Coastguard Worker a 1611*da0073e9SAndroid Build Coastguard Worker for a in self.arguments.flat_all 1612*da0073e9SAndroid Build Coastguard Worker if a.annotation is not None and a.annotation == r.annotation 1613*da0073e9SAndroid Build Coastguard Worker ] 1614*da0073e9SAndroid Build Coastguard Worker if len(aliased_args) == 0: 1615*da0073e9SAndroid Build Coastguard Worker outs.append(None) 1616*da0073e9SAndroid Build Coastguard Worker elif len(aliased_args) == 1: 1617*da0073e9SAndroid Build Coastguard Worker outs.append(aliased_args[0].name) 1618*da0073e9SAndroid Build Coastguard Worker else: 1619*da0073e9SAndroid Build Coastguard Worker aliased_names = ", ".join(a.name for a in aliased_args) 1620*da0073e9SAndroid Build Coastguard Worker raise AssertionError( 1621*da0073e9SAndroid Build Coastguard Worker f"Found a return ({r.name})that aliases multiple inputs ({aliased_names})" 1622*da0073e9SAndroid Build Coastguard Worker ) 1623*da0073e9SAndroid Build Coastguard Worker return outs 1624*da0073e9SAndroid Build Coastguard Worker 1625*da0073e9SAndroid Build Coastguard Worker def signature( 1626*da0073e9SAndroid Build Coastguard Worker self, 1627*da0073e9SAndroid Build Coastguard Worker *, 1628*da0073e9SAndroid Build Coastguard Worker strip_default: bool = False, 1629*da0073e9SAndroid Build Coastguard Worker strip_view_copy_name: bool = False, 1630*da0073e9SAndroid Build Coastguard Worker keep_return_names: bool = False, 1631*da0073e9SAndroid Build Coastguard Worker ) -> FunctionSchema: 1632*da0073e9SAndroid Build Coastguard Worker """ 1633*da0073e9SAndroid Build Coastguard Worker Certain schemas are 'related', in that they are simply 1634*da0073e9SAndroid Build Coastguard Worker inplace/out/functional versions of the same function. This method 1635*da0073e9SAndroid Build Coastguard Worker factors these schemas into the "core" functional signature which 1636*da0073e9SAndroid Build Coastguard Worker is equal across all versions. 1637*da0073e9SAndroid Build Coastguard Worker 1638*da0073e9SAndroid Build Coastguard Worker Here is what normalization happens to the schema to convert 1639*da0073e9SAndroid Build Coastguard Worker it to a signature: 1640*da0073e9SAndroid Build Coastguard Worker - The overload name is stripped (name is retained, since 1641*da0073e9SAndroid Build Coastguard Worker it expresses semantic content about what the function does) 1642*da0073e9SAndroid Build Coastguard Worker - Inplace is set False 1643*da0073e9SAndroid Build Coastguard Worker - Out arguments are stripped 1644*da0073e9SAndroid Build Coastguard Worker - Mutable post_self_positional args are converted to returns 1645*da0073e9SAndroid Build Coastguard Worker - Mutability annotations are stripped (this is sound 1646*da0073e9SAndroid Build Coastguard Worker because you cannot overload on mutability annotation) 1647*da0073e9SAndroid Build Coastguard Worker - Return names are stripped since they are not overloadable and 1648*da0073e9SAndroid Build Coastguard Worker some variants have return names but some not 1649*da0073e9SAndroid Build Coastguard Worker - TensorOptions are dropped 1650*da0073e9SAndroid Build Coastguard Worker because out= variants of factory functions don't include them 1651*da0073e9SAndroid Build Coastguard Worker (and we want to be able to pair up factory functions with their out variants) 1652*da0073e9SAndroid Build Coastguard Worker 1653*da0073e9SAndroid Build Coastguard Worker Finally, we want to be able to pair up related "view" and their 1654*da0073e9SAndroid Build Coastguard Worker corresponding "view_copy" operators. We do this by optionally 1655*da0073e9SAndroid Build Coastguard Worker stripping the trailing "_copy" from the base name. 1656*da0073e9SAndroid Build Coastguard Worker 1657*da0073e9SAndroid Build Coastguard Worker Example of a mutable op before and after: 1658*da0073e9SAndroid Build Coastguard Worker 1659*da0073e9SAndroid Build Coastguard Worker f.func (Mutable operator): 1660*da0073e9SAndroid Build Coastguard Worker _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950 1661*da0073e9SAndroid Build Coastguard Worker 1662*da0073e9SAndroid Build Coastguard Worker f.func (Corresponding functional operator): 1663*da0073e9SAndroid Build Coastguard Worker _fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) # noqa: B950 1664*da0073e9SAndroid Build Coastguard Worker 1665*da0073e9SAndroid Build Coastguard Worker f.func.signature() output: 1666*da0073e9SAndroid Build Coastguard Worker _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) # noqa: B950 1667*da0073e9SAndroid Build Coastguard Worker """ 1668*da0073e9SAndroid Build Coastguard Worker 1669*da0073e9SAndroid Build Coastguard Worker def strip_ret_annotation(r: Return) -> Return: 1670*da0073e9SAndroid Build Coastguard Worker return Return( 1671*da0073e9SAndroid Build Coastguard Worker name=r.name if keep_return_names else None, 1672*da0073e9SAndroid Build Coastguard Worker type=r.type, 1673*da0073e9SAndroid Build Coastguard Worker annotation=None, 1674*da0073e9SAndroid Build Coastguard Worker ) 1675*da0073e9SAndroid Build Coastguard Worker 1676*da0073e9SAndroid Build Coastguard Worker base_name = self.name.name.base 1677*da0073e9SAndroid Build Coastguard Worker if strip_view_copy_name: 1678*da0073e9SAndroid Build Coastguard Worker if base_name.endswith("_copy"): 1679*da0073e9SAndroid Build Coastguard Worker base_name = base_name.replace("_copy", "") 1680*da0073e9SAndroid Build Coastguard Worker elif base_name.endswith("_scatter"): 1681*da0073e9SAndroid Build Coastguard Worker base_name = base_name.replace("scatter", "inverse") 1682*da0073e9SAndroid Build Coastguard Worker 1683*da0073e9SAndroid Build Coastguard Worker # find mutable inputs that are not originally returned, and convert them to returns 1684*da0073e9SAndroid Build Coastguard Worker returns_from_mutable_inputs = tuple( 1685*da0073e9SAndroid Build Coastguard Worker # When we're grouping functions we strip the return names, 1686*da0073e9SAndroid Build Coastguard Worker # but when we're generating the actual functional variants then we follow 1687*da0073e9SAndroid Build Coastguard Worker # a convention for what to name the returns 1688*da0073e9SAndroid Build Coastguard Worker Return( 1689*da0073e9SAndroid Build Coastguard Worker name=f"{a.name}_out" if keep_return_names else None, 1690*da0073e9SAndroid Build Coastguard Worker type=a.type, 1691*da0073e9SAndroid Build Coastguard Worker annotation=None, 1692*da0073e9SAndroid Build Coastguard Worker ) 1693*da0073e9SAndroid Build Coastguard Worker for a in itertools.chain( 1694*da0073e9SAndroid Build Coastguard Worker # Order is important here (otherwise e.g. inplace with mutable args 1695*da0073e9SAndroid Build Coastguard Worker # and out= with mutable args won't have the same signature) 1696*da0073e9SAndroid Build Coastguard Worker [self.arguments.self_arg.argument] 1697*da0073e9SAndroid Build Coastguard Worker if self.arguments.self_arg is not None 1698*da0073e9SAndroid Build Coastguard Worker else [], 1699*da0073e9SAndroid Build Coastguard Worker self.arguments.out, 1700*da0073e9SAndroid Build Coastguard Worker self.arguments.post_self_positional, 1701*da0073e9SAndroid Build Coastguard Worker ) 1702*da0073e9SAndroid Build Coastguard Worker if a.annotation is not None 1703*da0073e9SAndroid Build Coastguard Worker and a.annotation.is_write 1704*da0073e9SAndroid Build Coastguard Worker and not any(a.annotation == r.annotation for r in self.returns) 1705*da0073e9SAndroid Build Coastguard Worker ) 1706*da0073e9SAndroid Build Coastguard Worker original_returns = tuple(map(strip_ret_annotation, self.returns)) 1707*da0073e9SAndroid Build Coastguard Worker # Ordering is important here. We expect the "mutable input" returns to come last. 1708*da0073e9SAndroid Build Coastguard Worker returns = original_returns + returns_from_mutable_inputs 1709*da0073e9SAndroid Build Coastguard Worker 1710*da0073e9SAndroid Build Coastguard Worker args_sig = self.arguments.signature(strip_default=strip_default) 1711*da0073e9SAndroid Build Coastguard Worker # See Note [bernoulli.p schema] 1712*da0073e9SAndroid Build Coastguard Worker if str(self.name) == "bernoulli.p": 1713*da0073e9SAndroid Build Coastguard Worker args_sig = Arguments.parse(str(args_sig).replace("float p", "float p=0.5")) 1714*da0073e9SAndroid Build Coastguard Worker 1715*da0073e9SAndroid Build Coastguard Worker return FunctionSchema( 1716*da0073e9SAndroid Build Coastguard Worker name=OperatorName( 1717*da0073e9SAndroid Build Coastguard Worker name=BaseOperatorName( 1718*da0073e9SAndroid Build Coastguard Worker base=base_name, 1719*da0073e9SAndroid Build Coastguard Worker inplace=False, 1720*da0073e9SAndroid Build Coastguard Worker dunder_method=self.name.name.dunder_method, 1721*da0073e9SAndroid Build Coastguard Worker ), 1722*da0073e9SAndroid Build Coastguard Worker overload_name="", # stripped 1723*da0073e9SAndroid Build Coastguard Worker ), 1724*da0073e9SAndroid Build Coastguard Worker arguments=args_sig, 1725*da0073e9SAndroid Build Coastguard Worker returns=returns, 1726*da0073e9SAndroid Build Coastguard Worker ) 1727*da0073e9SAndroid Build Coastguard Worker 1728*da0073e9SAndroid Build Coastguard Worker def view_signature(self) -> FunctionSchema: 1729*da0073e9SAndroid Build Coastguard Worker return self.signature(strip_view_copy_name=True) 1730*da0073e9SAndroid Build Coastguard Worker 1731*da0073e9SAndroid Build Coastguard Worker def with_name(self, name: OperatorName) -> FunctionSchema: 1732*da0073e9SAndroid Build Coastguard Worker return FunctionSchema( 1733*da0073e9SAndroid Build Coastguard Worker name=name, 1734*da0073e9SAndroid Build Coastguard Worker arguments=self.arguments, 1735*da0073e9SAndroid Build Coastguard Worker returns=self.returns, 1736*da0073e9SAndroid Build Coastguard Worker ) 1737*da0073e9SAndroid Build Coastguard Worker 1738*da0073e9SAndroid Build Coastguard Worker @property 1739*da0073e9SAndroid Build Coastguard Worker def modifies_arguments(self) -> bool: 1740*da0073e9SAndroid Build Coastguard Worker return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable] 1741*da0073e9SAndroid Build Coastguard Worker 1742*da0073e9SAndroid Build Coastguard Worker def has_symint(self) -> bool: 1743*da0073e9SAndroid Build Coastguard Worker return self.arguments.has_symint_arg() 1744*da0073e9SAndroid Build Coastguard Worker 1745*da0073e9SAndroid Build Coastguard Worker def __str__(self) -> str: 1746*da0073e9SAndroid Build Coastguard Worker all_arguments_str = str(self.arguments) 1747*da0073e9SAndroid Build Coastguard Worker if len(self.returns) == 1: 1748*da0073e9SAndroid Build Coastguard Worker returns = str(self.returns[0]) # omit parentheses 1749*da0073e9SAndroid Build Coastguard Worker else: 1750*da0073e9SAndroid Build Coastguard Worker returns = "(" + ", ".join(map(str, self.returns)) + ")" 1751*da0073e9SAndroid Build Coastguard Worker return f"{self.name}({all_arguments_str}) -> {returns}" 1752*da0073e9SAndroid Build Coastguard Worker 1753*da0073e9SAndroid Build Coastguard Worker 1754*da0073e9SAndroid Build Coastguard Worker# Here is the rest of the data model, described more briefly. 1755*da0073e9SAndroid Build Coastguard Worker 1756*da0073e9SAndroid Build Coastguard Worker 1757*da0073e9SAndroid Build Coastguard Worker# Simplified version for what actually shows up in built-ins. 1758*da0073e9SAndroid Build Coastguard Worker# Look at alias_info.h for expanded syntax. If you need the structure, 1759*da0073e9SAndroid Build Coastguard Worker# you also need to make this structure recursive so it can be lined 1760*da0073e9SAndroid Build Coastguard Worker# up with the type components too. For primitives this isn't really 1761*da0073e9SAndroid Build Coastguard Worker# necessary 1762*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 1763*da0073e9SAndroid Build Coastguard Workerclass Annotation: 1764*da0073e9SAndroid Build Coastguard Worker # Typically only has one element. Not actually a set so 1765*da0073e9SAndroid Build Coastguard Worker # we can conveniently assume it is canonically ordered 1766*da0073e9SAndroid Build Coastguard Worker alias_set: tuple[str, ...] 1767*da0073e9SAndroid Build Coastguard Worker is_write: bool 1768*da0073e9SAndroid Build Coastguard Worker alias_set_after: tuple[str, ...] 1769*da0073e9SAndroid Build Coastguard Worker 1770*da0073e9SAndroid Build Coastguard Worker @staticmethod 1771*da0073e9SAndroid Build Coastguard Worker def parse(ann: str) -> Annotation: 1772*da0073e9SAndroid Build Coastguard Worker # TODO: implement a proper parser if this gets more ugly 1773*da0073e9SAndroid Build Coastguard Worker # Regex Explanation: 1774*da0073e9SAndroid Build Coastguard Worker # Example: "a! -> a|b" 1775*da0073e9SAndroid Build Coastguard Worker # Group #1: alias before optional '|', required. Matches the first 1776*da0073e9SAndroid Build Coastguard Worker # character 'a' in the example 1777*da0073e9SAndroid Build Coastguard Worker # Group #2: optional alias set after optional '|', matches empty string 1778*da0073e9SAndroid Build Coastguard Worker # in the example 1779*da0073e9SAndroid Build Coastguard Worker # Group #3: optional "is write" flag, matches '!' in the example. 1780*da0073e9SAndroid Build Coastguard Worker # Group #4: optional section containing arrow, matches " -> a|b" in the 1781*da0073e9SAndroid Build Coastguard Worker # example. 1782*da0073e9SAndroid Build Coastguard Worker # Group #5: optional alias after set, supports wildcard, matches "a|b" 1783*da0073e9SAndroid Build Coastguard Worker # in the example. 1784*da0073e9SAndroid Build Coastguard Worker # Group #6: optional sub-section of alias after set, matches "|b" in the 1785*da0073e9SAndroid Build Coastguard Worker # example. 1786*da0073e9SAndroid Build Coastguard Worker m = re.match(r"^([a-z])(\|[a-z])*(!?)( -> (\*|[a-z](\|[a-z])*))?$", ann) 1787*da0073e9SAndroid Build Coastguard Worker 1788*da0073e9SAndroid Build Coastguard Worker assert m is not None, f"unrecognized alias annotation {ann}" 1789*da0073e9SAndroid Build Coastguard Worker before_alias = m.group(1) + (m.group(2) if m.group(2) else "") 1790*da0073e9SAndroid Build Coastguard Worker alias_set = tuple(before_alias.split("|")) 1791*da0073e9SAndroid Build Coastguard Worker is_write = m.group(3) == "!" 1792*da0073e9SAndroid Build Coastguard Worker assert not ( 1793*da0073e9SAndroid Build Coastguard Worker is_write and len(alias_set) > 1 1794*da0073e9SAndroid Build Coastguard Worker ), f"alias set larger than 1 is not mutable, got {ann} instead." 1795*da0073e9SAndroid Build Coastguard Worker after_set = tuple(m.group(5).split("|")) if m.group(5) else () 1796*da0073e9SAndroid Build Coastguard Worker assert not ( 1797*da0073e9SAndroid Build Coastguard Worker len(before_alias) > 1 and len(after_set) > 1 1798*da0073e9SAndroid Build Coastguard Worker ), f"before alias set and after alias set cannot be larger than 1 at the same time, got {ann} instead." 1799*da0073e9SAndroid Build Coastguard Worker r = Annotation( 1800*da0073e9SAndroid Build Coastguard Worker alias_set=alias_set, is_write=is_write, alias_set_after=after_set 1801*da0073e9SAndroid Build Coastguard Worker ) 1802*da0073e9SAndroid Build Coastguard Worker assert str(r) == ann, f"{r} != {ann}" 1803*da0073e9SAndroid Build Coastguard Worker return r 1804*da0073e9SAndroid Build Coastguard Worker 1805*da0073e9SAndroid Build Coastguard Worker def __str__(self) -> str: 1806*da0073e9SAndroid Build Coastguard Worker alias_set = "|".join(self.alias_set) 1807*da0073e9SAndroid Build Coastguard Worker if self.is_write: 1808*da0073e9SAndroid Build Coastguard Worker alias_set = f"{alias_set}!" 1809*da0073e9SAndroid Build Coastguard Worker alias_set_after = "|".join(self.alias_set_after) 1810*da0073e9SAndroid Build Coastguard Worker if alias_set_after: 1811*da0073e9SAndroid Build Coastguard Worker alias_set = f'{alias_set}{" -> "}{alias_set_after}' 1812*da0073e9SAndroid Build Coastguard Worker return alias_set 1813*da0073e9SAndroid Build Coastguard Worker 1814*da0073e9SAndroid Build Coastguard Worker 1815*da0073e9SAndroid Build Coastguard Worker# The base class for the type system. This is also loosely modeled 1816*da0073e9SAndroid Build Coastguard Worker# off of jit_type.h, but we've simplified the hierarchy to focus 1817*da0073e9SAndroid Build Coastguard Worker# in on the aspects of the type system that matter for code generation 1818*da0073e9SAndroid Build Coastguard Worker# (for example, there's no SingleElementType subclass anymore). 1819*da0073e9SAndroid Build Coastguard Worker# You never actually construct a Type; usually it's going to be one 1820*da0073e9SAndroid Build Coastguard Worker# of the subclasses. If Python had ADTs this would be one! 1821*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 1822*da0073e9SAndroid Build Coastguard Workerclass Type: 1823*da0073e9SAndroid Build Coastguard Worker @staticmethod 1824*da0073e9SAndroid Build Coastguard Worker def parse(t: str) -> Type: 1825*da0073e9SAndroid Build Coastguard Worker r = Type._parse(t) 1826*da0073e9SAndroid Build Coastguard Worker assert str(r) == t, f"{r} != {t}" 1827*da0073e9SAndroid Build Coastguard Worker return r 1828*da0073e9SAndroid Build Coastguard Worker 1829*da0073e9SAndroid Build Coastguard Worker @staticmethod 1830*da0073e9SAndroid Build Coastguard Worker def _parse(t: str) -> Type: 1831*da0073e9SAndroid Build Coastguard Worker m = re.match(r"^(.+)\?$", t) 1832*da0073e9SAndroid Build Coastguard Worker if m is not None: 1833*da0073e9SAndroid Build Coastguard Worker return OptionalType(Type.parse(m.group(1))) 1834*da0073e9SAndroid Build Coastguard Worker m = re.match(r"^(.+)\[([0-9]+)?\]$", t) 1835*da0073e9SAndroid Build Coastguard Worker if m is not None: 1836*da0073e9SAndroid Build Coastguard Worker size = int(m.group(2)) if m.group(2) is not None else None 1837*da0073e9SAndroid Build Coastguard Worker return ListType(elem=Type.parse(m.group(1)), size=size) 1838*da0073e9SAndroid Build Coastguard Worker 1839*da0073e9SAndroid Build Coastguard Worker # '__torch__.torch.classes.' is the prefix for custom class 1840*da0073e9SAndroid Build Coastguard Worker m = re.match(r"^__torch__\.torch\.classes\.([a-zA-Z0-9_.]+)$", t) 1841*da0073e9SAndroid Build Coastguard Worker if m is not None: 1842*da0073e9SAndroid Build Coastguard Worker return CustomClassType(m.group(1)) 1843*da0073e9SAndroid Build Coastguard Worker try: 1844*da0073e9SAndroid Build Coastguard Worker return BaseType(BaseTy[t]) 1845*da0073e9SAndroid Build Coastguard Worker except KeyError as e: 1846*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"unrecognized type {t}") from e 1847*da0073e9SAndroid Build Coastguard Worker 1848*da0073e9SAndroid Build Coastguard Worker def __str__(self) -> str: 1849*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1850*da0073e9SAndroid Build Coastguard Worker 1851*da0073e9SAndroid Build Coastguard Worker # WARNING: These concepts are not very well-defined. For example, 1852*da0073e9SAndroid Build Coastguard Worker # is "int?" nullable? How about "int?[]". They are defined 1853*da0073e9SAndroid Build Coastguard Worker # so we can conveniently generate legacy Declarations.yaml but 1854*da0073e9SAndroid Build Coastguard Worker # really we should probably just remove these at some point 1855*da0073e9SAndroid Build Coastguard Worker 1856*da0073e9SAndroid Build Coastguard Worker def is_base_ty_like(self, base_ty: BaseTy) -> bool: 1857*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1858*da0073e9SAndroid Build Coastguard Worker 1859*da0073e9SAndroid Build Coastguard Worker def is_tensor_like(self) -> bool: 1860*da0073e9SAndroid Build Coastguard Worker return self.is_base_ty_like(BaseTy.Tensor) 1861*da0073e9SAndroid Build Coastguard Worker 1862*da0073e9SAndroid Build Coastguard Worker def is_generator_like(self) -> bool: 1863*da0073e9SAndroid Build Coastguard Worker return self.is_base_ty_like(BaseTy.Generator) 1864*da0073e9SAndroid Build Coastguard Worker 1865*da0073e9SAndroid Build Coastguard Worker def is_symint_like(self) -> bool: 1866*da0073e9SAndroid Build Coastguard Worker return self.is_base_ty_like(BaseTy.SymInt) 1867*da0073e9SAndroid Build Coastguard Worker 1868*da0073e9SAndroid Build Coastguard Worker def is_nullable(self) -> bool: 1869*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1870*da0073e9SAndroid Build Coastguard Worker 1871*da0073e9SAndroid Build Coastguard Worker def is_list_like(self) -> ListType | None: 1872*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 1873*da0073e9SAndroid Build Coastguard Worker 1874*da0073e9SAndroid Build Coastguard Worker 1875*da0073e9SAndroid Build Coastguard Worker# Base types are simple, atomic types with no further structure 1876*da0073e9SAndroid Build Coastguard Workerclass BaseTy(Enum): 1877*da0073e9SAndroid Build Coastguard Worker Generator = auto() 1878*da0073e9SAndroid Build Coastguard Worker ScalarType = auto() 1879*da0073e9SAndroid Build Coastguard Worker Tensor = auto() 1880*da0073e9SAndroid Build Coastguard Worker int = auto() 1881*da0073e9SAndroid Build Coastguard Worker Dimname = auto() 1882*da0073e9SAndroid Build Coastguard Worker DimVector = auto() 1883*da0073e9SAndroid Build Coastguard Worker float = auto() 1884*da0073e9SAndroid Build Coastguard Worker str = auto() 1885*da0073e9SAndroid Build Coastguard Worker bool = auto() 1886*da0073e9SAndroid Build Coastguard Worker Layout = auto() 1887*da0073e9SAndroid Build Coastguard Worker Device = auto() 1888*da0073e9SAndroid Build Coastguard Worker DeviceIndex = auto() 1889*da0073e9SAndroid Build Coastguard Worker Scalar = auto() 1890*da0073e9SAndroid Build Coastguard Worker MemoryFormat = auto() 1891*da0073e9SAndroid Build Coastguard Worker QScheme = auto() 1892*da0073e9SAndroid Build Coastguard Worker Storage = auto() 1893*da0073e9SAndroid Build Coastguard Worker Stream = auto() 1894*da0073e9SAndroid Build Coastguard Worker SymInt = auto() 1895*da0073e9SAndroid Build Coastguard Worker SymBool = auto() 1896*da0073e9SAndroid Build Coastguard Worker ConstQuantizerPtr = auto() # TODO: rename 1897*da0073e9SAndroid Build Coastguard Worker GraphModule = auto() 1898*da0073e9SAndroid Build Coastguard Worker 1899*da0073e9SAndroid Build Coastguard Worker 1900*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 1901*da0073e9SAndroid Build Coastguard Workerclass BaseType(Type): 1902*da0073e9SAndroid Build Coastguard Worker name: BaseTy 1903*da0073e9SAndroid Build Coastguard Worker 1904*da0073e9SAndroid Build Coastguard Worker def __str__(self) -> str: 1905*da0073e9SAndroid Build Coastguard Worker return f"{self.name.name}" 1906*da0073e9SAndroid Build Coastguard Worker 1907*da0073e9SAndroid Build Coastguard Worker def is_base_ty_like(self, base_ty: BaseTy) -> bool: 1908*da0073e9SAndroid Build Coastguard Worker return self.name == base_ty 1909*da0073e9SAndroid Build Coastguard Worker 1910*da0073e9SAndroid Build Coastguard Worker def is_nullable(self) -> bool: 1911*da0073e9SAndroid Build Coastguard Worker return False 1912*da0073e9SAndroid Build Coastguard Worker 1913*da0073e9SAndroid Build Coastguard Worker def is_list_like(self) -> ListType | None: 1914*da0073e9SAndroid Build Coastguard Worker return None 1915*da0073e9SAndroid Build Coastguard Worker 1916*da0073e9SAndroid Build Coastguard Worker def is_symint_like(self) -> bool: 1917*da0073e9SAndroid Build Coastguard Worker return self.name == BaseTy.SymInt 1918*da0073e9SAndroid Build Coastguard Worker 1919*da0073e9SAndroid Build Coastguard Worker 1920*da0073e9SAndroid Build Coastguard Worker# Optional types may be specified, or may also be validly given None 1921*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 1922*da0073e9SAndroid Build Coastguard Workerclass OptionalType(Type): 1923*da0073e9SAndroid Build Coastguard Worker elem: Type 1924*da0073e9SAndroid Build Coastguard Worker 1925*da0073e9SAndroid Build Coastguard Worker def __str__(self) -> str: 1926*da0073e9SAndroid Build Coastguard Worker return f"{self.elem}?" 1927*da0073e9SAndroid Build Coastguard Worker 1928*da0073e9SAndroid Build Coastguard Worker def is_base_ty_like(self, base_ty: BaseTy) -> bool: 1929*da0073e9SAndroid Build Coastguard Worker return self.elem.is_base_ty_like(base_ty) 1930*da0073e9SAndroid Build Coastguard Worker 1931*da0073e9SAndroid Build Coastguard Worker def is_symint_like(self) -> bool: 1932*da0073e9SAndroid Build Coastguard Worker return self.elem.is_symint_like() 1933*da0073e9SAndroid Build Coastguard Worker 1934*da0073e9SAndroid Build Coastguard Worker def is_nullable(self) -> bool: 1935*da0073e9SAndroid Build Coastguard Worker return True 1936*da0073e9SAndroid Build Coastguard Worker 1937*da0073e9SAndroid Build Coastguard Worker def is_list_like(self) -> ListType | None: 1938*da0073e9SAndroid Build Coastguard Worker return self.elem.is_list_like() 1939*da0073e9SAndroid Build Coastguard Worker 1940*da0073e9SAndroid Build Coastguard Worker 1941*da0073e9SAndroid Build Coastguard Worker# A type representing a PyTorch custom class 1942*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 1943*da0073e9SAndroid Build Coastguard Workerclass CustomClassType(Type): 1944*da0073e9SAndroid Build Coastguard Worker class_name: str 1945*da0073e9SAndroid Build Coastguard Worker 1946*da0073e9SAndroid Build Coastguard Worker def __str__(self) -> str: 1947*da0073e9SAndroid Build Coastguard Worker """ 1948*da0073e9SAndroid Build Coastguard Worker Return the class name will prefix __torch__.torch.classes 1949*da0073e9SAndroid Build Coastguard Worker """ 1950*da0073e9SAndroid Build Coastguard Worker return f"__torch__.torch.classes.{self.class_name}" 1951*da0073e9SAndroid Build Coastguard Worker 1952*da0073e9SAndroid Build Coastguard Worker def is_base_ty_like(self, base_ty: BaseTy) -> bool: 1953*da0073e9SAndroid Build Coastguard Worker return False 1954*da0073e9SAndroid Build Coastguard Worker 1955*da0073e9SAndroid Build Coastguard Worker def is_symint_like(self) -> bool: 1956*da0073e9SAndroid Build Coastguard Worker return False 1957*da0073e9SAndroid Build Coastguard Worker 1958*da0073e9SAndroid Build Coastguard Worker def is_nullable(self) -> bool: 1959*da0073e9SAndroid Build Coastguard Worker """ 1960*da0073e9SAndroid Build Coastguard Worker Assume a custom class is not nullable. 1961*da0073e9SAndroid Build Coastguard Worker """ 1962*da0073e9SAndroid Build Coastguard Worker return False 1963*da0073e9SAndroid Build Coastguard Worker 1964*da0073e9SAndroid Build Coastguard Worker def is_list_like(self) -> ListType | None: 1965*da0073e9SAndroid Build Coastguard Worker return None 1966*da0073e9SAndroid Build Coastguard Worker 1967*da0073e9SAndroid Build Coastguard Worker 1968*da0073e9SAndroid Build Coastguard Worker# List types specify that we may have multiples of an element. We 1969*da0073e9SAndroid Build Coastguard Worker# also support explicit sizes on list types, but these have 1970*da0073e9SAndroid Build Coastguard Worker# some nontrivial semantics! (However, for C++ API purposes, explicit 1971*da0073e9SAndroid Build Coastguard Worker# sizes are mostly erased from the type system.) 1972*da0073e9SAndroid Build Coastguard Worker# 1973*da0073e9SAndroid Build Coastguard Worker# DANGER WILL ROBINSON: C++ elaboration depends on elem type; e.g., 1974*da0073e9SAndroid Build Coastguard Worker# int[] elaborates differently than bool[3]! 1975*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 1976*da0073e9SAndroid Build Coastguard Workerclass ListType(Type): 1977*da0073e9SAndroid Build Coastguard Worker elem: Type 1978*da0073e9SAndroid Build Coastguard Worker size: int | None 1979*da0073e9SAndroid Build Coastguard Worker 1980*da0073e9SAndroid Build Coastguard Worker def __str__(self) -> str: 1981*da0073e9SAndroid Build Coastguard Worker size = f"{self.size}" if self.size else "" 1982*da0073e9SAndroid Build Coastguard Worker return f"{self.elem}[{size}]" 1983*da0073e9SAndroid Build Coastguard Worker 1984*da0073e9SAndroid Build Coastguard Worker def is_base_ty_like(self, base_ty: BaseTy) -> bool: 1985*da0073e9SAndroid Build Coastguard Worker return self.elem.is_base_ty_like(base_ty) 1986*da0073e9SAndroid Build Coastguard Worker 1987*da0073e9SAndroid Build Coastguard Worker def is_symint_like(self) -> bool: 1988*da0073e9SAndroid Build Coastguard Worker return self.elem.is_symint_like() 1989*da0073e9SAndroid Build Coastguard Worker 1990*da0073e9SAndroid Build Coastguard Worker def is_nullable(self) -> bool: 1991*da0073e9SAndroid Build Coastguard Worker return self.elem.is_nullable() 1992*da0073e9SAndroid Build Coastguard Worker 1993*da0073e9SAndroid Build Coastguard Worker def is_list_like(self) -> ListType | None: 1994*da0073e9SAndroid Build Coastguard Worker return self 1995*da0073e9SAndroid Build Coastguard Worker 1996*da0073e9SAndroid Build Coastguard Worker 1997*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 1998*da0073e9SAndroid Build Coastguard Workerclass Argument: 1999*da0073e9SAndroid Build Coastguard Worker # NB: I didn't put kwarg_only as a boolean field here, unlike 2000*da0073e9SAndroid Build Coastguard Worker # c10::Argument, so that printing works correctly 2001*da0073e9SAndroid Build Coastguard Worker 2002*da0073e9SAndroid Build Coastguard Worker name: str 2003*da0073e9SAndroid Build Coastguard Worker type: Type 2004*da0073e9SAndroid Build Coastguard Worker default: str | None 2005*da0073e9SAndroid Build Coastguard Worker 2006*da0073e9SAndroid Build Coastguard Worker # The semantics of the annotation field are a little strange. 2007*da0073e9SAndroid Build Coastguard Worker # 2008*da0073e9SAndroid Build Coastguard Worker # Alias annotations parametrize Tensors (since Tensors are the only things 2009*da0073e9SAndroid Build Coastguard Worker # that can alias.) This motivates why I write Tensor(a!)? (and not, for 2010*da0073e9SAndroid Build Coastguard Worker # example, Tensor?(a!)), because the (a!) describes aliasing on the tensor, 2011*da0073e9SAndroid Build Coastguard Worker # which may be optional (i.e., the alias annotation should bind first to 2012*da0073e9SAndroid Build Coastguard Worker # Tensor, before the optional postfix annotation). 2013*da0073e9SAndroid Build Coastguard Worker # 2014*da0073e9SAndroid Build Coastguard Worker # However, despite being a property of Tensor, we (and c10::Argument) 2015*da0073e9SAndroid Build Coastguard Worker # store the annotation at the top level of the Argument, rather than 2016*da0073e9SAndroid Build Coastguard Worker # inside the embedded Tensor type. In the C++ version of this 2017*da0073e9SAndroid Build Coastguard Worker # class, we then go through great lengths to mimic the type 2018*da0073e9SAndroid Build Coastguard Worker # structure in the annotation structure so we can correlate 2019*da0073e9SAndroid Build Coastguard Worker # annotations with types. 2020*da0073e9SAndroid Build Coastguard Worker # 2021*da0073e9SAndroid Build Coastguard Worker # Now, it turns out, in all applications in code generation, the 2022*da0073e9SAndroid Build Coastguard Worker # structure of annotated types is very simple. So we just hard 2023*da0073e9SAndroid Build Coastguard Worker # code it here. But if we ever do get anything more complex, this 2024*da0073e9SAndroid Build Coastguard Worker # model will have to change! 2025*da0073e9SAndroid Build Coastguard Worker annotation: Annotation | None 2026*da0073e9SAndroid Build Coastguard Worker 2027*da0073e9SAndroid Build Coastguard Worker @property 2028*da0073e9SAndroid Build Coastguard Worker def alias_info(self) -> Annotation | None: 2029*da0073e9SAndroid Build Coastguard Worker return self.annotation 2030*da0073e9SAndroid Build Coastguard Worker 2031*da0073e9SAndroid Build Coastguard Worker @staticmethod 2032*da0073e9SAndroid Build Coastguard Worker def parse(arg: str) -> Argument: 2033*da0073e9SAndroid Build Coastguard Worker name: str 2034*da0073e9SAndroid Build Coastguard Worker default: str | None 2035*da0073e9SAndroid Build Coastguard Worker assert " " in arg, f"illegal argument '{arg}'" 2036*da0073e9SAndroid Build Coastguard Worker if "=" in arg: 2037*da0073e9SAndroid Build Coastguard Worker assert arg.count("=") == 1, f"illegal argument with default value: '{arg}'" 2038*da0073e9SAndroid Build Coastguard Worker type_and_annot_and_name, default = arg.split("=") 2039*da0073e9SAndroid Build Coastguard Worker type_and_annot, name = type_and_annot_and_name.rsplit(" ", 1) 2040*da0073e9SAndroid Build Coastguard Worker name_and_default = f"{name}={default}" 2041*da0073e9SAndroid Build Coastguard Worker else: 2042*da0073e9SAndroid Build Coastguard Worker type_and_annot, name_and_default = arg.rsplit(" ", 1) 2043*da0073e9SAndroid Build Coastguard Worker name = name_and_default 2044*da0073e9SAndroid Build Coastguard Worker default = None 2045*da0073e9SAndroid Build Coastguard Worker # TODO: deduplicate annotation matching with Return 2046*da0073e9SAndroid Build Coastguard Worker match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot) 2047*da0073e9SAndroid Build Coastguard Worker annotation: Annotation | None 2048*da0073e9SAndroid Build Coastguard Worker if match: 2049*da0073e9SAndroid Build Coastguard Worker # If you update this, make sure the __str__ still works too 2050*da0073e9SAndroid Build Coastguard Worker assert match.group(2) in [ 2051*da0073e9SAndroid Build Coastguard Worker "", 2052*da0073e9SAndroid Build Coastguard Worker "?", 2053*da0073e9SAndroid Build Coastguard Worker "[]", 2054*da0073e9SAndroid Build Coastguard Worker ], "unrecognized alias analysis form with Tensor" 2055*da0073e9SAndroid Build Coastguard Worker type_s = "Tensor" + match.group(2) 2056*da0073e9SAndroid Build Coastguard Worker annotation = Annotation.parse(match.group(1)) 2057*da0073e9SAndroid Build Coastguard Worker else: 2058*da0073e9SAndroid Build Coastguard Worker type_s = type_and_annot 2059*da0073e9SAndroid Build Coastguard Worker annotation = None 2060*da0073e9SAndroid Build Coastguard Worker type = Type.parse(type_s) 2061*da0073e9SAndroid Build Coastguard Worker r = Argument( 2062*da0073e9SAndroid Build Coastguard Worker name=name, 2063*da0073e9SAndroid Build Coastguard Worker type=type, 2064*da0073e9SAndroid Build Coastguard Worker default=default, 2065*da0073e9SAndroid Build Coastguard Worker annotation=annotation, 2066*da0073e9SAndroid Build Coastguard Worker ) 2067*da0073e9SAndroid Build Coastguard Worker assert str(r) == arg, f"{str(r)} != {arg}" 2068*da0073e9SAndroid Build Coastguard Worker return r 2069*da0073e9SAndroid Build Coastguard Worker 2070*da0073e9SAndroid Build Coastguard Worker @property 2071*da0073e9SAndroid Build Coastguard Worker def is_write(self) -> bool: 2072*da0073e9SAndroid Build Coastguard Worker return self.annotation is not None and self.annotation.is_write 2073*da0073e9SAndroid Build Coastguard Worker 2074*da0073e9SAndroid Build Coastguard Worker def __str__(self) -> str: 2075*da0073e9SAndroid Build Coastguard Worker type = f"{self.type}" 2076*da0073e9SAndroid Build Coastguard Worker if self.annotation: 2077*da0073e9SAndroid Build Coastguard Worker assert type in ["Tensor", "Tensor?", "Tensor[]"] 2078*da0073e9SAndroid Build Coastguard Worker type = type.replace("Tensor", f"Tensor({self.annotation})") 2079*da0073e9SAndroid Build Coastguard Worker if self.name is None: 2080*da0073e9SAndroid Build Coastguard Worker return type 2081*da0073e9SAndroid Build Coastguard Worker else: 2082*da0073e9SAndroid Build Coastguard Worker mb_default = "" 2083*da0073e9SAndroid Build Coastguard Worker if self.default: 2084*da0073e9SAndroid Build Coastguard Worker mb_default = f"={self.default}" 2085*da0073e9SAndroid Build Coastguard Worker return f"{type} {self.name}{mb_default}" 2086*da0073e9SAndroid Build Coastguard Worker 2087*da0073e9SAndroid Build Coastguard Worker 2088*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 2089*da0073e9SAndroid Build Coastguard Workerclass Return: 2090*da0073e9SAndroid Build Coastguard Worker name: str | None 2091*da0073e9SAndroid Build Coastguard Worker type: Type 2092*da0073e9SAndroid Build Coastguard Worker annotation: Annotation | None 2093*da0073e9SAndroid Build Coastguard Worker 2094*da0073e9SAndroid Build Coastguard Worker @property 2095*da0073e9SAndroid Build Coastguard Worker def alias_info(self) -> Annotation | None: 2096*da0073e9SAndroid Build Coastguard Worker return self.annotation 2097*da0073e9SAndroid Build Coastguard Worker 2098*da0073e9SAndroid Build Coastguard Worker @staticmethod 2099*da0073e9SAndroid Build Coastguard Worker def parse(arg: str) -> Return: 2100*da0073e9SAndroid Build Coastguard Worker name: str | None 2101*da0073e9SAndroid Build Coastguard Worker if " " in arg: 2102*da0073e9SAndroid Build Coastguard Worker type_and_annot, name = arg.rsplit(" ", 1) 2103*da0073e9SAndroid Build Coastguard Worker else: 2104*da0073e9SAndroid Build Coastguard Worker type_and_annot = arg 2105*da0073e9SAndroid Build Coastguard Worker name = None 2106*da0073e9SAndroid Build Coastguard Worker match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot) 2107*da0073e9SAndroid Build Coastguard Worker annotation: Annotation | None 2108*da0073e9SAndroid Build Coastguard Worker if match: 2109*da0073e9SAndroid Build Coastguard Worker # If you update this, make sure the __str__ still works too 2110*da0073e9SAndroid Build Coastguard Worker assert match.group(2) in [ 2111*da0073e9SAndroid Build Coastguard Worker "", 2112*da0073e9SAndroid Build Coastguard Worker "?", 2113*da0073e9SAndroid Build Coastguard Worker "[]", 2114*da0073e9SAndroid Build Coastguard Worker ], "unrecognized alias analysis form with Tensor" 2115*da0073e9SAndroid Build Coastguard Worker type_s = "Tensor" + match.group(2) 2116*da0073e9SAndroid Build Coastguard Worker annotation = Annotation.parse(match.group(1)) 2117*da0073e9SAndroid Build Coastguard Worker else: 2118*da0073e9SAndroid Build Coastguard Worker type_s = type_and_annot 2119*da0073e9SAndroid Build Coastguard Worker annotation = None 2120*da0073e9SAndroid Build Coastguard Worker type = Type.parse(type_s) 2121*da0073e9SAndroid Build Coastguard Worker r = Return( 2122*da0073e9SAndroid Build Coastguard Worker name=name, 2123*da0073e9SAndroid Build Coastguard Worker type=type, 2124*da0073e9SAndroid Build Coastguard Worker annotation=annotation, 2125*da0073e9SAndroid Build Coastguard Worker ) 2126*da0073e9SAndroid Build Coastguard Worker assert str(r) == arg, f"{str(r)} != {arg}" 2127*da0073e9SAndroid Build Coastguard Worker return r 2128*da0073e9SAndroid Build Coastguard Worker 2129*da0073e9SAndroid Build Coastguard Worker @property 2130*da0073e9SAndroid Build Coastguard Worker def is_write(self) -> bool: 2131*da0073e9SAndroid Build Coastguard Worker return self.annotation is not None and self.annotation.is_write 2132*da0073e9SAndroid Build Coastguard Worker 2133*da0073e9SAndroid Build Coastguard Worker def __str__(self) -> str: 2134*da0073e9SAndroid Build Coastguard Worker type = f"{self.type}" 2135*da0073e9SAndroid Build Coastguard Worker if self.annotation: 2136*da0073e9SAndroid Build Coastguard Worker assert type in ["Tensor", "Tensor?", "Tensor[]"] 2137*da0073e9SAndroid Build Coastguard Worker type = type.replace("Tensor", f"Tensor({self.annotation})") 2138*da0073e9SAndroid Build Coastguard Worker if self.name is None: 2139*da0073e9SAndroid Build Coastguard Worker return type 2140*da0073e9SAndroid Build Coastguard Worker else: 2141*da0073e9SAndroid Build Coastguard Worker return f"{type} {self.name}" 2142*da0073e9SAndroid Build Coastguard Worker 2143*da0073e9SAndroid Build Coastguard Worker 2144*da0073e9SAndroid Build Coastguard Worker# Represents the self argument for functions that may be methods 2145*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 2146*da0073e9SAndroid Build Coastguard Workerclass SelfArgument: 2147*da0073e9SAndroid Build Coastguard Worker argument: Argument 2148*da0073e9SAndroid Build Coastguard Worker 2149*da0073e9SAndroid Build Coastguard Worker 2150*da0073e9SAndroid Build Coastguard Worker# Bundle of arguments that represent a TensorOptions. This is mostly 2151*da0073e9SAndroid Build Coastguard Worker# relevant for the public C++ API but we bake it into the core data 2152*da0073e9SAndroid Build Coastguard Worker# model because other APIs often have to interact with it 2153*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 2154*da0073e9SAndroid Build Coastguard Workerclass TensorOptionsArguments: 2155*da0073e9SAndroid Build Coastguard Worker dtype: Argument 2156*da0073e9SAndroid Build Coastguard Worker layout: Argument 2157*da0073e9SAndroid Build Coastguard Worker device: Argument 2158*da0073e9SAndroid Build Coastguard Worker pin_memory: Argument 2159*da0073e9SAndroid Build Coastguard Worker 2160*da0073e9SAndroid Build Coastguard Worker def all(self) -> Sequence[Argument]: 2161*da0073e9SAndroid Build Coastguard Worker return [self.dtype, self.layout, self.device, self.pin_memory] 2162*da0073e9SAndroid Build Coastguard Worker 2163*da0073e9SAndroid Build Coastguard Worker 2164*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 2165*da0073e9SAndroid Build Coastguard Workerclass Arguments: 2166*da0073e9SAndroid Build Coastguard Worker # pre_self_positional is usually empty, but is notably non-empty 2167*da0073e9SAndroid Build Coastguard Worker # for where.self, where the condition argument comes before the 2168*da0073e9SAndroid Build Coastguard Worker # self argument 2169*da0073e9SAndroid Build Coastguard Worker pre_self_positional: tuple[Argument, ...] 2170*da0073e9SAndroid Build Coastguard Worker self_arg: SelfArgument | None 2171*da0073e9SAndroid Build Coastguard Worker post_self_positional: tuple[Argument, ...] 2172*da0073e9SAndroid Build Coastguard Worker 2173*da0073e9SAndroid Build Coastguard Worker pre_tensor_options_kwarg_only: tuple[Argument, ...] 2174*da0073e9SAndroid Build Coastguard Worker tensor_options: TensorOptionsArguments | None 2175*da0073e9SAndroid Build Coastguard Worker # post_tensor_options is typically memory format, which should be 2176*da0073e9SAndroid Build Coastguard Worker # part of tensor options but isn't right now, and is usually 2177*da0073e9SAndroid Build Coastguard Worker # placed after the tensor options arguments 2178*da0073e9SAndroid Build Coastguard Worker post_tensor_options_kwarg_only: tuple[Argument, ...] 2179*da0073e9SAndroid Build Coastguard Worker 2180*da0073e9SAndroid Build Coastguard Worker # Unlike in the previous codegen, we have factored out 'out' arguments 2181*da0073e9SAndroid Build Coastguard Worker # in the canonical representation, removing them from kwarg 2182*da0073e9SAndroid Build Coastguard Worker # arguments. This choice is justified by numerous downstream 2183*da0073e9SAndroid Build Coastguard Worker # transformations which treat out arguments specially; additionally, 2184*da0073e9SAndroid Build Coastguard Worker # you can see that canonicity is not violated! 2185*da0073e9SAndroid Build Coastguard Worker out: tuple[Argument, ...] # these are also kwarg-only 2186*da0073e9SAndroid Build Coastguard Worker 2187*da0073e9SAndroid Build Coastguard Worker @property 2188*da0073e9SAndroid Build Coastguard Worker def flat_non_out(self) -> Sequence[Argument]: 2189*da0073e9SAndroid Build Coastguard Worker ret: list[Argument] = [] 2190*da0073e9SAndroid Build Coastguard Worker ret.extend(self.flat_positional) 2191*da0073e9SAndroid Build Coastguard Worker ret.extend(self.flat_kwarg_only) 2192*da0073e9SAndroid Build Coastguard Worker return ret 2193*da0073e9SAndroid Build Coastguard Worker 2194*da0073e9SAndroid Build Coastguard Worker @property 2195*da0073e9SAndroid Build Coastguard Worker def flat_positional(self) -> Sequence[Argument]: 2196*da0073e9SAndroid Build Coastguard Worker ret: list[Argument] = [] 2197*da0073e9SAndroid Build Coastguard Worker ret.extend(self.pre_self_positional) 2198*da0073e9SAndroid Build Coastguard Worker if self.self_arg is not None: 2199*da0073e9SAndroid Build Coastguard Worker ret.append(self.self_arg.argument) 2200*da0073e9SAndroid Build Coastguard Worker ret.extend(self.post_self_positional) 2201*da0073e9SAndroid Build Coastguard Worker return ret 2202*da0073e9SAndroid Build Coastguard Worker 2203*da0073e9SAndroid Build Coastguard Worker @property 2204*da0073e9SAndroid Build Coastguard Worker def post_self_positional_mutable(self) -> Sequence[Argument]: 2205*da0073e9SAndroid Build Coastguard Worker return [a for a in self.post_self_positional if a.is_write] 2206*da0073e9SAndroid Build Coastguard Worker 2207*da0073e9SAndroid Build Coastguard Worker # NB: doesn't contain out arguments 2208*da0073e9SAndroid Build Coastguard Worker @property 2209*da0073e9SAndroid Build Coastguard Worker def flat_kwarg_only(self) -> Sequence[Argument]: 2210*da0073e9SAndroid Build Coastguard Worker ret: list[Argument] = [] 2211*da0073e9SAndroid Build Coastguard Worker ret.extend(self.pre_tensor_options_kwarg_only) 2212*da0073e9SAndroid Build Coastguard Worker if self.tensor_options is not None: 2213*da0073e9SAndroid Build Coastguard Worker ret.extend(self.tensor_options.all()) 2214*da0073e9SAndroid Build Coastguard Worker ret.extend(self.post_tensor_options_kwarg_only) 2215*da0073e9SAndroid Build Coastguard Worker return ret 2216*da0073e9SAndroid Build Coastguard Worker 2217*da0073e9SAndroid Build Coastguard Worker @property 2218*da0073e9SAndroid Build Coastguard Worker def flat_all(self) -> Sequence[Argument]: 2219*da0073e9SAndroid Build Coastguard Worker ret: list[Argument] = [] 2220*da0073e9SAndroid Build Coastguard Worker ret.extend(self.flat_positional) 2221*da0073e9SAndroid Build Coastguard Worker ret.extend(self.flat_kwarg_only) 2222*da0073e9SAndroid Build Coastguard Worker ret.extend(self.out) 2223*da0073e9SAndroid Build Coastguard Worker return ret 2224*da0073e9SAndroid Build Coastguard Worker 2225*da0073e9SAndroid Build Coastguard Worker @property 2226*da0073e9SAndroid Build Coastguard Worker def non_out( 2227*da0073e9SAndroid Build Coastguard Worker self, 2228*da0073e9SAndroid Build Coastguard Worker ) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]: 2229*da0073e9SAndroid Build Coastguard Worker ret: list[Argument | SelfArgument | TensorOptionsArguments] = [] 2230*da0073e9SAndroid Build Coastguard Worker ret.extend(self.positional) 2231*da0073e9SAndroid Build Coastguard Worker ret.extend(self.kwarg_only) 2232*da0073e9SAndroid Build Coastguard Worker return ret 2233*da0073e9SAndroid Build Coastguard Worker 2234*da0073e9SAndroid Build Coastguard Worker @property 2235*da0073e9SAndroid Build Coastguard Worker def positional(self) -> Sequence[Argument | SelfArgument]: 2236*da0073e9SAndroid Build Coastguard Worker ret: list[Argument | SelfArgument] = [] 2237*da0073e9SAndroid Build Coastguard Worker ret.extend(self.pre_self_positional) 2238*da0073e9SAndroid Build Coastguard Worker if self.self_arg is not None: 2239*da0073e9SAndroid Build Coastguard Worker ret.append(self.self_arg) 2240*da0073e9SAndroid Build Coastguard Worker ret.extend(self.post_self_positional) 2241*da0073e9SAndroid Build Coastguard Worker return ret 2242*da0073e9SAndroid Build Coastguard Worker 2243*da0073e9SAndroid Build Coastguard Worker @property 2244*da0073e9SAndroid Build Coastguard Worker def kwarg_only(self) -> Sequence[Argument | TensorOptionsArguments]: 2245*da0073e9SAndroid Build Coastguard Worker ret: list[Argument | TensorOptionsArguments] = [] 2246*da0073e9SAndroid Build Coastguard Worker ret.extend(self.pre_tensor_options_kwarg_only) 2247*da0073e9SAndroid Build Coastguard Worker if self.tensor_options is not None: 2248*da0073e9SAndroid Build Coastguard Worker ret.append(self.tensor_options) 2249*da0073e9SAndroid Build Coastguard Worker ret.extend(self.post_tensor_options_kwarg_only) 2250*da0073e9SAndroid Build Coastguard Worker return ret 2251*da0073e9SAndroid Build Coastguard Worker 2252*da0073e9SAndroid Build Coastguard Worker @property 2253*da0073e9SAndroid Build Coastguard Worker def all(self) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]: 2254*da0073e9SAndroid Build Coastguard Worker ret: list[Argument | SelfArgument | TensorOptionsArguments] = [] 2255*da0073e9SAndroid Build Coastguard Worker ret.extend(self.positional) 2256*da0073e9SAndroid Build Coastguard Worker ret.extend(self.kwarg_only) 2257*da0073e9SAndroid Build Coastguard Worker ret.extend(self.out) 2258*da0073e9SAndroid Build Coastguard Worker return ret 2259*da0073e9SAndroid Build Coastguard Worker 2260*da0073e9SAndroid Build Coastguard Worker def mutable_arg_names(self) -> list[str]: 2261*da0073e9SAndroid Build Coastguard Worker return [ 2262*da0073e9SAndroid Build Coastguard Worker a.name 2263*da0073e9SAndroid Build Coastguard Worker for a in self.flat_all 2264*da0073e9SAndroid Build Coastguard Worker if a.annotation is not None and a.annotation.is_write 2265*da0073e9SAndroid Build Coastguard Worker ] 2266*da0073e9SAndroid Build Coastguard Worker 2267*da0073e9SAndroid Build Coastguard Worker def has_tensor_arg(self) -> bool: 2268*da0073e9SAndroid Build Coastguard Worker return any(a.type.is_tensor_like() for a in self.flat_non_out) 2269*da0073e9SAndroid Build Coastguard Worker 2270*da0073e9SAndroid Build Coastguard Worker def has_symint_arg(self) -> bool: 2271*da0073e9SAndroid Build Coastguard Worker return any(a.type.is_symint_like() for a in self.flat_non_out) 2272*da0073e9SAndroid Build Coastguard Worker 2273*da0073e9SAndroid Build Coastguard Worker def has_generator_arg(self) -> bool: 2274*da0073e9SAndroid Build Coastguard Worker return any(a.type.is_generator_like() for a in self.flat_non_out) 2275*da0073e9SAndroid Build Coastguard Worker 2276*da0073e9SAndroid Build Coastguard Worker def signature(self, *, strip_default: bool = False) -> Arguments: 2277*da0073e9SAndroid Build Coastguard Worker # dataclasses.replace could be used here, but it is less 2278*da0073e9SAndroid Build Coastguard Worker # type safe so for now I've opted to type everything out 2279*da0073e9SAndroid Build Coastguard Worker def strip_arg_annotation(a: Argument) -> Argument: 2280*da0073e9SAndroid Build Coastguard Worker return Argument( 2281*da0073e9SAndroid Build Coastguard Worker name=a.name, 2282*da0073e9SAndroid Build Coastguard Worker type=a.type, 2283*da0073e9SAndroid Build Coastguard Worker default=a.default if not strip_default else None, 2284*da0073e9SAndroid Build Coastguard Worker annotation=None, 2285*da0073e9SAndroid Build Coastguard Worker ) 2286*da0073e9SAndroid Build Coastguard Worker 2287*da0073e9SAndroid Build Coastguard Worker return Arguments( 2288*da0073e9SAndroid Build Coastguard Worker pre_self_positional=tuple( 2289*da0073e9SAndroid Build Coastguard Worker map(strip_arg_annotation, self.pre_self_positional) 2290*da0073e9SAndroid Build Coastguard Worker ), 2291*da0073e9SAndroid Build Coastguard Worker self_arg=SelfArgument(strip_arg_annotation(self.self_arg.argument)) 2292*da0073e9SAndroid Build Coastguard Worker if self.self_arg is not None 2293*da0073e9SAndroid Build Coastguard Worker else None, 2294*da0073e9SAndroid Build Coastguard Worker post_self_positional=tuple( 2295*da0073e9SAndroid Build Coastguard Worker map(strip_arg_annotation, self.post_self_positional) 2296*da0073e9SAndroid Build Coastguard Worker ), 2297*da0073e9SAndroid Build Coastguard Worker # Since TensorOptions are dropped, the post_tensor_options_kwargs are 2298*da0073e9SAndroid Build Coastguard Worker # converted to pre_tensor_options_kwargs 2299*da0073e9SAndroid Build Coastguard Worker pre_tensor_options_kwarg_only=tuple( 2300*da0073e9SAndroid Build Coastguard Worker map(strip_arg_annotation, self.pre_tensor_options_kwarg_only) 2301*da0073e9SAndroid Build Coastguard Worker ) 2302*da0073e9SAndroid Build Coastguard Worker + tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)), 2303*da0073e9SAndroid Build Coastguard Worker # TensorOptions are dropped in signature, 2304*da0073e9SAndroid Build Coastguard Worker # so we can pair factory functions with their out= variants. 2305*da0073e9SAndroid Build Coastguard Worker tensor_options=None, 2306*da0073e9SAndroid Build Coastguard Worker post_tensor_options_kwarg_only=(), 2307*da0073e9SAndroid Build Coastguard Worker # out arguments are dropped in signature 2308*da0073e9SAndroid Build Coastguard Worker out=(), 2309*da0073e9SAndroid Build Coastguard Worker ) 2310*da0073e9SAndroid Build Coastguard Worker 2311*da0073e9SAndroid Build Coastguard Worker def remove_self_annotation(self) -> Arguments: 2312*da0073e9SAndroid Build Coastguard Worker assert self.self_arg is not None 2313*da0073e9SAndroid Build Coastguard Worker return dataclasses.replace( 2314*da0073e9SAndroid Build Coastguard Worker self, 2315*da0073e9SAndroid Build Coastguard Worker self_arg=SelfArgument( 2316*da0073e9SAndroid Build Coastguard Worker dataclasses.replace(self.self_arg.argument, annotation=None) 2317*da0073e9SAndroid Build Coastguard Worker ), 2318*da0073e9SAndroid Build Coastguard Worker ) 2319*da0073e9SAndroid Build Coastguard Worker 2320*da0073e9SAndroid Build Coastguard Worker def with_out_args(self, outs: list[Argument]) -> Arguments: 2321*da0073e9SAndroid Build Coastguard Worker assert len(self.out) == 0 2322*da0073e9SAndroid Build Coastguard Worker return dataclasses.replace( 2323*da0073e9SAndroid Build Coastguard Worker self, 2324*da0073e9SAndroid Build Coastguard Worker out=tuple(outs), 2325*da0073e9SAndroid Build Coastguard Worker ) 2326*da0073e9SAndroid Build Coastguard Worker 2327*da0073e9SAndroid Build Coastguard Worker @staticmethod 2328*da0073e9SAndroid Build Coastguard Worker def _preparse(args: str) -> tuple[list[Argument], list[Argument], list[Argument]]: 2329*da0073e9SAndroid Build Coastguard Worker positional: list[Argument] = [] 2330*da0073e9SAndroid Build Coastguard Worker kwarg_only: list[Argument] = [] 2331*da0073e9SAndroid Build Coastguard Worker out: list[Argument] = [] 2332*da0073e9SAndroid Build Coastguard Worker arguments_acc = positional 2333*da0073e9SAndroid Build Coastguard Worker 2334*da0073e9SAndroid Build Coastguard Worker # TODO: Use a real parser here; this will get bamboozled 2335*da0073e9SAndroid Build Coastguard Worker # by signatures that contain things like std::array<bool, 2> (note the space) 2336*da0073e9SAndroid Build Coastguard Worker for arg in args.split(", "): 2337*da0073e9SAndroid Build Coastguard Worker if not arg: 2338*da0073e9SAndroid Build Coastguard Worker continue 2339*da0073e9SAndroid Build Coastguard Worker if arg == "*": 2340*da0073e9SAndroid Build Coastguard Worker assert ( 2341*da0073e9SAndroid Build Coastguard Worker arguments_acc is positional 2342*da0073e9SAndroid Build Coastguard Worker ), "invalid syntax: kwarg-only specifier * can only occur once" 2343*da0073e9SAndroid Build Coastguard Worker arguments_acc = kwarg_only 2344*da0073e9SAndroid Build Coastguard Worker continue 2345*da0073e9SAndroid Build Coastguard Worker parg = Argument.parse(arg) 2346*da0073e9SAndroid Build Coastguard Worker # Currently, we rely directly on the invariant that there are NO 2347*da0073e9SAndroid Build Coastguard Worker # kwarg-only mutating arguments. If you want to relax this, 2348*da0073e9SAndroid Build Coastguard Worker # we will need a more semantic way of matching that takes 2349*da0073e9SAndroid Build Coastguard Worker # into account return arguments. In that case, you will have 2350*da0073e9SAndroid Build Coastguard Worker # to manage out computation a level up, in FunctionSchema. See Note 2351*da0073e9SAndroid Build Coastguard Worker # [is_out_fn] 2352*da0073e9SAndroid Build Coastguard Worker if parg.annotation is not None and parg.annotation.is_write: 2353*da0073e9SAndroid Build Coastguard Worker if arguments_acc is positional: 2354*da0073e9SAndroid Build Coastguard Worker pass # do nothing 2355*da0073e9SAndroid Build Coastguard Worker elif arguments_acc is kwarg_only: 2356*da0073e9SAndroid Build Coastguard Worker arguments_acc = out 2357*da0073e9SAndroid Build Coastguard Worker else: 2358*da0073e9SAndroid Build Coastguard Worker assert arguments_acc is not out 2359*da0073e9SAndroid Build Coastguard Worker arguments_acc.append(parg) 2360*da0073e9SAndroid Build Coastguard Worker 2361*da0073e9SAndroid Build Coastguard Worker return positional, kwarg_only, out 2362*da0073e9SAndroid Build Coastguard Worker 2363*da0073e9SAndroid Build Coastguard Worker @staticmethod 2364*da0073e9SAndroid Build Coastguard Worker def parse(args: str) -> Arguments: 2365*da0073e9SAndroid Build Coastguard Worker """ 2366*da0073e9SAndroid Build Coastguard Worker Input: 'int x, int y, int z' 2367*da0073e9SAndroid Build Coastguard Worker """ 2368*da0073e9SAndroid Build Coastguard Worker 2369*da0073e9SAndroid Build Coastguard Worker # We do this in two phases. First we parse into three 2370*da0073e9SAndroid Build Coastguard Worker # main categories: positional, kwarg_only, out. 2371*da0073e9SAndroid Build Coastguard Worker # Then, we reparse positional and kwarg_only to separate 2372*da0073e9SAndroid Build Coastguard Worker # out the self argument and tensor options arguments. 2373*da0073e9SAndroid Build Coastguard Worker 2374*da0073e9SAndroid Build Coastguard Worker positional, kwarg_only, out = Arguments._preparse(args) 2375*da0073e9SAndroid Build Coastguard Worker 2376*da0073e9SAndroid Build Coastguard Worker # Split self argument 2377*da0073e9SAndroid Build Coastguard Worker self_ix = None 2378*da0073e9SAndroid Build Coastguard Worker for i, a in enumerate(positional): 2379*da0073e9SAndroid Build Coastguard Worker if a.name == "self": 2380*da0073e9SAndroid Build Coastguard Worker self_ix = i 2381*da0073e9SAndroid Build Coastguard Worker break 2382*da0073e9SAndroid Build Coastguard Worker pre_self_positional: list[Argument] 2383*da0073e9SAndroid Build Coastguard Worker self_arg: SelfArgument | None 2384*da0073e9SAndroid Build Coastguard Worker post_self_positional: list[Argument] 2385*da0073e9SAndroid Build Coastguard Worker if self_ix is not None: 2386*da0073e9SAndroid Build Coastguard Worker pre_self_positional = positional[:self_ix] 2387*da0073e9SAndroid Build Coastguard Worker self_arg = SelfArgument(positional[self_ix]) 2388*da0073e9SAndroid Build Coastguard Worker post_self_positional = positional[self_ix + 1 :] 2389*da0073e9SAndroid Build Coastguard Worker else: 2390*da0073e9SAndroid Build Coastguard Worker pre_self_positional = [] 2391*da0073e9SAndroid Build Coastguard Worker self_arg = None 2392*da0073e9SAndroid Build Coastguard Worker post_self_positional = positional 2393*da0073e9SAndroid Build Coastguard Worker 2394*da0073e9SAndroid Build Coastguard Worker # Group tensor options arguments 2395*da0073e9SAndroid Build Coastguard Worker pre_tensor_options_kwarg_only: list[Argument] = [] 2396*da0073e9SAndroid Build Coastguard Worker tensor_options: TensorOptionsArguments | None = None 2397*da0073e9SAndroid Build Coastguard Worker post_tensor_options_kwarg_only: list[Argument] = [] 2398*da0073e9SAndroid Build Coastguard Worker kwarg_only_acc = pre_tensor_options_kwarg_only 2399*da0073e9SAndroid Build Coastguard Worker 2400*da0073e9SAndroid Build Coastguard Worker def pred(name: str, ty: Type) -> Callable[[Argument], bool]: 2401*da0073e9SAndroid Build Coastguard Worker return lambda a: a.name == name and a.type in [ty, OptionalType(ty)] 2402*da0073e9SAndroid Build Coastguard Worker 2403*da0073e9SAndroid Build Coastguard Worker predicates = [ # order matters 2404*da0073e9SAndroid Build Coastguard Worker pred("dtype", Type.parse("ScalarType")), 2405*da0073e9SAndroid Build Coastguard Worker pred("layout", Type.parse("Layout")), 2406*da0073e9SAndroid Build Coastguard Worker pred("device", Type.parse("Device")), 2407*da0073e9SAndroid Build Coastguard Worker pred("pin_memory", Type.parse("bool")), 2408*da0073e9SAndroid Build Coastguard Worker ] 2409*da0073e9SAndroid Build Coastguard Worker 2410*da0073e9SAndroid Build Coastguard Worker i = 0 2411*da0073e9SAndroid Build Coastguard Worker while i < len(kwarg_only): 2412*da0073e9SAndroid Build Coastguard Worker # If there is enough space... 2413*da0073e9SAndroid Build Coastguard Worker if i <= len(kwarg_only) - len(predicates): 2414*da0073e9SAndroid Build Coastguard Worker # And the next len(predicates) arguments look like TensorOptions arguments 2415*da0073e9SAndroid Build Coastguard Worker if all( 2416*da0073e9SAndroid Build Coastguard Worker p(a) 2417*da0073e9SAndroid Build Coastguard Worker for p, a in zip(predicates, kwarg_only[i : i + len(predicates)]) 2418*da0073e9SAndroid Build Coastguard Worker ): 2419*da0073e9SAndroid Build Coastguard Worker assert kwarg_only_acc is pre_tensor_options_kwarg_only 2420*da0073e9SAndroid Build Coastguard Worker # Group them together as one argument 2421*da0073e9SAndroid Build Coastguard Worker tensor_options = TensorOptionsArguments( 2422*da0073e9SAndroid Build Coastguard Worker dtype=kwarg_only[i], 2423*da0073e9SAndroid Build Coastguard Worker layout=kwarg_only[i + 1], 2424*da0073e9SAndroid Build Coastguard Worker device=kwarg_only[i + 2], 2425*da0073e9SAndroid Build Coastguard Worker pin_memory=kwarg_only[i + 3], 2426*da0073e9SAndroid Build Coastguard Worker ) 2427*da0073e9SAndroid Build Coastguard Worker i += len(predicates) 2428*da0073e9SAndroid Build Coastguard Worker kwarg_only_acc = post_tensor_options_kwarg_only 2429*da0073e9SAndroid Build Coastguard Worker continue 2430*da0073e9SAndroid Build Coastguard Worker kwarg_only_acc.append(kwarg_only[i]) 2431*da0073e9SAndroid Build Coastguard Worker i += 1 2432*da0073e9SAndroid Build Coastguard Worker 2433*da0073e9SAndroid Build Coastguard Worker return Arguments( 2434*da0073e9SAndroid Build Coastguard Worker pre_self_positional=tuple(pre_self_positional), 2435*da0073e9SAndroid Build Coastguard Worker self_arg=self_arg, 2436*da0073e9SAndroid Build Coastguard Worker post_self_positional=tuple(post_self_positional), 2437*da0073e9SAndroid Build Coastguard Worker pre_tensor_options_kwarg_only=tuple(pre_tensor_options_kwarg_only), 2438*da0073e9SAndroid Build Coastguard Worker tensor_options=tensor_options, 2439*da0073e9SAndroid Build Coastguard Worker post_tensor_options_kwarg_only=tuple(post_tensor_options_kwarg_only), 2440*da0073e9SAndroid Build Coastguard Worker out=tuple(out), 2441*da0073e9SAndroid Build Coastguard Worker ) 2442*da0073e9SAndroid Build Coastguard Worker 2443*da0073e9SAndroid Build Coastguard Worker def __str__(self) -> str: 2444*da0073e9SAndroid Build Coastguard Worker all_arguments: list[str] = [] 2445*da0073e9SAndroid Build Coastguard Worker all_arguments.extend(map(str, self.flat_positional)) 2446*da0073e9SAndroid Build Coastguard Worker if self.flat_kwarg_only or self.out: 2447*da0073e9SAndroid Build Coastguard Worker all_arguments.append("*") 2448*da0073e9SAndroid Build Coastguard Worker all_arguments.extend(map(str, self.flat_kwarg_only)) 2449*da0073e9SAndroid Build Coastguard Worker all_arguments.extend(map(str, self.out)) 2450*da0073e9SAndroid Build Coastguard Worker return ", ".join(all_arguments) 2451*da0073e9SAndroid Build Coastguard Worker 2452*da0073e9SAndroid Build Coastguard Worker def __post_init__(self) -> None: 2453*da0073e9SAndroid Build Coastguard Worker # TODO: These invariants are weirdly asymmetric? 2454*da0073e9SAndroid Build Coastguard Worker # TODO: Fancier types? 2455*da0073e9SAndroid Build Coastguard Worker if self.self_arg is None: 2456*da0073e9SAndroid Build Coastguard Worker assert not self.pre_self_positional 2457*da0073e9SAndroid Build Coastguard Worker if self.tensor_options is None: 2458*da0073e9SAndroid Build Coastguard Worker assert not self.post_tensor_options_kwarg_only 2459*da0073e9SAndroid Build Coastguard Worker 2460*da0073e9SAndroid Build Coastguard Worker # We don't allow any of the following to have argument annotations, 2461*da0073e9SAndroid Build Coastguard Worker # to keep things simple. 2462*da0073e9SAndroid Build Coastguard Worker mutable_pre_self_positionals = [ 2463*da0073e9SAndroid Build Coastguard Worker a 2464*da0073e9SAndroid Build Coastguard Worker for a in self.pre_self_positional 2465*da0073e9SAndroid Build Coastguard Worker if a.annotation is not None and a.annotation.is_write 2466*da0073e9SAndroid Build Coastguard Worker ] 2467*da0073e9SAndroid Build Coastguard Worker assert ( 2468*da0073e9SAndroid Build Coastguard Worker len(mutable_pre_self_positionals) == 0 2469*da0073e9SAndroid Build Coastguard Worker ), "mutable pre_self_positional arguments are not currently supported in the schema" 2470*da0073e9SAndroid Build Coastguard Worker 2471*da0073e9SAndroid Build Coastguard Worker 2472*da0073e9SAndroid Build Coastguard Worker# Names that validly are __iXXX__ indicating inplace operations. 2473*da0073e9SAndroid Build Coastguard Worker# Taken from https://www.python.org/dev/peps/pep-0203/#new-methods 2474*da0073e9SAndroid Build Coastguard Worker# NB: PyTorch hasn't actually implemented all of these 2475*da0073e9SAndroid Build Coastguard WorkerAUGMENTED_ASSIGNMENT_NAMES = [ 2476*da0073e9SAndroid Build Coastguard Worker "add", 2477*da0073e9SAndroid Build Coastguard Worker "sub", 2478*da0073e9SAndroid Build Coastguard Worker "mul", 2479*da0073e9SAndroid Build Coastguard Worker "div", 2480*da0073e9SAndroid Build Coastguard Worker "mod", 2481*da0073e9SAndroid Build Coastguard Worker "pow", 2482*da0073e9SAndroid Build Coastguard Worker "lshift", 2483*da0073e9SAndroid Build Coastguard Worker "rshift", 2484*da0073e9SAndroid Build Coastguard Worker "and", 2485*da0073e9SAndroid Build Coastguard Worker "xor", 2486*da0073e9SAndroid Build Coastguard Worker "or", 2487*da0073e9SAndroid Build Coastguard Worker] 2488*da0073e9SAndroid Build Coastguard Worker 2489*da0073e9SAndroid Build Coastguard Worker 2490*da0073e9SAndroid Build Coastguard Worker# A BaseOperatorName is what we think of the operator name, without 2491*da0073e9SAndroid Build Coastguard Worker# the overload name. Unusually, we don't represent this as just a 2492*da0073e9SAndroid Build Coastguard Worker# string; instead, we directly represent a few important semantic 2493*da0073e9SAndroid Build Coastguard Worker# bits of information we derive from the string: namely whether 2494*da0073e9SAndroid Build Coastguard Worker# or not it's inplace (add_) and whether or not it's a double-underscore 2495*da0073e9SAndroid Build Coastguard Worker# method (__add__) 2496*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 2497*da0073e9SAndroid Build Coastguard Workerclass BaseOperatorName: 2498*da0073e9SAndroid Build Coastguard Worker base: str 2499*da0073e9SAndroid Build Coastguard Worker inplace: bool 2500*da0073e9SAndroid Build Coastguard Worker dunder_method: bool 2501*da0073e9SAndroid Build Coastguard Worker # Note [Overload Ambiguity With Functional Variants] 2502*da0073e9SAndroid Build Coastguard Worker # A handful of operators have both a "mutable" and a "functional" variant. 2503*da0073e9SAndroid Build Coastguard Worker # (native_batch_norm is a good example, although this isn't the case today). 2504*da0073e9SAndroid Build Coastguard Worker # For those operators, the mutable and functional variant take in the same set of 2505*da0073e9SAndroid Build Coastguard Worker # arguments, but have different alias annotations. 2506*da0073e9SAndroid Build Coastguard Worker # this makes it ambiguous when you try to resolve an OverloadPacket into an overload, 2507*da0073e9SAndroid Build Coastguard Worker # given a set of input arguments. 2508*da0073e9SAndroid Build Coastguard Worker # 2509*da0073e9SAndroid Build Coastguard Worker # So instead of making the "functional" variant in this case a real overload, e.g: 2510*da0073e9SAndroid Build Coastguard Worker # native_batch_norm (mutable variant) 2511*da0073e9SAndroid Build Coastguard Worker # native_batch_norm.functional (functional variant) 2512*da0073e9SAndroid Build Coastguard Worker # we make it a new base operator, 2513*da0073e9SAndroid Build Coastguard Worker # native_batch_norm_functional (functional variant) 2514*da0073e9SAndroid Build Coastguard Worker # 2515*da0073e9SAndroid Build Coastguard Worker # In an ideal world, we would probably invert this so the operators were: 2516*da0073e9SAndroid Build Coastguard Worker # native_batch_norm.mutable (mutable variant) 2517*da0073e9SAndroid Build Coastguard Worker # native_batch_norm (functional variant) 2518*da0073e9SAndroid Build Coastguard Worker # 2519*da0073e9SAndroid Build Coastguard Worker # Doing that is BC-breaking though, so we're stuck with the above modeling. 2520*da0073e9SAndroid Build Coastguard Worker functional_overload: bool = False 2521*da0073e9SAndroid Build Coastguard Worker 2522*da0073e9SAndroid Build Coastguard Worker @staticmethod 2523*da0073e9SAndroid Build Coastguard Worker def parse(op: str) -> BaseOperatorName: 2524*da0073e9SAndroid Build Coastguard Worker assert op != "" 2525*da0073e9SAndroid Build Coastguard Worker assert not op.endswith("_out"), ( 2526*da0073e9SAndroid Build Coastguard Worker "_out suffix is reserved and not permitted for operator names; " 2527*da0073e9SAndroid Build Coastguard Worker "did you mean to specify an out overload name instead?" 2528*da0073e9SAndroid Build Coastguard Worker ) 2529*da0073e9SAndroid Build Coastguard Worker m = re.match(r"^__([^_]+)__$", op) 2530*da0073e9SAndroid Build Coastguard Worker if m is not None: 2531*da0073e9SAndroid Build Coastguard Worker dunder_method = True 2532*da0073e9SAndroid Build Coastguard Worker base = m.group(1) 2533*da0073e9SAndroid Build Coastguard Worker if any(base == f"i{n}" for n in AUGMENTED_ASSIGNMENT_NAMES): 2534*da0073e9SAndroid Build Coastguard Worker inplace = True 2535*da0073e9SAndroid Build Coastguard Worker base = base[1:] 2536*da0073e9SAndroid Build Coastguard Worker else: 2537*da0073e9SAndroid Build Coastguard Worker inplace = False 2538*da0073e9SAndroid Build Coastguard Worker # temporary, this is not intrinsically true but 2539*da0073e9SAndroid Build Coastguard Worker # has been historically true for dunder methods 2540*da0073e9SAndroid Build Coastguard Worker # we support (but, if we ever got, say, __int__, this would 2541*da0073e9SAndroid Build Coastguard Worker # be wrong!) 2542*da0073e9SAndroid Build Coastguard Worker assert base[0] != "i" 2543*da0073e9SAndroid Build Coastguard Worker else: 2544*da0073e9SAndroid Build Coastguard Worker dunder_method = False 2545*da0073e9SAndroid Build Coastguard Worker base = op 2546*da0073e9SAndroid Build Coastguard Worker if base[-1] == "_": 2547*da0073e9SAndroid Build Coastguard Worker inplace = True 2548*da0073e9SAndroid Build Coastguard Worker base = base[:-1] 2549*da0073e9SAndroid Build Coastguard Worker else: 2550*da0073e9SAndroid Build Coastguard Worker inplace = False 2551*da0073e9SAndroid Build Coastguard Worker 2552*da0073e9SAndroid Build Coastguard Worker # See Note [Overload Ambiguity With Functional Variants] 2553*da0073e9SAndroid Build Coastguard Worker functional_suffix = "_functional" 2554*da0073e9SAndroid Build Coastguard Worker if base.endswith(functional_suffix): 2555*da0073e9SAndroid Build Coastguard Worker functional_overload = True 2556*da0073e9SAndroid Build Coastguard Worker base = base[: -len(functional_suffix)] 2557*da0073e9SAndroid Build Coastguard Worker # This seems complicated and unnecessary, so banning dunder methods 2558*da0073e9SAndroid Build Coastguard Worker # for now on ops that have a functional + mutable variant (like native_batch_norm). 2559*da0073e9SAndroid Build Coastguard Worker assert not dunder_method and not inplace 2560*da0073e9SAndroid Build Coastguard Worker else: 2561*da0073e9SAndroid Build Coastguard Worker functional_overload = False 2562*da0073e9SAndroid Build Coastguard Worker 2563*da0073e9SAndroid Build Coastguard Worker r = BaseOperatorName( 2564*da0073e9SAndroid Build Coastguard Worker base=base, 2565*da0073e9SAndroid Build Coastguard Worker inplace=inplace, 2566*da0073e9SAndroid Build Coastguard Worker dunder_method=dunder_method, 2567*da0073e9SAndroid Build Coastguard Worker functional_overload=functional_overload, 2568*da0073e9SAndroid Build Coastguard Worker ) 2569*da0073e9SAndroid Build Coastguard Worker assert str(r) == op, f"{str(r)} != {op}" 2570*da0073e9SAndroid Build Coastguard Worker return r 2571*da0073e9SAndroid Build Coastguard Worker 2572*da0073e9SAndroid Build Coastguard Worker def __str__(self) -> str: 2573*da0073e9SAndroid Build Coastguard Worker if self.dunder_method: 2574*da0073e9SAndroid Build Coastguard Worker i = "i" if self.inplace else "" 2575*da0073e9SAndroid Build Coastguard Worker return f"__{i}{self.base}__" 2576*da0073e9SAndroid Build Coastguard Worker else: 2577*da0073e9SAndroid Build Coastguard Worker i = ( 2578*da0073e9SAndroid Build Coastguard Worker "_" 2579*da0073e9SAndroid Build Coastguard Worker if self.inplace 2580*da0073e9SAndroid Build Coastguard Worker else "_functional" 2581*da0073e9SAndroid Build Coastguard Worker if self.functional_overload 2582*da0073e9SAndroid Build Coastguard Worker else "" 2583*da0073e9SAndroid Build Coastguard Worker ) 2584*da0073e9SAndroid Build Coastguard Worker return f"{self.base}{i}" 2585*da0073e9SAndroid Build Coastguard Worker 2586*da0073e9SAndroid Build Coastguard Worker 2587*da0073e9SAndroid Build Coastguard Worker# Operator name is the base operator name along with the (typically not 2588*da0073e9SAndroid Build Coastguard Worker# user visible) overload string. 2589*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 2590*da0073e9SAndroid Build Coastguard Workerclass OperatorName: 2591*da0073e9SAndroid Build Coastguard Worker name: BaseOperatorName 2592*da0073e9SAndroid Build Coastguard Worker overload_name: str 2593*da0073e9SAndroid Build Coastguard Worker 2594*da0073e9SAndroid Build Coastguard Worker @staticmethod 2595*da0073e9SAndroid Build Coastguard Worker def parse(op_name: str) -> OperatorName: 2596*da0073e9SAndroid Build Coastguard Worker if "." in op_name: 2597*da0073e9SAndroid Build Coastguard Worker name, overload_name = op_name.split(".", 1) 2598*da0073e9SAndroid Build Coastguard Worker else: 2599*da0073e9SAndroid Build Coastguard Worker name = op_name 2600*da0073e9SAndroid Build Coastguard Worker overload_name = "" 2601*da0073e9SAndroid Build Coastguard Worker r = OperatorName(name=BaseOperatorName.parse(name), overload_name=overload_name) 2602*da0073e9SAndroid Build Coastguard Worker assert str(r) == op_name, f"{str(r)} != {op_name}" 2603*da0073e9SAndroid Build Coastguard Worker return r 2604*da0073e9SAndroid Build Coastguard Worker 2605*da0073e9SAndroid Build Coastguard Worker def __str__(self) -> str: 2606*da0073e9SAndroid Build Coastguard Worker if self.overload_name: 2607*da0073e9SAndroid Build Coastguard Worker return f"{self.name}.{self.overload_name}" 2608*da0073e9SAndroid Build Coastguard Worker else: 2609*da0073e9SAndroid Build Coastguard Worker return f"{self.name}" 2610*da0073e9SAndroid Build Coastguard Worker 2611*da0073e9SAndroid Build Coastguard Worker # NB: This must be synchronized with the naming scheme in 2612*da0073e9SAndroid Build Coastguard Worker # aten/src/ATen/templates/Operators.h 2613*da0073e9SAndroid Build Coastguard Worker # Given a function schema "aten::op.overload(...)", 2614*da0073e9SAndroid Build Coastguard Worker # If there is no overload name, this returns f"{op}" 2615*da0073e9SAndroid Build Coastguard Worker # If there is an overload name, this returns f"{op}_{overload}" 2616*da0073e9SAndroid Build Coastguard Worker def unambiguous_name(self) -> str: 2617*da0073e9SAndroid Build Coastguard Worker if self.overload_name: 2618*da0073e9SAndroid Build Coastguard Worker return f"{self.name}_{self.overload_name}" 2619*da0073e9SAndroid Build Coastguard Worker else: 2620*da0073e9SAndroid Build Coastguard Worker return f"{self.name}" 2621*da0073e9SAndroid Build Coastguard Worker 2622*da0073e9SAndroid Build Coastguard Worker def remove_inplace(self) -> OperatorName: 2623*da0073e9SAndroid Build Coastguard Worker return OperatorName( 2624*da0073e9SAndroid Build Coastguard Worker name=BaseOperatorName( 2625*da0073e9SAndroid Build Coastguard Worker base=self.name.base, 2626*da0073e9SAndroid Build Coastguard Worker inplace=False, 2627*da0073e9SAndroid Build Coastguard Worker dunder_method=self.name.dunder_method, 2628*da0073e9SAndroid Build Coastguard Worker ), 2629*da0073e9SAndroid Build Coastguard Worker overload_name=self.overload_name, 2630*da0073e9SAndroid Build Coastguard Worker ) 2631*da0073e9SAndroid Build Coastguard Worker 2632*da0073e9SAndroid Build Coastguard Worker def with_overload(self, overload: str) -> OperatorName: 2633*da0073e9SAndroid Build Coastguard Worker return OperatorName( 2634*da0073e9SAndroid Build Coastguard Worker name=BaseOperatorName( 2635*da0073e9SAndroid Build Coastguard Worker base=self.name.base, 2636*da0073e9SAndroid Build Coastguard Worker inplace=False, 2637*da0073e9SAndroid Build Coastguard Worker dunder_method=self.name.dunder_method, 2638*da0073e9SAndroid Build Coastguard Worker ), 2639*da0073e9SAndroid Build Coastguard Worker overload_name=overload, 2640*da0073e9SAndroid Build Coastguard Worker ) 2641*da0073e9SAndroid Build Coastguard Worker 2642*da0073e9SAndroid Build Coastguard Worker 2643*da0073e9SAndroid Build Coastguard Workerdef gets_generated_out_inplace_wrapper( 2644*da0073e9SAndroid Build Coastguard Worker f: NativeFunction, g: NativeFunctionsGroup, b: BackendIndex 2645*da0073e9SAndroid Build Coastguard Worker) -> bool: 2646*da0073e9SAndroid Build Coastguard Worker return ( 2647*da0073e9SAndroid Build Coastguard Worker f.func.kind() is not SchemaKind.functional 2648*da0073e9SAndroid Build Coastguard Worker and not b.has_kernel(f) 2649*da0073e9SAndroid Build Coastguard Worker and b.has_kernel(g.functional) 2650*da0073e9SAndroid Build Coastguard Worker ) 2651*da0073e9SAndroid Build Coastguard Worker 2652*da0073e9SAndroid Build Coastguard Worker 2653*da0073e9SAndroid Build Coastguard Worker# NativeFunction objects that are views (f.is_view_op returns True) 2654*da0073e9SAndroid Build Coastguard Worker# are added into a `NativeFunctionsViewGroup`, which we can use to 2655*da0073e9SAndroid Build Coastguard Worker# easily access the generated (optional) view_copy NativeFunction. 2656*da0073e9SAndroid Build Coastguard Worker# It's convenient to group them together, so we pair them up in NativeFunctionsViewGroup. 2657*da0073e9SAndroid Build Coastguard Worker# See Note [Codegen'd {view}_copy Operators] 2658*da0073e9SAndroid Build Coastguard Worker# 2659*da0073e9SAndroid Build Coastguard Worker# One property of this representation is that in order for a view-like op to be part of 2660*da0073e9SAndroid Build Coastguard Worker# a NativeFunctionsViewGroup, the "aliasing" version of that view op must exist. 2661*da0073e9SAndroid Build Coastguard Worker# There's one case where that doesn't happen: we have a non-aliasing `narrow_copy.out` op, 2662*da0073e9SAndroid Build Coastguard Worker# but don't have corresponding aliasing `narrow.out` op. 2663*da0073e9SAndroid Build Coastguard Worker# This means that `narrow_copy.out` won't appear as a NativeFunctionsViewGroup. 2664*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 2665*da0073e9SAndroid Build Coastguard Workerclass NativeFunctionsViewGroup: 2666*da0073e9SAndroid Build Coastguard Worker view: NativeFunction 2667*da0073e9SAndroid Build Coastguard Worker # Note: the {view}_copy operator is optional because we currently don't generate copy variants 2668*da0073e9SAndroid Build Coastguard Worker # for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views 2669*da0073e9SAndroid Build Coastguard Worker # (we already get them "for free" through decomposition) 2670*da0073e9SAndroid Build Coastguard Worker view_copy: NativeFunction | None 2671*da0073e9SAndroid Build Coastguard Worker # view_inplace ops are also optional, but every view_inplace op should have out-of-place variant. 2672*da0073e9SAndroid Build Coastguard Worker view_inplace: NativeFunction | None 2673*da0073e9SAndroid Build Coastguard Worker 2674*da0073e9SAndroid Build Coastguard Worker def __post_init__(self) -> None: 2675*da0073e9SAndroid Build Coastguard Worker assert self.view.is_view_op 2676*da0073e9SAndroid Build Coastguard Worker if self.view_copy is None: 2677*da0073e9SAndroid Build Coastguard Worker assert not gets_generated_view_copy(self.view), ( 2678*da0073e9SAndroid Build Coastguard Worker f"{str(self.view.func.name)} appears to be a new operator that aliases its inputs." 2679*da0073e9SAndroid Build Coastguard Worker " The codegen expects you to add a corresponding operator to native_functions.yaml:" 2680*da0073e9SAndroid Build Coastguard Worker f" {get_view_copy_name(self.view)!s}." 2681*da0073e9SAndroid Build Coastguard Worker " See Note [view_copy NativeFunctions] for details." 2682*da0073e9SAndroid Build Coastguard Worker ) 2683*da0073e9SAndroid Build Coastguard Worker else: 2684*da0073e9SAndroid Build Coastguard Worker assert self.view_copy.func.name.name.base.endswith(("_copy", "_scatter")) 2685*da0073e9SAndroid Build Coastguard Worker assert self.view.func.signature() == self.view_copy.func.signature( 2686*da0073e9SAndroid Build Coastguard Worker strip_view_copy_name=True, 2687*da0073e9SAndroid Build Coastguard Worker ) 2688*da0073e9SAndroid Build Coastguard Worker assert "view_copy" in self.view_copy.tags, ( 2689*da0073e9SAndroid Build Coastguard Worker f"{str(self.view_copy.func.name), str(self.view.tags)} appears to be a view_copy operator. The codegen expects" 2690*da0073e9SAndroid Build Coastguard Worker " view_copy operators to be annotated with the 'view_copy' tag in native_functions.yaml." 2691*da0073e9SAndroid Build Coastguard Worker " See Note [view_copy NativeFunction] for details." 2692*da0073e9SAndroid Build Coastguard Worker ) 2693*da0073e9SAndroid Build Coastguard Worker if self.view_inplace is not None: 2694*da0073e9SAndroid Build Coastguard Worker assert self.view.func.signature() == self.view_inplace.func.signature() 2695*da0073e9SAndroid Build Coastguard Worker 2696*da0073e9SAndroid Build Coastguard Worker if self.view.has_composite_implicit_autograd_kernel: 2697*da0073e9SAndroid Build Coastguard Worker if self.view_inplace is not None: 2698*da0073e9SAndroid Build Coastguard Worker assert self.view_inplace.has_composite_implicit_autograd_kernel, ( 2699*da0073e9SAndroid Build Coastguard Worker f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either" 2700*da0073e9SAndroid Build Coastguard Worker " both have CompositeImplicitAutograd kernels, or both not have composite kernels." 2701*da0073e9SAndroid Build Coastguard Worker ) 2702*da0073e9SAndroid Build Coastguard Worker if self.view.has_composite_implicit_autograd_nested_tensor_kernel: 2703*da0073e9SAndroid Build Coastguard Worker if self.view_inplace is not None: 2704*da0073e9SAndroid Build Coastguard Worker assert ( 2705*da0073e9SAndroid Build Coastguard Worker self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel 2706*da0073e9SAndroid Build Coastguard Worker ), ( 2707*da0073e9SAndroid Build Coastguard Worker f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either" 2708*da0073e9SAndroid Build Coastguard Worker " both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels." 2709*da0073e9SAndroid Build Coastguard Worker ) 2710*da0073e9SAndroid Build Coastguard Worker 2711*da0073e9SAndroid Build Coastguard Worker def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]: 2712*da0073e9SAndroid Build Coastguard Worker yield self.view 2713*da0073e9SAndroid Build Coastguard Worker if self.view_inplace is not None: 2714*da0073e9SAndroid Build Coastguard Worker yield self.view_inplace 2715*da0073e9SAndroid Build Coastguard Worker if self.view_copy is not None and include_copy: 2716*da0073e9SAndroid Build Coastguard Worker yield self.view_copy 2717*da0073e9SAndroid Build Coastguard Worker 2718*da0073e9SAndroid Build Coastguard Worker @property 2719*da0073e9SAndroid Build Coastguard Worker def root_name(self) -> str: 2720*da0073e9SAndroid Build Coastguard Worker return self.view.root_name 2721*da0073e9SAndroid Build Coastguard Worker 2722*da0073e9SAndroid Build Coastguard Worker @property 2723*da0073e9SAndroid Build Coastguard Worker def composite(self) -> bool: 2724*da0073e9SAndroid Build Coastguard Worker # We currently assert that the "group" is consistent. 2725*da0073e9SAndroid Build Coastguard Worker # If the view op is composite, then its view_inplace op is too. 2726*da0073e9SAndroid Build Coastguard Worker return self.view.has_composite_implicit_autograd_kernel 2727*da0073e9SAndroid Build Coastguard Worker 2728*da0073e9SAndroid Build Coastguard Worker 2729*da0073e9SAndroid Build Coastguard Workerdef gets_generated_view_copy(f: NativeFunction) -> bool: 2730*da0073e9SAndroid Build Coastguard Worker # Only aliasing (view) operators get a copy variant. 2731*da0073e9SAndroid Build Coastguard Worker if not f.is_view_op: 2732*da0073e9SAndroid Build Coastguard Worker return False 2733*da0073e9SAndroid Build Coastguard Worker # We don't need to bother generating copy variants for CompositeImplicitAutograd ops, 2734*da0073e9SAndroid Build Coastguard Worker # because we can let them decompose into base view ops. 2735*da0073e9SAndroid Build Coastguard Worker if f.has_composite_implicit_autograd_kernel: 2736*da0073e9SAndroid Build Coastguard Worker return False 2737*da0073e9SAndroid Build Coastguard Worker # We also don't need to generate copy variants for inplace views. 2738*da0073e9SAndroid Build Coastguard Worker if "inplace_view" in f.tags: 2739*da0073e9SAndroid Build Coastguard Worker return False 2740*da0073e9SAndroid Build Coastguard Worker # Assume ops ending in _inverse have manually-defined copy variants 2741*da0073e9SAndroid Build Coastguard Worker # (e.g. slice_inverse() has the copy variant slice_scatter()). 2742*da0073e9SAndroid Build Coastguard Worker # We -could- probably generate these as well, but the codegen will be 2743*da0073e9SAndroid Build Coastguard Worker # slightly different, and hand-writing these few kernels keeps codegen 2744*da0073e9SAndroid Build Coastguard Worker # complexity lower. 2745*da0073e9SAndroid Build Coastguard Worker if f.func.name.name.base.endswith("_inverse"): 2746*da0073e9SAndroid Build Coastguard Worker return False 2747*da0073e9SAndroid Build Coastguard Worker return True 2748*da0073e9SAndroid Build Coastguard Worker 2749*da0073e9SAndroid Build Coastguard Worker 2750*da0073e9SAndroid Build Coastguard Worker# Given a NativeFunction that corresponds to a view op, 2751*da0073e9SAndroid Build Coastguard Worker# returns the OperatorName of the corresponding "copy" variant of the op. 2752*da0073e9SAndroid Build Coastguard Workerdef get_view_copy_name(f: NativeFunction) -> OperatorName: 2753*da0073e9SAndroid Build Coastguard Worker # Right now, when asking for a view op's corresponding "view_copy" name 2754*da0073e9SAndroid Build Coastguard Worker # we assert for sanity that the op is allowed to have a generated view_copy variant. 2755*da0073e9SAndroid Build Coastguard Worker # (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op). 2756*da0073e9SAndroid Build Coastguard Worker # However, narrow_copy() already exists as an op directly in native_functions.yaml. 2757*da0073e9SAndroid Build Coastguard Worker # I'm hardcoding narrow_copy here for now to maintain the assert, 2758*da0073e9SAndroid Build Coastguard Worker # But we could also just get rid of the assert. 2759*da0073e9SAndroid Build Coastguard Worker list_of_ops_with_explicit_view_copy_operators = ["narrow"] 2760*da0073e9SAndroid Build Coastguard Worker if str(f.func.name) not in list_of_ops_with_explicit_view_copy_operators: 2761*da0073e9SAndroid Build Coastguard Worker assert gets_generated_view_copy(f) 2762*da0073e9SAndroid Build Coastguard Worker 2763*da0073e9SAndroid Build Coastguard Worker base_name = f"{f.func.name.name.base}_copy" 2764*da0073e9SAndroid Build Coastguard Worker view_copy_name = OperatorName( 2765*da0073e9SAndroid Build Coastguard Worker name=BaseOperatorName( 2766*da0073e9SAndroid Build Coastguard Worker base=base_name, inplace=False, dunder_method=f.func.name.name.dunder_method 2767*da0073e9SAndroid Build Coastguard Worker ), 2768*da0073e9SAndroid Build Coastguard Worker overload_name=f.func.name.overload_name, 2769*da0073e9SAndroid Build Coastguard Worker ) 2770*da0073e9SAndroid Build Coastguard Worker return view_copy_name 2771*da0073e9SAndroid Build Coastguard Worker 2772*da0073e9SAndroid Build Coastguard Worker 2773*da0073e9SAndroid Build Coastguard Worker# Helper functions for parsing argument lists (both inputs and returns) 2774*da0073e9SAndroid Build Coastguard Worker 2775*da0073e9SAndroid Build Coastguard Worker 2776*da0073e9SAndroid Build Coastguard Workerdef parse_returns(return_decl: str) -> tuple[Return, ...]: 2777*da0073e9SAndroid Build Coastguard Worker """ 2778*da0073e9SAndroid Build Coastguard Worker Input: '()' 2779*da0073e9SAndroid Build Coastguard Worker Output: [] 2780*da0073e9SAndroid Build Coastguard Worker """ 2781*da0073e9SAndroid Build Coastguard Worker if return_decl == "()": 2782*da0073e9SAndroid Build Coastguard Worker return () 2783*da0073e9SAndroid Build Coastguard Worker if return_decl[0] == "(" and return_decl[-1] == ")": 2784*da0073e9SAndroid Build Coastguard Worker return_decl = return_decl[1:-1] 2785*da0073e9SAndroid Build Coastguard Worker return tuple(Return.parse(arg) for arg in return_decl.split(", ")) 2786*da0073e9SAndroid Build Coastguard Worker 2787*da0073e9SAndroid Build Coastguard Worker 2788*da0073e9SAndroid Build Coastguard Worker# A Precompute instance consists of a map from kernel argument name 2789*da0073e9SAndroid Build Coastguard Worker# to the list of Argument instances that should replace that 2790*da0073e9SAndroid Build Coastguard Worker# kernel argument in the impl function. 2791*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 2792*da0073e9SAndroid Build Coastguard Workerclass Precompute: 2793*da0073e9SAndroid Build Coastguard Worker # A map from kernel argument name -> a list of precomputed 2794*da0073e9SAndroid Build Coastguard Worker # elements that replaces/supersedes it. 2795*da0073e9SAndroid Build Coastguard Worker replace: dict[str, list[Argument]] 2796*da0073e9SAndroid Build Coastguard Worker # List of precomputed args added without replacement 2797*da0073e9SAndroid Build Coastguard Worker add: list[Argument] 2798*da0073e9SAndroid Build Coastguard Worker 2799*da0073e9SAndroid Build Coastguard Worker @staticmethod 2800*da0073e9SAndroid Build Coastguard Worker def parse(src: object) -> Precompute: 2801*da0073e9SAndroid Build Coastguard Worker assert isinstance(src, list) 2802*da0073e9SAndroid Build Coastguard Worker 2803*da0073e9SAndroid Build Coastguard Worker # src is a list of strings of the format: 2804*da0073e9SAndroid Build Coastguard Worker # {kernel param name} -> {replacement decl}[, {replacement decl}, ...] 2805*da0073e9SAndroid Build Coastguard Worker # [{add decl}[, {add decl}, ...]] 2806*da0073e9SAndroid Build Coastguard Worker # The last line is optional and contains the precomputed parameters that are 2807*da0073e9SAndroid Build Coastguard Worker # added without replacement. 2808*da0073e9SAndroid Build Coastguard Worker # The other lines are parsed to get the names of which precomputed elements 2809*da0073e9SAndroid Build Coastguard Worker # should replace which kernel arguments. 2810*da0073e9SAndroid Build Coastguard Worker add_args = [] 2811*da0073e9SAndroid Build Coastguard Worker if " -> " not in src[-1]: 2812*da0073e9SAndroid Build Coastguard Worker add_list = src[-1].split(",") 2813*da0073e9SAndroid Build Coastguard Worker add_args = [Argument.parse(name.strip()) for name in add_list] 2814*da0073e9SAndroid Build Coastguard Worker src = src[:-1] 2815*da0073e9SAndroid Build Coastguard Worker 2816*da0073e9SAndroid Build Coastguard Worker replace = {} 2817*da0073e9SAndroid Build Coastguard Worker for raw_replace_item in src: 2818*da0073e9SAndroid Build Coastguard Worker assert isinstance(raw_replace_item, str) 2819*da0073e9SAndroid Build Coastguard Worker assert " -> " in raw_replace_item, ( 2820*da0073e9SAndroid Build Coastguard Worker "precomputed parameters without replacement" 2821*da0073e9SAndroid Build Coastguard Worker " are allowed only in the last line" 2822*da0073e9SAndroid Build Coastguard Worker ) 2823*da0073e9SAndroid Build Coastguard Worker 2824*da0073e9SAndroid Build Coastguard Worker arg, with_list_raw = raw_replace_item.split(" -> ") 2825*da0073e9SAndroid Build Coastguard Worker assert ( 2826*da0073e9SAndroid Build Coastguard Worker " " not in arg 2827*da0073e9SAndroid Build Coastguard Worker ), f"illegal kernel param name '{arg}' in precomputed parameters'" 2828*da0073e9SAndroid Build Coastguard Worker with_list = with_list_raw.split(",") 2829*da0073e9SAndroid Build Coastguard Worker with_list_args = [Argument.parse(name.strip()) for name in with_list] 2830*da0073e9SAndroid Build Coastguard Worker replace[arg] = with_list_args 2831*da0073e9SAndroid Build Coastguard Worker 2832*da0073e9SAndroid Build Coastguard Worker r = Precompute(replace=replace, add=add_args) 2833*da0073e9SAndroid Build Coastguard Worker assert r.to_list() == src, "r.to_list() != src" 2834*da0073e9SAndroid Build Coastguard Worker return r 2835*da0073e9SAndroid Build Coastguard Worker 2836*da0073e9SAndroid Build Coastguard Worker def __post_init__(self) -> None: 2837*da0073e9SAndroid Build Coastguard Worker # the template parameters are upper so if these are the 2838*da0073e9SAndroid Build Coastguard Worker # same then it is ambiguous 2839*da0073e9SAndroid Build Coastguard Worker for a in self.add: 2840*da0073e9SAndroid Build Coastguard Worker assert a.name.upper() != a.name 2841*da0073e9SAndroid Build Coastguard Worker for args in self.replace.values(): 2842*da0073e9SAndroid Build Coastguard Worker for a in args: 2843*da0073e9SAndroid Build Coastguard Worker assert a.name.upper() != a.name 2844*da0073e9SAndroid Build Coastguard Worker 2845*da0073e9SAndroid Build Coastguard Worker def to_list(self) -> list[str]: 2846*da0073e9SAndroid Build Coastguard Worker replace_list = [] 2847*da0073e9SAndroid Build Coastguard Worker for kernel_param, replacement_params in self.replace.items(): 2848*da0073e9SAndroid Build Coastguard Worker replacements = ", ".join(str(param) for param in replacement_params) 2849*da0073e9SAndroid Build Coastguard Worker replace_list.append(f"{kernel_param} -> {replacements}") 2850*da0073e9SAndroid Build Coastguard Worker 2851*da0073e9SAndroid Build Coastguard Worker return replace_list 2852