xref: /aosp_15_r20/external/executorch/kernels/test/op_narrow_copy_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/runtime/core/exec_aten/exec_aten.h>
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
14 
15 #include <gtest/gtest.h>
16 
17 using namespace ::testing;
18 using exec_aten::ArrayRef;
19 using exec_aten::optional;
20 using exec_aten::ScalarType;
21 using exec_aten::Tensor;
22 using torch::executor::testing::TensorFactory;
23 
24 class OpNarrowCopyOutTest : public OperatorTest {
25  protected:
op_narrow_copy_out(const Tensor & in,int64_t dim,int64_t start,int64_t length,Tensor & out)26   Tensor& op_narrow_copy_out(
27       const Tensor& in,
28       int64_t dim,
29       int64_t start,
30       int64_t length,
31       Tensor& out) {
32     return torch::executor::aten::narrow_copy_outf(
33         context_, in, dim, start, length, out);
34   }
35 
36   template <class CTYPE, exec_aten::ScalarType DTYPE>
test_dtype()37   void test_dtype() {
38     TensorFactory<DTYPE> tf;
39 
40     // clang-format off
41     Tensor input = tf.make(
42       /*sizes=*/{3, 4},
43       /*data=*/{
44         1,   2,   3,   4, // [0, :]
45         5,   6,   7,   8, // [1, :]
46         9,  10,  11,  12, // [2, :]
47       });
48 
49     Tensor expected = tf.make(
50       /*sizes=*/{2, 4},
51       /*data=*/{
52         1,   2,   3,   4, // [0, :]
53         5,   6,   7,   8, // [1, :]
54       });
55     // clang-format on
56 
57     Tensor out = tf.zeros({2, 4});
58     Tensor ret =
59         op_narrow_copy_out(input, /*dim=*/0, /*start=*/0, /*length=*/2, out);
60 
61     EXPECT_TENSOR_EQ(out, ret);
62     EXPECT_TENSOR_EQ(out, expected);
63   }
64 };
65 
TEST_F(OpNarrowCopyOutTest,AllDtypesSupported)66 TEST_F(OpNarrowCopyOutTest, AllDtypesSupported) {
67 #define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
68   ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
69 #undef TEST_ENTRY
70 }
71 
TEST_F(OpNarrowCopyOutTest,EmptyInputSupported)72 TEST_F(OpNarrowCopyOutTest, EmptyInputSupported) {
73   TensorFactory<ScalarType::Int> tf;
74 
75   Tensor input = tf.ones({1, 0, 1});
76   Tensor out = tf.zeros({1, 0, 1});
77 
78   Tensor expect = tf.ones({1, 0, 1});
79 
80   Tensor ret =
81       op_narrow_copy_out(input, /*dim=*/0, /*start=*/0, /*length=*/1, out);
82   EXPECT_TENSOR_EQ(ret, out);
83   EXPECT_TENSOR_EQ(ret, expect);
84 
85   ret = op_narrow_copy_out(input, /*dim=*/1, /*start=*/0, /*length=*/0, out);
86   EXPECT_TENSOR_EQ(ret, out);
87   EXPECT_TENSOR_EQ(ret, expect);
88 
89   ret = op_narrow_copy_out(input, /*dim=*/2, /*start=*/0, /*length=*/1, out);
90   EXPECT_TENSOR_EQ(ret, out);
91   EXPECT_TENSOR_EQ(ret, expect);
92 }
93 
TEST_F(OpNarrowCopyOutTest,ZeroLengthSupported)94 TEST_F(OpNarrowCopyOutTest, ZeroLengthSupported) {
95   TensorFactory<ScalarType::Int> tf;
96 
97   Tensor input = tf.ones({2, 3});
98   Tensor out = tf.ones({2, 0});
99 
100   Tensor expect = tf.ones({2, 0});
101 
102   Tensor ret =
103       op_narrow_copy_out(input, /*dim=*/1, /*start=*/1, /*length=*/0, out);
104   EXPECT_TENSOR_EQ(ret, out);
105   EXPECT_TENSOR_EQ(ret, expect);
106 
107   ret = op_narrow_copy_out(input, /*dim=*/1, /*start=*/-1, /*length=*/0, out);
108   EXPECT_TENSOR_EQ(ret, out);
109   EXPECT_TENSOR_EQ(ret, expect);
110 }
111 
TEST_F(OpNarrowCopyOutTest,ZeroDimInputDies)112 TEST_F(OpNarrowCopyOutTest, ZeroDimInputDies) {
113   TensorFactory<ScalarType::Int> tf;
114 
115   Tensor input = tf.ones({});
116   Tensor out = tf.ones({});
117 
118   // The operation shall die whatever the end is.
119   ET_EXPECT_KERNEL_FAILURE(
120       context_,
121       op_narrow_copy_out(input, /*dim=*/0, /*start=*/0, /*length=*/0, out));
122   ET_EXPECT_KERNEL_FAILURE(
123       context_,
124       op_narrow_copy_out(input, /*dim=*/0, /*start=*/1, /*length=*/1, out));
125 }
126 
TEST_F(OpNarrowCopyOutTest,InvalidStart)127 TEST_F(OpNarrowCopyOutTest, InvalidStart) {
128   TensorFactory<ScalarType::Int> tf;
129 
130   Tensor input = tf.ones({2, 3});
131   Tensor out = tf.ones({2, 3});
132 
133   ET_EXPECT_KERNEL_FAILURE(
134       context_,
135       op_narrow_copy_out(input, /*dim=*/0, /*start=*/-3, /*length=*/0, out));
136   ET_EXPECT_KERNEL_FAILURE(
137       context_,
138       op_narrow_copy_out(input, /*dim=*/1, /*start=*/4, /*length=*/0, out));
139 }
140 
TEST_F(OpNarrowCopyOutTest,InvalidStartLengthCombination)141 TEST_F(OpNarrowCopyOutTest, InvalidStartLengthCombination) {
142   TensorFactory<ScalarType::Int> tf;
143 
144   Tensor input = tf.ones({2, 3});
145   Tensor out = tf.ones({2, 3});
146 
147   ET_EXPECT_KERNEL_FAILURE(
148       context_,
149       op_narrow_copy_out(input, /*dim=*/0, /*start=*/0, /*length=*/3, out));
150   ET_EXPECT_KERNEL_FAILURE(
151       context_,
152       op_narrow_copy_out(input, /*dim=*/1, /*start=*/-1, /*length=*/2, out));
153 }
154 
TEST_F(OpNarrowCopyOutTest,NegativeLengthDies)155 TEST_F(OpNarrowCopyOutTest, NegativeLengthDies) {
156   TensorFactory<ScalarType::Int> tf;
157 
158   Tensor input = tf.ones({1, 1, 1});
159   Tensor out = tf.zeros({1, 1, 1});
160 
161   // Some invalid length values.
162   const std::vector<int64_t> invalid_lengths = {-3, -2, -1};
163   for (int64_t length : invalid_lengths) {
164     ET_EXPECT_KERNEL_FAILURE(
165         context_,
166         op_narrow_copy_out(
167             input, /*dim=*/0, /*start=*/0, /*length=*/length, out));
168   }
169 }
170 
TEST_F(OpNarrowCopyOutTest,DimOutOfBoundDies)171 TEST_F(OpNarrowCopyOutTest, DimOutOfBoundDies) {
172   TensorFactory<ScalarType::Int> tf;
173 
174   Tensor input = tf.ones({1, 1, 1});
175   Tensor out = tf.zeros({1, 1, 1});
176 
177   // Some invalid dim values.
178   const std::vector<int64_t> invalid_dims = {3, 4, 5, -4, -5, -6};
179   for (int64_t dim : invalid_dims) {
180     ET_EXPECT_KERNEL_FAILURE(
181         context_,
182         op_narrow_copy_out(input, dim, /*start=*/0, /*length=*/1, out));
183   }
184 }
185 
TEST_F(OpNarrowCopyOutTest,MismatchedDtypesDies)186 TEST_F(OpNarrowCopyOutTest, MismatchedDtypesDies) {
187   TensorFactory<ScalarType::Int> tf_int;
188   TensorFactory<ScalarType::Float> tf_float;
189   Tensor input = tf_int.zeros({1, 2, 2});
190 
191   // Size is compatible to the output, but a mismatched dtype.
192   Tensor out = tf_float.ones({1, 2, 2});
193 
194   ET_EXPECT_KERNEL_FAILURE(
195       context_,
196       op_narrow_copy_out(input, /*dim=*/0, /*start=*/0, /*length=*/1, out));
197 }
198