xref: /aosp_15_r20/external/executorch/kernels/test/op_gather_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 
16 #include <gtest/gtest.h>
17 #include <cmath>
18 
19 using namespace ::testing;
20 using exec_aten::Scalar;
21 using exec_aten::ScalarType;
22 using exec_aten::Tensor;
23 using torch::executor::testing::TensorFactory;
24 
25 class OpGatherOutTest : public OperatorTest {
26  protected:
op_gather_out(const Tensor & self,int64_t dim,const Tensor & index,bool sparse_grad,Tensor & out)27   Tensor& op_gather_out(
28       const Tensor& self,
29       int64_t dim,
30       const Tensor& index,
31       bool sparse_grad,
32       Tensor& out) {
33     return torch::executor::aten::gather_outf(
34         context_, self, dim, index, sparse_grad, out);
35   }
36 
37   // Common testing for the operator
38   template <ScalarType DATA_DTYPE>
test_gather_out()39   void test_gather_out() {
40     TensorFactory<ScalarType::Long> tf_index;
41     TensorFactory<DATA_DTYPE> tf_data;
42     const std::vector<int32_t> sizes = {2, 3};
43     // clang-format off
44     Tensor self = tf_data.make(
45       /*sizes=*/{2, 5},
46       {
47         1, 2, 3, 4, 5,
48         6, 7, 8, 9, 10
49       });
50     // clang-format on
51     Tensor out = tf_data.zeros(sizes);
52     // clang-format off
53     bool sparse_grad = false;
54     Tensor index = tf_index.make(sizes,
55       {
56         0, 1, 0,
57         1, 0, 1,
58       });
59     // clang-format on
60 
61     // Valid input should give the expected output
62     op_gather_out(self, 0, index, sparse_grad, out);
63     // clang-format off
64     EXPECT_TENSOR_EQ(
65         out, tf_data.make(
66           sizes,
67           {
68             1, 7, 3,
69             6, 2, 8,
70           }));
71     // clang-format on
72 
73     // Valid input should give the expected output
74     op_gather_out(self, 1, index, sparse_grad, out);
75     // clang-format off
76     EXPECT_TENSOR_EQ(
77         out, tf_data.make(sizes,
78         {
79           1, 2, 1,
80           7, 6, 7,
81         }));
82 
83     self = tf_data.make(
84         /*sizes=*/{2, 3, 3},
85         {
86           // [0, :, :]
87           1,  2,  3,
88           4,  5,  6,
89           7,  8,  9,
90 
91           // [1, :, :]
92           10, 11, 12,
93           13, 14, 15,
94           16, 17, 18
95         });
96     index = tf_index.make(
97       /*sizes=*/{1, 3, 2},
98       {
99         0, 1,
100         1, 2,
101         0, 2
102       });
103     // clang-format on
104     out = tf_data.zeros(/*sizes=*/{1, 3, 2});
105 
106     op_gather_out(self, 1, index, sparse_grad, out);
107     // clang-format off
108     EXPECT_TENSOR_EQ(
109         out,
110         tf_data.make(
111             /*sizes=*/{1, 3, 2},
112             {
113               1, 5,
114               4, 8,
115               1, 8,
116             }));
117     // clang-format on
118 
119     out = tf_data.zeros(/*sizes=*/{1, 3, 2});
120     op_gather_out(self, 2, index, sparse_grad, out);
121     // clang-format off
122     EXPECT_TENSOR_EQ(
123         out,
124         tf_data.make(
125             /*sizes=*/{1, 3, 2},
126             {
127               1, 2,
128               5, 6,
129               7, 9,
130             }));
131     // clang-format on
132   }
133 
134   // Invalid dimensions
135   template <ScalarType DATA_DTYPE>
test_gather_out_invalid_dim()136   void test_gather_out_invalid_dim() {
137     TensorFactory<ScalarType::Long> tf_index;
138     TensorFactory<DATA_DTYPE> tf_data;
139     // clang-format off
140     Tensor self = tf_data.make(/*sizes=*/{2, 5},
141       {
142         1, 2, 3, 4, 5,
143         6, 7, 8, 9, 10
144       });
145     const std::vector<int32_t> sizes = {2, 3};
146     Tensor index = tf_index.make(sizes,
147       {
148         0, 1, 0,
149         1, 0, 1,
150       });
151     // clang-format on
152     bool sparse_grad = false;
153     Tensor out = tf_data.zeros(sizes);
154 
155     // Invalid dim should die
156     ET_EXPECT_KERNEL_FAILURE(
157         context_, op_gather_out(self, -3, index, sparse_grad, out));
158     ET_EXPECT_KERNEL_FAILURE(
159         context_, op_gather_out(self, 2, index, sparse_grad, out));
160 
161     // Self and index hsould have same number of dimensions
162     index = tf_index.zeros(/*sizes=*/{2, 2, 2});
163     ET_EXPECT_KERNEL_FAILURE(
164         context_, op_gather_out(self, 0, index, sparse_grad, out));
165 
166     // Size of dimension of index should be smaller than the size of that
167     // dimension of self if dimension != dim
168     index = tf_index.zeros(/*sizes=*/{3, 5});
169     ET_EXPECT_KERNEL_FAILURE(
170         context_, op_gather_out(self, 1, index, sparse_grad, out));
171 
172     // Index out of bound for self in dim
173     index = tf_index.make(/*sizes=*/{2, 3}, {0, 1, 2, 0, 1, 2});
174     ET_EXPECT_KERNEL_FAILURE(
175         context_, op_gather_out(self, 0, index, sparse_grad, out));
176   }
177 
test_dynamic_shape(const std::vector<int32_t> & out_shape,enum torch::executor::TensorShapeDynamism dynamism)178   void test_dynamic_shape(
179       const std::vector<int32_t>& out_shape,
180       enum torch::executor::TensorShapeDynamism dynamism) {
181     TensorFactory<ScalarType::Int> tf;
182     TensorFactory<ScalarType::Long> tf_index;
183 
184     Tensor input = tf.ones({2, 3, 4});
185     Tensor index = tf_index.zeros({2, 3, 4});
186     bool sparse_grad = false;
187     Tensor expected = tf.ones({2, 3, 4});
188     Tensor out = tf.zeros(out_shape, dynamism);
189 
190     op_gather_out(input, 2, index, sparse_grad, out);
191     EXPECT_TENSOR_EQ(out, expected);
192   }
193 };
194 
TEST_F(OpGatherOutTest,AllValidInputOutputSupport)195 TEST_F(OpGatherOutTest, AllValidInputOutputSupport) {
196 #define TEST_ENTRY(CTYPE, DTYPE) test_gather_out<ScalarType::DTYPE>();
197   ET_FORALL_REAL_TYPES(TEST_ENTRY);
198 #undef TEST_ENTRY
199 }
200 
TEST_F(OpGatherOutTest,InfinityAndNANTest)201 TEST_F(OpGatherOutTest, InfinityAndNANTest) {
202   TensorFactory<ScalarType::Long> tf_index;
203   TensorFactory<ScalarType::Float> tf_data;
204   // clang-format off
205   Tensor self = tf_data.make(
206       /*sizes=*/{2, 5},
207       {
208         INFINITY, -INFINITY, NAN,       2.33, 3.14,
209         NAN,      INFINITY,  -INFINITY, 3.14, 2.33
210       });
211   // clang-format on
212   const std::vector<int32_t> sizes = {2, 3};
213   Tensor index = tf_index.make(sizes, {0, 1, 0, 1, 0, 1});
214   bool sparse_grad = false;
215   Tensor out = tf_data.zeros(sizes);
216 
217   // Valid input should give the expected output
218   op_gather_out(self, 0, index, sparse_grad, out);
219   // clang-format off
220   EXPECT_TENSOR_CLOSE(
221       out,
222       tf_data.make(sizes,
223       {
224         INFINITY, INFINITY, NAN,
225         NAN, -INFINITY, -INFINITY,
226       }));
227   // clang-format on
228 }
229 
TEST_F(OpGatherOutTest,InvalidDimensionsDies)230 TEST_F(OpGatherOutTest, InvalidDimensionsDies) {
231 #define TEST_ENTRY(CTYPE, DTYPE) \
232   test_gather_out_invalid_dim<ScalarType::DTYPE>();
233   ET_FORALL_REAL_TYPES(TEST_ENTRY);
234 #undef TEST_ENTRY
235 }
236 
TEST_F(OpGatherOutTest,MismatchedInputDtypesDies)237 TEST_F(OpGatherOutTest, MismatchedInputDtypesDies) {
238   TensorFactory<ScalarType::Byte> tf_byte;
239   TensorFactory<ScalarType::Char> tf_char;
240   TensorFactory<ScalarType::Long> tf_long;
241 
242   Tensor self = tf_char.make({2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
243   const std::vector<int32_t> sizes = {2, 3};
244   Tensor index = tf_byte.make(sizes, {0, 1, 0, 0, 1, 0});
245   bool sparse_grad = false;
246   Tensor out = tf_char.zeros(sizes);
247 
248   // Types other than long for index should die
249   ET_EXPECT_KERNEL_FAILURE(
250       context_, op_gather_out(self, 0, index, sparse_grad, out));
251 
252   // Mismatched dtype of self and out should die
253   self = tf_byte.make(/*sizes=*/{2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
254   index = tf_long.make(sizes, {0, 1, 0, 1, 0, 1});
255   out = tf_char.zeros(sizes);
256   ET_EXPECT_KERNEL_FAILURE(
257       context_, op_gather_out(self, 0, index, sparse_grad, out));
258 }
259 
TEST_F(OpGatherOutTest,DynamicShapeUpperBoundSameAsExpected)260 TEST_F(OpGatherOutTest, DynamicShapeUpperBoundSameAsExpected) {
261   test_dynamic_shape(
262       {2, 3, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
263 }
264 
TEST_F(OpGatherOutTest,DynamicShapeUpperBoundLargerThanExpected)265 TEST_F(OpGatherOutTest, DynamicShapeUpperBoundLargerThanExpected) {
266   test_dynamic_shape(
267       {10, 10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
268 }
269 
TEST_F(OpGatherOutTest,DynamicShapeUnbound)270 TEST_F(OpGatherOutTest, DynamicShapeUnbound) {
271   if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
272     GTEST_SKIP() << "Dynamic shape not supported";
273   }
274   test_dynamic_shape(
275       {1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
276 }
277 
TEST_F(OpGatherOutTest,EmptyIndex)278 TEST_F(OpGatherOutTest, EmptyIndex) {
279   TensorFactory<ScalarType::Long> tf_index;
280   TensorFactory<ScalarType::Float> tf_data;
281 
282   Tensor self = tf_data.ones({2, 5});
283   const std::vector<int32_t> sizes = {2, 0, 3};
284   Tensor index = tf_index.zeros(sizes);
285   bool sparse_grad = false;
286   Tensor out = tf_data.zeros(sizes);
287   op_gather_out(self, 0, index, sparse_grad, out);
288   EXPECT_TENSOR_CLOSE(out, tf_data.zeros(sizes));
289 }
290 
TEST_F(OpGatherOutTest,ValidZeroDim)291 TEST_F(OpGatherOutTest, ValidZeroDim) {
292   TensorFactory<ScalarType::Long> tf_index;
293   TensorFactory<ScalarType::Float> tf_data;
294 
295   Tensor self = tf_data.make({}, {3.14});
296   Tensor index = tf_index.zeros({});
297   bool sparse_grad = false;
298   Tensor out = tf_data.zeros({});
299   op_gather_out(self, 0, index, sparse_grad, out);
300   EXPECT_TENSOR_CLOSE(out, tf_data.make({}, {3.14}));
301 }
302 
TEST_F(OpGatherOutTest,InvalidZeroDimInput)303 TEST_F(OpGatherOutTest, InvalidZeroDimInput) {
304   TensorFactory<ScalarType::Long> tf_index;
305   TensorFactory<ScalarType::Float> tf_data;
306 
307   Tensor self = tf_data.ones({});
308   const std::vector<int32_t> sizes = {2, 3};
309   Tensor index = tf_index.make(sizes, {0, 0, 0, 0, 0, 0});
310   bool sparse_grad = false;
311   Tensor out = tf_data.zeros(sizes);
312   ET_EXPECT_KERNEL_FAILURE(
313       context_, op_gather_out(self, 0, index, sparse_grad, out));
314 }
315 
TEST_F(OpGatherOutTest,InvalidZeroDimIndex)316 TEST_F(OpGatherOutTest, InvalidZeroDimIndex) {
317   TensorFactory<ScalarType::Long> tf_index;
318   TensorFactory<ScalarType::Float> tf_data;
319 
320   Tensor self = tf_data.make({2, 3}, {1, 2, 3, 4, 5, 6});
321   const std::vector<int32_t> sizes = {};
322   Tensor index = tf_index.make(sizes, {2});
323   bool sparse_grad = false;
324   Tensor out = tf_data.zeros(sizes);
325   ET_EXPECT_KERNEL_FAILURE(
326       context_, op_gather_out(self, 1, index, sparse_grad, out));
327 }
328 
TEST_F(OpGatherOutTest,ValidZeroDimInputAndOneDimIndex)329 TEST_F(OpGatherOutTest, ValidZeroDimInputAndOneDimIndex) {
330   TensorFactory<ScalarType::Long> tf_index;
331   TensorFactory<ScalarType::Float> tf_data;
332 
333   Tensor self = tf_data.make({}, {3.14});
334   const std::vector<int32_t> sizes = {3};
335   Tensor index = tf_index.make(sizes, {0, 0, 0});
336   bool sparse_grad = false;
337   Tensor out = tf_data.make({3}, {2.71, 2.71, 2.71});
338   op_gather_out(self, 0, index, sparse_grad, out);
339   EXPECT_TENSOR_CLOSE(out, tf_data.make({3}, {3.14, 3.14, 3.14}));
340 }
341 
TEST_F(OpGatherOutTest,ValidOneDimInputAndZeroDimIndex)342 TEST_F(OpGatherOutTest, ValidOneDimInputAndZeroDimIndex) {
343   TensorFactory<ScalarType::Long> tf_index;
344   TensorFactory<ScalarType::Float> tf_data;
345 
346   Tensor self = tf_data.make({3}, {10, 20, 30});
347   const std::vector<int32_t> sizes = {};
348   Tensor index = tf_index.make(sizes, {2});
349   bool sparse_grad = false;
350   Tensor out = tf_data.make(sizes, {1729});
351   op_gather_out(self, 0, index, sparse_grad, out);
352   EXPECT_TENSOR_CLOSE(out, tf_data.make({}, {30}));
353 }
354 
TEST_F(OpGatherOutTest,InvalidZeroDimInputAndOneDimIndex)355 TEST_F(OpGatherOutTest, InvalidZeroDimInputAndOneDimIndex) {
356   TensorFactory<ScalarType::Long> tf_index;
357   TensorFactory<ScalarType::Float> tf_data;
358 
359   Tensor self = tf_data.make({}, {3.14});
360   const std::vector<int32_t> sizes = {3};
361   Tensor index = tf_index.make(sizes, {10, 100, 1000});
362   bool sparse_grad = false;
363   Tensor out = tf_data.make({3}, {2.71, 2.71, 2.71});
364   ET_EXPECT_KERNEL_FAILURE(
365       context_, op_gather_out(self, 0, index, sparse_grad, out));
366 }
367 
TEST_F(OpGatherOutTest,InvalidOneDimInputAndZeroDimIndex)368 TEST_F(OpGatherOutTest, InvalidOneDimInputAndZeroDimIndex) {
369   TensorFactory<ScalarType::Long> tf_index;
370   TensorFactory<ScalarType::Float> tf_data;
371 
372   Tensor self = tf_data.make({3}, {10, 20, 30});
373   const std::vector<int32_t> sizes = {};
374   Tensor index = tf_index.make(sizes, {100});
375   bool sparse_grad = false;
376   Tensor out = tf_data.make(sizes, {1729});
377   ET_EXPECT_KERNEL_FAILURE(
378       context_, op_gather_out(self, 0, index, sparse_grad, out));
379 }
380