xref: /aosp_15_r20/external/executorch/runtime/core/exec_aten/util/test/dim_order_util_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
10 
11 #include <numeric>
12 
13 #include <executorch/runtime/core/exec_aten/exec_aten.h>
14 
15 #include <gtest/gtest.h>
16 
17 using executorch::runtime::dim_order_to_stride;
18 using executorch::runtime::Error;
19 using executorch::runtime::is_channels_last_dim_order;
20 using executorch::runtime::is_contiguous_dim_order;
21 using executorch::runtime::stride_to_dim_order;
22 
23 namespace {
check_strides_eq(exec_aten::ArrayRef<exec_aten::StridesType> strides_a,exec_aten::ArrayRef<exec_aten::StridesType> strides_b)24 void check_strides_eq(
25     exec_aten::ArrayRef<exec_aten::StridesType> strides_a,
26     exec_aten::ArrayRef<exec_aten::StridesType> strides_b) {
27   for (int32_t i = 0; i < strides_a.size(); ++i) {
28     EXPECT_EQ(strides_a[i], strides_b[i]);
29   }
30 }
31 
check_dim_order_eq(exec_aten::ArrayRef<exec_aten::DimOrderType> dim_order_a,exec_aten::ArrayRef<exec_aten::DimOrderType> dim_order_b)32 void check_dim_order_eq(
33     exec_aten::ArrayRef<exec_aten::DimOrderType> dim_order_a,
34     exec_aten::ArrayRef<exec_aten::DimOrderType> dim_order_b) {
35   for (int32_t i = 0; i < dim_order_a.size(); ++i) {
36     EXPECT_EQ(dim_order_a[i], dim_order_b[i]);
37   }
38 }
39 } // namespace
40 
TEST(DimOrderUtilTest,DimOrderToStride)41 TEST(DimOrderUtilTest, DimOrderToStride) {
42   exec_aten::SizesType sizes_1[1] = {5};
43   exec_aten::SizesType dim_order_1[1] = {0};
44   exec_aten::SizesType strides_1[1] = {0};
45   exec_aten::SizesType expected_strides_1[1] = {1};
46   auto error = dim_order_to_stride(sizes_1, dim_order_1, 1, strides_1);
47   EXPECT_EQ(error, Error::Ok);
48   check_strides_eq({strides_1, 1}, {expected_strides_1, 1});
49 
50   exec_aten::SizesType sizes_2[2] = {2, 5};
51   exec_aten::SizesType dim_order_2[2] = {0, 1};
52   exec_aten::SizesType strides_2[2] = {0, 0};
53   exec_aten::SizesType expected_strides_2[2] = {5, 1};
54   error = dim_order_to_stride(sizes_2, dim_order_2, 2, strides_2);
55   EXPECT_EQ(error, Error::Ok);
56   check_strides_eq({strides_2, 2}, {expected_strides_2, 2});
57 
58   dim_order_2[0] = 1;
59   dim_order_2[1] = 0;
60   expected_strides_2[0] = 1;
61   expected_strides_2[1] = 2;
62   error = dim_order_to_stride(sizes_2, dim_order_2, 2, strides_2);
63   EXPECT_EQ(error, Error::Ok);
64   check_strides_eq({strides_2, 2}, {expected_strides_2, 2});
65 
66   exec_aten::SizesType sizes_3[3] = {2, 5, 7};
67   exec_aten::SizesType dim_order_3[3] = {0, 1, 2};
68   exec_aten::SizesType strides_3[3] = {0, 0, 0};
69   exec_aten::SizesType expected_strides_3[3] = {35, 7, 1};
70   error = dim_order_to_stride(sizes_3, dim_order_3, 3, strides_3);
71   EXPECT_EQ(error, Error::Ok);
72   check_strides_eq({strides_3, 3}, {expected_strides_3, 3});
73 
74   // {0, 2, 1}
75   dim_order_3[0] = 0, dim_order_3[1] = 2, dim_order_3[2] = 1;
76   // Expected stride {35, 1, 5}
77   expected_strides_3[0] = 35, expected_strides_3[1] = 1,
78   expected_strides_3[2] = 5;
79   error = dim_order_to_stride(sizes_3, dim_order_3, 3, strides_3);
80   EXPECT_EQ(error, Error::Ok);
81   check_strides_eq({strides_3, 3}, {expected_strides_3, 3});
82 
83   // {2, 5, 7}
84   // {1, 2, 0}
85   dim_order_3[0] = 1, dim_order_3[1] = 2, dim_order_3[2] = 0;
86   // Expected stride {35, 1, 5}
87   expected_strides_3[0] = 1, expected_strides_3[1] = 14,
88   expected_strides_3[2] = 2;
89   error = dim_order_to_stride(sizes_3, dim_order_3, 3, strides_3);
90   EXPECT_EQ(error, Error::Ok);
91   check_strides_eq({strides_3, 3}, {expected_strides_3, 3});
92 
93   exec_aten::SizesType sizes_4[4] = {2, 5, 7, 8};
94   exec_aten::SizesType dim_order_4[4] = {0, 1, 2, 3};
95   exec_aten::SizesType strides_4[4] = {0, 0, 0, 0};
96   exec_aten::SizesType expected_strides_4[4] = {280, 56, 8, 1};
97   error = dim_order_to_stride(sizes_4, dim_order_4, 4, strides_4);
98   EXPECT_EQ(error, Error::Ok);
99   check_strides_eq({strides_4, 4}, {expected_strides_4, 4});
100 
101   // {2, 5, 7, 8}
102   // {0, 2, 3, 1}
103   dim_order_4[0] = 0;
104   dim_order_4[1] = 2;
105   dim_order_4[2] = 3;
106   dim_order_4[3] = 1;
107   // Expected stride {280, 1, 40, 5}
108   expected_strides_4[0] = 280;
109   expected_strides_4[1] = 1;
110   expected_strides_4[2] = 40;
111   expected_strides_4[3] = 5;
112   error = dim_order_to_stride(sizes_4, dim_order_4, 4, strides_4);
113   EXPECT_EQ(error, Error::Ok);
114   check_strides_eq({strides_4, 4}, {expected_strides_4, 4});
115 
116   // {2, 5, 7, 8}
117   // {3, 1, 2, 0}
118   dim_order_4[0] = 3;
119   dim_order_4[1] = 1;
120   dim_order_4[2] = 2;
121   dim_order_4[3] = 0;
122   // Expected stride {1, 14, 2, 70}
123   expected_strides_4[0] = 1;
124   expected_strides_4[1] = 14;
125   expected_strides_4[2] = 2;
126   expected_strides_4[3] = 70;
127   error = dim_order_to_stride(sizes_4, dim_order_4, 4, strides_4);
128   EXPECT_EQ(error, Error::Ok);
129   check_strides_eq({strides_4, 4}, {expected_strides_4, 4});
130 
131   exec_aten::SizesType sizes_5[5] = {2, 5, 7, 8, 9};
132   exec_aten::SizesType dim_order_5[5] = {0, 1, 2, 3, 4};
133   exec_aten::SizesType strides_5[5] = {0, 0, 0, 0, 0};
134   exec_aten::SizesType expected_strides_5[5] = {2520, 504, 72, 9, 1};
135   error = dim_order_to_stride(sizes_5, dim_order_5, 5, strides_5);
136   EXPECT_EQ(error, Error::Ok);
137   check_strides_eq({strides_5, 5}, {expected_strides_5, 5});
138 
139   // {2, 5, 7, 8, 9}
140   // {0, 2, 3, 4, 1}
141   dim_order_5[0] = 0;
142   dim_order_5[1] = 2;
143   dim_order_5[2] = 3;
144   dim_order_5[3] = 4;
145   dim_order_5[4] = 1;
146   // Expected stride {2520, 1, 360, 45, 5}
147   expected_strides_5[0] = 2520;
148   expected_strides_5[1] = 1;
149   expected_strides_5[2] = 360;
150   expected_strides_5[3] = 45;
151   expected_strides_5[4] = 5;
152   error = dim_order_to_stride(sizes_5, dim_order_5, 5, strides_5);
153   EXPECT_EQ(error, Error::Ok);
154   check_strides_eq({strides_5, 5}, {expected_strides_5, 5});
155 
156   // {2, 5, 7, 8, 9}
157   // {4, 2, 0, 3, 1}
158   dim_order_5[0] = 4;
159   dim_order_5[1] = 2;
160   dim_order_5[2] = 0;
161   dim_order_5[3] = 3;
162   dim_order_5[4] = 1;
163   // Expected stride {40, 1, 80, 5, 560}
164   expected_strides_5[0] = 40;
165   expected_strides_5[1] = 1;
166   expected_strides_5[2] = 80;
167   expected_strides_5[3] = 5;
168   expected_strides_5[4] = 560;
169   error = dim_order_to_stride(sizes_5, dim_order_5, 5, strides_5);
170   EXPECT_EQ(error, Error::Ok);
171   check_strides_eq({strides_5, 5}, {expected_strides_5, 5});
172 
173   // Check 0 sized dims
174   exec_aten::SizesType sizes_3_zero[3] = {2, 5, 0};
175   exec_aten::SizesType dim_order_3_zero[3] = {0, 1, 2};
176   exec_aten::SizesType strides_3_zero[3] = {0, 0, 0};
177   exec_aten::SizesType expected_strides_3_zero[3] = {5, 1, 1};
178   error =
179       dim_order_to_stride(sizes_3_zero, dim_order_3_zero, 3, strides_3_zero);
180   EXPECT_EQ(error, Error::Ok);
181   check_strides_eq({strides_3_zero, 3}, {expected_strides_3_zero, 3});
182 
183   // {0, 2, 1}
184   // {2, 0, 5}
185   dim_order_3_zero[0] = 0, dim_order_3_zero[1] = 2, dim_order_3_zero[2] = 1;
186   // Expected stride {5, 5, 1}
187   expected_strides_3_zero[0] = 5, expected_strides_3_zero[1] = 1,
188   expected_strides_3_zero[2] = 5;
189   error =
190       dim_order_to_stride(sizes_3_zero, dim_order_3_zero, 3, strides_3_zero);
191   EXPECT_EQ(error, Error::Ok);
192   check_strides_eq({strides_3_zero, 3}, {expected_strides_3_zero, 3});
193 
194   // {2, 0, 1}
195   // {0, 2, 5}
196   dim_order_3_zero[0] = 2, dim_order_3_zero[1] = 0, dim_order_3_zero[2] = 1;
197   // Expected stride {10, 5, 1}
198   expected_strides_3_zero[0] = 5, expected_strides_3_zero[1] = 1,
199   expected_strides_3_zero[2] = 10;
200   error =
201       dim_order_to_stride(sizes_3_zero, dim_order_3_zero, 3, strides_3_zero);
202   EXPECT_EQ(error, Error::Ok);
203   check_strides_eq({strides_3_zero, 3}, {expected_strides_3_zero, 3});
204 }
205 
TEST(DimOrderUtilTest,StrideToDimOrder)206 TEST(DimOrderUtilTest, StrideToDimOrder) {
207   exec_aten::SizesType strides[3] = {5, 1, 15};
208   exec_aten::DimOrderType dim_order[3] = {0, 0, 0};
209 
210   auto error = stride_to_dim_order(strides, 3, dim_order);
211 
212   EXPECT_EQ(error, Error::Ok);
213 
214   exec_aten::DimOrderType expected_dim_order[3] = {2, 0, 1};
215   check_dim_order_eq(dim_order, expected_dim_order);
216 }
217 
TEST(DimOrderUtilTest,StrideToDimOrderSameStrides)218 TEST(DimOrderUtilTest, StrideToDimOrderSameStrides) {
219   exec_aten::SizesType strides[4] = {4, 3, 1, 1};
220   exec_aten::DimOrderType dim_order[4] = {0, 0, 0, 0};
221 
222   auto error = stride_to_dim_order(strides, 4, dim_order);
223   EXPECT_EQ(error, Error::Ok);
224 
225   exec_aten::DimOrderType expected_dim_order[4] = {0, 1, 2, 3};
226   check_dim_order_eq(dim_order, expected_dim_order);
227 }
228 
TEST(DimOrderUtilTest,IsDefaultDimOrderTest)229 TEST(DimOrderUtilTest, IsDefaultDimOrderTest) {
230   for (int i = 1; i < 7; ++i) {
231     std::vector<exec_aten::DimOrderType> dim_order(i);
232     std::iota(dim_order.begin(), dim_order.end(), 0);
233 
234     EXPECT_TRUE(is_contiguous_dim_order(dim_order.data(), dim_order.size()));
235 
236     // As a bonus, check that is_channels_last returns false
237     EXPECT_FALSE(
238         is_channels_last_dim_order(dim_order.data(), dim_order.size()));
239   }
240 }
241 
TEST(DimOrderUtilTest,IsDefaultDimOrderFailCasesTest)242 TEST(DimOrderUtilTest, IsDefaultDimOrderFailCasesTest) {
243   // Dims is default order but have two elements swapped
244   for (int i = 3; i < 8; ++i) {
245     std::vector<exec_aten::DimOrderType> dim_order(i);
246     std::iota(dim_order.begin(), dim_order.end(), 0);
247     std::swap(dim_order[0], dim_order[1]);
248 
249     EXPECT_FALSE(is_contiguous_dim_order(dim_order.data(), dim_order.size()));
250   }
251 
252   // Dims is default order but shifted by 1
253   for (int i = 3; i < 8; ++i) {
254     std::vector<exec_aten::DimOrderType> dim_order(i);
255     for (int d = 0; d < i; ++d) {
256       dim_order[d] = (d + 1) % i;
257     }
258 
259     EXPECT_FALSE(is_contiguous_dim_order(dim_order.data(), dim_order.size()));
260   }
261 }
262 
TEST(DimOrderUtilTest,IsChannelsLastDimOrderTest)263 TEST(DimOrderUtilTest, IsChannelsLastDimOrderTest) {
264   exec_aten::DimOrderType dim_order_4d[4] = {0, 2, 3, 1};
265   exec_aten::DimOrderType dim_order_5d[5] = {0, 2, 3, 4, 1};
266 
267   EXPECT_TRUE(is_channels_last_dim_order(dim_order_4d, 4));
268   EXPECT_TRUE(is_channels_last_dim_order(dim_order_5d, 5));
269 
270   // As a bonus, check that is_default returns false
271   EXPECT_FALSE(is_contiguous_dim_order(dim_order_4d, 4));
272   EXPECT_FALSE(is_contiguous_dim_order(dim_order_5d, 5));
273 }
274 
TEST(DimOrderUtilTest,IsChannelsLastDimOrderFailCasesTest)275 TEST(DimOrderUtilTest, IsChannelsLastDimOrderFailCasesTest) {
276   // Non 4D and 5D dim order returns false
277   exec_aten::DimOrderType dim_order_3d[4] = {1, 2, 0};
278   exec_aten::DimOrderType dim_order_6d[6] = {0, 2, 3, 4, 5, 1};
279 
280   EXPECT_FALSE(is_channels_last_dim_order(dim_order_3d, 3));
281   EXPECT_FALSE(is_channels_last_dim_order(dim_order_6d, 6));
282 
283   exec_aten::DimOrderType dim_order_4d[4] = {0, 3, 2, 1};
284   exec_aten::DimOrderType dim_order_5d[5] = {4, 3, 2, 0, 1};
285 
286   EXPECT_FALSE(is_channels_last_dim_order(dim_order_4d, 4));
287   EXPECT_FALSE(is_channels_last_dim_order(dim_order_5d, 5));
288 }
289