xref: /aosp_15_r20/external/libgav1/src/dsp/arm/intrapred_cfl_neon.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/intrapred_cfl.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON
19 
20 #include <arm_neon.h>
21 
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 
26 #include "src/dsp/arm/common_neon.h"
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/utils/common.h"
30 #include "src/utils/constants.h"
31 
32 namespace libgav1 {
33 namespace dsp {
34 
35 // Divide by the number of elements.
Average(const uint32_t sum,const int width,const int height)36 inline uint32_t Average(const uint32_t sum, const int width, const int height) {
37   return RightShiftWithRounding(sum, FloorLog2(width) + FloorLog2(height));
38 }
39 
40 // Subtract |val| from every element in |a|.
BlockSubtract(const uint32_t val,int16_t a[kCflLumaBufferStride][kCflLumaBufferStride],const int width,const int height)41 inline void BlockSubtract(const uint32_t val,
42                           int16_t a[kCflLumaBufferStride][kCflLumaBufferStride],
43                           const int width, const int height) {
44   assert(val <= INT16_MAX);
45   const int16x8_t val_v = vdupq_n_s16(static_cast<int16_t>(val));
46 
47   for (int y = 0; y < height; ++y) {
48     if (width == 4) {
49       const int16x4_t b = vld1_s16(a[y]);
50       vst1_s16(a[y], vsub_s16(b, vget_low_s16(val_v)));
51     } else if (width == 8) {
52       const int16x8_t b = vld1q_s16(a[y]);
53       vst1q_s16(a[y], vsubq_s16(b, val_v));
54     } else if (width == 16) {
55       const int16x8_t b = vld1q_s16(a[y]);
56       const int16x8_t c = vld1q_s16(a[y] + 8);
57       vst1q_s16(a[y], vsubq_s16(b, val_v));
58       vst1q_s16(a[y] + 8, vsubq_s16(c, val_v));
59     } else /* block_width == 32 */ {
60       const int16x8_t b = vld1q_s16(a[y]);
61       const int16x8_t c = vld1q_s16(a[y] + 8);
62       const int16x8_t d = vld1q_s16(a[y] + 16);
63       const int16x8_t e = vld1q_s16(a[y] + 24);
64       vst1q_s16(a[y], vsubq_s16(b, val_v));
65       vst1q_s16(a[y] + 8, vsubq_s16(c, val_v));
66       vst1q_s16(a[y] + 16, vsubq_s16(d, val_v));
67       vst1q_s16(a[y] + 24, vsubq_s16(e, val_v));
68     }
69   }
70 }
71 
72 namespace low_bitdepth {
73 namespace {
74 
75 template <int block_width, int block_height>
CflSubsampler420_NEON(int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int max_luma_width,const int max_luma_height,const void * LIBGAV1_RESTRICT const source,const ptrdiff_t stride)76 void CflSubsampler420_NEON(
77     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
78     const int max_luma_width, const int max_luma_height,
79     const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride) {
80   const auto* src = static_cast<const uint8_t*>(source);
81   uint32_t sum;
82   if (block_width == 4) {
83     assert(max_luma_width >= 8);
84     uint32x2_t running_sum = vdup_n_u32(0);
85 
86     for (int y = 0; y < block_height; ++y) {
87       const uint8x8_t row0 = vld1_u8(src);
88       const uint8x8_t row1 = vld1_u8(src + stride);
89 
90       uint16x4_t sum_row = vpadal_u8(vpaddl_u8(row0), row1);
91       sum_row = vshl_n_u16(sum_row, 1);
92       running_sum = vpadal_u16(running_sum, sum_row);
93       vst1_s16(luma[y], vreinterpret_s16_u16(sum_row));
94 
95       if (y << 1 < max_luma_height - 2) {
96         // Once this threshold is reached the loop could be simplified.
97         src += stride << 1;
98       }
99     }
100 
101     sum = SumVector(running_sum);
102   } else if (block_width == 8) {
103     const uint16x8_t x_index = {0, 2, 4, 6, 8, 10, 12, 14};
104     const uint16x8_t x_max_index =
105         vdupq_n_u16(max_luma_width == 8 ? max_luma_width - 2 : 16);
106     const uint16x8_t x_mask = vcltq_u16(x_index, x_max_index);
107 
108     uint32x4_t running_sum = vdupq_n_u32(0);
109 
110     for (int y = 0; y < block_height; ++y) {
111       const uint8x16_t row0 = vld1q_u8(src);
112       const uint8x16_t row1 = vld1q_u8(src + stride);
113       const uint16x8_t sum_row = vpadalq_u8(vpaddlq_u8(row0), row1);
114       const uint16x8_t sum_row_shifted = vshlq_n_u16(sum_row, 1);
115 
116       // Dup the 2x2 sum at the max luma offset.
117       const uint16x8_t max_luma_sum =
118           vdupq_lane_u16(vget_low_u16(sum_row_shifted), 3);
119       const uint16x8_t final_sum_row =
120           vbslq_u16(x_mask, sum_row_shifted, max_luma_sum);
121       vst1q_s16(luma[y], vreinterpretq_s16_u16(final_sum_row));
122 
123       running_sum = vpadalq_u16(running_sum, final_sum_row);
124 
125       if (y << 1 < max_luma_height - 2) {
126         src += stride << 1;
127       }
128     }
129 
130     sum = SumVector(running_sum);
131   } else /* block_width >= 16 */ {
132     const uint16x8_t x_max_index = vdupq_n_u16(max_luma_width - 2);
133     uint32x4_t running_sum = vdupq_n_u32(0);
134 
135     for (int y = 0; y < block_height; ++y) {
136       // Calculate the 2x2 sum at the max_luma offset
137       const uint8_t a00 = src[max_luma_width - 2];
138       const uint8_t a01 = src[max_luma_width - 1];
139       const uint8_t a10 = src[max_luma_width - 2 + stride];
140       const uint8_t a11 = src[max_luma_width - 1 + stride];
141       // Dup the 2x2 sum at the max luma offset.
142       const uint16x8_t max_luma_sum =
143           vdupq_n_u16(static_cast<uint16_t>((a00 + a01 + a10 + a11) << 1));
144       uint16x8_t x_index = {0, 2, 4, 6, 8, 10, 12, 14};
145 
146       ptrdiff_t src_x_offset = 0;
147       for (int x = 0; x < block_width; x += 8, src_x_offset += 16) {
148         const uint16x8_t x_mask = vcltq_u16(x_index, x_max_index);
149         const uint8x16_t row0 = vld1q_u8(src + src_x_offset);
150         const uint8x16_t row1 = vld1q_u8(src + src_x_offset + stride);
151         const uint16x8_t sum_row = vpadalq_u8(vpaddlq_u8(row0), row1);
152         const uint16x8_t sum_row_shifted = vshlq_n_u16(sum_row, 1);
153         const uint16x8_t final_sum_row =
154             vbslq_u16(x_mask, sum_row_shifted, max_luma_sum);
155         vst1q_s16(luma[y] + x, vreinterpretq_s16_u16(final_sum_row));
156 
157         running_sum = vpadalq_u16(running_sum, final_sum_row);
158         x_index = vaddq_u16(x_index, vdupq_n_u16(16));
159       }
160 
161       if (y << 1 < max_luma_height - 2) {
162         src += stride << 1;
163       }
164     }
165     sum = SumVector(running_sum);
166   }
167 
168   const uint32_t average = Average(sum, block_width, block_height);
169   BlockSubtract(average, luma, block_width, block_height);
170 }
171 
172 template <int block_width, int block_height>
CflSubsampler444_NEON(int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int max_luma_width,const int max_luma_height,const void * LIBGAV1_RESTRICT const source,const ptrdiff_t stride)173 void CflSubsampler444_NEON(
174     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
175     const int max_luma_width, const int max_luma_height,
176     const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride) {
177   const auto* src = static_cast<const uint8_t*>(source);
178   uint32_t sum;
179   if (block_width == 4) {
180     assert(max_luma_width >= 4);
181     assert(max_luma_height <= block_height);
182     assert((max_luma_height % 2) == 0);
183     uint32x4_t running_sum = vdupq_n_u32(0);
184     uint8x8_t row = vdup_n_u8(0);
185 
186     uint16x8_t row_shifted;
187     int y = 0;
188     do {
189       row = Load4<0>(src, row);
190       row = Load4<1>(src + stride, row);
191       if (y < (max_luma_height - 1)) {
192         src += stride << 1;
193       }
194 
195       row_shifted = vshll_n_u8(row, 3);
196       running_sum = vpadalq_u16(running_sum, row_shifted);
197       vst1_s16(luma[y], vreinterpret_s16_u16(vget_low_u16(row_shifted)));
198       vst1_s16(luma[y + 1], vreinterpret_s16_u16(vget_high_u16(row_shifted)));
199       y += 2;
200     } while (y < max_luma_height);
201 
202     row_shifted =
203         vcombine_u16(vget_high_u16(row_shifted), vget_high_u16(row_shifted));
204     for (; y < block_height; y += 2) {
205       running_sum = vpadalq_u16(running_sum, row_shifted);
206       vst1_s16(luma[y], vreinterpret_s16_u16(vget_low_u16(row_shifted)));
207       vst1_s16(luma[y + 1], vreinterpret_s16_u16(vget_high_u16(row_shifted)));
208     }
209 
210     sum = SumVector(running_sum);
211   } else if (block_width == 8) {
212     const uint8x8_t x_index = {0, 1, 2, 3, 4, 5, 6, 7};
213     const uint8x8_t x_max_index = vdup_n_u8(max_luma_width - 1);
214     const uint8x8_t x_mask = vclt_u8(x_index, x_max_index);
215 
216     uint32x4_t running_sum = vdupq_n_u32(0);
217 
218     for (int y = 0; y < block_height; ++y) {
219       const uint8x8_t x_max = vdup_n_u8(src[max_luma_width - 1]);
220       const uint8x8_t row = vbsl_u8(x_mask, vld1_u8(src), x_max);
221 
222       const uint16x8_t row_shifted = vshll_n_u8(row, 3);
223       running_sum = vpadalq_u16(running_sum, row_shifted);
224       vst1q_s16(luma[y], vreinterpretq_s16_u16(row_shifted));
225 
226       if (y < max_luma_height - 1) {
227         src += stride;
228       }
229     }
230 
231     sum = SumVector(running_sum);
232   } else /* block_width >= 16 */ {
233     const uint8x16_t x_max_index = vdupq_n_u8(max_luma_width - 1);
234     uint32x4_t running_sum = vdupq_n_u32(0);
235 
236     for (int y = 0; y < block_height; ++y) {
237       uint8x16_t x_index = {0, 1, 2,  3,  4,  5,  6,  7,
238                             8, 9, 10, 11, 12, 13, 14, 15};
239       const uint8x16_t x_max = vdupq_n_u8(src[max_luma_width - 1]);
240       for (int x = 0; x < block_width; x += 16) {
241         const uint8x16_t x_mask = vcltq_u8(x_index, x_max_index);
242         const uint8x16_t row = vbslq_u8(x_mask, vld1q_u8(src + x), x_max);
243 
244         const uint16x8_t row_shifted_low = vshll_n_u8(vget_low_u8(row), 3);
245         const uint16x8_t row_shifted_high = vshll_n_u8(vget_high_u8(row), 3);
246         running_sum = vpadalq_u16(running_sum, row_shifted_low);
247         running_sum = vpadalq_u16(running_sum, row_shifted_high);
248         vst1q_s16(luma[y] + x, vreinterpretq_s16_u16(row_shifted_low));
249         vst1q_s16(luma[y] + x + 8, vreinterpretq_s16_u16(row_shifted_high));
250 
251         x_index = vaddq_u8(x_index, vdupq_n_u8(16));
252       }
253       if (y < max_luma_height - 1) {
254         src += stride;
255       }
256     }
257     sum = SumVector(running_sum);
258   }
259 
260   const uint32_t average = Average(sum, block_width, block_height);
261   BlockSubtract(average, luma, block_width, block_height);
262 }
263 
264 // Saturate |dc + ((alpha * luma) >> 6))| to uint8_t.
Combine8(const int16x8_t luma,const int alpha,const int16x8_t dc)265 inline uint8x8_t Combine8(const int16x8_t luma, const int alpha,
266                           const int16x8_t dc) {
267   const int16x8_t la = vmulq_n_s16(luma, alpha);
268   // Subtract the sign bit to round towards zero.
269   const int16x8_t sub_sign = vsraq_n_s16(la, la, 15);
270   // Shift and accumulate.
271   const int16x8_t result = vrsraq_n_s16(dc, sub_sign, 6);
272   return vqmovun_s16(result);
273 }
274 
275 // The range of luma/alpha is not really important because it gets saturated to
276 // uint8_t. Saturated int16_t >> 6 outranges uint8_t.
277 template <int block_height>
CflIntraPredictor4xN_NEON(void * LIBGAV1_RESTRICT const dest,const ptrdiff_t stride,const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int alpha)278 inline void CflIntraPredictor4xN_NEON(
279     void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
280     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
281     const int alpha) {
282   auto* dst = static_cast<uint8_t*>(dest);
283   const int16x8_t dc = vdupq_n_s16(dst[0]);
284   for (int y = 0; y < block_height; y += 2) {
285     const int16x4_t luma_row0 = vld1_s16(luma[y]);
286     const int16x4_t luma_row1 = vld1_s16(luma[y + 1]);
287     const uint8x8_t sum =
288         Combine8(vcombine_s16(luma_row0, luma_row1), alpha, dc);
289     StoreLo4(dst, sum);
290     dst += stride;
291     StoreHi4(dst, sum);
292     dst += stride;
293   }
294 }
295 
296 template <int block_height>
CflIntraPredictor8xN_NEON(void * LIBGAV1_RESTRICT const dest,const ptrdiff_t stride,const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int alpha)297 inline void CflIntraPredictor8xN_NEON(
298     void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
299     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
300     const int alpha) {
301   auto* dst = static_cast<uint8_t*>(dest);
302   const int16x8_t dc = vdupq_n_s16(dst[0]);
303   for (int y = 0; y < block_height; ++y) {
304     const int16x8_t luma_row = vld1q_s16(luma[y]);
305     const uint8x8_t sum = Combine8(luma_row, alpha, dc);
306     vst1_u8(dst, sum);
307     dst += stride;
308   }
309 }
310 
311 template <int block_height>
CflIntraPredictor16xN_NEON(void * LIBGAV1_RESTRICT const dest,const ptrdiff_t stride,const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int alpha)312 inline void CflIntraPredictor16xN_NEON(
313     void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
314     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
315     const int alpha) {
316   auto* dst = static_cast<uint8_t*>(dest);
317   const int16x8_t dc = vdupq_n_s16(dst[0]);
318   for (int y = 0; y < block_height; ++y) {
319     const int16x8_t luma_row_0 = vld1q_s16(luma[y]);
320     const int16x8_t luma_row_1 = vld1q_s16(luma[y] + 8);
321     const uint8x8_t sum_0 = Combine8(luma_row_0, alpha, dc);
322     const uint8x8_t sum_1 = Combine8(luma_row_1, alpha, dc);
323     vst1_u8(dst, sum_0);
324     vst1_u8(dst + 8, sum_1);
325     dst += stride;
326   }
327 }
328 
329 template <int block_height>
CflIntraPredictor32xN_NEON(void * LIBGAV1_RESTRICT const dest,const ptrdiff_t stride,const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int alpha)330 inline void CflIntraPredictor32xN_NEON(
331     void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
332     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
333     const int alpha) {
334   auto* dst = static_cast<uint8_t*>(dest);
335   const int16x8_t dc = vdupq_n_s16(dst[0]);
336   for (int y = 0; y < block_height; ++y) {
337     const int16x8_t luma_row_0 = vld1q_s16(luma[y]);
338     const int16x8_t luma_row_1 = vld1q_s16(luma[y] + 8);
339     const int16x8_t luma_row_2 = vld1q_s16(luma[y] + 16);
340     const int16x8_t luma_row_3 = vld1q_s16(luma[y] + 24);
341     const uint8x8_t sum_0 = Combine8(luma_row_0, alpha, dc);
342     const uint8x8_t sum_1 = Combine8(luma_row_1, alpha, dc);
343     const uint8x8_t sum_2 = Combine8(luma_row_2, alpha, dc);
344     const uint8x8_t sum_3 = Combine8(luma_row_3, alpha, dc);
345     vst1_u8(dst, sum_0);
346     vst1_u8(dst + 8, sum_1);
347     vst1_u8(dst + 16, sum_2);
348     vst1_u8(dst + 24, sum_3);
349     dst += stride;
350   }
351 }
352 
Init8bpp()353 void Init8bpp() {
354   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
355   assert(dsp != nullptr);
356 
357   dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType420] =
358       CflSubsampler420_NEON<4, 4>;
359   dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType420] =
360       CflSubsampler420_NEON<4, 8>;
361   dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType420] =
362       CflSubsampler420_NEON<4, 16>;
363 
364   dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType420] =
365       CflSubsampler420_NEON<8, 4>;
366   dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType420] =
367       CflSubsampler420_NEON<8, 8>;
368   dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType420] =
369       CflSubsampler420_NEON<8, 16>;
370   dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType420] =
371       CflSubsampler420_NEON<8, 32>;
372 
373   dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType420] =
374       CflSubsampler420_NEON<16, 4>;
375   dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType420] =
376       CflSubsampler420_NEON<16, 8>;
377   dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType420] =
378       CflSubsampler420_NEON<16, 16>;
379   dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType420] =
380       CflSubsampler420_NEON<16, 32>;
381 
382   dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType420] =
383       CflSubsampler420_NEON<32, 8>;
384   dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType420] =
385       CflSubsampler420_NEON<32, 16>;
386   dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType420] =
387       CflSubsampler420_NEON<32, 32>;
388 
389   dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType444] =
390       CflSubsampler444_NEON<4, 4>;
391   dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType444] =
392       CflSubsampler444_NEON<4, 8>;
393   dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType444] =
394       CflSubsampler444_NEON<4, 16>;
395 
396   dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType444] =
397       CflSubsampler444_NEON<8, 4>;
398   dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType444] =
399       CflSubsampler444_NEON<8, 8>;
400   dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType444] =
401       CflSubsampler444_NEON<8, 16>;
402   dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType444] =
403       CflSubsampler444_NEON<8, 32>;
404 
405   dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType444] =
406       CflSubsampler444_NEON<16, 4>;
407   dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType444] =
408       CflSubsampler444_NEON<16, 8>;
409   dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType444] =
410       CflSubsampler444_NEON<16, 16>;
411   dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType444] =
412       CflSubsampler444_NEON<16, 32>;
413 
414   dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType444] =
415       CflSubsampler444_NEON<32, 8>;
416   dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType444] =
417       CflSubsampler444_NEON<32, 16>;
418   dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType444] =
419       CflSubsampler444_NEON<32, 32>;
420 
421   dsp->cfl_intra_predictors[kTransformSize4x4] = CflIntraPredictor4xN_NEON<4>;
422   dsp->cfl_intra_predictors[kTransformSize4x8] = CflIntraPredictor4xN_NEON<8>;
423   dsp->cfl_intra_predictors[kTransformSize4x16] = CflIntraPredictor4xN_NEON<16>;
424 
425   dsp->cfl_intra_predictors[kTransformSize8x4] = CflIntraPredictor8xN_NEON<4>;
426   dsp->cfl_intra_predictors[kTransformSize8x8] = CflIntraPredictor8xN_NEON<8>;
427   dsp->cfl_intra_predictors[kTransformSize8x16] = CflIntraPredictor8xN_NEON<16>;
428   dsp->cfl_intra_predictors[kTransformSize8x32] = CflIntraPredictor8xN_NEON<32>;
429 
430   dsp->cfl_intra_predictors[kTransformSize16x4] = CflIntraPredictor16xN_NEON<4>;
431   dsp->cfl_intra_predictors[kTransformSize16x8] = CflIntraPredictor16xN_NEON<8>;
432   dsp->cfl_intra_predictors[kTransformSize16x16] =
433       CflIntraPredictor16xN_NEON<16>;
434   dsp->cfl_intra_predictors[kTransformSize16x32] =
435       CflIntraPredictor16xN_NEON<32>;
436 
437   dsp->cfl_intra_predictors[kTransformSize32x8] = CflIntraPredictor32xN_NEON<8>;
438   dsp->cfl_intra_predictors[kTransformSize32x16] =
439       CflIntraPredictor32xN_NEON<16>;
440   dsp->cfl_intra_predictors[kTransformSize32x32] =
441       CflIntraPredictor32xN_NEON<32>;
442   // Max Cfl predictor size is 32x32.
443 }
444 
445 }  // namespace
446 }  // namespace low_bitdepth
447 
448 //------------------------------------------------------------------------------
449 #if LIBGAV1_MAX_BITDEPTH >= 10
450 namespace high_bitdepth {
451 namespace {
452 
453 //------------------------------------------------------------------------------
454 // CflSubsampler
455 #ifndef __aarch64__
vpaddq_u16(uint16x8_t a,uint16x8_t b)456 uint16x8_t vpaddq_u16(uint16x8_t a, uint16x8_t b) {
457   return vcombine_u16(vpadd_u16(vget_low_u16(a), vget_high_u16(a)),
458                       vpadd_u16(vget_low_u16(b), vget_high_u16(b)));
459 }
460 #endif
461 
462 // This duplicates the last two 16-bit values in |row|.
LastRowSamples(const uint16x8_t row)463 inline uint16x8_t LastRowSamples(const uint16x8_t row) {
464   const uint32x2_t a = vget_high_u32(vreinterpretq_u32_u16(row));
465   const uint32x4_t b = vdupq_lane_u32(a, 1);
466   return vreinterpretq_u16_u32(b);
467 }
468 
469 // This duplicates the last unsigned 16-bit value in |row|.
LastRowResult(const uint16x8_t row)470 inline uint16x8_t LastRowResult(const uint16x8_t row) {
471   const uint16x4_t a = vget_high_u16(row);
472   const uint16x8_t b = vdupq_lane_u16(a, 0x3);
473   return b;
474 }
475 
476 // This duplicates the last signed 16-bit value in |row|.
LastRowResult(const int16x8_t row)477 inline int16x8_t LastRowResult(const int16x8_t row) {
478   const int16x4_t a = vget_high_s16(row);
479   const int16x8_t b = vdupq_lane_s16(a, 0x3);
480   return b;
481 }
482 
483 // Takes in two sums of input row pairs, and completes the computation for two
484 // output rows.
StoreLumaResults4_420(const uint16x8_t vertical_sum0,const uint16x8_t vertical_sum1,int16_t * luma_ptr)485 inline uint16x8_t StoreLumaResults4_420(const uint16x8_t vertical_sum0,
486                                         const uint16x8_t vertical_sum1,
487                                         int16_t* luma_ptr) {
488   const uint16x8_t result = vpaddq_u16(vertical_sum0, vertical_sum1);
489   const uint16x8_t result_shifted = vshlq_n_u16(result, 1);
490   vst1_s16(luma_ptr, vreinterpret_s16_u16(vget_low_u16(result_shifted)));
491   vst1_s16(luma_ptr + kCflLumaBufferStride,
492            vreinterpret_s16_u16(vget_high_u16(result_shifted)));
493   return result_shifted;
494 }
495 
496 // Takes two halves of a vertically added pair of rows and completes the
497 // computation for one output row.
StoreLumaResults8_420(const uint16x8_t vertical_sum0,const uint16x8_t vertical_sum1,int16_t * luma_ptr)498 inline uint16x8_t StoreLumaResults8_420(const uint16x8_t vertical_sum0,
499                                         const uint16x8_t vertical_sum1,
500                                         int16_t* luma_ptr) {
501   const uint16x8_t result = vpaddq_u16(vertical_sum0, vertical_sum1);
502   const uint16x8_t result_shifted = vshlq_n_u16(result, 1);
503   vst1q_s16(luma_ptr, vreinterpretq_s16_u16(result_shifted));
504   return result_shifted;
505 }
506 
507 template <int block_height_log2, bool is_inside>
CflSubsampler444_4xH_NEON(int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int max_luma_height,const void * LIBGAV1_RESTRICT const source,ptrdiff_t stride)508 void CflSubsampler444_4xH_NEON(
509     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
510     const int max_luma_height, const void* LIBGAV1_RESTRICT const source,
511     ptrdiff_t stride) {
512   static_assert(block_height_log2 <= 4, "");
513   const int block_height = 1 << block_height_log2;
514   const int visible_height = max_luma_height;
515   const auto* src = static_cast<const uint16_t*>(source);
516   const ptrdiff_t src_stride = stride / sizeof(src[0]);
517   int16_t* luma_ptr = luma[0];
518   uint16x4_t sum = vdup_n_u16(0);
519   uint16x4_t samples[2];
520   int y = visible_height;
521 
522   do {
523     samples[0] = vld1_u16(src);
524     samples[1] = vld1_u16(src + src_stride);
525     src += src_stride << 1;
526     sum = vadd_u16(sum, samples[0]);
527     sum = vadd_u16(sum, samples[1]);
528     y -= 2;
529   } while (y != 0);
530 
531   if (!is_inside) {
532     y = visible_height;
533     samples[1] = vshl_n_u16(samples[1], 1);
534     do {
535       sum = vadd_u16(sum, samples[1]);
536       y += 2;
537     } while (y < block_height);
538   }
539 
540   // Here the left shift by 3 (to increase precision) is nullified in right
541   // shift ((log2 of width 4) + 1).
542   const uint32_t average_sum =
543       RightShiftWithRounding(SumVector(vpaddl_u16(sum)), block_height_log2 - 1);
544   const int16x4_t averages = vdup_n_s16(static_cast<int16_t>(average_sum));
545 
546   const auto* ssrc = static_cast<const int16_t*>(source);
547   int16x4_t ssample;
548   luma_ptr = luma[0];
549   y = visible_height;
550   do {
551     ssample = vld1_s16(ssrc);
552     ssample = vshl_n_s16(ssample, 3);
553     vst1_s16(luma_ptr, vsub_s16(ssample, averages));
554     ssrc += src_stride;
555     luma_ptr += kCflLumaBufferStride;
556   } while (--y != 0);
557 
558   if (!is_inside) {
559     y = visible_height;
560     // Replicate last line
561     do {
562       vst1_s16(luma_ptr, vsub_s16(ssample, averages));
563       luma_ptr += kCflLumaBufferStride;
564     } while (++y < block_height);
565   }
566 }
567 
568 template <int block_height_log2>
CflSubsampler444_4xH_NEON(int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int max_luma_width,const int max_luma_height,const void * LIBGAV1_RESTRICT const source,ptrdiff_t stride)569 void CflSubsampler444_4xH_NEON(
570     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
571     const int max_luma_width, const int max_luma_height,
572     const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
573   static_cast<void>(max_luma_width);
574   static_cast<void>(max_luma_height);
575   static_assert(block_height_log2 <= 4, "");
576   assert(max_luma_width >= 4);
577   assert(max_luma_height >= 4);
578   const int block_height = 1 << block_height_log2;
579 
580   if (block_height <= max_luma_height) {
581     CflSubsampler444_4xH_NEON<block_height_log2, true>(luma, max_luma_height,
582                                                        source, stride);
583   } else {
584     CflSubsampler444_4xH_NEON<block_height_log2, false>(luma, max_luma_height,
585                                                         source, stride);
586   }
587 }
588 
589 template <int block_height_log2, bool is_inside>
CflSubsampler444_8xH_NEON(int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int max_luma_height,const void * LIBGAV1_RESTRICT const source,ptrdiff_t stride)590 void CflSubsampler444_8xH_NEON(
591     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
592     const int max_luma_height, const void* LIBGAV1_RESTRICT const source,
593     ptrdiff_t stride) {
594   const int block_height = 1 << block_height_log2;
595   const int visible_height = max_luma_height;
596   const auto* src = static_cast<const uint16_t*>(source);
597   const ptrdiff_t src_stride = stride / sizeof(src[0]);
598   int16_t* luma_ptr = luma[0];
599   uint32x4_t sum = vdupq_n_u32(0);
600   uint16x8_t samples;
601   int y = visible_height;
602 
603   do {
604     samples = vld1q_u16(src);
605     src += src_stride;
606     sum = vpadalq_u16(sum, samples);
607   } while (--y != 0);
608 
609   if (!is_inside) {
610     y = visible_height;
611     do {
612       sum = vpadalq_u16(sum, samples);
613     } while (++y < block_height);
614   }
615 
616   // Here the left shift by 3 (to increase precision) is nullified in right
617   // shift (log2 of width 8).
618   const uint32_t average_sum =
619       RightShiftWithRounding(SumVector(sum), block_height_log2);
620   const int16x8_t averages = vdupq_n_s16(static_cast<int16_t>(average_sum));
621 
622   const auto* ssrc = static_cast<const int16_t*>(source);
623   int16x8_t ssample;
624   luma_ptr = luma[0];
625   y = visible_height;
626   do {
627     ssample = vld1q_s16(ssrc);
628     ssample = vshlq_n_s16(ssample, 3);
629     vst1q_s16(luma_ptr, vsubq_s16(ssample, averages));
630     ssrc += src_stride;
631     luma_ptr += kCflLumaBufferStride;
632   } while (--y != 0);
633 
634   if (!is_inside) {
635     y = visible_height;
636     // Replicate last line
637     do {
638       vst1q_s16(luma_ptr, vsubq_s16(ssample, averages));
639       luma_ptr += kCflLumaBufferStride;
640     } while (++y < block_height);
641   }
642 }
643 
644 template <int block_height_log2>
CflSubsampler444_8xH_NEON(int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int max_luma_width,const int max_luma_height,const void * LIBGAV1_RESTRICT const source,ptrdiff_t stride)645 void CflSubsampler444_8xH_NEON(
646     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
647     const int max_luma_width, const int max_luma_height,
648     const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
649   static_cast<void>(max_luma_width);
650   static_cast<void>(max_luma_height);
651   static_assert(block_height_log2 <= 5, "");
652   assert(max_luma_width >= 4);
653   assert(max_luma_height >= 4);
654   const int block_height = 1 << block_height_log2;
655   const int block_width = 8;
656 
657   const int horz_inside = block_width <= max_luma_width;
658   const int vert_inside = block_height <= max_luma_height;
659   if (horz_inside && vert_inside) {
660     CflSubsampler444_8xH_NEON<block_height_log2, true>(luma, max_luma_height,
661                                                        source, stride);
662   } else {
663     CflSubsampler444_8xH_NEON<block_height_log2, false>(luma, max_luma_height,
664                                                         source, stride);
665   }
666 }
667 
668 template <int block_width_log2, int block_height_log2, bool is_inside>
CflSubsampler444_WxH_NEON(int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int max_luma_width,const int max_luma_height,const void * LIBGAV1_RESTRICT const source,ptrdiff_t stride)669 void CflSubsampler444_WxH_NEON(
670     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
671     const int max_luma_width, const int max_luma_height,
672     const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
673   const int block_height = 1 << block_height_log2;
674   const int visible_height = max_luma_height;
675   const int block_width = 1 << block_width_log2;
676   const auto* src = static_cast<const uint16_t*>(source);
677   const ptrdiff_t src_stride = stride / sizeof(src[0]);
678   int16_t* luma_ptr = luma[0];
679   uint32x4_t sum = vdupq_n_u32(0);
680   uint16x8_t samples[4];
681   int y = visible_height;
682 
683   do {
684     samples[0] = vld1q_u16(src);
685     samples[1] =
686         (max_luma_width >= 16) ? vld1q_u16(src + 8) : LastRowResult(samples[0]);
687     uint16x8_t inner_sum = vaddq_u16(samples[0], samples[1]);
688     if (block_width == 32) {
689       samples[2] = (max_luma_width >= 24) ? vld1q_u16(src + 16)
690                                           : LastRowResult(samples[1]);
691       samples[3] = (max_luma_width == 32) ? vld1q_u16(src + 24)
692                                           : LastRowResult(samples[2]);
693       inner_sum = vaddq_u16(samples[2], inner_sum);
694       inner_sum = vaddq_u16(samples[3], inner_sum);
695     }
696     sum = vpadalq_u16(sum, inner_sum);
697     src += src_stride;
698   } while (--y != 0);
699 
700   if (!is_inside) {
701     y = visible_height;
702     uint16x8_t inner_sum = vaddq_u16(samples[0], samples[1]);
703     if (block_width == 32) {
704       inner_sum = vaddq_u16(samples[2], inner_sum);
705       inner_sum = vaddq_u16(samples[3], inner_sum);
706     }
707     do {
708       sum = vpadalq_u16(sum, inner_sum);
709     } while (++y < block_height);
710   }
711 
712   // Here the left shift by 3 (to increase precision) is subtracted in right
713   // shift factor (block_width_log2 + block_height_log2 - 3).
714   const uint32_t average_sum = RightShiftWithRounding(
715       SumVector(sum), block_width_log2 + block_height_log2 - 3);
716   const int16x8_t averages = vdupq_n_s16(static_cast<int16_t>(average_sum));
717 
718   const auto* ssrc = static_cast<const int16_t*>(source);
719   int16x8_t ssamples_ext = vdupq_n_s16(0);
720   int16x8_t ssamples[4];
721   luma_ptr = luma[0];
722   y = visible_height;
723   do {
724     int idx = 0;
725     for (int x = 0; x < block_width; x += 8) {
726       if (max_luma_width > x) {
727         ssamples[idx] = vld1q_s16(&ssrc[x]);
728         ssamples[idx] = vshlq_n_s16(ssamples[idx], 3);
729         ssamples_ext = ssamples[idx];
730       } else {
731         ssamples[idx] = LastRowResult(ssamples_ext);
732       }
733       vst1q_s16(&luma_ptr[x], vsubq_s16(ssamples[idx++], averages));
734     }
735     ssrc += src_stride;
736     luma_ptr += kCflLumaBufferStride;
737   } while (--y != 0);
738 
739   if (!is_inside) {
740     y = visible_height;
741     // Replicate last line
742     do {
743       int idx = 0;
744       for (int x = 0; x < block_width; x += 8) {
745         vst1q_s16(&luma_ptr[x], vsubq_s16(ssamples[idx++], averages));
746       }
747       luma_ptr += kCflLumaBufferStride;
748     } while (++y < block_height);
749   }
750 }
751 
752 template <int block_width_log2, int block_height_log2>
CflSubsampler444_WxH_NEON(int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int max_luma_width,const int max_luma_height,const void * LIBGAV1_RESTRICT const source,ptrdiff_t stride)753 void CflSubsampler444_WxH_NEON(
754     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
755     const int max_luma_width, const int max_luma_height,
756     const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
757   static_assert(block_width_log2 == 4 || block_width_log2 == 5,
758                 "This function will only work for block_width 16 and 32.");
759   static_assert(block_height_log2 <= 5, "");
760   assert(max_luma_width >= 4);
761   assert(max_luma_height >= 4);
762 
763   const int block_height = 1 << block_height_log2;
764   const int vert_inside = block_height <= max_luma_height;
765   if (vert_inside) {
766     CflSubsampler444_WxH_NEON<block_width_log2, block_height_log2, true>(
767         luma, max_luma_width, max_luma_height, source, stride);
768   } else {
769     CflSubsampler444_WxH_NEON<block_width_log2, block_height_log2, false>(
770         luma, max_luma_width, max_luma_height, source, stride);
771   }
772 }
773 
774 template <int block_height_log2>
CflSubsampler420_4xH_NEON(int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int,const int max_luma_height,const void * LIBGAV1_RESTRICT const source,ptrdiff_t stride)775 void CflSubsampler420_4xH_NEON(
776     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
777     const int /*max_luma_width*/, const int max_luma_height,
778     const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
779   const int block_height = 1 << block_height_log2;
780   const auto* src = static_cast<const uint16_t*>(source);
781   const ptrdiff_t src_stride = stride / sizeof(src[0]);
782   int16_t* luma_ptr = luma[0];
783   const int luma_height = std::min(block_height, max_luma_height >> 1);
784   int y = luma_height;
785 
786   uint32x4_t final_sum = vdupq_n_u32(0);
787   do {
788     const uint16x8_t samples_row0 = vld1q_u16(src);
789     src += src_stride;
790     const uint16x8_t samples_row1 = vld1q_u16(src);
791     src += src_stride;
792     const uint16x8_t luma_sum01 = vaddq_u16(samples_row0, samples_row1);
793 
794     const uint16x8_t samples_row2 = vld1q_u16(src);
795     src += src_stride;
796     const uint16x8_t samples_row3 = vld1q_u16(src);
797     src += src_stride;
798     const uint16x8_t luma_sum23 = vaddq_u16(samples_row2, samples_row3);
799     uint16x8_t sum = StoreLumaResults4_420(luma_sum01, luma_sum23, luma_ptr);
800     luma_ptr += kCflLumaBufferStride << 1;
801 
802     const uint16x8_t samples_row4 = vld1q_u16(src);
803     src += src_stride;
804     const uint16x8_t samples_row5 = vld1q_u16(src);
805     src += src_stride;
806     const uint16x8_t luma_sum45 = vaddq_u16(samples_row4, samples_row5);
807 
808     const uint16x8_t samples_row6 = vld1q_u16(src);
809     src += src_stride;
810     const uint16x8_t samples_row7 = vld1q_u16(src);
811     src += src_stride;
812     const uint16x8_t luma_sum67 = vaddq_u16(samples_row6, samples_row7);
813     sum =
814         vaddq_u16(sum, StoreLumaResults4_420(luma_sum45, luma_sum67, luma_ptr));
815     luma_ptr += kCflLumaBufferStride << 1;
816 
817     final_sum = vpadalq_u16(final_sum, sum);
818     y -= 4;
819   } while (y != 0);
820 
821   const uint16x4_t final_fill =
822       vreinterpret_u16_s16(vld1_s16(luma_ptr - kCflLumaBufferStride));
823   const uint32x4_t final_fill_to_sum = vmovl_u16(final_fill);
824   for (y = luma_height; y < block_height; ++y) {
825     vst1_s16(luma_ptr, vreinterpret_s16_u16(final_fill));
826     luma_ptr += kCflLumaBufferStride;
827     final_sum = vaddq_u32(final_sum, final_fill_to_sum);
828   }
829   const uint32_t average_sum = RightShiftWithRounding(
830       SumVector(final_sum), block_height_log2 + 2 /*log2 of width 4*/);
831   const int16x4_t averages = vdup_n_s16(static_cast<int16_t>(average_sum));
832   luma_ptr = luma[0];
833   y = block_height;
834   do {
835     const int16x4_t samples = vld1_s16(luma_ptr);
836     vst1_s16(luma_ptr, vsub_s16(samples, averages));
837     luma_ptr += kCflLumaBufferStride;
838   } while (--y != 0);
839 }
840 
841 template <int block_height_log2, int max_luma_width>
CflSubsampler420Impl_8xH_NEON(int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int max_luma_height,const void * LIBGAV1_RESTRICT const source,ptrdiff_t stride)842 inline void CflSubsampler420Impl_8xH_NEON(
843     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
844     const int max_luma_height, const void* LIBGAV1_RESTRICT const source,
845     ptrdiff_t stride) {
846   const int block_height = 1 << block_height_log2;
847   const auto* src = static_cast<const uint16_t*>(source);
848   const ptrdiff_t src_stride = stride / sizeof(src[0]);
849   int16_t* luma_ptr = luma[0];
850   const int luma_height = std::min(block_height, max_luma_height >> 1);
851   int y = luma_height;
852 
853   uint32x4_t final_sum = vdupq_n_u32(0);
854   do {
855     const uint16x8_t samples_row00 = vld1q_u16(src);
856     const uint16x8_t samples_row01 = (max_luma_width == 16)
857                                          ? vld1q_u16(src + 8)
858                                          : LastRowSamples(samples_row00);
859     src += src_stride;
860     const uint16x8_t samples_row10 = vld1q_u16(src);
861     const uint16x8_t samples_row11 = (max_luma_width == 16)
862                                          ? vld1q_u16(src + 8)
863                                          : LastRowSamples(samples_row10);
864     src += src_stride;
865     const uint16x8_t luma_sum00 = vaddq_u16(samples_row00, samples_row10);
866     const uint16x8_t luma_sum01 = vaddq_u16(samples_row01, samples_row11);
867     uint16x8_t sum = StoreLumaResults8_420(luma_sum00, luma_sum01, luma_ptr);
868     luma_ptr += kCflLumaBufferStride;
869 
870     const uint16x8_t samples_row20 = vld1q_u16(src);
871     const uint16x8_t samples_row21 = (max_luma_width == 16)
872                                          ? vld1q_u16(src + 8)
873                                          : LastRowSamples(samples_row20);
874     src += src_stride;
875     const uint16x8_t samples_row30 = vld1q_u16(src);
876     const uint16x8_t samples_row31 = (max_luma_width == 16)
877                                          ? vld1q_u16(src + 8)
878                                          : LastRowSamples(samples_row30);
879     src += src_stride;
880     const uint16x8_t luma_sum10 = vaddq_u16(samples_row20, samples_row30);
881     const uint16x8_t luma_sum11 = vaddq_u16(samples_row21, samples_row31);
882     sum =
883         vaddq_u16(sum, StoreLumaResults8_420(luma_sum10, luma_sum11, luma_ptr));
884     luma_ptr += kCflLumaBufferStride;
885 
886     const uint16x8_t samples_row40 = vld1q_u16(src);
887     const uint16x8_t samples_row41 = (max_luma_width == 16)
888                                          ? vld1q_u16(src + 8)
889                                          : LastRowSamples(samples_row40);
890     src += src_stride;
891     const uint16x8_t samples_row50 = vld1q_u16(src);
892     const uint16x8_t samples_row51 = (max_luma_width == 16)
893                                          ? vld1q_u16(src + 8)
894                                          : LastRowSamples(samples_row50);
895     src += src_stride;
896     const uint16x8_t luma_sum20 = vaddq_u16(samples_row40, samples_row50);
897     const uint16x8_t luma_sum21 = vaddq_u16(samples_row41, samples_row51);
898     sum =
899         vaddq_u16(sum, StoreLumaResults8_420(luma_sum20, luma_sum21, luma_ptr));
900     luma_ptr += kCflLumaBufferStride;
901 
902     const uint16x8_t samples_row60 = vld1q_u16(src);
903     const uint16x8_t samples_row61 = (max_luma_width == 16)
904                                          ? vld1q_u16(src + 8)
905                                          : LastRowSamples(samples_row60);
906     src += src_stride;
907     const uint16x8_t samples_row70 = vld1q_u16(src);
908     const uint16x8_t samples_row71 = (max_luma_width == 16)
909                                          ? vld1q_u16(src + 8)
910                                          : LastRowSamples(samples_row70);
911     src += src_stride;
912     const uint16x8_t luma_sum30 = vaddq_u16(samples_row60, samples_row70);
913     const uint16x8_t luma_sum31 = vaddq_u16(samples_row61, samples_row71);
914     sum =
915         vaddq_u16(sum, StoreLumaResults8_420(luma_sum30, luma_sum31, luma_ptr));
916     luma_ptr += kCflLumaBufferStride;
917 
918     final_sum = vpadalq_u16(final_sum, sum);
919     y -= 4;
920   } while (y != 0);
921 
922   // Duplicate the final row downward to the end after max_luma_height.
923   const uint16x8_t final_fill =
924       vreinterpretq_u16_s16(vld1q_s16(luma_ptr - kCflLumaBufferStride));
925   const uint32x4_t final_fill_to_sum =
926       vaddl_u16(vget_low_u16(final_fill), vget_high_u16(final_fill));
927 
928   for (y = luma_height; y < block_height; ++y) {
929     vst1q_s16(luma_ptr, vreinterpretq_s16_u16(final_fill));
930     luma_ptr += kCflLumaBufferStride;
931     final_sum = vaddq_u32(final_sum, final_fill_to_sum);
932   }
933 
934   const uint32_t average_sum = RightShiftWithRounding(
935       SumVector(final_sum), block_height_log2 + 3 /*log2 of width 8*/);
936   const int16x8_t averages = vdupq_n_s16(static_cast<int16_t>(average_sum));
937   luma_ptr = luma[0];
938   y = block_height;
939   do {
940     const int16x8_t samples = vld1q_s16(luma_ptr);
941     vst1q_s16(luma_ptr, vsubq_s16(samples, averages));
942     luma_ptr += kCflLumaBufferStride;
943   } while (--y != 0);
944 }
945 
946 template <int block_height_log2>
CflSubsampler420_8xH_NEON(int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int max_luma_width,const int max_luma_height,const void * LIBGAV1_RESTRICT const source,ptrdiff_t stride)947 void CflSubsampler420_8xH_NEON(
948     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
949     const int max_luma_width, const int max_luma_height,
950     const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
951   if (max_luma_width == 8) {
952     CflSubsampler420Impl_8xH_NEON<block_height_log2, 8>(luma, max_luma_height,
953                                                         source, stride);
954   } else {
955     CflSubsampler420Impl_8xH_NEON<block_height_log2, 16>(luma, max_luma_height,
956                                                          source, stride);
957   }
958 }
959 
960 template <int block_width_log2, int block_height_log2, int max_luma_width>
CflSubsampler420Impl_WxH_NEON(int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int max_luma_height,const void * LIBGAV1_RESTRICT const source,ptrdiff_t stride)961 inline void CflSubsampler420Impl_WxH_NEON(
962     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
963     const int max_luma_height, const void* LIBGAV1_RESTRICT const source,
964     ptrdiff_t stride) {
965   const auto* src = static_cast<const uint16_t*>(source);
966   const ptrdiff_t src_stride = stride / sizeof(src[0]);
967   const int block_height = 1 << block_height_log2;
968   const int luma_height = std::min(block_height, max_luma_height >> 1);
969   int16_t* luma_ptr = luma[0];
970   // Begin first y section, covering width up to 32.
971   int y = luma_height;
972 
973   uint16x8_t final_fill0, final_fill1;
974   uint32x4_t final_sum = vdupq_n_u32(0);
975   do {
976     const uint16_t* src_next = src + src_stride;
977     const uint16x8_t samples_row00 = vld1q_u16(src);
978     const uint16x8_t samples_row01 = (max_luma_width >= 16)
979                                          ? vld1q_u16(src + 8)
980                                          : LastRowSamples(samples_row00);
981     const uint16x8_t samples_row02 = (max_luma_width >= 24)
982                                          ? vld1q_u16(src + 16)
983                                          : LastRowSamples(samples_row01);
984     const uint16x8_t samples_row03 = (max_luma_width == 32)
985                                          ? vld1q_u16(src + 24)
986                                          : LastRowSamples(samples_row02);
987     const uint16x8_t samples_row10 = vld1q_u16(src_next);
988     const uint16x8_t samples_row11 = (max_luma_width >= 16)
989                                          ? vld1q_u16(src_next + 8)
990                                          : LastRowSamples(samples_row10);
991     const uint16x8_t samples_row12 = (max_luma_width >= 24)
992                                          ? vld1q_u16(src_next + 16)
993                                          : LastRowSamples(samples_row11);
994     const uint16x8_t samples_row13 = (max_luma_width == 32)
995                                          ? vld1q_u16(src_next + 24)
996                                          : LastRowSamples(samples_row12);
997     const uint16x8_t luma_sum0 = vaddq_u16(samples_row00, samples_row10);
998     const uint16x8_t luma_sum1 = vaddq_u16(samples_row01, samples_row11);
999     const uint16x8_t luma_sum2 = vaddq_u16(samples_row02, samples_row12);
1000     const uint16x8_t luma_sum3 = vaddq_u16(samples_row03, samples_row13);
1001     final_fill0 = StoreLumaResults8_420(luma_sum0, luma_sum1, luma_ptr);
1002     final_fill1 = StoreLumaResults8_420(luma_sum2, luma_sum3, luma_ptr + 8);
1003     const uint16x8_t sum = vaddq_u16(final_fill0, final_fill1);
1004 
1005     final_sum = vpadalq_u16(final_sum, sum);
1006 
1007     // Because max_luma_width is at most 32, any values beyond x=16 will
1008     // necessarily be duplicated.
1009     if (block_width_log2 == 5) {
1010       const uint16x8_t wide_fill = LastRowResult(final_fill1);
1011       final_sum = vpadalq_u16(final_sum, vshlq_n_u16(wide_fill, 1));
1012     }
1013     src += src_stride << 1;
1014     luma_ptr += kCflLumaBufferStride;
1015   } while (--y != 0);
1016 
1017   // Begin second y section.
1018   y = luma_height;
1019   if (y < block_height) {
1020     uint32x4_t wide_fill;
1021     if (block_width_log2 == 5) {
1022       // There are 16 16-bit fill values per row, shifting by 2 accounts for
1023       // the widening to 32-bit.  (a << 2) = (a + a) << 1.
1024       wide_fill = vshll_n_u16(vget_low_u16(LastRowResult(final_fill1)), 2);
1025     }
1026     const uint16x8_t final_inner_sum = vaddq_u16(final_fill0, final_fill1);
1027     const uint32x4_t final_fill_to_sum = vaddl_u16(
1028         vget_low_u16(final_inner_sum), vget_high_u16(final_inner_sum));
1029 
1030     do {
1031       vst1q_s16(luma_ptr, vreinterpretq_s16_u16(final_fill0));
1032       vst1q_s16(luma_ptr + 8, vreinterpretq_s16_u16(final_fill1));
1033       if (block_width_log2 == 5) {
1034         final_sum = vaddq_u32(final_sum, wide_fill);
1035       }
1036       luma_ptr += kCflLumaBufferStride;
1037       final_sum = vaddq_u32(final_sum, final_fill_to_sum);
1038     } while (++y < block_height);
1039   }  // End second y section.
1040 
1041   const uint32_t average_sum = RightShiftWithRounding(
1042       SumVector(final_sum), block_width_log2 + block_height_log2);
1043   const int16x8_t averages = vdupq_n_s16(static_cast<int16_t>(average_sum));
1044 
1045   luma_ptr = luma[0];
1046   y = block_height;
1047   do {
1048     const int16x8_t samples0 = vld1q_s16(luma_ptr);
1049     vst1q_s16(luma_ptr, vsubq_s16(samples0, averages));
1050     const int16x8_t samples1 = vld1q_s16(luma_ptr + 8);
1051     const int16x8_t final_row_result = vsubq_s16(samples1, averages);
1052     vst1q_s16(luma_ptr + 8, final_row_result);
1053 
1054     if (block_width_log2 == 5) {
1055       const int16x8_t wide_fill = LastRowResult(final_row_result);
1056       vst1q_s16(luma_ptr + 16, wide_fill);
1057       vst1q_s16(luma_ptr + 24, wide_fill);
1058     }
1059     luma_ptr += kCflLumaBufferStride;
1060   } while (--y != 0);
1061 }
1062 
1063 //------------------------------------------------------------------------------
1064 // Choose subsampler based on max_luma_width
1065 template <int block_width_log2, int block_height_log2>
CflSubsampler420_WxH_NEON(int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int max_luma_width,const int max_luma_height,const void * LIBGAV1_RESTRICT const source,ptrdiff_t stride)1066 void CflSubsampler420_WxH_NEON(
1067     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
1068     const int max_luma_width, const int max_luma_height,
1069     const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
1070   switch (max_luma_width) {
1071     case 8:
1072       CflSubsampler420Impl_WxH_NEON<block_width_log2, block_height_log2, 8>(
1073           luma, max_luma_height, source, stride);
1074       return;
1075     case 16:
1076       CflSubsampler420Impl_WxH_NEON<block_width_log2, block_height_log2, 16>(
1077           luma, max_luma_height, source, stride);
1078       return;
1079     case 24:
1080       CflSubsampler420Impl_WxH_NEON<block_width_log2, block_height_log2, 24>(
1081           luma, max_luma_height, source, stride);
1082       return;
1083     default:
1084       assert(max_luma_width == 32);
1085       CflSubsampler420Impl_WxH_NEON<block_width_log2, block_height_log2, 32>(
1086           luma, max_luma_height, source, stride);
1087       return;
1088   }
1089 }
1090 
1091 //------------------------------------------------------------------------------
1092 // CflIntraPredictor
1093 
1094 // |luma| can be within +/-(((1 << bitdepth) - 1) << 3), inclusive.
1095 // |alpha| can be -16 to 16 (inclusive).
1096 // Clip |dc + ((alpha * luma) >> 6))| to 0, (1 << bitdepth) - 1.
Combine8(const int16x8_t luma,const int16x8_t alpha_abs,const int16x8_t alpha_signed,const int16x8_t dc,const uint16x8_t max_value)1097 inline uint16x8_t Combine8(const int16x8_t luma, const int16x8_t alpha_abs,
1098                            const int16x8_t alpha_signed, const int16x8_t dc,
1099                            const uint16x8_t max_value) {
1100   const int16x8_t luma_abs = vabsq_s16(luma);
1101   const int16x8_t luma_alpha_sign =
1102       vshrq_n_s16(veorq_s16(luma, alpha_signed), 15);
1103   // (alpha * luma) >> 6
1104   const int16x8_t la_abs = vqrdmulhq_s16(luma_abs, alpha_abs);
1105   // Convert back to signed values.
1106   const int16x8_t la =
1107       vsubq_s16(veorq_s16(la_abs, luma_alpha_sign), luma_alpha_sign);
1108   const int16x8_t result = vaddq_s16(la, dc);
1109   const int16x8_t zero = vdupq_n_s16(0);
1110   // Clip.
1111   return vminq_u16(vreinterpretq_u16_s16(vmaxq_s16(result, zero)), max_value);
1112 }
1113 
1114 template <int block_height, int bitdepth = 10>
CflIntraPredictor4xN_NEON(void * LIBGAV1_RESTRICT const dest,const ptrdiff_t stride,const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int alpha)1115 inline void CflIntraPredictor4xN_NEON(
1116     void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
1117     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
1118     const int alpha) {
1119   auto* dst = static_cast<uint16_t*>(dest);
1120   const ptrdiff_t dst_stride = stride >> 1;
1121   const uint16x8_t max_value = vdupq_n_u16((1 << bitdepth) - 1);
1122   const int16x8_t alpha_signed = vdupq_n_s16(alpha << 9);
1123   const int16x8_t alpha_abs = vabsq_s16(alpha_signed);
1124   const int16x8_t dc = vdupq_n_s16(dst[0]);
1125   for (int y = 0; y < block_height; y += 2) {
1126     const int16x4_t luma_row0 = vld1_s16(luma[y]);
1127     const int16x4_t luma_row1 = vld1_s16(luma[y + 1]);
1128     const int16x8_t combined_luma = vcombine_s16(luma_row0, luma_row1);
1129     const uint16x8_t sum =
1130         Combine8(combined_luma, alpha_abs, alpha_signed, dc, max_value);
1131     vst1_u16(dst, vget_low_u16(sum));
1132     dst += dst_stride;
1133     vst1_u16(dst, vget_high_u16(sum));
1134     dst += dst_stride;
1135   }
1136 }
1137 
1138 template <int block_height, int bitdepth = 10>
CflIntraPredictor8xN_NEON(void * LIBGAV1_RESTRICT const dest,const ptrdiff_t stride,const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int alpha)1139 inline void CflIntraPredictor8xN_NEON(
1140     void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
1141     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
1142     const int alpha) {
1143   auto* dst = static_cast<uint16_t*>(dest);
1144   const ptrdiff_t dst_stride = stride >> 1;
1145   const uint16x8_t max_value = vdupq_n_u16((1 << bitdepth) - 1);
1146   const int16x8_t alpha_signed = vdupq_n_s16(alpha << 9);
1147   const int16x8_t alpha_abs = vabsq_s16(alpha_signed);
1148   const int16x8_t dc = vdupq_n_s16(dst[0]);
1149   for (int y = 0; y < block_height; ++y) {
1150     const int16x8_t luma_row = vld1q_s16(luma[y]);
1151     const uint16x8_t sum =
1152         Combine8(luma_row, alpha_abs, alpha_signed, dc, max_value);
1153     vst1q_u16(dst, sum);
1154     dst += dst_stride;
1155   }
1156 }
1157 
1158 template <int block_height, int bitdepth = 10>
CflIntraPredictor16xN_NEON(void * LIBGAV1_RESTRICT const dest,const ptrdiff_t stride,const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int alpha)1159 inline void CflIntraPredictor16xN_NEON(
1160     void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
1161     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
1162     const int alpha) {
1163   auto* dst = static_cast<uint16_t*>(dest);
1164   const ptrdiff_t dst_stride = stride >> 1;
1165   const uint16x8_t max_value = vdupq_n_u16((1 << bitdepth) - 1);
1166   const int16x8_t alpha_signed = vdupq_n_s16(alpha << 9);
1167   const int16x8_t alpha_abs = vabsq_s16(alpha_signed);
1168   const int16x8_t dc = vdupq_n_s16(dst[0]);
1169   for (int y = 0; y < block_height; ++y) {
1170     const int16x8_t luma_row_0 = vld1q_s16(luma[y]);
1171     const int16x8_t luma_row_1 = vld1q_s16(luma[y] + 8);
1172     const uint16x8_t sum_0 =
1173         Combine8(luma_row_0, alpha_abs, alpha_signed, dc, max_value);
1174     const uint16x8_t sum_1 =
1175         Combine8(luma_row_1, alpha_abs, alpha_signed, dc, max_value);
1176     vst1q_u16(dst, sum_0);
1177     vst1q_u16(dst + 8, sum_1);
1178     dst += dst_stride;
1179   }
1180 }
1181 
1182 template <int block_height, int bitdepth = 10>
CflIntraPredictor32xN_NEON(void * LIBGAV1_RESTRICT const dest,const ptrdiff_t stride,const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],const int alpha)1183 inline void CflIntraPredictor32xN_NEON(
1184     void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
1185     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
1186     const int alpha) {
1187   auto* dst = static_cast<uint16_t*>(dest);
1188   const ptrdiff_t dst_stride = stride >> 1;
1189   const uint16x8_t max_value = vdupq_n_u16((1 << bitdepth) - 1);
1190   const int16x8_t alpha_signed = vdupq_n_s16(alpha << 9);
1191   const int16x8_t alpha_abs = vabsq_s16(alpha_signed);
1192   const int16x8_t dc = vdupq_n_s16(dst[0]);
1193   for (int y = 0; y < block_height; ++y) {
1194     const int16x8_t luma_row_0 = vld1q_s16(luma[y]);
1195     const int16x8_t luma_row_1 = vld1q_s16(luma[y] + 8);
1196     const int16x8_t luma_row_2 = vld1q_s16(luma[y] + 16);
1197     const int16x8_t luma_row_3 = vld1q_s16(luma[y] + 24);
1198     const uint16x8_t sum_0 =
1199         Combine8(luma_row_0, alpha_abs, alpha_signed, dc, max_value);
1200     const uint16x8_t sum_1 =
1201         Combine8(luma_row_1, alpha_abs, alpha_signed, dc, max_value);
1202     const uint16x8_t sum_2 =
1203         Combine8(luma_row_2, alpha_abs, alpha_signed, dc, max_value);
1204     const uint16x8_t sum_3 =
1205         Combine8(luma_row_3, alpha_abs, alpha_signed, dc, max_value);
1206     vst1q_u16(dst, sum_0);
1207     vst1q_u16(dst + 8, sum_1);
1208     vst1q_u16(dst + 16, sum_2);
1209     vst1q_u16(dst + 24, sum_3);
1210     dst += dst_stride;
1211   }
1212 }
1213 
Init10bpp()1214 void Init10bpp() {
1215   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
1216   assert(dsp != nullptr);
1217 
1218   dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType420] =
1219       CflSubsampler420_4xH_NEON<2>;
1220   dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType420] =
1221       CflSubsampler420_4xH_NEON<3>;
1222   dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType420] =
1223       CflSubsampler420_4xH_NEON<4>;
1224 
1225   dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType420] =
1226       CflSubsampler420_8xH_NEON<2>;
1227   dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType420] =
1228       CflSubsampler420_8xH_NEON<3>;
1229   dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType420] =
1230       CflSubsampler420_8xH_NEON<4>;
1231   dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType420] =
1232       CflSubsampler420_8xH_NEON<5>;
1233 
1234   dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType420] =
1235       CflSubsampler420_WxH_NEON<4, 2>;
1236   dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType420] =
1237       CflSubsampler420_WxH_NEON<4, 3>;
1238   dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType420] =
1239       CflSubsampler420_WxH_NEON<4, 4>;
1240   dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType420] =
1241       CflSubsampler420_WxH_NEON<4, 5>;
1242 
1243   dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType420] =
1244       CflSubsampler420_WxH_NEON<5, 3>;
1245   dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType420] =
1246       CflSubsampler420_WxH_NEON<5, 4>;
1247   dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType420] =
1248       CflSubsampler420_WxH_NEON<5, 5>;
1249 
1250   dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType444] =
1251       CflSubsampler444_4xH_NEON<2>;
1252   dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType444] =
1253       CflSubsampler444_4xH_NEON<3>;
1254   dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType444] =
1255       CflSubsampler444_4xH_NEON<4>;
1256 
1257   dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType444] =
1258       CflSubsampler444_8xH_NEON<2>;
1259   dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType444] =
1260       CflSubsampler444_8xH_NEON<3>;
1261   dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType444] =
1262       CflSubsampler444_8xH_NEON<4>;
1263   dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType444] =
1264       CflSubsampler444_8xH_NEON<5>;
1265 
1266   dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType444] =
1267       CflSubsampler444_WxH_NEON<4, 2>;
1268   dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType444] =
1269       CflSubsampler444_WxH_NEON<4, 3>;
1270   dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType444] =
1271       CflSubsampler444_WxH_NEON<4, 4>;
1272   dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType444] =
1273       CflSubsampler444_WxH_NEON<4, 5>;
1274 
1275   dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType444] =
1276       CflSubsampler444_WxH_NEON<5, 3>;
1277   dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType444] =
1278       CflSubsampler444_WxH_NEON<5, 4>;
1279   dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType444] =
1280       CflSubsampler444_WxH_NEON<5, 5>;
1281 
1282   dsp->cfl_intra_predictors[kTransformSize4x4] = CflIntraPredictor4xN_NEON<4>;
1283   dsp->cfl_intra_predictors[kTransformSize4x8] = CflIntraPredictor4xN_NEON<8>;
1284   dsp->cfl_intra_predictors[kTransformSize4x16] = CflIntraPredictor4xN_NEON<16>;
1285 
1286   dsp->cfl_intra_predictors[kTransformSize8x4] = CflIntraPredictor8xN_NEON<4>;
1287   dsp->cfl_intra_predictors[kTransformSize8x8] = CflIntraPredictor8xN_NEON<8>;
1288   dsp->cfl_intra_predictors[kTransformSize8x16] = CflIntraPredictor8xN_NEON<16>;
1289   dsp->cfl_intra_predictors[kTransformSize8x32] = CflIntraPredictor8xN_NEON<32>;
1290 
1291   dsp->cfl_intra_predictors[kTransformSize16x4] = CflIntraPredictor16xN_NEON<4>;
1292   dsp->cfl_intra_predictors[kTransformSize16x8] = CflIntraPredictor16xN_NEON<8>;
1293   dsp->cfl_intra_predictors[kTransformSize16x16] =
1294       CflIntraPredictor16xN_NEON<16>;
1295   dsp->cfl_intra_predictors[kTransformSize16x32] =
1296       CflIntraPredictor16xN_NEON<32>;
1297   dsp->cfl_intra_predictors[kTransformSize32x8] = CflIntraPredictor32xN_NEON<8>;
1298   dsp->cfl_intra_predictors[kTransformSize32x16] =
1299       CflIntraPredictor32xN_NEON<16>;
1300   dsp->cfl_intra_predictors[kTransformSize32x32] =
1301       CflIntraPredictor32xN_NEON<32>;
1302   // Max Cfl predictor size is 32x32.
1303 }
1304 
1305 }  // namespace
1306 }  // namespace high_bitdepth
1307 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
1308 
IntraPredCflInit_NEON()1309 void IntraPredCflInit_NEON() {
1310   low_bitdepth::Init8bpp();
1311 #if LIBGAV1_MAX_BITDEPTH >= 10
1312   high_bitdepth::Init10bpp();
1313 #endif
1314 }
1315 
1316 }  // namespace dsp
1317 }  // namespace libgav1
1318 
1319 #else   // !LIBGAV1_ENABLE_NEON
1320 namespace libgav1 {
1321 namespace dsp {
1322 
IntraPredCflInit_NEON()1323 void IntraPredCflInit_NEON() {}
1324 
1325 }  // namespace dsp
1326 }  // namespace libgav1
1327 #endif  // LIBGAV1_ENABLE_NEON
1328