xref: /aosp_15_r20/external/libaom/av1/encoder/arm/av1_k_means_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 
14 #include "aom_dsp/arm/sum_neon.h"
15 #include "config/aom_config.h"
16 #include "config/av1_rtcd.h"
17 
k_means_multiply_add_neon(const int16x8_t a)18 static int32x4_t k_means_multiply_add_neon(const int16x8_t a) {
19   const int32x4_t l = vmull_s16(vget_low_s16(a), vget_low_s16(a));
20   const int32x4_t h = vmull_s16(vget_high_s16(a), vget_high_s16(a));
21 #if AOM_ARCH_AARCH64
22   return vpaddq_s32(l, h);
23 #else
24   const int32x2_t dl = vpadd_s32(vget_low_s32(l), vget_high_s32(l));
25   const int32x2_t dh = vpadd_s32(vget_low_s32(h), vget_high_s32(h));
26   return vcombine_s32(dl, dh);
27 #endif
28 }
29 
av1_calc_indices_dim1_neon(const int16_t * data,const int16_t * centroids,uint8_t * indices,int64_t * total_dist,int n,int k)30 void av1_calc_indices_dim1_neon(const int16_t *data, const int16_t *centroids,
31                                 uint8_t *indices, int64_t *total_dist, int n,
32                                 int k) {
33   int64x2_t sum = vdupq_n_s64(0);
34   int16x8_t cents[PALETTE_MAX_SIZE];
35   for (int j = 0; j < k; ++j) {
36     cents[j] = vdupq_n_s16(centroids[j]);
37   }
38 
39   for (int i = 0; i < n; i += 8) {
40     const int16x8_t in = vld1q_s16(data);
41     uint16x8_t ind = vdupq_n_u16(0);
42     // Compute the distance to the first centroid.
43     int16x8_t dist_min = vabdq_s16(in, cents[0]);
44 
45     for (int j = 1; j < k; ++j) {
46       // Compute the distance to the centroid.
47       const int16x8_t dist = vabdq_s16(in, cents[j]);
48       // Compare to the minimal one.
49       const uint16x8_t cmp = vcgtq_s16(dist_min, dist);
50       dist_min = vminq_s16(dist_min, dist);
51       const uint16x8_t ind1 = vdupq_n_u16(j);
52       ind = vbslq_u16(cmp, ind1, ind);
53     }
54     if (total_dist) {
55       // Square, convert to 32 bit and add together.
56       const int32x4_t l =
57           vmull_s16(vget_low_s16(dist_min), vget_low_s16(dist_min));
58       const int32x4_t sum32_tmp =
59           vmlal_s16(l, vget_high_s16(dist_min), vget_high_s16(dist_min));
60       // Pairwise sum, convert to 64 bit and add to sum.
61       sum = vpadalq_s32(sum, sum32_tmp);
62     }
63     vst1_u8(indices, vmovn_u16(ind));
64     indices += 8;
65     data += 8;
66   }
67   if (total_dist) {
68     *total_dist = horizontal_add_s64x2(sum);
69   }
70 }
71 
av1_calc_indices_dim2_neon(const int16_t * data,const int16_t * centroids,uint8_t * indices,int64_t * total_dist,int n,int k)72 void av1_calc_indices_dim2_neon(const int16_t *data, const int16_t *centroids,
73                                 uint8_t *indices, int64_t *total_dist, int n,
74                                 int k) {
75   int64x2_t sum = vdupq_n_s64(0);
76   uint32x4_t ind[2];
77   int16x8_t cents[PALETTE_MAX_SIZE];
78   for (int j = 0; j < k; ++j) {
79     const int16_t cx = centroids[2 * j], cy = centroids[2 * j + 1];
80     const int16_t cxcy[8] = { cx, cy, cx, cy, cx, cy, cx, cy };
81     cents[j] = vld1q_s16(cxcy);
82   }
83 
84   for (int i = 0; i < n; i += 8) {
85     for (int l = 0; l < 2; ++l) {
86       const int16x8_t in = vld1q_s16(data);
87       ind[l] = vdupq_n_u32(0);
88       // Compute the distance to the first centroid.
89       int16x8_t d1 = vsubq_s16(in, cents[0]);
90       int32x4_t dist_min = k_means_multiply_add_neon(d1);
91 
92       for (int j = 1; j < k; ++j) {
93         // Compute the distance to the centroid.
94         d1 = vsubq_s16(in, cents[j]);
95         const int32x4_t dist = k_means_multiply_add_neon(d1);
96         // Compare to the minimal one.
97         const uint32x4_t cmp = vcgtq_s32(dist_min, dist);
98         dist_min = vminq_s32(dist_min, dist);
99         const uint32x4_t ind1 = vdupq_n_u32(j);
100         ind[l] = vbslq_u32(cmp, ind1, ind[l]);
101       }
102       if (total_dist) {
103         // Pairwise sum, convert to 64 bit and add to sum.
104         sum = vpadalq_s32(sum, dist_min);
105       }
106       data += 8;
107     }
108     // Cast to 8 bit and store.
109     vst1_u8(indices,
110             vmovn_u16(vcombine_u16(vmovn_u32(ind[0]), vmovn_u32(ind[1]))));
111     indices += 8;
112   }
113   if (total_dist) {
114     *total_dist = horizontal_add_s64x2(sum);
115   }
116 }
117