xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/boxing/KernelFunction.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ATen_fwd.h>
4 #include <ATen/core/boxing/BoxedKernel.h>
5 #include <ATen/core/stack.h>
6 #include <c10/core/DispatchKeySet.h>
7 #include <c10/util/intrusive_ptr.h>
8 #include <c10/util/TypeList.h>
9 #include <type_traits>
10 
11 namespace c10 {
12 
13 using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack to the c10 namespace.
14 
15 class OperatorHandle;
16 struct OperatorKernel;
17 class KernelFunction;
18 
19 template <typename T>
20 using has_symint =
21   std::disjunction<
22     std::is_same<c10::SymInt, T>,
23     std::is_same<c10::SymIntArrayRef, T>,
24     std::is_same<at::OptionalSymIntArrayRef, T>,
25     std::is_same<std::optional<c10::SymInt>, T>
26   >;
27 
28 template <typename T>
29 struct remove_symint {
30   using type = T;
31 };
32 
33 template <>
34 struct remove_symint<c10::SymInt> {
35   using type = int64_t;
36 };
37 
38 template <>
39 struct remove_symint<at::OptionalSymIntArrayRef> {
40   using type = OptionalIntArrayRef;
41 };
42 
43 template <>
44 struct remove_symint<c10::SymIntArrayRef> {
45   using type = c10::IntArrayRef;
46 };
47 
48 template <>
49 struct remove_symint<std::optional<c10::SymInt>> {
50   using type = std::optional<int64_t>;
51 };
52 
53 
54 template <bool symint, typename T>
55 struct maybe_keep_symint final {};
56 
57 template <typename T>
58 struct maybe_keep_symint<true, T> { using type = T; };
59 
60 template <typename T>
61 struct maybe_keep_symint<false, T> { using type = typename remove_symint<T>::type; };
62 
63 template <typename T>
64 using fn_has_symint = typename guts::typelist::true_for_any_type<
65   has_symint,
66   typename guts::infer_function_traits<T>::type::parameter_types
67 >;
68 
69 template <typename T>
70 struct fn_remove_symint;
71 
72 template <typename Ret, typename... Args>
73 struct fn_remove_symint<Ret(Args...)> {
74   using type = Ret(typename remove_symint<Args>::type...);
75 };
76 
77 /**
78  * KernelFunction is similar to std::function but stores a kernel function.
79  * You can create a KernelFunction from a boxed or unboxed function/functor/lambda
80  * and call it in a boxed or unboxed way. If the way it was created doesn't
81  * match the way it was called, it will do boxing or unboxing as necessary.
82  */
83 class TORCH_API KernelFunction final {
84 public:
85   using InternalBoxedKernelFunction = BoxedKernel::InternalBoxedKernelFunction;
86   using BoxedKernelFunction = BoxedKernel::BoxedKernelFunction;
87   using BoxedKernelFunction_withDispatchKeys = BoxedKernel::BoxedKernelFunction_withDispatchKeys;
88 
89   KernelFunction();
90 
91   // Fast path for dispatch to allow not touching the boxed kernel in
92   // the common case where unboxed is available.
93   bool isValidUnboxed() const;
94   bool isValidSymUnboxed() const;
95   bool isValid() const;
96   bool isFallthrough() const;
97 
98   /**
99    * Call the function in a boxed way.
100    * If the kernel function was created with an unboxed function,
101    * this will call an unboxing wrapper which then calls into that
102    * unboxed function.
103    *
104    * Example:
105    *
106    * > void boxed_func(OperatorKernel*, Stack* stack) {...}
107    * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func);
108    * > Tensor result = func.callBoxed(stack);
109    *
110    * Or, with an unboxed implementation:
111    *
112    * > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
113    * >      [] (Tensor a, bool b) -> Tensor {...});
114    * > Tensor result = func.callBoxed(stack);
115    */
116   void callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const;
117 
118   /**
119    * Call the function in an unboxed way.
120    * If the kernel function was created with a boxed function,
121    * this will box all inputs and then call into that boxed function.
122    *
123    * Note that this doesn't work for all types yet.
124    *
125    * Example:
126    *
127    * > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
128    * >      [] (Tensor a, bool b) -> Tensor {...});
129    * > Tensor result = func.call<Tensor, Tensor, bool>(tensor1, true);
130    *
131    * Or, with a boxed implementation:
132    *
133    * > void boxed_func(OperatorKernel*, Stack* stack) {...}
134    * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func);
135    * > Tensor result = func.call<Tensor, Tensor, bool>(tensor1, true);
136    */
137   template<class Return, class... Args>
138   Return call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const;
139 
140   /**
141    * Create a KernelFunction from a BoxedKernel.
142    */
143   static KernelFunction makeFromBoxedKernel(BoxedKernel boxed_fn);
144 
145   /**
146    * Create a KernelFunction from a boxed function.
147    *
148    * Example:
149    *
150    * > void boxed_func(OperatorKernel*, Stack* stack) {...}
151    * > KernelFunction func = KernelFunction::makeFromBoxedFunction<&boxed_func>();
152    */
153   template<BoxedKernelFunction* func>
154   static KernelFunction makeFromBoxedFunction();
155 
156   /**
157    * TODO: This will only be useful if we write a backend fallback that plumbs dispatch keys (currently there are none)
158    * See Note [Plumbing Keys Through The Dispatcher] for details.
159    */
160   template<BoxedKernelFunction_withDispatchKeys* func>
161   static KernelFunction makeFromBoxedFunction();
162 
163   /**
164    * Create a KernelFunction from an unboxed functor.
165    *
166    * Example:
167    *
168    * > class MyFunctor final : public c10::OperatorKernel {
169    * >   public:
170    * >     Tensor operator()(Tensor a, Tensor b) {...}
171    * > };
172    * > KernelFunction func = KernelFunction::makeFromUnboxedFunctor<MyFunctor>(std::make_unique<MyFunctor>());
173    */
174   template<bool AllowLegacyTypes = false, class KernelFunctor>
175   static KernelFunction makeFromUnboxedFunctor(std::unique_ptr<OperatorKernel> kernelFunctor);
176 
177   /**
178    * Create a KernelFunction from a boxed functor.
179    *
180    * Example:
181    *
182    * > class MyFunctor final : public c10::OperatorKernel {
183    * >   public:
184    * >     void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...}
185    * > };
186    * > KernelFunction func = KernelFunction::makeFromBoxedFunctor(std::make_unique<MyFunctor>());
187    */
188   template<class KernelFunctor>
189   static KernelFunction makeFromBoxedFunctor(std::unique_ptr<KernelFunctor> kernelFunctor);
190 
191   /**
192    * Create a KernelFunction from an unboxed function.
193    * This is usually better than KernelFunction::makeFromUnboxedRuntimeFunction
194    * because knowing the function pointer as a template argument (i.e. at
195    * compile time) allows the compiler to inline the function into its
196    * unboxing wrapper and yields better performance when calling the function.
197    *
198    * Example:
199    *
200    * > Tensor unboxed_func(Tensor a, Tensor b) {...}
201    * > KernelFunction func = KernelFunction::makeFromUnboxedFunction<decltype(unboxed_func), &unboxed_func>();
202    */
203   template<class FuncPtr, bool AllowLegacyTypes = false>
204   static KernelFunction makeFromUnboxedFunction(FuncPtr);
205 
206   /**
207    * Create a KernelFunction from an unboxed function.
208    * KernelFunction::makeFromUnboxedFunction is usually a better choice than
209    * this if you know the function pointer at compile time, see doc comment
210    * there for an explanation.
211    *
212    * Example:
213    *
214    * > Tensor unboxed_func(Tensor a, Tensor b) {...}
215    * > KernelFunction func = KernelFunction::makeFromUnboxedRuntimeFunction(&unboxed_func);
216    */
217   template<bool AllowLegacyTypes = false, class FuncType>
218   static KernelFunction makeFromUnboxedRuntimeFunction(FuncType* func);
219 
220   static KernelFunction makeFallthrough();
221   static KernelFunction makeAmbiguousAutogradOther();
222   static KernelFunction makeNamedNotSupported();
223 
224   /**
225    * Create a KernelFunction from an unboxed lambda.
226    *
227    * Example:
228    *
229    * > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
230    * >      [] (Tensor a, bool b) -> Tensor {...});
231    */
232   template<bool AllowLegacyTypes = false, class Lambda>
233   static std::enable_if_t<guts::is_stateless_lambda<std::decay_t<Lambda>>::value, KernelFunction> makeFromUnboxedLambda(Lambda&& lambda);
234   template<bool AllowLegacyTypes = false, class Lambda>
235   static std::enable_if_t<!guts::is_stateless_lambda<std::decay_t<Lambda>>::value, KernelFunction> makeFromUnboxedLambda(Lambda&& lambda);
236 
237   std::string dumpState() const;
238   // For testing internal invariants only
239   bool _equalsBoxedAndUnboxed(const KernelFunction&) const;
240 
241 private:
242 
243   explicit KernelFunction(
244       std::unique_ptr<OperatorKernel> functor,
245       InternalBoxedKernelFunction* boxed_kernel_func,
246       void* unboxed_kernel_func,
247       void* sym_unboxed_kernel_func);
248   explicit KernelFunction(
249       BoxedKernel boxed_fn,
250       void* unboxed_kernel_func,
251       void* sym_unboxed_kernel_func);
252 
253   BoxedKernel boxed_kernel_func_;
254   void* unboxed_kernel_func_;
255   void* sym_unboxed_kernel_func_;
256 };
257 
258 }
259 
260 #include <ATen/core/boxing/KernelFunction_impl.h>
261