1 //===-- lib/CodeGen/GlobalISel/GICombinerHelper.cpp -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
9 #include "llvm/ADT/SetVector.h"
10 #include "llvm/ADT/SmallBitVector.h"
11 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
12 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
13 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
14 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
15 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
16 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
17 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
18 #include "llvm/CodeGen/GlobalISel/Utils.h"
19 #include "llvm/CodeGen/LowLevelType.h"
20 #include "llvm/CodeGen/MachineBasicBlock.h"
21 #include "llvm/CodeGen/MachineDominators.h"
22 #include "llvm/CodeGen/MachineInstr.h"
23 #include "llvm/CodeGen/MachineMemOperand.h"
24 #include "llvm/CodeGen/MachineRegisterInfo.h"
25 #include "llvm/CodeGen/RegisterBankInfo.h"
26 #include "llvm/CodeGen/TargetInstrInfo.h"
27 #include "llvm/CodeGen/TargetLowering.h"
28 #include "llvm/CodeGen/TargetOpcodes.h"
29 #include "llvm/IR/DataLayout.h"
30 #include "llvm/IR/InstrTypes.h"
31 #include "llvm/Support/Casting.h"
32 #include "llvm/Support/DivisionByConstantInfo.h"
33 #include "llvm/Support/MathExtras.h"
34 #include "llvm/Target/TargetMachine.h"
35 #include <cmath>
36 #include <optional>
37 #include <tuple>
38
39 #define DEBUG_TYPE "gi-combiner"
40
41 using namespace llvm;
42 using namespace MIPatternMatch;
43
44 // Option to allow testing of the combiner while no targets know about indexed
45 // addressing.
46 static cl::opt<bool>
47 ForceLegalIndexing("force-legal-indexing", cl::Hidden, cl::init(false),
48 cl::desc("Force all indexed operations to be "
49 "legal for the GlobalISel combiner"));
50
CombinerHelper(GISelChangeObserver & Observer,MachineIRBuilder & B,bool IsPreLegalize,GISelKnownBits * KB,MachineDominatorTree * MDT,const LegalizerInfo * LI)51 CombinerHelper::CombinerHelper(GISelChangeObserver &Observer,
52 MachineIRBuilder &B, bool IsPreLegalize,
53 GISelKnownBits *KB, MachineDominatorTree *MDT,
54 const LegalizerInfo *LI)
55 : Builder(B), MRI(Builder.getMF().getRegInfo()), Observer(Observer), KB(KB),
56 MDT(MDT), IsPreLegalize(IsPreLegalize), LI(LI),
57 RBI(Builder.getMF().getSubtarget().getRegBankInfo()),
58 TRI(Builder.getMF().getSubtarget().getRegisterInfo()) {
59 (void)this->KB;
60 }
61
getTargetLowering() const62 const TargetLowering &CombinerHelper::getTargetLowering() const {
63 return *Builder.getMF().getSubtarget().getTargetLowering();
64 }
65
66 /// \returns The little endian in-memory byte position of byte \p I in a
67 /// \p ByteWidth bytes wide type.
68 ///
69 /// E.g. Given a 4-byte type x, x[0] -> byte 0
littleEndianByteAt(const unsigned ByteWidth,const unsigned I)70 static unsigned littleEndianByteAt(const unsigned ByteWidth, const unsigned I) {
71 assert(I < ByteWidth && "I must be in [0, ByteWidth)");
72 return I;
73 }
74
75 /// Determines the LogBase2 value for a non-null input value using the
76 /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
buildLogBase2(Register V,MachineIRBuilder & MIB)77 static Register buildLogBase2(Register V, MachineIRBuilder &MIB) {
78 auto &MRI = *MIB.getMRI();
79 LLT Ty = MRI.getType(V);
80 auto Ctlz = MIB.buildCTLZ(Ty, V);
81 auto Base = MIB.buildConstant(Ty, Ty.getScalarSizeInBits() - 1);
82 return MIB.buildSub(Ty, Base, Ctlz).getReg(0);
83 }
84
85 /// \returns The big endian in-memory byte position of byte \p I in a
86 /// \p ByteWidth bytes wide type.
87 ///
88 /// E.g. Given a 4-byte type x, x[0] -> byte 3
bigEndianByteAt(const unsigned ByteWidth,const unsigned I)89 static unsigned bigEndianByteAt(const unsigned ByteWidth, const unsigned I) {
90 assert(I < ByteWidth && "I must be in [0, ByteWidth)");
91 return ByteWidth - I - 1;
92 }
93
94 /// Given a map from byte offsets in memory to indices in a load/store,
95 /// determine if that map corresponds to a little or big endian byte pattern.
96 ///
97 /// \param MemOffset2Idx maps memory offsets to address offsets.
98 /// \param LowestIdx is the lowest index in \p MemOffset2Idx.
99 ///
100 /// \returns true if the map corresponds to a big endian byte pattern, false if
101 /// it corresponds to a little endian byte pattern, and std::nullopt otherwise.
102 ///
103 /// E.g. given a 32-bit type x, and x[AddrOffset], the in-memory byte patterns
104 /// are as follows:
105 ///
106 /// AddrOffset Little endian Big endian
107 /// 0 0 3
108 /// 1 1 2
109 /// 2 2 1
110 /// 3 3 0
111 static std::optional<bool>
isBigEndian(const SmallDenseMap<int64_t,int64_t,8> & MemOffset2Idx,int64_t LowestIdx)112 isBigEndian(const SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx,
113 int64_t LowestIdx) {
114 // Need at least two byte positions to decide on endianness.
115 unsigned Width = MemOffset2Idx.size();
116 if (Width < 2)
117 return std::nullopt;
118 bool BigEndian = true, LittleEndian = true;
119 for (unsigned MemOffset = 0; MemOffset < Width; ++ MemOffset) {
120 auto MemOffsetAndIdx = MemOffset2Idx.find(MemOffset);
121 if (MemOffsetAndIdx == MemOffset2Idx.end())
122 return std::nullopt;
123 const int64_t Idx = MemOffsetAndIdx->second - LowestIdx;
124 assert(Idx >= 0 && "Expected non-negative byte offset?");
125 LittleEndian &= Idx == littleEndianByteAt(Width, MemOffset);
126 BigEndian &= Idx == bigEndianByteAt(Width, MemOffset);
127 if (!BigEndian && !LittleEndian)
128 return std::nullopt;
129 }
130
131 assert((BigEndian != LittleEndian) &&
132 "Pattern cannot be both big and little endian!");
133 return BigEndian;
134 }
135
isPreLegalize() const136 bool CombinerHelper::isPreLegalize() const { return IsPreLegalize; }
137
isLegal(const LegalityQuery & Query) const138 bool CombinerHelper::isLegal(const LegalityQuery &Query) const {
139 assert(LI && "Must have LegalizerInfo to query isLegal!");
140 return LI->getAction(Query).Action == LegalizeActions::Legal;
141 }
142
isLegalOrBeforeLegalizer(const LegalityQuery & Query) const143 bool CombinerHelper::isLegalOrBeforeLegalizer(
144 const LegalityQuery &Query) const {
145 return isPreLegalize() || isLegal(Query);
146 }
147
isConstantLegalOrBeforeLegalizer(const LLT Ty) const148 bool CombinerHelper::isConstantLegalOrBeforeLegalizer(const LLT Ty) const {
149 if (!Ty.isVector())
150 return isLegalOrBeforeLegalizer({TargetOpcode::G_CONSTANT, {Ty}});
151 // Vector constants are represented as a G_BUILD_VECTOR of scalar G_CONSTANTs.
152 if (isPreLegalize())
153 return true;
154 LLT EltTy = Ty.getElementType();
155 return isLegal({TargetOpcode::G_BUILD_VECTOR, {Ty, EltTy}}) &&
156 isLegal({TargetOpcode::G_CONSTANT, {EltTy}});
157 }
158
replaceRegWith(MachineRegisterInfo & MRI,Register FromReg,Register ToReg) const159 void CombinerHelper::replaceRegWith(MachineRegisterInfo &MRI, Register FromReg,
160 Register ToReg) const {
161 Observer.changingAllUsesOfReg(MRI, FromReg);
162
163 if (MRI.constrainRegAttrs(ToReg, FromReg))
164 MRI.replaceRegWith(FromReg, ToReg);
165 else
166 Builder.buildCopy(ToReg, FromReg);
167
168 Observer.finishedChangingAllUsesOfReg();
169 }
170
replaceRegOpWith(MachineRegisterInfo & MRI,MachineOperand & FromRegOp,Register ToReg) const171 void CombinerHelper::replaceRegOpWith(MachineRegisterInfo &MRI,
172 MachineOperand &FromRegOp,
173 Register ToReg) const {
174 assert(FromRegOp.getParent() && "Expected an operand in an MI");
175 Observer.changingInstr(*FromRegOp.getParent());
176
177 FromRegOp.setReg(ToReg);
178
179 Observer.changedInstr(*FromRegOp.getParent());
180 }
181
replaceOpcodeWith(MachineInstr & FromMI,unsigned ToOpcode) const182 void CombinerHelper::replaceOpcodeWith(MachineInstr &FromMI,
183 unsigned ToOpcode) const {
184 Observer.changingInstr(FromMI);
185
186 FromMI.setDesc(Builder.getTII().get(ToOpcode));
187
188 Observer.changedInstr(FromMI);
189 }
190
getRegBank(Register Reg) const191 const RegisterBank *CombinerHelper::getRegBank(Register Reg) const {
192 return RBI->getRegBank(Reg, MRI, *TRI);
193 }
194
setRegBank(Register Reg,const RegisterBank * RegBank)195 void CombinerHelper::setRegBank(Register Reg, const RegisterBank *RegBank) {
196 if (RegBank)
197 MRI.setRegBank(Reg, *RegBank);
198 }
199
tryCombineCopy(MachineInstr & MI)200 bool CombinerHelper::tryCombineCopy(MachineInstr &MI) {
201 if (matchCombineCopy(MI)) {
202 applyCombineCopy(MI);
203 return true;
204 }
205 return false;
206 }
matchCombineCopy(MachineInstr & MI)207 bool CombinerHelper::matchCombineCopy(MachineInstr &MI) {
208 if (MI.getOpcode() != TargetOpcode::COPY)
209 return false;
210 Register DstReg = MI.getOperand(0).getReg();
211 Register SrcReg = MI.getOperand(1).getReg();
212 return canReplaceReg(DstReg, SrcReg, MRI);
213 }
applyCombineCopy(MachineInstr & MI)214 void CombinerHelper::applyCombineCopy(MachineInstr &MI) {
215 Register DstReg = MI.getOperand(0).getReg();
216 Register SrcReg = MI.getOperand(1).getReg();
217 MI.eraseFromParent();
218 replaceRegWith(MRI, DstReg, SrcReg);
219 }
220
tryCombineConcatVectors(MachineInstr & MI)221 bool CombinerHelper::tryCombineConcatVectors(MachineInstr &MI) {
222 bool IsUndef = false;
223 SmallVector<Register, 4> Ops;
224 if (matchCombineConcatVectors(MI, IsUndef, Ops)) {
225 applyCombineConcatVectors(MI, IsUndef, Ops);
226 return true;
227 }
228 return false;
229 }
230
matchCombineConcatVectors(MachineInstr & MI,bool & IsUndef,SmallVectorImpl<Register> & Ops)231 bool CombinerHelper::matchCombineConcatVectors(MachineInstr &MI, bool &IsUndef,
232 SmallVectorImpl<Register> &Ops) {
233 assert(MI.getOpcode() == TargetOpcode::G_CONCAT_VECTORS &&
234 "Invalid instruction");
235 IsUndef = true;
236 MachineInstr *Undef = nullptr;
237
238 // Walk over all the operands of concat vectors and check if they are
239 // build_vector themselves or undef.
240 // Then collect their operands in Ops.
241 for (const MachineOperand &MO : MI.uses()) {
242 Register Reg = MO.getReg();
243 MachineInstr *Def = MRI.getVRegDef(Reg);
244 assert(Def && "Operand not defined");
245 switch (Def->getOpcode()) {
246 case TargetOpcode::G_BUILD_VECTOR:
247 IsUndef = false;
248 // Remember the operands of the build_vector to fold
249 // them into the yet-to-build flattened concat vectors.
250 for (const MachineOperand &BuildVecMO : Def->uses())
251 Ops.push_back(BuildVecMO.getReg());
252 break;
253 case TargetOpcode::G_IMPLICIT_DEF: {
254 LLT OpType = MRI.getType(Reg);
255 // Keep one undef value for all the undef operands.
256 if (!Undef) {
257 Builder.setInsertPt(*MI.getParent(), MI);
258 Undef = Builder.buildUndef(OpType.getScalarType());
259 }
260 assert(MRI.getType(Undef->getOperand(0).getReg()) ==
261 OpType.getScalarType() &&
262 "All undefs should have the same type");
263 // Break the undef vector in as many scalar elements as needed
264 // for the flattening.
265 for (unsigned EltIdx = 0, EltEnd = OpType.getNumElements();
266 EltIdx != EltEnd; ++EltIdx)
267 Ops.push_back(Undef->getOperand(0).getReg());
268 break;
269 }
270 default:
271 return false;
272 }
273 }
274 return true;
275 }
applyCombineConcatVectors(MachineInstr & MI,bool IsUndef,const ArrayRef<Register> Ops)276 void CombinerHelper::applyCombineConcatVectors(
277 MachineInstr &MI, bool IsUndef, const ArrayRef<Register> Ops) {
278 // We determined that the concat_vectors can be flatten.
279 // Generate the flattened build_vector.
280 Register DstReg = MI.getOperand(0).getReg();
281 Builder.setInsertPt(*MI.getParent(), MI);
282 Register NewDstReg = MRI.cloneVirtualRegister(DstReg);
283
284 // Note: IsUndef is sort of redundant. We could have determine it by
285 // checking that at all Ops are undef. Alternatively, we could have
286 // generate a build_vector of undefs and rely on another combine to
287 // clean that up. For now, given we already gather this information
288 // in tryCombineConcatVectors, just save compile time and issue the
289 // right thing.
290 if (IsUndef)
291 Builder.buildUndef(NewDstReg);
292 else
293 Builder.buildBuildVector(NewDstReg, Ops);
294 MI.eraseFromParent();
295 replaceRegWith(MRI, DstReg, NewDstReg);
296 }
297
tryCombineShuffleVector(MachineInstr & MI)298 bool CombinerHelper::tryCombineShuffleVector(MachineInstr &MI) {
299 SmallVector<Register, 4> Ops;
300 if (matchCombineShuffleVector(MI, Ops)) {
301 applyCombineShuffleVector(MI, Ops);
302 return true;
303 }
304 return false;
305 }
306
matchCombineShuffleVector(MachineInstr & MI,SmallVectorImpl<Register> & Ops)307 bool CombinerHelper::matchCombineShuffleVector(MachineInstr &MI,
308 SmallVectorImpl<Register> &Ops) {
309 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR &&
310 "Invalid instruction kind");
311 LLT DstType = MRI.getType(MI.getOperand(0).getReg());
312 Register Src1 = MI.getOperand(1).getReg();
313 LLT SrcType = MRI.getType(Src1);
314 // As bizarre as it may look, shuffle vector can actually produce
315 // scalar! This is because at the IR level a <1 x ty> shuffle
316 // vector is perfectly valid.
317 unsigned DstNumElts = DstType.isVector() ? DstType.getNumElements() : 1;
318 unsigned SrcNumElts = SrcType.isVector() ? SrcType.getNumElements() : 1;
319
320 // If the resulting vector is smaller than the size of the source
321 // vectors being concatenated, we won't be able to replace the
322 // shuffle vector into a concat_vectors.
323 //
324 // Note: We may still be able to produce a concat_vectors fed by
325 // extract_vector_elt and so on. It is less clear that would
326 // be better though, so don't bother for now.
327 //
328 // If the destination is a scalar, the size of the sources doesn't
329 // matter. we will lower the shuffle to a plain copy. This will
330 // work only if the source and destination have the same size. But
331 // that's covered by the next condition.
332 //
333 // TODO: If the size between the source and destination don't match
334 // we could still emit an extract vector element in that case.
335 if (DstNumElts < 2 * SrcNumElts && DstNumElts != 1)
336 return false;
337
338 // Check that the shuffle mask can be broken evenly between the
339 // different sources.
340 if (DstNumElts % SrcNumElts != 0)
341 return false;
342
343 // Mask length is a multiple of the source vector length.
344 // Check if the shuffle is some kind of concatenation of the input
345 // vectors.
346 unsigned NumConcat = DstNumElts / SrcNumElts;
347 SmallVector<int, 8> ConcatSrcs(NumConcat, -1);
348 ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
349 for (unsigned i = 0; i != DstNumElts; ++i) {
350 int Idx = Mask[i];
351 // Undef value.
352 if (Idx < 0)
353 continue;
354 // Ensure the indices in each SrcType sized piece are sequential and that
355 // the same source is used for the whole piece.
356 if ((Idx % SrcNumElts != (i % SrcNumElts)) ||
357 (ConcatSrcs[i / SrcNumElts] >= 0 &&
358 ConcatSrcs[i / SrcNumElts] != (int)(Idx / SrcNumElts)))
359 return false;
360 // Remember which source this index came from.
361 ConcatSrcs[i / SrcNumElts] = Idx / SrcNumElts;
362 }
363
364 // The shuffle is concatenating multiple vectors together.
365 // Collect the different operands for that.
366 Register UndefReg;
367 Register Src2 = MI.getOperand(2).getReg();
368 for (auto Src : ConcatSrcs) {
369 if (Src < 0) {
370 if (!UndefReg) {
371 Builder.setInsertPt(*MI.getParent(), MI);
372 UndefReg = Builder.buildUndef(SrcType).getReg(0);
373 }
374 Ops.push_back(UndefReg);
375 } else if (Src == 0)
376 Ops.push_back(Src1);
377 else
378 Ops.push_back(Src2);
379 }
380 return true;
381 }
382
applyCombineShuffleVector(MachineInstr & MI,const ArrayRef<Register> Ops)383 void CombinerHelper::applyCombineShuffleVector(MachineInstr &MI,
384 const ArrayRef<Register> Ops) {
385 Register DstReg = MI.getOperand(0).getReg();
386 Builder.setInsertPt(*MI.getParent(), MI);
387 Register NewDstReg = MRI.cloneVirtualRegister(DstReg);
388
389 if (Ops.size() == 1)
390 Builder.buildCopy(NewDstReg, Ops[0]);
391 else
392 Builder.buildMergeLikeInstr(NewDstReg, Ops);
393
394 MI.eraseFromParent();
395 replaceRegWith(MRI, DstReg, NewDstReg);
396 }
397
398 namespace {
399
400 /// Select a preference between two uses. CurrentUse is the current preference
401 /// while *ForCandidate is attributes of the candidate under consideration.
ChoosePreferredUse(PreferredTuple & CurrentUse,const LLT TyForCandidate,unsigned OpcodeForCandidate,MachineInstr * MIForCandidate)402 PreferredTuple ChoosePreferredUse(PreferredTuple &CurrentUse,
403 const LLT TyForCandidate,
404 unsigned OpcodeForCandidate,
405 MachineInstr *MIForCandidate) {
406 if (!CurrentUse.Ty.isValid()) {
407 if (CurrentUse.ExtendOpcode == OpcodeForCandidate ||
408 CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT)
409 return {TyForCandidate, OpcodeForCandidate, MIForCandidate};
410 return CurrentUse;
411 }
412
413 // We permit the extend to hoist through basic blocks but this is only
414 // sensible if the target has extending loads. If you end up lowering back
415 // into a load and extend during the legalizer then the end result is
416 // hoisting the extend up to the load.
417
418 // Prefer defined extensions to undefined extensions as these are more
419 // likely to reduce the number of instructions.
420 if (OpcodeForCandidate == TargetOpcode::G_ANYEXT &&
421 CurrentUse.ExtendOpcode != TargetOpcode::G_ANYEXT)
422 return CurrentUse;
423 else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT &&
424 OpcodeForCandidate != TargetOpcode::G_ANYEXT)
425 return {TyForCandidate, OpcodeForCandidate, MIForCandidate};
426
427 // Prefer sign extensions to zero extensions as sign-extensions tend to be
428 // more expensive.
429 if (CurrentUse.Ty == TyForCandidate) {
430 if (CurrentUse.ExtendOpcode == TargetOpcode::G_SEXT &&
431 OpcodeForCandidate == TargetOpcode::G_ZEXT)
432 return CurrentUse;
433 else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ZEXT &&
434 OpcodeForCandidate == TargetOpcode::G_SEXT)
435 return {TyForCandidate, OpcodeForCandidate, MIForCandidate};
436 }
437
438 // This is potentially target specific. We've chosen the largest type
439 // because G_TRUNC is usually free. One potential catch with this is that
440 // some targets have a reduced number of larger registers than smaller
441 // registers and this choice potentially increases the live-range for the
442 // larger value.
443 if (TyForCandidate.getSizeInBits() > CurrentUse.Ty.getSizeInBits()) {
444 return {TyForCandidate, OpcodeForCandidate, MIForCandidate};
445 }
446 return CurrentUse;
447 }
448
449 /// Find a suitable place to insert some instructions and insert them. This
450 /// function accounts for special cases like inserting before a PHI node.
451 /// The current strategy for inserting before PHI's is to duplicate the
452 /// instructions for each predecessor. However, while that's ok for G_TRUNC
453 /// on most targets since it generally requires no code, other targets/cases may
454 /// want to try harder to find a dominating block.
InsertInsnsWithoutSideEffectsBeforeUse(MachineIRBuilder & Builder,MachineInstr & DefMI,MachineOperand & UseMO,std::function<void (MachineBasicBlock *,MachineBasicBlock::iterator,MachineOperand & UseMO)> Inserter)455 static void InsertInsnsWithoutSideEffectsBeforeUse(
456 MachineIRBuilder &Builder, MachineInstr &DefMI, MachineOperand &UseMO,
457 std::function<void(MachineBasicBlock *, MachineBasicBlock::iterator,
458 MachineOperand &UseMO)>
459 Inserter) {
460 MachineInstr &UseMI = *UseMO.getParent();
461
462 MachineBasicBlock *InsertBB = UseMI.getParent();
463
464 // If the use is a PHI then we want the predecessor block instead.
465 if (UseMI.isPHI()) {
466 MachineOperand *PredBB = std::next(&UseMO);
467 InsertBB = PredBB->getMBB();
468 }
469
470 // If the block is the same block as the def then we want to insert just after
471 // the def instead of at the start of the block.
472 if (InsertBB == DefMI.getParent()) {
473 MachineBasicBlock::iterator InsertPt = &DefMI;
474 Inserter(InsertBB, std::next(InsertPt), UseMO);
475 return;
476 }
477
478 // Otherwise we want the start of the BB
479 Inserter(InsertBB, InsertBB->getFirstNonPHI(), UseMO);
480 }
481 } // end anonymous namespace
482
tryCombineExtendingLoads(MachineInstr & MI)483 bool CombinerHelper::tryCombineExtendingLoads(MachineInstr &MI) {
484 PreferredTuple Preferred;
485 if (matchCombineExtendingLoads(MI, Preferred)) {
486 applyCombineExtendingLoads(MI, Preferred);
487 return true;
488 }
489 return false;
490 }
491
getExtLoadOpcForExtend(unsigned ExtOpc)492 static unsigned getExtLoadOpcForExtend(unsigned ExtOpc) {
493 unsigned CandidateLoadOpc;
494 switch (ExtOpc) {
495 case TargetOpcode::G_ANYEXT:
496 CandidateLoadOpc = TargetOpcode::G_LOAD;
497 break;
498 case TargetOpcode::G_SEXT:
499 CandidateLoadOpc = TargetOpcode::G_SEXTLOAD;
500 break;
501 case TargetOpcode::G_ZEXT:
502 CandidateLoadOpc = TargetOpcode::G_ZEXTLOAD;
503 break;
504 default:
505 llvm_unreachable("Unexpected extend opc");
506 }
507 return CandidateLoadOpc;
508 }
509
matchCombineExtendingLoads(MachineInstr & MI,PreferredTuple & Preferred)510 bool CombinerHelper::matchCombineExtendingLoads(MachineInstr &MI,
511 PreferredTuple &Preferred) {
512 // We match the loads and follow the uses to the extend instead of matching
513 // the extends and following the def to the load. This is because the load
514 // must remain in the same position for correctness (unless we also add code
515 // to find a safe place to sink it) whereas the extend is freely movable.
516 // It also prevents us from duplicating the load for the volatile case or just
517 // for performance.
518 GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(&MI);
519 if (!LoadMI)
520 return false;
521
522 Register LoadReg = LoadMI->getDstReg();
523
524 LLT LoadValueTy = MRI.getType(LoadReg);
525 if (!LoadValueTy.isScalar())
526 return false;
527
528 // Most architectures are going to legalize <s8 loads into at least a 1 byte
529 // load, and the MMOs can only describe memory accesses in multiples of bytes.
530 // If we try to perform extload combining on those, we can end up with
531 // %a(s8) = extload %ptr (load 1 byte from %ptr)
532 // ... which is an illegal extload instruction.
533 if (LoadValueTy.getSizeInBits() < 8)
534 return false;
535
536 // For non power-of-2 types, they will very likely be legalized into multiple
537 // loads. Don't bother trying to match them into extending loads.
538 if (!isPowerOf2_32(LoadValueTy.getSizeInBits()))
539 return false;
540
541 // Find the preferred type aside from the any-extends (unless it's the only
542 // one) and non-extending ops. We'll emit an extending load to that type and
543 // and emit a variant of (extend (trunc X)) for the others according to the
544 // relative type sizes. At the same time, pick an extend to use based on the
545 // extend involved in the chosen type.
546 unsigned PreferredOpcode =
547 isa<GLoad>(&MI)
548 ? TargetOpcode::G_ANYEXT
549 : isa<GSExtLoad>(&MI) ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
550 Preferred = {LLT(), PreferredOpcode, nullptr};
551 for (auto &UseMI : MRI.use_nodbg_instructions(LoadReg)) {
552 if (UseMI.getOpcode() == TargetOpcode::G_SEXT ||
553 UseMI.getOpcode() == TargetOpcode::G_ZEXT ||
554 (UseMI.getOpcode() == TargetOpcode::G_ANYEXT)) {
555 const auto &MMO = LoadMI->getMMO();
556 // For atomics, only form anyextending loads.
557 if (MMO.isAtomic() && UseMI.getOpcode() != TargetOpcode::G_ANYEXT)
558 continue;
559 // Check for legality.
560 if (!isPreLegalize()) {
561 LegalityQuery::MemDesc MMDesc(MMO);
562 unsigned CandidateLoadOpc = getExtLoadOpcForExtend(UseMI.getOpcode());
563 LLT UseTy = MRI.getType(UseMI.getOperand(0).getReg());
564 LLT SrcTy = MRI.getType(LoadMI->getPointerReg());
565 if (LI->getAction({CandidateLoadOpc, {UseTy, SrcTy}, {MMDesc}})
566 .Action != LegalizeActions::Legal)
567 continue;
568 }
569 Preferred = ChoosePreferredUse(Preferred,
570 MRI.getType(UseMI.getOperand(0).getReg()),
571 UseMI.getOpcode(), &UseMI);
572 }
573 }
574
575 // There were no extends
576 if (!Preferred.MI)
577 return false;
578 // It should be impossible to chose an extend without selecting a different
579 // type since by definition the result of an extend is larger.
580 assert(Preferred.Ty != LoadValueTy && "Extending to same type?");
581
582 LLVM_DEBUG(dbgs() << "Preferred use is: " << *Preferred.MI);
583 return true;
584 }
585
applyCombineExtendingLoads(MachineInstr & MI,PreferredTuple & Preferred)586 void CombinerHelper::applyCombineExtendingLoads(MachineInstr &MI,
587 PreferredTuple &Preferred) {
588 // Rewrite the load to the chosen extending load.
589 Register ChosenDstReg = Preferred.MI->getOperand(0).getReg();
590
591 // Inserter to insert a truncate back to the original type at a given point
592 // with some basic CSE to limit truncate duplication to one per BB.
593 DenseMap<MachineBasicBlock *, MachineInstr *> EmittedInsns;
594 auto InsertTruncAt = [&](MachineBasicBlock *InsertIntoBB,
595 MachineBasicBlock::iterator InsertBefore,
596 MachineOperand &UseMO) {
597 MachineInstr *PreviouslyEmitted = EmittedInsns.lookup(InsertIntoBB);
598 if (PreviouslyEmitted) {
599 Observer.changingInstr(*UseMO.getParent());
600 UseMO.setReg(PreviouslyEmitted->getOperand(0).getReg());
601 Observer.changedInstr(*UseMO.getParent());
602 return;
603 }
604
605 Builder.setInsertPt(*InsertIntoBB, InsertBefore);
606 Register NewDstReg = MRI.cloneVirtualRegister(MI.getOperand(0).getReg());
607 MachineInstr *NewMI = Builder.buildTrunc(NewDstReg, ChosenDstReg);
608 EmittedInsns[InsertIntoBB] = NewMI;
609 replaceRegOpWith(MRI, UseMO, NewDstReg);
610 };
611
612 Observer.changingInstr(MI);
613 unsigned LoadOpc = getExtLoadOpcForExtend(Preferred.ExtendOpcode);
614 MI.setDesc(Builder.getTII().get(LoadOpc));
615
616 // Rewrite all the uses to fix up the types.
617 auto &LoadValue = MI.getOperand(0);
618 SmallVector<MachineOperand *, 4> Uses;
619 for (auto &UseMO : MRI.use_operands(LoadValue.getReg()))
620 Uses.push_back(&UseMO);
621
622 for (auto *UseMO : Uses) {
623 MachineInstr *UseMI = UseMO->getParent();
624
625 // If the extend is compatible with the preferred extend then we should fix
626 // up the type and extend so that it uses the preferred use.
627 if (UseMI->getOpcode() == Preferred.ExtendOpcode ||
628 UseMI->getOpcode() == TargetOpcode::G_ANYEXT) {
629 Register UseDstReg = UseMI->getOperand(0).getReg();
630 MachineOperand &UseSrcMO = UseMI->getOperand(1);
631 const LLT UseDstTy = MRI.getType(UseDstReg);
632 if (UseDstReg != ChosenDstReg) {
633 if (Preferred.Ty == UseDstTy) {
634 // If the use has the same type as the preferred use, then merge
635 // the vregs and erase the extend. For example:
636 // %1:_(s8) = G_LOAD ...
637 // %2:_(s32) = G_SEXT %1(s8)
638 // %3:_(s32) = G_ANYEXT %1(s8)
639 // ... = ... %3(s32)
640 // rewrites to:
641 // %2:_(s32) = G_SEXTLOAD ...
642 // ... = ... %2(s32)
643 replaceRegWith(MRI, UseDstReg, ChosenDstReg);
644 Observer.erasingInstr(*UseMO->getParent());
645 UseMO->getParent()->eraseFromParent();
646 } else if (Preferred.Ty.getSizeInBits() < UseDstTy.getSizeInBits()) {
647 // If the preferred size is smaller, then keep the extend but extend
648 // from the result of the extending load. For example:
649 // %1:_(s8) = G_LOAD ...
650 // %2:_(s32) = G_SEXT %1(s8)
651 // %3:_(s64) = G_ANYEXT %1(s8)
652 // ... = ... %3(s64)
653 /// rewrites to:
654 // %2:_(s32) = G_SEXTLOAD ...
655 // %3:_(s64) = G_ANYEXT %2:_(s32)
656 // ... = ... %3(s64)
657 replaceRegOpWith(MRI, UseSrcMO, ChosenDstReg);
658 } else {
659 // If the preferred size is large, then insert a truncate. For
660 // example:
661 // %1:_(s8) = G_LOAD ...
662 // %2:_(s64) = G_SEXT %1(s8)
663 // %3:_(s32) = G_ZEXT %1(s8)
664 // ... = ... %3(s32)
665 /// rewrites to:
666 // %2:_(s64) = G_SEXTLOAD ...
667 // %4:_(s8) = G_TRUNC %2:_(s32)
668 // %3:_(s64) = G_ZEXT %2:_(s8)
669 // ... = ... %3(s64)
670 InsertInsnsWithoutSideEffectsBeforeUse(Builder, MI, *UseMO,
671 InsertTruncAt);
672 }
673 continue;
674 }
675 // The use is (one of) the uses of the preferred use we chose earlier.
676 // We're going to update the load to def this value later so just erase
677 // the old extend.
678 Observer.erasingInstr(*UseMO->getParent());
679 UseMO->getParent()->eraseFromParent();
680 continue;
681 }
682
683 // The use isn't an extend. Truncate back to the type we originally loaded.
684 // This is free on many targets.
685 InsertInsnsWithoutSideEffectsBeforeUse(Builder, MI, *UseMO, InsertTruncAt);
686 }
687
688 MI.getOperand(0).setReg(ChosenDstReg);
689 Observer.changedInstr(MI);
690 }
691
matchCombineLoadWithAndMask(MachineInstr & MI,BuildFnTy & MatchInfo)692 bool CombinerHelper::matchCombineLoadWithAndMask(MachineInstr &MI,
693 BuildFnTy &MatchInfo) {
694 assert(MI.getOpcode() == TargetOpcode::G_AND);
695
696 // If we have the following code:
697 // %mask = G_CONSTANT 255
698 // %ld = G_LOAD %ptr, (load s16)
699 // %and = G_AND %ld, %mask
700 //
701 // Try to fold it into
702 // %ld = G_ZEXTLOAD %ptr, (load s8)
703
704 Register Dst = MI.getOperand(0).getReg();
705 if (MRI.getType(Dst).isVector())
706 return false;
707
708 auto MaybeMask =
709 getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI);
710 if (!MaybeMask)
711 return false;
712
713 APInt MaskVal = MaybeMask->Value;
714
715 if (!MaskVal.isMask())
716 return false;
717
718 Register SrcReg = MI.getOperand(1).getReg();
719 // Don't use getOpcodeDef() here since intermediate instructions may have
720 // multiple users.
721 GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(MRI.getVRegDef(SrcReg));
722 if (!LoadMI || !MRI.hasOneNonDBGUse(LoadMI->getDstReg()))
723 return false;
724
725 Register LoadReg = LoadMI->getDstReg();
726 LLT RegTy = MRI.getType(LoadReg);
727 Register PtrReg = LoadMI->getPointerReg();
728 unsigned RegSize = RegTy.getSizeInBits();
729 uint64_t LoadSizeBits = LoadMI->getMemSizeInBits();
730 unsigned MaskSizeBits = MaskVal.countTrailingOnes();
731
732 // The mask may not be larger than the in-memory type, as it might cover sign
733 // extended bits
734 if (MaskSizeBits > LoadSizeBits)
735 return false;
736
737 // If the mask covers the whole destination register, there's nothing to
738 // extend
739 if (MaskSizeBits >= RegSize)
740 return false;
741
742 // Most targets cannot deal with loads of size < 8 and need to re-legalize to
743 // at least byte loads. Avoid creating such loads here
744 if (MaskSizeBits < 8 || !isPowerOf2_32(MaskSizeBits))
745 return false;
746
747 const MachineMemOperand &MMO = LoadMI->getMMO();
748 LegalityQuery::MemDesc MemDesc(MMO);
749
750 // Don't modify the memory access size if this is atomic/volatile, but we can
751 // still adjust the opcode to indicate the high bit behavior.
752 if (LoadMI->isSimple())
753 MemDesc.MemoryTy = LLT::scalar(MaskSizeBits);
754 else if (LoadSizeBits > MaskSizeBits || LoadSizeBits == RegSize)
755 return false;
756
757 // TODO: Could check if it's legal with the reduced or original memory size.
758 if (!isLegalOrBeforeLegalizer(
759 {TargetOpcode::G_ZEXTLOAD, {RegTy, MRI.getType(PtrReg)}, {MemDesc}}))
760 return false;
761
762 MatchInfo = [=](MachineIRBuilder &B) {
763 B.setInstrAndDebugLoc(*LoadMI);
764 auto &MF = B.getMF();
765 auto PtrInfo = MMO.getPointerInfo();
766 auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, MemDesc.MemoryTy);
767 B.buildLoadInstr(TargetOpcode::G_ZEXTLOAD, Dst, PtrReg, *NewMMO);
768 LoadMI->eraseFromParent();
769 };
770 return true;
771 }
772
isPredecessor(const MachineInstr & DefMI,const MachineInstr & UseMI)773 bool CombinerHelper::isPredecessor(const MachineInstr &DefMI,
774 const MachineInstr &UseMI) {
775 assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() &&
776 "shouldn't consider debug uses");
777 assert(DefMI.getParent() == UseMI.getParent());
778 if (&DefMI == &UseMI)
779 return true;
780 const MachineBasicBlock &MBB = *DefMI.getParent();
781 auto DefOrUse = find_if(MBB, [&DefMI, &UseMI](const MachineInstr &MI) {
782 return &MI == &DefMI || &MI == &UseMI;
783 });
784 if (DefOrUse == MBB.end())
785 llvm_unreachable("Block must contain both DefMI and UseMI!");
786 return &*DefOrUse == &DefMI;
787 }
788
dominates(const MachineInstr & DefMI,const MachineInstr & UseMI)789 bool CombinerHelper::dominates(const MachineInstr &DefMI,
790 const MachineInstr &UseMI) {
791 assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() &&
792 "shouldn't consider debug uses");
793 if (MDT)
794 return MDT->dominates(&DefMI, &UseMI);
795 else if (DefMI.getParent() != UseMI.getParent())
796 return false;
797
798 return isPredecessor(DefMI, UseMI);
799 }
800
matchSextTruncSextLoad(MachineInstr & MI)801 bool CombinerHelper::matchSextTruncSextLoad(MachineInstr &MI) {
802 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
803 Register SrcReg = MI.getOperand(1).getReg();
804 Register LoadUser = SrcReg;
805
806 if (MRI.getType(SrcReg).isVector())
807 return false;
808
809 Register TruncSrc;
810 if (mi_match(SrcReg, MRI, m_GTrunc(m_Reg(TruncSrc))))
811 LoadUser = TruncSrc;
812
813 uint64_t SizeInBits = MI.getOperand(2).getImm();
814 // If the source is a G_SEXTLOAD from the same bit width, then we don't
815 // need any extend at all, just a truncate.
816 if (auto *LoadMI = getOpcodeDef<GSExtLoad>(LoadUser, MRI)) {
817 // If truncating more than the original extended value, abort.
818 auto LoadSizeBits = LoadMI->getMemSizeInBits();
819 if (TruncSrc && MRI.getType(TruncSrc).getSizeInBits() < LoadSizeBits)
820 return false;
821 if (LoadSizeBits == SizeInBits)
822 return true;
823 }
824 return false;
825 }
826
applySextTruncSextLoad(MachineInstr & MI)827 void CombinerHelper::applySextTruncSextLoad(MachineInstr &MI) {
828 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
829 Builder.setInstrAndDebugLoc(MI);
830 Builder.buildCopy(MI.getOperand(0).getReg(), MI.getOperand(1).getReg());
831 MI.eraseFromParent();
832 }
833
matchSextInRegOfLoad(MachineInstr & MI,std::tuple<Register,unsigned> & MatchInfo)834 bool CombinerHelper::matchSextInRegOfLoad(
835 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) {
836 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
837
838 Register DstReg = MI.getOperand(0).getReg();
839 LLT RegTy = MRI.getType(DstReg);
840
841 // Only supports scalars for now.
842 if (RegTy.isVector())
843 return false;
844
845 Register SrcReg = MI.getOperand(1).getReg();
846 auto *LoadDef = getOpcodeDef<GLoad>(SrcReg, MRI);
847 if (!LoadDef || !MRI.hasOneNonDBGUse(DstReg))
848 return false;
849
850 uint64_t MemBits = LoadDef->getMemSizeInBits();
851
852 // If the sign extend extends from a narrower width than the load's width,
853 // then we can narrow the load width when we combine to a G_SEXTLOAD.
854 // Avoid widening the load at all.
855 unsigned NewSizeBits = std::min((uint64_t)MI.getOperand(2).getImm(), MemBits);
856
857 // Don't generate G_SEXTLOADs with a < 1 byte width.
858 if (NewSizeBits < 8)
859 return false;
860 // Don't bother creating a non-power-2 sextload, it will likely be broken up
861 // anyway for most targets.
862 if (!isPowerOf2_32(NewSizeBits))
863 return false;
864
865 const MachineMemOperand &MMO = LoadDef->getMMO();
866 LegalityQuery::MemDesc MMDesc(MMO);
867
868 // Don't modify the memory access size if this is atomic/volatile, but we can
869 // still adjust the opcode to indicate the high bit behavior.
870 if (LoadDef->isSimple())
871 MMDesc.MemoryTy = LLT::scalar(NewSizeBits);
872 else if (MemBits > NewSizeBits || MemBits == RegTy.getSizeInBits())
873 return false;
874
875 // TODO: Could check if it's legal with the reduced or original memory size.
876 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SEXTLOAD,
877 {MRI.getType(LoadDef->getDstReg()),
878 MRI.getType(LoadDef->getPointerReg())},
879 {MMDesc}}))
880 return false;
881
882 MatchInfo = std::make_tuple(LoadDef->getDstReg(), NewSizeBits);
883 return true;
884 }
885
applySextInRegOfLoad(MachineInstr & MI,std::tuple<Register,unsigned> & MatchInfo)886 void CombinerHelper::applySextInRegOfLoad(
887 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) {
888 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
889 Register LoadReg;
890 unsigned ScalarSizeBits;
891 std::tie(LoadReg, ScalarSizeBits) = MatchInfo;
892 GLoad *LoadDef = cast<GLoad>(MRI.getVRegDef(LoadReg));
893
894 // If we have the following:
895 // %ld = G_LOAD %ptr, (load 2)
896 // %ext = G_SEXT_INREG %ld, 8
897 // ==>
898 // %ld = G_SEXTLOAD %ptr (load 1)
899
900 auto &MMO = LoadDef->getMMO();
901 Builder.setInstrAndDebugLoc(*LoadDef);
902 auto &MF = Builder.getMF();
903 auto PtrInfo = MMO.getPointerInfo();
904 auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, ScalarSizeBits / 8);
905 Builder.buildLoadInstr(TargetOpcode::G_SEXTLOAD, MI.getOperand(0).getReg(),
906 LoadDef->getPointerReg(), *NewMMO);
907 MI.eraseFromParent();
908 }
909
findPostIndexCandidate(MachineInstr & MI,Register & Addr,Register & Base,Register & Offset)910 bool CombinerHelper::findPostIndexCandidate(MachineInstr &MI, Register &Addr,
911 Register &Base, Register &Offset) {
912 auto &MF = *MI.getParent()->getParent();
913 const auto &TLI = *MF.getSubtarget().getTargetLowering();
914
915 #ifndef NDEBUG
916 unsigned Opcode = MI.getOpcode();
917 assert(Opcode == TargetOpcode::G_LOAD || Opcode == TargetOpcode::G_SEXTLOAD ||
918 Opcode == TargetOpcode::G_ZEXTLOAD || Opcode == TargetOpcode::G_STORE);
919 #endif
920
921 Base = MI.getOperand(1).getReg();
922 MachineInstr *BaseDef = MRI.getUniqueVRegDef(Base);
923 if (BaseDef && BaseDef->getOpcode() == TargetOpcode::G_FRAME_INDEX)
924 return false;
925
926 LLVM_DEBUG(dbgs() << "Searching for post-indexing opportunity for: " << MI);
927 // FIXME: The following use traversal needs a bail out for patholigical cases.
928 for (auto &Use : MRI.use_nodbg_instructions(Base)) {
929 if (Use.getOpcode() != TargetOpcode::G_PTR_ADD)
930 continue;
931
932 Offset = Use.getOperand(2).getReg();
933 if (!ForceLegalIndexing &&
934 !TLI.isIndexingLegal(MI, Base, Offset, /*IsPre*/ false, MRI)) {
935 LLVM_DEBUG(dbgs() << " Ignoring candidate with illegal addrmode: "
936 << Use);
937 continue;
938 }
939
940 // Make sure the offset calculation is before the potentially indexed op.
941 // FIXME: we really care about dependency here. The offset calculation might
942 // be movable.
943 MachineInstr *OffsetDef = MRI.getUniqueVRegDef(Offset);
944 if (!OffsetDef || !dominates(*OffsetDef, MI)) {
945 LLVM_DEBUG(dbgs() << " Ignoring candidate with offset after mem-op: "
946 << Use);
947 continue;
948 }
949
950 // FIXME: check whether all uses of Base are load/store with foldable
951 // addressing modes. If so, using the normal addr-modes is better than
952 // forming an indexed one.
953
954 bool MemOpDominatesAddrUses = true;
955 for (auto &PtrAddUse :
956 MRI.use_nodbg_instructions(Use.getOperand(0).getReg())) {
957 if (!dominates(MI, PtrAddUse)) {
958 MemOpDominatesAddrUses = false;
959 break;
960 }
961 }
962
963 if (!MemOpDominatesAddrUses) {
964 LLVM_DEBUG(
965 dbgs() << " Ignoring candidate as memop does not dominate uses: "
966 << Use);
967 continue;
968 }
969
970 LLVM_DEBUG(dbgs() << " Found match: " << Use);
971 Addr = Use.getOperand(0).getReg();
972 return true;
973 }
974
975 return false;
976 }
977
findPreIndexCandidate(MachineInstr & MI,Register & Addr,Register & Base,Register & Offset)978 bool CombinerHelper::findPreIndexCandidate(MachineInstr &MI, Register &Addr,
979 Register &Base, Register &Offset) {
980 auto &MF = *MI.getParent()->getParent();
981 const auto &TLI = *MF.getSubtarget().getTargetLowering();
982
983 #ifndef NDEBUG
984 unsigned Opcode = MI.getOpcode();
985 assert(Opcode == TargetOpcode::G_LOAD || Opcode == TargetOpcode::G_SEXTLOAD ||
986 Opcode == TargetOpcode::G_ZEXTLOAD || Opcode == TargetOpcode::G_STORE);
987 #endif
988
989 Addr = MI.getOperand(1).getReg();
990 MachineInstr *AddrDef = getOpcodeDef(TargetOpcode::G_PTR_ADD, Addr, MRI);
991 if (!AddrDef || MRI.hasOneNonDBGUse(Addr))
992 return false;
993
994 Base = AddrDef->getOperand(1).getReg();
995 Offset = AddrDef->getOperand(2).getReg();
996
997 LLVM_DEBUG(dbgs() << "Found potential pre-indexed load_store: " << MI);
998
999 if (!ForceLegalIndexing &&
1000 !TLI.isIndexingLegal(MI, Base, Offset, /*IsPre*/ true, MRI)) {
1001 LLVM_DEBUG(dbgs() << " Skipping, not legal for target");
1002 return false;
1003 }
1004
1005 MachineInstr *BaseDef = getDefIgnoringCopies(Base, MRI);
1006 if (BaseDef->getOpcode() == TargetOpcode::G_FRAME_INDEX) {
1007 LLVM_DEBUG(dbgs() << " Skipping, frame index would need copy anyway.");
1008 return false;
1009 }
1010
1011 if (MI.getOpcode() == TargetOpcode::G_STORE) {
1012 // Would require a copy.
1013 if (Base == MI.getOperand(0).getReg()) {
1014 LLVM_DEBUG(dbgs() << " Skipping, storing base so need copy anyway.");
1015 return false;
1016 }
1017
1018 // We're expecting one use of Addr in MI, but it could also be the
1019 // value stored, which isn't actually dominated by the instruction.
1020 if (MI.getOperand(0).getReg() == Addr) {
1021 LLVM_DEBUG(dbgs() << " Skipping, does not dominate all addr uses");
1022 return false;
1023 }
1024 }
1025
1026 // FIXME: check whether all uses of the base pointer are constant PtrAdds.
1027 // That might allow us to end base's liveness here by adjusting the constant.
1028
1029 for (auto &UseMI : MRI.use_nodbg_instructions(Addr)) {
1030 if (!dominates(MI, UseMI)) {
1031 LLVM_DEBUG(dbgs() << " Skipping, does not dominate all addr uses.");
1032 return false;
1033 }
1034 }
1035
1036 return true;
1037 }
1038
tryCombineIndexedLoadStore(MachineInstr & MI)1039 bool CombinerHelper::tryCombineIndexedLoadStore(MachineInstr &MI) {
1040 IndexedLoadStoreMatchInfo MatchInfo;
1041 if (matchCombineIndexedLoadStore(MI, MatchInfo)) {
1042 applyCombineIndexedLoadStore(MI, MatchInfo);
1043 return true;
1044 }
1045 return false;
1046 }
1047
matchCombineIndexedLoadStore(MachineInstr & MI,IndexedLoadStoreMatchInfo & MatchInfo)1048 bool CombinerHelper::matchCombineIndexedLoadStore(MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) {
1049 unsigned Opcode = MI.getOpcode();
1050 if (Opcode != TargetOpcode::G_LOAD && Opcode != TargetOpcode::G_SEXTLOAD &&
1051 Opcode != TargetOpcode::G_ZEXTLOAD && Opcode != TargetOpcode::G_STORE)
1052 return false;
1053
1054 // For now, no targets actually support these opcodes so don't waste time
1055 // running these unless we're forced to for testing.
1056 if (!ForceLegalIndexing)
1057 return false;
1058
1059 MatchInfo.IsPre = findPreIndexCandidate(MI, MatchInfo.Addr, MatchInfo.Base,
1060 MatchInfo.Offset);
1061 if (!MatchInfo.IsPre &&
1062 !findPostIndexCandidate(MI, MatchInfo.Addr, MatchInfo.Base,
1063 MatchInfo.Offset))
1064 return false;
1065
1066 return true;
1067 }
1068
applyCombineIndexedLoadStore(MachineInstr & MI,IndexedLoadStoreMatchInfo & MatchInfo)1069 void CombinerHelper::applyCombineIndexedLoadStore(
1070 MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) {
1071 MachineInstr &AddrDef = *MRI.getUniqueVRegDef(MatchInfo.Addr);
1072 MachineIRBuilder MIRBuilder(MI);
1073 unsigned Opcode = MI.getOpcode();
1074 bool IsStore = Opcode == TargetOpcode::G_STORE;
1075 unsigned NewOpcode;
1076 switch (Opcode) {
1077 case TargetOpcode::G_LOAD:
1078 NewOpcode = TargetOpcode::G_INDEXED_LOAD;
1079 break;
1080 case TargetOpcode::G_SEXTLOAD:
1081 NewOpcode = TargetOpcode::G_INDEXED_SEXTLOAD;
1082 break;
1083 case TargetOpcode::G_ZEXTLOAD:
1084 NewOpcode = TargetOpcode::G_INDEXED_ZEXTLOAD;
1085 break;
1086 case TargetOpcode::G_STORE:
1087 NewOpcode = TargetOpcode::G_INDEXED_STORE;
1088 break;
1089 default:
1090 llvm_unreachable("Unknown load/store opcode");
1091 }
1092
1093 auto MIB = MIRBuilder.buildInstr(NewOpcode);
1094 if (IsStore) {
1095 MIB.addDef(MatchInfo.Addr);
1096 MIB.addUse(MI.getOperand(0).getReg());
1097 } else {
1098 MIB.addDef(MI.getOperand(0).getReg());
1099 MIB.addDef(MatchInfo.Addr);
1100 }
1101
1102 MIB.addUse(MatchInfo.Base);
1103 MIB.addUse(MatchInfo.Offset);
1104 MIB.addImm(MatchInfo.IsPre);
1105 MI.eraseFromParent();
1106 AddrDef.eraseFromParent();
1107
1108 LLVM_DEBUG(dbgs() << " Combinined to indexed operation");
1109 }
1110
matchCombineDivRem(MachineInstr & MI,MachineInstr * & OtherMI)1111 bool CombinerHelper::matchCombineDivRem(MachineInstr &MI,
1112 MachineInstr *&OtherMI) {
1113 unsigned Opcode = MI.getOpcode();
1114 bool IsDiv, IsSigned;
1115
1116 switch (Opcode) {
1117 default:
1118 llvm_unreachable("Unexpected opcode!");
1119 case TargetOpcode::G_SDIV:
1120 case TargetOpcode::G_UDIV: {
1121 IsDiv = true;
1122 IsSigned = Opcode == TargetOpcode::G_SDIV;
1123 break;
1124 }
1125 case TargetOpcode::G_SREM:
1126 case TargetOpcode::G_UREM: {
1127 IsDiv = false;
1128 IsSigned = Opcode == TargetOpcode::G_SREM;
1129 break;
1130 }
1131 }
1132
1133 Register Src1 = MI.getOperand(1).getReg();
1134 unsigned DivOpcode, RemOpcode, DivremOpcode;
1135 if (IsSigned) {
1136 DivOpcode = TargetOpcode::G_SDIV;
1137 RemOpcode = TargetOpcode::G_SREM;
1138 DivremOpcode = TargetOpcode::G_SDIVREM;
1139 } else {
1140 DivOpcode = TargetOpcode::G_UDIV;
1141 RemOpcode = TargetOpcode::G_UREM;
1142 DivremOpcode = TargetOpcode::G_UDIVREM;
1143 }
1144
1145 if (!isLegalOrBeforeLegalizer({DivremOpcode, {MRI.getType(Src1)}}))
1146 return false;
1147
1148 // Combine:
1149 // %div:_ = G_[SU]DIV %src1:_, %src2:_
1150 // %rem:_ = G_[SU]REM %src1:_, %src2:_
1151 // into:
1152 // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_
1153
1154 // Combine:
1155 // %rem:_ = G_[SU]REM %src1:_, %src2:_
1156 // %div:_ = G_[SU]DIV %src1:_, %src2:_
1157 // into:
1158 // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_
1159
1160 for (auto &UseMI : MRI.use_nodbg_instructions(Src1)) {
1161 if (MI.getParent() == UseMI.getParent() &&
1162 ((IsDiv && UseMI.getOpcode() == RemOpcode) ||
1163 (!IsDiv && UseMI.getOpcode() == DivOpcode)) &&
1164 matchEqualDefs(MI.getOperand(2), UseMI.getOperand(2)) &&
1165 matchEqualDefs(MI.getOperand(1), UseMI.getOperand(1))) {
1166 OtherMI = &UseMI;
1167 return true;
1168 }
1169 }
1170
1171 return false;
1172 }
1173
applyCombineDivRem(MachineInstr & MI,MachineInstr * & OtherMI)1174 void CombinerHelper::applyCombineDivRem(MachineInstr &MI,
1175 MachineInstr *&OtherMI) {
1176 unsigned Opcode = MI.getOpcode();
1177 assert(OtherMI && "OtherMI shouldn't be empty.");
1178
1179 Register DestDivReg, DestRemReg;
1180 if (Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_UDIV) {
1181 DestDivReg = MI.getOperand(0).getReg();
1182 DestRemReg = OtherMI->getOperand(0).getReg();
1183 } else {
1184 DestDivReg = OtherMI->getOperand(0).getReg();
1185 DestRemReg = MI.getOperand(0).getReg();
1186 }
1187
1188 bool IsSigned =
1189 Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM;
1190
1191 // Check which instruction is first in the block so we don't break def-use
1192 // deps by "moving" the instruction incorrectly.
1193 if (dominates(MI, *OtherMI))
1194 Builder.setInstrAndDebugLoc(MI);
1195 else
1196 Builder.setInstrAndDebugLoc(*OtherMI);
1197
1198 Builder.buildInstr(IsSigned ? TargetOpcode::G_SDIVREM
1199 : TargetOpcode::G_UDIVREM,
1200 {DestDivReg, DestRemReg},
1201 {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()});
1202 MI.eraseFromParent();
1203 OtherMI->eraseFromParent();
1204 }
1205
matchOptBrCondByInvertingCond(MachineInstr & MI,MachineInstr * & BrCond)1206 bool CombinerHelper::matchOptBrCondByInvertingCond(MachineInstr &MI,
1207 MachineInstr *&BrCond) {
1208 assert(MI.getOpcode() == TargetOpcode::G_BR);
1209
1210 // Try to match the following:
1211 // bb1:
1212 // G_BRCOND %c1, %bb2
1213 // G_BR %bb3
1214 // bb2:
1215 // ...
1216 // bb3:
1217
1218 // The above pattern does not have a fall through to the successor bb2, always
1219 // resulting in a branch no matter which path is taken. Here we try to find
1220 // and replace that pattern with conditional branch to bb3 and otherwise
1221 // fallthrough to bb2. This is generally better for branch predictors.
1222
1223 MachineBasicBlock *MBB = MI.getParent();
1224 MachineBasicBlock::iterator BrIt(MI);
1225 if (BrIt == MBB->begin())
1226 return false;
1227 assert(std::next(BrIt) == MBB->end() && "expected G_BR to be a terminator");
1228
1229 BrCond = &*std::prev(BrIt);
1230 if (BrCond->getOpcode() != TargetOpcode::G_BRCOND)
1231 return false;
1232
1233 // Check that the next block is the conditional branch target. Also make sure
1234 // that it isn't the same as the G_BR's target (otherwise, this will loop.)
1235 MachineBasicBlock *BrCondTarget = BrCond->getOperand(1).getMBB();
1236 return BrCondTarget != MI.getOperand(0).getMBB() &&
1237 MBB->isLayoutSuccessor(BrCondTarget);
1238 }
1239
applyOptBrCondByInvertingCond(MachineInstr & MI,MachineInstr * & BrCond)1240 void CombinerHelper::applyOptBrCondByInvertingCond(MachineInstr &MI,
1241 MachineInstr *&BrCond) {
1242 MachineBasicBlock *BrTarget = MI.getOperand(0).getMBB();
1243 Builder.setInstrAndDebugLoc(*BrCond);
1244 LLT Ty = MRI.getType(BrCond->getOperand(0).getReg());
1245 // FIXME: Does int/fp matter for this? If so, we might need to restrict
1246 // this to i1 only since we might not know for sure what kind of
1247 // compare generated the condition value.
1248 auto True = Builder.buildConstant(
1249 Ty, getICmpTrueVal(getTargetLowering(), false, false));
1250 auto Xor = Builder.buildXor(Ty, BrCond->getOperand(0), True);
1251
1252 auto *FallthroughBB = BrCond->getOperand(1).getMBB();
1253 Observer.changingInstr(MI);
1254 MI.getOperand(0).setMBB(FallthroughBB);
1255 Observer.changedInstr(MI);
1256
1257 // Change the conditional branch to use the inverted condition and
1258 // new target block.
1259 Observer.changingInstr(*BrCond);
1260 BrCond->getOperand(0).setReg(Xor.getReg(0));
1261 BrCond->getOperand(1).setMBB(BrTarget);
1262 Observer.changedInstr(*BrCond);
1263 }
1264
getTypeForLLT(LLT Ty,LLVMContext & C)1265 static Type *getTypeForLLT(LLT Ty, LLVMContext &C) {
1266 if (Ty.isVector())
1267 return FixedVectorType::get(IntegerType::get(C, Ty.getScalarSizeInBits()),
1268 Ty.getNumElements());
1269 return IntegerType::get(C, Ty.getSizeInBits());
1270 }
1271
tryEmitMemcpyInline(MachineInstr & MI)1272 bool CombinerHelper::tryEmitMemcpyInline(MachineInstr &MI) {
1273 MachineIRBuilder HelperBuilder(MI);
1274 GISelObserverWrapper DummyObserver;
1275 LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder);
1276 return Helper.lowerMemcpyInline(MI) ==
1277 LegalizerHelper::LegalizeResult::Legalized;
1278 }
1279
tryCombineMemCpyFamily(MachineInstr & MI,unsigned MaxLen)1280 bool CombinerHelper::tryCombineMemCpyFamily(MachineInstr &MI, unsigned MaxLen) {
1281 MachineIRBuilder HelperBuilder(MI);
1282 GISelObserverWrapper DummyObserver;
1283 LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder);
1284 return Helper.lowerMemCpyFamily(MI, MaxLen) ==
1285 LegalizerHelper::LegalizeResult::Legalized;
1286 }
1287
1288 static std::optional<APFloat>
constantFoldFpUnary(unsigned Opcode,LLT DstTy,const Register Op,const MachineRegisterInfo & MRI)1289 constantFoldFpUnary(unsigned Opcode, LLT DstTy, const Register Op,
1290 const MachineRegisterInfo &MRI) {
1291 const ConstantFP *MaybeCst = getConstantFPVRegVal(Op, MRI);
1292 if (!MaybeCst)
1293 return std::nullopt;
1294
1295 APFloat V = MaybeCst->getValueAPF();
1296 switch (Opcode) {
1297 default:
1298 llvm_unreachable("Unexpected opcode!");
1299 case TargetOpcode::G_FNEG: {
1300 V.changeSign();
1301 return V;
1302 }
1303 case TargetOpcode::G_FABS: {
1304 V.clearSign();
1305 return V;
1306 }
1307 case TargetOpcode::G_FPTRUNC:
1308 break;
1309 case TargetOpcode::G_FSQRT: {
1310 bool Unused;
1311 V.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &Unused);
1312 V = APFloat(sqrt(V.convertToDouble()));
1313 break;
1314 }
1315 case TargetOpcode::G_FLOG2: {
1316 bool Unused;
1317 V.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &Unused);
1318 V = APFloat(log2(V.convertToDouble()));
1319 break;
1320 }
1321 }
1322 // Convert `APFloat` to appropriate IEEE type depending on `DstTy`. Otherwise,
1323 // `buildFConstant` will assert on size mismatch. Only `G_FPTRUNC`, `G_FSQRT`,
1324 // and `G_FLOG2` reach here.
1325 bool Unused;
1326 V.convert(getFltSemanticForLLT(DstTy), APFloat::rmNearestTiesToEven, &Unused);
1327 return V;
1328 }
1329
matchCombineConstantFoldFpUnary(MachineInstr & MI,std::optional<APFloat> & Cst)1330 bool CombinerHelper::matchCombineConstantFoldFpUnary(
1331 MachineInstr &MI, std::optional<APFloat> &Cst) {
1332 Register DstReg = MI.getOperand(0).getReg();
1333 Register SrcReg = MI.getOperand(1).getReg();
1334 LLT DstTy = MRI.getType(DstReg);
1335 Cst = constantFoldFpUnary(MI.getOpcode(), DstTy, SrcReg, MRI);
1336 return Cst.has_value();
1337 }
1338
applyCombineConstantFoldFpUnary(MachineInstr & MI,std::optional<APFloat> & Cst)1339 void CombinerHelper::applyCombineConstantFoldFpUnary(
1340 MachineInstr &MI, std::optional<APFloat> &Cst) {
1341 assert(Cst && "Optional is unexpectedly empty!");
1342 Builder.setInstrAndDebugLoc(MI);
1343 MachineFunction &MF = Builder.getMF();
1344 auto *FPVal = ConstantFP::get(MF.getFunction().getContext(), *Cst);
1345 Register DstReg = MI.getOperand(0).getReg();
1346 Builder.buildFConstant(DstReg, *FPVal);
1347 MI.eraseFromParent();
1348 }
1349
matchPtrAddImmedChain(MachineInstr & MI,PtrAddChain & MatchInfo)1350 bool CombinerHelper::matchPtrAddImmedChain(MachineInstr &MI,
1351 PtrAddChain &MatchInfo) {
1352 // We're trying to match the following pattern:
1353 // %t1 = G_PTR_ADD %base, G_CONSTANT imm1
1354 // %root = G_PTR_ADD %t1, G_CONSTANT imm2
1355 // -->
1356 // %root = G_PTR_ADD %base, G_CONSTANT (imm1 + imm2)
1357
1358 if (MI.getOpcode() != TargetOpcode::G_PTR_ADD)
1359 return false;
1360
1361 Register Add2 = MI.getOperand(1).getReg();
1362 Register Imm1 = MI.getOperand(2).getReg();
1363 auto MaybeImmVal = getIConstantVRegValWithLookThrough(Imm1, MRI);
1364 if (!MaybeImmVal)
1365 return false;
1366
1367 MachineInstr *Add2Def = MRI.getVRegDef(Add2);
1368 if (!Add2Def || Add2Def->getOpcode() != TargetOpcode::G_PTR_ADD)
1369 return false;
1370
1371 Register Base = Add2Def->getOperand(1).getReg();
1372 Register Imm2 = Add2Def->getOperand(2).getReg();
1373 auto MaybeImm2Val = getIConstantVRegValWithLookThrough(Imm2, MRI);
1374 if (!MaybeImm2Val)
1375 return false;
1376
1377 // Check if the new combined immediate forms an illegal addressing mode.
1378 // Do not combine if it was legal before but would get illegal.
1379 // To do so, we need to find a load/store user of the pointer to get
1380 // the access type.
1381 Type *AccessTy = nullptr;
1382 auto &MF = *MI.getMF();
1383 for (auto &UseMI : MRI.use_nodbg_instructions(MI.getOperand(0).getReg())) {
1384 if (auto *LdSt = dyn_cast<GLoadStore>(&UseMI)) {
1385 AccessTy = getTypeForLLT(MRI.getType(LdSt->getReg(0)),
1386 MF.getFunction().getContext());
1387 break;
1388 }
1389 }
1390 TargetLoweringBase::AddrMode AMNew;
1391 APInt CombinedImm = MaybeImmVal->Value + MaybeImm2Val->Value;
1392 AMNew.BaseOffs = CombinedImm.getSExtValue();
1393 if (AccessTy) {
1394 AMNew.HasBaseReg = true;
1395 TargetLoweringBase::AddrMode AMOld;
1396 AMOld.BaseOffs = MaybeImm2Val->Value.getSExtValue();
1397 AMOld.HasBaseReg = true;
1398 unsigned AS = MRI.getType(Add2).getAddressSpace();
1399 const auto &TLI = *MF.getSubtarget().getTargetLowering();
1400 if (TLI.isLegalAddressingMode(MF.getDataLayout(), AMOld, AccessTy, AS) &&
1401 !TLI.isLegalAddressingMode(MF.getDataLayout(), AMNew, AccessTy, AS))
1402 return false;
1403 }
1404
1405 // Pass the combined immediate to the apply function.
1406 MatchInfo.Imm = AMNew.BaseOffs;
1407 MatchInfo.Base = Base;
1408 MatchInfo.Bank = getRegBank(Imm2);
1409 return true;
1410 }
1411
applyPtrAddImmedChain(MachineInstr & MI,PtrAddChain & MatchInfo)1412 void CombinerHelper::applyPtrAddImmedChain(MachineInstr &MI,
1413 PtrAddChain &MatchInfo) {
1414 assert(MI.getOpcode() == TargetOpcode::G_PTR_ADD && "Expected G_PTR_ADD");
1415 MachineIRBuilder MIB(MI);
1416 LLT OffsetTy = MRI.getType(MI.getOperand(2).getReg());
1417 auto NewOffset = MIB.buildConstant(OffsetTy, MatchInfo.Imm);
1418 setRegBank(NewOffset.getReg(0), MatchInfo.Bank);
1419 Observer.changingInstr(MI);
1420 MI.getOperand(1).setReg(MatchInfo.Base);
1421 MI.getOperand(2).setReg(NewOffset.getReg(0));
1422 Observer.changedInstr(MI);
1423 }
1424
matchShiftImmedChain(MachineInstr & MI,RegisterImmPair & MatchInfo)1425 bool CombinerHelper::matchShiftImmedChain(MachineInstr &MI,
1426 RegisterImmPair &MatchInfo) {
1427 // We're trying to match the following pattern with any of
1428 // G_SHL/G_ASHR/G_LSHR/G_SSHLSAT/G_USHLSAT shift instructions:
1429 // %t1 = SHIFT %base, G_CONSTANT imm1
1430 // %root = SHIFT %t1, G_CONSTANT imm2
1431 // -->
1432 // %root = SHIFT %base, G_CONSTANT (imm1 + imm2)
1433
1434 unsigned Opcode = MI.getOpcode();
1435 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1436 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT ||
1437 Opcode == TargetOpcode::G_USHLSAT) &&
1438 "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT");
1439
1440 Register Shl2 = MI.getOperand(1).getReg();
1441 Register Imm1 = MI.getOperand(2).getReg();
1442 auto MaybeImmVal = getIConstantVRegValWithLookThrough(Imm1, MRI);
1443 if (!MaybeImmVal)
1444 return false;
1445
1446 MachineInstr *Shl2Def = MRI.getUniqueVRegDef(Shl2);
1447 if (Shl2Def->getOpcode() != Opcode)
1448 return false;
1449
1450 Register Base = Shl2Def->getOperand(1).getReg();
1451 Register Imm2 = Shl2Def->getOperand(2).getReg();
1452 auto MaybeImm2Val = getIConstantVRegValWithLookThrough(Imm2, MRI);
1453 if (!MaybeImm2Val)
1454 return false;
1455
1456 // Pass the combined immediate to the apply function.
1457 MatchInfo.Imm =
1458 (MaybeImmVal->Value.getSExtValue() + MaybeImm2Val->Value).getSExtValue();
1459 MatchInfo.Reg = Base;
1460
1461 // There is no simple replacement for a saturating unsigned left shift that
1462 // exceeds the scalar size.
1463 if (Opcode == TargetOpcode::G_USHLSAT &&
1464 MatchInfo.Imm >= MRI.getType(Shl2).getScalarSizeInBits())
1465 return false;
1466
1467 return true;
1468 }
1469
applyShiftImmedChain(MachineInstr & MI,RegisterImmPair & MatchInfo)1470 void CombinerHelper::applyShiftImmedChain(MachineInstr &MI,
1471 RegisterImmPair &MatchInfo) {
1472 unsigned Opcode = MI.getOpcode();
1473 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1474 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT ||
1475 Opcode == TargetOpcode::G_USHLSAT) &&
1476 "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT");
1477
1478 Builder.setInstrAndDebugLoc(MI);
1479 LLT Ty = MRI.getType(MI.getOperand(1).getReg());
1480 unsigned const ScalarSizeInBits = Ty.getScalarSizeInBits();
1481 auto Imm = MatchInfo.Imm;
1482
1483 if (Imm >= ScalarSizeInBits) {
1484 // Any logical shift that exceeds scalar size will produce zero.
1485 if (Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR) {
1486 Builder.buildConstant(MI.getOperand(0), 0);
1487 MI.eraseFromParent();
1488 return;
1489 }
1490 // Arithmetic shift and saturating signed left shift have no effect beyond
1491 // scalar size.
1492 Imm = ScalarSizeInBits - 1;
1493 }
1494
1495 LLT ImmTy = MRI.getType(MI.getOperand(2).getReg());
1496 Register NewImm = Builder.buildConstant(ImmTy, Imm).getReg(0);
1497 Observer.changingInstr(MI);
1498 MI.getOperand(1).setReg(MatchInfo.Reg);
1499 MI.getOperand(2).setReg(NewImm);
1500 Observer.changedInstr(MI);
1501 }
1502
matchShiftOfShiftedLogic(MachineInstr & MI,ShiftOfShiftedLogic & MatchInfo)1503 bool CombinerHelper::matchShiftOfShiftedLogic(MachineInstr &MI,
1504 ShiftOfShiftedLogic &MatchInfo) {
1505 // We're trying to match the following pattern with any of
1506 // G_SHL/G_ASHR/G_LSHR/G_USHLSAT/G_SSHLSAT shift instructions in combination
1507 // with any of G_AND/G_OR/G_XOR logic instructions.
1508 // %t1 = SHIFT %X, G_CONSTANT C0
1509 // %t2 = LOGIC %t1, %Y
1510 // %root = SHIFT %t2, G_CONSTANT C1
1511 // -->
1512 // %t3 = SHIFT %X, G_CONSTANT (C0+C1)
1513 // %t4 = SHIFT %Y, G_CONSTANT C1
1514 // %root = LOGIC %t3, %t4
1515 unsigned ShiftOpcode = MI.getOpcode();
1516 assert((ShiftOpcode == TargetOpcode::G_SHL ||
1517 ShiftOpcode == TargetOpcode::G_ASHR ||
1518 ShiftOpcode == TargetOpcode::G_LSHR ||
1519 ShiftOpcode == TargetOpcode::G_USHLSAT ||
1520 ShiftOpcode == TargetOpcode::G_SSHLSAT) &&
1521 "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT");
1522
1523 // Match a one-use bitwise logic op.
1524 Register LogicDest = MI.getOperand(1).getReg();
1525 if (!MRI.hasOneNonDBGUse(LogicDest))
1526 return false;
1527
1528 MachineInstr *LogicMI = MRI.getUniqueVRegDef(LogicDest);
1529 unsigned LogicOpcode = LogicMI->getOpcode();
1530 if (LogicOpcode != TargetOpcode::G_AND && LogicOpcode != TargetOpcode::G_OR &&
1531 LogicOpcode != TargetOpcode::G_XOR)
1532 return false;
1533
1534 // Find a matching one-use shift by constant.
1535 const Register C1 = MI.getOperand(2).getReg();
1536 auto MaybeImmVal = getIConstantVRegValWithLookThrough(C1, MRI);
1537 if (!MaybeImmVal)
1538 return false;
1539
1540 const uint64_t C1Val = MaybeImmVal->Value.getZExtValue();
1541
1542 auto matchFirstShift = [&](const MachineInstr *MI, uint64_t &ShiftVal) {
1543 // Shift should match previous one and should be a one-use.
1544 if (MI->getOpcode() != ShiftOpcode ||
1545 !MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
1546 return false;
1547
1548 // Must be a constant.
1549 auto MaybeImmVal =
1550 getIConstantVRegValWithLookThrough(MI->getOperand(2).getReg(), MRI);
1551 if (!MaybeImmVal)
1552 return false;
1553
1554 ShiftVal = MaybeImmVal->Value.getSExtValue();
1555 return true;
1556 };
1557
1558 // Logic ops are commutative, so check each operand for a match.
1559 Register LogicMIReg1 = LogicMI->getOperand(1).getReg();
1560 MachineInstr *LogicMIOp1 = MRI.getUniqueVRegDef(LogicMIReg1);
1561 Register LogicMIReg2 = LogicMI->getOperand(2).getReg();
1562 MachineInstr *LogicMIOp2 = MRI.getUniqueVRegDef(LogicMIReg2);
1563 uint64_t C0Val;
1564
1565 if (matchFirstShift(LogicMIOp1, C0Val)) {
1566 MatchInfo.LogicNonShiftReg = LogicMIReg2;
1567 MatchInfo.Shift2 = LogicMIOp1;
1568 } else if (matchFirstShift(LogicMIOp2, C0Val)) {
1569 MatchInfo.LogicNonShiftReg = LogicMIReg1;
1570 MatchInfo.Shift2 = LogicMIOp2;
1571 } else
1572 return false;
1573
1574 MatchInfo.ValSum = C0Val + C1Val;
1575
1576 // The fold is not valid if the sum of the shift values exceeds bitwidth.
1577 if (MatchInfo.ValSum >= MRI.getType(LogicDest).getScalarSizeInBits())
1578 return false;
1579
1580 MatchInfo.Logic = LogicMI;
1581 return true;
1582 }
1583
applyShiftOfShiftedLogic(MachineInstr & MI,ShiftOfShiftedLogic & MatchInfo)1584 void CombinerHelper::applyShiftOfShiftedLogic(MachineInstr &MI,
1585 ShiftOfShiftedLogic &MatchInfo) {
1586 unsigned Opcode = MI.getOpcode();
1587 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR ||
1588 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_USHLSAT ||
1589 Opcode == TargetOpcode::G_SSHLSAT) &&
1590 "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT");
1591
1592 LLT ShlType = MRI.getType(MI.getOperand(2).getReg());
1593 LLT DestType = MRI.getType(MI.getOperand(0).getReg());
1594 Builder.setInstrAndDebugLoc(MI);
1595
1596 Register Const = Builder.buildConstant(ShlType, MatchInfo.ValSum).getReg(0);
1597
1598 Register Shift1Base = MatchInfo.Shift2->getOperand(1).getReg();
1599 Register Shift1 =
1600 Builder.buildInstr(Opcode, {DestType}, {Shift1Base, Const}).getReg(0);
1601
1602 // If LogicNonShiftReg is the same to Shift1Base, and shift1 const is the same
1603 // to MatchInfo.Shift2 const, CSEMIRBuilder will reuse the old shift1 when
1604 // build shift2. So, if we erase MatchInfo.Shift2 at the end, actually we
1605 // remove old shift1. And it will cause crash later. So erase it earlier to
1606 // avoid the crash.
1607 MatchInfo.Shift2->eraseFromParent();
1608
1609 Register Shift2Const = MI.getOperand(2).getReg();
1610 Register Shift2 = Builder
1611 .buildInstr(Opcode, {DestType},
1612 {MatchInfo.LogicNonShiftReg, Shift2Const})
1613 .getReg(0);
1614
1615 Register Dest = MI.getOperand(0).getReg();
1616 Builder.buildInstr(MatchInfo.Logic->getOpcode(), {Dest}, {Shift1, Shift2});
1617
1618 // This was one use so it's safe to remove it.
1619 MatchInfo.Logic->eraseFromParent();
1620
1621 MI.eraseFromParent();
1622 }
1623
matchCombineMulToShl(MachineInstr & MI,unsigned & ShiftVal)1624 bool CombinerHelper::matchCombineMulToShl(MachineInstr &MI,
1625 unsigned &ShiftVal) {
1626 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
1627 auto MaybeImmVal =
1628 getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI);
1629 if (!MaybeImmVal)
1630 return false;
1631
1632 ShiftVal = MaybeImmVal->Value.exactLogBase2();
1633 return (static_cast<int32_t>(ShiftVal) != -1);
1634 }
1635
applyCombineMulToShl(MachineInstr & MI,unsigned & ShiftVal)1636 void CombinerHelper::applyCombineMulToShl(MachineInstr &MI,
1637 unsigned &ShiftVal) {
1638 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
1639 MachineIRBuilder MIB(MI);
1640 LLT ShiftTy = MRI.getType(MI.getOperand(0).getReg());
1641 auto ShiftCst = MIB.buildConstant(ShiftTy, ShiftVal);
1642 Observer.changingInstr(MI);
1643 MI.setDesc(MIB.getTII().get(TargetOpcode::G_SHL));
1644 MI.getOperand(2).setReg(ShiftCst.getReg(0));
1645 Observer.changedInstr(MI);
1646 }
1647
1648 // shl ([sza]ext x), y => zext (shl x, y), if shift does not overflow source
matchCombineShlOfExtend(MachineInstr & MI,RegisterImmPair & MatchData)1649 bool CombinerHelper::matchCombineShlOfExtend(MachineInstr &MI,
1650 RegisterImmPair &MatchData) {
1651 assert(MI.getOpcode() == TargetOpcode::G_SHL && KB);
1652
1653 Register LHS = MI.getOperand(1).getReg();
1654
1655 Register ExtSrc;
1656 if (!mi_match(LHS, MRI, m_GAnyExt(m_Reg(ExtSrc))) &&
1657 !mi_match(LHS, MRI, m_GZExt(m_Reg(ExtSrc))) &&
1658 !mi_match(LHS, MRI, m_GSExt(m_Reg(ExtSrc))))
1659 return false;
1660
1661 // TODO: Should handle vector splat.
1662 Register RHS = MI.getOperand(2).getReg();
1663 auto MaybeShiftAmtVal = getIConstantVRegValWithLookThrough(RHS, MRI);
1664 if (!MaybeShiftAmtVal)
1665 return false;
1666
1667 if (LI) {
1668 LLT SrcTy = MRI.getType(ExtSrc);
1669
1670 // We only really care about the legality with the shifted value. We can
1671 // pick any type the constant shift amount, so ask the target what to
1672 // use. Otherwise we would have to guess and hope it is reported as legal.
1673 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(SrcTy);
1674 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SHL, {SrcTy, ShiftAmtTy}}))
1675 return false;
1676 }
1677
1678 int64_t ShiftAmt = MaybeShiftAmtVal->Value.getSExtValue();
1679 MatchData.Reg = ExtSrc;
1680 MatchData.Imm = ShiftAmt;
1681
1682 unsigned MinLeadingZeros = KB->getKnownZeroes(ExtSrc).countLeadingOnes();
1683 return MinLeadingZeros >= ShiftAmt;
1684 }
1685
applyCombineShlOfExtend(MachineInstr & MI,const RegisterImmPair & MatchData)1686 void CombinerHelper::applyCombineShlOfExtend(MachineInstr &MI,
1687 const RegisterImmPair &MatchData) {
1688 Register ExtSrcReg = MatchData.Reg;
1689 int64_t ShiftAmtVal = MatchData.Imm;
1690
1691 LLT ExtSrcTy = MRI.getType(ExtSrcReg);
1692 Builder.setInstrAndDebugLoc(MI);
1693 auto ShiftAmt = Builder.buildConstant(ExtSrcTy, ShiftAmtVal);
1694 auto NarrowShift =
1695 Builder.buildShl(ExtSrcTy, ExtSrcReg, ShiftAmt, MI.getFlags());
1696 Builder.buildZExt(MI.getOperand(0), NarrowShift);
1697 MI.eraseFromParent();
1698 }
1699
matchCombineMergeUnmerge(MachineInstr & MI,Register & MatchInfo)1700 bool CombinerHelper::matchCombineMergeUnmerge(MachineInstr &MI,
1701 Register &MatchInfo) {
1702 GMerge &Merge = cast<GMerge>(MI);
1703 SmallVector<Register, 16> MergedValues;
1704 for (unsigned I = 0; I < Merge.getNumSources(); ++I)
1705 MergedValues.emplace_back(Merge.getSourceReg(I));
1706
1707 auto *Unmerge = getOpcodeDef<GUnmerge>(MergedValues[0], MRI);
1708 if (!Unmerge || Unmerge->getNumDefs() != Merge.getNumSources())
1709 return false;
1710
1711 for (unsigned I = 0; I < MergedValues.size(); ++I)
1712 if (MergedValues[I] != Unmerge->getReg(I))
1713 return false;
1714
1715 MatchInfo = Unmerge->getSourceReg();
1716 return true;
1717 }
1718
peekThroughBitcast(Register Reg,const MachineRegisterInfo & MRI)1719 static Register peekThroughBitcast(Register Reg,
1720 const MachineRegisterInfo &MRI) {
1721 while (mi_match(Reg, MRI, m_GBitcast(m_Reg(Reg))))
1722 ;
1723
1724 return Reg;
1725 }
1726
matchCombineUnmergeMergeToPlainValues(MachineInstr & MI,SmallVectorImpl<Register> & Operands)1727 bool CombinerHelper::matchCombineUnmergeMergeToPlainValues(
1728 MachineInstr &MI, SmallVectorImpl<Register> &Operands) {
1729 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
1730 "Expected an unmerge");
1731 auto &Unmerge = cast<GUnmerge>(MI);
1732 Register SrcReg = peekThroughBitcast(Unmerge.getSourceReg(), MRI);
1733
1734 auto *SrcInstr = getOpcodeDef<GMergeLikeInstr>(SrcReg, MRI);
1735 if (!SrcInstr)
1736 return false;
1737
1738 // Check the source type of the merge.
1739 LLT SrcMergeTy = MRI.getType(SrcInstr->getSourceReg(0));
1740 LLT Dst0Ty = MRI.getType(Unmerge.getReg(0));
1741 bool SameSize = Dst0Ty.getSizeInBits() == SrcMergeTy.getSizeInBits();
1742 if (SrcMergeTy != Dst0Ty && !SameSize)
1743 return false;
1744 // They are the same now (modulo a bitcast).
1745 // We can collect all the src registers.
1746 for (unsigned Idx = 0; Idx < SrcInstr->getNumSources(); ++Idx)
1747 Operands.push_back(SrcInstr->getSourceReg(Idx));
1748 return true;
1749 }
1750
applyCombineUnmergeMergeToPlainValues(MachineInstr & MI,SmallVectorImpl<Register> & Operands)1751 void CombinerHelper::applyCombineUnmergeMergeToPlainValues(
1752 MachineInstr &MI, SmallVectorImpl<Register> &Operands) {
1753 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
1754 "Expected an unmerge");
1755 assert((MI.getNumOperands() - 1 == Operands.size()) &&
1756 "Not enough operands to replace all defs");
1757 unsigned NumElems = MI.getNumOperands() - 1;
1758
1759 LLT SrcTy = MRI.getType(Operands[0]);
1760 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
1761 bool CanReuseInputDirectly = DstTy == SrcTy;
1762 Builder.setInstrAndDebugLoc(MI);
1763 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
1764 Register DstReg = MI.getOperand(Idx).getReg();
1765 Register SrcReg = Operands[Idx];
1766 if (CanReuseInputDirectly)
1767 replaceRegWith(MRI, DstReg, SrcReg);
1768 else
1769 Builder.buildCast(DstReg, SrcReg);
1770 }
1771 MI.eraseFromParent();
1772 }
1773
matchCombineUnmergeConstant(MachineInstr & MI,SmallVectorImpl<APInt> & Csts)1774 bool CombinerHelper::matchCombineUnmergeConstant(MachineInstr &MI,
1775 SmallVectorImpl<APInt> &Csts) {
1776 unsigned SrcIdx = MI.getNumOperands() - 1;
1777 Register SrcReg = MI.getOperand(SrcIdx).getReg();
1778 MachineInstr *SrcInstr = MRI.getVRegDef(SrcReg);
1779 if (SrcInstr->getOpcode() != TargetOpcode::G_CONSTANT &&
1780 SrcInstr->getOpcode() != TargetOpcode::G_FCONSTANT)
1781 return false;
1782 // Break down the big constant in smaller ones.
1783 const MachineOperand &CstVal = SrcInstr->getOperand(1);
1784 APInt Val = SrcInstr->getOpcode() == TargetOpcode::G_CONSTANT
1785 ? CstVal.getCImm()->getValue()
1786 : CstVal.getFPImm()->getValueAPF().bitcastToAPInt();
1787
1788 LLT Dst0Ty = MRI.getType(MI.getOperand(0).getReg());
1789 unsigned ShiftAmt = Dst0Ty.getSizeInBits();
1790 // Unmerge a constant.
1791 for (unsigned Idx = 0; Idx != SrcIdx; ++Idx) {
1792 Csts.emplace_back(Val.trunc(ShiftAmt));
1793 Val = Val.lshr(ShiftAmt);
1794 }
1795
1796 return true;
1797 }
1798
applyCombineUnmergeConstant(MachineInstr & MI,SmallVectorImpl<APInt> & Csts)1799 void CombinerHelper::applyCombineUnmergeConstant(MachineInstr &MI,
1800 SmallVectorImpl<APInt> &Csts) {
1801 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
1802 "Expected an unmerge");
1803 assert((MI.getNumOperands() - 1 == Csts.size()) &&
1804 "Not enough operands to replace all defs");
1805 unsigned NumElems = MI.getNumOperands() - 1;
1806 Builder.setInstrAndDebugLoc(MI);
1807 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
1808 Register DstReg = MI.getOperand(Idx).getReg();
1809 Builder.buildConstant(DstReg, Csts[Idx]);
1810 }
1811
1812 MI.eraseFromParent();
1813 }
1814
matchCombineUnmergeUndef(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)1815 bool CombinerHelper::matchCombineUnmergeUndef(
1816 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
1817 unsigned SrcIdx = MI.getNumOperands() - 1;
1818 Register SrcReg = MI.getOperand(SrcIdx).getReg();
1819 MatchInfo = [&MI](MachineIRBuilder &B) {
1820 unsigned NumElems = MI.getNumOperands() - 1;
1821 for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
1822 Register DstReg = MI.getOperand(Idx).getReg();
1823 B.buildUndef(DstReg);
1824 }
1825 };
1826 return isa<GImplicitDef>(MRI.getVRegDef(SrcReg));
1827 }
1828
matchCombineUnmergeWithDeadLanesToTrunc(MachineInstr & MI)1829 bool CombinerHelper::matchCombineUnmergeWithDeadLanesToTrunc(MachineInstr &MI) {
1830 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
1831 "Expected an unmerge");
1832 // Check that all the lanes are dead except the first one.
1833 for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) {
1834 if (!MRI.use_nodbg_empty(MI.getOperand(Idx).getReg()))
1835 return false;
1836 }
1837 return true;
1838 }
1839
applyCombineUnmergeWithDeadLanesToTrunc(MachineInstr & MI)1840 void CombinerHelper::applyCombineUnmergeWithDeadLanesToTrunc(MachineInstr &MI) {
1841 Builder.setInstrAndDebugLoc(MI);
1842 Register SrcReg = MI.getOperand(MI.getNumDefs()).getReg();
1843 // Truncating a vector is going to truncate every single lane,
1844 // whereas we want the full lowbits.
1845 // Do the operation on a scalar instead.
1846 LLT SrcTy = MRI.getType(SrcReg);
1847 if (SrcTy.isVector())
1848 SrcReg =
1849 Builder.buildCast(LLT::scalar(SrcTy.getSizeInBits()), SrcReg).getReg(0);
1850
1851 Register Dst0Reg = MI.getOperand(0).getReg();
1852 LLT Dst0Ty = MRI.getType(Dst0Reg);
1853 if (Dst0Ty.isVector()) {
1854 auto MIB = Builder.buildTrunc(LLT::scalar(Dst0Ty.getSizeInBits()), SrcReg);
1855 Builder.buildCast(Dst0Reg, MIB);
1856 } else
1857 Builder.buildTrunc(Dst0Reg, SrcReg);
1858 MI.eraseFromParent();
1859 }
1860
matchCombineUnmergeZExtToZExt(MachineInstr & MI)1861 bool CombinerHelper::matchCombineUnmergeZExtToZExt(MachineInstr &MI) {
1862 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
1863 "Expected an unmerge");
1864 Register Dst0Reg = MI.getOperand(0).getReg();
1865 LLT Dst0Ty = MRI.getType(Dst0Reg);
1866 // G_ZEXT on vector applies to each lane, so it will
1867 // affect all destinations. Therefore we won't be able
1868 // to simplify the unmerge to just the first definition.
1869 if (Dst0Ty.isVector())
1870 return false;
1871 Register SrcReg = MI.getOperand(MI.getNumDefs()).getReg();
1872 LLT SrcTy = MRI.getType(SrcReg);
1873 if (SrcTy.isVector())
1874 return false;
1875
1876 Register ZExtSrcReg;
1877 if (!mi_match(SrcReg, MRI, m_GZExt(m_Reg(ZExtSrcReg))))
1878 return false;
1879
1880 // Finally we can replace the first definition with
1881 // a zext of the source if the definition is big enough to hold
1882 // all of ZExtSrc bits.
1883 LLT ZExtSrcTy = MRI.getType(ZExtSrcReg);
1884 return ZExtSrcTy.getSizeInBits() <= Dst0Ty.getSizeInBits();
1885 }
1886
applyCombineUnmergeZExtToZExt(MachineInstr & MI)1887 void CombinerHelper::applyCombineUnmergeZExtToZExt(MachineInstr &MI) {
1888 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES &&
1889 "Expected an unmerge");
1890
1891 Register Dst0Reg = MI.getOperand(0).getReg();
1892
1893 MachineInstr *ZExtInstr =
1894 MRI.getVRegDef(MI.getOperand(MI.getNumDefs()).getReg());
1895 assert(ZExtInstr && ZExtInstr->getOpcode() == TargetOpcode::G_ZEXT &&
1896 "Expecting a G_ZEXT");
1897
1898 Register ZExtSrcReg = ZExtInstr->getOperand(1).getReg();
1899 LLT Dst0Ty = MRI.getType(Dst0Reg);
1900 LLT ZExtSrcTy = MRI.getType(ZExtSrcReg);
1901
1902 Builder.setInstrAndDebugLoc(MI);
1903
1904 if (Dst0Ty.getSizeInBits() > ZExtSrcTy.getSizeInBits()) {
1905 Builder.buildZExt(Dst0Reg, ZExtSrcReg);
1906 } else {
1907 assert(Dst0Ty.getSizeInBits() == ZExtSrcTy.getSizeInBits() &&
1908 "ZExt src doesn't fit in destination");
1909 replaceRegWith(MRI, Dst0Reg, ZExtSrcReg);
1910 }
1911
1912 Register ZeroReg;
1913 for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) {
1914 if (!ZeroReg)
1915 ZeroReg = Builder.buildConstant(Dst0Ty, 0).getReg(0);
1916 replaceRegWith(MRI, MI.getOperand(Idx).getReg(), ZeroReg);
1917 }
1918 MI.eraseFromParent();
1919 }
1920
matchCombineShiftToUnmerge(MachineInstr & MI,unsigned TargetShiftSize,unsigned & ShiftVal)1921 bool CombinerHelper::matchCombineShiftToUnmerge(MachineInstr &MI,
1922 unsigned TargetShiftSize,
1923 unsigned &ShiftVal) {
1924 assert((MI.getOpcode() == TargetOpcode::G_SHL ||
1925 MI.getOpcode() == TargetOpcode::G_LSHR ||
1926 MI.getOpcode() == TargetOpcode::G_ASHR) && "Expected a shift");
1927
1928 LLT Ty = MRI.getType(MI.getOperand(0).getReg());
1929 if (Ty.isVector()) // TODO:
1930 return false;
1931
1932 // Don't narrow further than the requested size.
1933 unsigned Size = Ty.getSizeInBits();
1934 if (Size <= TargetShiftSize)
1935 return false;
1936
1937 auto MaybeImmVal =
1938 getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI);
1939 if (!MaybeImmVal)
1940 return false;
1941
1942 ShiftVal = MaybeImmVal->Value.getSExtValue();
1943 return ShiftVal >= Size / 2 && ShiftVal < Size;
1944 }
1945
applyCombineShiftToUnmerge(MachineInstr & MI,const unsigned & ShiftVal)1946 void CombinerHelper::applyCombineShiftToUnmerge(MachineInstr &MI,
1947 const unsigned &ShiftVal) {
1948 Register DstReg = MI.getOperand(0).getReg();
1949 Register SrcReg = MI.getOperand(1).getReg();
1950 LLT Ty = MRI.getType(SrcReg);
1951 unsigned Size = Ty.getSizeInBits();
1952 unsigned HalfSize = Size / 2;
1953 assert(ShiftVal >= HalfSize);
1954
1955 LLT HalfTy = LLT::scalar(HalfSize);
1956
1957 Builder.setInstr(MI);
1958 auto Unmerge = Builder.buildUnmerge(HalfTy, SrcReg);
1959 unsigned NarrowShiftAmt = ShiftVal - HalfSize;
1960
1961 if (MI.getOpcode() == TargetOpcode::G_LSHR) {
1962 Register Narrowed = Unmerge.getReg(1);
1963
1964 // dst = G_LSHR s64:x, C for C >= 32
1965 // =>
1966 // lo, hi = G_UNMERGE_VALUES x
1967 // dst = G_MERGE_VALUES (G_LSHR hi, C - 32), 0
1968
1969 if (NarrowShiftAmt != 0) {
1970 Narrowed = Builder.buildLShr(HalfTy, Narrowed,
1971 Builder.buildConstant(HalfTy, NarrowShiftAmt)).getReg(0);
1972 }
1973
1974 auto Zero = Builder.buildConstant(HalfTy, 0);
1975 Builder.buildMergeLikeInstr(DstReg, {Narrowed, Zero});
1976 } else if (MI.getOpcode() == TargetOpcode::G_SHL) {
1977 Register Narrowed = Unmerge.getReg(0);
1978 // dst = G_SHL s64:x, C for C >= 32
1979 // =>
1980 // lo, hi = G_UNMERGE_VALUES x
1981 // dst = G_MERGE_VALUES 0, (G_SHL hi, C - 32)
1982 if (NarrowShiftAmt != 0) {
1983 Narrowed = Builder.buildShl(HalfTy, Narrowed,
1984 Builder.buildConstant(HalfTy, NarrowShiftAmt)).getReg(0);
1985 }
1986
1987 auto Zero = Builder.buildConstant(HalfTy, 0);
1988 Builder.buildMergeLikeInstr(DstReg, {Zero, Narrowed});
1989 } else {
1990 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
1991 auto Hi = Builder.buildAShr(
1992 HalfTy, Unmerge.getReg(1),
1993 Builder.buildConstant(HalfTy, HalfSize - 1));
1994
1995 if (ShiftVal == HalfSize) {
1996 // (G_ASHR i64:x, 32) ->
1997 // G_MERGE_VALUES hi_32(x), (G_ASHR hi_32(x), 31)
1998 Builder.buildMergeLikeInstr(DstReg, {Unmerge.getReg(1), Hi});
1999 } else if (ShiftVal == Size - 1) {
2000 // Don't need a second shift.
2001 // (G_ASHR i64:x, 63) ->
2002 // %narrowed = (G_ASHR hi_32(x), 31)
2003 // G_MERGE_VALUES %narrowed, %narrowed
2004 Builder.buildMergeLikeInstr(DstReg, {Hi, Hi});
2005 } else {
2006 auto Lo = Builder.buildAShr(
2007 HalfTy, Unmerge.getReg(1),
2008 Builder.buildConstant(HalfTy, ShiftVal - HalfSize));
2009
2010 // (G_ASHR i64:x, C) ->, for C >= 32
2011 // G_MERGE_VALUES (G_ASHR hi_32(x), C - 32), (G_ASHR hi_32(x), 31)
2012 Builder.buildMergeLikeInstr(DstReg, {Lo, Hi});
2013 }
2014 }
2015
2016 MI.eraseFromParent();
2017 }
2018
tryCombineShiftToUnmerge(MachineInstr & MI,unsigned TargetShiftAmount)2019 bool CombinerHelper::tryCombineShiftToUnmerge(MachineInstr &MI,
2020 unsigned TargetShiftAmount) {
2021 unsigned ShiftAmt;
2022 if (matchCombineShiftToUnmerge(MI, TargetShiftAmount, ShiftAmt)) {
2023 applyCombineShiftToUnmerge(MI, ShiftAmt);
2024 return true;
2025 }
2026
2027 return false;
2028 }
2029
matchCombineI2PToP2I(MachineInstr & MI,Register & Reg)2030 bool CombinerHelper::matchCombineI2PToP2I(MachineInstr &MI, Register &Reg) {
2031 assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR");
2032 Register DstReg = MI.getOperand(0).getReg();
2033 LLT DstTy = MRI.getType(DstReg);
2034 Register SrcReg = MI.getOperand(1).getReg();
2035 return mi_match(SrcReg, MRI,
2036 m_GPtrToInt(m_all_of(m_SpecificType(DstTy), m_Reg(Reg))));
2037 }
2038
applyCombineI2PToP2I(MachineInstr & MI,Register & Reg)2039 void CombinerHelper::applyCombineI2PToP2I(MachineInstr &MI, Register &Reg) {
2040 assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR");
2041 Register DstReg = MI.getOperand(0).getReg();
2042 Builder.setInstr(MI);
2043 Builder.buildCopy(DstReg, Reg);
2044 MI.eraseFromParent();
2045 }
2046
applyCombineP2IToI2P(MachineInstr & MI,Register & Reg)2047 void CombinerHelper::applyCombineP2IToI2P(MachineInstr &MI, Register &Reg) {
2048 assert(MI.getOpcode() == TargetOpcode::G_PTRTOINT && "Expected a G_PTRTOINT");
2049 Register DstReg = MI.getOperand(0).getReg();
2050 Builder.setInstr(MI);
2051 Builder.buildZExtOrTrunc(DstReg, Reg);
2052 MI.eraseFromParent();
2053 }
2054
matchCombineAddP2IToPtrAdd(MachineInstr & MI,std::pair<Register,bool> & PtrReg)2055 bool CombinerHelper::matchCombineAddP2IToPtrAdd(
2056 MachineInstr &MI, std::pair<Register, bool> &PtrReg) {
2057 assert(MI.getOpcode() == TargetOpcode::G_ADD);
2058 Register LHS = MI.getOperand(1).getReg();
2059 Register RHS = MI.getOperand(2).getReg();
2060 LLT IntTy = MRI.getType(LHS);
2061
2062 // G_PTR_ADD always has the pointer in the LHS, so we may need to commute the
2063 // instruction.
2064 PtrReg.second = false;
2065 for (Register SrcReg : {LHS, RHS}) {
2066 if (mi_match(SrcReg, MRI, m_GPtrToInt(m_Reg(PtrReg.first)))) {
2067 // Don't handle cases where the integer is implicitly converted to the
2068 // pointer width.
2069 LLT PtrTy = MRI.getType(PtrReg.first);
2070 if (PtrTy.getScalarSizeInBits() == IntTy.getScalarSizeInBits())
2071 return true;
2072 }
2073
2074 PtrReg.second = true;
2075 }
2076
2077 return false;
2078 }
2079
applyCombineAddP2IToPtrAdd(MachineInstr & MI,std::pair<Register,bool> & PtrReg)2080 void CombinerHelper::applyCombineAddP2IToPtrAdd(
2081 MachineInstr &MI, std::pair<Register, bool> &PtrReg) {
2082 Register Dst = MI.getOperand(0).getReg();
2083 Register LHS = MI.getOperand(1).getReg();
2084 Register RHS = MI.getOperand(2).getReg();
2085
2086 const bool DoCommute = PtrReg.second;
2087 if (DoCommute)
2088 std::swap(LHS, RHS);
2089 LHS = PtrReg.first;
2090
2091 LLT PtrTy = MRI.getType(LHS);
2092
2093 Builder.setInstrAndDebugLoc(MI);
2094 auto PtrAdd = Builder.buildPtrAdd(PtrTy, LHS, RHS);
2095 Builder.buildPtrToInt(Dst, PtrAdd);
2096 MI.eraseFromParent();
2097 }
2098
matchCombineConstPtrAddToI2P(MachineInstr & MI,APInt & NewCst)2099 bool CombinerHelper::matchCombineConstPtrAddToI2P(MachineInstr &MI,
2100 APInt &NewCst) {
2101 auto &PtrAdd = cast<GPtrAdd>(MI);
2102 Register LHS = PtrAdd.getBaseReg();
2103 Register RHS = PtrAdd.getOffsetReg();
2104 MachineRegisterInfo &MRI = Builder.getMF().getRegInfo();
2105
2106 if (auto RHSCst = getIConstantVRegVal(RHS, MRI)) {
2107 APInt Cst;
2108 if (mi_match(LHS, MRI, m_GIntToPtr(m_ICst(Cst)))) {
2109 auto DstTy = MRI.getType(PtrAdd.getReg(0));
2110 // G_INTTOPTR uses zero-extension
2111 NewCst = Cst.zextOrTrunc(DstTy.getSizeInBits());
2112 NewCst += RHSCst->sextOrTrunc(DstTy.getSizeInBits());
2113 return true;
2114 }
2115 }
2116
2117 return false;
2118 }
2119
applyCombineConstPtrAddToI2P(MachineInstr & MI,APInt & NewCst)2120 void CombinerHelper::applyCombineConstPtrAddToI2P(MachineInstr &MI,
2121 APInt &NewCst) {
2122 auto &PtrAdd = cast<GPtrAdd>(MI);
2123 Register Dst = PtrAdd.getReg(0);
2124
2125 Builder.setInstrAndDebugLoc(MI);
2126 Builder.buildConstant(Dst, NewCst);
2127 PtrAdd.eraseFromParent();
2128 }
2129
matchCombineAnyExtTrunc(MachineInstr & MI,Register & Reg)2130 bool CombinerHelper::matchCombineAnyExtTrunc(MachineInstr &MI, Register &Reg) {
2131 assert(MI.getOpcode() == TargetOpcode::G_ANYEXT && "Expected a G_ANYEXT");
2132 Register DstReg = MI.getOperand(0).getReg();
2133 Register SrcReg = MI.getOperand(1).getReg();
2134 LLT DstTy = MRI.getType(DstReg);
2135 return mi_match(SrcReg, MRI,
2136 m_GTrunc(m_all_of(m_Reg(Reg), m_SpecificType(DstTy))));
2137 }
2138
matchCombineZextTrunc(MachineInstr & MI,Register & Reg)2139 bool CombinerHelper::matchCombineZextTrunc(MachineInstr &MI, Register &Reg) {
2140 assert(MI.getOpcode() == TargetOpcode::G_ZEXT && "Expected a G_ZEXT");
2141 Register DstReg = MI.getOperand(0).getReg();
2142 Register SrcReg = MI.getOperand(1).getReg();
2143 LLT DstTy = MRI.getType(DstReg);
2144 if (mi_match(SrcReg, MRI,
2145 m_GTrunc(m_all_of(m_Reg(Reg), m_SpecificType(DstTy))))) {
2146 unsigned DstSize = DstTy.getScalarSizeInBits();
2147 unsigned SrcSize = MRI.getType(SrcReg).getScalarSizeInBits();
2148 return KB->getKnownBits(Reg).countMinLeadingZeros() >= DstSize - SrcSize;
2149 }
2150 return false;
2151 }
2152
matchCombineExtOfExt(MachineInstr & MI,std::tuple<Register,unsigned> & MatchInfo)2153 bool CombinerHelper::matchCombineExtOfExt(
2154 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) {
2155 assert((MI.getOpcode() == TargetOpcode::G_ANYEXT ||
2156 MI.getOpcode() == TargetOpcode::G_SEXT ||
2157 MI.getOpcode() == TargetOpcode::G_ZEXT) &&
2158 "Expected a G_[ASZ]EXT");
2159 Register SrcReg = MI.getOperand(1).getReg();
2160 MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
2161 // Match exts with the same opcode, anyext([sz]ext) and sext(zext).
2162 unsigned Opc = MI.getOpcode();
2163 unsigned SrcOpc = SrcMI->getOpcode();
2164 if (Opc == SrcOpc ||
2165 (Opc == TargetOpcode::G_ANYEXT &&
2166 (SrcOpc == TargetOpcode::G_SEXT || SrcOpc == TargetOpcode::G_ZEXT)) ||
2167 (Opc == TargetOpcode::G_SEXT && SrcOpc == TargetOpcode::G_ZEXT)) {
2168 MatchInfo = std::make_tuple(SrcMI->getOperand(1).getReg(), SrcOpc);
2169 return true;
2170 }
2171 return false;
2172 }
2173
applyCombineExtOfExt(MachineInstr & MI,std::tuple<Register,unsigned> & MatchInfo)2174 void CombinerHelper::applyCombineExtOfExt(
2175 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) {
2176 assert((MI.getOpcode() == TargetOpcode::G_ANYEXT ||
2177 MI.getOpcode() == TargetOpcode::G_SEXT ||
2178 MI.getOpcode() == TargetOpcode::G_ZEXT) &&
2179 "Expected a G_[ASZ]EXT");
2180
2181 Register Reg = std::get<0>(MatchInfo);
2182 unsigned SrcExtOp = std::get<1>(MatchInfo);
2183
2184 // Combine exts with the same opcode.
2185 if (MI.getOpcode() == SrcExtOp) {
2186 Observer.changingInstr(MI);
2187 MI.getOperand(1).setReg(Reg);
2188 Observer.changedInstr(MI);
2189 return;
2190 }
2191
2192 // Combine:
2193 // - anyext([sz]ext x) to [sz]ext x
2194 // - sext(zext x) to zext x
2195 if (MI.getOpcode() == TargetOpcode::G_ANYEXT ||
2196 (MI.getOpcode() == TargetOpcode::G_SEXT &&
2197 SrcExtOp == TargetOpcode::G_ZEXT)) {
2198 Register DstReg = MI.getOperand(0).getReg();
2199 Builder.setInstrAndDebugLoc(MI);
2200 Builder.buildInstr(SrcExtOp, {DstReg}, {Reg});
2201 MI.eraseFromParent();
2202 }
2203 }
2204
applyCombineMulByNegativeOne(MachineInstr & MI)2205 void CombinerHelper::applyCombineMulByNegativeOne(MachineInstr &MI) {
2206 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
2207 Register DstReg = MI.getOperand(0).getReg();
2208 Register SrcReg = MI.getOperand(1).getReg();
2209 LLT DstTy = MRI.getType(DstReg);
2210
2211 Builder.setInstrAndDebugLoc(MI);
2212 Builder.buildSub(DstReg, Builder.buildConstant(DstTy, 0), SrcReg,
2213 MI.getFlags());
2214 MI.eraseFromParent();
2215 }
2216
matchCombineFAbsOfFNeg(MachineInstr & MI,BuildFnTy & MatchInfo)2217 bool CombinerHelper::matchCombineFAbsOfFNeg(MachineInstr &MI,
2218 BuildFnTy &MatchInfo) {
2219 assert(MI.getOpcode() == TargetOpcode::G_FABS && "Expected a G_FABS");
2220 Register Src = MI.getOperand(1).getReg();
2221 Register NegSrc;
2222
2223 if (!mi_match(Src, MRI, m_GFNeg(m_Reg(NegSrc))))
2224 return false;
2225
2226 MatchInfo = [=, &MI](MachineIRBuilder &B) {
2227 Observer.changingInstr(MI);
2228 MI.getOperand(1).setReg(NegSrc);
2229 Observer.changedInstr(MI);
2230 };
2231 return true;
2232 }
2233
matchCombineTruncOfExt(MachineInstr & MI,std::pair<Register,unsigned> & MatchInfo)2234 bool CombinerHelper::matchCombineTruncOfExt(
2235 MachineInstr &MI, std::pair<Register, unsigned> &MatchInfo) {
2236 assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
2237 Register SrcReg = MI.getOperand(1).getReg();
2238 MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
2239 unsigned SrcOpc = SrcMI->getOpcode();
2240 if (SrcOpc == TargetOpcode::G_ANYEXT || SrcOpc == TargetOpcode::G_SEXT ||
2241 SrcOpc == TargetOpcode::G_ZEXT) {
2242 MatchInfo = std::make_pair(SrcMI->getOperand(1).getReg(), SrcOpc);
2243 return true;
2244 }
2245 return false;
2246 }
2247
applyCombineTruncOfExt(MachineInstr & MI,std::pair<Register,unsigned> & MatchInfo)2248 void CombinerHelper::applyCombineTruncOfExt(
2249 MachineInstr &MI, std::pair<Register, unsigned> &MatchInfo) {
2250 assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
2251 Register SrcReg = MatchInfo.first;
2252 unsigned SrcExtOp = MatchInfo.second;
2253 Register DstReg = MI.getOperand(0).getReg();
2254 LLT SrcTy = MRI.getType(SrcReg);
2255 LLT DstTy = MRI.getType(DstReg);
2256 if (SrcTy == DstTy) {
2257 MI.eraseFromParent();
2258 replaceRegWith(MRI, DstReg, SrcReg);
2259 return;
2260 }
2261 Builder.setInstrAndDebugLoc(MI);
2262 if (SrcTy.getSizeInBits() < DstTy.getSizeInBits())
2263 Builder.buildInstr(SrcExtOp, {DstReg}, {SrcReg});
2264 else
2265 Builder.buildTrunc(DstReg, SrcReg);
2266 MI.eraseFromParent();
2267 }
2268
getMidVTForTruncRightShiftCombine(LLT ShiftTy,LLT TruncTy)2269 static LLT getMidVTForTruncRightShiftCombine(LLT ShiftTy, LLT TruncTy) {
2270 const unsigned ShiftSize = ShiftTy.getScalarSizeInBits();
2271 const unsigned TruncSize = TruncTy.getScalarSizeInBits();
2272
2273 // ShiftTy > 32 > TruncTy -> 32
2274 if (ShiftSize > 32 && TruncSize < 32)
2275 return ShiftTy.changeElementSize(32);
2276
2277 // TODO: We could also reduce to 16 bits, but that's more target-dependent.
2278 // Some targets like it, some don't, some only like it under certain
2279 // conditions/processor versions, etc.
2280 // A TL hook might be needed for this.
2281
2282 // Don't combine
2283 return ShiftTy;
2284 }
2285
matchCombineTruncOfShift(MachineInstr & MI,std::pair<MachineInstr *,LLT> & MatchInfo)2286 bool CombinerHelper::matchCombineTruncOfShift(
2287 MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) {
2288 assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
2289 Register DstReg = MI.getOperand(0).getReg();
2290 Register SrcReg = MI.getOperand(1).getReg();
2291
2292 if (!MRI.hasOneNonDBGUse(SrcReg))
2293 return false;
2294
2295 LLT SrcTy = MRI.getType(SrcReg);
2296 LLT DstTy = MRI.getType(DstReg);
2297
2298 MachineInstr *SrcMI = getDefIgnoringCopies(SrcReg, MRI);
2299 const auto &TL = getTargetLowering();
2300
2301 LLT NewShiftTy;
2302 switch (SrcMI->getOpcode()) {
2303 default:
2304 return false;
2305 case TargetOpcode::G_SHL: {
2306 NewShiftTy = DstTy;
2307
2308 // Make sure new shift amount is legal.
2309 KnownBits Known = KB->getKnownBits(SrcMI->getOperand(2).getReg());
2310 if (Known.getMaxValue().uge(NewShiftTy.getScalarSizeInBits()))
2311 return false;
2312 break;
2313 }
2314 case TargetOpcode::G_LSHR:
2315 case TargetOpcode::G_ASHR: {
2316 // For right shifts, we conservatively do not do the transform if the TRUNC
2317 // has any STORE users. The reason is that if we change the type of the
2318 // shift, we may break the truncstore combine.
2319 //
2320 // TODO: Fix truncstore combine to handle (trunc(lshr (trunc x), k)).
2321 for (auto &User : MRI.use_instructions(DstReg))
2322 if (User.getOpcode() == TargetOpcode::G_STORE)
2323 return false;
2324
2325 NewShiftTy = getMidVTForTruncRightShiftCombine(SrcTy, DstTy);
2326 if (NewShiftTy == SrcTy)
2327 return false;
2328
2329 // Make sure we won't lose information by truncating the high bits.
2330 KnownBits Known = KB->getKnownBits(SrcMI->getOperand(2).getReg());
2331 if (Known.getMaxValue().ugt(NewShiftTy.getScalarSizeInBits() -
2332 DstTy.getScalarSizeInBits()))
2333 return false;
2334 break;
2335 }
2336 }
2337
2338 if (!isLegalOrBeforeLegalizer(
2339 {SrcMI->getOpcode(),
2340 {NewShiftTy, TL.getPreferredShiftAmountTy(NewShiftTy)}}))
2341 return false;
2342
2343 MatchInfo = std::make_pair(SrcMI, NewShiftTy);
2344 return true;
2345 }
2346
applyCombineTruncOfShift(MachineInstr & MI,std::pair<MachineInstr *,LLT> & MatchInfo)2347 void CombinerHelper::applyCombineTruncOfShift(
2348 MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) {
2349 Builder.setInstrAndDebugLoc(MI);
2350
2351 MachineInstr *ShiftMI = MatchInfo.first;
2352 LLT NewShiftTy = MatchInfo.second;
2353
2354 Register Dst = MI.getOperand(0).getReg();
2355 LLT DstTy = MRI.getType(Dst);
2356
2357 Register ShiftAmt = ShiftMI->getOperand(2).getReg();
2358 Register ShiftSrc = ShiftMI->getOperand(1).getReg();
2359 ShiftSrc = Builder.buildTrunc(NewShiftTy, ShiftSrc).getReg(0);
2360
2361 Register NewShift =
2362 Builder
2363 .buildInstr(ShiftMI->getOpcode(), {NewShiftTy}, {ShiftSrc, ShiftAmt})
2364 .getReg(0);
2365
2366 if (NewShiftTy == DstTy)
2367 replaceRegWith(MRI, Dst, NewShift);
2368 else
2369 Builder.buildTrunc(Dst, NewShift);
2370
2371 eraseInst(MI);
2372 }
2373
matchAnyExplicitUseIsUndef(MachineInstr & MI)2374 bool CombinerHelper::matchAnyExplicitUseIsUndef(MachineInstr &MI) {
2375 return any_of(MI.explicit_uses(), [this](const MachineOperand &MO) {
2376 return MO.isReg() &&
2377 getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MO.getReg(), MRI);
2378 });
2379 }
2380
matchAllExplicitUsesAreUndef(MachineInstr & MI)2381 bool CombinerHelper::matchAllExplicitUsesAreUndef(MachineInstr &MI) {
2382 return all_of(MI.explicit_uses(), [this](const MachineOperand &MO) {
2383 return !MO.isReg() ||
2384 getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MO.getReg(), MRI);
2385 });
2386 }
2387
matchUndefShuffleVectorMask(MachineInstr & MI)2388 bool CombinerHelper::matchUndefShuffleVectorMask(MachineInstr &MI) {
2389 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
2390 ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
2391 return all_of(Mask, [](int Elt) { return Elt < 0; });
2392 }
2393
matchUndefStore(MachineInstr & MI)2394 bool CombinerHelper::matchUndefStore(MachineInstr &MI) {
2395 assert(MI.getOpcode() == TargetOpcode::G_STORE);
2396 return getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MI.getOperand(0).getReg(),
2397 MRI);
2398 }
2399
matchUndefSelectCmp(MachineInstr & MI)2400 bool CombinerHelper::matchUndefSelectCmp(MachineInstr &MI) {
2401 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
2402 return getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MI.getOperand(1).getReg(),
2403 MRI);
2404 }
2405
matchInsertExtractVecEltOutOfBounds(MachineInstr & MI)2406 bool CombinerHelper::matchInsertExtractVecEltOutOfBounds(MachineInstr &MI) {
2407 assert((MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT ||
2408 MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) &&
2409 "Expected an insert/extract element op");
2410 LLT VecTy = MRI.getType(MI.getOperand(1).getReg());
2411 unsigned IdxIdx =
2412 MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT ? 2 : 3;
2413 auto Idx = getIConstantVRegVal(MI.getOperand(IdxIdx).getReg(), MRI);
2414 if (!Idx)
2415 return false;
2416 return Idx->getZExtValue() >= VecTy.getNumElements();
2417 }
2418
matchConstantSelectCmp(MachineInstr & MI,unsigned & OpIdx)2419 bool CombinerHelper::matchConstantSelectCmp(MachineInstr &MI, unsigned &OpIdx) {
2420 GSelect &SelMI = cast<GSelect>(MI);
2421 auto Cst =
2422 isConstantOrConstantSplatVector(*MRI.getVRegDef(SelMI.getCondReg()), MRI);
2423 if (!Cst)
2424 return false;
2425 OpIdx = Cst->isZero() ? 3 : 2;
2426 return true;
2427 }
2428
eraseInst(MachineInstr & MI)2429 bool CombinerHelper::eraseInst(MachineInstr &MI) {
2430 MI.eraseFromParent();
2431 return true;
2432 }
2433
matchEqualDefs(const MachineOperand & MOP1,const MachineOperand & MOP2)2434 bool CombinerHelper::matchEqualDefs(const MachineOperand &MOP1,
2435 const MachineOperand &MOP2) {
2436 if (!MOP1.isReg() || !MOP2.isReg())
2437 return false;
2438 auto InstAndDef1 = getDefSrcRegIgnoringCopies(MOP1.getReg(), MRI);
2439 if (!InstAndDef1)
2440 return false;
2441 auto InstAndDef2 = getDefSrcRegIgnoringCopies(MOP2.getReg(), MRI);
2442 if (!InstAndDef2)
2443 return false;
2444 MachineInstr *I1 = InstAndDef1->MI;
2445 MachineInstr *I2 = InstAndDef2->MI;
2446
2447 // Handle a case like this:
2448 //
2449 // %0:_(s64), %1:_(s64) = G_UNMERGE_VALUES %2:_(<2 x s64>)
2450 //
2451 // Even though %0 and %1 are produced by the same instruction they are not
2452 // the same values.
2453 if (I1 == I2)
2454 return MOP1.getReg() == MOP2.getReg();
2455
2456 // If we have an instruction which loads or stores, we can't guarantee that
2457 // it is identical.
2458 //
2459 // For example, we may have
2460 //
2461 // %x1 = G_LOAD %addr (load N from @somewhere)
2462 // ...
2463 // call @foo
2464 // ...
2465 // %x2 = G_LOAD %addr (load N from @somewhere)
2466 // ...
2467 // %or = G_OR %x1, %x2
2468 //
2469 // It's possible that @foo will modify whatever lives at the address we're
2470 // loading from. To be safe, let's just assume that all loads and stores
2471 // are different (unless we have something which is guaranteed to not
2472 // change.)
2473 if (I1->mayLoadOrStore() && !I1->isDereferenceableInvariantLoad())
2474 return false;
2475
2476 // If both instructions are loads or stores, they are equal only if both
2477 // are dereferenceable invariant loads with the same number of bits.
2478 if (I1->mayLoadOrStore() && I2->mayLoadOrStore()) {
2479 GLoadStore *LS1 = dyn_cast<GLoadStore>(I1);
2480 GLoadStore *LS2 = dyn_cast<GLoadStore>(I2);
2481 if (!LS1 || !LS2)
2482 return false;
2483
2484 if (!I2->isDereferenceableInvariantLoad() ||
2485 (LS1->getMemSizeInBits() != LS2->getMemSizeInBits()))
2486 return false;
2487 }
2488
2489 // Check for physical registers on the instructions first to avoid cases
2490 // like this:
2491 //
2492 // %a = COPY $physreg
2493 // ...
2494 // SOMETHING implicit-def $physreg
2495 // ...
2496 // %b = COPY $physreg
2497 //
2498 // These copies are not equivalent.
2499 if (any_of(I1->uses(), [](const MachineOperand &MO) {
2500 return MO.isReg() && MO.getReg().isPhysical();
2501 })) {
2502 // Check if we have a case like this:
2503 //
2504 // %a = COPY $physreg
2505 // %b = COPY %a
2506 //
2507 // In this case, I1 and I2 will both be equal to %a = COPY $physreg.
2508 // From that, we know that they must have the same value, since they must
2509 // have come from the same COPY.
2510 return I1->isIdenticalTo(*I2);
2511 }
2512
2513 // We don't have any physical registers, so we don't necessarily need the
2514 // same vreg defs.
2515 //
2516 // On the off-chance that there's some target instruction feeding into the
2517 // instruction, let's use produceSameValue instead of isIdenticalTo.
2518 if (Builder.getTII().produceSameValue(*I1, *I2, &MRI)) {
2519 // Handle instructions with multiple defs that produce same values. Values
2520 // are same for operands with same index.
2521 // %0:_(s8), %1:_(s8), %2:_(s8), %3:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>)
2522 // %5:_(s8), %6:_(s8), %7:_(s8), %8:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>)
2523 // I1 and I2 are different instructions but produce same values,
2524 // %1 and %6 are same, %1 and %7 are not the same value.
2525 return I1->findRegisterDefOperandIdx(InstAndDef1->Reg) ==
2526 I2->findRegisterDefOperandIdx(InstAndDef2->Reg);
2527 }
2528 return false;
2529 }
2530
matchConstantOp(const MachineOperand & MOP,int64_t C)2531 bool CombinerHelper::matchConstantOp(const MachineOperand &MOP, int64_t C) {
2532 if (!MOP.isReg())
2533 return false;
2534 auto *MI = MRI.getVRegDef(MOP.getReg());
2535 auto MaybeCst = isConstantOrConstantSplatVector(*MI, MRI);
2536 return MaybeCst && MaybeCst->getBitWidth() <= 64 &&
2537 MaybeCst->getSExtValue() == C;
2538 }
2539
replaceSingleDefInstWithOperand(MachineInstr & MI,unsigned OpIdx)2540 bool CombinerHelper::replaceSingleDefInstWithOperand(MachineInstr &MI,
2541 unsigned OpIdx) {
2542 assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?");
2543 Register OldReg = MI.getOperand(0).getReg();
2544 Register Replacement = MI.getOperand(OpIdx).getReg();
2545 assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?");
2546 MI.eraseFromParent();
2547 replaceRegWith(MRI, OldReg, Replacement);
2548 return true;
2549 }
2550
replaceSingleDefInstWithReg(MachineInstr & MI,Register Replacement)2551 bool CombinerHelper::replaceSingleDefInstWithReg(MachineInstr &MI,
2552 Register Replacement) {
2553 assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?");
2554 Register OldReg = MI.getOperand(0).getReg();
2555 assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?");
2556 MI.eraseFromParent();
2557 replaceRegWith(MRI, OldReg, Replacement);
2558 return true;
2559 }
2560
matchSelectSameVal(MachineInstr & MI)2561 bool CombinerHelper::matchSelectSameVal(MachineInstr &MI) {
2562 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
2563 // Match (cond ? x : x)
2564 return matchEqualDefs(MI.getOperand(2), MI.getOperand(3)) &&
2565 canReplaceReg(MI.getOperand(0).getReg(), MI.getOperand(2).getReg(),
2566 MRI);
2567 }
2568
matchBinOpSameVal(MachineInstr & MI)2569 bool CombinerHelper::matchBinOpSameVal(MachineInstr &MI) {
2570 return matchEqualDefs(MI.getOperand(1), MI.getOperand(2)) &&
2571 canReplaceReg(MI.getOperand(0).getReg(), MI.getOperand(1).getReg(),
2572 MRI);
2573 }
2574
matchOperandIsZero(MachineInstr & MI,unsigned OpIdx)2575 bool CombinerHelper::matchOperandIsZero(MachineInstr &MI, unsigned OpIdx) {
2576 return matchConstantOp(MI.getOperand(OpIdx), 0) &&
2577 canReplaceReg(MI.getOperand(0).getReg(), MI.getOperand(OpIdx).getReg(),
2578 MRI);
2579 }
2580
matchOperandIsUndef(MachineInstr & MI,unsigned OpIdx)2581 bool CombinerHelper::matchOperandIsUndef(MachineInstr &MI, unsigned OpIdx) {
2582 MachineOperand &MO = MI.getOperand(OpIdx);
2583 return MO.isReg() &&
2584 getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MO.getReg(), MRI);
2585 }
2586
matchOperandIsKnownToBeAPowerOfTwo(MachineInstr & MI,unsigned OpIdx)2587 bool CombinerHelper::matchOperandIsKnownToBeAPowerOfTwo(MachineInstr &MI,
2588 unsigned OpIdx) {
2589 MachineOperand &MO = MI.getOperand(OpIdx);
2590 return isKnownToBeAPowerOfTwo(MO.getReg(), MRI, KB);
2591 }
2592
replaceInstWithFConstant(MachineInstr & MI,double C)2593 bool CombinerHelper::replaceInstWithFConstant(MachineInstr &MI, double C) {
2594 assert(MI.getNumDefs() == 1 && "Expected only one def?");
2595 Builder.setInstr(MI);
2596 Builder.buildFConstant(MI.getOperand(0), C);
2597 MI.eraseFromParent();
2598 return true;
2599 }
2600
replaceInstWithConstant(MachineInstr & MI,int64_t C)2601 bool CombinerHelper::replaceInstWithConstant(MachineInstr &MI, int64_t C) {
2602 assert(MI.getNumDefs() == 1 && "Expected only one def?");
2603 Builder.setInstr(MI);
2604 Builder.buildConstant(MI.getOperand(0), C);
2605 MI.eraseFromParent();
2606 return true;
2607 }
2608
replaceInstWithConstant(MachineInstr & MI,APInt C)2609 bool CombinerHelper::replaceInstWithConstant(MachineInstr &MI, APInt C) {
2610 assert(MI.getNumDefs() == 1 && "Expected only one def?");
2611 Builder.setInstr(MI);
2612 Builder.buildConstant(MI.getOperand(0), C);
2613 MI.eraseFromParent();
2614 return true;
2615 }
2616
replaceInstWithUndef(MachineInstr & MI)2617 bool CombinerHelper::replaceInstWithUndef(MachineInstr &MI) {
2618 assert(MI.getNumDefs() == 1 && "Expected only one def?");
2619 Builder.setInstr(MI);
2620 Builder.buildUndef(MI.getOperand(0));
2621 MI.eraseFromParent();
2622 return true;
2623 }
2624
matchSimplifyAddToSub(MachineInstr & MI,std::tuple<Register,Register> & MatchInfo)2625 bool CombinerHelper::matchSimplifyAddToSub(
2626 MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) {
2627 Register LHS = MI.getOperand(1).getReg();
2628 Register RHS = MI.getOperand(2).getReg();
2629 Register &NewLHS = std::get<0>(MatchInfo);
2630 Register &NewRHS = std::get<1>(MatchInfo);
2631
2632 // Helper lambda to check for opportunities for
2633 // ((0-A) + B) -> B - A
2634 // (A + (0-B)) -> A - B
2635 auto CheckFold = [&](Register &MaybeSub, Register &MaybeNewLHS) {
2636 if (!mi_match(MaybeSub, MRI, m_Neg(m_Reg(NewRHS))))
2637 return false;
2638 NewLHS = MaybeNewLHS;
2639 return true;
2640 };
2641
2642 return CheckFold(LHS, RHS) || CheckFold(RHS, LHS);
2643 }
2644
matchCombineInsertVecElts(MachineInstr & MI,SmallVectorImpl<Register> & MatchInfo)2645 bool CombinerHelper::matchCombineInsertVecElts(
2646 MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) {
2647 assert(MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT &&
2648 "Invalid opcode");
2649 Register DstReg = MI.getOperand(0).getReg();
2650 LLT DstTy = MRI.getType(DstReg);
2651 assert(DstTy.isVector() && "Invalid G_INSERT_VECTOR_ELT?");
2652 unsigned NumElts = DstTy.getNumElements();
2653 // If this MI is part of a sequence of insert_vec_elts, then
2654 // don't do the combine in the middle of the sequence.
2655 if (MRI.hasOneUse(DstReg) && MRI.use_instr_begin(DstReg)->getOpcode() ==
2656 TargetOpcode::G_INSERT_VECTOR_ELT)
2657 return false;
2658 MachineInstr *CurrInst = &MI;
2659 MachineInstr *TmpInst;
2660 int64_t IntImm;
2661 Register TmpReg;
2662 MatchInfo.resize(NumElts);
2663 while (mi_match(
2664 CurrInst->getOperand(0).getReg(), MRI,
2665 m_GInsertVecElt(m_MInstr(TmpInst), m_Reg(TmpReg), m_ICst(IntImm)))) {
2666 if (IntImm >= NumElts || IntImm < 0)
2667 return false;
2668 if (!MatchInfo[IntImm])
2669 MatchInfo[IntImm] = TmpReg;
2670 CurrInst = TmpInst;
2671 }
2672 // Variable index.
2673 if (CurrInst->getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT)
2674 return false;
2675 if (TmpInst->getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
2676 for (unsigned I = 1; I < TmpInst->getNumOperands(); ++I) {
2677 if (!MatchInfo[I - 1].isValid())
2678 MatchInfo[I - 1] = TmpInst->getOperand(I).getReg();
2679 }
2680 return true;
2681 }
2682 // If we didn't end in a G_IMPLICIT_DEF, bail out.
2683 return TmpInst->getOpcode() == TargetOpcode::G_IMPLICIT_DEF;
2684 }
2685
applyCombineInsertVecElts(MachineInstr & MI,SmallVectorImpl<Register> & MatchInfo)2686 void CombinerHelper::applyCombineInsertVecElts(
2687 MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) {
2688 Builder.setInstr(MI);
2689 Register UndefReg;
2690 auto GetUndef = [&]() {
2691 if (UndefReg)
2692 return UndefReg;
2693 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
2694 UndefReg = Builder.buildUndef(DstTy.getScalarType()).getReg(0);
2695 return UndefReg;
2696 };
2697 for (unsigned I = 0; I < MatchInfo.size(); ++I) {
2698 if (!MatchInfo[I])
2699 MatchInfo[I] = GetUndef();
2700 }
2701 Builder.buildBuildVector(MI.getOperand(0).getReg(), MatchInfo);
2702 MI.eraseFromParent();
2703 }
2704
applySimplifyAddToSub(MachineInstr & MI,std::tuple<Register,Register> & MatchInfo)2705 void CombinerHelper::applySimplifyAddToSub(
2706 MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) {
2707 Builder.setInstr(MI);
2708 Register SubLHS, SubRHS;
2709 std::tie(SubLHS, SubRHS) = MatchInfo;
2710 Builder.buildSub(MI.getOperand(0).getReg(), SubLHS, SubRHS);
2711 MI.eraseFromParent();
2712 }
2713
matchHoistLogicOpWithSameOpcodeHands(MachineInstr & MI,InstructionStepsMatchInfo & MatchInfo)2714 bool CombinerHelper::matchHoistLogicOpWithSameOpcodeHands(
2715 MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) {
2716 // Matches: logic (hand x, ...), (hand y, ...) -> hand (logic x, y), ...
2717 //
2718 // Creates the new hand + logic instruction (but does not insert them.)
2719 //
2720 // On success, MatchInfo is populated with the new instructions. These are
2721 // inserted in applyHoistLogicOpWithSameOpcodeHands.
2722 unsigned LogicOpcode = MI.getOpcode();
2723 assert(LogicOpcode == TargetOpcode::G_AND ||
2724 LogicOpcode == TargetOpcode::G_OR ||
2725 LogicOpcode == TargetOpcode::G_XOR);
2726 MachineIRBuilder MIB(MI);
2727 Register Dst = MI.getOperand(0).getReg();
2728 Register LHSReg = MI.getOperand(1).getReg();
2729 Register RHSReg = MI.getOperand(2).getReg();
2730
2731 // Don't recompute anything.
2732 if (!MRI.hasOneNonDBGUse(LHSReg) || !MRI.hasOneNonDBGUse(RHSReg))
2733 return false;
2734
2735 // Make sure we have (hand x, ...), (hand y, ...)
2736 MachineInstr *LeftHandInst = getDefIgnoringCopies(LHSReg, MRI);
2737 MachineInstr *RightHandInst = getDefIgnoringCopies(RHSReg, MRI);
2738 if (!LeftHandInst || !RightHandInst)
2739 return false;
2740 unsigned HandOpcode = LeftHandInst->getOpcode();
2741 if (HandOpcode != RightHandInst->getOpcode())
2742 return false;
2743 if (!LeftHandInst->getOperand(1).isReg() ||
2744 !RightHandInst->getOperand(1).isReg())
2745 return false;
2746
2747 // Make sure the types match up, and if we're doing this post-legalization,
2748 // we end up with legal types.
2749 Register X = LeftHandInst->getOperand(1).getReg();
2750 Register Y = RightHandInst->getOperand(1).getReg();
2751 LLT XTy = MRI.getType(X);
2752 LLT YTy = MRI.getType(Y);
2753 if (XTy != YTy)
2754 return false;
2755 if (!isLegalOrBeforeLegalizer({LogicOpcode, {XTy, YTy}}))
2756 return false;
2757
2758 // Optional extra source register.
2759 Register ExtraHandOpSrcReg;
2760 switch (HandOpcode) {
2761 default:
2762 return false;
2763 case TargetOpcode::G_ANYEXT:
2764 case TargetOpcode::G_SEXT:
2765 case TargetOpcode::G_ZEXT: {
2766 // Match: logic (ext X), (ext Y) --> ext (logic X, Y)
2767 break;
2768 }
2769 case TargetOpcode::G_AND:
2770 case TargetOpcode::G_ASHR:
2771 case TargetOpcode::G_LSHR:
2772 case TargetOpcode::G_SHL: {
2773 // Match: logic (binop x, z), (binop y, z) -> binop (logic x, y), z
2774 MachineOperand &ZOp = LeftHandInst->getOperand(2);
2775 if (!matchEqualDefs(ZOp, RightHandInst->getOperand(2)))
2776 return false;
2777 ExtraHandOpSrcReg = ZOp.getReg();
2778 break;
2779 }
2780 }
2781
2782 // Record the steps to build the new instructions.
2783 //
2784 // Steps to build (logic x, y)
2785 auto NewLogicDst = MRI.createGenericVirtualRegister(XTy);
2786 OperandBuildSteps LogicBuildSteps = {
2787 [=](MachineInstrBuilder &MIB) { MIB.addDef(NewLogicDst); },
2788 [=](MachineInstrBuilder &MIB) { MIB.addReg(X); },
2789 [=](MachineInstrBuilder &MIB) { MIB.addReg(Y); }};
2790 InstructionBuildSteps LogicSteps(LogicOpcode, LogicBuildSteps);
2791
2792 // Steps to build hand (logic x, y), ...z
2793 OperandBuildSteps HandBuildSteps = {
2794 [=](MachineInstrBuilder &MIB) { MIB.addDef(Dst); },
2795 [=](MachineInstrBuilder &MIB) { MIB.addReg(NewLogicDst); }};
2796 if (ExtraHandOpSrcReg.isValid())
2797 HandBuildSteps.push_back(
2798 [=](MachineInstrBuilder &MIB) { MIB.addReg(ExtraHandOpSrcReg); });
2799 InstructionBuildSteps HandSteps(HandOpcode, HandBuildSteps);
2800
2801 MatchInfo = InstructionStepsMatchInfo({LogicSteps, HandSteps});
2802 return true;
2803 }
2804
applyBuildInstructionSteps(MachineInstr & MI,InstructionStepsMatchInfo & MatchInfo)2805 void CombinerHelper::applyBuildInstructionSteps(
2806 MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) {
2807 assert(MatchInfo.InstrsToBuild.size() &&
2808 "Expected at least one instr to build?");
2809 Builder.setInstr(MI);
2810 for (auto &InstrToBuild : MatchInfo.InstrsToBuild) {
2811 assert(InstrToBuild.Opcode && "Expected a valid opcode?");
2812 assert(InstrToBuild.OperandFns.size() && "Expected at least one operand?");
2813 MachineInstrBuilder Instr = Builder.buildInstr(InstrToBuild.Opcode);
2814 for (auto &OperandFn : InstrToBuild.OperandFns)
2815 OperandFn(Instr);
2816 }
2817 MI.eraseFromParent();
2818 }
2819
matchAshrShlToSextInreg(MachineInstr & MI,std::tuple<Register,int64_t> & MatchInfo)2820 bool CombinerHelper::matchAshrShlToSextInreg(
2821 MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) {
2822 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
2823 int64_t ShlCst, AshrCst;
2824 Register Src;
2825 if (!mi_match(MI.getOperand(0).getReg(), MRI,
2826 m_GAShr(m_GShl(m_Reg(Src), m_ICstOrSplat(ShlCst)),
2827 m_ICstOrSplat(AshrCst))))
2828 return false;
2829 if (ShlCst != AshrCst)
2830 return false;
2831 if (!isLegalOrBeforeLegalizer(
2832 {TargetOpcode::G_SEXT_INREG, {MRI.getType(Src)}}))
2833 return false;
2834 MatchInfo = std::make_tuple(Src, ShlCst);
2835 return true;
2836 }
2837
applyAshShlToSextInreg(MachineInstr & MI,std::tuple<Register,int64_t> & MatchInfo)2838 void CombinerHelper::applyAshShlToSextInreg(
2839 MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) {
2840 assert(MI.getOpcode() == TargetOpcode::G_ASHR);
2841 Register Src;
2842 int64_t ShiftAmt;
2843 std::tie(Src, ShiftAmt) = MatchInfo;
2844 unsigned Size = MRI.getType(Src).getScalarSizeInBits();
2845 Builder.setInstrAndDebugLoc(MI);
2846 Builder.buildSExtInReg(MI.getOperand(0).getReg(), Src, Size - ShiftAmt);
2847 MI.eraseFromParent();
2848 }
2849
2850 /// and(and(x, C1), C2) -> C1&C2 ? and(x, C1&C2) : 0
matchOverlappingAnd(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)2851 bool CombinerHelper::matchOverlappingAnd(
2852 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
2853 assert(MI.getOpcode() == TargetOpcode::G_AND);
2854
2855 Register Dst = MI.getOperand(0).getReg();
2856 LLT Ty = MRI.getType(Dst);
2857
2858 Register R;
2859 int64_t C1;
2860 int64_t C2;
2861 if (!mi_match(
2862 Dst, MRI,
2863 m_GAnd(m_GAnd(m_Reg(R), m_ICst(C1)), m_ICst(C2))))
2864 return false;
2865
2866 MatchInfo = [=](MachineIRBuilder &B) {
2867 if (C1 & C2) {
2868 B.buildAnd(Dst, R, B.buildConstant(Ty, C1 & C2));
2869 return;
2870 }
2871 auto Zero = B.buildConstant(Ty, 0);
2872 replaceRegWith(MRI, Dst, Zero->getOperand(0).getReg());
2873 };
2874 return true;
2875 }
2876
matchRedundantAnd(MachineInstr & MI,Register & Replacement)2877 bool CombinerHelper::matchRedundantAnd(MachineInstr &MI,
2878 Register &Replacement) {
2879 // Given
2880 //
2881 // %y:_(sN) = G_SOMETHING
2882 // %x:_(sN) = G_SOMETHING
2883 // %res:_(sN) = G_AND %x, %y
2884 //
2885 // Eliminate the G_AND when it is known that x & y == x or x & y == y.
2886 //
2887 // Patterns like this can appear as a result of legalization. E.g.
2888 //
2889 // %cmp:_(s32) = G_ICMP intpred(pred), %x(s32), %y
2890 // %one:_(s32) = G_CONSTANT i32 1
2891 // %and:_(s32) = G_AND %cmp, %one
2892 //
2893 // In this case, G_ICMP only produces a single bit, so x & 1 == x.
2894 assert(MI.getOpcode() == TargetOpcode::G_AND);
2895 if (!KB)
2896 return false;
2897
2898 Register AndDst = MI.getOperand(0).getReg();
2899 Register LHS = MI.getOperand(1).getReg();
2900 Register RHS = MI.getOperand(2).getReg();
2901 KnownBits LHSBits = KB->getKnownBits(LHS);
2902 KnownBits RHSBits = KB->getKnownBits(RHS);
2903
2904 // Check that x & Mask == x.
2905 // x & 1 == x, always
2906 // x & 0 == x, only if x is also 0
2907 // Meaning Mask has no effect if every bit is either one in Mask or zero in x.
2908 //
2909 // Check if we can replace AndDst with the LHS of the G_AND
2910 if (canReplaceReg(AndDst, LHS, MRI) &&
2911 (LHSBits.Zero | RHSBits.One).isAllOnes()) {
2912 Replacement = LHS;
2913 return true;
2914 }
2915
2916 // Check if we can replace AndDst with the RHS of the G_AND
2917 if (canReplaceReg(AndDst, RHS, MRI) &&
2918 (LHSBits.One | RHSBits.Zero).isAllOnes()) {
2919 Replacement = RHS;
2920 return true;
2921 }
2922
2923 return false;
2924 }
2925
matchRedundantOr(MachineInstr & MI,Register & Replacement)2926 bool CombinerHelper::matchRedundantOr(MachineInstr &MI, Register &Replacement) {
2927 // Given
2928 //
2929 // %y:_(sN) = G_SOMETHING
2930 // %x:_(sN) = G_SOMETHING
2931 // %res:_(sN) = G_OR %x, %y
2932 //
2933 // Eliminate the G_OR when it is known that x | y == x or x | y == y.
2934 assert(MI.getOpcode() == TargetOpcode::G_OR);
2935 if (!KB)
2936 return false;
2937
2938 Register OrDst = MI.getOperand(0).getReg();
2939 Register LHS = MI.getOperand(1).getReg();
2940 Register RHS = MI.getOperand(2).getReg();
2941 KnownBits LHSBits = KB->getKnownBits(LHS);
2942 KnownBits RHSBits = KB->getKnownBits(RHS);
2943
2944 // Check that x | Mask == x.
2945 // x | 0 == x, always
2946 // x | 1 == x, only if x is also 1
2947 // Meaning Mask has no effect if every bit is either zero in Mask or one in x.
2948 //
2949 // Check if we can replace OrDst with the LHS of the G_OR
2950 if (canReplaceReg(OrDst, LHS, MRI) &&
2951 (LHSBits.One | RHSBits.Zero).isAllOnes()) {
2952 Replacement = LHS;
2953 return true;
2954 }
2955
2956 // Check if we can replace OrDst with the RHS of the G_OR
2957 if (canReplaceReg(OrDst, RHS, MRI) &&
2958 (LHSBits.Zero | RHSBits.One).isAllOnes()) {
2959 Replacement = RHS;
2960 return true;
2961 }
2962
2963 return false;
2964 }
2965
matchRedundantSExtInReg(MachineInstr & MI)2966 bool CombinerHelper::matchRedundantSExtInReg(MachineInstr &MI) {
2967 // If the input is already sign extended, just drop the extension.
2968 Register Src = MI.getOperand(1).getReg();
2969 unsigned ExtBits = MI.getOperand(2).getImm();
2970 unsigned TypeSize = MRI.getType(Src).getScalarSizeInBits();
2971 return KB->computeNumSignBits(Src) >= (TypeSize - ExtBits + 1);
2972 }
2973
isConstValidTrue(const TargetLowering & TLI,unsigned ScalarSizeBits,int64_t Cst,bool IsVector,bool IsFP)2974 static bool isConstValidTrue(const TargetLowering &TLI, unsigned ScalarSizeBits,
2975 int64_t Cst, bool IsVector, bool IsFP) {
2976 // For i1, Cst will always be -1 regardless of boolean contents.
2977 return (ScalarSizeBits == 1 && Cst == -1) ||
2978 isConstTrueVal(TLI, Cst, IsVector, IsFP);
2979 }
2980
matchNotCmp(MachineInstr & MI,SmallVectorImpl<Register> & RegsToNegate)2981 bool CombinerHelper::matchNotCmp(MachineInstr &MI,
2982 SmallVectorImpl<Register> &RegsToNegate) {
2983 assert(MI.getOpcode() == TargetOpcode::G_XOR);
2984 LLT Ty = MRI.getType(MI.getOperand(0).getReg());
2985 const auto &TLI = *Builder.getMF().getSubtarget().getTargetLowering();
2986 Register XorSrc;
2987 Register CstReg;
2988 // We match xor(src, true) here.
2989 if (!mi_match(MI.getOperand(0).getReg(), MRI,
2990 m_GXor(m_Reg(XorSrc), m_Reg(CstReg))))
2991 return false;
2992
2993 if (!MRI.hasOneNonDBGUse(XorSrc))
2994 return false;
2995
2996 // Check that XorSrc is the root of a tree of comparisons combined with ANDs
2997 // and ORs. The suffix of RegsToNegate starting from index I is used a work
2998 // list of tree nodes to visit.
2999 RegsToNegate.push_back(XorSrc);
3000 // Remember whether the comparisons are all integer or all floating point.
3001 bool IsInt = false;
3002 bool IsFP = false;
3003 for (unsigned I = 0; I < RegsToNegate.size(); ++I) {
3004 Register Reg = RegsToNegate[I];
3005 if (!MRI.hasOneNonDBGUse(Reg))
3006 return false;
3007 MachineInstr *Def = MRI.getVRegDef(Reg);
3008 switch (Def->getOpcode()) {
3009 default:
3010 // Don't match if the tree contains anything other than ANDs, ORs and
3011 // comparisons.
3012 return false;
3013 case TargetOpcode::G_ICMP:
3014 if (IsFP)
3015 return false;
3016 IsInt = true;
3017 // When we apply the combine we will invert the predicate.
3018 break;
3019 case TargetOpcode::G_FCMP:
3020 if (IsInt)
3021 return false;
3022 IsFP = true;
3023 // When we apply the combine we will invert the predicate.
3024 break;
3025 case TargetOpcode::G_AND:
3026 case TargetOpcode::G_OR:
3027 // Implement De Morgan's laws:
3028 // ~(x & y) -> ~x | ~y
3029 // ~(x | y) -> ~x & ~y
3030 // When we apply the combine we will change the opcode and recursively
3031 // negate the operands.
3032 RegsToNegate.push_back(Def->getOperand(1).getReg());
3033 RegsToNegate.push_back(Def->getOperand(2).getReg());
3034 break;
3035 }
3036 }
3037
3038 // Now we know whether the comparisons are integer or floating point, check
3039 // the constant in the xor.
3040 int64_t Cst;
3041 if (Ty.isVector()) {
3042 MachineInstr *CstDef = MRI.getVRegDef(CstReg);
3043 auto MaybeCst = getIConstantSplatSExtVal(*CstDef, MRI);
3044 if (!MaybeCst)
3045 return false;
3046 if (!isConstValidTrue(TLI, Ty.getScalarSizeInBits(), *MaybeCst, true, IsFP))
3047 return false;
3048 } else {
3049 if (!mi_match(CstReg, MRI, m_ICst(Cst)))
3050 return false;
3051 if (!isConstValidTrue(TLI, Ty.getSizeInBits(), Cst, false, IsFP))
3052 return false;
3053 }
3054
3055 return true;
3056 }
3057
applyNotCmp(MachineInstr & MI,SmallVectorImpl<Register> & RegsToNegate)3058 void CombinerHelper::applyNotCmp(MachineInstr &MI,
3059 SmallVectorImpl<Register> &RegsToNegate) {
3060 for (Register Reg : RegsToNegate) {
3061 MachineInstr *Def = MRI.getVRegDef(Reg);
3062 Observer.changingInstr(*Def);
3063 // For each comparison, invert the opcode. For each AND and OR, change the
3064 // opcode.
3065 switch (Def->getOpcode()) {
3066 default:
3067 llvm_unreachable("Unexpected opcode");
3068 case TargetOpcode::G_ICMP:
3069 case TargetOpcode::G_FCMP: {
3070 MachineOperand &PredOp = Def->getOperand(1);
3071 CmpInst::Predicate NewP = CmpInst::getInversePredicate(
3072 (CmpInst::Predicate)PredOp.getPredicate());
3073 PredOp.setPredicate(NewP);
3074 break;
3075 }
3076 case TargetOpcode::G_AND:
3077 Def->setDesc(Builder.getTII().get(TargetOpcode::G_OR));
3078 break;
3079 case TargetOpcode::G_OR:
3080 Def->setDesc(Builder.getTII().get(TargetOpcode::G_AND));
3081 break;
3082 }
3083 Observer.changedInstr(*Def);
3084 }
3085
3086 replaceRegWith(MRI, MI.getOperand(0).getReg(), MI.getOperand(1).getReg());
3087 MI.eraseFromParent();
3088 }
3089
matchXorOfAndWithSameReg(MachineInstr & MI,std::pair<Register,Register> & MatchInfo)3090 bool CombinerHelper::matchXorOfAndWithSameReg(
3091 MachineInstr &MI, std::pair<Register, Register> &MatchInfo) {
3092 // Match (xor (and x, y), y) (or any of its commuted cases)
3093 assert(MI.getOpcode() == TargetOpcode::G_XOR);
3094 Register &X = MatchInfo.first;
3095 Register &Y = MatchInfo.second;
3096 Register AndReg = MI.getOperand(1).getReg();
3097 Register SharedReg = MI.getOperand(2).getReg();
3098
3099 // Find a G_AND on either side of the G_XOR.
3100 // Look for one of
3101 //
3102 // (xor (and x, y), SharedReg)
3103 // (xor SharedReg, (and x, y))
3104 if (!mi_match(AndReg, MRI, m_GAnd(m_Reg(X), m_Reg(Y)))) {
3105 std::swap(AndReg, SharedReg);
3106 if (!mi_match(AndReg, MRI, m_GAnd(m_Reg(X), m_Reg(Y))))
3107 return false;
3108 }
3109
3110 // Only do this if we'll eliminate the G_AND.
3111 if (!MRI.hasOneNonDBGUse(AndReg))
3112 return false;
3113
3114 // We can combine if SharedReg is the same as either the LHS or RHS of the
3115 // G_AND.
3116 if (Y != SharedReg)
3117 std::swap(X, Y);
3118 return Y == SharedReg;
3119 }
3120
applyXorOfAndWithSameReg(MachineInstr & MI,std::pair<Register,Register> & MatchInfo)3121 void CombinerHelper::applyXorOfAndWithSameReg(
3122 MachineInstr &MI, std::pair<Register, Register> &MatchInfo) {
3123 // Fold (xor (and x, y), y) -> (and (not x), y)
3124 Builder.setInstrAndDebugLoc(MI);
3125 Register X, Y;
3126 std::tie(X, Y) = MatchInfo;
3127 auto Not = Builder.buildNot(MRI.getType(X), X);
3128 Observer.changingInstr(MI);
3129 MI.setDesc(Builder.getTII().get(TargetOpcode::G_AND));
3130 MI.getOperand(1).setReg(Not->getOperand(0).getReg());
3131 MI.getOperand(2).setReg(Y);
3132 Observer.changedInstr(MI);
3133 }
3134
matchPtrAddZero(MachineInstr & MI)3135 bool CombinerHelper::matchPtrAddZero(MachineInstr &MI) {
3136 auto &PtrAdd = cast<GPtrAdd>(MI);
3137 Register DstReg = PtrAdd.getReg(0);
3138 LLT Ty = MRI.getType(DstReg);
3139 const DataLayout &DL = Builder.getMF().getDataLayout();
3140
3141 if (DL.isNonIntegralAddressSpace(Ty.getScalarType().getAddressSpace()))
3142 return false;
3143
3144 if (Ty.isPointer()) {
3145 auto ConstVal = getIConstantVRegVal(PtrAdd.getBaseReg(), MRI);
3146 return ConstVal && *ConstVal == 0;
3147 }
3148
3149 assert(Ty.isVector() && "Expecting a vector type");
3150 const MachineInstr *VecMI = MRI.getVRegDef(PtrAdd.getBaseReg());
3151 return isBuildVectorAllZeros(*VecMI, MRI);
3152 }
3153
applyPtrAddZero(MachineInstr & MI)3154 void CombinerHelper::applyPtrAddZero(MachineInstr &MI) {
3155 auto &PtrAdd = cast<GPtrAdd>(MI);
3156 Builder.setInstrAndDebugLoc(PtrAdd);
3157 Builder.buildIntToPtr(PtrAdd.getReg(0), PtrAdd.getOffsetReg());
3158 PtrAdd.eraseFromParent();
3159 }
3160
3161 /// The second source operand is known to be a power of 2.
applySimplifyURemByPow2(MachineInstr & MI)3162 void CombinerHelper::applySimplifyURemByPow2(MachineInstr &MI) {
3163 Register DstReg = MI.getOperand(0).getReg();
3164 Register Src0 = MI.getOperand(1).getReg();
3165 Register Pow2Src1 = MI.getOperand(2).getReg();
3166 LLT Ty = MRI.getType(DstReg);
3167 Builder.setInstrAndDebugLoc(MI);
3168
3169 // Fold (urem x, pow2) -> (and x, pow2-1)
3170 auto NegOne = Builder.buildConstant(Ty, -1);
3171 auto Add = Builder.buildAdd(Ty, Pow2Src1, NegOne);
3172 Builder.buildAnd(DstReg, Src0, Add);
3173 MI.eraseFromParent();
3174 }
3175
matchFoldBinOpIntoSelect(MachineInstr & MI,unsigned & SelectOpNo)3176 bool CombinerHelper::matchFoldBinOpIntoSelect(MachineInstr &MI,
3177 unsigned &SelectOpNo) {
3178 Register LHS = MI.getOperand(1).getReg();
3179 Register RHS = MI.getOperand(2).getReg();
3180
3181 Register OtherOperandReg = RHS;
3182 SelectOpNo = 1;
3183 MachineInstr *Select = MRI.getVRegDef(LHS);
3184
3185 // Don't do this unless the old select is going away. We want to eliminate the
3186 // binary operator, not replace a binop with a select.
3187 if (Select->getOpcode() != TargetOpcode::G_SELECT ||
3188 !MRI.hasOneNonDBGUse(LHS)) {
3189 OtherOperandReg = LHS;
3190 SelectOpNo = 2;
3191 Select = MRI.getVRegDef(RHS);
3192 if (Select->getOpcode() != TargetOpcode::G_SELECT ||
3193 !MRI.hasOneNonDBGUse(RHS))
3194 return false;
3195 }
3196
3197 MachineInstr *SelectLHS = MRI.getVRegDef(Select->getOperand(2).getReg());
3198 MachineInstr *SelectRHS = MRI.getVRegDef(Select->getOperand(3).getReg());
3199
3200 if (!isConstantOrConstantVector(*SelectLHS, MRI,
3201 /*AllowFP*/ true,
3202 /*AllowOpaqueConstants*/ false))
3203 return false;
3204 if (!isConstantOrConstantVector(*SelectRHS, MRI,
3205 /*AllowFP*/ true,
3206 /*AllowOpaqueConstants*/ false))
3207 return false;
3208
3209 unsigned BinOpcode = MI.getOpcode();
3210
3211 // We know know one of the operands is a select of constants. Now verify that
3212 // the other binary operator operand is either a constant, or we can handle a
3213 // variable.
3214 bool CanFoldNonConst =
3215 (BinOpcode == TargetOpcode::G_AND || BinOpcode == TargetOpcode::G_OR) &&
3216 (isNullOrNullSplat(*SelectLHS, MRI) ||
3217 isAllOnesOrAllOnesSplat(*SelectLHS, MRI)) &&
3218 (isNullOrNullSplat(*SelectRHS, MRI) ||
3219 isAllOnesOrAllOnesSplat(*SelectRHS, MRI));
3220 if (CanFoldNonConst)
3221 return true;
3222
3223 return isConstantOrConstantVector(*MRI.getVRegDef(OtherOperandReg), MRI,
3224 /*AllowFP*/ true,
3225 /*AllowOpaqueConstants*/ false);
3226 }
3227
3228 /// \p SelectOperand is the operand in binary operator \p MI that is the select
3229 /// to fold.
applyFoldBinOpIntoSelect(MachineInstr & MI,const unsigned & SelectOperand)3230 bool CombinerHelper::applyFoldBinOpIntoSelect(MachineInstr &MI,
3231 const unsigned &SelectOperand) {
3232 Builder.setInstrAndDebugLoc(MI);
3233
3234 Register Dst = MI.getOperand(0).getReg();
3235 Register LHS = MI.getOperand(1).getReg();
3236 Register RHS = MI.getOperand(2).getReg();
3237 MachineInstr *Select = MRI.getVRegDef(MI.getOperand(SelectOperand).getReg());
3238
3239 Register SelectCond = Select->getOperand(1).getReg();
3240 Register SelectTrue = Select->getOperand(2).getReg();
3241 Register SelectFalse = Select->getOperand(3).getReg();
3242
3243 LLT Ty = MRI.getType(Dst);
3244 unsigned BinOpcode = MI.getOpcode();
3245
3246 Register FoldTrue, FoldFalse;
3247
3248 // We have a select-of-constants followed by a binary operator with a
3249 // constant. Eliminate the binop by pulling the constant math into the select.
3250 // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO
3251 if (SelectOperand == 1) {
3252 // TODO: SelectionDAG verifies this actually constant folds before
3253 // committing to the combine.
3254
3255 FoldTrue = Builder.buildInstr(BinOpcode, {Ty}, {SelectTrue, RHS}).getReg(0);
3256 FoldFalse =
3257 Builder.buildInstr(BinOpcode, {Ty}, {SelectFalse, RHS}).getReg(0);
3258 } else {
3259 FoldTrue = Builder.buildInstr(BinOpcode, {Ty}, {LHS, SelectTrue}).getReg(0);
3260 FoldFalse =
3261 Builder.buildInstr(BinOpcode, {Ty}, {LHS, SelectFalse}).getReg(0);
3262 }
3263
3264 Builder.buildSelect(Dst, SelectCond, FoldTrue, FoldFalse, MI.getFlags());
3265 MI.eraseFromParent();
3266
3267 return true;
3268 }
3269
3270 std::optional<SmallVector<Register, 8>>
findCandidatesForLoadOrCombine(const MachineInstr * Root) const3271 CombinerHelper::findCandidatesForLoadOrCombine(const MachineInstr *Root) const {
3272 assert(Root->getOpcode() == TargetOpcode::G_OR && "Expected G_OR only!");
3273 // We want to detect if Root is part of a tree which represents a bunch
3274 // of loads being merged into a larger load. We'll try to recognize patterns
3275 // like, for example:
3276 //
3277 // Reg Reg
3278 // \ /
3279 // OR_1 Reg
3280 // \ /
3281 // OR_2
3282 // \ Reg
3283 // .. /
3284 // Root
3285 //
3286 // Reg Reg Reg Reg
3287 // \ / \ /
3288 // OR_1 OR_2
3289 // \ /
3290 // \ /
3291 // ...
3292 // Root
3293 //
3294 // Each "Reg" may have been produced by a load + some arithmetic. This
3295 // function will save each of them.
3296 SmallVector<Register, 8> RegsToVisit;
3297 SmallVector<const MachineInstr *, 7> Ors = {Root};
3298
3299 // In the "worst" case, we're dealing with a load for each byte. So, there
3300 // are at most #bytes - 1 ORs.
3301 const unsigned MaxIter =
3302 MRI.getType(Root->getOperand(0).getReg()).getSizeInBytes() - 1;
3303 for (unsigned Iter = 0; Iter < MaxIter; ++Iter) {
3304 if (Ors.empty())
3305 break;
3306 const MachineInstr *Curr = Ors.pop_back_val();
3307 Register OrLHS = Curr->getOperand(1).getReg();
3308 Register OrRHS = Curr->getOperand(2).getReg();
3309
3310 // In the combine, we want to elimate the entire tree.
3311 if (!MRI.hasOneNonDBGUse(OrLHS) || !MRI.hasOneNonDBGUse(OrRHS))
3312 return std::nullopt;
3313
3314 // If it's a G_OR, save it and continue to walk. If it's not, then it's
3315 // something that may be a load + arithmetic.
3316 if (const MachineInstr *Or = getOpcodeDef(TargetOpcode::G_OR, OrLHS, MRI))
3317 Ors.push_back(Or);
3318 else
3319 RegsToVisit.push_back(OrLHS);
3320 if (const MachineInstr *Or = getOpcodeDef(TargetOpcode::G_OR, OrRHS, MRI))
3321 Ors.push_back(Or);
3322 else
3323 RegsToVisit.push_back(OrRHS);
3324 }
3325
3326 // We're going to try and merge each register into a wider power-of-2 type,
3327 // so we ought to have an even number of registers.
3328 if (RegsToVisit.empty() || RegsToVisit.size() % 2 != 0)
3329 return std::nullopt;
3330 return RegsToVisit;
3331 }
3332
3333 /// Helper function for findLoadOffsetsForLoadOrCombine.
3334 ///
3335 /// Check if \p Reg is the result of loading a \p MemSizeInBits wide value,
3336 /// and then moving that value into a specific byte offset.
3337 ///
3338 /// e.g. x[i] << 24
3339 ///
3340 /// \returns The load instruction and the byte offset it is moved into.
3341 static std::optional<std::pair<GZExtLoad *, int64_t>>
matchLoadAndBytePosition(Register Reg,unsigned MemSizeInBits,const MachineRegisterInfo & MRI)3342 matchLoadAndBytePosition(Register Reg, unsigned MemSizeInBits,
3343 const MachineRegisterInfo &MRI) {
3344 assert(MRI.hasOneNonDBGUse(Reg) &&
3345 "Expected Reg to only have one non-debug use?");
3346 Register MaybeLoad;
3347 int64_t Shift;
3348 if (!mi_match(Reg, MRI,
3349 m_OneNonDBGUse(m_GShl(m_Reg(MaybeLoad), m_ICst(Shift))))) {
3350 Shift = 0;
3351 MaybeLoad = Reg;
3352 }
3353
3354 if (Shift % MemSizeInBits != 0)
3355 return std::nullopt;
3356
3357 // TODO: Handle other types of loads.
3358 auto *Load = getOpcodeDef<GZExtLoad>(MaybeLoad, MRI);
3359 if (!Load)
3360 return std::nullopt;
3361
3362 if (!Load->isUnordered() || Load->getMemSizeInBits() != MemSizeInBits)
3363 return std::nullopt;
3364
3365 return std::make_pair(Load, Shift / MemSizeInBits);
3366 }
3367
3368 std::optional<std::tuple<GZExtLoad *, int64_t, GZExtLoad *>>
findLoadOffsetsForLoadOrCombine(SmallDenseMap<int64_t,int64_t,8> & MemOffset2Idx,const SmallVector<Register,8> & RegsToVisit,const unsigned MemSizeInBits)3369 CombinerHelper::findLoadOffsetsForLoadOrCombine(
3370 SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx,
3371 const SmallVector<Register, 8> &RegsToVisit, const unsigned MemSizeInBits) {
3372
3373 // Each load found for the pattern. There should be one for each RegsToVisit.
3374 SmallSetVector<const MachineInstr *, 8> Loads;
3375
3376 // The lowest index used in any load. (The lowest "i" for each x[i].)
3377 int64_t LowestIdx = INT64_MAX;
3378
3379 // The load which uses the lowest index.
3380 GZExtLoad *LowestIdxLoad = nullptr;
3381
3382 // Keeps track of the load indices we see. We shouldn't see any indices twice.
3383 SmallSet<int64_t, 8> SeenIdx;
3384
3385 // Ensure each load is in the same MBB.
3386 // TODO: Support multiple MachineBasicBlocks.
3387 MachineBasicBlock *MBB = nullptr;
3388 const MachineMemOperand *MMO = nullptr;
3389
3390 // Earliest instruction-order load in the pattern.
3391 GZExtLoad *EarliestLoad = nullptr;
3392
3393 // Latest instruction-order load in the pattern.
3394 GZExtLoad *LatestLoad = nullptr;
3395
3396 // Base pointer which every load should share.
3397 Register BasePtr;
3398
3399 // We want to find a load for each register. Each load should have some
3400 // appropriate bit twiddling arithmetic. During this loop, we will also keep
3401 // track of the load which uses the lowest index. Later, we will check if we
3402 // can use its pointer in the final, combined load.
3403 for (auto Reg : RegsToVisit) {
3404 // Find the load, and find the position that it will end up in (e.g. a
3405 // shifted) value.
3406 auto LoadAndPos = matchLoadAndBytePosition(Reg, MemSizeInBits, MRI);
3407 if (!LoadAndPos)
3408 return std::nullopt;
3409 GZExtLoad *Load;
3410 int64_t DstPos;
3411 std::tie(Load, DstPos) = *LoadAndPos;
3412
3413 // TODO: Handle multiple MachineBasicBlocks. Currently not handled because
3414 // it is difficult to check for stores/calls/etc between loads.
3415 MachineBasicBlock *LoadMBB = Load->getParent();
3416 if (!MBB)
3417 MBB = LoadMBB;
3418 if (LoadMBB != MBB)
3419 return std::nullopt;
3420
3421 // Make sure that the MachineMemOperands of every seen load are compatible.
3422 auto &LoadMMO = Load->getMMO();
3423 if (!MMO)
3424 MMO = &LoadMMO;
3425 if (MMO->getAddrSpace() != LoadMMO.getAddrSpace())
3426 return std::nullopt;
3427
3428 // Find out what the base pointer and index for the load is.
3429 Register LoadPtr;
3430 int64_t Idx;
3431 if (!mi_match(Load->getOperand(1).getReg(), MRI,
3432 m_GPtrAdd(m_Reg(LoadPtr), m_ICst(Idx)))) {
3433 LoadPtr = Load->getOperand(1).getReg();
3434 Idx = 0;
3435 }
3436
3437 // Don't combine things like a[i], a[i] -> a bigger load.
3438 if (!SeenIdx.insert(Idx).second)
3439 return std::nullopt;
3440
3441 // Every load must share the same base pointer; don't combine things like:
3442 //
3443 // a[i], b[i + 1] -> a bigger load.
3444 if (!BasePtr.isValid())
3445 BasePtr = LoadPtr;
3446 if (BasePtr != LoadPtr)
3447 return std::nullopt;
3448
3449 if (Idx < LowestIdx) {
3450 LowestIdx = Idx;
3451 LowestIdxLoad = Load;
3452 }
3453
3454 // Keep track of the byte offset that this load ends up at. If we have seen
3455 // the byte offset, then stop here. We do not want to combine:
3456 //
3457 // a[i] << 16, a[i + k] << 16 -> a bigger load.
3458 if (!MemOffset2Idx.try_emplace(DstPos, Idx).second)
3459 return std::nullopt;
3460 Loads.insert(Load);
3461
3462 // Keep track of the position of the earliest/latest loads in the pattern.
3463 // We will check that there are no load fold barriers between them later
3464 // on.
3465 //
3466 // FIXME: Is there a better way to check for load fold barriers?
3467 if (!EarliestLoad || dominates(*Load, *EarliestLoad))
3468 EarliestLoad = Load;
3469 if (!LatestLoad || dominates(*LatestLoad, *Load))
3470 LatestLoad = Load;
3471 }
3472
3473 // We found a load for each register. Let's check if each load satisfies the
3474 // pattern.
3475 assert(Loads.size() == RegsToVisit.size() &&
3476 "Expected to find a load for each register?");
3477 assert(EarliestLoad != LatestLoad && EarliestLoad &&
3478 LatestLoad && "Expected at least two loads?");
3479
3480 // Check if there are any stores, calls, etc. between any of the loads. If
3481 // there are, then we can't safely perform the combine.
3482 //
3483 // MaxIter is chosen based off the (worst case) number of iterations it
3484 // typically takes to succeed in the LLVM test suite plus some padding.
3485 //
3486 // FIXME: Is there a better way to check for load fold barriers?
3487 const unsigned MaxIter = 20;
3488 unsigned Iter = 0;
3489 for (const auto &MI : instructionsWithoutDebug(EarliestLoad->getIterator(),
3490 LatestLoad->getIterator())) {
3491 if (Loads.count(&MI))
3492 continue;
3493 if (MI.isLoadFoldBarrier())
3494 return std::nullopt;
3495 if (Iter++ == MaxIter)
3496 return std::nullopt;
3497 }
3498
3499 return std::make_tuple(LowestIdxLoad, LowestIdx, LatestLoad);
3500 }
3501
matchLoadOrCombine(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)3502 bool CombinerHelper::matchLoadOrCombine(
3503 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
3504 assert(MI.getOpcode() == TargetOpcode::G_OR);
3505 MachineFunction &MF = *MI.getMF();
3506 // Assuming a little-endian target, transform:
3507 // s8 *a = ...
3508 // s32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
3509 // =>
3510 // s32 val = *((i32)a)
3511 //
3512 // s8 *a = ...
3513 // s32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
3514 // =>
3515 // s32 val = BSWAP(*((s32)a))
3516 Register Dst = MI.getOperand(0).getReg();
3517 LLT Ty = MRI.getType(Dst);
3518 if (Ty.isVector())
3519 return false;
3520
3521 // We need to combine at least two loads into this type. Since the smallest
3522 // possible load is into a byte, we need at least a 16-bit wide type.
3523 const unsigned WideMemSizeInBits = Ty.getSizeInBits();
3524 if (WideMemSizeInBits < 16 || WideMemSizeInBits % 8 != 0)
3525 return false;
3526
3527 // Match a collection of non-OR instructions in the pattern.
3528 auto RegsToVisit = findCandidatesForLoadOrCombine(&MI);
3529 if (!RegsToVisit)
3530 return false;
3531
3532 // We have a collection of non-OR instructions. Figure out how wide each of
3533 // the small loads should be based off of the number of potential loads we
3534 // found.
3535 const unsigned NarrowMemSizeInBits = WideMemSizeInBits / RegsToVisit->size();
3536 if (NarrowMemSizeInBits % 8 != 0)
3537 return false;
3538
3539 // Check if each register feeding into each OR is a load from the same
3540 // base pointer + some arithmetic.
3541 //
3542 // e.g. a[0], a[1] << 8, a[2] << 16, etc.
3543 //
3544 // Also verify that each of these ends up putting a[i] into the same memory
3545 // offset as a load into a wide type would.
3546 SmallDenseMap<int64_t, int64_t, 8> MemOffset2Idx;
3547 GZExtLoad *LowestIdxLoad, *LatestLoad;
3548 int64_t LowestIdx;
3549 auto MaybeLoadInfo = findLoadOffsetsForLoadOrCombine(
3550 MemOffset2Idx, *RegsToVisit, NarrowMemSizeInBits);
3551 if (!MaybeLoadInfo)
3552 return false;
3553 std::tie(LowestIdxLoad, LowestIdx, LatestLoad) = *MaybeLoadInfo;
3554
3555 // We have a bunch of loads being OR'd together. Using the addresses + offsets
3556 // we found before, check if this corresponds to a big or little endian byte
3557 // pattern. If it does, then we can represent it using a load + possibly a
3558 // BSWAP.
3559 bool IsBigEndianTarget = MF.getDataLayout().isBigEndian();
3560 std::optional<bool> IsBigEndian = isBigEndian(MemOffset2Idx, LowestIdx);
3561 if (!IsBigEndian)
3562 return false;
3563 bool NeedsBSwap = IsBigEndianTarget != *IsBigEndian;
3564 if (NeedsBSwap && !isLegalOrBeforeLegalizer({TargetOpcode::G_BSWAP, {Ty}}))
3565 return false;
3566
3567 // Make sure that the load from the lowest index produces offset 0 in the
3568 // final value.
3569 //
3570 // This ensures that we won't combine something like this:
3571 //
3572 // load x[i] -> byte 2
3573 // load x[i+1] -> byte 0 ---> wide_load x[i]
3574 // load x[i+2] -> byte 1
3575 const unsigned NumLoadsInTy = WideMemSizeInBits / NarrowMemSizeInBits;
3576 const unsigned ZeroByteOffset =
3577 *IsBigEndian
3578 ? bigEndianByteAt(NumLoadsInTy, 0)
3579 : littleEndianByteAt(NumLoadsInTy, 0);
3580 auto ZeroOffsetIdx = MemOffset2Idx.find(ZeroByteOffset);
3581 if (ZeroOffsetIdx == MemOffset2Idx.end() ||
3582 ZeroOffsetIdx->second != LowestIdx)
3583 return false;
3584
3585 // We wil reuse the pointer from the load which ends up at byte offset 0. It
3586 // may not use index 0.
3587 Register Ptr = LowestIdxLoad->getPointerReg();
3588 const MachineMemOperand &MMO = LowestIdxLoad->getMMO();
3589 LegalityQuery::MemDesc MMDesc(MMO);
3590 MMDesc.MemoryTy = Ty;
3591 if (!isLegalOrBeforeLegalizer(
3592 {TargetOpcode::G_LOAD, {Ty, MRI.getType(Ptr)}, {MMDesc}}))
3593 return false;
3594 auto PtrInfo = MMO.getPointerInfo();
3595 auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, WideMemSizeInBits / 8);
3596
3597 // Load must be allowed and fast on the target.
3598 LLVMContext &C = MF.getFunction().getContext();
3599 auto &DL = MF.getDataLayout();
3600 unsigned Fast = 0;
3601 if (!getTargetLowering().allowsMemoryAccess(C, DL, Ty, *NewMMO, &Fast) ||
3602 !Fast)
3603 return false;
3604
3605 MatchInfo = [=](MachineIRBuilder &MIB) {
3606 MIB.setInstrAndDebugLoc(*LatestLoad);
3607 Register LoadDst = NeedsBSwap ? MRI.cloneVirtualRegister(Dst) : Dst;
3608 MIB.buildLoad(LoadDst, Ptr, *NewMMO);
3609 if (NeedsBSwap)
3610 MIB.buildBSwap(Dst, LoadDst);
3611 };
3612 return true;
3613 }
3614
3615 /// Check if the store \p Store is a truncstore that can be merged. That is,
3616 /// it's a store of a shifted value of \p SrcVal. If \p SrcVal is an empty
3617 /// Register then it does not need to match and SrcVal is set to the source
3618 /// value found.
3619 /// On match, returns the start byte offset of the \p SrcVal that is being
3620 /// stored.
3621 static std::optional<int64_t>
getTruncStoreByteOffset(GStore & Store,Register & SrcVal,MachineRegisterInfo & MRI)3622 getTruncStoreByteOffset(GStore &Store, Register &SrcVal,
3623 MachineRegisterInfo &MRI) {
3624 Register TruncVal;
3625 if (!mi_match(Store.getValueReg(), MRI, m_GTrunc(m_Reg(TruncVal))))
3626 return std::nullopt;
3627
3628 // The shift amount must be a constant multiple of the narrow type.
3629 // It is translated to the offset address in the wide source value "y".
3630 //
3631 // x = G_LSHR y, ShiftAmtC
3632 // s8 z = G_TRUNC x
3633 // store z, ...
3634 Register FoundSrcVal;
3635 int64_t ShiftAmt;
3636 if (!mi_match(TruncVal, MRI,
3637 m_any_of(m_GLShr(m_Reg(FoundSrcVal), m_ICst(ShiftAmt)),
3638 m_GAShr(m_Reg(FoundSrcVal), m_ICst(ShiftAmt))))) {
3639 if (!SrcVal.isValid() || TruncVal == SrcVal) {
3640 if (!SrcVal.isValid())
3641 SrcVal = TruncVal;
3642 return 0; // If it's the lowest index store.
3643 }
3644 return std::nullopt;
3645 }
3646
3647 unsigned NarrowBits = Store.getMMO().getMemoryType().getScalarSizeInBits();
3648 if (ShiftAmt % NarrowBits!= 0)
3649 return std::nullopt;
3650 const unsigned Offset = ShiftAmt / NarrowBits;
3651
3652 if (SrcVal.isValid() && FoundSrcVal != SrcVal)
3653 return std::nullopt;
3654
3655 if (!SrcVal.isValid())
3656 SrcVal = FoundSrcVal;
3657 else if (MRI.getType(SrcVal) != MRI.getType(FoundSrcVal))
3658 return std::nullopt;
3659 return Offset;
3660 }
3661
3662 /// Match a pattern where a wide type scalar value is stored by several narrow
3663 /// stores. Fold it into a single store or a BSWAP and a store if the targets
3664 /// supports it.
3665 ///
3666 /// Assuming little endian target:
3667 /// i8 *p = ...
3668 /// i32 val = ...
3669 /// p[0] = (val >> 0) & 0xFF;
3670 /// p[1] = (val >> 8) & 0xFF;
3671 /// p[2] = (val >> 16) & 0xFF;
3672 /// p[3] = (val >> 24) & 0xFF;
3673 /// =>
3674 /// *((i32)p) = val;
3675 ///
3676 /// i8 *p = ...
3677 /// i32 val = ...
3678 /// p[0] = (val >> 24) & 0xFF;
3679 /// p[1] = (val >> 16) & 0xFF;
3680 /// p[2] = (val >> 8) & 0xFF;
3681 /// p[3] = (val >> 0) & 0xFF;
3682 /// =>
3683 /// *((i32)p) = BSWAP(val);
matchTruncStoreMerge(MachineInstr & MI,MergeTruncStoresInfo & MatchInfo)3684 bool CombinerHelper::matchTruncStoreMerge(MachineInstr &MI,
3685 MergeTruncStoresInfo &MatchInfo) {
3686 auto &StoreMI = cast<GStore>(MI);
3687 LLT MemTy = StoreMI.getMMO().getMemoryType();
3688
3689 // We only handle merging simple stores of 1-4 bytes.
3690 if (!MemTy.isScalar())
3691 return false;
3692 switch (MemTy.getSizeInBits()) {
3693 case 8:
3694 case 16:
3695 case 32:
3696 break;
3697 default:
3698 return false;
3699 }
3700 if (!StoreMI.isSimple())
3701 return false;
3702
3703 // We do a simple search for mergeable stores prior to this one.
3704 // Any potential alias hazard along the way terminates the search.
3705 SmallVector<GStore *> FoundStores;
3706
3707 // We're looking for:
3708 // 1) a (store(trunc(...)))
3709 // 2) of an LSHR/ASHR of a single wide value, by the appropriate shift to get
3710 // the partial value stored.
3711 // 3) where the offsets form either a little or big-endian sequence.
3712
3713 auto &LastStore = StoreMI;
3714
3715 // The single base pointer that all stores must use.
3716 Register BaseReg;
3717 int64_t LastOffset;
3718 if (!mi_match(LastStore.getPointerReg(), MRI,
3719 m_GPtrAdd(m_Reg(BaseReg), m_ICst(LastOffset)))) {
3720 BaseReg = LastStore.getPointerReg();
3721 LastOffset = 0;
3722 }
3723
3724 GStore *LowestIdxStore = &LastStore;
3725 int64_t LowestIdxOffset = LastOffset;
3726
3727 Register WideSrcVal;
3728 auto LowestShiftAmt = getTruncStoreByteOffset(LastStore, WideSrcVal, MRI);
3729 if (!LowestShiftAmt)
3730 return false; // Didn't match a trunc.
3731 assert(WideSrcVal.isValid());
3732
3733 LLT WideStoreTy = MRI.getType(WideSrcVal);
3734 // The wide type might not be a multiple of the memory type, e.g. s48 and s32.
3735 if (WideStoreTy.getSizeInBits() % MemTy.getSizeInBits() != 0)
3736 return false;
3737 const unsigned NumStoresRequired =
3738 WideStoreTy.getSizeInBits() / MemTy.getSizeInBits();
3739
3740 SmallVector<int64_t, 8> OffsetMap(NumStoresRequired, INT64_MAX);
3741 OffsetMap[*LowestShiftAmt] = LastOffset;
3742 FoundStores.emplace_back(&LastStore);
3743
3744 // Search the block up for more stores.
3745 // We use a search threshold of 10 instructions here because the combiner
3746 // works top-down within a block, and we don't want to search an unbounded
3747 // number of predecessor instructions trying to find matching stores.
3748 // If we moved this optimization into a separate pass then we could probably
3749 // use a more efficient search without having a hard-coded threshold.
3750 const int MaxInstsToCheck = 10;
3751 int NumInstsChecked = 0;
3752 for (auto II = ++LastStore.getReverseIterator();
3753 II != LastStore.getParent()->rend() && NumInstsChecked < MaxInstsToCheck;
3754 ++II) {
3755 NumInstsChecked++;
3756 GStore *NewStore;
3757 if ((NewStore = dyn_cast<GStore>(&*II))) {
3758 if (NewStore->getMMO().getMemoryType() != MemTy || !NewStore->isSimple())
3759 break;
3760 } else if (II->isLoadFoldBarrier() || II->mayLoad()) {
3761 break;
3762 } else {
3763 continue; // This is a safe instruction we can look past.
3764 }
3765
3766 Register NewBaseReg;
3767 int64_t MemOffset;
3768 // Check we're storing to the same base + some offset.
3769 if (!mi_match(NewStore->getPointerReg(), MRI,
3770 m_GPtrAdd(m_Reg(NewBaseReg), m_ICst(MemOffset)))) {
3771 NewBaseReg = NewStore->getPointerReg();
3772 MemOffset = 0;
3773 }
3774 if (BaseReg != NewBaseReg)
3775 break;
3776
3777 auto ShiftByteOffset = getTruncStoreByteOffset(*NewStore, WideSrcVal, MRI);
3778 if (!ShiftByteOffset)
3779 break;
3780 if (MemOffset < LowestIdxOffset) {
3781 LowestIdxOffset = MemOffset;
3782 LowestIdxStore = NewStore;
3783 }
3784
3785 // Map the offset in the store and the offset in the combined value, and
3786 // early return if it has been set before.
3787 if (*ShiftByteOffset < 0 || *ShiftByteOffset >= NumStoresRequired ||
3788 OffsetMap[*ShiftByteOffset] != INT64_MAX)
3789 break;
3790 OffsetMap[*ShiftByteOffset] = MemOffset;
3791
3792 FoundStores.emplace_back(NewStore);
3793 // Reset counter since we've found a matching inst.
3794 NumInstsChecked = 0;
3795 if (FoundStores.size() == NumStoresRequired)
3796 break;
3797 }
3798
3799 if (FoundStores.size() != NumStoresRequired) {
3800 return false;
3801 }
3802
3803 const auto &DL = LastStore.getMF()->getDataLayout();
3804 auto &C = LastStore.getMF()->getFunction().getContext();
3805 // Check that a store of the wide type is both allowed and fast on the target
3806 unsigned Fast = 0;
3807 bool Allowed = getTargetLowering().allowsMemoryAccess(
3808 C, DL, WideStoreTy, LowestIdxStore->getMMO(), &Fast);
3809 if (!Allowed || !Fast)
3810 return false;
3811
3812 // Check if the pieces of the value are going to the expected places in memory
3813 // to merge the stores.
3814 unsigned NarrowBits = MemTy.getScalarSizeInBits();
3815 auto checkOffsets = [&](bool MatchLittleEndian) {
3816 if (MatchLittleEndian) {
3817 for (unsigned i = 0; i != NumStoresRequired; ++i)
3818 if (OffsetMap[i] != i * (NarrowBits / 8) + LowestIdxOffset)
3819 return false;
3820 } else { // MatchBigEndian by reversing loop counter.
3821 for (unsigned i = 0, j = NumStoresRequired - 1; i != NumStoresRequired;
3822 ++i, --j)
3823 if (OffsetMap[j] != i * (NarrowBits / 8) + LowestIdxOffset)
3824 return false;
3825 }
3826 return true;
3827 };
3828
3829 // Check if the offsets line up for the native data layout of this target.
3830 bool NeedBswap = false;
3831 bool NeedRotate = false;
3832 if (!checkOffsets(DL.isLittleEndian())) {
3833 // Special-case: check if byte offsets line up for the opposite endian.
3834 if (NarrowBits == 8 && checkOffsets(DL.isBigEndian()))
3835 NeedBswap = true;
3836 else if (NumStoresRequired == 2 && checkOffsets(DL.isBigEndian()))
3837 NeedRotate = true;
3838 else
3839 return false;
3840 }
3841
3842 if (NeedBswap &&
3843 !isLegalOrBeforeLegalizer({TargetOpcode::G_BSWAP, {WideStoreTy}}))
3844 return false;
3845 if (NeedRotate &&
3846 !isLegalOrBeforeLegalizer({TargetOpcode::G_ROTR, {WideStoreTy}}))
3847 return false;
3848
3849 MatchInfo.NeedBSwap = NeedBswap;
3850 MatchInfo.NeedRotate = NeedRotate;
3851 MatchInfo.LowestIdxStore = LowestIdxStore;
3852 MatchInfo.WideSrcVal = WideSrcVal;
3853 MatchInfo.FoundStores = std::move(FoundStores);
3854 return true;
3855 }
3856
applyTruncStoreMerge(MachineInstr & MI,MergeTruncStoresInfo & MatchInfo)3857 void CombinerHelper::applyTruncStoreMerge(MachineInstr &MI,
3858 MergeTruncStoresInfo &MatchInfo) {
3859
3860 Builder.setInstrAndDebugLoc(MI);
3861 Register WideSrcVal = MatchInfo.WideSrcVal;
3862 LLT WideStoreTy = MRI.getType(WideSrcVal);
3863
3864 if (MatchInfo.NeedBSwap) {
3865 WideSrcVal = Builder.buildBSwap(WideStoreTy, WideSrcVal).getReg(0);
3866 } else if (MatchInfo.NeedRotate) {
3867 assert(WideStoreTy.getSizeInBits() % 2 == 0 &&
3868 "Unexpected type for rotate");
3869 auto RotAmt =
3870 Builder.buildConstant(WideStoreTy, WideStoreTy.getSizeInBits() / 2);
3871 WideSrcVal =
3872 Builder.buildRotateRight(WideStoreTy, WideSrcVal, RotAmt).getReg(0);
3873 }
3874
3875 Builder.buildStore(WideSrcVal, MatchInfo.LowestIdxStore->getPointerReg(),
3876 MatchInfo.LowestIdxStore->getMMO().getPointerInfo(),
3877 MatchInfo.LowestIdxStore->getMMO().getAlign());
3878
3879 // Erase the old stores.
3880 for (auto *ST : MatchInfo.FoundStores)
3881 ST->eraseFromParent();
3882 }
3883
matchExtendThroughPhis(MachineInstr & MI,MachineInstr * & ExtMI)3884 bool CombinerHelper::matchExtendThroughPhis(MachineInstr &MI,
3885 MachineInstr *&ExtMI) {
3886 assert(MI.getOpcode() == TargetOpcode::G_PHI);
3887
3888 Register DstReg = MI.getOperand(0).getReg();
3889
3890 // TODO: Extending a vector may be expensive, don't do this until heuristics
3891 // are better.
3892 if (MRI.getType(DstReg).isVector())
3893 return false;
3894
3895 // Try to match a phi, whose only use is an extend.
3896 if (!MRI.hasOneNonDBGUse(DstReg))
3897 return false;
3898 ExtMI = &*MRI.use_instr_nodbg_begin(DstReg);
3899 switch (ExtMI->getOpcode()) {
3900 case TargetOpcode::G_ANYEXT:
3901 return true; // G_ANYEXT is usually free.
3902 case TargetOpcode::G_ZEXT:
3903 case TargetOpcode::G_SEXT:
3904 break;
3905 default:
3906 return false;
3907 }
3908
3909 // If the target is likely to fold this extend away, don't propagate.
3910 if (Builder.getTII().isExtendLikelyToBeFolded(*ExtMI, MRI))
3911 return false;
3912
3913 // We don't want to propagate the extends unless there's a good chance that
3914 // they'll be optimized in some way.
3915 // Collect the unique incoming values.
3916 SmallPtrSet<MachineInstr *, 4> InSrcs;
3917 for (unsigned Idx = 1; Idx < MI.getNumOperands(); Idx += 2) {
3918 auto *DefMI = getDefIgnoringCopies(MI.getOperand(Idx).getReg(), MRI);
3919 switch (DefMI->getOpcode()) {
3920 case TargetOpcode::G_LOAD:
3921 case TargetOpcode::G_TRUNC:
3922 case TargetOpcode::G_SEXT:
3923 case TargetOpcode::G_ZEXT:
3924 case TargetOpcode::G_ANYEXT:
3925 case TargetOpcode::G_CONSTANT:
3926 InSrcs.insert(getDefIgnoringCopies(MI.getOperand(Idx).getReg(), MRI));
3927 // Don't try to propagate if there are too many places to create new
3928 // extends, chances are it'll increase code size.
3929 if (InSrcs.size() > 2)
3930 return false;
3931 break;
3932 default:
3933 return false;
3934 }
3935 }
3936 return true;
3937 }
3938
applyExtendThroughPhis(MachineInstr & MI,MachineInstr * & ExtMI)3939 void CombinerHelper::applyExtendThroughPhis(MachineInstr &MI,
3940 MachineInstr *&ExtMI) {
3941 assert(MI.getOpcode() == TargetOpcode::G_PHI);
3942 Register DstReg = ExtMI->getOperand(0).getReg();
3943 LLT ExtTy = MRI.getType(DstReg);
3944
3945 // Propagate the extension into the block of each incoming reg's block.
3946 // Use a SetVector here because PHIs can have duplicate edges, and we want
3947 // deterministic iteration order.
3948 SmallSetVector<MachineInstr *, 8> SrcMIs;
3949 SmallDenseMap<MachineInstr *, MachineInstr *, 8> OldToNewSrcMap;
3950 for (unsigned SrcIdx = 1; SrcIdx < MI.getNumOperands(); SrcIdx += 2) {
3951 auto *SrcMI = MRI.getVRegDef(MI.getOperand(SrcIdx).getReg());
3952 if (!SrcMIs.insert(SrcMI))
3953 continue;
3954
3955 // Build an extend after each src inst.
3956 auto *MBB = SrcMI->getParent();
3957 MachineBasicBlock::iterator InsertPt = ++SrcMI->getIterator();
3958 if (InsertPt != MBB->end() && InsertPt->isPHI())
3959 InsertPt = MBB->getFirstNonPHI();
3960
3961 Builder.setInsertPt(*SrcMI->getParent(), InsertPt);
3962 Builder.setDebugLoc(MI.getDebugLoc());
3963 auto NewExt = Builder.buildExtOrTrunc(ExtMI->getOpcode(), ExtTy,
3964 SrcMI->getOperand(0).getReg());
3965 OldToNewSrcMap[SrcMI] = NewExt;
3966 }
3967
3968 // Create a new phi with the extended inputs.
3969 Builder.setInstrAndDebugLoc(MI);
3970 auto NewPhi = Builder.buildInstrNoInsert(TargetOpcode::G_PHI);
3971 NewPhi.addDef(DstReg);
3972 for (const MachineOperand &MO : llvm::drop_begin(MI.operands())) {
3973 if (!MO.isReg()) {
3974 NewPhi.addMBB(MO.getMBB());
3975 continue;
3976 }
3977 auto *NewSrc = OldToNewSrcMap[MRI.getVRegDef(MO.getReg())];
3978 NewPhi.addUse(NewSrc->getOperand(0).getReg());
3979 }
3980 Builder.insertInstr(NewPhi);
3981 ExtMI->eraseFromParent();
3982 }
3983
matchExtractVecEltBuildVec(MachineInstr & MI,Register & Reg)3984 bool CombinerHelper::matchExtractVecEltBuildVec(MachineInstr &MI,
3985 Register &Reg) {
3986 assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT);
3987 // If we have a constant index, look for a G_BUILD_VECTOR source
3988 // and find the source register that the index maps to.
3989 Register SrcVec = MI.getOperand(1).getReg();
3990 LLT SrcTy = MRI.getType(SrcVec);
3991
3992 auto Cst = getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI);
3993 if (!Cst || Cst->Value.getZExtValue() >= SrcTy.getNumElements())
3994 return false;
3995
3996 unsigned VecIdx = Cst->Value.getZExtValue();
3997
3998 // Check if we have a build_vector or build_vector_trunc with an optional
3999 // trunc in front.
4000 MachineInstr *SrcVecMI = MRI.getVRegDef(SrcVec);
4001 if (SrcVecMI->getOpcode() == TargetOpcode::G_TRUNC) {
4002 SrcVecMI = MRI.getVRegDef(SrcVecMI->getOperand(1).getReg());
4003 }
4004
4005 if (SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR &&
4006 SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR_TRUNC)
4007 return false;
4008
4009 EVT Ty(getMVTForLLT(SrcTy));
4010 if (!MRI.hasOneNonDBGUse(SrcVec) &&
4011 !getTargetLowering().aggressivelyPreferBuildVectorSources(Ty))
4012 return false;
4013
4014 Reg = SrcVecMI->getOperand(VecIdx + 1).getReg();
4015 return true;
4016 }
4017
applyExtractVecEltBuildVec(MachineInstr & MI,Register & Reg)4018 void CombinerHelper::applyExtractVecEltBuildVec(MachineInstr &MI,
4019 Register &Reg) {
4020 // Check the type of the register, since it may have come from a
4021 // G_BUILD_VECTOR_TRUNC.
4022 LLT ScalarTy = MRI.getType(Reg);
4023 Register DstReg = MI.getOperand(0).getReg();
4024 LLT DstTy = MRI.getType(DstReg);
4025
4026 Builder.setInstrAndDebugLoc(MI);
4027 if (ScalarTy != DstTy) {
4028 assert(ScalarTy.getSizeInBits() > DstTy.getSizeInBits());
4029 Builder.buildTrunc(DstReg, Reg);
4030 MI.eraseFromParent();
4031 return;
4032 }
4033 replaceSingleDefInstWithReg(MI, Reg);
4034 }
4035
matchExtractAllEltsFromBuildVector(MachineInstr & MI,SmallVectorImpl<std::pair<Register,MachineInstr * >> & SrcDstPairs)4036 bool CombinerHelper::matchExtractAllEltsFromBuildVector(
4037 MachineInstr &MI,
4038 SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) {
4039 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
4040 // This combine tries to find build_vector's which have every source element
4041 // extracted using G_EXTRACT_VECTOR_ELT. This can happen when transforms like
4042 // the masked load scalarization is run late in the pipeline. There's already
4043 // a combine for a similar pattern starting from the extract, but that
4044 // doesn't attempt to do it if there are multiple uses of the build_vector,
4045 // which in this case is true. Starting the combine from the build_vector
4046 // feels more natural than trying to find sibling nodes of extracts.
4047 // E.g.
4048 // %vec(<4 x s32>) = G_BUILD_VECTOR %s1(s32), %s2, %s3, %s4
4049 // %ext1 = G_EXTRACT_VECTOR_ELT %vec, 0
4050 // %ext2 = G_EXTRACT_VECTOR_ELT %vec, 1
4051 // %ext3 = G_EXTRACT_VECTOR_ELT %vec, 2
4052 // %ext4 = G_EXTRACT_VECTOR_ELT %vec, 3
4053 // ==>
4054 // replace ext{1,2,3,4} with %s{1,2,3,4}
4055
4056 Register DstReg = MI.getOperand(0).getReg();
4057 LLT DstTy = MRI.getType(DstReg);
4058 unsigned NumElts = DstTy.getNumElements();
4059
4060 SmallBitVector ExtractedElts(NumElts);
4061 for (MachineInstr &II : MRI.use_nodbg_instructions(DstReg)) {
4062 if (II.getOpcode() != TargetOpcode::G_EXTRACT_VECTOR_ELT)
4063 return false;
4064 auto Cst = getIConstantVRegVal(II.getOperand(2).getReg(), MRI);
4065 if (!Cst)
4066 return false;
4067 unsigned Idx = Cst->getZExtValue();
4068 if (Idx >= NumElts)
4069 return false; // Out of range.
4070 ExtractedElts.set(Idx);
4071 SrcDstPairs.emplace_back(
4072 std::make_pair(MI.getOperand(Idx + 1).getReg(), &II));
4073 }
4074 // Match if every element was extracted.
4075 return ExtractedElts.all();
4076 }
4077
applyExtractAllEltsFromBuildVector(MachineInstr & MI,SmallVectorImpl<std::pair<Register,MachineInstr * >> & SrcDstPairs)4078 void CombinerHelper::applyExtractAllEltsFromBuildVector(
4079 MachineInstr &MI,
4080 SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) {
4081 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
4082 for (auto &Pair : SrcDstPairs) {
4083 auto *ExtMI = Pair.second;
4084 replaceRegWith(MRI, ExtMI->getOperand(0).getReg(), Pair.first);
4085 ExtMI->eraseFromParent();
4086 }
4087 MI.eraseFromParent();
4088 }
4089
applyBuildFn(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)4090 void CombinerHelper::applyBuildFn(
4091 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4092 Builder.setInstrAndDebugLoc(MI);
4093 MatchInfo(Builder);
4094 MI.eraseFromParent();
4095 }
4096
applyBuildFnNoErase(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)4097 void CombinerHelper::applyBuildFnNoErase(
4098 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4099 Builder.setInstrAndDebugLoc(MI);
4100 MatchInfo(Builder);
4101 }
4102
matchOrShiftToFunnelShift(MachineInstr & MI,BuildFnTy & MatchInfo)4103 bool CombinerHelper::matchOrShiftToFunnelShift(MachineInstr &MI,
4104 BuildFnTy &MatchInfo) {
4105 assert(MI.getOpcode() == TargetOpcode::G_OR);
4106
4107 Register Dst = MI.getOperand(0).getReg();
4108 LLT Ty = MRI.getType(Dst);
4109 unsigned BitWidth = Ty.getScalarSizeInBits();
4110
4111 Register ShlSrc, ShlAmt, LShrSrc, LShrAmt, Amt;
4112 unsigned FshOpc = 0;
4113
4114 // Match (or (shl ...), (lshr ...)).
4115 if (!mi_match(Dst, MRI,
4116 // m_GOr() handles the commuted version as well.
4117 m_GOr(m_GShl(m_Reg(ShlSrc), m_Reg(ShlAmt)),
4118 m_GLShr(m_Reg(LShrSrc), m_Reg(LShrAmt)))))
4119 return false;
4120
4121 // Given constants C0 and C1 such that C0 + C1 is bit-width:
4122 // (or (shl x, C0), (lshr y, C1)) -> (fshl x, y, C0) or (fshr x, y, C1)
4123 int64_t CstShlAmt, CstLShrAmt;
4124 if (mi_match(ShlAmt, MRI, m_ICstOrSplat(CstShlAmt)) &&
4125 mi_match(LShrAmt, MRI, m_ICstOrSplat(CstLShrAmt)) &&
4126 CstShlAmt + CstLShrAmt == BitWidth) {
4127 FshOpc = TargetOpcode::G_FSHR;
4128 Amt = LShrAmt;
4129
4130 } else if (mi_match(LShrAmt, MRI,
4131 m_GSub(m_SpecificICstOrSplat(BitWidth), m_Reg(Amt))) &&
4132 ShlAmt == Amt) {
4133 // (or (shl x, amt), (lshr y, (sub bw, amt))) -> (fshl x, y, amt)
4134 FshOpc = TargetOpcode::G_FSHL;
4135
4136 } else if (mi_match(ShlAmt, MRI,
4137 m_GSub(m_SpecificICstOrSplat(BitWidth), m_Reg(Amt))) &&
4138 LShrAmt == Amt) {
4139 // (or (shl x, (sub bw, amt)), (lshr y, amt)) -> (fshr x, y, amt)
4140 FshOpc = TargetOpcode::G_FSHR;
4141
4142 } else {
4143 return false;
4144 }
4145
4146 LLT AmtTy = MRI.getType(Amt);
4147 if (!isLegalOrBeforeLegalizer({FshOpc, {Ty, AmtTy}}))
4148 return false;
4149
4150 MatchInfo = [=](MachineIRBuilder &B) {
4151 B.buildInstr(FshOpc, {Dst}, {ShlSrc, LShrSrc, Amt});
4152 };
4153 return true;
4154 }
4155
4156 /// Match an FSHL or FSHR that can be combined to a ROTR or ROTL rotate.
matchFunnelShiftToRotate(MachineInstr & MI)4157 bool CombinerHelper::matchFunnelShiftToRotate(MachineInstr &MI) {
4158 unsigned Opc = MI.getOpcode();
4159 assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR);
4160 Register X = MI.getOperand(1).getReg();
4161 Register Y = MI.getOperand(2).getReg();
4162 if (X != Y)
4163 return false;
4164 unsigned RotateOpc =
4165 Opc == TargetOpcode::G_FSHL ? TargetOpcode::G_ROTL : TargetOpcode::G_ROTR;
4166 return isLegalOrBeforeLegalizer({RotateOpc, {MRI.getType(X), MRI.getType(Y)}});
4167 }
4168
applyFunnelShiftToRotate(MachineInstr & MI)4169 void CombinerHelper::applyFunnelShiftToRotate(MachineInstr &MI) {
4170 unsigned Opc = MI.getOpcode();
4171 assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR);
4172 bool IsFSHL = Opc == TargetOpcode::G_FSHL;
4173 Observer.changingInstr(MI);
4174 MI.setDesc(Builder.getTII().get(IsFSHL ? TargetOpcode::G_ROTL
4175 : TargetOpcode::G_ROTR));
4176 MI.removeOperand(2);
4177 Observer.changedInstr(MI);
4178 }
4179
4180 // Fold (rot x, c) -> (rot x, c % BitSize)
matchRotateOutOfRange(MachineInstr & MI)4181 bool CombinerHelper::matchRotateOutOfRange(MachineInstr &MI) {
4182 assert(MI.getOpcode() == TargetOpcode::G_ROTL ||
4183 MI.getOpcode() == TargetOpcode::G_ROTR);
4184 unsigned Bitsize =
4185 MRI.getType(MI.getOperand(0).getReg()).getScalarSizeInBits();
4186 Register AmtReg = MI.getOperand(2).getReg();
4187 bool OutOfRange = false;
4188 auto MatchOutOfRange = [Bitsize, &OutOfRange](const Constant *C) {
4189 if (auto *CI = dyn_cast<ConstantInt>(C))
4190 OutOfRange |= CI->getValue().uge(Bitsize);
4191 return true;
4192 };
4193 return matchUnaryPredicate(MRI, AmtReg, MatchOutOfRange) && OutOfRange;
4194 }
4195
applyRotateOutOfRange(MachineInstr & MI)4196 void CombinerHelper::applyRotateOutOfRange(MachineInstr &MI) {
4197 assert(MI.getOpcode() == TargetOpcode::G_ROTL ||
4198 MI.getOpcode() == TargetOpcode::G_ROTR);
4199 unsigned Bitsize =
4200 MRI.getType(MI.getOperand(0).getReg()).getScalarSizeInBits();
4201 Builder.setInstrAndDebugLoc(MI);
4202 Register Amt = MI.getOperand(2).getReg();
4203 LLT AmtTy = MRI.getType(Amt);
4204 auto Bits = Builder.buildConstant(AmtTy, Bitsize);
4205 Amt = Builder.buildURem(AmtTy, MI.getOperand(2).getReg(), Bits).getReg(0);
4206 Observer.changingInstr(MI);
4207 MI.getOperand(2).setReg(Amt);
4208 Observer.changedInstr(MI);
4209 }
4210
matchICmpToTrueFalseKnownBits(MachineInstr & MI,int64_t & MatchInfo)4211 bool CombinerHelper::matchICmpToTrueFalseKnownBits(MachineInstr &MI,
4212 int64_t &MatchInfo) {
4213 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
4214 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
4215 auto KnownLHS = KB->getKnownBits(MI.getOperand(2).getReg());
4216 auto KnownRHS = KB->getKnownBits(MI.getOperand(3).getReg());
4217 std::optional<bool> KnownVal;
4218 switch (Pred) {
4219 default:
4220 llvm_unreachable("Unexpected G_ICMP predicate?");
4221 case CmpInst::ICMP_EQ:
4222 KnownVal = KnownBits::eq(KnownLHS, KnownRHS);
4223 break;
4224 case CmpInst::ICMP_NE:
4225 KnownVal = KnownBits::ne(KnownLHS, KnownRHS);
4226 break;
4227 case CmpInst::ICMP_SGE:
4228 KnownVal = KnownBits::sge(KnownLHS, KnownRHS);
4229 break;
4230 case CmpInst::ICMP_SGT:
4231 KnownVal = KnownBits::sgt(KnownLHS, KnownRHS);
4232 break;
4233 case CmpInst::ICMP_SLE:
4234 KnownVal = KnownBits::sle(KnownLHS, KnownRHS);
4235 break;
4236 case CmpInst::ICMP_SLT:
4237 KnownVal = KnownBits::slt(KnownLHS, KnownRHS);
4238 break;
4239 case CmpInst::ICMP_UGE:
4240 KnownVal = KnownBits::uge(KnownLHS, KnownRHS);
4241 break;
4242 case CmpInst::ICMP_UGT:
4243 KnownVal = KnownBits::ugt(KnownLHS, KnownRHS);
4244 break;
4245 case CmpInst::ICMP_ULE:
4246 KnownVal = KnownBits::ule(KnownLHS, KnownRHS);
4247 break;
4248 case CmpInst::ICMP_ULT:
4249 KnownVal = KnownBits::ult(KnownLHS, KnownRHS);
4250 break;
4251 }
4252 if (!KnownVal)
4253 return false;
4254 MatchInfo =
4255 *KnownVal
4256 ? getICmpTrueVal(getTargetLowering(),
4257 /*IsVector = */
4258 MRI.getType(MI.getOperand(0).getReg()).isVector(),
4259 /* IsFP = */ false)
4260 : 0;
4261 return true;
4262 }
4263
matchICmpToLHSKnownBits(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)4264 bool CombinerHelper::matchICmpToLHSKnownBits(
4265 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4266 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
4267 // Given:
4268 //
4269 // %x = G_WHATEVER (... x is known to be 0 or 1 ...)
4270 // %cmp = G_ICMP ne %x, 0
4271 //
4272 // Or:
4273 //
4274 // %x = G_WHATEVER (... x is known to be 0 or 1 ...)
4275 // %cmp = G_ICMP eq %x, 1
4276 //
4277 // We can replace %cmp with %x assuming true is 1 on the target.
4278 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
4279 if (!CmpInst::isEquality(Pred))
4280 return false;
4281 Register Dst = MI.getOperand(0).getReg();
4282 LLT DstTy = MRI.getType(Dst);
4283 if (getICmpTrueVal(getTargetLowering(), DstTy.isVector(),
4284 /* IsFP = */ false) != 1)
4285 return false;
4286 int64_t OneOrZero = Pred == CmpInst::ICMP_EQ;
4287 if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICst(OneOrZero)))
4288 return false;
4289 Register LHS = MI.getOperand(2).getReg();
4290 auto KnownLHS = KB->getKnownBits(LHS);
4291 if (KnownLHS.getMinValue() != 0 || KnownLHS.getMaxValue() != 1)
4292 return false;
4293 // Make sure replacing Dst with the LHS is a legal operation.
4294 LLT LHSTy = MRI.getType(LHS);
4295 unsigned LHSSize = LHSTy.getSizeInBits();
4296 unsigned DstSize = DstTy.getSizeInBits();
4297 unsigned Op = TargetOpcode::COPY;
4298 if (DstSize != LHSSize)
4299 Op = DstSize < LHSSize ? TargetOpcode::G_TRUNC : TargetOpcode::G_ZEXT;
4300 if (!isLegalOrBeforeLegalizer({Op, {DstTy, LHSTy}}))
4301 return false;
4302 MatchInfo = [=](MachineIRBuilder &B) { B.buildInstr(Op, {Dst}, {LHS}); };
4303 return true;
4304 }
4305
4306 // Replace (and (or x, c1), c2) with (and x, c2) iff c1 & c2 == 0
matchAndOrDisjointMask(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)4307 bool CombinerHelper::matchAndOrDisjointMask(
4308 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4309 assert(MI.getOpcode() == TargetOpcode::G_AND);
4310
4311 // Ignore vector types to simplify matching the two constants.
4312 // TODO: do this for vectors and scalars via a demanded bits analysis.
4313 LLT Ty = MRI.getType(MI.getOperand(0).getReg());
4314 if (Ty.isVector())
4315 return false;
4316
4317 Register Src;
4318 Register AndMaskReg;
4319 int64_t AndMaskBits;
4320 int64_t OrMaskBits;
4321 if (!mi_match(MI, MRI,
4322 m_GAnd(m_GOr(m_Reg(Src), m_ICst(OrMaskBits)),
4323 m_all_of(m_ICst(AndMaskBits), m_Reg(AndMaskReg)))))
4324 return false;
4325
4326 // Check if OrMask could turn on any bits in Src.
4327 if (AndMaskBits & OrMaskBits)
4328 return false;
4329
4330 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4331 Observer.changingInstr(MI);
4332 // Canonicalize the result to have the constant on the RHS.
4333 if (MI.getOperand(1).getReg() == AndMaskReg)
4334 MI.getOperand(2).setReg(AndMaskReg);
4335 MI.getOperand(1).setReg(Src);
4336 Observer.changedInstr(MI);
4337 };
4338 return true;
4339 }
4340
4341 /// Form a G_SBFX from a G_SEXT_INREG fed by a right shift.
matchBitfieldExtractFromSExtInReg(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)4342 bool CombinerHelper::matchBitfieldExtractFromSExtInReg(
4343 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4344 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
4345 Register Dst = MI.getOperand(0).getReg();
4346 Register Src = MI.getOperand(1).getReg();
4347 LLT Ty = MRI.getType(Src);
4348 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
4349 if (!LI || !LI->isLegalOrCustom({TargetOpcode::G_SBFX, {Ty, ExtractTy}}))
4350 return false;
4351 int64_t Width = MI.getOperand(2).getImm();
4352 Register ShiftSrc;
4353 int64_t ShiftImm;
4354 if (!mi_match(
4355 Src, MRI,
4356 m_OneNonDBGUse(m_any_of(m_GAShr(m_Reg(ShiftSrc), m_ICst(ShiftImm)),
4357 m_GLShr(m_Reg(ShiftSrc), m_ICst(ShiftImm))))))
4358 return false;
4359 if (ShiftImm < 0 || ShiftImm + Width > Ty.getScalarSizeInBits())
4360 return false;
4361
4362 MatchInfo = [=](MachineIRBuilder &B) {
4363 auto Cst1 = B.buildConstant(ExtractTy, ShiftImm);
4364 auto Cst2 = B.buildConstant(ExtractTy, Width);
4365 B.buildSbfx(Dst, ShiftSrc, Cst1, Cst2);
4366 };
4367 return true;
4368 }
4369
4370 /// Form a G_UBFX from "(a srl b) & mask", where b and mask are constants.
matchBitfieldExtractFromAnd(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)4371 bool CombinerHelper::matchBitfieldExtractFromAnd(
4372 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4373 assert(MI.getOpcode() == TargetOpcode::G_AND);
4374 Register Dst = MI.getOperand(0).getReg();
4375 LLT Ty = MRI.getType(Dst);
4376 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
4377 if (!getTargetLowering().isConstantUnsignedBitfieldExtractLegal(
4378 TargetOpcode::G_UBFX, Ty, ExtractTy))
4379 return false;
4380
4381 int64_t AndImm, LSBImm;
4382 Register ShiftSrc;
4383 const unsigned Size = Ty.getScalarSizeInBits();
4384 if (!mi_match(MI.getOperand(0).getReg(), MRI,
4385 m_GAnd(m_OneNonDBGUse(m_GLShr(m_Reg(ShiftSrc), m_ICst(LSBImm))),
4386 m_ICst(AndImm))))
4387 return false;
4388
4389 // The mask is a mask of the low bits iff imm & (imm+1) == 0.
4390 auto MaybeMask = static_cast<uint64_t>(AndImm);
4391 if (MaybeMask & (MaybeMask + 1))
4392 return false;
4393
4394 // LSB must fit within the register.
4395 if (static_cast<uint64_t>(LSBImm) >= Size)
4396 return false;
4397
4398 uint64_t Width = APInt(Size, AndImm).countTrailingOnes();
4399 MatchInfo = [=](MachineIRBuilder &B) {
4400 auto WidthCst = B.buildConstant(ExtractTy, Width);
4401 auto LSBCst = B.buildConstant(ExtractTy, LSBImm);
4402 B.buildInstr(TargetOpcode::G_UBFX, {Dst}, {ShiftSrc, LSBCst, WidthCst});
4403 };
4404 return true;
4405 }
4406
matchBitfieldExtractFromShr(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)4407 bool CombinerHelper::matchBitfieldExtractFromShr(
4408 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4409 const unsigned Opcode = MI.getOpcode();
4410 assert(Opcode == TargetOpcode::G_ASHR || Opcode == TargetOpcode::G_LSHR);
4411
4412 const Register Dst = MI.getOperand(0).getReg();
4413
4414 const unsigned ExtrOpcode = Opcode == TargetOpcode::G_ASHR
4415 ? TargetOpcode::G_SBFX
4416 : TargetOpcode::G_UBFX;
4417
4418 // Check if the type we would use for the extract is legal
4419 LLT Ty = MRI.getType(Dst);
4420 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
4421 if (!LI || !LI->isLegalOrCustom({ExtrOpcode, {Ty, ExtractTy}}))
4422 return false;
4423
4424 Register ShlSrc;
4425 int64_t ShrAmt;
4426 int64_t ShlAmt;
4427 const unsigned Size = Ty.getScalarSizeInBits();
4428
4429 // Try to match shr (shl x, c1), c2
4430 if (!mi_match(Dst, MRI,
4431 m_BinOp(Opcode,
4432 m_OneNonDBGUse(m_GShl(m_Reg(ShlSrc), m_ICst(ShlAmt))),
4433 m_ICst(ShrAmt))))
4434 return false;
4435
4436 // Make sure that the shift sizes can fit a bitfield extract
4437 if (ShlAmt < 0 || ShlAmt > ShrAmt || ShrAmt >= Size)
4438 return false;
4439
4440 // Skip this combine if the G_SEXT_INREG combine could handle it
4441 if (Opcode == TargetOpcode::G_ASHR && ShlAmt == ShrAmt)
4442 return false;
4443
4444 // Calculate start position and width of the extract
4445 const int64_t Pos = ShrAmt - ShlAmt;
4446 const int64_t Width = Size - ShrAmt;
4447
4448 MatchInfo = [=](MachineIRBuilder &B) {
4449 auto WidthCst = B.buildConstant(ExtractTy, Width);
4450 auto PosCst = B.buildConstant(ExtractTy, Pos);
4451 B.buildInstr(ExtrOpcode, {Dst}, {ShlSrc, PosCst, WidthCst});
4452 };
4453 return true;
4454 }
4455
matchBitfieldExtractFromShrAnd(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)4456 bool CombinerHelper::matchBitfieldExtractFromShrAnd(
4457 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4458 const unsigned Opcode = MI.getOpcode();
4459 assert(Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_ASHR);
4460
4461 const Register Dst = MI.getOperand(0).getReg();
4462 LLT Ty = MRI.getType(Dst);
4463 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
4464 if (!getTargetLowering().isConstantUnsignedBitfieldExtractLegal(
4465 TargetOpcode::G_UBFX, Ty, ExtractTy))
4466 return false;
4467
4468 // Try to match shr (and x, c1), c2
4469 Register AndSrc;
4470 int64_t ShrAmt;
4471 int64_t SMask;
4472 if (!mi_match(Dst, MRI,
4473 m_BinOp(Opcode,
4474 m_OneNonDBGUse(m_GAnd(m_Reg(AndSrc), m_ICst(SMask))),
4475 m_ICst(ShrAmt))))
4476 return false;
4477
4478 const unsigned Size = Ty.getScalarSizeInBits();
4479 if (ShrAmt < 0 || ShrAmt >= Size)
4480 return false;
4481
4482 // If the shift subsumes the mask, emit the 0 directly.
4483 if (0 == (SMask >> ShrAmt)) {
4484 MatchInfo = [=](MachineIRBuilder &B) {
4485 B.buildConstant(Dst, 0);
4486 };
4487 return true;
4488 }
4489
4490 // Check that ubfx can do the extraction, with no holes in the mask.
4491 uint64_t UMask = SMask;
4492 UMask |= maskTrailingOnes<uint64_t>(ShrAmt);
4493 UMask &= maskTrailingOnes<uint64_t>(Size);
4494 if (!isMask_64(UMask))
4495 return false;
4496
4497 // Calculate start position and width of the extract.
4498 const int64_t Pos = ShrAmt;
4499 const int64_t Width = countTrailingOnes(UMask) - ShrAmt;
4500
4501 // It's preferable to keep the shift, rather than form G_SBFX.
4502 // TODO: remove the G_AND via demanded bits analysis.
4503 if (Opcode == TargetOpcode::G_ASHR && Width + ShrAmt == Size)
4504 return false;
4505
4506 MatchInfo = [=](MachineIRBuilder &B) {
4507 auto WidthCst = B.buildConstant(ExtractTy, Width);
4508 auto PosCst = B.buildConstant(ExtractTy, Pos);
4509 B.buildInstr(TargetOpcode::G_UBFX, {Dst}, {AndSrc, PosCst, WidthCst});
4510 };
4511 return true;
4512 }
4513
reassociationCanBreakAddressingModePattern(MachineInstr & PtrAdd)4514 bool CombinerHelper::reassociationCanBreakAddressingModePattern(
4515 MachineInstr &PtrAdd) {
4516 assert(PtrAdd.getOpcode() == TargetOpcode::G_PTR_ADD);
4517
4518 Register Src1Reg = PtrAdd.getOperand(1).getReg();
4519 MachineInstr *Src1Def = getOpcodeDef(TargetOpcode::G_PTR_ADD, Src1Reg, MRI);
4520 if (!Src1Def)
4521 return false;
4522
4523 Register Src2Reg = PtrAdd.getOperand(2).getReg();
4524
4525 if (MRI.hasOneNonDBGUse(Src1Reg))
4526 return false;
4527
4528 auto C1 = getIConstantVRegVal(Src1Def->getOperand(2).getReg(), MRI);
4529 if (!C1)
4530 return false;
4531 auto C2 = getIConstantVRegVal(Src2Reg, MRI);
4532 if (!C2)
4533 return false;
4534
4535 const APInt &C1APIntVal = *C1;
4536 const APInt &C2APIntVal = *C2;
4537 const int64_t CombinedValue = (C1APIntVal + C2APIntVal).getSExtValue();
4538
4539 for (auto &UseMI : MRI.use_nodbg_instructions(Src1Reg)) {
4540 // This combine may end up running before ptrtoint/inttoptr combines
4541 // manage to eliminate redundant conversions, so try to look through them.
4542 MachineInstr *ConvUseMI = &UseMI;
4543 unsigned ConvUseOpc = ConvUseMI->getOpcode();
4544 while (ConvUseOpc == TargetOpcode::G_INTTOPTR ||
4545 ConvUseOpc == TargetOpcode::G_PTRTOINT) {
4546 Register DefReg = ConvUseMI->getOperand(0).getReg();
4547 if (!MRI.hasOneNonDBGUse(DefReg))
4548 break;
4549 ConvUseMI = &*MRI.use_instr_nodbg_begin(DefReg);
4550 ConvUseOpc = ConvUseMI->getOpcode();
4551 }
4552 auto LoadStore = ConvUseOpc == TargetOpcode::G_LOAD ||
4553 ConvUseOpc == TargetOpcode::G_STORE;
4554 if (!LoadStore)
4555 continue;
4556 // Is x[offset2] already not a legal addressing mode? If so then
4557 // reassociating the constants breaks nothing (we test offset2 because
4558 // that's the one we hope to fold into the load or store).
4559 TargetLoweringBase::AddrMode AM;
4560 AM.HasBaseReg = true;
4561 AM.BaseOffs = C2APIntVal.getSExtValue();
4562 unsigned AS =
4563 MRI.getType(ConvUseMI->getOperand(1).getReg()).getAddressSpace();
4564 Type *AccessTy =
4565 getTypeForLLT(MRI.getType(ConvUseMI->getOperand(0).getReg()),
4566 PtrAdd.getMF()->getFunction().getContext());
4567 const auto &TLI = *PtrAdd.getMF()->getSubtarget().getTargetLowering();
4568 if (!TLI.isLegalAddressingMode(PtrAdd.getMF()->getDataLayout(), AM,
4569 AccessTy, AS))
4570 continue;
4571
4572 // Would x[offset1+offset2] still be a legal addressing mode?
4573 AM.BaseOffs = CombinedValue;
4574 if (!TLI.isLegalAddressingMode(PtrAdd.getMF()->getDataLayout(), AM,
4575 AccessTy, AS))
4576 return true;
4577 }
4578
4579 return false;
4580 }
4581
matchReassocConstantInnerRHS(GPtrAdd & MI,MachineInstr * RHS,BuildFnTy & MatchInfo)4582 bool CombinerHelper::matchReassocConstantInnerRHS(GPtrAdd &MI,
4583 MachineInstr *RHS,
4584 BuildFnTy &MatchInfo) {
4585 // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C)
4586 Register Src1Reg = MI.getOperand(1).getReg();
4587 if (RHS->getOpcode() != TargetOpcode::G_ADD)
4588 return false;
4589 auto C2 = getIConstantVRegVal(RHS->getOperand(2).getReg(), MRI);
4590 if (!C2)
4591 return false;
4592
4593 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4594 LLT PtrTy = MRI.getType(MI.getOperand(0).getReg());
4595
4596 auto NewBase =
4597 Builder.buildPtrAdd(PtrTy, Src1Reg, RHS->getOperand(1).getReg());
4598 Observer.changingInstr(MI);
4599 MI.getOperand(1).setReg(NewBase.getReg(0));
4600 MI.getOperand(2).setReg(RHS->getOperand(2).getReg());
4601 Observer.changedInstr(MI);
4602 };
4603 return !reassociationCanBreakAddressingModePattern(MI);
4604 }
4605
matchReassocConstantInnerLHS(GPtrAdd & MI,MachineInstr * LHS,MachineInstr * RHS,BuildFnTy & MatchInfo)4606 bool CombinerHelper::matchReassocConstantInnerLHS(GPtrAdd &MI,
4607 MachineInstr *LHS,
4608 MachineInstr *RHS,
4609 BuildFnTy &MatchInfo) {
4610 // G_PTR_ADD (G_PTR_ADD X, C), Y) -> (G_PTR_ADD (G_PTR_ADD(X, Y), C)
4611 // if and only if (G_PTR_ADD X, C) has one use.
4612 Register LHSBase;
4613 std::optional<ValueAndVReg> LHSCstOff;
4614 if (!mi_match(MI.getBaseReg(), MRI,
4615 m_OneNonDBGUse(m_GPtrAdd(m_Reg(LHSBase), m_GCst(LHSCstOff)))))
4616 return false;
4617
4618 auto *LHSPtrAdd = cast<GPtrAdd>(LHS);
4619 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4620 // When we change LHSPtrAdd's offset register we might cause it to use a reg
4621 // before its def. Sink the instruction so the outer PTR_ADD to ensure this
4622 // doesn't happen.
4623 LHSPtrAdd->moveBefore(&MI);
4624 Register RHSReg = MI.getOffsetReg();
4625 // set VReg will cause type mismatch if it comes from extend/trunc
4626 auto NewCst = B.buildConstant(MRI.getType(RHSReg), LHSCstOff->Value);
4627 Observer.changingInstr(MI);
4628 MI.getOperand(2).setReg(NewCst.getReg(0));
4629 Observer.changedInstr(MI);
4630 Observer.changingInstr(*LHSPtrAdd);
4631 LHSPtrAdd->getOperand(2).setReg(RHSReg);
4632 Observer.changedInstr(*LHSPtrAdd);
4633 };
4634 return !reassociationCanBreakAddressingModePattern(MI);
4635 }
4636
matchReassocFoldConstantsInSubTree(GPtrAdd & MI,MachineInstr * LHS,MachineInstr * RHS,BuildFnTy & MatchInfo)4637 bool CombinerHelper::matchReassocFoldConstantsInSubTree(GPtrAdd &MI,
4638 MachineInstr *LHS,
4639 MachineInstr *RHS,
4640 BuildFnTy &MatchInfo) {
4641 // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2)
4642 auto *LHSPtrAdd = dyn_cast<GPtrAdd>(LHS);
4643 if (!LHSPtrAdd)
4644 return false;
4645
4646 Register Src2Reg = MI.getOperand(2).getReg();
4647 Register LHSSrc1 = LHSPtrAdd->getBaseReg();
4648 Register LHSSrc2 = LHSPtrAdd->getOffsetReg();
4649 auto C1 = getIConstantVRegVal(LHSSrc2, MRI);
4650 if (!C1)
4651 return false;
4652 auto C2 = getIConstantVRegVal(Src2Reg, MRI);
4653 if (!C2)
4654 return false;
4655
4656 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4657 auto NewCst = B.buildConstant(MRI.getType(Src2Reg), *C1 + *C2);
4658 Observer.changingInstr(MI);
4659 MI.getOperand(1).setReg(LHSSrc1);
4660 MI.getOperand(2).setReg(NewCst.getReg(0));
4661 Observer.changedInstr(MI);
4662 };
4663 return !reassociationCanBreakAddressingModePattern(MI);
4664 }
4665
matchReassocPtrAdd(MachineInstr & MI,BuildFnTy & MatchInfo)4666 bool CombinerHelper::matchReassocPtrAdd(MachineInstr &MI,
4667 BuildFnTy &MatchInfo) {
4668 auto &PtrAdd = cast<GPtrAdd>(MI);
4669 // We're trying to match a few pointer computation patterns here for
4670 // re-association opportunities.
4671 // 1) Isolating a constant operand to be on the RHS, e.g.:
4672 // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C)
4673 //
4674 // 2) Folding two constants in each sub-tree as long as such folding
4675 // doesn't break a legal addressing mode.
4676 // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2)
4677 //
4678 // 3) Move a constant from the LHS of an inner op to the RHS of the outer.
4679 // G_PTR_ADD (G_PTR_ADD X, C), Y) -> G_PTR_ADD (G_PTR_ADD(X, Y), C)
4680 // iif (G_PTR_ADD X, C) has one use.
4681 MachineInstr *LHS = MRI.getVRegDef(PtrAdd.getBaseReg());
4682 MachineInstr *RHS = MRI.getVRegDef(PtrAdd.getOffsetReg());
4683
4684 // Try to match example 2.
4685 if (matchReassocFoldConstantsInSubTree(PtrAdd, LHS, RHS, MatchInfo))
4686 return true;
4687
4688 // Try to match example 3.
4689 if (matchReassocConstantInnerLHS(PtrAdd, LHS, RHS, MatchInfo))
4690 return true;
4691
4692 // Try to match example 1.
4693 if (matchReassocConstantInnerRHS(PtrAdd, RHS, MatchInfo))
4694 return true;
4695
4696 return false;
4697 }
4698
matchConstantFold(MachineInstr & MI,APInt & MatchInfo)4699 bool CombinerHelper::matchConstantFold(MachineInstr &MI, APInt &MatchInfo) {
4700 Register Op1 = MI.getOperand(1).getReg();
4701 Register Op2 = MI.getOperand(2).getReg();
4702 auto MaybeCst = ConstantFoldBinOp(MI.getOpcode(), Op1, Op2, MRI);
4703 if (!MaybeCst)
4704 return false;
4705 MatchInfo = *MaybeCst;
4706 return true;
4707 }
4708
matchNarrowBinopFeedingAnd(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)4709 bool CombinerHelper::matchNarrowBinopFeedingAnd(
4710 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
4711 // Look for a binop feeding into an AND with a mask:
4712 //
4713 // %add = G_ADD %lhs, %rhs
4714 // %and = G_AND %add, 000...11111111
4715 //
4716 // Check if it's possible to perform the binop at a narrower width and zext
4717 // back to the original width like so:
4718 //
4719 // %narrow_lhs = G_TRUNC %lhs
4720 // %narrow_rhs = G_TRUNC %rhs
4721 // %narrow_add = G_ADD %narrow_lhs, %narrow_rhs
4722 // %new_add = G_ZEXT %narrow_add
4723 // %and = G_AND %new_add, 000...11111111
4724 //
4725 // This can allow later combines to eliminate the G_AND if it turns out
4726 // that the mask is irrelevant.
4727 assert(MI.getOpcode() == TargetOpcode::G_AND);
4728 Register Dst = MI.getOperand(0).getReg();
4729 Register AndLHS = MI.getOperand(1).getReg();
4730 Register AndRHS = MI.getOperand(2).getReg();
4731 LLT WideTy = MRI.getType(Dst);
4732
4733 // If the potential binop has more than one use, then it's possible that one
4734 // of those uses will need its full width.
4735 if (!WideTy.isScalar() || !MRI.hasOneNonDBGUse(AndLHS))
4736 return false;
4737
4738 // Check if the LHS feeding the AND is impacted by the high bits that we're
4739 // masking out.
4740 //
4741 // e.g. for 64-bit x, y:
4742 //
4743 // add_64(x, y) & 65535 == zext(add_16(trunc(x), trunc(y))) & 65535
4744 MachineInstr *LHSInst = getDefIgnoringCopies(AndLHS, MRI);
4745 if (!LHSInst)
4746 return false;
4747 unsigned LHSOpc = LHSInst->getOpcode();
4748 switch (LHSOpc) {
4749 default:
4750 return false;
4751 case TargetOpcode::G_ADD:
4752 case TargetOpcode::G_SUB:
4753 case TargetOpcode::G_MUL:
4754 case TargetOpcode::G_AND:
4755 case TargetOpcode::G_OR:
4756 case TargetOpcode::G_XOR:
4757 break;
4758 }
4759
4760 // Find the mask on the RHS.
4761 auto Cst = getIConstantVRegValWithLookThrough(AndRHS, MRI);
4762 if (!Cst)
4763 return false;
4764 auto Mask = Cst->Value;
4765 if (!Mask.isMask())
4766 return false;
4767
4768 // No point in combining if there's nothing to truncate.
4769 unsigned NarrowWidth = Mask.countTrailingOnes();
4770 if (NarrowWidth == WideTy.getSizeInBits())
4771 return false;
4772 LLT NarrowTy = LLT::scalar(NarrowWidth);
4773
4774 // Check if adding the zext + truncates could be harmful.
4775 auto &MF = *MI.getMF();
4776 const auto &TLI = getTargetLowering();
4777 LLVMContext &Ctx = MF.getFunction().getContext();
4778 auto &DL = MF.getDataLayout();
4779 if (!TLI.isTruncateFree(WideTy, NarrowTy, DL, Ctx) ||
4780 !TLI.isZExtFree(NarrowTy, WideTy, DL, Ctx))
4781 return false;
4782 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {NarrowTy, WideTy}}) ||
4783 !isLegalOrBeforeLegalizer({TargetOpcode::G_ZEXT, {WideTy, NarrowTy}}))
4784 return false;
4785 Register BinOpLHS = LHSInst->getOperand(1).getReg();
4786 Register BinOpRHS = LHSInst->getOperand(2).getReg();
4787 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4788 auto NarrowLHS = Builder.buildTrunc(NarrowTy, BinOpLHS);
4789 auto NarrowRHS = Builder.buildTrunc(NarrowTy, BinOpRHS);
4790 auto NarrowBinOp =
4791 Builder.buildInstr(LHSOpc, {NarrowTy}, {NarrowLHS, NarrowRHS});
4792 auto Ext = Builder.buildZExt(WideTy, NarrowBinOp);
4793 Observer.changingInstr(MI);
4794 MI.getOperand(1).setReg(Ext.getReg(0));
4795 Observer.changedInstr(MI);
4796 };
4797 return true;
4798 }
4799
matchMulOBy2(MachineInstr & MI,BuildFnTy & MatchInfo)4800 bool CombinerHelper::matchMulOBy2(MachineInstr &MI, BuildFnTy &MatchInfo) {
4801 unsigned Opc = MI.getOpcode();
4802 assert(Opc == TargetOpcode::G_UMULO || Opc == TargetOpcode::G_SMULO);
4803
4804 if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICstOrSplat(2)))
4805 return false;
4806
4807 MatchInfo = [=, &MI](MachineIRBuilder &B) {
4808 Observer.changingInstr(MI);
4809 unsigned NewOpc = Opc == TargetOpcode::G_UMULO ? TargetOpcode::G_UADDO
4810 : TargetOpcode::G_SADDO;
4811 MI.setDesc(Builder.getTII().get(NewOpc));
4812 MI.getOperand(3).setReg(MI.getOperand(2).getReg());
4813 Observer.changedInstr(MI);
4814 };
4815 return true;
4816 }
4817
matchMulOBy0(MachineInstr & MI,BuildFnTy & MatchInfo)4818 bool CombinerHelper::matchMulOBy0(MachineInstr &MI, BuildFnTy &MatchInfo) {
4819 // (G_*MULO x, 0) -> 0 + no carry out
4820 assert(MI.getOpcode() == TargetOpcode::G_UMULO ||
4821 MI.getOpcode() == TargetOpcode::G_SMULO);
4822 if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICstOrSplat(0)))
4823 return false;
4824 Register Dst = MI.getOperand(0).getReg();
4825 Register Carry = MI.getOperand(1).getReg();
4826 if (!isConstantLegalOrBeforeLegalizer(MRI.getType(Dst)) ||
4827 !isConstantLegalOrBeforeLegalizer(MRI.getType(Carry)))
4828 return false;
4829 MatchInfo = [=](MachineIRBuilder &B) {
4830 B.buildConstant(Dst, 0);
4831 B.buildConstant(Carry, 0);
4832 };
4833 return true;
4834 }
4835
matchAddOBy0(MachineInstr & MI,BuildFnTy & MatchInfo)4836 bool CombinerHelper::matchAddOBy0(MachineInstr &MI, BuildFnTy &MatchInfo) {
4837 // (G_*ADDO x, 0) -> x + no carry out
4838 assert(MI.getOpcode() == TargetOpcode::G_UADDO ||
4839 MI.getOpcode() == TargetOpcode::G_SADDO);
4840 if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICstOrSplat(0)))
4841 return false;
4842 Register Carry = MI.getOperand(1).getReg();
4843 if (!isConstantLegalOrBeforeLegalizer(MRI.getType(Carry)))
4844 return false;
4845 Register Dst = MI.getOperand(0).getReg();
4846 Register LHS = MI.getOperand(2).getReg();
4847 MatchInfo = [=](MachineIRBuilder &B) {
4848 B.buildCopy(Dst, LHS);
4849 B.buildConstant(Carry, 0);
4850 };
4851 return true;
4852 }
4853
matchAddEToAddO(MachineInstr & MI,BuildFnTy & MatchInfo)4854 bool CombinerHelper::matchAddEToAddO(MachineInstr &MI, BuildFnTy &MatchInfo) {
4855 // (G_*ADDE x, y, 0) -> (G_*ADDO x, y)
4856 // (G_*SUBE x, y, 0) -> (G_*SUBO x, y)
4857 assert(MI.getOpcode() == TargetOpcode::G_UADDE ||
4858 MI.getOpcode() == TargetOpcode::G_SADDE ||
4859 MI.getOpcode() == TargetOpcode::G_USUBE ||
4860 MI.getOpcode() == TargetOpcode::G_SSUBE);
4861 if (!mi_match(MI.getOperand(4).getReg(), MRI, m_SpecificICstOrSplat(0)))
4862 return false;
4863 MatchInfo = [&](MachineIRBuilder &B) {
4864 unsigned NewOpcode;
4865 switch (MI.getOpcode()) {
4866 case TargetOpcode::G_UADDE:
4867 NewOpcode = TargetOpcode::G_UADDO;
4868 break;
4869 case TargetOpcode::G_SADDE:
4870 NewOpcode = TargetOpcode::G_SADDO;
4871 break;
4872 case TargetOpcode::G_USUBE:
4873 NewOpcode = TargetOpcode::G_USUBO;
4874 break;
4875 case TargetOpcode::G_SSUBE:
4876 NewOpcode = TargetOpcode::G_SSUBO;
4877 break;
4878 }
4879 Observer.changingInstr(MI);
4880 MI.setDesc(B.getTII().get(NewOpcode));
4881 MI.removeOperand(4);
4882 Observer.changedInstr(MI);
4883 };
4884 return true;
4885 }
4886
matchSubAddSameReg(MachineInstr & MI,BuildFnTy & MatchInfo)4887 bool CombinerHelper::matchSubAddSameReg(MachineInstr &MI,
4888 BuildFnTy &MatchInfo) {
4889 assert(MI.getOpcode() == TargetOpcode::G_SUB);
4890 Register Dst = MI.getOperand(0).getReg();
4891 // (x + y) - z -> x (if y == z)
4892 // (x + y) - z -> y (if x == z)
4893 Register X, Y, Z;
4894 if (mi_match(Dst, MRI, m_GSub(m_GAdd(m_Reg(X), m_Reg(Y)), m_Reg(Z)))) {
4895 Register ReplaceReg;
4896 int64_t CstX, CstY;
4897 if (Y == Z || (mi_match(Y, MRI, m_ICstOrSplat(CstY)) &&
4898 mi_match(Z, MRI, m_SpecificICstOrSplat(CstY))))
4899 ReplaceReg = X;
4900 else if (X == Z || (mi_match(X, MRI, m_ICstOrSplat(CstX)) &&
4901 mi_match(Z, MRI, m_SpecificICstOrSplat(CstX))))
4902 ReplaceReg = Y;
4903 if (ReplaceReg) {
4904 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, ReplaceReg); };
4905 return true;
4906 }
4907 }
4908
4909 // x - (y + z) -> 0 - y (if x == z)
4910 // x - (y + z) -> 0 - z (if x == y)
4911 if (mi_match(Dst, MRI, m_GSub(m_Reg(X), m_GAdd(m_Reg(Y), m_Reg(Z))))) {
4912 Register ReplaceReg;
4913 int64_t CstX;
4914 if (X == Z || (mi_match(X, MRI, m_ICstOrSplat(CstX)) &&
4915 mi_match(Z, MRI, m_SpecificICstOrSplat(CstX))))
4916 ReplaceReg = Y;
4917 else if (X == Y || (mi_match(X, MRI, m_ICstOrSplat(CstX)) &&
4918 mi_match(Y, MRI, m_SpecificICstOrSplat(CstX))))
4919 ReplaceReg = Z;
4920 if (ReplaceReg) {
4921 MatchInfo = [=](MachineIRBuilder &B) {
4922 auto Zero = B.buildConstant(MRI.getType(Dst), 0);
4923 B.buildSub(Dst, Zero, ReplaceReg);
4924 };
4925 return true;
4926 }
4927 }
4928 return false;
4929 }
4930
buildUDivUsingMul(MachineInstr & MI)4931 MachineInstr *CombinerHelper::buildUDivUsingMul(MachineInstr &MI) {
4932 assert(MI.getOpcode() == TargetOpcode::G_UDIV);
4933 auto &UDiv = cast<GenericMachineInstr>(MI);
4934 Register Dst = UDiv.getReg(0);
4935 Register LHS = UDiv.getReg(1);
4936 Register RHS = UDiv.getReg(2);
4937 LLT Ty = MRI.getType(Dst);
4938 LLT ScalarTy = Ty.getScalarType();
4939 const unsigned EltBits = ScalarTy.getScalarSizeInBits();
4940 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
4941 LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
4942 auto &MIB = Builder;
4943 MIB.setInstrAndDebugLoc(MI);
4944
4945 bool UseNPQ = false;
4946 SmallVector<Register, 16> PreShifts, PostShifts, MagicFactors, NPQFactors;
4947
4948 auto BuildUDIVPattern = [&](const Constant *C) {
4949 auto *CI = cast<ConstantInt>(C);
4950 const APInt &Divisor = CI->getValue();
4951
4952 bool SelNPQ = false;
4953 APInt Magic(Divisor.getBitWidth(), 0);
4954 unsigned PreShift = 0, PostShift = 0;
4955
4956 // Magic algorithm doesn't work for division by 1. We need to emit a select
4957 // at the end.
4958 // TODO: Use undef values for divisor of 1.
4959 if (!Divisor.isOneValue()) {
4960 UnsignedDivisionByConstantInfo magics =
4961 UnsignedDivisionByConstantInfo::get(Divisor);
4962
4963 Magic = std::move(magics.Magic);
4964
4965 assert(magics.PreShift < Divisor.getBitWidth() &&
4966 "We shouldn't generate an undefined shift!");
4967 assert(magics.PostShift < Divisor.getBitWidth() &&
4968 "We shouldn't generate an undefined shift!");
4969 assert((!magics.IsAdd || magics.PreShift == 0) && "Unexpected pre-shift");
4970 PreShift = magics.PreShift;
4971 PostShift = magics.PostShift;
4972 SelNPQ = magics.IsAdd;
4973 }
4974
4975 PreShifts.push_back(
4976 MIB.buildConstant(ScalarShiftAmtTy, PreShift).getReg(0));
4977 MagicFactors.push_back(MIB.buildConstant(ScalarTy, Magic).getReg(0));
4978 NPQFactors.push_back(
4979 MIB.buildConstant(ScalarTy,
4980 SelNPQ ? APInt::getOneBitSet(EltBits, EltBits - 1)
4981 : APInt::getZero(EltBits))
4982 .getReg(0));
4983 PostShifts.push_back(
4984 MIB.buildConstant(ScalarShiftAmtTy, PostShift).getReg(0));
4985 UseNPQ |= SelNPQ;
4986 return true;
4987 };
4988
4989 // Collect the shifts/magic values from each element.
4990 bool Matched = matchUnaryPredicate(MRI, RHS, BuildUDIVPattern);
4991 (void)Matched;
4992 assert(Matched && "Expected unary predicate match to succeed");
4993
4994 Register PreShift, PostShift, MagicFactor, NPQFactor;
4995 auto *RHSDef = getOpcodeDef<GBuildVector>(RHS, MRI);
4996 if (RHSDef) {
4997 PreShift = MIB.buildBuildVector(ShiftAmtTy, PreShifts).getReg(0);
4998 MagicFactor = MIB.buildBuildVector(Ty, MagicFactors).getReg(0);
4999 NPQFactor = MIB.buildBuildVector(Ty, NPQFactors).getReg(0);
5000 PostShift = MIB.buildBuildVector(ShiftAmtTy, PostShifts).getReg(0);
5001 } else {
5002 assert(MRI.getType(RHS).isScalar() &&
5003 "Non-build_vector operation should have been a scalar");
5004 PreShift = PreShifts[0];
5005 MagicFactor = MagicFactors[0];
5006 PostShift = PostShifts[0];
5007 }
5008
5009 Register Q = LHS;
5010 Q = MIB.buildLShr(Ty, Q, PreShift).getReg(0);
5011
5012 // Multiply the numerator (operand 0) by the magic value.
5013 Q = MIB.buildUMulH(Ty, Q, MagicFactor).getReg(0);
5014
5015 if (UseNPQ) {
5016 Register NPQ = MIB.buildSub(Ty, LHS, Q).getReg(0);
5017
5018 // For vectors we might have a mix of non-NPQ/NPQ paths, so use
5019 // G_UMULH to act as a SRL-by-1 for NPQ, else multiply by zero.
5020 if (Ty.isVector())
5021 NPQ = MIB.buildUMulH(Ty, NPQ, NPQFactor).getReg(0);
5022 else
5023 NPQ = MIB.buildLShr(Ty, NPQ, MIB.buildConstant(ShiftAmtTy, 1)).getReg(0);
5024
5025 Q = MIB.buildAdd(Ty, NPQ, Q).getReg(0);
5026 }
5027
5028 Q = MIB.buildLShr(Ty, Q, PostShift).getReg(0);
5029 auto One = MIB.buildConstant(Ty, 1);
5030 auto IsOne = MIB.buildICmp(
5031 CmpInst::Predicate::ICMP_EQ,
5032 Ty.isScalar() ? LLT::scalar(1) : Ty.changeElementSize(1), RHS, One);
5033 return MIB.buildSelect(Ty, IsOne, LHS, Q);
5034 }
5035
matchUDivByConst(MachineInstr & MI)5036 bool CombinerHelper::matchUDivByConst(MachineInstr &MI) {
5037 assert(MI.getOpcode() == TargetOpcode::G_UDIV);
5038 Register Dst = MI.getOperand(0).getReg();
5039 Register RHS = MI.getOperand(2).getReg();
5040 LLT DstTy = MRI.getType(Dst);
5041 auto *RHSDef = MRI.getVRegDef(RHS);
5042 if (!isConstantOrConstantVector(*RHSDef, MRI))
5043 return false;
5044
5045 auto &MF = *MI.getMF();
5046 AttributeList Attr = MF.getFunction().getAttributes();
5047 const auto &TLI = getTargetLowering();
5048 LLVMContext &Ctx = MF.getFunction().getContext();
5049 auto &DL = MF.getDataLayout();
5050 if (TLI.isIntDivCheap(getApproximateEVTForLLT(DstTy, DL, Ctx), Attr))
5051 return false;
5052
5053 // Don't do this for minsize because the instruction sequence is usually
5054 // larger.
5055 if (MF.getFunction().hasMinSize())
5056 return false;
5057
5058 // Don't do this if the types are not going to be legal.
5059 if (LI) {
5060 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_MUL, {DstTy, DstTy}}))
5061 return false;
5062 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_UMULH, {DstTy}}))
5063 return false;
5064 if (!isLegalOrBeforeLegalizer(
5065 {TargetOpcode::G_ICMP,
5066 {DstTy.isVector() ? DstTy.changeElementSize(1) : LLT::scalar(1),
5067 DstTy}}))
5068 return false;
5069 }
5070
5071 auto CheckEltValue = [&](const Constant *C) {
5072 if (auto *CI = dyn_cast_or_null<ConstantInt>(C))
5073 return !CI->isZero();
5074 return false;
5075 };
5076 return matchUnaryPredicate(MRI, RHS, CheckEltValue);
5077 }
5078
applyUDivByConst(MachineInstr & MI)5079 void CombinerHelper::applyUDivByConst(MachineInstr &MI) {
5080 auto *NewMI = buildUDivUsingMul(MI);
5081 replaceSingleDefInstWithReg(MI, NewMI->getOperand(0).getReg());
5082 }
5083
matchSDivByConst(MachineInstr & MI)5084 bool CombinerHelper::matchSDivByConst(MachineInstr &MI) {
5085 assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
5086 Register Dst = MI.getOperand(0).getReg();
5087 Register RHS = MI.getOperand(2).getReg();
5088 LLT DstTy = MRI.getType(Dst);
5089
5090 auto &MF = *MI.getMF();
5091 AttributeList Attr = MF.getFunction().getAttributes();
5092 const auto &TLI = getTargetLowering();
5093 LLVMContext &Ctx = MF.getFunction().getContext();
5094 auto &DL = MF.getDataLayout();
5095 if (TLI.isIntDivCheap(getApproximateEVTForLLT(DstTy, DL, Ctx), Attr))
5096 return false;
5097
5098 // Don't do this for minsize because the instruction sequence is usually
5099 // larger.
5100 if (MF.getFunction().hasMinSize())
5101 return false;
5102
5103 // If the sdiv has an 'exact' flag we can use a simpler lowering.
5104 if (MI.getFlag(MachineInstr::MIFlag::IsExact)) {
5105 return matchUnaryPredicate(
5106 MRI, RHS, [](const Constant *C) { return C && !C->isZeroValue(); });
5107 }
5108
5109 // Don't support the general case for now.
5110 return false;
5111 }
5112
applySDivByConst(MachineInstr & MI)5113 void CombinerHelper::applySDivByConst(MachineInstr &MI) {
5114 auto *NewMI = buildSDivUsingMul(MI);
5115 replaceSingleDefInstWithReg(MI, NewMI->getOperand(0).getReg());
5116 }
5117
buildSDivUsingMul(MachineInstr & MI)5118 MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) {
5119 assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
5120 auto &SDiv = cast<GenericMachineInstr>(MI);
5121 Register Dst = SDiv.getReg(0);
5122 Register LHS = SDiv.getReg(1);
5123 Register RHS = SDiv.getReg(2);
5124 LLT Ty = MRI.getType(Dst);
5125 LLT ScalarTy = Ty.getScalarType();
5126 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
5127 LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
5128 auto &MIB = Builder;
5129 MIB.setInstrAndDebugLoc(MI);
5130
5131 bool UseSRA = false;
5132 SmallVector<Register, 16> Shifts, Factors;
5133
5134 auto *RHSDef = cast<GenericMachineInstr>(getDefIgnoringCopies(RHS, MRI));
5135 bool IsSplat = getIConstantSplatVal(*RHSDef, MRI).has_value();
5136
5137 auto BuildSDIVPattern = [&](const Constant *C) {
5138 // Don't recompute inverses for each splat element.
5139 if (IsSplat && !Factors.empty()) {
5140 Shifts.push_back(Shifts[0]);
5141 Factors.push_back(Factors[0]);
5142 return true;
5143 }
5144
5145 auto *CI = cast<ConstantInt>(C);
5146 APInt Divisor = CI->getValue();
5147 unsigned Shift = Divisor.countTrailingZeros();
5148 if (Shift) {
5149 Divisor.ashrInPlace(Shift);
5150 UseSRA = true;
5151 }
5152
5153 // Calculate the multiplicative inverse modulo BW.
5154 // 2^W requires W + 1 bits, so we have to extend and then truncate.
5155 unsigned W = Divisor.getBitWidth();
5156 APInt Factor = Divisor.zext(W + 1)
5157 .multiplicativeInverse(APInt::getSignedMinValue(W + 1))
5158 .trunc(W);
5159 Shifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0));
5160 Factors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0));
5161 return true;
5162 };
5163
5164 // Collect all magic values from the build vector.
5165 bool Matched = matchUnaryPredicate(MRI, RHS, BuildSDIVPattern);
5166 (void)Matched;
5167 assert(Matched && "Expected unary predicate match to succeed");
5168
5169 Register Shift, Factor;
5170 if (Ty.isVector()) {
5171 Shift = MIB.buildBuildVector(ShiftAmtTy, Shifts).getReg(0);
5172 Factor = MIB.buildBuildVector(Ty, Factors).getReg(0);
5173 } else {
5174 Shift = Shifts[0];
5175 Factor = Factors[0];
5176 }
5177
5178 Register Res = LHS;
5179
5180 if (UseSRA)
5181 Res = MIB.buildAShr(Ty, Res, Shift, MachineInstr::IsExact).getReg(0);
5182
5183 return MIB.buildMul(Ty, Res, Factor);
5184 }
5185
matchUMulHToLShr(MachineInstr & MI)5186 bool CombinerHelper::matchUMulHToLShr(MachineInstr &MI) {
5187 assert(MI.getOpcode() == TargetOpcode::G_UMULH);
5188 Register RHS = MI.getOperand(2).getReg();
5189 Register Dst = MI.getOperand(0).getReg();
5190 LLT Ty = MRI.getType(Dst);
5191 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
5192 auto MatchPow2ExceptOne = [&](const Constant *C) {
5193 if (auto *CI = dyn_cast<ConstantInt>(C))
5194 return CI->getValue().isPowerOf2() && !CI->getValue().isOne();
5195 return false;
5196 };
5197 if (!matchUnaryPredicate(MRI, RHS, MatchPow2ExceptOne, false))
5198 return false;
5199 return isLegalOrBeforeLegalizer({TargetOpcode::G_LSHR, {Ty, ShiftAmtTy}});
5200 }
5201
applyUMulHToLShr(MachineInstr & MI)5202 void CombinerHelper::applyUMulHToLShr(MachineInstr &MI) {
5203 Register LHS = MI.getOperand(1).getReg();
5204 Register RHS = MI.getOperand(2).getReg();
5205 Register Dst = MI.getOperand(0).getReg();
5206 LLT Ty = MRI.getType(Dst);
5207 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
5208 unsigned NumEltBits = Ty.getScalarSizeInBits();
5209
5210 Builder.setInstrAndDebugLoc(MI);
5211 auto LogBase2 = buildLogBase2(RHS, Builder);
5212 auto ShiftAmt =
5213 Builder.buildSub(Ty, Builder.buildConstant(Ty, NumEltBits), LogBase2);
5214 auto Trunc = Builder.buildZExtOrTrunc(ShiftAmtTy, ShiftAmt);
5215 Builder.buildLShr(Dst, LHS, Trunc);
5216 MI.eraseFromParent();
5217 }
5218
matchRedundantNegOperands(MachineInstr & MI,BuildFnTy & MatchInfo)5219 bool CombinerHelper::matchRedundantNegOperands(MachineInstr &MI,
5220 BuildFnTy &MatchInfo) {
5221 unsigned Opc = MI.getOpcode();
5222 assert(Opc == TargetOpcode::G_FADD || Opc == TargetOpcode::G_FSUB ||
5223 Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV ||
5224 Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA);
5225
5226 Register Dst = MI.getOperand(0).getReg();
5227 Register X = MI.getOperand(1).getReg();
5228 Register Y = MI.getOperand(2).getReg();
5229 LLT Type = MRI.getType(Dst);
5230
5231 // fold (fadd x, fneg(y)) -> (fsub x, y)
5232 // fold (fadd fneg(y), x) -> (fsub x, y)
5233 // G_ADD is commutative so both cases are checked by m_GFAdd
5234 if (mi_match(Dst, MRI, m_GFAdd(m_Reg(X), m_GFNeg(m_Reg(Y)))) &&
5235 isLegalOrBeforeLegalizer({TargetOpcode::G_FSUB, {Type}})) {
5236 Opc = TargetOpcode::G_FSUB;
5237 }
5238 /// fold (fsub x, fneg(y)) -> (fadd x, y)
5239 else if (mi_match(Dst, MRI, m_GFSub(m_Reg(X), m_GFNeg(m_Reg(Y)))) &&
5240 isLegalOrBeforeLegalizer({TargetOpcode::G_FADD, {Type}})) {
5241 Opc = TargetOpcode::G_FADD;
5242 }
5243 // fold (fmul fneg(x), fneg(y)) -> (fmul x, y)
5244 // fold (fdiv fneg(x), fneg(y)) -> (fdiv x, y)
5245 // fold (fmad fneg(x), fneg(y), z) -> (fmad x, y, z)
5246 // fold (fma fneg(x), fneg(y), z) -> (fma x, y, z)
5247 else if ((Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV ||
5248 Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA) &&
5249 mi_match(X, MRI, m_GFNeg(m_Reg(X))) &&
5250 mi_match(Y, MRI, m_GFNeg(m_Reg(Y)))) {
5251 // no opcode change
5252 } else
5253 return false;
5254
5255 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5256 Observer.changingInstr(MI);
5257 MI.setDesc(B.getTII().get(Opc));
5258 MI.getOperand(1).setReg(X);
5259 MI.getOperand(2).setReg(Y);
5260 Observer.changedInstr(MI);
5261 };
5262 return true;
5263 }
5264
matchFsubToFneg(MachineInstr & MI,Register & MatchInfo)5265 bool CombinerHelper::matchFsubToFneg(MachineInstr &MI, Register &MatchInfo) {
5266 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
5267
5268 Register LHS = MI.getOperand(1).getReg();
5269 MatchInfo = MI.getOperand(2).getReg();
5270 LLT Ty = MRI.getType(MI.getOperand(0).getReg());
5271
5272 const auto LHSCst = Ty.isVector()
5273 ? getFConstantSplat(LHS, MRI, /* allowUndef */ true)
5274 : getFConstantVRegValWithLookThrough(LHS, MRI);
5275 if (!LHSCst)
5276 return false;
5277
5278 // -0.0 is always allowed
5279 if (LHSCst->Value.isNegZero())
5280 return true;
5281
5282 // +0.0 is only allowed if nsz is set.
5283 if (LHSCst->Value.isPosZero())
5284 return MI.getFlag(MachineInstr::FmNsz);
5285
5286 return false;
5287 }
5288
applyFsubToFneg(MachineInstr & MI,Register & MatchInfo)5289 void CombinerHelper::applyFsubToFneg(MachineInstr &MI, Register &MatchInfo) {
5290 Builder.setInstrAndDebugLoc(MI);
5291 Register Dst = MI.getOperand(0).getReg();
5292 Builder.buildFNeg(
5293 Dst, Builder.buildFCanonicalize(MRI.getType(Dst), MatchInfo).getReg(0));
5294 eraseInst(MI);
5295 }
5296
5297 /// Checks if \p MI is TargetOpcode::G_FMUL and contractable either
5298 /// due to global flags or MachineInstr flags.
isContractableFMul(MachineInstr & MI,bool AllowFusionGlobally)5299 static bool isContractableFMul(MachineInstr &MI, bool AllowFusionGlobally) {
5300 if (MI.getOpcode() != TargetOpcode::G_FMUL)
5301 return false;
5302 return AllowFusionGlobally || MI.getFlag(MachineInstr::MIFlag::FmContract);
5303 }
5304
hasMoreUses(const MachineInstr & MI0,const MachineInstr & MI1,const MachineRegisterInfo & MRI)5305 static bool hasMoreUses(const MachineInstr &MI0, const MachineInstr &MI1,
5306 const MachineRegisterInfo &MRI) {
5307 return std::distance(MRI.use_instr_nodbg_begin(MI0.getOperand(0).getReg()),
5308 MRI.use_instr_nodbg_end()) >
5309 std::distance(MRI.use_instr_nodbg_begin(MI1.getOperand(0).getReg()),
5310 MRI.use_instr_nodbg_end());
5311 }
5312
canCombineFMadOrFMA(MachineInstr & MI,bool & AllowFusionGlobally,bool & HasFMAD,bool & Aggressive,bool CanReassociate)5313 bool CombinerHelper::canCombineFMadOrFMA(MachineInstr &MI,
5314 bool &AllowFusionGlobally,
5315 bool &HasFMAD, bool &Aggressive,
5316 bool CanReassociate) {
5317
5318 auto *MF = MI.getMF();
5319 const auto &TLI = *MF->getSubtarget().getTargetLowering();
5320 const TargetOptions &Options = MF->getTarget().Options;
5321 LLT DstType = MRI.getType(MI.getOperand(0).getReg());
5322
5323 if (CanReassociate &&
5324 !(Options.UnsafeFPMath || MI.getFlag(MachineInstr::MIFlag::FmReassoc)))
5325 return false;
5326
5327 // Floating-point multiply-add with intermediate rounding.
5328 HasFMAD = (!isPreLegalize() && TLI.isFMADLegal(MI, DstType));
5329 // Floating-point multiply-add without intermediate rounding.
5330 bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(*MF, DstType) &&
5331 isLegalOrBeforeLegalizer({TargetOpcode::G_FMA, {DstType}});
5332 // No valid opcode, do not combine.
5333 if (!HasFMAD && !HasFMA)
5334 return false;
5335
5336 AllowFusionGlobally = Options.AllowFPOpFusion == FPOpFusion::Fast ||
5337 Options.UnsafeFPMath || HasFMAD;
5338 // If the addition is not contractable, do not combine.
5339 if (!AllowFusionGlobally && !MI.getFlag(MachineInstr::MIFlag::FmContract))
5340 return false;
5341
5342 Aggressive = TLI.enableAggressiveFMAFusion(DstType);
5343 return true;
5344 }
5345
matchCombineFAddFMulToFMadOrFMA(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)5346 bool CombinerHelper::matchCombineFAddFMulToFMadOrFMA(
5347 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5348 assert(MI.getOpcode() == TargetOpcode::G_FADD);
5349
5350 bool AllowFusionGlobally, HasFMAD, Aggressive;
5351 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5352 return false;
5353
5354 Register Op1 = MI.getOperand(1).getReg();
5355 Register Op2 = MI.getOperand(2).getReg();
5356 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1};
5357 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2};
5358 unsigned PreferredFusedOpcode =
5359 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5360
5361 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5362 // prefer to fold the multiply with fewer uses.
5363 if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5364 isContractableFMul(*RHS.MI, AllowFusionGlobally)) {
5365 if (hasMoreUses(*LHS.MI, *RHS.MI, MRI))
5366 std::swap(LHS, RHS);
5367 }
5368
5369 // fold (fadd (fmul x, y), z) -> (fma x, y, z)
5370 if (isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5371 (Aggressive || MRI.hasOneNonDBGUse(LHS.Reg))) {
5372 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5373 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5374 {LHS.MI->getOperand(1).getReg(),
5375 LHS.MI->getOperand(2).getReg(), RHS.Reg});
5376 };
5377 return true;
5378 }
5379
5380 // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
5381 if (isContractableFMul(*RHS.MI, AllowFusionGlobally) &&
5382 (Aggressive || MRI.hasOneNonDBGUse(RHS.Reg))) {
5383 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5384 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5385 {RHS.MI->getOperand(1).getReg(),
5386 RHS.MI->getOperand(2).getReg(), LHS.Reg});
5387 };
5388 return true;
5389 }
5390
5391 return false;
5392 }
5393
matchCombineFAddFpExtFMulToFMadOrFMA(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)5394 bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMA(
5395 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5396 assert(MI.getOpcode() == TargetOpcode::G_FADD);
5397
5398 bool AllowFusionGlobally, HasFMAD, Aggressive;
5399 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5400 return false;
5401
5402 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
5403 Register Op1 = MI.getOperand(1).getReg();
5404 Register Op2 = MI.getOperand(2).getReg();
5405 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1};
5406 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2};
5407 LLT DstType = MRI.getType(MI.getOperand(0).getReg());
5408
5409 unsigned PreferredFusedOpcode =
5410 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5411
5412 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5413 // prefer to fold the multiply with fewer uses.
5414 if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5415 isContractableFMul(*RHS.MI, AllowFusionGlobally)) {
5416 if (hasMoreUses(*LHS.MI, *RHS.MI, MRI))
5417 std::swap(LHS, RHS);
5418 }
5419
5420 // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
5421 MachineInstr *FpExtSrc;
5422 if (mi_match(LHS.Reg, MRI, m_GFPExt(m_MInstr(FpExtSrc))) &&
5423 isContractableFMul(*FpExtSrc, AllowFusionGlobally) &&
5424 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
5425 MRI.getType(FpExtSrc->getOperand(1).getReg()))) {
5426 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5427 auto FpExtX = B.buildFPExt(DstType, FpExtSrc->getOperand(1).getReg());
5428 auto FpExtY = B.buildFPExt(DstType, FpExtSrc->getOperand(2).getReg());
5429 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5430 {FpExtX.getReg(0), FpExtY.getReg(0), RHS.Reg});
5431 };
5432 return true;
5433 }
5434
5435 // fold (fadd z, (fpext (fmul x, y))) -> (fma (fpext x), (fpext y), z)
5436 // Note: Commutes FADD operands.
5437 if (mi_match(RHS.Reg, MRI, m_GFPExt(m_MInstr(FpExtSrc))) &&
5438 isContractableFMul(*FpExtSrc, AllowFusionGlobally) &&
5439 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
5440 MRI.getType(FpExtSrc->getOperand(1).getReg()))) {
5441 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5442 auto FpExtX = B.buildFPExt(DstType, FpExtSrc->getOperand(1).getReg());
5443 auto FpExtY = B.buildFPExt(DstType, FpExtSrc->getOperand(2).getReg());
5444 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5445 {FpExtX.getReg(0), FpExtY.getReg(0), LHS.Reg});
5446 };
5447 return true;
5448 }
5449
5450 return false;
5451 }
5452
matchCombineFAddFMAFMulToFMadOrFMA(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)5453 bool CombinerHelper::matchCombineFAddFMAFMulToFMadOrFMA(
5454 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5455 assert(MI.getOpcode() == TargetOpcode::G_FADD);
5456
5457 bool AllowFusionGlobally, HasFMAD, Aggressive;
5458 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive, true))
5459 return false;
5460
5461 Register Op1 = MI.getOperand(1).getReg();
5462 Register Op2 = MI.getOperand(2).getReg();
5463 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1};
5464 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2};
5465 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
5466
5467 unsigned PreferredFusedOpcode =
5468 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5469
5470 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5471 // prefer to fold the multiply with fewer uses.
5472 if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5473 isContractableFMul(*RHS.MI, AllowFusionGlobally)) {
5474 if (hasMoreUses(*LHS.MI, *RHS.MI, MRI))
5475 std::swap(LHS, RHS);
5476 }
5477
5478 MachineInstr *FMA = nullptr;
5479 Register Z;
5480 // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y, (fma u, v, z))
5481 if (LHS.MI->getOpcode() == PreferredFusedOpcode &&
5482 (MRI.getVRegDef(LHS.MI->getOperand(3).getReg())->getOpcode() ==
5483 TargetOpcode::G_FMUL) &&
5484 MRI.hasOneNonDBGUse(LHS.MI->getOperand(0).getReg()) &&
5485 MRI.hasOneNonDBGUse(LHS.MI->getOperand(3).getReg())) {
5486 FMA = LHS.MI;
5487 Z = RHS.Reg;
5488 }
5489 // fold (fadd z, (fma x, y, (fmul u, v))) -> (fma x, y, (fma u, v, z))
5490 else if (RHS.MI->getOpcode() == PreferredFusedOpcode &&
5491 (MRI.getVRegDef(RHS.MI->getOperand(3).getReg())->getOpcode() ==
5492 TargetOpcode::G_FMUL) &&
5493 MRI.hasOneNonDBGUse(RHS.MI->getOperand(0).getReg()) &&
5494 MRI.hasOneNonDBGUse(RHS.MI->getOperand(3).getReg())) {
5495 Z = LHS.Reg;
5496 FMA = RHS.MI;
5497 }
5498
5499 if (FMA) {
5500 MachineInstr *FMulMI = MRI.getVRegDef(FMA->getOperand(3).getReg());
5501 Register X = FMA->getOperand(1).getReg();
5502 Register Y = FMA->getOperand(2).getReg();
5503 Register U = FMulMI->getOperand(1).getReg();
5504 Register V = FMulMI->getOperand(2).getReg();
5505
5506 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5507 Register InnerFMA = MRI.createGenericVirtualRegister(DstTy);
5508 B.buildInstr(PreferredFusedOpcode, {InnerFMA}, {U, V, Z});
5509 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5510 {X, Y, InnerFMA});
5511 };
5512 return true;
5513 }
5514
5515 return false;
5516 }
5517
matchCombineFAddFpExtFMulToFMadOrFMAAggressive(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)5518 bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMAAggressive(
5519 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5520 assert(MI.getOpcode() == TargetOpcode::G_FADD);
5521
5522 bool AllowFusionGlobally, HasFMAD, Aggressive;
5523 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5524 return false;
5525
5526 if (!Aggressive)
5527 return false;
5528
5529 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
5530 LLT DstType = MRI.getType(MI.getOperand(0).getReg());
5531 Register Op1 = MI.getOperand(1).getReg();
5532 Register Op2 = MI.getOperand(2).getReg();
5533 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1};
5534 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2};
5535
5536 unsigned PreferredFusedOpcode =
5537 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5538
5539 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5540 // prefer to fold the multiply with fewer uses.
5541 if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5542 isContractableFMul(*RHS.MI, AllowFusionGlobally)) {
5543 if (hasMoreUses(*LHS.MI, *RHS.MI, MRI))
5544 std::swap(LHS, RHS);
5545 }
5546
5547 // Builds: (fma x, y, (fma (fpext u), (fpext v), z))
5548 auto buildMatchInfo = [=, &MI](Register U, Register V, Register Z, Register X,
5549 Register Y, MachineIRBuilder &B) {
5550 Register FpExtU = B.buildFPExt(DstType, U).getReg(0);
5551 Register FpExtV = B.buildFPExt(DstType, V).getReg(0);
5552 Register InnerFMA =
5553 B.buildInstr(PreferredFusedOpcode, {DstType}, {FpExtU, FpExtV, Z})
5554 .getReg(0);
5555 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5556 {X, Y, InnerFMA});
5557 };
5558
5559 MachineInstr *FMulMI, *FMAMI;
5560 // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
5561 // -> (fma x, y, (fma (fpext u), (fpext v), z))
5562 if (LHS.MI->getOpcode() == PreferredFusedOpcode &&
5563 mi_match(LHS.MI->getOperand(3).getReg(), MRI,
5564 m_GFPExt(m_MInstr(FMulMI))) &&
5565 isContractableFMul(*FMulMI, AllowFusionGlobally) &&
5566 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
5567 MRI.getType(FMulMI->getOperand(0).getReg()))) {
5568 MatchInfo = [=](MachineIRBuilder &B) {
5569 buildMatchInfo(FMulMI->getOperand(1).getReg(),
5570 FMulMI->getOperand(2).getReg(), RHS.Reg,
5571 LHS.MI->getOperand(1).getReg(),
5572 LHS.MI->getOperand(2).getReg(), B);
5573 };
5574 return true;
5575 }
5576
5577 // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
5578 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
5579 // FIXME: This turns two single-precision and one double-precision
5580 // operation into two double-precision operations, which might not be
5581 // interesting for all targets, especially GPUs.
5582 if (mi_match(LHS.Reg, MRI, m_GFPExt(m_MInstr(FMAMI))) &&
5583 FMAMI->getOpcode() == PreferredFusedOpcode) {
5584 MachineInstr *FMulMI = MRI.getVRegDef(FMAMI->getOperand(3).getReg());
5585 if (isContractableFMul(*FMulMI, AllowFusionGlobally) &&
5586 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
5587 MRI.getType(FMAMI->getOperand(0).getReg()))) {
5588 MatchInfo = [=](MachineIRBuilder &B) {
5589 Register X = FMAMI->getOperand(1).getReg();
5590 Register Y = FMAMI->getOperand(2).getReg();
5591 X = B.buildFPExt(DstType, X).getReg(0);
5592 Y = B.buildFPExt(DstType, Y).getReg(0);
5593 buildMatchInfo(FMulMI->getOperand(1).getReg(),
5594 FMulMI->getOperand(2).getReg(), RHS.Reg, X, Y, B);
5595 };
5596
5597 return true;
5598 }
5599 }
5600
5601 // fold (fadd z, (fma x, y, (fpext (fmul u, v)))
5602 // -> (fma x, y, (fma (fpext u), (fpext v), z))
5603 if (RHS.MI->getOpcode() == PreferredFusedOpcode &&
5604 mi_match(RHS.MI->getOperand(3).getReg(), MRI,
5605 m_GFPExt(m_MInstr(FMulMI))) &&
5606 isContractableFMul(*FMulMI, AllowFusionGlobally) &&
5607 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
5608 MRI.getType(FMulMI->getOperand(0).getReg()))) {
5609 MatchInfo = [=](MachineIRBuilder &B) {
5610 buildMatchInfo(FMulMI->getOperand(1).getReg(),
5611 FMulMI->getOperand(2).getReg(), LHS.Reg,
5612 RHS.MI->getOperand(1).getReg(),
5613 RHS.MI->getOperand(2).getReg(), B);
5614 };
5615 return true;
5616 }
5617
5618 // fold (fadd z, (fpext (fma x, y, (fmul u, v)))
5619 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
5620 // FIXME: This turns two single-precision and one double-precision
5621 // operation into two double-precision operations, which might not be
5622 // interesting for all targets, especially GPUs.
5623 if (mi_match(RHS.Reg, MRI, m_GFPExt(m_MInstr(FMAMI))) &&
5624 FMAMI->getOpcode() == PreferredFusedOpcode) {
5625 MachineInstr *FMulMI = MRI.getVRegDef(FMAMI->getOperand(3).getReg());
5626 if (isContractableFMul(*FMulMI, AllowFusionGlobally) &&
5627 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType,
5628 MRI.getType(FMAMI->getOperand(0).getReg()))) {
5629 MatchInfo = [=](MachineIRBuilder &B) {
5630 Register X = FMAMI->getOperand(1).getReg();
5631 Register Y = FMAMI->getOperand(2).getReg();
5632 X = B.buildFPExt(DstType, X).getReg(0);
5633 Y = B.buildFPExt(DstType, Y).getReg(0);
5634 buildMatchInfo(FMulMI->getOperand(1).getReg(),
5635 FMulMI->getOperand(2).getReg(), LHS.Reg, X, Y, B);
5636 };
5637 return true;
5638 }
5639 }
5640
5641 return false;
5642 }
5643
matchCombineFSubFMulToFMadOrFMA(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)5644 bool CombinerHelper::matchCombineFSubFMulToFMadOrFMA(
5645 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5646 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
5647
5648 bool AllowFusionGlobally, HasFMAD, Aggressive;
5649 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5650 return false;
5651
5652 Register Op1 = MI.getOperand(1).getReg();
5653 Register Op2 = MI.getOperand(2).getReg();
5654 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1};
5655 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2};
5656 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
5657
5658 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
5659 // prefer to fold the multiply with fewer uses.
5660 int FirstMulHasFewerUses = true;
5661 if (isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5662 isContractableFMul(*RHS.MI, AllowFusionGlobally) &&
5663 hasMoreUses(*LHS.MI, *RHS.MI, MRI))
5664 FirstMulHasFewerUses = false;
5665
5666 unsigned PreferredFusedOpcode =
5667 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5668
5669 // fold (fsub (fmul x, y), z) -> (fma x, y, -z)
5670 if (FirstMulHasFewerUses &&
5671 (isContractableFMul(*LHS.MI, AllowFusionGlobally) &&
5672 (Aggressive || MRI.hasOneNonDBGUse(LHS.Reg)))) {
5673 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5674 Register NegZ = B.buildFNeg(DstTy, RHS.Reg).getReg(0);
5675 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5676 {LHS.MI->getOperand(1).getReg(),
5677 LHS.MI->getOperand(2).getReg(), NegZ});
5678 };
5679 return true;
5680 }
5681 // fold (fsub x, (fmul y, z)) -> (fma -y, z, x)
5682 else if ((isContractableFMul(*RHS.MI, AllowFusionGlobally) &&
5683 (Aggressive || MRI.hasOneNonDBGUse(RHS.Reg)))) {
5684 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5685 Register NegY =
5686 B.buildFNeg(DstTy, RHS.MI->getOperand(1).getReg()).getReg(0);
5687 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5688 {NegY, RHS.MI->getOperand(2).getReg(), LHS.Reg});
5689 };
5690 return true;
5691 }
5692
5693 return false;
5694 }
5695
matchCombineFSubFNegFMulToFMadOrFMA(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)5696 bool CombinerHelper::matchCombineFSubFNegFMulToFMadOrFMA(
5697 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5698 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
5699
5700 bool AllowFusionGlobally, HasFMAD, Aggressive;
5701 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5702 return false;
5703
5704 Register LHSReg = MI.getOperand(1).getReg();
5705 Register RHSReg = MI.getOperand(2).getReg();
5706 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
5707
5708 unsigned PreferredFusedOpcode =
5709 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5710
5711 MachineInstr *FMulMI;
5712 // fold (fsub (fneg (fmul x, y)), z) -> (fma (fneg x), y, (fneg z))
5713 if (mi_match(LHSReg, MRI, m_GFNeg(m_MInstr(FMulMI))) &&
5714 (Aggressive || (MRI.hasOneNonDBGUse(LHSReg) &&
5715 MRI.hasOneNonDBGUse(FMulMI->getOperand(0).getReg()))) &&
5716 isContractableFMul(*FMulMI, AllowFusionGlobally)) {
5717 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5718 Register NegX =
5719 B.buildFNeg(DstTy, FMulMI->getOperand(1).getReg()).getReg(0);
5720 Register NegZ = B.buildFNeg(DstTy, RHSReg).getReg(0);
5721 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5722 {NegX, FMulMI->getOperand(2).getReg(), NegZ});
5723 };
5724 return true;
5725 }
5726
5727 // fold (fsub x, (fneg (fmul, y, z))) -> (fma y, z, x)
5728 if (mi_match(RHSReg, MRI, m_GFNeg(m_MInstr(FMulMI))) &&
5729 (Aggressive || (MRI.hasOneNonDBGUse(RHSReg) &&
5730 MRI.hasOneNonDBGUse(FMulMI->getOperand(0).getReg()))) &&
5731 isContractableFMul(*FMulMI, AllowFusionGlobally)) {
5732 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5733 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5734 {FMulMI->getOperand(1).getReg(),
5735 FMulMI->getOperand(2).getReg(), LHSReg});
5736 };
5737 return true;
5738 }
5739
5740 return false;
5741 }
5742
matchCombineFSubFpExtFMulToFMadOrFMA(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)5743 bool CombinerHelper::matchCombineFSubFpExtFMulToFMadOrFMA(
5744 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5745 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
5746
5747 bool AllowFusionGlobally, HasFMAD, Aggressive;
5748 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5749 return false;
5750
5751 Register LHSReg = MI.getOperand(1).getReg();
5752 Register RHSReg = MI.getOperand(2).getReg();
5753 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
5754
5755 unsigned PreferredFusedOpcode =
5756 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5757
5758 MachineInstr *FMulMI;
5759 // fold (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z))
5760 if (mi_match(LHSReg, MRI, m_GFPExt(m_MInstr(FMulMI))) &&
5761 isContractableFMul(*FMulMI, AllowFusionGlobally) &&
5762 (Aggressive || MRI.hasOneNonDBGUse(LHSReg))) {
5763 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5764 Register FpExtX =
5765 B.buildFPExt(DstTy, FMulMI->getOperand(1).getReg()).getReg(0);
5766 Register FpExtY =
5767 B.buildFPExt(DstTy, FMulMI->getOperand(2).getReg()).getReg(0);
5768 Register NegZ = B.buildFNeg(DstTy, RHSReg).getReg(0);
5769 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5770 {FpExtX, FpExtY, NegZ});
5771 };
5772 return true;
5773 }
5774
5775 // fold (fsub x, (fpext (fmul y, z))) -> (fma (fneg (fpext y)), (fpext z), x)
5776 if (mi_match(RHSReg, MRI, m_GFPExt(m_MInstr(FMulMI))) &&
5777 isContractableFMul(*FMulMI, AllowFusionGlobally) &&
5778 (Aggressive || MRI.hasOneNonDBGUse(RHSReg))) {
5779 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5780 Register FpExtY =
5781 B.buildFPExt(DstTy, FMulMI->getOperand(1).getReg()).getReg(0);
5782 Register NegY = B.buildFNeg(DstTy, FpExtY).getReg(0);
5783 Register FpExtZ =
5784 B.buildFPExt(DstTy, FMulMI->getOperand(2).getReg()).getReg(0);
5785 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()},
5786 {NegY, FpExtZ, LHSReg});
5787 };
5788 return true;
5789 }
5790
5791 return false;
5792 }
5793
matchCombineFSubFpExtFNegFMulToFMadOrFMA(MachineInstr & MI,std::function<void (MachineIRBuilder &)> & MatchInfo)5794 bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA(
5795 MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
5796 assert(MI.getOpcode() == TargetOpcode::G_FSUB);
5797
5798 bool AllowFusionGlobally, HasFMAD, Aggressive;
5799 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive))
5800 return false;
5801
5802 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering();
5803 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
5804 Register LHSReg = MI.getOperand(1).getReg();
5805 Register RHSReg = MI.getOperand(2).getReg();
5806
5807 unsigned PreferredFusedOpcode =
5808 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA;
5809
5810 auto buildMatchInfo = [=](Register Dst, Register X, Register Y, Register Z,
5811 MachineIRBuilder &B) {
5812 Register FpExtX = B.buildFPExt(DstTy, X).getReg(0);
5813 Register FpExtY = B.buildFPExt(DstTy, Y).getReg(0);
5814 B.buildInstr(PreferredFusedOpcode, {Dst}, {FpExtX, FpExtY, Z});
5815 };
5816
5817 MachineInstr *FMulMI;
5818 // fold (fsub (fpext (fneg (fmul x, y))), z) ->
5819 // (fneg (fma (fpext x), (fpext y), z))
5820 // fold (fsub (fneg (fpext (fmul x, y))), z) ->
5821 // (fneg (fma (fpext x), (fpext y), z))
5822 if ((mi_match(LHSReg, MRI, m_GFPExt(m_GFNeg(m_MInstr(FMulMI)))) ||
5823 mi_match(LHSReg, MRI, m_GFNeg(m_GFPExt(m_MInstr(FMulMI))))) &&
5824 isContractableFMul(*FMulMI, AllowFusionGlobally) &&
5825 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstTy,
5826 MRI.getType(FMulMI->getOperand(0).getReg()))) {
5827 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5828 Register FMAReg = MRI.createGenericVirtualRegister(DstTy);
5829 buildMatchInfo(FMAReg, FMulMI->getOperand(1).getReg(),
5830 FMulMI->getOperand(2).getReg(), RHSReg, B);
5831 B.buildFNeg(MI.getOperand(0).getReg(), FMAReg);
5832 };
5833 return true;
5834 }
5835
5836 // fold (fsub x, (fpext (fneg (fmul y, z)))) -> (fma (fpext y), (fpext z), x)
5837 // fold (fsub x, (fneg (fpext (fmul y, z)))) -> (fma (fpext y), (fpext z), x)
5838 if ((mi_match(RHSReg, MRI, m_GFPExt(m_GFNeg(m_MInstr(FMulMI)))) ||
5839 mi_match(RHSReg, MRI, m_GFNeg(m_GFPExt(m_MInstr(FMulMI))))) &&
5840 isContractableFMul(*FMulMI, AllowFusionGlobally) &&
5841 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstTy,
5842 MRI.getType(FMulMI->getOperand(0).getReg()))) {
5843 MatchInfo = [=, &MI](MachineIRBuilder &B) {
5844 buildMatchInfo(MI.getOperand(0).getReg(), FMulMI->getOperand(1).getReg(),
5845 FMulMI->getOperand(2).getReg(), LHSReg, B);
5846 };
5847 return true;
5848 }
5849
5850 return false;
5851 }
5852
matchSelectToLogical(MachineInstr & MI,BuildFnTy & MatchInfo)5853 bool CombinerHelper::matchSelectToLogical(MachineInstr &MI,
5854 BuildFnTy &MatchInfo) {
5855 GSelect &Sel = cast<GSelect>(MI);
5856 Register DstReg = Sel.getReg(0);
5857 Register Cond = Sel.getCondReg();
5858 Register TrueReg = Sel.getTrueReg();
5859 Register FalseReg = Sel.getFalseReg();
5860
5861 auto *TrueDef = getDefIgnoringCopies(TrueReg, MRI);
5862 auto *FalseDef = getDefIgnoringCopies(FalseReg, MRI);
5863
5864 const LLT CondTy = MRI.getType(Cond);
5865 const LLT OpTy = MRI.getType(TrueReg);
5866 if (CondTy != OpTy || OpTy.getScalarSizeInBits() != 1)
5867 return false;
5868
5869 // We have a boolean select.
5870
5871 // select Cond, Cond, F --> or Cond, F
5872 // select Cond, 1, F --> or Cond, F
5873 auto MaybeCstTrue = isConstantOrConstantSplatVector(*TrueDef, MRI);
5874 if (Cond == TrueReg || (MaybeCstTrue && MaybeCstTrue->isOne())) {
5875 MatchInfo = [=](MachineIRBuilder &MIB) {
5876 MIB.buildOr(DstReg, Cond, FalseReg);
5877 };
5878 return true;
5879 }
5880
5881 // select Cond, T, Cond --> and Cond, T
5882 // select Cond, T, 0 --> and Cond, T
5883 auto MaybeCstFalse = isConstantOrConstantSplatVector(*FalseDef, MRI);
5884 if (Cond == FalseReg || (MaybeCstFalse && MaybeCstFalse->isZero())) {
5885 MatchInfo = [=](MachineIRBuilder &MIB) {
5886 MIB.buildAnd(DstReg, Cond, TrueReg);
5887 };
5888 return true;
5889 }
5890
5891 // select Cond, T, 1 --> or (not Cond), T
5892 if (MaybeCstFalse && MaybeCstFalse->isOne()) {
5893 MatchInfo = [=](MachineIRBuilder &MIB) {
5894 MIB.buildOr(DstReg, MIB.buildNot(OpTy, Cond), TrueReg);
5895 };
5896 return true;
5897 }
5898
5899 // select Cond, 0, F --> and (not Cond), F
5900 if (MaybeCstTrue && MaybeCstTrue->isZero()) {
5901 MatchInfo = [=](MachineIRBuilder &MIB) {
5902 MIB.buildAnd(DstReg, MIB.buildNot(OpTy, Cond), FalseReg);
5903 };
5904 return true;
5905 }
5906 return false;
5907 }
5908
matchCombineFMinMaxNaN(MachineInstr & MI,unsigned & IdxToPropagate)5909 bool CombinerHelper::matchCombineFMinMaxNaN(MachineInstr &MI,
5910 unsigned &IdxToPropagate) {
5911 bool PropagateNaN;
5912 switch (MI.getOpcode()) {
5913 default:
5914 return false;
5915 case TargetOpcode::G_FMINNUM:
5916 case TargetOpcode::G_FMAXNUM:
5917 PropagateNaN = false;
5918 break;
5919 case TargetOpcode::G_FMINIMUM:
5920 case TargetOpcode::G_FMAXIMUM:
5921 PropagateNaN = true;
5922 break;
5923 }
5924
5925 auto MatchNaN = [&](unsigned Idx) {
5926 Register MaybeNaNReg = MI.getOperand(Idx).getReg();
5927 const ConstantFP *MaybeCst = getConstantFPVRegVal(MaybeNaNReg, MRI);
5928 if (!MaybeCst || !MaybeCst->getValueAPF().isNaN())
5929 return false;
5930 IdxToPropagate = PropagateNaN ? Idx : (Idx == 1 ? 2 : 1);
5931 return true;
5932 };
5933
5934 return MatchNaN(1) || MatchNaN(2);
5935 }
5936
matchAddSubSameReg(MachineInstr & MI,Register & Src)5937 bool CombinerHelper::matchAddSubSameReg(MachineInstr &MI, Register &Src) {
5938 assert(MI.getOpcode() == TargetOpcode::G_ADD && "Expected a G_ADD");
5939 Register LHS = MI.getOperand(1).getReg();
5940 Register RHS = MI.getOperand(2).getReg();
5941
5942 // Helper lambda to check for opportunities for
5943 // A + (B - A) -> B
5944 // (B - A) + A -> B
5945 auto CheckFold = [&](Register MaybeSub, Register MaybeSameReg) {
5946 Register Reg;
5947 return mi_match(MaybeSub, MRI, m_GSub(m_Reg(Src), m_Reg(Reg))) &&
5948 Reg == MaybeSameReg;
5949 };
5950 return CheckFold(LHS, RHS) || CheckFold(RHS, LHS);
5951 }
5952
matchBuildVectorIdentityFold(MachineInstr & MI,Register & MatchInfo)5953 bool CombinerHelper::matchBuildVectorIdentityFold(MachineInstr &MI,
5954 Register &MatchInfo) {
5955 // This combine folds the following patterns:
5956 //
5957 // G_BUILD_VECTOR_TRUNC (G_BITCAST(x), G_LSHR(G_BITCAST(x), k))
5958 // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), G_TRUNC(G_LSHR(G_BITCAST(x), k)))
5959 // into
5960 // x
5961 // if
5962 // k == sizeof(VecEltTy)/2
5963 // type(x) == type(dst)
5964 //
5965 // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), undef)
5966 // into
5967 // x
5968 // if
5969 // type(x) == type(dst)
5970
5971 LLT DstVecTy = MRI.getType(MI.getOperand(0).getReg());
5972 LLT DstEltTy = DstVecTy.getElementType();
5973
5974 Register Lo, Hi;
5975
5976 if (mi_match(
5977 MI, MRI,
5978 m_GBuildVector(m_GTrunc(m_GBitcast(m_Reg(Lo))), m_GImplicitDef()))) {
5979 MatchInfo = Lo;
5980 return MRI.getType(MatchInfo) == DstVecTy;
5981 }
5982
5983 std::optional<ValueAndVReg> ShiftAmount;
5984 const auto LoPattern = m_GBitcast(m_Reg(Lo));
5985 const auto HiPattern = m_GLShr(m_GBitcast(m_Reg(Hi)), m_GCst(ShiftAmount));
5986 if (mi_match(
5987 MI, MRI,
5988 m_any_of(m_GBuildVectorTrunc(LoPattern, HiPattern),
5989 m_GBuildVector(m_GTrunc(LoPattern), m_GTrunc(HiPattern))))) {
5990 if (Lo == Hi && ShiftAmount->Value == DstEltTy.getSizeInBits()) {
5991 MatchInfo = Lo;
5992 return MRI.getType(MatchInfo) == DstVecTy;
5993 }
5994 }
5995
5996 return false;
5997 }
5998
matchTruncBuildVectorFold(MachineInstr & MI,Register & MatchInfo)5999 bool CombinerHelper::matchTruncBuildVectorFold(MachineInstr &MI,
6000 Register &MatchInfo) {
6001 // Replace (G_TRUNC (G_BITCAST (G_BUILD_VECTOR x, y)) with just x
6002 // if type(x) == type(G_TRUNC)
6003 if (!mi_match(MI.getOperand(1).getReg(), MRI,
6004 m_GBitcast(m_GBuildVector(m_Reg(MatchInfo), m_Reg()))))
6005 return false;
6006
6007 return MRI.getType(MatchInfo) == MRI.getType(MI.getOperand(0).getReg());
6008 }
6009
matchTruncLshrBuildVectorFold(MachineInstr & MI,Register & MatchInfo)6010 bool CombinerHelper::matchTruncLshrBuildVectorFold(MachineInstr &MI,
6011 Register &MatchInfo) {
6012 // Replace (G_TRUNC (G_LSHR (G_BITCAST (G_BUILD_VECTOR x, y)), K)) with
6013 // y if K == size of vector element type
6014 std::optional<ValueAndVReg> ShiftAmt;
6015 if (!mi_match(MI.getOperand(1).getReg(), MRI,
6016 m_GLShr(m_GBitcast(m_GBuildVector(m_Reg(), m_Reg(MatchInfo))),
6017 m_GCst(ShiftAmt))))
6018 return false;
6019
6020 LLT MatchTy = MRI.getType(MatchInfo);
6021 return ShiftAmt->Value.getZExtValue() == MatchTy.getSizeInBits() &&
6022 MatchTy == MRI.getType(MI.getOperand(0).getReg());
6023 }
6024
getFPMinMaxOpcForSelect(CmpInst::Predicate Pred,LLT DstTy,SelectPatternNaNBehaviour VsNaNRetVal) const6025 unsigned CombinerHelper::getFPMinMaxOpcForSelect(
6026 CmpInst::Predicate Pred, LLT DstTy,
6027 SelectPatternNaNBehaviour VsNaNRetVal) const {
6028 assert(VsNaNRetVal != SelectPatternNaNBehaviour::NOT_APPLICABLE &&
6029 "Expected a NaN behaviour?");
6030 // Choose an opcode based off of legality or the behaviour when one of the
6031 // LHS/RHS may be NaN.
6032 switch (Pred) {
6033 default:
6034 return 0;
6035 case CmpInst::FCMP_UGT:
6036 case CmpInst::FCMP_UGE:
6037 case CmpInst::FCMP_OGT:
6038 case CmpInst::FCMP_OGE:
6039 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER)
6040 return TargetOpcode::G_FMAXNUM;
6041 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN)
6042 return TargetOpcode::G_FMAXIMUM;
6043 if (isLegal({TargetOpcode::G_FMAXNUM, {DstTy}}))
6044 return TargetOpcode::G_FMAXNUM;
6045 if (isLegal({TargetOpcode::G_FMAXIMUM, {DstTy}}))
6046 return TargetOpcode::G_FMAXIMUM;
6047 return 0;
6048 case CmpInst::FCMP_ULT:
6049 case CmpInst::FCMP_ULE:
6050 case CmpInst::FCMP_OLT:
6051 case CmpInst::FCMP_OLE:
6052 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER)
6053 return TargetOpcode::G_FMINNUM;
6054 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN)
6055 return TargetOpcode::G_FMINIMUM;
6056 if (isLegal({TargetOpcode::G_FMINNUM, {DstTy}}))
6057 return TargetOpcode::G_FMINNUM;
6058 if (!isLegal({TargetOpcode::G_FMINIMUM, {DstTy}}))
6059 return 0;
6060 return TargetOpcode::G_FMINIMUM;
6061 }
6062 }
6063
6064 CombinerHelper::SelectPatternNaNBehaviour
computeRetValAgainstNaN(Register LHS,Register RHS,bool IsOrderedComparison) const6065 CombinerHelper::computeRetValAgainstNaN(Register LHS, Register RHS,
6066 bool IsOrderedComparison) const {
6067 bool LHSSafe = isKnownNeverNaN(LHS, MRI);
6068 bool RHSSafe = isKnownNeverNaN(RHS, MRI);
6069 // Completely unsafe.
6070 if (!LHSSafe && !RHSSafe)
6071 return SelectPatternNaNBehaviour::NOT_APPLICABLE;
6072 if (LHSSafe && RHSSafe)
6073 return SelectPatternNaNBehaviour::RETURNS_ANY;
6074 // An ordered comparison will return false when given a NaN, so it
6075 // returns the RHS.
6076 if (IsOrderedComparison)
6077 return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_NAN
6078 : SelectPatternNaNBehaviour::RETURNS_OTHER;
6079 // An unordered comparison will return true when given a NaN, so it
6080 // returns the LHS.
6081 return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_OTHER
6082 : SelectPatternNaNBehaviour::RETURNS_NAN;
6083 }
6084
matchFPSelectToMinMax(Register Dst,Register Cond,Register TrueVal,Register FalseVal,BuildFnTy & MatchInfo)6085 bool CombinerHelper::matchFPSelectToMinMax(Register Dst, Register Cond,
6086 Register TrueVal, Register FalseVal,
6087 BuildFnTy &MatchInfo) {
6088 // Match: select (fcmp cond x, y) x, y
6089 // select (fcmp cond x, y) y, x
6090 // And turn it into fminnum/fmaxnum or fmin/fmax based off of the condition.
6091 LLT DstTy = MRI.getType(Dst);
6092 // Bail out early on pointers, since we'll never want to fold to a min/max.
6093 if (DstTy.isPointer())
6094 return false;
6095 // Match a floating point compare with a less-than/greater-than predicate.
6096 // TODO: Allow multiple users of the compare if they are all selects.
6097 CmpInst::Predicate Pred;
6098 Register CmpLHS, CmpRHS;
6099 if (!mi_match(Cond, MRI,
6100 m_OneNonDBGUse(
6101 m_GFCmp(m_Pred(Pred), m_Reg(CmpLHS), m_Reg(CmpRHS)))) ||
6102 CmpInst::isEquality(Pred))
6103 return false;
6104 SelectPatternNaNBehaviour ResWithKnownNaNInfo =
6105 computeRetValAgainstNaN(CmpLHS, CmpRHS, CmpInst::isOrdered(Pred));
6106 if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::NOT_APPLICABLE)
6107 return false;
6108 if (TrueVal == CmpRHS && FalseVal == CmpLHS) {
6109 std::swap(CmpLHS, CmpRHS);
6110 Pred = CmpInst::getSwappedPredicate(Pred);
6111 if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_NAN)
6112 ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_OTHER;
6113 else if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_OTHER)
6114 ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_NAN;
6115 }
6116 if (TrueVal != CmpLHS || FalseVal != CmpRHS)
6117 return false;
6118 // Decide what type of max/min this should be based off of the predicate.
6119 unsigned Opc = getFPMinMaxOpcForSelect(Pred, DstTy, ResWithKnownNaNInfo);
6120 if (!Opc || !isLegal({Opc, {DstTy}}))
6121 return false;
6122 // Comparisons between signed zero and zero may have different results...
6123 // unless we have fmaximum/fminimum. In that case, we know -0 < 0.
6124 if (Opc != TargetOpcode::G_FMAXIMUM && Opc != TargetOpcode::G_FMINIMUM) {
6125 // We don't know if a comparison between two 0s will give us a consistent
6126 // result. Be conservative and only proceed if at least one side is
6127 // non-zero.
6128 auto KnownNonZeroSide = getFConstantVRegValWithLookThrough(CmpLHS, MRI);
6129 if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) {
6130 KnownNonZeroSide = getFConstantVRegValWithLookThrough(CmpRHS, MRI);
6131 if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero())
6132 return false;
6133 }
6134 }
6135 MatchInfo = [=](MachineIRBuilder &B) {
6136 B.buildInstr(Opc, {Dst}, {CmpLHS, CmpRHS});
6137 };
6138 return true;
6139 }
6140
matchSimplifySelectToMinMax(MachineInstr & MI,BuildFnTy & MatchInfo)6141 bool CombinerHelper::matchSimplifySelectToMinMax(MachineInstr &MI,
6142 BuildFnTy &MatchInfo) {
6143 // TODO: Handle integer cases.
6144 assert(MI.getOpcode() == TargetOpcode::G_SELECT);
6145 // Condition may be fed by a truncated compare.
6146 Register Cond = MI.getOperand(1).getReg();
6147 Register MaybeTrunc;
6148 if (mi_match(Cond, MRI, m_OneNonDBGUse(m_GTrunc(m_Reg(MaybeTrunc)))))
6149 Cond = MaybeTrunc;
6150 Register Dst = MI.getOperand(0).getReg();
6151 Register TrueVal = MI.getOperand(2).getReg();
6152 Register FalseVal = MI.getOperand(3).getReg();
6153 return matchFPSelectToMinMax(Dst, Cond, TrueVal, FalseVal, MatchInfo);
6154 }
6155
matchRedundantBinOpInEquality(MachineInstr & MI,BuildFnTy & MatchInfo)6156 bool CombinerHelper::matchRedundantBinOpInEquality(MachineInstr &MI,
6157 BuildFnTy &MatchInfo) {
6158 assert(MI.getOpcode() == TargetOpcode::G_ICMP);
6159 // (X + Y) == X --> Y == 0
6160 // (X + Y) != X --> Y != 0
6161 // (X - Y) == X --> Y == 0
6162 // (X - Y) != X --> Y != 0
6163 // (X ^ Y) == X --> Y == 0
6164 // (X ^ Y) != X --> Y != 0
6165 Register Dst = MI.getOperand(0).getReg();
6166 CmpInst::Predicate Pred;
6167 Register X, Y, OpLHS, OpRHS;
6168 bool MatchedSub = mi_match(
6169 Dst, MRI,
6170 m_c_GICmp(m_Pred(Pred), m_Reg(X), m_GSub(m_Reg(OpLHS), m_Reg(Y))));
6171 if (MatchedSub && X != OpLHS)
6172 return false;
6173 if (!MatchedSub) {
6174 if (!mi_match(Dst, MRI,
6175 m_c_GICmp(m_Pred(Pred), m_Reg(X),
6176 m_any_of(m_GAdd(m_Reg(OpLHS), m_Reg(OpRHS)),
6177 m_GXor(m_Reg(OpLHS), m_Reg(OpRHS))))))
6178 return false;
6179 Y = X == OpLHS ? OpRHS : X == OpRHS ? OpLHS : Register();
6180 }
6181 MatchInfo = [=](MachineIRBuilder &B) {
6182 auto Zero = B.buildConstant(MRI.getType(Y), 0);
6183 B.buildICmp(Pred, Dst, Y, Zero);
6184 };
6185 return CmpInst::isEquality(Pred) && Y.isValid();
6186 }
6187
tryCombine(MachineInstr & MI)6188 bool CombinerHelper::tryCombine(MachineInstr &MI) {
6189 if (tryCombineCopy(MI))
6190 return true;
6191 if (tryCombineExtendingLoads(MI))
6192 return true;
6193 if (tryCombineIndexedLoadStore(MI))
6194 return true;
6195 return false;
6196 }
6197