xref: /aosp_15_r20/external/executorch/kernels/test/op_slice_scatter_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
16 
17 #include <gtest/gtest.h>
18 
19 using namespace ::testing;
20 using exec_aten::ArrayRef;
21 using exec_aten::optional;
22 using exec_aten::ScalarType;
23 using exec_aten::Tensor;
24 using torch::executor::testing::TensorFactory;
25 
26 class OpSliceScatterTensorOutTest : public OperatorTest {
27  protected:
op_slice_scatter_out(const Tensor & self,const Tensor & src,int64_t dim,optional<int64_t> start,optional<int64_t> end,int64_t step,Tensor & out)28   Tensor& op_slice_scatter_out(
29       const Tensor& self,
30       const Tensor& src,
31       int64_t dim,
32       optional<int64_t> start,
33       optional<int64_t> end,
34       int64_t step,
35       Tensor& out) {
36     return torch::executor::aten::slice_scatter_outf(
37         context_, self, src, dim, start, end, step, out);
38   }
39 
40   template <class CTYPE, exec_aten::ScalarType DTYPE>
test_dtype()41   void test_dtype() {
42     TensorFactory<DTYPE> tf;
43 
44     // clang-format off
45     Tensor input = tf.make(
46       /*sizes=*/{3, 4},
47       /*data=*/{
48         1,   2,   3,   4, // [0, :]
49         5,   6,   7,   8, // [1, :]
50         9,  10,  11,  12, // [2, :]
51       });
52 
53     // op_slice_scatter_out(input, src, /*dim=*/0, /*start=*/0, /*end=*/2, /*step=*/1, out),
54     // src shape should equal to input[0:2:1, :]
55     Tensor src = tf.make(
56       /*sizes=*/{2, 4},
57       /*data=*/{
58         5,   6,   7,   8, // [0, :]
59         1,   2,   3,   4, // [1, :]
60       });
61     Tensor expect_ret = tf.make(
62       /*sizes=*/{3, 4},
63       /*data=*/{
64         5,   6,   7,   8, // [0, :]
65         1,   2,   3,   4, // [1, :]
66         9,  10,  11,  12, // [2, :]
67       });
68     // clang-format on
69 
70     Tensor out = tf.zeros({3, 4});
71     Tensor ret = op_slice_scatter_out(
72         input, src, /*dim=*/0, /*start=*/0, /*end=*/2, /*step=*/1, out);
73 
74     EXPECT_TENSOR_EQ(out, ret);
75     EXPECT_TENSOR_EQ(ret, expect_ret);
76   }
77 };
78 
TEST_F(OpSliceScatterTensorOutTest,LegalDimSupported)79 TEST_F(OpSliceScatterTensorOutTest, LegalDimSupported) {
80   TensorFactory<ScalarType::Double> tf;
81 
82   // clang-format off
83   Tensor input = tf.make(
84     /*sizes=*/{2, 3, 4},
85     /*data=*/{
86       // [0, :, :]
87        1.,   2.,   3.,   4., // [0, 0, :]
88        5.,   6.,   7.,   8., // [0, 1, :]
89        9.,  10.,  11.,  12., // [0, 2, :]
90 
91       // [1, :, :]
92       -1.,  -2.,  -3.,  -4., // [1, 0, :]
93       -5.,  -6.,  -7.,  -8., // [1, 1, :]
94       -9., -10., -11., -12., // [1, 2, :]
95     });
96   // clang-format on
97 
98   // clang-format off
99   // The size of the src tensor should follow these rules:
100   // - src.size(i) shall equal input.size(i) if i != dim,
101   // - src.size(i) shall equal num_values if i == dim
102   //   The definition of num_values could be found at https://fburl.com/code/mnnxkowm
103 
104   // op_slice_scatter_out(input, src, /*dim=*/0, /*start=*/0, /*end=*/1, /*step=*/1, out),
105   // src shape should equal to input[0:1:1,:, :]
106   Tensor src_dim_0 = tf.make(
107     /*sizes=*/{1, 3, 4},
108     /*data=*/{
109       8.,   7.,   6.,   5., // [1, :]
110       4.,   3.,   2.,   1., // [0, :]
111       1.,  14.,  18.,  19., // [2, :]
112     });
113   Tensor expected_dim_0 = tf.make(
114     /*sizes=*/{2, 3, 4},
115     /*data=*/{
116       // [0, :, :]
117        8.,   7.,   6.,   5., // [0, 1, :]
118        4.,   3.,   2.,   1., // [0, 0, :]
119        1.,  14.,  18.,  19., // [0, 2, :]
120 
121       // [1, :, :]
122       -1.,  -2.,  -3.,  -4., // [1, 0, :]
123       -5.,  -6.,  -7.,  -8., // [1, 1, :]
124       -9., -10., -11., -12., // [1, 2, :]
125     });
126   // op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/0, /*end=*/1, /*step=*/1, out),
127   // src shape should equal to input[:,0:1:1, :]
128   Tensor src_dim_1 = tf.make(
129     /*sizes=*/{2, 1, 4},
130     /*data=*/{
131        4.,   3.,   2.,   1., // [0, :, :]
132       -4.,  -3.,  -2.,  -1., // [1, :, :]
133     });
134   Tensor expected_dim_1 = tf.make(
135     /*sizes=*/{2, 3, 4},
136     /*data=*/{
137       // [0, :, :]
138        4.,   3.,   2.,   1., // [0, 0, :]
139        5.,   6.,   7.,   8., // [0, 1, :]
140        9.,  10.,  11.,  12., // [0, 2, :]
141 
142       // [1, :, :]
143       -4.,  -3.,  -2.,  -1., // [1, 0, :]
144       -5.,  -6.,  -7.,  -8., // [1, 1, :]
145       -9., -10., -11., -12., // [1, 2, :]
146     });
147   // op_slice_scatter_out(input, src, /*dim=*/2, /*start=*/0, /*end=*/1, /*step=*/1, out),
148   // src shape should equal to input[:,:, 0:1:1]
149   Tensor src_dim_2 = tf.make(
150     /*sizes=*/{2, 3, 1},
151     /*data=*/{
152        7.,   1.,   6., // [0, :, :]
153       -5.,  -9.,  -2., // [1, :, :]
154     });
155   Tensor expected_dim_2 = tf.make(
156     /*sizes=*/{2, 3, 4},
157     /*data=*/{
158       // [0, :, :]
159        7.,   2.,   3.,   4., // [0, 0, :]
160        1.,   6.,   7.,   8., // [0, 1, :]
161        6.,  10.,  11.,  12., // [0, 2, :]
162 
163       // [1, :, :]
164       -5.,  -2.,  -3.,  -4., // [1, 0, :]
165       -9.,  -6.,  -7.,  -8., // [1, 1, :]
166       -2., -10., -11., -12., // [1, 2, :]
167     });
168   // clang-format on
169   std::vector<Tensor> src_tensors = {
170       // Source tensor for dim=-3
171       src_dim_0,
172       // Source tensor for dim=-2
173       src_dim_1,
174       // Source tensor for dim=-1
175       src_dim_2,
176       // Source tensor for dim=0
177       src_dim_0,
178       // Source tensor for dim=1
179       src_dim_1,
180       // Source tensor for dim=2
181       src_dim_2,
182   };
183   std::vector<Tensor> expected_rets = {
184       // Ground truth for dim=-3
185       expected_dim_0,
186       // Ground truth for dim=-2
187       expected_dim_1,
188       // Ground truth for dim=-1
189       expected_dim_2,
190       // Ground truth for dim=0
191       expected_dim_0,
192       // Ground truth for dim=1
193       expected_dim_1,
194       // Ground truth for dim=2
195       expected_dim_2,
196   };
197 
198   for (int64_t dim = -3; dim < 3; dim++) {
199     int64_t testcase_idx = dim + 3;
200     auto src = src_tensors[testcase_idx];
201     auto expected_ret = expected_rets[testcase_idx];
202 
203     Tensor out = tf.zeros_like(expected_ret);
204 
205     // Slice input on dim with start=0, end = 0 and step = 1
206     // Should always return the provided out Tensor.
207     // The ret shall meet the expectation.
208     Tensor ret = op_slice_scatter_out(
209         input, src, dim, /*start=*/0, /*end=*/1, /*step=*/1, out);
210     EXPECT_TENSOR_EQ(out, ret);
211     EXPECT_TENSOR_EQ(ret, expected_rets[testcase_idx]);
212   }
213 }
214 
TEST_F(OpSliceScatterTensorOutTest,AllStartValsSupported)215 TEST_F(OpSliceScatterTensorOutTest, AllStartValsSupported) {
216   TensorFactory<ScalarType::Double> tf;
217 
218   // clang-format off
219   Tensor input = tf.make(
220     /*sizes=*/{2, 3, 4},
221     /*data=*/{
222       // [0, :, :]
223        1.,   2.,   3.,   4., // [0, 0, :]
224        5.,   6.,   7.,   8., // [0, 1, :]
225        9.,  10.,  11.,  12., // [0, 2, :]
226 
227       // [1, :, :]
228       -1.,  -2.,  -3.,  -4., // [1, 0, :]
229       -5.,  -6.,  -7.,  -8., // [1, 1, :]
230       -9., -10., -11., -12., // [1, 2, :]
231     });
232   // clang-format on
233 
234   // clang-format off
235   // Set the end large enough to hold any start
236 
237   // The size of the src tensor should follow these rules:
238   // - src.size(i) shall equal input.size(i) if i != dim,
239   // - src.size(i) shall equal num_values if i == dim
240   //   The definition of num_values could be found at https://fburl.com/code/mnnxkowm
241 
242   // op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/ <= 0, /*end=*/10, /*step=*/1, out),
243   // src shape shall equal to input[:,0:3:1, :]
244   Tensor src_start_0_or_below = tf.make(
245     /*sizes=*/{2, 3, 4},
246     /*data=*/{
247       // [0, :, :]
248       -1.,  -2.,  -3.,  -4., // [0, 0, :]
249       -5.,  -6.,  -7.,  -8., // [0, 1, :]
250       -9., -10., -11., -12., // [0, 2, :]
251 
252       // [1, :, :]
253        1.,   2.,   3.,   4., // [1, 0, :]
254        5.,   6.,   7.,   8., // [1, 1, :]
255        9.,  10.,  11.,  12., // [1, 2, :]
256     });
257   Tensor expected_start_0_or_below = tf.make(
258     /*sizes=*/{2, 3, 4},
259     /*data=*/{
260       // [0, :, :]
261       -1.,  -2.,  -3.,  -4., // [0, 0, :]
262       -5.,  -6.,  -7.,  -8., // [0, 1, :]
263       -9., -10., -11., -12., // [0, 2, :]
264 
265       // [1, :, :]
266        1.,   2.,   3.,   4., // [1, 0, :]
267        5.,   6.,   7.,   8., // [1, 1, :]
268        9.,  10.,  11.,  12., // [1, 2, :]
269     });
270   // op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/1, /*end=*/10, /*step=*/1, out),
271   // src shape shall equal to input[:,1:3:1, :]
272   Tensor src_start_1 = tf.make(
273     /*sizes=*/{2, 2, 4},
274     /*data=*/{
275       // [0, :, :]
276       -9., -10., -11., -12., // [0, 1, :]
277       -5.,  -6.,  -7.,  -8., // [0, 0, :]
278 
279       // [1, :, :]
280        9.,  10.,  11.,  12., // [1, 1, :]
281        5.,   6.,   7.,   8., // [1, 0, :]
282     });
283   Tensor expected_start_1 = tf.make(
284     /*sizes=*/{2, 3, 4},
285     /*data=*/{
286       // [0, :, :]
287        1.,   2.,   3.,   4., // [0, 0, :]
288       -9., -10., -11., -12., // [0, 1, :]
289       -5.,  -6.,  -7.,  -8., // [0, 2, :]
290 
291       // [1, :, :]
292       -1.,  -2.,  -3.,  -4., // [1, 0, :]
293        9.,  10.,  11.,  12., // [1, 1, :]
294        5.,   6.,   7.,   8., // [1, 0, :]
295     });
296   // op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/2, /*end=*/10, /*step=*/1, out),
297   // src shape shall equal to input[:,2:3:1, :] = input
298   Tensor src_start_2 = tf.make(
299     /*sizes=*/{2, 1, 4},
300     /*data=*/{
301        1.,  19.,  18.,  17., // [0, 0, :]
302       -1., -19., -18., -17., // [1, 0, :]
303     });
304   Tensor expected_start_2 = tf.make(
305     /*sizes=*/{2, 3, 4},
306     /*data=*/{
307       // [0, :, :]
308        1.,   2.,   3.,   4., // [0, 0, :]
309        5.,   6.,   7.,   8., // [0, 1, :]
310        1.,  19.,  18.,  17., // [0, 2, :]
311 
312       // [1, :, :]
313       -1.,  -2.,  -3.,  -4., // [1, 0, :]
314       -5.,  -6.,  -7.,  -8., // [1, 1, :]
315       -1., -19., -18., -17., // [1, 2, :]
316     });
317   // op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/ > input.size(1) = 2, /*end=*/10, /*step=*/1, out),
318   // src_shape shall equal to input[:, 3:3:1, :], which is an empty tensor
319   Tensor src_start_3_or_above = tf.make({2, 0, 4}, {});
320   Tensor expected_start_3_or_above = tf.make(
321     /*sizes=*/{2, 3, 4},
322     /*data=*/{
323       // [0, :, :]
324       1.,   2.,   3.,   4., // [0, 0, :]
325       5.,   6.,   7.,   8., // [0, 1, :]
326       9.,  10.,  11.,  12., // [0, 2, :]
327 
328       // [1, :, :]
329       -1.,  -2.,  -3.,  -4., // [1, 0, :]
330       -5.,  -6.,  -7.,  -8., // [1, 1, :]
331       -9., -10., -11., -12., // [1, 2, :]
332     });
333   // clang-format on
334 
335   std::vector<Tensor> src_tensors = {// start = -3
336                                      src_start_0_or_below,
337                                      // start = -2
338                                      src_start_1,
339                                      // start = -1
340                                      src_start_2,
341                                      // start = 0
342                                      src_start_0_or_below,
343                                      // start = 1
344                                      src_start_1,
345                                      // start = 2
346                                      src_start_2,
347                                      // start = 3
348                                      src_start_3_or_above};
349   std::vector<Tensor> expected_rets = {// start = -3
350                                        expected_start_0_or_below,
351                                        // start = -2
352                                        expected_start_1,
353                                        // start = -1
354                                        expected_start_2,
355                                        // start = 0
356                                        expected_start_0_or_below,
357                                        // start = 1
358                                        expected_start_1,
359                                        // start = 2
360                                        expected_start_2,
361                                        // start = 3
362                                        expected_start_3_or_above};
363 
364   // In this test, we maintain dim and step as 1 and 1, also set the end
365   // large enough to hold any start
366   int64_t dim = 1;
367   int64_t end = 10;
368   int64_t step = 1;
369   for (int64_t start = -3; start < 4; start++) {
370     int64_t testcase_idx = start + 3;
371     auto src = src_tensors[testcase_idx];
372     auto expected_ret = expected_rets[testcase_idx];
373     Tensor out = tf.zeros_like(expected_ret);
374 
375     // Should always return the provided out Tensor.
376     // The ret shall meet the expectation.
377     Tensor ret = op_slice_scatter_out(input, src, dim, start, end, step, out);
378     EXPECT_TENSOR_EQ(out, ret);
379     EXPECT_TENSOR_EQ(ret, expected_ret);
380   }
381 }
382 
TEST_F(OpSliceScatterTensorOutTest,AllEndValsSupported)383 TEST_F(OpSliceScatterTensorOutTest, AllEndValsSupported) {
384   TensorFactory<ScalarType::Double> tf;
385 
386   // clang-format off
387   Tensor input = tf.make(
388     /*sizes=*/{2, 3, 4},
389     /*data=*/{
390       // [0, :, :]
391        1.,   2.,   3.,   4., // [0, 0, :]
392        5.,   6.,   7.,   8., // [0, 1, :]
393        9.,  10.,  11.,  12., // [0, 2, :]
394 
395       // [1, :, :]
396       -1.,  -2.,  -3.,  -4., // [1, 0, :]
397       -5.,  -6.,  -7.,  -8., // [1, 1, :]
398       -9., -10., -11., -12., // [1, 2, :]
399     });
400 
401   // The size of expected output tensor should follow these rules:
402   // - output.size(i) shall equal input.size(i) if i != dim,
403   // - output.size(i) shall equal num_values if i == dim
404   //   The definition of num_values could be found at https://fburl.com/code/mnnxkowm
405 
406   // op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/0, /*end=*/ <= 0, /*step=*/1, out),
407   // src shape should equal input[:,0:0:1, :], which should be an empty tensor
408   Tensor src_end_0_or_below = tf.make({2, 0, 4}, {});
409   Tensor expected_end_0_or_below = tf.make(
410     /*sizes=*/{2, 3, 4},
411     /*data=*/{
412       // [0, :, :]
413        1.,   2.,   3.,   4., // [0, 0, :]
414        5.,   6.,   7.,   8., // [0, 1, :]
415        9.,  10.,  11.,  12., // [0, 2, :]
416 
417       // [1, :, :]
418       -1.,  -2.,  -3.,  -4., // [1, 0, :]
419       -5.,  -6.,  -7.,  -8., // [1, 1, :]
420       -9., -10., -11., -12., // [1, 2, :]
421     });
422 
423   // op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/0, /*end=*/1, /*step=*/1, out),
424   // src shape should equal to input[:,0:1:1, :]
425   Tensor src_end_1 = tf.make(
426     /*sizes=*/{2, 1, 4},
427     /*data=*/{
428       -4.,  -3.,  -2.,  -1., // [0, :, :]
429        4.,   3.,   2.,   1., // [1, :, :]
430     });
431   Tensor expected_end_1 = tf.make(
432     /*sizes=*/{2, 3, 4},
433     /*data=*/{
434       // [0, :, :]
435       -4.,  -3.,  -2.,  -1., // [0, 0, :]
436        5.,   6.,   7.,   8., // [0, 1, :]
437        9.,  10.,  11.,  12., // [0, 2, :]
438 
439       // [1, :, :]
440        4.,   3.,   2.,   1., // [1, 0, :]
441       -5.,  -6.,  -7.,  -8., // [1, 1, :]
442       -9., -10., -11., -12., // [1, 2, :]
443     });
444 
445   // op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/0, /*end=*/2, /*step=*/1, out),
446   // src shape should equal input[:,0:2:1, :]
447   Tensor src_end_2 = tf.make(
448     /*sizes=*/{2, 2, 4},
449     /*data=*/{
450       // [0, :, :]
451       -8.,  -7.,  -6.,  -5., // [0, 0, :]
452       -4.,  -3.,  -2.,  -1., // [0, :, :]
453 
454       // [1, :, :]
455        8.,   7.,   6.,   5., // [1, 0, :]
456        4.,   3.,   2.,   1., // [1, 1, :]
457     });
458   Tensor expected_end_2 = tf.make(
459     /*sizes=*/{2, 3, 4},
460     /*data=*/{
461       // [0, :, :]
462       -8.,  -7.,  -6.,  -5., // [0, 0, :]
463       -4.,  -3.,  -2.,  -1., // [0, 1, :]
464        9.,  10.,  11.,  12., // [0, 2, :]
465 
466       // [1, :, :]
467        8.,   7.,   6.,   5., // [1, 0, :]
468        4.,   3.,   2.,   1., // [1, 1, :]
469       -9., -10., -11., -12., // [1, 2, :]
470     });
471   // op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/0, /*end=*/ >= 3, /*step=*/1, out),
472   // src shape should equal input[:,0:3:1, :] = input for any end >= 3
473   Tensor src_end_3_or_above = tf.make(
474     /*sizes=*/{2, 3, 4},
475     /*data=*/{
476       // [0, :, :]
477       -1.,  -2.,  -3.,  -4., // [0, 0, :]
478       -5.,  -6.,  -7.,  -8., // [0, 1, :]
479       -9., -10., -11., -12., // [0, 2, :]
480 
481       // [1, :, :]
482        1.,   2.,   3.,   4., // [1, 0, :]
483        5.,   6.,   7.,   8., // [1, 1, :]
484        9.,  10.,  11.,  12., // [1, 2, :]
485     });
486   Tensor expected_end_3_or_above = tf.make(
487     /*sizes=*/{2, 3, 4},
488     /*data=*/{
489       // [0, :, :]
490       -1.,  -2.,  -3.,  -4., // [0, 0, :]
491       -5.,  -6.,  -7.,  -8., // [0, 1, :]
492       -9., -10., -11., -12., // [0, 2, :]
493 
494       // [1, :, :]
495        1.,   2.,   3.,   4., // [1, 0, :]
496        5.,   6.,   7.,   8., // [1, 1, :]
497        9.,  10.,  11.,  12., // [1, 2, :]
498     });
499   // clang-format on
500 
501   std::vector<Tensor> src_tensors = {// end = -3
502                                      src_end_0_or_below,
503                                      // end = -2
504                                      src_end_1,
505                                      // end = -1
506                                      src_end_2,
507                                      // end = 0
508                                      src_end_0_or_below,
509                                      // end = 1
510                                      src_end_1,
511                                      // end = 2
512                                      src_end_2,
513                                      // end = 3
514                                      src_end_3_or_above};
515 
516   std::vector<Tensor> expected_rets = {// end = -3
517                                        expected_end_0_or_below,
518                                        // end = -2
519                                        expected_end_1,
520                                        // end = -1
521                                        expected_end_2,
522                                        // end = 0
523                                        expected_end_0_or_below,
524                                        // end = 1
525                                        expected_end_1,
526                                        // end = 2
527                                        expected_end_2,
528                                        // end = 3
529                                        expected_end_3_or_above};
530 
531   int64_t dim = 1;
532   int64_t start = 0;
533   int64_t step = 1;
534   for (int64_t end = -3; end < 4; end++) {
535     int64_t testcase_idx = end + 3;
536 
537     auto src = src_tensors[testcase_idx];
538     auto expected_ret = expected_rets[testcase_idx];
539     Tensor out = tf.zeros_like(expected_ret);
540 
541     // Should always return the provided out Tensor.
542     // The ret shall meet the expectation.
543     Tensor ret = op_slice_scatter_out(input, src, dim, start, end, step, out);
544     EXPECT_TENSOR_EQ(out, ret);
545     EXPECT_TENSOR_EQ(ret, expected_ret);
546   }
547 }
548 
TEST_F(OpSliceScatterTensorOutTest,LegalStepsSupported)549 TEST_F(OpSliceScatterTensorOutTest, LegalStepsSupported) {
550   TensorFactory<ScalarType::Double> tf;
551 
552   // clang-format off
553   Tensor input = tf.make(
554     /*sizes=*/{2, 3, 4},
555     /*data=*/{
556       // [0, :, :]
557        1.,   2.,   3.,   4., // [0, 0, :]
558        5.,   6.,   7.,   8., // [0, 1, :]
559        9.,  10.,  11.,  12., // [0, 2, :]
560 
561       // [1, :, :]
562       -1.,  -2.,  -3.,  -4., // [1, 0, :]
563       -5.,  -6.,  -7.,  -8., // [1, 1, :]
564       -9., -10., -11., -12., // [1, 2, :]
565     });
566 
567   // Set the end large enough to hold any step
568 
569   // Expected ret for op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/0, /*end=*/10, /*step=*/1, out),
570   // src shape should equal to input[:,0:3:1, :]
571   Tensor src_0 = tf.make(
572     /*sizes=*/{2, 3, 4},
573     /*data=*/{
574       // [0, :, :]
575       -1.,  -2.,  -3.,  -4., // [0, 0, :]
576       -5.,  -6.,  -7.,  -8., // [0, 1, :]
577       -9., -10., -11., -12., // [0, 2, :]
578 
579       // [1, :, :]
580        1.,   2.,   3.,   4., // [1, 0, :]
581        5.,   6.,   7.,   8., // [1, 1, :]
582        9.,  10.,  11.,  12., // [1, 2, :]
583     });
584   Tensor expected_0 = tf.make(
585     /*sizes=*/{2, 3, 4},
586     /*data=*/{
587       // [0, :, :]
588       -1.,  -2.,  -3.,  -4., // [0, 0, :]
589       -5.,  -6.,  -7.,  -8., // [0, 1, :]
590       -9., -10., -11., -12., // [0, 2, :]
591 
592       // [1, :, :]
593        1.,   2.,   3.,   4., // [1, 0, :]
594        5.,   6.,   7.,   8., // [1, 1, :]
595        9.,  10.,  11.,  12., // [1, 2, :]
596     });
597   // Expected ret for op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/0, /*end=*/10, /*step=*/2, out),
598   // src shape should equal to input[:,0:3:2, :]
599   Tensor src_1 = tf.make(
600     /*sizes=*/{2, 2, 4},
601     /*data=*/{
602       // [0, :, :]
603       -1.,  -2.,  -3.,  -4., // [0, 0, :]
604       -9., -10., -11., -12., // [0, 1, :]
605 
606       // [1, :, :]
607        1.,   2.,   3.,   4., // [1, 0, :]
608        9.,  10.,  11.,  12., // [1, 1, :]
609     });
610   Tensor expected_1 = tf.make(
611     /*sizes=*/{2, 3, 4},
612     /*data=*/{
613       // [0, :, :]
614       -1.,  -2.,  -3.,  -4., // [0, 0, :]
615        5.,   6.,   7.,   8., // [0, 1, :]
616       -9., -10., -11., -12., // [0, 2, :]
617 
618       // [1, :, :]
619        1.,   2.,   3.,   4., // [1, 0, :]
620       -5.,  -6.,  -7.,  -8., // [1, 1, :]
621        9.,  10.,  11.,  12., // [1, 2, :]
622     });
623   // Expected ret for op_slice_scatter_out(input, src, /*dim=*/1, /*start=*/0, /*end=*/10, /*step=*/3, out),
624   // src shape should equal to input[:,0:3:3, :] = input
625   Tensor src_2 = tf.make(
626     /*sizes=*/{2, 1, 4},
627     /*data=*/{
628       -1.,  -2.,  -3.,  -4., // [0, 0, :]
629        1.,   2.,   3.,   4., // [1, 0, :]
630     });
631   Tensor expected_2 = tf.make(
632     /*sizes=*/{2, 3, 4},
633     /*data=*/{
634       // [0, :, :]
635       -1.,  -2.,  -3.,  -4., // [0, 0, :]
636        5.,   6.,   7.,   8., // [0, 1, :]
637        9.,  10.,  11.,  12., // [0, 2, :]
638 
639       // [1, :, :]
640        1.,   2.,   3.,   4., // [1, 0, :]
641       -5.,  -6.,  -7.,  -8., // [1, 1, :]
642       -9., -10., -11., -12., // [1, 2, :]
643     });
644   // clang-format on
645 
646   std::vector<Tensor> src_tensors = {src_0, src_1, src_2};
647   std::vector<Tensor> expected_rets = {expected_0, expected_1, expected_2};
648 
649   // In this test, we maintain start and dim as 0 and 1, also set the
650   // end large enough to hold any step
651   int64_t start = 0;
652   int64_t dim = 1;
653   int64_t end = 10;
654   for (int64_t step = 1; step < 4; step++) {
655     int64_t testcase_idx = step - 1;
656 
657     auto src = src_tensors[testcase_idx];
658     auto expected_ret = expected_rets[testcase_idx];
659     Tensor out = tf.zeros_like(expected_ret);
660 
661     // Should always return the provided out Tensor.
662     // The ret shall meet the expectation.
663     Tensor ret = op_slice_scatter_out(input, src, dim, start, end, step, out);
664     EXPECT_TENSOR_EQ(out, ret);
665     EXPECT_TENSOR_EQ(ret, expected_ret);
666   }
667 }
668 
669 /// A generic smoke test that works for any dtype that supports ones() and
670 /// zeros().
TEST_F(OpSliceScatterTensorOutTest,AllRealDtypesSupported)671 TEST_F(OpSliceScatterTensorOutTest, AllRealDtypesSupported) {
672 #define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
673   ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
674 #undef TEST_ENTRY
675   // TODO: Also add tests for half, complex, quantized, and other types. Easiest
676   // way to do that would be to make TensorFactory support zeros() and ones()
677   // for those types.
678 }
679 
TEST_F(OpSliceScatterTensorOutTest,EmptyInputSupported)680 TEST_F(OpSliceScatterTensorOutTest, EmptyInputSupported) {
681   TensorFactory<ScalarType::Int> tf;
682 
683   Tensor input = tf.ones({1, 0, 1});
684   Tensor src = tf.zeros({1, 0, 1});
685   Tensor out = tf.zeros({1, 0, 1});
686 
687   Tensor expect = tf.ones({1, 0, 1});
688 
689   // Some invalid dim values.
690   for (int64_t dim = 0; dim > input.dim(); dim++) {
691     Tensor ret = op_slice_scatter_out(
692         input, src, dim, /*start=*/0, /*end=*/1, /*step=*/1, out);
693     EXPECT_TENSOR_EQ(ret, out);
694 
695     // All operations in this test share same ground truth
696     EXPECT_TENSOR_EQ(ret, expect);
697   }
698 }
699 
TEST_F(OpSliceScatterTensorOutTest,EmptySizeInputDies)700 TEST_F(OpSliceScatterTensorOutTest, EmptySizeInputDies) {
701   TensorFactory<ScalarType::Int> tf;
702 
703   Tensor input = tf.ones({});
704   Tensor src = tf.ones({});
705   Tensor out = tf.ones({});
706 
707   // The operation shall die whatever the end is.
708   ET_EXPECT_KERNEL_FAILURE(
709       context_,
710       op_slice_scatter_out(
711           input, src, /*dim=*/0, /*start=*/0, /*end=*/0, /*step=*/1, out));
712   ET_EXPECT_KERNEL_FAILURE(
713       context_,
714       op_slice_scatter_out(
715           input, src, /*dim=*/0, /*start=*/0, /*end=*/1, /*step=*/1, out));
716 }
717 
TEST_F(OpSliceScatterTensorOutTest,NonPostiveStepsDies)718 TEST_F(OpSliceScatterTensorOutTest, NonPostiveStepsDies) {
719   TensorFactory<ScalarType::Int> tf;
720 
721   Tensor input = tf.ones({1, 1, 1});
722   Tensor src = tf.zeros({1, 1, 1});
723   Tensor out = tf.zeros({1, 1, 1});
724 
725   // Some invalid step values.
726   const std::vector<int64_t> invalid_steps = {-2, -1, 0};
727   for (int64_t step : invalid_steps) {
728     ET_EXPECT_KERNEL_FAILURE(
729         context_,
730         op_slice_scatter_out(
731             input, src, /*dim=*/0, /*start=*/0, /*end=*/1, /*step=*/step, out));
732   }
733 }
734 
TEST_F(OpSliceScatterTensorOutTest,DimOutOfBoundDies)735 TEST_F(OpSliceScatterTensorOutTest, DimOutOfBoundDies) {
736   TensorFactory<ScalarType::Int> tf;
737 
738   Tensor input = tf.ones({1, 1, 1});
739   Tensor src = tf.zeros({1, 1, 1});
740   Tensor out = tf.zeros({1, 1, 1});
741 
742   // Some invalid dim values.
743   const std::vector<int64_t> invalid_dims = {3, 4, 5, -4, -5, -6};
744   for (int64_t dim : invalid_dims) {
745     ET_EXPECT_KERNEL_FAILURE(
746         context_,
747         op_slice_scatter_out(
748             input, src, dim, /*start=*/0, /*end=*/1, /*step=*/1, out));
749   }
750 }
751 
TEST_F(OpSliceScatterTensorOutTest,MismatchedOutDtypesDies)752 TEST_F(OpSliceScatterTensorOutTest, MismatchedOutDtypesDies) {
753   TensorFactory<ScalarType::Int> tf_int;
754   TensorFactory<ScalarType::Float> tf_float;
755   Tensor input = tf_int.zeros({1, 2, 2});
756   Tensor src = tf_int.zeros({1, 2, 2});
757 
758   // Size is compatible to the output, but a mismatched dtype.
759   Tensor out = tf_float.ones({1, 2, 2});
760 
761   ET_EXPECT_KERNEL_FAILURE(
762       context_,
763       op_slice_scatter_out(
764           input, src, /*dim=*/0, /*start=*/0, /*end=*/1, /*step=*/1, out));
765 }
766 
TEST_F(OpSliceScatterTensorOutTest,OutSizeMismatchDimDies)767 TEST_F(OpSliceScatterTensorOutTest, OutSizeMismatchDimDies) {
768   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
769     GTEST_SKIP() << "ATen kernel can handle out with mismatched dimensions";
770   }
771   TensorFactory<ScalarType::Int> tf;
772 
773   Tensor input = tf.zeros({2, 4, 7, 5});
774   Tensor src = tf.zeros({2, 4, 7, 5});
775 
776   // Should be {2, 4, 7, 5}
777   Tensor out = tf.zeros({2, 4, 7});
778 
779   ET_EXPECT_KERNEL_FAILURE(
780       context_,
781       op_slice_scatter_out(
782           input, src, /*dim=*/0, /*start=*/0, /*end=*/2, /*step=*/1, out));
783 }
784 
TEST_F(OpSliceScatterTensorOutTest,SrcSizeMismatchDimDies)785 TEST_F(OpSliceScatterTensorOutTest, SrcSizeMismatchDimDies) {
786   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
787     GTEST_SKIP() << "ATen kernel can handle out with mismatched dimensions";
788   }
789   TensorFactory<ScalarType::Int> tf;
790 
791   Tensor input = tf.zeros({2, 4, 7, 5});
792   Tensor src = tf.zeros({2, 4, 7});
793 
794   // Should be {2, 4, 7, 5}
795   Tensor out = tf.zeros({2, 4, 7, 5});
796 
797   ET_EXPECT_KERNEL_FAILURE(
798       context_,
799       op_slice_scatter_out(
800           input, src, /*dim=*/0, /*start=*/0, /*end=*/2, /*step=*/1, out));
801 }
802 
TEST_F(OpSliceScatterTensorOutTest,DefaultStartValSupported)803 TEST_F(OpSliceScatterTensorOutTest, DefaultStartValSupported) {
804   TensorFactory<ScalarType::Int> tf;
805 
806   Tensor input = tf.zeros({2, 4, 7, 5});
807   Tensor src = tf.ones({2, 4, 7, 5});
808 
809   Tensor out = tf.zeros({2, 4, 7, 5});
810   Tensor expected = tf.ones({2, 4, 7, 5});
811 
812   Tensor ret_default_start = op_slice_scatter_out(
813       input,
814       src,
815       /*dim=*/0,
816       /*start=*/exec_aten::nullopt,
817       /*end=*/2,
818       /*step=*/1,
819       out);
820   EXPECT_TENSOR_EQ(ret_default_start, out);
821   EXPECT_TENSOR_EQ(ret_default_start, expected);
822 }
823 
TEST_F(OpSliceScatterTensorOutTest,DefaultEndValSupported)824 TEST_F(OpSliceScatterTensorOutTest, DefaultEndValSupported) {
825   TensorFactory<ScalarType::Int> tf;
826 
827   Tensor input = tf.zeros({2, 4, 7, 5});
828   Tensor src = tf.ones({2, 4, 7, 5});
829 
830   Tensor out = tf.zeros({2, 4, 7, 5});
831   Tensor expected = tf.ones({2, 4, 7, 5});
832 
833   Tensor ret_default_end = op_slice_scatter_out(
834       input,
835       src,
836       /*dim=*/0,
837       /*start=*/0,
838       /*end=*/exec_aten::nullopt,
839       /*step=*/1,
840       out);
841   EXPECT_TENSOR_EQ(ret_default_end, out);
842   EXPECT_TENSOR_EQ(ret_default_end, expected);
843 }
844 
TEST_F(OpSliceScatterTensorOutTest,DynamicShapeTest)845 TEST_F(OpSliceScatterTensorOutTest, DynamicShapeTest) {
846   TensorFactory<ScalarType::Int> tf;
847 
848   Tensor input = tf.zeros({1, 4, 4});
849   Tensor src = tf.ones({1, 4, 4});
850 
851   Tensor out =
852       tf.zeros({1, 2, 8}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
853   Tensor expected = tf.ones({1, 4, 4});
854 
855   Tensor ret_default_end = op_slice_scatter_out(
856       input,
857       src,
858       /*dim=*/0,
859       /*start=*/0,
860       /*end=*/exec_aten::nullopt,
861       /*step=*/1,
862       out);
863   EXPECT_TENSOR_EQ(ret_default_end, out);
864   EXPECT_TENSOR_EQ(ret_default_end, expected);
865 }
866 
TEST_F(OpSliceScatterTensorOutTest,LargeEndValue)867 TEST_F(OpSliceScatterTensorOutTest, LargeEndValue) {
868   TensorFactory<ScalarType::Int> tf;
869 
870   Tensor input = tf.zeros({1, 1, 2, 5, 3, 3});
871   Tensor src = tf.ones({1, 1, 2, 5, 3, 3});
872 
873   Tensor out = tf.zeros({1, 1, 2, 5, 3, 3});
874   Tensor expected = tf.ones({1, 1, 2, 5, 3, 3});
875 
876   Tensor ret = op_slice_scatter_out(
877       input,
878       src,
879       /*dim=*/1,
880       /*start=*/0,
881       /*end=*/9223372036854775807,
882       /*step=*/1,
883       out);
884   EXPECT_TENSOR_EQ(ret, out);
885   EXPECT_TENSOR_EQ(ret, expected);
886 }
887