1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <executorch/kernels/portable/cpu/scalar_utils.h>
12 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18 namespace internal {
19
20 #define DEFINE_BINARY_OPERATOR_TEMPLATE(name, op) \
21 template <typename T> \
22 T name(const T val_a, const T val_b) { \
23 return val_a op val_b; \
24 }
25
26 DEFINE_BINARY_OPERATOR_TEMPLATE(eq, ==)
27 DEFINE_BINARY_OPERATOR_TEMPLATE(ne, !=)
28 DEFINE_BINARY_OPERATOR_TEMPLATE(ge, >=)
29 DEFINE_BINARY_OPERATOR_TEMPLATE(le, <=)
30 DEFINE_BINARY_OPERATOR_TEMPLATE(gt, >)
31 DEFINE_BINARY_OPERATOR_TEMPLATE(lt, <)
32
33 template <typename T>
34 using comparison_fn = T (*)(const T, const T);
35
36 template <typename T, const char* op_name>
get_comparison_fn()37 constexpr comparison_fn<T> get_comparison_fn() {
38 std::string_view op = op_name;
39 if (op == "eq.Tensor_out" || op == "eq.Scalar_out") {
40 return eq;
41 }
42 if (op == "ne.Tensor_out" || op == "ne.Scalar_out") {
43 return ne;
44 }
45 if (op == "ge.Tensor_out" || op == "ge.Scalar_out") {
46 return ge;
47 }
48 if (op == "le.Tensor_out" || op == "le.Scalar_out") {
49 return le;
50 }
51 if (op == "gt.Tensor_out" || op == "gt.Scalar_out") {
52 return gt;
53 }
54 if (op == "lt.Tensor_out" || op == "lt.Scalar_out") {
55 return lt;
56 }
57 return nullptr;
58 };
59
60 template <typename T, const char* op_name>
61 struct ComparisonFnForOp {
62 static constexpr auto value = get_comparison_fn<T, op_name>();
63 static_assert(value != nullptr, "unknown op_name!");
64 };
65
66 template <const char* op_name>
comparison_tensor_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)67 Tensor& comparison_tensor_out(
68 KernelRuntimeContext& ctx,
69 const Tensor& a,
70 const Tensor& b,
71 Tensor& out) {
72 // Common Dtype
73 ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
74 if (executorch::runtime::isFloatingType(common_type) &&
75 a.scalar_type() != b.scalar_type()) {
76 common_type = ScalarType::Float;
77 }
78
79 // Check Dim Order
80 ET_KERNEL_CHECK(
81 ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
82
83 // Resize
84 ET_KERNEL_CHECK(
85 ctx,
86 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
87 InvalidArgument,
88 out);
89
90 // Compute Dtype
91 ScalarType compute_type = utils::get_compute_type(common_type);
92
93 ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
94 utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
95 ComparisonFnForOp<CTYPE_COMPUTE, op_name>::value,
96 ctx,
97 a,
98 utils::SupportedTensorDtypes::REALHBBF16,
99 b,
100 utils::SupportedTensorDtypes::REALHBBF16,
101 out,
102 utils::SupportedTensorDtypes::REALHBBF16);
103 });
104
105 return out;
106 }
107
108 template <const char* op_name>
comparison_scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,Tensor & out)109 Tensor& comparison_scalar_out(
110 KernelRuntimeContext& ctx,
111 const Tensor& a,
112 const Scalar& b,
113 Tensor& out) {
114 // Common Dtype
115 ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
116
117 // Check Dim Order
118 ET_KERNEL_CHECK(
119 ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
120
121 // Resize
122 ET_KERNEL_CHECK(
123 ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
124
125 // Compute Dtype
126 ScalarType compute_type = utils::get_compute_type(common_type);
127
128 ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
129 const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
130 utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
131 [val_b](const CTYPE_COMPUTE val_a) {
132 return ComparisonFnForOp<CTYPE_COMPUTE, op_name>::value(val_a, val_b);
133 },
134 ctx,
135 a,
136 utils::SupportedTensorDtypes::REALHBBF16,
137 out,
138 utils::SupportedTensorDtypes::REALHBBF16);
139 });
140
141 return out;
142 }
143
144 } // namespace internal
145 } // namespace native
146 } // namespace executor
147 } // namespace torch
148