xref: /aosp_15_r20/external/libgav1/src/dsp/arm/average_blend_neon.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/average_blend.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON
19 
20 #include <arm_neon.h>
21 
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 
26 #include "src/dsp/arm/common_neon.h"
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/utils/common.h"
30 
31 namespace libgav1 {
32 namespace dsp {
33 namespace {
34 
35 constexpr int kInterPostRoundBit =
36     kInterRoundBitsVertical - kInterRoundBitsCompoundVertical;
37 
38 }  // namespace
39 
40 namespace low_bitdepth {
41 namespace {
42 
AverageBlend8Row(const int16_t * LIBGAV1_RESTRICT prediction_0,const int16_t * LIBGAV1_RESTRICT prediction_1)43 inline uint8x8_t AverageBlend8Row(const int16_t* LIBGAV1_RESTRICT prediction_0,
44                                   const int16_t* LIBGAV1_RESTRICT
45                                       prediction_1) {
46   const int16x8_t pred0 = vld1q_s16(prediction_0);
47   const int16x8_t pred1 = vld1q_s16(prediction_1);
48   const int16x8_t res = vaddq_s16(pred0, pred1);
49   return vqrshrun_n_s16(res, kInterPostRoundBit + 1);
50 }
51 
AverageBlendLargeRow(const int16_t * LIBGAV1_RESTRICT prediction_0,const int16_t * LIBGAV1_RESTRICT prediction_1,const int width,uint8_t * LIBGAV1_RESTRICT dest)52 inline void AverageBlendLargeRow(const int16_t* LIBGAV1_RESTRICT prediction_0,
53                                  const int16_t* LIBGAV1_RESTRICT prediction_1,
54                                  const int width,
55                                  uint8_t* LIBGAV1_RESTRICT dest) {
56   int x = width;
57   do {
58     const int16x8_t pred_00 = vld1q_s16(prediction_0);
59     const int16x8_t pred_01 = vld1q_s16(prediction_1);
60     prediction_0 += 8;
61     prediction_1 += 8;
62     const int16x8_t res0 = vaddq_s16(pred_00, pred_01);
63     const uint8x8_t res_out0 = vqrshrun_n_s16(res0, kInterPostRoundBit + 1);
64     const int16x8_t pred_10 = vld1q_s16(prediction_0);
65     const int16x8_t pred_11 = vld1q_s16(prediction_1);
66     prediction_0 += 8;
67     prediction_1 += 8;
68     const int16x8_t res1 = vaddq_s16(pred_10, pred_11);
69     const uint8x8_t res_out1 = vqrshrun_n_s16(res1, kInterPostRoundBit + 1);
70     vst1q_u8(dest, vcombine_u8(res_out0, res_out1));
71     dest += 16;
72     x -= 16;
73   } while (x != 0);
74 }
75 
AverageBlend_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)76 void AverageBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
77                        const void* LIBGAV1_RESTRICT prediction_1,
78                        const int width, const int height,
79                        void* LIBGAV1_RESTRICT const dest,
80                        const ptrdiff_t dest_stride) {
81   auto* dst = static_cast<uint8_t*>(dest);
82   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
83   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
84   int y = height;
85 
86   if (width == 4) {
87     do {
88       const uint8x8_t result = AverageBlend8Row(pred_0, pred_1);
89       pred_0 += 8;
90       pred_1 += 8;
91 
92       StoreLo4(dst, result);
93       dst += dest_stride;
94       StoreHi4(dst, result);
95       dst += dest_stride;
96       y -= 2;
97     } while (y != 0);
98     return;
99   }
100 
101   if (width == 8) {
102     do {
103       vst1_u8(dst, AverageBlend8Row(pred_0, pred_1));
104       dst += dest_stride;
105       pred_0 += 8;
106       pred_1 += 8;
107 
108       vst1_u8(dst, AverageBlend8Row(pred_0, pred_1));
109       dst += dest_stride;
110       pred_0 += 8;
111       pred_1 += 8;
112 
113       y -= 2;
114     } while (y != 0);
115     return;
116   }
117 
118   do {
119     AverageBlendLargeRow(pred_0, pred_1, width, dst);
120     dst += dest_stride;
121     pred_0 += width;
122     pred_1 += width;
123 
124     AverageBlendLargeRow(pred_0, pred_1, width, dst);
125     dst += dest_stride;
126     pred_0 += width;
127     pred_1 += width;
128 
129     y -= 2;
130   } while (y != 0);
131 }
132 
Init8bpp()133 void Init8bpp() {
134   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
135   assert(dsp != nullptr);
136   dsp->average_blend = AverageBlend_NEON;
137 }
138 
139 }  // namespace
140 }  // namespace low_bitdepth
141 
142 #if LIBGAV1_MAX_BITDEPTH >= 10
143 namespace high_bitdepth {
144 namespace {
145 
AverageBlend8Row(const uint16_t * LIBGAV1_RESTRICT prediction_0,const uint16_t * LIBGAV1_RESTRICT prediction_1,const int32x4_t compound_offset,const uint16x8_t v_bitdepth)146 inline uint16x8_t AverageBlend8Row(
147     const uint16_t* LIBGAV1_RESTRICT prediction_0,
148     const uint16_t* LIBGAV1_RESTRICT prediction_1,
149     const int32x4_t compound_offset, const uint16x8_t v_bitdepth) {
150   const uint16x8_t pred0 = vld1q_u16(prediction_0);
151   const uint16x8_t pred1 = vld1q_u16(prediction_1);
152   const uint32x4_t pred_lo =
153       vaddl_u16(vget_low_u16(pred0), vget_low_u16(pred1));
154   const uint32x4_t pred_hi =
155       vaddl_u16(vget_high_u16(pred0), vget_high_u16(pred1));
156   const int32x4_t offset_lo =
157       vsubq_s32(vreinterpretq_s32_u32(pred_lo), compound_offset);
158   const int32x4_t offset_hi =
159       vsubq_s32(vreinterpretq_s32_u32(pred_hi), compound_offset);
160   const uint16x4_t res_lo = vqrshrun_n_s32(offset_lo, kInterPostRoundBit + 1);
161   const uint16x4_t res_hi = vqrshrun_n_s32(offset_hi, kInterPostRoundBit + 1);
162   return vminq_u16(vcombine_u16(res_lo, res_hi), v_bitdepth);
163 }
164 
AverageBlendLargeRow(const uint16_t * LIBGAV1_RESTRICT prediction_0,const uint16_t * LIBGAV1_RESTRICT prediction_1,const int width,uint16_t * LIBGAV1_RESTRICT dest,const int32x4_t compound_offset,const uint16x8_t v_bitdepth)165 inline void AverageBlendLargeRow(const uint16_t* LIBGAV1_RESTRICT prediction_0,
166                                  const uint16_t* LIBGAV1_RESTRICT prediction_1,
167                                  const int width,
168                                  uint16_t* LIBGAV1_RESTRICT dest,
169                                  const int32x4_t compound_offset,
170                                  const uint16x8_t v_bitdepth) {
171   int x = width;
172   do {
173     vst1q_u16(dest, AverageBlend8Row(prediction_0, prediction_1,
174                                      compound_offset, v_bitdepth));
175     prediction_0 += 8;
176     prediction_1 += 8;
177     dest += 8;
178 
179     vst1q_u16(dest, AverageBlend8Row(prediction_0, prediction_1,
180                                      compound_offset, v_bitdepth));
181     prediction_0 += 8;
182     prediction_1 += 8;
183     dest += 8;
184 
185     x -= 16;
186   } while (x != 0);
187 }
188 
AverageBlend_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)189 void AverageBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
190                        const void* LIBGAV1_RESTRICT prediction_1,
191                        const int width, const int height,
192                        void* LIBGAV1_RESTRICT const dest,
193                        const ptrdiff_t dest_stride) {
194   auto* dst = static_cast<uint16_t*>(dest);
195   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
196   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
197   int y = height;
198 
199   const ptrdiff_t dst_stride = dest_stride >> 1;
200   const int32x4_t compound_offset =
201       vdupq_n_s32(static_cast<int32_t>(kCompoundOffset + kCompoundOffset));
202   const uint16x8_t v_bitdepth = vdupq_n_u16((1 << kBitdepth10) - 1);
203   if (width == 4) {
204     do {
205       const uint16x8_t result =
206           AverageBlend8Row(pred_0, pred_1, compound_offset, v_bitdepth);
207       pred_0 += 8;
208       pred_1 += 8;
209 
210       vst1_u16(dst, vget_low_u16(result));
211       dst += dst_stride;
212       vst1_u16(dst, vget_high_u16(result));
213       dst += dst_stride;
214       y -= 2;
215     } while (y != 0);
216     return;
217   }
218 
219   if (width == 8) {
220     do {
221       vst1q_u16(dst,
222                 AverageBlend8Row(pred_0, pred_1, compound_offset, v_bitdepth));
223       dst += dst_stride;
224       pred_0 += 8;
225       pred_1 += 8;
226 
227       vst1q_u16(dst,
228                 AverageBlend8Row(pred_0, pred_1, compound_offset, v_bitdepth));
229       dst += dst_stride;
230       pred_0 += 8;
231       pred_1 += 8;
232 
233       y -= 2;
234     } while (y != 0);
235     return;
236   }
237 
238   do {
239     AverageBlendLargeRow(pred_0, pred_1, width, dst, compound_offset,
240                          v_bitdepth);
241     dst += dst_stride;
242     pred_0 += width;
243     pred_1 += width;
244 
245     AverageBlendLargeRow(pred_0, pred_1, width, dst, compound_offset,
246                          v_bitdepth);
247     dst += dst_stride;
248     pred_0 += width;
249     pred_1 += width;
250 
251     y -= 2;
252   } while (y != 0);
253 }
254 
Init10bpp()255 void Init10bpp() {
256   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
257   assert(dsp != nullptr);
258   dsp->average_blend = AverageBlend_NEON;
259 }
260 
261 }  // namespace
262 }  // namespace high_bitdepth
263 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
264 
AverageBlendInit_NEON()265 void AverageBlendInit_NEON() {
266   low_bitdepth::Init8bpp();
267 #if LIBGAV1_MAX_BITDEPTH >= 10
268   high_bitdepth::Init10bpp();
269 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
270 }
271 
272 }  // namespace dsp
273 }  // namespace libgav1
274 
275 #else   // !LIBGAV1_ENABLE_NEON
276 
277 namespace libgav1 {
278 namespace dsp {
279 
AverageBlendInit_NEON()280 void AverageBlendInit_NEON() {}
281 
282 }  // namespace dsp
283 }  // namespace libgav1
284 #endif  // LIBGAV1_ENABLE_NEON
285