xref: /aosp_15_r20/external/pytorch/c10/util/TypeTraits.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <functional>
4 #include <type_traits>
5 
6 namespace c10::guts {
7 
8 /**
9  * is_equality_comparable<T> is true_type iff the equality operator is defined
10  * for T.
11  */
12 template <class T, class Enable = void>
13 struct is_equality_comparable : std::false_type {};
14 template <class T>
15 struct is_equality_comparable<
16     T,
17     std::void_t<decltype(std::declval<T&>() == std::declval<T&>())>>
18     : std::true_type {};
19 template <class T>
20 using is_equality_comparable_t = typename is_equality_comparable<T>::type;
21 
22 /**
23  * is_hashable<T> is true_type iff std::hash is defined for T
24  */
25 template <class T, class Enable = void>
26 struct is_hashable : std::false_type {};
27 template <class T>
28 struct is_hashable<T, std::void_t<decltype(std::hash<T>()(std::declval<T&>()))>>
29     : std::true_type {};
30 template <class T>
31 using is_hashable_t = typename is_hashable<T>::type;
32 
33 /**
34  * is_function_type<T> is true_type iff T is a plain function type (i.e.
35  * "Result(Args...)")
36  */
37 template <class T>
38 struct is_function_type : std::false_type {};
39 template <class Result, class... Args>
40 struct is_function_type<Result(Args...)> : std::true_type {};
41 template <class T>
42 using is_function_type_t = typename is_function_type<T>::type;
43 
44 /**
45  * is_instantiation_of<T, I> is true_type iff I is a template instantiation of T
46  * (e.g. vector<int> is an instantiation of vector) Example:
47  *    is_instantiation_of_t<vector, vector<int>> // true
48  *    is_instantiation_of_t<pair, pair<int, string>> // true
49  *    is_instantiation_of_t<vector, pair<int, string>> // false
50  */
51 template <template <class...> class Template, class T>
52 struct is_instantiation_of : std::false_type {};
53 template <template <class...> class Template, class... Args>
54 struct is_instantiation_of<Template, Template<Args...>> : std::true_type {};
55 template <template <class...> class Template, class T>
56 using is_instantiation_of_t = typename is_instantiation_of<Template, T>::type;
57 
58 namespace detail {
59 /**
60  * strip_class: helper to remove the class type from pointers to `operator()`.
61  */
62 
63 template <typename T>
64 struct strip_class {};
65 template <typename Class, typename Result, typename... Args>
66 struct strip_class<Result (Class::*)(Args...)> {
67   using type = Result(Args...);
68 };
69 template <typename Class, typename Result, typename... Args>
70 struct strip_class<Result (Class::*)(Args...) const> {
71   using type = Result(Args...);
72 };
73 template <typename T>
74 using strip_class_t = typename strip_class<T>::type;
75 } // namespace detail
76 
77 /**
78  * Evaluates to true_type, iff the given class is a Functor
79  * (i.e. has a call operator with some set of arguments)
80  */
81 
82 template <class Functor, class Enable = void>
83 struct is_functor : std::false_type {};
84 template <class Functor>
85 struct is_functor<
86     Functor,
87     std::enable_if_t<is_function_type<
88         detail::strip_class_t<decltype(&Functor::operator())>>::value>>
89     : std::true_type {};
90 
91 /**
92  * lambda_is_stateless<T> is true iff the lambda type T is stateless
93  * (i.e. does not have a closure).
94  * Example:
95  *  auto stateless_lambda = [] (int a) {return a;};
96  *  lambda_is_stateless<decltype(stateless_lambda)> // true
97  *  auto stateful_lambda = [&] (int a) {return a;};
98  *  lambda_is_stateless<decltype(stateful_lambda)> // false
99  */
100 namespace detail {
101 template <class LambdaType, class FuncType>
102 struct is_stateless_lambda__ final {
103   static_assert(
104       !std::is_same_v<LambdaType, LambdaType>,
105       "Base case shouldn't be hit");
106 };
107 // implementation idea: According to the C++ standard, stateless lambdas are
108 // convertible to function pointers
109 template <class LambdaType, class C, class Result, class... Args>
110 struct is_stateless_lambda__<LambdaType, Result (C::*)(Args...) const>
111     : std::is_convertible<LambdaType, Result (*)(Args...)> {};
112 template <class LambdaType, class C, class Result, class... Args>
113 struct is_stateless_lambda__<LambdaType, Result (C::*)(Args...)>
114     : std::is_convertible<LambdaType, Result (*)(Args...)> {};
115 
116 // case where LambdaType is not even a functor
117 template <class LambdaType, class Enable = void>
118 struct is_stateless_lambda_ final : std::false_type {};
119 // case where LambdaType is a functor
120 template <class LambdaType>
121 struct is_stateless_lambda_<
122     LambdaType,
123     std::enable_if_t<is_functor<LambdaType>::value>>
124     : is_stateless_lambda__<LambdaType, decltype(&LambdaType::operator())> {};
125 } // namespace detail
126 template <class T>
127 using is_stateless_lambda = detail::is_stateless_lambda_<std::decay_t<T>>;
128 
129 /**
130  * is_type_condition<C> is true_type iff C<...> is a type trait representing a
131  * condition (i.e. has a constexpr static bool ::value member) Example:
132  *   is_type_condition<std::is_reference>  // true
133  */
134 template <template <class> class C, class Enable = void>
135 struct is_type_condition : std::false_type {};
136 template <template <class> class C>
137 struct is_type_condition<
138     C,
139     std::enable_if_t<
140         std::is_same_v<bool, std::remove_cv_t<decltype(C<int>::value)>>>>
141     : std::true_type {};
142 
143 /**
144  * is_fundamental<T> is true_type iff the lambda type T is a fundamental type
145  * (that is, arithmetic type, void, or nullptr_t). Example: is_fundamental<int>
146  * // true We define it here to resolve a MSVC bug. See
147  * https://github.com/pytorch/pytorch/issues/30932 for details.
148  */
149 template <class T>
150 struct is_fundamental : std::is_fundamental<T> {};
151 } // namespace c10::guts
152