1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/optimized/vec/functional.h>
10 #include <executorch/kernels/optimized/vec/vec.h>
11 #include <executorch/kernels/portable/cpu/scalar_utils.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 #include <executorch/runtime/platform/assert.h>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18
19 using Tensor = exec_aten::Tensor;
20 using ScalarType = exec_aten::ScalarType;
21
opt_le_tensor_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)22 Tensor& opt_le_tensor_out(
23 KernelRuntimeContext& ctx,
24 const Tensor& a,
25 const Tensor& b,
26 Tensor& out) {
27 (void)ctx;
28
29 ET_KERNEL_CHECK(ctx, tensors_have_same_shape(a, b), InvalidArgument, out);
30
31 // Resize for dynamic shape
32 auto error = resize_tensor(out, a.sizes());
33 ET_KERNEL_CHECK_MSG(
34 ctx,
35 error == Error::Ok,
36 InvalidArgument,
37 out,
38 "Failed to resize output tensor.");
39
40 ScalarType a_type = a.scalar_type();
41 ScalarType b_type = b.scalar_type();
42 ScalarType out_type = out.scalar_type();
43
44 if (a_type == b_type && a_type == out_type) {
45 ET_SWITCH_REAL_TYPES_AND(
46 Bool, out_type, ctx, "le.Tensor_out", CTYPE, [&]() {
47 using Vec = executorch::vec::Vectorized<CTYPE>;
48 executorch::vec::map2<CTYPE>(
49 [](Vec x, Vec y) { return x.le(y); },
50 out.mutable_data_ptr<CTYPE>(),
51 a.const_data_ptr<CTYPE>(),
52 b.const_data_ptr<CTYPE>(),
53 a.numel());
54 });
55 } else {
56 ET_SWITCH_REAL_TYPES_AND(
57 Bool, a_type, ctx, "le.Tensor_out", CTYPE_A, [&]() {
58 ET_SWITCH_REAL_TYPES_AND(
59 Bool, b_type, ctx, "le.Tensor_out", CTYPE_B, [&]() {
60 using CTYPE_IN = typename torch::executor::
61 promote_types<CTYPE_A, CTYPE_B>::type;
62 ET_DCHECK(
63 CppTypeToScalarType<CTYPE_IN>::value ==
64 promoteTypes(a_type, b_type));
65 ET_SWITCH_REAL_TYPES_AND(
66 Bool, out_type, ctx, "le.Tensor_out", CTYPE_OUT, [&]() {
67 const size_t n = a.numel();
68 const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
69 const CTYPE_B* b_data = b.const_data_ptr<CTYPE_B>();
70 CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
71 for (auto i = 0; i < n; ++i) {
72 out_data[i] = static_cast<CTYPE_OUT>(
73 static_cast<CTYPE_IN>(a_data[i]) <=
74 static_cast<CTYPE_IN>(b_data[i]));
75 }
76 });
77 });
78 });
79 }
80
81 return out;
82 }
83
opt_le_scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,Tensor & out)84 Tensor& opt_le_scalar_out(
85 KernelRuntimeContext& ctx,
86 const Tensor& a,
87 const Scalar& b,
88 Tensor& out) {
89 (void)ctx;
90
91 // Resize for dynamic shape
92 auto error = resize_tensor(out, a.sizes());
93 ET_KERNEL_CHECK_MSG(
94 ctx,
95 error == Error::Ok,
96 InvalidArgument,
97 out,
98 "Failed to resize output tensor.");
99
100 ScalarType a_type = a.scalar_type();
101 ScalarType b_type = utils::get_scalar_dtype(b);
102 ScalarType common_type = promoteTypes(a_type, b_type);
103 ScalarType out_type = out.scalar_type();
104
105 if (a_type == common_type && a_type == out_type) {
106 ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "le.Scalar_out", CTYPE, [&]() {
107 ET_SWITCH_REAL_TYPES_AND(
108 Bool, b_type, ctx, "le.Scalar_out", CTYPE_B, [&]() {
109 CTYPE_B b_val = 0;
110 ET_EXTRACT_SCALAR(b, b_val);
111 CTYPE b_casted = static_cast<CTYPE>(b_val);
112 using Vec = executorch::vec::Vectorized<CTYPE>;
113 executorch::vec::map<CTYPE>(
114 [b_casted](Vec x) { return x.le(Vec(b_casted)); },
115 out.mutable_data_ptr<CTYPE>(),
116 a.const_data_ptr<CTYPE>(),
117 a.numel());
118 });
119 });
120 } else {
121 ET_SWITCH_REAL_TYPES_AND(
122 Bool, a_type, ctx, "le.Scalar_out", CTYPE_A, [&]() {
123 ET_SWITCH_REAL_TYPES_AND(
124 Bool, b_type, ctx, "le.Scalar_out", CTYPE_B, [&]() {
125 ET_SWITCH_REAL_TYPES_AND(
126 Bool, common_type, ctx, "le.Scalar_out", CTYPE_IN, [&]() {
127 ET_SWITCH_REAL_TYPES_AND(
128 Bool,
129 out_type,
130 ctx,
131 "le.Scalar_out",
132 CTYPE_OUT,
133 [&]() {
134 CTYPE_B b_val = 0;
135 ET_EXTRACT_SCALAR(b, b_val);
136 CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
137 const size_t n = a.numel();
138 const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
139 CTYPE_OUT* out_data =
140 out.mutable_data_ptr<CTYPE_OUT>();
141 for (auto i = 0; i < n; ++i) {
142 out_data[i] = static_cast<CTYPE_OUT>(
143 static_cast<CTYPE_IN>(a_data[i]) <= b_casted);
144 }
145 });
146 });
147 });
148 });
149 }
150
151 return out;
152 }
153
154 } // namespace native
155 } // namespace executor
156 } // namespace torch
157