xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_div.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/scalar_utils.h>
10 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11 #include <executorch/kernels/portable/cpu/util/math_util.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 #include <executorch/runtime/platform/assert.h>
14 #include <cmath>
15 
16 namespace torch {
17 namespace executor {
18 namespace native {
19 
20 namespace {
21 
get_common_type(ScalarType a_type,ScalarType b_type)22 ScalarType get_common_type(ScalarType a_type, ScalarType b_type) {
23   if (isFloatingType(a_type) && isFloatingType(b_type)) {
24     return promoteTypes(a_type, b_type);
25   } else if (isFloatingType(a_type)) {
26     return a_type;
27   } else if (isFloatingType(b_type)) {
28     return b_type;
29   }
30   return ScalarType::Float;
31 }
32 
33 } // namespace
34 
div_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)35 Tensor& div_out(
36     KernelRuntimeContext& ctx,
37     const Tensor& a,
38     const Tensor& b,
39     Tensor& out) {
40   // Common Dtype
41   ScalarType common_type = get_common_type(a.scalar_type(), b.scalar_type());
42 
43   // Check Dim Order
44   ET_KERNEL_CHECK(
45       ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
46 
47   // Resize
48   ET_KERNEL_CHECK(
49       ctx,
50       resize_to_broadcast_target_size(a, b, out) == Error::Ok,
51       InvalidArgument,
52       out);
53 
54   // Compute Dtype
55   ScalarType compute_type = utils::get_compute_type(common_type);
56 
57   // @lint-ignore CLANGTIDY facebook-hte-CArray
58   static constexpr const char op_name[] = "div.out";
59 
60   ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
61     utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
62         [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
63           return val_a / val_b;
64         },
65         ctx,
66         a,
67         utils::SupportedTensorDtypes::REALHBBF16,
68         b,
69         utils::SupportedTensorDtypes::REALHBBF16,
70         out,
71         utils::SupportedTensorDtypes::FLOATHBF16);
72   });
73 
74   return out;
75 }
76 
div_out_mode(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,exec_aten::optional<exec_aten::string_view> mode,Tensor & out)77 Tensor& div_out_mode(
78     KernelRuntimeContext& ctx,
79     const Tensor& a,
80     const Tensor& b,
81     exec_aten::optional<exec_aten::string_view> mode,
82     Tensor& out) {
83   if (!mode.has_value()) {
84     return div_out(ctx, a, b, out);
85   }
86 
87   auto mode_val = mode.value();
88 
89   // Check mode
90   ET_KERNEL_CHECK(
91       ctx, mode_val == "trunc" || mode_val == "floor", InvalidArgument, out);
92 
93   // Common Dtype
94   ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
95 
96   // Check Common Dtype
97   ET_KERNEL_CHECK(
98       ctx,
99       (canCast(common_type, out.scalar_type()) &&
100        common_type != ScalarType::Bool),
101       InvalidArgument,
102       out);
103 
104   // Check Dim Order
105   ET_KERNEL_CHECK(
106       ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
107 
108   // Resize
109   ET_KERNEL_CHECK(
110       ctx,
111       resize_to_broadcast_target_size(a, b, out) == Error::Ok,
112       InvalidArgument,
113       out);
114 
115   // Compute Dtype
116   ScalarType compute_type = utils::get_compute_type(common_type);
117 
118   // @lint-ignore CLANGTIDY facebook-hte-CArray
119   static constexpr const char op_name[] = "div.out_mode";
120 
121   const bool mode_is_trunc = mode_val == "trunc";
122   bool div_by_zero_error = false;
123 
124   ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
125     utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
126         [mode_is_trunc, &div_by_zero_error](
127             const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
128           if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
129             if (val_b == 0) {
130               div_by_zero_error = true;
131               return static_cast<CTYPE_COMPUTE>(0);
132             }
133           }
134           CTYPE_COMPUTE value = val_a / val_b;
135           if (mode_is_trunc) {
136             value = std::trunc(value);
137           } else {
138             // We established above that the mode is either trunc or floor, so
139             // it must be floor.
140             value = utils::floor_divide(val_a, val_b);
141           }
142           return value;
143         },
144         ctx,
145         a,
146         utils::SupportedTensorDtypes::REALHBBF16,
147         b,
148         utils::SupportedTensorDtypes::REALHBBF16,
149         out,
150         utils::SupportedTensorDtypes::REALHBF16);
151   });
152 
153   ET_KERNEL_CHECK_MSG(
154       ctx,
155       !div_by_zero_error,
156       InvalidArgument,
157       out,
158       "Div mode operation encountered integer division by zero");
159 
160   return out;
161 }
162 
div_scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,Tensor & out)163 Tensor& div_scalar_out(
164     KernelRuntimeContext& ctx,
165     const Tensor& a,
166     const Scalar& b,
167     Tensor& out) {
168   // Common Dtype
169   ScalarType common_type =
170       isFloatingType(a.scalar_type()) ? a.scalar_type() : ScalarType::Float;
171 
172   // Check Common Dtype
173   ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out);
174 
175   // Check Dim Order
176   ET_KERNEL_CHECK(
177       ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
178 
179   // Resize
180   ET_KERNEL_CHECK(
181       ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
182 
183   // Compute Dtype
184   ScalarType compute_type = utils::get_compute_type(common_type);
185 
186   // @lint-ignore CLANGTIDY facebook-hte-CArray
187   static constexpr const char op_name[] = "div.Scalar_out";
188 
189   ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
190     const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
191     utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
192         [val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; },
193         ctx,
194         a,
195         utils::SupportedTensorDtypes::REALHBBF16,
196         out,
197         utils::SupportedTensorDtypes::SAME_AS_COMMON);
198   });
199 
200   return out;
201 }
202 
div_scalar_mode_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,exec_aten::optional<exec_aten::string_view> mode,Tensor & out)203 Tensor& div_scalar_mode_out(
204     KernelRuntimeContext& ctx,
205     const Tensor& a,
206     const Scalar& b,
207     exec_aten::optional<exec_aten::string_view> mode,
208     Tensor& out) {
209   if (!mode.has_value()) {
210     return div_scalar_out(ctx, a, b, out);
211   }
212 
213   auto mode_val = mode.value();
214 
215   // Check mode
216   ET_KERNEL_CHECK(
217       ctx, mode_val == "trunc" || mode_val == "floor", InvalidArgument, out);
218 
219   // Common Dtype
220   ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
221 
222   // Check Common Dtype
223   ET_KERNEL_CHECK(
224       ctx,
225       (canCast(common_type, out.scalar_type()) &&
226        common_type != ScalarType::Bool),
227       InvalidArgument,
228       out);
229 
230   // Check for intergral division by zero
231   ET_KERNEL_CHECK_MSG(
232       ctx,
233       !(executorch::runtime::isIntegralType(common_type, true) &&
234         utils::scalar_to<double>(b) == 0),
235       InvalidArgument,
236       out,
237       "Div mode operation encountered integer division by zero");
238 
239   // Check Dim Order
240   ET_KERNEL_CHECK(
241       ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
242 
243   // Resize
244   ET_KERNEL_CHECK(
245       ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
246 
247   // Compute Dtype
248   ScalarType compute_type = utils::get_compute_type(common_type);
249 
250   const bool mode_is_trunc = mode_val == "trunc";
251 
252   // @lint-ignore CLANGTIDY facebook-hte-CArray
253   static constexpr const char op_name[] = "div.Scalar_mode_out";
254 
255   ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
256     const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
257     utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
258         [val_b, mode_is_trunc](const CTYPE_COMPUTE val_a) {
259           CTYPE_COMPUTE value = val_a / val_b;
260           if (mode_is_trunc) {
261             value = std::trunc(value);
262           } else {
263             value = utils::floor_divide(val_a, val_b);
264           }
265           return value;
266         },
267         ctx,
268         a,
269         utils::SupportedTensorDtypes::REALHBBF16,
270         out,
271         utils::SupportedTensorDtypes::REALHBF16);
272   });
273 
274   return out;
275 }
276 
277 } // namespace native
278 } // namespace executor
279 } // namespace torch
280