xref: /aosp_15_r20/external/llvm/lib/Target/NVPTX/NVVMReflect.cpp (revision 9880d6810fe72a1726cb53787c6711e909410d58)
1*9880d681SAndroid Build Coastguard Worker //===- NVVMReflect.cpp - NVVM Emulate conditional compilation -------------===//
2*9880d681SAndroid Build Coastguard Worker //
3*9880d681SAndroid Build Coastguard Worker //                     The LLVM Compiler Infrastructure
4*9880d681SAndroid Build Coastguard Worker //
5*9880d681SAndroid Build Coastguard Worker // This file is distributed under the University of Illinois Open Source
6*9880d681SAndroid Build Coastguard Worker // License. See LICENSE.TXT for details.
7*9880d681SAndroid Build Coastguard Worker //
8*9880d681SAndroid Build Coastguard Worker //===----------------------------------------------------------------------===//
9*9880d681SAndroid Build Coastguard Worker //
10*9880d681SAndroid Build Coastguard Worker // This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
11*9880d681SAndroid Build Coastguard Worker // with an integer.
12*9880d681SAndroid Build Coastguard Worker //
13*9880d681SAndroid Build Coastguard Worker // We choose the value we use by looking, in this order, at:
14*9880d681SAndroid Build Coastguard Worker //
15*9880d681SAndroid Build Coastguard Worker //  * the -nvvm-reflect-list flag, which has the format "foo=1,bar=42",
16*9880d681SAndroid Build Coastguard Worker //  * the StringMap passed to the pass's constructor, and
17*9880d681SAndroid Build Coastguard Worker //  * metadata in the module itself.
18*9880d681SAndroid Build Coastguard Worker //
19*9880d681SAndroid Build Coastguard Worker // If we see an unknown string, we replace its call with 0.
20*9880d681SAndroid Build Coastguard Worker //
21*9880d681SAndroid Build Coastguard Worker //===----------------------------------------------------------------------===//
22*9880d681SAndroid Build Coastguard Worker 
23*9880d681SAndroid Build Coastguard Worker #include "NVPTX.h"
24*9880d681SAndroid Build Coastguard Worker #include "llvm/ADT/SmallVector.h"
25*9880d681SAndroid Build Coastguard Worker #include "llvm/ADT/StringMap.h"
26*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/Constants.h"
27*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/DerivedTypes.h"
28*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/Function.h"
29*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/InstIterator.h"
30*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/Instructions.h"
31*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/Intrinsics.h"
32*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/Module.h"
33*9880d681SAndroid Build Coastguard Worker #include "llvm/IR/Type.h"
34*9880d681SAndroid Build Coastguard Worker #include "llvm/Pass.h"
35*9880d681SAndroid Build Coastguard Worker #include "llvm/Support/CommandLine.h"
36*9880d681SAndroid Build Coastguard Worker #include "llvm/Support/Debug.h"
37*9880d681SAndroid Build Coastguard Worker #include "llvm/Support/raw_os_ostream.h"
38*9880d681SAndroid Build Coastguard Worker #include "llvm/Support/raw_ostream.h"
39*9880d681SAndroid Build Coastguard Worker #include "llvm/Transforms/Scalar.h"
40*9880d681SAndroid Build Coastguard Worker #include <sstream>
41*9880d681SAndroid Build Coastguard Worker #include <string>
42*9880d681SAndroid Build Coastguard Worker #define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
43*9880d681SAndroid Build Coastguard Worker 
44*9880d681SAndroid Build Coastguard Worker using namespace llvm;
45*9880d681SAndroid Build Coastguard Worker 
46*9880d681SAndroid Build Coastguard Worker #define DEBUG_TYPE "nvptx-reflect"
47*9880d681SAndroid Build Coastguard Worker 
48*9880d681SAndroid Build Coastguard Worker namespace llvm { void initializeNVVMReflectPass(PassRegistry &); }
49*9880d681SAndroid Build Coastguard Worker 
50*9880d681SAndroid Build Coastguard Worker namespace {
51*9880d681SAndroid Build Coastguard Worker class NVVMReflect : public FunctionPass {
52*9880d681SAndroid Build Coastguard Worker private:
53*9880d681SAndroid Build Coastguard Worker   StringMap<int> VarMap;
54*9880d681SAndroid Build Coastguard Worker 
55*9880d681SAndroid Build Coastguard Worker public:
56*9880d681SAndroid Build Coastguard Worker   static char ID;
NVVMReflect()57*9880d681SAndroid Build Coastguard Worker   NVVMReflect() : NVVMReflect(StringMap<int>()) {}
58*9880d681SAndroid Build Coastguard Worker 
NVVMReflect(const StringMap<int> & Mapping)59*9880d681SAndroid Build Coastguard Worker   NVVMReflect(const StringMap<int> &Mapping)
60*9880d681SAndroid Build Coastguard Worker       : FunctionPass(ID), VarMap(Mapping) {
61*9880d681SAndroid Build Coastguard Worker     initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
62*9880d681SAndroid Build Coastguard Worker     setVarMap();
63*9880d681SAndroid Build Coastguard Worker   }
64*9880d681SAndroid Build Coastguard Worker 
65*9880d681SAndroid Build Coastguard Worker   bool runOnFunction(Function &) override;
66*9880d681SAndroid Build Coastguard Worker 
67*9880d681SAndroid Build Coastguard Worker private:
68*9880d681SAndroid Build Coastguard Worker   bool handleFunction(Function *ReflectFunction);
69*9880d681SAndroid Build Coastguard Worker   void setVarMap();
70*9880d681SAndroid Build Coastguard Worker };
71*9880d681SAndroid Build Coastguard Worker }
72*9880d681SAndroid Build Coastguard Worker 
createNVVMReflectPass()73*9880d681SAndroid Build Coastguard Worker FunctionPass *llvm::createNVVMReflectPass() { return new NVVMReflect(); }
createNVVMReflectPass(const StringMap<int> & Mapping)74*9880d681SAndroid Build Coastguard Worker FunctionPass *llvm::createNVVMReflectPass(const StringMap<int> &Mapping) {
75*9880d681SAndroid Build Coastguard Worker   return new NVVMReflect(Mapping);
76*9880d681SAndroid Build Coastguard Worker }
77*9880d681SAndroid Build Coastguard Worker 
78*9880d681SAndroid Build Coastguard Worker static cl::opt<bool>
79*9880d681SAndroid Build Coastguard Worker NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden,
80*9880d681SAndroid Build Coastguard Worker                    cl::desc("NVVM reflection, enabled by default"));
81*9880d681SAndroid Build Coastguard Worker 
82*9880d681SAndroid Build Coastguard Worker char NVVMReflect::ID = 0;
83*9880d681SAndroid Build Coastguard Worker INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
84*9880d681SAndroid Build Coastguard Worker                 "Replace occurrences of __nvvm_reflect() calls with 0/1", false,
85*9880d681SAndroid Build Coastguard Worker                 false)
86*9880d681SAndroid Build Coastguard Worker 
87*9880d681SAndroid Build Coastguard Worker static cl::list<std::string>
88*9880d681SAndroid Build Coastguard Worker ReflectList("nvvm-reflect-list", cl::value_desc("name=<int>"), cl::Hidden,
89*9880d681SAndroid Build Coastguard Worker             cl::desc("A list of string=num assignments"),
90*9880d681SAndroid Build Coastguard Worker             cl::ValueRequired);
91*9880d681SAndroid Build Coastguard Worker 
92*9880d681SAndroid Build Coastguard Worker /// The command line can look as follows :
93*9880d681SAndroid Build Coastguard Worker /// -nvvm-reflect-list a=1,b=2 -nvvm-reflect-list c=3,d=0 -R e=2
94*9880d681SAndroid Build Coastguard Worker /// The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the
95*9880d681SAndroid Build Coastguard Worker /// ReflectList vector. First, each of ReflectList[i] is 'split'
96*9880d681SAndroid Build Coastguard Worker /// using "," as the delimiter. Then each of this part is split
97*9880d681SAndroid Build Coastguard Worker /// using "=" as the delimiter.
setVarMap()98*9880d681SAndroid Build Coastguard Worker void NVVMReflect::setVarMap() {
99*9880d681SAndroid Build Coastguard Worker   for (unsigned i = 0, e = ReflectList.size(); i != e; ++i) {
100*9880d681SAndroid Build Coastguard Worker     DEBUG(dbgs() << "Option : "  << ReflectList[i] << "\n");
101*9880d681SAndroid Build Coastguard Worker     SmallVector<StringRef, 4> NameValList;
102*9880d681SAndroid Build Coastguard Worker     StringRef(ReflectList[i]).split(NameValList, ',');
103*9880d681SAndroid Build Coastguard Worker     for (unsigned j = 0, ej = NameValList.size(); j != ej; ++j) {
104*9880d681SAndroid Build Coastguard Worker       SmallVector<StringRef, 2> NameValPair;
105*9880d681SAndroid Build Coastguard Worker       NameValList[j].split(NameValPair, '=');
106*9880d681SAndroid Build Coastguard Worker       assert(NameValPair.size() == 2 && "name=val expected");
107*9880d681SAndroid Build Coastguard Worker       std::stringstream ValStream(NameValPair[1]);
108*9880d681SAndroid Build Coastguard Worker       int Val;
109*9880d681SAndroid Build Coastguard Worker       ValStream >> Val;
110*9880d681SAndroid Build Coastguard Worker       assert((!(ValStream.fail())) && "integer value expected");
111*9880d681SAndroid Build Coastguard Worker       VarMap[NameValPair[0]] = Val;
112*9880d681SAndroid Build Coastguard Worker     }
113*9880d681SAndroid Build Coastguard Worker   }
114*9880d681SAndroid Build Coastguard Worker }
115*9880d681SAndroid Build Coastguard Worker 
runOnFunction(Function & F)116*9880d681SAndroid Build Coastguard Worker bool NVVMReflect::runOnFunction(Function &F) {
117*9880d681SAndroid Build Coastguard Worker   if (!NVVMReflectEnabled)
118*9880d681SAndroid Build Coastguard Worker     return false;
119*9880d681SAndroid Build Coastguard Worker 
120*9880d681SAndroid Build Coastguard Worker   if (F.getName() == NVVM_REFLECT_FUNCTION) {
121*9880d681SAndroid Build Coastguard Worker     assert(F.isDeclaration() && "_reflect function should not have a body");
122*9880d681SAndroid Build Coastguard Worker     assert(F.getReturnType()->isIntegerTy() &&
123*9880d681SAndroid Build Coastguard Worker            "_reflect's return type should be integer");
124*9880d681SAndroid Build Coastguard Worker     return false;
125*9880d681SAndroid Build Coastguard Worker   }
126*9880d681SAndroid Build Coastguard Worker 
127*9880d681SAndroid Build Coastguard Worker   SmallVector<Instruction *, 4> ToRemove;
128*9880d681SAndroid Build Coastguard Worker 
129*9880d681SAndroid Build Coastguard Worker   // Go through the calls in this function.  Each call to __nvvm_reflect or
130*9880d681SAndroid Build Coastguard Worker   // llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
131*9880d681SAndroid Build Coastguard Worker   // First validate that. If the c-string corresponding to the ConstantArray can
132*9880d681SAndroid Build Coastguard Worker   // be found successfully, see if it can be found in VarMap. If so, replace the
133*9880d681SAndroid Build Coastguard Worker   // uses of CallInst with the value found in VarMap. If not, replace the use
134*9880d681SAndroid Build Coastguard Worker   // with value 0.
135*9880d681SAndroid Build Coastguard Worker 
136*9880d681SAndroid Build Coastguard Worker   // The IR for __nvvm_reflect calls differs between CUDA versions.
137*9880d681SAndroid Build Coastguard Worker   //
138*9880d681SAndroid Build Coastguard Worker   // CUDA 6.5 and earlier uses this sequence:
139*9880d681SAndroid Build Coastguard Worker   //    %ptr = tail call i8* @llvm.nvvm.ptr.constant.to.gen.p0i8.p4i8
140*9880d681SAndroid Build Coastguard Worker   //        (i8 addrspace(4)* getelementptr inbounds
141*9880d681SAndroid Build Coastguard Worker   //           ([8 x i8], [8 x i8] addrspace(4)* @str, i32 0, i32 0))
142*9880d681SAndroid Build Coastguard Worker   //    %reflect = tail call i32 @__nvvm_reflect(i8* %ptr)
143*9880d681SAndroid Build Coastguard Worker   //
144*9880d681SAndroid Build Coastguard Worker   // The value returned by Sym->getOperand(0) is a Constant with a
145*9880d681SAndroid Build Coastguard Worker   // ConstantDataSequential operand which can be converted to string and used
146*9880d681SAndroid Build Coastguard Worker   // for lookup.
147*9880d681SAndroid Build Coastguard Worker   //
148*9880d681SAndroid Build Coastguard Worker   // CUDA 7.0 does it slightly differently:
149*9880d681SAndroid Build Coastguard Worker   //   %reflect = call i32 @__nvvm_reflect(i8* addrspacecast
150*9880d681SAndroid Build Coastguard Worker   //        (i8 addrspace(1)* getelementptr inbounds
151*9880d681SAndroid Build Coastguard Worker   //           ([8 x i8], [8 x i8] addrspace(1)* @str, i32 0, i32 0) to i8*))
152*9880d681SAndroid Build Coastguard Worker   //
153*9880d681SAndroid Build Coastguard Worker   // In this case, we get a Constant with a GlobalVariable operand and we need
154*9880d681SAndroid Build Coastguard Worker   // to dig deeper to find its initializer with the string we'll use for lookup.
155*9880d681SAndroid Build Coastguard Worker   for (Instruction &I : instructions(F)) {
156*9880d681SAndroid Build Coastguard Worker     CallInst *Call = dyn_cast<CallInst>(&I);
157*9880d681SAndroid Build Coastguard Worker     if (!Call)
158*9880d681SAndroid Build Coastguard Worker       continue;
159*9880d681SAndroid Build Coastguard Worker     Function *Callee = Call->getCalledFunction();
160*9880d681SAndroid Build Coastguard Worker     if (!Callee || (Callee->getName() != NVVM_REFLECT_FUNCTION &&
161*9880d681SAndroid Build Coastguard Worker                     Callee->getIntrinsicID() != Intrinsic::nvvm_reflect))
162*9880d681SAndroid Build Coastguard Worker       continue;
163*9880d681SAndroid Build Coastguard Worker 
164*9880d681SAndroid Build Coastguard Worker     // FIXME: Improve error handling here and elsewhere in this pass.
165*9880d681SAndroid Build Coastguard Worker     assert(Call->getNumOperands() == 2 &&
166*9880d681SAndroid Build Coastguard Worker            "Wrong number of operands to __nvvm_reflect function");
167*9880d681SAndroid Build Coastguard Worker 
168*9880d681SAndroid Build Coastguard Worker     // In cuda 6.5 and earlier, we will have an extra constant-to-generic
169*9880d681SAndroid Build Coastguard Worker     // conversion of the string.
170*9880d681SAndroid Build Coastguard Worker     const Value *Str = Call->getArgOperand(0);
171*9880d681SAndroid Build Coastguard Worker     if (const CallInst *ConvCall = dyn_cast<CallInst>(Str)) {
172*9880d681SAndroid Build Coastguard Worker       // FIXME: Add assertions about ConvCall.
173*9880d681SAndroid Build Coastguard Worker       Str = ConvCall->getArgOperand(0);
174*9880d681SAndroid Build Coastguard Worker     }
175*9880d681SAndroid Build Coastguard Worker     assert(isa<ConstantExpr>(Str) &&
176*9880d681SAndroid Build Coastguard Worker            "Format of __nvvm__reflect function not recognized");
177*9880d681SAndroid Build Coastguard Worker     const ConstantExpr *GEP = cast<ConstantExpr>(Str);
178*9880d681SAndroid Build Coastguard Worker 
179*9880d681SAndroid Build Coastguard Worker     const Value *Sym = GEP->getOperand(0);
180*9880d681SAndroid Build Coastguard Worker     assert(isa<Constant>(Sym) &&
181*9880d681SAndroid Build Coastguard Worker            "Format of __nvvm_reflect function not recognized");
182*9880d681SAndroid Build Coastguard Worker 
183*9880d681SAndroid Build Coastguard Worker     const Value *Operand = cast<Constant>(Sym)->getOperand(0);
184*9880d681SAndroid Build Coastguard Worker     if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Operand)) {
185*9880d681SAndroid Build Coastguard Worker       // For CUDA-7.0 style __nvvm_reflect calls, we need to find the operand's
186*9880d681SAndroid Build Coastguard Worker       // initializer.
187*9880d681SAndroid Build Coastguard Worker       assert(GV->hasInitializer() &&
188*9880d681SAndroid Build Coastguard Worker              "Format of _reflect function not recognized");
189*9880d681SAndroid Build Coastguard Worker       const Constant *Initializer = GV->getInitializer();
190*9880d681SAndroid Build Coastguard Worker       Operand = Initializer;
191*9880d681SAndroid Build Coastguard Worker     }
192*9880d681SAndroid Build Coastguard Worker 
193*9880d681SAndroid Build Coastguard Worker     assert(isa<ConstantDataSequential>(Operand) &&
194*9880d681SAndroid Build Coastguard Worker            "Format of _reflect function not recognized");
195*9880d681SAndroid Build Coastguard Worker     assert(cast<ConstantDataSequential>(Operand)->isCString() &&
196*9880d681SAndroid Build Coastguard Worker            "Format of _reflect function not recognized");
197*9880d681SAndroid Build Coastguard Worker 
198*9880d681SAndroid Build Coastguard Worker     StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString();
199*9880d681SAndroid Build Coastguard Worker     ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1);
200*9880d681SAndroid Build Coastguard Worker     DEBUG(dbgs() << "Arg of _reflect : " << ReflectArg << "\n");
201*9880d681SAndroid Build Coastguard Worker 
202*9880d681SAndroid Build Coastguard Worker     int ReflectVal = 0; // The default value is 0
203*9880d681SAndroid Build Coastguard Worker     auto Iter = VarMap.find(ReflectArg);
204*9880d681SAndroid Build Coastguard Worker     if (Iter != VarMap.end())
205*9880d681SAndroid Build Coastguard Worker       ReflectVal = Iter->second;
206*9880d681SAndroid Build Coastguard Worker     else if (ReflectArg == "__CUDA_FTZ") {
207*9880d681SAndroid Build Coastguard Worker       // Try to pull __CUDA_FTZ from the nvvm-reflect-ftz module flag.
208*9880d681SAndroid Build Coastguard Worker       if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
209*9880d681SAndroid Build Coastguard Worker               F.getParent()->getModuleFlag("nvvm-reflect-ftz")))
210*9880d681SAndroid Build Coastguard Worker         ReflectVal = Flag->getSExtValue();
211*9880d681SAndroid Build Coastguard Worker     }
212*9880d681SAndroid Build Coastguard Worker     Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
213*9880d681SAndroid Build Coastguard Worker     ToRemove.push_back(Call);
214*9880d681SAndroid Build Coastguard Worker   }
215*9880d681SAndroid Build Coastguard Worker 
216*9880d681SAndroid Build Coastguard Worker   for (Instruction *I : ToRemove)
217*9880d681SAndroid Build Coastguard Worker     I->eraseFromParent();
218*9880d681SAndroid Build Coastguard Worker 
219*9880d681SAndroid Build Coastguard Worker   return ToRemove.size() > 0;
220*9880d681SAndroid Build Coastguard Worker }
221