1 //===-- SPIRVPreLegalizer.cpp - prepare IR for legalization -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The pass prepares IR for legalization: it assigns SPIR-V types to registers
10 // and removes intrinsics which holded these types during IR translation.
11 // Also it processes constants and registers them in GR to avoid duplication.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "SPIRV.h"
16 #include "SPIRVSubtarget.h"
17 #include "SPIRVUtils.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
20 #include "llvm/IR/Attributes.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DebugInfoMetadata.h"
23 #include "llvm/IR/IntrinsicsSPIRV.h"
24 #include "llvm/Target/TargetIntrinsicInfo.h"
25
26 #define DEBUG_TYPE "spirv-prelegalizer"
27
28 using namespace llvm;
29
30 namespace {
31 class SPIRVPreLegalizer : public MachineFunctionPass {
32 public:
33 static char ID;
SPIRVPreLegalizer()34 SPIRVPreLegalizer() : MachineFunctionPass(ID) {
35 initializeSPIRVPreLegalizerPass(*PassRegistry::getPassRegistry());
36 }
37 bool runOnMachineFunction(MachineFunction &MF) override;
38 };
39 } // namespace
40
addConstantsToTrack(MachineFunction & MF,SPIRVGlobalRegistry * GR)41 static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
42 MachineRegisterInfo &MRI = MF.getRegInfo();
43 DenseMap<MachineInstr *, Register> RegsAlreadyAddedToDT;
44 SmallVector<MachineInstr *, 10> ToErase, ToEraseComposites;
45 for (MachineBasicBlock &MBB : MF) {
46 for (MachineInstr &MI : MBB) {
47 if (!isSpvIntrinsic(MI, Intrinsic::spv_track_constant))
48 continue;
49 ToErase.push_back(&MI);
50 auto *Const =
51 cast<Constant>(cast<ConstantAsMetadata>(
52 MI.getOperand(3).getMetadata()->getOperand(0))
53 ->getValue());
54 if (auto *GV = dyn_cast<GlobalValue>(Const)) {
55 Register Reg = GR->find(GV, &MF);
56 if (!Reg.isValid())
57 GR->add(GV, &MF, MI.getOperand(2).getReg());
58 else
59 RegsAlreadyAddedToDT[&MI] = Reg;
60 } else {
61 Register Reg = GR->find(Const, &MF);
62 if (!Reg.isValid()) {
63 if (auto *ConstVec = dyn_cast<ConstantDataVector>(Const)) {
64 auto *BuildVec = MRI.getVRegDef(MI.getOperand(2).getReg());
65 assert(BuildVec &&
66 BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
67 for (unsigned i = 0; i < ConstVec->getNumElements(); ++i)
68 GR->add(ConstVec->getElementAsConstant(i), &MF,
69 BuildVec->getOperand(1 + i).getReg());
70 }
71 GR->add(Const, &MF, MI.getOperand(2).getReg());
72 } else {
73 RegsAlreadyAddedToDT[&MI] = Reg;
74 // This MI is unused and will be removed. If the MI uses
75 // const_composite, it will be unused and should be removed too.
76 assert(MI.getOperand(2).isReg() && "Reg operand is expected");
77 MachineInstr *SrcMI = MRI.getVRegDef(MI.getOperand(2).getReg());
78 if (SrcMI && isSpvIntrinsic(*SrcMI, Intrinsic::spv_const_composite))
79 ToEraseComposites.push_back(SrcMI);
80 }
81 }
82 }
83 }
84 for (MachineInstr *MI : ToErase) {
85 Register Reg = MI->getOperand(2).getReg();
86 if (RegsAlreadyAddedToDT.find(MI) != RegsAlreadyAddedToDT.end())
87 Reg = RegsAlreadyAddedToDT[MI];
88 MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg);
89 MI->eraseFromParent();
90 }
91 for (MachineInstr *MI : ToEraseComposites)
92 MI->eraseFromParent();
93 }
94
foldConstantsIntoIntrinsics(MachineFunction & MF)95 static void foldConstantsIntoIntrinsics(MachineFunction &MF) {
96 SmallVector<MachineInstr *, 10> ToErase;
97 MachineRegisterInfo &MRI = MF.getRegInfo();
98 const unsigned AssignNameOperandShift = 2;
99 for (MachineBasicBlock &MBB : MF) {
100 for (MachineInstr &MI : MBB) {
101 if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
102 continue;
103 unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift;
104 while (MI.getOperand(NumOp).isReg()) {
105 MachineOperand &MOp = MI.getOperand(NumOp);
106 MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg());
107 assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT);
108 MI.removeOperand(NumOp);
109 MI.addOperand(MachineOperand::CreateImm(
110 ConstMI->getOperand(1).getCImm()->getZExtValue()));
111 if (MRI.use_empty(ConstMI->getOperand(0).getReg()))
112 ToErase.push_back(ConstMI);
113 }
114 }
115 }
116 for (MachineInstr *MI : ToErase)
117 MI->eraseFromParent();
118 }
119
insertBitcasts(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB)120 static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
121 MachineIRBuilder MIB) {
122 SmallVector<MachineInstr *, 10> ToErase;
123 for (MachineBasicBlock &MBB : MF) {
124 for (MachineInstr &MI : MBB) {
125 if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast))
126 continue;
127 assert(MI.getOperand(2).isReg());
128 MIB.setInsertPt(*MI.getParent(), MI);
129 MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
130 ToErase.push_back(&MI);
131 }
132 }
133 for (MachineInstr *MI : ToErase)
134 MI->eraseFromParent();
135 }
136
137 // Translating GV, IRTranslator sometimes generates following IR:
138 // %1 = G_GLOBAL_VALUE
139 // %2 = COPY %1
140 // %3 = G_ADDRSPACE_CAST %2
141 // New registers have no SPIRVType and no register class info.
142 //
143 // Set SPIRVType for GV, propagate it from GV to other instructions,
144 // also set register classes.
propagateSPIRVType(MachineInstr * MI,SPIRVGlobalRegistry * GR,MachineRegisterInfo & MRI,MachineIRBuilder & MIB)145 static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
146 MachineRegisterInfo &MRI,
147 MachineIRBuilder &MIB) {
148 SPIRVType *SpirvTy = nullptr;
149 assert(MI && "Machine instr is expected");
150 if (MI->getOperand(0).isReg()) {
151 Register Reg = MI->getOperand(0).getReg();
152 SpirvTy = GR->getSPIRVTypeForVReg(Reg);
153 if (!SpirvTy) {
154 switch (MI->getOpcode()) {
155 case TargetOpcode::G_CONSTANT: {
156 MIB.setInsertPt(*MI->getParent(), MI);
157 Type *Ty = MI->getOperand(1).getCImm()->getType();
158 SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
159 break;
160 }
161 case TargetOpcode::G_GLOBAL_VALUE: {
162 MIB.setInsertPt(*MI->getParent(), MI);
163 Type *Ty = MI->getOperand(1).getGlobal()->getType();
164 SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
165 break;
166 }
167 case TargetOpcode::G_TRUNC:
168 case TargetOpcode::G_ADDRSPACE_CAST:
169 case TargetOpcode::G_PTR_ADD:
170 case TargetOpcode::COPY: {
171 MachineOperand &Op = MI->getOperand(1);
172 MachineInstr *Def = Op.isReg() ? MRI.getVRegDef(Op.getReg()) : nullptr;
173 if (Def)
174 SpirvTy = propagateSPIRVType(Def, GR, MRI, MIB);
175 break;
176 }
177 default:
178 break;
179 }
180 if (SpirvTy)
181 GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
182 if (!MRI.getRegClassOrNull(Reg))
183 MRI.setRegClass(Reg, &SPIRV::IDRegClass);
184 }
185 }
186 return SpirvTy;
187 }
188
189 // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
190 // a dst of the definition, assign SPIRVType to both registers. If SpirvTy is
191 // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
192 // It's used also in SPIRVBuiltins.cpp.
193 // TODO: maybe move to SPIRVUtils.
194 namespace llvm {
insertAssignInstr(Register Reg,Type * Ty,SPIRVType * SpirvTy,SPIRVGlobalRegistry * GR,MachineIRBuilder & MIB,MachineRegisterInfo & MRI)195 Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
196 SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
197 MachineRegisterInfo &MRI) {
198 MachineInstr *Def = MRI.getVRegDef(Reg);
199 assert((Ty || SpirvTy) && "Either LLVM or SPIRV type is expected.");
200 MIB.setInsertPt(*Def->getParent(),
201 (Def->getNextNode() ? Def->getNextNode()->getIterator()
202 : Def->getParent()->end()));
203 Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
204 if (auto *RC = MRI.getRegClassOrNull(Reg))
205 MRI.setRegClass(NewReg, RC);
206 SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB);
207 GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
208 // This is to make it convenient for Legalizer to get the SPIRVType
209 // when processing the actual MI (i.e. not pseudo one).
210 GR->assignSPIRVTypeToVReg(SpirvTy, NewReg, MIB.getMF());
211 // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to keep
212 // the flags after instruction selection.
213 const uint16_t Flags = Def->getFlags();
214 MIB.buildInstr(SPIRV::ASSIGN_TYPE)
215 .addDef(Reg)
216 .addUse(NewReg)
217 .addUse(GR->getSPIRVTypeID(SpirvTy))
218 .setMIFlags(Flags);
219 Def->getOperand(0).setReg(NewReg);
220 MRI.setRegClass(Reg, &SPIRV::ANYIDRegClass);
221 return NewReg;
222 }
223 } // namespace llvm
224
generateAssignInstrs(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB)225 static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
226 MachineIRBuilder MIB) {
227 MachineRegisterInfo &MRI = MF.getRegInfo();
228 SmallVector<MachineInstr *, 10> ToErase;
229
230 for (MachineBasicBlock *MBB : post_order(&MF)) {
231 if (MBB->empty())
232 continue;
233
234 bool ReachedBegin = false;
235 for (auto MII = std::prev(MBB->end()), Begin = MBB->begin();
236 !ReachedBegin;) {
237 MachineInstr &MI = *MII;
238
239 if (isSpvIntrinsic(MI, Intrinsic::spv_assign_type)) {
240 Register Reg = MI.getOperand(1).getReg();
241 Type *Ty = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
242 MachineInstr *Def = MRI.getVRegDef(Reg);
243 assert(Def && "Expecting an instruction that defines the register");
244 // G_GLOBAL_VALUE already has type info.
245 if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE)
246 insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
247 ToErase.push_back(&MI);
248 } else if (MI.getOpcode() == TargetOpcode::G_CONSTANT ||
249 MI.getOpcode() == TargetOpcode::G_FCONSTANT ||
250 MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
251 // %rc = G_CONSTANT ty Val
252 // ===>
253 // %cty = OpType* ty
254 // %rctmp = G_CONSTANT ty Val
255 // %rc = ASSIGN_TYPE %rctmp, %cty
256 Register Reg = MI.getOperand(0).getReg();
257 if (MRI.hasOneUse(Reg)) {
258 MachineInstr &UseMI = *MRI.use_instr_begin(Reg);
259 if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
260 isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
261 continue;
262 }
263 Type *Ty = nullptr;
264 if (MI.getOpcode() == TargetOpcode::G_CONSTANT)
265 Ty = MI.getOperand(1).getCImm()->getType();
266 else if (MI.getOpcode() == TargetOpcode::G_FCONSTANT)
267 Ty = MI.getOperand(1).getFPImm()->getType();
268 else {
269 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
270 Type *ElemTy = nullptr;
271 MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
272 assert(ElemMI);
273
274 if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT)
275 ElemTy = ElemMI->getOperand(1).getCImm()->getType();
276 else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT)
277 ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
278 else
279 llvm_unreachable("Unexpected opcode");
280 unsigned NumElts =
281 MI.getNumExplicitOperands() - MI.getNumExplicitDefs();
282 Ty = VectorType::get(ElemTy, NumElts, false);
283 }
284 insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
285 } else if (MI.getOpcode() == TargetOpcode::G_TRUNC ||
286 MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
287 MI.getOpcode() == TargetOpcode::COPY ||
288 MI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) {
289 propagateSPIRVType(&MI, GR, MRI, MIB);
290 }
291
292 if (MII == Begin)
293 ReachedBegin = true;
294 else
295 --MII;
296 }
297 }
298 for (MachineInstr *MI : ToErase)
299 MI->eraseFromParent();
300 }
301
302 static std::pair<Register, unsigned>
createNewIdReg(Register ValReg,unsigned Opcode,MachineRegisterInfo & MRI,const SPIRVGlobalRegistry & GR)303 createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI,
304 const SPIRVGlobalRegistry &GR) {
305 LLT NewT = LLT::scalar(32);
306 SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg);
307 assert(SpvType && "VReg is expected to have SPIRV type");
308 bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
309 bool IsVectorFloat =
310 SpvType->getOpcode() == SPIRV::OpTypeVector &&
311 GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() ==
312 SPIRV::OpTypeFloat;
313 IsFloat |= IsVectorFloat;
314 auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
315 auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
316 if (MRI.getType(ValReg).isPointer()) {
317 NewT = LLT::pointer(0, 32);
318 GetIdOp = SPIRV::GET_pID;
319 DstClass = &SPIRV::pIDRegClass;
320 } else if (MRI.getType(ValReg).isVector()) {
321 NewT = LLT::fixed_vector(2, NewT);
322 GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID;
323 DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass;
324 }
325 Register IdReg = MRI.createGenericVirtualRegister(NewT);
326 MRI.setRegClass(IdReg, DstClass);
327 return {IdReg, GetIdOp};
328 }
329
processInstr(MachineInstr & MI,MachineIRBuilder & MIB,MachineRegisterInfo & MRI,SPIRVGlobalRegistry * GR)330 static void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
331 MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
332 unsigned Opc = MI.getOpcode();
333 assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
334 MachineInstr &AssignTypeInst =
335 *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
336 auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first;
337 AssignTypeInst.getOperand(1).setReg(NewReg);
338 MI.getOperand(0).setReg(NewReg);
339 MIB.setInsertPt(*MI.getParent(),
340 (MI.getNextNode() ? MI.getNextNode()->getIterator()
341 : MI.getParent()->end()));
342 for (auto &Op : MI.operands()) {
343 if (!Op.isReg() || Op.isDef())
344 continue;
345 auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR);
346 MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
347 Op.setReg(IdOpInfo.first);
348 }
349 }
350
351 // Defined in SPIRVLegalizerInfo.cpp.
352 extern bool isTypeFoldingSupported(unsigned Opcode);
353
processInstrsWithTypeFolding(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB)354 static void processInstrsWithTypeFolding(MachineFunction &MF,
355 SPIRVGlobalRegistry *GR,
356 MachineIRBuilder MIB) {
357 MachineRegisterInfo &MRI = MF.getRegInfo();
358 for (MachineBasicBlock &MBB : MF) {
359 for (MachineInstr &MI : MBB) {
360 if (isTypeFoldingSupported(MI.getOpcode()))
361 processInstr(MI, MIB, MRI, GR);
362 }
363 }
364 for (MachineBasicBlock &MBB : MF) {
365 for (MachineInstr &MI : MBB) {
366 // We need to rewrite dst types for ASSIGN_TYPE instrs to be able
367 // to perform tblgen'erated selection and we can't do that on Legalizer
368 // as it operates on gMIR only.
369 if (MI.getOpcode() != SPIRV::ASSIGN_TYPE)
370 continue;
371 Register SrcReg = MI.getOperand(1).getReg();
372 unsigned Opcode = MRI.getVRegDef(SrcReg)->getOpcode();
373 if (!isTypeFoldingSupported(Opcode))
374 continue;
375 Register DstReg = MI.getOperand(0).getReg();
376 if (MRI.getType(DstReg).isVector())
377 MRI.setRegClass(DstReg, &SPIRV::IDRegClass);
378 // Don't need to reset type of register holding constant and used in
379 // G_ADDRSPACE_CAST, since it braaks legalizer.
380 if (Opcode == TargetOpcode::G_CONSTANT && MRI.hasOneUse(DstReg)) {
381 MachineInstr &UseMI = *MRI.use_instr_begin(DstReg);
382 if (UseMI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST)
383 continue;
384 }
385 MRI.setType(DstReg, LLT::scalar(32));
386 }
387 }
388 }
389
processSwitches(MachineFunction & MF,SPIRVGlobalRegistry * GR,MachineIRBuilder MIB)390 static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
391 MachineIRBuilder MIB) {
392 // Before IRTranslator pass, calls to spv_switch intrinsic are inserted before
393 // each switch instruction. IRTranslator lowers switches to G_ICMP + G_BRCOND
394 // + G_BR triples. A switch with two cases may be transformed to this MIR
395 // sequence:
396 //
397 // intrinsic(@llvm.spv.switch), %CmpReg, %Const0, %Const1
398 // %Dst0 = G_ICMP intpred(eq), %CmpReg, %Const0
399 // G_BRCOND %Dst0, %bb.2
400 // G_BR %bb.5
401 // bb.5.entry:
402 // %Dst1 = G_ICMP intpred(eq), %CmpReg, %Const1
403 // G_BRCOND %Dst1, %bb.3
404 // G_BR %bb.4
405 // bb.2.sw.bb:
406 // ...
407 // bb.3.sw.bb1:
408 // ...
409 // bb.4.sw.epilog:
410 // ...
411 //
412 // Sometimes (in case of range-compare switches), additional G_SUBs
413 // instructions are inserted before G_ICMPs. Those need to be additionally
414 // processed and require type assignment.
415 //
416 // This function modifies spv_switch call's operands to include destination
417 // MBBs (default and for each constant value).
418 // Note that this function does not remove G_ICMP + G_BRCOND + G_BR sequences,
419 // but they are marked by ModuleAnalysis as skipped and as a result AsmPrinter
420 // does not output them.
421
422 MachineRegisterInfo &MRI = MF.getRegInfo();
423
424 // Collect all MIs relevant to switches across all MBBs in MF.
425 std::vector<MachineInstr *> RelevantInsts;
426
427 // Temporary set of compare registers. G_SUBs and G_ICMPs relating to
428 // spv_switch use these registers.
429 DenseSet<Register> CompareRegs;
430 for (MachineBasicBlock &MBB : MF) {
431 for (MachineInstr &MI : MBB) {
432 // Calls to spv_switch intrinsics representing IR switches.
433 if (isSpvIntrinsic(MI, Intrinsic::spv_switch)) {
434 assert(MI.getOperand(1).isReg());
435 CompareRegs.insert(MI.getOperand(1).getReg());
436 RelevantInsts.push_back(&MI);
437 }
438
439 // G_SUBs coming from range-compare switch lowering. G_SUBs are found
440 // after spv_switch but before G_ICMP.
441 if (MI.getOpcode() == TargetOpcode::G_SUB && MI.getOperand(1).isReg() &&
442 CompareRegs.contains(MI.getOperand(1).getReg())) {
443 assert(MI.getOperand(0).isReg() && MI.getOperand(1).isReg());
444 Register Dst = MI.getOperand(0).getReg();
445 CompareRegs.insert(Dst);
446 SPIRVType *Ty = GR->getSPIRVTypeForVReg(MI.getOperand(1).getReg());
447 insertAssignInstr(Dst, nullptr, Ty, GR, MIB, MRI);
448 }
449
450 // G_ICMPs relating to switches.
451 if (MI.getOpcode() == TargetOpcode::G_ICMP && MI.getOperand(2).isReg() &&
452 CompareRegs.contains(MI.getOperand(2).getReg())) {
453 Register Dst = MI.getOperand(0).getReg();
454 // Set type info for destination register of switch's ICMP instruction.
455 if (GR->getSPIRVTypeForVReg(Dst) == nullptr) {
456 MIB.setInsertPt(*MI.getParent(), MI);
457 Type *LLVMTy = IntegerType::get(MF.getFunction().getContext(), 1);
458 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, MIB);
459 MRI.setRegClass(Dst, &SPIRV::IDRegClass);
460 GR->assignSPIRVTypeToVReg(SpirvTy, Dst, MIB.getMF());
461 }
462 RelevantInsts.push_back(&MI);
463 }
464 }
465 }
466
467 // Update each spv_switch with destination MBBs.
468 for (auto i = RelevantInsts.begin(); i != RelevantInsts.end(); i++) {
469 if (!isSpvIntrinsic(**i, Intrinsic::spv_switch))
470 continue;
471
472 // Currently considered spv_switch.
473 MachineInstr *Switch = *i;
474 // Set the first successor as default MBB to support empty switches.
475 MachineBasicBlock *DefaultMBB = *Switch->getParent()->succ_begin();
476 // Container for mapping values to MMBs.
477 SmallDenseMap<uint64_t, MachineBasicBlock *> ValuesToMBBs;
478
479 // Walk all G_ICMPs to collect ValuesToMBBs. Start at currently considered
480 // spv_switch (i) and break at any spv_switch with the same compare
481 // register (indicating we are back at the same scope).
482 Register CompareReg = Switch->getOperand(1).getReg();
483 for (auto j = i + 1; j != RelevantInsts.end(); j++) {
484 if (isSpvIntrinsic(**j, Intrinsic::spv_switch) &&
485 (*j)->getOperand(1).getReg() == CompareReg)
486 break;
487
488 if (!((*j)->getOpcode() == TargetOpcode::G_ICMP &&
489 (*j)->getOperand(2).getReg() == CompareReg))
490 continue;
491
492 MachineInstr *ICMP = *j;
493 Register Dst = ICMP->getOperand(0).getReg();
494 MachineOperand &PredOp = ICMP->getOperand(1);
495 const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
496 assert((CC == CmpInst::ICMP_EQ || CC == CmpInst::ICMP_ULE) &&
497 MRI.hasOneUse(Dst) && MRI.hasOneDef(CompareReg));
498 uint64_t Value = getIConstVal(ICMP->getOperand(3).getReg(), &MRI);
499 MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
500 assert(CBr->getOpcode() == SPIRV::G_BRCOND && CBr->getOperand(1).isMBB());
501 MachineBasicBlock *MBB = CBr->getOperand(1).getMBB();
502
503 // Map switch case Value to target MBB.
504 ValuesToMBBs[Value] = MBB;
505
506 // The next MI is always G_BR to either the next case or the default.
507 MachineInstr *NextMI = CBr->getNextNode();
508 assert(NextMI->getOpcode() == SPIRV::G_BR &&
509 NextMI->getOperand(0).isMBB());
510 MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB();
511 // Default MBB does not begin with G_ICMP using spv_switch compare
512 // register.
513 if (NextMBB->front().getOpcode() != SPIRV::G_ICMP ||
514 (NextMBB->front().getOperand(2).isReg() &&
515 NextMBB->front().getOperand(2).getReg() != CompareReg))
516 DefaultMBB = NextMBB;
517 }
518
519 // Modify considered spv_switch operands using collected Values and
520 // MBBs.
521 SmallVector<const ConstantInt *, 3> Values;
522 SmallVector<MachineBasicBlock *, 3> MBBs;
523 for (unsigned k = 2; k < Switch->getNumExplicitOperands(); k++) {
524 Register CReg = Switch->getOperand(k).getReg();
525 uint64_t Val = getIConstVal(CReg, &MRI);
526 MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI);
527 if (!ValuesToMBBs[Val])
528 continue;
529
530 Values.push_back(ConstInstr->getOperand(1).getCImm());
531 MBBs.push_back(ValuesToMBBs[Val]);
532 }
533
534 for (unsigned k = Switch->getNumExplicitOperands() - 1; k > 1; k--)
535 Switch->removeOperand(k);
536
537 Switch->addOperand(MachineOperand::CreateMBB(DefaultMBB));
538 for (unsigned k = 0; k < Values.size(); k++) {
539 Switch->addOperand(MachineOperand::CreateCImm(Values[k]));
540 Switch->addOperand(MachineOperand::CreateMBB(MBBs[k]));
541 }
542 }
543 }
544
runOnMachineFunction(MachineFunction & MF)545 bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
546 // Initialize the type registry.
547 const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
548 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
549 GR->setCurrentFunc(MF);
550 MachineIRBuilder MIB(MF);
551 addConstantsToTrack(MF, GR);
552 foldConstantsIntoIntrinsics(MF);
553 insertBitcasts(MF, GR, MIB);
554 generateAssignInstrs(MF, GR, MIB);
555 processSwitches(MF, GR, MIB);
556 processInstrsWithTypeFolding(MF, GR, MIB);
557
558 return true;
559 }
560
561 INITIALIZE_PASS(SPIRVPreLegalizer, DEBUG_TYPE, "SPIRV pre legalizer", false,
562 false)
563
564 char SPIRVPreLegalizer::ID = 0;
565
createSPIRVPreLegalizerPass()566 FunctionPass *llvm::createSPIRVPreLegalizerPass() {
567 return new SPIRVPreLegalizer();
568 }
569