xref: /aosp_15_r20/external/libdav1d/tests/checkasm/msac.c (revision c09093415860a1c2373dacd84c4fde00c507cdfd)
1 /*
2  * Copyright © 2019, VideoLAN and dav1d authors
3  * Copyright © 2019, Two Orioles, LLC
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this
10  *    list of conditions and the following disclaimer.
11  *
12  * 2. Redistributions in binary form must reproduce the above copyright notice,
13  *    this list of conditions and the following disclaimer in the documentation
14  *    and/or other materials provided with the distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  */
27 
28 #include "tests/checkasm/checkasm.h"
29 
30 #include "src/cpu.h"
31 #include "src/msac.h"
32 
33 #include <stdio.h>
34 #include <string.h>
35 
36 #define BUF_SIZE 128
37 
38 /* The normal code doesn't use function pointers */
39 typedef unsigned (*decode_symbol_adapt_fn)(MsacContext *s, uint16_t *cdf,
40                                            size_t n_symbols);
41 typedef unsigned (*decode_adapt_fn)(MsacContext *s, uint16_t *cdf);
42 typedef unsigned (*decode_bool_equi_fn)(MsacContext *s);
43 typedef unsigned (*decode_bool_fn)(MsacContext *s, unsigned f);
44 
45 typedef struct {
46     decode_symbol_adapt_fn decode_symbol_adapt4;
47     decode_symbol_adapt_fn decode_symbol_adapt8;
48     decode_symbol_adapt_fn decode_symbol_adapt16;
49     decode_adapt_fn        decode_bool_adapt;
50     decode_bool_equi_fn    decode_bool_equi;
51     decode_bool_fn         decode_bool;
52     decode_adapt_fn        decode_hi_tok;
53 } MsacDSPContext;
54 
randomize_cdf(uint16_t * const cdf,const int n)55 static void randomize_cdf(uint16_t *const cdf, const int n) {
56     int i;
57     for (i = 15; i > n; i--)
58         cdf[i] = 0; // padding
59     cdf[i] = 0;     // count
60     do {
61         cdf[i - 1] = cdf[i] + rnd() % (32768 - cdf[i] - i) + 1;
62     } while (--i > 0);
63 }
64 
65 /* memcmp() on structs can have weird behavior due to padding etc. */
msac_cmp(const MsacContext * const a,const MsacContext * const b)66 static int msac_cmp(const MsacContext *const a, const MsacContext *const b) {
67     if (a->buf_pos != b->buf_pos || a->buf_end != b->buf_end ||
68         a->rng != b->rng || a->cnt != b->cnt ||
69         a->allow_update_cdf != b->allow_update_cdf)
70     {
71         return 1;
72     }
73 
74     /* Only check valid dif bits, ignoring partial bytes at the end */
75     const ec_win dif_mask = ~((~(ec_win)0) >> (imax(a->cnt, 0) + 16));
76     return !!((a->dif ^ b->dif) & dif_mask);
77 }
78 
msac_dump(unsigned c_res,unsigned a_res,const MsacContext * const a,const MsacContext * const b,const uint16_t * const cdf_a,const uint16_t * const cdf_b,const int num_cdf)79 static void msac_dump(unsigned c_res, unsigned a_res,
80                       const MsacContext *const a, const MsacContext *const b,
81                       const uint16_t *const cdf_a, const uint16_t *const cdf_b,
82                       const int num_cdf)
83 {
84     if (c_res != a_res)
85         fprintf(stderr, "c_res %u a_res %u\n", c_res, a_res);
86     if (a->buf_pos != b->buf_pos)
87         fprintf(stderr, "buf_pos %p vs %p\n", a->buf_pos, b->buf_pos);
88     if (a->buf_end != b->buf_end)
89         fprintf(stderr, "buf_end %p vs %p\n", a->buf_end, b->buf_end);
90     if (a->dif != b->dif)
91         fprintf(stderr, "dif %zx vs %zx\n", a->dif, b->dif);
92     if (a->rng != b->rng)
93         fprintf(stderr, "rng %u vs %u\n", a->rng, b->rng);
94     if (a->cnt != b->cnt)
95         fprintf(stderr, "cnt %d vs %d\n", a->cnt, b->cnt);
96     if (a->allow_update_cdf != b->allow_update_cdf)
97         fprintf(stderr, "allow_update_cdf %d vs %d\n",
98                 a->allow_update_cdf, b->allow_update_cdf);
99     if (num_cdf && memcmp(cdf_a, cdf_b, sizeof(*cdf_a) * (num_cdf + 1))) {
100         fprintf(stderr, "cdf:\n");
101         for (int i = 0; i <= num_cdf; i++)
102             fprintf(stderr, " %5u", cdf_a[i]);
103         fprintf(stderr, "\n");
104         for (int i = 0; i <= num_cdf; i++)
105             fprintf(stderr, " %5u", cdf_b[i]);
106         fprintf(stderr, "\n");
107         for (int i = 0; i <= num_cdf; i++)
108             fprintf(stderr, "     %c", cdf_a[i] != cdf_b[i] ? 'x' : '.');
109         fprintf(stderr, "\n");
110     }
111 }
112 
113 #define CHECK_SYMBOL_ADAPT(n, n_min, n_max) do {                           \
114     if (check_func(c->decode_symbol_adapt##n,                              \
115                    "msac_decode_symbol_adapt%d", n))                       \
116     {                                                                      \
117         for (int cdf_update = 0; cdf_update <= 1; cdf_update++) {          \
118             for (int ns = n_min; ns <= n_max; ns++) {                      \
119                 dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update);         \
120                 s_a = s_c;                                                 \
121                 randomize_cdf(cdf[0], ns);                                 \
122                 memcpy(cdf[1], cdf[0], sizeof(*cdf));                      \
123                 while (s_c.cnt >= 0) {                                     \
124                     unsigned c_res = call_ref(&s_c, cdf[0], ns);           \
125                     unsigned a_res = call_new(&s_a, cdf[1], ns);           \
126                     if (c_res != a_res || msac_cmp(&s_c, &s_a) ||          \
127                         memcmp(cdf[0], cdf[1], sizeof(**cdf) * (ns + 1)))  \
128                     {                                                      \
129                         if (fail())                                        \
130                             msac_dump(c_res, a_res, &s_c, &s_a,            \
131                                       cdf[0], cdf[1], ns);                 \
132                     }                                                      \
133                 }                                                          \
134                 if (cdf_update && ns == n - 1)                             \
135                     bench_new(alternate(&s_c, &s_a),                       \
136                               alternate(cdf[0], cdf[1]), ns);              \
137             }                                                              \
138         }                                                                  \
139     }                                                                      \
140 } while (0)
141 
check_decode_symbol(MsacDSPContext * const c,uint8_t * const buf)142 static void check_decode_symbol(MsacDSPContext *const c, uint8_t *const buf) {
143     ALIGN_STK_32(uint16_t, cdf, 2, [16]);
144     MsacContext s_c, s_a;
145 
146     declare_func(unsigned, MsacContext *s, uint16_t *cdf, size_t n_symbols);
147     CHECK_SYMBOL_ADAPT( 4, 1,  3);
148     CHECK_SYMBOL_ADAPT( 8, 1,  7);
149     CHECK_SYMBOL_ADAPT(16, 3, 15);
150     report("decode_symbol");
151 }
152 
check_decode_bool_adapt(MsacDSPContext * const c,uint8_t * const buf)153 static void check_decode_bool_adapt(MsacDSPContext *const c, uint8_t *const buf) {
154     MsacContext s_c, s_a;
155 
156     declare_func(unsigned, MsacContext *s, uint16_t *cdf);
157     if (check_func(c->decode_bool_adapt, "msac_decode_bool_adapt")) {
158         uint16_t cdf[2][2];
159         for (int cdf_update = 0; cdf_update <= 1; cdf_update++) {
160             dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update);
161             s_a = s_c;
162             cdf[0][0] = cdf[1][0] = rnd() % 32767 + 1;
163             cdf[0][1] = cdf[1][1] = 0;
164             while (s_c.cnt >= 0) {
165                 unsigned c_res = call_ref(&s_c, cdf[0]);
166                 unsigned a_res = call_new(&s_a, cdf[1]);
167                 if (c_res != a_res || msac_cmp(&s_c, &s_a) ||
168                     memcmp(cdf[0], cdf[1], sizeof(*cdf)))
169                 {
170                     if (fail())
171                         msac_dump(c_res, a_res, &s_c, &s_a, cdf[0], cdf[1], 1);
172                 }
173             }
174             if (cdf_update)
175                 bench_new(alternate(&s_c, &s_a), alternate(cdf[0], cdf[1]));
176         }
177     }
178 }
179 
check_decode_bool_equi(MsacDSPContext * const c,uint8_t * const buf)180 static void check_decode_bool_equi(MsacDSPContext *const c, uint8_t *const buf) {
181     MsacContext s_c, s_a;
182 
183     declare_func(unsigned, MsacContext *s);
184     if (check_func(c->decode_bool_equi, "msac_decode_bool_equi")) {
185         dav1d_msac_init(&s_c, buf, BUF_SIZE, 1);
186         s_a = s_c;
187         while (s_c.cnt >= 0) {
188             unsigned c_res = call_ref(&s_c);
189             unsigned a_res = call_new(&s_a);
190             if (c_res != a_res || msac_cmp(&s_c, &s_a)) {
191                 if (fail())
192                     msac_dump(c_res, a_res, &s_c, &s_a, NULL, NULL, 0);
193             }
194         }
195         bench_new(alternate(&s_c, &s_a));
196     }
197 }
198 
check_decode_bool(MsacDSPContext * const c,uint8_t * const buf)199 static void check_decode_bool(MsacDSPContext *const c, uint8_t *const buf) {
200     MsacContext s_c, s_a;
201 
202     declare_func(unsigned, MsacContext *s, unsigned f);
203     if (check_func(c->decode_bool, "msac_decode_bool")) {
204         dav1d_msac_init(&s_c, buf, BUF_SIZE, 1);
205         s_a = s_c;
206         while (s_c.cnt >= 0) {
207             const unsigned f = rnd() & 0x7fff;
208             unsigned c_res = call_ref(&s_c, f);
209             unsigned a_res = call_new(&s_a, f);
210             if (c_res != a_res || msac_cmp(&s_c, &s_a)) {
211                 if (fail())
212                     msac_dump(c_res, a_res, &s_c, &s_a, NULL, NULL, 0);
213             }
214         }
215         bench_new(alternate(&s_c, &s_a), 16384);
216     }
217 
218 }
219 
check_decode_bool_funcs(MsacDSPContext * const c,uint8_t * const buf)220 static void check_decode_bool_funcs(MsacDSPContext *const c, uint8_t *const buf) {
221     check_decode_bool_adapt(c, buf);
222     check_decode_bool_equi(c, buf);
223     check_decode_bool(c, buf);
224     report("decode_bool");
225 }
226 
check_decode_hi_tok(MsacDSPContext * const c,uint8_t * const buf)227 static void check_decode_hi_tok(MsacDSPContext *const c, uint8_t *const buf) {
228     ALIGN_STK_16(uint16_t, cdf, 2, [16]);
229     MsacContext s_c, s_a;
230 
231     declare_func(unsigned, MsacContext *s, uint16_t *cdf);
232     if (check_func(c->decode_hi_tok, "msac_decode_hi_tok")) {
233         for (int cdf_update = 0; cdf_update <= 1; cdf_update++) {
234             dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update);
235             s_a = s_c;
236             randomize_cdf(cdf[0], 3);
237             memcpy(cdf[1], cdf[0], sizeof(*cdf));
238             while (s_c.cnt >= 0) {
239                 unsigned c_res = call_ref(&s_c, cdf[0]);
240                 unsigned a_res = call_new(&s_a, cdf[1]);
241                 if (c_res != a_res || msac_cmp(&s_c, &s_a) ||
242                     memcmp(cdf[0], cdf[1], sizeof(*cdf)))
243                 {
244                     if (fail())
245                         msac_dump(c_res, a_res, &s_c, &s_a, cdf[0], cdf[1], 3);
246                     break;
247                 }
248             }
249             if (cdf_update)
250                 bench_new(alternate(&s_c, &s_a), alternate(cdf[0], cdf[1]));
251         }
252     }
253     report("decode_hi_tok");
254 }
255 
checkasm_check_msac(void)256 void checkasm_check_msac(void) {
257     MsacDSPContext c;
258     c.decode_symbol_adapt4  = dav1d_msac_decode_symbol_adapt_c;
259     c.decode_symbol_adapt8  = dav1d_msac_decode_symbol_adapt_c;
260     c.decode_symbol_adapt16 = dav1d_msac_decode_symbol_adapt_c;
261     c.decode_bool_adapt     = dav1d_msac_decode_bool_adapt_c;
262     c.decode_bool_equi      = dav1d_msac_decode_bool_equi_c;
263     c.decode_bool           = dav1d_msac_decode_bool_c;
264     c.decode_hi_tok         = dav1d_msac_decode_hi_tok_c;
265 
266 #if (ARCH_AARCH64 || ARCH_ARM) && HAVE_ASM
267     if (dav1d_get_cpu_flags() & DAV1D_ARM_CPU_FLAG_NEON) {
268         c.decode_symbol_adapt4  = dav1d_msac_decode_symbol_adapt4_neon;
269         c.decode_symbol_adapt8  = dav1d_msac_decode_symbol_adapt8_neon;
270         c.decode_symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_neon;
271         c.decode_bool_adapt     = dav1d_msac_decode_bool_adapt_neon;
272         c.decode_bool_equi      = dav1d_msac_decode_bool_equi_neon;
273         c.decode_bool           = dav1d_msac_decode_bool_neon;
274         c.decode_hi_tok         = dav1d_msac_decode_hi_tok_neon;
275     }
276 #elif ARCH_LOONGARCH64 && HAVE_ASM
277     if (dav1d_get_cpu_flags() & DAV1D_LOONGARCH_CPU_FLAG_LSX) {
278         c.decode_symbol_adapt4  = dav1d_msac_decode_symbol_adapt4_lsx;
279         c.decode_symbol_adapt8  = dav1d_msac_decode_symbol_adapt8_lsx;
280         c.decode_symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_lsx;
281         c.decode_bool_adapt     = dav1d_msac_decode_bool_adapt_lsx;
282         c.decode_bool           = dav1d_msac_decode_bool_lsx;
283         c.decode_bool_equi      = dav1d_msac_decode_bool_equi_lsx;
284         c.decode_hi_tok         = dav1d_msac_decode_hi_tok_lsx;
285     }
286 #elif ARCH_X86 && HAVE_ASM
287     if (dav1d_get_cpu_flags() & DAV1D_X86_CPU_FLAG_SSE2) {
288         c.decode_symbol_adapt4  = dav1d_msac_decode_symbol_adapt4_sse2;
289         c.decode_symbol_adapt8  = dav1d_msac_decode_symbol_adapt8_sse2;
290         c.decode_symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_sse2;
291         c.decode_bool_adapt     = dav1d_msac_decode_bool_adapt_sse2;
292         c.decode_bool_equi      = dav1d_msac_decode_bool_equi_sse2;
293         c.decode_bool           = dav1d_msac_decode_bool_sse2;
294         c.decode_hi_tok         = dav1d_msac_decode_hi_tok_sse2;
295     }
296 
297 #if ARCH_X86_64
298     if (dav1d_get_cpu_flags() & DAV1D_X86_CPU_FLAG_AVX2) {
299         c.decode_symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_avx2;
300     }
301 #endif
302 #endif
303 
304     uint8_t buf[BUF_SIZE];
305     for (int i = 0; i < BUF_SIZE; i++)
306         buf[i] = rnd();
307 
308     check_decode_symbol(&c, buf);
309     check_decode_bool_funcs(&c, buf);
310     check_decode_hi_tok(&c, buf);
311 }
312