xref: /aosp_15_r20/external/executorch/kernels/test/op_scatter_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 
16 #include <gtest/gtest.h>
17 #include <cmath>
18 
19 using namespace ::testing;
20 using exec_aten::Scalar;
21 using exec_aten::ScalarType;
22 using exec_aten::Tensor;
23 using torch::executor::testing::TensorFactory;
24 
25 class OpScatterSrcOutTest : public OperatorTest {
26  protected:
op_scatter_src_out(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & src,Tensor & out)27   Tensor& op_scatter_src_out(
28       const Tensor& self,
29       int64_t dim,
30       const Tensor& index,
31       const Tensor& src,
32       Tensor& out) {
33     return torch::executor::aten::scatter_outf(
34         context_, self, dim, index, src, out);
35   }
36 
37   // Common testing for the operator
38   template <ScalarType DATA_DTYPE>
test_scatter_src_out()39   void test_scatter_src_out() {
40     TensorFactory<ScalarType::Long> tf_index;
41     TensorFactory<DATA_DTYPE> tf_data;
42     const std::vector<int32_t> sizes = {3, 5};
43     // clang-format off
44     Tensor src = tf_data.make(
45       /*sizes=*/{2, 5},
46       {
47         1, 2, 3, 4, 5,
48         6, 7, 8, 9, 10
49       });
50     // clang-format on
51     Tensor in = tf_data.zeros(sizes);
52     Tensor out = tf_data.zeros(sizes);
53     // clang-format off
54     Tensor index = tf_index.make(
55       /*sizes=*/{2, 3},
56       {
57         0, 1, 2,
58         0, 1, 2
59       });
60     // clang-format on
61 
62     // Valid input should give the expected output
63     op_scatter_src_out(in, 0, index, src, out);
64     // clang-format off
65     EXPECT_TENSOR_EQ(
66         out, tf_data.make(
67           sizes,
68           {
69             6, 0, 0, 0, 0,
70             0, 7, 0, 0, 0,
71             0, 0, 8, 0, 0
72           }));
73     // clang-format on
74 
75     // Valid input should give the expected output
76     op_scatter_src_out(in, 1, index, src, out);
77     // clang-format off
78     EXPECT_TENSOR_EQ(
79         out, tf_data.make(sizes,
80         {
81           1, 2, 3, 0, 0,
82           6, 7, 8, 0, 0,
83           0, 0, 0, 0, 0
84         }));
85 
86     src = tf_data.make(
87         /*sizes=*/{2, 3, 3},
88         {
89           // [0, :, :]
90           1,  2,  3,
91           4,  5,  6,
92           7,  8,  9,
93 
94           // [1, :, :]
95           10, 11, 12,
96           13, 14, 15,
97           16, 17, 18
98         });
99     // clang-format on
100     in = tf_data.ones(/*sizes=*/{2, 3, 3});
101     out = tf_data.zeros(/*sizes=*/{2, 3, 3});
102     // clang-format off
103     index = tf_index.make(
104       /*sizes=*/{1, 3, 2},
105       {
106         0, 1,
107         1, 2,
108         0, 2
109       });
110     // clang-format on
111 
112     op_scatter_src_out(in, 1, index, src, out);
113     // clang-format off
114     EXPECT_TENSOR_EQ(
115         out,
116         tf_data.make(
117             /*sizes=*/{2, 3, 3},
118             {
119               // [0, :, :]
120               7, 1,  1,
121               4, 2,  1,
122               1, 8, 1,
123 
124               // [1, :, :]
125               1, 1,  1,
126               1, 1,  1,
127               1, 1,  1
128             }));
129     // clang-format on
130 
131     out = tf_data.zeros(/*sizes=*/{2, 3, 3});
132     op_scatter_src_out(in, 2, index, src, out);
133     // clang-format off
134     EXPECT_TENSOR_EQ(
135         out,
136         tf_data.make(
137             /*sizes=*/{2, 3, 3},
138             {
139               // [0, :, :]
140               1, 2, 1,
141               1, 4, 5,
142               7, 1, 8,
143 
144               // [1, :, :]
145               1, 1, 1,
146               1, 1, 1,
147               1, 1, 1
148             }));
149     // clang-format on
150   }
151 
152   // Invalid dimensions
153   template <ScalarType DATA_DTYPE>
test_scatter_src_out_invalid_dim()154   void test_scatter_src_out_invalid_dim() {
155     TensorFactory<ScalarType::Long> tf_index;
156     TensorFactory<DATA_DTYPE> tf_data;
157     const std::vector<int32_t> sizes = {3, 5};
158     // clang-format off
159     Tensor src = tf_data.make(/*sizes=*/{2, 5},
160       {
161         1, 2, 3, 4, 5,
162         6, 7, 8, 9, 10
163       });
164     Tensor index = tf_index.make(/*sizes=*/{2, 3},
165       {
166         0, 1, 2,
167         0, 1, 2
168       });
169     // clang-format on
170     Tensor self = tf_data.zeros(sizes);
171     Tensor out = tf_data.zeros(sizes);
172 
173     // Invalid dim should die
174     ET_EXPECT_KERNEL_FAILURE(
175         context_, op_scatter_src_out(self, -3, index, src, out));
176     ET_EXPECT_KERNEL_FAILURE(
177         context_, op_scatter_src_out(self, 2, index, src, out));
178 
179     // Self, index and src hsould have same number of dimensions
180     src = tf_data.zeros(/*sizes=*/{2, 2, 2});
181     ET_EXPECT_KERNEL_FAILURE(
182         context_, op_scatter_src_out(self, 0, index, src, out));
183 
184     src = tf_data.zeros(/*sizes=*/{5, 5});
185     index = tf_index.zeros(/*sizes=*/{2, 2, 2});
186     ET_EXPECT_KERNEL_FAILURE(
187         context_, op_scatter_src_out(self, 0, index, src, out));
188 
189     // Size of dimension of index should be smaller than the size of that
190     // dimension of src
191     index = tf_index.zeros(/*sizes=*/{4, 6});
192     ET_EXPECT_KERNEL_FAILURE(
193         context_, op_scatter_src_out(self, 0, index, src, out));
194 
195     // Size of dimension of index should be smaller than the size of that
196     // dimension of self if dimension != dim
197     index = tf_index.zeros(/*sizes=*/{4, 5});
198     ET_EXPECT_KERNEL_FAILURE(
199         context_, op_scatter_src_out(self, 1, index, src, out));
200 
201     // Index out of bound for self in dim
202     index = tf_index.make(/*sizes=*/{2, 3}, {0, 1, 3, 0, 1, 3});
203     ET_EXPECT_KERNEL_FAILURE(
204         context_, op_scatter_src_out(self, 0, index, src, out));
205   }
206 };
207 
208 class OpScatterValueOutTest : public OperatorTest {
209  protected:
op_scatter_value_out(const Tensor & self,int64_t dim,const Tensor & index,const Scalar & value,Tensor & out)210   Tensor& op_scatter_value_out(
211       const Tensor& self,
212       int64_t dim,
213       const Tensor& index,
214       const Scalar& value,
215       Tensor& out) {
216     return torch::executor::aten::scatter_outf(
217         context_, self, dim, index, value, out);
218   }
219 
220   // Common testing for the operator
221   template <ScalarType DATA_DTYPE>
test_scatter_value_out()222   void test_scatter_value_out() {
223     TensorFactory<ScalarType::Long> tf_index;
224     TensorFactory<DATA_DTYPE> tf_data;
225 
226     const Scalar& value = 1;
227 
228     const std::vector<int32_t> sizes = {3, 5};
229     Tensor self = tf_data.zeros(sizes);
230     Tensor out = tf_data.zeros(sizes);
231     Tensor index = tf_index.make({2, 3}, {0, 1, 2, 0, 1, 2});
232 
233     op_scatter_value_out(self, 0, index, value, out);
234     // clang-format off
235     EXPECT_TENSOR_EQ(
236         out, tf_data.make(
237           sizes,
238           {
239             1, 0, 0,  0, 0,
240             0, 1, 0,  0, 0,
241             0, 0, 1, 0, 0
242           }));
243     // clang-format on
244 
245     op_scatter_value_out(self, 1, index, value, out);
246     // clang-format off
247     EXPECT_TENSOR_EQ(
248         out, tf_data.make(sizes,
249         {
250           1, 1, 1, 0, 0,
251           1, 1, 1, 0, 0,
252           0, 0, 0, 0, 0
253         }));
254 
255     const Scalar& value2 = 2;
256     self = tf_data.ones(/*sizes=*/{2, 3, 3});
257     out = tf_data.zeros(/*sizes=*/{2, 3, 3});
258     // clang-format off
259     index = tf_index.make(
260       /*sizes=*/{1, 3, 2},
261       {
262         0, 1,
263         1, 2,
264         0, 2
265       });
266     // clang-format on
267 
268     op_scatter_value_out(self, 1, index, value2, out);
269     // clang-format off
270     EXPECT_TENSOR_EQ(
271         out,
272         tf_data.make(
273             /*sizes=*/{2, 3, 3},
274             {
275               // [0, :, :]
276               2, 1, 1,
277               2, 2, 1,
278               1, 2, 1,
279 
280               // [1, :, :]
281               1, 1, 1,
282               1, 1, 1,
283               1, 1, 1
284             }));
285     // clang-format on
286 
287     out = tf_data.zeros(/*sizes=*/{2, 3, 3});
288     op_scatter_value_out(self, 2, index, value2, out);
289     // clang-format off
290     EXPECT_TENSOR_EQ(
291         out,
292         tf_data.make(
293             /*sizes=*/{2, 3, 3},
294             {
295               // [0, :, :]
296               2, 2, 1,
297               1, 2, 2,
298               2, 1, 2,
299 
300               // [1, :, :]
301               1, 1, 1,
302               1, 1, 1,
303               1, 1, 1
304             }));
305     // clang-format on
306   }
307 
308   // Invalid dimensions
309   template <ScalarType DATA_DTYPE>
test_scatter_value_out_invalid_dim()310   void test_scatter_value_out_invalid_dim() {
311     TensorFactory<ScalarType::Long> tf_index;
312     TensorFactory<DATA_DTYPE> tf_data;
313     // clang-format off
314     Tensor self = tf_data.make(/*sizes=*/{2, 5},
315       {
316         1, 2, 3, 4, 5,
317         6, 7, 8, 9, 10
318       });
319     const std::vector<int32_t> sizes = {2, 3};
320     Tensor index = tf_index.make(sizes,
321       {
322         0, 1, 0,
323         1, 0, 1,
324       });
325     // clang-format on
326     const Scalar& value = 1;
327     Tensor out = tf_data.zeros(sizes);
328 
329     // Invalid dim should die
330     ET_EXPECT_KERNEL_FAILURE(
331         context_, op_scatter_value_out(self, -3, index, value, out));
332     ET_EXPECT_KERNEL_FAILURE(
333         context_, op_scatter_value_out(self, 2, index, value, out));
334 
335     // Self and index hsould have same number of dimensions
336     index = tf_index.zeros(/*sizes=*/{2, 2, 2});
337     ET_EXPECT_KERNEL_FAILURE(
338         context_, op_scatter_value_out(self, 0, index, value, out));
339 
340     // Size of dimension of index should be smaller than the size of that
341     // dimension of self if dimension != dim
342     index = tf_index.zeros(/*sizes=*/{3, 5});
343     ET_EXPECT_KERNEL_FAILURE(
344         context_, op_scatter_value_out(self, 1, index, value, out));
345 
346     // Index out of bound for self in dim
347     index = tf_index.make(/*sizes=*/{2, 3}, {0, 1, 2, 0, 1, 2});
348     ET_EXPECT_KERNEL_FAILURE(
349         context_, op_scatter_value_out(self, 0, index, value, out));
350   }
351 
test_dynamic_shape(const std::vector<int32_t> & out_shape,enum torch::executor::TensorShapeDynamism dynamism)352   void test_dynamic_shape(
353       const std::vector<int32_t>& out_shape,
354       enum torch::executor::TensorShapeDynamism dynamism) {
355     TensorFactory<ScalarType::Int> tf;
356     TensorFactory<ScalarType::Long> tf_index;
357 
358     Tensor input = tf.ones({2, 3, 4});
359     Tensor index = tf_index.zeros({2, 3, 4});
360     const Scalar& value = 1;
361     Tensor expected = tf.ones({2, 3, 4});
362     Tensor out = tf.zeros(out_shape, dynamism);
363 
364     op_scatter_value_out(input, 2, index, value, out);
365     EXPECT_TENSOR_EQ(out, expected);
366   }
367 };
368 
TEST_F(OpScatterSrcOutTest,AllValidInputOutputSupport)369 TEST_F(OpScatterSrcOutTest, AllValidInputOutputSupport) {
370 #define TEST_ENTRY(CTYPE, DTYPE) test_scatter_src_out<ScalarType::DTYPE>();
371   ET_FORALL_REAL_TYPES(TEST_ENTRY);
372 #undef TEST_ENTRY
373 }
374 
TEST_F(OpScatterSrcOutTest,InvalidDimensionsDies)375 TEST_F(OpScatterSrcOutTest, InvalidDimensionsDies) {
376 #define TEST_ENTRY(CTYPE, DTYPE) \
377   test_scatter_src_out_invalid_dim<ScalarType::DTYPE>();
378   ET_FORALL_REAL_TYPES(TEST_ENTRY);
379 #undef TEST_ENTRY
380 }
381 
TEST_F(OpScatterValueOutTest,AllValidInputOutputSupport)382 TEST_F(OpScatterValueOutTest, AllValidInputOutputSupport) {
383 #define TEST_ENTRY(CTYPE, DTYPE) test_scatter_value_out<ScalarType::DTYPE>();
384   ET_FORALL_REAL_TYPES(TEST_ENTRY);
385 #undef TEST_ENTRY
386 }
387 
TEST_F(OpScatterValueOutTest,InfinityAndNANTest)388 TEST_F(OpScatterValueOutTest, InfinityAndNANTest) {
389   TensorFactory<ScalarType::Long> tf_index;
390   TensorFactory<ScalarType::Float> tf_data;
391   // clang-format off
392   Tensor self = tf_data.make(
393       /*sizes=*/{2, 5},
394       {
395         0.0, -INFINITY,        NAN,      2.33, NAN,
396         NAN,  INFINITY,  -INFINITY, -INFINITY, 2.33
397       });
398   // clang-format on
399   Tensor index = tf_index.make({2, 3}, {0, 1, 0, 1, 0, 1});
400   const Scalar& value = INFINITY;
401   Tensor out = tf_data.zeros({2, 5});
402 
403   // Valid input should give the expected output
404   op_scatter_value_out(self, 0, index, value, out);
405   // clang-format off
406   EXPECT_TENSOR_CLOSE(
407       out,
408       tf_data.make(/*sizes=*/{2, 5},
409       {
410         INFINITY, INFINITY, INFINITY,      2.33, NAN,
411         INFINITY, INFINITY, INFINITY, -INFINITY, 2.33
412       }));
413   // clang-format on
414 }
415 
TEST_F(OpScatterValueOutTest,InvalidDimensionsDies)416 TEST_F(OpScatterValueOutTest, InvalidDimensionsDies) {
417 #define TEST_ENTRY(CTYPE, DTYPE) \
418   test_scatter_value_out_invalid_dim<ScalarType::DTYPE>();
419   ET_FORALL_REAL_TYPES(TEST_ENTRY);
420 #undef TEST_ENTRY
421 }
422 
TEST_F(OpScatterValueOutTest,MismatchedInputDtypesDies)423 TEST_F(OpScatterValueOutTest, MismatchedInputDtypesDies) {
424   TensorFactory<ScalarType::Byte> tf_byte;
425   TensorFactory<ScalarType::Char> tf_char;
426   TensorFactory<ScalarType::Long> tf_long;
427 
428   Tensor self = tf_char.make({2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
429   const std::vector<int32_t> sizes = {2, 3};
430   Tensor index = tf_byte.make(sizes, {0, 1, 0, 0, 1, 0});
431   const Scalar& value = 5;
432   Tensor out = tf_char.zeros(sizes);
433 
434   // Types other than long for index should die
435   ET_EXPECT_KERNEL_FAILURE(
436       context_, op_scatter_value_out(self, 0, index, value, out));
437 
438   // Mismatched dtype of self and out should die
439   self = tf_byte.make(/*sizes=*/{2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
440   index = tf_long.make(sizes, {0, 1, 0, 1, 0, 1});
441   out = tf_char.zeros(sizes);
442   ET_EXPECT_KERNEL_FAILURE(
443       context_, op_scatter_value_out(self, 0, index, value, out));
444 }
445 
TEST_F(OpScatterValueOutTest,DynamicShapeUpperBoundSameAsExpected)446 TEST_F(OpScatterValueOutTest, DynamicShapeUpperBoundSameAsExpected) {
447   test_dynamic_shape(
448       {2, 3, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
449 }
450 
TEST_F(OpScatterValueOutTest,DynamicShapeUpperBoundLargerThanExpected)451 TEST_F(OpScatterValueOutTest, DynamicShapeUpperBoundLargerThanExpected) {
452   test_dynamic_shape(
453       {10, 10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
454 }
455 
TEST_F(OpScatterValueOutTest,DynamicShapeUnbound)456 TEST_F(OpScatterValueOutTest, DynamicShapeUnbound) {
457   if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
458     GTEST_SKIP() << "Dynamic shape not supported";
459   }
460   test_dynamic_shape(
461       {1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
462 }
463 
TEST_F(OpScatterValueOutTest,EmptyIndex)464 TEST_F(OpScatterValueOutTest, EmptyIndex) {
465   TensorFactory<ScalarType::Long> tf_index;
466   TensorFactory<ScalarType::Float> tf_data;
467 
468   Tensor self = tf_data.ones({2, 5});
469   Tensor index = tf_index.zeros({2, 0, 3});
470   const Scalar& value = 5;
471   Tensor out = tf_data.zeros({2, 5});
472   op_scatter_value_out(self, 0, index, value, out);
473   EXPECT_TENSOR_CLOSE(out, tf_data.ones({2, 5}));
474 }
475 
TEST_F(OpScatterValueOutTest,ValidZeroDim)476 TEST_F(OpScatterValueOutTest, ValidZeroDim) {
477   TensorFactory<ScalarType::Long> tf_index;
478   TensorFactory<ScalarType::Float> tf_data;
479 
480   Tensor self = tf_data.make({}, {3.14});
481   Tensor index = tf_index.zeros({});
482   const Scalar& value = 5;
483   Tensor out = tf_data.zeros({});
484   op_scatter_value_out(self, 0, index, value, out);
485   EXPECT_TENSOR_CLOSE(out, tf_data.make({}, {5}));
486 }
487 
TEST_F(OpScatterValueOutTest,InvalidZeroDimInput)488 TEST_F(OpScatterValueOutTest, InvalidZeroDimInput) {
489   TensorFactory<ScalarType::Long> tf_index;
490   TensorFactory<ScalarType::Float> tf_data;
491 
492   Tensor self = tf_data.ones({});
493   Tensor index = tf_index.make({2, 3}, {0, 0, 0, 0, 0, 0});
494   const Scalar& value = 5;
495   Tensor out = tf_data.zeros({});
496   ET_EXPECT_KERNEL_FAILURE(
497       context_, op_scatter_value_out(self, 0, index, value, out));
498 }
499 
TEST_F(OpScatterValueOutTest,InvalidZeroDimIndex)500 TEST_F(OpScatterValueOutTest, InvalidZeroDimIndex) {
501   TensorFactory<ScalarType::Long> tf_index;
502   TensorFactory<ScalarType::Float> tf_data;
503 
504   Tensor self = tf_data.make({2, 3}, {1, 2, 3, 4, 5, 6});
505   Tensor index = tf_index.make({}, {2});
506   const Scalar& value = 5;
507   Tensor out = tf_data.zeros({2, 3});
508   ET_EXPECT_KERNEL_FAILURE(
509       context_, op_scatter_value_out(self, 1, index, value, out));
510 }
511 
TEST_F(OpScatterValueOutTest,ValidZeroDimInputAndOneDimIndex)512 TEST_F(OpScatterValueOutTest, ValidZeroDimInputAndOneDimIndex) {
513   TensorFactory<ScalarType::Long> tf_index;
514   TensorFactory<ScalarType::Float> tf_data;
515 
516   Tensor self = tf_data.make({}, {3.14});
517   Tensor index = tf_index.make({3}, {0, 0, 0});
518   const Scalar& value = 5;
519   Tensor out = tf_data.make({}, {2.71});
520   op_scatter_value_out(self, 0, index, value, out);
521   EXPECT_TENSOR_CLOSE(out, tf_data.make({}, {5}));
522 }
523 
TEST_F(OpScatterValueOutTest,ValidOneDimInputAndZeroDimIndex)524 TEST_F(OpScatterValueOutTest, ValidOneDimInputAndZeroDimIndex) {
525   TensorFactory<ScalarType::Long> tf_index;
526   TensorFactory<ScalarType::Float> tf_data;
527 
528   Tensor self = tf_data.make({3}, {10, 20, 30});
529   Tensor index = tf_index.make({}, {2});
530   const Scalar& value = 5;
531   Tensor out = tf_data.make({3}, {1729, 1729, 1729});
532   op_scatter_value_out(self, 0, index, value, out);
533   EXPECT_TENSOR_CLOSE(out, tf_data.make({3}, {10, 20, 5}));
534 }
535 
TEST_F(OpScatterValueOutTest,InvalidZeroDimInputAndOneDimIndex)536 TEST_F(OpScatterValueOutTest, InvalidZeroDimInputAndOneDimIndex) {
537   TensorFactory<ScalarType::Long> tf_index;
538   TensorFactory<ScalarType::Float> tf_data;
539 
540   Tensor self = tf_data.make({}, {3.14});
541   Tensor index = tf_index.make({3}, {10, 100, 1000});
542   const Scalar& value = 5;
543   Tensor out = tf_data.make({}, {2.71});
544   ET_EXPECT_KERNEL_FAILURE(
545       context_, op_scatter_value_out(self, 0, index, value, out));
546 }
547 
TEST_F(OpScatterValueOutTest,InvalidOneDimInputAndZeroDimIndex)548 TEST_F(OpScatterValueOutTest, InvalidOneDimInputAndZeroDimIndex) {
549   TensorFactory<ScalarType::Long> tf_index;
550   TensorFactory<ScalarType::Float> tf_data;
551 
552   Tensor self = tf_data.make({3}, {10, 20, 30});
553   Tensor index = tf_index.make({}, {100});
554   const Scalar& value = 5;
555   Tensor out = tf_data.make({3}, {1729, 1729, 1729});
556   ET_EXPECT_KERNEL_FAILURE(
557       context_, op_scatter_value_out(self, 0, index, value, out));
558 }
559 
TEST_F(OpScatterSrcOutTest,EmptyIndex)560 TEST_F(OpScatterSrcOutTest, EmptyIndex) {
561   TensorFactory<ScalarType::Long> tf_index;
562   TensorFactory<ScalarType::Float> tf_data;
563 
564   Tensor self = tf_data.ones({2, 5});
565   Tensor index = tf_index.zeros({2, 0, 3});
566   Tensor src = tf_data.ones({1, 1, 4});
567   Tensor out = tf_data.zeros({2, 5});
568   op_scatter_src_out(self, 0, index, src, out);
569   EXPECT_TENSOR_CLOSE(out, tf_data.ones({2, 5}));
570 }
571 
TEST_F(OpScatterSrcOutTest,ValidZeroDim)572 TEST_F(OpScatterSrcOutTest, ValidZeroDim) {
573   TensorFactory<ScalarType::Long> tf_index;
574   TensorFactory<ScalarType::Float> tf_data;
575 
576   Tensor self = tf_data.make({}, {3.14});
577   Tensor index = tf_index.zeros({});
578   Tensor src = tf_data.make({}, {5});
579   Tensor out = tf_data.zeros({});
580   op_scatter_src_out(self, 0, index, src, out);
581   EXPECT_TENSOR_CLOSE(out, tf_data.make({}, {5}));
582 }
583 
TEST_F(OpScatterSrcOutTest,InvalidZeroDimInput)584 TEST_F(OpScatterSrcOutTest, InvalidZeroDimInput) {
585   TensorFactory<ScalarType::Long> tf_index;
586   TensorFactory<ScalarType::Float> tf_data;
587 
588   Tensor self = tf_data.ones({});
589   Tensor index = tf_index.make({2, 3}, {0, 0, 0, 0, 0, 0});
590   Tensor src = tf_data.make({}, {5});
591   Tensor out = tf_data.zeros({});
592   ET_EXPECT_KERNEL_FAILURE(
593       context_, op_scatter_src_out(self, 0, index, src, out));
594 }
595 
TEST_F(OpScatterSrcOutTest,InvalidZeroDimIndex)596 TEST_F(OpScatterSrcOutTest, InvalidZeroDimIndex) {
597   TensorFactory<ScalarType::Long> tf_index;
598   TensorFactory<ScalarType::Float> tf_data;
599 
600   Tensor self = tf_data.make({2, 3}, {1, 2, 3, 4, 5, 6});
601   Tensor index = tf_index.make({}, {2});
602   Tensor src = tf_data.make({}, {5});
603   Tensor out = tf_data.zeros({2, 3});
604   ET_EXPECT_KERNEL_FAILURE(
605       context_, op_scatter_src_out(self, 1, index, src, out));
606 }
607 
TEST_F(OpScatterSrcOutTest,ValidZeroDimInputAndOneDimIndex)608 TEST_F(OpScatterSrcOutTest, ValidZeroDimInputAndOneDimIndex) {
609   TensorFactory<ScalarType::Long> tf_index;
610   TensorFactory<ScalarType::Float> tf_data;
611 
612   Tensor self = tf_data.make({}, {3.14});
613   Tensor index = tf_index.make({3}, {0, 0, 0});
614   Tensor src = tf_data.make({3}, {5, 5, 5});
615   Tensor out = tf_data.make({}, {2.71});
616   op_scatter_src_out(self, 0, index, src, out);
617   EXPECT_TENSOR_CLOSE(out, tf_data.make({}, {5}));
618 }
619 
TEST_F(OpScatterSrcOutTest,ValidOneDimInputAndZeroDimIndex)620 TEST_F(OpScatterSrcOutTest, ValidOneDimInputAndZeroDimIndex) {
621   TensorFactory<ScalarType::Long> tf_index;
622   TensorFactory<ScalarType::Float> tf_data;
623 
624   Tensor self = tf_data.make({3}, {10, 20, 30});
625   Tensor index = tf_index.make({}, {2});
626   Tensor src = tf_data.make({}, {5});
627   Tensor out = tf_data.make({3}, {1729, 1729, 1729});
628   op_scatter_src_out(self, 0, index, src, out);
629   EXPECT_TENSOR_CLOSE(out, tf_data.make({3}, {10, 20, 5}));
630 }
631 
TEST_F(OpScatterSrcOutTest,InvalidZeroDimInputAndOneDimIndex)632 TEST_F(OpScatterSrcOutTest, InvalidZeroDimInputAndOneDimIndex) {
633   TensorFactory<ScalarType::Long> tf_index;
634   TensorFactory<ScalarType::Float> tf_data;
635 
636   Tensor self = tf_data.make({}, {3.14});
637   Tensor index = tf_index.make({3}, {10, 100, 1000});
638   Tensor src = tf_data.make({}, {5});
639   Tensor out = tf_data.make({}, {2.71});
640   ET_EXPECT_KERNEL_FAILURE(
641       context_, op_scatter_src_out(self, 0, index, src, out));
642 }
643 
TEST_F(OpScatterSrcOutTest,InvalidOneDimInputAndZeroDimIndex)644 TEST_F(OpScatterSrcOutTest, InvalidOneDimInputAndZeroDimIndex) {
645   TensorFactory<ScalarType::Long> tf_index;
646   TensorFactory<ScalarType::Float> tf_data;
647 
648   Tensor self = tf_data.make({3}, {10, 20, 30});
649   Tensor index = tf_index.make({}, {100});
650   Tensor src = tf_data.make({}, {5});
651   Tensor out = tf_data.make({3}, {1729, 1729, 1729});
652   ET_EXPECT_KERNEL_FAILURE(
653       context_, op_scatter_src_out(self, 0, index, src, out));
654 }
655