xref: /aosp_15_r20/external/pytorch/c10/test/util/TypeList_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/TypeList.h>
2 #include <gtest/gtest.h>
3 #include <memory>
4 
5 using namespace c10::guts::typelist;
6 // NOLINTBEGIN(modernize-unary-static-assert)
7 namespace test_size {
8 class MyClass {};
9 static_assert(0 == size<typelist<>>::value, "");
10 static_assert(1 == size<typelist<int>>::value, "");
11 static_assert(3 == size<typelist<int, float&, const MyClass&&>>::value, "");
12 } // namespace test_size
13 
14 namespace test_from_tuple {
15 class MyClass {};
16 static_assert(
17     std::is_same<
18         typelist<int, float&, const MyClass&&>,
19         from_tuple_t<std::tuple<int, float&, const MyClass&&>>>::value,
20     "");
21 static_assert(std::is_same<typelist<>, from_tuple_t<std::tuple<>>>::value, "");
22 } // namespace test_from_tuple
23 
24 namespace test_to_tuple {
25 class MyClass {};
26 static_assert(
27     std::is_same<
28         std::tuple<int, float&, const MyClass&&>,
29         to_tuple_t<typelist<int, float&, const MyClass&&>>>::value,
30     "");
31 static_assert(std::is_same<std::tuple<>, to_tuple_t<typelist<>>>::value, "");
32 } // namespace test_to_tuple
33 
34 namespace test_concat {
35 class MyClass {};
36 static_assert(std::is_same<typelist<>, concat_t<>>::value, "");
37 static_assert(std::is_same<typelist<>, concat_t<typelist<>>>::value, "");
38 static_assert(
39     std::is_same<typelist<>, concat_t<typelist<>, typelist<>>>::value,
40     "");
41 static_assert(std::is_same<typelist<int>, concat_t<typelist<int>>>::value, "");
42 static_assert(
43     std::is_same<typelist<int>, concat_t<typelist<int>, typelist<>>>::value,
44     "");
45 static_assert(
46     std::is_same<typelist<int>, concat_t<typelist<>, typelist<int>>>::value,
47     "");
48 static_assert(
49     std::is_same<
50         typelist<int>,
51         concat_t<typelist<>, typelist<int>, typelist<>>>::value,
52     "");
53 static_assert(
54     std::is_same<
55         typelist<int, float&>,
56         concat_t<typelist<int>, typelist<float&>>>::value,
57     "");
58 static_assert(
59     std::is_same<
60         typelist<int, float&>,
61         concat_t<typelist<>, typelist<int, float&>, typelist<>>>::value,
62     "");
63 static_assert(
64     std::is_same<
65         typelist<int, float&, const MyClass&&>,
66         concat_t<
67             typelist<>,
68             typelist<int, float&>,
69             typelist<const MyClass&&>>>::value,
70     "");
71 } // namespace test_concat
72 
73 namespace test_filter {
74 class MyClass {};
75 static_assert(
76     std::is_same<typelist<>, filter_t<std::is_reference, typelist<>>>::value,
77     "");
78 static_assert(
79     std::is_same<
80         typelist<>,
81         filter_t<std::is_reference, typelist<int, float, double, MyClass>>>::
82         value,
83     "");
84 static_assert(
85     std::is_same<
86         typelist<float&, const MyClass&&>,
87         filter_t<
88             std::is_reference,
89             typelist<int, float&, double, const MyClass&&>>>::value,
90     "");
91 } // namespace test_filter
92 
93 namespace test_count_if {
94 class MyClass final {};
95 static_assert(
96     count_if<
97         std::is_reference,
98         typelist<int, bool&, const MyClass&&, float, double>>::value == 2,
99     "");
100 static_assert(count_if<std::is_reference, typelist<int, bool>>::value == 0, "");
101 static_assert(count_if<std::is_reference, typelist<>>::value == 0, "");
102 } // namespace test_count_if
103 
104 namespace test_true_for_each_type {
105 template <class>
106 class Test;
107 class MyClass {};
108 static_assert(
109     all<std::is_reference,
110         typelist<int&, const float&&, const MyClass&>>::value,
111     "");
112 static_assert(
113     !all<std::is_reference, typelist<int&, const float, const MyClass&>>::value,
114     "");
115 static_assert(all<std::is_reference, typelist<>>::value, "");
116 } // namespace test_true_for_each_type
117 
118 namespace test_true_for_any_type {
119 template <class>
120 class Test;
121 class MyClass {};
122 static_assert(
123     true_for_any_type<
124         std::is_reference,
125         typelist<int&, const float&&, const MyClass&>>::value,
126     "");
127 static_assert(
128     true_for_any_type<
129         std::is_reference,
130         typelist<int&, const float, const MyClass&>>::value,
131     "");
132 static_assert(
133     !true_for_any_type<
134         std::is_reference,
135         typelist<int, const float, const MyClass>>::value,
136     "");
137 static_assert(!true_for_any_type<std::is_reference, typelist<>>::value, "");
138 } // namespace test_true_for_any_type
139 
140 namespace test_map {
141 class MyClass {};
142 static_assert(
143     std::is_same<typelist<>, map_t<std::add_lvalue_reference_t, typelist<>>>::
144         value,
145     "");
146 static_assert(
147     std::is_same<
148         typelist<int&>,
149         map_t<std::add_lvalue_reference_t, typelist<int>>>::value,
150     "");
151 static_assert(
152     std::is_same<
153         typelist<int&, double&, const MyClass&>,
154         map_t<
155             std::add_lvalue_reference_t,
156             typelist<int, double, const MyClass>>>::value,
157     "");
158 } // namespace test_map
159 
160 namespace test_head {
161 class MyClass {};
162 static_assert(std::is_same<int, head_t<typelist<int, double>>>::value, "");
163 static_assert(
164     std::is_same<const MyClass&, head_t<typelist<const MyClass&, double>>>::
165         value,
166     "");
167 static_assert(
168     std::is_same<MyClass&&, head_t<typelist<MyClass&&, MyClass>>>::value,
169     "");
170 static_assert(std::is_same<bool, head_t<typelist<bool>>>::value, "");
171 } // namespace test_head
172 
173 namespace test_head_with_default {
174 class MyClass {};
175 static_assert(
176     std::is_same<int, head_with_default_t<bool, typelist<int, double>>>::value,
177     "");
178 static_assert(
179     std::is_same<
180         const MyClass&,
181         head_with_default_t<bool, typelist<const MyClass&, double>>>::value,
182     "");
183 static_assert(
184     std::is_same<
185         MyClass&&,
186         head_with_default_t<bool, typelist<MyClass&&, MyClass>>>::value,
187     "");
188 static_assert(
189     std::is_same<int, head_with_default_t<bool, typelist<int>>>::value,
190     "");
191 static_assert(
192     std::is_same<bool, head_with_default_t<bool, typelist<>>>::value,
193     "");
194 } // namespace test_head_with_default
195 
196 namespace test_reverse {
197 class MyClass {};
198 static_assert(
199     std::is_same<
200         typelist<int, double, MyClass*, const MyClass&&>,
201         reverse_t<typelist<const MyClass&&, MyClass*, double, int>>>::value,
202     "");
203 static_assert(std::is_same<typelist<>, reverse_t<typelist<>>>::value, "");
204 } // namespace test_reverse
205 
206 namespace test_map_types_to_values {
207 struct map_to_size {
208   template <class T>
operator ()test_map_types_to_values::map_to_size209   constexpr size_t operator()(T) const {
210     return sizeof(typename T::type);
211   }
212 };
213 
TEST(TypeListTest,MapTypesToValues_sametype)214 TEST(TypeListTest, MapTypesToValues_sametype) {
215   auto sizes =
216       map_types_to_values<typelist<int64_t, bool, uint32_t>>(map_to_size());
217   std::tuple<size_t, size_t, size_t> expected(8, 1, 4);
218   static_assert(std::is_same<decltype(expected), decltype(sizes)>::value, "");
219   EXPECT_EQ(expected, sizes);
220 }
221 
222 struct map_make_shared {
223   template <class T>
operator ()test_map_types_to_values::map_make_shared224   std::shared_ptr<typename T::type> operator()(T) {
225     return std::make_shared<typename T::type>();
226   }
227 };
228 
TEST(TypeListTest,MapTypesToValues_differenttypes)229 TEST(TypeListTest, MapTypesToValues_differenttypes) {
230   auto shared_ptrs =
231       map_types_to_values<typelist<int, double>>(map_make_shared());
232   static_assert(
233       std::is_same<
234           std::tuple<std::shared_ptr<int>, std::shared_ptr<double>>,
235           decltype(shared_ptrs)>::value,
236       "");
237 }
238 
239 struct Class1 {
functest_map_types_to_values::Class1240   static int func() {
241     return 3;
242   }
243 };
244 struct Class2 {
functest_map_types_to_values::Class2245   static double func() {
246     return 2.0;
247   }
248 };
249 
250 struct mapper_call_func {
251   template <class T>
operator ()test_map_types_to_values::mapper_call_func252   decltype(auto) operator()(T) {
253     return T::type::func();
254   }
255 };
256 
TEST(TypeListTest,MapTypesToValues_members)257 TEST(TypeListTest, MapTypesToValues_members) {
258   auto result =
259       map_types_to_values<typelist<Class1, Class2>>(mapper_call_func());
260   std::tuple<int, double> expected(3, 2.0);
261   static_assert(std::is_same<decltype(expected), decltype(result)>::value, "");
262   EXPECT_EQ(expected, result);
263 }
264 
265 struct mapper_call_nonexistent_function {
266   template <class T>
operator ()test_map_types_to_values::mapper_call_nonexistent_function267   decltype(auto) operator()(T) {
268     return T::type::this_doesnt_exist();
269   }
270 };
271 
TEST(TypeListTest,MapTypesToValues_empty)272 TEST(TypeListTest, MapTypesToValues_empty) {
273   auto result =
274       map_types_to_values<typelist<>>(mapper_call_nonexistent_function());
275   std::tuple<> expected;
276   static_assert(std::is_same<decltype(expected), decltype(result)>::value, "");
277   EXPECT_EQ(expected, result);
278 }
279 } // namespace test_map_types_to_values
280 
281 namespace test_find_if {
282 static_assert(0 == find_if<typelist<char&>, std::is_reference>::value, "");
283 static_assert(
284     0 == find_if<typelist<char&, int, char&, int&>, std::is_reference>::value,
285     "");
286 static_assert(
287     2 == find_if<typelist<char, int, char&, int&>, std::is_reference>::value,
288     "");
289 static_assert(
290     3 == find_if<typelist<char, int, char, int&>, std::is_reference>::value,
291     "");
292 } // namespace test_find_if
293 
294 namespace test_contains {
295 static_assert(contains<typelist<double>, double>::value, "");
296 static_assert(contains<typelist<int, double>, double>::value, "");
297 static_assert(!contains<typelist<int, double>, float>::value, "");
298 static_assert(!contains<typelist<>, double>::value, "");
299 } // namespace test_contains
300 
301 namespace test_take {
302 static_assert(std::is_same<typelist<>, take_t<typelist<>, 0>>::value, "");
303 static_assert(
304     std::is_same<typelist<>, take_t<typelist<int64_t>, 0>>::value,
305     "");
306 static_assert(
307     std::is_same<typelist<int64_t>, take_t<typelist<int64_t>, 1>>::value,
308     "");
309 static_assert(
310     std::is_same<typelist<>, take_t<typelist<int64_t, int32_t>, 0>>::value,
311     "");
312 static_assert(
313     std::is_same<typelist<int64_t>, take_t<typelist<int64_t, int32_t>, 1>>::
314         value,
315     "");
316 static_assert(
317     std::is_same<
318         typelist<int64_t, int32_t>,
319         take_t<typelist<int64_t, int32_t>, 2>>::value,
320     "");
321 } // namespace test_take
322 
323 namespace test_drop {
324 static_assert(std::is_same<typelist<>, drop_t<typelist<>, 0>>::value, "");
325 static_assert(
326     std::is_same<typelist<int64_t>, drop_t<typelist<int64_t>, 0>>::value,
327     "");
328 static_assert(
329     std::is_same<typelist<>, drop_t<typelist<int64_t>, 1>>::value,
330     "");
331 static_assert(
332     std::is_same<
333         typelist<int64_t, int32_t>,
334         drop_t<typelist<int64_t, int32_t>, 0>>::value,
335     "");
336 static_assert(
337     std::is_same<typelist<int32_t>, drop_t<typelist<int64_t, int32_t>, 1>>::
338         value,
339     "");
340 static_assert(
341     std::is_same<typelist<>, drop_t<typelist<int64_t, int32_t>, 2>>::value,
342     "");
343 } // namespace test_drop
344 
345 namespace test_drop_if_nonempty {
346 static_assert(
347     std::is_same<typelist<>, drop_if_nonempty_t<typelist<>, 0>>::value,
348     "");
349 static_assert(
350     std::is_same<typelist<int64_t>, drop_if_nonempty_t<typelist<int64_t>, 0>>::
351         value,
352     "");
353 static_assert(
354     std::is_same<typelist<>, drop_if_nonempty_t<typelist<int64_t>, 1>>::value,
355     "");
356 static_assert(
357     std::is_same<
358         typelist<int64_t, int32_t>,
359         drop_if_nonempty_t<typelist<int64_t, int32_t>, 0>>::value,
360     "");
361 static_assert(
362     std::is_same<
363         typelist<int32_t>,
364         drop_if_nonempty_t<typelist<int64_t, int32_t>, 1>>::value,
365     "");
366 static_assert(
367     std::is_same<
368         typelist<>,
369         drop_if_nonempty_t<typelist<int64_t, int32_t>, 2>>::value,
370     "");
371 static_assert(
372     std::is_same<typelist<>, drop_if_nonempty_t<typelist<>, 1>>::value,
373     "");
374 static_assert(
375     std::is_same<
376         typelist<>,
377         drop_if_nonempty_t<typelist<int64_t, int32_t>, 3>>::value,
378     "");
379 } // namespace test_drop_if_nonempty
380 // NOLINTEND(modernize-unary-static-assert)
381