xref: /aosp_15_r20/external/libopus/celt/arm/celt_neon_intr.c (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1 /* Copyright (c) 2014-2015 Xiph.Org Foundation
2    Written by Viswanath Puttagunta */
3 /**
4    @file celt_neon_intr.c
5    @brief ARM Neon Intrinsic optimizations for celt
6  */
7 
8 /*
9    Redistribution and use in source and binary forms, with or without
10    modification, are permitted provided that the following conditions
11    are met:
12 
13    - Redistributions of source code must retain the above copyright
14    notice, this list of conditions and the following disclaimer.
15 
16    - Redistributions in binary form must reproduce the above copyright
17    notice, this list of conditions and the following disclaimer in the
18    documentation and/or other materials provided with the distribution.
19 
20    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
24    OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
25    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
26    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32 
33 #ifdef HAVE_CONFIG_H
34 #include "config.h"
35 #endif
36 
37 #include <arm_neon.h>
38 #include "../pitch.h"
39 
40 #if defined(FIXED_POINT)
41 #include <string.h>
42 
xcorr_kernel_neon_fixed(const opus_val16 * x,const opus_val16 * y,opus_val32 sum[4],int len)43 void xcorr_kernel_neon_fixed(const opus_val16 * x, const opus_val16 * y, opus_val32 sum[4], int len)
44 {
45    int j;
46    int32x4_t a = vld1q_s32(sum);
47    /* Load y[0...3] */
48    /* This requires len>0 to always be valid (which we assert in the C code). */
49    int16x4_t y0 = vld1_s16(y);
50    y += 4;
51 
52    /* This loop loads one y value more than we actually need.
53       Therefore we have to stop as soon as there are 8 or fewer samples left
54        (instead of 7), to avoid reading past the end of the array. */
55    for (j = 0; j + 8 < len; j += 8)
56    {
57       /* Load x[0...7] */
58       int16x8_t xx = vld1q_s16(x);
59       int16x4_t x0 = vget_low_s16(xx);
60       int16x4_t x4 = vget_high_s16(xx);
61       /* Load y[4...11] */
62       int16x8_t yy = vld1q_s16(y);
63       int16x4_t y4 = vget_low_s16(yy);
64       int16x4_t y8 = vget_high_s16(yy);
65       int32x4_t a0 = vmlal_lane_s16(a, y0, x0, 0);
66       int32x4_t a1 = vmlal_lane_s16(a0, y4, x4, 0);
67 
68       int16x4_t y1 = vext_s16(y0, y4, 1);
69       int16x4_t y5 = vext_s16(y4, y8, 1);
70       int32x4_t a2 = vmlal_lane_s16(a1, y1, x0, 1);
71       int32x4_t a3 = vmlal_lane_s16(a2, y5, x4, 1);
72 
73       int16x4_t y2 = vext_s16(y0, y4, 2);
74       int16x4_t y6 = vext_s16(y4, y8, 2);
75       int32x4_t a4 = vmlal_lane_s16(a3, y2, x0, 2);
76       int32x4_t a5 = vmlal_lane_s16(a4, y6, x4, 2);
77 
78       int16x4_t y3 = vext_s16(y0, y4, 3);
79       int16x4_t y7 = vext_s16(y4, y8, 3);
80       int32x4_t a6 = vmlal_lane_s16(a5, y3, x0, 3);
81       int32x4_t a7 = vmlal_lane_s16(a6, y7, x4, 3);
82 
83       y0 = y8;
84       a = a7;
85       x += 8;
86       y += 8;
87    }
88    if (j + 4 < len) {
89       /* Load x[0...3] */
90       int16x4_t x0 = vld1_s16(x);
91       /* Load y[4...7] */
92       int16x4_t y4 = vld1_s16(y);
93       int32x4_t a0 = vmlal_lane_s16(a, y0, x0, 0);
94       int16x4_t y1 = vext_s16(y0, y4, 1);
95       int32x4_t a1 = vmlal_lane_s16(a0, y1, x0, 1);
96       int16x4_t y2 = vext_s16(y0, y4, 2);
97       int32x4_t a2 = vmlal_lane_s16(a1, y2, x0, 2);
98       int16x4_t y3 = vext_s16(y0, y4, 3);
99       int32x4_t a3 = vmlal_lane_s16(a2, y3, x0, 3);
100       y0 = y4;
101       a = a3;
102       x += 4;
103       y += 4;
104       j += 4;
105    }
106    if (j + 2 < len) {
107       /* Load x[0...1] */
108       int16x4x2_t xx = vld2_dup_s16(x);
109       int16x4_t x0 = xx.val[0];
110       int16x4_t x1 = xx.val[1];
111       /* Load y[4...5].
112          We would like to use vld1_dup_s32(), but casting the pointer would
113           break strict aliasing rules and potentially have alignment issues.
114          Fortunately the compiler seems capable of translating this memcpy()
115           and vdup_n_s32() into the equivalent vld1_dup_s32().*/
116       int32_t yy;
117       memcpy(&yy, y, sizeof(yy));
118       int16x4_t y4 = vreinterpret_s16_s32(vdup_n_s32(yy));
119       int32x4_t a0 = vmlal_s16(a, y0, x0);
120       int16x4_t y1 = vext_s16(y0, y4, 1);
121       /* Replace bottom copy of {y[5], y[4]} in y4 with {y[3], y[2]} from y0,
122           using VSRI instead of VEXT, since it's a data-processing
123           instruction. */
124       y0 = vreinterpret_s16_s64(vsri_n_s64(vreinterpret_s64_s16(y4),
125        vreinterpret_s64_s16(y0), 32));
126       int32x4_t a1 = vmlal_s16(a0, y1, x1);
127       a = a1;
128       x += 2;
129       y += 2;
130       j += 2;
131    }
132    if (j + 1 < len) {
133       /* Load next x. */
134       int16x4_t x0 = vld1_dup_s16(x);
135       int32x4_t a0 = vmlal_s16(a, y0, x0);
136       /* Load last y. */
137       int16x4_t y4 = vld1_dup_s16(y);
138       y0 = vreinterpret_s16_s64(vsri_n_s64(vreinterpret_s64_s16(y4),
139        vreinterpret_s64_s16(y0), 16));
140       a = a0;
141       x++;
142    }
143    /* Load last x. */
144    int16x4_t x0 = vld1_dup_s16(x);
145    int32x4_t a0 = vmlal_s16(a, y0, x0);
146    vst1q_s32(sum, a0);
147 }
148 
149 #else
150 
151 #if defined(__ARM_FEATURE_FMA) && defined(__ARM_ARCH_ISA_A64)
152 /* If we can, force the compiler to use an FMA instruction rather than break
153  *    vmlaq_f32() into fmul/fadd. */
154 #ifdef vmlaq_lane_f32
155 #undef vmlaq_lane_f32
156 #endif
157 #define vmlaq_lane_f32(a,b,c,lane) vfmaq_lane_f32(a,b,c,lane)
158 #endif
159 
160 
161 /*
162  * Function: xcorr_kernel_neon_float
163  * ---------------------------------
164  * Computes 4 correlation values and stores them in sum[4]
165  */
xcorr_kernel_neon_float(const float32_t * x,const float32_t * y,float32_t sum[4],int len)166 static void xcorr_kernel_neon_float(const float32_t *x, const float32_t *y,
167       float32_t sum[4], int len) {
168    float32x4_t YY[3];
169    float32x4_t YEXT[3];
170    float32x4_t XX[2];
171    float32x2_t XX_2;
172    float32x4_t SUMM;
173    const float32_t *xi = x;
174    const float32_t *yi = y;
175 
176    celt_assert(len>0);
177 
178    YY[0] = vld1q_f32(yi);
179    SUMM = vdupq_n_f32(0);
180 
181    /* Consume 8 elements in x vector and 12 elements in y
182     * vector. However, the 12'th element never really gets
183     * touched in this loop. So, if len == 8, then we only
184     * must access y[0] to y[10]. y[11] must not be accessed
185     * hence make sure len > 8 and not len >= 8
186     */
187    while (len > 8) {
188       yi += 4;
189       YY[1] = vld1q_f32(yi);
190       yi += 4;
191       YY[2] = vld1q_f32(yi);
192 
193       XX[0] = vld1q_f32(xi);
194       xi += 4;
195       XX[1] = vld1q_f32(xi);
196       xi += 4;
197 
198       SUMM = vmlaq_lane_f32(SUMM, YY[0], vget_low_f32(XX[0]), 0);
199       YEXT[0] = vextq_f32(YY[0], YY[1], 1);
200       SUMM = vmlaq_lane_f32(SUMM, YEXT[0], vget_low_f32(XX[0]), 1);
201       YEXT[1] = vextq_f32(YY[0], YY[1], 2);
202       SUMM = vmlaq_lane_f32(SUMM, YEXT[1], vget_high_f32(XX[0]), 0);
203       YEXT[2] = vextq_f32(YY[0], YY[1], 3);
204       SUMM = vmlaq_lane_f32(SUMM, YEXT[2], vget_high_f32(XX[0]), 1);
205 
206       SUMM = vmlaq_lane_f32(SUMM, YY[1], vget_low_f32(XX[1]), 0);
207       YEXT[0] = vextq_f32(YY[1], YY[2], 1);
208       SUMM = vmlaq_lane_f32(SUMM, YEXT[0], vget_low_f32(XX[1]), 1);
209       YEXT[1] = vextq_f32(YY[1], YY[2], 2);
210       SUMM = vmlaq_lane_f32(SUMM, YEXT[1], vget_high_f32(XX[1]), 0);
211       YEXT[2] = vextq_f32(YY[1], YY[2], 3);
212       SUMM = vmlaq_lane_f32(SUMM, YEXT[2], vget_high_f32(XX[1]), 1);
213 
214       YY[0] = YY[2];
215       len -= 8;
216    }
217 
218    /* Consume 4 elements in x vector and 8 elements in y
219     * vector. However, the 8'th element in y never really gets
220     * touched in this loop. So, if len == 4, then we only
221     * must access y[0] to y[6]. y[7] must not be accessed
222     * hence make sure len>4 and not len>=4
223     */
224    if (len > 4) {
225       yi += 4;
226       YY[1] = vld1q_f32(yi);
227 
228       XX[0] = vld1q_f32(xi);
229       xi += 4;
230 
231       SUMM = vmlaq_lane_f32(SUMM, YY[0], vget_low_f32(XX[0]), 0);
232       YEXT[0] = vextq_f32(YY[0], YY[1], 1);
233       SUMM = vmlaq_lane_f32(SUMM, YEXT[0], vget_low_f32(XX[0]), 1);
234       YEXT[1] = vextq_f32(YY[0], YY[1], 2);
235       SUMM = vmlaq_lane_f32(SUMM, YEXT[1], vget_high_f32(XX[0]), 0);
236       YEXT[2] = vextq_f32(YY[0], YY[1], 3);
237       SUMM = vmlaq_lane_f32(SUMM, YEXT[2], vget_high_f32(XX[0]), 1);
238 
239       YY[0] = YY[1];
240       len -= 4;
241    }
242 
243    while (--len > 0) {
244       XX_2 = vld1_dup_f32(xi++);
245       SUMM = vmlaq_lane_f32(SUMM, YY[0], XX_2, 0);
246       YY[0]= vld1q_f32(++yi);
247    }
248 
249    XX_2 = vld1_dup_f32(xi);
250    SUMM = vmlaq_lane_f32(SUMM, YY[0], XX_2, 0);
251 
252    vst1q_f32(sum, SUMM);
253 }
254 
celt_pitch_xcorr_float_neon(const opus_val16 * _x,const opus_val16 * _y,opus_val32 * xcorr,int len,int max_pitch,int arch)255 void celt_pitch_xcorr_float_neon(const opus_val16 *_x, const opus_val16 *_y,
256                         opus_val32 *xcorr, int len, int max_pitch, int arch) {
257    int i;
258    (void)arch;
259    celt_assert(max_pitch > 0);
260    celt_sig_assert((((unsigned char *)_x-(unsigned char *)NULL)&3)==0);
261 
262    for (i = 0; i < (max_pitch-3); i += 4) {
263       xcorr_kernel_neon_float((const float32_t *)_x, (const float32_t *)_y+i,
264             (float32_t *)xcorr+i, len);
265    }
266 
267    /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
268    for (; i < max_pitch; i++) {
269       xcorr[i] = celt_inner_prod_neon(_x, _y+i, len);
270    }
271 }
272 #endif
273