1 //===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a printer that converts from our internal representation
10 // of machine-dependent LLVM code to NVPTX assembly language.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "NVPTXAsmPrinter.h"
15 #include "MCTargetDesc/NVPTXBaseInfo.h"
16 #include "MCTargetDesc/NVPTXInstPrinter.h"
17 #include "MCTargetDesc/NVPTXMCAsmInfo.h"
18 #include "MCTargetDesc/NVPTXTargetStreamer.h"
19 #include "NVPTX.h"
20 #include "NVPTXMCExpr.h"
21 #include "NVPTXMachineFunctionInfo.h"
22 #include "NVPTXRegisterInfo.h"
23 #include "NVPTXSubtarget.h"
24 #include "NVPTXTargetMachine.h"
25 #include "NVPTXUtilities.h"
26 #include "TargetInfo/NVPTXTargetInfo.h"
27 #include "cl_common_defines.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/ADT/Triple.h"
37 #include "llvm/ADT/Twine.h"
38 #include "llvm/Analysis/ConstantFolding.h"
39 #include "llvm/CodeGen/Analysis.h"
40 #include "llvm/CodeGen/MachineBasicBlock.h"
41 #include "llvm/CodeGen/MachineFrameInfo.h"
42 #include "llvm/CodeGen/MachineFunction.h"
43 #include "llvm/CodeGen/MachineInstr.h"
44 #include "llvm/CodeGen/MachineLoopInfo.h"
45 #include "llvm/CodeGen/MachineModuleInfo.h"
46 #include "llvm/CodeGen/MachineOperand.h"
47 #include "llvm/CodeGen/MachineRegisterInfo.h"
48 #include "llvm/CodeGen/TargetRegisterInfo.h"
49 #include "llvm/CodeGen/ValueTypes.h"
50 #include "llvm/IR/Attributes.h"
51 #include "llvm/IR/BasicBlock.h"
52 #include "llvm/IR/Constant.h"
53 #include "llvm/IR/Constants.h"
54 #include "llvm/IR/DataLayout.h"
55 #include "llvm/IR/DebugInfo.h"
56 #include "llvm/IR/DebugInfoMetadata.h"
57 #include "llvm/IR/DebugLoc.h"
58 #include "llvm/IR/DerivedTypes.h"
59 #include "llvm/IR/Function.h"
60 #include "llvm/IR/GlobalValue.h"
61 #include "llvm/IR/GlobalVariable.h"
62 #include "llvm/IR/Instruction.h"
63 #include "llvm/IR/LLVMContext.h"
64 #include "llvm/IR/Module.h"
65 #include "llvm/IR/Operator.h"
66 #include "llvm/IR/Type.h"
67 #include "llvm/IR/User.h"
68 #include "llvm/MC/MCExpr.h"
69 #include "llvm/MC/MCInst.h"
70 #include "llvm/MC/MCInstrDesc.h"
71 #include "llvm/MC/MCStreamer.h"
72 #include "llvm/MC/MCSymbol.h"
73 #include "llvm/MC/TargetRegistry.h"
74 #include "llvm/Support/Casting.h"
75 #include "llvm/Support/CommandLine.h"
76 #include "llvm/Support/Endian.h"
77 #include "llvm/Support/ErrorHandling.h"
78 #include "llvm/Support/MachineValueType.h"
79 #include "llvm/Support/NativeFormatting.h"
80 #include "llvm/Support/Path.h"
81 #include "llvm/Support/raw_ostream.h"
82 #include "llvm/Target/TargetLoweringObjectFile.h"
83 #include "llvm/Target/TargetMachine.h"
84 #include "llvm/Transforms/Utils/UnrollLoop.h"
85 #include <cassert>
86 #include <cstdint>
87 #include <cstring>
88 #include <new>
89 #include <string>
90 #include <utility>
91 #include <vector>
92
93 using namespace llvm;
94
95 #define DEPOTNAME "__local_depot"
96
97 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
98 /// depends.
99 static void
DiscoverDependentGlobals(const Value * V,DenseSet<const GlobalVariable * > & Globals)100 DiscoverDependentGlobals(const Value *V,
101 DenseSet<const GlobalVariable *> &Globals) {
102 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
103 Globals.insert(GV);
104 else {
105 if (const User *U = dyn_cast<User>(V)) {
106 for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
107 DiscoverDependentGlobals(U->getOperand(i), Globals);
108 }
109 }
110 }
111 }
112
113 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
114 /// instances to be emitted, but only after any dependents have been added
115 /// first.s
116 static void
VisitGlobalVariableForEmission(const GlobalVariable * GV,SmallVectorImpl<const GlobalVariable * > & Order,DenseSet<const GlobalVariable * > & Visited,DenseSet<const GlobalVariable * > & Visiting)117 VisitGlobalVariableForEmission(const GlobalVariable *GV,
118 SmallVectorImpl<const GlobalVariable *> &Order,
119 DenseSet<const GlobalVariable *> &Visited,
120 DenseSet<const GlobalVariable *> &Visiting) {
121 // Have we already visited this one?
122 if (Visited.count(GV))
123 return;
124
125 // Do we have a circular dependency?
126 if (!Visiting.insert(GV).second)
127 report_fatal_error("Circular dependency found in global variable set");
128
129 // Make sure we visit all dependents first
130 DenseSet<const GlobalVariable *> Others;
131 for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
132 DiscoverDependentGlobals(GV->getOperand(i), Others);
133
134 for (const GlobalVariable *GV : Others)
135 VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
136
137 // Now we can visit ourself
138 Order.push_back(GV);
139 Visited.insert(GV);
140 Visiting.erase(GV);
141 }
142
emitInstruction(const MachineInstr * MI)143 void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
144 NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(),
145 getSubtargetInfo().getFeatureBits());
146
147 MCInst Inst;
148 lowerToMCInst(MI, Inst);
149 EmitToStreamer(*OutStreamer, Inst);
150 }
151
152 // Handle symbol backtracking for targets that do not support image handles
lowerImageHandleOperand(const MachineInstr * MI,unsigned OpNo,MCOperand & MCOp)153 bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
154 unsigned OpNo, MCOperand &MCOp) {
155 const MachineOperand &MO = MI->getOperand(OpNo);
156 const MCInstrDesc &MCID = MI->getDesc();
157
158 if (MCID.TSFlags & NVPTXII::IsTexFlag) {
159 // This is a texture fetch, so operand 4 is a texref and operand 5 is
160 // a samplerref
161 if (OpNo == 4 && MO.isImm()) {
162 lowerImageHandleSymbol(MO.getImm(), MCOp);
163 return true;
164 }
165 if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
166 lowerImageHandleSymbol(MO.getImm(), MCOp);
167 return true;
168 }
169
170 return false;
171 } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
172 unsigned VecSize =
173 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
174
175 // For a surface load of vector size N, the Nth operand will be the surfref
176 if (OpNo == VecSize && MO.isImm()) {
177 lowerImageHandleSymbol(MO.getImm(), MCOp);
178 return true;
179 }
180
181 return false;
182 } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
183 // This is a surface store, so operand 0 is a surfref
184 if (OpNo == 0 && MO.isImm()) {
185 lowerImageHandleSymbol(MO.getImm(), MCOp);
186 return true;
187 }
188
189 return false;
190 } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
191 // This is a query, so operand 1 is a surfref/texref
192 if (OpNo == 1 && MO.isImm()) {
193 lowerImageHandleSymbol(MO.getImm(), MCOp);
194 return true;
195 }
196
197 return false;
198 }
199
200 return false;
201 }
202
lowerImageHandleSymbol(unsigned Index,MCOperand & MCOp)203 void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
204 // Ewwww
205 LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget());
206 NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM);
207 const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
208 const char *Sym = MFI->getImageHandleSymbol(Index);
209 StringRef SymName = nvTM.getStrPool().save(Sym);
210 MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));
211 }
212
lowerToMCInst(const MachineInstr * MI,MCInst & OutMI)213 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
214 OutMI.setOpcode(MI->getOpcode());
215 // Special: Do not mangle symbol operand of CALL_PROTOTYPE
216 if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
217 const MachineOperand &MO = MI->getOperand(0);
218 OutMI.addOperand(GetSymbolRef(
219 OutContext.getOrCreateSymbol(Twine(MO.getSymbolName()))));
220 return;
221 }
222
223 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
224 for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
225 const MachineOperand &MO = MI->getOperand(i);
226
227 MCOperand MCOp;
228 if (!STI.hasImageHandles()) {
229 if (lowerImageHandleOperand(MI, i, MCOp)) {
230 OutMI.addOperand(MCOp);
231 continue;
232 }
233 }
234
235 if (lowerOperand(MO, MCOp))
236 OutMI.addOperand(MCOp);
237 }
238 }
239
lowerOperand(const MachineOperand & MO,MCOperand & MCOp)240 bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
241 MCOperand &MCOp) {
242 switch (MO.getType()) {
243 default: llvm_unreachable("unknown operand type");
244 case MachineOperand::MO_Register:
245 MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
246 break;
247 case MachineOperand::MO_Immediate:
248 MCOp = MCOperand::createImm(MO.getImm());
249 break;
250 case MachineOperand::MO_MachineBasicBlock:
251 MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
252 MO.getMBB()->getSymbol(), OutContext));
253 break;
254 case MachineOperand::MO_ExternalSymbol:
255 MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
256 break;
257 case MachineOperand::MO_GlobalAddress:
258 MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
259 break;
260 case MachineOperand::MO_FPImmediate: {
261 const ConstantFP *Cnt = MO.getFPImm();
262 const APFloat &Val = Cnt->getValueAPF();
263
264 switch (Cnt->getType()->getTypeID()) {
265 default: report_fatal_error("Unsupported FP type"); break;
266 case Type::HalfTyID:
267 MCOp = MCOperand::createExpr(
268 NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
269 break;
270 case Type::FloatTyID:
271 MCOp = MCOperand::createExpr(
272 NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
273 break;
274 case Type::DoubleTyID:
275 MCOp = MCOperand::createExpr(
276 NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
277 break;
278 }
279 break;
280 }
281 }
282 return true;
283 }
284
encodeVirtualRegister(unsigned Reg)285 unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
286 if (Register::isVirtualRegister(Reg)) {
287 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
288
289 DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
290 unsigned RegNum = RegMap[Reg];
291
292 // Encode the register class in the upper 4 bits
293 // Must be kept in sync with NVPTXInstPrinter::printRegName
294 unsigned Ret = 0;
295 if (RC == &NVPTX::Int1RegsRegClass) {
296 Ret = (1 << 28);
297 } else if (RC == &NVPTX::Int16RegsRegClass) {
298 Ret = (2 << 28);
299 } else if (RC == &NVPTX::Int32RegsRegClass) {
300 Ret = (3 << 28);
301 } else if (RC == &NVPTX::Int64RegsRegClass) {
302 Ret = (4 << 28);
303 } else if (RC == &NVPTX::Float32RegsRegClass) {
304 Ret = (5 << 28);
305 } else if (RC == &NVPTX::Float64RegsRegClass) {
306 Ret = (6 << 28);
307 } else if (RC == &NVPTX::Float16RegsRegClass) {
308 Ret = (7 << 28);
309 } else if (RC == &NVPTX::Float16x2RegsRegClass) {
310 Ret = (8 << 28);
311 } else {
312 report_fatal_error("Bad register class");
313 }
314
315 // Insert the vreg number
316 Ret |= (RegNum & 0x0FFFFFFF);
317 return Ret;
318 } else {
319 // Some special-use registers are actually physical registers.
320 // Encode this as the register class ID of 0 and the real register ID.
321 return Reg & 0x0FFFFFFF;
322 }
323 }
324
GetSymbolRef(const MCSymbol * Symbol)325 MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
326 const MCExpr *Expr;
327 Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None,
328 OutContext);
329 return MCOperand::createExpr(Expr);
330 }
331
printReturnValStr(const Function * F,raw_ostream & O)332 void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
333 const DataLayout &DL = getDataLayout();
334 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
335 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
336
337 Type *Ty = F->getReturnType();
338
339 bool isABI = (STI.getSmVersion() >= 20);
340
341 if (Ty->getTypeID() == Type::VoidTyID)
342 return;
343
344 O << " (";
345
346 if (isABI) {
347 if (Ty->isFloatingPointTy() || (Ty->isIntegerTy() && !Ty->isIntegerTy(128))) {
348 unsigned size = 0;
349 if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
350 size = ITy->getBitWidth();
351 } else {
352 assert(Ty->isFloatingPointTy() && "Floating point type expected here");
353 size = Ty->getPrimitiveSizeInBits();
354 }
355 // PTX ABI requires all scalar return values to be at least 32
356 // bits in size. fp16 normally uses .b16 as its storage type in
357 // PTX, so its size must be adjusted here, too.
358 size = promoteScalarArgumentSize(size);
359
360 O << ".param .b" << size << " func_retval0";
361 } else if (isa<PointerType>(Ty)) {
362 O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
363 << " func_retval0";
364 } else if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
365 unsigned totalsz = DL.getTypeAllocSize(Ty);
366 unsigned retAlignment = 0;
367 if (!getAlign(*F, 0, retAlignment))
368 retAlignment = TLI->getFunctionParamOptimizedAlign(F, Ty, DL).value();
369 O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz
370 << "]";
371 } else
372 llvm_unreachable("Unknown return type");
373 } else {
374 SmallVector<EVT, 16> vtparts;
375 ComputeValueVTs(*TLI, DL, Ty, vtparts);
376 unsigned idx = 0;
377 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
378 unsigned elems = 1;
379 EVT elemtype = vtparts[i];
380 if (vtparts[i].isVector()) {
381 elems = vtparts[i].getVectorNumElements();
382 elemtype = vtparts[i].getVectorElementType();
383 }
384
385 for (unsigned j = 0, je = elems; j != je; ++j) {
386 unsigned sz = elemtype.getSizeInBits();
387 if (elemtype.isInteger())
388 sz = promoteScalarArgumentSize(sz);
389 O << ".reg .b" << sz << " func_retval" << idx;
390 if (j < je - 1)
391 O << ", ";
392 ++idx;
393 }
394 if (i < e - 1)
395 O << ", ";
396 }
397 }
398 O << ") ";
399 }
400
printReturnValStr(const MachineFunction & MF,raw_ostream & O)401 void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
402 raw_ostream &O) {
403 const Function &F = MF.getFunction();
404 printReturnValStr(&F, O);
405 }
406
407 // Return true if MBB is the header of a loop marked with
408 // llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
isLoopHeaderOfNoUnroll(const MachineBasicBlock & MBB) const409 bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
410 const MachineBasicBlock &MBB) const {
411 MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>();
412 // We insert .pragma "nounroll" only to the loop header.
413 if (!LI.isLoopHeader(&MBB))
414 return false;
415
416 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
417 // we iterate through each back edge of the loop with header MBB, and check
418 // whether its metadata contains llvm.loop.unroll.disable.
419 for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
420 if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) {
421 // Edges from other loops to MBB are not back edges.
422 continue;
423 }
424 if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
425 if (MDNode *LoopID =
426 PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) {
427 if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable"))
428 return true;
429 if (MDNode *UnrollCountMD =
430 GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) {
431 if (mdconst::extract<ConstantInt>(UnrollCountMD->getOperand(1))
432 ->getZExtValue() == 1)
433 return true;
434 }
435 }
436 }
437 }
438 return false;
439 }
440
emitBasicBlockStart(const MachineBasicBlock & MBB)441 void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
442 AsmPrinter::emitBasicBlockStart(MBB);
443 if (isLoopHeaderOfNoUnroll(MBB))
444 OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));
445 }
446
emitFunctionEntryLabel()447 void NVPTXAsmPrinter::emitFunctionEntryLabel() {
448 SmallString<128> Str;
449 raw_svector_ostream O(Str);
450
451 if (!GlobalsEmitted) {
452 emitGlobals(*MF->getFunction().getParent());
453 GlobalsEmitted = true;
454 }
455
456 // Set up
457 MRI = &MF->getRegInfo();
458 F = &MF->getFunction();
459 emitLinkageDirective(F, O);
460 if (isKernelFunction(*F))
461 O << ".entry ";
462 else {
463 O << ".func ";
464 printReturnValStr(*MF, O);
465 }
466
467 CurrentFnSym->print(O, MAI);
468
469 emitFunctionParamList(*MF, O);
470
471 if (isKernelFunction(*F))
472 emitKernelFunctionDirectives(*F, O);
473
474 if (shouldEmitPTXNoReturn(F, TM))
475 O << ".noreturn";
476
477 OutStreamer->emitRawText(O.str());
478
479 VRegMapping.clear();
480 // Emit open brace for function body.
481 OutStreamer->emitRawText(StringRef("{\n"));
482 setAndEmitFunctionVirtualRegisters(*MF);
483 // Emit initial .loc debug directive for correct relocation symbol data.
484 if (MMI && MMI->hasDebugInfo())
485 emitInitialRawDwarfLocDirective(*MF);
486 }
487
runOnMachineFunction(MachineFunction & F)488 bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
489 bool Result = AsmPrinter::runOnMachineFunction(F);
490 // Emit closing brace for the body of function F.
491 // The closing brace must be emitted here because we need to emit additional
492 // debug labels/data after the last basic block.
493 // We need to emit the closing brace here because we don't have function that
494 // finished emission of the function body.
495 OutStreamer->emitRawText(StringRef("}\n"));
496 return Result;
497 }
498
emitFunctionBodyStart()499 void NVPTXAsmPrinter::emitFunctionBodyStart() {
500 SmallString<128> Str;
501 raw_svector_ostream O(Str);
502 emitDemotedVars(&MF->getFunction(), O);
503 OutStreamer->emitRawText(O.str());
504 }
505
emitFunctionBodyEnd()506 void NVPTXAsmPrinter::emitFunctionBodyEnd() {
507 VRegMapping.clear();
508 }
509
getFunctionFrameSymbol() const510 const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
511 SmallString<128> Str;
512 raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
513 return OutContext.getOrCreateSymbol(Str);
514 }
515
emitImplicitDef(const MachineInstr * MI) const516 void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
517 Register RegNo = MI->getOperand(0).getReg();
518 if (RegNo.isVirtual()) {
519 OutStreamer->AddComment(Twine("implicit-def: ") +
520 getVirtualRegisterName(RegNo));
521 } else {
522 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
523 OutStreamer->AddComment(Twine("implicit-def: ") +
524 STI.getRegisterInfo()->getName(RegNo));
525 }
526 OutStreamer->addBlankLine();
527 }
528
emitKernelFunctionDirectives(const Function & F,raw_ostream & O) const529 void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
530 raw_ostream &O) const {
531 // If the NVVM IR has some of reqntid* specified, then output
532 // the reqntid directive, and set the unspecified ones to 1.
533 // If none of reqntid* is specified, don't output reqntid directive.
534 unsigned reqntidx, reqntidy, reqntidz;
535 bool specified = false;
536 if (!getReqNTIDx(F, reqntidx))
537 reqntidx = 1;
538 else
539 specified = true;
540 if (!getReqNTIDy(F, reqntidy))
541 reqntidy = 1;
542 else
543 specified = true;
544 if (!getReqNTIDz(F, reqntidz))
545 reqntidz = 1;
546 else
547 specified = true;
548
549 if (specified)
550 O << ".reqntid " << reqntidx << ", " << reqntidy << ", " << reqntidz
551 << "\n";
552
553 // If the NVVM IR has some of maxntid* specified, then output
554 // the maxntid directive, and set the unspecified ones to 1.
555 // If none of maxntid* is specified, don't output maxntid directive.
556 unsigned maxntidx, maxntidy, maxntidz;
557 specified = false;
558 if (!getMaxNTIDx(F, maxntidx))
559 maxntidx = 1;
560 else
561 specified = true;
562 if (!getMaxNTIDy(F, maxntidy))
563 maxntidy = 1;
564 else
565 specified = true;
566 if (!getMaxNTIDz(F, maxntidz))
567 maxntidz = 1;
568 else
569 specified = true;
570
571 if (specified)
572 O << ".maxntid " << maxntidx << ", " << maxntidy << ", " << maxntidz
573 << "\n";
574
575 unsigned mincta;
576 if (getMinCTASm(F, mincta))
577 O << ".minnctapersm " << mincta << "\n";
578
579 unsigned maxnreg;
580 if (getMaxNReg(F, maxnreg))
581 O << ".maxnreg " << maxnreg << "\n";
582 }
583
584 std::string
getVirtualRegisterName(unsigned Reg) const585 NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
586 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
587
588 std::string Name;
589 raw_string_ostream NameStr(Name);
590
591 VRegRCMap::const_iterator I = VRegMapping.find(RC);
592 assert(I != VRegMapping.end() && "Bad register class");
593 const DenseMap<unsigned, unsigned> &RegMap = I->second;
594
595 VRegMap::const_iterator VI = RegMap.find(Reg);
596 assert(VI != RegMap.end() && "Bad virtual register");
597 unsigned MappedVR = VI->second;
598
599 NameStr << getNVPTXRegClassStr(RC) << MappedVR;
600
601 NameStr.flush();
602 return Name;
603 }
604
emitVirtualRegister(unsigned int vr,raw_ostream & O)605 void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
606 raw_ostream &O) {
607 O << getVirtualRegisterName(vr);
608 }
609
emitDeclaration(const Function * F,raw_ostream & O)610 void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
611 emitLinkageDirective(F, O);
612 if (isKernelFunction(*F))
613 O << ".entry ";
614 else
615 O << ".func ";
616 printReturnValStr(F, O);
617 getSymbol(F)->print(O, MAI);
618 O << "\n";
619 emitFunctionParamList(F, O);
620 if (shouldEmitPTXNoReturn(F, TM))
621 O << ".noreturn";
622 O << ";\n";
623 }
624
usedInGlobalVarDef(const Constant * C)625 static bool usedInGlobalVarDef(const Constant *C) {
626 if (!C)
627 return false;
628
629 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
630 return GV->getName() != "llvm.used";
631 }
632
633 for (const User *U : C->users())
634 if (const Constant *C = dyn_cast<Constant>(U))
635 if (usedInGlobalVarDef(C))
636 return true;
637
638 return false;
639 }
640
usedInOneFunc(const User * U,Function const * & oneFunc)641 static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
642 if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
643 if (othergv->getName() == "llvm.used")
644 return true;
645 }
646
647 if (const Instruction *instr = dyn_cast<Instruction>(U)) {
648 if (instr->getParent() && instr->getParent()->getParent()) {
649 const Function *curFunc = instr->getParent()->getParent();
650 if (oneFunc && (curFunc != oneFunc))
651 return false;
652 oneFunc = curFunc;
653 return true;
654 } else
655 return false;
656 }
657
658 for (const User *UU : U->users())
659 if (!usedInOneFunc(UU, oneFunc))
660 return false;
661
662 return true;
663 }
664
665 /* Find out if a global variable can be demoted to local scope.
666 * Currently, this is valid for CUDA shared variables, which have local
667 * scope and global lifetime. So the conditions to check are :
668 * 1. Is the global variable in shared address space?
669 * 2. Does it have internal linkage?
670 * 3. Is the global variable referenced only in one function?
671 */
canDemoteGlobalVar(const GlobalVariable * gv,Function const * & f)672 static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
673 if (!gv->hasInternalLinkage())
674 return false;
675 PointerType *Pty = gv->getType();
676 if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
677 return false;
678
679 const Function *oneFunc = nullptr;
680
681 bool flag = usedInOneFunc(gv, oneFunc);
682 if (!flag)
683 return false;
684 if (!oneFunc)
685 return false;
686 f = oneFunc;
687 return true;
688 }
689
useFuncSeen(const Constant * C,DenseMap<const Function *,bool> & seenMap)690 static bool useFuncSeen(const Constant *C,
691 DenseMap<const Function *, bool> &seenMap) {
692 for (const User *U : C->users()) {
693 if (const Constant *cu = dyn_cast<Constant>(U)) {
694 if (useFuncSeen(cu, seenMap))
695 return true;
696 } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
697 const BasicBlock *bb = I->getParent();
698 if (!bb)
699 continue;
700 const Function *caller = bb->getParent();
701 if (!caller)
702 continue;
703 if (seenMap.find(caller) != seenMap.end())
704 return true;
705 }
706 }
707 return false;
708 }
709
emitDeclarations(const Module & M,raw_ostream & O)710 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
711 DenseMap<const Function *, bool> seenMap;
712 for (const Function &F : M) {
713 if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
714 emitDeclaration(&F, O);
715 continue;
716 }
717
718 if (F.isDeclaration()) {
719 if (F.use_empty())
720 continue;
721 if (F.getIntrinsicID())
722 continue;
723 emitDeclaration(&F, O);
724 continue;
725 }
726 for (const User *U : F.users()) {
727 if (const Constant *C = dyn_cast<Constant>(U)) {
728 if (usedInGlobalVarDef(C)) {
729 // The use is in the initialization of a global variable
730 // that is a function pointer, so print a declaration
731 // for the original function
732 emitDeclaration(&F, O);
733 break;
734 }
735 // Emit a declaration of this function if the function that
736 // uses this constant expr has already been seen.
737 if (useFuncSeen(C, seenMap)) {
738 emitDeclaration(&F, O);
739 break;
740 }
741 }
742
743 if (!isa<Instruction>(U))
744 continue;
745 const Instruction *instr = cast<Instruction>(U);
746 const BasicBlock *bb = instr->getParent();
747 if (!bb)
748 continue;
749 const Function *caller = bb->getParent();
750 if (!caller)
751 continue;
752
753 // If a caller has already been seen, then the caller is
754 // appearing in the module before the callee. so print out
755 // a declaration for the callee.
756 if (seenMap.find(caller) != seenMap.end()) {
757 emitDeclaration(&F, O);
758 break;
759 }
760 }
761 seenMap[&F] = true;
762 }
763 }
764
isEmptyXXStructor(GlobalVariable * GV)765 static bool isEmptyXXStructor(GlobalVariable *GV) {
766 if (!GV) return true;
767 const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer());
768 if (!InitList) return true; // Not an array; we don't know how to parse.
769 return InitList->getNumOperands() == 0;
770 }
771
emitStartOfAsmFile(Module & M)772 void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
773 // Construct a default subtarget off of the TargetMachine defaults. The
774 // rest of NVPTX isn't friendly to change subtargets per function and
775 // so the default TargetMachine will have all of the options.
776 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
777 const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
778 SmallString<128> Str1;
779 raw_svector_ostream OS1(Str1);
780
781 // Emit header before any dwarf directives are emitted below.
782 emitHeader(M, OS1, *STI);
783 OutStreamer->emitRawText(OS1.str());
784 }
785
doInitialization(Module & M)786 bool NVPTXAsmPrinter::doInitialization(Module &M) {
787 if (M.alias_size()) {
788 report_fatal_error("Module has aliases, which NVPTX does not support.");
789 return true; // error
790 }
791 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors"))) {
792 report_fatal_error(
793 "Module has a nontrivial global ctor, which NVPTX does not support.");
794 return true; // error
795 }
796 if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors"))) {
797 report_fatal_error(
798 "Module has a nontrivial global dtor, which NVPTX does not support.");
799 return true; // error
800 }
801
802 // We need to call the parent's one explicitly.
803 bool Result = AsmPrinter::doInitialization(M);
804
805 GlobalsEmitted = false;
806
807 return Result;
808 }
809
emitGlobals(const Module & M)810 void NVPTXAsmPrinter::emitGlobals(const Module &M) {
811 SmallString<128> Str2;
812 raw_svector_ostream OS2(Str2);
813
814 emitDeclarations(M, OS2);
815
816 // As ptxas does not support forward references of globals, we need to first
817 // sort the list of module-level globals in def-use order. We visit each
818 // global variable in order, and ensure that we emit it *after* its dependent
819 // globals. We use a little extra memory maintaining both a set and a list to
820 // have fast searches while maintaining a strict ordering.
821 SmallVector<const GlobalVariable *, 8> Globals;
822 DenseSet<const GlobalVariable *> GVVisited;
823 DenseSet<const GlobalVariable *> GVVisiting;
824
825 // Visit each global variable, in order
826 for (const GlobalVariable &I : M.globals())
827 VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting);
828
829 assert(GVVisited.size() == M.getGlobalList().size() &&
830 "Missed a global variable");
831 assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
832
833 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
834 const NVPTXSubtarget &STI =
835 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
836
837 // Print out module-level global variables in proper order
838 for (unsigned i = 0, e = Globals.size(); i != e; ++i)
839 printModuleLevelGV(Globals[i], OS2, /*processDemoted=*/false, STI);
840
841 OS2 << '\n';
842
843 OutStreamer->emitRawText(OS2.str());
844 }
845
emitHeader(Module & M,raw_ostream & O,const NVPTXSubtarget & STI)846 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
847 const NVPTXSubtarget &STI) {
848 O << "//\n";
849 O << "// Generated by LLVM NVPTX Back-End\n";
850 O << "//\n";
851 O << "\n";
852
853 unsigned PTXVersion = STI.getPTXVersion();
854 O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
855
856 O << ".target ";
857 O << STI.getTargetName();
858
859 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
860 if (NTM.getDrvInterface() == NVPTX::NVCL)
861 O << ", texmode_independent";
862
863 bool HasFullDebugInfo = false;
864 for (DICompileUnit *CU : M.debug_compile_units()) {
865 switch(CU->getEmissionKind()) {
866 case DICompileUnit::NoDebug:
867 case DICompileUnit::DebugDirectivesOnly:
868 break;
869 case DICompileUnit::LineTablesOnly:
870 case DICompileUnit::FullDebug:
871 HasFullDebugInfo = true;
872 break;
873 }
874 if (HasFullDebugInfo)
875 break;
876 }
877 if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo)
878 O << ", debug";
879
880 O << "\n";
881
882 O << ".address_size ";
883 if (NTM.is64Bit())
884 O << "64";
885 else
886 O << "32";
887 O << "\n";
888
889 O << "\n";
890 }
891
doFinalization(Module & M)892 bool NVPTXAsmPrinter::doFinalization(Module &M) {
893 bool HasDebugInfo = MMI && MMI->hasDebugInfo();
894
895 // If we did not emit any functions, then the global declarations have not
896 // yet been emitted.
897 if (!GlobalsEmitted) {
898 emitGlobals(M);
899 GlobalsEmitted = true;
900 }
901
902 // call doFinalization
903 bool ret = AsmPrinter::doFinalization(M);
904
905 clearAnnotationCache(&M);
906
907 auto *TS =
908 static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());
909 // Close the last emitted section
910 if (HasDebugInfo) {
911 TS->closeLastSection();
912 // Emit empty .debug_loc section for better support of the empty files.
913 OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}");
914 }
915
916 // Output last DWARF .file directives, if any.
917 TS->outputDwarfFileDirectives();
918
919 return ret;
920 }
921
922 // This function emits appropriate linkage directives for
923 // functions and global variables.
924 //
925 // extern function declaration -> .extern
926 // extern function definition -> .visible
927 // external global variable with init -> .visible
928 // external without init -> .extern
929 // appending -> not allowed, assert.
930 // for any linkage other than
931 // internal, private, linker_private,
932 // linker_private_weak, linker_private_weak_def_auto,
933 // we emit -> .weak.
934
emitLinkageDirective(const GlobalValue * V,raw_ostream & O)935 void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
936 raw_ostream &O) {
937 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
938 if (V->hasExternalLinkage()) {
939 if (isa<GlobalVariable>(V)) {
940 const GlobalVariable *GVar = cast<GlobalVariable>(V);
941 if (GVar) {
942 if (GVar->hasInitializer())
943 O << ".visible ";
944 else
945 O << ".extern ";
946 }
947 } else if (V->isDeclaration())
948 O << ".extern ";
949 else
950 O << ".visible ";
951 } else if (V->hasAppendingLinkage()) {
952 std::string msg;
953 msg.append("Error: ");
954 msg.append("Symbol ");
955 if (V->hasName())
956 msg.append(std::string(V->getName()));
957 msg.append("has unsupported appending linkage type");
958 llvm_unreachable(msg.c_str());
959 } else if (!V->hasInternalLinkage() &&
960 !V->hasPrivateLinkage()) {
961 O << ".weak ";
962 }
963 }
964 }
965
printModuleLevelGV(const GlobalVariable * GVar,raw_ostream & O,bool processDemoted,const NVPTXSubtarget & STI)966 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
967 raw_ostream &O, bool processDemoted,
968 const NVPTXSubtarget &STI) {
969 // Skip meta data
970 if (GVar->hasSection()) {
971 if (GVar->getSection() == "llvm.metadata")
972 return;
973 }
974
975 // Skip LLVM intrinsic global variables
976 if (GVar->getName().startswith("llvm.") ||
977 GVar->getName().startswith("nvvm."))
978 return;
979
980 const DataLayout &DL = getDataLayout();
981
982 // GlobalVariables are always constant pointers themselves.
983 PointerType *PTy = GVar->getType();
984 Type *ETy = GVar->getValueType();
985
986 if (GVar->hasExternalLinkage()) {
987 if (GVar->hasInitializer())
988 O << ".visible ";
989 else
990 O << ".extern ";
991 } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
992 GVar->hasAvailableExternallyLinkage() ||
993 GVar->hasCommonLinkage()) {
994 O << ".weak ";
995 }
996
997 if (isTexture(*GVar)) {
998 O << ".global .texref " << getTextureName(*GVar) << ";\n";
999 return;
1000 }
1001
1002 if (isSurface(*GVar)) {
1003 O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
1004 return;
1005 }
1006
1007 if (GVar->isDeclaration()) {
1008 // (extern) declarations, no definition or initializer
1009 // Currently the only known declaration is for an automatic __local
1010 // (.shared) promoted to global.
1011 emitPTXGlobalVariable(GVar, O, STI);
1012 O << ";\n";
1013 return;
1014 }
1015
1016 if (isSampler(*GVar)) {
1017 O << ".global .samplerref " << getSamplerName(*GVar);
1018
1019 const Constant *Initializer = nullptr;
1020 if (GVar->hasInitializer())
1021 Initializer = GVar->getInitializer();
1022 const ConstantInt *CI = nullptr;
1023 if (Initializer)
1024 CI = dyn_cast<ConstantInt>(Initializer);
1025 if (CI) {
1026 unsigned sample = CI->getZExtValue();
1027
1028 O << " = { ";
1029
1030 for (int i = 0,
1031 addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
1032 i < 3; i++) {
1033 O << "addr_mode_" << i << " = ";
1034 switch (addr) {
1035 case 0:
1036 O << "wrap";
1037 break;
1038 case 1:
1039 O << "clamp_to_border";
1040 break;
1041 case 2:
1042 O << "clamp_to_edge";
1043 break;
1044 case 3:
1045 O << "wrap";
1046 break;
1047 case 4:
1048 O << "mirror";
1049 break;
1050 }
1051 O << ", ";
1052 }
1053 O << "filter_mode = ";
1054 switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
1055 case 0:
1056 O << "nearest";
1057 break;
1058 case 1:
1059 O << "linear";
1060 break;
1061 case 2:
1062 llvm_unreachable("Anisotropic filtering is not supported");
1063 default:
1064 O << "nearest";
1065 break;
1066 }
1067 if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
1068 O << ", force_unnormalized_coords = 1";
1069 }
1070 O << " }";
1071 }
1072
1073 O << ";\n";
1074 return;
1075 }
1076
1077 if (GVar->hasPrivateLinkage()) {
1078 if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
1079 return;
1080
1081 // FIXME - need better way (e.g. Metadata) to avoid generating this global
1082 if (strncmp(GVar->getName().data(), "filename", 8) == 0)
1083 return;
1084 if (GVar->use_empty())
1085 return;
1086 }
1087
1088 const Function *demotedFunc = nullptr;
1089 if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
1090 O << "// " << GVar->getName() << " has been demoted\n";
1091 if (localDecls.find(demotedFunc) != localDecls.end())
1092 localDecls[demotedFunc].push_back(GVar);
1093 else {
1094 std::vector<const GlobalVariable *> temp;
1095 temp.push_back(GVar);
1096 localDecls[demotedFunc] = temp;
1097 }
1098 return;
1099 }
1100
1101 O << ".";
1102 emitPTXAddressSpace(PTy->getAddressSpace(), O);
1103
1104 if (isManaged(*GVar)) {
1105 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1106 report_fatal_error(
1107 ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1108 }
1109 O << " .attribute(.managed)";
1110 }
1111
1112 if (MaybeAlign A = GVar->getAlign())
1113 O << " .align " << A->value();
1114 else
1115 O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1116
1117 if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
1118 (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
1119 O << " .";
1120 // Special case: ABI requires that we use .u8 for predicates
1121 if (ETy->isIntegerTy(1))
1122 O << "u8";
1123 else
1124 O << getPTXFundamentalTypeStr(ETy, false);
1125 O << " ";
1126 getSymbol(GVar)->print(O, MAI);
1127
1128 // Ptx allows variable initilization only for constant and global state
1129 // spaces.
1130 if (GVar->hasInitializer()) {
1131 if ((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1132 (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1133 const Constant *Initializer = GVar->getInitializer();
1134 // 'undef' is treated as there is no value specified.
1135 if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) {
1136 O << " = ";
1137 printScalarConstant(Initializer, O);
1138 }
1139 } else {
1140 // The frontend adds zero-initializer to device and constant variables
1141 // that don't have an initial value, and UndefValue to shared
1142 // variables, so skip warning for this case.
1143 if (!GVar->getInitializer()->isNullValue() &&
1144 !isa<UndefValue>(GVar->getInitializer())) {
1145 report_fatal_error("initial value of '" + GVar->getName() +
1146 "' is not allowed in addrspace(" +
1147 Twine(PTy->getAddressSpace()) + ")");
1148 }
1149 }
1150 }
1151 } else {
1152 unsigned int ElementSize = 0;
1153
1154 // Although PTX has direct support for struct type and array type and
1155 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1156 // targets that support these high level field accesses. Structs, arrays
1157 // and vectors are lowered into arrays of bytes.
1158 switch (ETy->getTypeID()) {
1159 case Type::IntegerTyID: // Integers larger than 64 bits
1160 case Type::StructTyID:
1161 case Type::ArrayTyID:
1162 case Type::FixedVectorTyID:
1163 ElementSize = DL.getTypeStoreSize(ETy);
1164 // Ptx allows variable initilization only for constant and
1165 // global state spaces.
1166 if (((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1167 (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1168 GVar->hasInitializer()) {
1169 const Constant *Initializer = GVar->getInitializer();
1170 if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) {
1171 AggBuffer aggBuffer(ElementSize, *this);
1172 bufferAggregateConstant(Initializer, &aggBuffer);
1173 if (aggBuffer.numSymbols()) {
1174 unsigned int ptrSize = MAI->getCodePointerSize();
1175 if (ElementSize % ptrSize ||
1176 !aggBuffer.allSymbolsAligned(ptrSize)) {
1177 // Print in bytes and use the mask() operator for pointers.
1178 if (!STI.hasMaskOperator())
1179 report_fatal_error(
1180 "initialized packed aggregate with pointers '" +
1181 GVar->getName() +
1182 "' requires at least PTX ISA version 7.1");
1183 O << " .u8 ";
1184 getSymbol(GVar)->print(O, MAI);
1185 O << "[" << ElementSize << "] = {";
1186 aggBuffer.printBytes(O);
1187 O << "}";
1188 } else {
1189 O << " .u" << ptrSize * 8 << " ";
1190 getSymbol(GVar)->print(O, MAI);
1191 O << "[" << ElementSize / ptrSize << "] = {";
1192 aggBuffer.printWords(O);
1193 O << "}";
1194 }
1195 } else {
1196 O << " .b8 ";
1197 getSymbol(GVar)->print(O, MAI);
1198 O << "[" << ElementSize << "] = {";
1199 aggBuffer.printBytes(O);
1200 O << "}";
1201 }
1202 } else {
1203 O << " .b8 ";
1204 getSymbol(GVar)->print(O, MAI);
1205 if (ElementSize) {
1206 O << "[";
1207 O << ElementSize;
1208 O << "]";
1209 }
1210 }
1211 } else {
1212 O << " .b8 ";
1213 getSymbol(GVar)->print(O, MAI);
1214 if (ElementSize) {
1215 O << "[";
1216 O << ElementSize;
1217 O << "]";
1218 }
1219 }
1220 break;
1221 default:
1222 llvm_unreachable("type not supported yet");
1223 }
1224 }
1225 O << ";\n";
1226 }
1227
printSymbol(unsigned nSym,raw_ostream & os)1228 void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
1229 const Value *v = Symbols[nSym];
1230 const Value *v0 = SymbolsBeforeStripping[nSym];
1231 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
1232 MCSymbol *Name = AP.getSymbol(GVar);
1233 PointerType *PTy = dyn_cast<PointerType>(v0->getType());
1234 // Is v0 a generic pointer?
1235 bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
1236 if (EmitGeneric && isGenericPointer && !isa<Function>(v)) {
1237 os << "generic(";
1238 Name->print(os, AP.MAI);
1239 os << ")";
1240 } else {
1241 Name->print(os, AP.MAI);
1242 }
1243 } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
1244 const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);
1245 AP.printMCExpr(*Expr, os);
1246 } else
1247 llvm_unreachable("symbol type unknown");
1248 }
1249
printBytes(raw_ostream & os)1250 void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
1251 unsigned int ptrSize = AP.MAI->getCodePointerSize();
1252 symbolPosInBuffer.push_back(size);
1253 unsigned int nSym = 0;
1254 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1255 for (unsigned int pos = 0; pos < size;) {
1256 if (pos)
1257 os << ", ";
1258 if (pos != nextSymbolPos) {
1259 os << (unsigned int)buffer[pos];
1260 ++pos;
1261 continue;
1262 }
1263 // Generate a per-byte mask() operator for the symbol, which looks like:
1264 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1265 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1266 std::string symText;
1267 llvm::raw_string_ostream oss(symText);
1268 printSymbol(nSym, oss);
1269 for (unsigned i = 0; i < ptrSize; ++i) {
1270 if (i)
1271 os << ", ";
1272 llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper);
1273 os << "(" << symText << ")";
1274 }
1275 pos += ptrSize;
1276 nextSymbolPos = symbolPosInBuffer[++nSym];
1277 assert(nextSymbolPos >= pos);
1278 }
1279 }
1280
printWords(raw_ostream & os)1281 void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
1282 unsigned int ptrSize = AP.MAI->getCodePointerSize();
1283 symbolPosInBuffer.push_back(size);
1284 unsigned int nSym = 0;
1285 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1286 assert(nextSymbolPos % ptrSize == 0);
1287 for (unsigned int pos = 0; pos < size; pos += ptrSize) {
1288 if (pos)
1289 os << ", ";
1290 if (pos == nextSymbolPos) {
1291 printSymbol(nSym, os);
1292 nextSymbolPos = symbolPosInBuffer[++nSym];
1293 assert(nextSymbolPos % ptrSize == 0);
1294 assert(nextSymbolPos >= pos + ptrSize);
1295 } else if (ptrSize == 4)
1296 os << support::endian::read32le(&buffer[pos]);
1297 else
1298 os << support::endian::read64le(&buffer[pos]);
1299 }
1300 }
1301
emitDemotedVars(const Function * f,raw_ostream & O)1302 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
1303 if (localDecls.find(f) == localDecls.end())
1304 return;
1305
1306 std::vector<const GlobalVariable *> &gvars = localDecls[f];
1307
1308 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
1309 const NVPTXSubtarget &STI =
1310 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
1311
1312 for (const GlobalVariable *GV : gvars) {
1313 O << "\t// demoted variable\n\t";
1314 printModuleLevelGV(GV, O, /*processDemoted=*/true, STI);
1315 }
1316 }
1317
emitPTXAddressSpace(unsigned int AddressSpace,raw_ostream & O) const1318 void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1319 raw_ostream &O) const {
1320 switch (AddressSpace) {
1321 case ADDRESS_SPACE_LOCAL:
1322 O << "local";
1323 break;
1324 case ADDRESS_SPACE_GLOBAL:
1325 O << "global";
1326 break;
1327 case ADDRESS_SPACE_CONST:
1328 O << "const";
1329 break;
1330 case ADDRESS_SPACE_SHARED:
1331 O << "shared";
1332 break;
1333 default:
1334 report_fatal_error("Bad address space found while emitting PTX: " +
1335 llvm::Twine(AddressSpace));
1336 break;
1337 }
1338 }
1339
1340 std::string
getPTXFundamentalTypeStr(Type * Ty,bool useB4PTR) const1341 NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1342 switch (Ty->getTypeID()) {
1343 case Type::IntegerTyID: {
1344 unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
1345 if (NumBits == 1)
1346 return "pred";
1347 else if (NumBits <= 64) {
1348 std::string name = "u";
1349 return name + utostr(NumBits);
1350 } else {
1351 llvm_unreachable("Integer too large");
1352 break;
1353 }
1354 break;
1355 }
1356 case Type::HalfTyID:
1357 // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly.
1358 return "b16";
1359 case Type::FloatTyID:
1360 return "f32";
1361 case Type::DoubleTyID:
1362 return "f64";
1363 case Type::PointerTyID: {
1364 unsigned PtrSize = TM.getPointerSizeInBits(Ty->getPointerAddressSpace());
1365 assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size");
1366
1367 if (PtrSize == 64)
1368 if (useB4PTR)
1369 return "b64";
1370 else
1371 return "u64";
1372 else if (useB4PTR)
1373 return "b32";
1374 else
1375 return "u32";
1376 }
1377 default:
1378 break;
1379 }
1380 llvm_unreachable("unexpected type");
1381 }
1382
emitPTXGlobalVariable(const GlobalVariable * GVar,raw_ostream & O,const NVPTXSubtarget & STI)1383 void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1384 raw_ostream &O,
1385 const NVPTXSubtarget &STI) {
1386 const DataLayout &DL = getDataLayout();
1387
1388 // GlobalVariables are always constant pointers themselves.
1389 Type *ETy = GVar->getValueType();
1390
1391 O << ".";
1392 emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
1393 if (isManaged(*GVar)) {
1394 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1395 report_fatal_error(
1396 ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1397 }
1398 O << " .attribute(.managed)";
1399 }
1400 if (MaybeAlign A = GVar->getAlign())
1401 O << " .align " << A->value();
1402 else
1403 O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1404
1405 // Special case for i128
1406 if (ETy->isIntegerTy(128)) {
1407 O << " .b8 ";
1408 getSymbol(GVar)->print(O, MAI);
1409 O << "[16]";
1410 return;
1411 }
1412
1413 if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1414 O << " .";
1415 O << getPTXFundamentalTypeStr(ETy);
1416 O << " ";
1417 getSymbol(GVar)->print(O, MAI);
1418 return;
1419 }
1420
1421 int64_t ElementSize = 0;
1422
1423 // Although PTX has direct support for struct type and array type and LLVM IR
1424 // is very similar to PTX, the LLVM CodeGen does not support for targets that
1425 // support these high level field accesses. Structs and arrays are lowered
1426 // into arrays of bytes.
1427 switch (ETy->getTypeID()) {
1428 case Type::StructTyID:
1429 case Type::ArrayTyID:
1430 case Type::FixedVectorTyID:
1431 ElementSize = DL.getTypeStoreSize(ETy);
1432 O << " .b8 ";
1433 getSymbol(GVar)->print(O, MAI);
1434 O << "[";
1435 if (ElementSize) {
1436 O << ElementSize;
1437 }
1438 O << "]";
1439 break;
1440 default:
1441 llvm_unreachable("type not supported yet");
1442 }
1443 }
1444
printParamName(Function::const_arg_iterator I,int paramIndex,raw_ostream & O)1445 void NVPTXAsmPrinter::printParamName(Function::const_arg_iterator I,
1446 int paramIndex, raw_ostream &O) {
1447 getSymbol(I->getParent())->print(O, MAI);
1448 O << "_param_" << paramIndex;
1449 }
1450
emitFunctionParamList(const Function * F,raw_ostream & O)1451 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1452 const DataLayout &DL = getDataLayout();
1453 const AttributeList &PAL = F->getAttributes();
1454 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
1455 const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
1456
1457 Function::const_arg_iterator I, E;
1458 unsigned paramIndex = 0;
1459 bool first = true;
1460 bool isKernelFunc = isKernelFunction(*F);
1461 bool isABI = (STI.getSmVersion() >= 20);
1462 bool hasImageHandles = STI.hasImageHandles();
1463
1464 if (F->arg_empty() && !F->isVarArg()) {
1465 O << "()\n";
1466 return;
1467 }
1468
1469 O << "(\n";
1470
1471 for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
1472 Type *Ty = I->getType();
1473
1474 if (!first)
1475 O << ",\n";
1476
1477 first = false;
1478
1479 // Handle image/sampler parameters
1480 if (isKernelFunction(*F)) {
1481 if (isSampler(*I) || isImage(*I)) {
1482 if (isImage(*I)) {
1483 std::string sname = std::string(I->getName());
1484 if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
1485 if (hasImageHandles)
1486 O << "\t.param .u64 .ptr .surfref ";
1487 else
1488 O << "\t.param .surfref ";
1489 CurrentFnSym->print(O, MAI);
1490 O << "_param_" << paramIndex;
1491 }
1492 else { // Default image is read_only
1493 if (hasImageHandles)
1494 O << "\t.param .u64 .ptr .texref ";
1495 else
1496 O << "\t.param .texref ";
1497 CurrentFnSym->print(O, MAI);
1498 O << "_param_" << paramIndex;
1499 }
1500 } else {
1501 if (hasImageHandles)
1502 O << "\t.param .u64 .ptr .samplerref ";
1503 else
1504 O << "\t.param .samplerref ";
1505 CurrentFnSym->print(O, MAI);
1506 O << "_param_" << paramIndex;
1507 }
1508 continue;
1509 }
1510 }
1511
1512 auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
1513 paramIndex](Type *Ty) -> Align {
1514 Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
1515 MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
1516 return std::max(TypeAlign, ParamAlign.valueOrOne());
1517 };
1518
1519 if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
1520 if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
1521 // Just print .param .align <a> .b8 .param[size];
1522 // <a> = optimal alignment for the element type; always multiple of
1523 // PAL.getParamAlignment
1524 // size = typeallocsize of element type
1525 Align OptimalAlign = getOptimalAlignForParam(Ty);
1526
1527 O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1528 printParamName(I, paramIndex, O);
1529 O << "[" << DL.getTypeAllocSize(Ty) << "]";
1530
1531 continue;
1532 }
1533 // Just a scalar
1534 auto *PTy = dyn_cast<PointerType>(Ty);
1535 unsigned PTySizeInBits = 0;
1536 if (PTy) {
1537 PTySizeInBits =
1538 TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits();
1539 assert(PTySizeInBits && "Invalid pointer size");
1540 }
1541
1542 if (isKernelFunc) {
1543 if (PTy) {
1544 // Special handling for pointer arguments to kernel
1545 O << "\t.param .u" << PTySizeInBits << " ";
1546
1547 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
1548 NVPTX::CUDA) {
1549 int addrSpace = PTy->getAddressSpace();
1550 switch (addrSpace) {
1551 default:
1552 O << ".ptr ";
1553 break;
1554 case ADDRESS_SPACE_CONST:
1555 O << ".ptr .const ";
1556 break;
1557 case ADDRESS_SPACE_SHARED:
1558 O << ".ptr .shared ";
1559 break;
1560 case ADDRESS_SPACE_GLOBAL:
1561 O << ".ptr .global ";
1562 break;
1563 }
1564 Align ParamAlign = I->getParamAlign().valueOrOne();
1565 O << ".align " << ParamAlign.value() << " ";
1566 }
1567 printParamName(I, paramIndex, O);
1568 continue;
1569 }
1570
1571 // non-pointer scalar to kernel func
1572 O << "\t.param .";
1573 // Special case: predicate operands become .u8 types
1574 if (Ty->isIntegerTy(1))
1575 O << "u8";
1576 else
1577 O << getPTXFundamentalTypeStr(Ty);
1578 O << " ";
1579 printParamName(I, paramIndex, O);
1580 continue;
1581 }
1582 // Non-kernel function, just print .param .b<size> for ABI
1583 // and .reg .b<size> for non-ABI
1584 unsigned sz = 0;
1585 if (isa<IntegerType>(Ty)) {
1586 sz = cast<IntegerType>(Ty)->getBitWidth();
1587 sz = promoteScalarArgumentSize(sz);
1588 } else if (PTy) {
1589 assert(PTySizeInBits && "Invalid pointer size");
1590 sz = PTySizeInBits;
1591 } else if (Ty->isHalfTy())
1592 // PTX ABI requires all scalar parameters to be at least 32
1593 // bits in size. fp16 normally uses .b16 as its storage type
1594 // in PTX, so its size must be adjusted here, too.
1595 sz = 32;
1596 else
1597 sz = Ty->getPrimitiveSizeInBits();
1598 if (isABI)
1599 O << "\t.param .b" << sz << " ";
1600 else
1601 O << "\t.reg .b" << sz << " ";
1602 printParamName(I, paramIndex, O);
1603 continue;
1604 }
1605
1606 // param has byVal attribute.
1607 Type *ETy = PAL.getParamByValType(paramIndex);
1608 assert(ETy && "Param should have byval type");
1609
1610 if (isABI || isKernelFunc) {
1611 // Just print .param .align <a> .b8 .param[size];
1612 // <a> = optimal alignment for the element type; always multiple of
1613 // PAL.getParamAlignment
1614 // size = typeallocsize of element type
1615 Align OptimalAlign =
1616 isKernelFunc
1617 ? getOptimalAlignForParam(ETy)
1618 : TLI->getFunctionByValParamAlign(
1619 F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL);
1620
1621 unsigned sz = DL.getTypeAllocSize(ETy);
1622 O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1623 printParamName(I, paramIndex, O);
1624 O << "[" << sz << "]";
1625 continue;
1626 } else {
1627 // Split the ETy into constituent parts and
1628 // print .param .b<size> <name> for each part.
1629 // Further, if a part is vector, print the above for
1630 // each vector element.
1631 SmallVector<EVT, 16> vtparts;
1632 ComputeValueVTs(*TLI, DL, ETy, vtparts);
1633 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
1634 unsigned elems = 1;
1635 EVT elemtype = vtparts[i];
1636 if (vtparts[i].isVector()) {
1637 elems = vtparts[i].getVectorNumElements();
1638 elemtype = vtparts[i].getVectorElementType();
1639 }
1640
1641 for (unsigned j = 0, je = elems; j != je; ++j) {
1642 unsigned sz = elemtype.getSizeInBits();
1643 if (elemtype.isInteger())
1644 sz = promoteScalarArgumentSize(sz);
1645 O << "\t.reg .b" << sz << " ";
1646 printParamName(I, paramIndex, O);
1647 if (j < je - 1)
1648 O << ",\n";
1649 ++paramIndex;
1650 }
1651 if (i < e - 1)
1652 O << ",\n";
1653 }
1654 --paramIndex;
1655 continue;
1656 }
1657 }
1658
1659 if (F->isVarArg()) {
1660 if (!first)
1661 O << ",\n";
1662 O << "\t.param .align " << STI.getMaxRequiredAlignment();
1663 O << " .b8 ";
1664 getSymbol(F)->print(O, MAI);
1665 O << "_vararg[]";
1666 }
1667
1668 O << "\n)\n";
1669 }
1670
emitFunctionParamList(const MachineFunction & MF,raw_ostream & O)1671 void NVPTXAsmPrinter::emitFunctionParamList(const MachineFunction &MF,
1672 raw_ostream &O) {
1673 const Function &F = MF.getFunction();
1674 emitFunctionParamList(&F, O);
1675 }
1676
setAndEmitFunctionVirtualRegisters(const MachineFunction & MF)1677 void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1678 const MachineFunction &MF) {
1679 SmallString<128> Str;
1680 raw_svector_ostream O(Str);
1681
1682 // Map the global virtual register number to a register class specific
1683 // virtual register number starting from 1 with that class.
1684 const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1685 //unsigned numRegClasses = TRI->getNumRegClasses();
1686
1687 // Emit the Fake Stack Object
1688 const MachineFrameInfo &MFI = MF.getFrameInfo();
1689 int NumBytes = (int) MFI.getStackSize();
1690 if (NumBytes) {
1691 O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1692 << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1693 if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1694 O << "\t.reg .b64 \t%SP;\n";
1695 O << "\t.reg .b64 \t%SPL;\n";
1696 } else {
1697 O << "\t.reg .b32 \t%SP;\n";
1698 O << "\t.reg .b32 \t%SPL;\n";
1699 }
1700 }
1701
1702 // Go through all virtual registers to establish the mapping between the
1703 // global virtual
1704 // register number and the per class virtual register number.
1705 // We use the per class virtual register number in the ptx output.
1706 unsigned int numVRs = MRI->getNumVirtRegs();
1707 for (unsigned i = 0; i < numVRs; i++) {
1708 Register vr = Register::index2VirtReg(i);
1709 const TargetRegisterClass *RC = MRI->getRegClass(vr);
1710 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC];
1711 int n = regmap.size();
1712 regmap.insert(std::make_pair(vr, n + 1));
1713 }
1714
1715 // Emit register declarations
1716 // @TODO: Extract out the real register usage
1717 // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1718 // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1719 // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1720 // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1721 // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1722 // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1723 // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1724
1725 // Emit declaration of the virtual registers or 'physical' registers for
1726 // each register class
1727 for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
1728 const TargetRegisterClass *RC = TRI->getRegClass(i);
1729 DenseMap<unsigned, unsigned> ®map = VRegMapping[RC];
1730 std::string rcname = getNVPTXRegClassName(RC);
1731 std::string rcStr = getNVPTXRegClassStr(RC);
1732 int n = regmap.size();
1733
1734 // Only declare those registers that may be used.
1735 if (n) {
1736 O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
1737 << ">;\n";
1738 }
1739 }
1740
1741 OutStreamer->emitRawText(O.str());
1742 }
1743
printFPConstant(const ConstantFP * Fp,raw_ostream & O)1744 void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
1745 APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1746 bool ignored;
1747 unsigned int numHex;
1748 const char *lead;
1749
1750 if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1751 numHex = 8;
1752 lead = "0f";
1753 APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
1754 } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1755 numHex = 16;
1756 lead = "0d";
1757 APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored);
1758 } else
1759 llvm_unreachable("unsupported fp type");
1760
1761 APInt API = APF.bitcastToAPInt();
1762 O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true);
1763 }
1764
printScalarConstant(const Constant * CPV,raw_ostream & O)1765 void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1766 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1767 O << CI->getValue();
1768 return;
1769 }
1770 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1771 printFPConstant(CFP, O);
1772 return;
1773 }
1774 if (isa<ConstantPointerNull>(CPV)) {
1775 O << "0";
1776 return;
1777 }
1778 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1779 bool IsNonGenericPointer = false;
1780 if (GVar->getType()->getAddressSpace() != 0) {
1781 IsNonGenericPointer = true;
1782 }
1783 if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {
1784 O << "generic(";
1785 getSymbol(GVar)->print(O, MAI);
1786 O << ")";
1787 } else {
1788 getSymbol(GVar)->print(O, MAI);
1789 }
1790 return;
1791 }
1792 if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1793 const MCExpr *E = lowerConstantForGV(cast<Constant>(Cexpr), false);
1794 printMCExpr(*E, O);
1795 return;
1796 }
1797 llvm_unreachable("Not scalar type found in printScalarConstant()");
1798 }
1799
bufferLEByte(const Constant * CPV,int Bytes,AggBuffer * AggBuffer)1800 void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1801 AggBuffer *AggBuffer) {
1802 const DataLayout &DL = getDataLayout();
1803 int AllocSize = DL.getTypeAllocSize(CPV->getType());
1804 if (isa<UndefValue>(CPV) || CPV->isNullValue()) {
1805 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1806 // only the space allocated by CPV.
1807 AggBuffer->addZeros(Bytes ? Bytes : AllocSize);
1808 return;
1809 }
1810
1811 // Helper for filling AggBuffer with APInts.
1812 auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1813 size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1814 SmallVector<unsigned char, 16> Buf(NumBytes);
1815 for (unsigned I = 0; I < NumBytes; ++I) {
1816 Buf[I] = Val.extractBitsAsZExtValue(8, I * 8);
1817 }
1818 AggBuffer->addBytes(Buf.data(), NumBytes, Bytes);
1819 };
1820
1821 switch (CPV->getType()->getTypeID()) {
1822 case Type::IntegerTyID:
1823 if (const auto CI = dyn_cast<ConstantInt>(CPV)) {
1824 AddIntToBuffer(CI->getValue());
1825 break;
1826 }
1827 if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1828 if (const auto *CI =
1829 dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) {
1830 AddIntToBuffer(CI->getValue());
1831 break;
1832 }
1833 if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1834 Value *V = Cexpr->getOperand(0)->stripPointerCasts();
1835 AggBuffer->addSymbol(V, Cexpr->getOperand(0));
1836 AggBuffer->addZeros(AllocSize);
1837 break;
1838 }
1839 }
1840 llvm_unreachable("unsupported integer const type");
1841 break;
1842
1843 case Type::HalfTyID:
1844 case Type::BFloatTyID:
1845 case Type::FloatTyID:
1846 case Type::DoubleTyID:
1847 AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt());
1848 break;
1849
1850 case Type::PointerTyID: {
1851 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1852 AggBuffer->addSymbol(GVar, GVar);
1853 } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1854 const Value *v = Cexpr->stripPointerCasts();
1855 AggBuffer->addSymbol(v, Cexpr);
1856 }
1857 AggBuffer->addZeros(AllocSize);
1858 break;
1859 }
1860
1861 case Type::ArrayTyID:
1862 case Type::FixedVectorTyID:
1863 case Type::StructTyID: {
1864 if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) {
1865 bufferAggregateConstant(CPV, AggBuffer);
1866 if (Bytes > AllocSize)
1867 AggBuffer->addZeros(Bytes - AllocSize);
1868 } else if (isa<ConstantAggregateZero>(CPV))
1869 AggBuffer->addZeros(Bytes);
1870 else
1871 llvm_unreachable("Unexpected Constant type");
1872 break;
1873 }
1874
1875 default:
1876 llvm_unreachable("unsupported type");
1877 }
1878 }
1879
bufferAggregateConstant(const Constant * CPV,AggBuffer * aggBuffer)1880 void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1881 AggBuffer *aggBuffer) {
1882 const DataLayout &DL = getDataLayout();
1883 int Bytes;
1884
1885 // Integers of arbitrary width
1886 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1887 APInt Val = CI->getValue();
1888 for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
1889 uint8_t Byte = Val.getLoBits(8).getZExtValue();
1890 aggBuffer->addBytes(&Byte, 1, 1);
1891 Val.lshrInPlace(8);
1892 }
1893 return;
1894 }
1895
1896 // Old constants
1897 if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1898 if (CPV->getNumOperands())
1899 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
1900 bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
1901 return;
1902 }
1903
1904 if (const ConstantDataSequential *CDS =
1905 dyn_cast<ConstantDataSequential>(CPV)) {
1906 if (CDS->getNumElements())
1907 for (unsigned i = 0; i < CDS->getNumElements(); ++i)
1908 bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
1909 aggBuffer);
1910 return;
1911 }
1912
1913 if (isa<ConstantStruct>(CPV)) {
1914 if (CPV->getNumOperands()) {
1915 StructType *ST = cast<StructType>(CPV->getType());
1916 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
1917 if (i == (e - 1))
1918 Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
1919 DL.getTypeAllocSize(ST) -
1920 DL.getStructLayout(ST)->getElementOffset(i);
1921 else
1922 Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
1923 DL.getStructLayout(ST)->getElementOffset(i);
1924 bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
1925 }
1926 }
1927 return;
1928 }
1929 llvm_unreachable("unsupported constant type in printAggregateConstant()");
1930 }
1931
1932 /// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly
1933 /// a copy from AsmPrinter::lowerConstant, except customized to only handle
1934 /// expressions that are representable in PTX and create
1935 /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1936 const MCExpr *
lowerConstantForGV(const Constant * CV,bool ProcessingGeneric)1937 NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
1938 MCContext &Ctx = OutContext;
1939
1940 if (CV->isNullValue() || isa<UndefValue>(CV))
1941 return MCConstantExpr::create(0, Ctx);
1942
1943 if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
1944 return MCConstantExpr::create(CI->getZExtValue(), Ctx);
1945
1946 if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
1947 const MCSymbolRefExpr *Expr =
1948 MCSymbolRefExpr::create(getSymbol(GV), Ctx);
1949 if (ProcessingGeneric) {
1950 return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);
1951 } else {
1952 return Expr;
1953 }
1954 }
1955
1956 const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
1957 if (!CE) {
1958 llvm_unreachable("Unknown constant value to lower!");
1959 }
1960
1961 switch (CE->getOpcode()) {
1962 default: {
1963 // If the code isn't optimized, there may be outstanding folding
1964 // opportunities. Attempt to fold the expression using DataLayout as a
1965 // last resort before giving up.
1966 Constant *C = ConstantFoldConstant(CE, getDataLayout());
1967 if (C != CE)
1968 return lowerConstantForGV(C, ProcessingGeneric);
1969
1970 // Otherwise report the problem to the user.
1971 std::string S;
1972 raw_string_ostream OS(S);
1973 OS << "Unsupported expression in static initializer: ";
1974 CE->printAsOperand(OS, /*PrintType=*/false,
1975 !MF ? nullptr : MF->getFunction().getParent());
1976 report_fatal_error(Twine(OS.str()));
1977 }
1978
1979 case Instruction::AddrSpaceCast: {
1980 // Strip the addrspacecast and pass along the operand
1981 PointerType *DstTy = cast<PointerType>(CE->getType());
1982 if (DstTy->getAddressSpace() == 0) {
1983 return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
1984 }
1985 std::string S;
1986 raw_string_ostream OS(S);
1987 OS << "Unsupported expression in static initializer: ";
1988 CE->printAsOperand(OS, /*PrintType=*/ false,
1989 !MF ? nullptr : MF->getFunction().getParent());
1990 report_fatal_error(Twine(OS.str()));
1991 }
1992
1993 case Instruction::GetElementPtr: {
1994 const DataLayout &DL = getDataLayout();
1995
1996 // Generate a symbolic expression for the byte address
1997 APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
1998 cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
1999
2000 const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
2001 ProcessingGeneric);
2002 if (!OffsetAI)
2003 return Base;
2004
2005 int64_t Offset = OffsetAI.getSExtValue();
2006 return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx),
2007 Ctx);
2008 }
2009
2010 case Instruction::Trunc:
2011 // We emit the value and depend on the assembler to truncate the generated
2012 // expression properly. This is important for differences between
2013 // blockaddress labels. Since the two labels are in the same function, it
2014 // is reasonable to treat their delta as a 32-bit value.
2015 [[fallthrough]];
2016 case Instruction::BitCast:
2017 return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2018
2019 case Instruction::IntToPtr: {
2020 const DataLayout &DL = getDataLayout();
2021
2022 // Handle casts to pointers by changing them into casts to the appropriate
2023 // integer type. This promotes constant folding and simplifies this code.
2024 Constant *Op = CE->getOperand(0);
2025 Op = ConstantExpr::getIntegerCast(Op, DL.getIntPtrType(CV->getType()),
2026 false/*ZExt*/);
2027 return lowerConstantForGV(Op, ProcessingGeneric);
2028 }
2029
2030 case Instruction::PtrToInt: {
2031 const DataLayout &DL = getDataLayout();
2032
2033 // Support only foldable casts to/from pointers that can be eliminated by
2034 // changing the pointer to the appropriately sized integer type.
2035 Constant *Op = CE->getOperand(0);
2036 Type *Ty = CE->getType();
2037
2038 const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
2039
2040 // We can emit the pointer value into this slot if the slot is an
2041 // integer slot equal to the size of the pointer.
2042 if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
2043 return OpExpr;
2044
2045 // Otherwise the pointer is smaller than the resultant integer, mask off
2046 // the high bits so we are sure to get a proper truncation if the input is
2047 // a constant expr.
2048 unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
2049 const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx);
2050 return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx);
2051 }
2052
2053 // The MC library also has a right-shift operator, but it isn't consistently
2054 // signed or unsigned between different targets.
2055 case Instruction::Add: {
2056 const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2057 const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
2058 switch (CE->getOpcode()) {
2059 default: llvm_unreachable("Unknown binary operator constant cast expr");
2060 case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
2061 }
2062 }
2063 }
2064 }
2065
2066 // Copy of MCExpr::print customized for NVPTX
printMCExpr(const MCExpr & Expr,raw_ostream & OS)2067 void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
2068 switch (Expr.getKind()) {
2069 case MCExpr::Target:
2070 return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);
2071 case MCExpr::Constant:
2072 OS << cast<MCConstantExpr>(Expr).getValue();
2073 return;
2074
2075 case MCExpr::SymbolRef: {
2076 const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
2077 const MCSymbol &Sym = SRE.getSymbol();
2078 Sym.print(OS, MAI);
2079 return;
2080 }
2081
2082 case MCExpr::Unary: {
2083 const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
2084 switch (UE.getOpcode()) {
2085 case MCUnaryExpr::LNot: OS << '!'; break;
2086 case MCUnaryExpr::Minus: OS << '-'; break;
2087 case MCUnaryExpr::Not: OS << '~'; break;
2088 case MCUnaryExpr::Plus: OS << '+'; break;
2089 }
2090 printMCExpr(*UE.getSubExpr(), OS);
2091 return;
2092 }
2093
2094 case MCExpr::Binary: {
2095 const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
2096
2097 // Only print parens around the LHS if it is non-trivial.
2098 if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
2099 isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
2100 printMCExpr(*BE.getLHS(), OS);
2101 } else {
2102 OS << '(';
2103 printMCExpr(*BE.getLHS(), OS);
2104 OS<< ')';
2105 }
2106
2107 switch (BE.getOpcode()) {
2108 case MCBinaryExpr::Add:
2109 // Print "X-42" instead of "X+-42".
2110 if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
2111 if (RHSC->getValue() < 0) {
2112 OS << RHSC->getValue();
2113 return;
2114 }
2115 }
2116
2117 OS << '+';
2118 break;
2119 default: llvm_unreachable("Unhandled binary operator");
2120 }
2121
2122 // Only print parens around the LHS if it is non-trivial.
2123 if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
2124 printMCExpr(*BE.getRHS(), OS);
2125 } else {
2126 OS << '(';
2127 printMCExpr(*BE.getRHS(), OS);
2128 OS << ')';
2129 }
2130 return;
2131 }
2132 }
2133
2134 llvm_unreachable("Invalid expression kind!");
2135 }
2136
2137 /// PrintAsmOperand - Print out an operand for an inline asm expression.
2138 ///
PrintAsmOperand(const MachineInstr * MI,unsigned OpNo,const char * ExtraCode,raw_ostream & O)2139 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
2140 const char *ExtraCode, raw_ostream &O) {
2141 if (ExtraCode && ExtraCode[0]) {
2142 if (ExtraCode[1] != 0)
2143 return true; // Unknown modifier.
2144
2145 switch (ExtraCode[0]) {
2146 default:
2147 // See if this is a generic print operand
2148 return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O);
2149 case 'r':
2150 break;
2151 }
2152 }
2153
2154 printOperand(MI, OpNo, O);
2155
2156 return false;
2157 }
2158
PrintAsmMemoryOperand(const MachineInstr * MI,unsigned OpNo,const char * ExtraCode,raw_ostream & O)2159 bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
2160 unsigned OpNo,
2161 const char *ExtraCode,
2162 raw_ostream &O) {
2163 if (ExtraCode && ExtraCode[0])
2164 return true; // Unknown modifier
2165
2166 O << '[';
2167 printMemOperand(MI, OpNo, O);
2168 O << ']';
2169
2170 return false;
2171 }
2172
printOperand(const MachineInstr * MI,int opNum,raw_ostream & O)2173 void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum,
2174 raw_ostream &O) {
2175 const MachineOperand &MO = MI->getOperand(opNum);
2176 switch (MO.getType()) {
2177 case MachineOperand::MO_Register:
2178 if (MO.getReg().isPhysical()) {
2179 if (MO.getReg() == NVPTX::VRDepot)
2180 O << DEPOTNAME << getFunctionNumber();
2181 else
2182 O << NVPTXInstPrinter::getRegisterName(MO.getReg());
2183 } else {
2184 emitVirtualRegister(MO.getReg(), O);
2185 }
2186 break;
2187
2188 case MachineOperand::MO_Immediate:
2189 O << MO.getImm();
2190 break;
2191
2192 case MachineOperand::MO_FPImmediate:
2193 printFPConstant(MO.getFPImm(), O);
2194 break;
2195
2196 case MachineOperand::MO_GlobalAddress:
2197 PrintSymbolOperand(MO, O);
2198 break;
2199
2200 case MachineOperand::MO_MachineBasicBlock:
2201 MO.getMBB()->getSymbol()->print(O, MAI);
2202 break;
2203
2204 default:
2205 llvm_unreachable("Operand type not supported.");
2206 }
2207 }
2208
printMemOperand(const MachineInstr * MI,int opNum,raw_ostream & O,const char * Modifier)2209 void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum,
2210 raw_ostream &O, const char *Modifier) {
2211 printOperand(MI, opNum, O);
2212
2213 if (Modifier && strcmp(Modifier, "add") == 0) {
2214 O << ", ";
2215 printOperand(MI, opNum + 1, O);
2216 } else {
2217 if (MI->getOperand(opNum + 1).isImm() &&
2218 MI->getOperand(opNum + 1).getImm() == 0)
2219 return; // don't print ',0' or '+0'
2220 O << "+";
2221 printOperand(MI, opNum + 1, O);
2222 }
2223 }
2224
2225 // Force static initialization.
LLVMInitializeNVPTXAsmPrinter()2226 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() {
2227 RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
2228 RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
2229 }
2230