1from __future__ import annotations 2 3from dataclasses import dataclass 4 5import torchgen.api.types as api_types 6from torchgen.api import cpp, structured 7from torchgen.api.types import ( 8 ArgName, 9 BaseCppType, 10 BaseCType, 11 Binding, 12 ConstRefCType, 13 CType, 14 NamedCType, 15 scalarT, 16) 17from torchgen.model import ( 18 Argument, 19 BaseTy, 20 BaseType, 21 DispatchKey, 22 FunctionSchema, 23 NativeFunctionsGroup, 24 Type, 25) 26 27 28def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str: 29 assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas" 30 return f"ufunc_{func.name.name}_{dispatch_key}" 31 32 33def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str: 34 return schema_kernel_name(g.out.func, dispatch_key) 35 36 37# Tensors are omitted (as they are stored in TensorIterator), everything else is 38# passed along (technically, we can pass tensors along too, it just wastes 39# argument registers) 40# 41# NB: used for CPU only 42def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None: 43 # Dispatch stubs are always plain ints 44 r = cpp.valuetype_type(t, binds=binds, symint=False) 45 if r is not None: 46 return r 47 48 if t == BaseType(BaseTy.Scalar): 49 return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) 50 elif t == BaseType(BaseTy.Tensor): 51 return None 52 else: 53 raise AssertionError(f"unrecognized type {repr(t)}") 54 55 56def opmath_type(scalar_t: BaseCppType) -> BaseCppType: 57 if scalar_t == api_types.scalar_t: 58 return api_types.opmath_t 59 raise NotImplementedError 60 61 62# NB: Tensors in constructor are stored in opmath_t, not scalar_t 63# because Tensor in constructor = its a scalar tensor partially applied = 64# it can be higher precision and we want to compute in that higher precision 65# 66# NB: CUDA only 67def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType: 68 r = cpp.valuetype_type(t, binds=binds, symint=False) 69 if r is not None: 70 return r 71 72 if t == BaseType(BaseTy.Scalar): 73 return NamedCType(binds, BaseCType(opmath_type(scalar_t))) 74 elif t == BaseType(BaseTy.Tensor): 75 return NamedCType(binds, BaseCType(opmath_type(scalar_t))) 76 else: 77 raise AssertionError(f"unrecognized type {repr(t)}") 78 79 80# Only Tensors ever get passed directly to operator() 81# 82# NB: CUDA only 83# (Actually, this works for CPU too) 84def ufunctor_apply_type( 85 t: Type, *, binds: ArgName, scalar_t: BaseCppType 86) -> NamedCType: 87 if t == BaseType(BaseTy.Tensor): 88 return NamedCType(binds, BaseCType(scalar_t)) 89 else: 90 raise AssertionError(f"unrecognized type {repr(t)}") 91 92 93# The actual ufunc template function the user writes. Everything here 94# is done in the computation type. compute_t is opmath_t in CUDA and scalar_t 95# in CPU 96def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType: 97 r = cpp.valuetype_type(t, binds=binds, symint=False) 98 if r is not None: 99 return r 100 101 if t == BaseType(BaseTy.Scalar): 102 return NamedCType(binds, compute_t) 103 elif t == BaseType(BaseTy.Tensor): 104 return NamedCType(binds, compute_t) 105 else: 106 raise AssertionError(f"unrecognized type {repr(t)}") 107 108 109def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding: 110 return Binding( 111 nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t), 112 name=a.name, 113 default=None, 114 argument=a, 115 ) 116 117 118def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding: 119 return Binding( 120 nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t), 121 name=a.name, 122 default=None, 123 argument=a, 124 ) 125 126 127def ufunc_argument(a: Argument, compute_t: CType) -> Binding: 128 return Binding( 129 nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t), 130 name=a.name, 131 default=None, 132 argument=a, 133 ) 134 135 136@dataclass(frozen=True) 137class UfunctorBindings: 138 ctor: list[Binding] 139 apply: list[Binding] 140 141 142# ufunctors are a CUDA-only concept representing functors that take some of 143# their arguments on a host-side constructor, and the rest in the device-side 144# apply. E.g., 145# 146# template <typename scalar_t> 147# struct CUDAFunctorOnSelf_add { 148# using opmath_t = at::opmath_type<scalar_t>; 149# opmath_t other_; 150# opmath_t alpha_; 151# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {} 152# __device__ scalar_t operator()(scalar_t self) { 153# return ufunc::add(static_cast<opmath_t>(self), other_, alpha_); 154# } 155# }; 156# 157# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers 158# to the operator() definition 159def ufunctor_arguments( 160 g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType 161) -> UfunctorBindings: 162 ctor = [] 163 apply = [] 164 for a in g.functional.func.arguments.flat_non_out: 165 if a.type.is_tensor_like(): 166 if scalar_tensor_idx == 0: 167 # put it in the ctor anyway 168 ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t)) 169 scalar_tensor_idx = None 170 else: 171 if scalar_tensor_idx is not None: 172 scalar_tensor_idx -= 1 173 apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t)) 174 else: 175 ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t)) 176 assert scalar_tensor_idx is None 177 return UfunctorBindings(ctor=ctor, apply=apply) 178 179 180# ufuncs are the inner loop template functions that you wrote in ufunc/add.h 181# which do the actual computation in question. E.g., 182# 183# template <typename T> 184# C10_HOST_DEVICE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ { 185# return self + alpha * other; 186# } 187# 188# In this file, we refer to T as compute_t which is bound by caller 189def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]: 190 return [ 191 ufunc_argument(a, compute_t=compute_t) 192 for a in g.functional.func.arguments.flat_non_out 193 ] 194 195 196# Stubs are the DispatchStub trampolines that CPU kernels use to get to their 197# vectorized versions. E.g., 198# 199# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha); 200# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub); 201def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]: 202 # stubs drop all tensor arguments (they are implicit in the TensorIterator 203 # argument and keep everything else) 204 return [ 205 r 206 for a in g.out.func.arguments.flat_non_out 207 if not a.type.is_tensor_like() 208 for r in structured.argument(a) 209 ] 210