1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/cadence/hifi/kernels/kernels.h>
10 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
11 #include <executorch/kernels/portable/cpu/util/functional_util.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13
14 using exec_aten::ScalarType;
15 using exec_aten::Tensor;
16 using executorch::aten::RuntimeContext;
17 using torch::executor::Error;
18
19 namespace cadence {
20 namespace impl {
21 namespace HiFi {
22 namespace native {
23
where_out(RuntimeContext & ctx,const Tensor & cond,const Tensor & a,const Tensor & b,Tensor & out)24 Tensor& where_out(
25 RuntimeContext& ctx,
26 const Tensor& cond,
27 const Tensor& a,
28 const Tensor& b,
29 Tensor& out) {
30 ScalarType cond_type = cond.scalar_type();
31 ScalarType a_type = a.scalar_type();
32 ScalarType b_type = b.scalar_type();
33 ScalarType common_type = executorch::runtime::promoteTypes(a_type, b_type);
34 ScalarType out_type = out.scalar_type();
35
36 ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
37
38 // Determine output size and resize for dynamic shapes
39 ET_KERNEL_CHECK(
40 ctx,
41 torch::executor::resize_to_broadcast_target_size(a, b, cond, out) ==
42 Error::Ok,
43 InvalidArgument,
44 out);
45
46 constexpr int kNnlibMaxDim = 4; /*fallback if broadcast and dim > 4 */
47 constexpr auto name = "where.self_out";
48
49 ET_CHECK_MSG(
50 cond_type == ScalarType::Bool || cond_type == ScalarType::Byte,
51 "Unhandled dtype %s for where.self_out",
52 torch::executor::toString(cond_type));
53
54 int a_dim = a.dim(), b_dim = b.dim(), con_dim = cond.dim(),
55 out_dim = out.dim();
56 bool optimized = 1;
57 /*find broadcast*/
58 const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
59 const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
60 const bool cond_is_broadcasted = !out.sizes().equals(cond.sizes());
61 const bool broadcast =
62 (a_is_broadcasted || b_is_broadcasted || cond_is_broadcasted);
63
64 int max_dim = a.dim() > b.dim() ? a.dim() : b.dim();
65 max_dim = cond.dim() > max_dim ? cond.dim() : max_dim;
66 max_dim = out.dim() > max_dim ? out.dim() : max_dim;
67
68 if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float))
69 optimized = 0;
70
71 if ((a_dim == 0) || (b_dim == 0) || (con_dim == 0))
72 optimized = 0;
73
74 if ((broadcast == 1) && (max_dim > kNnlibMaxDim))
75 optimized = 0;
76
77 if (optimized) {
78 const float* a_data = a.const_data_ptr<float>();
79 const float* b_data = b.const_data_ptr<float>();
80 float* out_data = out.mutable_data_ptr<float>();
81 const unsigned char* con = cond.const_data_ptr<uint8_t>();
82
83 if (broadcast == 1) {
84 int out_shape[kNnlibMaxDim];
85 int inp1_shape[kNnlibMaxDim];
86 int inp2_shape[kNnlibMaxDim];
87 int con_shape[kNnlibMaxDim];
88
89 for (int i = 0; i < kNnlibMaxDim; i++) {
90 con_shape[i] = 1;
91 out_shape[i] = 1;
92 inp1_shape[i] = 1;
93 inp2_shape[i] = 1;
94 }
95
96 int off_o = kNnlibMaxDim - out.dim();
97 int off_a = kNnlibMaxDim - a.dim();
98 int off_b = kNnlibMaxDim - b.dim();
99 int off_c = kNnlibMaxDim - cond.dim();
100
101 for (int i = 0; i < out.dim(); i++)
102 out_shape[i + off_o] = out.size(i);
103 for (int i = 0; i < a.dim(); i++)
104 inp1_shape[i + off_a] = a.size(i);
105 for (int i = 0; i < b.dim(); i++)
106 inp2_shape[i + off_b] = b.size(i);
107 for (int i = 0; i < cond.dim(); i++)
108 con_shape[i + off_c] = cond.size(i);
109
110 if (con_shape[0] != out_shape[0] || con_shape[1] != out_shape[1] ||
111 con_shape[2] != out_shape[2] || con_shape[3] != out_shape[3]) {
112 void* p_scratch =
113 malloc(out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3]);
114 const unsigned char* p_brd_cond = (const unsigned char*)p_scratch;
115 xa_nn_broadcast_8_8(
116 (WORD8* __restrict__)p_brd_cond,
117 out_shape,
118 (const WORD8* __restrict__)con,
119 con_shape,
120 4);
121
122 for (int i = 0; i < 4; i++) {
123 con_shape[i] = out_shape[i];
124 }
125 xa_nn_elm_where_broadcast_4D_f32xf32_f32(
126 out_data,
127 out_shape,
128 a_data,
129 inp1_shape,
130 b_data,
131 inp2_shape,
132 p_brd_cond,
133 con_shape);
134 free(p_scratch);
135 } else {
136 xa_nn_elm_where_broadcast_4D_f32xf32_f32(
137 out_data,
138 out_shape,
139 a_data,
140 inp1_shape,
141 b_data,
142 inp2_shape,
143 con,
144 con_shape);
145 }
146 } else {
147 xa_nn_elm_where_f32xf32_f32(out_data, a_data, b_data, con, out.numel());
148 }
149 return out;
150 }
151 ET_SWITCH_REALHB_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
152 ET_SWITCH_REALHB_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
153 using CTYPE_OUT =
154 typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
155 torch::executor::
156 apply_ternary_elementwise_fn<CTYPE_A, CTYPE_B, uint8_t, CTYPE_OUT>(
157 [](const CTYPE_A val_a,
158 const CTYPE_B val_b,
159 const uint8_t val_c) {
160 CTYPE_OUT a_casted = static_cast<CTYPE_OUT>(val_a);
161 CTYPE_OUT b_casted = static_cast<CTYPE_OUT>(val_b);
162 return val_c ? a_casted : b_casted;
163 },
164 a,
165 b,
166 cond,
167 out);
168 });
169 });
170 return out;
171 }
172
173 } // namespace native
174 } // namespace HiFi
175 } // namespace impl
176 } // namespace cadence
177