1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11
12 namespace torch {
13 namespace executor {
14 namespace native {
15
where_out(KernelRuntimeContext & ctx,const Tensor & cond,const Tensor & a,const Tensor & b,Tensor & out)16 Tensor& where_out(
17 KernelRuntimeContext& ctx,
18 const Tensor& cond,
19 const Tensor& a,
20 const Tensor& b,
21 Tensor& out) {
22 // Common Dtype
23 ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
24
25 // Check Common Dtype
26 ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out);
27
28 // Check Dim Order
29 ET_KERNEL_CHECK(
30 ctx, tensors_have_same_dim_order(cond, a, b, out), InvalidArgument, out);
31
32 // Resize
33 ET_KERNEL_CHECK(
34 ctx,
35 resize_to_broadcast_target_size(a, b, cond, out) == Error::Ok,
36 InvalidArgument,
37 out);
38
39 // Compute Dtype
40 ScalarType compute_type = utils::get_compute_type(common_type);
41
42 // @lint-ignore CLANGTIDY facebook-hte-CArray
43 static constexpr const char op_name[] = "where.self_out";
44
45 ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
46 utils::apply_tritensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
47 [](const CTYPE_COMPUTE val_a,
48 const CTYPE_COMPUTE val_b,
49 const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; },
50 ctx,
51 a,
52 utils::SupportedTensorDtypes::REALHBBF16,
53 b,
54 utils::SupportedTensorDtypes::REALHBBF16,
55 cond,
56 utils::SupportedTensorDtypes::BOOL_OR_BYTE,
57 out,
58 utils::SupportedTensorDtypes::SAME_AS_COMMON);
59 });
60
61 return out;
62 }
63
64 } // namespace native
65 } // namespace executor
66 } // namespace torch
67