1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15
16 #include <gtest/gtest.h>
17 #include <cmath>
18
19 using namespace ::testing;
20 using exec_aten::Scalar;
21 using exec_aten::ScalarType;
22 using exec_aten::Tensor;
23 using torch::executor::testing::TensorFactory;
24
25 class OpScatterSrcOutTest : public OperatorTest {
26 protected:
op_scatter_src_out(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & src,Tensor & out)27 Tensor& op_scatter_src_out(
28 const Tensor& self,
29 int64_t dim,
30 const Tensor& index,
31 const Tensor& src,
32 Tensor& out) {
33 return torch::executor::aten::scatter_outf(
34 context_, self, dim, index, src, out);
35 }
36
37 // Common testing for the operator
38 template <ScalarType DATA_DTYPE>
test_scatter_src_out()39 void test_scatter_src_out() {
40 TensorFactory<ScalarType::Long> tf_index;
41 TensorFactory<DATA_DTYPE> tf_data;
42 const std::vector<int32_t> sizes = {3, 5};
43 // clang-format off
44 Tensor src = tf_data.make(
45 /*sizes=*/{2, 5},
46 {
47 1, 2, 3, 4, 5,
48 6, 7, 8, 9, 10
49 });
50 // clang-format on
51 Tensor in = tf_data.zeros(sizes);
52 Tensor out = tf_data.zeros(sizes);
53 // clang-format off
54 Tensor index = tf_index.make(
55 /*sizes=*/{2, 3},
56 {
57 0, 1, 2,
58 0, 1, 2
59 });
60 // clang-format on
61
62 // Valid input should give the expected output
63 op_scatter_src_out(in, 0, index, src, out);
64 // clang-format off
65 EXPECT_TENSOR_EQ(
66 out, tf_data.make(
67 sizes,
68 {
69 6, 0, 0, 0, 0,
70 0, 7, 0, 0, 0,
71 0, 0, 8, 0, 0
72 }));
73 // clang-format on
74
75 // Valid input should give the expected output
76 op_scatter_src_out(in, 1, index, src, out);
77 // clang-format off
78 EXPECT_TENSOR_EQ(
79 out, tf_data.make(sizes,
80 {
81 1, 2, 3, 0, 0,
82 6, 7, 8, 0, 0,
83 0, 0, 0, 0, 0
84 }));
85
86 src = tf_data.make(
87 /*sizes=*/{2, 3, 3},
88 {
89 // [0, :, :]
90 1, 2, 3,
91 4, 5, 6,
92 7, 8, 9,
93
94 // [1, :, :]
95 10, 11, 12,
96 13, 14, 15,
97 16, 17, 18
98 });
99 // clang-format on
100 in = tf_data.ones(/*sizes=*/{2, 3, 3});
101 out = tf_data.zeros(/*sizes=*/{2, 3, 3});
102 // clang-format off
103 index = tf_index.make(
104 /*sizes=*/{1, 3, 2},
105 {
106 0, 1,
107 1, 2,
108 0, 2
109 });
110 // clang-format on
111
112 op_scatter_src_out(in, 1, index, src, out);
113 // clang-format off
114 EXPECT_TENSOR_EQ(
115 out,
116 tf_data.make(
117 /*sizes=*/{2, 3, 3},
118 {
119 // [0, :, :]
120 7, 1, 1,
121 4, 2, 1,
122 1, 8, 1,
123
124 // [1, :, :]
125 1, 1, 1,
126 1, 1, 1,
127 1, 1, 1
128 }));
129 // clang-format on
130
131 out = tf_data.zeros(/*sizes=*/{2, 3, 3});
132 op_scatter_src_out(in, 2, index, src, out);
133 // clang-format off
134 EXPECT_TENSOR_EQ(
135 out,
136 tf_data.make(
137 /*sizes=*/{2, 3, 3},
138 {
139 // [0, :, :]
140 1, 2, 1,
141 1, 4, 5,
142 7, 1, 8,
143
144 // [1, :, :]
145 1, 1, 1,
146 1, 1, 1,
147 1, 1, 1
148 }));
149 // clang-format on
150 }
151
152 // Invalid dimensions
153 template <ScalarType DATA_DTYPE>
test_scatter_src_out_invalid_dim()154 void test_scatter_src_out_invalid_dim() {
155 TensorFactory<ScalarType::Long> tf_index;
156 TensorFactory<DATA_DTYPE> tf_data;
157 const std::vector<int32_t> sizes = {3, 5};
158 // clang-format off
159 Tensor src = tf_data.make(/*sizes=*/{2, 5},
160 {
161 1, 2, 3, 4, 5,
162 6, 7, 8, 9, 10
163 });
164 Tensor index = tf_index.make(/*sizes=*/{2, 3},
165 {
166 0, 1, 2,
167 0, 1, 2
168 });
169 // clang-format on
170 Tensor self = tf_data.zeros(sizes);
171 Tensor out = tf_data.zeros(sizes);
172
173 // Invalid dim should die
174 ET_EXPECT_KERNEL_FAILURE(
175 context_, op_scatter_src_out(self, -3, index, src, out));
176 ET_EXPECT_KERNEL_FAILURE(
177 context_, op_scatter_src_out(self, 2, index, src, out));
178
179 // Self, index and src hsould have same number of dimensions
180 src = tf_data.zeros(/*sizes=*/{2, 2, 2});
181 ET_EXPECT_KERNEL_FAILURE(
182 context_, op_scatter_src_out(self, 0, index, src, out));
183
184 src = tf_data.zeros(/*sizes=*/{5, 5});
185 index = tf_index.zeros(/*sizes=*/{2, 2, 2});
186 ET_EXPECT_KERNEL_FAILURE(
187 context_, op_scatter_src_out(self, 0, index, src, out));
188
189 // Size of dimension of index should be smaller than the size of that
190 // dimension of src
191 index = tf_index.zeros(/*sizes=*/{4, 6});
192 ET_EXPECT_KERNEL_FAILURE(
193 context_, op_scatter_src_out(self, 0, index, src, out));
194
195 // Size of dimension of index should be smaller than the size of that
196 // dimension of self if dimension != dim
197 index = tf_index.zeros(/*sizes=*/{4, 5});
198 ET_EXPECT_KERNEL_FAILURE(
199 context_, op_scatter_src_out(self, 1, index, src, out));
200
201 // Index out of bound for self in dim
202 index = tf_index.make(/*sizes=*/{2, 3}, {0, 1, 3, 0, 1, 3});
203 ET_EXPECT_KERNEL_FAILURE(
204 context_, op_scatter_src_out(self, 0, index, src, out));
205 }
206 };
207
208 class OpScatterValueOutTest : public OperatorTest {
209 protected:
op_scatter_value_out(const Tensor & self,int64_t dim,const Tensor & index,const Scalar & value,Tensor & out)210 Tensor& op_scatter_value_out(
211 const Tensor& self,
212 int64_t dim,
213 const Tensor& index,
214 const Scalar& value,
215 Tensor& out) {
216 return torch::executor::aten::scatter_outf(
217 context_, self, dim, index, value, out);
218 }
219
220 // Common testing for the operator
221 template <ScalarType DATA_DTYPE>
test_scatter_value_out()222 void test_scatter_value_out() {
223 TensorFactory<ScalarType::Long> tf_index;
224 TensorFactory<DATA_DTYPE> tf_data;
225
226 const Scalar& value = 1;
227
228 const std::vector<int32_t> sizes = {3, 5};
229 Tensor self = tf_data.zeros(sizes);
230 Tensor out = tf_data.zeros(sizes);
231 Tensor index = tf_index.make({2, 3}, {0, 1, 2, 0, 1, 2});
232
233 op_scatter_value_out(self, 0, index, value, out);
234 // clang-format off
235 EXPECT_TENSOR_EQ(
236 out, tf_data.make(
237 sizes,
238 {
239 1, 0, 0, 0, 0,
240 0, 1, 0, 0, 0,
241 0, 0, 1, 0, 0
242 }));
243 // clang-format on
244
245 op_scatter_value_out(self, 1, index, value, out);
246 // clang-format off
247 EXPECT_TENSOR_EQ(
248 out, tf_data.make(sizes,
249 {
250 1, 1, 1, 0, 0,
251 1, 1, 1, 0, 0,
252 0, 0, 0, 0, 0
253 }));
254
255 const Scalar& value2 = 2;
256 self = tf_data.ones(/*sizes=*/{2, 3, 3});
257 out = tf_data.zeros(/*sizes=*/{2, 3, 3});
258 // clang-format off
259 index = tf_index.make(
260 /*sizes=*/{1, 3, 2},
261 {
262 0, 1,
263 1, 2,
264 0, 2
265 });
266 // clang-format on
267
268 op_scatter_value_out(self, 1, index, value2, out);
269 // clang-format off
270 EXPECT_TENSOR_EQ(
271 out,
272 tf_data.make(
273 /*sizes=*/{2, 3, 3},
274 {
275 // [0, :, :]
276 2, 1, 1,
277 2, 2, 1,
278 1, 2, 1,
279
280 // [1, :, :]
281 1, 1, 1,
282 1, 1, 1,
283 1, 1, 1
284 }));
285 // clang-format on
286
287 out = tf_data.zeros(/*sizes=*/{2, 3, 3});
288 op_scatter_value_out(self, 2, index, value2, out);
289 // clang-format off
290 EXPECT_TENSOR_EQ(
291 out,
292 tf_data.make(
293 /*sizes=*/{2, 3, 3},
294 {
295 // [0, :, :]
296 2, 2, 1,
297 1, 2, 2,
298 2, 1, 2,
299
300 // [1, :, :]
301 1, 1, 1,
302 1, 1, 1,
303 1, 1, 1
304 }));
305 // clang-format on
306 }
307
308 // Invalid dimensions
309 template <ScalarType DATA_DTYPE>
test_scatter_value_out_invalid_dim()310 void test_scatter_value_out_invalid_dim() {
311 TensorFactory<ScalarType::Long> tf_index;
312 TensorFactory<DATA_DTYPE> tf_data;
313 // clang-format off
314 Tensor self = tf_data.make(/*sizes=*/{2, 5},
315 {
316 1, 2, 3, 4, 5,
317 6, 7, 8, 9, 10
318 });
319 const std::vector<int32_t> sizes = {2, 3};
320 Tensor index = tf_index.make(sizes,
321 {
322 0, 1, 0,
323 1, 0, 1,
324 });
325 // clang-format on
326 const Scalar& value = 1;
327 Tensor out = tf_data.zeros(sizes);
328
329 // Invalid dim should die
330 ET_EXPECT_KERNEL_FAILURE(
331 context_, op_scatter_value_out(self, -3, index, value, out));
332 ET_EXPECT_KERNEL_FAILURE(
333 context_, op_scatter_value_out(self, 2, index, value, out));
334
335 // Self and index hsould have same number of dimensions
336 index = tf_index.zeros(/*sizes=*/{2, 2, 2});
337 ET_EXPECT_KERNEL_FAILURE(
338 context_, op_scatter_value_out(self, 0, index, value, out));
339
340 // Size of dimension of index should be smaller than the size of that
341 // dimension of self if dimension != dim
342 index = tf_index.zeros(/*sizes=*/{3, 5});
343 ET_EXPECT_KERNEL_FAILURE(
344 context_, op_scatter_value_out(self, 1, index, value, out));
345
346 // Index out of bound for self in dim
347 index = tf_index.make(/*sizes=*/{2, 3}, {0, 1, 2, 0, 1, 2});
348 ET_EXPECT_KERNEL_FAILURE(
349 context_, op_scatter_value_out(self, 0, index, value, out));
350 }
351
test_dynamic_shape(const std::vector<int32_t> & out_shape,enum torch::executor::TensorShapeDynamism dynamism)352 void test_dynamic_shape(
353 const std::vector<int32_t>& out_shape,
354 enum torch::executor::TensorShapeDynamism dynamism) {
355 TensorFactory<ScalarType::Int> tf;
356 TensorFactory<ScalarType::Long> tf_index;
357
358 Tensor input = tf.ones({2, 3, 4});
359 Tensor index = tf_index.zeros({2, 3, 4});
360 const Scalar& value = 1;
361 Tensor expected = tf.ones({2, 3, 4});
362 Tensor out = tf.zeros(out_shape, dynamism);
363
364 op_scatter_value_out(input, 2, index, value, out);
365 EXPECT_TENSOR_EQ(out, expected);
366 }
367 };
368
TEST_F(OpScatterSrcOutTest,AllValidInputOutputSupport)369 TEST_F(OpScatterSrcOutTest, AllValidInputOutputSupport) {
370 #define TEST_ENTRY(CTYPE, DTYPE) test_scatter_src_out<ScalarType::DTYPE>();
371 ET_FORALL_REAL_TYPES(TEST_ENTRY);
372 #undef TEST_ENTRY
373 }
374
TEST_F(OpScatterSrcOutTest,InvalidDimensionsDies)375 TEST_F(OpScatterSrcOutTest, InvalidDimensionsDies) {
376 #define TEST_ENTRY(CTYPE, DTYPE) \
377 test_scatter_src_out_invalid_dim<ScalarType::DTYPE>();
378 ET_FORALL_REAL_TYPES(TEST_ENTRY);
379 #undef TEST_ENTRY
380 }
381
TEST_F(OpScatterValueOutTest,AllValidInputOutputSupport)382 TEST_F(OpScatterValueOutTest, AllValidInputOutputSupport) {
383 #define TEST_ENTRY(CTYPE, DTYPE) test_scatter_value_out<ScalarType::DTYPE>();
384 ET_FORALL_REAL_TYPES(TEST_ENTRY);
385 #undef TEST_ENTRY
386 }
387
TEST_F(OpScatterValueOutTest,InfinityAndNANTest)388 TEST_F(OpScatterValueOutTest, InfinityAndNANTest) {
389 TensorFactory<ScalarType::Long> tf_index;
390 TensorFactory<ScalarType::Float> tf_data;
391 // clang-format off
392 Tensor self = tf_data.make(
393 /*sizes=*/{2, 5},
394 {
395 0.0, -INFINITY, NAN, 2.33, NAN,
396 NAN, INFINITY, -INFINITY, -INFINITY, 2.33
397 });
398 // clang-format on
399 Tensor index = tf_index.make({2, 3}, {0, 1, 0, 1, 0, 1});
400 const Scalar& value = INFINITY;
401 Tensor out = tf_data.zeros({2, 5});
402
403 // Valid input should give the expected output
404 op_scatter_value_out(self, 0, index, value, out);
405 // clang-format off
406 EXPECT_TENSOR_CLOSE(
407 out,
408 tf_data.make(/*sizes=*/{2, 5},
409 {
410 INFINITY, INFINITY, INFINITY, 2.33, NAN,
411 INFINITY, INFINITY, INFINITY, -INFINITY, 2.33
412 }));
413 // clang-format on
414 }
415
TEST_F(OpScatterValueOutTest,InvalidDimensionsDies)416 TEST_F(OpScatterValueOutTest, InvalidDimensionsDies) {
417 #define TEST_ENTRY(CTYPE, DTYPE) \
418 test_scatter_value_out_invalid_dim<ScalarType::DTYPE>();
419 ET_FORALL_REAL_TYPES(TEST_ENTRY);
420 #undef TEST_ENTRY
421 }
422
TEST_F(OpScatterValueOutTest,MismatchedInputDtypesDies)423 TEST_F(OpScatterValueOutTest, MismatchedInputDtypesDies) {
424 TensorFactory<ScalarType::Byte> tf_byte;
425 TensorFactory<ScalarType::Char> tf_char;
426 TensorFactory<ScalarType::Long> tf_long;
427
428 Tensor self = tf_char.make({2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
429 const std::vector<int32_t> sizes = {2, 3};
430 Tensor index = tf_byte.make(sizes, {0, 1, 0, 0, 1, 0});
431 const Scalar& value = 5;
432 Tensor out = tf_char.zeros(sizes);
433
434 // Types other than long for index should die
435 ET_EXPECT_KERNEL_FAILURE(
436 context_, op_scatter_value_out(self, 0, index, value, out));
437
438 // Mismatched dtype of self and out should die
439 self = tf_byte.make(/*sizes=*/{2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
440 index = tf_long.make(sizes, {0, 1, 0, 1, 0, 1});
441 out = tf_char.zeros(sizes);
442 ET_EXPECT_KERNEL_FAILURE(
443 context_, op_scatter_value_out(self, 0, index, value, out));
444 }
445
TEST_F(OpScatterValueOutTest,DynamicShapeUpperBoundSameAsExpected)446 TEST_F(OpScatterValueOutTest, DynamicShapeUpperBoundSameAsExpected) {
447 test_dynamic_shape(
448 {2, 3, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
449 }
450
TEST_F(OpScatterValueOutTest,DynamicShapeUpperBoundLargerThanExpected)451 TEST_F(OpScatterValueOutTest, DynamicShapeUpperBoundLargerThanExpected) {
452 test_dynamic_shape(
453 {10, 10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
454 }
455
TEST_F(OpScatterValueOutTest,DynamicShapeUnbound)456 TEST_F(OpScatterValueOutTest, DynamicShapeUnbound) {
457 if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
458 GTEST_SKIP() << "Dynamic shape not supported";
459 }
460 test_dynamic_shape(
461 {1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
462 }
463
TEST_F(OpScatterValueOutTest,EmptyIndex)464 TEST_F(OpScatterValueOutTest, EmptyIndex) {
465 TensorFactory<ScalarType::Long> tf_index;
466 TensorFactory<ScalarType::Float> tf_data;
467
468 Tensor self = tf_data.ones({2, 5});
469 Tensor index = tf_index.zeros({2, 0, 3});
470 const Scalar& value = 5;
471 Tensor out = tf_data.zeros({2, 5});
472 op_scatter_value_out(self, 0, index, value, out);
473 EXPECT_TENSOR_CLOSE(out, tf_data.ones({2, 5}));
474 }
475
TEST_F(OpScatterValueOutTest,ValidZeroDim)476 TEST_F(OpScatterValueOutTest, ValidZeroDim) {
477 TensorFactory<ScalarType::Long> tf_index;
478 TensorFactory<ScalarType::Float> tf_data;
479
480 Tensor self = tf_data.make({}, {3.14});
481 Tensor index = tf_index.zeros({});
482 const Scalar& value = 5;
483 Tensor out = tf_data.zeros({});
484 op_scatter_value_out(self, 0, index, value, out);
485 EXPECT_TENSOR_CLOSE(out, tf_data.make({}, {5}));
486 }
487
TEST_F(OpScatterValueOutTest,InvalidZeroDimInput)488 TEST_F(OpScatterValueOutTest, InvalidZeroDimInput) {
489 TensorFactory<ScalarType::Long> tf_index;
490 TensorFactory<ScalarType::Float> tf_data;
491
492 Tensor self = tf_data.ones({});
493 Tensor index = tf_index.make({2, 3}, {0, 0, 0, 0, 0, 0});
494 const Scalar& value = 5;
495 Tensor out = tf_data.zeros({});
496 ET_EXPECT_KERNEL_FAILURE(
497 context_, op_scatter_value_out(self, 0, index, value, out));
498 }
499
TEST_F(OpScatterValueOutTest,InvalidZeroDimIndex)500 TEST_F(OpScatterValueOutTest, InvalidZeroDimIndex) {
501 TensorFactory<ScalarType::Long> tf_index;
502 TensorFactory<ScalarType::Float> tf_data;
503
504 Tensor self = tf_data.make({2, 3}, {1, 2, 3, 4, 5, 6});
505 Tensor index = tf_index.make({}, {2});
506 const Scalar& value = 5;
507 Tensor out = tf_data.zeros({2, 3});
508 ET_EXPECT_KERNEL_FAILURE(
509 context_, op_scatter_value_out(self, 1, index, value, out));
510 }
511
TEST_F(OpScatterValueOutTest,ValidZeroDimInputAndOneDimIndex)512 TEST_F(OpScatterValueOutTest, ValidZeroDimInputAndOneDimIndex) {
513 TensorFactory<ScalarType::Long> tf_index;
514 TensorFactory<ScalarType::Float> tf_data;
515
516 Tensor self = tf_data.make({}, {3.14});
517 Tensor index = tf_index.make({3}, {0, 0, 0});
518 const Scalar& value = 5;
519 Tensor out = tf_data.make({}, {2.71});
520 op_scatter_value_out(self, 0, index, value, out);
521 EXPECT_TENSOR_CLOSE(out, tf_data.make({}, {5}));
522 }
523
TEST_F(OpScatterValueOutTest,ValidOneDimInputAndZeroDimIndex)524 TEST_F(OpScatterValueOutTest, ValidOneDimInputAndZeroDimIndex) {
525 TensorFactory<ScalarType::Long> tf_index;
526 TensorFactory<ScalarType::Float> tf_data;
527
528 Tensor self = tf_data.make({3}, {10, 20, 30});
529 Tensor index = tf_index.make({}, {2});
530 const Scalar& value = 5;
531 Tensor out = tf_data.make({3}, {1729, 1729, 1729});
532 op_scatter_value_out(self, 0, index, value, out);
533 EXPECT_TENSOR_CLOSE(out, tf_data.make({3}, {10, 20, 5}));
534 }
535
TEST_F(OpScatterValueOutTest,InvalidZeroDimInputAndOneDimIndex)536 TEST_F(OpScatterValueOutTest, InvalidZeroDimInputAndOneDimIndex) {
537 TensorFactory<ScalarType::Long> tf_index;
538 TensorFactory<ScalarType::Float> tf_data;
539
540 Tensor self = tf_data.make({}, {3.14});
541 Tensor index = tf_index.make({3}, {10, 100, 1000});
542 const Scalar& value = 5;
543 Tensor out = tf_data.make({}, {2.71});
544 ET_EXPECT_KERNEL_FAILURE(
545 context_, op_scatter_value_out(self, 0, index, value, out));
546 }
547
TEST_F(OpScatterValueOutTest,InvalidOneDimInputAndZeroDimIndex)548 TEST_F(OpScatterValueOutTest, InvalidOneDimInputAndZeroDimIndex) {
549 TensorFactory<ScalarType::Long> tf_index;
550 TensorFactory<ScalarType::Float> tf_data;
551
552 Tensor self = tf_data.make({3}, {10, 20, 30});
553 Tensor index = tf_index.make({}, {100});
554 const Scalar& value = 5;
555 Tensor out = tf_data.make({3}, {1729, 1729, 1729});
556 ET_EXPECT_KERNEL_FAILURE(
557 context_, op_scatter_value_out(self, 0, index, value, out));
558 }
559
TEST_F(OpScatterSrcOutTest,EmptyIndex)560 TEST_F(OpScatterSrcOutTest, EmptyIndex) {
561 TensorFactory<ScalarType::Long> tf_index;
562 TensorFactory<ScalarType::Float> tf_data;
563
564 Tensor self = tf_data.ones({2, 5});
565 Tensor index = tf_index.zeros({2, 0, 3});
566 Tensor src = tf_data.ones({1, 1, 4});
567 Tensor out = tf_data.zeros({2, 5});
568 op_scatter_src_out(self, 0, index, src, out);
569 EXPECT_TENSOR_CLOSE(out, tf_data.ones({2, 5}));
570 }
571
TEST_F(OpScatterSrcOutTest,ValidZeroDim)572 TEST_F(OpScatterSrcOutTest, ValidZeroDim) {
573 TensorFactory<ScalarType::Long> tf_index;
574 TensorFactory<ScalarType::Float> tf_data;
575
576 Tensor self = tf_data.make({}, {3.14});
577 Tensor index = tf_index.zeros({});
578 Tensor src = tf_data.make({}, {5});
579 Tensor out = tf_data.zeros({});
580 op_scatter_src_out(self, 0, index, src, out);
581 EXPECT_TENSOR_CLOSE(out, tf_data.make({}, {5}));
582 }
583
TEST_F(OpScatterSrcOutTest,InvalidZeroDimInput)584 TEST_F(OpScatterSrcOutTest, InvalidZeroDimInput) {
585 TensorFactory<ScalarType::Long> tf_index;
586 TensorFactory<ScalarType::Float> tf_data;
587
588 Tensor self = tf_data.ones({});
589 Tensor index = tf_index.make({2, 3}, {0, 0, 0, 0, 0, 0});
590 Tensor src = tf_data.make({}, {5});
591 Tensor out = tf_data.zeros({});
592 ET_EXPECT_KERNEL_FAILURE(
593 context_, op_scatter_src_out(self, 0, index, src, out));
594 }
595
TEST_F(OpScatterSrcOutTest,InvalidZeroDimIndex)596 TEST_F(OpScatterSrcOutTest, InvalidZeroDimIndex) {
597 TensorFactory<ScalarType::Long> tf_index;
598 TensorFactory<ScalarType::Float> tf_data;
599
600 Tensor self = tf_data.make({2, 3}, {1, 2, 3, 4, 5, 6});
601 Tensor index = tf_index.make({}, {2});
602 Tensor src = tf_data.make({}, {5});
603 Tensor out = tf_data.zeros({2, 3});
604 ET_EXPECT_KERNEL_FAILURE(
605 context_, op_scatter_src_out(self, 1, index, src, out));
606 }
607
TEST_F(OpScatterSrcOutTest,ValidZeroDimInputAndOneDimIndex)608 TEST_F(OpScatterSrcOutTest, ValidZeroDimInputAndOneDimIndex) {
609 TensorFactory<ScalarType::Long> tf_index;
610 TensorFactory<ScalarType::Float> tf_data;
611
612 Tensor self = tf_data.make({}, {3.14});
613 Tensor index = tf_index.make({3}, {0, 0, 0});
614 Tensor src = tf_data.make({3}, {5, 5, 5});
615 Tensor out = tf_data.make({}, {2.71});
616 op_scatter_src_out(self, 0, index, src, out);
617 EXPECT_TENSOR_CLOSE(out, tf_data.make({}, {5}));
618 }
619
TEST_F(OpScatterSrcOutTest,ValidOneDimInputAndZeroDimIndex)620 TEST_F(OpScatterSrcOutTest, ValidOneDimInputAndZeroDimIndex) {
621 TensorFactory<ScalarType::Long> tf_index;
622 TensorFactory<ScalarType::Float> tf_data;
623
624 Tensor self = tf_data.make({3}, {10, 20, 30});
625 Tensor index = tf_index.make({}, {2});
626 Tensor src = tf_data.make({}, {5});
627 Tensor out = tf_data.make({3}, {1729, 1729, 1729});
628 op_scatter_src_out(self, 0, index, src, out);
629 EXPECT_TENSOR_CLOSE(out, tf_data.make({3}, {10, 20, 5}));
630 }
631
TEST_F(OpScatterSrcOutTest,InvalidZeroDimInputAndOneDimIndex)632 TEST_F(OpScatterSrcOutTest, InvalidZeroDimInputAndOneDimIndex) {
633 TensorFactory<ScalarType::Long> tf_index;
634 TensorFactory<ScalarType::Float> tf_data;
635
636 Tensor self = tf_data.make({}, {3.14});
637 Tensor index = tf_index.make({3}, {10, 100, 1000});
638 Tensor src = tf_data.make({}, {5});
639 Tensor out = tf_data.make({}, {2.71});
640 ET_EXPECT_KERNEL_FAILURE(
641 context_, op_scatter_src_out(self, 0, index, src, out));
642 }
643
TEST_F(OpScatterSrcOutTest,InvalidOneDimInputAndZeroDimIndex)644 TEST_F(OpScatterSrcOutTest, InvalidOneDimInputAndZeroDimIndex) {
645 TensorFactory<ScalarType::Long> tf_index;
646 TensorFactory<ScalarType::Float> tf_data;
647
648 Tensor self = tf_data.make({3}, {10, 20, 30});
649 Tensor index = tf_index.make({}, {100});
650 Tensor src = tf_data.make({}, {5});
651 Tensor out = tf_data.make({3}, {1729, 1729, 1729});
652 ET_EXPECT_KERNEL_FAILURE(
653 context_, op_scatter_src_out(self, 0, index, src, out));
654 }
655