xref: /aosp_15_r20/external/pytorch/torchgen/api/ufunc.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from dataclasses import dataclass
4
5import torchgen.api.types as api_types
6from torchgen.api import cpp, structured
7from torchgen.api.types import (
8    ArgName,
9    BaseCppType,
10    BaseCType,
11    Binding,
12    ConstRefCType,
13    CType,
14    NamedCType,
15    scalarT,
16)
17from torchgen.model import (
18    Argument,
19    BaseTy,
20    BaseType,
21    DispatchKey,
22    FunctionSchema,
23    NativeFunctionsGroup,
24    Type,
25)
26
27
28def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str:
29    assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas"
30    return f"ufunc_{func.name.name}_{dispatch_key}"
31
32
33def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
34    return schema_kernel_name(g.out.func, dispatch_key)
35
36
37# Tensors are omitted (as they are stored in TensorIterator), everything else is
38# passed along  (technically, we can pass tensors along too, it just wastes
39# argument registers)
40#
41# NB: used for CPU only
42def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None:
43    # Dispatch stubs are always plain ints
44    r = cpp.valuetype_type(t, binds=binds, symint=False)
45    if r is not None:
46        return r
47
48    if t == BaseType(BaseTy.Scalar):
49        return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
50    elif t == BaseType(BaseTy.Tensor):
51        return None
52    else:
53        raise AssertionError(f"unrecognized type {repr(t)}")
54
55
56def opmath_type(scalar_t: BaseCppType) -> BaseCppType:
57    if scalar_t == api_types.scalar_t:
58        return api_types.opmath_t
59    raise NotImplementedError
60
61
62# NB: Tensors in constructor are stored in opmath_t, not scalar_t
63# because Tensor in constructor = its a scalar tensor partially applied =
64# it can be higher precision and we want to compute in that higher precision
65#
66# NB: CUDA only
67def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType:
68    r = cpp.valuetype_type(t, binds=binds, symint=False)
69    if r is not None:
70        return r
71
72    if t == BaseType(BaseTy.Scalar):
73        return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
74    elif t == BaseType(BaseTy.Tensor):
75        return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
76    else:
77        raise AssertionError(f"unrecognized type {repr(t)}")
78
79
80# Only Tensors ever get passed directly to operator()
81#
82# NB: CUDA only
83# (Actually, this works for CPU too)
84def ufunctor_apply_type(
85    t: Type, *, binds: ArgName, scalar_t: BaseCppType
86) -> NamedCType:
87    if t == BaseType(BaseTy.Tensor):
88        return NamedCType(binds, BaseCType(scalar_t))
89    else:
90        raise AssertionError(f"unrecognized type {repr(t)}")
91
92
93# The actual ufunc template function the user writes.  Everything here
94# is done in the computation type.  compute_t is opmath_t in CUDA and scalar_t
95# in CPU
96def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType:
97    r = cpp.valuetype_type(t, binds=binds, symint=False)
98    if r is not None:
99        return r
100
101    if t == BaseType(BaseTy.Scalar):
102        return NamedCType(binds, compute_t)
103    elif t == BaseType(BaseTy.Tensor):
104        return NamedCType(binds, compute_t)
105    else:
106        raise AssertionError(f"unrecognized type {repr(t)}")
107
108
109def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
110    return Binding(
111        nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t),
112        name=a.name,
113        default=None,
114        argument=a,
115    )
116
117
118def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
119    return Binding(
120        nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t),
121        name=a.name,
122        default=None,
123        argument=a,
124    )
125
126
127def ufunc_argument(a: Argument, compute_t: CType) -> Binding:
128    return Binding(
129        nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t),
130        name=a.name,
131        default=None,
132        argument=a,
133    )
134
135
136@dataclass(frozen=True)
137class UfunctorBindings:
138    ctor: list[Binding]
139    apply: list[Binding]
140
141
142# ufunctors are a CUDA-only concept representing functors that take some of
143# their arguments on a host-side constructor, and the rest in the device-side
144# apply.  E.g.,
145#
146# template <typename scalar_t>
147# struct CUDAFunctorOnSelf_add {
148#   using opmath_t = at::opmath_type<scalar_t>;
149#   opmath_t other_;
150#   opmath_t alpha_;
151#   CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {}
152#   __device__ scalar_t operator()(scalar_t self) {
153#     return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
154#   }
155# };
156#
157# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers
158# to the operator() definition
159def ufunctor_arguments(
160    g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType
161) -> UfunctorBindings:
162    ctor = []
163    apply = []
164    for a in g.functional.func.arguments.flat_non_out:
165        if a.type.is_tensor_like():
166            if scalar_tensor_idx == 0:
167                # put it in the ctor anyway
168                ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
169                scalar_tensor_idx = None
170            else:
171                if scalar_tensor_idx is not None:
172                    scalar_tensor_idx -= 1
173                apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t))
174        else:
175            ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
176    assert scalar_tensor_idx is None
177    return UfunctorBindings(ctor=ctor, apply=apply)
178
179
180# ufuncs are the inner loop template functions that you wrote in ufunc/add.h
181# which do the actual computation in question.  E.g.,
182#
183# template <typename T>
184# C10_HOST_DEVICE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ {
185#   return self + alpha * other;
186# }
187#
188# In this file, we refer to T as compute_t which is bound by caller
189def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]:
190    return [
191        ufunc_argument(a, compute_t=compute_t)
192        for a in g.functional.func.arguments.flat_non_out
193    ]
194
195
196# Stubs are the DispatchStub trampolines that CPU kernels use to get to their
197# vectorized versions.  E.g.,
198#
199# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
200# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
201def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]:
202    # stubs drop all tensor arguments (they are implicit in the TensorIterator
203    # argument and keep everything else)
204    return [
205        r
206        for a in g.out.func.arguments.flat_non_out
207        if not a.type.is_tensor_like()
208        for r in structured.argument(a)
209    ]
210