1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <algorithm>
10 #include <cmath>
11 #include <cstring>
12 #include <limits>
13
14 #include <executorch/kernels/portable/cpu/scalar_utils.h>
15 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
16 #include <executorch/kernels/portable/cpu/util/math_util.h>
17 #include <executorch/runtime/kernel/kernel_includes.h>
18
19 namespace torch {
20 namespace executor {
21 namespace native {
22
23 using Scalar = exec_aten::Scalar;
24 using ScalarType = exec_aten::ScalarType;
25 using Tensor = exec_aten::Tensor;
26
27 namespace {
28
29 template <typename CTYPE_VAL, typename CTYPE_OUT, typename CTYPE_CAST>
30 /** Check if val, when cast to CTYPE_CAST, is not in the range of CTYPE_OUT */
is_out_of_bounds(CTYPE_VAL val)31 bool is_out_of_bounds(CTYPE_VAL val) {
32 const CTYPE_CAST val_cast = static_cast<CTYPE_CAST>(val);
33 return val_cast < std::numeric_limits<CTYPE_OUT>::lowest() ||
34 val_cast > std::numeric_limits<CTYPE_OUT>::max();
35 }
36
check_bounds(const Scalar & val_scalar,const torch::executor::native::ScalarType & val_type,const torch::executor::native::ScalarType & out_type,const char * val_name)37 ET_NODISCARD bool check_bounds(
38 const Scalar& val_scalar,
39 const torch::executor::native::ScalarType& val_type,
40 const torch::executor::native::ScalarType& out_type,
41 const char* val_name) {
42 auto is_valid = true;
43
44 ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, "clamp.out", CTYPE_VAL, [&]() {
45 CTYPE_VAL val = 0;
46 utils::extract_scalar(val_scalar, &val);
47 if (isIntegralType(out_type, /*includeBool=*/false)) {
48 ET_SWITCH_INT_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() {
49 if (is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, long>(val)) {
50 ET_LOG(Error, "%s value out of bounds", val_name);
51 is_valid = false;
52 }
53 });
54 } else if (isFloatingType(out_type)) {
55 ET_SWITCH_FLOATH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
56 if (std::isfinite(val) &&
57 is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, double>(val)) {
58 ET_LOG(Error, "%s value out of bounds", val_name);
59 is_valid = false;
60 }
61 });
62 }
63 });
64
65 return is_valid;
66 }
67
68 } // namespace
69
clamp_out(KernelRuntimeContext & ctx,const Tensor & in,const exec_aten::optional<Scalar> & min_opt,const exec_aten::optional<Scalar> & max_opt,Tensor & out)70 Tensor& clamp_out(
71 KernelRuntimeContext& ctx,
72 const Tensor& in,
73 const exec_aten::optional<Scalar>& min_opt,
74 const exec_aten::optional<Scalar>& max_opt,
75 Tensor& out) {
76 bool has_min = min_opt.has_value();
77 bool has_max = max_opt.has_value();
78
79 ET_KERNEL_CHECK_MSG(
80 ctx,
81 has_min || has_max,
82 InvalidArgument,
83 out,
84 "At least one of 'min' or 'max' must not be None");
85
86 // Input Dtypes
87 ScalarType in_type = in.scalar_type();
88 ScalarType min_type =
89 has_min ? utils::get_scalar_dtype(min_opt.value()) : in_type;
90 ScalarType max_type =
91 has_max ? utils::get_scalar_dtype(max_opt.value()) : in_type;
92 ScalarType out_type = out.scalar_type();
93
94 // Common Dtype
95 ScalarType common_type = in_type;
96 if (has_min) {
97 common_type = utils::promote_type_with_scalar(common_type, min_opt.value());
98 }
99 if (has_max) {
100 common_type = utils::promote_type_with_scalar(common_type, max_opt.value());
101 }
102
103 // Check Common Dtype
104 ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
105
106 // Check Scalar Bounds
107 if (has_min) {
108 ET_KERNEL_CHECK(
109 ctx,
110 check_bounds(min_opt.value(), min_type, out_type, "minimum"),
111 InvalidArgument,
112 out);
113 }
114 if (has_max) {
115 ET_KERNEL_CHECK(
116 ctx,
117 check_bounds(max_opt.value(), max_type, out_type, "maximum"),
118 InvalidArgument,
119 out);
120 }
121
122 // Check Dim Order
123 ET_KERNEL_CHECK(
124 ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
125
126 // Resize
127 ET_KERNEL_CHECK(
128 ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
129
130 // Compute Dtype
131 ScalarType compute_type = utils::get_compute_type(common_type);
132
133 // @lint-ignore CLANGTIDY facebook-hte-CArray
134 static constexpr const char op_name[] = "clamp.out";
135
136 ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
137 utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
138 [has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) {
139 CTYPE_COMPUTE val_out = val_in;
140 if (has_min) {
141 val_out = utils::max_override(
142 val_out, utils::scalar_to<CTYPE_COMPUTE>(min_opt.value()));
143 }
144 if (has_max) {
145 val_out = utils::min_override(
146 val_out, utils::scalar_to<CTYPE_COMPUTE>(max_opt.value()));
147 }
148 return val_out;
149 },
150 ctx,
151 in,
152 utils::SupportedTensorDtypes::REALHBBF16,
153 out,
154 utils::SupportedTensorDtypes::SAME_AS_COMMON);
155 });
156
157 return out;
158 }
159
clamp_tensor_out(KernelRuntimeContext & ctx,const Tensor & in,const exec_aten::optional<Tensor> & min_opt,const exec_aten::optional<Tensor> & max_opt,Tensor & out)160 Tensor& clamp_tensor_out(
161 KernelRuntimeContext& ctx,
162 const Tensor& in,
163 const exec_aten::optional<Tensor>& min_opt,
164 const exec_aten::optional<Tensor>& max_opt,
165 Tensor& out) {
166 bool has_min = min_opt.has_value();
167 bool has_max = max_opt.has_value();
168
169 ET_KERNEL_CHECK_MSG(
170 ctx,
171 has_min || has_max,
172 InvalidArgument,
173 out,
174 "At least one of 'min' or 'max' must not be None");
175
176 const Tensor& min = has_min ? min_opt.value() : in;
177 const Tensor& max = has_max ? max_opt.value() : in;
178
179 // Common Dtype
180 ScalarType common_type = in.scalar_type();
181 if (has_min) {
182 common_type = promoteTypes(common_type, min.scalar_type());
183 }
184 if (has_max) {
185 common_type = promoteTypes(common_type, max.scalar_type());
186 }
187
188 // Check Common Dtype
189 ET_KERNEL_CHECK(
190 ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);
191
192 // Check Dim Order
193 ET_KERNEL_CHECK(
194 ctx,
195 tensors_have_same_dim_order(in, min, max, out),
196 InvalidArgument,
197 out);
198
199 // Resize
200 ET_KERNEL_CHECK(
201 ctx,
202 resize_to_broadcast_target_size(in, min, max, out) == Error::Ok,
203 InvalidArgument,
204 out);
205
206 // Compute Dtype
207 ScalarType compute_type = utils::get_compute_type(common_type);
208
209 // @lint-ignore CLANGTIDY facebook-hte-CArray
210 static constexpr const char op_name[] = "clamp.Tensor_out";
211
212 ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
213 utils::apply_tritensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
214 [has_min, has_max](
215 const CTYPE_COMPUTE val_in,
216 const CTYPE_COMPUTE val_min,
217 const CTYPE_COMPUTE val_max) {
218 CTYPE_COMPUTE val_out = val_in;
219 if (has_min) {
220 val_out = utils::max_override(val_out, val_min);
221 }
222 if (has_max) {
223 val_out = utils::min_override(val_out, val_max);
224 }
225 return val_out;
226 },
227 ctx,
228 in,
229 utils::SupportedTensorDtypes::REALHBBF16,
230 min,
231 utils::SupportedTensorDtypes::REALHBBF16,
232 max,
233 utils::SupportedTensorDtypes::REALHBBF16,
234 out,
235 utils::SupportedTensorDtypes::REALHBBF16);
236 });
237
238 return out;
239 }
240
241 } // namespace native
242 } // namespace executor
243 } // namespace torch
244