xref: /aosp_15_r20/external/libaom/av1/encoder/compound_type.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include "av1/common/pred_common.h"
13 #include "av1/encoder/compound_type.h"
14 #include "av1/encoder/encoder_alloc.h"
15 #include "av1/encoder/model_rd.h"
16 #include "av1/encoder/motion_search_facade.h"
17 #include "av1/encoder/rdopt_utils.h"
18 #include "av1/encoder/reconinter_enc.h"
19 #include "av1/encoder/tx_search.h"
20 
21 typedef int64_t (*pick_interinter_mask_type)(
22     const AV1_COMP *const cpi, MACROBLOCK *x, const BLOCK_SIZE bsize,
23     const uint8_t *const p0, const uint8_t *const p1,
24     const int16_t *const residual1, const int16_t *const diff10,
25     uint64_t *best_sse);
26 
27 // Checks if characteristics of search match
is_comp_rd_match(const AV1_COMP * const cpi,const MACROBLOCK * const x,const COMP_RD_STATS * st,const MB_MODE_INFO * const mi,int32_t * comp_rate,int64_t * comp_dist,int32_t * comp_model_rate,int64_t * comp_model_dist,int * comp_rs2)28 static inline int is_comp_rd_match(const AV1_COMP *const cpi,
29                                    const MACROBLOCK *const x,
30                                    const COMP_RD_STATS *st,
31                                    const MB_MODE_INFO *const mi,
32                                    int32_t *comp_rate, int64_t *comp_dist,
33                                    int32_t *comp_model_rate,
34                                    int64_t *comp_model_dist, int *comp_rs2) {
35   // TODO(ranjit): Ensure that compound type search use regular filter always
36   // and check if following check can be removed
37   // Check if interp filter matches with previous case
38   if (st->filter.as_int != mi->interp_filters.as_int) return 0;
39 
40   const MACROBLOCKD *const xd = &x->e_mbd;
41   // Match MV and reference indices
42   for (int i = 0; i < 2; ++i) {
43     if ((st->ref_frames[i] != mi->ref_frame[i]) ||
44         (st->mv[i].as_int != mi->mv[i].as_int)) {
45       return 0;
46     }
47     const WarpedMotionParams *const wm = &xd->global_motion[mi->ref_frame[i]];
48     if (is_global_mv_block(mi, wm->wmtype) != st->is_global[i]) return 0;
49   }
50 
51   int reuse_data[COMPOUND_TYPES] = { 1, 1, 0, 0 };
52   // For compound wedge, reuse data if newmv search is disabled when NEWMV is
53   // present or if NEWMV is not present in either of the directions
54   if ((!have_newmv_in_inter_mode(mi->mode) &&
55        !have_newmv_in_inter_mode(st->mode)) ||
56       (cpi->sf.inter_sf.disable_interinter_wedge_newmv_search))
57     reuse_data[COMPOUND_WEDGE] = 1;
58   // For compound diffwtd, reuse data if fast search is enabled (no newmv search
59   // when NEWMV is present) or if NEWMV is not present in either of the
60   // directions
61   if (cpi->sf.inter_sf.enable_fast_compound_mode_search ||
62       (!have_newmv_in_inter_mode(mi->mode) &&
63        !have_newmv_in_inter_mode(st->mode)))
64     reuse_data[COMPOUND_DIFFWTD] = 1;
65 
66   // Store the stats for the different compound types
67   for (int comp_type = COMPOUND_AVERAGE; comp_type < COMPOUND_TYPES;
68        comp_type++) {
69     if (reuse_data[comp_type]) {
70       comp_rate[comp_type] = st->rate[comp_type];
71       comp_dist[comp_type] = st->dist[comp_type];
72       comp_model_rate[comp_type] = st->model_rate[comp_type];
73       comp_model_dist[comp_type] = st->model_dist[comp_type];
74       comp_rs2[comp_type] = st->comp_rs2[comp_type];
75     }
76   }
77   return 1;
78 }
79 
80 // Checks if similar compound type search case is accounted earlier
81 // If found, returns relevant rd data
find_comp_rd_in_stats(const AV1_COMP * const cpi,const MACROBLOCK * x,const MB_MODE_INFO * const mbmi,int32_t * comp_rate,int64_t * comp_dist,int32_t * comp_model_rate,int64_t * comp_model_dist,int * comp_rs2,int * match_index)82 static inline int find_comp_rd_in_stats(const AV1_COMP *const cpi,
83                                         const MACROBLOCK *x,
84                                         const MB_MODE_INFO *const mbmi,
85                                         int32_t *comp_rate, int64_t *comp_dist,
86                                         int32_t *comp_model_rate,
87                                         int64_t *comp_model_dist, int *comp_rs2,
88                                         int *match_index) {
89   for (int j = 0; j < x->comp_rd_stats_idx; ++j) {
90     if (is_comp_rd_match(cpi, x, &x->comp_rd_stats[j], mbmi, comp_rate,
91                          comp_dist, comp_model_rate, comp_model_dist,
92                          comp_rs2)) {
93       *match_index = j;
94       return 1;
95     }
96   }
97   return 0;  // no match result found
98 }
99 
enable_wedge_search(MACROBLOCK * const x,const unsigned int disable_wedge_var_thresh)100 static inline bool enable_wedge_search(
101     MACROBLOCK *const x, const unsigned int disable_wedge_var_thresh) {
102   // Enable wedge search if source variance and edge strength are above
103   // the thresholds.
104   return x->source_variance > disable_wedge_var_thresh;
105 }
106 
enable_wedge_interinter_search(MACROBLOCK * const x,const AV1_COMP * const cpi)107 static inline bool enable_wedge_interinter_search(MACROBLOCK *const x,
108                                                   const AV1_COMP *const cpi) {
109   return enable_wedge_search(
110              x, cpi->sf.inter_sf.disable_interinter_wedge_var_thresh) &&
111          cpi->oxcf.comp_type_cfg.enable_interinter_wedge;
112 }
113 
enable_wedge_interintra_search(MACROBLOCK * const x,const AV1_COMP * const cpi)114 static inline bool enable_wedge_interintra_search(MACROBLOCK *const x,
115                                                   const AV1_COMP *const cpi) {
116   return enable_wedge_search(
117              x, cpi->sf.inter_sf.disable_interintra_wedge_var_thresh) &&
118          cpi->oxcf.comp_type_cfg.enable_interintra_wedge;
119 }
120 
estimate_wedge_sign(const AV1_COMP * cpi,const MACROBLOCK * x,const BLOCK_SIZE bsize,const uint8_t * pred0,int stride0,const uint8_t * pred1,int stride1)121 static int8_t estimate_wedge_sign(const AV1_COMP *cpi, const MACROBLOCK *x,
122                                   const BLOCK_SIZE bsize, const uint8_t *pred0,
123                                   int stride0, const uint8_t *pred1,
124                                   int stride1) {
125   static const BLOCK_SIZE split_qtr[BLOCK_SIZES_ALL] = {
126     //                            4X4
127     BLOCK_INVALID,
128     // 4X8,        8X4,           8X8
129     BLOCK_INVALID, BLOCK_INVALID, BLOCK_4X4,
130     // 8X16,       16X8,          16X16
131     BLOCK_4X8, BLOCK_8X4, BLOCK_8X8,
132     // 16X32,      32X16,         32X32
133     BLOCK_8X16, BLOCK_16X8, BLOCK_16X16,
134     // 32X64,      64X32,         64X64
135     BLOCK_16X32, BLOCK_32X16, BLOCK_32X32,
136     // 64x128,     128x64,        128x128
137     BLOCK_32X64, BLOCK_64X32, BLOCK_64X64,
138     // 4X16,       16X4,          8X32
139     BLOCK_INVALID, BLOCK_INVALID, BLOCK_4X16,
140     // 32X8,       16X64,         64X16
141     BLOCK_16X4, BLOCK_8X32, BLOCK_32X8
142   };
143   const struct macroblock_plane *const p = &x->plane[0];
144   const uint8_t *src = p->src.buf;
145   int src_stride = p->src.stride;
146   const int bw = block_size_wide[bsize];
147   const int bh = block_size_high[bsize];
148   const int bw_by2 = bw >> 1;
149   const int bh_by2 = bh >> 1;
150   uint32_t esq[2][2];
151   int64_t tl, br;
152 
153   const BLOCK_SIZE f_index = split_qtr[bsize];
154   assert(f_index != BLOCK_INVALID);
155 
156   if (is_cur_buf_hbd(&x->e_mbd)) {
157     pred0 = CONVERT_TO_BYTEPTR(pred0);
158     pred1 = CONVERT_TO_BYTEPTR(pred1);
159   }
160 
161   // Residual variance computation over relevant quandrants in order to
162   // find TL + BR, TL = sum(1st,2nd,3rd) quadrants of (pred0 - pred1),
163   // BR = sum(2nd,3rd,4th) quadrants of (pred1 - pred0)
164   // The 2nd and 3rd quadrants cancel out in TL + BR
165   // Hence TL + BR = 1st quadrant of (pred0-pred1) + 4th of (pred1-pred0)
166   // TODO(nithya): Sign estimation assumes 45 degrees (1st and 4th quadrants)
167   // for all codebooks; experiment with other quadrant combinations for
168   // 0, 90 and 135 degrees also.
169   cpi->ppi->fn_ptr[f_index].vf(src, src_stride, pred0, stride0, &esq[0][0]);
170   cpi->ppi->fn_ptr[f_index].vf(src + bh_by2 * src_stride + bw_by2, src_stride,
171                                pred0 + bh_by2 * stride0 + bw_by2, stride0,
172                                &esq[0][1]);
173   cpi->ppi->fn_ptr[f_index].vf(src, src_stride, pred1, stride1, &esq[1][0]);
174   cpi->ppi->fn_ptr[f_index].vf(src + bh_by2 * src_stride + bw_by2, src_stride,
175                                pred1 + bh_by2 * stride1 + bw_by2, stride0,
176                                &esq[1][1]);
177 
178   tl = ((int64_t)esq[0][0]) - ((int64_t)esq[1][0]);
179   br = ((int64_t)esq[1][1]) - ((int64_t)esq[0][1]);
180   return (tl + br > 0);
181 }
182 
183 // Choose the best wedge index and sign
pick_wedge(const AV1_COMP * const cpi,const MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const int16_t * const residual1,const int16_t * const diff10,int8_t * const best_wedge_sign,int8_t * const best_wedge_index,uint64_t * best_sse)184 static int64_t pick_wedge(const AV1_COMP *const cpi, const MACROBLOCK *const x,
185                           const BLOCK_SIZE bsize, const uint8_t *const p0,
186                           const int16_t *const residual1,
187                           const int16_t *const diff10,
188                           int8_t *const best_wedge_sign,
189                           int8_t *const best_wedge_index, uint64_t *best_sse) {
190   const MACROBLOCKD *const xd = &x->e_mbd;
191   const struct buf_2d *const src = &x->plane[0].src;
192   const int bw = block_size_wide[bsize];
193   const int bh = block_size_high[bsize];
194   const int N = bw * bh;
195   assert(N >= 64);
196   int rate;
197   int64_t dist;
198   int64_t rd, best_rd = INT64_MAX;
199   int8_t wedge_index;
200   int8_t wedge_sign;
201   const int8_t wedge_types = get_wedge_types_lookup(bsize);
202   const uint8_t *mask;
203   uint64_t sse;
204   const int hbd = is_cur_buf_hbd(xd);
205   const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
206 
207   DECLARE_ALIGNED(32, int16_t, residual0[MAX_SB_SQUARE]);  // src - pred0
208 #if CONFIG_AV1_HIGHBITDEPTH
209   if (hbd) {
210     aom_highbd_subtract_block(bh, bw, residual0, bw, src->buf, src->stride,
211                               CONVERT_TO_BYTEPTR(p0), bw);
212   } else {
213     aom_subtract_block(bh, bw, residual0, bw, src->buf, src->stride, p0, bw);
214   }
215 #else
216   (void)hbd;
217   aom_subtract_block(bh, bw, residual0, bw, src->buf, src->stride, p0, bw);
218 #endif
219 
220   int64_t sign_limit = ((int64_t)aom_sum_squares_i16(residual0, N) -
221                         (int64_t)aom_sum_squares_i16(residual1, N)) *
222                        (1 << WEDGE_WEIGHT_BITS) / 2;
223   int16_t *ds = residual0;
224 
225   av1_wedge_compute_delta_squares(ds, residual0, residual1, N);
226 
227   for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
228     mask = av1_get_contiguous_soft_mask(wedge_index, 0, bsize);
229 
230     wedge_sign = av1_wedge_sign_from_residuals(ds, mask, N, sign_limit);
231 
232     mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize);
233     sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N);
234     sse = ROUND_POWER_OF_TWO(sse, bd_round);
235 
236     model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
237                                                   &rate, &dist);
238     // int rate2;
239     // int64_t dist2;
240     // model_rd_with_curvfit(cpi, x, bsize, 0, sse, N, &rate2, &dist2);
241     // printf("sse %"PRId64": leagacy: %d %"PRId64", curvfit %d %"PRId64"\n",
242     // sse, rate, dist, rate2, dist2); dist = dist2;
243     // rate = rate2;
244 
245     rate += x->mode_costs.wedge_idx_cost[bsize][wedge_index];
246     rd = RDCOST(x->rdmult, rate, dist);
247 
248     if (rd < best_rd) {
249       *best_wedge_index = wedge_index;
250       *best_wedge_sign = wedge_sign;
251       best_rd = rd;
252       *best_sse = sse;
253     }
254   }
255 
256   return best_rd -
257          RDCOST(x->rdmult,
258                 x->mode_costs.wedge_idx_cost[bsize][*best_wedge_index], 0);
259 }
260 
261 // Choose the best wedge index the specified sign
pick_wedge_fixed_sign(const AV1_COMP * const cpi,const MACROBLOCK * const x,const BLOCK_SIZE bsize,const int16_t * const residual1,const int16_t * const diff10,const int8_t wedge_sign,int8_t * const best_wedge_index,uint64_t * best_sse)262 static int64_t pick_wedge_fixed_sign(
263     const AV1_COMP *const cpi, const MACROBLOCK *const x,
264     const BLOCK_SIZE bsize, const int16_t *const residual1,
265     const int16_t *const diff10, const int8_t wedge_sign,
266     int8_t *const best_wedge_index, uint64_t *best_sse) {
267   const MACROBLOCKD *const xd = &x->e_mbd;
268 
269   const int bw = block_size_wide[bsize];
270   const int bh = block_size_high[bsize];
271   const int N = bw * bh;
272   assert(N >= 64);
273   int rate;
274   int64_t dist;
275   int64_t rd, best_rd = INT64_MAX;
276   int8_t wedge_index;
277   const int8_t wedge_types = get_wedge_types_lookup(bsize);
278   const uint8_t *mask;
279   uint64_t sse;
280   const int hbd = is_cur_buf_hbd(xd);
281   const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
282   for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
283     mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize);
284     sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N);
285     sse = ROUND_POWER_OF_TWO(sse, bd_round);
286 
287     model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
288                                                   &rate, &dist);
289     rate += x->mode_costs.wedge_idx_cost[bsize][wedge_index];
290     rd = RDCOST(x->rdmult, rate, dist);
291 
292     if (rd < best_rd) {
293       *best_wedge_index = wedge_index;
294       best_rd = rd;
295       *best_sse = sse;
296     }
297   }
298   return best_rd -
299          RDCOST(x->rdmult,
300                 x->mode_costs.wedge_idx_cost[bsize][*best_wedge_index], 0);
301 }
302 
pick_interinter_wedge(const AV1_COMP * const cpi,MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1,const int16_t * const residual1,const int16_t * const diff10,uint64_t * best_sse)303 static int64_t pick_interinter_wedge(
304     const AV1_COMP *const cpi, MACROBLOCK *const x, const BLOCK_SIZE bsize,
305     const uint8_t *const p0, const uint8_t *const p1,
306     const int16_t *const residual1, const int16_t *const diff10,
307     uint64_t *best_sse) {
308   MACROBLOCKD *const xd = &x->e_mbd;
309   MB_MODE_INFO *const mbmi = xd->mi[0];
310   const int bw = block_size_wide[bsize];
311 
312   int64_t rd;
313   int8_t wedge_index = -1;
314   int8_t wedge_sign = 0;
315 
316   assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
317   assert(cpi->common.seq_params->enable_masked_compound);
318 
319   if (cpi->sf.inter_sf.fast_wedge_sign_estimate) {
320     wedge_sign = estimate_wedge_sign(cpi, x, bsize, p0, bw, p1, bw);
321     rd = pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, wedge_sign,
322                                &wedge_index, best_sse);
323   } else {
324     rd = pick_wedge(cpi, x, bsize, p0, residual1, diff10, &wedge_sign,
325                     &wedge_index, best_sse);
326   }
327 
328   mbmi->interinter_comp.wedge_sign = wedge_sign;
329   mbmi->interinter_comp.wedge_index = wedge_index;
330   return rd;
331 }
332 
pick_interinter_seg(const AV1_COMP * const cpi,MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1,const int16_t * const residual1,const int16_t * const diff10,uint64_t * best_sse)333 static int64_t pick_interinter_seg(const AV1_COMP *const cpi,
334                                    MACROBLOCK *const x, const BLOCK_SIZE bsize,
335                                    const uint8_t *const p0,
336                                    const uint8_t *const p1,
337                                    const int16_t *const residual1,
338                                    const int16_t *const diff10,
339                                    uint64_t *best_sse) {
340   MACROBLOCKD *const xd = &x->e_mbd;
341   MB_MODE_INFO *const mbmi = xd->mi[0];
342   const int bw = block_size_wide[bsize];
343   const int bh = block_size_high[bsize];
344   const int N = 1 << num_pels_log2_lookup[bsize];
345   int rate;
346   int64_t dist;
347   DIFFWTD_MASK_TYPE cur_mask_type;
348   int64_t best_rd = INT64_MAX;
349   DIFFWTD_MASK_TYPE best_mask_type = 0;
350   const int hbd = is_cur_buf_hbd(xd);
351   const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
352   DECLARE_ALIGNED(16, uint8_t, seg_mask[2 * MAX_SB_SQUARE]);
353   uint8_t *tmp_mask[2] = { xd->seg_mask, seg_mask };
354   // try each mask type and its inverse
355   for (cur_mask_type = 0; cur_mask_type < DIFFWTD_MASK_TYPES; cur_mask_type++) {
356     // build mask and inverse
357 #if CONFIG_AV1_HIGHBITDEPTH
358     if (hbd)
359       av1_build_compound_diffwtd_mask_highbd(
360           tmp_mask[cur_mask_type], cur_mask_type, CONVERT_TO_BYTEPTR(p0), bw,
361           CONVERT_TO_BYTEPTR(p1), bw, bh, bw, xd->bd);
362     else
363       av1_build_compound_diffwtd_mask(tmp_mask[cur_mask_type], cur_mask_type,
364                                       p0, bw, p1, bw, bh, bw);
365 #else
366     (void)hbd;
367     av1_build_compound_diffwtd_mask(tmp_mask[cur_mask_type], cur_mask_type, p0,
368                                     bw, p1, bw, bh, bw);
369 #endif  // CONFIG_AV1_HIGHBITDEPTH
370 
371     // compute rd for mask
372     uint64_t sse = av1_wedge_sse_from_residuals(residual1, diff10,
373                                                 tmp_mask[cur_mask_type], N);
374     sse = ROUND_POWER_OF_TWO(sse, bd_round);
375 
376     model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
377                                                   &rate, &dist);
378     const int64_t rd0 = RDCOST(x->rdmult, rate, dist);
379 
380     if (rd0 < best_rd) {
381       best_mask_type = cur_mask_type;
382       best_rd = rd0;
383       *best_sse = sse;
384     }
385   }
386   mbmi->interinter_comp.mask_type = best_mask_type;
387   if (best_mask_type == DIFFWTD_38_INV) {
388     memcpy(xd->seg_mask, seg_mask, N * 2);
389   }
390   return best_rd;
391 }
392 
pick_interintra_wedge(const AV1_COMP * const cpi,const MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1)393 static int64_t pick_interintra_wedge(const AV1_COMP *const cpi,
394                                      const MACROBLOCK *const x,
395                                      const BLOCK_SIZE bsize,
396                                      const uint8_t *const p0,
397                                      const uint8_t *const p1) {
398   const MACROBLOCKD *const xd = &x->e_mbd;
399   MB_MODE_INFO *const mbmi = xd->mi[0];
400   assert(av1_is_wedge_used(bsize));
401   assert(cpi->common.seq_params->enable_interintra_compound);
402 
403   const struct buf_2d *const src = &x->plane[0].src;
404   const int bw = block_size_wide[bsize];
405   const int bh = block_size_high[bsize];
406   DECLARE_ALIGNED(32, int16_t, residual1[MAX_SB_SQUARE]);  // src - pred1
407   DECLARE_ALIGNED(32, int16_t, diff10[MAX_SB_SQUARE]);     // pred1 - pred0
408 #if CONFIG_AV1_HIGHBITDEPTH
409   if (is_cur_buf_hbd(xd)) {
410     aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride,
411                               CONVERT_TO_BYTEPTR(p1), bw);
412     aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(p1), bw,
413                               CONVERT_TO_BYTEPTR(p0), bw);
414   } else {
415     aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, p1, bw);
416     aom_subtract_block(bh, bw, diff10, bw, p1, bw, p0, bw);
417   }
418 #else
419   aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, p1, bw);
420   aom_subtract_block(bh, bw, diff10, bw, p1, bw, p0, bw);
421 #endif
422   int8_t wedge_index = -1;
423   uint64_t sse;
424   int64_t rd = pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, 0,
425                                      &wedge_index, &sse);
426 
427   mbmi->interintra_wedge_index = wedge_index;
428   return rd;
429 }
430 
get_inter_predictors_masked_compound(MACROBLOCK * x,const BLOCK_SIZE bsize,uint8_t ** preds0,uint8_t ** preds1,int16_t * residual1,int16_t * diff10,int * strides)431 static inline void get_inter_predictors_masked_compound(
432     MACROBLOCK *x, const BLOCK_SIZE bsize, uint8_t **preds0, uint8_t **preds1,
433     int16_t *residual1, int16_t *diff10, int *strides) {
434   MACROBLOCKD *xd = &x->e_mbd;
435   const int bw = block_size_wide[bsize];
436   const int bh = block_size_high[bsize];
437   // get inter predictors to use for masked compound modes
438   av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 0, preds0,
439                                                    strides);
440   av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 1, preds1,
441                                                    strides);
442   const struct buf_2d *const src = &x->plane[0].src;
443 #if CONFIG_AV1_HIGHBITDEPTH
444   if (is_cur_buf_hbd(xd)) {
445     aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride,
446                               CONVERT_TO_BYTEPTR(*preds1), bw);
447     aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(*preds1),
448                               bw, CONVERT_TO_BYTEPTR(*preds0), bw);
449   } else {
450     aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, *preds1,
451                        bw);
452     aom_subtract_block(bh, bw, diff10, bw, *preds1, bw, *preds0, bw);
453   }
454 #else
455   aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, *preds1, bw);
456   aom_subtract_block(bh, bw, diff10, bw, *preds1, bw, *preds0, bw);
457 #endif
458 }
459 
460 // Computes the rd cost for the given interintra mode and updates the best
compute_best_interintra_mode(const AV1_COMP * const cpi,MB_MODE_INFO * mbmi,MACROBLOCKD * xd,MACROBLOCK * const x,const int * const interintra_mode_cost,const BUFFER_SET * orig_dst,uint8_t * intrapred,const uint8_t * tmp_buf,INTERINTRA_MODE * best_interintra_mode,int64_t * best_interintra_rd,INTERINTRA_MODE interintra_mode,BLOCK_SIZE bsize)461 static inline void compute_best_interintra_mode(
462     const AV1_COMP *const cpi, MB_MODE_INFO *mbmi, MACROBLOCKD *xd,
463     MACROBLOCK *const x, const int *const interintra_mode_cost,
464     const BUFFER_SET *orig_dst, uint8_t *intrapred, const uint8_t *tmp_buf,
465     INTERINTRA_MODE *best_interintra_mode, int64_t *best_interintra_rd,
466     INTERINTRA_MODE interintra_mode, BLOCK_SIZE bsize) {
467   const AV1_COMMON *const cm = &cpi->common;
468   int rate;
469   uint8_t skip_txfm_sb;
470   int64_t dist, skip_sse_sb;
471   const int bw = block_size_wide[bsize];
472   mbmi->interintra_mode = interintra_mode;
473   int rmode = interintra_mode_cost[interintra_mode];
474   av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
475                                             intrapred, bw);
476   av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
477   model_rd_sb_fn[MODELRD_TYPE_INTERINTRA](cpi, bsize, x, xd, 0, 0, &rate, &dist,
478                                           &skip_txfm_sb, &skip_sse_sb, NULL,
479                                           NULL, NULL);
480   int64_t rd = RDCOST(x->rdmult, rate + rmode, dist);
481   if (rd < *best_interintra_rd) {
482     *best_interintra_rd = rd;
483     *best_interintra_mode = mbmi->interintra_mode;
484   }
485 }
486 
estimate_yrd_for_sb(const AV1_COMP * const cpi,BLOCK_SIZE bs,MACROBLOCK * x,int64_t ref_best_rd,RD_STATS * rd_stats)487 static int64_t estimate_yrd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bs,
488                                    MACROBLOCK *x, int64_t ref_best_rd,
489                                    RD_STATS *rd_stats) {
490   MACROBLOCKD *const xd = &x->e_mbd;
491   if (ref_best_rd < 0) return INT64_MAX;
492   av1_subtract_plane(x, bs, 0);
493   const int64_t rd = av1_estimate_txfm_yrd(cpi, x, rd_stats, ref_best_rd, bs,
494                                            max_txsize_rect_lookup[bs]);
495   if (rd != INT64_MAX) {
496     const int skip_ctx = av1_get_skip_txfm_context(xd);
497     if (rd_stats->skip_txfm) {
498       const int s1 = x->mode_costs.skip_txfm_cost[skip_ctx][1];
499       rd_stats->rate = s1;
500     } else {
501       const int s0 = x->mode_costs.skip_txfm_cost[skip_ctx][0];
502       rd_stats->rate += s0;
503     }
504   }
505   return rd;
506 }
507 
508 // Computes the rd_threshold for smooth interintra rd search.
compute_rd_thresh(MACROBLOCK * const x,int total_mode_rate,int64_t ref_best_rd)509 static inline int64_t compute_rd_thresh(MACROBLOCK *const x,
510                                         int total_mode_rate,
511                                         int64_t ref_best_rd) {
512   const int64_t rd_thresh = get_rd_thresh_from_best_rd(
513       ref_best_rd, (1 << INTER_INTRA_RD_THRESH_SHIFT),
514       INTER_INTRA_RD_THRESH_SCALE);
515   const int64_t mode_rd = RDCOST(x->rdmult, total_mode_rate, 0);
516   return (rd_thresh - mode_rd);
517 }
518 
519 // Computes the best wedge interintra mode
compute_best_wedge_interintra(const AV1_COMP * const cpi,MB_MODE_INFO * mbmi,MACROBLOCKD * xd,MACROBLOCK * const x,const int * const interintra_mode_cost,const BUFFER_SET * orig_dst,uint8_t * intrapred_,uint8_t * tmp_buf_,int * best_mode,int * best_wedge_index,BLOCK_SIZE bsize)520 static inline int64_t compute_best_wedge_interintra(
521     const AV1_COMP *const cpi, MB_MODE_INFO *mbmi, MACROBLOCKD *xd,
522     MACROBLOCK *const x, const int *const interintra_mode_cost,
523     const BUFFER_SET *orig_dst, uint8_t *intrapred_, uint8_t *tmp_buf_,
524     int *best_mode, int *best_wedge_index, BLOCK_SIZE bsize) {
525   const AV1_COMMON *const cm = &cpi->common;
526   const int bw = block_size_wide[bsize];
527   int64_t best_interintra_rd_wedge = INT64_MAX;
528   int64_t best_total_rd = INT64_MAX;
529   uint8_t *intrapred = get_buf_by_bd(xd, intrapred_);
530   for (INTERINTRA_MODE mode = 0; mode < INTERINTRA_MODES; ++mode) {
531     mbmi->interintra_mode = mode;
532     av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
533                                               intrapred, bw);
534     int64_t rd = pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
535     const int rate_overhead =
536         interintra_mode_cost[mode] +
537         x->mode_costs.wedge_idx_cost[bsize][mbmi->interintra_wedge_index];
538     const int64_t total_rd = rd + RDCOST(x->rdmult, rate_overhead, 0);
539     if (total_rd < best_total_rd) {
540       best_total_rd = total_rd;
541       best_interintra_rd_wedge = rd;
542       *best_mode = mbmi->interintra_mode;
543       *best_wedge_index = mbmi->interintra_wedge_index;
544     }
545   }
546   return best_interintra_rd_wedge;
547 }
548 
handle_smooth_inter_intra_mode(const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,MB_MODE_INFO * mbmi,int64_t ref_best_rd,int * rate_mv,INTERINTRA_MODE * best_interintra_mode,int64_t * best_rd,int * best_mode_rate,const BUFFER_SET * orig_dst,uint8_t * tmp_buf,uint8_t * intrapred,HandleInterModeArgs * args)549 static int handle_smooth_inter_intra_mode(
550     const AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize,
551     MB_MODE_INFO *mbmi, int64_t ref_best_rd, int *rate_mv,
552     INTERINTRA_MODE *best_interintra_mode, int64_t *best_rd,
553     int *best_mode_rate, const BUFFER_SET *orig_dst, uint8_t *tmp_buf,
554     uint8_t *intrapred, HandleInterModeArgs *args) {
555   MACROBLOCKD *xd = &x->e_mbd;
556   const ModeCosts *mode_costs = &x->mode_costs;
557   const int *const interintra_mode_cost =
558       mode_costs->interintra_mode_cost[size_group_lookup[bsize]];
559   const AV1_COMMON *const cm = &cpi->common;
560   const int bw = block_size_wide[bsize];
561 
562   mbmi->use_wedge_interintra = 0;
563 
564   if (cpi->sf.inter_sf.reuse_inter_intra_mode == 0 ||
565       *best_interintra_mode == INTERINTRA_MODES) {
566     int64_t best_interintra_rd = INT64_MAX;
567     for (INTERINTRA_MODE cur_mode = 0; cur_mode < INTERINTRA_MODES;
568          ++cur_mode) {
569       if ((!cpi->oxcf.intra_mode_cfg.enable_smooth_intra ||
570            cpi->sf.intra_sf.disable_smooth_intra) &&
571           cur_mode == II_SMOOTH_PRED)
572         continue;
573       compute_best_interintra_mode(
574           cpi, mbmi, xd, x, interintra_mode_cost, orig_dst, intrapred, tmp_buf,
575           best_interintra_mode, &best_interintra_rd, cur_mode, bsize);
576     }
577     args->inter_intra_mode[mbmi->ref_frame[0]] = *best_interintra_mode;
578   }
579   assert(IMPLIES(!cpi->oxcf.comp_type_cfg.enable_smooth_interintra,
580                  *best_interintra_mode != II_SMOOTH_PRED));
581   // Recompute prediction if required
582   bool interintra_mode_reuse = cpi->sf.inter_sf.reuse_inter_intra_mode ||
583                                *best_interintra_mode != INTERINTRA_MODES;
584   if (interintra_mode_reuse || *best_interintra_mode != INTERINTRA_MODES - 1) {
585     mbmi->interintra_mode = *best_interintra_mode;
586     av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
587                                               intrapred, bw);
588     av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
589   }
590 
591   // Compute rd cost for best smooth_interintra
592   RD_STATS rd_stats;
593   const int is_wedge_used = av1_is_wedge_used(bsize);
594   const int rmode =
595       interintra_mode_cost[*best_interintra_mode] +
596       (is_wedge_used ? mode_costs->wedge_interintra_cost[bsize][0] : 0);
597   const int total_mode_rate = rmode + *rate_mv;
598   const int64_t rd_thresh = compute_rd_thresh(x, total_mode_rate, ref_best_rd);
599   int64_t rd = estimate_yrd_for_sb(cpi, bsize, x, rd_thresh, &rd_stats);
600   if (rd != INT64_MAX) {
601     rd = RDCOST(x->rdmult, total_mode_rate + rd_stats.rate, rd_stats.dist);
602   } else {
603     return IGNORE_MODE;
604   }
605   *best_rd = rd;
606   *best_mode_rate = rmode;
607   // Return early if best rd not good enough
608   if (ref_best_rd < INT64_MAX &&
609       (*best_rd >> INTER_INTRA_RD_THRESH_SHIFT) * INTER_INTRA_RD_THRESH_SCALE >
610           ref_best_rd) {
611     return IGNORE_MODE;
612   }
613   return 0;
614 }
615 
handle_wedge_inter_intra_mode(const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,MB_MODE_INFO * mbmi,int * rate_mv,INTERINTRA_MODE * best_interintra_mode,int64_t * best_rd,const BUFFER_SET * orig_dst,uint8_t * tmp_buf_,uint8_t * tmp_buf,uint8_t * intrapred_,uint8_t * intrapred,HandleInterModeArgs * args,int * tmp_rate_mv,int * rate_overhead,int_mv * tmp_mv,int64_t best_rd_no_wedge)616 static int handle_wedge_inter_intra_mode(
617     const AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize,
618     MB_MODE_INFO *mbmi, int *rate_mv, INTERINTRA_MODE *best_interintra_mode,
619     int64_t *best_rd, const BUFFER_SET *orig_dst, uint8_t *tmp_buf_,
620     uint8_t *tmp_buf, uint8_t *intrapred_, uint8_t *intrapred,
621     HandleInterModeArgs *args, int *tmp_rate_mv, int *rate_overhead,
622     int_mv *tmp_mv, int64_t best_rd_no_wedge) {
623   MACROBLOCKD *xd = &x->e_mbd;
624   const ModeCosts *mode_costs = &x->mode_costs;
625   const int *const interintra_mode_cost =
626       mode_costs->interintra_mode_cost[size_group_lookup[bsize]];
627   const AV1_COMMON *const cm = &cpi->common;
628   const int bw = block_size_wide[bsize];
629   const int try_smooth_interintra =
630       cpi->oxcf.comp_type_cfg.enable_smooth_interintra;
631 
632   mbmi->use_wedge_interintra = 1;
633 
634   if (!cpi->sf.inter_sf.fast_interintra_wedge_search) {
635     // Exhaustive search of all wedge and mode combinations.
636     int best_mode = 0;
637     int best_wedge_index = 0;
638     *best_rd = compute_best_wedge_interintra(
639         cpi, mbmi, xd, x, interintra_mode_cost, orig_dst, intrapred_, tmp_buf_,
640         &best_mode, &best_wedge_index, bsize);
641     mbmi->interintra_mode = best_mode;
642     mbmi->interintra_wedge_index = best_wedge_index;
643     if (best_mode != INTERINTRA_MODES - 1) {
644       av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
645                                                 intrapred, bw);
646     }
647   } else if (!try_smooth_interintra) {
648     if (*best_interintra_mode == INTERINTRA_MODES) {
649       mbmi->interintra_mode = INTERINTRA_MODES - 1;
650       *best_interintra_mode = INTERINTRA_MODES - 1;
651       av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
652                                                 intrapred, bw);
653       // Pick wedge mask based on INTERINTRA_MODES - 1
654       *best_rd = pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
655       // Find the best interintra mode for the chosen wedge mask
656       for (INTERINTRA_MODE cur_mode = 0; cur_mode < INTERINTRA_MODES;
657            ++cur_mode) {
658         compute_best_interintra_mode(
659             cpi, mbmi, xd, x, interintra_mode_cost, orig_dst, intrapred,
660             tmp_buf, best_interintra_mode, best_rd, cur_mode, bsize);
661       }
662       args->inter_intra_mode[mbmi->ref_frame[0]] = *best_interintra_mode;
663       mbmi->interintra_mode = *best_interintra_mode;
664 
665       // Recompute prediction if required
666       if (*best_interintra_mode != INTERINTRA_MODES - 1) {
667         av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
668                                                   intrapred, bw);
669       }
670     } else {
671       // Pick wedge mask for the best interintra mode (reused)
672       mbmi->interintra_mode = *best_interintra_mode;
673       av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
674                                                 intrapred, bw);
675       *best_rd = pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
676     }
677   } else {
678     // Pick wedge mask for the best interintra mode from smooth_interintra
679     *best_rd = pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
680   }
681 
682   *rate_overhead =
683       interintra_mode_cost[mbmi->interintra_mode] +
684       mode_costs->wedge_idx_cost[bsize][mbmi->interintra_wedge_index] +
685       mode_costs->wedge_interintra_cost[bsize][1];
686   *best_rd += RDCOST(x->rdmult, *rate_overhead + *rate_mv, 0);
687 
688   int64_t rd = INT64_MAX;
689   const int_mv mv0 = mbmi->mv[0];
690   // Refine motion vector for NEWMV case.
691   if (have_newmv_in_inter_mode(mbmi->mode)) {
692     int rate_sum;
693     uint8_t skip_txfm_sb;
694     int64_t dist_sum, skip_sse_sb;
695     // get negative of mask
696     const uint8_t *mask =
697         av1_get_contiguous_soft_mask(mbmi->interintra_wedge_index, 1, bsize);
698     av1_compound_single_motion_search(cpi, x, bsize, &tmp_mv->as_mv, intrapred,
699                                       mask, bw, tmp_rate_mv, 0);
700     if (mbmi->mv[0].as_int != tmp_mv->as_int) {
701       mbmi->mv[0].as_int = tmp_mv->as_int;
702       // Set ref_frame[1] to NONE_FRAME temporarily so that the intra
703       // predictor is not calculated again in av1_enc_build_inter_predictor().
704       mbmi->ref_frame[1] = NONE_FRAME;
705       const int mi_row = xd->mi_row;
706       const int mi_col = xd->mi_col;
707       av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
708                                     AOM_PLANE_Y, AOM_PLANE_Y);
709       mbmi->ref_frame[1] = INTRA_FRAME;
710       av1_combine_interintra(xd, bsize, 0, xd->plane[AOM_PLANE_Y].dst.buf,
711                              xd->plane[AOM_PLANE_Y].dst.stride, intrapred, bw);
712       model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
713           cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &skip_txfm_sb,
714           &skip_sse_sb, NULL, NULL, NULL);
715       rd =
716           RDCOST(x->rdmult, *tmp_rate_mv + *rate_overhead + rate_sum, dist_sum);
717     }
718   }
719   if (rd >= *best_rd) {
720     tmp_mv->as_int = mv0.as_int;
721     *tmp_rate_mv = *rate_mv;
722     av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
723   }
724   // Evaluate closer to true rd
725   RD_STATS rd_stats;
726   const int64_t mode_rd = RDCOST(x->rdmult, *rate_overhead + *tmp_rate_mv, 0);
727   const int64_t tmp_rd_thresh = best_rd_no_wedge - mode_rd;
728   rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &rd_stats);
729   if (rd != INT64_MAX) {
730     rd = RDCOST(x->rdmult, *rate_overhead + *tmp_rate_mv + rd_stats.rate,
731                 rd_stats.dist);
732   } else {
733     if (*best_rd == INT64_MAX) return IGNORE_MODE;
734   }
735   *best_rd = rd;
736   return 0;
737 }
738 
av1_handle_inter_intra_mode(const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,MB_MODE_INFO * mbmi,HandleInterModeArgs * args,int64_t ref_best_rd,int * rate_mv,int * tmp_rate2,const BUFFER_SET * orig_dst)739 int av1_handle_inter_intra_mode(const AV1_COMP *const cpi, MACROBLOCK *const x,
740                                 BLOCK_SIZE bsize, MB_MODE_INFO *mbmi,
741                                 HandleInterModeArgs *args, int64_t ref_best_rd,
742                                 int *rate_mv, int *tmp_rate2,
743                                 const BUFFER_SET *orig_dst) {
744   const int try_smooth_interintra =
745       cpi->oxcf.comp_type_cfg.enable_smooth_interintra;
746 
747   const int is_wedge_used = av1_is_wedge_used(bsize);
748   const int try_wedge_interintra =
749       is_wedge_used && enable_wedge_interintra_search(x, cpi);
750 
751   const AV1_COMMON *const cm = &cpi->common;
752   MACROBLOCKD *xd = &x->e_mbd;
753   const int bw = block_size_wide[bsize];
754   DECLARE_ALIGNED(16, uint8_t, tmp_buf_[2 * MAX_INTERINTRA_SB_SQUARE]);
755   DECLARE_ALIGNED(16, uint8_t, intrapred_[2 * MAX_INTERINTRA_SB_SQUARE]);
756   uint8_t *tmp_buf = get_buf_by_bd(xd, tmp_buf_);
757   uint8_t *intrapred = get_buf_by_bd(xd, intrapred_);
758   const int mi_row = xd->mi_row;
759   const int mi_col = xd->mi_col;
760 
761   // Single reference inter prediction
762   mbmi->ref_frame[1] = NONE_FRAME;
763   xd->plane[0].dst.buf = tmp_buf;
764   xd->plane[0].dst.stride = bw;
765   av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL, bsize,
766                                 AOM_PLANE_Y, AOM_PLANE_Y);
767   const int num_planes = av1_num_planes(cm);
768 
769   // Restore the buffers for intra prediction
770   restore_dst_buf(xd, *orig_dst, num_planes);
771   mbmi->ref_frame[1] = INTRA_FRAME;
772   INTERINTRA_MODE best_interintra_mode =
773       args->inter_intra_mode[mbmi->ref_frame[0]];
774 
775   // Compute smooth_interintra
776   int64_t best_interintra_rd_nowedge = INT64_MAX;
777   int best_mode_rate = INT_MAX;
778   if (try_smooth_interintra) {
779     int ret = handle_smooth_inter_intra_mode(
780         cpi, x, bsize, mbmi, ref_best_rd, rate_mv, &best_interintra_mode,
781         &best_interintra_rd_nowedge, &best_mode_rate, orig_dst, tmp_buf,
782         intrapred, args);
783     if (ret == IGNORE_MODE) {
784       return IGNORE_MODE;
785     }
786   }
787 
788   // Compute wedge interintra
789   int64_t best_interintra_rd_wedge = INT64_MAX;
790   const int_mv mv0 = mbmi->mv[0];
791   int_mv tmp_mv = mv0;
792   int tmp_rate_mv = 0;
793   int rate_overhead = 0;
794   if (try_wedge_interintra) {
795     int ret = handle_wedge_inter_intra_mode(
796         cpi, x, bsize, mbmi, rate_mv, &best_interintra_mode,
797         &best_interintra_rd_wedge, orig_dst, tmp_buf_, tmp_buf, intrapred_,
798         intrapred, args, &tmp_rate_mv, &rate_overhead, &tmp_mv,
799         best_interintra_rd_nowedge);
800     if (ret == IGNORE_MODE) {
801       return IGNORE_MODE;
802     }
803   }
804 
805   if (best_interintra_rd_nowedge == INT64_MAX &&
806       best_interintra_rd_wedge == INT64_MAX) {
807     return IGNORE_MODE;
808   }
809   if (best_interintra_rd_wedge < best_interintra_rd_nowedge) {
810     mbmi->mv[0].as_int = tmp_mv.as_int;
811     *tmp_rate2 += tmp_rate_mv - *rate_mv;
812     *rate_mv = tmp_rate_mv;
813     best_mode_rate = rate_overhead;
814   } else if (try_smooth_interintra && try_wedge_interintra) {
815     // If smooth was best, but we over-wrote the values when evaluating the
816     // wedge mode, we need to recompute the smooth values.
817     mbmi->use_wedge_interintra = 0;
818     mbmi->interintra_mode = best_interintra_mode;
819     mbmi->mv[0].as_int = mv0.as_int;
820     av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
821                                   AOM_PLANE_Y, AOM_PLANE_Y);
822   }
823   *tmp_rate2 += best_mode_rate;
824 
825   if (num_planes > 1) {
826     av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
827                                   AOM_PLANE_U, num_planes - 1);
828   }
829   return 0;
830 }
831 
832 // Computes the valid compound_types to be evaluated
compute_valid_comp_types(MACROBLOCK * x,const AV1_COMP * const cpi,BLOCK_SIZE bsize,int masked_compound_used,int mode_search_mask,COMPOUND_TYPE * valid_comp_types)833 static inline int compute_valid_comp_types(MACROBLOCK *x,
834                                            const AV1_COMP *const cpi,
835                                            BLOCK_SIZE bsize,
836                                            int masked_compound_used,
837                                            int mode_search_mask,
838                                            COMPOUND_TYPE *valid_comp_types) {
839   const AV1_COMMON *cm = &cpi->common;
840   int valid_type_count = 0;
841   int comp_type, valid_check;
842   int8_t enable_masked_type[MASKED_COMPOUND_TYPES] = { 0, 0 };
843 
844   const int try_average_comp = (mode_search_mask & (1 << COMPOUND_AVERAGE));
845   const int try_distwtd_comp =
846       ((mode_search_mask & (1 << COMPOUND_DISTWTD)) &&
847        cm->seq_params->order_hint_info.enable_dist_wtd_comp == 1 &&
848        cpi->sf.inter_sf.use_dist_wtd_comp_flag != DIST_WTD_COMP_DISABLED);
849 
850   // Check if COMPOUND_AVERAGE and COMPOUND_DISTWTD are valid cases
851   for (comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
852        comp_type++) {
853     valid_check =
854         (comp_type == COMPOUND_AVERAGE) ? try_average_comp : try_distwtd_comp;
855     if (valid_check && is_interinter_compound_used(comp_type, bsize))
856       valid_comp_types[valid_type_count++] = comp_type;
857   }
858   // Check if COMPOUND_WEDGE and COMPOUND_DIFFWTD are valid cases
859   if (masked_compound_used) {
860     // enable_masked_type[0] corresponds to COMPOUND_WEDGE
861     // enable_masked_type[1] corresponds to COMPOUND_DIFFWTD
862     enable_masked_type[0] = enable_wedge_interinter_search(x, cpi);
863     enable_masked_type[1] = cpi->oxcf.comp_type_cfg.enable_diff_wtd_comp;
864     for (comp_type = COMPOUND_WEDGE; comp_type <= COMPOUND_DIFFWTD;
865          comp_type++) {
866       if ((mode_search_mask & (1 << comp_type)) &&
867           is_interinter_compound_used(comp_type, bsize) &&
868           enable_masked_type[comp_type - COMPOUND_WEDGE])
869         valid_comp_types[valid_type_count++] = comp_type;
870     }
871   }
872   return valid_type_count;
873 }
874 
875 // Calculates the cost for compound type mask
calc_masked_type_cost(const ModeCosts * mode_costs,BLOCK_SIZE bsize,int comp_group_idx_ctx,int comp_index_ctx,int masked_compound_used,int * masked_type_cost)876 static inline void calc_masked_type_cost(
877     const ModeCosts *mode_costs, BLOCK_SIZE bsize, int comp_group_idx_ctx,
878     int comp_index_ctx, int masked_compound_used, int *masked_type_cost) {
879   av1_zero_array(masked_type_cost, COMPOUND_TYPES);
880   // Account for group index cost when wedge and/or diffwtd prediction are
881   // enabled
882   if (masked_compound_used) {
883     // Compound group index of average and distwtd is 0
884     // Compound group index of wedge and diffwtd is 1
885     masked_type_cost[COMPOUND_AVERAGE] +=
886         mode_costs->comp_group_idx_cost[comp_group_idx_ctx][0];
887     masked_type_cost[COMPOUND_DISTWTD] += masked_type_cost[COMPOUND_AVERAGE];
888     masked_type_cost[COMPOUND_WEDGE] +=
889         mode_costs->comp_group_idx_cost[comp_group_idx_ctx][1];
890     masked_type_cost[COMPOUND_DIFFWTD] += masked_type_cost[COMPOUND_WEDGE];
891   }
892 
893   // Compute the cost to signal compound index/type
894   masked_type_cost[COMPOUND_AVERAGE] +=
895       mode_costs->comp_idx_cost[comp_index_ctx][1];
896   masked_type_cost[COMPOUND_DISTWTD] +=
897       mode_costs->comp_idx_cost[comp_index_ctx][0];
898   masked_type_cost[COMPOUND_WEDGE] += mode_costs->compound_type_cost[bsize][0];
899   masked_type_cost[COMPOUND_DIFFWTD] +=
900       mode_costs->compound_type_cost[bsize][1];
901 }
902 
903 // Updates mbmi structure with the relevant compound type info
update_mbmi_for_compound_type(MB_MODE_INFO * mbmi,COMPOUND_TYPE cur_type)904 static inline void update_mbmi_for_compound_type(MB_MODE_INFO *mbmi,
905                                                  COMPOUND_TYPE cur_type) {
906   mbmi->interinter_comp.type = cur_type;
907   mbmi->comp_group_idx = (cur_type >= COMPOUND_WEDGE);
908   mbmi->compound_idx = (cur_type != COMPOUND_DISTWTD);
909 }
910 
911 // When match is found, populate the compound type data
912 // and calculate the rd cost using the stored stats and
913 // update the mbmi appropriately.
populate_reuse_comp_type_data(const MACROBLOCK * x,MB_MODE_INFO * mbmi,BEST_COMP_TYPE_STATS * best_type_stats,int_mv * cur_mv,int32_t * comp_rate,int64_t * comp_dist,int * comp_rs2,int * rate_mv,int64_t * rd,int match_index)914 static inline int populate_reuse_comp_type_data(
915     const MACROBLOCK *x, MB_MODE_INFO *mbmi,
916     BEST_COMP_TYPE_STATS *best_type_stats, int_mv *cur_mv, int32_t *comp_rate,
917     int64_t *comp_dist, int *comp_rs2, int *rate_mv, int64_t *rd,
918     int match_index) {
919   const int winner_comp_type =
920       x->comp_rd_stats[match_index].interinter_comp.type;
921   if (comp_rate[winner_comp_type] == INT_MAX)
922     return best_type_stats->best_compmode_interinter_cost;
923   update_mbmi_for_compound_type(mbmi, winner_comp_type);
924   mbmi->interinter_comp = x->comp_rd_stats[match_index].interinter_comp;
925   *rd = RDCOST(
926       x->rdmult,
927       comp_rs2[winner_comp_type] + *rate_mv + comp_rate[winner_comp_type],
928       comp_dist[winner_comp_type]);
929   mbmi->mv[0].as_int = cur_mv[0].as_int;
930   mbmi->mv[1].as_int = cur_mv[1].as_int;
931   return comp_rs2[winner_comp_type];
932 }
933 
934 // Updates rd cost and relevant compound type data for the best compound type
update_best_info(const MB_MODE_INFO * const mbmi,int64_t * rd,BEST_COMP_TYPE_STATS * best_type_stats,int64_t best_rd_cur,int64_t comp_model_rd_cur,int rs2)935 static inline void update_best_info(const MB_MODE_INFO *const mbmi, int64_t *rd,
936                                     BEST_COMP_TYPE_STATS *best_type_stats,
937                                     int64_t best_rd_cur,
938                                     int64_t comp_model_rd_cur, int rs2) {
939   *rd = best_rd_cur;
940   best_type_stats->comp_best_model_rd = comp_model_rd_cur;
941   best_type_stats->best_compound_data = mbmi->interinter_comp;
942   best_type_stats->best_compmode_interinter_cost = rs2;
943 }
944 
945 // Updates best_mv for masked compound types
update_mask_best_mv(const MB_MODE_INFO * const mbmi,int_mv * best_mv,int * best_tmp_rate_mv,int tmp_rate_mv)946 static inline void update_mask_best_mv(const MB_MODE_INFO *const mbmi,
947                                        int_mv *best_mv, int *best_tmp_rate_mv,
948                                        int tmp_rate_mv) {
949   *best_tmp_rate_mv = tmp_rate_mv;
950   best_mv[0].as_int = mbmi->mv[0].as_int;
951   best_mv[1].as_int = mbmi->mv[1].as_int;
952 }
953 
save_comp_rd_search_stat(MACROBLOCK * x,const MB_MODE_INFO * const mbmi,const int32_t * comp_rate,const int64_t * comp_dist,const int32_t * comp_model_rate,const int64_t * comp_model_dist,const int_mv * cur_mv,const int * comp_rs2)954 static inline void save_comp_rd_search_stat(
955     MACROBLOCK *x, const MB_MODE_INFO *const mbmi, const int32_t *comp_rate,
956     const int64_t *comp_dist, const int32_t *comp_model_rate,
957     const int64_t *comp_model_dist, const int_mv *cur_mv, const int *comp_rs2) {
958   const int offset = x->comp_rd_stats_idx;
959   if (offset < MAX_COMP_RD_STATS) {
960     COMP_RD_STATS *const rd_stats = x->comp_rd_stats + offset;
961     memcpy(rd_stats->rate, comp_rate, sizeof(rd_stats->rate));
962     memcpy(rd_stats->dist, comp_dist, sizeof(rd_stats->dist));
963     memcpy(rd_stats->model_rate, comp_model_rate, sizeof(rd_stats->model_rate));
964     memcpy(rd_stats->model_dist, comp_model_dist, sizeof(rd_stats->model_dist));
965     memcpy(rd_stats->comp_rs2, comp_rs2, sizeof(rd_stats->comp_rs2));
966     memcpy(rd_stats->mv, cur_mv, sizeof(rd_stats->mv));
967     memcpy(rd_stats->ref_frames, mbmi->ref_frame, sizeof(rd_stats->ref_frames));
968     rd_stats->mode = mbmi->mode;
969     rd_stats->filter = mbmi->interp_filters;
970     rd_stats->ref_mv_idx = mbmi->ref_mv_idx;
971     const MACROBLOCKD *const xd = &x->e_mbd;
972     for (int i = 0; i < 2; ++i) {
973       const WarpedMotionParams *const wm =
974           &xd->global_motion[mbmi->ref_frame[i]];
975       rd_stats->is_global[i] = is_global_mv_block(mbmi, wm->wmtype);
976     }
977     memcpy(&rd_stats->interinter_comp, &mbmi->interinter_comp,
978            sizeof(rd_stats->interinter_comp));
979     ++x->comp_rd_stats_idx;
980   }
981 }
982 
get_interinter_compound_mask_rate(const ModeCosts * const mode_costs,const MB_MODE_INFO * const mbmi)983 static inline int get_interinter_compound_mask_rate(
984     const ModeCosts *const mode_costs, const MB_MODE_INFO *const mbmi) {
985   const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
986   // This function will be called only for COMPOUND_WEDGE and COMPOUND_DIFFWTD
987   if (compound_type == COMPOUND_WEDGE) {
988     return av1_is_wedge_used(mbmi->bsize)
989                ? av1_cost_literal(1) +
990                      mode_costs
991                          ->wedge_idx_cost[mbmi->bsize]
992                                          [mbmi->interinter_comp.wedge_index]
993                : 0;
994   } else {
995     assert(compound_type == COMPOUND_DIFFWTD);
996     return av1_cost_literal(1);
997   }
998 }
999 
1000 // Takes a backup of rate, distortion and model_rd for future reuse
backup_stats(COMPOUND_TYPE cur_type,int32_t * comp_rate,int64_t * comp_dist,int32_t * comp_model_rate,int64_t * comp_model_dist,int rate_sum,int64_t dist_sum,RD_STATS * rd_stats,int * comp_rs2,int rs2)1001 static inline void backup_stats(COMPOUND_TYPE cur_type, int32_t *comp_rate,
1002                                 int64_t *comp_dist, int32_t *comp_model_rate,
1003                                 int64_t *comp_model_dist, int rate_sum,
1004                                 int64_t dist_sum, RD_STATS *rd_stats,
1005                                 int *comp_rs2, int rs2) {
1006   comp_rate[cur_type] = rd_stats->rate;
1007   comp_dist[cur_type] = rd_stats->dist;
1008   comp_model_rate[cur_type] = rate_sum;
1009   comp_model_dist[cur_type] = dist_sum;
1010   comp_rs2[cur_type] = rs2;
1011 }
1012 
save_mask_search_results(const PREDICTION_MODE this_mode,const int reuse_level)1013 static inline int save_mask_search_results(const PREDICTION_MODE this_mode,
1014                                            const int reuse_level) {
1015   if (reuse_level || (this_mode == NEW_NEWMV))
1016     return 1;
1017   else
1018     return 0;
1019 }
1020 
prune_mode_by_skip_rd(const AV1_COMP * const cpi,MACROBLOCK * x,MACROBLOCKD * xd,const BLOCK_SIZE bsize,int64_t ref_skip_rd,int mode_rate)1021 static inline int prune_mode_by_skip_rd(const AV1_COMP *const cpi,
1022                                         MACROBLOCK *x, MACROBLOCKD *xd,
1023                                         const BLOCK_SIZE bsize,
1024                                         int64_t ref_skip_rd, int mode_rate) {
1025   int eval_txfm = 1;
1026   const int txfm_rd_gate_level =
1027       get_txfm_rd_gate_level(cpi->common.seq_params->enable_masked_compound,
1028                              cpi->sf.inter_sf.txfm_rd_gate_level, bsize,
1029                              TX_SEARCH_COMP_TYPE_MODE, /*eval_motion_mode=*/0);
1030   // Check if the mode is good enough based on skip rd
1031   if (txfm_rd_gate_level) {
1032     int64_t sse_y = compute_sse_plane(x, xd, PLANE_TYPE_Y, bsize);
1033     int64_t skip_rd = RDCOST(x->rdmult, mode_rate, (sse_y << 4));
1034     eval_txfm =
1035         check_txfm_eval(x, bsize, ref_skip_rd, skip_rd, txfm_rd_gate_level, 1);
1036   }
1037   return eval_txfm;
1038 }
1039 
masked_compound_type_rd(const AV1_COMP * const cpi,MACROBLOCK * x,const int_mv * const cur_mv,const BLOCK_SIZE bsize,const PREDICTION_MODE this_mode,int * rs2,int rate_mv,const BUFFER_SET * ctx,int * out_rate_mv,uint8_t ** preds0,uint8_t ** preds1,int16_t * residual1,int16_t * diff10,int * strides,int mode_rate,int64_t rd_thresh,int * calc_pred_masked_compound,int32_t * comp_rate,int64_t * comp_dist,int32_t * comp_model_rate,int64_t * comp_model_dist,const int64_t comp_best_model_rd,int64_t * const comp_model_rd_cur,int * comp_rs2,int64_t ref_skip_rd)1040 static int64_t masked_compound_type_rd(
1041     const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
1042     const BLOCK_SIZE bsize, const PREDICTION_MODE this_mode, int *rs2,
1043     int rate_mv, const BUFFER_SET *ctx, int *out_rate_mv, uint8_t **preds0,
1044     uint8_t **preds1, int16_t *residual1, int16_t *diff10, int *strides,
1045     int mode_rate, int64_t rd_thresh, int *calc_pred_masked_compound,
1046     int32_t *comp_rate, int64_t *comp_dist, int32_t *comp_model_rate,
1047     int64_t *comp_model_dist, const int64_t comp_best_model_rd,
1048     int64_t *const comp_model_rd_cur, int *comp_rs2, int64_t ref_skip_rd) {
1049   const AV1_COMMON *const cm = &cpi->common;
1050   MACROBLOCKD *xd = &x->e_mbd;
1051   MB_MODE_INFO *const mbmi = xd->mi[0];
1052   int64_t best_rd_cur = INT64_MAX;
1053   int64_t rd = INT64_MAX;
1054   const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
1055   // This function will be called only for COMPOUND_WEDGE and COMPOUND_DIFFWTD
1056   assert(compound_type == COMPOUND_WEDGE || compound_type == COMPOUND_DIFFWTD);
1057   int rate_sum;
1058   uint8_t tmp_skip_txfm_sb;
1059   int64_t dist_sum, tmp_skip_sse_sb;
1060   pick_interinter_mask_type pick_interinter_mask[2] = { pick_interinter_wedge,
1061                                                         pick_interinter_seg };
1062 
1063   // TODO(any): Save pred and mask calculation as well into records. However
1064   // this may increase memory requirements as compound segment mask needs to be
1065   // stored in each record.
1066   if (*calc_pred_masked_compound) {
1067     get_inter_predictors_masked_compound(x, bsize, preds0, preds1, residual1,
1068                                          diff10, strides);
1069     *calc_pred_masked_compound = 0;
1070   }
1071   if (compound_type == COMPOUND_WEDGE) {
1072     unsigned int sse;
1073     if (is_cur_buf_hbd(xd))
1074       (void)cpi->ppi->fn_ptr[bsize].vf(CONVERT_TO_BYTEPTR(*preds0), *strides,
1075                                        CONVERT_TO_BYTEPTR(*preds1), *strides,
1076                                        &sse);
1077     else
1078       (void)cpi->ppi->fn_ptr[bsize].vf(*preds0, *strides, *preds1, *strides,
1079                                        &sse);
1080     const unsigned int mse =
1081         ROUND_POWER_OF_TWO(sse, num_pels_log2_lookup[bsize]);
1082     // If two predictors are very similar, skip wedge compound mode search
1083     if (mse < 8 || (!have_newmv_in_inter_mode(this_mode) && mse < 64)) {
1084       *comp_model_rd_cur = INT64_MAX;
1085       return INT64_MAX;
1086     }
1087   }
1088   // Function pointer to pick the appropriate mask
1089   // compound_type == COMPOUND_WEDGE, calls pick_interinter_wedge()
1090   // compound_type == COMPOUND_DIFFWTD, calls pick_interinter_seg()
1091   uint64_t cur_sse = UINT64_MAX;
1092   best_rd_cur = pick_interinter_mask[compound_type - COMPOUND_WEDGE](
1093       cpi, x, bsize, *preds0, *preds1, residual1, diff10, &cur_sse);
1094   *rs2 += get_interinter_compound_mask_rate(&x->mode_costs, mbmi);
1095   best_rd_cur += RDCOST(x->rdmult, *rs2 + rate_mv, 0);
1096   assert(cur_sse != UINT64_MAX);
1097   int64_t skip_rd_cur = RDCOST(x->rdmult, *rs2 + rate_mv, (cur_sse << 4));
1098 
1099   // Although the true rate_mv might be different after motion search, but it
1100   // is unlikely to be the best mode considering the transform rd cost and other
1101   // mode overhead cost
1102   int64_t mode_rd = RDCOST(x->rdmult, *rs2 + mode_rate, 0);
1103   if (mode_rd > rd_thresh) {
1104     *comp_model_rd_cur = INT64_MAX;
1105     return INT64_MAX;
1106   }
1107 
1108   // Check if the mode is good enough based on skip rd
1109   // TODO(nithya): Handle wedge_newmv_search if extending for lower speed
1110   // setting
1111   const int txfm_rd_gate_level =
1112       get_txfm_rd_gate_level(cm->seq_params->enable_masked_compound,
1113                              cpi->sf.inter_sf.txfm_rd_gate_level, bsize,
1114                              TX_SEARCH_COMP_TYPE_MODE, /*eval_motion_mode=*/0);
1115   if (txfm_rd_gate_level) {
1116     int eval_txfm = check_txfm_eval(x, bsize, ref_skip_rd, skip_rd_cur,
1117                                     txfm_rd_gate_level, 1);
1118     if (!eval_txfm) {
1119       *comp_model_rd_cur = INT64_MAX;
1120       return INT64_MAX;
1121     }
1122   }
1123 
1124   // Compute cost if matching record not found, else, reuse data
1125   if (comp_rate[compound_type] == INT_MAX) {
1126     // Check whether new MV search for wedge is to be done
1127     int wedge_newmv_search =
1128         have_newmv_in_inter_mode(this_mode) &&
1129         (compound_type == COMPOUND_WEDGE) &&
1130         (!cpi->sf.inter_sf.disable_interinter_wedge_newmv_search);
1131 
1132     // Search for new MV if needed and build predictor
1133     if (wedge_newmv_search) {
1134       *out_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
1135                                                            bsize, this_mode);
1136       const int mi_row = xd->mi_row;
1137       const int mi_col = xd->mi_col;
1138       av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, ctx, bsize,
1139                                     AOM_PLANE_Y, AOM_PLANE_Y);
1140     } else {
1141       *out_rate_mv = rate_mv;
1142       av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides,
1143                                                preds1, strides);
1144     }
1145     // Get the RD cost from model RD
1146     model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
1147         cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &tmp_skip_txfm_sb,
1148         &tmp_skip_sse_sb, NULL, NULL, NULL);
1149     rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + rate_sum, dist_sum);
1150     *comp_model_rd_cur = rd;
1151     // Override with best if current is worse than best for new MV
1152     if (wedge_newmv_search) {
1153       if (rd >= best_rd_cur) {
1154         mbmi->mv[0].as_int = cur_mv[0].as_int;
1155         mbmi->mv[1].as_int = cur_mv[1].as_int;
1156         *out_rate_mv = rate_mv;
1157         av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0,
1158                                                  strides, preds1, strides);
1159         *comp_model_rd_cur = best_rd_cur;
1160       }
1161     }
1162     if (cpi->sf.inter_sf.prune_comp_type_by_model_rd &&
1163         (*comp_model_rd_cur > comp_best_model_rd) &&
1164         comp_best_model_rd != INT64_MAX) {
1165       *comp_model_rd_cur = INT64_MAX;
1166       return INT64_MAX;
1167     }
1168     // Compute RD cost for the current type
1169     RD_STATS rd_stats;
1170     const int64_t tmp_mode_rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv, 0);
1171     const int64_t tmp_rd_thresh = rd_thresh - tmp_mode_rd;
1172     rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &rd_stats);
1173     if (rd != INT64_MAX) {
1174       rd =
1175           RDCOST(x->rdmult, *rs2 + *out_rate_mv + rd_stats.rate, rd_stats.dist);
1176       // Backup rate and distortion for future reuse
1177       backup_stats(compound_type, comp_rate, comp_dist, comp_model_rate,
1178                    comp_model_dist, rate_sum, dist_sum, &rd_stats, comp_rs2,
1179                    *rs2);
1180     }
1181   } else {
1182     // Reuse data as matching record is found
1183     assert(comp_dist[compound_type] != INT64_MAX);
1184     // When disable_interinter_wedge_newmv_search is set, motion refinement is
1185     // disabled. Hence rate and distortion can be reused in this case as well
1186     assert(IMPLIES((have_newmv_in_inter_mode(this_mode) &&
1187                     (compound_type == COMPOUND_WEDGE)),
1188                    cpi->sf.inter_sf.disable_interinter_wedge_newmv_search));
1189     assert(mbmi->mv[0].as_int == cur_mv[0].as_int);
1190     assert(mbmi->mv[1].as_int == cur_mv[1].as_int);
1191     *out_rate_mv = rate_mv;
1192     // Calculate RD cost based on stored stats
1193     rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + comp_rate[compound_type],
1194                 comp_dist[compound_type]);
1195     // Recalculate model rdcost with the updated rate
1196     *comp_model_rd_cur =
1197         RDCOST(x->rdmult, *rs2 + *out_rate_mv + comp_model_rate[compound_type],
1198                comp_model_dist[compound_type]);
1199   }
1200   return rd;
1201 }
1202 
1203 // scaling values to be used for gating wedge/compound segment based on best
1204 // approximate rd
1205 static int comp_type_rd_threshold_mul[3] = { 1, 11, 12 };
1206 static int comp_type_rd_threshold_div[3] = { 3, 16, 16 };
1207 
av1_compound_type_rd(const AV1_COMP * const cpi,MACROBLOCK * x,HandleInterModeArgs * args,BLOCK_SIZE bsize,int_mv * cur_mv,int mode_search_mask,int masked_compound_used,const BUFFER_SET * orig_dst,const BUFFER_SET * tmp_dst,const CompoundTypeRdBuffers * buffers,int * rate_mv,int64_t * rd,RD_STATS * rd_stats,int64_t ref_best_rd,int64_t ref_skip_rd,int * is_luma_interp_done,int64_t rd_thresh)1208 int av1_compound_type_rd(const AV1_COMP *const cpi, MACROBLOCK *x,
1209                          HandleInterModeArgs *args, BLOCK_SIZE bsize,
1210                          int_mv *cur_mv, int mode_search_mask,
1211                          int masked_compound_used, const BUFFER_SET *orig_dst,
1212                          const BUFFER_SET *tmp_dst,
1213                          const CompoundTypeRdBuffers *buffers, int *rate_mv,
1214                          int64_t *rd, RD_STATS *rd_stats, int64_t ref_best_rd,
1215                          int64_t ref_skip_rd, int *is_luma_interp_done,
1216                          int64_t rd_thresh) {
1217   const AV1_COMMON *cm = &cpi->common;
1218   MACROBLOCKD *xd = &x->e_mbd;
1219   MB_MODE_INFO *mbmi = xd->mi[0];
1220   const PREDICTION_MODE this_mode = mbmi->mode;
1221   int ref_frame = av1_ref_frame_type(mbmi->ref_frame);
1222   const int bw = block_size_wide[bsize];
1223   int rs2;
1224   int_mv best_mv[2];
1225   int best_tmp_rate_mv = *rate_mv;
1226   BEST_COMP_TYPE_STATS best_type_stats;
1227   // Initializing BEST_COMP_TYPE_STATS
1228   best_type_stats.best_compound_data.type = COMPOUND_AVERAGE;
1229   best_type_stats.best_compmode_interinter_cost = 0;
1230   best_type_stats.comp_best_model_rd = INT64_MAX;
1231 
1232   uint8_t *preds0[1] = { buffers->pred0 };
1233   uint8_t *preds1[1] = { buffers->pred1 };
1234   int strides[1] = { bw };
1235   int tmp_rate_mv;
1236   COMPOUND_TYPE cur_type;
1237   // Local array to store the mask cost for different compound types
1238   int masked_type_cost[COMPOUND_TYPES];
1239 
1240   int calc_pred_masked_compound = 1;
1241   int64_t comp_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
1242                                         INT64_MAX };
1243   int32_t comp_rate[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX, INT_MAX };
1244   int comp_rs2[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX, INT_MAX };
1245   int32_t comp_model_rate[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX,
1246                                               INT_MAX };
1247   int64_t comp_model_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
1248                                               INT64_MAX };
1249   int match_index = 0;
1250   const int match_found =
1251       find_comp_rd_in_stats(cpi, x, mbmi, comp_rate, comp_dist, comp_model_rate,
1252                             comp_model_dist, comp_rs2, &match_index);
1253   best_mv[0].as_int = cur_mv[0].as_int;
1254   best_mv[1].as_int = cur_mv[1].as_int;
1255   *rd = INT64_MAX;
1256 
1257   // Local array to store the valid compound types to be evaluated in the core
1258   // loop
1259   COMPOUND_TYPE valid_comp_types[COMPOUND_TYPES] = {
1260     COMPOUND_AVERAGE, COMPOUND_DISTWTD, COMPOUND_WEDGE, COMPOUND_DIFFWTD
1261   };
1262   int valid_type_count = 0;
1263   // compute_valid_comp_types() returns the number of valid compound types to be
1264   // evaluated and populates the same in the local array valid_comp_types[].
1265   // It also sets the flag 'try_average_and_distwtd_comp'
1266   valid_type_count = compute_valid_comp_types(
1267       x, cpi, bsize, masked_compound_used, mode_search_mask, valid_comp_types);
1268 
1269   // The following context indices are independent of compound type
1270   const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
1271   const int comp_index_ctx = get_comp_index_context(cm, xd);
1272 
1273   // Populates masked_type_cost local array for the 4 compound types
1274   calc_masked_type_cost(&x->mode_costs, bsize, comp_group_idx_ctx,
1275                         comp_index_ctx, masked_compound_used, masked_type_cost);
1276 
1277   int64_t comp_model_rd_cur = INT64_MAX;
1278   int64_t best_rd_cur = ref_best_rd;
1279   const int mi_row = xd->mi_row;
1280   const int mi_col = xd->mi_col;
1281 
1282   // If the match is found, calculate the rd cost using the
1283   // stored stats and update the mbmi appropriately.
1284   if (match_found && cpi->sf.inter_sf.reuse_compound_type_decision) {
1285     return populate_reuse_comp_type_data(x, mbmi, &best_type_stats, cur_mv,
1286                                          comp_rate, comp_dist, comp_rs2,
1287                                          rate_mv, rd, match_index);
1288   }
1289 
1290   // If COMPOUND_AVERAGE is not valid, use the spare buffer
1291   if (valid_comp_types[0] != COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
1292 
1293   // Loop over valid compound types
1294   for (int i = 0; i < valid_type_count; i++) {
1295     cur_type = valid_comp_types[i];
1296 
1297     if (args->cmp_mode[ref_frame] == COMPOUND_AVERAGE) {
1298       if (cur_type == COMPOUND_WEDGE) continue;
1299     }
1300 
1301     comp_model_rd_cur = INT64_MAX;
1302     tmp_rate_mv = *rate_mv;
1303     best_rd_cur = INT64_MAX;
1304     ref_best_rd = AOMMIN(ref_best_rd, *rd);
1305     update_mbmi_for_compound_type(mbmi, cur_type);
1306     rs2 = masked_type_cost[cur_type];
1307 
1308     int64_t mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
1309     if (mode_rd >= ref_best_rd) continue;
1310 
1311     // Derive the flags to indicate enabling/disabling of MV refinement process.
1312     const int enable_fast_compound_mode_search =
1313         cpi->sf.inter_sf.enable_fast_compound_mode_search;
1314     const bool skip_mv_refinement_for_avg_distwtd =
1315         enable_fast_compound_mode_search == 3 ||
1316         (enable_fast_compound_mode_search == 2 && (this_mode != NEW_NEWMV));
1317     const bool skip_mv_refinement_for_diffwtd =
1318         (!enable_fast_compound_mode_search && cur_type == COMPOUND_DIFFWTD);
1319 
1320     // Case COMPOUND_AVERAGE and COMPOUND_DISTWTD
1321     if (cur_type < COMPOUND_WEDGE) {
1322       if (skip_mv_refinement_for_avg_distwtd) {
1323         int rate_sum;
1324         uint8_t tmp_skip_txfm_sb;
1325         int64_t dist_sum, tmp_skip_sse_sb;
1326 
1327         // Reuse data if matching record is found
1328         if (comp_rate[cur_type] == INT_MAX) {
1329           av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
1330                                         AOM_PLANE_Y, AOM_PLANE_Y);
1331           if (cur_type == COMPOUND_AVERAGE) *is_luma_interp_done = 1;
1332           // Compute RD cost for the current type
1333           RD_STATS est_rd_stats;
1334           const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
1335           int64_t est_rd = INT64_MAX;
1336           int eval_txfm = prune_mode_by_skip_rd(cpi, x, xd, bsize, ref_skip_rd,
1337                                                 rs2 + *rate_mv);
1338           // Evaluate further if skip rd is low enough
1339           if (eval_txfm) {
1340             est_rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh,
1341                                          &est_rd_stats);
1342           }
1343           if (est_rd != INT64_MAX) {
1344             best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
1345                                  est_rd_stats.dist);
1346             model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
1347                 cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
1348                 &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
1349             comp_model_rd_cur =
1350                 RDCOST(x->rdmult, rs2 + *rate_mv + rate_sum, dist_sum);
1351             // Backup rate and distortion for future reuse
1352             backup_stats(cur_type, comp_rate, comp_dist, comp_model_rate,
1353                          comp_model_dist, rate_sum, dist_sum, &est_rd_stats,
1354                          comp_rs2, rs2);
1355           }
1356         } else {
1357           // Calculate RD cost based on stored stats
1358           assert(comp_dist[cur_type] != INT64_MAX);
1359           best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + comp_rate[cur_type],
1360                                comp_dist[cur_type]);
1361           // Recalculate model rdcost with the updated rate
1362           comp_model_rd_cur =
1363               RDCOST(x->rdmult, rs2 + *rate_mv + comp_model_rate[cur_type],
1364                      comp_model_dist[cur_type]);
1365         }
1366       } else {
1367         tmp_rate_mv = *rate_mv;
1368         if (have_newmv_in_inter_mode(this_mode)) {
1369           InterPredParams inter_pred_params;
1370           av1_dist_wtd_comp_weight_assign(
1371               &cpi->common, mbmi, &inter_pred_params.conv_params.fwd_offset,
1372               &inter_pred_params.conv_params.bck_offset,
1373               &inter_pred_params.conv_params.use_dist_wtd_comp_avg, 1);
1374           int mask_value = inter_pred_params.conv_params.fwd_offset * 4;
1375           memset(xd->seg_mask, mask_value,
1376                  sizeof(xd->seg_mask[0]) * 2 * MAX_SB_SQUARE);
1377           tmp_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
1378                                                               bsize, this_mode);
1379         }
1380         av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
1381                                       AOM_PLANE_Y, AOM_PLANE_Y);
1382         if (cur_type == COMPOUND_AVERAGE) *is_luma_interp_done = 1;
1383 
1384         int eval_txfm = prune_mode_by_skip_rd(cpi, x, xd, bsize, ref_skip_rd,
1385                                               rs2 + *rate_mv);
1386         if (eval_txfm) {
1387           RD_STATS est_rd_stats;
1388           estimate_yrd_for_sb(cpi, bsize, x, INT64_MAX, &est_rd_stats);
1389 
1390           best_rd_cur = RDCOST(x->rdmult, rs2 + tmp_rate_mv + est_rd_stats.rate,
1391                                est_rd_stats.dist);
1392         }
1393       }
1394 
1395       // use spare buffer for following compound type try
1396       if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
1397     } else if (cur_type == COMPOUND_WEDGE) {
1398       int best_mask_index = 0;
1399       int best_wedge_sign = 0;
1400       int_mv tmp_mv[2] = { mbmi->mv[0], mbmi->mv[1] };
1401       int best_rs2 = 0;
1402       int best_rate_mv = *rate_mv;
1403       int wedge_mask_size = get_wedge_types_lookup(bsize);
1404       int need_mask_search = args->wedge_index == -1;
1405       int wedge_newmv_search =
1406           have_newmv_in_inter_mode(this_mode) &&
1407           !cpi->sf.inter_sf.disable_interinter_wedge_newmv_search;
1408 
1409       if (need_mask_search && !wedge_newmv_search) {
1410         // short cut repeated single reference block build
1411         av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 0,
1412                                                          preds0, strides);
1413         av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 1,
1414                                                          preds1, strides);
1415       }
1416 
1417       for (int wedge_mask = 0; wedge_mask < wedge_mask_size && need_mask_search;
1418            ++wedge_mask) {
1419         for (int wedge_sign = 0; wedge_sign < 2; ++wedge_sign) {
1420           tmp_rate_mv = *rate_mv;
1421           mbmi->interinter_comp.wedge_index = wedge_mask;
1422           mbmi->interinter_comp.wedge_sign = wedge_sign;
1423           rs2 = masked_type_cost[cur_type];
1424           rs2 += get_interinter_compound_mask_rate(&x->mode_costs, mbmi);
1425 
1426           mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
1427           if (mode_rd >= ref_best_rd / 2) continue;
1428 
1429           if (wedge_newmv_search) {
1430             tmp_rate_mv = av1_interinter_compound_motion_search(
1431                 cpi, x, cur_mv, bsize, this_mode);
1432             av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst,
1433                                           bsize, AOM_PLANE_Y, AOM_PLANE_Y);
1434           } else {
1435             av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0,
1436                                                      strides, preds1, strides);
1437           }
1438 
1439           RD_STATS est_rd_stats;
1440           int64_t this_rd_cur = INT64_MAX;
1441           int eval_txfm = prune_mode_by_skip_rd(cpi, x, xd, bsize, ref_skip_rd,
1442                                                 rs2 + *rate_mv);
1443           if (eval_txfm) {
1444             this_rd_cur = estimate_yrd_for_sb(
1445                 cpi, bsize, x, AOMMIN(best_rd_cur, ref_best_rd), &est_rd_stats);
1446           }
1447           if (this_rd_cur < INT64_MAX) {
1448             this_rd_cur =
1449                 RDCOST(x->rdmult, rs2 + tmp_rate_mv + est_rd_stats.rate,
1450                        est_rd_stats.dist);
1451           }
1452           if (this_rd_cur < best_rd_cur) {
1453             best_mask_index = wedge_mask;
1454             best_wedge_sign = wedge_sign;
1455             best_rd_cur = this_rd_cur;
1456             tmp_mv[0] = mbmi->mv[0];
1457             tmp_mv[1] = mbmi->mv[1];
1458             best_rate_mv = tmp_rate_mv;
1459             best_rs2 = rs2;
1460           }
1461         }
1462         // Consider the asymmetric partitions for oblique angle only if the
1463         // corresponding symmetric partition is the best so far.
1464         // Note: For horizontal and vertical types, both symmetric and
1465         // asymmetric partitions are always considered.
1466         if (cpi->sf.inter_sf.enable_fast_wedge_mask_search) {
1467           // The first 4 entries in wedge_codebook_16_heqw/hltw/hgtw[16]
1468           // correspond to symmetric partitions of the 4 oblique angles, the
1469           // next 4 entries correspond to the vertical/horizontal
1470           // symmetric/asymmetric partitions and the last 8 entries correspond
1471           // to the asymmetric partitions of oblique types.
1472           const int idx_before_asym_oblique = 7;
1473           const int last_oblique_sym_idx = 3;
1474           if (wedge_mask == idx_before_asym_oblique) {
1475             if (best_mask_index > last_oblique_sym_idx) {
1476               break;
1477             } else {
1478               // Asymmetric (Index-1) map for the corresponding oblique masks.
1479               // WEDGE_OBLIQUE27: sym - 0, asym - 8, 9
1480               // WEDGE_OBLIQUE63: sym - 1, asym - 12, 13
1481               // WEDGE_OBLIQUE117: sym - 2, asym - 14, 15
1482               // WEDGE_OBLIQUE153: sym - 3, asym - 10, 11
1483               const int asym_mask_idx[4] = { 7, 11, 13, 9 };
1484               wedge_mask = asym_mask_idx[best_mask_index];
1485               wedge_mask_size = wedge_mask + 3;
1486             }
1487           }
1488         }
1489       }
1490 
1491       if (need_mask_search) {
1492         if (save_mask_search_results(
1493                 this_mode, cpi->sf.inter_sf.reuse_mask_search_results)) {
1494           args->wedge_index = best_mask_index;
1495           args->wedge_sign = best_wedge_sign;
1496         }
1497       } else {
1498         mbmi->interinter_comp.wedge_index = args->wedge_index;
1499         mbmi->interinter_comp.wedge_sign = args->wedge_sign;
1500         rs2 = masked_type_cost[cur_type];
1501         rs2 += get_interinter_compound_mask_rate(&x->mode_costs, mbmi);
1502 
1503         if (wedge_newmv_search) {
1504           tmp_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
1505                                                               bsize, this_mode);
1506         }
1507 
1508         best_mask_index = args->wedge_index;
1509         best_wedge_sign = args->wedge_sign;
1510         tmp_mv[0] = mbmi->mv[0];
1511         tmp_mv[1] = mbmi->mv[1];
1512         best_rate_mv = tmp_rate_mv;
1513         best_rs2 = masked_type_cost[cur_type];
1514         best_rs2 += get_interinter_compound_mask_rate(&x->mode_costs, mbmi);
1515         av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
1516                                       AOM_PLANE_Y, AOM_PLANE_Y);
1517         int eval_txfm = prune_mode_by_skip_rd(cpi, x, xd, bsize, ref_skip_rd,
1518                                               best_rs2 + *rate_mv);
1519         if (eval_txfm) {
1520           RD_STATS est_rd_stats;
1521           estimate_yrd_for_sb(cpi, bsize, x, INT64_MAX, &est_rd_stats);
1522           best_rd_cur =
1523               RDCOST(x->rdmult, best_rs2 + tmp_rate_mv + est_rd_stats.rate,
1524                      est_rd_stats.dist);
1525         }
1526       }
1527 
1528       mbmi->interinter_comp.wedge_index = best_mask_index;
1529       mbmi->interinter_comp.wedge_sign = best_wedge_sign;
1530       mbmi->mv[0] = tmp_mv[0];
1531       mbmi->mv[1] = tmp_mv[1];
1532       tmp_rate_mv = best_rate_mv;
1533       rs2 = best_rs2;
1534     } else if (skip_mv_refinement_for_diffwtd) {
1535       int_mv tmp_mv[2];
1536       int best_mask_index = 0;
1537       rs2 += get_interinter_compound_mask_rate(&x->mode_costs, mbmi);
1538 
1539       int need_mask_search = args->diffwtd_index == -1;
1540 
1541       for (int mask_index = 0; mask_index < 2 && need_mask_search;
1542            ++mask_index) {
1543         tmp_rate_mv = *rate_mv;
1544         mbmi->interinter_comp.mask_type = mask_index;
1545         if (have_newmv_in_inter_mode(this_mode)) {
1546           // hard coded number for diff wtd
1547           int mask_value = mask_index == 0 ? 38 : 26;
1548           memset(xd->seg_mask, mask_value,
1549                  sizeof(xd->seg_mask[0]) * 2 * MAX_SB_SQUARE);
1550           tmp_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
1551                                                               bsize, this_mode);
1552         }
1553         av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
1554                                       AOM_PLANE_Y, AOM_PLANE_Y);
1555         RD_STATS est_rd_stats;
1556         int64_t this_rd_cur = INT64_MAX;
1557         int eval_txfm = prune_mode_by_skip_rd(cpi, x, xd, bsize, ref_skip_rd,
1558                                               rs2 + *rate_mv);
1559         if (eval_txfm) {
1560           this_rd_cur =
1561               estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &est_rd_stats);
1562         }
1563         if (this_rd_cur < INT64_MAX) {
1564           this_rd_cur = RDCOST(x->rdmult, rs2 + tmp_rate_mv + est_rd_stats.rate,
1565                                est_rd_stats.dist);
1566         }
1567 
1568         if (this_rd_cur < best_rd_cur) {
1569           best_rd_cur = this_rd_cur;
1570           best_mask_index = mbmi->interinter_comp.mask_type;
1571           tmp_mv[0] = mbmi->mv[0];
1572           tmp_mv[1] = mbmi->mv[1];
1573         }
1574       }
1575 
1576       if (need_mask_search) {
1577         if (save_mask_search_results(this_mode, 0))
1578           args->diffwtd_index = best_mask_index;
1579       } else {
1580         mbmi->interinter_comp.mask_type = args->diffwtd_index;
1581         rs2 = masked_type_cost[cur_type];
1582         rs2 += get_interinter_compound_mask_rate(&x->mode_costs, mbmi);
1583 
1584         int mask_value = mbmi->interinter_comp.mask_type == 0 ? 38 : 26;
1585         memset(xd->seg_mask, mask_value,
1586                sizeof(xd->seg_mask[0]) * 2 * MAX_SB_SQUARE);
1587 
1588         if (have_newmv_in_inter_mode(this_mode)) {
1589           tmp_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
1590                                                               bsize, this_mode);
1591         }
1592         best_mask_index = mbmi->interinter_comp.mask_type;
1593         tmp_mv[0] = mbmi->mv[0];
1594         tmp_mv[1] = mbmi->mv[1];
1595         av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
1596                                       AOM_PLANE_Y, AOM_PLANE_Y);
1597         RD_STATS est_rd_stats;
1598         int64_t this_rd_cur = INT64_MAX;
1599         int eval_txfm = prune_mode_by_skip_rd(cpi, x, xd, bsize, ref_skip_rd,
1600                                               rs2 + *rate_mv);
1601         if (eval_txfm) {
1602           this_rd_cur =
1603               estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &est_rd_stats);
1604         }
1605         if (this_rd_cur < INT64_MAX) {
1606           best_rd_cur = RDCOST(x->rdmult, rs2 + tmp_rate_mv + est_rd_stats.rate,
1607                                est_rd_stats.dist);
1608         }
1609       }
1610 
1611       mbmi->interinter_comp.mask_type = best_mask_index;
1612       mbmi->mv[0] = tmp_mv[0];
1613       mbmi->mv[1] = tmp_mv[1];
1614     } else {
1615       // Handle masked compound types
1616       bool eval_masked_comp_type = true;
1617       if (*rd != INT64_MAX) {
1618         // Factors to control gating of compound type selection based on best
1619         // approximate rd so far
1620         const int max_comp_type_rd_threshold_mul =
1621             comp_type_rd_threshold_mul[cpi->sf.inter_sf
1622                                            .prune_comp_type_by_comp_avg];
1623         const int max_comp_type_rd_threshold_div =
1624             comp_type_rd_threshold_div[cpi->sf.inter_sf
1625                                            .prune_comp_type_by_comp_avg];
1626         // Evaluate COMPOUND_WEDGE / COMPOUND_DIFFWTD if approximated cost is
1627         // within threshold
1628         const int64_t approx_rd = ((*rd / max_comp_type_rd_threshold_div) *
1629                                    max_comp_type_rd_threshold_mul);
1630         if (approx_rd >= ref_best_rd) eval_masked_comp_type = false;
1631       }
1632 
1633       if (eval_masked_comp_type) {
1634         const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh);
1635         best_rd_cur = masked_compound_type_rd(
1636             cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
1637             &tmp_rate_mv, preds0, preds1, buffers->residual1, buffers->diff10,
1638             strides, rd_stats->rate, tmp_rd_thresh, &calc_pred_masked_compound,
1639             comp_rate, comp_dist, comp_model_rate, comp_model_dist,
1640             best_type_stats.comp_best_model_rd, &comp_model_rd_cur, comp_rs2,
1641             ref_skip_rd);
1642       }
1643     }
1644 
1645     // Update stats for best compound type
1646     if (best_rd_cur < *rd) {
1647       update_best_info(mbmi, rd, &best_type_stats, best_rd_cur,
1648                        comp_model_rd_cur, rs2);
1649       if (have_newmv_in_inter_mode(this_mode))
1650         update_mask_best_mv(mbmi, best_mv, &best_tmp_rate_mv, tmp_rate_mv);
1651     }
1652     // reset to original mvs for next iteration
1653     mbmi->mv[0].as_int = cur_mv[0].as_int;
1654     mbmi->mv[1].as_int = cur_mv[1].as_int;
1655   }
1656 
1657   mbmi->comp_group_idx =
1658       (best_type_stats.best_compound_data.type < COMPOUND_WEDGE) ? 0 : 1;
1659   mbmi->compound_idx =
1660       !(best_type_stats.best_compound_data.type == COMPOUND_DISTWTD);
1661   mbmi->interinter_comp = best_type_stats.best_compound_data;
1662 
1663   if (have_newmv_in_inter_mode(this_mode)) {
1664     mbmi->mv[0].as_int = best_mv[0].as_int;
1665     mbmi->mv[1].as_int = best_mv[1].as_int;
1666     rd_stats->rate += best_tmp_rate_mv - *rate_mv;
1667     *rate_mv = best_tmp_rate_mv;
1668   }
1669 
1670   if (this_mode == NEW_NEWMV)
1671     args->cmp_mode[ref_frame] = mbmi->interinter_comp.type;
1672 
1673   restore_dst_buf(xd, *orig_dst, 1);
1674   if (!match_found)
1675     save_comp_rd_search_stat(x, mbmi, comp_rate, comp_dist, comp_model_rate,
1676                              comp_model_dist, cur_mv, comp_rs2);
1677   return best_type_stats.best_compmode_interinter_cost;
1678 }
1679