xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/CompileTimeFunctionPointer.h>
4 
5 namespace c10 {
6 namespace impl {
7   namespace detail {
8     template<class FuncPtr, class ReturnType, class ParameterList> class WrapFunctionIntoFunctor_ {};
9     template<class FuncPtr, class ReturnType, class... Parameters>
10     class WrapFunctionIntoFunctor_<FuncPtr, ReturnType, guts::typelist::typelist<Parameters...>> final : public c10::OperatorKernel {
11     public:
decltype(auto)12       C10_ALWAYS_INLINE decltype(auto) operator()(Parameters... args) {
13         return (*FuncPtr::func_ptr())(std::forward<Parameters>(args)...);
14       }
15     };
16   }
17 
18   // WrapFunctionIntoFunctor: Wraps a compile time function pointer into a kernel functor.
19   // Since it is a compile time function pointer, many compilers can inline it
20   // into the wrapper and you don't get any performance overhead for wrapping.
21   template<class FuncPtr>
22   struct WrapFunctionIntoFunctor final {
23     static_assert(c10::is_compile_time_function_pointer<FuncPtr>::value, "WrapFunctionIntoFunctor can only wrap functions created with TORCH_FN.");
24     using type = detail::WrapFunctionIntoFunctor_<
25         FuncPtr,
26         typename guts::function_traits<typename FuncPtr::FuncType>::return_type,
27         typename guts::function_traits<typename FuncPtr::FuncType>::parameter_types
28     >;
29   };
30 }
31 
32 }
33