xref: /aosp_15_r20/external/executorch/kernels/test/op_select_scatter_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
16 
17 #include <gtest/gtest.h>
18 #include <sys/types.h>
19 
20 using namespace ::testing;
21 using exec_aten::ArrayRef;
22 using exec_aten::ScalarType;
23 using exec_aten::Tensor;
24 using torch::executor::testing::TensorFactory;
25 
26 class OpSelectScatterOutTest : public OperatorTest {
27  protected:
op_select_scatter_out(const Tensor & self,const Tensor & src,int64_t dim,int64_t index,Tensor & out)28   Tensor& op_select_scatter_out(
29       const Tensor& self,
30       const Tensor& src,
31       int64_t dim,
32       int64_t index,
33       Tensor& out) {
34     return torch::executor::aten::select_scatter_outf(
35         context_, self, src, dim, index, out);
36   }
37 
38   template <class CTYPE, exec_aten::ScalarType DTYPE>
test_dtype()39   void test_dtype() {
40     TensorFactory<DTYPE> tf;
41 
42     // Using the following tensors, inserting a tensor of either ones or zeros
43     // into the appropriate selected slice should result in a tensor of all ones
44     // or all zeros.
45 
46     // clang-format off
47     Tensor x = tf.make(
48         {3, 2, 4},
49         {
50           // all ones below are from x,
51           // and all zeros are from y.
52           // [0, :, :]
53           1, 1, 1, 1, // [0, 0, :]
54           0, 0, 0, 0, // [0, 1, :]
55 
56           // [1, :, :]
57           1, 1, 1, 1, // [1, 0, :]
58           0, 0, 0, 0, // [1, 1, :]
59 
60           // [2, :, :]
61           1, 1, 1, 1, // [2, 0, :]
62           0, 0, 0, 0, // [2, 1, :]
63         });
64     // clang-format on
65 
66     // clang-format off
67     Tensor src_ones = tf.make(
68         {3, 4},
69         {
70             // [:, :]
71             1,  1,  1,  1, // [0, :]
72             1,  1,  1,  1, // [1, :]
73             1,  1,  1,  1, // [2, :]
74         });
75     // clang-format on
76 
77     // clang-format off
78     Tensor src_zeros = tf.make(
79         {3, 4},
80         {
81             // [:, :]
82             0,  0,  0,  0, // [0, :]
83             0,  0,  0,  0, // [1, :]
84             0,  0,  0,  0, // [2, :]
85         });
86     // clang-format on
87 
88     // Expected outs should be all ones or all zeros depending on which src
89     // tensor is used.
90 
91     Tensor out_0 = tf.zeros({3, 2, 4});
92     Tensor out_1 = tf.ones({3, 2, 4});
93     Tensor ret_0 =
94         op_select_scatter_out(x, src_zeros, /*dim=*/1, /*index=*/0, out_0);
95     Tensor ret_1 =
96         op_select_scatter_out(x, src_ones, /*dim=*/1, /*index=*/1, out_1);
97 
98     EXPECT_TENSOR_EQ(ret_0, out_0);
99     EXPECT_TENSOR_EQ(ret_1, out_1);
100 
101     EXPECT_TENSOR_EQ(ret_0, tf.zeros({3, 2, 4}));
102     EXPECT_TENSOR_EQ(ret_1, tf.ones({3, 2, 4}));
103   }
104 
105   // Run the test by selecting Tensor x on given dim and all available indexes
106   // on that dimension
run_test_cases(const Tensor & x,const Tensor & src,ssize_t dim,const std::vector<Tensor> & expected)107   void run_test_cases(
108       const Tensor& x,
109       const Tensor& src,
110       ssize_t dim,
111       const std::vector<Tensor>& expected) {
112     // Generated out tensor sharing same size and dtype with expected tensor
113     TensorFactory<ScalarType::Double> tf;
114 
115     const std::vector<int32_t> out_size(
116         expected[0].sizes().begin(), expected[0].sizes().end());
117     Tensor out = tf.zeros(out_size);
118 
119     for (ssize_t idx = 0; idx < x.size(dim); idx++) {
120       // Should always return the provided out Tensor.
121       // The ret shall meet the expectation.
122       Tensor ret = op_select_scatter_out(x, src, dim, idx, out);
123       EXPECT_TENSOR_EQ(out, ret);
124       EXPECT_TENSOR_EQ(out, expected[idx]);
125 
126       ret =
127           op_select_scatter_out(x, src, dim, /*index=*/idx - x.size(dim), out);
128       EXPECT_TENSOR_EQ(out, ret);
129       EXPECT_TENSOR_EQ(out, expected[idx]);
130     }
131   }
132 
133   /* %python
134   import torch
135   torch.manual_seed(0)
136   x = torch.randint(10, (2, 3, 2))
137   y = torch.randint(10, (3, 2))
138   dim = 0
139   index = 1
140   res = torch.select_scatter(x, y, dim, index)
141   op = "op_select_scatter_out"
142   opt_extra_params = f"""{dim}, {index},"""
143   out_args = "out_shape, dynamism"
144   dtype = "ScalarType::Int"
145   check = "EXPECT_TENSOR_CLOSE" */
146 
test_dynamic_shape(const std::vector<int32_t> & out_shape,enum torch::executor::TensorShapeDynamism dynamism)147   void test_dynamic_shape(
148       const std::vector<int32_t>& out_shape,
149       enum torch::executor::TensorShapeDynamism dynamism) {
150     /* %python
151     %rewrite(binary_op) */
152 
153     TensorFactory<ScalarType::Int> tf;
154 
155     Tensor x = tf.make({2, 3, 2}, {4, 9, 3, 0, 3, 9, 7, 3, 7, 3, 1, 6});
156     Tensor y = tf.make({3, 2}, {6, 9, 8, 6, 6, 8});
157     Tensor expected = tf.make({2, 3, 2}, {4, 9, 3, 0, 3, 9, 6, 9, 8, 6, 6, 8});
158 
159     Tensor out = tf.zeros(out_shape, dynamism);
160     op_select_scatter_out(x, y, 0, 1, out);
161     EXPECT_TENSOR_CLOSE(out, expected);
162   }
163 };
164 
TEST_F(OpSelectScatterOutTest,SelectFrontDimAllIndexes)165 TEST_F(OpSelectScatterOutTest, SelectFrontDimAllIndexes) {
166   TensorFactory<ScalarType::Double> tf;
167 
168   // clang-format off
169   Tensor x = tf.make(
170       {2, 3, 4},
171       {
172           // [0, :, :]
173           1.,   2.,   3.,   4., // [0, 0, :]
174           5.,   6.,   7.,   8., // [0, 1, :]
175           9.,  10.,  11.,  12., // [0, 2, :]
176 
177           // [1, :, :]
178          -1.,  -2.,  -3.,  -4., // [1, 0, :]
179          -5.,  -6.,  -7.,  -8., // [1, 1, :]
180          -9., -10., -11., -12., // [1, 2, :]
181       });
182   // clang-format on
183 
184   // clang-format off
185   Tensor src = tf.make(
186       {3, 4},
187       {
188           // [0, :, :]
189           1.,  4.,  1.,  4., // [0, 0, :]
190           1.,  4.,  1.,  4., // [0, 1, :]
191           1.,  4.,  1.,  4., // [0, 2, :]
192       });
193   // clang-format on
194 
195   // Try to select the tensor from the input front (0th dimension)
196   // The size of output tensor should follow these rules:
197   // - output.size(i) shall equal input.size(i) if i < dim,
198   // - output.size(i) shall equal input.size(i+1) if i >= dim
199   const std::vector<int32_t> out_size = {2, 3, 4};
200 
201   Tensor out = tf.zeros(out_size);
202 
203   // clang-format off
204   std::vector<Tensor> expected_rets = {
205     // Expected result when choosing from the 0th dimension and 0th index
206     // The result should equal x[0,:, :]
207     tf.make(
208       out_size,
209       {
210          // [0, :, :]
211          1.,  4.,  1.,  4., // [0, 0, :]
212          1.,  4.,  1.,  4., // [0, 1, :]
213          1.,  4.,  1.,  4., // [0, 2, :]
214 
215          // [1, :, :]
216         -1.,  -2.,  -3.,  -4., // [1, 0, :]
217         -5.,  -6.,  -7.,  -8., // [1, 1, :]
218         -9., -10., -11., -12., // [1, 2, :]
219       }),
220 
221     // Expected result when choosing from the 0th dimension and 1st index
222     // The result should euqal x[1, :, :]
223     tf.make(
224       out_size,
225       {
226         // [0, :, :]
227         1.,   2.,   3.,   4., // [0, 0, :]
228         5.,   6.,   7.,   8., // [0, 1, :]
229         9.,  10.,  11.,  12., // [0, 2, :]
230 
231         // [1, :, :]
232         1.,  4.,  1.,  4., // [1, 0, :]
233         1.,  4.,  1.,  4., // [1, 1, :]
234         1.,  4.,  1.,  4., // [1, 2, :]
235       })
236   };
237   // clang-format on
238 
239   run_test_cases(x, src, /*dim=*/0, expected_rets);
240 }
241 
TEST_F(OpSelectScatterOutTest,SelectMiddleDimAllIndexes)242 TEST_F(OpSelectScatterOutTest, SelectMiddleDimAllIndexes) {
243   TensorFactory<ScalarType::Double> tf;
244 
245   // clang-format off
246   Tensor x = tf.make(
247       {2, 3, 4},
248       {
249           // [0, :, :]
250           1.,   2.,   3.,   4., // [0, 0, :]
251           5.,   6.,   7.,   8., // [0, 1, :]
252           9.,  10.,  11.,  12., // [0, 2, :]
253 
254           // [1, :, :]
255          -1.,  -2.,  -3.,  -4., // [1, 0, :]
256          -5.,  -6.,  -7.,  -8., // [1, 1, :]
257          -9., -10., -11., -12., // [1, 2, :]
258       });
259   // clang-format on
260 
261   // clang-format off
262   Tensor src = tf.make(
263       {2, 4},
264       {
265           // [0, :, :]
266           1.,  4.,  1.,  4., // [0, 0, :]
267           1.,  4.,  1.,  4., // [0, 2, :]
268       });
269   // clang-format on
270 
271   // Try to select the tensor from the input front (0th dimension)
272   // The size of output tensor should follow these rules:
273   // - output.size(i) shall equal input.size(i) if i < dim,
274   // - output.size(i) shall equal input.size(i+1) if i >= dim
275   const std::vector<int32_t> out_size = {2, 3, 4};
276 
277   Tensor out = tf.zeros(out_size);
278 
279   // clang-format off
280   std::vector<Tensor> expected_rets = {
281     // Expected result when choosing from the 1st dimension and 0th index
282     // The result should equal x[:,0, :]
283     tf.make(
284       out_size,
285       {
286          // [0, :, :]
287          1.,   4.,   1.,   4., // [0, 0, :]
288          5.,   6.,   7.,   8., // [0, 1, :]
289          9.,  10.,  11.,  12., // [0, 2, :]
290 
291          // [1, :, :]
292          1.,   4.,   1.,   4., // [1, 0, :]
293         -5.,  -6.,  -7.,  -8., // [1, 1, :]
294         -9., -10., -11., -12., // [1, 2, :]
295       }),
296     // Expected result when choosing from the 1st dimension and 1st index
297     // The result should equal x[:, 1, :]
298     tf.make(
299       out_size,
300       {
301          // [0, :, :]
302          1.,   2.,   3.,   4., // [0, 0, :]
303          1.,   4.,   1.,   4., // [0, 1, :]
304          9.,  10.,  11.,  12., // [0, 2, :]
305 
306          // [1, :, :]
307         -1.,  -2.,  -3.,  -4., // [1, 0, :]
308          1.,   4.,   1.,   4., // [1, 1, :]
309         -9., -10., -11., -12., // [1, 2, :]
310       }),
311     // Expected result when choosing from the 1st dimension and 2th index
312     // The result should equal x[:,2, :]
313     tf.make(
314       out_size,
315       {
316          // [0, :, :]
317          1.,   2.,   3.,   4., // [0, 0, :]
318          5.,   6.,   7.,   8., // [0, 1, :]
319          1.,   4.,   1.,   4., // [0, 2, :]
320 
321          // [1, :, :]
322         -1.,  -2.,  -3.,  -4., // [1, 0, :]
323         -5.,  -6.,  -7.,  -8., // [1, 1, :]
324          1.,   4.,   1.,   4., // [1, 2, :]
325       })
326   };
327   // clang-format on
328 
329   run_test_cases(x, src, /*dim=*/1, expected_rets);
330 }
331 
TEST_F(OpSelectScatterOutTest,SelectEndDimAllIndexes)332 TEST_F(OpSelectScatterOutTest, SelectEndDimAllIndexes) {
333   TensorFactory<ScalarType::Double> tf;
334 
335   // clang-format off
336   Tensor x = tf.make(
337     {2, 3, 4},
338     {
339         // [0, :, :]
340         1.,   2.,   3.,   4., // [0, 0, :]
341         5.,   6.,   7.,   8., // [0, 1, :]
342         9.,  10.,  11.,  12., // [0, 2, :]
343 
344         // [1, :, :]
345        -1.,  -2.,  -3.,  -4., // [1, 0, :]
346        -5.,  -6.,  -7.,  -8., // [1, 1, :]
347        -9., -10., -11., -12., // [1, 2, :]
348     });
349   // clang-format on
350 
351   // clang-format off
352   Tensor src = tf.make(
353     {2, 3},
354     {
355         // [0, :, :]
356         1.,  4.,  1., // [0, 0, :]
357         1.,  4.,  1., // [0, 1, :]
358     });
359   // clang-format on
360 
361   // Try to select the tensor from the input front (0th dimension)
362   // The size of output tensor should follow these rules:
363   // - output.size(i) shall equal input.size(i) if i < dim,
364   // - output.size(i) shall equal input.size(i+1) if i >= dim
365   const std::vector<int32_t> out_size = {2, 3, 4};
366 
367   Tensor out = tf.zeros(out_size);
368 
369   // clang-format off
370   std::vector<Tensor> expected_rets = {
371     // Expected result when choosing from the 2nd dimension and 0th index
372     // The result should equal x[:,:, 0] (a.k.a 0th column of x data layout)
373     tf.make(
374       out_size,
375       {
376         // [0, :, :]
377         1.,   2.,   3.,   4., // [0, 0, :]
378         4.,   6.,   7.,   8., // [0, 1, :]
379         1.,  10.,  11.,  12., // [0, 2, :]
380 
381         // [1, :, :]
382         1.,  -2.,  -3.,  -4., // [1, 0, :]
383         4.,  -6.,  -7.,  -8., // [1, 1, :]
384         1., -10., -11., -12., // [1, 2, :]
385       }),
386     // Expected result when choosing from the 2nd dimension and 1st index
387     // The result should equal x[:,:, 1] (a.k.a 1st column of x data layout)
388     tf.make(
389       out_size,
390       {
391          // [0, :, :]
392          1.,  1.,   3.,   4., // [0, 0, :]
393          5.,  4.,   7.,   8., // [0, 1, :]
394          9.,  1.,  11.,  12., // [0, 2, :]
395 
396          // [1, :, :]
397         -1.,  1.,  -3.,  -4., // [1, 0, :]
398         -5.,  4.,  -7.,  -8., // [1, 1, :]
399         -9.,  1., -11., -12., // [1, 2, :]
400       }),
401     // Expected result when choosing from the 2nd dimension and 2nd index
402     // The result should equal x[:,:, 2] (a.k.a 2nd column of x data layout)
403     tf.make(
404       out_size,
405       {
406          // [0, :, :]
407          1.,   2.,  1.,   4., // [0, 0, :]
408          5.,   6.,  4.,   8., // [0, 1, :]
409          9.,  10.,  1.,  12., // [0, 2, :]
410 
411          // [1, :, :]
412         -1.,  -2.,  1.,  -4., // [1, 0, :]
413         -5.,  -6.,  4.,  -8., // [1, 1, :]
414         -9., -10.,  1., -12., // [1, 2, :]
415       }),
416     // Expected result when choosing from the 2nd dimension and 3rd index
417     // The result should equal x[:,:, 3] (a.k.a 3rd column of x data layout)
418     tf.make(
419       out_size,
420       {
421          // [0, :, :]
422          1.,   2.,   3.,  1., // [0, 0, :]
423          5.,   6.,   7.,  4., // [0, 1, :]
424          9.,  10.,  11.,  1., // [0, 2, :]
425 
426          // [1, :, :]
427         -1.,  -2.,  -3.,  1., // [1, 0, :]
428         -5.,  -6.,  -7.,  4., // [1, 1, :]
429         -9., -10., -11.,  1., // [1, 2, :]
430       })
431   };
432   // clang-format on
433 
434   run_test_cases(x, src, /*dim=*/2, expected_rets);
435 }
436 
437 #ifndef USE_ATEN_LIB
438 // Same test as above, but this time the output size is slightly off
TEST_F(OpSelectScatterOutTest,OutputDynamicShape)439 TEST_F(OpSelectScatterOutTest, OutputDynamicShape) {
440   TensorFactory<ScalarType::Double> tf;
441 
442   // clang-format off
443   Tensor x = tf.make(
444     {2, 3, 4},
445     {
446         // [0, :, :]
447         1.,   2.,   3.,   4., // [0, 0, :]
448         5.,   6.,   7.,   8., // [0, 1, :]
449         9.,  10.,  11.,  12., // [0, 2, :]
450 
451         // [1, :, :]
452        -1.,  -2.,  -3.,  -4., // [1, 0, :]
453        -5.,  -6.,  -7.,  -8., // [1, 1, :]
454        -9., -10., -11., -12., // [1, 2, :]
455     });
456   // clang-format on
457 
458   // clang-format off
459   Tensor src = tf.make(
460     {2, 3},
461     {
462         // [0, :, :]
463         1.,  4.,  1., // [0, 0, :]
464         1.,  4.,  1., // [0, 1, :]
465     });
466   // clang-format on
467 
468   // In this case, the output starts off with a different shape than is
469   // expected. We are checking to see that dynamic shape support is working
470   // correctly and that the output will be resized to the correct shape inside
471   // the kernel.
472   const std::vector<int32_t> out_size = {2, 6, 2};
473   const std::vector<int32_t> actual_out_size = {2, 3, 4};
474 
475   Tensor out =
476       tf.zeros(out_size, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
477 
478   // clang-format off
479   Tensor expected_ret = tf.make(
480     actual_out_size,
481     {
482       // [0, :, :]
483       1.,   2.,   3.,   4., // [0, 0, :]
484       4.,   6.,   7.,   8., // [0, 1, :]
485       1.,  10.,  11.,  12., // [0, 2, :]
486 
487       // [1, :, :]
488       1.,  -2.,  -3.,  -4., // [1, 0, :]
489       4.,  -6.,  -7.,  -8., // [1, 1, :]
490       1., -10., -11., -12., // [1, 2, :]
491     });
492   // clang-format on
493 
494   Tensor ret = op_select_scatter_out(x, src, 2, 0, out);
495   EXPECT_TENSOR_EQ(out, ret);
496   EXPECT_TENSOR_EQ(out, expected_ret);
497 }
498 #endif
499 
500 /// A generic smoke test that works for any dtype that supports ones() and
501 /// zeros().
TEST_F(OpSelectScatterOutTest,AllDtypesSupported)502 TEST_F(OpSelectScatterOutTest, AllDtypesSupported) {
503 #define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
504   ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
505 #undef TEST_ENTRY
506   // TODO: Also add tests for half, complex, quantized, and other types. Easiest
507   // way to do that would be to make TensorFactory support zeros() and ones()
508   // for those types.
509 }
510 
511 //////////////////////////////////////////////////////////////////////////////
512 // The following tests focus on empty-size tensor and empty tensor.
513 // Here we first define the term:
514 // empty-size tensor: size is [] but do have data (e.g.tensor(5))
515 // empty tensor: size is not [] and the size of at least one
516 // dim is zero, and does not have data in it (e.g ones(1,0,2,3))
517 
518 // This test focuses on the support for empty tensor (dim() > 0) input and empty
519 // tensor output
TEST_F(OpSelectScatterOutTest,EmptyTensorNonZeroNDimsInputSupported)520 TEST_F(OpSelectScatterOutTest, EmptyTensorNonZeroNDimsInputSupported) {
521   TensorFactory<ScalarType::Int> tf;
522 
523   // Using empty tensors as input.
524   Tensor x = tf.make({3, 0, 10, 3}, {});
525   EXPECT_EQ(x.numel(), 0);
526 
527   // src tensor whose shape is appropriate to place in dim(2) of x
528   Tensor src = tf.make({3, 0, 3}, {});
529 
530   // Output whose shape is equal to the input shape
531   Tensor out = tf.make({3, 0, 10, 3}, {});
532   EXPECT_EQ(out.numel(), 0);
533 
534   Tensor ret = op_select_scatter_out(x, src, /*dim=*/2, /*index=*/3, out);
535   EXPECT_EQ(ret.numel(), 0);
536   // Success if it doesn't assert on the weird-shaped empty input and the
537   // ret is still a empty array
538 }
539 
540 // Apply select on dim() == 0 empty tensor input and empty tensor output
TEST_F(OpSelectScatterOutTest,EmptyTensorZeroNDimsInputDies)541 TEST_F(OpSelectScatterOutTest, EmptyTensorZeroNDimsInputDies) {
542   TensorFactory<ScalarType::Int> tf;
543 
544   // Using empty tensors as input.
545   Tensor x = tf.make({0}, {});
546   EXPECT_EQ(x.numel(), 0);
547 
548   // Using empty src tensor
549   Tensor src = tf.make({0}, {});
550   EXPECT_EQ(src.numel(), 0);
551 
552   // Output whose shape is equal to the input shape
553   Tensor out = tf.make({}, {0});
554   EXPECT_EQ(out.numel(), 1);
555 
556   // Expected failure when slicing on the dimension with length 0 since no space
557   // on the dimension could be sliced. (out of bound error)
558   ET_EXPECT_KERNEL_FAILURE(
559       context_, op_select_scatter_out(x, src, /*dim=*/0, /*index=*/0, out));
560 }
561 ///////////////////////////////////////////////////////////////////////
562 
TEST_F(OpSelectScatterOutTest,DimOutOfBoundDies)563 TEST_F(OpSelectScatterOutTest, DimOutOfBoundDies) {
564   TensorFactory<ScalarType::Int> tf;
565 
566   Tensor x = tf.ones({1, 1, 1});
567   Tensor src = tf.ones({1, 1});
568 
569   Tensor out = tf.zeros({1, 1, 1});
570 
571   // Some invalid dim values.
572   const std::vector<int32_t> invalid_dims = {3, 4, 5, -4, -5, -6};
573   for (ssize_t dim : invalid_dims) {
574     ET_EXPECT_KERNEL_FAILURE(
575         context_, op_select_scatter_out(x, src, dim, /*index=*/0, out));
576   }
577 }
578 
TEST_F(OpSelectScatterOutTest,IndexOutOfBoundDies)579 TEST_F(OpSelectScatterOutTest, IndexOutOfBoundDies) {
580   TensorFactory<ScalarType::Int> tf;
581 
582   Tensor x = tf.ones({1, 1, 1});
583   Tensor src = tf.ones({1, 1});
584 
585   Tensor out = tf.zeros({1, 1, 1});
586 
587   // Some invalid dim values.
588   const std::vector<int32_t> invalid_indices = {3, 4, 5, -4, -5, -6};
589   for (ssize_t idx : invalid_indices) {
590     ET_EXPECT_KERNEL_FAILURE(
591         context_, op_select_scatter_out(x, src, /*dim=*/0, idx, out));
592   }
593 }
594 
TEST_F(OpSelectScatterOutTest,MismatchedDtypesDies)595 TEST_F(OpSelectScatterOutTest, MismatchedDtypesDies) {
596   TensorFactory<ScalarType::Int> tf_int;
597   TensorFactory<ScalarType::Float> tf_float;
598   Tensor x = tf_int.zeros({1, 2, 2});
599   Tensor src = tf_int.zeros({2, 2});
600 
601   // Size is compatible to the output, but a mismatched dtype.
602   Tensor out = tf_float.ones({1, 2, 2});
603 
604   ET_EXPECT_KERNEL_FAILURE(
605       context_, op_select_scatter_out(x, src, /*dim=*/0, /*index=*/0, out));
606 }
607 
TEST_F(OpSelectScatterOutTest,SrcMatchNumelLackDimAtEndDies)608 TEST_F(OpSelectScatterOutTest, SrcMatchNumelLackDimAtEndDies) {
609   TensorFactory<ScalarType::Int> tf;
610   Tensor x = tf.zeros({1, 2, 2, 1});
611   // src shares the same dtype and numel as the selected slice, but the wrong
612   // size (src.dim() should always one lower than x.dim())
613   Tensor src = tf.zeros({2, 2});
614 
615   Tensor out = tf.ones({1, 2, 2, 1});
616 
617   ET_EXPECT_KERNEL_FAILURE(
618       context_, op_select_scatter_out(x, src, /*dim=*/0, /*index=*/0, out));
619 }
620 
TEST_F(OpSelectScatterOutTest,SrcMatchNumelExtraDimAtFrontDies)621 TEST_F(OpSelectScatterOutTest, SrcMatchNumelExtraDimAtFrontDies) {
622   TensorFactory<ScalarType::Int> tf;
623   Tensor x = tf.zeros({2, 2});
624   // src shares the same dtype and numel as the selected slice, but the wrong
625   // size (src.dim() should always one lower than x.dim())
626   Tensor src = tf.zeros({1, 2});
627 
628   Tensor out = tf.ones({2, 2});
629 
630   ET_EXPECT_KERNEL_FAILURE(
631       context_, op_select_scatter_out(x, src, /*dim=*/0, /*index=*/0, out));
632 }
633 
TEST_F(OpSelectScatterOutTest,SrcSizeMismatchDimDies)634 TEST_F(OpSelectScatterOutTest, SrcSizeMismatchDimDies) {
635   TensorFactory<ScalarType::Int> tf;
636 
637   Tensor x = tf.zeros({2, 4, 7, 5});
638   // Should be {2, 4, 5} to match the selected slice of x when calling select()
639   // with dim 2.
640   Tensor src = tf.zeros({2, 4, 7});
641 
642   Tensor out = tf.zeros({2, 4, 7, 5});
643 
644   ET_EXPECT_KERNEL_FAILURE(
645       context_, op_select_scatter_out(x, src, /*dim=*/2, /*index=*/3, out));
646 }
647 
TEST_F(OpSelectScatterOutTest,DynamicShapeUpperBoundSameAsExpected)648 TEST_F(OpSelectScatterOutTest, DynamicShapeUpperBoundSameAsExpected) {
649   test_dynamic_shape(
650       {2, 3, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
651 }
652 
TEST_F(OpSelectScatterOutTest,DynamicShapeUpperBoundLargerThanExpected)653 TEST_F(OpSelectScatterOutTest, DynamicShapeUpperBoundLargerThanExpected) {
654   if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
655     GTEST_SKIP() << "Dynamic shape not supported";
656   }
657   test_dynamic_shape(
658       {10, 10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
659 }
660 
TEST_F(OpSelectScatterOutTest,DynamicShapeUnbound)661 TEST_F(OpSelectScatterOutTest, DynamicShapeUnbound) {
662   if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
663     GTEST_SKIP() << "Dynamic shape not supported";
664   }
665   test_dynamic_shape(
666       {1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
667 }
668