1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13
14 #include "aom_dsp/arm/sum_neon.h"
15 #include "config/aom_config.h"
16 #include "config/av1_rtcd.h"
17
k_means_multiply_add_neon(const int16x8_t a)18 static int32x4_t k_means_multiply_add_neon(const int16x8_t a) {
19 const int32x4_t l = vmull_s16(vget_low_s16(a), vget_low_s16(a));
20 const int32x4_t h = vmull_s16(vget_high_s16(a), vget_high_s16(a));
21 #if AOM_ARCH_AARCH64
22 return vpaddq_s32(l, h);
23 #else
24 const int32x2_t dl = vpadd_s32(vget_low_s32(l), vget_high_s32(l));
25 const int32x2_t dh = vpadd_s32(vget_low_s32(h), vget_high_s32(h));
26 return vcombine_s32(dl, dh);
27 #endif
28 }
29
av1_calc_indices_dim1_neon(const int16_t * data,const int16_t * centroids,uint8_t * indices,int64_t * total_dist,int n,int k)30 void av1_calc_indices_dim1_neon(const int16_t *data, const int16_t *centroids,
31 uint8_t *indices, int64_t *total_dist, int n,
32 int k) {
33 int64x2_t sum = vdupq_n_s64(0);
34 int16x8_t cents[PALETTE_MAX_SIZE];
35 for (int j = 0; j < k; ++j) {
36 cents[j] = vdupq_n_s16(centroids[j]);
37 }
38
39 for (int i = 0; i < n; i += 8) {
40 const int16x8_t in = vld1q_s16(data);
41 uint16x8_t ind = vdupq_n_u16(0);
42 // Compute the distance to the first centroid.
43 int16x8_t dist_min = vabdq_s16(in, cents[0]);
44
45 for (int j = 1; j < k; ++j) {
46 // Compute the distance to the centroid.
47 const int16x8_t dist = vabdq_s16(in, cents[j]);
48 // Compare to the minimal one.
49 const uint16x8_t cmp = vcgtq_s16(dist_min, dist);
50 dist_min = vminq_s16(dist_min, dist);
51 const uint16x8_t ind1 = vdupq_n_u16(j);
52 ind = vbslq_u16(cmp, ind1, ind);
53 }
54 if (total_dist) {
55 // Square, convert to 32 bit and add together.
56 const int32x4_t l =
57 vmull_s16(vget_low_s16(dist_min), vget_low_s16(dist_min));
58 const int32x4_t sum32_tmp =
59 vmlal_s16(l, vget_high_s16(dist_min), vget_high_s16(dist_min));
60 // Pairwise sum, convert to 64 bit and add to sum.
61 sum = vpadalq_s32(sum, sum32_tmp);
62 }
63 vst1_u8(indices, vmovn_u16(ind));
64 indices += 8;
65 data += 8;
66 }
67 if (total_dist) {
68 *total_dist = horizontal_add_s64x2(sum);
69 }
70 }
71
av1_calc_indices_dim2_neon(const int16_t * data,const int16_t * centroids,uint8_t * indices,int64_t * total_dist,int n,int k)72 void av1_calc_indices_dim2_neon(const int16_t *data, const int16_t *centroids,
73 uint8_t *indices, int64_t *total_dist, int n,
74 int k) {
75 int64x2_t sum = vdupq_n_s64(0);
76 uint32x4_t ind[2];
77 int16x8_t cents[PALETTE_MAX_SIZE];
78 for (int j = 0; j < k; ++j) {
79 const int16_t cx = centroids[2 * j], cy = centroids[2 * j + 1];
80 const int16_t cxcy[8] = { cx, cy, cx, cy, cx, cy, cx, cy };
81 cents[j] = vld1q_s16(cxcy);
82 }
83
84 for (int i = 0; i < n; i += 8) {
85 for (int l = 0; l < 2; ++l) {
86 const int16x8_t in = vld1q_s16(data);
87 ind[l] = vdupq_n_u32(0);
88 // Compute the distance to the first centroid.
89 int16x8_t d1 = vsubq_s16(in, cents[0]);
90 int32x4_t dist_min = k_means_multiply_add_neon(d1);
91
92 for (int j = 1; j < k; ++j) {
93 // Compute the distance to the centroid.
94 d1 = vsubq_s16(in, cents[j]);
95 const int32x4_t dist = k_means_multiply_add_neon(d1);
96 // Compare to the minimal one.
97 const uint32x4_t cmp = vcgtq_s32(dist_min, dist);
98 dist_min = vminq_s32(dist_min, dist);
99 const uint32x4_t ind1 = vdupq_n_u32(j);
100 ind[l] = vbslq_u32(cmp, ind1, ind[l]);
101 }
102 if (total_dist) {
103 // Pairwise sum, convert to 64 bit and add to sum.
104 sum = vpadalq_s32(sum, dist_min);
105 }
106 data += 8;
107 }
108 // Cast to 8 bit and store.
109 vst1_u8(indices,
110 vmovn_u16(vcombine_u16(vmovn_u32(ind[0]), vmovn_u32(ind[1]))));
111 indices += 8;
112 }
113 if (total_dist) {
114 *total_dist = horizontal_add_s64x2(sum);
115 }
116 }
117