1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cmath>
10
11 #include <executorch/kernels/portable/cpu/scalar_utils.h>
12 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
13 #include <executorch/kernels/portable/cpu/util/math_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15
16 namespace torch {
17 namespace executor {
18 namespace native {
19
remainder_Tensor_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)20 Tensor& remainder_Tensor_out(
21 KernelRuntimeContext& ctx,
22 const Tensor& a,
23 const Tensor& b,
24 Tensor& out) {
25 // Common Dtype
26 ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
27
28 // Check Common Dtype
29 ET_KERNEL_CHECK(
30 ctx,
31 (canCast(common_type, out.scalar_type()) &&
32 common_type != ScalarType::Bool),
33 InvalidArgument,
34 out);
35
36 // Check Dim Order
37 ET_KERNEL_CHECK(
38 ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
39
40 // Resize
41 ET_KERNEL_CHECK(
42 ctx,
43 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
44 InvalidArgument,
45 out);
46
47 // Compute Dtype
48 ScalarType compute_type = utils::get_compute_type(common_type);
49
50 // @lint-ignore CLANGTIDY facebook-hte-CArray
51 static constexpr const char op_name[] = "remainder.Tensor_out";
52
53 bool div_by_zero_error = false;
54
55 ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
56 utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
57 [&div_by_zero_error](
58 const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
59 CTYPE_COMPUTE value = 0;
60 if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
61 if (val_b == 0) {
62 div_by_zero_error = true;
63 return value;
64 }
65 }
66 value = utils::remainder_override(val_a, val_b);
67 return value;
68 },
69 ctx,
70 a,
71 utils::SupportedTensorDtypes::REALHBBF16,
72 b,
73 utils::SupportedTensorDtypes::REALHBBF16,
74 out,
75 utils::SupportedTensorDtypes::REALHBF16);
76 });
77
78 ET_KERNEL_CHECK_MSG(
79 ctx,
80 !div_by_zero_error,
81 InvalidArgument,
82 out,
83 "Remainder operation encountered integer division by zero");
84
85 return out;
86 }
87
remainder_Scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,Tensor & out)88 Tensor& remainder_Scalar_out(
89 KernelRuntimeContext& ctx,
90 const Tensor& a,
91 const Scalar& b,
92 Tensor& out) {
93 // Common Dtype
94 ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
95
96 // Check Common Dtype
97 ET_KERNEL_CHECK(
98 ctx,
99 (canCast(common_type, out.scalar_type()) &&
100 common_type != ScalarType::Bool),
101 InvalidArgument,
102 out);
103
104 // Check for intergral division by zero
105 ET_KERNEL_CHECK_MSG(
106 ctx,
107 !(executorch::runtime::isIntegralType(common_type, true) &&
108 utils::scalar_to<double>(b) == 0),
109 InvalidArgument,
110 out,
111 "Remainder operation encountered integer division by zero");
112
113 // Check Dim Order
114 ET_KERNEL_CHECK(
115 ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
116
117 // Resize
118 ET_KERNEL_CHECK(
119 ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
120
121 // Compute Dtype
122 ScalarType compute_type = utils::get_compute_type(common_type);
123
124 // @lint-ignore CLANGTIDY facebook-hte-CArray
125 static constexpr const char op_name[] = "remainder.Scalar_out";
126
127 ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
128 const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
129 utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
130 [val_b](const CTYPE_COMPUTE val_a) {
131 return utils::remainder_override(val_a, val_b);
132 },
133 ctx,
134 a,
135 utils::SupportedTensorDtypes::REALHBBF16,
136 out,
137 utils::SupportedTensorDtypes::REALHBF16);
138 });
139
140 return out;
141 }
142
143 } // namespace native
144 } // namespace executor
145 } // namespace torch
146