1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/optimized/cpu/binary_ops.h>
10 #include <executorch/kernels/optimized/vec/functional.h>
11 #include <executorch/kernels/optimized/vec/vec.h>
12 #include <executorch/kernels/portable/cpu/scalar_utils.h>
13 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15 #include <executorch/runtime/platform/assert.h>
16
17 namespace torch {
18 namespace executor {
19 namespace native {
20 namespace {
21
22 template <
23 bool can_cast,
24 typename CTYPE_A,
25 typename CTYPE_B,
26 typename CTYPE_IN,
27 typename CTYPE_OUT>
28 struct AddInner;
29
30 template <
31 typename CTYPE_A,
32 typename CTYPE_B,
33 typename CTYPE_IN,
34 typename CTYPE_OUT>
35 struct AddInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
36 static void
runtorch::executor::native::__anon479afa460111::AddInner37 run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
38 apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39 // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
40 [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
41 CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
42 CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
43 CTYPE_IN value = a_casted + alpha_val * b_casted;
44
45 return static_cast<CTYPE_OUT>(value);
46 },
47 a,
48 b,
49 out);
50 }
51 };
52
53 template <typename CTYPE_IN>
54 struct ReportCanCastBug {
runtorch::executor::native::__anon479afa460111::ReportCanCastBug55 static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
56 ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
57 }
58 };
59
60 template <
61 typename CTYPE_A,
62 typename CTYPE_B,
63 typename CTYPE_IN,
64 typename CTYPE_OUT>
65 struct AddInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
66 : public ReportCanCastBug<CTYPE_IN> {};
67
68 } // namespace
69
70 using Tensor = exec_aten::Tensor;
71 using ScalarType = exec_aten::ScalarType;
72
opt_add_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,const Scalar & alpha,Tensor & out)73 Tensor& opt_add_out(
74 KernelRuntimeContext& ctx,
75 const Tensor& a,
76 const Tensor& b,
77 const Scalar& alpha,
78 Tensor& out) {
79 (void)ctx;
80
81 ScalarType a_type = a.scalar_type();
82 ScalarType b_type = b.scalar_type();
83 ScalarType out_type = out.scalar_type();
84
85 if (b.numel() == 1) {
86 if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
87 a_type != ScalarType::BFloat16) {
88 ET_KERNEL_CHECK(
89 ctx,
90 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
91 InvalidArgument,
92 out);
93
94 ET_SWITCH_REALB_TYPES(a_type, ctx, "add.out", CTYPE, [&]() {
95 ET_SWITCH_REALB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
96 CTYPE alpha_val;
97 ET_KERNEL_CHECK(
98 ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
99 CTYPE_B b_val = *b.const_data_ptr<CTYPE_B>();
100 CTYPE b_casted = static_cast<CTYPE>(b_val);
101
102 using Vec = executorch::vec::Vectorized<CTYPE>;
103 executorch::vec::map<CTYPE>(
104 [alpha_val, b_casted](Vec x) {
105 return x + Vec(alpha_val * b_casted);
106 },
107 out.mutable_data_ptr<CTYPE>(),
108 a.const_data_ptr<CTYPE>(),
109 out.numel());
110 });
111 });
112 return out;
113 }
114 } else if (a.numel() == 1) {
115 return opt_add_out(ctx, b, a, alpha, out);
116 }
117
118 auto selected_optimized_path = select_optimized_path(a, b, out);
119 if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
120 // Resize for dynamic shape
121 auto error = resize_tensor(out, a.sizes());
122 ET_KERNEL_CHECK_MSG(
123 ctx,
124 error == Error::Ok,
125 InvalidArgument,
126 out,
127 "Failed to resize output tensor.");
128
129 ET_SWITCH_REALB_TYPES(a_type, ctx, "add.out", CTYPE, [&]() {
130 CTYPE alpha_val;
131 ET_KERNEL_CHECK(
132 ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
133
134 using Vec = executorch::vec::Vectorized<CTYPE>;
135 executorch::vec::map2<CTYPE>(
136 [alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; },
137 out.mutable_data_ptr<CTYPE>(),
138 a.const_data_ptr<CTYPE>(),
139 b.const_data_ptr<CTYPE>(),
140 out.numel());
141 });
142 } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
143 const Tensor* lhs;
144 const Tensor* rhs;
145 if (selected_optimized_path ==
146 ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
147 lhs = &b;
148 rhs = &a;
149 } else {
150 // Catch failure to update logic when adding new broadcasting possibility.
151 ET_DCHECK(
152 selected_optimized_path ==
153 ElementwiseOptimizedPath::kBroadcast2dBy1d);
154 lhs = &a;
155 rhs = &b;
156 }
157 auto error = resize_tensor(out, lhs->sizes());
158 ET_KERNEL_CHECK_MSG(
159 ctx,
160 error == Error::Ok,
161 InvalidArgument,
162 out,
163 "Failed to resize output tensor.");
164 ET_SWITCH_REALB_TYPES(out_type, ctx, "add.out", CTYPE, [&]() {
165 CTYPE alpha_val;
166 ET_KERNEL_CHECK(
167 ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
168
169 using Vec = executorch::vec::Vectorized<CTYPE>;
170 executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
171 [alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; },
172 out.mutable_data_ptr<CTYPE>(),
173 lhs->const_data_ptr<CTYPE>(),
174 rhs->const_data_ptr<CTYPE>(),
175 lhs->sizes()[lhs->dim() - 2],
176 lhs->sizes()[lhs->dim() - 1]);
177 });
178 } else {
179 ScalarType common_type =
180 promoteTypes(a_type, b_type, /*half_to_float*/ true);
181 ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
182
183 ET_KERNEL_CHECK(
184 ctx,
185 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
186 InvalidArgument,
187 out);
188
189 ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() {
190 ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
191 using CTYPE_IN = typename torch::executor::
192 promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
193 ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
194 ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() {
195 CTYPE_IN alpha_val;
196 ET_KERNEL_CHECK(
197 ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
198
199 AddInner<
200 can_cast<CTYPE_IN, CTYPE_OUT>::value,
201 CTYPE_A,
202 CTYPE_B,
203 CTYPE_IN,
204 CTYPE_OUT>::run(a, b, alpha_val, out);
205 });
206 });
207 });
208 }
209
210 return out;
211 }
212
opt_add_scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,const Scalar & alpha,Tensor & out)213 Tensor& opt_add_scalar_out(
214 KernelRuntimeContext& ctx,
215 const Tensor& a,
216 const Scalar& b,
217 const Scalar& alpha,
218 Tensor& out) {
219 (void)ctx;
220
221 ScalarType a_type = a.scalar_type();
222 ScalarType b_type = utils::get_scalar_dtype(b);
223 ScalarType common_type =
224 utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
225 ScalarType out_type = out.scalar_type();
226
227 ET_CHECK(common_type == out_type);
228
229 if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
230 common_type = ScalarType::Float;
231 }
232
233 // Resize for dynamic shape
234 auto error = resize_tensor(out, a.sizes());
235 ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
236
237 if (a_type == common_type && a_type == out_type &&
238 a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
239 ET_SWITCH_REALB_TYPES(a_type, ctx, "add.Scalar_out", CTYPE, [&]() {
240 ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
241 CTYPE_B b_val;
242 ET_EXTRACT_SCALAR(b, b_val);
243 CTYPE b_casted = static_cast<CTYPE>(b_val);
244 CTYPE alpha_val;
245 ET_EXTRACT_SCALAR(alpha, alpha_val);
246
247 using Vec = executorch::vec::Vectorized<CTYPE>;
248 executorch::vec::map<CTYPE>(
249 [alpha_val, b_casted](Vec x) {
250 return x + Vec(alpha_val * b_casted);
251 },
252 out.mutable_data_ptr<CTYPE>(),
253 a.const_data_ptr<CTYPE>(),
254 out.numel());
255 });
256 });
257 } else {
258 ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() {
259 ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
260 ET_SWITCH_REALB_TYPES(
261 common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() {
262 ET_SWITCH_REALHBBF16_TYPES(
263 out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() {
264 CTYPE_B b_val;
265 ET_EXTRACT_SCALAR(b, b_val);
266 CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
267 CTYPE_IN alpha_val;
268 ET_EXTRACT_SCALAR(alpha, alpha_val);
269
270 const size_t n = a.numel();
271 const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
272 CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
273 for (auto i = 0; i < n; ++i) {
274 out_data[i] = static_cast<CTYPE_OUT>(
275 static_cast<CTYPE_IN>(a_data[i]) +
276 alpha_val * b_casted);
277 }
278 });
279 });
280 });
281 });
282 }
283
284 return out;
285 }
286
287 } // namespace native
288 } // namespace executor
289 } // namespace torch
290