xref: /aosp_15_r20/external/libaom/aom_dsp/arm/highbd_avg_pred_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023 The WebM project authors. All rights reserved.
3  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
4  *
5  * This source code is subject to the terms of the BSD 2 Clause License and
6  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
7  * was not distributed with this source code in the LICENSE file, you can
8  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
9  * Media Patent License 1.0 was not distributed with this source code in the
10  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
11  */
12 
13 #include <arm_neon.h>
14 #include <assert.h>
15 
16 #include "config/aom_config.h"
17 #include "config/aom_dsp_rtcd.h"
18 
19 #include "aom_dsp/arm/blend_neon.h"
20 #include "aom_dsp/arm/dist_wtd_avg_neon.h"
21 #include "aom_dsp/arm/mem_neon.h"
22 #include "aom_dsp/blend.h"
23 
aom_highbd_comp_avg_pred_neon(uint8_t * comp_pred8,const uint8_t * pred8,int width,int height,const uint8_t * ref8,int ref_stride)24 void aom_highbd_comp_avg_pred_neon(uint8_t *comp_pred8, const uint8_t *pred8,
25                                    int width, int height, const uint8_t *ref8,
26                                    int ref_stride) {
27   const uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
28   const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
29   uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
30 
31   int i = height;
32   if (width > 8) {
33     do {
34       int j = 0;
35       do {
36         const uint16x8_t p = vld1q_u16(pred + j);
37         const uint16x8_t r = vld1q_u16(ref + j);
38 
39         uint16x8_t avg = vrhaddq_u16(p, r);
40         vst1q_u16(comp_pred + j, avg);
41 
42         j += 8;
43       } while (j < width);
44 
45       comp_pred += width;
46       pred += width;
47       ref += ref_stride;
48     } while (--i != 0);
49   } else if (width == 8) {
50     do {
51       const uint16x8_t p = vld1q_u16(pred);
52       const uint16x8_t r = vld1q_u16(ref);
53 
54       uint16x8_t avg = vrhaddq_u16(p, r);
55       vst1q_u16(comp_pred, avg);
56 
57       comp_pred += width;
58       pred += width;
59       ref += ref_stride;
60     } while (--i != 0);
61   } else {
62     assert(width == 4);
63     do {
64       const uint16x4_t p = vld1_u16(pred);
65       const uint16x4_t r = vld1_u16(ref);
66 
67       uint16x4_t avg = vrhadd_u16(p, r);
68       vst1_u16(comp_pred, avg);
69 
70       comp_pred += width;
71       pred += width;
72       ref += ref_stride;
73     } while (--i != 0);
74   }
75 }
76 
aom_highbd_comp_mask_pred_neon(uint8_t * comp_pred8,const uint8_t * pred8,int width,int height,const uint8_t * ref8,int ref_stride,const uint8_t * mask,int mask_stride,int invert_mask)77 void aom_highbd_comp_mask_pred_neon(uint8_t *comp_pred8, const uint8_t *pred8,
78                                     int width, int height, const uint8_t *ref8,
79                                     int ref_stride, const uint8_t *mask,
80                                     int mask_stride, int invert_mask) {
81   uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
82   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
83   uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
84 
85   const uint16_t *src0 = invert_mask ? pred : ref;
86   const uint16_t *src1 = invert_mask ? ref : pred;
87   const int src_stride0 = invert_mask ? width : ref_stride;
88   const int src_stride1 = invert_mask ? ref_stride : width;
89 
90   if (width >= 8) {
91     do {
92       int j = 0;
93 
94       do {
95         const uint16x8_t s0 = vld1q_u16(src0 + j);
96         const uint16x8_t s1 = vld1q_u16(src1 + j);
97         const uint16x8_t m0 = vmovl_u8(vld1_u8(mask + j));
98 
99         uint16x8_t blend_u16 = alpha_blend_a64_u16x8(m0, s0, s1);
100 
101         vst1q_u16(comp_pred + j, blend_u16);
102 
103         j += 8;
104       } while (j < width);
105 
106       src0 += src_stride0;
107       src1 += src_stride1;
108       mask += mask_stride;
109       comp_pred += width;
110     } while (--height != 0);
111   } else {
112     assert(width == 4);
113 
114     do {
115       const uint16x4_t s0 = vld1_u16(src0);
116       const uint16x4_t s1 = vld1_u16(src1);
117       const uint16x4_t m0 = vget_low_u16(vmovl_u8(load_unaligned_u8_4x1(mask)));
118 
119       uint16x4_t blend_u16 = alpha_blend_a64_u16x4(m0, s0, s1);
120 
121       vst1_u16(comp_pred, blend_u16);
122 
123       src0 += src_stride0;
124       src1 += src_stride1;
125       mask += mask_stride;
126       comp_pred += 4;
127     } while (--height != 0);
128   }
129 }
130 
aom_highbd_dist_wtd_comp_avg_pred_neon(uint8_t * comp_pred8,const uint8_t * pred8,int width,int height,const uint8_t * ref8,int ref_stride,const DIST_WTD_COMP_PARAMS * jcp_param)131 void aom_highbd_dist_wtd_comp_avg_pred_neon(
132     uint8_t *comp_pred8, const uint8_t *pred8, int width, int height,
133     const uint8_t *ref8, int ref_stride,
134     const DIST_WTD_COMP_PARAMS *jcp_param) {
135   const uint16x8_t fwd_offset_u16 = vdupq_n_u16(jcp_param->fwd_offset);
136   const uint16x8_t bck_offset_u16 = vdupq_n_u16(jcp_param->bck_offset);
137   const uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
138   const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
139   uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8);
140 
141   if (width > 8) {
142     do {
143       int j = 0;
144       do {
145         const uint16x8_t p = vld1q_u16(pred + j);
146         const uint16x8_t r = vld1q_u16(ref + j);
147 
148         const uint16x8_t avg =
149             dist_wtd_avg_u16x8(r, p, fwd_offset_u16, bck_offset_u16);
150 
151         vst1q_u16(comp_pred + j, avg);
152 
153         j += 8;
154       } while (j < width);
155 
156       comp_pred += width;
157       pred += width;
158       ref += ref_stride;
159     } while (--height != 0);
160   } else if (width == 8) {
161     do {
162       const uint16x8_t p = vld1q_u16(pred);
163       const uint16x8_t r = vld1q_u16(ref);
164 
165       const uint16x8_t avg =
166           dist_wtd_avg_u16x8(r, p, fwd_offset_u16, bck_offset_u16);
167 
168       vst1q_u16(comp_pred, avg);
169 
170       comp_pred += width;
171       pred += width;
172       ref += ref_stride;
173     } while (--height != 0);
174   } else {
175     assert(width == 4);
176     do {
177       const uint16x4_t p = vld1_u16(pred);
178       const uint16x4_t r = vld1_u16(ref);
179 
180       const uint16x4_t avg = dist_wtd_avg_u16x4(
181           r, p, vget_low_u16(fwd_offset_u16), vget_low_u16(bck_offset_u16));
182 
183       vst1_u16(comp_pred, avg);
184 
185       comp_pred += width;
186       pred += width;
187       ref += ref_stride;
188     } while (--height != 0);
189   }
190 }
191