xref: /aosp_15_r20/external/pytorch/aten/src/ATen/detail/FunctionTraits.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstddef>
4 #include <tuple>
5 
6 // Modified from https://stackoverflow.com/questions/7943525/is-it-possible-to-figure-out-the-parameter-type-and-return-type-of-a-lambda
7 
8 // Fallback, anything with an operator()
9 template <typename T>
10 struct function_traits : public function_traits<decltype(&T::operator())> {
11 };
12 
13 // Pointers to class members that are themselves functors.
14 // For example, in the following code:
15 // template <typename func_t>
16 // struct S {
17 //     func_t f;
18 // };
19 // template <typename func_t>
20 // S<func_t> make_s(func_t f) {
21 //     return S<func_t> { .f = f };
22 // }
23 //
24 // auto s = make_s([] (int, float) -> double { /* ... */ });
25 //
26 // function_traits<decltype(&s::f)> traits;
27 template <typename ClassType, typename T>
28 struct function_traits<T ClassType::*> : public function_traits<T> {
29 };
30 
31 // Const class member functions
32 template <typename ClassType, typename ReturnType, typename... Args>
33 struct function_traits<ReturnType(ClassType::*)(Args...) const> : public function_traits<ReturnType(Args...)> {
34 };
35 
36 // Reference types
37 template <typename T>
38 struct function_traits<T&> : public function_traits<T> {};
39 template <typename T>
40 struct function_traits<T*> : public function_traits<T> {};
41 
42 // Free functions
43 template <typename ReturnType, typename... Args>
44 struct function_traits<ReturnType(Args...)> {
45   // arity is the number of arguments.
46   enum { arity = sizeof...(Args) };
47 
48   using ArgsTuple = std::tuple<Args...>;
49   using result_type = ReturnType;
50 
51   template <size_t i>
52   struct arg
53   {
54       using type = typename std::tuple_element<i, std::tuple<Args...>>::type;
55       // the i-th argument is equivalent to the i-th tuple element of a tuple
56       // composed of those arguments.
57   };
58 };
59 
60 template <typename T>
61 struct nullary_function_traits {
62   using traits = function_traits<T>;
63   using result_type = typename traits::result_type;
64 };
65 
66 template <typename T>
67 struct unary_function_traits {
68   using traits = function_traits<T>;
69   using result_type = typename traits::result_type;
70   using arg1_t = typename traits::template arg<0>::type;
71 };
72 
73 template <typename T>
74 struct binary_function_traits {
75   using traits = function_traits<T>;
76   using result_type = typename traits::result_type;
77   using arg1_t = typename traits::template arg<0>::type;
78   using arg2_t = typename traits::template arg<1>::type;
79 };
80 
81 
82 // Traits for calling with c10::guts::invoke, where member_functions have a first argument of ClassType
83 template <typename T>
84 struct invoke_traits : public function_traits<T>{
85 };
86 
87 template <typename T>
88 struct invoke_traits<T&> : public invoke_traits<T>{
89 };
90 
91 template <typename T>
92 struct invoke_traits<T&&> : public invoke_traits<T>{
93 };
94 
95 template <typename ClassType, typename ReturnType, typename... Args>
96 struct invoke_traits<ReturnType(ClassType::*)(Args...)> :
97   public function_traits<ReturnType(ClassType&, Args...)> {
98 };
99 
100 template <typename ClassType, typename ReturnType, typename... Args>
101 struct invoke_traits<ReturnType(ClassType::*)(Args...) const> :
102   public function_traits<ReturnType(const ClassType&, Args...)> {
103 };
104