xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/elementwise_binary/generic/neon/impl.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2021-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef SRC_CORE_NEON_KERNELS_ELEMENTWISE_IMPL_H
25 #define SRC_CORE_NEON_KERNELS_ELEMENTWISE_IMPL_H
26 
27 #include "src/core/NEON/NEAsymm.h"
28 
29 namespace arm_compute
30 {
31 namespace cpu
32 {
33 template <ArithmeticOperation op, typename VectorType>
elementwise_arithm_op(const typename VectorType::type & a,const typename VectorType::type & b)34 typename VectorType::type elementwise_arithm_op(const typename VectorType::type &a, const typename VectorType::type &b)
35 {
36     using vec_type    = typename VectorType::type;
37     using scalar_type = typename VectorType::scalar_type;
38     using tag_type    = typename VectorType::tag_type;
39 
40     vec_type res = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
41 
42     switch(op)
43     {
44         case ArithmeticOperation::MAX:
45             res = wrapper::vmax(a, b);
46             break;
47         case ArithmeticOperation::MIN:
48             res = wrapper::vmin(a, b);
49             break;
50         case ArithmeticOperation::SQUARED_DIFF:
51         {
52             const vec_type tmp = wrapper::vsub(a, b);
53             res                = wrapper::vmul(tmp, tmp);
54             break;
55         }
56         case ArithmeticOperation::PRELU:
57         {
58             const vec_type zero = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
59             const vec_type tmp  = wrapper::vmul(a, b);
60             const auto     gt   = wrapper::vcgt(a, zero);
61 
62             res = wrapper::vbsl(gt, a, tmp);
63             break;
64         }
65 
66         default:
67             ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
68     }
69 
70     return res;
71 }
72 
73 template <ArithmeticOperation op, typename ScalarType, typename VectorType>
elementwise_arithm_op_broadcast(const typename VectorType::type & a,const ScalarType & broadcast_value,const bool reorder)74 typename VectorType::type elementwise_arithm_op_broadcast(const typename VectorType::type &a, const ScalarType &broadcast_value, const bool reorder)
75 {
76     using tag_type = typename VectorType::tag_type;
77     using vec_type = typename VectorType::type;
78 
79     vec_type broadcast_vector = wrapper::vdup_n(broadcast_value, tag_type{});
80     return elementwise_arithm_op<op, VectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
81 }
82 
83 template <typename InputScalarType, typename OutputScalarType, typename InputVectorType>
elementwise_op(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,OutputScalarType (* scalar_func)(const InputScalarType &,const InputScalarType &),int (* broadcast_func)(int,int,int,const InputScalarType *,const InputScalarType &,OutputScalarType *,const bool),int (* neon_func)(int,int,int,const InputScalarType *,const InputScalarType *,OutputScalarType *))84 void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
85                     OutputScalarType (*scalar_func)(const InputScalarType &, const InputScalarType &),
86                     int (*broadcast_func)(int, int, int, const InputScalarType *, const InputScalarType &, OutputScalarType *, const bool),
87                     int (*neon_func)(int, int, int, const InputScalarType *, const InputScalarType *, OutputScalarType *))
88 {
89     // Create input windows
90     Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
91     Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
92 
93     // Clear X Dimension on execution window as we handle manually
94     Window win = window;
95     win.set(Window::DimX, Window::Dimension(0, 1, 1));
96 
97     const int  window_step_x         = std::min(16 / static_cast<int>(sizeof(OutputScalarType)), 8);
98     const auto window_start_x        = static_cast<int>(window.x().start());
99     const auto window_end_x          = static_cast<int>(window.x().end());
100     const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
101 
102     if(is_broadcast_across_x)
103     {
104         const bool     is_broadcast_input_2 = input2_win.x().step() == 0;
105         Window         broadcast_win        = is_broadcast_input_2 ? input2_win : input1_win;
106         Window         non_broadcast_win    = !is_broadcast_input_2 ? input2_win : input1_win;
107         const ITensor *broadcast_tensor     = is_broadcast_input_2 ? in2 : in1;
108         const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
109 
110         // Clear X Dimension on execution window as we handle manually
111         non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
112 
113         Iterator broadcast_input(broadcast_tensor, broadcast_win);
114         Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
115         Iterator output(out, win);
116 
117         execute_window_loop(win, [&](const Coordinates &)
118         {
119             auto                  output_ptr              = reinterpret_cast<OutputScalarType *>(output.ptr());
120             const auto            non_broadcast_input_ptr = reinterpret_cast<const InputScalarType *>(non_broadcast_input.ptr());
121             const InputScalarType broadcast_value         = *reinterpret_cast<const InputScalarType *>(broadcast_input.ptr());
122 
123             int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_value, output_ptr, !is_broadcast_input_2);
124             for(; x < window_end_x; ++x)
125             {
126                 const auto a      = *(non_broadcast_input_ptr + x);
127                 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? broadcast_value : a, !is_broadcast_input_2 ? a : broadcast_value);
128             }
129         },
130         broadcast_input, non_broadcast_input, output);
131     }
132     else
133     {
134         // Clear X Dimension on execution window as we handle manually
135         input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
136         input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
137 
138         Iterator input1(in1, input1_win);
139         Iterator input2(in2, input2_win);
140         Iterator output(out, win);
141 
142         execute_window_loop(win, [&](const Coordinates &)
143         {
144             auto       output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr());
145             const auto input1_ptr = reinterpret_cast<const InputScalarType *>(input1.ptr());
146             const auto input2_ptr = reinterpret_cast<const InputScalarType *>(input2.ptr());
147 
148             int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr);
149             for(; x < window_end_x; ++x)
150             {
151                 const auto a      = *(input1_ptr + x);
152                 const auto b      = *(input2_ptr + x);
153                 *(output_ptr + x) = (*scalar_func)(a, b);
154             }
155         },
156         input1, input2, output);
157     }
158 }
159 
160 template <ArithmeticOperation op, typename ScalarType>
elementwise_arithm_op_scalar(const ScalarType & a,const ScalarType & b)161 inline ScalarType elementwise_arithm_op_scalar(const ScalarType &a, const ScalarType &b)
162 {
163     auto res = ScalarType(0);
164 
165     switch(op)
166     {
167         case ArithmeticOperation::MAX:
168             res = std::max(a, b);
169             break;
170         case ArithmeticOperation::MIN:
171             res = std::min(a, b);
172             break;
173         case ArithmeticOperation::SQUARED_DIFF:
174         {
175             res = (a - b) * (a - b);
176             break;
177         }
178         case ArithmeticOperation::PRELU:
179         {
180             res = (a > 0 ? a : a * b);
181             break;
182         }
183         case ArithmeticOperation::DIV:
184         {
185             res = a / b;
186             if(std::is_integral<ScalarType>::value)
187             {
188                 res = (b == 0) ? 0 : res;
189                 if(static_cast<int32_t>(a) % static_cast<int32_t>(b) != 0 && ((a < 0) != (b < 0)))
190                 {
191                     --res;
192                 }
193             }
194             break;
195         }
196         case ArithmeticOperation::POWER:
197         {
198             res = std::pow(a, b);
199             break;
200         }
201         default:
202             ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
203     }
204     return res;
205 }
206 
207 template <>
208 inline int32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<int32_t, 4>>(const int32x4_t &a, const int32x4_t &b)
209 {
210     return vcvtq_s32_f32(vfloorq_f32(wrapper::vdiv(vcvtq_f32_s32(a), vcvtq_f32_s32(b))));
211 }
212 
213 template <>
214 inline float32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
215 {
216     return wrapper::vdiv(a, b);
217 }
218 
219 template <>
220 inline float32x4_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
221 {
222     return wrapper::vpow(a, b);
223 }
224 
225 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
226 template <>
227 inline float16x8_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b)
228 {
229     return wrapper::vdiv(a, b);
230 }
231 
232 template <>
233 inline float16x8_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b)
234 {
235     return wrapper::vpow(a, b);
236 }
237 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
238 
239 template <ArithmeticOperation op, typename ScalarType, typename VectorType>
elementwise_arithm_op_loop(int window_start_x,int window_end_x,int window_step_x,const ScalarType * input1_ptr,const ScalarType * input2_ptr,ScalarType * output_ptr)240 inline int elementwise_arithm_op_loop(int window_start_x, int window_end_x, int window_step_x,
241                                       const ScalarType *input1_ptr, const ScalarType *input2_ptr, ScalarType *output_ptr)
242 {
243     int x = window_start_x;
244     for(; x <= (window_end_x - window_step_x); x += window_step_x)
245     {
246         const auto a = wrapper::vloadq(input1_ptr + x);
247         const auto b = wrapper::vloadq(input2_ptr + x);
248         wrapper::vstore(output_ptr + x, elementwise_arithm_op<op, VectorType>(a, b));
249     }
250     return x;
251 }
252 
253 template <ArithmeticOperation op, typename ScalarType, typename VectorType>
elementwise_arithm_op_broadcast_loop(int window_start_x,int window_end_x,int window_step_x,const ScalarType * non_broadcast_input_ptr,const ScalarType & broadcast_value,ScalarType * output_ptr,const bool reorder)254 inline int elementwise_arithm_op_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
255                                                 const ScalarType *non_broadcast_input_ptr, const ScalarType &broadcast_value, ScalarType *output_ptr, const bool reorder)
256 {
257     int x = window_start_x;
258     for(; x <= (window_end_x - window_step_x); x += window_step_x)
259     {
260         const auto a = wrapper::vloadq((non_broadcast_input_ptr + x));
261         wrapper::vstore(output_ptr + x, elementwise_arithm_op_broadcast<op, ScalarType, VectorType>(a, broadcast_value, reorder));
262     }
263     return x;
264 }
265 
266 template <ArithmeticOperation op, typename VectorType>
elementwise_arithm_op(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window)267 void elementwise_arithm_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
268 {
269     using scalar_type = typename VectorType::scalar_type;
270 
271     elementwise_op<scalar_type, scalar_type, VectorType>(in1, in2, out, window,
272                                                          &elementwise_arithm_op_scalar<op, scalar_type>,
273                                                          &elementwise_arithm_op_broadcast_loop<op, scalar_type, VectorType>,
274                                                          &elementwise_arithm_op_loop<op, scalar_type, VectorType>);
275 }
276 
277 template <ComparisonOperation op, typename InputScalarType>
elementwise_comp_op_scalar(const InputScalarType & a,const InputScalarType & b)278 inline uint8_t elementwise_comp_op_scalar(const InputScalarType &a, const InputScalarType &b)
279 {
280     bool res = false;
281 
282     switch(op)
283     {
284         case ComparisonOperation::Equal:
285             res = (a == b);
286             break;
287         case ComparisonOperation::NotEqual:
288             res = (a != b);
289             break;
290         case ComparisonOperation::Greater:
291             res = (a > b);
292             break;
293         case ComparisonOperation::GreaterEqual:
294             res = (a >= b);
295             break;
296         case ComparisonOperation::Less:
297             res = (a < b);
298             break;
299         case ComparisonOperation::LessEqual:
300             res = (a <= b);
301             break;
302         default:
303             ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
304     }
305     return res ? ~static_cast<uint8_t>(0) : static_cast<uint8_t>(0);
306 }
307 
308 template <ComparisonOperation op, typename InputVectorType, typename OutputVectorType>
elementwise_comp_op(const InputVectorType & a,const InputVectorType & b)309 inline OutputVectorType elementwise_comp_op(const InputVectorType &a, const InputVectorType &b)
310 {
311     OutputVectorType res = { 0, 0, 0, 0 };
312 
313     switch(op)
314     {
315         case ComparisonOperation::Equal:
316             res = wrapper::vceq(a, b);
317             break;
318         case ComparisonOperation::NotEqual:
319             res = wrapper::vnot(wrapper::vceq(a, b));
320             break;
321         case ComparisonOperation::Greater:
322             res = wrapper::vcgt(a, b);
323             break;
324         case ComparisonOperation::GreaterEqual:
325             res = wrapper::vcge(a, b);
326             break;
327         case ComparisonOperation::Less:
328             res = wrapper::vcgt(b, a);
329             break;
330         case ComparisonOperation::LessEqual:
331             res = wrapper::vcge(b, a);
332             break;
333         default:
334             ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
335     }
336 
337     return res;
338 }
339 
340 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType, typename OutputVectorType>
elementwise_comp_op_broadcast(const InputVectorType & a,const InputScalarType & broadcast_value,const bool reorder)341 inline OutputVectorType elementwise_comp_op_broadcast(const InputVectorType &a, const InputScalarType &broadcast_value, const bool reorder)
342 {
343     InputVectorType broadcast_vector = wrapper::vdup_n(broadcast_value, wrapper::traits::vector_128_tag());
344     return elementwise_comp_op<op, InputVectorType, OutputVectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
345 }
346 
347 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
elementwise_comp_op_broadcast_8_loop(int window_start_x,int window_end_x,int window_step_x,const InputScalarType * non_broadcast_input_ptr,const InputScalarType & broadcast_value,uint8_t * output_ptr,const bool reorder)348 inline int elementwise_comp_op_broadcast_8_loop(int window_start_x, int window_end_x, int window_step_x,
349                                                 const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
350 {
351     int x = window_start_x;
352     for(; x <= (window_end_x - window_step_x); x += window_step_x)
353     {
354         const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint8x16_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
355         wrapper::vstore(output_ptr + x, a);
356     }
357     return x;
358 }
359 
360 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
elementwise_comp_op_broadcast_16_loop(int window_start_x,int window_end_x,int window_step_x,const InputScalarType * non_broadcast_input_ptr,const InputScalarType & broadcast_value,uint8_t * output_ptr,const bool reorder)361 inline int elementwise_comp_op_broadcast_16_loop(int window_start_x, int window_end_x, int window_step_x,
362                                                  const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
363 {
364     int x = window_start_x;
365     for(; x <= (window_end_x - window_step_x); x += window_step_x)
366     {
367         const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint16x8_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
368         wrapper::vstore(output_ptr + x, wrapper::vmovn(a));
369     }
370     return x;
371 }
372 
373 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
elementwise_comp_op_broadcast_32_loop(int window_start_x,int window_end_x,int window_step_x,const InputScalarType * non_broadcast_input_ptr,const InputScalarType & broadcast_value,uint8_t * output_ptr,const bool reorder)374 inline int elementwise_comp_op_broadcast_32_loop(int window_start_x, int window_end_x, int window_step_x,
375                                                  const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder)
376 {
377     int x = window_start_x;
378     for(; x <= (window_end_x - window_step_x); x += window_step_x)
379     {
380         const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x), broadcast_value, reorder);
381         const auto b = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x + 4), broadcast_value, reorder);
382         wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(a), wrapper::vmovn(b))));
383     }
384     if(x <= window_end_x - 4)
385     {
386         const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder);
387         for(int i = 0; i < 4; i++)
388         {
389             *(output_ptr + x + i) = wrapper::vgetlane(a, i);
390         }
391         x = +4;
392     }
393     return x;
394 }
395 
396 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
elementwise_comp_op_8_loop(int window_start_x,int window_end_x,int window_step_x,const InputScalarType * input1_ptr,const InputScalarType * input2_ptr,uint8_t * output_ptr)397 inline int elementwise_comp_op_8_loop(int window_start_x, int window_end_x, int window_step_x,
398                                       const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
399 {
400     int x = window_start_x;
401     for(; x <= (window_end_x - window_step_x); x += window_step_x)
402     {
403         const auto a   = wrapper::vloadq(input1_ptr + x);
404         const auto b   = wrapper::vloadq(input2_ptr + x);
405         const auto res = elementwise_comp_op<op, InputVectorType, uint8x16_t>(a, b);
406         wrapper::vstore(output_ptr + x, res);
407     }
408     return x;
409 }
410 
411 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
elementwise_comp_op_16_loop(int window_start_x,int window_end_x,int window_step_x,const InputScalarType * input1_ptr,const InputScalarType * input2_ptr,uint8_t * output_ptr)412 inline int elementwise_comp_op_16_loop(int window_start_x, int window_end_x, int window_step_x,
413                                        const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
414 {
415     int x = window_start_x;
416     for(; x <= (window_end_x - window_step_x); x += window_step_x)
417     {
418         const auto a   = wrapper::vloadq(input1_ptr + x);
419         const auto b   = wrapper::vloadq(input2_ptr + x);
420         const auto res = elementwise_comp_op<op, InputVectorType, uint16x8_t>(a, b);
421         wrapper::vstore(output_ptr + x, wrapper::vmovn(res));
422     }
423     return x;
424 }
425 
426 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
elementwise_comp_op_32_loop(int window_start_x,int window_end_x,int window_step_x,const InputScalarType * input1_ptr,const InputScalarType * input2_ptr,uint8_t * output_ptr)427 inline int elementwise_comp_op_32_loop(int window_start_x, int window_end_x, int window_step_x,
428                                        const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr)
429 {
430     int x = window_start_x;
431     for(; x <= (window_end_x - window_step_x); x += window_step_x)
432     {
433         auto       a    = wrapper::vloadq(input1_ptr + x);
434         auto       b    = wrapper::vloadq(input2_ptr + x);
435         const auto res  = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
436         a               = wrapper::vloadq(input1_ptr + x + 4);
437         b               = wrapper::vloadq(input2_ptr + x + 4);
438         const auto res2 = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
439         wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(res), wrapper::vmovn(res2))));
440     }
441     if(x <= window_end_x - 4)
442     {
443         const auto a   = wrapper::vloadq(input1_ptr + x);
444         const auto b   = wrapper::vloadq(input2_ptr + x);
445         const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b);
446         for(int i = 0; i < 4; i++)
447         {
448             *(output_ptr + x + i) = wrapper::vgetlane(res, i);
449         }
450         x = +4;
451     }
452     return x;
453 }
454 
455 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
elementwise_comp_op_8(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window)456 void elementwise_comp_op_8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
457 {
458     elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
459                                                               &elementwise_comp_op_scalar<op, InputScalarType>,
460                                                               &elementwise_comp_op_broadcast_8_loop<op, InputScalarType, InputVectorType>,
461                                                               &elementwise_comp_op_8_loop<op, InputScalarType, InputVectorType>);
462 }
463 
464 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
elementwise_comp_op_16(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window)465 void elementwise_comp_op_16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
466 {
467     elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
468                                                               &elementwise_comp_op_scalar<op, InputScalarType>,
469                                                               &elementwise_comp_op_broadcast_16_loop<op, InputScalarType, InputVectorType>,
470                                                               &elementwise_comp_op_16_loop<op, InputScalarType, InputVectorType>);
471 }
472 
473 template <ComparisonOperation op, typename InputScalarType, typename InputVectorType>
elementwise_comp_op_32(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window)474 void elementwise_comp_op_32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
475 {
476     elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window,
477                                                               &elementwise_comp_op_scalar<op, InputScalarType>,
478                                                               &elementwise_comp_op_broadcast_32_loop<op, InputScalarType, InputVectorType>,
479                                                               &elementwise_comp_op_32_loop<op, InputScalarType, InputVectorType>);
480 }
481 
load_quantized(const uint8_t * input1_ptr,const int32x4_t & offset,const float32x4_t & scale)482 inline float32x4x4_t load_quantized(const uint8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale)
483 {
484     qasymm8x16_t        x = vld1q_u8(input1_ptr);
485     const float32x4x4_t out =
486     {
487         {
488             vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale),
489             vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale),
490             vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale),
491             vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale),
492         }
493     };
494     return out;
495 }
496 
load_quantized_signed(const int8_t * input1_ptr,const int32x4_t & offset,const float32x4_t & scale)497 inline float32x4x4_t load_quantized_signed(const int8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale)
498 {
499     qasymm8x16_signed_t x = vld1q_s8(input1_ptr);
500     const float32x4x4_t out =
501     {
502         {
503             vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale),
504             vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale),
505             vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale),
506             vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale),
507         }
508     };
509     return out;
510 }
511 
store_quantized(uint8_t * output_ptr,const uint32x4x4_t & out)512 inline void store_quantized(uint8_t *output_ptr, const uint32x4x4_t &out)
513 {
514     const uint8x8_t pa = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[0]), vqmovn_u32(out.val[1])));
515     const uint8x8_t pb = vqmovn_u16(vcombine_u16(vqmovn_u32(out.val[2]), vqmovn_u32(out.val[3])));
516     vst1q_u8(output_ptr, vcombine_u8(pa, pb));
517 }
518 
store_quantized(uint8_t * output_ptr,const int32x4x4_t & out)519 inline void store_quantized(uint8_t *output_ptr, const int32x4x4_t &out)
520 {
521     const uint8x8_t pa = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
522     const uint8x8_t pb = vqmovun_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
523     vst1q_u8(output_ptr, vcombine_u8(pa, pb));
524 }
525 
store_quantized(uint8_t * output_ptr,const float32x4x4_t & rf,const float32x4_t & offset,const float32x4_t & invscale)526 inline void store_quantized(uint8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale)
527 {
528     int32x4x4_t out =
529     {
530         {
531             vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)),
532             vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)),
533             vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)),
534             vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)),
535         }
536     };
537     store_quantized(output_ptr, out);
538 }
539 
store_quantized_signed(int8_t * output_ptr,const int32x4x4_t & out)540 inline void store_quantized_signed(int8_t *output_ptr, const int32x4x4_t &out)
541 {
542     const int8x8_t pa = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[0]), vqmovn_s32(out.val[1])));
543     const int8x8_t pb = vqmovn_s16(vcombine_s16(vqmovn_s32(out.val[2]), vqmovn_s32(out.val[3])));
544     vst1q_s8(output_ptr, vcombine_s8(pa, pb));
545 }
546 
store_quantized_signed(int8_t * output_ptr,const float32x4x4_t & rf,const float32x4_t & offset,const float32x4_t & invscale)547 inline void store_quantized_signed(int8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale)
548 {
549     int32x4x4_t out =
550     {
551         {
552             vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)),
553             vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)),
554             vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)),
555             vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)),
556         }
557     };
558     store_quantized_signed(output_ptr, out);
559 }
560 
561 template <ArithmeticOperation op>
elementwise_arithm_op_quantized_scalar(const float & a,const float & b,UniformQuantizationInfo qinfo)562 inline uint8_t elementwise_arithm_op_quantized_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
563 {
564     return quantize_qasymm8(elementwise_arithm_op_scalar<op>(a, b), qinfo);
565 }
566 
567 template <ArithmeticOperation op>
elementwise_arithm_op_quantized_signed_scalar(const float & a,const float & b,UniformQuantizationInfo qinfo)568 inline int8_t elementwise_arithm_op_quantized_signed_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
569 {
570     return quantize_qasymm8_signed(elementwise_arithm_op_scalar<op>(a, b), qinfo);
571 }
572 
573 template <ArithmeticOperation op>
elementwise_arithm_op(const float32x4x4_t & a,const float32x4x4_t & b)574 float32x4x4_t elementwise_arithm_op(const float32x4x4_t &a, const float32x4x4_t &b)
575 {
576     using neon_vector_float = wrapper::traits::neon_vector<float, 4>;
577     float32x4x4_t out =
578     {
579         {
580             elementwise_arithm_op<op, neon_vector_float>(a.val[0], b.val[0]),
581             elementwise_arithm_op<op, neon_vector_float>(a.val[1], b.val[1]),
582             elementwise_arithm_op<op, neon_vector_float>(a.val[2], b.val[2]),
583             elementwise_arithm_op<op, neon_vector_float>(a.val[3], b.val[3]),
584         }
585     };
586     return out;
587 }
588 
589 template <ComparisonOperation op>
elementwise_comp_op_quantized_scalar(const float & a,const float & b,UniformQuantizationInfo qinfo)590 inline uint8_t elementwise_comp_op_quantized_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo)
591 {
592     ARM_COMPUTE_UNUSED(qinfo);
593     return elementwise_comp_op_scalar<op>(a, b);
594 }
595 
596 template <ComparisonOperation op>
elementwise_comp_op(const float32x4x4_t & a,const float32x4x4_t & b)597 inline uint32x4x4_t elementwise_comp_op(const float32x4x4_t &a, const float32x4x4_t &b)
598 {
599     uint32x4x4_t out =
600     {
601         {
602             elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[0], b.val[0]),
603             elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[1], b.val[1]),
604             elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[2], b.val[2]),
605             elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[3], b.val[3])
606         }
607     };
608     return out;
609 }
610 
611 template <ArithmeticOperation op>
elementwise_arithm_op_quantized_loop(int window_start_x,int window_end_x,int window_step_x,const uint8_t * input1_ptr,const uint8_t * input2_ptr,uint8_t * output_ptr,int32x4_t voffset1,int32x4_t voffset2,float32x4_t vscale1,float32x4_t vscale2,float32x4_t voffseto,float32x4_t invvscaleo)612 inline int elementwise_arithm_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x,
613                                                 const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr,
614                                                 int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
615                                                 float32x4_t voffseto, float32x4_t invvscaleo)
616 {
617     int x = window_start_x;
618     for(; x <= (window_end_x - window_step_x); x += window_step_x)
619     {
620         // Get inputs and compute output
621         const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1);
622         const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2);
623         const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
624         store_quantized(output_ptr + x, rf, voffseto, invvscaleo);
625     }
626     return x;
627 }
628 
629 template <ArithmeticOperation op>
elementwise_arithm_op_quantized_singed_loop(int window_start_x,int window_end_x,int window_step_x,const int8_t * input1_ptr,const int8_t * input2_ptr,int8_t * output_ptr,int32x4_t voffset1,int32x4_t voffset2,float32x4_t vscale1,float32x4_t vscale2,float32x4_t voffseto,float32x4_t invvscaleo)630 inline int elementwise_arithm_op_quantized_singed_loop(int window_start_x, int window_end_x, int window_step_x,
631                                                        const int8_t *input1_ptr, const int8_t *input2_ptr, int8_t *output_ptr,
632                                                        int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
633                                                        float32x4_t voffseto, float32x4_t invvscaleo)
634 {
635     int x = window_start_x;
636     for(; x <= (window_end_x - window_step_x); x += window_step_x)
637     {
638         // Get inputs and compute output
639         const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1);
640         const float32x4x4_t bf = load_quantized_signed(input2_ptr + x, voffset2, vscale2);
641         const float32x4x4_t rf = elementwise_arithm_op<op>(af, bf);
642         store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo);
643     }
644     return x;
645 }
646 
647 template <ArithmeticOperation op>
elementwise_arithm_op_quantized_broadcast_loop(int window_start_x,int window_end_x,int window_step_x,const uint8_t * non_broadcast_input_ptr,float32x4x4_t broadcast_vector,uint8_t * output_ptr,int32x4_t voffset_non_broadcast,float32x4_t vscale_non_broadcast,float32x4_t voffseto,float32x4_t invvscaleo,bool reorder)648 inline int elementwise_arithm_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
649                                                           const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
650                                                           int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
651                                                           float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
652 {
653     int x = window_start_x;
654     for(; x <= (window_end_x - window_step_x); x += window_step_x)
655     {
656         const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
657         const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
658         store_quantized(output_ptr + x, rf, voffseto, invvscaleo);
659     }
660     return x;
661 }
662 template <ArithmeticOperation op>
elementwise_arithm_op_quantized_signed_broadcast_loop(int window_start_x,int window_end_x,int window_step_x,const int8_t * non_broadcast_input_ptr,float32x4x4_t broadcast_vector,int8_t * output_ptr,int32x4_t voffset_non_broadcast,float32x4_t vscale_non_broadcast,float32x4_t voffseto,float32x4_t invvscaleo,bool reorder)663 inline int elementwise_arithm_op_quantized_signed_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
664                                                                  const int8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, int8_t *output_ptr,
665                                                                  int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
666                                                                  float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
667 {
668     int x = window_start_x;
669     for(; x <= (window_end_x - window_step_x); x += window_step_x)
670     {
671         const float32x4x4_t af = load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
672         const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
673         store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo);
674     }
675     return x;
676 }
677 
678 template <ComparisonOperation op>
elementwise_comp_op_quantized_loop(int window_start_x,int window_end_x,int window_step_x,const uint8_t * input1_ptr,const uint8_t * input2_ptr,uint8_t * output_ptr,int32x4_t voffset1,int32x4_t voffset2,float32x4_t vscale1,float32x4_t vscale2,float32x4_t voffseto,float32x4_t invvscaleo)679 inline int elementwise_comp_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x,
680                                               const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr,
681                                               int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
682                                               float32x4_t voffseto, float32x4_t invvscaleo)
683 {
684     ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
685     int x = window_start_x;
686     for(; x <= (window_end_x - window_step_x); x += window_step_x)
687     {
688         const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1);
689         const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2);
690         const uint32x4x4_t  rf = elementwise_comp_op<op>(af, bf);
691         store_quantized(output_ptr + x, rf);
692     }
693     return x;
694 }
695 
696 template <ComparisonOperation op>
elementwise_comp_op_quantized_signed_loop(int window_start_x,int window_end_x,int window_step_x,const int8_t * input1_ptr,const int8_t * input2_ptr,uint8_t * output_ptr,int32x4_t voffset1,int32x4_t voffset2,float32x4_t vscale1,float32x4_t vscale2,float32x4_t voffseto,float32x4_t invvscaleo)697 inline int elementwise_comp_op_quantized_signed_loop(int window_start_x, int window_end_x, int window_step_x,
698                                                      const int8_t *input1_ptr, const int8_t *input2_ptr, uint8_t *output_ptr,
699                                                      int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2,
700                                                      float32x4_t voffseto, float32x4_t invvscaleo)
701 {
702     ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
703     int x = window_start_x;
704     for(; x <= (window_end_x - window_step_x); x += window_step_x)
705     {
706         const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1);
707         const float32x4x4_t bf = load_quantized_signed(input2_ptr + x, voffset2, vscale2);
708         const uint32x4x4_t  rf = elementwise_comp_op<op>(af, bf);
709         store_quantized(output_ptr + x, rf);
710     }
711     return x;
712 }
713 
714 template <ComparisonOperation op>
elementwise_comp_op_quantized_broadcast_loop(int window_start_x,int window_end_x,int window_step_x,const uint8_t * non_broadcast_input_ptr,float32x4x4_t broadcast_vector,uint8_t * output_ptr,int32x4_t voffset_non_broadcast,float32x4_t vscale_non_broadcast,float32x4_t voffseto,float32x4_t invvscaleo,bool reorder)715 inline int elementwise_comp_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
716                                                         const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
717                                                         int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
718                                                         float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
719 {
720     ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
721     int x = window_start_x;
722     for(; x <= (window_end_x - window_step_x); x += window_step_x)
723     {
724         const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
725         const uint32x4x4_t  rf = elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
726         store_quantized(output_ptr + x, rf);
727     }
728     return x;
729 }
730 
731 template <ComparisonOperation op>
elementwise_comp_op_quantized_signed_broadcast_loop(int window_start_x,int window_end_x,int window_step_x,const int8_t * non_broadcast_input_ptr,float32x4x4_t broadcast_vector,uint8_t * output_ptr,int32x4_t voffset_non_broadcast,float32x4_t vscale_non_broadcast,float32x4_t voffseto,float32x4_t invvscaleo,bool reorder)732 inline int elementwise_comp_op_quantized_signed_broadcast_loop(int window_start_x, int window_end_x, int window_step_x,
733                                                                const int8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr,
734                                                                int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast,
735                                                                float32x4_t voffseto, float32x4_t invvscaleo, bool reorder)
736 {
737     ARM_COMPUTE_UNUSED(voffseto, invvscaleo);
738     int x = window_start_x;
739     for(; x <= (window_end_x - window_step_x); x += window_step_x)
740     {
741         const float32x4x4_t af = load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast);
742         const uint32x4x4_t  rf = elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector);
743         store_quantized(output_ptr + x, rf);
744     }
745     return x;
746 }
747 
elementwise_op_quantized(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,uint8_t (* scalar_func)(const float &,const float &,UniformQuantizationInfo),int (* broadcast_func)(int,int,int,const uint8_t *,float32x4x4_t,uint8_t *,int32x4_t,float32x4_t,float32x4_t,float32x4_t,const bool),int (* neon_func)(int,int,int,const uint8_t *,const uint8_t *,uint8_t *,int32x4_t,int32x4_t,float32x4_t,float32x4_t,float32x4_t,float32x4_t))748 inline void elementwise_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
749                                      uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
750                                      int (*broadcast_func)(int, int, int, const uint8_t *, float32x4x4_t, uint8_t *, int32x4_t, float32x4_t,
751                                                            float32x4_t, float32x4_t, const bool),
752                                      int (*neon_func)(int, int, int, const uint8_t *, const uint8_t *, uint8_t *,
753                                                       int32x4_t, int32x4_t, float32x4_t, float32x4_t,
754                                                       float32x4_t, float32x4_t))
755 {
756     // Create input windows
757     Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
758     Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
759 
760     // Clear X Dimension on execution window as we handle manually
761     Window win = window;
762     win.set(Window::DimX, Window::Dimension(0, 1, 1));
763 
764     const int  window_step_x         = 16;
765     const auto window_start_x        = static_cast<int>(window.x().start());
766     const auto window_end_x          = static_cast<int>(window.x().end());
767     const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
768 
769     const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
770 
771     // Output quantization info (add 0.5 to round toward the nearest integer - 0.5 rounds away from zero)
772     const float32x4_t voffseto   = vdupq_n_f32(output_qinfo.offset + 0.5f);
773     const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
774 
775     if(is_broadcast_across_x)
776     {
777         // Select the broadcast input on the X axis
778         const bool     is_broadcast_input_2 = input2_win.x().step() == 0;
779         Window         broadcast_win        = is_broadcast_input_2 ? input2_win : input1_win;
780         Window         non_broadcast_win    = !is_broadcast_input_2 ? input2_win : input1_win;
781         const ITensor *broadcast_tensor     = is_broadcast_input_2 ? in2 : in1;
782         const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
783 
784         const UniformQuantizationInfo broadcast_qinfo     = broadcast_tensor->info()->quantization_info().uniform();
785         const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
786 
787         const int32x4_t   voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
788         const float32x4_t vscale_non_broadcast  = vdupq_n_f32(non_broadcast_qinfo.scale);
789 
790         // Clear X Dimension on execution window as we handle manually
791         non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
792 
793         Iterator broadcast_input(broadcast_tensor, broadcast_win);
794         Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
795         Iterator output(out, win);
796 
797         execute_window_loop(win, [&](const Coordinates &)
798         {
799             const auto non_broadcast_input_ptr = reinterpret_cast<const uint8_t *>(non_broadcast_input.ptr());
800             const auto output_ptr              = reinterpret_cast<uint8_t *>(output.ptr());
801 
802             const uint8_t       broadcast_value  = *reinterpret_cast<const uint8_t *>(broadcast_input.ptr());
803             const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_u8(broadcast_value), broadcast_qinfo);
804 
805             int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
806                                       voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
807             for(; x < window_end_x; ++x)
808             {
809                 const float afs   = dequantize_qasymm8(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
810                 const float bfs   = dequantize_qasymm8(broadcast_value, broadcast_qinfo);
811                 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
812             }
813         },
814         broadcast_input, non_broadcast_input, output);
815     }
816     else
817     {
818         const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
819         const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
820 
821         // Input1 quantization info
822         const int32x4_t   voffset1 = vdupq_n_s32(input1_qinfo.offset);
823         const float32x4_t vscale1  = vdupq_n_f32(input1_qinfo.scale);
824 
825         // Input2 quantization info
826         const int32x4_t   voffset2 = vdupq_n_s32(input2_qinfo.offset);
827         const float32x4_t vscale2  = vdupq_n_f32(input2_qinfo.scale);
828 
829         // Clear X Dimension on execution window as we handle manually
830         input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
831         input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
832 
833         Iterator input1(in1, input1_win);
834         Iterator input2(in2, input2_win);
835         Iterator output(out, win);
836 
837         execute_window_loop(win, [&](const Coordinates &)
838         {
839             const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
840             const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
841             const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
842 
843             int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
844                                  vscale1, vscale2, voffseto, invvscaleo);
845             for(; x < window_end_x; ++x)
846             {
847                 const float afs   = dequantize_qasymm8(*(input1_ptr + x), input1_qinfo);
848                 const float bfs   = dequantize_qasymm8(*(input2_ptr + x), input2_qinfo);
849                 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
850             }
851         },
852         input1, input2, output);
853     }
854 }
855 
elementwise_comp_quantized_signed(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,uint8_t (* scalar_func)(const float &,const float &,UniformQuantizationInfo),int (* broadcast_func)(int,int,int,const int8_t *,float32x4x4_t,uint8_t *,int32x4_t,float32x4_t,float32x4_t,float32x4_t,const bool),int (* neon_func)(int,int,int,const int8_t *,const int8_t *,uint8_t *,int32x4_t,int32x4_t,float32x4_t,float32x4_t,float32x4_t,float32x4_t))856 inline void elementwise_comp_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
857                                               uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
858                                               int (*broadcast_func)(int, int, int, const int8_t *, float32x4x4_t, uint8_t *, int32x4_t, float32x4_t,
859                                                                     float32x4_t, float32x4_t, const bool),
860                                               int (*neon_func)(int, int, int, const int8_t *, const int8_t *, uint8_t *,
861                                                                int32x4_t, int32x4_t, float32x4_t, float32x4_t,
862                                                                float32x4_t, float32x4_t))
863 {
864     // Create input windows
865     Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
866     Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
867 
868     // Clear X Dimension on execution window as we handle manually
869     Window win = window;
870     win.set(Window::DimX, Window::Dimension(0, 1, 1));
871 
872     const int  window_step_x         = 16;
873     const auto window_start_x        = static_cast<int>(window.x().start());
874     const auto window_end_x          = static_cast<int>(window.x().end());
875     const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
876 
877     const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
878 
879     const float32x4_t voffseto   = vdupq_n_f32(output_qinfo.offset);
880     const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
881 
882     if(is_broadcast_across_x)
883     {
884         // Select the broadcast input on the X axis
885         const bool     is_broadcast_input_2 = input2_win.x().step() == 0;
886         Window         broadcast_win        = is_broadcast_input_2 ? input2_win : input1_win;
887         Window         non_broadcast_win    = !is_broadcast_input_2 ? input2_win : input1_win;
888         const ITensor *broadcast_tensor     = is_broadcast_input_2 ? in2 : in1;
889         const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
890 
891         const UniformQuantizationInfo broadcast_qinfo     = broadcast_tensor->info()->quantization_info().uniform();
892         const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
893 
894         const int32x4_t   voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
895         const float32x4_t vscale_non_broadcast  = vdupq_n_f32(non_broadcast_qinfo.scale);
896 
897         // Clear X Dimension on execution window as we handle manually
898         non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
899 
900         Iterator broadcast_input(broadcast_tensor, broadcast_win);
901         Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
902         Iterator output(out, win);
903 
904         execute_window_loop(win, [&](const Coordinates &)
905         {
906             const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr());
907             const auto output_ptr              = reinterpret_cast<uint8_t *>(output.ptr());
908 
909             const int8_t        broadcast_value  = *reinterpret_cast<const int8_t *>(broadcast_input.ptr());
910             const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_s8(broadcast_value), broadcast_qinfo);
911 
912             int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
913                                       voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
914             for(; x < window_end_x; ++x)
915             {
916                 const float afs   = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
917                 const float bfs   = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo);
918                 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
919             }
920         },
921         broadcast_input, non_broadcast_input, output);
922     }
923     else
924     {
925         const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
926         const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
927 
928         // Input1 quantization info
929         const int32x4_t   voffset1 = vdupq_n_s32(input1_qinfo.offset);
930         const float32x4_t vscale1  = vdupq_n_f32(input1_qinfo.scale);
931 
932         // Input2 quantization info
933         const int32x4_t   voffset2 = vdupq_n_s32(input2_qinfo.offset);
934         const float32x4_t vscale2  = vdupq_n_f32(input2_qinfo.scale);
935 
936         // Clear X Dimension on execution window as we handle manually
937         input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
938         input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
939 
940         Iterator input1(in1, input1_win);
941         Iterator input2(in2, input2_win);
942         Iterator output(out, win);
943 
944         execute_window_loop(win, [&](const Coordinates &)
945         {
946             const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr());
947             const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr());
948             const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
949 
950             int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
951                                  vscale1, vscale2, voffseto, invvscaleo);
952             for(; x < window_end_x; ++x)
953             {
954                 const float afs   = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo);
955                 const float bfs   = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo);
956                 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
957             }
958         },
959         input1, input2, output);
960     }
961 }
962 
elementwise_op_quantized_signed(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window,int8_t (* scalar_func)(const float &,const float &,UniformQuantizationInfo),int (* broadcast_func)(int,int,int,const int8_t *,float32x4x4_t,int8_t *,int32x4_t,float32x4_t,float32x4_t,float32x4_t,const bool),int (* neon_func)(int,int,int,const int8_t *,const int8_t *,int8_t *,int32x4_t,int32x4_t,float32x4_t,float32x4_t,float32x4_t,float32x4_t))963 inline void elementwise_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window,
964                                             int8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo),
965                                             int (*broadcast_func)(int, int, int, const int8_t *, float32x4x4_t, int8_t *, int32x4_t, float32x4_t,
966                                                                   float32x4_t, float32x4_t, const bool),
967                                             int (*neon_func)(int, int, int, const int8_t *, const int8_t *, int8_t *,
968                                                              int32x4_t, int32x4_t, float32x4_t, float32x4_t,
969                                                              float32x4_t, float32x4_t))
970 {
971     // Create input windows
972     Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape());
973     Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape());
974 
975     // Clear X Dimension on execution window as we handle manually
976     Window win = window;
977     win.set(Window::DimX, Window::Dimension(0, 1, 1));
978 
979     const int  window_step_x         = 16;
980     const auto window_start_x        = static_cast<int>(window.x().start());
981     const auto window_end_x          = static_cast<int>(window.x().end());
982     const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x();
983 
984     const UniformQuantizationInfo output_qinfo = out->info()->quantization_info().uniform();
985 
986     const float32x4_t voffseto   = vdupq_n_f32(output_qinfo.offset);
987     const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale);
988 
989     if(is_broadcast_across_x)
990     {
991         // Select the broadcast input on the X axis
992         const bool     is_broadcast_input_2 = input2_win.x().step() == 0;
993         Window         broadcast_win        = is_broadcast_input_2 ? input2_win : input1_win;
994         Window         non_broadcast_win    = !is_broadcast_input_2 ? input2_win : input1_win;
995         const ITensor *broadcast_tensor     = is_broadcast_input_2 ? in2 : in1;
996         const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1;
997 
998         const UniformQuantizationInfo broadcast_qinfo     = broadcast_tensor->info()->quantization_info().uniform();
999         const UniformQuantizationInfo non_broadcast_qinfo = non_broadcast_tensor->info()->quantization_info().uniform();
1000 
1001         const int32x4_t   voffset_non_broadcast = vdupq_n_s32(non_broadcast_qinfo.offset);
1002         const float32x4_t vscale_non_broadcast  = vdupq_n_f32(non_broadcast_qinfo.scale);
1003 
1004         // Clear X Dimension on execution window as we handle manually
1005         non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1006 
1007         Iterator broadcast_input(broadcast_tensor, broadcast_win);
1008         Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
1009         Iterator output(out, win);
1010 
1011         execute_window_loop(win, [&](const Coordinates &)
1012         {
1013             const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr());
1014             const auto output_ptr              = reinterpret_cast<int8_t *>(output.ptr());
1015 
1016             const int8_t        broadcast_value  = *reinterpret_cast<const int8_t *>(broadcast_input.ptr());
1017             const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_s8(broadcast_value), broadcast_qinfo);
1018 
1019             int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr,
1020                                       voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2);
1021             for(; x < window_end_x; ++x)
1022             {
1023                 const float afs   = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo);
1024                 const float bfs   = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo);
1025                 *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo);
1026             }
1027         },
1028         broadcast_input, non_broadcast_input, output);
1029     }
1030     else
1031     {
1032         const UniformQuantizationInfo input1_qinfo = in1->info()->quantization_info().uniform();
1033         const UniformQuantizationInfo input2_qinfo = in2->info()->quantization_info().uniform();
1034 
1035         // Input1 quantization info
1036         const int32x4_t   voffset1 = vdupq_n_s32(input1_qinfo.offset);
1037         const float32x4_t vscale1  = vdupq_n_f32(input1_qinfo.scale);
1038 
1039         // Input2 quantization info
1040         const int32x4_t   voffset2 = vdupq_n_s32(input2_qinfo.offset);
1041         const float32x4_t vscale2  = vdupq_n_f32(input2_qinfo.scale);
1042 
1043         // Clear X Dimension on execution window as we handle manually
1044         input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1045         input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1046 
1047         Iterator input1(in1, input1_win);
1048         Iterator input2(in2, input2_win);
1049         Iterator output(out, win);
1050 
1051         execute_window_loop(win, [&](const Coordinates &)
1052         {
1053             const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr());
1054             const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr());
1055             const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
1056 
1057             int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2,
1058                                  vscale1, vscale2, voffseto, invvscaleo);
1059             for(; x < window_end_x; ++x)
1060             {
1061                 const float afs   = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo);
1062                 const float bfs   = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo);
1063                 *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo);
1064             }
1065         },
1066         input1, input2, output);
1067     }
1068 }
1069 
1070 template <ArithmeticOperation op>
elementwise_arithm_op_quantized(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window)1071 void elementwise_arithm_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1072 {
1073     elementwise_op_quantized(in1, in2, out, window, &elementwise_arithm_op_quantized_scalar<op>,
1074                              &elementwise_arithm_op_quantized_broadcast_loop<op>,
1075                              &elementwise_arithm_op_quantized_loop<op>);
1076 }
1077 
1078 template <ArithmeticOperation op>
elementwise_arithm_op_quantized_signed(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window)1079 void elementwise_arithm_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1080 {
1081     elementwise_op_quantized_signed(in1, in2, out, window, &elementwise_arithm_op_quantized_signed_scalar<op>,
1082                                     &elementwise_arithm_op_quantized_signed_broadcast_loop<op>,
1083                                     &elementwise_arithm_op_quantized_singed_loop<op>);
1084 }
1085 
1086 template <ComparisonOperation op>
elementwise_comp_op_quantized(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window)1087 void elementwise_comp_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1088 {
1089     elementwise_op_quantized(in1, in2, out, window, &elementwise_comp_op_quantized_scalar<op>,
1090                              &elementwise_comp_op_quantized_broadcast_loop<op>,
1091                              &elementwise_comp_op_quantized_loop<op>);
1092 }
1093 
1094 template <ComparisonOperation op>
elementwise_comp_op_quantized_signed(const ITensor * in1,const ITensor * in2,ITensor * out,const Window & window)1095 void elementwise_comp_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
1096 {
1097     elementwise_comp_quantized_signed(in1, in2, out, window, &elementwise_comp_op_quantized_scalar<op>,
1098                                       &elementwise_comp_op_quantized_signed_broadcast_loop<op>,
1099                                       &elementwise_comp_op_quantized_signed_loop<op>);
1100 }
1101 } // namespace cpu
1102 } // namespace arm_compute
1103 
1104 #endif /* SRC_CORE_NEON_KERNELS_ELEMENTWISE_IMPL_H */
1105