xref: /aosp_15_r20/external/libaom/test/transform_test_base.h (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #ifndef AOM_TEST_TRANSFORM_TEST_BASE_H_
13 #define AOM_TEST_TRANSFORM_TEST_BASE_H_
14 
15 #include "gtest/gtest.h"
16 
17 #include "aom/aom_codec.h"
18 #include "aom_dsp/txfm_common.h"
19 #include "aom_mem/aom_mem.h"
20 #include "test/acm_random.h"
21 
22 namespace libaom_test {
23 
24 //  Note:
25 //   Same constant are defined in av1/common/av1_entropy.h and
26 //   av1/common/entropy.h.  Goal is to make this base class
27 //   to use for future codec transform testing.  But including
28 //   either of them would lead to compiling error when we do
29 //   unit test for another codec. Suggest to move the definition
30 //   to a aom header file.
31 const int kDctMaxValue = 16384;
32 
33 template <typename OutputType>
34 using FhtFunc = void (*)(const int16_t *in, OutputType *out, int stride,
35                          TxfmParam *txfm_param);
36 
37 template <typename OutputType>
38 using IhtFunc = void (*)(const tran_low_t *in, uint8_t *out, int stride,
39                          const TxfmParam *txfm_param);
40 
41 template <typename OutType>
42 class TransformTestBase {
43  public:
44   virtual ~TransformTestBase() = default;
45 
46  protected:
47   virtual void RunFwdTxfm(const int16_t *in, OutType *out, int stride) = 0;
48 
49   virtual void RunInvTxfm(const OutType *out, uint8_t *dst, int stride) = 0;
50 
RunAccuracyCheck(uint32_t ref_max_error,double ref_avg_error)51   void RunAccuracyCheck(uint32_t ref_max_error, double ref_avg_error) {
52     ACMRandom rnd(ACMRandom::DeterministicSeed());
53     uint32_t max_error = 0;
54     int64_t total_error = 0;
55     const int count_test_block = 10000;
56 
57     int16_t *test_input_block = reinterpret_cast<int16_t *>(
58         aom_memalign(16, sizeof(int16_t) * num_coeffs_));
59     ASSERT_NE(test_input_block, nullptr);
60     OutType *test_temp_block = reinterpret_cast<OutType *>(
61         aom_memalign(16, sizeof(test_temp_block[0]) * num_coeffs_));
62     ASSERT_NE(test_temp_block, nullptr);
63     uint8_t *dst = reinterpret_cast<uint8_t *>(
64         aom_memalign(16, sizeof(uint8_t) * num_coeffs_));
65     ASSERT_NE(dst, nullptr);
66     uint8_t *src = reinterpret_cast<uint8_t *>(
67         aom_memalign(16, sizeof(uint8_t) * num_coeffs_));
68     ASSERT_NE(src, nullptr);
69     uint16_t *dst16 = reinterpret_cast<uint16_t *>(
70         aom_memalign(16, sizeof(uint16_t) * num_coeffs_));
71     ASSERT_NE(dst16, nullptr);
72     uint16_t *src16 = reinterpret_cast<uint16_t *>(
73         aom_memalign(16, sizeof(uint16_t) * num_coeffs_));
74     ASSERT_NE(src16, nullptr);
75 
76     for (int i = 0; i < count_test_block; ++i) {
77       // Initialize a test block with input range [-255, 255].
78       for (int j = 0; j < num_coeffs_; ++j) {
79         if (bit_depth_ == AOM_BITS_8) {
80           src[j] = rnd.Rand8();
81           dst[j] = rnd.Rand8();
82           test_input_block[j] = src[j] - dst[j];
83         } else {
84           src16[j] = rnd.Rand16() & mask_;
85           dst16[j] = rnd.Rand16() & mask_;
86           test_input_block[j] = src16[j] - dst16[j];
87         }
88       }
89 
90       API_REGISTER_STATE_CHECK(
91           RunFwdTxfm(test_input_block, test_temp_block, pitch_));
92       if (bit_depth_ == AOM_BITS_8) {
93         API_REGISTER_STATE_CHECK(RunInvTxfm(test_temp_block, dst, pitch_));
94       } else {
95         API_REGISTER_STATE_CHECK(
96             RunInvTxfm(test_temp_block, CONVERT_TO_BYTEPTR(dst16), pitch_));
97       }
98 
99       for (int j = 0; j < num_coeffs_; ++j) {
100         const int diff =
101             bit_depth_ == AOM_BITS_8 ? dst[j] - src[j] : dst16[j] - src16[j];
102         const uint32_t error = diff * diff;
103         if (max_error < error) max_error = error;
104         total_error += error;
105       }
106     }
107 
108     double avg_error = total_error * 1. / count_test_block / num_coeffs_;
109 
110     EXPECT_GE(ref_max_error, max_error)
111         << "Error: FHT/IHT has an individual round trip error > "
112         << ref_max_error;
113 
114     EXPECT_GE(ref_avg_error, avg_error)
115         << "Error: FHT/IHT has average round trip error > " << ref_avg_error
116         << " per block";
117 
118     aom_free(test_input_block);
119     aom_free(test_temp_block);
120     aom_free(dst);
121     aom_free(src);
122     aom_free(dst16);
123     aom_free(src16);
124   }
125 
RunCoeffCheck()126   void RunCoeffCheck() {
127     ACMRandom rnd(ACMRandom::DeterministicSeed());
128     const int count_test_block = 5000;
129 
130     // Use a stride value which is not the width of any transform, to catch
131     // cases where the transforms use the stride incorrectly.
132     int stride = 96;
133 
134     int16_t *input_block = reinterpret_cast<int16_t *>(
135         aom_memalign(16, sizeof(int16_t) * stride * height_));
136     ASSERT_NE(input_block, nullptr);
137     OutType *output_ref_block = reinterpret_cast<OutType *>(
138         aom_memalign(16, sizeof(output_ref_block[0]) * num_coeffs_));
139     ASSERT_NE(output_ref_block, nullptr);
140     OutType *output_block = reinterpret_cast<OutType *>(
141         aom_memalign(16, sizeof(output_block[0]) * num_coeffs_));
142     ASSERT_NE(output_block, nullptr);
143 
144     for (int i = 0; i < count_test_block; ++i) {
145       int j, k;
146       for (j = 0; j < height_; ++j) {
147         for (k = 0; k < pitch_; ++k) {
148           int in_idx = j * stride + k;
149           int out_idx = j * pitch_ + k;
150           input_block[in_idx] = (rnd.Rand16() & mask_) - (rnd.Rand16() & mask_);
151           if (bit_depth_ == AOM_BITS_8) {
152             output_block[out_idx] = output_ref_block[out_idx] = rnd.Rand8();
153           } else {
154             output_block[out_idx] = output_ref_block[out_idx] =
155                 rnd.Rand16() & mask_;
156           }
157         }
158       }
159 
160       fwd_txfm_ref(input_block, output_ref_block, stride, &txfm_param_);
161       API_REGISTER_STATE_CHECK(RunFwdTxfm(input_block, output_block, stride));
162 
163       // The minimum quant value is 4.
164       for (j = 0; j < height_; ++j) {
165         for (k = 0; k < pitch_; ++k) {
166           int out_idx = j * pitch_ + k;
167           ASSERT_EQ(output_block[out_idx], output_ref_block[out_idx])
168               << "Error: not bit-exact result at index: " << out_idx
169               << " at test block: " << i;
170         }
171       }
172     }
173     aom_free(input_block);
174     aom_free(output_ref_block);
175     aom_free(output_block);
176   }
177 
RunInvCoeffCheck()178   void RunInvCoeffCheck() {
179     ACMRandom rnd(ACMRandom::DeterministicSeed());
180     const int count_test_block = 5000;
181 
182     // Use a stride value which is not the width of any transform, to catch
183     // cases where the transforms use the stride incorrectly.
184     int stride = 96;
185 
186     int16_t *input_block = reinterpret_cast<int16_t *>(
187         aom_memalign(16, sizeof(int16_t) * num_coeffs_));
188     ASSERT_NE(input_block, nullptr);
189     OutType *trans_block = reinterpret_cast<OutType *>(
190         aom_memalign(16, sizeof(trans_block[0]) * num_coeffs_));
191     ASSERT_NE(trans_block, nullptr);
192     uint8_t *output_block = reinterpret_cast<uint8_t *>(
193         aom_memalign(16, sizeof(uint8_t) * stride * height_));
194     ASSERT_NE(output_block, nullptr);
195     uint8_t *output_ref_block = reinterpret_cast<uint8_t *>(
196         aom_memalign(16, sizeof(uint8_t) * stride * height_));
197     ASSERT_NE(output_ref_block, nullptr);
198 
199     for (int i = 0; i < count_test_block; ++i) {
200       // Initialize a test block with input range [-mask_, mask_].
201       int j, k;
202       for (j = 0; j < height_; ++j) {
203         for (k = 0; k < pitch_; ++k) {
204           int in_idx = j * pitch_ + k;
205           int out_idx = j * stride + k;
206           input_block[in_idx] = (rnd.Rand16() & mask_) - (rnd.Rand16() & mask_);
207           output_ref_block[out_idx] = rnd.Rand16() & mask_;
208           output_block[out_idx] = output_ref_block[out_idx];
209         }
210       }
211 
212       fwd_txfm_ref(input_block, trans_block, pitch_, &txfm_param_);
213 
214       inv_txfm_ref(trans_block, output_ref_block, stride, &txfm_param_);
215       API_REGISTER_STATE_CHECK(RunInvTxfm(trans_block, output_block, stride));
216 
217       for (j = 0; j < height_; ++j) {
218         for (k = 0; k < pitch_; ++k) {
219           int out_idx = j * stride + k;
220           ASSERT_EQ(output_block[out_idx], output_ref_block[out_idx])
221               << "Error: not bit-exact result at index: " << out_idx
222               << " j = " << j << " k = " << k << " at test block: " << i;
223         }
224       }
225     }
226     aom_free(input_block);
227     aom_free(trans_block);
228     aom_free(output_ref_block);
229     aom_free(output_block);
230   }
231 
RunMemCheck()232   void RunMemCheck() {
233     ACMRandom rnd(ACMRandom::DeterministicSeed());
234     const int count_test_block = 5000;
235 
236     int16_t *input_extreme_block = reinterpret_cast<int16_t *>(
237         aom_memalign(16, sizeof(int16_t) * num_coeffs_));
238     ASSERT_NE(input_extreme_block, nullptr);
239     OutType *output_ref_block = reinterpret_cast<OutType *>(
240         aom_memalign(16, sizeof(output_ref_block[0]) * num_coeffs_));
241     ASSERT_NE(output_ref_block, nullptr);
242     OutType *output_block = reinterpret_cast<OutType *>(
243         aom_memalign(16, sizeof(output_block[0]) * num_coeffs_));
244     ASSERT_NE(output_block, nullptr);
245 
246     for (int i = 0; i < count_test_block; ++i) {
247       // Initialize a test block with input range [-mask_, mask_].
248       for (int j = 0; j < num_coeffs_; ++j) {
249         input_extreme_block[j] = rnd.Rand8() % 2 ? mask_ : -mask_;
250       }
251       if (i == 0) {
252         for (int j = 0; j < num_coeffs_; ++j) input_extreme_block[j] = mask_;
253       } else if (i == 1) {
254         for (int j = 0; j < num_coeffs_; ++j) input_extreme_block[j] = -mask_;
255       }
256 
257       fwd_txfm_ref(input_extreme_block, output_ref_block, pitch_, &txfm_param_);
258       API_REGISTER_STATE_CHECK(
259           RunFwdTxfm(input_extreme_block, output_block, pitch_));
260 
261       int row_length = FindRowLength();
262       // The minimum quant value is 4.
263       for (int j = 0; j < num_coeffs_; ++j) {
264         ASSERT_EQ(output_block[j], output_ref_block[j])
265             << "Not bit-exact at test index: " << i << ", "
266             << "j = " << j << std::endl;
267         EXPECT_GE(row_length * kDctMaxValue << (bit_depth_ - 8),
268                   abs(output_block[j]))
269             << "Error: NxN FDCT has coefficient larger than N*DCT_MAX_VALUE";
270       }
271     }
272     aom_free(input_extreme_block);
273     aom_free(output_ref_block);
274     aom_free(output_block);
275   }
276 
RunInvAccuracyCheck(int limit)277   void RunInvAccuracyCheck(int limit) {
278     ACMRandom rnd(ACMRandom::DeterministicSeed());
279     const int count_test_block = 1000;
280 
281     int16_t *in = reinterpret_cast<int16_t *>(
282         aom_memalign(16, sizeof(int16_t) * num_coeffs_));
283     ASSERT_NE(in, nullptr);
284     OutType *coeff = reinterpret_cast<OutType *>(
285         aom_memalign(16, sizeof(coeff[0]) * num_coeffs_));
286     ASSERT_NE(coeff, nullptr);
287     uint8_t *dst = reinterpret_cast<uint8_t *>(
288         aom_memalign(16, sizeof(uint8_t) * num_coeffs_));
289     ASSERT_NE(dst, nullptr);
290     uint8_t *src = reinterpret_cast<uint8_t *>(
291         aom_memalign(16, sizeof(uint8_t) * num_coeffs_));
292     ASSERT_NE(src, nullptr);
293 
294     uint16_t *dst16 = reinterpret_cast<uint16_t *>(
295         aom_memalign(16, sizeof(uint16_t) * num_coeffs_));
296     ASSERT_NE(dst16, nullptr);
297     uint16_t *src16 = reinterpret_cast<uint16_t *>(
298         aom_memalign(16, sizeof(uint16_t) * num_coeffs_));
299     ASSERT_NE(src16, nullptr);
300 
301     for (int i = 0; i < count_test_block; ++i) {
302       // Initialize a test block with input range [-mask_, mask_].
303       for (int j = 0; j < num_coeffs_; ++j) {
304         if (bit_depth_ == AOM_BITS_8) {
305           src[j] = rnd.Rand8();
306           dst[j] = rnd.Rand8();
307           in[j] = src[j] - dst[j];
308         } else {
309           src16[j] = rnd.Rand16() & mask_;
310           dst16[j] = rnd.Rand16() & mask_;
311           in[j] = src16[j] - dst16[j];
312         }
313       }
314 
315       fwd_txfm_ref(in, coeff, pitch_, &txfm_param_);
316 
317       if (bit_depth_ == AOM_BITS_8) {
318         API_REGISTER_STATE_CHECK(RunInvTxfm(coeff, dst, pitch_));
319       } else {
320         API_REGISTER_STATE_CHECK(
321             RunInvTxfm(coeff, CONVERT_TO_BYTEPTR(dst16), pitch_));
322       }
323 
324       for (int j = 0; j < num_coeffs_; ++j) {
325         const int diff =
326             bit_depth_ == AOM_BITS_8 ? dst[j] - src[j] : dst16[j] - src16[j];
327         const uint32_t error = diff * diff;
328         ASSERT_GE(static_cast<uint32_t>(limit), error)
329             << "Error: 4x4 IDCT has error " << error << " at index " << j;
330       }
331     }
332     aom_free(in);
333     aom_free(coeff);
334     aom_free(dst);
335     aom_free(src);
336     aom_free(src16);
337     aom_free(dst16);
338   }
339 
340   int pitch_;
341   int height_;
342   FhtFunc<OutType> fwd_txfm_ref;
343   IhtFunc<OutType> inv_txfm_ref;
344   aom_bit_depth_t bit_depth_;
345   int mask_;
346   int num_coeffs_;
347   TxfmParam txfm_param_;
348 
349  private:
350   //  Assume transform size is 4x4, 8x8, 16x16,...
FindRowLength()351   int FindRowLength() const {
352     int row = 4;
353     if (16 == num_coeffs_) {
354       row = 4;
355     } else if (64 == num_coeffs_) {
356       row = 8;
357     } else if (256 == num_coeffs_) {
358       row = 16;
359     } else if (1024 == num_coeffs_) {
360       row = 32;
361     }
362     return row;
363   }
364 };
365 
366 }  // namespace libaom_test
367 
368 #endif  // AOM_TEST_TRANSFORM_TEST_BASE_H_
369