1*9880d681SAndroid Build Coastguard Worker //===- NVVMReflect.cpp - NVVM Emulate conditional compilation -------------===//
2*9880d681SAndroid Build Coastguard Worker //
3*9880d681SAndroid Build Coastguard Worker // The LLVM Compiler Infrastructure
4*9880d681SAndroid Build Coastguard Worker //
5*9880d681SAndroid Build Coastguard Worker // This file is distributed under the University of Illinois Open Source
6*9880d681SAndroid Build Coastguard Worker // License. See LICENSE.TXT for details.
7*9880d681SAndroid Build Coastguard Worker //
8*9880d681SAndroid Build Coastguard Worker //===----------------------------------------------------------------------===//
9*9880d681SAndroid Build Coastguard Worker //
10*9880d681SAndroid Build Coastguard Worker // This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
11*9880d681SAndroid Build Coastguard Worker // with an integer.
12*9880d681SAndroid Build Coastguard Worker //
13*9880d681SAndroid Build Coastguard Worker // We choose the value we use by looking, in this order, at:
14*9880d681SAndroid Build Coastguard Worker //
15*9880d681SAndroid Build Coastguard Worker // * the -nvvm-reflect-list flag, which has the format "foo=1,bar=42",
16*9880d681SAndroid Build Coastguard Worker // * the StringMap passed to the pass's constructor, and
17*9880d681SAndroid Build Coastguard Worker // * metadata in the module itself.
18*9880d681SAndroid Build Coastguard Worker //
19*9880d681SAndroid Build Coastguard Worker // If we see an unknown string, we replace its call with 0.
20*9880d681SAndroid Build Coastguard Worker //
21*9880d681SAndroid Build Coastguard Worker //===----------------------------------------------------------------------===//
22*9880d681SAndroid Build Coastguard Worker
23*9880d681SAndroid Build Coastguard Worker #include "NVPTX.h"
24*9880d681SAndroid Build Coastguard Worker #include "llvm/ADT/SmallVector.h"
25*9880d681SAndroid Build Coastguard Worker #include "llvm/ADT/StringMap.h"
26*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/Constants.h"
27*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/DerivedTypes.h"
28*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/Function.h"
29*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/InstIterator.h"
30*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/Instructions.h"
31*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/Intrinsics.h"
32*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/Module.h"
33*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/Type.h"
34*9880d681SAndroid Build Coastguard Worker #include "llvm/Pass.h"
35*9880d681SAndroid Build Coastguard Worker #include "llvm/Support/CommandLine.h"
36*9880d681SAndroid Build Coastguard Worker #include "llvm/Support/Debug.h"
37*9880d681SAndroid Build Coastguard Worker #include "llvm/Support/raw_os_ostream.h"
38*9880d681SAndroid Build Coastguard Worker #include "llvm/Support/raw_ostream.h"
39*9880d681SAndroid Build Coastguard Worker #include "llvm/Transforms/Scalar.h"
40*9880d681SAndroid Build Coastguard Worker #include <sstream>
41*9880d681SAndroid Build Coastguard Worker #include <string>
42*9880d681SAndroid Build Coastguard Worker #define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
43*9880d681SAndroid Build Coastguard Worker
44*9880d681SAndroid Build Coastguard Worker using namespace llvm;
45*9880d681SAndroid Build Coastguard Worker
46*9880d681SAndroid Build Coastguard Worker #define DEBUG_TYPE "nvptx-reflect"
47*9880d681SAndroid Build Coastguard Worker
48*9880d681SAndroid Build Coastguard Worker namespace llvm { void initializeNVVMReflectPass(PassRegistry &); }
49*9880d681SAndroid Build Coastguard Worker
50*9880d681SAndroid Build Coastguard Worker namespace {
51*9880d681SAndroid Build Coastguard Worker class NVVMReflect : public FunctionPass {
52*9880d681SAndroid Build Coastguard Worker private:
53*9880d681SAndroid Build Coastguard Worker StringMap<int> VarMap;
54*9880d681SAndroid Build Coastguard Worker
55*9880d681SAndroid Build Coastguard Worker public:
56*9880d681SAndroid Build Coastguard Worker static char ID;
NVVMReflect()57*9880d681SAndroid Build Coastguard Worker NVVMReflect() : NVVMReflect(StringMap<int>()) {}
58*9880d681SAndroid Build Coastguard Worker
NVVMReflect(const StringMap<int> & Mapping)59*9880d681SAndroid Build Coastguard Worker NVVMReflect(const StringMap<int> &Mapping)
60*9880d681SAndroid Build Coastguard Worker : FunctionPass(ID), VarMap(Mapping) {
61*9880d681SAndroid Build Coastguard Worker initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
62*9880d681SAndroid Build Coastguard Worker setVarMap();
63*9880d681SAndroid Build Coastguard Worker }
64*9880d681SAndroid Build Coastguard Worker
65*9880d681SAndroid Build Coastguard Worker bool runOnFunction(Function &) override;
66*9880d681SAndroid Build Coastguard Worker
67*9880d681SAndroid Build Coastguard Worker private:
68*9880d681SAndroid Build Coastguard Worker bool handleFunction(Function *ReflectFunction);
69*9880d681SAndroid Build Coastguard Worker void setVarMap();
70*9880d681SAndroid Build Coastguard Worker };
71*9880d681SAndroid Build Coastguard Worker }
72*9880d681SAndroid Build Coastguard Worker
createNVVMReflectPass()73*9880d681SAndroid Build Coastguard Worker FunctionPass *llvm::createNVVMReflectPass() { return new NVVMReflect(); }
createNVVMReflectPass(const StringMap<int> & Mapping)74*9880d681SAndroid Build Coastguard Worker FunctionPass *llvm::createNVVMReflectPass(const StringMap<int> &Mapping) {
75*9880d681SAndroid Build Coastguard Worker return new NVVMReflect(Mapping);
76*9880d681SAndroid Build Coastguard Worker }
77*9880d681SAndroid Build Coastguard Worker
78*9880d681SAndroid Build Coastguard Worker static cl::opt<bool>
79*9880d681SAndroid Build Coastguard Worker NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden,
80*9880d681SAndroid Build Coastguard Worker cl::desc("NVVM reflection, enabled by default"));
81*9880d681SAndroid Build Coastguard Worker
82*9880d681SAndroid Build Coastguard Worker char NVVMReflect::ID = 0;
83*9880d681SAndroid Build Coastguard Worker INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
84*9880d681SAndroid Build Coastguard Worker "Replace occurrences of __nvvm_reflect() calls with 0/1", false,
85*9880d681SAndroid Build Coastguard Worker false)
86*9880d681SAndroid Build Coastguard Worker
87*9880d681SAndroid Build Coastguard Worker static cl::list<std::string>
88*9880d681SAndroid Build Coastguard Worker ReflectList("nvvm-reflect-list", cl::value_desc("name=<int>"), cl::Hidden,
89*9880d681SAndroid Build Coastguard Worker cl::desc("A list of string=num assignments"),
90*9880d681SAndroid Build Coastguard Worker cl::ValueRequired);
91*9880d681SAndroid Build Coastguard Worker
92*9880d681SAndroid Build Coastguard Worker /// The command line can look as follows :
93*9880d681SAndroid Build Coastguard Worker /// -nvvm-reflect-list a=1,b=2 -nvvm-reflect-list c=3,d=0 -R e=2
94*9880d681SAndroid Build Coastguard Worker /// The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the
95*9880d681SAndroid Build Coastguard Worker /// ReflectList vector. First, each of ReflectList[i] is 'split'
96*9880d681SAndroid Build Coastguard Worker /// using "," as the delimiter. Then each of this part is split
97*9880d681SAndroid Build Coastguard Worker /// using "=" as the delimiter.
setVarMap()98*9880d681SAndroid Build Coastguard Worker void NVVMReflect::setVarMap() {
99*9880d681SAndroid Build Coastguard Worker for (unsigned i = 0, e = ReflectList.size(); i != e; ++i) {
100*9880d681SAndroid Build Coastguard Worker DEBUG(dbgs() << "Option : " << ReflectList[i] << "\n");
101*9880d681SAndroid Build Coastguard Worker SmallVector<StringRef, 4> NameValList;
102*9880d681SAndroid Build Coastguard Worker StringRef(ReflectList[i]).split(NameValList, ',');
103*9880d681SAndroid Build Coastguard Worker for (unsigned j = 0, ej = NameValList.size(); j != ej; ++j) {
104*9880d681SAndroid Build Coastguard Worker SmallVector<StringRef, 2> NameValPair;
105*9880d681SAndroid Build Coastguard Worker NameValList[j].split(NameValPair, '=');
106*9880d681SAndroid Build Coastguard Worker assert(NameValPair.size() == 2 && "name=val expected");
107*9880d681SAndroid Build Coastguard Worker std::stringstream ValStream(NameValPair[1]);
108*9880d681SAndroid Build Coastguard Worker int Val;
109*9880d681SAndroid Build Coastguard Worker ValStream >> Val;
110*9880d681SAndroid Build Coastguard Worker assert((!(ValStream.fail())) && "integer value expected");
111*9880d681SAndroid Build Coastguard Worker VarMap[NameValPair[0]] = Val;
112*9880d681SAndroid Build Coastguard Worker }
113*9880d681SAndroid Build Coastguard Worker }
114*9880d681SAndroid Build Coastguard Worker }
115*9880d681SAndroid Build Coastguard Worker
runOnFunction(Function & F)116*9880d681SAndroid Build Coastguard Worker bool NVVMReflect::runOnFunction(Function &F) {
117*9880d681SAndroid Build Coastguard Worker if (!NVVMReflectEnabled)
118*9880d681SAndroid Build Coastguard Worker return false;
119*9880d681SAndroid Build Coastguard Worker
120*9880d681SAndroid Build Coastguard Worker if (F.getName() == NVVM_REFLECT_FUNCTION) {
121*9880d681SAndroid Build Coastguard Worker assert(F.isDeclaration() && "_reflect function should not have a body");
122*9880d681SAndroid Build Coastguard Worker assert(F.getReturnType()->isIntegerTy() &&
123*9880d681SAndroid Build Coastguard Worker "_reflect's return type should be integer");
124*9880d681SAndroid Build Coastguard Worker return false;
125*9880d681SAndroid Build Coastguard Worker }
126*9880d681SAndroid Build Coastguard Worker
127*9880d681SAndroid Build Coastguard Worker SmallVector<Instruction *, 4> ToRemove;
128*9880d681SAndroid Build Coastguard Worker
129*9880d681SAndroid Build Coastguard Worker // Go through the calls in this function. Each call to __nvvm_reflect or
130*9880d681SAndroid Build Coastguard Worker // llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
131*9880d681SAndroid Build Coastguard Worker // First validate that. If the c-string corresponding to the ConstantArray can
132*9880d681SAndroid Build Coastguard Worker // be found successfully, see if it can be found in VarMap. If so, replace the
133*9880d681SAndroid Build Coastguard Worker // uses of CallInst with the value found in VarMap. If not, replace the use
134*9880d681SAndroid Build Coastguard Worker // with value 0.
135*9880d681SAndroid Build Coastguard Worker
136*9880d681SAndroid Build Coastguard Worker // The IR for __nvvm_reflect calls differs between CUDA versions.
137*9880d681SAndroid Build Coastguard Worker //
138*9880d681SAndroid Build Coastguard Worker // CUDA 6.5 and earlier uses this sequence:
139*9880d681SAndroid Build Coastguard Worker // %ptr = tail call i8* @llvm.nvvm.ptr.constant.to.gen.p0i8.p4i8
140*9880d681SAndroid Build Coastguard Worker // (i8 addrspace(4)* getelementptr inbounds
141*9880d681SAndroid Build Coastguard Worker // ([8 x i8], [8 x i8] addrspace(4)* @str, i32 0, i32 0))
142*9880d681SAndroid Build Coastguard Worker // %reflect = tail call i32 @__nvvm_reflect(i8* %ptr)
143*9880d681SAndroid Build Coastguard Worker //
144*9880d681SAndroid Build Coastguard Worker // The value returned by Sym->getOperand(0) is a Constant with a
145*9880d681SAndroid Build Coastguard Worker // ConstantDataSequential operand which can be converted to string and used
146*9880d681SAndroid Build Coastguard Worker // for lookup.
147*9880d681SAndroid Build Coastguard Worker //
148*9880d681SAndroid Build Coastguard Worker // CUDA 7.0 does it slightly differently:
149*9880d681SAndroid Build Coastguard Worker // %reflect = call i32 @__nvvm_reflect(i8* addrspacecast
150*9880d681SAndroid Build Coastguard Worker // (i8 addrspace(1)* getelementptr inbounds
151*9880d681SAndroid Build Coastguard Worker // ([8 x i8], [8 x i8] addrspace(1)* @str, i32 0, i32 0) to i8*))
152*9880d681SAndroid Build Coastguard Worker //
153*9880d681SAndroid Build Coastguard Worker // In this case, we get a Constant with a GlobalVariable operand and we need
154*9880d681SAndroid Build Coastguard Worker // to dig deeper to find its initializer with the string we'll use for lookup.
155*9880d681SAndroid Build Coastguard Worker for (Instruction &I : instructions(F)) {
156*9880d681SAndroid Build Coastguard Worker CallInst *Call = dyn_cast<CallInst>(&I);
157*9880d681SAndroid Build Coastguard Worker if (!Call)
158*9880d681SAndroid Build Coastguard Worker continue;
159*9880d681SAndroid Build Coastguard Worker Function *Callee = Call->getCalledFunction();
160*9880d681SAndroid Build Coastguard Worker if (!Callee || (Callee->getName() != NVVM_REFLECT_FUNCTION &&
161*9880d681SAndroid Build Coastguard Worker Callee->getIntrinsicID() != Intrinsic::nvvm_reflect))
162*9880d681SAndroid Build Coastguard Worker continue;
163*9880d681SAndroid Build Coastguard Worker
164*9880d681SAndroid Build Coastguard Worker // FIXME: Improve error handling here and elsewhere in this pass.
165*9880d681SAndroid Build Coastguard Worker assert(Call->getNumOperands() == 2 &&
166*9880d681SAndroid Build Coastguard Worker "Wrong number of operands to __nvvm_reflect function");
167*9880d681SAndroid Build Coastguard Worker
168*9880d681SAndroid Build Coastguard Worker // In cuda 6.5 and earlier, we will have an extra constant-to-generic
169*9880d681SAndroid Build Coastguard Worker // conversion of the string.
170*9880d681SAndroid Build Coastguard Worker const Value *Str = Call->getArgOperand(0);
171*9880d681SAndroid Build Coastguard Worker if (const CallInst *ConvCall = dyn_cast<CallInst>(Str)) {
172*9880d681SAndroid Build Coastguard Worker // FIXME: Add assertions about ConvCall.
173*9880d681SAndroid Build Coastguard Worker Str = ConvCall->getArgOperand(0);
174*9880d681SAndroid Build Coastguard Worker }
175*9880d681SAndroid Build Coastguard Worker assert(isa<ConstantExpr>(Str) &&
176*9880d681SAndroid Build Coastguard Worker "Format of __nvvm__reflect function not recognized");
177*9880d681SAndroid Build Coastguard Worker const ConstantExpr *GEP = cast<ConstantExpr>(Str);
178*9880d681SAndroid Build Coastguard Worker
179*9880d681SAndroid Build Coastguard Worker const Value *Sym = GEP->getOperand(0);
180*9880d681SAndroid Build Coastguard Worker assert(isa<Constant>(Sym) &&
181*9880d681SAndroid Build Coastguard Worker "Format of __nvvm_reflect function not recognized");
182*9880d681SAndroid Build Coastguard Worker
183*9880d681SAndroid Build Coastguard Worker const Value *Operand = cast<Constant>(Sym)->getOperand(0);
184*9880d681SAndroid Build Coastguard Worker if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Operand)) {
185*9880d681SAndroid Build Coastguard Worker // For CUDA-7.0 style __nvvm_reflect calls, we need to find the operand's
186*9880d681SAndroid Build Coastguard Worker // initializer.
187*9880d681SAndroid Build Coastguard Worker assert(GV->hasInitializer() &&
188*9880d681SAndroid Build Coastguard Worker "Format of _reflect function not recognized");
189*9880d681SAndroid Build Coastguard Worker const Constant *Initializer = GV->getInitializer();
190*9880d681SAndroid Build Coastguard Worker Operand = Initializer;
191*9880d681SAndroid Build Coastguard Worker }
192*9880d681SAndroid Build Coastguard Worker
193*9880d681SAndroid Build Coastguard Worker assert(isa<ConstantDataSequential>(Operand) &&
194*9880d681SAndroid Build Coastguard Worker "Format of _reflect function not recognized");
195*9880d681SAndroid Build Coastguard Worker assert(cast<ConstantDataSequential>(Operand)->isCString() &&
196*9880d681SAndroid Build Coastguard Worker "Format of _reflect function not recognized");
197*9880d681SAndroid Build Coastguard Worker
198*9880d681SAndroid Build Coastguard Worker StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString();
199*9880d681SAndroid Build Coastguard Worker ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1);
200*9880d681SAndroid Build Coastguard Worker DEBUG(dbgs() << "Arg of _reflect : " << ReflectArg << "\n");
201*9880d681SAndroid Build Coastguard Worker
202*9880d681SAndroid Build Coastguard Worker int ReflectVal = 0; // The default value is 0
203*9880d681SAndroid Build Coastguard Worker auto Iter = VarMap.find(ReflectArg);
204*9880d681SAndroid Build Coastguard Worker if (Iter != VarMap.end())
205*9880d681SAndroid Build Coastguard Worker ReflectVal = Iter->second;
206*9880d681SAndroid Build Coastguard Worker else if (ReflectArg == "__CUDA_FTZ") {
207*9880d681SAndroid Build Coastguard Worker // Try to pull __CUDA_FTZ from the nvvm-reflect-ftz module flag.
208*9880d681SAndroid Build Coastguard Worker if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
209*9880d681SAndroid Build Coastguard Worker F.getParent()->getModuleFlag("nvvm-reflect-ftz")))
210*9880d681SAndroid Build Coastguard Worker ReflectVal = Flag->getSExtValue();
211*9880d681SAndroid Build Coastguard Worker }
212*9880d681SAndroid Build Coastguard Worker Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
213*9880d681SAndroid Build Coastguard Worker ToRemove.push_back(Call);
214*9880d681SAndroid Build Coastguard Worker }
215*9880d681SAndroid Build Coastguard Worker
216*9880d681SAndroid Build Coastguard Worker for (Instruction *I : ToRemove)
217*9880d681SAndroid Build Coastguard Worker I->eraseFromParent();
218*9880d681SAndroid Build Coastguard Worker
219*9880d681SAndroid Build Coastguard Worker return ToRemove.size() > 0;
220*9880d681SAndroid Build Coastguard Worker }
221