xref: /aosp_15_r20/external/libaom/aom_dsp/x86/jnt_variance_ssse3.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <emmintrin.h>  // SSE2
14 #include <tmmintrin.h>
15 
16 #include "config/aom_config.h"
17 #include "config/aom_dsp_rtcd.h"
18 
19 #include "aom_dsp/x86/synonyms.h"
20 #include "aom_dsp/x86/variance_impl_ssse3.h"
21 
compute_dist_wtd_avg(__m128i * p0,__m128i * p1,const __m128i * w,const __m128i * r,void * const result)22 static inline void compute_dist_wtd_avg(__m128i *p0, __m128i *p1,
23                                         const __m128i *w, const __m128i *r,
24                                         void *const result) {
25   __m128i p_lo = _mm_unpacklo_epi8(*p0, *p1);
26   __m128i mult_lo = _mm_maddubs_epi16(p_lo, *w);
27   __m128i round_lo = _mm_add_epi16(mult_lo, *r);
28   __m128i shift_lo = _mm_srai_epi16(round_lo, DIST_PRECISION_BITS);
29 
30   __m128i p_hi = _mm_unpackhi_epi8(*p0, *p1);
31   __m128i mult_hi = _mm_maddubs_epi16(p_hi, *w);
32   __m128i round_hi = _mm_add_epi16(mult_hi, *r);
33   __m128i shift_hi = _mm_srai_epi16(round_hi, DIST_PRECISION_BITS);
34 
35   xx_storeu_128(result, _mm_packus_epi16(shift_lo, shift_hi));
36 }
37 
aom_dist_wtd_comp_avg_pred_ssse3(uint8_t * comp_pred,const uint8_t * pred,int width,int height,const uint8_t * ref,int ref_stride,const DIST_WTD_COMP_PARAMS * jcp_param)38 void aom_dist_wtd_comp_avg_pred_ssse3(uint8_t *comp_pred, const uint8_t *pred,
39                                       int width, int height, const uint8_t *ref,
40                                       int ref_stride,
41                                       const DIST_WTD_COMP_PARAMS *jcp_param) {
42   int i;
43   const int8_t w0 = (int8_t)jcp_param->fwd_offset;
44   const int8_t w1 = (int8_t)jcp_param->bck_offset;
45   const __m128i w = _mm_set_epi8(w1, w0, w1, w0, w1, w0, w1, w0, w1, w0, w1, w0,
46                                  w1, w0, w1, w0);
47   const int16_t round = (int16_t)((1 << DIST_PRECISION_BITS) >> 1);
48   const __m128i r = _mm_set1_epi16(round);
49 
50   if (width >= 16) {
51     // Read 16 pixels one row at a time
52     assert(!(width & 15));
53     for (i = 0; i < height; ++i) {
54       int j;
55       for (j = 0; j < width; j += 16) {
56         __m128i p0 = xx_loadu_128(ref);
57         __m128i p1 = xx_loadu_128(pred);
58 
59         compute_dist_wtd_avg(&p0, &p1, &w, &r, comp_pred);
60 
61         comp_pred += 16;
62         pred += 16;
63         ref += 16;
64       }
65       ref += ref_stride - width;
66     }
67   } else if (width >= 8) {
68     // Read 8 pixels two row at a time
69     assert(!(width & 7));
70     assert(!(width & 1));
71     for (i = 0; i < height; i += 2) {
72       __m128i p0_0 = xx_loadl_64(ref + 0 * ref_stride);
73       __m128i p0_1 = xx_loadl_64(ref + 1 * ref_stride);
74       __m128i p0 = _mm_unpacklo_epi64(p0_0, p0_1);
75       __m128i p1 = xx_loadu_128(pred);
76 
77       compute_dist_wtd_avg(&p0, &p1, &w, &r, comp_pred);
78 
79       comp_pred += 16;
80       pred += 16;
81       ref += 2 * ref_stride;
82     }
83   } else {
84     // Read 4 pixels four row at a time
85     assert(!(width & 3));
86     assert(!(height & 3));
87     for (i = 0; i < height; i += 4) {
88       const int8_t *row0 = (const int8_t *)ref + 0 * ref_stride;
89       const int8_t *row1 = (const int8_t *)ref + 1 * ref_stride;
90       const int8_t *row2 = (const int8_t *)ref + 2 * ref_stride;
91       const int8_t *row3 = (const int8_t *)ref + 3 * ref_stride;
92 
93       __m128i p0 =
94           _mm_setr_epi8(row0[0], row0[1], row0[2], row0[3], row1[0], row1[1],
95                         row1[2], row1[3], row2[0], row2[1], row2[2], row2[3],
96                         row3[0], row3[1], row3[2], row3[3]);
97       __m128i p1 = xx_loadu_128(pred);
98 
99       compute_dist_wtd_avg(&p0, &p1, &w, &r, comp_pred);
100 
101       comp_pred += 16;
102       pred += 16;
103       ref += 4 * ref_stride;
104     }
105   }
106 }
107 
108 #define DIST_WTD_SUBPIX_AVG_VAR(W, H)                                      \
109   uint32_t aom_dist_wtd_sub_pixel_avg_variance##W##x##H##_ssse3(           \
110       const uint8_t *a, int a_stride, int xoffset, int yoffset,            \
111       const uint8_t *b, int b_stride, uint32_t *sse,                       \
112       const uint8_t *second_pred, const DIST_WTD_COMP_PARAMS *jcp_param) { \
113     uint16_t fdata3[(H + 1) * W];                                          \
114     uint8_t temp2[H * W];                                                  \
115     DECLARE_ALIGNED(16, uint8_t, temp3[H * W]);                            \
116                                                                            \
117     aom_var_filter_block2d_bil_first_pass_ssse3(                           \
118         a, fdata3, a_stride, 1, H + 1, W, bilinear_filters_2t[xoffset]);   \
119     aom_var_filter_block2d_bil_second_pass_ssse3(                          \
120         fdata3, temp2, W, W, H, W, bilinear_filters_2t[yoffset]);          \
121                                                                            \
122     aom_dist_wtd_comp_avg_pred_ssse3(temp3, second_pred, W, H, temp2, W,   \
123                                      jcp_param);                           \
124                                                                            \
125     return aom_variance##W##x##H(temp3, W, b, b_stride, sse);              \
126   }
127 
128 DIST_WTD_SUBPIX_AVG_VAR(128, 128)
129 DIST_WTD_SUBPIX_AVG_VAR(128, 64)
130 DIST_WTD_SUBPIX_AVG_VAR(64, 128)
131 DIST_WTD_SUBPIX_AVG_VAR(64, 64)
132 DIST_WTD_SUBPIX_AVG_VAR(64, 32)
133 DIST_WTD_SUBPIX_AVG_VAR(32, 64)
134 DIST_WTD_SUBPIX_AVG_VAR(32, 32)
135 DIST_WTD_SUBPIX_AVG_VAR(32, 16)
136 DIST_WTD_SUBPIX_AVG_VAR(16, 32)
137 DIST_WTD_SUBPIX_AVG_VAR(16, 16)
138 DIST_WTD_SUBPIX_AVG_VAR(16, 8)
139 DIST_WTD_SUBPIX_AVG_VAR(8, 16)
140 DIST_WTD_SUBPIX_AVG_VAR(8, 8)
141 DIST_WTD_SUBPIX_AVG_VAR(8, 4)
142 DIST_WTD_SUBPIX_AVG_VAR(4, 8)
143 DIST_WTD_SUBPIX_AVG_VAR(4, 4)
144 
145 #if !CONFIG_REALTIME_ONLY
146 DIST_WTD_SUBPIX_AVG_VAR(4, 16)
147 DIST_WTD_SUBPIX_AVG_VAR(16, 4)
148 DIST_WTD_SUBPIX_AVG_VAR(8, 32)
149 DIST_WTD_SUBPIX_AVG_VAR(32, 8)
150 DIST_WTD_SUBPIX_AVG_VAR(16, 64)
151 DIST_WTD_SUBPIX_AVG_VAR(64, 16)
152 #endif
153