1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/runtime/core/exec_aten/exec_aten.h>
10 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
11 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
12 #include <executorch/runtime/platform/runtime.h>
13 #include <executorch/test/utils/DeathTest.h>
14 #include <cmath>
15 #include <limits>
16
17 using namespace ::testing;
18 using exec_aten::ScalarType;
19 using exec_aten::Tensor;
20 using executorch::runtime::extract_scalar_tensor;
21 using executorch::runtime::testing::TensorFactory;
22
23 class TensorUtilTest : public ::testing::Test {
24 protected:
25 // Factories for tests to use. These will be torn down and recreated for each
26 // test case.
27 TensorFactory<ScalarType::Byte> tf_byte_;
28 TensorFactory<ScalarType::Int> tf_int_;
29 TensorFactory<ScalarType::Float> tf_float_;
30 TensorFactory<ScalarType::Double> tf_double_;
31 TensorFactory<ScalarType::Bool> tf_bool_;
32
SetUp()33 void SetUp() override {
34 // As some of these tests cause ET_LOG to be called, the PAL must be
35 // initialized first by calling runtime_init();
36 executorch::runtime::runtime_init();
37 }
38 };
39
TEST_F(TensorUtilTest,IdentityChecks)40 TEST_F(TensorUtilTest, IdentityChecks) {
41 Tensor t = tf_byte_.ones({2, 2});
42
43 // A tensor is the same shape as itself.
44 ET_CHECK_SAME_SHAPE2(t, t);
45 ET_CHECK_SAME_SHAPE3(t, t, t);
46
47 // A tensor is the same dtype as itself.
48 ET_CHECK_SAME_DTYPE2(t, t);
49 ET_CHECK_SAME_DTYPE3(t, t, t);
50
51 // A tensor is the same shape and dtype as itself.
52 ET_CHECK_SAME_SHAPE_AND_DTYPE2(t, t);
53 ET_CHECK_SAME_SHAPE_AND_DTYPE3(t, t, t);
54 }
55
TEST_F(TensorUtilTest,SameShapesDifferentDtypes)56 TEST_F(TensorUtilTest, SameShapesDifferentDtypes) {
57 // Three different tensors with the same shape but different dtypes.
58 Tensor a = tf_byte_.ones({2, 2});
59 Tensor b = tf_int_.ones({2, 2});
60 Tensor c = tf_float_.ones({2, 2});
61
62 // The tensors have the same shapes.
63 ET_CHECK_SAME_SHAPE2(a, b);
64 ET_CHECK_SAME_SHAPE3(a, b, c);
65
66 // Not the same dtypes. Check both positions.
67 ET_EXPECT_DEATH(ET_CHECK_SAME_DTYPE2(a, b), "");
68 ET_EXPECT_DEATH(ET_CHECK_SAME_DTYPE2(b, a), "");
69 ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE2(a, b), "");
70 ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE2(b, a), "");
71
72 // Test with a mismatching tensor in all positions, where the other two agree.
73 ET_EXPECT_DEATH(ET_CHECK_SAME_DTYPE3(a, b, b), "");
74 ET_EXPECT_DEATH(ET_CHECK_SAME_DTYPE3(b, a, b), "");
75 ET_EXPECT_DEATH(ET_CHECK_SAME_DTYPE3(b, b, a), "");
76 ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE3(a, b, b), "");
77 ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE3(b, a, b), "");
78 ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE3(b, b, a), "");
79 }
80
TEST_F(TensorUtilTest,DifferentShapesSameDtypes)81 TEST_F(TensorUtilTest, DifferentShapesSameDtypes) {
82 // Two different tensors with different shapes but the same dtypes,
83 // dimensions, and number of elements.
84 Tensor a = tf_int_.ones({1, 4});
85 Tensor b = tf_int_.ones({2, 2});
86 // A third tensor with the same shape and dtype as b.
87 Tensor b2 = tf_int_.ones({2, 2});
88
89 // The different tensors are not the same shape. Check both positions.
90 ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE2(a, b), "");
91 ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE2(b, a), "");
92
93 // Test with the different tensor in all positions.
94 ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE3(a, b, b2), "");
95 ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE3(b, a, b2), "");
96 ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE3(b, b2, a), "");
97
98 // They are the same dtypes.
99 ET_CHECK_SAME_DTYPE2(a, b);
100 ET_CHECK_SAME_DTYPE2(b, a);
101 ET_CHECK_SAME_DTYPE3(a, b, b2);
102 ET_CHECK_SAME_DTYPE3(b, a, b2);
103 ET_CHECK_SAME_DTYPE3(b, b2, a);
104
105 // But not the same shape-and-dtype.
106 ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE2(a, b), "");
107 ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE2(b, a), "");
108 ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE3(a, b, b2), "");
109 ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE3(b, a, b2), "");
110 ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE3(b, b2, a), "");
111 }
112
TEST_F(TensorUtilTest,ZeroDimensionalTensor)113 TEST_F(TensorUtilTest, ZeroDimensionalTensor) {
114 // Create a zero-dimensional tensor.
115 Tensor t = tf_int_.ones({});
116
117 // Demonstrate that the tensor has zero dimensions.
118 EXPECT_EQ(t.dim(), 0);
119
120 // Make sure nothing blows up when the tensor has zero dimensions.
121 ET_CHECK_SAME_SHAPE2(t, t);
122 ET_CHECK_SAME_SHAPE3(t, t, t);
123 ET_CHECK_SAME_DTYPE2(t, t);
124 ET_CHECK_SAME_DTYPE3(t, t, t);
125 ET_CHECK_SAME_SHAPE_AND_DTYPE2(t, t);
126 ET_CHECK_SAME_SHAPE_AND_DTYPE3(t, t, t);
127 }
128
TEST_F(TensorUtilTest,EmptyTensor)129 TEST_F(TensorUtilTest, EmptyTensor) {
130 // Create a tensor with no elements by providing a zero-width dimension.
131 Tensor t = tf_int_.ones({0});
132
133 // Demonstrate that the tensor has no elements.
134 EXPECT_EQ(t.nbytes(), 0);
135 EXPECT_EQ(t.numel(), 0);
136
137 // Make sure nothing blows up when the tensor has no elements.
138 ET_CHECK_SAME_SHAPE2(t, t);
139 ET_CHECK_SAME_SHAPE3(t, t, t);
140 ET_CHECK_SAME_DTYPE2(t, t);
141 ET_CHECK_SAME_DTYPE3(t, t, t);
142 ET_CHECK_SAME_SHAPE_AND_DTYPE2(t, t);
143 ET_CHECK_SAME_SHAPE_AND_DTYPE3(t, t, t);
144 }
145
TEST_F(TensorUtilTest,GetLeadingDimsSmokeTest)146 TEST_F(TensorUtilTest, GetLeadingDimsSmokeTest) {
147 // Create a tensor with some dimensions
148 Tensor t = tf_int_.ones({2, 3, 4});
149
150 // getLeadingDims(t, 1) => t.size(0)
151 EXPECT_EQ(executorch::runtime::getLeadingDims(t, 1), 2);
152
153 // getLeadingDims(t, 2) => t.size(0) * t.size(1)
154 EXPECT_EQ(executorch::runtime::getLeadingDims(t, 2), 6);
155
156 // getLeadingDims(t, 3) => t.size(0) * t.size(1) * t.size(2)
157 EXPECT_EQ(executorch::runtime::getLeadingDims(t, 3), 24);
158 }
159
TEST_F(TensorUtilTest,GetLeadingDimsInputOutOfBoundDies)160 TEST_F(TensorUtilTest, GetLeadingDimsInputOutOfBoundDies) {
161 // Create a tensor with some dimensions
162 Tensor t = tf_int_.ones({2, 3, 4});
163
164 // dim needs to be in the range [0, t.dim()]
165 ET_EXPECT_DEATH(executorch::runtime::getLeadingDims(t, -2), "");
166 ET_EXPECT_DEATH(executorch::runtime::getLeadingDims(t, -1), "");
167 ET_EXPECT_DEATH(executorch::runtime::getLeadingDims(t, 4), "");
168 }
169
TEST_F(TensorUtilTest,GetTrailingDimsSmokeTest)170 TEST_F(TensorUtilTest, GetTrailingDimsSmokeTest) {
171 // Create a tensor with some dimensions
172 Tensor t = tf_int_.ones({2, 3, 4});
173
174 // getTrailingDims(t, 1) => t.size(2)
175 EXPECT_EQ(executorch::runtime::getTrailingDims(t, 1), 4);
176
177 // getTrailingDims(t, 0) => t.size(1) * t.size(2)
178 EXPECT_EQ(executorch::runtime::getTrailingDims(t, 0), 12);
179
180 // getTrailingDims(t, -1) => t.size(0) * t.size(1) * t.size(2)
181 EXPECT_EQ(executorch::runtime::getTrailingDims(t, -1), 24);
182 }
183
TEST_F(TensorUtilTest,GetTrailingDimsInputOutOfBoundDies)184 TEST_F(TensorUtilTest, GetTrailingDimsInputOutOfBoundDies) {
185 // Create a tensor with some dimensions
186 Tensor t = tf_int_.ones({2, 3, 4});
187
188 // dim needs to be in the range [-1, t.dim() - 1)
189 ET_EXPECT_DEATH(executorch::runtime::getTrailingDims(t, -2), "");
190 ET_EXPECT_DEATH(executorch::runtime::getTrailingDims(t, 3), "");
191 ET_EXPECT_DEATH(executorch::runtime::getTrailingDims(t, 4), "");
192 }
193
TEST_F(TensorUtilTest,ContiguousCheckSupported)194 TEST_F(TensorUtilTest, ContiguousCheckSupported) {
195 std::vector<float> data = {1, 2, 3, 4, 5, 6};
196 std::vector<int32_t> sizes = {1, 2, 3};
197
198 Tensor t_contiguous = tf_float_.make(sizes, data);
199
200 // t_incontiguous = tf.make(sizes=(1, 2, 3)).permute(2, 0, 1)
201 // {3, 1, 2}
202 // changed stride {1, 3, 1} => {2, 1, 2} because {1, 3, 1} is not
203 // the right value.
204 Tensor t_incontiguous = tf_float_.make(sizes, data, /*strides=*/{2, 1, 2});
205
206 // Assert t_contiguous is contiguous.
207 ET_CHECK_CONTIGUOUS(t_contiguous);
208
209 // Assert t_incontiguous is incontiguous.
210 ET_EXPECT_DEATH(ET_CHECK_CONTIGUOUS(t_incontiguous), "");
211 }
212
TEST_F(TensorUtilTest,CheckSameContiguousStrideSupported)213 TEST_F(TensorUtilTest, CheckSameContiguousStrideSupported) {
214 // Tensors in the following list share same stride.
215 std::vector<Tensor> same_stride_tensor_list = {
216 tf_float_.ones(/*sizes=*/{1, 2, 3, 4}),
217 tf_byte_.ones(/*sizes=*/{4, 2, 3, 4}),
218 tf_int_.ones(/*sizes=*/{10, 2, 3, 4}),
219 tf_float_.make(
220 /*sizes=*/{0, 2, 3, 4}, /*data=*/{}, /*strides=*/{24, 12, 4, 1})};
221
222 // different_stride = tensor(size=(0,2,3,4)).permute(0, 2, 3, 1)
223 // {0, 3, 4, 2}
224 // stride for (0, 2, 3, 4) with permute = (24, 1, 8, 2)
225 // So change stride from {24, 3, 1, 6} => {24, 1, 8, 2}
226 Tensor different_stride = tf_float_.make(
227 /*sizes=*/{0, 2, 3, 4}, /*data=*/{}, /*strides=*/{24, 1, 8, 2});
228
229 // Any two tensors in `same_stride_tensor_list` have same strides. The two
230 // could contain duplicate tensors.
231 for (int i = 0; i < same_stride_tensor_list.size(); i++) {
232 for (int j = i; j < same_stride_tensor_list.size(); j++) {
233 auto ti = same_stride_tensor_list[i];
234 auto tj = same_stride_tensor_list[j];
235 ET_CHECK_SAME_STRIDES2(ti, tj);
236 }
237 }
238
239 // Any tensor in `same_stride_tensor_list` shall not have same stride with
240 // `different_stride`.
241 for (int i = 0; i < same_stride_tensor_list.size(); i++) {
242 auto ti = same_stride_tensor_list[i];
243 ET_EXPECT_DEATH(ET_CHECK_SAME_STRIDES2(ti, different_stride), "");
244 }
245
246 // Any three tensors in same_stride_tensor_list have same strides. The three
247 // could contain duplicate tensors.
248 for (size_t i = 0; i < same_stride_tensor_list.size(); i++) {
249 for (size_t j = i; j < same_stride_tensor_list.size(); j++) {
250 for (size_t k = j; k < same_stride_tensor_list.size(); k++) {
251 auto ti = same_stride_tensor_list[i];
252 auto tj = same_stride_tensor_list[j];
253 auto tk = same_stride_tensor_list[k];
254 ET_CHECK_SAME_STRIDES3(ti, tj, tk);
255 }
256 }
257 }
258
259 // Any two tensors in same_stride_tensor_list shall not have same strides with
260 // `different_stride`. The two could contain duplicate tensors.
261 for (int i = 0; i < same_stride_tensor_list.size(); i++) {
262 for (int j = i; j < same_stride_tensor_list.size(); j++) {
263 auto ti = same_stride_tensor_list[i];
264 auto tj = same_stride_tensor_list[j];
265 ET_EXPECT_DEATH(ET_CHECK_SAME_STRIDES3(ti, tj, different_stride), "");
266 }
267 }
268 }
269
TEST_F(TensorUtilTest,ExtractIntScalarTensorSmoke)270 TEST_F(TensorUtilTest, ExtractIntScalarTensorSmoke) {
271 Tensor t = tf_int_.ones({1});
272 bool ok;
273 #define CASE_INT_DTYPE(ctype, unused) \
274 ctype out_##ctype; \
275 ok = extract_scalar_tensor(t, &out_##ctype); \
276 ASSERT_TRUE(ok); \
277 EXPECT_EQ(out_##ctype, 1);
278
279 ET_FORALL_INT_TYPES(CASE_INT_DTYPE);
280 #undef CASE_INT_DTYPE
281 }
282
TEST_F(TensorUtilTest,ExtractFloatScalarTensorFloatingTypeSmoke)283 TEST_F(TensorUtilTest, ExtractFloatScalarTensorFloatingTypeSmoke) {
284 Tensor t = tf_float_.ones({1});
285 bool ok;
286 #define CASE_FLOAT_DTYPE(ctype, unused) \
287 ctype out_##ctype; \
288 ok = extract_scalar_tensor(t, &out_##ctype); \
289 ASSERT_TRUE(ok); \
290 EXPECT_EQ(out_##ctype, 1.0);
291
292 ET_FORALL_FLOAT_TYPES(CASE_FLOAT_DTYPE);
293 #undef CASE_FLOAT_DTYPE
294 }
295
TEST_F(TensorUtilTest,ExtractFloatScalarTensorIntegralTypeSmoke)296 TEST_F(TensorUtilTest, ExtractFloatScalarTensorIntegralTypeSmoke) {
297 Tensor t = tf_int_.ones({1});
298 bool ok;
299 #define CASE_FLOAT_DTYPE(ctype, unused) \
300 ctype out_##ctype; \
301 ok = extract_scalar_tensor(t, &out_##ctype); \
302 ASSERT_TRUE(ok); \
303 EXPECT_EQ(out_##ctype, 1.0);
304
305 ET_FORALL_INT_TYPES(CASE_FLOAT_DTYPE);
306 #undef CASE_FLOAT_DTYPE
307 }
308
TEST_F(TensorUtilTest,ExtractBoolScalarTensorSmoke)309 TEST_F(TensorUtilTest, ExtractBoolScalarTensorSmoke) {
310 Tensor t = tf_bool_.ones({1});
311 bool out;
312 bool ok;
313 ok = extract_scalar_tensor(t, &out);
314 ASSERT_TRUE(ok);
315 EXPECT_EQ(out, true);
316 }
317
TEST_F(TensorUtilTest,FloatScalarTensorStressTests)318 TEST_F(TensorUtilTest, FloatScalarTensorStressTests) {
319 float value;
320 bool ok;
321
322 // Case: Positive Infinity
323 Tensor t_pos_inf = tf_double_.make({1}, {INFINITY});
324 ok = extract_scalar_tensor(t_pos_inf, &value);
325 EXPECT_TRUE(ok);
326 EXPECT_TRUE(std::isinf(value));
327
328 // Case: Negative Infinity
329 Tensor t_neg_inf = tf_double_.make({1}, {-INFINITY});
330 ok = extract_scalar_tensor(t_neg_inf, &value);
331 EXPECT_TRUE(ok);
332 EXPECT_TRUE(std::isinf(value));
333
334 // Case: Not a Number (NaN) - ex: sqrt(-1.0)
335 Tensor t_nan = tf_double_.make({1}, {NAN});
336 ok = extract_scalar_tensor(t_nan, &value);
337 EXPECT_TRUE(ok);
338 EXPECT_TRUE(std::isnan(value));
339 }
340
TEST_F(TensorUtilTest,IntScalarTensorNotIntegralTypeFails)341 TEST_F(TensorUtilTest, IntScalarTensorNotIntegralTypeFails) {
342 Tensor t = tf_float_.ones({1});
343 int64_t out;
344 // Fails since tensor is floating type but attempting to extract integer
345 // value.
346 bool ok = extract_scalar_tensor(t, &out);
347 EXPECT_FALSE(ok);
348 }
349
TEST_F(TensorUtilTest,FloatScalarTensorNotFloatingTypeFails)350 TEST_F(TensorUtilTest, FloatScalarTensorNotFloatingTypeFails) {
351 Tensor t = tf_bool_.ones({1});
352 double out;
353 // Fails since tensor is boolean type but attempting to extract float value.
354 bool ok = extract_scalar_tensor(t, &out);
355 EXPECT_FALSE(ok);
356 }
357
TEST_F(TensorUtilTest,IntTensorNotScalarFails)358 TEST_F(TensorUtilTest, IntTensorNotScalarFails) {
359 Tensor t = tf_int_.ones({2, 3});
360 int64_t out;
361 // Fails since tensor has multiple dims and values.
362 bool ok = extract_scalar_tensor(t, &out);
363 EXPECT_FALSE(ok);
364 }
365
TEST_F(TensorUtilTest,FloatTensorNotScalarFails)366 TEST_F(TensorUtilTest, FloatTensorNotScalarFails) {
367 Tensor t = tf_float_.ones({2, 3});
368 double out;
369 // Fails since tensor has multiple dims and values.
370 bool ok = extract_scalar_tensor(t, &out);
371 EXPECT_FALSE(ok);
372 }
373
TEST_F(TensorUtilTest,IntTensorOutOfBoundFails)374 TEST_F(TensorUtilTest, IntTensorOutOfBoundFails) {
375 Tensor t = tf_int_.make({1}, {256});
376 int8_t out;
377 // Fails since 256 is out of bounds for `int8_t` (-128 to 127).
378 bool ok = extract_scalar_tensor(t, &out);
379 EXPECT_FALSE(ok);
380 }
381
TEST_F(TensorUtilTest,FloatTensorOutOfBoundFails)382 TEST_F(TensorUtilTest, FloatTensorOutOfBoundFails) {
383 Tensor t = tf_double_.make({1}, {1.0}); // Placeholder value.
384 float out;
385 bool ok;
386
387 #define CASE_FLOAT(value) \
388 t = tf_double_.make({1}, {value}); \
389 ok = extract_scalar_tensor(t, &out); \
390 EXPECT_FALSE(ok);
391
392 // Float tensor can't handle double's largest negative value (note the use of
393 // `lowest` rather than `min`).
394 CASE_FLOAT(std::numeric_limits<double>::lowest());
395
396 // Float tensor can't handle double's largest positive value.
397 CASE_FLOAT(std::numeric_limits<double>::max());
398
399 #undef CASE_FLOAT
400 }
401
TEST_F(TensorUtilTest,BoolScalarTensorNotBooleanTypeFails)402 TEST_F(TensorUtilTest, BoolScalarTensorNotBooleanTypeFails) {
403 Tensor c = tf_byte_.ones({1});
404 bool out;
405 // Fails since tensor is integral type but attempting to extract boolean
406 // value.
407 bool ok = extract_scalar_tensor(c, &out);
408 EXPECT_FALSE(ok);
409 }
410
TEST_F(TensorUtilTest,BoolTensorNotScalarFails)411 TEST_F(TensorUtilTest, BoolTensorNotScalarFails) {
412 Tensor c = tf_bool_.ones({2, 3});
413 bool out;
414 // Fails since tensor has multiple dims and values.
415 bool ok = extract_scalar_tensor(c, &out);
416 EXPECT_FALSE(ok);
417 }
418
419 //
420 // Tests for utility functions that check tensor attributes
421 //
422
TEST_F(TensorUtilTest,TensorIsRankTest)423 TEST_F(TensorUtilTest, TensorIsRankTest) {
424 using executorch::runtime::tensor_is_rank;
425 Tensor a = tf_float_.ones({2, 3, 5});
426
427 EXPECT_TRUE(tensor_is_rank(a, 3));
428 EXPECT_FALSE(tensor_is_rank(a, 0));
429 EXPECT_FALSE(tensor_is_rank(a, 5));
430 }
431
TEST_F(TensorUtilTest,TensorHasDimTest)432 TEST_F(TensorUtilTest, TensorHasDimTest) {
433 using executorch::runtime::tensor_has_dim;
434 Tensor a = tf_float_.ones({2, 3, 5});
435
436 EXPECT_TRUE(tensor_has_dim(a, 2));
437 EXPECT_TRUE(tensor_has_dim(a, 1));
438 EXPECT_TRUE(tensor_has_dim(a, 0));
439 EXPECT_TRUE(tensor_has_dim(a, -1));
440 EXPECT_TRUE(tensor_has_dim(a, -2));
441 EXPECT_TRUE(tensor_has_dim(a, -3));
442
443 EXPECT_FALSE(tensor_has_dim(a, -4));
444 EXPECT_FALSE(tensor_has_dim(a, 3));
445 }
446
TEST_F(TensorUtilTest,TensorsHaveSameDtypeTest)447 TEST_F(TensorUtilTest, TensorsHaveSameDtypeTest) {
448 using executorch::runtime::tensors_have_same_dtype;
449 Tensor a = tf_float_.ones({2, 3});
450 Tensor b = tf_float_.ones({2, 3});
451 Tensor c = tf_float_.ones({3, 3});
452 Tensor d = tf_int_.ones({4, 3});
453
454 EXPECT_TRUE(tensors_have_same_dtype(a, b));
455 EXPECT_FALSE(tensors_have_same_dtype(a, d));
456 EXPECT_TRUE(tensors_have_same_dtype(a, b, c));
457 EXPECT_FALSE(tensors_have_same_dtype(a, b, d));
458 }
459
TEST_F(TensorUtilTest,TensorsHaveSameSizeAtDimTest)460 TEST_F(TensorUtilTest, TensorsHaveSameSizeAtDimTest) {
461 using executorch::runtime::tensors_have_same_size_at_dims;
462 Tensor a = tf_float_.ones({2, 3, 4, 5});
463 Tensor b = tf_float_.ones({5, 4, 3, 2});
464
465 EXPECT_TRUE(tensors_have_same_size_at_dims(a, 0, b, 3));
466 EXPECT_TRUE(tensors_have_same_size_at_dims(a, 1, b, 2));
467 EXPECT_FALSE(tensors_have_same_size_at_dims(a, 1, b, 0));
468 EXPECT_FALSE(tensors_have_same_size_at_dims(a, 4, b, 0));
469 EXPECT_FALSE(tensors_have_same_size_at_dims(a, 2, b, 3));
470 }
471
TEST_F(TensorUtilTest,TensorsHaveSameShapeTest)472 TEST_F(TensorUtilTest, TensorsHaveSameShapeTest) {
473 using executorch::runtime::tensors_have_same_shape;
474 Tensor a = tf_float_.ones({2, 3});
475 Tensor b = tf_int_.ones({2, 3});
476 Tensor c = tf_byte_.ones({2, 3});
477 Tensor d = tf_double_.ones({3, 2});
478 Tensor e = tf_bool_.ones({3, 2});
479
480 EXPECT_TRUE(tensors_have_same_shape(a, b));
481 EXPECT_FALSE(tensors_have_same_shape(a, d));
482 EXPECT_TRUE(tensors_have_same_shape(d, e));
483 EXPECT_TRUE(tensors_have_same_shape(a, b, c));
484 EXPECT_FALSE(tensors_have_same_shape(a, b, d));
485 EXPECT_FALSE(tensors_have_same_shape(a, d, e));
486
487 Tensor scalar_a = tf_float_.ones({1, 1});
488 Tensor scalar_b = tf_double_.ones({1});
489 Tensor scalar_c = tf_int_.ones({1, 1, 1, 1});
490
491 EXPECT_TRUE(tensors_have_same_shape(scalar_a, scalar_b));
492 EXPECT_TRUE(tensors_have_same_shape(scalar_a, scalar_b, scalar_c));
493 }
494
TEST_F(TensorUtilTest,TensorsHaveSameShapeAndDtypeTest)495 TEST_F(TensorUtilTest, TensorsHaveSameShapeAndDtypeTest) {
496 using executorch::runtime::tensors_have_same_shape_and_dtype;
497 Tensor a = tf_float_.ones({2, 3});
498 Tensor b = tf_float_.ones({2, 3});
499 Tensor c = tf_float_.ones({2, 3});
500 Tensor d = tf_double_.ones({2, 3});
501 Tensor e = tf_float_.ones({3, 2});
502
503 EXPECT_TRUE(tensors_have_same_shape_and_dtype(a, b));
504 EXPECT_FALSE(tensors_have_same_shape_and_dtype(a, d));
505 EXPECT_TRUE(tensors_have_same_shape_and_dtype(a, b, c));
506 EXPECT_FALSE(tensors_have_same_shape_and_dtype(a, b, d));
507 EXPECT_FALSE(tensors_have_same_shape_and_dtype(a, d, e));
508
509 Tensor scalar_a = tf_float_.ones({1, 1});
510 Tensor scalar_b = tf_float_.ones({1});
511 Tensor scalar_c = tf_float_.ones({1, 1, 1, 1});
512
513 EXPECT_TRUE(tensors_have_same_shape_and_dtype(scalar_a, scalar_b));
514 EXPECT_TRUE(tensors_have_same_shape_and_dtype(scalar_a, scalar_b, scalar_c));
515 }
516
TEST_F(TensorUtilTest,TensorsHaveSameStridesTest)517 TEST_F(TensorUtilTest, TensorsHaveSameStridesTest) {
518 using executorch::runtime::tensors_have_same_strides;
519 Tensor a = tf_float_.full_channels_last({4, 5, 2, 3}, 1);
520 Tensor b = tf_float_.full_channels_last({4, 5, 2, 3}, 2);
521 Tensor c = tf_float_.full_channels_last({4, 5, 2, 3}, 3);
522 Tensor d = tf_double_.ones({4, 5, 2, 3});
523 Tensor e = tf_float_.ones({4, 5, 2, 3});
524
525 EXPECT_TRUE(tensors_have_same_strides(a, b));
526 EXPECT_FALSE(tensors_have_same_strides(a, d));
527 EXPECT_TRUE(tensors_have_same_strides(a, b, c));
528 EXPECT_FALSE(tensors_have_same_strides(a, b, d));
529 EXPECT_FALSE(tensors_have_same_strides(a, d, e));
530 }
531
TEST_F(TensorUtilTest,TensorIsContiguous)532 TEST_F(TensorUtilTest, TensorIsContiguous) {
533 using executorch::runtime::tensor_is_contiguous;
534 // Note that the strides.size() == 0 case is not tested, since
535 Tensor a = tf_float_.full_channels_last({4, 5, 2, 3}, 1);
536 Tensor b = tf_float_.ones({4, 5, 2, 3});
537 Tensor c = tf_float_.full_channels_last({1, 1, 1, 1}, 1);
538 Tensor d = tf_float_.ones({});
539
540 EXPECT_FALSE(tensor_is_contiguous(a));
541 EXPECT_TRUE(tensor_is_contiguous(b));
542 EXPECT_TRUE(tensor_is_contiguous(c));
543 EXPECT_TRUE(tensor_is_contiguous(d));
544 }
545
TEST_F(TensorUtilTest,ResizeZeroDimTensor)546 TEST_F(TensorUtilTest, ResizeZeroDimTensor) {
547 Tensor a = tf_float_.ones({});
548
549 EXPECT_EQ(
550 executorch::runtime::resize_tensor(a, {}),
551 executorch::runtime::Error::Ok);
552 EXPECT_EQ(a.dim(), 0);
553 }
554
TEST_F(TensorUtilTest,SameDimOrderContiguous)555 TEST_F(TensorUtilTest, SameDimOrderContiguous) {
556 using namespace torch::executor;
557 // Three different tensors with the same shape and same dim order
558 // ([0, 1, 2, 3]), but different dtypes and contents.
559 std::vector<int32_t> sizes = {3, 5, 2, 1};
560 Tensor a = tf_byte_.ones(sizes);
561 Tensor b = tf_int_.zeros(sizes);
562 Tensor c = tf_float_.full(sizes, 0.1);
563
564 // The tensors have the same dim order, should pass the following checks.
565 EXPECT_TRUE(tensors_have_same_dim_order(a, b));
566 EXPECT_TRUE(tensors_have_same_dim_order(b, a));
567 EXPECT_TRUE(tensors_have_same_dim_order(a, b, c));
568 EXPECT_TRUE(tensors_have_same_dim_order(b, c, a));
569 EXPECT_TRUE(tensors_have_same_dim_order(c, a, b));
570 }
571
TEST_F(TensorUtilTest,SameDimOrderChannelsLast)572 TEST_F(TensorUtilTest, SameDimOrderChannelsLast) {
573 using namespace torch::executor;
574 // Three different tensors with the same shape and same dim order
575 // ([0, 2, 3, 1]), but different dtypes and contents.
576 std::vector<int32_t> sizes = {3, 5, 2, 1};
577 Tensor a = tf_byte_.full_channels_last(sizes, 1);
578 Tensor b = tf_int_.full_channels_last(sizes, 0);
579 Tensor c = tf_float_.full_channels_last(sizes, 0.1);
580
581 // The tensors have the same dim order, should pass the following checks.
582 EXPECT_TRUE(tensors_have_same_dim_order(a, b));
583 EXPECT_TRUE(tensors_have_same_dim_order(b, a));
584 EXPECT_TRUE(tensors_have_same_dim_order(a, b, c));
585 EXPECT_TRUE(tensors_have_same_dim_order(b, c, a));
586 EXPECT_TRUE(tensors_have_same_dim_order(c, a, b));
587 }
588
TEST_F(TensorUtilTest,SameShapesDifferentDimOrder)589 TEST_F(TensorUtilTest, SameShapesDifferentDimOrder) {
590 using namespace torch::executor;
591 // Three different tensors with the same shape but different dtypes and
592 // contents, where b and c have the same dim order ([0, 2, 3, 1]) while a is
593 // different ([0, 1, 2, 3]).
594 std::vector<int32_t> sizes = {3, 5, 2, 1};
595 Tensor a = tf_byte_.ones(sizes);
596 Tensor b = tf_int_.full_channels_last(sizes, 0);
597 Tensor c = tf_float_.full_channels_last(sizes, 0.1);
598
599 // Not the same dim order. Chec
600 EXPECT_FALSE(tensors_have_same_dim_order(a, b));
601 EXPECT_FALSE(tensors_have_same_dim_order(b, a));
602
603 // Test with a mismatching tensor in all positions, where the other two agree.
604 EXPECT_FALSE(tensors_have_same_dim_order(a, b, c));
605 EXPECT_FALSE(tensors_have_same_dim_order(a, c, b));
606 EXPECT_FALSE(tensors_have_same_dim_order(c, b, a));
607 }
608