xref: /aosp_15_r20/external/libaom/test/av1_fwd_txfm2d_test.cc (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <math.h>
13 #include <stdio.h>
14 #include <stdlib.h>
15 #include <tuple>
16 #include <vector>
17 
18 #include "config/av1_rtcd.h"
19 
20 #include "test/acm_random.h"
21 #include "test/util.h"
22 #include "test/av1_txfm_test.h"
23 #include "av1/common/av1_txfm.h"
24 #include "av1/encoder/hybrid_fwd_txfm.h"
25 
26 using libaom_test::ACMRandom;
27 using libaom_test::bd;
28 using libaom_test::compute_avg_abs_error;
29 using libaom_test::input_base;
30 using libaom_test::tx_type_name;
31 using libaom_test::TYPE_TXFM;
32 
33 using std::vector;
34 
35 namespace {
36 // tx_type_, tx_size_, max_error_, max_avg_error_
37 typedef std::tuple<TX_TYPE, TX_SIZE, double, double> AV1FwdTxfm2dParam;
38 
39 class AV1FwdTxfm2d : public ::testing::TestWithParam<AV1FwdTxfm2dParam> {
40  public:
SetUp()41   void SetUp() override {
42     tx_type_ = GET_PARAM(0);
43     tx_size_ = GET_PARAM(1);
44     max_error_ = GET_PARAM(2);
45     max_avg_error_ = GET_PARAM(3);
46     count_ = 500;
47     TXFM_2D_FLIP_CFG fwd_txfm_flip_cfg;
48     av1_get_fwd_txfm_cfg(tx_type_, tx_size_, &fwd_txfm_flip_cfg);
49     amplify_factor_ = libaom_test::get_amplification_factor(tx_type_, tx_size_);
50     tx_width_ = tx_size_wide[fwd_txfm_flip_cfg.tx_size];
51     tx_height_ = tx_size_high[fwd_txfm_flip_cfg.tx_size];
52     ud_flip_ = fwd_txfm_flip_cfg.ud_flip;
53     lr_flip_ = fwd_txfm_flip_cfg.lr_flip;
54 
55     fwd_txfm_ = libaom_test::fwd_txfm_func_ls[tx_size_];
56     txfm2d_size_ = tx_width_ * tx_height_;
57     input_ = reinterpret_cast<int16_t *>(
58         aom_memalign(16, sizeof(input_[0]) * txfm2d_size_));
59     ASSERT_NE(input_, nullptr);
60     output_ = reinterpret_cast<int32_t *>(
61         aom_memalign(16, sizeof(output_[0]) * txfm2d_size_));
62     ASSERT_NE(output_, nullptr);
63     ref_input_ = reinterpret_cast<double *>(
64         aom_memalign(16, sizeof(ref_input_[0]) * txfm2d_size_));
65     ASSERT_NE(ref_input_, nullptr);
66     ref_output_ = reinterpret_cast<double *>(
67         aom_memalign(16, sizeof(ref_output_[0]) * txfm2d_size_));
68     ASSERT_NE(ref_output_, nullptr);
69   }
70 
RunFwdAccuracyCheck()71   void RunFwdAccuracyCheck() {
72     ACMRandom rnd(ACMRandom::DeterministicSeed());
73     double avg_abs_error = 0;
74     for (int ci = 0; ci < count_; ci++) {
75       for (int ni = 0; ni < txfm2d_size_; ++ni) {
76         input_[ni] = rnd.Rand16() % input_base;
77         ref_input_[ni] = static_cast<double>(input_[ni]);
78         output_[ni] = 0;
79         ref_output_[ni] = 0;
80       }
81 
82       fwd_txfm_(input_, output_, tx_width_, tx_type_, bd);
83 
84       if (lr_flip_ && ud_flip_) {
85         libaom_test::fliplrud(ref_input_, tx_width_, tx_height_, tx_width_);
86       } else if (lr_flip_) {
87         libaom_test::fliplr(ref_input_, tx_width_, tx_height_, tx_width_);
88       } else if (ud_flip_) {
89         libaom_test::flipud(ref_input_, tx_width_, tx_height_, tx_width_);
90       }
91 
92       libaom_test::reference_hybrid_2d(ref_input_, ref_output_, tx_type_,
93                                        tx_size_);
94 
95       double actual_max_error = 0;
96       for (int ni = 0; ni < txfm2d_size_; ++ni) {
97         ref_output_[ni] = round(ref_output_[ni]);
98         const double this_error =
99             fabs(output_[ni] - ref_output_[ni]) / amplify_factor_;
100         actual_max_error = AOMMAX(actual_max_error, this_error);
101       }
102       EXPECT_GE(max_error_, actual_max_error)
103           << "tx_w: " << tx_width_ << " tx_h: " << tx_height_
104           << ", tx_type = " << (int)tx_type_;
105       if (actual_max_error > max_error_) {  // exit early.
106         break;
107       }
108 
109       avg_abs_error += compute_avg_abs_error<int32_t, double>(
110           output_, ref_output_, txfm2d_size_);
111     }
112 
113     avg_abs_error /= amplify_factor_;
114     avg_abs_error /= count_;
115     EXPECT_GE(max_avg_error_, avg_abs_error)
116         << "tx_size = " << tx_size_ << ", tx_type = " << tx_type_;
117   }
118 
TearDown()119   void TearDown() override {
120     aom_free(input_);
121     aom_free(output_);
122     aom_free(ref_input_);
123     aom_free(ref_output_);
124   }
125 
126  private:
127   double max_error_;
128   double max_avg_error_;
129   int count_;
130   double amplify_factor_;
131   TX_TYPE tx_type_;
132   TX_SIZE tx_size_;
133   int tx_width_;
134   int tx_height_;
135   int txfm2d_size_;
136   FwdTxfm2dFunc fwd_txfm_;
137   int16_t *input_;
138   int32_t *output_;
139   double *ref_input_;
140   double *ref_output_;
141   int ud_flip_;  // flip upside down
142   int lr_flip_;  // flip left to right
143 };
144 
145 static double avg_error_ls[TX_SIZES_ALL] = {
146   0.5,   // 4x4 transform
147   0.5,   // 8x8 transform
148   1.2,   // 16x16 transform
149   6.1,   // 32x32 transform
150   3.4,   // 64x64 transform
151   0.57,  // 4x8 transform
152   0.68,  // 8x4 transform
153   0.92,  // 8x16 transform
154   1.1,   // 16x8 transform
155   4.1,   // 16x32 transform
156   6,     // 32x16 transform
157   3.5,   // 32x64 transform
158   5.7,   // 64x32 transform
159   0.6,   // 4x16 transform
160   0.9,   // 16x4 transform
161   1.2,   // 8x32 transform
162   1.7,   // 32x8 transform
163   2.0,   // 16x64 transform
164   4.7,   // 64x16 transform
165 };
166 
167 static double max_error_ls[TX_SIZES_ALL] = {
168   3,    // 4x4 transform
169   5,    // 8x8 transform
170   11,   // 16x16 transform
171   70,   // 32x32 transform
172   64,   // 64x64 transform
173   3.9,  // 4x8 transform
174   4.3,  // 8x4 transform
175   12,   // 8x16 transform
176   12,   // 16x8 transform
177   32,   // 16x32 transform
178   46,   // 32x16 transform
179   136,  // 32x64 transform
180   136,  // 64x32 transform
181   5,    // 4x16 transform
182   6,    // 16x4 transform
183   21,   // 8x32 transform
184   13,   // 32x8 transform
185   30,   // 16x64 transform
186   36,   // 64x16 transform
187 };
188 
GetTxfm2dParamList()189 vector<AV1FwdTxfm2dParam> GetTxfm2dParamList() {
190   vector<AV1FwdTxfm2dParam> param_list;
191   for (int s = 0; s < TX_SIZES; ++s) {
192     const double max_error = max_error_ls[s];
193     const double avg_error = avg_error_ls[s];
194     for (int t = 0; t < TX_TYPES; ++t) {
195       const TX_TYPE tx_type = static_cast<TX_TYPE>(t);
196       const TX_SIZE tx_size = static_cast<TX_SIZE>(s);
197       if (libaom_test::IsTxSizeTypeValid(tx_size, tx_type)) {
198         param_list.push_back(
199             AV1FwdTxfm2dParam(tx_type, tx_size, max_error, avg_error));
200       }
201     }
202   }
203   return param_list;
204 }
205 
206 INSTANTIATE_TEST_SUITE_P(C, AV1FwdTxfm2d,
207                          ::testing::ValuesIn(GetTxfm2dParamList()));
208 
TEST_P(AV1FwdTxfm2d,RunFwdAccuracyCheck)209 TEST_P(AV1FwdTxfm2d, RunFwdAccuracyCheck) { RunFwdAccuracyCheck(); }
210 
TEST(AV1FwdTxfm2d,CfgTest)211 TEST(AV1FwdTxfm2d, CfgTest) {
212   for (int bd_idx = 0; bd_idx < BD_NUM; ++bd_idx) {
213     int bd = libaom_test::bd_arr[bd_idx];
214     int8_t low_range = libaom_test::low_range_arr[bd_idx];
215     int8_t high_range = libaom_test::high_range_arr[bd_idx];
216     for (int tx_size = 0; tx_size < TX_SIZES_ALL; ++tx_size) {
217       for (int tx_type = 0; tx_type < TX_TYPES; ++tx_type) {
218         if (libaom_test::IsTxSizeTypeValid(static_cast<TX_SIZE>(tx_size),
219                                            static_cast<TX_TYPE>(tx_type)) ==
220             false) {
221           continue;
222         }
223         TXFM_2D_FLIP_CFG cfg;
224         av1_get_fwd_txfm_cfg(static_cast<TX_TYPE>(tx_type),
225                              static_cast<TX_SIZE>(tx_size), &cfg);
226         int8_t stage_range_col[MAX_TXFM_STAGE_NUM];
227         int8_t stage_range_row[MAX_TXFM_STAGE_NUM];
228         av1_gen_fwd_stage_range(stage_range_col, stage_range_row, &cfg, bd);
229         libaom_test::txfm_stage_range_check(stage_range_col, cfg.stage_num_col,
230                                             cfg.cos_bit_col, low_range,
231                                             high_range);
232         libaom_test::txfm_stage_range_check(stage_range_row, cfg.stage_num_row,
233                                             cfg.cos_bit_row, low_range,
234                                             high_range);
235       }
236     }
237   }
238 }
239 
240 typedef void (*lowbd_fwd_txfm_func)(const int16_t *src_diff, tran_low_t *coeff,
241                                     int diff_stride, TxfmParam *txfm_param);
242 
AV1FwdTxfm2dMatchTest(TX_SIZE tx_size,lowbd_fwd_txfm_func target_func)243 void AV1FwdTxfm2dMatchTest(TX_SIZE tx_size, lowbd_fwd_txfm_func target_func) {
244   const int bd = 8;
245   TxfmParam param;
246   memset(&param, 0, sizeof(param));
247   const int rows = tx_size_high[tx_size];
248   const int cols = tx_size_wide[tx_size];
249   // printf("%d x %d\n", cols, rows);
250   for (int tx_type = 0; tx_type < TX_TYPES; ++tx_type) {
251     if (libaom_test::IsTxSizeTypeValid(
252             tx_size, static_cast<TX_TYPE>(tx_type)) == false) {
253       continue;
254     }
255 
256     FwdTxfm2dFunc ref_func = libaom_test::fwd_txfm_func_ls[tx_size];
257     if (ref_func != nullptr) {
258       DECLARE_ALIGNED(32, int16_t, input[64 * 64]) = { 0 };
259       DECLARE_ALIGNED(32, int32_t, output[64 * 64]);
260       DECLARE_ALIGNED(32, int32_t, ref_output[64 * 64]);
261       int input_stride = 64;
262       ACMRandom rnd(ACMRandom::DeterministicSeed());
263       for (int cnt = 0; cnt < 500; ++cnt) {
264         if (cnt == 0) {
265           for (int c = 0; c < cols; ++c) {
266             for (int r = 0; r < rows; ++r) {
267               input[r * input_stride + c] = (1 << bd) - 1;
268             }
269           }
270         } else {
271           for (int r = 0; r < rows; ++r) {
272             for (int c = 0; c < cols; ++c) {
273               input[r * input_stride + c] = rnd.Rand16() % (1 << bd);
274             }
275           }
276         }
277         param.tx_type = (TX_TYPE)tx_type;
278         param.tx_size = (TX_SIZE)tx_size;
279         param.tx_set_type = EXT_TX_SET_ALL16;
280         param.bd = bd;
281         ref_func(input, ref_output, input_stride, (TX_TYPE)tx_type, bd);
282         target_func(input, output, input_stride, &param);
283         const int check_cols = AOMMIN(32, cols);
284         const int check_rows = AOMMIN(32, rows * cols / check_cols);
285         for (int r = 0; r < check_rows; ++r) {
286           for (int c = 0; c < check_cols; ++c) {
287             ASSERT_EQ(ref_output[r * check_cols + c],
288                       output[r * check_cols + c])
289                 << "[" << r << "," << c << "] cnt:" << cnt
290                 << " tx_size: " << cols << "x" << rows
291                 << " tx_type: " << tx_type_name[tx_type];
292           }
293         }
294       }
295     }
296   }
297 }
298 
AV1FwdTxfm2dSpeedTest(TX_SIZE tx_size,lowbd_fwd_txfm_func target_func)299 void AV1FwdTxfm2dSpeedTest(TX_SIZE tx_size, lowbd_fwd_txfm_func target_func) {
300   TxfmParam param;
301   memset(&param, 0, sizeof(param));
302   const int rows = tx_size_high[tx_size];
303   const int cols = tx_size_wide[tx_size];
304   const int num_loops = 1000000 / (rows * cols);
305 
306   const int bd = 8;
307   for (int tx_type = 0; tx_type < TX_TYPES; ++tx_type) {
308     if (libaom_test::IsTxSizeTypeValid(
309             tx_size, static_cast<TX_TYPE>(tx_type)) == false) {
310       continue;
311     }
312 
313     FwdTxfm2dFunc ref_func = libaom_test::fwd_txfm_func_ls[tx_size];
314     if (ref_func != nullptr) {
315       DECLARE_ALIGNED(32, int16_t, input[64 * 64]) = { 0 };
316       DECLARE_ALIGNED(32, int32_t, output[64 * 64]);
317       DECLARE_ALIGNED(32, int32_t, ref_output[64 * 64]);
318       int input_stride = 64;
319       ACMRandom rnd(ACMRandom::DeterministicSeed());
320 
321       for (int r = 0; r < rows; ++r) {
322         for (int c = 0; c < cols; ++c) {
323           input[r * input_stride + c] = rnd.Rand16() % (1 << bd);
324         }
325       }
326 
327       param.tx_type = (TX_TYPE)tx_type;
328       param.tx_size = (TX_SIZE)tx_size;
329       param.tx_set_type = EXT_TX_SET_ALL16;
330       param.bd = bd;
331 
332       aom_usec_timer ref_timer, test_timer;
333 
334       aom_usec_timer_start(&ref_timer);
335       for (int i = 0; i < num_loops; ++i) {
336         ref_func(input, ref_output, input_stride, (TX_TYPE)tx_type, bd);
337       }
338       aom_usec_timer_mark(&ref_timer);
339       const int elapsed_time_c =
340           static_cast<int>(aom_usec_timer_elapsed(&ref_timer));
341 
342       aom_usec_timer_start(&test_timer);
343       for (int i = 0; i < num_loops; ++i) {
344         target_func(input, output, input_stride, &param);
345       }
346       aom_usec_timer_mark(&test_timer);
347       const int elapsed_time_simd =
348           static_cast<int>(aom_usec_timer_elapsed(&test_timer));
349 
350       printf(
351           "txfm_size[%2dx%-2d] \t txfm_type[%d] \t c_time=%d \t"
352           "simd_time=%d \t gain=%d \n",
353           rows, cols, tx_type, elapsed_time_c, elapsed_time_simd,
354           (elapsed_time_c / elapsed_time_simd));
355     }
356   }
357 }
358 
359 typedef std::tuple<TX_SIZE, lowbd_fwd_txfm_func> LbdFwdTxfm2dParam;
360 
361 class AV1FwdTxfm2dTest : public ::testing::TestWithParam<LbdFwdTxfm2dParam> {};
362 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AV1FwdTxfm2dTest);
363 
TEST_P(AV1FwdTxfm2dTest,match)364 TEST_P(AV1FwdTxfm2dTest, match) {
365   AV1FwdTxfm2dMatchTest(GET_PARAM(0), GET_PARAM(1));
366 }
TEST_P(AV1FwdTxfm2dTest,DISABLED_Speed)367 TEST_P(AV1FwdTxfm2dTest, DISABLED_Speed) {
368   AV1FwdTxfm2dSpeedTest(GET_PARAM(0), GET_PARAM(1));
369 }
TEST(AV1FwdTxfm2dTest,DCTScaleTest)370 TEST(AV1FwdTxfm2dTest, DCTScaleTest) {
371   BitDepthInfo bd_info;
372   bd_info.bit_depth = 8;
373   bd_info.use_highbitdepth_buf = 0;
374   DECLARE_ALIGNED(32, int16_t, src_diff[1024]);
375   DECLARE_ALIGNED(32, tran_low_t, coeff[1024]);
376 
377   const TX_SIZE tx_size_list[4] = { TX_4X4, TX_8X8, TX_16X16, TX_32X32 };
378   const int stride_list[4] = { 4, 8, 16, 32 };
379   const int ref_scale_list[4] = { 64, 64, 64, 16 };
380 
381   for (int i = 0; i < 4; i++) {
382     TX_SIZE tx_size = tx_size_list[i];
383     int stride = stride_list[i];
384     int array_size = stride * stride;
385 
386     for (int j = 0; j < array_size; j++) {
387       src_diff[j] = 8;
388       coeff[j] = 0;
389     }
390 
391     av1_quick_txfm(/*use_hadamard=*/0, tx_size, bd_info, src_diff, stride,
392                    coeff);
393 
394     double input_sse = 0;
395     double output_sse = 0;
396     for (int j = 0; j < array_size; j++) {
397       input_sse += pow(src_diff[j], 2);
398       output_sse += pow(coeff[j], 2);
399     }
400 
401     double scale = output_sse / input_sse;
402 
403     EXPECT_NEAR(scale, ref_scale_list[i], 5);
404   }
405 }
TEST(AV1FwdTxfm2dTest,HadamardScaleTest)406 TEST(AV1FwdTxfm2dTest, HadamardScaleTest) {
407   BitDepthInfo bd_info;
408   bd_info.bit_depth = 8;
409   bd_info.use_highbitdepth_buf = 0;
410   DECLARE_ALIGNED(32, int16_t, src_diff[1024]);
411   DECLARE_ALIGNED(32, tran_low_t, coeff[1024]);
412 
413   const TX_SIZE tx_size_list[4] = { TX_4X4, TX_8X8, TX_16X16, TX_32X32 };
414   const int stride_list[4] = { 4, 8, 16, 32 };
415   const int ref_scale_list[4] = { 1, 64, 64, 16 };
416 
417   for (int i = 0; i < 4; i++) {
418     TX_SIZE tx_size = tx_size_list[i];
419     int stride = stride_list[i];
420     int array_size = stride * stride;
421 
422     for (int j = 0; j < array_size; j++) {
423       src_diff[j] = 8;
424       coeff[j] = 0;
425     }
426 
427     av1_quick_txfm(/*use_hadamard=*/1, tx_size, bd_info, src_diff, stride,
428                    coeff);
429 
430     double input_sse = 0;
431     double output_sse = 0;
432     for (int j = 0; j < array_size; j++) {
433       input_sse += pow(src_diff[j], 2);
434       output_sse += pow(coeff[j], 2);
435     }
436 
437     double scale = output_sse / input_sse;
438 
439     EXPECT_NEAR(scale, ref_scale_list[i], 5);
440   }
441 }
442 using ::testing::Combine;
443 using ::testing::Values;
444 using ::testing::ValuesIn;
445 
446 #if AOM_ARCH_X86 && HAVE_SSE2
447 static TX_SIZE fwd_txfm_for_sse2[] = {
448   TX_4X4,
449   TX_8X8,
450   TX_16X16,
451   TX_32X32,
452   // TX_64X64,
453   TX_4X8,
454   TX_8X4,
455   TX_8X16,
456   TX_16X8,
457   TX_16X32,
458   TX_32X16,
459   // TX_32X64,
460   // TX_64X32,
461   TX_4X16,
462   TX_16X4,
463   TX_8X32,
464   TX_32X8,
465   TX_16X64,
466   TX_64X16,
467 };
468 
469 INSTANTIATE_TEST_SUITE_P(SSE2, AV1FwdTxfm2dTest,
470                          Combine(ValuesIn(fwd_txfm_for_sse2),
471                                  Values(av1_lowbd_fwd_txfm_sse2)));
472 #endif  // AOM_ARCH_X86 && HAVE_SSE2
473 
474 #if HAVE_SSE4_1
475 static TX_SIZE fwd_txfm_for_sse41[] = { TX_4X4,   TX_8X8,   TX_16X16, TX_32X32,
476                                         TX_64X64, TX_4X8,   TX_8X4,   TX_8X16,
477                                         TX_16X8,  TX_16X32, TX_32X16, TX_32X64,
478                                         TX_64X32, TX_4X16,  TX_16X4,  TX_8X32,
479                                         TX_32X8,  TX_16X64, TX_64X16 };
480 
481 INSTANTIATE_TEST_SUITE_P(SSE4_1, AV1FwdTxfm2dTest,
482                          Combine(ValuesIn(fwd_txfm_for_sse41),
483                                  Values(av1_lowbd_fwd_txfm_sse4_1)));
484 #endif  // HAVE_SSE4_1
485 
486 #if HAVE_AVX2
487 static TX_SIZE fwd_txfm_for_avx2[] = {
488   TX_4X4,  TX_8X8,  TX_16X16, TX_32X32, TX_64X64, TX_4X8,   TX_8X4,
489   TX_8X16, TX_16X8, TX_16X32, TX_32X16, TX_32X64, TX_64X32, TX_4X16,
490   TX_16X4, TX_8X32, TX_32X8,  TX_16X64, TX_64X16,
491 };
492 
493 INSTANTIATE_TEST_SUITE_P(AVX2, AV1FwdTxfm2dTest,
494                          Combine(ValuesIn(fwd_txfm_for_avx2),
495                                  Values(av1_lowbd_fwd_txfm_avx2)));
496 #endif  // HAVE_AVX2
497 
498 #if HAVE_NEON
499 
500 static TX_SIZE fwd_txfm_for_neon[] = { TX_4X4,   TX_8X8,   TX_16X16, TX_32X32,
501                                        TX_64X64, TX_4X8,   TX_8X4,   TX_8X16,
502                                        TX_16X8,  TX_16X32, TX_32X16, TX_32X64,
503                                        TX_64X32, TX_4X16,  TX_16X4,  TX_8X32,
504                                        TX_32X8,  TX_16X64, TX_64X16 };
505 
506 INSTANTIATE_TEST_SUITE_P(NEON, AV1FwdTxfm2dTest,
507                          Combine(ValuesIn(fwd_txfm_for_neon),
508                                  Values(av1_lowbd_fwd_txfm_neon)));
509 
510 #endif  // HAVE_NEON
511 
512 typedef void (*Highbd_fwd_txfm_func)(const int16_t *src_diff, tran_low_t *coeff,
513                                      int diff_stride, TxfmParam *txfm_param);
514 
AV1HighbdFwdTxfm2dMatchTest(TX_SIZE tx_size,Highbd_fwd_txfm_func target_func)515 void AV1HighbdFwdTxfm2dMatchTest(TX_SIZE tx_size,
516                                  Highbd_fwd_txfm_func target_func) {
517   const int bd_ar[2] = { 10, 12 };
518   TxfmParam param;
519   memset(&param, 0, sizeof(param));
520   const int rows = tx_size_high[tx_size];
521   const int cols = tx_size_wide[tx_size];
522   for (int i = 0; i < 2; ++i) {
523     const int bd = bd_ar[i];
524     for (int tx_type = 0; tx_type < TX_TYPES; ++tx_type) {
525       if (libaom_test::IsTxSizeTypeValid(
526               tx_size, static_cast<TX_TYPE>(tx_type)) == false) {
527         continue;
528       }
529 
530       FwdTxfm2dFunc ref_func = libaom_test::fwd_txfm_func_ls[tx_size];
531       if (ref_func != nullptr) {
532         DECLARE_ALIGNED(32, int16_t, input[64 * 64]) = { 0 };
533         DECLARE_ALIGNED(32, int32_t, output[64 * 64]);
534         DECLARE_ALIGNED(32, int32_t, ref_output[64 * 64]);
535         int input_stride = 64;
536         ACMRandom rnd(ACMRandom::DeterministicSeed());
537         for (int cnt = 0; cnt < 500; ++cnt) {
538           if (cnt == 0) {
539             for (int r = 0; r < rows; ++r) {
540               for (int c = 0; c < cols; ++c) {
541                 input[r * input_stride + c] = (1 << bd) - 1;
542               }
543             }
544           } else {
545             for (int r = 0; r < rows; ++r) {
546               for (int c = 0; c < cols; ++c) {
547                 input[r * input_stride + c] = rnd.Rand16() % (1 << bd);
548               }
549             }
550           }
551           param.tx_type = (TX_TYPE)tx_type;
552           param.tx_size = (TX_SIZE)tx_size;
553           param.tx_set_type = EXT_TX_SET_ALL16;
554           param.bd = bd;
555 
556           ref_func(input, ref_output, input_stride, (TX_TYPE)tx_type, bd);
557           target_func(input, output, input_stride, &param);
558           const int check_cols = AOMMIN(32, cols);
559           const int check_rows = AOMMIN(32, rows * cols / check_cols);
560           for (int r = 0; r < check_rows; ++r) {
561             for (int c = 0; c < check_cols; ++c) {
562               ASSERT_EQ(ref_output[c * check_rows + r],
563                         output[c * check_rows + r])
564                   << "[" << r << "," << c << "] cnt:" << cnt
565                   << " tx_size: " << cols << "x" << rows
566                   << " tx_type: " << tx_type;
567             }
568           }
569         }
570       }
571     }
572   }
573 }
574 
AV1HighbdFwdTxfm2dSpeedTest(TX_SIZE tx_size,Highbd_fwd_txfm_func target_func)575 void AV1HighbdFwdTxfm2dSpeedTest(TX_SIZE tx_size,
576                                  Highbd_fwd_txfm_func target_func) {
577   const int bd_ar[2] = { 10, 12 };
578   TxfmParam param;
579   memset(&param, 0, sizeof(param));
580   const int rows = tx_size_high[tx_size];
581   const int cols = tx_size_wide[tx_size];
582   const int num_loops = 1000000 / (rows * cols);
583 
584   for (int i = 0; i < 2; ++i) {
585     const int bd = bd_ar[i];
586     for (int tx_type = 0; tx_type < TX_TYPES; ++tx_type) {
587       if (libaom_test::IsTxSizeTypeValid(
588               tx_size, static_cast<TX_TYPE>(tx_type)) == false) {
589         continue;
590       }
591 
592       FwdTxfm2dFunc ref_func = libaom_test::fwd_txfm_func_ls[tx_size];
593       if (ref_func != nullptr) {
594         DECLARE_ALIGNED(32, int16_t, input[64 * 64]) = { 0 };
595         DECLARE_ALIGNED(32, int32_t, output[64 * 64]);
596         DECLARE_ALIGNED(32, int32_t, ref_output[64 * 64]);
597         int input_stride = 64;
598         ACMRandom rnd(ACMRandom::DeterministicSeed());
599 
600         for (int r = 0; r < rows; ++r) {
601           for (int c = 0; c < cols; ++c) {
602             input[r * input_stride + c] = rnd.Rand16() % (1 << bd);
603           }
604         }
605 
606         param.tx_type = (TX_TYPE)tx_type;
607         param.tx_size = (TX_SIZE)tx_size;
608         param.tx_set_type = EXT_TX_SET_ALL16;
609         param.bd = bd;
610 
611         aom_usec_timer ref_timer, test_timer;
612 
613         aom_usec_timer_start(&ref_timer);
614         for (int j = 0; j < num_loops; ++j) {
615           ref_func(input, ref_output, input_stride, (TX_TYPE)tx_type, bd);
616         }
617         aom_usec_timer_mark(&ref_timer);
618         const int elapsed_time_c =
619             static_cast<int>(aom_usec_timer_elapsed(&ref_timer));
620 
621         aom_usec_timer_start(&test_timer);
622         for (int j = 0; j < num_loops; ++j) {
623           target_func(input, output, input_stride, &param);
624         }
625         aom_usec_timer_mark(&test_timer);
626         const int elapsed_time_simd =
627             static_cast<int>(aom_usec_timer_elapsed(&test_timer));
628 
629         printf(
630             "txfm_size[%2dx%-2d] \t txfm_type[%d] \t c_time=%d \t"
631             "simd_time=%d \t gain=%d \n",
632             cols, rows, tx_type, elapsed_time_c, elapsed_time_simd,
633             (elapsed_time_c / elapsed_time_simd));
634       }
635     }
636   }
637 }
638 
639 typedef std::tuple<TX_SIZE, Highbd_fwd_txfm_func> HighbdFwdTxfm2dParam;
640 
641 class AV1HighbdFwdTxfm2dTest
642     : public ::testing::TestWithParam<HighbdFwdTxfm2dParam> {};
643 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AV1HighbdFwdTxfm2dTest);
644 
TEST_P(AV1HighbdFwdTxfm2dTest,match)645 TEST_P(AV1HighbdFwdTxfm2dTest, match) {
646   AV1HighbdFwdTxfm2dMatchTest(GET_PARAM(0), GET_PARAM(1));
647 }
648 
TEST_P(AV1HighbdFwdTxfm2dTest,DISABLED_Speed)649 TEST_P(AV1HighbdFwdTxfm2dTest, DISABLED_Speed) {
650   AV1HighbdFwdTxfm2dSpeedTest(GET_PARAM(0), GET_PARAM(1));
651 }
652 
653 using ::testing::Combine;
654 using ::testing::Values;
655 using ::testing::ValuesIn;
656 
657 #if HAVE_SSE4_1
658 static TX_SIZE Highbd_fwd_txfm_for_sse4_1[] = {
659   TX_4X4,  TX_8X8,  TX_16X16, TX_32X32, TX_64X64, TX_4X8,   TX_8X4,
660   TX_8X16, TX_16X8, TX_16X32, TX_32X16, TX_32X64, TX_64X32,
661 #if !CONFIG_REALTIME_ONLY
662   TX_4X16, TX_16X4, TX_8X32,  TX_32X8,  TX_16X64, TX_64X16,
663 #endif  // !CONFIG_REALTIME_ONLY
664 };
665 
666 INSTANTIATE_TEST_SUITE_P(SSE4_1, AV1HighbdFwdTxfm2dTest,
667                          Combine(ValuesIn(Highbd_fwd_txfm_for_sse4_1),
668                                  Values(av1_highbd_fwd_txfm)));
669 #endif  // HAVE_SSE4_1
670 #if HAVE_AVX2
671 static TX_SIZE Highbd_fwd_txfm_for_avx2[] = { TX_8X8,   TX_16X16, TX_32X32,
672                                               TX_64X64, TX_8X16,  TX_16X8 };
673 
674 INSTANTIATE_TEST_SUITE_P(AVX2, AV1HighbdFwdTxfm2dTest,
675                          Combine(ValuesIn(Highbd_fwd_txfm_for_avx2),
676                                  Values(av1_highbd_fwd_txfm)));
677 #endif  // HAVE_AVX2
678 
679 #if HAVE_NEON
680 static TX_SIZE Highbd_fwd_txfm_for_neon[] = {
681   TX_4X4,  TX_8X8,  TX_16X16, TX_32X32, TX_64X64, TX_4X8,   TX_8X4,
682   TX_8X16, TX_16X8, TX_16X32, TX_32X16, TX_32X64, TX_64X32,
683 #if !CONFIG_REALTIME_ONLY
684   TX_4X16, TX_16X4, TX_8X32,  TX_32X8,  TX_16X64, TX_64X16
685 #endif  // !CONFIG_REALTIME_ONLY
686 };
687 
688 INSTANTIATE_TEST_SUITE_P(NEON, AV1HighbdFwdTxfm2dTest,
689                          Combine(ValuesIn(Highbd_fwd_txfm_for_neon),
690                                  Values(av1_highbd_fwd_txfm)));
691 #endif  // HAVE_NEON
692 
693 }  // namespace
694