1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to pre-config the shapes of AMX registers
10 /// AMX register needs to be configured before use. The shapes of AMX register
11 /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
12 ///
13 /// The instruction ldtilecfg is used to config the shapes. It must be reachable
14 /// for all variable shapes. ldtilecfg will be inserted more than once if we
15 /// cannot find a dominating point for all AMX instructions.
16 ///
17 /// The configure register is caller saved according to ABI. We need to insert
18 /// ldtilecfg again after the call instruction if callee clobbers any AMX
19 /// registers.
20 ///
21 /// This pass calculates all points that ldtilecfg need to be inserted to and
22 /// insert them. It reports error if the reachability conditions aren't met.
23 //
24 //===----------------------------------------------------------------------===//
25
26 #include "X86.h"
27 #include "X86InstrBuilder.h"
28 #include "X86MachineFunctionInfo.h"
29 #include "X86RegisterInfo.h"
30 #include "X86Subtarget.h"
31 #include "llvm/CodeGen/MachineFunctionPass.h"
32 #include "llvm/CodeGen/MachineInstr.h"
33 #include "llvm/CodeGen/MachineLoopInfo.h"
34 #include "llvm/CodeGen/MachineModuleInfo.h"
35 #include "llvm/CodeGen/MachineRegisterInfo.h"
36 #include "llvm/CodeGen/Passes.h"
37 #include "llvm/CodeGen/TargetInstrInfo.h"
38 #include "llvm/CodeGen/TargetRegisterInfo.h"
39 #include "llvm/InitializePasses.h"
40
41 using namespace llvm;
42
43 #define DEBUG_TYPE "tile-pre-config"
44
emitErrorMsg(MachineFunction & MF)45 static void emitErrorMsg(MachineFunction &MF) {
46 SmallString<32> Str;
47 Twine ErrorMsg =
48 MF.getName() +
49 ": Failed to config tile register, please define the shape earlier";
50 LLVMContext &Context = MF.getMMI().getModule()->getContext();
51 Context.emitError(ErrorMsg);
52 }
53
54 namespace {
55
56 struct MIRef {
57 MachineInstr *MI = nullptr;
58 MachineBasicBlock *MBB = nullptr;
59 // A virtual position for instruction that will be inserted after MI.
60 size_t Pos = 0;
61 MIRef() = default;
MIRef__anona24527ac0111::MIRef62 MIRef(MachineBasicBlock *MBB) : MBB(MBB) {
63 for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI();
64 ++I, ++Pos)
65 MI = &*I;
66 }
MIRef__anona24527ac0111::MIRef67 MIRef(MachineInstr *MI)
68 : MI(MI), MBB(MI->getParent()),
69 Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
MIRef__anona24527ac0111::MIRef70 MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
71 : MI(MI), MBB(MBB),
72 Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
MIRef__anona24527ac0111::MIRef73 MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos)
74 : MI(MI), MBB(MBB), Pos(Pos) {}
operator bool__anona24527ac0111::MIRef75 operator bool() const { return MBB != nullptr; }
operator ==__anona24527ac0111::MIRef76 bool operator==(const MIRef &RHS) const {
77 return MI == RHS.MI && MBB == RHS.MBB;
78 }
operator !=__anona24527ac0111::MIRef79 bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
operator <__anona24527ac0111::MIRef80 bool operator<(const MIRef &RHS) const {
81 // Comparison between different BBs happens when inserting a MIRef into set.
82 // So we compare MBB first to make the insertion happy.
83 return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos);
84 }
operator >__anona24527ac0111::MIRef85 bool operator>(const MIRef &RHS) const {
86 // Comparison between different BBs happens when inserting a MIRef into set.
87 // So we compare MBB first to make the insertion happy.
88 return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos);
89 }
90 };
91
92 struct BBInfo {
93 MIRef FirstAMX;
94 MIRef LastCall;
95 bool HasAMXRegLiveIn = false;
96 bool TileCfgForbidden = false;
97 bool NeedTileCfgLiveIn = false;
98 };
99
100 class X86PreTileConfig : public MachineFunctionPass {
101 MachineRegisterInfo *MRI;
102 const MachineLoopInfo *MLI;
103 SmallSet<MachineInstr *, 8> DefVisited;
104 DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
105 DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;
106
107 /// Check if the callee will clobber AMX registers.
isDestructiveCall(MachineInstr & MI,BitVector UsableRegs)108 bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
109 auto Iter = llvm::find_if(
110 MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); });
111 if (Iter == MI.operands_end())
112 return false;
113 UsableRegs.clearBitsInMask(Iter->getRegMask());
114 return !UsableRegs.none();
115 }
116
117 /// Check if MI is AMX pseudo instruction.
isAMXInstruction(MachineInstr & MI)118 bool isAMXInstruction(MachineInstr &MI) {
119 if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3)
120 return false;
121 MachineOperand &MO = MI.getOperand(0);
122 // We can simply check if it is AMX instruction by its def.
123 // But we should exclude old API which uses physical registers.
124 if (MO.isReg() && MO.getReg().isVirtual() &&
125 MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) {
126 collectShapeInfo(MI);
127 return true;
128 }
129 // PTILESTOREDV is the only exception that doesn't def a AMX register.
130 return MI.getOpcode() == X86::PTILESTOREDV;
131 }
132
133 /// Check if it is an edge from loop bottom to loop head.
isLoopBackEdge(MachineBasicBlock * Header,MachineBasicBlock * Bottom)134 bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) {
135 if (!MLI->isLoopHeader(Header))
136 return false;
137 auto *ML = MLI->getLoopFor(Header);
138 if (ML->contains(Bottom) && ML->isLoopLatch(Bottom))
139 return true;
140
141 return false;
142 }
143
144 /// Collect the shape def information for later use.
145 void collectShapeInfo(MachineInstr &MI);
146
147 /// Try to hoist shapes definded below AMX instructions.
hoistShapesInBB(MachineBasicBlock * MBB,SmallVectorImpl<MIRef> & Shapes)148 bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) {
149 MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX;
150 auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX);
151 auto InsertPoint = FirstAMX.MI->getIterator();
152 for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) {
153 // Do not hoist instructions that access memory.
154 if (I->MI->mayLoadOrStore())
155 return false;
156 for (auto &MO : I->MI->operands()) {
157 if (MO.isDef())
158 continue;
159 // Do not hoist instructions if the sources' def under AMX instruction.
160 // TODO: We can handle isMoveImmediate MI here.
161 if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX)
162 return false;
163 // TODO: Maybe need more checks here.
164 }
165 MBB->insert(InsertPoint, I->MI->removeFromParent());
166 }
167 // We only need to mark the last shape in the BB now.
168 Shapes.clear();
169 Shapes.push_back(MIRef(&*--InsertPoint, MBB));
170 return true;
171 }
172
173 public:
X86PreTileConfig()174 X86PreTileConfig() : MachineFunctionPass(ID) {}
175
176 /// Return the pass name.
getPassName() const177 StringRef getPassName() const override {
178 return "Tile Register Pre-configure";
179 }
180
181 /// X86PreTileConfig analysis usage.
getAnalysisUsage(AnalysisUsage & AU) const182 void getAnalysisUsage(AnalysisUsage &AU) const override {
183 AU.setPreservesAll();
184 AU.addRequired<MachineLoopInfo>();
185 MachineFunctionPass::getAnalysisUsage(AU);
186 }
187
188 /// Clear MF related structures.
releaseMemory()189 void releaseMemory() override {
190 ShapeBBs.clear();
191 DefVisited.clear();
192 BBVisitedInfo.clear();
193 }
194
195 /// Perform ldtilecfg instructions inserting.
196 bool runOnMachineFunction(MachineFunction &MF) override;
197
198 static char ID;
199 };
200
201 } // end anonymous namespace
202
203 char X86PreTileConfig::ID = 0;
204
205 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig",
206 "Tile Register Pre-configure", false, false)
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)207 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
208 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
209 "Tile Register Pre-configure", false, false)
210
211 void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
212 auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
213 MIRef MIR(MI, MBB);
214 auto I = llvm::lower_bound(ShapeBBs[MBB], MIR);
215 if (I == ShapeBBs[MBB].end() || *I != MIR)
216 ShapeBBs[MBB].insert(I, MIR);
217 };
218
219 SmallVector<Register, 8> WorkList(
220 {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()});
221 while (!WorkList.empty()) {
222 Register R = WorkList.pop_back_val();
223 MachineInstr *DefMI = MRI->getVRegDef(R);
224 assert(DefMI && "R must has one define instruction");
225 MachineBasicBlock *DefMBB = DefMI->getParent();
226 if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second)
227 continue;
228 if (DefMI->isPHI()) {
229 for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2)
230 if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB()))
231 RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def.
232 else
233 WorkList.push_back(DefMI->getOperand(I).getReg());
234 } else {
235 RecordShape(DefMI, DefMBB);
236 }
237 }
238 }
239
runOnMachineFunction(MachineFunction & MF)240 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
241 const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
242 const TargetInstrInfo *TII = ST.getInstrInfo();
243 const TargetRegisterInfo *TRI = ST.getRegisterInfo();
244 const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID);
245 X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>();
246
247 BitVector AMXRegs(TRI->getNumRegs());
248 for (unsigned I = 0; I < RC->getNumRegs(); I++)
249 AMXRegs.set(X86::TMM0 + I);
250
251 // Iterate MF to collect information.
252 MRI = &MF.getRegInfo();
253 MLI = &getAnalysis<MachineLoopInfo>();
254 SmallSet<MIRef, 8> CfgNeedInsert;
255 SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs;
256 for (auto &MBB : MF) {
257 size_t Pos = 0;
258 for (auto &MI : MBB) {
259 ++Pos;
260 if (isAMXInstruction(MI)) {
261 // If there's call before the AMX, we need to reload tile config.
262 if (BBVisitedInfo[&MBB].LastCall)
263 CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall);
264 else // Otherwise, we need tile config to live in this BB.
265 BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true;
266 // Always record the first AMX in case there's shape def after it.
267 if (!BBVisitedInfo[&MBB].FirstAMX)
268 BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos);
269 } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) {
270 // Record the call only if the callee clobbers all AMX registers.
271 BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos);
272 }
273 }
274 if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) {
275 if (&MBB == &MF.front())
276 CfgNeedInsert.insert(MIRef(&MBB));
277 else
278 CfgLiveInBBs.push_back(&MBB);
279 }
280 if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn)
281 for (auto *Succ : MBB.successors())
282 if (!isLoopBackEdge(Succ, &MBB))
283 BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
284 }
285
286 // Update NeedTileCfgLiveIn for predecessors.
287 while (!CfgLiveInBBs.empty()) {
288 MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val();
289 for (auto *Pred : MBB->predecessors()) {
290 if (BBVisitedInfo[Pred].LastCall) {
291 CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall);
292 } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) {
293 BBVisitedInfo[Pred].NeedTileCfgLiveIn = true;
294 if (Pred == &MF.front())
295 CfgNeedInsert.insert(MIRef(Pred));
296 else
297 CfgLiveInBBs.push_back(Pred);
298 }
299 }
300 }
301
302 // There's no AMX instruction if we didn't find a tile config live in point.
303 if (CfgNeedInsert.empty())
304 return false;
305 X86FI->setHasVirtualTileReg(true);
306
307 // Avoid to insert ldtilecfg before any shape defs.
308 SmallVector<MachineBasicBlock *, 8> WorkList;
309 for (auto &I : ShapeBBs) {
310 // TODO: We can hoist shapes across BBs here.
311 if (BBVisitedInfo[I.first].HasAMXRegLiveIn) {
312 // We are not able to config tile registers since the shape to config
313 // is not defined yet. Emit error message and continue. The function
314 // would not config tile registers.
315 emitErrorMsg(MF);
316 return false;
317 }
318 if (BBVisitedInfo[I.first].FirstAMX &&
319 BBVisitedInfo[I.first].FirstAMX < I.second.back() &&
320 !hoistShapesInBB(I.first, I.second)) {
321 emitErrorMsg(MF);
322 return false;
323 }
324 WorkList.push_back(I.first);
325 }
326 while (!WorkList.empty()) {
327 MachineBasicBlock *MBB = WorkList.pop_back_val();
328 for (auto *Pred : MBB->predecessors()) {
329 if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) {
330 BBVisitedInfo[Pred].TileCfgForbidden = true;
331 WorkList.push_back(Pred);
332 }
333 }
334 }
335
336 DebugLoc DL;
337 SmallSet<MIRef, 8> VisitedOrInserted;
338 int SS = MF.getFrameInfo().CreateStackObject(
339 ST.getTileConfigSize(), ST.getTileConfigAlignment(), false);
340
341 // Try to insert for the tile config live in points.
342 for (const auto &I : CfgNeedInsert) {
343 SmallSet<MIRef, 8> InsertPoints;
344 SmallVector<MIRef, 8> WorkList({I});
345 while (!WorkList.empty()) {
346 MIRef I = WorkList.pop_back_val();
347 if (!VisitedOrInserted.count(I)) {
348 if (!BBVisitedInfo[I.MBB].TileCfgForbidden) {
349 // If the BB is all shapes reachable, stop sink and try to insert.
350 InsertPoints.insert(I);
351 } else {
352 // Avoid the BB to be multi visited.
353 VisitedOrInserted.insert(I);
354 // Sink the inserting point along the chain with NeedTileCfgLiveIn =
355 // true when MBB isn't all shapes reachable.
356 for (auto *Succ : I.MBB->successors())
357 if (BBVisitedInfo[Succ].NeedTileCfgLiveIn)
358 WorkList.push_back(MIRef(Succ));
359 }
360 }
361 }
362
363 // A given point might be forked due to shape conditions are not met.
364 for (MIRef I : InsertPoints) {
365 // Make sure we insert ldtilecfg after the last shape def in MBB.
366 if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back())
367 I = ShapeBBs[I.MBB].back();
368 // There're chances the MBB is sunk more than once. Record it to avoid
369 // multi insert.
370 if (VisitedOrInserted.insert(I).second) {
371 auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin();
372 addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::PLDTILECFGV)),
373 SS);
374 }
375 }
376 }
377
378 // Zero stack slot.
379 MachineBasicBlock &MBB = MF.front();
380 MachineInstr *MI = &*MBB.begin();
381 if (ST.hasAVX512()) {
382 Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass);
383 BuildMI(MBB, MI, DL, TII->get(X86::AVX512_512_SET0), Zmm);
384 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS)
385 .addReg(Zmm);
386 } else if (ST.hasAVX2()) {
387 Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass);
388 BuildMI(MBB, MI, DL, TII->get(X86::AVX_SET0), Ymm);
389 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS)
390 .addReg(Ymm);
391 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32)
392 .addReg(Ymm);
393 } else {
394 assert(ST.hasSSE2() && "AMX should assume SSE2 enabled");
395 unsigned StoreOpc = ST.hasAVX() ? X86::VMOVUPSmr : X86::MOVUPSmr;
396 Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass);
397 BuildMI(MBB, MI, DL, TII->get(X86::V_SET0), Xmm);
398 addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS).addReg(Xmm);
399 addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS, 16)
400 .addReg(Xmm);
401 addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS, 32)
402 .addReg(Xmm);
403 addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS, 48)
404 .addReg(Xmm);
405 }
406 // Fill in the palette first.
407 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1);
408
409 return true;
410 }
411
createX86PreTileConfigPass()412 FunctionPass *llvm::createX86PreTileConfigPass() {
413 return new X86PreTileConfig();
414 }
415