1 //===- ConvergenceUtils.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "llvm/Analysis/UniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/Analysis/CycleAnalysis.h"
12 #include "llvm/Analysis/TargetTransformInfo.h"
13 #include "llvm/IR/Constants.h"
14 #include "llvm/IR/Dominators.h"
15 #include "llvm/IR/InstIterator.h"
16 #include "llvm/IR/Instructions.h"
17 #include "llvm/InitializePasses.h"
18
19 using namespace llvm;
20
21 template <>
hasDivergentDefs(const Instruction & I) const22 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
23 const Instruction &I) const {
24 return isDivergent((const Value *)&I);
25 }
26
27 template <>
markDefsDivergent(const Instruction & Instr,bool AllDefsDivergent)28 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
29 const Instruction &Instr, bool AllDefsDivergent) {
30 return markDivergent(&Instr);
31 }
32
initialize()33 template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
34 for (auto &I : instructions(F)) {
35 if (TTI->isSourceOfDivergence(&I)) {
36 assert(!I.isTerminator());
37 markDivergent(I);
38 } else if (TTI->isAlwaysUniform(&I)) {
39 addUniformOverride(I);
40 }
41 }
42 for (auto &Arg : F.args()) {
43 if (TTI->isSourceOfDivergence(&Arg)) {
44 markDivergent(&Arg);
45 }
46 }
47 }
48
49 template <>
pushUsers(const Value * V)50 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
51 const Value *V) {
52 for (const auto *User : V->users()) {
53 const auto *UserInstr = dyn_cast<const Instruction>(User);
54 if (!UserInstr)
55 continue;
56 if (isAlwaysUniform(*UserInstr))
57 continue;
58 if (markDivergent(*UserInstr)) {
59 Worklist.push_back(UserInstr);
60 }
61 }
62 }
63
64 template <>
pushUsers(const Instruction & Instr)65 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
66 const Instruction &Instr) {
67 assert(!isAlwaysUniform(Instr));
68 if (Instr.isTerminator())
69 return;
70 pushUsers(cast<Value>(&Instr));
71 }
72
73 template <>
usesValueFromCycle(const Instruction & I,const Cycle & DefCycle) const74 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
75 const Instruction &I, const Cycle &DefCycle) const {
76 if (isAlwaysUniform(I))
77 return false;
78 for (const Use &U : I.operands()) {
79 if (auto *I = dyn_cast<Instruction>(&U)) {
80 if (DefCycle.contains(I->getParent()))
81 return true;
82 }
83 }
84 return false;
85 }
86
87 // This ensures explicit instantiation of
88 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
89 template class llvm::GenericUniformityInfo<SSAContext>;
90 template struct llvm::GenericUniformityAnalysisImplDeleter<
91 llvm::GenericUniformityAnalysisImpl<SSAContext>>;
92
93 //===----------------------------------------------------------------------===//
94 // UniformityInfoAnalysis and related pass implementations
95 //===----------------------------------------------------------------------===//
96
run(Function & F,FunctionAnalysisManager & FAM)97 llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
98 FunctionAnalysisManager &FAM) {
99 auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
100 auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
101 auto &CI = FAM.getResult<CycleAnalysis>(F);
102 return UniformityInfo{F, DT, CI, &TTI};
103 }
104
105 AnalysisKey UniformityInfoAnalysis::Key;
106
UniformityInfoPrinterPass(raw_ostream & OS)107 UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS)
108 : OS(OS) {}
109
run(Function & F,FunctionAnalysisManager & AM)110 PreservedAnalyses UniformityInfoPrinterPass::run(Function &F,
111 FunctionAnalysisManager &AM) {
112 OS << "UniformityInfo for function '" << F.getName() << "':\n";
113 AM.getResult<UniformityInfoAnalysis>(F).print(OS);
114
115 return PreservedAnalyses::all();
116 }
117
118 //===----------------------------------------------------------------------===//
119 // UniformityInfoWrapperPass Implementation
120 //===----------------------------------------------------------------------===//
121
122 char UniformityInfoWrapperPass::ID = 0;
123
UniformityInfoWrapperPass()124 UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {
125 initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry());
126 }
127
128 INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniforminfo",
129 "Uniform Info Analysis", true, true)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)130 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
131 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
132 INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniforminfo",
133 "Uniform Info Analysis", true, true)
134
135 void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
136 AU.setPreservesAll();
137 AU.addRequired<DominatorTreeWrapperPass>();
138 AU.addRequired<CycleInfoWrapperPass>();
139 AU.addRequired<TargetTransformInfoWrapperPass>();
140 }
141
runOnFunction(Function & F)142 bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
143 auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
144 auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
145 auto &targetTransformInfo =
146 getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
147
148 m_function = &F;
149 m_uniformityInfo =
150 UniformityInfo{F, domTree, cycleInfo, &targetTransformInfo};
151 return false;
152 }
153
print(raw_ostream & OS,const Module *) const154 void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
155 OS << "UniformityInfo for function '" << m_function->getName() << "':\n";
156 }
157
releaseMemory()158 void UniformityInfoWrapperPass::releaseMemory() {
159 m_uniformityInfo = UniformityInfo{};
160 m_function = nullptr;
161 }
162