xref: /aosp_15_r20/external/swiftshader/third_party/llvm-16.0/llvm/lib/Analysis/DivergenceAnalysis.cpp (revision 03ce13f70fcc45d86ee91b7ee4cab1936a95046e)
1 //===---- DivergenceAnalysis.cpp --- Divergence Analysis Implementation ----==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a general divergence analysis for loop vectorization
10 // and GPU programs. It determines which branches and values in a loop or GPU
11 // program are divergent. It can help branch optimizations such as jump
12 // threading and loop unswitching to make better decisions.
13 //
14 // GPU programs typically use the SIMD execution model, where multiple threads
15 // in the same execution group have to execute in lock-step. Therefore, if the
16 // code contains divergent branches (i.e., threads in a group do not agree on
17 // which path of the branch to take), the group of threads has to execute all
18 // the paths from that branch with different subsets of threads enabled until
19 // they re-converge.
20 //
21 // Due to this execution model, some optimizations such as jump
22 // threading and loop unswitching can interfere with thread re-convergence.
23 // Therefore, an analysis that computes which branches in a GPU program are
24 // divergent can help the compiler to selectively run these optimizations.
25 //
26 // This implementation is derived from the Vectorization Analysis of the
27 // Region Vectorizer (RV). The analysis is based on the approach described in
28 //
29 //   An abstract interpretation for SPMD divergence
30 //       on reducible control flow graphs.
31 //   Julian Rosemann, Simon Moll and Sebastian Hack
32 //   POPL '21
33 //
34 // This implementation is generic in the sense that it does
35 // not itself identify original sources of divergence.
36 // Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and
37 // (DivergenceAnalysis) for functions, identify the sources of divergence
38 // (e.g., special variables that hold the thread ID or the iteration variable).
39 //
40 // The generic implementation propagates divergence to variables that are data
41 // or sync dependent on a source of divergence.
42 //
43 // While data dependency is a well-known concept, the notion of sync dependency
44 // is worth more explanation. Sync dependence characterizes the control flow
45 // aspect of the propagation of branch divergence. For example,
46 //
47 //   %cond = icmp slt i32 %tid, 10
48 //   br i1 %cond, label %then, label %else
49 // then:
50 //   br label %merge
51 // else:
52 //   br label %merge
53 // merge:
54 //   %a = phi i32 [ 0, %then ], [ 1, %else ]
55 //
56 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
57 // because %tid is not on its use-def chains, %a is sync dependent on %tid
58 // because the branch "br i1 %cond" depends on %tid and affects which value %a
59 // is assigned to.
60 //
61 // The sync dependence detection (which branch induces divergence in which join
62 // points) is implemented in the SyncDependenceAnalysis.
63 //
64 // The current implementation has the following limitations:
65 // 1. intra-procedural. It conservatively considers the arguments of a
66 //    non-kernel-entry function and the return value of a function call as
67 //    divergent.
68 // 2. memory as black box. It conservatively considers values loaded from
69 //    generic or local address as divergent. This can be improved by leveraging
70 //    pointer analysis and/or by modelling non-escaping memory objects in SSA
71 //    as done in RV.
72 //
73 //===----------------------------------------------------------------------===//
74 
75 #include "llvm/Analysis/DivergenceAnalysis.h"
76 #include "llvm/ADT/PostOrderIterator.h"
77 #include "llvm/Analysis/CFG.h"
78 #include "llvm/Analysis/LoopInfo.h"
79 #include "llvm/Analysis/PostDominators.h"
80 #include "llvm/Analysis/TargetTransformInfo.h"
81 #include "llvm/IR/Dominators.h"
82 #include "llvm/IR/InstIterator.h"
83 #include "llvm/IR/Instructions.h"
84 #include "llvm/IR/Value.h"
85 #include "llvm/Support/Debug.h"
86 #include "llvm/Support/raw_ostream.h"
87 
88 using namespace llvm;
89 
90 #define DEBUG_TYPE "divergence"
91 
DivergenceAnalysisImpl(const Function & F,const Loop * RegionLoop,const DominatorTree & DT,const LoopInfo & LI,SyncDependenceAnalysis & SDA,bool IsLCSSAForm)92 DivergenceAnalysisImpl::DivergenceAnalysisImpl(
93     const Function &F, const Loop *RegionLoop, const DominatorTree &DT,
94     const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm)
95     : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA),
96       IsLCSSAForm(IsLCSSAForm) {}
97 
markDivergent(const Value & DivVal)98 bool DivergenceAnalysisImpl::markDivergent(const Value &DivVal) {
99   if (isAlwaysUniform(DivVal))
100     return false;
101   assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal));
102   assert(!isAlwaysUniform(DivVal) && "cannot be a divergent");
103   return DivergentValues.insert(&DivVal).second;
104 }
105 
addUniformOverride(const Value & UniVal)106 void DivergenceAnalysisImpl::addUniformOverride(const Value &UniVal) {
107   UniformOverrides.insert(&UniVal);
108 }
109 
isTemporalDivergent(const BasicBlock & ObservingBlock,const Value & Val) const110 bool DivergenceAnalysisImpl::isTemporalDivergent(
111     const BasicBlock &ObservingBlock, const Value &Val) const {
112   const auto *Inst = dyn_cast<const Instruction>(&Val);
113   if (!Inst)
114     return false;
115   // check whether any divergent loop carrying Val terminates before control
116   // proceeds to ObservingBlock
117   for (const auto *Loop = LI.getLoopFor(Inst->getParent());
118        Loop != RegionLoop && !Loop->contains(&ObservingBlock);
119        Loop = Loop->getParentLoop()) {
120     if (DivergentLoops.contains(Loop))
121       return true;
122   }
123 
124   return false;
125 }
126 
inRegion(const Instruction & I) const127 bool DivergenceAnalysisImpl::inRegion(const Instruction &I) const {
128   return I.getParent() && inRegion(*I.getParent());
129 }
130 
inRegion(const BasicBlock & BB) const131 bool DivergenceAnalysisImpl::inRegion(const BasicBlock &BB) const {
132   return RegionLoop ? RegionLoop->contains(&BB) : (BB.getParent() == &F);
133 }
134 
pushUsers(const Value & V)135 void DivergenceAnalysisImpl::pushUsers(const Value &V) {
136   const auto *I = dyn_cast<const Instruction>(&V);
137 
138   if (I && I->isTerminator()) {
139     analyzeControlDivergence(*I);
140     return;
141   }
142 
143   for (const auto *User : V.users()) {
144     const auto *UserInst = dyn_cast<const Instruction>(User);
145     if (!UserInst)
146       continue;
147 
148     // only compute divergent inside loop
149     if (!inRegion(*UserInst))
150       continue;
151 
152     // All users of divergent values are immediate divergent
153     if (markDivergent(*UserInst))
154       Worklist.push_back(UserInst);
155   }
156 }
157 
getIfCarriedInstruction(const Use & U,const Loop & DivLoop)158 static const Instruction *getIfCarriedInstruction(const Use &U,
159                                                   const Loop &DivLoop) {
160   const auto *I = dyn_cast<const Instruction>(&U);
161   if (!I)
162     return nullptr;
163   if (!DivLoop.contains(I))
164     return nullptr;
165   return I;
166 }
167 
analyzeTemporalDivergence(const Instruction & I,const Loop & OuterDivLoop)168 void DivergenceAnalysisImpl::analyzeTemporalDivergence(
169     const Instruction &I, const Loop &OuterDivLoop) {
170   if (isAlwaysUniform(I))
171     return;
172   if (isDivergent(I))
173     return;
174 
175   LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << I.getName() << "\n");
176   assert((isa<PHINode>(I) || !IsLCSSAForm) &&
177          "In LCSSA form all users of loop-exiting defs are Phi nodes.");
178   for (const Use &Op : I.operands()) {
179     const auto *OpInst = getIfCarriedInstruction(Op, OuterDivLoop);
180     if (!OpInst)
181       continue;
182     if (markDivergent(I))
183       pushUsers(I);
184     return;
185   }
186 }
187 
188 // marks all users of loop-carried values of the loop headed by LoopHeader as
189 // divergent
analyzeLoopExitDivergence(const BasicBlock & DivExit,const Loop & OuterDivLoop)190 void DivergenceAnalysisImpl::analyzeLoopExitDivergence(
191     const BasicBlock &DivExit, const Loop &OuterDivLoop) {
192   // All users are in immediate exit blocks
193   if (IsLCSSAForm) {
194     for (const auto &Phi : DivExit.phis()) {
195       analyzeTemporalDivergence(Phi, OuterDivLoop);
196     }
197     return;
198   }
199 
200   // For non-LCSSA we have to follow all live out edges wherever they may lead.
201   const BasicBlock &LoopHeader = *OuterDivLoop.getHeader();
202   SmallVector<const BasicBlock *, 8> TaintStack;
203   TaintStack.push_back(&DivExit);
204 
205   // Otherwise potential users of loop-carried values could be anywhere in the
206   // dominance region of DivLoop (including its fringes for phi nodes)
207   DenseSet<const BasicBlock *> Visited;
208   Visited.insert(&DivExit);
209 
210   do {
211     auto *UserBlock = TaintStack.pop_back_val();
212 
213     // don't spread divergence beyond the region
214     if (!inRegion(*UserBlock))
215       continue;
216 
217     assert(!OuterDivLoop.contains(UserBlock) &&
218            "irreducible control flow detected");
219 
220     // phi nodes at the fringes of the dominance region
221     if (!DT.dominates(&LoopHeader, UserBlock)) {
222       // all PHI nodes of UserBlock become divergent
223       for (const auto &Phi : UserBlock->phis()) {
224         analyzeTemporalDivergence(Phi, OuterDivLoop);
225       }
226       continue;
227     }
228 
229     // Taint outside users of values carried by OuterDivLoop.
230     for (const auto &I : *UserBlock) {
231       analyzeTemporalDivergence(I, OuterDivLoop);
232     }
233 
234     // visit all blocks in the dominance region
235     for (const auto *SuccBlock : successors(UserBlock)) {
236       if (!Visited.insert(SuccBlock).second) {
237         continue;
238       }
239       TaintStack.push_back(SuccBlock);
240     }
241   } while (!TaintStack.empty());
242 }
243 
propagateLoopExitDivergence(const BasicBlock & DivExit,const Loop & InnerDivLoop)244 void DivergenceAnalysisImpl::propagateLoopExitDivergence(
245     const BasicBlock &DivExit, const Loop &InnerDivLoop) {
246   LLVM_DEBUG(dbgs() << "\tpropLoopExitDiv " << DivExit.getName() << "\n");
247 
248   // Find outer-most loop that does not contain \p DivExit
249   const Loop *DivLoop = &InnerDivLoop;
250   const Loop *OuterDivLoop = DivLoop;
251   const Loop *ExitLevelLoop = LI.getLoopFor(&DivExit);
252   const unsigned LoopExitDepth =
253       ExitLevelLoop ? ExitLevelLoop->getLoopDepth() : 0;
254   while (DivLoop && DivLoop->getLoopDepth() > LoopExitDepth) {
255     DivergentLoops.insert(DivLoop); // all crossed loops are divergent
256     OuterDivLoop = DivLoop;
257     DivLoop = DivLoop->getParentLoop();
258   }
259   LLVM_DEBUG(dbgs() << "\tOuter-most left loop: " << OuterDivLoop->getName()
260                     << "\n");
261 
262   analyzeLoopExitDivergence(DivExit, *OuterDivLoop);
263 }
264 
265 // this is a divergent join point - mark all phi nodes as divergent and push
266 // them onto the stack.
taintAndPushPhiNodes(const BasicBlock & JoinBlock)267 void DivergenceAnalysisImpl::taintAndPushPhiNodes(const BasicBlock &JoinBlock) {
268   LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << JoinBlock.getName()
269                     << "\n");
270 
271   // ignore divergence outside the region
272   if (!inRegion(JoinBlock)) {
273     return;
274   }
275 
276   // push non-divergent phi nodes in JoinBlock to the worklist
277   for (const auto &Phi : JoinBlock.phis()) {
278     if (isDivergent(Phi))
279       continue;
280     // FIXME Theoretically ,the 'undef' value could be replaced by any other
281     // value causing spurious divergence.
282     if (Phi.hasConstantOrUndefValue())
283       continue;
284     if (markDivergent(Phi))
285       Worklist.push_back(&Phi);
286   }
287 }
288 
analyzeControlDivergence(const Instruction & Term)289 void DivergenceAnalysisImpl::analyzeControlDivergence(const Instruction &Term) {
290   LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Term.getParent()->getName()
291                     << "\n");
292 
293   // Don't propagate divergence from unreachable blocks.
294   if (!DT.isReachableFromEntry(Term.getParent()))
295     return;
296 
297   const auto *BranchLoop = LI.getLoopFor(Term.getParent());
298 
299   const auto &DivDesc = SDA.getJoinBlocks(Term);
300 
301   // Iterate over all blocks now reachable by a disjoint path join
302   for (const auto *JoinBlock : DivDesc.JoinDivBlocks) {
303     taintAndPushPhiNodes(*JoinBlock);
304   }
305 
306   assert(DivDesc.LoopDivBlocks.empty() || BranchLoop);
307   for (const auto *DivExitBlock : DivDesc.LoopDivBlocks) {
308     propagateLoopExitDivergence(*DivExitBlock, *BranchLoop);
309   }
310 }
311 
compute()312 void DivergenceAnalysisImpl::compute() {
313   // Initialize worklist.
314   auto DivValuesCopy = DivergentValues;
315   for (const auto *DivVal : DivValuesCopy) {
316     assert(isDivergent(*DivVal) && "Worklist invariant violated!");
317     pushUsers(*DivVal);
318   }
319 
320   // All values on the Worklist are divergent.
321   // Their users may not have been updated yed.
322   while (!Worklist.empty()) {
323     const Instruction &I = *Worklist.back();
324     Worklist.pop_back();
325 
326     // propagate value divergence to users
327     assert(isDivergent(I) && "Worklist invariant violated!");
328     pushUsers(I);
329   }
330 }
331 
isAlwaysUniform(const Value & V) const332 bool DivergenceAnalysisImpl::isAlwaysUniform(const Value &V) const {
333   return UniformOverrides.contains(&V);
334 }
335 
isDivergent(const Value & V) const336 bool DivergenceAnalysisImpl::isDivergent(const Value &V) const {
337   return DivergentValues.contains(&V);
338 }
339 
isDivergentUse(const Use & U) const340 bool DivergenceAnalysisImpl::isDivergentUse(const Use &U) const {
341   Value &V = *U.get();
342   Instruction &I = *cast<Instruction>(U.getUser());
343   return isDivergent(V) || isTemporalDivergent(*I.getParent(), V);
344 }
345 
DivergenceInfo(Function & F,const DominatorTree & DT,const PostDominatorTree & PDT,const LoopInfo & LI,const TargetTransformInfo & TTI,bool KnownReducible)346 DivergenceInfo::DivergenceInfo(Function &F, const DominatorTree &DT,
347                                const PostDominatorTree &PDT, const LoopInfo &LI,
348                                const TargetTransformInfo &TTI,
349                                bool KnownReducible)
350     : F(F) {
351   if (!KnownReducible) {
352     using RPOTraversal = ReversePostOrderTraversal<const Function *>;
353     RPOTraversal FuncRPOT(&F);
354     if (containsIrreducibleCFG<const BasicBlock *, const RPOTraversal,
355                                const LoopInfo>(FuncRPOT, LI)) {
356       ContainsIrreducible = true;
357       return;
358     }
359   }
360   SDA = std::make_unique<SyncDependenceAnalysis>(DT, PDT, LI);
361   DA = std::make_unique<DivergenceAnalysisImpl>(F, nullptr, DT, LI, *SDA,
362                                                 /* LCSSA */ false);
363   for (auto &I : instructions(F)) {
364     if (TTI.isSourceOfDivergence(&I)) {
365       DA->markDivergent(I);
366     } else if (TTI.isAlwaysUniform(&I)) {
367       DA->addUniformOverride(I);
368     }
369   }
370   for (auto &Arg : F.args()) {
371     if (TTI.isSourceOfDivergence(&Arg)) {
372       DA->markDivergent(Arg);
373     }
374   }
375 
376   DA->compute();
377 }
378 
379 AnalysisKey DivergenceAnalysis::Key;
380 
381 DivergenceAnalysis::Result
run(Function & F,FunctionAnalysisManager & AM)382 DivergenceAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
383   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
384   auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F);
385   auto &LI = AM.getResult<LoopAnalysis>(F);
386   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
387 
388   return DivergenceInfo(F, DT, PDT, LI, TTI, /* KnownReducible = */ false);
389 }
390 
391 PreservedAnalyses
run(Function & F,FunctionAnalysisManager & FAM)392 DivergenceAnalysisPrinterPass::run(Function &F, FunctionAnalysisManager &FAM) {
393   auto &DI = FAM.getResult<DivergenceAnalysis>(F);
394   OS << "'Divergence Analysis' for function '" << F.getName() << "':\n";
395   if (DI.hasDivergence()) {
396     for (auto &Arg : F.args()) {
397       OS << (DI.isDivergent(Arg) ? "DIVERGENT: " : "           ");
398       OS << Arg << "\n";
399     }
400     for (const BasicBlock &BB : F) {
401       OS << "\n           " << BB.getName() << ":\n";
402       for (const auto &I : BB.instructionsWithoutDebug()) {
403         OS << (DI.isDivergent(I) ? "DIVERGENT:     " : "               ");
404         OS << I << "\n";
405       }
406     }
407   }
408   return PreservedAnalyses::all();
409 }
410