1 //
2 // Copyright 2022 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // EmulateYUVBuiltIns: Adds functions that emulate yuv_2_rgb and rgb_2_yuv built-ins.
7 //
8
9 #include "compiler/translator/tree_ops/spirv/EmulateYUVBuiltIns.h"
10
11 #include "compiler/translator/StaticType.h"
12 #include "compiler/translator/SymbolTable.h"
13 #include "compiler/translator/tree_util/IntermNode_util.h"
14 #include "compiler/translator/tree_util/IntermTraverse.h"
15
16 namespace sh
17 {
18 namespace
19 {
20 // A traverser that replaces the yuv built-ins with a function call that emulates it.
21 class EmulateYUVBuiltInsTraverser : public TIntermTraverser
22 {
23 public:
EmulateYUVBuiltInsTraverser(TSymbolTable * symbolTable)24 EmulateYUVBuiltInsTraverser(TSymbolTable *symbolTable)
25 : TIntermTraverser(true, false, false, symbolTable)
26 {}
27
28 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
29
30 bool update(TCompiler *compiler, TIntermBlock *root);
31
32 private:
33 const TFunction *getYUV2RGBFunc(TPrecision precision);
34 const TFunction *getRGB2YUVFunc(TPrecision precision);
35 const TFunction *getYUVFunc(TPrecision precision,
36 const char *name,
37 TIntermTyped *itu601Matrix,
38 TIntermTyped *itu601WideMatrix,
39 TIntermTyped *itu709Matrix,
40 TIntermFunctionDefinition **funcDefOut);
41
42 TIntermTyped *replaceYUVFuncCall(TIntermTyped *node);
43
44 // One emulation function for each sampler precision
45 std::array<TIntermFunctionDefinition *, EbpLast> mYUV2RGBFuncDefs = {};
46 std::array<TIntermFunctionDefinition *, EbpLast> mRGB2YUVFuncDefs = {};
47 };
48
visitAggregate(Visit visit,TIntermAggregate * node)49 bool EmulateYUVBuiltInsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
50 {
51 TIntermTyped *replacement = replaceYUVFuncCall(node);
52
53 if (replacement != nullptr)
54 {
55 queueReplacement(replacement, OriginalNode::IS_DROPPED);
56 return false;
57 }
58
59 return true;
60 }
61
replaceYUVFuncCall(TIntermTyped * node)62 TIntermTyped *EmulateYUVBuiltInsTraverser::replaceYUVFuncCall(TIntermTyped *node)
63 {
64 TIntermAggregate *asAggregate = node->getAsAggregate();
65 if (asAggregate == nullptr)
66 {
67 return nullptr;
68 }
69
70 TOperator op = asAggregate->getOp();
71 if (op != EOpYuv_2_rgb && op != EOpRgb_2_yuv)
72 {
73 return nullptr;
74 }
75
76 ASSERT(asAggregate->getChildCount() == 2);
77
78 TIntermTyped *param0 = asAggregate->getChildNode(0)->getAsTyped();
79 TPrecision precision = param0->getPrecision();
80 if (precision == EbpUndefined)
81 {
82 precision = EbpMedium;
83 }
84
85 const TFunction *emulatedFunction =
86 op == EOpYuv_2_rgb ? getYUV2RGBFunc(precision) : getRGB2YUVFunc(precision);
87
88 // The first parameter of the built-ins (|color|) may itself contain a built-in call. With
89 // TIntermTraverser, if the direct children also needs to be replaced that needs to be done
90 // while constructing this node as replacement doesn't work.
91 TIntermTyped *param0Replacement = replaceYUVFuncCall(param0);
92
93 if (param0Replacement == nullptr)
94 {
95 // If param0 is not directly a YUV built-in call, visit it recursively so YIV built-in call
96 // sub expressions are replaced.
97 param0->traverse(this);
98 param0Replacement = param0;
99 }
100
101 // Create the function call
102 TIntermSequence args = {
103 param0Replacement,
104 asAggregate->getChildNode(1),
105 };
106 return TIntermAggregate::CreateFunctionCall(*emulatedFunction, &args);
107 }
108
MakeMatrix(const std::array<float,12> & elements)109 TIntermTyped *MakeMatrix(const std::array<float, 12> &elements)
110 {
111 TIntermSequence matrix;
112 for (float element : elements)
113 {
114 matrix.push_back(CreateFloatNode(element, EbpMedium));
115 }
116
117 const TType *matType = StaticType::GetBasic<EbtFloat, EbpMedium, 4, 3>();
118 return TIntermAggregate::CreateConstructor(*matType, &matrix);
119 }
120
getYUV2RGBFunc(TPrecision precision)121 const TFunction *EmulateYUVBuiltInsTraverser::getYUV2RGBFunc(TPrecision precision)
122 {
123 const char *name = "ANGLE_yuv_2_rgb";
124 switch (precision)
125 {
126 case EbpLow:
127 name = "ANGLE_yuv_2_rgb_lowp";
128 break;
129 case EbpMedium:
130 name = "ANGLE_yuv_2_rgb_mediump";
131 break;
132 case EbpHigh:
133 name = "ANGLE_yuv_2_rgb_highp";
134 break;
135 default:
136 UNREACHABLE();
137 }
138
139 // Matrix is combination of the "pure" colorspace conversion matrix for each standard,
140 // the appropriate range expansion (in the case of narrow range) and shifting down of chroma
141 // components to be centered on zero. These arrays are interpreted as mat4x3
142 //
143 // Pure conversion used for itu601:
144 // 1.0, 1.0, 1.0, 0.0, -0.3441, 1.7720, 1.4020, -0.7141, 0.0
145 // Pure conversion used for itu709:
146 // 1.0, 1.0, 1.0, 0.0, -0.1873, 1.8556, 1.5748, -0.4681, 0.0
147 //
148 // For narrow range, Y is rescaled from [16/255, 235/255] to [0,1]
149 // and Cb/Cr are rescaled from [16/255, 240/255] to [0,1] and shifted by -128/255
150 // to center on zero. For wide range, only the Cb/Cr shifting by -128/255 is performed.
151
152 constexpr std::array<float, 12> itu601Matrix = {1.164384, 1.164384, 1.164384, 0.0,
153 -0.391721, 2.017232, 1.596027, -0.812926,
154 0.0, -0.874202, 0.531626, -1.085631};
155
156 constexpr std::array<float, 12> itu601WideMatrix = {1.000000, 1.000000, 1.000000, 0.000000,
157 -0.344100, 1.772000, 1.402000, -0.714100,
158 0.000000, -0.703749, 0.531175, -0.889475};
159
160 constexpr std::array<float, 12> itu709Matrix = {1.164384, 1.164384, 1.164384, 0.000000,
161 -0.213221, 2.112402, 1.792741, -0.532882,
162 0.000000, -0.972945, 0.301455, -1.133402};
163
164 return getYUVFunc(precision, name, MakeMatrix(itu601Matrix), MakeMatrix(itu601WideMatrix),
165 MakeMatrix(itu709Matrix), &mYUV2RGBFuncDefs[precision]);
166 }
167
getRGB2YUVFunc(TPrecision precision)168 const TFunction *EmulateYUVBuiltInsTraverser::getRGB2YUVFunc(TPrecision precision)
169 {
170 const char *name = "ANGLE_rgb_2_yuv";
171 switch (precision)
172 {
173 case EbpLow:
174 name = "ANGLE_rgb_2_yuv_lowp";
175 break;
176 case EbpMedium:
177 name = "ANGLE_rgb_2_yuv_mediump";
178 break;
179 case EbpHigh:
180 name = "ANGLE_rgb_2_yuv_highp";
181 break;
182 default:
183 UNREACHABLE();
184 }
185
186 // Inverse of yuv_2_rgb transforms above
187 const std::array<float, 12> itu601Matrix = {0.256782, -0.148219, 0.439220, 0.504143,
188 -0.291001, -0.367798, 0.097898, 0.439220,
189 -0.071422, 0.062745, 0.501961, 0.501961};
190
191 const std::array<float, 12> itu601WideMatrix = {0.298993, -0.168732, 0.500005, 0.587016,
192 -0.331273, -0.418699, 0.113991, 0.500005,
193 -0.081306, 0.000000, 0.501961, 0.501961};
194
195 const std::array<float, 12> itu709Matrix = {0.182580, -0.100641, 0.439219, 0.614243,
196 -0.338579, -0.398950, 0.062000, 0.439219,
197 -0.040269, 0.062745, 0.501961, 0.501961};
198
199 return getYUVFunc(precision, name, MakeMatrix(itu601Matrix), MakeMatrix(itu601WideMatrix),
200 MakeMatrix(itu709Matrix), &mRGB2YUVFuncDefs[precision]);
201 }
202
getYUVFunc(TPrecision precision,const char * name,TIntermTyped * itu601Matrix,TIntermTyped * itu601WideMatrix,TIntermTyped * itu709Matrix,TIntermFunctionDefinition ** funcDefOut)203 const TFunction *EmulateYUVBuiltInsTraverser::getYUVFunc(TPrecision precision,
204 const char *name,
205 TIntermTyped *itu601Matrix,
206 TIntermTyped *itu601WideMatrix,
207 TIntermTyped *itu709Matrix,
208 TIntermFunctionDefinition **funcDefOut)
209 {
210 if (*funcDefOut != nullptr)
211 {
212 return (*funcDefOut)->getFunction();
213 }
214
215 // The function prototype is vec3 name(vec3 color, yuvCscStandardEXT conv_standard)
216 TType *vec3Type = new TType(*StaticType::GetBasic<EbtFloat, EbpMedium, 3>());
217 vec3Type->setPrecision(precision);
218 const TType *yuvCscType = StaticType::GetBasic<EbtYuvCscStandardEXT, EbpUndefined>();
219
220 TType *colorType = new TType(*vec3Type);
221 TType *convType = new TType(*yuvCscType);
222 colorType->setQualifier(EvqParamIn);
223 convType->setQualifier(EvqParamIn);
224
225 TVariable *colorParam =
226 new TVariable(mSymbolTable, ImmutableString("color"), colorType, SymbolType::AngleInternal);
227 TVariable *convParam = new TVariable(mSymbolTable, ImmutableString("conv_standard"), convType,
228 SymbolType::AngleInternal);
229
230 TFunction *function = new TFunction(mSymbolTable, ImmutableString(name),
231 SymbolType::AngleInternal, vec3Type, true);
232 function->addParameter(colorParam);
233 function->addParameter(convParam);
234
235 TType *vec4Type = new TType(*StaticType::GetBasic<EbtFloat, EbpMedium, 4>());
236 vec4Type->setPrecision(precision);
237
238 TIntermSequence components;
239 components.push_back(new TIntermSymbol(colorParam));
240 components.push_back(CreateFloatNode(1.0f, EbpMedium));
241 // vec4(color, 1)
242 TIntermTyped *extendedColor = TIntermAggregate::CreateConstructor(*vec4Type, &components);
243
244 // The function body is as such:
245 //
246 // switch (conv_standard)
247 // {
248 // case itu_601:
249 // return itu601Matrix * color;
250 // case itu_601_full_range:
251 // return itu601WideMatrix * color;
252 // case itu_709:
253 // return itu709Matrix * color;
254 // }
255 //
256 // // error
257 // return vec3(0.0);
258
259 // Matrix * color
260 TIntermTyped *itu601Mul = new TIntermBinary(EOpMatrixTimesVector, itu601Matrix, extendedColor);
261 TIntermTyped *itu601FullRangeMul =
262 new TIntermBinary(EOpMatrixTimesVector, itu601WideMatrix, extendedColor->deepCopy());
263 TIntermTyped *itu709Mul =
264 new TIntermBinary(EOpMatrixTimesVector, itu709Matrix, extendedColor->deepCopy());
265
266 // return Matrix * color
267 TIntermBranch *returnItu601Mul = new TIntermBranch(EOpReturn, itu601Mul);
268 TIntermBranch *returnItu601FullRangeMul = new TIntermBranch(EOpReturn, itu601FullRangeMul);
269 TIntermBranch *returnItu709Mul = new TIntermBranch(EOpReturn, itu709Mul);
270
271 // itu_* constants
272 TConstantUnion *ituConstants = new TConstantUnion[3];
273 ituConstants[0].setYuvCscStandardEXTConst(EycsItu601);
274 ituConstants[1].setYuvCscStandardEXTConst(EycsItu601FullRange);
275 ituConstants[2].setYuvCscStandardEXTConst(EycsItu709);
276
277 TIntermConstantUnion *itu601 = new TIntermConstantUnion(&ituConstants[0], *yuvCscType);
278 TIntermConstantUnion *itu601FullRange = new TIntermConstantUnion(&ituConstants[1], *yuvCscType);
279 TIntermConstantUnion *itu709 = new TIntermConstantUnion(&ituConstants[2], *yuvCscType);
280
281 // case ...: return ...
282 TIntermBlock *switchBody = new TIntermBlock;
283
284 switchBody->appendStatement(new TIntermCase(itu601));
285 switchBody->appendStatement(returnItu601Mul);
286 switchBody->appendStatement(new TIntermCase(itu601FullRange));
287 switchBody->appendStatement(returnItu601FullRangeMul);
288 switchBody->appendStatement(new TIntermCase(itu709));
289 switchBody->appendStatement(returnItu709Mul);
290
291 // switch (conv_standard) ...
292 TIntermSwitch *switchStatement = new TIntermSwitch(new TIntermSymbol(convParam), switchBody);
293
294 TIntermBlock *body = new TIntermBlock;
295
296 body->appendStatement(switchStatement);
297 body->appendStatement(new TIntermBranch(EOpReturn, CreateZeroNode(*vec3Type)));
298
299 *funcDefOut = new TIntermFunctionDefinition(new TIntermFunctionPrototype(function), body);
300
301 return function;
302 }
303
update(TCompiler * compiler,TIntermBlock * root)304 bool EmulateYUVBuiltInsTraverser::update(TCompiler *compiler, TIntermBlock *root)
305 {
306 // Insert any added function definitions before the first function.
307 const size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(root);
308 TIntermSequence funcDefs;
309
310 for (TIntermFunctionDefinition *funcDef : mYUV2RGBFuncDefs)
311 {
312 if (funcDef != nullptr)
313 {
314 funcDefs.push_back(funcDef);
315 }
316 }
317
318 for (TIntermFunctionDefinition *funcDef : mRGB2YUVFuncDefs)
319 {
320 if (funcDef != nullptr)
321 {
322 funcDefs.push_back(funcDef);
323 }
324 }
325
326 root->insertChildNodes(firstFunctionIndex, funcDefs);
327
328 return updateTree(compiler, root);
329 }
330 } // anonymous namespace
331
EmulateYUVBuiltIns(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)332 bool EmulateYUVBuiltIns(TCompiler *compiler, TIntermBlock *root, TSymbolTable *symbolTable)
333 {
334 EmulateYUVBuiltInsTraverser traverser(symbolTable);
335 root->traverse(&traverser);
336 return traverser.update(compiler, root);
337 }
338 } // namespace sh
339