xref: /aosp_15_r20/external/pytorch/c10/test/util/TypeIndex_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/Metaprogramming.h>
2 #include <c10/util/TypeIndex.h>
3 #include <gtest/gtest.h>
4 
5 using c10::string_view;
6 using c10::util::get_fully_qualified_type_name;
7 using c10::util::get_type_index;
8 
9 // NOLINTBEGIN(modernize-unary-static-assert)
10 namespace {
11 
12 static_assert(get_type_index<int>() == get_type_index<int>(), "");
13 static_assert(get_type_index<float>() == get_type_index<float>(), "");
14 static_assert(get_type_index<int>() != get_type_index<float>(), "");
15 static_assert(
16     get_type_index<int(double, double)>() ==
17         get_type_index<int(double, double)>(),
18     "");
19 static_assert(
20     get_type_index<int(double, double)>() != get_type_index<int(double)>(),
21     "");
22 static_assert(
23     get_type_index<int(double, double)>() ==
24         get_type_index<int (*)(double, double)>(),
25     "");
26 static_assert(
27     get_type_index<std::function<int(double, double)>>() ==
28         get_type_index<std::function<int(double, double)>>(),
29     "");
30 static_assert(
31     get_type_index<std::function<int(double, double)>>() !=
32         get_type_index<std::function<int(double)>>(),
33     "");
34 
35 static_assert(get_type_index<int>() == get_type_index<int&>(), "");
36 static_assert(get_type_index<int>() == get_type_index<int&&>(), "");
37 static_assert(get_type_index<int>() == get_type_index<const int&>(), "");
38 static_assert(get_type_index<int>() == get_type_index<const int>(), "");
39 static_assert(get_type_index<const int>() == get_type_index<int&>(), "");
40 static_assert(get_type_index<int>() != get_type_index<int*>(), "");
41 static_assert(get_type_index<int*>() != get_type_index<int**>(), "");
42 static_assert(
43     get_type_index<int(double&, double)>() !=
44         get_type_index<int(double, double)>(),
45     "");
46 
47 struct Dummy final {};
48 struct Functor final {
49   int64_t operator()(uint32_t, Dummy&&, const Dummy&) const;
50 };
51 static_assert(
52     get_type_index<int64_t(uint32_t, Dummy&&, const Dummy&)>() ==
53         get_type_index<
54             c10::guts::infer_function_traits_t<Functor>::func_type>(),
55     "");
56 
57 namespace test_top_level_name {
58 #if C10_TYPENAME_SUPPORTS_CONSTEXPR
59 static_assert(
60     string_view::npos != get_fully_qualified_type_name<Dummy>().find("Dummy"),
61     "");
62 #endif
TEST(TypeIndex,TopLevelName)63 TEST(TypeIndex, TopLevelName) {
64   EXPECT_NE(
65       string_view::npos, get_fully_qualified_type_name<Dummy>().find("Dummy"));
66 }
67 } // namespace test_top_level_name
68 
69 namespace test_nested_name {
70 struct Dummy final {};
71 
72 #if C10_TYPENAME_SUPPORTS_CONSTEXPR
73 static_assert(
74     string_view::npos !=
75         get_fully_qualified_type_name<Dummy>().find("test_nested_name::Dummy"),
76     "");
77 #endif
TEST(TypeIndex,NestedName)78 TEST(TypeIndex, NestedName) {
79   EXPECT_NE(
80       string_view::npos,
81       get_fully_qualified_type_name<Dummy>().find("test_nested_name::Dummy"));
82 }
83 } // namespace test_nested_name
84 
85 namespace test_type_template_parameter {
86 template <class T>
87 struct Outer final {};
88 struct Inner final {};
89 
90 #if C10_TYPENAME_SUPPORTS_CONSTEXPR
91 static_assert(
92     string_view::npos !=
93         get_fully_qualified_type_name<Outer<Inner>>().find(
94             "test_type_template_parameter::Outer"),
95     "");
96 static_assert(
97     string_view::npos !=
98         get_fully_qualified_type_name<Outer<Inner>>().find(
99             "test_type_template_parameter::Inner"),
100     "");
101 #endif
TEST(TypeIndex,TypeTemplateParameter)102 TEST(TypeIndex, TypeTemplateParameter) {
103   EXPECT_NE(
104       string_view::npos,
105       get_fully_qualified_type_name<Outer<Inner>>().find(
106           "test_type_template_parameter::Outer"));
107   EXPECT_NE(
108       string_view::npos,
109       get_fully_qualified_type_name<Outer<Inner>>().find(
110           "test_type_template_parameter::Inner"));
111 }
112 } // namespace test_type_template_parameter
113 
114 namespace test_nontype_template_parameter {
115 template <size_t N>
116 struct Class final {};
117 
118 #if C10_TYPENAME_SUPPORTS_CONSTEXPR
119 static_assert(
120     string_view::npos !=
121         get_fully_qualified_type_name<Class<38474355>>().find("38474355"),
122     "");
123 #endif
TEST(TypeIndex,NonTypeTemplateParameter)124 TEST(TypeIndex, NonTypeTemplateParameter) {
125   EXPECT_NE(
126       string_view::npos,
127       get_fully_qualified_type_name<Class<38474355>>().find("38474355"));
128 }
129 } // namespace test_nontype_template_parameter
130 
131 namespace test_type_computations_are_resolved {
132 template <class T>
133 struct Type final {
134   using type = const T*;
135 };
136 
137 #if C10_TYPENAME_SUPPORTS_CONSTEXPR
138 static_assert(
139     string_view::npos !=
140         get_fully_qualified_type_name<typename Type<int>::type>().find("int"),
141     "");
142 static_assert(
143     string_view::npos !=
144         get_fully_qualified_type_name<typename Type<int>::type>().find("*"),
145     "");
146 
147 // but with remove_pointer applied, there is no '*' in the type name anymore
148 static_assert(
149     string_view::npos ==
150         get_fully_qualified_type_name<
151             typename std::remove_pointer<typename Type<int>::type>::type>()
152             .find("*"),
153     "");
154 #endif
TEST(TypeIndex,TypeComputationsAreResolved)155 TEST(TypeIndex, TypeComputationsAreResolved) {
156   EXPECT_NE(
157       string_view::npos,
158       get_fully_qualified_type_name<typename Type<int>::type>().find("int"));
159   EXPECT_NE(
160       string_view::npos,
161       get_fully_qualified_type_name<typename Type<int>::type>().find("*"));
162   // but with remove_pointer applied, there is no '*' in the type name anymore
163   EXPECT_EQ(
164       string_view::npos,
165       get_fully_qualified_type_name<
166           typename std::remove_pointer<typename Type<int>::type>::type>()
167           .find("*"));
168 }
169 
170 struct Functor final {
171   std::string operator()(int64_t a, const Type<int>& b) const;
172 };
173 #if C10_TYPENAME_SUPPORTS_CONSTEXPR
174 static_assert(
175     // NOLINTNEXTLINE(misc-redundant-expression)
176     get_fully_qualified_type_name<std::string(int64_t, const Type<int>&)>() ==
177         get_fully_qualified_type_name<
178             typename c10::guts::infer_function_traits_t<Functor>::func_type>(),
179     "");
180 #endif
TEST(TypeIndex,FunctionTypeComputationsAreResolved)181 TEST(TypeIndex, FunctionTypeComputationsAreResolved) {
182   EXPECT_EQ(
183       get_fully_qualified_type_name<std::string(int64_t, const Type<int>&)>(),
184       get_fully_qualified_type_name<
185           typename c10::guts::infer_function_traits_t<Functor>::func_type>());
186 }
187 } // namespace test_type_computations_are_resolved
188 
189 namespace test_function_arguments_and_returns {
190 class Dummy final {};
191 
192 #if C10_TYPENAME_SUPPORTS_CONSTEXPR
193 static_assert(
194     string_view::npos !=
195         get_fully_qualified_type_name<Dummy(int)>().find(
196             "test_function_arguments_and_returns::Dummy"),
197     "");
198 static_assert(
199     string_view::npos !=
200         get_fully_qualified_type_name<void(Dummy)>().find(
201             "test_function_arguments_and_returns::Dummy"),
202     "");
203 #endif
TEST(TypeIndex,FunctionArgumentsAndReturns)204 TEST(TypeIndex, FunctionArgumentsAndReturns) {
205   EXPECT_NE(
206       string_view::npos,
207       get_fully_qualified_type_name<Dummy(int)>().find(
208           "test_function_arguments_and_returns::Dummy"));
209   EXPECT_NE(
210       string_view::npos,
211       get_fully_qualified_type_name<void(Dummy)>().find(
212           "test_function_arguments_and_returns::Dummy"));
213 }
214 } // namespace test_function_arguments_and_returns
215 } // namespace
216 // NOLINTEND(modernize-unary-static-assert)
217