xref: /aosp_15_r20/external/executorch/runtime/core/exec_aten/util/test/tensor_util_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/runtime/core/exec_aten/exec_aten.h>
10 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
11 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
12 #include <executorch/runtime/platform/runtime.h>
13 #include <executorch/test/utils/DeathTest.h>
14 #include <cmath>
15 #include <limits>
16 
17 using namespace ::testing;
18 using exec_aten::ScalarType;
19 using exec_aten::Tensor;
20 using executorch::runtime::extract_scalar_tensor;
21 using executorch::runtime::testing::TensorFactory;
22 
23 class TensorUtilTest : public ::testing::Test {
24  protected:
25   // Factories for tests to use. These will be torn down and recreated for each
26   // test case.
27   TensorFactory<ScalarType::Byte> tf_byte_;
28   TensorFactory<ScalarType::Int> tf_int_;
29   TensorFactory<ScalarType::Float> tf_float_;
30   TensorFactory<ScalarType::Double> tf_double_;
31   TensorFactory<ScalarType::Bool> tf_bool_;
32 
SetUp()33   void SetUp() override {
34     // As some of these tests cause ET_LOG to be called, the PAL must be
35     // initialized first by calling runtime_init();
36     executorch::runtime::runtime_init();
37   }
38 };
39 
TEST_F(TensorUtilTest,IdentityChecks)40 TEST_F(TensorUtilTest, IdentityChecks) {
41   Tensor t = tf_byte_.ones({2, 2});
42 
43   // A tensor is the same shape as itself.
44   ET_CHECK_SAME_SHAPE2(t, t);
45   ET_CHECK_SAME_SHAPE3(t, t, t);
46 
47   // A tensor is the same dtype as itself.
48   ET_CHECK_SAME_DTYPE2(t, t);
49   ET_CHECK_SAME_DTYPE3(t, t, t);
50 
51   // A tensor is the same shape and dtype as itself.
52   ET_CHECK_SAME_SHAPE_AND_DTYPE2(t, t);
53   ET_CHECK_SAME_SHAPE_AND_DTYPE3(t, t, t);
54 }
55 
TEST_F(TensorUtilTest,SameShapesDifferentDtypes)56 TEST_F(TensorUtilTest, SameShapesDifferentDtypes) {
57   // Three different tensors with the same shape but different dtypes.
58   Tensor a = tf_byte_.ones({2, 2});
59   Tensor b = tf_int_.ones({2, 2});
60   Tensor c = tf_float_.ones({2, 2});
61 
62   // The tensors have the same shapes.
63   ET_CHECK_SAME_SHAPE2(a, b);
64   ET_CHECK_SAME_SHAPE3(a, b, c);
65 
66   // Not the same dtypes. Check both positions.
67   ET_EXPECT_DEATH(ET_CHECK_SAME_DTYPE2(a, b), "");
68   ET_EXPECT_DEATH(ET_CHECK_SAME_DTYPE2(b, a), "");
69   ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE2(a, b), "");
70   ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE2(b, a), "");
71 
72   // Test with a mismatching tensor in all positions, where the other two agree.
73   ET_EXPECT_DEATH(ET_CHECK_SAME_DTYPE3(a, b, b), "");
74   ET_EXPECT_DEATH(ET_CHECK_SAME_DTYPE3(b, a, b), "");
75   ET_EXPECT_DEATH(ET_CHECK_SAME_DTYPE3(b, b, a), "");
76   ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE3(a, b, b), "");
77   ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE3(b, a, b), "");
78   ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE3(b, b, a), "");
79 }
80 
TEST_F(TensorUtilTest,DifferentShapesSameDtypes)81 TEST_F(TensorUtilTest, DifferentShapesSameDtypes) {
82   // Two different tensors with different shapes but the same dtypes,
83   // dimensions, and number of elements.
84   Tensor a = tf_int_.ones({1, 4});
85   Tensor b = tf_int_.ones({2, 2});
86   // A third tensor with the same shape and dtype as b.
87   Tensor b2 = tf_int_.ones({2, 2});
88 
89   // The different tensors are not the same shape. Check both positions.
90   ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE2(a, b), "");
91   ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE2(b, a), "");
92 
93   // Test with the different tensor in all positions.
94   ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE3(a, b, b2), "");
95   ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE3(b, a, b2), "");
96   ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE3(b, b2, a), "");
97 
98   // They are the same dtypes.
99   ET_CHECK_SAME_DTYPE2(a, b);
100   ET_CHECK_SAME_DTYPE2(b, a);
101   ET_CHECK_SAME_DTYPE3(a, b, b2);
102   ET_CHECK_SAME_DTYPE3(b, a, b2);
103   ET_CHECK_SAME_DTYPE3(b, b2, a);
104 
105   // But not the same shape-and-dtype.
106   ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE2(a, b), "");
107   ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE2(b, a), "");
108   ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE3(a, b, b2), "");
109   ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE3(b, a, b2), "");
110   ET_EXPECT_DEATH(ET_CHECK_SAME_SHAPE_AND_DTYPE3(b, b2, a), "");
111 }
112 
TEST_F(TensorUtilTest,ZeroDimensionalTensor)113 TEST_F(TensorUtilTest, ZeroDimensionalTensor) {
114   // Create a zero-dimensional tensor.
115   Tensor t = tf_int_.ones({});
116 
117   // Demonstrate that the tensor has zero dimensions.
118   EXPECT_EQ(t.dim(), 0);
119 
120   // Make sure nothing blows up when the tensor has zero dimensions.
121   ET_CHECK_SAME_SHAPE2(t, t);
122   ET_CHECK_SAME_SHAPE3(t, t, t);
123   ET_CHECK_SAME_DTYPE2(t, t);
124   ET_CHECK_SAME_DTYPE3(t, t, t);
125   ET_CHECK_SAME_SHAPE_AND_DTYPE2(t, t);
126   ET_CHECK_SAME_SHAPE_AND_DTYPE3(t, t, t);
127 }
128 
TEST_F(TensorUtilTest,EmptyTensor)129 TEST_F(TensorUtilTest, EmptyTensor) {
130   // Create a tensor with no elements by providing a zero-width dimension.
131   Tensor t = tf_int_.ones({0});
132 
133   // Demonstrate that the tensor has no elements.
134   EXPECT_EQ(t.nbytes(), 0);
135   EXPECT_EQ(t.numel(), 0);
136 
137   // Make sure nothing blows up when the tensor has no elements.
138   ET_CHECK_SAME_SHAPE2(t, t);
139   ET_CHECK_SAME_SHAPE3(t, t, t);
140   ET_CHECK_SAME_DTYPE2(t, t);
141   ET_CHECK_SAME_DTYPE3(t, t, t);
142   ET_CHECK_SAME_SHAPE_AND_DTYPE2(t, t);
143   ET_CHECK_SAME_SHAPE_AND_DTYPE3(t, t, t);
144 }
145 
TEST_F(TensorUtilTest,GetLeadingDimsSmokeTest)146 TEST_F(TensorUtilTest, GetLeadingDimsSmokeTest) {
147   // Create a tensor with some dimensions
148   Tensor t = tf_int_.ones({2, 3, 4});
149 
150   // getLeadingDims(t, 1) => t.size(0)
151   EXPECT_EQ(executorch::runtime::getLeadingDims(t, 1), 2);
152 
153   // getLeadingDims(t, 2) => t.size(0) * t.size(1)
154   EXPECT_EQ(executorch::runtime::getLeadingDims(t, 2), 6);
155 
156   // getLeadingDims(t, 3) => t.size(0) * t.size(1) * t.size(2)
157   EXPECT_EQ(executorch::runtime::getLeadingDims(t, 3), 24);
158 }
159 
TEST_F(TensorUtilTest,GetLeadingDimsInputOutOfBoundDies)160 TEST_F(TensorUtilTest, GetLeadingDimsInputOutOfBoundDies) {
161   // Create a tensor with some dimensions
162   Tensor t = tf_int_.ones({2, 3, 4});
163 
164   // dim needs to be in the range [0, t.dim()]
165   ET_EXPECT_DEATH(executorch::runtime::getLeadingDims(t, -2), "");
166   ET_EXPECT_DEATH(executorch::runtime::getLeadingDims(t, -1), "");
167   ET_EXPECT_DEATH(executorch::runtime::getLeadingDims(t, 4), "");
168 }
169 
TEST_F(TensorUtilTest,GetTrailingDimsSmokeTest)170 TEST_F(TensorUtilTest, GetTrailingDimsSmokeTest) {
171   // Create a tensor with some dimensions
172   Tensor t = tf_int_.ones({2, 3, 4});
173 
174   // getTrailingDims(t, 1) => t.size(2)
175   EXPECT_EQ(executorch::runtime::getTrailingDims(t, 1), 4);
176 
177   // getTrailingDims(t, 0) => t.size(1) * t.size(2)
178   EXPECT_EQ(executorch::runtime::getTrailingDims(t, 0), 12);
179 
180   // getTrailingDims(t, -1) => t.size(0) * t.size(1) * t.size(2)
181   EXPECT_EQ(executorch::runtime::getTrailingDims(t, -1), 24);
182 }
183 
TEST_F(TensorUtilTest,GetTrailingDimsInputOutOfBoundDies)184 TEST_F(TensorUtilTest, GetTrailingDimsInputOutOfBoundDies) {
185   // Create a tensor with some dimensions
186   Tensor t = tf_int_.ones({2, 3, 4});
187 
188   // dim needs to be in the range [-1, t.dim() - 1)
189   ET_EXPECT_DEATH(executorch::runtime::getTrailingDims(t, -2), "");
190   ET_EXPECT_DEATH(executorch::runtime::getTrailingDims(t, 3), "");
191   ET_EXPECT_DEATH(executorch::runtime::getTrailingDims(t, 4), "");
192 }
193 
TEST_F(TensorUtilTest,ContiguousCheckSupported)194 TEST_F(TensorUtilTest, ContiguousCheckSupported) {
195   std::vector<float> data = {1, 2, 3, 4, 5, 6};
196   std::vector<int32_t> sizes = {1, 2, 3};
197 
198   Tensor t_contiguous = tf_float_.make(sizes, data);
199 
200   // t_incontiguous = tf.make(sizes=(1, 2, 3)).permute(2, 0, 1)
201   // {3, 1, 2}
202   // changed stride {1, 3, 1} => {2, 1, 2} because {1, 3, 1} is not
203   // the right value.
204   Tensor t_incontiguous = tf_float_.make(sizes, data, /*strides=*/{2, 1, 2});
205 
206   // Assert t_contiguous is contiguous.
207   ET_CHECK_CONTIGUOUS(t_contiguous);
208 
209   // Assert t_incontiguous is incontiguous.
210   ET_EXPECT_DEATH(ET_CHECK_CONTIGUOUS(t_incontiguous), "");
211 }
212 
TEST_F(TensorUtilTest,CheckSameContiguousStrideSupported)213 TEST_F(TensorUtilTest, CheckSameContiguousStrideSupported) {
214   // Tensors in the following list share same stride.
215   std::vector<Tensor> same_stride_tensor_list = {
216       tf_float_.ones(/*sizes=*/{1, 2, 3, 4}),
217       tf_byte_.ones(/*sizes=*/{4, 2, 3, 4}),
218       tf_int_.ones(/*sizes=*/{10, 2, 3, 4}),
219       tf_float_.make(
220           /*sizes=*/{0, 2, 3, 4}, /*data=*/{}, /*strides=*/{24, 12, 4, 1})};
221 
222   // different_stride = tensor(size=(0,2,3,4)).permute(0, 2, 3, 1)
223   // {0, 3, 4, 2}
224   // stride for (0, 2, 3, 4) with permute = (24, 1, 8, 2)
225   // So change stride from {24, 3, 1, 6} => {24, 1, 8, 2}
226   Tensor different_stride = tf_float_.make(
227       /*sizes=*/{0, 2, 3, 4}, /*data=*/{}, /*strides=*/{24, 1, 8, 2});
228 
229   // Any two tensors in `same_stride_tensor_list` have same strides. The two
230   // could contain duplicate tensors.
231   for (int i = 0; i < same_stride_tensor_list.size(); i++) {
232     for (int j = i; j < same_stride_tensor_list.size(); j++) {
233       auto ti = same_stride_tensor_list[i];
234       auto tj = same_stride_tensor_list[j];
235       ET_CHECK_SAME_STRIDES2(ti, tj);
236     }
237   }
238 
239   // Any tensor in `same_stride_tensor_list` shall not have same stride with
240   // `different_stride`.
241   for (int i = 0; i < same_stride_tensor_list.size(); i++) {
242     auto ti = same_stride_tensor_list[i];
243     ET_EXPECT_DEATH(ET_CHECK_SAME_STRIDES2(ti, different_stride), "");
244   }
245 
246   // Any three tensors in same_stride_tensor_list have same strides. The three
247   // could contain duplicate tensors.
248   for (size_t i = 0; i < same_stride_tensor_list.size(); i++) {
249     for (size_t j = i; j < same_stride_tensor_list.size(); j++) {
250       for (size_t k = j; k < same_stride_tensor_list.size(); k++) {
251         auto ti = same_stride_tensor_list[i];
252         auto tj = same_stride_tensor_list[j];
253         auto tk = same_stride_tensor_list[k];
254         ET_CHECK_SAME_STRIDES3(ti, tj, tk);
255       }
256     }
257   }
258 
259   // Any two tensors in same_stride_tensor_list shall not have same strides with
260   // `different_stride`. The two could contain duplicate tensors.
261   for (int i = 0; i < same_stride_tensor_list.size(); i++) {
262     for (int j = i; j < same_stride_tensor_list.size(); j++) {
263       auto ti = same_stride_tensor_list[i];
264       auto tj = same_stride_tensor_list[j];
265       ET_EXPECT_DEATH(ET_CHECK_SAME_STRIDES3(ti, tj, different_stride), "");
266     }
267   }
268 }
269 
TEST_F(TensorUtilTest,ExtractIntScalarTensorSmoke)270 TEST_F(TensorUtilTest, ExtractIntScalarTensorSmoke) {
271   Tensor t = tf_int_.ones({1});
272   bool ok;
273 #define CASE_INT_DTYPE(ctype, unused)          \
274   ctype out_##ctype;                           \
275   ok = extract_scalar_tensor(t, &out_##ctype); \
276   ASSERT_TRUE(ok);                             \
277   EXPECT_EQ(out_##ctype, 1);
278 
279   ET_FORALL_INT_TYPES(CASE_INT_DTYPE);
280 #undef CASE_INT_DTYPE
281 }
282 
TEST_F(TensorUtilTest,ExtractFloatScalarTensorFloatingTypeSmoke)283 TEST_F(TensorUtilTest, ExtractFloatScalarTensorFloatingTypeSmoke) {
284   Tensor t = tf_float_.ones({1});
285   bool ok;
286 #define CASE_FLOAT_DTYPE(ctype, unused)        \
287   ctype out_##ctype;                           \
288   ok = extract_scalar_tensor(t, &out_##ctype); \
289   ASSERT_TRUE(ok);                             \
290   EXPECT_EQ(out_##ctype, 1.0);
291 
292   ET_FORALL_FLOAT_TYPES(CASE_FLOAT_DTYPE);
293 #undef CASE_FLOAT_DTYPE
294 }
295 
TEST_F(TensorUtilTest,ExtractFloatScalarTensorIntegralTypeSmoke)296 TEST_F(TensorUtilTest, ExtractFloatScalarTensorIntegralTypeSmoke) {
297   Tensor t = tf_int_.ones({1});
298   bool ok;
299 #define CASE_FLOAT_DTYPE(ctype, unused)        \
300   ctype out_##ctype;                           \
301   ok = extract_scalar_tensor(t, &out_##ctype); \
302   ASSERT_TRUE(ok);                             \
303   EXPECT_EQ(out_##ctype, 1.0);
304 
305   ET_FORALL_INT_TYPES(CASE_FLOAT_DTYPE);
306 #undef CASE_FLOAT_DTYPE
307 }
308 
TEST_F(TensorUtilTest,ExtractBoolScalarTensorSmoke)309 TEST_F(TensorUtilTest, ExtractBoolScalarTensorSmoke) {
310   Tensor t = tf_bool_.ones({1});
311   bool out;
312   bool ok;
313   ok = extract_scalar_tensor(t, &out);
314   ASSERT_TRUE(ok);
315   EXPECT_EQ(out, true);
316 }
317 
TEST_F(TensorUtilTest,FloatScalarTensorStressTests)318 TEST_F(TensorUtilTest, FloatScalarTensorStressTests) {
319   float value;
320   bool ok;
321 
322   // Case: Positive Infinity
323   Tensor t_pos_inf = tf_double_.make({1}, {INFINITY});
324   ok = extract_scalar_tensor(t_pos_inf, &value);
325   EXPECT_TRUE(ok);
326   EXPECT_TRUE(std::isinf(value));
327 
328   // Case: Negative Infinity
329   Tensor t_neg_inf = tf_double_.make({1}, {-INFINITY});
330   ok = extract_scalar_tensor(t_neg_inf, &value);
331   EXPECT_TRUE(ok);
332   EXPECT_TRUE(std::isinf(value));
333 
334   // Case: Not a Number (NaN) - ex: sqrt(-1.0)
335   Tensor t_nan = tf_double_.make({1}, {NAN});
336   ok = extract_scalar_tensor(t_nan, &value);
337   EXPECT_TRUE(ok);
338   EXPECT_TRUE(std::isnan(value));
339 }
340 
TEST_F(TensorUtilTest,IntScalarTensorNotIntegralTypeFails)341 TEST_F(TensorUtilTest, IntScalarTensorNotIntegralTypeFails) {
342   Tensor t = tf_float_.ones({1});
343   int64_t out;
344   // Fails since tensor is floating type but attempting to extract integer
345   // value.
346   bool ok = extract_scalar_tensor(t, &out);
347   EXPECT_FALSE(ok);
348 }
349 
TEST_F(TensorUtilTest,FloatScalarTensorNotFloatingTypeFails)350 TEST_F(TensorUtilTest, FloatScalarTensorNotFloatingTypeFails) {
351   Tensor t = tf_bool_.ones({1});
352   double out;
353   // Fails since tensor is boolean type but attempting to extract float value.
354   bool ok = extract_scalar_tensor(t, &out);
355   EXPECT_FALSE(ok);
356 }
357 
TEST_F(TensorUtilTest,IntTensorNotScalarFails)358 TEST_F(TensorUtilTest, IntTensorNotScalarFails) {
359   Tensor t = tf_int_.ones({2, 3});
360   int64_t out;
361   // Fails since tensor has multiple dims and values.
362   bool ok = extract_scalar_tensor(t, &out);
363   EXPECT_FALSE(ok);
364 }
365 
TEST_F(TensorUtilTest,FloatTensorNotScalarFails)366 TEST_F(TensorUtilTest, FloatTensorNotScalarFails) {
367   Tensor t = tf_float_.ones({2, 3});
368   double out;
369   // Fails since tensor has multiple dims and values.
370   bool ok = extract_scalar_tensor(t, &out);
371   EXPECT_FALSE(ok);
372 }
373 
TEST_F(TensorUtilTest,IntTensorOutOfBoundFails)374 TEST_F(TensorUtilTest, IntTensorOutOfBoundFails) {
375   Tensor t = tf_int_.make({1}, {256});
376   int8_t out;
377   // Fails since 256 is out of bounds for `int8_t` (-128 to 127).
378   bool ok = extract_scalar_tensor(t, &out);
379   EXPECT_FALSE(ok);
380 }
381 
TEST_F(TensorUtilTest,FloatTensorOutOfBoundFails)382 TEST_F(TensorUtilTest, FloatTensorOutOfBoundFails) {
383   Tensor t = tf_double_.make({1}, {1.0}); // Placeholder value.
384   float out;
385   bool ok;
386 
387 #define CASE_FLOAT(value)              \
388   t = tf_double_.make({1}, {value});   \
389   ok = extract_scalar_tensor(t, &out); \
390   EXPECT_FALSE(ok);
391 
392   // Float tensor can't handle double's largest negative value (note the use of
393   // `lowest` rather than `min`).
394   CASE_FLOAT(std::numeric_limits<double>::lowest());
395 
396   // Float tensor can't handle double's largest positive value.
397   CASE_FLOAT(std::numeric_limits<double>::max());
398 
399 #undef CASE_FLOAT
400 }
401 
TEST_F(TensorUtilTest,BoolScalarTensorNotBooleanTypeFails)402 TEST_F(TensorUtilTest, BoolScalarTensorNotBooleanTypeFails) {
403   Tensor c = tf_byte_.ones({1});
404   bool out;
405   // Fails since tensor is integral type but attempting to extract boolean
406   // value.
407   bool ok = extract_scalar_tensor(c, &out);
408   EXPECT_FALSE(ok);
409 }
410 
TEST_F(TensorUtilTest,BoolTensorNotScalarFails)411 TEST_F(TensorUtilTest, BoolTensorNotScalarFails) {
412   Tensor c = tf_bool_.ones({2, 3});
413   bool out;
414   // Fails since tensor has multiple dims and values.
415   bool ok = extract_scalar_tensor(c, &out);
416   EXPECT_FALSE(ok);
417 }
418 
419 //
420 // Tests for utility functions that check tensor attributes
421 //
422 
TEST_F(TensorUtilTest,TensorIsRankTest)423 TEST_F(TensorUtilTest, TensorIsRankTest) {
424   using executorch::runtime::tensor_is_rank;
425   Tensor a = tf_float_.ones({2, 3, 5});
426 
427   EXPECT_TRUE(tensor_is_rank(a, 3));
428   EXPECT_FALSE(tensor_is_rank(a, 0));
429   EXPECT_FALSE(tensor_is_rank(a, 5));
430 }
431 
TEST_F(TensorUtilTest,TensorHasDimTest)432 TEST_F(TensorUtilTest, TensorHasDimTest) {
433   using executorch::runtime::tensor_has_dim;
434   Tensor a = tf_float_.ones({2, 3, 5});
435 
436   EXPECT_TRUE(tensor_has_dim(a, 2));
437   EXPECT_TRUE(tensor_has_dim(a, 1));
438   EXPECT_TRUE(tensor_has_dim(a, 0));
439   EXPECT_TRUE(tensor_has_dim(a, -1));
440   EXPECT_TRUE(tensor_has_dim(a, -2));
441   EXPECT_TRUE(tensor_has_dim(a, -3));
442 
443   EXPECT_FALSE(tensor_has_dim(a, -4));
444   EXPECT_FALSE(tensor_has_dim(a, 3));
445 }
446 
TEST_F(TensorUtilTest,TensorsHaveSameDtypeTest)447 TEST_F(TensorUtilTest, TensorsHaveSameDtypeTest) {
448   using executorch::runtime::tensors_have_same_dtype;
449   Tensor a = tf_float_.ones({2, 3});
450   Tensor b = tf_float_.ones({2, 3});
451   Tensor c = tf_float_.ones({3, 3});
452   Tensor d = tf_int_.ones({4, 3});
453 
454   EXPECT_TRUE(tensors_have_same_dtype(a, b));
455   EXPECT_FALSE(tensors_have_same_dtype(a, d));
456   EXPECT_TRUE(tensors_have_same_dtype(a, b, c));
457   EXPECT_FALSE(tensors_have_same_dtype(a, b, d));
458 }
459 
TEST_F(TensorUtilTest,TensorsHaveSameSizeAtDimTest)460 TEST_F(TensorUtilTest, TensorsHaveSameSizeAtDimTest) {
461   using executorch::runtime::tensors_have_same_size_at_dims;
462   Tensor a = tf_float_.ones({2, 3, 4, 5});
463   Tensor b = tf_float_.ones({5, 4, 3, 2});
464 
465   EXPECT_TRUE(tensors_have_same_size_at_dims(a, 0, b, 3));
466   EXPECT_TRUE(tensors_have_same_size_at_dims(a, 1, b, 2));
467   EXPECT_FALSE(tensors_have_same_size_at_dims(a, 1, b, 0));
468   EXPECT_FALSE(tensors_have_same_size_at_dims(a, 4, b, 0));
469   EXPECT_FALSE(tensors_have_same_size_at_dims(a, 2, b, 3));
470 }
471 
TEST_F(TensorUtilTest,TensorsHaveSameShapeTest)472 TEST_F(TensorUtilTest, TensorsHaveSameShapeTest) {
473   using executorch::runtime::tensors_have_same_shape;
474   Tensor a = tf_float_.ones({2, 3});
475   Tensor b = tf_int_.ones({2, 3});
476   Tensor c = tf_byte_.ones({2, 3});
477   Tensor d = tf_double_.ones({3, 2});
478   Tensor e = tf_bool_.ones({3, 2});
479 
480   EXPECT_TRUE(tensors_have_same_shape(a, b));
481   EXPECT_FALSE(tensors_have_same_shape(a, d));
482   EXPECT_TRUE(tensors_have_same_shape(d, e));
483   EXPECT_TRUE(tensors_have_same_shape(a, b, c));
484   EXPECT_FALSE(tensors_have_same_shape(a, b, d));
485   EXPECT_FALSE(tensors_have_same_shape(a, d, e));
486 
487   Tensor scalar_a = tf_float_.ones({1, 1});
488   Tensor scalar_b = tf_double_.ones({1});
489   Tensor scalar_c = tf_int_.ones({1, 1, 1, 1});
490 
491   EXPECT_TRUE(tensors_have_same_shape(scalar_a, scalar_b));
492   EXPECT_TRUE(tensors_have_same_shape(scalar_a, scalar_b, scalar_c));
493 }
494 
TEST_F(TensorUtilTest,TensorsHaveSameShapeAndDtypeTest)495 TEST_F(TensorUtilTest, TensorsHaveSameShapeAndDtypeTest) {
496   using executorch::runtime::tensors_have_same_shape_and_dtype;
497   Tensor a = tf_float_.ones({2, 3});
498   Tensor b = tf_float_.ones({2, 3});
499   Tensor c = tf_float_.ones({2, 3});
500   Tensor d = tf_double_.ones({2, 3});
501   Tensor e = tf_float_.ones({3, 2});
502 
503   EXPECT_TRUE(tensors_have_same_shape_and_dtype(a, b));
504   EXPECT_FALSE(tensors_have_same_shape_and_dtype(a, d));
505   EXPECT_TRUE(tensors_have_same_shape_and_dtype(a, b, c));
506   EXPECT_FALSE(tensors_have_same_shape_and_dtype(a, b, d));
507   EXPECT_FALSE(tensors_have_same_shape_and_dtype(a, d, e));
508 
509   Tensor scalar_a = tf_float_.ones({1, 1});
510   Tensor scalar_b = tf_float_.ones({1});
511   Tensor scalar_c = tf_float_.ones({1, 1, 1, 1});
512 
513   EXPECT_TRUE(tensors_have_same_shape_and_dtype(scalar_a, scalar_b));
514   EXPECT_TRUE(tensors_have_same_shape_and_dtype(scalar_a, scalar_b, scalar_c));
515 }
516 
TEST_F(TensorUtilTest,TensorsHaveSameStridesTest)517 TEST_F(TensorUtilTest, TensorsHaveSameStridesTest) {
518   using executorch::runtime::tensors_have_same_strides;
519   Tensor a = tf_float_.full_channels_last({4, 5, 2, 3}, 1);
520   Tensor b = tf_float_.full_channels_last({4, 5, 2, 3}, 2);
521   Tensor c = tf_float_.full_channels_last({4, 5, 2, 3}, 3);
522   Tensor d = tf_double_.ones({4, 5, 2, 3});
523   Tensor e = tf_float_.ones({4, 5, 2, 3});
524 
525   EXPECT_TRUE(tensors_have_same_strides(a, b));
526   EXPECT_FALSE(tensors_have_same_strides(a, d));
527   EXPECT_TRUE(tensors_have_same_strides(a, b, c));
528   EXPECT_FALSE(tensors_have_same_strides(a, b, d));
529   EXPECT_FALSE(tensors_have_same_strides(a, d, e));
530 }
531 
TEST_F(TensorUtilTest,TensorIsContiguous)532 TEST_F(TensorUtilTest, TensorIsContiguous) {
533   using executorch::runtime::tensor_is_contiguous;
534   // Note that the strides.size() == 0 case is not tested, since
535   Tensor a = tf_float_.full_channels_last({4, 5, 2, 3}, 1);
536   Tensor b = tf_float_.ones({4, 5, 2, 3});
537   Tensor c = tf_float_.full_channels_last({1, 1, 1, 1}, 1);
538   Tensor d = tf_float_.ones({});
539 
540   EXPECT_FALSE(tensor_is_contiguous(a));
541   EXPECT_TRUE(tensor_is_contiguous(b));
542   EXPECT_TRUE(tensor_is_contiguous(c));
543   EXPECT_TRUE(tensor_is_contiguous(d));
544 }
545 
TEST_F(TensorUtilTest,ResizeZeroDimTensor)546 TEST_F(TensorUtilTest, ResizeZeroDimTensor) {
547   Tensor a = tf_float_.ones({});
548 
549   EXPECT_EQ(
550       executorch::runtime::resize_tensor(a, {}),
551       executorch::runtime::Error::Ok);
552   EXPECT_EQ(a.dim(), 0);
553 }
554 
TEST_F(TensorUtilTest,SameDimOrderContiguous)555 TEST_F(TensorUtilTest, SameDimOrderContiguous) {
556   using namespace torch::executor;
557   // Three different tensors with the same shape and same dim order
558   // ([0, 1, 2, 3]), but different dtypes and contents.
559   std::vector<int32_t> sizes = {3, 5, 2, 1};
560   Tensor a = tf_byte_.ones(sizes);
561   Tensor b = tf_int_.zeros(sizes);
562   Tensor c = tf_float_.full(sizes, 0.1);
563 
564   // The tensors have the same dim order, should pass the following checks.
565   EXPECT_TRUE(tensors_have_same_dim_order(a, b));
566   EXPECT_TRUE(tensors_have_same_dim_order(b, a));
567   EXPECT_TRUE(tensors_have_same_dim_order(a, b, c));
568   EXPECT_TRUE(tensors_have_same_dim_order(b, c, a));
569   EXPECT_TRUE(tensors_have_same_dim_order(c, a, b));
570 }
571 
TEST_F(TensorUtilTest,SameDimOrderChannelsLast)572 TEST_F(TensorUtilTest, SameDimOrderChannelsLast) {
573   using namespace torch::executor;
574   // Three different tensors with the same shape and same dim order
575   // ([0, 2, 3, 1]), but different dtypes and contents.
576   std::vector<int32_t> sizes = {3, 5, 2, 1};
577   Tensor a = tf_byte_.full_channels_last(sizes, 1);
578   Tensor b = tf_int_.full_channels_last(sizes, 0);
579   Tensor c = tf_float_.full_channels_last(sizes, 0.1);
580 
581   // The tensors have the same dim order, should pass the following checks.
582   EXPECT_TRUE(tensors_have_same_dim_order(a, b));
583   EXPECT_TRUE(tensors_have_same_dim_order(b, a));
584   EXPECT_TRUE(tensors_have_same_dim_order(a, b, c));
585   EXPECT_TRUE(tensors_have_same_dim_order(b, c, a));
586   EXPECT_TRUE(tensors_have_same_dim_order(c, a, b));
587 }
588 
TEST_F(TensorUtilTest,SameShapesDifferentDimOrder)589 TEST_F(TensorUtilTest, SameShapesDifferentDimOrder) {
590   using namespace torch::executor;
591   // Three different tensors with the same shape but different dtypes and
592   // contents, where b and c have the same dim order ([0, 2, 3, 1]) while a is
593   // different ([0, 1, 2, 3]).
594   std::vector<int32_t> sizes = {3, 5, 2, 1};
595   Tensor a = tf_byte_.ones(sizes);
596   Tensor b = tf_int_.full_channels_last(sizes, 0);
597   Tensor c = tf_float_.full_channels_last(sizes, 0.1);
598 
599   // Not the same dim order. Chec
600   EXPECT_FALSE(tensors_have_same_dim_order(a, b));
601   EXPECT_FALSE(tensors_have_same_dim_order(b, a));
602 
603   // Test with a mismatching tensor in all positions, where the other two agree.
604   EXPECT_FALSE(tensors_have_same_dim_order(a, b, c));
605   EXPECT_FALSE(tensors_have_same_dim_order(a, c, b));
606   EXPECT_FALSE(tensors_have_same_dim_order(c, b, a));
607 }
608