xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_remainder.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cmath>
10 
11 #include <executorch/kernels/portable/cpu/scalar_utils.h>
12 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
13 #include <executorch/kernels/portable/cpu/util/math_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15 
16 namespace torch {
17 namespace executor {
18 namespace native {
19 
remainder_Tensor_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)20 Tensor& remainder_Tensor_out(
21     KernelRuntimeContext& ctx,
22     const Tensor& a,
23     const Tensor& b,
24     Tensor& out) {
25   // Common Dtype
26   ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
27 
28   // Check Common Dtype
29   ET_KERNEL_CHECK(
30       ctx,
31       (canCast(common_type, out.scalar_type()) &&
32        common_type != ScalarType::Bool),
33       InvalidArgument,
34       out);
35 
36   // Check Dim Order
37   ET_KERNEL_CHECK(
38       ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
39 
40   // Resize
41   ET_KERNEL_CHECK(
42       ctx,
43       resize_to_broadcast_target_size(a, b, out) == Error::Ok,
44       InvalidArgument,
45       out);
46 
47   // Compute Dtype
48   ScalarType compute_type = utils::get_compute_type(common_type);
49 
50   // @lint-ignore CLANGTIDY facebook-hte-CArray
51   static constexpr const char op_name[] = "remainder.Tensor_out";
52 
53   bool div_by_zero_error = false;
54 
55   ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
56     utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
57         [&div_by_zero_error](
58             const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
59           CTYPE_COMPUTE value = 0;
60           if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
61             if (val_b == 0) {
62               div_by_zero_error = true;
63               return value;
64             }
65           }
66           value = utils::remainder_override(val_a, val_b);
67           return value;
68         },
69         ctx,
70         a,
71         utils::SupportedTensorDtypes::REALHBBF16,
72         b,
73         utils::SupportedTensorDtypes::REALHBBF16,
74         out,
75         utils::SupportedTensorDtypes::REALHBF16);
76   });
77 
78   ET_KERNEL_CHECK_MSG(
79       ctx,
80       !div_by_zero_error,
81       InvalidArgument,
82       out,
83       "Remainder operation encountered integer division by zero");
84 
85   return out;
86 }
87 
remainder_Scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,Tensor & out)88 Tensor& remainder_Scalar_out(
89     KernelRuntimeContext& ctx,
90     const Tensor& a,
91     const Scalar& b,
92     Tensor& out) {
93   // Common Dtype
94   ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
95 
96   // Check Common Dtype
97   ET_KERNEL_CHECK(
98       ctx,
99       (canCast(common_type, out.scalar_type()) &&
100        common_type != ScalarType::Bool),
101       InvalidArgument,
102       out);
103 
104   // Check for intergral division by zero
105   ET_KERNEL_CHECK_MSG(
106       ctx,
107       !(executorch::runtime::isIntegralType(common_type, true) &&
108         utils::scalar_to<double>(b) == 0),
109       InvalidArgument,
110       out,
111       "Remainder operation encountered integer division by zero");
112 
113   // Check Dim Order
114   ET_KERNEL_CHECK(
115       ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
116 
117   // Resize
118   ET_KERNEL_CHECK(
119       ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
120 
121   // Compute Dtype
122   ScalarType compute_type = utils::get_compute_type(common_type);
123 
124   // @lint-ignore CLANGTIDY facebook-hte-CArray
125   static constexpr const char op_name[] = "remainder.Scalar_out";
126 
127   ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
128     const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
129     utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
130         [val_b](const CTYPE_COMPUTE val_a) {
131           return utils::remainder_override(val_a, val_b);
132         },
133         ctx,
134         a,
135         utils::SupportedTensorDtypes::REALHBBF16,
136         out,
137         utils::SupportedTensorDtypes::REALHBF16);
138   });
139 
140   return out;
141 }
142 
143 } // namespace native
144 } // namespace executor
145 } // namespace torch
146