xref: /aosp_15_r20/external/executorch/kernels/optimized/cpu/op_add.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/optimized/cpu/binary_ops.h>
10 #include <executorch/kernels/optimized/vec/functional.h>
11 #include <executorch/kernels/optimized/vec/vec.h>
12 #include <executorch/kernels/portable/cpu/scalar_utils.h>
13 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15 #include <executorch/runtime/platform/assert.h>
16 
17 namespace torch {
18 namespace executor {
19 namespace native {
20 namespace {
21 
22 template <
23     bool can_cast,
24     typename CTYPE_A,
25     typename CTYPE_B,
26     typename CTYPE_IN,
27     typename CTYPE_OUT>
28 struct AddInner;
29 
30 template <
31     typename CTYPE_A,
32     typename CTYPE_B,
33     typename CTYPE_IN,
34     typename CTYPE_OUT>
35 struct AddInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
36   static void
runtorch::executor::native::__anon479afa460111::AddInner37   run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
38     apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39         // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
40         [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
41           CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
42           CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
43           CTYPE_IN value = a_casted + alpha_val * b_casted;
44 
45           return static_cast<CTYPE_OUT>(value);
46         },
47         a,
48         b,
49         out);
50   }
51 };
52 
53 template <typename CTYPE_IN>
54 struct ReportCanCastBug {
runtorch::executor::native::__anon479afa460111::ReportCanCastBug55   static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
56     ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
57   }
58 };
59 
60 template <
61     typename CTYPE_A,
62     typename CTYPE_B,
63     typename CTYPE_IN,
64     typename CTYPE_OUT>
65 struct AddInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
66     : public ReportCanCastBug<CTYPE_IN> {};
67 
68 } // namespace
69 
70 using Tensor = exec_aten::Tensor;
71 using ScalarType = exec_aten::ScalarType;
72 
opt_add_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,const Scalar & alpha,Tensor & out)73 Tensor& opt_add_out(
74     KernelRuntimeContext& ctx,
75     const Tensor& a,
76     const Tensor& b,
77     const Scalar& alpha,
78     Tensor& out) {
79   (void)ctx;
80 
81   ScalarType a_type = a.scalar_type();
82   ScalarType b_type = b.scalar_type();
83   ScalarType out_type = out.scalar_type();
84 
85   if (b.numel() == 1) {
86     if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
87         a_type != ScalarType::BFloat16) {
88       ET_KERNEL_CHECK(
89           ctx,
90           resize_to_broadcast_target_size(a, b, out) == Error::Ok,
91           InvalidArgument,
92           out);
93 
94       ET_SWITCH_REALB_TYPES(a_type, ctx, "add.out", CTYPE, [&]() {
95         ET_SWITCH_REALB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
96           CTYPE alpha_val;
97           ET_KERNEL_CHECK(
98               ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
99           CTYPE_B b_val = *b.const_data_ptr<CTYPE_B>();
100           CTYPE b_casted = static_cast<CTYPE>(b_val);
101 
102           using Vec = executorch::vec::Vectorized<CTYPE>;
103           executorch::vec::map<CTYPE>(
104               [alpha_val, b_casted](Vec x) {
105                 return x + Vec(alpha_val * b_casted);
106               },
107               out.mutable_data_ptr<CTYPE>(),
108               a.const_data_ptr<CTYPE>(),
109               out.numel());
110         });
111       });
112       return out;
113     }
114   } else if (a.numel() == 1) {
115     return opt_add_out(ctx, b, a, alpha, out);
116   }
117 
118   auto selected_optimized_path = select_optimized_path(a, b, out);
119   if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
120     // Resize for dynamic shape
121     auto error = resize_tensor(out, a.sizes());
122     ET_KERNEL_CHECK_MSG(
123         ctx,
124         error == Error::Ok,
125         InvalidArgument,
126         out,
127         "Failed to resize output tensor.");
128 
129     ET_SWITCH_REALB_TYPES(a_type, ctx, "add.out", CTYPE, [&]() {
130       CTYPE alpha_val;
131       ET_KERNEL_CHECK(
132           ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
133 
134       using Vec = executorch::vec::Vectorized<CTYPE>;
135       executorch::vec::map2<CTYPE>(
136           [alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; },
137           out.mutable_data_ptr<CTYPE>(),
138           a.const_data_ptr<CTYPE>(),
139           b.const_data_ptr<CTYPE>(),
140           out.numel());
141     });
142   } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
143     const Tensor* lhs;
144     const Tensor* rhs;
145     if (selected_optimized_path ==
146         ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
147       lhs = &b;
148       rhs = &a;
149     } else {
150       // Catch failure to update logic when adding new broadcasting possibility.
151       ET_DCHECK(
152           selected_optimized_path ==
153           ElementwiseOptimizedPath::kBroadcast2dBy1d);
154       lhs = &a;
155       rhs = &b;
156     }
157     auto error = resize_tensor(out, lhs->sizes());
158     ET_KERNEL_CHECK_MSG(
159         ctx,
160         error == Error::Ok,
161         InvalidArgument,
162         out,
163         "Failed to resize output tensor.");
164     ET_SWITCH_REALB_TYPES(out_type, ctx, "add.out", CTYPE, [&]() {
165       CTYPE alpha_val;
166       ET_KERNEL_CHECK(
167           ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
168 
169       using Vec = executorch::vec::Vectorized<CTYPE>;
170       executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
171           [alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; },
172           out.mutable_data_ptr<CTYPE>(),
173           lhs->const_data_ptr<CTYPE>(),
174           rhs->const_data_ptr<CTYPE>(),
175           lhs->sizes()[lhs->dim() - 2],
176           lhs->sizes()[lhs->dim() - 1]);
177     });
178   } else {
179     ScalarType common_type =
180         promoteTypes(a_type, b_type, /*half_to_float*/ true);
181     ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
182 
183     ET_KERNEL_CHECK(
184         ctx,
185         resize_to_broadcast_target_size(a, b, out) == Error::Ok,
186         InvalidArgument,
187         out);
188 
189     ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() {
190       ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
191         using CTYPE_IN = typename torch::executor::
192             promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
193         ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
194         ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() {
195           CTYPE_IN alpha_val;
196           ET_KERNEL_CHECK(
197               ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
198 
199           AddInner<
200               can_cast<CTYPE_IN, CTYPE_OUT>::value,
201               CTYPE_A,
202               CTYPE_B,
203               CTYPE_IN,
204               CTYPE_OUT>::run(a, b, alpha_val, out);
205         });
206       });
207     });
208   }
209 
210   return out;
211 }
212 
opt_add_scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,const Scalar & alpha,Tensor & out)213 Tensor& opt_add_scalar_out(
214     KernelRuntimeContext& ctx,
215     const Tensor& a,
216     const Scalar& b,
217     const Scalar& alpha,
218     Tensor& out) {
219   (void)ctx;
220 
221   ScalarType a_type = a.scalar_type();
222   ScalarType b_type = utils::get_scalar_dtype(b);
223   ScalarType common_type =
224       utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
225   ScalarType out_type = out.scalar_type();
226 
227   ET_CHECK(common_type == out_type);
228 
229   if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
230     common_type = ScalarType::Float;
231   }
232 
233   // Resize for dynamic shape
234   auto error = resize_tensor(out, a.sizes());
235   ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
236 
237   if (a_type == common_type && a_type == out_type &&
238       a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
239     ET_SWITCH_REALB_TYPES(a_type, ctx, "add.Scalar_out", CTYPE, [&]() {
240       ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
241         CTYPE_B b_val;
242         ET_EXTRACT_SCALAR(b, b_val);
243         CTYPE b_casted = static_cast<CTYPE>(b_val);
244         CTYPE alpha_val;
245         ET_EXTRACT_SCALAR(alpha, alpha_val);
246 
247         using Vec = executorch::vec::Vectorized<CTYPE>;
248         executorch::vec::map<CTYPE>(
249             [alpha_val, b_casted](Vec x) {
250               return x + Vec(alpha_val * b_casted);
251             },
252             out.mutable_data_ptr<CTYPE>(),
253             a.const_data_ptr<CTYPE>(),
254             out.numel());
255       });
256     });
257   } else {
258     ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() {
259       ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
260         ET_SWITCH_REALB_TYPES(
261             common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() {
262               ET_SWITCH_REALHBBF16_TYPES(
263                   out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() {
264                     CTYPE_B b_val;
265                     ET_EXTRACT_SCALAR(b, b_val);
266                     CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
267                     CTYPE_IN alpha_val;
268                     ET_EXTRACT_SCALAR(alpha, alpha_val);
269 
270                     const size_t n = a.numel();
271                     const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
272                     CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
273                     for (auto i = 0; i < n; ++i) {
274                       out_data[i] = static_cast<CTYPE_OUT>(
275                           static_cast<CTYPE_IN>(a_data[i]) +
276                           alpha_val * b_casted);
277                     }
278                   });
279             });
280       });
281     });
282   }
283 
284   return out;
285 }
286 
287 } // namespace native
288 } // namespace executor
289 } // namespace torch
290