xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/CpuMulKernel.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2016-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/cpu/kernels/CpuMulKernel.h"
25 
26 #include "arm_compute/core/ITensor.h"
27 #include "arm_compute/core/TensorInfo.h"
28 #include "src/core/CPP/Validate.h"
29 #include "src/core/NEON/NEAsymm.h"
30 #include "src/core/NEON/NESymm.h"
31 #include "src/core/NEON/wrapper/wrapper.h"
32 #include "src/core/helpers/AutoConfiguration.h"
33 #include "src/core/helpers/WindowHelpers.h"
34 
35 #include <arm_neon.h>
36 
37 namespace
38 {
39 #if defined(ENABLE_FP32_KERNELS)
40     static constexpr size_t default_mws_N1_fp32_neon = 22447;
41     static constexpr size_t default_mws_V1_fp32_neon = 38982;
42 #endif /* ENABLE_FP32_KERNELS */
43     static constexpr size_t default_mws_other_platforms_1d_tensor = 10240;
44 }
45 namespace arm_compute
46 {
47 namespace cpu
48 {
49 namespace kernels
50 {
51 namespace
52 {
53 const float       scale255_constant      = 1.f / 255.f;
54 const float32x4_t scale255_constant_f32q = vdupq_n_f32(scale255_constant);
55 const float32x4_t positive_round_f32q    = vdupq_n_f32(0.5f);
56 
validate_arguments(const ITensorInfo * src1,const ITensorInfo * src2,const ITensorInfo * dst,float scale,ConvertPolicy overflow_policy,RoundingPolicy rounding_policy)57 inline Status validate_arguments(const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
58 {
59     ARM_COMPUTE_UNUSED(overflow_policy);
60     ARM_COMPUTE_UNUSED(rounding_policy);
61 
62     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src1);
63     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::S32, DataType::QSYMM16, DataType::F16,
64                                                          DataType::F32);
65     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::S32, DataType::QSYMM16, DataType::F16,
66                                                          DataType::F32);
67     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
68                                                          DataType::S16, DataType::QSYMM16,
69                                                          DataType::S32, DataType::F16, DataType::F32);
70     if(is_data_type_quantized(src1->data_type()) || is_data_type_quantized(src2->data_type()))
71     {
72         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src1, src2);
73         ARM_COMPUTE_RETURN_ERROR_ON_MSG(overflow_policy == ConvertPolicy::WRAP, "ConvertPolicy cannot be WRAP if datatype is quantized");
74     }
75 
76     if(dst->total_size() > 0)
77     {
78         const TensorShape &out_shape = TensorShape::broadcast_shape(src1->tensor_shape(), src2->tensor_shape());
79         ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst->tensor_shape(), 0), "Wrong shape for dst");
80         ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
81         // clang-format off
82         ARM_COMPUTE_RETURN_ERROR_ON_MSG(
83             !(src1->data_type() == src2->data_type() && src2->data_type() == dst->data_type()) &&
84             !(src1->data_type() == DataType::U8 && src2->data_type() == DataType::U8 && dst->data_type() == DataType::S16) &&
85             !(src1->data_type() == DataType::U8 && src2->data_type() == DataType::S16 && dst->data_type() == DataType::S16) &&
86             !(src1->data_type() == DataType::S16 && src2->data_type() == DataType::U8 && dst->data_type() == DataType::S16) &&
87             !(src1->data_type() == DataType::S16 && src2->data_type() == DataType::U8 && dst->data_type() == DataType::S16) &&
88             !(src1->data_type() == DataType::QSYMM16 && src2->data_type() == DataType::QSYMM16 && dst->data_type() == DataType::S32)
89             , "Invalid data type combination");
90         // clang-format on
91         ARM_COMPUTE_RETURN_ERROR_ON_MSG(src1->data_type() == DataType::S16 && dst->data_type() == DataType::S32 && scale != 1.f, "Unsupported scale for QSYMM16 inputs and S32 dst");
92     }
93 
94     if(std::abs(scale - scale255_constant) < 0.00001f)
95     {
96         ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_NEAREST_UP && rounding_policy != RoundingPolicy::TO_NEAREST_EVEN);
97         ARM_COMPUTE_RETURN_ERROR_ON_MSG(src1->data_type() == DataType::S32 && src2->data_type() == DataType::S32 && dst->data_type() == DataType::S32,
98                                         "Scale == 1/255 is not supported if input and dst are of data type S32");
99     }
100     else
101     {
102         ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_ZERO);
103 
104         int         exponent            = 0;
105         const float normalized_mantissa = std::frexp(scale, &exponent);
106 
107         // Use int scaling if factor is equal to 1/2^n for 0 <= n <= 15
108         // frexp returns 0.5 as mantissa which means that the exponent will be in the range of -1 <= e <= 14
109         // Moreover, it will be negative as we deal with 1/2^n
110         ARM_COMPUTE_RETURN_ERROR_ON_MSG(!((normalized_mantissa == 0.5f) && (-14 <= exponent) && (exponent <= 1)), "Scale value not supported (Should be 1/(2^n) or 1/255");
111     }
112 
113     return Status{};
114 }
115 
116 /* Scales a given vector by 1/255.
117  *
118  * @note This does not work for all cases. e.g. for float of 0.49999999999999994 and large floats.
119  *
120  * @param in Input vector to scale.
121  * @return   Scaled dst rounded to nearest (round half up).
122  */
scale255_S32_S32(int32x4_t in)123 inline int32x4_t scale255_S32_S32(int32x4_t in)
124 {
125     // Scale
126     const float32x4_t tmp = vmulq_f32(vcvtq_f32_s32(in), scale255_constant_f32q);
127     // Round to nearest (round half up)
128     // Add +0.5 for all values
129     // Afterwards vcvt rounds toward zero
130     return vcvtq_s32_f32(vaddq_f32(tmp, positive_round_f32q));
131 }
132 
scale255_U16_U16(uint16x8_t in)133 inline uint16x8_t scale255_U16_U16(uint16x8_t in)
134 {
135     const int32x4_t tmp_s1 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(in))));
136     const int32x4_t tmp_s2 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(in))));
137     return vreinterpretq_u16_s16(vcombine_s16(vmovn_s32(tmp_s2), vmovn_s32(tmp_s1)));
138 }
139 
140 template <typename T>
141 inline typename std::enable_if<std::is_same<T, int8_t>::value, int8x16_t>::type
vquantize(float32x4x4_t val,const UniformQuantizationInfo & info)142 vquantize(float32x4x4_t val, const UniformQuantizationInfo &info)
143 {
144     return vquantize_signed(val, info);
145 }
146 
147 template <typename T>
148 inline typename std::enable_if<std::is_same<T, uint8_t>::value, uint8x16_t>::type
vquantize(float32x4x4_t val,const UniformQuantizationInfo & info)149 vquantize(float32x4x4_t val, const UniformQuantizationInfo &info)
150 {
151     return vquantize(val, info);
152 }
153 
154 template <typename T>
mul_saturate_quantized_8(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,float scale)155 void mul_saturate_quantized_8(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, float scale)
156 {
157     // Create input windows
158     Window win        = window;
159     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
160     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
161 
162     // Clear X Dimension on execution window as we handle manually
163     win.set(Window::DimX, Window::Dimension(0, 1, 1));
164 
165     const int  window_step_x         = 16 / sizeof(T);
166     const auto window_start_x        = static_cast<int>(window.x().start());
167     const auto window_end_x          = static_cast<int>(window.x().end());
168     const bool is_broadcast_across_x = src1->info()->tensor_shape().x() != src2->info()->tensor_shape().x();
169 
170     const UniformQuantizationInfo output_qua_info = out->info()->quantization_info().uniform();
171     const UniformQuantizationInfo tmp_qua_info    = { output_qua_info.scale / scale, output_qua_info.offset };
172 
173     if(is_broadcast_across_x)
174     {
175         const bool                    is_broadcast_input_2 = input2_win.x().step() == 0;
176         Window                        broadcast_win        = is_broadcast_input_2 ? input2_win : input1_win;
177         Window                        non_broadcast_win    = !is_broadcast_input_2 ? input2_win : input1_win;
178         const ITensor                *broadcast_tensor     = is_broadcast_input_2 ? src2 : src1;
179         const ITensor                *non_broadcast_tensor = !is_broadcast_input_2 ? src2 : src1;
180         const UniformQuantizationInfo broadcast_qinfo      = broadcast_tensor->info()->quantization_info().uniform();
181         const UniformQuantizationInfo non_broadcast_qinfo  = non_broadcast_tensor->info()->quantization_info().uniform();
182 
183         // Clear X Dimension on execution window as we handle manually
184         non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
185 
186         Iterator broadcast_input(broadcast_tensor, broadcast_win);
187         Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
188         Iterator dst(out, win);
189 
190         using ExactTagType = typename wrapper::traits::neon_vector<T, window_step_x>::tag_type;
191 
192         execute_window_loop(
193             win, [&](const Coordinates &)
194         {
195             const auto non_broadcast_input_ptr = reinterpret_cast<const T *>(non_broadcast_input.ptr());
196             const auto output_ptr              = reinterpret_cast<T *>(dst.ptr());
197 
198             const auto broadcast_value     = *reinterpret_cast<const T *>(broadcast_input.ptr());
199             const auto broadcast_value_vec = wrapper::vdup_n(broadcast_value, ExactTagType{});
200 
201             // Compute window_step_x elements per iteration
202             int x = window_start_x;
203             for(; x <= (window_end_x - window_step_x); x += window_step_x)
204             {
205                 const auto non_broadcast_v = wrapper::vloadq(non_broadcast_input_ptr + x);
206 
207                 // Dequantize inputs
208                 const float32x4x4_t in1_f32x4x4 = vdequantize(non_broadcast_v, non_broadcast_qinfo);
209                 const float32x4x4_t in2_f32x4x4 = vdequantize(broadcast_value_vec, broadcast_qinfo);
210 
211                 const float32x4x4_t out_f32x4x4 =
212                 {
213                     vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
214                     vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
215                     vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
216                     vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
217                 };
218 
219                 // Quantize dst
220                 const auto result = vquantize<T>(out_f32x4x4, tmp_qua_info);
221                 wrapper::vstore(output_ptr + x, result);
222             }
223 
224             // Compute left-over elements
225             for(; x < window_end_x; ++x)
226             {
227                 // Dequantize inputs
228                 const T     src1    = *(non_broadcast_input_ptr + x);
229                 const float tmp_in1 = Qasymm8QuantizationHelper<T>::dequantize(src1, non_broadcast_qinfo);
230                 const float tmp_in2 = Qasymm8QuantizationHelper<T>::dequantize(broadcast_value, broadcast_qinfo);
231                 const float tmp_f   = tmp_in1 * tmp_in2;
232 
233                 // Quantize dst
234                 const auto tmp_qua = Qasymm8QuantizationHelper<T>::quantize(tmp_f, tmp_qua_info);
235                 *(output_ptr + x)  = tmp_qua;
236             }
237         },
238         broadcast_input, non_broadcast_input, dst);
239     }
240     else
241     {
242         const UniformQuantizationInfo input1_qua_info = src1->info()->quantization_info().uniform();
243         const UniformQuantizationInfo input2_qua_info = src2->info()->quantization_info().uniform();
244 
245         // Clear X Dimension on execution window as we handle manually
246         input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
247         input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
248 
249         Iterator input1(src1, input1_win);
250         Iterator input2(src2, input2_win);
251         Iterator dst(out, win);
252 
253         execute_window_loop(
254             win, [&](const Coordinates &)
255         {
256             const auto input1_ptr = reinterpret_cast<const T *>(input1.ptr());
257             const auto input2_ptr = reinterpret_cast<const T *>(input2.ptr());
258             const auto output_ptr = reinterpret_cast<T *>(dst.ptr());
259 
260             // Compute window_step_x elements per iteration
261             int x = window_start_x;
262             for(; x <= (window_end_x - window_step_x); x += window_step_x)
263             {
264                 const auto input1_q = wrapper::vloadq(input1_ptr + x);
265                 const auto input2_q = wrapper::vloadq(input2_ptr + x);
266 
267                 // Dequantize inputs
268                 const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
269                 const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
270 
271                 const float32x4x4_t out_f32x4x4 =
272                 {
273                     vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
274                     vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
275                     vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
276                     vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
277                 };
278 
279                 // Quantize dst
280                 const auto result = vquantize<T>(out_f32x4x4, tmp_qua_info);
281                 wrapper::vstore(output_ptr + x, result);
282             }
283 
284             // Compute left-over elements
285             for(; x < window_end_x; ++x)
286             {
287                 // Dequantize inputs
288                 const T     src1    = *(input1_ptr + x);
289                 const T     src2    = *(input2_ptr + x);
290                 const float tmp_in1 = Qasymm8QuantizationHelper<T>::dequantize(src1, input1_qua_info);
291                 const float tmp_in2 = Qasymm8QuantizationHelper<T>::dequantize(src2, input2_qua_info);
292                 const float tmp_f   = tmp_in1 * tmp_in2;
293 
294                 // Quantize dst
295                 const auto tmp_qua = Qasymm8QuantizationHelper<T>::quantize(tmp_f, tmp_qua_info);
296                 *(output_ptr + x)  = tmp_qua;
297             }
298         },
299         input1, input2, dst);
300     }
301 }
302 
mul_q8_neon_fixedpoint_possible(const ITensorInfo * src0,const ITensorInfo * src1,const ITensorInfo * dst,float scale)303 bool mul_q8_neon_fixedpoint_possible(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst, float scale)
304 {
305     const auto iq0 = src0->quantization_info().uniform();
306     const auto iq1 = src1->quantization_info().uniform();
307     const auto oq  = dst->quantization_info().uniform();
308 
309     const auto multiplier = ((iq0.scale * iq1.scale) / oq.scale) * scale;
310 
311     if(multiplier < -8191.f || multiplier > 8191.f)
312     {
313         //The multiplier cannot be stored as a 14.18 signed fixed-point number
314         return false;
315     }
316 
317     const auto offset_out = float(oq.offset);
318 
319     const auto max_result = multiplier * (256) * (256) + offset_out;
320 
321     if(max_result > 8191.f)
322     {
323         //It might not be possible to store the result as a 14.18 signed fixed-point number.
324         return false;
325     }
326 
327     return true;
328 }
329 
330 template <typename ScalarType>
mul_q8_neon_fixedpoint(const ITensor * src0,const ITensor * src1,ITensor * dst,const Window & window,float scale)331 void mul_q8_neon_fixedpoint(const ITensor *src0, const ITensor *src1, ITensor *dst, const Window &window, float scale)
332 {
333     const auto in0_info = src0->info();
334     const auto in1_info = src1->info();
335 
336     const auto &in0_shape = in0_info->tensor_shape();
337     const auto &in1_shape = in1_info->tensor_shape();
338 
339     // Create input windows.
340     Window in0_win = window.broadcast_if_dimension_le_one(in0_shape);
341     Window in1_win = window.broadcast_if_dimension_le_one(in1_shape);
342 
343     // Clear the x dimension on the execution window as we process the whole row each iteration.
344     Window win = window;
345     win.set(Window::DimX, Window::Dimension(0, 1, 1));
346 
347     constexpr int window_step_x         = 16;
348     const auto    window_start_x        = window.x().start();
349     const auto    window_end_x          = window.x().end();
350     const auto    is_broadcast_across_x = in0_shape.x() != in1_shape.x();
351 
352     const auto iq0_info = in0_info->quantization_info().uniform();
353     const auto iq1_info = in1_info->quantization_info().uniform();
354     const auto oq_info  = dst->info()->quantization_info().uniform();
355 
356     const auto in0_offset = iq0_info.offset;
357     const auto in1_offset = iq1_info.offset;
358     const auto out_offset = oq_info.offset;
359     const auto multiplier = ((iq0_info.scale * iq1_info.scale) / oq_info.scale) * scale;
360 
361     constexpr int32_t two_pwr18i = 262144;
362     constexpr float   two_pwr18f = 262144.f;
363 
364     const auto in0_offset_16p0  = static_cast<int16_t>(in0_offset);
365     const auto in1_offset_16p0  = static_cast<int16_t>(in1_offset);
366     const auto out_offset_14p18 = static_cast<int32_t>(out_offset * two_pwr18i);
367     const auto multiplier_14p18 = static_cast<int32_t>(multiplier * two_pwr18f);
368 
369     if(is_broadcast_across_x)
370     {
371         // Prefix: a = non-broadcast, b = broadcast.
372 
373         const auto is_broadcast_input_1 = in1_win.x().step() == 0;
374         auto       a_win                = is_broadcast_input_1 ? in0_win : in1_win;
375         auto       b_win                = is_broadcast_input_1 ? in1_win : in0_win;
376         const auto a_tensor             = is_broadcast_input_1 ? src0 : src1;
377         const auto b_tensor             = is_broadcast_input_1 ? src1 : src0;
378 
379         const auto a_offset_16p0 = is_broadcast_input_1 ? in0_offset_16p0 : in1_offset_16p0;
380         const auto b_offset_16p0 = is_broadcast_input_1 ? in1_offset : in0_offset;
381 #ifndef __aarch64__
382         const auto a_offset = is_broadcast_input_1 ? in0_offset : in1_offset;
383         const auto b_offset = is_broadcast_input_1 ? in1_offset : in0_offset;
384 #endif //__aarch64__
385         const auto a_voffset_16p0 = wrapper::vdup_n(a_offset_16p0, wrapper::traits::vector_64_tag());
386 
387         // Clear the x dimension on the execution window as we process the whole row each iteration.
388         a_win.set(Window::DimX, Window::Dimension(0, 1, 1));
389 
390         Iterator a_input_it(a_tensor, a_win);
391         Iterator b_input_it(b_tensor, b_win);
392         Iterator out_it(dst, win);
393 
394         execute_window_loop(
395             win, [&](const Coordinates &)
396         {
397             const auto a_ptr   = reinterpret_cast<const ScalarType *>(a_input_it.ptr());
398             const auto b_ptr   = reinterpret_cast<const ScalarType *>(b_input_it.ptr());
399             const auto out_ptr = reinterpret_cast<ScalarType *>(out_it.ptr());
400 
401             const auto b_val            = *b_ptr;
402             const auto b_offseted_32p0  = static_cast<int32_t>(b_val - b_offset_16p0);
403             const auto b_voffseted_32p0 = wrapper::vdup_n(b_offseted_32p0, wrapper::traits::vector_128_tag());
404 
405             const auto vmultiplier_14p18 = wrapper::vdup_n(multiplier_14p18, wrapper::traits::vector_128_tag());
406             const auto voffsetout_14p18  = wrapper::vdup_n(out_offset_14p18, wrapper::traits::vector_128_tag());
407 
408             int x = window_start_x;
409 
410             for(; x <= (window_end_x - window_step_x); x += window_step_x)
411             {
412                 // Load the inputs.
413                 const auto a_vin_8p0 = wrapper::vloadq(a_ptr + x);
414 
415                 // Widen the non-broadcast elements to signed 16-bit regardless of the input signedness.
416                 const auto a_vin_16p0_0 = wrapper::vreinterpret(wrapper::vmovl(wrapper::vgetlow(a_vin_8p0)));
417                 const auto a_vin_16p0_1 = wrapper::vreinterpret(wrapper::vmovl(wrapper::vgethigh(a_vin_8p0)));
418 
419                 const auto voffseted_32p0_00 = wrapper::vsubl(wrapper::vgetlow(a_vin_16p0_0), a_voffset_16p0);
420                 const auto voffseted_32p0_01 = wrapper::vsubl(wrapper::vgethigh(a_vin_16p0_0), a_voffset_16p0);
421                 const auto voffseted_32p0_10 = wrapper::vsubl(wrapper::vgetlow(a_vin_16p0_1), a_voffset_16p0);
422                 const auto voffseted_32p0_11 = wrapper::vsubl(wrapper::vgethigh(a_vin_16p0_1), a_voffset_16p0);
423 
424                 const auto vinnermul_32p0_00 = wrapper::vmul(voffseted_32p0_00, b_voffseted_32p0);
425                 const auto vinnermul_32p0_01 = wrapper::vmul(voffseted_32p0_01, b_voffseted_32p0);
426                 const auto vinnermul_32p0_10 = wrapper::vmul(voffseted_32p0_10, b_voffseted_32p0);
427                 const auto vinnermul_32p0_11 = wrapper::vmul(voffseted_32p0_11, b_voffseted_32p0);
428 
429                 const auto vout_14p18_00 = wrapper::vmla(voffsetout_14p18, vinnermul_32p0_00, vmultiplier_14p18);
430                 const auto vout_14p18_01 = wrapper::vmla(voffsetout_14p18, vinnermul_32p0_01, vmultiplier_14p18);
431                 const auto vout_14p18_10 = wrapper::vmla(voffsetout_14p18, vinnermul_32p0_10, vmultiplier_14p18);
432                 const auto vout_14p18_11 = wrapper::vmla(voffsetout_14p18, vinnermul_32p0_11, vmultiplier_14p18);
433 
434                 // These shift rights are to revert the multiplication by twopwr18. Hard limit of a maximum shift by 8 requires multiple shift instructions to achieve this.
435                 const auto vout_15p1_00 = wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>(vout_14p18_00));
436                 const auto vout_15p1_01 = wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>(vout_14p18_01));
437                 const auto vout_15p1_10 = wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>(vout_14p18_10));
438                 const auto vout_15p1_11 = wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>(vout_14p18_11));
439 
440                 const auto vout_15p1_0 = wrapper::vcombine(
441                                              vout_15p1_00,
442                                              vout_15p1_01);
443 
444                 const auto vout_15p1_1 = wrapper::vcombine(
445                                              vout_15p1_10,
446                                              vout_15p1_11);
447                 const auto out_ptr = reinterpret_cast<ScalarType *>(out_it.ptr());
448 
449                 const auto vout_8p0 = wrapper::vcombine(
450                                           wrapper::vqrshrn<2>(vout_15p1_0),
451                                           wrapper::vqrshrn<2>(vout_15p1_1));
452                 wrapper::vstore(out_ptr + x, vout_8p0);
453             }
454 
455             //Process the left-over elements.
456             for(; x < window_end_x; ++x)
457             {
458 #ifdef __aarch64__
459                 out_ptr[x] = wrapper::vqrshrn<2>(wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>((multiplier_14p18 * (int32_t(a_ptr[x]) - a_offset_16p0) * (int32_t(
460                                                                                                              b_val) - b_offset_16p0)) + out_offset_14p18)));
461 #else  //__aarch64__
462                 out_ptr[x] = utility::clamp<int32_t, ScalarType>(support::cpp11::lround(multiplier * ((float(a_ptr[x]) - a_offset) * (float(b_val) - b_offset)) + float(out_offset)));
463 #endif //__aarch64__
464             }
465         },
466         a_input_it, b_input_it, out_it);
467     }
468     else
469     {
470         const auto voffset0_16p0     = wrapper::vdup_n(in0_offset_16p0, wrapper::traits::vector_64_tag());
471         const auto voffset1_16p0     = wrapper::vdup_n(in1_offset_16p0, wrapper::traits::vector_64_tag());
472         const auto voffsetout_14p18  = wrapper::vdup_n(out_offset_14p18, wrapper::traits::vector_128_tag());
473         const auto vmultiplier_14p18 = wrapper::vdup_n(multiplier_14p18, wrapper::traits::vector_128_tag());
474 
475         // Clear the x dimension on the execution window as we process the whole row each iteration.
476         in0_win.set(Window::DimX, Window::Dimension(0, 1, 1));
477         in1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
478 
479         Iterator in0_it(src0, in0_win);
480         Iterator in1_it(src1, in1_win);
481         Iterator out_it(dst, win);
482 
483         execute_window_loop(
484             win, [&](const Coordinates &)
485         {
486             const auto in0_ptr = reinterpret_cast<const ScalarType *>(in0_it.ptr());
487             const auto in1_ptr = reinterpret_cast<const ScalarType *>(in1_it.ptr());
488             const auto out_ptr = reinterpret_cast<ScalarType *>(out_it.ptr());
489 
490             int x = window_start_x;
491 
492             for(; x <= (window_end_x - window_step_x); x += window_step_x)
493             {
494                 // Load the inputs.
495                 const auto vin0_8p0 = wrapper::vloadq(in0_ptr + x);
496                 const auto vin1_8p0 = wrapper::vloadq(in1_ptr + x);
497 
498                 // Widen the input elements to signed 16-bit regardless of the input signedness.
499                 const auto vin0_16p0_0 = wrapper::vreinterpret(wrapper::vmovl(wrapper::vgetlow(vin0_8p0)));
500                 const auto vin0_16p0_1 = wrapper::vreinterpret(wrapper::vmovl(wrapper::vgethigh(vin0_8p0)));
501                 const auto vin1_16p0_0 = wrapper::vreinterpret(wrapper::vmovl(wrapper::vgetlow(vin1_8p0)));
502                 const auto vin1_16p0_1 = wrapper::vreinterpret(wrapper::vmovl(wrapper::vgethigh(vin1_8p0)));
503 
504                 const auto voffseted0_32p0_00 = wrapper::vsubl(wrapper::vgetlow(vin0_16p0_0), voffset0_16p0);
505                 const auto voffseted0_32p0_01 = wrapper::vsubl(wrapper::vgethigh(vin0_16p0_0), voffset0_16p0);
506                 const auto voffseted0_32p0_10 = wrapper::vsubl(wrapper::vgetlow(vin0_16p0_1), voffset0_16p0);
507                 const auto voffseted0_32p0_11 = wrapper::vsubl(wrapper::vgethigh(vin0_16p0_1), voffset0_16p0);
508 
509                 const auto voffseted1_32p0_00 = wrapper::vsubl(wrapper::vgetlow(vin1_16p0_0), voffset1_16p0);
510                 const auto voffseted1_32p0_01 = wrapper::vsubl(wrapper::vgethigh(vin1_16p0_0), voffset1_16p0);
511                 const auto voffseted1_32p0_10 = wrapper::vsubl(wrapper::vgetlow(vin1_16p0_1), voffset1_16p0);
512                 const auto voffseted1_32p0_11 = wrapper::vsubl(wrapper::vgethigh(vin1_16p0_1), voffset1_16p0);
513 
514                 const auto vinnermul_32p0_00 = wrapper::vmul(voffseted0_32p0_00, voffseted1_32p0_00);
515                 const auto vinnermul_32p0_01 = wrapper::vmul(voffseted0_32p0_01, voffseted1_32p0_01);
516                 const auto vinnermul_32p0_10 = wrapper::vmul(voffseted0_32p0_10, voffseted1_32p0_10);
517                 const auto vinnermul_32p0_11 = wrapper::vmul(voffseted0_32p0_11, voffseted1_32p0_11);
518 
519                 const auto vout_14p18_00 = wrapper::vmla(voffsetout_14p18, vinnermul_32p0_00, vmultiplier_14p18);
520                 const auto vout_14p18_01 = wrapper::vmla(voffsetout_14p18, vinnermul_32p0_01, vmultiplier_14p18);
521                 const auto vout_14p18_10 = wrapper::vmla(voffsetout_14p18, vinnermul_32p0_10, vmultiplier_14p18);
522                 const auto vout_14p18_11 = wrapper::vmla(voffsetout_14p18, vinnermul_32p0_11, vmultiplier_14p18);
523 
524                 // These shift rights are to revert the multiplication by twopwr18. Hard limit of a maximum shift by 8 requires multiple shift instructions to achieve this.
525                 const auto vout_14p2_00 = wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>(vout_14p18_00));
526                 const auto vout_14p2_01 = wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>(vout_14p18_01));
527                 const auto vout_14p2_10 = wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>(vout_14p18_10));
528                 const auto vout_14p2_11 = wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>(vout_14p18_11));
529 
530                 const auto vout_14p2_0 = wrapper::vcombine(
531                                              vout_14p2_00,
532                                              vout_14p2_01);
533 
534                 const auto vout_14p2_1 = wrapper::vcombine(
535                                              vout_14p2_10,
536                                              vout_14p2_11);
537 
538                 const auto vout_8p0 = wrapper::vcombine(
539                                           wrapper::vqrshrn<2>(vout_14p2_0),
540                                           wrapper::vqrshrn<2>(vout_14p2_1));
541                 wrapper::vstore(out_ptr + x, vout_8p0);
542             }
543 
544             //Process the left-over elements.
545             for(; x < window_end_x; ++x)
546             {
547 #ifdef __aarch64__
548                 out_ptr[x] = wrapper::vqrshrn<2>(wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>((multiplier_14p18 * (int32_t(in0_ptr[x]) - in0_offset_16p0) * (int32_t(
549                                                                                                              in1_ptr[x]) - in1_offset_16p0)) + out_offset_14p18)));
550 #else  //__aarch64__
551                 out_ptr[x] = utility::clamp<int32_t, ScalarType>(support::cpp11::lround(multiplier * ((float(in0_ptr[x]) - in0_offset) * (float(in1_ptr[x]) - in1_offset)) + float(out_offset)));
552 #endif //__aarch64__
553             }
554         },
555         in0_it, in1_it, out_it);
556     }
557 }
558 
mul_saturate_QSYMM16_QSYMM16_QSYMM16(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,float scale)559 void mul_saturate_QSYMM16_QSYMM16_QSYMM16(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, float scale)
560 {
561     const UniformQuantizationInfo input1_qua_info = src1->info()->quantization_info().uniform();
562     const UniformQuantizationInfo input2_qua_info = src2->info()->quantization_info().uniform();
563     const UniformQuantizationInfo output_qua_info = out->info()->quantization_info().uniform();
564 
565     // Create input windows
566     Window win        = window;
567     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
568     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
569 
570     // Clear X Dimension on execution window as we handle manually
571     win.set(Window::DimX, Window::Dimension(0, 1, 1));
572     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
573     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
574 
575     Iterator input1(src1, input1_win);
576     Iterator input2(src2, input2_win);
577     Iterator dst(out, win);
578 
579     const int  window_step_x  = 16;
580     const auto window_start_x = static_cast<int>(window.x().start());
581     const auto window_end_x   = static_cast<int>(window.x().end());
582 
583     const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset };
584 
585     execute_window_loop(
586         win, [&](const Coordinates &)
587     {
588         const auto input1_ptr = reinterpret_cast<const qsymm16_t *>(input1.ptr());
589         const auto input2_ptr = reinterpret_cast<const qsymm16_t *>(input2.ptr());
590         const auto output_ptr = reinterpret_cast<qsymm16_t *>(dst.ptr());
591 
592         // Compute window_step_x elements per iteration
593         int x = window_start_x;
594         for(; x <= (window_end_x - window_step_x); x += window_step_x)
595         {
596             const qsymm16x8x2_t input1_q =
597             {
598                 {
599                     vld1q_s16(input1_ptr + x),
600                     vld1q_s16(input1_ptr + x + 8),
601                 }
602             };
603             const qsymm16x8x2_t input2_q =
604             {
605                 {
606                     vld1q_s16(input2_ptr + x),
607                     vld1q_s16(input2_ptr + x + 8),
608                 }
609             };
610 
611             // Dequantize inputs
612             const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
613             const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
614 
615             const float32x4x4_t out_f32x4x4 =
616             {
617                 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
618                 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
619                 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
620                 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
621             };
622 
623             const qsymm16x8x2_t result = vquantize_qsymm16(out_f32x4x4, tmp_qua_info);
624             vst1q_s16(output_ptr + x, result.val[0]);
625             vst1q_s16(output_ptr + x + 8, result.val[1]);
626         }
627 
628         // Compute left-over elements
629         for(; x < window_end_x; ++x)
630         {
631             // Dequantize inputs
632             float tmp_in1 = static_cast<float>(*(input1_ptr + x)) * input1_qua_info.scale;
633             float tmp_in2 = static_cast<float>(*(input2_ptr + x)) * input2_qua_info.scale;
634             float tmp_f   = tmp_in1 * tmp_in2;
635 
636             // Quantize dst, lrintf() has same rounding mode as vcombine_s16
637             int32_t   tmp     = lrintf(tmp_f / tmp_qua_info.scale);
638             qsymm16_t tmp_qua = static_cast<qsymm16_t>(tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
639             *(output_ptr + x) = tmp_qua;
640         }
641     },
642     input1, input2, dst);
643 }
644 
mul_QSYMM16_QSYMM16_S32(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,int scale)645 void mul_QSYMM16_QSYMM16_S32(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, int scale)
646 {
647     ARM_COMPUTE_UNUSED(scale);
648 
649     // Create input windows
650     Window win        = window;
651     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
652     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
653 
654     // Clear X Dimension on execution window as we handle manually
655     win.set(Window::DimX, Window::Dimension(0, 1, 1));
656     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
657     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
658 
659     Iterator input1(src1, input1_win);
660     Iterator input2(src2, input2_win);
661     Iterator dst(out, win);
662 
663     const int  window_step_x  = 16;
664     const auto window_start_x = static_cast<int>(window.x().start());
665     const auto window_end_x   = static_cast<int>(window.x().end());
666 
667     execute_window_loop(
668         win, [&](const Coordinates &)
669     {
670         const auto input1_ptr = reinterpret_cast<const qsymm16_t *>(input1.ptr());
671         const auto input2_ptr = reinterpret_cast<const qsymm16_t *>(input2.ptr());
672         const auto output_ptr = reinterpret_cast<int32_t *>(dst.ptr());
673 
674         // Compute window_step_x elements per iteration
675         int x = window_start_x;
676         for(; x <= (window_end_x - window_step_x); x += window_step_x)
677         {
678             const qsymm16x8x2_t input1_q =
679             {
680                 {
681                     vld1q_s16(input1_ptr + x),
682                     vld1q_s16(input1_ptr + x + 8),
683                 }
684             };
685             const qsymm16x8x2_t input2_q =
686             {
687                 {
688                     vld1q_s16(input2_ptr + x),
689                     vld1q_s16(input2_ptr + x + 8),
690                 }
691             };
692 
693             const int32x4x4_t in1_s32 =
694             {
695                 {
696                     vmovl_s16(vget_low_s16(input1_q.val[0])),
697                     vmovl_s16(vget_high_s16(input1_q.val[0])),
698                     vmovl_s16(vget_low_s16(input1_q.val[1])),
699                     vmovl_s16(vget_high_s16(input1_q.val[1])),
700                 }
701             };
702             const int32x4x4_t in2_s32 =
703             {
704                 {
705                     vmovl_s16(vget_low_s16(input2_q.val[0])),
706                     vmovl_s16(vget_high_s16(input2_q.val[0])),
707                     vmovl_s16(vget_low_s16(input2_q.val[1])),
708                     vmovl_s16(vget_high_s16(input2_q.val[1])),
709                 }
710             };
711 
712             const int32x4x4_t result =
713             {
714                 {
715                     vmulq_s32(in1_s32.val[0], in2_s32.val[0]),
716                     vmulq_s32(in1_s32.val[1], in2_s32.val[1]),
717                     vmulq_s32(in1_s32.val[2], in2_s32.val[2]),
718                     vmulq_s32(in1_s32.val[3], in2_s32.val[3]),
719                 }
720             };
721 
722             vst1q_s32(output_ptr + x, result.val[0]);
723             vst1q_s32(output_ptr + x + 4, result.val[1]);
724             vst1q_s32(output_ptr + x + 8, result.val[2]);
725             vst1q_s32(output_ptr + x + 12, result.val[3]);
726         }
727 
728         // Compute left-over elements
729         for(; x < window_end_x; ++x)
730         {
731             int32_t tmp       = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
732             *(output_ptr + x) = tmp;
733         }
734     },
735     input1, input2, dst);
736 }
737 
738 template <bool is_scale255, bool is_sat>
mul_U8_U8_U8(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,int n)739 void mul_U8_U8_U8(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, int n)
740 {
741     // Create input windows
742     Window win        = window;
743     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
744     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
745 
746     // Clear X Dimension on execution window as we handle manually
747     win.set(Window::DimX, Window::Dimension(0, 1, 1));
748     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
749     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
750 
751     Iterator input1(src1, input1_win);
752     Iterator input2(src2, input2_win);
753     Iterator dst(out, win);
754 
755     const int  window_step_x  = 16 / sizeof(uint8_t);
756     const auto window_start_x = static_cast<int>(window.x().start());
757     const auto window_end_x   = static_cast<int>(window.x().end());
758 
759     execute_window_loop(
760         win, [&](const Coordinates &)
761     {
762         const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
763         const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
764         const auto output_ptr = reinterpret_cast<uint8_t *>(dst.ptr());
765 
766         // Compute window_step_x elements per iteration
767         int x = window_start_x;
768         for(; x <= (window_end_x - window_step_x); x += window_step_x)
769         {
770             const uint8x16_t ta1 = wrapper::vloadq(input1_ptr + x);
771             const uint8x16_t ta2 = wrapper::vloadq(input2_ptr + x);
772 
773             uint16x8_t       tmp1_high = vmovl_u8(vget_high_u8(ta1));
774             const uint16x8_t tmp2_high = vmovl_u8(vget_high_u8(ta2));
775             uint16x8_t       tmp1_low  = vmovl_u8(vget_low_u8(ta1));
776             const uint16x8_t tmp2_low  = vmovl_u8(vget_low_u8(ta2));
777 
778             tmp1_high = vmulq_u16(tmp1_high, tmp2_high);
779             tmp1_low  = vmulq_u16(tmp1_low, tmp2_low);
780 
781             if(is_scale255)
782             {
783                 tmp1_high = scale255_U16_U16(tmp1_high);
784                 tmp1_low  = scale255_U16_U16(tmp1_low);
785             }
786             else
787             {
788                 const int16x8_t vn = vdupq_n_s16(-n);
789 
790                 if(is_sat)
791                 {
792                     tmp1_high = vqshlq_u16(tmp1_high, vn);
793                     tmp1_low  = vqshlq_u16(tmp1_low, vn);
794                 }
795                 else
796                 {
797                     tmp1_high = vshlq_u16(tmp1_high, vn);
798                     tmp1_low  = vshlq_u16(tmp1_low, vn);
799                 }
800             }
801             if(is_sat)
802             {
803                 vst1q_u8(output_ptr + x, vcombine_u8(vqmovn_u16(tmp1_low), vqmovn_u16(tmp1_high)));
804             }
805             else
806             {
807                 vst1q_u8(output_ptr + x, vcombine_u8(vmovn_u16(tmp1_low), vmovn_u16(tmp1_high)));
808             }
809         }
810 
811         // Compute left-over elements
812         for(; x < window_end_x; ++x)
813         {
814             uint16_t tmp = static_cast<uint16_t>(*(input1_ptr + x)) * static_cast<uint16_t>(*(input2_ptr + x));
815 
816             if(is_scale255)
817             {
818                 float tmp_f = static_cast<float>(tmp) * scale255_constant;
819                 tmp         = static_cast<uint16_t>(tmp_f + 0.5f);
820             }
821             else
822             {
823                 tmp >>= n;
824             }
825             if(is_sat && tmp > 255)
826             {
827                 tmp = 255;
828             }
829             *(output_ptr + x) = static_cast<uint8_t>(tmp);
830         }
831     },
832     input1, input2, dst);
833 }
834 
835 template <bool is_scale255, bool is_sat>
mul_S16_S16_S16_n_loop(const int16x8_t & src1,const int16x8_t & src2,int n)836 inline int16x8_t mul_S16_S16_S16_n_loop(const int16x8_t &src1, const int16x8_t &src2, int n)
837 {
838     int32x4_t       tmp1_high = vmovl_s16(vget_high_s16(src1));
839     const int32x4_t tmp2_high = vmovl_s16(vget_high_s16(src2));
840     int32x4_t       tmp1_low  = vmovl_s16(vget_low_s16(src1));
841     const int32x4_t tmp2_low  = vmovl_s16(vget_low_s16(src2));
842 
843     tmp1_high = vmulq_s32(tmp1_high, tmp2_high);
844     tmp1_low  = vmulq_s32(tmp1_low, tmp2_low);
845 
846     if(is_scale255)
847     {
848         tmp1_high = scale255_S32_S32(tmp1_high);
849         tmp1_low  = scale255_S32_S32(tmp1_low);
850     }
851     else
852     {
853         // Right shift amount
854         const int32x4_t vn = vdupq_n_s32(-n);
855         // Left shift amount
856         const int32x4_t vnl = vdupq_n_s32(n);
857         // Calculate conversion bit
858         const uint32x4_t tmp1_high_u  = vreinterpretq_u32_s32(tmp1_high);
859         const uint32x4_t tmp1_low_u   = vreinterpretq_u32_s32(tmp1_low);
860         const uint32x4_t sign_high    = vshrq_n_u32(tmp1_high_u, 31);
861         const uint32x4_t sign_low     = vshrq_n_u32(tmp1_low_u, 31);
862         const int32x4_t  sign_high_s  = vreinterpretq_s32_u32(sign_high);
863         const int32x4_t  sign_low_s   = vreinterpretq_s32_u32(sign_low);
864         const int32x4_t  convert_high = vsubq_s32(vshlq_s32(sign_high_s, vnl), sign_high_s);
865         const int32x4_t  convert_low  = vsubq_s32(vshlq_s32(sign_low_s, vnl), sign_low_s);
866         if(is_sat)
867         {
868             tmp1_high = vqshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
869             tmp1_low  = vqshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
870         }
871         else
872         {
873             tmp1_high = vshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
874             tmp1_low  = vshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
875         }
876     }
877 
878     if(is_sat)
879     {
880         return vcombine_s16(vqmovn_s32(tmp1_low), vqmovn_s32(tmp1_high));
881     }
882     else
883     {
884         return vcombine_s16(vmovn_s32(tmp1_low), vmovn_s32(tmp1_high));
885     }
886 }
887 
888 template <bool is_scale255, bool is_sat>
mul_S16_S16_S16_n_k(const int16x8x2_t & src1,const int16x8x2_t & src2,int n)889 inline int16x8x2_t mul_S16_S16_S16_n_k(const int16x8x2_t &src1, const int16x8x2_t &src2, int n)
890 {
891     const int16x8x2_t result =
892     {
893         {
894             // First 8 elements
895             mul_S16_S16_S16_n_loop<is_scale255, is_sat>(src1.val[0], src2.val[0], n),
896             // Second 8 elements
897             mul_S16_S16_S16_n_loop<is_scale255, is_sat>(src1.val[1], src2.val[1], n)
898         }
899     };
900 
901     return result;
902 }
903 
904 template <bool is_scale255, bool is_sat>
mul_S16_S16_S16(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,int n)905 void mul_S16_S16_S16(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, int n)
906 {
907     // Create input windows
908     Window win        = window;
909     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
910     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
911 
912     // Clear X Dimension on execution window as we handle manually
913     win.set(Window::DimX, Window::Dimension(0, 1, 1));
914     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
915     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
916 
917     Iterator input1(src1, input1_win);
918     Iterator input2(src2, input2_win);
919     Iterator dst(out, win);
920 
921     const int  window_step_x  = 16;
922     const auto window_start_x = static_cast<int>(window.x().start());
923     const auto window_end_x   = static_cast<int>(window.x().end());
924 
925     execute_window_loop(
926         win, [&](const Coordinates &)
927     {
928         const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
929         const auto input2_ptr = reinterpret_cast<const int16_t *>(input2.ptr());
930         const auto output_ptr = reinterpret_cast<int16_t *>(dst.ptr());
931 
932         // Compute window_step_x elements per iteration
933         int x = window_start_x;
934         for(; x <= (window_end_x - window_step_x); x += window_step_x)
935         {
936             const int16x8x2_t ta1 =
937             {
938                 {
939                     vld1q_s16(input1_ptr + x),
940                     vld1q_s16(input1_ptr + x + 8),
941                 }
942             };
943             const int16x8x2_t ta2 =
944             {
945                 {
946                     vld1q_s16(input2_ptr + x),
947                     vld1q_s16(input2_ptr + x + 8),
948                 }
949             };
950             const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
951 
952             vst1q_s16(output_ptr + x, result.val[0]);
953             vst1q_s16(output_ptr + x + 8, result.val[1]);
954         }
955 
956         // Compute left-over elements
957         for(; x < window_end_x; ++x)
958         {
959             int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
960 
961             if(is_scale255)
962             {
963                 float tmp_f = static_cast<float>(tmp) * scale255_constant;
964 
965                 tmp = static_cast<int32_t>(tmp_f + 0.5f);
966             }
967             else
968             {
969                 if(tmp >= 0)
970                 {
971                     tmp >>= n;
972                 }
973                 else
974                 {
975                     uint32_t mask = (1u << n) - 1;
976                     tmp           = (tmp + static_cast<int32_t>(mask)) >> n;
977                 }
978             }
979             if(is_sat)
980             {
981                 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
982             }
983             *(output_ptr + x) = static_cast<int16_t>(tmp);
984         }
985     },
986     input1, input2, dst);
987 }
988 
989 template <bool is_sat>
mul_S32_S32_S32_n_loop(const int32x4_t & src1,const int32x4_t & src2,int n)990 inline int32x4_t mul_S32_S32_S32_n_loop(const int32x4_t &src1, const int32x4_t &src2, int n)
991 {
992     const int32x2_t input1_1 = vget_low_s32(src1);
993     const int32x2_t input2_1 = vget_low_s32(src2);
994     const int32x2_t input1_2 = vget_high_s32(src1);
995     const int32x2_t input2_2 = vget_high_s32(src2);
996 
997     int64x2_t tmp_1 = vmull_s32(input1_1, input2_1);
998     int64x2_t tmp_2 = vmull_s32(input1_2, input2_2);
999 
1000     // Apply scaling, conversion and rounding (round to zero)
1001     // Right shift amount
1002     const int64x2_t vn = vdupq_n_s64(-n);
1003     // Left shift amount
1004     const int64x2_t vnl = vdupq_n_s64(n);
1005     // Calculate conversion bit
1006     const uint64x2_t tmp_1_u   = vreinterpretq_u64_s64(tmp_1);
1007     const uint64x2_t sign_1    = vshrq_n_u64(tmp_1_u, 63);
1008     const int64x2_t  sign_1_s  = vreinterpretq_s64_u64(sign_1);
1009     const int64x2_t  convert_1 = vsubq_s64(vshlq_s64(sign_1_s, vnl), sign_1_s);
1010 
1011     const uint64x2_t tmp_2_u   = vreinterpretq_u64_s64(tmp_2);
1012     const uint64x2_t sign_2    = vshrq_n_u64(tmp_2_u, 63);
1013     const int64x2_t  sign_2_s  = vreinterpretq_s64_u64(sign_2);
1014     const int64x2_t  convert_2 = vsubq_s64(vshlq_s64(sign_2_s, vnl), sign_2_s);
1015     if(is_sat)
1016     {
1017         tmp_1 = vqshlq_s64(vaddq_s64(tmp_1, convert_1), vn);
1018         tmp_2 = vqshlq_s64(vaddq_s64(tmp_2, convert_2), vn);
1019         return vcombine_s32(vqmovn_s64(tmp_1), vqmovn_s64(tmp_2));
1020     }
1021     else
1022     {
1023         tmp_1 = vshlq_s64(vaddq_s64(tmp_1, convert_1), vn);
1024         tmp_2 = vshlq_s64(vaddq_s64(tmp_2, convert_2), vn);
1025         return vcombine_s32(vmovn_s64(tmp_1), vmovn_s64(tmp_2));
1026     }
1027 }
1028 
1029 template <bool is_sat>
mul_S32_S32_S32_n_k(const int32x4x2_t & src1,const int32x4x2_t & src2,int n)1030 inline int32x4x2_t mul_S32_S32_S32_n_k(const int32x4x2_t &src1, const int32x4x2_t &src2, int n)
1031 {
1032     const int32x4x2_t result =
1033     {
1034         {
1035             // First 4 elements
1036             mul_S32_S32_S32_n_loop<is_sat>(src1.val[0], src2.val[0], n),
1037             // Second 4 elements
1038             mul_S32_S32_S32_n_loop<is_sat>(src1.val[1], src2.val[1], n)
1039         }
1040     };
1041 
1042     return result;
1043 }
1044 
1045 template <bool is_sat>
mul_S32_S32_S32(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,int n)1046 void mul_S32_S32_S32(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, int n)
1047 {
1048     // Create input windows
1049     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
1050     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
1051 
1052     // Clear X Dimension on execution window as we handle manually
1053     Window win = window;
1054     win.set(Window::DimX, Window::Dimension(0, 1, 1));
1055 
1056     const int  window_step_x         = 8;
1057     const auto window_start_x        = static_cast<int>(window.x().start());
1058     const auto window_end_x          = static_cast<int>(window.x().end());
1059     const bool is_broadcast_across_x = src1->info()->tensor_shape().x() != src2->info()->tensor_shape().x();
1060 
1061     if(is_broadcast_across_x)
1062     {
1063         const bool     is_broadcast_input_2 = input2_win.x().step() == 0;
1064         Window         broadcast_win        = is_broadcast_input_2 ? input2_win : input1_win;
1065         Window         non_broadcast_win    = !is_broadcast_input_2 ? input2_win : input1_win;
1066         const ITensor *broadcast_tensor     = is_broadcast_input_2 ? src2 : src1;
1067         const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? src2 : src1;
1068 
1069         // Clear X Dimension on execution window as we handle manually
1070         non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1071 
1072         Iterator broadcast_input(broadcast_tensor, broadcast_win);
1073         Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
1074         Iterator dst(out, win);
1075 
1076         execute_window_loop(
1077             win, [&](const Coordinates &)
1078         {
1079             const auto non_broadcast_input_ptr = reinterpret_cast<const int32_t *>(non_broadcast_input.ptr());
1080             const auto output_ptr              = reinterpret_cast<int32_t *>(dst.ptr());
1081 
1082             const int32_t broadcast_value     = *reinterpret_cast<const int32_t *>(broadcast_input.ptr());
1083             const auto    broadcast_value_vec = vdupq_n_s32(broadcast_value);
1084 
1085             // Compute window_step_x elements per iteration
1086             int x = window_start_x;
1087             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1088             {
1089                 const int32x4x2_t broadcast_v =
1090                 {
1091                     {
1092                         broadcast_value_vec,
1093                         broadcast_value_vec,
1094                     }
1095                 };
1096                 const int32x4x2_t non_broadcast_v =
1097                 {
1098                     {
1099                         vld1q_s32(non_broadcast_input_ptr + x),
1100                         vld1q_s32(non_broadcast_input_ptr + x + 4),
1101                     }
1102                 };
1103                 const int32x4x2_t result = mul_S32_S32_S32_n_k<is_sat>(broadcast_v, non_broadcast_v, n);
1104 
1105                 vst1q_s32(output_ptr + x, result.val[0]);
1106                 vst1q_s32(output_ptr + x + 4, result.val[1]);
1107             }
1108 
1109             // Compute left-over elements
1110             for(; x < window_end_x; ++x)
1111             {
1112                 int64_t tmp = static_cast<int64_t>(broadcast_value) * static_cast<int64_t>(*(non_broadcast_input_ptr + x));
1113 
1114                 if(tmp >= 0)
1115                 {
1116                     tmp >>= n;
1117                 }
1118                 else
1119                 {
1120                     uint64_t mask = ((uint64_t)1u << n) - 1;
1121                     tmp           = (tmp + static_cast<int64_t>(mask)) >> n;
1122                 }
1123                 if(is_sat)
1124                 {
1125                     tmp = utility::clamp<int64_t, int32_t>(tmp);
1126                 }
1127                 *(output_ptr + x) = static_cast<int32_t>(tmp);
1128             }
1129         },
1130         broadcast_input, non_broadcast_input, dst);
1131     }
1132     else
1133     {
1134         // Clear X Dimension on execution window as we handle manually
1135         input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1136         input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1137 
1138         Iterator input1(src1, input1_win);
1139         Iterator input2(src2, input2_win);
1140         Iterator dst(out, win);
1141 
1142         execute_window_loop(
1143             win, [&](const Coordinates &)
1144         {
1145             const auto input1_ptr = reinterpret_cast<const int32_t *>(input1.ptr());
1146             const auto input2_ptr = reinterpret_cast<const int32_t *>(input2.ptr());
1147             const auto output_ptr = reinterpret_cast<int32_t *>(dst.ptr());
1148 
1149             // Compute window_step_x elements per iteration
1150             int x = window_start_x;
1151             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1152             {
1153                 const int32x4x2_t ta1 =
1154                 {
1155                     {
1156                         vld1q_s32(input1_ptr + x),
1157                         vld1q_s32(input1_ptr + x + 4),
1158                     }
1159                 };
1160                 const int32x4x2_t ta2 =
1161                 {
1162                     {
1163                         vld1q_s32(input2_ptr + x),
1164                         vld1q_s32(input2_ptr + x + 4),
1165                     }
1166                 };
1167                 const int32x4x2_t result = mul_S32_S32_S32_n_k<is_sat>(ta1, ta2, n);
1168 
1169                 vst1q_s32(output_ptr + x, result.val[0]);
1170                 vst1q_s32(output_ptr + x + 4, result.val[1]);
1171             }
1172 
1173             // Compute left-over elements
1174             for(; x < window_end_x; ++x)
1175             {
1176                 int64_t tmp = static_cast<int64_t>(*(input1_ptr + x)) * static_cast<int64_t>(*(input2_ptr + x));
1177 
1178                 if(tmp >= 0)
1179                 {
1180                     tmp >>= n;
1181                 }
1182                 else
1183                 {
1184                     uint64_t mask = ((uint64_t)1u << n) - 1;
1185                     tmp           = (tmp + static_cast<int64_t>(mask)) >> n;
1186                 }
1187                 if(is_sat)
1188                 {
1189                     tmp = utility::clamp<int64_t, int32_t>(tmp);
1190                 }
1191                 *(output_ptr + x) = static_cast<int32_t>(tmp);
1192             }
1193         },
1194         input1, input2, dst);
1195     }
1196 }
1197 
mul_F32_F32_F32(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,float scale)1198 void mul_F32_F32_F32(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, float scale)
1199 {
1200     // Create input windows
1201     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
1202     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
1203 
1204     // Clear X Dimension on execution window as we handle manually
1205     Window win = window;
1206     win.set(Window::DimX, Window::Dimension(0, 1, 1));
1207 
1208     constexpr int window_step_x         = 16 / sizeof(float);
1209     const auto    window_start_x        = static_cast<int>(window.x().start());
1210     const auto    window_end_x          = static_cast<int>(window.x().end());
1211     const bool    is_broadcast_across_x = src1->info()->tensor_shape().x() != src2->info()->tensor_shape().x();
1212 
1213     using ExactTagType = typename wrapper::traits::neon_vector<float, window_step_x>::tag_type;
1214 
1215     if(is_broadcast_across_x)
1216     {
1217         const bool     is_broadcast_input_2 = input2_win.x().step() == 0;
1218         Window         broadcast_win        = is_broadcast_input_2 ? input2_win : input1_win;
1219         Window         non_broadcast_win    = !is_broadcast_input_2 ? input2_win : input1_win;
1220         const ITensor *broadcast_tensor     = is_broadcast_input_2 ? src2 : src1;
1221         const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? src2 : src1;
1222 
1223         // Clear X Dimension on execution window as we handle manually
1224         non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1225 
1226         Iterator broadcast_input(broadcast_tensor, broadcast_win);
1227         Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
1228         Iterator dst(out, win);
1229 
1230         execute_window_loop(
1231             win, [&](const Coordinates &)
1232         {
1233             const auto non_broadcast_input_ptr = reinterpret_cast<const float *>(non_broadcast_input.ptr());
1234             const auto output_ptr              = reinterpret_cast<float *>(dst.ptr());
1235 
1236             const float broadcast_value     = *reinterpret_cast<const float *>(broadcast_input.ptr());
1237             const auto  broadcast_value_vec = wrapper::vdup_n(broadcast_value, ExactTagType{});
1238             const auto  scale_vec           = wrapper::vdup_n(scale, ExactTagType{});
1239 
1240             // Compute window_step_x elements per iteration
1241             int x = window_start_x;
1242             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1243             {
1244                 const auto non_broadcast_v = wrapper::vloadq(non_broadcast_input_ptr + x);
1245                 auto       res             = wrapper::vmul(wrapper::vmul(broadcast_value_vec, non_broadcast_v), scale_vec);
1246                 wrapper::vstore(output_ptr + x, res);
1247             }
1248 
1249             // Compute left-over elements
1250             for(; x < window_end_x; ++x)
1251             {
1252                 const auto non_broadcast_v = *(non_broadcast_input_ptr + x);
1253                 *(output_ptr + x)          = broadcast_value * non_broadcast_v * scale;
1254             }
1255         },
1256         broadcast_input, non_broadcast_input, dst);
1257     }
1258     else
1259     {
1260         // Clear X Dimension on execution window as we handle manually
1261         input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1262         input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1263 
1264         Iterator input1(src1, input1_win);
1265         Iterator input2(src2, input2_win);
1266         Iterator dst(out, win);
1267 
1268         execute_window_loop(
1269             win, [&](const Coordinates &)
1270         {
1271             const auto input1_ptr = reinterpret_cast<const float *>(input1.ptr());
1272             const auto input2_ptr = reinterpret_cast<const float *>(input2.ptr());
1273             const auto output_ptr = reinterpret_cast<float *>(dst.ptr());
1274 
1275             // Compute window_step_x elements per iteration
1276             int x = window_start_x;
1277             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1278             {
1279                 const auto ta1       = wrapper::vloadq(input1_ptr + x);
1280                 const auto ta2       = wrapper::vloadq(input2_ptr + x);
1281                 const auto scale_vec = wrapper::vdup_n(scale, ExactTagType{});
1282                 const auto res       = wrapper::vmul(wrapper::vmul(ta1, ta2), scale_vec);
1283                 wrapper::vstore(output_ptr + x, res);
1284             }
1285 
1286             // Compute left-over elements
1287             for(; x < window_end_x; ++x)
1288             {
1289                 const auto ta1    = *(input1_ptr + x);
1290                 const auto ta2    = *(input2_ptr + x);
1291                 *(output_ptr + x) = ta1 * ta2 * scale;
1292             }
1293         },
1294         input1, input2, dst);
1295     }
1296 }
1297 
c_mul_F32_F32_F32_n(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window)1298 void c_mul_F32_F32_F32_n(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window)
1299 {
1300     // Create input windows
1301     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
1302     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
1303 
1304     // Clear X Dimension on execution window as we handle manually
1305     Window win = window;
1306     win.set(Window::DimX, Window::Dimension(0, 1, 1));
1307 
1308     constexpr int window_step_x         = 8 / sizeof(float);
1309     const auto    window_start_x        = static_cast<int>(window.x().start());
1310     const auto    window_end_x          = static_cast<int>(window.x().end());
1311     const bool    is_broadcast_across_x = src1->info()->tensor_shape().x() != src2->info()->tensor_shape().x();
1312 
1313     using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
1314 
1315     if(is_broadcast_across_x)
1316     {
1317         const bool     is_broadcast_input_2 = input2_win.x().step() == 0;
1318         Window         broadcast_win        = is_broadcast_input_2 ? input2_win : input1_win;
1319         Window         non_broadcast_win    = !is_broadcast_input_2 ? input2_win : input1_win;
1320         const ITensor *broadcast_tensor     = is_broadcast_input_2 ? src2 : src1;
1321         const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? src2 : src1;
1322 
1323         // Clear X Dimension on execution window as we handle manually
1324         non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1325 
1326         Iterator broadcast_input(broadcast_tensor, broadcast_win);
1327         Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
1328         Iterator dst(out, win);
1329 
1330         execute_window_loop(
1331             win, [&](const Coordinates &)
1332         {
1333             const auto non_broadcast_input_ptr = reinterpret_cast<const float *>(non_broadcast_input.ptr());
1334             const auto output_ptr              = reinterpret_cast<float *>(dst.ptr());
1335 
1336             const float broadcast_value = *reinterpret_cast<const float *>(broadcast_input.ptr());
1337 
1338             // Compute window_step_x elements per iteration
1339             int x = window_start_x;
1340             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1341             {
1342                 const auto  a = wrapper::vloadq(non_broadcast_input_ptr + 2 * x);
1343                 float32x4_t b = vdupq_n_f32(broadcast_value);
1344 
1345                 const float32x4_t mask  = { -1.0f, 1.0f, -1.0f, 1.0f };
1346                 const float32x2_t tmp00 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
1347                 const float32x2_t tmp01 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
1348                 const float32x2_t tmp10 = wrapper::vdup_n(wrapper::vgetlane(a, 2), ExactTagType{});
1349                 const float32x2_t tmp11 = wrapper::vdup_n(wrapper::vgetlane(a, 3), ExactTagType{});
1350 
1351                 const float32x4_t tmp0 = wrapper::vcombine(tmp00, tmp10);
1352                 const float32x4_t tmp1 = wrapper::vcombine(tmp01, tmp11);
1353 
1354                 float32x4_t res = wrapper::vmul(tmp0, b);
1355                 b               = wrapper::vmul(b, mask);
1356 
1357                 res = wrapper::vmla(res, tmp1, b);
1358                 wrapper::vstore(output_ptr + 2 * x, res);
1359             }
1360 
1361             // Compute left-over elements
1362             for(; x < window_end_x; ++x)
1363             {
1364                 const auto non_broadcast_value0 = *(non_broadcast_input_ptr + 2 * x);
1365                 const auto non_broadcast_value1 = *(non_broadcast_input_ptr + 2 * x + 1);
1366                 auto       res1                 = broadcast_value * (non_broadcast_value0 - non_broadcast_value1);
1367                 auto       res2                 = broadcast_value * (non_broadcast_value1 + non_broadcast_value0);
1368                 *(output_ptr + 2 * x)           = res1;
1369                 *(output_ptr + 2 * x + 1)       = res2;
1370             }
1371         },
1372         broadcast_input, non_broadcast_input, dst);
1373     }
1374     else
1375     {
1376         // Clear X Dimension on execution window as we handle manually
1377         input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1378         input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1379 
1380         Iterator input1(src1, input1_win);
1381         Iterator input2(src2, input2_win);
1382         Iterator dst(out, win);
1383 
1384         execute_window_loop(
1385             win, [&](const Coordinates &)
1386         {
1387             const auto input1_ptr = reinterpret_cast<const float *>(input1.ptr());
1388             const auto input2_ptr = reinterpret_cast<const float *>(input2.ptr());
1389             const auto output_ptr = reinterpret_cast<float *>(dst.ptr());
1390 
1391             // Compute window_step_x elements per iteration
1392             int x = window_start_x;
1393             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1394             {
1395                 const float32x4_t a = wrapper::vloadq(input1_ptr + 2 * x);
1396                 float32x4_t       b = wrapper::vloadq(input2_ptr + 2 * x);
1397 
1398                 const float32x4_t mask  = { -1.0f, 1.0f, -1.0f, 1.0f };
1399                 const float32x2_t tmp00 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
1400                 const float32x2_t tmp01 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
1401                 const float32x2_t tmp10 = wrapper::vdup_n(wrapper::vgetlane(a, 2), ExactTagType{});
1402                 const float32x2_t tmp11 = wrapper::vdup_n(wrapper::vgetlane(a, 3), ExactTagType{});
1403 
1404                 const float32x4_t tmp0 = wrapper::vcombine(tmp00, tmp10);
1405                 const float32x4_t tmp1 = wrapper::vcombine(tmp01, tmp11);
1406 
1407                 float32x4_t res = wrapper::vmul(tmp0, b);
1408 
1409                 b = wrapper::vrev64(b);
1410                 b = wrapper::vmul(b, mask);
1411 
1412                 res = wrapper::vmla(res, tmp1, b);
1413                 wrapper::vstore(output_ptr + 2 * x, res);
1414             }
1415 
1416             // Compute left-over elements
1417             for(; x < window_end_x; ++x)
1418             {
1419                 const auto a0             = *(input1_ptr + 2 * x);
1420                 const auto a1             = *(input1_ptr + 2 * x + 1);
1421                 const auto b0             = *(input2_ptr + 2 * x);
1422                 const auto b1             = *(input2_ptr + 2 * x + 1);
1423                 auto       res1           = a0 * b0 - a1 * b1;
1424                 auto       res2           = a0 * b1 + a1 * b0;
1425                 *(output_ptr + 2 * x)     = res1;
1426                 *(output_ptr + 2 * x + 1) = res2;
1427             }
1428         },
1429         input1, input2, dst);
1430     }
1431 }
1432 
1433 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
mul_F16_F16_F16(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,float scale)1434 void mul_F16_F16_F16(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, float scale)
1435 {
1436     // Create input windows
1437     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
1438     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
1439 
1440     // Clear X Dimension on execution window as we handle manually
1441     Window win = window;
1442     win.set(Window::DimX, Window::Dimension(0, 1, 1));
1443     constexpr int window_step_x         = 16;
1444     const auto    window_start_x        = static_cast<int>(window.x().start());
1445     const auto    window_end_x          = static_cast<int>(window.x().end());
1446     const bool    is_broadcast_across_x = src1->info()->tensor_shape().x() != src2->info()->tensor_shape().x();
1447     if(is_broadcast_across_x)
1448     {
1449         const bool     is_broadcast_input_2 = input2_win.x().step() == 0;
1450         Window         broadcast_win        = is_broadcast_input_2 ? input2_win : input1_win;
1451         Window         non_broadcast_win    = !is_broadcast_input_2 ? input2_win : input1_win;
1452         const ITensor *broadcast_tensor     = is_broadcast_input_2 ? src2 : src1;
1453         const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? src2 : src1;
1454         // Clear X Dimension on execution window as we handle manually
1455         non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1456         Iterator broadcast_input(broadcast_tensor, broadcast_win);
1457         Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
1458         Iterator dst(out, win);
1459         execute_window_loop(
1460             win, [&](const Coordinates &)
1461         {
1462             const auto          non_broadcast_input_ptr = reinterpret_cast<const float16_t *>(non_broadcast_input.ptr());
1463             const auto          output_ptr              = reinterpret_cast<float16_t *>(dst.ptr());
1464             const auto          broadcast_value         = *reinterpret_cast<const float16_t *>(broadcast_input.ptr());
1465             const float16x8x2_t broadcast_value_vec     =
1466             {
1467                 {
1468                     vdupq_n_f16(broadcast_value),
1469                     vdupq_n_f16(broadcast_value),
1470                 }
1471             };
1472             const auto scale_vec = vdupq_n_f16(scale);
1473             // Compute window_step_x elements per iteration
1474             int x = window_start_x;
1475             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1476             {
1477                 const float16x8x2_t non_broadcast_v =
1478                 {
1479                     {
1480                         vld1q_f16(non_broadcast_input_ptr + x),
1481                         vld1q_f16(non_broadcast_input_ptr + x + 8),
1482                     }
1483                 };
1484                 const float16x8x2_t result =
1485                 {
1486                     {
1487                         vmulq_f16(vmulq_f16(broadcast_value_vec.val[0], non_broadcast_v.val[0]), scale_vec),
1488                         vmulq_f16(vmulq_f16(broadcast_value_vec.val[1], non_broadcast_v.val[1]), scale_vec),
1489                     }
1490                 };
1491                 vst1q_f16(output_ptr + x, result.val[0]);
1492                 vst1q_f16(output_ptr + x + 8, result.val[1]);
1493             }
1494             // Compute left-over elements
1495             for(; x < window_end_x; ++x)
1496             {
1497                 const auto non_broadcast_v = *(non_broadcast_input_ptr + x);
1498                 *(output_ptr + x)          = broadcast_value * non_broadcast_v * scale;
1499             }
1500         },
1501         broadcast_input, non_broadcast_input, dst);
1502     }
1503     else
1504     {
1505         input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1506         input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1507         Iterator input1(src1, input1_win);
1508         Iterator input2(src2, input2_win);
1509         Iterator dst(out, win);
1510         execute_window_loop(
1511             win, [&](const Coordinates &)
1512         {
1513             const auto input1_ptr = reinterpret_cast<const float16_t *>(input1.ptr());
1514             const auto input2_ptr = reinterpret_cast<const float16_t *>(input2.ptr());
1515             const auto output_ptr = reinterpret_cast<float16_t *>(dst.ptr());
1516             // Compute window_step_x elements per iteration
1517             int x = window_start_x;
1518             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1519             {
1520                 const float16x8x2_t ta1 =
1521                 {
1522                     {
1523                         vld1q_f16(input1_ptr + x),
1524                         vld1q_f16(input1_ptr + x + 8),
1525                     }
1526                 };
1527                 const float16x8x2_t ta2 =
1528                 {
1529                     {
1530                         vld1q_f16(input2_ptr + x),
1531                         vld1q_f16(input2_ptr + x + 8),
1532                     }
1533                 };
1534                 const float16x8_t   scale_vec = vdupq_n_f16(scale);
1535                 const float16x8x2_t result    =
1536                 {
1537                     {
1538                         vmulq_f16(vmulq_f16(ta1.val[0], ta2.val[0]), scale_vec),
1539                         vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec),
1540                     }
1541                 };
1542                 vst1q_f16(output_ptr + x, result.val[0]);
1543                 vst1q_f16(output_ptr + x + 8, result.val[1]);
1544             }
1545             // Compute left-over elements
1546             for(; x < window_end_x; ++x)
1547             {
1548                 const auto ta1    = *(input1_ptr + x);
1549                 const auto ta2    = *(input2_ptr + x);
1550                 *(output_ptr + x) = ta1 * ta2 * scale;
1551             }
1552         },
1553         input1, input2, dst);
1554     }
1555 }
1556 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1557 
1558 template <bool is_scale255, bool is_sat>
mul_U8_U8_S16(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,int n)1559 void mul_U8_U8_S16(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, int n)
1560 {
1561     // Create input windows
1562     Window win        = window;
1563     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
1564     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
1565 
1566     // Clear X Dimension on execution window as we handle manually
1567     win.set(Window::DimX, Window::Dimension(0, 1, 1));
1568     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1569     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1570 
1571     Iterator input1(src1, input1_win);
1572     Iterator input2(src2, input2_win);
1573     Iterator dst(out, win);
1574 
1575     const int  window_step_x  = 16 / sizeof(uint8_t);
1576     const auto window_start_x = static_cast<int>(window.x().start());
1577     const auto window_end_x   = static_cast<int>(window.x().end());
1578 
1579     execute_window_loop(
1580         win, [&](const Coordinates &)
1581     {
1582         const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
1583         const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
1584         const auto output_ptr = reinterpret_cast<int16_t *>(dst.ptr());
1585 
1586         // Compute window_step_x elements per iteration
1587         int x = window_start_x;
1588         for(; x <= (window_end_x - window_step_x); x += window_step_x)
1589         {
1590             const uint8x16_t bv = wrapper::vloadq(input2_ptr + x);
1591             const uint8x16_t av = wrapper::vloadq(input1_ptr + x);
1592 
1593             uint16x8_t tmp_low  = vmovl_u8(vget_low_u8(av));
1594             uint16x8_t tmp_high = vmovl_u8(vget_high_u8(av));
1595             tmp_low             = vmulq_u16(tmp_low, vmovl_u8(vget_low_u8(bv)));
1596             tmp_high            = vmulq_u16(tmp_high, vmovl_u8(vget_high_u8(bv)));
1597 
1598             if(is_scale255)
1599             {
1600                 tmp_low  = scale255_U16_U16(tmp_low);
1601                 tmp_high = scale255_U16_U16(tmp_high);
1602             }
1603             else
1604             {
1605                 const int16x8_t vn = vdupq_n_s16(-n);
1606 
1607                 if(is_sat)
1608                 {
1609                     tmp_low  = vqshlq_u16(tmp_low, vn);
1610                     tmp_high = vqshlq_u16(tmp_high, vn);
1611                 }
1612                 else
1613                 {
1614                     tmp_low  = vshlq_u16(tmp_low, vn);
1615                     tmp_high = vshlq_u16(tmp_high, vn);
1616                 }
1617             }
1618 
1619             if(is_sat)
1620             {
1621                 static const uint16x8_t max = vdupq_n_u16(SHRT_MAX);
1622 
1623                 tmp_low  = vminq_u16(tmp_low, max);
1624                 tmp_high = vminq_u16(tmp_high, max);
1625             }
1626 
1627             vst1q_s16(output_ptr + x, vreinterpretq_s16_u16(tmp_low));
1628             vst1q_s16(output_ptr + x + 8, vreinterpretq_s16_u16(tmp_high));
1629         }
1630 
1631         // Compute left-over elements
1632         for(; x < window_end_x; ++x)
1633         {
1634             int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
1635 
1636             if(is_scale255)
1637             {
1638                 float tmp_f = static_cast<float>(tmp) * scale255_constant;
1639                 tmp         = static_cast<int32_t>(tmp_f + 0.5f);
1640             }
1641             else
1642             {
1643                 tmp >>= n;
1644             }
1645 
1646             if(is_sat)
1647             {
1648                 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : tmp;
1649             }
1650 
1651             *(output_ptr + x) = static_cast<int16_t>(tmp);
1652         }
1653     },
1654     input1, input2, dst);
1655 }
1656 
1657 template <bool is_scale255, bool is_sat>
mul_S16_U8_S16(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,int n)1658 void mul_S16_U8_S16(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, int n)
1659 {
1660     // Create input windows
1661     Window win        = window;
1662     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
1663     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
1664 
1665     // Clear X Dimension on execution window as we handle manually
1666     win.set(Window::DimX, Window::Dimension(0, 1, 1));
1667     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1668     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1669 
1670     Iterator input1(src1, input1_win);
1671     Iterator input2(src2, input2_win);
1672     Iterator dst(out, win);
1673 
1674     const int  window_step_x  = 16;
1675     const auto window_start_x = static_cast<int>(window.x().start());
1676     const auto window_end_x   = static_cast<int>(window.x().end());
1677 
1678     execute_window_loop(
1679         win, [&](const Coordinates &)
1680     {
1681         const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
1682         const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
1683         const auto output_ptr = reinterpret_cast<int16_t *>(dst.ptr());
1684 
1685         // Compute window_step_x elements per iteration
1686         int x = window_start_x;
1687         for(; x <= (window_end_x - window_step_x); x += window_step_x)
1688         {
1689             const int16x8x2_t ta1 =
1690             {
1691                 {
1692                     vld1q_s16(input1_ptr + x),
1693                     vld1q_s16(input1_ptr + x + 8),
1694                 }
1695             };
1696             const uint8x8x2_t ta2u =
1697             {
1698                 {
1699                     vld1_u8(input2_ptr + x),
1700                     vld1_u8(input2_ptr + x + 8),
1701                 }
1702             };
1703             const int16x8x2_t ta2 =
1704             {
1705                 {
1706                     vreinterpretq_s16_u16(vmovl_u8(ta2u.val[0])),
1707                     vreinterpretq_s16_u16(vmovl_u8(ta2u.val[1]))
1708                 }
1709             };
1710 
1711             const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
1712 
1713             vst1q_s16(output_ptr + x, result.val[0]);
1714             vst1q_s16(output_ptr + x + 8, result.val[1]);
1715         }
1716 
1717         // Compute left-over elements
1718         for(; x < window_end_x; ++x)
1719         {
1720             int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
1721 
1722             if(is_scale255)
1723             {
1724                 float tmp_f = static_cast<float>(tmp) * scale255_constant;
1725 
1726                 tmp = static_cast<int32_t>(tmp_f + 0.5f);
1727             }
1728             else
1729             {
1730                 if(tmp >= 0)
1731                 {
1732                     tmp >>= n;
1733                 }
1734                 else
1735                 {
1736                     uint32_t mask = (1u << n) - 1;
1737                     tmp           = (tmp + static_cast<int32_t>(mask)) >> n;
1738                 }
1739             }
1740             if(is_sat)
1741             {
1742                 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
1743             }
1744             *(output_ptr + x) = static_cast<int16_t>(tmp);
1745         }
1746     },
1747     input1, input2, dst);
1748 }
1749 
1750 template <bool is_scale255, bool is_sat>
mul_U8_S16_S16(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,int n)1751 void mul_U8_S16_S16(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, int n)
1752 {
1753     // Simply swap the two input buffers
1754     mul_S16_U8_S16<is_scale255, is_sat>(src2, src1, out, window, n);
1755 }
1756 } // namespace
1757 
configure(ITensorInfo * src1,ITensorInfo * src2,ITensorInfo * dst,float scale,ConvertPolicy overflow_policy,RoundingPolicy rounding_policy)1758 void CpuMulKernel::configure(ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
1759 {
1760     ARM_COMPUTE_UNUSED(rounding_policy);
1761     ARM_COMPUTE_ERROR_ON_NULLPTR(src1, src2, dst);
1762 
1763     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src1, src2, dst, scale, overflow_policy, rounding_policy));
1764 
1765     const TensorShape &out_shape = TensorShape::broadcast_shape(src1->tensor_shape(), src2->tensor_shape());
1766 
1767     // Auto initialize dst if not initialized
1768     set_shape_if_empty(*dst, out_shape);
1769 
1770     _scale          = scale;
1771     _scale_exponent = 0;
1772     _func_quantized = nullptr;
1773     _func_int       = nullptr;
1774     _func_float     = nullptr;
1775 
1776     bool is_scale_255 = false;
1777     // Check and validate scaling factor
1778     if(std::abs(scale - scale255_constant) < 0.00001f)
1779     {
1780         is_scale_255 = true;
1781     }
1782     else
1783     {
1784         int exponent = 0;
1785 
1786         std::frexp(scale, &exponent);
1787 
1788         // Store the positive exponent. We know that we compute 1/2^n
1789         // Additionally we need to subtract 1 to compensate that frexp used a mantissa of 0.5
1790         _scale_exponent = std::abs(exponent - 1);
1791     }
1792 
1793     const DataType dt_input1 = src1->data_type();
1794     const DataType dt_input2 = src2->data_type();
1795     const DataType dt_output = dst->data_type();
1796     const bool     is_sat    = (overflow_policy == ConvertPolicy::SATURATE);
1797 
1798     switch(dt_input1)
1799     {
1800         case DataType::QASYMM8:
1801             if(dt_input2 == DataType::QASYMM8 && dt_output == DataType::QASYMM8)
1802             {
1803                 if(mul_q8_neon_fixedpoint_possible(src1, src2, dst, scale))
1804                 {
1805                     _func_quantized = &mul_q8_neon_fixedpoint<uint8_t>;
1806                 }
1807                 else
1808                 {
1809                     _func_quantized = &mul_saturate_quantized_8<uint8_t>;
1810                 }
1811             }
1812             break;
1813         case DataType::QASYMM8_SIGNED:
1814             if(dt_input2 == DataType::QASYMM8_SIGNED)
1815             {
1816                 if(mul_q8_neon_fixedpoint_possible(src1, src2, dst, scale))
1817                 {
1818                     _func_quantized = &mul_q8_neon_fixedpoint<int8_t>;
1819                 }
1820                 else
1821                 {
1822                     _func_quantized = &mul_saturate_quantized_8<int8_t>;
1823                 }
1824             }
1825             break;
1826         case DataType::QSYMM16:
1827             if(dt_input2 == DataType::QSYMM16 && dt_output == DataType::QSYMM16)
1828             {
1829                 _func_quantized = &mul_saturate_QSYMM16_QSYMM16_QSYMM16;
1830             }
1831             else if(dt_input2 == DataType::QSYMM16 && dt_output == DataType::S32)
1832             {
1833                 _func_int = &mul_QSYMM16_QSYMM16_S32;
1834             }
1835             break;
1836         case DataType::S16:
1837             if(DataType::U8 == dt_input2 && DataType::S16 == dt_output)
1838             {
1839                 if(is_scale_255)
1840                 {
1841                     _func_int = is_sat ? &mul_S16_U8_S16<true, true> : &mul_S16_U8_S16<true, false>;
1842                 }
1843                 else
1844                 {
1845                     _func_int = is_sat ? &mul_S16_U8_S16<false, true> : &mul_S16_U8_S16<false, false>;
1846                 }
1847             }
1848             if(DataType::S16 == dt_input2 && DataType::S16 == dt_output)
1849             {
1850                 if(is_scale_255)
1851                 {
1852                     _func_int = is_sat ? &mul_S16_S16_S16<true, true> : &mul_S16_S16_S16<true, false>;
1853                 }
1854                 else
1855                 {
1856                     _func_int = is_sat ? &mul_S16_S16_S16<false, true> : &mul_S16_S16_S16<false, false>;
1857                 }
1858             }
1859             break;
1860         case DataType::S32:
1861             if(DataType::S32 == dt_input2 && DataType::S32 == dt_output)
1862             {
1863                 _func_int = is_sat ? &mul_S32_S32_S32<true> : &mul_S32_S32_S32<false>;
1864             }
1865             break;
1866         case DataType::U8:
1867             if(DataType::U8 == dt_input2 && DataType::U8 == dt_output)
1868             {
1869                 if(is_scale_255)
1870                 {
1871                     _func_int = is_sat ? &mul_U8_U8_U8<true, true> : &mul_U8_U8_U8<true, false>;
1872                 }
1873                 else
1874                 {
1875                     _func_int = is_sat ? &mul_U8_U8_U8<false, true> : &mul_U8_U8_U8<false, false>;
1876                 }
1877             }
1878             else if(DataType::U8 == dt_input2 && DataType::S16 == dt_output)
1879             {
1880                 if(is_scale_255)
1881                 {
1882                     _func_int = is_sat ? &mul_U8_U8_S16<true, true> : &mul_U8_U8_S16<true, false>;
1883                 }
1884                 else
1885                 {
1886                     _func_int = is_sat ? &mul_U8_U8_S16<false, true> : &mul_U8_U8_S16<false, false>;
1887                 }
1888             }
1889             else if(DataType::S16 == dt_input2 && DataType::S16 == dt_output)
1890             {
1891                 if(is_scale_255)
1892                 {
1893                     _func_int = is_sat ? &mul_U8_S16_S16<true, true> : &mul_U8_S16_S16<true, false>;
1894                 }
1895                 else
1896                 {
1897                     _func_int = is_sat ? &mul_U8_S16_S16<false, true> : &mul_U8_S16_S16<false, false>;
1898                 }
1899             }
1900             break;
1901 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1902         case DataType::F16:
1903             _func_float = &mul_F16_F16_F16;
1904             break;
1905 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1906         case DataType::F32:
1907             _func_float = &mul_F32_F32_F32;
1908             break;
1909         default:
1910             ARM_COMPUTE_ERROR("You called with the wrong img formats");
1911     }
1912 
1913     // Configure kernel window
1914     Window win;
1915     std::tie(win, _split_dimension) = calculate_squashed_or_max_window(*src1, *src2);
1916 
1917     ICpuKernel::configure(win);
1918 }
1919 
get_mws(const CPUInfo & platform,size_t thread_count) const1920 size_t CpuMulKernel::get_mws(const CPUInfo &platform, size_t thread_count) const
1921 {
1922     ARM_COMPUTE_UNUSED(thread_count);
1923 
1924 #if defined(ENABLE_FP32_KERNELS)
1925     if(this->_func_float == &mul_F32_F32_F32)
1926     {
1927         size_t mws = ICPPKernel::default_mws;
1928         if(platform.get_cpu_model() == CPUModel::N1)
1929         {
1930             mws = default_mws_N1_fp32_neon;
1931         }
1932         else if(platform.get_cpu_model() == CPUModel::V1)
1933         {
1934             mws = default_mws_V1_fp32_neon;
1935         }
1936         else
1937         {
1938             if(_split_dimension == Window::DimX)
1939             {
1940                 // Don't split the work load too small if the tensor has been reinterpreted as 1D.
1941                 // This number is loosely chosen as threading overhead in each platform varies wildly.
1942                 return default_mws_other_platforms_1d_tensor;
1943             }
1944             return default_mws;
1945         }
1946 
1947         // tensor is 1D or was re-interpreted as 1D
1948         if(this->window().shape().num_dimensions() == 1)
1949         {
1950             return mws;
1951         }
1952         else
1953         {
1954             // scale mws down by the number of elements along all the dimensions (x, z, w, etc) except the one
1955             // that we parallelize along (the y dimension). This allows for parallelization when the Y_SIZE is small
1956             // but the other sizes are large, which boosts performance.
1957             mws = static_cast<size_t>(mws / (this->window().num_iterations_total() / this->window().num_iterations(1)));
1958             return std::max(static_cast<size_t>(1), mws);
1959         }
1960     }
1961 #else /* ENABLE_FP32_KERNELS */
1962     ARM_COMPUTE_UNUSED(platform);
1963 #endif /* ENABLE_FP32_KERNELS */
1964     if(_split_dimension == Window::DimX)
1965     {
1966         // Don't split the work load too small if the tensor has been reinterpreted as 1D.
1967         // This number is loosely chosen as threading overhead in each platform varies wildly.
1968         return default_mws_other_platforms_1d_tensor;
1969     }
1970     return default_mws;
1971 }
1972 
validate(const ITensorInfo * src1,const ITensorInfo * src2,const ITensorInfo * dst,float scale,ConvertPolicy overflow_policy,RoundingPolicy rounding_policy)1973 Status CpuMulKernel::validate(const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float scale, ConvertPolicy overflow_policy,
1974                               RoundingPolicy rounding_policy)
1975 {
1976     ARM_COMPUTE_ERROR_ON_NULLPTR(src1, src2, dst);
1977     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src1, src2, dst, scale, overflow_policy, rounding_policy));
1978 
1979     return Status{};
1980 }
1981 
run_op(ITensorPack & tensors,const Window & window,const ThreadInfo & info)1982 void CpuMulKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
1983 {
1984     ARM_COMPUTE_UNUSED(info);
1985     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1986     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
1987 
1988     auto src1 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
1989     auto src2 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
1990     auto dst  = tensors.get_tensor(TensorType::ACL_DST);
1991 
1992     if(_func_quantized != nullptr)
1993     {
1994         (*_func_quantized)(src1, src2, dst, window, _scale);
1995     }
1996     else if(_func_int != nullptr)
1997     {
1998         (*_func_int)(src1, src2, dst, window, _scale_exponent);
1999     }
2000     else
2001     {
2002         ARM_COMPUTE_ERROR_ON(_func_float == nullptr);
2003         (*_func_float)(src1, src2, dst, window, _scale);
2004     }
2005 }
2006 
name() const2007 const char *CpuMulKernel::name() const
2008 {
2009     return "CpuMulKernel";
2010 }
2011 
2012 namespace
2013 {
validate_arguments_complex(const ITensorInfo * src1,const ITensorInfo * src2,const ITensorInfo * dst)2014 Status validate_arguments_complex(const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst)
2015 {
2016     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src1, 2, DataType::F32);
2017     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src2, 2, DataType::F32);
2018 
2019     const TensorShape &out_shape = TensorShape::broadcast_shape(src1->tensor_shape(), src2->tensor_shape());
2020 
2021     ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
2022 
2023     // Validate in case of configured dst
2024     if(dst->total_size() > 0)
2025     {
2026         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 2, DataType::F32);
2027         ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst->tensor_shape(), 0), "Wrong shape for dst");
2028     }
2029 
2030     return Status{};
2031 }
2032 } // namespace
2033 
configure(ITensorInfo * src1,ITensorInfo * src2,ITensorInfo * dst)2034 void CpuComplexMulKernel::configure(ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst)
2035 {
2036     ARM_COMPUTE_ERROR_ON_NULLPTR(src1, src2, dst);
2037     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_complex(src1, src2, dst));
2038 
2039     const TensorShape &out_shape = TensorShape::broadcast_shape(src1->tensor_shape(), src2->tensor_shape());
2040 
2041     // Auto initialize dst if not initialized
2042     const TensorInfo out_info(out_shape, src1->num_channels(), src1->data_type());
2043     auto_init_if_empty(*dst, out_info);
2044 
2045     // Configure kernel window
2046     Window win = calculate_max_window(out_shape);
2047 
2048     ICpuKernel::configure(win);
2049 }
2050 
validate(const ITensorInfo * src1,const ITensorInfo * src2,const ITensorInfo * dst)2051 Status CpuComplexMulKernel::validate(const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst)
2052 {
2053     ARM_COMPUTE_ERROR_ON_NULLPTR(src1, src2, dst);
2054     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_complex(src1, src2, dst));
2055 
2056     return Status{};
2057 }
2058 
run_op(ITensorPack & tensors,const Window & window,const ThreadInfo & info)2059 void CpuComplexMulKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
2060 {
2061     ARM_COMPUTE_UNUSED(info);
2062     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
2063     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
2064 
2065     auto src1 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
2066     auto src2 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
2067     auto dst  = tensors.get_tensor(TensorType::ACL_DST);
2068 
2069     c_mul_F32_F32_F32_n(src1, src2, dst, window);
2070 }
2071 
name() const2072 const char *CpuComplexMulKernel::name() const
2073 {
2074     return "CpuComplexMulKernel";
2075 }
2076 } // namespace kernels
2077 } // namespace cpu
2078 } // namespace arm_compute
2079