1 /*
2 * Copyright (c) 2021-2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #ifndef SRC_CORE_SVE_KERNELS_ELEMENTWISE_LIST_H
25 #define SRC_CORE_SVE_KERNELS_ELEMENTWISE_LIST_H
26
27 #include "arm_compute/core/Helpers.h"
28 #include "src/core/NEON/wrapper/intrinsics/intrinsics.h"
29 #include "src/core/NEON/wrapper/svtraits.h"
30
31 namespace arm_compute
32 {
33 namespace cpu
34 {
35 using namespace arm_compute::wrapper;
36
37 template <typename VectorType>
elementwise_pow(svbool_t & pg,const VectorType & a,const VectorType & b)38 VectorType elementwise_pow(svbool_t &pg, const VectorType &a, const VectorType &b)
39 {
40 return svpow_z(pg, a, b);
41 }
42
43 template <typename VectorType>
elementwise_div(svbool_t & pg,const VectorType & a,const VectorType & b)44 VectorType elementwise_div(svbool_t &pg, const VectorType &a, const VectorType &b)
45 {
46 return svdiv_z(pg, a, b);
47 }
48
49 template <uint32_t bytewidth>
narrow_to_byte_predicate(svbool_t pg)50 svbool_t narrow_to_byte_predicate(svbool_t pg)
51 {
52 const auto all_false = svpfalse();
53
54 switch(bytewidth)
55 {
56 case 8:
57 pg = svuzp1_b32(pg, all_false);
58 /* fall through */
59 case 4:
60 pg = svuzp1_b16(pg, all_false);
61 /* fall through */
62 case 2:
63 pg = svuzp1_b8(pg, all_false);
64 /* fall through */
65 default:
66 break;
67 }
68 return pg;
69 }
70
71 template <typename VectorType>
elementwise_arithmetic_op(svbool_t & pg,const VectorType & a,const VectorType & b,ArithmeticOperation op)72 VectorType elementwise_arithmetic_op(svbool_t &pg, const VectorType &a, const VectorType &b, ArithmeticOperation op)
73 {
74 using ScalarType = typename wrapper::sve_scalar<VectorType>::type;
75 VectorType res{};
76
77 switch(op)
78 {
79 case ArithmeticOperation::MAX:
80 res = svmax_z(pg, a, b);
81 break;
82 case ArithmeticOperation::MIN:
83 res = svmin_z(pg, a, b);
84 break;
85 case ArithmeticOperation::SQUARED_DIFF:
86 {
87 const auto tmp = svsub_z(pg, a, b);
88 res = svmul_z(pg, tmp, tmp);
89 break;
90 }
91 case ArithmeticOperation::PRELU:
92 {
93 const auto zero = svdup_n(ScalarType(0));
94 const auto tmp = svmul_z(pg, a, b);
95 const auto gt = svcmpgt(pg, a, zero);
96 res = svsel(gt, a, tmp);
97 break;
98 }
99 case ArithmeticOperation::DIV:
100 {
101 res = elementwise_div(pg, a, b);
102 break;
103 }
104 case ArithmeticOperation::POWER:
105 {
106 res = elementwise_pow(pg, a, b);
107 break;
108 }
109 default:
110 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
111 }
112
113 return res;
114 }
115
116 template <typename InputVectorType, typename OutputVectorType>
elementwise_comparison_op(svbool_t & pg,const InputVectorType & a,const InputVectorType & b,ComparisonOperation op)117 OutputVectorType elementwise_comparison_op(svbool_t &pg, const InputVectorType &a, const InputVectorType &b, ComparisonOperation op)
118 {
119 svbool_t selection_vector{};
120
121 switch(op)
122 {
123 case ComparisonOperation::Equal:
124 selection_vector = svcmpeq(pg, a, b);
125 break;
126 case ComparisonOperation::NotEqual:
127 selection_vector = svcmpne(pg, a, b);
128 break;
129 case ComparisonOperation::Greater:
130 selection_vector = svcmpgt(pg, a, b);
131 break;
132 case ComparisonOperation::GreaterEqual:
133 selection_vector = svcmpge(pg, a, b);
134 break;
135 case ComparisonOperation::Less:
136 selection_vector = svcmplt(pg, a, b);
137 break;
138 case ComparisonOperation::LessEqual:
139 selection_vector = svcmple(pg, a, b);
140 break;
141 default:
142 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
143 }
144
145 using InputScalarType = typename wrapper::sve_scalar<InputVectorType>::type;
146 selection_vector = narrow_to_byte_predicate<sizeof(InputScalarType)>(selection_vector);
147
148 using OutputScalarType = typename wrapper::sve_scalar<OutputVectorType>::type;
149 const auto false_vector = svdup_n(static_cast<OutputScalarType>((uint32_t)0));
150 const auto true_vector = svdup_n(static_cast<OutputScalarType>(~(uint32_t)0));
151 auto ret = svsel(selection_vector, true_vector, false_vector);
152
153 return ret;
154 }
155
156 template <typename ScalarType>
157 void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, ArithmeticOperation op, const Window &window);
158
159 template <typename ScalarType, typename OutputScalarType = uint8_t>
160 void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, ComparisonOperation op, const Window &window);
161 } // namespace cpu
162 } // namespace arm_compute
163 #endif /* SRC_CORE_SVE_KERNELS_ELEMENTWISE_LIST_H */
164