xref: /aosp_15_r20/external/libgav1/src/dsp/arm/intra_edge_neon.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/intra_edge.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON
19 
20 #include <arm_neon.h>
21 
22 #include <algorithm>
23 #include <cassert>
24 
25 #include "src/dsp/arm/common_neon.h"
26 #include "src/dsp/constants.h"
27 #include "src/dsp/dsp.h"
28 #include "src/utils/common.h"
29 
30 namespace libgav1 {
31 namespace dsp {
32 namespace {
33 
34 // Simplified version of intra_edge.cc:kKernels[][]. Only |strength| 1 and 2 are
35 // required.
36 constexpr int kKernelsNEON[3][2] = {{4, 8}, {5, 6}};
37 
38 }  // namespace
39 
40 namespace low_bitdepth {
41 namespace {
42 
IntraEdgeFilter_NEON(void * buffer,const int size,const int strength)43 void IntraEdgeFilter_NEON(void* buffer, const int size, const int strength) {
44   assert(strength == 1 || strength == 2 || strength == 3);
45   const int kernel_index = strength - 1;
46   auto* const dst_buffer = static_cast<uint8_t*>(buffer);
47 
48   // The first element is not written out (but it is input) so the number of
49   // elements written is |size| - 1.
50   if (size == 1) return;
51 
52   const uint8x16_t v_index = vcombine_u8(vcreate_u8(0x0706050403020100),
53                                          vcreate_u8(0x0f0e0d0c0b0a0908));
54   // |strength| 1 and 2 use a 3 tap filter.
55   if (strength < 3) {
56     // The last value requires extending the buffer (duplicating
57     // |dst_buffer[size - 1]). Calculate it here to avoid extra processing in
58     // neon.
59     const uint8_t last_val = RightShiftWithRounding(
60         kKernelsNEON[kernel_index][0] * dst_buffer[size - 2] +
61             kKernelsNEON[kernel_index][1] * dst_buffer[size - 1] +
62             kKernelsNEON[kernel_index][0] * dst_buffer[size - 1],
63         4);
64 
65     const uint8x8_t krn1 = vdup_n_u8(kKernelsNEON[kernel_index][1]);
66 
67     // The first value we need gets overwritten by the output from the
68     // previous iteration.
69     uint8x16_t src_0 = vld1q_u8(dst_buffer);
70     int i = 1;
71 
72     // Process blocks until there are less than 16 values remaining.
73     for (; i < size - 15; i += 16) {
74       // Loading these at the end of the block with |src_0| will read past the
75       // end of |top_row_data[160]|, the source of |buffer|.
76       const uint8x16_t src_1 = vld1q_u8(dst_buffer + i);
77       const uint8x16_t src_2 = vld1q_u8(dst_buffer + i + 1);
78       uint16x8_t sum_lo = vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_2));
79       sum_lo = vmulq_n_u16(sum_lo, kKernelsNEON[kernel_index][0]);
80       sum_lo = vmlal_u8(sum_lo, vget_low_u8(src_1), krn1);
81       uint16x8_t sum_hi = vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_2));
82       sum_hi = vmulq_n_u16(sum_hi, kKernelsNEON[kernel_index][0]);
83       sum_hi = vmlal_u8(sum_hi, vget_high_u8(src_1), krn1);
84 
85       const uint8x16_t result =
86           vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4));
87 
88       // Load the next row before overwriting. This loads an extra 15 values
89       // past |size| on the trailing iteration.
90       src_0 = vld1q_u8(dst_buffer + i + 15);
91 
92       vst1q_u8(dst_buffer + i, result);
93     }
94 
95     // The last output value |last_val| was already calculated so if
96     // |remainder| == 1 then we don't have to do anything.
97     const int remainder = (size - 1) & 0xf;
98     if (remainder > 1) {
99       const uint8x16_t src_1 = vld1q_u8(dst_buffer + i);
100       const uint8x16_t src_2 = vld1q_u8(dst_buffer + i + 1);
101 
102       uint16x8_t sum_lo = vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_2));
103       sum_lo = vmulq_n_u16(sum_lo, kKernelsNEON[kernel_index][0]);
104       sum_lo = vmlal_u8(sum_lo, vget_low_u8(src_1), krn1);
105       uint16x8_t sum_hi = vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_2));
106       sum_hi = vmulq_n_u16(sum_hi, kKernelsNEON[kernel_index][0]);
107       sum_hi = vmlal_u8(sum_hi, vget_high_u8(src_1), krn1);
108 
109       const uint8x16_t result =
110           vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4));
111       const uint8x16_t v_remainder = vdupq_n_u8(remainder);
112       // Create over write mask.
113       const uint8x16_t mask = vcleq_u8(v_remainder, v_index);
114       const uint8x16_t dst_remainder = vbslq_u8(mask, src_1, result);
115       vst1q_u8(dst_buffer + i, dst_remainder);
116     }
117 
118     dst_buffer[size - 1] = last_val;
119     return;
120   }
121 
122   assert(strength == 3);
123   // 5 tap filter. The first element requires duplicating |buffer[0]| and the
124   // last two elements require duplicating |buffer[size - 1]|.
125   uint8_t special_vals[3];
126   special_vals[0] = RightShiftWithRounding(
127       (dst_buffer[0] << 1) + (dst_buffer[0] << 2) + (dst_buffer[1] << 2) +
128           (dst_buffer[2] << 2) + (dst_buffer[3] << 1),
129       4);
130   // Clamp index for very small |size| values.
131   const int first_index_min = std::max(size - 4, 0);
132   const int second_index_min = std::max(size - 3, 0);
133   const int third_index_min = std::max(size - 2, 0);
134   special_vals[1] = RightShiftWithRounding(
135       (dst_buffer[first_index_min] << 1) + (dst_buffer[second_index_min] << 2) +
136           (dst_buffer[third_index_min] << 2) + (dst_buffer[size - 1] << 2) +
137           (dst_buffer[size - 1] << 1),
138       4);
139   special_vals[2] = RightShiftWithRounding(
140       (dst_buffer[second_index_min] << 1) + (dst_buffer[third_index_min] << 2) +
141           // x << 2 + x << 2 == x << 3
142           (dst_buffer[size - 1] << 3) + (dst_buffer[size - 1] << 1),
143       4);
144 
145   // The first two values we need get overwritten by the output from the
146   // previous iteration.
147   uint8x16_t src_0 = vld1q_u8(dst_buffer - 1);
148   uint8x16_t src_1 = vld1q_u8(dst_buffer);
149   int i = 1;
150 
151   for (; i < size - 15; i += 16) {
152     // Loading these at the end of the block with |src_[01]| will read past
153     // the end of |top_row_data[160]|, the source of |buffer|.
154     const uint8x16_t src_2 = vld1q_u8(dst_buffer + i);
155     const uint8x16_t src_3 = vld1q_u8(dst_buffer + i + 1);
156     const uint8x16_t src_4 = vld1q_u8(dst_buffer + i + 2);
157 
158     uint16x8_t sum_lo =
159         vshlq_n_u16(vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_4)), 1);
160     const uint16x8_t sum_123_lo = vaddw_u8(
161         vaddl_u8(vget_low_u8(src_1), vget_low_u8(src_2)), vget_low_u8(src_3));
162     sum_lo = vaddq_u16(sum_lo, vshlq_n_u16(sum_123_lo, 2));
163 
164     uint16x8_t sum_hi =
165         vshlq_n_u16(vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_4)), 1);
166     const uint16x8_t sum_123_hi =
167         vaddw_u8(vaddl_u8(vget_high_u8(src_1), vget_high_u8(src_2)),
168                  vget_high_u8(src_3));
169     sum_hi = vaddq_u16(sum_hi, vshlq_n_u16(sum_123_hi, 2));
170 
171     const uint8x16_t result =
172         vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4));
173 
174     src_0 = vld1q_u8(dst_buffer + i + 14);
175     src_1 = vld1q_u8(dst_buffer + i + 15);
176 
177     vst1q_u8(dst_buffer + i, result);
178   }
179 
180   const int remainder = (size - 1) & 0xf;
181   // Like the 3 tap but if there are two remaining values we have already
182   // calculated them.
183   if (remainder > 2) {
184     const uint8x16_t src_2 = vld1q_u8(dst_buffer + i);
185     const uint8x16_t src_3 = vld1q_u8(dst_buffer + i + 1);
186     const uint8x16_t src_4 = vld1q_u8(dst_buffer + i + 2);
187 
188     uint16x8_t sum_lo =
189         vshlq_n_u16(vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_4)), 1);
190     const uint16x8_t sum_123_lo = vaddw_u8(
191         vaddl_u8(vget_low_u8(src_1), vget_low_u8(src_2)), vget_low_u8(src_3));
192     sum_lo = vaddq_u16(sum_lo, vshlq_n_u16(sum_123_lo, 2));
193 
194     uint16x8_t sum_hi =
195         vshlq_n_u16(vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_4)), 1);
196     const uint16x8_t sum_123_hi =
197         vaddw_u8(vaddl_u8(vget_high_u8(src_1), vget_high_u8(src_2)),
198                  vget_high_u8(src_3));
199     sum_hi = vaddq_u16(sum_hi, vshlq_n_u16(sum_123_hi, 2));
200 
201     const uint8x16_t result =
202         vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4));
203     const uint8x16_t v_remainder = vdupq_n_u8(remainder);
204     // Create over write mask.
205     const uint8x16_t mask = vcleq_u8(v_remainder, v_index);
206     const uint8x16_t dst_remainder = vbslq_u8(mask, src_2, result);
207     vst1q_u8(dst_buffer + i, dst_remainder);
208   }
209 
210   dst_buffer[1] = special_vals[0];
211   // Avoid overwriting |dst_buffer[0]|.
212   if (size > 2) dst_buffer[size - 2] = special_vals[1];
213   dst_buffer[size - 1] = special_vals[2];
214 }
215 
216 // (-|src0| + |src1| * 9 + |src2| * 9 - |src3|) >> 4
Upsample(const uint8x8_t src0,const uint8x8_t src1,const uint8x8_t src2,const uint8x8_t src3)217 uint8x8_t Upsample(const uint8x8_t src0, const uint8x8_t src1,
218                    const uint8x8_t src2, const uint8x8_t src3) {
219   const uint16x8_t middle = vmulq_n_u16(vaddl_u8(src1, src2), 9);
220   const uint16x8_t ends = vaddl_u8(src0, src3);
221   const int16x8_t sum =
222       vsubq_s16(vreinterpretq_s16_u16(middle), vreinterpretq_s16_u16(ends));
223   return vqrshrun_n_s16(sum, 4);
224 }
225 
IntraEdgeUpsampler_NEON(void * buffer,const int size)226 void IntraEdgeUpsampler_NEON(void* buffer, const int size) {
227   assert(size % 4 == 0 && size <= 16);
228   auto* const pixel_buffer = static_cast<uint8_t*>(buffer);
229   // This is OK because we don't read this value for |size| 4 or 8 but if we
230   // write |pixel_buffer[size]| and then vld() it, that seems to introduce
231   // some latency.
232   pixel_buffer[-2] = pixel_buffer[-1];
233   if (size == 4) {
234     // This uses one load and two vtbl() which is better than 4x Load{Lo,Hi}4().
235     const uint8x8_t src = vld1_u8(pixel_buffer - 1);
236     // The outside values are negated so put those in the same vector.
237     const uint8x8_t src03 = vtbl1_u8(src, vcreate_u8(0x0404030202010000));
238     // Reverse |src1| and |src2| so we can use |src2| for the interleave at the
239     // end.
240     const uint8x8_t src21 = vtbl1_u8(src, vcreate_u8(0x0302010004030201));
241 
242     const uint16x8_t middle = vmull_u8(src21, vdup_n_u8(9));
243     const int16x8_t half_sum = vsubq_s16(
244         vreinterpretq_s16_u16(middle), vreinterpretq_s16_u16(vmovl_u8(src03)));
245     const int16x4_t sum =
246         vadd_s16(vget_low_s16(half_sum), vget_high_s16(half_sum));
247     const uint8x8_t result = vqrshrun_n_s16(vcombine_s16(sum, sum), 4);
248 
249     vst1_u8(pixel_buffer - 1, InterleaveLow8(result, src21));
250     return;
251   }
252   if (size == 8) {
253     // Likewise, one load + multiple vtbls seems preferred to multiple loads.
254     const uint8x16_t src = vld1q_u8(pixel_buffer - 1);
255     const uint8x8_t src0 = VQTbl1U8(src, vcreate_u8(0x0605040302010000));
256     const uint8x8_t src1 = vget_low_u8(src);
257     const uint8x8_t src2 = VQTbl1U8(src, vcreate_u8(0x0807060504030201));
258     const uint8x8_t src3 = VQTbl1U8(src, vcreate_u8(0x0808070605040302));
259 
260     const uint8x8x2_t output = {Upsample(src0, src1, src2, src3), src2};
261     vst2_u8(pixel_buffer - 1, output);
262     return;
263   }
264   assert(size == 12 || size == 16);
265   // Extend the input borders to avoid branching later.
266   pixel_buffer[size] = pixel_buffer[size - 1];
267   const uint8x16_t src0 = vld1q_u8(pixel_buffer - 2);
268   const uint8x16_t src1 = vld1q_u8(pixel_buffer - 1);
269   const uint8x16_t src2 = vld1q_u8(pixel_buffer);
270   const uint8x16_t src3 = vld1q_u8(pixel_buffer + 1);
271 
272   const uint8x8_t result_lo = Upsample(vget_low_u8(src0), vget_low_u8(src1),
273                                        vget_low_u8(src2), vget_low_u8(src3));
274 
275   const uint8x8x2_t output_lo = {result_lo, vget_low_u8(src2)};
276   vst2_u8(pixel_buffer - 1, output_lo);
277 
278   const uint8x8_t result_hi = Upsample(vget_high_u8(src0), vget_high_u8(src1),
279                                        vget_high_u8(src2), vget_high_u8(src3));
280 
281   if (size == 12) {
282     vst1_u8(pixel_buffer + 15, InterleaveLow8(result_hi, vget_high_u8(src2)));
283   } else /* size == 16 */ {
284     const uint8x8x2_t output_hi = {result_hi, vget_high_u8(src2)};
285     vst2_u8(pixel_buffer + 15, output_hi);
286   }
287 }
288 
Init8bpp()289 void Init8bpp() {
290   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
291   assert(dsp != nullptr);
292   dsp->intra_edge_filter = IntraEdgeFilter_NEON;
293   dsp->intra_edge_upsampler = IntraEdgeUpsampler_NEON;
294 }
295 
296 }  // namespace
297 }  // namespace low_bitdepth
298 
299 //------------------------------------------------------------------------------
300 #if LIBGAV1_MAX_BITDEPTH >= 10
301 namespace high_bitdepth {
302 namespace {
303 
304 const uint16_t kRemainderMask[8][8] = {
305     {0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000},
306     {0xffff, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000},
307     {0xffff, 0xffff, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000},
308     {0xffff, 0xffff, 0xffff, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000},
309     {0xffff, 0xffff, 0xffff, 0xffff, 0x0000, 0x0000, 0x0000, 0x0000},
310     {0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0x0000, 0x0000, 0x0000},
311     {0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0x0000, 0x0000},
312     {0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0x0000},
313 };
314 
IntraEdgeFilter_NEON(void * buffer,const int size,const int strength)315 void IntraEdgeFilter_NEON(void* buffer, const int size, const int strength) {
316   assert(strength == 1 || strength == 2 || strength == 3);
317   const int kernel_index = strength - 1;
318   auto* const dst_buffer = static_cast<uint16_t*>(buffer);
319 
320   // The first element is not written out (but it is input) so the number of
321   // elements written is |size| - 1.
322   if (size == 1) return;
323 
324   // |strength| 1 and 2 use a 3 tap filter.
325   if (strength < 3) {
326     // The last value requires extending the buffer (duplicating
327     // |dst_buffer[size - 1]). Calculate it here to avoid extra processing in
328     // neon.
329     const uint16_t last_val = RightShiftWithRounding(
330         kKernelsNEON[kernel_index][0] * dst_buffer[size - 2] +
331             kKernelsNEON[kernel_index][1] * dst_buffer[size - 1] +
332             kKernelsNEON[kernel_index][0] * dst_buffer[size - 1],
333         4);
334 
335     const uint16_t krn0 = kKernelsNEON[kernel_index][0];
336     const uint16_t krn1 = kKernelsNEON[kernel_index][1];
337 
338     // The first value we need gets overwritten by the output from the
339     // previous iteration.
340     uint16x8_t src_0 = vld1q_u16(dst_buffer);
341     int i = 1;
342 
343     // Process blocks until there are less than 16 values remaining.
344     for (; i < size - 7; i += 8) {
345       // Loading these at the end of the block with |src_0| will read past the
346       // end of |top_row_data[160]|, the source of |buffer|.
347       const uint16x8_t src_1 = vld1q_u16(dst_buffer + i);
348       const uint16x8_t src_2 = vld1q_u16(dst_buffer + i + 1);
349       const uint16x8_t sum_02 = vmulq_n_u16(vaddq_u16(src_0, src_2), krn0);
350       const uint16x8_t sum = vmlaq_n_u16(sum_02, src_1, krn1);
351       const uint16x8_t result = vrshrq_n_u16(sum, 4);
352       // Load the next row before overwriting. This loads an extra 7 values
353       // past |size| on the trailing iteration.
354       src_0 = vld1q_u16(dst_buffer + i + 7);
355       vst1q_u16(dst_buffer + i, result);
356     }
357 
358     // The last output value |last_val| was already calculated so if
359     // |remainder| == 1 then we don't have to do anything.
360     const int remainder = (size - 1) & 0x7;
361     if (remainder > 1) {
362       const uint16x8_t src_1 = vld1q_u16(dst_buffer + i);
363       const uint16x8_t src_2 = vld1q_u16(dst_buffer + i + 1);
364       const uint16x8_t sum_02 = vmulq_n_u16(vaddq_u16(src_0, src_2), krn0);
365       const uint16x8_t sum = vmlaq_n_u16(sum_02, src_1, krn1);
366       const uint16x8_t result = vrshrq_n_u16(sum, 4);
367       const uint16x8_t mask = vld1q_u16(kRemainderMask[remainder]);
368       const uint16x8_t dst_remainder = vbslq_u16(mask, result, src_1);
369       vst1q_u16(dst_buffer + i, dst_remainder);
370     }
371 
372     dst_buffer[size - 1] = last_val;
373     return;
374   }
375 
376   assert(strength == 3);
377   // 5 tap filter. The first element requires duplicating |buffer[0]| and the
378   // last two elements require duplicating |buffer[size - 1]|.
379   uint16_t special_vals[3];
380   special_vals[0] = RightShiftWithRounding(
381       (dst_buffer[0] << 1) + (dst_buffer[0] << 2) + (dst_buffer[1] << 2) +
382           (dst_buffer[2] << 2) + (dst_buffer[3] << 1),
383       4);
384   // Clamp index for very small |size| values.
385   const int first_index_min = std::max(size - 4, 0);
386   const int second_index_min = std::max(size - 3, 0);
387   const int third_index_min = std::max(size - 2, 0);
388   special_vals[1] = RightShiftWithRounding(
389       (dst_buffer[first_index_min] << 1) + (dst_buffer[second_index_min] << 2) +
390           (dst_buffer[third_index_min] << 2) + (dst_buffer[size - 1] << 2) +
391           (dst_buffer[size - 1] << 1),
392       4);
393   special_vals[2] = RightShiftWithRounding(
394       (dst_buffer[second_index_min] << 1) + (dst_buffer[third_index_min] << 2) +
395           // x << 2 + x << 2 == x << 3
396           (dst_buffer[size - 1] << 3) + (dst_buffer[size - 1] << 1),
397       4);
398 
399   // The first two values we need get overwritten by the output from the
400   // previous iteration.
401   uint16x8_t src_0 = vld1q_u16(dst_buffer - 1);
402   uint16x8_t src_1 = vld1q_u16(dst_buffer);
403   int i = 1;
404 
405   for (; i < size - 7; i += 8) {
406     // Loading these at the end of the block with |src_[01]| will read past
407     // the end of |top_row_data[160]|, the source of |buffer|.
408     const uint16x8_t src_2 = vld1q_u16(dst_buffer + i);
409     const uint16x8_t src_3 = vld1q_u16(dst_buffer + i + 1);
410     const uint16x8_t src_4 = vld1q_u16(dst_buffer + i + 2);
411     const uint16x8_t sum_04 = vshlq_n_u16(vaddq_u16(src_0, src_4), 1);
412     const uint16x8_t sum_123 = vaddq_u16(vaddq_u16(src_1, src_2), src_3);
413     const uint16x8_t sum = vaddq_u16(sum_04, vshlq_n_u16(sum_123, 2));
414     const uint16x8_t result = vrshrq_n_u16(sum, 4);
415 
416     // Load the next before overwriting.
417     src_0 = vld1q_u16(dst_buffer + i + 6);
418     src_1 = vld1q_u16(dst_buffer + i + 7);
419 
420     vst1q_u16(dst_buffer + i, result);
421   }
422 
423   const int remainder = (size - 1) & 0x7;
424   // Like the 3 tap but if there are two remaining values we have already
425   // calculated them.
426   if (remainder > 2) {
427     const uint16x8_t src_2 = vld1q_u16(dst_buffer + i);
428     const uint16x8_t src_3 = vld1q_u16(dst_buffer + i + 1);
429     const uint16x8_t src_4 = vld1q_u16(dst_buffer + i + 2);
430     const uint16x8_t sum_04 = vshlq_n_u16(vaddq_u16(src_0, src_4), 1);
431     const uint16x8_t sum_123 = vaddq_u16(vaddq_u16(src_1, src_2), src_3);
432     const uint16x8_t sum = vaddq_u16(sum_04, vshlq_n_u16(sum_123, 2));
433     const uint16x8_t result = vrshrq_n_u16(sum, 4);
434     const uint16x8_t mask = vld1q_u16(kRemainderMask[remainder]);
435     const uint16x8_t dst_remainder = vbslq_u16(mask, result, src_2);
436     vst1q_u16(dst_buffer + i, dst_remainder);
437   }
438 
439   dst_buffer[1] = special_vals[0];
440   // Avoid overwriting |dst_buffer[0]|.
441   if (size > 2) dst_buffer[size - 2] = special_vals[1];
442   dst_buffer[size - 1] = special_vals[2];
443 }
444 
IntraEdgeUpsampler_NEON(void * buffer,const int size)445 void IntraEdgeUpsampler_NEON(void* buffer, const int size) {
446   assert(size % 4 == 0 && size <= 16);
447   auto* const pixel_buffer = static_cast<uint16_t*>(buffer);
448 
449   // Extend first/last samples
450   pixel_buffer[-2] = pixel_buffer[-1];
451   pixel_buffer[size] = pixel_buffer[size - 1];
452 
453   const int16x8_t src_lo = vreinterpretq_s16_u16(vld1q_u16(pixel_buffer - 2));
454   const int16x8_t src_hi =
455       vreinterpretq_s16_u16(vld1q_u16(pixel_buffer - 2 + 8));
456   const int16x8_t src9_hi = vaddq_s16(src_hi, vshlq_n_s16(src_hi, 3));
457   const int16x8_t src9_lo = vaddq_s16(src_lo, vshlq_n_s16(src_lo, 3));
458 
459   int16x8_t sum_lo = vsubq_s16(vextq_s16(src9_lo, src9_hi, 1), src_lo);
460   sum_lo = vaddq_s16(sum_lo, vextq_s16(src9_lo, src9_hi, 2));
461   sum_lo = vsubq_s16(sum_lo, vextq_s16(src_lo, src_hi, 3));
462   sum_lo = vrshrq_n_s16(sum_lo, 4);
463 
464   uint16x8x2_t result_lo;
465   result_lo.val[0] =
466       vminq_u16(vreinterpretq_u16_s16(vmaxq_s16(sum_lo, vdupq_n_s16(0))),
467                 vdupq_n_u16((1 << kBitdepth10) - 1));
468   result_lo.val[1] = vreinterpretq_u16_s16(vextq_s16(src_lo, src_hi, 2));
469 
470   if (size > 8) {
471     const int16x8_t src_hi_extra =
472         vreinterpretq_s16_u16(vld1q_u16(pixel_buffer + 16 - 2));
473     const int16x8_t src9_hi_extra =
474         vaddq_s16(src_hi_extra, vshlq_n_s16(src_hi_extra, 3));
475 
476     int16x8_t sum_hi = vsubq_s16(vextq_s16(src9_hi, src9_hi_extra, 1), src_hi);
477     sum_hi = vaddq_s16(sum_hi, vextq_s16(src9_hi, src9_hi_extra, 2));
478     sum_hi = vsubq_s16(sum_hi, vextq_s16(src_hi, src_hi_extra, 3));
479     sum_hi = vrshrq_n_s16(sum_hi, 4);
480 
481     uint16x8x2_t result_hi;
482     result_hi.val[0] =
483         vminq_u16(vreinterpretq_u16_s16(vmaxq_s16(sum_hi, vdupq_n_s16(0))),
484                   vdupq_n_u16((1 << kBitdepth10) - 1));
485     result_hi.val[1] =
486         vreinterpretq_u16_s16(vextq_s16(src_hi, src_hi_extra, 2));
487     vst2q_u16(pixel_buffer - 1, result_lo);
488     vst2q_u16(pixel_buffer + 15, result_hi);
489   } else {
490     vst2q_u16(pixel_buffer - 1, result_lo);
491   }
492 }
493 
Init10bpp()494 void Init10bpp() {
495   Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
496   assert(dsp != nullptr);
497   dsp->intra_edge_filter = IntraEdgeFilter_NEON;
498   dsp->intra_edge_upsampler = IntraEdgeUpsampler_NEON;
499 }
500 
501 }  // namespace
502 }  // namespace high_bitdepth
503 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
504 
IntraEdgeInit_NEON()505 void IntraEdgeInit_NEON() {
506   low_bitdepth::Init8bpp();
507 #if LIBGAV1_MAX_BITDEPTH >= 10
508   high_bitdepth::Init10bpp();
509 #endif
510 }
511 
512 }  // namespace dsp
513 }  // namespace libgav1
514 
515 #else   // !LIBGAV1_ENABLE_NEON
516 namespace libgav1 {
517 namespace dsp {
518 
IntraEdgeInit_NEON()519 void IntraEdgeInit_NEON() {}
520 
521 }  // namespace dsp
522 }  // namespace libgav1
523 #endif  // LIBGAV1_ENABLE_NEON
524