xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/pattern/comparison_op.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/kernels/portable/cpu/scalar_utils.h>
12 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 
15 namespace torch {
16 namespace executor {
17 namespace native {
18 namespace internal {
19 
20 #define DEFINE_BINARY_OPERATOR_TEMPLATE(name, op) \
21   template <typename T>                           \
22   T name(const T val_a, const T val_b) {          \
23     return val_a op val_b;                        \
24   }
25 
26 DEFINE_BINARY_OPERATOR_TEMPLATE(eq, ==)
27 DEFINE_BINARY_OPERATOR_TEMPLATE(ne, !=)
28 DEFINE_BINARY_OPERATOR_TEMPLATE(ge, >=)
29 DEFINE_BINARY_OPERATOR_TEMPLATE(le, <=)
30 DEFINE_BINARY_OPERATOR_TEMPLATE(gt, >)
31 DEFINE_BINARY_OPERATOR_TEMPLATE(lt, <)
32 
33 template <typename T>
34 using comparison_fn = T (*)(const T, const T);
35 
36 template <typename T, const char* op_name>
get_comparison_fn()37 constexpr comparison_fn<T> get_comparison_fn() {
38   std::string_view op = op_name;
39   if (op == "eq.Tensor_out" || op == "eq.Scalar_out") {
40     return eq;
41   }
42   if (op == "ne.Tensor_out" || op == "ne.Scalar_out") {
43     return ne;
44   }
45   if (op == "ge.Tensor_out" || op == "ge.Scalar_out") {
46     return ge;
47   }
48   if (op == "le.Tensor_out" || op == "le.Scalar_out") {
49     return le;
50   }
51   if (op == "gt.Tensor_out" || op == "gt.Scalar_out") {
52     return gt;
53   }
54   if (op == "lt.Tensor_out" || op == "lt.Scalar_out") {
55     return lt;
56   }
57   return nullptr;
58 };
59 
60 template <typename T, const char* op_name>
61 struct ComparisonFnForOp {
62   static constexpr auto value = get_comparison_fn<T, op_name>();
63   static_assert(value != nullptr, "unknown op_name!");
64 };
65 
66 template <const char* op_name>
comparison_tensor_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)67 Tensor& comparison_tensor_out(
68     KernelRuntimeContext& ctx,
69     const Tensor& a,
70     const Tensor& b,
71     Tensor& out) {
72   // Common Dtype
73   ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
74   if (executorch::runtime::isFloatingType(common_type) &&
75       a.scalar_type() != b.scalar_type()) {
76     common_type = ScalarType::Float;
77   }
78 
79   // Check Dim Order
80   ET_KERNEL_CHECK(
81       ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
82 
83   // Resize
84   ET_KERNEL_CHECK(
85       ctx,
86       resize_to_broadcast_target_size(a, b, out) == Error::Ok,
87       InvalidArgument,
88       out);
89 
90   // Compute Dtype
91   ScalarType compute_type = utils::get_compute_type(common_type);
92 
93   ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
94     utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
95         ComparisonFnForOp<CTYPE_COMPUTE, op_name>::value,
96         ctx,
97         a,
98         utils::SupportedTensorDtypes::REALHBBF16,
99         b,
100         utils::SupportedTensorDtypes::REALHBBF16,
101         out,
102         utils::SupportedTensorDtypes::REALHBBF16);
103   });
104 
105   return out;
106 }
107 
108 template <const char* op_name>
comparison_scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,Tensor & out)109 Tensor& comparison_scalar_out(
110     KernelRuntimeContext& ctx,
111     const Tensor& a,
112     const Scalar& b,
113     Tensor& out) {
114   // Common Dtype
115   ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
116 
117   // Check Dim Order
118   ET_KERNEL_CHECK(
119       ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
120 
121   // Resize
122   ET_KERNEL_CHECK(
123       ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
124 
125   // Compute Dtype
126   ScalarType compute_type = utils::get_compute_type(common_type);
127 
128   ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
129     const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
130     utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
131         [val_b](const CTYPE_COMPUTE val_a) {
132           return ComparisonFnForOp<CTYPE_COMPUTE, op_name>::value(val_a, val_b);
133         },
134         ctx,
135         a,
136         utils::SupportedTensorDtypes::REALHBBF16,
137         out,
138         utils::SupportedTensorDtypes::REALHBBF16);
139   });
140 
141   return out;
142 }
143 
144 } // namespace internal
145 } // namespace native
146 } // namespace executor
147 } // namespace torch
148