xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_where.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 
12 namespace torch {
13 namespace executor {
14 namespace native {
15 
where_out(KernelRuntimeContext & ctx,const Tensor & cond,const Tensor & a,const Tensor & b,Tensor & out)16 Tensor& where_out(
17     KernelRuntimeContext& ctx,
18     const Tensor& cond,
19     const Tensor& a,
20     const Tensor& b,
21     Tensor& out) {
22   // Common Dtype
23   ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
24 
25   // Check Common Dtype
26   ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out);
27 
28   // Check Dim Order
29   ET_KERNEL_CHECK(
30       ctx, tensors_have_same_dim_order(cond, a, b, out), InvalidArgument, out);
31 
32   // Resize
33   ET_KERNEL_CHECK(
34       ctx,
35       resize_to_broadcast_target_size(a, b, cond, out) == Error::Ok,
36       InvalidArgument,
37       out);
38 
39   // Compute Dtype
40   ScalarType compute_type = utils::get_compute_type(common_type);
41 
42   // @lint-ignore CLANGTIDY facebook-hte-CArray
43   static constexpr const char op_name[] = "where.self_out";
44 
45   ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
46     utils::apply_tritensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
47         [](const CTYPE_COMPUTE val_a,
48            const CTYPE_COMPUTE val_b,
49            const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; },
50         ctx,
51         a,
52         utils::SupportedTensorDtypes::REALHBBF16,
53         b,
54         utils::SupportedTensorDtypes::REALHBBF16,
55         cond,
56         utils::SupportedTensorDtypes::BOOL_OR_BYTE,
57         out,
58         utils::SupportedTensorDtypes::SAME_AS_COMMON);
59   });
60 
61   return out;
62 }
63 
64 } // namespace native
65 } // namespace executor
66 } // namespace torch
67