xref: /aosp_15_r20/external/libaom/aom_dsp/x86/sse_sse4.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <smmintrin.h>
14 
15 #include "config/aom_config.h"
16 #include "config/aom_dsp_rtcd.h"
17 
18 #include "aom_ports/mem.h"
19 #include "aom/aom_integer.h"
20 #include "aom_dsp/x86/synonyms.h"
21 
summary_all_sse4(const __m128i * sum_all)22 static inline int64_t summary_all_sse4(const __m128i *sum_all) {
23   int64_t sum;
24   const __m128i sum0 = _mm_cvtepu32_epi64(*sum_all);
25   const __m128i sum1 = _mm_cvtepu32_epi64(_mm_srli_si128(*sum_all, 8));
26   const __m128i sum_2x64 = _mm_add_epi64(sum0, sum1);
27   const __m128i sum_1x64 = _mm_add_epi64(sum_2x64, _mm_srli_si128(sum_2x64, 8));
28   xx_storel_64(&sum, sum_1x64);
29   return sum;
30 }
31 
32 #if CONFIG_AV1_HIGHBITDEPTH
summary_32_sse4(const __m128i * sum32,__m128i * sum64)33 static inline void summary_32_sse4(const __m128i *sum32, __m128i *sum64) {
34   const __m128i sum0 = _mm_cvtepu32_epi64(*sum32);
35   const __m128i sum1 = _mm_cvtepu32_epi64(_mm_srli_si128(*sum32, 8));
36   *sum64 = _mm_add_epi64(sum0, *sum64);
37   *sum64 = _mm_add_epi64(sum1, *sum64);
38 }
39 #endif
40 
sse_w16_sse4_1(__m128i * sum,const uint8_t * a,const uint8_t * b)41 static inline void sse_w16_sse4_1(__m128i *sum, const uint8_t *a,
42                                   const uint8_t *b) {
43   const __m128i v_a0 = xx_loadu_128(a);
44   const __m128i v_b0 = xx_loadu_128(b);
45   const __m128i v_a00_w = _mm_cvtepu8_epi16(v_a0);
46   const __m128i v_a01_w = _mm_cvtepu8_epi16(_mm_srli_si128(v_a0, 8));
47   const __m128i v_b00_w = _mm_cvtepu8_epi16(v_b0);
48   const __m128i v_b01_w = _mm_cvtepu8_epi16(_mm_srli_si128(v_b0, 8));
49   const __m128i v_d00_w = _mm_sub_epi16(v_a00_w, v_b00_w);
50   const __m128i v_d01_w = _mm_sub_epi16(v_a01_w, v_b01_w);
51   *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d00_w, v_d00_w));
52   *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d01_w, v_d01_w));
53 }
54 
sse4x2_sse4_1(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,__m128i * sum)55 static inline void sse4x2_sse4_1(const uint8_t *a, int a_stride,
56                                  const uint8_t *b, int b_stride, __m128i *sum) {
57   const __m128i v_a0 = xx_loadl_32(a);
58   const __m128i v_a1 = xx_loadl_32(a + a_stride);
59   const __m128i v_b0 = xx_loadl_32(b);
60   const __m128i v_b1 = xx_loadl_32(b + b_stride);
61   const __m128i v_a_w = _mm_cvtepu8_epi16(_mm_unpacklo_epi32(v_a0, v_a1));
62   const __m128i v_b_w = _mm_cvtepu8_epi16(_mm_unpacklo_epi32(v_b0, v_b1));
63   const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
64   *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w));
65 }
66 
sse8_sse4_1(const uint8_t * a,const uint8_t * b,__m128i * sum)67 static inline void sse8_sse4_1(const uint8_t *a, const uint8_t *b,
68                                __m128i *sum) {
69   const __m128i v_a0 = xx_loadl_64(a);
70   const __m128i v_b0 = xx_loadl_64(b);
71   const __m128i v_a_w = _mm_cvtepu8_epi16(v_a0);
72   const __m128i v_b_w = _mm_cvtepu8_epi16(v_b0);
73   const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
74   *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w));
75 }
76 
aom_sse_sse4_1(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,int width,int height)77 int64_t aom_sse_sse4_1(const uint8_t *a, int a_stride, const uint8_t *b,
78                        int b_stride, int width, int height) {
79   int y = 0;
80   int64_t sse = 0;
81   __m128i sum = _mm_setzero_si128();
82   switch (width) {
83     case 4:
84       do {
85         sse4x2_sse4_1(a, a_stride, b, b_stride, &sum);
86         a += a_stride << 1;
87         b += b_stride << 1;
88         y += 2;
89       } while (y < height);
90       sse = summary_all_sse4(&sum);
91       break;
92     case 8:
93       do {
94         sse8_sse4_1(a, b, &sum);
95         a += a_stride;
96         b += b_stride;
97         y += 1;
98       } while (y < height);
99       sse = summary_all_sse4(&sum);
100       break;
101     case 16:
102       do {
103         sse_w16_sse4_1(&sum, a, b);
104         a += a_stride;
105         b += b_stride;
106         y += 1;
107       } while (y < height);
108       sse = summary_all_sse4(&sum);
109       break;
110     case 32:
111       do {
112         sse_w16_sse4_1(&sum, a, b);
113         sse_w16_sse4_1(&sum, a + 16, b + 16);
114         a += a_stride;
115         b += b_stride;
116         y += 1;
117       } while (y < height);
118       sse = summary_all_sse4(&sum);
119       break;
120     case 64:
121       do {
122         sse_w16_sse4_1(&sum, a, b);
123         sse_w16_sse4_1(&sum, a + 16 * 1, b + 16 * 1);
124         sse_w16_sse4_1(&sum, a + 16 * 2, b + 16 * 2);
125         sse_w16_sse4_1(&sum, a + 16 * 3, b + 16 * 3);
126         a += a_stride;
127         b += b_stride;
128         y += 1;
129       } while (y < height);
130       sse = summary_all_sse4(&sum);
131       break;
132     case 128:
133       do {
134         sse_w16_sse4_1(&sum, a, b);
135         sse_w16_sse4_1(&sum, a + 16 * 1, b + 16 * 1);
136         sse_w16_sse4_1(&sum, a + 16 * 2, b + 16 * 2);
137         sse_w16_sse4_1(&sum, a + 16 * 3, b + 16 * 3);
138         sse_w16_sse4_1(&sum, a + 16 * 4, b + 16 * 4);
139         sse_w16_sse4_1(&sum, a + 16 * 5, b + 16 * 5);
140         sse_w16_sse4_1(&sum, a + 16 * 6, b + 16 * 6);
141         sse_w16_sse4_1(&sum, a + 16 * 7, b + 16 * 7);
142         a += a_stride;
143         b += b_stride;
144         y += 1;
145       } while (y < height);
146       sse = summary_all_sse4(&sum);
147       break;
148     default:
149       if (width & 0x07) {
150         do {
151           int i = 0;
152           do {
153             sse8_sse4_1(a + i, b + i, &sum);
154             sse8_sse4_1(a + i + a_stride, b + i + b_stride, &sum);
155             i += 8;
156           } while (i + 4 < width);
157           sse4x2_sse4_1(a + i, a_stride, b + i, b_stride, &sum);
158           a += (a_stride << 1);
159           b += (b_stride << 1);
160           y += 2;
161         } while (y < height);
162       } else {
163         do {
164           int i = 0;
165           do {
166             sse8_sse4_1(a + i, b + i, &sum);
167             i += 8;
168           } while (i < width);
169           a += a_stride;
170           b += b_stride;
171           y += 1;
172         } while (y < height);
173       }
174       sse = summary_all_sse4(&sum);
175       break;
176   }
177 
178   return sse;
179 }
180 
181 #if CONFIG_AV1_HIGHBITDEPTH
highbd_sse_w4x2_sse4_1(__m128i * sum,const uint16_t * a,int a_stride,const uint16_t * b,int b_stride)182 static inline void highbd_sse_w4x2_sse4_1(__m128i *sum, const uint16_t *a,
183                                           int a_stride, const uint16_t *b,
184                                           int b_stride) {
185   const __m128i v_a0 = xx_loadl_64(a);
186   const __m128i v_a1 = xx_loadl_64(a + a_stride);
187   const __m128i v_b0 = xx_loadl_64(b);
188   const __m128i v_b1 = xx_loadl_64(b + b_stride);
189   const __m128i v_a_w = _mm_unpacklo_epi64(v_a0, v_a1);
190   const __m128i v_b_w = _mm_unpacklo_epi64(v_b0, v_b1);
191   const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
192   *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w));
193 }
194 
highbd_sse_w8_sse4_1(__m128i * sum,const uint16_t * a,const uint16_t * b)195 static inline void highbd_sse_w8_sse4_1(__m128i *sum, const uint16_t *a,
196                                         const uint16_t *b) {
197   const __m128i v_a_w = xx_loadu_128(a);
198   const __m128i v_b_w = xx_loadu_128(b);
199   const __m128i v_d_w = _mm_sub_epi16(v_a_w, v_b_w);
200   *sum = _mm_add_epi32(*sum, _mm_madd_epi16(v_d_w, v_d_w));
201 }
202 
aom_highbd_sse_sse4_1(const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,int width,int height)203 int64_t aom_highbd_sse_sse4_1(const uint8_t *a8, int a_stride,
204                               const uint8_t *b8, int b_stride, int width,
205                               int height) {
206   int32_t y = 0;
207   int64_t sse = 0;
208   uint16_t *a = CONVERT_TO_SHORTPTR(a8);
209   uint16_t *b = CONVERT_TO_SHORTPTR(b8);
210   __m128i sum = _mm_setzero_si128();
211   switch (width) {
212     case 4:
213       do {
214         highbd_sse_w4x2_sse4_1(&sum, a, a_stride, b, b_stride);
215         a += a_stride << 1;
216         b += b_stride << 1;
217         y += 2;
218       } while (y < height);
219       sse = summary_all_sse4(&sum);
220       break;
221     case 8:
222       do {
223         highbd_sse_w8_sse4_1(&sum, a, b);
224         a += a_stride;
225         b += b_stride;
226         y += 1;
227       } while (y < height);
228       sse = summary_all_sse4(&sum);
229       break;
230     case 16:
231       do {
232         int l = 0;
233         __m128i sum32 = _mm_setzero_si128();
234         do {
235           highbd_sse_w8_sse4_1(&sum32, a, b);
236           highbd_sse_w8_sse4_1(&sum32, a + 8, b + 8);
237           a += a_stride;
238           b += b_stride;
239           l += 1;
240         } while (l < 64 && l < (height - y));
241         summary_32_sse4(&sum32, &sum);
242         y += 64;
243       } while (y < height);
244       xx_storel_64(&sse, _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
245       break;
246     case 32:
247       do {
248         int l = 0;
249         __m128i sum32 = _mm_setzero_si128();
250         do {
251           highbd_sse_w8_sse4_1(&sum32, a, b);
252           highbd_sse_w8_sse4_1(&sum32, a + 8 * 1, b + 8 * 1);
253           highbd_sse_w8_sse4_1(&sum32, a + 8 * 2, b + 8 * 2);
254           highbd_sse_w8_sse4_1(&sum32, a + 8 * 3, b + 8 * 3);
255           a += a_stride;
256           b += b_stride;
257           l += 1;
258         } while (l < 32 && l < (height - y));
259         summary_32_sse4(&sum32, &sum);
260         y += 32;
261       } while (y < height);
262       xx_storel_64(&sse, _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
263       break;
264     case 64:
265       do {
266         int l = 0;
267         __m128i sum32 = _mm_setzero_si128();
268         do {
269           highbd_sse_w8_sse4_1(&sum32, a, b);
270           highbd_sse_w8_sse4_1(&sum32, a + 8 * 1, b + 8 * 1);
271           highbd_sse_w8_sse4_1(&sum32, a + 8 * 2, b + 8 * 2);
272           highbd_sse_w8_sse4_1(&sum32, a + 8 * 3, b + 8 * 3);
273           highbd_sse_w8_sse4_1(&sum32, a + 8 * 4, b + 8 * 4);
274           highbd_sse_w8_sse4_1(&sum32, a + 8 * 5, b + 8 * 5);
275           highbd_sse_w8_sse4_1(&sum32, a + 8 * 6, b + 8 * 6);
276           highbd_sse_w8_sse4_1(&sum32, a + 8 * 7, b + 8 * 7);
277           a += a_stride;
278           b += b_stride;
279           l += 1;
280         } while (l < 16 && l < (height - y));
281         summary_32_sse4(&sum32, &sum);
282         y += 16;
283       } while (y < height);
284       xx_storel_64(&sse, _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
285       break;
286     case 128:
287       do {
288         int l = 0;
289         __m128i sum32 = _mm_setzero_si128();
290         do {
291           highbd_sse_w8_sse4_1(&sum32, a, b);
292           highbd_sse_w8_sse4_1(&sum32, a + 8 * 1, b + 8 * 1);
293           highbd_sse_w8_sse4_1(&sum32, a + 8 * 2, b + 8 * 2);
294           highbd_sse_w8_sse4_1(&sum32, a + 8 * 3, b + 8 * 3);
295           highbd_sse_w8_sse4_1(&sum32, a + 8 * 4, b + 8 * 4);
296           highbd_sse_w8_sse4_1(&sum32, a + 8 * 5, b + 8 * 5);
297           highbd_sse_w8_sse4_1(&sum32, a + 8 * 6, b + 8 * 6);
298           highbd_sse_w8_sse4_1(&sum32, a + 8 * 7, b + 8 * 7);
299           highbd_sse_w8_sse4_1(&sum32, a + 8 * 8, b + 8 * 8);
300           highbd_sse_w8_sse4_1(&sum32, a + 8 * 9, b + 8 * 9);
301           highbd_sse_w8_sse4_1(&sum32, a + 8 * 10, b + 8 * 10);
302           highbd_sse_w8_sse4_1(&sum32, a + 8 * 11, b + 8 * 11);
303           highbd_sse_w8_sse4_1(&sum32, a + 8 * 12, b + 8 * 12);
304           highbd_sse_w8_sse4_1(&sum32, a + 8 * 13, b + 8 * 13);
305           highbd_sse_w8_sse4_1(&sum32, a + 8 * 14, b + 8 * 14);
306           highbd_sse_w8_sse4_1(&sum32, a + 8 * 15, b + 8 * 15);
307           a += a_stride;
308           b += b_stride;
309           l += 1;
310         } while (l < 8 && l < (height - y));
311         summary_32_sse4(&sum32, &sum);
312         y += 8;
313       } while (y < height);
314       xx_storel_64(&sse, _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
315       break;
316     default:
317       if (width & 0x7) {
318         do {
319           __m128i sum32 = _mm_setzero_si128();
320           int i = 0;
321           do {
322             highbd_sse_w8_sse4_1(&sum32, a + i, b + i);
323             highbd_sse_w8_sse4_1(&sum32, a + i + a_stride, b + i + b_stride);
324             i += 8;
325           } while (i + 4 < width);
326           highbd_sse_w4x2_sse4_1(&sum32, a + i, a_stride, b + i, b_stride);
327           a += (a_stride << 1);
328           b += (b_stride << 1);
329           y += 2;
330           summary_32_sse4(&sum32, &sum);
331         } while (y < height);
332       } else {
333         do {
334           int l = 0;
335           __m128i sum32 = _mm_setzero_si128();
336           do {
337             int i = 0;
338             do {
339               highbd_sse_w8_sse4_1(&sum32, a + i, b + i);
340               i += 8;
341             } while (i < width);
342             a += a_stride;
343             b += b_stride;
344             l += 1;
345           } while (l < 8 && l < (height - y));
346           summary_32_sse4(&sum32, &sum);
347           y += 8;
348         } while (y < height);
349       }
350       xx_storel_64(&sse, _mm_add_epi64(sum, _mm_srli_si128(sum, 8)));
351       break;
352   }
353   return sse;
354 }
355 #endif  // CONFIG_AV1_HIGHBITDEPTH
356