xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/TypeTraits.h>
4 
5 namespace c10 {
6 
7 namespace impl {
8   namespace detail {
9     template<class FuncType, class ReturnType, class ParameterList> class WrapFunctionIntoRuntimeFunctor_ {};
10     template<class FuncType, class ReturnType, class... Parameters>
11     class WrapFunctionIntoRuntimeFunctor_<FuncType, ReturnType, guts::typelist::typelist<Parameters...>> final : public c10::OperatorKernel {
12     public:
13       template<class FuncType_>
WrapFunctionIntoRuntimeFunctor_(FuncType_ && kernel_func)14       explicit WrapFunctionIntoRuntimeFunctor_(FuncType_&& kernel_func)
15       : kernel_func_(std::forward<FuncType_>(kernel_func)) {}
16 
operator()17       decltype(auto) operator()(Parameters... args) {
18         return kernel_func_(std::forward<Parameters>(args)...);
19       }
20 
21     private:
22       FuncType kernel_func_;
23     };
24   }
25 
26   // WrapFunctionIntoRuntimeFunctor: Wraps any runtime functor into a functor that
27   // inherits from c10::OperatorKernel, so it can be used as a c10 kernel.
28   // This can, for example, be used for lambdas, functors or even function pointers.
29   // In the case of function pointers, since it is a runtime function pointer,
30   // there is an overhead for calling it whenever the kernel is invoked.
31   template<class FuncType>
32   using WrapFunctionIntoRuntimeFunctor = detail::WrapFunctionIntoRuntimeFunctor_<
33       FuncType,
34       typename guts::infer_function_traits_t<FuncType>::return_type,
35       typename guts::infer_function_traits_t<FuncType>::parameter_types
36   >;
37 }
38 
39 }
40