xref: /aosp_15_r20/external/executorch/kernels/optimized/cpu/op_sub.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/optimized/cpu/binary_ops.h>
10 #include <executorch/kernels/optimized/vec/functional.h>
11 #include <executorch/kernels/optimized/vec/vec.h>
12 #include <executorch/kernels/portable/cpu/scalar_utils.h>
13 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
14 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
15 #include <executorch/runtime/kernel/kernel_includes.h>
16 #include <executorch/runtime/platform/assert.h>
17 
18 namespace torch {
19 namespace executor {
20 namespace native {
21 namespace {
22 
23 template <
24     bool can_cast,
25     typename CTYPE_A,
26     typename CTYPE_B,
27     typename CTYPE_IN,
28     typename CTYPE_OUT>
29 struct SubInner;
30 
31 template <
32     typename CTYPE_A,
33     typename CTYPE_B,
34     typename CTYPE_IN,
35     typename CTYPE_OUT>
36 struct SubInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
37   static void
runtorch::executor::native::__anond8b6c1870111::SubInner38   run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
39     apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
40         // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
41         [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
42           CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
43           CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
44           CTYPE_IN value = a_casted - alpha_val * b_casted;
45 
46           return static_cast<CTYPE_OUT>(value);
47         },
48         a,
49         b,
50         out);
51   }
52 };
53 
54 template <typename CTYPE_IN>
55 struct ReportCanCastBug {
runtorch::executor::native::__anond8b6c1870111::ReportCanCastBug56   static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
57     ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
58   }
59 };
60 
61 template <
62     typename CTYPE_A,
63     typename CTYPE_B,
64     typename CTYPE_IN,
65     typename CTYPE_OUT>
66 struct SubInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
67     : public ReportCanCastBug<CTYPE_IN> {};
68 
69 } // namespace
70 
71 using Tensor = exec_aten::Tensor;
72 using ScalarType = exec_aten::ScalarType;
73 
opt_sub_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,const Scalar & alpha,Tensor & out)74 Tensor& opt_sub_out(
75     KernelRuntimeContext& ctx,
76     const Tensor& a,
77     const Tensor& b,
78     const Scalar& alpha,
79     Tensor& out) {
80   (void)ctx;
81 
82   ScalarType a_type = a.scalar_type();
83   ScalarType b_type = b.scalar_type();
84   ScalarType out_type = out.scalar_type();
85 
86   ET_KERNEL_CHECK(ctx, tensor_is_realh_type(out), InvalidArgument, out);
87   if (a.numel() == 1 || b.numel() == 1) {
88     if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) {
89       const Tensor* tensor;
90       const Tensor* scalar;
91       ScalarType tensor_type;
92       ScalarType scalar_type;
93       if (a.numel() == 1) {
94         tensor = &b;
95         tensor_type = b_type;
96         scalar = &a;
97         scalar_type = a_type;
98       } else {
99         tensor = &a;
100         tensor_type = a_type;
101         scalar = &b;
102         scalar_type = b_type;
103       }
104       ET_KERNEL_CHECK(
105           ctx,
106           resize_to_broadcast_target_size(a, b, out) == Error::Ok,
107           InvalidArgument,
108           out);
109       ET_SWITCH_REAL_TYPES(tensor_type, ctx, "sub.out", CTYPE, [&]() {
110         ET_SWITCH_REAL_TYPES(scalar_type, ctx, "sub.out", CTYPE_SCALAR, [&]() {
111           CTYPE alpha_val;
112           ET_KERNEL_CHECK(
113               ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
114           CTYPE_SCALAR scalar_val = *scalar->const_data_ptr<CTYPE_SCALAR>();
115           CTYPE scalar_casted = static_cast<CTYPE>(scalar_val);
116 
117           using Vec = executorch::vec::Vectorized<CTYPE>;
118           if (a.numel() == 1) {
119             executorch::vec::map<CTYPE>(
120                 [alpha_val, scalar_casted](Vec x) {
121                   return Vec(scalar_casted) - Vec(alpha_val) * x;
122                 },
123                 out.mutable_data_ptr<CTYPE>(),
124                 tensor->const_data_ptr<CTYPE>(),
125                 out.numel());
126           } else {
127             executorch::vec::map<CTYPE>(
128                 [alpha_val, scalar_casted](Vec x) {
129                   return x - Vec(alpha_val * scalar_casted);
130                 },
131                 out.mutable_data_ptr<CTYPE>(),
132                 tensor->const_data_ptr<CTYPE>(),
133                 out.numel());
134           }
135         });
136       });
137       return out;
138     }
139   }
140 
141   auto selected_optimized_path = select_optimized_path(a, b, out);
142   if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
143     // Resize for dynamic shape
144     auto error = resize_tensor(out, a.sizes());
145     ET_KERNEL_CHECK_MSG(
146         ctx,
147         error == Error::Ok,
148         InvalidArgument,
149         out,
150         "Failed to resize output tensor.");
151 
152     ET_SWITCH_REAL_TYPES(a_type, ctx, "sub.out", CTYPE, [&]() {
153       CTYPE alpha_val;
154       ET_KERNEL_CHECK(
155           ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
156 
157       using Vec = executorch::vec::Vectorized<CTYPE>;
158       executorch::vec::map2<CTYPE>(
159           [alpha_val](Vec x, Vec y) { return x - Vec(alpha_val) * y; },
160           out.mutable_data_ptr<CTYPE>(),
161           a.const_data_ptr<CTYPE>(),
162           b.const_data_ptr<CTYPE>(),
163           out.numel());
164     });
165   } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
166     const Tensor* lhs;
167     const Tensor* rhs;
168     if (selected_optimized_path ==
169         ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
170       lhs = &b;
171       rhs = &a;
172     } else {
173       // Catch failure to update logic when subing new broadcasting possibility.
174       ET_DCHECK(
175           selected_optimized_path ==
176           ElementwiseOptimizedPath::kBroadcast2dBy1d);
177       lhs = &a;
178       rhs = &b;
179     }
180     auto error = resize_tensor(out, lhs->sizes());
181     ET_KERNEL_CHECK_MSG(
182         ctx,
183         error == Error::Ok,
184         InvalidArgument,
185         out,
186         "Failed to resize output tensor.");
187     ET_SWITCH_REAL_TYPES(out_type, ctx, "sub.out", CTYPE, [&]() {
188       CTYPE alpha_val;
189       ET_KERNEL_CHECK(
190           ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
191 
192       using Vec = executorch::vec::Vectorized<CTYPE>;
193       if (selected_optimized_path ==
194           ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
195         executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
196             [alpha_val](Vec x, Vec y) { return y - Vec(alpha_val) * x; },
197             out.mutable_data_ptr<CTYPE>(),
198             lhs->const_data_ptr<CTYPE>(),
199             rhs->const_data_ptr<CTYPE>(),
200             lhs->sizes()[lhs->dim() - 2],
201             lhs->sizes()[lhs->dim() - 1]);
202       } else {
203         executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
204             [alpha_val](Vec x, Vec y) { return x - Vec(alpha_val) * y; },
205             out.mutable_data_ptr<CTYPE>(),
206             lhs->const_data_ptr<CTYPE>(),
207             rhs->const_data_ptr<CTYPE>(),
208             lhs->sizes()[lhs->dim() - 2],
209             lhs->sizes()[lhs->dim() - 1]);
210       }
211     });
212   } else {
213     ScalarType common_type =
214         promoteTypes(a_type, b_type, /*half_to_float*/ true);
215     ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
216 
217     ET_KERNEL_CHECK(
218         ctx,
219         resize_to_broadcast_target_size(a, b, out) == Error::Ok,
220         InvalidArgument,
221         out);
222 
223     ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.out", CTYPE_A, [&]() {
224       ET_SWITCH_REALH_TYPES(b_type, ctx, "sub.out", CTYPE_B, [&]() {
225         using CTYPE_IN = typename torch::executor::
226             promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
227         ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
228         ET_SWITCH_REALH_TYPES(out_type, ctx, "sub.out", CTYPE_OUT, [&]() {
229           CTYPE_IN alpha_val;
230           ET_KERNEL_CHECK(
231               ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
232 
233           SubInner<
234               can_cast<CTYPE_IN, CTYPE_OUT>::value,
235               CTYPE_A,
236               CTYPE_B,
237               CTYPE_IN,
238               CTYPE_OUT>::run(a, b, alpha_val, out);
239         });
240       });
241     });
242   }
243 
244   return out;
245 }
246 
opt_sub_scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,const Scalar & alpha,Tensor & out)247 Tensor& opt_sub_scalar_out(
248     KernelRuntimeContext& ctx,
249     const Tensor& a,
250     const Scalar& b,
251     const Scalar& alpha,
252     Tensor& out) {
253   (void)ctx;
254 
255   ScalarType a_type = a.scalar_type();
256   ScalarType b_type = utils::get_scalar_dtype(b);
257   ScalarType common_type =
258       utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
259   ScalarType out_type = out.scalar_type();
260 
261   ET_CHECK(common_type == out_type);
262 
263   if (common_type == ScalarType::Half) {
264     common_type = ScalarType::Float;
265   }
266 
267   // Resize for dynamic shape
268   auto error = resize_tensor(out, a.sizes());
269   ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
270 
271   if (a_type == common_type && a_type == out_type &&
272       a_type != ScalarType::Half) {
273     ET_SWITCH_REAL_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE, [&]() {
274       ET_SWITCH_SCALAR_OBJ_REAL_TYPES(
275           b_type, ctx, "sub.Scalar_out", CTYPE_B, [&]() {
276             CTYPE_B b_val;
277             ET_EXTRACT_SCALAR(b, b_val);
278             CTYPE b_casted = static_cast<CTYPE>(b_val);
279             CTYPE alpha_val;
280             ET_EXTRACT_SCALAR(alpha, alpha_val);
281 
282             using Vec = executorch::vec::Vectorized<CTYPE>;
283             executorch::vec::map<CTYPE>(
284                 [alpha_val, b_casted](Vec x) {
285                   return x - Vec(alpha_val * b_casted);
286                 },
287                 out.mutable_data_ptr<CTYPE>(),
288                 a.const_data_ptr<CTYPE>(),
289                 out.numel());
290           });
291     });
292   } else {
293     ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE_A, [&]() {
294       ET_SWITCH_SCALAR_OBJ_REAL_TYPES(
295           b_type, ctx, "sub.Scalar_out", CTYPE_B, [&]() {
296             ET_SWITCH_REAL_TYPES(
297                 common_type, ctx, "sub.Scalar_out", CTYPE_IN, [&]() {
298                   ET_SWITCH_REALH_TYPES(
299                       out_type, ctx, "sub.Scalar_out", CTYPE_OUT, [&]() {
300                         CTYPE_B b_val;
301                         ET_EXTRACT_SCALAR(b, b_val);
302                         CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
303                         CTYPE_IN alpha_val;
304                         ET_EXTRACT_SCALAR(alpha, alpha_val);
305 
306                         const size_t n = a.numel();
307                         const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
308                         CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
309                         for (auto i = 0; i < n; ++i) {
310                           out_data[i] = static_cast<CTYPE_OUT>(
311                               static_cast<CTYPE_IN>(a_data[i]) -
312                               alpha_val * b_casted);
313                         }
314                       });
315                 });
316           });
317     });
318   }
319 
320   return out;
321 }
322 
323 } // namespace native
324 } // namespace executor
325 } // namespace torch
326