1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
16
17 #include <gtest/gtest.h>
18
19 using namespace ::testing;
20 using exec_aten::ArrayRef;
21 using exec_aten::optional;
22 using exec_aten::ScalarType;
23 using exec_aten::Tensor;
24 using torch::executor::testing::TensorFactory;
25
26 class OpSliceScatterTensorOutTest : public OperatorTest {
27 protected:
op_slice_scatter_out(const Tensor & self,const Tensor & src,int64_t dim,optional<int64_t> start,optional<int64_t> end,int64_t step,Tensor & out)28 Tensor& op_slice_scatter_out(
29 const Tensor& self,
30 const Tensor& src,
31 int64_t dim,
32 optional<int64_t> start,
33 optional<int64_t> end,
34 int64_t step,
35 Tensor& out) {
36 return torch::executor::aten::slice_scatter_outf(
37 context_, self, src, dim, start, end, step, out);
38 }
39
40 template <class CTYPE, exec_aten::ScalarType DTYPE>
test_dtype()41 void test_dtype() {
42 TensorFactory<DTYPE> tf;
43
44 // clang-format off
45 Tensor input = tf.make(
46 /*sizes=*/{3, 4},
47 /*data=*/{
48 1, 2, 3, 4, // [0, :]
49 5, 6, 7, 8, // [1, :]
50 9, 10, 11, 12, // [2, :]
51 });
52
53 // op_slice_scatter_out(input, src, /*dim=*/0, /*start=*/0, /*end=*/2, /*step=*/1, out),
54 // src shape should equal to input[0:2:1, :]
55 Tensor src = tf.make(
56 /*sizes=*/{2, 4},
57 /*data=*/{
58 5, 6, 7, 8, // [0, :]
59 1, 2, 3, 4, // [1, :]
60 });
61 Tensor expect_ret = tf.make(
62 /*sizes=*/{3, 4},
63 /*data=*/{
64 5, 6, 7, 8, // [0, :]
65 1, 2, 3, 4, // [1, :]
66 9, 10, 11, 12, // [2, :]
67 });
68 // clang-format on
69
70 Tensor out = tf.zeros({3, 4});
71 Tensor ret = op_slice_scatter_out(
72 input, src, /*dim=*/0, /*start=*/0, /*end=*/2, /*step=*/1, out);
73
74 EXPECT_TENSOR_EQ(out, ret);
75 EXPECT_TENSOR_EQ(ret, expect_ret);
76 }
77 };
78
TEST_F(OpSliceScatterTensorOutTest,LegalDimSupported)79 TEST_F(OpSliceScatterTensorOutTest, LegalDimSupported) {
80 TensorFactory<ScalarType::Double> tf;
81
82 // clang-format off
83 Tensor input = tf.make(
84 /*sizes=*/{2, 3, 4},
85 /*data=*/{
86 // [0, :, :]
87 1., 2., 3., 4., // [0, 0, :]
88 5., 6., 7., 8., // [0, 1, :]
89 9., 10., 11., 12., // [0, 2, :]
90
91 // [1, :, :]
92 -1., -2., -3., -4., // [1, 0, :]
93 -5., -6., -7., -8., // [1, 1, :]
94 -9., -10., -11., -12., // [1, 2, :]
95 });
96 // clang-format on
97
98 // clang-format off
99 // The size of the src tensor should follow these rules:
100 // - src.size(i) shall equal input.size(i) if i != dim,
101 // - src.size(i) shall equal num_values if i == dim
102 // The definition of num_values could be found at https://fburl.com/code/mnnxkowm
103
104 // op_slice_scatter_out(input, src, /*dim=*/0, /*start=*/0, /*end=*/1, /*step=*/1, out),
105 // src shape should equal to input[0:1:1,:, :]
106 Tensor src_dim_0 = tf.make(
107 /*sizes=*/{1, 3, 4},
108 /*data=*/{
109 8., 7., 6., 5., // [1, :]
110 4., 3., 2., 1., // [0, :]
111 1., 14., 18., 19., // [2, :]
112 });
113 Tensor expected_dim_0 = tf.make(
114 /*sizes=*/{2, 3, 4},
115 /*data=*/{
116 // [0, :, :]
117 8., 7., 6., 5., // [0, 1, :]
118 4., 3., 2., 1., // [0, 0, :]
119 1., 14., 18., 19., // [0, 2, :]
120
121 // [1, :, :]
122 -1., -2., -3., -4., // [1, 0, :]
123 -5., -6., -7., -8., // [1, 1, :]
124 -9., -10., -11., -12., // [1, 2, :]
125 });
126 // op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/0, /*end=*/1, /*step=*/1, out),
127 // src shape should equal to input[:,0:1:1, :]
128 Tensor src_dim_1 = tf.make(
129 /*sizes=*/{2, 1, 4},
130 /*data=*/{
131 4., 3., 2., 1., // [0, :, :]
132 -4., -3., -2., -1., // [1, :, :]
133 });
134 Tensor expected_dim_1 = tf.make(
135 /*sizes=*/{2, 3, 4},
136 /*data=*/{
137 // [0, :, :]
138 4., 3., 2., 1., // [0, 0, :]
139 5., 6., 7., 8., // [0, 1, :]
140 9., 10., 11., 12., // [0, 2, :]
141
142 // [1, :, :]
143 -4., -3., -2., -1., // [1, 0, :]
144 -5., -6., -7., -8., // [1, 1, :]
145 -9., -10., -11., -12., // [1, 2, :]
146 });
147 // op_slice_scatter_out(input, src, /*dim=*/2, /*start=*/0, /*end=*/1, /*step=*/1, out),
148 // src shape should equal to input[:,:, 0:1:1]
149 Tensor src_dim_2 = tf.make(
150 /*sizes=*/{2, 3, 1},
151 /*data=*/{
152 7., 1., 6., // [0, :, :]
153 -5., -9., -2., // [1, :, :]
154 });
155 Tensor expected_dim_2 = tf.make(
156 /*sizes=*/{2, 3, 4},
157 /*data=*/{
158 // [0, :, :]
159 7., 2., 3., 4., // [0, 0, :]
160 1., 6., 7., 8., // [0, 1, :]
161 6., 10., 11., 12., // [0, 2, :]
162
163 // [1, :, :]
164 -5., -2., -3., -4., // [1, 0, :]
165 -9., -6., -7., -8., // [1, 1, :]
166 -2., -10., -11., -12., // [1, 2, :]
167 });
168 // clang-format on
169 std::vector<Tensor> src_tensors = {
170 // Source tensor for dim=-3
171 src_dim_0,
172 // Source tensor for dim=-2
173 src_dim_1,
174 // Source tensor for dim=-1
175 src_dim_2,
176 // Source tensor for dim=0
177 src_dim_0,
178 // Source tensor for dim=1
179 src_dim_1,
180 // Source tensor for dim=2
181 src_dim_2,
182 };
183 std::vector<Tensor> expected_rets = {
184 // Ground truth for dim=-3
185 expected_dim_0,
186 // Ground truth for dim=-2
187 expected_dim_1,
188 // Ground truth for dim=-1
189 expected_dim_2,
190 // Ground truth for dim=0
191 expected_dim_0,
192 // Ground truth for dim=1
193 expected_dim_1,
194 // Ground truth for dim=2
195 expected_dim_2,
196 };
197
198 for (int64_t dim = -3; dim < 3; dim++) {
199 int64_t testcase_idx = dim + 3;
200 auto src = src_tensors[testcase_idx];
201 auto expected_ret = expected_rets[testcase_idx];
202
203 Tensor out = tf.zeros_like(expected_ret);
204
205 // Slice input on dim with start=0, end = 0 and step = 1
206 // Should always return the provided out Tensor.
207 // The ret shall meet the expectation.
208 Tensor ret = op_slice_scatter_out(
209 input, src, dim, /*start=*/0, /*end=*/1, /*step=*/1, out);
210 EXPECT_TENSOR_EQ(out, ret);
211 EXPECT_TENSOR_EQ(ret, expected_rets[testcase_idx]);
212 }
213 }
214
TEST_F(OpSliceScatterTensorOutTest,AllStartValsSupported)215 TEST_F(OpSliceScatterTensorOutTest, AllStartValsSupported) {
216 TensorFactory<ScalarType::Double> tf;
217
218 // clang-format off
219 Tensor input = tf.make(
220 /*sizes=*/{2, 3, 4},
221 /*data=*/{
222 // [0, :, :]
223 1., 2., 3., 4., // [0, 0, :]
224 5., 6., 7., 8., // [0, 1, :]
225 9., 10., 11., 12., // [0, 2, :]
226
227 // [1, :, :]
228 -1., -2., -3., -4., // [1, 0, :]
229 -5., -6., -7., -8., // [1, 1, :]
230 -9., -10., -11., -12., // [1, 2, :]
231 });
232 // clang-format on
233
234 // clang-format off
235 // Set the end large enough to hold any start
236
237 // The size of the src tensor should follow these rules:
238 // - src.size(i) shall equal input.size(i) if i != dim,
239 // - src.size(i) shall equal num_values if i == dim
240 // The definition of num_values could be found at https://fburl.com/code/mnnxkowm
241
242 // op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/ <= 0, /*end=*/10, /*step=*/1, out),
243 // src shape shall equal to input[:,0:3:1, :]
244 Tensor src_start_0_or_below = tf.make(
245 /*sizes=*/{2, 3, 4},
246 /*data=*/{
247 // [0, :, :]
248 -1., -2., -3., -4., // [0, 0, :]
249 -5., -6., -7., -8., // [0, 1, :]
250 -9., -10., -11., -12., // [0, 2, :]
251
252 // [1, :, :]
253 1., 2., 3., 4., // [1, 0, :]
254 5., 6., 7., 8., // [1, 1, :]
255 9., 10., 11., 12., // [1, 2, :]
256 });
257 Tensor expected_start_0_or_below = tf.make(
258 /*sizes=*/{2, 3, 4},
259 /*data=*/{
260 // [0, :, :]
261 -1., -2., -3., -4., // [0, 0, :]
262 -5., -6., -7., -8., // [0, 1, :]
263 -9., -10., -11., -12., // [0, 2, :]
264
265 // [1, :, :]
266 1., 2., 3., 4., // [1, 0, :]
267 5., 6., 7., 8., // [1, 1, :]
268 9., 10., 11., 12., // [1, 2, :]
269 });
270 // op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/1, /*end=*/10, /*step=*/1, out),
271 // src shape shall equal to input[:,1:3:1, :]
272 Tensor src_start_1 = tf.make(
273 /*sizes=*/{2, 2, 4},
274 /*data=*/{
275 // [0, :, :]
276 -9., -10., -11., -12., // [0, 1, :]
277 -5., -6., -7., -8., // [0, 0, :]
278
279 // [1, :, :]
280 9., 10., 11., 12., // [1, 1, :]
281 5., 6., 7., 8., // [1, 0, :]
282 });
283 Tensor expected_start_1 = tf.make(
284 /*sizes=*/{2, 3, 4},
285 /*data=*/{
286 // [0, :, :]
287 1., 2., 3., 4., // [0, 0, :]
288 -9., -10., -11., -12., // [0, 1, :]
289 -5., -6., -7., -8., // [0, 2, :]
290
291 // [1, :, :]
292 -1., -2., -3., -4., // [1, 0, :]
293 9., 10., 11., 12., // [1, 1, :]
294 5., 6., 7., 8., // [1, 0, :]
295 });
296 // op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/2, /*end=*/10, /*step=*/1, out),
297 // src shape shall equal to input[:,2:3:1, :] = input
298 Tensor src_start_2 = tf.make(
299 /*sizes=*/{2, 1, 4},
300 /*data=*/{
301 1., 19., 18., 17., // [0, 0, :]
302 -1., -19., -18., -17., // [1, 0, :]
303 });
304 Tensor expected_start_2 = tf.make(
305 /*sizes=*/{2, 3, 4},
306 /*data=*/{
307 // [0, :, :]
308 1., 2., 3., 4., // [0, 0, :]
309 5., 6., 7., 8., // [0, 1, :]
310 1., 19., 18., 17., // [0, 2, :]
311
312 // [1, :, :]
313 -1., -2., -3., -4., // [1, 0, :]
314 -5., -6., -7., -8., // [1, 1, :]
315 -1., -19., -18., -17., // [1, 2, :]
316 });
317 // op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/ > input.size(1) = 2, /*end=*/10, /*step=*/1, out),
318 // src_shape shall equal to input[:, 3:3:1, :], which is an empty tensor
319 Tensor src_start_3_or_above = tf.make({2, 0, 4}, {});
320 Tensor expected_start_3_or_above = tf.make(
321 /*sizes=*/{2, 3, 4},
322 /*data=*/{
323 // [0, :, :]
324 1., 2., 3., 4., // [0, 0, :]
325 5., 6., 7., 8., // [0, 1, :]
326 9., 10., 11., 12., // [0, 2, :]
327
328 // [1, :, :]
329 -1., -2., -3., -4., // [1, 0, :]
330 -5., -6., -7., -8., // [1, 1, :]
331 -9., -10., -11., -12., // [1, 2, :]
332 });
333 // clang-format on
334
335 std::vector<Tensor> src_tensors = {// start = -3
336 src_start_0_or_below,
337 // start = -2
338 src_start_1,
339 // start = -1
340 src_start_2,
341 // start = 0
342 src_start_0_or_below,
343 // start = 1
344 src_start_1,
345 // start = 2
346 src_start_2,
347 // start = 3
348 src_start_3_or_above};
349 std::vector<Tensor> expected_rets = {// start = -3
350 expected_start_0_or_below,
351 // start = -2
352 expected_start_1,
353 // start = -1
354 expected_start_2,
355 // start = 0
356 expected_start_0_or_below,
357 // start = 1
358 expected_start_1,
359 // start = 2
360 expected_start_2,
361 // start = 3
362 expected_start_3_or_above};
363
364 // In this test, we maintain dim and step as 1 and 1, also set the end
365 // large enough to hold any start
366 int64_t dim = 1;
367 int64_t end = 10;
368 int64_t step = 1;
369 for (int64_t start = -3; start < 4; start++) {
370 int64_t testcase_idx = start + 3;
371 auto src = src_tensors[testcase_idx];
372 auto expected_ret = expected_rets[testcase_idx];
373 Tensor out = tf.zeros_like(expected_ret);
374
375 // Should always return the provided out Tensor.
376 // The ret shall meet the expectation.
377 Tensor ret = op_slice_scatter_out(input, src, dim, start, end, step, out);
378 EXPECT_TENSOR_EQ(out, ret);
379 EXPECT_TENSOR_EQ(ret, expected_ret);
380 }
381 }
382
TEST_F(OpSliceScatterTensorOutTest,AllEndValsSupported)383 TEST_F(OpSliceScatterTensorOutTest, AllEndValsSupported) {
384 TensorFactory<ScalarType::Double> tf;
385
386 // clang-format off
387 Tensor input = tf.make(
388 /*sizes=*/{2, 3, 4},
389 /*data=*/{
390 // [0, :, :]
391 1., 2., 3., 4., // [0, 0, :]
392 5., 6., 7., 8., // [0, 1, :]
393 9., 10., 11., 12., // [0, 2, :]
394
395 // [1, :, :]
396 -1., -2., -3., -4., // [1, 0, :]
397 -5., -6., -7., -8., // [1, 1, :]
398 -9., -10., -11., -12., // [1, 2, :]
399 });
400
401 // The size of expected output tensor should follow these rules:
402 // - output.size(i) shall equal input.size(i) if i != dim,
403 // - output.size(i) shall equal num_values if i == dim
404 // The definition of num_values could be found at https://fburl.com/code/mnnxkowm
405
406 // op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/0, /*end=*/ <= 0, /*step=*/1, out),
407 // src shape should equal input[:,0:0:1, :], which should be an empty tensor
408 Tensor src_end_0_or_below = tf.make({2, 0, 4}, {});
409 Tensor expected_end_0_or_below = tf.make(
410 /*sizes=*/{2, 3, 4},
411 /*data=*/{
412 // [0, :, :]
413 1., 2., 3., 4., // [0, 0, :]
414 5., 6., 7., 8., // [0, 1, :]
415 9., 10., 11., 12., // [0, 2, :]
416
417 // [1, :, :]
418 -1., -2., -3., -4., // [1, 0, :]
419 -5., -6., -7., -8., // [1, 1, :]
420 -9., -10., -11., -12., // [1, 2, :]
421 });
422
423 // op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/0, /*end=*/1, /*step=*/1, out),
424 // src shape should equal to input[:,0:1:1, :]
425 Tensor src_end_1 = tf.make(
426 /*sizes=*/{2, 1, 4},
427 /*data=*/{
428 -4., -3., -2., -1., // [0, :, :]
429 4., 3., 2., 1., // [1, :, :]
430 });
431 Tensor expected_end_1 = tf.make(
432 /*sizes=*/{2, 3, 4},
433 /*data=*/{
434 // [0, :, :]
435 -4., -3., -2., -1., // [0, 0, :]
436 5., 6., 7., 8., // [0, 1, :]
437 9., 10., 11., 12., // [0, 2, :]
438
439 // [1, :, :]
440 4., 3., 2., 1., // [1, 0, :]
441 -5., -6., -7., -8., // [1, 1, :]
442 -9., -10., -11., -12., // [1, 2, :]
443 });
444
445 // op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/0, /*end=*/2, /*step=*/1, out),
446 // src shape should equal input[:,0:2:1, :]
447 Tensor src_end_2 = tf.make(
448 /*sizes=*/{2, 2, 4},
449 /*data=*/{
450 // [0, :, :]
451 -8., -7., -6., -5., // [0, 0, :]
452 -4., -3., -2., -1., // [0, :, :]
453
454 // [1, :, :]
455 8., 7., 6., 5., // [1, 0, :]
456 4., 3., 2., 1., // [1, 1, :]
457 });
458 Tensor expected_end_2 = tf.make(
459 /*sizes=*/{2, 3, 4},
460 /*data=*/{
461 // [0, :, :]
462 -8., -7., -6., -5., // [0, 0, :]
463 -4., -3., -2., -1., // [0, 1, :]
464 9., 10., 11., 12., // [0, 2, :]
465
466 // [1, :, :]
467 8., 7., 6., 5., // [1, 0, :]
468 4., 3., 2., 1., // [1, 1, :]
469 -9., -10., -11., -12., // [1, 2, :]
470 });
471 // op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/0, /*end=*/ >= 3, /*step=*/1, out),
472 // src shape should equal input[:,0:3:1, :] = input for any end >= 3
473 Tensor src_end_3_or_above = tf.make(
474 /*sizes=*/{2, 3, 4},
475 /*data=*/{
476 // [0, :, :]
477 -1., -2., -3., -4., // [0, 0, :]
478 -5., -6., -7., -8., // [0, 1, :]
479 -9., -10., -11., -12., // [0, 2, :]
480
481 // [1, :, :]
482 1., 2., 3., 4., // [1, 0, :]
483 5., 6., 7., 8., // [1, 1, :]
484 9., 10., 11., 12., // [1, 2, :]
485 });
486 Tensor expected_end_3_or_above = tf.make(
487 /*sizes=*/{2, 3, 4},
488 /*data=*/{
489 // [0, :, :]
490 -1., -2., -3., -4., // [0, 0, :]
491 -5., -6., -7., -8., // [0, 1, :]
492 -9., -10., -11., -12., // [0, 2, :]
493
494 // [1, :, :]
495 1., 2., 3., 4., // [1, 0, :]
496 5., 6., 7., 8., // [1, 1, :]
497 9., 10., 11., 12., // [1, 2, :]
498 });
499 // clang-format on
500
501 std::vector<Tensor> src_tensors = {// end = -3
502 src_end_0_or_below,
503 // end = -2
504 src_end_1,
505 // end = -1
506 src_end_2,
507 // end = 0
508 src_end_0_or_below,
509 // end = 1
510 src_end_1,
511 // end = 2
512 src_end_2,
513 // end = 3
514 src_end_3_or_above};
515
516 std::vector<Tensor> expected_rets = {// end = -3
517 expected_end_0_or_below,
518 // end = -2
519 expected_end_1,
520 // end = -1
521 expected_end_2,
522 // end = 0
523 expected_end_0_or_below,
524 // end = 1
525 expected_end_1,
526 // end = 2
527 expected_end_2,
528 // end = 3
529 expected_end_3_or_above};
530
531 int64_t dim = 1;
532 int64_t start = 0;
533 int64_t step = 1;
534 for (int64_t end = -3; end < 4; end++) {
535 int64_t testcase_idx = end + 3;
536
537 auto src = src_tensors[testcase_idx];
538 auto expected_ret = expected_rets[testcase_idx];
539 Tensor out = tf.zeros_like(expected_ret);
540
541 // Should always return the provided out Tensor.
542 // The ret shall meet the expectation.
543 Tensor ret = op_slice_scatter_out(input, src, dim, start, end, step, out);
544 EXPECT_TENSOR_EQ(out, ret);
545 EXPECT_TENSOR_EQ(ret, expected_ret);
546 }
547 }
548
TEST_F(OpSliceScatterTensorOutTest,LegalStepsSupported)549 TEST_F(OpSliceScatterTensorOutTest, LegalStepsSupported) {
550 TensorFactory<ScalarType::Double> tf;
551
552 // clang-format off
553 Tensor input = tf.make(
554 /*sizes=*/{2, 3, 4},
555 /*data=*/{
556 // [0, :, :]
557 1., 2., 3., 4., // [0, 0, :]
558 5., 6., 7., 8., // [0, 1, :]
559 9., 10., 11., 12., // [0, 2, :]
560
561 // [1, :, :]
562 -1., -2., -3., -4., // [1, 0, :]
563 -5., -6., -7., -8., // [1, 1, :]
564 -9., -10., -11., -12., // [1, 2, :]
565 });
566
567 // Set the end large enough to hold any step
568
569 // Expected ret for op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/0, /*end=*/10, /*step=*/1, out),
570 // src shape should equal to input[:,0:3:1, :]
571 Tensor src_0 = tf.make(
572 /*sizes=*/{2, 3, 4},
573 /*data=*/{
574 // [0, :, :]
575 -1., -2., -3., -4., // [0, 0, :]
576 -5., -6., -7., -8., // [0, 1, :]
577 -9., -10., -11., -12., // [0, 2, :]
578
579 // [1, :, :]
580 1., 2., 3., 4., // [1, 0, :]
581 5., 6., 7., 8., // [1, 1, :]
582 9., 10., 11., 12., // [1, 2, :]
583 });
584 Tensor expected_0 = tf.make(
585 /*sizes=*/{2, 3, 4},
586 /*data=*/{
587 // [0, :, :]
588 -1., -2., -3., -4., // [0, 0, :]
589 -5., -6., -7., -8., // [0, 1, :]
590 -9., -10., -11., -12., // [0, 2, :]
591
592 // [1, :, :]
593 1., 2., 3., 4., // [1, 0, :]
594 5., 6., 7., 8., // [1, 1, :]
595 9., 10., 11., 12., // [1, 2, :]
596 });
597 // Expected ret for op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/0, /*end=*/10, /*step=*/2, out),
598 // src shape should equal to input[:,0:3:2, :]
599 Tensor src_1 = tf.make(
600 /*sizes=*/{2, 2, 4},
601 /*data=*/{
602 // [0, :, :]
603 -1., -2., -3., -4., // [0, 0, :]
604 -9., -10., -11., -12., // [0, 1, :]
605
606 // [1, :, :]
607 1., 2., 3., 4., // [1, 0, :]
608 9., 10., 11., 12., // [1, 1, :]
609 });
610 Tensor expected_1 = tf.make(
611 /*sizes=*/{2, 3, 4},
612 /*data=*/{
613 // [0, :, :]
614 -1., -2., -3., -4., // [0, 0, :]
615 5., 6., 7., 8., // [0, 1, :]
616 -9., -10., -11., -12., // [0, 2, :]
617
618 // [1, :, :]
619 1., 2., 3., 4., // [1, 0, :]
620 -5., -6., -7., -8., // [1, 1, :]
621 9., 10., 11., 12., // [1, 2, :]
622 });
623 // Expected ret for op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/0, /*end=*/10, /*step=*/3, out),
624 // src shape should equal to input[:,0:3:3, :] = input
625 Tensor src_2 = tf.make(
626 /*sizes=*/{2, 1, 4},
627 /*data=*/{
628 -1., -2., -3., -4., // [0, 0, :]
629 1., 2., 3., 4., // [1, 0, :]
630 });
631 Tensor expected_2 = tf.make(
632 /*sizes=*/{2, 3, 4},
633 /*data=*/{
634 // [0, :, :]
635 -1., -2., -3., -4., // [0, 0, :]
636 5., 6., 7., 8., // [0, 1, :]
637 9., 10., 11., 12., // [0, 2, :]
638
639 // [1, :, :]
640 1., 2., 3., 4., // [1, 0, :]
641 -5., -6., -7., -8., // [1, 1, :]
642 -9., -10., -11., -12., // [1, 2, :]
643 });
644 // clang-format on
645
646 std::vector<Tensor> src_tensors = {src_0, src_1, src_2};
647 std::vector<Tensor> expected_rets = {expected_0, expected_1, expected_2};
648
649 // In this test, we maintain start and dim as 0 and 1, also set the
650 // end large enough to hold any step
651 int64_t start = 0;
652 int64_t dim = 1;
653 int64_t end = 10;
654 for (int64_t step = 1; step < 4; step++) {
655 int64_t testcase_idx = step - 1;
656
657 auto src = src_tensors[testcase_idx];
658 auto expected_ret = expected_rets[testcase_idx];
659 Tensor out = tf.zeros_like(expected_ret);
660
661 // Should always return the provided out Tensor.
662 // The ret shall meet the expectation.
663 Tensor ret = op_slice_scatter_out(input, src, dim, start, end, step, out);
664 EXPECT_TENSOR_EQ(out, ret);
665 EXPECT_TENSOR_EQ(ret, expected_ret);
666 }
667 }
668
669 /// A generic smoke test that works for any dtype that supports ones() and
670 /// zeros().
TEST_F(OpSliceScatterTensorOutTest,AllRealDtypesSupported)671 TEST_F(OpSliceScatterTensorOutTest, AllRealDtypesSupported) {
672 #define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
673 ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
674 #undef TEST_ENTRY
675 // TODO: Also add tests for half, complex, quantized, and other types. Easiest
676 // way to do that would be to make TensorFactory support zeros() and ones()
677 // for those types.
678 }
679
TEST_F(OpSliceScatterTensorOutTest,EmptyInputSupported)680 TEST_F(OpSliceScatterTensorOutTest, EmptyInputSupported) {
681 TensorFactory<ScalarType::Int> tf;
682
683 Tensor input = tf.ones({1, 0, 1});
684 Tensor src = tf.zeros({1, 0, 1});
685 Tensor out = tf.zeros({1, 0, 1});
686
687 Tensor expect = tf.ones({1, 0, 1});
688
689 // Some invalid dim values.
690 for (int64_t dim = 0; dim > input.dim(); dim++) {
691 Tensor ret = op_slice_scatter_out(
692 input, src, dim, /*start=*/0, /*end=*/1, /*step=*/1, out);
693 EXPECT_TENSOR_EQ(ret, out);
694
695 // All operations in this test share same ground truth
696 EXPECT_TENSOR_EQ(ret, expect);
697 }
698 }
699
TEST_F(OpSliceScatterTensorOutTest,EmptySizeInputDies)700 TEST_F(OpSliceScatterTensorOutTest, EmptySizeInputDies) {
701 TensorFactory<ScalarType::Int> tf;
702
703 Tensor input = tf.ones({});
704 Tensor src = tf.ones({});
705 Tensor out = tf.ones({});
706
707 // The operation shall die whatever the end is.
708 ET_EXPECT_KERNEL_FAILURE(
709 context_,
710 op_slice_scatter_out(
711 input, src, /*dim=*/0, /*start=*/0, /*end=*/0, /*step=*/1, out));
712 ET_EXPECT_KERNEL_FAILURE(
713 context_,
714 op_slice_scatter_out(
715 input, src, /*dim=*/0, /*start=*/0, /*end=*/1, /*step=*/1, out));
716 }
717
TEST_F(OpSliceScatterTensorOutTest,NonPostiveStepsDies)718 TEST_F(OpSliceScatterTensorOutTest, NonPostiveStepsDies) {
719 TensorFactory<ScalarType::Int> tf;
720
721 Tensor input = tf.ones({1, 1, 1});
722 Tensor src = tf.zeros({1, 1, 1});
723 Tensor out = tf.zeros({1, 1, 1});
724
725 // Some invalid step values.
726 const std::vector<int64_t> invalid_steps = {-2, -1, 0};
727 for (int64_t step : invalid_steps) {
728 ET_EXPECT_KERNEL_FAILURE(
729 context_,
730 op_slice_scatter_out(
731 input, src, /*dim=*/0, /*start=*/0, /*end=*/1, /*step=*/step, out));
732 }
733 }
734
TEST_F(OpSliceScatterTensorOutTest,DimOutOfBoundDies)735 TEST_F(OpSliceScatterTensorOutTest, DimOutOfBoundDies) {
736 TensorFactory<ScalarType::Int> tf;
737
738 Tensor input = tf.ones({1, 1, 1});
739 Tensor src = tf.zeros({1, 1, 1});
740 Tensor out = tf.zeros({1, 1, 1});
741
742 // Some invalid dim values.
743 const std::vector<int64_t> invalid_dims = {3, 4, 5, -4, -5, -6};
744 for (int64_t dim : invalid_dims) {
745 ET_EXPECT_KERNEL_FAILURE(
746 context_,
747 op_slice_scatter_out(
748 input, src, dim, /*start=*/0, /*end=*/1, /*step=*/1, out));
749 }
750 }
751
TEST_F(OpSliceScatterTensorOutTest,MismatchedOutDtypesDies)752 TEST_F(OpSliceScatterTensorOutTest, MismatchedOutDtypesDies) {
753 TensorFactory<ScalarType::Int> tf_int;
754 TensorFactory<ScalarType::Float> tf_float;
755 Tensor input = tf_int.zeros({1, 2, 2});
756 Tensor src = tf_int.zeros({1, 2, 2});
757
758 // Size is compatible to the output, but a mismatched dtype.
759 Tensor out = tf_float.ones({1, 2, 2});
760
761 ET_EXPECT_KERNEL_FAILURE(
762 context_,
763 op_slice_scatter_out(
764 input, src, /*dim=*/0, /*start=*/0, /*end=*/1, /*step=*/1, out));
765 }
766
TEST_F(OpSliceScatterTensorOutTest,OutSizeMismatchDimDies)767 TEST_F(OpSliceScatterTensorOutTest, OutSizeMismatchDimDies) {
768 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
769 GTEST_SKIP() << "ATen kernel can handle out with mismatched dimensions";
770 }
771 TensorFactory<ScalarType::Int> tf;
772
773 Tensor input = tf.zeros({2, 4, 7, 5});
774 Tensor src = tf.zeros({2, 4, 7, 5});
775
776 // Should be {2, 4, 7, 5}
777 Tensor out = tf.zeros({2, 4, 7});
778
779 ET_EXPECT_KERNEL_FAILURE(
780 context_,
781 op_slice_scatter_out(
782 input, src, /*dim=*/0, /*start=*/0, /*end=*/2, /*step=*/1, out));
783 }
784
TEST_F(OpSliceScatterTensorOutTest,SrcSizeMismatchDimDies)785 TEST_F(OpSliceScatterTensorOutTest, SrcSizeMismatchDimDies) {
786 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
787 GTEST_SKIP() << "ATen kernel can handle out with mismatched dimensions";
788 }
789 TensorFactory<ScalarType::Int> tf;
790
791 Tensor input = tf.zeros({2, 4, 7, 5});
792 Tensor src = tf.zeros({2, 4, 7});
793
794 // Should be {2, 4, 7, 5}
795 Tensor out = tf.zeros({2, 4, 7, 5});
796
797 ET_EXPECT_KERNEL_FAILURE(
798 context_,
799 op_slice_scatter_out(
800 input, src, /*dim=*/0, /*start=*/0, /*end=*/2, /*step=*/1, out));
801 }
802
TEST_F(OpSliceScatterTensorOutTest,DefaultStartValSupported)803 TEST_F(OpSliceScatterTensorOutTest, DefaultStartValSupported) {
804 TensorFactory<ScalarType::Int> tf;
805
806 Tensor input = tf.zeros({2, 4, 7, 5});
807 Tensor src = tf.ones({2, 4, 7, 5});
808
809 Tensor out = tf.zeros({2, 4, 7, 5});
810 Tensor expected = tf.ones({2, 4, 7, 5});
811
812 Tensor ret_default_start = op_slice_scatter_out(
813 input,
814 src,
815 /*dim=*/0,
816 /*start=*/exec_aten::nullopt,
817 /*end=*/2,
818 /*step=*/1,
819 out);
820 EXPECT_TENSOR_EQ(ret_default_start, out);
821 EXPECT_TENSOR_EQ(ret_default_start, expected);
822 }
823
TEST_F(OpSliceScatterTensorOutTest,DefaultEndValSupported)824 TEST_F(OpSliceScatterTensorOutTest, DefaultEndValSupported) {
825 TensorFactory<ScalarType::Int> tf;
826
827 Tensor input = tf.zeros({2, 4, 7, 5});
828 Tensor src = tf.ones({2, 4, 7, 5});
829
830 Tensor out = tf.zeros({2, 4, 7, 5});
831 Tensor expected = tf.ones({2, 4, 7, 5});
832
833 Tensor ret_default_end = op_slice_scatter_out(
834 input,
835 src,
836 /*dim=*/0,
837 /*start=*/0,
838 /*end=*/exec_aten::nullopt,
839 /*step=*/1,
840 out);
841 EXPECT_TENSOR_EQ(ret_default_end, out);
842 EXPECT_TENSOR_EQ(ret_default_end, expected);
843 }
844
TEST_F(OpSliceScatterTensorOutTest,DynamicShapeTest)845 TEST_F(OpSliceScatterTensorOutTest, DynamicShapeTest) {
846 TensorFactory<ScalarType::Int> tf;
847
848 Tensor input = tf.zeros({1, 4, 4});
849 Tensor src = tf.ones({1, 4, 4});
850
851 Tensor out =
852 tf.zeros({1, 2, 8}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
853 Tensor expected = tf.ones({1, 4, 4});
854
855 Tensor ret_default_end = op_slice_scatter_out(
856 input,
857 src,
858 /*dim=*/0,
859 /*start=*/0,
860 /*end=*/exec_aten::nullopt,
861 /*step=*/1,
862 out);
863 EXPECT_TENSOR_EQ(ret_default_end, out);
864 EXPECT_TENSOR_EQ(ret_default_end, expected);
865 }
866
TEST_F(OpSliceScatterTensorOutTest,LargeEndValue)867 TEST_F(OpSliceScatterTensorOutTest, LargeEndValue) {
868 TensorFactory<ScalarType::Int> tf;
869
870 Tensor input = tf.zeros({1, 1, 2, 5, 3, 3});
871 Tensor src = tf.ones({1, 1, 2, 5, 3, 3});
872
873 Tensor out = tf.zeros({1, 1, 2, 5, 3, 3});
874 Tensor expected = tf.ones({1, 1, 2, 5, 3, 3});
875
876 Tensor ret = op_slice_scatter_out(
877 input,
878 src,
879 /*dim=*/1,
880 /*start=*/0,
881 /*end=*/9223372036854775807,
882 /*step=*/1,
883 out);
884 EXPECT_TENSOR_EQ(ret, out);
885 EXPECT_TENSOR_EQ(ret, expected);
886 }
887