1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/scalar_utils.h>
10 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11 #include <executorch/kernels/portable/cpu/util/math_util.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 #include <executorch/runtime/platform/assert.h>
14 #include <cmath>
15
16 namespace torch {
17 namespace executor {
18 namespace native {
19
20 namespace {
21
get_common_type(ScalarType a_type,ScalarType b_type)22 ScalarType get_common_type(ScalarType a_type, ScalarType b_type) {
23 if (isFloatingType(a_type) && isFloatingType(b_type)) {
24 return promoteTypes(a_type, b_type);
25 } else if (isFloatingType(a_type)) {
26 return a_type;
27 } else if (isFloatingType(b_type)) {
28 return b_type;
29 }
30 return ScalarType::Float;
31 }
32
33 } // namespace
34
div_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)35 Tensor& div_out(
36 KernelRuntimeContext& ctx,
37 const Tensor& a,
38 const Tensor& b,
39 Tensor& out) {
40 // Common Dtype
41 ScalarType common_type = get_common_type(a.scalar_type(), b.scalar_type());
42
43 // Check Dim Order
44 ET_KERNEL_CHECK(
45 ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
46
47 // Resize
48 ET_KERNEL_CHECK(
49 ctx,
50 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
51 InvalidArgument,
52 out);
53
54 // Compute Dtype
55 ScalarType compute_type = utils::get_compute_type(common_type);
56
57 // @lint-ignore CLANGTIDY facebook-hte-CArray
58 static constexpr const char op_name[] = "div.out";
59
60 ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
61 utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
62 [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
63 return val_a / val_b;
64 },
65 ctx,
66 a,
67 utils::SupportedTensorDtypes::REALHBBF16,
68 b,
69 utils::SupportedTensorDtypes::REALHBBF16,
70 out,
71 utils::SupportedTensorDtypes::FLOATHBF16);
72 });
73
74 return out;
75 }
76
div_out_mode(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,exec_aten::optional<exec_aten::string_view> mode,Tensor & out)77 Tensor& div_out_mode(
78 KernelRuntimeContext& ctx,
79 const Tensor& a,
80 const Tensor& b,
81 exec_aten::optional<exec_aten::string_view> mode,
82 Tensor& out) {
83 if (!mode.has_value()) {
84 return div_out(ctx, a, b, out);
85 }
86
87 auto mode_val = mode.value();
88
89 // Check mode
90 ET_KERNEL_CHECK(
91 ctx, mode_val == "trunc" || mode_val == "floor", InvalidArgument, out);
92
93 // Common Dtype
94 ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
95
96 // Check Common Dtype
97 ET_KERNEL_CHECK(
98 ctx,
99 (canCast(common_type, out.scalar_type()) &&
100 common_type != ScalarType::Bool),
101 InvalidArgument,
102 out);
103
104 // Check Dim Order
105 ET_KERNEL_CHECK(
106 ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
107
108 // Resize
109 ET_KERNEL_CHECK(
110 ctx,
111 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
112 InvalidArgument,
113 out);
114
115 // Compute Dtype
116 ScalarType compute_type = utils::get_compute_type(common_type);
117
118 // @lint-ignore CLANGTIDY facebook-hte-CArray
119 static constexpr const char op_name[] = "div.out_mode";
120
121 const bool mode_is_trunc = mode_val == "trunc";
122 bool div_by_zero_error = false;
123
124 ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
125 utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
126 [mode_is_trunc, &div_by_zero_error](
127 const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
128 if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
129 if (val_b == 0) {
130 div_by_zero_error = true;
131 return static_cast<CTYPE_COMPUTE>(0);
132 }
133 }
134 CTYPE_COMPUTE value = val_a / val_b;
135 if (mode_is_trunc) {
136 value = std::trunc(value);
137 } else {
138 // We established above that the mode is either trunc or floor, so
139 // it must be floor.
140 value = utils::floor_divide(val_a, val_b);
141 }
142 return value;
143 },
144 ctx,
145 a,
146 utils::SupportedTensorDtypes::REALHBBF16,
147 b,
148 utils::SupportedTensorDtypes::REALHBBF16,
149 out,
150 utils::SupportedTensorDtypes::REALHBF16);
151 });
152
153 ET_KERNEL_CHECK_MSG(
154 ctx,
155 !div_by_zero_error,
156 InvalidArgument,
157 out,
158 "Div mode operation encountered integer division by zero");
159
160 return out;
161 }
162
div_scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,Tensor & out)163 Tensor& div_scalar_out(
164 KernelRuntimeContext& ctx,
165 const Tensor& a,
166 const Scalar& b,
167 Tensor& out) {
168 // Common Dtype
169 ScalarType common_type =
170 isFloatingType(a.scalar_type()) ? a.scalar_type() : ScalarType::Float;
171
172 // Check Common Dtype
173 ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out);
174
175 // Check Dim Order
176 ET_KERNEL_CHECK(
177 ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
178
179 // Resize
180 ET_KERNEL_CHECK(
181 ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
182
183 // Compute Dtype
184 ScalarType compute_type = utils::get_compute_type(common_type);
185
186 // @lint-ignore CLANGTIDY facebook-hte-CArray
187 static constexpr const char op_name[] = "div.Scalar_out";
188
189 ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
190 const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
191 utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
192 [val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; },
193 ctx,
194 a,
195 utils::SupportedTensorDtypes::REALHBBF16,
196 out,
197 utils::SupportedTensorDtypes::SAME_AS_COMMON);
198 });
199
200 return out;
201 }
202
div_scalar_mode_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,exec_aten::optional<exec_aten::string_view> mode,Tensor & out)203 Tensor& div_scalar_mode_out(
204 KernelRuntimeContext& ctx,
205 const Tensor& a,
206 const Scalar& b,
207 exec_aten::optional<exec_aten::string_view> mode,
208 Tensor& out) {
209 if (!mode.has_value()) {
210 return div_scalar_out(ctx, a, b, out);
211 }
212
213 auto mode_val = mode.value();
214
215 // Check mode
216 ET_KERNEL_CHECK(
217 ctx, mode_val == "trunc" || mode_val == "floor", InvalidArgument, out);
218
219 // Common Dtype
220 ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
221
222 // Check Common Dtype
223 ET_KERNEL_CHECK(
224 ctx,
225 (canCast(common_type, out.scalar_type()) &&
226 common_type != ScalarType::Bool),
227 InvalidArgument,
228 out);
229
230 // Check for intergral division by zero
231 ET_KERNEL_CHECK_MSG(
232 ctx,
233 !(executorch::runtime::isIntegralType(common_type, true) &&
234 utils::scalar_to<double>(b) == 0),
235 InvalidArgument,
236 out,
237 "Div mode operation encountered integer division by zero");
238
239 // Check Dim Order
240 ET_KERNEL_CHECK(
241 ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
242
243 // Resize
244 ET_KERNEL_CHECK(
245 ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
246
247 // Compute Dtype
248 ScalarType compute_type = utils::get_compute_type(common_type);
249
250 const bool mode_is_trunc = mode_val == "trunc";
251
252 // @lint-ignore CLANGTIDY facebook-hte-CArray
253 static constexpr const char op_name[] = "div.Scalar_mode_out";
254
255 ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
256 const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
257 utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
258 [val_b, mode_is_trunc](const CTYPE_COMPUTE val_a) {
259 CTYPE_COMPUTE value = val_a / val_b;
260 if (mode_is_trunc) {
261 value = std::trunc(value);
262 } else {
263 value = utils::floor_divide(val_a, val_b);
264 }
265 return value;
266 },
267 ctx,
268 a,
269 utils::SupportedTensorDtypes::REALHBBF16,
270 out,
271 utils::SupportedTensorDtypes::REALHBF16);
272 });
273
274 return out;
275 }
276
277 } // namespace native
278 } // namespace executor
279 } // namespace torch
280