xref: /aosp_15_r20/external/executorch/kernels/test/op_ne_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 #include <gtest/gtest.h>
16 
17 using namespace ::testing;
18 using exec_aten::Scalar;
19 using exec_aten::ScalarType;
20 using exec_aten::Tensor;
21 using executorch::runtime::KernelRuntimeContext;
22 using torch::executor::testing::TensorFactory;
23 
24 class OpNeTest : public OperatorTest {
25  protected:
op_ne_tensor_out(const Tensor & self,Tensor & other,Tensor & out)26   Tensor& op_ne_tensor_out(const Tensor& self, Tensor& other, Tensor& out) {
27     return torch::executor::aten::ne_outf(context_, self, other, out);
28   }
29 
30   template <class CTYPE, ScalarType DTYPE>
test_dtype()31   void test_dtype() {
32     TensorFactory<DTYPE> tf_input;
33     TensorFactory<ScalarType::Bool> tf_bool;
34     Tensor a = tf_input.make(/*sizes=*/{2, 2}, /*data=*/{2, 3, 2, 4});
35     Tensor b = tf_input.make({2, 2}, {2, 2, 2, 2});
36     Tensor out = tf_bool.zeros({2, 2});
37     KernelRuntimeContext context{};
38 
39     torch::executor::aten::ne_outf(context, a, b, out);
40     EXPECT_TENSOR_EQ(out, tf_bool.make({2, 2}, {false, true, false, true}));
41   }
42 };
43 
44 class OpNeScalarOutTest : public OperatorTest {
45  protected:
op_ne_scalar_out(const Tensor & self,Scalar & other,Tensor & out)46   Tensor& op_ne_scalar_out(const Tensor& self, Scalar& other, Tensor& out) {
47     return torch::executor::aten::ne_outf(context_, self, other, out);
48   }
49 
50   // Common testing for ne operator
51   template <ScalarType DTYPE>
test_ne_scalar_out()52   void test_ne_scalar_out() {
53     TensorFactory<DTYPE> tf;
54     TensorFactory<ScalarType::Bool> tf_out;
55 
56     const std::vector<int32_t> sizes = {2, 2};
57     // Destination for the ne
58     Tensor out = tf_out.ones(sizes);
59     Scalar other = 2;
60 
61     // Valid input should give the expected output
62     op_ne_scalar_out(tf.make(sizes, /*data=*/{2, 3, 2, 3}), other, out);
63     EXPECT_TENSOR_EQ(
64         out, tf_out.make(sizes, /*data=*/{false, true, false, true}));
65   }
66 
67   // Handle all output dtypes.
68   template <ScalarType OUTPUT_DTYPE>
test_ne_all_output_dtypes()69   void test_ne_all_output_dtypes() {
70     TensorFactory<ScalarType::Float> tf_float;
71     TensorFactory<OUTPUT_DTYPE> tf_out;
72 
73     const std::vector<int32_t> sizes = {2, 5};
74 
75     Tensor in = tf_float.ones(sizes);
76     Tensor out = tf_out.zeros(sizes);
77     Scalar other = 3;
78 
79     op_ne_scalar_out(in, other, out);
80     EXPECT_TENSOR_EQ(out, tf_out.ones(sizes));
81   }
82 };
83 
TEST_F(OpNeScalarOutTest,AllRealInputBoolOutputSupport)84 TEST_F(OpNeScalarOutTest, AllRealInputBoolOutputSupport) {
85 #define TEST_ENTRY(ctype, dtype) test_ne_scalar_out<ScalarType::dtype>();
86   ET_FORALL_REAL_TYPES(TEST_ENTRY);
87 #undef TEST_ENTRY
88 }
89 
TEST_F(OpNeScalarOutTest,BoolInputDtype)90 TEST_F(OpNeScalarOutTest, BoolInputDtype) {
91   TensorFactory<ScalarType::Bool> tf_bool;
92 
93   const std::vector<int32_t> sizes = {2, 2};
94   Tensor a = tf_bool.make(sizes, /*data=*/{false, true, false, true});
95   Tensor out = tf_bool.zeros(sizes);
96   Scalar other = 1;
97 
98   op_ne_scalar_out(a, other, out);
99   EXPECT_TENSOR_EQ(
100       out, tf_bool.make(sizes, /*data=*/{true, false, true, false}));
101 }
102 
103 // Mismatched shape tests.
TEST_F(OpNeScalarOutTest,MismatchedShapesDies)104 TEST_F(OpNeScalarOutTest, MismatchedShapesDies) {
105   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
106     GTEST_SKIP() << "ATen kernel can handle mismatched shapes";
107   }
108   TensorFactory<ScalarType::Int> tf_int;
109   TensorFactory<ScalarType::Bool> tf_bool;
110 
111   Tensor a = tf_int.ones(/*sizes=*/{4});
112   Tensor out = tf_bool.ones(/*sizes=*/{2, 2});
113   Scalar other = 3;
114 
115   ET_EXPECT_KERNEL_FAILURE(context_, op_ne_scalar_out(a, other, out));
116 }
117 
TEST_F(OpNeScalarOutTest,AllRealOutputDTypesSupported)118 TEST_F(OpNeScalarOutTest, AllRealOutputDTypesSupported) {
119 #define TEST_ENTRY(ctype, dtype) test_ne_all_output_dtypes<ScalarType::dtype>();
120   ET_FORALL_REAL_TYPES(TEST_ENTRY);
121 #undef TEST_ENTRY
122 }
123 
TEST_F(OpNeTest,AllDtypesSupported)124 TEST_F(OpNeTest, AllDtypesSupported) {
125 #define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
126   ET_FORALL_REAL_TYPES(TEST_ENTRY);
127 #undef TEST_ENTRY
128 }
129 
130 /* %python
131 import torch
132 torch.manual_seed(0)
133 x = torch.randint(3, (3, 2))
134 res = torch.ne(x, 2)
135 op = "op_ne_scalar_out"
136 opt_setup_params = """
137   Scalar other = 2;
138 """
139 opt_extra_params = "other,"
140 dtype = "ScalarType::Int"
141 out_dtype = "ScalarType::Bool"
142 check = "EXPECT_TENSOR_EQ" */
143 
TEST_F(OpNeScalarOutTest,DynamicShapeUpperBoundSameAsExpected)144 TEST_F(OpNeScalarOutTest, DynamicShapeUpperBoundSameAsExpected) {
145   /* %python
146   out_args = "{3, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND"
147   %rewrite(unary_op_out_dtype) */
148 
149   TensorFactory<ScalarType::Int> tf;
150   TensorFactory<ScalarType::Bool> tfOut;
151 
152   Tensor x = tf.make({3, 2}, {2, 0, 2, 0, 1, 0});
153   Tensor expected = tfOut.make({3, 2}, {false, true, false, true, true, true});
154 
155   Scalar other = 2;
156 
157   Tensor out =
158       tfOut.zeros({3, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
159   op_ne_scalar_out(x, other, out);
160   EXPECT_TENSOR_EQ(out, expected);
161 }
162 
TEST_F(OpNeScalarOutTest,DynamicShapeUpperBoundLargerThanExpected)163 TEST_F(OpNeScalarOutTest, DynamicShapeUpperBoundLargerThanExpected) {
164   /* %python
165   out_args = "{10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND"
166   %rewrite(unary_op_out_dtype) */
167 
168   TensorFactory<ScalarType::Int> tf;
169   TensorFactory<ScalarType::Bool> tfOut;
170 
171   Tensor x = tf.make({3, 2}, {2, 0, 2, 0, 1, 0});
172   Tensor expected = tfOut.make({3, 2}, {false, true, false, true, true, true});
173 
174   Scalar other = 2;
175 
176   Tensor out = tfOut.zeros(
177       {10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
178   op_ne_scalar_out(x, other, out);
179   EXPECT_TENSOR_EQ(out, expected);
180 }
181 
TEST_F(OpNeScalarOutTest,DynamicShapeUnbound)182 TEST_F(OpNeScalarOutTest, DynamicShapeUnbound) {
183   if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
184     GTEST_SKIP() << "Dynamic shape unbound not supported";
185   }
186   /* %python
187   out_args = "{1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND"
188   %rewrite(unary_op_out_dtype) */
189 
190   TensorFactory<ScalarType::Int> tf;
191   TensorFactory<ScalarType::Bool> tfOut;
192 
193   Tensor x = tf.make({3, 2}, {2, 0, 2, 0, 1, 0});
194   Tensor expected = tfOut.make({3, 2}, {false, true, false, true, true, true});
195 
196   Scalar other = 2;
197 
198   Tensor out = tfOut.zeros(
199       {1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
200   op_ne_scalar_out(x, other, out);
201   EXPECT_TENSOR_EQ(out, expected);
202 }
203