1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
16
17 #include <gtest/gtest.h>
18 #include <sys/types.h>
19
20 using namespace ::testing;
21 using exec_aten::ArrayRef;
22 using exec_aten::ScalarType;
23 using exec_aten::Tensor;
24 using torch::executor::testing::TensorFactory;
25
26 class OpSelectScatterOutTest : public OperatorTest {
27 protected:
op_select_scatter_out(const Tensor & self,const Tensor & src,int64_t dim,int64_t index,Tensor & out)28 Tensor& op_select_scatter_out(
29 const Tensor& self,
30 const Tensor& src,
31 int64_t dim,
32 int64_t index,
33 Tensor& out) {
34 return torch::executor::aten::select_scatter_outf(
35 context_, self, src, dim, index, out);
36 }
37
38 template <class CTYPE, exec_aten::ScalarType DTYPE>
test_dtype()39 void test_dtype() {
40 TensorFactory<DTYPE> tf;
41
42 // Using the following tensors, inserting a tensor of either ones or zeros
43 // into the appropriate selected slice should result in a tensor of all ones
44 // or all zeros.
45
46 // clang-format off
47 Tensor x = tf.make(
48 {3, 2, 4},
49 {
50 // all ones below are from x,
51 // and all zeros are from y.
52 // [0, :, :]
53 1, 1, 1, 1, // [0, 0, :]
54 0, 0, 0, 0, // [0, 1, :]
55
56 // [1, :, :]
57 1, 1, 1, 1, // [1, 0, :]
58 0, 0, 0, 0, // [1, 1, :]
59
60 // [2, :, :]
61 1, 1, 1, 1, // [2, 0, :]
62 0, 0, 0, 0, // [2, 1, :]
63 });
64 // clang-format on
65
66 // clang-format off
67 Tensor src_ones = tf.make(
68 {3, 4},
69 {
70 // [:, :]
71 1, 1, 1, 1, // [0, :]
72 1, 1, 1, 1, // [1, :]
73 1, 1, 1, 1, // [2, :]
74 });
75 // clang-format on
76
77 // clang-format off
78 Tensor src_zeros = tf.make(
79 {3, 4},
80 {
81 // [:, :]
82 0, 0, 0, 0, // [0, :]
83 0, 0, 0, 0, // [1, :]
84 0, 0, 0, 0, // [2, :]
85 });
86 // clang-format on
87
88 // Expected outs should be all ones or all zeros depending on which src
89 // tensor is used.
90
91 Tensor out_0 = tf.zeros({3, 2, 4});
92 Tensor out_1 = tf.ones({3, 2, 4});
93 Tensor ret_0 =
94 op_select_scatter_out(x, src_zeros, /*dim=*/1, /*index=*/0, out_0);
95 Tensor ret_1 =
96 op_select_scatter_out(x, src_ones, /*dim=*/1, /*index=*/1, out_1);
97
98 EXPECT_TENSOR_EQ(ret_0, out_0);
99 EXPECT_TENSOR_EQ(ret_1, out_1);
100
101 EXPECT_TENSOR_EQ(ret_0, tf.zeros({3, 2, 4}));
102 EXPECT_TENSOR_EQ(ret_1, tf.ones({3, 2, 4}));
103 }
104
105 // Run the test by selecting Tensor x on given dim and all available indexes
106 // on that dimension
run_test_cases(const Tensor & x,const Tensor & src,ssize_t dim,const std::vector<Tensor> & expected)107 void run_test_cases(
108 const Tensor& x,
109 const Tensor& src,
110 ssize_t dim,
111 const std::vector<Tensor>& expected) {
112 // Generated out tensor sharing same size and dtype with expected tensor
113 TensorFactory<ScalarType::Double> tf;
114
115 const std::vector<int32_t> out_size(
116 expected[0].sizes().begin(), expected[0].sizes().end());
117 Tensor out = tf.zeros(out_size);
118
119 for (ssize_t idx = 0; idx < x.size(dim); idx++) {
120 // Should always return the provided out Tensor.
121 // The ret shall meet the expectation.
122 Tensor ret = op_select_scatter_out(x, src, dim, idx, out);
123 EXPECT_TENSOR_EQ(out, ret);
124 EXPECT_TENSOR_EQ(out, expected[idx]);
125
126 ret =
127 op_select_scatter_out(x, src, dim, /*index=*/idx - x.size(dim), out);
128 EXPECT_TENSOR_EQ(out, ret);
129 EXPECT_TENSOR_EQ(out, expected[idx]);
130 }
131 }
132
133 /* %python
134 import torch
135 torch.manual_seed(0)
136 x = torch.randint(10, (2, 3, 2))
137 y = torch.randint(10, (3, 2))
138 dim = 0
139 index = 1
140 res = torch.select_scatter(x, y, dim, index)
141 op = "op_select_scatter_out"
142 opt_extra_params = f"""{dim}, {index},"""
143 out_args = "out_shape, dynamism"
144 dtype = "ScalarType::Int"
145 check = "EXPECT_TENSOR_CLOSE" */
146
test_dynamic_shape(const std::vector<int32_t> & out_shape,enum torch::executor::TensorShapeDynamism dynamism)147 void test_dynamic_shape(
148 const std::vector<int32_t>& out_shape,
149 enum torch::executor::TensorShapeDynamism dynamism) {
150 /* %python
151 %rewrite(binary_op) */
152
153 TensorFactory<ScalarType::Int> tf;
154
155 Tensor x = tf.make({2, 3, 2}, {4, 9, 3, 0, 3, 9, 7, 3, 7, 3, 1, 6});
156 Tensor y = tf.make({3, 2}, {6, 9, 8, 6, 6, 8});
157 Tensor expected = tf.make({2, 3, 2}, {4, 9, 3, 0, 3, 9, 6, 9, 8, 6, 6, 8});
158
159 Tensor out = tf.zeros(out_shape, dynamism);
160 op_select_scatter_out(x, y, 0, 1, out);
161 EXPECT_TENSOR_CLOSE(out, expected);
162 }
163 };
164
TEST_F(OpSelectScatterOutTest,SelectFrontDimAllIndexes)165 TEST_F(OpSelectScatterOutTest, SelectFrontDimAllIndexes) {
166 TensorFactory<ScalarType::Double> tf;
167
168 // clang-format off
169 Tensor x = tf.make(
170 {2, 3, 4},
171 {
172 // [0, :, :]
173 1., 2., 3., 4., // [0, 0, :]
174 5., 6., 7., 8., // [0, 1, :]
175 9., 10., 11., 12., // [0, 2, :]
176
177 // [1, :, :]
178 -1., -2., -3., -4., // [1, 0, :]
179 -5., -6., -7., -8., // [1, 1, :]
180 -9., -10., -11., -12., // [1, 2, :]
181 });
182 // clang-format on
183
184 // clang-format off
185 Tensor src = tf.make(
186 {3, 4},
187 {
188 // [0, :, :]
189 1., 4., 1., 4., // [0, 0, :]
190 1., 4., 1., 4., // [0, 1, :]
191 1., 4., 1., 4., // [0, 2, :]
192 });
193 // clang-format on
194
195 // Try to select the tensor from the input front (0th dimension)
196 // The size of output tensor should follow these rules:
197 // - output.size(i) shall equal input.size(i) if i < dim,
198 // - output.size(i) shall equal input.size(i+1) if i >= dim
199 const std::vector<int32_t> out_size = {2, 3, 4};
200
201 Tensor out = tf.zeros(out_size);
202
203 // clang-format off
204 std::vector<Tensor> expected_rets = {
205 // Expected result when choosing from the 0th dimension and 0th index
206 // The result should equal x[0,:, :]
207 tf.make(
208 out_size,
209 {
210 // [0, :, :]
211 1., 4., 1., 4., // [0, 0, :]
212 1., 4., 1., 4., // [0, 1, :]
213 1., 4., 1., 4., // [0, 2, :]
214
215 // [1, :, :]
216 -1., -2., -3., -4., // [1, 0, :]
217 -5., -6., -7., -8., // [1, 1, :]
218 -9., -10., -11., -12., // [1, 2, :]
219 }),
220
221 // Expected result when choosing from the 0th dimension and 1st index
222 // The result should euqal x[1, :, :]
223 tf.make(
224 out_size,
225 {
226 // [0, :, :]
227 1., 2., 3., 4., // [0, 0, :]
228 5., 6., 7., 8., // [0, 1, :]
229 9., 10., 11., 12., // [0, 2, :]
230
231 // [1, :, :]
232 1., 4., 1., 4., // [1, 0, :]
233 1., 4., 1., 4., // [1, 1, :]
234 1., 4., 1., 4., // [1, 2, :]
235 })
236 };
237 // clang-format on
238
239 run_test_cases(x, src, /*dim=*/0, expected_rets);
240 }
241
TEST_F(OpSelectScatterOutTest,SelectMiddleDimAllIndexes)242 TEST_F(OpSelectScatterOutTest, SelectMiddleDimAllIndexes) {
243 TensorFactory<ScalarType::Double> tf;
244
245 // clang-format off
246 Tensor x = tf.make(
247 {2, 3, 4},
248 {
249 // [0, :, :]
250 1., 2., 3., 4., // [0, 0, :]
251 5., 6., 7., 8., // [0, 1, :]
252 9., 10., 11., 12., // [0, 2, :]
253
254 // [1, :, :]
255 -1., -2., -3., -4., // [1, 0, :]
256 -5., -6., -7., -8., // [1, 1, :]
257 -9., -10., -11., -12., // [1, 2, :]
258 });
259 // clang-format on
260
261 // clang-format off
262 Tensor src = tf.make(
263 {2, 4},
264 {
265 // [0, :, :]
266 1., 4., 1., 4., // [0, 0, :]
267 1., 4., 1., 4., // [0, 2, :]
268 });
269 // clang-format on
270
271 // Try to select the tensor from the input front (0th dimension)
272 // The size of output tensor should follow these rules:
273 // - output.size(i) shall equal input.size(i) if i < dim,
274 // - output.size(i) shall equal input.size(i+1) if i >= dim
275 const std::vector<int32_t> out_size = {2, 3, 4};
276
277 Tensor out = tf.zeros(out_size);
278
279 // clang-format off
280 std::vector<Tensor> expected_rets = {
281 // Expected result when choosing from the 1st dimension and 0th index
282 // The result should equal x[:,0, :]
283 tf.make(
284 out_size,
285 {
286 // [0, :, :]
287 1., 4., 1., 4., // [0, 0, :]
288 5., 6., 7., 8., // [0, 1, :]
289 9., 10., 11., 12., // [0, 2, :]
290
291 // [1, :, :]
292 1., 4., 1., 4., // [1, 0, :]
293 -5., -6., -7., -8., // [1, 1, :]
294 -9., -10., -11., -12., // [1, 2, :]
295 }),
296 // Expected result when choosing from the 1st dimension and 1st index
297 // The result should equal x[:, 1, :]
298 tf.make(
299 out_size,
300 {
301 // [0, :, :]
302 1., 2., 3., 4., // [0, 0, :]
303 1., 4., 1., 4., // [0, 1, :]
304 9., 10., 11., 12., // [0, 2, :]
305
306 // [1, :, :]
307 -1., -2., -3., -4., // [1, 0, :]
308 1., 4., 1., 4., // [1, 1, :]
309 -9., -10., -11., -12., // [1, 2, :]
310 }),
311 // Expected result when choosing from the 1st dimension and 2th index
312 // The result should equal x[:,2, :]
313 tf.make(
314 out_size,
315 {
316 // [0, :, :]
317 1., 2., 3., 4., // [0, 0, :]
318 5., 6., 7., 8., // [0, 1, :]
319 1., 4., 1., 4., // [0, 2, :]
320
321 // [1, :, :]
322 -1., -2., -3., -4., // [1, 0, :]
323 -5., -6., -7., -8., // [1, 1, :]
324 1., 4., 1., 4., // [1, 2, :]
325 })
326 };
327 // clang-format on
328
329 run_test_cases(x, src, /*dim=*/1, expected_rets);
330 }
331
TEST_F(OpSelectScatterOutTest,SelectEndDimAllIndexes)332 TEST_F(OpSelectScatterOutTest, SelectEndDimAllIndexes) {
333 TensorFactory<ScalarType::Double> tf;
334
335 // clang-format off
336 Tensor x = tf.make(
337 {2, 3, 4},
338 {
339 // [0, :, :]
340 1., 2., 3., 4., // [0, 0, :]
341 5., 6., 7., 8., // [0, 1, :]
342 9., 10., 11., 12., // [0, 2, :]
343
344 // [1, :, :]
345 -1., -2., -3., -4., // [1, 0, :]
346 -5., -6., -7., -8., // [1, 1, :]
347 -9., -10., -11., -12., // [1, 2, :]
348 });
349 // clang-format on
350
351 // clang-format off
352 Tensor src = tf.make(
353 {2, 3},
354 {
355 // [0, :, :]
356 1., 4., 1., // [0, 0, :]
357 1., 4., 1., // [0, 1, :]
358 });
359 // clang-format on
360
361 // Try to select the tensor from the input front (0th dimension)
362 // The size of output tensor should follow these rules:
363 // - output.size(i) shall equal input.size(i) if i < dim,
364 // - output.size(i) shall equal input.size(i+1) if i >= dim
365 const std::vector<int32_t> out_size = {2, 3, 4};
366
367 Tensor out = tf.zeros(out_size);
368
369 // clang-format off
370 std::vector<Tensor> expected_rets = {
371 // Expected result when choosing from the 2nd dimension and 0th index
372 // The result should equal x[:,:, 0] (a.k.a 0th column of x data layout)
373 tf.make(
374 out_size,
375 {
376 // [0, :, :]
377 1., 2., 3., 4., // [0, 0, :]
378 4., 6., 7., 8., // [0, 1, :]
379 1., 10., 11., 12., // [0, 2, :]
380
381 // [1, :, :]
382 1., -2., -3., -4., // [1, 0, :]
383 4., -6., -7., -8., // [1, 1, :]
384 1., -10., -11., -12., // [1, 2, :]
385 }),
386 // Expected result when choosing from the 2nd dimension and 1st index
387 // The result should equal x[:,:, 1] (a.k.a 1st column of x data layout)
388 tf.make(
389 out_size,
390 {
391 // [0, :, :]
392 1., 1., 3., 4., // [0, 0, :]
393 5., 4., 7., 8., // [0, 1, :]
394 9., 1., 11., 12., // [0, 2, :]
395
396 // [1, :, :]
397 -1., 1., -3., -4., // [1, 0, :]
398 -5., 4., -7., -8., // [1, 1, :]
399 -9., 1., -11., -12., // [1, 2, :]
400 }),
401 // Expected result when choosing from the 2nd dimension and 2nd index
402 // The result should equal x[:,:, 2] (a.k.a 2nd column of x data layout)
403 tf.make(
404 out_size,
405 {
406 // [0, :, :]
407 1., 2., 1., 4., // [0, 0, :]
408 5., 6., 4., 8., // [0, 1, :]
409 9., 10., 1., 12., // [0, 2, :]
410
411 // [1, :, :]
412 -1., -2., 1., -4., // [1, 0, :]
413 -5., -6., 4., -8., // [1, 1, :]
414 -9., -10., 1., -12., // [1, 2, :]
415 }),
416 // Expected result when choosing from the 2nd dimension and 3rd index
417 // The result should equal x[:,:, 3] (a.k.a 3rd column of x data layout)
418 tf.make(
419 out_size,
420 {
421 // [0, :, :]
422 1., 2., 3., 1., // [0, 0, :]
423 5., 6., 7., 4., // [0, 1, :]
424 9., 10., 11., 1., // [0, 2, :]
425
426 // [1, :, :]
427 -1., -2., -3., 1., // [1, 0, :]
428 -5., -6., -7., 4., // [1, 1, :]
429 -9., -10., -11., 1., // [1, 2, :]
430 })
431 };
432 // clang-format on
433
434 run_test_cases(x, src, /*dim=*/2, expected_rets);
435 }
436
437 #ifndef USE_ATEN_LIB
438 // Same test as above, but this time the output size is slightly off
TEST_F(OpSelectScatterOutTest,OutputDynamicShape)439 TEST_F(OpSelectScatterOutTest, OutputDynamicShape) {
440 TensorFactory<ScalarType::Double> tf;
441
442 // clang-format off
443 Tensor x = tf.make(
444 {2, 3, 4},
445 {
446 // [0, :, :]
447 1., 2., 3., 4., // [0, 0, :]
448 5., 6., 7., 8., // [0, 1, :]
449 9., 10., 11., 12., // [0, 2, :]
450
451 // [1, :, :]
452 -1., -2., -3., -4., // [1, 0, :]
453 -5., -6., -7., -8., // [1, 1, :]
454 -9., -10., -11., -12., // [1, 2, :]
455 });
456 // clang-format on
457
458 // clang-format off
459 Tensor src = tf.make(
460 {2, 3},
461 {
462 // [0, :, :]
463 1., 4., 1., // [0, 0, :]
464 1., 4., 1., // [0, 1, :]
465 });
466 // clang-format on
467
468 // In this case, the output starts off with a different shape than is
469 // expected. We are checking to see that dynamic shape support is working
470 // correctly and that the output will be resized to the correct shape inside
471 // the kernel.
472 const std::vector<int32_t> out_size = {2, 6, 2};
473 const std::vector<int32_t> actual_out_size = {2, 3, 4};
474
475 Tensor out =
476 tf.zeros(out_size, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
477
478 // clang-format off
479 Tensor expected_ret = tf.make(
480 actual_out_size,
481 {
482 // [0, :, :]
483 1., 2., 3., 4., // [0, 0, :]
484 4., 6., 7., 8., // [0, 1, :]
485 1., 10., 11., 12., // [0, 2, :]
486
487 // [1, :, :]
488 1., -2., -3., -4., // [1, 0, :]
489 4., -6., -7., -8., // [1, 1, :]
490 1., -10., -11., -12., // [1, 2, :]
491 });
492 // clang-format on
493
494 Tensor ret = op_select_scatter_out(x, src, 2, 0, out);
495 EXPECT_TENSOR_EQ(out, ret);
496 EXPECT_TENSOR_EQ(out, expected_ret);
497 }
498 #endif
499
500 /// A generic smoke test that works for any dtype that supports ones() and
501 /// zeros().
TEST_F(OpSelectScatterOutTest,AllDtypesSupported)502 TEST_F(OpSelectScatterOutTest, AllDtypesSupported) {
503 #define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
504 ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
505 #undef TEST_ENTRY
506 // TODO: Also add tests for half, complex, quantized, and other types. Easiest
507 // way to do that would be to make TensorFactory support zeros() and ones()
508 // for those types.
509 }
510
511 //////////////////////////////////////////////////////////////////////////////
512 // The following tests focus on empty-size tensor and empty tensor.
513 // Here we first define the term:
514 // empty-size tensor: size is [] but do have data (e.g.tensor(5))
515 // empty tensor: size is not [] and the size of at least one
516 // dim is zero, and does not have data in it (e.g ones(1,0,2,3))
517
518 // This test focuses on the support for empty tensor (dim() > 0) input and empty
519 // tensor output
TEST_F(OpSelectScatterOutTest,EmptyTensorNonZeroNDimsInputSupported)520 TEST_F(OpSelectScatterOutTest, EmptyTensorNonZeroNDimsInputSupported) {
521 TensorFactory<ScalarType::Int> tf;
522
523 // Using empty tensors as input.
524 Tensor x = tf.make({3, 0, 10, 3}, {});
525 EXPECT_EQ(x.numel(), 0);
526
527 // src tensor whose shape is appropriate to place in dim(2) of x
528 Tensor src = tf.make({3, 0, 3}, {});
529
530 // Output whose shape is equal to the input shape
531 Tensor out = tf.make({3, 0, 10, 3}, {});
532 EXPECT_EQ(out.numel(), 0);
533
534 Tensor ret = op_select_scatter_out(x, src, /*dim=*/2, /*index=*/3, out);
535 EXPECT_EQ(ret.numel(), 0);
536 // Success if it doesn't assert on the weird-shaped empty input and the
537 // ret is still a empty array
538 }
539
540 // Apply select on dim() == 0 empty tensor input and empty tensor output
TEST_F(OpSelectScatterOutTest,EmptyTensorZeroNDimsInputDies)541 TEST_F(OpSelectScatterOutTest, EmptyTensorZeroNDimsInputDies) {
542 TensorFactory<ScalarType::Int> tf;
543
544 // Using empty tensors as input.
545 Tensor x = tf.make({0}, {});
546 EXPECT_EQ(x.numel(), 0);
547
548 // Using empty src tensor
549 Tensor src = tf.make({0}, {});
550 EXPECT_EQ(src.numel(), 0);
551
552 // Output whose shape is equal to the input shape
553 Tensor out = tf.make({}, {0});
554 EXPECT_EQ(out.numel(), 1);
555
556 // Expected failure when slicing on the dimension with length 0 since no space
557 // on the dimension could be sliced. (out of bound error)
558 ET_EXPECT_KERNEL_FAILURE(
559 context_, op_select_scatter_out(x, src, /*dim=*/0, /*index=*/0, out));
560 }
561 ///////////////////////////////////////////////////////////////////////
562
TEST_F(OpSelectScatterOutTest,DimOutOfBoundDies)563 TEST_F(OpSelectScatterOutTest, DimOutOfBoundDies) {
564 TensorFactory<ScalarType::Int> tf;
565
566 Tensor x = tf.ones({1, 1, 1});
567 Tensor src = tf.ones({1, 1});
568
569 Tensor out = tf.zeros({1, 1, 1});
570
571 // Some invalid dim values.
572 const std::vector<int32_t> invalid_dims = {3, 4, 5, -4, -5, -6};
573 for (ssize_t dim : invalid_dims) {
574 ET_EXPECT_KERNEL_FAILURE(
575 context_, op_select_scatter_out(x, src, dim, /*index=*/0, out));
576 }
577 }
578
TEST_F(OpSelectScatterOutTest,IndexOutOfBoundDies)579 TEST_F(OpSelectScatterOutTest, IndexOutOfBoundDies) {
580 TensorFactory<ScalarType::Int> tf;
581
582 Tensor x = tf.ones({1, 1, 1});
583 Tensor src = tf.ones({1, 1});
584
585 Tensor out = tf.zeros({1, 1, 1});
586
587 // Some invalid dim values.
588 const std::vector<int32_t> invalid_indices = {3, 4, 5, -4, -5, -6};
589 for (ssize_t idx : invalid_indices) {
590 ET_EXPECT_KERNEL_FAILURE(
591 context_, op_select_scatter_out(x, src, /*dim=*/0, idx, out));
592 }
593 }
594
TEST_F(OpSelectScatterOutTest,MismatchedDtypesDies)595 TEST_F(OpSelectScatterOutTest, MismatchedDtypesDies) {
596 TensorFactory<ScalarType::Int> tf_int;
597 TensorFactory<ScalarType::Float> tf_float;
598 Tensor x = tf_int.zeros({1, 2, 2});
599 Tensor src = tf_int.zeros({2, 2});
600
601 // Size is compatible to the output, but a mismatched dtype.
602 Tensor out = tf_float.ones({1, 2, 2});
603
604 ET_EXPECT_KERNEL_FAILURE(
605 context_, op_select_scatter_out(x, src, /*dim=*/0, /*index=*/0, out));
606 }
607
TEST_F(OpSelectScatterOutTest,SrcMatchNumelLackDimAtEndDies)608 TEST_F(OpSelectScatterOutTest, SrcMatchNumelLackDimAtEndDies) {
609 TensorFactory<ScalarType::Int> tf;
610 Tensor x = tf.zeros({1, 2, 2, 1});
611 // src shares the same dtype and numel as the selected slice, but the wrong
612 // size (src.dim() should always one lower than x.dim())
613 Tensor src = tf.zeros({2, 2});
614
615 Tensor out = tf.ones({1, 2, 2, 1});
616
617 ET_EXPECT_KERNEL_FAILURE(
618 context_, op_select_scatter_out(x, src, /*dim=*/0, /*index=*/0, out));
619 }
620
TEST_F(OpSelectScatterOutTest,SrcMatchNumelExtraDimAtFrontDies)621 TEST_F(OpSelectScatterOutTest, SrcMatchNumelExtraDimAtFrontDies) {
622 TensorFactory<ScalarType::Int> tf;
623 Tensor x = tf.zeros({2, 2});
624 // src shares the same dtype and numel as the selected slice, but the wrong
625 // size (src.dim() should always one lower than x.dim())
626 Tensor src = tf.zeros({1, 2});
627
628 Tensor out = tf.ones({2, 2});
629
630 ET_EXPECT_KERNEL_FAILURE(
631 context_, op_select_scatter_out(x, src, /*dim=*/0, /*index=*/0, out));
632 }
633
TEST_F(OpSelectScatterOutTest,SrcSizeMismatchDimDies)634 TEST_F(OpSelectScatterOutTest, SrcSizeMismatchDimDies) {
635 TensorFactory<ScalarType::Int> tf;
636
637 Tensor x = tf.zeros({2, 4, 7, 5});
638 // Should be {2, 4, 5} to match the selected slice of x when calling select()
639 // with dim 2.
640 Tensor src = tf.zeros({2, 4, 7});
641
642 Tensor out = tf.zeros({2, 4, 7, 5});
643
644 ET_EXPECT_KERNEL_FAILURE(
645 context_, op_select_scatter_out(x, src, /*dim=*/2, /*index=*/3, out));
646 }
647
TEST_F(OpSelectScatterOutTest,DynamicShapeUpperBoundSameAsExpected)648 TEST_F(OpSelectScatterOutTest, DynamicShapeUpperBoundSameAsExpected) {
649 test_dynamic_shape(
650 {2, 3, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
651 }
652
TEST_F(OpSelectScatterOutTest,DynamicShapeUpperBoundLargerThanExpected)653 TEST_F(OpSelectScatterOutTest, DynamicShapeUpperBoundLargerThanExpected) {
654 if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
655 GTEST_SKIP() << "Dynamic shape not supported";
656 }
657 test_dynamic_shape(
658 {10, 10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
659 }
660
TEST_F(OpSelectScatterOutTest,DynamicShapeUnbound)661 TEST_F(OpSelectScatterOutTest, DynamicShapeUnbound) {
662 if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
663 GTEST_SKIP() << "Dynamic shape not supported";
664 }
665 test_dynamic_shape(
666 {1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
667 }
668