xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/boxing/KernelFunction_impl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/boxing/impl/boxing.h>
2 #include <ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h>
3 #include <ATen/core/boxing/impl/WrapFunctionIntoFunctor.h>
4 #include <ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h>
5 
6 #include <c10/util/C++17.h>
7 #include <type_traits>
8 
9 namespace c10 {
10 
KernelFunction()11 inline KernelFunction::KernelFunction()
12     : boxed_kernel_func_()
13     , unboxed_kernel_func_(nullptr)
14     , sym_unboxed_kernel_func_(nullptr)
15 {}
16 
17 inline KernelFunction::KernelFunction(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr)
boxed_kernel_func_(std::move (functor),boxed_kernel_func)18   : boxed_kernel_func_(std::move(functor), boxed_kernel_func)
19   , unboxed_kernel_func_(unboxed_kernel_func)
20   , sym_unboxed_kernel_func_(sym_unboxed_kernel_func)
21 {}
22 
23 inline KernelFunction::KernelFunction(BoxedKernel boxed_fn, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr)
boxed_kernel_func_(std::move (boxed_fn))24   : boxed_kernel_func_(std::move(boxed_fn))
25   , unboxed_kernel_func_(unboxed_kernel_func)
26   , sym_unboxed_kernel_func_(sym_unboxed_kernel_func)
27 {}
28 
isValidUnboxed()29 inline bool KernelFunction::isValidUnboxed() const {
30   return unboxed_kernel_func_ != nullptr;
31 }
32 
isValidSymUnboxed()33 inline bool KernelFunction::isValidSymUnboxed() const {
34   return sym_unboxed_kernel_func_ != nullptr;
35 }
36 
isValid()37 inline bool KernelFunction::isValid() const {
38   return boxed_kernel_func_.isValid();
39 }
40 
isFallthrough()41 inline bool KernelFunction::isFallthrough() const {
42   return boxed_kernel_func_.isFallthrough();
43 }
44 
callBoxed(const OperatorHandle & opHandle,DispatchKeySet dispatchKeySet,Stack * stack)45 inline void KernelFunction::callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const {
46   boxed_kernel_func_.callBoxed(opHandle, dispatchKeySet, stack);
47 }
48 
49 template<class Return, class... Args>
callUnboxedKernelFunction(void * unboxed_kernel_func,OperatorKernel * functor,DispatchKeySet dispatchKeySet,Args &&...args)50 inline Return callUnboxedKernelFunction(void* unboxed_kernel_func, OperatorKernel* functor, DispatchKeySet dispatchKeySet, Args&&... args) {
51     using ActualSignature = Return (OperatorKernel*, DispatchKeySet, Args...);
52     ActualSignature* func = reinterpret_cast<ActualSignature*>(unboxed_kernel_func);
53     return (*func)(functor, dispatchKeySet, std::forward<Args>(args)...);
54 }
55 
56 // This template requires you to explicitly specify the argument you want to
57 // forward; it doesn't work if you try to deduce it
58 // NB: keep this in sync with cloneWithRealTypes in function_schema.cpp
59 
60 template <typename T>
unpackSymInt(T x)61 inline typename remove_symint<T>::type unpackSymInt(T x) { return x; }
62 
63 template <>
unpackSymInt(c10::SymInt x)64 inline typename remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) {
65   return x.guard_int(__FILE__, __LINE__);
66 }
67 
68 template <>
unpackSymInt(c10::SymIntArrayRef x)69 inline typename remove_symint<c10::SymIntArrayRef>::type unpackSymInt(c10::SymIntArrayRef x) {
70   return C10_AS_INTARRAYREF_SLOW(x);
71 }
72 
73 template <>
unpackSymInt(std::optional<c10::SymInt> x)74 inline typename remove_symint<std::optional<c10::SymInt>>::type unpackSymInt(std::optional<c10::SymInt> x) {
75   return x.has_value() ? std::make_optional(x->guard_int(__FILE__, __LINE__)) : std::nullopt;
76 }
77 
78 template <>
unpackSymInt(at::OptionalSymIntArrayRef x)79 inline typename remove_symint<at::OptionalSymIntArrayRef>::type unpackSymInt(at::OptionalSymIntArrayRef x) {
80   return x.has_value() ? std::make_optional(C10_AS_INTARRAYREF_SLOW(*x)) : std::nullopt;
81 }
82 
83 template<class Return, class... Args>
call(const OperatorHandle & opHandle,DispatchKeySet dispatchKeySet,Args...args)84 C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const {
85     // note: Args above is intentionally not Args&&. We don't want perfect
86     // forwarding, which would require Args to be deduced, but instead we
87     // want callers to explicitly specify the Args.
88 
89     if constexpr (std::disjunction_v<has_symint<Args>...>) {
90       if (sym_unboxed_kernel_func_ != nullptr) {
91         auto *functor = boxed_kernel_func_.getFunctor();
92         return callUnboxedKernelFunction<Return, Args...>(
93             sym_unboxed_kernel_func_, functor, dispatchKeySet, std::forward<Args>(args)...);
94       }
95 
96       if (unboxed_kernel_func_ != nullptr) {
97         auto *functor = boxed_kernel_func_.getFunctor();
98         return callUnboxedKernelFunction<Return, typename remove_symint<Args>::type...>(
99             unboxed_kernel_func_, functor, dispatchKeySet, unpackSymInt<Args>(args)...);
100       }
101     } else {
102       if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
103         auto *functor = boxed_kernel_func_.getFunctor();
104         return callUnboxedKernelFunction<Return, Args...>(
105             unboxed_kernel_func_, functor, dispatchKeySet, std::forward<Args>(args)...);
106       }
107     }
108 
109     return impl::BoxedKernelWrapper<Return(Args...)>::call(
110         boxed_kernel_func_,
111         opHandle,
112         dispatchKeySet,
113         std::forward<Args>(args)...
114     );
115 }
116 
makeFromBoxedKernel(BoxedKernel boxed_fn)117 inline KernelFunction KernelFunction::makeFromBoxedKernel(BoxedKernel boxed_fn) {
118   return KernelFunction(std::move(boxed_fn), nullptr);  // no unboxed function pointer
119 }
120 
121 template<KernelFunction::BoxedKernelFunction* func>
makeFromBoxedFunction()122 inline KernelFunction KernelFunction::makeFromBoxedFunction() {
123   return KernelFunction::makeFromBoxedKernel(
124       BoxedKernel::makeFromFunction<func>());
125 }
126 
127 template<KernelFunction::BoxedKernelFunction_withDispatchKeys* func>
makeFromBoxedFunction()128 inline KernelFunction KernelFunction::makeFromBoxedFunction() {
129   return KernelFunction::makeFromBoxedKernel(
130       BoxedKernel::makeFromFunction<func>());
131 }
132 
makeFallthrough()133 inline KernelFunction KernelFunction::makeFallthrough() {
134   return KernelFunction::makeFromBoxedKernel(
135       BoxedKernel::makeFallthrough());
136 }
137 
makeAmbiguousAutogradOther()138 inline KernelFunction KernelFunction::makeAmbiguousAutogradOther() {
139   return KernelFunction::makeFromBoxedKernel(
140       BoxedKernel::makeAmbiguousAutogradOther());
141 }
142 
makeNamedNotSupported()143 inline KernelFunction KernelFunction::makeNamedNotSupported() {
144   return KernelFunction::makeFromBoxedKernel(
145       BoxedKernel::makeNamedNotSupported());
146 }
147 
148 template<bool AllowLegacyTypes, class KernelFunctor>
makeFromUnboxedFunctor(std::unique_ptr<OperatorKernel> kernelFunctor)149 inline KernelFunction KernelFunction::makeFromUnboxedFunctor(std::unique_ptr<OperatorKernel> kernelFunctor) {
150 #ifndef NDEBUG
151   // This assertion is costly for build time so it's debug-gated.
152     static_assert(guts::is_functor<KernelFunctor>::value, "Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor> but the argument is not a functor.");
153 #endif
154     static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
155 
156     auto* unboxed_fn = &impl::wrap_kernel_functor_unboxed<KernelFunctor>::call;
157     void* void_unboxed_fn = reinterpret_cast<void*>(unboxed_fn);
158     bool is_symint = fn_has_symint<decltype(unboxed_fn)>::value;
159     return KernelFunction(
160         std::move(kernelFunctor),
161         &impl::make_boxed_from_unboxed_functor<KernelFunctor, AllowLegacyTypes>::call,
162         is_symint ? nullptr : void_unboxed_fn,
163         is_symint ? void_unboxed_fn : nullptr
164     );
165 }
166 
167 template<class KernelFunctor>
makeFromBoxedFunctor(std::unique_ptr<KernelFunctor> kernelFunctor)168 inline KernelFunction KernelFunction::makeFromBoxedFunctor(std::unique_ptr<KernelFunctor> kernelFunctor) {
169   return KernelFunction::makeFromBoxedKernel(
170       BoxedKernel::makeFromFunctor(std::move(kernelFunctor)));
171 }
172 
173 template<class FuncPtr, bool AllowLegacyTypes>
makeFromUnboxedFunction(FuncPtr func_ptr)174 inline KernelFunction KernelFunction::makeFromUnboxedFunction(FuncPtr func_ptr) {
175     static_assert(is_compile_time_function_pointer<FuncPtr>::value, "Tried to call KernelFunction::makeFromUnboxedFunction with an invalid parameter. It must be a function pointer created with TORCH_FN.");
176     static_assert(!std::is_same<typename FuncPtr::FuncType, BoxedKernelFunction>::value, "Tried to call KernelFunction::makeFromUnboxedFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
177     static_assert(FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr");
178 
179 #if !defined(C10_MOBILE)
180     (void)func_ptr; // Suppress unused variable warning
181     return makeFromUnboxedFunctor<AllowLegacyTypes, typename impl::WrapFunctionIntoFunctor<FuncPtr>::type>(
182         guts::make_unique_base<OperatorKernel, typename impl::WrapFunctionIntoFunctor<FuncPtr>::type>()
183     );
184 #else
185     // On mobile, we rather want to optimize for binary size than for performance,
186     // so let's not inline the kernel into the wrapper but use makeFromUnboxedRuntimeFunction
187     // instead.
188     return makeFromUnboxedRuntimeFunction(func_ptr.func_ptr());
189 #endif
190 }
191 
192 template<bool AllowLegacyTypes, class FuncType>
makeFromUnboxedRuntimeFunction(FuncType * func)193 inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction(FuncType* func) {
194     static_assert(guts::is_function_type<FuncType>::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type.");
195     static_assert(!std::is_same<FuncType, BoxedKernelFunction>::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
196     TORCH_INTERNAL_ASSERT(func != nullptr, "Kernel function cannot be nullptr");
197 
198     return makeFromUnboxedFunctor<AllowLegacyTypes, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>(
199         guts::make_unique_base<OperatorKernel, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>(func)
200     );
201 }
202 
203 template<bool AllowLegacyTypes, class Lambda>
makeFromUnboxedLambda(Lambda && lambda)204 inline std::enable_if_t<guts::is_stateless_lambda<std::decay_t<Lambda>>::value, KernelFunction> KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
205     static_assert(guts::is_functor<std::decay_t<Lambda>>::value, "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type.");
206 
207 #if !defined(C10_MOBILE)
208     return makeFromUnboxedFunctor<AllowLegacyTypes, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
209         guts::make_unique_base<OperatorKernel, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(std::forward<Lambda>(lambda))
210     );
211 #else
212     // On mobile, we rather want to optimize for binary size than for performance,
213     // so let's not inline the kernel into the wrapper but use makeFromUnboxedRuntimeFunction
214     // instead.
215     using FuncType = typename guts::infer_function_traits_t<std::decay_t<Lambda>>::func_type;
216     return makeFromUnboxedRuntimeFunction<AllowLegacyTypes, FuncType>(lambda);
217 #endif
218 }
219 
220 template<bool AllowLegacyTypes, class Lambda>
makeFromUnboxedLambda(Lambda && lambda)221 inline std::enable_if_t<!guts::is_stateless_lambda<std::decay_t<Lambda>>::value, KernelFunction> KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
222     static_assert(guts::is_functor<std::decay_t<Lambda>>::value, "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type.");
223 
224     return makeFromUnboxedFunctor<AllowLegacyTypes, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(
225         guts::make_unique_base<OperatorKernel, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>(std::forward<Lambda>(lambda))
226     );
227 }
228 
229 }
230