xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/InitializeVariables.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/tree_ops/InitializeVariables.h"
8 
9 #include "angle_gl.h"
10 #include "common/debug.h"
11 #include "common/hash_containers.h"
12 #include "compiler/translator/Compiler.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/FindMain.h"
16 #include "compiler/translator/tree_util/FindSymbolNode.h"
17 #include "compiler/translator/tree_util/IntermNode_util.h"
18 #include "compiler/translator/tree_util/IntermTraverse.h"
19 #include "compiler/translator/util.h"
20 
21 namespace sh
22 {
23 
24 namespace
25 {
26 
27 void AddArrayZeroInitSequence(const TIntermTyped *initializedNode,
28                               bool canUseLoopsToInitialize,
29                               bool highPrecisionSupported,
30                               TIntermSequence *initSequenceOut,
31                               TSymbolTable *symbolTable);
32 
33 void AddStructZeroInitSequence(const TIntermTyped *initializedNode,
34                                bool canUseLoopsToInitialize,
35                                bool highPrecisionSupported,
36                                TIntermSequence *initSequenceOut,
37                                TSymbolTable *symbolTable);
38 
CreateZeroInitAssignment(const TIntermTyped * initializedNode)39 TIntermBinary *CreateZeroInitAssignment(const TIntermTyped *initializedNode)
40 {
41     TIntermTyped *zero = CreateZeroNode(initializedNode->getType());
42     return new TIntermBinary(EOpAssign, initializedNode->deepCopy(), zero);
43 }
44 
AddZeroInitSequence(const TIntermTyped * initializedNode,bool canUseLoopsToInitialize,bool highPrecisionSupported,TIntermSequence * initSequenceOut,TSymbolTable * symbolTable)45 void AddZeroInitSequence(const TIntermTyped *initializedNode,
46                          bool canUseLoopsToInitialize,
47                          bool highPrecisionSupported,
48                          TIntermSequence *initSequenceOut,
49                          TSymbolTable *symbolTable)
50 {
51     if (initializedNode->isArray())
52     {
53         AddArrayZeroInitSequence(initializedNode, canUseLoopsToInitialize, highPrecisionSupported,
54                                  initSequenceOut, symbolTable);
55     }
56     else if (initializedNode->getType().isStructureContainingArrays() ||
57              initializedNode->getType().isNamelessStruct())
58     {
59         AddStructZeroInitSequence(initializedNode, canUseLoopsToInitialize, highPrecisionSupported,
60                                   initSequenceOut, symbolTable);
61     }
62     else if (initializedNode->getType().isInterfaceBlock())
63     {
64         const TType &type                     = initializedNode->getType();
65         const TInterfaceBlock &interfaceBlock = *type.getInterfaceBlock();
66         const TFieldList &fieldList           = interfaceBlock.fields();
67         for (size_t fieldIndex = 0; fieldIndex < fieldList.size(); ++fieldIndex)
68         {
69             const TField &field         = *fieldList[fieldIndex];
70             TIntermTyped *fieldIndexRef = CreateIndexNode(static_cast<int>(fieldIndex));
71             TIntermTyped *fieldReference =
72                 new TIntermBinary(TOperator::EOpIndexDirectInterfaceBlock,
73                                   initializedNode->deepCopy(), fieldIndexRef);
74             TIntermTyped *fieldZero = CreateZeroNode(*field.type());
75             TIntermTyped *assignment =
76                 new TIntermBinary(TOperator::EOpAssign, fieldReference, fieldZero);
77             initSequenceOut->push_back(assignment);
78         }
79     }
80     else
81     {
82         initSequenceOut->push_back(CreateZeroInitAssignment(initializedNode));
83     }
84 }
85 
AddStructZeroInitSequence(const TIntermTyped * initializedNode,bool canUseLoopsToInitialize,bool highPrecisionSupported,TIntermSequence * initSequenceOut,TSymbolTable * symbolTable)86 void AddStructZeroInitSequence(const TIntermTyped *initializedNode,
87                                bool canUseLoopsToInitialize,
88                                bool highPrecisionSupported,
89                                TIntermSequence *initSequenceOut,
90                                TSymbolTable *symbolTable)
91 {
92     ASSERT(initializedNode->getBasicType() == EbtStruct);
93     const TStructure *structType = initializedNode->getType().getStruct();
94     for (int i = 0; i < static_cast<int>(structType->fields().size()); ++i)
95     {
96         TIntermBinary *element = new TIntermBinary(EOpIndexDirectStruct,
97                                                    initializedNode->deepCopy(), CreateIndexNode(i));
98         // Structs can't be defined inside structs, so the type of a struct field can't be a
99         // nameless struct.
100         ASSERT(!element->getType().isNamelessStruct());
101         AddZeroInitSequence(element, canUseLoopsToInitialize, highPrecisionSupported,
102                             initSequenceOut, symbolTable);
103     }
104 }
105 
AddArrayZeroInitStatementList(const TIntermTyped * initializedNode,bool canUseLoopsToInitialize,bool highPrecisionSupported,TIntermSequence * initSequenceOut,TSymbolTable * symbolTable)106 void AddArrayZeroInitStatementList(const TIntermTyped *initializedNode,
107                                    bool canUseLoopsToInitialize,
108                                    bool highPrecisionSupported,
109                                    TIntermSequence *initSequenceOut,
110                                    TSymbolTable *symbolTable)
111 {
112     for (unsigned int i = 0; i < initializedNode->getOutermostArraySize(); ++i)
113     {
114         TIntermBinary *element =
115             new TIntermBinary(EOpIndexDirect, initializedNode->deepCopy(), CreateIndexNode(i));
116         AddZeroInitSequence(element, canUseLoopsToInitialize, highPrecisionSupported,
117                             initSequenceOut, symbolTable);
118     }
119 }
120 
AddArrayZeroInitForLoop(const TIntermTyped * initializedNode,bool highPrecisionSupported,TIntermSequence * initSequenceOut,TSymbolTable * symbolTable)121 void AddArrayZeroInitForLoop(const TIntermTyped *initializedNode,
122                              bool highPrecisionSupported,
123                              TIntermSequence *initSequenceOut,
124                              TSymbolTable *symbolTable)
125 {
126     ASSERT(initializedNode->isArray());
127     const TType *mediumpIndexType = StaticType::Get<EbtInt, EbpMedium, EvqTemporary, 1, 1>();
128     const TType *highpIndexType   = StaticType::Get<EbtInt, EbpHigh, EvqTemporary, 1, 1>();
129     TVariable *indexVariable =
130         CreateTempVariable(symbolTable, highPrecisionSupported ? highpIndexType : mediumpIndexType);
131 
132     TIntermSymbol *indexSymbolNode = CreateTempSymbolNode(indexVariable);
133     TIntermDeclaration *indexInit =
134         CreateTempInitDeclarationNode(indexVariable, CreateZeroNode(indexVariable->getType()));
135     TIntermConstantUnion *arraySizeNode = CreateIndexNode(initializedNode->getOutermostArraySize());
136     TIntermBinary *indexSmallerThanSize =
137         new TIntermBinary(EOpLessThan, indexSymbolNode->deepCopy(), arraySizeNode);
138     TIntermUnary *indexIncrement =
139         new TIntermUnary(EOpPreIncrement, indexSymbolNode->deepCopy(), nullptr);
140 
141     TIntermBlock *forLoopBody       = new TIntermBlock();
142     TIntermSequence *forLoopBodySeq = forLoopBody->getSequence();
143 
144     TIntermBinary *element = new TIntermBinary(EOpIndexIndirect, initializedNode->deepCopy(),
145                                                indexSymbolNode->deepCopy());
146     AddZeroInitSequence(element, true, highPrecisionSupported, forLoopBodySeq, symbolTable);
147 
148     TIntermLoop *forLoop =
149         new TIntermLoop(ELoopFor, indexInit, indexSmallerThanSize, indexIncrement, forLoopBody);
150     initSequenceOut->push_back(forLoop);
151 }
152 
AddArrayZeroInitSequence(const TIntermTyped * initializedNode,bool canUseLoopsToInitialize,bool highPrecisionSupported,TIntermSequence * initSequenceOut,TSymbolTable * symbolTable)153 void AddArrayZeroInitSequence(const TIntermTyped *initializedNode,
154                               bool canUseLoopsToInitialize,
155                               bool highPrecisionSupported,
156                               TIntermSequence *initSequenceOut,
157                               TSymbolTable *symbolTable)
158 {
159     // The array elements are assigned one by one to keep the AST compatible with ESSL 1.00 which
160     // doesn't have array assignment. We'll do this either with a for loop or just a list of
161     // statements assigning to each array index. Note that it is important to have the array init in
162     // the right order to workaround http://crbug.com/709317
163     bool isSmallArray = initializedNode->getOutermostArraySize() <= 1u ||
164                         (initializedNode->getBasicType() != EbtStruct &&
165                          !initializedNode->getType().isArrayOfArrays() &&
166                          initializedNode->getOutermostArraySize() <= 3u);
167     if (initializedNode->getQualifier() == EvqFragData ||
168         initializedNode->getQualifier() == EvqFragmentOut || isSmallArray ||
169         !canUseLoopsToInitialize)
170     {
171         // Fragment outputs should not be indexed by non-constant indices.
172         // Also it doesn't make sense to use loops to initialize very small arrays.
173         AddArrayZeroInitStatementList(initializedNode, canUseLoopsToInitialize,
174                                       highPrecisionSupported, initSequenceOut, symbolTable);
175     }
176     else
177     {
178         AddArrayZeroInitForLoop(initializedNode, highPrecisionSupported, initSequenceOut,
179                                 symbolTable);
180     }
181 }
182 
InsertInitCode(TCompiler * compiler,TIntermBlock * root,const InitVariableList & variables,TSymbolTable * symbolTable,int shaderVersion,const TExtensionBehavior & extensionBehavior,bool canUseLoopsToInitialize,bool highPrecisionSupported)183 void InsertInitCode(TCompiler *compiler,
184                     TIntermBlock *root,
185                     const InitVariableList &variables,
186                     TSymbolTable *symbolTable,
187                     int shaderVersion,
188                     const TExtensionBehavior &extensionBehavior,
189                     bool canUseLoopsToInitialize,
190                     bool highPrecisionSupported)
191 {
192     TIntermSequence *mainBody = FindMainBody(root)->getSequence();
193     for (const ShaderVariable &var : variables)
194     {
195         // Note that tempVariableName will reference a short-lived char array here - that's fine
196         // since we're only using it to find symbols.
197         ImmutableString tempVariableName(var.name.c_str(), var.name.length());
198 
199         TIntermTyped *initializedSymbol = nullptr;
200         if (var.isBuiltIn() && !symbolTable->findUserDefined(tempVariableName))
201         {
202             initializedSymbol =
203                 ReferenceBuiltInVariable(tempVariableName, *symbolTable, shaderVersion);
204             if (initializedSymbol->getQualifier() == EvqFragData &&
205                 !IsExtensionEnabled(extensionBehavior, TExtension::EXT_draw_buffers))
206             {
207                 // If GL_EXT_draw_buffers is disabled, only the 0th index of gl_FragData can be
208                 // written to.
209                 // TODO(oetuaho): This is a bit hacky and would be better to remove, if we came up
210                 // with a good way to do it. Right now "gl_FragData" in symbol table is initialized
211                 // to have the array size of MaxDrawBuffers, and the initialization happens before
212                 // the shader sets the extensions it is using.
213                 initializedSymbol =
214                     new TIntermBinary(EOpIndexDirect, initializedSymbol, CreateIndexNode(0));
215             }
216             else if (initializedSymbol->getQualifier() == EvqClipDistance ||
217                      initializedSymbol->getQualifier() == EvqCullDistance)
218             {
219                 // The built-in may have been implicitly resized.
220                 initializedSymbol =
221                     new TIntermSymbol(&FindSymbolNode(root, tempVariableName)->variable());
222             }
223         }
224         else
225         {
226             if (tempVariableName != "")
227             {
228                 initializedSymbol =
229                     new TIntermSymbol(&FindSymbolNode(root, tempVariableName)->variable());
230             }
231             else
232             {
233                 // Must be a nameless interface block.
234                 ASSERT(var.structOrBlockName != "");
235                 const TSymbol *symbol = symbolTable->findGlobal(var.structOrBlockName);
236                 ASSERT(symbol && symbol->isInterfaceBlock());
237                 const TInterfaceBlock *block = static_cast<const TInterfaceBlock *>(symbol);
238 
239                 for (const TField *field : block->fields())
240                 {
241                     initializedSymbol = ReferenceGlobalVariable(field->name(), *symbolTable);
242 
243                     TIntermSequence initCode;
244                     CreateInitCode(initializedSymbol, canUseLoopsToInitialize,
245                                    highPrecisionSupported, &initCode, symbolTable);
246                     mainBody->insert(mainBody->begin(), initCode.begin(), initCode.end());
247                 }
248                 // Already inserted init code in this case
249                 continue;
250             }
251         }
252         ASSERT(initializedSymbol != nullptr);
253 
254         TIntermSequence initCode;
255         CreateInitCode(initializedSymbol, canUseLoopsToInitialize, highPrecisionSupported,
256                        &initCode, symbolTable);
257         mainBody->insert(mainBody->begin(), initCode.begin(), initCode.end());
258     }
259 }
260 
CloneFunctionHeader(TSymbolTable * symbolTable,const TFunction * function)261 TFunction *CloneFunctionHeader(TSymbolTable *symbolTable, const TFunction *function)
262 {
263     TFunction *newFunction =
264         new TFunction(symbolTable, function->name(), function->symbolType(),
265                       &function->getReturnType(), function->isKnownToNotHaveSideEffects());
266 
267     if (function->isDefined())
268     {
269         newFunction->setDefined();
270     }
271     if (function->hasPrototypeDeclaration())
272     {
273         newFunction->setHasPrototypeDeclaration();
274     }
275     return newFunction;
276 }
277 
278 class InitializeLocalsTraverser final : public TIntermTraverser
279 {
280   public:
InitializeLocalsTraverser(int shaderVersion,TSymbolTable * symbolTable,bool canUseLoopsToInitialize,bool highPrecisionSupported)281     InitializeLocalsTraverser(int shaderVersion,
282                               TSymbolTable *symbolTable,
283                               bool canUseLoopsToInitialize,
284                               bool highPrecisionSupported)
285         : TIntermTraverser(true, false, false, symbolTable),
286           mShaderVersion(shaderVersion),
287           mCanUseLoopsToInitialize(canUseLoopsToInitialize),
288           mHighPrecisionSupported(highPrecisionSupported)
289     {}
290 
collectUnnamedOutFunctions(TIntermBlock & root)291     void collectUnnamedOutFunctions(TIntermBlock &root)
292     {
293         TIntermSequence &sequence = *root.getSequence();
294         const size_t count        = sequence.size();
295         for (size_t i = 0; i < count; ++i)
296         {
297             const TIntermFunctionDefinition *functionDefinition =
298                 sequence[i]->getAsFunctionDefinition();
299             if (!functionDefinition)
300             {
301                 continue;
302             }
303             const TFunction *function = functionDefinition->getFunction();
304             TFunction *newFunction    = nullptr;
305             for (size_t p = 0; p < function->getParamCount(); ++p)
306             {
307                 const TVariable *param = function->getParam(p);
308                 const TType &type      = param->getType();
309                 if (param->symbolType() == SymbolType::Empty)
310                 {
311                     if (!newFunction)
312                     {
313                         newFunction                   = CloneFunctionHeader(mSymbolTable, function);
314                         mFunctionsToReplace[function] = newFunction;
315                         for (size_t z = 0; z < p; ++z)
316                         {
317                             newFunction->addParameter(function->getParam(z));
318                         }
319                     }
320                     param = new TVariable(mSymbolTable, kEmptyImmutableString, &type,
321                                           SymbolType::AngleInternal, param->extensions());
322                 }
323                 if (newFunction)
324                 {
325                     newFunction->addParameter(param);
326                 }
327             }
328         }
329     }
330 
331   protected:
visitDeclaration(Visit visit,TIntermDeclaration * node)332     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
333     {
334         for (TIntermNode *declarator : *node->getSequence())
335         {
336             if (!mInGlobalScope && !declarator->getAsBinaryNode())
337             {
338                 TIntermSymbol *symbol = declarator->getAsSymbolNode();
339                 ASSERT(symbol);
340                 if (symbol->variable().symbolType() == SymbolType::Empty)
341                 {
342                     continue;
343                 }
344 
345                 // Arrays may need to be initialized one element at a time, since ESSL 1.00 does not
346                 // support array constructors or assigning arrays.
347                 bool arrayConstructorUnavailable =
348                     (symbol->isArray() || symbol->getType().isStructureContainingArrays()) &&
349                     mShaderVersion == 100;
350                 // Nameless struct constructors can't be referred to, so they also need to be
351                 // initialized one element at a time.
352                 // TODO(oetuaho): Check if it makes sense to initialize using a loop, even if we
353                 // could use an initializer. It could at least reduce code size for very large
354                 // arrays, but could hurt runtime performance.
355                 if (arrayConstructorUnavailable || symbol->getType().isNamelessStruct())
356                 {
357                     // SimplifyLoopConditions should have been run so the parent node of this node
358                     // should not be a loop.
359                     ASSERT(getParentNode()->getAsLoopNode() == nullptr);
360                     // SeparateDeclarations should have already been run, so we don't need to worry
361                     // about further declarators in this declaration depending on the effects of
362                     // this declarator.
363                     ASSERT(node->getSequence()->size() == 1);
364                     TIntermSequence initCode;
365                     CreateInitCode(symbol, mCanUseLoopsToInitialize, mHighPrecisionSupported,
366                                    &initCode, mSymbolTable);
367                     insertStatementsInParentBlock(TIntermSequence(), initCode);
368                 }
369                 else
370                 {
371                     TIntermBinary *init =
372                         new TIntermBinary(EOpInitialize, symbol, CreateZeroNode(symbol->getType()));
373                     queueReplacementWithParent(node, symbol, init, OriginalNode::BECOMES_CHILD);
374                 }
375             }
376         }
377         // Must recurse in the cases which had initializers, because the initializiers might
378         // call the function that was rewritten.
379         return true;
380     }
381 
visitFunctionPrototype(TIntermFunctionPrototype * node)382     void visitFunctionPrototype(TIntermFunctionPrototype *node) override
383     {
384         if (getParentNode()->getAsFunctionDefinition() != nullptr)
385         {
386             return;
387         }
388         auto it = mFunctionsToReplace.find(node->getFunction());
389         if (it != mFunctionsToReplace.end())
390         {
391             queueReplacement(new TIntermFunctionPrototype(it->second), OriginalNode::IS_DROPPED);
392         }
393     }
394 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)395     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
396     {
397         // Initialize output function arguments as well, the parameter passed in at call time may be
398         // clobbered if the function doesn't fully write to the argument.
399 
400         TIntermSequence initCode;
401 
402         const TFunction *function = node->getFunction();
403         auto it                   = mFunctionsToReplace.find(function);
404         if (it != mFunctionsToReplace.end())
405         {
406             function                                   = it->second;
407             TIntermFunctionPrototype *newPrototypeNode = new TIntermFunctionPrototype(function);
408             TIntermFunctionDefinition *newNode =
409                 new TIntermFunctionDefinition(newPrototypeNode, node->getBody());
410             queueReplacement(newNode, OriginalNode::IS_DROPPED);
411         }
412 
413         for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
414         {
415             const TVariable *paramVariable = function->getParam(paramIndex);
416             const TType &paramType         = paramVariable->getType();
417 
418             if (paramType.getQualifier() != EvqParamOut)
419             {
420                 continue;
421             }
422 
423             CreateInitCode(new TIntermSymbol(paramVariable), mCanUseLoopsToInitialize,
424                            mHighPrecisionSupported, &initCode, mSymbolTable);
425         }
426 
427         if (!initCode.empty())
428         {
429             TIntermSequence *body = node->getBody()->getSequence();
430             body->insert(body->begin(), initCode.begin(), initCode.end());
431         }
432 
433         return true;
434     }
435 
visitAggregate(Visit visit,TIntermAggregate * node)436     bool visitAggregate(Visit visit, TIntermAggregate *node) override
437     {
438         const TFunction *function = node->getFunction();
439         if (function != nullptr)
440         {
441             auto it = mFunctionsToReplace.find(function);
442             if (it != mFunctionsToReplace.end())
443             {
444                 const TFunction *target = it->second;
445                 TIntermAggregate *newNode =
446                     TIntermAggregate::CreateFunctionCall(*target, node->getSequence());
447                 queueReplacement(newNode, OriginalNode::IS_DROPPED);
448             }
449         }
450         return true;
451     }
452 
453   private:
454     int mShaderVersion;
455     bool mCanUseLoopsToInitialize;
456     bool mHighPrecisionSupported;
457     angle::HashMap<const TFunction *, TFunction *> mFunctionsToReplace;
458 };
459 
460 }  // namespace
461 
CreateInitCode(const TIntermTyped * initializedSymbol,bool canUseLoopsToInitialize,bool highPrecisionSupported,TIntermSequence * initCode,TSymbolTable * symbolTable)462 void CreateInitCode(const TIntermTyped *initializedSymbol,
463                     bool canUseLoopsToInitialize,
464                     bool highPrecisionSupported,
465                     TIntermSequence *initCode,
466                     TSymbolTable *symbolTable)
467 {
468     AddZeroInitSequence(initializedSymbol, canUseLoopsToInitialize, highPrecisionSupported,
469                         initCode, symbolTable);
470 }
471 
InitializeUninitializedLocals(TCompiler * compiler,TIntermBlock * root,int shaderVersion,bool canUseLoopsToInitialize,bool highPrecisionSupported,TSymbolTable * symbolTable)472 bool InitializeUninitializedLocals(TCompiler *compiler,
473                                    TIntermBlock *root,
474                                    int shaderVersion,
475                                    bool canUseLoopsToInitialize,
476                                    bool highPrecisionSupported,
477                                    TSymbolTable *symbolTable)
478 {
479     InitializeLocalsTraverser traverser(shaderVersion, symbolTable, canUseLoopsToInitialize,
480                                         highPrecisionSupported);
481     traverser.collectUnnamedOutFunctions(*root);
482     root->traverse(&traverser);
483     return traverser.updateTree(compiler, root);
484 }
485 
InitializeVariables(TCompiler * compiler,TIntermBlock * root,const InitVariableList & vars,TSymbolTable * symbolTable,int shaderVersion,const TExtensionBehavior & extensionBehavior,bool canUseLoopsToInitialize,bool highPrecisionSupported)486 bool InitializeVariables(TCompiler *compiler,
487                          TIntermBlock *root,
488                          const InitVariableList &vars,
489                          TSymbolTable *symbolTable,
490                          int shaderVersion,
491                          const TExtensionBehavior &extensionBehavior,
492                          bool canUseLoopsToInitialize,
493                          bool highPrecisionSupported)
494 {
495     InsertInitCode(compiler, root, vars, symbolTable, shaderVersion, extensionBehavior,
496                    canUseLoopsToInitialize, highPrecisionSupported);
497 
498     return compiler->validateAST(root);
499 }
500 
501 }  // namespace sh
502