1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/optimized/cpu/binary_ops.h>
10 #include <executorch/kernels/optimized/vec/functional.h>
11 #include <executorch/kernels/optimized/vec/vec.h>
12 #include <executorch/kernels/portable/cpu/scalar_utils.h>
13 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15 #include <executorch/runtime/platform/assert.h>
16
17 namespace torch {
18 namespace executor {
19 namespace native {
20
21 namespace {
22
get_compute_type(ScalarType a_type,ScalarType b_type)23 ScalarType get_compute_type(ScalarType a_type, ScalarType b_type) {
24 ET_CHECK(
25 !isComplexType(a_type) && !isQIntType(a_type) && !isBitsType(a_type));
26 ET_CHECK(
27 !isComplexType(b_type) && !isQIntType(b_type) && !isBitsType(b_type));
28
29 if (isFloatingType(a_type) && isFloatingType(b_type)) {
30 return promoteTypes(a_type, b_type);
31 } else if (isFloatingType(a_type)) {
32 return a_type;
33 } else if (isFloatingType(b_type)) {
34 return b_type;
35 }
36 return ScalarType::Float;
37 }
38
39 } // namespace
40
opt_div_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)41 Tensor& opt_div_out(
42 KernelRuntimeContext& ctx,
43 const Tensor& a,
44 const Tensor& b,
45 Tensor& out) {
46 (void)ctx;
47
48 ScalarType a_type = a.scalar_type();
49 ScalarType b_type = b.scalar_type();
50 ScalarType out_type = out.scalar_type();
51
52 if (a.numel() == 1 || b.numel() == 1) {
53 if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) {
54 const Tensor* tensor;
55 const Tensor* scalar;
56 ScalarType tensor_type;
57 ScalarType scalar_type;
58 if (a.numel() == 1) {
59 tensor = &b;
60 tensor_type = b_type;
61 scalar = &a;
62 scalar_type = a_type;
63 } else {
64 tensor = &a;
65 tensor_type = a_type;
66 scalar = &b;
67 scalar_type = b_type;
68 }
69 ET_KERNEL_CHECK(
70 ctx,
71 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
72 InvalidArgument,
73 out);
74 ET_SWITCH_REALB_TYPES(tensor_type, ctx, "div.out", CTYPE, [&]() {
75 ET_SWITCH_REALB_TYPES(scalar_type, ctx, "div.out", CTYPE_SCALAR, [&]() {
76 CTYPE_SCALAR scalar_val = *scalar->const_data_ptr<CTYPE_SCALAR>();
77 CTYPE scalar_casted = static_cast<CTYPE>(scalar_val);
78
79 using Vec = executorch::vec::Vectorized<CTYPE>;
80 if (a.numel() == 1) {
81 executorch::vec::map<CTYPE>(
82 [scalar_casted](Vec x) { return Vec(scalar_casted) / x; },
83 out.mutable_data_ptr<CTYPE>(),
84 tensor->const_data_ptr<CTYPE>(),
85 out.numel());
86 } else {
87 Vec inv_scalar_casted_vec(CTYPE(1) / scalar_casted);
88 executorch::vec::map<CTYPE>(
89 [inv_scalar_casted_vec](Vec x) {
90 return x * inv_scalar_casted_vec;
91 },
92 out.mutable_data_ptr<CTYPE>(),
93 tensor->const_data_ptr<CTYPE>(),
94 out.numel());
95 }
96 });
97 });
98 return out;
99 }
100 }
101
102 auto selected_optimized_path = select_optimized_path(a, b, out);
103 if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
104 // Resize for dynamic shape
105 auto error = resize_tensor(out, a.sizes());
106 ET_KERNEL_CHECK_MSG(
107 ctx,
108 error == Error::Ok,
109 InvalidArgument,
110 out,
111 "Failed to resize output tensor.");
112
113 ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "div.out", CTYPE, [&]() {
114 using Vec = executorch::vec::Vectorized<CTYPE>;
115 executorch::vec::map2<CTYPE>(
116 [](Vec x, Vec y) { return x / y; },
117 out.mutable_data_ptr<CTYPE>(),
118 a.const_data_ptr<CTYPE>(),
119 b.const_data_ptr<CTYPE>(),
120 out.numel());
121 });
122 } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
123 const Tensor* lhs;
124 const Tensor* rhs;
125 if (selected_optimized_path ==
126 ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
127 lhs = &b;
128 rhs = &a;
129 } else {
130 // Catch failure to update logic when subing new broadcasting possibility.
131 ET_DCHECK(
132 selected_optimized_path ==
133 ElementwiseOptimizedPath::kBroadcast2dBy1d);
134 lhs = &a;
135 rhs = &b;
136 }
137 auto error = resize_tensor(out, lhs->sizes());
138 ET_KERNEL_CHECK_MSG(
139 ctx,
140 error == Error::Ok,
141 InvalidArgument,
142 out,
143 "Failed to resize output tensor.");
144 ET_SWITCH_REALB_TYPES(out_type, ctx, "sub.out", CTYPE, [&]() {
145 using Vec = executorch::vec::Vectorized<CTYPE>;
146 if (selected_optimized_path ==
147 ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
148 executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
149 [](Vec x, Vec y) { return y / x; },
150 out.mutable_data_ptr<CTYPE>(),
151 lhs->const_data_ptr<CTYPE>(),
152 rhs->const_data_ptr<CTYPE>(),
153 lhs->sizes()[lhs->dim() - 2],
154 lhs->sizes()[lhs->dim() - 1]);
155 } else {
156 executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
157 [](Vec x, Vec y) { return x / y; },
158 out.mutable_data_ptr<CTYPE>(),
159 lhs->const_data_ptr<CTYPE>(),
160 rhs->const_data_ptr<CTYPE>(),
161 lhs->sizes()[lhs->dim() - 2],
162 lhs->sizes()[lhs->dim() - 1]);
163 }
164 });
165 } else {
166 ScalarType common_type = get_compute_type(a_type, b_type);
167 ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
168
169 ET_KERNEL_CHECK(
170 ctx,
171 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
172 InvalidArgument,
173 out);
174
175 ET_SWITCH_REALB_TYPES(a_type, ctx, "div.out", CTYPE_A, [&]() {
176 ET_SWITCH_REALB_TYPES(b_type, ctx, "div.out", CTYPE_B, [&]() {
177 ET_SWITCH_REALB_TYPES(common_type, ctx, "div.out", CTYPE_IN, [&]() {
178 ET_SWITCH_REALB_TYPES(out_type, ctx, "div.out", CTYPE_OUT, [&]() {
179 apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
180 [](const CTYPE_A val_a, const CTYPE_B val_b) {
181 CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
182 CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
183 CTYPE_IN value = a_casted / b_casted;
184
185 return static_cast<CTYPE_OUT>(value);
186 },
187 a,
188 b,
189 out);
190 });
191 });
192 });
193 });
194 }
195
196 return out;
197 }
198
opt_div_scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,Tensor & out)199 Tensor& opt_div_scalar_out(
200 KernelRuntimeContext& ctx,
201 const Tensor& a,
202 const Scalar& b,
203 Tensor& out) {
204 (void)ctx;
205
206 ScalarType a_type = a.scalar_type();
207 ScalarType b_type = utils::get_scalar_dtype(b);
208 ScalarType common_type = isFloatingType(a_type) ? a_type : ScalarType::Float;
209 ScalarType out_type = out.scalar_type();
210
211 ET_CHECK(common_type == out_type);
212
213 // Resize for dynamic shape
214 auto error = resize_tensor(out, a.sizes());
215 ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
216
217 if (a_type == common_type && a_type == out_type) {
218 ET_SWITCH_REAL_TYPES(a_type, ctx, "div.Scalar_out", CTYPE, [&]() {
219 ET_SWITCH_REAL_TYPES_AND(
220 Bool, b_type, ctx, "div.Scalar_out", CTYPE_B, [&]() {
221 CTYPE_B b_val;
222 ET_EXTRACT_SCALAR(b, b_val);
223 CTYPE b_casted = static_cast<CTYPE>(b_val);
224
225 using Vec = executorch::vec::Vectorized<CTYPE>;
226 Vec inv_b_casted_vec(CTYPE(1) / b_casted);
227 executorch::vec::map<CTYPE>(
228 [inv_b_casted_vec](Vec x) { return x * inv_b_casted_vec; },
229 out.mutable_data_ptr<CTYPE>(),
230 a.const_data_ptr<CTYPE>(),
231 out.numel());
232 });
233 });
234 } else {
235 ET_SWITCH_REAL_TYPES_AND(
236 Bool, a_type, ctx, "div.Scalar_out", CTYPE_A, [&]() {
237 ET_SWITCH_REAL_TYPES_AND(
238 Bool, b_type, ctx, "div.Scalar_out", CTYPE_B, [&]() {
239 ET_SWITCH_REAL_TYPES(
240 common_type, ctx, "div.Scalar_out", CTYPE_IN, [&]() {
241 ET_SWITCH_REAL_TYPES(
242 out_type, ctx, "div.Scalar_out", CTYPE_OUT, [&]() {
243 CTYPE_B b_val;
244 ET_EXTRACT_SCALAR(b, b_val);
245 CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
246 CTYPE_IN inv_b_casted = CTYPE_IN(1) / b_casted;
247
248 const size_t n = a.numel();
249 const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
250 CTYPE_OUT* out_data =
251 out.mutable_data_ptr<CTYPE_OUT>();
252 for (auto i = 0; i < n; ++i) {
253 out_data[i] = static_cast<CTYPE_OUT>(
254 static_cast<CTYPE_IN>(a_data[i]) *
255 inv_b_casted);
256 }
257 });
258 });
259 });
260 });
261 }
262
263 return out;
264 }
265
266 } // namespace native
267 } // namespace executor
268 } // namespace torch
269