1 /*
2 * Copyright © 2019, VideoLAN and dav1d authors
3 * Copyright © 2019, Two Orioles, LLC
4 * All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions are met:
8 *
9 * 1. Redistributions of source code must retain the above copyright notice, this
10 * list of conditions and the following disclaimer.
11 *
12 * 2. Redistributions in binary form must reproduce the above copyright notice,
13 * this list of conditions and the following disclaimer in the documentation
14 * and/or other materials provided with the distribution.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28 #include "tests/checkasm/checkasm.h"
29
30 #include "src/cpu.h"
31 #include "src/msac.h"
32
33 #include <stdio.h>
34 #include <string.h>
35
36 #define BUF_SIZE 128
37
38 /* The normal code doesn't use function pointers */
39 typedef unsigned (*decode_symbol_adapt_fn)(MsacContext *s, uint16_t *cdf,
40 size_t n_symbols);
41 typedef unsigned (*decode_adapt_fn)(MsacContext *s, uint16_t *cdf);
42 typedef unsigned (*decode_bool_equi_fn)(MsacContext *s);
43 typedef unsigned (*decode_bool_fn)(MsacContext *s, unsigned f);
44
45 typedef struct {
46 decode_symbol_adapt_fn decode_symbol_adapt4;
47 decode_symbol_adapt_fn decode_symbol_adapt8;
48 decode_symbol_adapt_fn decode_symbol_adapt16;
49 decode_adapt_fn decode_bool_adapt;
50 decode_bool_equi_fn decode_bool_equi;
51 decode_bool_fn decode_bool;
52 decode_adapt_fn decode_hi_tok;
53 } MsacDSPContext;
54
randomize_cdf(uint16_t * const cdf,const int n)55 static void randomize_cdf(uint16_t *const cdf, const int n) {
56 int i;
57 for (i = 15; i > n; i--)
58 cdf[i] = 0; // padding
59 cdf[i] = 0; // count
60 do {
61 cdf[i - 1] = cdf[i] + rnd() % (32768 - cdf[i] - i) + 1;
62 } while (--i > 0);
63 }
64
65 /* memcmp() on structs can have weird behavior due to padding etc. */
msac_cmp(const MsacContext * const a,const MsacContext * const b)66 static int msac_cmp(const MsacContext *const a, const MsacContext *const b) {
67 if (a->buf_pos != b->buf_pos || a->buf_end != b->buf_end ||
68 a->rng != b->rng || a->cnt != b->cnt ||
69 a->allow_update_cdf != b->allow_update_cdf)
70 {
71 return 1;
72 }
73
74 /* Only check valid dif bits, ignoring partial bytes at the end */
75 const ec_win dif_mask = ~((~(ec_win)0) >> (imax(a->cnt, 0) + 16));
76 return !!((a->dif ^ b->dif) & dif_mask);
77 }
78
msac_dump(unsigned c_res,unsigned a_res,const MsacContext * const a,const MsacContext * const b,const uint16_t * const cdf_a,const uint16_t * const cdf_b,const int num_cdf)79 static void msac_dump(unsigned c_res, unsigned a_res,
80 const MsacContext *const a, const MsacContext *const b,
81 const uint16_t *const cdf_a, const uint16_t *const cdf_b,
82 const int num_cdf)
83 {
84 if (c_res != a_res)
85 fprintf(stderr, "c_res %u a_res %u\n", c_res, a_res);
86 if (a->buf_pos != b->buf_pos)
87 fprintf(stderr, "buf_pos %p vs %p\n", a->buf_pos, b->buf_pos);
88 if (a->buf_end != b->buf_end)
89 fprintf(stderr, "buf_end %p vs %p\n", a->buf_end, b->buf_end);
90 if (a->dif != b->dif)
91 fprintf(stderr, "dif %zx vs %zx\n", a->dif, b->dif);
92 if (a->rng != b->rng)
93 fprintf(stderr, "rng %u vs %u\n", a->rng, b->rng);
94 if (a->cnt != b->cnt)
95 fprintf(stderr, "cnt %d vs %d\n", a->cnt, b->cnt);
96 if (a->allow_update_cdf != b->allow_update_cdf)
97 fprintf(stderr, "allow_update_cdf %d vs %d\n",
98 a->allow_update_cdf, b->allow_update_cdf);
99 if (num_cdf && memcmp(cdf_a, cdf_b, sizeof(*cdf_a) * (num_cdf + 1))) {
100 fprintf(stderr, "cdf:\n");
101 for (int i = 0; i <= num_cdf; i++)
102 fprintf(stderr, " %5u", cdf_a[i]);
103 fprintf(stderr, "\n");
104 for (int i = 0; i <= num_cdf; i++)
105 fprintf(stderr, " %5u", cdf_b[i]);
106 fprintf(stderr, "\n");
107 for (int i = 0; i <= num_cdf; i++)
108 fprintf(stderr, " %c", cdf_a[i] != cdf_b[i] ? 'x' : '.');
109 fprintf(stderr, "\n");
110 }
111 }
112
113 #define CHECK_SYMBOL_ADAPT(n, n_min, n_max) do { \
114 if (check_func(c->decode_symbol_adapt##n, \
115 "msac_decode_symbol_adapt%d", n)) \
116 { \
117 for (int cdf_update = 0; cdf_update <= 1; cdf_update++) { \
118 for (int ns = n_min; ns <= n_max; ns++) { \
119 dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update); \
120 s_a = s_c; \
121 randomize_cdf(cdf[0], ns); \
122 memcpy(cdf[1], cdf[0], sizeof(*cdf)); \
123 while (s_c.cnt >= 0) { \
124 unsigned c_res = call_ref(&s_c, cdf[0], ns); \
125 unsigned a_res = call_new(&s_a, cdf[1], ns); \
126 if (c_res != a_res || msac_cmp(&s_c, &s_a) || \
127 memcmp(cdf[0], cdf[1], sizeof(**cdf) * (ns + 1))) \
128 { \
129 if (fail()) \
130 msac_dump(c_res, a_res, &s_c, &s_a, \
131 cdf[0], cdf[1], ns); \
132 } \
133 } \
134 if (cdf_update && ns == n - 1) \
135 bench_new(alternate(&s_c, &s_a), \
136 alternate(cdf[0], cdf[1]), ns); \
137 } \
138 } \
139 } \
140 } while (0)
141
check_decode_symbol(MsacDSPContext * const c,uint8_t * const buf)142 static void check_decode_symbol(MsacDSPContext *const c, uint8_t *const buf) {
143 ALIGN_STK_32(uint16_t, cdf, 2, [16]);
144 MsacContext s_c, s_a;
145
146 declare_func(unsigned, MsacContext *s, uint16_t *cdf, size_t n_symbols);
147 CHECK_SYMBOL_ADAPT( 4, 1, 3);
148 CHECK_SYMBOL_ADAPT( 8, 1, 7);
149 CHECK_SYMBOL_ADAPT(16, 3, 15);
150 report("decode_symbol");
151 }
152
check_decode_bool_adapt(MsacDSPContext * const c,uint8_t * const buf)153 static void check_decode_bool_adapt(MsacDSPContext *const c, uint8_t *const buf) {
154 MsacContext s_c, s_a;
155
156 declare_func(unsigned, MsacContext *s, uint16_t *cdf);
157 if (check_func(c->decode_bool_adapt, "msac_decode_bool_adapt")) {
158 uint16_t cdf[2][2];
159 for (int cdf_update = 0; cdf_update <= 1; cdf_update++) {
160 dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update);
161 s_a = s_c;
162 cdf[0][0] = cdf[1][0] = rnd() % 32767 + 1;
163 cdf[0][1] = cdf[1][1] = 0;
164 while (s_c.cnt >= 0) {
165 unsigned c_res = call_ref(&s_c, cdf[0]);
166 unsigned a_res = call_new(&s_a, cdf[1]);
167 if (c_res != a_res || msac_cmp(&s_c, &s_a) ||
168 memcmp(cdf[0], cdf[1], sizeof(*cdf)))
169 {
170 if (fail())
171 msac_dump(c_res, a_res, &s_c, &s_a, cdf[0], cdf[1], 1);
172 }
173 }
174 if (cdf_update)
175 bench_new(alternate(&s_c, &s_a), alternate(cdf[0], cdf[1]));
176 }
177 }
178 }
179
check_decode_bool_equi(MsacDSPContext * const c,uint8_t * const buf)180 static void check_decode_bool_equi(MsacDSPContext *const c, uint8_t *const buf) {
181 MsacContext s_c, s_a;
182
183 declare_func(unsigned, MsacContext *s);
184 if (check_func(c->decode_bool_equi, "msac_decode_bool_equi")) {
185 dav1d_msac_init(&s_c, buf, BUF_SIZE, 1);
186 s_a = s_c;
187 while (s_c.cnt >= 0) {
188 unsigned c_res = call_ref(&s_c);
189 unsigned a_res = call_new(&s_a);
190 if (c_res != a_res || msac_cmp(&s_c, &s_a)) {
191 if (fail())
192 msac_dump(c_res, a_res, &s_c, &s_a, NULL, NULL, 0);
193 }
194 }
195 bench_new(alternate(&s_c, &s_a));
196 }
197 }
198
check_decode_bool(MsacDSPContext * const c,uint8_t * const buf)199 static void check_decode_bool(MsacDSPContext *const c, uint8_t *const buf) {
200 MsacContext s_c, s_a;
201
202 declare_func(unsigned, MsacContext *s, unsigned f);
203 if (check_func(c->decode_bool, "msac_decode_bool")) {
204 dav1d_msac_init(&s_c, buf, BUF_SIZE, 1);
205 s_a = s_c;
206 while (s_c.cnt >= 0) {
207 const unsigned f = rnd() & 0x7fff;
208 unsigned c_res = call_ref(&s_c, f);
209 unsigned a_res = call_new(&s_a, f);
210 if (c_res != a_res || msac_cmp(&s_c, &s_a)) {
211 if (fail())
212 msac_dump(c_res, a_res, &s_c, &s_a, NULL, NULL, 0);
213 }
214 }
215 bench_new(alternate(&s_c, &s_a), 16384);
216 }
217
218 }
219
check_decode_bool_funcs(MsacDSPContext * const c,uint8_t * const buf)220 static void check_decode_bool_funcs(MsacDSPContext *const c, uint8_t *const buf) {
221 check_decode_bool_adapt(c, buf);
222 check_decode_bool_equi(c, buf);
223 check_decode_bool(c, buf);
224 report("decode_bool");
225 }
226
check_decode_hi_tok(MsacDSPContext * const c,uint8_t * const buf)227 static void check_decode_hi_tok(MsacDSPContext *const c, uint8_t *const buf) {
228 ALIGN_STK_16(uint16_t, cdf, 2, [16]);
229 MsacContext s_c, s_a;
230
231 declare_func(unsigned, MsacContext *s, uint16_t *cdf);
232 if (check_func(c->decode_hi_tok, "msac_decode_hi_tok")) {
233 for (int cdf_update = 0; cdf_update <= 1; cdf_update++) {
234 dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update);
235 s_a = s_c;
236 randomize_cdf(cdf[0], 3);
237 memcpy(cdf[1], cdf[0], sizeof(*cdf));
238 while (s_c.cnt >= 0) {
239 unsigned c_res = call_ref(&s_c, cdf[0]);
240 unsigned a_res = call_new(&s_a, cdf[1]);
241 if (c_res != a_res || msac_cmp(&s_c, &s_a) ||
242 memcmp(cdf[0], cdf[1], sizeof(*cdf)))
243 {
244 if (fail())
245 msac_dump(c_res, a_res, &s_c, &s_a, cdf[0], cdf[1], 3);
246 break;
247 }
248 }
249 if (cdf_update)
250 bench_new(alternate(&s_c, &s_a), alternate(cdf[0], cdf[1]));
251 }
252 }
253 report("decode_hi_tok");
254 }
255
checkasm_check_msac(void)256 void checkasm_check_msac(void) {
257 MsacDSPContext c;
258 c.decode_symbol_adapt4 = dav1d_msac_decode_symbol_adapt_c;
259 c.decode_symbol_adapt8 = dav1d_msac_decode_symbol_adapt_c;
260 c.decode_symbol_adapt16 = dav1d_msac_decode_symbol_adapt_c;
261 c.decode_bool_adapt = dav1d_msac_decode_bool_adapt_c;
262 c.decode_bool_equi = dav1d_msac_decode_bool_equi_c;
263 c.decode_bool = dav1d_msac_decode_bool_c;
264 c.decode_hi_tok = dav1d_msac_decode_hi_tok_c;
265
266 #if (ARCH_AARCH64 || ARCH_ARM) && HAVE_ASM
267 if (dav1d_get_cpu_flags() & DAV1D_ARM_CPU_FLAG_NEON) {
268 c.decode_symbol_adapt4 = dav1d_msac_decode_symbol_adapt4_neon;
269 c.decode_symbol_adapt8 = dav1d_msac_decode_symbol_adapt8_neon;
270 c.decode_symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_neon;
271 c.decode_bool_adapt = dav1d_msac_decode_bool_adapt_neon;
272 c.decode_bool_equi = dav1d_msac_decode_bool_equi_neon;
273 c.decode_bool = dav1d_msac_decode_bool_neon;
274 c.decode_hi_tok = dav1d_msac_decode_hi_tok_neon;
275 }
276 #elif ARCH_LOONGARCH64 && HAVE_ASM
277 if (dav1d_get_cpu_flags() & DAV1D_LOONGARCH_CPU_FLAG_LSX) {
278 c.decode_symbol_adapt4 = dav1d_msac_decode_symbol_adapt4_lsx;
279 c.decode_symbol_adapt8 = dav1d_msac_decode_symbol_adapt8_lsx;
280 c.decode_symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_lsx;
281 c.decode_bool_adapt = dav1d_msac_decode_bool_adapt_lsx;
282 c.decode_bool = dav1d_msac_decode_bool_lsx;
283 c.decode_bool_equi = dav1d_msac_decode_bool_equi_lsx;
284 c.decode_hi_tok = dav1d_msac_decode_hi_tok_lsx;
285 }
286 #elif ARCH_X86 && HAVE_ASM
287 if (dav1d_get_cpu_flags() & DAV1D_X86_CPU_FLAG_SSE2) {
288 c.decode_symbol_adapt4 = dav1d_msac_decode_symbol_adapt4_sse2;
289 c.decode_symbol_adapt8 = dav1d_msac_decode_symbol_adapt8_sse2;
290 c.decode_symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_sse2;
291 c.decode_bool_adapt = dav1d_msac_decode_bool_adapt_sse2;
292 c.decode_bool_equi = dav1d_msac_decode_bool_equi_sse2;
293 c.decode_bool = dav1d_msac_decode_bool_sse2;
294 c.decode_hi_tok = dav1d_msac_decode_hi_tok_sse2;
295 }
296
297 #if ARCH_X86_64
298 if (dav1d_get_cpu_flags() & DAV1D_X86_CPU_FLAG_AVX2) {
299 c.decode_symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_avx2;
300 }
301 #endif
302 #endif
303
304 uint8_t buf[BUF_SIZE];
305 for (int i = 0; i < BUF_SIZE; i++)
306 buf[i] = rnd();
307
308 check_decode_symbol(&c, buf);
309 check_decode_bool_funcs(&c, buf);
310 check_decode_hi_tok(&c, buf);
311 }
312