xref: /aosp_15_r20/external/libaom/test/av1_txfm_test.cc (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include "test/av1_txfm_test.h"
13 
14 #include <stdio.h>
15 
16 #include <memory>
17 #include <new>
18 
19 namespace libaom_test {
20 
21 const char *tx_type_name[] = {
22   "DCT_DCT",
23   "ADST_DCT",
24   "DCT_ADST",
25   "ADST_ADST",
26   "FLIPADST_DCT",
27   "DCT_FLIPADST",
28   "FLIPADST_FLIPADST",
29   "ADST_FLIPADST",
30   "FLIPADST_ADST",
31   "IDTX",
32   "V_DCT",
33   "H_DCT",
34   "V_ADST",
35   "H_ADST",
36   "V_FLIPADST",
37   "H_FLIPADST",
38 };
39 
get_txfm1d_size(TX_SIZE tx_size)40 int get_txfm1d_size(TX_SIZE tx_size) { return tx_size_wide[tx_size]; }
41 
get_txfm1d_type(TX_TYPE txfm2d_type,TYPE_TXFM * type0,TYPE_TXFM * type1)42 void get_txfm1d_type(TX_TYPE txfm2d_type, TYPE_TXFM *type0, TYPE_TXFM *type1) {
43   switch (txfm2d_type) {
44     case DCT_DCT:
45       *type0 = TYPE_DCT;
46       *type1 = TYPE_DCT;
47       break;
48     case ADST_DCT:
49       *type0 = TYPE_ADST;
50       *type1 = TYPE_DCT;
51       break;
52     case DCT_ADST:
53       *type0 = TYPE_DCT;
54       *type1 = TYPE_ADST;
55       break;
56     case ADST_ADST:
57       *type0 = TYPE_ADST;
58       *type1 = TYPE_ADST;
59       break;
60     case FLIPADST_DCT:
61       *type0 = TYPE_ADST;
62       *type1 = TYPE_DCT;
63       break;
64     case DCT_FLIPADST:
65       *type0 = TYPE_DCT;
66       *type1 = TYPE_ADST;
67       break;
68     case FLIPADST_FLIPADST:
69       *type0 = TYPE_ADST;
70       *type1 = TYPE_ADST;
71       break;
72     case ADST_FLIPADST:
73       *type0 = TYPE_ADST;
74       *type1 = TYPE_ADST;
75       break;
76     case FLIPADST_ADST:
77       *type0 = TYPE_ADST;
78       *type1 = TYPE_ADST;
79       break;
80     case IDTX:
81       *type0 = TYPE_IDTX;
82       *type1 = TYPE_IDTX;
83       break;
84     case H_DCT:
85       *type0 = TYPE_IDTX;
86       *type1 = TYPE_DCT;
87       break;
88     case V_DCT:
89       *type0 = TYPE_DCT;
90       *type1 = TYPE_IDTX;
91       break;
92     case H_ADST:
93       *type0 = TYPE_IDTX;
94       *type1 = TYPE_ADST;
95       break;
96     case V_ADST:
97       *type0 = TYPE_ADST;
98       *type1 = TYPE_IDTX;
99       break;
100     case H_FLIPADST:
101       *type0 = TYPE_IDTX;
102       *type1 = TYPE_ADST;
103       break;
104     case V_FLIPADST:
105       *type0 = TYPE_ADST;
106       *type1 = TYPE_IDTX;
107       break;
108     default:
109       *type0 = TYPE_DCT;
110       *type1 = TYPE_DCT;
111       assert(0);
112       break;
113   }
114 }
115 
116 double Sqrt2 = pow(2, 0.5);
117 double invSqrt2 = 1 / pow(2, 0.5);
118 
dct_matrix(double n,double k,int size)119 static double dct_matrix(double n, double k, int size) {
120   return cos(PI * (2 * n + 1) * k / (2 * size));
121 }
122 
reference_dct_1d(const double * in,double * out,int size)123 void reference_dct_1d(const double *in, double *out, int size) {
124   for (int k = 0; k < size; ++k) {
125     out[k] = 0;
126     for (int n = 0; n < size; ++n) {
127       out[k] += in[n] * dct_matrix(n, k, size);
128     }
129     if (k == 0) out[k] = out[k] * invSqrt2;
130   }
131 }
132 
reference_idct_1d(const double * in,double * out,int size)133 void reference_idct_1d(const double *in, double *out, int size) {
134   for (int k = 0; k < size; ++k) {
135     out[k] = 0;
136     for (int n = 0; n < size; ++n) {
137       if (n == 0)
138         out[k] += invSqrt2 * in[n] * dct_matrix(k, n, size);
139       else
140         out[k] += in[n] * dct_matrix(k, n, size);
141     }
142   }
143 }
144 
145 // TODO(any): Copied from the old 'fadst4' (same as the new 'av1_fadst4'
146 // function). Should be replaced by a proper reference function that takes
147 // 'double' input & output.
fadst4_new(const tran_low_t * input,tran_low_t * output)148 static void fadst4_new(const tran_low_t *input, tran_low_t *output) {
149   tran_high_t x0, x1, x2, x3;
150   tran_high_t s0, s1, s2, s3, s4, s5, s6, s7;
151 
152   x0 = input[0];
153   x1 = input[1];
154   x2 = input[2];
155   x3 = input[3];
156 
157   if (!(x0 | x1 | x2 | x3)) {
158     output[0] = output[1] = output[2] = output[3] = 0;
159     return;
160   }
161 
162   s0 = sinpi_1_9 * x0;
163   s1 = sinpi_4_9 * x0;
164   s2 = sinpi_2_9 * x1;
165   s3 = sinpi_1_9 * x1;
166   s4 = sinpi_3_9 * x2;
167   s5 = sinpi_4_9 * x3;
168   s6 = sinpi_2_9 * x3;
169   s7 = x0 + x1 - x3;
170 
171   x0 = s0 + s2 + s5;
172   x1 = sinpi_3_9 * s7;
173   x2 = s1 - s3 + s6;
174   x3 = s4;
175 
176   s0 = x0 + x3;
177   s1 = x1;
178   s2 = x2 - x3;
179   s3 = x2 - x0 + x3;
180 
181   // 1-D transform scaling factor is sqrt(2).
182   output[0] = (tran_low_t)fdct_round_shift(s0);
183   output[1] = (tran_low_t)fdct_round_shift(s1);
184   output[2] = (tran_low_t)fdct_round_shift(s2);
185   output[3] = (tran_low_t)fdct_round_shift(s3);
186 }
187 
reference_adst_1d(const double * in,double * out,int size)188 void reference_adst_1d(const double *in, double *out, int size) {
189   if (size == 4) {  // Special case.
190     tran_low_t int_input[4];
191     for (int i = 0; i < 4; ++i) {
192       int_input[i] = static_cast<tran_low_t>(round(in[i]));
193     }
194     tran_low_t int_output[4];
195     fadst4_new(int_input, int_output);
196     for (int i = 0; i < 4; ++i) {
197       out[i] = int_output[i];
198     }
199     return;
200   }
201 
202   for (int k = 0; k < size; ++k) {
203     out[k] = 0;
204     for (int n = 0; n < size; ++n) {
205       out[k] += in[n] * sin(PI * (2 * n + 1) * (2 * k + 1) / (4 * size));
206     }
207   }
208 }
209 
reference_idtx_1d(const double * in,double * out,int size)210 static void reference_idtx_1d(const double *in, double *out, int size) {
211   double scale = 0;
212   if (size == 4)
213     scale = Sqrt2;
214   else if (size == 8)
215     scale = 2;
216   else if (size == 16)
217     scale = 2 * Sqrt2;
218   else if (size == 32)
219     scale = 4;
220   else if (size == 64)
221     scale = 4 * Sqrt2;
222   for (int k = 0; k < size; ++k) {
223     out[k] = in[k] * scale;
224   }
225 }
226 
reference_hybrid_1d(double * in,double * out,int size,int type)227 void reference_hybrid_1d(double *in, double *out, int size, int type) {
228   if (type == TYPE_DCT)
229     reference_dct_1d(in, out, size);
230   else if (type == TYPE_ADST)
231     reference_adst_1d(in, out, size);
232   else
233     reference_idtx_1d(in, out, size);
234 }
235 
get_amplification_factor(TX_TYPE tx_type,TX_SIZE tx_size)236 double get_amplification_factor(TX_TYPE tx_type, TX_SIZE tx_size) {
237   TXFM_2D_FLIP_CFG fwd_txfm_flip_cfg;
238   av1_get_fwd_txfm_cfg(tx_type, tx_size, &fwd_txfm_flip_cfg);
239   const int tx_width = tx_size_wide[fwd_txfm_flip_cfg.tx_size];
240   const int tx_height = tx_size_high[fwd_txfm_flip_cfg.tx_size];
241   const int8_t *shift = fwd_txfm_flip_cfg.shift;
242   const int amplify_bit = shift[0] + shift[1] + shift[2];
243   double amplify_factor =
244       amplify_bit >= 0 ? (1 << amplify_bit) : (1.0 / (1 << -amplify_bit));
245 
246   // For rectangular transforms, we need to multiply by an extra factor.
247   const int rect_type = get_rect_tx_log_ratio(tx_width, tx_height);
248   if (abs(rect_type) == 1) {
249     amplify_factor *= pow(2, 0.5);
250   }
251   return amplify_factor;
252 }
253 
reference_hybrid_2d(double * in,double * out,TX_TYPE tx_type,TX_SIZE tx_size)254 void reference_hybrid_2d(double *in, double *out, TX_TYPE tx_type,
255                          TX_SIZE tx_size) {
256   // Get transform type and size of each dimension.
257   TYPE_TXFM type0;
258   TYPE_TXFM type1;
259   get_txfm1d_type(tx_type, &type0, &type1);
260   const int tx_width = tx_size_wide[tx_size];
261   const int tx_height = tx_size_high[tx_size];
262 
263   std::unique_ptr<double[]> temp_in(
264       new (std::nothrow) double[AOMMAX(tx_width, tx_height)]);
265   std::unique_ptr<double[]> temp_out(
266       new (std::nothrow) double[AOMMAX(tx_width, tx_height)]);
267   std::unique_ptr<double[]> out_interm(
268       new (std::nothrow) double[tx_width * tx_height]);
269   ASSERT_NE(temp_in, nullptr);
270   ASSERT_NE(temp_out, nullptr);
271   ASSERT_NE(out_interm, nullptr);
272 
273   // Transform columns.
274   for (int c = 0; c < tx_width; ++c) {
275     for (int r = 0; r < tx_height; ++r) {
276       temp_in[r] = in[r * tx_width + c];
277     }
278     reference_hybrid_1d(temp_in.get(), temp_out.get(), tx_height, type0);
279     for (int r = 0; r < tx_height; ++r) {
280       out_interm[r * tx_width + c] = temp_out[r];
281     }
282   }
283 
284   // Transform rows.
285   for (int r = 0; r < tx_height; ++r) {
286     reference_hybrid_1d(out_interm.get() + r * tx_width, temp_out.get(),
287                         tx_width, type1);
288     for (int c = 0; c < tx_width; ++c) {
289       out[c * tx_height + r] = temp_out[c];
290     }
291   }
292 
293   // These transforms use an approximate 2D DCT transform, by only keeping the
294   // top-left quarter of the coefficients, and repacking them in the first
295   // quarter indices.
296   // TODO(urvang): Refactor this code.
297   if (tx_width == 64 && tx_height == 64) {  // tx_size == TX_64X64
298     // Zero out top-right 32x32 area.
299     for (int col = 0; col < 32; ++col) {
300       memset(out + col * 64 + 32, 0, 32 * sizeof(*out));
301     }
302     // Zero out the bottom 64x32 area.
303     memset(out + 32 * 64, 0, 32 * 64 * sizeof(*out));
304     // Re-pack non-zero coeffs in the first 32x32 indices.
305     for (int col = 1; col < 32; ++col) {
306       memcpy(out + col * 32, out + col * 64, 32 * sizeof(*out));
307     }
308   } else if (tx_width == 32 && tx_height == 64) {  // tx_size == TX_32X64
309     // Zero out right 32x32 area.
310     for (int col = 0; col < 32; ++col) {
311       memset(out + col * 64 + 32, 0, 32 * sizeof(*out));
312     }
313     // Re-pack non-zero coeffs in the first 32x32 indices.
314     for (int col = 1; col < 32; ++col) {
315       memcpy(out + col * 32, out + col * 64, 32 * sizeof(*out));
316     }
317   } else if (tx_width == 64 && tx_height == 32) {  // tx_size == TX_64X32
318     // Zero out the bottom 32x32 area.
319     memset(out + 32 * 32, 0, 32 * 32 * sizeof(*out));
320     // Note: no repacking needed here.
321   } else if (tx_width == 16 && tx_height == 64) {  // tx_size == TX_16X64
322     // Note: no repacking needed here.
323     // Zero out right 32x16 area.
324     for (int col = 0; col < 16; ++col) {
325       memset(out + col * 64 + 32, 0, 32 * sizeof(*out));
326     }
327     // Re-pack non-zero coeffs in the first 32x16 indices.
328     for (int col = 1; col < 16; ++col) {
329       memcpy(out + col * 32, out + col * 64, 32 * sizeof(*out));
330     }
331   } else if (tx_width == 64 && tx_height == 16) {  // tx_size == TX_64X16
332     // Zero out the bottom 16x32 area.
333     memset(out + 16 * 32, 0, 16 * 32 * sizeof(*out));
334   }
335 
336   // Apply appropriate scale.
337   const double amplify_factor = get_amplification_factor(tx_type, tx_size);
338   for (int c = 0; c < tx_width; ++c) {
339     for (int r = 0; r < tx_height; ++r) {
340       out[c * tx_height + r] *= amplify_factor;
341     }
342   }
343 }
344 
345 template <typename Type>
fliplr(Type * dest,int width,int height,int stride)346 void fliplr(Type *dest, int width, int height, int stride) {
347   for (int r = 0; r < height; ++r) {
348     for (int c = 0; c < width / 2; ++c) {
349       const Type tmp = dest[r * stride + c];
350       dest[r * stride + c] = dest[r * stride + width - 1 - c];
351       dest[r * stride + width - 1 - c] = tmp;
352     }
353   }
354 }
355 
356 template <typename Type>
flipud(Type * dest,int width,int height,int stride)357 void flipud(Type *dest, int width, int height, int stride) {
358   for (int c = 0; c < width; ++c) {
359     for (int r = 0; r < height / 2; ++r) {
360       const Type tmp = dest[r * stride + c];
361       dest[r * stride + c] = dest[(height - 1 - r) * stride + c];
362       dest[(height - 1 - r) * stride + c] = tmp;
363     }
364   }
365 }
366 
367 template <typename Type>
fliplrud(Type * dest,int width,int height,int stride)368 void fliplrud(Type *dest, int width, int height, int stride) {
369   for (int r = 0; r < height / 2; ++r) {
370     for (int c = 0; c < width; ++c) {
371       const Type tmp = dest[r * stride + c];
372       dest[r * stride + c] = dest[(height - 1 - r) * stride + width - 1 - c];
373       dest[(height - 1 - r) * stride + width - 1 - c] = tmp;
374     }
375   }
376 }
377 
378 template void fliplr<double>(double *dest, int width, int height, int stride);
379 template void flipud<double>(double *dest, int width, int height, int stride);
380 template void fliplrud<double>(double *dest, int width, int height, int stride);
381 
382 int bd_arr[BD_NUM] = { 8, 10, 12 };
383 
384 int8_t low_range_arr[BD_NUM] = { 18, 32, 32 };
385 int8_t high_range_arr[BD_NUM] = { 32, 32, 32 };
386 
txfm_stage_range_check(const int8_t * stage_range,int stage_num,int8_t cos_bit,int low_range,int high_range)387 void txfm_stage_range_check(const int8_t *stage_range, int stage_num,
388                             int8_t cos_bit, int low_range, int high_range) {
389   for (int i = 0; i < stage_num; ++i) {
390     EXPECT_LE(stage_range[i], low_range);
391     ASSERT_LE(stage_range[i] + cos_bit, high_range) << "stage = " << i;
392   }
393   for (int i = 0; i < stage_num - 1; ++i) {
394     // make sure there is no overflow while doing half_btf()
395     ASSERT_LE(stage_range[i + 1] + cos_bit, high_range) << "stage = " << i;
396   }
397 }
398 }  // namespace libaom_test
399