xref: /aosp_15_r20/external/executorch/kernels/optimized/cpu/op_mul.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/optimized/cpu/binary_ops.h>
10 #include <executorch/kernels/optimized/vec/functional.h>
11 #include <executorch/kernels/optimized/vec/vec.h>
12 #include <executorch/kernels/portable/cpu/scalar_utils.h>
13 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
14 #include <executorch/runtime/core/exec_aten/util/tensor_util.h> // IWYU pragma: export
15 #include <executorch/runtime/kernel/kernel_includes.h>
16 #include <executorch/runtime/platform/assert.h>
17 
18 namespace torch {
19 namespace executor {
20 namespace native {
21 
22 using Tensor = exec_aten::Tensor;
23 using ScalarType = exec_aten::ScalarType;
24 
25 namespace {
26 
27 template <
28     bool can_cast,
29     typename CTYPE_A,
30     typename CTYPE_B,
31     typename CTYPE_IN,
32     typename CTYPE_OUT>
33 struct MulInner;
34 
35 template <
36     typename CTYPE_A,
37     typename CTYPE_B,
38     typename CTYPE_IN,
39     typename CTYPE_OUT>
40 struct MulInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
runtorch::executor::native::__anon0b8e4a0b0111::MulInner41   static void run(const Tensor& a, const Tensor& b, Tensor& out) {
42     apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
43         // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
44         [](const CTYPE_A val_a, const CTYPE_B val_b) {
45           CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
46           CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
47           CTYPE_IN value = a_casted * b_casted;
48 
49           return static_cast<CTYPE_OUT>(value);
50         },
51         a,
52         b,
53         out);
54   }
55 };
56 
57 struct ReportCanCastBug {
runtorch::executor::native::__anon0b8e4a0b0111::ReportCanCastBug58   static void run(const Tensor&, const Tensor&, Tensor&) {
59     ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
60   }
61 };
62 
63 template <
64     typename CTYPE_A,
65     typename CTYPE_B,
66     typename CTYPE_IN,
67     typename CTYPE_OUT>
68 struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
69     : public ReportCanCastBug {};
70 
handle_last_dim_broadcast(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out,const ElementwiseOptimizedPath selected_optimized_path)71 Tensor& handle_last_dim_broadcast(
72     KernelRuntimeContext& ctx,
73     const Tensor& a,
74     const Tensor& b,
75     Tensor& out,
76     const ElementwiseOptimizedPath selected_optimized_path) {
77   ScalarType out_type = out.scalar_type();
78   const Tensor* lhs;
79   const Tensor* rhs;
80   if (selected_optimized_path ==
81       ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) {
82     lhs = &b;
83     rhs = &a;
84   } else {
85     lhs = &a;
86     rhs = &b;
87   }
88   auto error = resize_tensor(out, lhs->sizes());
89   ET_KERNEL_CHECK_MSG(
90       ctx,
91       error == Error::Ok,
92       InvalidArgument,
93       out,
94       "Failed to resize output tensor.");
95   const size_t outer_size = getLeadingDims(out, out.dim() - 1);
96   const auto broadcast_size = out.size(out.dim() - 1);
97   ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
98     using Vec = executorch::vec::Vectorized<CTYPE>;
99     executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>(
100         [](Vec x, Vec y) { return x * y; },
101         out.mutable_data_ptr<CTYPE>(),
102         lhs->const_data_ptr<CTYPE>(),
103         rhs->const_data_ptr<CTYPE>(),
104         outer_size,
105         broadcast_size);
106   });
107   return out;
108 }
109 
handle_broadcast_mul(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out,const ElementwiseOptimizedPath selected_optimized_path)110 Tensor& handle_broadcast_mul(
111     KernelRuntimeContext& ctx,
112     const Tensor& a,
113     const Tensor& b,
114     Tensor& out,
115     const ElementwiseOptimizedPath selected_optimized_path) {
116   if ((selected_optimized_path ==
117        ElementwiseOptimizedPath::kBroadcastLastDim) ||
118       (selected_optimized_path ==
119        ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
120     return handle_last_dim_broadcast(ctx, a, b, out, selected_optimized_path);
121   }
122 
123   ScalarType out_type = out.scalar_type();
124   const Tensor* lhs;
125   const Tensor* rhs;
126   if ((selected_optimized_path ==
127        ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
128       (selected_optimized_path ==
129        ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
130     lhs = &b;
131     rhs = &a;
132   } else {
133     // Catch failure to update logic when adding new broadcasting possibility.
134     ET_DCHECK(
135         (selected_optimized_path ==
136          ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
137         (selected_optimized_path ==
138          ElementwiseOptimizedPath::kBroadcastNdByNd));
139     lhs = &a;
140     rhs = &b;
141   }
142   auto error = resize_tensor(out, lhs->sizes());
143   ET_KERNEL_CHECK_MSG(
144       ctx,
145       error == Error::Ok,
146       InvalidArgument,
147       out,
148       "Failed to resize output tensor.");
149   int64_t outer_size = 1;
150   int64_t broadcast_size;
151   int64_t inner_size;
152   if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) ||
153       (selected_optimized_path ==
154        ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
155     int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
156     int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
157     auto normalized_tensor_size_lhs =
158         get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
159     outer_size = normalized_tensor_size_lhs[0];
160     broadcast_size = normalized_tensor_size_lhs[1];
161     inner_size = normalized_tensor_size_lhs[2];
162   } else {
163     broadcast_size = lhs->sizes()[lhs->dim() - 2];
164     inner_size = lhs->sizes()[lhs->dim() - 1];
165   }
166   ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
167     using Vec = executorch::vec::Vectorized<CTYPE>;
168     executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
169         [](Vec x, Vec y) { return x * y; },
170         out.mutable_data_ptr<CTYPE>(),
171         lhs->const_data_ptr<CTYPE>(),
172         rhs->const_data_ptr<CTYPE>(),
173         outer_size,
174         broadcast_size,
175         inner_size);
176   });
177   return out;
178 }
179 } // namespace
180 
opt_mul_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)181 Tensor& opt_mul_out(
182     KernelRuntimeContext& ctx,
183     const Tensor& a,
184     const Tensor& b,
185     Tensor& out) {
186   (void)ctx;
187 
188   ScalarType a_type = a.scalar_type();
189   ScalarType b_type = b.scalar_type();
190   ScalarType out_type = out.scalar_type();
191 
192   if (b.numel() == 1) {
193     if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
194         a_type != ScalarType::BFloat16) {
195       ET_KERNEL_CHECK(
196           ctx,
197           resize_to_broadcast_target_size(a, b, out) == Error::Ok,
198           InvalidArgument,
199           out);
200 
201       ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.out", CTYPE, [&]() {
202         ET_SWITCH_REALB_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
203           CTYPE_B b_val = *b.const_data_ptr<CTYPE_B>();
204           CTYPE b_casted = static_cast<CTYPE>(b_val);
205 
206           using Vec = executorch::vec::Vectorized<CTYPE>;
207           executorch::vec::map<CTYPE>(
208               [b_casted](Vec x) { return x * Vec(b_casted); },
209               out.mutable_data_ptr<CTYPE>(),
210               a.const_data_ptr<CTYPE>(),
211               out.numel());
212         });
213       });
214       return out;
215     }
216   } else if (a.numel() == 1) {
217     return opt_mul_out(ctx, b, a, out);
218   }
219 
220   auto selected_optimized_path = select_optimized_path(a, b, out);
221   if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
222     // Resize for dynamic shape
223     auto error = resize_tensor(out, a.sizes());
224     ET_KERNEL_CHECK_MSG(
225         ctx,
226         error == Error::Ok,
227         InvalidArgument,
228         out,
229         "Failed to resize output tensor.");
230 
231     ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
232       using Vec = executorch::vec::Vectorized<CTYPE>;
233       executorch::vec::map2<CTYPE>(
234           [](Vec x, Vec y) { return x * y; },
235           out.mutable_data_ptr<CTYPE>(),
236           a.const_data_ptr<CTYPE>(),
237           b.const_data_ptr<CTYPE>(),
238           out.numel());
239     });
240   } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
241     return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path);
242   } else {
243     ScalarType common_type =
244         promoteTypes(a_type, b_type, /*half_to_float*/ true);
245     ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
246 
247     ET_KERNEL_CHECK(
248         ctx,
249         resize_to_broadcast_target_size(a, b, out) == Error::Ok,
250         InvalidArgument,
251         out);
252 
253     ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
254       ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
255         using CTYPE_IN = typename torch::executor::
256             promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
257         ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
258         ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
259           apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
260               [](const CTYPE_A val_a, const CTYPE_B val_b) {
261                 CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
262                 CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
263                 CTYPE_IN value = a_casted * b_casted;
264 
265                 return static_cast<CTYPE_OUT>(value);
266               },
267               a,
268               b,
269               out);
270         });
271       });
272     });
273   }
274 
275   return out;
276 }
277 
opt_mul_scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,Tensor & out)278 Tensor& opt_mul_scalar_out(
279     KernelRuntimeContext& ctx,
280     const Tensor& a,
281     const Scalar& b,
282     Tensor& out) {
283   (void)ctx;
284 
285   ScalarType a_type = a.scalar_type();
286   ScalarType b_type = utils::get_scalar_dtype(b);
287   ScalarType common_type =
288       utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
289   ScalarType out_type = out.scalar_type();
290 
291   ET_CHECK(common_type == out_type);
292 
293   if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
294     common_type = ScalarType::Float;
295   }
296 
297   // Resize for dynamic shape
298   auto error = resize_tensor(out, a.sizes());
299   ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
300 
301   if (a_type == common_type && a_type == out_type &&
302       a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
303     ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE, [&]() {
304       ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() {
305         CTYPE_B b_val;
306         ET_EXTRACT_SCALAR(b, b_val);
307         CTYPE b_casted = static_cast<CTYPE>(b_val);
308 
309         using Vec = executorch::vec::Vectorized<CTYPE>;
310         executorch::vec::map<CTYPE>(
311             [b_casted](Vec x) { return x * Vec(b_casted); },
312             out.mutable_data_ptr<CTYPE>(),
313             a.const_data_ptr<CTYPE>(),
314             out.numel());
315       });
316     });
317   } else {
318     ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() {
319       ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() {
320         ET_SWITCH_REALB_TYPES(
321             common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() {
322               ET_SWITCH_REALHBBF16_TYPES(
323                   out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() {
324                     CTYPE_B b_val;
325                     ET_EXTRACT_SCALAR(b, b_val);
326                     CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
327 
328                     const size_t n = a.numel();
329                     const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
330                     CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
331                     for (auto i = 0; i < n; ++i) {
332                       out_data[i] = static_cast<CTYPE_OUT>(
333                           static_cast<CTYPE_IN>(a_data[i]) * b_casted);
334                     }
335                   });
336             });
337       });
338     });
339   }
340 
341   return out;
342 }
343 
344 } // namespace native
345 } // namespace executor
346 } // namespace torch
347