1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/tree_ops/InitializeVariables.h"
8
9 #include "angle_gl.h"
10 #include "common/debug.h"
11 #include "common/hash_containers.h"
12 #include "compiler/translator/Compiler.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/FindMain.h"
16 #include "compiler/translator/tree_util/FindSymbolNode.h"
17 #include "compiler/translator/tree_util/IntermNode_util.h"
18 #include "compiler/translator/tree_util/IntermTraverse.h"
19 #include "compiler/translator/util.h"
20
21 namespace sh
22 {
23
24 namespace
25 {
26
27 void AddArrayZeroInitSequence(const TIntermTyped *initializedNode,
28 bool canUseLoopsToInitialize,
29 bool highPrecisionSupported,
30 TIntermSequence *initSequenceOut,
31 TSymbolTable *symbolTable);
32
33 void AddStructZeroInitSequence(const TIntermTyped *initializedNode,
34 bool canUseLoopsToInitialize,
35 bool highPrecisionSupported,
36 TIntermSequence *initSequenceOut,
37 TSymbolTable *symbolTable);
38
CreateZeroInitAssignment(const TIntermTyped * initializedNode)39 TIntermBinary *CreateZeroInitAssignment(const TIntermTyped *initializedNode)
40 {
41 TIntermTyped *zero = CreateZeroNode(initializedNode->getType());
42 return new TIntermBinary(EOpAssign, initializedNode->deepCopy(), zero);
43 }
44
AddZeroInitSequence(const TIntermTyped * initializedNode,bool canUseLoopsToInitialize,bool highPrecisionSupported,TIntermSequence * initSequenceOut,TSymbolTable * symbolTable)45 void AddZeroInitSequence(const TIntermTyped *initializedNode,
46 bool canUseLoopsToInitialize,
47 bool highPrecisionSupported,
48 TIntermSequence *initSequenceOut,
49 TSymbolTable *symbolTable)
50 {
51 if (initializedNode->isArray())
52 {
53 AddArrayZeroInitSequence(initializedNode, canUseLoopsToInitialize, highPrecisionSupported,
54 initSequenceOut, symbolTable);
55 }
56 else if (initializedNode->getType().isStructureContainingArrays() ||
57 initializedNode->getType().isNamelessStruct())
58 {
59 AddStructZeroInitSequence(initializedNode, canUseLoopsToInitialize, highPrecisionSupported,
60 initSequenceOut, symbolTable);
61 }
62 else if (initializedNode->getType().isInterfaceBlock())
63 {
64 const TType &type = initializedNode->getType();
65 const TInterfaceBlock &interfaceBlock = *type.getInterfaceBlock();
66 const TFieldList &fieldList = interfaceBlock.fields();
67 for (size_t fieldIndex = 0; fieldIndex < fieldList.size(); ++fieldIndex)
68 {
69 const TField &field = *fieldList[fieldIndex];
70 TIntermTyped *fieldIndexRef = CreateIndexNode(static_cast<int>(fieldIndex));
71 TIntermTyped *fieldReference =
72 new TIntermBinary(TOperator::EOpIndexDirectInterfaceBlock,
73 initializedNode->deepCopy(), fieldIndexRef);
74 TIntermTyped *fieldZero = CreateZeroNode(*field.type());
75 TIntermTyped *assignment =
76 new TIntermBinary(TOperator::EOpAssign, fieldReference, fieldZero);
77 initSequenceOut->push_back(assignment);
78 }
79 }
80 else
81 {
82 initSequenceOut->push_back(CreateZeroInitAssignment(initializedNode));
83 }
84 }
85
AddStructZeroInitSequence(const TIntermTyped * initializedNode,bool canUseLoopsToInitialize,bool highPrecisionSupported,TIntermSequence * initSequenceOut,TSymbolTable * symbolTable)86 void AddStructZeroInitSequence(const TIntermTyped *initializedNode,
87 bool canUseLoopsToInitialize,
88 bool highPrecisionSupported,
89 TIntermSequence *initSequenceOut,
90 TSymbolTable *symbolTable)
91 {
92 ASSERT(initializedNode->getBasicType() == EbtStruct);
93 const TStructure *structType = initializedNode->getType().getStruct();
94 for (int i = 0; i < static_cast<int>(structType->fields().size()); ++i)
95 {
96 TIntermBinary *element = new TIntermBinary(EOpIndexDirectStruct,
97 initializedNode->deepCopy(), CreateIndexNode(i));
98 // Structs can't be defined inside structs, so the type of a struct field can't be a
99 // nameless struct.
100 ASSERT(!element->getType().isNamelessStruct());
101 AddZeroInitSequence(element, canUseLoopsToInitialize, highPrecisionSupported,
102 initSequenceOut, symbolTable);
103 }
104 }
105
AddArrayZeroInitStatementList(const TIntermTyped * initializedNode,bool canUseLoopsToInitialize,bool highPrecisionSupported,TIntermSequence * initSequenceOut,TSymbolTable * symbolTable)106 void AddArrayZeroInitStatementList(const TIntermTyped *initializedNode,
107 bool canUseLoopsToInitialize,
108 bool highPrecisionSupported,
109 TIntermSequence *initSequenceOut,
110 TSymbolTable *symbolTable)
111 {
112 for (unsigned int i = 0; i < initializedNode->getOutermostArraySize(); ++i)
113 {
114 TIntermBinary *element =
115 new TIntermBinary(EOpIndexDirect, initializedNode->deepCopy(), CreateIndexNode(i));
116 AddZeroInitSequence(element, canUseLoopsToInitialize, highPrecisionSupported,
117 initSequenceOut, symbolTable);
118 }
119 }
120
AddArrayZeroInitForLoop(const TIntermTyped * initializedNode,bool highPrecisionSupported,TIntermSequence * initSequenceOut,TSymbolTable * symbolTable)121 void AddArrayZeroInitForLoop(const TIntermTyped *initializedNode,
122 bool highPrecisionSupported,
123 TIntermSequence *initSequenceOut,
124 TSymbolTable *symbolTable)
125 {
126 ASSERT(initializedNode->isArray());
127 const TType *mediumpIndexType = StaticType::Get<EbtInt, EbpMedium, EvqTemporary, 1, 1>();
128 const TType *highpIndexType = StaticType::Get<EbtInt, EbpHigh, EvqTemporary, 1, 1>();
129 TVariable *indexVariable =
130 CreateTempVariable(symbolTable, highPrecisionSupported ? highpIndexType : mediumpIndexType);
131
132 TIntermSymbol *indexSymbolNode = CreateTempSymbolNode(indexVariable);
133 TIntermDeclaration *indexInit =
134 CreateTempInitDeclarationNode(indexVariable, CreateZeroNode(indexVariable->getType()));
135 TIntermConstantUnion *arraySizeNode = CreateIndexNode(initializedNode->getOutermostArraySize());
136 TIntermBinary *indexSmallerThanSize =
137 new TIntermBinary(EOpLessThan, indexSymbolNode->deepCopy(), arraySizeNode);
138 TIntermUnary *indexIncrement =
139 new TIntermUnary(EOpPreIncrement, indexSymbolNode->deepCopy(), nullptr);
140
141 TIntermBlock *forLoopBody = new TIntermBlock();
142 TIntermSequence *forLoopBodySeq = forLoopBody->getSequence();
143
144 TIntermBinary *element = new TIntermBinary(EOpIndexIndirect, initializedNode->deepCopy(),
145 indexSymbolNode->deepCopy());
146 AddZeroInitSequence(element, true, highPrecisionSupported, forLoopBodySeq, symbolTable);
147
148 TIntermLoop *forLoop =
149 new TIntermLoop(ELoopFor, indexInit, indexSmallerThanSize, indexIncrement, forLoopBody);
150 initSequenceOut->push_back(forLoop);
151 }
152
AddArrayZeroInitSequence(const TIntermTyped * initializedNode,bool canUseLoopsToInitialize,bool highPrecisionSupported,TIntermSequence * initSequenceOut,TSymbolTable * symbolTable)153 void AddArrayZeroInitSequence(const TIntermTyped *initializedNode,
154 bool canUseLoopsToInitialize,
155 bool highPrecisionSupported,
156 TIntermSequence *initSequenceOut,
157 TSymbolTable *symbolTable)
158 {
159 // The array elements are assigned one by one to keep the AST compatible with ESSL 1.00 which
160 // doesn't have array assignment. We'll do this either with a for loop or just a list of
161 // statements assigning to each array index. Note that it is important to have the array init in
162 // the right order to workaround http://crbug.com/709317
163 bool isSmallArray = initializedNode->getOutermostArraySize() <= 1u ||
164 (initializedNode->getBasicType() != EbtStruct &&
165 !initializedNode->getType().isArrayOfArrays() &&
166 initializedNode->getOutermostArraySize() <= 3u);
167 if (initializedNode->getQualifier() == EvqFragData ||
168 initializedNode->getQualifier() == EvqFragmentOut || isSmallArray ||
169 !canUseLoopsToInitialize)
170 {
171 // Fragment outputs should not be indexed by non-constant indices.
172 // Also it doesn't make sense to use loops to initialize very small arrays.
173 AddArrayZeroInitStatementList(initializedNode, canUseLoopsToInitialize,
174 highPrecisionSupported, initSequenceOut, symbolTable);
175 }
176 else
177 {
178 AddArrayZeroInitForLoop(initializedNode, highPrecisionSupported, initSequenceOut,
179 symbolTable);
180 }
181 }
182
InsertInitCode(TCompiler * compiler,TIntermBlock * root,const InitVariableList & variables,TSymbolTable * symbolTable,int shaderVersion,const TExtensionBehavior & extensionBehavior,bool canUseLoopsToInitialize,bool highPrecisionSupported)183 void InsertInitCode(TCompiler *compiler,
184 TIntermBlock *root,
185 const InitVariableList &variables,
186 TSymbolTable *symbolTable,
187 int shaderVersion,
188 const TExtensionBehavior &extensionBehavior,
189 bool canUseLoopsToInitialize,
190 bool highPrecisionSupported)
191 {
192 TIntermSequence *mainBody = FindMainBody(root)->getSequence();
193 for (const ShaderVariable &var : variables)
194 {
195 // Note that tempVariableName will reference a short-lived char array here - that's fine
196 // since we're only using it to find symbols.
197 ImmutableString tempVariableName(var.name.c_str(), var.name.length());
198
199 TIntermTyped *initializedSymbol = nullptr;
200 if (var.isBuiltIn() && !symbolTable->findUserDefined(tempVariableName))
201 {
202 initializedSymbol =
203 ReferenceBuiltInVariable(tempVariableName, *symbolTable, shaderVersion);
204 if (initializedSymbol->getQualifier() == EvqFragData &&
205 !IsExtensionEnabled(extensionBehavior, TExtension::EXT_draw_buffers))
206 {
207 // If GL_EXT_draw_buffers is disabled, only the 0th index of gl_FragData can be
208 // written to.
209 // TODO(oetuaho): This is a bit hacky and would be better to remove, if we came up
210 // with a good way to do it. Right now "gl_FragData" in symbol table is initialized
211 // to have the array size of MaxDrawBuffers, and the initialization happens before
212 // the shader sets the extensions it is using.
213 initializedSymbol =
214 new TIntermBinary(EOpIndexDirect, initializedSymbol, CreateIndexNode(0));
215 }
216 else if (initializedSymbol->getQualifier() == EvqClipDistance ||
217 initializedSymbol->getQualifier() == EvqCullDistance)
218 {
219 // The built-in may have been implicitly resized.
220 initializedSymbol =
221 new TIntermSymbol(&FindSymbolNode(root, tempVariableName)->variable());
222 }
223 }
224 else
225 {
226 if (tempVariableName != "")
227 {
228 initializedSymbol =
229 new TIntermSymbol(&FindSymbolNode(root, tempVariableName)->variable());
230 }
231 else
232 {
233 // Must be a nameless interface block.
234 ASSERT(var.structOrBlockName != "");
235 const TSymbol *symbol = symbolTable->findGlobal(var.structOrBlockName);
236 ASSERT(symbol && symbol->isInterfaceBlock());
237 const TInterfaceBlock *block = static_cast<const TInterfaceBlock *>(symbol);
238
239 for (const TField *field : block->fields())
240 {
241 initializedSymbol = ReferenceGlobalVariable(field->name(), *symbolTable);
242
243 TIntermSequence initCode;
244 CreateInitCode(initializedSymbol, canUseLoopsToInitialize,
245 highPrecisionSupported, &initCode, symbolTable);
246 mainBody->insert(mainBody->begin(), initCode.begin(), initCode.end());
247 }
248 // Already inserted init code in this case
249 continue;
250 }
251 }
252 ASSERT(initializedSymbol != nullptr);
253
254 TIntermSequence initCode;
255 CreateInitCode(initializedSymbol, canUseLoopsToInitialize, highPrecisionSupported,
256 &initCode, symbolTable);
257 mainBody->insert(mainBody->begin(), initCode.begin(), initCode.end());
258 }
259 }
260
CloneFunctionHeader(TSymbolTable * symbolTable,const TFunction * function)261 TFunction *CloneFunctionHeader(TSymbolTable *symbolTable, const TFunction *function)
262 {
263 TFunction *newFunction =
264 new TFunction(symbolTable, function->name(), function->symbolType(),
265 &function->getReturnType(), function->isKnownToNotHaveSideEffects());
266
267 if (function->isDefined())
268 {
269 newFunction->setDefined();
270 }
271 if (function->hasPrototypeDeclaration())
272 {
273 newFunction->setHasPrototypeDeclaration();
274 }
275 return newFunction;
276 }
277
278 class InitializeLocalsTraverser final : public TIntermTraverser
279 {
280 public:
InitializeLocalsTraverser(int shaderVersion,TSymbolTable * symbolTable,bool canUseLoopsToInitialize,bool highPrecisionSupported)281 InitializeLocalsTraverser(int shaderVersion,
282 TSymbolTable *symbolTable,
283 bool canUseLoopsToInitialize,
284 bool highPrecisionSupported)
285 : TIntermTraverser(true, false, false, symbolTable),
286 mShaderVersion(shaderVersion),
287 mCanUseLoopsToInitialize(canUseLoopsToInitialize),
288 mHighPrecisionSupported(highPrecisionSupported)
289 {}
290
collectUnnamedOutFunctions(TIntermBlock & root)291 void collectUnnamedOutFunctions(TIntermBlock &root)
292 {
293 TIntermSequence &sequence = *root.getSequence();
294 const size_t count = sequence.size();
295 for (size_t i = 0; i < count; ++i)
296 {
297 const TIntermFunctionDefinition *functionDefinition =
298 sequence[i]->getAsFunctionDefinition();
299 if (!functionDefinition)
300 {
301 continue;
302 }
303 const TFunction *function = functionDefinition->getFunction();
304 TFunction *newFunction = nullptr;
305 for (size_t p = 0; p < function->getParamCount(); ++p)
306 {
307 const TVariable *param = function->getParam(p);
308 const TType &type = param->getType();
309 if (param->symbolType() == SymbolType::Empty)
310 {
311 if (!newFunction)
312 {
313 newFunction = CloneFunctionHeader(mSymbolTable, function);
314 mFunctionsToReplace[function] = newFunction;
315 for (size_t z = 0; z < p; ++z)
316 {
317 newFunction->addParameter(function->getParam(z));
318 }
319 }
320 param = new TVariable(mSymbolTable, kEmptyImmutableString, &type,
321 SymbolType::AngleInternal, param->extensions());
322 }
323 if (newFunction)
324 {
325 newFunction->addParameter(param);
326 }
327 }
328 }
329 }
330
331 protected:
visitDeclaration(Visit visit,TIntermDeclaration * node)332 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
333 {
334 for (TIntermNode *declarator : *node->getSequence())
335 {
336 if (!mInGlobalScope && !declarator->getAsBinaryNode())
337 {
338 TIntermSymbol *symbol = declarator->getAsSymbolNode();
339 ASSERT(symbol);
340 if (symbol->variable().symbolType() == SymbolType::Empty)
341 {
342 continue;
343 }
344
345 // Arrays may need to be initialized one element at a time, since ESSL 1.00 does not
346 // support array constructors or assigning arrays.
347 bool arrayConstructorUnavailable =
348 (symbol->isArray() || symbol->getType().isStructureContainingArrays()) &&
349 mShaderVersion == 100;
350 // Nameless struct constructors can't be referred to, so they also need to be
351 // initialized one element at a time.
352 // TODO(oetuaho): Check if it makes sense to initialize using a loop, even if we
353 // could use an initializer. It could at least reduce code size for very large
354 // arrays, but could hurt runtime performance.
355 if (arrayConstructorUnavailable || symbol->getType().isNamelessStruct())
356 {
357 // SimplifyLoopConditions should have been run so the parent node of this node
358 // should not be a loop.
359 ASSERT(getParentNode()->getAsLoopNode() == nullptr);
360 // SeparateDeclarations should have already been run, so we don't need to worry
361 // about further declarators in this declaration depending on the effects of
362 // this declarator.
363 ASSERT(node->getSequence()->size() == 1);
364 TIntermSequence initCode;
365 CreateInitCode(symbol, mCanUseLoopsToInitialize, mHighPrecisionSupported,
366 &initCode, mSymbolTable);
367 insertStatementsInParentBlock(TIntermSequence(), initCode);
368 }
369 else
370 {
371 TIntermBinary *init =
372 new TIntermBinary(EOpInitialize, symbol, CreateZeroNode(symbol->getType()));
373 queueReplacementWithParent(node, symbol, init, OriginalNode::BECOMES_CHILD);
374 }
375 }
376 }
377 // Must recurse in the cases which had initializers, because the initializiers might
378 // call the function that was rewritten.
379 return true;
380 }
381
visitFunctionPrototype(TIntermFunctionPrototype * node)382 void visitFunctionPrototype(TIntermFunctionPrototype *node) override
383 {
384 if (getParentNode()->getAsFunctionDefinition() != nullptr)
385 {
386 return;
387 }
388 auto it = mFunctionsToReplace.find(node->getFunction());
389 if (it != mFunctionsToReplace.end())
390 {
391 queueReplacement(new TIntermFunctionPrototype(it->second), OriginalNode::IS_DROPPED);
392 }
393 }
394
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)395 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
396 {
397 // Initialize output function arguments as well, the parameter passed in at call time may be
398 // clobbered if the function doesn't fully write to the argument.
399
400 TIntermSequence initCode;
401
402 const TFunction *function = node->getFunction();
403 auto it = mFunctionsToReplace.find(function);
404 if (it != mFunctionsToReplace.end())
405 {
406 function = it->second;
407 TIntermFunctionPrototype *newPrototypeNode = new TIntermFunctionPrototype(function);
408 TIntermFunctionDefinition *newNode =
409 new TIntermFunctionDefinition(newPrototypeNode, node->getBody());
410 queueReplacement(newNode, OriginalNode::IS_DROPPED);
411 }
412
413 for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
414 {
415 const TVariable *paramVariable = function->getParam(paramIndex);
416 const TType ¶mType = paramVariable->getType();
417
418 if (paramType.getQualifier() != EvqParamOut)
419 {
420 continue;
421 }
422
423 CreateInitCode(new TIntermSymbol(paramVariable), mCanUseLoopsToInitialize,
424 mHighPrecisionSupported, &initCode, mSymbolTable);
425 }
426
427 if (!initCode.empty())
428 {
429 TIntermSequence *body = node->getBody()->getSequence();
430 body->insert(body->begin(), initCode.begin(), initCode.end());
431 }
432
433 return true;
434 }
435
visitAggregate(Visit visit,TIntermAggregate * node)436 bool visitAggregate(Visit visit, TIntermAggregate *node) override
437 {
438 const TFunction *function = node->getFunction();
439 if (function != nullptr)
440 {
441 auto it = mFunctionsToReplace.find(function);
442 if (it != mFunctionsToReplace.end())
443 {
444 const TFunction *target = it->second;
445 TIntermAggregate *newNode =
446 TIntermAggregate::CreateFunctionCall(*target, node->getSequence());
447 queueReplacement(newNode, OriginalNode::IS_DROPPED);
448 }
449 }
450 return true;
451 }
452
453 private:
454 int mShaderVersion;
455 bool mCanUseLoopsToInitialize;
456 bool mHighPrecisionSupported;
457 angle::HashMap<const TFunction *, TFunction *> mFunctionsToReplace;
458 };
459
460 } // namespace
461
CreateInitCode(const TIntermTyped * initializedSymbol,bool canUseLoopsToInitialize,bool highPrecisionSupported,TIntermSequence * initCode,TSymbolTable * symbolTable)462 void CreateInitCode(const TIntermTyped *initializedSymbol,
463 bool canUseLoopsToInitialize,
464 bool highPrecisionSupported,
465 TIntermSequence *initCode,
466 TSymbolTable *symbolTable)
467 {
468 AddZeroInitSequence(initializedSymbol, canUseLoopsToInitialize, highPrecisionSupported,
469 initCode, symbolTable);
470 }
471
InitializeUninitializedLocals(TCompiler * compiler,TIntermBlock * root,int shaderVersion,bool canUseLoopsToInitialize,bool highPrecisionSupported,TSymbolTable * symbolTable)472 bool InitializeUninitializedLocals(TCompiler *compiler,
473 TIntermBlock *root,
474 int shaderVersion,
475 bool canUseLoopsToInitialize,
476 bool highPrecisionSupported,
477 TSymbolTable *symbolTable)
478 {
479 InitializeLocalsTraverser traverser(shaderVersion, symbolTable, canUseLoopsToInitialize,
480 highPrecisionSupported);
481 traverser.collectUnnamedOutFunctions(*root);
482 root->traverse(&traverser);
483 return traverser.updateTree(compiler, root);
484 }
485
InitializeVariables(TCompiler * compiler,TIntermBlock * root,const InitVariableList & vars,TSymbolTable * symbolTable,int shaderVersion,const TExtensionBehavior & extensionBehavior,bool canUseLoopsToInitialize,bool highPrecisionSupported)486 bool InitializeVariables(TCompiler *compiler,
487 TIntermBlock *root,
488 const InitVariableList &vars,
489 TSymbolTable *symbolTable,
490 int shaderVersion,
491 const TExtensionBehavior &extensionBehavior,
492 bool canUseLoopsToInitialize,
493 bool highPrecisionSupported)
494 {
495 InsertInitCode(compiler, root, vars, symbolTable, shaderVersion, extensionBehavior,
496 canUseLoopsToInitialize, highPrecisionSupported);
497
498 return compiler->validateAST(root);
499 }
500
501 } // namespace sh
502