xref: /aosp_15_r20/external/libaom/av1/encoder/x86/rdopt_sse4.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <smmintrin.h>
14 #include "aom_dsp/x86/synonyms.h"
15 
16 #include "config/av1_rtcd.h"
17 #include "av1/encoder/rdopt.h"
18 
19 // Process horizontal and vertical correlations in a 4x4 block of pixels.
20 // We actually use the 4x4 pixels to calculate correlations corresponding to
21 // the top-left 3x3 pixels, so this function must be called with 1x1 overlap,
22 // moving the window along/down by 3 pixels at a time.
horver_correlation_4x4(const int16_t * diff,int stride,__m128i * xy_sum_32,__m128i * xz_sum_32,__m128i * x_sum_32,__m128i * x2_sum_32)23 static inline void horver_correlation_4x4(const int16_t *diff, int stride,
24                                           __m128i *xy_sum_32,
25                                           __m128i *xz_sum_32, __m128i *x_sum_32,
26                                           __m128i *x2_sum_32) {
27   // Pixels in this 4x4   [ a b c d ]
28   // are referred to as:  [ e f g h ]
29   //                      [ i j k l ]
30   //                      [ m n o p ]
31 
32   const __m128i pixelsa = xx_loadu_2x64(&diff[0 * stride], &diff[2 * stride]);
33   const __m128i pixelsb = xx_loadu_2x64(&diff[1 * stride], &diff[3 * stride]);
34   // pixelsa = [d c b a l k j i] as i16
35   // pixelsb = [h g f e p o n m] as i16
36 
37   const __m128i slli_a = _mm_slli_epi64(pixelsa, 16);
38   const __m128i slli_b = _mm_slli_epi64(pixelsb, 16);
39   // slli_a = [c b a 0 k j i 0] as i16
40   // slli_b = [g f e 0 o n m 0] as i16
41 
42   const __m128i xy_madd_a = _mm_madd_epi16(pixelsa, slli_a);
43   const __m128i xy_madd_b = _mm_madd_epi16(pixelsb, slli_b);
44   // xy_madd_a = [bc+cd ab jk+kl ij] as i32
45   // xy_madd_b = [fg+gh ef no+op mn] as i32
46 
47   const __m128i xy32 = _mm_hadd_epi32(xy_madd_b, xy_madd_a);
48   // xy32 = [ab+bc+cd ij+jk+kl ef+fg+gh mn+no+op] as i32
49   *xy_sum_32 = _mm_add_epi32(*xy_sum_32, xy32);
50 
51   const __m128i xz_madd_a = _mm_madd_epi16(slli_a, slli_b);
52   // xz_madd_a = [bf+cg ae jn+ko im] i32
53 
54   const __m128i swap_b = _mm_srli_si128(slli_b, 8);
55   // swap_b = [0 0 0 0 g f e 0] as i16
56   const __m128i xz_madd_b = _mm_madd_epi16(slli_a, swap_b);
57   // xz_madd_b = [0 0 gk+fj ei] i32
58 
59   const __m128i xz32 = _mm_hadd_epi32(xz_madd_b, xz_madd_a);
60   // xz32 = [ae+bf+cg im+jn+ko 0 ei+fj+gk] i32
61   *xz_sum_32 = _mm_add_epi32(*xz_sum_32, xz32);
62 
63   // Now calculate the straight sums, x_sum += a+b+c+e+f+g+i+j+k
64   // (sum up every element in slli_a and swap_b)
65   const __m128i sum_slli_a = _mm_hadd_epi16(slli_a, slli_a);
66   const __m128i sum_slli_a32 = _mm_cvtepi16_epi32(sum_slli_a);
67   // sum_slli_a32 = [c+b a k+j i] as i32
68   const __m128i swap_b32 = _mm_cvtepi16_epi32(swap_b);
69   // swap_b32 = [g f e 0] as i32
70   *x_sum_32 = _mm_add_epi32(*x_sum_32, sum_slli_a32);
71   *x_sum_32 = _mm_add_epi32(*x_sum_32, swap_b32);
72   // sum = [c+b+g a+f k+j+e i] as i32
73 
74   // Also sum their squares
75   const __m128i slli_a_2 = _mm_madd_epi16(slli_a, slli_a);
76   const __m128i swap_b_2 = _mm_madd_epi16(swap_b, swap_b);
77   // slli_a_2 = [c2+b2 a2 k2+j2 i2]
78   // swap_b_2 = [0 0 g2+f2 e2]
79   const __m128i sum2 = _mm_hadd_epi32(slli_a_2, swap_b_2);
80   // sum2 = [0 g2+f2+e2 c2+b2+a2 k2+j2+i2]
81   *x2_sum_32 = _mm_add_epi32(*x2_sum_32, sum2);
82 }
83 
av1_get_horver_correlation_full_sse4_1(const int16_t * diff,int stride,int width,int height,float * hcorr,float * vcorr)84 void av1_get_horver_correlation_full_sse4_1(const int16_t *diff, int stride,
85                                             int width, int height, float *hcorr,
86                                             float *vcorr) {
87   // The following notation is used:
88   // x - current pixel
89   // y - right neighbour pixel
90   // z - below neighbour pixel
91   // w - down-right neighbour pixel
92   int64_t xy_sum = 0, xz_sum = 0;
93   int64_t x_sum = 0, x2_sum = 0;
94 
95   // Process horizontal and vertical correlations through the body in 4x4
96   // blocks.  This excludes the final row and column and possibly one extra
97   // column depending how 3 divides into width and height
98   int32_t xy_tmp[4] = { 0 }, xz_tmp[4] = { 0 };
99   int32_t x_tmp[4] = { 0 }, x2_tmp[4] = { 0 };
100   __m128i xy_sum_32 = _mm_setzero_si128();
101   __m128i xz_sum_32 = _mm_setzero_si128();
102   __m128i x_sum_32 = _mm_setzero_si128();
103   __m128i x2_sum_32 = _mm_setzero_si128();
104   for (int i = 0; i <= height - 4; i += 3) {
105     for (int j = 0; j <= width - 4; j += 3) {
106       horver_correlation_4x4(&diff[i * stride + j], stride, &xy_sum_32,
107                              &xz_sum_32, &x_sum_32, &x2_sum_32);
108     }
109     xx_storeu_128(xy_tmp, xy_sum_32);
110     xx_storeu_128(xz_tmp, xz_sum_32);
111     xx_storeu_128(x_tmp, x_sum_32);
112     xx_storeu_128(x2_tmp, x2_sum_32);
113     xy_sum += (int64_t)xy_tmp[3] + xy_tmp[2] + xy_tmp[1];
114     xz_sum += (int64_t)xz_tmp[3] + xz_tmp[2] + xz_tmp[0];
115     x_sum += (int64_t)x_tmp[3] + x_tmp[2] + x_tmp[1] + x_tmp[0];
116     x2_sum += (int64_t)x2_tmp[2] + x2_tmp[1] + x2_tmp[0];
117     xy_sum_32 = _mm_setzero_si128();
118     xz_sum_32 = _mm_setzero_si128();
119     x_sum_32 = _mm_setzero_si128();
120     x2_sum_32 = _mm_setzero_si128();
121   }
122 
123   // x_sum now covers every pixel except the final 1-2 rows and 1-2 cols
124   int64_t x_finalrow = 0, x_finalcol = 0, x2_finalrow = 0, x2_finalcol = 0;
125 
126   // Do we have 2 rows remaining or just the one?  Note that width and height
127   // are powers of 2, so each modulo 3 must be 1 or 2.
128   if (height % 3 == 1) {  // Just horiz corrs on the final row
129     const int16_t x0 = diff[(height - 1) * stride];
130     x_sum += x0;
131     x_finalrow += x0;
132     x2_sum += x0 * x0;
133     x2_finalrow += x0 * x0;
134     for (int j = 0; j < width - 1; ++j) {
135       const int16_t x = diff[(height - 1) * stride + j];
136       const int16_t y = diff[(height - 1) * stride + j + 1];
137       xy_sum += x * y;
138       x_sum += y;
139       x2_sum += y * y;
140       x_finalrow += y;
141       x2_finalrow += y * y;
142     }
143   } else {  // Two rows remaining to do
144     const int16_t x0 = diff[(height - 2) * stride];
145     const int16_t z0 = diff[(height - 1) * stride];
146     x_sum += x0 + z0;
147     x2_sum += x0 * x0 + z0 * z0;
148     x_finalrow += z0;
149     x2_finalrow += z0 * z0;
150     for (int j = 0; j < width - 1; ++j) {
151       const int16_t x = diff[(height - 2) * stride + j];
152       const int16_t y = diff[(height - 2) * stride + j + 1];
153       const int16_t z = diff[(height - 1) * stride + j];
154       const int16_t w = diff[(height - 1) * stride + j + 1];
155 
156       // Horizontal and vertical correlations for the penultimate row:
157       xy_sum += x * y;
158       xz_sum += x * z;
159 
160       // Now just horizontal correlations for the final row:
161       xy_sum += z * w;
162 
163       x_sum += y + w;
164       x2_sum += y * y + w * w;
165       x_finalrow += w;
166       x2_finalrow += w * w;
167     }
168   }
169 
170   // Do we have 2 columns remaining or just the one?
171   if (width % 3 == 1) {  // Just vert corrs on the final col
172     const int16_t x0 = diff[width - 1];
173     x_sum += x0;
174     x_finalcol += x0;
175     x2_sum += x0 * x0;
176     x2_finalcol += x0 * x0;
177     for (int i = 0; i < height - 1; ++i) {
178       const int16_t x = diff[i * stride + width - 1];
179       const int16_t z = diff[(i + 1) * stride + width - 1];
180       xz_sum += x * z;
181       x_finalcol += z;
182       x2_finalcol += z * z;
183       // So the bottom-right elements don't get counted twice:
184       if (i < height - (height % 3 == 1 ? 2 : 3)) {
185         x_sum += z;
186         x2_sum += z * z;
187       }
188     }
189   } else {  // Two cols remaining
190     const int16_t x0 = diff[width - 2];
191     const int16_t y0 = diff[width - 1];
192     x_sum += x0 + y0;
193     x2_sum += x0 * x0 + y0 * y0;
194     x_finalcol += y0;
195     x2_finalcol += y0 * y0;
196     for (int i = 0; i < height - 1; ++i) {
197       const int16_t x = diff[i * stride + width - 2];
198       const int16_t y = diff[i * stride + width - 1];
199       const int16_t z = diff[(i + 1) * stride + width - 2];
200       const int16_t w = diff[(i + 1) * stride + width - 1];
201 
202       // Horizontal and vertical correlations for the penultimate col:
203       // Skip these on the last iteration of this loop if we also had two
204       // rows remaining, otherwise the final horizontal and vertical correlation
205       // get erroneously processed twice
206       if (i < height - 2 || height % 3 == 1) {
207         xy_sum += x * y;
208         xz_sum += x * z;
209       }
210 
211       x_finalcol += w;
212       x2_finalcol += w * w;
213       // So the bottom-right elements don't get counted twice:
214       if (i < height - (height % 3 == 1 ? 2 : 3)) {
215         x_sum += z + w;
216         x2_sum += z * z + w * w;
217       }
218 
219       // Now just vertical correlations for the final column:
220       xz_sum += y * w;
221     }
222   }
223 
224   // Calculate the simple sums and squared-sums
225   int64_t x_firstrow = 0, x_firstcol = 0;
226   int64_t x2_firstrow = 0, x2_firstcol = 0;
227 
228   for (int j = 0; j < width; ++j) {
229     x_firstrow += diff[j];
230     x2_firstrow += diff[j] * diff[j];
231   }
232   for (int i = 0; i < height; ++i) {
233     x_firstcol += diff[i * stride];
234     x2_firstcol += diff[i * stride] * diff[i * stride];
235   }
236 
237   int64_t xhor_sum = x_sum - x_finalcol;
238   int64_t xver_sum = x_sum - x_finalrow;
239   int64_t y_sum = x_sum - x_firstcol;
240   int64_t z_sum = x_sum - x_firstrow;
241   int64_t x2hor_sum = x2_sum - x2_finalcol;
242   int64_t x2ver_sum = x2_sum - x2_finalrow;
243   int64_t y2_sum = x2_sum - x2_firstcol;
244   int64_t z2_sum = x2_sum - x2_firstrow;
245 
246   const float num_hor = (float)(height * (width - 1));
247   const float num_ver = (float)((height - 1) * width);
248 
249   const float xhor_var_n = x2hor_sum - (xhor_sum * xhor_sum) / num_hor;
250   const float xver_var_n = x2ver_sum - (xver_sum * xver_sum) / num_ver;
251 
252   const float y_var_n = y2_sum - (y_sum * y_sum) / num_hor;
253   const float z_var_n = z2_sum - (z_sum * z_sum) / num_ver;
254 
255   const float xy_var_n = xy_sum - (xhor_sum * y_sum) / num_hor;
256   const float xz_var_n = xz_sum - (xver_sum * z_sum) / num_ver;
257 
258   if (xhor_var_n > 0 && y_var_n > 0) {
259     *hcorr = xy_var_n / sqrtf(xhor_var_n * y_var_n);
260     *hcorr = *hcorr < 0 ? 0 : *hcorr;
261   } else {
262     *hcorr = 1.0;
263   }
264   if (xver_var_n > 0 && z_var_n > 0) {
265     *vcorr = xz_var_n / sqrtf(xver_var_n * z_var_n);
266     *vcorr = *vcorr < 0 ? 0 : *vcorr;
267   } else {
268     *vcorr = 1.0;
269   }
270 }
271