1 /*
2 * Copyright (c) 2021-2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #ifndef SRC_CORE_NEON_KERNELS_ELEMENTWISE_IMPL_H
25 #define SRC_CORE_NEON_KERNELS_ELEMENTWISE_IMPL_H
26
27 #include "src/core/NEON/NEAsymm.h"
28
29 namespace arm_compute
30 {
31 namespace cpu
32 {
33 template <ArithmeticOperation op, typename VectorType>
elementwise_arithm_op(const typename VectorType::type & a,const typename VectorType::type & b)34 typename VectorType::type elementwise_arithm_op(const typename VectorType::type &a, const typename VectorType::type &b)
35 {
36 using vec_type = typename VectorType::type;
37 using scalar_type = typename VectorType::scalar_type;
38 using tag_type = typename VectorType::tag_type;
39
40 vec_type res = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
41
42 switch(op)
43 {
44 case ArithmeticOperation::MAX:
45 res = wrapper::vmax(a, b);
46 break;
47 case ArithmeticOperation::MIN:
48 res = wrapper::vmin(a, b);
49 break;
50 case ArithmeticOperation::SQUARED_DIFF:
51 {
52 const vec_type tmp = wrapper::vsub(a, b);
53 res = wrapper::vmul(tmp, tmp);
54 break;
55 }
56 case ArithmeticOperation::PRELU:
57 {
58 const vec_type zero = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
59 const vec_type tmp = wrapper::vmul(a, b);
60 const auto gt = wrapper::vcgt(a, zero);
61
62 res = wrapper::vbsl(gt, a, tmp);
63 break;
64 }
65
66 default:
67 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
68 }
69
70 return res;
71 }
72
73 template <ArithmeticOperation op, typename ScalarType, typename VectorType>
elementwise_arithm_op_broadcast(const typename VectorType::type & a,const ScalarType & broadcast_value,const bool reorder)74 typename VectorType::type elementwise_arithm_op_broadcast(const typename VectorType::type &a, const ScalarType &broadcast_value, const bool reorder)
75 {
76 using tag_type = typename VectorType::tag_type;
77 using vec_type = typename VectorType::type;
78
79 vec_type broadcast_vector = wrapper::vdup_n(broadcast_value, tag_type{});
80 return elementwise_arithm_op<op, VectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
81 }
82
83 template <typename InputScalarType, typename OutputScalarType, typename InputVectorType>
elementwise_op(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,OutputScalarType (* scalar_func)(const InputScalarType &,const InputScalarType &),int (* broadcast_func)(int,int,int,const InputScalarType *,const InputScalarType &,OutputScalarType *,const bool),int (* neon_func)(int,int,int,const InputScalarType *,const InputScalarType *,OutputScalarType *))84 void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
85 OutputScalarType (*scalar_func)(const InputScalarType &, const InputScalarType &),
86 int (*broadcast_func)(int, int, int, const InputScalarType *, const InputScalarType &, OutputScalarType *, const bool),
87 int (*neon_func)(int, int, int, const InputScalarType *, const InputScalarType *, OutputScalarType *))
88 {
89 // Create input windows
90 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
91 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
92
93 // Clear X Dimension on execution window as we handle manually
94 Window win = window;
95 win.set(Window::DimX, Window::Dimension(0, 1, 1));
96
97 const int window_step_x = std::min(16 / static_cast<int>(sizeof(OutputScalarType)), 8);
98 const auto window_start_x = static_cast<int>(window.x().start());
99 const auto window_end_x = static_cast<int>(window.x().end());
100 const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
101
102 if(is_broadcast_across_x)
103 {
104 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
105 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
106 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
107 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
108 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
109
110 // Clear X Dimension on execution window as we handle manually
111 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
112
113 Iterator broadcast_input(broadcast_tensor, broadcast_win);
114 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
115 Iterator output(out, win);
116
117 execute_window_loop(win, [&](const Coordinates &)
118 {
119 auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
120 const auto non_broadcast_input_ptr = reinterpret_cast<const InputScalarType *>(non_broadcast_input.ptr());
121 const InputScalarType broadcast_value = *reinterpret_cast<const InputScalarType *>(broadcast_input.ptr());
122
123 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_value, output_ptr, !is_broadcast_input_2);
124 for(; x < window_end_x; ++x)
125 {
126 const auto a = *(non_broadcast_input_ptr + x);
127 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? broadcast_value : a, !is_broadcast_input_2 ? a : broadcast_value);
128 }
129 },
130 broadcast_input, non_broadcast_input, output);
131 }
132 else
133 {
134 // Clear X Dimension on execution window as we handle manually
135 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
136 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
137
138 Iterator input1(in1, input1_win);
139 Iterator input2(in2, input2_win);
140 Iterator output(out, win);
141
142 execute_window_loop(win, [&](const Coordinates &)
143 {
144 auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
145 const auto input1_ptr = reinterpret_cast<const InputScalarType *>(input1.ptr());
146 const auto input2_ptr = reinterpret_cast<const InputScalarType *>(input2.ptr());
147
148 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr);
149 for(; x < window_end_x; ++x)
150 {
151 const auto a = *(input1_ptr + x);
152 const auto b = *(input2_ptr + x);
153 *(output_ptr + x) = (*scalar_func)(a, b);
154 }
155 },
156 input1, input2, output);
157 }
158 }
159
160 template <ArithmeticOperation op, typename ScalarType>
elementwise_arithm_op_scalar(const ScalarType & a,const ScalarType & b)161 inline ScalarType elementwise_arithm_op_scalar(const ScalarType &a, const ScalarType &b)
162 {
163 auto res = ScalarType(0);
164
165 switch(op)
166 {
167 case ArithmeticOperation::MAX:
168 res = std::max(a, b);
169 break;
170 case ArithmeticOperation::MIN:
171 res = std::min(a, b);
172 break;
173 case ArithmeticOperation::SQUARED_DIFF:
174 {
175 res = (a - b) * (a - b);
176 break;
177 }
178 case ArithmeticOperation::PRELU:
179 {
180 res = (a > 0 ? a : a * b);
181 break;
182 }
183 case ArithmeticOperation::DIV:
184 {
185 res = a / b;
186 if(std::is_integral<ScalarType>::value)
187 {
188 res = (b == 0) ? 0 : res;
189 if(static_cast<int32_t>(a) % static_cast<int32_t>(b) != 0 && ((a < 0) != (b < 0)))
190 {
191 --res;
192 }
193 }
194 break;
195 }
196 case ArithmeticOperation::POWER:
197 {
198 res = std::pow(a, b);
199 break;
200 }
201 default:
202 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
203 }
204 return res;
205 }
206
207 template <>
208 inline int32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<int32_t, 4>>(const int32x4_t &a, const int32x4_t &b)
209 {
210 return vcvtq_s32_f32(vfloorq_f32(wrapper::vdiv(vcvtq_f32_s32(a), vcvtq_f32_s32(b))));
211 }
212
213 template <>
214 inline float32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
215 {
216 return wrapper::vdiv(a, b);
217 }
218
219 template <>
220 inline float32x4_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
221 {
222 return wrapper::vpow(a, b);
223 }
224
225 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
226 template <>
227 inline float16x8_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b)
228 {
229 return wrapper::vdiv(a, b);
230 }
231
232 template <>
233 inline float16x8_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b)
234 {
235 return wrapper::vpow(a, b);
236 }
237 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
238
239 template <ArithmeticOperation op, typename ScalarType, typename VectorType>
elementwise_arithm_op_loop(int window_start_x,int window_end_x,int window_step_x,const ScalarType * input1_ptr,const ScalarType * input2_ptr,ScalarType * output_ptr)240 inline int elementwise_arithm_op_loop(int window_start_x, int window_end_x, int window_step_x,
241 const ScalarType *input1_ptr, const ScalarType *input2_ptr, ScalarType *output_ptr)
242 {
243 int x = window_start_x;
244 for(; x <= (window_end_x - window_step_x); x += window_step_x)
245 {
246 const auto a = wrapper::vloadq(input1_ptr + x);
247 const auto b = wrapper::vloadq(input2_ptr + x);
248 wrapper::vstore(output_ptr + x, elementwise_arithm_op<op, VectorType>(a, b));
249 }
250 return x;
251 }
252
253 template <ArithmeticOperation op, typename ScalarType, typename VectorType>
elementwise_arithm_op_broadcast_loop(int window_start_x,int window_end_x,int window_step_x,const ScalarType * non_broadcast_input_ptr,const ScalarType & broadcast_value,ScalarType * output_ptr,const bool reorder)254 inline int elementwise_arithm_op_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
255 const ScalarType *non_broadcast_input_ptr, const ScalarType &broadcast_value, ScalarType *output_ptr, const bool reorder)
256 {
257 int x = window_start_x;
258 for(; x <= (window_end_x - window_step_x); x += window_step_x)
259 {
260 const auto a = wrapper::vloadq((non_broadcast_input_ptr + x));
261 wrapper::vstore(output_ptr + x, elementwise_arithm_op_broadcast<op, ScalarType, VectorType>(a, broadcast_value, reorder));
262 }
263 return x;
264 }
265
266 template <ArithmeticOperation op, typename VectorType>
elementwise_arithm_op(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window)267 void elementwise_arithm_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
268 {
269 using scalar_type = typename VectorType::scalar_type;
270
271 elementwise_op<scalar_type, scalar_type, VectorType>(in1, in2, out, window,
272 &elementwise_arithm_op_scalar<op, scalar_type>,
273 &elementwise_arithm_op_broadcast_loop<op, scalar_type, VectorType>,
274 &elementwise_arithm_op_loop<op, scalar_type, VectorType>);
275 }
276
277 template <ComparisonOperation op, typename InputScalarType>
elementwise_comp_op_scalar(const InputScalarType & a,const InputScalarType & b)278 inline uint8_t elementwise_comp_op_scalar(const InputScalarType &a, const InputScalarType &b)
279 {
280 bool res = false;
281
282 switch(op)
283 {
284 case ComparisonOperation::Equal:
285 res = (a == b);
286 break;
287 case ComparisonOperation::NotEqual:
288 res = (a != b);
289 break;
290 case ComparisonOperation::Greater:
291 res = (a > b);
292 break;
293 case ComparisonOperation::GreaterEqual:
294 res = (a >= b);
295 break;
296 case ComparisonOperation::Less:
297 res = (a < b);
298 break;
299 case ComparisonOperation::LessEqual:
300 res = (a <= b);
301 break;
302 default:
303 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
304 }
305 return res ? ~static_cast<uint8_t>(0) : static_cast<uint8_t>(0);
306 }
307
308 template <ComparisonOperation op, typename InputVectorType, typename OutputVectorType>
elementwise_comp_op(const InputVectorType & a,const InputVectorType & b)309 inline OutputVectorType elementwise_comp_op(const InputVectorType &a, const InputVectorType &b)
310 {
311 OutputVectorType res = { 0, 0, 0, 0 };
312
313 switch(op)
314 {
315 case ComparisonOperation::Equal:
316 res = wrapper::vceq(a, b);
317 break;
318 case ComparisonOperation::NotEqual:
319 res = wrapper::vnot(wrapper::vceq(a, b));
320 break;
321 case ComparisonOperation::Greater:
322 res = wrapper::vcgt(a, b);
323 break;
324 case ComparisonOperation::GreaterEqual:
325 res = wrapper::vcge(a, b);
326 break;
327 case ComparisonOperation::Less:
328 res = wrapper::vcgt(b, a);
329 break;
330 case ComparisonOperation::LessEqual:
331 res = wrapper::vcge(b, a);
332 break;
333 default:
334 ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
335 }
336
337 return res;
338 }
339
340 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType, typename OutputVectorType>
elementwise_comp_op_broadcast(const InputVectorType & a,const InputScalarType & broadcast_value,const bool reorder)341 inline OutputVectorType elementwise_comp_op_broadcast(const InputVectorType &a, const InputScalarType &broadcast_value, const bool reorder)
342 {
343 InputVectorType broadcast_vector = wrapper::vdup_n(broadcast_value, wrapper::traits::vector_128_tag());
344 return elementwise_comp_op<op, InputVectorType, OutputVectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
345 }
346
347 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
elementwise_comp_op_broadcast_8_loop(int window_start_x,int window_end_x,int window_step_x,const InputScalarType * non_broadcast_input_ptr,const InputScalarType & broadcast_value,uint8_t * output_ptr,const bool reorder)348 inline int elementwise_comp_op_broadcast_8_loop(int window_start_x, int window_end_x, int window_step_x,
349 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
350 {
351 int x = window_start_x;
352 for(; x <= (window_end_x - window_step_x); x += window_step_x)
353 {
354 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint8x16_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
355 wrapper::vstore(output_ptr + x, a);
356 }
357 return x;
358 }
359
360 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
elementwise_comp_op_broadcast_16_loop(int window_start_x,int window_end_x,int window_step_x,const InputScalarType * non_broadcast_input_ptr,const InputScalarType & broadcast_value,uint8_t * output_ptr,const bool reorder)361 inline int elementwise_comp_op_broadcast_16_loop(int window_start_x, int window_end_x, int window_step_x,
362 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
363 {
364 int x = window_start_x;
365 for(; x <= (window_end_x - window_step_x); x += window_step_x)
366 {
367 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint16x8_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
368 wrapper::vstore(output_ptr + x, wrapper::vmovn(a));
369 }
370 return x;
371 }
372
373 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
elementwise_comp_op_broadcast_32_loop(int window_start_x,int window_end_x,int window_step_x,const InputScalarType * non_broadcast_input_ptr,const InputScalarType & broadcast_value,uint8_t * output_ptr,const bool reorder)374 inline int elementwise_comp_op_broadcast_32_loop(int window_start_x, int window_end_x, int window_step_x,
375 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
376 {
377 int x = window_start_x;
378 for(; x <= (window_end_x - window_step_x); x += window_step_x)
379 {
380 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x), broadcast_value, reorder);
381 const auto b = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x + 4), broadcast_value, reorder);
382 wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(a), wrapper::vmovn(b))));
383 }
384 if(x <= window_end_x - 4)
385 {
386 const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
387 for(int i = 0; i < 4; i++)
388 {
389 *(output_ptr + x + i) = wrapper::vgetlane(a, i);
390 }
391 x = +4;
392 }
393 return x;
394 }
395
396 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
elementwise_comp_op_8_loop(int window_start_x,int window_end_x,int window_step_x,const InputScalarType * input1_ptr,const InputScalarType * input2_ptr,uint8_t * output_ptr)397 inline int elementwise_comp_op_8_loop(int window_start_x, int window_end_x, int window_step_x,
398 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
399 {
400 int x = window_start_x;
401 for(; x <= (window_end_x - window_step_x); x += window_step_x)
402 {
403 const auto a = wrapper::vloadq(input1_ptr + x);
404 const auto b = wrapper::vloadq(input2_ptr + x);
405 const auto res = elementwise_comp_op<op, InputVectorType, uint8x16_t>(a, b);
406 wrapper::vstore(output_ptr + x, res);
407 }
408 return x;
409 }
410
411 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
elementwise_comp_op_16_loop(int window_start_x,int window_end_x,int window_step_x,const InputScalarType * input1_ptr,const InputScalarType * input2_ptr,uint8_t * output_ptr)412 inline int elementwise_comp_op_16_loop(int window_start_x, int window_end_x, int window_step_x,
413 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
414 {
415 int x = window_start_x;
416 for(; x <= (window_end_x - window_step_x); x += window_step_x)
417 {
418 const auto a = wrapper::vloadq(input1_ptr + x);
419 const auto b = wrapper::vloadq(input2_ptr + x);
420 const auto res = elementwise_comp_op<op, InputVectorType, uint16x8_t>(a, b);
421 wrapper::vstore(output_ptr + x, wrapper::vmovn(res));
422 }
423 return x;
424 }
425
426 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
elementwise_comp_op_32_loop(int window_start_x,int window_end_x,int window_step_x,const InputScalarType * input1_ptr,const InputScalarType * input2_ptr,uint8_t * output_ptr)427 inline int elementwise_comp_op_32_loop(int window_start_x, int window_end_x, int window_step_x,
428 const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
429 {
430 int x = window_start_x;
431 for(; x <= (window_end_x - window_step_x); x += window_step_x)
432 {
433 auto a = wrapper::vloadq(input1_ptr + x);
434 auto b = wrapper::vloadq(input2_ptr + x);
435 const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
436 a = wrapper::vloadq(input1_ptr + x + 4);
437 b = wrapper::vloadq(input2_ptr + x + 4);
438 const auto res2 = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
439 wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(res), wrapper::vmovn(res2))));
440 }
441 if(x <= window_end_x - 4)
442 {
443 const auto a = wrapper::vloadq(input1_ptr + x);
444 const auto b = wrapper::vloadq(input2_ptr + x);
445 const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
446 for(int i = 0; i < 4; i++)
447 {
448 *(output_ptr + x + i) = wrapper::vgetlane(res, i);
449 }
450 x = +4;
451 }
452 return x;
453 }
454
455 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
elementwise_comp_op_8(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window)456 void elementwise_comp_op_8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
457 {
458 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
459 &elementwise_comp_op_scalar<op, InputScalarType>,
460 &elementwise_comp_op_broadcast_8_loop<op, InputScalarType, InputVectorType>,
461 &elementwise_comp_op_8_loop<op, InputScalarType, InputVectorType>);
462 }
463
464 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
elementwise_comp_op_16(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window)465 void elementwise_comp_op_16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
466 {
467 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
468 &elementwise_comp_op_scalar<op, InputScalarType>,
469 &elementwise_comp_op_broadcast_16_loop<op, InputScalarType, InputVectorType>,
470 &elementwise_comp_op_16_loop<op, InputScalarType, InputVectorType>);
471 }
472
473 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
elementwise_comp_op_32(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window)474 void elementwise_comp_op_32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
475 {
476 elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
477 &elementwise_comp_op_scalar<op, InputScalarType>,
478 &elementwise_comp_op_broadcast_32_loop<op, InputScalarType, InputVectorType>,
479 &elementwise_comp_op_32_loop<op, InputScalarType, InputVectorType>);
480 }
481
load_quantized(const uint8_t * input1_ptr,const int32x4_t & offset,const float32x4_t & scale)482 inline float32x4x4_t load_quantized(const uint8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale)
483 {
484 qasymm8x16_t x = vld1q_u8(input1_ptr);
485 const float32x4x4_t out =
486 {
487 {
488 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale),
489 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale),
490 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale),
491 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale),
492 }
493 };
494 return out;
495 }
496
load_quantized_signed(const int8_t * input1_ptr,const int32x4_t & offset,const float32x4_t & scale)497 inline float32x4x4_t load_quantized_signed(const int8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale)
498 {
499 qasymm8x16_signed_t x = vld1q_s8(input1_ptr);
500 const float32x4x4_t out =
501 {
502 {
503 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale),
504 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale),
505 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale),
506 vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale),
507 }
508 };
509 return out;
510 }
511
store_quantized(uint8_t * output_ptr,const uint32x4x4_t & out)512 inline void store_quantized(uint8_t *output_ptr, const uint32x4x4_t &out)
513 {
514 const uint8x8_t pa = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[0]), vqmovn_u32(out.val[1])));
515 const uint8x8_t pb = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[2]), vqmovn_u32(out.val[3])));
516 vst1q_u8(output_ptr, vcombine_u8(pa, pb));
517 }
518
store_quantized(uint8_t * output_ptr,const int32x4x4_t & out)519 inline void store_quantized(uint8_t *output_ptr, const int32x4x4_t &out)
520 {
521 const uint8x8_t pa = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
522 const uint8x8_t pb = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
523 vst1q_u8(output_ptr, vcombine_u8(pa, pb));
524 }
525
store_quantized(uint8_t * output_ptr,const float32x4x4_t & rf,const float32x4_t & offset,const float32x4_t & invscale)526 inline void store_quantized(uint8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale)
527 {
528 int32x4x4_t out =
529 {
530 {
531 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)),
532 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)),
533 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)),
534 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)),
535 }
536 };
537 store_quantized(output_ptr, out);
538 }
539
store_quantized_signed(int8_t * output_ptr,const int32x4x4_t & out)540 inline void store_quantized_signed(int8_t *output_ptr, const int32x4x4_t &out)
541 {
542 const int8x8_t pa = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
543 const int8x8_t pb = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
544 vst1q_s8(output_ptr, vcombine_s8(pa, pb));
545 }
546
store_quantized_signed(int8_t * output_ptr,const float32x4x4_t & rf,const float32x4_t & offset,const float32x4_t & invscale)547 inline void store_quantized_signed(int8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale)
548 {
549 int32x4x4_t out =
550 {
551 {
552 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)),
553 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)),
554 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)),
555 vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)),
556 }
557 };
558 store_quantized_signed(output_ptr, out);
559 }
560
561 template <ArithmeticOperation op>
elementwise_arithm_op_quantized_scalar(const float & a,const float & b,UniformQuantizationInfo qinfo)562 inline uint8_t elementwise_arithm_op_quantized_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
563 {
564 return quantize_qasymm8(elementwise_arithm_op_scalar<op>(a, b), qinfo);
565 }
566
567 template <ArithmeticOperation op>
elementwise_arithm_op_quantized_signed_scalar(const float & a,const float & b,UniformQuantizationInfo qinfo)568 inline int8_t elementwise_arithm_op_quantized_signed_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
569 {
570 return quantize_qasymm8_signed(elementwise_arithm_op_scalar<op>(a, b), qinfo);
571 }
572
573 template <ArithmeticOperation op>
elementwise_arithm_op(const float32x4x4_t & a,const float32x4x4_t & b)574 float32x4x4_t elementwise_arithm_op(const float32x4x4_t &a, const float32x4x4_t &b)
575 {
576 using neon_vector_float = wrapper::traits::neon_vector<float, 4>;
577 float32x4x4_t out =
578 {
579 {
580 elementwise_arithm_op<op, neon_vector_float>(a.val[0], b.val[0]),
581 elementwise_arithm_op<op, neon_vector_float>(a.val[1], b.val[1]),
582 elementwise_arithm_op<op, neon_vector_float>(a.val[2], b.val[2]),
583 elementwise_arithm_op<op, neon_vector_float>(a.val[3], b.val[3]),
584 }
585 };
586 return out;
587 }
588
589 template <ComparisonOperation op>
elementwise_comp_op_quantized_scalar(const float & a,const float & b,UniformQuantizationInfo qinfo)590 inline uint8_t elementwise_comp_op_quantized_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
591 {
592 ARM_COMPUTE_UNUSED(qinfo);
593 return elementwise_comp_op_scalar<op>(a, b);
594 }
595
596 template <ComparisonOperation op>
elementwise_comp_op(const float32x4x4_t & a,const float32x4x4_t & b)597 inline uint32x4x4_t elementwise_comp_op(const float32x4x4_t &a, const float32x4x4_t &b)
598 {
599 uint32x4x4_t out =
600 {
601 {
602 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[0], b.val[0]),
603 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[1], b.val[1]),
604 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[2], b.val[2]),
605 elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[3], b.val[3])
606 }
607 };
608 return out;
609 }
610
611 template <ArithmeticOperation op>
elementwise_arithm_op_quantized_loop(int window_start_x,int window_end_x,int window_step_x,const uint8_t * input1_ptr,const uint8_t * input2_ptr,uint8_t * output_ptr,int32x4_t voffset1,int32x4_t voffset2,float32x4_t vscale1,float32x4_t vscale2,float32x4_t voffseto,float32x4_t invvscaleo)612 inline int elementwise_arithm_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x,
613 const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr,
614 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
615 float32x4_t voffseto, float32x4_t invvscaleo)
616 {
617 int x = window_start_x;
618 for(; x <= (window_end_x - window_step_x); x += window_step_x)
619 {
620 // Get inputs and compute output
621 const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1);
622 const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2);
623 const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
624 store_quantized(output_ptr + x, rf, voffseto, invvscaleo);
625 }
626 return x;
627 }
628
629 template <ArithmeticOperation op>
elementwise_arithm_op_quantized_singed_loop(int window_start_x,int window_end_x,int window_step_x,const int8_t * input1_ptr,const int8_t * input2_ptr,int8_t * output_ptr,int32x4_t voffset1,int32x4_t voffset2,float32x4_t vscale1,float32x4_t vscale2,float32x4_t voffseto,float32x4_t invvscaleo)630 inline int elementwise_arithm_op_quantized_singed_loop(int window_start_x, int window_end_x, int window_step_x,
631 const int8_t *input1_ptr, const int8_t *input2_ptr, int8_t *output_ptr,
632 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
633 float32x4_t voffseto, float32x4_t invvscaleo)
634 {
635 int x = window_start_x;
636 for(; x <= (window_end_x - window_step_x); x += window_step_x)
637 {
638 // Get inputs and compute output
639 const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1);
640 const float32x4x4_t bf = load_quantized_signed(input2_ptr + x, voffset2, vscale2);
641 const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
642 store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo);
643 }
644 return x;
645 }
646
647 template <ArithmeticOperation op>
elementwise_arithm_op_quantized_broadcast_loop(int window_start_x,int window_end_x,int window_step_x,const uint8_t * non_broadcast_input_ptr,float32x4x4_t broadcast_vector,uint8_t * output_ptr,int32x4_t voffset_non_broadcast,float32x4_t vscale_non_broadcast,float32x4_t voffseto,float32x4_t invvscaleo,bool reorder)648 inline int elementwise_arithm_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
649 const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
650 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
651 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
652 {
653 int x = window_start_x;
654 for(; x <= (window_end_x - window_step_x); x += window_step_x)
655 {
656 const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
657 const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
658 store_quantized(output_ptr + x, rf, voffseto, invvscaleo);
659 }
660 return x;
661 }
662 template <ArithmeticOperation op>
elementwise_arithm_op_quantized_signed_broadcast_loop(int window_start_x,int window_end_x,int window_step_x,const int8_t * non_broadcast_input_ptr,float32x4x4_t broadcast_vector,int8_t * output_ptr,int32x4_t voffset_non_broadcast,float32x4_t vscale_non_broadcast,float32x4_t voffseto,float32x4_t invvscaleo,bool reorder)663 inline int elementwise_arithm_op_quantized_signed_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
664 const int8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, int8_t *output_ptr,
665 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
666 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
667 {
668 int x = window_start_x;
669 for(; x <= (window_end_x - window_step_x); x += window_step_x)
670 {
671 const float32x4x4_t af = load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
672 const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
673 store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo);
674 }
675 return x;
676 }
677
678 template <ComparisonOperation op>
elementwise_comp_op_quantized_loop(int window_start_x,int window_end_x,int window_step_x,const uint8_t * input1_ptr,const uint8_t * input2_ptr,uint8_t * output_ptr,int32x4_t voffset1,int32x4_t voffset2,float32x4_t vscale1,float32x4_t vscale2,float32x4_t voffseto,float32x4_t invvscaleo)679 inline int elementwise_comp_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x,
680 const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr,
681 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
682 float32x4_t voffseto, float32x4_t invvscaleo)
683 {
684 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
685 int x = window_start_x;
686 for(; x <= (window_end_x - window_step_x); x += window_step_x)
687 {
688 const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1);
689 const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2);
690 const uint32x4x4_t rf = elementwise_comp_op<op>(af, bf);
691 store_quantized(output_ptr + x, rf);
692 }
693 return x;
694 }
695
696 template <ComparisonOperation op>
elementwise_comp_op_quantized_signed_loop(int window_start_x,int window_end_x,int window_step_x,const int8_t * input1_ptr,const int8_t * input2_ptr,uint8_t * output_ptr,int32x4_t voffset1,int32x4_t voffset2,float32x4_t vscale1,float32x4_t vscale2,float32x4_t voffseto,float32x4_t invvscaleo)697 inline int elementwise_comp_op_quantized_signed_loop(int window_start_x, int window_end_x, int window_step_x,
698 const int8_t *input1_ptr, const int8_t *input2_ptr, uint8_t *output_ptr,
699 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
700 float32x4_t voffseto, float32x4_t invvscaleo)
701 {
702 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
703 int x = window_start_x;
704 for(; x <= (window_end_x - window_step_x); x += window_step_x)
705 {
706 const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1);
707 const float32x4x4_t bf = load_quantized_signed(input2_ptr + x, voffset2, vscale2);
708 const uint32x4x4_t rf = elementwise_comp_op<op>(af, bf);
709 store_quantized(output_ptr + x, rf);
710 }
711 return x;
712 }
713
714 template <ComparisonOperation op>
elementwise_comp_op_quantized_broadcast_loop(int window_start_x,int window_end_x,int window_step_x,const uint8_t * non_broadcast_input_ptr,float32x4x4_t broadcast_vector,uint8_t * output_ptr,int32x4_t voffset_non_broadcast,float32x4_t vscale_non_broadcast,float32x4_t voffseto,float32x4_t invvscaleo,bool reorder)715 inline int elementwise_comp_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
716 const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
717 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
718 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
719 {
720 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
721 int x = window_start_x;
722 for(; x <= (window_end_x - window_step_x); x += window_step_x)
723 {
724 const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
725 const uint32x4x4_t rf = elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
726 store_quantized(output_ptr + x, rf);
727 }
728 return x;
729 }
730
731 template <ComparisonOperation op>
elementwise_comp_op_quantized_signed_broadcast_loop(int window_start_x,int window_end_x,int window_step_x,const int8_t * non_broadcast_input_ptr,float32x4x4_t broadcast_vector,uint8_t * output_ptr,int32x4_t voffset_non_broadcast,float32x4_t vscale_non_broadcast,float32x4_t voffseto,float32x4_t invvscaleo,bool reorder)732 inline int elementwise_comp_op_quantized_signed_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
733 const int8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
734 int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
735 float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
736 {
737 ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
738 int x = window_start_x;
739 for(; x <= (window_end_x - window_step_x); x += window_step_x)
740 {
741 const float32x4x4_t af = load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
742 const uint32x4x4_t rf = elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
743 store_quantized(output_ptr + x, rf);
744 }
745 return x;
746 }
747
elementwise_op_quantized(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,uint8_t (* scalar_func)(const float &,const float &,UniformQuantizationInfo),int (* broadcast_func)(int,int,int,const uint8_t *,float32x4x4_t,uint8_t *,int32x4_t,float32x4_t,float32x4_t,float32x4_t,const bool),int (* neon_func)(int,int,int,const uint8_t *,const uint8_t *,uint8_t *,int32x4_t,int32x4_t,float32x4_t,float32x4_t,float32x4_t,float32x4_t))748 inline void elementwise_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
749 uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
750 int (*broadcast_func)(int, int, int, const uint8_t *, float32x4x4_t, uint8_t *, int32x4_t, float32x4_t,
751 float32x4_t, float32x4_t, const bool),
752 int (*neon_func)(int, int, int, const uint8_t *, const uint8_t *, uint8_t *,
753 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
754 float32x4_t, float32x4_t))
755 {
756 // Create input windows
757 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
758 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
759
760 // Clear X Dimension on execution window as we handle manually
761 Window win = window;
762 win.set(Window::DimX, Window::Dimension(0, 1, 1));
763
764 const int window_step_x = 16;
765 const auto window_start_x = static_cast<int>(window.x().start());
766 const auto window_end_x = static_cast<int>(window.x().end());
767 const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
768
769 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
770
771 // Output quantization info (add 0.5 to round toward the nearest integer - 0.5 rounds away from zero)
772 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset + 0.5f);
773 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
774
775 if(is_broadcast_across_x)
776 {
777 // Select the broadcast input on the X axis
778 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
779 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
780 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
781 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
782 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
783
784 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
785 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
786
787 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
788 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
789
790 // Clear X Dimension on execution window as we handle manually
791 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
792
793 Iterator broadcast_input(broadcast_tensor, broadcast_win);
794 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
795 Iterator output(out, win);
796
797 execute_window_loop(win, [&](const Coordinates &)
798 {
799 const auto non_broadcast_input_ptr = reinterpret_cast<const uint8_t *>(non_broadcast_input.ptr());
800 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
801
802 const uint8_t broadcast_value = *reinterpret_cast<const uint8_t *>(broadcast_input.ptr());
803 const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_u8(broadcast_value), broadcast_qinfo);
804
805 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
806 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
807 for(; x < window_end_x; ++x)
808 {
809 const float afs = dequantize_qasymm8(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
810 const float bfs = dequantize_qasymm8(broadcast_value, broadcast_qinfo);
811 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
812 }
813 },
814 broadcast_input, non_broadcast_input, output);
815 }
816 else
817 {
818 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
819 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
820
821 // Input1 quantization info
822 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
823 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
824
825 // Input2 quantization info
826 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
827 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
828
829 // Clear X Dimension on execution window as we handle manually
830 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
831 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
832
833 Iterator input1(in1, input1_win);
834 Iterator input2(in2, input2_win);
835 Iterator output(out, win);
836
837 execute_window_loop(win, [&](const Coordinates &)
838 {
839 const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
840 const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
841 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
842
843 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
844 vscale1, vscale2, voffseto, invvscaleo);
845 for(; x < window_end_x; ++x)
846 {
847 const float afs = dequantize_qasymm8(*(input1_ptr + x), input1_qinfo);
848 const float bfs = dequantize_qasymm8(*(input2_ptr + x), input2_qinfo);
849 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
850 }
851 },
852 input1, input2, output);
853 }
854 }
855
elementwise_comp_quantized_signed(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,uint8_t (* scalar_func)(const float &,const float &,UniformQuantizationInfo),int (* broadcast_func)(int,int,int,const int8_t *,float32x4x4_t,uint8_t *,int32x4_t,float32x4_t,float32x4_t,float32x4_t,const bool),int (* neon_func)(int,int,int,const int8_t *,const int8_t *,uint8_t *,int32x4_t,int32x4_t,float32x4_t,float32x4_t,float32x4_t,float32x4_t))856 inline void elementwise_comp_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
857 uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
858 int (*broadcast_func)(int, int, int, const int8_t *, float32x4x4_t, uint8_t *, int32x4_t, float32x4_t,
859 float32x4_t, float32x4_t, const bool),
860 int (*neon_func)(int, int, int, const int8_t *, const int8_t *, uint8_t *,
861 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
862 float32x4_t, float32x4_t))
863 {
864 // Create input windows
865 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
866 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
867
868 // Clear X Dimension on execution window as we handle manually
869 Window win = window;
870 win.set(Window::DimX, Window::Dimension(0, 1, 1));
871
872 const int window_step_x = 16;
873 const auto window_start_x = static_cast<int>(window.x().start());
874 const auto window_end_x = static_cast<int>(window.x().end());
875 const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
876
877 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
878
879 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset);
880 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
881
882 if(is_broadcast_across_x)
883 {
884 // Select the broadcast input on the X axis
885 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
886 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
887 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
888 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
889 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
890
891 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
892 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
893
894 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
895 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
896
897 // Clear X Dimension on execution window as we handle manually
898 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
899
900 Iterator broadcast_input(broadcast_tensor, broadcast_win);
901 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
902 Iterator output(out, win);
903
904 execute_window_loop(win, [&](const Coordinates &)
905 {
906 const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr());
907 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
908
909 const int8_t broadcast_value = *reinterpret_cast<const int8_t *>(broadcast_input.ptr());
910 const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_s8(broadcast_value), broadcast_qinfo);
911
912 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
913 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
914 for(; x < window_end_x; ++x)
915 {
916 const float afs = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
917 const float bfs = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo);
918 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
919 }
920 },
921 broadcast_input, non_broadcast_input, output);
922 }
923 else
924 {
925 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
926 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
927
928 // Input1 quantization info
929 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
930 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
931
932 // Input2 quantization info
933 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
934 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
935
936 // Clear X Dimension on execution window as we handle manually
937 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
938 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
939
940 Iterator input1(in1, input1_win);
941 Iterator input2(in2, input2_win);
942 Iterator output(out, win);
943
944 execute_window_loop(win, [&](const Coordinates &)
945 {
946 const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr());
947 const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr());
948 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
949
950 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
951 vscale1, vscale2, voffseto, invvscaleo);
952 for(; x < window_end_x; ++x)
953 {
954 const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo);
955 const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo);
956 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
957 }
958 },
959 input1, input2, output);
960 }
961 }
962
elementwise_op_quantized_signed(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,int8_t (* scalar_func)(const float &,const float &,UniformQuantizationInfo),int (* broadcast_func)(int,int,int,const int8_t *,float32x4x4_t,int8_t *,int32x4_t,float32x4_t,float32x4_t,float32x4_t,const bool),int (* neon_func)(int,int,int,const int8_t *,const int8_t *,int8_t *,int32x4_t,int32x4_t,float32x4_t,float32x4_t,float32x4_t,float32x4_t))963 inline void elementwise_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
964 int8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
965 int (*broadcast_func)(int, int, int, const int8_t *, float32x4x4_t, int8_t *, int32x4_t, float32x4_t,
966 float32x4_t, float32x4_t, const bool),
967 int (*neon_func)(int, int, int, const int8_t *, const int8_t *, int8_t *,
968 int32x4_t, int32x4_t, float32x4_t, float32x4_t,
969 float32x4_t, float32x4_t))
970 {
971 // Create input windows
972 Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
973 Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
974
975 // Clear X Dimension on execution window as we handle manually
976 Window win = window;
977 win.set(Window::DimX, Window::Dimension(0, 1, 1));
978
979 const int window_step_x = 16;
980 const auto window_start_x = static_cast<int>(window.x().start());
981 const auto window_end_x = static_cast<int>(window.x().end());
982 const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
983
984 const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
985
986 const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset);
987 const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
988
989 if(is_broadcast_across_x)
990 {
991 // Select the broadcast input on the X axis
992 const bool is_broadcast_input_2 = input2_win.x().step() == 0;
993 Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win;
994 Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win;
995 const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1;
996 const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
997
998 const UniformQuantizationInfo broadcast_qinfo = broadcast_tensor->info()->quantization_info().uniform();
999 const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
1000
1001 const int32x4_t voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
1002 const float32x4_t vscale_non_broadcast = vdupq_n_f32(non_broadcast_qinfo.scale);
1003
1004 // Clear X Dimension on execution window as we handle manually
1005 non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1006
1007 Iterator broadcast_input(broadcast_tensor, broadcast_win);
1008 Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
1009 Iterator output(out, win);
1010
1011 execute_window_loop(win, [&](const Coordinates &)
1012 {
1013 const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr());
1014 const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
1015
1016 const int8_t broadcast_value = *reinterpret_cast<const int8_t *>(broadcast_input.ptr());
1017 const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_s8(broadcast_value), broadcast_qinfo);
1018
1019 int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
1020 voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
1021 for(; x < window_end_x; ++x)
1022 {
1023 const float afs = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
1024 const float bfs = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo);
1025 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
1026 }
1027 },
1028 broadcast_input, non_broadcast_input, output);
1029 }
1030 else
1031 {
1032 const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
1033 const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
1034
1035 // Input1 quantization info
1036 const int32x4_t voffset1 = vdupq_n_s32(input1_qinfo.offset);
1037 const float32x4_t vscale1 = vdupq_n_f32(input1_qinfo.scale);
1038
1039 // Input2 quantization info
1040 const int32x4_t voffset2 = vdupq_n_s32(input2_qinfo.offset);
1041 const float32x4_t vscale2 = vdupq_n_f32(input2_qinfo.scale);
1042
1043 // Clear X Dimension on execution window as we handle manually
1044 input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1045 input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1046
1047 Iterator input1(in1, input1_win);
1048 Iterator input2(in2, input2_win);
1049 Iterator output(out, win);
1050
1051 execute_window_loop(win, [&](const Coordinates &)
1052 {
1053 const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr());
1054 const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr());
1055 const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
1056
1057 int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
1058 vscale1, vscale2, voffseto, invvscaleo);
1059 for(; x < window_end_x; ++x)
1060 {
1061 const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo);
1062 const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo);
1063 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
1064 }
1065 },
1066 input1, input2, output);
1067 }
1068 }
1069
1070 template <ArithmeticOperation op>
elementwise_arithm_op_quantized(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window)1071 void elementwise_arithm_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1072 {
1073 elementwise_op_quantized(in1, in2, out, window, &elementwise_arithm_op_quantized_scalar<op>,
1074 &elementwise_arithm_op_quantized_broadcast_loop<op>,
1075 &elementwise_arithm_op_quantized_loop<op>);
1076 }
1077
1078 template <ArithmeticOperation op>
elementwise_arithm_op_quantized_signed(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window)1079 void elementwise_arithm_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1080 {
1081 elementwise_op_quantized_signed(in1, in2, out, window, &elementwise_arithm_op_quantized_signed_scalar<op>,
1082 &elementwise_arithm_op_quantized_signed_broadcast_loop<op>,
1083 &elementwise_arithm_op_quantized_singed_loop<op>);
1084 }
1085
1086 template <ComparisonOperation op>
elementwise_comp_op_quantized(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window)1087 void elementwise_comp_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1088 {
1089 elementwise_op_quantized(in1, in2, out, window, &elementwise_comp_op_quantized_scalar<op>,
1090 &elementwise_comp_op_quantized_broadcast_loop<op>,
1091 &elementwise_comp_op_quantized_loop<op>);
1092 }
1093
1094 template <ComparisonOperation op>
elementwise_comp_op_quantized_signed(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window)1095 void elementwise_comp_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1096 {
1097 elementwise_comp_quantized_signed(in1, in2, out, window, &elementwise_comp_op_quantized_scalar<op>,
1098 &elementwise_comp_op_quantized_signed_broadcast_loop<op>,
1099 &elementwise_comp_op_quantized_signed_loop<op>);
1100 }
1101 } // namespace cpu
1102 } // namespace arm_compute
1103
1104 #endif /* SRC_CORE_NEON_KERNELS_ELEMENTWISE_IMPL_H */
1105