1*67e74705SXin Li //===---- CGOpenMPRuntimeNVPTX.cpp - Interface to OpenMP NVPTX Runtimes ---===//
2*67e74705SXin Li //
3*67e74705SXin Li // The LLVM Compiler Infrastructure
4*67e74705SXin Li //
5*67e74705SXin Li // This file is distributed under the University of Illinois Open Source
6*67e74705SXin Li // License. See LICENSE.TXT for details.
7*67e74705SXin Li //
8*67e74705SXin Li //===----------------------------------------------------------------------===//
9*67e74705SXin Li //
10*67e74705SXin Li // This provides a class for OpenMP runtime code generation specialized to NVPTX
11*67e74705SXin Li // targets.
12*67e74705SXin Li //
13*67e74705SXin Li //===----------------------------------------------------------------------===//
14*67e74705SXin Li
15*67e74705SXin Li #include "CGOpenMPRuntimeNVPTX.h"
16*67e74705SXin Li #include "clang/AST/DeclOpenMP.h"
17*67e74705SXin Li #include "CodeGenFunction.h"
18*67e74705SXin Li #include "clang/AST/StmtOpenMP.h"
19*67e74705SXin Li
20*67e74705SXin Li using namespace clang;
21*67e74705SXin Li using namespace CodeGen;
22*67e74705SXin Li
23*67e74705SXin Li /// \brief Get the GPU warp size.
getNVPTXWarpSize(CodeGenFunction & CGF)24*67e74705SXin Li llvm::Value *CGOpenMPRuntimeNVPTX::getNVPTXWarpSize(CodeGenFunction &CGF) {
25*67e74705SXin Li CGBuilderTy &Bld = CGF.Builder;
26*67e74705SXin Li return Bld.CreateCall(
27*67e74705SXin Li llvm::Intrinsic::getDeclaration(
28*67e74705SXin Li &CGM.getModule(), llvm::Intrinsic::nvvm_read_ptx_sreg_warpsize),
29*67e74705SXin Li llvm::None, "nvptx_warp_size");
30*67e74705SXin Li }
31*67e74705SXin Li
32*67e74705SXin Li /// \brief Get the id of the current thread on the GPU.
getNVPTXThreadID(CodeGenFunction & CGF)33*67e74705SXin Li llvm::Value *CGOpenMPRuntimeNVPTX::getNVPTXThreadID(CodeGenFunction &CGF) {
34*67e74705SXin Li CGBuilderTy &Bld = CGF.Builder;
35*67e74705SXin Li return Bld.CreateCall(
36*67e74705SXin Li llvm::Intrinsic::getDeclaration(
37*67e74705SXin Li &CGM.getModule(), llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x),
38*67e74705SXin Li llvm::None, "nvptx_tid");
39*67e74705SXin Li }
40*67e74705SXin Li
41*67e74705SXin Li // \brief Get the maximum number of threads in a block of the GPU.
getNVPTXNumThreads(CodeGenFunction & CGF)42*67e74705SXin Li llvm::Value *CGOpenMPRuntimeNVPTX::getNVPTXNumThreads(CodeGenFunction &CGF) {
43*67e74705SXin Li CGBuilderTy &Bld = CGF.Builder;
44*67e74705SXin Li return Bld.CreateCall(
45*67e74705SXin Li llvm::Intrinsic::getDeclaration(
46*67e74705SXin Li &CGM.getModule(), llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x),
47*67e74705SXin Li llvm::None, "nvptx_num_threads");
48*67e74705SXin Li }
49*67e74705SXin Li
50*67e74705SXin Li /// \brief Get barrier to synchronize all threads in a block.
getNVPTXCTABarrier(CodeGenFunction & CGF)51*67e74705SXin Li void CGOpenMPRuntimeNVPTX::getNVPTXCTABarrier(CodeGenFunction &CGF) {
52*67e74705SXin Li CGBuilderTy &Bld = CGF.Builder;
53*67e74705SXin Li Bld.CreateCall(llvm::Intrinsic::getDeclaration(
54*67e74705SXin Li &CGM.getModule(), llvm::Intrinsic::nvvm_barrier0));
55*67e74705SXin Li }
56*67e74705SXin Li
57*67e74705SXin Li // \brief Synchronize all GPU threads in a block.
syncCTAThreads(CodeGenFunction & CGF)58*67e74705SXin Li void CGOpenMPRuntimeNVPTX::syncCTAThreads(CodeGenFunction &CGF) {
59*67e74705SXin Li getNVPTXCTABarrier(CGF);
60*67e74705SXin Li }
61*67e74705SXin Li
62*67e74705SXin Li /// \brief Get the thread id of the OMP master thread.
63*67e74705SXin Li /// The master thread id is the first thread (lane) of the last warp in the
64*67e74705SXin Li /// GPU block. Warp size is assumed to be some power of 2.
65*67e74705SXin Li /// Thread id is 0 indexed.
66*67e74705SXin Li /// E.g: If NumThreads is 33, master id is 32.
67*67e74705SXin Li /// If NumThreads is 64, master id is 32.
68*67e74705SXin Li /// If NumThreads is 1024, master id is 992.
getMasterThreadID(CodeGenFunction & CGF)69*67e74705SXin Li llvm::Value *CGOpenMPRuntimeNVPTX::getMasterThreadID(CodeGenFunction &CGF) {
70*67e74705SXin Li CGBuilderTy &Bld = CGF.Builder;
71*67e74705SXin Li llvm::Value *NumThreads = getNVPTXNumThreads(CGF);
72*67e74705SXin Li
73*67e74705SXin Li // We assume that the warp size is a power of 2.
74*67e74705SXin Li llvm::Value *Mask = Bld.CreateSub(getNVPTXWarpSize(CGF), Bld.getInt32(1));
75*67e74705SXin Li
76*67e74705SXin Li return Bld.CreateAnd(Bld.CreateSub(NumThreads, Bld.getInt32(1)),
77*67e74705SXin Li Bld.CreateNot(Mask), "master_tid");
78*67e74705SXin Li }
79*67e74705SXin Li
80*67e74705SXin Li namespace {
81*67e74705SXin Li enum OpenMPRTLFunctionNVPTX {
82*67e74705SXin Li /// \brief Call to void __kmpc_kernel_init(kmp_int32 omp_handle,
83*67e74705SXin Li /// kmp_int32 thread_limit);
84*67e74705SXin Li OMPRTL_NVPTX__kmpc_kernel_init,
85*67e74705SXin Li };
86*67e74705SXin Li
87*67e74705SXin Li // NVPTX Address space
88*67e74705SXin Li enum ADDRESS_SPACE {
89*67e74705SXin Li ADDRESS_SPACE_SHARED = 3,
90*67e74705SXin Li };
91*67e74705SXin Li } // namespace
92*67e74705SXin Li
WorkerFunctionState(CodeGenModule & CGM)93*67e74705SXin Li CGOpenMPRuntimeNVPTX::WorkerFunctionState::WorkerFunctionState(
94*67e74705SXin Li CodeGenModule &CGM)
95*67e74705SXin Li : WorkerFn(nullptr), CGFI(nullptr) {
96*67e74705SXin Li createWorkerFunction(CGM);
97*67e74705SXin Li }
98*67e74705SXin Li
createWorkerFunction(CodeGenModule & CGM)99*67e74705SXin Li void CGOpenMPRuntimeNVPTX::WorkerFunctionState::createWorkerFunction(
100*67e74705SXin Li CodeGenModule &CGM) {
101*67e74705SXin Li // Create an worker function with no arguments.
102*67e74705SXin Li CGFI = &CGM.getTypes().arrangeNullaryFunction();
103*67e74705SXin Li
104*67e74705SXin Li WorkerFn = llvm::Function::Create(
105*67e74705SXin Li CGM.getTypes().GetFunctionType(*CGFI), llvm::GlobalValue::InternalLinkage,
106*67e74705SXin Li /* placeholder */ "_worker", &CGM.getModule());
107*67e74705SXin Li CGM.SetInternalFunctionAttributes(/*D=*/nullptr, WorkerFn, *CGFI);
108*67e74705SXin Li WorkerFn->setLinkage(llvm::GlobalValue::InternalLinkage);
109*67e74705SXin Li WorkerFn->addFnAttr(llvm::Attribute::NoInline);
110*67e74705SXin Li }
111*67e74705SXin Li
initializeEnvironment()112*67e74705SXin Li void CGOpenMPRuntimeNVPTX::initializeEnvironment() {
113*67e74705SXin Li //
114*67e74705SXin Li // Initialize master-worker control state in shared memory.
115*67e74705SXin Li //
116*67e74705SXin Li
117*67e74705SXin Li auto DL = CGM.getDataLayout();
118*67e74705SXin Li ActiveWorkers = new llvm::GlobalVariable(
119*67e74705SXin Li CGM.getModule(), CGM.Int32Ty, /*isConstant=*/false,
120*67e74705SXin Li llvm::GlobalValue::CommonLinkage,
121*67e74705SXin Li llvm::Constant::getNullValue(CGM.Int32Ty), "__omp_num_threads", 0,
122*67e74705SXin Li llvm::GlobalVariable::NotThreadLocal, ADDRESS_SPACE_SHARED);
123*67e74705SXin Li ActiveWorkers->setAlignment(DL.getPrefTypeAlignment(CGM.Int32Ty));
124*67e74705SXin Li
125*67e74705SXin Li WorkID = new llvm::GlobalVariable(
126*67e74705SXin Li CGM.getModule(), CGM.Int64Ty, /*isConstant=*/false,
127*67e74705SXin Li llvm::GlobalValue::CommonLinkage,
128*67e74705SXin Li llvm::Constant::getNullValue(CGM.Int64Ty), "__tgt_work_id", 0,
129*67e74705SXin Li llvm::GlobalVariable::NotThreadLocal, ADDRESS_SPACE_SHARED);
130*67e74705SXin Li WorkID->setAlignment(DL.getPrefTypeAlignment(CGM.Int64Ty));
131*67e74705SXin Li }
132*67e74705SXin Li
emitWorkerFunction(WorkerFunctionState & WST)133*67e74705SXin Li void CGOpenMPRuntimeNVPTX::emitWorkerFunction(WorkerFunctionState &WST) {
134*67e74705SXin Li auto &Ctx = CGM.getContext();
135*67e74705SXin Li
136*67e74705SXin Li CodeGenFunction CGF(CGM, /*suppressNewContext=*/true);
137*67e74705SXin Li CGF.StartFunction(GlobalDecl(), Ctx.VoidTy, WST.WorkerFn, *WST.CGFI, {});
138*67e74705SXin Li emitWorkerLoop(CGF, WST);
139*67e74705SXin Li CGF.FinishFunction();
140*67e74705SXin Li }
141*67e74705SXin Li
emitWorkerLoop(CodeGenFunction & CGF,WorkerFunctionState & WST)142*67e74705SXin Li void CGOpenMPRuntimeNVPTX::emitWorkerLoop(CodeGenFunction &CGF,
143*67e74705SXin Li WorkerFunctionState &WST) {
144*67e74705SXin Li //
145*67e74705SXin Li // The workers enter this loop and wait for parallel work from the master.
146*67e74705SXin Li // When the master encounters a parallel region it sets up the work + variable
147*67e74705SXin Li // arguments, and wakes up the workers. The workers first check to see if
148*67e74705SXin Li // they are required for the parallel region, i.e., within the # of requested
149*67e74705SXin Li // parallel threads. The activated workers load the variable arguments and
150*67e74705SXin Li // execute the parallel work.
151*67e74705SXin Li //
152*67e74705SXin Li
153*67e74705SXin Li CGBuilderTy &Bld = CGF.Builder;
154*67e74705SXin Li
155*67e74705SXin Li llvm::BasicBlock *AwaitBB = CGF.createBasicBlock(".await.work");
156*67e74705SXin Li llvm::BasicBlock *SelectWorkersBB = CGF.createBasicBlock(".select.workers");
157*67e74705SXin Li llvm::BasicBlock *ExecuteBB = CGF.createBasicBlock(".execute.parallel");
158*67e74705SXin Li llvm::BasicBlock *TerminateBB = CGF.createBasicBlock(".terminate.parallel");
159*67e74705SXin Li llvm::BasicBlock *BarrierBB = CGF.createBasicBlock(".barrier.parallel");
160*67e74705SXin Li llvm::BasicBlock *ExitBB = CGF.createBasicBlock(".exit");
161*67e74705SXin Li
162*67e74705SXin Li CGF.EmitBranch(AwaitBB);
163*67e74705SXin Li
164*67e74705SXin Li // Workers wait for work from master.
165*67e74705SXin Li CGF.EmitBlock(AwaitBB);
166*67e74705SXin Li // Wait for parallel work
167*67e74705SXin Li syncCTAThreads(CGF);
168*67e74705SXin Li // On termination condition (workid == 0), exit loop.
169*67e74705SXin Li llvm::Value *ShouldTerminate = Bld.CreateICmpEQ(
170*67e74705SXin Li Bld.CreateAlignedLoad(WorkID, WorkID->getAlignment()),
171*67e74705SXin Li llvm::Constant::getNullValue(WorkID->getType()->getElementType()),
172*67e74705SXin Li "should_terminate");
173*67e74705SXin Li Bld.CreateCondBr(ShouldTerminate, ExitBB, SelectWorkersBB);
174*67e74705SXin Li
175*67e74705SXin Li // Activate requested workers.
176*67e74705SXin Li CGF.EmitBlock(SelectWorkersBB);
177*67e74705SXin Li llvm::Value *ThreadID = getNVPTXThreadID(CGF);
178*67e74705SXin Li llvm::Value *ActiveThread = Bld.CreateICmpSLT(
179*67e74705SXin Li ThreadID,
180*67e74705SXin Li Bld.CreateAlignedLoad(ActiveWorkers, ActiveWorkers->getAlignment()),
181*67e74705SXin Li "active_thread");
182*67e74705SXin Li Bld.CreateCondBr(ActiveThread, ExecuteBB, BarrierBB);
183*67e74705SXin Li
184*67e74705SXin Li // Signal start of parallel region.
185*67e74705SXin Li CGF.EmitBlock(ExecuteBB);
186*67e74705SXin Li // TODO: Add parallel work.
187*67e74705SXin Li
188*67e74705SXin Li // Signal end of parallel region.
189*67e74705SXin Li CGF.EmitBlock(TerminateBB);
190*67e74705SXin Li CGF.EmitBranch(BarrierBB);
191*67e74705SXin Li
192*67e74705SXin Li // All active and inactive workers wait at a barrier after parallel region.
193*67e74705SXin Li CGF.EmitBlock(BarrierBB);
194*67e74705SXin Li // Barrier after parallel region.
195*67e74705SXin Li syncCTAThreads(CGF);
196*67e74705SXin Li CGF.EmitBranch(AwaitBB);
197*67e74705SXin Li
198*67e74705SXin Li // Exit target region.
199*67e74705SXin Li CGF.EmitBlock(ExitBB);
200*67e74705SXin Li }
201*67e74705SXin Li
202*67e74705SXin Li // Setup NVPTX threads for master-worker OpenMP scheme.
emitEntryHeader(CodeGenFunction & CGF,EntryFunctionState & EST,WorkerFunctionState & WST)203*67e74705SXin Li void CGOpenMPRuntimeNVPTX::emitEntryHeader(CodeGenFunction &CGF,
204*67e74705SXin Li EntryFunctionState &EST,
205*67e74705SXin Li WorkerFunctionState &WST) {
206*67e74705SXin Li CGBuilderTy &Bld = CGF.Builder;
207*67e74705SXin Li
208*67e74705SXin Li // Get the master thread id.
209*67e74705SXin Li llvm::Value *MasterID = getMasterThreadID(CGF);
210*67e74705SXin Li // Current thread's identifier.
211*67e74705SXin Li llvm::Value *ThreadID = getNVPTXThreadID(CGF);
212*67e74705SXin Li
213*67e74705SXin Li // Setup BBs in entry function.
214*67e74705SXin Li llvm::BasicBlock *WorkerCheckBB = CGF.createBasicBlock(".check.for.worker");
215*67e74705SXin Li llvm::BasicBlock *WorkerBB = CGF.createBasicBlock(".worker");
216*67e74705SXin Li llvm::BasicBlock *MasterBB = CGF.createBasicBlock(".master");
217*67e74705SXin Li EST.ExitBB = CGF.createBasicBlock(".exit");
218*67e74705SXin Li
219*67e74705SXin Li // The head (master thread) marches on while its body of companion threads in
220*67e74705SXin Li // the warp go to sleep.
221*67e74705SXin Li llvm::Value *ShouldDie =
222*67e74705SXin Li Bld.CreateICmpUGT(ThreadID, MasterID, "excess_in_master_warp");
223*67e74705SXin Li Bld.CreateCondBr(ShouldDie, EST.ExitBB, WorkerCheckBB);
224*67e74705SXin Li
225*67e74705SXin Li // Select worker threads...
226*67e74705SXin Li CGF.EmitBlock(WorkerCheckBB);
227*67e74705SXin Li llvm::Value *IsWorker = Bld.CreateICmpULT(ThreadID, MasterID, "is_worker");
228*67e74705SXin Li Bld.CreateCondBr(IsWorker, WorkerBB, MasterBB);
229*67e74705SXin Li
230*67e74705SXin Li // ... and send to worker loop, awaiting parallel invocation.
231*67e74705SXin Li CGF.EmitBlock(WorkerBB);
232*67e74705SXin Li CGF.EmitCallOrInvoke(WST.WorkerFn, llvm::None);
233*67e74705SXin Li CGF.EmitBranch(EST.ExitBB);
234*67e74705SXin Li
235*67e74705SXin Li // Only master thread executes subsequent serial code.
236*67e74705SXin Li CGF.EmitBlock(MasterBB);
237*67e74705SXin Li
238*67e74705SXin Li // First action in sequential region:
239*67e74705SXin Li // Initialize the state of the OpenMP runtime library on the GPU.
240*67e74705SXin Li llvm::Value *Args[] = {Bld.getInt32(/*OmpHandle=*/0), getNVPTXThreadID(CGF)};
241*67e74705SXin Li CGF.EmitRuntimeCall(createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_init),
242*67e74705SXin Li Args);
243*67e74705SXin Li }
244*67e74705SXin Li
emitEntryFooter(CodeGenFunction & CGF,EntryFunctionState & EST)245*67e74705SXin Li void CGOpenMPRuntimeNVPTX::emitEntryFooter(CodeGenFunction &CGF,
246*67e74705SXin Li EntryFunctionState &EST) {
247*67e74705SXin Li CGBuilderTy &Bld = CGF.Builder;
248*67e74705SXin Li llvm::BasicBlock *TerminateBB = CGF.createBasicBlock(".termination.notifier");
249*67e74705SXin Li CGF.EmitBranch(TerminateBB);
250*67e74705SXin Li
251*67e74705SXin Li CGF.EmitBlock(TerminateBB);
252*67e74705SXin Li // Signal termination condition.
253*67e74705SXin Li Bld.CreateAlignedStore(
254*67e74705SXin Li llvm::Constant::getNullValue(WorkID->getType()->getElementType()), WorkID,
255*67e74705SXin Li WorkID->getAlignment());
256*67e74705SXin Li // Barrier to terminate worker threads.
257*67e74705SXin Li syncCTAThreads(CGF);
258*67e74705SXin Li // Master thread jumps to exit point.
259*67e74705SXin Li CGF.EmitBranch(EST.ExitBB);
260*67e74705SXin Li
261*67e74705SXin Li CGF.EmitBlock(EST.ExitBB);
262*67e74705SXin Li }
263*67e74705SXin Li
264*67e74705SXin Li /// \brief Returns specified OpenMP runtime function for the current OpenMP
265*67e74705SXin Li /// implementation. Specialized for the NVPTX device.
266*67e74705SXin Li /// \param Function OpenMP runtime function.
267*67e74705SXin Li /// \return Specified function.
268*67e74705SXin Li llvm::Constant *
createNVPTXRuntimeFunction(unsigned Function)269*67e74705SXin Li CGOpenMPRuntimeNVPTX::createNVPTXRuntimeFunction(unsigned Function) {
270*67e74705SXin Li llvm::Constant *RTLFn = nullptr;
271*67e74705SXin Li switch (static_cast<OpenMPRTLFunctionNVPTX>(Function)) {
272*67e74705SXin Li case OMPRTL_NVPTX__kmpc_kernel_init: {
273*67e74705SXin Li // Build void __kmpc_kernel_init(kmp_int32 omp_handle,
274*67e74705SXin Li // kmp_int32 thread_limit);
275*67e74705SXin Li llvm::Type *TypeParams[] = {CGM.Int32Ty, CGM.Int32Ty};
276*67e74705SXin Li llvm::FunctionType *FnTy =
277*67e74705SXin Li llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
278*67e74705SXin Li RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_kernel_init");
279*67e74705SXin Li break;
280*67e74705SXin Li }
281*67e74705SXin Li }
282*67e74705SXin Li return RTLFn;
283*67e74705SXin Li }
284*67e74705SXin Li
createOffloadEntry(llvm::Constant * ID,llvm::Constant * Addr,uint64_t Size)285*67e74705SXin Li void CGOpenMPRuntimeNVPTX::createOffloadEntry(llvm::Constant *ID,
286*67e74705SXin Li llvm::Constant *Addr,
287*67e74705SXin Li uint64_t Size) {
288*67e74705SXin Li auto *F = dyn_cast<llvm::Function>(Addr);
289*67e74705SXin Li // TODO: Add support for global variables on the device after declare target
290*67e74705SXin Li // support.
291*67e74705SXin Li if (!F)
292*67e74705SXin Li return;
293*67e74705SXin Li llvm::Module *M = F->getParent();
294*67e74705SXin Li llvm::LLVMContext &Ctx = M->getContext();
295*67e74705SXin Li
296*67e74705SXin Li // Get "nvvm.annotations" metadata node
297*67e74705SXin Li llvm::NamedMDNode *MD = M->getOrInsertNamedMetadata("nvvm.annotations");
298*67e74705SXin Li
299*67e74705SXin Li llvm::Metadata *MDVals[] = {
300*67e74705SXin Li llvm::ConstantAsMetadata::get(F), llvm::MDString::get(Ctx, "kernel"),
301*67e74705SXin Li llvm::ConstantAsMetadata::get(
302*67e74705SXin Li llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
303*67e74705SXin Li // Append metadata to nvvm.annotations
304*67e74705SXin Li MD->addOperand(llvm::MDNode::get(Ctx, MDVals));
305*67e74705SXin Li }
306*67e74705SXin Li
emitTargetOutlinedFunction(const OMPExecutableDirective & D,StringRef ParentName,llvm::Function * & OutlinedFn,llvm::Constant * & OutlinedFnID,bool IsOffloadEntry,const RegionCodeGenTy & CodeGen)307*67e74705SXin Li void CGOpenMPRuntimeNVPTX::emitTargetOutlinedFunction(
308*67e74705SXin Li const OMPExecutableDirective &D, StringRef ParentName,
309*67e74705SXin Li llvm::Function *&OutlinedFn, llvm::Constant *&OutlinedFnID,
310*67e74705SXin Li bool IsOffloadEntry, const RegionCodeGenTy &CodeGen) {
311*67e74705SXin Li if (!IsOffloadEntry) // Nothing to do.
312*67e74705SXin Li return;
313*67e74705SXin Li
314*67e74705SXin Li assert(!ParentName.empty() && "Invalid target region parent name!");
315*67e74705SXin Li
316*67e74705SXin Li EntryFunctionState EST;
317*67e74705SXin Li WorkerFunctionState WST(CGM);
318*67e74705SXin Li
319*67e74705SXin Li // Emit target region as a standalone region.
320*67e74705SXin Li class NVPTXPrePostActionTy : public PrePostActionTy {
321*67e74705SXin Li CGOpenMPRuntimeNVPTX &RT;
322*67e74705SXin Li CGOpenMPRuntimeNVPTX::EntryFunctionState &EST;
323*67e74705SXin Li CGOpenMPRuntimeNVPTX::WorkerFunctionState &WST;
324*67e74705SXin Li
325*67e74705SXin Li public:
326*67e74705SXin Li NVPTXPrePostActionTy(CGOpenMPRuntimeNVPTX &RT,
327*67e74705SXin Li CGOpenMPRuntimeNVPTX::EntryFunctionState &EST,
328*67e74705SXin Li CGOpenMPRuntimeNVPTX::WorkerFunctionState &WST)
329*67e74705SXin Li : RT(RT), EST(EST), WST(WST) {}
330*67e74705SXin Li void Enter(CodeGenFunction &CGF) override {
331*67e74705SXin Li RT.emitEntryHeader(CGF, EST, WST);
332*67e74705SXin Li }
333*67e74705SXin Li void Exit(CodeGenFunction &CGF) override { RT.emitEntryFooter(CGF, EST); }
334*67e74705SXin Li } Action(*this, EST, WST);
335*67e74705SXin Li CodeGen.setAction(Action);
336*67e74705SXin Li emitTargetOutlinedFunctionHelper(D, ParentName, OutlinedFn, OutlinedFnID,
337*67e74705SXin Li IsOffloadEntry, CodeGen);
338*67e74705SXin Li
339*67e74705SXin Li // Create the worker function
340*67e74705SXin Li emitWorkerFunction(WST);
341*67e74705SXin Li
342*67e74705SXin Li // Now change the name of the worker function to correspond to this target
343*67e74705SXin Li // region's entry function.
344*67e74705SXin Li WST.WorkerFn->setName(OutlinedFn->getName() + "_worker");
345*67e74705SXin Li }
346*67e74705SXin Li
CGOpenMPRuntimeNVPTX(CodeGenModule & CGM)347*67e74705SXin Li CGOpenMPRuntimeNVPTX::CGOpenMPRuntimeNVPTX(CodeGenModule &CGM)
348*67e74705SXin Li : CGOpenMPRuntime(CGM), ActiveWorkers(nullptr), WorkID(nullptr) {
349*67e74705SXin Li if (!CGM.getLangOpts().OpenMPIsDevice)
350*67e74705SXin Li llvm_unreachable("OpenMP NVPTX can only handle device code.");
351*67e74705SXin Li
352*67e74705SXin Li // Called once per module during initialization.
353*67e74705SXin Li initializeEnvironment();
354*67e74705SXin Li }
355*67e74705SXin Li
emitNumTeamsClause(CodeGenFunction & CGF,const Expr * NumTeams,const Expr * ThreadLimit,SourceLocation Loc)356*67e74705SXin Li void CGOpenMPRuntimeNVPTX::emitNumTeamsClause(CodeGenFunction &CGF,
357*67e74705SXin Li const Expr *NumTeams,
358*67e74705SXin Li const Expr *ThreadLimit,
359*67e74705SXin Li SourceLocation Loc) {}
360*67e74705SXin Li
emitParallelOrTeamsOutlinedFunction(const OMPExecutableDirective & D,const VarDecl * ThreadIDVar,OpenMPDirectiveKind InnermostKind,const RegionCodeGenTy & CodeGen)361*67e74705SXin Li llvm::Value *CGOpenMPRuntimeNVPTX::emitParallelOrTeamsOutlinedFunction(
362*67e74705SXin Li const OMPExecutableDirective &D, const VarDecl *ThreadIDVar,
363*67e74705SXin Li OpenMPDirectiveKind InnermostKind, const RegionCodeGenTy &CodeGen) {
364*67e74705SXin Li
365*67e74705SXin Li llvm::Function *OutlinedFun = nullptr;
366*67e74705SXin Li if (isa<OMPTeamsDirective>(D)) {
367*67e74705SXin Li llvm::Value *OutlinedFunVal =
368*67e74705SXin Li CGOpenMPRuntime::emitParallelOrTeamsOutlinedFunction(
369*67e74705SXin Li D, ThreadIDVar, InnermostKind, CodeGen);
370*67e74705SXin Li OutlinedFun = cast<llvm::Function>(OutlinedFunVal);
371*67e74705SXin Li OutlinedFun->addFnAttr(llvm::Attribute::AlwaysInline);
372*67e74705SXin Li } else
373*67e74705SXin Li llvm_unreachable("parallel directive is not yet supported for nvptx "
374*67e74705SXin Li "backend.");
375*67e74705SXin Li
376*67e74705SXin Li return OutlinedFun;
377*67e74705SXin Li }
378*67e74705SXin Li
emitTeamsCall(CodeGenFunction & CGF,const OMPExecutableDirective & D,SourceLocation Loc,llvm::Value * OutlinedFn,ArrayRef<llvm::Value * > CapturedVars)379*67e74705SXin Li void CGOpenMPRuntimeNVPTX::emitTeamsCall(CodeGenFunction &CGF,
380*67e74705SXin Li const OMPExecutableDirective &D,
381*67e74705SXin Li SourceLocation Loc,
382*67e74705SXin Li llvm::Value *OutlinedFn,
383*67e74705SXin Li ArrayRef<llvm::Value *> CapturedVars) {
384*67e74705SXin Li if (!CGF.HaveInsertPoint())
385*67e74705SXin Li return;
386*67e74705SXin Li
387*67e74705SXin Li Address ZeroAddr =
388*67e74705SXin Li CGF.CreateTempAlloca(CGF.Int32Ty, CharUnits::fromQuantity(4),
389*67e74705SXin Li /*Name*/ ".zero.addr");
390*67e74705SXin Li CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C*/ 0));
391*67e74705SXin Li llvm::SmallVector<llvm::Value *, 16> OutlinedFnArgs;
392*67e74705SXin Li OutlinedFnArgs.push_back(ZeroAddr.getPointer());
393*67e74705SXin Li OutlinedFnArgs.push_back(ZeroAddr.getPointer());
394*67e74705SXin Li OutlinedFnArgs.append(CapturedVars.begin(), CapturedVars.end());
395*67e74705SXin Li CGF.EmitCallOrInvoke(OutlinedFn, OutlinedFnArgs);
396*67e74705SXin Li }
397