1from __future__ import annotations 2 3from torchgen.api import cpp 4from torchgen.api.types import ( 5 ArgName, 6 ArrayRefCType, 7 BaseCType, 8 Binding, 9 ConstRefCType, 10 dimnameListT, 11 intArrayRefT, 12 iOptTensorListRefT, 13 iTensorListRefT, 14 NamedCType, 15 OptionalCType, 16 optionalIntArrayRefT, 17 optionalScalarRefT, 18 optionalTensorRefT, 19 scalarT, 20 tensorT, 21) 22from torchgen.model import ( 23 Argument, 24 BaseTy, 25 BaseType, 26 ListType, 27 NativeFunctionsGroup, 28 OptionalType, 29 SelfArgument, 30 TensorOptionsArguments, 31 Type, 32) 33from torchgen.utils import assert_never 34 35 36# This file describes the translation of JIT schema to the structured functions API. 37# This is similar to native API, but a number of historical problems with native 38# API have been fixed. 39 40 41# Translation of types occurring in JIT arguments to a C++ argument type. 42# NB: For now, mutable doesn't do anything; but it could if we make 43# some more nominal types 44def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: 45 # If it's a value type, do the value type translation 46 # NB: structured kernels ALWAYS have symint off, since they involve actual 47 # kernels that require real ints. The one exception is the 48 # CompositeExplicitAutograd and the meta function (which could 49 # hypothetically be SymInt), but for simplicity we plan for these to just 50 # be handled in Python 51 r = cpp.valuetype_type(t, symint=False, binds=binds, mutable=mutable) 52 if r is not None: 53 return r 54 55 if isinstance(t, BaseType): 56 if t.name == BaseTy.Tensor: 57 return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) 58 elif t.name == BaseTy.Scalar: 59 return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) 60 else: 61 raise AssertionError(f"base type should have been value type {t}") 62 elif isinstance(t, OptionalType): 63 if t.elem == BaseType(BaseTy.Tensor): 64 return NamedCType(binds, BaseCType(optionalTensorRefT)) 65 elif t.elem == BaseType(BaseTy.Scalar): 66 return NamedCType(binds, BaseCType(optionalScalarRefT)) 67 elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int": 68 return NamedCType(binds, BaseCType(optionalIntArrayRefT)) 69 elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) 70 return NamedCType(binds, OptionalCType(elem.type)) 71 elif isinstance(t, ListType): 72 if t.elem == BaseType(BaseTy.Tensor): 73 return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT))) 74 elif t.elem == OptionalType(BaseType(BaseTy.Tensor)): 75 return NamedCType(binds, BaseCType(iOptTensorListRefT)) 76 # TODO: delete these special cases; see torchgen.api.cpp--these 77 # must be changed in tandem, but there are problems; see 78 # https://github.com/pytorch/pytorch/pull/51485 79 elif str(t.elem) == "int": 80 return NamedCType(binds, BaseCType(intArrayRefT)) 81 elif str(t.elem) == "Dimname": 82 return NamedCType(binds, BaseCType(dimnameListT)) 83 elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) 84 return NamedCType(binds, ArrayRefCType(elem.type)) 85 else: 86 raise AssertionError(f"unrecognized type {repr(t)}") 87 88 89def argument_type(a: Argument, *, binds: ArgName) -> NamedCType: 90 return argumenttype_type(a.type, mutable=a.is_write, binds=binds) 91 92 93# returns_type intentionally omitted, because structured kernels never "return"; 94# instead, they always indirectly report their outputs (in the case of a meta 95# function, by calling set_output; in the case of an impl function, by writing 96# directly into the provided out argument). 97 98 99# Structured kernels are never defaulted 100def argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Binding]: 101 if isinstance(a, Argument): 102 return [ 103 Binding( 104 nctype=argument_type(a, binds=a.name), 105 name=a.name, 106 default=None, 107 argument=a, 108 ) 109 ] 110 elif isinstance(a, SelfArgument): 111 return argument(a.argument) 112 elif isinstance(a, TensorOptionsArguments): 113 raise AssertionError("structured kernels don't support TensorOptions yet") 114 else: 115 assert_never(a) 116 117 118def impl_arguments(g: NativeFunctionsGroup) -> list[Binding]: 119 args: list[Argument | TensorOptionsArguments | SelfArgument] = [] 120 121 if g.out.precomputed: 122 # A list of parameters for the impl function with 123 # certain parameters replaced with precomputed counterparts 124 # as specified in native_functions.yaml. 125 non_out_args_replaced: list[ 126 Argument | TensorOptionsArguments | SelfArgument 127 ] = [] 128 for a in g.out.func.arguments.non_out: 129 if isinstance(a, Argument) and a.name in g.out.precomputed.replace: 130 # If a is in precompute.replace, append the parameters 131 # that should replace it onto non_out_args_replaced. 132 non_out_args_replaced.extend(g.out.precomputed.replace[a.name]) 133 else: 134 # If not, push a as it is. 135 non_out_args_replaced.append(a) 136 137 args.extend(non_out_args_replaced) 138 # g.out.precomputed.add is the list of parameters that are added 139 # without replacement after the non out args and just before the out args 140 args.extend(g.out.precomputed.add) 141 else: 142 args.extend(g.out.func.arguments.non_out) 143 144 args.extend(g.out.func.arguments.out) 145 return [r for arg in args for r in argument(arg)] 146 147 148def meta_arguments(g: NativeFunctionsGroup) -> list[Binding]: 149 args: list[Argument | TensorOptionsArguments | SelfArgument] = [] 150 args.extend(g.functional.func.arguments.non_out) 151 return [r for arg in args for r in argument(arg)] 152 153 154def out_arguments(g: NativeFunctionsGroup) -> list[Binding]: 155 args: list[Argument | TensorOptionsArguments | SelfArgument] = [] 156 args.extend(g.out.func.arguments.out) 157 return [r for arg in args for r in argument(arg)] 158