xref: /aosp_15_r20/external/executorch/backends/cadence/hifi/operators/op_where.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/cadence/hifi/kernels/kernels.h>
10 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
11 #include <executorch/kernels/portable/cpu/util/functional_util.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 
14 using exec_aten::ScalarType;
15 using exec_aten::Tensor;
16 using executorch::aten::RuntimeContext;
17 using torch::executor::Error;
18 
19 namespace cadence {
20 namespace impl {
21 namespace HiFi {
22 namespace native {
23 
where_out(RuntimeContext & ctx,const Tensor & cond,const Tensor & a,const Tensor & b,Tensor & out)24 Tensor& where_out(
25     RuntimeContext& ctx,
26     const Tensor& cond,
27     const Tensor& a,
28     const Tensor& b,
29     Tensor& out) {
30   ScalarType cond_type = cond.scalar_type();
31   ScalarType a_type = a.scalar_type();
32   ScalarType b_type = b.scalar_type();
33   ScalarType common_type = executorch::runtime::promoteTypes(a_type, b_type);
34   ScalarType out_type = out.scalar_type();
35 
36   ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
37 
38   // Determine output size and resize for dynamic shapes
39   ET_KERNEL_CHECK(
40       ctx,
41       torch::executor::resize_to_broadcast_target_size(a, b, cond, out) ==
42           Error::Ok,
43       InvalidArgument,
44       out);
45 
46   constexpr int kNnlibMaxDim = 4; /*fallback if broadcast and dim > 4 */
47   constexpr auto name = "where.self_out";
48 
49   ET_CHECK_MSG(
50       cond_type == ScalarType::Bool || cond_type == ScalarType::Byte,
51       "Unhandled dtype %s for where.self_out",
52       torch::executor::toString(cond_type));
53 
54   int a_dim = a.dim(), b_dim = b.dim(), con_dim = cond.dim(),
55       out_dim = out.dim();
56   bool optimized = 1;
57   /*find broadcast*/
58   const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
59   const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
60   const bool cond_is_broadcasted = !out.sizes().equals(cond.sizes());
61   const bool broadcast =
62       (a_is_broadcasted || b_is_broadcasted || cond_is_broadcasted);
63 
64   int max_dim = a.dim() > b.dim() ? a.dim() : b.dim();
65   max_dim = cond.dim() > max_dim ? cond.dim() : max_dim;
66   max_dim = out.dim() > max_dim ? out.dim() : max_dim;
67 
68   if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float))
69     optimized = 0;
70 
71   if ((a_dim == 0) || (b_dim == 0) || (con_dim == 0))
72     optimized = 0;
73 
74   if ((broadcast == 1) && (max_dim > kNnlibMaxDim))
75     optimized = 0;
76 
77   if (optimized) {
78     const float* a_data = a.const_data_ptr<float>();
79     const float* b_data = b.const_data_ptr<float>();
80     float* out_data = out.mutable_data_ptr<float>();
81     const unsigned char* con = cond.const_data_ptr<uint8_t>();
82 
83     if (broadcast == 1) {
84       int out_shape[kNnlibMaxDim];
85       int inp1_shape[kNnlibMaxDim];
86       int inp2_shape[kNnlibMaxDim];
87       int con_shape[kNnlibMaxDim];
88 
89       for (int i = 0; i < kNnlibMaxDim; i++) {
90         con_shape[i] = 1;
91         out_shape[i] = 1;
92         inp1_shape[i] = 1;
93         inp2_shape[i] = 1;
94       }
95 
96       int off_o = kNnlibMaxDim - out.dim();
97       int off_a = kNnlibMaxDim - a.dim();
98       int off_b = kNnlibMaxDim - b.dim();
99       int off_c = kNnlibMaxDim - cond.dim();
100 
101       for (int i = 0; i < out.dim(); i++)
102         out_shape[i + off_o] = out.size(i);
103       for (int i = 0; i < a.dim(); i++)
104         inp1_shape[i + off_a] = a.size(i);
105       for (int i = 0; i < b.dim(); i++)
106         inp2_shape[i + off_b] = b.size(i);
107       for (int i = 0; i < cond.dim(); i++)
108         con_shape[i + off_c] = cond.size(i);
109 
110       if (con_shape[0] != out_shape[0] || con_shape[1] != out_shape[1] ||
111           con_shape[2] != out_shape[2] || con_shape[3] != out_shape[3]) {
112         void* p_scratch =
113             malloc(out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3]);
114         const unsigned char* p_brd_cond = (const unsigned char*)p_scratch;
115         xa_nn_broadcast_8_8(
116             (WORD8* __restrict__)p_brd_cond,
117             out_shape,
118             (const WORD8* __restrict__)con,
119             con_shape,
120             4);
121 
122         for (int i = 0; i < 4; i++) {
123           con_shape[i] = out_shape[i];
124         }
125         xa_nn_elm_where_broadcast_4D_f32xf32_f32(
126             out_data,
127             out_shape,
128             a_data,
129             inp1_shape,
130             b_data,
131             inp2_shape,
132             p_brd_cond,
133             con_shape);
134         free(p_scratch);
135       } else {
136         xa_nn_elm_where_broadcast_4D_f32xf32_f32(
137             out_data,
138             out_shape,
139             a_data,
140             inp1_shape,
141             b_data,
142             inp2_shape,
143             con,
144             con_shape);
145       }
146     } else {
147       xa_nn_elm_where_f32xf32_f32(out_data, a_data, b_data, con, out.numel());
148     }
149     return out;
150   }
151   ET_SWITCH_REALHB_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
152     ET_SWITCH_REALHB_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
153       using CTYPE_OUT =
154           typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
155       torch::executor::
156           apply_ternary_elementwise_fn<CTYPE_A, CTYPE_B, uint8_t, CTYPE_OUT>(
157               [](const CTYPE_A val_a,
158                  const CTYPE_B val_b,
159                  const uint8_t val_c) {
160                 CTYPE_OUT a_casted = static_cast<CTYPE_OUT>(val_a);
161                 CTYPE_OUT b_casted = static_cast<CTYPE_OUT>(val_b);
162                 return val_c ? a_casted : b_casted;
163               },
164               a,
165               b,
166               cond,
167               out);
168     });
169   });
170   return out;
171 }
172 
173 } // namespace native
174 } // namespace HiFi
175 } // namespace impl
176 } // namespace cadence
177