1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 #if __cplusplus < 201703L 11 #error "This header requires C++17" 12 #endif 13 14 #include <executorch/extension/kernel_util/type_list.h> 15 #include <cstdlib> 16 #include <memory> 17 #include <type_traits> 18 #include <typeinfo> 19 20 namespace executorch { 21 namespace extension { 22 // This extension has a lot of generic internal names like "size"; use a unique 23 // internal namespace to avoid conflicts with other extensions. 24 namespace kernel_util_internal { 25 26 // Check if a given type is a function 27 template <class T> 28 struct is_function_type : std::false_type {}; 29 template <class Result, class... Args> 30 struct is_function_type<Result(Args...)> : std::true_type {}; 31 template <class T> 32 using is_function_type_t = typename is_function_type<T>::type; 33 34 // A compile-time wrapper around a function pointer 35 template <class FuncType_, FuncType_* func_ptr_> 36 struct CompileTimeFunctionPointer final { 37 static_assert( 38 is_function_type<FuncType_>::value, 39 "EXECUTORCH_FN can only wrap function types."); 40 using FuncType = FuncType_; 41 42 static constexpr FuncType* func_ptr() { 43 return func_ptr_; 44 } 45 }; 46 47 // Check if a given type is a compile-time function pointer 48 template <class T> 49 struct is_compile_time_function_pointer : std::false_type {}; 50 template <class FuncType, FuncType* func_ptr> 51 struct is_compile_time_function_pointer< 52 CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {}; 53 54 #define EXECUTORCH_FN_TYPE(func) \ 55 ::executorch::extension::kernel_util_internal::CompileTimeFunctionPointer< \ 56 std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \ 57 func> 58 #define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)() 59 60 /** 61 * strip_class: helper to remove the class type from pointers to `operator()`. 62 */ 63 template <typename T> 64 struct strip_class {}; 65 template <typename Class, typename Result, typename... Args> 66 struct strip_class<Result (Class::*)(Args...)> { 67 using type = Result(Args...); 68 }; 69 template <typename Class, typename Result, typename... Args> 70 struct strip_class<Result (Class::*)(Args...) const> { 71 using type = Result(Args...); 72 }; 73 template <typename T> 74 using strip_class_t = typename strip_class<T>::type; 75 76 /** 77 * Access information about result type or arguments from a function type. 78 * Example: 79 * using A = function_traits<int (float, double)>::return_type // A == int 80 * using A = function_traits<int (float, double)>::parameter_types::tuple_type 81 * // A == tuple<float, double> 82 */ 83 template <class Func> 84 struct function_traits { 85 static_assert( 86 !std::is_same<Func, Func>::value, 87 "In function_traits<Func>, Func must be a plain function type."); 88 }; 89 template <class Result, class... Args> 90 struct function_traits<Result(Args...)> { 91 using func_type = Result(Args...); 92 using return_type = Result; 93 using parameter_types = typelist<Args...>; 94 static constexpr auto number_of_parameters = sizeof...(Args); 95 }; 96 97 /** 98 * infer_function_traits: creates a `function_traits` type for a simple 99 * function (pointer) or functor (lambda/struct). Currently does not support 100 * class methods. 101 */ 102 template <typename Functor> 103 struct infer_function_traits { 104 using type = function_traits<strip_class_t<decltype(&Functor::operator())>>; 105 }; 106 template <typename Result, typename... Args> 107 struct infer_function_traits<Result (*)(Args...)> { 108 using type = function_traits<Result(Args...)>; 109 }; 110 template <typename Result, typename... Args> 111 struct infer_function_traits<Result(Args...)> { 112 using type = function_traits<Result(Args...)>; 113 }; 114 template <typename T> 115 using infer_function_traits_t = typename infer_function_traits<T>::type; 116 117 } // namespace kernel_util_internal 118 } // namespace extension 119 } // namespace executorch 120