xref: /aosp_15_r20/external/libgav1/src/dsp/arm/obmc_neon.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/obmc.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON
19 
20 #include <arm_neon.h>
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <cstddef>
25 #include <cstdint>
26 #include <cstring>
27 
28 #include "src/dsp/arm/common_neon.h"
29 #include "src/dsp/constants.h"
30 #include "src/dsp/dsp.h"
31 #include "src/utils/common.h"
32 
33 namespace libgav1 {
34 namespace dsp {
35 namespace {
36 #include "src/dsp/obmc.inc"
37 
38 }  // namespace
39 
40 namespace low_bitdepth {
41 namespace {
42 
WriteObmcLine4(uint8_t * LIBGAV1_RESTRICT const pred,const uint8_t * LIBGAV1_RESTRICT const obmc_pred,const uint8x8_t pred_mask,const uint8x8_t obmc_pred_mask)43 inline void WriteObmcLine4(uint8_t* LIBGAV1_RESTRICT const pred,
44                            const uint8_t* LIBGAV1_RESTRICT const obmc_pred,
45                            const uint8x8_t pred_mask,
46                            const uint8x8_t obmc_pred_mask) {
47   const uint8x8_t pred_val = Load4(pred);
48   const uint8x8_t obmc_pred_val = Load4(obmc_pred);
49   const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val);
50   const uint8x8_t result =
51       vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6);
52   StoreLo4(pred, result);
53 }
54 
WriteObmcLine8(uint8_t * LIBGAV1_RESTRICT const pred,const uint8x8_t obmc_pred_val,const uint8x8_t pred_mask,const uint8x8_t obmc_pred_mask)55 inline void WriteObmcLine8(uint8_t* LIBGAV1_RESTRICT const pred,
56                            const uint8x8_t obmc_pred_val,
57                            const uint8x8_t pred_mask,
58                            const uint8x8_t obmc_pred_mask) {
59   const uint8x8_t pred_val = vld1_u8(pred);
60   const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val);
61   const uint8x8_t result =
62       vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6);
63   vst1_u8(pred, result);
64 }
65 
OverlapBlendFromLeft2xH_NEON(uint8_t * LIBGAV1_RESTRICT pred,const ptrdiff_t prediction_stride,const int height,const uint8_t * LIBGAV1_RESTRICT obmc_pred,const ptrdiff_t obmc_prediction_stride)66 inline void OverlapBlendFromLeft2xH_NEON(
67     uint8_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
68     const int height, const uint8_t* LIBGAV1_RESTRICT obmc_pred,
69     const ptrdiff_t obmc_prediction_stride) {
70   const uint8x8_t mask_inverter = vdup_n_u8(64);
71   const uint8x8_t pred_mask = Load2(kObmcMask);
72   const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
73   uint8x8_t pred_val = vdup_n_u8(0);
74   uint8x8_t obmc_pred_val = vdup_n_u8(0);
75   int y = 0;
76   do {
77     pred_val = Load2<0>(pred, pred_val);
78     const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val);
79     obmc_pred_val = Load2<0>(obmc_pred, obmc_pred_val);
80     const uint8x8_t result =
81         vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6);
82     Store2<0>(pred, result);
83 
84     pred += prediction_stride;
85     obmc_pred += obmc_prediction_stride;
86   } while (++y != height);
87 }
88 
OverlapBlendFromLeft4xH_NEON(uint8_t * LIBGAV1_RESTRICT pred,const ptrdiff_t prediction_stride,const int height,const uint8_t * LIBGAV1_RESTRICT obmc_pred,const ptrdiff_t obmc_prediction_stride)89 inline void OverlapBlendFromLeft4xH_NEON(
90     uint8_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
91     const int height, const uint8_t* LIBGAV1_RESTRICT obmc_pred,
92     const ptrdiff_t obmc_prediction_stride) {
93   const uint8x8_t mask_inverter = vdup_n_u8(64);
94   const uint8x8_t pred_mask = Load4(kObmcMask + 2);
95   // 64 - mask
96   const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
97   int y = 0;
98   do {
99     WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
100     pred += prediction_stride;
101     obmc_pred += obmc_prediction_stride;
102 
103     WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
104     pred += prediction_stride;
105     obmc_pred += obmc_prediction_stride;
106 
107     y += 2;
108   } while (y != height);
109 }
110 
OverlapBlendFromLeft8xH_NEON(uint8_t * LIBGAV1_RESTRICT pred,const ptrdiff_t prediction_stride,const int height,const uint8_t * LIBGAV1_RESTRICT obmc_pred)111 inline void OverlapBlendFromLeft8xH_NEON(
112     uint8_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
113     const int height, const uint8_t* LIBGAV1_RESTRICT obmc_pred) {
114   const uint8x8_t mask_inverter = vdup_n_u8(64);
115   const uint8x8_t pred_mask = vld1_u8(kObmcMask + 6);
116   constexpr int obmc_prediction_stride = 8;
117   // 64 - mask
118   const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
119   int y = 0;
120   do {
121     const uint8x16_t obmc_pred_val = vld1q_u8(obmc_pred);
122     WriteObmcLine8(pred, vget_low_u8(obmc_pred_val), pred_mask, obmc_pred_mask);
123     pred += prediction_stride;
124 
125     WriteObmcLine8(pred, vget_high_u8(obmc_pred_val), pred_mask,
126                    obmc_pred_mask);
127     pred += prediction_stride;
128 
129     obmc_pred += obmc_prediction_stride << 1;
130     y += 2;
131   } while (y != height);
132 }
133 
OverlapBlendFromLeft_NEON(void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t prediction_stride,const int width,const int height,const void * LIBGAV1_RESTRICT const obmc_prediction,const ptrdiff_t obmc_prediction_stride)134 void OverlapBlendFromLeft_NEON(
135     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t prediction_stride,
136     const int width, const int height,
137     const void* LIBGAV1_RESTRICT const obmc_prediction,
138     const ptrdiff_t obmc_prediction_stride) {
139   auto* pred = static_cast<uint8_t*>(prediction);
140   const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction);
141   assert(width >= 2);
142   assert(height >= 4);
143 
144   if (width == 2) {
145     OverlapBlendFromLeft2xH_NEON(pred, prediction_stride, height, obmc_pred,
146                                  obmc_prediction_stride);
147     return;
148   }
149   if (width == 4) {
150     OverlapBlendFromLeft4xH_NEON(pred, prediction_stride, height, obmc_pred,
151                                  obmc_prediction_stride);
152     return;
153   }
154   if (width == 8) {
155     OverlapBlendFromLeft8xH_NEON(pred, prediction_stride, height, obmc_pred);
156     return;
157   }
158   const uint8x16_t mask_inverter = vdupq_n_u8(64);
159   const uint8_t* mask = kObmcMask + width - 2;
160   int x = 0;
161   do {
162     pred = static_cast<uint8_t*>(prediction) + x;
163     obmc_pred = static_cast<const uint8_t*>(obmc_prediction) + x;
164     const uint8x16_t pred_mask = vld1q_u8(mask + x);
165     // 64 - mask
166     const uint8x16_t obmc_pred_mask = vsubq_u8(mask_inverter, pred_mask);
167     int y = 0;
168     do {
169       const uint8x16_t pred_val = vld1q_u8(pred);
170       const uint8x16_t obmc_pred_val = vld1q_u8(obmc_pred);
171       const uint16x8_t weighted_pred_lo =
172           vmull_u8(vget_low_u8(pred_mask), vget_low_u8(pred_val));
173       const uint8x8_t result_lo =
174           vrshrn_n_u16(vmlal_u8(weighted_pred_lo, vget_low_u8(obmc_pred_mask),
175                                 vget_low_u8(obmc_pred_val)),
176                        6);
177       const uint16x8_t weighted_pred_hi =
178           vmull_u8(vget_high_u8(pred_mask), vget_high_u8(pred_val));
179       const uint8x8_t result_hi =
180           vrshrn_n_u16(vmlal_u8(weighted_pred_hi, vget_high_u8(obmc_pred_mask),
181                                 vget_high_u8(obmc_pred_val)),
182                        6);
183       vst1q_u8(pred, vcombine_u8(result_lo, result_hi));
184 
185       pred += prediction_stride;
186       obmc_pred += obmc_prediction_stride;
187     } while (++y < height);
188     x += 16;
189   } while (x < width);
190 }
191 
OverlapBlendFromTop4x4_NEON(uint8_t * LIBGAV1_RESTRICT pred,const ptrdiff_t prediction_stride,const uint8_t * LIBGAV1_RESTRICT obmc_pred,const ptrdiff_t obmc_prediction_stride,const int height)192 inline void OverlapBlendFromTop4x4_NEON(
193     uint8_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
194     const uint8_t* LIBGAV1_RESTRICT obmc_pred,
195     const ptrdiff_t obmc_prediction_stride, const int height) {
196   uint8x8_t pred_mask = vdup_n_u8(kObmcMask[height - 2]);
197   const uint8x8_t mask_inverter = vdup_n_u8(64);
198   uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
199   WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
200   pred += prediction_stride;
201   obmc_pred += obmc_prediction_stride;
202 
203   if (height == 2) {
204     return;
205   }
206 
207   pred_mask = vdup_n_u8(kObmcMask[3]);
208   obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
209   WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
210   pred += prediction_stride;
211   obmc_pred += obmc_prediction_stride;
212 
213   pred_mask = vdup_n_u8(kObmcMask[4]);
214   obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
215   WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
216 }
217 
OverlapBlendFromTop4xH_NEON(uint8_t * LIBGAV1_RESTRICT pred,const ptrdiff_t prediction_stride,const int height,const uint8_t * LIBGAV1_RESTRICT obmc_pred,const ptrdiff_t obmc_prediction_stride)218 inline void OverlapBlendFromTop4xH_NEON(
219     uint8_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
220     const int height, const uint8_t* LIBGAV1_RESTRICT obmc_pred,
221     const ptrdiff_t obmc_prediction_stride) {
222   if (height < 8) {
223     OverlapBlendFromTop4x4_NEON(pred, prediction_stride, obmc_pred,
224                                 obmc_prediction_stride, height);
225     return;
226   }
227   const uint8_t* mask = kObmcMask + height - 2;
228   const uint8x8_t mask_inverter = vdup_n_u8(64);
229   int y = 0;
230   // Compute 6 lines for height 8, or 12 lines for height 16. The remaining
231   // lines are unchanged as the corresponding mask value is 64.
232   do {
233     uint8x8_t pred_mask = vdup_n_u8(mask[y]);
234     uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
235     WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
236     pred += prediction_stride;
237     obmc_pred += obmc_prediction_stride;
238 
239     pred_mask = vdup_n_u8(mask[y + 1]);
240     obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
241     WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
242     pred += prediction_stride;
243     obmc_pred += obmc_prediction_stride;
244 
245     pred_mask = vdup_n_u8(mask[y + 2]);
246     obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
247     WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
248     pred += prediction_stride;
249     obmc_pred += obmc_prediction_stride;
250 
251     pred_mask = vdup_n_u8(mask[y + 3]);
252     obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
253     WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
254     pred += prediction_stride;
255     obmc_pred += obmc_prediction_stride;
256 
257     pred_mask = vdup_n_u8(mask[y + 4]);
258     obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
259     WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
260     pred += prediction_stride;
261     obmc_pred += obmc_prediction_stride;
262 
263     pred_mask = vdup_n_u8(mask[y + 5]);
264     obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
265     WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
266     pred += prediction_stride;
267     obmc_pred += obmc_prediction_stride;
268 
269     // Increment for the right mask index.
270     y += 6;
271   } while (y < height - 4);
272 }
273 
OverlapBlendFromTop8xH_NEON(uint8_t * LIBGAV1_RESTRICT pred,const ptrdiff_t prediction_stride,const int height,const uint8_t * LIBGAV1_RESTRICT obmc_pred)274 inline void OverlapBlendFromTop8xH_NEON(
275     uint8_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
276     const int height, const uint8_t* LIBGAV1_RESTRICT obmc_pred) {
277   constexpr int obmc_prediction_stride = 8;
278   const uint8x8_t mask_inverter = vdup_n_u8(64);
279   const uint8_t* mask = kObmcMask + height - 2;
280   const int compute_height = height - (height >> 2);
281   int y = 0;
282   do {
283     const uint8x8_t pred_mask0 = vdup_n_u8(mask[y]);
284     // 64 - mask
285     const uint8x8_t obmc_pred_mask0 = vsub_u8(mask_inverter, pred_mask0);
286     const uint8x16_t obmc_pred_val = vld1q_u8(obmc_pred);
287 
288     WriteObmcLine8(pred, vget_low_u8(obmc_pred_val), pred_mask0,
289                    obmc_pred_mask0);
290     pred += prediction_stride;
291     ++y;
292 
293     const uint8x8_t pred_mask1 = vdup_n_u8(mask[y]);
294     // 64 - mask
295     const uint8x8_t obmc_pred_mask1 = vsub_u8(mask_inverter, pred_mask1);
296     WriteObmcLine8(pred, vget_high_u8(obmc_pred_val), pred_mask1,
297                    obmc_pred_mask1);
298     pred += prediction_stride;
299     obmc_pred += obmc_prediction_stride << 1;
300   } while (++y < compute_height);
301 }
302 
OverlapBlendFromTop_NEON(void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t prediction_stride,const int width,const int height,const void * LIBGAV1_RESTRICT const obmc_prediction,const ptrdiff_t obmc_prediction_stride)303 void OverlapBlendFromTop_NEON(
304     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t prediction_stride,
305     const int width, const int height,
306     const void* LIBGAV1_RESTRICT const obmc_prediction,
307     const ptrdiff_t obmc_prediction_stride) {
308   auto* pred = static_cast<uint8_t*>(prediction);
309   const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction);
310   assert(width >= 4);
311   assert(height >= 2);
312 
313   if (width == 4) {
314     OverlapBlendFromTop4xH_NEON(pred, prediction_stride, height, obmc_pred,
315                                 obmc_prediction_stride);
316     return;
317   }
318 
319   if (width == 8) {
320     OverlapBlendFromTop8xH_NEON(pred, prediction_stride, height, obmc_pred);
321     return;
322   }
323 
324   const uint8_t* mask = kObmcMask + height - 2;
325   const uint8x8_t mask_inverter = vdup_n_u8(64);
326   // Stop when mask value becomes 64. This is inferred for 4xH.
327   const int compute_height = height - (height >> 2);
328   int y = 0;
329   do {
330     const uint8x8_t pred_mask = vdup_n_u8(mask[y]);
331     // 64 - mask
332     const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
333     int x = 0;
334     do {
335       const uint8x16_t pred_val = vld1q_u8(pred + x);
336       const uint8x16_t obmc_pred_val = vld1q_u8(obmc_pred + x);
337       const uint16x8_t weighted_pred_lo =
338           vmull_u8(pred_mask, vget_low_u8(pred_val));
339       const uint8x8_t result_lo =
340           vrshrn_n_u16(vmlal_u8(weighted_pred_lo, obmc_pred_mask,
341                                 vget_low_u8(obmc_pred_val)),
342                        6);
343       const uint16x8_t weighted_pred_hi =
344           vmull_u8(pred_mask, vget_high_u8(pred_val));
345       const uint8x8_t result_hi =
346           vrshrn_n_u16(vmlal_u8(weighted_pred_hi, obmc_pred_mask,
347                                 vget_high_u8(obmc_pred_val)),
348                        6);
349       vst1q_u8(pred + x, vcombine_u8(result_lo, result_hi));
350 
351       x += 16;
352     } while (x < width);
353     pred += prediction_stride;
354     obmc_pred += obmc_prediction_stride;
355   } while (++y < compute_height);
356 }
357 
Init8bpp()358 void Init8bpp() {
359   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
360   assert(dsp != nullptr);
361   dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendFromTop_NEON;
362   dsp->obmc_blend[kObmcDirectionHorizontal] = OverlapBlendFromLeft_NEON;
363 }
364 
365 }  // namespace
366 }  // namespace low_bitdepth
367 
368 #if LIBGAV1_MAX_BITDEPTH >= 10
369 namespace high_bitdepth {
370 namespace {
371 
372 // This is a flat array of masks for each block dimension from 2 to 32. The
373 // starting index for each length is length-2. The value 64 leaves the result
374 // equal to |pred| and may be ignored if convenient. Vector loads may overrread
375 // values meant for larger sizes, but these values will be unused.
376 constexpr uint16_t kObmcMask[62] = {
377     // Obmc Mask 2
378     45, 64,
379     // Obmc Mask 4
380     39, 50, 59, 64,
381     // Obmc Mask 8
382     36, 42, 48, 53, 57, 61, 64, 64,
383     // Obmc Mask 16
384     34, 37, 40, 43, 46, 49, 52, 54, 56, 58, 60, 61, 64, 64, 64, 64,
385     // Obmc Mask 32
386     33, 35, 36, 38, 40, 41, 43, 44, 45, 47, 48, 50, 51, 52, 53, 55, 56, 57, 58,
387     59, 60, 60, 61, 62, 64, 64, 64, 64, 64, 64, 64, 64};
388 
BlendObmc2Or4(uint16_t * const pred,const uint16x4_t obmc_pred_val,const uint16x4_t pred_mask,const uint16x4_t obmc_pred_mask)389 inline uint16x4_t BlendObmc2Or4(uint16_t* const pred,
390                                 const uint16x4_t obmc_pred_val,
391                                 const uint16x4_t pred_mask,
392                                 const uint16x4_t obmc_pred_mask) {
393   const uint16x4_t pred_val = vld1_u16(pred);
394   const uint16x4_t weighted_pred = vmul_u16(pred_mask, pred_val);
395   const uint16x4_t result =
396       vrshr_n_u16(vmla_u16(weighted_pred, obmc_pred_mask, obmc_pred_val), 6);
397   return result;
398 }
399 
BlendObmc8(uint16_t * LIBGAV1_RESTRICT const pred,const uint16_t * LIBGAV1_RESTRICT const obmc_pred,const uint16x8_t pred_mask,const uint16x8_t obmc_pred_mask)400 inline uint16x8_t BlendObmc8(uint16_t* LIBGAV1_RESTRICT const pred,
401                              const uint16_t* LIBGAV1_RESTRICT const obmc_pred,
402                              const uint16x8_t pred_mask,
403                              const uint16x8_t obmc_pred_mask) {
404   const uint16x8_t pred_val = vld1q_u16(pred);
405   const uint16x8_t obmc_pred_val = vld1q_u16(obmc_pred);
406   const uint16x8_t weighted_pred = vmulq_u16(pred_mask, pred_val);
407   const uint16x8_t result =
408       vrshrq_n_u16(vmlaq_u16(weighted_pred, obmc_pred_mask, obmc_pred_val), 6);
409   return result;
410 }
411 
OverlapBlendFromLeft2xH_NEON(uint16_t * LIBGAV1_RESTRICT pred,const ptrdiff_t prediction_stride,const int height,const uint16_t * LIBGAV1_RESTRICT obmc_pred)412 inline void OverlapBlendFromLeft2xH_NEON(
413     uint16_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
414     const int height, const uint16_t* LIBGAV1_RESTRICT obmc_pred) {
415   constexpr int obmc_prediction_stride = 2;
416   const uint16x4_t mask_inverter = vdup_n_u16(64);
417   // Second two lanes unused.
418   const uint16x4_t pred_mask = vld1_u16(kObmcMask);
419   const uint16x4_t obmc_pred_mask = vsub_u16(mask_inverter, pred_mask);
420   int y = 0;
421   do {
422     const uint16x4_t obmc_pred_0 = vld1_u16(obmc_pred);
423     const uint16x4_t result_0 =
424         BlendObmc2Or4(pred, obmc_pred_0, pred_mask, obmc_pred_mask);
425     Store2<0>(pred, result_0);
426 
427     pred = AddByteStride(pred, prediction_stride);
428     obmc_pred += obmc_prediction_stride;
429 
430     const uint16x4_t obmc_pred_1 = vld1_u16(obmc_pred);
431     const uint16x4_t result_1 =
432         BlendObmc2Or4(pred, obmc_pred_1, pred_mask, obmc_pred_mask);
433     Store2<0>(pred, result_1);
434 
435     pred = AddByteStride(pred, prediction_stride);
436     obmc_pred += obmc_prediction_stride;
437 
438     y += 2;
439   } while (y != height);
440 }
441 
OverlapBlendFromLeft4xH_NEON(uint16_t * LIBGAV1_RESTRICT pred,const ptrdiff_t prediction_stride,const int height,const uint16_t * LIBGAV1_RESTRICT obmc_pred)442 inline void OverlapBlendFromLeft4xH_NEON(
443     uint16_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
444     const int height, const uint16_t* LIBGAV1_RESTRICT obmc_pred) {
445   constexpr int obmc_prediction_stride = 4;
446   const uint16x4_t mask_inverter = vdup_n_u16(64);
447   const uint16x4_t pred_mask = vld1_u16(kObmcMask + 2);
448   // 64 - mask
449   const uint16x4_t obmc_pred_mask = vsub_u16(mask_inverter, pred_mask);
450   int y = 0;
451   do {
452     const uint16x8_t obmc_pred_val = vld1q_u16(obmc_pred);
453     const uint16x4_t result_0 = BlendObmc2Or4(pred, vget_low_u16(obmc_pred_val),
454                                               pred_mask, obmc_pred_mask);
455     vst1_u16(pred, result_0);
456     pred = AddByteStride(pred, prediction_stride);
457 
458     const uint16x4_t result_1 = BlendObmc2Or4(
459         pred, vget_high_u16(obmc_pred_val), pred_mask, obmc_pred_mask);
460     vst1_u16(pred, result_1);
461     pred = AddByteStride(pred, prediction_stride);
462     obmc_pred += obmc_prediction_stride << 1;
463 
464     y += 2;
465   } while (y != height);
466 }
467 
OverlapBlendFromLeft_NEON(void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t prediction_stride,const int width,const int height,const void * LIBGAV1_RESTRICT const obmc_prediction,const ptrdiff_t obmc_prediction_stride)468 void OverlapBlendFromLeft_NEON(
469     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t prediction_stride,
470     const int width, const int height,
471     const void* LIBGAV1_RESTRICT const obmc_prediction,
472     const ptrdiff_t obmc_prediction_stride) {
473   auto* pred = static_cast<uint16_t*>(prediction);
474   const auto* obmc_pred = static_cast<const uint16_t*>(obmc_prediction);
475   assert(width >= 2);
476   assert(height >= 4);
477 
478   if (width == 2) {
479     OverlapBlendFromLeft2xH_NEON(pred, prediction_stride, height, obmc_pred);
480     return;
481   }
482   if (width == 4) {
483     OverlapBlendFromLeft4xH_NEON(pred, prediction_stride, height, obmc_pred);
484     return;
485   }
486   const uint16x8_t mask_inverter = vdupq_n_u16(64);
487   const uint16_t* mask = kObmcMask + width - 2;
488   int x = 0;
489   do {
490     uint16_t* pred_x = pred + x;
491     const uint16_t* obmc_pred_x = obmc_pred + x;
492     const uint16x8_t pred_mask = vld1q_u16(mask + x);
493     // 64 - mask
494     const uint16x8_t obmc_pred_mask = vsubq_u16(mask_inverter, pred_mask);
495     int y = 0;
496     do {
497       const uint16x8_t result =
498           BlendObmc8(pred_x, obmc_pred_x, pred_mask, obmc_pred_mask);
499       vst1q_u16(pred_x, result);
500 
501       pred_x = AddByteStride(pred_x, prediction_stride);
502       obmc_pred_x = AddByteStride(obmc_pred_x, obmc_prediction_stride);
503     } while (++y < height);
504     x += 8;
505   } while (x < width);
506 }
507 
508 template <int lane>
BlendObmcFromTop4(uint16_t * const pred,const uint16x4_t obmc_pred_val,const uint16x8_t pred_mask,const uint16x8_t obmc_pred_mask)509 inline uint16x4_t BlendObmcFromTop4(uint16_t* const pred,
510                                     const uint16x4_t obmc_pred_val,
511                                     const uint16x8_t pred_mask,
512                                     const uint16x8_t obmc_pred_mask) {
513   const uint16x4_t pred_val = vld1_u16(pred);
514   const uint16x4_t weighted_pred = VMulLaneQU16<lane>(pred_val, pred_mask);
515   const uint16x4_t result = vrshr_n_u16(
516       VMlaLaneQU16<lane>(weighted_pred, obmc_pred_val, obmc_pred_mask), 6);
517   return result;
518 }
519 
520 template <int lane>
BlendObmcFromTop8(uint16_t * LIBGAV1_RESTRICT const pred,const uint16_t * LIBGAV1_RESTRICT const obmc_pred,const uint16x8_t pred_mask,const uint16x8_t obmc_pred_mask)521 inline uint16x8_t BlendObmcFromTop8(
522     uint16_t* LIBGAV1_RESTRICT const pred,
523     const uint16_t* LIBGAV1_RESTRICT const obmc_pred,
524     const uint16x8_t pred_mask, const uint16x8_t obmc_pred_mask) {
525   const uint16x8_t pred_val = vld1q_u16(pred);
526   const uint16x8_t obmc_pred_val = vld1q_u16(obmc_pred);
527   const uint16x8_t weighted_pred = VMulQLaneQU16<lane>(pred_val, pred_mask);
528   const uint16x8_t result = vrshrq_n_u16(
529       VMlaQLaneQU16<lane>(weighted_pred, obmc_pred_val, obmc_pred_mask), 6);
530   return result;
531 }
532 
OverlapBlendFromTop4x2Or4_NEON(uint16_t * LIBGAV1_RESTRICT pred,const ptrdiff_t prediction_stride,const uint16_t * LIBGAV1_RESTRICT obmc_pred,const int height)533 inline void OverlapBlendFromTop4x2Or4_NEON(
534     uint16_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
535     const uint16_t* LIBGAV1_RESTRICT obmc_pred, const int height) {
536   constexpr int obmc_prediction_stride = 4;
537   const uint16x8_t pred_mask = vld1q_u16(&kObmcMask[height - 2]);
538   const uint16x8_t mask_inverter = vdupq_n_u16(64);
539   const uint16x8_t obmc_pred_mask = vsubq_u16(mask_inverter, pred_mask);
540   const uint16x8_t obmc_pred_val_0 = vld1q_u16(obmc_pred);
541   uint16x4_t result = BlendObmcFromTop4<0>(pred, vget_low_u16(obmc_pred_val_0),
542                                            pred_mask, obmc_pred_mask);
543   vst1_u16(pred, result);
544   pred = AddByteStride(pred, prediction_stride);
545 
546   if (height == 2) {
547     // Mask value is 64, meaning |pred| is unchanged.
548     return;
549   }
550 
551   result = BlendObmcFromTop4<1>(pred, vget_high_u16(obmc_pred_val_0), pred_mask,
552                                 obmc_pred_mask);
553   vst1_u16(pred, result);
554   pred = AddByteStride(pred, prediction_stride);
555   obmc_pred += obmc_prediction_stride << 1;
556 
557   const uint16x4_t obmc_pred_val_2 = vld1_u16(obmc_pred);
558   result =
559       BlendObmcFromTop4<2>(pred, obmc_pred_val_2, pred_mask, obmc_pred_mask);
560   vst1_u16(pred, result);
561 }
562 
OverlapBlendFromTop4xH_NEON(uint16_t * LIBGAV1_RESTRICT pred,const ptrdiff_t prediction_stride,const int height,const uint16_t * LIBGAV1_RESTRICT obmc_pred)563 inline void OverlapBlendFromTop4xH_NEON(
564     uint16_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
565     const int height, const uint16_t* LIBGAV1_RESTRICT obmc_pred) {
566   if (height < 8) {
567     OverlapBlendFromTop4x2Or4_NEON(pred, prediction_stride, obmc_pred, height);
568     return;
569   }
570   constexpr int obmc_prediction_stride = 4;
571   const uint16_t* mask = kObmcMask + height - 2;
572   const uint16x8_t mask_inverter = vdupq_n_u16(64);
573   int y = 0;
574   // Compute 6 lines for height 8, or 12 lines for height 16. The remaining
575   // lines are unchanged as the corresponding mask value is 64.
576   do {
577     const uint16x8_t pred_mask = vld1q_u16(&mask[y]);
578     const uint16x8_t obmc_pred_mask = vsubq_u16(mask_inverter, pred_mask);
579     // Load obmc row 0, 1.
580     uint16x8_t obmc_pred_val = vld1q_u16(obmc_pred);
581     uint16x4_t result = BlendObmcFromTop4<0>(pred, vget_low_u16(obmc_pred_val),
582                                              pred_mask, obmc_pred_mask);
583     vst1_u16(pred, result);
584     pred = AddByteStride(pred, prediction_stride);
585 
586     result = BlendObmcFromTop4<1>(pred, vget_high_u16(obmc_pred_val), pred_mask,
587                                   obmc_pred_mask);
588     vst1_u16(pred, result);
589     pred = AddByteStride(pred, prediction_stride);
590     obmc_pred += obmc_prediction_stride << 1;
591 
592     // Load obmc row 2, 3.
593     obmc_pred_val = vld1q_u16(obmc_pred);
594     result = BlendObmcFromTop4<2>(pred, vget_low_u16(obmc_pred_val), pred_mask,
595                                   obmc_pred_mask);
596     vst1_u16(pred, result);
597     pred = AddByteStride(pred, prediction_stride);
598 
599     result = BlendObmcFromTop4<3>(pred, vget_high_u16(obmc_pred_val), pred_mask,
600                                   obmc_pred_mask);
601     vst1_u16(pred, result);
602     pred = AddByteStride(pred, prediction_stride);
603     obmc_pred += obmc_prediction_stride << 1;
604 
605     // Load obmc row 4, 5.
606     obmc_pred_val = vld1q_u16(obmc_pred);
607     result = BlendObmcFromTop4<4>(pred, vget_low_u16(obmc_pred_val), pred_mask,
608                                   obmc_pred_mask);
609     vst1_u16(pred, result);
610     pred = AddByteStride(pred, prediction_stride);
611 
612     result = BlendObmcFromTop4<5>(pred, vget_high_u16(obmc_pred_val), pred_mask,
613                                   obmc_pred_mask);
614     vst1_u16(pred, result);
615     pred = AddByteStride(pred, prediction_stride);
616     obmc_pred += obmc_prediction_stride << 1;
617 
618     // Increment for the right mask index.
619     y += 6;
620   } while (y < height - 4);
621 }
622 
OverlapBlendFromTop8xH_NEON(uint16_t * LIBGAV1_RESTRICT pred,const ptrdiff_t prediction_stride,const uint16_t * LIBGAV1_RESTRICT obmc_pred,const int height)623 inline void OverlapBlendFromTop8xH_NEON(
624     uint16_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
625     const uint16_t* LIBGAV1_RESTRICT obmc_pred, const int height) {
626   const uint16_t* mask = kObmcMask + height - 2;
627   const uint16x8_t mask_inverter = vdupq_n_u16(64);
628   uint16x8_t pred_mask = vld1q_u16(mask);
629   uint16x8_t obmc_pred_mask = vsubq_u16(mask_inverter, pred_mask);
630   uint16x8_t result =
631       BlendObmcFromTop8<0>(pred, obmc_pred, pred_mask, obmc_pred_mask);
632   vst1q_u16(pred, result);
633   if (height == 2) return;
634 
635   constexpr int obmc_prediction_stride = 8;
636   pred = AddByteStride(pred, prediction_stride);
637   obmc_pred += obmc_prediction_stride;
638 
639   result = BlendObmcFromTop8<1>(pred, obmc_pred, pred_mask, obmc_pred_mask);
640   vst1q_u16(pred, result);
641   pred = AddByteStride(pred, prediction_stride);
642   obmc_pred += obmc_prediction_stride;
643 
644   result = BlendObmcFromTop8<2>(pred, obmc_pred, pred_mask, obmc_pred_mask);
645   vst1q_u16(pred, result);
646   pred = AddByteStride(pred, prediction_stride);
647   obmc_pred += obmc_prediction_stride;
648 
649   result = BlendObmcFromTop8<3>(pred, obmc_pred, pred_mask, obmc_pred_mask);
650   vst1q_u16(pred, result);
651   if (height == 4) return;
652 
653   pred = AddByteStride(pred, prediction_stride);
654   obmc_pred += obmc_prediction_stride;
655 
656   result = BlendObmcFromTop8<4>(pred, obmc_pred, pred_mask, obmc_pred_mask);
657   vst1q_u16(pred, result);
658   pred = AddByteStride(pred, prediction_stride);
659   obmc_pred += obmc_prediction_stride;
660 
661   result = BlendObmcFromTop8<5>(pred, obmc_pred, pred_mask, obmc_pred_mask);
662   vst1q_u16(pred, result);
663 
664   if (height == 8) return;
665 
666   pred = AddByteStride(pred, prediction_stride);
667   obmc_pred += obmc_prediction_stride;
668 
669   result = BlendObmcFromTop8<6>(pred, obmc_pred, pred_mask, obmc_pred_mask);
670   vst1q_u16(pred, result);
671   pred = AddByteStride(pred, prediction_stride);
672   obmc_pred += obmc_prediction_stride;
673 
674   result = BlendObmcFromTop8<7>(pred, obmc_pred, pred_mask, obmc_pred_mask);
675   vst1q_u16(pred, result);
676   pred = AddByteStride(pred, prediction_stride);
677   obmc_pred += obmc_prediction_stride;
678 
679   pred_mask = vld1q_u16(&mask[8]);
680   obmc_pred_mask = vsubq_u16(mask_inverter, pred_mask);
681 
682   result = BlendObmcFromTop8<0>(pred, obmc_pred, pred_mask, obmc_pred_mask);
683   vst1q_u16(pred, result);
684   pred = AddByteStride(pred, prediction_stride);
685   obmc_pred += obmc_prediction_stride;
686 
687   result = BlendObmcFromTop8<1>(pred, obmc_pred, pred_mask, obmc_pred_mask);
688   vst1q_u16(pred, result);
689   pred = AddByteStride(pred, prediction_stride);
690   obmc_pred += obmc_prediction_stride;
691 
692   result = BlendObmcFromTop8<2>(pred, obmc_pred, pred_mask, obmc_pred_mask);
693   vst1q_u16(pred, result);
694   pred = AddByteStride(pred, prediction_stride);
695   obmc_pred += obmc_prediction_stride;
696 
697   result = BlendObmcFromTop8<3>(pred, obmc_pred, pred_mask, obmc_pred_mask);
698   vst1q_u16(pred, result);
699 
700   if (height == 16) return;
701 
702   pred = AddByteStride(pred, prediction_stride);
703   obmc_pred += obmc_prediction_stride;
704 
705   result = BlendObmcFromTop8<4>(pred, obmc_pred, pred_mask, obmc_pred_mask);
706   vst1q_u16(pred, result);
707   pred = AddByteStride(pred, prediction_stride);
708   obmc_pred += obmc_prediction_stride;
709 
710   result = BlendObmcFromTop8<5>(pred, obmc_pred, pred_mask, obmc_pred_mask);
711   vst1q_u16(pred, result);
712   pred = AddByteStride(pred, prediction_stride);
713   obmc_pred += obmc_prediction_stride;
714 
715   result = BlendObmcFromTop8<6>(pred, obmc_pred, pred_mask, obmc_pred_mask);
716   vst1q_u16(pred, result);
717   pred = AddByteStride(pred, prediction_stride);
718   obmc_pred += obmc_prediction_stride;
719 
720   result = BlendObmcFromTop8<7>(pred, obmc_pred, pred_mask, obmc_pred_mask);
721   vst1q_u16(pred, result);
722   pred = AddByteStride(pred, prediction_stride);
723   obmc_pred += obmc_prediction_stride;
724 
725   pred_mask = vld1q_u16(&mask[16]);
726   obmc_pred_mask = vsubq_u16(mask_inverter, pred_mask);
727 
728   result = BlendObmcFromTop8<0>(pred, obmc_pred, pred_mask, obmc_pred_mask);
729   vst1q_u16(pred, result);
730   pred = AddByteStride(pred, prediction_stride);
731   obmc_pred += obmc_prediction_stride;
732 
733   result = BlendObmcFromTop8<1>(pred, obmc_pred, pred_mask, obmc_pred_mask);
734   vst1q_u16(pred, result);
735   pred = AddByteStride(pred, prediction_stride);
736   obmc_pred += obmc_prediction_stride;
737 
738   result = BlendObmcFromTop8<2>(pred, obmc_pred, pred_mask, obmc_pred_mask);
739   vst1q_u16(pred, result);
740   pred = AddByteStride(pred, prediction_stride);
741   obmc_pred += obmc_prediction_stride;
742 
743   result = BlendObmcFromTop8<3>(pred, obmc_pred, pred_mask, obmc_pred_mask);
744   vst1q_u16(pred, result);
745   pred = AddByteStride(pred, prediction_stride);
746   obmc_pred += obmc_prediction_stride;
747 
748   result = BlendObmcFromTop8<4>(pred, obmc_pred, pred_mask, obmc_pred_mask);
749   vst1q_u16(pred, result);
750   pred = AddByteStride(pred, prediction_stride);
751   obmc_pred += obmc_prediction_stride;
752 
753   result = BlendObmcFromTop8<5>(pred, obmc_pred, pred_mask, obmc_pred_mask);
754   vst1q_u16(pred, result);
755   pred = AddByteStride(pred, prediction_stride);
756   obmc_pred += obmc_prediction_stride;
757 
758   result = BlendObmcFromTop8<6>(pred, obmc_pred, pred_mask, obmc_pred_mask);
759   vst1q_u16(pred, result);
760   pred = AddByteStride(pred, prediction_stride);
761   obmc_pred += obmc_prediction_stride;
762 
763   result = BlendObmcFromTop8<7>(pred, obmc_pred, pred_mask, obmc_pred_mask);
764   vst1q_u16(pred, result);
765 }
766 
OverlapBlendFromTop_NEON(void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t prediction_stride,const int width,const int height,const void * LIBGAV1_RESTRICT const obmc_prediction,const ptrdiff_t obmc_prediction_stride)767 void OverlapBlendFromTop_NEON(
768     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t prediction_stride,
769     const int width, const int height,
770     const void* LIBGAV1_RESTRICT const obmc_prediction,
771     const ptrdiff_t obmc_prediction_stride) {
772   auto* pred = static_cast<uint16_t*>(prediction);
773   const auto* obmc_pred = static_cast<const uint16_t*>(obmc_prediction);
774   assert(width >= 4);
775   assert(height >= 2);
776 
777   if (width == 4) {
778     OverlapBlendFromTop4xH_NEON(pred, prediction_stride, height, obmc_pred);
779     return;
780   }
781 
782   if (width == 8) {
783     OverlapBlendFromTop8xH_NEON(pred, prediction_stride, obmc_pred, height);
784     return;
785   }
786 
787   const uint16_t* mask = kObmcMask + height - 2;
788   const uint16x8_t mask_inverter = vdupq_n_u16(64);
789   const uint16x8_t pred_mask = vld1q_u16(mask);
790   // 64 - mask
791   const uint16x8_t obmc_pred_mask = vsubq_u16(mask_inverter, pred_mask);
792 #define OBMC_ROW_FROM_TOP(n)                                   \
793   do {                                                         \
794     int x = 0;                                                 \
795     do {                                                       \
796       const uint16x8_t result = BlendObmcFromTop8<n>(          \
797           pred + x, obmc_pred + x, pred_mask, obmc_pred_mask); \
798       vst1q_u16(pred + x, result);                             \
799                                                                \
800       x += 8;                                                  \
801     } while (x < width);                                       \
802   } while (false)
803 
804   // Compute 1 row.
805   if (height == 2) {
806     OBMC_ROW_FROM_TOP(0);
807     return;
808   }
809 
810   // Compute 3 rows.
811   if (height == 4) {
812     OBMC_ROW_FROM_TOP(0);
813     pred = AddByteStride(pred, prediction_stride);
814     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
815     OBMC_ROW_FROM_TOP(1);
816     pred = AddByteStride(pred, prediction_stride);
817     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
818     OBMC_ROW_FROM_TOP(2);
819     return;
820   }
821 
822   // Compute 6 rows.
823   if (height == 8) {
824     OBMC_ROW_FROM_TOP(0);
825     pred = AddByteStride(pred, prediction_stride);
826     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
827     OBMC_ROW_FROM_TOP(1);
828     pred = AddByteStride(pred, prediction_stride);
829     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
830     OBMC_ROW_FROM_TOP(2);
831     pred = AddByteStride(pred, prediction_stride);
832     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
833     OBMC_ROW_FROM_TOP(3);
834     pred = AddByteStride(pred, prediction_stride);
835     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
836     OBMC_ROW_FROM_TOP(4);
837     pred = AddByteStride(pred, prediction_stride);
838     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
839     OBMC_ROW_FROM_TOP(5);
840     return;
841   }
842 
843   // Compute 12 rows.
844   if (height == 16) {
845     OBMC_ROW_FROM_TOP(0);
846     pred = AddByteStride(pred, prediction_stride);
847     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
848     OBMC_ROW_FROM_TOP(1);
849     pred = AddByteStride(pred, prediction_stride);
850     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
851     OBMC_ROW_FROM_TOP(2);
852     pred = AddByteStride(pred, prediction_stride);
853     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
854     OBMC_ROW_FROM_TOP(3);
855     pred = AddByteStride(pred, prediction_stride);
856     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
857     OBMC_ROW_FROM_TOP(4);
858     pred = AddByteStride(pred, prediction_stride);
859     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
860     OBMC_ROW_FROM_TOP(5);
861     pred = AddByteStride(pred, prediction_stride);
862     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
863     OBMC_ROW_FROM_TOP(6);
864     pred = AddByteStride(pred, prediction_stride);
865     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
866     OBMC_ROW_FROM_TOP(7);
867     pred = AddByteStride(pred, prediction_stride);
868     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
869 
870     const uint16x8_t pred_mask = vld1q_u16(&mask[8]);
871     // 64 - mask
872     const uint16x8_t obmc_pred_mask = vsubq_u16(mask_inverter, pred_mask);
873     OBMC_ROW_FROM_TOP(0);
874     pred = AddByteStride(pred, prediction_stride);
875     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
876     OBMC_ROW_FROM_TOP(1);
877     pred = AddByteStride(pred, prediction_stride);
878     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
879     OBMC_ROW_FROM_TOP(2);
880     pred = AddByteStride(pred, prediction_stride);
881     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
882     OBMC_ROW_FROM_TOP(3);
883     return;
884   }
885 
886   // Stop when mask value becomes 64. This is a multiple of 8 for height 32
887   // and 64.
888   const int compute_height = height - (height >> 2);
889   int y = 0;
890   do {
891     const uint16x8_t pred_mask = vld1q_u16(&mask[y]);
892     // 64 - mask
893     const uint16x8_t obmc_pred_mask = vsubq_u16(mask_inverter, pred_mask);
894     OBMC_ROW_FROM_TOP(0);
895     pred = AddByteStride(pred, prediction_stride);
896     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
897     OBMC_ROW_FROM_TOP(1);
898     pred = AddByteStride(pred, prediction_stride);
899     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
900     OBMC_ROW_FROM_TOP(2);
901     pred = AddByteStride(pred, prediction_stride);
902     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
903     OBMC_ROW_FROM_TOP(3);
904     pred = AddByteStride(pred, prediction_stride);
905     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
906     OBMC_ROW_FROM_TOP(4);
907     pred = AddByteStride(pred, prediction_stride);
908     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
909     OBMC_ROW_FROM_TOP(5);
910     pred = AddByteStride(pred, prediction_stride);
911     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
912     OBMC_ROW_FROM_TOP(6);
913     pred = AddByteStride(pred, prediction_stride);
914     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
915     OBMC_ROW_FROM_TOP(7);
916     pred = AddByteStride(pred, prediction_stride);
917     obmc_pred = AddByteStride(obmc_pred, obmc_prediction_stride);
918 
919     y += 8;
920   } while (y < compute_height);
921 }
922 
Init10bpp()923 void Init10bpp() {
924   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
925   assert(dsp != nullptr);
926   dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendFromTop_NEON;
927   dsp->obmc_blend[kObmcDirectionHorizontal] = OverlapBlendFromLeft_NEON;
928 }
929 
930 }  // namespace
931 }  // namespace high_bitdepth
932 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
933 
ObmcInit_NEON()934 void ObmcInit_NEON() {
935   low_bitdepth::Init8bpp();
936 #if LIBGAV1_MAX_BITDEPTH >= 10
937   high_bitdepth::Init10bpp();
938 #endif
939 }
940 
941 }  // namespace dsp
942 }  // namespace libgav1
943 
944 #else   // !LIBGAV1_ENABLE_NEON
945 
946 namespace libgav1 {
947 namespace dsp {
948 
ObmcInit_NEON()949 void ObmcInit_NEON() {}
950 
951 }  // namespace dsp
952 }  // namespace libgav1
953 #endif  // LIBGAV1_ENABLE_NEON
954