xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_clamp.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <algorithm>
10 #include <cmath>
11 #include <cstring>
12 #include <limits>
13 
14 #include <executorch/kernels/portable/cpu/scalar_utils.h>
15 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
16 #include <executorch/kernels/portable/cpu/util/math_util.h>
17 #include <executorch/runtime/kernel/kernel_includes.h>
18 
19 namespace torch {
20 namespace executor {
21 namespace native {
22 
23 using Scalar = exec_aten::Scalar;
24 using ScalarType = exec_aten::ScalarType;
25 using Tensor = exec_aten::Tensor;
26 
27 namespace {
28 
29 template <typename CTYPE_VAL, typename CTYPE_OUT, typename CTYPE_CAST>
30 /** Check if val, when cast to CTYPE_CAST, is not in the range of CTYPE_OUT */
is_out_of_bounds(CTYPE_VAL val)31 bool is_out_of_bounds(CTYPE_VAL val) {
32   const CTYPE_CAST val_cast = static_cast<CTYPE_CAST>(val);
33   return val_cast < std::numeric_limits<CTYPE_OUT>::lowest() ||
34       val_cast > std::numeric_limits<CTYPE_OUT>::max();
35 }
36 
check_bounds(const Scalar & val_scalar,const torch::executor::native::ScalarType & val_type,const torch::executor::native::ScalarType & out_type,const char * val_name)37 ET_NODISCARD bool check_bounds(
38     const Scalar& val_scalar,
39     const torch::executor::native::ScalarType& val_type,
40     const torch::executor::native::ScalarType& out_type,
41     const char* val_name) {
42   auto is_valid = true;
43 
44   ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, "clamp.out", CTYPE_VAL, [&]() {
45     CTYPE_VAL val = 0;
46     utils::extract_scalar(val_scalar, &val);
47     if (isIntegralType(out_type, /*includeBool=*/false)) {
48       ET_SWITCH_INT_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() {
49         if (is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, long>(val)) {
50           ET_LOG(Error, "%s value out of bounds", val_name);
51           is_valid = false;
52         }
53       });
54     } else if (isFloatingType(out_type)) {
55       ET_SWITCH_FLOATH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
56         if (std::isfinite(val) &&
57             is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, double>(val)) {
58           ET_LOG(Error, "%s value out of bounds", val_name);
59           is_valid = false;
60         }
61       });
62     }
63   });
64 
65   return is_valid;
66 }
67 
68 } // namespace
69 
clamp_out(KernelRuntimeContext & ctx,const Tensor & in,const exec_aten::optional<Scalar> & min_opt,const exec_aten::optional<Scalar> & max_opt,Tensor & out)70 Tensor& clamp_out(
71     KernelRuntimeContext& ctx,
72     const Tensor& in,
73     const exec_aten::optional<Scalar>& min_opt,
74     const exec_aten::optional<Scalar>& max_opt,
75     Tensor& out) {
76   bool has_min = min_opt.has_value();
77   bool has_max = max_opt.has_value();
78 
79   ET_KERNEL_CHECK_MSG(
80       ctx,
81       has_min || has_max,
82       InvalidArgument,
83       out,
84       "At least one of 'min' or 'max' must not be None");
85 
86   // Input Dtypes
87   ScalarType in_type = in.scalar_type();
88   ScalarType min_type =
89       has_min ? utils::get_scalar_dtype(min_opt.value()) : in_type;
90   ScalarType max_type =
91       has_max ? utils::get_scalar_dtype(max_opt.value()) : in_type;
92   ScalarType out_type = out.scalar_type();
93 
94   // Common Dtype
95   ScalarType common_type = in_type;
96   if (has_min) {
97     common_type = utils::promote_type_with_scalar(common_type, min_opt.value());
98   }
99   if (has_max) {
100     common_type = utils::promote_type_with_scalar(common_type, max_opt.value());
101   }
102 
103   // Check Common Dtype
104   ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
105 
106   // Check Scalar Bounds
107   if (has_min) {
108     ET_KERNEL_CHECK(
109         ctx,
110         check_bounds(min_opt.value(), min_type, out_type, "minimum"),
111         InvalidArgument,
112         out);
113   }
114   if (has_max) {
115     ET_KERNEL_CHECK(
116         ctx,
117         check_bounds(max_opt.value(), max_type, out_type, "maximum"),
118         InvalidArgument,
119         out);
120   }
121 
122   // Check Dim Order
123   ET_KERNEL_CHECK(
124       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
125 
126   // Resize
127   ET_KERNEL_CHECK(
128       ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
129 
130   // Compute Dtype
131   ScalarType compute_type = utils::get_compute_type(common_type);
132 
133   // @lint-ignore CLANGTIDY facebook-hte-CArray
134   static constexpr const char op_name[] = "clamp.out";
135 
136   ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
137     utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
138         [has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) {
139           CTYPE_COMPUTE val_out = val_in;
140           if (has_min) {
141             val_out = utils::max_override(
142                 val_out, utils::scalar_to<CTYPE_COMPUTE>(min_opt.value()));
143           }
144           if (has_max) {
145             val_out = utils::min_override(
146                 val_out, utils::scalar_to<CTYPE_COMPUTE>(max_opt.value()));
147           }
148           return val_out;
149         },
150         ctx,
151         in,
152         utils::SupportedTensorDtypes::REALHBBF16,
153         out,
154         utils::SupportedTensorDtypes::SAME_AS_COMMON);
155   });
156 
157   return out;
158 }
159 
clamp_tensor_out(KernelRuntimeContext & ctx,const Tensor & in,const exec_aten::optional<Tensor> & min_opt,const exec_aten::optional<Tensor> & max_opt,Tensor & out)160 Tensor& clamp_tensor_out(
161     KernelRuntimeContext& ctx,
162     const Tensor& in,
163     const exec_aten::optional<Tensor>& min_opt,
164     const exec_aten::optional<Tensor>& max_opt,
165     Tensor& out) {
166   bool has_min = min_opt.has_value();
167   bool has_max = max_opt.has_value();
168 
169   ET_KERNEL_CHECK_MSG(
170       ctx,
171       has_min || has_max,
172       InvalidArgument,
173       out,
174       "At least one of 'min' or 'max' must not be None");
175 
176   const Tensor& min = has_min ? min_opt.value() : in;
177   const Tensor& max = has_max ? max_opt.value() : in;
178 
179   // Common Dtype
180   ScalarType common_type = in.scalar_type();
181   if (has_min) {
182     common_type = promoteTypes(common_type, min.scalar_type());
183   }
184   if (has_max) {
185     common_type = promoteTypes(common_type, max.scalar_type());
186   }
187 
188   // Check Common Dtype
189   ET_KERNEL_CHECK(
190       ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);
191 
192   // Check Dim Order
193   ET_KERNEL_CHECK(
194       ctx,
195       tensors_have_same_dim_order(in, min, max, out),
196       InvalidArgument,
197       out);
198 
199   // Resize
200   ET_KERNEL_CHECK(
201       ctx,
202       resize_to_broadcast_target_size(in, min, max, out) == Error::Ok,
203       InvalidArgument,
204       out);
205 
206   // Compute Dtype
207   ScalarType compute_type = utils::get_compute_type(common_type);
208 
209   // @lint-ignore CLANGTIDY facebook-hte-CArray
210   static constexpr const char op_name[] = "clamp.Tensor_out";
211 
212   ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
213     utils::apply_tritensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
214         [has_min, has_max](
215             const CTYPE_COMPUTE val_in,
216             const CTYPE_COMPUTE val_min,
217             const CTYPE_COMPUTE val_max) {
218           CTYPE_COMPUTE val_out = val_in;
219           if (has_min) {
220             val_out = utils::max_override(val_out, val_min);
221           }
222           if (has_max) {
223             val_out = utils::min_override(val_out, val_max);
224           }
225           return val_out;
226         },
227         ctx,
228         in,
229         utils::SupportedTensorDtypes::REALHBBF16,
230         min,
231         utils::SupportedTensorDtypes::REALHBBF16,
232         max,
233         utils::SupportedTensorDtypes::REALHBBF16,
234         out,
235         utils::SupportedTensorDtypes::REALHBBF16);
236   });
237 
238   return out;
239 }
240 
241 } // namespace native
242 } // namespace executor
243 } // namespace torch
244