xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/DeclarePerVertexBlocks.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // DeclarePerVertexBlocks: Declare gl_PerVertex blocks if not already.
7 //
8 
9 #include "compiler/translator/tree_ops/DeclarePerVertexBlocks.h"
10 
11 #include "compiler/translator/Compiler.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 #include "compiler/translator/tree_util/ReplaceVariable.h"
18 
19 namespace sh
20 {
21 namespace
22 {
23 using PerVertexMemberFlags = std::array<bool, 4>;
24 
GetPerVertexFieldIndex(const TQualifier qualifier,const ImmutableString & name)25 int GetPerVertexFieldIndex(const TQualifier qualifier, const ImmutableString &name)
26 {
27     switch (qualifier)
28     {
29         case EvqPosition:
30             ASSERT(name == "gl_Position");
31             return 0;
32         case EvqPointSize:
33             ASSERT(name == "gl_PointSize");
34             return 1;
35         case EvqClipDistance:
36             ASSERT(name == "gl_ClipDistance");
37             return 2;
38         case EvqCullDistance:
39             ASSERT(name == "gl_CullDistance");
40             return 3;
41         default:
42             return -1;
43     }
44 }
45 
46 // Traverser that:
47 //
48 // Inspects global qualifier declarations and extracts whether any of the gl_PerVertex built-ins
49 // are invariant or precise. These declarations are then dropped.
50 class InspectPerVertexBuiltInsTraverser : public TIntermTraverser
51 {
52   public:
InspectPerVertexBuiltInsTraverser(TCompiler * compiler,TSymbolTable * symbolTable,PerVertexMemberFlags * invariantFlagsOut,PerVertexMemberFlags * preciseFlagsOut)53     InspectPerVertexBuiltInsTraverser(TCompiler *compiler,
54                                       TSymbolTable *symbolTable,
55                                       PerVertexMemberFlags *invariantFlagsOut,
56                                       PerVertexMemberFlags *preciseFlagsOut)
57         : TIntermTraverser(true, false, false, symbolTable),
58           mInvariantFlagsOut(invariantFlagsOut),
59           mPreciseFlagsOut(preciseFlagsOut)
60     {}
61 
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)62     bool visitGlobalQualifierDeclaration(Visit visit,
63                                          TIntermGlobalQualifierDeclaration *node) override
64     {
65         TIntermSymbol *symbol = node->getSymbol();
66 
67         const int fieldIndex =
68             GetPerVertexFieldIndex(symbol->getType().getQualifier(), symbol->getName());
69         if (fieldIndex < 0)
70         {
71             return false;
72         }
73 
74         if (node->isInvariant())
75         {
76             (*mInvariantFlagsOut)[fieldIndex] = true;
77         }
78         else if (node->isPrecise())
79         {
80             (*mPreciseFlagsOut)[fieldIndex] = true;
81         }
82 
83         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node, TIntermSequence());
84 
85         return false;
86     }
87 
visitDeclaration(Visit visit,TIntermDeclaration * node)88     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
89     {
90         const TIntermSequence &sequence = *(node->getSequence());
91 
92         ASSERT(sequence.size() == 1);
93 
94         const TIntermSymbol *symbol = sequence.front()->getAsSymbolNode();
95         if (symbol == nullptr)
96         {
97             return true;
98         }
99 
100         const TType &type = symbol->getType();
101         switch (type.getQualifier())
102         {
103             case EvqClipDistance:
104             case EvqCullDistance:
105                 break;
106             default:
107                 return true;
108         }
109 
110         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node, TIntermSequence());
111         return true;
112     }
113 
114   private:
115     PerVertexMemberFlags *mInvariantFlagsOut;
116     PerVertexMemberFlags *mPreciseFlagsOut;
117 };
118 
119 // Traverser that:
120 //
121 // 1. Declares the input and output gl_PerVertex types and variables if not already (based on shader
122 //    type).
123 // 2. Turns built-in references into indexes into these variables.
124 class DeclarePerVertexBlocksTraverser : public TIntermTraverser
125 {
126   public:
DeclarePerVertexBlocksTraverser(TCompiler * compiler,TSymbolTable * symbolTable,const PerVertexMemberFlags & invariantFlags,const PerVertexMemberFlags & preciseFlags,uint8_t clipDistanceArraySize,uint8_t cullDistanceArraySize)127     DeclarePerVertexBlocksTraverser(TCompiler *compiler,
128                                     TSymbolTable *symbolTable,
129                                     const PerVertexMemberFlags &invariantFlags,
130                                     const PerVertexMemberFlags &preciseFlags,
131                                     uint8_t clipDistanceArraySize,
132                                     uint8_t cullDistanceArraySize)
133         : TIntermTraverser(true, false, false, symbolTable),
134           mShaderType(compiler->getShaderType()),
135           mShaderVersion(compiler->getShaderVersion()),
136           mResources(compiler->getResources()),
137           mClipDistanceArraySize(clipDistanceArraySize),
138           mCullDistanceArraySize(cullDistanceArraySize),
139           mPerVertexInVar(nullptr),
140           mPerVertexOutVar(nullptr),
141           mPerVertexInVarRedeclared(false),
142           mPerVertexOutVarRedeclared(false),
143           mPositionRedeclaredForSeparateShaderObject(false),
144           mPointSizeRedeclaredForSeparateShaderObject(false),
145           mPerVertexOutInvariantFlags(invariantFlags),
146           mPerVertexOutPreciseFlags(preciseFlags)
147     {}
148 
visitSymbol(TIntermSymbol * symbol)149     void visitSymbol(TIntermSymbol *symbol) override
150     {
151         const TVariable *variable = &symbol->variable();
152         const TType *type         = &variable->getType();
153 
154         // Replace gl_out if necessary.
155         if (mShaderType == GL_TESS_CONTROL_SHADER && type->getQualifier() == EvqPerVertexOut)
156         {
157             ASSERT(variable->name() == "gl_out");
158 
159             // Declare gl_out if not already.
160             if (mPerVertexOutVar == nullptr)
161             {
162                 // Record invariant and precise qualifiers used on the fields so they would be
163                 // applied to the replacement gl_out.
164                 for (const TField *field : type->getInterfaceBlock()->fields())
165                 {
166                     const TType &fieldType = *field->type();
167                     const int fieldIndex =
168                         GetPerVertexFieldIndex(fieldType.getQualifier(), field->name());
169                     ASSERT(fieldIndex >= 0);
170 
171                     if (fieldType.isInvariant())
172                     {
173                         mPerVertexOutInvariantFlags[fieldIndex] = true;
174                     }
175                     if (fieldType.isPrecise())
176                     {
177                         mPerVertexOutPreciseFlags[fieldIndex] = true;
178                     }
179                 }
180 
181                 declareDefaultGlOut();
182             }
183 
184             if (mPerVertexOutVarRedeclared)
185             {
186                 // Traverse the parents and promote the new type.  Replace the root of
187                 // EOpIndex[In]Direct chain.
188                 queueAccessChainReplacement(new TIntermSymbol(mPerVertexOutVar));
189             }
190 
191             return;
192         }
193 
194         // Replace gl_in if necessary.
195         if ((mShaderType == GL_TESS_CONTROL_SHADER || mShaderType == GL_TESS_EVALUATION_SHADER ||
196              mShaderType == GL_GEOMETRY_SHADER) &&
197             type->getQualifier() == EvqPerVertexIn)
198         {
199             ASSERT(variable->name() == "gl_in");
200 
201             // Declare gl_in if not already.
202             if (mPerVertexInVar == nullptr)
203             {
204                 declareDefaultGlIn();
205             }
206 
207             if (mPerVertexInVarRedeclared)
208             {
209                 // Traverse the parents and promote the new type.  Replace the root of
210                 // EOpIndex[In]Direct chain.
211                 queueAccessChainReplacement(new TIntermSymbol(mPerVertexInVar));
212             }
213 
214             return;
215         }
216 
217         // Turn gl_Position, gl_PointSize, gl_ClipDistance and gl_CullDistance into references to
218         // the output gl_PerVertex.  Note that the default gl_PerVertex is declared as follows:
219         //
220         //     out gl_PerVertex
221         //     {
222         //         vec4 gl_Position;
223         //         float gl_PointSize;
224         //         float gl_ClipDistance[];
225         //         float gl_CullDistance[];
226         //     };
227         //
228 
229         if (variable->symbolType() != SymbolType::BuiltIn &&
230             !(variable->name() == "gl_Position" && mPositionRedeclaredForSeparateShaderObject) &&
231             !(variable->name() == "gl_PointSize" && mPointSizeRedeclaredForSeparateShaderObject))
232         {
233             ASSERT(variable->name() != "gl_Position" && variable->name() != "gl_PointSize" &&
234                    variable->name() != "gl_ClipDistance" && variable->name() != "gl_CullDistance" &&
235                    variable->name() != "gl_in" && variable->name() != "gl_out");
236 
237             return;
238         }
239 
240         // If this built-in was already visited, reuse the variable defined for it.
241         auto replacement = mVariableMap.find(variable);
242         if (replacement != mVariableMap.end())
243         {
244             queueReplacement(replacement->second->deepCopy(), OriginalNode::IS_DROPPED);
245             return;
246         }
247 
248         int fieldIndex = GetPerVertexFieldIndex(type->getQualifier(), variable->name());
249 
250         // Not the built-in we are looking for.
251         if (fieldIndex < 0)
252         {
253             return;
254         }
255 
256         // If gl_ClipDistance is not used, it will be skipped and gl_CullDistance will have index 2.
257         if (fieldIndex == 3 && mClipDistanceArraySize == 0)
258         {
259             fieldIndex = 2;
260         }
261 
262         // Declare the output gl_PerVertex if not already.
263         if (mPerVertexOutVar == nullptr)
264         {
265             declareDefaultGlOut();
266         }
267 
268         TType *newType = new TType(*type);
269         newType->setInterfaceBlockField(mPerVertexOutVar->getType().getInterfaceBlock(),
270                                         fieldIndex);
271 
272         TVariable *newVariable = new TVariable(mSymbolTable, variable->name(), newType,
273                                                variable->symbolType(), variable->extensions());
274 
275         TIntermSymbol *newSymbol = new TIntermSymbol(newVariable);
276         mVariableMap[variable]   = newSymbol;
277 
278         queueReplacement(newSymbol, OriginalNode::IS_DROPPED);
279     }
280 
visitDeclaration(Visit visit,TIntermDeclaration * node)281     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
282     {
283         if (!mInGlobalScope)
284         {
285             return true;
286         }
287 
288         // When EXT_separate_shader_objects is enabled, gl_Position and gl_PointSize are required to
289         // be redeclared by the vertex shader.  Make sure that is taken into account.
290         TIntermSequence *sequence = node->getSequence();
291         TIntermSymbol *symbol     = sequence->front()->getAsSymbolNode();
292         if (symbol == nullptr)
293         {
294             return true;
295         }
296 
297         TIntermSequence emptyReplacement;
298         if (symbol->getType().getQualifier() == EvqPosition)
299         {
300             mPositionRedeclaredForSeparateShaderObject = true;
301             mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
302                                             std::move(emptyReplacement));
303             return false;
304         }
305         if (symbol->getType().getQualifier() == EvqPointSize)
306         {
307             mPointSizeRedeclaredForSeparateShaderObject = true;
308             mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
309                                             std::move(emptyReplacement));
310             return false;
311         }
312 
313         return true;
314     }
315 
getRedeclaredPerVertexOutVar()316     const TVariable *getRedeclaredPerVertexOutVar()
317     {
318         return mPerVertexOutVarRedeclared ? mPerVertexOutVar : nullptr;
319     }
320 
getRedeclaredPerVertexInVar()321     const TVariable *getRedeclaredPerVertexInVar()
322     {
323         return mPerVertexInVarRedeclared ? mPerVertexInVar : nullptr;
324     }
325 
326   private:
declarePerVertex(TQualifier qualifier,uint32_t arraySize,ImmutableString & variableName)327     const TVariable *declarePerVertex(TQualifier qualifier,
328                                       uint32_t arraySize,
329                                       ImmutableString &variableName)
330     {
331         TFieldList *fields = new TFieldList;
332 
333         const TType *vec4Type  = StaticType::GetBasic<EbtFloat, EbpHigh, 4>();
334         const TType *floatType = StaticType::GetBasic<EbtFloat, EbpHigh, 1>();
335 
336         TType *positionType     = new TType(*vec4Type);
337         TType *pointSizeType    = new TType(*floatType);
338         TType *clipDistanceType = mClipDistanceArraySize ? new TType(*floatType) : nullptr;
339         TType *cullDistanceType = mCullDistanceArraySize ? new TType(*floatType) : nullptr;
340 
341         positionType->setQualifier(EvqPosition);
342         pointSizeType->setQualifier(EvqPointSize);
343         if (clipDistanceType)
344             clipDistanceType->setQualifier(EvqClipDistance);
345         if (cullDistanceType)
346             cullDistanceType->setQualifier(EvqCullDistance);
347 
348         TPrecision pointSizePrecision = EbpHigh;
349         if (mShaderType == GL_VERTEX_SHADER)
350         {
351             // gl_PointSize is mediump in ES100 and highp in ES300+.
352             const TVariable *glPointSize = static_cast<const TVariable *>(
353                 mSymbolTable->findBuiltIn(ImmutableString("gl_PointSize"), mShaderVersion));
354             ASSERT(glPointSize);
355 
356             pointSizePrecision = glPointSize->getType().getPrecision();
357         }
358         pointSizeType->setPrecision(pointSizePrecision);
359 
360         // TODO: handle interaction with GS and T*S where the two can have different sizes.  These
361         // values are valid for EvqPerVertexOut only.  For EvqPerVertexIn, the size should come from
362         // the declaration of gl_in.  http://anglebug.com/42264006.
363         if (clipDistanceType)
364             clipDistanceType->makeArray(mClipDistanceArraySize);
365         if (cullDistanceType)
366             cullDistanceType->makeArray(mCullDistanceArraySize);
367 
368         if (qualifier == EvqPerVertexOut)
369         {
370             positionType->setInvariant(mPerVertexOutInvariantFlags[0]);
371             pointSizeType->setInvariant(mPerVertexOutInvariantFlags[1]);
372             if (clipDistanceType)
373                 clipDistanceType->setInvariant(mPerVertexOutInvariantFlags[2]);
374             if (cullDistanceType)
375                 cullDistanceType->setInvariant(mPerVertexOutInvariantFlags[3]);
376 
377             positionType->setPrecise(mPerVertexOutPreciseFlags[0]);
378             pointSizeType->setPrecise(mPerVertexOutPreciseFlags[1]);
379             if (clipDistanceType)
380                 clipDistanceType->setPrecise(mPerVertexOutPreciseFlags[2]);
381             if (cullDistanceType)
382                 cullDistanceType->setPrecise(mPerVertexOutPreciseFlags[3]);
383         }
384 
385         fields->push_back(new TField(positionType, ImmutableString("gl_Position"), TSourceLoc(),
386                                      SymbolType::AngleInternal));
387         fields->push_back(new TField(pointSizeType, ImmutableString("gl_PointSize"), TSourceLoc(),
388                                      SymbolType::AngleInternal));
389         if (clipDistanceType)
390             fields->push_back(new TField(clipDistanceType, ImmutableString("gl_ClipDistance"),
391                                          TSourceLoc(), SymbolType::AngleInternal));
392         if (cullDistanceType)
393             fields->push_back(new TField(cullDistanceType, ImmutableString("gl_CullDistance"),
394                                          TSourceLoc(), SymbolType::AngleInternal));
395 
396         TInterfaceBlock *interfaceBlock =
397             new TInterfaceBlock(mSymbolTable, ImmutableString("gl_PerVertex"), fields,
398                                 TLayoutQualifier::Create(), SymbolType::AngleInternal);
399 
400         TType *interfaceBlockType =
401             new TType(interfaceBlock, qualifier, TLayoutQualifier::Create());
402         if (arraySize > 0)
403         {
404             interfaceBlockType->makeArray(arraySize);
405         }
406 
407         TVariable *interfaceBlockVar =
408             new TVariable(mSymbolTable, variableName, interfaceBlockType,
409                           variableName.empty() ? SymbolType::Empty : SymbolType::AngleInternal);
410 
411         return interfaceBlockVar;
412     }
413 
declareDefaultGlOut()414     void declareDefaultGlOut()
415     {
416         ASSERT(!mPerVertexOutVarRedeclared);
417 
418         // For tessellation control shaders, gl_out is an array of MaxPatchVertices
419         // For other shaders, there's no explicit name or array size
420 
421         ImmutableString varName("");
422         uint32_t arraySize = 0;
423         if (mShaderType == GL_TESS_CONTROL_SHADER)
424         {
425             varName   = ImmutableString("gl_out");
426             arraySize = mResources.MaxPatchVertices;
427         }
428 
429         mPerVertexOutVar           = declarePerVertex(EvqPerVertexOut, arraySize, varName);
430         mPerVertexOutVarRedeclared = true;
431     }
432 
declareDefaultGlIn()433     void declareDefaultGlIn()
434     {
435         ASSERT(!mPerVertexInVarRedeclared);
436 
437         // For tessellation shaders, gl_in is an array of MaxPatchVertices.
438         // For geometry shaders, gl_in is sized based on the primitive type.
439 
440         ImmutableString varName("gl_in");
441         uint32_t arraySize = mResources.MaxPatchVertices;
442         if (mShaderType == GL_GEOMETRY_SHADER)
443         {
444             arraySize =
445                 mSymbolTable->getGlInVariableWithArraySize()->getType().getOutermostArraySize();
446         }
447 
448         mPerVertexInVar           = declarePerVertex(EvqPerVertexIn, arraySize, varName);
449         mPerVertexInVarRedeclared = true;
450     }
451 
452     GLenum mShaderType;
453     int mShaderVersion;
454     const ShBuiltInResources &mResources;
455     uint8_t mClipDistanceArraySize;
456     uint8_t mCullDistanceArraySize;
457 
458     const TVariable *mPerVertexInVar;
459     const TVariable *mPerVertexOutVar;
460 
461     bool mPerVertexInVarRedeclared;
462     bool mPerVertexOutVarRedeclared;
463 
464     bool mPositionRedeclaredForSeparateShaderObject;
465     bool mPointSizeRedeclaredForSeparateShaderObject;
466 
467     // A map of already replaced built-in variables.
468     VariableReplacementMap mVariableMap;
469 
470     // Whether each field is invariant or precise.
471     PerVertexMemberFlags mPerVertexOutInvariantFlags;
472     PerVertexMemberFlags mPerVertexOutPreciseFlags;
473 };
474 
AddPerVertexDecl(TIntermBlock * root,const TVariable * variable)475 void AddPerVertexDecl(TIntermBlock *root, const TVariable *variable)
476 {
477     if (variable == nullptr)
478     {
479         return;
480     }
481 
482     TIntermDeclaration *decl = new TIntermDeclaration;
483     TIntermSymbol *symbol    = new TIntermSymbol(variable);
484     decl->appendDeclarator(symbol);
485 
486     // Insert the declaration before the first function.
487     size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(root);
488     root->insertChildNodes(firstFunctionIndex, {decl});
489 }
490 }  // anonymous namespace
491 
DeclarePerVertexBlocks(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const TVariable ** inputPerVertexOut,const TVariable ** outputPerVertexOut)492 bool DeclarePerVertexBlocks(TCompiler *compiler,
493                             TIntermBlock *root,
494                             TSymbolTable *symbolTable,
495                             const TVariable **inputPerVertexOut,
496                             const TVariable **outputPerVertexOut)
497 {
498     if (compiler->getShaderType() == GL_COMPUTE_SHADER ||
499         compiler->getShaderType() == GL_FRAGMENT_SHADER)
500     {
501         return true;
502     }
503 
504     // First, visit all global qualifier declarations and find which built-ins are invariant or
505     // precise. At the same time, remove gl_ClipDistance and gl_CullDistance array redeclarations.
506     PerVertexMemberFlags invariantFlags = {};
507     PerVertexMemberFlags preciseFlags   = {};
508 
509     InspectPerVertexBuiltInsTraverser infoTraverser(compiler, symbolTable, &invariantFlags,
510                                                     &preciseFlags);
511     root->traverse(&infoTraverser);
512     if (!infoTraverser.updateTree(compiler, root))
513     {
514         return false;
515     }
516 
517     // If #pragma STDGL invariant(all) is specified, make all outputs invariant.
518     if (compiler->getPragma().stdgl.invariantAll)
519     {
520         std::fill(invariantFlags.begin(), invariantFlags.end(), true);
521     }
522 
523     // Then declare the in and out gl_PerVertex I/O blocks.
524     DeclarePerVertexBlocksTraverser traverser(compiler, symbolTable, invariantFlags, preciseFlags,
525                                               compiler->getClipDistanceArraySize(),
526                                               compiler->getCullDistanceArraySize());
527     root->traverse(&traverser);
528     if (!traverser.updateTree(compiler, root))
529     {
530         return false;
531     }
532 
533     AddPerVertexDecl(root, traverser.getRedeclaredPerVertexOutVar());
534     AddPerVertexDecl(root, traverser.getRedeclaredPerVertexInVar());
535 
536     if (inputPerVertexOut)
537     {
538         *inputPerVertexOut = traverser.getRedeclaredPerVertexInVar();
539     }
540     if (outputPerVertexOut)
541     {
542         *outputPerVertexOut = traverser.getRedeclaredPerVertexOutVar();
543     }
544 
545     return compiler->validateAST(root);
546 }
547 }  // namespace sh
548