1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/optimized/cpu/binary_ops.h>
10 #include <executorch/kernels/optimized/vec/functional.h>
11 #include <executorch/kernels/optimized/vec/vec.h>
12 #include <executorch/kernels/portable/cpu/scalar_utils.h>
13 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
14 #include <executorch/runtime/core/exec_aten/util/tensor_util.h> // IWYU pragma: export
15 #include <executorch/runtime/kernel/kernel_includes.h>
16 #include <executorch/runtime/platform/assert.h>
17
18 namespace torch {
19 namespace executor {
20 namespace native {
21
22 using Tensor = exec_aten::Tensor;
23 using ScalarType = exec_aten::ScalarType;
24
25 namespace {
26
27 template <
28 bool can_cast,
29 typename CTYPE_A,
30 typename CTYPE_B,
31 typename CTYPE_IN,
32 typename CTYPE_OUT>
33 struct MulInner;
34
35 template <
36 typename CTYPE_A,
37 typename CTYPE_B,
38 typename CTYPE_IN,
39 typename CTYPE_OUT>
40 struct MulInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
runtorch::executor::native::__anon0b8e4a0b0111::MulInner41 static void run(const Tensor& a, const Tensor& b, Tensor& out) {
42 apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
43 // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
44 [](const CTYPE_A val_a, const CTYPE_B val_b) {
45 CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
46 CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
47 CTYPE_IN value = a_casted * b_casted;
48
49 return static_cast<CTYPE_OUT>(value);
50 },
51 a,
52 b,
53 out);
54 }
55 };
56
57 struct ReportCanCastBug {
runtorch::executor::native::__anon0b8e4a0b0111::ReportCanCastBug58 static void run(const Tensor&, const Tensor&, Tensor&) {
59 ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
60 }
61 };
62
63 template <
64 typename CTYPE_A,
65 typename CTYPE_B,
66 typename CTYPE_IN,
67 typename CTYPE_OUT>
68 struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
69 : public ReportCanCastBug {};
70
handle_last_dim_broadcast(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out,const ElementwiseOptimizedPath selected_optimized_path)71 Tensor& handle_last_dim_broadcast(
72 KernelRuntimeContext& ctx,
73 const Tensor& a,
74 const Tensor& b,
75 Tensor& out,
76 const ElementwiseOptimizedPath selected_optimized_path) {
77 ScalarType out_type = out.scalar_type();
78 const Tensor* lhs;
79 const Tensor* rhs;
80 if (selected_optimized_path ==
81 ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) {
82 lhs = &b;
83 rhs = &a;
84 } else {
85 lhs = &a;
86 rhs = &b;
87 }
88 auto error = resize_tensor(out, lhs->sizes());
89 ET_KERNEL_CHECK_MSG(
90 ctx,
91 error == Error::Ok,
92 InvalidArgument,
93 out,
94 "Failed to resize output tensor.");
95 const size_t outer_size = getLeadingDims(out, out.dim() - 1);
96 const auto broadcast_size = out.size(out.dim() - 1);
97 ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
98 using Vec = executorch::vec::Vectorized<CTYPE>;
99 executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>(
100 [](Vec x, Vec y) { return x * y; },
101 out.mutable_data_ptr<CTYPE>(),
102 lhs->const_data_ptr<CTYPE>(),
103 rhs->const_data_ptr<CTYPE>(),
104 outer_size,
105 broadcast_size);
106 });
107 return out;
108 }
109
handle_broadcast_mul(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out,const ElementwiseOptimizedPath selected_optimized_path)110 Tensor& handle_broadcast_mul(
111 KernelRuntimeContext& ctx,
112 const Tensor& a,
113 const Tensor& b,
114 Tensor& out,
115 const ElementwiseOptimizedPath selected_optimized_path) {
116 if ((selected_optimized_path ==
117 ElementwiseOptimizedPath::kBroadcastLastDim) ||
118 (selected_optimized_path ==
119 ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
120 return handle_last_dim_broadcast(ctx, a, b, out, selected_optimized_path);
121 }
122
123 ScalarType out_type = out.scalar_type();
124 const Tensor* lhs;
125 const Tensor* rhs;
126 if ((selected_optimized_path ==
127 ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
128 (selected_optimized_path ==
129 ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
130 lhs = &b;
131 rhs = &a;
132 } else {
133 // Catch failure to update logic when adding new broadcasting possibility.
134 ET_DCHECK(
135 (selected_optimized_path ==
136 ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
137 (selected_optimized_path ==
138 ElementwiseOptimizedPath::kBroadcastNdByNd));
139 lhs = &a;
140 rhs = &b;
141 }
142 auto error = resize_tensor(out, lhs->sizes());
143 ET_KERNEL_CHECK_MSG(
144 ctx,
145 error == Error::Ok,
146 InvalidArgument,
147 out,
148 "Failed to resize output tensor.");
149 int64_t outer_size = 1;
150 int64_t broadcast_size;
151 int64_t inner_size;
152 if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) ||
153 (selected_optimized_path ==
154 ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
155 int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
156 int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
157 auto normalized_tensor_size_lhs =
158 get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
159 outer_size = normalized_tensor_size_lhs[0];
160 broadcast_size = normalized_tensor_size_lhs[1];
161 inner_size = normalized_tensor_size_lhs[2];
162 } else {
163 broadcast_size = lhs->sizes()[lhs->dim() - 2];
164 inner_size = lhs->sizes()[lhs->dim() - 1];
165 }
166 ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
167 using Vec = executorch::vec::Vectorized<CTYPE>;
168 executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
169 [](Vec x, Vec y) { return x * y; },
170 out.mutable_data_ptr<CTYPE>(),
171 lhs->const_data_ptr<CTYPE>(),
172 rhs->const_data_ptr<CTYPE>(),
173 outer_size,
174 broadcast_size,
175 inner_size);
176 });
177 return out;
178 }
179 } // namespace
180
opt_mul_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)181 Tensor& opt_mul_out(
182 KernelRuntimeContext& ctx,
183 const Tensor& a,
184 const Tensor& b,
185 Tensor& out) {
186 (void)ctx;
187
188 ScalarType a_type = a.scalar_type();
189 ScalarType b_type = b.scalar_type();
190 ScalarType out_type = out.scalar_type();
191
192 if (b.numel() == 1) {
193 if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
194 a_type != ScalarType::BFloat16) {
195 ET_KERNEL_CHECK(
196 ctx,
197 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
198 InvalidArgument,
199 out);
200
201 ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.out", CTYPE, [&]() {
202 ET_SWITCH_REALB_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
203 CTYPE_B b_val = *b.const_data_ptr<CTYPE_B>();
204 CTYPE b_casted = static_cast<CTYPE>(b_val);
205
206 using Vec = executorch::vec::Vectorized<CTYPE>;
207 executorch::vec::map<CTYPE>(
208 [b_casted](Vec x) { return x * Vec(b_casted); },
209 out.mutable_data_ptr<CTYPE>(),
210 a.const_data_ptr<CTYPE>(),
211 out.numel());
212 });
213 });
214 return out;
215 }
216 } else if (a.numel() == 1) {
217 return opt_mul_out(ctx, b, a, out);
218 }
219
220 auto selected_optimized_path = select_optimized_path(a, b, out);
221 if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
222 // Resize for dynamic shape
223 auto error = resize_tensor(out, a.sizes());
224 ET_KERNEL_CHECK_MSG(
225 ctx,
226 error == Error::Ok,
227 InvalidArgument,
228 out,
229 "Failed to resize output tensor.");
230
231 ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
232 using Vec = executorch::vec::Vectorized<CTYPE>;
233 executorch::vec::map2<CTYPE>(
234 [](Vec x, Vec y) { return x * y; },
235 out.mutable_data_ptr<CTYPE>(),
236 a.const_data_ptr<CTYPE>(),
237 b.const_data_ptr<CTYPE>(),
238 out.numel());
239 });
240 } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
241 return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path);
242 } else {
243 ScalarType common_type =
244 promoteTypes(a_type, b_type, /*half_to_float*/ true);
245 ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
246
247 ET_KERNEL_CHECK(
248 ctx,
249 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
250 InvalidArgument,
251 out);
252
253 ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
254 ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
255 using CTYPE_IN = typename torch::executor::
256 promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
257 ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
258 ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
259 apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
260 [](const CTYPE_A val_a, const CTYPE_B val_b) {
261 CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
262 CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
263 CTYPE_IN value = a_casted * b_casted;
264
265 return static_cast<CTYPE_OUT>(value);
266 },
267 a,
268 b,
269 out);
270 });
271 });
272 });
273 }
274
275 return out;
276 }
277
opt_mul_scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,Tensor & out)278 Tensor& opt_mul_scalar_out(
279 KernelRuntimeContext& ctx,
280 const Tensor& a,
281 const Scalar& b,
282 Tensor& out) {
283 (void)ctx;
284
285 ScalarType a_type = a.scalar_type();
286 ScalarType b_type = utils::get_scalar_dtype(b);
287 ScalarType common_type =
288 utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
289 ScalarType out_type = out.scalar_type();
290
291 ET_CHECK(common_type == out_type);
292
293 if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
294 common_type = ScalarType::Float;
295 }
296
297 // Resize for dynamic shape
298 auto error = resize_tensor(out, a.sizes());
299 ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
300
301 if (a_type == common_type && a_type == out_type &&
302 a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
303 ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE, [&]() {
304 ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() {
305 CTYPE_B b_val;
306 ET_EXTRACT_SCALAR(b, b_val);
307 CTYPE b_casted = static_cast<CTYPE>(b_val);
308
309 using Vec = executorch::vec::Vectorized<CTYPE>;
310 executorch::vec::map<CTYPE>(
311 [b_casted](Vec x) { return x * Vec(b_casted); },
312 out.mutable_data_ptr<CTYPE>(),
313 a.const_data_ptr<CTYPE>(),
314 out.numel());
315 });
316 });
317 } else {
318 ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() {
319 ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() {
320 ET_SWITCH_REALB_TYPES(
321 common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() {
322 ET_SWITCH_REALHBBF16_TYPES(
323 out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() {
324 CTYPE_B b_val;
325 ET_EXTRACT_SCALAR(b, b_val);
326 CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
327
328 const size_t n = a.numel();
329 const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
330 CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
331 for (auto i = 0; i < n; ++i) {
332 out_data[i] = static_cast<CTYPE_OUT>(
333 static_cast<CTYPE_IN>(a_data[i]) * b_casted);
334 }
335 });
336 });
337 });
338 });
339 }
340
341 return out;
342 }
343
344 } // namespace native
345 } // namespace executor
346 } // namespace torch
347