xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/NEAsymm.inl (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1/*
2 * Copyright (c) 2017-2020 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24namespace arm_compute
25{
26inline qasymm8x16_t vmlaq_qasymm8(qasymm8x16_t vd, float32x4_t vs, float32x4_t vo)
27{
28    // Convert uint8 vectors to uint16 vectors
29    const uint8x8_t vd_low        = vget_low_u8(vd);
30    const uint8x8_t vd_high       = vget_high_u8(vd);
31    uint16x8_t      vd_low_u16x8  = vmovl_u8(vd_low);
32    uint16x8_t      vd_high_u16x8 = vmovl_u8(vd_high);
33    // Convert uint16 vectors to uint32 vectors
34    uint32x4_t A_u32x4 = vmovl_u16(vget_low_u16(vd_low_u16x8));
35    uint32x4_t B_u32x4 = vmovl_u16(vget_high_u16(vd_low_u16x8));
36    uint32x4_t C_u32x4 = vmovl_u16(vget_low_u16(vd_high_u16x8));
37    uint32x4_t D_u32x4 = vmovl_u16(vget_high_u16(vd_high_u16x8));
38    // Convert uint32 vectors to float32 vectors
39    float32x4_t A_f32x4 = vcvtq_f32_u32(A_u32x4);
40    float32x4_t B_f32x4 = vcvtq_f32_u32(B_u32x4);
41    float32x4_t C_f32x4 = vcvtq_f32_u32(C_u32x4);
42    float32x4_t D_f32x4 = vcvtq_f32_u32(D_u32x4);
43    // vd = vd*vs + vo
44    A_f32x4 = vmlaq_f32(vo, A_f32x4, vs);
45    B_f32x4 = vmlaq_f32(vo, B_f32x4, vs);
46    C_f32x4 = vmlaq_f32(vo, C_f32x4, vs);
47    D_f32x4 = vmlaq_f32(vo, D_f32x4, vs);
48    // Convert float32 vectors to uint32 vectors
49    A_u32x4 = vcvtq_u32_f32(A_f32x4);
50    B_u32x4 = vcvtq_u32_f32(B_f32x4);
51    C_u32x4 = vcvtq_u32_f32(C_f32x4);
52    D_u32x4 = vcvtq_u32_f32(D_f32x4);
53    // Convert uint32 vectors to uint16 vectors (with saturation)
54    vd_low_u16x8  = vcombine_u16(vqmovn_u32(A_u32x4), vqmovn_u32(B_u32x4));
55    vd_high_u16x8 = vcombine_u16(vqmovn_u32(C_u32x4), vqmovn_u32(D_u32x4));
56    // convert uint16 vectors to uint8 vectors (with saturation)
57    return vcombine_u8(vqmovn_u16(vd_low_u16x8), vqmovn_u16(vd_high_u16x8));
58}
59inline qasymm8x16_signed_t vmlaq_qasymm8_signed(qasymm8x16_signed_t vd, float32x4_t vs, float32x4_t vo)
60{
61    // Convert uint8 vectors to int16 vectors
62    const int8x8_t vd_low        = vget_low_s8(vd);
63    const int8x8_t vd_high       = vget_high_s8(vd);
64    int16x8_t      vd_low_s16x8  = vmovl_s8(vd_low);
65    int16x8_t      vd_high_s16x8 = vmovl_s8(vd_high);
66    // Convert int16 vectors to int32 vectors
67    int32x4_t A_s32x4 = vmovl_s16(vget_low_s16(vd_low_s16x8));
68    int32x4_t B_s32x4 = vmovl_s16(vget_high_s16(vd_low_s16x8));
69    int32x4_t C_s32x4 = vmovl_s16(vget_low_s16(vd_high_s16x8));
70    int32x4_t D_s32x4 = vmovl_s16(vget_high_s16(vd_high_s16x8));
71    // Convert int32 vectors to float32 vectors
72    float32x4_t A_f32x4 = vcvtq_f32_s32(A_s32x4);
73    float32x4_t B_f32x4 = vcvtq_f32_s32(B_s32x4);
74    float32x4_t C_f32x4 = vcvtq_f32_s32(C_s32x4);
75    float32x4_t D_f32x4 = vcvtq_f32_s32(D_s32x4);
76    // vd = vd*vs + vo
77    A_f32x4 = vmlaq_f32(vo, A_f32x4, vs);
78    B_f32x4 = vmlaq_f32(vo, B_f32x4, vs);
79    C_f32x4 = vmlaq_f32(vo, C_f32x4, vs);
80    D_f32x4 = vmlaq_f32(vo, D_f32x4, vs);
81    // Convert float32 vectors to int32 vectors
82    A_s32x4 = vcvtq_s32_f32(A_f32x4);
83    B_s32x4 = vcvtq_s32_f32(B_f32x4);
84    C_s32x4 = vcvtq_s32_f32(C_f32x4);
85    D_s32x4 = vcvtq_s32_f32(D_f32x4);
86    // Convert int32 vectors to int16 vectors (with saturation)
87    vd_low_s16x8  = vcombine_s16(vqmovn_s32(A_s32x4), vqmovn_s32(B_s32x4));
88    vd_high_s16x8 = vcombine_s16(vqmovn_s32(C_s32x4), vqmovn_s32(D_s32x4));
89    // convert int16 vectors to int8 vectors (with saturation)
90    return vcombine_s8(vqmovn_s16(vd_low_s16x8), vqmovn_s16(vd_high_s16x8));
91}
92} // namespace arm_compute
93