xref: /aosp_15_r20/external/executorch/kernels/test/op_where_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 
16 #include <gtest/gtest.h>
17 
18 using namespace ::testing;
19 using exec_aten::ScalarType;
20 using exec_aten::Tensor;
21 using torch::executor::testing::SupportedFeatures;
22 using torch::executor::testing::TensorFactory;
23 
24 class OpWhereOutTest : public OperatorTest {
25  protected:
op_where_self_out(const Tensor & condition,const Tensor & self,const Tensor & other,Tensor & out)26   Tensor& op_where_self_out(
27       const Tensor& condition,
28       const Tensor& self,
29       const Tensor& other,
30       Tensor& out) {
31     return torch::executor::aten::where_outf(
32         context_, condition, self, other, out);
33   }
34 
35   template <ScalarType DTYPE_A, ScalarType DTYPE_B, ScalarType DTYPE_OUT>
test_where()36   void test_where() {
37     if (DTYPE_OUT == ScalarType::Byte || DTYPE_OUT == ScalarType::Char) {
38       return;
39     }
40     TensorFactory<ScalarType::Bool> tf_condition;
41     TensorFactory<ScalarType::Byte> tf_condition_byte;
42     TensorFactory<DTYPE_A> tf_a;
43     TensorFactory<DTYPE_B> tf_b;
44     TensorFactory<DTYPE_OUT> tf_out;
45 
46     const std::vector<int32_t> condition_sizes = {12};
47     const std::vector<int32_t> sizes = {1, 12};
48 
49     Tensor out = tf_out.zeros(sizes);
50 
51     // clang-format off
52     std::vector<uint8_t> condition_data = {
53       false, true, false, true, true, false,
54       false, true, false, true, true, false
55     };
56     const auto a_tensor = tf_a.make(sizes, /*data=*/{  1,  2,  3,  4,  5,  6,  6,  5,  4,  3,  2,  1});
57     const auto b_tensor = tf_b.make(sizes, /*data=*/{  6,  5,  4,  3,  2,  1,  1,  2,  3,  4,  5,  6});
58     // clang-format on
59     op_where_self_out(
60         tf_condition.make(condition_sizes, /*data=*/condition_data),
61         a_tensor,
62         b_tensor,
63         out);
64 
65     auto expectedOut =
66         tf_out.make(sizes, /*data=*/{6, 2, 4, 4, 5, 1, 1, 5, 3, 3, 2, 6});
67     // Check that it matches the expected output.
68     EXPECT_TENSOR_CLOSE(out, expectedOut);
69 
70     op_where_self_out(
71         tf_condition_byte.make(condition_sizes, condition_data),
72         a_tensor,
73         b_tensor,
74         out);
75     EXPECT_TENSOR_CLOSE(out, expectedOut);
76   }
77 
78   template <ScalarType DTYPE_A, ScalarType DTYPE_B>
test_where_enumerate_out_types()79   void test_where_enumerate_out_types() {
80 #define ENUMERATE_TEST_ENTRY(ctype, dtype) \
81   test_where<DTYPE_A, DTYPE_B, ScalarType::dtype>();
82 
83     ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY)
84 
85 #undef ENUMERATE_TEST_ENTRY
86   }
87 
88   template <ScalarType DTYPE_A>
test_where_enumerate_b_types()89   void test_where_enumerate_b_types() {
90 #define ENUMERATE_TEST_ENTRY(ctype, dtype) \
91   test_where<DTYPE_A, ScalarType::dtype, DTYPE_A>();
92 
93     ET_FORALL_REALHBBF16_TYPES(ENUMERATE_TEST_ENTRY)
94 
95 #undef ENUMERATE_TEST_ENTRY
96   }
97 
test_dynamic_shape(const std::vector<int32_t> & out_shape,enum torch::executor::TensorShapeDynamism dynamism)98   void test_dynamic_shape(
99       const std::vector<int32_t>& out_shape,
100       enum torch::executor::TensorShapeDynamism dynamism) {
101     /* %python
102     %rewrite(where_template) */
103 
104     TensorFactory<ScalarType::Bool> tfBool;
105     TensorFactory<ScalarType::Float> tf;
106 
107     Tensor condition = tfBool.make(
108         {2, 3, 4}, {true,  false, true, true,  true,  false, false, true,
109                     false, true,  true, false, false, false, false, false,
110                     false, false, true, true,  false, false, true,  true});
111     Tensor input = tf.make(
112         {2, 3, 4},
113         {0.41940832138061523,  0.5529070496559143,   0.9527381062507629,
114          0.036164820194244385, 0.1852310299873352,   0.37341737747192383,
115          0.3051000237464905,   0.9320003986358643,   0.17591017484664917,
116          0.2698335647583008,   0.15067976713180542,  0.03171950578689575,
117          0.20812976360321045,  0.9297990202903748,   0.7231091856956482,
118          0.7423362731933594,   0.5262957811355591,   0.24365824460983276,
119          0.584592342376709,    0.033152639865875244, 0.13871687650680542,
120          0.242235004901886,    0.815468966960907,    0.793160617351532});
121     Tensor other = tf.make(
122         {2, 3, 4},
123         {0.2782524824142456,  0.48195880651474,   0.8197803497314453,
124          0.9970665574073792,  0.6984410881996155, 0.5675464272499084,
125          0.8352431654930115,  0.2055988311767578, 0.593172013759613,
126          0.11234724521636963, 0.1534569263458252, 0.24170821905136108,
127          0.7262365221977234,  0.7010802030563354, 0.2038237452507019,
128          0.6510535478591919,  0.7744860053062439, 0.4368913173675537,
129          0.5190907716751099,  0.6158523559570312, 0.8101882934570312,
130          0.9800970554351807,  0.1146882176399231, 0.3167651295661926});
131     Tensor expected = tf.make(
132         {2, 3, 4},
133         {0.41940832138061523,  0.48195880651474,     0.9527381062507629,
134          0.036164820194244385, 0.1852310299873352,   0.5675464272499084,
135          0.8352431654930115,   0.9320003986358643,   0.593172013759613,
136          0.2698335647583008,   0.15067976713180542,  0.24170821905136108,
137          0.7262365221977234,   0.7010802030563354,   0.2038237452507019,
138          0.6510535478591919,   0.7744860053062439,   0.4368913173675537,
139          0.584592342376709,    0.033152639865875244, 0.8101882934570312,
140          0.9800970554351807,   0.815468966960907,    0.793160617351532});
141     Tensor out = tf.zeros(out_shape, dynamism);
142 
143     op_where_self_out(condition, input, other, out);
144     EXPECT_TENSOR_EQ(out, expected);
145   }
146 
test_where_enumerate_a_types()147   void test_where_enumerate_a_types() {
148 #define ENUMERATE_TEST_ENTRY(ctype, dtype) \
149   test_where_enumerate_b_types<ScalarType::dtype>();
150 
151     ET_FORALL_REALHBBF16_TYPES(ENUMERATE_TEST_ENTRY)
152 
153 #undef ENUMERATE_TEST_ENTRY
154   }
155 
test_where_enumerate_a_types_aten()156   void test_where_enumerate_a_types_aten() {
157 #define ENUMERATE_TEST_ENTRY(ctype, dtype) \
158   test_where<ScalarType::dtype, ScalarType::dtype, ScalarType::dtype>();
159 
160     ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY)
161 
162 #undef ENUMERATE_TEST_ENTRY
163   }
164 };
165 
166 //
167 // Correctness Test
168 //
169 
TEST_F(OpWhereOutTest,AllRealDtypesSupported)170 TEST_F(OpWhereOutTest, AllRealDtypesSupported) {
171   test_where_enumerate_a_types_aten();
172 }
173 
174 // Condition is true, all items will be from x
TEST_F(OpWhereOutTest,AllTrueTest)175 TEST_F(OpWhereOutTest, AllTrueTest) {
176   TensorFactory<ScalarType::Bool> tf_condition;
177   TensorFactory<ScalarType::Float> tf_x;
178   TensorFactory<ScalarType::Float> tf_y;
179   TensorFactory<ScalarType::Float> tf_out;
180 
181   const std::vector<int32_t> condition_sizes = {1};
182   const std::vector<int32_t> sizes = {1, 12};
183 
184   Tensor out = tf_out.zeros(sizes);
185 
186   // clang-format off
187   op_where_self_out(
188       tf_condition.make(condition_sizes, /*data=*/{true}),
189       tf_x.make(sizes, /*data=*/{ 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
190                                   6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 100.0f}),
191       tf_y.make(sizes, /*data=*/{ 0.1f, 1.1f,  2.1f,  3.1f, 4.1f,  5.1f,
192                                    6.1f, 7.1f, 8.1f, 9.1f, 10.1f, 100.1f}),
193       out);
194 
195   // Check that it matches (or close to) the expected output.
196   EXPECT_TENSOR_CLOSE(
197       out,
198       tf_out.make(
199           sizes, /*data=*/{ 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
200                             6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 100.0f}));
201   // clang-format on
202 }
203 
204 // Condition is false, all items will be from y
TEST_F(OpWhereOutTest,AllFalseTest)205 TEST_F(OpWhereOutTest, AllFalseTest) {
206   TensorFactory<ScalarType::Bool> tf_condition;
207   TensorFactory<ScalarType::Float> tf_x;
208   TensorFactory<ScalarType::Float> tf_y;
209   TensorFactory<ScalarType::Float> tf_out;
210 
211   const std::vector<int32_t> condition_sizes = {1};
212   const std::vector<int32_t> sizes = {1, 12};
213 
214   // Destination for the where operator.
215   Tensor out = tf_out.zeros(sizes);
216 
217   // clang-format off
218   op_where_self_out(
219       tf_condition.make(condition_sizes, /*data=*/{false}),
220       tf_x.make(sizes, /*data=*/{ 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
221                                   6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 100.0f}),
222       tf_y.make(sizes, /*data=*/{ 0.1f, 1.1f, 2.1f, 3.1f, 4.1f, 5.1f,
223                                   6.1f, 7.1f, 8.1f, 9.1f, 10.1f, 100.1f}),
224       out);
225 
226   // Check that it matches the expected output.
227   EXPECT_TENSOR_CLOSE(
228       out,
229       tf_out.make(
230           sizes, /*data=*/{ 0.1f, 1.1f, 2.1f, 3.1f, 4.1f, 5.1f,
231                             6.1f, 7.1f, 8.1f, 9.1f, 10.1f, 100.1f}));
232   // clang-format on
233 }
234 
235 // Choosing based on condition[i] ? x[i] : y[i]
TEST_F(OpWhereOutTest,MixedTrueFalseTest)236 TEST_F(OpWhereOutTest, MixedTrueFalseTest) {
237   TensorFactory<ScalarType::Bool> tf_condition;
238   TensorFactory<ScalarType::Float> tf_x;
239   TensorFactory<ScalarType::Float> tf_y;
240   TensorFactory<ScalarType::Float> tf_out;
241 
242   const std::vector<int32_t> condition_sizes = {12};
243   const std::vector<int32_t> sizes = {1, 12};
244 
245   // Destination for the where operator.
246   Tensor out = tf_out.zeros(sizes);
247 
248   // clang-format off
249   op_where_self_out(
250       tf_condition.make(condition_sizes, /*data=*/{false, true, false ,true, true, false,
251                                                     false, true, false ,true, true, false}),
252       tf_x.make(sizes, /*data=*/{ 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
253                                   6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 100.0f}),
254       tf_y.make(sizes, /*data=*/{ 0.1f, 1.1f,  2.1f,  3.1f, 4.1f,  5.1f,
255                                   6.1f, 7.1f, 8.1f, 9.1f, 10.1f, 100.1f}),
256       out);
257 
258   // Check that it matches the expected output.
259   EXPECT_TENSOR_CLOSE(
260       out,
261       tf_out.make(
262           sizes, /*data=*/{ 0.1f, 1.0f, 2.1f, 3.0f, 4.0f, 5.1f,
263                             6.1f, 7.0f, 8.1f, 9.0f, 10.0f, 100.1f}));
264   // clang-format on
265 }
266 
267 // Choosing based on condition[i] ? x[i] : y[i]
TEST_F(OpWhereOutTest,BroadcastConditionTest)268 TEST_F(OpWhereOutTest, BroadcastConditionTest) {
269   TensorFactory<ScalarType::Bool> tf_condition;
270   TensorFactory<ScalarType::Float> tf_x;
271   TensorFactory<ScalarType::Float> tf_y;
272   TensorFactory<ScalarType::Float> tf_out;
273 
274   const std::vector<int32_t> condition_sizes = {3, 1};
275   const std::vector<int32_t> x_sizes = {3, 4};
276   const std::vector<int32_t> y_sizes = {3, 4};
277 
278   // Destination for the where operator.
279   Tensor out = tf_out.zeros(x_sizes);
280 
281   // clang-format off
282   op_where_self_out(
283       tf_condition.make(condition_sizes, /*data=*/{
284                                   false,
285                                   true,
286                                   false}),
287       tf_x.make(x_sizes, /*data=*/{
288                                   0.0f, 1.0f, 2.0f, 3.0f,
289                                   4.0f, 5.0f, 6.0f, 7.0f,
290                                   8.0f,  9.0f, 10.0f, 100.0f}),
291       tf_y.make(y_sizes, /*data=*/
292                                   {0.1f, 1.1f, 2.1f, 3.1f,
293                                   4.1f,  5.1f, 6.1f, 7.1f,
294                                   8.1f,  9.1f, 10.1f, 100.1f}),
295       out);
296 
297   // Check that it matches the expected output.
298   EXPECT_TENSOR_CLOSE(
299       out,
300       tf_out.make(
301           x_sizes, /*data=*/{ 0.1f, 1.1f, 2.1f, 3.1f,
302                               4.0f, 5.0f, 6.0f, 7.0f,
303                               8.1f,  9.1f, 10.1f, 100.1f}));
304   // clang-format on
305 }
306 
307 // Choosing based on condition[i] ? x[i] : y[i]
TEST_F(OpWhereOutTest,BroadcastConditionAndBroadCastYTest)308 TEST_F(OpWhereOutTest, BroadcastConditionAndBroadCastYTest) {
309   TensorFactory<ScalarType::Bool> tf_condition;
310   TensorFactory<ScalarType::Float> tf_x;
311   TensorFactory<ScalarType::Float> tf_y;
312   TensorFactory<ScalarType::Float> tf_out;
313 
314   const std::vector<int32_t> condition_sizes = {3, 1};
315   const std::vector<int32_t> x_sizes = {3, 4};
316   const std::vector<int32_t> y_sizes = {3, 1};
317 
318   // Destination for the where operator.
319   Tensor out = tf_out.zeros(x_sizes);
320 
321   // clang-format off
322   op_where_self_out(
323       tf_condition.make(condition_sizes, /*data=*/{
324                                   false,
325                                   true,
326                                   false}),
327       tf_x.make(x_sizes, /*data=*/{
328                                   0.0f, 1.0f, 2.0f, 3.0f,
329                                   4.0f, 5.0f, 6.0f, 7.0f,
330                                   8.0f,  9.0f, 10.0f, 100.0f}),
331       tf_y.make(y_sizes, /*data=*/{
332                                   0.1f,
333                                   4.1f,
334                                   8.1f}),
335       out);
336 
337   // Check that it matches the expected output.
338   EXPECT_TENSOR_CLOSE(
339       out,
340       tf_out.make(
341           x_sizes, /*data=*/{
342                           0.1f, 0.1f, 0.1f, 0.1f,
343                           4.0f, 5.0f, 6.0f, 7.0f,
344                           8.1f, 8.1f, 8.1f, 8.1f}));
345   // clang-format on
346 }
347 
348 // Choosing based on condition[i] ? x[i] : y[i]
TEST_F(OpWhereOutTest,DoubleTypeTest)349 TEST_F(OpWhereOutTest, DoubleTypeTest) {
350   TensorFactory<ScalarType::Bool> tf_condition;
351   TensorFactory<ScalarType::Double> tf_x;
352   TensorFactory<ScalarType::Double> tf_y;
353   TensorFactory<ScalarType::Double> tf_out;
354 
355   const std::vector<int32_t> condition_sizes = {3, 1};
356   const std::vector<int32_t> x_sizes = {3, 4};
357   const std::vector<int32_t> y_sizes = {3, 1};
358 
359   // Destination for the where operator.
360   Tensor out = tf_out.zeros(x_sizes);
361 
362   // clang-format off
363   op_where_self_out(
364       tf_condition.make(condition_sizes, /*data=*/{
365                                   false,
366                                   true,
367                                   false}),
368       tf_x.make(x_sizes, /*data=*/{
369                                   0.0, 1.0, 2.0, 3.0,
370                                   4.0, 5.0, 6.0, 7.0,
371                                   8.0, 9.0, 10.0, 100.0}),
372       tf_y.make(y_sizes, /*data=*/{
373                                   0.1,
374                                   4.1,
375                                   8.1}),
376       out);
377 
378   // Check that it matches the expected output.
379   EXPECT_TENSOR_CLOSE(
380       out,
381       tf_out.make(
382           x_sizes, /*data=*/{
383                           0.1, 0.1, 0.1, 0.1,
384                           4.0, 5.0, 6.0, 7.0,
385                           8.1, 8.1, 8.1, 8.1}));
386   // clang-format on
387 }
388 
389 // Choosing based on condition[i] ? x[i] : y[i]
TEST_F(OpWhereOutTest,MismatchedShapeTest)390 TEST_F(OpWhereOutTest, MismatchedShapeTest) {
391   TensorFactory<ScalarType::Bool> tf_condition;
392   TensorFactory<ScalarType::Float> tf_x;
393   TensorFactory<ScalarType::Double> tf_y;
394   TensorFactory<ScalarType::Double> tf_out;
395 
396   const std::vector<int32_t> condition_sizes = {3, 1};
397   const std::vector<int32_t> x_sizes = {3, 4};
398   const std::vector<int32_t> y_sizes = {4, 1};
399 
400   // Destination for the where operator.
401   Tensor out = tf_out.zeros(x_sizes);
402 
403   // clang-format off
404   ET_EXPECT_KERNEL_FAILURE(context_, op_where_self_out(
405       tf_condition.make(condition_sizes, /*data=*/{
406                                   false,
407                                   true,
408                                   false}),
409       tf_x.make(x_sizes, /*data=*/{
410                                   0.0f, 1.0f, 2.0f, 3.0f,
411                                   4.0f, 5.0f, 6.0f, 7.0f,
412                                   8.0f,  9.0f, 10.0f, 100.0f}),
413       tf_y.make(y_sizes, /*data=*/{
414                                   0.1,
415                                   4.1,
416                                   8.1,
417                                   11.1}),
418       out));
419   // clang-format on
420 }
421 
422 /* %python
423 import torch
424 torch.manual_seed(0)
425 input_shape = (2, 3, 4)
426 condition = torch.randint(10, input_shape) < 5
427 input = torch.rand(input_shape)
428 other = torch.rand(input_shape)
429 expected = torch.where(condition, input, other)
430 
431 where_template = f"""
432   {declare_tensor_factory("ScalarType::Bool", "tfBool")}
433   {declare_tensor_factory("ScalarType::Float", "tf")}
434 
435   {declare_tensor_make_t("condition", "tfBool")}
436   {declare_tensor_make_t("input", "tf")}
437   {declare_tensor_make_t("other", "tf")}
438   {declare_tensor_make_t("expected", "tf")}
439   {declare_tensor_zeros("out_shape, dynamism", "tf", "out")}
440 
441   op_where_self_out(condition, input, other, out);
442   EXPECT_TENSOR_EQ(out, expected);""" */
443 
TEST_F(OpWhereOutTest,DynamicShapeUpperBoundSameAsExpected)444 TEST_F(OpWhereOutTest, DynamicShapeUpperBoundSameAsExpected) {
445   test_dynamic_shape(
446       {2, 3, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
447 }
448 
TEST_F(OpWhereOutTest,DynamicShapeUpperBoundLargerThanExpected)449 TEST_F(OpWhereOutTest, DynamicShapeUpperBoundLargerThanExpected) {
450   if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
451     GTEST_SKIP() << "Dynamic shape not supported";
452   }
453   test_dynamic_shape(
454       {10, 10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
455 }
456 
TEST_F(OpWhereOutTest,DynamicShapeUnbound)457 TEST_F(OpWhereOutTest, DynamicShapeUnbound) {
458   if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
459     GTEST_SKIP() << "Dynamic shape not supported";
460   }
461   test_dynamic_shape(
462       {1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
463 }
464 
TEST_F(OpWhereOutTest,HalfSupport)465 TEST_F(OpWhereOutTest, HalfSupport) {
466   TensorFactory<ScalarType::Bool> tb;
467   TensorFactory<ScalarType::Half> tf;
468   Tensor cond = tb.make({2, 3}, {true, false, true, false, true, false});
469   Tensor a = tf.full({2, 3}, 1.5);
470   Tensor b = tf.full({2, 3}, 2.5);
471   Tensor out = tf.zeros({2, 3});
472 
473   op_where_self_out(cond, a, b, out);
474   EXPECT_TENSOR_CLOSE(out, tf.make({2, 3}, {1.5, 2.5, 1.5, 2.5, 1.5, 2.5}));
475 }
476