xref: /aosp_15_r20/external/libaom/av1/encoder/tune_vmaf.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2019, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include "av1/encoder/tune_vmaf.h"
13 
14 #include "aom_dsp/psnr.h"
15 #include "av1/encoder/extend.h"
16 #include "av1/encoder/rdopt.h"
17 #include "config/aom_scale_rtcd.h"
18 
19 static const double kBaselineVmaf = 97.42773;
20 
get_layer_value(const double * array,int layer)21 static double get_layer_value(const double *array, int layer) {
22   while (array[layer] < 0.0 && layer > 0) layer--;
23   return AOMMAX(array[layer], 0.0);
24 }
25 
motion_search(AV1_COMP * cpi,const YV12_BUFFER_CONFIG * src,const YV12_BUFFER_CONFIG * ref,const BLOCK_SIZE block_size,const int mb_row,const int mb_col,FULLPEL_MV * ref_mv)26 static void motion_search(AV1_COMP *cpi, const YV12_BUFFER_CONFIG *src,
27                           const YV12_BUFFER_CONFIG *ref,
28                           const BLOCK_SIZE block_size, const int mb_row,
29                           const int mb_col, FULLPEL_MV *ref_mv) {
30   // Block information (ONLY Y-plane is used for motion search).
31   const int mb_height = block_size_high[block_size];
32   const int mb_width = block_size_wide[block_size];
33   const int y_stride = src->y_stride;
34   assert(y_stride == ref->y_stride);
35   const int y_offset = mb_row * mb_height * y_stride + mb_col * mb_width;
36 
37   // Save input state.
38   MACROBLOCK *const mb = &cpi->td.mb;
39   MACROBLOCKD *const mbd = &mb->e_mbd;
40   const struct buf_2d ori_src_buf = mb->plane[0].src;
41   const struct buf_2d ori_pre_buf = mbd->plane[0].pre[0];
42 
43   // Parameters used for motion search.
44   FULLPEL_MOTION_SEARCH_PARAMS full_ms_params;
45   FULLPEL_MV_STATS best_mv_stats;
46   const SEARCH_METHODS search_method = NSTEP;
47   const search_site_config *search_site_cfg =
48       cpi->mv_search_params.search_site_cfg[SS_CFG_FPF];
49   const int step_param =
50       av1_init_search_range(AOMMAX(src->y_crop_width, src->y_crop_height));
51 
52   // Baseline position for motion search (used for rate distortion comparison).
53   const MV baseline_mv = kZeroMv;
54 
55   // Setup.
56   mb->plane[0].src.buf = src->y_buffer + y_offset;
57   mb->plane[0].src.stride = y_stride;
58   mbd->plane[0].pre[0].buf = ref->y_buffer + y_offset;
59   mbd->plane[0].pre[0].stride = y_stride;
60 
61   // Unused intermediate results for motion search.
62   int cost_list[5];
63 
64   // Do motion search.
65   // Only do full search on the entire block.
66   av1_make_default_fullpel_ms_params(&full_ms_params, cpi, mb, block_size,
67                                      &baseline_mv, *ref_mv, search_site_cfg,
68                                      search_method,
69                                      /*fine_search_interval=*/0);
70   av1_full_pixel_search(*ref_mv, &full_ms_params, step_param,
71                         cond_cost_list(cpi, cost_list), ref_mv, &best_mv_stats,
72                         NULL);
73 
74   // Restore input state.
75   mb->plane[0].src = ori_src_buf;
76   mbd->plane[0].pre[0] = ori_pre_buf;
77 }
78 
residual_variance(const AV1_COMP * cpi,const YV12_BUFFER_CONFIG * src,const YV12_BUFFER_CONFIG * ref,const BLOCK_SIZE block_size,const int mb_row,const int mb_col,FULLPEL_MV ref_mv,unsigned int * sse)79 static unsigned int residual_variance(const AV1_COMP *cpi,
80                                       const YV12_BUFFER_CONFIG *src,
81                                       const YV12_BUFFER_CONFIG *ref,
82                                       const BLOCK_SIZE block_size,
83                                       const int mb_row, const int mb_col,
84                                       FULLPEL_MV ref_mv, unsigned int *sse) {
85   const int mb_height = block_size_high[block_size];
86   const int mb_width = block_size_wide[block_size];
87   const int y_stride = src->y_stride;
88   assert(y_stride == ref->y_stride);
89   const int y_offset = mb_row * mb_height * y_stride + mb_col * mb_width;
90   const int mv_offset = ref_mv.row * y_stride + ref_mv.col;
91   const unsigned int var = cpi->ppi->fn_ptr[block_size].vf(
92       ref->y_buffer + y_offset + mv_offset, y_stride, src->y_buffer + y_offset,
93       y_stride, sse);
94   return var;
95 }
96 
frame_average_variance(const AV1_COMP * const cpi,const YV12_BUFFER_CONFIG * const frame)97 static double frame_average_variance(const AV1_COMP *const cpi,
98                                      const YV12_BUFFER_CONFIG *const frame) {
99   const MACROBLOCKD *const xd = &cpi->td.mb.e_mbd;
100   const uint8_t *const y_buffer = frame->y_buffer;
101   const int y_stride = frame->y_stride;
102   const BLOCK_SIZE block_size = BLOCK_64X64;
103 
104   const int block_w = mi_size_wide[block_size] * 4;
105   const int block_h = mi_size_high[block_size] * 4;
106   int row, col;
107   double var = 0.0, var_count = 0.0;
108   const int use_hbd = frame->flags & YV12_FLAG_HIGHBITDEPTH;
109 
110   // Loop through each block.
111   for (row = 0; row < frame->y_height / block_h; ++row) {
112     for (col = 0; col < frame->y_width / block_w; ++col) {
113       struct buf_2d buf;
114       const int row_offset_y = row * block_h;
115       const int col_offset_y = col * block_w;
116 
117       buf.buf = (uint8_t *)y_buffer + row_offset_y * y_stride + col_offset_y;
118       buf.stride = y_stride;
119 
120       var += av1_get_perpixel_variance(cpi, xd, &buf, block_size, AOM_PLANE_Y,
121                                        use_hbd);
122       var_count += 1.0;
123     }
124   }
125   var /= var_count;
126   return var;
127 }
128 
residual_frame_average_variance(AV1_COMP * cpi,const YV12_BUFFER_CONFIG * src,const YV12_BUFFER_CONFIG * ref,FULLPEL_MV * mvs)129 static double residual_frame_average_variance(AV1_COMP *cpi,
130                                               const YV12_BUFFER_CONFIG *src,
131                                               const YV12_BUFFER_CONFIG *ref,
132                                               FULLPEL_MV *mvs) {
133   if (ref == NULL) return frame_average_variance(cpi, src);
134   const BLOCK_SIZE block_size = BLOCK_16X16;
135   const int frame_height = src->y_height;
136   const int frame_width = src->y_width;
137   const int mb_height = block_size_high[block_size];
138   const int mb_width = block_size_wide[block_size];
139   const int mb_rows = (frame_height + mb_height - 1) / mb_height;
140   const int mb_cols = (frame_width + mb_width - 1) / mb_width;
141   const int num_planes = av1_num_planes(&cpi->common);
142   const int mi_h = mi_size_high_log2[block_size];
143   const int mi_w = mi_size_wide_log2[block_size];
144   assert(num_planes >= 1 && num_planes <= MAX_MB_PLANE);
145 
146   // Save input state.
147   MACROBLOCK *const mb = &cpi->td.mb;
148   MACROBLOCKD *const mbd = &mb->e_mbd;
149   uint8_t *input_buffer[MAX_MB_PLANE];
150   for (int i = 0; i < num_planes; i++) {
151     input_buffer[i] = mbd->plane[i].pre[0].buf;
152   }
153   MB_MODE_INFO **input_mb_mode_info = mbd->mi;
154 
155   bool do_motion_search = false;
156   if (mvs == NULL) {
157     do_motion_search = true;
158     CHECK_MEM_ERROR(&cpi->common, mvs,
159                     (FULLPEL_MV *)aom_calloc(mb_rows * mb_cols, sizeof(*mvs)));
160   }
161 
162   unsigned int variance = 0;
163   // Perform temporal filtering block by block.
164   for (int mb_row = 0; mb_row < mb_rows; mb_row++) {
165     av1_set_mv_row_limits(&cpi->common.mi_params, &mb->mv_limits,
166                           (mb_row << mi_h), (mb_height >> MI_SIZE_LOG2),
167                           cpi->oxcf.border_in_pixels);
168     for (int mb_col = 0; mb_col < mb_cols; mb_col++) {
169       av1_set_mv_col_limits(&cpi->common.mi_params, &mb->mv_limits,
170                             (mb_col << mi_w), (mb_width >> MI_SIZE_LOG2),
171                             cpi->oxcf.border_in_pixels);
172       FULLPEL_MV *ref_mv = &mvs[mb_col + mb_row * mb_cols];
173       if (do_motion_search) {
174         motion_search(cpi, src, ref, block_size, mb_row, mb_col, ref_mv);
175       }
176       unsigned int mv_sse;
177       const unsigned int blk_var = residual_variance(
178           cpi, src, ref, block_size, mb_row, mb_col, *ref_mv, &mv_sse);
179       variance += blk_var;
180     }
181   }
182 
183   // Restore input state
184   for (int i = 0; i < num_planes; i++) {
185     mbd->plane[i].pre[0].buf = input_buffer[i];
186   }
187   mbd->mi = input_mb_mode_info;
188   return (double)variance / (double)(mb_rows * mb_cols);
189 }
190 
191 // TODO(sdeng): Add the SIMD implementation.
highbd_unsharp_rect(const uint16_t * source,int source_stride,const uint16_t * blurred,int blurred_stride,uint16_t * dst,int dst_stride,int w,int h,double amount,int bit_depth)192 static inline void highbd_unsharp_rect(const uint16_t *source,
193                                        int source_stride,
194                                        const uint16_t *blurred,
195                                        int blurred_stride, uint16_t *dst,
196                                        int dst_stride, int w, int h,
197                                        double amount, int bit_depth) {
198   const int max_value = (1 << bit_depth) - 1;
199   for (int i = 0; i < h; ++i) {
200     for (int j = 0; j < w; ++j) {
201       const double val =
202           (double)source[j] + amount * ((double)source[j] - (double)blurred[j]);
203       dst[j] = (uint16_t)clamp((int)(val + 0.5), 0, max_value);
204     }
205     source += source_stride;
206     blurred += blurred_stride;
207     dst += dst_stride;
208   }
209 }
210 
unsharp_rect(const uint8_t * source,int source_stride,const uint8_t * blurred,int blurred_stride,uint8_t * dst,int dst_stride,int w,int h,double amount)211 static inline void unsharp_rect(const uint8_t *source, int source_stride,
212                                 const uint8_t *blurred, int blurred_stride,
213                                 uint8_t *dst, int dst_stride, int w, int h,
214                                 double amount) {
215   for (int i = 0; i < h; ++i) {
216     for (int j = 0; j < w; ++j) {
217       const double val =
218           (double)source[j] + amount * ((double)source[j] - (double)blurred[j]);
219       dst[j] = (uint8_t)clamp((int)(val + 0.5), 0, 255);
220     }
221     source += source_stride;
222     blurred += blurred_stride;
223     dst += dst_stride;
224   }
225 }
226 
unsharp(const AV1_COMP * const cpi,const YV12_BUFFER_CONFIG * source,const YV12_BUFFER_CONFIG * blurred,const YV12_BUFFER_CONFIG * dst,double amount)227 static inline void unsharp(const AV1_COMP *const cpi,
228                            const YV12_BUFFER_CONFIG *source,
229                            const YV12_BUFFER_CONFIG *blurred,
230                            const YV12_BUFFER_CONFIG *dst, double amount) {
231   const int bit_depth = cpi->td.mb.e_mbd.bd;
232   if (cpi->common.seq_params->use_highbitdepth) {
233     assert(source->flags & YV12_FLAG_HIGHBITDEPTH);
234     assert(blurred->flags & YV12_FLAG_HIGHBITDEPTH);
235     assert(dst->flags & YV12_FLAG_HIGHBITDEPTH);
236     highbd_unsharp_rect(CONVERT_TO_SHORTPTR(source->y_buffer), source->y_stride,
237                         CONVERT_TO_SHORTPTR(blurred->y_buffer),
238                         blurred->y_stride, CONVERT_TO_SHORTPTR(dst->y_buffer),
239                         dst->y_stride, source->y_width, source->y_height,
240                         amount, bit_depth);
241   } else {
242     unsharp_rect(source->y_buffer, source->y_stride, blurred->y_buffer,
243                  blurred->y_stride, dst->y_buffer, dst->y_stride,
244                  source->y_width, source->y_height, amount);
245   }
246 }
247 
248 // 8-tap Gaussian convolution filter with sigma = 1.0, sums to 128,
249 // all co-efficients must be even.
250 // The array is of size 9 to allow passing gauss_filter + 1 to
251 // _mm_loadu_si128() in prepare_coeffs_6t().
252 DECLARE_ALIGNED(16, static const int16_t, gauss_filter[9]) = { 0,  8, 30, 52,
253                                                                30, 8, 0,  0 };
gaussian_blur(const int bit_depth,const YV12_BUFFER_CONFIG * source,const YV12_BUFFER_CONFIG * dst)254 static inline void gaussian_blur(const int bit_depth,
255                                  const YV12_BUFFER_CONFIG *source,
256                                  const YV12_BUFFER_CONFIG *dst) {
257   const int block_size = BLOCK_128X128;
258   const int block_w = mi_size_wide[block_size] * 4;
259   const int block_h = mi_size_high[block_size] * 4;
260   const int num_cols = (source->y_width + block_w - 1) / block_w;
261   const int num_rows = (source->y_height + block_h - 1) / block_h;
262   int row, col;
263 
264   ConvolveParams conv_params = get_conv_params(0, 0, bit_depth);
265   InterpFilterParams filter = { .filter_ptr = gauss_filter,
266                                 .taps = 8,
267                                 .interp_filter = EIGHTTAP_REGULAR };
268 
269   for (row = 0; row < num_rows; ++row) {
270     for (col = 0; col < num_cols; ++col) {
271       const int row_offset_y = row * block_h;
272       const int col_offset_y = col * block_w;
273 
274       uint8_t *src_buf =
275           source->y_buffer + row_offset_y * source->y_stride + col_offset_y;
276       uint8_t *dst_buf =
277           dst->y_buffer + row_offset_y * dst->y_stride + col_offset_y;
278 
279       if (source->flags & YV12_FLAG_HIGHBITDEPTH) {
280         av1_highbd_convolve_2d_sr(
281             CONVERT_TO_SHORTPTR(src_buf), source->y_stride,
282             CONVERT_TO_SHORTPTR(dst_buf), dst->y_stride, block_w, block_h,
283             &filter, &filter, 0, 0, &conv_params, bit_depth);
284       } else {
285         av1_convolve_2d_sr(src_buf, source->y_stride, dst_buf, dst->y_stride,
286                            block_w, block_h, &filter, &filter, 0, 0,
287                            &conv_params);
288       }
289     }
290   }
291 }
292 
cal_approx_vmaf(const AV1_COMP * const cpi,double source_variance,const YV12_BUFFER_CONFIG * const source,const YV12_BUFFER_CONFIG * const sharpened)293 static inline double cal_approx_vmaf(
294     const AV1_COMP *const cpi, double source_variance,
295     const YV12_BUFFER_CONFIG *const source,
296     const YV12_BUFFER_CONFIG *const sharpened) {
297   const int bit_depth = cpi->td.mb.e_mbd.bd;
298   const bool cal_vmaf_neg =
299       cpi->oxcf.tune_cfg.tuning == AOM_TUNE_VMAF_NEG_MAX_GAIN;
300   double new_vmaf;
301 
302   aom_calc_vmaf(cpi->vmaf_info.vmaf_model, source, sharpened, bit_depth,
303                 cal_vmaf_neg, &new_vmaf);
304 
305   const double sharpened_var = frame_average_variance(cpi, sharpened);
306   return source_variance / sharpened_var * (new_vmaf - kBaselineVmaf);
307 }
308 
find_best_frame_unsharp_amount_loop(const AV1_COMP * const cpi,const YV12_BUFFER_CONFIG * const source,const YV12_BUFFER_CONFIG * const blurred,const YV12_BUFFER_CONFIG * const sharpened,double best_vmaf,const double baseline_variance,const double unsharp_amount_start,const double step_size,const int max_loop_count,const double max_amount)309 static double find_best_frame_unsharp_amount_loop(
310     const AV1_COMP *const cpi, const YV12_BUFFER_CONFIG *const source,
311     const YV12_BUFFER_CONFIG *const blurred,
312     const YV12_BUFFER_CONFIG *const sharpened, double best_vmaf,
313     const double baseline_variance, const double unsharp_amount_start,
314     const double step_size, const int max_loop_count, const double max_amount) {
315   const double min_amount = 0.0;
316   int loop_count = 0;
317   double approx_vmaf = best_vmaf;
318   double unsharp_amount = unsharp_amount_start;
319   do {
320     best_vmaf = approx_vmaf;
321     unsharp_amount += step_size;
322     if (unsharp_amount > max_amount || unsharp_amount < min_amount) break;
323     unsharp(cpi, source, blurred, sharpened, unsharp_amount);
324     approx_vmaf = cal_approx_vmaf(cpi, baseline_variance, source, sharpened);
325 
326     loop_count++;
327   } while (approx_vmaf > best_vmaf && loop_count < max_loop_count);
328   unsharp_amount =
329       approx_vmaf > best_vmaf ? unsharp_amount : unsharp_amount - step_size;
330   return AOMMIN(max_amount, AOMMAX(unsharp_amount, min_amount));
331 }
332 
find_best_frame_unsharp_amount(const AV1_COMP * const cpi,const YV12_BUFFER_CONFIG * const source,const YV12_BUFFER_CONFIG * const blurred,const double unsharp_amount_start,const double step_size,const int max_loop_count,const double max_filter_amount)333 static double find_best_frame_unsharp_amount(
334     const AV1_COMP *const cpi, const YV12_BUFFER_CONFIG *const source,
335     const YV12_BUFFER_CONFIG *const blurred, const double unsharp_amount_start,
336     const double step_size, const int max_loop_count,
337     const double max_filter_amount) {
338   const AV1_COMMON *const cm = &cpi->common;
339   const int width = source->y_width;
340   const int height = source->y_height;
341   YV12_BUFFER_CONFIG sharpened;
342   memset(&sharpened, 0, sizeof(sharpened));
343   aom_alloc_frame_buffer(
344       &sharpened, width, height, source->subsampling_x, source->subsampling_y,
345       cm->seq_params->use_highbitdepth, cpi->oxcf.border_in_pixels,
346       cm->features.byte_alignment, false, 0);
347 
348   const double baseline_variance = frame_average_variance(cpi, source);
349   double unsharp_amount;
350   if (unsharp_amount_start <= step_size) {
351     unsharp_amount = find_best_frame_unsharp_amount_loop(
352         cpi, source, blurred, &sharpened, 0.0, baseline_variance, 0.0,
353         step_size, max_loop_count, max_filter_amount);
354   } else {
355     double a0 = unsharp_amount_start - step_size, a1 = unsharp_amount_start;
356     double v0, v1;
357     unsharp(cpi, source, blurred, &sharpened, a0);
358     v0 = cal_approx_vmaf(cpi, baseline_variance, source, &sharpened);
359     unsharp(cpi, source, blurred, &sharpened, a1);
360     v1 = cal_approx_vmaf(cpi, baseline_variance, source, &sharpened);
361     if (fabs(v0 - v1) < 0.01) {
362       unsharp_amount = a0;
363     } else if (v0 > v1) {
364       unsharp_amount = find_best_frame_unsharp_amount_loop(
365           cpi, source, blurred, &sharpened, v0, baseline_variance, a0,
366           -step_size, max_loop_count, max_filter_amount);
367     } else {
368       unsharp_amount = find_best_frame_unsharp_amount_loop(
369           cpi, source, blurred, &sharpened, v1, baseline_variance, a1,
370           step_size, max_loop_count, max_filter_amount);
371     }
372   }
373 
374   aom_free_frame_buffer(&sharpened);
375   return unsharp_amount;
376 }
377 
av1_vmaf_neg_preprocessing(AV1_COMP * const cpi,const YV12_BUFFER_CONFIG * const source)378 void av1_vmaf_neg_preprocessing(AV1_COMP *const cpi,
379                                 const YV12_BUFFER_CONFIG *const source) {
380   const AV1_COMMON *const cm = &cpi->common;
381   const int bit_depth = cpi->td.mb.e_mbd.bd;
382   const int width = source->y_width;
383   const int height = source->y_height;
384 
385   const GF_GROUP *const gf_group = &cpi->ppi->gf_group;
386   const int layer_depth =
387       AOMMIN(gf_group->layer_depth[cpi->gf_frame_index], MAX_ARF_LAYERS - 1);
388   const double best_frame_unsharp_amount =
389       get_layer_value(cpi->vmaf_info.last_frame_unsharp_amount, layer_depth);
390 
391   if (best_frame_unsharp_amount <= 0.0) return;
392 
393   YV12_BUFFER_CONFIG blurred;
394   memset(&blurred, 0, sizeof(blurred));
395   aom_alloc_frame_buffer(
396       &blurred, width, height, source->subsampling_x, source->subsampling_y,
397       cm->seq_params->use_highbitdepth, cpi->oxcf.border_in_pixels,
398       cm->features.byte_alignment, false, 0);
399 
400   gaussian_blur(bit_depth, source, &blurred);
401   unsharp(cpi, source, &blurred, source, best_frame_unsharp_amount);
402   aom_free_frame_buffer(&blurred);
403 }
404 
av1_vmaf_frame_preprocessing(AV1_COMP * const cpi,const YV12_BUFFER_CONFIG * const source)405 void av1_vmaf_frame_preprocessing(AV1_COMP *const cpi,
406                                   const YV12_BUFFER_CONFIG *const source) {
407   const AV1_COMMON *const cm = &cpi->common;
408   const int bit_depth = cpi->td.mb.e_mbd.bd;
409   const int width = source->y_width;
410   const int height = source->y_height;
411 
412   YV12_BUFFER_CONFIG source_extended, blurred;
413   memset(&source_extended, 0, sizeof(source_extended));
414   memset(&blurred, 0, sizeof(blurred));
415   aom_alloc_frame_buffer(
416       &source_extended, width, height, source->subsampling_x,
417       source->subsampling_y, cm->seq_params->use_highbitdepth,
418       cpi->oxcf.border_in_pixels, cm->features.byte_alignment, false, 0);
419   aom_alloc_frame_buffer(
420       &blurred, width, height, source->subsampling_x, source->subsampling_y,
421       cm->seq_params->use_highbitdepth, cpi->oxcf.border_in_pixels,
422       cm->features.byte_alignment, false, 0);
423 
424   av1_copy_and_extend_frame(source, &source_extended);
425   gaussian_blur(bit_depth, &source_extended, &blurred);
426   aom_free_frame_buffer(&source_extended);
427 
428   const GF_GROUP *const gf_group = &cpi->ppi->gf_group;
429   const int layer_depth =
430       AOMMIN(gf_group->layer_depth[cpi->gf_frame_index], MAX_ARF_LAYERS - 1);
431   const double last_frame_unsharp_amount =
432       get_layer_value(cpi->vmaf_info.last_frame_unsharp_amount, layer_depth);
433 
434   const double best_frame_unsharp_amount = find_best_frame_unsharp_amount(
435       cpi, source, &blurred, last_frame_unsharp_amount, 0.05, 20, 1.01);
436 
437   cpi->vmaf_info.last_frame_unsharp_amount[layer_depth] =
438       best_frame_unsharp_amount;
439 
440   unsharp(cpi, source, &blurred, source, best_frame_unsharp_amount);
441   aom_free_frame_buffer(&blurred);
442 }
443 
av1_vmaf_blk_preprocessing(AV1_COMP * const cpi,const YV12_BUFFER_CONFIG * const source)444 void av1_vmaf_blk_preprocessing(AV1_COMP *const cpi,
445                                 const YV12_BUFFER_CONFIG *const source) {
446   const AV1_COMMON *const cm = &cpi->common;
447   const int width = source->y_width;
448   const int height = source->y_height;
449   const int bit_depth = cpi->td.mb.e_mbd.bd;
450   const int ss_x = source->subsampling_x;
451   const int ss_y = source->subsampling_y;
452 
453   YV12_BUFFER_CONFIG source_extended, blurred;
454   memset(&blurred, 0, sizeof(blurred));
455   memset(&source_extended, 0, sizeof(source_extended));
456   aom_alloc_frame_buffer(
457       &blurred, width, height, ss_x, ss_y, cm->seq_params->use_highbitdepth,
458       cpi->oxcf.border_in_pixels, cm->features.byte_alignment, false, 0);
459   aom_alloc_frame_buffer(&source_extended, width, height, ss_x, ss_y,
460                          cm->seq_params->use_highbitdepth,
461                          cpi->oxcf.border_in_pixels,
462                          cm->features.byte_alignment, false, 0);
463 
464   av1_copy_and_extend_frame(source, &source_extended);
465   gaussian_blur(bit_depth, &source_extended, &blurred);
466   aom_free_frame_buffer(&source_extended);
467 
468   const GF_GROUP *const gf_group = &cpi->ppi->gf_group;
469   const int layer_depth =
470       AOMMIN(gf_group->layer_depth[cpi->gf_frame_index], MAX_ARF_LAYERS - 1);
471   const double last_frame_unsharp_amount =
472       get_layer_value(cpi->vmaf_info.last_frame_unsharp_amount, layer_depth);
473 
474   const double best_frame_unsharp_amount = find_best_frame_unsharp_amount(
475       cpi, source, &blurred, last_frame_unsharp_amount, 0.05, 20, 1.01);
476 
477   cpi->vmaf_info.last_frame_unsharp_amount[layer_depth] =
478       best_frame_unsharp_amount;
479 
480   const int block_size = BLOCK_64X64;
481   const int block_w = mi_size_wide[block_size] * 4;
482   const int block_h = mi_size_high[block_size] * 4;
483   const int num_cols = (source->y_width + block_w - 1) / block_w;
484   const int num_rows = (source->y_height + block_h - 1) / block_h;
485   double *best_unsharp_amounts =
486       aom_calloc(num_cols * num_rows, sizeof(*best_unsharp_amounts));
487   if (!best_unsharp_amounts) {
488     aom_internal_error(cm->error, AOM_CODEC_MEM_ERROR,
489                        "Error allocating vmaf data");
490   }
491 
492   YV12_BUFFER_CONFIG source_block, blurred_block;
493   memset(&source_block, 0, sizeof(source_block));
494   memset(&blurred_block, 0, sizeof(blurred_block));
495   aom_alloc_frame_buffer(&source_block, block_w, block_h, ss_x, ss_y,
496                          cm->seq_params->use_highbitdepth,
497                          cpi->oxcf.border_in_pixels,
498                          cm->features.byte_alignment, false, 0);
499   aom_alloc_frame_buffer(&blurred_block, block_w, block_h, ss_x, ss_y,
500                          cm->seq_params->use_highbitdepth,
501                          cpi->oxcf.border_in_pixels,
502                          cm->features.byte_alignment, false, 0);
503 
504   for (int row = 0; row < num_rows; ++row) {
505     for (int col = 0; col < num_cols; ++col) {
506       const int row_offset_y = row * block_h;
507       const int col_offset_y = col * block_w;
508       const int block_width = AOMMIN(width - col_offset_y, block_w);
509       const int block_height = AOMMIN(height - row_offset_y, block_h);
510       const int index = col + row * num_cols;
511 
512       if (cm->seq_params->use_highbitdepth) {
513         assert(source->flags & YV12_FLAG_HIGHBITDEPTH);
514         assert(blurred.flags & YV12_FLAG_HIGHBITDEPTH);
515         uint16_t *frame_src_buf = CONVERT_TO_SHORTPTR(source->y_buffer) +
516                                   row_offset_y * source->y_stride +
517                                   col_offset_y;
518         uint16_t *frame_blurred_buf = CONVERT_TO_SHORTPTR(blurred.y_buffer) +
519                                       row_offset_y * blurred.y_stride +
520                                       col_offset_y;
521         uint16_t *blurred_dst = CONVERT_TO_SHORTPTR(blurred_block.y_buffer);
522         uint16_t *src_dst = CONVERT_TO_SHORTPTR(source_block.y_buffer);
523 
524         // Copy block from source frame.
525         for (int i = 0; i < block_h; ++i) {
526           for (int j = 0; j < block_w; ++j) {
527             if (i >= block_height || j >= block_width) {
528               src_dst[j] = 0;
529               blurred_dst[j] = 0;
530             } else {
531               src_dst[j] = frame_src_buf[j];
532               blurred_dst[j] = frame_blurred_buf[j];
533             }
534           }
535           frame_src_buf += source->y_stride;
536           frame_blurred_buf += blurred.y_stride;
537           src_dst += source_block.y_stride;
538           blurred_dst += blurred_block.y_stride;
539         }
540       } else {
541         uint8_t *frame_src_buf =
542             source->y_buffer + row_offset_y * source->y_stride + col_offset_y;
543         uint8_t *frame_blurred_buf =
544             blurred.y_buffer + row_offset_y * blurred.y_stride + col_offset_y;
545         uint8_t *blurred_dst = blurred_block.y_buffer;
546         uint8_t *src_dst = source_block.y_buffer;
547 
548         // Copy block from source frame.
549         for (int i = 0; i < block_h; ++i) {
550           for (int j = 0; j < block_w; ++j) {
551             if (i >= block_height || j >= block_width) {
552               src_dst[j] = 0;
553               blurred_dst[j] = 0;
554             } else {
555               src_dst[j] = frame_src_buf[j];
556               blurred_dst[j] = frame_blurred_buf[j];
557             }
558           }
559           frame_src_buf += source->y_stride;
560           frame_blurred_buf += blurred.y_stride;
561           src_dst += source_block.y_stride;
562           blurred_dst += blurred_block.y_stride;
563         }
564       }
565 
566       best_unsharp_amounts[index] = find_best_frame_unsharp_amount(
567           cpi, &source_block, &blurred_block, best_frame_unsharp_amount, 0.1, 3,
568           1.5);
569     }
570   }
571 
572   // Apply best blur amounts
573   for (int row = 0; row < num_rows; ++row) {
574     for (int col = 0; col < num_cols; ++col) {
575       const int row_offset_y = row * block_h;
576       const int col_offset_y = col * block_w;
577       const int block_width = AOMMIN(source->y_width - col_offset_y, block_w);
578       const int block_height = AOMMIN(source->y_height - row_offset_y, block_h);
579       const int index = col + row * num_cols;
580 
581       if (cm->seq_params->use_highbitdepth) {
582         assert(source->flags & YV12_FLAG_HIGHBITDEPTH);
583         assert(blurred.flags & YV12_FLAG_HIGHBITDEPTH);
584         uint16_t *src_buf = CONVERT_TO_SHORTPTR(source->y_buffer) +
585                             row_offset_y * source->y_stride + col_offset_y;
586         uint16_t *blurred_buf = CONVERT_TO_SHORTPTR(blurred.y_buffer) +
587                                 row_offset_y * blurred.y_stride + col_offset_y;
588         highbd_unsharp_rect(src_buf, source->y_stride, blurred_buf,
589                             blurred.y_stride, src_buf, source->y_stride,
590                             block_width, block_height,
591                             best_unsharp_amounts[index], bit_depth);
592       } else {
593         uint8_t *src_buf =
594             source->y_buffer + row_offset_y * source->y_stride + col_offset_y;
595         uint8_t *blurred_buf =
596             blurred.y_buffer + row_offset_y * blurred.y_stride + col_offset_y;
597         unsharp_rect(src_buf, source->y_stride, blurred_buf, blurred.y_stride,
598                      src_buf, source->y_stride, block_width, block_height,
599                      best_unsharp_amounts[index]);
600       }
601     }
602   }
603 
604   aom_free_frame_buffer(&source_block);
605   aom_free_frame_buffer(&blurred_block);
606   aom_free_frame_buffer(&blurred);
607   aom_free(best_unsharp_amounts);
608 }
609 
av1_set_mb_vmaf_rdmult_scaling(AV1_COMP * cpi)610 void av1_set_mb_vmaf_rdmult_scaling(AV1_COMP *cpi) {
611   AV1_COMMON *cm = &cpi->common;
612   const int y_width = cpi->source->y_width;
613   const int y_height = cpi->source->y_height;
614   const int resized_block_size = BLOCK_32X32;
615   const int resize_factor = 2;
616   const int bit_depth = cpi->td.mb.e_mbd.bd;
617   const int ss_x = cpi->source->subsampling_x;
618   const int ss_y = cpi->source->subsampling_y;
619 
620   YV12_BUFFER_CONFIG resized_source;
621   memset(&resized_source, 0, sizeof(resized_source));
622   aom_alloc_frame_buffer(
623       &resized_source, y_width / resize_factor, y_height / resize_factor, ss_x,
624       ss_y, cm->seq_params->use_highbitdepth, cpi->oxcf.border_in_pixels,
625       cm->features.byte_alignment, false, 0);
626   if (!av1_resize_and_extend_frame_nonnormative(
627           cpi->source, &resized_source, bit_depth, av1_num_planes(cm))) {
628     aom_internal_error(cm->error, AOM_CODEC_MEM_ERROR,
629                        "Error allocating buffers during resize");
630   }
631 
632   const int resized_y_width = resized_source.y_width;
633   const int resized_y_height = resized_source.y_height;
634   const int resized_block_w = mi_size_wide[resized_block_size] * 4;
635   const int resized_block_h = mi_size_high[resized_block_size] * 4;
636   const int num_cols =
637       (resized_y_width + resized_block_w - 1) / resized_block_w;
638   const int num_rows =
639       (resized_y_height + resized_block_h - 1) / resized_block_h;
640 
641   YV12_BUFFER_CONFIG blurred;
642   memset(&blurred, 0, sizeof(blurred));
643   aom_alloc_frame_buffer(&blurred, resized_y_width, resized_y_height, ss_x,
644                          ss_y, cm->seq_params->use_highbitdepth,
645                          cpi->oxcf.border_in_pixels,
646                          cm->features.byte_alignment, false, 0);
647   gaussian_blur(bit_depth, &resized_source, &blurred);
648 
649   YV12_BUFFER_CONFIG recon;
650   memset(&recon, 0, sizeof(recon));
651   aom_alloc_frame_buffer(&recon, resized_y_width, resized_y_height, ss_x, ss_y,
652                          cm->seq_params->use_highbitdepth,
653                          cpi->oxcf.border_in_pixels,
654                          cm->features.byte_alignment, false, 0);
655   aom_yv12_copy_frame(&resized_source, &recon, 1);
656 
657   VmafContext *vmaf_context;
658   const bool cal_vmaf_neg =
659       cpi->oxcf.tune_cfg.tuning == AOM_TUNE_VMAF_NEG_MAX_GAIN;
660   aom_init_vmaf_context(&vmaf_context, cpi->vmaf_info.vmaf_model, cal_vmaf_neg);
661   unsigned int *sses = aom_calloc(num_rows * num_cols, sizeof(*sses));
662   if (!sses) {
663     aom_internal_error(cm->error, AOM_CODEC_MEM_ERROR,
664                        "Error allocating vmaf data");
665   }
666 
667   // Loop through each 'block_size' block.
668   for (int row = 0; row < num_rows; ++row) {
669     for (int col = 0; col < num_cols; ++col) {
670       const int index = row * num_cols + col;
671       const int row_offset_y = row * resized_block_h;
672       const int col_offset_y = col * resized_block_w;
673 
674       uint8_t *const orig_buf = resized_source.y_buffer +
675                                 row_offset_y * resized_source.y_stride +
676                                 col_offset_y;
677       uint8_t *const blurred_buf =
678           blurred.y_buffer + row_offset_y * blurred.y_stride + col_offset_y;
679 
680       cpi->ppi->fn_ptr[resized_block_size].vf(orig_buf, resized_source.y_stride,
681                                               blurred_buf, blurred.y_stride,
682                                               &sses[index]);
683 
684       uint8_t *const recon_buf =
685           recon.y_buffer + row_offset_y * recon.y_stride + col_offset_y;
686       // Set recon buf
687       if (cpi->common.seq_params->use_highbitdepth) {
688         highbd_unsharp_rect(CONVERT_TO_SHORTPTR(blurred_buf), blurred.y_stride,
689                             CONVERT_TO_SHORTPTR(blurred_buf), blurred.y_stride,
690                             CONVERT_TO_SHORTPTR(recon_buf), recon.y_stride,
691                             resized_block_w, resized_block_h, 0.0, bit_depth);
692       } else {
693         unsharp_rect(blurred_buf, blurred.y_stride, blurred_buf,
694                      blurred.y_stride, recon_buf, recon.y_stride,
695                      resized_block_w, resized_block_h, 0.0);
696       }
697 
698       aom_read_vmaf_image(vmaf_context, &resized_source, &recon, bit_depth,
699                           index);
700 
701       // Restore recon buf
702       if (cpi->common.seq_params->use_highbitdepth) {
703         highbd_unsharp_rect(
704             CONVERT_TO_SHORTPTR(orig_buf), resized_source.y_stride,
705             CONVERT_TO_SHORTPTR(orig_buf), resized_source.y_stride,
706             CONVERT_TO_SHORTPTR(recon_buf), recon.y_stride, resized_block_w,
707             resized_block_h, 0.0, bit_depth);
708       } else {
709         unsharp_rect(orig_buf, resized_source.y_stride, orig_buf,
710                      resized_source.y_stride, recon_buf, recon.y_stride,
711                      resized_block_w, resized_block_h, 0.0);
712       }
713     }
714   }
715   aom_flush_vmaf_context(vmaf_context);
716   for (int row = 0; row < num_rows; ++row) {
717     for (int col = 0; col < num_cols; ++col) {
718       const int index = row * num_cols + col;
719       const double vmaf = aom_calc_vmaf_at_index(
720           vmaf_context, cpi->vmaf_info.vmaf_model, index);
721       const double dvmaf = kBaselineVmaf - vmaf;
722 
723       const double mse =
724           (double)sses[index] / (double)(resized_y_width * resized_y_height);
725       double weight;
726       const double eps = 0.01 / (num_rows * num_cols);
727       if (dvmaf < eps || mse < eps) {
728         weight = 1.0;
729       } else {
730         weight = mse / dvmaf;
731       }
732 
733       // Normalize it with a data fitted model.
734       weight = 6.0 * (1.0 - exp(-0.05 * weight)) + 0.8;
735       cpi->vmaf_info.rdmult_scaling_factors[index] = weight;
736     }
737   }
738 
739   aom_free_frame_buffer(&resized_source);
740   aom_free_frame_buffer(&blurred);
741   aom_close_vmaf_context(vmaf_context);
742   aom_free(sses);
743 }
744 
av1_set_vmaf_rdmult(const AV1_COMP * const cpi,MACROBLOCK * const x,const BLOCK_SIZE bsize,const int mi_row,const int mi_col,int * const rdmult)745 void av1_set_vmaf_rdmult(const AV1_COMP *const cpi, MACROBLOCK *const x,
746                          const BLOCK_SIZE bsize, const int mi_row,
747                          const int mi_col, int *const rdmult) {
748   const AV1_COMMON *const cm = &cpi->common;
749 
750   const int bsize_base = BLOCK_64X64;
751   const int num_mi_w = mi_size_wide[bsize_base];
752   const int num_mi_h = mi_size_high[bsize_base];
753   const int num_cols = (cm->mi_params.mi_cols + num_mi_w - 1) / num_mi_w;
754   const int num_rows = (cm->mi_params.mi_rows + num_mi_h - 1) / num_mi_h;
755   const int num_bcols = (mi_size_wide[bsize] + num_mi_w - 1) / num_mi_w;
756   const int num_brows = (mi_size_high[bsize] + num_mi_h - 1) / num_mi_h;
757   int row, col;
758   double num_of_mi = 0.0;
759   double geom_mean_of_scale = 0.0;
760 
761   for (row = mi_row / num_mi_w;
762        row < num_rows && row < mi_row / num_mi_w + num_brows; ++row) {
763     for (col = mi_col / num_mi_h;
764          col < num_cols && col < mi_col / num_mi_h + num_bcols; ++col) {
765       const int index = row * num_cols + col;
766       geom_mean_of_scale += log(cpi->vmaf_info.rdmult_scaling_factors[index]);
767       num_of_mi += 1.0;
768     }
769   }
770   geom_mean_of_scale = exp(geom_mean_of_scale / num_of_mi);
771 
772   *rdmult = (int)((double)(*rdmult) * geom_mean_of_scale + 0.5);
773   *rdmult = AOMMAX(*rdmult, 0);
774   av1_set_error_per_bit(&x->errorperbit, *rdmult);
775 }
776 
777 // TODO(sdeng): replace them with the SIMD versions.
highbd_image_sad_c(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int w,int h)778 static inline double highbd_image_sad_c(const uint16_t *src, int src_stride,
779                                         const uint16_t *ref, int ref_stride,
780                                         int w, int h) {
781   double accum = 0.0;
782   int i, j;
783 
784   for (i = 0; i < h; ++i) {
785     for (j = 0; j < w; ++j) {
786       double img1px = src[i * src_stride + j];
787       double img2px = ref[i * ref_stride + j];
788 
789       accum += fabs(img1px - img2px);
790     }
791   }
792 
793   return accum / (double)(h * w);
794 }
795 
image_sad_c(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int w,int h)796 static inline double image_sad_c(const uint8_t *src, int src_stride,
797                                  const uint8_t *ref, int ref_stride, int w,
798                                  int h) {
799   double accum = 0.0;
800   int i, j;
801 
802   for (i = 0; i < h; ++i) {
803     for (j = 0; j < w; ++j) {
804       double img1px = src[i * src_stride + j];
805       double img2px = ref[i * ref_stride + j];
806 
807       accum += fabs(img1px - img2px);
808     }
809   }
810 
811   return accum / (double)(h * w);
812 }
813 
calc_vmaf_motion_score(const AV1_COMP * const cpi,const AV1_COMMON * const cm,const YV12_BUFFER_CONFIG * const cur,const YV12_BUFFER_CONFIG * const last,const YV12_BUFFER_CONFIG * const next)814 static double calc_vmaf_motion_score(const AV1_COMP *const cpi,
815                                      const AV1_COMMON *const cm,
816                                      const YV12_BUFFER_CONFIG *const cur,
817                                      const YV12_BUFFER_CONFIG *const last,
818                                      const YV12_BUFFER_CONFIG *const next) {
819   const int y_width = cur->y_width;
820   const int y_height = cur->y_height;
821   YV12_BUFFER_CONFIG blurred_cur, blurred_last, blurred_next;
822   const int bit_depth = cpi->td.mb.e_mbd.bd;
823   const int ss_x = cur->subsampling_x;
824   const int ss_y = cur->subsampling_y;
825 
826   memset(&blurred_cur, 0, sizeof(blurred_cur));
827   memset(&blurred_last, 0, sizeof(blurred_last));
828   memset(&blurred_next, 0, sizeof(blurred_next));
829 
830   aom_alloc_frame_buffer(&blurred_cur, y_width, y_height, ss_x, ss_y,
831                          cm->seq_params->use_highbitdepth,
832                          cpi->oxcf.border_in_pixels,
833                          cm->features.byte_alignment, false, 0);
834   aom_alloc_frame_buffer(&blurred_last, y_width, y_height, ss_x, ss_y,
835                          cm->seq_params->use_highbitdepth,
836                          cpi->oxcf.border_in_pixels,
837                          cm->features.byte_alignment, false, 0);
838   aom_alloc_frame_buffer(&blurred_next, y_width, y_height, ss_x, ss_y,
839                          cm->seq_params->use_highbitdepth,
840                          cpi->oxcf.border_in_pixels,
841                          cm->features.byte_alignment, false, 0);
842 
843   gaussian_blur(bit_depth, cur, &blurred_cur);
844   gaussian_blur(bit_depth, last, &blurred_last);
845   if (next) gaussian_blur(bit_depth, next, &blurred_next);
846 
847   double motion1, motion2 = 65536.0;
848   if (cm->seq_params->use_highbitdepth) {
849     assert(blurred_cur.flags & YV12_FLAG_HIGHBITDEPTH);
850     assert(blurred_last.flags & YV12_FLAG_HIGHBITDEPTH);
851     const float scale_factor = 1.0f / (float)(1 << (bit_depth - 8));
852     motion1 = highbd_image_sad_c(CONVERT_TO_SHORTPTR(blurred_cur.y_buffer),
853                                  blurred_cur.y_stride,
854                                  CONVERT_TO_SHORTPTR(blurred_last.y_buffer),
855                                  blurred_last.y_stride, y_width, y_height) *
856               scale_factor;
857     if (next) {
858       assert(blurred_next.flags & YV12_FLAG_HIGHBITDEPTH);
859       motion2 = highbd_image_sad_c(CONVERT_TO_SHORTPTR(blurred_cur.y_buffer),
860                                    blurred_cur.y_stride,
861                                    CONVERT_TO_SHORTPTR(blurred_next.y_buffer),
862                                    blurred_next.y_stride, y_width, y_height) *
863                 scale_factor;
864     }
865   } else {
866     motion1 = image_sad_c(blurred_cur.y_buffer, blurred_cur.y_stride,
867                           blurred_last.y_buffer, blurred_last.y_stride, y_width,
868                           y_height);
869     if (next) {
870       motion2 = image_sad_c(blurred_cur.y_buffer, blurred_cur.y_stride,
871                             blurred_next.y_buffer, blurred_next.y_stride,
872                             y_width, y_height);
873     }
874   }
875 
876   aom_free_frame_buffer(&blurred_cur);
877   aom_free_frame_buffer(&blurred_last);
878   aom_free_frame_buffer(&blurred_next);
879 
880   return AOMMIN(motion1, motion2);
881 }
882 
get_neighbor_frames(const AV1_COMP * const cpi,const YV12_BUFFER_CONFIG ** last,const YV12_BUFFER_CONFIG ** next)883 static inline void get_neighbor_frames(const AV1_COMP *const cpi,
884                                        const YV12_BUFFER_CONFIG **last,
885                                        const YV12_BUFFER_CONFIG **next) {
886   const AV1_COMMON *const cm = &cpi->common;
887   const GF_GROUP *gf_group = &cpi->ppi->gf_group;
888   const int src_index =
889       cm->show_frame != 0 ? 0 : gf_group->arf_src_offset[cpi->gf_frame_index];
890   struct lookahead_entry *last_entry = av1_lookahead_peek(
891       cpi->ppi->lookahead, src_index - 1, cpi->compressor_stage);
892   struct lookahead_entry *next_entry = av1_lookahead_peek(
893       cpi->ppi->lookahead, src_index + 1, cpi->compressor_stage);
894   *next = &next_entry->img;
895   *last = cm->show_frame ? cpi->last_source : &last_entry->img;
896 }
897 
898 // Calculates the new qindex from the VMAF motion score. This is based on the
899 // observation: when the motion score becomes higher, the VMAF score of the
900 // same source and distorted frames would become higher.
av1_get_vmaf_base_qindex(const AV1_COMP * const cpi,int current_qindex)901 int av1_get_vmaf_base_qindex(const AV1_COMP *const cpi, int current_qindex) {
902   const AV1_COMMON *const cm = &cpi->common;
903   if (cm->current_frame.frame_number == 0 || cpi->oxcf.pass == 1) {
904     return current_qindex;
905   }
906   const GF_GROUP *const gf_group = &cpi->ppi->gf_group;
907   const int layer_depth =
908       AOMMIN(gf_group->layer_depth[cpi->gf_frame_index], MAX_ARF_LAYERS - 1);
909   const double last_frame_ysse =
910       get_layer_value(cpi->vmaf_info.last_frame_ysse, layer_depth);
911   const double last_frame_vmaf =
912       get_layer_value(cpi->vmaf_info.last_frame_vmaf, layer_depth);
913   const int bit_depth = cpi->td.mb.e_mbd.bd;
914   const double approx_sse = last_frame_ysse / (double)((1 << (bit_depth - 8)) *
915                                                        (1 << (bit_depth - 8)));
916   const double approx_dvmaf = kBaselineVmaf - last_frame_vmaf;
917   const double sse_threshold =
918       0.01 * cpi->source->y_width * cpi->source->y_height;
919   const double vmaf_threshold = 0.01;
920   if (approx_sse < sse_threshold || approx_dvmaf < vmaf_threshold) {
921     return current_qindex;
922   }
923   const YV12_BUFFER_CONFIG *cur_buf = cpi->source;
924   if (cm->show_frame == 0) {
925     const int src_index = gf_group->arf_src_offset[cpi->gf_frame_index];
926     struct lookahead_entry *cur_entry = av1_lookahead_peek(
927         cpi->ppi->lookahead, src_index, cpi->compressor_stage);
928     cur_buf = &cur_entry->img;
929   }
930   assert(cur_buf);
931 
932   const YV12_BUFFER_CONFIG *next_buf, *last_buf;
933   get_neighbor_frames(cpi, &last_buf, &next_buf);
934   assert(last_buf);
935 
936   const double motion =
937       calc_vmaf_motion_score(cpi, cm, cur_buf, last_buf, next_buf);
938 
939   // Get dVMAF through a data fitted model.
940   const double dvmaf = 26.11 * (1.0 - exp(-0.06 * motion));
941   const double dsse = dvmaf * approx_sse / approx_dvmaf;
942 
943   // Clamping beta to address VQ issue (aomedia:3170).
944   const double beta = AOMMAX(approx_sse / (dsse + approx_sse), 0.5);
945   const int offset =
946       av1_get_deltaq_offset(cm->seq_params->bit_depth, current_qindex, beta);
947   int qindex = current_qindex + offset;
948 
949   qindex = AOMMIN(qindex, MAXQ);
950   qindex = AOMMAX(qindex, MINQ);
951 
952   return qindex;
953 }
954 
cal_approx_score(AV1_COMP * const cpi,double src_variance,double new_variance,double src_score,const YV12_BUFFER_CONFIG * const src,const YV12_BUFFER_CONFIG * const recon_sharpened)955 static inline double cal_approx_score(
956     AV1_COMP *const cpi, double src_variance, double new_variance,
957     double src_score, const YV12_BUFFER_CONFIG *const src,
958     const YV12_BUFFER_CONFIG *const recon_sharpened) {
959   double score;
960   const uint32_t bit_depth = cpi->td.mb.e_mbd.bd;
961   const bool cal_vmaf_neg =
962       cpi->oxcf.tune_cfg.tuning == AOM_TUNE_VMAF_NEG_MAX_GAIN;
963   aom_calc_vmaf(cpi->vmaf_info.vmaf_model, src, recon_sharpened, bit_depth,
964                 cal_vmaf_neg, &score);
965   return src_variance / new_variance * (score - src_score);
966 }
967 
find_best_frame_unsharp_amount_loop_neg(AV1_COMP * const cpi,double src_variance,double base_score,const YV12_BUFFER_CONFIG * const src,const YV12_BUFFER_CONFIG * const recon,const YV12_BUFFER_CONFIG * const ref,const YV12_BUFFER_CONFIG * const src_blurred,const YV12_BUFFER_CONFIG * const recon_blurred,const YV12_BUFFER_CONFIG * const src_sharpened,const YV12_BUFFER_CONFIG * const recon_sharpened,FULLPEL_MV * mvs,double best_score,const double unsharp_amount_start,const double step_size,const int max_loop_count,const double max_amount)968 static double find_best_frame_unsharp_amount_loop_neg(
969     AV1_COMP *const cpi, double src_variance, double base_score,
970     const YV12_BUFFER_CONFIG *const src, const YV12_BUFFER_CONFIG *const recon,
971     const YV12_BUFFER_CONFIG *const ref,
972     const YV12_BUFFER_CONFIG *const src_blurred,
973     const YV12_BUFFER_CONFIG *const recon_blurred,
974     const YV12_BUFFER_CONFIG *const src_sharpened,
975     const YV12_BUFFER_CONFIG *const recon_sharpened, FULLPEL_MV *mvs,
976     double best_score, const double unsharp_amount_start,
977     const double step_size, const int max_loop_count, const double max_amount) {
978   const double min_amount = 0.0;
979   int loop_count = 0;
980   double approx_score = best_score;
981   double unsharp_amount = unsharp_amount_start;
982 
983   do {
984     best_score = approx_score;
985     unsharp_amount += step_size;
986     if (unsharp_amount > max_amount || unsharp_amount < min_amount) break;
987     unsharp(cpi, recon, recon_blurred, recon_sharpened, unsharp_amount);
988     unsharp(cpi, src, src_blurred, src_sharpened, unsharp_amount);
989     const double new_variance =
990         residual_frame_average_variance(cpi, src_sharpened, ref, mvs);
991     approx_score = cal_approx_score(cpi, src_variance, new_variance, base_score,
992                                     src, recon_sharpened);
993 
994     loop_count++;
995   } while (approx_score > best_score && loop_count < max_loop_count);
996   unsharp_amount =
997       approx_score > best_score ? unsharp_amount : unsharp_amount - step_size;
998 
999   return AOMMIN(max_amount, AOMMAX(unsharp_amount, min_amount));
1000 }
1001 
find_best_frame_unsharp_amount_neg(AV1_COMP * const cpi,const YV12_BUFFER_CONFIG * const src,const YV12_BUFFER_CONFIG * const recon,const YV12_BUFFER_CONFIG * const ref,double base_score,const double unsharp_amount_start,const double step_size,const int max_loop_count,const double max_filter_amount)1002 static double find_best_frame_unsharp_amount_neg(
1003     AV1_COMP *const cpi, const YV12_BUFFER_CONFIG *const src,
1004     const YV12_BUFFER_CONFIG *const recon, const YV12_BUFFER_CONFIG *const ref,
1005     double base_score, const double unsharp_amount_start,
1006     const double step_size, const int max_loop_count,
1007     const double max_filter_amount) {
1008   FULLPEL_MV *mvs = NULL;
1009   const double src_variance =
1010       residual_frame_average_variance(cpi, src, ref, mvs);
1011 
1012   const AV1_COMMON *const cm = &cpi->common;
1013   const int width = recon->y_width;
1014   const int height = recon->y_height;
1015   const int bit_depth = cpi->td.mb.e_mbd.bd;
1016   const int ss_x = recon->subsampling_x;
1017   const int ss_y = recon->subsampling_y;
1018 
1019   YV12_BUFFER_CONFIG src_blurred, recon_blurred, src_sharpened, recon_sharpened;
1020   memset(&recon_sharpened, 0, sizeof(recon_sharpened));
1021   memset(&src_sharpened, 0, sizeof(src_sharpened));
1022   memset(&recon_blurred, 0, sizeof(recon_blurred));
1023   memset(&src_blurred, 0, sizeof(src_blurred));
1024   aom_alloc_frame_buffer(&recon_sharpened, width, height, ss_x, ss_y,
1025                          cm->seq_params->use_highbitdepth,
1026                          cpi->oxcf.border_in_pixels,
1027                          cm->features.byte_alignment, false, 0);
1028   aom_alloc_frame_buffer(&src_sharpened, width, height, ss_x, ss_y,
1029                          cm->seq_params->use_highbitdepth,
1030                          cpi->oxcf.border_in_pixels,
1031                          cm->features.byte_alignment, false, 0);
1032   aom_alloc_frame_buffer(&recon_blurred, width, height, ss_x, ss_y,
1033                          cm->seq_params->use_highbitdepth,
1034                          cpi->oxcf.border_in_pixels,
1035                          cm->features.byte_alignment, false, 0);
1036   aom_alloc_frame_buffer(
1037       &src_blurred, width, height, ss_x, ss_y, cm->seq_params->use_highbitdepth,
1038       cpi->oxcf.border_in_pixels, cm->features.byte_alignment, false, 0);
1039 
1040   gaussian_blur(bit_depth, recon, &recon_blurred);
1041   gaussian_blur(bit_depth, src, &src_blurred);
1042 
1043   unsharp(cpi, recon, &recon_blurred, &recon_sharpened, unsharp_amount_start);
1044   unsharp(cpi, src, &src_blurred, &src_sharpened, unsharp_amount_start);
1045   const double variance_start =
1046       residual_frame_average_variance(cpi, &src_sharpened, ref, mvs);
1047   const double score_start = cal_approx_score(
1048       cpi, src_variance, variance_start, base_score, src, &recon_sharpened);
1049 
1050   const double unsharp_amount_next = unsharp_amount_start + step_size;
1051   unsharp(cpi, recon, &recon_blurred, &recon_sharpened, unsharp_amount_next);
1052   unsharp(cpi, src, &src_blurred, &src_sharpened, unsharp_amount_next);
1053   const double variance_next =
1054       residual_frame_average_variance(cpi, &src_sharpened, ref, mvs);
1055   const double score_next = cal_approx_score(cpi, src_variance, variance_next,
1056                                              base_score, src, &recon_sharpened);
1057 
1058   double unsharp_amount;
1059   if (score_next > score_start) {
1060     unsharp_amount = find_best_frame_unsharp_amount_loop_neg(
1061         cpi, src_variance, base_score, src, recon, ref, &src_blurred,
1062         &recon_blurred, &src_sharpened, &recon_sharpened, mvs, score_next,
1063         unsharp_amount_next, step_size, max_loop_count, max_filter_amount);
1064   } else {
1065     unsharp_amount = find_best_frame_unsharp_amount_loop_neg(
1066         cpi, src_variance, base_score, src, recon, ref, &src_blurred,
1067         &recon_blurred, &src_sharpened, &recon_sharpened, mvs, score_start,
1068         unsharp_amount_start, -step_size, max_loop_count, max_filter_amount);
1069   }
1070 
1071   aom_free_frame_buffer(&recon_sharpened);
1072   aom_free_frame_buffer(&src_sharpened);
1073   aom_free_frame_buffer(&recon_blurred);
1074   aom_free_frame_buffer(&src_blurred);
1075   aom_free(mvs);
1076   return unsharp_amount;
1077 }
1078 
av1_update_vmaf_curve(AV1_COMP * cpi)1079 void av1_update_vmaf_curve(AV1_COMP *cpi) {
1080   const YV12_BUFFER_CONFIG *source = cpi->source;
1081   const YV12_BUFFER_CONFIG *recon = &cpi->common.cur_frame->buf;
1082   const int bit_depth = cpi->td.mb.e_mbd.bd;
1083   const GF_GROUP *const gf_group = &cpi->ppi->gf_group;
1084   const int layer_depth =
1085       AOMMIN(gf_group->layer_depth[cpi->gf_frame_index], MAX_ARF_LAYERS - 1);
1086   double base_score;
1087   const bool cal_vmaf_neg =
1088       cpi->oxcf.tune_cfg.tuning == AOM_TUNE_VMAF_NEG_MAX_GAIN;
1089   aom_calc_vmaf(cpi->vmaf_info.vmaf_model, source, recon, bit_depth,
1090                 cal_vmaf_neg, &base_score);
1091   cpi->vmaf_info.last_frame_vmaf[layer_depth] = base_score;
1092   if (cpi->common.seq_params->use_highbitdepth) {
1093     assert(source->flags & YV12_FLAG_HIGHBITDEPTH);
1094     assert(recon->flags & YV12_FLAG_HIGHBITDEPTH);
1095     cpi->vmaf_info.last_frame_ysse[layer_depth] =
1096         (double)aom_highbd_get_y_sse(source, recon);
1097   } else {
1098     cpi->vmaf_info.last_frame_ysse[layer_depth] =
1099         (double)aom_get_y_sse(source, recon);
1100   }
1101 
1102   if (cpi->oxcf.tune_cfg.tuning == AOM_TUNE_VMAF_NEG_MAX_GAIN) {
1103     const YV12_BUFFER_CONFIG *last, *next;
1104     get_neighbor_frames(cpi, &last, &next);
1105     double best_unsharp_amount_start =
1106         get_layer_value(cpi->vmaf_info.last_frame_unsharp_amount, layer_depth);
1107     const int max_loop_count = 5;
1108     cpi->vmaf_info.last_frame_unsharp_amount[layer_depth] =
1109         find_best_frame_unsharp_amount_neg(cpi, source, recon, last, base_score,
1110                                            best_unsharp_amount_start, 0.025,
1111                                            max_loop_count, 1.01);
1112   }
1113 }
1114