xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/RemoveUnreferencedVariables.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2017 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RemoveUnreferencedVariables.cpp:
7 //  Drop variables that are declared but never referenced in the AST. This avoids adding unnecessary
8 //  initialization code for them. Also removes unreferenced struct types.
9 //
10 
11 #include "compiler/translator/tree_ops/RemoveUnreferencedVariables.h"
12 
13 #include "common/hash_containers.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermTraverse.h"
16 
17 namespace sh
18 {
19 
20 namespace
21 {
22 
23 class CollectVariableRefCountsTraverser : public TIntermTraverser
24 {
25   public:
26     CollectVariableRefCountsTraverser();
27 
28     using RefCountMap = angle::HashMap<int, unsigned int>;
getSymbolIdRefCounts()29     RefCountMap &getSymbolIdRefCounts() { return mSymbolIdRefCounts; }
getStructIdRefCounts()30     RefCountMap &getStructIdRefCounts() { return mStructIdRefCounts; }
31 
32     void visitSymbol(TIntermSymbol *node) override;
33     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
34     void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
35 
36   private:
37     void incrementStructTypeRefCount(const TType &type);
38 
39     RefCountMap mSymbolIdRefCounts;
40 
41     // Structure reference counts are counted from symbols, constructors, function calls, function
42     // return values and from interface block and structure fields. We need to track both function
43     // calls and function return values since there's a compiler option not to prune unused
44     // functions. The type of a constant union may also be a struct, but statements that are just a
45     // constant union are always pruned, and if the constant union is used somehow it will get
46     // counted by something else.
47     RefCountMap mStructIdRefCounts;
48 };
49 
CollectVariableRefCountsTraverser()50 CollectVariableRefCountsTraverser::CollectVariableRefCountsTraverser()
51     : TIntermTraverser(true, false, false)
52 {}
53 
incrementStructTypeRefCount(const TType & type)54 void CollectVariableRefCountsTraverser::incrementStructTypeRefCount(const TType &type)
55 {
56     if (type.isInterfaceBlock())
57     {
58         const auto *block = type.getInterfaceBlock();
59         ASSERT(block);
60 
61         // We can end up incrementing ref counts of struct types referenced from an interface block
62         // multiple times for the same block. This doesn't matter, because interface blocks can't be
63         // pruned so we'll never do the reverse operation.
64         for (const auto &field : block->fields())
65         {
66             ASSERT(!field->type()->isInterfaceBlock());
67             incrementStructTypeRefCount(*field->type());
68         }
69         return;
70     }
71 
72     const auto *structure = type.getStruct();
73     if (structure != nullptr)
74     {
75         auto structIter = mStructIdRefCounts.find(structure->uniqueId().get());
76         if (structIter == mStructIdRefCounts.end())
77         {
78             mStructIdRefCounts[structure->uniqueId().get()] = 1u;
79 
80             for (const auto &field : structure->fields())
81             {
82                 incrementStructTypeRefCount(*field->type());
83             }
84 
85             return;
86         }
87         ++(structIter->second);
88     }
89 }
90 
visitSymbol(TIntermSymbol * node)91 void CollectVariableRefCountsTraverser::visitSymbol(TIntermSymbol *node)
92 {
93     incrementStructTypeRefCount(node->getType());
94 
95     auto iter = mSymbolIdRefCounts.find(node->uniqueId().get());
96     if (iter == mSymbolIdRefCounts.end())
97     {
98         mSymbolIdRefCounts[node->uniqueId().get()] = 1u;
99         return;
100     }
101     ++(iter->second);
102 }
103 
visitAggregate(Visit visit,TIntermAggregate * node)104 bool CollectVariableRefCountsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
105 {
106     // This tracks struct references in both function calls and constructors.
107     incrementStructTypeRefCount(node->getType());
108     return true;
109 }
110 
visitFunctionPrototype(TIntermFunctionPrototype * node)111 void CollectVariableRefCountsTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
112 {
113     incrementStructTypeRefCount(node->getType());
114     size_t paramCount = node->getFunction()->getParamCount();
115     for (size_t i = 0; i < paramCount; ++i)
116     {
117         incrementStructTypeRefCount(node->getFunction()->getParam(i)->getType());
118     }
119 }
120 
121 // Traverser that removes all unreferenced variables on one traversal.
122 class RemoveUnreferencedVariablesTraverser : public TIntermTraverser
123 {
124   public:
125     RemoveUnreferencedVariablesTraverser(
126         CollectVariableRefCountsTraverser::RefCountMap *symbolIdRefCounts,
127         CollectVariableRefCountsTraverser::RefCountMap *structIdRefCounts,
128         TSymbolTable *symbolTable);
129 
130     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
131     void visitSymbol(TIntermSymbol *node) override;
132     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
133 
134     // Traverse loop and block nodes in reverse order. Note that this traverser does not track
135     // parent block positions, so insertStatementInParentBlock is unusable!
136     void traverseBlock(TIntermBlock *block) override;
137     void traverseLoop(TIntermLoop *loop) override;
138 
139   private:
140     void removeVariableDeclaration(TIntermDeclaration *node, TIntermTyped *declarator);
141     void decrementStructTypeRefCount(const TType &type);
142 
143     CollectVariableRefCountsTraverser::RefCountMap *mSymbolIdRefCounts;
144     CollectVariableRefCountsTraverser::RefCountMap *mStructIdRefCounts;
145     bool mRemoveReferences;
146 };
147 
RemoveUnreferencedVariablesTraverser(CollectVariableRefCountsTraverser::RefCountMap * symbolIdRefCounts,CollectVariableRefCountsTraverser::RefCountMap * structIdRefCounts,TSymbolTable * symbolTable)148 RemoveUnreferencedVariablesTraverser::RemoveUnreferencedVariablesTraverser(
149     CollectVariableRefCountsTraverser::RefCountMap *symbolIdRefCounts,
150     CollectVariableRefCountsTraverser::RefCountMap *structIdRefCounts,
151     TSymbolTable *symbolTable)
152     : TIntermTraverser(true, false, true, symbolTable),
153       mSymbolIdRefCounts(symbolIdRefCounts),
154       mStructIdRefCounts(structIdRefCounts),
155       mRemoveReferences(false)
156 {}
157 
decrementStructTypeRefCount(const TType & type)158 void RemoveUnreferencedVariablesTraverser::decrementStructTypeRefCount(const TType &type)
159 {
160     auto *structure = type.getStruct();
161     if (structure != nullptr)
162     {
163         ASSERT(mStructIdRefCounts->find(structure->uniqueId().get()) != mStructIdRefCounts->end());
164         unsigned int structRefCount = --(*mStructIdRefCounts)[structure->uniqueId().get()];
165 
166         if (structRefCount == 0)
167         {
168             for (const auto &field : structure->fields())
169             {
170                 decrementStructTypeRefCount(*field->type());
171             }
172         }
173     }
174 }
175 
removeVariableDeclaration(TIntermDeclaration * node,TIntermTyped * declarator)176 void RemoveUnreferencedVariablesTraverser::removeVariableDeclaration(TIntermDeclaration *node,
177                                                                      TIntermTyped *declarator)
178 {
179     if (declarator->getType().isStructSpecifier() && !declarator->getType().isNamelessStruct())
180     {
181         unsigned int structId = declarator->getType().getStruct()->uniqueId().get();
182         unsigned int structRefCountInThisDeclarator = 1u;
183         if (declarator->getAsBinaryNode() &&
184             declarator->getAsBinaryNode()->getRight()->getAsAggregate())
185         {
186             ASSERT(declarator->getAsBinaryNode()->getLeft()->getType().getStruct() ==
187                    declarator->getType().getStruct());
188             ASSERT(declarator->getAsBinaryNode()->getRight()->getType().getStruct() ==
189                    declarator->getType().getStruct());
190             structRefCountInThisDeclarator = 2u;
191         }
192         if ((*mStructIdRefCounts)[structId] > structRefCountInThisDeclarator)
193         {
194             // If this declaration declares a named struct type that is used elsewhere, we need to
195             // keep it. We can still change the declarator though so that it doesn't declare an
196             // unreferenced variable.
197 
198             // Note that since we're not removing the entire declaration, the struct's reference
199             // count will end up being one less than the correct refcount. But since the struct
200             // declaration is kept, the incorrect refcount can't cause any other problems.
201 
202             if (declarator->getAsSymbolNode() &&
203                 declarator->getAsSymbolNode()->variable().symbolType() == SymbolType::Empty)
204             {
205                 // Already an empty declaration - nothing to do.
206                 return;
207             }
208             TVariable *emptyVariable =
209                 new TVariable(mSymbolTable, kEmptyImmutableString, new TType(declarator->getType()),
210                               SymbolType::Empty);
211             queueReplacementWithParent(node, declarator, new TIntermSymbol(emptyVariable),
212                                        OriginalNode::IS_DROPPED);
213             return;
214         }
215     }
216 
217     if (getParentNode()->getAsBlock())
218     {
219         TIntermSequence emptyReplacement;
220         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
221                                         std::move(emptyReplacement));
222     }
223     else
224     {
225         ASSERT(getParentNode()->getAsLoopNode());
226         queueReplacement(nullptr, OriginalNode::IS_DROPPED);
227     }
228 }
229 
visitDeclaration(Visit visit,TIntermDeclaration * node)230 bool RemoveUnreferencedVariablesTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
231 {
232     if (visit == PreVisit)
233     {
234         // SeparateDeclarations should have already been run.
235         ASSERT(node->getSequence()->size() == 1u);
236 
237         TIntermTyped *declarator = node->getSequence()->back()->getAsTyped();
238         ASSERT(declarator);
239 
240         // We can only remove variables that are not a part of the shader interface.
241         TQualifier qualifier = declarator->getQualifier();
242         if (qualifier != EvqTemporary && qualifier != EvqGlobal && qualifier != EvqConst)
243         {
244             return true;
245         }
246 
247         bool canRemoveVariable    = false;
248         TIntermSymbol *symbolNode = declarator->getAsSymbolNode();
249         if (symbolNode != nullptr)
250         {
251             canRemoveVariable = (*mSymbolIdRefCounts)[symbolNode->uniqueId().get()] == 1u ||
252                                 symbolNode->variable().symbolType() == SymbolType::Empty;
253         }
254         TIntermBinary *initNode = declarator->getAsBinaryNode();
255         if (initNode != nullptr)
256         {
257             ASSERT(initNode->getLeft()->getAsSymbolNode());
258             int symbolId = initNode->getLeft()->getAsSymbolNode()->uniqueId().get();
259             canRemoveVariable =
260                 (*mSymbolIdRefCounts)[symbolId] == 1u && !initNode->getRight()->hasSideEffects();
261         }
262 
263         if (canRemoveVariable)
264         {
265             removeVariableDeclaration(node, declarator);
266             mRemoveReferences = true;
267         }
268         return true;
269     }
270     ASSERT(visit == PostVisit);
271     mRemoveReferences = false;
272     return true;
273 }
274 
visitSymbol(TIntermSymbol * node)275 void RemoveUnreferencedVariablesTraverser::visitSymbol(TIntermSymbol *node)
276 {
277     if (mRemoveReferences)
278     {
279         ASSERT(mSymbolIdRefCounts->find(node->uniqueId().get()) != mSymbolIdRefCounts->end());
280         --(*mSymbolIdRefCounts)[node->uniqueId().get()];
281 
282         decrementStructTypeRefCount(node->getType());
283     }
284 }
285 
visitAggregate(Visit visit,TIntermAggregate * node)286 bool RemoveUnreferencedVariablesTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
287 {
288     if (visit == PreVisit && mRemoveReferences)
289     {
290         decrementStructTypeRefCount(node->getType());
291     }
292     return true;
293 }
294 
traverseBlock(TIntermBlock * node)295 void RemoveUnreferencedVariablesTraverser::traverseBlock(TIntermBlock *node)
296 {
297     // We traverse blocks in reverse order.  This way reference counts can be decremented when
298     // removing initializers, and variables that become unused when initializers are removed can be
299     // removed on the same traversal.
300 
301     ScopedNodeInTraversalPath addToPath(this, node);
302 
303     bool visit = true;
304 
305     TIntermSequence *sequence = node->getSequence();
306 
307     if (preVisit)
308         visit = visitBlock(PreVisit, node);
309 
310     if (visit)
311     {
312         for (auto iter = sequence->rbegin(); iter != sequence->rend(); ++iter)
313         {
314             (*iter)->traverse(this);
315             if (visit && inVisit)
316             {
317                 if ((iter + 1) != sequence->rend())
318                     visit = visitBlock(InVisit, node);
319             }
320         }
321     }
322 
323     if (visit && postVisit)
324         visitBlock(PostVisit, node);
325 }
326 
traverseLoop(TIntermLoop * node)327 void RemoveUnreferencedVariablesTraverser::traverseLoop(TIntermLoop *node)
328 {
329     // We traverse loops in reverse order as well. The loop body gets traversed before the init
330     // node.
331 
332     ScopedNodeInTraversalPath addToPath(this, node);
333 
334     bool visit = true;
335 
336     if (preVisit)
337         visit = visitLoop(PreVisit, node);
338 
339     if (visit)
340     {
341         // We don't need to traverse loop expressions or conditions since they can't be declarations
342         // in the AST (loops which have a declaration in their condition get transformed in the
343         // parsing stage).
344         ASSERT(node->getExpression() == nullptr ||
345                node->getExpression()->getAsDeclarationNode() == nullptr);
346         ASSERT(node->getCondition() == nullptr ||
347                node->getCondition()->getAsDeclarationNode() == nullptr);
348 
349         node->getBody()->traverse(this);
350 
351         if (node->getInit())
352             node->getInit()->traverse(this);
353     }
354 
355     if (visit && postVisit)
356         visitLoop(PostVisit, node);
357 }
358 
359 }  // namespace
360 
RemoveUnreferencedVariables(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)361 bool RemoveUnreferencedVariables(TCompiler *compiler, TIntermBlock *root, TSymbolTable *symbolTable)
362 {
363     CollectVariableRefCountsTraverser collector;
364     root->traverse(&collector);
365     RemoveUnreferencedVariablesTraverser traverser(&collector.getSymbolIdRefCounts(),
366                                                    &collector.getStructIdRefCounts(), symbolTable);
367     root->traverse(&traverser);
368     return traverser.updateTree(compiler, root);
369 }
370 
371 }  // namespace sh
372