xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/SimplifyLoopConditions.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2016 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // SimplifyLoopConditions is an AST traverser that converts loop conditions and loop expressions
7 // to regular statements inside the loop. This way further transformations that generate statements
8 // from loop conditions and loop expressions work correctly.
9 //
10 
11 #include "compiler/translator/tree_ops/SimplifyLoopConditions.h"
12 
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/tree_util/IntermNodePatternMatcher.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 
18 namespace sh
19 {
20 
21 namespace
22 {
23 
24 struct LoopInfo
25 {
26     const TVariable *conditionVariable = nullptr;
27     TIntermTyped *condition            = nullptr;
28     TIntermTyped *expression           = nullptr;
29 };
30 
31 class SimplifyLoopConditionsTraverser : public TLValueTrackingTraverser
32 {
33   public:
34     SimplifyLoopConditionsTraverser(const IntermNodePatternMatcher *conditionsToSimplify,
35                                     TSymbolTable *symbolTable);
36 
37     void traverseLoop(TIntermLoop *node) override;
38 
39     bool visitUnary(Visit visit, TIntermUnary *node) override;
40     bool visitBinary(Visit visit, TIntermBinary *node) override;
41     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
42     bool visitTernary(Visit visit, TIntermTernary *node) override;
43     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
44     bool visitBranch(Visit visit, TIntermBranch *node) override;
45 
foundLoopToChange() const46     bool foundLoopToChange() const { return mFoundLoopToChange; }
47 
48   protected:
49     // Marked to true once an operation that needs to be hoisted out of a loop expression has been
50     // found.
51     bool mFoundLoopToChange;
52     bool mInsideLoopInitConditionOrExpression;
53     const IntermNodePatternMatcher *mConditionsToSimplify;
54 
55   private:
56     LoopInfo mLoop;
57 };
58 
SimplifyLoopConditionsTraverser(const IntermNodePatternMatcher * conditionsToSimplify,TSymbolTable * symbolTable)59 SimplifyLoopConditionsTraverser::SimplifyLoopConditionsTraverser(
60     const IntermNodePatternMatcher *conditionsToSimplify,
61     TSymbolTable *symbolTable)
62     : TLValueTrackingTraverser(true, false, false, symbolTable),
63       mFoundLoopToChange(false),
64       mInsideLoopInitConditionOrExpression(false),
65       mConditionsToSimplify(conditionsToSimplify)
66 {}
67 
68 // If we're inside a loop initialization, condition, or expression, we check for expressions that
69 // should be moved out of the loop condition or expression. If one is found, the loop is
70 // transformed.
71 // If we're not inside loop initialization, condition, or expression, we only need to traverse nodes
72 // that may contain loops.
73 
visitUnary(Visit visit,TIntermUnary * node)74 bool SimplifyLoopConditionsTraverser::visitUnary(Visit visit, TIntermUnary *node)
75 {
76     if (!mInsideLoopInitConditionOrExpression)
77         return false;
78 
79     if (mFoundLoopToChange)
80         return false;  // Already decided to change this loop.
81 
82     ASSERT(mConditionsToSimplify);
83     mFoundLoopToChange = mConditionsToSimplify->match(node);
84     return !mFoundLoopToChange;
85 }
86 
visitBinary(Visit visit,TIntermBinary * node)87 bool SimplifyLoopConditionsTraverser::visitBinary(Visit visit, TIntermBinary *node)
88 {
89     if (!mInsideLoopInitConditionOrExpression)
90         return false;
91 
92     if (mFoundLoopToChange)
93         return false;  // Already decided to change this loop.
94 
95     ASSERT(mConditionsToSimplify);
96     mFoundLoopToChange =
97         mConditionsToSimplify->match(node, getParentNode(), isLValueRequiredHere());
98     return !mFoundLoopToChange;
99 }
100 
visitAggregate(Visit visit,TIntermAggregate * node)101 bool SimplifyLoopConditionsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
102 {
103     if (!mInsideLoopInitConditionOrExpression)
104         return false;
105 
106     if (mFoundLoopToChange)
107         return false;  // Already decided to change this loop.
108 
109     ASSERT(mConditionsToSimplify);
110     mFoundLoopToChange = mConditionsToSimplify->match(node, getParentNode());
111     return !mFoundLoopToChange;
112 }
113 
visitTernary(Visit visit,TIntermTernary * node)114 bool SimplifyLoopConditionsTraverser::visitTernary(Visit visit, TIntermTernary *node)
115 {
116     if (!mInsideLoopInitConditionOrExpression)
117         return false;
118 
119     if (mFoundLoopToChange)
120         return false;  // Already decided to change this loop.
121 
122     ASSERT(mConditionsToSimplify);
123     mFoundLoopToChange = mConditionsToSimplify->match(node);
124     return !mFoundLoopToChange;
125 }
126 
visitDeclaration(Visit visit,TIntermDeclaration * node)127 bool SimplifyLoopConditionsTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
128 {
129     if (!mInsideLoopInitConditionOrExpression)
130         return false;
131 
132     if (mFoundLoopToChange)
133         return false;  // Already decided to change this loop.
134 
135     ASSERT(mConditionsToSimplify);
136     mFoundLoopToChange = mConditionsToSimplify->match(node);
137     return !mFoundLoopToChange;
138 }
139 
visitBranch(Visit visit,TIntermBranch * node)140 bool SimplifyLoopConditionsTraverser::visitBranch(Visit visit, TIntermBranch *node)
141 {
142     if (node->getFlowOp() == EOpContinue && (mLoop.condition || mLoop.expression))
143     {
144         TIntermBlock *parent = getParentNode()->getAsBlock();
145         ASSERT(parent);
146         TIntermSequence seq;
147         if (mLoop.expression)
148         {
149             seq.push_back(mLoop.expression->deepCopy());
150         }
151         if (mLoop.condition)
152         {
153             ASSERT(mLoop.conditionVariable);
154             seq.push_back(
155                 CreateTempAssignmentNode(mLoop.conditionVariable, mLoop.condition->deepCopy()));
156         }
157         seq.push_back(node);
158         mMultiReplacements.push_back(NodeReplaceWithMultipleEntry(parent, node, std::move(seq)));
159     }
160 
161     return true;
162 }
163 
CreateFromBody(TIntermLoop * node,bool * bodyEndsInBranchOut)164 static TIntermBlock *CreateFromBody(TIntermLoop *node, bool *bodyEndsInBranchOut)
165 {
166     TIntermBlock *newBody  = new TIntermBlock();
167     TIntermBlock *nodeBody = node->getBody();
168     newBody->getSequence()->push_back(nodeBody);
169     *bodyEndsInBranchOut = EndsInBranch(nodeBody);
170     return newBody;
171 }
172 
traverseLoop(TIntermLoop * node)173 void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node)
174 {
175     // Mark that we're inside a loop condition or expression, and determine if the loop needs to be
176     // transformed.
177 
178     ScopedNodeInTraversalPath addToPath(this, node);
179 
180     mInsideLoopInitConditionOrExpression = true;
181     mFoundLoopToChange                   = !mConditionsToSimplify;
182 
183     if (!mFoundLoopToChange && node->getInit())
184     {
185         node->getInit()->traverse(this);
186     }
187 
188     if (!mFoundLoopToChange && node->getCondition())
189     {
190         node->getCondition()->traverse(this);
191     }
192 
193     if (!mFoundLoopToChange && node->getExpression())
194     {
195         node->getExpression()->traverse(this);
196     }
197 
198     mInsideLoopInitConditionOrExpression = false;
199 
200     const LoopInfo prevLoop = mLoop;
201 
202     if (mFoundLoopToChange)
203     {
204         const TType *boolType   = StaticType::Get<EbtBool, EbpUndefined, EvqTemporary, 1, 1>();
205         mLoop.conditionVariable = CreateTempVariable(mSymbolTable, boolType);
206         mLoop.condition         = node->getCondition();
207         mLoop.expression        = node->getExpression();
208 
209         // Replace the loop condition with a boolean variable that's updated on each iteration.
210         TLoopType loopType = node->getType();
211         if (loopType == ELoopWhile)
212         {
213             ASSERT(!mLoop.expression);
214 
215             if (mLoop.condition->getAsSymbolNode())
216             {
217                 // Mask continue statement condition variable update.
218                 mLoop.condition = nullptr;
219             }
220             else if (mLoop.condition->getAsConstantUnion())
221             {
222                 // Transform:
223                 //   while (expr) { body; }
224                 // into
225                 //   bool s0 = expr;
226                 //   while (s0) { body; }
227                 TIntermDeclaration *tempInitDeclaration =
228                     CreateTempInitDeclarationNode(mLoop.conditionVariable, mLoop.condition);
229                 insertStatementInParentBlock(tempInitDeclaration);
230 
231                 node->setCondition(CreateTempSymbolNode(mLoop.conditionVariable));
232 
233                 // Mask continue statement condition variable update.
234                 mLoop.condition = nullptr;
235             }
236             else
237             {
238                 // Transform:
239                 //   while (expr) { body; }
240                 // into
241                 //   bool s0 = expr;
242                 //   while (s0) { { body; } s0 = expr; }
243                 //
244                 // Local case statements are transformed into:
245                 //   s0 = expr; continue;
246                 TIntermDeclaration *tempInitDeclaration =
247                     CreateTempInitDeclarationNode(mLoop.conditionVariable, mLoop.condition);
248                 insertStatementInParentBlock(tempInitDeclaration);
249 
250                 bool bodyEndsInBranch;
251                 TIntermBlock *newBody = CreateFromBody(node, &bodyEndsInBranch);
252                 if (!bodyEndsInBranch)
253                 {
254                     newBody->getSequence()->push_back(CreateTempAssignmentNode(
255                         mLoop.conditionVariable, mLoop.condition->deepCopy()));
256                 }
257 
258                 // Can't use queueReplacement to replace old body, since it may have been nullptr.
259                 // It's safe to do the replacements in place here - the new body will still be
260                 // traversed, but that won't create any problems.
261                 node->setBody(newBody);
262                 node->setCondition(CreateTempSymbolNode(mLoop.conditionVariable));
263             }
264         }
265         else if (loopType == ELoopDoWhile)
266         {
267             ASSERT(!mLoop.expression);
268 
269             if (mLoop.condition->getAsSymbolNode())
270             {
271                 // Mask continue statement condition variable update.
272                 mLoop.condition = nullptr;
273             }
274             else if (mLoop.condition->getAsConstantUnion())
275             {
276                 // Transform:
277                 //   do {
278                 //     body;
279                 //   } while (expr);
280                 // into
281                 //   bool s0 = expr;
282                 //   do {
283                 //     body;
284                 //   } while (s0);
285                 TIntermDeclaration *tempInitDeclaration =
286                     CreateTempInitDeclarationNode(mLoop.conditionVariable, mLoop.condition);
287                 insertStatementInParentBlock(tempInitDeclaration);
288 
289                 node->setCondition(CreateTempSymbolNode(mLoop.conditionVariable));
290 
291                 // Mask continue statement condition variable update.
292                 mLoop.condition = nullptr;
293             }
294             else
295             {
296                 // Transform:
297                 //   do {
298                 //     body;
299                 //   } while (expr);
300                 // into
301                 //   bool s0;
302                 //   do {
303                 //     { body; }
304                 //     s0 = expr;
305                 //   } while (s0);
306                 // Local case statements are transformed into:
307                 //   s0 = expr; continue;
308                 TIntermDeclaration *tempInitDeclaration =
309                     CreateTempDeclarationNode(mLoop.conditionVariable);
310                 insertStatementInParentBlock(tempInitDeclaration);
311 
312                 bool bodyEndsInBranch;
313                 TIntermBlock *newBody = CreateFromBody(node, &bodyEndsInBranch);
314                 if (!bodyEndsInBranch)
315                 {
316                     newBody->getSequence()->push_back(
317                         CreateTempAssignmentNode(mLoop.conditionVariable, mLoop.condition));
318                 }
319 
320                 // Can't use queueReplacement to replace old body, since it may have been nullptr.
321                 // It's safe to do the replacements in place here - the new body will still be
322                 // traversed, but that won't create any problems.
323                 node->setBody(newBody);
324                 node->setCondition(CreateTempSymbolNode(mLoop.conditionVariable));
325             }
326         }
327         else if (loopType == ELoopFor)
328         {
329             if (!mLoop.condition)
330             {
331                 mLoop.condition = CreateBoolNode(true);
332             }
333 
334             TIntermLoop *whileLoop;
335             TIntermBlock *loopScope            = new TIntermBlock();
336             TIntermSequence *loopScopeSequence = loopScope->getSequence();
337 
338             // Insert "init;"
339             if (node->getInit())
340             {
341                 loopScopeSequence->push_back(node->getInit());
342             }
343 
344             if (mLoop.condition->getAsSymbolNode())
345             {
346                 // Move the loop condition inside the loop.
347                 // Transform:
348                 //   for (init; expr; exprB) { body; }
349                 // into
350                 //   {
351                 //     init;
352                 //     while (expr) {
353                 //       { body; }
354                 //       exprB;
355                 //     }
356                 //   }
357                 //
358                 // Local case statements are transformed into:
359                 //   exprB; continue;
360 
361                 // Insert "{ body; }" in the while loop
362                 bool bodyEndsInBranch;
363                 TIntermBlock *whileLoopBody = CreateFromBody(node, &bodyEndsInBranch);
364                 // Insert "exprB;" in the while loop
365                 if (!bodyEndsInBranch && node->getExpression())
366                 {
367                     whileLoopBody->getSequence()->push_back(node->getExpression());
368                 }
369                 // Create "while(expr) { whileLoopBody }"
370                 whileLoop =
371                     new TIntermLoop(ELoopWhile, nullptr, mLoop.condition, nullptr, whileLoopBody);
372 
373                 // Mask continue statement condition variable update.
374                 mLoop.condition = nullptr;
375             }
376             else if (mLoop.condition->getAsConstantUnion())
377             {
378                 // Move the loop condition inside the loop.
379                 // Transform:
380                 //   for (init; expr; exprB) { body; }
381                 // into
382                 //   {
383                 //     init;
384                 //     bool s0 = expr;
385                 //     while (s0) {
386                 //       { body; }
387                 //       exprB;
388                 //     }
389                 //   }
390                 //
391                 // Local case statements are transformed into:
392                 //   exprB; continue;
393 
394                 // Insert "bool s0 = expr;"
395                 loopScopeSequence->push_back(
396                     CreateTempInitDeclarationNode(mLoop.conditionVariable, mLoop.condition));
397                 // Insert "{ body; }" in the while loop
398                 bool bodyEndsInBranch;
399                 TIntermBlock *whileLoopBody = CreateFromBody(node, &bodyEndsInBranch);
400                 // Insert "exprB;" in the while loop
401                 if (!bodyEndsInBranch && node->getExpression())
402                 {
403                     whileLoopBody->getSequence()->push_back(node->getExpression());
404                 }
405                 // Create "while(s0) { whileLoopBody }"
406                 whileLoop = new TIntermLoop(ELoopWhile, nullptr,
407                                             CreateTempSymbolNode(mLoop.conditionVariable), nullptr,
408                                             whileLoopBody);
409 
410                 // Mask continue statement condition variable update.
411                 mLoop.condition = nullptr;
412             }
413             else
414             {
415                 // Move the loop condition inside the loop.
416                 // Transform:
417                 //   for (init; expr; exprB) { body; }
418                 // into
419                 //   {
420                 //     init;
421                 //     bool s0 = expr;
422                 //     while (s0) {
423                 //       { body; }
424                 //       exprB;
425                 //       s0 = expr;
426                 //     }
427                 //   }
428                 //
429                 // Local case statements are transformed into:
430                 //   exprB; s0 = expr; continue;
431 
432                 // Insert "bool s0 = expr;"
433                 loopScopeSequence->push_back(
434                     CreateTempInitDeclarationNode(mLoop.conditionVariable, mLoop.condition));
435                 // Insert "{ body; }" in the while loop
436                 bool bodyEndsInBranch;
437                 TIntermBlock *whileLoopBody = CreateFromBody(node, &bodyEndsInBranch);
438                 // Insert "exprB;" in the while loop
439                 if (!bodyEndsInBranch && node->getExpression())
440                 {
441                     whileLoopBody->getSequence()->push_back(node->getExpression());
442                 }
443                 // Insert "s0 = expr;" in the while loop
444                 if (!bodyEndsInBranch)
445                 {
446                     whileLoopBody->getSequence()->push_back(CreateTempAssignmentNode(
447                         mLoop.conditionVariable, mLoop.condition->deepCopy()));
448                 }
449                 // Create "while(s0) { whileLoopBody }"
450                 whileLoop = new TIntermLoop(ELoopWhile, nullptr,
451                                             CreateTempSymbolNode(mLoop.conditionVariable), nullptr,
452                                             whileLoopBody);
453             }
454 
455             loopScope->getSequence()->push_back(whileLoop);
456             queueReplacement(loopScope, OriginalNode::IS_DROPPED);
457 
458             // After this the old body node will be traversed and loops inside it may be
459             // transformed. This is fine, since the old body node will still be in the AST after
460             // the transformation that's queued here, and transforming loops inside it doesn't
461             // need to know the exact post-transform path to it.
462         }
463     }
464 
465     mFoundLoopToChange = false;
466 
467     // We traverse the body of the loop even if the loop is transformed.
468     node->getBody()->traverse(this);
469 
470     mLoop = prevLoop;
471 }
472 
473 }  // namespace
474 
SimplifyLoopConditions(TCompiler * compiler,TIntermNode * root,TSymbolTable * symbolTable)475 bool SimplifyLoopConditions(TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable)
476 {
477     SimplifyLoopConditionsTraverser traverser(nullptr, symbolTable);
478     root->traverse(&traverser);
479     return traverser.updateTree(compiler, root);
480 }
481 
SimplifyLoopConditions(TCompiler * compiler,TIntermNode * root,unsigned int conditionsToSimplifyMask,TSymbolTable * symbolTable)482 bool SimplifyLoopConditions(TCompiler *compiler,
483                             TIntermNode *root,
484                             unsigned int conditionsToSimplifyMask,
485                             TSymbolTable *symbolTable)
486 {
487     IntermNodePatternMatcher conditionsToSimplify(conditionsToSimplifyMask);
488     SimplifyLoopConditionsTraverser traverser(&conditionsToSimplify, symbolTable);
489     root->traverse(&traverser);
490     return traverser.updateTree(compiler, root);
491 }
492 
493 }  // namespace sh
494