1 //
2 // Copyright 2018 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteStructSamplers: Extract samplers from structs.
7 //
8
9 #include "compiler/translator/tree_ops/RewriteStructSamplers.h"
10
11 #include "common/hash_containers.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/tree_util/IntermNode_util.h"
15 #include "compiler/translator/tree_util/IntermTraverse.h"
16
17 namespace sh
18 {
19 namespace
20 {
21
22 // Used to map one structure type to another (one where the samplers are removed).
23 struct StructureData
24 {
25 // The structure this was replaced with. If nullptr, it means the structure is removed (because
26 // it had all samplers).
27 //
28 // ParseContext reorders the samplers to the end of the struct, so the EOpIndexDirectStruct
29 // expressions that select non-sampler members don't have to change when they are moved out of
30 // the struct.
31 const TStructure *modified;
32 };
33
34 using StructureMap = angle::HashMap<const TStructure *, StructureData>;
35 using StructureUniformMap = angle::HashMap<const TVariable *, const TVariable *>;
36 using ExtractedSamplerMap = angle::HashMap<std::string, const TVariable *>;
37
38 TIntermTyped *RewriteModifiedStructFieldSelectionExpression(
39 TCompiler *compiler,
40 TIntermBinary *node,
41 const StructureMap &structureMap,
42 const StructureUniformMap &structureUniformMap,
43 const ExtractedSamplerMap &extractedSamplers);
44
RewriteExpressionVisitBinaryHelper(TCompiler * compiler,TIntermBinary * node,const StructureMap & structureMap,const StructureUniformMap & structureUniformMap,const ExtractedSamplerMap & extractedSamplers)45 TIntermTyped *RewriteExpressionVisitBinaryHelper(TCompiler *compiler,
46 TIntermBinary *node,
47 const StructureMap &structureMap,
48 const StructureUniformMap &structureUniformMap,
49 const ExtractedSamplerMap &extractedSamplers)
50 {
51 // Only interested in EOpIndexDirectStruct binary nodes.
52 if (node->getOp() != EOpIndexDirectStruct)
53 {
54 return nullptr;
55 }
56
57 const TStructure *structure = node->getLeft()->getType().getStruct();
58 ASSERT(structure);
59
60 // If the result of the index is not a sampler and the struct is not replaced, there's nothing
61 // to do.
62 if (!node->getType().isSampler() && structureMap.find(structure) == structureMap.end())
63 {
64 return nullptr;
65 }
66
67 // Otherwise, replace the whole expression such that:
68 //
69 // - if sampler, it's indexed with whatever indices the parent structs were indexed with,
70 // - otherwise, the chain of field selections is rewritten by modifying the base uniform so all
71 // the intermediate nodes would have the correct type (and therefore fields).
72 ASSERT(structureMap.find(structure) != structureMap.end());
73
74 return RewriteModifiedStructFieldSelectionExpression(compiler, node, structureMap,
75 structureUniformMap, extractedSamplers);
76 }
77
78 // Given an expression, this traverser calculates a new expression where sampler-in-structs are
79 // replaced with their extracted ones, and field indices are adjusted for the rest of the fields.
80 // In particular, this is run on the right node of EOpIndexIndirect binary nodes, so that the
81 // expression in the index gets a chance to go through this transformation.
82 class RewriteExpressionTraverser final : public TIntermTraverser
83 {
84 public:
RewriteExpressionTraverser(TCompiler * compiler,const StructureMap & structureMap,const StructureUniformMap & structureUniformMap,const ExtractedSamplerMap & extractedSamplers)85 explicit RewriteExpressionTraverser(TCompiler *compiler,
86 const StructureMap &structureMap,
87 const StructureUniformMap &structureUniformMap,
88 const ExtractedSamplerMap &extractedSamplers)
89 : TIntermTraverser(true, false, false),
90 mCompiler(compiler),
91 mStructureMap(structureMap),
92 mStructureUniformMap(structureUniformMap),
93 mExtractedSamplers(extractedSamplers)
94 {}
95
visitBinary(Visit visit,TIntermBinary * node)96 bool visitBinary(Visit visit, TIntermBinary *node) override
97 {
98 TIntermTyped *rewritten = RewriteExpressionVisitBinaryHelper(
99 mCompiler, node, mStructureMap, mStructureUniformMap, mExtractedSamplers);
100
101 if (rewritten == nullptr)
102 {
103 return true;
104 }
105
106 queueReplacement(rewritten, OriginalNode::IS_DROPPED);
107
108 // Don't iterate as the expression is rewritten.
109 return false;
110 }
111
visitSymbol(TIntermSymbol * node)112 void visitSymbol(TIntermSymbol *node) override
113 {
114 // It's impossible to reach here with a symbol that needs replacement.
115 // MonomorphizeUnsupportedFunctions makes sure that whole structs containing
116 // samplers are not passed to functions, so any instance of the struct uniform is
117 // necessarily indexed right away. visitBinary should have already taken care of it.
118 ASSERT(mStructureUniformMap.find(&node->variable()) == mStructureUniformMap.end());
119 }
120
121 private:
122 TCompiler *mCompiler;
123
124 // See RewriteStructSamplersTraverser.
125 const StructureMap &mStructureMap;
126 const StructureUniformMap &mStructureUniformMap;
127 const ExtractedSamplerMap &mExtractedSamplers;
128 };
129
130 // Rewrite the index of an EOpIndexIndirect expression. The root can never need replacing, because
131 // it cannot be a sampler itself or of a struct type.
RewriteIndexExpression(TCompiler * compiler,TIntermTyped * expression,const StructureMap & structureMap,const StructureUniformMap & structureUniformMap,const ExtractedSamplerMap & extractedSamplers)132 void RewriteIndexExpression(TCompiler *compiler,
133 TIntermTyped *expression,
134 const StructureMap &structureMap,
135 const StructureUniformMap &structureUniformMap,
136 const ExtractedSamplerMap &extractedSamplers)
137 {
138 RewriteExpressionTraverser traverser(compiler, structureMap, structureUniformMap,
139 extractedSamplers);
140 expression->traverse(&traverser);
141 bool valid = traverser.updateTree(compiler, expression);
142 ASSERT(valid);
143 }
144
145 // Given an expression such as the following:
146 //
147 // EOpIndexDirectStruct (sampler)
148 // / \
149 // EOpIndex* field index
150 // / \
151 // EOpIndexDirectStruct index 2
152 // / \
153 // EOpIndex* field index
154 // / \
155 // EOpIndexDirectStruct index 1
156 // / \
157 // Uniform Struct field index
158 //
159 // produces:
160 //
161 // EOpIndex*
162 // / \
163 // EOpIndex* index 2
164 // / \
165 // sampler index 1
166 //
167 // If the expression is not a sampler, it only replaces the struct with the modified one, while
168 // still processing the EOpIndexIndirect expressions (which may contain more structs to map).
RewriteModifiedStructFieldSelectionExpression(TCompiler * compiler,TIntermBinary * node,const StructureMap & structureMap,const StructureUniformMap & structureUniformMap,const ExtractedSamplerMap & extractedSamplers)169 TIntermTyped *RewriteModifiedStructFieldSelectionExpression(
170 TCompiler *compiler,
171 TIntermBinary *node,
172 const StructureMap &structureMap,
173 const StructureUniformMap &structureUniformMap,
174 const ExtractedSamplerMap &extractedSamplers)
175 {
176 ASSERT(node->getOp() == EOpIndexDirectStruct);
177
178 const bool isSampler = node->getType().isSampler();
179
180 TIntermSymbol *baseUniform = nullptr;
181 std::string samplerName;
182
183 TVector<TIntermBinary *> indexNodeStack;
184
185 // Iterate once and build the name of the sampler.
186 TIntermBinary *iter = node;
187 while (baseUniform == nullptr)
188 {
189 indexNodeStack.push_back(iter);
190 baseUniform = iter->getLeft()->getAsSymbolNode();
191
192 if (isSampler)
193 {
194 if (iter->getOp() == EOpIndexDirectStruct)
195 {
196 // When indexed into a struct, get the field name instead and construct the sampler
197 // name.
198 samplerName.insert(0, iter->getIndexStructFieldName().data());
199 samplerName.insert(0, "_");
200 }
201
202 if (baseUniform)
203 {
204 // If left is a symbol, we have reached the end of the chain. Use the struct name
205 // to finish building the name of the sampler.
206 samplerName.insert(0, baseUniform->variable().name().data());
207 }
208 }
209
210 iter = iter->getLeft()->getAsBinaryNode();
211 }
212
213 TIntermTyped *rewritten = nullptr;
214
215 if (isSampler)
216 {
217 ASSERT(extractedSamplers.find(samplerName) != extractedSamplers.end());
218 rewritten = new TIntermSymbol(extractedSamplers.at(samplerName));
219 }
220 else
221 {
222 const TVariable *baseUniformVar = &baseUniform->variable();
223 ASSERT(structureUniformMap.find(baseUniformVar) != structureUniformMap.end());
224 rewritten = new TIntermSymbol(structureUniformMap.at(baseUniformVar));
225 }
226
227 // Iterate again and build the expression from bottom up.
228 for (auto it = indexNodeStack.rbegin(); it != indexNodeStack.rend(); ++it)
229 {
230 TIntermBinary *indexNode = *it;
231
232 switch (indexNode->getOp())
233 {
234 case EOpIndexDirectStruct:
235 if (!isSampler)
236 {
237 rewritten =
238 new TIntermBinary(EOpIndexDirectStruct, rewritten, indexNode->getRight());
239 }
240 break;
241
242 case EOpIndexDirect:
243 rewritten = new TIntermBinary(EOpIndexDirect, rewritten, indexNode->getRight());
244 break;
245
246 case EOpIndexIndirect:
247 {
248 // Run RewriteExpressionTraverser on the right node. It may itself be an expression
249 // with a sampler inside that needs to be rewritten, or simply use a field of a
250 // struct that's remapped.
251 TIntermTyped *indexExpression = indexNode->getRight();
252 RewriteIndexExpression(compiler, indexExpression, structureMap, structureUniformMap,
253 extractedSamplers);
254 rewritten = new TIntermBinary(EOpIndexIndirect, rewritten, indexExpression);
255 break;
256 }
257
258 default:
259 UNREACHABLE();
260 break;
261 }
262 }
263
264 return rewritten;
265 }
266
267 class RewriteStructSamplersTraverser final : public TIntermTraverser
268 {
269 public:
RewriteStructSamplersTraverser(TCompiler * compiler,TSymbolTable * symbolTable)270 explicit RewriteStructSamplersTraverser(TCompiler *compiler, TSymbolTable *symbolTable)
271 : TIntermTraverser(true, false, false, symbolTable),
272 mCompiler(compiler),
273 mRemovedUniformsCount(0)
274 {}
275
removedUniformsCount() const276 int removedUniformsCount() const { return mRemovedUniformsCount; }
277
278 // Each struct sampler declaration is stripped of its samplers. New uniforms are added for each
279 // stripped struct sampler.
visitDeclaration(Visit visit,TIntermDeclaration * decl)280 bool visitDeclaration(Visit visit, TIntermDeclaration *decl) override
281 {
282 if (!mInGlobalScope)
283 {
284 return true;
285 }
286
287 const TIntermSequence &sequence = *(decl->getSequence());
288 TIntermTyped *declarator = sequence.front()->getAsTyped();
289 const TType &type = declarator->getType();
290
291 if (!type.isStructureContainingSamplers())
292 {
293 return false;
294 }
295
296 TIntermSequence newSequence;
297
298 if (type.isStructSpecifier())
299 {
300 // If this is just a struct definition (not a uniform variable declaration of a
301 // struct type), just remove the samplers. They are not instantiated yet.
302 const TStructure *structure = type.getStruct();
303 ASSERT(structure && mStructureMap.find(structure) == mStructureMap.end());
304
305 stripStructSpecifierSamplers(structure, &newSequence);
306 }
307 else
308 {
309 const TStructure *structure = type.getStruct();
310
311 // If the structure is defined at the same time, create the mapping to the stripped
312 // version first.
313 if (mStructureMap.find(structure) == mStructureMap.end())
314 {
315 stripStructSpecifierSamplers(structure, &newSequence);
316 }
317
318 // Then, extract the samplers from the struct and create global-scope variables instead.
319 TIntermSymbol *asSymbol = declarator->getAsSymbolNode();
320 ASSERT(asSymbol);
321 const TVariable &variable = asSymbol->variable();
322 ASSERT(variable.symbolType() != SymbolType::Empty);
323
324 extractStructSamplerUniforms(variable, structure, &newSequence);
325 }
326
327 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), decl,
328 std::move(newSequence));
329
330 return false;
331 }
332
333 // Same implementation as in RewriteExpressionTraverser. That traverser cannot replace root.
visitBinary(Visit visit,TIntermBinary * node)334 bool visitBinary(Visit visit, TIntermBinary *node) override
335 {
336 TIntermTyped *rewritten = RewriteExpressionVisitBinaryHelper(
337 mCompiler, node, mStructureMap, mStructureUniformMap, mExtractedSamplers);
338
339 if (rewritten == nullptr)
340 {
341 return true;
342 }
343
344 queueReplacement(rewritten, OriginalNode::IS_DROPPED);
345
346 // Don't iterate as the expression is rewritten.
347 return false;
348 }
349
350 // Same implementation as in RewriteExpressionTraverser. That traverser cannot replace root.
visitSymbol(TIntermSymbol * node)351 void visitSymbol(TIntermSymbol *node) override
352 {
353 ASSERT(mStructureUniformMap.find(&node->variable()) == mStructureUniformMap.end());
354 }
355
356 private:
357 // Removes all samplers from a struct specifier.
stripStructSpecifierSamplers(const TStructure * structure,TIntermSequence * newSequence)358 void stripStructSpecifierSamplers(const TStructure *structure, TIntermSequence *newSequence)
359 {
360 TFieldList *newFieldList = new TFieldList;
361 ASSERT(structure->containsSamplers());
362
363 // Add this struct to the struct map
364 ASSERT(mStructureMap.find(structure) == mStructureMap.end());
365 StructureData *modifiedData = &mStructureMap[structure];
366
367 modifiedData->modified = nullptr;
368
369 for (size_t fieldIndex = 0; fieldIndex < structure->fields().size(); ++fieldIndex)
370 {
371 const TField *field = structure->fields()[fieldIndex];
372 const TType &fieldType = *field->type();
373
374 // If the field is a sampler, or a struct that's entirely removed, skip it.
375 if (!fieldType.isSampler() && !isRemovedStructType(fieldType))
376 {
377 TType *newType = nullptr;
378
379 // Otherwise, if it's a struct that's replaced, create a new field of the replaced
380 // type.
381 if (fieldType.isStructureContainingSamplers())
382 {
383 const TStructure *fieldStruct = fieldType.getStruct();
384 ASSERT(mStructureMap.find(fieldStruct) != mStructureMap.end());
385
386 const TStructure *modifiedStruct = mStructureMap[fieldStruct].modified;
387 ASSERT(modifiedStruct);
388
389 newType = new TType(modifiedStruct, true);
390 if (fieldType.isArray())
391 {
392 newType->makeArrays(fieldType.getArraySizes());
393 }
394 }
395 else
396 {
397 // If not, duplicate the field as is.
398 newType = new TType(fieldType);
399 }
400
401 TField *newField =
402 new TField(newType, field->name(), field->line(), field->symbolType());
403 newFieldList->push_back(newField);
404 }
405 }
406
407 // Prune empty structs.
408 if (newFieldList->empty())
409 {
410 return;
411 }
412
413 // Declare a new struct with the same name and the new fields.
414 modifiedData->modified =
415 new TStructure(mSymbolTable,
416 structure->symbolType() == SymbolType::Empty ? kEmptyImmutableString
417 : structure->name(),
418 newFieldList, structure->symbolType());
419 TType *newStructType = new TType(modifiedData->modified, true);
420 TVariable *newStructVar =
421 new TVariable(mSymbolTable, kEmptyImmutableString, newStructType, SymbolType::Empty);
422 TIntermSymbol *newStructRef = new TIntermSymbol(newStructVar);
423
424 TIntermDeclaration *structDecl = new TIntermDeclaration;
425 structDecl->appendDeclarator(newStructRef);
426
427 newSequence->push_back(structDecl);
428 }
429
430 // Returns true if the type is a struct that was removed because we extracted all the members.
isRemovedStructType(const TType & type) const431 bool isRemovedStructType(const TType &type) const
432 {
433 const TStructure *structure = type.getStruct();
434 if (structure == nullptr)
435 {
436 // Not a struct
437 return false;
438 }
439
440 // A struct is removed if it is in the map, but doesn't have a replacement struct.
441 auto iter = mStructureMap.find(structure);
442 return iter != mStructureMap.end() && iter->second.modified == nullptr;
443 }
444
445 // Removes samplers from struct uniforms. For each sampler removed also adds a new globally
446 // defined sampler uniform.
extractStructSamplerUniforms(const TVariable & variable,const TStructure * structure,TIntermSequence * newSequence)447 void extractStructSamplerUniforms(const TVariable &variable,
448 const TStructure *structure,
449 TIntermSequence *newSequence)
450 {
451 ASSERT(structure->containsSamplers());
452 ASSERT(mStructureMap.find(structure) != mStructureMap.end());
453
454 const TType &type = variable.getType();
455 enterArray(type);
456
457 for (const TField *field : structure->fields())
458 {
459 extractFieldSamplers(variable.name().data(), field, newSequence);
460 }
461
462 // If there's a replacement structure (because there are non-sampler fields in the struct),
463 // add a declaration with that type.
464 const TStructure *modified = mStructureMap[structure].modified;
465 if (modified != nullptr)
466 {
467 TType *newType = new TType(modified, false);
468 if (type.isArray())
469 {
470 newType->makeArrays(type.getArraySizes());
471 }
472 newType->setQualifier(EvqUniform);
473 const TVariable *newVariable =
474 new TVariable(mSymbolTable, variable.name(), newType, variable.symbolType());
475
476 TIntermDeclaration *newDecl = new TIntermDeclaration();
477 newDecl->appendDeclarator(new TIntermSymbol(newVariable));
478
479 newSequence->push_back(newDecl);
480
481 ASSERT(mStructureUniformMap.find(&variable) == mStructureUniformMap.end());
482 mStructureUniformMap[&variable] = newVariable;
483 }
484 else
485 {
486 mRemovedUniformsCount++;
487 }
488
489 exitArray(type);
490 }
491
492 // Extracts samplers from a field of a struct. Works with nested structs and arrays.
extractFieldSamplers(const std::string & prefix,const TField * field,TIntermSequence * newSequence)493 void extractFieldSamplers(const std::string &prefix,
494 const TField *field,
495 TIntermSequence *newSequence)
496 {
497 const TType &fieldType = *field->type();
498 if (fieldType.isSampler() || fieldType.isStructureContainingSamplers())
499 {
500 std::string newPrefix = prefix + "_" + field->name().data();
501
502 if (fieldType.isSampler())
503 {
504 extractSampler(newPrefix, fieldType, newSequence);
505 }
506 else
507 {
508 enterArray(fieldType);
509 const TStructure *structure = fieldType.getStruct();
510 for (const TField *nestedField : structure->fields())
511 {
512 extractFieldSamplers(newPrefix, nestedField, newSequence);
513 }
514 exitArray(fieldType);
515 }
516 }
517 }
518
GenerateArraySizesFromStack(TVector<unsigned int> * sizesOut)519 void GenerateArraySizesFromStack(TVector<unsigned int> *sizesOut)
520 {
521 sizesOut->reserve(mArraySizeStack.size());
522
523 for (auto it = mArraySizeStack.rbegin(); it != mArraySizeStack.rend(); ++it)
524 {
525 sizesOut->push_back(*it);
526 }
527 }
528
529 // Extracts a sampler from a struct. Declares the new extracted sampler.
extractSampler(const std::string & newName,const TType & fieldType,TIntermSequence * newSequence)530 void extractSampler(const std::string &newName,
531 const TType &fieldType,
532 TIntermSequence *newSequence)
533 {
534 ASSERT(fieldType.isSampler());
535
536 TType *newType = new TType(fieldType);
537
538 // Add array dimensions accumulated so far due to struct arrays. Note that to support
539 // nested arrays, mArraySizeStack has the outermost size in the front. |makeArrays| thus
540 // expects this in reverse order.
541 TVector<unsigned int> parentArraySizes;
542 GenerateArraySizesFromStack(&parentArraySizes);
543 newType->makeArrays(parentArraySizes);
544
545 ImmutableStringBuilder nameBuilder(newName.size() + 1);
546 nameBuilder << newName;
547
548 newType->setQualifier(EvqUniform);
549 TVariable *newVariable =
550 new TVariable(mSymbolTable, nameBuilder, newType, SymbolType::AngleInternal);
551 TIntermSymbol *newSymbol = new TIntermSymbol(newVariable);
552
553 TIntermDeclaration *samplerDecl = new TIntermDeclaration;
554 samplerDecl->appendDeclarator(newSymbol);
555
556 newSequence->push_back(samplerDecl);
557
558 // TODO: Use a temp name instead of generating a name as currently done. There is no
559 // guarantee that these generated names cannot clash. Create a mapping from the previous
560 // name to the name assigned to the temp variable so ShaderVariable::mappedName can be
561 // updated post-transformation. http://anglebug.com/42262930
562 ASSERT(mExtractedSamplers.find(newName) == mExtractedSamplers.end());
563 mExtractedSamplers[newName] = newVariable;
564 }
565
enterArray(const TType & arrayType)566 void enterArray(const TType &arrayType)
567 {
568 const TSpan<const unsigned int> &arraySizes = arrayType.getArraySizes();
569 for (auto it = arraySizes.rbegin(); it != arraySizes.rend(); ++it)
570 {
571 unsigned int arraySize = *it;
572 mArraySizeStack.push_back(arraySize);
573 }
574 }
575
exitArray(const TType & arrayType)576 void exitArray(const TType &arrayType)
577 {
578 mArraySizeStack.resize(mArraySizeStack.size() - arrayType.getNumArraySizes());
579 }
580
581 TCompiler *mCompiler;
582 int mRemovedUniformsCount;
583
584 // Map structures with samplers to ones that have their samplers removed.
585 StructureMap mStructureMap;
586
587 // Map uniform variables of structure type that are replaced with another variable.
588 StructureUniformMap mStructureUniformMap;
589
590 // Map a constructed sampler name to its variable. Used to replace an expression that uses this
591 // sampler with the extracted one.
592 ExtractedSamplerMap mExtractedSamplers;
593
594 // A stack of array sizes. Used to figure out the array dimensions of the extracted sampler,
595 // for example when it's nested in an array of structs in an array of structs.
596 TVector<unsigned int> mArraySizeStack;
597 };
598 } // anonymous namespace
599
RewriteStructSamplers(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,int * removedUniformsCountOut)600 bool RewriteStructSamplers(TCompiler *compiler,
601 TIntermBlock *root,
602 TSymbolTable *symbolTable,
603 int *removedUniformsCountOut)
604 {
605 RewriteStructSamplersTraverser traverser(compiler, symbolTable);
606 root->traverse(&traverser);
607 *removedUniformsCountOut = traverser.removedUniformsCount();
608 return traverser.updateTree(compiler, root);
609 }
610 } // namespace sh
611