xref: /aosp_15_r20/external/libaom/av1/encoder/x86/temporal_filter_avx2.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2019, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <immintrin.h>
14 
15 #include "config/av1_rtcd.h"
16 #include "av1/encoder/encoder.h"
17 #include "av1/encoder/temporal_filter.h"
18 
19 #define SSE_STRIDE (BW + 2)
20 
21 DECLARE_ALIGNED(32, static const uint32_t, sse_bytemask[4][8]) = {
22   { 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0 },
23   { 0, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0 },
24   { 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0 },
25   { 0, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF }
26 };
27 
28 DECLARE_ALIGNED(32, static const uint8_t, shufflemask_16b[2][16]) = {
29   { 0, 1, 0, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 },
30   { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 10, 11, 10, 11 }
31 };
32 
33 #define CALC_X_GRADIENT(AC, GI, DF, out) \
34   out = _mm256_abs_epi16(                \
35       _mm256_add_epi16(_mm256_add_epi16(AC, GI), _mm256_slli_epi16(DF, 1)));
36 
37 #define CALC_Y_GRADIENT(AC, GI, BH, out) \
38   out = _mm256_abs_epi16(                \
39       _mm256_add_epi16(_mm256_sub_epi16(AC, GI), _mm256_slli_epi16(BH, 1)));
40 
av1_estimate_noise_from_single_plane_avx2(const uint8_t * src,int height,int width,int stride,int edge_thresh)41 double av1_estimate_noise_from_single_plane_avx2(const uint8_t *src, int height,
42                                                  int width, int stride,
43                                                  int edge_thresh) {
44   int count = 0;
45   int64_t accum = 0;
46   // w32 stores width multiple of 32.
47   const int w32 = (width - 1) & ~0x1f;
48   const __m256i zero = _mm256_setzero_si256();
49   const __m256i edge_threshold = _mm256_set1_epi16(edge_thresh);
50   __m256i num_accumulator = zero;
51   __m256i sum_accumulator = zero;
52 
53   //  A | B | C
54   //  D | E | F
55   //  G | H | I
56   // g_x = (A - C) + (G - I) + 2*(D - F)
57   // g_y = (A + C) - (G + I) + 2*(B - H)
58   // v   = 4*E - 2*(D+F+B+H) + (A+C+G+I)
59 
60   // Process the width multiple of 32 here.
61   for (int w = 1; w < w32; w += 32) {
62     int h = 1;
63     const int start_idx = h * stride + w;
64     const int stride_0 = start_idx - stride;
65 
66     __m256i num_accum_row_lvl = zero;
67     const __m256i A = _mm256_loadu_si256((__m256i *)(&src[stride_0 - 1]));
68     const __m256i C = _mm256_loadu_si256((__m256i *)(&src[stride_0 + 1]));
69     const __m256i D = _mm256_loadu_si256((__m256i *)(&src[start_idx - 1]));
70     const __m256i F = _mm256_loadu_si256((__m256i *)(&src[start_idx + 1]));
71     __m256i B = _mm256_loadu_si256((__m256i *)(&src[stride_0]));
72     __m256i E = _mm256_loadu_si256((__m256i *)(&src[start_idx]));
73 
74     const __m256i A_lo = _mm256_unpacklo_epi8(A, zero);
75     const __m256i A_hi = _mm256_unpackhi_epi8(A, zero);
76     const __m256i C_lo = _mm256_unpacklo_epi8(C, zero);
77     const __m256i C_hi = _mm256_unpackhi_epi8(C, zero);
78     const __m256i D_lo = _mm256_unpacklo_epi8(D, zero);
79     const __m256i D_hi = _mm256_unpackhi_epi8(D, zero);
80     const __m256i F_lo = _mm256_unpacklo_epi8(F, zero);
81     const __m256i F_hi = _mm256_unpackhi_epi8(F, zero);
82 
83     __m256i sub_AC_lo = _mm256_sub_epi16(A_lo, C_lo);
84     __m256i sub_AC_hi = _mm256_sub_epi16(A_hi, C_hi);
85     __m256i sum_AC_lo = _mm256_add_epi16(A_lo, C_lo);
86     __m256i sum_AC_hi = _mm256_add_epi16(A_hi, C_hi);
87     __m256i sub_DF_lo = _mm256_sub_epi16(D_lo, F_lo);
88     __m256i sub_DF_hi = _mm256_sub_epi16(D_hi, F_hi);
89     __m256i sum_DF_lo = _mm256_add_epi16(D_lo, F_lo);
90     __m256i sum_DF_hi = _mm256_add_epi16(D_hi, F_hi);
91 
92     for (; h < height - 1; h++) {
93       __m256i sum_GI_lo, sub_GI_lo, sum_GI_hi, sub_GI_hi, gx_lo, gy_lo, gx_hi,
94           gy_hi;
95       const int k = h * stride + w;
96       const __m256i G = _mm256_loadu_si256((__m256i *)(&src[k + stride - 1]));
97       const __m256i H = _mm256_loadu_si256((__m256i *)(&src[k + stride]));
98       const __m256i I = _mm256_loadu_si256((__m256i *)(&src[k + stride + 1]));
99 
100       const __m256i B_lo = _mm256_unpacklo_epi8(B, zero);
101       const __m256i B_hi = _mm256_unpackhi_epi8(B, zero);
102       const __m256i G_lo = _mm256_unpacklo_epi8(G, zero);
103       const __m256i G_hi = _mm256_unpackhi_epi8(G, zero);
104       const __m256i I_lo = _mm256_unpacklo_epi8(I, zero);
105       const __m256i I_hi = _mm256_unpackhi_epi8(I, zero);
106       const __m256i H_lo = _mm256_unpacklo_epi8(H, zero);
107       const __m256i H_hi = _mm256_unpackhi_epi8(H, zero);
108 
109       sub_GI_lo = _mm256_sub_epi16(G_lo, I_lo);
110       sub_GI_hi = _mm256_sub_epi16(G_hi, I_hi);
111       sum_GI_lo = _mm256_add_epi16(G_lo, I_lo);
112       sum_GI_hi = _mm256_add_epi16(G_hi, I_hi);
113       const __m256i sub_BH_lo = _mm256_sub_epi16(B_lo, H_lo);
114       const __m256i sub_BH_hi = _mm256_sub_epi16(B_hi, H_hi);
115 
116       CALC_X_GRADIENT(sub_AC_lo, sub_GI_lo, sub_DF_lo, gx_lo)
117       CALC_Y_GRADIENT(sum_AC_lo, sum_GI_lo, sub_BH_lo, gy_lo)
118 
119       const __m256i ga_lo = _mm256_add_epi16(gx_lo, gy_lo);
120 
121       CALC_X_GRADIENT(sub_AC_hi, sub_GI_hi, sub_DF_hi, gx_hi)
122       CALC_Y_GRADIENT(sum_AC_hi, sum_GI_hi, sub_BH_hi, gy_hi)
123 
124       const __m256i ga_hi = _mm256_add_epi16(gx_hi, gy_hi);
125 
126       __m256i cmp_lo = _mm256_cmpgt_epi16(edge_threshold, ga_lo);
127       __m256i cmp_hi = _mm256_cmpgt_epi16(edge_threshold, ga_hi);
128       const __m256i comp_reg = _mm256_add_epi16(cmp_lo, cmp_hi);
129 
130       // v = 4*E -2*(D+F+B+H) + (A+C+G+I)
131       if (_mm256_movemask_epi8(comp_reg) != 0) {
132         const __m256i sum_BH_lo = _mm256_add_epi16(B_lo, H_lo);
133         const __m256i sum_BH_hi = _mm256_add_epi16(B_hi, H_hi);
134 
135         // 2*(D+F+B+H)
136         const __m256i sum_DFBH_lo =
137             _mm256_slli_epi16(_mm256_add_epi16(sum_DF_lo, sum_BH_lo), 1);
138         // (A+C+G+I)
139         const __m256i sum_ACGI_lo = _mm256_add_epi16(sum_AC_lo, sum_GI_lo);
140         const __m256i sum_DFBH_hi =
141             _mm256_slli_epi16(_mm256_add_epi16(sum_DF_hi, sum_BH_hi), 1);
142         const __m256i sum_ACGI_hi = _mm256_add_epi16(sum_AC_hi, sum_GI_hi);
143 
144         // Convert E register values from 8bit to 16bit
145         const __m256i E_lo = _mm256_unpacklo_epi8(E, zero);
146         const __m256i E_hi = _mm256_unpackhi_epi8(E, zero);
147 
148         // 4*E - 2*(D+F+B+H)+ (A+C+G+I)
149         const __m256i var_lo_0 = _mm256_abs_epi16(_mm256_add_epi16(
150             _mm256_sub_epi16(_mm256_slli_epi16(E_lo, 2), sum_DFBH_lo),
151             sum_ACGI_lo));
152         const __m256i var_hi_0 = _mm256_abs_epi16(_mm256_add_epi16(
153             _mm256_sub_epi16(_mm256_slli_epi16(E_hi, 2), sum_DFBH_hi),
154             sum_ACGI_hi));
155         cmp_lo = _mm256_srli_epi16(cmp_lo, 15);
156         cmp_hi = _mm256_srli_epi16(cmp_hi, 15);
157         const __m256i var_lo = _mm256_mullo_epi16(var_lo_0, cmp_lo);
158         const __m256i var_hi = _mm256_mullo_epi16(var_hi_0, cmp_hi);
159 
160         num_accum_row_lvl = _mm256_add_epi16(num_accum_row_lvl, cmp_lo);
161         num_accum_row_lvl = _mm256_add_epi16(num_accum_row_lvl, cmp_hi);
162 
163         sum_accumulator = _mm256_add_epi32(sum_accumulator,
164                                            _mm256_unpacklo_epi16(var_lo, zero));
165         sum_accumulator = _mm256_add_epi32(sum_accumulator,
166                                            _mm256_unpackhi_epi16(var_lo, zero));
167         sum_accumulator = _mm256_add_epi32(sum_accumulator,
168                                            _mm256_unpacklo_epi16(var_hi, zero));
169         sum_accumulator = _mm256_add_epi32(sum_accumulator,
170                                            _mm256_unpackhi_epi16(var_hi, zero));
171       }
172       sub_AC_lo = sub_DF_lo;
173       sub_AC_hi = sub_DF_hi;
174       sub_DF_lo = sub_GI_lo;
175       sub_DF_hi = sub_GI_hi;
176       sum_AC_lo = sum_DF_lo;
177       sum_AC_hi = sum_DF_hi;
178       sum_DF_lo = sum_GI_lo;
179       sum_DF_hi = sum_GI_hi;
180       B = E;
181       E = H;
182     }
183     const __m256i num_0 = _mm256_unpacklo_epi16(num_accum_row_lvl, zero);
184     const __m256i num_1 = _mm256_unpackhi_epi16(num_accum_row_lvl, zero);
185     num_accumulator =
186         _mm256_add_epi32(num_accumulator, _mm256_add_epi32(num_0, num_1));
187   }
188 
189   // Process the remaining width here.
190   for (int h = 1; h < height - 1; ++h) {
191     for (int w = w32 + 1; w < width - 1; ++w) {
192       const int k = h * stride + w;
193 
194       // Compute sobel gradients
195       const int g_x = (src[k - stride - 1] - src[k - stride + 1]) +
196                       (src[k + stride - 1] - src[k + stride + 1]) +
197                       2 * (src[k - 1] - src[k + 1]);
198       const int g_y = (src[k - stride - 1] - src[k + stride - 1]) +
199                       (src[k - stride + 1] - src[k + stride + 1]) +
200                       2 * (src[k - stride] - src[k + stride]);
201       const int ga = abs(g_x) + abs(g_y);
202 
203       if (ga < edge_thresh) {
204         // Find Laplacian
205         const int v =
206             4 * src[k] -
207             2 * (src[k - 1] + src[k + 1] + src[k - stride] + src[k + stride]) +
208             (src[k - stride - 1] + src[k - stride + 1] + src[k + stride - 1] +
209              src[k + stride + 1]);
210         accum += abs(v);
211         ++count;
212       }
213     }
214   }
215 
216   // s0 s1 n0 n1 s2 s3 n2 n3
217   __m256i sum_avx = _mm256_hadd_epi32(sum_accumulator, num_accumulator);
218   __m128i sum_avx_lo = _mm256_castsi256_si128(sum_avx);
219   __m128i sum_avx_hi = _mm256_extractf128_si256(sum_avx, 1);
220   // s0+s2 s1+s3 n0+n2 n1+n3
221   __m128i sum_avx_1 = _mm_add_epi32(sum_avx_lo, sum_avx_hi);
222   // s0+s2+s1+s3 n0+n2+n1+n3
223   __m128i result = _mm_add_epi32(_mm_srli_si128(sum_avx_1, 4), sum_avx_1);
224 
225   accum += _mm_cvtsi128_si32(result);
226   count += _mm_extract_epi32(result, 2);
227 
228   // If very few smooth pels, return -1 since the estimate is unreliable.
229   return (count < 16) ? -1.0 : (double)accum / (6 * count) * SQRT_PI_BY_2;
230 }
231 
get_squared_error_16x16_avx2(const uint8_t * frame1,const unsigned int stride,const uint8_t * frame2,const unsigned int stride2,const int block_width,const int block_height,uint16_t * frame_sse,const unsigned int sse_stride)232 static AOM_FORCE_INLINE void get_squared_error_16x16_avx2(
233     const uint8_t *frame1, const unsigned int stride, const uint8_t *frame2,
234     const unsigned int stride2, const int block_width, const int block_height,
235     uint16_t *frame_sse, const unsigned int sse_stride) {
236   (void)block_width;
237   const uint8_t *src1 = frame1;
238   const uint8_t *src2 = frame2;
239   uint16_t *dst = frame_sse;
240   for (int i = 0; i < block_height; i++) {
241     __m128i vf1_128, vf2_128;
242     __m256i vf1, vf2, vdiff1, vsqdiff1;
243 
244     vf1_128 = _mm_loadu_si128((__m128i *)(src1));
245     vf2_128 = _mm_loadu_si128((__m128i *)(src2));
246     vf1 = _mm256_cvtepu8_epi16(vf1_128);
247     vf2 = _mm256_cvtepu8_epi16(vf2_128);
248     vdiff1 = _mm256_sub_epi16(vf1, vf2);
249     vsqdiff1 = _mm256_mullo_epi16(vdiff1, vdiff1);
250 
251     _mm256_storeu_si256((__m256i *)(dst), vsqdiff1);
252     // Set zero to uninitialized memory to avoid uninitialized loads later
253     *(int *)(dst + 16) = _mm_cvtsi128_si32(_mm_setzero_si128());
254 
255     src1 += stride, src2 += stride2;
256     dst += sse_stride;
257   }
258 }
259 
get_squared_error_32x32_avx2(const uint8_t * frame1,const unsigned int stride,const uint8_t * frame2,const unsigned int stride2,const int block_width,const int block_height,uint16_t * frame_sse,const unsigned int sse_stride)260 static AOM_FORCE_INLINE void get_squared_error_32x32_avx2(
261     const uint8_t *frame1, const unsigned int stride, const uint8_t *frame2,
262     const unsigned int stride2, const int block_width, const int block_height,
263     uint16_t *frame_sse, const unsigned int sse_stride) {
264   (void)block_width;
265   const uint8_t *src1 = frame1;
266   const uint8_t *src2 = frame2;
267   uint16_t *dst = frame_sse;
268   for (int i = 0; i < block_height; i++) {
269     __m256i vsrc1, vsrc2, vmin, vmax, vdiff, vdiff1, vdiff2, vres1, vres2;
270 
271     vsrc1 = _mm256_loadu_si256((__m256i *)src1);
272     vsrc2 = _mm256_loadu_si256((__m256i *)src2);
273     vmax = _mm256_max_epu8(vsrc1, vsrc2);
274     vmin = _mm256_min_epu8(vsrc1, vsrc2);
275     vdiff = _mm256_subs_epu8(vmax, vmin);
276 
277     __m128i vtmp1 = _mm256_castsi256_si128(vdiff);
278     __m128i vtmp2 = _mm256_extracti128_si256(vdiff, 1);
279     vdiff1 = _mm256_cvtepu8_epi16(vtmp1);
280     vdiff2 = _mm256_cvtepu8_epi16(vtmp2);
281 
282     vres1 = _mm256_mullo_epi16(vdiff1, vdiff1);
283     vres2 = _mm256_mullo_epi16(vdiff2, vdiff2);
284     _mm256_storeu_si256((__m256i *)(dst), vres1);
285     _mm256_storeu_si256((__m256i *)(dst + 16), vres2);
286     // Set zero to uninitialized memory to avoid uninitialized loads later
287     *(int *)(dst + 32) = _mm_cvtsi128_si32(_mm_setzero_si128());
288 
289     src1 += stride;
290     src2 += stride2;
291     dst += sse_stride;
292   }
293 }
294 
xx_load_and_pad(uint16_t * src,int col,int block_width)295 static AOM_FORCE_INLINE __m256i xx_load_and_pad(uint16_t *src, int col,
296                                                 int block_width) {
297   __m128i v128tmp = _mm_loadu_si128((__m128i *)(src));
298   if (col == 0) {
299     // For the first column, replicate the first element twice to the left
300     v128tmp = _mm_shuffle_epi8(v128tmp, *(__m128i *)shufflemask_16b[0]);
301   }
302   if (col == block_width - 4) {
303     // For the last column, replicate the last element twice to the right
304     v128tmp = _mm_shuffle_epi8(v128tmp, *(__m128i *)shufflemask_16b[1]);
305   }
306   return _mm256_cvtepu16_epi32(v128tmp);
307 }
308 
xx_mask_and_hadd(__m256i vsum,int i)309 static AOM_FORCE_INLINE int32_t xx_mask_and_hadd(__m256i vsum, int i) {
310   // Mask the required 5 values inside the vector
311   __m256i vtmp = _mm256_and_si256(vsum, *(__m256i *)sse_bytemask[i]);
312   __m128i v128a, v128b;
313   // Extract 256b as two 128b registers A and B
314   v128a = _mm256_castsi256_si128(vtmp);
315   v128b = _mm256_extracti128_si256(vtmp, 1);
316   // A = [A0+B0, A1+B1, A2+B2, A3+B3]
317   v128a = _mm_add_epi32(v128a, v128b);
318   // B = [A2+B2, A3+B3, 0, 0]
319   v128b = _mm_srli_si128(v128a, 8);
320   // A = [A0+B0+A2+B2, A1+B1+A3+B3, X, X]
321   v128a = _mm_add_epi32(v128a, v128b);
322   // B = [A1+B1+A3+B3, 0, 0, 0]
323   v128b = _mm_srli_si128(v128a, 4);
324   // A = [A0+B0+A2+B2+A1+B1+A3+B3, X, X, X]
325   v128a = _mm_add_epi32(v128a, v128b);
326   return _mm_extract_epi32(v128a, 0);
327 }
328 
329 // AVX2 implementation of approx_exp()
approx_exp_avx2(__m256 y)330 static inline __m256 approx_exp_avx2(__m256 y) {
331 #define A ((1 << 23) / 0.69314718056f)  // (1 << 23) / ln(2)
332 #define B \
333   127  // Offset for the exponent according to IEEE floating point standard.
334 #define C 60801  // Magic number controls the accuracy of approximation
335   const __m256 multiplier = _mm256_set1_ps(A);
336   const __m256i offset = _mm256_set1_epi32(B * (1 << 23) - C);
337 
338   y = _mm256_mul_ps(y, multiplier);
339   y = _mm256_castsi256_ps(_mm256_add_epi32(_mm256_cvttps_epi32(y), offset));
340   return y;
341 #undef A
342 #undef B
343 #undef C
344 }
345 
apply_temporal_filter(const uint8_t * frame1,const unsigned int stride,const uint8_t * frame2,const unsigned int stride2,const int block_width,const int block_height,const int * subblock_mses,unsigned int * accumulator,uint16_t * count,uint16_t * frame_sse,uint32_t * luma_sse_sum,const double inv_num_ref_pixels,const double decay_factor,const double inv_factor,const double weight_factor,double * d_factor,int tf_wgt_calc_lvl)346 static void apply_temporal_filter(
347     const uint8_t *frame1, const unsigned int stride, const uint8_t *frame2,
348     const unsigned int stride2, const int block_width, const int block_height,
349     const int *subblock_mses, unsigned int *accumulator, uint16_t *count,
350     uint16_t *frame_sse, uint32_t *luma_sse_sum,
351     const double inv_num_ref_pixels, const double decay_factor,
352     const double inv_factor, const double weight_factor, double *d_factor,
353     int tf_wgt_calc_lvl) {
354   assert(((block_width == 16) || (block_width == 32)) &&
355          ((block_height == 16) || (block_height == 32)));
356 
357   uint32_t acc_5x5_sse[BH][BW];
358 
359   if (block_width == 32) {
360     get_squared_error_32x32_avx2(frame1, stride, frame2, stride2, block_width,
361                                  block_height, frame_sse, SSE_STRIDE);
362   } else {
363     get_squared_error_16x16_avx2(frame1, stride, frame2, stride2, block_width,
364                                  block_height, frame_sse, SSE_STRIDE);
365   }
366 
367   __m256i vsrc[5];
368 
369   // Traverse 4 columns at a time
370   // First and last columns will require padding
371   for (int col = 0; col < block_width; col += 4) {
372     uint16_t *src = (col) ? frame_sse + col - 2 : frame_sse;
373 
374     // Load and pad(for first and last col) 3 rows from the top
375     for (int i = 2; i < 5; i++) {
376       vsrc[i] = xx_load_and_pad(src, col, block_width);
377       src += SSE_STRIDE;
378     }
379 
380     // Copy first row to first 2 vectors
381     vsrc[0] = vsrc[2];
382     vsrc[1] = vsrc[2];
383 
384     for (int row = 0; row < block_height; row++) {
385       __m256i vsum = _mm256_setzero_si256();
386 
387       // Add 5 consecutive rows
388       for (int i = 0; i < 5; i++) {
389         vsum = _mm256_add_epi32(vsum, vsrc[i]);
390       }
391 
392       // Push all elements by one element to the top
393       for (int i = 0; i < 4; i++) {
394         vsrc[i] = vsrc[i + 1];
395       }
396 
397       // Load next row to the last element
398       if (row <= block_height - 4) {
399         vsrc[4] = xx_load_and_pad(src, col, block_width);
400         src += SSE_STRIDE;
401       } else {
402         vsrc[4] = vsrc[3];
403       }
404 
405       // Accumulate the sum horizontally
406       for (int i = 0; i < 4; i++) {
407         acc_5x5_sse[row][col + i] = xx_mask_and_hadd(vsum, i);
408       }
409     }
410   }
411 
412   double subblock_mses_scaled[4];
413   double d_factor_decayed[4];
414   for (int idx = 0; idx < 4; idx++) {
415     subblock_mses_scaled[idx] = subblock_mses[idx] * inv_factor;
416     d_factor_decayed[idx] = d_factor[idx] * decay_factor;
417   }
418   if (tf_wgt_calc_lvl == 0) {
419     for (int i = 0, k = 0; i < block_height; i++) {
420       const int y_blk_raster_offset = (i >= block_height / 2) * 2;
421       for (int j = 0; j < block_width; j++, k++) {
422         const int pixel_value = frame2[i * stride2 + j];
423         uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
424 
425         const double window_error = diff_sse * inv_num_ref_pixels;
426         const int subblock_idx = y_blk_raster_offset + (j >= block_width / 2);
427         const double combined_error =
428             weight_factor * window_error + subblock_mses_scaled[subblock_idx];
429 
430         double scaled_error = combined_error * d_factor_decayed[subblock_idx];
431         scaled_error = AOMMIN(scaled_error, 7);
432         const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
433 
434         count[k] += weight;
435         accumulator[k] += weight * pixel_value;
436       }
437     }
438   } else {
439     __m256d subblock_mses_reg[4];
440     __m256d d_factor_mul_n_decay_qr_invs[4];
441     const __m256 zero = _mm256_set1_ps(0.0f);
442     const __m256 point_five = _mm256_set1_ps(0.5f);
443     const __m256 seven = _mm256_set1_ps(7.0f);
444     const __m256d inv_num_ref_pixel_256bit = _mm256_set1_pd(inv_num_ref_pixels);
445     const __m256d weight_factor_256bit = _mm256_set1_pd(weight_factor);
446     const __m256 tf_weight_scale = _mm256_set1_ps((float)TF_WEIGHT_SCALE);
447     // Maintain registers to hold mse and d_factor at subblock level.
448     subblock_mses_reg[0] = _mm256_set1_pd(subblock_mses_scaled[0]);
449     subblock_mses_reg[1] = _mm256_set1_pd(subblock_mses_scaled[1]);
450     subblock_mses_reg[2] = _mm256_set1_pd(subblock_mses_scaled[2]);
451     subblock_mses_reg[3] = _mm256_set1_pd(subblock_mses_scaled[3]);
452     d_factor_mul_n_decay_qr_invs[0] = _mm256_set1_pd(d_factor_decayed[0]);
453     d_factor_mul_n_decay_qr_invs[1] = _mm256_set1_pd(d_factor_decayed[1]);
454     d_factor_mul_n_decay_qr_invs[2] = _mm256_set1_pd(d_factor_decayed[2]);
455     d_factor_mul_n_decay_qr_invs[3] = _mm256_set1_pd(d_factor_decayed[3]);
456 
457     for (int i = 0; i < block_height; i++) {
458       const int y_blk_raster_offset = (i >= block_height / 2) * 2;
459       uint32_t *luma_sse_sum_temp = luma_sse_sum + i * BW;
460       for (int j = 0; j < block_width; j += 8) {
461         const __m256i acc_sse =
462             _mm256_lddqu_si256((__m256i *)(acc_5x5_sse[i] + j));
463         const __m256i luma_sse =
464             _mm256_lddqu_si256((__m256i *)((luma_sse_sum_temp + j)));
465 
466         // uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
467         const __m256i diff_sse = _mm256_add_epi32(acc_sse, luma_sse);
468 
469         const __m256d diff_sse_pd_1 =
470             _mm256_cvtepi32_pd(_mm256_castsi256_si128(diff_sse));
471         const __m256d diff_sse_pd_2 =
472             _mm256_cvtepi32_pd(_mm256_extracti128_si256(diff_sse, 1));
473 
474         // const double window_error = diff_sse * inv_num_ref_pixels;
475         const __m256d window_error_1 =
476             _mm256_mul_pd(diff_sse_pd_1, inv_num_ref_pixel_256bit);
477         const __m256d window_error_2 =
478             _mm256_mul_pd(diff_sse_pd_2, inv_num_ref_pixel_256bit);
479 
480         // const int subblock_idx = y_blk_raster_offset + (j >= block_width /
481         // 2);
482         const int subblock_idx = y_blk_raster_offset + (j >= block_width / 2);
483         const __m256d blk_error = subblock_mses_reg[subblock_idx];
484 
485         // const double combined_error =
486         // weight_factor *window_error + subblock_mses_scaled[subblock_idx];
487         const __m256d combined_error_1 = _mm256_add_pd(
488             _mm256_mul_pd(window_error_1, weight_factor_256bit), blk_error);
489 
490         const __m256d combined_error_2 = _mm256_add_pd(
491             _mm256_mul_pd(window_error_2, weight_factor_256bit), blk_error);
492 
493         // d_factor_decayed[subblock_idx]
494         const __m256d d_fact_mul_n_decay =
495             d_factor_mul_n_decay_qr_invs[subblock_idx];
496 
497         // double scaled_error = combined_error *
498         // d_factor_decayed[subblock_idx];
499         const __m256d scaled_error_1 =
500             _mm256_mul_pd(combined_error_1, d_fact_mul_n_decay);
501         const __m256d scaled_error_2 =
502             _mm256_mul_pd(combined_error_2, d_fact_mul_n_decay);
503 
504         const __m128 scaled_error_ps_1 = _mm256_cvtpd_ps(scaled_error_1);
505         const __m128 scaled_error_ps_2 = _mm256_cvtpd_ps(scaled_error_2);
506 
507         const __m256 scaled_error_ps = _mm256_insertf128_ps(
508             _mm256_castps128_ps256(scaled_error_ps_1), scaled_error_ps_2, 0x1);
509 
510         // scaled_error = AOMMIN(scaled_error, 7);
511         const __m256 scaled_diff_ps = _mm256_min_ps(scaled_error_ps, seven);
512         const __m256 minus_scaled_diff_ps = _mm256_sub_ps(zero, scaled_diff_ps);
513         // const int weight =
514         //(int)(approx_exp((float)-scaled_error) * TF_WEIGHT_SCALE + 0.5f);
515         const __m256 exp_result = approx_exp_avx2(minus_scaled_diff_ps);
516         const __m256 scale_weight_exp_result =
517             _mm256_mul_ps(exp_result, tf_weight_scale);
518         const __m256 round_result =
519             _mm256_add_ps(scale_weight_exp_result, point_five);
520         __m256i weights_in_32bit = _mm256_cvttps_epi32(round_result);
521 
522         __m128i weights_in_16bit =
523             _mm_packus_epi32(_mm256_castsi256_si128(weights_in_32bit),
524                              _mm256_extractf128_si256(weights_in_32bit, 0x1));
525 
526         // count[k] += weight;
527         // accumulator[k] += weight * pixel_value;
528         const int stride_idx = i * stride2 + j;
529         const __m128i count_array =
530             _mm_loadu_si128((__m128i *)(count + stride_idx));
531         _mm_storeu_si128((__m128i *)(count + stride_idx),
532                          _mm_add_epi16(count_array, weights_in_16bit));
533 
534         const __m256i accumulator_array =
535             _mm256_loadu_si256((__m256i *)(accumulator + stride_idx));
536         const __m128i pred_values =
537             _mm_loadl_epi64((__m128i *)(frame2 + stride_idx));
538 
539         const __m256i pred_values_u32 = _mm256_cvtepu8_epi32(pred_values);
540         const __m256i mull_frame2_weight_u32 =
541             _mm256_mullo_epi32(pred_values_u32, weights_in_32bit);
542         _mm256_storeu_si256(
543             (__m256i *)(accumulator + stride_idx),
544             _mm256_add_epi32(accumulator_array, mull_frame2_weight_u32));
545       }
546     }
547   }
548 }
549 
av1_apply_temporal_filter_avx2(const YV12_BUFFER_CONFIG * frame_to_filter,const MACROBLOCKD * mbd,const BLOCK_SIZE block_size,const int mb_row,const int mb_col,const int num_planes,const double * noise_levels,const MV * subblock_mvs,const int * subblock_mses,const int q_factor,const int filter_strength,int tf_wgt_calc_lvl,const uint8_t * pred,uint32_t * accum,uint16_t * count)550 void av1_apply_temporal_filter_avx2(
551     const YV12_BUFFER_CONFIG *frame_to_filter, const MACROBLOCKD *mbd,
552     const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
553     const int num_planes, const double *noise_levels, const MV *subblock_mvs,
554     const int *subblock_mses, const int q_factor, const int filter_strength,
555     int tf_wgt_calc_lvl, const uint8_t *pred, uint32_t *accum,
556     uint16_t *count) {
557   const int is_high_bitdepth = frame_to_filter->flags & YV12_FLAG_HIGHBITDEPTH;
558   assert(block_size == BLOCK_32X32 && "Only support 32x32 block with avx2!");
559   assert(TF_WINDOW_LENGTH == 5 && "Only support window length 5 with avx2!");
560   assert(!is_high_bitdepth && "Only support low bit-depth with avx2!");
561   assert(num_planes >= 1 && num_planes <= MAX_MB_PLANE);
562   (void)is_high_bitdepth;
563 
564   const int mb_height = block_size_high[block_size];
565   const int mb_width = block_size_wide[block_size];
566   const int frame_height = frame_to_filter->y_crop_height;
567   const int frame_width = frame_to_filter->y_crop_width;
568   const int min_frame_size = AOMMIN(frame_height, frame_width);
569   // Variables to simplify combined error calculation.
570   const double inv_factor = 1.0 / ((TF_WINDOW_BLOCK_BALANCE_WEIGHT + 1) *
571                                    TF_SEARCH_ERROR_NORM_WEIGHT);
572   const double weight_factor =
573       (double)TF_WINDOW_BLOCK_BALANCE_WEIGHT * inv_factor;
574   // Adjust filtering based on q.
575   // Larger q -> stronger filtering -> larger weight.
576   // Smaller q -> weaker filtering -> smaller weight.
577   double q_decay = pow((double)q_factor / TF_Q_DECAY_THRESHOLD, 2);
578   q_decay = CLIP(q_decay, 1e-5, 1);
579   if (q_factor >= TF_QINDEX_CUTOFF) {
580     // Max q_factor is 255, therefore the upper bound of q_decay is 8.
581     // We do not need a clip here.
582     q_decay = 0.5 * pow((double)q_factor / 64, 2);
583   }
584   // Smaller strength -> smaller filtering weight.
585   double s_decay = pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2);
586   s_decay = CLIP(s_decay, 1e-5, 1);
587   double d_factor[4] = { 0 };
588   uint16_t frame_sse[SSE_STRIDE * BH] = { 0 };
589   uint32_t luma_sse_sum[BW * BH] = { 0 };
590 
591   for (int subblock_idx = 0; subblock_idx < 4; subblock_idx++) {
592     // Larger motion vector -> smaller filtering weight.
593     const MV mv = subblock_mvs[subblock_idx];
594     const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
595     double distance_threshold = min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD;
596     distance_threshold = AOMMAX(distance_threshold, 1);
597     d_factor[subblock_idx] = distance / distance_threshold;
598     d_factor[subblock_idx] = AOMMAX(d_factor[subblock_idx], 1);
599   }
600 
601   // Handle planes in sequence.
602   int plane_offset = 0;
603   for (int plane = 0; plane < num_planes; ++plane) {
604     const uint32_t plane_h = mb_height >> mbd->plane[plane].subsampling_y;
605     const uint32_t plane_w = mb_width >> mbd->plane[plane].subsampling_x;
606     const uint32_t frame_stride = frame_to_filter->strides[plane == 0 ? 0 : 1];
607     const int frame_offset = mb_row * plane_h * frame_stride + mb_col * plane_w;
608 
609     const uint8_t *ref = frame_to_filter->buffers[plane] + frame_offset;
610     const int ss_x_shift =
611         mbd->plane[plane].subsampling_x - mbd->plane[AOM_PLANE_Y].subsampling_x;
612     const int ss_y_shift =
613         mbd->plane[plane].subsampling_y - mbd->plane[AOM_PLANE_Y].subsampling_y;
614     const int num_ref_pixels = TF_WINDOW_LENGTH * TF_WINDOW_LENGTH +
615                                ((plane) ? (1 << (ss_x_shift + ss_y_shift)) : 0);
616     const double inv_num_ref_pixels = 1.0 / num_ref_pixels;
617     // Larger noise -> larger filtering weight.
618     const double n_decay = 0.5 + log(2 * noise_levels[plane] + 5.0);
619     // Decay factors for non-local mean approach.
620     const double decay_factor = 1 / (n_decay * q_decay * s_decay);
621 
622     // Filter U-plane and V-plane using Y-plane. This is because motion
623     // search is only done on Y-plane, so the information from Y-plane
624     // will be more accurate. The luma sse sum is reused in both chroma
625     // planes.
626     if (plane == AOM_PLANE_U) {
627       for (unsigned int i = 0, k = 0; i < plane_h; i++) {
628         for (unsigned int j = 0; j < plane_w; j++, k++) {
629           for (int ii = 0; ii < (1 << ss_y_shift); ++ii) {
630             for (int jj = 0; jj < (1 << ss_x_shift); ++jj) {
631               const int yy = (i << ss_y_shift) + ii;  // Y-coord on Y-plane.
632               const int xx = (j << ss_x_shift) + jj;  // X-coord on Y-plane.
633               luma_sse_sum[i * BW + j] += frame_sse[yy * SSE_STRIDE + xx];
634             }
635           }
636         }
637       }
638     }
639 
640     apply_temporal_filter(ref, frame_stride, pred + plane_offset, plane_w,
641                           plane_w, plane_h, subblock_mses, accum + plane_offset,
642                           count + plane_offset, frame_sse, luma_sse_sum,
643                           inv_num_ref_pixels, decay_factor, inv_factor,
644                           weight_factor, d_factor, tf_wgt_calc_lvl);
645     plane_offset += plane_h * plane_w;
646   }
647 }
648