xref: /aosp_15_r20/external/executorch/kernels/optimized/cpu/op_le.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/optimized/vec/functional.h>
10 #include <executorch/kernels/optimized/vec/vec.h>
11 #include <executorch/kernels/portable/cpu/scalar_utils.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 #include <executorch/runtime/platform/assert.h>
14 
15 namespace torch {
16 namespace executor {
17 namespace native {
18 
19 using Tensor = exec_aten::Tensor;
20 using ScalarType = exec_aten::ScalarType;
21 
opt_le_tensor_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)22 Tensor& opt_le_tensor_out(
23     KernelRuntimeContext& ctx,
24     const Tensor& a,
25     const Tensor& b,
26     Tensor& out) {
27   (void)ctx;
28 
29   ET_KERNEL_CHECK(ctx, tensors_have_same_shape(a, b), InvalidArgument, out);
30 
31   // Resize for dynamic shape
32   auto error = resize_tensor(out, a.sizes());
33   ET_KERNEL_CHECK_MSG(
34       ctx,
35       error == Error::Ok,
36       InvalidArgument,
37       out,
38       "Failed to resize output tensor.");
39 
40   ScalarType a_type = a.scalar_type();
41   ScalarType b_type = b.scalar_type();
42   ScalarType out_type = out.scalar_type();
43 
44   if (a_type == b_type && a_type == out_type) {
45     ET_SWITCH_REAL_TYPES_AND(
46         Bool, out_type, ctx, "le.Tensor_out", CTYPE, [&]() {
47           using Vec = executorch::vec::Vectorized<CTYPE>;
48           executorch::vec::map2<CTYPE>(
49               [](Vec x, Vec y) { return x.le(y); },
50               out.mutable_data_ptr<CTYPE>(),
51               a.const_data_ptr<CTYPE>(),
52               b.const_data_ptr<CTYPE>(),
53               a.numel());
54         });
55   } else {
56     ET_SWITCH_REAL_TYPES_AND(
57         Bool, a_type, ctx, "le.Tensor_out", CTYPE_A, [&]() {
58           ET_SWITCH_REAL_TYPES_AND(
59               Bool, b_type, ctx, "le.Tensor_out", CTYPE_B, [&]() {
60                 using CTYPE_IN = typename torch::executor::
61                     promote_types<CTYPE_A, CTYPE_B>::type;
62                 ET_DCHECK(
63                     CppTypeToScalarType<CTYPE_IN>::value ==
64                     promoteTypes(a_type, b_type));
65                 ET_SWITCH_REAL_TYPES_AND(
66                     Bool, out_type, ctx, "le.Tensor_out", CTYPE_OUT, [&]() {
67                       const size_t n = a.numel();
68                       const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
69                       const CTYPE_B* b_data = b.const_data_ptr<CTYPE_B>();
70                       CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
71                       for (auto i = 0; i < n; ++i) {
72                         out_data[i] = static_cast<CTYPE_OUT>(
73                             static_cast<CTYPE_IN>(a_data[i]) <=
74                             static_cast<CTYPE_IN>(b_data[i]));
75                       }
76                     });
77               });
78         });
79   }
80 
81   return out;
82 }
83 
opt_le_scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,Tensor & out)84 Tensor& opt_le_scalar_out(
85     KernelRuntimeContext& ctx,
86     const Tensor& a,
87     const Scalar& b,
88     Tensor& out) {
89   (void)ctx;
90 
91   // Resize for dynamic shape
92   auto error = resize_tensor(out, a.sizes());
93   ET_KERNEL_CHECK_MSG(
94       ctx,
95       error == Error::Ok,
96       InvalidArgument,
97       out,
98       "Failed to resize output tensor.");
99 
100   ScalarType a_type = a.scalar_type();
101   ScalarType b_type = utils::get_scalar_dtype(b);
102   ScalarType common_type = promoteTypes(a_type, b_type);
103   ScalarType out_type = out.scalar_type();
104 
105   if (a_type == common_type && a_type == out_type) {
106     ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "le.Scalar_out", CTYPE, [&]() {
107       ET_SWITCH_REAL_TYPES_AND(
108           Bool, b_type, ctx, "le.Scalar_out", CTYPE_B, [&]() {
109             CTYPE_B b_val = 0;
110             ET_EXTRACT_SCALAR(b, b_val);
111             CTYPE b_casted = static_cast<CTYPE>(b_val);
112             using Vec = executorch::vec::Vectorized<CTYPE>;
113             executorch::vec::map<CTYPE>(
114                 [b_casted](Vec x) { return x.le(Vec(b_casted)); },
115                 out.mutable_data_ptr<CTYPE>(),
116                 a.const_data_ptr<CTYPE>(),
117                 a.numel());
118           });
119     });
120   } else {
121     ET_SWITCH_REAL_TYPES_AND(
122         Bool, a_type, ctx, "le.Scalar_out", CTYPE_A, [&]() {
123           ET_SWITCH_REAL_TYPES_AND(
124               Bool, b_type, ctx, "le.Scalar_out", CTYPE_B, [&]() {
125                 ET_SWITCH_REAL_TYPES_AND(
126                     Bool, common_type, ctx, "le.Scalar_out", CTYPE_IN, [&]() {
127                       ET_SWITCH_REAL_TYPES_AND(
128                           Bool,
129                           out_type,
130                           ctx,
131                           "le.Scalar_out",
132                           CTYPE_OUT,
133                           [&]() {
134                             CTYPE_B b_val = 0;
135                             ET_EXTRACT_SCALAR(b, b_val);
136                             CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
137                             const size_t n = a.numel();
138                             const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
139                             CTYPE_OUT* out_data =
140                                 out.mutable_data_ptr<CTYPE_OUT>();
141                             for (auto i = 0; i < n; ++i) {
142                               out_data[i] = static_cast<CTYPE_OUT>(
143                                   static_cast<CTYPE_IN>(a_data[i]) <= b_casted);
144                             }
145                           });
146                     });
147               });
148         });
149   }
150 
151   return out;
152 }
153 
154 } // namespace native
155 } // namespace executor
156 } // namespace torch
157