xref: /aosp_15_r20/external/libgav1/src/dsp/mask_blend.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/mask_blend.h"
16 
17 #include <cassert>
18 #include <cstddef>
19 #include <cstdint>
20 
21 #include "src/dsp/dsp.h"
22 #include "src/utils/common.h"
23 
24 namespace libgav1 {
25 namespace dsp {
26 namespace {
27 
GetMaskValue(const uint8_t * LIBGAV1_RESTRICT mask,const uint8_t * LIBGAV1_RESTRICT mask_next_row,int x,int subsampling_x,int subsampling_y)28 uint8_t GetMaskValue(const uint8_t* LIBGAV1_RESTRICT mask,
29                      const uint8_t* LIBGAV1_RESTRICT mask_next_row, int x,
30                      int subsampling_x, int subsampling_y) {
31   if ((subsampling_x | subsampling_y) == 0) {
32     return mask[x];
33   }
34   if (subsampling_x == 1 && subsampling_y == 0) {
35     return static_cast<uint8_t>(RightShiftWithRounding(
36         mask[MultiplyBy2(x)] + mask[MultiplyBy2(x) + 1], 1));
37   }
38   assert(subsampling_x == 1 && subsampling_y == 1);
39   return static_cast<uint8_t>(RightShiftWithRounding(
40       mask[MultiplyBy2(x)] + mask[MultiplyBy2(x) + 1] +
41           mask_next_row[MultiplyBy2(x)] + mask_next_row[MultiplyBy2(x) + 1],
42       2));
43 }
44 
45 template <int bitdepth, typename Pixel, bool is_inter_intra, int subsampling_x,
46           int subsampling_y>
MaskBlend_C(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const ptrdiff_t prediction_stride_1,const uint8_t * LIBGAV1_RESTRICT mask,const ptrdiff_t mask_stride,const int width,const int height,void * LIBGAV1_RESTRICT dest,const ptrdiff_t dest_stride)47 void MaskBlend_C(const void* LIBGAV1_RESTRICT prediction_0,
48                  const void* LIBGAV1_RESTRICT prediction_1,
49                  const ptrdiff_t prediction_stride_1,
50                  const uint8_t* LIBGAV1_RESTRICT mask,
51                  const ptrdiff_t mask_stride, const int width, const int height,
52                  void* LIBGAV1_RESTRICT dest, const ptrdiff_t dest_stride) {
53   static_assert(!(bitdepth == 8 && is_inter_intra), "");
54   assert(mask != nullptr);
55   using PredType =
56       typename std::conditional<bitdepth == 8, int16_t, uint16_t>::type;
57   const auto* pred_0 = static_cast<const PredType*>(prediction_0);
58   const auto* pred_1 = static_cast<const PredType*>(prediction_1);
59   auto* dst = static_cast<Pixel*>(dest);
60   const ptrdiff_t dst_stride = dest_stride / sizeof(Pixel);
61   constexpr int step_y = subsampling_y ? 2 : 1;
62   const uint8_t* mask_next_row = mask + mask_stride;
63   // 7.11.3.2 Rounding variables derivation process
64   //   2 * FILTER_BITS(7) - (InterRound0(3|5) + InterRound1(7))
65   constexpr int inter_post_round_bits = (bitdepth == 12) ? 2 : 4;
66   for (int y = 0; y < height; ++y) {
67     for (int x = 0; x < width; ++x) {
68       const uint8_t mask_value =
69           GetMaskValue(mask, mask_next_row, x, subsampling_x, subsampling_y);
70       if (is_inter_intra) {
71         dst[x] = static_cast<Pixel>(RightShiftWithRounding(
72             mask_value * pred_1[x] + (64 - mask_value) * pred_0[x], 6));
73       } else {
74         assert(prediction_stride_1 == width);
75         int res = (mask_value * pred_0[x] + (64 - mask_value) * pred_1[x]) >> 6;
76         res -= (bitdepth == 8) ? 0 : kCompoundOffset;
77         dst[x] = static_cast<Pixel>(
78             Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
79                   (1 << bitdepth) - 1));
80       }
81     }
82     dst += dst_stride;
83     mask += mask_stride * step_y;
84     mask_next_row += mask_stride * step_y;
85     pred_0 += width;
86     pred_1 += prediction_stride_1;
87   }
88 }
89 
90 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlend8bpp_C(const uint8_t * LIBGAV1_RESTRICT prediction_0,uint8_t * LIBGAV1_RESTRICT prediction_1,const ptrdiff_t prediction_stride_1,const uint8_t * LIBGAV1_RESTRICT mask,const ptrdiff_t mask_stride,const int width,const int height)91 void InterIntraMaskBlend8bpp_C(const uint8_t* LIBGAV1_RESTRICT prediction_0,
92                                uint8_t* LIBGAV1_RESTRICT prediction_1,
93                                const ptrdiff_t prediction_stride_1,
94                                const uint8_t* LIBGAV1_RESTRICT mask,
95                                const ptrdiff_t mask_stride, const int width,
96                                const int height) {
97   assert(mask != nullptr);
98   constexpr int step_y = subsampling_y ? 2 : 1;
99   const uint8_t* mask_next_row = mask + mask_stride;
100   for (int y = 0; y < height; ++y) {
101     for (int x = 0; x < width; ++x) {
102       const uint8_t mask_value =
103           GetMaskValue(mask, mask_next_row, x, subsampling_x, subsampling_y);
104       prediction_1[x] = static_cast<uint8_t>(RightShiftWithRounding(
105           mask_value * prediction_1[x] + (64 - mask_value) * prediction_0[x],
106           6));
107     }
108     mask += mask_stride * step_y;
109     mask_next_row += mask_stride * step_y;
110     prediction_0 += width;
111     prediction_1 += prediction_stride_1;
112   }
113 }
114 
Init8bpp()115 void Init8bpp() {
116   Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
117   assert(dsp != nullptr);
118 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
119   dsp->mask_blend[0][0] = MaskBlend_C<8, uint8_t, false, 0, 0>;
120   dsp->mask_blend[1][0] = MaskBlend_C<8, uint8_t, false, 1, 0>;
121   dsp->mask_blend[2][0] = MaskBlend_C<8, uint8_t, false, 1, 1>;
122   // The is_inter_intra index of mask_blend[][] is replaced by
123   // inter_intra_mask_blend_8bpp[] in 8-bit.
124   dsp->mask_blend[0][1] = nullptr;
125   dsp->mask_blend[1][1] = nullptr;
126   dsp->mask_blend[2][1] = nullptr;
127   dsp->inter_intra_mask_blend_8bpp[0] = InterIntraMaskBlend8bpp_C<0, 0>;
128   dsp->inter_intra_mask_blend_8bpp[1] = InterIntraMaskBlend8bpp_C<1, 0>;
129   dsp->inter_intra_mask_blend_8bpp[2] = InterIntraMaskBlend8bpp_C<1, 1>;
130 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
131   static_cast<void>(dsp);
132 #ifndef LIBGAV1_Dsp8bpp_MaskBlend444
133   dsp->mask_blend[0][0] = MaskBlend_C<8, uint8_t, false, 0, 0>;
134 #endif
135 #ifndef LIBGAV1_Dsp8bpp_MaskBlend422
136   dsp->mask_blend[1][0] = MaskBlend_C<8, uint8_t, false, 1, 0>;
137 #endif
138 #ifndef LIBGAV1_Dsp8bpp_MaskBlend420
139   dsp->mask_blend[2][0] = MaskBlend_C<8, uint8_t, false, 1, 1>;
140 #endif
141   // The is_inter_intra index of mask_blend[][] is replaced by
142   // inter_intra_mask_blend_8bpp[] in 8-bit.
143   dsp->mask_blend[0][1] = nullptr;
144   dsp->mask_blend[1][1] = nullptr;
145   dsp->mask_blend[2][1] = nullptr;
146 #ifndef LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp444
147   dsp->inter_intra_mask_blend_8bpp[0] = InterIntraMaskBlend8bpp_C<0, 0>;
148 #endif
149 #ifndef LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp422
150   dsp->inter_intra_mask_blend_8bpp[1] = InterIntraMaskBlend8bpp_C<1, 0>;
151 #endif
152 #ifndef LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp420
153   dsp->inter_intra_mask_blend_8bpp[2] = InterIntraMaskBlend8bpp_C<1, 1>;
154 #endif
155   static_cast<void>(GetMaskValue);
156 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
157 }
158 
159 #if LIBGAV1_MAX_BITDEPTH >= 10
Init10bpp()160 void Init10bpp() {
161   Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
162   assert(dsp != nullptr);
163 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
164   dsp->mask_blend[0][0] = MaskBlend_C<10, uint16_t, false, 0, 0>;
165   dsp->mask_blend[1][0] = MaskBlend_C<10, uint16_t, false, 1, 0>;
166   dsp->mask_blend[2][0] = MaskBlend_C<10, uint16_t, false, 1, 1>;
167   dsp->mask_blend[0][1] = MaskBlend_C<10, uint16_t, true, 0, 0>;
168   dsp->mask_blend[1][1] = MaskBlend_C<10, uint16_t, true, 1, 0>;
169   dsp->mask_blend[2][1] = MaskBlend_C<10, uint16_t, true, 1, 1>;
170   // These are only used with 8-bit.
171   dsp->inter_intra_mask_blend_8bpp[0] = nullptr;
172   dsp->inter_intra_mask_blend_8bpp[1] = nullptr;
173   dsp->inter_intra_mask_blend_8bpp[2] = nullptr;
174 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
175   static_cast<void>(dsp);
176 #ifndef LIBGAV1_Dsp10bpp_MaskBlend444
177   dsp->mask_blend[0][0] = MaskBlend_C<10, uint16_t, false, 0, 0>;
178 #endif
179 #ifndef LIBGAV1_Dsp10bpp_MaskBlend422
180   dsp->mask_blend[1][0] = MaskBlend_C<10, uint16_t, false, 1, 0>;
181 #endif
182 #ifndef LIBGAV1_Dsp10bpp_MaskBlend420
183   dsp->mask_blend[2][0] = MaskBlend_C<10, uint16_t, false, 1, 1>;
184 #endif
185 #ifndef LIBGAV1_Dsp10bpp_MaskBlendInterIntra444
186   dsp->mask_blend[0][1] = MaskBlend_C<10, uint16_t, true, 0, 0>;
187 #endif
188 #ifndef LIBGAV1_Dsp10bpp_MaskBlendInterIntra422
189   dsp->mask_blend[1][1] = MaskBlend_C<10, uint16_t, true, 1, 0>;
190 #endif
191 #ifndef LIBGAV1_Dsp10bpp_MaskBlendInterIntra420
192   dsp->mask_blend[2][1] = MaskBlend_C<10, uint16_t, true, 1, 1>;
193 #endif
194   // These are only used with 8-bit.
195   dsp->inter_intra_mask_blend_8bpp[0] = nullptr;
196   dsp->inter_intra_mask_blend_8bpp[1] = nullptr;
197   dsp->inter_intra_mask_blend_8bpp[2] = nullptr;
198 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
199 }
200 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
201 
202 #if LIBGAV1_MAX_BITDEPTH == 12
Init12bpp()203 void Init12bpp() {
204   Dsp* const dsp = dsp_internal::GetWritableDspTable(12);
205   assert(dsp != nullptr);
206 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
207   dsp->mask_blend[0][0] = MaskBlend_C<12, uint16_t, false, 0, 0>;
208   dsp->mask_blend[1][0] = MaskBlend_C<12, uint16_t, false, 1, 0>;
209   dsp->mask_blend[2][0] = MaskBlend_C<12, uint16_t, false, 1, 1>;
210   dsp->mask_blend[0][1] = MaskBlend_C<12, uint16_t, true, 0, 0>;
211   dsp->mask_blend[1][1] = MaskBlend_C<12, uint16_t, true, 1, 0>;
212   dsp->mask_blend[2][1] = MaskBlend_C<12, uint16_t, true, 1, 1>;
213   // These are only used with 8-bit.
214   dsp->inter_intra_mask_blend_8bpp[0] = nullptr;
215   dsp->inter_intra_mask_blend_8bpp[1] = nullptr;
216   dsp->inter_intra_mask_blend_8bpp[2] = nullptr;
217 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
218   static_cast<void>(dsp);
219 #ifndef LIBGAV1_Dsp12bpp_MaskBlend444
220   dsp->mask_blend[0][0] = MaskBlend_C<12, uint16_t, false, 0, 0>;
221 #endif
222 #ifndef LIBGAV1_Dsp12bpp_MaskBlend422
223   dsp->mask_blend[1][0] = MaskBlend_C<12, uint16_t, false, 1, 0>;
224 #endif
225 #ifndef LIBGAV1_Dsp12bpp_MaskBlend420
226   dsp->mask_blend[2][0] = MaskBlend_C<12, uint16_t, false, 1, 1>;
227 #endif
228 #ifndef LIBGAV1_Dsp12bpp_MaskBlendInterIntra444
229   dsp->mask_blend[0][1] = MaskBlend_C<12, uint16_t, true, 0, 0>;
230 #endif
231 #ifndef LIBGAV1_Dsp12bpp_MaskBlendInterIntra422
232   dsp->mask_blend[1][1] = MaskBlend_C<12, uint16_t, true, 1, 0>;
233 #endif
234 #ifndef LIBGAV1_Dsp12bpp_MaskBlendInterIntra420
235   dsp->mask_blend[2][1] = MaskBlend_C<12, uint16_t, true, 1, 1>;
236 #endif
237   // These are only used with 8-bit.
238   dsp->inter_intra_mask_blend_8bpp[0] = nullptr;
239   dsp->inter_intra_mask_blend_8bpp[1] = nullptr;
240   dsp->inter_intra_mask_blend_8bpp[2] = nullptr;
241 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
242 }
243 #endif  // LIBGAV1_MAX_BITDEPTH == 12
244 
245 }  // namespace
246 
MaskBlendInit_C()247 void MaskBlendInit_C() {
248   Init8bpp();
249 #if LIBGAV1_MAX_BITDEPTH >= 10
250   Init10bpp();
251 #endif
252 #if LIBGAV1_MAX_BITDEPTH == 12
253   Init12bpp();
254 #endif
255 }
256 
257 }  // namespace dsp
258 }  // namespace libgav1
259