xref: /aosp_15_r20/external/libaom/aom_dsp/flow_estimation/arm/disflow_sve.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2024, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include "aom_dsp/flow_estimation/disflow.h"
13 
14 #include <arm_neon.h>
15 #include <arm_sve.h>
16 #include <math.h>
17 
18 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/arm/sum_neon.h"
21 #include "aom_dsp/flow_estimation/arm/disflow_neon.h"
22 #include "config/aom_config.h"
23 #include "config/aom_dsp_rtcd.h"
24 
25 DECLARE_ALIGNED(16, static const uint16_t, kDeinterleaveTbl[8]) = {
26   0, 2, 4, 6, 1, 3, 5, 7,
27 };
28 
29 // Compare two regions of width x height pixels, one rooted at position
30 // (x, y) in src and the other at (x + u, y + v) in ref.
31 // This function returns the sum of squared pixel differences between
32 // the two regions.
compute_flow_error(const uint8_t * src,const uint8_t * ref,int width,int height,int stride,int x,int y,double u,double v,int16_t * dt)33 static inline void compute_flow_error(const uint8_t *src, const uint8_t *ref,
34                                       int width, int height, int stride, int x,
35                                       int y, double u, double v, int16_t *dt) {
36   // Split offset into integer and fractional parts, and compute cubic
37   // interpolation kernels
38   const int u_int = (int)floor(u);
39   const int v_int = (int)floor(v);
40   const double u_frac = u - floor(u);
41   const double v_frac = v - floor(v);
42 
43   int h_kernel[4];
44   int v_kernel[4];
45   get_cubic_kernel_int(u_frac, h_kernel);
46   get_cubic_kernel_int(v_frac, v_kernel);
47 
48   int16_t tmp_[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 3)];
49 
50   // Clamp coordinates so that all pixels we fetch will remain within the
51   // allocated border region, but allow them to go far enough out that
52   // the border pixels' values do not change.
53   // Since we are calculating an 8x8 block, the bottom-right pixel
54   // in the block has coordinates (x0 + 7, y0 + 7). Then, the cubic
55   // interpolation has 4 taps, meaning that the output of pixel
56   // (x_w, y_w) depends on the pixels in the range
57   // ([x_w - 1, x_w + 2], [y_w - 1, y_w + 2]).
58   //
59   // Thus the most extreme coordinates which will be fetched are
60   // (x0 - 1, y0 - 1) and (x0 + 9, y0 + 9).
61   const int x0 = clamp(x + u_int, -9, width);
62   const int y0 = clamp(y + v_int, -9, height);
63 
64   // Horizontal convolution.
65   const uint8_t *ref_start = ref + (y0 - 1) * stride + (x0 - 1);
66   const int16x4_t h_kernel_s16 = vmovn_s32(vld1q_s32(h_kernel));
67   const int16x8_t h_filter = vcombine_s16(h_kernel_s16, vdup_n_s16(0));
68   const uint16x8_t idx = vld1q_u16(kDeinterleaveTbl);
69 
70   for (int i = 0; i < DISFLOW_PATCH_SIZE + 3; ++i) {
71     svuint16_t r0 = svld1ub_u16(svptrue_b16(), ref_start + i * stride + 0);
72     svuint16_t r1 = svld1ub_u16(svptrue_b16(), ref_start + i * stride + 1);
73     svuint16_t r2 = svld1ub_u16(svptrue_b16(), ref_start + i * stride + 2);
74     svuint16_t r3 = svld1ub_u16(svptrue_b16(), ref_start + i * stride + 3);
75 
76     int16x8_t s0 = vreinterpretq_s16_u16(svget_neonq_u16(r0));
77     int16x8_t s1 = vreinterpretq_s16_u16(svget_neonq_u16(r1));
78     int16x8_t s2 = vreinterpretq_s16_u16(svget_neonq_u16(r2));
79     int16x8_t s3 = vreinterpretq_s16_u16(svget_neonq_u16(r3));
80 
81     int64x2_t sum04 = aom_svdot_lane_s16(vdupq_n_s64(0), s0, h_filter, 0);
82     int64x2_t sum15 = aom_svdot_lane_s16(vdupq_n_s64(0), s1, h_filter, 0);
83     int64x2_t sum26 = aom_svdot_lane_s16(vdupq_n_s64(0), s2, h_filter, 0);
84     int64x2_t sum37 = aom_svdot_lane_s16(vdupq_n_s64(0), s3, h_filter, 0);
85 
86     int32x4_t res0 = vcombine_s32(vmovn_s64(sum04), vmovn_s64(sum15));
87     int32x4_t res1 = vcombine_s32(vmovn_s64(sum26), vmovn_s64(sum37));
88 
89     // 6 is the maximum allowable number of extra bits which will avoid
90     // the intermediate values overflowing an int16_t. The most extreme
91     // intermediate value occurs when:
92     // * The input pixels are [0, 255, 255, 0]
93     // * u_frac = 0.5
94     // In this case, the un-scaled output is 255 * 1.125 = 286.875.
95     // As an integer with 6 fractional bits, that is 18360, which fits
96     // in an int16_t. But with 7 fractional bits it would be 36720,
97     // which is too large.
98     int16x8_t res = vcombine_s16(vrshrn_n_s32(res0, DISFLOW_INTERP_BITS - 6),
99                                  vrshrn_n_s32(res1, DISFLOW_INTERP_BITS - 6));
100 
101     res = aom_tbl_s16(res, idx);
102 
103     vst1q_s16(tmp_ + i * DISFLOW_PATCH_SIZE, res);
104   }
105 
106   // Vertical convolution.
107   int16x4_t v_filter = vmovn_s32(vld1q_s32(v_kernel));
108   int16_t *tmp_start = tmp_ + DISFLOW_PATCH_SIZE;
109 
110   for (int i = 0; i < DISFLOW_PATCH_SIZE; ++i) {
111     int16x8_t t0 = vld1q_s16(tmp_start + (i - 1) * DISFLOW_PATCH_SIZE);
112     int16x8_t t1 = vld1q_s16(tmp_start + i * DISFLOW_PATCH_SIZE);
113     int16x8_t t2 = vld1q_s16(tmp_start + (i + 1) * DISFLOW_PATCH_SIZE);
114     int16x8_t t3 = vld1q_s16(tmp_start + (i + 2) * DISFLOW_PATCH_SIZE);
115 
116     int32x4_t sum_lo = vmull_lane_s16(vget_low_s16(t0), v_filter, 0);
117     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(t1), v_filter, 1);
118     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(t2), v_filter, 2);
119     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(t3), v_filter, 3);
120 
121     int32x4_t sum_hi = vmull_lane_s16(vget_high_s16(t0), v_filter, 0);
122     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(t1), v_filter, 1);
123     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(t2), v_filter, 2);
124     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(t3), v_filter, 3);
125 
126     uint8x8_t s = vld1_u8(src + (i + y) * stride + x);
127     int16x8_t s_s16 = vreinterpretq_s16_u16(vshll_n_u8(s, 3));
128 
129     // This time, we have to round off the 6 extra bits which were kept
130     // earlier, but we also want to keep DISFLOW_DERIV_SCALE_LOG2 extra bits
131     // of precision to match the scale of the dx and dy arrays.
132     sum_lo = vrshrq_n_s32(sum_lo,
133                           DISFLOW_INTERP_BITS + 6 - DISFLOW_DERIV_SCALE_LOG2);
134     sum_hi = vrshrq_n_s32(sum_hi,
135                           DISFLOW_INTERP_BITS + 6 - DISFLOW_DERIV_SCALE_LOG2);
136     int32x4_t err_lo = vsubw_s16(sum_lo, vget_low_s16(s_s16));
137     int32x4_t err_hi = vsubw_s16(sum_hi, vget_high_s16(s_s16));
138     vst1q_s16(dt + i * DISFLOW_PATCH_SIZE,
139               vcombine_s16(vmovn_s32(err_lo), vmovn_s32(err_hi)));
140   }
141 }
142 
143 // Computes the components of the system of equations used to solve for
144 // a flow vector.
145 //
146 // The flow equations are a least-squares system, derived as follows:
147 //
148 // For each pixel in the patch, we calculate the current error `dt`,
149 // and the x and y gradients `dx` and `dy` of the source patch.
150 // This means that, to first order, the squared error for this pixel is
151 //
152 //    (dt + u * dx + v * dy)^2
153 //
154 // where (u, v) are the incremental changes to the flow vector.
155 //
156 // We then want to find the values of u and v which minimize the sum
157 // of the squared error across all pixels. Conveniently, this fits exactly
158 // into the form of a least squares problem, with one equation
159 //
160 //   u * dx + v * dy = -dt
161 //
162 // for each pixel.
163 //
164 // Summing across all pixels in a square window of size DISFLOW_PATCH_SIZE,
165 // and absorbing the - sign elsewhere, this results in the least squares system
166 //
167 //   M = |sum(dx * dx)  sum(dx * dy)|
168 //       |sum(dx * dy)  sum(dy * dy)|
169 //
170 //   b = |sum(dx * dt)|
171 //       |sum(dy * dt)|
compute_flow_matrix(const int16_t * dx,int dx_stride,const int16_t * dy,int dy_stride,double * M_inv)172 static inline void compute_flow_matrix(const int16_t *dx, int dx_stride,
173                                        const int16_t *dy, int dy_stride,
174                                        double *M_inv) {
175   int64x2_t sum[3] = { vdupq_n_s64(0), vdupq_n_s64(0), vdupq_n_s64(0) };
176 
177   for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
178     int16x8_t x = vld1q_s16(dx + i * dx_stride);
179     int16x8_t y = vld1q_s16(dy + i * dy_stride);
180 
181     sum[0] = aom_sdotq_s16(sum[0], x, x);
182     sum[1] = aom_sdotq_s16(sum[1], x, y);
183     sum[2] = aom_sdotq_s16(sum[2], y, y);
184   }
185 
186   sum[0] = vpaddq_s64(sum[0], sum[1]);
187   sum[2] = vpaddq_s64(sum[1], sum[2]);
188   int32x4_t res = vcombine_s32(vmovn_s64(sum[0]), vmovn_s64(sum[2]));
189 
190   // Apply regularization
191   // We follow the standard regularization method of adding `k * I` before
192   // inverting. This ensures that the matrix will be invertible.
193   //
194   // Setting the regularization strength k to 1 seems to work well here, as
195   // typical values coming from the other equations are very large (1e5 to
196   // 1e6, with an upper limit of around 6e7, at the time of writing).
197   // It also preserves the property that all matrix values are whole numbers,
198   // which is convenient for integerized SIMD implementation.
199 
200   double M0 = (double)vgetq_lane_s32(res, 0) + 1;
201   double M1 = (double)vgetq_lane_s32(res, 1);
202   double M2 = (double)vgetq_lane_s32(res, 2);
203   double M3 = (double)vgetq_lane_s32(res, 3) + 1;
204 
205   // Invert matrix M.
206   double det = (M0 * M3) - (M1 * M2);
207   assert(det >= 1);
208   const double det_inv = 1 / det;
209 
210   M_inv[0] = M3 * det_inv;
211   M_inv[1] = -M1 * det_inv;
212   M_inv[2] = -M2 * det_inv;
213   M_inv[3] = M0 * det_inv;
214 }
215 
compute_flow_vector(const int16_t * dx,int dx_stride,const int16_t * dy,int dy_stride,const int16_t * dt,int dt_stride,int * b)216 static inline void compute_flow_vector(const int16_t *dx, int dx_stride,
217                                        const int16_t *dy, int dy_stride,
218                                        const int16_t *dt, int dt_stride,
219                                        int *b) {
220   int64x2_t b_s64[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
221 
222   for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
223     int16x8_t dx16 = vld1q_s16(dx + i * dx_stride);
224     int16x8_t dy16 = vld1q_s16(dy + i * dy_stride);
225     int16x8_t dt16 = vld1q_s16(dt + i * dt_stride);
226 
227     b_s64[0] = aom_sdotq_s16(b_s64[0], dx16, dt16);
228     b_s64[1] = aom_sdotq_s16(b_s64[1], dy16, dt16);
229   }
230 
231   b_s64[0] = vpaddq_s64(b_s64[0], b_s64[1]);
232   vst1_s32(b, vmovn_s64(b_s64[0]));
233 }
234 
aom_compute_flow_at_point_sve(const uint8_t * src,const uint8_t * ref,int x,int y,int width,int height,int stride,double * u,double * v)235 void aom_compute_flow_at_point_sve(const uint8_t *src, const uint8_t *ref,
236                                    int x, int y, int width, int height,
237                                    int stride, double *u, double *v) {
238   double M_inv[4];
239   int b[2];
240   int16_t dt[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE];
241   int16_t dx[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE];
242   int16_t dy[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE];
243 
244   // Compute gradients within this patch
245   const uint8_t *src_patch = &src[y * stride + x];
246   sobel_filter_x(src_patch, stride, dx, DISFLOW_PATCH_SIZE);
247   sobel_filter_y(src_patch, stride, dy, DISFLOW_PATCH_SIZE);
248 
249   compute_flow_matrix(dx, DISFLOW_PATCH_SIZE, dy, DISFLOW_PATCH_SIZE, M_inv);
250 
251   for (int itr = 0; itr < DISFLOW_MAX_ITR; itr++) {
252     compute_flow_error(src, ref, width, height, stride, x, y, *u, *v, dt);
253     compute_flow_vector(dx, DISFLOW_PATCH_SIZE, dy, DISFLOW_PATCH_SIZE, dt,
254                         DISFLOW_PATCH_SIZE, b);
255 
256     // Solve flow equations to find a better estimate for the flow vector
257     // at this point
258     const double step_u = M_inv[0] * b[0] + M_inv[1] * b[1];
259     const double step_v = M_inv[2] * b[0] + M_inv[3] * b[1];
260     *u += fclamp(step_u * DISFLOW_STEP_SIZE, -2, 2);
261     *v += fclamp(step_v * DISFLOW_STEP_SIZE, -2, 2);
262 
263     if (fabs(step_u) + fabs(step_v) < DISFLOW_STEP_SIZE_THRESOLD) {
264       // Stop iteration when we're close to convergence
265       break;
266     }
267   }
268 }
269