xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/spirv/EmulateYUVBuiltIns.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2022 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // EmulateYUVBuiltIns: Adds functions that emulate yuv_2_rgb and rgb_2_yuv built-ins.
7 //
8 
9 #include "compiler/translator/tree_ops/spirv/EmulateYUVBuiltIns.h"
10 
11 #include "compiler/translator/StaticType.h"
12 #include "compiler/translator/SymbolTable.h"
13 #include "compiler/translator/tree_util/IntermNode_util.h"
14 #include "compiler/translator/tree_util/IntermTraverse.h"
15 
16 namespace sh
17 {
18 namespace
19 {
20 // A traverser that replaces the yuv built-ins with a function call that emulates it.
21 class EmulateYUVBuiltInsTraverser : public TIntermTraverser
22 {
23   public:
EmulateYUVBuiltInsTraverser(TSymbolTable * symbolTable)24     EmulateYUVBuiltInsTraverser(TSymbolTable *symbolTable)
25         : TIntermTraverser(true, false, false, symbolTable)
26     {}
27 
28     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
29 
30     bool update(TCompiler *compiler, TIntermBlock *root);
31 
32   private:
33     const TFunction *getYUV2RGBFunc(TPrecision precision);
34     const TFunction *getRGB2YUVFunc(TPrecision precision);
35     const TFunction *getYUVFunc(TPrecision precision,
36                                 const char *name,
37                                 TIntermTyped *itu601Matrix,
38                                 TIntermTyped *itu601WideMatrix,
39                                 TIntermTyped *itu709Matrix,
40                                 TIntermFunctionDefinition **funcDefOut);
41 
42     TIntermTyped *replaceYUVFuncCall(TIntermTyped *node);
43 
44     // One emulation function for each sampler precision
45     std::array<TIntermFunctionDefinition *, EbpLast> mYUV2RGBFuncDefs = {};
46     std::array<TIntermFunctionDefinition *, EbpLast> mRGB2YUVFuncDefs = {};
47 };
48 
visitAggregate(Visit visit,TIntermAggregate * node)49 bool EmulateYUVBuiltInsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
50 {
51     TIntermTyped *replacement = replaceYUVFuncCall(node);
52 
53     if (replacement != nullptr)
54     {
55         queueReplacement(replacement, OriginalNode::IS_DROPPED);
56         return false;
57     }
58 
59     return true;
60 }
61 
replaceYUVFuncCall(TIntermTyped * node)62 TIntermTyped *EmulateYUVBuiltInsTraverser::replaceYUVFuncCall(TIntermTyped *node)
63 {
64     TIntermAggregate *asAggregate = node->getAsAggregate();
65     if (asAggregate == nullptr)
66     {
67         return nullptr;
68     }
69 
70     TOperator op = asAggregate->getOp();
71     if (op != EOpYuv_2_rgb && op != EOpRgb_2_yuv)
72     {
73         return nullptr;
74     }
75 
76     ASSERT(asAggregate->getChildCount() == 2);
77 
78     TIntermTyped *param0 = asAggregate->getChildNode(0)->getAsTyped();
79     TPrecision precision = param0->getPrecision();
80     if (precision == EbpUndefined)
81     {
82         precision = EbpMedium;
83     }
84 
85     const TFunction *emulatedFunction =
86         op == EOpYuv_2_rgb ? getYUV2RGBFunc(precision) : getRGB2YUVFunc(precision);
87 
88     // The first parameter of the built-ins (|color|) may itself contain a built-in call.  With
89     // TIntermTraverser, if the direct children also needs to be replaced that needs to be done
90     // while constructing this node as replacement doesn't work.
91     TIntermTyped *param0Replacement = replaceYUVFuncCall(param0);
92 
93     if (param0Replacement == nullptr)
94     {
95         // If param0 is not directly a YUV built-in call, visit it recursively so YIV built-in call
96         // sub expressions are replaced.
97         param0->traverse(this);
98         param0Replacement = param0;
99     }
100 
101     // Create the function call
102     TIntermSequence args = {
103         param0Replacement,
104         asAggregate->getChildNode(1),
105     };
106     return TIntermAggregate::CreateFunctionCall(*emulatedFunction, &args);
107 }
108 
MakeMatrix(const std::array<float,12> & elements)109 TIntermTyped *MakeMatrix(const std::array<float, 12> &elements)
110 {
111     TIntermSequence matrix;
112     for (float element : elements)
113     {
114         matrix.push_back(CreateFloatNode(element, EbpMedium));
115     }
116 
117     const TType *matType = StaticType::GetBasic<EbtFloat, EbpMedium, 4, 3>();
118     return TIntermAggregate::CreateConstructor(*matType, &matrix);
119 }
120 
getYUV2RGBFunc(TPrecision precision)121 const TFunction *EmulateYUVBuiltInsTraverser::getYUV2RGBFunc(TPrecision precision)
122 {
123     const char *name = "ANGLE_yuv_2_rgb";
124     switch (precision)
125     {
126         case EbpLow:
127             name = "ANGLE_yuv_2_rgb_lowp";
128             break;
129         case EbpMedium:
130             name = "ANGLE_yuv_2_rgb_mediump";
131             break;
132         case EbpHigh:
133             name = "ANGLE_yuv_2_rgb_highp";
134             break;
135         default:
136             UNREACHABLE();
137     }
138 
139     // Matrix is combination of the "pure" colorspace conversion matrix for each standard,
140     // the appropriate range expansion (in the case of narrow range) and shifting down of chroma
141     // components to be centered on zero. These arrays are interpreted as mat4x3
142     //
143     // Pure conversion used for itu601:
144     //   1.0, 1.0, 1.0, 0.0, -0.3441, 1.7720, 1.4020, -0.7141, 0.0
145     // Pure conversion used for itu709:
146     //   1.0, 1.0, 1.0, 0.0, -0.1873, 1.8556, 1.5748, -0.4681, 0.0
147     //
148     // For narrow range, Y is rescaled from [16/255, 235/255] to [0,1]
149     // and Cb/Cr are rescaled from [16/255, 240/255] to [0,1] and shifted by -128/255
150     // to center on zero. For wide range, only the Cb/Cr shifting by -128/255 is performed.
151 
152     constexpr std::array<float, 12> itu601Matrix = {1.164384,  1.164384,  1.164384, 0.0,
153                                                     -0.391721, 2.017232,  1.596027, -0.812926,
154                                                     0.0,       -0.874202, 0.531626, -1.085631};
155 
156     constexpr std::array<float, 12> itu601WideMatrix = {1.000000,  1.000000,  1.000000, 0.000000,
157                                                         -0.344100, 1.772000,  1.402000, -0.714100,
158                                                         0.000000,  -0.703749, 0.531175, -0.889475};
159 
160     constexpr std::array<float, 12> itu709Matrix = {1.164384,  1.164384,  1.164384, 0.000000,
161                                                     -0.213221, 2.112402,  1.792741, -0.532882,
162                                                     0.000000,  -0.972945, 0.301455, -1.133402};
163 
164     return getYUVFunc(precision, name, MakeMatrix(itu601Matrix), MakeMatrix(itu601WideMatrix),
165                       MakeMatrix(itu709Matrix), &mYUV2RGBFuncDefs[precision]);
166 }
167 
getRGB2YUVFunc(TPrecision precision)168 const TFunction *EmulateYUVBuiltInsTraverser::getRGB2YUVFunc(TPrecision precision)
169 {
170     const char *name = "ANGLE_rgb_2_yuv";
171     switch (precision)
172     {
173         case EbpLow:
174             name = "ANGLE_rgb_2_yuv_lowp";
175             break;
176         case EbpMedium:
177             name = "ANGLE_rgb_2_yuv_mediump";
178             break;
179         case EbpHigh:
180             name = "ANGLE_rgb_2_yuv_highp";
181             break;
182         default:
183             UNREACHABLE();
184     }
185 
186     // Inverse of yuv_2_rgb transforms above
187     const std::array<float, 12> itu601Matrix = {0.256782,  -0.148219, 0.439220, 0.504143,
188                                                 -0.291001, -0.367798, 0.097898, 0.439220,
189                                                 -0.071422, 0.062745,  0.501961, 0.501961};
190 
191     const std::array<float, 12> itu601WideMatrix = {0.298993,  -0.168732, 0.500005, 0.587016,
192                                                     -0.331273, -0.418699, 0.113991, 0.500005,
193                                                     -0.081306, 0.000000,  0.501961, 0.501961};
194 
195     const std::array<float, 12> itu709Matrix = {0.182580,  -0.100641, 0.439219, 0.614243,
196                                                 -0.338579, -0.398950, 0.062000, 0.439219,
197                                                 -0.040269, 0.062745,  0.501961, 0.501961};
198 
199     return getYUVFunc(precision, name, MakeMatrix(itu601Matrix), MakeMatrix(itu601WideMatrix),
200                       MakeMatrix(itu709Matrix), &mRGB2YUVFuncDefs[precision]);
201 }
202 
getYUVFunc(TPrecision precision,const char * name,TIntermTyped * itu601Matrix,TIntermTyped * itu601WideMatrix,TIntermTyped * itu709Matrix,TIntermFunctionDefinition ** funcDefOut)203 const TFunction *EmulateYUVBuiltInsTraverser::getYUVFunc(TPrecision precision,
204                                                          const char *name,
205                                                          TIntermTyped *itu601Matrix,
206                                                          TIntermTyped *itu601WideMatrix,
207                                                          TIntermTyped *itu709Matrix,
208                                                          TIntermFunctionDefinition **funcDefOut)
209 {
210     if (*funcDefOut != nullptr)
211     {
212         return (*funcDefOut)->getFunction();
213     }
214 
215     // The function prototype is vec3 name(vec3 color, yuvCscStandardEXT conv_standard)
216     TType *vec3Type = new TType(*StaticType::GetBasic<EbtFloat, EbpMedium, 3>());
217     vec3Type->setPrecision(precision);
218     const TType *yuvCscType = StaticType::GetBasic<EbtYuvCscStandardEXT, EbpUndefined>();
219 
220     TType *colorType = new TType(*vec3Type);
221     TType *convType  = new TType(*yuvCscType);
222     colorType->setQualifier(EvqParamIn);
223     convType->setQualifier(EvqParamIn);
224 
225     TVariable *colorParam =
226         new TVariable(mSymbolTable, ImmutableString("color"), colorType, SymbolType::AngleInternal);
227     TVariable *convParam = new TVariable(mSymbolTable, ImmutableString("conv_standard"), convType,
228                                          SymbolType::AngleInternal);
229 
230     TFunction *function = new TFunction(mSymbolTable, ImmutableString(name),
231                                         SymbolType::AngleInternal, vec3Type, true);
232     function->addParameter(colorParam);
233     function->addParameter(convParam);
234 
235     TType *vec4Type = new TType(*StaticType::GetBasic<EbtFloat, EbpMedium, 4>());
236     vec4Type->setPrecision(precision);
237 
238     TIntermSequence components;
239     components.push_back(new TIntermSymbol(colorParam));
240     components.push_back(CreateFloatNode(1.0f, EbpMedium));
241     // vec4(color, 1)
242     TIntermTyped *extendedColor = TIntermAggregate::CreateConstructor(*vec4Type, &components);
243 
244     // The function body is as such:
245     //
246     //     switch (conv_standard)
247     //     {
248     //       case itu_601:
249     //         return itu601Matrix * color;
250     //       case itu_601_full_range:
251     //         return itu601WideMatrix * color;
252     //       case itu_709:
253     //         return itu709Matrix * color;
254     //     }
255     //
256     //     // error
257     //     return vec3(0.0);
258 
259     // Matrix * color
260     TIntermTyped *itu601Mul = new TIntermBinary(EOpMatrixTimesVector, itu601Matrix, extendedColor);
261     TIntermTyped *itu601FullRangeMul =
262         new TIntermBinary(EOpMatrixTimesVector, itu601WideMatrix, extendedColor->deepCopy());
263     TIntermTyped *itu709Mul =
264         new TIntermBinary(EOpMatrixTimesVector, itu709Matrix, extendedColor->deepCopy());
265 
266     // return Matrix * color
267     TIntermBranch *returnItu601Mul          = new TIntermBranch(EOpReturn, itu601Mul);
268     TIntermBranch *returnItu601FullRangeMul = new TIntermBranch(EOpReturn, itu601FullRangeMul);
269     TIntermBranch *returnItu709Mul          = new TIntermBranch(EOpReturn, itu709Mul);
270 
271     // itu_* constants
272     TConstantUnion *ituConstants = new TConstantUnion[3];
273     ituConstants[0].setYuvCscStandardEXTConst(EycsItu601);
274     ituConstants[1].setYuvCscStandardEXTConst(EycsItu601FullRange);
275     ituConstants[2].setYuvCscStandardEXTConst(EycsItu709);
276 
277     TIntermConstantUnion *itu601          = new TIntermConstantUnion(&ituConstants[0], *yuvCscType);
278     TIntermConstantUnion *itu601FullRange = new TIntermConstantUnion(&ituConstants[1], *yuvCscType);
279     TIntermConstantUnion *itu709          = new TIntermConstantUnion(&ituConstants[2], *yuvCscType);
280 
281     // case ...: return ...
282     TIntermBlock *switchBody = new TIntermBlock;
283 
284     switchBody->appendStatement(new TIntermCase(itu601));
285     switchBody->appendStatement(returnItu601Mul);
286     switchBody->appendStatement(new TIntermCase(itu601FullRange));
287     switchBody->appendStatement(returnItu601FullRangeMul);
288     switchBody->appendStatement(new TIntermCase(itu709));
289     switchBody->appendStatement(returnItu709Mul);
290 
291     // switch (conv_standard) ...
292     TIntermSwitch *switchStatement = new TIntermSwitch(new TIntermSymbol(convParam), switchBody);
293 
294     TIntermBlock *body = new TIntermBlock;
295 
296     body->appendStatement(switchStatement);
297     body->appendStatement(new TIntermBranch(EOpReturn, CreateZeroNode(*vec3Type)));
298 
299     *funcDefOut = new TIntermFunctionDefinition(new TIntermFunctionPrototype(function), body);
300 
301     return function;
302 }
303 
update(TCompiler * compiler,TIntermBlock * root)304 bool EmulateYUVBuiltInsTraverser::update(TCompiler *compiler, TIntermBlock *root)
305 {
306     // Insert any added function definitions before the first function.
307     const size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(root);
308     TIntermSequence funcDefs;
309 
310     for (TIntermFunctionDefinition *funcDef : mYUV2RGBFuncDefs)
311     {
312         if (funcDef != nullptr)
313         {
314             funcDefs.push_back(funcDef);
315         }
316     }
317 
318     for (TIntermFunctionDefinition *funcDef : mRGB2YUVFuncDefs)
319     {
320         if (funcDef != nullptr)
321         {
322             funcDefs.push_back(funcDef);
323         }
324     }
325 
326     root->insertChildNodes(firstFunctionIndex, funcDefs);
327 
328     return updateTree(compiler, root);
329 }
330 }  // anonymous namespace
331 
EmulateYUVBuiltIns(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)332 bool EmulateYUVBuiltIns(TCompiler *compiler, TIntermBlock *root, TSymbolTable *symbolTable)
333 {
334     EmulateYUVBuiltInsTraverser traverser(symbolTable);
335     root->traverse(&traverser);
336     return traverser.update(compiler, root);
337 }
338 }  // namespace sh
339