1 /*
2 * Copyright (c) 2023 The WebM project authors. All rights reserved.
3 * Copyright (c) 2022, Alliance for Open Media. All rights reserved.
4 *
5 * This source code is subject to the terms of the BSD 2 Clause License and
6 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
7 * was not distributed with this source code in the LICENSE file, you can
8 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
9 * Media Patent License 1.0 was not distributed with this source code in the
10 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
11 */
12
13 #include <arm_neon.h>
14
15 #include "config/aom_config.h"
16 #include "config/aom_dsp_rtcd.h"
17
18 #include "aom_dsp/aom_filter.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/arm/sum_neon.h"
21 #include "aom_dsp/variance.h"
22
23 // Process a block of width 4 two rows at a time.
highbd_variance_4xh_neon(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int h,uint64_t * sse,int64_t * sum)24 static inline void highbd_variance_4xh_neon(const uint16_t *src_ptr,
25 int src_stride,
26 const uint16_t *ref_ptr,
27 int ref_stride, int h,
28 uint64_t *sse, int64_t *sum) {
29 int16x8_t sum_s16 = vdupq_n_s16(0);
30 int32x4_t sse_s32 = vdupq_n_s32(0);
31
32 int i = h;
33 do {
34 const uint16x8_t s = load_unaligned_u16_4x2(src_ptr, src_stride);
35 const uint16x8_t r = load_unaligned_u16_4x2(ref_ptr, ref_stride);
36
37 int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s, r));
38 sum_s16 = vaddq_s16(sum_s16, diff);
39
40 sse_s32 = vmlal_s16(sse_s32, vget_low_s16(diff), vget_low_s16(diff));
41 sse_s32 = vmlal_s16(sse_s32, vget_high_s16(diff), vget_high_s16(diff));
42
43 src_ptr += 2 * src_stride;
44 ref_ptr += 2 * ref_stride;
45 i -= 2;
46 } while (i != 0);
47
48 *sum = horizontal_add_s16x8(sum_s16);
49 *sse = horizontal_add_s32x4(sse_s32);
50 }
51
52 // For 8-bit and 10-bit data, since we're using two int32x4 accumulators, all
53 // block sizes can be processed in 32-bit elements (1023*1023*128*32 =
54 // 4286582784 for a 128x128 block).
highbd_variance_large_neon(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int w,int h,uint64_t * sse,int64_t * sum)55 static inline void highbd_variance_large_neon(const uint16_t *src_ptr,
56 int src_stride,
57 const uint16_t *ref_ptr,
58 int ref_stride, int w, int h,
59 uint64_t *sse, int64_t *sum) {
60 int32x4_t sum_s32 = vdupq_n_s32(0);
61 int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
62
63 int i = h;
64 do {
65 int j = 0;
66 do {
67 const uint16x8_t s = vld1q_u16(src_ptr + j);
68 const uint16x8_t r = vld1q_u16(ref_ptr + j);
69
70 const int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s, r));
71 sum_s32 = vpadalq_s16(sum_s32, diff);
72
73 sse_s32[0] =
74 vmlal_s16(sse_s32[0], vget_low_s16(diff), vget_low_s16(diff));
75 sse_s32[1] =
76 vmlal_s16(sse_s32[1], vget_high_s16(diff), vget_high_s16(diff));
77
78 j += 8;
79 } while (j < w);
80
81 src_ptr += src_stride;
82 ref_ptr += ref_stride;
83 } while (--i != 0);
84
85 *sum = horizontal_add_s32x4(sum_s32);
86 *sse = horizontal_long_add_u32x4(vaddq_u32(
87 vreinterpretq_u32_s32(sse_s32[0]), vreinterpretq_u32_s32(sse_s32[1])));
88 }
89
highbd_variance_8xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int h,uint64_t * sse,int64_t * sum)90 static inline void highbd_variance_8xh_neon(const uint16_t *src, int src_stride,
91 const uint16_t *ref, int ref_stride,
92 int h, uint64_t *sse,
93 int64_t *sum) {
94 highbd_variance_large_neon(src, src_stride, ref, ref_stride, 8, h, sse, sum);
95 }
96
highbd_variance_16xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int h,uint64_t * sse,int64_t * sum)97 static inline void highbd_variance_16xh_neon(const uint16_t *src,
98 int src_stride,
99 const uint16_t *ref,
100 int ref_stride, int h,
101 uint64_t *sse, int64_t *sum) {
102 highbd_variance_large_neon(src, src_stride, ref, ref_stride, 16, h, sse, sum);
103 }
104
highbd_variance_32xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int h,uint64_t * sse,int64_t * sum)105 static inline void highbd_variance_32xh_neon(const uint16_t *src,
106 int src_stride,
107 const uint16_t *ref,
108 int ref_stride, int h,
109 uint64_t *sse, int64_t *sum) {
110 highbd_variance_large_neon(src, src_stride, ref, ref_stride, 32, h, sse, sum);
111 }
112
highbd_variance_64xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int h,uint64_t * sse,int64_t * sum)113 static inline void highbd_variance_64xh_neon(const uint16_t *src,
114 int src_stride,
115 const uint16_t *ref,
116 int ref_stride, int h,
117 uint64_t *sse, int64_t *sum) {
118 highbd_variance_large_neon(src, src_stride, ref, ref_stride, 64, h, sse, sum);
119 }
120
highbd_variance_128xh_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int h,uint64_t * sse,int64_t * sum)121 static inline void highbd_variance_128xh_neon(const uint16_t *src,
122 int src_stride,
123 const uint16_t *ref,
124 int ref_stride, int h,
125 uint64_t *sse, int64_t *sum) {
126 highbd_variance_large_neon(src, src_stride, ref, ref_stride, 128, h, sse,
127 sum);
128 }
129
130 // For 12-bit data, we can only accumulate up to 128 elements in the sum of
131 // squares (4095*4095*128 = 2146435200), and because we're using two int32x4
132 // accumulators, we can only process up to 32 32-element rows (32*32/8 = 128)
133 // or 16 64-element rows before we have to accumulate into 64-bit elements.
134 // Therefore blocks of size 32x64, 64x32, 64x64, 64x128, 128x64, 128x128 are
135 // processed in a different helper function.
136
137 // Process a block of any size where the width is divisible by 8, with
138 // accumulation into 64-bit elements.
highbd_variance_xlarge_neon(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int w,int h,int h_limit,uint64_t * sse,int64_t * sum)139 static inline void highbd_variance_xlarge_neon(
140 const uint16_t *src_ptr, int src_stride, const uint16_t *ref_ptr,
141 int ref_stride, int w, int h, int h_limit, uint64_t *sse, int64_t *sum) {
142 int32x4_t sum_s32 = vdupq_n_s32(0);
143 int64x2_t sse_s64 = vdupq_n_s64(0);
144
145 // 'h_limit' is the number of 'w'-width rows we can process before our 32-bit
146 // accumulator overflows. After hitting this limit we accumulate into 64-bit
147 // elements.
148 int h_tmp = h > h_limit ? h_limit : h;
149
150 int i = 0;
151 do {
152 int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
153 do {
154 int j = 0;
155 do {
156 const uint16x8_t s0 = vld1q_u16(src_ptr + j);
157 const uint16x8_t r0 = vld1q_u16(ref_ptr + j);
158
159 const int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s0, r0));
160 sum_s32 = vpadalq_s16(sum_s32, diff);
161
162 sse_s32[0] =
163 vmlal_s16(sse_s32[0], vget_low_s16(diff), vget_low_s16(diff));
164 sse_s32[1] =
165 vmlal_s16(sse_s32[1], vget_high_s16(diff), vget_high_s16(diff));
166
167 j += 8;
168 } while (j < w);
169
170 src_ptr += src_stride;
171 ref_ptr += ref_stride;
172 i++;
173 } while (i < h_tmp);
174
175 sse_s64 = vpadalq_s32(sse_s64, sse_s32[0]);
176 sse_s64 = vpadalq_s32(sse_s64, sse_s32[1]);
177 h_tmp += h_limit;
178 } while (i < h);
179
180 *sum = horizontal_add_s32x4(sum_s32);
181 *sse = (uint64_t)horizontal_add_s64x2(sse_s64);
182 }
183
highbd_variance_32xh_xlarge_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int h,uint64_t * sse,int64_t * sum)184 static inline void highbd_variance_32xh_xlarge_neon(
185 const uint16_t *src, int src_stride, const uint16_t *ref, int ref_stride,
186 int h, uint64_t *sse, int64_t *sum) {
187 highbd_variance_xlarge_neon(src, src_stride, ref, ref_stride, 32, h, 32, sse,
188 sum);
189 }
190
highbd_variance_64xh_xlarge_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int h,uint64_t * sse,int64_t * sum)191 static inline void highbd_variance_64xh_xlarge_neon(
192 const uint16_t *src, int src_stride, const uint16_t *ref, int ref_stride,
193 int h, uint64_t *sse, int64_t *sum) {
194 highbd_variance_xlarge_neon(src, src_stride, ref, ref_stride, 64, h, 16, sse,
195 sum);
196 }
197
highbd_variance_128xh_xlarge_neon(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int h,uint64_t * sse,int64_t * sum)198 static inline void highbd_variance_128xh_xlarge_neon(
199 const uint16_t *src, int src_stride, const uint16_t *ref, int ref_stride,
200 int h, uint64_t *sse, int64_t *sum) {
201 highbd_variance_xlarge_neon(src, src_stride, ref, ref_stride, 128, h, 8, sse,
202 sum);
203 }
204
205 #define HBD_VARIANCE_WXH_8_NEON(w, h) \
206 uint32_t aom_highbd_8_variance##w##x##h##_neon( \
207 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
208 int ref_stride, uint32_t *sse) { \
209 int sum; \
210 uint64_t sse_long = 0; \
211 int64_t sum_long = 0; \
212 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \
213 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \
214 highbd_variance_##w##xh_neon(src, src_stride, ref, ref_stride, h, \
215 &sse_long, &sum_long); \
216 *sse = (uint32_t)sse_long; \
217 sum = (int)sum_long; \
218 return *sse - (uint32_t)(((int64_t)sum * sum) / (w * h)); \
219 }
220
221 #define HBD_VARIANCE_WXH_10_NEON(w, h) \
222 uint32_t aom_highbd_10_variance##w##x##h##_neon( \
223 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
224 int ref_stride, uint32_t *sse) { \
225 int sum; \
226 int64_t var; \
227 uint64_t sse_long = 0; \
228 int64_t sum_long = 0; \
229 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \
230 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \
231 highbd_variance_##w##xh_neon(src, src_stride, ref, ref_stride, h, \
232 &sse_long, &sum_long); \
233 *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 4); \
234 sum = (int)ROUND_POWER_OF_TWO(sum_long, 2); \
235 var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h)); \
236 return (var >= 0) ? (uint32_t)var : 0; \
237 }
238
239 #define HBD_VARIANCE_WXH_12_NEON(w, h) \
240 uint32_t aom_highbd_12_variance##w##x##h##_neon( \
241 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
242 int ref_stride, uint32_t *sse) { \
243 int sum; \
244 int64_t var; \
245 uint64_t sse_long = 0; \
246 int64_t sum_long = 0; \
247 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \
248 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \
249 highbd_variance_##w##xh_neon(src, src_stride, ref, ref_stride, h, \
250 &sse_long, &sum_long); \
251 *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8); \
252 sum = (int)ROUND_POWER_OF_TWO(sum_long, 4); \
253 var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h)); \
254 return (var >= 0) ? (uint32_t)var : 0; \
255 }
256
257 #define HBD_VARIANCE_WXH_12_XLARGE_NEON(w, h) \
258 uint32_t aom_highbd_12_variance##w##x##h##_neon( \
259 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
260 int ref_stride, uint32_t *sse) { \
261 int sum; \
262 int64_t var; \
263 uint64_t sse_long = 0; \
264 int64_t sum_long = 0; \
265 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \
266 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \
267 highbd_variance_##w##xh_xlarge_neon(src, src_stride, ref, ref_stride, h, \
268 &sse_long, &sum_long); \
269 *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8); \
270 sum = (int)ROUND_POWER_OF_TWO(sum_long, 4); \
271 var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h)); \
272 return (var >= 0) ? (uint32_t)var : 0; \
273 }
274
275 // 8-bit
276 HBD_VARIANCE_WXH_8_NEON(4, 4)
277 HBD_VARIANCE_WXH_8_NEON(4, 8)
278
279 HBD_VARIANCE_WXH_8_NEON(8, 4)
280 HBD_VARIANCE_WXH_8_NEON(8, 8)
281 HBD_VARIANCE_WXH_8_NEON(8, 16)
282
283 HBD_VARIANCE_WXH_8_NEON(16, 8)
284 HBD_VARIANCE_WXH_8_NEON(16, 16)
285 HBD_VARIANCE_WXH_8_NEON(16, 32)
286
287 HBD_VARIANCE_WXH_8_NEON(32, 16)
288 HBD_VARIANCE_WXH_8_NEON(32, 32)
289 HBD_VARIANCE_WXH_8_NEON(32, 64)
290
291 HBD_VARIANCE_WXH_8_NEON(64, 32)
292 HBD_VARIANCE_WXH_8_NEON(64, 64)
293 HBD_VARIANCE_WXH_8_NEON(64, 128)
294
295 HBD_VARIANCE_WXH_8_NEON(128, 64)
296 HBD_VARIANCE_WXH_8_NEON(128, 128)
297
298 // 10-bit
299 HBD_VARIANCE_WXH_10_NEON(4, 4)
300 HBD_VARIANCE_WXH_10_NEON(4, 8)
301
302 HBD_VARIANCE_WXH_10_NEON(8, 4)
303 HBD_VARIANCE_WXH_10_NEON(8, 8)
304 HBD_VARIANCE_WXH_10_NEON(8, 16)
305
306 HBD_VARIANCE_WXH_10_NEON(16, 8)
307 HBD_VARIANCE_WXH_10_NEON(16, 16)
308 HBD_VARIANCE_WXH_10_NEON(16, 32)
309
310 HBD_VARIANCE_WXH_10_NEON(32, 16)
311 HBD_VARIANCE_WXH_10_NEON(32, 32)
312 HBD_VARIANCE_WXH_10_NEON(32, 64)
313
314 HBD_VARIANCE_WXH_10_NEON(64, 32)
315 HBD_VARIANCE_WXH_10_NEON(64, 64)
316 HBD_VARIANCE_WXH_10_NEON(64, 128)
317
318 HBD_VARIANCE_WXH_10_NEON(128, 64)
319 HBD_VARIANCE_WXH_10_NEON(128, 128)
320
321 // 12-bit
322 HBD_VARIANCE_WXH_12_NEON(4, 4)
323 HBD_VARIANCE_WXH_12_NEON(4, 8)
324
325 HBD_VARIANCE_WXH_12_NEON(8, 4)
326 HBD_VARIANCE_WXH_12_NEON(8, 8)
327 HBD_VARIANCE_WXH_12_NEON(8, 16)
328
329 HBD_VARIANCE_WXH_12_NEON(16, 8)
330 HBD_VARIANCE_WXH_12_NEON(16, 16)
331 HBD_VARIANCE_WXH_12_NEON(16, 32)
332
333 HBD_VARIANCE_WXH_12_NEON(32, 16)
334 HBD_VARIANCE_WXH_12_NEON(32, 32)
335 HBD_VARIANCE_WXH_12_XLARGE_NEON(32, 64)
336
337 HBD_VARIANCE_WXH_12_XLARGE_NEON(64, 32)
338 HBD_VARIANCE_WXH_12_XLARGE_NEON(64, 64)
339 HBD_VARIANCE_WXH_12_XLARGE_NEON(64, 128)
340
341 HBD_VARIANCE_WXH_12_XLARGE_NEON(128, 64)
342 HBD_VARIANCE_WXH_12_XLARGE_NEON(128, 128)
343
344 #if !CONFIG_REALTIME_ONLY
345 // 8-bit
346 HBD_VARIANCE_WXH_8_NEON(4, 16)
347
348 HBD_VARIANCE_WXH_8_NEON(8, 32)
349
350 HBD_VARIANCE_WXH_8_NEON(16, 4)
351 HBD_VARIANCE_WXH_8_NEON(16, 64)
352
353 HBD_VARIANCE_WXH_8_NEON(32, 8)
354
355 HBD_VARIANCE_WXH_8_NEON(64, 16)
356
357 // 10-bit
358 HBD_VARIANCE_WXH_10_NEON(4, 16)
359
360 HBD_VARIANCE_WXH_10_NEON(8, 32)
361
362 HBD_VARIANCE_WXH_10_NEON(16, 4)
363 HBD_VARIANCE_WXH_10_NEON(16, 64)
364
365 HBD_VARIANCE_WXH_10_NEON(32, 8)
366
367 HBD_VARIANCE_WXH_10_NEON(64, 16)
368
369 // 12-bit
370 HBD_VARIANCE_WXH_12_NEON(4, 16)
371
372 HBD_VARIANCE_WXH_12_NEON(8, 32)
373
374 HBD_VARIANCE_WXH_12_NEON(16, 4)
375 HBD_VARIANCE_WXH_12_NEON(16, 64)
376
377 HBD_VARIANCE_WXH_12_NEON(32, 8)
378
379 HBD_VARIANCE_WXH_12_NEON(64, 16)
380
381 #endif // !CONFIG_REALTIME_ONLY
382
highbd_mse_wxh_neon(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,int w,int h,unsigned int * sse)383 static inline uint32_t highbd_mse_wxh_neon(const uint16_t *src_ptr,
384 int src_stride,
385 const uint16_t *ref_ptr,
386 int ref_stride, int w, int h,
387 unsigned int *sse) {
388 uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
389
390 int i = h;
391 do {
392 int j = 0;
393 do {
394 uint16x8_t s = vld1q_u16(src_ptr + j);
395 uint16x8_t r = vld1q_u16(ref_ptr + j);
396
397 uint16x8_t diff = vabdq_u16(s, r);
398
399 sse_u32[0] =
400 vmlal_u16(sse_u32[0], vget_low_u16(diff), vget_low_u16(diff));
401 sse_u32[1] =
402 vmlal_u16(sse_u32[1], vget_high_u16(diff), vget_high_u16(diff));
403
404 j += 8;
405 } while (j < w);
406
407 src_ptr += src_stride;
408 ref_ptr += ref_stride;
409 } while (--i != 0);
410
411 *sse = horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
412 return *sse;
413 }
414
415 #define HIGHBD_MSE_WXH_NEON(w, h) \
416 uint32_t aom_highbd_8_mse##w##x##h##_neon( \
417 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
418 int ref_stride, uint32_t *sse) { \
419 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \
420 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \
421 highbd_mse_wxh_neon(src, src_stride, ref, ref_stride, w, h, sse); \
422 return *sse; \
423 } \
424 \
425 uint32_t aom_highbd_10_mse##w##x##h##_neon( \
426 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
427 int ref_stride, uint32_t *sse) { \
428 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \
429 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \
430 highbd_mse_wxh_neon(src, src_stride, ref, ref_stride, w, h, sse); \
431 *sse = ROUND_POWER_OF_TWO(*sse, 4); \
432 return *sse; \
433 } \
434 \
435 uint32_t aom_highbd_12_mse##w##x##h##_neon( \
436 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
437 int ref_stride, uint32_t *sse) { \
438 uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \
439 uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \
440 highbd_mse_wxh_neon(src, src_stride, ref, ref_stride, w, h, sse); \
441 *sse = ROUND_POWER_OF_TWO(*sse, 8); \
442 return *sse; \
443 }
444
445 HIGHBD_MSE_WXH_NEON(16, 16)
446 HIGHBD_MSE_WXH_NEON(16, 8)
447 HIGHBD_MSE_WXH_NEON(8, 16)
448 HIGHBD_MSE_WXH_NEON(8, 8)
449
450 #undef HIGHBD_MSE_WXH_NEON
451
mse_accumulate_u16_8x2(uint64x2_t sum,uint16x8_t s0,uint16x8_t s1,uint16x8_t d0,uint16x8_t d1)452 static inline uint64x2_t mse_accumulate_u16_8x2(uint64x2_t sum, uint16x8_t s0,
453 uint16x8_t s1, uint16x8_t d0,
454 uint16x8_t d1) {
455 uint16x8_t e0 = vabdq_u16(s0, d0);
456 uint16x8_t e1 = vabdq_u16(s1, d1);
457
458 uint32x4_t mse = vmull_u16(vget_low_u16(e0), vget_low_u16(e0));
459 mse = vmlal_u16(mse, vget_high_u16(e0), vget_high_u16(e0));
460 mse = vmlal_u16(mse, vget_low_u16(e1), vget_low_u16(e1));
461 mse = vmlal_u16(mse, vget_high_u16(e1), vget_high_u16(e1));
462
463 return vpadalq_u32(sum, mse);
464 }
465
aom_mse_wxh_16bit_highbd_neon(uint16_t * dst,int dstride,uint16_t * src,int sstride,int w,int h)466 uint64_t aom_mse_wxh_16bit_highbd_neon(uint16_t *dst, int dstride,
467 uint16_t *src, int sstride, int w,
468 int h) {
469 assert((w == 8 || w == 4) && (h == 8 || h == 4));
470
471 uint64x2_t sum = vdupq_n_u64(0);
472
473 if (w == 8) {
474 do {
475 uint16x8_t d0 = vld1q_u16(dst + 0 * dstride);
476 uint16x8_t d1 = vld1q_u16(dst + 1 * dstride);
477 uint16x8_t s0 = vld1q_u16(src + 0 * sstride);
478 uint16x8_t s1 = vld1q_u16(src + 1 * sstride);
479
480 sum = mse_accumulate_u16_8x2(sum, s0, s1, d0, d1);
481
482 dst += 2 * dstride;
483 src += 2 * sstride;
484 h -= 2;
485 } while (h != 0);
486 } else { // w == 4
487 do {
488 uint16x8_t d0 = load_unaligned_u16_4x2(dst + 0 * dstride, dstride);
489 uint16x8_t d1 = load_unaligned_u16_4x2(dst + 2 * dstride, dstride);
490 uint16x8_t s0 = load_unaligned_u16_4x2(src + 0 * sstride, sstride);
491 uint16x8_t s1 = load_unaligned_u16_4x2(src + 2 * sstride, sstride);
492
493 sum = mse_accumulate_u16_8x2(sum, s0, s1, d0, d1);
494
495 dst += 4 * dstride;
496 src += 4 * sstride;
497 h -= 4;
498 } while (h != 0);
499 }
500
501 return horizontal_add_u64x2(sum);
502 }
503