1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15
16 #include <gtest/gtest.h>
17
18 using namespace ::testing;
19 using exec_aten::ScalarType;
20 using exec_aten::Tensor;
21 using torch::executor::testing::SupportedFeatures;
22 using torch::executor::testing::TensorFactory;
23
24 class OpWhereOutTest : public OperatorTest {
25 protected:
op_where_self_out(const Tensor & condition,const Tensor & self,const Tensor & other,Tensor & out)26 Tensor& op_where_self_out(
27 const Tensor& condition,
28 const Tensor& self,
29 const Tensor& other,
30 Tensor& out) {
31 return torch::executor::aten::where_outf(
32 context_, condition, self, other, out);
33 }
34
35 template <ScalarType DTYPE_A, ScalarType DTYPE_B, ScalarType DTYPE_OUT>
test_where()36 void test_where() {
37 if (DTYPE_OUT == ScalarType::Byte || DTYPE_OUT == ScalarType::Char) {
38 return;
39 }
40 TensorFactory<ScalarType::Bool> tf_condition;
41 TensorFactory<ScalarType::Byte> tf_condition_byte;
42 TensorFactory<DTYPE_A> tf_a;
43 TensorFactory<DTYPE_B> tf_b;
44 TensorFactory<DTYPE_OUT> tf_out;
45
46 const std::vector<int32_t> condition_sizes = {12};
47 const std::vector<int32_t> sizes = {1, 12};
48
49 Tensor out = tf_out.zeros(sizes);
50
51 // clang-format off
52 std::vector<uint8_t> condition_data = {
53 false, true, false, true, true, false,
54 false, true, false, true, true, false
55 };
56 const auto a_tensor = tf_a.make(sizes, /*data=*/{ 1, 2, 3, 4, 5, 6, 6, 5, 4, 3, 2, 1});
57 const auto b_tensor = tf_b.make(sizes, /*data=*/{ 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6});
58 // clang-format on
59 op_where_self_out(
60 tf_condition.make(condition_sizes, /*data=*/condition_data),
61 a_tensor,
62 b_tensor,
63 out);
64
65 auto expectedOut =
66 tf_out.make(sizes, /*data=*/{6, 2, 4, 4, 5, 1, 1, 5, 3, 3, 2, 6});
67 // Check that it matches the expected output.
68 EXPECT_TENSOR_CLOSE(out, expectedOut);
69
70 op_where_self_out(
71 tf_condition_byte.make(condition_sizes, condition_data),
72 a_tensor,
73 b_tensor,
74 out);
75 EXPECT_TENSOR_CLOSE(out, expectedOut);
76 }
77
78 template <ScalarType DTYPE_A, ScalarType DTYPE_B>
test_where_enumerate_out_types()79 void test_where_enumerate_out_types() {
80 #define ENUMERATE_TEST_ENTRY(ctype, dtype) \
81 test_where<DTYPE_A, DTYPE_B, ScalarType::dtype>();
82
83 ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY)
84
85 #undef ENUMERATE_TEST_ENTRY
86 }
87
88 template <ScalarType DTYPE_A>
test_where_enumerate_b_types()89 void test_where_enumerate_b_types() {
90 #define ENUMERATE_TEST_ENTRY(ctype, dtype) \
91 test_where<DTYPE_A, ScalarType::dtype, DTYPE_A>();
92
93 ET_FORALL_REALHBBF16_TYPES(ENUMERATE_TEST_ENTRY)
94
95 #undef ENUMERATE_TEST_ENTRY
96 }
97
test_dynamic_shape(const std::vector<int32_t> & out_shape,enum torch::executor::TensorShapeDynamism dynamism)98 void test_dynamic_shape(
99 const std::vector<int32_t>& out_shape,
100 enum torch::executor::TensorShapeDynamism dynamism) {
101 /* %python
102 %rewrite(where_template) */
103
104 TensorFactory<ScalarType::Bool> tfBool;
105 TensorFactory<ScalarType::Float> tf;
106
107 Tensor condition = tfBool.make(
108 {2, 3, 4}, {true, false, true, true, true, false, false, true,
109 false, true, true, false, false, false, false, false,
110 false, false, true, true, false, false, true, true});
111 Tensor input = tf.make(
112 {2, 3, 4},
113 {0.41940832138061523, 0.5529070496559143, 0.9527381062507629,
114 0.036164820194244385, 0.1852310299873352, 0.37341737747192383,
115 0.3051000237464905, 0.9320003986358643, 0.17591017484664917,
116 0.2698335647583008, 0.15067976713180542, 0.03171950578689575,
117 0.20812976360321045, 0.9297990202903748, 0.7231091856956482,
118 0.7423362731933594, 0.5262957811355591, 0.24365824460983276,
119 0.584592342376709, 0.033152639865875244, 0.13871687650680542,
120 0.242235004901886, 0.815468966960907, 0.793160617351532});
121 Tensor other = tf.make(
122 {2, 3, 4},
123 {0.2782524824142456, 0.48195880651474, 0.8197803497314453,
124 0.9970665574073792, 0.6984410881996155, 0.5675464272499084,
125 0.8352431654930115, 0.2055988311767578, 0.593172013759613,
126 0.11234724521636963, 0.1534569263458252, 0.24170821905136108,
127 0.7262365221977234, 0.7010802030563354, 0.2038237452507019,
128 0.6510535478591919, 0.7744860053062439, 0.4368913173675537,
129 0.5190907716751099, 0.6158523559570312, 0.8101882934570312,
130 0.9800970554351807, 0.1146882176399231, 0.3167651295661926});
131 Tensor expected = tf.make(
132 {2, 3, 4},
133 {0.41940832138061523, 0.48195880651474, 0.9527381062507629,
134 0.036164820194244385, 0.1852310299873352, 0.5675464272499084,
135 0.8352431654930115, 0.9320003986358643, 0.593172013759613,
136 0.2698335647583008, 0.15067976713180542, 0.24170821905136108,
137 0.7262365221977234, 0.7010802030563354, 0.2038237452507019,
138 0.6510535478591919, 0.7744860053062439, 0.4368913173675537,
139 0.584592342376709, 0.033152639865875244, 0.8101882934570312,
140 0.9800970554351807, 0.815468966960907, 0.793160617351532});
141 Tensor out = tf.zeros(out_shape, dynamism);
142
143 op_where_self_out(condition, input, other, out);
144 EXPECT_TENSOR_EQ(out, expected);
145 }
146
test_where_enumerate_a_types()147 void test_where_enumerate_a_types() {
148 #define ENUMERATE_TEST_ENTRY(ctype, dtype) \
149 test_where_enumerate_b_types<ScalarType::dtype>();
150
151 ET_FORALL_REALHBBF16_TYPES(ENUMERATE_TEST_ENTRY)
152
153 #undef ENUMERATE_TEST_ENTRY
154 }
155
test_where_enumerate_a_types_aten()156 void test_where_enumerate_a_types_aten() {
157 #define ENUMERATE_TEST_ENTRY(ctype, dtype) \
158 test_where<ScalarType::dtype, ScalarType::dtype, ScalarType::dtype>();
159
160 ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY)
161
162 #undef ENUMERATE_TEST_ENTRY
163 }
164 };
165
166 //
167 // Correctness Test
168 //
169
TEST_F(OpWhereOutTest,AllRealDtypesSupported)170 TEST_F(OpWhereOutTest, AllRealDtypesSupported) {
171 test_where_enumerate_a_types_aten();
172 }
173
174 // Condition is true, all items will be from x
TEST_F(OpWhereOutTest,AllTrueTest)175 TEST_F(OpWhereOutTest, AllTrueTest) {
176 TensorFactory<ScalarType::Bool> tf_condition;
177 TensorFactory<ScalarType::Float> tf_x;
178 TensorFactory<ScalarType::Float> tf_y;
179 TensorFactory<ScalarType::Float> tf_out;
180
181 const std::vector<int32_t> condition_sizes = {1};
182 const std::vector<int32_t> sizes = {1, 12};
183
184 Tensor out = tf_out.zeros(sizes);
185
186 // clang-format off
187 op_where_self_out(
188 tf_condition.make(condition_sizes, /*data=*/{true}),
189 tf_x.make(sizes, /*data=*/{ 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
190 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 100.0f}),
191 tf_y.make(sizes, /*data=*/{ 0.1f, 1.1f, 2.1f, 3.1f, 4.1f, 5.1f,
192 6.1f, 7.1f, 8.1f, 9.1f, 10.1f, 100.1f}),
193 out);
194
195 // Check that it matches (or close to) the expected output.
196 EXPECT_TENSOR_CLOSE(
197 out,
198 tf_out.make(
199 sizes, /*data=*/{ 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
200 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 100.0f}));
201 // clang-format on
202 }
203
204 // Condition is false, all items will be from y
TEST_F(OpWhereOutTest,AllFalseTest)205 TEST_F(OpWhereOutTest, AllFalseTest) {
206 TensorFactory<ScalarType::Bool> tf_condition;
207 TensorFactory<ScalarType::Float> tf_x;
208 TensorFactory<ScalarType::Float> tf_y;
209 TensorFactory<ScalarType::Float> tf_out;
210
211 const std::vector<int32_t> condition_sizes = {1};
212 const std::vector<int32_t> sizes = {1, 12};
213
214 // Destination for the where operator.
215 Tensor out = tf_out.zeros(sizes);
216
217 // clang-format off
218 op_where_self_out(
219 tf_condition.make(condition_sizes, /*data=*/{false}),
220 tf_x.make(sizes, /*data=*/{ 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
221 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 100.0f}),
222 tf_y.make(sizes, /*data=*/{ 0.1f, 1.1f, 2.1f, 3.1f, 4.1f, 5.1f,
223 6.1f, 7.1f, 8.1f, 9.1f, 10.1f, 100.1f}),
224 out);
225
226 // Check that it matches the expected output.
227 EXPECT_TENSOR_CLOSE(
228 out,
229 tf_out.make(
230 sizes, /*data=*/{ 0.1f, 1.1f, 2.1f, 3.1f, 4.1f, 5.1f,
231 6.1f, 7.1f, 8.1f, 9.1f, 10.1f, 100.1f}));
232 // clang-format on
233 }
234
235 // Choosing based on condition[i] ? x[i] : y[i]
TEST_F(OpWhereOutTest,MixedTrueFalseTest)236 TEST_F(OpWhereOutTest, MixedTrueFalseTest) {
237 TensorFactory<ScalarType::Bool> tf_condition;
238 TensorFactory<ScalarType::Float> tf_x;
239 TensorFactory<ScalarType::Float> tf_y;
240 TensorFactory<ScalarType::Float> tf_out;
241
242 const std::vector<int32_t> condition_sizes = {12};
243 const std::vector<int32_t> sizes = {1, 12};
244
245 // Destination for the where operator.
246 Tensor out = tf_out.zeros(sizes);
247
248 // clang-format off
249 op_where_self_out(
250 tf_condition.make(condition_sizes, /*data=*/{false, true, false ,true, true, false,
251 false, true, false ,true, true, false}),
252 tf_x.make(sizes, /*data=*/{ 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
253 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 100.0f}),
254 tf_y.make(sizes, /*data=*/{ 0.1f, 1.1f, 2.1f, 3.1f, 4.1f, 5.1f,
255 6.1f, 7.1f, 8.1f, 9.1f, 10.1f, 100.1f}),
256 out);
257
258 // Check that it matches the expected output.
259 EXPECT_TENSOR_CLOSE(
260 out,
261 tf_out.make(
262 sizes, /*data=*/{ 0.1f, 1.0f, 2.1f, 3.0f, 4.0f, 5.1f,
263 6.1f, 7.0f, 8.1f, 9.0f, 10.0f, 100.1f}));
264 // clang-format on
265 }
266
267 // Choosing based on condition[i] ? x[i] : y[i]
TEST_F(OpWhereOutTest,BroadcastConditionTest)268 TEST_F(OpWhereOutTest, BroadcastConditionTest) {
269 TensorFactory<ScalarType::Bool> tf_condition;
270 TensorFactory<ScalarType::Float> tf_x;
271 TensorFactory<ScalarType::Float> tf_y;
272 TensorFactory<ScalarType::Float> tf_out;
273
274 const std::vector<int32_t> condition_sizes = {3, 1};
275 const std::vector<int32_t> x_sizes = {3, 4};
276 const std::vector<int32_t> y_sizes = {3, 4};
277
278 // Destination for the where operator.
279 Tensor out = tf_out.zeros(x_sizes);
280
281 // clang-format off
282 op_where_self_out(
283 tf_condition.make(condition_sizes, /*data=*/{
284 false,
285 true,
286 false}),
287 tf_x.make(x_sizes, /*data=*/{
288 0.0f, 1.0f, 2.0f, 3.0f,
289 4.0f, 5.0f, 6.0f, 7.0f,
290 8.0f, 9.0f, 10.0f, 100.0f}),
291 tf_y.make(y_sizes, /*data=*/
292 {0.1f, 1.1f, 2.1f, 3.1f,
293 4.1f, 5.1f, 6.1f, 7.1f,
294 8.1f, 9.1f, 10.1f, 100.1f}),
295 out);
296
297 // Check that it matches the expected output.
298 EXPECT_TENSOR_CLOSE(
299 out,
300 tf_out.make(
301 x_sizes, /*data=*/{ 0.1f, 1.1f, 2.1f, 3.1f,
302 4.0f, 5.0f, 6.0f, 7.0f,
303 8.1f, 9.1f, 10.1f, 100.1f}));
304 // clang-format on
305 }
306
307 // Choosing based on condition[i] ? x[i] : y[i]
TEST_F(OpWhereOutTest,BroadcastConditionAndBroadCastYTest)308 TEST_F(OpWhereOutTest, BroadcastConditionAndBroadCastYTest) {
309 TensorFactory<ScalarType::Bool> tf_condition;
310 TensorFactory<ScalarType::Float> tf_x;
311 TensorFactory<ScalarType::Float> tf_y;
312 TensorFactory<ScalarType::Float> tf_out;
313
314 const std::vector<int32_t> condition_sizes = {3, 1};
315 const std::vector<int32_t> x_sizes = {3, 4};
316 const std::vector<int32_t> y_sizes = {3, 1};
317
318 // Destination for the where operator.
319 Tensor out = tf_out.zeros(x_sizes);
320
321 // clang-format off
322 op_where_self_out(
323 tf_condition.make(condition_sizes, /*data=*/{
324 false,
325 true,
326 false}),
327 tf_x.make(x_sizes, /*data=*/{
328 0.0f, 1.0f, 2.0f, 3.0f,
329 4.0f, 5.0f, 6.0f, 7.0f,
330 8.0f, 9.0f, 10.0f, 100.0f}),
331 tf_y.make(y_sizes, /*data=*/{
332 0.1f,
333 4.1f,
334 8.1f}),
335 out);
336
337 // Check that it matches the expected output.
338 EXPECT_TENSOR_CLOSE(
339 out,
340 tf_out.make(
341 x_sizes, /*data=*/{
342 0.1f, 0.1f, 0.1f, 0.1f,
343 4.0f, 5.0f, 6.0f, 7.0f,
344 8.1f, 8.1f, 8.1f, 8.1f}));
345 // clang-format on
346 }
347
348 // Choosing based on condition[i] ? x[i] : y[i]
TEST_F(OpWhereOutTest,DoubleTypeTest)349 TEST_F(OpWhereOutTest, DoubleTypeTest) {
350 TensorFactory<ScalarType::Bool> tf_condition;
351 TensorFactory<ScalarType::Double> tf_x;
352 TensorFactory<ScalarType::Double> tf_y;
353 TensorFactory<ScalarType::Double> tf_out;
354
355 const std::vector<int32_t> condition_sizes = {3, 1};
356 const std::vector<int32_t> x_sizes = {3, 4};
357 const std::vector<int32_t> y_sizes = {3, 1};
358
359 // Destination for the where operator.
360 Tensor out = tf_out.zeros(x_sizes);
361
362 // clang-format off
363 op_where_self_out(
364 tf_condition.make(condition_sizes, /*data=*/{
365 false,
366 true,
367 false}),
368 tf_x.make(x_sizes, /*data=*/{
369 0.0, 1.0, 2.0, 3.0,
370 4.0, 5.0, 6.0, 7.0,
371 8.0, 9.0, 10.0, 100.0}),
372 tf_y.make(y_sizes, /*data=*/{
373 0.1,
374 4.1,
375 8.1}),
376 out);
377
378 // Check that it matches the expected output.
379 EXPECT_TENSOR_CLOSE(
380 out,
381 tf_out.make(
382 x_sizes, /*data=*/{
383 0.1, 0.1, 0.1, 0.1,
384 4.0, 5.0, 6.0, 7.0,
385 8.1, 8.1, 8.1, 8.1}));
386 // clang-format on
387 }
388
389 // Choosing based on condition[i] ? x[i] : y[i]
TEST_F(OpWhereOutTest,MismatchedShapeTest)390 TEST_F(OpWhereOutTest, MismatchedShapeTest) {
391 TensorFactory<ScalarType::Bool> tf_condition;
392 TensorFactory<ScalarType::Float> tf_x;
393 TensorFactory<ScalarType::Double> tf_y;
394 TensorFactory<ScalarType::Double> tf_out;
395
396 const std::vector<int32_t> condition_sizes = {3, 1};
397 const std::vector<int32_t> x_sizes = {3, 4};
398 const std::vector<int32_t> y_sizes = {4, 1};
399
400 // Destination for the where operator.
401 Tensor out = tf_out.zeros(x_sizes);
402
403 // clang-format off
404 ET_EXPECT_KERNEL_FAILURE(context_, op_where_self_out(
405 tf_condition.make(condition_sizes, /*data=*/{
406 false,
407 true,
408 false}),
409 tf_x.make(x_sizes, /*data=*/{
410 0.0f, 1.0f, 2.0f, 3.0f,
411 4.0f, 5.0f, 6.0f, 7.0f,
412 8.0f, 9.0f, 10.0f, 100.0f}),
413 tf_y.make(y_sizes, /*data=*/{
414 0.1,
415 4.1,
416 8.1,
417 11.1}),
418 out));
419 // clang-format on
420 }
421
422 /* %python
423 import torch
424 torch.manual_seed(0)
425 input_shape = (2, 3, 4)
426 condition = torch.randint(10, input_shape) < 5
427 input = torch.rand(input_shape)
428 other = torch.rand(input_shape)
429 expected = torch.where(condition, input, other)
430
431 where_template = f"""
432 {declare_tensor_factory("ScalarType::Bool", "tfBool")}
433 {declare_tensor_factory("ScalarType::Float", "tf")}
434
435 {declare_tensor_make_t("condition", "tfBool")}
436 {declare_tensor_make_t("input", "tf")}
437 {declare_tensor_make_t("other", "tf")}
438 {declare_tensor_make_t("expected", "tf")}
439 {declare_tensor_zeros("out_shape, dynamism", "tf", "out")}
440
441 op_where_self_out(condition, input, other, out);
442 EXPECT_TENSOR_EQ(out, expected);""" */
443
TEST_F(OpWhereOutTest,DynamicShapeUpperBoundSameAsExpected)444 TEST_F(OpWhereOutTest, DynamicShapeUpperBoundSameAsExpected) {
445 test_dynamic_shape(
446 {2, 3, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
447 }
448
TEST_F(OpWhereOutTest,DynamicShapeUpperBoundLargerThanExpected)449 TEST_F(OpWhereOutTest, DynamicShapeUpperBoundLargerThanExpected) {
450 if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
451 GTEST_SKIP() << "Dynamic shape not supported";
452 }
453 test_dynamic_shape(
454 {10, 10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
455 }
456
TEST_F(OpWhereOutTest,DynamicShapeUnbound)457 TEST_F(OpWhereOutTest, DynamicShapeUnbound) {
458 if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
459 GTEST_SKIP() << "Dynamic shape not supported";
460 }
461 test_dynamic_shape(
462 {1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
463 }
464
TEST_F(OpWhereOutTest,HalfSupport)465 TEST_F(OpWhereOutTest, HalfSupport) {
466 TensorFactory<ScalarType::Bool> tb;
467 TensorFactory<ScalarType::Half> tf;
468 Tensor cond = tb.make({2, 3}, {true, false, true, false, true, false});
469 Tensor a = tf.full({2, 3}, 1.5);
470 Tensor b = tf.full({2, 3}, 2.5);
471 Tensor out = tf.zeros({2, 3});
472
473 op_where_self_out(cond, a, b, out);
474 EXPECT_TENSOR_CLOSE(out, tf.make({2, 3}, {1.5, 2.5, 1.5, 2.5, 1.5, 2.5}));
475 }
476