1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
10
11 #include <numeric>
12
13 #include <executorch/runtime/core/exec_aten/exec_aten.h>
14
15 #include <gtest/gtest.h>
16
17 using executorch::runtime::dim_order_to_stride;
18 using executorch::runtime::Error;
19 using executorch::runtime::is_channels_last_dim_order;
20 using executorch::runtime::is_contiguous_dim_order;
21 using executorch::runtime::stride_to_dim_order;
22
23 namespace {
check_strides_eq(exec_aten::ArrayRef<exec_aten::StridesType> strides_a,exec_aten::ArrayRef<exec_aten::StridesType> strides_b)24 void check_strides_eq(
25 exec_aten::ArrayRef<exec_aten::StridesType> strides_a,
26 exec_aten::ArrayRef<exec_aten::StridesType> strides_b) {
27 for (int32_t i = 0; i < strides_a.size(); ++i) {
28 EXPECT_EQ(strides_a[i], strides_b[i]);
29 }
30 }
31
check_dim_order_eq(exec_aten::ArrayRef<exec_aten::DimOrderType> dim_order_a,exec_aten::ArrayRef<exec_aten::DimOrderType> dim_order_b)32 void check_dim_order_eq(
33 exec_aten::ArrayRef<exec_aten::DimOrderType> dim_order_a,
34 exec_aten::ArrayRef<exec_aten::DimOrderType> dim_order_b) {
35 for (int32_t i = 0; i < dim_order_a.size(); ++i) {
36 EXPECT_EQ(dim_order_a[i], dim_order_b[i]);
37 }
38 }
39 } // namespace
40
TEST(DimOrderUtilTest,DimOrderToStride)41 TEST(DimOrderUtilTest, DimOrderToStride) {
42 exec_aten::SizesType sizes_1[1] = {5};
43 exec_aten::SizesType dim_order_1[1] = {0};
44 exec_aten::SizesType strides_1[1] = {0};
45 exec_aten::SizesType expected_strides_1[1] = {1};
46 auto error = dim_order_to_stride(sizes_1, dim_order_1, 1, strides_1);
47 EXPECT_EQ(error, Error::Ok);
48 check_strides_eq({strides_1, 1}, {expected_strides_1, 1});
49
50 exec_aten::SizesType sizes_2[2] = {2, 5};
51 exec_aten::SizesType dim_order_2[2] = {0, 1};
52 exec_aten::SizesType strides_2[2] = {0, 0};
53 exec_aten::SizesType expected_strides_2[2] = {5, 1};
54 error = dim_order_to_stride(sizes_2, dim_order_2, 2, strides_2);
55 EXPECT_EQ(error, Error::Ok);
56 check_strides_eq({strides_2, 2}, {expected_strides_2, 2});
57
58 dim_order_2[0] = 1;
59 dim_order_2[1] = 0;
60 expected_strides_2[0] = 1;
61 expected_strides_2[1] = 2;
62 error = dim_order_to_stride(sizes_2, dim_order_2, 2, strides_2);
63 EXPECT_EQ(error, Error::Ok);
64 check_strides_eq({strides_2, 2}, {expected_strides_2, 2});
65
66 exec_aten::SizesType sizes_3[3] = {2, 5, 7};
67 exec_aten::SizesType dim_order_3[3] = {0, 1, 2};
68 exec_aten::SizesType strides_3[3] = {0, 0, 0};
69 exec_aten::SizesType expected_strides_3[3] = {35, 7, 1};
70 error = dim_order_to_stride(sizes_3, dim_order_3, 3, strides_3);
71 EXPECT_EQ(error, Error::Ok);
72 check_strides_eq({strides_3, 3}, {expected_strides_3, 3});
73
74 // {0, 2, 1}
75 dim_order_3[0] = 0, dim_order_3[1] = 2, dim_order_3[2] = 1;
76 // Expected stride {35, 1, 5}
77 expected_strides_3[0] = 35, expected_strides_3[1] = 1,
78 expected_strides_3[2] = 5;
79 error = dim_order_to_stride(sizes_3, dim_order_3, 3, strides_3);
80 EXPECT_EQ(error, Error::Ok);
81 check_strides_eq({strides_3, 3}, {expected_strides_3, 3});
82
83 // {2, 5, 7}
84 // {1, 2, 0}
85 dim_order_3[0] = 1, dim_order_3[1] = 2, dim_order_3[2] = 0;
86 // Expected stride {35, 1, 5}
87 expected_strides_3[0] = 1, expected_strides_3[1] = 14,
88 expected_strides_3[2] = 2;
89 error = dim_order_to_stride(sizes_3, dim_order_3, 3, strides_3);
90 EXPECT_EQ(error, Error::Ok);
91 check_strides_eq({strides_3, 3}, {expected_strides_3, 3});
92
93 exec_aten::SizesType sizes_4[4] = {2, 5, 7, 8};
94 exec_aten::SizesType dim_order_4[4] = {0, 1, 2, 3};
95 exec_aten::SizesType strides_4[4] = {0, 0, 0, 0};
96 exec_aten::SizesType expected_strides_4[4] = {280, 56, 8, 1};
97 error = dim_order_to_stride(sizes_4, dim_order_4, 4, strides_4);
98 EXPECT_EQ(error, Error::Ok);
99 check_strides_eq({strides_4, 4}, {expected_strides_4, 4});
100
101 // {2, 5, 7, 8}
102 // {0, 2, 3, 1}
103 dim_order_4[0] = 0;
104 dim_order_4[1] = 2;
105 dim_order_4[2] = 3;
106 dim_order_4[3] = 1;
107 // Expected stride {280, 1, 40, 5}
108 expected_strides_4[0] = 280;
109 expected_strides_4[1] = 1;
110 expected_strides_4[2] = 40;
111 expected_strides_4[3] = 5;
112 error = dim_order_to_stride(sizes_4, dim_order_4, 4, strides_4);
113 EXPECT_EQ(error, Error::Ok);
114 check_strides_eq({strides_4, 4}, {expected_strides_4, 4});
115
116 // {2, 5, 7, 8}
117 // {3, 1, 2, 0}
118 dim_order_4[0] = 3;
119 dim_order_4[1] = 1;
120 dim_order_4[2] = 2;
121 dim_order_4[3] = 0;
122 // Expected stride {1, 14, 2, 70}
123 expected_strides_4[0] = 1;
124 expected_strides_4[1] = 14;
125 expected_strides_4[2] = 2;
126 expected_strides_4[3] = 70;
127 error = dim_order_to_stride(sizes_4, dim_order_4, 4, strides_4);
128 EXPECT_EQ(error, Error::Ok);
129 check_strides_eq({strides_4, 4}, {expected_strides_4, 4});
130
131 exec_aten::SizesType sizes_5[5] = {2, 5, 7, 8, 9};
132 exec_aten::SizesType dim_order_5[5] = {0, 1, 2, 3, 4};
133 exec_aten::SizesType strides_5[5] = {0, 0, 0, 0, 0};
134 exec_aten::SizesType expected_strides_5[5] = {2520, 504, 72, 9, 1};
135 error = dim_order_to_stride(sizes_5, dim_order_5, 5, strides_5);
136 EXPECT_EQ(error, Error::Ok);
137 check_strides_eq({strides_5, 5}, {expected_strides_5, 5});
138
139 // {2, 5, 7, 8, 9}
140 // {0, 2, 3, 4, 1}
141 dim_order_5[0] = 0;
142 dim_order_5[1] = 2;
143 dim_order_5[2] = 3;
144 dim_order_5[3] = 4;
145 dim_order_5[4] = 1;
146 // Expected stride {2520, 1, 360, 45, 5}
147 expected_strides_5[0] = 2520;
148 expected_strides_5[1] = 1;
149 expected_strides_5[2] = 360;
150 expected_strides_5[3] = 45;
151 expected_strides_5[4] = 5;
152 error = dim_order_to_stride(sizes_5, dim_order_5, 5, strides_5);
153 EXPECT_EQ(error, Error::Ok);
154 check_strides_eq({strides_5, 5}, {expected_strides_5, 5});
155
156 // {2, 5, 7, 8, 9}
157 // {4, 2, 0, 3, 1}
158 dim_order_5[0] = 4;
159 dim_order_5[1] = 2;
160 dim_order_5[2] = 0;
161 dim_order_5[3] = 3;
162 dim_order_5[4] = 1;
163 // Expected stride {40, 1, 80, 5, 560}
164 expected_strides_5[0] = 40;
165 expected_strides_5[1] = 1;
166 expected_strides_5[2] = 80;
167 expected_strides_5[3] = 5;
168 expected_strides_5[4] = 560;
169 error = dim_order_to_stride(sizes_5, dim_order_5, 5, strides_5);
170 EXPECT_EQ(error, Error::Ok);
171 check_strides_eq({strides_5, 5}, {expected_strides_5, 5});
172
173 // Check 0 sized dims
174 exec_aten::SizesType sizes_3_zero[3] = {2, 5, 0};
175 exec_aten::SizesType dim_order_3_zero[3] = {0, 1, 2};
176 exec_aten::SizesType strides_3_zero[3] = {0, 0, 0};
177 exec_aten::SizesType expected_strides_3_zero[3] = {5, 1, 1};
178 error =
179 dim_order_to_stride(sizes_3_zero, dim_order_3_zero, 3, strides_3_zero);
180 EXPECT_EQ(error, Error::Ok);
181 check_strides_eq({strides_3_zero, 3}, {expected_strides_3_zero, 3});
182
183 // {0, 2, 1}
184 // {2, 0, 5}
185 dim_order_3_zero[0] = 0, dim_order_3_zero[1] = 2, dim_order_3_zero[2] = 1;
186 // Expected stride {5, 5, 1}
187 expected_strides_3_zero[0] = 5, expected_strides_3_zero[1] = 1,
188 expected_strides_3_zero[2] = 5;
189 error =
190 dim_order_to_stride(sizes_3_zero, dim_order_3_zero, 3, strides_3_zero);
191 EXPECT_EQ(error, Error::Ok);
192 check_strides_eq({strides_3_zero, 3}, {expected_strides_3_zero, 3});
193
194 // {2, 0, 1}
195 // {0, 2, 5}
196 dim_order_3_zero[0] = 2, dim_order_3_zero[1] = 0, dim_order_3_zero[2] = 1;
197 // Expected stride {10, 5, 1}
198 expected_strides_3_zero[0] = 5, expected_strides_3_zero[1] = 1,
199 expected_strides_3_zero[2] = 10;
200 error =
201 dim_order_to_stride(sizes_3_zero, dim_order_3_zero, 3, strides_3_zero);
202 EXPECT_EQ(error, Error::Ok);
203 check_strides_eq({strides_3_zero, 3}, {expected_strides_3_zero, 3});
204 }
205
TEST(DimOrderUtilTest,StrideToDimOrder)206 TEST(DimOrderUtilTest, StrideToDimOrder) {
207 exec_aten::SizesType strides[3] = {5, 1, 15};
208 exec_aten::DimOrderType dim_order[3] = {0, 0, 0};
209
210 auto error = stride_to_dim_order(strides, 3, dim_order);
211
212 EXPECT_EQ(error, Error::Ok);
213
214 exec_aten::DimOrderType expected_dim_order[3] = {2, 0, 1};
215 check_dim_order_eq(dim_order, expected_dim_order);
216 }
217
TEST(DimOrderUtilTest,StrideToDimOrderSameStrides)218 TEST(DimOrderUtilTest, StrideToDimOrderSameStrides) {
219 exec_aten::SizesType strides[4] = {4, 3, 1, 1};
220 exec_aten::DimOrderType dim_order[4] = {0, 0, 0, 0};
221
222 auto error = stride_to_dim_order(strides, 4, dim_order);
223 EXPECT_EQ(error, Error::Ok);
224
225 exec_aten::DimOrderType expected_dim_order[4] = {0, 1, 2, 3};
226 check_dim_order_eq(dim_order, expected_dim_order);
227 }
228
TEST(DimOrderUtilTest,IsDefaultDimOrderTest)229 TEST(DimOrderUtilTest, IsDefaultDimOrderTest) {
230 for (int i = 1; i < 7; ++i) {
231 std::vector<exec_aten::DimOrderType> dim_order(i);
232 std::iota(dim_order.begin(), dim_order.end(), 0);
233
234 EXPECT_TRUE(is_contiguous_dim_order(dim_order.data(), dim_order.size()));
235
236 // As a bonus, check that is_channels_last returns false
237 EXPECT_FALSE(
238 is_channels_last_dim_order(dim_order.data(), dim_order.size()));
239 }
240 }
241
TEST(DimOrderUtilTest,IsDefaultDimOrderFailCasesTest)242 TEST(DimOrderUtilTest, IsDefaultDimOrderFailCasesTest) {
243 // Dims is default order but have two elements swapped
244 for (int i = 3; i < 8; ++i) {
245 std::vector<exec_aten::DimOrderType> dim_order(i);
246 std::iota(dim_order.begin(), dim_order.end(), 0);
247 std::swap(dim_order[0], dim_order[1]);
248
249 EXPECT_FALSE(is_contiguous_dim_order(dim_order.data(), dim_order.size()));
250 }
251
252 // Dims is default order but shifted by 1
253 for (int i = 3; i < 8; ++i) {
254 std::vector<exec_aten::DimOrderType> dim_order(i);
255 for (int d = 0; d < i; ++d) {
256 dim_order[d] = (d + 1) % i;
257 }
258
259 EXPECT_FALSE(is_contiguous_dim_order(dim_order.data(), dim_order.size()));
260 }
261 }
262
TEST(DimOrderUtilTest,IsChannelsLastDimOrderTest)263 TEST(DimOrderUtilTest, IsChannelsLastDimOrderTest) {
264 exec_aten::DimOrderType dim_order_4d[4] = {0, 2, 3, 1};
265 exec_aten::DimOrderType dim_order_5d[5] = {0, 2, 3, 4, 1};
266
267 EXPECT_TRUE(is_channels_last_dim_order(dim_order_4d, 4));
268 EXPECT_TRUE(is_channels_last_dim_order(dim_order_5d, 5));
269
270 // As a bonus, check that is_default returns false
271 EXPECT_FALSE(is_contiguous_dim_order(dim_order_4d, 4));
272 EXPECT_FALSE(is_contiguous_dim_order(dim_order_5d, 5));
273 }
274
TEST(DimOrderUtilTest,IsChannelsLastDimOrderFailCasesTest)275 TEST(DimOrderUtilTest, IsChannelsLastDimOrderFailCasesTest) {
276 // Non 4D and 5D dim order returns false
277 exec_aten::DimOrderType dim_order_3d[4] = {1, 2, 0};
278 exec_aten::DimOrderType dim_order_6d[6] = {0, 2, 3, 4, 5, 1};
279
280 EXPECT_FALSE(is_channels_last_dim_order(dim_order_3d, 3));
281 EXPECT_FALSE(is_channels_last_dim_order(dim_order_6d, 6));
282
283 exec_aten::DimOrderType dim_order_4d[4] = {0, 3, 2, 1};
284 exec_aten::DimOrderType dim_order_5d[5] = {4, 3, 2, 0, 1};
285
286 EXPECT_FALSE(is_channels_last_dim_order(dim_order_4d, 4));
287 EXPECT_FALSE(is_channels_last_dim_order(dim_order_5d, 5));
288 }
289