1 /*
2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13 #include <arm_sve.h>
14
15 #include <assert.h>
16 #include <stdint.h>
17
18 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/arm/sum_neon.h"
21 #include "aom_dsp/arm/transpose_neon.h"
22 #include "av1/encoder/arm/pickrst_neon.h"
23 #include "av1/encoder/arm/pickrst_sve.h"
24 #include "av1/encoder/pickrst.h"
25
highbd_find_average_sve(const uint16_t * src,int src_stride,int width,int height)26 static inline uint16_t highbd_find_average_sve(const uint16_t *src,
27 int src_stride, int width,
28 int height) {
29 uint64x2_t avg_u64 = vdupq_n_u64(0);
30 uint16x8_t ones = vdupq_n_u16(1);
31
32 // Use a predicate to compute the last columns.
33 svbool_t pattern = svwhilelt_b16_u32(0, width % 8 == 0 ? 8 : width % 8);
34
35 int h = height;
36 do {
37 int j = width;
38 const uint16_t *src_ptr = src;
39 while (j > 8) {
40 uint16x8_t s = vld1q_u16(src_ptr);
41 avg_u64 = aom_udotq_u16(avg_u64, s, ones);
42
43 j -= 8;
44 src_ptr += 8;
45 }
46 uint16x8_t s_end = svget_neonq_u16(svld1_u16(pattern, src_ptr));
47 avg_u64 = aom_udotq_u16(avg_u64, s_end, ones);
48
49 src += src_stride;
50 } while (--h != 0);
51 return (uint16_t)(vaddvq_u64(avg_u64) / (width * height));
52 }
53
sub_avg_block_highbd_sve(const uint16_t * buf,int buf_stride,int16_t avg,int width,int height,int16_t * buf_avg,int buf_avg_stride)54 static inline void sub_avg_block_highbd_sve(const uint16_t *buf, int buf_stride,
55 int16_t avg, int width, int height,
56 int16_t *buf_avg,
57 int buf_avg_stride) {
58 uint16x8_t avg_u16 = vdupq_n_u16(avg);
59
60 // Use a predicate to compute the last columns.
61 svbool_t pattern = svwhilelt_b16_u32(0, width % 8 == 0 ? 8 : width % 8);
62
63 uint16x8_t avg_end = svget_neonq_u16(svdup_n_u16_z(pattern, avg));
64
65 do {
66 int j = width;
67 const uint16_t *buf_ptr = buf;
68 int16_t *buf_avg_ptr = buf_avg;
69 while (j > 8) {
70 uint16x8_t d = vld1q_u16(buf_ptr);
71 vst1q_s16(buf_avg_ptr, vreinterpretq_s16_u16(vsubq_u16(d, avg_u16)));
72
73 j -= 8;
74 buf_ptr += 8;
75 buf_avg_ptr += 8;
76 }
77 uint16x8_t d_end = svget_neonq_u16(svld1_u16(pattern, buf_ptr));
78 vst1q_s16(buf_avg_ptr, vreinterpretq_s16_u16(vsubq_u16(d_end, avg_end)));
79
80 buf += buf_stride;
81 buf_avg += buf_avg_stride;
82 } while (--height > 0);
83 }
84
av1_compute_stats_highbd_sve(int32_t wiener_win,const uint8_t * dgd8,const uint8_t * src8,int16_t * dgd_avg,int16_t * src_avg,int32_t h_start,int32_t h_end,int32_t v_start,int32_t v_end,int32_t dgd_stride,int32_t src_stride,int64_t * M,int64_t * H,aom_bit_depth_t bit_depth)85 void av1_compute_stats_highbd_sve(int32_t wiener_win, const uint8_t *dgd8,
86 const uint8_t *src8, int16_t *dgd_avg,
87 int16_t *src_avg, int32_t h_start,
88 int32_t h_end, int32_t v_start, int32_t v_end,
89 int32_t dgd_stride, int32_t src_stride,
90 int64_t *M, int64_t *H,
91 aom_bit_depth_t bit_depth) {
92 const int32_t wiener_win2 = wiener_win * wiener_win;
93 const int32_t wiener_halfwin = (wiener_win >> 1);
94 const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
95 const uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
96 const int32_t width = h_end - h_start;
97 const int32_t height = v_end - v_start;
98 const int32_t d_stride = (width + 2 * wiener_halfwin + 15) & ~15;
99 const int32_t s_stride = (width + 15) & ~15;
100
101 const uint16_t *dgd_start = dgd + h_start + v_start * dgd_stride;
102 const uint16_t *src_start = src + h_start + v_start * src_stride;
103 const uint16_t avg =
104 highbd_find_average_sve(dgd_start, dgd_stride, width, height);
105
106 sub_avg_block_highbd_sve(src_start, src_stride, avg, width, height, src_avg,
107 s_stride);
108 sub_avg_block_highbd_sve(
109 dgd + (v_start - wiener_halfwin) * dgd_stride + h_start - wiener_halfwin,
110 dgd_stride, avg, width + 2 * wiener_halfwin, height + 2 * wiener_halfwin,
111 dgd_avg, d_stride);
112
113 if (wiener_win == WIENER_WIN) {
114 compute_stats_win7_sve(dgd_avg, d_stride, src_avg, s_stride, width, height,
115 M, H);
116 } else {
117 assert(wiener_win == WIENER_WIN_CHROMA);
118 compute_stats_win5_sve(dgd_avg, d_stride, src_avg, s_stride, width, height,
119 M, H);
120 }
121
122 // H is a symmetric matrix, so we only need to fill out the upper triangle.
123 // We can copy it down to the lower triangle outside the (i, j) loops.
124 if (bit_depth == AOM_BITS_8) {
125 diagonal_copy_stats_neon(wiener_win2, H);
126 } else if (bit_depth == AOM_BITS_10) { // bit_depth == EB_TEN_BIT
127 const int32_t k4 = wiener_win2 & ~3;
128
129 int32_t k = 0;
130 do {
131 int64x2_t dst = div4_neon(vld1q_s64(M + k));
132 vst1q_s64(M + k, dst);
133 dst = div4_neon(vld1q_s64(M + k + 2));
134 vst1q_s64(M + k + 2, dst);
135 H[k * wiener_win2 + k] /= 4;
136 k += 4;
137 } while (k < k4);
138
139 H[k * wiener_win2 + k] /= 4;
140
141 for (; k < wiener_win2; ++k) {
142 M[k] /= 4;
143 }
144
145 div4_diagonal_copy_stats_neon(wiener_win2, H);
146 } else { // bit_depth == AOM_BITS_12
147 const int32_t k4 = wiener_win2 & ~3;
148
149 int32_t k = 0;
150 do {
151 int64x2_t dst = div16_neon(vld1q_s64(M + k));
152 vst1q_s64(M + k, dst);
153 dst = div16_neon(vld1q_s64(M + k + 2));
154 vst1q_s64(M + k + 2, dst);
155 H[k * wiener_win2 + k] /= 16;
156 k += 4;
157 } while (k < k4);
158
159 H[k * wiener_win2 + k] /= 16;
160
161 for (; k < wiener_win2; ++k) {
162 M[k] /= 16;
163 }
164
165 div16_diagonal_copy_stats_neon(wiener_win2, H);
166 }
167 }
168