1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/optimized/cpu/binary_ops.h>
10 #include <executorch/kernels/optimized/vec/functional.h>
11 #include <executorch/kernels/optimized/vec/vec.h>
12 #include <executorch/kernels/portable/cpu/scalar_utils.h>
13 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
14 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
15 #include <executorch/runtime/kernel/kernel_includes.h>
16 #include <executorch/runtime/platform/assert.h>
17
18 namespace torch {
19 namespace executor {
20 namespace native {
21 namespace {
22
23 template <
24 bool can_cast,
25 typename CTYPE_A,
26 typename CTYPE_B,
27 typename CTYPE_IN,
28 typename CTYPE_OUT>
29 struct SubInner;
30
31 template <
32 typename CTYPE_A,
33 typename CTYPE_B,
34 typename CTYPE_IN,
35 typename CTYPE_OUT>
36 struct SubInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
37 static void
runtorch::executor::native::__anond8b6c1870111::SubInner38 run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
39 apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
40 // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
41 [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
42 CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
43 CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
44 CTYPE_IN value = a_casted - alpha_val * b_casted;
45
46 return static_cast<CTYPE_OUT>(value);
47 },
48 a,
49 b,
50 out);
51 }
52 };
53
54 template <typename CTYPE_IN>
55 struct ReportCanCastBug {
runtorch::executor::native::__anond8b6c1870111::ReportCanCastBug56 static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
57 ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
58 }
59 };
60
61 template <
62 typename CTYPE_A,
63 typename CTYPE_B,
64 typename CTYPE_IN,
65 typename CTYPE_OUT>
66 struct SubInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
67 : public ReportCanCastBug<CTYPE_IN> {};
68
69 } // namespace
70
71 using Tensor = exec_aten::Tensor;
72 using ScalarType = exec_aten::ScalarType;
73
opt_sub_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,const Scalar & alpha,Tensor & out)74 Tensor& opt_sub_out(
75 KernelRuntimeContext& ctx,
76 const Tensor& a,
77 const Tensor& b,
78 const Scalar& alpha,
79 Tensor& out) {
80 (void)ctx;
81
82 ScalarType a_type = a.scalar_type();
83 ScalarType b_type = b.scalar_type();
84 ScalarType out_type = out.scalar_type();
85
86 ET_KERNEL_CHECK(ctx, tensor_is_realh_type(out), InvalidArgument, out);
87 if (a.numel() == 1 || b.numel() == 1) {
88 if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) {
89 const Tensor* tensor;
90 const Tensor* scalar;
91 ScalarType tensor_type;
92 ScalarType scalar_type;
93 if (a.numel() == 1) {
94 tensor = &b;
95 tensor_type = b_type;
96 scalar = &a;
97 scalar_type = a_type;
98 } else {
99 tensor = &a;
100 tensor_type = a_type;
101 scalar = &b;
102 scalar_type = b_type;
103 }
104 ET_KERNEL_CHECK(
105 ctx,
106 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
107 InvalidArgument,
108 out);
109 ET_SWITCH_REAL_TYPES(tensor_type, ctx, "sub.out", CTYPE, [&]() {
110 ET_SWITCH_REAL_TYPES(scalar_type, ctx, "sub.out", CTYPE_SCALAR, [&]() {
111 CTYPE alpha_val;
112 ET_KERNEL_CHECK(
113 ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
114 CTYPE_SCALAR scalar_val = *scalar->const_data_ptr<CTYPE_SCALAR>();
115 CTYPE scalar_casted = static_cast<CTYPE>(scalar_val);
116
117 using Vec = executorch::vec::Vectorized<CTYPE>;
118 if (a.numel() == 1) {
119 executorch::vec::map<CTYPE>(
120 [alpha_val, scalar_casted](Vec x) {
121 return Vec(scalar_casted) - Vec(alpha_val) * x;
122 },
123 out.mutable_data_ptr<CTYPE>(),
124 tensor->const_data_ptr<CTYPE>(),
125 out.numel());
126 } else {
127 executorch::vec::map<CTYPE>(
128 [alpha_val, scalar_casted](Vec x) {
129 return x - Vec(alpha_val * scalar_casted);
130 },
131 out.mutable_data_ptr<CTYPE>(),
132 tensor->const_data_ptr<CTYPE>(),
133 out.numel());
134 }
135 });
136 });
137 return out;
138 }
139 }
140
141 auto selected_optimized_path = select_optimized_path(a, b, out);
142 if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
143 // Resize for dynamic shape
144 auto error = resize_tensor(out, a.sizes());
145 ET_KERNEL_CHECK_MSG(
146 ctx,
147 error == Error::Ok,
148 InvalidArgument,
149 out,
150 "Failed to resize output tensor.");
151
152 ET_SWITCH_REAL_TYPES(a_type, ctx, "sub.out", CTYPE, [&]() {
153 CTYPE alpha_val;
154 ET_KERNEL_CHECK(
155 ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
156
157 using Vec = executorch::vec::Vectorized<CTYPE>;
158 executorch::vec::map2<CTYPE>(
159 [alpha_val](Vec x, Vec y) { return x - Vec(alpha_val) * y; },
160 out.mutable_data_ptr<CTYPE>(),
161 a.const_data_ptr<CTYPE>(),
162 b.const_data_ptr<CTYPE>(),
163 out.numel());
164 });
165 } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
166 const Tensor* lhs;
167 const Tensor* rhs;
168 if (selected_optimized_path ==
169 ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
170 lhs = &b;
171 rhs = &a;
172 } else {
173 // Catch failure to update logic when subing new broadcasting possibility.
174 ET_DCHECK(
175 selected_optimized_path ==
176 ElementwiseOptimizedPath::kBroadcast2dBy1d);
177 lhs = &a;
178 rhs = &b;
179 }
180 auto error = resize_tensor(out, lhs->sizes());
181 ET_KERNEL_CHECK_MSG(
182 ctx,
183 error == Error::Ok,
184 InvalidArgument,
185 out,
186 "Failed to resize output tensor.");
187 ET_SWITCH_REAL_TYPES(out_type, ctx, "sub.out", CTYPE, [&]() {
188 CTYPE alpha_val;
189 ET_KERNEL_CHECK(
190 ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
191
192 using Vec = executorch::vec::Vectorized<CTYPE>;
193 if (selected_optimized_path ==
194 ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
195 executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
196 [alpha_val](Vec x, Vec y) { return y - Vec(alpha_val) * x; },
197 out.mutable_data_ptr<CTYPE>(),
198 lhs->const_data_ptr<CTYPE>(),
199 rhs->const_data_ptr<CTYPE>(),
200 lhs->sizes()[lhs->dim() - 2],
201 lhs->sizes()[lhs->dim() - 1]);
202 } else {
203 executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
204 [alpha_val](Vec x, Vec y) { return x - Vec(alpha_val) * y; },
205 out.mutable_data_ptr<CTYPE>(),
206 lhs->const_data_ptr<CTYPE>(),
207 rhs->const_data_ptr<CTYPE>(),
208 lhs->sizes()[lhs->dim() - 2],
209 lhs->sizes()[lhs->dim() - 1]);
210 }
211 });
212 } else {
213 ScalarType common_type =
214 promoteTypes(a_type, b_type, /*half_to_float*/ true);
215 ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
216
217 ET_KERNEL_CHECK(
218 ctx,
219 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
220 InvalidArgument,
221 out);
222
223 ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.out", CTYPE_A, [&]() {
224 ET_SWITCH_REALH_TYPES(b_type, ctx, "sub.out", CTYPE_B, [&]() {
225 using CTYPE_IN = typename torch::executor::
226 promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
227 ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
228 ET_SWITCH_REALH_TYPES(out_type, ctx, "sub.out", CTYPE_OUT, [&]() {
229 CTYPE_IN alpha_val;
230 ET_KERNEL_CHECK(
231 ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
232
233 SubInner<
234 can_cast<CTYPE_IN, CTYPE_OUT>::value,
235 CTYPE_A,
236 CTYPE_B,
237 CTYPE_IN,
238 CTYPE_OUT>::run(a, b, alpha_val, out);
239 });
240 });
241 });
242 }
243
244 return out;
245 }
246
opt_sub_scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,const Scalar & alpha,Tensor & out)247 Tensor& opt_sub_scalar_out(
248 KernelRuntimeContext& ctx,
249 const Tensor& a,
250 const Scalar& b,
251 const Scalar& alpha,
252 Tensor& out) {
253 (void)ctx;
254
255 ScalarType a_type = a.scalar_type();
256 ScalarType b_type = utils::get_scalar_dtype(b);
257 ScalarType common_type =
258 utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
259 ScalarType out_type = out.scalar_type();
260
261 ET_CHECK(common_type == out_type);
262
263 if (common_type == ScalarType::Half) {
264 common_type = ScalarType::Float;
265 }
266
267 // Resize for dynamic shape
268 auto error = resize_tensor(out, a.sizes());
269 ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
270
271 if (a_type == common_type && a_type == out_type &&
272 a_type != ScalarType::Half) {
273 ET_SWITCH_REAL_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE, [&]() {
274 ET_SWITCH_SCALAR_OBJ_REAL_TYPES(
275 b_type, ctx, "sub.Scalar_out", CTYPE_B, [&]() {
276 CTYPE_B b_val;
277 ET_EXTRACT_SCALAR(b, b_val);
278 CTYPE b_casted = static_cast<CTYPE>(b_val);
279 CTYPE alpha_val;
280 ET_EXTRACT_SCALAR(alpha, alpha_val);
281
282 using Vec = executorch::vec::Vectorized<CTYPE>;
283 executorch::vec::map<CTYPE>(
284 [alpha_val, b_casted](Vec x) {
285 return x - Vec(alpha_val * b_casted);
286 },
287 out.mutable_data_ptr<CTYPE>(),
288 a.const_data_ptr<CTYPE>(),
289 out.numel());
290 });
291 });
292 } else {
293 ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.Scalar_out", CTYPE_A, [&]() {
294 ET_SWITCH_SCALAR_OBJ_REAL_TYPES(
295 b_type, ctx, "sub.Scalar_out", CTYPE_B, [&]() {
296 ET_SWITCH_REAL_TYPES(
297 common_type, ctx, "sub.Scalar_out", CTYPE_IN, [&]() {
298 ET_SWITCH_REALH_TYPES(
299 out_type, ctx, "sub.Scalar_out", CTYPE_OUT, [&]() {
300 CTYPE_B b_val;
301 ET_EXTRACT_SCALAR(b, b_val);
302 CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
303 CTYPE_IN alpha_val;
304 ET_EXTRACT_SCALAR(alpha, alpha_val);
305
306 const size_t n = a.numel();
307 const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
308 CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
309 for (auto i = 0; i < n; ++i) {
310 out_data[i] = static_cast<CTYPE_OUT>(
311 static_cast<CTYPE_IN>(a_data[i]) -
312 alpha_val * b_casted);
313 }
314 });
315 });
316 });
317 });
318 }
319
320 return out;
321 }
322
323 } // namespace native
324 } // namespace executor
325 } // namespace torch
326