xref: /aosp_15_r20/external/libaom/av1/encoder/arm/av1_error_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2019, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 
15 #include "config/aom_config.h"
16 #include "config/av1_rtcd.h"
17 
18 #include "aom_dsp/aom_dsp_common.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/arm/sum_neon.h"
21 
av1_block_error_neon(const tran_low_t * coeff,const tran_low_t * dqcoeff,intptr_t block_size,int64_t * ssz)22 int64_t av1_block_error_neon(const tran_low_t *coeff, const tran_low_t *dqcoeff,
23                              intptr_t block_size, int64_t *ssz) {
24   uint64x2_t err_u64 = vdupq_n_u64(0);
25   int64x2_t ssz_s64 = vdupq_n_s64(0);
26 
27   assert(block_size >= 16);
28   assert((block_size % 16) == 0);
29 
30   do {
31     const int16x8_t c0 = load_tran_low_to_s16q(coeff);
32     const int16x8_t c1 = load_tran_low_to_s16q(coeff + 8);
33     const int16x8_t d0 = load_tran_low_to_s16q(dqcoeff);
34     const int16x8_t d1 = load_tran_low_to_s16q(dqcoeff + 8);
35 
36     const uint16x8_t diff0 = vreinterpretq_u16_s16(vabdq_s16(c0, d0));
37     const uint16x8_t diff1 = vreinterpretq_u16_s16(vabdq_s16(c1, d1));
38 
39     // By operating on unsigned integers we can store up to 4 squared diff in a
40     // 32-bit element before having to widen to 64 bits.
41     uint32x4_t err = vmull_u16(vget_low_u16(diff0), vget_low_u16(diff0));
42     err = vmlal_u16(err, vget_high_u16(diff0), vget_high_u16(diff0));
43     err = vmlal_u16(err, vget_low_u16(diff1), vget_low_u16(diff1));
44     err = vmlal_u16(err, vget_high_u16(diff1), vget_high_u16(diff1));
45     err_u64 = vpadalq_u32(err_u64, err);
46 
47     // We can't do the same here as we're operating on signed integers, so we
48     // can only accumulate 2 squares.
49     int32x4_t ssz0 = vmull_s16(vget_low_s16(c0), vget_low_s16(c0));
50     ssz0 = vmlal_s16(ssz0, vget_high_s16(c0), vget_high_s16(c0));
51     ssz_s64 = vpadalq_s32(ssz_s64, ssz0);
52 
53     int32x4_t ssz1 = vmull_s16(vget_low_s16(c1), vget_low_s16(c1));
54     ssz1 = vmlal_s16(ssz1, vget_high_s16(c1), vget_high_s16(c1));
55     ssz_s64 = vpadalq_s32(ssz_s64, ssz1);
56 
57     coeff += 16;
58     dqcoeff += 16;
59     block_size -= 16;
60   } while (block_size != 0);
61 
62   *ssz = horizontal_add_s64x2(ssz_s64);
63   return (int64_t)horizontal_add_u64x2(err_u64);
64 }
65 
av1_block_error_lp_neon(const int16_t * coeff,const int16_t * dqcoeff,intptr_t block_size)66 int64_t av1_block_error_lp_neon(const int16_t *coeff, const int16_t *dqcoeff,
67                                 intptr_t block_size) {
68   uint64x2_t err_u64 = vdupq_n_u64(0);
69 
70   assert(block_size >= 16);
71   assert((block_size % 16) == 0);
72 
73   do {
74     const int16x8_t c0 = vld1q_s16(coeff);
75     const int16x8_t c1 = vld1q_s16(coeff + 8);
76     const int16x8_t d0 = vld1q_s16(dqcoeff);
77     const int16x8_t d1 = vld1q_s16(dqcoeff + 8);
78 
79     const uint16x8_t diff0 = vreinterpretq_u16_s16(vabdq_s16(c0, d0));
80     const uint16x8_t diff1 = vreinterpretq_u16_s16(vabdq_s16(c1, d1));
81 
82     // By operating on unsigned integers we can store up to 4 squared diff in a
83     // 32-bit element before having to widen to 64 bits.
84     uint32x4_t err = vmull_u16(vget_low_u16(diff0), vget_low_u16(diff0));
85     err = vmlal_u16(err, vget_high_u16(diff0), vget_high_u16(diff0));
86     err = vmlal_u16(err, vget_low_u16(diff1), vget_low_u16(diff1));
87     err = vmlal_u16(err, vget_high_u16(diff1), vget_high_u16(diff1));
88     err_u64 = vpadalq_u32(err_u64, err);
89 
90     coeff += 16;
91     dqcoeff += 16;
92     block_size -= 16;
93   } while (block_size != 0);
94 
95   return (int64_t)horizontal_add_u64x2(err_u64);
96 }
97