xref: /aosp_15_r20/external/libaom/aom_dsp/x86/sad_avx2.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 #include <immintrin.h>
12 
13 #include "config/aom_dsp_rtcd.h"
14 
15 #include "aom_ports/mem.h"
16 
sad64xh_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)17 static inline unsigned int sad64xh_avx2(const uint8_t *src_ptr, int src_stride,
18                                         const uint8_t *ref_ptr, int ref_stride,
19                                         int h) {
20   int i;
21   __m256i sad1_reg, sad2_reg, ref1_reg, ref2_reg;
22   __m256i sum_sad = _mm256_setzero_si256();
23   __m256i sum_sad_h;
24   __m128i sum_sad128;
25   for (i = 0; i < h; i++) {
26     ref1_reg = _mm256_loadu_si256((__m256i const *)ref_ptr);
27     ref2_reg = _mm256_loadu_si256((__m256i const *)(ref_ptr + 32));
28     sad1_reg =
29         _mm256_sad_epu8(ref1_reg, _mm256_loadu_si256((__m256i const *)src_ptr));
30     sad2_reg = _mm256_sad_epu8(
31         ref2_reg, _mm256_loadu_si256((__m256i const *)(src_ptr + 32)));
32     sum_sad = _mm256_add_epi32(sum_sad, _mm256_add_epi32(sad1_reg, sad2_reg));
33     ref_ptr += ref_stride;
34     src_ptr += src_stride;
35   }
36   sum_sad_h = _mm256_srli_si256(sum_sad, 8);
37   sum_sad = _mm256_add_epi32(sum_sad, sum_sad_h);
38   sum_sad128 = _mm256_extracti128_si256(sum_sad, 1);
39   sum_sad128 = _mm_add_epi32(_mm256_castsi256_si128(sum_sad), sum_sad128);
40   unsigned int res = (unsigned int)_mm_cvtsi128_si32(sum_sad128);
41   _mm256_zeroupper();
42   return res;
43 }
44 
sad32xh_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)45 static inline unsigned int sad32xh_avx2(const uint8_t *src_ptr, int src_stride,
46                                         const uint8_t *ref_ptr, int ref_stride,
47                                         int h) {
48   int i;
49   __m256i sad1_reg, sad2_reg, ref1_reg, ref2_reg;
50   __m256i sum_sad = _mm256_setzero_si256();
51   __m256i sum_sad_h;
52   __m128i sum_sad128;
53   int ref2_stride = ref_stride << 1;
54   int src2_stride = src_stride << 1;
55   int max = h >> 1;
56   for (i = 0; i < max; i++) {
57     ref1_reg = _mm256_loadu_si256((__m256i const *)ref_ptr);
58     ref2_reg = _mm256_loadu_si256((__m256i const *)(ref_ptr + ref_stride));
59     sad1_reg =
60         _mm256_sad_epu8(ref1_reg, _mm256_loadu_si256((__m256i const *)src_ptr));
61     sad2_reg = _mm256_sad_epu8(
62         ref2_reg, _mm256_loadu_si256((__m256i const *)(src_ptr + src_stride)));
63     sum_sad = _mm256_add_epi32(sum_sad, _mm256_add_epi32(sad1_reg, sad2_reg));
64     ref_ptr += ref2_stride;
65     src_ptr += src2_stride;
66   }
67   sum_sad_h = _mm256_srli_si256(sum_sad, 8);
68   sum_sad = _mm256_add_epi32(sum_sad, sum_sad_h);
69   sum_sad128 = _mm256_extracti128_si256(sum_sad, 1);
70   sum_sad128 = _mm_add_epi32(_mm256_castsi256_si128(sum_sad), sum_sad128);
71   unsigned int res = (unsigned int)_mm_cvtsi128_si32(sum_sad128);
72   _mm256_zeroupper();
73   return res;
74 }
75 
76 #define FSAD64_H(h)                                                           \
77   unsigned int aom_sad64x##h##_avx2(const uint8_t *src_ptr, int src_stride,   \
78                                     const uint8_t *ref_ptr, int ref_stride) { \
79     return sad64xh_avx2(src_ptr, src_stride, ref_ptr, ref_stride, h);         \
80   }
81 
82 #define FSADS64_H(h)                                                          \
83   unsigned int aom_sad_skip_64x##h##_avx2(                                    \
84       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,         \
85       int ref_stride) {                                                       \
86     return 2 * sad64xh_avx2(src_ptr, src_stride * 2, ref_ptr, ref_stride * 2, \
87                             h / 2);                                           \
88   }
89 
90 #define FSAD32_H(h)                                                           \
91   unsigned int aom_sad32x##h##_avx2(const uint8_t *src_ptr, int src_stride,   \
92                                     const uint8_t *ref_ptr, int ref_stride) { \
93     return sad32xh_avx2(src_ptr, src_stride, ref_ptr, ref_stride, h);         \
94   }
95 
96 #define FSADS32_H(h)                                                          \
97   unsigned int aom_sad_skip_32x##h##_avx2(                                    \
98       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,         \
99       int ref_stride) {                                                       \
100     return 2 * sad32xh_avx2(src_ptr, src_stride * 2, ref_ptr, ref_stride * 2, \
101                             h / 2);                                           \
102   }
103 
104 #define FSAD64  \
105   FSAD64_H(64)  \
106   FSAD64_H(32)  \
107   FSADS64_H(64) \
108   FSADS64_H(32)
109 
110 #define FSAD32  \
111   FSAD32_H(64)  \
112   FSAD32_H(32)  \
113   FSAD32_H(16)  \
114   FSADS32_H(64) \
115   FSADS32_H(32) \
116   FSADS32_H(16)
117 
118 /* clang-format off */
119 FSAD64
120 FSAD32
121 /* clang-format on */
122 
123 #undef FSAD64
124 #undef FSAD32
125 #undef FSAD64_H
126 #undef FSAD32_H
127 
128 #define FSADAVG64_H(h)                                                        \
129   unsigned int aom_sad64x##h##_avg_avx2(                                      \
130       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,         \
131       int ref_stride, const uint8_t *second_pred) {                           \
132     int i;                                                                    \
133     __m256i sad1_reg, sad2_reg, ref1_reg, ref2_reg;                           \
134     __m256i sum_sad = _mm256_setzero_si256();                                 \
135     __m256i sum_sad_h;                                                        \
136     __m128i sum_sad128;                                                       \
137     for (i = 0; i < h; i++) {                                                 \
138       ref1_reg = _mm256_loadu_si256((__m256i const *)ref_ptr);                \
139       ref2_reg = _mm256_loadu_si256((__m256i const *)(ref_ptr + 32));         \
140       ref1_reg = _mm256_avg_epu8(                                             \
141           ref1_reg, _mm256_loadu_si256((__m256i const *)second_pred));        \
142       ref2_reg = _mm256_avg_epu8(                                             \
143           ref2_reg, _mm256_loadu_si256((__m256i const *)(second_pred + 32))); \
144       sad1_reg = _mm256_sad_epu8(                                             \
145           ref1_reg, _mm256_loadu_si256((__m256i const *)src_ptr));            \
146       sad2_reg = _mm256_sad_epu8(                                             \
147           ref2_reg, _mm256_loadu_si256((__m256i const *)(src_ptr + 32)));     \
148       sum_sad =                                                               \
149           _mm256_add_epi32(sum_sad, _mm256_add_epi32(sad1_reg, sad2_reg));    \
150       ref_ptr += ref_stride;                                                  \
151       src_ptr += src_stride;                                                  \
152       second_pred += 64;                                                      \
153     }                                                                         \
154     sum_sad_h = _mm256_srli_si256(sum_sad, 8);                                \
155     sum_sad = _mm256_add_epi32(sum_sad, sum_sad_h);                           \
156     sum_sad128 = _mm256_extracti128_si256(sum_sad, 1);                        \
157     sum_sad128 = _mm_add_epi32(_mm256_castsi256_si128(sum_sad), sum_sad128);  \
158     unsigned int res = (unsigned int)_mm_cvtsi128_si32(sum_sad128);           \
159     _mm256_zeroupper();                                                       \
160     return res;                                                               \
161   }
162 
163 #define FSADAVG32_H(h)                                                        \
164   unsigned int aom_sad32x##h##_avg_avx2(                                      \
165       const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,         \
166       int ref_stride, const uint8_t *second_pred) {                           \
167     int i;                                                                    \
168     __m256i sad1_reg, sad2_reg, ref1_reg, ref2_reg;                           \
169     __m256i sum_sad = _mm256_setzero_si256();                                 \
170     __m256i sum_sad_h;                                                        \
171     __m128i sum_sad128;                                                       \
172     int ref2_stride = ref_stride << 1;                                        \
173     int src2_stride = src_stride << 1;                                        \
174     int max = h >> 1;                                                         \
175     for (i = 0; i < max; i++) {                                               \
176       ref1_reg = _mm256_loadu_si256((__m256i const *)ref_ptr);                \
177       ref2_reg = _mm256_loadu_si256((__m256i const *)(ref_ptr + ref_stride)); \
178       ref1_reg = _mm256_avg_epu8(                                             \
179           ref1_reg, _mm256_loadu_si256((__m256i const *)second_pred));        \
180       ref2_reg = _mm256_avg_epu8(                                             \
181           ref2_reg, _mm256_loadu_si256((__m256i const *)(second_pred + 32))); \
182       sad1_reg = _mm256_sad_epu8(                                             \
183           ref1_reg, _mm256_loadu_si256((__m256i const *)src_ptr));            \
184       sad2_reg = _mm256_sad_epu8(                                             \
185           ref2_reg,                                                           \
186           _mm256_loadu_si256((__m256i const *)(src_ptr + src_stride)));       \
187       sum_sad =                                                               \
188           _mm256_add_epi32(sum_sad, _mm256_add_epi32(sad1_reg, sad2_reg));    \
189       ref_ptr += ref2_stride;                                                 \
190       src_ptr += src2_stride;                                                 \
191       second_pred += 64;                                                      \
192     }                                                                         \
193     sum_sad_h = _mm256_srli_si256(sum_sad, 8);                                \
194     sum_sad = _mm256_add_epi32(sum_sad, sum_sad_h);                           \
195     sum_sad128 = _mm256_extracti128_si256(sum_sad, 1);                        \
196     sum_sad128 = _mm_add_epi32(_mm256_castsi256_si128(sum_sad), sum_sad128);  \
197     unsigned int res = (unsigned int)_mm_cvtsi128_si32(sum_sad128);           \
198     _mm256_zeroupper();                                                       \
199     return res;                                                               \
200   }
201 
202 #define FSADAVG64 \
203   FSADAVG64_H(64) \
204   FSADAVG64_H(32)
205 
206 #define FSADAVG32 \
207   FSADAVG32_H(64) \
208   FSADAVG32_H(32) \
209   FSADAVG32_H(16)
210 
211 /* clang-format off */
212 FSADAVG64
213 FSADAVG32
214 /* clang-format on */
215 
216 #undef FSADAVG64
217 #undef FSADAVG32
218 #undef FSADAVG64_H
219 #undef FSADAVG32_H
220