1 //===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass combines dag nodes to form fewer, simpler DAG nodes. It can be run
10 // both before and after the DAG is legalized.
11 //
12 // This pass is not a substitute for the LLVM IR instcombine pass. This pass is
13 // primarily intended to handle simplification opportunities that are implicit
14 // in the LLVM IR and exposed by the various codegen lowering phases.
15 //
16 //===----------------------------------------------------------------------===//
17
18 #include "llvm/ADT/APFloat.h"
19 #include "llvm/ADT/APInt.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/IntervalMap.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/SmallSet.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/Statistic.h"
30 #include "llvm/Analysis/AliasAnalysis.h"
31 #include "llvm/Analysis/MemoryLocation.h"
32 #include "llvm/Analysis/TargetLibraryInfo.h"
33 #include "llvm/Analysis/VectorUtils.h"
34 #include "llvm/CodeGen/DAGCombine.h"
35 #include "llvm/CodeGen/ISDOpcodes.h"
36 #include "llvm/CodeGen/MachineFunction.h"
37 #include "llvm/CodeGen/MachineMemOperand.h"
38 #include "llvm/CodeGen/RuntimeLibcalls.h"
39 #include "llvm/CodeGen/SelectionDAG.h"
40 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
41 #include "llvm/CodeGen/SelectionDAGNodes.h"
42 #include "llvm/CodeGen/SelectionDAGTargetInfo.h"
43 #include "llvm/CodeGen/TargetLowering.h"
44 #include "llvm/CodeGen/TargetRegisterInfo.h"
45 #include "llvm/CodeGen/TargetSubtargetInfo.h"
46 #include "llvm/CodeGen/ValueTypes.h"
47 #include "llvm/IR/Attributes.h"
48 #include "llvm/IR/Constant.h"
49 #include "llvm/IR/DataLayout.h"
50 #include "llvm/IR/DerivedTypes.h"
51 #include "llvm/IR/Function.h"
52 #include "llvm/IR/Metadata.h"
53 #include "llvm/Support/Casting.h"
54 #include "llvm/Support/CodeGen.h"
55 #include "llvm/Support/CommandLine.h"
56 #include "llvm/Support/Compiler.h"
57 #include "llvm/Support/Debug.h"
58 #include "llvm/Support/ErrorHandling.h"
59 #include "llvm/Support/KnownBits.h"
60 #include "llvm/Support/MachineValueType.h"
61 #include "llvm/Support/MathExtras.h"
62 #include "llvm/Support/raw_ostream.h"
63 #include "llvm/Target/TargetMachine.h"
64 #include "llvm/Target/TargetOptions.h"
65 #include <algorithm>
66 #include <cassert>
67 #include <cstdint>
68 #include <functional>
69 #include <iterator>
70 #include <optional>
71 #include <string>
72 #include <tuple>
73 #include <utility>
74 #include <variant>
75
76 using namespace llvm;
77
78 #define DEBUG_TYPE "dagcombine"
79
80 STATISTIC(NodesCombined , "Number of dag nodes combined");
81 STATISTIC(PreIndexedNodes , "Number of pre-indexed nodes created");
82 STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
83 STATISTIC(OpsNarrowed , "Number of load/op/store narrowed");
84 STATISTIC(LdStFP2Int , "Number of fp load/store pairs transformed to int");
85 STATISTIC(SlicedLoads, "Number of load sliced");
86 STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
87
88 static cl::opt<bool>
89 CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
90 cl::desc("Enable DAG combiner's use of IR alias analysis"));
91
92 static cl::opt<bool>
93 UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(true),
94 cl::desc("Enable DAG combiner's use of TBAA"));
95
96 #ifndef NDEBUG
97 static cl::opt<std::string>
98 CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
99 cl::desc("Only use DAG-combiner alias analysis in this"
100 " function"));
101 #endif
102
103 /// Hidden option to stress test load slicing, i.e., when this option
104 /// is enabled, load slicing bypasses most of its profitability guards.
105 static cl::opt<bool>
106 StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
107 cl::desc("Bypass the profitability model of load slicing"),
108 cl::init(false));
109
110 static cl::opt<bool>
111 MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(true),
112 cl::desc("DAG combiner may split indexing from loads"));
113
114 static cl::opt<bool>
115 EnableStoreMerging("combiner-store-merging", cl::Hidden, cl::init(true),
116 cl::desc("DAG combiner enable merging multiple stores "
117 "into a wider store"));
118
119 static cl::opt<unsigned> TokenFactorInlineLimit(
120 "combiner-tokenfactor-inline-limit", cl::Hidden, cl::init(2048),
121 cl::desc("Limit the number of operands to inline for Token Factors"));
122
123 static cl::opt<unsigned> StoreMergeDependenceLimit(
124 "combiner-store-merge-dependence-limit", cl::Hidden, cl::init(10),
125 cl::desc("Limit the number of times for the same StoreNode and RootNode "
126 "to bail out in store merging dependence check"));
127
128 static cl::opt<bool> EnableReduceLoadOpStoreWidth(
129 "combiner-reduce-load-op-store-width", cl::Hidden, cl::init(true),
130 cl::desc("DAG combiner enable reducing the width of load/op/store "
131 "sequence"));
132
133 static cl::opt<bool> EnableShrinkLoadReplaceStoreWithStore(
134 "combiner-shrink-load-replace-store-with-store", cl::Hidden, cl::init(true),
135 cl::desc("DAG combiner enable load/<replace bytes>/store with "
136 "a narrower store"));
137
138 static cl::opt<bool> EnableVectorFCopySignExtendRound(
139 "combiner-vector-fcopysign-extend-round", cl::Hidden, cl::init(false),
140 cl::desc(
141 "Enable merging extends and rounds into FCOPYSIGN on vector types"));
142
143 namespace {
144
145 class DAGCombiner {
146 SelectionDAG &DAG;
147 const TargetLowering &TLI;
148 const SelectionDAGTargetInfo *STI;
149 CombineLevel Level = BeforeLegalizeTypes;
150 CodeGenOpt::Level OptLevel;
151 bool LegalDAG = false;
152 bool LegalOperations = false;
153 bool LegalTypes = false;
154 bool ForCodeSize;
155 bool DisableGenericCombines;
156
157 /// Worklist of all of the nodes that need to be simplified.
158 ///
159 /// This must behave as a stack -- new nodes to process are pushed onto the
160 /// back and when processing we pop off of the back.
161 ///
162 /// The worklist will not contain duplicates but may contain null entries
163 /// due to nodes being deleted from the underlying DAG.
164 SmallVector<SDNode *, 64> Worklist;
165
166 /// Mapping from an SDNode to its position on the worklist.
167 ///
168 /// This is used to find and remove nodes from the worklist (by nulling
169 /// them) when they are deleted from the underlying DAG. It relies on
170 /// stable indices of nodes within the worklist.
171 DenseMap<SDNode *, unsigned> WorklistMap;
172 /// This records all nodes attempted to add to the worklist since we
173 /// considered a new worklist entry. As we keep do not add duplicate nodes
174 /// in the worklist, this is different from the tail of the worklist.
175 SmallSetVector<SDNode *, 32> PruningList;
176
177 /// Set of nodes which have been combined (at least once).
178 ///
179 /// This is used to allow us to reliably add any operands of a DAG node
180 /// which have not yet been combined to the worklist.
181 SmallPtrSet<SDNode *, 32> CombinedNodes;
182
183 /// Map from candidate StoreNode to the pair of RootNode and count.
184 /// The count is used to track how many times we have seen the StoreNode
185 /// with the same RootNode bail out in dependence check. If we have seen
186 /// the bail out for the same pair many times over a limit, we won't
187 /// consider the StoreNode with the same RootNode as store merging
188 /// candidate again.
189 DenseMap<SDNode *, std::pair<SDNode *, unsigned>> StoreRootCountMap;
190
191 // AA - Used for DAG load/store alias analysis.
192 AliasAnalysis *AA;
193
194 /// When an instruction is simplified, add all users of the instruction to
195 /// the work lists because they might get more simplified now.
AddUsersToWorklist(SDNode * N)196 void AddUsersToWorklist(SDNode *N) {
197 for (SDNode *Node : N->uses())
198 AddToWorklist(Node);
199 }
200
201 /// Convenient shorthand to add a node and all of its user to the worklist.
AddToWorklistWithUsers(SDNode * N)202 void AddToWorklistWithUsers(SDNode *N) {
203 AddUsersToWorklist(N);
204 AddToWorklist(N);
205 }
206
207 // Prune potentially dangling nodes. This is called after
208 // any visit to a node, but should also be called during a visit after any
209 // failed combine which may have created a DAG node.
clearAddedDanglingWorklistEntries()210 void clearAddedDanglingWorklistEntries() {
211 // Check any nodes added to the worklist to see if they are prunable.
212 while (!PruningList.empty()) {
213 auto *N = PruningList.pop_back_val();
214 if (N->use_empty())
215 recursivelyDeleteUnusedNodes(N);
216 }
217 }
218
getNextWorklistEntry()219 SDNode *getNextWorklistEntry() {
220 // Before we do any work, remove nodes that are not in use.
221 clearAddedDanglingWorklistEntries();
222 SDNode *N = nullptr;
223 // The Worklist holds the SDNodes in order, but it may contain null
224 // entries.
225 while (!N && !Worklist.empty()) {
226 N = Worklist.pop_back_val();
227 }
228
229 if (N) {
230 bool GoodWorklistEntry = WorklistMap.erase(N);
231 (void)GoodWorklistEntry;
232 assert(GoodWorklistEntry &&
233 "Found a worklist entry without a corresponding map entry!");
234 }
235 return N;
236 }
237
238 /// Call the node-specific routine that folds each particular type of node.
239 SDValue visit(SDNode *N);
240
241 public:
DAGCombiner(SelectionDAG & D,AliasAnalysis * AA,CodeGenOpt::Level OL)242 DAGCombiner(SelectionDAG &D, AliasAnalysis *AA, CodeGenOpt::Level OL)
243 : DAG(D), TLI(D.getTargetLoweringInfo()),
244 STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL), AA(AA) {
245 ForCodeSize = DAG.shouldOptForSize();
246 DisableGenericCombines = STI && STI->disableGenericCombines(OptLevel);
247
248 MaximumLegalStoreInBits = 0;
249 // We use the minimum store size here, since that's all we can guarantee
250 // for the scalable vector types.
251 for (MVT VT : MVT::all_valuetypes())
252 if (EVT(VT).isSimple() && VT != MVT::Other &&
253 TLI.isTypeLegal(EVT(VT)) &&
254 VT.getSizeInBits().getKnownMinValue() >= MaximumLegalStoreInBits)
255 MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinValue();
256 }
257
ConsiderForPruning(SDNode * N)258 void ConsiderForPruning(SDNode *N) {
259 // Mark this for potential pruning.
260 PruningList.insert(N);
261 }
262
263 /// Add to the worklist making sure its instance is at the back (next to be
264 /// processed.)
AddToWorklist(SDNode * N)265 void AddToWorklist(SDNode *N) {
266 assert(N->getOpcode() != ISD::DELETED_NODE &&
267 "Deleted Node added to Worklist");
268
269 // Skip handle nodes as they can't usefully be combined and confuse the
270 // zero-use deletion strategy.
271 if (N->getOpcode() == ISD::HANDLENODE)
272 return;
273
274 ConsiderForPruning(N);
275
276 if (WorklistMap.insert(std::make_pair(N, Worklist.size())).second)
277 Worklist.push_back(N);
278 }
279
280 /// Remove all instances of N from the worklist.
removeFromWorklist(SDNode * N)281 void removeFromWorklist(SDNode *N) {
282 CombinedNodes.erase(N);
283 PruningList.remove(N);
284 StoreRootCountMap.erase(N);
285
286 auto It = WorklistMap.find(N);
287 if (It == WorklistMap.end())
288 return; // Not in the worklist.
289
290 // Null out the entry rather than erasing it to avoid a linear operation.
291 Worklist[It->second] = nullptr;
292 WorklistMap.erase(It);
293 }
294
295 void deleteAndRecombine(SDNode *N);
296 bool recursivelyDeleteUnusedNodes(SDNode *N);
297
298 /// Replaces all uses of the results of one DAG node with new values.
299 SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
300 bool AddTo = true);
301
302 /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res,bool AddTo=true)303 SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) {
304 return CombineTo(N, &Res, 1, AddTo);
305 }
306
307 /// Replaces all uses of the results of one DAG node with new values.
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo=true)308 SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1,
309 bool AddTo = true) {
310 SDValue To[] = { Res0, Res1 };
311 return CombineTo(N, To, 2, AddTo);
312 }
313
314 void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
315
316 private:
317 unsigned MaximumLegalStoreInBits;
318
319 /// Check the specified integer node value to see if it can be simplified or
320 /// if things it uses can be simplified by bit propagation.
321 /// If so, return true.
SimplifyDemandedBits(SDValue Op)322 bool SimplifyDemandedBits(SDValue Op) {
323 unsigned BitWidth = Op.getScalarValueSizeInBits();
324 APInt DemandedBits = APInt::getAllOnes(BitWidth);
325 return SimplifyDemandedBits(Op, DemandedBits);
326 }
327
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits)328 bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) {
329 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
330 KnownBits Known;
331 if (!TLI.SimplifyDemandedBits(Op, DemandedBits, Known, TLO, 0, false))
332 return false;
333
334 // Revisit the node.
335 AddToWorklist(Op.getNode());
336
337 CommitTargetLoweringOpt(TLO);
338 return true;
339 }
340
341 /// Check the specified vector node value to see if it can be simplified or
342 /// if things it uses can be simplified as it only uses some of the
343 /// elements. If so, return true.
SimplifyDemandedVectorElts(SDValue Op)344 bool SimplifyDemandedVectorElts(SDValue Op) {
345 // TODO: For now just pretend it cannot be simplified.
346 if (Op.getValueType().isScalableVector())
347 return false;
348
349 unsigned NumElts = Op.getValueType().getVectorNumElements();
350 APInt DemandedElts = APInt::getAllOnes(NumElts);
351 return SimplifyDemandedVectorElts(Op, DemandedElts);
352 }
353
354 bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
355 const APInt &DemandedElts,
356 bool AssumeSingleUse = false);
357 bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts,
358 bool AssumeSingleUse = false);
359
360 bool CombineToPreIndexedLoadStore(SDNode *N);
361 bool CombineToPostIndexedLoadStore(SDNode *N);
362 SDValue SplitIndexingFromLoad(LoadSDNode *LD);
363 bool SliceUpLoad(SDNode *N);
364
365 // Scalars have size 0 to distinguish from singleton vectors.
366 SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
367 bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
368 bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
369
370 /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
371 /// load.
372 ///
373 /// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced.
374 /// \param InVecVT type of the input vector to EVE with bitcasts resolved.
375 /// \param EltNo index of the vector element to load.
376 /// \param OriginalLoad load that EVE came from to be replaced.
377 /// \returns EVE on success SDValue() on failure.
378 SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
379 SDValue EltNo,
380 LoadSDNode *OriginalLoad);
381 void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
382 SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
383 SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
384 SDValue ZExtPromoteOperand(SDValue Op, EVT PVT);
385 SDValue PromoteIntBinOp(SDValue Op);
386 SDValue PromoteIntShiftOp(SDValue Op);
387 SDValue PromoteExtend(SDValue Op);
388 bool PromoteLoad(SDValue Op);
389
390 SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
391 SDValue RHS, SDValue True, SDValue False,
392 ISD::CondCode CC);
393
394 /// Call the node-specific routine that knows how to fold each
395 /// particular type of node. If that doesn't do anything, try the
396 /// target-specific DAG combines.
397 SDValue combine(SDNode *N);
398
399 // Visitation implementation - Implement dag node combining for different
400 // node types. The semantics are as follows:
401 // Return Value:
402 // SDValue.getNode() == 0 - No change was made
403 // SDValue.getNode() == N - N was replaced, is dead and has been handled.
404 // otherwise - N should be replaced by the returned Operand.
405 //
406 SDValue visitTokenFactor(SDNode *N);
407 SDValue visitMERGE_VALUES(SDNode *N);
408 SDValue visitADD(SDNode *N);
409 SDValue visitADDLike(SDNode *N);
410 SDValue visitADDLikeCommutative(SDValue N0, SDValue N1, SDNode *LocReference);
411 SDValue visitSUB(SDNode *N);
412 SDValue visitADDSAT(SDNode *N);
413 SDValue visitSUBSAT(SDNode *N);
414 SDValue visitADDC(SDNode *N);
415 SDValue visitADDO(SDNode *N);
416 SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
417 SDValue visitSUBC(SDNode *N);
418 SDValue visitSUBO(SDNode *N);
419 SDValue visitADDE(SDNode *N);
420 SDValue visitADDCARRY(SDNode *N);
421 SDValue visitSADDO_CARRY(SDNode *N);
422 SDValue visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn, SDNode *N);
423 SDValue visitSUBE(SDNode *N);
424 SDValue visitSUBCARRY(SDNode *N);
425 SDValue visitSSUBO_CARRY(SDNode *N);
426 SDValue visitMUL(SDNode *N);
427 SDValue visitMULFIX(SDNode *N);
428 SDValue useDivRem(SDNode *N);
429 SDValue visitSDIV(SDNode *N);
430 SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
431 SDValue visitUDIV(SDNode *N);
432 SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
433 SDValue visitREM(SDNode *N);
434 SDValue visitMULHU(SDNode *N);
435 SDValue visitMULHS(SDNode *N);
436 SDValue visitAVG(SDNode *N);
437 SDValue visitSMUL_LOHI(SDNode *N);
438 SDValue visitUMUL_LOHI(SDNode *N);
439 SDValue visitMULO(SDNode *N);
440 SDValue visitIMINMAX(SDNode *N);
441 SDValue visitAND(SDNode *N);
442 SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
443 SDValue visitOR(SDNode *N);
444 SDValue visitORLike(SDValue N0, SDValue N1, SDNode *N);
445 SDValue visitXOR(SDNode *N);
446 SDValue SimplifyVCastOp(SDNode *N, const SDLoc &DL);
447 SDValue SimplifyVBinOp(SDNode *N, const SDLoc &DL);
448 SDValue visitSHL(SDNode *N);
449 SDValue visitSRA(SDNode *N);
450 SDValue visitSRL(SDNode *N);
451 SDValue visitFunnelShift(SDNode *N);
452 SDValue visitSHLSAT(SDNode *N);
453 SDValue visitRotate(SDNode *N);
454 SDValue visitABS(SDNode *N);
455 SDValue visitBSWAP(SDNode *N);
456 SDValue visitBITREVERSE(SDNode *N);
457 SDValue visitCTLZ(SDNode *N);
458 SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
459 SDValue visitCTTZ(SDNode *N);
460 SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
461 SDValue visitCTPOP(SDNode *N);
462 SDValue visitSELECT(SDNode *N);
463 SDValue visitVSELECT(SDNode *N);
464 SDValue visitSELECT_CC(SDNode *N);
465 SDValue visitSETCC(SDNode *N);
466 SDValue visitSETCCCARRY(SDNode *N);
467 SDValue visitSIGN_EXTEND(SDNode *N);
468 SDValue visitZERO_EXTEND(SDNode *N);
469 SDValue visitANY_EXTEND(SDNode *N);
470 SDValue visitAssertExt(SDNode *N);
471 SDValue visitAssertAlign(SDNode *N);
472 SDValue visitSIGN_EXTEND_INREG(SDNode *N);
473 SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
474 SDValue visitTRUNCATE(SDNode *N);
475 SDValue visitBITCAST(SDNode *N);
476 SDValue visitFREEZE(SDNode *N);
477 SDValue visitBUILD_PAIR(SDNode *N);
478 SDValue visitFADD(SDNode *N);
479 SDValue visitSTRICT_FADD(SDNode *N);
480 SDValue visitFSUB(SDNode *N);
481 SDValue visitFMUL(SDNode *N);
482 SDValue visitFMA(SDNode *N);
483 SDValue visitFDIV(SDNode *N);
484 SDValue visitFREM(SDNode *N);
485 SDValue visitFSQRT(SDNode *N);
486 SDValue visitFCOPYSIGN(SDNode *N);
487 SDValue visitFPOW(SDNode *N);
488 SDValue visitSINT_TO_FP(SDNode *N);
489 SDValue visitUINT_TO_FP(SDNode *N);
490 SDValue visitFP_TO_SINT(SDNode *N);
491 SDValue visitFP_TO_UINT(SDNode *N);
492 SDValue visitFP_ROUND(SDNode *N);
493 SDValue visitFP_EXTEND(SDNode *N);
494 SDValue visitFNEG(SDNode *N);
495 SDValue visitFABS(SDNode *N);
496 SDValue visitFCEIL(SDNode *N);
497 SDValue visitFTRUNC(SDNode *N);
498 SDValue visitFFLOOR(SDNode *N);
499 SDValue visitFMinMax(SDNode *N);
500 SDValue visitBRCOND(SDNode *N);
501 SDValue visitBR_CC(SDNode *N);
502 SDValue visitLOAD(SDNode *N);
503
504 SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
505 SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
506
507 bool refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode *N);
508
509 SDValue visitSTORE(SDNode *N);
510 SDValue visitLIFETIME_END(SDNode *N);
511 SDValue visitINSERT_VECTOR_ELT(SDNode *N);
512 SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
513 SDValue visitBUILD_VECTOR(SDNode *N);
514 SDValue visitCONCAT_VECTORS(SDNode *N);
515 SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
516 SDValue visitVECTOR_SHUFFLE(SDNode *N);
517 SDValue visitSCALAR_TO_VECTOR(SDNode *N);
518 SDValue visitINSERT_SUBVECTOR(SDNode *N);
519 SDValue visitMLOAD(SDNode *N);
520 SDValue visitMSTORE(SDNode *N);
521 SDValue visitMGATHER(SDNode *N);
522 SDValue visitMSCATTER(SDNode *N);
523 SDValue visitVPGATHER(SDNode *N);
524 SDValue visitVPSCATTER(SDNode *N);
525 SDValue visitFP_TO_FP16(SDNode *N);
526 SDValue visitFP16_TO_FP(SDNode *N);
527 SDValue visitFP_TO_BF16(SDNode *N);
528 SDValue visitVECREDUCE(SDNode *N);
529 SDValue visitVPOp(SDNode *N);
530
531 SDValue visitFADDForFMACombine(SDNode *N);
532 SDValue visitFSUBForFMACombine(SDNode *N);
533 SDValue visitFMULForFMADistributiveCombine(SDNode *N);
534
535 SDValue XformToShuffleWithZero(SDNode *N);
536 bool reassociationCanBreakAddressingModePattern(unsigned Opc,
537 const SDLoc &DL,
538 SDNode *N,
539 SDValue N0,
540 SDValue N1);
541 SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
542 SDValue N1);
543 SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
544 SDValue N1, SDNodeFlags Flags);
545
546 SDValue visitShiftByConstant(SDNode *N);
547
548 SDValue foldSelectOfConstants(SDNode *N);
549 SDValue foldVSelectOfConstants(SDNode *N);
550 SDValue foldBinOpIntoSelect(SDNode *BO);
551 bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
552 SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
553 SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
554 SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
555 SDValue N2, SDValue N3, ISD::CondCode CC,
556 bool NotExtCompare = false);
557 SDValue convertSelectOfFPConstantsToLoadOffset(
558 const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
559 ISD::CondCode CC);
560 SDValue foldSignChangeInBitcast(SDNode *N);
561 SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
562 SDValue N2, SDValue N3, ISD::CondCode CC);
563 SDValue foldSelectOfBinops(SDNode *N);
564 SDValue foldSextSetcc(SDNode *N);
565 SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
566 const SDLoc &DL);
567 SDValue foldSubToUSubSat(EVT DstVT, SDNode *N);
568 SDValue foldABSToABD(SDNode *N);
569 SDValue unfoldMaskedMerge(SDNode *N);
570 SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
571 SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
572 const SDLoc &DL, bool foldBooleans);
573 SDValue rebuildSetCC(SDValue N);
574
575 bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
576 SDValue &CC, bool MatchStrict = false) const;
577 bool isOneUseSetCC(SDValue N) const;
578
579 SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
580 unsigned HiOp);
581 SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
582 SDValue CombineExtLoad(SDNode *N);
583 SDValue CombineZExtLogicopShiftLoad(SDNode *N);
584 SDValue combineRepeatedFPDivisors(SDNode *N);
585 SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex);
586 SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
587 SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT);
588 SDValue BuildSDIV(SDNode *N);
589 SDValue BuildSDIVPow2(SDNode *N);
590 SDValue BuildUDIV(SDNode *N);
591 SDValue BuildSREMPow2(SDNode *N);
592 SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N);
593 SDValue BuildLogBase2(SDValue V, const SDLoc &DL);
594 SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
595 SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
596 SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
597 SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip);
598 SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
599 SDNodeFlags Flags, bool Reciprocal);
600 SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
601 SDNodeFlags Flags, bool Reciprocal);
602 SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
603 bool DemandHighBits = true);
604 SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
605 SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
606 SDValue InnerPos, SDValue InnerNeg, bool HasPos,
607 unsigned PosOpcode, unsigned NegOpcode,
608 const SDLoc &DL);
609 SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
610 SDValue InnerPos, SDValue InnerNeg, bool HasPos,
611 unsigned PosOpcode, unsigned NegOpcode,
612 const SDLoc &DL);
613 SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
614 SDValue MatchLoadCombine(SDNode *N);
615 SDValue mergeTruncStores(StoreSDNode *N);
616 SDValue reduceLoadWidth(SDNode *N);
617 SDValue ReduceLoadOpStoreWidth(SDNode *N);
618 SDValue splitMergedValStore(StoreSDNode *ST);
619 SDValue TransformFPLoadStorePair(SDNode *N);
620 SDValue convertBuildVecZextToZext(SDNode *N);
621 SDValue convertBuildVecZextToBuildVecWithZeros(SDNode *N);
622 SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
623 SDValue reduceBuildVecTruncToBitCast(SDNode *N);
624 SDValue reduceBuildVecToShuffle(SDNode *N);
625 SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
626 ArrayRef<int> VectorMask, SDValue VecIn1,
627 SDValue VecIn2, unsigned LeftIdx,
628 bool DidSplitVec);
629 SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
630
631 /// Walk up chain skipping non-aliasing memory nodes,
632 /// looking for aliasing nodes and adding them to the Aliases vector.
633 void GatherAllAliases(SDNode *N, SDValue OriginalChain,
634 SmallVectorImpl<SDValue> &Aliases);
635
636 /// Return true if there is any possibility that the two addresses overlap.
637 bool mayAlias(SDNode *Op0, SDNode *Op1) const;
638
639 /// Walk up chain skipping non-aliasing memory nodes, looking for a better
640 /// chain (aliasing node.)
641 SDValue FindBetterChain(SDNode *N, SDValue Chain);
642
643 /// Try to replace a store and any possibly adjacent stores on
644 /// consecutive chains with better chains. Return true only if St is
645 /// replaced.
646 ///
647 /// Notice that other chains may still be replaced even if the function
648 /// returns false.
649 bool findBetterNeighborChains(StoreSDNode *St);
650
651 // Helper for findBetterNeighborChains. Walk up store chain add additional
652 // chained stores that do not overlap and can be parallelized.
653 bool parallelizeChainedStores(StoreSDNode *St);
654
655 /// Holds a pointer to an LSBaseSDNode as well as information on where it
656 /// is located in a sequence of memory operations connected by a chain.
657 struct MemOpLink {
658 // Ptr to the mem node.
659 LSBaseSDNode *MemNode;
660
661 // Offset from the base ptr.
662 int64_t OffsetFromBase;
663
MemOpLink__anonbd6f1c500111::DAGCombiner::MemOpLink664 MemOpLink(LSBaseSDNode *N, int64_t Offset)
665 : MemNode(N), OffsetFromBase(Offset) {}
666 };
667
668 // Classify the origin of a stored value.
669 enum class StoreSource { Unknown, Constant, Extract, Load };
getStoreSource(SDValue StoreVal)670 StoreSource getStoreSource(SDValue StoreVal) {
671 switch (StoreVal.getOpcode()) {
672 case ISD::Constant:
673 case ISD::ConstantFP:
674 return StoreSource::Constant;
675 case ISD::EXTRACT_VECTOR_ELT:
676 case ISD::EXTRACT_SUBVECTOR:
677 return StoreSource::Extract;
678 case ISD::LOAD:
679 return StoreSource::Load;
680 default:
681 return StoreSource::Unknown;
682 }
683 }
684
685 /// This is a helper function for visitMUL to check the profitability
686 /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
687 /// MulNode is the original multiply, AddNode is (add x, c1),
688 /// and ConstNode is c2.
689 bool isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
690 SDValue ConstNode);
691
692 /// This is a helper function for visitAND and visitZERO_EXTEND. Returns
693 /// true if the (and (load x) c) pattern matches an extload. ExtVT returns
694 /// the type of the loaded value to be extended.
695 bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
696 EVT LoadResultTy, EVT &ExtVT);
697
698 /// Helper function to calculate whether the given Load/Store can have its
699 /// width reduced to ExtVT.
700 bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
701 EVT &MemVT, unsigned ShAmt = 0);
702
703 /// Used by BackwardsPropagateMask to find suitable loads.
704 bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
705 SmallPtrSetImpl<SDNode*> &NodesWithConsts,
706 ConstantSDNode *Mask, SDNode *&NodeToMask);
707 /// Attempt to propagate a given AND node back to load leaves so that they
708 /// can be combined into narrow loads.
709 bool BackwardsPropagateMask(SDNode *N);
710
711 /// Helper function for mergeConsecutiveStores which merges the component
712 /// store chains.
713 SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
714 unsigned NumStores);
715
716 /// This is a helper function for mergeConsecutiveStores. When the source
717 /// elements of the consecutive stores are all constants or all extracted
718 /// vector elements, try to merge them into one larger store introducing
719 /// bitcasts if necessary. \return True if a merged store was created.
720 bool mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
721 EVT MemVT, unsigned NumStores,
722 bool IsConstantSrc, bool UseVector,
723 bool UseTrunc);
724
725 /// This is a helper function for mergeConsecutiveStores. Stores that
726 /// potentially may be merged with St are placed in StoreNodes. RootNode is
727 /// a chain predecessor to all store candidates.
728 void getStoreMergeCandidates(StoreSDNode *St,
729 SmallVectorImpl<MemOpLink> &StoreNodes,
730 SDNode *&Root);
731
732 /// Helper function for mergeConsecutiveStores. Checks if candidate stores
733 /// have indirect dependency through their operands. RootNode is the
734 /// predecessor to all stores calculated by getStoreMergeCandidates and is
735 /// used to prune the dependency check. \return True if safe to merge.
736 bool checkMergeStoreCandidatesForDependencies(
737 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
738 SDNode *RootNode);
739
740 /// This is a helper function for mergeConsecutiveStores. Given a list of
741 /// store candidates, find the first N that are consecutive in memory.
742 /// Returns 0 if there are not at least 2 consecutive stores to try merging.
743 unsigned getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
744 int64_t ElementSizeBytes) const;
745
746 /// This is a helper function for mergeConsecutiveStores. It is used for
747 /// store chains that are composed entirely of constant values.
748 bool tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> &StoreNodes,
749 unsigned NumConsecutiveStores,
750 EVT MemVT, SDNode *Root, bool AllowVectors);
751
752 /// This is a helper function for mergeConsecutiveStores. It is used for
753 /// store chains that are composed entirely of extracted vector elements.
754 /// When extracting multiple vector elements, try to store them in one
755 /// vector store rather than a sequence of scalar stores.
756 bool tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> &StoreNodes,
757 unsigned NumConsecutiveStores, EVT MemVT,
758 SDNode *Root);
759
760 /// This is a helper function for mergeConsecutiveStores. It is used for
761 /// store chains that are composed entirely of loaded values.
762 bool tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
763 unsigned NumConsecutiveStores, EVT MemVT,
764 SDNode *Root, bool AllowVectors,
765 bool IsNonTemporalStore, bool IsNonTemporalLoad);
766
767 /// Merge consecutive store operations into a wide store.
768 /// This optimization uses wide integers or vectors when possible.
769 /// \return true if stores were merged.
770 bool mergeConsecutiveStores(StoreSDNode *St);
771
772 /// Try to transform a truncation where C is a constant:
773 /// (trunc (and X, C)) -> (and (trunc X), (trunc C))
774 ///
775 /// \p N needs to be a truncation and its first operand an AND. Other
776 /// requirements are checked by the function (e.g. that trunc is
777 /// single-use) and if missed an empty SDValue is returned.
778 SDValue distributeTruncateThroughAnd(SDNode *N);
779
780 /// Helper function to determine whether the target supports operation
781 /// given by \p Opcode for type \p VT, that is, whether the operation
782 /// is legal or custom before legalizing operations, and whether is
783 /// legal (but not custom) after legalization.
hasOperation(unsigned Opcode,EVT VT)784 bool hasOperation(unsigned Opcode, EVT VT) {
785 return TLI.isOperationLegalOrCustom(Opcode, VT, LegalOperations);
786 }
787
788 public:
789 /// Runs the dag combiner on all nodes in the work list
790 void Run(CombineLevel AtLevel);
791
getDAG() const792 SelectionDAG &getDAG() const { return DAG; }
793
794 /// Returns a type large enough to hold any valid shift amount - before type
795 /// legalization these can be huge.
getShiftAmountTy(EVT LHSTy)796 EVT getShiftAmountTy(EVT LHSTy) {
797 assert(LHSTy.isInteger() && "Shift amount is not an integer type!");
798 return TLI.getShiftAmountTy(LHSTy, DAG.getDataLayout(), LegalTypes);
799 }
800
801 /// This method returns true if we are running before type legalization or
802 /// if the specified VT is legal.
isTypeLegal(const EVT & VT)803 bool isTypeLegal(const EVT &VT) {
804 if (!LegalTypes) return true;
805 return TLI.isTypeLegal(VT);
806 }
807
808 /// Convenience wrapper around TargetLowering::getSetCCResultType
getSetCCResultType(EVT VT) const809 EVT getSetCCResultType(EVT VT) const {
810 return TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
811 }
812
813 void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
814 SDValue OrigLoad, SDValue ExtLoad,
815 ISD::NodeType ExtType);
816 };
817
818 /// This class is a DAGUpdateListener that removes any deleted
819 /// nodes from the worklist.
820 class WorklistRemover : public SelectionDAG::DAGUpdateListener {
821 DAGCombiner &DC;
822
823 public:
WorklistRemover(DAGCombiner & dc)824 explicit WorklistRemover(DAGCombiner &dc)
825 : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
826
NodeDeleted(SDNode * N,SDNode * E)827 void NodeDeleted(SDNode *N, SDNode *E) override {
828 DC.removeFromWorklist(N);
829 }
830 };
831
832 class WorklistInserter : public SelectionDAG::DAGUpdateListener {
833 DAGCombiner &DC;
834
835 public:
WorklistInserter(DAGCombiner & dc)836 explicit WorklistInserter(DAGCombiner &dc)
837 : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
838
839 // FIXME: Ideally we could add N to the worklist, but this causes exponential
840 // compile time costs in large DAGs, e.g. Halide.
NodeInserted(SDNode * N)841 void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
842 };
843
844 } // end anonymous namespace
845
846 //===----------------------------------------------------------------------===//
847 // TargetLowering::DAGCombinerInfo implementation
848 //===----------------------------------------------------------------------===//
849
AddToWorklist(SDNode * N)850 void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) {
851 ((DAGCombiner*)DC)->AddToWorklist(N);
852 }
853
854 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,ArrayRef<SDValue> To,bool AddTo)855 CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
856 return ((DAGCombiner*)DC)->CombineTo(N, &To[0], To.size(), AddTo);
857 }
858
859 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res,bool AddTo)860 CombineTo(SDNode *N, SDValue Res, bool AddTo) {
861 return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
862 }
863
864 SDValue TargetLowering::DAGCombinerInfo::
CombineTo(SDNode * N,SDValue Res0,SDValue Res1,bool AddTo)865 CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
866 return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
867 }
868
869 bool TargetLowering::DAGCombinerInfo::
recursivelyDeleteUnusedNodes(SDNode * N)870 recursivelyDeleteUnusedNodes(SDNode *N) {
871 return ((DAGCombiner*)DC)->recursivelyDeleteUnusedNodes(N);
872 }
873
874 void TargetLowering::DAGCombinerInfo::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)875 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
876 return ((DAGCombiner*)DC)->CommitTargetLoweringOpt(TLO);
877 }
878
879 //===----------------------------------------------------------------------===//
880 // Helper Functions
881 //===----------------------------------------------------------------------===//
882
deleteAndRecombine(SDNode * N)883 void DAGCombiner::deleteAndRecombine(SDNode *N) {
884 removeFromWorklist(N);
885
886 // If the operands of this node are only used by the node, they will now be
887 // dead. Make sure to re-visit them and recursively delete dead nodes.
888 for (const SDValue &Op : N->ops())
889 // For an operand generating multiple values, one of the values may
890 // become dead allowing further simplification (e.g. split index
891 // arithmetic from an indexed load).
892 if (Op->hasOneUse() || Op->getNumValues() > 1)
893 AddToWorklist(Op.getNode());
894
895 DAG.DeleteNode(N);
896 }
897
898 // APInts must be the same size for most operations, this helper
899 // function zero extends the shorter of the pair so that they match.
900 // We provide an Offset so that we can create bitwidths that won't overflow.
zeroExtendToMatch(APInt & LHS,APInt & RHS,unsigned Offset=0)901 static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
902 unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth());
903 LHS = LHS.zext(Bits);
904 RHS = RHS.zext(Bits);
905 }
906
907 // Return true if this node is a setcc, or is a select_cc
908 // that selects between the target values used for true and false, making it
909 // equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
910 // the appropriate nodes based on the type of node we are checking. This
911 // simplifies life a bit for the callers.
isSetCCEquivalent(SDValue N,SDValue & LHS,SDValue & RHS,SDValue & CC,bool MatchStrict) const912 bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
913 SDValue &CC, bool MatchStrict) const {
914 if (N.getOpcode() == ISD::SETCC) {
915 LHS = N.getOperand(0);
916 RHS = N.getOperand(1);
917 CC = N.getOperand(2);
918 return true;
919 }
920
921 if (MatchStrict &&
922 (N.getOpcode() == ISD::STRICT_FSETCC ||
923 N.getOpcode() == ISD::STRICT_FSETCCS)) {
924 LHS = N.getOperand(1);
925 RHS = N.getOperand(2);
926 CC = N.getOperand(3);
927 return true;
928 }
929
930 if (N.getOpcode() != ISD::SELECT_CC || !TLI.isConstTrueVal(N.getOperand(2)) ||
931 !TLI.isConstFalseVal(N.getOperand(3)))
932 return false;
933
934 if (TLI.getBooleanContents(N.getValueType()) ==
935 TargetLowering::UndefinedBooleanContent)
936 return false;
937
938 LHS = N.getOperand(0);
939 RHS = N.getOperand(1);
940 CC = N.getOperand(4);
941 return true;
942 }
943
944 /// Return true if this is a SetCC-equivalent operation with only one use.
945 /// If this is true, it allows the users to invert the operation for free when
946 /// it is profitable to do so.
isOneUseSetCC(SDValue N) const947 bool DAGCombiner::isOneUseSetCC(SDValue N) const {
948 SDValue N0, N1, N2;
949 if (isSetCCEquivalent(N, N0, N1, N2) && N->hasOneUse())
950 return true;
951 return false;
952 }
953
isConstantSplatVectorMaskForType(SDNode * N,EVT ScalarTy)954 static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) {
955 if (!ScalarTy.isSimple())
956 return false;
957
958 uint64_t MaskForTy = 0ULL;
959 switch (ScalarTy.getSimpleVT().SimpleTy) {
960 case MVT::i8:
961 MaskForTy = 0xFFULL;
962 break;
963 case MVT::i16:
964 MaskForTy = 0xFFFFULL;
965 break;
966 case MVT::i32:
967 MaskForTy = 0xFFFFFFFFULL;
968 break;
969 default:
970 return false;
971 break;
972 }
973
974 APInt Val;
975 if (ISD::isConstantSplatVector(N, Val))
976 return Val.getLimitedValue() == MaskForTy;
977
978 return false;
979 }
980
981 // Determines if it is a constant integer or a splat/build vector of constant
982 // integers (and undefs).
983 // Do not permit build vector implicit truncation.
isConstantOrConstantVector(SDValue N,bool NoOpaques=false)984 static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) {
985 if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N))
986 return !(Const->isOpaque() && NoOpaques);
987 if (N.getOpcode() != ISD::BUILD_VECTOR && N.getOpcode() != ISD::SPLAT_VECTOR)
988 return false;
989 unsigned BitWidth = N.getScalarValueSizeInBits();
990 for (const SDValue &Op : N->op_values()) {
991 if (Op.isUndef())
992 continue;
993 ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Op);
994 if (!Const || Const->getAPIntValue().getBitWidth() != BitWidth ||
995 (Const->isOpaque() && NoOpaques))
996 return false;
997 }
998 return true;
999 }
1000
1001 // Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
1002 // undef's.
isAnyConstantBuildVector(SDValue V,bool NoOpaques=false)1003 static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
1004 if (V.getOpcode() != ISD::BUILD_VECTOR)
1005 return false;
1006 return isConstantOrConstantVector(V, NoOpaques) ||
1007 ISD::isBuildVectorOfConstantFPSDNodes(V.getNode());
1008 }
1009
1010 // Determine if this an indexed load with an opaque target constant index.
canSplitIdx(LoadSDNode * LD)1011 static bool canSplitIdx(LoadSDNode *LD) {
1012 return MaySplitLoadIndex &&
1013 (LD->getOperand(2).getOpcode() != ISD::TargetConstant ||
1014 !cast<ConstantSDNode>(LD->getOperand(2))->isOpaque());
1015 }
1016
reassociationCanBreakAddressingModePattern(unsigned Opc,const SDLoc & DL,SDNode * N,SDValue N0,SDValue N1)1017 bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
1018 const SDLoc &DL,
1019 SDNode *N,
1020 SDValue N0,
1021 SDValue N1) {
1022 // Currently this only tries to ensure we don't undo the GEP splits done by
1023 // CodeGenPrepare when shouldConsiderGEPOffsetSplit is true. To ensure this,
1024 // we check if the following transformation would be problematic:
1025 // (load/store (add, (add, x, offset1), offset2)) ->
1026 // (load/store (add, x, offset1+offset2)).
1027
1028 // (load/store (add, (add, x, y), offset2)) ->
1029 // (load/store (add, (add, x, offset2), y)).
1030
1031 if (Opc != ISD::ADD || N0.getOpcode() != ISD::ADD)
1032 return false;
1033
1034 auto *C2 = dyn_cast<ConstantSDNode>(N1);
1035 if (!C2)
1036 return false;
1037
1038 const APInt &C2APIntVal = C2->getAPIntValue();
1039 if (C2APIntVal.getSignificantBits() > 64)
1040 return false;
1041
1042 if (auto *C1 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
1043 if (N0.hasOneUse())
1044 return false;
1045
1046 const APInt &C1APIntVal = C1->getAPIntValue();
1047 const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal;
1048 if (CombinedValueIntVal.getSignificantBits() > 64)
1049 return false;
1050 const int64_t CombinedValue = CombinedValueIntVal.getSExtValue();
1051
1052 for (SDNode *Node : N->uses()) {
1053 if (auto *LoadStore = dyn_cast<MemSDNode>(Node)) {
1054 // Is x[offset2] already not a legal addressing mode? If so then
1055 // reassociating the constants breaks nothing (we test offset2 because
1056 // that's the one we hope to fold into the load or store).
1057 TargetLoweringBase::AddrMode AM;
1058 AM.HasBaseReg = true;
1059 AM.BaseOffs = C2APIntVal.getSExtValue();
1060 EVT VT = LoadStore->getMemoryVT();
1061 unsigned AS = LoadStore->getAddressSpace();
1062 Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1063 if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1064 continue;
1065
1066 // Would x[offset1+offset2] still be a legal addressing mode?
1067 AM.BaseOffs = CombinedValue;
1068 if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1069 return true;
1070 }
1071 }
1072 } else {
1073 if (auto *GA = dyn_cast<GlobalAddressSDNode>(N0.getOperand(1)))
1074 if (GA->getOpcode() == ISD::GlobalAddress && TLI.isOffsetFoldingLegal(GA))
1075 return false;
1076
1077 for (SDNode *Node : N->uses()) {
1078 auto *LoadStore = dyn_cast<MemSDNode>(Node);
1079 if (!LoadStore)
1080 return false;
1081
1082 // Is x[offset2] a legal addressing mode? If so then
1083 // reassociating the constants breaks address pattern
1084 TargetLoweringBase::AddrMode AM;
1085 AM.HasBaseReg = true;
1086 AM.BaseOffs = C2APIntVal.getSExtValue();
1087 EVT VT = LoadStore->getMemoryVT();
1088 unsigned AS = LoadStore->getAddressSpace();
1089 Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
1090 if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
1091 return false;
1092 }
1093 return true;
1094 }
1095
1096 return false;
1097 }
1098
1099 // Helper for DAGCombiner::reassociateOps. Try to reassociate an expression
1100 // such as (Opc N0, N1), if \p N0 is the same kind of operation as \p Opc.
reassociateOpsCommutative(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1)1101 SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
1102 SDValue N0, SDValue N1) {
1103 EVT VT = N0.getValueType();
1104
1105 if (N0.getOpcode() != Opc)
1106 return SDValue();
1107
1108 SDValue N00 = N0.getOperand(0);
1109 SDValue N01 = N0.getOperand(1);
1110
1111 if (DAG.isConstantIntBuildVectorOrConstantInt(peekThroughBitcasts(N01))) {
1112 if (DAG.isConstantIntBuildVectorOrConstantInt(peekThroughBitcasts(N1))) {
1113 // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
1114 if (SDValue OpNode = DAG.FoldConstantArithmetic(Opc, DL, VT, {N01, N1}))
1115 return DAG.getNode(Opc, DL, VT, N00, OpNode);
1116 return SDValue();
1117 }
1118 if (TLI.isReassocProfitable(DAG, N0, N1)) {
1119 // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
1120 // iff (op x, c1) has one use
1121 SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N00, N1);
1122 return DAG.getNode(Opc, DL, VT, OpNode, N01);
1123 }
1124 }
1125
1126 // Check for repeated operand logic simplifications.
1127 if (Opc == ISD::AND || Opc == ISD::OR) {
1128 // (N00 & N01) & N00 --> N00 & N01
1129 // (N00 & N01) & N01 --> N00 & N01
1130 // (N00 | N01) | N00 --> N00 | N01
1131 // (N00 | N01) | N01 --> N00 | N01
1132 if (N1 == N00 || N1 == N01)
1133 return N0;
1134 }
1135 if (Opc == ISD::XOR) {
1136 // (N00 ^ N01) ^ N00 --> N01
1137 if (N1 == N00)
1138 return N01;
1139 // (N00 ^ N01) ^ N01 --> N00
1140 if (N1 == N01)
1141 return N00;
1142 }
1143
1144 if (TLI.isReassocProfitable(DAG, N0, N1)) {
1145 if (N1 != N01) {
1146 // Reassociate if (op N00, N1) already exist
1147 if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N00, N1})) {
1148 // if Op (Op N00, N1), N01 already exist
1149 // we need to stop reassciate to avoid dead loop
1150 if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N01}))
1151 return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N01);
1152 }
1153 }
1154
1155 if (N1 != N00) {
1156 // Reassociate if (op N01, N1) already exist
1157 if (SDNode *NE = DAG.getNodeIfExists(Opc, DAG.getVTList(VT), {N01, N1})) {
1158 // if Op (Op N01, N1), N00 already exist
1159 // we need to stop reassciate to avoid dead loop
1160 if (!DAG.doesNodeExist(Opc, DAG.getVTList(VT), {SDValue(NE, 0), N00}))
1161 return DAG.getNode(Opc, DL, VT, SDValue(NE, 0), N00);
1162 }
1163 }
1164 }
1165
1166 return SDValue();
1167 }
1168
1169 // Try to reassociate commutative binops.
reassociateOps(unsigned Opc,const SDLoc & DL,SDValue N0,SDValue N1,SDNodeFlags Flags)1170 SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
1171 SDValue N1, SDNodeFlags Flags) {
1172 assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative.");
1173
1174 // Floating-point reassociation is not allowed without loose FP math.
1175 if (N0.getValueType().isFloatingPoint() ||
1176 N1.getValueType().isFloatingPoint())
1177 if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros())
1178 return SDValue();
1179
1180 if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1))
1181 return Combined;
1182 if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N1, N0))
1183 return Combined;
1184 return SDValue();
1185 }
1186
CombineTo(SDNode * N,const SDValue * To,unsigned NumTo,bool AddTo)1187 SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1188 bool AddTo) {
1189 assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
1190 ++NodesCombined;
1191 LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
1192 To[0].dump(&DAG);
1193 dbgs() << " and " << NumTo - 1 << " other values\n");
1194 for (unsigned i = 0, e = NumTo; i != e; ++i)
1195 assert((!To[i].getNode() ||
1196 N->getValueType(i) == To[i].getValueType()) &&
1197 "Cannot combine value to value of different type!");
1198
1199 WorklistRemover DeadNodes(*this);
1200 DAG.ReplaceAllUsesWith(N, To);
1201 if (AddTo) {
1202 // Push the new nodes and any users onto the worklist
1203 for (unsigned i = 0, e = NumTo; i != e; ++i) {
1204 if (To[i].getNode())
1205 AddToWorklistWithUsers(To[i].getNode());
1206 }
1207 }
1208
1209 // Finally, if the node is now dead, remove it from the graph. The node
1210 // may not be dead if the replacement process recursively simplified to
1211 // something else needing this node.
1212 if (N->use_empty())
1213 deleteAndRecombine(N);
1214 return SDValue(N, 0);
1215 }
1216
1217 void DAGCombiner::
CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt & TLO)1218 CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1219 // Replace the old value with the new one.
1220 ++NodesCombined;
1221 LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.dump(&DAG);
1222 dbgs() << "\nWith: "; TLO.New.dump(&DAG); dbgs() << '\n');
1223
1224 // Replace all uses.
1225 DAG.ReplaceAllUsesOfValueWith(TLO.Old, TLO.New);
1226
1227 // Push the new node and any (possibly new) users onto the worklist.
1228 AddToWorklistWithUsers(TLO.New.getNode());
1229
1230 // Finally, if the node is now dead, remove it from the graph.
1231 recursivelyDeleteUnusedNodes(TLO.Old.getNode());
1232 }
1233
1234 /// Check the specified integer node value to see if it can be simplified or if
1235 /// things it uses can be simplified by bit propagation. If so, return true.
SimplifyDemandedBits(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,bool AssumeSingleUse)1236 bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
1237 const APInt &DemandedElts,
1238 bool AssumeSingleUse) {
1239 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1240 KnownBits Known;
1241 if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, 0,
1242 AssumeSingleUse))
1243 return false;
1244
1245 // Revisit the node.
1246 AddToWorklist(Op.getNode());
1247
1248 CommitTargetLoweringOpt(TLO);
1249 return true;
1250 }
1251
1252 /// Check the specified vector node value to see if it can be simplified or
1253 /// if things it uses can be simplified as it only uses some of the elements.
1254 /// If so, return true.
SimplifyDemandedVectorElts(SDValue Op,const APInt & DemandedElts,bool AssumeSingleUse)1255 bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
1256 const APInt &DemandedElts,
1257 bool AssumeSingleUse) {
1258 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1259 APInt KnownUndef, KnownZero;
1260 if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
1261 TLO, 0, AssumeSingleUse))
1262 return false;
1263
1264 // Revisit the node.
1265 AddToWorklist(Op.getNode());
1266
1267 CommitTargetLoweringOpt(TLO);
1268 return true;
1269 }
1270
ReplaceLoadWithPromotedLoad(SDNode * Load,SDNode * ExtLoad)1271 void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
1272 SDLoc DL(Load);
1273 EVT VT = Load->getValueType(0);
1274 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, SDValue(ExtLoad, 0));
1275
1276 LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
1277 Trunc.dump(&DAG); dbgs() << '\n');
1278
1279 DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), Trunc);
1280 DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), SDValue(ExtLoad, 1));
1281
1282 AddToWorklist(Trunc.getNode());
1283 recursivelyDeleteUnusedNodes(Load);
1284 }
1285
PromoteOperand(SDValue Op,EVT PVT,bool & Replace)1286 SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
1287 Replace = false;
1288 SDLoc DL(Op);
1289 if (ISD::isUNINDEXEDLoad(Op.getNode())) {
1290 LoadSDNode *LD = cast<LoadSDNode>(Op);
1291 EVT MemVT = LD->getMemoryVT();
1292 ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1293 : LD->getExtensionType();
1294 Replace = true;
1295 return DAG.getExtLoad(ExtType, DL, PVT,
1296 LD->getChain(), LD->getBasePtr(),
1297 MemVT, LD->getMemOperand());
1298 }
1299
1300 unsigned Opc = Op.getOpcode();
1301 switch (Opc) {
1302 default: break;
1303 case ISD::AssertSext:
1304 if (SDValue Op0 = SExtPromoteOperand(Op.getOperand(0), PVT))
1305 return DAG.getNode(ISD::AssertSext, DL, PVT, Op0, Op.getOperand(1));
1306 break;
1307 case ISD::AssertZext:
1308 if (SDValue Op0 = ZExtPromoteOperand(Op.getOperand(0), PVT))
1309 return DAG.getNode(ISD::AssertZext, DL, PVT, Op0, Op.getOperand(1));
1310 break;
1311 case ISD::Constant: {
1312 unsigned ExtOpc =
1313 Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1314 return DAG.getNode(ExtOpc, DL, PVT, Op);
1315 }
1316 }
1317
1318 if (!TLI.isOperationLegal(ISD::ANY_EXTEND, PVT))
1319 return SDValue();
1320 return DAG.getNode(ISD::ANY_EXTEND, DL, PVT, Op);
1321 }
1322
SExtPromoteOperand(SDValue Op,EVT PVT)1323 SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
1324 if (!TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG, PVT))
1325 return SDValue();
1326 EVT OldVT = Op.getValueType();
1327 SDLoc DL(Op);
1328 bool Replace = false;
1329 SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1330 if (!NewOp.getNode())
1331 return SDValue();
1332 AddToWorklist(NewOp.getNode());
1333
1334 if (Replace)
1335 ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1336 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, NewOp.getValueType(), NewOp,
1337 DAG.getValueType(OldVT));
1338 }
1339
ZExtPromoteOperand(SDValue Op,EVT PVT)1340 SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
1341 EVT OldVT = Op.getValueType();
1342 SDLoc DL(Op);
1343 bool Replace = false;
1344 SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1345 if (!NewOp.getNode())
1346 return SDValue();
1347 AddToWorklist(NewOp.getNode());
1348
1349 if (Replace)
1350 ReplaceLoadWithPromotedLoad(Op.getNode(), NewOp.getNode());
1351 return DAG.getZeroExtendInReg(NewOp, DL, OldVT);
1352 }
1353
1354 /// Promote the specified integer binary operation if the target indicates it is
1355 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1356 /// i32 since i16 instructions are longer.
PromoteIntBinOp(SDValue Op)1357 SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
1358 if (!LegalOperations)
1359 return SDValue();
1360
1361 EVT VT = Op.getValueType();
1362 if (VT.isVector() || !VT.isInteger())
1363 return SDValue();
1364
1365 // If operation type is 'undesirable', e.g. i16 on x86, consider
1366 // promoting it.
1367 unsigned Opc = Op.getOpcode();
1368 if (TLI.isTypeDesirableForOp(Opc, VT))
1369 return SDValue();
1370
1371 EVT PVT = VT;
1372 // Consult target whether it is a good idea to promote this operation and
1373 // what's the right type to promote it to.
1374 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1375 assert(PVT != VT && "Don't know what type to promote to!");
1376
1377 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1378
1379 bool Replace0 = false;
1380 SDValue N0 = Op.getOperand(0);
1381 SDValue NN0 = PromoteOperand(N0, PVT, Replace0);
1382
1383 bool Replace1 = false;
1384 SDValue N1 = Op.getOperand(1);
1385 SDValue NN1 = PromoteOperand(N1, PVT, Replace1);
1386 SDLoc DL(Op);
1387
1388 SDValue RV =
1389 DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, NN0, NN1));
1390
1391 // We are always replacing N0/N1's use in N and only need additional
1392 // replacements if there are additional uses.
1393 // Note: We are checking uses of the *nodes* (SDNode) rather than values
1394 // (SDValue) here because the node may reference multiple values
1395 // (for example, the chain value of a load node).
1396 Replace0 &= !N0->hasOneUse();
1397 Replace1 &= (N0 != N1) && !N1->hasOneUse();
1398
1399 // Combine Op here so it is preserved past replacements.
1400 CombineTo(Op.getNode(), RV);
1401
1402 // If operands have a use ordering, make sure we deal with
1403 // predecessor first.
1404 if (Replace0 && Replace1 && N0->isPredecessorOf(N1.getNode())) {
1405 std::swap(N0, N1);
1406 std::swap(NN0, NN1);
1407 }
1408
1409 if (Replace0) {
1410 AddToWorklist(NN0.getNode());
1411 ReplaceLoadWithPromotedLoad(N0.getNode(), NN0.getNode());
1412 }
1413 if (Replace1) {
1414 AddToWorklist(NN1.getNode());
1415 ReplaceLoadWithPromotedLoad(N1.getNode(), NN1.getNode());
1416 }
1417 return Op;
1418 }
1419 return SDValue();
1420 }
1421
1422 /// Promote the specified integer shift operation if the target indicates it is
1423 /// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1424 /// i32 since i16 instructions are longer.
PromoteIntShiftOp(SDValue Op)1425 SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
1426 if (!LegalOperations)
1427 return SDValue();
1428
1429 EVT VT = Op.getValueType();
1430 if (VT.isVector() || !VT.isInteger())
1431 return SDValue();
1432
1433 // If operation type is 'undesirable', e.g. i16 on x86, consider
1434 // promoting it.
1435 unsigned Opc = Op.getOpcode();
1436 if (TLI.isTypeDesirableForOp(Opc, VT))
1437 return SDValue();
1438
1439 EVT PVT = VT;
1440 // Consult target whether it is a good idea to promote this operation and
1441 // what's the right type to promote it to.
1442 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1443 assert(PVT != VT && "Don't know what type to promote to!");
1444
1445 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1446
1447 bool Replace = false;
1448 SDValue N0 = Op.getOperand(0);
1449 if (Opc == ISD::SRA)
1450 N0 = SExtPromoteOperand(N0, PVT);
1451 else if (Opc == ISD::SRL)
1452 N0 = ZExtPromoteOperand(N0, PVT);
1453 else
1454 N0 = PromoteOperand(N0, PVT, Replace);
1455
1456 if (!N0.getNode())
1457 return SDValue();
1458
1459 SDLoc DL(Op);
1460 SDValue N1 = Op.getOperand(1);
1461 SDValue RV =
1462 DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getNode(Opc, DL, PVT, N0, N1));
1463
1464 if (Replace)
1465 ReplaceLoadWithPromotedLoad(Op.getOperand(0).getNode(), N0.getNode());
1466
1467 // Deal with Op being deleted.
1468 if (Op && Op.getOpcode() != ISD::DELETED_NODE)
1469 return RV;
1470 }
1471 return SDValue();
1472 }
1473
PromoteExtend(SDValue Op)1474 SDValue DAGCombiner::PromoteExtend(SDValue Op) {
1475 if (!LegalOperations)
1476 return SDValue();
1477
1478 EVT VT = Op.getValueType();
1479 if (VT.isVector() || !VT.isInteger())
1480 return SDValue();
1481
1482 // If operation type is 'undesirable', e.g. i16 on x86, consider
1483 // promoting it.
1484 unsigned Opc = Op.getOpcode();
1485 if (TLI.isTypeDesirableForOp(Opc, VT))
1486 return SDValue();
1487
1488 EVT PVT = VT;
1489 // Consult target whether it is a good idea to promote this operation and
1490 // what's the right type to promote it to.
1491 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1492 assert(PVT != VT && "Don't know what type to promote to!");
1493 // fold (aext (aext x)) -> (aext x)
1494 // fold (aext (zext x)) -> (zext x)
1495 // fold (aext (sext x)) -> (sext x)
1496 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1497 return DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, Op.getOperand(0));
1498 }
1499 return SDValue();
1500 }
1501
PromoteLoad(SDValue Op)1502 bool DAGCombiner::PromoteLoad(SDValue Op) {
1503 if (!LegalOperations)
1504 return false;
1505
1506 if (!ISD::isUNINDEXEDLoad(Op.getNode()))
1507 return false;
1508
1509 EVT VT = Op.getValueType();
1510 if (VT.isVector() || !VT.isInteger())
1511 return false;
1512
1513 // If operation type is 'undesirable', e.g. i16 on x86, consider
1514 // promoting it.
1515 unsigned Opc = Op.getOpcode();
1516 if (TLI.isTypeDesirableForOp(Opc, VT))
1517 return false;
1518
1519 EVT PVT = VT;
1520 // Consult target whether it is a good idea to promote this operation and
1521 // what's the right type to promote it to.
1522 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1523 assert(PVT != VT && "Don't know what type to promote to!");
1524
1525 SDLoc DL(Op);
1526 SDNode *N = Op.getNode();
1527 LoadSDNode *LD = cast<LoadSDNode>(N);
1528 EVT MemVT = LD->getMemoryVT();
1529 ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(LD) ? ISD::EXTLOAD
1530 : LD->getExtensionType();
1531 SDValue NewLD = DAG.getExtLoad(ExtType, DL, PVT,
1532 LD->getChain(), LD->getBasePtr(),
1533 MemVT, LD->getMemOperand());
1534 SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, VT, NewLD);
1535
1536 LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
1537 Result.dump(&DAG); dbgs() << '\n');
1538
1539 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
1540 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), NewLD.getValue(1));
1541
1542 AddToWorklist(Result.getNode());
1543 recursivelyDeleteUnusedNodes(N);
1544 return true;
1545 }
1546
1547 return false;
1548 }
1549
1550 /// Recursively delete a node which has no uses and any operands for
1551 /// which it is the only use.
1552 ///
1553 /// Note that this both deletes the nodes and removes them from the worklist.
1554 /// It also adds any nodes who have had a user deleted to the worklist as they
1555 /// may now have only one use and subject to other combines.
recursivelyDeleteUnusedNodes(SDNode * N)1556 bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
1557 if (!N->use_empty())
1558 return false;
1559
1560 SmallSetVector<SDNode *, 16> Nodes;
1561 Nodes.insert(N);
1562 do {
1563 N = Nodes.pop_back_val();
1564 if (!N)
1565 continue;
1566
1567 if (N->use_empty()) {
1568 for (const SDValue &ChildN : N->op_values())
1569 Nodes.insert(ChildN.getNode());
1570
1571 removeFromWorklist(N);
1572 DAG.DeleteNode(N);
1573 } else {
1574 AddToWorklist(N);
1575 }
1576 } while (!Nodes.empty());
1577 return true;
1578 }
1579
1580 //===----------------------------------------------------------------------===//
1581 // Main DAG Combiner implementation
1582 //===----------------------------------------------------------------------===//
1583
Run(CombineLevel AtLevel)1584 void DAGCombiner::Run(CombineLevel AtLevel) {
1585 // set the instance variables, so that the various visit routines may use it.
1586 Level = AtLevel;
1587 LegalDAG = Level >= AfterLegalizeDAG;
1588 LegalOperations = Level >= AfterLegalizeVectorOps;
1589 LegalTypes = Level >= AfterLegalizeTypes;
1590
1591 WorklistInserter AddNodes(*this);
1592
1593 // Add all the dag nodes to the worklist.
1594 for (SDNode &Node : DAG.allnodes())
1595 AddToWorklist(&Node);
1596
1597 // Create a dummy node (which is not added to allnodes), that adds a reference
1598 // to the root node, preventing it from being deleted, and tracking any
1599 // changes of the root.
1600 HandleSDNode Dummy(DAG.getRoot());
1601
1602 // While we have a valid worklist entry node, try to combine it.
1603 while (SDNode *N = getNextWorklistEntry()) {
1604 // If N has no uses, it is dead. Make sure to revisit all N's operands once
1605 // N is deleted from the DAG, since they too may now be dead or may have a
1606 // reduced number of uses, allowing other xforms.
1607 if (recursivelyDeleteUnusedNodes(N))
1608 continue;
1609
1610 WorklistRemover DeadNodes(*this);
1611
1612 // If this combine is running after legalizing the DAG, re-legalize any
1613 // nodes pulled off the worklist.
1614 if (LegalDAG) {
1615 SmallSetVector<SDNode *, 16> UpdatedNodes;
1616 bool NIsValid = DAG.LegalizeOp(N, UpdatedNodes);
1617
1618 for (SDNode *LN : UpdatedNodes)
1619 AddToWorklistWithUsers(LN);
1620
1621 if (!NIsValid)
1622 continue;
1623 }
1624
1625 LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
1626
1627 // Add any operands of the new node which have not yet been combined to the
1628 // worklist as well. Because the worklist uniques things already, this
1629 // won't repeatedly process the same operand.
1630 CombinedNodes.insert(N);
1631 for (const SDValue &ChildN : N->op_values())
1632 if (!CombinedNodes.count(ChildN.getNode()))
1633 AddToWorklist(ChildN.getNode());
1634
1635 SDValue RV = combine(N);
1636
1637 if (!RV.getNode())
1638 continue;
1639
1640 ++NodesCombined;
1641
1642 // If we get back the same node we passed in, rather than a new node or
1643 // zero, we know that the node must have defined multiple values and
1644 // CombineTo was used. Since CombineTo takes care of the worklist
1645 // mechanics for us, we have no work to do in this case.
1646 if (RV.getNode() == N)
1647 continue;
1648
1649 assert(N->getOpcode() != ISD::DELETED_NODE &&
1650 RV.getOpcode() != ISD::DELETED_NODE &&
1651 "Node was deleted but visit returned new node!");
1652
1653 LLVM_DEBUG(dbgs() << " ... into: "; RV.dump(&DAG));
1654
1655 if (N->getNumValues() == RV->getNumValues())
1656 DAG.ReplaceAllUsesWith(N, RV.getNode());
1657 else {
1658 assert(N->getValueType(0) == RV.getValueType() &&
1659 N->getNumValues() == 1 && "Type mismatch");
1660 DAG.ReplaceAllUsesWith(N, &RV);
1661 }
1662
1663 // Push the new node and any users onto the worklist. Omit this if the
1664 // new node is the EntryToken (e.g. if a store managed to get optimized
1665 // out), because re-visiting the EntryToken and its users will not uncover
1666 // any additional opportunities, but there may be a large number of such
1667 // users, potentially causing compile time explosion.
1668 if (RV.getOpcode() != ISD::EntryToken) {
1669 AddToWorklist(RV.getNode());
1670 AddUsersToWorklist(RV.getNode());
1671 }
1672
1673 // Finally, if the node is now dead, remove it from the graph. The node
1674 // may not be dead if the replacement process recursively simplified to
1675 // something else needing this node. This will also take care of adding any
1676 // operands which have lost a user to the worklist.
1677 recursivelyDeleteUnusedNodes(N);
1678 }
1679
1680 // If the root changed (e.g. it was a dead load, update the root).
1681 DAG.setRoot(Dummy.getValue());
1682 DAG.RemoveDeadNodes();
1683 }
1684
visit(SDNode * N)1685 SDValue DAGCombiner::visit(SDNode *N) {
1686 switch (N->getOpcode()) {
1687 default: break;
1688 case ISD::TokenFactor: return visitTokenFactor(N);
1689 case ISD::MERGE_VALUES: return visitMERGE_VALUES(N);
1690 case ISD::ADD: return visitADD(N);
1691 case ISD::SUB: return visitSUB(N);
1692 case ISD::SADDSAT:
1693 case ISD::UADDSAT: return visitADDSAT(N);
1694 case ISD::SSUBSAT:
1695 case ISD::USUBSAT: return visitSUBSAT(N);
1696 case ISD::ADDC: return visitADDC(N);
1697 case ISD::SADDO:
1698 case ISD::UADDO: return visitADDO(N);
1699 case ISD::SUBC: return visitSUBC(N);
1700 case ISD::SSUBO:
1701 case ISD::USUBO: return visitSUBO(N);
1702 case ISD::ADDE: return visitADDE(N);
1703 case ISD::ADDCARRY: return visitADDCARRY(N);
1704 case ISD::SADDO_CARRY: return visitSADDO_CARRY(N);
1705 case ISD::SUBE: return visitSUBE(N);
1706 case ISD::SUBCARRY: return visitSUBCARRY(N);
1707 case ISD::SSUBO_CARRY: return visitSSUBO_CARRY(N);
1708 case ISD::SMULFIX:
1709 case ISD::SMULFIXSAT:
1710 case ISD::UMULFIX:
1711 case ISD::UMULFIXSAT: return visitMULFIX(N);
1712 case ISD::MUL: return visitMUL(N);
1713 case ISD::SDIV: return visitSDIV(N);
1714 case ISD::UDIV: return visitUDIV(N);
1715 case ISD::SREM:
1716 case ISD::UREM: return visitREM(N);
1717 case ISD::MULHU: return visitMULHU(N);
1718 case ISD::MULHS: return visitMULHS(N);
1719 case ISD::AVGFLOORS:
1720 case ISD::AVGFLOORU:
1721 case ISD::AVGCEILS:
1722 case ISD::AVGCEILU: return visitAVG(N);
1723 case ISD::SMUL_LOHI: return visitSMUL_LOHI(N);
1724 case ISD::UMUL_LOHI: return visitUMUL_LOHI(N);
1725 case ISD::SMULO:
1726 case ISD::UMULO: return visitMULO(N);
1727 case ISD::SMIN:
1728 case ISD::SMAX:
1729 case ISD::UMIN:
1730 case ISD::UMAX: return visitIMINMAX(N);
1731 case ISD::AND: return visitAND(N);
1732 case ISD::OR: return visitOR(N);
1733 case ISD::XOR: return visitXOR(N);
1734 case ISD::SHL: return visitSHL(N);
1735 case ISD::SRA: return visitSRA(N);
1736 case ISD::SRL: return visitSRL(N);
1737 case ISD::ROTR:
1738 case ISD::ROTL: return visitRotate(N);
1739 case ISD::FSHL:
1740 case ISD::FSHR: return visitFunnelShift(N);
1741 case ISD::SSHLSAT:
1742 case ISD::USHLSAT: return visitSHLSAT(N);
1743 case ISD::ABS: return visitABS(N);
1744 case ISD::BSWAP: return visitBSWAP(N);
1745 case ISD::BITREVERSE: return visitBITREVERSE(N);
1746 case ISD::CTLZ: return visitCTLZ(N);
1747 case ISD::CTLZ_ZERO_UNDEF: return visitCTLZ_ZERO_UNDEF(N);
1748 case ISD::CTTZ: return visitCTTZ(N);
1749 case ISD::CTTZ_ZERO_UNDEF: return visitCTTZ_ZERO_UNDEF(N);
1750 case ISD::CTPOP: return visitCTPOP(N);
1751 case ISD::SELECT: return visitSELECT(N);
1752 case ISD::VSELECT: return visitVSELECT(N);
1753 case ISD::SELECT_CC: return visitSELECT_CC(N);
1754 case ISD::SETCC: return visitSETCC(N);
1755 case ISD::SETCCCARRY: return visitSETCCCARRY(N);
1756 case ISD::SIGN_EXTEND: return visitSIGN_EXTEND(N);
1757 case ISD::ZERO_EXTEND: return visitZERO_EXTEND(N);
1758 case ISD::ANY_EXTEND: return visitANY_EXTEND(N);
1759 case ISD::AssertSext:
1760 case ISD::AssertZext: return visitAssertExt(N);
1761 case ISD::AssertAlign: return visitAssertAlign(N);
1762 case ISD::SIGN_EXTEND_INREG: return visitSIGN_EXTEND_INREG(N);
1763 case ISD::SIGN_EXTEND_VECTOR_INREG:
1764 case ISD::ZERO_EXTEND_VECTOR_INREG:
1765 case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
1766 case ISD::TRUNCATE: return visitTRUNCATE(N);
1767 case ISD::BITCAST: return visitBITCAST(N);
1768 case ISD::BUILD_PAIR: return visitBUILD_PAIR(N);
1769 case ISD::FADD: return visitFADD(N);
1770 case ISD::STRICT_FADD: return visitSTRICT_FADD(N);
1771 case ISD::FSUB: return visitFSUB(N);
1772 case ISD::FMUL: return visitFMUL(N);
1773 case ISD::FMA: return visitFMA(N);
1774 case ISD::FDIV: return visitFDIV(N);
1775 case ISD::FREM: return visitFREM(N);
1776 case ISD::FSQRT: return visitFSQRT(N);
1777 case ISD::FCOPYSIGN: return visitFCOPYSIGN(N);
1778 case ISD::FPOW: return visitFPOW(N);
1779 case ISD::SINT_TO_FP: return visitSINT_TO_FP(N);
1780 case ISD::UINT_TO_FP: return visitUINT_TO_FP(N);
1781 case ISD::FP_TO_SINT: return visitFP_TO_SINT(N);
1782 case ISD::FP_TO_UINT: return visitFP_TO_UINT(N);
1783 case ISD::FP_ROUND: return visitFP_ROUND(N);
1784 case ISD::FP_EXTEND: return visitFP_EXTEND(N);
1785 case ISD::FNEG: return visitFNEG(N);
1786 case ISD::FABS: return visitFABS(N);
1787 case ISD::FFLOOR: return visitFFLOOR(N);
1788 case ISD::FMINNUM:
1789 case ISD::FMAXNUM:
1790 case ISD::FMINIMUM:
1791 case ISD::FMAXIMUM: return visitFMinMax(N);
1792 case ISD::FCEIL: return visitFCEIL(N);
1793 case ISD::FTRUNC: return visitFTRUNC(N);
1794 case ISD::BRCOND: return visitBRCOND(N);
1795 case ISD::BR_CC: return visitBR_CC(N);
1796 case ISD::LOAD: return visitLOAD(N);
1797 case ISD::STORE: return visitSTORE(N);
1798 case ISD::INSERT_VECTOR_ELT: return visitINSERT_VECTOR_ELT(N);
1799 case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
1800 case ISD::BUILD_VECTOR: return visitBUILD_VECTOR(N);
1801 case ISD::CONCAT_VECTORS: return visitCONCAT_VECTORS(N);
1802 case ISD::EXTRACT_SUBVECTOR: return visitEXTRACT_SUBVECTOR(N);
1803 case ISD::VECTOR_SHUFFLE: return visitVECTOR_SHUFFLE(N);
1804 case ISD::SCALAR_TO_VECTOR: return visitSCALAR_TO_VECTOR(N);
1805 case ISD::INSERT_SUBVECTOR: return visitINSERT_SUBVECTOR(N);
1806 case ISD::MGATHER: return visitMGATHER(N);
1807 case ISD::MLOAD: return visitMLOAD(N);
1808 case ISD::MSCATTER: return visitMSCATTER(N);
1809 case ISD::MSTORE: return visitMSTORE(N);
1810 case ISD::LIFETIME_END: return visitLIFETIME_END(N);
1811 case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
1812 case ISD::FP16_TO_FP: return visitFP16_TO_FP(N);
1813 case ISD::FP_TO_BF16: return visitFP_TO_BF16(N);
1814 case ISD::FREEZE: return visitFREEZE(N);
1815 case ISD::VECREDUCE_FADD:
1816 case ISD::VECREDUCE_FMUL:
1817 case ISD::VECREDUCE_ADD:
1818 case ISD::VECREDUCE_MUL:
1819 case ISD::VECREDUCE_AND:
1820 case ISD::VECREDUCE_OR:
1821 case ISD::VECREDUCE_XOR:
1822 case ISD::VECREDUCE_SMAX:
1823 case ISD::VECREDUCE_SMIN:
1824 case ISD::VECREDUCE_UMAX:
1825 case ISD::VECREDUCE_UMIN:
1826 case ISD::VECREDUCE_FMAX:
1827 case ISD::VECREDUCE_FMIN: return visitVECREDUCE(N);
1828 #define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) case ISD::SDOPC:
1829 #include "llvm/IR/VPIntrinsics.def"
1830 return visitVPOp(N);
1831 }
1832 return SDValue();
1833 }
1834
combine(SDNode * N)1835 SDValue DAGCombiner::combine(SDNode *N) {
1836 SDValue RV;
1837 if (!DisableGenericCombines)
1838 RV = visit(N);
1839
1840 // If nothing happened, try a target-specific DAG combine.
1841 if (!RV.getNode()) {
1842 assert(N->getOpcode() != ISD::DELETED_NODE &&
1843 "Node was deleted but visit returned NULL!");
1844
1845 if (N->getOpcode() >= ISD::BUILTIN_OP_END ||
1846 TLI.hasTargetDAGCombine((ISD::NodeType)N->getOpcode())) {
1847
1848 // Expose the DAG combiner to the target combiner impls.
1849 TargetLowering::DAGCombinerInfo
1850 DagCombineInfo(DAG, Level, false, this);
1851
1852 RV = TLI.PerformDAGCombine(N, DagCombineInfo);
1853 }
1854 }
1855
1856 // If nothing happened still, try promoting the operation.
1857 if (!RV.getNode()) {
1858 switch (N->getOpcode()) {
1859 default: break;
1860 case ISD::ADD:
1861 case ISD::SUB:
1862 case ISD::MUL:
1863 case ISD::AND:
1864 case ISD::OR:
1865 case ISD::XOR:
1866 RV = PromoteIntBinOp(SDValue(N, 0));
1867 break;
1868 case ISD::SHL:
1869 case ISD::SRA:
1870 case ISD::SRL:
1871 RV = PromoteIntShiftOp(SDValue(N, 0));
1872 break;
1873 case ISD::SIGN_EXTEND:
1874 case ISD::ZERO_EXTEND:
1875 case ISD::ANY_EXTEND:
1876 RV = PromoteExtend(SDValue(N, 0));
1877 break;
1878 case ISD::LOAD:
1879 if (PromoteLoad(SDValue(N, 0)))
1880 RV = SDValue(N, 0);
1881 break;
1882 }
1883 }
1884
1885 // If N is a commutative binary node, try to eliminate it if the commuted
1886 // version is already present in the DAG.
1887 if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode())) {
1888 SDValue N0 = N->getOperand(0);
1889 SDValue N1 = N->getOperand(1);
1890
1891 // Constant operands are canonicalized to RHS.
1892 if (N0 != N1 && (isa<ConstantSDNode>(N0) || !isa<ConstantSDNode>(N1))) {
1893 SDValue Ops[] = {N1, N0};
1894 SDNode *CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops,
1895 N->getFlags());
1896 if (CSENode)
1897 return SDValue(CSENode, 0);
1898 }
1899 }
1900
1901 return RV;
1902 }
1903
1904 /// Given a node, return its input chain if it has one, otherwise return a null
1905 /// sd operand.
getInputChainForNode(SDNode * N)1906 static SDValue getInputChainForNode(SDNode *N) {
1907 if (unsigned NumOps = N->getNumOperands()) {
1908 if (N->getOperand(0).getValueType() == MVT::Other)
1909 return N->getOperand(0);
1910 if (N->getOperand(NumOps-1).getValueType() == MVT::Other)
1911 return N->getOperand(NumOps-1);
1912 for (unsigned i = 1; i < NumOps-1; ++i)
1913 if (N->getOperand(i).getValueType() == MVT::Other)
1914 return N->getOperand(i);
1915 }
1916 return SDValue();
1917 }
1918
visitTokenFactor(SDNode * N)1919 SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
1920 // If N has two operands, where one has an input chain equal to the other,
1921 // the 'other' chain is redundant.
1922 if (N->getNumOperands() == 2) {
1923 if (getInputChainForNode(N->getOperand(0).getNode()) == N->getOperand(1))
1924 return N->getOperand(0);
1925 if (getInputChainForNode(N->getOperand(1).getNode()) == N->getOperand(0))
1926 return N->getOperand(1);
1927 }
1928
1929 // Don't simplify token factors if optnone.
1930 if (OptLevel == CodeGenOpt::None)
1931 return SDValue();
1932
1933 // Don't simplify the token factor if the node itself has too many operands.
1934 if (N->getNumOperands() > TokenFactorInlineLimit)
1935 return SDValue();
1936
1937 // If the sole user is a token factor, we should make sure we have a
1938 // chance to merge them together. This prevents TF chains from inhibiting
1939 // optimizations.
1940 if (N->hasOneUse() && N->use_begin()->getOpcode() == ISD::TokenFactor)
1941 AddToWorklist(*(N->use_begin()));
1942
1943 SmallVector<SDNode *, 8> TFs; // List of token factors to visit.
1944 SmallVector<SDValue, 8> Ops; // Ops for replacing token factor.
1945 SmallPtrSet<SDNode*, 16> SeenOps;
1946 bool Changed = false; // If we should replace this token factor.
1947
1948 // Start out with this token factor.
1949 TFs.push_back(N);
1950
1951 // Iterate through token factors. The TFs grows when new token factors are
1952 // encountered.
1953 for (unsigned i = 0; i < TFs.size(); ++i) {
1954 // Limit number of nodes to inline, to avoid quadratic compile times.
1955 // We have to add the outstanding Token Factors to Ops, otherwise we might
1956 // drop Ops from the resulting Token Factors.
1957 if (Ops.size() > TokenFactorInlineLimit) {
1958 for (unsigned j = i; j < TFs.size(); j++)
1959 Ops.emplace_back(TFs[j], 0);
1960 // Drop unprocessed Token Factors from TFs, so we do not add them to the
1961 // combiner worklist later.
1962 TFs.resize(i);
1963 break;
1964 }
1965
1966 SDNode *TF = TFs[i];
1967 // Check each of the operands.
1968 for (const SDValue &Op : TF->op_values()) {
1969 switch (Op.getOpcode()) {
1970 case ISD::EntryToken:
1971 // Entry tokens don't need to be added to the list. They are
1972 // redundant.
1973 Changed = true;
1974 break;
1975
1976 case ISD::TokenFactor:
1977 if (Op.hasOneUse() && !is_contained(TFs, Op.getNode())) {
1978 // Queue up for processing.
1979 TFs.push_back(Op.getNode());
1980 Changed = true;
1981 break;
1982 }
1983 [[fallthrough]];
1984
1985 default:
1986 // Only add if it isn't already in the list.
1987 if (SeenOps.insert(Op.getNode()).second)
1988 Ops.push_back(Op);
1989 else
1990 Changed = true;
1991 break;
1992 }
1993 }
1994 }
1995
1996 // Re-visit inlined Token Factors, to clean them up in case they have been
1997 // removed. Skip the first Token Factor, as this is the current node.
1998 for (unsigned i = 1, e = TFs.size(); i < e; i++)
1999 AddToWorklist(TFs[i]);
2000
2001 // Remove Nodes that are chained to another node in the list. Do so
2002 // by walking up chains breath-first stopping when we've seen
2003 // another operand. In general we must climb to the EntryNode, but we can exit
2004 // early if we find all remaining work is associated with just one operand as
2005 // no further pruning is possible.
2006
2007 // List of nodes to search through and original Ops from which they originate.
2008 SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
2009 SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
2010 SmallPtrSet<SDNode *, 16> SeenChains;
2011 bool DidPruneOps = false;
2012
2013 unsigned NumLeftToConsider = 0;
2014 for (const SDValue &Op : Ops) {
2015 Worklist.push_back(std::make_pair(Op.getNode(), NumLeftToConsider++));
2016 OpWorkCount.push_back(1);
2017 }
2018
2019 auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
2020 // If this is an Op, we can remove the op from the list. Remark any
2021 // search associated with it as from the current OpNumber.
2022 if (SeenOps.contains(Op)) {
2023 Changed = true;
2024 DidPruneOps = true;
2025 unsigned OrigOpNumber = 0;
2026 while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
2027 OrigOpNumber++;
2028 assert((OrigOpNumber != Ops.size()) &&
2029 "expected to find TokenFactor Operand");
2030 // Re-mark worklist from OrigOpNumber to OpNumber
2031 for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
2032 if (Worklist[i].second == OrigOpNumber) {
2033 Worklist[i].second = OpNumber;
2034 }
2035 }
2036 OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
2037 OpWorkCount[OrigOpNumber] = 0;
2038 NumLeftToConsider--;
2039 }
2040 // Add if it's a new chain
2041 if (SeenChains.insert(Op).second) {
2042 OpWorkCount[OpNumber]++;
2043 Worklist.push_back(std::make_pair(Op, OpNumber));
2044 }
2045 };
2046
2047 for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
2048 // We need at least be consider at least 2 Ops to prune.
2049 if (NumLeftToConsider <= 1)
2050 break;
2051 auto CurNode = Worklist[i].first;
2052 auto CurOpNumber = Worklist[i].second;
2053 assert((OpWorkCount[CurOpNumber] > 0) &&
2054 "Node should not appear in worklist");
2055 switch (CurNode->getOpcode()) {
2056 case ISD::EntryToken:
2057 // Hitting EntryToken is the only way for the search to terminate without
2058 // hitting
2059 // another operand's search. Prevent us from marking this operand
2060 // considered.
2061 NumLeftToConsider++;
2062 break;
2063 case ISD::TokenFactor:
2064 for (const SDValue &Op : CurNode->op_values())
2065 AddToWorklist(i, Op.getNode(), CurOpNumber);
2066 break;
2067 case ISD::LIFETIME_START:
2068 case ISD::LIFETIME_END:
2069 case ISD::CopyFromReg:
2070 case ISD::CopyToReg:
2071 AddToWorklist(i, CurNode->getOperand(0).getNode(), CurOpNumber);
2072 break;
2073 default:
2074 if (auto *MemNode = dyn_cast<MemSDNode>(CurNode))
2075 AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
2076 break;
2077 }
2078 OpWorkCount[CurOpNumber]--;
2079 if (OpWorkCount[CurOpNumber] == 0)
2080 NumLeftToConsider--;
2081 }
2082
2083 // If we've changed things around then replace token factor.
2084 if (Changed) {
2085 SDValue Result;
2086 if (Ops.empty()) {
2087 // The entry token is the only possible outcome.
2088 Result = DAG.getEntryNode();
2089 } else {
2090 if (DidPruneOps) {
2091 SmallVector<SDValue, 8> PrunedOps;
2092 //
2093 for (const SDValue &Op : Ops) {
2094 if (SeenChains.count(Op.getNode()) == 0)
2095 PrunedOps.push_back(Op);
2096 }
2097 Result = DAG.getTokenFactor(SDLoc(N), PrunedOps);
2098 } else {
2099 Result = DAG.getTokenFactor(SDLoc(N), Ops);
2100 }
2101 }
2102 return Result;
2103 }
2104 return SDValue();
2105 }
2106
2107 /// MERGE_VALUES can always be eliminated.
visitMERGE_VALUES(SDNode * N)2108 SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
2109 WorklistRemover DeadNodes(*this);
2110 // Replacing results may cause a different MERGE_VALUES to suddenly
2111 // be CSE'd with N, and carry its uses with it. Iterate until no
2112 // uses remain, to ensure that the node can be safely deleted.
2113 // First add the users of this node to the work list so that they
2114 // can be tried again once they have new operands.
2115 AddUsersToWorklist(N);
2116 do {
2117 // Do as a single replacement to avoid rewalking use lists.
2118 SmallVector<SDValue, 8> Ops;
2119 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2120 Ops.push_back(N->getOperand(i));
2121 DAG.ReplaceAllUsesWith(N, Ops.data());
2122 } while (!N->use_empty());
2123 deleteAndRecombine(N);
2124 return SDValue(N, 0); // Return N so it doesn't get rechecked!
2125 }
2126
2127 /// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
2128 /// ConstantSDNode pointer else nullptr.
getAsNonOpaqueConstant(SDValue N)2129 static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
2130 ConstantSDNode *Const = dyn_cast<ConstantSDNode>(N);
2131 return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
2132 }
2133
2134 /// Return true if 'Use' is a load or a store that uses N as its base pointer
2135 /// and that N may be folded in the load / store addressing mode.
canFoldInAddressingMode(SDNode * N,SDNode * Use,SelectionDAG & DAG,const TargetLowering & TLI)2136 static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
2137 const TargetLowering &TLI) {
2138 EVT VT;
2139 unsigned AS;
2140
2141 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Use)) {
2142 if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2143 return false;
2144 VT = LD->getMemoryVT();
2145 AS = LD->getAddressSpace();
2146 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Use)) {
2147 if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2148 return false;
2149 VT = ST->getMemoryVT();
2150 AS = ST->getAddressSpace();
2151 } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Use)) {
2152 if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2153 return false;
2154 VT = LD->getMemoryVT();
2155 AS = LD->getAddressSpace();
2156 } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Use)) {
2157 if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2158 return false;
2159 VT = ST->getMemoryVT();
2160 AS = ST->getAddressSpace();
2161 } else {
2162 return false;
2163 }
2164
2165 TargetLowering::AddrMode AM;
2166 if (N->getOpcode() == ISD::ADD) {
2167 AM.HasBaseReg = true;
2168 ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2169 if (Offset)
2170 // [reg +/- imm]
2171 AM.BaseOffs = Offset->getSExtValue();
2172 else
2173 // [reg +/- reg]
2174 AM.Scale = 1;
2175 } else if (N->getOpcode() == ISD::SUB) {
2176 AM.HasBaseReg = true;
2177 ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
2178 if (Offset)
2179 // [reg +/- imm]
2180 AM.BaseOffs = -Offset->getSExtValue();
2181 else
2182 // [reg +/- reg]
2183 AM.Scale = 1;
2184 } else {
2185 return false;
2186 }
2187
2188 return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM,
2189 VT.getTypeForEVT(*DAG.getContext()), AS);
2190 }
2191
2192 /// This inverts a canonicalization in IR that replaces a variable select arm
2193 /// with an identity constant. Codegen improves if we re-use the variable
2194 /// operand rather than load a constant. This can also be converted into a
2195 /// masked vector operation if the target supports it.
foldSelectWithIdentityConstant(SDNode * N,SelectionDAG & DAG,bool ShouldCommuteOperands)2196 static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
2197 bool ShouldCommuteOperands) {
2198 // Match a select as operand 1. The identity constant that we are looking for
2199 // is only valid as operand 1 of a non-commutative binop.
2200 SDValue N0 = N->getOperand(0);
2201 SDValue N1 = N->getOperand(1);
2202 if (ShouldCommuteOperands)
2203 std::swap(N0, N1);
2204
2205 // TODO: Should this apply to scalar select too?
2206 if (N1.getOpcode() != ISD::VSELECT || !N1.hasOneUse())
2207 return SDValue();
2208
2209 // We can't hoist div/rem because of immediate UB (not speculatable).
2210 unsigned Opcode = N->getOpcode();
2211 if (!DAG.isSafeToSpeculativelyExecute(Opcode))
2212 return SDValue();
2213
2214 EVT VT = N->getValueType(0);
2215 SDValue Cond = N1.getOperand(0);
2216 SDValue TVal = N1.getOperand(1);
2217 SDValue FVal = N1.getOperand(2);
2218
2219 // This transform increases uses of N0, so freeze it to be safe.
2220 // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
2221 unsigned OpNo = ShouldCommuteOperands ? 0 : 1;
2222 if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo)) {
2223 SDValue F0 = DAG.getFreeze(N0);
2224 SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags());
2225 return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO);
2226 }
2227 // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
2228 if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo)) {
2229 SDValue F0 = DAG.getFreeze(N0);
2230 SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags());
2231 return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0);
2232 }
2233
2234 return SDValue();
2235 }
2236
foldBinOpIntoSelect(SDNode * BO)2237 SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
2238 assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
2239 "Unexpected binary operator");
2240
2241 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2242 auto BinOpcode = BO->getOpcode();
2243 EVT VT = BO->getValueType(0);
2244 if (TLI.shouldFoldSelectWithIdentityConstant(BinOpcode, VT)) {
2245 if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
2246 return Sel;
2247
2248 if (TLI.isCommutativeBinOp(BO->getOpcode()))
2249 if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
2250 return Sel;
2251 }
2252
2253 // Don't do this unless the old select is going away. We want to eliminate the
2254 // binary operator, not replace a binop with a select.
2255 // TODO: Handle ISD::SELECT_CC.
2256 unsigned SelOpNo = 0;
2257 SDValue Sel = BO->getOperand(0);
2258 if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
2259 SelOpNo = 1;
2260 Sel = BO->getOperand(1);
2261 }
2262
2263 if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
2264 return SDValue();
2265
2266 SDValue CT = Sel.getOperand(1);
2267 if (!isConstantOrConstantVector(CT, true) &&
2268 !DAG.isConstantFPBuildVectorOrConstantFP(CT))
2269 return SDValue();
2270
2271 SDValue CF = Sel.getOperand(2);
2272 if (!isConstantOrConstantVector(CF, true) &&
2273 !DAG.isConstantFPBuildVectorOrConstantFP(CF))
2274 return SDValue();
2275
2276 // Bail out if any constants are opaque because we can't constant fold those.
2277 // The exception is "and" and "or" with either 0 or -1 in which case we can
2278 // propagate non constant operands into select. I.e.:
2279 // and (select Cond, 0, -1), X --> select Cond, 0, X
2280 // or X, (select Cond, -1, 0) --> select Cond, -1, X
2281 bool CanFoldNonConst =
2282 (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
2283 ((isNullOrNullSplat(CT) && isAllOnesOrAllOnesSplat(CF)) ||
2284 (isNullOrNullSplat(CF) && isAllOnesOrAllOnesSplat(CT)));
2285
2286 SDValue CBO = BO->getOperand(SelOpNo ^ 1);
2287 if (!CanFoldNonConst &&
2288 !isConstantOrConstantVector(CBO, true) &&
2289 !DAG.isConstantFPBuildVectorOrConstantFP(CBO))
2290 return SDValue();
2291
2292 SDLoc DL(Sel);
2293 SDValue NewCT, NewCF;
2294
2295 if (CanFoldNonConst) {
2296 // If CBO is an opaque constant, we can't rely on getNode to constant fold.
2297 if ((BinOpcode == ISD::AND && isNullOrNullSplat(CT)) ||
2298 (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(CT)))
2299 NewCT = CT;
2300 else
2301 NewCT = CBO;
2302
2303 if ((BinOpcode == ISD::AND && isNullOrNullSplat(CF)) ||
2304 (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(CF)))
2305 NewCF = CF;
2306 else
2307 NewCF = CBO;
2308 } else {
2309 // We have a select-of-constants followed by a binary operator with a
2310 // constant. Eliminate the binop by pulling the constant math into the
2311 // select. Example: add (select Cond, CT, CF), CBO --> select Cond, CT +
2312 // CBO, CF + CBO
2313 NewCT = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CT)
2314 : DAG.getNode(BinOpcode, DL, VT, CT, CBO);
2315 if (!CanFoldNonConst && !NewCT.isUndef() &&
2316 !isConstantOrConstantVector(NewCT, true) &&
2317 !DAG.isConstantFPBuildVectorOrConstantFP(NewCT))
2318 return SDValue();
2319
2320 NewCF = SelOpNo ? DAG.getNode(BinOpcode, DL, VT, CBO, CF)
2321 : DAG.getNode(BinOpcode, DL, VT, CF, CBO);
2322 if (!CanFoldNonConst && !NewCF.isUndef() &&
2323 !isConstantOrConstantVector(NewCF, true) &&
2324 !DAG.isConstantFPBuildVectorOrConstantFP(NewCF))
2325 return SDValue();
2326 }
2327
2328 SDValue SelectOp = DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF);
2329 SelectOp->setFlags(BO->getFlags());
2330 return SelectOp;
2331 }
2332
foldAddSubBoolOfMaskedVal(SDNode * N,SelectionDAG & DAG)2333 static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, SelectionDAG &DAG) {
2334 assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2335 "Expecting add or sub");
2336
2337 // Match a constant operand and a zext operand for the math instruction:
2338 // add Z, C
2339 // sub C, Z
2340 bool IsAdd = N->getOpcode() == ISD::ADD;
2341 SDValue C = IsAdd ? N->getOperand(1) : N->getOperand(0);
2342 SDValue Z = IsAdd ? N->getOperand(0) : N->getOperand(1);
2343 auto *CN = dyn_cast<ConstantSDNode>(C);
2344 if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
2345 return SDValue();
2346
2347 // Match the zext operand as a setcc of a boolean.
2348 if (Z.getOperand(0).getOpcode() != ISD::SETCC ||
2349 Z.getOperand(0).getValueType() != MVT::i1)
2350 return SDValue();
2351
2352 // Match the compare as: setcc (X & 1), 0, eq.
2353 SDValue SetCC = Z.getOperand(0);
2354 ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
2355 if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) ||
2356 SetCC.getOperand(0).getOpcode() != ISD::AND ||
2357 !isOneConstant(SetCC.getOperand(0).getOperand(1)))
2358 return SDValue();
2359
2360 // We are adding/subtracting a constant and an inverted low bit. Turn that
2361 // into a subtract/add of the low bit with incremented/decremented constant:
2362 // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
2363 // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
2364 EVT VT = C.getValueType();
2365 SDLoc DL(N);
2366 SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT);
2367 SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) :
2368 DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
2369 return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
2370 }
2371
2372 /// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2373 /// a shift and add with a different constant.
foldAddSubOfSignBit(SDNode * N,SelectionDAG & DAG)2374 static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) {
2375 assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2376 "Expecting add or sub");
2377
2378 // We need a constant operand for the add/sub, and the other operand is a
2379 // logical shift right: add (srl), C or sub C, (srl).
2380 bool IsAdd = N->getOpcode() == ISD::ADD;
2381 SDValue ConstantOp = IsAdd ? N->getOperand(1) : N->getOperand(0);
2382 SDValue ShiftOp = IsAdd ? N->getOperand(0) : N->getOperand(1);
2383 if (!DAG.isConstantIntBuildVectorOrConstantInt(ConstantOp) ||
2384 ShiftOp.getOpcode() != ISD::SRL)
2385 return SDValue();
2386
2387 // The shift must be of a 'not' value.
2388 SDValue Not = ShiftOp.getOperand(0);
2389 if (!Not.hasOneUse() || !isBitwiseNot(Not))
2390 return SDValue();
2391
2392 // The shift must be moving the sign bit to the least-significant-bit.
2393 EVT VT = ShiftOp.getValueType();
2394 SDValue ShAmt = ShiftOp.getOperand(1);
2395 ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
2396 if (!ShAmtC || ShAmtC->getAPIntValue() != (VT.getScalarSizeInBits() - 1))
2397 return SDValue();
2398
2399 // Eliminate the 'not' by adjusting the shift and add/sub constant:
2400 // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
2401 // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
2402 SDLoc DL(N);
2403 if (SDValue NewC = DAG.FoldConstantArithmetic(
2404 IsAdd ? ISD::ADD : ISD::SUB, DL, VT,
2405 {ConstantOp, DAG.getConstant(1, DL, VT)})) {
2406 SDValue NewShift = DAG.getNode(IsAdd ? ISD::SRA : ISD::SRL, DL, VT,
2407 Not.getOperand(0), ShAmt);
2408 return DAG.getNode(ISD::ADD, DL, VT, NewShift, NewC);
2409 }
2410
2411 return SDValue();
2412 }
2413
isADDLike(SDValue V,const SelectionDAG & DAG)2414 static bool isADDLike(SDValue V, const SelectionDAG &DAG) {
2415 unsigned Opcode = V.getOpcode();
2416 if (Opcode == ISD::OR)
2417 return DAG.haveNoCommonBitsSet(V.getOperand(0), V.getOperand(1));
2418 if (Opcode == ISD::XOR)
2419 return isMinSignedConstant(V.getOperand(1));
2420 return false;
2421 }
2422
2423 /// Try to fold a node that behaves like an ADD (note that N isn't necessarily
2424 /// an ISD::ADD here, it could for example be an ISD::OR if we know that there
2425 /// are no common bits set in the operands).
visitADDLike(SDNode * N)2426 SDValue DAGCombiner::visitADDLike(SDNode *N) {
2427 SDValue N0 = N->getOperand(0);
2428 SDValue N1 = N->getOperand(1);
2429 EVT VT = N0.getValueType();
2430 SDLoc DL(N);
2431
2432 // fold (add x, undef) -> undef
2433 if (N0.isUndef())
2434 return N0;
2435 if (N1.isUndef())
2436 return N1;
2437
2438 // fold (add c1, c2) -> c1+c2
2439 if (SDValue C = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N0, N1}))
2440 return C;
2441
2442 // canonicalize constant to RHS
2443 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2444 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2445 return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
2446
2447 // fold vector ops
2448 if (VT.isVector()) {
2449 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2450 return FoldedVOp;
2451
2452 // fold (add x, 0) -> x, vector edition
2453 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
2454 return N0;
2455 }
2456
2457 // fold (add x, 0) -> x
2458 if (isNullConstant(N1))
2459 return N0;
2460
2461 if (N0.getOpcode() == ISD::SUB) {
2462 SDValue N00 = N0.getOperand(0);
2463 SDValue N01 = N0.getOperand(1);
2464
2465 // fold ((A-c1)+c2) -> (A+(c2-c1))
2466 if (SDValue Sub = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N1, N01}))
2467 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Sub);
2468
2469 // fold ((c1-A)+c2) -> (c1+c2)-A
2470 if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N00}))
2471 return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2472 }
2473
2474 // add (sext i1 X), 1 -> zext (not i1 X)
2475 // We don't transform this pattern:
2476 // add (zext i1 X), -1 -> sext (not i1 X)
2477 // because most (?) targets generate better code for the zext form.
2478 if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
2479 isOneOrOneSplat(N1)) {
2480 SDValue X = N0.getOperand(0);
2481 if ((!LegalOperations ||
2482 (TLI.isOperationLegal(ISD::XOR, X.getValueType()) &&
2483 TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) &&
2484 X.getScalarValueSizeInBits() == 1) {
2485 SDValue Not = DAG.getNOT(DL, X, X.getValueType());
2486 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Not);
2487 }
2488 }
2489
2490 // Fold (add (or x, c0), c1) -> (add x, (c0 + c1))
2491 // iff (or x, c0) is equivalent to (add x, c0).
2492 // Fold (add (xor x, c0), c1) -> (add x, (c0 + c1))
2493 // iff (xor x, c0) is equivalent to (add x, c0).
2494 if (isADDLike(N0, DAG)) {
2495 SDValue N01 = N0.getOperand(1);
2496 if (SDValue Add = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N1, N01}))
2497 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), Add);
2498 }
2499
2500 if (SDValue NewSel = foldBinOpIntoSelect(N))
2501 return NewSel;
2502
2503 // reassociate add
2504 if (!reassociationCanBreakAddressingModePattern(ISD::ADD, DL, N, N0, N1)) {
2505 if (SDValue RADD = reassociateOps(ISD::ADD, DL, N0, N1, N->getFlags()))
2506 return RADD;
2507
2508 // Reassociate (add (or x, c), y) -> (add add(x, y), c)) if (or x, c) is
2509 // equivalent to (add x, c).
2510 // Reassociate (add (xor x, c), y) -> (add add(x, y), c)) if (xor x, c) is
2511 // equivalent to (add x, c).
2512 auto ReassociateAddOr = [&](SDValue N0, SDValue N1) {
2513 if (isADDLike(N0, DAG) && N0.hasOneUse() &&
2514 isConstantOrConstantVector(N0.getOperand(1), /* NoOpaque */ true)) {
2515 return DAG.getNode(ISD::ADD, DL, VT,
2516 DAG.getNode(ISD::ADD, DL, VT, N1, N0.getOperand(0)),
2517 N0.getOperand(1));
2518 }
2519 return SDValue();
2520 };
2521 if (SDValue Add = ReassociateAddOr(N0, N1))
2522 return Add;
2523 if (SDValue Add = ReassociateAddOr(N1, N0))
2524 return Add;
2525 }
2526 // fold ((0-A) + B) -> B-A
2527 if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0)))
2528 return DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
2529
2530 // fold (A + (0-B)) -> A-B
2531 if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
2532 return DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(1));
2533
2534 // fold (A+(B-A)) -> B
2535 if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(1))
2536 return N1.getOperand(0);
2537
2538 // fold ((B-A)+A) -> B
2539 if (N0.getOpcode() == ISD::SUB && N1 == N0.getOperand(1))
2540 return N0.getOperand(0);
2541
2542 // fold ((A-B)+(C-A)) -> (C-B)
2543 if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2544 N0.getOperand(0) == N1.getOperand(1))
2545 return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2546 N0.getOperand(1));
2547
2548 // fold ((A-B)+(B-C)) -> (A-C)
2549 if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2550 N0.getOperand(1) == N1.getOperand(0))
2551 return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
2552 N1.getOperand(1));
2553
2554 // fold (A+(B-(A+C))) to (B-C)
2555 if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
2556 N0 == N1.getOperand(1).getOperand(0))
2557 return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2558 N1.getOperand(1).getOperand(1));
2559
2560 // fold (A+(B-(C+A))) to (B-C)
2561 if (N1.getOpcode() == ISD::SUB && N1.getOperand(1).getOpcode() == ISD::ADD &&
2562 N0 == N1.getOperand(1).getOperand(1))
2563 return DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(0),
2564 N1.getOperand(1).getOperand(0));
2565
2566 // fold (A+((B-A)+or-C)) to (B+or-C)
2567 if ((N1.getOpcode() == ISD::SUB || N1.getOpcode() == ISD::ADD) &&
2568 N1.getOperand(0).getOpcode() == ISD::SUB &&
2569 N0 == N1.getOperand(0).getOperand(1))
2570 return DAG.getNode(N1.getOpcode(), DL, VT, N1.getOperand(0).getOperand(0),
2571 N1.getOperand(1));
2572
2573 // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
2574 if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2575 N0->hasOneUse() && N1->hasOneUse()) {
2576 SDValue N00 = N0.getOperand(0);
2577 SDValue N01 = N0.getOperand(1);
2578 SDValue N10 = N1.getOperand(0);
2579 SDValue N11 = N1.getOperand(1);
2580
2581 if (isConstantOrConstantVector(N00) || isConstantOrConstantVector(N10))
2582 return DAG.getNode(ISD::SUB, DL, VT,
2583 DAG.getNode(ISD::ADD, SDLoc(N0), VT, N00, N10),
2584 DAG.getNode(ISD::ADD, SDLoc(N1), VT, N01, N11));
2585 }
2586
2587 // fold (add (umax X, C), -C) --> (usubsat X, C)
2588 if (N0.getOpcode() == ISD::UMAX && hasOperation(ISD::USUBSAT, VT)) {
2589 auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) {
2590 return (!Max && !Op) ||
2591 (Max && Op && Max->getAPIntValue() == (-Op->getAPIntValue()));
2592 };
2593 if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchUSUBSAT,
2594 /*AllowUndefs*/ true))
2595 return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0),
2596 N0.getOperand(1));
2597 }
2598
2599 if (SimplifyDemandedBits(SDValue(N, 0)))
2600 return SDValue(N, 0);
2601
2602 if (isOneOrOneSplat(N1)) {
2603 // fold (add (xor a, -1), 1) -> (sub 0, a)
2604 if (isBitwiseNot(N0))
2605 return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
2606 N0.getOperand(0));
2607
2608 // fold (add (add (xor a, -1), b), 1) -> (sub b, a)
2609 if (N0.getOpcode() == ISD::ADD) {
2610 SDValue A, Xor;
2611
2612 if (isBitwiseNot(N0.getOperand(0))) {
2613 A = N0.getOperand(1);
2614 Xor = N0.getOperand(0);
2615 } else if (isBitwiseNot(N0.getOperand(1))) {
2616 A = N0.getOperand(0);
2617 Xor = N0.getOperand(1);
2618 }
2619
2620 if (Xor)
2621 return DAG.getNode(ISD::SUB, DL, VT, A, Xor.getOperand(0));
2622 }
2623
2624 // Look for:
2625 // add (add x, y), 1
2626 // And if the target does not like this form then turn into:
2627 // sub y, (xor x, -1)
2628 if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
2629 N0.hasOneUse()) {
2630 SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
2631 DAG.getAllOnesConstant(DL, VT));
2632 return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(1), Not);
2633 }
2634 }
2635
2636 // (x - y) + -1 -> add (xor y, -1), x
2637 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
2638 isAllOnesOrAllOnesSplat(N1)) {
2639 SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1), N1);
2640 return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
2641 }
2642
2643 if (SDValue Combined = visitADDLikeCommutative(N0, N1, N))
2644 return Combined;
2645
2646 if (SDValue Combined = visitADDLikeCommutative(N1, N0, N))
2647 return Combined;
2648
2649 return SDValue();
2650 }
2651
visitADD(SDNode * N)2652 SDValue DAGCombiner::visitADD(SDNode *N) {
2653 SDValue N0 = N->getOperand(0);
2654 SDValue N1 = N->getOperand(1);
2655 EVT VT = N0.getValueType();
2656 SDLoc DL(N);
2657
2658 if (SDValue Combined = visitADDLike(N))
2659 return Combined;
2660
2661 if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
2662 return V;
2663
2664 if (SDValue V = foldAddSubOfSignBit(N, DAG))
2665 return V;
2666
2667 // fold (a+b) -> (a|b) iff a and b share no bits.
2668 if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
2669 DAG.haveNoCommonBitsSet(N0, N1))
2670 return DAG.getNode(ISD::OR, DL, VT, N0, N1);
2671
2672 // Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
2673 if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
2674 const APInt &C0 = N0->getConstantOperandAPInt(0);
2675 const APInt &C1 = N1->getConstantOperandAPInt(0);
2676 return DAG.getVScale(DL, VT, C0 + C1);
2677 }
2678
2679 // fold a+vscale(c1)+vscale(c2) -> a+vscale(c1+c2)
2680 if (N0.getOpcode() == ISD::ADD &&
2681 N0.getOperand(1).getOpcode() == ISD::VSCALE &&
2682 N1.getOpcode() == ISD::VSCALE) {
2683 const APInt &VS0 = N0.getOperand(1)->getConstantOperandAPInt(0);
2684 const APInt &VS1 = N1->getConstantOperandAPInt(0);
2685 SDValue VS = DAG.getVScale(DL, VT, VS0 + VS1);
2686 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), VS);
2687 }
2688
2689 // Fold (add step_vector(c1), step_vector(c2) to step_vector(c1+c2))
2690 if (N0.getOpcode() == ISD::STEP_VECTOR &&
2691 N1.getOpcode() == ISD::STEP_VECTOR) {
2692 const APInt &C0 = N0->getConstantOperandAPInt(0);
2693 const APInt &C1 = N1->getConstantOperandAPInt(0);
2694 APInt NewStep = C0 + C1;
2695 return DAG.getStepVector(DL, VT, NewStep);
2696 }
2697
2698 // Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2)
2699 if (N0.getOpcode() == ISD::ADD &&
2700 N0.getOperand(1).getOpcode() == ISD::STEP_VECTOR &&
2701 N1.getOpcode() == ISD::STEP_VECTOR) {
2702 const APInt &SV0 = N0.getOperand(1)->getConstantOperandAPInt(0);
2703 const APInt &SV1 = N1->getConstantOperandAPInt(0);
2704 APInt NewStep = SV0 + SV1;
2705 SDValue SV = DAG.getStepVector(DL, VT, NewStep);
2706 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), SV);
2707 }
2708
2709 return SDValue();
2710 }
2711
visitADDSAT(SDNode * N)2712 SDValue DAGCombiner::visitADDSAT(SDNode *N) {
2713 unsigned Opcode = N->getOpcode();
2714 SDValue N0 = N->getOperand(0);
2715 SDValue N1 = N->getOperand(1);
2716 EVT VT = N0.getValueType();
2717 SDLoc DL(N);
2718
2719 // fold (add_sat x, undef) -> -1
2720 if (N0.isUndef() || N1.isUndef())
2721 return DAG.getAllOnesConstant(DL, VT);
2722
2723 // fold (add_sat c1, c2) -> c3
2724 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
2725 return C;
2726
2727 // canonicalize constant to RHS
2728 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2729 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2730 return DAG.getNode(Opcode, DL, VT, N1, N0);
2731
2732 // fold vector ops
2733 if (VT.isVector()) {
2734 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2735 return FoldedVOp;
2736
2737 // fold (add_sat x, 0) -> x, vector edition
2738 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
2739 return N0;
2740 }
2741
2742 // fold (add_sat x, 0) -> x
2743 if (isNullConstant(N1))
2744 return N0;
2745
2746 // If it cannot overflow, transform into an add.
2747 if (Opcode == ISD::UADDSAT)
2748 if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2749 return DAG.getNode(ISD::ADD, DL, VT, N0, N1);
2750
2751 return SDValue();
2752 }
2753
getAsCarry(const TargetLowering & TLI,SDValue V)2754 static SDValue getAsCarry(const TargetLowering &TLI, SDValue V) {
2755 bool Masked = false;
2756
2757 // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
2758 while (true) {
2759 if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
2760 V = V.getOperand(0);
2761 continue;
2762 }
2763
2764 if (V.getOpcode() == ISD::AND && isOneConstant(V.getOperand(1))) {
2765 Masked = true;
2766 V = V.getOperand(0);
2767 continue;
2768 }
2769
2770 break;
2771 }
2772
2773 // If this is not a carry, return.
2774 if (V.getResNo() != 1)
2775 return SDValue();
2776
2777 if (V.getOpcode() != ISD::ADDCARRY && V.getOpcode() != ISD::SUBCARRY &&
2778 V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
2779 return SDValue();
2780
2781 EVT VT = V->getValueType(0);
2782 if (!TLI.isOperationLegalOrCustom(V.getOpcode(), VT))
2783 return SDValue();
2784
2785 // If the result is masked, then no matter what kind of bool it is we can
2786 // return. If it isn't, then we need to make sure the bool type is either 0 or
2787 // 1 and not other values.
2788 if (Masked ||
2789 TLI.getBooleanContents(V.getValueType()) ==
2790 TargetLoweringBase::ZeroOrOneBooleanContent)
2791 return V;
2792
2793 return SDValue();
2794 }
2795
2796 /// Given the operands of an add/sub operation, see if the 2nd operand is a
2797 /// masked 0/1 whose source operand is actually known to be 0/-1. If so, invert
2798 /// the opcode and bypass the mask operation.
foldAddSubMasked1(bool IsAdd,SDValue N0,SDValue N1,SelectionDAG & DAG,const SDLoc & DL)2799 static SDValue foldAddSubMasked1(bool IsAdd, SDValue N0, SDValue N1,
2800 SelectionDAG &DAG, const SDLoc &DL) {
2801 if (N1.getOpcode() == ISD::ZERO_EXTEND)
2802 N1 = N1.getOperand(0);
2803
2804 if (N1.getOpcode() != ISD::AND || !isOneOrOneSplat(N1->getOperand(1)))
2805 return SDValue();
2806
2807 EVT VT = N0.getValueType();
2808 SDValue N10 = N1.getOperand(0);
2809 if (N10.getValueType() != VT && N10.getOpcode() == ISD::TRUNCATE)
2810 N10 = N10.getOperand(0);
2811
2812 if (N10.getValueType() != VT)
2813 return SDValue();
2814
2815 if (DAG.ComputeNumSignBits(N10) != VT.getScalarSizeInBits())
2816 return SDValue();
2817
2818 // add N0, (and (AssertSext X, i1), 1) --> sub N0, X
2819 // sub N0, (and (AssertSext X, i1), 1) --> add N0, X
2820 return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N0, N10);
2821 }
2822
2823 /// Helper for doing combines based on N0 and N1 being added to each other.
visitADDLikeCommutative(SDValue N0,SDValue N1,SDNode * LocReference)2824 SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
2825 SDNode *LocReference) {
2826 EVT VT = N0.getValueType();
2827 SDLoc DL(LocReference);
2828
2829 // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
2830 if (N1.getOpcode() == ISD::SHL && N1.getOperand(0).getOpcode() == ISD::SUB &&
2831 isNullOrNullSplat(N1.getOperand(0).getOperand(0)))
2832 return DAG.getNode(ISD::SUB, DL, VT, N0,
2833 DAG.getNode(ISD::SHL, DL, VT,
2834 N1.getOperand(0).getOperand(1),
2835 N1.getOperand(1)));
2836
2837 if (SDValue V = foldAddSubMasked1(true, N0, N1, DAG, DL))
2838 return V;
2839
2840 // Look for:
2841 // add (add x, 1), y
2842 // And if the target does not like this form then turn into:
2843 // sub y, (xor x, -1)
2844 if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
2845 N0.hasOneUse() && isOneOrOneSplat(N0.getOperand(1))) {
2846 SDValue Not = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(0),
2847 DAG.getAllOnesConstant(DL, VT));
2848 return DAG.getNode(ISD::SUB, DL, VT, N1, Not);
2849 }
2850
2851 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse()) {
2852 // Hoist one-use subtraction by non-opaque constant:
2853 // (x - C) + y -> (x + y) - C
2854 // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
2855 if (isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
2856 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), N1);
2857 return DAG.getNode(ISD::SUB, DL, VT, Add, N0.getOperand(1));
2858 }
2859 // Hoist one-use subtraction from non-opaque constant:
2860 // (C - x) + y -> (y - x) + C
2861 if (isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
2862 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N1, N0.getOperand(1));
2863 return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(0));
2864 }
2865 }
2866
2867 // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
2868 // rather than 'add 0/-1' (the zext should get folded).
2869 // add (sext i1 Y), X --> sub X, (zext i1 Y)
2870 if (N0.getOpcode() == ISD::SIGN_EXTEND &&
2871 N0.getOperand(0).getScalarValueSizeInBits() == 1 &&
2872 TLI.getBooleanContents(VT) == TargetLowering::ZeroOrOneBooleanContent) {
2873 SDValue ZExt = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
2874 return DAG.getNode(ISD::SUB, DL, VT, N1, ZExt);
2875 }
2876
2877 // add X, (sextinreg Y i1) -> sub X, (and Y 1)
2878 if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
2879 VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
2880 if (TN->getVT() == MVT::i1) {
2881 SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
2882 DAG.getConstant(1, DL, VT));
2883 return DAG.getNode(ISD::SUB, DL, VT, N0, ZExt);
2884 }
2885 }
2886
2887 // (add X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
2888 if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1)) &&
2889 N1.getResNo() == 0)
2890 return DAG.getNode(ISD::ADDCARRY, DL, N1->getVTList(),
2891 N0, N1.getOperand(0), N1.getOperand(2));
2892
2893 // (add X, Carry) -> (addcarry X, 0, Carry)
2894 if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
2895 if (SDValue Carry = getAsCarry(TLI, N1))
2896 return DAG.getNode(ISD::ADDCARRY, DL,
2897 DAG.getVTList(VT, Carry.getValueType()), N0,
2898 DAG.getConstant(0, DL, VT), Carry);
2899
2900 return SDValue();
2901 }
2902
visitADDC(SDNode * N)2903 SDValue DAGCombiner::visitADDC(SDNode *N) {
2904 SDValue N0 = N->getOperand(0);
2905 SDValue N1 = N->getOperand(1);
2906 EVT VT = N0.getValueType();
2907 SDLoc DL(N);
2908
2909 // If the flag result is dead, turn this into an ADD.
2910 if (!N->hasAnyUseOfValue(1))
2911 return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2912 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
2913
2914 // canonicalize constant to RHS.
2915 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
2916 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
2917 if (N0C && !N1C)
2918 return DAG.getNode(ISD::ADDC, DL, N->getVTList(), N1, N0);
2919
2920 // fold (addc x, 0) -> x + no carry out
2921 if (isNullConstant(N1))
2922 return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE,
2923 DL, MVT::Glue));
2924
2925 // If it cannot overflow, transform into an add.
2926 if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
2927 return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2928 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
2929
2930 return SDValue();
2931 }
2932
2933 /**
2934 * Flips a boolean if it is cheaper to compute. If the Force parameters is set,
2935 * then the flip also occurs if computing the inverse is the same cost.
2936 * This function returns an empty SDValue in case it cannot flip the boolean
2937 * without increasing the cost of the computation. If you want to flip a boolean
2938 * no matter what, use DAG.getLogicalNOT.
2939 */
extractBooleanFlip(SDValue V,SelectionDAG & DAG,const TargetLowering & TLI,bool Force)2940 static SDValue extractBooleanFlip(SDValue V, SelectionDAG &DAG,
2941 const TargetLowering &TLI,
2942 bool Force) {
2943 if (Force && isa<ConstantSDNode>(V))
2944 return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
2945
2946 if (V.getOpcode() != ISD::XOR)
2947 return SDValue();
2948
2949 ConstantSDNode *Const = isConstOrConstSplat(V.getOperand(1), false);
2950 if (!Const)
2951 return SDValue();
2952
2953 EVT VT = V.getValueType();
2954
2955 bool IsFlip = false;
2956 switch(TLI.getBooleanContents(VT)) {
2957 case TargetLowering::ZeroOrOneBooleanContent:
2958 IsFlip = Const->isOne();
2959 break;
2960 case TargetLowering::ZeroOrNegativeOneBooleanContent:
2961 IsFlip = Const->isAllOnes();
2962 break;
2963 case TargetLowering::UndefinedBooleanContent:
2964 IsFlip = (Const->getAPIntValue() & 0x01) == 1;
2965 break;
2966 }
2967
2968 if (IsFlip)
2969 return V.getOperand(0);
2970 if (Force)
2971 return DAG.getLogicalNOT(SDLoc(V), V, V.getValueType());
2972 return SDValue();
2973 }
2974
visitADDO(SDNode * N)2975 SDValue DAGCombiner::visitADDO(SDNode *N) {
2976 SDValue N0 = N->getOperand(0);
2977 SDValue N1 = N->getOperand(1);
2978 EVT VT = N0.getValueType();
2979 bool IsSigned = (ISD::SADDO == N->getOpcode());
2980
2981 EVT CarryVT = N->getValueType(1);
2982 SDLoc DL(N);
2983
2984 // If the flag result is dead, turn this into an ADD.
2985 if (!N->hasAnyUseOfValue(1))
2986 return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
2987 DAG.getUNDEF(CarryVT));
2988
2989 // canonicalize constant to RHS.
2990 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
2991 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
2992 return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
2993
2994 // fold (addo x, 0) -> x + no carry out
2995 if (isNullOrNullSplat(N1))
2996 return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
2997
2998 if (!IsSigned) {
2999 // If it cannot overflow, transform into an add.
3000 if (DAG.computeOverflowKind(N0, N1) == SelectionDAG::OFK_Never)
3001 return CombineTo(N, DAG.getNode(ISD::ADD, DL, VT, N0, N1),
3002 DAG.getConstant(0, DL, CarryVT));
3003
3004 // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
3005 if (isBitwiseNot(N0) && isOneOrOneSplat(N1)) {
3006 SDValue Sub = DAG.getNode(ISD::USUBO, DL, N->getVTList(),
3007 DAG.getConstant(0, DL, VT), N0.getOperand(0));
3008 return CombineTo(
3009 N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
3010 }
3011
3012 if (SDValue Combined = visitUADDOLike(N0, N1, N))
3013 return Combined;
3014
3015 if (SDValue Combined = visitUADDOLike(N1, N0, N))
3016 return Combined;
3017 }
3018
3019 return SDValue();
3020 }
3021
visitUADDOLike(SDValue N0,SDValue N1,SDNode * N)3022 SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
3023 EVT VT = N0.getValueType();
3024 if (VT.isVector())
3025 return SDValue();
3026
3027 // (uaddo X, (addcarry Y, 0, Carry)) -> (addcarry X, Y, Carry)
3028 // If Y + 1 cannot overflow.
3029 if (N1.getOpcode() == ISD::ADDCARRY && isNullConstant(N1.getOperand(1))) {
3030 SDValue Y = N1.getOperand(0);
3031 SDValue One = DAG.getConstant(1, SDLoc(N), Y.getValueType());
3032 if (DAG.computeOverflowKind(Y, One) == SelectionDAG::OFK_Never)
3033 return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0, Y,
3034 N1.getOperand(2));
3035 }
3036
3037 // (uaddo X, Carry) -> (addcarry X, 0, Carry)
3038 if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT))
3039 if (SDValue Carry = getAsCarry(TLI, N1))
3040 return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(), N0,
3041 DAG.getConstant(0, SDLoc(N), VT), Carry);
3042
3043 return SDValue();
3044 }
3045
visitADDE(SDNode * N)3046 SDValue DAGCombiner::visitADDE(SDNode *N) {
3047 SDValue N0 = N->getOperand(0);
3048 SDValue N1 = N->getOperand(1);
3049 SDValue CarryIn = N->getOperand(2);
3050
3051 // canonicalize constant to RHS
3052 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3053 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3054 if (N0C && !N1C)
3055 return DAG.getNode(ISD::ADDE, SDLoc(N), N->getVTList(),
3056 N1, N0, CarryIn);
3057
3058 // fold (adde x, y, false) -> (addc x, y)
3059 if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3060 return DAG.getNode(ISD::ADDC, SDLoc(N), N->getVTList(), N0, N1);
3061
3062 return SDValue();
3063 }
3064
visitADDCARRY(SDNode * N)3065 SDValue DAGCombiner::visitADDCARRY(SDNode *N) {
3066 SDValue N0 = N->getOperand(0);
3067 SDValue N1 = N->getOperand(1);
3068 SDValue CarryIn = N->getOperand(2);
3069 SDLoc DL(N);
3070
3071 // canonicalize constant to RHS
3072 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3073 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3074 if (N0C && !N1C)
3075 return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), N1, N0, CarryIn);
3076
3077 // fold (addcarry x, y, false) -> (uaddo x, y)
3078 if (isNullConstant(CarryIn)) {
3079 if (!LegalOperations ||
3080 TLI.isOperationLegalOrCustom(ISD::UADDO, N->getValueType(0)))
3081 return DAG.getNode(ISD::UADDO, DL, N->getVTList(), N0, N1);
3082 }
3083
3084 // fold (addcarry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
3085 if (isNullConstant(N0) && isNullConstant(N1)) {
3086 EVT VT = N0.getValueType();
3087 EVT CarryVT = CarryIn.getValueType();
3088 SDValue CarryExt = DAG.getBoolExtOrTrunc(CarryIn, DL, VT, CarryVT);
3089 AddToWorklist(CarryExt.getNode());
3090 return CombineTo(N, DAG.getNode(ISD::AND, DL, VT, CarryExt,
3091 DAG.getConstant(1, DL, VT)),
3092 DAG.getConstant(0, DL, CarryVT));
3093 }
3094
3095 if (SDValue Combined = visitADDCARRYLike(N0, N1, CarryIn, N))
3096 return Combined;
3097
3098 if (SDValue Combined = visitADDCARRYLike(N1, N0, CarryIn, N))
3099 return Combined;
3100
3101 // We want to avoid useless duplication.
3102 // TODO: This is done automatically for binary operations. As ADDCARRY is
3103 // not a binary operation, this is not really possible to leverage this
3104 // existing mechanism for it. However, if more operations require the same
3105 // deduplication logic, then it may be worth generalize.
3106 SDValue Ops[] = {N1, N0, CarryIn};
3107 SDNode *CSENode =
3108 DAG.getNodeIfExists(ISD::ADDCARRY, N->getVTList(), Ops, N->getFlags());
3109 if (CSENode)
3110 return SDValue(CSENode, 0);
3111
3112 return SDValue();
3113 }
3114
visitSADDO_CARRY(SDNode * N)3115 SDValue DAGCombiner::visitSADDO_CARRY(SDNode *N) {
3116 SDValue N0 = N->getOperand(0);
3117 SDValue N1 = N->getOperand(1);
3118 SDValue CarryIn = N->getOperand(2);
3119 SDLoc DL(N);
3120
3121 // canonicalize constant to RHS
3122 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
3123 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
3124 if (N0C && !N1C)
3125 return DAG.getNode(ISD::SADDO_CARRY, DL, N->getVTList(), N1, N0, CarryIn);
3126
3127 // fold (saddo_carry x, y, false) -> (saddo x, y)
3128 if (isNullConstant(CarryIn)) {
3129 if (!LegalOperations ||
3130 TLI.isOperationLegalOrCustom(ISD::SADDO, N->getValueType(0)))
3131 return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0, N1);
3132 }
3133
3134 return SDValue();
3135 }
3136
3137 /**
3138 * If we are facing some sort of diamond carry propapagtion pattern try to
3139 * break it up to generate something like:
3140 * (addcarry X, 0, (addcarry A, B, Z):Carry)
3141 *
3142 * The end result is usually an increase in operation required, but because the
3143 * carry is now linearized, other transforms can kick in and optimize the DAG.
3144 *
3145 * Patterns typically look something like
3146 * (uaddo A, B)
3147 * / \
3148 * Carry Sum
3149 * | \
3150 * | (addcarry *, 0, Z)
3151 * | /
3152 * \ Carry
3153 * | /
3154 * (addcarry X, *, *)
3155 *
3156 * But numerous variation exist. Our goal is to identify A, B, X and Z and
3157 * produce a combine with a single path for carry propagation.
3158 */
combineADDCARRYDiamond(DAGCombiner & Combiner,SelectionDAG & DAG,SDValue X,SDValue Carry0,SDValue Carry1,SDNode * N)3159 static SDValue combineADDCARRYDiamond(DAGCombiner &Combiner, SelectionDAG &DAG,
3160 SDValue X, SDValue Carry0, SDValue Carry1,
3161 SDNode *N) {
3162 if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1)
3163 return SDValue();
3164 if (Carry1.getOpcode() != ISD::UADDO)
3165 return SDValue();
3166
3167 SDValue Z;
3168
3169 /**
3170 * First look for a suitable Z. It will present itself in the form of
3171 * (addcarry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true
3172 */
3173 if (Carry0.getOpcode() == ISD::ADDCARRY &&
3174 isNullConstant(Carry0.getOperand(1))) {
3175 Z = Carry0.getOperand(2);
3176 } else if (Carry0.getOpcode() == ISD::UADDO &&
3177 isOneConstant(Carry0.getOperand(1))) {
3178 EVT VT = Combiner.getSetCCResultType(Carry0.getValueType());
3179 Z = DAG.getConstant(1, SDLoc(Carry0.getOperand(1)), VT);
3180 } else {
3181 // We couldn't find a suitable Z.
3182 return SDValue();
3183 }
3184
3185
3186 auto cancelDiamond = [&](SDValue A,SDValue B) {
3187 SDLoc DL(N);
3188 SDValue NewY = DAG.getNode(ISD::ADDCARRY, DL, Carry0->getVTList(), A, B, Z);
3189 Combiner.AddToWorklist(NewY.getNode());
3190 return DAG.getNode(ISD::ADDCARRY, DL, N->getVTList(), X,
3191 DAG.getConstant(0, DL, X.getValueType()),
3192 NewY.getValue(1));
3193 };
3194
3195 /**
3196 * (uaddo A, B)
3197 * |
3198 * Sum
3199 * |
3200 * (addcarry *, 0, Z)
3201 */
3202 if (Carry0.getOperand(0) == Carry1.getValue(0)) {
3203 return cancelDiamond(Carry1.getOperand(0), Carry1.getOperand(1));
3204 }
3205
3206 /**
3207 * (addcarry A, 0, Z)
3208 * |
3209 * Sum
3210 * |
3211 * (uaddo *, B)
3212 */
3213 if (Carry1.getOperand(0) == Carry0.getValue(0)) {
3214 return cancelDiamond(Carry0.getOperand(0), Carry1.getOperand(1));
3215 }
3216
3217 if (Carry1.getOperand(1) == Carry0.getValue(0)) {
3218 return cancelDiamond(Carry1.getOperand(0), Carry0.getOperand(0));
3219 }
3220
3221 return SDValue();
3222 }
3223
3224 // If we are facing some sort of diamond carry/borrow in/out pattern try to
3225 // match patterns like:
3226 //
3227 // (uaddo A, B) CarryIn
3228 // | \ |
3229 // | \ |
3230 // PartialSum PartialCarryOutX /
3231 // | | /
3232 // | ____|____________/
3233 // | / |
3234 // (uaddo *, *) \________
3235 // | \ \
3236 // | \ |
3237 // | PartialCarryOutY |
3238 // | \ |
3239 // | \ /
3240 // AddCarrySum | ______/
3241 // | /
3242 // CarryOut = (or *, *)
3243 //
3244 // And generate ADDCARRY (or SUBCARRY) with two result values:
3245 //
3246 // {AddCarrySum, CarryOut} = (addcarry A, B, CarryIn)
3247 //
3248 // Our goal is to identify A, B, and CarryIn and produce ADDCARRY/SUBCARRY with
3249 // a single path for carry/borrow out propagation:
combineCarryDiamond(SelectionDAG & DAG,const TargetLowering & TLI,SDValue N0,SDValue N1,SDNode * N)3250 static SDValue combineCarryDiamond(SelectionDAG &DAG, const TargetLowering &TLI,
3251 SDValue N0, SDValue N1, SDNode *N) {
3252 SDValue Carry0 = getAsCarry(TLI, N0);
3253 if (!Carry0)
3254 return SDValue();
3255 SDValue Carry1 = getAsCarry(TLI, N1);
3256 if (!Carry1)
3257 return SDValue();
3258
3259 unsigned Opcode = Carry0.getOpcode();
3260 if (Opcode != Carry1.getOpcode())
3261 return SDValue();
3262 if (Opcode != ISD::UADDO && Opcode != ISD::USUBO)
3263 return SDValue();
3264
3265 // Canonicalize the add/sub of A and B (the top node in the above ASCII art)
3266 // as Carry0 and the add/sub of the carry in as Carry1 (the middle node).
3267 if (Carry1.getNode()->isOperandOf(Carry0.getNode()))
3268 std::swap(Carry0, Carry1);
3269
3270 // Check if nodes are connected in expected way.
3271 if (Carry1.getOperand(0) != Carry0.getValue(0) &&
3272 Carry1.getOperand(1) != Carry0.getValue(0))
3273 return SDValue();
3274
3275 // The carry in value must be on the righthand side for subtraction.
3276 unsigned CarryInOperandNum =
3277 Carry1.getOperand(0) == Carry0.getValue(0) ? 1 : 0;
3278 if (Opcode == ISD::USUBO && CarryInOperandNum != 1)
3279 return SDValue();
3280 SDValue CarryIn = Carry1.getOperand(CarryInOperandNum);
3281
3282 unsigned NewOp = Opcode == ISD::UADDO ? ISD::ADDCARRY : ISD::SUBCARRY;
3283 if (!TLI.isOperationLegalOrCustom(NewOp, Carry0.getValue(0).getValueType()))
3284 return SDValue();
3285
3286 // Verify that the carry/borrow in is plausibly a carry/borrow bit.
3287 // TODO: make getAsCarry() aware of how partial carries are merged.
3288 if (CarryIn.getOpcode() != ISD::ZERO_EXTEND)
3289 return SDValue();
3290 CarryIn = CarryIn.getOperand(0);
3291 if (CarryIn.getValueType() != MVT::i1)
3292 return SDValue();
3293
3294 SDLoc DL(N);
3295 SDValue Merged =
3296 DAG.getNode(NewOp, DL, Carry1->getVTList(), Carry0.getOperand(0),
3297 Carry0.getOperand(1), CarryIn);
3298
3299 // Please note that because we have proven that the result of the UADDO/USUBO
3300 // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can
3301 // therefore prove that if the first UADDO/USUBO overflows, the second
3302 // UADDO/USUBO cannot. For example consider 8-bit numbers where 0xFF is the
3303 // maximum value.
3304 //
3305 // 0xFF + 0xFF == 0xFE with carry but 0xFE + 1 does not carry
3306 // 0x00 - 0xFF == 1 with a carry/borrow but 1 - 1 == 0 (no carry/borrow)
3307 //
3308 // This is important because it means that OR and XOR can be used to merge
3309 // carry flags; and that AND can return a constant zero.
3310 //
3311 // TODO: match other operations that can merge flags (ADD, etc)
3312 DAG.ReplaceAllUsesOfValueWith(Carry1.getValue(0), Merged.getValue(0));
3313 if (N->getOpcode() == ISD::AND)
3314 return DAG.getConstant(0, DL, MVT::i1);
3315 return Merged.getValue(1);
3316 }
3317
visitADDCARRYLike(SDValue N0,SDValue N1,SDValue CarryIn,SDNode * N)3318 SDValue DAGCombiner::visitADDCARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
3319 SDNode *N) {
3320 // fold (addcarry (xor a, -1), b, c) -> (subcarry b, a, !c) and flip carry.
3321 if (isBitwiseNot(N0))
3322 if (SDValue NotC = extractBooleanFlip(CarryIn, DAG, TLI, true)) {
3323 SDLoc DL(N);
3324 SDValue Sub = DAG.getNode(ISD::SUBCARRY, DL, N->getVTList(), N1,
3325 N0.getOperand(0), NotC);
3326 return CombineTo(
3327 N, Sub, DAG.getLogicalNOT(DL, Sub.getValue(1), Sub->getValueType(1)));
3328 }
3329
3330 // Iff the flag result is dead:
3331 // (addcarry (add|uaddo X, Y), 0, Carry) -> (addcarry X, Y, Carry)
3332 // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo
3333 // or the dependency between the instructions.
3334 if ((N0.getOpcode() == ISD::ADD ||
3335 (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 &&
3336 N0.getValue(1) != CarryIn)) &&
3337 isNullConstant(N1) && !N->hasAnyUseOfValue(1))
3338 return DAG.getNode(ISD::ADDCARRY, SDLoc(N), N->getVTList(),
3339 N0.getOperand(0), N0.getOperand(1), CarryIn);
3340
3341 /**
3342 * When one of the addcarry argument is itself a carry, we may be facing
3343 * a diamond carry propagation. In which case we try to transform the DAG
3344 * to ensure linear carry propagation if that is possible.
3345 */
3346 if (auto Y = getAsCarry(TLI, N1)) {
3347 // Because both are carries, Y and Z can be swapped.
3348 if (auto R = combineADDCARRYDiamond(*this, DAG, N0, Y, CarryIn, N))
3349 return R;
3350 if (auto R = combineADDCARRYDiamond(*this, DAG, N0, CarryIn, Y, N))
3351 return R;
3352 }
3353
3354 return SDValue();
3355 }
3356
3357 // Attempt to create a USUBSAT(LHS, RHS) node with DstVT, performing a
3358 // clamp/truncation if necessary.
getTruncatedUSUBSAT(EVT DstVT,EVT SrcVT,SDValue LHS,SDValue RHS,SelectionDAG & DAG,const SDLoc & DL)3359 static SDValue getTruncatedUSUBSAT(EVT DstVT, EVT SrcVT, SDValue LHS,
3360 SDValue RHS, SelectionDAG &DAG,
3361 const SDLoc &DL) {
3362 assert(DstVT.getScalarSizeInBits() <= SrcVT.getScalarSizeInBits() &&
3363 "Illegal truncation");
3364
3365 if (DstVT == SrcVT)
3366 return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3367
3368 // If the LHS is zero-extended then we can perform the USUBSAT as DstVT by
3369 // clamping RHS.
3370 APInt UpperBits = APInt::getBitsSetFrom(SrcVT.getScalarSizeInBits(),
3371 DstVT.getScalarSizeInBits());
3372 if (!DAG.MaskedValueIsZero(LHS, UpperBits))
3373 return SDValue();
3374
3375 SDValue SatLimit =
3376 DAG.getConstant(APInt::getLowBitsSet(SrcVT.getScalarSizeInBits(),
3377 DstVT.getScalarSizeInBits()),
3378 DL, SrcVT);
3379 RHS = DAG.getNode(ISD::UMIN, DL, SrcVT, RHS, SatLimit);
3380 RHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, RHS);
3381 LHS = DAG.getNode(ISD::TRUNCATE, DL, DstVT, LHS);
3382 return DAG.getNode(ISD::USUBSAT, DL, DstVT, LHS, RHS);
3383 }
3384
3385 // Try to find umax(a,b) - b or a - umin(a,b) patterns that may be converted to
3386 // usubsat(a,b), optionally as a truncated type.
foldSubToUSubSat(EVT DstVT,SDNode * N)3387 SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N) {
3388 if (N->getOpcode() != ISD::SUB ||
3389 !(!LegalOperations || hasOperation(ISD::USUBSAT, DstVT)))
3390 return SDValue();
3391
3392 EVT SubVT = N->getValueType(0);
3393 SDValue Op0 = N->getOperand(0);
3394 SDValue Op1 = N->getOperand(1);
3395
3396 // Try to find umax(a,b) - b or a - umin(a,b) patterns
3397 // they may be converted to usubsat(a,b).
3398 if (Op0.getOpcode() == ISD::UMAX && Op0.hasOneUse()) {
3399 SDValue MaxLHS = Op0.getOperand(0);
3400 SDValue MaxRHS = Op0.getOperand(1);
3401 if (MaxLHS == Op1)
3402 return getTruncatedUSUBSAT(DstVT, SubVT, MaxRHS, Op1, DAG, SDLoc(N));
3403 if (MaxRHS == Op1)
3404 return getTruncatedUSUBSAT(DstVT, SubVT, MaxLHS, Op1, DAG, SDLoc(N));
3405 }
3406
3407 if (Op1.getOpcode() == ISD::UMIN && Op1.hasOneUse()) {
3408 SDValue MinLHS = Op1.getOperand(0);
3409 SDValue MinRHS = Op1.getOperand(1);
3410 if (MinLHS == Op0)
3411 return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinRHS, DAG, SDLoc(N));
3412 if (MinRHS == Op0)
3413 return getTruncatedUSUBSAT(DstVT, SubVT, Op0, MinLHS, DAG, SDLoc(N));
3414 }
3415
3416 // sub(a,trunc(umin(zext(a),b))) -> usubsat(a,trunc(umin(b,SatLimit)))
3417 if (Op1.getOpcode() == ISD::TRUNCATE &&
3418 Op1.getOperand(0).getOpcode() == ISD::UMIN &&
3419 Op1.getOperand(0).hasOneUse()) {
3420 SDValue MinLHS = Op1.getOperand(0).getOperand(0);
3421 SDValue MinRHS = Op1.getOperand(0).getOperand(1);
3422 if (MinLHS.getOpcode() == ISD::ZERO_EXTEND && MinLHS.getOperand(0) == Op0)
3423 return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinLHS, MinRHS,
3424 DAG, SDLoc(N));
3425 if (MinRHS.getOpcode() == ISD::ZERO_EXTEND && MinRHS.getOperand(0) == Op0)
3426 return getTruncatedUSUBSAT(DstVT, MinLHS.getValueType(), MinRHS, MinLHS,
3427 DAG, SDLoc(N));
3428 }
3429
3430 return SDValue();
3431 }
3432
3433 // Since it may not be valid to emit a fold to zero for vector initializers
3434 // check if we can before folding.
tryFoldToZero(const SDLoc & DL,const TargetLowering & TLI,EVT VT,SelectionDAG & DAG,bool LegalOperations)3435 static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
3436 SelectionDAG &DAG, bool LegalOperations) {
3437 if (!VT.isVector())
3438 return DAG.getConstant(0, DL, VT);
3439 if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT))
3440 return DAG.getConstant(0, DL, VT);
3441 return SDValue();
3442 }
3443
visitSUB(SDNode * N)3444 SDValue DAGCombiner::visitSUB(SDNode *N) {
3445 SDValue N0 = N->getOperand(0);
3446 SDValue N1 = N->getOperand(1);
3447 EVT VT = N0.getValueType();
3448 SDLoc DL(N);
3449
3450 auto PeekThroughFreeze = [](SDValue N) {
3451 if (N->getOpcode() == ISD::FREEZE && N.hasOneUse())
3452 return N->getOperand(0);
3453 return N;
3454 };
3455
3456 // fold (sub x, x) -> 0
3457 // FIXME: Refactor this and xor and other similar operations together.
3458 if (PeekThroughFreeze(N0) == PeekThroughFreeze(N1))
3459 return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
3460
3461 // fold (sub c1, c2) -> c3
3462 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N1}))
3463 return C;
3464
3465 // fold vector ops
3466 if (VT.isVector()) {
3467 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3468 return FoldedVOp;
3469
3470 // fold (sub x, 0) -> x, vector edition
3471 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
3472 return N0;
3473 }
3474
3475 if (SDValue NewSel = foldBinOpIntoSelect(N))
3476 return NewSel;
3477
3478 ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
3479
3480 // fold (sub x, c) -> (add x, -c)
3481 if (N1C) {
3482 return DAG.getNode(ISD::ADD, DL, VT, N0,
3483 DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
3484 }
3485
3486 if (isNullOrNullSplat(N0)) {
3487 unsigned BitWidth = VT.getScalarSizeInBits();
3488 // Right-shifting everything out but the sign bit followed by negation is
3489 // the same as flipping arithmetic/logical shift type without the negation:
3490 // -(X >>u 31) -> (X >>s 31)
3491 // -(X >>s 31) -> (X >>u 31)
3492 if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
3493 ConstantSDNode *ShiftAmt = isConstOrConstSplat(N1.getOperand(1));
3494 if (ShiftAmt && ShiftAmt->getAPIntValue() == (BitWidth - 1)) {
3495 auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
3496 if (!LegalOperations || TLI.isOperationLegal(NewSh, VT))
3497 return DAG.getNode(NewSh, DL, VT, N1.getOperand(0), N1.getOperand(1));
3498 }
3499 }
3500
3501 // 0 - X --> 0 if the sub is NUW.
3502 if (N->getFlags().hasNoUnsignedWrap())
3503 return N0;
3504
3505 if (DAG.MaskedValueIsZero(N1, ~APInt::getSignMask(BitWidth))) {
3506 // N1 is either 0 or the minimum signed value. If the sub is NSW, then
3507 // N1 must be 0 because negating the minimum signed value is undefined.
3508 if (N->getFlags().hasNoSignedWrap())
3509 return N0;
3510
3511 // 0 - X --> X if X is 0 or the minimum signed value.
3512 return N1;
3513 }
3514
3515 // Convert 0 - abs(x).
3516 if (N1.getOpcode() == ISD::ABS && N1.hasOneUse() &&
3517 !TLI.isOperationLegalOrCustom(ISD::ABS, VT))
3518 if (SDValue Result = TLI.expandABS(N1.getNode(), DAG, true))
3519 return Result;
3520
3521 // Fold neg(splat(neg(x)) -> splat(x)
3522 if (VT.isVector()) {
3523 SDValue N1S = DAG.getSplatValue(N1, true);
3524 if (N1S && N1S.getOpcode() == ISD::SUB &&
3525 isNullConstant(N1S.getOperand(0)))
3526 return DAG.getSplat(VT, DL, N1S.getOperand(1));
3527 }
3528 }
3529
3530 // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
3531 if (isAllOnesOrAllOnesSplat(N0))
3532 return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
3533
3534 // fold (A - (0-B)) -> A+B
3535 if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(N1.getOperand(0)))
3536 return DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(1));
3537
3538 // fold A-(A-B) -> B
3539 if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(0))
3540 return N1.getOperand(1);
3541
3542 // fold (A+B)-A -> B
3543 if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N1)
3544 return N0.getOperand(1);
3545
3546 // fold (A+B)-B -> A
3547 if (N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1)
3548 return N0.getOperand(0);
3549
3550 // fold (A+C1)-C2 -> A+(C1-C2)
3551 if (N0.getOpcode() == ISD::ADD) {
3552 SDValue N01 = N0.getOperand(1);
3553 if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N01, N1}))
3554 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), NewC);
3555 }
3556
3557 // fold C2-(A+C1) -> (C2-C1)-A
3558 if (N1.getOpcode() == ISD::ADD) {
3559 SDValue N11 = N1.getOperand(1);
3560 if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N11}))
3561 return DAG.getNode(ISD::SUB, DL, VT, NewC, N1.getOperand(0));
3562 }
3563
3564 // fold (A-C1)-C2 -> A-(C1+C2)
3565 if (N0.getOpcode() == ISD::SUB) {
3566 SDValue N01 = N0.getOperand(1);
3567 if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N01, N1}))
3568 return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), NewC);
3569 }
3570
3571 // fold (c1-A)-c2 -> (c1-c2)-A
3572 if (N0.getOpcode() == ISD::SUB) {
3573 SDValue N00 = N0.getOperand(0);
3574 if (SDValue NewC = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N00, N1}))
3575 return DAG.getNode(ISD::SUB, DL, VT, NewC, N0.getOperand(1));
3576 }
3577
3578 // fold ((A+(B+or-C))-B) -> A+or-C
3579 if (N0.getOpcode() == ISD::ADD &&
3580 (N0.getOperand(1).getOpcode() == ISD::SUB ||
3581 N0.getOperand(1).getOpcode() == ISD::ADD) &&
3582 N0.getOperand(1).getOperand(0) == N1)
3583 return DAG.getNode(N0.getOperand(1).getOpcode(), DL, VT, N0.getOperand(0),
3584 N0.getOperand(1).getOperand(1));
3585
3586 // fold ((A+(C+B))-B) -> A+C
3587 if (N0.getOpcode() == ISD::ADD && N0.getOperand(1).getOpcode() == ISD::ADD &&
3588 N0.getOperand(1).getOperand(1) == N1)
3589 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0),
3590 N0.getOperand(1).getOperand(0));
3591
3592 // fold ((A-(B-C))-C) -> A-B
3593 if (N0.getOpcode() == ISD::SUB && N0.getOperand(1).getOpcode() == ISD::SUB &&
3594 N0.getOperand(1).getOperand(1) == N1)
3595 return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0),
3596 N0.getOperand(1).getOperand(0));
3597
3598 // fold (A-(B-C)) -> A+(C-B)
3599 if (N1.getOpcode() == ISD::SUB && N1.hasOneUse())
3600 return DAG.getNode(ISD::ADD, DL, VT, N0,
3601 DAG.getNode(ISD::SUB, DL, VT, N1.getOperand(1),
3602 N1.getOperand(0)));
3603
3604 // A - (A & B) -> A & (~B)
3605 if (N1.getOpcode() == ISD::AND) {
3606 SDValue A = N1.getOperand(0);
3607 SDValue B = N1.getOperand(1);
3608 if (A != N0)
3609 std::swap(A, B);
3610 if (A == N0 &&
3611 (N1.hasOneUse() || isConstantOrConstantVector(B, /*NoOpaques=*/true))) {
3612 SDValue InvB =
3613 DAG.getNode(ISD::XOR, DL, VT, B, DAG.getAllOnesConstant(DL, VT));
3614 return DAG.getNode(ISD::AND, DL, VT, A, InvB);
3615 }
3616 }
3617
3618 // fold (X - (-Y * Z)) -> (X + (Y * Z))
3619 if (N1.getOpcode() == ISD::MUL && N1.hasOneUse()) {
3620 if (N1.getOperand(0).getOpcode() == ISD::SUB &&
3621 isNullOrNullSplat(N1.getOperand(0).getOperand(0))) {
3622 SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
3623 N1.getOperand(0).getOperand(1),
3624 N1.getOperand(1));
3625 return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
3626 }
3627 if (N1.getOperand(1).getOpcode() == ISD::SUB &&
3628 isNullOrNullSplat(N1.getOperand(1).getOperand(0))) {
3629 SDValue Mul = DAG.getNode(ISD::MUL, DL, VT,
3630 N1.getOperand(0),
3631 N1.getOperand(1).getOperand(1));
3632 return DAG.getNode(ISD::ADD, DL, VT, N0, Mul);
3633 }
3634 }
3635
3636 // If either operand of a sub is undef, the result is undef
3637 if (N0.isUndef())
3638 return N0;
3639 if (N1.isUndef())
3640 return N1;
3641
3642 if (SDValue V = foldAddSubBoolOfMaskedVal(N, DAG))
3643 return V;
3644
3645 if (SDValue V = foldAddSubOfSignBit(N, DAG))
3646 return V;
3647
3648 if (SDValue V = foldAddSubMasked1(false, N0, N1, DAG, SDLoc(N)))
3649 return V;
3650
3651 if (SDValue V = foldSubToUSubSat(VT, N))
3652 return V;
3653
3654 // (x - y) - 1 -> add (xor y, -1), x
3655 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() && isOneOrOneSplat(N1)) {
3656 SDValue Xor = DAG.getNode(ISD::XOR, DL, VT, N0.getOperand(1),
3657 DAG.getAllOnesConstant(DL, VT));
3658 return DAG.getNode(ISD::ADD, DL, VT, Xor, N0.getOperand(0));
3659 }
3660
3661 // Look for:
3662 // sub y, (xor x, -1)
3663 // And if the target does not like this form then turn into:
3664 // add (add x, y), 1
3665 if (TLI.preferIncOfAddToSubOfNot(VT) && N1.hasOneUse() && isBitwiseNot(N1)) {
3666 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, N1.getOperand(0));
3667 return DAG.getNode(ISD::ADD, DL, VT, Add, DAG.getConstant(1, DL, VT));
3668 }
3669
3670 // Hoist one-use addition by non-opaque constant:
3671 // (x + C) - y -> (x - y) + C
3672 if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
3673 isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3674 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
3675 return DAG.getNode(ISD::ADD, DL, VT, Sub, N0.getOperand(1));
3676 }
3677 // y - (x + C) -> (y - x) - C
3678 if (N1.getOpcode() == ISD::ADD && N1.hasOneUse() &&
3679 isConstantOrConstantVector(N1.getOperand(1), /*NoOpaques=*/true)) {
3680 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, N1.getOperand(0));
3681 return DAG.getNode(ISD::SUB, DL, VT, Sub, N1.getOperand(1));
3682 }
3683 // (x - C) - y -> (x - y) - C
3684 // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3685 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
3686 isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true)) {
3687 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), N1);
3688 return DAG.getNode(ISD::SUB, DL, VT, Sub, N0.getOperand(1));
3689 }
3690 // (C - x) - y -> C - (x + y)
3691 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
3692 isConstantOrConstantVector(N0.getOperand(0), /*NoOpaques=*/true)) {
3693 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1), N1);
3694 return DAG.getNode(ISD::SUB, DL, VT, N0.getOperand(0), Add);
3695 }
3696
3697 // If the target's bool is represented as 0/-1, prefer to make this 'add 0/-1'
3698 // rather than 'sub 0/1' (the sext should get folded).
3699 // sub X, (zext i1 Y) --> add X, (sext i1 Y)
3700 if (N1.getOpcode() == ISD::ZERO_EXTEND &&
3701 N1.getOperand(0).getScalarValueSizeInBits() == 1 &&
3702 TLI.getBooleanContents(VT) ==
3703 TargetLowering::ZeroOrNegativeOneBooleanContent) {
3704 SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N1.getOperand(0));
3705 return DAG.getNode(ISD::ADD, DL, VT, N0, SExt);
3706 }
3707
3708 // fold Y = sra (X, size(X)-1); sub (xor (X, Y), Y) -> (abs X)
3709 if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
3710 if (N0.getOpcode() == ISD::XOR && N1.getOpcode() == ISD::SRA) {
3711 SDValue X0 = N0.getOperand(0), X1 = N0.getOperand(1);
3712 SDValue S0 = N1.getOperand(0);
3713 if ((X0 == S0 && X1 == N1) || (X0 == N1 && X1 == S0))
3714 if (ConstantSDNode *C = isConstOrConstSplat(N1.getOperand(1)))
3715 if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
3716 return DAG.getNode(ISD::ABS, SDLoc(N), VT, S0);
3717 }
3718 }
3719
3720 // If the relocation model supports it, consider symbol offsets.
3721 if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N0))
3722 if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
3723 // fold (sub Sym, c) -> Sym-c
3724 if (N1C && GA->getOpcode() == ISD::GlobalAddress)
3725 return DAG.getGlobalAddress(GA->getGlobal(), SDLoc(N1C), VT,
3726 GA->getOffset() -
3727 (uint64_t)N1C->getSExtValue());
3728 // fold (sub Sym+c1, Sym+c2) -> c1-c2
3729 if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(N1))
3730 if (GA->getGlobal() == GB->getGlobal())
3731 return DAG.getConstant((uint64_t)GA->getOffset() - GB->getOffset(),
3732 DL, VT);
3733 }
3734
3735 // sub X, (sextinreg Y i1) -> add X, (and Y 1)
3736 if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3737 VTSDNode *TN = cast<VTSDNode>(N1.getOperand(1));
3738 if (TN->getVT() == MVT::i1) {
3739 SDValue ZExt = DAG.getNode(ISD::AND, DL, VT, N1.getOperand(0),
3740 DAG.getConstant(1, DL, VT));
3741 return DAG.getNode(ISD::ADD, DL, VT, N0, ZExt);
3742 }
3743 }
3744
3745 // canonicalize (sub X, (vscale * C)) to (add X, (vscale * -C))
3746 if (N1.getOpcode() == ISD::VSCALE && N1.hasOneUse()) {
3747 const APInt &IntVal = N1.getConstantOperandAPInt(0);
3748 return DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getVScale(DL, VT, -IntVal));
3749 }
3750
3751 // canonicalize (sub X, step_vector(C)) to (add X, step_vector(-C))
3752 if (N1.getOpcode() == ISD::STEP_VECTOR && N1.hasOneUse()) {
3753 APInt NewStep = -N1.getConstantOperandAPInt(0);
3754 return DAG.getNode(ISD::ADD, DL, VT, N0,
3755 DAG.getStepVector(DL, VT, NewStep));
3756 }
3757
3758 // Prefer an add for more folding potential and possibly better codegen:
3759 // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
3760 if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
3761 SDValue ShAmt = N1.getOperand(1);
3762 ConstantSDNode *ShAmtC = isConstOrConstSplat(ShAmt);
3763 if (ShAmtC &&
3764 ShAmtC->getAPIntValue() == (N1.getScalarValueSizeInBits() - 1)) {
3765 SDValue SRA = DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0), ShAmt);
3766 return DAG.getNode(ISD::ADD, DL, VT, N0, SRA);
3767 }
3768 }
3769
3770 // As with the previous fold, prefer add for more folding potential.
3771 // Subtracting SMIN/0 is the same as adding SMIN/0:
3772 // N0 - (X << BW-1) --> N0 + (X << BW-1)
3773 if (N1.getOpcode() == ISD::SHL) {
3774 ConstantSDNode *ShlC = isConstOrConstSplat(N1.getOperand(1));
3775 if (ShlC && ShlC->getAPIntValue() == VT.getScalarSizeInBits() - 1)
3776 return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
3777 }
3778
3779 // (sub (subcarry X, 0, Carry), Y) -> (subcarry X, Y, Carry)
3780 if (N0.getOpcode() == ISD::SUBCARRY && isNullConstant(N0.getOperand(1)) &&
3781 N0.getResNo() == 0 && N0.hasOneUse())
3782 return DAG.getNode(ISD::SUBCARRY, DL, N0->getVTList(),
3783 N0.getOperand(0), N1, N0.getOperand(2));
3784
3785 if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, VT)) {
3786 // (sub Carry, X) -> (addcarry (sub 0, X), 0, Carry)
3787 if (SDValue Carry = getAsCarry(TLI, N0)) {
3788 SDValue X = N1;
3789 SDValue Zero = DAG.getConstant(0, DL, VT);
3790 SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, Zero, X);
3791 return DAG.getNode(ISD::ADDCARRY, DL,
3792 DAG.getVTList(VT, Carry.getValueType()), NegX, Zero,
3793 Carry);
3794 }
3795 }
3796
3797 // If there's no chance of borrowing from adjacent bits, then sub is xor:
3798 // sub C0, X --> xor X, C0
3799 if (ConstantSDNode *C0 = isConstOrConstSplat(N0)) {
3800 if (!C0->isOpaque()) {
3801 const APInt &C0Val = C0->getAPIntValue();
3802 const APInt &MaybeOnes = ~DAG.computeKnownBits(N1).Zero;
3803 if ((C0Val - MaybeOnes) == (C0Val ^ MaybeOnes))
3804 return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
3805 }
3806 }
3807
3808 // max(a,b) - min(a,b) --> abd(a,b)
3809 auto MatchSubMaxMin = [&](unsigned Max, unsigned Min, unsigned Abd) {
3810 if (N0.getOpcode() != Max || N1.getOpcode() != Min)
3811 return SDValue();
3812 if ((N0.getOperand(0) != N1.getOperand(0) ||
3813 N0.getOperand(1) != N1.getOperand(1)) &&
3814 (N0.getOperand(0) != N1.getOperand(1) ||
3815 N0.getOperand(1) != N1.getOperand(0)))
3816 return SDValue();
3817 if (!TLI.isOperationLegalOrCustom(Abd, VT))
3818 return SDValue();
3819 return DAG.getNode(Abd, DL, VT, N0.getOperand(0), N0.getOperand(1));
3820 };
3821 if (SDValue R = MatchSubMaxMin(ISD::SMAX, ISD::SMIN, ISD::ABDS))
3822 return R;
3823 if (SDValue R = MatchSubMaxMin(ISD::UMAX, ISD::UMIN, ISD::ABDU))
3824 return R;
3825
3826 return SDValue();
3827 }
3828
visitSUBSAT(SDNode * N)3829 SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
3830 SDValue N0 = N->getOperand(0);
3831 SDValue N1 = N->getOperand(1);
3832 EVT VT = N0.getValueType();
3833 SDLoc DL(N);
3834
3835 // fold (sub_sat x, undef) -> 0
3836 if (N0.isUndef() || N1.isUndef())
3837 return DAG.getConstant(0, DL, VT);
3838
3839 // fold (sub_sat x, x) -> 0
3840 if (N0 == N1)
3841 return DAG.getConstant(0, DL, VT);
3842
3843 // fold (sub_sat c1, c2) -> c3
3844 if (SDValue C = DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1}))
3845 return C;
3846
3847 // fold vector ops
3848 if (VT.isVector()) {
3849 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3850 return FoldedVOp;
3851
3852 // fold (sub_sat x, 0) -> x, vector edition
3853 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
3854 return N0;
3855 }
3856
3857 // fold (sub_sat x, 0) -> x
3858 if (isNullConstant(N1))
3859 return N0;
3860
3861 return SDValue();
3862 }
3863
visitSUBC(SDNode * N)3864 SDValue DAGCombiner::visitSUBC(SDNode *N) {
3865 SDValue N0 = N->getOperand(0);
3866 SDValue N1 = N->getOperand(1);
3867 EVT VT = N0.getValueType();
3868 SDLoc DL(N);
3869
3870 // If the flag result is dead, turn this into an SUB.
3871 if (!N->hasAnyUseOfValue(1))
3872 return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
3873 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3874
3875 // fold (subc x, x) -> 0 + no borrow
3876 if (N0 == N1)
3877 return CombineTo(N, DAG.getConstant(0, DL, VT),
3878 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3879
3880 // fold (subc x, 0) -> x + no borrow
3881 if (isNullConstant(N1))
3882 return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3883
3884 // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow
3885 if (isAllOnesConstant(N0))
3886 return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
3887 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3888
3889 return SDValue();
3890 }
3891
visitSUBO(SDNode * N)3892 SDValue DAGCombiner::visitSUBO(SDNode *N) {
3893 SDValue N0 = N->getOperand(0);
3894 SDValue N1 = N->getOperand(1);
3895 EVT VT = N0.getValueType();
3896 bool IsSigned = (ISD::SSUBO == N->getOpcode());
3897
3898 EVT CarryVT = N->getValueType(1);
3899 SDLoc DL(N);
3900
3901 // If the flag result is dead, turn this into an SUB.
3902 if (!N->hasAnyUseOfValue(1))
3903 return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1),
3904 DAG.getUNDEF(CarryVT));
3905
3906 // fold (subo x, x) -> 0 + no borrow
3907 if (N0 == N1)
3908 return CombineTo(N, DAG.getConstant(0, DL, VT),
3909 DAG.getConstant(0, DL, CarryVT));
3910
3911 ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);
3912
3913 // fold (subox, c) -> (addo x, -c)
3914 if (IsSigned && N1C && !N1C->getAPIntValue().isMinSignedValue()) {
3915 return DAG.getNode(ISD::SADDO, DL, N->getVTList(), N0,
3916 DAG.getConstant(-N1C->getAPIntValue(), DL, VT));
3917 }
3918
3919 // fold (subo x, 0) -> x + no borrow
3920 if (isNullOrNullSplat(N1))
3921 return CombineTo(N, N0, DAG.getConstant(0, DL, CarryVT));
3922
3923 // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
3924 if (!IsSigned && isAllOnesOrAllOnesSplat(N0))
3925 return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0),
3926 DAG.getConstant(0, DL, CarryVT));
3927
3928 return SDValue();
3929 }
3930
visitSUBE(SDNode * N)3931 SDValue DAGCombiner::visitSUBE(SDNode *N) {
3932 SDValue N0 = N->getOperand(0);
3933 SDValue N1 = N->getOperand(1);
3934 SDValue CarryIn = N->getOperand(2);
3935
3936 // fold (sube x, y, false) -> (subc x, y)
3937 if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3938 return DAG.getNode(ISD::SUBC, SDLoc(N), N->getVTList(), N0, N1);
3939
3940 return SDValue();
3941 }
3942
visitSUBCARRY(SDNode * N)3943 SDValue DAGCombiner::visitSUBCARRY(SDNode *N) {
3944 SDValue N0 = N->getOperand(0);
3945 SDValue N1 = N->getOperand(1);
3946 SDValue CarryIn = N->getOperand(2);
3947
3948 // fold (subcarry x, y, false) -> (usubo x, y)
3949 if (isNullConstant(CarryIn)) {
3950 if (!LegalOperations ||
3951 TLI.isOperationLegalOrCustom(ISD::USUBO, N->getValueType(0)))
3952 return DAG.getNode(ISD::USUBO, SDLoc(N), N->getVTList(), N0, N1);
3953 }
3954
3955 return SDValue();
3956 }
3957
visitSSUBO_CARRY(SDNode * N)3958 SDValue DAGCombiner::visitSSUBO_CARRY(SDNode *N) {
3959 SDValue N0 = N->getOperand(0);
3960 SDValue N1 = N->getOperand(1);
3961 SDValue CarryIn = N->getOperand(2);
3962
3963 // fold (ssubo_carry x, y, false) -> (ssubo x, y)
3964 if (isNullConstant(CarryIn)) {
3965 if (!LegalOperations ||
3966 TLI.isOperationLegalOrCustom(ISD::SSUBO, N->getValueType(0)))
3967 return DAG.getNode(ISD::SSUBO, SDLoc(N), N->getVTList(), N0, N1);
3968 }
3969
3970 return SDValue();
3971 }
3972
3973 // Notice that "mulfix" can be any of SMULFIX, SMULFIXSAT, UMULFIX and
3974 // UMULFIXSAT here.
visitMULFIX(SDNode * N)3975 SDValue DAGCombiner::visitMULFIX(SDNode *N) {
3976 SDValue N0 = N->getOperand(0);
3977 SDValue N1 = N->getOperand(1);
3978 SDValue Scale = N->getOperand(2);
3979 EVT VT = N0.getValueType();
3980
3981 // fold (mulfix x, undef, scale) -> 0
3982 if (N0.isUndef() || N1.isUndef())
3983 return DAG.getConstant(0, SDLoc(N), VT);
3984
3985 // Canonicalize constant to RHS (vector doesn't have to splat)
3986 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
3987 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
3988 return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0, Scale);
3989
3990 // fold (mulfix x, 0, scale) -> 0
3991 if (isNullConstant(N1))
3992 return DAG.getConstant(0, SDLoc(N), VT);
3993
3994 return SDValue();
3995 }
3996
visitMUL(SDNode * N)3997 SDValue DAGCombiner::visitMUL(SDNode *N) {
3998 SDValue N0 = N->getOperand(0);
3999 SDValue N1 = N->getOperand(1);
4000 EVT VT = N0.getValueType();
4001 SDLoc DL(N);
4002
4003 // fold (mul x, undef) -> 0
4004 if (N0.isUndef() || N1.isUndef())
4005 return DAG.getConstant(0, DL, VT);
4006
4007 // fold (mul c1, c2) -> c1*c2
4008 if (SDValue C = DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {N0, N1}))
4009 return C;
4010
4011 // canonicalize constant to RHS (vector doesn't have to splat)
4012 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4013 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4014 return DAG.getNode(ISD::MUL, DL, VT, N1, N0);
4015
4016 bool N1IsConst = false;
4017 bool N1IsOpaqueConst = false;
4018 APInt ConstValue1;
4019
4020 // fold vector ops
4021 if (VT.isVector()) {
4022 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4023 return FoldedVOp;
4024
4025 N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1);
4026 assert((!N1IsConst ||
4027 ConstValue1.getBitWidth() == VT.getScalarSizeInBits()) &&
4028 "Splat APInt should be element width");
4029 } else {
4030 N1IsConst = isa<ConstantSDNode>(N1);
4031 if (N1IsConst) {
4032 ConstValue1 = cast<ConstantSDNode>(N1)->getAPIntValue();
4033 N1IsOpaqueConst = cast<ConstantSDNode>(N1)->isOpaque();
4034 }
4035 }
4036
4037 // fold (mul x, 0) -> 0
4038 if (N1IsConst && ConstValue1.isZero())
4039 return N1;
4040
4041 // fold (mul x, 1) -> x
4042 if (N1IsConst && ConstValue1.isOne())
4043 return N0;
4044
4045 if (SDValue NewSel = foldBinOpIntoSelect(N))
4046 return NewSel;
4047
4048 // fold (mul x, -1) -> 0-x
4049 if (N1IsConst && ConstValue1.isAllOnes())
4050 return DAG.getNegative(N0, DL, VT);
4051
4052 // fold (mul x, (1 << c)) -> x << c
4053 if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4054 DAG.isKnownToBeAPowerOfTwo(N1) &&
4055 (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
4056 SDValue LogBase2 = BuildLogBase2(N1, DL);
4057 EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4058 SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4059 return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc);
4060 }
4061
4062 // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
4063 if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) {
4064 unsigned Log2Val = (-ConstValue1).logBase2();
4065 // FIXME: If the input is something that is easily negated (e.g. a
4066 // single-use add), we should put the negate there.
4067 return DAG.getNode(ISD::SUB, DL, VT,
4068 DAG.getConstant(0, DL, VT),
4069 DAG.getNode(ISD::SHL, DL, VT, N0,
4070 DAG.getConstant(Log2Val, DL,
4071 getShiftAmountTy(N0.getValueType()))));
4072 }
4073
4074 // Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
4075 // hi result is in use in case we hit this mid-legalization.
4076 for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4077 if (!LegalOperations || TLI.isOperationLegalOrCustom(LoHiOpc, VT)) {
4078 SDVTList LoHiVT = DAG.getVTList(VT, VT);
4079 // TODO: Can we match commutable operands with getNodeIfExists?
4080 if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N0, N1}))
4081 if (LoHi->hasAnyUseOfValue(1))
4082 return SDValue(LoHi, 0);
4083 if (SDNode *LoHi = DAG.getNodeIfExists(LoHiOpc, LoHiVT, {N1, N0}))
4084 if (LoHi->hasAnyUseOfValue(1))
4085 return SDValue(LoHi, 0);
4086 }
4087 }
4088
4089 // Try to transform:
4090 // (1) multiply-by-(power-of-2 +/- 1) into shift and add/sub.
4091 // mul x, (2^N + 1) --> add (shl x, N), x
4092 // mul x, (2^N - 1) --> sub (shl x, N), x
4093 // Examples: x * 33 --> (x << 5) + x
4094 // x * 15 --> (x << 4) - x
4095 // x * -33 --> -((x << 5) + x)
4096 // x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
4097 // (2) multiply-by-(power-of-2 +/- power-of-2) into shifts and add/sub.
4098 // mul x, (2^N + 2^M) --> (add (shl x, N), (shl x, M))
4099 // mul x, (2^N - 2^M) --> (sub (shl x, N), (shl x, M))
4100 // Examples: x * 0x8800 --> (x << 15) + (x << 11)
4101 // x * 0xf800 --> (x << 16) - (x << 11)
4102 // x * -0x8800 --> -((x << 15) + (x << 11))
4103 // x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4104 if (N1IsConst && TLI.decomposeMulByConstant(*DAG.getContext(), VT, N1)) {
4105 // TODO: We could handle more general decomposition of any constant by
4106 // having the target set a limit on number of ops and making a
4107 // callback to determine that sequence (similar to sqrt expansion).
4108 unsigned MathOp = ISD::DELETED_NODE;
4109 APInt MulC = ConstValue1.abs();
4110 // The constant `2` should be treated as (2^0 + 1).
4111 unsigned TZeros = MulC == 2 ? 0 : MulC.countTrailingZeros();
4112 MulC.lshrInPlace(TZeros);
4113 if ((MulC - 1).isPowerOf2())
4114 MathOp = ISD::ADD;
4115 else if ((MulC + 1).isPowerOf2())
4116 MathOp = ISD::SUB;
4117
4118 if (MathOp != ISD::DELETED_NODE) {
4119 unsigned ShAmt =
4120 MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
4121 ShAmt += TZeros;
4122 assert(ShAmt < VT.getScalarSizeInBits() &&
4123 "multiply-by-constant generated out of bounds shift");
4124 SDValue Shl =
4125 DAG.getNode(ISD::SHL, DL, VT, N0, DAG.getConstant(ShAmt, DL, VT));
4126 SDValue R =
4127 TZeros ? DAG.getNode(MathOp, DL, VT, Shl,
4128 DAG.getNode(ISD::SHL, DL, VT, N0,
4129 DAG.getConstant(TZeros, DL, VT)))
4130 : DAG.getNode(MathOp, DL, VT, Shl, N0);
4131 if (ConstValue1.isNegative())
4132 R = DAG.getNegative(R, DL, VT);
4133 return R;
4134 }
4135 }
4136
4137 // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4138 if (N0.getOpcode() == ISD::SHL) {
4139 SDValue N01 = N0.getOperand(1);
4140 if (SDValue C3 = DAG.FoldConstantArithmetic(ISD::SHL, DL, VT, {N1, N01}))
4141 return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), C3);
4142 }
4143
4144 // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
4145 // use.
4146 {
4147 SDValue Sh, Y;
4148
4149 // Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)).
4150 if (N0.getOpcode() == ISD::SHL &&
4151 isConstantOrConstantVector(N0.getOperand(1)) && N0->hasOneUse()) {
4152 Sh = N0; Y = N1;
4153 } else if (N1.getOpcode() == ISD::SHL &&
4154 isConstantOrConstantVector(N1.getOperand(1)) &&
4155 N1->hasOneUse()) {
4156 Sh = N1; Y = N0;
4157 }
4158
4159 if (Sh.getNode()) {
4160 SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, Sh.getOperand(0), Y);
4161 return DAG.getNode(ISD::SHL, DL, VT, Mul, Sh.getOperand(1));
4162 }
4163 }
4164
4165 // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4166 if (DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
4167 N0.getOpcode() == ISD::ADD &&
4168 DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) &&
4169 isMulAddWithConstProfitable(N, N0, N1))
4170 return DAG.getNode(
4171 ISD::ADD, DL, VT,
4172 DAG.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
4173 DAG.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
4174
4175 // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
4176 ConstantSDNode *NC1 = isConstOrConstSplat(N1);
4177 if (N0.getOpcode() == ISD::VSCALE && NC1) {
4178 const APInt &C0 = N0.getConstantOperandAPInt(0);
4179 const APInt &C1 = NC1->getAPIntValue();
4180 return DAG.getVScale(DL, VT, C0 * C1);
4181 }
4182
4183 // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
4184 APInt MulVal;
4185 if (N0.getOpcode() == ISD::STEP_VECTOR &&
4186 ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
4187 const APInt &C0 = N0.getConstantOperandAPInt(0);
4188 APInt NewStep = C0 * MulVal;
4189 return DAG.getStepVector(DL, VT, NewStep);
4190 }
4191
4192 // Fold ((mul x, 0/undef) -> 0,
4193 // (mul x, 1) -> x) -> x)
4194 // -> and(x, mask)
4195 // We can replace vectors with '0' and '1' factors with a clearing mask.
4196 if (VT.isFixedLengthVector()) {
4197 unsigned NumElts = VT.getVectorNumElements();
4198 SmallBitVector ClearMask;
4199 ClearMask.reserve(NumElts);
4200 auto IsClearMask = [&ClearMask](ConstantSDNode *V) {
4201 if (!V || V->isZero()) {
4202 ClearMask.push_back(true);
4203 return true;
4204 }
4205 ClearMask.push_back(false);
4206 return V->isOne();
4207 };
4208 if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::AND, VT)) &&
4209 ISD::matchUnaryPredicate(N1, IsClearMask, /*AllowUndefs*/ true)) {
4210 assert(N1.getOpcode() == ISD::BUILD_VECTOR && "Unknown constant vector");
4211 EVT LegalSVT = N1.getOperand(0).getValueType();
4212 SDValue Zero = DAG.getConstant(0, DL, LegalSVT);
4213 SDValue AllOnes = DAG.getAllOnesConstant(DL, LegalSVT);
4214 SmallVector<SDValue, 16> Mask(NumElts, AllOnes);
4215 for (unsigned I = 0; I != NumElts; ++I)
4216 if (ClearMask[I])
4217 Mask[I] = Zero;
4218 return DAG.getNode(ISD::AND, DL, VT, N0, DAG.getBuildVector(VT, DL, Mask));
4219 }
4220 }
4221
4222 // reassociate mul
4223 if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4224 return RMUL;
4225
4226 // Simplify the operands using demanded-bits information.
4227 if (SimplifyDemandedBits(SDValue(N, 0)))
4228 return SDValue(N, 0);
4229
4230 return SDValue();
4231 }
4232
4233 /// Return true if divmod libcall is available.
isDivRemLibcallAvailable(SDNode * Node,bool isSigned,const TargetLowering & TLI)4234 static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
4235 const TargetLowering &TLI) {
4236 RTLIB::Libcall LC;
4237 EVT NodeType = Node->getValueType(0);
4238 if (!NodeType.isSimple())
4239 return false;
4240 switch (NodeType.getSimpleVT().SimpleTy) {
4241 default: return false; // No libcall for vector types.
4242 case MVT::i8: LC= isSigned ? RTLIB::SDIVREM_I8 : RTLIB::UDIVREM_I8; break;
4243 case MVT::i16: LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
4244 case MVT::i32: LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
4245 case MVT::i64: LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break;
4246 case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break;
4247 }
4248
4249 return TLI.getLibcallName(LC) != nullptr;
4250 }
4251
4252 /// Issue divrem if both quotient and remainder are needed.
useDivRem(SDNode * Node)4253 SDValue DAGCombiner::useDivRem(SDNode *Node) {
4254 if (Node->use_empty())
4255 return SDValue(); // This is a dead node, leave it alone.
4256
4257 unsigned Opcode = Node->getOpcode();
4258 bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
4259 unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
4260
4261 // DivMod lib calls can still work on non-legal types if using lib-calls.
4262 EVT VT = Node->getValueType(0);
4263 if (VT.isVector() || !VT.isInteger())
4264 return SDValue();
4265
4266 if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(DivRemOpc, VT))
4267 return SDValue();
4268
4269 // If DIVREM is going to get expanded into a libcall,
4270 // but there is no libcall available, then don't combine.
4271 if (!TLI.isOperationLegalOrCustom(DivRemOpc, VT) &&
4272 !isDivRemLibcallAvailable(Node, isSigned, TLI))
4273 return SDValue();
4274
4275 // If div is legal, it's better to do the normal expansion
4276 unsigned OtherOpcode = 0;
4277 if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) {
4278 OtherOpcode = isSigned ? ISD::SREM : ISD::UREM;
4279 if (TLI.isOperationLegalOrCustom(Opcode, VT))
4280 return SDValue();
4281 } else {
4282 OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4283 if (TLI.isOperationLegalOrCustom(OtherOpcode, VT))
4284 return SDValue();
4285 }
4286
4287 SDValue Op0 = Node->getOperand(0);
4288 SDValue Op1 = Node->getOperand(1);
4289 SDValue combined;
4290 for (SDNode *User : Op0->uses()) {
4291 if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
4292 User->use_empty())
4293 continue;
4294 // Convert the other matching node(s), too;
4295 // otherwise, the DIVREM may get target-legalized into something
4296 // target-specific that we won't be able to recognize.
4297 unsigned UserOpc = User->getOpcode();
4298 if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) &&
4299 User->getOperand(0) == Op0 &&
4300 User->getOperand(1) == Op1) {
4301 if (!combined) {
4302 if (UserOpc == OtherOpcode) {
4303 SDVTList VTs = DAG.getVTList(VT, VT);
4304 combined = DAG.getNode(DivRemOpc, SDLoc(Node), VTs, Op0, Op1);
4305 } else if (UserOpc == DivRemOpc) {
4306 combined = SDValue(User, 0);
4307 } else {
4308 assert(UserOpc == Opcode);
4309 continue;
4310 }
4311 }
4312 if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV)
4313 CombineTo(User, combined);
4314 else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM)
4315 CombineTo(User, combined.getValue(1));
4316 }
4317 }
4318 return combined;
4319 }
4320
simplifyDivRem(SDNode * N,SelectionDAG & DAG)4321 static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
4322 SDValue N0 = N->getOperand(0);
4323 SDValue N1 = N->getOperand(1);
4324 EVT VT = N->getValueType(0);
4325 SDLoc DL(N);
4326
4327 unsigned Opc = N->getOpcode();
4328 bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
4329 ConstantSDNode *N1C = isConstOrConstSplat(N1);
4330
4331 // X / undef -> undef
4332 // X % undef -> undef
4333 // X / 0 -> undef
4334 // X % 0 -> undef
4335 // NOTE: This includes vectors where any divisor element is zero/undef.
4336 if (DAG.isUndef(Opc, {N0, N1}))
4337 return DAG.getUNDEF(VT);
4338
4339 // undef / X -> 0
4340 // undef % X -> 0
4341 if (N0.isUndef())
4342 return DAG.getConstant(0, DL, VT);
4343
4344 // 0 / X -> 0
4345 // 0 % X -> 0
4346 ConstantSDNode *N0C = isConstOrConstSplat(N0);
4347 if (N0C && N0C->isZero())
4348 return N0;
4349
4350 // X / X -> 1
4351 // X % X -> 0
4352 if (N0 == N1)
4353 return DAG.getConstant(IsDiv ? 1 : 0, DL, VT);
4354
4355 // X / 1 -> X
4356 // X % 1 -> 0
4357 // If this is a boolean op (single-bit element type), we can't have
4358 // division-by-zero or remainder-by-zero, so assume the divisor is 1.
4359 // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
4360 // it's a 1.
4361 if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1))
4362 return IsDiv ? N0 : DAG.getConstant(0, DL, VT);
4363
4364 return SDValue();
4365 }
4366
visitSDIV(SDNode * N)4367 SDValue DAGCombiner::visitSDIV(SDNode *N) {
4368 SDValue N0 = N->getOperand(0);
4369 SDValue N1 = N->getOperand(1);
4370 EVT VT = N->getValueType(0);
4371 EVT CCVT = getSetCCResultType(VT);
4372 SDLoc DL(N);
4373
4374 // fold (sdiv c1, c2) -> c1/c2
4375 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, {N0, N1}))
4376 return C;
4377
4378 // fold vector ops
4379 if (VT.isVector())
4380 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4381 return FoldedVOp;
4382
4383 // fold (sdiv X, -1) -> 0-X
4384 ConstantSDNode *N1C = isConstOrConstSplat(N1);
4385 if (N1C && N1C->isAllOnes())
4386 return DAG.getNegative(N0, DL, VT);
4387
4388 // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
4389 if (N1C && N1C->getAPIntValue().isMinSignedValue())
4390 return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4391 DAG.getConstant(1, DL, VT),
4392 DAG.getConstant(0, DL, VT));
4393
4394 if (SDValue V = simplifyDivRem(N, DAG))
4395 return V;
4396
4397 if (SDValue NewSel = foldBinOpIntoSelect(N))
4398 return NewSel;
4399
4400 // If we know the sign bits of both operands are zero, strength reduce to a
4401 // udiv instead. Handles (X&15) /s 4 -> X&15 >> 2
4402 if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
4403 return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1);
4404
4405 if (SDValue V = visitSDIVLike(N0, N1, N)) {
4406 // If the corresponding remainder node exists, update its users with
4407 // (Dividend - (Quotient * Divisor).
4408 if (SDNode *RemNode = DAG.getNodeIfExists(ISD::SREM, N->getVTList(),
4409 { N0, N1 })) {
4410 SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
4411 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4412 AddToWorklist(Mul.getNode());
4413 AddToWorklist(Sub.getNode());
4414 CombineTo(RemNode, Sub);
4415 }
4416 return V;
4417 }
4418
4419 // sdiv, srem -> sdivrem
4420 // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4421 // true. Otherwise, we break the simplification logic in visitREM().
4422 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4423 if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
4424 if (SDValue DivRem = useDivRem(N))
4425 return DivRem;
4426
4427 return SDValue();
4428 }
4429
isDivisorPowerOfTwo(SDValue Divisor)4430 static bool isDivisorPowerOfTwo(SDValue Divisor) {
4431 // Helper for determining whether a value is a power-2 constant scalar or a
4432 // vector of such elements.
4433 auto IsPowerOfTwo = [](ConstantSDNode *C) {
4434 if (C->isZero() || C->isOpaque())
4435 return false;
4436 if (C->getAPIntValue().isPowerOf2())
4437 return true;
4438 if (C->getAPIntValue().isNegatedPowerOf2())
4439 return true;
4440 return false;
4441 };
4442
4443 return ISD::matchUnaryPredicate(Divisor, IsPowerOfTwo);
4444 }
4445
visitSDIVLike(SDValue N0,SDValue N1,SDNode * N)4446 SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4447 SDLoc DL(N);
4448 EVT VT = N->getValueType(0);
4449 EVT CCVT = getSetCCResultType(VT);
4450 unsigned BitWidth = VT.getScalarSizeInBits();
4451
4452 // fold (sdiv X, pow2) -> simple ops after legalize
4453 // FIXME: We check for the exact bit here because the generic lowering gives
4454 // better results in that case. The target-specific lowering should learn how
4455 // to handle exact sdivs efficiently.
4456 if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1)) {
4457 // Target-specific implementation of sdiv x, pow2.
4458 if (SDValue Res = BuildSDIVPow2(N))
4459 return Res;
4460
4461 // Create constants that are functions of the shift amount value.
4462 EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
4463 SDValue Bits = DAG.getConstant(BitWidth, DL, ShiftAmtTy);
4464 SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT, N1);
4465 C1 = DAG.getZExtOrTrunc(C1, DL, ShiftAmtTy);
4466 SDValue Inexact = DAG.getNode(ISD::SUB, DL, ShiftAmtTy, Bits, C1);
4467 if (!isConstantOrConstantVector(Inexact))
4468 return SDValue();
4469
4470 // Splat the sign bit into the register
4471 SDValue Sign = DAG.getNode(ISD::SRA, DL, VT, N0,
4472 DAG.getConstant(BitWidth - 1, DL, ShiftAmtTy));
4473 AddToWorklist(Sign.getNode());
4474
4475 // Add (N0 < 0) ? abs2 - 1 : 0;
4476 SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, Sign, Inexact);
4477 AddToWorklist(Srl.getNode());
4478 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Srl);
4479 AddToWorklist(Add.getNode());
4480 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Add, C1);
4481 AddToWorklist(Sra.getNode());
4482
4483 // Special case: (sdiv X, 1) -> X
4484 // Special Case: (sdiv X, -1) -> 0-X
4485 SDValue One = DAG.getConstant(1, DL, VT);
4486 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
4487 SDValue IsOne = DAG.getSetCC(DL, CCVT, N1, One, ISD::SETEQ);
4488 SDValue IsAllOnes = DAG.getSetCC(DL, CCVT, N1, AllOnes, ISD::SETEQ);
4489 SDValue IsOneOrAllOnes = DAG.getNode(ISD::OR, DL, CCVT, IsOne, IsAllOnes);
4490 Sra = DAG.getSelect(DL, VT, IsOneOrAllOnes, N0, Sra);
4491
4492 // If dividing by a positive value, we're done. Otherwise, the result must
4493 // be negated.
4494 SDValue Zero = DAG.getConstant(0, DL, VT);
4495 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, Zero, Sra);
4496
4497 // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
4498 SDValue IsNeg = DAG.getSetCC(DL, CCVT, N1, Zero, ISD::SETLT);
4499 SDValue Res = DAG.getSelect(DL, VT, IsNeg, Sub, Sra);
4500 return Res;
4501 }
4502
4503 // If integer divide is expensive and we satisfy the requirements, emit an
4504 // alternate sequence. Targets may check function attributes for size/speed
4505 // trade-offs.
4506 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4507 if (isConstantOrConstantVector(N1) &&
4508 !TLI.isIntDivCheap(N->getValueType(0), Attr))
4509 if (SDValue Op = BuildSDIV(N))
4510 return Op;
4511
4512 return SDValue();
4513 }
4514
visitUDIV(SDNode * N)4515 SDValue DAGCombiner::visitUDIV(SDNode *N) {
4516 SDValue N0 = N->getOperand(0);
4517 SDValue N1 = N->getOperand(1);
4518 EVT VT = N->getValueType(0);
4519 EVT CCVT = getSetCCResultType(VT);
4520 SDLoc DL(N);
4521
4522 // fold (udiv c1, c2) -> c1/c2
4523 if (SDValue C = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, {N0, N1}))
4524 return C;
4525
4526 // fold vector ops
4527 if (VT.isVector())
4528 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4529 return FoldedVOp;
4530
4531 // fold (udiv X, -1) -> select(X == -1, 1, 0)
4532 ConstantSDNode *N1C = isConstOrConstSplat(N1);
4533 if (N1C && N1C->isAllOnes() && CCVT.isVector() == VT.isVector()) {
4534 return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
4535 DAG.getConstant(1, DL, VT),
4536 DAG.getConstant(0, DL, VT));
4537 }
4538
4539 if (SDValue V = simplifyDivRem(N, DAG))
4540 return V;
4541
4542 if (SDValue NewSel = foldBinOpIntoSelect(N))
4543 return NewSel;
4544
4545 if (SDValue V = visitUDIVLike(N0, N1, N)) {
4546 // If the corresponding remainder node exists, update its users with
4547 // (Dividend - (Quotient * Divisor).
4548 if (SDNode *RemNode = DAG.getNodeIfExists(ISD::UREM, N->getVTList(),
4549 { N0, N1 })) {
4550 SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, V, N1);
4551 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4552 AddToWorklist(Mul.getNode());
4553 AddToWorklist(Sub.getNode());
4554 CombineTo(RemNode, Sub);
4555 }
4556 return V;
4557 }
4558
4559 // sdiv, srem -> sdivrem
4560 // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4561 // true. Otherwise, we break the simplification logic in visitREM().
4562 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4563 if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr))
4564 if (SDValue DivRem = useDivRem(N))
4565 return DivRem;
4566
4567 return SDValue();
4568 }
4569
visitUDIVLike(SDValue N0,SDValue N1,SDNode * N)4570 SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4571 SDLoc DL(N);
4572 EVT VT = N->getValueType(0);
4573
4574 // fold (udiv x, (1 << c)) -> x >>u c
4575 if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4576 DAG.isKnownToBeAPowerOfTwo(N1)) {
4577 SDValue LogBase2 = BuildLogBase2(N1, DL);
4578 AddToWorklist(LogBase2.getNode());
4579
4580 EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4581 SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
4582 AddToWorklist(Trunc.getNode());
4583 return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
4584 }
4585
4586 // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
4587 if (N1.getOpcode() == ISD::SHL) {
4588 SDValue N10 = N1.getOperand(0);
4589 if (isConstantOrConstantVector(N10, /*NoOpaques*/ true) &&
4590 DAG.isKnownToBeAPowerOfTwo(N10)) {
4591 SDValue LogBase2 = BuildLogBase2(N10, DL);
4592 AddToWorklist(LogBase2.getNode());
4593
4594 EVT ADDVT = N1.getOperand(1).getValueType();
4595 SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
4596 AddToWorklist(Trunc.getNode());
4597 SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc);
4598 AddToWorklist(Add.getNode());
4599 return DAG.getNode(ISD::SRL, DL, VT, N0, Add);
4600 }
4601 }
4602
4603 // fold (udiv x, c) -> alternate
4604 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4605 if (isConstantOrConstantVector(N1) &&
4606 !TLI.isIntDivCheap(N->getValueType(0), Attr))
4607 if (SDValue Op = BuildUDIV(N))
4608 return Op;
4609
4610 return SDValue();
4611 }
4612
buildOptimizedSREM(SDValue N0,SDValue N1,SDNode * N)4613 SDValue DAGCombiner::buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N) {
4614 if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1) &&
4615 !DAG.doesNodeExist(ISD::SDIV, N->getVTList(), {N0, N1})) {
4616 // Target-specific implementation of srem x, pow2.
4617 if (SDValue Res = BuildSREMPow2(N))
4618 return Res;
4619 }
4620 return SDValue();
4621 }
4622
4623 // handles ISD::SREM and ISD::UREM
visitREM(SDNode * N)4624 SDValue DAGCombiner::visitREM(SDNode *N) {
4625 unsigned Opcode = N->getOpcode();
4626 SDValue N0 = N->getOperand(0);
4627 SDValue N1 = N->getOperand(1);
4628 EVT VT = N->getValueType(0);
4629 EVT CCVT = getSetCCResultType(VT);
4630
4631 bool isSigned = (Opcode == ISD::SREM);
4632 SDLoc DL(N);
4633
4634 // fold (rem c1, c2) -> c1%c2
4635 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
4636 return C;
4637
4638 // fold (urem X, -1) -> select(FX == -1, 0, FX)
4639 // Freeze the numerator to avoid a miscompile with an undefined value.
4640 if (!isSigned && llvm::isAllOnesOrAllOnesSplat(N1, /*AllowUndefs*/ false) &&
4641 CCVT.isVector() == VT.isVector()) {
4642 SDValue F0 = DAG.getFreeze(N0);
4643 SDValue EqualsNeg1 = DAG.getSetCC(DL, CCVT, F0, N1, ISD::SETEQ);
4644 return DAG.getSelect(DL, VT, EqualsNeg1, DAG.getConstant(0, DL, VT), F0);
4645 }
4646
4647 if (SDValue V = simplifyDivRem(N, DAG))
4648 return V;
4649
4650 if (SDValue NewSel = foldBinOpIntoSelect(N))
4651 return NewSel;
4652
4653 if (isSigned) {
4654 // If we know the sign bits of both operands are zero, strength reduce to a
4655 // urem instead. Handles (X & 0x0FFFFFFF) %s 16 -> X&15
4656 if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
4657 return DAG.getNode(ISD::UREM, DL, VT, N0, N1);
4658 } else {
4659 if (DAG.isKnownToBeAPowerOfTwo(N1)) {
4660 // fold (urem x, pow2) -> (and x, pow2-1)
4661 SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
4662 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
4663 AddToWorklist(Add.getNode());
4664 return DAG.getNode(ISD::AND, DL, VT, N0, Add);
4665 }
4666 // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
4667 // fold (urem x, (lshr pow2, y)) -> (and x, (add (lshr pow2, y), -1))
4668 // TODO: We should sink the following into isKnownToBePowerOfTwo
4669 // using a OrZero parameter analogous to our handling in ValueTracking.
4670 if ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) &&
4671 DAG.isKnownToBeAPowerOfTwo(N1.getOperand(0))) {
4672 SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
4673 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, NegOne);
4674 AddToWorklist(Add.getNode());
4675 return DAG.getNode(ISD::AND, DL, VT, N0, Add);
4676 }
4677 }
4678
4679 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4680
4681 // If X/C can be simplified by the division-by-constant logic, lower
4682 // X%C to the equivalent of X-X/C*C.
4683 // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
4684 // speculative DIV must not cause a DIVREM conversion. We guard against this
4685 // by skipping the simplification if isIntDivCheap(). When div is not cheap,
4686 // combine will not return a DIVREM. Regardless, checking cheapness here
4687 // makes sense since the simplification results in fatter code.
4688 if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) {
4689 if (isSigned) {
4690 // check if we can build faster implementation for srem
4691 if (SDValue OptimizedRem = buildOptimizedSREM(N0, N1, N))
4692 return OptimizedRem;
4693 }
4694
4695 SDValue OptimizedDiv =
4696 isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
4697 if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) {
4698 // If the equivalent Div node also exists, update its users.
4699 unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4700 if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(),
4701 { N0, N1 }))
4702 CombineTo(DivNode, OptimizedDiv);
4703 SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1);
4704 SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul);
4705 AddToWorklist(OptimizedDiv.getNode());
4706 AddToWorklist(Mul.getNode());
4707 return Sub;
4708 }
4709 }
4710
4711 // sdiv, srem -> sdivrem
4712 if (SDValue DivRem = useDivRem(N))
4713 return DivRem.getValue(1);
4714
4715 return SDValue();
4716 }
4717
visitMULHS(SDNode * N)4718 SDValue DAGCombiner::visitMULHS(SDNode *N) {
4719 SDValue N0 = N->getOperand(0);
4720 SDValue N1 = N->getOperand(1);
4721 EVT VT = N->getValueType(0);
4722 SDLoc DL(N);
4723
4724 // fold (mulhs c1, c2)
4725 if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHS, DL, VT, {N0, N1}))
4726 return C;
4727
4728 // canonicalize constant to RHS.
4729 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4730 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4731 return DAG.getNode(ISD::MULHS, DL, N->getVTList(), N1, N0);
4732
4733 if (VT.isVector()) {
4734 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4735 return FoldedVOp;
4736
4737 // fold (mulhs x, 0) -> 0
4738 // do not return N1, because undef node may exist.
4739 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
4740 return DAG.getConstant(0, DL, VT);
4741 }
4742
4743 // fold (mulhs x, 0) -> 0
4744 if (isNullConstant(N1))
4745 return N1;
4746
4747 // fold (mulhs x, 1) -> (sra x, size(x)-1)
4748 if (isOneConstant(N1))
4749 return DAG.getNode(ISD::SRA, DL, N0.getValueType(), N0,
4750 DAG.getConstant(N0.getScalarValueSizeInBits() - 1, DL,
4751 getShiftAmountTy(N0.getValueType())));
4752
4753 // fold (mulhs x, undef) -> 0
4754 if (N0.isUndef() || N1.isUndef())
4755 return DAG.getConstant(0, DL, VT);
4756
4757 // If the type twice as wide is legal, transform the mulhs to a wider multiply
4758 // plus a shift.
4759 if (!TLI.isOperationLegalOrCustom(ISD::MULHS, VT) && VT.isSimple() &&
4760 !VT.isVector()) {
4761 MVT Simple = VT.getSimpleVT();
4762 unsigned SimpleSize = Simple.getSizeInBits();
4763 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4764 if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4765 N0 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
4766 N1 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
4767 N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
4768 N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
4769 DAG.getConstant(SimpleSize, DL,
4770 getShiftAmountTy(N1.getValueType())));
4771 return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
4772 }
4773 }
4774
4775 return SDValue();
4776 }
4777
visitMULHU(SDNode * N)4778 SDValue DAGCombiner::visitMULHU(SDNode *N) {
4779 SDValue N0 = N->getOperand(0);
4780 SDValue N1 = N->getOperand(1);
4781 EVT VT = N->getValueType(0);
4782 SDLoc DL(N);
4783
4784 // fold (mulhu c1, c2)
4785 if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHU, DL, VT, {N0, N1}))
4786 return C;
4787
4788 // canonicalize constant to RHS.
4789 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4790 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4791 return DAG.getNode(ISD::MULHU, DL, N->getVTList(), N1, N0);
4792
4793 if (VT.isVector()) {
4794 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4795 return FoldedVOp;
4796
4797 // fold (mulhu x, 0) -> 0
4798 // do not return N1, because undef node may exist.
4799 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
4800 return DAG.getConstant(0, DL, VT);
4801 }
4802
4803 // fold (mulhu x, 0) -> 0
4804 if (isNullConstant(N1))
4805 return N1;
4806
4807 // fold (mulhu x, 1) -> 0
4808 if (isOneConstant(N1))
4809 return DAG.getConstant(0, DL, N0.getValueType());
4810
4811 // fold (mulhu x, undef) -> 0
4812 if (N0.isUndef() || N1.isUndef())
4813 return DAG.getConstant(0, DL, VT);
4814
4815 // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
4816 if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
4817 DAG.isKnownToBeAPowerOfTwo(N1) && hasOperation(ISD::SRL, VT)) {
4818 unsigned NumEltBits = VT.getScalarSizeInBits();
4819 SDValue LogBase2 = BuildLogBase2(N1, DL);
4820 SDValue SRLAmt = DAG.getNode(
4821 ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2);
4822 EVT ShiftVT = getShiftAmountTy(N0.getValueType());
4823 SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT);
4824 return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc);
4825 }
4826
4827 // If the type twice as wide is legal, transform the mulhu to a wider multiply
4828 // plus a shift.
4829 if (!TLI.isOperationLegalOrCustom(ISD::MULHU, VT) && VT.isSimple() &&
4830 !VT.isVector()) {
4831 MVT Simple = VT.getSimpleVT();
4832 unsigned SimpleSize = Simple.getSizeInBits();
4833 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4834 if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4835 N0 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
4836 N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
4837 N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
4838 N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
4839 DAG.getConstant(SimpleSize, DL,
4840 getShiftAmountTy(N1.getValueType())));
4841 return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
4842 }
4843 }
4844
4845 // Simplify the operands using demanded-bits information.
4846 // We don't have demanded bits support for MULHU so this just enables constant
4847 // folding based on known bits.
4848 if (SimplifyDemandedBits(SDValue(N, 0)))
4849 return SDValue(N, 0);
4850
4851 return SDValue();
4852 }
4853
visitAVG(SDNode * N)4854 SDValue DAGCombiner::visitAVG(SDNode *N) {
4855 unsigned Opcode = N->getOpcode();
4856 SDValue N0 = N->getOperand(0);
4857 SDValue N1 = N->getOperand(1);
4858 EVT VT = N->getValueType(0);
4859 SDLoc DL(N);
4860
4861 // fold (avg c1, c2)
4862 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
4863 return C;
4864
4865 // canonicalize constant to RHS.
4866 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4867 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4868 return DAG.getNode(Opcode, DL, N->getVTList(), N1, N0);
4869
4870 if (VT.isVector()) {
4871 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4872 return FoldedVOp;
4873
4874 // fold (avgfloor x, 0) -> x >> 1
4875 if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) {
4876 if (Opcode == ISD::AVGFLOORS)
4877 return DAG.getNode(ISD::SRA, DL, VT, N0, DAG.getConstant(1, DL, VT));
4878 if (Opcode == ISD::AVGFLOORU)
4879 return DAG.getNode(ISD::SRL, DL, VT, N0, DAG.getConstant(1, DL, VT));
4880 }
4881 }
4882
4883 // fold (avg x, undef) -> x
4884 if (N0.isUndef())
4885 return N1;
4886 if (N1.isUndef())
4887 return N0;
4888
4889 // TODO If we use avg for scalars anywhere, we can add (avgfl x, 0) -> x >> 1
4890
4891 return SDValue();
4892 }
4893
4894 /// Perform optimizations common to nodes that compute two values. LoOp and HiOp
4895 /// give the opcodes for the two computations that are being performed. Return
4896 /// true if a simplification was made.
SimplifyNodeWithTwoResults(SDNode * N,unsigned LoOp,unsigned HiOp)4897 SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
4898 unsigned HiOp) {
4899 // If the high half is not needed, just compute the low half.
4900 bool HiExists = N->hasAnyUseOfValue(1);
4901 if (!HiExists && (!LegalOperations ||
4902 TLI.isOperationLegalOrCustom(LoOp, N->getValueType(0)))) {
4903 SDValue Res = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
4904 return CombineTo(N, Res, Res);
4905 }
4906
4907 // If the low half is not needed, just compute the high half.
4908 bool LoExists = N->hasAnyUseOfValue(0);
4909 if (!LoExists && (!LegalOperations ||
4910 TLI.isOperationLegalOrCustom(HiOp, N->getValueType(1)))) {
4911 SDValue Res = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
4912 return CombineTo(N, Res, Res);
4913 }
4914
4915 // If both halves are used, return as it is.
4916 if (LoExists && HiExists)
4917 return SDValue();
4918
4919 // If the two computed results can be simplified separately, separate them.
4920 if (LoExists) {
4921 SDValue Lo = DAG.getNode(LoOp, SDLoc(N), N->getValueType(0), N->ops());
4922 AddToWorklist(Lo.getNode());
4923 SDValue LoOpt = combine(Lo.getNode());
4924 if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
4925 (!LegalOperations ||
4926 TLI.isOperationLegalOrCustom(LoOpt.getOpcode(), LoOpt.getValueType())))
4927 return CombineTo(N, LoOpt, LoOpt);
4928 }
4929
4930 if (HiExists) {
4931 SDValue Hi = DAG.getNode(HiOp, SDLoc(N), N->getValueType(1), N->ops());
4932 AddToWorklist(Hi.getNode());
4933 SDValue HiOpt = combine(Hi.getNode());
4934 if (HiOpt.getNode() && HiOpt != Hi &&
4935 (!LegalOperations ||
4936 TLI.isOperationLegalOrCustom(HiOpt.getOpcode(), HiOpt.getValueType())))
4937 return CombineTo(N, HiOpt, HiOpt);
4938 }
4939
4940 return SDValue();
4941 }
4942
visitSMUL_LOHI(SDNode * N)4943 SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
4944 if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHS))
4945 return Res;
4946
4947 SDValue N0 = N->getOperand(0);
4948 SDValue N1 = N->getOperand(1);
4949 EVT VT = N->getValueType(0);
4950 SDLoc DL(N);
4951
4952 // canonicalize constant to RHS (vector doesn't have to splat)
4953 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4954 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4955 return DAG.getNode(ISD::SMUL_LOHI, DL, N->getVTList(), N1, N0);
4956
4957 // If the type is twice as wide is legal, transform the mulhu to a wider
4958 // multiply plus a shift.
4959 if (VT.isSimple() && !VT.isVector()) {
4960 MVT Simple = VT.getSimpleVT();
4961 unsigned SimpleSize = Simple.getSizeInBits();
4962 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
4963 if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
4964 SDValue Lo = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N0);
4965 SDValue Hi = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
4966 Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
4967 // Compute the high part as N1.
4968 Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
4969 DAG.getConstant(SimpleSize, DL,
4970 getShiftAmountTy(Lo.getValueType())));
4971 Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
4972 // Compute the low part as N0.
4973 Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
4974 return CombineTo(N, Lo, Hi);
4975 }
4976 }
4977
4978 return SDValue();
4979 }
4980
visitUMUL_LOHI(SDNode * N)4981 SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
4982 if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHU))
4983 return Res;
4984
4985 SDValue N0 = N->getOperand(0);
4986 SDValue N1 = N->getOperand(1);
4987 EVT VT = N->getValueType(0);
4988 SDLoc DL(N);
4989
4990 // canonicalize constant to RHS (vector doesn't have to splat)
4991 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
4992 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
4993 return DAG.getNode(ISD::UMUL_LOHI, DL, N->getVTList(), N1, N0);
4994
4995 // (umul_lohi N0, 0) -> (0, 0)
4996 if (isNullConstant(N1)) {
4997 SDValue Zero = DAG.getConstant(0, DL, VT);
4998 return CombineTo(N, Zero, Zero);
4999 }
5000
5001 // (umul_lohi N0, 1) -> (N0, 0)
5002 if (isOneConstant(N1)) {
5003 SDValue Zero = DAG.getConstant(0, DL, VT);
5004 return CombineTo(N, N0, Zero);
5005 }
5006
5007 // If the type is twice as wide is legal, transform the mulhu to a wider
5008 // multiply plus a shift.
5009 if (VT.isSimple() && !VT.isVector()) {
5010 MVT Simple = VT.getSimpleVT();
5011 unsigned SimpleSize = Simple.getSizeInBits();
5012 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), SimpleSize*2);
5013 if (TLI.isOperationLegal(ISD::MUL, NewVT)) {
5014 SDValue Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N0);
5015 SDValue Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
5016 Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
5017 // Compute the high part as N1.
5018 Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
5019 DAG.getConstant(SimpleSize, DL,
5020 getShiftAmountTy(Lo.getValueType())));
5021 Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
5022 // Compute the low part as N0.
5023 Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
5024 return CombineTo(N, Lo, Hi);
5025 }
5026 }
5027
5028 return SDValue();
5029 }
5030
visitMULO(SDNode * N)5031 SDValue DAGCombiner::visitMULO(SDNode *N) {
5032 SDValue N0 = N->getOperand(0);
5033 SDValue N1 = N->getOperand(1);
5034 EVT VT = N0.getValueType();
5035 bool IsSigned = (ISD::SMULO == N->getOpcode());
5036
5037 EVT CarryVT = N->getValueType(1);
5038 SDLoc DL(N);
5039
5040 ConstantSDNode *N0C = isConstOrConstSplat(N0);
5041 ConstantSDNode *N1C = isConstOrConstSplat(N1);
5042
5043 // fold operation with constant operands.
5044 // TODO: Move this to FoldConstantArithmetic when it supports nodes with
5045 // multiple results.
5046 if (N0C && N1C) {
5047 bool Overflow;
5048 APInt Result =
5049 IsSigned ? N0C->getAPIntValue().smul_ov(N1C->getAPIntValue(), Overflow)
5050 : N0C->getAPIntValue().umul_ov(N1C->getAPIntValue(), Overflow);
5051 return CombineTo(N, DAG.getConstant(Result, DL, VT),
5052 DAG.getBoolConstant(Overflow, DL, CarryVT, CarryVT));
5053 }
5054
5055 // canonicalize constant to RHS.
5056 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5057 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5058 return DAG.getNode(N->getOpcode(), DL, N->getVTList(), N1, N0);
5059
5060 // fold (mulo x, 0) -> 0 + no carry out
5061 if (isNullOrNullSplat(N1))
5062 return CombineTo(N, DAG.getConstant(0, DL, VT),
5063 DAG.getConstant(0, DL, CarryVT));
5064
5065 // (mulo x, 2) -> (addo x, x)
5066 // FIXME: This needs a freeze.
5067 if (N1C && N1C->getAPIntValue() == 2 &&
5068 (!IsSigned || VT.getScalarSizeInBits() > 2))
5069 return DAG.getNode(IsSigned ? ISD::SADDO : ISD::UADDO, DL,
5070 N->getVTList(), N0, N0);
5071
5072 if (IsSigned) {
5073 // A 1 bit SMULO overflows if both inputs are 1.
5074 if (VT.getScalarSizeInBits() == 1) {
5075 SDValue And = DAG.getNode(ISD::AND, DL, VT, N0, N1);
5076 return CombineTo(N, And,
5077 DAG.getSetCC(DL, CarryVT, And,
5078 DAG.getConstant(0, DL, VT), ISD::SETNE));
5079 }
5080
5081 // Multiplying n * m significant bits yields a result of n + m significant
5082 // bits. If the total number of significant bits does not exceed the
5083 // result bit width (minus 1), there is no overflow.
5084 unsigned SignBits = DAG.ComputeNumSignBits(N0);
5085 if (SignBits > 1)
5086 SignBits += DAG.ComputeNumSignBits(N1);
5087 if (SignBits > VT.getScalarSizeInBits() + 1)
5088 return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1),
5089 DAG.getConstant(0, DL, CarryVT));
5090 } else {
5091 KnownBits N1Known = DAG.computeKnownBits(N1);
5092 KnownBits N0Known = DAG.computeKnownBits(N0);
5093 bool Overflow;
5094 (void)N0Known.getMaxValue().umul_ov(N1Known.getMaxValue(), Overflow);
5095 if (!Overflow)
5096 return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1),
5097 DAG.getConstant(0, DL, CarryVT));
5098 }
5099
5100 return SDValue();
5101 }
5102
5103 // Function to calculate whether the Min/Max pair of SDNodes (potentially
5104 // swapped around) make a signed saturate pattern, clamping to between a signed
5105 // saturate of -2^(BW-1) and 2^(BW-1)-1, or an unsigned saturate of 0 and 2^BW.
5106 // Returns the node being clamped and the bitwidth of the clamp in BW. Should
5107 // work with both SMIN/SMAX nodes and setcc/select combo. The operands are the
5108 // same as SimplifySelectCC. N0<N1 ? N2 : N3.
isSaturatingMinMax(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,unsigned & BW,bool & Unsigned)5109 static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2,
5110 SDValue N3, ISD::CondCode CC, unsigned &BW,
5111 bool &Unsigned) {
5112 auto isSignedMinMax = [&](SDValue N0, SDValue N1, SDValue N2, SDValue N3,
5113 ISD::CondCode CC) {
5114 // The compare and select operand should be the same or the select operands
5115 // should be truncated versions of the comparison.
5116 if (N0 != N2 && (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0)))
5117 return 0;
5118 // The constants need to be the same or a truncated version of each other.
5119 ConstantSDNode *N1C = isConstOrConstSplat(N1);
5120 ConstantSDNode *N3C = isConstOrConstSplat(N3);
5121 if (!N1C || !N3C)
5122 return 0;
5123 const APInt &C1 = N1C->getAPIntValue();
5124 const APInt &C2 = N3C->getAPIntValue();
5125 if (C1.getBitWidth() < C2.getBitWidth() || C1 != C2.sext(C1.getBitWidth()))
5126 return 0;
5127 return CC == ISD::SETLT ? ISD::SMIN : (CC == ISD::SETGT ? ISD::SMAX : 0);
5128 };
5129
5130 // Check the initial value is a SMIN/SMAX equivalent.
5131 unsigned Opcode0 = isSignedMinMax(N0, N1, N2, N3, CC);
5132 if (!Opcode0)
5133 return SDValue();
5134
5135 SDValue N00, N01, N02, N03;
5136 ISD::CondCode N0CC;
5137 switch (N0.getOpcode()) {
5138 case ISD::SMIN:
5139 case ISD::SMAX:
5140 N00 = N02 = N0.getOperand(0);
5141 N01 = N03 = N0.getOperand(1);
5142 N0CC = N0.getOpcode() == ISD::SMIN ? ISD::SETLT : ISD::SETGT;
5143 break;
5144 case ISD::SELECT_CC:
5145 N00 = N0.getOperand(0);
5146 N01 = N0.getOperand(1);
5147 N02 = N0.getOperand(2);
5148 N03 = N0.getOperand(3);
5149 N0CC = cast<CondCodeSDNode>(N0.getOperand(4))->get();
5150 break;
5151 case ISD::SELECT:
5152 case ISD::VSELECT:
5153 if (N0.getOperand(0).getOpcode() != ISD::SETCC)
5154 return SDValue();
5155 N00 = N0.getOperand(0).getOperand(0);
5156 N01 = N0.getOperand(0).getOperand(1);
5157 N02 = N0.getOperand(1);
5158 N03 = N0.getOperand(2);
5159 N0CC = cast<CondCodeSDNode>(N0.getOperand(0).getOperand(2))->get();
5160 break;
5161 default:
5162 return SDValue();
5163 }
5164
5165 unsigned Opcode1 = isSignedMinMax(N00, N01, N02, N03, N0CC);
5166 if (!Opcode1 || Opcode0 == Opcode1)
5167 return SDValue();
5168
5169 ConstantSDNode *MinCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N1 : N01);
5170 ConstantSDNode *MaxCOp = isConstOrConstSplat(Opcode0 == ISD::SMIN ? N01 : N1);
5171 if (!MinCOp || !MaxCOp || MinCOp->getValueType(0) != MaxCOp->getValueType(0))
5172 return SDValue();
5173
5174 const APInt &MinC = MinCOp->getAPIntValue();
5175 const APInt &MaxC = MaxCOp->getAPIntValue();
5176 APInt MinCPlus1 = MinC + 1;
5177 if (-MaxC == MinCPlus1 && MinCPlus1.isPowerOf2()) {
5178 BW = MinCPlus1.exactLogBase2() + 1;
5179 Unsigned = false;
5180 return N02;
5181 }
5182
5183 if (MaxC == 0 && MinCPlus1.isPowerOf2()) {
5184 BW = MinCPlus1.exactLogBase2();
5185 Unsigned = true;
5186 return N02;
5187 }
5188
5189 return SDValue();
5190 }
5191
PerformMinMaxFpToSatCombine(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,SelectionDAG & DAG)5192 static SDValue PerformMinMaxFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5193 SDValue N3, ISD::CondCode CC,
5194 SelectionDAG &DAG) {
5195 unsigned BW;
5196 bool Unsigned;
5197 SDValue Fp = isSaturatingMinMax(N0, N1, N2, N3, CC, BW, Unsigned);
5198 if (!Fp || Fp.getOpcode() != ISD::FP_TO_SINT)
5199 return SDValue();
5200 EVT FPVT = Fp.getOperand(0).getValueType();
5201 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), BW);
5202 if (FPVT.isVector())
5203 NewVT = EVT::getVectorVT(*DAG.getContext(), NewVT,
5204 FPVT.getVectorElementCount());
5205 unsigned NewOpc = Unsigned ? ISD::FP_TO_UINT_SAT : ISD::FP_TO_SINT_SAT;
5206 if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(NewOpc, FPVT, NewVT))
5207 return SDValue();
5208 SDLoc DL(Fp);
5209 SDValue Sat = DAG.getNode(NewOpc, DL, NewVT, Fp.getOperand(0),
5210 DAG.getValueType(NewVT.getScalarType()));
5211 return Unsigned ? DAG.getZExtOrTrunc(Sat, DL, N2->getValueType(0))
5212 : DAG.getSExtOrTrunc(Sat, DL, N2->getValueType(0));
5213 }
5214
PerformUMinFpToSatCombine(SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,SelectionDAG & DAG)5215 static SDValue PerformUMinFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5216 SDValue N3, ISD::CondCode CC,
5217 SelectionDAG &DAG) {
5218 // We are looking for UMIN(FPTOUI(X), (2^n)-1), which may have come via a
5219 // select/vselect/select_cc. The two operands pairs for the select (N2/N3) may
5220 // be truncated versions of the the setcc (N0/N1).
5221 if ((N0 != N2 &&
5222 (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(0))) ||
5223 N0.getOpcode() != ISD::FP_TO_UINT || CC != ISD::SETULT)
5224 return SDValue();
5225 ConstantSDNode *N1C = isConstOrConstSplat(N1);
5226 ConstantSDNode *N3C = isConstOrConstSplat(N3);
5227 if (!N1C || !N3C)
5228 return SDValue();
5229 const APInt &C1 = N1C->getAPIntValue();
5230 const APInt &C3 = N3C->getAPIntValue();
5231 if (!(C1 + 1).isPowerOf2() || C1.getBitWidth() < C3.getBitWidth() ||
5232 C1 != C3.zext(C1.getBitWidth()))
5233 return SDValue();
5234
5235 unsigned BW = (C1 + 1).exactLogBase2();
5236 EVT FPVT = N0.getOperand(0).getValueType();
5237 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), BW);
5238 if (FPVT.isVector())
5239 NewVT = EVT::getVectorVT(*DAG.getContext(), NewVT,
5240 FPVT.getVectorElementCount());
5241 if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
5242 FPVT, NewVT))
5243 return SDValue();
5244
5245 SDValue Sat =
5246 DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(N0), NewVT, N0.getOperand(0),
5247 DAG.getValueType(NewVT.getScalarType()));
5248 return DAG.getZExtOrTrunc(Sat, SDLoc(N0), N3.getValueType());
5249 }
5250
visitIMINMAX(SDNode * N)5251 SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
5252 SDValue N0 = N->getOperand(0);
5253 SDValue N1 = N->getOperand(1);
5254 EVT VT = N0.getValueType();
5255 unsigned Opcode = N->getOpcode();
5256 SDLoc DL(N);
5257
5258 // fold operation with constant operands.
5259 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
5260 return C;
5261
5262 // If the operands are the same, this is a no-op.
5263 if (N0 == N1)
5264 return N0;
5265
5266 // canonicalize constant to RHS
5267 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
5268 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
5269 return DAG.getNode(Opcode, DL, VT, N1, N0);
5270
5271 // fold vector ops
5272 if (VT.isVector())
5273 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5274 return FoldedVOp;
5275
5276 // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX.
5277 // Only do this if the current op isn't legal and the flipped is.
5278 if (!TLI.isOperationLegal(Opcode, VT) &&
5279 (N0.isUndef() || DAG.SignBitIsZero(N0)) &&
5280 (N1.isUndef() || DAG.SignBitIsZero(N1))) {
5281 unsigned AltOpcode;
5282 switch (Opcode) {
5283 case ISD::SMIN: AltOpcode = ISD::UMIN; break;
5284 case ISD::SMAX: AltOpcode = ISD::UMAX; break;
5285 case ISD::UMIN: AltOpcode = ISD::SMIN; break;
5286 case ISD::UMAX: AltOpcode = ISD::SMAX; break;
5287 default: llvm_unreachable("Unknown MINMAX opcode");
5288 }
5289 if (TLI.isOperationLegal(AltOpcode, VT))
5290 return DAG.getNode(AltOpcode, DL, VT, N0, N1);
5291 }
5292
5293 if (Opcode == ISD::SMIN || Opcode == ISD::SMAX)
5294 if (SDValue S = PerformMinMaxFpToSatCombine(
5295 N0, N1, N0, N1, Opcode == ISD::SMIN ? ISD::SETLT : ISD::SETGT, DAG))
5296 return S;
5297 if (Opcode == ISD::UMIN)
5298 if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N0, N1, ISD::SETULT, DAG))
5299 return S;
5300
5301 // Simplify the operands using demanded-bits information.
5302 if (SimplifyDemandedBits(SDValue(N, 0)))
5303 return SDValue(N, 0);
5304
5305 return SDValue();
5306 }
5307
5308 /// If this is a bitwise logic instruction and both operands have the same
5309 /// opcode, try to sink the other opcode after the logic instruction.
hoistLogicOpWithSameOpcodeHands(SDNode * N)5310 SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
5311 SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
5312 EVT VT = N0.getValueType();
5313 unsigned LogicOpcode = N->getOpcode();
5314 unsigned HandOpcode = N0.getOpcode();
5315 assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR ||
5316 LogicOpcode == ISD::XOR) && "Expected logic opcode");
5317 assert(HandOpcode == N1.getOpcode() && "Bad input!");
5318
5319 // Bail early if none of these transforms apply.
5320 if (N0.getNumOperands() == 0)
5321 return SDValue();
5322
5323 // FIXME: We should check number of uses of the operands to not increase
5324 // the instruction count for all transforms.
5325
5326 // Handle size-changing casts.
5327 SDValue X = N0.getOperand(0);
5328 SDValue Y = N1.getOperand(0);
5329 EVT XVT = X.getValueType();
5330 SDLoc DL(N);
5331 if (HandOpcode == ISD::ANY_EXTEND || HandOpcode == ISD::ZERO_EXTEND ||
5332 HandOpcode == ISD::SIGN_EXTEND) {
5333 // If both operands have other uses, this transform would create extra
5334 // instructions without eliminating anything.
5335 if (!N0.hasOneUse() && !N1.hasOneUse())
5336 return SDValue();
5337 // We need matching integer source types.
5338 if (XVT != Y.getValueType())
5339 return SDValue();
5340 // Don't create an illegal op during or after legalization. Don't ever
5341 // create an unsupported vector op.
5342 if ((VT.isVector() || LegalOperations) &&
5343 !TLI.isOperationLegalOrCustom(LogicOpcode, XVT))
5344 return SDValue();
5345 // Avoid infinite looping with PromoteIntBinOp.
5346 // TODO: Should we apply desirable/legal constraints to all opcodes?
5347 if (HandOpcode == ISD::ANY_EXTEND && LegalTypes &&
5348 !TLI.isTypeDesirableForOp(LogicOpcode, XVT))
5349 return SDValue();
5350 // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
5351 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5352 return DAG.getNode(HandOpcode, DL, VT, Logic);
5353 }
5354
5355 // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
5356 if (HandOpcode == ISD::TRUNCATE) {
5357 // If both operands have other uses, this transform would create extra
5358 // instructions without eliminating anything.
5359 if (!N0.hasOneUse() && !N1.hasOneUse())
5360 return SDValue();
5361 // We need matching source types.
5362 if (XVT != Y.getValueType())
5363 return SDValue();
5364 // Don't create an illegal op during or after legalization.
5365 if (LegalOperations && !TLI.isOperationLegal(LogicOpcode, XVT))
5366 return SDValue();
5367 // Be extra careful sinking truncate. If it's free, there's no benefit in
5368 // widening a binop. Also, don't create a logic op on an illegal type.
5369 if (TLI.isZExtFree(VT, XVT) && TLI.isTruncateFree(XVT, VT))
5370 return SDValue();
5371 if (!TLI.isTypeLegal(XVT))
5372 return SDValue();
5373 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5374 return DAG.getNode(HandOpcode, DL, VT, Logic);
5375 }
5376
5377 // For binops SHL/SRL/SRA/AND:
5378 // logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
5379 if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
5380 HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
5381 N0.getOperand(1) == N1.getOperand(1)) {
5382 // If either operand has other uses, this transform is not an improvement.
5383 if (!N0.hasOneUse() || !N1.hasOneUse())
5384 return SDValue();
5385 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5386 return DAG.getNode(HandOpcode, DL, VT, Logic, N0.getOperand(1));
5387 }
5388
5389 // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
5390 if (HandOpcode == ISD::BSWAP) {
5391 // If either operand has other uses, this transform is not an improvement.
5392 if (!N0.hasOneUse() || !N1.hasOneUse())
5393 return SDValue();
5394 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5395 return DAG.getNode(HandOpcode, DL, VT, Logic);
5396 }
5397
5398 // For funnel shifts FSHL/FSHR:
5399 // logic_op (OP x, x1, s), (OP y, y1, s) -->
5400 // --> OP (logic_op x, y), (logic_op, x1, y1), s
5401 if ((HandOpcode == ISD::FSHL || HandOpcode == ISD::FSHR) &&
5402 N0.getOperand(2) == N1.getOperand(2)) {
5403 if (!N0.hasOneUse() || !N1.hasOneUse())
5404 return SDValue();
5405 SDValue X1 = N0.getOperand(1);
5406 SDValue Y1 = N1.getOperand(1);
5407 SDValue S = N0.getOperand(2);
5408 SDValue Logic0 = DAG.getNode(LogicOpcode, DL, VT, X, Y);
5409 SDValue Logic1 = DAG.getNode(LogicOpcode, DL, VT, X1, Y1);
5410 return DAG.getNode(HandOpcode, DL, VT, Logic0, Logic1, S);
5411 }
5412
5413 // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
5414 // Only perform this optimization up until type legalization, before
5415 // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
5416 // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
5417 // we don't want to undo this promotion.
5418 // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
5419 // on scalars.
5420 if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
5421 Level <= AfterLegalizeTypes) {
5422 // Input types must be integer and the same.
5423 if (XVT.isInteger() && XVT == Y.getValueType() &&
5424 !(VT.isVector() && TLI.isTypeLegal(VT) &&
5425 !XVT.isVector() && !TLI.isTypeLegal(XVT))) {
5426 SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y);
5427 return DAG.getNode(HandOpcode, DL, VT, Logic);
5428 }
5429 }
5430
5431 // Xor/and/or are indifferent to the swizzle operation (shuffle of one value).
5432 // Simplify xor/and/or (shuff(A), shuff(B)) -> shuff(op (A,B))
5433 // If both shuffles use the same mask, and both shuffle within a single
5434 // vector, then it is worthwhile to move the swizzle after the operation.
5435 // The type-legalizer generates this pattern when loading illegal
5436 // vector types from memory. In many cases this allows additional shuffle
5437 // optimizations.
5438 // There are other cases where moving the shuffle after the xor/and/or
5439 // is profitable even if shuffles don't perform a swizzle.
5440 // If both shuffles use the same mask, and both shuffles have the same first
5441 // or second operand, then it might still be profitable to move the shuffle
5442 // after the xor/and/or operation.
5443 if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
5444 auto *SVN0 = cast<ShuffleVectorSDNode>(N0);
5445 auto *SVN1 = cast<ShuffleVectorSDNode>(N1);
5446 assert(X.getValueType() == Y.getValueType() &&
5447 "Inputs to shuffles are not the same type");
5448
5449 // Check that both shuffles use the same mask. The masks are known to be of
5450 // the same length because the result vector type is the same.
5451 // Check also that shuffles have only one use to avoid introducing extra
5452 // instructions.
5453 if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
5454 !SVN0->getMask().equals(SVN1->getMask()))
5455 return SDValue();
5456
5457 // Don't try to fold this node if it requires introducing a
5458 // build vector of all zeros that might be illegal at this stage.
5459 SDValue ShOp = N0.getOperand(1);
5460 if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
5461 ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
5462
5463 // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
5464 if (N0.getOperand(1) == N1.getOperand(1) && ShOp.getNode()) {
5465 SDValue Logic = DAG.getNode(LogicOpcode, DL, VT,
5466 N0.getOperand(0), N1.getOperand(0));
5467 return DAG.getVectorShuffle(VT, DL, Logic, ShOp, SVN0->getMask());
5468 }
5469
5470 // Don't try to fold this node if it requires introducing a
5471 // build vector of all zeros that might be illegal at this stage.
5472 ShOp = N0.getOperand(0);
5473 if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
5474 ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
5475
5476 // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
5477 if (N0.getOperand(0) == N1.getOperand(0) && ShOp.getNode()) {
5478 SDValue Logic = DAG.getNode(LogicOpcode, DL, VT, N0.getOperand(1),
5479 N1.getOperand(1));
5480 return DAG.getVectorShuffle(VT, DL, ShOp, Logic, SVN0->getMask());
5481 }
5482 }
5483
5484 return SDValue();
5485 }
5486
5487 /// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
foldLogicOfSetCCs(bool IsAnd,SDValue N0,SDValue N1,const SDLoc & DL)5488 SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
5489 const SDLoc &DL) {
5490 SDValue LL, LR, RL, RR, N0CC, N1CC;
5491 if (!isSetCCEquivalent(N0, LL, LR, N0CC) ||
5492 !isSetCCEquivalent(N1, RL, RR, N1CC))
5493 return SDValue();
5494
5495 assert(N0.getValueType() == N1.getValueType() &&
5496 "Unexpected operand types for bitwise logic op");
5497 assert(LL.getValueType() == LR.getValueType() &&
5498 RL.getValueType() == RR.getValueType() &&
5499 "Unexpected operand types for setcc");
5500
5501 // If we're here post-legalization or the logic op type is not i1, the logic
5502 // op type must match a setcc result type. Also, all folds require new
5503 // operations on the left and right operands, so those types must match.
5504 EVT VT = N0.getValueType();
5505 EVT OpVT = LL.getValueType();
5506 if (LegalOperations || VT.getScalarType() != MVT::i1)
5507 if (VT != getSetCCResultType(OpVT))
5508 return SDValue();
5509 if (OpVT != RL.getValueType())
5510 return SDValue();
5511
5512 ISD::CondCode CC0 = cast<CondCodeSDNode>(N0CC)->get();
5513 ISD::CondCode CC1 = cast<CondCodeSDNode>(N1CC)->get();
5514 bool IsInteger = OpVT.isInteger();
5515 if (LR == RR && CC0 == CC1 && IsInteger) {
5516 bool IsZero = isNullOrNullSplat(LR);
5517 bool IsNeg1 = isAllOnesOrAllOnesSplat(LR);
5518
5519 // All bits clear?
5520 bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
5521 // All sign bits clear?
5522 bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
5523 // Any bits set?
5524 bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
5525 // Any sign bits set?
5526 bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
5527
5528 // (and (seteq X, 0), (seteq Y, 0)) --> (seteq (or X, Y), 0)
5529 // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
5530 // (or (setne X, 0), (setne Y, 0)) --> (setne (or X, Y), 0)
5531 // (or (setlt X, 0), (setlt Y, 0)) --> (setlt (or X, Y), 0)
5532 if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
5533 SDValue Or = DAG.getNode(ISD::OR, SDLoc(N0), OpVT, LL, RL);
5534 AddToWorklist(Or.getNode());
5535 return DAG.getSetCC(DL, VT, Or, LR, CC1);
5536 }
5537
5538 // All bits set?
5539 bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
5540 // All sign bits set?
5541 bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
5542 // Any bits clear?
5543 bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
5544 // Any sign bits clear?
5545 bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
5546
5547 // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
5548 // (and (setlt X, 0), (setlt Y, 0)) --> (setlt (and X, Y), 0)
5549 // (or (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
5550 // (or (setgt X, -1), (setgt Y -1)) --> (setgt (and X, Y), -1)
5551 if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
5552 SDValue And = DAG.getNode(ISD::AND, SDLoc(N0), OpVT, LL, RL);
5553 AddToWorklist(And.getNode());
5554 return DAG.getSetCC(DL, VT, And, LR, CC1);
5555 }
5556 }
5557
5558 // TODO: What is the 'or' equivalent of this fold?
5559 // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
5560 if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
5561 IsInteger && CC0 == ISD::SETNE &&
5562 ((isNullConstant(LR) && isAllOnesConstant(RR)) ||
5563 (isAllOnesConstant(LR) && isNullConstant(RR)))) {
5564 SDValue One = DAG.getConstant(1, DL, OpVT);
5565 SDValue Two = DAG.getConstant(2, DL, OpVT);
5566 SDValue Add = DAG.getNode(ISD::ADD, SDLoc(N0), OpVT, LL, One);
5567 AddToWorklist(Add.getNode());
5568 return DAG.getSetCC(DL, VT, Add, Two, ISD::SETUGE);
5569 }
5570
5571 // Try more general transforms if the predicates match and the only user of
5572 // the compares is the 'and' or 'or'.
5573 if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(OpVT) && CC0 == CC1 &&
5574 N0.hasOneUse() && N1.hasOneUse()) {
5575 // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
5576 // or (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
5577 if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
5578 SDValue XorL = DAG.getNode(ISD::XOR, SDLoc(N0), OpVT, LL, LR);
5579 SDValue XorR = DAG.getNode(ISD::XOR, SDLoc(N1), OpVT, RL, RR);
5580 SDValue Or = DAG.getNode(ISD::OR, DL, OpVT, XorL, XorR);
5581 SDValue Zero = DAG.getConstant(0, DL, OpVT);
5582 return DAG.getSetCC(DL, VT, Or, Zero, CC1);
5583 }
5584
5585 // Turn compare of constants whose difference is 1 bit into add+and+setcc.
5586 if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) {
5587 // Match a shared variable operand and 2 non-opaque constant operands.
5588 auto MatchDiffPow2 = [&](ConstantSDNode *C0, ConstantSDNode *C1) {
5589 // The difference of the constants must be a single bit.
5590 const APInt &CMax =
5591 APIntOps::umax(C0->getAPIntValue(), C1->getAPIntValue());
5592 const APInt &CMin =
5593 APIntOps::umin(C0->getAPIntValue(), C1->getAPIntValue());
5594 return !C0->isOpaque() && !C1->isOpaque() && (CMax - CMin).isPowerOf2();
5595 };
5596 if (LL == RL && ISD::matchBinaryPredicate(LR, RR, MatchDiffPow2)) {
5597 // and/or (setcc X, CMax, ne), (setcc X, CMin, ne/eq) -->
5598 // setcc ((sub X, CMin), ~(CMax - CMin)), 0, ne/eq
5599 SDValue Max = DAG.getNode(ISD::UMAX, DL, OpVT, LR, RR);
5600 SDValue Min = DAG.getNode(ISD::UMIN, DL, OpVT, LR, RR);
5601 SDValue Offset = DAG.getNode(ISD::SUB, DL, OpVT, LL, Min);
5602 SDValue Diff = DAG.getNode(ISD::SUB, DL, OpVT, Max, Min);
5603 SDValue Mask = DAG.getNOT(DL, Diff, OpVT);
5604 SDValue And = DAG.getNode(ISD::AND, DL, OpVT, Offset, Mask);
5605 SDValue Zero = DAG.getConstant(0, DL, OpVT);
5606 return DAG.getSetCC(DL, VT, And, Zero, CC0);
5607 }
5608 }
5609 }
5610
5611 // Canonicalize equivalent operands to LL == RL.
5612 if (LL == RR && LR == RL) {
5613 CC1 = ISD::getSetCCSwappedOperands(CC1);
5614 std::swap(RL, RR);
5615 }
5616
5617 // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
5618 // (or (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
5619 if (LL == RL && LR == RR) {
5620 ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(CC0, CC1, OpVT)
5621 : ISD::getSetCCOrOperation(CC0, CC1, OpVT);
5622 if (NewCC != ISD::SETCC_INVALID &&
5623 (!LegalOperations ||
5624 (TLI.isCondCodeLegal(NewCC, LL.getSimpleValueType()) &&
5625 TLI.isOperationLegal(ISD::SETCC, OpVT))))
5626 return DAG.getSetCC(DL, VT, LL, LR, NewCC);
5627 }
5628
5629 return SDValue();
5630 }
5631
5632 /// This contains all DAGCombine rules which reduce two values combined by
5633 /// an And operation to a single value. This makes them reusable in the context
5634 /// of visitSELECT(). Rules involving constants are not included as
5635 /// visitSELECT() already handles those cases.
visitANDLike(SDValue N0,SDValue N1,SDNode * N)5636 SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
5637 EVT VT = N1.getValueType();
5638 SDLoc DL(N);
5639
5640 // fold (and x, undef) -> 0
5641 if (N0.isUndef() || N1.isUndef())
5642 return DAG.getConstant(0, DL, VT);
5643
5644 if (SDValue V = foldLogicOfSetCCs(true, N0, N1, DL))
5645 return V;
5646
5647 // TODO: Rewrite this to return a new 'AND' instead of using CombineTo.
5648 if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
5649 VT.getSizeInBits() <= 64 && N0->hasOneUse()) {
5650 if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
5651 if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(N1.getOperand(1))) {
5652 // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
5653 // immediate for an add, but it is legal if its top c2 bits are set,
5654 // transform the ADD so the immediate doesn't need to be materialized
5655 // in a register.
5656 APInt ADDC = ADDI->getAPIntValue();
5657 APInt SRLC = SRLI->getAPIntValue();
5658 if (ADDC.getMinSignedBits() <= 64 &&
5659 SRLC.ult(VT.getSizeInBits()) &&
5660 !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
5661 APInt Mask = APInt::getHighBitsSet(VT.getSizeInBits(),
5662 SRLC.getZExtValue());
5663 if (DAG.MaskedValueIsZero(N0.getOperand(1), Mask)) {
5664 ADDC |= Mask;
5665 if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
5666 SDLoc DL0(N0);
5667 SDValue NewAdd =
5668 DAG.getNode(ISD::ADD, DL0, VT,
5669 N0.getOperand(0), DAG.getConstant(ADDC, DL, VT));
5670 CombineTo(N0.getNode(), NewAdd);
5671 // Return N so it doesn't get rechecked!
5672 return SDValue(N, 0);
5673 }
5674 }
5675 }
5676 }
5677 }
5678 }
5679
5680 // Reduce bit extract of low half of an integer to the narrower type.
5681 // (and (srl i64:x, K), KMask) ->
5682 // (i64 zero_extend (and (srl (i32 (trunc i64:x)), K)), KMask)
5683 if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
5684 if (ConstantSDNode *CAnd = dyn_cast<ConstantSDNode>(N1)) {
5685 if (ConstantSDNode *CShift = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
5686 unsigned Size = VT.getSizeInBits();
5687 const APInt &AndMask = CAnd->getAPIntValue();
5688 unsigned ShiftBits = CShift->getZExtValue();
5689
5690 // Bail out, this node will probably disappear anyway.
5691 if (ShiftBits == 0)
5692 return SDValue();
5693
5694 unsigned MaskBits = AndMask.countTrailingOnes();
5695 EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), Size / 2);
5696
5697 if (AndMask.isMask() &&
5698 // Required bits must not span the two halves of the integer and
5699 // must fit in the half size type.
5700 (ShiftBits + MaskBits <= Size / 2) &&
5701 TLI.isNarrowingProfitable(VT, HalfVT) &&
5702 TLI.isTypeDesirableForOp(ISD::AND, HalfVT) &&
5703 TLI.isTypeDesirableForOp(ISD::SRL, HalfVT) &&
5704 TLI.isTruncateFree(VT, HalfVT) &&
5705 TLI.isZExtFree(HalfVT, VT)) {
5706 // The isNarrowingProfitable is to avoid regressions on PPC and
5707 // AArch64 which match a few 64-bit bit insert / bit extract patterns
5708 // on downstream users of this. Those patterns could probably be
5709 // extended to handle extensions mixed in.
5710
5711 SDValue SL(N0);
5712 assert(MaskBits <= Size);
5713
5714 // Extracting the highest bit of the low half.
5715 EVT ShiftVT = TLI.getShiftAmountTy(HalfVT, DAG.getDataLayout());
5716 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, HalfVT,
5717 N0.getOperand(0));
5718
5719 SDValue NewMask = DAG.getConstant(AndMask.trunc(Size / 2), SL, HalfVT);
5720 SDValue ShiftK = DAG.getConstant(ShiftBits, SL, ShiftVT);
5721 SDValue Shift = DAG.getNode(ISD::SRL, SL, HalfVT, Trunc, ShiftK);
5722 SDValue And = DAG.getNode(ISD::AND, SL, HalfVT, Shift, NewMask);
5723 return DAG.getNode(ISD::ZERO_EXTEND, SL, VT, And);
5724 }
5725 }
5726 }
5727 }
5728
5729 return SDValue();
5730 }
5731
isAndLoadExtLoad(ConstantSDNode * AndC,LoadSDNode * LoadN,EVT LoadResultTy,EVT & ExtVT)5732 bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
5733 EVT LoadResultTy, EVT &ExtVT) {
5734 if (!AndC->getAPIntValue().isMask())
5735 return false;
5736
5737 unsigned ActiveBits = AndC->getAPIntValue().countTrailingOnes();
5738
5739 ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
5740 EVT LoadedVT = LoadN->getMemoryVT();
5741
5742 if (ExtVT == LoadedVT &&
5743 (!LegalOperations ||
5744 TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) {
5745 // ZEXTLOAD will match without needing to change the size of the value being
5746 // loaded.
5747 return true;
5748 }
5749
5750 // Do not change the width of a volatile or atomic loads.
5751 if (!LoadN->isSimple())
5752 return false;
5753
5754 // Do not generate loads of non-round integer types since these can
5755 // be expensive (and would be wrong if the type is not byte sized).
5756 if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound())
5757 return false;
5758
5759 if (LegalOperations &&
5760 !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))
5761 return false;
5762
5763 if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT))
5764 return false;
5765
5766 return true;
5767 }
5768
isLegalNarrowLdSt(LSBaseSDNode * LDST,ISD::LoadExtType ExtType,EVT & MemVT,unsigned ShAmt)5769 bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
5770 ISD::LoadExtType ExtType, EVT &MemVT,
5771 unsigned ShAmt) {
5772 if (!LDST)
5773 return false;
5774 // Only allow byte offsets.
5775 if (ShAmt % 8)
5776 return false;
5777
5778 // Do not generate loads of non-round integer types since these can
5779 // be expensive (and would be wrong if the type is not byte sized).
5780 if (!MemVT.isRound())
5781 return false;
5782
5783 // Don't change the width of a volatile or atomic loads.
5784 if (!LDST->isSimple())
5785 return false;
5786
5787 EVT LdStMemVT = LDST->getMemoryVT();
5788
5789 // Bail out when changing the scalable property, since we can't be sure that
5790 // we're actually narrowing here.
5791 if (LdStMemVT.isScalableVector() != MemVT.isScalableVector())
5792 return false;
5793
5794 // Verify that we are actually reducing a load width here.
5795 if (LdStMemVT.bitsLT(MemVT))
5796 return false;
5797
5798 // Ensure that this isn't going to produce an unsupported memory access.
5799 if (ShAmt) {
5800 assert(ShAmt % 8 == 0 && "ShAmt is byte offset");
5801 const unsigned ByteShAmt = ShAmt / 8;
5802 const Align LDSTAlign = LDST->getAlign();
5803 const Align NarrowAlign = commonAlignment(LDSTAlign, ByteShAmt);
5804 if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
5805 LDST->getAddressSpace(), NarrowAlign,
5806 LDST->getMemOperand()->getFlags()))
5807 return false;
5808 }
5809
5810 // It's not possible to generate a constant of extended or untyped type.
5811 EVT PtrType = LDST->getBasePtr().getValueType();
5812 if (PtrType == MVT::Untyped || PtrType.isExtended())
5813 return false;
5814
5815 if (isa<LoadSDNode>(LDST)) {
5816 LoadSDNode *Load = cast<LoadSDNode>(LDST);
5817 // Don't transform one with multiple uses, this would require adding a new
5818 // load.
5819 if (!SDValue(Load, 0).hasOneUse())
5820 return false;
5821
5822 if (LegalOperations &&
5823 !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT))
5824 return false;
5825
5826 // For the transform to be legal, the load must produce only two values
5827 // (the value loaded and the chain). Don't transform a pre-increment
5828 // load, for example, which produces an extra value. Otherwise the
5829 // transformation is not equivalent, and the downstream logic to replace
5830 // uses gets things wrong.
5831 if (Load->getNumValues() > 2)
5832 return false;
5833
5834 // If the load that we're shrinking is an extload and we're not just
5835 // discarding the extension we can't simply shrink the load. Bail.
5836 // TODO: It would be possible to merge the extensions in some cases.
5837 if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
5838 Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
5839 return false;
5840
5841 if (!TLI.shouldReduceLoadWidth(Load, ExtType, MemVT))
5842 return false;
5843 } else {
5844 assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
5845 StoreSDNode *Store = cast<StoreSDNode>(LDST);
5846 // Can't write outside the original store
5847 if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
5848 return false;
5849
5850 if (LegalOperations &&
5851 !TLI.isTruncStoreLegal(Store->getValue().getValueType(), MemVT))
5852 return false;
5853 }
5854 return true;
5855 }
5856
SearchForAndLoads(SDNode * N,SmallVectorImpl<LoadSDNode * > & Loads,SmallPtrSetImpl<SDNode * > & NodesWithConsts,ConstantSDNode * Mask,SDNode * & NodeToMask)5857 bool DAGCombiner::SearchForAndLoads(SDNode *N,
5858 SmallVectorImpl<LoadSDNode*> &Loads,
5859 SmallPtrSetImpl<SDNode*> &NodesWithConsts,
5860 ConstantSDNode *Mask,
5861 SDNode *&NodeToMask) {
5862 // Recursively search for the operands, looking for loads which can be
5863 // narrowed.
5864 for (SDValue Op : N->op_values()) {
5865 if (Op.getValueType().isVector())
5866 return false;
5867
5868 // Some constants may need fixing up later if they are too large.
5869 if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
5870 if ((N->getOpcode() == ISD::OR || N->getOpcode() == ISD::XOR) &&
5871 (Mask->getAPIntValue() & C->getAPIntValue()) != C->getAPIntValue())
5872 NodesWithConsts.insert(N);
5873 continue;
5874 }
5875
5876 if (!Op.hasOneUse())
5877 return false;
5878
5879 switch(Op.getOpcode()) {
5880 case ISD::LOAD: {
5881 auto *Load = cast<LoadSDNode>(Op);
5882 EVT ExtVT;
5883 if (isAndLoadExtLoad(Mask, Load, Load->getValueType(0), ExtVT) &&
5884 isLegalNarrowLdSt(Load, ISD::ZEXTLOAD, ExtVT)) {
5885
5886 // ZEXTLOAD is already small enough.
5887 if (Load->getExtensionType() == ISD::ZEXTLOAD &&
5888 ExtVT.bitsGE(Load->getMemoryVT()))
5889 continue;
5890
5891 // Use LE to convert equal sized loads to zext.
5892 if (ExtVT.bitsLE(Load->getMemoryVT()))
5893 Loads.push_back(Load);
5894
5895 continue;
5896 }
5897 return false;
5898 }
5899 case ISD::ZERO_EXTEND:
5900 case ISD::AssertZext: {
5901 unsigned ActiveBits = Mask->getAPIntValue().countTrailingOnes();
5902 EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
5903 EVT VT = Op.getOpcode() == ISD::AssertZext ?
5904 cast<VTSDNode>(Op.getOperand(1))->getVT() :
5905 Op.getOperand(0).getValueType();
5906
5907 // We can accept extending nodes if the mask is wider or an equal
5908 // width to the original type.
5909 if (ExtVT.bitsGE(VT))
5910 continue;
5911 break;
5912 }
5913 case ISD::OR:
5914 case ISD::XOR:
5915 case ISD::AND:
5916 if (!SearchForAndLoads(Op.getNode(), Loads, NodesWithConsts, Mask,
5917 NodeToMask))
5918 return false;
5919 continue;
5920 }
5921
5922 // Allow one node which will masked along with any loads found.
5923 if (NodeToMask)
5924 return false;
5925
5926 // Also ensure that the node to be masked only produces one data result.
5927 NodeToMask = Op.getNode();
5928 if (NodeToMask->getNumValues() > 1) {
5929 bool HasValue = false;
5930 for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
5931 MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
5932 if (VT != MVT::Glue && VT != MVT::Other) {
5933 if (HasValue) {
5934 NodeToMask = nullptr;
5935 return false;
5936 }
5937 HasValue = true;
5938 }
5939 }
5940 assert(HasValue && "Node to be masked has no data result?");
5941 }
5942 }
5943 return true;
5944 }
5945
BackwardsPropagateMask(SDNode * N)5946 bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
5947 auto *Mask = dyn_cast<ConstantSDNode>(N->getOperand(1));
5948 if (!Mask)
5949 return false;
5950
5951 if (!Mask->getAPIntValue().isMask())
5952 return false;
5953
5954 // No need to do anything if the and directly uses a load.
5955 if (isa<LoadSDNode>(N->getOperand(0)))
5956 return false;
5957
5958 SmallVector<LoadSDNode*, 8> Loads;
5959 SmallPtrSet<SDNode*, 2> NodesWithConsts;
5960 SDNode *FixupNode = nullptr;
5961 if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, FixupNode)) {
5962 if (Loads.size() == 0)
5963 return false;
5964
5965 LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
5966 SDValue MaskOp = N->getOperand(1);
5967
5968 // If it exists, fixup the single node we allow in the tree that needs
5969 // masking.
5970 if (FixupNode) {
5971 LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
5972 SDValue And = DAG.getNode(ISD::AND, SDLoc(FixupNode),
5973 FixupNode->getValueType(0),
5974 SDValue(FixupNode, 0), MaskOp);
5975 DAG.ReplaceAllUsesOfValueWith(SDValue(FixupNode, 0), And);
5976 if (And.getOpcode() == ISD ::AND)
5977 DAG.UpdateNodeOperands(And.getNode(), SDValue(FixupNode, 0), MaskOp);
5978 }
5979
5980 // Narrow any constants that need it.
5981 for (auto *LogicN : NodesWithConsts) {
5982 SDValue Op0 = LogicN->getOperand(0);
5983 SDValue Op1 = LogicN->getOperand(1);
5984
5985 if (isa<ConstantSDNode>(Op0))
5986 std::swap(Op0, Op1);
5987
5988 SDValue And = DAG.getNode(ISD::AND, SDLoc(Op1), Op1.getValueType(),
5989 Op1, MaskOp);
5990
5991 DAG.UpdateNodeOperands(LogicN, Op0, And);
5992 }
5993
5994 // Create narrow loads.
5995 for (auto *Load : Loads) {
5996 LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
5997 SDValue And = DAG.getNode(ISD::AND, SDLoc(Load), Load->getValueType(0),
5998 SDValue(Load, 0), MaskOp);
5999 DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 0), And);
6000 if (And.getOpcode() == ISD ::AND)
6001 And = SDValue(
6002 DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0);
6003 SDValue NewLoad = reduceLoadWidth(And.getNode());
6004 assert(NewLoad &&
6005 "Shouldn't be masking the load if it can't be narrowed");
6006 CombineTo(Load, NewLoad, NewLoad.getValue(1));
6007 }
6008 DAG.ReplaceAllUsesWith(N, N->getOperand(0).getNode());
6009 return true;
6010 }
6011 return false;
6012 }
6013
6014 // Unfold
6015 // x & (-1 'logical shift' y)
6016 // To
6017 // (x 'opposite logical shift' y) 'logical shift' y
6018 // if it is better for performance.
unfoldExtremeBitClearingToShifts(SDNode * N)6019 SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
6020 assert(N->getOpcode() == ISD::AND);
6021
6022 SDValue N0 = N->getOperand(0);
6023 SDValue N1 = N->getOperand(1);
6024
6025 // Do we actually prefer shifts over mask?
6026 if (!TLI.shouldFoldMaskToVariableShiftPair(N0))
6027 return SDValue();
6028
6029 // Try to match (-1 '[outer] logical shift' y)
6030 unsigned OuterShift;
6031 unsigned InnerShift; // The opposite direction to the OuterShift.
6032 SDValue Y; // Shift amount.
6033 auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
6034 if (!M.hasOneUse())
6035 return false;
6036 OuterShift = M->getOpcode();
6037 if (OuterShift == ISD::SHL)
6038 InnerShift = ISD::SRL;
6039 else if (OuterShift == ISD::SRL)
6040 InnerShift = ISD::SHL;
6041 else
6042 return false;
6043 if (!isAllOnesConstant(M->getOperand(0)))
6044 return false;
6045 Y = M->getOperand(1);
6046 return true;
6047 };
6048
6049 SDValue X;
6050 if (matchMask(N1))
6051 X = N0;
6052 else if (matchMask(N0))
6053 X = N1;
6054 else
6055 return SDValue();
6056
6057 SDLoc DL(N);
6058 EVT VT = N->getValueType(0);
6059
6060 // tmp = x 'opposite logical shift' y
6061 SDValue T0 = DAG.getNode(InnerShift, DL, VT, X, Y);
6062 // ret = tmp 'logical shift' y
6063 SDValue T1 = DAG.getNode(OuterShift, DL, VT, T0, Y);
6064
6065 return T1;
6066 }
6067
6068 /// Try to replace shift/logic that tests if a bit is clear with mask + setcc.
6069 /// For a target with a bit test, this is expected to become test + set and save
6070 /// at least 1 instruction.
combineShiftAnd1ToBitTest(SDNode * And,SelectionDAG & DAG)6071 static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) {
6072 assert(And->getOpcode() == ISD::AND && "Expected an 'and' op");
6073
6074 // This is probably not worthwhile without a supported type.
6075 EVT VT = And->getValueType(0);
6076 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6077 if (!TLI.isTypeLegal(VT))
6078 return SDValue();
6079
6080 // Look through an optional extension.
6081 SDValue And0 = And->getOperand(0), And1 = And->getOperand(1);
6082 if (And0.getOpcode() == ISD::ANY_EXTEND && And0.hasOneUse())
6083 And0 = And0.getOperand(0);
6084 if (!isOneConstant(And1) || !And0.hasOneUse())
6085 return SDValue();
6086
6087 SDValue Src = And0;
6088
6089 // Attempt to find a 'not' op.
6090 // TODO: Should we favor test+set even without the 'not' op?
6091 bool FoundNot = false;
6092 if (isBitwiseNot(Src)) {
6093 FoundNot = true;
6094 Src = Src.getOperand(0);
6095
6096 // Look though an optional truncation. The source operand may not be the
6097 // same type as the original 'and', but that is ok because we are masking
6098 // off everything but the low bit.
6099 if (Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse())
6100 Src = Src.getOperand(0);
6101 }
6102
6103 // Match a shift-right by constant.
6104 if (Src.getOpcode() != ISD::SRL || !Src.hasOneUse())
6105 return SDValue();
6106
6107 // We might have looked through casts that make this transform invalid.
6108 // TODO: If the source type is wider than the result type, do the mask and
6109 // compare in the source type.
6110 unsigned VTBitWidth = VT.getScalarSizeInBits();
6111 SDValue ShiftAmt = Src.getOperand(1);
6112 auto *ShiftAmtC = dyn_cast<ConstantSDNode>(ShiftAmt);
6113 if (!ShiftAmtC || !ShiftAmtC->getAPIntValue().ult(VTBitWidth))
6114 return SDValue();
6115
6116 // Set source to shift source.
6117 Src = Src.getOperand(0);
6118
6119 // Try again to find a 'not' op.
6120 // TODO: Should we favor test+set even with two 'not' ops?
6121 if (!FoundNot) {
6122 if (!isBitwiseNot(Src))
6123 return SDValue();
6124 Src = Src.getOperand(0);
6125 }
6126
6127 if (!TLI.hasBitTest(Src, ShiftAmt))
6128 return SDValue();
6129
6130 // Turn this into a bit-test pattern using mask op + setcc:
6131 // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
6132 // and (srl (not X), C)), 1 --> (and X, 1<<C) == 0
6133 SDLoc DL(And);
6134 SDValue X = DAG.getZExtOrTrunc(Src, DL, VT);
6135 EVT CCVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
6136 SDValue Mask = DAG.getConstant(
6137 APInt::getOneBitSet(VTBitWidth, ShiftAmtC->getZExtValue()), DL, VT);
6138 SDValue NewAnd = DAG.getNode(ISD::AND, DL, VT, X, Mask);
6139 SDValue Zero = DAG.getConstant(0, DL, VT);
6140 SDValue Setcc = DAG.getSetCC(DL, CCVT, NewAnd, Zero, ISD::SETEQ);
6141 return DAG.getZExtOrTrunc(Setcc, DL, VT);
6142 }
6143
6144 /// For targets that support usubsat, match a bit-hack form of that operation
6145 /// that ends in 'and' and convert it.
foldAndToUsubsat(SDNode * N,SelectionDAG & DAG)6146 static SDValue foldAndToUsubsat(SDNode *N, SelectionDAG &DAG) {
6147 SDValue N0 = N->getOperand(0);
6148 SDValue N1 = N->getOperand(1);
6149 EVT VT = N1.getValueType();
6150
6151 // Canonicalize SRA as operand 1.
6152 if (N0.getOpcode() == ISD::SRA)
6153 std::swap(N0, N1);
6154
6155 // xor/add with SMIN (signmask) are logically equivalent.
6156 if (N0.getOpcode() != ISD::XOR && N0.getOpcode() != ISD::ADD)
6157 return SDValue();
6158
6159 if (N1.getOpcode() != ISD::SRA || !N0.hasOneUse() || !N1.hasOneUse() ||
6160 N0.getOperand(0) != N1.getOperand(0))
6161 return SDValue();
6162
6163 unsigned BitWidth = VT.getScalarSizeInBits();
6164 ConstantSDNode *XorC = isConstOrConstSplat(N0.getOperand(1), true);
6165 ConstantSDNode *SraC = isConstOrConstSplat(N1.getOperand(1), true);
6166 if (!XorC || !XorC->getAPIntValue().isSignMask() ||
6167 !SraC || SraC->getAPIntValue() != BitWidth - 1)
6168 return SDValue();
6169
6170 // (i8 X ^ 128) & (i8 X s>> 7) --> usubsat X, 128
6171 // (i8 X + 128) & (i8 X s>> 7) --> usubsat X, 128
6172 SDLoc DL(N);
6173 SDValue SignMask = DAG.getConstant(XorC->getAPIntValue(), DL, VT);
6174 return DAG.getNode(ISD::USUBSAT, DL, VT, N0.getOperand(0), SignMask);
6175 }
6176
6177 /// Given a bitwise logic operation N with a matching bitwise logic operand,
6178 /// fold a pattern where 2 of the source operands are identically shifted
6179 /// values. For example:
6180 /// ((X0 << Y) | Z) | (X1 << Y) --> ((X0 | X1) << Y) | Z
foldLogicOfShifts(SDNode * N,SDValue LogicOp,SDValue ShiftOp,SelectionDAG & DAG)6181 static SDValue foldLogicOfShifts(SDNode *N, SDValue LogicOp, SDValue ShiftOp,
6182 SelectionDAG &DAG) {
6183 unsigned LogicOpcode = N->getOpcode();
6184 assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR ||
6185 LogicOpcode == ISD::XOR)
6186 && "Expected bitwise logic operation");
6187
6188 if (!LogicOp.hasOneUse() || !ShiftOp.hasOneUse())
6189 return SDValue();
6190
6191 // Match another bitwise logic op and a shift.
6192 unsigned ShiftOpcode = ShiftOp.getOpcode();
6193 if (LogicOp.getOpcode() != LogicOpcode ||
6194 !(ShiftOpcode == ISD::SHL || ShiftOpcode == ISD::SRL ||
6195 ShiftOpcode == ISD::SRA))
6196 return SDValue();
6197
6198 // Match another shift op inside the first logic operand. Handle both commuted
6199 // possibilities.
6200 // LOGIC (LOGIC (SH X0, Y), Z), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
6201 // LOGIC (LOGIC Z, (SH X0, Y)), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
6202 SDValue X1 = ShiftOp.getOperand(0);
6203 SDValue Y = ShiftOp.getOperand(1);
6204 SDValue X0, Z;
6205 if (LogicOp.getOperand(0).getOpcode() == ShiftOpcode &&
6206 LogicOp.getOperand(0).getOperand(1) == Y) {
6207 X0 = LogicOp.getOperand(0).getOperand(0);
6208 Z = LogicOp.getOperand(1);
6209 } else if (LogicOp.getOperand(1).getOpcode() == ShiftOpcode &&
6210 LogicOp.getOperand(1).getOperand(1) == Y) {
6211 X0 = LogicOp.getOperand(1).getOperand(0);
6212 Z = LogicOp.getOperand(0);
6213 } else {
6214 return SDValue();
6215 }
6216
6217 EVT VT = N->getValueType(0);
6218 SDLoc DL(N);
6219 SDValue LogicX = DAG.getNode(LogicOpcode, DL, VT, X0, X1);
6220 SDValue NewShift = DAG.getNode(ShiftOpcode, DL, VT, LogicX, Y);
6221 return DAG.getNode(LogicOpcode, DL, VT, NewShift, Z);
6222 }
6223
6224 /// Given a tree of logic operations with shape like
6225 /// (LOGIC (LOGIC (X, Y), LOGIC (Z, Y)))
6226 /// try to match and fold shift operations with the same shift amount.
6227 /// For example:
6228 /// LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W) -->
6229 /// --> LOGIC (SH (LOGIC X0, X1), Y), (LOGIC Z, W)
foldLogicTreeOfShifts(SDNode * N,SDValue LeftHand,SDValue RightHand,SelectionDAG & DAG)6230 static SDValue foldLogicTreeOfShifts(SDNode *N, SDValue LeftHand,
6231 SDValue RightHand, SelectionDAG &DAG) {
6232 unsigned LogicOpcode = N->getOpcode();
6233 assert((LogicOpcode == ISD::AND || LogicOpcode == ISD::OR ||
6234 LogicOpcode == ISD::XOR));
6235 if (LeftHand.getOpcode() != LogicOpcode ||
6236 RightHand.getOpcode() != LogicOpcode)
6237 return SDValue();
6238 if (!LeftHand.hasOneUse() || !RightHand.hasOneUse())
6239 return SDValue();
6240
6241 // Try to match one of following patterns:
6242 // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W)
6243 // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC W, (SH X1, Y))
6244 // Note that foldLogicOfShifts will handle commuted versions of the left hand
6245 // itself.
6246 SDValue CombinedShifts, W;
6247 SDValue R0 = RightHand.getOperand(0);
6248 SDValue R1 = RightHand.getOperand(1);
6249 if ((CombinedShifts = foldLogicOfShifts(N, LeftHand, R0, DAG)))
6250 W = R1;
6251 else if ((CombinedShifts = foldLogicOfShifts(N, LeftHand, R1, DAG)))
6252 W = R0;
6253 else
6254 return SDValue();
6255
6256 EVT VT = N->getValueType(0);
6257 SDLoc DL(N);
6258 return DAG.getNode(LogicOpcode, DL, VT, CombinedShifts, W);
6259 }
6260
visitAND(SDNode * N)6261 SDValue DAGCombiner::visitAND(SDNode *N) {
6262 SDValue N0 = N->getOperand(0);
6263 SDValue N1 = N->getOperand(1);
6264 EVT VT = N1.getValueType();
6265
6266 // x & x --> x
6267 if (N0 == N1)
6268 return N0;
6269
6270 // fold (and c1, c2) -> c1&c2
6271 if (SDValue C = DAG.FoldConstantArithmetic(ISD::AND, SDLoc(N), VT, {N0, N1}))
6272 return C;
6273
6274 // canonicalize constant to RHS
6275 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
6276 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
6277 return DAG.getNode(ISD::AND, SDLoc(N), VT, N1, N0);
6278
6279 // fold vector ops
6280 if (VT.isVector()) {
6281 if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
6282 return FoldedVOp;
6283
6284 // fold (and x, 0) -> 0, vector edition
6285 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
6286 // do not return N1, because undef node may exist in N1
6287 return DAG.getConstant(APInt::getZero(N1.getScalarValueSizeInBits()),
6288 SDLoc(N), N1.getValueType());
6289
6290 // fold (and x, -1) -> x, vector edition
6291 if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
6292 return N0;
6293
6294 // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load
6295 auto *MLoad = dyn_cast<MaskedLoadSDNode>(N0);
6296 ConstantSDNode *Splat = isConstOrConstSplat(N1, true, true);
6297 if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat &&
6298 N1.hasOneUse()) {
6299 EVT LoadVT = MLoad->getMemoryVT();
6300 EVT ExtVT = VT;
6301 if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) {
6302 // For this AND to be a zero extension of the masked load the elements
6303 // of the BuildVec must mask the bottom bits of the extended element
6304 // type
6305 uint64_t ElementSize =
6306 LoadVT.getVectorElementType().getScalarSizeInBits();
6307 if (Splat->getAPIntValue().isMask(ElementSize)) {
6308 auto NewLoad = DAG.getMaskedLoad(
6309 ExtVT, SDLoc(N), MLoad->getChain(), MLoad->getBasePtr(),
6310 MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(),
6311 LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(),
6312 ISD::ZEXTLOAD, MLoad->isExpandingLoad());
6313 bool LoadHasOtherUsers = !N0.hasOneUse();
6314 CombineTo(N, NewLoad);
6315 if (LoadHasOtherUsers)
6316 CombineTo(MLoad, NewLoad.getValue(0), NewLoad.getValue(1));
6317 return SDValue(N, 0);
6318 }
6319 }
6320 }
6321 }
6322
6323 // fold (and x, -1) -> x
6324 if (isAllOnesConstant(N1))
6325 return N0;
6326
6327 // if (and x, c) is known to be zero, return 0
6328 unsigned BitWidth = VT.getScalarSizeInBits();
6329 ConstantSDNode *N1C = isConstOrConstSplat(N1);
6330 if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(BitWidth)))
6331 return DAG.getConstant(0, SDLoc(N), VT);
6332
6333 if (SDValue NewSel = foldBinOpIntoSelect(N))
6334 return NewSel;
6335
6336 // reassociate and
6337 if (SDValue RAND = reassociateOps(ISD::AND, SDLoc(N), N0, N1, N->getFlags()))
6338 return RAND;
6339
6340 // fold (and (or x, C), D) -> D if (C & D) == D
6341 auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
6342 return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue());
6343 };
6344 if (N0.getOpcode() == ISD::OR &&
6345 ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset))
6346 return N1;
6347
6348 // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
6349 if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
6350 SDValue N0Op0 = N0.getOperand(0);
6351 APInt Mask = ~N1C->getAPIntValue();
6352 Mask = Mask.trunc(N0Op0.getScalarValueSizeInBits());
6353 if (DAG.MaskedValueIsZero(N0Op0, Mask))
6354 return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N0.getValueType(), N0Op0);
6355 }
6356
6357 // fold (and (ext (and V, c1)), c2) -> (and (ext V), (and c1, (ext c2)))
6358 if (ISD::isExtOpcode(N0.getOpcode())) {
6359 unsigned ExtOpc = N0.getOpcode();
6360 SDValue N0Op0 = N0.getOperand(0);
6361 if (N0Op0.getOpcode() == ISD::AND &&
6362 (ExtOpc != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0Op0, VT)) &&
6363 DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
6364 DAG.isConstantIntBuildVectorOrConstantInt(N0Op0.getOperand(1)) &&
6365 N0->hasOneUse() && N0Op0->hasOneUse()) {
6366 SDLoc DL(N);
6367 SDValue NewMask =
6368 DAG.getNode(ISD::AND, DL, VT, N1,
6369 DAG.getNode(ExtOpc, DL, VT, N0Op0.getOperand(1)));
6370 return DAG.getNode(ISD::AND, DL, VT,
6371 DAG.getNode(ExtOpc, DL, VT, N0Op0.getOperand(0)),
6372 NewMask);
6373 }
6374 }
6375
6376 // similarly fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
6377 // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
6378 // already be zero by virtue of the width of the base type of the load.
6379 //
6380 // the 'X' node here can either be nothing or an extract_vector_elt to catch
6381 // more cases.
6382 if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
6383 N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() &&
6384 N0.getOperand(0).getOpcode() == ISD::LOAD &&
6385 N0.getOperand(0).getResNo() == 0) ||
6386 (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
6387 LoadSDNode *Load = cast<LoadSDNode>( (N0.getOpcode() == ISD::LOAD) ?
6388 N0 : N0.getOperand(0) );
6389
6390 // Get the constant (if applicable) the zero'th operand is being ANDed with.
6391 // This can be a pure constant or a vector splat, in which case we treat the
6392 // vector as a scalar and use the splat value.
6393 APInt Constant = APInt::getZero(1);
6394 if (const ConstantSDNode *C = isConstOrConstSplat(
6395 N1, /*AllowUndef=*/false, /*AllowTruncation=*/true)) {
6396 Constant = C->getAPIntValue();
6397 } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) {
6398 APInt SplatValue, SplatUndef;
6399 unsigned SplatBitSize;
6400 bool HasAnyUndefs;
6401 bool IsSplat = Vector->isConstantSplat(SplatValue, SplatUndef,
6402 SplatBitSize, HasAnyUndefs);
6403 if (IsSplat) {
6404 // Undef bits can contribute to a possible optimisation if set, so
6405 // set them.
6406 SplatValue |= SplatUndef;
6407
6408 // The splat value may be something like "0x00FFFFFF", which means 0 for
6409 // the first vector value and FF for the rest, repeating. We need a mask
6410 // that will apply equally to all members of the vector, so AND all the
6411 // lanes of the constant together.
6412 unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits();
6413
6414 // If the splat value has been compressed to a bitlength lower
6415 // than the size of the vector lane, we need to re-expand it to
6416 // the lane size.
6417 if (EltBitWidth > SplatBitSize)
6418 for (SplatValue = SplatValue.zextOrTrunc(EltBitWidth);
6419 SplatBitSize < EltBitWidth; SplatBitSize = SplatBitSize * 2)
6420 SplatValue |= SplatValue.shl(SplatBitSize);
6421
6422 // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
6423 // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
6424 if ((SplatBitSize % EltBitWidth) == 0) {
6425 Constant = APInt::getAllOnes(EltBitWidth);
6426 for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
6427 Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth);
6428 }
6429 }
6430 }
6431
6432 // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
6433 // actually legal and isn't going to get expanded, else this is a false
6434 // optimisation.
6435 bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
6436 Load->getValueType(0),
6437 Load->getMemoryVT());
6438
6439 // Resize the constant to the same size as the original memory access before
6440 // extension. If it is still the AllOnesValue then this AND is completely
6441 // unneeded.
6442 Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits());
6443
6444 bool B;
6445 switch (Load->getExtensionType()) {
6446 default: B = false; break;
6447 case ISD::EXTLOAD: B = CanZextLoadProfitably; break;
6448 case ISD::ZEXTLOAD:
6449 case ISD::NON_EXTLOAD: B = true; break;
6450 }
6451
6452 if (B && Constant.isAllOnes()) {
6453 // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
6454 // preserve semantics once we get rid of the AND.
6455 SDValue NewLoad(Load, 0);
6456
6457 // Fold the AND away. NewLoad may get replaced immediately.
6458 CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
6459
6460 if (Load->getExtensionType() == ISD::EXTLOAD) {
6461 NewLoad = DAG.getLoad(Load->getAddressingMode(), ISD::ZEXTLOAD,
6462 Load->getValueType(0), SDLoc(Load),
6463 Load->getChain(), Load->getBasePtr(),
6464 Load->getOffset(), Load->getMemoryVT(),
6465 Load->getMemOperand());
6466 // Replace uses of the EXTLOAD with the new ZEXTLOAD.
6467 if (Load->getNumValues() == 3) {
6468 // PRE/POST_INC loads have 3 values.
6469 SDValue To[] = { NewLoad.getValue(0), NewLoad.getValue(1),
6470 NewLoad.getValue(2) };
6471 CombineTo(Load, To, 3, true);
6472 } else {
6473 CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1));
6474 }
6475 }
6476
6477 return SDValue(N, 0); // Return N so it doesn't get rechecked!
6478 }
6479 }
6480
6481 // Try to convert a constant mask AND into a shuffle clear mask.
6482 if (VT.isVector())
6483 if (SDValue Shuffle = XformToShuffleWithZero(N))
6484 return Shuffle;
6485
6486 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
6487 return Combined;
6488
6489 if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() && N1C &&
6490 ISD::isExtOpcode(N0.getOperand(0).getOpcode())) {
6491 SDValue Ext = N0.getOperand(0);
6492 EVT ExtVT = Ext->getValueType(0);
6493 SDValue Extendee = Ext->getOperand(0);
6494
6495 unsigned ScalarWidth = Extendee.getValueType().getScalarSizeInBits();
6496 if (N1C->getAPIntValue().isMask(ScalarWidth) &&
6497 (!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, ExtVT))) {
6498 // (and (extract_subvector (zext|anyext|sext v) _) iN_mask)
6499 // => (extract_subvector (iN_zeroext v))
6500 SDValue ZeroExtExtendee =
6501 DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), ExtVT, Extendee);
6502
6503 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT, ZeroExtExtendee,
6504 N0.getOperand(1));
6505 }
6506 }
6507
6508 // fold (and (masked_gather x)) -> (zext_masked_gather x)
6509 if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
6510 EVT MemVT = GN0->getMemoryVT();
6511 EVT ScalarVT = MemVT.getScalarType();
6512
6513 if (SDValue(GN0, 0).hasOneUse() &&
6514 isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) &&
6515 TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
6516 SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(),
6517 GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()};
6518
6519 SDValue ZExtLoad = DAG.getMaskedGather(
6520 DAG.getVTList(VT, MVT::Other), MemVT, SDLoc(N), Ops,
6521 GN0->getMemOperand(), GN0->getIndexType(), ISD::ZEXTLOAD);
6522
6523 CombineTo(N, ZExtLoad);
6524 AddToWorklist(ZExtLoad.getNode());
6525 // Avoid recheck of N.
6526 return SDValue(N, 0);
6527 }
6528 }
6529
6530 // fold (and (load x), 255) -> (zextload x, i8)
6531 // fold (and (extload x, i16), 255) -> (zextload x, i8)
6532 if (N1C && N0.getOpcode() == ISD::LOAD && !VT.isVector())
6533 if (SDValue Res = reduceLoadWidth(N))
6534 return Res;
6535
6536 if (LegalTypes) {
6537 // Attempt to propagate the AND back up to the leaves which, if they're
6538 // loads, can be combined to narrow loads and the AND node can be removed.
6539 // Perform after legalization so that extend nodes will already be
6540 // combined into the loads.
6541 if (BackwardsPropagateMask(N))
6542 return SDValue(N, 0);
6543 }
6544
6545 if (SDValue Combined = visitANDLike(N0, N1, N))
6546 return Combined;
6547
6548 // Simplify: (and (op x...), (op y...)) -> (op (and x, y))
6549 if (N0.getOpcode() == N1.getOpcode())
6550 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
6551 return V;
6552
6553 if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
6554 return R;
6555 if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG))
6556 return R;
6557
6558 // Masking the negated extension of a boolean is just the zero-extended
6559 // boolean:
6560 // and (sub 0, zext(bool X)), 1 --> zext(bool X)
6561 // and (sub 0, sext(bool X)), 1 --> zext(bool X)
6562 //
6563 // Note: the SimplifyDemandedBits fold below can make an information-losing
6564 // transform, and then we have no way to find this better fold.
6565 if (N1C && N1C->isOne() && N0.getOpcode() == ISD::SUB) {
6566 if (isNullOrNullSplat(N0.getOperand(0))) {
6567 SDValue SubRHS = N0.getOperand(1);
6568 if (SubRHS.getOpcode() == ISD::ZERO_EXTEND &&
6569 SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
6570 return SubRHS;
6571 if (SubRHS.getOpcode() == ISD::SIGN_EXTEND &&
6572 SubRHS.getOperand(0).getScalarValueSizeInBits() == 1)
6573 return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, SubRHS.getOperand(0));
6574 }
6575 }
6576
6577 // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
6578 // fold (and (sra)) -> (and (srl)) when possible.
6579 if (SimplifyDemandedBits(SDValue(N, 0)))
6580 return SDValue(N, 0);
6581
6582 // fold (zext_inreg (extload x)) -> (zextload x)
6583 // fold (zext_inreg (sextload x)) -> (zextload x) iff load has one use
6584 if (ISD::isUNINDEXEDLoad(N0.getNode()) &&
6585 (ISD::isEXTLoad(N0.getNode()) ||
6586 (ISD::isSEXTLoad(N0.getNode()) && N0.hasOneUse()))) {
6587 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
6588 EVT MemVT = LN0->getMemoryVT();
6589 // If we zero all the possible extended bits, then we can turn this into
6590 // a zextload if we are running before legalize or the operation is legal.
6591 unsigned ExtBitSize = N1.getScalarValueSizeInBits();
6592 unsigned MemBitSize = MemVT.getScalarSizeInBits();
6593 APInt ExtBits = APInt::getHighBitsSet(ExtBitSize, ExtBitSize - MemBitSize);
6594 if (DAG.MaskedValueIsZero(N1, ExtBits) &&
6595 ((!LegalOperations && LN0->isSimple()) ||
6596 TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) {
6597 SDValue ExtLoad =
6598 DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, LN0->getChain(),
6599 LN0->getBasePtr(), MemVT, LN0->getMemOperand());
6600 AddToWorklist(N);
6601 CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
6602 return SDValue(N, 0); // Return N so it doesn't get rechecked!
6603 }
6604 }
6605
6606 // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
6607 if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
6608 if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
6609 N0.getOperand(1), false))
6610 return BSwap;
6611 }
6612
6613 if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
6614 return Shifts;
6615
6616 if (SDValue V = combineShiftAnd1ToBitTest(N, DAG))
6617 return V;
6618
6619 // Recognize the following pattern:
6620 //
6621 // AndVT = (and (sign_extend NarrowVT to AndVT) #bitmask)
6622 //
6623 // where bitmask is a mask that clears the upper bits of AndVT. The
6624 // number of bits in bitmask must be a power of two.
6625 auto IsAndZeroExtMask = [](SDValue LHS, SDValue RHS) {
6626 if (LHS->getOpcode() != ISD::SIGN_EXTEND)
6627 return false;
6628
6629 auto *C = dyn_cast<ConstantSDNode>(RHS);
6630 if (!C)
6631 return false;
6632
6633 if (!C->getAPIntValue().isMask(
6634 LHS.getOperand(0).getValueType().getFixedSizeInBits()))
6635 return false;
6636
6637 return true;
6638 };
6639
6640 // Replace (and (sign_extend ...) #bitmask) with (zero_extend ...).
6641 if (IsAndZeroExtMask(N0, N1))
6642 return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, N0.getOperand(0));
6643
6644 if (hasOperation(ISD::USUBSAT, VT))
6645 if (SDValue V = foldAndToUsubsat(N, DAG))
6646 return V;
6647
6648 // Postpone until legalization completed to avoid interference with bswap
6649 // folding
6650 if (LegalOperations || VT.isVector())
6651 if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
6652 return R;
6653
6654 return SDValue();
6655 }
6656
6657 /// Match (a >> 8) | (a << 8) as (bswap a) >> 16.
MatchBSwapHWordLow(SDNode * N,SDValue N0,SDValue N1,bool DemandHighBits)6658 SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
6659 bool DemandHighBits) {
6660 if (!LegalOperations)
6661 return SDValue();
6662
6663 EVT VT = N->getValueType(0);
6664 if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
6665 return SDValue();
6666 if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
6667 return SDValue();
6668
6669 // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
6670 bool LookPassAnd0 = false;
6671 bool LookPassAnd1 = false;
6672 if (N0.getOpcode() == ISD::AND && N0.getOperand(0).getOpcode() == ISD::SRL)
6673 std::swap(N0, N1);
6674 if (N1.getOpcode() == ISD::AND && N1.getOperand(0).getOpcode() == ISD::SHL)
6675 std::swap(N0, N1);
6676 if (N0.getOpcode() == ISD::AND) {
6677 if (!N0->hasOneUse())
6678 return SDValue();
6679 ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6680 // Also handle 0xffff since the LHS is guaranteed to have zeros there.
6681 // This is needed for X86.
6682 if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
6683 N01C->getZExtValue() != 0xFFFF))
6684 return SDValue();
6685 N0 = N0.getOperand(0);
6686 LookPassAnd0 = true;
6687 }
6688
6689 if (N1.getOpcode() == ISD::AND) {
6690 if (!N1->hasOneUse())
6691 return SDValue();
6692 ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
6693 if (!N11C || N11C->getZExtValue() != 0xFF)
6694 return SDValue();
6695 N1 = N1.getOperand(0);
6696 LookPassAnd1 = true;
6697 }
6698
6699 if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
6700 std::swap(N0, N1);
6701 if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
6702 return SDValue();
6703 if (!N0->hasOneUse() || !N1->hasOneUse())
6704 return SDValue();
6705
6706 ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6707 ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(N1.getOperand(1));
6708 if (!N01C || !N11C)
6709 return SDValue();
6710 if (N01C->getZExtValue() != 8 || N11C->getZExtValue() != 8)
6711 return SDValue();
6712
6713 // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8)
6714 SDValue N00 = N0->getOperand(0);
6715 if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) {
6716 if (!N00->hasOneUse())
6717 return SDValue();
6718 ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(N00.getOperand(1));
6719 if (!N001C || N001C->getZExtValue() != 0xFF)
6720 return SDValue();
6721 N00 = N00.getOperand(0);
6722 LookPassAnd0 = true;
6723 }
6724
6725 SDValue N10 = N1->getOperand(0);
6726 if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) {
6727 if (!N10->hasOneUse())
6728 return SDValue();
6729 ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N10.getOperand(1));
6730 // Also allow 0xFFFF since the bits will be shifted out. This is needed
6731 // for X86.
6732 if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
6733 N101C->getZExtValue() != 0xFFFF))
6734 return SDValue();
6735 N10 = N10.getOperand(0);
6736 LookPassAnd1 = true;
6737 }
6738
6739 if (N00 != N10)
6740 return SDValue();
6741
6742 // Make sure everything beyond the low halfword gets set to zero since the SRL
6743 // 16 will clear the top bits.
6744 unsigned OpSizeInBits = VT.getSizeInBits();
6745 if (OpSizeInBits > 16) {
6746 // If the left-shift isn't masked out then the only way this is a bswap is
6747 // if all bits beyond the low 8 are 0. In that case the entire pattern
6748 // reduces to a left shift anyway: leave it for other parts of the combiner.
6749 if (DemandHighBits && !LookPassAnd0)
6750 return SDValue();
6751
6752 // However, if the right shift isn't masked out then it might be because
6753 // it's not needed. See if we can spot that too. If the high bits aren't
6754 // demanded, we only need bits 23:16 to be zero. Otherwise, we need all
6755 // upper bits to be zero.
6756 if (!LookPassAnd1) {
6757 unsigned HighBit = DemandHighBits ? OpSizeInBits : 24;
6758 if (!DAG.MaskedValueIsZero(N10,
6759 APInt::getBitsSet(OpSizeInBits, 16, HighBit)))
6760 return SDValue();
6761 }
6762 }
6763
6764 SDValue Res = DAG.getNode(ISD::BSWAP, SDLoc(N), VT, N00);
6765 if (OpSizeInBits > 16) {
6766 SDLoc DL(N);
6767 Res = DAG.getNode(ISD::SRL, DL, VT, Res,
6768 DAG.getConstant(OpSizeInBits - 16, DL,
6769 getShiftAmountTy(VT)));
6770 }
6771 return Res;
6772 }
6773
6774 /// Return true if the specified node is an element that makes up a 32-bit
6775 /// packed halfword byteswap.
6776 /// ((x & 0x000000ff) << 8) |
6777 /// ((x & 0x0000ff00) >> 8) |
6778 /// ((x & 0x00ff0000) << 8) |
6779 /// ((x & 0xff000000) >> 8)
isBSwapHWordElement(SDValue N,MutableArrayRef<SDNode * > Parts)6780 static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
6781 if (!N->hasOneUse())
6782 return false;
6783
6784 unsigned Opc = N.getOpcode();
6785 if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
6786 return false;
6787
6788 SDValue N0 = N.getOperand(0);
6789 unsigned Opc0 = N0.getOpcode();
6790 if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
6791 return false;
6792
6793 ConstantSDNode *N1C = nullptr;
6794 // SHL or SRL: look upstream for AND mask operand
6795 if (Opc == ISD::AND)
6796 N1C = dyn_cast<ConstantSDNode>(N.getOperand(1));
6797 else if (Opc0 == ISD::AND)
6798 N1C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6799 if (!N1C)
6800 return false;
6801
6802 unsigned MaskByteOffset;
6803 switch (N1C->getZExtValue()) {
6804 default:
6805 return false;
6806 case 0xFF: MaskByteOffset = 0; break;
6807 case 0xFF00: MaskByteOffset = 1; break;
6808 case 0xFFFF:
6809 // In case demanded bits didn't clear the bits that will be shifted out.
6810 // This is needed for X86.
6811 if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
6812 MaskByteOffset = 1;
6813 break;
6814 }
6815 return false;
6816 case 0xFF0000: MaskByteOffset = 2; break;
6817 case 0xFF000000: MaskByteOffset = 3; break;
6818 }
6819
6820 // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
6821 if (Opc == ISD::AND) {
6822 if (MaskByteOffset == 0 || MaskByteOffset == 2) {
6823 // (x >> 8) & 0xff
6824 // (x >> 8) & 0xff0000
6825 if (Opc0 != ISD::SRL)
6826 return false;
6827 ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6828 if (!C || C->getZExtValue() != 8)
6829 return false;
6830 } else {
6831 // (x << 8) & 0xff00
6832 // (x << 8) & 0xff000000
6833 if (Opc0 != ISD::SHL)
6834 return false;
6835 ConstantSDNode *C = dyn_cast<ConstantSDNode>(N0.getOperand(1));
6836 if (!C || C->getZExtValue() != 8)
6837 return false;
6838 }
6839 } else if (Opc == ISD::SHL) {
6840 // (x & 0xff) << 8
6841 // (x & 0xff0000) << 8
6842 if (MaskByteOffset != 0 && MaskByteOffset != 2)
6843 return false;
6844 ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
6845 if (!C || C->getZExtValue() != 8)
6846 return false;
6847 } else { // Opc == ISD::SRL
6848 // (x & 0xff00) >> 8
6849 // (x & 0xff000000) >> 8
6850 if (MaskByteOffset != 1 && MaskByteOffset != 3)
6851 return false;
6852 ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(1));
6853 if (!C || C->getZExtValue() != 8)
6854 return false;
6855 }
6856
6857 if (Parts[MaskByteOffset])
6858 return false;
6859
6860 Parts[MaskByteOffset] = N0.getOperand(0).getNode();
6861 return true;
6862 }
6863
6864 // Match 2 elements of a packed halfword bswap.
isBSwapHWordPair(SDValue N,MutableArrayRef<SDNode * > Parts)6865 static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
6866 if (N.getOpcode() == ISD::OR)
6867 return isBSwapHWordElement(N.getOperand(0), Parts) &&
6868 isBSwapHWordElement(N.getOperand(1), Parts);
6869
6870 if (N.getOpcode() == ISD::SRL && N.getOperand(0).getOpcode() == ISD::BSWAP) {
6871 ConstantSDNode *C = isConstOrConstSplat(N.getOperand(1));
6872 if (!C || C->getAPIntValue() != 16)
6873 return false;
6874 Parts[0] = Parts[1] = N.getOperand(0).getOperand(0).getNode();
6875 return true;
6876 }
6877
6878 return false;
6879 }
6880
6881 // Match this pattern:
6882 // (or (and (shl (A, 8)), 0xff00ff00), (and (srl (A, 8)), 0x00ff00ff))
6883 // And rewrite this to:
6884 // (rotr (bswap A), 16)
matchBSwapHWordOrAndAnd(const TargetLowering & TLI,SelectionDAG & DAG,SDNode * N,SDValue N0,SDValue N1,EVT VT,EVT ShiftAmountTy)6885 static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
6886 SelectionDAG &DAG, SDNode *N, SDValue N0,
6887 SDValue N1, EVT VT, EVT ShiftAmountTy) {
6888 assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
6889 "MatchBSwapHWordOrAndAnd: expecting i32");
6890 if (!TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
6891 return SDValue();
6892 if (N0.getOpcode() != ISD::AND || N1.getOpcode() != ISD::AND)
6893 return SDValue();
6894 // TODO: this is too restrictive; lifting this restriction requires more tests
6895 if (!N0->hasOneUse() || !N1->hasOneUse())
6896 return SDValue();
6897 ConstantSDNode *Mask0 = isConstOrConstSplat(N0.getOperand(1));
6898 ConstantSDNode *Mask1 = isConstOrConstSplat(N1.getOperand(1));
6899 if (!Mask0 || !Mask1)
6900 return SDValue();
6901 if (Mask0->getAPIntValue() != 0xff00ff00 ||
6902 Mask1->getAPIntValue() != 0x00ff00ff)
6903 return SDValue();
6904 SDValue Shift0 = N0.getOperand(0);
6905 SDValue Shift1 = N1.getOperand(0);
6906 if (Shift0.getOpcode() != ISD::SHL || Shift1.getOpcode() != ISD::SRL)
6907 return SDValue();
6908 ConstantSDNode *ShiftAmt0 = isConstOrConstSplat(Shift0.getOperand(1));
6909 ConstantSDNode *ShiftAmt1 = isConstOrConstSplat(Shift1.getOperand(1));
6910 if (!ShiftAmt0 || !ShiftAmt1)
6911 return SDValue();
6912 if (ShiftAmt0->getAPIntValue() != 8 || ShiftAmt1->getAPIntValue() != 8)
6913 return SDValue();
6914 if (Shift0.getOperand(0) != Shift1.getOperand(0))
6915 return SDValue();
6916
6917 SDLoc DL(N);
6918 SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, Shift0.getOperand(0));
6919 SDValue ShAmt = DAG.getConstant(16, DL, ShiftAmountTy);
6920 return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
6921 }
6922
6923 /// Match a 32-bit packed halfword bswap. That is
6924 /// ((x & 0x000000ff) << 8) |
6925 /// ((x & 0x0000ff00) >> 8) |
6926 /// ((x & 0x00ff0000) << 8) |
6927 /// ((x & 0xff000000) >> 8)
6928 /// => (rotl (bswap x), 16)
MatchBSwapHWord(SDNode * N,SDValue N0,SDValue N1)6929 SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
6930 if (!LegalOperations)
6931 return SDValue();
6932
6933 EVT VT = N->getValueType(0);
6934 if (VT != MVT::i32)
6935 return SDValue();
6936 if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
6937 return SDValue();
6938
6939 if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT,
6940 getShiftAmountTy(VT)))
6941 return BSwap;
6942
6943 // Try again with commuted operands.
6944 if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT,
6945 getShiftAmountTy(VT)))
6946 return BSwap;
6947
6948
6949 // Look for either
6950 // (or (bswaphpair), (bswaphpair))
6951 // (or (or (bswaphpair), (and)), (and))
6952 // (or (or (and), (bswaphpair)), (and))
6953 SDNode *Parts[4] = {};
6954
6955 if (isBSwapHWordPair(N0, Parts)) {
6956 // (or (or (and), (and)), (or (and), (and)))
6957 if (!isBSwapHWordPair(N1, Parts))
6958 return SDValue();
6959 } else if (N0.getOpcode() == ISD::OR) {
6960 // (or (or (or (and), (and)), (and)), (and))
6961 if (!isBSwapHWordElement(N1, Parts))
6962 return SDValue();
6963 SDValue N00 = N0.getOperand(0);
6964 SDValue N01 = N0.getOperand(1);
6965 if (!(isBSwapHWordElement(N01, Parts) && isBSwapHWordPair(N00, Parts)) &&
6966 !(isBSwapHWordElement(N00, Parts) && isBSwapHWordPair(N01, Parts)))
6967 return SDValue();
6968 } else {
6969 return SDValue();
6970 }
6971
6972 // Make sure the parts are all coming from the same node.
6973 if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3])
6974 return SDValue();
6975
6976 SDLoc DL(N);
6977 SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT,
6978 SDValue(Parts[0], 0));
6979
6980 // Result of the bswap should be rotated by 16. If it's not legal, then
6981 // do (x << 16) | (x >> 16).
6982 SDValue ShAmt = DAG.getConstant(16, DL, getShiftAmountTy(VT));
6983 if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT))
6984 return DAG.getNode(ISD::ROTL, DL, VT, BSwap, ShAmt);
6985 if (TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
6986 return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
6987 return DAG.getNode(ISD::OR, DL, VT,
6988 DAG.getNode(ISD::SHL, DL, VT, BSwap, ShAmt),
6989 DAG.getNode(ISD::SRL, DL, VT, BSwap, ShAmt));
6990 }
6991
6992 /// This contains all DAGCombine rules which reduce two values combined by
6993 /// an Or operation to a single value \see visitANDLike().
visitORLike(SDValue N0,SDValue N1,SDNode * N)6994 SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) {
6995 EVT VT = N1.getValueType();
6996 SDLoc DL(N);
6997
6998 // fold (or x, undef) -> -1
6999 if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
7000 return DAG.getAllOnesConstant(DL, VT);
7001
7002 if (SDValue V = foldLogicOfSetCCs(false, N0, N1, DL))
7003 return V;
7004
7005 // (or (and X, C1), (and Y, C2)) -> (and (or X, Y), C3) if possible.
7006 if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
7007 // Don't increase # computations.
7008 (N0->hasOneUse() || N1->hasOneUse())) {
7009 // We can only do this xform if we know that bits from X that are set in C2
7010 // but not in C1 are already zero. Likewise for Y.
7011 if (const ConstantSDNode *N0O1C =
7012 getAsNonOpaqueConstant(N0.getOperand(1))) {
7013 if (const ConstantSDNode *N1O1C =
7014 getAsNonOpaqueConstant(N1.getOperand(1))) {
7015 // We can only do this xform if we know that bits from X that are set in
7016 // C2 but not in C1 are already zero. Likewise for Y.
7017 const APInt &LHSMask = N0O1C->getAPIntValue();
7018 const APInt &RHSMask = N1O1C->getAPIntValue();
7019
7020 if (DAG.MaskedValueIsZero(N0.getOperand(0), RHSMask&~LHSMask) &&
7021 DAG.MaskedValueIsZero(N1.getOperand(0), LHSMask&~RHSMask)) {
7022 SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
7023 N0.getOperand(0), N1.getOperand(0));
7024 return DAG.getNode(ISD::AND, DL, VT, X,
7025 DAG.getConstant(LHSMask | RHSMask, DL, VT));
7026 }
7027 }
7028 }
7029 }
7030
7031 // (or (and X, M), (and X, N)) -> (and X, (or M, N))
7032 if (N0.getOpcode() == ISD::AND &&
7033 N1.getOpcode() == ISD::AND &&
7034 N0.getOperand(0) == N1.getOperand(0) &&
7035 // Don't increase # computations.
7036 (N0->hasOneUse() || N1->hasOneUse())) {
7037 SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT,
7038 N0.getOperand(1), N1.getOperand(1));
7039 return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), X);
7040 }
7041
7042 return SDValue();
7043 }
7044
7045 /// OR combines for which the commuted variant will be tried as well.
visitORCommutative(SelectionDAG & DAG,SDValue N0,SDValue N1,SDNode * N)7046 static SDValue visitORCommutative(SelectionDAG &DAG, SDValue N0, SDValue N1,
7047 SDNode *N) {
7048 EVT VT = N0.getValueType();
7049 if (N0.getOpcode() == ISD::AND) {
7050 SDValue N00 = N0.getOperand(0);
7051 SDValue N01 = N0.getOperand(1);
7052
7053 // fold or (and x, y), x --> x
7054 if (N00 == N1 || N01 == N1)
7055 return N1;
7056
7057 // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
7058 // TODO: Set AllowUndefs = true.
7059 if (getBitwiseNotOperand(N01, N00,
7060 /* AllowUndefs */ false) == N1)
7061 return DAG.getNode(ISD::OR, SDLoc(N), VT, N00, N1);
7062
7063 // fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
7064 if (getBitwiseNotOperand(N00, N01,
7065 /* AllowUndefs */ false) == N1)
7066 return DAG.getNode(ISD::OR, SDLoc(N), VT, N01, N1);
7067 }
7068
7069 if (N0.getOpcode() == ISD::XOR) {
7070 // fold or (xor x, y), x --> or x, y
7071 // or (xor x, y), (x and/or y) --> or x, y
7072 SDValue N00 = N0.getOperand(0);
7073 SDValue N01 = N0.getOperand(1);
7074 if (N00 == N1)
7075 return DAG.getNode(ISD::OR, SDLoc(N), VT, N01, N1);
7076 if (N01 == N1)
7077 return DAG.getNode(ISD::OR, SDLoc(N), VT, N00, N1);
7078
7079 if (N1.getOpcode() == ISD::AND || N1.getOpcode() == ISD::OR) {
7080 SDValue N10 = N1.getOperand(0);
7081 SDValue N11 = N1.getOperand(1);
7082 if ((N00 == N10 && N01 == N11) || (N00 == N11 && N01 == N10))
7083 return DAG.getNode(ISD::OR, SDLoc(N), VT, N00, N01);
7084 }
7085 }
7086
7087 if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
7088 return R;
7089
7090 auto peekThroughZext = [](SDValue V) {
7091 if (V->getOpcode() == ISD::ZERO_EXTEND)
7092 return V->getOperand(0);
7093 return V;
7094 };
7095
7096 // (fshl X, ?, Y) | (shl X, Y) --> fshl X, ?, Y
7097 if (N0.getOpcode() == ISD::FSHL && N1.getOpcode() == ISD::SHL &&
7098 N0.getOperand(0) == N1.getOperand(0) &&
7099 peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1)))
7100 return N0;
7101
7102 // (fshr ?, X, Y) | (srl X, Y) --> fshr ?, X, Y
7103 if (N0.getOpcode() == ISD::FSHR && N1.getOpcode() == ISD::SRL &&
7104 N0.getOperand(1) == N1.getOperand(0) &&
7105 peekThroughZext(N0.getOperand(2)) == peekThroughZext(N1.getOperand(1)))
7106 return N0;
7107
7108 return SDValue();
7109 }
7110
visitOR(SDNode * N)7111 SDValue DAGCombiner::visitOR(SDNode *N) {
7112 SDValue N0 = N->getOperand(0);
7113 SDValue N1 = N->getOperand(1);
7114 EVT VT = N1.getValueType();
7115
7116 // x | x --> x
7117 if (N0 == N1)
7118 return N0;
7119
7120 // fold (or c1, c2) -> c1|c2
7121 if (SDValue C = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N), VT, {N0, N1}))
7122 return C;
7123
7124 // canonicalize constant to RHS
7125 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
7126 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
7127 return DAG.getNode(ISD::OR, SDLoc(N), VT, N1, N0);
7128
7129 // fold vector ops
7130 if (VT.isVector()) {
7131 if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
7132 return FoldedVOp;
7133
7134 // fold (or x, 0) -> x, vector edition
7135 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
7136 return N0;
7137
7138 // fold (or x, -1) -> -1, vector edition
7139 if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
7140 // do not return N1, because undef node may exist in N1
7141 return DAG.getAllOnesConstant(SDLoc(N), N1.getValueType());
7142
7143 // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
7144 // Do this only if the resulting type / shuffle is legal.
7145 auto *SV0 = dyn_cast<ShuffleVectorSDNode>(N0);
7146 auto *SV1 = dyn_cast<ShuffleVectorSDNode>(N1);
7147 if (SV0 && SV1 && TLI.isTypeLegal(VT)) {
7148 bool ZeroN00 = ISD::isBuildVectorAllZeros(N0.getOperand(0).getNode());
7149 bool ZeroN01 = ISD::isBuildVectorAllZeros(N0.getOperand(1).getNode());
7150 bool ZeroN10 = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
7151 bool ZeroN11 = ISD::isBuildVectorAllZeros(N1.getOperand(1).getNode());
7152 // Ensure both shuffles have a zero input.
7153 if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
7154 assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
7155 assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
7156 bool CanFold = true;
7157 int NumElts = VT.getVectorNumElements();
7158 SmallVector<int, 4> Mask(NumElts, -1);
7159
7160 for (int i = 0; i != NumElts; ++i) {
7161 int M0 = SV0->getMaskElt(i);
7162 int M1 = SV1->getMaskElt(i);
7163
7164 // Determine if either index is pointing to a zero vector.
7165 bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
7166 bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
7167
7168 // If one element is zero and the otherside is undef, keep undef.
7169 // This also handles the case that both are undef.
7170 if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0))
7171 continue;
7172
7173 // Make sure only one of the elements is zero.
7174 if (M0Zero == M1Zero) {
7175 CanFold = false;
7176 break;
7177 }
7178
7179 assert((M0 >= 0 || M1 >= 0) && "Undef index!");
7180
7181 // We have a zero and non-zero element. If the non-zero came from
7182 // SV0 make the index a LHS index. If it came from SV1, make it
7183 // a RHS index. We need to mod by NumElts because we don't care
7184 // which operand it came from in the original shuffles.
7185 Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
7186 }
7187
7188 if (CanFold) {
7189 SDValue NewLHS = ZeroN00 ? N0.getOperand(1) : N0.getOperand(0);
7190 SDValue NewRHS = ZeroN10 ? N1.getOperand(1) : N1.getOperand(0);
7191
7192 SDValue LegalShuffle =
7193 TLI.buildLegalVectorShuffle(VT, SDLoc(N), NewLHS, NewRHS,
7194 Mask, DAG);
7195 if (LegalShuffle)
7196 return LegalShuffle;
7197 }
7198 }
7199 }
7200 }
7201
7202 // fold (or x, 0) -> x
7203 if (isNullConstant(N1))
7204 return N0;
7205
7206 // fold (or x, -1) -> -1
7207 if (isAllOnesConstant(N1))
7208 return N1;
7209
7210 if (SDValue NewSel = foldBinOpIntoSelect(N))
7211 return NewSel;
7212
7213 // fold (or x, c) -> c iff (x & ~c) == 0
7214 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
7215 if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue()))
7216 return N1;
7217
7218 if (SDValue Combined = visitORLike(N0, N1, N))
7219 return Combined;
7220
7221 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7222 return Combined;
7223
7224 // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
7225 if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
7226 return BSwap;
7227 if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
7228 return BSwap;
7229
7230 // reassociate or
7231 if (SDValue ROR = reassociateOps(ISD::OR, SDLoc(N), N0, N1, N->getFlags()))
7232 return ROR;
7233
7234 // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
7235 // iff (c1 & c2) != 0 or c1/c2 are undef.
7236 auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
7237 return !C1 || !C2 || C1->getAPIntValue().intersects(C2->getAPIntValue());
7238 };
7239 if (N0.getOpcode() == ISD::AND && N0->hasOneUse() &&
7240 ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect, true)) {
7241 if (SDValue COR = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N1), VT,
7242 {N1, N0.getOperand(1)})) {
7243 SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1);
7244 AddToWorklist(IOR.getNode());
7245 return DAG.getNode(ISD::AND, SDLoc(N), VT, COR, IOR);
7246 }
7247 }
7248
7249 if (SDValue Combined = visitORCommutative(DAG, N0, N1, N))
7250 return Combined;
7251 if (SDValue Combined = visitORCommutative(DAG, N1, N0, N))
7252 return Combined;
7253
7254 // Simplify: (or (op x...), (op y...)) -> (op (or x, y))
7255 if (N0.getOpcode() == N1.getOpcode())
7256 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7257 return V;
7258
7259 // See if this is some rotate idiom.
7260 if (SDValue Rot = MatchRotate(N0, N1, SDLoc(N)))
7261 return Rot;
7262
7263 if (SDValue Load = MatchLoadCombine(N))
7264 return Load;
7265
7266 // Simplify the operands using demanded-bits information.
7267 if (SimplifyDemandedBits(SDValue(N, 0)))
7268 return SDValue(N, 0);
7269
7270 // If OR can be rewritten into ADD, try combines based on ADD.
7271 if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
7272 DAG.haveNoCommonBitsSet(N0, N1))
7273 if (SDValue Combined = visitADDLike(N))
7274 return Combined;
7275
7276 // Postpone until legalization completed to avoid interference with bswap
7277 // folding
7278 if (LegalOperations || VT.isVector())
7279 if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
7280 return R;
7281
7282 return SDValue();
7283 }
7284
stripConstantMask(const SelectionDAG & DAG,SDValue Op,SDValue & Mask)7285 static SDValue stripConstantMask(const SelectionDAG &DAG, SDValue Op,
7286 SDValue &Mask) {
7287 if (Op.getOpcode() == ISD::AND &&
7288 DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) {
7289 Mask = Op.getOperand(1);
7290 return Op.getOperand(0);
7291 }
7292 return Op;
7293 }
7294
7295 /// Match "(X shl/srl V1) & V2" where V2 may not be present.
matchRotateHalf(const SelectionDAG & DAG,SDValue Op,SDValue & Shift,SDValue & Mask)7296 static bool matchRotateHalf(const SelectionDAG &DAG, SDValue Op, SDValue &Shift,
7297 SDValue &Mask) {
7298 Op = stripConstantMask(DAG, Op, Mask);
7299 if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
7300 Shift = Op;
7301 return true;
7302 }
7303 return false;
7304 }
7305
7306 /// Helper function for visitOR to extract the needed side of a rotate idiom
7307 /// from a shl/srl/mul/udiv. This is meant to handle cases where
7308 /// InstCombine merged some outside op with one of the shifts from
7309 /// the rotate pattern.
7310 /// \returns An empty \c SDValue if the needed shift couldn't be extracted.
7311 /// Otherwise, returns an expansion of \p ExtractFrom based on the following
7312 /// patterns:
7313 ///
7314 /// (or (add v v) (shrl v bitwidth-1)):
7315 /// expands (add v v) -> (shl v 1)
7316 ///
7317 /// (or (mul v c0) (shrl (mul v c1) c2)):
7318 /// expands (mul v c0) -> (shl (mul v c1) c3)
7319 ///
7320 /// (or (udiv v c0) (shl (udiv v c1) c2)):
7321 /// expands (udiv v c0) -> (shrl (udiv v c1) c3)
7322 ///
7323 /// (or (shl v c0) (shrl (shl v c1) c2)):
7324 /// expands (shl v c0) -> (shl (shl v c1) c3)
7325 ///
7326 /// (or (shrl v c0) (shl (shrl v c1) c2)):
7327 /// expands (shrl v c0) -> (shrl (shrl v c1) c3)
7328 ///
7329 /// Such that in all cases, c3+c2==bitwidth(op v c1).
extractShiftForRotate(SelectionDAG & DAG,SDValue OppShift,SDValue ExtractFrom,SDValue & Mask,const SDLoc & DL)7330 static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
7331 SDValue ExtractFrom, SDValue &Mask,
7332 const SDLoc &DL) {
7333 assert(OppShift && ExtractFrom && "Empty SDValue");
7334 if (OppShift.getOpcode() != ISD::SHL && OppShift.getOpcode() != ISD::SRL)
7335 return SDValue();
7336
7337 ExtractFrom = stripConstantMask(DAG, ExtractFrom, Mask);
7338
7339 // Value and Type of the shift.
7340 SDValue OppShiftLHS = OppShift.getOperand(0);
7341 EVT ShiftedVT = OppShiftLHS.getValueType();
7342
7343 // Amount of the existing shift.
7344 ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1));
7345
7346 // (add v v) -> (shl v 1)
7347 // TODO: Should this be a general DAG canonicalization?
7348 if (OppShift.getOpcode() == ISD::SRL && OppShiftCst &&
7349 ExtractFrom.getOpcode() == ISD::ADD &&
7350 ExtractFrom.getOperand(0) == ExtractFrom.getOperand(1) &&
7351 ExtractFrom.getOperand(0) == OppShiftLHS &&
7352 OppShiftCst->getAPIntValue() == ShiftedVT.getScalarSizeInBits() - 1)
7353 return DAG.getNode(ISD::SHL, DL, ShiftedVT, OppShiftLHS,
7354 DAG.getShiftAmountConstant(1, ShiftedVT, DL));
7355
7356 // Preconditions:
7357 // (or (op0 v c0) (shiftl/r (op0 v c1) c2))
7358 //
7359 // Find opcode of the needed shift to be extracted from (op0 v c0).
7360 unsigned Opcode = ISD::DELETED_NODE;
7361 bool IsMulOrDiv = false;
7362 // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
7363 // opcode or its arithmetic (mul or udiv) variant.
7364 auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
7365 IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
7366 if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
7367 return false;
7368 Opcode = NeededShift;
7369 return true;
7370 };
7371 // op0 must be either the needed shift opcode or the mul/udiv equivalent
7372 // that the needed shift can be extracted from.
7373 if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
7374 (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
7375 return SDValue();
7376
7377 // op0 must be the same opcode on both sides, have the same LHS argument,
7378 // and produce the same value type.
7379 if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
7380 OppShiftLHS.getOperand(0) != ExtractFrom.getOperand(0) ||
7381 ShiftedVT != ExtractFrom.getValueType())
7382 return SDValue();
7383
7384 // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
7385 ConstantSDNode *OppLHSCst = isConstOrConstSplat(OppShiftLHS.getOperand(1));
7386 // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
7387 ConstantSDNode *ExtractFromCst =
7388 isConstOrConstSplat(ExtractFrom.getOperand(1));
7389 // TODO: We should be able to handle non-uniform constant vectors for these values
7390 // Check that we have constant values.
7391 if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
7392 !OppLHSCst || !OppLHSCst->getAPIntValue() ||
7393 !ExtractFromCst || !ExtractFromCst->getAPIntValue())
7394 return SDValue();
7395
7396 // Compute the shift amount we need to extract to complete the rotate.
7397 const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
7398 if (OppShiftCst->getAPIntValue().ugt(VTWidth))
7399 return SDValue();
7400 APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
7401 // Normalize the bitwidth of the two mul/udiv/shift constant operands.
7402 APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
7403 APInt OppLHSAmt = OppLHSCst->getAPIntValue();
7404 zeroExtendToMatch(ExtractFromAmt, OppLHSAmt);
7405
7406 // Now try extract the needed shift from the ExtractFrom op and see if the
7407 // result matches up with the existing shift's LHS op.
7408 if (IsMulOrDiv) {
7409 // Op to extract from is a mul or udiv by a constant.
7410 // Check:
7411 // c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
7412 // c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
7413 const APInt ExtractDiv = APInt::getOneBitSet(ExtractFromAmt.getBitWidth(),
7414 NeededShiftAmt.getZExtValue());
7415 APInt ResultAmt;
7416 APInt Rem;
7417 APInt::udivrem(ExtractFromAmt, ExtractDiv, ResultAmt, Rem);
7418 if (Rem != 0 || ResultAmt != OppLHSAmt)
7419 return SDValue();
7420 } else {
7421 // Op to extract from is a shift by a constant.
7422 // Check:
7423 // c2 - (bitwidth(op0 v c0) - c1) == c0
7424 if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
7425 ExtractFromAmt.getBitWidth()))
7426 return SDValue();
7427 }
7428
7429 // Return the expanded shift op that should allow a rotate to be formed.
7430 EVT ShiftVT = OppShift.getOperand(1).getValueType();
7431 EVT ResVT = ExtractFrom.getValueType();
7432 SDValue NewShiftNode = DAG.getConstant(NeededShiftAmt, DL, ShiftVT);
7433 return DAG.getNode(Opcode, DL, ResVT, OppShiftLHS, NewShiftNode);
7434 }
7435
7436 // Return true if we can prove that, whenever Neg and Pos are both in the
7437 // range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos). This means that
7438 // for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
7439 //
7440 // (or (shift1 X, Neg), (shift2 X, Pos))
7441 //
7442 // reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
7443 // in direction shift1 by Neg. The range [0, EltSize) means that we only need
7444 // to consider shift amounts with defined behavior.
7445 //
7446 // The IsRotate flag should be set when the LHS of both shifts is the same.
7447 // Otherwise if matching a general funnel shift, it should be clear.
matchRotateSub(SDValue Pos,SDValue Neg,unsigned EltSize,SelectionDAG & DAG,bool IsRotate)7448 static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
7449 SelectionDAG &DAG, bool IsRotate) {
7450 const auto &TLI = DAG.getTargetLoweringInfo();
7451 // If EltSize is a power of 2 then:
7452 //
7453 // (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
7454 // (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
7455 //
7456 // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
7457 // for the stronger condition:
7458 //
7459 // Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1) [A]
7460 //
7461 // for all Neg and Pos. Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
7462 // we can just replace Neg with Neg' for the rest of the function.
7463 //
7464 // In other cases we check for the even stronger condition:
7465 //
7466 // Neg == EltSize - Pos [B]
7467 //
7468 // for all Neg and Pos. Note that the (or ...) then invokes undefined
7469 // behavior if Pos == 0 (and consequently Neg == EltSize).
7470 //
7471 // We could actually use [A] whenever EltSize is a power of 2, but the
7472 // only extra cases that it would match are those uninteresting ones
7473 // where Neg and Pos are never in range at the same time. E.g. for
7474 // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
7475 // as well as (sub 32, Pos), but:
7476 //
7477 // (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
7478 //
7479 // always invokes undefined behavior for 32-bit X.
7480 //
7481 // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
7482 // This allows us to peek through any operations that only affect Mask's
7483 // un-demanded bits.
7484 //
7485 // NOTE: We can only do this when matching operations which won't modify the
7486 // least Log2(EltSize) significant bits and not a general funnel shift.
7487 unsigned MaskLoBits = 0;
7488 if (IsRotate && isPowerOf2_64(EltSize)) {
7489 unsigned Bits = Log2_64(EltSize);
7490 unsigned NegBits = Neg.getScalarValueSizeInBits();
7491 if (NegBits >= Bits) {
7492 APInt DemandedBits = APInt::getLowBitsSet(NegBits, Bits);
7493 if (SDValue Inner =
7494 TLI.SimplifyMultipleUseDemandedBits(Neg, DemandedBits, DAG)) {
7495 Neg = Inner;
7496 MaskLoBits = Bits;
7497 }
7498 }
7499 }
7500
7501 // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
7502 if (Neg.getOpcode() != ISD::SUB)
7503 return false;
7504 ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(0));
7505 if (!NegC)
7506 return false;
7507 SDValue NegOp1 = Neg.getOperand(1);
7508
7509 // On the RHS of [A], if Pos is the result of operation on Pos' that won't
7510 // affect Mask's demanded bits, just replace Pos with Pos'. These operations
7511 // are redundant for the purpose of the equality.
7512 if (MaskLoBits) {
7513 unsigned PosBits = Pos.getScalarValueSizeInBits();
7514 if (PosBits >= MaskLoBits) {
7515 APInt DemandedBits = APInt::getLowBitsSet(PosBits, MaskLoBits);
7516 if (SDValue Inner =
7517 TLI.SimplifyMultipleUseDemandedBits(Pos, DemandedBits, DAG)) {
7518 Pos = Inner;
7519 }
7520 }
7521 }
7522
7523 // The condition we need is now:
7524 //
7525 // (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
7526 //
7527 // If NegOp1 == Pos then we need:
7528 //
7529 // EltSize & Mask == NegC & Mask
7530 //
7531 // (because "x & Mask" is a truncation and distributes through subtraction).
7532 //
7533 // We also need to account for a potential truncation of NegOp1 if the amount
7534 // has already been legalized to a shift amount type.
7535 APInt Width;
7536 if ((Pos == NegOp1) ||
7537 (NegOp1.getOpcode() == ISD::TRUNCATE && Pos == NegOp1.getOperand(0)))
7538 Width = NegC->getAPIntValue();
7539
7540 // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
7541 // Then the condition we want to prove becomes:
7542 //
7543 // (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
7544 //
7545 // which, again because "x & Mask" is a truncation, becomes:
7546 //
7547 // NegC & Mask == (EltSize - PosC) & Mask
7548 // EltSize & Mask == (NegC + PosC) & Mask
7549 else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(0) == NegOp1) {
7550 if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1)))
7551 Width = PosC->getAPIntValue() + NegC->getAPIntValue();
7552 else
7553 return false;
7554 } else
7555 return false;
7556
7557 // Now we just need to check that EltSize & Mask == Width & Mask.
7558 if (MaskLoBits)
7559 // EltSize & Mask is 0 since Mask is EltSize - 1.
7560 return Width.getLoBits(MaskLoBits) == 0;
7561 return Width == EltSize;
7562 }
7563
7564 // A subroutine of MatchRotate used once we have found an OR of two opposite
7565 // shifts of Shifted. If Neg == <operand size> - Pos then the OR reduces
7566 // to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the
7567 // former being preferred if supported. InnerPos and InnerNeg are Pos and
7568 // Neg with outer conversions stripped away.
MatchRotatePosNeg(SDValue Shifted,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,bool HasPos,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)7569 SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
7570 SDValue Neg, SDValue InnerPos,
7571 SDValue InnerNeg, bool HasPos,
7572 unsigned PosOpcode, unsigned NegOpcode,
7573 const SDLoc &DL) {
7574 // fold (or (shl x, (*ext y)),
7575 // (srl x, (*ext (sub 32, y)))) ->
7576 // (rotl x, y) or (rotr x, (sub 32, y))
7577 //
7578 // fold (or (shl x, (*ext (sub 32, y))),
7579 // (srl x, (*ext y))) ->
7580 // (rotr x, y) or (rotl x, (sub 32, y))
7581 EVT VT = Shifted.getValueType();
7582 if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG,
7583 /*IsRotate*/ true)) {
7584 return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
7585 HasPos ? Pos : Neg);
7586 }
7587
7588 return SDValue();
7589 }
7590
7591 // A subroutine of MatchRotate used once we have found an OR of two opposite
7592 // shifts of N0 + N1. If Neg == <operand size> - Pos then the OR reduces
7593 // to both (PosOpcode N0, N1, Pos) and (NegOpcode N0, N1, Neg), with the
7594 // former being preferred if supported. InnerPos and InnerNeg are Pos and
7595 // Neg with outer conversions stripped away.
7596 // TODO: Merge with MatchRotatePosNeg.
MatchFunnelPosNeg(SDValue N0,SDValue N1,SDValue Pos,SDValue Neg,SDValue InnerPos,SDValue InnerNeg,bool HasPos,unsigned PosOpcode,unsigned NegOpcode,const SDLoc & DL)7597 SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
7598 SDValue Neg, SDValue InnerPos,
7599 SDValue InnerNeg, bool HasPos,
7600 unsigned PosOpcode, unsigned NegOpcode,
7601 const SDLoc &DL) {
7602 EVT VT = N0.getValueType();
7603 unsigned EltBits = VT.getScalarSizeInBits();
7604
7605 // fold (or (shl x0, (*ext y)),
7606 // (srl x1, (*ext (sub 32, y)))) ->
7607 // (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y))
7608 //
7609 // fold (or (shl x0, (*ext (sub 32, y))),
7610 // (srl x1, (*ext y))) ->
7611 // (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
7612 if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1)) {
7613 return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, N0, N1,
7614 HasPos ? Pos : Neg);
7615 }
7616
7617 // Matching the shift+xor cases, we can't easily use the xor'd shift amount
7618 // so for now just use the PosOpcode case if its legal.
7619 // TODO: When can we use the NegOpcode case?
7620 if (PosOpcode == ISD::FSHL && isPowerOf2_32(EltBits)) {
7621 auto IsBinOpImm = [](SDValue Op, unsigned BinOpc, unsigned Imm) {
7622 if (Op.getOpcode() != BinOpc)
7623 return false;
7624 ConstantSDNode *Cst = isConstOrConstSplat(Op.getOperand(1));
7625 return Cst && (Cst->getAPIntValue() == Imm);
7626 };
7627
7628 // fold (or (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
7629 // -> (fshl x0, x1, y)
7630 if (IsBinOpImm(N1, ISD::SRL, 1) &&
7631 IsBinOpImm(InnerNeg, ISD::XOR, EltBits - 1) &&
7632 InnerPos == InnerNeg.getOperand(0) &&
7633 TLI.isOperationLegalOrCustom(ISD::FSHL, VT)) {
7634 return DAG.getNode(ISD::FSHL, DL, VT, N0, N1.getOperand(0), Pos);
7635 }
7636
7637 // fold (or (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
7638 // -> (fshr x0, x1, y)
7639 if (IsBinOpImm(N0, ISD::SHL, 1) &&
7640 IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
7641 InnerNeg == InnerPos.getOperand(0) &&
7642 TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
7643 return DAG.getNode(ISD::FSHR, DL, VT, N0.getOperand(0), N1, Neg);
7644 }
7645
7646 // fold (or (shl (add x0, x0), (xor y, 31)), (srl x1, y))
7647 // -> (fshr x0, x1, y)
7648 // TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization?
7649 if (N0.getOpcode() == ISD::ADD && N0.getOperand(0) == N0.getOperand(1) &&
7650 IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
7651 InnerNeg == InnerPos.getOperand(0) &&
7652 TLI.isOperationLegalOrCustom(ISD::FSHR, VT)) {
7653 return DAG.getNode(ISD::FSHR, DL, VT, N0.getOperand(0), N1, Neg);
7654 }
7655 }
7656
7657 return SDValue();
7658 }
7659
7660 // MatchRotate - Handle an 'or' of two operands. If this is one of the many
7661 // idioms for rotate, and if the target supports rotation instructions, generate
7662 // a rot[lr]. This also matches funnel shift patterns, similar to rotation but
7663 // with different shifted sources.
MatchRotate(SDValue LHS,SDValue RHS,const SDLoc & DL)7664 SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
7665 EVT VT = LHS.getValueType();
7666
7667 // The target must have at least one rotate/funnel flavor.
7668 // We still try to match rotate by constant pre-legalization.
7669 // TODO: Support pre-legalization funnel-shift by constant.
7670 bool HasROTL = hasOperation(ISD::ROTL, VT);
7671 bool HasROTR = hasOperation(ISD::ROTR, VT);
7672 bool HasFSHL = hasOperation(ISD::FSHL, VT);
7673 bool HasFSHR = hasOperation(ISD::FSHR, VT);
7674
7675 // If the type is going to be promoted and the target has enabled custom
7676 // lowering for rotate, allow matching rotate by non-constants. Only allow
7677 // this for scalar types.
7678 if (VT.isScalarInteger() && TLI.getTypeAction(*DAG.getContext(), VT) ==
7679 TargetLowering::TypePromoteInteger) {
7680 HasROTL |= TLI.getOperationAction(ISD::ROTL, VT) == TargetLowering::Custom;
7681 HasROTR |= TLI.getOperationAction(ISD::ROTR, VT) == TargetLowering::Custom;
7682 }
7683
7684 if (LegalOperations && !HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
7685 return SDValue();
7686
7687 // Check for truncated rotate.
7688 if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
7689 LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) {
7690 assert(LHS.getValueType() == RHS.getValueType());
7691 if (SDValue Rot = MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL)) {
7692 return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), Rot);
7693 }
7694 }
7695
7696 // Match "(X shl/srl V1) & V2" where V2 may not be present.
7697 SDValue LHSShift; // The shift.
7698 SDValue LHSMask; // AND value if any.
7699 matchRotateHalf(DAG, LHS, LHSShift, LHSMask);
7700
7701 SDValue RHSShift; // The shift.
7702 SDValue RHSMask; // AND value if any.
7703 matchRotateHalf(DAG, RHS, RHSShift, RHSMask);
7704
7705 // If neither side matched a rotate half, bail
7706 if (!LHSShift && !RHSShift)
7707 return SDValue();
7708
7709 // InstCombine may have combined a constant shl, srl, mul, or udiv with one
7710 // side of the rotate, so try to handle that here. In all cases we need to
7711 // pass the matched shift from the opposite side to compute the opcode and
7712 // needed shift amount to extract. We still want to do this if both sides
7713 // matched a rotate half because one half may be a potential overshift that
7714 // can be broken down (ie if InstCombine merged two shl or srl ops into a
7715 // single one).
7716
7717 // Have LHS side of the rotate, try to extract the needed shift from the RHS.
7718 if (LHSShift)
7719 if (SDValue NewRHSShift =
7720 extractShiftForRotate(DAG, LHSShift, RHS, RHSMask, DL))
7721 RHSShift = NewRHSShift;
7722 // Have RHS side of the rotate, try to extract the needed shift from the LHS.
7723 if (RHSShift)
7724 if (SDValue NewLHSShift =
7725 extractShiftForRotate(DAG, RHSShift, LHS, LHSMask, DL))
7726 LHSShift = NewLHSShift;
7727
7728 // If a side is still missing, nothing else we can do.
7729 if (!RHSShift || !LHSShift)
7730 return SDValue();
7731
7732 // At this point we've matched or extracted a shift op on each side.
7733
7734 if (LHSShift.getOpcode() == RHSShift.getOpcode())
7735 return SDValue(); // Shifts must disagree.
7736
7737 // Canonicalize shl to left side in a shl/srl pair.
7738 if (RHSShift.getOpcode() == ISD::SHL) {
7739 std::swap(LHS, RHS);
7740 std::swap(LHSShift, RHSShift);
7741 std::swap(LHSMask, RHSMask);
7742 }
7743
7744 // Something has gone wrong - we've lost the shl/srl pair - bail.
7745 if (LHSShift.getOpcode() != ISD::SHL || RHSShift.getOpcode() != ISD::SRL)
7746 return SDValue();
7747
7748 unsigned EltSizeInBits = VT.getScalarSizeInBits();
7749 SDValue LHSShiftArg = LHSShift.getOperand(0);
7750 SDValue LHSShiftAmt = LHSShift.getOperand(1);
7751 SDValue RHSShiftArg = RHSShift.getOperand(0);
7752 SDValue RHSShiftAmt = RHSShift.getOperand(1);
7753
7754 auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
7755 ConstantSDNode *RHS) {
7756 return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
7757 };
7758
7759 auto ApplyMasks = [&](SDValue Res) {
7760 // If there is an AND of either shifted operand, apply it to the result.
7761 if (LHSMask.getNode() || RHSMask.getNode()) {
7762 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
7763 SDValue Mask = AllOnes;
7764
7765 if (LHSMask.getNode()) {
7766 SDValue RHSBits = DAG.getNode(ISD::SRL, DL, VT, AllOnes, RHSShiftAmt);
7767 Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
7768 DAG.getNode(ISD::OR, DL, VT, LHSMask, RHSBits));
7769 }
7770 if (RHSMask.getNode()) {
7771 SDValue LHSBits = DAG.getNode(ISD::SHL, DL, VT, AllOnes, LHSShiftAmt);
7772 Mask = DAG.getNode(ISD::AND, DL, VT, Mask,
7773 DAG.getNode(ISD::OR, DL, VT, RHSMask, LHSBits));
7774 }
7775
7776 Res = DAG.getNode(ISD::AND, DL, VT, Res, Mask);
7777 }
7778
7779 return Res;
7780 };
7781
7782 // TODO: Support pre-legalization funnel-shift by constant.
7783 bool IsRotate = LHSShiftArg == RHSShiftArg;
7784 if (!IsRotate && !(HasFSHL || HasFSHR)) {
7785 if (TLI.isTypeLegal(VT) && LHS.hasOneUse() && RHS.hasOneUse() &&
7786 ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
7787 // Look for a disguised rotate by constant.
7788 // The common shifted operand X may be hidden inside another 'or'.
7789 SDValue X, Y;
7790 auto matchOr = [&X, &Y](SDValue Or, SDValue CommonOp) {
7791 if (!Or.hasOneUse() || Or.getOpcode() != ISD::OR)
7792 return false;
7793 if (CommonOp == Or.getOperand(0)) {
7794 X = CommonOp;
7795 Y = Or.getOperand(1);
7796 return true;
7797 }
7798 if (CommonOp == Or.getOperand(1)) {
7799 X = CommonOp;
7800 Y = Or.getOperand(0);
7801 return true;
7802 }
7803 return false;
7804 };
7805
7806 SDValue Res;
7807 if (matchOr(LHSShiftArg, RHSShiftArg)) {
7808 // (shl (X | Y), C1) | (srl X, C2) --> (rotl X, C1) | (shl Y, C1)
7809 SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt);
7810 SDValue ShlY = DAG.getNode(ISD::SHL, DL, VT, Y, LHSShiftAmt);
7811 Res = DAG.getNode(ISD::OR, DL, VT, RotX, ShlY);
7812 } else if (matchOr(RHSShiftArg, LHSShiftArg)) {
7813 // (shl X, C1) | (srl (X | Y), C2) --> (rotl X, C1) | (srl Y, C2)
7814 SDValue RotX = DAG.getNode(ISD::ROTL, DL, VT, X, LHSShiftAmt);
7815 SDValue SrlY = DAG.getNode(ISD::SRL, DL, VT, Y, RHSShiftAmt);
7816 Res = DAG.getNode(ISD::OR, DL, VT, RotX, SrlY);
7817 } else {
7818 return SDValue();
7819 }
7820
7821 return ApplyMasks(Res);
7822 }
7823
7824 return SDValue(); // Requires funnel shift support.
7825 }
7826
7827 // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
7828 // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
7829 // fold (or (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
7830 // fold (or (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
7831 // iff C1+C2 == EltSizeInBits
7832 if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
7833 SDValue Res;
7834 if (IsRotate && (HasROTL || HasROTR || !(HasFSHL || HasFSHR))) {
7835 bool UseROTL = !LegalOperations || HasROTL;
7836 Res = DAG.getNode(UseROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg,
7837 UseROTL ? LHSShiftAmt : RHSShiftAmt);
7838 } else {
7839 bool UseFSHL = !LegalOperations || HasFSHL;
7840 Res = DAG.getNode(UseFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, LHSShiftArg,
7841 RHSShiftArg, UseFSHL ? LHSShiftAmt : RHSShiftAmt);
7842 }
7843
7844 return ApplyMasks(Res);
7845 }
7846
7847 // Even pre-legalization, we can't easily rotate/funnel-shift by a variable
7848 // shift.
7849 if (!HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
7850 return SDValue();
7851
7852 // If there is a mask here, and we have a variable shift, we can't be sure
7853 // that we're masking out the right stuff.
7854 if (LHSMask.getNode() || RHSMask.getNode())
7855 return SDValue();
7856
7857 // If the shift amount is sign/zext/any-extended just peel it off.
7858 SDValue LExtOp0 = LHSShiftAmt;
7859 SDValue RExtOp0 = RHSShiftAmt;
7860 if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
7861 LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
7862 LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
7863 LHSShiftAmt.getOpcode() == ISD::TRUNCATE) &&
7864 (RHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
7865 RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
7866 RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
7867 RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
7868 LExtOp0 = LHSShiftAmt.getOperand(0);
7869 RExtOp0 = RHSShiftAmt.getOperand(0);
7870 }
7871
7872 if (IsRotate && (HasROTL || HasROTR)) {
7873 SDValue TryL =
7874 MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt, LExtOp0,
7875 RExtOp0, HasROTL, ISD::ROTL, ISD::ROTR, DL);
7876 if (TryL)
7877 return TryL;
7878
7879 SDValue TryR =
7880 MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt, RExtOp0,
7881 LExtOp0, HasROTR, ISD::ROTR, ISD::ROTL, DL);
7882 if (TryR)
7883 return TryR;
7884 }
7885
7886 SDValue TryL =
7887 MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt, RHSShiftAmt,
7888 LExtOp0, RExtOp0, HasFSHL, ISD::FSHL, ISD::FSHR, DL);
7889 if (TryL)
7890 return TryL;
7891
7892 SDValue TryR =
7893 MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
7894 RExtOp0, LExtOp0, HasFSHR, ISD::FSHR, ISD::FSHL, DL);
7895 if (TryR)
7896 return TryR;
7897
7898 return SDValue();
7899 }
7900
7901 namespace {
7902
7903 /// Represents known origin of an individual byte in load combine pattern. The
7904 /// value of the byte is either constant zero or comes from memory.
7905 struct ByteProvider {
7906 // For constant zero providers Load is set to nullptr. For memory providers
7907 // Load represents the node which loads the byte from memory.
7908 // ByteOffset is the offset of the byte in the value produced by the load.
7909 LoadSDNode *Load = nullptr;
7910 unsigned ByteOffset = 0;
7911 unsigned VectorOffset = 0;
7912
7913 ByteProvider() = default;
7914
getMemory__anonbd6f1c501611::ByteProvider7915 static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset,
7916 unsigned VectorOffset) {
7917 return ByteProvider(Load, ByteOffset, VectorOffset);
7918 }
7919
getConstantZero__anonbd6f1c501611::ByteProvider7920 static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0, 0); }
7921
isConstantZero__anonbd6f1c501611::ByteProvider7922 bool isConstantZero() const { return !Load; }
isMemory__anonbd6f1c501611::ByteProvider7923 bool isMemory() const { return Load; }
7924
operator ==__anonbd6f1c501611::ByteProvider7925 bool operator==(const ByteProvider &Other) const {
7926 return Other.Load == Load && Other.ByteOffset == ByteOffset &&
7927 Other.VectorOffset == VectorOffset;
7928 }
7929
7930 private:
ByteProvider__anonbd6f1c501611::ByteProvider7931 ByteProvider(LoadSDNode *Load, unsigned ByteOffset, unsigned VectorOffset)
7932 : Load(Load), ByteOffset(ByteOffset), VectorOffset(VectorOffset) {}
7933 };
7934
7935 } // end anonymous namespace
7936
7937 /// Recursively traverses the expression calculating the origin of the requested
7938 /// byte of the given value. Returns std::nullopt if the provider can't be
7939 /// calculated.
7940 ///
7941 /// For all the values except the root of the expression, we verify that the
7942 /// value has exactly one use and if not then return std::nullopt. This way if
7943 /// the origin of the byte is returned it's guaranteed that the values which
7944 /// contribute to the byte are not used outside of this expression.
7945
7946 /// However, there is a special case when dealing with vector loads -- we allow
7947 /// more than one use if the load is a vector type. Since the values that
7948 /// contribute to the byte ultimately come from the ExtractVectorElements of the
7949 /// Load, we don't care if the Load has uses other than ExtractVectorElements,
7950 /// because those operations are independent from the pattern to be combined.
7951 /// For vector loads, we simply care that the ByteProviders are adjacent
7952 /// positions of the same vector, and their index matches the byte that is being
7953 /// provided. This is captured by the \p VectorIndex algorithm. \p VectorIndex
7954 /// is the index used in an ExtractVectorElement, and \p StartingIndex is the
7955 /// byte position we are trying to provide for the LoadCombine. If these do
7956 /// not match, then we can not combine the vector loads. \p Index uses the
7957 /// byte position we are trying to provide for and is matched against the
7958 /// shl and load size. The \p Index algorithm ensures the requested byte is
7959 /// provided for by the pattern, and the pattern does not over provide bytes.
7960 ///
7961 ///
7962 /// The supported LoadCombine pattern for vector loads is as follows
7963 /// or
7964 /// / \
7965 /// or shl
7966 /// / \ |
7967 /// or shl zext
7968 /// / \ | |
7969 /// shl zext zext EVE*
7970 /// | | | |
7971 /// zext EVE* EVE* LOAD
7972 /// | | |
7973 /// EVE* LOAD LOAD
7974 /// |
7975 /// LOAD
7976 ///
7977 /// *ExtractVectorElement
7978 static const std::optional<ByteProvider>
calculateByteProvider(SDValue Op,unsigned Index,unsigned Depth,std::optional<uint64_t> VectorIndex,unsigned StartingIndex=0)7979 calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
7980 std::optional<uint64_t> VectorIndex,
7981 unsigned StartingIndex = 0) {
7982
7983 // Typical i64 by i8 pattern requires recursion up to 8 calls depth
7984 if (Depth == 10)
7985 return std::nullopt;
7986
7987 // Only allow multiple uses if the instruction is a vector load (in which
7988 // case we will use the load for every ExtractVectorElement)
7989 if (Depth && !Op.hasOneUse() &&
7990 (Op.getOpcode() != ISD::LOAD || !Op.getValueType().isVector()))
7991 return std::nullopt;
7992
7993 // Fail to combine if we have encountered anything but a LOAD after handling
7994 // an ExtractVectorElement.
7995 if (Op.getOpcode() != ISD::LOAD && VectorIndex.has_value())
7996 return std::nullopt;
7997
7998 unsigned BitWidth = Op.getValueSizeInBits();
7999 if (BitWidth % 8 != 0)
8000 return std::nullopt;
8001 unsigned ByteWidth = BitWidth / 8;
8002 assert(Index < ByteWidth && "invalid index requested");
8003 (void) ByteWidth;
8004
8005 switch (Op.getOpcode()) {
8006 case ISD::OR: {
8007 auto LHS =
8008 calculateByteProvider(Op->getOperand(0), Index, Depth + 1, VectorIndex);
8009 if (!LHS)
8010 return std::nullopt;
8011 auto RHS =
8012 calculateByteProvider(Op->getOperand(1), Index, Depth + 1, VectorIndex);
8013 if (!RHS)
8014 return std::nullopt;
8015
8016 if (LHS->isConstantZero())
8017 return RHS;
8018 if (RHS->isConstantZero())
8019 return LHS;
8020 return std::nullopt;
8021 }
8022 case ISD::SHL: {
8023 auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
8024 if (!ShiftOp)
8025 return std::nullopt;
8026
8027 uint64_t BitShift = ShiftOp->getZExtValue();
8028
8029 if (BitShift % 8 != 0)
8030 return std::nullopt;
8031 uint64_t ByteShift = BitShift / 8;
8032
8033 // If we are shifting by an amount greater than the index we are trying to
8034 // provide, then do not provide anything. Otherwise, subtract the index by
8035 // the amount we shifted by.
8036 return Index < ByteShift
8037 ? ByteProvider::getConstantZero()
8038 : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
8039 Depth + 1, VectorIndex, Index);
8040 }
8041 case ISD::ANY_EXTEND:
8042 case ISD::SIGN_EXTEND:
8043 case ISD::ZERO_EXTEND: {
8044 SDValue NarrowOp = Op->getOperand(0);
8045 unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
8046 if (NarrowBitWidth % 8 != 0)
8047 return std::nullopt;
8048 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8049
8050 if (Index >= NarrowByteWidth)
8051 return Op.getOpcode() == ISD::ZERO_EXTEND
8052 ? std::optional<ByteProvider>(ByteProvider::getConstantZero())
8053 : std::nullopt;
8054 return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex,
8055 StartingIndex);
8056 }
8057 case ISD::BSWAP:
8058 return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
8059 Depth + 1, VectorIndex, StartingIndex);
8060 case ISD::EXTRACT_VECTOR_ELT: {
8061 auto OffsetOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
8062 if (!OffsetOp)
8063 return std::nullopt;
8064
8065 VectorIndex = OffsetOp->getZExtValue();
8066
8067 SDValue NarrowOp = Op->getOperand(0);
8068 unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
8069 if (NarrowBitWidth % 8 != 0)
8070 return std::nullopt;
8071 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8072
8073 // Check to see if the position of the element in the vector corresponds
8074 // with the byte we are trying to provide for. In the case of a vector of
8075 // i8, this simply means the VectorIndex == StartingIndex. For non i8 cases,
8076 // the element will provide a range of bytes. For example, if we have a
8077 // vector of i16s, each element provides two bytes (V[1] provides byte 2 and
8078 // 3).
8079 if (*VectorIndex * NarrowByteWidth > StartingIndex)
8080 return std::nullopt;
8081 if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex)
8082 return std::nullopt;
8083
8084 return calculateByteProvider(Op->getOperand(0), Index, Depth + 1,
8085 VectorIndex, StartingIndex);
8086 }
8087 case ISD::LOAD: {
8088 auto L = cast<LoadSDNode>(Op.getNode());
8089 if (!L->isSimple() || L->isIndexed())
8090 return std::nullopt;
8091
8092 unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
8093 if (NarrowBitWidth % 8 != 0)
8094 return std::nullopt;
8095 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8096
8097 // If the width of the load does not reach byte we are trying to provide for
8098 // and it is not a ZEXTLOAD, then the load does not provide for the byte in
8099 // question
8100 if (Index >= NarrowByteWidth)
8101 return L->getExtensionType() == ISD::ZEXTLOAD
8102 ? std::optional<ByteProvider>(ByteProvider::getConstantZero())
8103 : std::nullopt;
8104
8105 unsigned BPVectorIndex = VectorIndex.value_or(0U);
8106 return ByteProvider::getMemory(L, Index, BPVectorIndex);
8107 }
8108 }
8109
8110 return std::nullopt;
8111 }
8112
littleEndianByteAt(unsigned BW,unsigned i)8113 static unsigned littleEndianByteAt(unsigned BW, unsigned i) {
8114 return i;
8115 }
8116
bigEndianByteAt(unsigned BW,unsigned i)8117 static unsigned bigEndianByteAt(unsigned BW, unsigned i) {
8118 return BW - i - 1;
8119 }
8120
8121 // Check if the bytes offsets we are looking at match with either big or
8122 // little endian value loaded. Return true for big endian, false for little
8123 // endian, and std::nullopt if match failed.
isBigEndian(const ArrayRef<int64_t> ByteOffsets,int64_t FirstOffset)8124 static std::optional<bool> isBigEndian(const ArrayRef<int64_t> ByteOffsets,
8125 int64_t FirstOffset) {
8126 // The endian can be decided only when it is 2 bytes at least.
8127 unsigned Width = ByteOffsets.size();
8128 if (Width < 2)
8129 return std::nullopt;
8130
8131 bool BigEndian = true, LittleEndian = true;
8132 for (unsigned i = 0; i < Width; i++) {
8133 int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
8134 LittleEndian &= CurrentByteOffset == littleEndianByteAt(Width, i);
8135 BigEndian &= CurrentByteOffset == bigEndianByteAt(Width, i);
8136 if (!BigEndian && !LittleEndian)
8137 return std::nullopt;
8138 }
8139
8140 assert((BigEndian != LittleEndian) && "It should be either big endian or"
8141 "little endian");
8142 return BigEndian;
8143 }
8144
stripTruncAndExt(SDValue Value)8145 static SDValue stripTruncAndExt(SDValue Value) {
8146 switch (Value.getOpcode()) {
8147 case ISD::TRUNCATE:
8148 case ISD::ZERO_EXTEND:
8149 case ISD::SIGN_EXTEND:
8150 case ISD::ANY_EXTEND:
8151 return stripTruncAndExt(Value.getOperand(0));
8152 }
8153 return Value;
8154 }
8155
8156 /// Match a pattern where a wide type scalar value is stored by several narrow
8157 /// stores. Fold it into a single store or a BSWAP and a store if the targets
8158 /// supports it.
8159 ///
8160 /// Assuming little endian target:
8161 /// i8 *p = ...
8162 /// i32 val = ...
8163 /// p[0] = (val >> 0) & 0xFF;
8164 /// p[1] = (val >> 8) & 0xFF;
8165 /// p[2] = (val >> 16) & 0xFF;
8166 /// p[3] = (val >> 24) & 0xFF;
8167 /// =>
8168 /// *((i32)p) = val;
8169 ///
8170 /// i8 *p = ...
8171 /// i32 val = ...
8172 /// p[0] = (val >> 24) & 0xFF;
8173 /// p[1] = (val >> 16) & 0xFF;
8174 /// p[2] = (val >> 8) & 0xFF;
8175 /// p[3] = (val >> 0) & 0xFF;
8176 /// =>
8177 /// *((i32)p) = BSWAP(val);
mergeTruncStores(StoreSDNode * N)8178 SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
8179 // The matching looks for "store (trunc x)" patterns that appear early but are
8180 // likely to be replaced by truncating store nodes during combining.
8181 // TODO: If there is evidence that running this later would help, this
8182 // limitation could be removed. Legality checks may need to be added
8183 // for the created store and optional bswap/rotate.
8184 if (LegalOperations || OptLevel == CodeGenOpt::None)
8185 return SDValue();
8186
8187 // We only handle merging simple stores of 1-4 bytes.
8188 // TODO: Allow unordered atomics when wider type is legal (see D66309)
8189 EVT MemVT = N->getMemoryVT();
8190 if (!(MemVT == MVT::i8 || MemVT == MVT::i16 || MemVT == MVT::i32) ||
8191 !N->isSimple() || N->isIndexed())
8192 return SDValue();
8193
8194 // Collect all of the stores in the chain.
8195 SDValue Chain = N->getChain();
8196 SmallVector<StoreSDNode *, 8> Stores = {N};
8197 while (auto *Store = dyn_cast<StoreSDNode>(Chain)) {
8198 // All stores must be the same size to ensure that we are writing all of the
8199 // bytes in the wide value.
8200 // This store should have exactly one use as a chain operand for another
8201 // store in the merging set. If there are other chain uses, then the
8202 // transform may not be safe because order of loads/stores outside of this
8203 // set may not be preserved.
8204 // TODO: We could allow multiple sizes by tracking each stored byte.
8205 if (Store->getMemoryVT() != MemVT || !Store->isSimple() ||
8206 Store->isIndexed() || !Store->hasOneUse())
8207 return SDValue();
8208 Stores.push_back(Store);
8209 Chain = Store->getChain();
8210 }
8211 // There is no reason to continue if we do not have at least a pair of stores.
8212 if (Stores.size() < 2)
8213 return SDValue();
8214
8215 // Handle simple types only.
8216 LLVMContext &Context = *DAG.getContext();
8217 unsigned NumStores = Stores.size();
8218 unsigned NarrowNumBits = N->getMemoryVT().getScalarSizeInBits();
8219 unsigned WideNumBits = NumStores * NarrowNumBits;
8220 EVT WideVT = EVT::getIntegerVT(Context, WideNumBits);
8221 if (WideVT != MVT::i16 && WideVT != MVT::i32 && WideVT != MVT::i64)
8222 return SDValue();
8223
8224 // Check if all bytes of the source value that we are looking at are stored
8225 // to the same base address. Collect offsets from Base address into OffsetMap.
8226 SDValue SourceValue;
8227 SmallVector<int64_t, 8> OffsetMap(NumStores, INT64_MAX);
8228 int64_t FirstOffset = INT64_MAX;
8229 StoreSDNode *FirstStore = nullptr;
8230 std::optional<BaseIndexOffset> Base;
8231 for (auto *Store : Stores) {
8232 // All the stores store different parts of the CombinedValue. A truncate is
8233 // required to get the partial value.
8234 SDValue Trunc = Store->getValue();
8235 if (Trunc.getOpcode() != ISD::TRUNCATE)
8236 return SDValue();
8237 // Other than the first/last part, a shift operation is required to get the
8238 // offset.
8239 int64_t Offset = 0;
8240 SDValue WideVal = Trunc.getOperand(0);
8241 if ((WideVal.getOpcode() == ISD::SRL || WideVal.getOpcode() == ISD::SRA) &&
8242 isa<ConstantSDNode>(WideVal.getOperand(1))) {
8243 // The shift amount must be a constant multiple of the narrow type.
8244 // It is translated to the offset address in the wide source value "y".
8245 //
8246 // x = srl y, ShiftAmtC
8247 // i8 z = trunc x
8248 // store z, ...
8249 uint64_t ShiftAmtC = WideVal.getConstantOperandVal(1);
8250 if (ShiftAmtC % NarrowNumBits != 0)
8251 return SDValue();
8252
8253 Offset = ShiftAmtC / NarrowNumBits;
8254 WideVal = WideVal.getOperand(0);
8255 }
8256
8257 // Stores must share the same source value with different offsets.
8258 // Truncate and extends should be stripped to get the single source value.
8259 if (!SourceValue)
8260 SourceValue = WideVal;
8261 else if (stripTruncAndExt(SourceValue) != stripTruncAndExt(WideVal))
8262 return SDValue();
8263 else if (SourceValue.getValueType() != WideVT) {
8264 if (WideVal.getValueType() == WideVT ||
8265 WideVal.getScalarValueSizeInBits() >
8266 SourceValue.getScalarValueSizeInBits())
8267 SourceValue = WideVal;
8268 // Give up if the source value type is smaller than the store size.
8269 if (SourceValue.getScalarValueSizeInBits() < WideVT.getScalarSizeInBits())
8270 return SDValue();
8271 }
8272
8273 // Stores must share the same base address.
8274 BaseIndexOffset Ptr = BaseIndexOffset::match(Store, DAG);
8275 int64_t ByteOffsetFromBase = 0;
8276 if (!Base)
8277 Base = Ptr;
8278 else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
8279 return SDValue();
8280
8281 // Remember the first store.
8282 if (ByteOffsetFromBase < FirstOffset) {
8283 FirstStore = Store;
8284 FirstOffset = ByteOffsetFromBase;
8285 }
8286 // Map the offset in the store and the offset in the combined value, and
8287 // early return if it has been set before.
8288 if (Offset < 0 || Offset >= NumStores || OffsetMap[Offset] != INT64_MAX)
8289 return SDValue();
8290 OffsetMap[Offset] = ByteOffsetFromBase;
8291 }
8292
8293 assert(FirstOffset != INT64_MAX && "First byte offset must be set");
8294 assert(FirstStore && "First store must be set");
8295
8296 // Check that a store of the wide type is both allowed and fast on the target
8297 const DataLayout &Layout = DAG.getDataLayout();
8298 unsigned Fast = 0;
8299 bool Allowed = TLI.allowsMemoryAccess(Context, Layout, WideVT,
8300 *FirstStore->getMemOperand(), &Fast);
8301 if (!Allowed || !Fast)
8302 return SDValue();
8303
8304 // Check if the pieces of the value are going to the expected places in memory
8305 // to merge the stores.
8306 auto checkOffsets = [&](bool MatchLittleEndian) {
8307 if (MatchLittleEndian) {
8308 for (unsigned i = 0; i != NumStores; ++i)
8309 if (OffsetMap[i] != i * (NarrowNumBits / 8) + FirstOffset)
8310 return false;
8311 } else { // MatchBigEndian by reversing loop counter.
8312 for (unsigned i = 0, j = NumStores - 1; i != NumStores; ++i, --j)
8313 if (OffsetMap[j] != i * (NarrowNumBits / 8) + FirstOffset)
8314 return false;
8315 }
8316 return true;
8317 };
8318
8319 // Check if the offsets line up for the native data layout of this target.
8320 bool NeedBswap = false;
8321 bool NeedRotate = false;
8322 if (!checkOffsets(Layout.isLittleEndian())) {
8323 // Special-case: check if byte offsets line up for the opposite endian.
8324 if (NarrowNumBits == 8 && checkOffsets(Layout.isBigEndian()))
8325 NeedBswap = true;
8326 else if (NumStores == 2 && checkOffsets(Layout.isBigEndian()))
8327 NeedRotate = true;
8328 else
8329 return SDValue();
8330 }
8331
8332 SDLoc DL(N);
8333 if (WideVT != SourceValue.getValueType()) {
8334 assert(SourceValue.getValueType().getScalarSizeInBits() > WideNumBits &&
8335 "Unexpected store value to merge");
8336 SourceValue = DAG.getNode(ISD::TRUNCATE, DL, WideVT, SourceValue);
8337 }
8338
8339 // Before legalize we can introduce illegal bswaps/rotates which will be later
8340 // converted to an explicit bswap sequence. This way we end up with a single
8341 // store and byte shuffling instead of several stores and byte shuffling.
8342 if (NeedBswap) {
8343 SourceValue = DAG.getNode(ISD::BSWAP, DL, WideVT, SourceValue);
8344 } else if (NeedRotate) {
8345 assert(WideNumBits % 2 == 0 && "Unexpected type for rotate");
8346 SDValue RotAmt = DAG.getConstant(WideNumBits / 2, DL, WideVT);
8347 SourceValue = DAG.getNode(ISD::ROTR, DL, WideVT, SourceValue, RotAmt);
8348 }
8349
8350 SDValue NewStore =
8351 DAG.getStore(Chain, DL, SourceValue, FirstStore->getBasePtr(),
8352 FirstStore->getPointerInfo(), FirstStore->getAlign());
8353
8354 // Rely on other DAG combine rules to remove the other individual stores.
8355 DAG.ReplaceAllUsesWith(N, NewStore.getNode());
8356 return NewStore;
8357 }
8358
8359 /// Match a pattern where a wide type scalar value is loaded by several narrow
8360 /// loads and combined by shifts and ors. Fold it into a single load or a load
8361 /// and a BSWAP if the targets supports it.
8362 ///
8363 /// Assuming little endian target:
8364 /// i8 *a = ...
8365 /// i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
8366 /// =>
8367 /// i32 val = *((i32)a)
8368 ///
8369 /// i8 *a = ...
8370 /// i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
8371 /// =>
8372 /// i32 val = BSWAP(*((i32)a))
8373 ///
8374 /// TODO: This rule matches complex patterns with OR node roots and doesn't
8375 /// interact well with the worklist mechanism. When a part of the pattern is
8376 /// updated (e.g. one of the loads) its direct users are put into the worklist,
8377 /// but the root node of the pattern which triggers the load combine is not
8378 /// necessarily a direct user of the changed node. For example, once the address
8379 /// of t28 load is reassociated load combine won't be triggered:
8380 /// t25: i32 = add t4, Constant:i32<2>
8381 /// t26: i64 = sign_extend t25
8382 /// t27: i64 = add t2, t26
8383 /// t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
8384 /// t29: i32 = zero_extend t28
8385 /// t32: i32 = shl t29, Constant:i8<8>
8386 /// t33: i32 = or t23, t32
8387 /// As a possible fix visitLoad can check if the load can be a part of a load
8388 /// combine pattern and add corresponding OR roots to the worklist.
MatchLoadCombine(SDNode * N)8389 SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
8390 assert(N->getOpcode() == ISD::OR &&
8391 "Can only match load combining against OR nodes");
8392
8393 // Handles simple types only
8394 EVT VT = N->getValueType(0);
8395 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
8396 return SDValue();
8397 unsigned ByteWidth = VT.getSizeInBits() / 8;
8398
8399 bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
8400 auto MemoryByteOffset = [&] (ByteProvider P) {
8401 assert(P.isMemory() && "Must be a memory byte provider");
8402 unsigned LoadBitWidth = P.Load->getMemoryVT().getScalarSizeInBits();
8403
8404 assert(LoadBitWidth % 8 == 0 &&
8405 "can only analyze providers for individual bytes not bit");
8406 unsigned LoadByteWidth = LoadBitWidth / 8;
8407 return IsBigEndianTarget
8408 ? bigEndianByteAt(LoadByteWidth, P.ByteOffset)
8409 : littleEndianByteAt(LoadByteWidth, P.ByteOffset);
8410 };
8411
8412 std::optional<BaseIndexOffset> Base;
8413 SDValue Chain;
8414
8415 SmallPtrSet<LoadSDNode *, 8> Loads;
8416 std::optional<ByteProvider> FirstByteProvider;
8417 int64_t FirstOffset = INT64_MAX;
8418
8419 // Check if all the bytes of the OR we are looking at are loaded from the same
8420 // base address. Collect bytes offsets from Base address in ByteOffsets.
8421 SmallVector<int64_t, 8> ByteOffsets(ByteWidth);
8422 unsigned ZeroExtendedBytes = 0;
8423 for (int i = ByteWidth - 1; i >= 0; --i) {
8424 auto P =
8425 calculateByteProvider(SDValue(N, 0), i, 0, /*VectorIndex*/ std::nullopt,
8426 /*StartingIndex*/ i);
8427 if (!P)
8428 return SDValue();
8429
8430 if (P->isConstantZero()) {
8431 // It's OK for the N most significant bytes to be 0, we can just
8432 // zero-extend the load.
8433 if (++ZeroExtendedBytes != (ByteWidth - static_cast<unsigned>(i)))
8434 return SDValue();
8435 continue;
8436 }
8437 assert(P->isMemory() && "provenance should either be memory or zero");
8438
8439 LoadSDNode *L = P->Load;
8440
8441 // All loads must share the same chain
8442 SDValue LChain = L->getChain();
8443 if (!Chain)
8444 Chain = LChain;
8445 else if (Chain != LChain)
8446 return SDValue();
8447
8448 // Loads must share the same base address
8449 BaseIndexOffset Ptr = BaseIndexOffset::match(L, DAG);
8450 int64_t ByteOffsetFromBase = 0;
8451
8452 // For vector loads, the expected load combine pattern will have an
8453 // ExtractElement for each index in the vector. While each of these
8454 // ExtractElements will be accessing the same base address as determined
8455 // by the load instruction, the actual bytes they interact with will differ
8456 // due to different ExtractElement indices. To accurately determine the
8457 // byte position of an ExtractElement, we offset the base load ptr with
8458 // the index multiplied by the byte size of each element in the vector.
8459 if (L->getMemoryVT().isVector()) {
8460 unsigned LoadWidthInBit = L->getMemoryVT().getScalarSizeInBits();
8461 if (LoadWidthInBit % 8 != 0)
8462 return SDValue();
8463 unsigned ByteOffsetFromVector = P->VectorOffset * LoadWidthInBit / 8;
8464 Ptr.addToOffset(ByteOffsetFromVector);
8465 }
8466
8467 if (!Base)
8468 Base = Ptr;
8469
8470 else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
8471 return SDValue();
8472
8473 // Calculate the offset of the current byte from the base address
8474 ByteOffsetFromBase += MemoryByteOffset(*P);
8475 ByteOffsets[i] = ByteOffsetFromBase;
8476
8477 // Remember the first byte load
8478 if (ByteOffsetFromBase < FirstOffset) {
8479 FirstByteProvider = P;
8480 FirstOffset = ByteOffsetFromBase;
8481 }
8482
8483 Loads.insert(L);
8484 }
8485
8486 assert(!Loads.empty() && "All the bytes of the value must be loaded from "
8487 "memory, so there must be at least one load which produces the value");
8488 assert(Base && "Base address of the accessed memory location must be set");
8489 assert(FirstOffset != INT64_MAX && "First byte offset must be set");
8490
8491 bool NeedsZext = ZeroExtendedBytes > 0;
8492
8493 EVT MemVT =
8494 EVT::getIntegerVT(*DAG.getContext(), (ByteWidth - ZeroExtendedBytes) * 8);
8495
8496 if (!MemVT.isSimple())
8497 return SDValue();
8498
8499 // Before legalize we can introduce too wide illegal loads which will be later
8500 // split into legal sized loads. This enables us to combine i64 load by i8
8501 // patterns to a couple of i32 loads on 32 bit targets.
8502 if (LegalOperations &&
8503 !TLI.isOperationLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD,
8504 MemVT))
8505 return SDValue();
8506
8507 // Check if the bytes of the OR we are looking at match with either big or
8508 // little endian value load
8509 std::optional<bool> IsBigEndian = isBigEndian(
8510 ArrayRef(ByteOffsets).drop_back(ZeroExtendedBytes), FirstOffset);
8511 if (!IsBigEndian)
8512 return SDValue();
8513
8514 assert(FirstByteProvider && "must be set");
8515
8516 // Ensure that the first byte is loaded from zero offset of the first load.
8517 // So the combined value can be loaded from the first load address.
8518 if (MemoryByteOffset(*FirstByteProvider) != 0)
8519 return SDValue();
8520 LoadSDNode *FirstLoad = FirstByteProvider->Load;
8521
8522 // The node we are looking at matches with the pattern, check if we can
8523 // replace it with a single (possibly zero-extended) load and bswap + shift if
8524 // needed.
8525
8526 // If the load needs byte swap check if the target supports it
8527 bool NeedsBswap = IsBigEndianTarget != *IsBigEndian;
8528
8529 // Before legalize we can introduce illegal bswaps which will be later
8530 // converted to an explicit bswap sequence. This way we end up with a single
8531 // load and byte shuffling instead of several loads and byte shuffling.
8532 // We do not introduce illegal bswaps when zero-extending as this tends to
8533 // introduce too many arithmetic instructions.
8534 if (NeedsBswap && (LegalOperations || NeedsZext) &&
8535 !TLI.isOperationLegal(ISD::BSWAP, VT))
8536 return SDValue();
8537
8538 // If we need to bswap and zero extend, we have to insert a shift. Check that
8539 // it is legal.
8540 if (NeedsBswap && NeedsZext && LegalOperations &&
8541 !TLI.isOperationLegal(ISD::SHL, VT))
8542 return SDValue();
8543
8544 // Check that a load of the wide type is both allowed and fast on the target
8545 unsigned Fast = 0;
8546 bool Allowed =
8547 TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
8548 *FirstLoad->getMemOperand(), &Fast);
8549 if (!Allowed || !Fast)
8550 return SDValue();
8551
8552 SDValue NewLoad =
8553 DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(N), VT,
8554 Chain, FirstLoad->getBasePtr(),
8555 FirstLoad->getPointerInfo(), MemVT, FirstLoad->getAlign());
8556
8557 // Transfer chain users from old loads to the new load.
8558 for (LoadSDNode *L : Loads)
8559 DAG.ReplaceAllUsesOfValueWith(SDValue(L, 1), SDValue(NewLoad.getNode(), 1));
8560
8561 if (!NeedsBswap)
8562 return NewLoad;
8563
8564 SDValue ShiftedLoad =
8565 NeedsZext
8566 ? DAG.getNode(ISD::SHL, SDLoc(N), VT, NewLoad,
8567 DAG.getShiftAmountConstant(ZeroExtendedBytes * 8, VT,
8568 SDLoc(N), LegalOperations))
8569 : NewLoad;
8570 return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, ShiftedLoad);
8571 }
8572
8573 // If the target has andn, bsl, or a similar bit-select instruction,
8574 // we want to unfold masked merge, with canonical pattern of:
8575 // | A | |B|
8576 // ((x ^ y) & m) ^ y
8577 // | D |
8578 // Into:
8579 // (x & m) | (y & ~m)
8580 // If y is a constant, m is not a 'not', and the 'andn' does not work with
8581 // immediates, we unfold into a different pattern:
8582 // ~(~x & m) & (m | y)
8583 // If x is a constant, m is a 'not', and the 'andn' does not work with
8584 // immediates, we unfold into a different pattern:
8585 // (x | ~m) & ~(~m & ~y)
8586 // NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
8587 // the very least that breaks andnpd / andnps patterns, and because those
8588 // patterns are simplified in IR and shouldn't be created in the DAG
unfoldMaskedMerge(SDNode * N)8589 SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
8590 assert(N->getOpcode() == ISD::XOR);
8591
8592 // Don't touch 'not' (i.e. where y = -1).
8593 if (isAllOnesOrAllOnesSplat(N->getOperand(1)))
8594 return SDValue();
8595
8596 EVT VT = N->getValueType(0);
8597
8598 // There are 3 commutable operators in the pattern,
8599 // so we have to deal with 8 possible variants of the basic pattern.
8600 SDValue X, Y, M;
8601 auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
8602 if (And.getOpcode() != ISD::AND || !And.hasOneUse())
8603 return false;
8604 SDValue Xor = And.getOperand(XorIdx);
8605 if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
8606 return false;
8607 SDValue Xor0 = Xor.getOperand(0);
8608 SDValue Xor1 = Xor.getOperand(1);
8609 // Don't touch 'not' (i.e. where y = -1).
8610 if (isAllOnesOrAllOnesSplat(Xor1))
8611 return false;
8612 if (Other == Xor0)
8613 std::swap(Xor0, Xor1);
8614 if (Other != Xor1)
8615 return false;
8616 X = Xor0;
8617 Y = Xor1;
8618 M = And.getOperand(XorIdx ? 0 : 1);
8619 return true;
8620 };
8621
8622 SDValue N0 = N->getOperand(0);
8623 SDValue N1 = N->getOperand(1);
8624 if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
8625 !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
8626 return SDValue();
8627
8628 // Don't do anything if the mask is constant. This should not be reachable.
8629 // InstCombine should have already unfolded this pattern, and DAGCombiner
8630 // probably shouldn't produce it, too.
8631 if (isa<ConstantSDNode>(M.getNode()))
8632 return SDValue();
8633
8634 // We can transform if the target has AndNot
8635 if (!TLI.hasAndNot(M))
8636 return SDValue();
8637
8638 SDLoc DL(N);
8639
8640 // If Y is a constant, check that 'andn' works with immediates. Unless M is
8641 // a bitwise not that would already allow ANDN to be used.
8642 if (!TLI.hasAndNot(Y) && !isBitwiseNot(M)) {
8643 assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
8644 // If not, we need to do a bit more work to make sure andn is still used.
8645 SDValue NotX = DAG.getNOT(DL, X, VT);
8646 SDValue LHS = DAG.getNode(ISD::AND, DL, VT, NotX, M);
8647 SDValue NotLHS = DAG.getNOT(DL, LHS, VT);
8648 SDValue RHS = DAG.getNode(ISD::OR, DL, VT, M, Y);
8649 return DAG.getNode(ISD::AND, DL, VT, NotLHS, RHS);
8650 }
8651
8652 // If X is a constant and M is a bitwise not, check that 'andn' works with
8653 // immediates.
8654 if (!TLI.hasAndNot(X) && isBitwiseNot(M)) {
8655 assert(TLI.hasAndNot(Y) && "Only mask is a variable? Unreachable.");
8656 // If not, we need to do a bit more work to make sure andn is still used.
8657 SDValue NotM = M.getOperand(0);
8658 SDValue LHS = DAG.getNode(ISD::OR, DL, VT, X, NotM);
8659 SDValue NotY = DAG.getNOT(DL, Y, VT);
8660 SDValue RHS = DAG.getNode(ISD::AND, DL, VT, NotM, NotY);
8661 SDValue NotRHS = DAG.getNOT(DL, RHS, VT);
8662 return DAG.getNode(ISD::AND, DL, VT, LHS, NotRHS);
8663 }
8664
8665 SDValue LHS = DAG.getNode(ISD::AND, DL, VT, X, M);
8666 SDValue NotM = DAG.getNOT(DL, M, VT);
8667 SDValue RHS = DAG.getNode(ISD::AND, DL, VT, Y, NotM);
8668
8669 return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
8670 }
8671
visitXOR(SDNode * N)8672 SDValue DAGCombiner::visitXOR(SDNode *N) {
8673 SDValue N0 = N->getOperand(0);
8674 SDValue N1 = N->getOperand(1);
8675 EVT VT = N0.getValueType();
8676 SDLoc DL(N);
8677
8678 // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
8679 if (N0.isUndef() && N1.isUndef())
8680 return DAG.getConstant(0, DL, VT);
8681
8682 // fold (xor x, undef) -> undef
8683 if (N0.isUndef())
8684 return N0;
8685 if (N1.isUndef())
8686 return N1;
8687
8688 // fold (xor c1, c2) -> c1^c2
8689 if (SDValue C = DAG.FoldConstantArithmetic(ISD::XOR, DL, VT, {N0, N1}))
8690 return C;
8691
8692 // canonicalize constant to RHS
8693 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
8694 !DAG.isConstantIntBuildVectorOrConstantInt(N1))
8695 return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
8696
8697 // fold vector ops
8698 if (VT.isVector()) {
8699 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
8700 return FoldedVOp;
8701
8702 // fold (xor x, 0) -> x, vector edition
8703 if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
8704 return N0;
8705 }
8706
8707 // fold (xor x, 0) -> x
8708 if (isNullConstant(N1))
8709 return N0;
8710
8711 if (SDValue NewSel = foldBinOpIntoSelect(N))
8712 return NewSel;
8713
8714 // reassociate xor
8715 if (SDValue RXOR = reassociateOps(ISD::XOR, DL, N0, N1, N->getFlags()))
8716 return RXOR;
8717
8718 // fold (a^b) -> (a|b) iff a and b share no bits.
8719 if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
8720 DAG.haveNoCommonBitsSet(N0, N1))
8721 return DAG.getNode(ISD::OR, DL, VT, N0, N1);
8722
8723 // look for 'add-like' folds:
8724 // XOR(N0,MIN_SIGNED_VALUE) == ADD(N0,MIN_SIGNED_VALUE)
8725 if ((!LegalOperations || TLI.isOperationLegal(ISD::ADD, VT)) &&
8726 isMinSignedConstant(N1))
8727 if (SDValue Combined = visitADDLike(N))
8728 return Combined;
8729
8730 // fold !(x cc y) -> (x !cc y)
8731 unsigned N0Opcode = N0.getOpcode();
8732 SDValue LHS, RHS, CC;
8733 if (TLI.isConstTrueVal(N1) &&
8734 isSetCCEquivalent(N0, LHS, RHS, CC, /*MatchStrict*/ true)) {
8735 ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(),
8736 LHS.getValueType());
8737 if (!LegalOperations ||
8738 TLI.isCondCodeLegal(NotCC, LHS.getSimpleValueType())) {
8739 switch (N0Opcode) {
8740 default:
8741 llvm_unreachable("Unhandled SetCC Equivalent!");
8742 case ISD::SETCC:
8743 return DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC);
8744 case ISD::SELECT_CC:
8745 return DAG.getSelectCC(SDLoc(N0), LHS, RHS, N0.getOperand(2),
8746 N0.getOperand(3), NotCC);
8747 case ISD::STRICT_FSETCC:
8748 case ISD::STRICT_FSETCCS: {
8749 if (N0.hasOneUse()) {
8750 // FIXME Can we handle multiple uses? Could we token factor the chain
8751 // results from the new/old setcc?
8752 SDValue SetCC =
8753 DAG.getSetCC(SDLoc(N0), VT, LHS, RHS, NotCC,
8754 N0.getOperand(0), N0Opcode == ISD::STRICT_FSETCCS);
8755 CombineTo(N, SetCC);
8756 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), SetCC.getValue(1));
8757 recursivelyDeleteUnusedNodes(N0.getNode());
8758 return SDValue(N, 0); // Return N so it doesn't get rechecked!
8759 }
8760 break;
8761 }
8762 }
8763 }
8764 }
8765
8766 // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
8767 if (isOneConstant(N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
8768 isSetCCEquivalent(N0.getOperand(0), LHS, RHS, CC)){
8769 SDValue V = N0.getOperand(0);
8770 SDLoc DL0(N0);
8771 V = DAG.getNode(ISD::XOR, DL0, V.getValueType(), V,
8772 DAG.getConstant(1, DL0, V.getValueType()));
8773 AddToWorklist(V.getNode());
8774 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, V);
8775 }
8776
8777 // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
8778 if (isOneConstant(N1) && VT == MVT::i1 && N0.hasOneUse() &&
8779 (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
8780 SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
8781 if (isOneUseSetCC(N01) || isOneUseSetCC(N00)) {
8782 unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
8783 N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
8784 N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
8785 AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
8786 return DAG.getNode(NewOpcode, DL, VT, N00, N01);
8787 }
8788 }
8789 // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
8790 if (isAllOnesConstant(N1) && N0.hasOneUse() &&
8791 (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
8792 SDValue N00 = N0.getOperand(0), N01 = N0.getOperand(1);
8793 if (isa<ConstantSDNode>(N01) || isa<ConstantSDNode>(N00)) {
8794 unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
8795 N00 = DAG.getNode(ISD::XOR, SDLoc(N00), VT, N00, N1); // N00 = ~N00
8796 N01 = DAG.getNode(ISD::XOR, SDLoc(N01), VT, N01, N1); // N01 = ~N01
8797 AddToWorklist(N00.getNode()); AddToWorklist(N01.getNode());
8798 return DAG.getNode(NewOpcode, DL, VT, N00, N01);
8799 }
8800 }
8801
8802 // fold (not (neg x)) -> (add X, -1)
8803 // FIXME: This can be generalized to (not (sub Y, X)) -> (add X, ~Y) if
8804 // Y is a constant or the subtract has a single use.
8805 if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::SUB &&
8806 isNullConstant(N0.getOperand(0))) {
8807 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1),
8808 DAG.getAllOnesConstant(DL, VT));
8809 }
8810
8811 // fold (not (add X, -1)) -> (neg X)
8812 if (isAllOnesConstant(N1) && N0.getOpcode() == ISD::ADD &&
8813 isAllOnesOrAllOnesSplat(N0.getOperand(1))) {
8814 return DAG.getNegative(N0.getOperand(0), DL, VT);
8815 }
8816
8817 // fold (xor (and x, y), y) -> (and (not x), y)
8818 if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(1) == N1) {
8819 SDValue X = N0.getOperand(0);
8820 SDValue NotX = DAG.getNOT(SDLoc(X), X, VT);
8821 AddToWorklist(NotX.getNode());
8822 return DAG.getNode(ISD::AND, DL, VT, NotX, N1);
8823 }
8824
8825 // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
8826 if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
8827 SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
8828 SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
8829 if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
8830 SDValue A0 = A.getOperand(0), A1 = A.getOperand(1);
8831 SDValue S0 = S.getOperand(0);
8832 if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0))
8833 if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1)))
8834 if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
8835 return DAG.getNode(ISD::ABS, DL, VT, S0);
8836 }
8837 }
8838
8839 // fold (xor x, x) -> 0
8840 if (N0 == N1)
8841 return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
8842
8843 // fold (xor (shl 1, x), -1) -> (rotl ~1, x)
8844 // Here is a concrete example of this equivalence:
8845 // i16 x == 14
8846 // i16 shl == 1 << 14 == 16384 == 0b0100000000000000
8847 // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111
8848 //
8849 // =>
8850 //
8851 // i16 ~1 == 0b1111111111111110
8852 // i16 rol(~1, 14) == 0b1011111111111111
8853 //
8854 // Some additional tips to help conceptualize this transform:
8855 // - Try to see the operation as placing a single zero in a value of all ones.
8856 // - There exists no value for x which would allow the result to contain zero.
8857 // - Values of x larger than the bitwidth are undefined and do not require a
8858 // consistent result.
8859 // - Pushing the zero left requires shifting one bits in from the right.
8860 // A rotate left of ~1 is a nice way of achieving the desired result.
8861 if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
8862 isAllOnesConstant(N1) && isOneConstant(N0.getOperand(0))) {
8863 return DAG.getNode(ISD::ROTL, DL, VT, DAG.getConstant(~1, DL, VT),
8864 N0.getOperand(1));
8865 }
8866
8867 // Simplify: xor (op x...), (op y...) -> (op (xor x, y))
8868 if (N0Opcode == N1.getOpcode())
8869 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
8870 return V;
8871
8872 if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
8873 return R;
8874 if (SDValue R = foldLogicOfShifts(N, N1, N0, DAG))
8875 return R;
8876 if (SDValue R = foldLogicTreeOfShifts(N, N0, N1, DAG))
8877 return R;
8878
8879 // Unfold ((x ^ y) & m) ^ y into (x & m) | (y & ~m) if profitable
8880 if (SDValue MM = unfoldMaskedMerge(N))
8881 return MM;
8882
8883 // Simplify the expression using non-local knowledge.
8884 if (SimplifyDemandedBits(SDValue(N, 0)))
8885 return SDValue(N, 0);
8886
8887 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
8888 return Combined;
8889
8890 return SDValue();
8891 }
8892
8893 /// If we have a shift-by-constant of a bitwise logic op that itself has a
8894 /// shift-by-constant operand with identical opcode, we may be able to convert
8895 /// that into 2 independent shifts followed by the logic op. This is a
8896 /// throughput improvement.
combineShiftOfShiftedLogic(SDNode * Shift,SelectionDAG & DAG)8897 static SDValue combineShiftOfShiftedLogic(SDNode *Shift, SelectionDAG &DAG) {
8898 // Match a one-use bitwise logic op.
8899 SDValue LogicOp = Shift->getOperand(0);
8900 if (!LogicOp.hasOneUse())
8901 return SDValue();
8902
8903 unsigned LogicOpcode = LogicOp.getOpcode();
8904 if (LogicOpcode != ISD::AND && LogicOpcode != ISD::OR &&
8905 LogicOpcode != ISD::XOR)
8906 return SDValue();
8907
8908 // Find a matching one-use shift by constant.
8909 unsigned ShiftOpcode = Shift->getOpcode();
8910 SDValue C1 = Shift->getOperand(1);
8911 ConstantSDNode *C1Node = isConstOrConstSplat(C1);
8912 assert(C1Node && "Expected a shift with constant operand");
8913 const APInt &C1Val = C1Node->getAPIntValue();
8914 auto matchFirstShift = [&](SDValue V, SDValue &ShiftOp,
8915 const APInt *&ShiftAmtVal) {
8916 if (V.getOpcode() != ShiftOpcode || !V.hasOneUse())
8917 return false;
8918
8919 ConstantSDNode *ShiftCNode = isConstOrConstSplat(V.getOperand(1));
8920 if (!ShiftCNode)
8921 return false;
8922
8923 // Capture the shifted operand and shift amount value.
8924 ShiftOp = V.getOperand(0);
8925 ShiftAmtVal = &ShiftCNode->getAPIntValue();
8926
8927 // Shift amount types do not have to match their operand type, so check that
8928 // the constants are the same width.
8929 if (ShiftAmtVal->getBitWidth() != C1Val.getBitWidth())
8930 return false;
8931
8932 // The fold is not valid if the sum of the shift values exceeds bitwidth.
8933 if ((*ShiftAmtVal + C1Val).uge(V.getScalarValueSizeInBits()))
8934 return false;
8935
8936 return true;
8937 };
8938
8939 // Logic ops are commutative, so check each operand for a match.
8940 SDValue X, Y;
8941 const APInt *C0Val;
8942 if (matchFirstShift(LogicOp.getOperand(0), X, C0Val))
8943 Y = LogicOp.getOperand(1);
8944 else if (matchFirstShift(LogicOp.getOperand(1), X, C0Val))
8945 Y = LogicOp.getOperand(0);
8946 else
8947 return SDValue();
8948
8949 // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
8950 SDLoc DL(Shift);
8951 EVT VT = Shift->getValueType(0);
8952 EVT ShiftAmtVT = Shift->getOperand(1).getValueType();
8953 SDValue ShiftSumC = DAG.getConstant(*C0Val + C1Val, DL, ShiftAmtVT);
8954 SDValue NewShift1 = DAG.getNode(ShiftOpcode, DL, VT, X, ShiftSumC);
8955 SDValue NewShift2 = DAG.getNode(ShiftOpcode, DL, VT, Y, C1);
8956 return DAG.getNode(LogicOpcode, DL, VT, NewShift1, NewShift2);
8957 }
8958
8959 /// Handle transforms common to the three shifts, when the shift amount is a
8960 /// constant.
8961 /// We are looking for: (shift being one of shl/sra/srl)
8962 /// shift (binop X, C0), C1
8963 /// And want to transform into:
8964 /// binop (shift X, C1), (shift C0, C1)
visitShiftByConstant(SDNode * N)8965 SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
8966 assert(isConstOrConstSplat(N->getOperand(1)) && "Expected constant operand");
8967
8968 // Do not turn a 'not' into a regular xor.
8969 if (isBitwiseNot(N->getOperand(0)))
8970 return SDValue();
8971
8972 // The inner binop must be one-use, since we want to replace it.
8973 SDValue LHS = N->getOperand(0);
8974 if (!LHS.hasOneUse() || !TLI.isDesirableToCommuteWithShift(N, Level))
8975 return SDValue();
8976
8977 // Fold shift(bitop(shift(x,c1),y), c2) -> bitop(shift(x,c1+c2),shift(y,c2)).
8978 if (SDValue R = combineShiftOfShiftedLogic(N, DAG))
8979 return R;
8980
8981 // We want to pull some binops through shifts, so that we have (and (shift))
8982 // instead of (shift (and)), likewise for add, or, xor, etc. This sort of
8983 // thing happens with address calculations, so it's important to canonicalize
8984 // it.
8985 switch (LHS.getOpcode()) {
8986 default:
8987 return SDValue();
8988 case ISD::OR:
8989 case ISD::XOR:
8990 case ISD::AND:
8991 break;
8992 case ISD::ADD:
8993 if (N->getOpcode() != ISD::SHL)
8994 return SDValue(); // only shl(add) not sr[al](add).
8995 break;
8996 }
8997
8998 // FIXME: disable this unless the input to the binop is a shift by a constant
8999 // or is copy/select. Enable this in other cases when figure out it's exactly
9000 // profitable.
9001 SDValue BinOpLHSVal = LHS.getOperand(0);
9002 bool IsShiftByConstant = (BinOpLHSVal.getOpcode() == ISD::SHL ||
9003 BinOpLHSVal.getOpcode() == ISD::SRA ||
9004 BinOpLHSVal.getOpcode() == ISD::SRL) &&
9005 isa<ConstantSDNode>(BinOpLHSVal.getOperand(1));
9006 bool IsCopyOrSelect = BinOpLHSVal.getOpcode() == ISD::CopyFromReg ||
9007 BinOpLHSVal.getOpcode() == ISD::SELECT;
9008
9009 if (!IsShiftByConstant && !IsCopyOrSelect)
9010 return SDValue();
9011
9012 if (IsCopyOrSelect && N->hasOneUse())
9013 return SDValue();
9014
9015 // Attempt to fold the constants, shifting the binop RHS by the shift amount.
9016 SDLoc DL(N);
9017 EVT VT = N->getValueType(0);
9018 if (SDValue NewRHS = DAG.FoldConstantArithmetic(
9019 N->getOpcode(), DL, VT, {LHS.getOperand(1), N->getOperand(1)})) {
9020 SDValue NewShift = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(0),
9021 N->getOperand(1));
9022 return DAG.getNode(LHS.getOpcode(), DL, VT, NewShift, NewRHS);
9023 }
9024
9025 return SDValue();
9026 }
9027
distributeTruncateThroughAnd(SDNode * N)9028 SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
9029 assert(N->getOpcode() == ISD::TRUNCATE);
9030 assert(N->getOperand(0).getOpcode() == ISD::AND);
9031
9032 // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
9033 EVT TruncVT = N->getValueType(0);
9034 if (N->hasOneUse() && N->getOperand(0).hasOneUse() &&
9035 TLI.isTypeDesirableForOp(ISD::AND, TruncVT)) {
9036 SDValue N01 = N->getOperand(0).getOperand(1);
9037 if (isConstantOrConstantVector(N01, /* NoOpaques */ true)) {
9038 SDLoc DL(N);
9039 SDValue N00 = N->getOperand(0).getOperand(0);
9040 SDValue Trunc00 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N00);
9041 SDValue Trunc01 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, N01);
9042 AddToWorklist(Trunc00.getNode());
9043 AddToWorklist(Trunc01.getNode());
9044 return DAG.getNode(ISD::AND, DL, TruncVT, Trunc00, Trunc01);
9045 }
9046 }
9047
9048 return SDValue();
9049 }
9050
visitRotate(SDNode * N)9051 SDValue DAGCombiner::visitRotate(SDNode *N) {
9052 SDLoc dl(N);
9053 SDValue N0 = N->getOperand(0);
9054 SDValue N1 = N->getOperand(1);
9055 EVT VT = N->getValueType(0);
9056 unsigned Bitsize = VT.getScalarSizeInBits();
9057
9058 // fold (rot x, 0) -> x
9059 if (isNullOrNullSplat(N1))
9060 return N0;
9061
9062 // fold (rot x, c) -> x iff (c % BitSize) == 0
9063 if (isPowerOf2_32(Bitsize) && Bitsize > 1) {
9064 APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
9065 if (DAG.MaskedValueIsZero(N1, ModuloMask))
9066 return N0;
9067 }
9068
9069 // fold (rot x, c) -> (rot x, c % BitSize)
9070 bool OutOfRange = false;
9071 auto MatchOutOfRange = [Bitsize, &OutOfRange](ConstantSDNode *C) {
9072 OutOfRange |= C->getAPIntValue().uge(Bitsize);
9073 return true;
9074 };
9075 if (ISD::matchUnaryPredicate(N1, MatchOutOfRange) && OutOfRange) {
9076 EVT AmtVT = N1.getValueType();
9077 SDValue Bits = DAG.getConstant(Bitsize, dl, AmtVT);
9078 if (SDValue Amt =
9079 DAG.FoldConstantArithmetic(ISD::UREM, dl, AmtVT, {N1, Bits}))
9080 return DAG.getNode(N->getOpcode(), dl, VT, N0, Amt);
9081 }
9082
9083 // rot i16 X, 8 --> bswap X
9084 auto *RotAmtC = isConstOrConstSplat(N1);
9085 if (RotAmtC && RotAmtC->getAPIntValue() == 8 &&
9086 VT.getScalarSizeInBits() == 16 && hasOperation(ISD::BSWAP, VT))
9087 return DAG.getNode(ISD::BSWAP, dl, VT, N0);
9088
9089 // Simplify the operands using demanded-bits information.
9090 if (SimplifyDemandedBits(SDValue(N, 0)))
9091 return SDValue(N, 0);
9092
9093 // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
9094 if (N1.getOpcode() == ISD::TRUNCATE &&
9095 N1.getOperand(0).getOpcode() == ISD::AND) {
9096 if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
9097 return DAG.getNode(N->getOpcode(), dl, VT, N0, NewOp1);
9098 }
9099
9100 unsigned NextOp = N0.getOpcode();
9101
9102 // fold (rot* (rot* x, c2), c1)
9103 // -> (rot* x, ((c1 % bitsize) +- (c2 % bitsize) + bitsize) % bitsize)
9104 if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
9105 SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N1);
9106 SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N0.getOperand(1));
9107 if (C1 && C2 && C1->getValueType(0) == C2->getValueType(0)) {
9108 EVT ShiftVT = C1->getValueType(0);
9109 bool SameSide = (N->getOpcode() == NextOp);
9110 unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
9111 SDValue BitsizeC = DAG.getConstant(Bitsize, dl, ShiftVT);
9112 SDValue Norm1 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT,
9113 {N1, BitsizeC});
9114 SDValue Norm2 = DAG.FoldConstantArithmetic(ISD::UREM, dl, ShiftVT,
9115 {N0.getOperand(1), BitsizeC});
9116 if (Norm1 && Norm2)
9117 if (SDValue CombinedShift = DAG.FoldConstantArithmetic(
9118 CombineOp, dl, ShiftVT, {Norm1, Norm2})) {
9119 CombinedShift = DAG.FoldConstantArithmetic(ISD::ADD, dl, ShiftVT,
9120 {CombinedShift, BitsizeC});
9121 SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
9122 ISD::UREM, dl, ShiftVT, {CombinedShift, BitsizeC});
9123 return DAG.getNode(N->getOpcode(), dl, VT, N0->getOperand(0),
9124 CombinedShiftNorm);
9125 }
9126 }
9127 }
9128 return SDValue();
9129 }
9130
visitSHL(SDNode * N)9131 SDValue DAGCombiner::visitSHL(SDNode *N) {
9132 SDValue N0 = N->getOperand(0);
9133 SDValue N1 = N->getOperand(1);
9134 if (SDValue V = DAG.simplifyShift(N0, N1))
9135 return V;
9136
9137 EVT VT = N0.getValueType();
9138 EVT ShiftVT = N1.getValueType();
9139 unsigned OpSizeInBits = VT.getScalarSizeInBits();
9140
9141 // fold (shl c1, c2) -> c1<<c2
9142 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, {N0, N1}))
9143 return C;
9144
9145 // fold vector ops
9146 if (VT.isVector()) {
9147 if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
9148 return FoldedVOp;
9149
9150 BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(N1);
9151 // If setcc produces all-one true value then:
9152 // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV)
9153 if (N1CV && N1CV->isConstant()) {
9154 if (N0.getOpcode() == ISD::AND) {
9155 SDValue N00 = N0->getOperand(0);
9156 SDValue N01 = N0->getOperand(1);
9157 BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(N01);
9158
9159 if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC &&
9160 TLI.getBooleanContents(N00.getOperand(0).getValueType()) ==
9161 TargetLowering::ZeroOrNegativeOneBooleanContent) {
9162 if (SDValue C =
9163 DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, {N01, N1}))
9164 return DAG.getNode(ISD::AND, SDLoc(N), VT, N00, C);
9165 }
9166 }
9167 }
9168 }
9169
9170 if (SDValue NewSel = foldBinOpIntoSelect(N))
9171 return NewSel;
9172
9173 // if (shl x, c) is known to be zero, return 0
9174 if (DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
9175 return DAG.getConstant(0, SDLoc(N), VT);
9176
9177 // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
9178 if (N1.getOpcode() == ISD::TRUNCATE &&
9179 N1.getOperand(0).getOpcode() == ISD::AND) {
9180 if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
9181 return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, NewOp1);
9182 }
9183
9184 if (SimplifyDemandedBits(SDValue(N, 0)))
9185 return SDValue(N, 0);
9186
9187 // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
9188 if (N0.getOpcode() == ISD::SHL) {
9189 auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
9190 ConstantSDNode *RHS) {
9191 APInt c1 = LHS->getAPIntValue();
9192 APInt c2 = RHS->getAPIntValue();
9193 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9194 return (c1 + c2).uge(OpSizeInBits);
9195 };
9196 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
9197 return DAG.getConstant(0, SDLoc(N), VT);
9198
9199 auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
9200 ConstantSDNode *RHS) {
9201 APInt c1 = LHS->getAPIntValue();
9202 APInt c2 = RHS->getAPIntValue();
9203 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9204 return (c1 + c2).ult(OpSizeInBits);
9205 };
9206 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
9207 SDLoc DL(N);
9208 SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
9209 return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Sum);
9210 }
9211 }
9212
9213 // fold (shl (ext (shl x, c1)), c2) -> (shl (ext x), (add c1, c2))
9214 // For this to be valid, the second form must not preserve any of the bits
9215 // that are shifted out by the inner shift in the first form. This means
9216 // the outer shift size must be >= the number of bits added by the ext.
9217 // As a corollary, we don't care what kind of ext it is.
9218 if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
9219 N0.getOpcode() == ISD::ANY_EXTEND ||
9220 N0.getOpcode() == ISD::SIGN_EXTEND) &&
9221 N0.getOperand(0).getOpcode() == ISD::SHL) {
9222 SDValue N0Op0 = N0.getOperand(0);
9223 SDValue InnerShiftAmt = N0Op0.getOperand(1);
9224 EVT InnerVT = N0Op0.getValueType();
9225 uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
9226
9227 auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
9228 ConstantSDNode *RHS) {
9229 APInt c1 = LHS->getAPIntValue();
9230 APInt c2 = RHS->getAPIntValue();
9231 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9232 return c2.uge(OpSizeInBits - InnerBitwidth) &&
9233 (c1 + c2).uge(OpSizeInBits);
9234 };
9235 if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchOutOfRange,
9236 /*AllowUndefs*/ false,
9237 /*AllowTypeMismatch*/ true))
9238 return DAG.getConstant(0, SDLoc(N), VT);
9239
9240 auto MatchInRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
9241 ConstantSDNode *RHS) {
9242 APInt c1 = LHS->getAPIntValue();
9243 APInt c2 = RHS->getAPIntValue();
9244 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9245 return c2.uge(OpSizeInBits - InnerBitwidth) &&
9246 (c1 + c2).ult(OpSizeInBits);
9247 };
9248 if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchInRange,
9249 /*AllowUndefs*/ false,
9250 /*AllowTypeMismatch*/ true)) {
9251 SDLoc DL(N);
9252 SDValue Ext = DAG.getNode(N0.getOpcode(), DL, VT, N0Op0.getOperand(0));
9253 SDValue Sum = DAG.getZExtOrTrunc(InnerShiftAmt, DL, ShiftVT);
9254 Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, Sum, N1);
9255 return DAG.getNode(ISD::SHL, DL, VT, Ext, Sum);
9256 }
9257 }
9258
9259 // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
9260 // Only fold this if the inner zext has no other uses to avoid increasing
9261 // the total number of instructions.
9262 if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
9263 N0.getOperand(0).getOpcode() == ISD::SRL) {
9264 SDValue N0Op0 = N0.getOperand(0);
9265 SDValue InnerShiftAmt = N0Op0.getOperand(1);
9266
9267 auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) {
9268 APInt c1 = LHS->getAPIntValue();
9269 APInt c2 = RHS->getAPIntValue();
9270 zeroExtendToMatch(c1, c2);
9271 return c1.ult(VT.getScalarSizeInBits()) && (c1 == c2);
9272 };
9273 if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchEqual,
9274 /*AllowUndefs*/ false,
9275 /*AllowTypeMismatch*/ true)) {
9276 SDLoc DL(N);
9277 EVT InnerShiftAmtVT = N0Op0.getOperand(1).getValueType();
9278 SDValue NewSHL = DAG.getZExtOrTrunc(N1, DL, InnerShiftAmtVT);
9279 NewSHL = DAG.getNode(ISD::SHL, DL, N0Op0.getValueType(), N0Op0, NewSHL);
9280 AddToWorklist(NewSHL.getNode());
9281 return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL);
9282 }
9283 }
9284
9285 if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) {
9286 auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
9287 ConstantSDNode *RHS) {
9288 const APInt &LHSC = LHS->getAPIntValue();
9289 const APInt &RHSC = RHS->getAPIntValue();
9290 return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
9291 LHSC.getZExtValue() <= RHSC.getZExtValue();
9292 };
9293
9294 SDLoc DL(N);
9295
9296 // fold (shl (sr[la] exact X, C1), C2) -> (shl X, (C2-C1)) if C1 <= C2
9297 // fold (shl (sr[la] exact X, C1), C2) -> (sr[la] X, (C2-C1)) if C1 >= C2
9298 if (N0->getFlags().hasExact()) {
9299 if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
9300 /*AllowUndefs*/ false,
9301 /*AllowTypeMismatch*/ true)) {
9302 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9303 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
9304 return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
9305 }
9306 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
9307 /*AllowUndefs*/ false,
9308 /*AllowTypeMismatch*/ true)) {
9309 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9310 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
9311 return DAG.getNode(N0.getOpcode(), DL, VT, N0.getOperand(0), Diff);
9312 }
9313 }
9314
9315 // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
9316 // (and (srl x, (sub c1, c2), MASK)
9317 // Only fold this if the inner shift has no other uses -- if it does,
9318 // folding this will increase the total number of instructions.
9319 if (N0.getOpcode() == ISD::SRL &&
9320 (N0.getOperand(1) == N1 || N0.hasOneUse()) &&
9321 TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
9322 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
9323 /*AllowUndefs*/ false,
9324 /*AllowTypeMismatch*/ true)) {
9325 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9326 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
9327 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
9328 Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N01);
9329 Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, Diff);
9330 SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
9331 return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
9332 }
9333 if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
9334 /*AllowUndefs*/ false,
9335 /*AllowTypeMismatch*/ true)) {
9336 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9337 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
9338 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
9339 Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, N1);
9340 SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
9341 return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
9342 }
9343 }
9344 }
9345
9346 // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
9347 if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1) &&
9348 isConstantOrConstantVector(N1, /* No Opaques */ true)) {
9349 SDLoc DL(N);
9350 SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
9351 SDValue HiBitsMask = DAG.getNode(ISD::SHL, DL, VT, AllBits, N1);
9352 return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), HiBitsMask);
9353 }
9354
9355 // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
9356 // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
9357 // Variant of version done on multiply, except mul by a power of 2 is turned
9358 // into a shift.
9359 if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
9360 N0->hasOneUse() &&
9361 isConstantOrConstantVector(N1, /* No Opaques */ true) &&
9362 isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true) &&
9363 TLI.isDesirableToCommuteWithShift(N, Level)) {
9364 SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1);
9365 SDValue Shl1 = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1);
9366 AddToWorklist(Shl0.getNode());
9367 AddToWorklist(Shl1.getNode());
9368 return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, Shl0, Shl1);
9369 }
9370
9371 // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
9372 if (N0.getOpcode() == ISD::MUL && N0->hasOneUse()) {
9373 SDValue N01 = N0.getOperand(1);
9374 if (SDValue Shl =
9375 DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, {N01, N1}))
9376 return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Shl);
9377 }
9378
9379 ConstantSDNode *N1C = isConstOrConstSplat(N1);
9380 if (N1C && !N1C->isOpaque())
9381 if (SDValue NewSHL = visitShiftByConstant(N))
9382 return NewSHL;
9383
9384 // Fold (shl (vscale * C0), C1) to (vscale * (C0 << C1)).
9385 if (N0.getOpcode() == ISD::VSCALE && N1C) {
9386 const APInt &C0 = N0.getConstantOperandAPInt(0);
9387 const APInt &C1 = N1C->getAPIntValue();
9388 return DAG.getVScale(SDLoc(N), VT, C0 << C1);
9389 }
9390
9391 // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
9392 APInt ShlVal;
9393 if (N0.getOpcode() == ISD::STEP_VECTOR &&
9394 ISD::isConstantSplatVector(N1.getNode(), ShlVal)) {
9395 const APInt &C0 = N0.getConstantOperandAPInt(0);
9396 if (ShlVal.ult(C0.getBitWidth())) {
9397 APInt NewStep = C0 << ShlVal;
9398 return DAG.getStepVector(SDLoc(N), VT, NewStep);
9399 }
9400 }
9401
9402 return SDValue();
9403 }
9404
9405 // Transform a right shift of a multiply into a multiply-high.
9406 // Examples:
9407 // (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b)
9408 // (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b)
combineShiftToMULH(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)9409 static SDValue combineShiftToMULH(SDNode *N, SelectionDAG &DAG,
9410 const TargetLowering &TLI) {
9411 assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
9412 "SRL or SRA node is required here!");
9413
9414 // Check the shift amount. Proceed with the transformation if the shift
9415 // amount is constant.
9416 ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N->getOperand(1));
9417 if (!ShiftAmtSrc)
9418 return SDValue();
9419
9420 SDLoc DL(N);
9421
9422 // The operation feeding into the shift must be a multiply.
9423 SDValue ShiftOperand = N->getOperand(0);
9424 if (ShiftOperand.getOpcode() != ISD::MUL)
9425 return SDValue();
9426
9427 // Both operands must be equivalent extend nodes.
9428 SDValue LeftOp = ShiftOperand.getOperand(0);
9429 SDValue RightOp = ShiftOperand.getOperand(1);
9430
9431 bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND;
9432 bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND;
9433
9434 if (!IsSignExt && !IsZeroExt)
9435 return SDValue();
9436
9437 EVT NarrowVT = LeftOp.getOperand(0).getValueType();
9438 unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits();
9439
9440 // return true if U may use the lower bits of its operands
9441 auto UserOfLowerBits = [NarrowVTSize](SDNode *U) {
9442 if (U->getOpcode() != ISD::SRL && U->getOpcode() != ISD::SRA) {
9443 return true;
9444 }
9445 ConstantSDNode *UShiftAmtSrc = isConstOrConstSplat(U->getOperand(1));
9446 if (!UShiftAmtSrc) {
9447 return true;
9448 }
9449 unsigned UShiftAmt = UShiftAmtSrc->getZExtValue();
9450 return UShiftAmt < NarrowVTSize;
9451 };
9452
9453 // If the lower part of the MUL is also used and MUL_LOHI is supported
9454 // do not introduce the MULH in favor of MUL_LOHI
9455 unsigned MulLoHiOp = IsSignExt ? ISD::SMUL_LOHI : ISD::UMUL_LOHI;
9456 if (!ShiftOperand.hasOneUse() &&
9457 TLI.isOperationLegalOrCustom(MulLoHiOp, NarrowVT) &&
9458 llvm::any_of(ShiftOperand->uses(), UserOfLowerBits)) {
9459 return SDValue();
9460 }
9461
9462 SDValue MulhRightOp;
9463 if (ConstantSDNode *Constant = isConstOrConstSplat(RightOp)) {
9464 unsigned ActiveBits = IsSignExt
9465 ? Constant->getAPIntValue().getMinSignedBits()
9466 : Constant->getAPIntValue().getActiveBits();
9467 if (ActiveBits > NarrowVTSize)
9468 return SDValue();
9469 MulhRightOp = DAG.getConstant(
9470 Constant->getAPIntValue().trunc(NarrowVT.getScalarSizeInBits()), DL,
9471 NarrowVT);
9472 } else {
9473 if (LeftOp.getOpcode() != RightOp.getOpcode())
9474 return SDValue();
9475 // Check that the two extend nodes are the same type.
9476 if (NarrowVT != RightOp.getOperand(0).getValueType())
9477 return SDValue();
9478 MulhRightOp = RightOp.getOperand(0);
9479 }
9480
9481 EVT WideVT = LeftOp.getValueType();
9482 // Proceed with the transformation if the wide types match.
9483 assert((WideVT == RightOp.getValueType()) &&
9484 "Cannot have a multiply node with two different operand types.");
9485
9486 // Proceed with the transformation if the wide type is twice as large
9487 // as the narrow type.
9488 if (WideVT.getScalarSizeInBits() != 2 * NarrowVTSize)
9489 return SDValue();
9490
9491 // Check the shift amount with the narrow type size.
9492 // Proceed with the transformation if the shift amount is the width
9493 // of the narrow type.
9494 unsigned ShiftAmt = ShiftAmtSrc->getZExtValue();
9495 if (ShiftAmt != NarrowVTSize)
9496 return SDValue();
9497
9498 // If the operation feeding into the MUL is a sign extend (sext),
9499 // we use mulhs. Othewise, zero extends (zext) use mulhu.
9500 unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU;
9501
9502 // Combine to mulh if mulh is legal/custom for the narrow type on the target.
9503 if (!TLI.isOperationLegalOrCustom(MulhOpcode, NarrowVT))
9504 return SDValue();
9505
9506 SDValue Result =
9507 DAG.getNode(MulhOpcode, DL, NarrowVT, LeftOp.getOperand(0), MulhRightOp);
9508 return (N->getOpcode() == ISD::SRA ? DAG.getSExtOrTrunc(Result, DL, WideVT)
9509 : DAG.getZExtOrTrunc(Result, DL, WideVT));
9510 }
9511
visitSRA(SDNode * N)9512 SDValue DAGCombiner::visitSRA(SDNode *N) {
9513 SDValue N0 = N->getOperand(0);
9514 SDValue N1 = N->getOperand(1);
9515 if (SDValue V = DAG.simplifyShift(N0, N1))
9516 return V;
9517
9518 EVT VT = N0.getValueType();
9519 unsigned OpSizeInBits = VT.getScalarSizeInBits();
9520
9521 // fold (sra c1, c2) -> (sra c1, c2)
9522 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRA, SDLoc(N), VT, {N0, N1}))
9523 return C;
9524
9525 // Arithmetic shifting an all-sign-bit value is a no-op.
9526 // fold (sra 0, x) -> 0
9527 // fold (sra -1, x) -> -1
9528 if (DAG.ComputeNumSignBits(N0) == OpSizeInBits)
9529 return N0;
9530
9531 // fold vector ops
9532 if (VT.isVector())
9533 if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
9534 return FoldedVOp;
9535
9536 if (SDValue NewSel = foldBinOpIntoSelect(N))
9537 return NewSel;
9538
9539 // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target supports
9540 // sext_inreg.
9541 ConstantSDNode *N1C = isConstOrConstSplat(N1);
9542 if (N1C && N0.getOpcode() == ISD::SHL && N1 == N0.getOperand(1)) {
9543 unsigned LowBits = OpSizeInBits - (unsigned)N1C->getZExtValue();
9544 EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), LowBits);
9545 if (VT.isVector())
9546 ExtVT = EVT::getVectorVT(*DAG.getContext(), ExtVT,
9547 VT.getVectorElementCount());
9548 if (!LegalOperations ||
9549 TLI.getOperationAction(ISD::SIGN_EXTEND_INREG, ExtVT) ==
9550 TargetLowering::Legal)
9551 return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT,
9552 N0.getOperand(0), DAG.getValueType(ExtVT));
9553 // Even if we can't convert to sext_inreg, we might be able to remove
9554 // this shift pair if the input is already sign extended.
9555 if (DAG.ComputeNumSignBits(N0.getOperand(0)) > N1C->getZExtValue())
9556 return N0.getOperand(0);
9557 }
9558
9559 // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
9560 // clamp (add c1, c2) to max shift.
9561 if (N0.getOpcode() == ISD::SRA) {
9562 SDLoc DL(N);
9563 EVT ShiftVT = N1.getValueType();
9564 EVT ShiftSVT = ShiftVT.getScalarType();
9565 SmallVector<SDValue, 16> ShiftValues;
9566
9567 auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
9568 APInt c1 = LHS->getAPIntValue();
9569 APInt c2 = RHS->getAPIntValue();
9570 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9571 APInt Sum = c1 + c2;
9572 unsigned ShiftSum =
9573 Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
9574 ShiftValues.push_back(DAG.getConstant(ShiftSum, DL, ShiftSVT));
9575 return true;
9576 };
9577 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), SumOfShifts)) {
9578 SDValue ShiftValue;
9579 if (N1.getOpcode() == ISD::BUILD_VECTOR)
9580 ShiftValue = DAG.getBuildVector(ShiftVT, DL, ShiftValues);
9581 else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
9582 assert(ShiftValues.size() == 1 &&
9583 "Expected matchBinaryPredicate to return one element for "
9584 "SPLAT_VECTORs");
9585 ShiftValue = DAG.getSplatVector(ShiftVT, DL, ShiftValues[0]);
9586 } else
9587 ShiftValue = ShiftValues[0];
9588 return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), ShiftValue);
9589 }
9590 }
9591
9592 // fold (sra (shl X, m), (sub result_size, n))
9593 // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
9594 // result_size - n != m.
9595 // If truncate is free for the target sext(shl) is likely to result in better
9596 // code.
9597 if (N0.getOpcode() == ISD::SHL && N1C) {
9598 // Get the two constanst of the shifts, CN0 = m, CN = n.
9599 const ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1));
9600 if (N01C) {
9601 LLVMContext &Ctx = *DAG.getContext();
9602 // Determine what the truncate's result bitsize and type would be.
9603 EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - N1C->getZExtValue());
9604
9605 if (VT.isVector())
9606 TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorElementCount());
9607
9608 // Determine the residual right-shift amount.
9609 int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
9610
9611 // If the shift is not a no-op (in which case this should be just a sign
9612 // extend already), the truncated to type is legal, sign_extend is legal
9613 // on that type, and the truncate to that type is both legal and free,
9614 // perform the transform.
9615 if ((ShiftAmt > 0) &&
9616 TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND, TruncVT) &&
9617 TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
9618 TLI.isTruncateFree(VT, TruncVT)) {
9619 SDLoc DL(N);
9620 SDValue Amt = DAG.getConstant(ShiftAmt, DL,
9621 getShiftAmountTy(N0.getOperand(0).getValueType()));
9622 SDValue Shift = DAG.getNode(ISD::SRL, DL, VT,
9623 N0.getOperand(0), Amt);
9624 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT,
9625 Shift);
9626 return DAG.getNode(ISD::SIGN_EXTEND, DL,
9627 N->getValueType(0), Trunc);
9628 }
9629 }
9630 }
9631
9632 // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper.
9633 // sra (add (shl X, N1C), AddC), N1C -->
9634 // sext (add (trunc X to (width - N1C)), AddC')
9635 // sra (sub AddC, (shl X, N1C)), N1C -->
9636 // sext (sub AddC1',(trunc X to (width - N1C)))
9637 if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) && N1C &&
9638 N0.hasOneUse()) {
9639 bool IsAdd = N0.getOpcode() == ISD::ADD;
9640 SDValue Shl = N0.getOperand(IsAdd ? 0 : 1);
9641 if (Shl.getOpcode() == ISD::SHL && Shl.getOperand(1) == N1 &&
9642 Shl.hasOneUse()) {
9643 // TODO: AddC does not need to be a splat.
9644 if (ConstantSDNode *AddC =
9645 isConstOrConstSplat(N0.getOperand(IsAdd ? 1 : 0))) {
9646 // Determine what the truncate's type would be and ask the target if
9647 // that is a free operation.
9648 LLVMContext &Ctx = *DAG.getContext();
9649 unsigned ShiftAmt = N1C->getZExtValue();
9650 EVT TruncVT = EVT::getIntegerVT(Ctx, OpSizeInBits - ShiftAmt);
9651 if (VT.isVector())
9652 TruncVT = EVT::getVectorVT(Ctx, TruncVT, VT.getVectorElementCount());
9653
9654 // TODO: The simple type check probably belongs in the default hook
9655 // implementation and/or target-specific overrides (because
9656 // non-simple types likely require masking when legalized), but
9657 // that restriction may conflict with other transforms.
9658 if (TruncVT.isSimple() && isTypeLegal(TruncVT) &&
9659 TLI.isTruncateFree(VT, TruncVT)) {
9660 SDLoc DL(N);
9661 SDValue Trunc = DAG.getZExtOrTrunc(Shl.getOperand(0), DL, TruncVT);
9662 SDValue ShiftC =
9663 DAG.getConstant(AddC->getAPIntValue().lshr(ShiftAmt).trunc(
9664 TruncVT.getScalarSizeInBits()),
9665 DL, TruncVT);
9666 SDValue Add;
9667 if (IsAdd)
9668 Add = DAG.getNode(ISD::ADD, DL, TruncVT, Trunc, ShiftC);
9669 else
9670 Add = DAG.getNode(ISD::SUB, DL, TruncVT, ShiftC, Trunc);
9671 return DAG.getSExtOrTrunc(Add, DL, VT);
9672 }
9673 }
9674 }
9675 }
9676
9677 // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
9678 if (N1.getOpcode() == ISD::TRUNCATE &&
9679 N1.getOperand(0).getOpcode() == ISD::AND) {
9680 if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
9681 return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0, NewOp1);
9682 }
9683
9684 // fold (sra (trunc (sra x, c1)), c2) -> (trunc (sra x, c1 + c2))
9685 // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))
9686 // if c1 is equal to the number of bits the trunc removes
9687 // TODO - support non-uniform vector shift amounts.
9688 if (N0.getOpcode() == ISD::TRUNCATE &&
9689 (N0.getOperand(0).getOpcode() == ISD::SRL ||
9690 N0.getOperand(0).getOpcode() == ISD::SRA) &&
9691 N0.getOperand(0).hasOneUse() &&
9692 N0.getOperand(0).getOperand(1).hasOneUse() && N1C) {
9693 SDValue N0Op0 = N0.getOperand(0);
9694 if (ConstantSDNode *LargeShift = isConstOrConstSplat(N0Op0.getOperand(1))) {
9695 EVT LargeVT = N0Op0.getValueType();
9696 unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits;
9697 if (LargeShift->getAPIntValue() == TruncBits) {
9698 SDLoc DL(N);
9699 EVT LargeShiftVT = getShiftAmountTy(LargeVT);
9700 SDValue Amt = DAG.getZExtOrTrunc(N1, DL, LargeShiftVT);
9701 Amt = DAG.getNode(ISD::ADD, DL, LargeShiftVT, Amt,
9702 DAG.getConstant(TruncBits, DL, LargeShiftVT));
9703 SDValue SRA =
9704 DAG.getNode(ISD::SRA, DL, LargeVT, N0Op0.getOperand(0), Amt);
9705 return DAG.getNode(ISD::TRUNCATE, DL, VT, SRA);
9706 }
9707 }
9708 }
9709
9710 // Simplify, based on bits shifted out of the LHS.
9711 if (SimplifyDemandedBits(SDValue(N, 0)))
9712 return SDValue(N, 0);
9713
9714 // If the sign bit is known to be zero, switch this to a SRL.
9715 if (DAG.SignBitIsZero(N0))
9716 return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, N1);
9717
9718 if (N1C && !N1C->isOpaque())
9719 if (SDValue NewSRA = visitShiftByConstant(N))
9720 return NewSRA;
9721
9722 // Try to transform this shift into a multiply-high if
9723 // it matches the appropriate pattern detected in combineShiftToMULH.
9724 if (SDValue MULH = combineShiftToMULH(N, DAG, TLI))
9725 return MULH;
9726
9727 // Attempt to convert a sra of a load into a narrower sign-extending load.
9728 if (SDValue NarrowLoad = reduceLoadWidth(N))
9729 return NarrowLoad;
9730
9731 return SDValue();
9732 }
9733
visitSRL(SDNode * N)9734 SDValue DAGCombiner::visitSRL(SDNode *N) {
9735 SDValue N0 = N->getOperand(0);
9736 SDValue N1 = N->getOperand(1);
9737 if (SDValue V = DAG.simplifyShift(N0, N1))
9738 return V;
9739
9740 EVT VT = N0.getValueType();
9741 EVT ShiftVT = N1.getValueType();
9742 unsigned OpSizeInBits = VT.getScalarSizeInBits();
9743
9744 // fold (srl c1, c2) -> c1 >>u c2
9745 if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRL, SDLoc(N), VT, {N0, N1}))
9746 return C;
9747
9748 // fold vector ops
9749 if (VT.isVector())
9750 if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
9751 return FoldedVOp;
9752
9753 if (SDValue NewSel = foldBinOpIntoSelect(N))
9754 return NewSel;
9755
9756 // if (srl x, c) is known to be zero, return 0
9757 ConstantSDNode *N1C = isConstOrConstSplat(N1);
9758 if (N1C &&
9759 DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
9760 return DAG.getConstant(0, SDLoc(N), VT);
9761
9762 // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
9763 if (N0.getOpcode() == ISD::SRL) {
9764 auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
9765 ConstantSDNode *RHS) {
9766 APInt c1 = LHS->getAPIntValue();
9767 APInt c2 = RHS->getAPIntValue();
9768 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9769 return (c1 + c2).uge(OpSizeInBits);
9770 };
9771 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
9772 return DAG.getConstant(0, SDLoc(N), VT);
9773
9774 auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
9775 ConstantSDNode *RHS) {
9776 APInt c1 = LHS->getAPIntValue();
9777 APInt c2 = RHS->getAPIntValue();
9778 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
9779 return (c1 + c2).ult(OpSizeInBits);
9780 };
9781 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
9782 SDLoc DL(N);
9783 SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
9784 return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Sum);
9785 }
9786 }
9787
9788 if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
9789 N0.getOperand(0).getOpcode() == ISD::SRL) {
9790 SDValue InnerShift = N0.getOperand(0);
9791 // TODO - support non-uniform vector shift amounts.
9792 if (auto *N001C = isConstOrConstSplat(InnerShift.getOperand(1))) {
9793 uint64_t c1 = N001C->getZExtValue();
9794 uint64_t c2 = N1C->getZExtValue();
9795 EVT InnerShiftVT = InnerShift.getValueType();
9796 EVT ShiftAmtVT = InnerShift.getOperand(1).getValueType();
9797 uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
9798 // srl (trunc (srl x, c1)), c2 --> 0 or (trunc (srl x, (add c1, c2)))
9799 // This is only valid if the OpSizeInBits + c1 = size of inner shift.
9800 if (c1 + OpSizeInBits == InnerShiftSize) {
9801 SDLoc DL(N);
9802 if (c1 + c2 >= InnerShiftSize)
9803 return DAG.getConstant(0, DL, VT);
9804 SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
9805 SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
9806 InnerShift.getOperand(0), NewShiftAmt);
9807 return DAG.getNode(ISD::TRUNCATE, DL, VT, NewShift);
9808 }
9809 // In the more general case, we can clear the high bits after the shift:
9810 // srl (trunc (srl x, c1)), c2 --> trunc (and (srl x, (c1+c2)), Mask)
9811 if (N0.hasOneUse() && InnerShift.hasOneUse() &&
9812 c1 + c2 < InnerShiftSize) {
9813 SDLoc DL(N);
9814 SDValue NewShiftAmt = DAG.getConstant(c1 + c2, DL, ShiftAmtVT);
9815 SDValue NewShift = DAG.getNode(ISD::SRL, DL, InnerShiftVT,
9816 InnerShift.getOperand(0), NewShiftAmt);
9817 SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(InnerShiftSize,
9818 OpSizeInBits - c2),
9819 DL, InnerShiftVT);
9820 SDValue And = DAG.getNode(ISD::AND, DL, InnerShiftVT, NewShift, Mask);
9821 return DAG.getNode(ISD::TRUNCATE, DL, VT, And);
9822 }
9823 }
9824 }
9825
9826 // fold (srl (shl x, c1), c2) -> (and (shl x, (sub c1, c2), MASK) or
9827 // (and (srl x, (sub c2, c1), MASK)
9828 if (N0.getOpcode() == ISD::SHL &&
9829 (N0.getOperand(1) == N1 || N0->hasOneUse()) &&
9830 TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
9831 auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
9832 ConstantSDNode *RHS) {
9833 const APInt &LHSC = LHS->getAPIntValue();
9834 const APInt &RHSC = RHS->getAPIntValue();
9835 return LHSC.ult(OpSizeInBits) && RHSC.ult(OpSizeInBits) &&
9836 LHSC.getZExtValue() <= RHSC.getZExtValue();
9837 };
9838 if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
9839 /*AllowUndefs*/ false,
9840 /*AllowTypeMismatch*/ true)) {
9841 SDLoc DL(N);
9842 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9843 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N01, N1);
9844 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
9845 Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N01);
9846 Mask = DAG.getNode(ISD::SHL, DL, VT, Mask, Diff);
9847 SDValue Shift = DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Diff);
9848 return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
9849 }
9850 if (ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchShiftAmount,
9851 /*AllowUndefs*/ false,
9852 /*AllowTypeMismatch*/ true)) {
9853 SDLoc DL(N);
9854 SDValue N01 = DAG.getZExtOrTrunc(N0.getOperand(1), DL, ShiftVT);
9855 SDValue Diff = DAG.getNode(ISD::SUB, DL, ShiftVT, N1, N01);
9856 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
9857 Mask = DAG.getNode(ISD::SRL, DL, VT, Mask, N1);
9858 SDValue Shift = DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), Diff);
9859 return DAG.getNode(ISD::AND, DL, VT, Shift, Mask);
9860 }
9861 }
9862
9863 // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
9864 // TODO - support non-uniform vector shift amounts.
9865 if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
9866 // Shifting in all undef bits?
9867 EVT SmallVT = N0.getOperand(0).getValueType();
9868 unsigned BitSize = SmallVT.getScalarSizeInBits();
9869 if (N1C->getAPIntValue().uge(BitSize))
9870 return DAG.getUNDEF(VT);
9871
9872 if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) {
9873 uint64_t ShiftAmt = N1C->getZExtValue();
9874 SDLoc DL0(N0);
9875 SDValue SmallShift = DAG.getNode(ISD::SRL, DL0, SmallVT,
9876 N0.getOperand(0),
9877 DAG.getConstant(ShiftAmt, DL0,
9878 getShiftAmountTy(SmallVT)));
9879 AddToWorklist(SmallShift.getNode());
9880 APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
9881 SDLoc DL(N);
9882 return DAG.getNode(ISD::AND, DL, VT,
9883 DAG.getNode(ISD::ANY_EXTEND, DL, VT, SmallShift),
9884 DAG.getConstant(Mask, DL, VT));
9885 }
9886 }
9887
9888 // fold (srl (sra X, Y), 31) -> (srl X, 31). This srl only looks at the sign
9889 // bit, which is unmodified by sra.
9890 if (N1C && N1C->getAPIntValue() == (OpSizeInBits - 1)) {
9891 if (N0.getOpcode() == ISD::SRA)
9892 return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0.getOperand(0), N1);
9893 }
9894
9895 // fold (srl (ctlz x), "5") -> x iff x has one bit set (the low bit).
9896 if (N1C && N0.getOpcode() == ISD::CTLZ &&
9897 N1C->getAPIntValue() == Log2_32(OpSizeInBits)) {
9898 KnownBits Known = DAG.computeKnownBits(N0.getOperand(0));
9899
9900 // If any of the input bits are KnownOne, then the input couldn't be all
9901 // zeros, thus the result of the srl will always be zero.
9902 if (Known.One.getBoolValue()) return DAG.getConstant(0, SDLoc(N0), VT);
9903
9904 // If all of the bits input the to ctlz node are known to be zero, then
9905 // the result of the ctlz is "32" and the result of the shift is one.
9906 APInt UnknownBits = ~Known.Zero;
9907 if (UnknownBits == 0) return DAG.getConstant(1, SDLoc(N0), VT);
9908
9909 // Otherwise, check to see if there is exactly one bit input to the ctlz.
9910 if (UnknownBits.isPowerOf2()) {
9911 // Okay, we know that only that the single bit specified by UnknownBits
9912 // could be set on input to the CTLZ node. If this bit is set, the SRL
9913 // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
9914 // to an SRL/XOR pair, which is likely to simplify more.
9915 unsigned ShAmt = UnknownBits.countTrailingZeros();
9916 SDValue Op = N0.getOperand(0);
9917
9918 if (ShAmt) {
9919 SDLoc DL(N0);
9920 Op = DAG.getNode(ISD::SRL, DL, VT, Op,
9921 DAG.getConstant(ShAmt, DL,
9922 getShiftAmountTy(Op.getValueType())));
9923 AddToWorklist(Op.getNode());
9924 }
9925
9926 SDLoc DL(N);
9927 return DAG.getNode(ISD::XOR, DL, VT,
9928 Op, DAG.getConstant(1, DL, VT));
9929 }
9930 }
9931
9932 // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))).
9933 if (N1.getOpcode() == ISD::TRUNCATE &&
9934 N1.getOperand(0).getOpcode() == ISD::AND) {
9935 if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()))
9936 return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, NewOp1);
9937 }
9938
9939 // fold operands of srl based on knowledge that the low bits are not
9940 // demanded.
9941 if (SimplifyDemandedBits(SDValue(N, 0)))
9942 return SDValue(N, 0);
9943
9944 if (N1C && !N1C->isOpaque())
9945 if (SDValue NewSRL = visitShiftByConstant(N))
9946 return NewSRL;
9947
9948 // Attempt to convert a srl of a load into a narrower zero-extending load.
9949 if (SDValue NarrowLoad = reduceLoadWidth(N))
9950 return NarrowLoad;
9951
9952 // Here is a common situation. We want to optimize:
9953 //
9954 // %a = ...
9955 // %b = and i32 %a, 2
9956 // %c = srl i32 %b, 1
9957 // brcond i32 %c ...
9958 //
9959 // into
9960 //
9961 // %a = ...
9962 // %b = and %a, 2
9963 // %c = setcc eq %b, 0
9964 // brcond %c ...
9965 //
9966 // However when after the source operand of SRL is optimized into AND, the SRL
9967 // itself may not be optimized further. Look for it and add the BRCOND into
9968 // the worklist.
9969 //
9970 // The also tends to happen for binary operations when SimplifyDemandedBits
9971 // is involved.
9972 //
9973 // FIXME: This is unecessary if we process the DAG in topological order,
9974 // which we plan to do. This workaround can be removed once the DAG is
9975 // processed in topological order.
9976 if (N->hasOneUse()) {
9977 SDNode *Use = *N->use_begin();
9978
9979 // Look pass the truncate.
9980 if (Use->getOpcode() == ISD::TRUNCATE && Use->hasOneUse())
9981 Use = *Use->use_begin();
9982
9983 if (Use->getOpcode() == ISD::BRCOND || Use->getOpcode() == ISD::AND ||
9984 Use->getOpcode() == ISD::OR || Use->getOpcode() == ISD::XOR)
9985 AddToWorklist(Use);
9986 }
9987
9988 // Try to transform this shift into a multiply-high if
9989 // it matches the appropriate pattern detected in combineShiftToMULH.
9990 if (SDValue MULH = combineShiftToMULH(N, DAG, TLI))
9991 return MULH;
9992
9993 return SDValue();
9994 }
9995
visitFunnelShift(SDNode * N)9996 SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
9997 EVT VT = N->getValueType(0);
9998 SDValue N0 = N->getOperand(0);
9999 SDValue N1 = N->getOperand(1);
10000 SDValue N2 = N->getOperand(2);
10001 bool IsFSHL = N->getOpcode() == ISD::FSHL;
10002 unsigned BitWidth = VT.getScalarSizeInBits();
10003
10004 // fold (fshl N0, N1, 0) -> N0
10005 // fold (fshr N0, N1, 0) -> N1
10006 if (isPowerOf2_32(BitWidth))
10007 if (DAG.MaskedValueIsZero(
10008 N2, APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
10009 return IsFSHL ? N0 : N1;
10010
10011 auto IsUndefOrZero = [](SDValue V) {
10012 return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
10013 };
10014
10015 // TODO - support non-uniform vector shift amounts.
10016 if (ConstantSDNode *Cst = isConstOrConstSplat(N2)) {
10017 EVT ShAmtTy = N2.getValueType();
10018
10019 // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
10020 if (Cst->getAPIntValue().uge(BitWidth)) {
10021 uint64_t RotAmt = Cst->getAPIntValue().urem(BitWidth);
10022 return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0, N1,
10023 DAG.getConstant(RotAmt, SDLoc(N), ShAmtTy));
10024 }
10025
10026 unsigned ShAmt = Cst->getZExtValue();
10027 if (ShAmt == 0)
10028 return IsFSHL ? N0 : N1;
10029
10030 // fold fshl(undef_or_zero, N1, C) -> lshr(N1, BW-C)
10031 // fold fshr(undef_or_zero, N1, C) -> lshr(N1, C)
10032 // fold fshl(N0, undef_or_zero, C) -> shl(N0, C)
10033 // fold fshr(N0, undef_or_zero, C) -> shl(N0, BW-C)
10034 if (IsUndefOrZero(N0))
10035 return DAG.getNode(ISD::SRL, SDLoc(N), VT, N1,
10036 DAG.getConstant(IsFSHL ? BitWidth - ShAmt : ShAmt,
10037 SDLoc(N), ShAmtTy));
10038 if (IsUndefOrZero(N1))
10039 return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0,
10040 DAG.getConstant(IsFSHL ? ShAmt : BitWidth - ShAmt,
10041 SDLoc(N), ShAmtTy));
10042
10043 // fold (fshl ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
10044 // fold (fshr ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
10045 // TODO - bigendian support once we have test coverage.
10046 // TODO - can we merge this with CombineConseutiveLoads/MatchLoadCombine?
10047 // TODO - permit LHS EXTLOAD if extensions are shifted out.
10048 if ((BitWidth % 8) == 0 && (ShAmt % 8) == 0 && !VT.isVector() &&
10049 !DAG.getDataLayout().isBigEndian()) {
10050 auto *LHS = dyn_cast<LoadSDNode>(N0);
10051 auto *RHS = dyn_cast<LoadSDNode>(N1);
10052 if (LHS && RHS && LHS->isSimple() && RHS->isSimple() &&
10053 LHS->getAddressSpace() == RHS->getAddressSpace() &&
10054 (LHS->hasOneUse() || RHS->hasOneUse()) && ISD::isNON_EXTLoad(RHS) &&
10055 ISD::isNON_EXTLoad(LHS)) {
10056 if (DAG.areNonVolatileConsecutiveLoads(LHS, RHS, BitWidth / 8, 1)) {
10057 SDLoc DL(RHS);
10058 uint64_t PtrOff =
10059 IsFSHL ? (((BitWidth - ShAmt) % BitWidth) / 8) : (ShAmt / 8);
10060 Align NewAlign = commonAlignment(RHS->getAlign(), PtrOff);
10061 unsigned Fast = 0;
10062 if (TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
10063 RHS->getAddressSpace(), NewAlign,
10064 RHS->getMemOperand()->getFlags(), &Fast) &&
10065 Fast) {
10066 SDValue NewPtr = DAG.getMemBasePlusOffset(
10067 RHS->getBasePtr(), TypeSize::Fixed(PtrOff), DL);
10068 AddToWorklist(NewPtr.getNode());
10069 SDValue Load = DAG.getLoad(
10070 VT, DL, RHS->getChain(), NewPtr,
10071 RHS->getPointerInfo().getWithOffset(PtrOff), NewAlign,
10072 RHS->getMemOperand()->getFlags(), RHS->getAAInfo());
10073 // Replace the old load's chain with the new load's chain.
10074 WorklistRemover DeadNodes(*this);
10075 DAG.ReplaceAllUsesOfValueWith(N1.getValue(1), Load.getValue(1));
10076 return Load;
10077 }
10078 }
10079 }
10080 }
10081 }
10082
10083 // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2)
10084 // fold fshl(N0, undef_or_zero, N2) -> shl(N0, N2)
10085 // iff We know the shift amount is in range.
10086 // TODO: when is it worth doing SUB(BW, N2) as well?
10087 if (isPowerOf2_32(BitWidth)) {
10088 APInt ModuloBits(N2.getScalarValueSizeInBits(), BitWidth - 1);
10089 if (IsUndefOrZero(N0) && !IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
10090 return DAG.getNode(ISD::SRL, SDLoc(N), VT, N1, N2);
10091 if (IsUndefOrZero(N1) && IsFSHL && DAG.MaskedValueIsZero(N2, ~ModuloBits))
10092 return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N2);
10093 }
10094
10095 // fold (fshl N0, N0, N2) -> (rotl N0, N2)
10096 // fold (fshr N0, N0, N2) -> (rotr N0, N2)
10097 // TODO: Investigate flipping this rotate if only one is legal, if funnel shift
10098 // is legal as well we might be better off avoiding non-constant (BW - N2).
10099 unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
10100 if (N0 == N1 && hasOperation(RotOpc, VT))
10101 return DAG.getNode(RotOpc, SDLoc(N), VT, N0, N2);
10102
10103 // Simplify, based on bits shifted out of N0/N1.
10104 if (SimplifyDemandedBits(SDValue(N, 0)))
10105 return SDValue(N, 0);
10106
10107 return SDValue();
10108 }
10109
visitSHLSAT(SDNode * N)10110 SDValue DAGCombiner::visitSHLSAT(SDNode *N) {
10111 SDValue N0 = N->getOperand(0);
10112 SDValue N1 = N->getOperand(1);
10113 if (SDValue V = DAG.simplifyShift(N0, N1))
10114 return V;
10115
10116 EVT VT = N0.getValueType();
10117
10118 // fold (*shlsat c1, c2) -> c1<<c2
10119 if (SDValue C =
10120 DAG.FoldConstantArithmetic(N->getOpcode(), SDLoc(N), VT, {N0, N1}))
10121 return C;
10122
10123 ConstantSDNode *N1C = isConstOrConstSplat(N1);
10124
10125 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::SHL, VT)) {
10126 // fold (sshlsat x, c) -> (shl x, c)
10127 if (N->getOpcode() == ISD::SSHLSAT && N1C &&
10128 N1C->getAPIntValue().ult(DAG.ComputeNumSignBits(N0)))
10129 return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N1);
10130
10131 // fold (ushlsat x, c) -> (shl x, c)
10132 if (N->getOpcode() == ISD::USHLSAT && N1C &&
10133 N1C->getAPIntValue().ule(
10134 DAG.computeKnownBits(N0).countMinLeadingZeros()))
10135 return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N1);
10136 }
10137
10138 return SDValue();
10139 }
10140
10141 // Given a ABS node, detect the following pattern:
10142 // (ABS (SUB (EXTEND a), (EXTEND b))).
10143 // Generates UABD/SABD instruction.
foldABSToABD(SDNode * N)10144 SDValue DAGCombiner::foldABSToABD(SDNode *N) {
10145 EVT VT = N->getValueType(0);
10146 SDValue AbsOp1 = N->getOperand(0);
10147 SDValue Op0, Op1;
10148
10149 if (AbsOp1.getOpcode() != ISD::SUB)
10150 return SDValue();
10151
10152 Op0 = AbsOp1.getOperand(0);
10153 Op1 = AbsOp1.getOperand(1);
10154
10155 unsigned Opc0 = Op0.getOpcode();
10156 // Check if the operands of the sub are (zero|sign)-extended.
10157 if (Opc0 != Op1.getOpcode() ||
10158 (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND)) {
10159 // fold (abs (sub nsw x, y)) -> abds(x, y)
10160 if (AbsOp1->getFlags().hasNoSignedWrap() &&
10161 TLI.isOperationLegalOrCustom(ISD::ABDS, VT))
10162 return DAG.getNode(ISD::ABDS, SDLoc(N), VT, Op0, Op1);
10163 return SDValue();
10164 }
10165
10166 EVT VT1 = Op0.getOperand(0).getValueType();
10167 EVT VT2 = Op1.getOperand(0).getValueType();
10168 unsigned ABDOpcode = (Opc0 == ISD::SIGN_EXTEND) ? ISD::ABDS : ISD::ABDU;
10169
10170 // fold abs(sext(x) - sext(y)) -> zext(abds(x, y))
10171 // fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
10172 // NOTE: Extensions must be equivalent.
10173 if (VT1 == VT2 && TLI.isOperationLegalOrCustom(ABDOpcode, VT1)) {
10174 Op0 = Op0.getOperand(0);
10175 Op1 = Op1.getOperand(0);
10176 SDValue ABD = DAG.getNode(ABDOpcode, SDLoc(N), VT1, Op0, Op1);
10177 return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, ABD);
10178 }
10179
10180 // fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
10181 // fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
10182 if (TLI.isOperationLegalOrCustom(ABDOpcode, VT))
10183 return DAG.getNode(ABDOpcode, SDLoc(N), VT, Op0, Op1);
10184
10185 return SDValue();
10186 }
10187
visitABS(SDNode * N)10188 SDValue DAGCombiner::visitABS(SDNode *N) {
10189 SDValue N0 = N->getOperand(0);
10190 EVT VT = N->getValueType(0);
10191
10192 // fold (abs c1) -> c2
10193 if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10194 return DAG.getNode(ISD::ABS, SDLoc(N), VT, N0);
10195 // fold (abs (abs x)) -> (abs x)
10196 if (N0.getOpcode() == ISD::ABS)
10197 return N0;
10198 // fold (abs x) -> x iff not-negative
10199 if (DAG.SignBitIsZero(N0))
10200 return N0;
10201
10202 if (SDValue ABD = foldABSToABD(N))
10203 return ABD;
10204
10205 // fold (abs (sign_extend_inreg x)) -> (zero_extend (abs (truncate x)))
10206 // iff zero_extend/truncate are free.
10207 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
10208 EVT ExtVT = cast<VTSDNode>(N0.getOperand(1))->getVT();
10209 if (TLI.isTruncateFree(VT, ExtVT) && TLI.isZExtFree(ExtVT, VT) &&
10210 TLI.isTypeDesirableForOp(ISD::ABS, ExtVT) &&
10211 hasOperation(ISD::ABS, ExtVT)) {
10212 SDLoc DL(N);
10213 return DAG.getNode(
10214 ISD::ZERO_EXTEND, DL, VT,
10215 DAG.getNode(ISD::ABS, DL, ExtVT,
10216 DAG.getNode(ISD::TRUNCATE, DL, ExtVT, N0.getOperand(0))));
10217 }
10218 }
10219
10220 return SDValue();
10221 }
10222
visitBSWAP(SDNode * N)10223 SDValue DAGCombiner::visitBSWAP(SDNode *N) {
10224 SDValue N0 = N->getOperand(0);
10225 EVT VT = N->getValueType(0);
10226 SDLoc DL(N);
10227
10228 // fold (bswap c1) -> c2
10229 if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10230 return DAG.getNode(ISD::BSWAP, DL, VT, N0);
10231 // fold (bswap (bswap x)) -> x
10232 if (N0.getOpcode() == ISD::BSWAP)
10233 return N0.getOperand(0);
10234
10235 // Canonicalize bswap(bitreverse(x)) -> bitreverse(bswap(x)). If bitreverse
10236 // isn't supported, it will be expanded to bswap followed by a manual reversal
10237 // of bits in each byte. By placing bswaps before bitreverse, we can remove
10238 // the two bswaps if the bitreverse gets expanded.
10239 if (N0.getOpcode() == ISD::BITREVERSE && N0.hasOneUse()) {
10240 SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0));
10241 return DAG.getNode(ISD::BITREVERSE, DL, VT, BSwap);
10242 }
10243
10244 // fold (bswap shl(x,c)) -> (zext(bswap(trunc(shl(x,sub(c,bw/2))))))
10245 // iff x >= bw/2 (i.e. lower half is known zero)
10246 unsigned BW = VT.getScalarSizeInBits();
10247 if (BW >= 32 && N0.getOpcode() == ISD::SHL && N0.hasOneUse()) {
10248 auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1));
10249 EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), BW / 2);
10250 if (ShAmt && ShAmt->getAPIntValue().ult(BW) &&
10251 ShAmt->getZExtValue() >= (BW / 2) &&
10252 (ShAmt->getZExtValue() % 16) == 0 && TLI.isTypeLegal(HalfVT) &&
10253 TLI.isTruncateFree(VT, HalfVT) &&
10254 (!LegalOperations || hasOperation(ISD::BSWAP, HalfVT))) {
10255 SDValue Res = N0.getOperand(0);
10256 if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2)))
10257 Res = DAG.getNode(ISD::SHL, DL, VT, Res,
10258 DAG.getConstant(NewShAmt, DL, getShiftAmountTy(VT)));
10259 Res = DAG.getZExtOrTrunc(Res, DL, HalfVT);
10260 Res = DAG.getNode(ISD::BSWAP, DL, HalfVT, Res);
10261 return DAG.getZExtOrTrunc(Res, DL, VT);
10262 }
10263 }
10264
10265 // Try to canonicalize bswap-of-logical-shift-by-8-bit-multiple as
10266 // inverse-shift-of-bswap:
10267 // bswap (X u<< C) --> (bswap X) u>> C
10268 // bswap (X u>> C) --> (bswap X) u<< C
10269 if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
10270 N0.hasOneUse()) {
10271 auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1));
10272 if (ShAmt && ShAmt->getAPIntValue().ult(BW) &&
10273 ShAmt->getZExtValue() % 8 == 0) {
10274 SDValue NewSwap = DAG.getNode(ISD::BSWAP, DL, VT, N0.getOperand(0));
10275 unsigned InverseShift = N0.getOpcode() == ISD::SHL ? ISD::SRL : ISD::SHL;
10276 return DAG.getNode(InverseShift, DL, VT, NewSwap, N0.getOperand(1));
10277 }
10278 }
10279
10280 return SDValue();
10281 }
10282
visitBITREVERSE(SDNode * N)10283 SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
10284 SDValue N0 = N->getOperand(0);
10285 EVT VT = N->getValueType(0);
10286
10287 // fold (bitreverse c1) -> c2
10288 if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10289 return DAG.getNode(ISD::BITREVERSE, SDLoc(N), VT, N0);
10290 // fold (bitreverse (bitreverse x)) -> x
10291 if (N0.getOpcode() == ISD::BITREVERSE)
10292 return N0.getOperand(0);
10293 return SDValue();
10294 }
10295
visitCTLZ(SDNode * N)10296 SDValue DAGCombiner::visitCTLZ(SDNode *N) {
10297 SDValue N0 = N->getOperand(0);
10298 EVT VT = N->getValueType(0);
10299
10300 // fold (ctlz c1) -> c2
10301 if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10302 return DAG.getNode(ISD::CTLZ, SDLoc(N), VT, N0);
10303
10304 // If the value is known never to be zero, switch to the undef version.
10305 if (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ_ZERO_UNDEF, VT)) {
10306 if (DAG.isKnownNeverZero(N0))
10307 return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0);
10308 }
10309
10310 return SDValue();
10311 }
10312
visitCTLZ_ZERO_UNDEF(SDNode * N)10313 SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
10314 SDValue N0 = N->getOperand(0);
10315 EVT VT = N->getValueType(0);
10316
10317 // fold (ctlz_zero_undef c1) -> c2
10318 if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10319 return DAG.getNode(ISD::CTLZ_ZERO_UNDEF, SDLoc(N), VT, N0);
10320 return SDValue();
10321 }
10322
visitCTTZ(SDNode * N)10323 SDValue DAGCombiner::visitCTTZ(SDNode *N) {
10324 SDValue N0 = N->getOperand(0);
10325 EVT VT = N->getValueType(0);
10326
10327 // fold (cttz c1) -> c2
10328 if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10329 return DAG.getNode(ISD::CTTZ, SDLoc(N), VT, N0);
10330
10331 // If the value is known never to be zero, switch to the undef version.
10332 if (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ_ZERO_UNDEF, VT)) {
10333 if (DAG.isKnownNeverZero(N0))
10334 return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0);
10335 }
10336
10337 return SDValue();
10338 }
10339
visitCTTZ_ZERO_UNDEF(SDNode * N)10340 SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
10341 SDValue N0 = N->getOperand(0);
10342 EVT VT = N->getValueType(0);
10343
10344 // fold (cttz_zero_undef c1) -> c2
10345 if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10346 return DAG.getNode(ISD::CTTZ_ZERO_UNDEF, SDLoc(N), VT, N0);
10347 return SDValue();
10348 }
10349
visitCTPOP(SDNode * N)10350 SDValue DAGCombiner::visitCTPOP(SDNode *N) {
10351 SDValue N0 = N->getOperand(0);
10352 EVT VT = N->getValueType(0);
10353
10354 // fold (ctpop c1) -> c2
10355 if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
10356 return DAG.getNode(ISD::CTPOP, SDLoc(N), VT, N0);
10357 return SDValue();
10358 }
10359
10360 // FIXME: This should be checking for no signed zeros on individual operands, as
10361 // well as no nans.
isLegalToCombineMinNumMaxNum(SelectionDAG & DAG,SDValue LHS,SDValue RHS,const TargetLowering & TLI)10362 static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS,
10363 SDValue RHS,
10364 const TargetLowering &TLI) {
10365 const TargetOptions &Options = DAG.getTarget().Options;
10366 EVT VT = LHS.getValueType();
10367
10368 return Options.NoSignedZerosFPMath && VT.isFloatingPoint() &&
10369 TLI.isProfitableToCombineMinNumMaxNum(VT) &&
10370 DAG.isKnownNeverNaN(LHS) && DAG.isKnownNeverNaN(RHS);
10371 }
10372
combineMinNumMaxNumImpl(const SDLoc & DL,EVT VT,SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC,const TargetLowering & TLI,SelectionDAG & DAG)10373 static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS,
10374 SDValue RHS, SDValue True, SDValue False,
10375 ISD::CondCode CC,
10376 const TargetLowering &TLI,
10377 SelectionDAG &DAG) {
10378 EVT TransformVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
10379 switch (CC) {
10380 case ISD::SETOLT:
10381 case ISD::SETOLE:
10382 case ISD::SETLT:
10383 case ISD::SETLE:
10384 case ISD::SETULT:
10385 case ISD::SETULE: {
10386 // Since it's known never nan to get here already, either fminnum or
10387 // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
10388 // expanded in terms of it.
10389 unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
10390 if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
10391 return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
10392
10393 unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
10394 if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
10395 return DAG.getNode(Opcode, DL, VT, LHS, RHS);
10396 return SDValue();
10397 }
10398 case ISD::SETOGT:
10399 case ISD::SETOGE:
10400 case ISD::SETGT:
10401 case ISD::SETGE:
10402 case ISD::SETUGT:
10403 case ISD::SETUGE: {
10404 unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
10405 if (TLI.isOperationLegalOrCustom(IEEEOpcode, VT))
10406 return DAG.getNode(IEEEOpcode, DL, VT, LHS, RHS);
10407
10408 unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
10409 if (TLI.isOperationLegalOrCustom(Opcode, TransformVT))
10410 return DAG.getNode(Opcode, DL, VT, LHS, RHS);
10411 return SDValue();
10412 }
10413 default:
10414 return SDValue();
10415 }
10416 }
10417
10418 /// Generate Min/Max node
combineMinNumMaxNum(const SDLoc & DL,EVT VT,SDValue LHS,SDValue RHS,SDValue True,SDValue False,ISD::CondCode CC)10419 SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
10420 SDValue RHS, SDValue True,
10421 SDValue False, ISD::CondCode CC) {
10422 if ((LHS == True && RHS == False) || (LHS == False && RHS == True))
10423 return combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True, False, CC, TLI, DAG);
10424
10425 // If we can't directly match this, try to see if we can pull an fneg out of
10426 // the select.
10427 SDValue NegTrue = TLI.getCheaperOrNeutralNegatedExpression(
10428 True, DAG, LegalOperations, ForCodeSize);
10429 if (!NegTrue)
10430 return SDValue();
10431
10432 HandleSDNode NegTrueHandle(NegTrue);
10433
10434 // Try to unfold an fneg from the select if we are comparing the negated
10435 // constant.
10436 //
10437 // select (setcc x, K) (fneg x), -K -> fneg(minnum(x, K))
10438 //
10439 // TODO: Handle fabs
10440 if (LHS == NegTrue) {
10441 // If we can't directly match this, try to see if we can pull an fneg out of
10442 // the select.
10443 SDValue NegRHS = TLI.getCheaperOrNeutralNegatedExpression(
10444 RHS, DAG, LegalOperations, ForCodeSize);
10445 if (NegRHS) {
10446 HandleSDNode NegRHSHandle(NegRHS);
10447 if (NegRHS == False) {
10448 SDValue Combined = combineMinNumMaxNumImpl(DL, VT, LHS, RHS, NegTrue,
10449 False, CC, TLI, DAG);
10450 if (Combined)
10451 return DAG.getNode(ISD::FNEG, DL, VT, Combined);
10452 }
10453 }
10454 }
10455
10456 return SDValue();
10457 }
10458
10459 /// If a (v)select has a condition value that is a sign-bit test, try to smear
10460 /// the condition operand sign-bit across the value width and use it as a mask.
foldSelectOfConstantsUsingSra(SDNode * N,SelectionDAG & DAG)10461 static SDValue foldSelectOfConstantsUsingSra(SDNode *N, SelectionDAG &DAG) {
10462 SDValue Cond = N->getOperand(0);
10463 SDValue C1 = N->getOperand(1);
10464 SDValue C2 = N->getOperand(2);
10465 if (!isConstantOrConstantVector(C1) || !isConstantOrConstantVector(C2))
10466 return SDValue();
10467
10468 EVT VT = N->getValueType(0);
10469 if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse() ||
10470 VT != Cond.getOperand(0).getValueType())
10471 return SDValue();
10472
10473 // The inverted-condition + commuted-select variants of these patterns are
10474 // canonicalized to these forms in IR.
10475 SDValue X = Cond.getOperand(0);
10476 SDValue CondC = Cond.getOperand(1);
10477 ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
10478 if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(CondC) &&
10479 isAllOnesOrAllOnesSplat(C2)) {
10480 // i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
10481 SDLoc DL(N);
10482 SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
10483 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
10484 return DAG.getNode(ISD::OR, DL, VT, Sra, C1);
10485 }
10486 if (CC == ISD::SETLT && isNullOrNullSplat(CondC) && isNullOrNullSplat(C2)) {
10487 // i8 X < 0 ? C1 : 0 --> (X >>s 7) & C1
10488 SDLoc DL(N);
10489 SDValue ShAmtC = DAG.getConstant(X.getScalarValueSizeInBits() - 1, DL, VT);
10490 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, X, ShAmtC);
10491 return DAG.getNode(ISD::AND, DL, VT, Sra, C1);
10492 }
10493 return SDValue();
10494 }
10495
shouldConvertSelectOfConstantsToMath(const SDValue & Cond,EVT VT,const TargetLowering & TLI)10496 static bool shouldConvertSelectOfConstantsToMath(const SDValue &Cond, EVT VT,
10497 const TargetLowering &TLI) {
10498 if (!TLI.convertSelectOfConstantsToMath(VT))
10499 return false;
10500
10501 if (Cond.getOpcode() != ISD::SETCC || !Cond->hasOneUse())
10502 return true;
10503 if (!TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))
10504 return true;
10505
10506 ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
10507 if (CC == ISD::SETLT && isNullOrNullSplat(Cond.getOperand(1)))
10508 return true;
10509 if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(Cond.getOperand(1)))
10510 return true;
10511
10512 return false;
10513 }
10514
foldSelectOfConstants(SDNode * N)10515 SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
10516 SDValue Cond = N->getOperand(0);
10517 SDValue N1 = N->getOperand(1);
10518 SDValue N2 = N->getOperand(2);
10519 EVT VT = N->getValueType(0);
10520 EVT CondVT = Cond.getValueType();
10521 SDLoc DL(N);
10522
10523 if (!VT.isInteger())
10524 return SDValue();
10525
10526 auto *C1 = dyn_cast<ConstantSDNode>(N1);
10527 auto *C2 = dyn_cast<ConstantSDNode>(N2);
10528 if (!C1 || !C2)
10529 return SDValue();
10530
10531 if (CondVT != MVT::i1 || LegalOperations) {
10532 // fold (select Cond, 0, 1) -> (xor Cond, 1)
10533 // We can't do this reliably if integer based booleans have different contents
10534 // to floating point based booleans. This is because we can't tell whether we
10535 // have an integer-based boolean or a floating-point-based boolean unless we
10536 // can find the SETCC that produced it and inspect its operands. This is
10537 // fairly easy if C is the SETCC node, but it can potentially be
10538 // undiscoverable (or not reasonably discoverable). For example, it could be
10539 // in another basic block or it could require searching a complicated
10540 // expression.
10541 if (CondVT.isInteger() &&
10542 TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
10543 TargetLowering::ZeroOrOneBooleanContent &&
10544 TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
10545 TargetLowering::ZeroOrOneBooleanContent &&
10546 C1->isZero() && C2->isOne()) {
10547 SDValue NotCond =
10548 DAG.getNode(ISD::XOR, DL, CondVT, Cond, DAG.getConstant(1, DL, CondVT));
10549 if (VT.bitsEq(CondVT))
10550 return NotCond;
10551 return DAG.getZExtOrTrunc(NotCond, DL, VT);
10552 }
10553
10554 return SDValue();
10555 }
10556
10557 // Only do this before legalization to avoid conflicting with target-specific
10558 // transforms in the other direction (create a select from a zext/sext). There
10559 // is also a target-independent combine here in DAGCombiner in the other
10560 // direction for (select Cond, -1, 0) when the condition is not i1.
10561 assert(CondVT == MVT::i1 && !LegalOperations);
10562
10563 // select Cond, 1, 0 --> zext (Cond)
10564 if (C1->isOne() && C2->isZero())
10565 return DAG.getZExtOrTrunc(Cond, DL, VT);
10566
10567 // select Cond, -1, 0 --> sext (Cond)
10568 if (C1->isAllOnes() && C2->isZero())
10569 return DAG.getSExtOrTrunc(Cond, DL, VT);
10570
10571 // select Cond, 0, 1 --> zext (!Cond)
10572 if (C1->isZero() && C2->isOne()) {
10573 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
10574 NotCond = DAG.getZExtOrTrunc(NotCond, DL, VT);
10575 return NotCond;
10576 }
10577
10578 // select Cond, 0, -1 --> sext (!Cond)
10579 if (C1->isZero() && C2->isAllOnes()) {
10580 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
10581 NotCond = DAG.getSExtOrTrunc(NotCond, DL, VT);
10582 return NotCond;
10583 }
10584
10585 // Use a target hook because some targets may prefer to transform in the
10586 // other direction.
10587 if (!shouldConvertSelectOfConstantsToMath(Cond, VT, TLI))
10588 return SDValue();
10589
10590 // For any constants that differ by 1, we can transform the select into
10591 // an extend and add.
10592 const APInt &C1Val = C1->getAPIntValue();
10593 const APInt &C2Val = C2->getAPIntValue();
10594
10595 // select Cond, C1, C1-1 --> add (zext Cond), C1-1
10596 if (C1Val - 1 == C2Val) {
10597 Cond = DAG.getZExtOrTrunc(Cond, DL, VT);
10598 return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
10599 }
10600
10601 // select Cond, C1, C1+1 --> add (sext Cond), C1+1
10602 if (C1Val + 1 == C2Val) {
10603 Cond = DAG.getSExtOrTrunc(Cond, DL, VT);
10604 return DAG.getNode(ISD::ADD, DL, VT, Cond, N2);
10605 }
10606
10607 // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
10608 if (C1Val.isPowerOf2() && C2Val.isZero()) {
10609 Cond = DAG.getZExtOrTrunc(Cond, DL, VT);
10610 SDValue ShAmtC =
10611 DAG.getShiftAmountConstant(C1Val.exactLogBase2(), VT, DL);
10612 return DAG.getNode(ISD::SHL, DL, VT, Cond, ShAmtC);
10613 }
10614
10615 // select Cond, -1, C --> or (sext Cond), C
10616 if (C1->isAllOnes()) {
10617 Cond = DAG.getSExtOrTrunc(Cond, DL, VT);
10618 return DAG.getNode(ISD::OR, DL, VT, Cond, N2);
10619 }
10620
10621 // select Cond, C, -1 --> or (sext (not Cond)), C
10622 if (C2->isAllOnes()) {
10623 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
10624 NotCond = DAG.getSExtOrTrunc(NotCond, DL, VT);
10625 return DAG.getNode(ISD::OR, DL, VT, NotCond, N1);
10626 }
10627
10628 if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
10629 return V;
10630
10631 return SDValue();
10632 }
10633
foldBoolSelectToLogic(SDNode * N,SelectionDAG & DAG)10634 static SDValue foldBoolSelectToLogic(SDNode *N, SelectionDAG &DAG) {
10635 assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT) &&
10636 "Expected a (v)select");
10637 SDValue Cond = N->getOperand(0);
10638 SDValue T = N->getOperand(1), F = N->getOperand(2);
10639 EVT VT = N->getValueType(0);
10640 if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
10641 return SDValue();
10642
10643 // select Cond, Cond, F --> or Cond, F
10644 // select Cond, 1, F --> or Cond, F
10645 if (Cond == T || isOneOrOneSplat(T, /* AllowUndefs */ true))
10646 return DAG.getNode(ISD::OR, SDLoc(N), VT, Cond, F);
10647
10648 // select Cond, T, Cond --> and Cond, T
10649 // select Cond, T, 0 --> and Cond, T
10650 if (Cond == F || isNullOrNullSplat(F, /* AllowUndefs */ true))
10651 return DAG.getNode(ISD::AND, SDLoc(N), VT, Cond, T);
10652
10653 // select Cond, T, 1 --> or (not Cond), T
10654 if (isOneOrOneSplat(F, /* AllowUndefs */ true)) {
10655 SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT);
10656 return DAG.getNode(ISD::OR, SDLoc(N), VT, NotCond, T);
10657 }
10658
10659 // select Cond, 0, F --> and (not Cond), F
10660 if (isNullOrNullSplat(T, /* AllowUndefs */ true)) {
10661 SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT);
10662 return DAG.getNode(ISD::AND, SDLoc(N), VT, NotCond, F);
10663 }
10664
10665 return SDValue();
10666 }
10667
foldVSelectToSignBitSplatMask(SDNode * N,SelectionDAG & DAG)10668 static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
10669 SDValue N0 = N->getOperand(0);
10670 SDValue N1 = N->getOperand(1);
10671 SDValue N2 = N->getOperand(2);
10672 EVT VT = N->getValueType(0);
10673 if (N0.getOpcode() != ISD::SETCC || !N0.hasOneUse())
10674 return SDValue();
10675
10676 SDValue Cond0 = N0.getOperand(0);
10677 SDValue Cond1 = N0.getOperand(1);
10678 ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
10679 if (VT != Cond0.getValueType())
10680 return SDValue();
10681
10682 // Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the
10683 // compare is inverted from that pattern ("Cond0 s> -1").
10684 if (CC == ISD::SETLT && isNullOrNullSplat(Cond1))
10685 ; // This is the pattern we are looking for.
10686 else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(Cond1))
10687 std::swap(N1, N2);
10688 else
10689 return SDValue();
10690
10691 // (Cond0 s< 0) ? N1 : 0 --> (Cond0 s>> BW-1) & N1
10692 if (isNullOrNullSplat(N2)) {
10693 SDLoc DL(N);
10694 SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
10695 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
10696 return DAG.getNode(ISD::AND, DL, VT, Sra, N1);
10697 }
10698
10699 // (Cond0 s< 0) ? -1 : N2 --> (Cond0 s>> BW-1) | N2
10700 if (isAllOnesOrAllOnesSplat(N1)) {
10701 SDLoc DL(N);
10702 SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
10703 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
10704 return DAG.getNode(ISD::OR, DL, VT, Sra, N2);
10705 }
10706
10707 // If we have to invert the sign bit mask, only do that transform if the
10708 // target has a bitwise 'and not' instruction (the invert is free).
10709 // (Cond0 s< -0) ? 0 : N2 --> ~(Cond0 s>> BW-1) & N2
10710 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
10711 if (isNullOrNullSplat(N1) && TLI.hasAndNot(N1)) {
10712 SDLoc DL(N);
10713 SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
10714 SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
10715 SDValue Not = DAG.getNOT(DL, Sra, VT);
10716 return DAG.getNode(ISD::AND, DL, VT, Not, N2);
10717 }
10718
10719 // TODO: There's another pattern in this family, but it may require
10720 // implementing hasOrNot() to check for profitability:
10721 // (Cond0 s> -1) ? -1 : N2 --> ~(Cond0 s>> BW-1) | N2
10722
10723 return SDValue();
10724 }
10725
visitSELECT(SDNode * N)10726 SDValue DAGCombiner::visitSELECT(SDNode *N) {
10727 SDValue N0 = N->getOperand(0);
10728 SDValue N1 = N->getOperand(1);
10729 SDValue N2 = N->getOperand(2);
10730 EVT VT = N->getValueType(0);
10731 EVT VT0 = N0.getValueType();
10732 SDLoc DL(N);
10733 SDNodeFlags Flags = N->getFlags();
10734
10735 if (SDValue V = DAG.simplifySelect(N0, N1, N2))
10736 return V;
10737
10738 if (SDValue V = foldBoolSelectToLogic(N, DAG))
10739 return V;
10740
10741 // select (not Cond), N1, N2 -> select Cond, N2, N1
10742 if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false)) {
10743 SDValue SelectOp = DAG.getSelect(DL, VT, F, N2, N1);
10744 SelectOp->setFlags(Flags);
10745 return SelectOp;
10746 }
10747
10748 if (SDValue V = foldSelectOfConstants(N))
10749 return V;
10750
10751 // If we can fold this based on the true/false value, do so.
10752 if (SimplifySelectOps(N, N1, N2))
10753 return SDValue(N, 0); // Don't revisit N.
10754
10755 if (VT0 == MVT::i1) {
10756 // The code in this block deals with the following 2 equivalences:
10757 // select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
10758 // select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
10759 // The target can specify its preferred form with the
10760 // shouldNormalizeToSelectSequence() callback. However we always transform
10761 // to the right anyway if we find the inner select exists in the DAG anyway
10762 // and we always transform to the left side if we know that we can further
10763 // optimize the combination of the conditions.
10764 bool normalizeToSequence =
10765 TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT);
10766 // select (and Cond0, Cond1), X, Y
10767 // -> select Cond0, (select Cond1, X, Y), Y
10768 if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
10769 SDValue Cond0 = N0->getOperand(0);
10770 SDValue Cond1 = N0->getOperand(1);
10771 SDValue InnerSelect =
10772 DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond1, N1, N2, Flags);
10773 if (normalizeToSequence || !InnerSelect.use_empty())
10774 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0,
10775 InnerSelect, N2, Flags);
10776 // Cleanup on failure.
10777 if (InnerSelect.use_empty())
10778 recursivelyDeleteUnusedNodes(InnerSelect.getNode());
10779 }
10780 // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
10781 if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
10782 SDValue Cond0 = N0->getOperand(0);
10783 SDValue Cond1 = N0->getOperand(1);
10784 SDValue InnerSelect = DAG.getNode(ISD::SELECT, DL, N1.getValueType(),
10785 Cond1, N1, N2, Flags);
10786 if (normalizeToSequence || !InnerSelect.use_empty())
10787 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Cond0, N1,
10788 InnerSelect, Flags);
10789 // Cleanup on failure.
10790 if (InnerSelect.use_empty())
10791 recursivelyDeleteUnusedNodes(InnerSelect.getNode());
10792 }
10793
10794 // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y
10795 if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) {
10796 SDValue N1_0 = N1->getOperand(0);
10797 SDValue N1_1 = N1->getOperand(1);
10798 SDValue N1_2 = N1->getOperand(2);
10799 if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
10800 // Create the actual and node if we can generate good code for it.
10801 if (!normalizeToSequence) {
10802 SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0);
10803 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), And, N1_1,
10804 N2, Flags);
10805 }
10806 // Otherwise see if we can optimize the "and" to a better pattern.
10807 if (SDValue Combined = visitANDLike(N0, N1_0, N)) {
10808 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1_1,
10809 N2, Flags);
10810 }
10811 }
10812 }
10813 // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
10814 if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) {
10815 SDValue N2_0 = N2->getOperand(0);
10816 SDValue N2_1 = N2->getOperand(1);
10817 SDValue N2_2 = N2->getOperand(2);
10818 if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
10819 // Create the actual or node if we can generate good code for it.
10820 if (!normalizeToSequence) {
10821 SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0);
10822 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Or, N1,
10823 N2_2, Flags);
10824 }
10825 // Otherwise see if we can optimize to a better pattern.
10826 if (SDValue Combined = visitORLike(N0, N2_0, N))
10827 return DAG.getNode(ISD::SELECT, DL, N1.getValueType(), Combined, N1,
10828 N2_2, Flags);
10829 }
10830 }
10831 }
10832
10833 // Fold selects based on a setcc into other things, such as min/max/abs.
10834 if (N0.getOpcode() == ISD::SETCC) {
10835 SDValue Cond0 = N0.getOperand(0), Cond1 = N0.getOperand(1);
10836 ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
10837
10838 // select (fcmp lt x, y), x, y -> fminnum x, y
10839 // select (fcmp gt x, y), x, y -> fmaxnum x, y
10840 //
10841 // This is OK if we don't care what happens if either operand is a NaN.
10842 if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, N1, N2, TLI))
10843 if (SDValue FMinMax =
10844 combineMinNumMaxNum(DL, VT, Cond0, Cond1, N1, N2, CC))
10845 return FMinMax;
10846
10847 // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
10848 // This is conservatively limited to pre-legal-operations to give targets
10849 // a chance to reverse the transform if they want to do that. Also, it is
10850 // unlikely that the pattern would be formed late, so it's probably not
10851 // worth going through the other checks.
10852 if (!LegalOperations && TLI.isOperationLegalOrCustom(ISD::UADDO, VT) &&
10853 CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(N1) &&
10854 N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(0)) {
10855 auto *C = dyn_cast<ConstantSDNode>(N2.getOperand(1));
10856 auto *NotC = dyn_cast<ConstantSDNode>(Cond1);
10857 if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
10858 // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
10859 // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
10860 //
10861 // The IR equivalent of this transform would have this form:
10862 // %a = add %x, C
10863 // %c = icmp ugt %x, ~C
10864 // %r = select %c, -1, %a
10865 // =>
10866 // %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
10867 // %u0 = extractvalue %u, 0
10868 // %u1 = extractvalue %u, 1
10869 // %r = select %u1, -1, %u0
10870 SDVTList VTs = DAG.getVTList(VT, VT0);
10871 SDValue UAO = DAG.getNode(ISD::UADDO, DL, VTs, Cond0, N2.getOperand(1));
10872 return DAG.getSelect(DL, VT, UAO.getValue(1), N1, UAO.getValue(0));
10873 }
10874 }
10875
10876 if (TLI.isOperationLegal(ISD::SELECT_CC, VT) ||
10877 (!LegalOperations &&
10878 TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT))) {
10879 // Any flags available in a select/setcc fold will be on the setcc as they
10880 // migrated from fcmp
10881 Flags = N0->getFlags();
10882 SDValue SelectNode = DAG.getNode(ISD::SELECT_CC, DL, VT, Cond0, Cond1, N1,
10883 N2, N0.getOperand(2));
10884 SelectNode->setFlags(Flags);
10885 return SelectNode;
10886 }
10887
10888 if (SDValue NewSel = SimplifySelect(DL, N0, N1, N2))
10889 return NewSel;
10890 }
10891
10892 if (!VT.isVector())
10893 if (SDValue BinOp = foldSelectOfBinops(N))
10894 return BinOp;
10895
10896 return SDValue();
10897 }
10898
10899 // This function assumes all the vselect's arguments are CONCAT_VECTOR
10900 // nodes and that the condition is a BV of ConstantSDNodes (or undefs).
ConvertSelectToConcatVector(SDNode * N,SelectionDAG & DAG)10901 static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
10902 SDLoc DL(N);
10903 SDValue Cond = N->getOperand(0);
10904 SDValue LHS = N->getOperand(1);
10905 SDValue RHS = N->getOperand(2);
10906 EVT VT = N->getValueType(0);
10907 int NumElems = VT.getVectorNumElements();
10908 assert(LHS.getOpcode() == ISD::CONCAT_VECTORS &&
10909 RHS.getOpcode() == ISD::CONCAT_VECTORS &&
10910 Cond.getOpcode() == ISD::BUILD_VECTOR);
10911
10912 // CONCAT_VECTOR can take an arbitrary number of arguments. We only care about
10913 // binary ones here.
10914 if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
10915 return SDValue();
10916
10917 // We're sure we have an even number of elements due to the
10918 // concat_vectors we have as arguments to vselect.
10919 // Skip BV elements until we find one that's not an UNDEF
10920 // After we find an UNDEF element, keep looping until we get to half the
10921 // length of the BV and see if all the non-undef nodes are the same.
10922 ConstantSDNode *BottomHalf = nullptr;
10923 for (int i = 0; i < NumElems / 2; ++i) {
10924 if (Cond->getOperand(i)->isUndef())
10925 continue;
10926
10927 if (BottomHalf == nullptr)
10928 BottomHalf = cast<ConstantSDNode>(Cond.getOperand(i));
10929 else if (Cond->getOperand(i).getNode() != BottomHalf)
10930 return SDValue();
10931 }
10932
10933 // Do the same for the second half of the BuildVector
10934 ConstantSDNode *TopHalf = nullptr;
10935 for (int i = NumElems / 2; i < NumElems; ++i) {
10936 if (Cond->getOperand(i)->isUndef())
10937 continue;
10938
10939 if (TopHalf == nullptr)
10940 TopHalf = cast<ConstantSDNode>(Cond.getOperand(i));
10941 else if (Cond->getOperand(i).getNode() != TopHalf)
10942 return SDValue();
10943 }
10944
10945 assert(TopHalf && BottomHalf &&
10946 "One half of the selector was all UNDEFs and the other was all the "
10947 "same value. This should have been addressed before this function.");
10948 return DAG.getNode(
10949 ISD::CONCAT_VECTORS, DL, VT,
10950 BottomHalf->isZero() ? RHS->getOperand(0) : LHS->getOperand(0),
10951 TopHalf->isZero() ? RHS->getOperand(1) : LHS->getOperand(1));
10952 }
10953
refineUniformBase(SDValue & BasePtr,SDValue & Index,bool IndexIsScaled,SelectionDAG & DAG,const SDLoc & DL)10954 bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
10955 SelectionDAG &DAG, const SDLoc &DL) {
10956 if (Index.getOpcode() != ISD::ADD)
10957 return false;
10958
10959 // Only perform the transformation when existing operands can be reused.
10960 if (IndexIsScaled)
10961 return false;
10962
10963 if (!isNullConstant(BasePtr) && !Index.hasOneUse())
10964 return false;
10965
10966 EVT VT = BasePtr.getValueType();
10967 if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(0));
10968 SplatVal && SplatVal.getValueType() == VT) {
10969 if (isNullConstant(BasePtr))
10970 BasePtr = SplatVal;
10971 else
10972 BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
10973 Index = Index.getOperand(1);
10974 return true;
10975 }
10976 if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(1));
10977 SplatVal && SplatVal.getValueType() == VT) {
10978 if (isNullConstant(BasePtr))
10979 BasePtr = SplatVal;
10980 else
10981 BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
10982 Index = Index.getOperand(0);
10983 return true;
10984 }
10985 return false;
10986 }
10987
10988 // Fold sext/zext of index into index type.
refineIndexType(SDValue & Index,ISD::MemIndexType & IndexType,EVT DataVT,SelectionDAG & DAG)10989 bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType, EVT DataVT,
10990 SelectionDAG &DAG) {
10991 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
10992
10993 // It's always safe to look through zero extends.
10994 if (Index.getOpcode() == ISD::ZERO_EXTEND) {
10995 SDValue Op = Index.getOperand(0);
10996 if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType(), DataVT)) {
10997 IndexType = ISD::UNSIGNED_SCALED;
10998 Index = Op;
10999 return true;
11000 }
11001 if (ISD::isIndexTypeSigned(IndexType)) {
11002 IndexType = ISD::UNSIGNED_SCALED;
11003 return true;
11004 }
11005 }
11006
11007 // It's only safe to look through sign extends when Index is signed.
11008 if (Index.getOpcode() == ISD::SIGN_EXTEND &&
11009 ISD::isIndexTypeSigned(IndexType)) {
11010 SDValue Op = Index.getOperand(0);
11011 if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType(), DataVT)) {
11012 Index = Op;
11013 return true;
11014 }
11015 }
11016
11017 return false;
11018 }
11019
visitVPSCATTER(SDNode * N)11020 SDValue DAGCombiner::visitVPSCATTER(SDNode *N) {
11021 VPScatterSDNode *MSC = cast<VPScatterSDNode>(N);
11022 SDValue Mask = MSC->getMask();
11023 SDValue Chain = MSC->getChain();
11024 SDValue Index = MSC->getIndex();
11025 SDValue Scale = MSC->getScale();
11026 SDValue StoreVal = MSC->getValue();
11027 SDValue BasePtr = MSC->getBasePtr();
11028 SDValue VL = MSC->getVectorLength();
11029 ISD::MemIndexType IndexType = MSC->getIndexType();
11030 SDLoc DL(N);
11031
11032 // Zap scatters with a zero mask.
11033 if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
11034 return Chain;
11035
11036 if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) {
11037 SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
11038 return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11039 DL, Ops, MSC->getMemOperand(), IndexType);
11040 }
11041
11042 if (refineIndexType(Index, IndexType, StoreVal.getValueType(), DAG)) {
11043 SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
11044 return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11045 DL, Ops, MSC->getMemOperand(), IndexType);
11046 }
11047
11048 return SDValue();
11049 }
11050
visitMSCATTER(SDNode * N)11051 SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
11052 MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(N);
11053 SDValue Mask = MSC->getMask();
11054 SDValue Chain = MSC->getChain();
11055 SDValue Index = MSC->getIndex();
11056 SDValue Scale = MSC->getScale();
11057 SDValue StoreVal = MSC->getValue();
11058 SDValue BasePtr = MSC->getBasePtr();
11059 ISD::MemIndexType IndexType = MSC->getIndexType();
11060 SDLoc DL(N);
11061
11062 // Zap scatters with a zero mask.
11063 if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
11064 return Chain;
11065
11066 if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) {
11067 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
11068 return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11069 DL, Ops, MSC->getMemOperand(), IndexType,
11070 MSC->isTruncatingStore());
11071 }
11072
11073 if (refineIndexType(Index, IndexType, StoreVal.getValueType(), DAG)) {
11074 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
11075 return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11076 DL, Ops, MSC->getMemOperand(), IndexType,
11077 MSC->isTruncatingStore());
11078 }
11079
11080 return SDValue();
11081 }
11082
visitMSTORE(SDNode * N)11083 SDValue DAGCombiner::visitMSTORE(SDNode *N) {
11084 MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
11085 SDValue Mask = MST->getMask();
11086 SDValue Chain = MST->getChain();
11087 SDValue Value = MST->getValue();
11088 SDValue Ptr = MST->getBasePtr();
11089 SDLoc DL(N);
11090
11091 // Zap masked stores with a zero mask.
11092 if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
11093 return Chain;
11094
11095 // If this is a masked load with an all ones mask, we can use a unmasked load.
11096 // FIXME: Can we do this for indexed, compressing, or truncating stores?
11097 if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MST->isUnindexed() &&
11098 !MST->isCompressingStore() && !MST->isTruncatingStore())
11099 return DAG.getStore(MST->getChain(), SDLoc(N), MST->getValue(),
11100 MST->getBasePtr(), MST->getPointerInfo(),
11101 MST->getOriginalAlign(), MachineMemOperand::MOStore,
11102 MST->getAAInfo());
11103
11104 // Try transforming N to an indexed store.
11105 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
11106 return SDValue(N, 0);
11107
11108 if (MST->isTruncatingStore() && MST->isUnindexed() &&
11109 Value.getValueType().isInteger() &&
11110 (!isa<ConstantSDNode>(Value) ||
11111 !cast<ConstantSDNode>(Value)->isOpaque())) {
11112 APInt TruncDemandedBits =
11113 APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
11114 MST->getMemoryVT().getScalarSizeInBits());
11115
11116 // See if we can simplify the operation with
11117 // SimplifyDemandedBits, which only works if the value has a single use.
11118 if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
11119 // Re-visit the store if anything changed and the store hasn't been merged
11120 // with another node (N is deleted) SimplifyDemandedBits will add Value's
11121 // node back to the worklist if necessary, but we also need to re-visit
11122 // the Store node itself.
11123 if (N->getOpcode() != ISD::DELETED_NODE)
11124 AddToWorklist(N);
11125 return SDValue(N, 0);
11126 }
11127 }
11128
11129 // If this is a TRUNC followed by a masked store, fold this into a masked
11130 // truncating store. We can do this even if this is already a masked
11131 // truncstore.
11132 // TODO: Try combine to masked compress store if possiable.
11133 if ((Value.getOpcode() == ISD::TRUNCATE) && Value->hasOneUse() &&
11134 MST->isUnindexed() && !MST->isCompressingStore() &&
11135 TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
11136 MST->getMemoryVT(), LegalOperations)) {
11137 auto Mask = TLI.promoteTargetBoolean(DAG, MST->getMask(),
11138 Value.getOperand(0).getValueType());
11139 return DAG.getMaskedStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
11140 MST->getOffset(), Mask, MST->getMemoryVT(),
11141 MST->getMemOperand(), MST->getAddressingMode(),
11142 /*IsTruncating=*/true);
11143 }
11144
11145 return SDValue();
11146 }
11147
visitVPGATHER(SDNode * N)11148 SDValue DAGCombiner::visitVPGATHER(SDNode *N) {
11149 VPGatherSDNode *MGT = cast<VPGatherSDNode>(N);
11150 SDValue Mask = MGT->getMask();
11151 SDValue Chain = MGT->getChain();
11152 SDValue Index = MGT->getIndex();
11153 SDValue Scale = MGT->getScale();
11154 SDValue BasePtr = MGT->getBasePtr();
11155 SDValue VL = MGT->getVectorLength();
11156 ISD::MemIndexType IndexType = MGT->getIndexType();
11157 SDLoc DL(N);
11158
11159 if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) {
11160 SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
11161 return DAG.getGatherVP(
11162 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
11163 Ops, MGT->getMemOperand(), IndexType);
11164 }
11165
11166 if (refineIndexType(Index, IndexType, N->getValueType(0), DAG)) {
11167 SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
11168 return DAG.getGatherVP(
11169 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
11170 Ops, MGT->getMemOperand(), IndexType);
11171 }
11172
11173 return SDValue();
11174 }
11175
visitMGATHER(SDNode * N)11176 SDValue DAGCombiner::visitMGATHER(SDNode *N) {
11177 MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(N);
11178 SDValue Mask = MGT->getMask();
11179 SDValue Chain = MGT->getChain();
11180 SDValue Index = MGT->getIndex();
11181 SDValue Scale = MGT->getScale();
11182 SDValue PassThru = MGT->getPassThru();
11183 SDValue BasePtr = MGT->getBasePtr();
11184 ISD::MemIndexType IndexType = MGT->getIndexType();
11185 SDLoc DL(N);
11186
11187 // Zap gathers with a zero mask.
11188 if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
11189 return CombineTo(N, PassThru, MGT->getChain());
11190
11191 if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) {
11192 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
11193 return DAG.getMaskedGather(
11194 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
11195 Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
11196 }
11197
11198 if (refineIndexType(Index, IndexType, N->getValueType(0), DAG)) {
11199 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
11200 return DAG.getMaskedGather(
11201 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
11202 Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
11203 }
11204
11205 return SDValue();
11206 }
11207
visitMLOAD(SDNode * N)11208 SDValue DAGCombiner::visitMLOAD(SDNode *N) {
11209 MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(N);
11210 SDValue Mask = MLD->getMask();
11211 SDLoc DL(N);
11212
11213 // Zap masked loads with a zero mask.
11214 if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
11215 return CombineTo(N, MLD->getPassThru(), MLD->getChain());
11216
11217 // If this is a masked load with an all ones mask, we can use a unmasked load.
11218 // FIXME: Can we do this for indexed, expanding, or extending loads?
11219 if (ISD::isConstantSplatVectorAllOnes(Mask.getNode()) && MLD->isUnindexed() &&
11220 !MLD->isExpandingLoad() && MLD->getExtensionType() == ISD::NON_EXTLOAD) {
11221 SDValue NewLd = DAG.getLoad(
11222 N->getValueType(0), SDLoc(N), MLD->getChain(), MLD->getBasePtr(),
11223 MLD->getPointerInfo(), MLD->getOriginalAlign(),
11224 MachineMemOperand::MOLoad, MLD->getAAInfo(), MLD->getRanges());
11225 return CombineTo(N, NewLd, NewLd.getValue(1));
11226 }
11227
11228 // Try transforming N to an indexed load.
11229 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
11230 return SDValue(N, 0);
11231
11232 return SDValue();
11233 }
11234
11235 /// A vector select of 2 constant vectors can be simplified to math/logic to
11236 /// avoid a variable select instruction and possibly avoid constant loads.
foldVSelectOfConstants(SDNode * N)11237 SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
11238 SDValue Cond = N->getOperand(0);
11239 SDValue N1 = N->getOperand(1);
11240 SDValue N2 = N->getOperand(2);
11241 EVT VT = N->getValueType(0);
11242 if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
11243 !shouldConvertSelectOfConstantsToMath(Cond, VT, TLI) ||
11244 !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()) ||
11245 !ISD::isBuildVectorOfConstantSDNodes(N2.getNode()))
11246 return SDValue();
11247
11248 // Check if we can use the condition value to increment/decrement a single
11249 // constant value. This simplifies a select to an add and removes a constant
11250 // load/materialization from the general case.
11251 bool AllAddOne = true;
11252 bool AllSubOne = true;
11253 unsigned Elts = VT.getVectorNumElements();
11254 for (unsigned i = 0; i != Elts; ++i) {
11255 SDValue N1Elt = N1.getOperand(i);
11256 SDValue N2Elt = N2.getOperand(i);
11257 if (N1Elt.isUndef() || N2Elt.isUndef())
11258 continue;
11259 if (N1Elt.getValueType() != N2Elt.getValueType())
11260 continue;
11261
11262 const APInt &C1 = cast<ConstantSDNode>(N1Elt)->getAPIntValue();
11263 const APInt &C2 = cast<ConstantSDNode>(N2Elt)->getAPIntValue();
11264 if (C1 != C2 + 1)
11265 AllAddOne = false;
11266 if (C1 != C2 - 1)
11267 AllSubOne = false;
11268 }
11269
11270 // Further simplifications for the extra-special cases where the constants are
11271 // all 0 or all -1 should be implemented as folds of these patterns.
11272 SDLoc DL(N);
11273 if (AllAddOne || AllSubOne) {
11274 // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
11275 // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
11276 auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
11277 SDValue ExtendedCond = DAG.getNode(ExtendOpcode, DL, VT, Cond);
11278 return DAG.getNode(ISD::ADD, DL, VT, ExtendedCond, N2);
11279 }
11280
11281 // select Cond, Pow2C, 0 --> (zext Cond) << log2(Pow2C)
11282 APInt Pow2C;
11283 if (ISD::isConstantSplatVector(N1.getNode(), Pow2C) && Pow2C.isPowerOf2() &&
11284 isNullOrNullSplat(N2)) {
11285 SDValue ZextCond = DAG.getZExtOrTrunc(Cond, DL, VT);
11286 SDValue ShAmtC = DAG.getConstant(Pow2C.exactLogBase2(), DL, VT);
11287 return DAG.getNode(ISD::SHL, DL, VT, ZextCond, ShAmtC);
11288 }
11289
11290 if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
11291 return V;
11292
11293 // The general case for select-of-constants:
11294 // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
11295 // ...but that only makes sense if a vselect is slower than 2 logic ops, so
11296 // leave that to a machine-specific pass.
11297 return SDValue();
11298 }
11299
visitVSELECT(SDNode * N)11300 SDValue DAGCombiner::visitVSELECT(SDNode *N) {
11301 SDValue N0 = N->getOperand(0);
11302 SDValue N1 = N->getOperand(1);
11303 SDValue N2 = N->getOperand(2);
11304 EVT VT = N->getValueType(0);
11305 SDLoc DL(N);
11306
11307 if (SDValue V = DAG.simplifySelect(N0, N1, N2))
11308 return V;
11309
11310 if (SDValue V = foldBoolSelectToLogic(N, DAG))
11311 return V;
11312
11313 // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
11314 if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
11315 return DAG.getSelect(DL, VT, F, N2, N1);
11316
11317 // Canonicalize integer abs.
11318 // vselect (setg[te] X, 0), X, -X ->
11319 // vselect (setgt X, -1), X, -X ->
11320 // vselect (setl[te] X, 0), -X, X ->
11321 // Y = sra (X, size(X)-1); xor (add (X, Y), Y)
11322 if (N0.getOpcode() == ISD::SETCC) {
11323 SDValue LHS = N0.getOperand(0), RHS = N0.getOperand(1);
11324 ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
11325 bool isAbs = false;
11326 bool RHSIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode());
11327
11328 if (((RHSIsAllZeros && (CC == ISD::SETGT || CC == ISD::SETGE)) ||
11329 (ISD::isBuildVectorAllOnes(RHS.getNode()) && CC == ISD::SETGT)) &&
11330 N1 == LHS && N2.getOpcode() == ISD::SUB && N1 == N2.getOperand(1))
11331 isAbs = ISD::isBuildVectorAllZeros(N2.getOperand(0).getNode());
11332 else if ((RHSIsAllZeros && (CC == ISD::SETLT || CC == ISD::SETLE)) &&
11333 N2 == LHS && N1.getOpcode() == ISD::SUB && N2 == N1.getOperand(1))
11334 isAbs = ISD::isBuildVectorAllZeros(N1.getOperand(0).getNode());
11335
11336 if (isAbs) {
11337 if (TLI.isOperationLegalOrCustom(ISD::ABS, VT))
11338 return DAG.getNode(ISD::ABS, DL, VT, LHS);
11339
11340 SDValue Shift = DAG.getNode(ISD::SRA, DL, VT, LHS,
11341 DAG.getConstant(VT.getScalarSizeInBits() - 1,
11342 DL, getShiftAmountTy(VT)));
11343 SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift);
11344 AddToWorklist(Shift.getNode());
11345 AddToWorklist(Add.getNode());
11346 return DAG.getNode(ISD::XOR, DL, VT, Add, Shift);
11347 }
11348
11349 // vselect x, y (fcmp lt x, y) -> fminnum x, y
11350 // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
11351 //
11352 // This is OK if we don't care about what happens if either operand is a
11353 // NaN.
11354 //
11355 if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, TLI)) {
11356 if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, LHS, RHS, N1, N2, CC))
11357 return FMinMax;
11358 }
11359
11360 if (SDValue S = PerformMinMaxFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
11361 return S;
11362 if (SDValue S = PerformUMinFpToSatCombine(LHS, RHS, N1, N2, CC, DAG))
11363 return S;
11364
11365 // If this select has a condition (setcc) with narrower operands than the
11366 // select, try to widen the compare to match the select width.
11367 // TODO: This should be extended to handle any constant.
11368 // TODO: This could be extended to handle non-loading patterns, but that
11369 // requires thorough testing to avoid regressions.
11370 if (isNullOrNullSplat(RHS)) {
11371 EVT NarrowVT = LHS.getValueType();
11372 EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger();
11373 EVT SetCCVT = getSetCCResultType(LHS.getValueType());
11374 unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
11375 unsigned WideWidth = WideVT.getScalarSizeInBits();
11376 bool IsSigned = isSignedIntSetCC(CC);
11377 auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
11378 if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
11379 SetCCWidth != 1 && SetCCWidth < WideWidth &&
11380 TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) &&
11381 TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) {
11382 // Both compare operands can be widened for free. The LHS can use an
11383 // extended load, and the RHS is a constant:
11384 // vselect (ext (setcc load(X), C)), N1, N2 -->
11385 // vselect (setcc extload(X), C'), N1, N2
11386 auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
11387 SDValue WideLHS = DAG.getNode(ExtOpcode, DL, WideVT, LHS);
11388 SDValue WideRHS = DAG.getNode(ExtOpcode, DL, WideVT, RHS);
11389 EVT WideSetCCVT = getSetCCResultType(WideVT);
11390 SDValue WideSetCC = DAG.getSetCC(DL, WideSetCCVT, WideLHS, WideRHS, CC);
11391 return DAG.getSelect(DL, N1.getValueType(), WideSetCC, N1, N2);
11392 }
11393 }
11394
11395 // Match VSELECTs into add with unsigned saturation.
11396 if (hasOperation(ISD::UADDSAT, VT)) {
11397 // Check if one of the arms of the VSELECT is vector with all bits set.
11398 // If it's on the left side invert the predicate to simplify logic below.
11399 SDValue Other;
11400 ISD::CondCode SatCC = CC;
11401 if (ISD::isConstantSplatVectorAllOnes(N1.getNode())) {
11402 Other = N2;
11403 SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
11404 } else if (ISD::isConstantSplatVectorAllOnes(N2.getNode())) {
11405 Other = N1;
11406 }
11407
11408 if (Other && Other.getOpcode() == ISD::ADD) {
11409 SDValue CondLHS = LHS, CondRHS = RHS;
11410 SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
11411
11412 // Canonicalize condition operands.
11413 if (SatCC == ISD::SETUGE) {
11414 std::swap(CondLHS, CondRHS);
11415 SatCC = ISD::SETULE;
11416 }
11417
11418 // We can test against either of the addition operands.
11419 // x <= x+y ? x+y : ~0 --> uaddsat x, y
11420 // x+y >= x ? x+y : ~0 --> uaddsat x, y
11421 if (SatCC == ISD::SETULE && Other == CondRHS &&
11422 (OpLHS == CondLHS || OpRHS == CondLHS))
11423 return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
11424
11425 if (OpRHS.getOpcode() == CondRHS.getOpcode() &&
11426 (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
11427 OpRHS.getOpcode() == ISD::SPLAT_VECTOR) &&
11428 CondLHS == OpLHS) {
11429 // If the RHS is a constant we have to reverse the const
11430 // canonicalization.
11431 // x >= ~C ? x+C : ~0 --> uaddsat x, C
11432 auto MatchUADDSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
11433 return Cond->getAPIntValue() == ~Op->getAPIntValue();
11434 };
11435 if (SatCC == ISD::SETULE &&
11436 ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUADDSAT))
11437 return DAG.getNode(ISD::UADDSAT, DL, VT, OpLHS, OpRHS);
11438 }
11439 }
11440 }
11441
11442 // Match VSELECTs into sub with unsigned saturation.
11443 if (hasOperation(ISD::USUBSAT, VT)) {
11444 // Check if one of the arms of the VSELECT is a zero vector. If it's on
11445 // the left side invert the predicate to simplify logic below.
11446 SDValue Other;
11447 ISD::CondCode SatCC = CC;
11448 if (ISD::isConstantSplatVectorAllZeros(N1.getNode())) {
11449 Other = N2;
11450 SatCC = ISD::getSetCCInverse(SatCC, VT.getScalarType());
11451 } else if (ISD::isConstantSplatVectorAllZeros(N2.getNode())) {
11452 Other = N1;
11453 }
11454
11455 // zext(x) >= y ? trunc(zext(x) - y) : 0
11456 // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
11457 // zext(x) > y ? trunc(zext(x) - y) : 0
11458 // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
11459 if (Other && Other.getOpcode() == ISD::TRUNCATE &&
11460 Other.getOperand(0).getOpcode() == ISD::SUB &&
11461 (SatCC == ISD::SETUGE || SatCC == ISD::SETUGT)) {
11462 SDValue OpLHS = Other.getOperand(0).getOperand(0);
11463 SDValue OpRHS = Other.getOperand(0).getOperand(1);
11464 if (LHS == OpLHS && RHS == OpRHS && LHS.getOpcode() == ISD::ZERO_EXTEND)
11465 if (SDValue R = getTruncatedUSUBSAT(VT, LHS.getValueType(), LHS, RHS,
11466 DAG, DL))
11467 return R;
11468 }
11469
11470 if (Other && Other.getNumOperands() == 2) {
11471 SDValue CondRHS = RHS;
11472 SDValue OpLHS = Other.getOperand(0), OpRHS = Other.getOperand(1);
11473
11474 if (OpLHS == LHS) {
11475 // Look for a general sub with unsigned saturation first.
11476 // x >= y ? x-y : 0 --> usubsat x, y
11477 // x > y ? x-y : 0 --> usubsat x, y
11478 if ((SatCC == ISD::SETUGE || SatCC == ISD::SETUGT) &&
11479 Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
11480 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
11481
11482 if (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
11483 OpRHS.getOpcode() == ISD::SPLAT_VECTOR) {
11484 if (CondRHS.getOpcode() == ISD::BUILD_VECTOR ||
11485 CondRHS.getOpcode() == ISD::SPLAT_VECTOR) {
11486 // If the RHS is a constant we have to reverse the const
11487 // canonicalization.
11488 // x > C-1 ? x+-C : 0 --> usubsat x, C
11489 auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
11490 return (!Op && !Cond) ||
11491 (Op && Cond &&
11492 Cond->getAPIntValue() == (-Op->getAPIntValue() - 1));
11493 };
11494 if (SatCC == ISD::SETUGT && Other.getOpcode() == ISD::ADD &&
11495 ISD::matchBinaryPredicate(OpRHS, CondRHS, MatchUSUBSAT,
11496 /*AllowUndefs*/ true)) {
11497 OpRHS = DAG.getNegative(OpRHS, DL, VT);
11498 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
11499 }
11500
11501 // Another special case: If C was a sign bit, the sub has been
11502 // canonicalized into a xor.
11503 // FIXME: Would it be better to use computeKnownBits to
11504 // determine whether it's safe to decanonicalize the xor?
11505 // x s< 0 ? x^C : 0 --> usubsat x, C
11506 APInt SplatValue;
11507 if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
11508 ISD::isConstantSplatVector(OpRHS.getNode(), SplatValue) &&
11509 ISD::isConstantSplatVectorAllZeros(CondRHS.getNode()) &&
11510 SplatValue.isSignMask()) {
11511 // Note that we have to rebuild the RHS constant here to
11512 // ensure we don't rely on particular values of undef lanes.
11513 OpRHS = DAG.getConstant(SplatValue, DL, VT);
11514 return DAG.getNode(ISD::USUBSAT, DL, VT, OpLHS, OpRHS);
11515 }
11516 }
11517 }
11518 }
11519 }
11520 }
11521 }
11522
11523 if (SimplifySelectOps(N, N1, N2))
11524 return SDValue(N, 0); // Don't revisit N.
11525
11526 // Fold (vselect all_ones, N1, N2) -> N1
11527 if (ISD::isConstantSplatVectorAllOnes(N0.getNode()))
11528 return N1;
11529 // Fold (vselect all_zeros, N1, N2) -> N2
11530 if (ISD::isConstantSplatVectorAllZeros(N0.getNode()))
11531 return N2;
11532
11533 // The ConvertSelectToConcatVector function is assuming both the above
11534 // checks for (vselect (build_vector all{ones,zeros) ...) have been made
11535 // and addressed.
11536 if (N1.getOpcode() == ISD::CONCAT_VECTORS &&
11537 N2.getOpcode() == ISD::CONCAT_VECTORS &&
11538 ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) {
11539 if (SDValue CV = ConvertSelectToConcatVector(N, DAG))
11540 return CV;
11541 }
11542
11543 if (SDValue V = foldVSelectOfConstants(N))
11544 return V;
11545
11546 if (hasOperation(ISD::SRA, VT))
11547 if (SDValue V = foldVSelectToSignBitSplatMask(N, DAG))
11548 return V;
11549
11550 if (SimplifyDemandedVectorElts(SDValue(N, 0)))
11551 return SDValue(N, 0);
11552
11553 return SDValue();
11554 }
11555
visitSELECT_CC(SDNode * N)11556 SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
11557 SDValue N0 = N->getOperand(0);
11558 SDValue N1 = N->getOperand(1);
11559 SDValue N2 = N->getOperand(2);
11560 SDValue N3 = N->getOperand(3);
11561 SDValue N4 = N->getOperand(4);
11562 ISD::CondCode CC = cast<CondCodeSDNode>(N4)->get();
11563
11564 // fold select_cc lhs, rhs, x, x, cc -> x
11565 if (N2 == N3)
11566 return N2;
11567
11568 // select_cc bool, 0, x, y, seteq -> select bool, y, x
11569 if (CC == ISD::SETEQ && !LegalTypes && N0.getValueType() == MVT::i1 &&
11570 isNullConstant(N1))
11571 return DAG.getSelect(SDLoc(N), N2.getValueType(), N0, N3, N2);
11572
11573 // Determine if the condition we're dealing with is constant
11574 if (SDValue SCC = SimplifySetCC(getSetCCResultType(N0.getValueType()), N0, N1,
11575 CC, SDLoc(N), false)) {
11576 AddToWorklist(SCC.getNode());
11577
11578 // cond always true -> true val
11579 // cond always false -> false val
11580 if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC.getNode()))
11581 return SCCC->isZero() ? N3 : N2;
11582
11583 // When the condition is UNDEF, just return the first operand. This is
11584 // coherent the DAG creation, no setcc node is created in this case
11585 if (SCC->isUndef())
11586 return N2;
11587
11588 // Fold to a simpler select_cc
11589 if (SCC.getOpcode() == ISD::SETCC) {
11590 SDValue SelectOp = DAG.getNode(
11591 ISD::SELECT_CC, SDLoc(N), N2.getValueType(), SCC.getOperand(0),
11592 SCC.getOperand(1), N2, N3, SCC.getOperand(2));
11593 SelectOp->setFlags(SCC->getFlags());
11594 return SelectOp;
11595 }
11596 }
11597
11598 // If we can fold this based on the true/false value, do so.
11599 if (SimplifySelectOps(N, N2, N3))
11600 return SDValue(N, 0); // Don't revisit N.
11601
11602 // fold select_cc into other things, such as min/max/abs
11603 return SimplifySelectCC(SDLoc(N), N0, N1, N2, N3, CC);
11604 }
11605
visitSETCC(SDNode * N)11606 SDValue DAGCombiner::visitSETCC(SDNode *N) {
11607 // setcc is very commonly used as an argument to brcond. This pattern
11608 // also lend itself to numerous combines and, as a result, it is desired
11609 // we keep the argument to a brcond as a setcc as much as possible.
11610 bool PreferSetCC =
11611 N->hasOneUse() && N->use_begin()->getOpcode() == ISD::BRCOND;
11612
11613 ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
11614 EVT VT = N->getValueType(0);
11615
11616 // SETCC(FREEZE(X), CONST, Cond)
11617 // =>
11618 // FREEZE(SETCC(X, CONST, Cond))
11619 // This is correct if FREEZE(X) has one use and SETCC(FREEZE(X), CONST, Cond)
11620 // isn't equivalent to true or false.
11621 // For example, SETCC(FREEZE(X), -128, SETULT) cannot be folded to
11622 // FREEZE(SETCC(X, -128, SETULT)) because X can be poison.
11623 //
11624 // This transformation is beneficial because visitBRCOND can fold
11625 // BRCOND(FREEZE(X)) to BRCOND(X).
11626
11627 // Conservatively optimize integer comparisons only.
11628 if (PreferSetCC) {
11629 // Do this only when SETCC is going to be used by BRCOND.
11630
11631 SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
11632 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
11633 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
11634 bool Updated = false;
11635
11636 // Is 'X Cond C' always true or false?
11637 auto IsAlwaysTrueOrFalse = [](ISD::CondCode Cond, ConstantSDNode *C) {
11638 bool False = (Cond == ISD::SETULT && C->isZero()) ||
11639 (Cond == ISD::SETLT && C->isMinSignedValue()) ||
11640 (Cond == ISD::SETUGT && C->isAllOnes()) ||
11641 (Cond == ISD::SETGT && C->isMaxSignedValue());
11642 bool True = (Cond == ISD::SETULE && C->isAllOnes()) ||
11643 (Cond == ISD::SETLE && C->isMaxSignedValue()) ||
11644 (Cond == ISD::SETUGE && C->isZero()) ||
11645 (Cond == ISD::SETGE && C->isMinSignedValue());
11646 return True || False;
11647 };
11648
11649 if (N0->getOpcode() == ISD::FREEZE && N0.hasOneUse() && N1C) {
11650 if (!IsAlwaysTrueOrFalse(Cond, N1C)) {
11651 N0 = N0->getOperand(0);
11652 Updated = true;
11653 }
11654 }
11655 if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse() && N0C) {
11656 if (!IsAlwaysTrueOrFalse(ISD::getSetCCSwappedOperands(Cond),
11657 N0C)) {
11658 N1 = N1->getOperand(0);
11659 Updated = true;
11660 }
11661 }
11662
11663 if (Updated)
11664 return DAG.getFreeze(DAG.getSetCC(SDLoc(N), VT, N0, N1, Cond));
11665 }
11666
11667 SDValue Combined = SimplifySetCC(VT, N->getOperand(0), N->getOperand(1), Cond,
11668 SDLoc(N), !PreferSetCC);
11669
11670 if (!Combined)
11671 return SDValue();
11672
11673 // If we prefer to have a setcc, and we don't, we'll try our best to
11674 // recreate one using rebuildSetCC.
11675 if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
11676 SDValue NewSetCC = rebuildSetCC(Combined);
11677
11678 // We don't have anything interesting to combine to.
11679 if (NewSetCC.getNode() == N)
11680 return SDValue();
11681
11682 if (NewSetCC)
11683 return NewSetCC;
11684 }
11685
11686 return Combined;
11687 }
11688
visitSETCCCARRY(SDNode * N)11689 SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
11690 SDValue LHS = N->getOperand(0);
11691 SDValue RHS = N->getOperand(1);
11692 SDValue Carry = N->getOperand(2);
11693 SDValue Cond = N->getOperand(3);
11694
11695 // If Carry is false, fold to a regular SETCC.
11696 if (isNullConstant(Carry))
11697 return DAG.getNode(ISD::SETCC, SDLoc(N), N->getVTList(), LHS, RHS, Cond);
11698
11699 return SDValue();
11700 }
11701
11702 /// Check if N satisfies:
11703 /// N is used once.
11704 /// N is a Load.
11705 /// The load is compatible with ExtOpcode. It means
11706 /// If load has explicit zero/sign extension, ExpOpcode must have the same
11707 /// extension.
11708 /// Otherwise returns true.
isCompatibleLoad(SDValue N,unsigned ExtOpcode)11709 static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) {
11710 if (!N.hasOneUse())
11711 return false;
11712
11713 if (!isa<LoadSDNode>(N))
11714 return false;
11715
11716 LoadSDNode *Load = cast<LoadSDNode>(N);
11717 ISD::LoadExtType LoadExt = Load->getExtensionType();
11718 if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD)
11719 return true;
11720
11721 // Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same
11722 // extension.
11723 if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) ||
11724 (LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND))
11725 return false;
11726
11727 return true;
11728 }
11729
11730 /// Fold
11731 /// (sext (select c, load x, load y)) -> (select c, sextload x, sextload y)
11732 /// (zext (select c, load x, load y)) -> (select c, zextload x, zextload y)
11733 /// (aext (select c, load x, load y)) -> (select c, extload x, extload y)
11734 /// This function is called by the DAGCombiner when visiting sext/zext/aext
11735 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
tryToFoldExtendSelectLoad(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG)11736 static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
11737 SelectionDAG &DAG) {
11738 unsigned Opcode = N->getOpcode();
11739 SDValue N0 = N->getOperand(0);
11740 EVT VT = N->getValueType(0);
11741 SDLoc DL(N);
11742
11743 assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
11744 Opcode == ISD::ANY_EXTEND) &&
11745 "Expected EXTEND dag node in input!");
11746
11747 if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
11748 !N0.hasOneUse())
11749 return SDValue();
11750
11751 SDValue Op1 = N0->getOperand(1);
11752 SDValue Op2 = N0->getOperand(2);
11753 if (!isCompatibleLoad(Op1, Opcode) || !isCompatibleLoad(Op2, Opcode))
11754 return SDValue();
11755
11756 auto ExtLoadOpcode = ISD::EXTLOAD;
11757 if (Opcode == ISD::SIGN_EXTEND)
11758 ExtLoadOpcode = ISD::SEXTLOAD;
11759 else if (Opcode == ISD::ZERO_EXTEND)
11760 ExtLoadOpcode = ISD::ZEXTLOAD;
11761
11762 LoadSDNode *Load1 = cast<LoadSDNode>(Op1);
11763 LoadSDNode *Load2 = cast<LoadSDNode>(Op2);
11764 if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) ||
11765 !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT()))
11766 return SDValue();
11767
11768 SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1);
11769 SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2);
11770 return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2);
11771 }
11772
11773 /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
11774 /// a build_vector of constants.
11775 /// This function is called by the DAGCombiner when visiting sext/zext/aext
11776 /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
11777 /// Vector extends are not folded if operations are legal; this is to
11778 /// avoid introducing illegal build_vector dag nodes.
tryToFoldExtendOfConstant(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG,bool LegalTypes)11779 static SDValue tryToFoldExtendOfConstant(SDNode *N, const TargetLowering &TLI,
11780 SelectionDAG &DAG, bool LegalTypes) {
11781 unsigned Opcode = N->getOpcode();
11782 SDValue N0 = N->getOperand(0);
11783 EVT VT = N->getValueType(0);
11784 SDLoc DL(N);
11785
11786 assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
11787 Opcode == ISD::ANY_EXTEND ||
11788 Opcode == ISD::SIGN_EXTEND_VECTOR_INREG ||
11789 Opcode == ISD::ZERO_EXTEND_VECTOR_INREG ||
11790 Opcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
11791 "Expected EXTEND dag node in input!");
11792
11793 // fold (sext c1) -> c1
11794 // fold (zext c1) -> c1
11795 // fold (aext c1) -> c1
11796 if (isa<ConstantSDNode>(N0))
11797 return DAG.getNode(Opcode, DL, VT, N0);
11798
11799 // fold (sext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
11800 // fold (zext (select cond, c1, c2)) -> (select cond, zext c1, zext c2)
11801 // fold (aext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
11802 if (N0->getOpcode() == ISD::SELECT) {
11803 SDValue Op1 = N0->getOperand(1);
11804 SDValue Op2 = N0->getOperand(2);
11805 if (isa<ConstantSDNode>(Op1) && isa<ConstantSDNode>(Op2) &&
11806 (Opcode != ISD::ZERO_EXTEND || !TLI.isZExtFree(N0.getValueType(), VT))) {
11807 // For any_extend, choose sign extension of the constants to allow a
11808 // possible further transform to sign_extend_inreg.i.e.
11809 //
11810 // t1: i8 = select t0, Constant:i8<-1>, Constant:i8<0>
11811 // t2: i64 = any_extend t1
11812 // -->
11813 // t3: i64 = select t0, Constant:i64<-1>, Constant:i64<0>
11814 // -->
11815 // t4: i64 = sign_extend_inreg t3
11816 unsigned FoldOpc = Opcode;
11817 if (FoldOpc == ISD::ANY_EXTEND)
11818 FoldOpc = ISD::SIGN_EXTEND;
11819 return DAG.getSelect(DL, VT, N0->getOperand(0),
11820 DAG.getNode(FoldOpc, DL, VT, Op1),
11821 DAG.getNode(FoldOpc, DL, VT, Op2));
11822 }
11823 }
11824
11825 // fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
11826 // fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
11827 // fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
11828 EVT SVT = VT.getScalarType();
11829 if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(SVT)) &&
11830 ISD::isBuildVectorOfConstantSDNodes(N0.getNode())))
11831 return SDValue();
11832
11833 // We can fold this node into a build_vector.
11834 unsigned VTBits = SVT.getSizeInBits();
11835 unsigned EVTBits = N0->getValueType(0).getScalarSizeInBits();
11836 SmallVector<SDValue, 8> Elts;
11837 unsigned NumElts = VT.getVectorNumElements();
11838
11839 for (unsigned i = 0; i != NumElts; ++i) {
11840 SDValue Op = N0.getOperand(i);
11841 if (Op.isUndef()) {
11842 if (Opcode == ISD::ANY_EXTEND || Opcode == ISD::ANY_EXTEND_VECTOR_INREG)
11843 Elts.push_back(DAG.getUNDEF(SVT));
11844 else
11845 Elts.push_back(DAG.getConstant(0, DL, SVT));
11846 continue;
11847 }
11848
11849 SDLoc DL(Op);
11850 // Get the constant value and if needed trunc it to the size of the type.
11851 // Nodes like build_vector might have constants wider than the scalar type.
11852 APInt C = cast<ConstantSDNode>(Op)->getAPIntValue().zextOrTrunc(EVTBits);
11853 if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
11854 Elts.push_back(DAG.getConstant(C.sext(VTBits), DL, SVT));
11855 else
11856 Elts.push_back(DAG.getConstant(C.zext(VTBits), DL, SVT));
11857 }
11858
11859 return DAG.getBuildVector(VT, DL, Elts);
11860 }
11861
11862 // ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
11863 // "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
11864 // transformation. Returns true if extension are possible and the above
11865 // mentioned transformation is profitable.
ExtendUsesToFormExtLoad(EVT VT,SDNode * N,SDValue N0,unsigned ExtOpc,SmallVectorImpl<SDNode * > & ExtendNodes,const TargetLowering & TLI)11866 static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0,
11867 unsigned ExtOpc,
11868 SmallVectorImpl<SDNode *> &ExtendNodes,
11869 const TargetLowering &TLI) {
11870 bool HasCopyToRegUses = false;
11871 bool isTruncFree = TLI.isTruncateFree(VT, N0.getValueType());
11872 for (SDNode::use_iterator UI = N0->use_begin(), UE = N0->use_end(); UI != UE;
11873 ++UI) {
11874 SDNode *User = *UI;
11875 if (User == N)
11876 continue;
11877 if (UI.getUse().getResNo() != N0.getResNo())
11878 continue;
11879 // FIXME: Only extend SETCC N, N and SETCC N, c for now.
11880 if (ExtOpc != ISD::ANY_EXTEND && User->getOpcode() == ISD::SETCC) {
11881 ISD::CondCode CC = cast<CondCodeSDNode>(User->getOperand(2))->get();
11882 if (ExtOpc == ISD::ZERO_EXTEND && ISD::isSignedIntSetCC(CC))
11883 // Sign bits will be lost after a zext.
11884 return false;
11885 bool Add = false;
11886 for (unsigned i = 0; i != 2; ++i) {
11887 SDValue UseOp = User->getOperand(i);
11888 if (UseOp == N0)
11889 continue;
11890 if (!isa<ConstantSDNode>(UseOp))
11891 return false;
11892 Add = true;
11893 }
11894 if (Add)
11895 ExtendNodes.push_back(User);
11896 continue;
11897 }
11898 // If truncates aren't free and there are users we can't
11899 // extend, it isn't worthwhile.
11900 if (!isTruncFree)
11901 return false;
11902 // Remember if this value is live-out.
11903 if (User->getOpcode() == ISD::CopyToReg)
11904 HasCopyToRegUses = true;
11905 }
11906
11907 if (HasCopyToRegUses) {
11908 bool BothLiveOut = false;
11909 for (SDNode::use_iterator UI = N->use_begin(), UE = N->use_end();
11910 UI != UE; ++UI) {
11911 SDUse &Use = UI.getUse();
11912 if (Use.getResNo() == 0 && Use.getUser()->getOpcode() == ISD::CopyToReg) {
11913 BothLiveOut = true;
11914 break;
11915 }
11916 }
11917 if (BothLiveOut)
11918 // Both unextended and extended values are live out. There had better be
11919 // a good reason for the transformation.
11920 return ExtendNodes.size();
11921 }
11922 return true;
11923 }
11924
ExtendSetCCUses(const SmallVectorImpl<SDNode * > & SetCCs,SDValue OrigLoad,SDValue ExtLoad,ISD::NodeType ExtType)11925 void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
11926 SDValue OrigLoad, SDValue ExtLoad,
11927 ISD::NodeType ExtType) {
11928 // Extend SetCC uses if necessary.
11929 SDLoc DL(ExtLoad);
11930 for (SDNode *SetCC : SetCCs) {
11931 SmallVector<SDValue, 4> Ops;
11932
11933 for (unsigned j = 0; j != 2; ++j) {
11934 SDValue SOp = SetCC->getOperand(j);
11935 if (SOp == OrigLoad)
11936 Ops.push_back(ExtLoad);
11937 else
11938 Ops.push_back(DAG.getNode(ExtType, DL, ExtLoad->getValueType(0), SOp));
11939 }
11940
11941 Ops.push_back(SetCC->getOperand(2));
11942 CombineTo(SetCC, DAG.getNode(ISD::SETCC, DL, SetCC->getValueType(0), Ops));
11943 }
11944 }
11945
11946 // FIXME: Bring more similar combines here, common to sext/zext (maybe aext?).
CombineExtLoad(SDNode * N)11947 SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
11948 SDValue N0 = N->getOperand(0);
11949 EVT DstVT = N->getValueType(0);
11950 EVT SrcVT = N0.getValueType();
11951
11952 assert((N->getOpcode() == ISD::SIGN_EXTEND ||
11953 N->getOpcode() == ISD::ZERO_EXTEND) &&
11954 "Unexpected node type (not an extend)!");
11955
11956 // fold (sext (load x)) to multiple smaller sextloads; same for zext.
11957 // For example, on a target with legal v4i32, but illegal v8i32, turn:
11958 // (v8i32 (sext (v8i16 (load x))))
11959 // into:
11960 // (v8i32 (concat_vectors (v4i32 (sextload x)),
11961 // (v4i32 (sextload (x + 16)))))
11962 // Where uses of the original load, i.e.:
11963 // (v8i16 (load x))
11964 // are replaced with:
11965 // (v8i16 (truncate
11966 // (v8i32 (concat_vectors (v4i32 (sextload x)),
11967 // (v4i32 (sextload (x + 16)))))))
11968 //
11969 // This combine is only applicable to illegal, but splittable, vectors.
11970 // All legal types, and illegal non-vector types, are handled elsewhere.
11971 // This combine is controlled by TargetLowering::isVectorLoadExtDesirable.
11972 //
11973 if (N0->getOpcode() != ISD::LOAD)
11974 return SDValue();
11975
11976 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
11977
11978 if (!ISD::isNON_EXTLoad(LN0) || !ISD::isUNINDEXEDLoad(LN0) ||
11979 !N0.hasOneUse() || !LN0->isSimple() ||
11980 !DstVT.isVector() || !DstVT.isPow2VectorType() ||
11981 !TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
11982 return SDValue();
11983
11984 SmallVector<SDNode *, 4> SetCCs;
11985 if (!ExtendUsesToFormExtLoad(DstVT, N, N0, N->getOpcode(), SetCCs, TLI))
11986 return SDValue();
11987
11988 ISD::LoadExtType ExtType =
11989 N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
11990
11991 // Try to split the vector types to get down to legal types.
11992 EVT SplitSrcVT = SrcVT;
11993 EVT SplitDstVT = DstVT;
11994 while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT) &&
11995 SplitSrcVT.getVectorNumElements() > 1) {
11996 SplitDstVT = DAG.GetSplitDestVTs(SplitDstVT).first;
11997 SplitSrcVT = DAG.GetSplitDestVTs(SplitSrcVT).first;
11998 }
11999
12000 if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT))
12001 return SDValue();
12002
12003 assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
12004
12005 SDLoc DL(N);
12006 const unsigned NumSplits =
12007 DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
12008 const unsigned Stride = SplitSrcVT.getStoreSize();
12009 SmallVector<SDValue, 4> Loads;
12010 SmallVector<SDValue, 4> Chains;
12011
12012 SDValue BasePtr = LN0->getBasePtr();
12013 for (unsigned Idx = 0; Idx < NumSplits; Idx++) {
12014 const unsigned Offset = Idx * Stride;
12015 const Align Align = commonAlignment(LN0->getAlign(), Offset);
12016
12017 SDValue SplitLoad = DAG.getExtLoad(
12018 ExtType, SDLoc(LN0), SplitDstVT, LN0->getChain(), BasePtr,
12019 LN0->getPointerInfo().getWithOffset(Offset), SplitSrcVT, Align,
12020 LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
12021
12022 BasePtr = DAG.getMemBasePlusOffset(BasePtr, TypeSize::Fixed(Stride), DL);
12023
12024 Loads.push_back(SplitLoad.getValue(0));
12025 Chains.push_back(SplitLoad.getValue(1));
12026 }
12027
12028 SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
12029 SDValue NewValue = DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, Loads);
12030
12031 // Simplify TF.
12032 AddToWorklist(NewChain.getNode());
12033
12034 CombineTo(N, NewValue);
12035
12036 // Replace uses of the original load (before extension)
12037 // with a truncate of the concatenated sextloaded vectors.
12038 SDValue Trunc =
12039 DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), NewValue);
12040 ExtendSetCCUses(SetCCs, N0, NewValue, (ISD::NodeType)N->getOpcode());
12041 CombineTo(N0.getNode(), Trunc, NewChain);
12042 return SDValue(N, 0); // Return N so it doesn't get rechecked!
12043 }
12044
12045 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
12046 // (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
CombineZExtLogicopShiftLoad(SDNode * N)12047 SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
12048 assert(N->getOpcode() == ISD::ZERO_EXTEND);
12049 EVT VT = N->getValueType(0);
12050 EVT OrigVT = N->getOperand(0).getValueType();
12051 if (TLI.isZExtFree(OrigVT, VT))
12052 return SDValue();
12053
12054 // and/or/xor
12055 SDValue N0 = N->getOperand(0);
12056 if (!(N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
12057 N0.getOpcode() == ISD::XOR) ||
12058 N0.getOperand(1).getOpcode() != ISD::Constant ||
12059 (LegalOperations && !TLI.isOperationLegal(N0.getOpcode(), VT)))
12060 return SDValue();
12061
12062 // shl/shr
12063 SDValue N1 = N0->getOperand(0);
12064 if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
12065 N1.getOperand(1).getOpcode() != ISD::Constant ||
12066 (LegalOperations && !TLI.isOperationLegal(N1.getOpcode(), VT)))
12067 return SDValue();
12068
12069 // load
12070 if (!isa<LoadSDNode>(N1.getOperand(0)))
12071 return SDValue();
12072 LoadSDNode *Load = cast<LoadSDNode>(N1.getOperand(0));
12073 EVT MemVT = Load->getMemoryVT();
12074 if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) ||
12075 Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
12076 return SDValue();
12077
12078
12079 // If the shift op is SHL, the logic op must be AND, otherwise the result
12080 // will be wrong.
12081 if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
12082 return SDValue();
12083
12084 if (!N0.hasOneUse() || !N1.hasOneUse())
12085 return SDValue();
12086
12087 SmallVector<SDNode*, 4> SetCCs;
12088 if (!ExtendUsesToFormExtLoad(VT, N1.getNode(), N1.getOperand(0),
12089 ISD::ZERO_EXTEND, SetCCs, TLI))
12090 return SDValue();
12091
12092 // Actually do the transformation.
12093 SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Load), VT,
12094 Load->getChain(), Load->getBasePtr(),
12095 Load->getMemoryVT(), Load->getMemOperand());
12096
12097 SDLoc DL1(N1);
12098 SDValue Shift = DAG.getNode(N1.getOpcode(), DL1, VT, ExtLoad,
12099 N1.getOperand(1));
12100
12101 APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
12102 SDLoc DL0(N0);
12103 SDValue And = DAG.getNode(N0.getOpcode(), DL0, VT, Shift,
12104 DAG.getConstant(Mask, DL0, VT));
12105
12106 ExtendSetCCUses(SetCCs, N1.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
12107 CombineTo(N, And);
12108 if (SDValue(Load, 0).hasOneUse()) {
12109 DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), ExtLoad.getValue(1));
12110 } else {
12111 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(Load),
12112 Load->getValueType(0), ExtLoad);
12113 CombineTo(Load, Trunc, ExtLoad.getValue(1));
12114 }
12115
12116 // N0 is dead at this point.
12117 recursivelyDeleteUnusedNodes(N0.getNode());
12118
12119 return SDValue(N,0); // Return N so it doesn't get rechecked!
12120 }
12121
12122 /// If we're narrowing or widening the result of a vector select and the final
12123 /// size is the same size as a setcc (compare) feeding the select, then try to
12124 /// apply the cast operation to the select's operands because matching vector
12125 /// sizes for a select condition and other operands should be more efficient.
matchVSelectOpSizesWithSetCC(SDNode * Cast)12126 SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
12127 unsigned CastOpcode = Cast->getOpcode();
12128 assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
12129 CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
12130 CastOpcode == ISD::FP_ROUND) &&
12131 "Unexpected opcode for vector select narrowing/widening");
12132
12133 // We only do this transform before legal ops because the pattern may be
12134 // obfuscated by target-specific operations after legalization. Do not create
12135 // an illegal select op, however, because that may be difficult to lower.
12136 EVT VT = Cast->getValueType(0);
12137 if (LegalOperations || !TLI.isOperationLegalOrCustom(ISD::VSELECT, VT))
12138 return SDValue();
12139
12140 SDValue VSel = Cast->getOperand(0);
12141 if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
12142 VSel.getOperand(0).getOpcode() != ISD::SETCC)
12143 return SDValue();
12144
12145 // Does the setcc have the same vector size as the casted select?
12146 SDValue SetCC = VSel.getOperand(0);
12147 EVT SetCCVT = getSetCCResultType(SetCC.getOperand(0).getValueType());
12148 if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
12149 return SDValue();
12150
12151 // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
12152 SDValue A = VSel.getOperand(1);
12153 SDValue B = VSel.getOperand(2);
12154 SDValue CastA, CastB;
12155 SDLoc DL(Cast);
12156 if (CastOpcode == ISD::FP_ROUND) {
12157 // FP_ROUND (fptrunc) has an extra flag operand to pass along.
12158 CastA = DAG.getNode(CastOpcode, DL, VT, A, Cast->getOperand(1));
12159 CastB = DAG.getNode(CastOpcode, DL, VT, B, Cast->getOperand(1));
12160 } else {
12161 CastA = DAG.getNode(CastOpcode, DL, VT, A);
12162 CastB = DAG.getNode(CastOpcode, DL, VT, B);
12163 }
12164 return DAG.getNode(ISD::VSELECT, DL, VT, SetCC, CastA, CastB);
12165 }
12166
12167 // fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
12168 // fold ([s|z]ext ( extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
tryToFoldExtOfExtload(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType)12169 static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
12170 const TargetLowering &TLI, EVT VT,
12171 bool LegalOperations, SDNode *N,
12172 SDValue N0, ISD::LoadExtType ExtLoadType) {
12173 SDNode *N0Node = N0.getNode();
12174 bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N0Node)
12175 : ISD::isZEXTLoad(N0Node);
12176 if ((!isAExtLoad && !ISD::isEXTLoad(N0Node)) ||
12177 !ISD::isUNINDEXEDLoad(N0Node) || !N0.hasOneUse())
12178 return SDValue();
12179
12180 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
12181 EVT MemVT = LN0->getMemoryVT();
12182 if ((LegalOperations || !LN0->isSimple() ||
12183 VT.isVector()) &&
12184 !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT))
12185 return SDValue();
12186
12187 SDValue ExtLoad =
12188 DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
12189 LN0->getBasePtr(), MemVT, LN0->getMemOperand());
12190 Combiner.CombineTo(N, ExtLoad);
12191 DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
12192 if (LN0->use_empty())
12193 Combiner.recursivelyDeleteUnusedNodes(LN0);
12194 return SDValue(N, 0); // Return N so it doesn't get rechecked!
12195 }
12196
12197 // fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
12198 // Only generate vector extloads when 1) they're legal, and 2) they are
12199 // deemed desirable by the target.
tryToFoldExtOfLoad(SelectionDAG & DAG,DAGCombiner & Combiner,const TargetLowering & TLI,EVT VT,bool LegalOperations,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)12200 static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
12201 const TargetLowering &TLI, EVT VT,
12202 bool LegalOperations, SDNode *N, SDValue N0,
12203 ISD::LoadExtType ExtLoadType,
12204 ISD::NodeType ExtOpc) {
12205 // TODO: isFixedLengthVector() should be removed and any negative effects on
12206 // code generation being the result of that target's implementation of
12207 // isVectorLoadExtDesirable().
12208 if (!ISD::isNON_EXTLoad(N0.getNode()) ||
12209 !ISD::isUNINDEXEDLoad(N0.getNode()) ||
12210 ((LegalOperations || VT.isFixedLengthVector() ||
12211 !cast<LoadSDNode>(N0)->isSimple()) &&
12212 !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType())))
12213 return {};
12214
12215 bool DoXform = true;
12216 SmallVector<SDNode *, 4> SetCCs;
12217 if (!N0.hasOneUse())
12218 DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, SetCCs, TLI);
12219 if (VT.isVector())
12220 DoXform &= TLI.isVectorLoadExtDesirable(SDValue(N, 0));
12221 if (!DoXform)
12222 return {};
12223
12224 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
12225 SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
12226 LN0->getBasePtr(), N0.getValueType(),
12227 LN0->getMemOperand());
12228 Combiner.ExtendSetCCUses(SetCCs, N0, ExtLoad, ExtOpc);
12229 // If the load value is used only by N, replace it via CombineTo N.
12230 bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
12231 Combiner.CombineTo(N, ExtLoad);
12232 if (NoReplaceTrunc) {
12233 DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
12234 Combiner.recursivelyDeleteUnusedNodes(LN0);
12235 } else {
12236 SDValue Trunc =
12237 DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
12238 Combiner.CombineTo(LN0, Trunc, ExtLoad.getValue(1));
12239 }
12240 return SDValue(N, 0); // Return N so it doesn't get rechecked!
12241 }
12242
tryToFoldExtOfMaskedLoad(SelectionDAG & DAG,const TargetLowering & TLI,EVT VT,SDNode * N,SDValue N0,ISD::LoadExtType ExtLoadType,ISD::NodeType ExtOpc)12243 static SDValue tryToFoldExtOfMaskedLoad(SelectionDAG &DAG,
12244 const TargetLowering &TLI, EVT VT,
12245 SDNode *N, SDValue N0,
12246 ISD::LoadExtType ExtLoadType,
12247 ISD::NodeType ExtOpc) {
12248 if (!N0.hasOneUse())
12249 return SDValue();
12250
12251 MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0);
12252 if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
12253 return SDValue();
12254
12255 if (!TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0)))
12256 return SDValue();
12257
12258 if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
12259 return SDValue();
12260
12261 SDLoc dl(Ld);
12262 SDValue PassThru = DAG.getNode(ExtOpc, dl, VT, Ld->getPassThru());
12263 SDValue NewLoad = DAG.getMaskedLoad(
12264 VT, dl, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(), Ld->getMask(),
12265 PassThru, Ld->getMemoryVT(), Ld->getMemOperand(), Ld->getAddressingMode(),
12266 ExtLoadType, Ld->isExpandingLoad());
12267 DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), SDValue(NewLoad.getNode(), 1));
12268 return NewLoad;
12269 }
12270
foldExtendedSignBitTest(SDNode * N,SelectionDAG & DAG,bool LegalOperations)12271 static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
12272 bool LegalOperations) {
12273 assert((N->getOpcode() == ISD::SIGN_EXTEND ||
12274 N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
12275
12276 SDValue SetCC = N->getOperand(0);
12277 if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
12278 !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
12279 return SDValue();
12280
12281 SDValue X = SetCC.getOperand(0);
12282 SDValue Ones = SetCC.getOperand(1);
12283 ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
12284 EVT VT = N->getValueType(0);
12285 EVT XVT = X.getValueType();
12286 // setge X, C is canonicalized to setgt, so we do not need to match that
12287 // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
12288 // not require the 'not' op.
12289 if (CC == ISD::SETGT && isAllOnesConstant(Ones) && VT == XVT) {
12290 // Invert and smear/shift the sign bit:
12291 // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
12292 // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
12293 SDLoc DL(N);
12294 unsigned ShCt = VT.getSizeInBits() - 1;
12295 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12296 if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
12297 SDValue NotX = DAG.getNOT(DL, X, VT);
12298 SDValue ShiftAmount = DAG.getConstant(ShCt, DL, VT);
12299 auto ShiftOpcode =
12300 N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
12301 return DAG.getNode(ShiftOpcode, DL, VT, NotX, ShiftAmount);
12302 }
12303 }
12304 return SDValue();
12305 }
12306
foldSextSetcc(SDNode * N)12307 SDValue DAGCombiner::foldSextSetcc(SDNode *N) {
12308 SDValue N0 = N->getOperand(0);
12309 if (N0.getOpcode() != ISD::SETCC)
12310 return SDValue();
12311
12312 SDValue N00 = N0.getOperand(0);
12313 SDValue N01 = N0.getOperand(1);
12314 ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
12315 EVT VT = N->getValueType(0);
12316 EVT N00VT = N00.getValueType();
12317 SDLoc DL(N);
12318
12319 // Propagate fast-math-flags.
12320 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
12321
12322 // On some architectures (such as SSE/NEON/etc) the SETCC result type is
12323 // the same size as the compared operands. Try to optimize sext(setcc())
12324 // if this is the case.
12325 if (VT.isVector() && !LegalOperations &&
12326 TLI.getBooleanContents(N00VT) ==
12327 TargetLowering::ZeroOrNegativeOneBooleanContent) {
12328 EVT SVT = getSetCCResultType(N00VT);
12329
12330 // If we already have the desired type, don't change it.
12331 if (SVT != N0.getValueType()) {
12332 // We know that the # elements of the results is the same as the
12333 // # elements of the compare (and the # elements of the compare result
12334 // for that matter). Check to see that they are the same size. If so,
12335 // we know that the element size of the sext'd result matches the
12336 // element size of the compare operands.
12337 if (VT.getSizeInBits() == SVT.getSizeInBits())
12338 return DAG.getSetCC(DL, VT, N00, N01, CC);
12339
12340 // If the desired elements are smaller or larger than the source
12341 // elements, we can use a matching integer vector type and then
12342 // truncate/sign extend.
12343 EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
12344 if (SVT == MatchingVecType) {
12345 SDValue VsetCC = DAG.getSetCC(DL, MatchingVecType, N00, N01, CC);
12346 return DAG.getSExtOrTrunc(VsetCC, DL, VT);
12347 }
12348 }
12349
12350 // Try to eliminate the sext of a setcc by zexting the compare operands.
12351 if (N0.hasOneUse() && TLI.isOperationLegalOrCustom(ISD::SETCC, VT) &&
12352 !TLI.isOperationLegalOrCustom(ISD::SETCC, SVT)) {
12353 bool IsSignedCmp = ISD::isSignedIntSetCC(CC);
12354 unsigned LoadOpcode = IsSignedCmp ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
12355 unsigned ExtOpcode = IsSignedCmp ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
12356
12357 // We have an unsupported narrow vector compare op that would be legal
12358 // if extended to the destination type. See if the compare operands
12359 // can be freely extended to the destination type.
12360 auto IsFreeToExtend = [&](SDValue V) {
12361 if (isConstantOrConstantVector(V, /*NoOpaques*/ true))
12362 return true;
12363 // Match a simple, non-extended load that can be converted to a
12364 // legal {z/s}ext-load.
12365 // TODO: Allow widening of an existing {z/s}ext-load?
12366 if (!(ISD::isNON_EXTLoad(V.getNode()) &&
12367 ISD::isUNINDEXEDLoad(V.getNode()) &&
12368 cast<LoadSDNode>(V)->isSimple() &&
12369 TLI.isLoadExtLegal(LoadOpcode, VT, V.getValueType())))
12370 return false;
12371
12372 // Non-chain users of this value must either be the setcc in this
12373 // sequence or extends that can be folded into the new {z/s}ext-load.
12374 for (SDNode::use_iterator UI = V->use_begin(), UE = V->use_end();
12375 UI != UE; ++UI) {
12376 // Skip uses of the chain and the setcc.
12377 SDNode *User = *UI;
12378 if (UI.getUse().getResNo() != 0 || User == N0.getNode())
12379 continue;
12380 // Extra users must have exactly the same cast we are about to create.
12381 // TODO: This restriction could be eased if ExtendUsesToFormExtLoad()
12382 // is enhanced similarly.
12383 if (User->getOpcode() != ExtOpcode || User->getValueType(0) != VT)
12384 return false;
12385 }
12386 return true;
12387 };
12388
12389 if (IsFreeToExtend(N00) && IsFreeToExtend(N01)) {
12390 SDValue Ext0 = DAG.getNode(ExtOpcode, DL, VT, N00);
12391 SDValue Ext1 = DAG.getNode(ExtOpcode, DL, VT, N01);
12392 return DAG.getSetCC(DL, VT, Ext0, Ext1, CC);
12393 }
12394 }
12395 }
12396
12397 // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
12398 // Here, T can be 1 or -1, depending on the type of the setcc and
12399 // getBooleanContents().
12400 unsigned SetCCWidth = N0.getScalarValueSizeInBits();
12401
12402 // To determine the "true" side of the select, we need to know the high bit
12403 // of the value returned by the setcc if it evaluates to true.
12404 // If the type of the setcc is i1, then the true case of the select is just
12405 // sext(i1 1), that is, -1.
12406 // If the type of the setcc is larger (say, i8) then the value of the high
12407 // bit depends on getBooleanContents(), so ask TLI for a real "true" value
12408 // of the appropriate width.
12409 SDValue ExtTrueVal = (SetCCWidth == 1)
12410 ? DAG.getAllOnesConstant(DL, VT)
12411 : DAG.getBoolConstant(true, DL, VT, N00VT);
12412 SDValue Zero = DAG.getConstant(0, DL, VT);
12413 if (SDValue SCC = SimplifySelectCC(DL, N00, N01, ExtTrueVal, Zero, CC, true))
12414 return SCC;
12415
12416 if (!VT.isVector() && !shouldConvertSelectOfConstantsToMath(N0, VT, TLI)) {
12417 EVT SetCCVT = getSetCCResultType(N00VT);
12418 // Don't do this transform for i1 because there's a select transform
12419 // that would reverse it.
12420 // TODO: We should not do this transform at all without a target hook
12421 // because a sext is likely cheaper than a select?
12422 if (SetCCVT.getScalarSizeInBits() != 1 &&
12423 (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, N00VT))) {
12424 SDValue SetCC = DAG.getSetCC(DL, SetCCVT, N00, N01, CC);
12425 return DAG.getSelect(DL, VT, SetCC, ExtTrueVal, Zero);
12426 }
12427 }
12428
12429 return SDValue();
12430 }
12431
visitSIGN_EXTEND(SDNode * N)12432 SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
12433 SDValue N0 = N->getOperand(0);
12434 EVT VT = N->getValueType(0);
12435 SDLoc DL(N);
12436
12437 if (VT.isVector())
12438 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
12439 return FoldedVOp;
12440
12441 // sext(undef) = 0 because the top bit will all be the same.
12442 if (N0.isUndef())
12443 return DAG.getConstant(0, DL, VT);
12444
12445 if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
12446 return Res;
12447
12448 // fold (sext (sext x)) -> (sext x)
12449 // fold (sext (aext x)) -> (sext x)
12450 if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
12451 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, N0.getOperand(0));
12452
12453 // fold (sext (sext_inreg x)) -> (sext (trunc x))
12454 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
12455 SDValue N00 = N0.getOperand(0);
12456 EVT ExtVT = cast<VTSDNode>(N0->getOperand(1))->getVT();
12457 if (N00.getOpcode() == ISD::TRUNCATE &&
12458 (!LegalTypes || TLI.isTypeLegal(ExtVT))) {
12459 SDValue T = DAG.getNode(ISD::TRUNCATE, DL, ExtVT, N00.getOperand(0));
12460 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, T);
12461 }
12462 }
12463
12464 if (N0.getOpcode() == ISD::TRUNCATE) {
12465 // fold (sext (truncate (load x))) -> (sext (smaller load x))
12466 // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
12467 if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
12468 SDNode *oye = N0.getOperand(0).getNode();
12469 if (NarrowLoad.getNode() != N0.getNode()) {
12470 CombineTo(N0.getNode(), NarrowLoad);
12471 // CombineTo deleted the truncate, if needed, but not what's under it.
12472 AddToWorklist(oye);
12473 }
12474 return SDValue(N, 0); // Return N so it doesn't get rechecked!
12475 }
12476
12477 // See if the value being truncated is already sign extended. If so, just
12478 // eliminate the trunc/sext pair.
12479 SDValue Op = N0.getOperand(0);
12480 unsigned OpBits = Op.getScalarValueSizeInBits();
12481 unsigned MidBits = N0.getScalarValueSizeInBits();
12482 unsigned DestBits = VT.getScalarSizeInBits();
12483 unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
12484
12485 if (OpBits == DestBits) {
12486 // Op is i32, Mid is i8, and Dest is i32. If Op has more than 24 sign
12487 // bits, it is already ready.
12488 if (NumSignBits > DestBits-MidBits)
12489 return Op;
12490 } else if (OpBits < DestBits) {
12491 // Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign
12492 // bits, just sext from i32.
12493 if (NumSignBits > OpBits-MidBits)
12494 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, Op);
12495 } else {
12496 // Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign
12497 // bits, just truncate to i32.
12498 if (NumSignBits > OpBits-MidBits)
12499 return DAG.getNode(ISD::TRUNCATE, DL, VT, Op);
12500 }
12501
12502 // fold (sext (truncate x)) -> (sextinreg x).
12503 if (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG,
12504 N0.getValueType())) {
12505 if (OpBits < DestBits)
12506 Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N0), VT, Op);
12507 else if (OpBits > DestBits)
12508 Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N0), VT, Op);
12509 return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op,
12510 DAG.getValueType(N0.getValueType()));
12511 }
12512 }
12513
12514 // Try to simplify (sext (load x)).
12515 if (SDValue foldedExt =
12516 tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
12517 ISD::SEXTLOAD, ISD::SIGN_EXTEND))
12518 return foldedExt;
12519
12520 if (SDValue foldedExt =
12521 tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::SEXTLOAD,
12522 ISD::SIGN_EXTEND))
12523 return foldedExt;
12524
12525 // fold (sext (load x)) to multiple smaller sextloads.
12526 // Only on illegal but splittable vectors.
12527 if (SDValue ExtLoad = CombineExtLoad(N))
12528 return ExtLoad;
12529
12530 // Try to simplify (sext (sextload x)).
12531 if (SDValue foldedExt = tryToFoldExtOfExtload(
12532 DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::SEXTLOAD))
12533 return foldedExt;
12534
12535 // fold (sext (and/or/xor (load x), cst)) ->
12536 // (and/or/xor (sextload x), (sext cst))
12537 if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
12538 N0.getOpcode() == ISD::XOR) &&
12539 isa<LoadSDNode>(N0.getOperand(0)) &&
12540 N0.getOperand(1).getOpcode() == ISD::Constant &&
12541 (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
12542 LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
12543 EVT MemVT = LN00->getMemoryVT();
12544 if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) &&
12545 LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
12546 SmallVector<SDNode*, 4> SetCCs;
12547 bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
12548 ISD::SIGN_EXTEND, SetCCs, TLI);
12549 if (DoXform) {
12550 SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(LN00), VT,
12551 LN00->getChain(), LN00->getBasePtr(),
12552 LN00->getMemoryVT(),
12553 LN00->getMemOperand());
12554 APInt Mask = N0.getConstantOperandAPInt(1).sext(VT.getSizeInBits());
12555 SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
12556 ExtLoad, DAG.getConstant(Mask, DL, VT));
12557 ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::SIGN_EXTEND);
12558 bool NoReplaceTruncAnd = !N0.hasOneUse();
12559 bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
12560 CombineTo(N, And);
12561 // If N0 has multiple uses, change other uses as well.
12562 if (NoReplaceTruncAnd) {
12563 SDValue TruncAnd =
12564 DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
12565 CombineTo(N0.getNode(), TruncAnd);
12566 }
12567 if (NoReplaceTrunc) {
12568 DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
12569 } else {
12570 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
12571 LN00->getValueType(0), ExtLoad);
12572 CombineTo(LN00, Trunc, ExtLoad.getValue(1));
12573 }
12574 return SDValue(N,0); // Return N so it doesn't get rechecked!
12575 }
12576 }
12577 }
12578
12579 if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
12580 return V;
12581
12582 if (SDValue V = foldSextSetcc(N))
12583 return V;
12584
12585 // fold (sext x) -> (zext x) if the sign bit is known zero.
12586 if ((!LegalOperations || TLI.isOperationLegal(ISD::ZERO_EXTEND, VT)) &&
12587 DAG.SignBitIsZero(N0))
12588 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0);
12589
12590 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
12591 return NewVSel;
12592
12593 // Eliminate this sign extend by doing a negation in the destination type:
12594 // sext i32 (0 - (zext i8 X to i32)) to i64 --> 0 - (zext i8 X to i64)
12595 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
12596 isNullOrNullSplat(N0.getOperand(0)) &&
12597 N0.getOperand(1).getOpcode() == ISD::ZERO_EXTEND &&
12598 TLI.isOperationLegalOrCustom(ISD::SUB, VT)) {
12599 SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(1).getOperand(0), DL, VT);
12600 return DAG.getNegative(Zext, DL, VT);
12601 }
12602 // Eliminate this sign extend by doing a decrement in the destination type:
12603 // sext i32 ((zext i8 X to i32) + (-1)) to i64 --> (zext i8 X to i64) + (-1)
12604 if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
12605 isAllOnesOrAllOnesSplat(N0.getOperand(1)) &&
12606 N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
12607 TLI.isOperationLegalOrCustom(ISD::ADD, VT)) {
12608 SDValue Zext = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
12609 return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
12610 }
12611
12612 // fold sext (not i1 X) -> add (zext i1 X), -1
12613 // TODO: This could be extended to handle bool vectors.
12614 if (N0.getValueType() == MVT::i1 && isBitwiseNot(N0) && N0.hasOneUse() &&
12615 (!LegalOperations || (TLI.isOperationLegal(ISD::ZERO_EXTEND, VT) &&
12616 TLI.isOperationLegal(ISD::ADD, VT)))) {
12617 // If we can eliminate the 'not', the sext form should be better
12618 if (SDValue NewXor = visitXOR(N0.getNode())) {
12619 // Returning N0 is a form of in-visit replacement that may have
12620 // invalidated N0.
12621 if (NewXor.getNode() == N0.getNode()) {
12622 // Return SDValue here as the xor should have already been replaced in
12623 // this sext.
12624 return SDValue();
12625 }
12626
12627 // Return a new sext with the new xor.
12628 return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NewXor);
12629 }
12630
12631 SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0));
12632 return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
12633 }
12634
12635 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
12636 return Res;
12637
12638 return SDValue();
12639 }
12640
12641 // isTruncateOf - If N is a truncate of some other value, return true, record
12642 // the value being truncated in Op and which of Op's bits are zero/one in Known.
12643 // This function computes KnownBits to avoid a duplicated call to
12644 // computeKnownBits in the caller.
isTruncateOf(SelectionDAG & DAG,SDValue N,SDValue & Op,KnownBits & Known)12645 static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
12646 KnownBits &Known) {
12647 if (N->getOpcode() == ISD::TRUNCATE) {
12648 Op = N->getOperand(0);
12649 Known = DAG.computeKnownBits(Op);
12650 return true;
12651 }
12652
12653 if (N.getOpcode() != ISD::SETCC ||
12654 N.getValueType().getScalarType() != MVT::i1 ||
12655 cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE)
12656 return false;
12657
12658 SDValue Op0 = N->getOperand(0);
12659 SDValue Op1 = N->getOperand(1);
12660 assert(Op0.getValueType() == Op1.getValueType());
12661
12662 if (isNullOrNullSplat(Op0))
12663 Op = Op1;
12664 else if (isNullOrNullSplat(Op1))
12665 Op = Op0;
12666 else
12667 return false;
12668
12669 Known = DAG.computeKnownBits(Op);
12670
12671 return (Known.Zero | 1).isAllOnes();
12672 }
12673
12674 /// Given an extending node with a pop-count operand, if the target does not
12675 /// support a pop-count in the narrow source type but does support it in the
12676 /// destination type, widen the pop-count to the destination type.
widenCtPop(SDNode * Extend,SelectionDAG & DAG)12677 static SDValue widenCtPop(SDNode *Extend, SelectionDAG &DAG) {
12678 assert((Extend->getOpcode() == ISD::ZERO_EXTEND ||
12679 Extend->getOpcode() == ISD::ANY_EXTEND) && "Expected extend op");
12680
12681 SDValue CtPop = Extend->getOperand(0);
12682 if (CtPop.getOpcode() != ISD::CTPOP || !CtPop.hasOneUse())
12683 return SDValue();
12684
12685 EVT VT = Extend->getValueType(0);
12686 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12687 if (TLI.isOperationLegalOrCustom(ISD::CTPOP, CtPop.getValueType()) ||
12688 !TLI.isOperationLegalOrCustom(ISD::CTPOP, VT))
12689 return SDValue();
12690
12691 // zext (ctpop X) --> ctpop (zext X)
12692 SDLoc DL(Extend);
12693 SDValue NewZext = DAG.getZExtOrTrunc(CtPop.getOperand(0), DL, VT);
12694 return DAG.getNode(ISD::CTPOP, DL, VT, NewZext);
12695 }
12696
12697 // If we have (zext (abs X)) where X is a type that will be promoted by type
12698 // legalization, convert to (abs (sext X)). But don't extend past a legal type.
widenAbs(SDNode * Extend,SelectionDAG & DAG)12699 static SDValue widenAbs(SDNode *Extend, SelectionDAG &DAG) {
12700 assert(Extend->getOpcode() == ISD::ZERO_EXTEND && "Expected zero extend.");
12701
12702 EVT VT = Extend->getValueType(0);
12703 if (VT.isVector())
12704 return SDValue();
12705
12706 SDValue Abs = Extend->getOperand(0);
12707 if (Abs.getOpcode() != ISD::ABS || !Abs.hasOneUse())
12708 return SDValue();
12709
12710 EVT AbsVT = Abs.getValueType();
12711 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
12712 if (TLI.getTypeAction(*DAG.getContext(), AbsVT) !=
12713 TargetLowering::TypePromoteInteger)
12714 return SDValue();
12715
12716 EVT LegalVT = TLI.getTypeToTransformTo(*DAG.getContext(), AbsVT);
12717
12718 SDValue SExt =
12719 DAG.getNode(ISD::SIGN_EXTEND, SDLoc(Abs), LegalVT, Abs.getOperand(0));
12720 SDValue NewAbs = DAG.getNode(ISD::ABS, SDLoc(Abs), LegalVT, SExt);
12721 return DAG.getZExtOrTrunc(NewAbs, SDLoc(Extend), VT);
12722 }
12723
visitZERO_EXTEND(SDNode * N)12724 SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
12725 SDValue N0 = N->getOperand(0);
12726 EVT VT = N->getValueType(0);
12727
12728 if (VT.isVector())
12729 if (SDValue FoldedVOp = SimplifyVCastOp(N, SDLoc(N)))
12730 return FoldedVOp;
12731
12732 // zext(undef) = 0
12733 if (N0.isUndef())
12734 return DAG.getConstant(0, SDLoc(N), VT);
12735
12736 if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
12737 return Res;
12738
12739 // fold (zext (zext x)) -> (zext x)
12740 // fold (zext (aext x)) -> (zext x)
12741 if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
12742 return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT,
12743 N0.getOperand(0));
12744
12745 // fold (zext (truncate x)) -> (zext x) or
12746 // (zext (truncate x)) -> (truncate x)
12747 // This is valid when the truncated bits of x are already zero.
12748 SDValue Op;
12749 KnownBits Known;
12750 if (isTruncateOf(DAG, N0, Op, Known)) {
12751 APInt TruncatedBits =
12752 (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
12753 APInt(Op.getScalarValueSizeInBits(), 0) :
12754 APInt::getBitsSet(Op.getScalarValueSizeInBits(),
12755 N0.getScalarValueSizeInBits(),
12756 std::min(Op.getScalarValueSizeInBits(),
12757 VT.getScalarSizeInBits()));
12758 if (TruncatedBits.isSubsetOf(Known.Zero))
12759 return DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
12760 }
12761
12762 // fold (zext (truncate x)) -> (and x, mask)
12763 if (N0.getOpcode() == ISD::TRUNCATE) {
12764 // fold (zext (truncate (load x))) -> (zext (smaller load x))
12765 // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
12766 if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
12767 SDNode *oye = N0.getOperand(0).getNode();
12768 if (NarrowLoad.getNode() != N0.getNode()) {
12769 CombineTo(N0.getNode(), NarrowLoad);
12770 // CombineTo deleted the truncate, if needed, but not what's under it.
12771 AddToWorklist(oye);
12772 }
12773 return SDValue(N, 0); // Return N so it doesn't get rechecked!
12774 }
12775
12776 EVT SrcVT = N0.getOperand(0).getValueType();
12777 EVT MinVT = N0.getValueType();
12778
12779 // Try to mask before the extension to avoid having to generate a larger mask,
12780 // possibly over several sub-vectors.
12781 if (SrcVT.bitsLT(VT) && VT.isVector()) {
12782 if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) &&
12783 TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) {
12784 SDValue Op = N0.getOperand(0);
12785 Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT);
12786 AddToWorklist(Op.getNode());
12787 SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
12788 // Transfer the debug info; the new node is equivalent to N0.
12789 DAG.transferDbgValues(N0, ZExtOrTrunc);
12790 return ZExtOrTrunc;
12791 }
12792 }
12793
12794 if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) {
12795 SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
12796 AddToWorklist(Op.getNode());
12797 SDValue And = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT);
12798 // We may safely transfer the debug info describing the truncate node over
12799 // to the equivalent and operation.
12800 DAG.transferDbgValues(N0, And);
12801 return And;
12802 }
12803 }
12804
12805 // Fold (zext (and (trunc x), cst)) -> (and x, cst),
12806 // if either of the casts is not free.
12807 if (N0.getOpcode() == ISD::AND &&
12808 N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
12809 N0.getOperand(1).getOpcode() == ISD::Constant &&
12810 (!TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(),
12811 N0.getValueType()) ||
12812 !TLI.isZExtFree(N0.getValueType(), VT))) {
12813 SDValue X = N0.getOperand(0).getOperand(0);
12814 X = DAG.getAnyExtOrTrunc(X, SDLoc(X), VT);
12815 APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
12816 SDLoc DL(N);
12817 return DAG.getNode(ISD::AND, DL, VT,
12818 X, DAG.getConstant(Mask, DL, VT));
12819 }
12820
12821 // Try to simplify (zext (load x)).
12822 if (SDValue foldedExt =
12823 tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
12824 ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
12825 return foldedExt;
12826
12827 if (SDValue foldedExt =
12828 tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::ZEXTLOAD,
12829 ISD::ZERO_EXTEND))
12830 return foldedExt;
12831
12832 // fold (zext (load x)) to multiple smaller zextloads.
12833 // Only on illegal but splittable vectors.
12834 if (SDValue ExtLoad = CombineExtLoad(N))
12835 return ExtLoad;
12836
12837 // fold (zext (and/or/xor (load x), cst)) ->
12838 // (and/or/xor (zextload x), (zext cst))
12839 // Unless (and (load x) cst) will match as a zextload already and has
12840 // additional users.
12841 if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR ||
12842 N0.getOpcode() == ISD::XOR) &&
12843 isa<LoadSDNode>(N0.getOperand(0)) &&
12844 N0.getOperand(1).getOpcode() == ISD::Constant &&
12845 (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
12846 LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
12847 EVT MemVT = LN00->getMemoryVT();
12848 if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) &&
12849 LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
12850 bool DoXform = true;
12851 SmallVector<SDNode*, 4> SetCCs;
12852 if (!N0.hasOneUse()) {
12853 if (N0.getOpcode() == ISD::AND) {
12854 auto *AndC = cast<ConstantSDNode>(N0.getOperand(1));
12855 EVT LoadResultTy = AndC->getValueType(0);
12856 EVT ExtVT;
12857 if (isAndLoadExtLoad(AndC, LN00, LoadResultTy, ExtVT))
12858 DoXform = false;
12859 }
12860 }
12861 if (DoXform)
12862 DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
12863 ISD::ZERO_EXTEND, SetCCs, TLI);
12864 if (DoXform) {
12865 SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN00), VT,
12866 LN00->getChain(), LN00->getBasePtr(),
12867 LN00->getMemoryVT(),
12868 LN00->getMemOperand());
12869 APInt Mask = N0.getConstantOperandAPInt(1).zext(VT.getSizeInBits());
12870 SDLoc DL(N);
12871 SDValue And = DAG.getNode(N0.getOpcode(), DL, VT,
12872 ExtLoad, DAG.getConstant(Mask, DL, VT));
12873 ExtendSetCCUses(SetCCs, N0.getOperand(0), ExtLoad, ISD::ZERO_EXTEND);
12874 bool NoReplaceTruncAnd = !N0.hasOneUse();
12875 bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
12876 CombineTo(N, And);
12877 // If N0 has multiple uses, change other uses as well.
12878 if (NoReplaceTruncAnd) {
12879 SDValue TruncAnd =
12880 DAG.getNode(ISD::TRUNCATE, DL, N0.getValueType(), And);
12881 CombineTo(N0.getNode(), TruncAnd);
12882 }
12883 if (NoReplaceTrunc) {
12884 DAG.ReplaceAllUsesOfValueWith(SDValue(LN00, 1), ExtLoad.getValue(1));
12885 } else {
12886 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SDLoc(LN00),
12887 LN00->getValueType(0), ExtLoad);
12888 CombineTo(LN00, Trunc, ExtLoad.getValue(1));
12889 }
12890 return SDValue(N,0); // Return N so it doesn't get rechecked!
12891 }
12892 }
12893 }
12894
12895 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
12896 // (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
12897 if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
12898 return ZExtLoad;
12899
12900 // Try to simplify (zext (zextload x)).
12901 if (SDValue foldedExt = tryToFoldExtOfExtload(
12902 DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::ZEXTLOAD))
12903 return foldedExt;
12904
12905 if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
12906 return V;
12907
12908 if (N0.getOpcode() == ISD::SETCC) {
12909 // Propagate fast-math-flags.
12910 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
12911
12912 // Only do this before legalize for now.
12913 if (!LegalOperations && VT.isVector() &&
12914 N0.getValueType().getVectorElementType() == MVT::i1) {
12915 EVT N00VT = N0.getOperand(0).getValueType();
12916 if (getSetCCResultType(N00VT) == N0.getValueType())
12917 return SDValue();
12918
12919 // We know that the # elements of the results is the same as the #
12920 // elements of the compare (and the # elements of the compare result for
12921 // that matter). Check to see that they are the same size. If so, we know
12922 // that the element size of the sext'd result matches the element size of
12923 // the compare operands.
12924 SDLoc DL(N);
12925 if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
12926 // zext(setcc) -> zext_in_reg(vsetcc) for vectors.
12927 SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0),
12928 N0.getOperand(1), N0.getOperand(2));
12929 return DAG.getZeroExtendInReg(VSetCC, DL, N0.getValueType());
12930 }
12931
12932 // If the desired elements are smaller or larger than the source
12933 // elements we can use a matching integer vector type and then
12934 // truncate/any extend followed by zext_in_reg.
12935 EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
12936 SDValue VsetCC =
12937 DAG.getNode(ISD::SETCC, DL, MatchingVectorType, N0.getOperand(0),
12938 N0.getOperand(1), N0.getOperand(2));
12939 return DAG.getZeroExtendInReg(DAG.getAnyExtOrTrunc(VsetCC, DL, VT), DL,
12940 N0.getValueType());
12941 }
12942
12943 // zext(setcc x,y,cc) -> zext(select x, y, true, false, cc)
12944 SDLoc DL(N);
12945 EVT N0VT = N0.getValueType();
12946 EVT N00VT = N0.getOperand(0).getValueType();
12947 if (SDValue SCC = SimplifySelectCC(
12948 DL, N0.getOperand(0), N0.getOperand(1),
12949 DAG.getBoolConstant(true, DL, N0VT, N00VT),
12950 DAG.getBoolConstant(false, DL, N0VT, N00VT),
12951 cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
12952 return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, SCC);
12953 }
12954
12955 // (zext (shl (zext x), cst)) -> (shl (zext x), cst)
12956 if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
12957 isa<ConstantSDNode>(N0.getOperand(1)) &&
12958 N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
12959 N0.hasOneUse()) {
12960 SDValue ShAmt = N0.getOperand(1);
12961 if (N0.getOpcode() == ISD::SHL) {
12962 SDValue InnerZExt = N0.getOperand(0);
12963 // If the original shl may be shifting out bits, do not perform this
12964 // transformation.
12965 unsigned KnownZeroBits = InnerZExt.getValueSizeInBits() -
12966 InnerZExt.getOperand(0).getValueSizeInBits();
12967 if (cast<ConstantSDNode>(ShAmt)->getAPIntValue().ugt(KnownZeroBits))
12968 return SDValue();
12969 }
12970
12971 SDLoc DL(N);
12972
12973 // Ensure that the shift amount is wide enough for the shifted value.
12974 if (Log2_32_Ceil(VT.getSizeInBits()) > ShAmt.getValueSizeInBits())
12975 ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt);
12976
12977 return DAG.getNode(N0.getOpcode(), DL, VT,
12978 DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N0.getOperand(0)),
12979 ShAmt);
12980 }
12981
12982 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
12983 return NewVSel;
12984
12985 if (SDValue NewCtPop = widenCtPop(N, DAG))
12986 return NewCtPop;
12987
12988 if (SDValue V = widenAbs(N, DAG))
12989 return V;
12990
12991 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
12992 return Res;
12993
12994 return SDValue();
12995 }
12996
visitANY_EXTEND(SDNode * N)12997 SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
12998 SDValue N0 = N->getOperand(0);
12999 EVT VT = N->getValueType(0);
13000
13001 // aext(undef) = undef
13002 if (N0.isUndef())
13003 return DAG.getUNDEF(VT);
13004
13005 if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
13006 return Res;
13007
13008 // fold (aext (aext x)) -> (aext x)
13009 // fold (aext (zext x)) -> (zext x)
13010 // fold (aext (sext x)) -> (sext x)
13011 if (N0.getOpcode() == ISD::ANY_EXTEND ||
13012 N0.getOpcode() == ISD::ZERO_EXTEND ||
13013 N0.getOpcode() == ISD::SIGN_EXTEND)
13014 return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
13015
13016 // fold (aext (truncate (load x))) -> (aext (smaller load x))
13017 // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
13018 if (N0.getOpcode() == ISD::TRUNCATE) {
13019 if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
13020 SDNode *oye = N0.getOperand(0).getNode();
13021 if (NarrowLoad.getNode() != N0.getNode()) {
13022 CombineTo(N0.getNode(), NarrowLoad);
13023 // CombineTo deleted the truncate, if needed, but not what's under it.
13024 AddToWorklist(oye);
13025 }
13026 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13027 }
13028 }
13029
13030 // fold (aext (truncate x))
13031 if (N0.getOpcode() == ISD::TRUNCATE)
13032 return DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
13033
13034 // Fold (aext (and (trunc x), cst)) -> (and x, cst)
13035 // if the trunc is not free.
13036 if (N0.getOpcode() == ISD::AND &&
13037 N0.getOperand(0).getOpcode() == ISD::TRUNCATE &&
13038 N0.getOperand(1).getOpcode() == ISD::Constant &&
13039 !TLI.isTruncateFree(N0.getOperand(0).getOperand(0).getValueType(),
13040 N0.getValueType())) {
13041 SDLoc DL(N);
13042 SDValue X = DAG.getAnyExtOrTrunc(N0.getOperand(0).getOperand(0), DL, VT);
13043 SDValue Y = DAG.getNode(ISD::ANY_EXTEND, DL, VT, N0.getOperand(1));
13044 assert(isa<ConstantSDNode>(Y) && "Expected constant to be folded!");
13045 return DAG.getNode(ISD::AND, DL, VT, X, Y);
13046 }
13047
13048 // fold (aext (load x)) -> (aext (truncate (extload x)))
13049 // None of the supported targets knows how to perform load and any_ext
13050 // on vectors in one instruction, so attempt to fold to zext instead.
13051 if (VT.isVector()) {
13052 // Try to simplify (zext (load x)).
13053 if (SDValue foldedExt =
13054 tryToFoldExtOfLoad(DAG, *this, TLI, VT, LegalOperations, N, N0,
13055 ISD::ZEXTLOAD, ISD::ZERO_EXTEND))
13056 return foldedExt;
13057 } else if (ISD::isNON_EXTLoad(N0.getNode()) &&
13058 ISD::isUNINDEXEDLoad(N0.getNode()) &&
13059 TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) {
13060 bool DoXform = true;
13061 SmallVector<SDNode *, 4> SetCCs;
13062 if (!N0.hasOneUse())
13063 DoXform =
13064 ExtendUsesToFormExtLoad(VT, N, N0, ISD::ANY_EXTEND, SetCCs, TLI);
13065 if (DoXform) {
13066 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13067 SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
13068 LN0->getChain(), LN0->getBasePtr(),
13069 N0.getValueType(), LN0->getMemOperand());
13070 ExtendSetCCUses(SetCCs, N0, ExtLoad, ISD::ANY_EXTEND);
13071 // If the load value is used only by N, replace it via CombineTo N.
13072 bool NoReplaceTrunc = N0.hasOneUse();
13073 CombineTo(N, ExtLoad);
13074 if (NoReplaceTrunc) {
13075 DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
13076 recursivelyDeleteUnusedNodes(LN0);
13077 } else {
13078 SDValue Trunc =
13079 DAG.getNode(ISD::TRUNCATE, SDLoc(N0), N0.getValueType(), ExtLoad);
13080 CombineTo(LN0, Trunc, ExtLoad.getValue(1));
13081 }
13082 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13083 }
13084 }
13085
13086 // fold (aext (zextload x)) -> (aext (truncate (zextload x)))
13087 // fold (aext (sextload x)) -> (aext (truncate (sextload x)))
13088 // fold (aext ( extload x)) -> (aext (truncate (extload x)))
13089 if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N0.getNode()) &&
13090 ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse()) {
13091 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13092 ISD::LoadExtType ExtType = LN0->getExtensionType();
13093 EVT MemVT = LN0->getMemoryVT();
13094 if (!LegalOperations || TLI.isLoadExtLegal(ExtType, VT, MemVT)) {
13095 SDValue ExtLoad = DAG.getExtLoad(ExtType, SDLoc(N),
13096 VT, LN0->getChain(), LN0->getBasePtr(),
13097 MemVT, LN0->getMemOperand());
13098 CombineTo(N, ExtLoad);
13099 DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), ExtLoad.getValue(1));
13100 recursivelyDeleteUnusedNodes(LN0);
13101 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13102 }
13103 }
13104
13105 if (N0.getOpcode() == ISD::SETCC) {
13106 // Propagate fast-math-flags.
13107 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
13108
13109 // For vectors:
13110 // aext(setcc) -> vsetcc
13111 // aext(setcc) -> truncate(vsetcc)
13112 // aext(setcc) -> aext(vsetcc)
13113 // Only do this before legalize for now.
13114 if (VT.isVector() && !LegalOperations) {
13115 EVT N00VT = N0.getOperand(0).getValueType();
13116 if (getSetCCResultType(N00VT) == N0.getValueType())
13117 return SDValue();
13118
13119 // We know that the # elements of the results is the same as the
13120 // # elements of the compare (and the # elements of the compare result
13121 // for that matter). Check to see that they are the same size. If so,
13122 // we know that the element size of the sext'd result matches the
13123 // element size of the compare operands.
13124 if (VT.getSizeInBits() == N00VT.getSizeInBits())
13125 return DAG.getSetCC(SDLoc(N), VT, N0.getOperand(0),
13126 N0.getOperand(1),
13127 cast<CondCodeSDNode>(N0.getOperand(2))->get());
13128
13129 // If the desired elements are smaller or larger than the source
13130 // elements we can use a matching integer vector type and then
13131 // truncate/any extend
13132 EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
13133 SDValue VsetCC =
13134 DAG.getSetCC(SDLoc(N), MatchingVectorType, N0.getOperand(0),
13135 N0.getOperand(1),
13136 cast<CondCodeSDNode>(N0.getOperand(2))->get());
13137 return DAG.getAnyExtOrTrunc(VsetCC, SDLoc(N), VT);
13138 }
13139
13140 // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
13141 SDLoc DL(N);
13142 if (SDValue SCC = SimplifySelectCC(
13143 DL, N0.getOperand(0), N0.getOperand(1), DAG.getConstant(1, DL, VT),
13144 DAG.getConstant(0, DL, VT),
13145 cast<CondCodeSDNode>(N0.getOperand(2))->get(), true))
13146 return SCC;
13147 }
13148
13149 if (SDValue NewCtPop = widenCtPop(N, DAG))
13150 return NewCtPop;
13151
13152 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
13153 return Res;
13154
13155 return SDValue();
13156 }
13157
visitAssertExt(SDNode * N)13158 SDValue DAGCombiner::visitAssertExt(SDNode *N) {
13159 unsigned Opcode = N->getOpcode();
13160 SDValue N0 = N->getOperand(0);
13161 SDValue N1 = N->getOperand(1);
13162 EVT AssertVT = cast<VTSDNode>(N1)->getVT();
13163
13164 // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
13165 if (N0.getOpcode() == Opcode &&
13166 AssertVT == cast<VTSDNode>(N0.getOperand(1))->getVT())
13167 return N0;
13168
13169 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
13170 N0.getOperand(0).getOpcode() == Opcode) {
13171 // We have an assert, truncate, assert sandwich. Make one stronger assert
13172 // by asserting on the smallest asserted type to the larger source type.
13173 // This eliminates the later assert:
13174 // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
13175 // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
13176 SDLoc DL(N);
13177 SDValue BigA = N0.getOperand(0);
13178 EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
13179 EVT MinAssertVT = AssertVT.bitsLT(BigA_AssertVT) ? AssertVT : BigA_AssertVT;
13180 SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
13181 SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
13182 BigA.getOperand(0), MinAssertVTVal);
13183 return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
13184 }
13185
13186 // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
13187 // than X. Just move the AssertZext in front of the truncate and drop the
13188 // AssertSExt.
13189 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
13190 N0.getOperand(0).getOpcode() == ISD::AssertSext &&
13191 Opcode == ISD::AssertZext) {
13192 SDValue BigA = N0.getOperand(0);
13193 EVT BigA_AssertVT = cast<VTSDNode>(BigA.getOperand(1))->getVT();
13194 if (AssertVT.bitsLT(BigA_AssertVT)) {
13195 SDLoc DL(N);
13196 SDValue NewAssert = DAG.getNode(Opcode, DL, BigA.getValueType(),
13197 BigA.getOperand(0), N1);
13198 return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), NewAssert);
13199 }
13200 }
13201
13202 return SDValue();
13203 }
13204
visitAssertAlign(SDNode * N)13205 SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
13206 SDLoc DL(N);
13207
13208 Align AL = cast<AssertAlignSDNode>(N)->getAlign();
13209 SDValue N0 = N->getOperand(0);
13210
13211 // Fold (assertalign (assertalign x, AL0), AL1) ->
13212 // (assertalign x, max(AL0, AL1))
13213 if (auto *AAN = dyn_cast<AssertAlignSDNode>(N0))
13214 return DAG.getAssertAlign(DL, N0.getOperand(0),
13215 std::max(AL, AAN->getAlign()));
13216
13217 // In rare cases, there are trivial arithmetic ops in source operands. Sink
13218 // this assert down to source operands so that those arithmetic ops could be
13219 // exposed to the DAG combining.
13220 switch (N0.getOpcode()) {
13221 default:
13222 break;
13223 case ISD::ADD:
13224 case ISD::SUB: {
13225 unsigned AlignShift = Log2(AL);
13226 SDValue LHS = N0.getOperand(0);
13227 SDValue RHS = N0.getOperand(1);
13228 unsigned LHSAlignShift = DAG.computeKnownBits(LHS).countMinTrailingZeros();
13229 unsigned RHSAlignShift = DAG.computeKnownBits(RHS).countMinTrailingZeros();
13230 if (LHSAlignShift >= AlignShift || RHSAlignShift >= AlignShift) {
13231 if (LHSAlignShift < AlignShift)
13232 LHS = DAG.getAssertAlign(DL, LHS, AL);
13233 if (RHSAlignShift < AlignShift)
13234 RHS = DAG.getAssertAlign(DL, RHS, AL);
13235 return DAG.getNode(N0.getOpcode(), DL, N0.getValueType(), LHS, RHS);
13236 }
13237 break;
13238 }
13239 }
13240
13241 return SDValue();
13242 }
13243
13244 /// If the result of a load is shifted/masked/truncated to an effectively
13245 /// narrower type, try to transform the load to a narrower type and/or
13246 /// use an extending load.
reduceLoadWidth(SDNode * N)13247 SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
13248 unsigned Opc = N->getOpcode();
13249
13250 ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
13251 SDValue N0 = N->getOperand(0);
13252 EVT VT = N->getValueType(0);
13253 EVT ExtVT = VT;
13254
13255 // This transformation isn't valid for vector loads.
13256 if (VT.isVector())
13257 return SDValue();
13258
13259 // The ShAmt variable is used to indicate that we've consumed a right
13260 // shift. I.e. we want to narrow the width of the load by skipping to load the
13261 // ShAmt least significant bits.
13262 unsigned ShAmt = 0;
13263 // A special case is when the least significant bits from the load are masked
13264 // away, but using an AND rather than a right shift. HasShiftedOffset is used
13265 // to indicate that the narrowed load should be left-shifted ShAmt bits to get
13266 // the result.
13267 bool HasShiftedOffset = false;
13268 // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
13269 // extended to VT.
13270 if (Opc == ISD::SIGN_EXTEND_INREG) {
13271 ExtType = ISD::SEXTLOAD;
13272 ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT();
13273 } else if (Opc == ISD::SRL || Opc == ISD::SRA) {
13274 // Another special-case: SRL/SRA is basically zero/sign-extending a narrower
13275 // value, or it may be shifting a higher subword, half or byte into the
13276 // lowest bits.
13277
13278 // Only handle shift with constant shift amount, and the shiftee must be a
13279 // load.
13280 auto *LN = dyn_cast<LoadSDNode>(N0);
13281 auto *N1C = dyn_cast<ConstantSDNode>(N->getOperand(1));
13282 if (!N1C || !LN)
13283 return SDValue();
13284 // If the shift amount is larger than the memory type then we're not
13285 // accessing any of the loaded bytes.
13286 ShAmt = N1C->getZExtValue();
13287 uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits();
13288 if (MemoryWidth <= ShAmt)
13289 return SDValue();
13290 // Attempt to fold away the SRL by using ZEXTLOAD and SRA by using SEXTLOAD.
13291 ExtType = Opc == ISD::SRL ? ISD::ZEXTLOAD : ISD::SEXTLOAD;
13292 ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
13293 // If original load is a SEXTLOAD then we can't simply replace it by a
13294 // ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD
13295 // followed by a ZEXT, but that is not handled at the moment). Similarly if
13296 // the original load is a ZEXTLOAD and we want to use a SEXTLOAD.
13297 if ((LN->getExtensionType() == ISD::SEXTLOAD ||
13298 LN->getExtensionType() == ISD::ZEXTLOAD) &&
13299 LN->getExtensionType() != ExtType)
13300 return SDValue();
13301 } else if (Opc == ISD::AND) {
13302 // An AND with a constant mask is the same as a truncate + zero-extend.
13303 auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1));
13304 if (!AndC)
13305 return SDValue();
13306
13307 const APInt &Mask = AndC->getAPIntValue();
13308 unsigned ActiveBits = 0;
13309 if (Mask.isMask()) {
13310 ActiveBits = Mask.countTrailingOnes();
13311 } else if (Mask.isShiftedMask(ShAmt, ActiveBits)) {
13312 HasShiftedOffset = true;
13313 } else {
13314 return SDValue();
13315 }
13316
13317 ExtType = ISD::ZEXTLOAD;
13318 ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
13319 }
13320
13321 // In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing
13322 // a right shift. Here we redo some of those checks, to possibly adjust the
13323 // ExtVT even further based on "a masking AND". We could also end up here for
13324 // other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks
13325 // need to be done here as well.
13326 if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) {
13327 SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0;
13328 // Bail out when the SRL has more than one use. This is done for historical
13329 // (undocumented) reasons. Maybe intent was to guard the AND-masking below
13330 // check below? And maybe it could be non-profitable to do the transform in
13331 // case the SRL has multiple uses and we get here with Opc!=ISD::SRL?
13332 // FIXME: Can't we just skip this check for the Opc==ISD::SRL case.
13333 if (!SRL.hasOneUse())
13334 return SDValue();
13335
13336 // Only handle shift with constant shift amount, and the shiftee must be a
13337 // load.
13338 auto *LN = dyn_cast<LoadSDNode>(SRL.getOperand(0));
13339 auto *SRL1C = dyn_cast<ConstantSDNode>(SRL.getOperand(1));
13340 if (!SRL1C || !LN)
13341 return SDValue();
13342
13343 // If the shift amount is larger than the input type then we're not
13344 // accessing any of the loaded bytes. If the load was a zextload/extload
13345 // then the result of the shift+trunc is zero/undef (handled elsewhere).
13346 ShAmt = SRL1C->getZExtValue();
13347 uint64_t MemoryWidth = LN->getMemoryVT().getSizeInBits();
13348 if (ShAmt >= MemoryWidth)
13349 return SDValue();
13350
13351 // Because a SRL must be assumed to *need* to zero-extend the high bits
13352 // (as opposed to anyext the high bits), we can't combine the zextload
13353 // lowering of SRL and an sextload.
13354 if (LN->getExtensionType() == ISD::SEXTLOAD)
13355 return SDValue();
13356
13357 // Avoid reading outside the memory accessed by the original load (could
13358 // happened if we only adjust the load base pointer by ShAmt). Instead we
13359 // try to narrow the load even further. The typical scenario here is:
13360 // (i64 (truncate (i96 (srl (load x), 64)))) ->
13361 // (i64 (truncate (i96 (zextload (load i32 + offset) from i32))))
13362 if (ExtVT.getScalarSizeInBits() > MemoryWidth - ShAmt) {
13363 // Don't replace sextload by zextload.
13364 if (ExtType == ISD::SEXTLOAD)
13365 return SDValue();
13366 // Narrow the load.
13367 ExtType = ISD::ZEXTLOAD;
13368 ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
13369 }
13370
13371 // If the SRL is only used by a masking AND, we may be able to adjust
13372 // the ExtVT to make the AND redundant.
13373 SDNode *Mask = *(SRL->use_begin());
13374 if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND &&
13375 isa<ConstantSDNode>(Mask->getOperand(1))) {
13376 const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
13377 if (ShiftMask.isMask()) {
13378 EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(),
13379 ShiftMask.countTrailingOnes());
13380 // If the mask is smaller, recompute the type.
13381 if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
13382 TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT))
13383 ExtVT = MaskedVT;
13384 }
13385 }
13386
13387 N0 = SRL.getOperand(0);
13388 }
13389
13390 // If the load is shifted left (and the result isn't shifted back right), we
13391 // can fold a truncate through the shift. The typical scenario is that N
13392 // points at a TRUNCATE here so the attempted fold is:
13393 // (truncate (shl (load x), c))) -> (shl (narrow load x), c)
13394 // ShLeftAmt will indicate how much a narrowed load should be shifted left.
13395 unsigned ShLeftAmt = 0;
13396 if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
13397 ExtVT == VT && TLI.isNarrowingProfitable(N0.getValueType(), VT)) {
13398 if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
13399 ShLeftAmt = N01->getZExtValue();
13400 N0 = N0.getOperand(0);
13401 }
13402 }
13403
13404 // If we haven't found a load, we can't narrow it.
13405 if (!isa<LoadSDNode>(N0))
13406 return SDValue();
13407
13408 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13409 // Reducing the width of a volatile load is illegal. For atomics, we may be
13410 // able to reduce the width provided we never widen again. (see D66309)
13411 if (!LN0->isSimple() ||
13412 !isLegalNarrowLdSt(LN0, ExtType, ExtVT, ShAmt))
13413 return SDValue();
13414
13415 auto AdjustBigEndianShift = [&](unsigned ShAmt) {
13416 unsigned LVTStoreBits =
13417 LN0->getMemoryVT().getStoreSizeInBits().getFixedValue();
13418 unsigned EVTStoreBits = ExtVT.getStoreSizeInBits().getFixedValue();
13419 return LVTStoreBits - EVTStoreBits - ShAmt;
13420 };
13421
13422 // We need to adjust the pointer to the load by ShAmt bits in order to load
13423 // the correct bytes.
13424 unsigned PtrAdjustmentInBits =
13425 DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt;
13426
13427 uint64_t PtrOff = PtrAdjustmentInBits / 8;
13428 Align NewAlign = commonAlignment(LN0->getAlign(), PtrOff);
13429 SDLoc DL(LN0);
13430 // The original load itself didn't wrap, so an offset within it doesn't.
13431 SDNodeFlags Flags;
13432 Flags.setNoUnsignedWrap(true);
13433 SDValue NewPtr = DAG.getMemBasePlusOffset(LN0->getBasePtr(),
13434 TypeSize::Fixed(PtrOff), DL, Flags);
13435 AddToWorklist(NewPtr.getNode());
13436
13437 SDValue Load;
13438 if (ExtType == ISD::NON_EXTLOAD)
13439 Load = DAG.getLoad(VT, DL, LN0->getChain(), NewPtr,
13440 LN0->getPointerInfo().getWithOffset(PtrOff), NewAlign,
13441 LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
13442 else
13443 Load = DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), NewPtr,
13444 LN0->getPointerInfo().getWithOffset(PtrOff), ExtVT,
13445 NewAlign, LN0->getMemOperand()->getFlags(),
13446 LN0->getAAInfo());
13447
13448 // Replace the old load's chain with the new load's chain.
13449 WorklistRemover DeadNodes(*this);
13450 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
13451
13452 // Shift the result left, if we've swallowed a left shift.
13453 SDValue Result = Load;
13454 if (ShLeftAmt != 0) {
13455 EVT ShImmTy = getShiftAmountTy(Result.getValueType());
13456 if (!isUIntN(ShImmTy.getScalarSizeInBits(), ShLeftAmt))
13457 ShImmTy = VT;
13458 // If the shift amount is as large as the result size (but, presumably,
13459 // no larger than the source) then the useful bits of the result are
13460 // zero; we can't simply return the shortened shift, because the result
13461 // of that operation is undefined.
13462 if (ShLeftAmt >= VT.getScalarSizeInBits())
13463 Result = DAG.getConstant(0, DL, VT);
13464 else
13465 Result = DAG.getNode(ISD::SHL, DL, VT,
13466 Result, DAG.getConstant(ShLeftAmt, DL, ShImmTy));
13467 }
13468
13469 if (HasShiftedOffset) {
13470 // We're using a shifted mask, so the load now has an offset. This means
13471 // that data has been loaded into the lower bytes than it would have been
13472 // before, so we need to shl the loaded data into the correct position in the
13473 // register.
13474 SDValue ShiftC = DAG.getConstant(ShAmt, DL, VT);
13475 Result = DAG.getNode(ISD::SHL, DL, VT, Result, ShiftC);
13476 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result);
13477 }
13478
13479 // Return the new loaded value.
13480 return Result;
13481 }
13482
visitSIGN_EXTEND_INREG(SDNode * N)13483 SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
13484 SDValue N0 = N->getOperand(0);
13485 SDValue N1 = N->getOperand(1);
13486 EVT VT = N->getValueType(0);
13487 EVT ExtVT = cast<VTSDNode>(N1)->getVT();
13488 unsigned VTBits = VT.getScalarSizeInBits();
13489 unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
13490
13491 // sext_vector_inreg(undef) = 0 because the top bit will all be the same.
13492 if (N0.isUndef())
13493 return DAG.getConstant(0, SDLoc(N), VT);
13494
13495 // fold (sext_in_reg c1) -> c1
13496 if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
13497 return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0, N1);
13498
13499 // If the input is already sign extended, just drop the extension.
13500 if (ExtVTBits >= DAG.ComputeMaxSignificantBits(N0))
13501 return N0;
13502
13503 // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2
13504 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
13505 ExtVT.bitsLT(cast<VTSDNode>(N0.getOperand(1))->getVT()))
13506 return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0.getOperand(0),
13507 N1);
13508
13509 // fold (sext_in_reg (sext x)) -> (sext x)
13510 // fold (sext_in_reg (aext x)) -> (sext x)
13511 // if x is small enough or if we know that x has more than 1 sign bit and the
13512 // sign_extend_inreg is extending from one of them.
13513 if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
13514 SDValue N00 = N0.getOperand(0);
13515 unsigned N00Bits = N00.getScalarValueSizeInBits();
13516 if ((N00Bits <= ExtVTBits ||
13517 DAG.ComputeMaxSignificantBits(N00) <= ExtVTBits) &&
13518 (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
13519 return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00);
13520 }
13521
13522 // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
13523 // if x is small enough or if we know that x has more than 1 sign bit and the
13524 // sign_extend_inreg is extending from one of them.
13525 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
13526 N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG ||
13527 N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG) {
13528 SDValue N00 = N0.getOperand(0);
13529 unsigned N00Bits = N00.getScalarValueSizeInBits();
13530 unsigned DstElts = N0.getValueType().getVectorMinNumElements();
13531 unsigned SrcElts = N00.getValueType().getVectorMinNumElements();
13532 bool IsZext = N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
13533 APInt DemandedSrcElts = APInt::getLowBitsSet(SrcElts, DstElts);
13534 if ((N00Bits == ExtVTBits ||
13535 (!IsZext && (N00Bits < ExtVTBits ||
13536 DAG.ComputeMaxSignificantBits(N00) <= ExtVTBits))) &&
13537 (!LegalOperations ||
13538 TLI.isOperationLegal(ISD::SIGN_EXTEND_VECTOR_INREG, VT)))
13539 return DAG.getNode(ISD::SIGN_EXTEND_VECTOR_INREG, SDLoc(N), VT, N00);
13540 }
13541
13542 // fold (sext_in_reg (zext x)) -> (sext x)
13543 // iff we are extending the source sign bit.
13544 if (N0.getOpcode() == ISD::ZERO_EXTEND) {
13545 SDValue N00 = N0.getOperand(0);
13546 if (N00.getScalarValueSizeInBits() == ExtVTBits &&
13547 (!LegalOperations || TLI.isOperationLegal(ISD::SIGN_EXTEND, VT)))
13548 return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, N00, N1);
13549 }
13550
13551 // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
13552 if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, ExtVTBits - 1)))
13553 return DAG.getZeroExtendInReg(N0, SDLoc(N), ExtVT);
13554
13555 // fold operands of sext_in_reg based on knowledge that the top bits are not
13556 // demanded.
13557 if (SimplifyDemandedBits(SDValue(N, 0)))
13558 return SDValue(N, 0);
13559
13560 // fold (sext_in_reg (load x)) -> (smaller sextload x)
13561 // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
13562 if (SDValue NarrowLoad = reduceLoadWidth(N))
13563 return NarrowLoad;
13564
13565 // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
13566 // fold (sext_in_reg (srl X, 23), i8) -> (sra X, 23) iff possible.
13567 // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
13568 if (N0.getOpcode() == ISD::SRL) {
13569 if (auto *ShAmt = dyn_cast<ConstantSDNode>(N0.getOperand(1)))
13570 if (ShAmt->getAPIntValue().ule(VTBits - ExtVTBits)) {
13571 // We can turn this into an SRA iff the input to the SRL is already sign
13572 // extended enough.
13573 unsigned InSignBits = DAG.ComputeNumSignBits(N0.getOperand(0));
13574 if (((VTBits - ExtVTBits) - ShAmt->getZExtValue()) < InSignBits)
13575 return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0.getOperand(0),
13576 N0.getOperand(1));
13577 }
13578 }
13579
13580 // fold (sext_inreg (extload x)) -> (sextload x)
13581 // If sextload is not supported by target, we can only do the combine when
13582 // load has one use. Doing otherwise can block folding the extload with other
13583 // extends that the target does support.
13584 if (ISD::isEXTLoad(N0.getNode()) &&
13585 ISD::isUNINDEXEDLoad(N0.getNode()) &&
13586 ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
13587 ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple() &&
13588 N0.hasOneUse()) ||
13589 TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
13590 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13591 SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
13592 LN0->getChain(),
13593 LN0->getBasePtr(), ExtVT,
13594 LN0->getMemOperand());
13595 CombineTo(N, ExtLoad);
13596 CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
13597 AddToWorklist(ExtLoad.getNode());
13598 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13599 }
13600
13601 // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
13602 if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
13603 N0.hasOneUse() &&
13604 ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
13605 ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) &&
13606 TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) {
13607 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13608 SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, SDLoc(N), VT,
13609 LN0->getChain(),
13610 LN0->getBasePtr(), ExtVT,
13611 LN0->getMemOperand());
13612 CombineTo(N, ExtLoad);
13613 CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
13614 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13615 }
13616
13617 // fold (sext_inreg (masked_load x)) -> (sext_masked_load x)
13618 // ignore it if the masked load is already sign extended
13619 if (MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(N0)) {
13620 if (ExtVT == Ld->getMemoryVT() && N0.hasOneUse() &&
13621 Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD &&
13622 TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) {
13623 SDValue ExtMaskedLoad = DAG.getMaskedLoad(
13624 VT, SDLoc(N), Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(),
13625 Ld->getMask(), Ld->getPassThru(), ExtVT, Ld->getMemOperand(),
13626 Ld->getAddressingMode(), ISD::SEXTLOAD, Ld->isExpandingLoad());
13627 CombineTo(N, ExtMaskedLoad);
13628 CombineTo(N0.getNode(), ExtMaskedLoad, ExtMaskedLoad.getValue(1));
13629 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13630 }
13631 }
13632
13633 // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x)
13634 if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) {
13635 if (SDValue(GN0, 0).hasOneUse() &&
13636 ExtVT == GN0->getMemoryVT() &&
13637 TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) {
13638 SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(),
13639 GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()};
13640
13641 SDValue ExtLoad = DAG.getMaskedGather(
13642 DAG.getVTList(VT, MVT::Other), ExtVT, SDLoc(N), Ops,
13643 GN0->getMemOperand(), GN0->getIndexType(), ISD::SEXTLOAD);
13644
13645 CombineTo(N, ExtLoad);
13646 CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
13647 AddToWorklist(ExtLoad.getNode());
13648 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13649 }
13650 }
13651
13652 // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
13653 if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) {
13654 if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0),
13655 N0.getOperand(1), false))
13656 return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, BSwap, N1);
13657 }
13658
13659 // Fold (iM_signext_inreg
13660 // (extract_subvector (zext|anyext|sext iN_v to _) _)
13661 // from iN)
13662 // -> (extract_subvector (signext iN_v to iM))
13663 if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() &&
13664 ISD::isExtOpcode(N0.getOperand(0).getOpcode())) {
13665 SDValue InnerExt = N0.getOperand(0);
13666 EVT InnerExtVT = InnerExt->getValueType(0);
13667 SDValue Extendee = InnerExt->getOperand(0);
13668
13669 if (ExtVTBits == Extendee.getValueType().getScalarSizeInBits() &&
13670 (!LegalOperations ||
13671 TLI.isOperationLegal(ISD::SIGN_EXTEND, InnerExtVT))) {
13672 SDValue SignExtExtendee =
13673 DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), InnerExtVT, Extendee);
13674 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT, SignExtExtendee,
13675 N0.getOperand(1));
13676 }
13677 }
13678
13679 return SDValue();
13680 }
13681
13682 static SDValue
foldExtendVectorInregToExtendOfSubvector(SDNode * N,const TargetLowering & TLI,SelectionDAG & DAG,bool LegalOperations)13683 foldExtendVectorInregToExtendOfSubvector(SDNode *N, const TargetLowering &TLI,
13684 SelectionDAG &DAG,
13685 bool LegalOperations) {
13686 unsigned InregOpcode = N->getOpcode();
13687 unsigned Opcode = DAG.getOpcode_EXTEND(InregOpcode);
13688
13689 SDValue Src = N->getOperand(0);
13690 EVT VT = N->getValueType(0);
13691 EVT SrcVT = EVT::getVectorVT(*DAG.getContext(),
13692 Src.getValueType().getVectorElementType(),
13693 VT.getVectorElementCount());
13694
13695 assert((InregOpcode == ISD::SIGN_EXTEND_VECTOR_INREG ||
13696 InregOpcode == ISD::ZERO_EXTEND_VECTOR_INREG ||
13697 InregOpcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
13698 "Expected EXTEND_VECTOR_INREG dag node in input!");
13699
13700 // Profitability check: our operand must be an one-use CONCAT_VECTORS.
13701 // FIXME: one-use check may be overly restrictive
13702 if (!Src.hasOneUse() || Src.getOpcode() != ISD::CONCAT_VECTORS)
13703 return SDValue();
13704
13705 // Profitability check: we must be extending exactly one of it's operands.
13706 // FIXME: this is probably overly restrictive.
13707 Src = Src.getOperand(0);
13708 if (Src.getValueType() != SrcVT)
13709 return SDValue();
13710
13711 if (LegalOperations && !TLI.isOperationLegal(Opcode, VT))
13712 return SDValue();
13713
13714 return DAG.getNode(Opcode, SDLoc(N), VT, Src);
13715 }
13716
visitEXTEND_VECTOR_INREG(SDNode * N)13717 SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
13718 SDValue N0 = N->getOperand(0);
13719 EVT VT = N->getValueType(0);
13720
13721 if (N0.isUndef()) {
13722 // aext_vector_inreg(undef) = undef because the top bits are undefined.
13723 // {s/z}ext_vector_inreg(undef) = 0 because the top bits must be the same.
13724 return N->getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG
13725 ? DAG.getUNDEF(VT)
13726 : DAG.getConstant(0, SDLoc(N), VT);
13727 }
13728
13729 if (SDValue Res = tryToFoldExtendOfConstant(N, TLI, DAG, LegalTypes))
13730 return Res;
13731
13732 if (SimplifyDemandedVectorElts(SDValue(N, 0)))
13733 return SDValue(N, 0);
13734
13735 if (SDValue R = foldExtendVectorInregToExtendOfSubvector(N, TLI, DAG,
13736 LegalOperations))
13737 return R;
13738
13739 return SDValue();
13740 }
13741
visitTRUNCATE(SDNode * N)13742 SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
13743 SDValue N0 = N->getOperand(0);
13744 EVT VT = N->getValueType(0);
13745 EVT SrcVT = N0.getValueType();
13746 bool isLE = DAG.getDataLayout().isLittleEndian();
13747
13748 // noop truncate
13749 if (SrcVT == VT)
13750 return N0;
13751
13752 // fold (truncate (truncate x)) -> (truncate x)
13753 if (N0.getOpcode() == ISD::TRUNCATE)
13754 return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
13755
13756 // fold (truncate c1) -> c1
13757 if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
13758 SDValue C = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0);
13759 if (C.getNode() != N)
13760 return C;
13761 }
13762
13763 // fold (truncate (ext x)) -> (ext x) or (truncate x) or x
13764 if (N0.getOpcode() == ISD::ZERO_EXTEND ||
13765 N0.getOpcode() == ISD::SIGN_EXTEND ||
13766 N0.getOpcode() == ISD::ANY_EXTEND) {
13767 // if the source is smaller than the dest, we still need an extend.
13768 if (N0.getOperand(0).getValueType().bitsLT(VT))
13769 return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, N0.getOperand(0));
13770 // if the source is larger than the dest, than we just need the truncate.
13771 if (N0.getOperand(0).getValueType().bitsGT(VT))
13772 return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0.getOperand(0));
13773 // if the source and dest are the same type, we can drop both the extend
13774 // and the truncate.
13775 return N0.getOperand(0);
13776 }
13777
13778 // Try to narrow a truncate-of-sext_in_reg to the destination type:
13779 // trunc (sign_ext_inreg X, iM) to iN --> sign_ext_inreg (trunc X to iN), iM
13780 if (!LegalTypes && N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
13781 N0.hasOneUse()) {
13782 SDValue X = N0.getOperand(0);
13783 SDValue ExtVal = N0.getOperand(1);
13784 EVT ExtVT = cast<VTSDNode>(ExtVal)->getVT();
13785 if (ExtVT.bitsLT(VT)) {
13786 SDValue TrX = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, X);
13787 return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, TrX, ExtVal);
13788 }
13789 }
13790
13791 // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
13792 if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ANY_EXTEND))
13793 return SDValue();
13794
13795 // Fold extract-and-trunc into a narrow extract. For example:
13796 // i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
13797 // i32 y = TRUNCATE(i64 x)
13798 // -- becomes --
13799 // v16i8 b = BITCAST (v2i64 val)
13800 // i8 x = EXTRACT_VECTOR_ELT(v16i8 b, i32 8)
13801 //
13802 // Note: We only run this optimization after type legalization (which often
13803 // creates this pattern) and before operation legalization after which
13804 // we need to be more careful about the vector instructions that we generate.
13805 if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
13806 LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) {
13807 EVT VecTy = N0.getOperand(0).getValueType();
13808 EVT ExTy = N0.getValueType();
13809 EVT TrTy = N->getValueType(0);
13810
13811 auto EltCnt = VecTy.getVectorElementCount();
13812 unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits();
13813 auto NewEltCnt = EltCnt * SizeRatio;
13814
13815 EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
13816 assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
13817
13818 SDValue EltNo = N0->getOperand(1);
13819 if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
13820 int Elt = cast<ConstantSDNode>(EltNo)->getZExtValue();
13821 int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
13822
13823 SDLoc DL(N);
13824 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
13825 DAG.getBitcast(NVT, N0.getOperand(0)),
13826 DAG.getVectorIdxConstant(Index, DL));
13827 }
13828 }
13829
13830 // trunc (select c, a, b) -> select c, (trunc a), (trunc b)
13831 if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse()) {
13832 if ((!LegalOperations || TLI.isOperationLegal(ISD::SELECT, SrcVT)) &&
13833 TLI.isTruncateFree(SrcVT, VT)) {
13834 SDLoc SL(N0);
13835 SDValue Cond = N0.getOperand(0);
13836 SDValue TruncOp0 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
13837 SDValue TruncOp1 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(2));
13838 return DAG.getNode(ISD::SELECT, SDLoc(N), VT, Cond, TruncOp0, TruncOp1);
13839 }
13840 }
13841
13842 // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
13843 if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
13844 (!LegalOperations || TLI.isOperationLegal(ISD::SHL, VT)) &&
13845 TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
13846 SDValue Amt = N0.getOperand(1);
13847 KnownBits Known = DAG.computeKnownBits(Amt);
13848 unsigned Size = VT.getScalarSizeInBits();
13849 if (Known.countMaxActiveBits() <= Log2_32(Size)) {
13850 SDLoc SL(N);
13851 EVT AmtVT = TLI.getShiftAmountTy(VT, DAG.getDataLayout());
13852
13853 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(0));
13854 if (AmtVT != Amt.getValueType()) {
13855 Amt = DAG.getZExtOrTrunc(Amt, SL, AmtVT);
13856 AddToWorklist(Amt.getNode());
13857 }
13858 return DAG.getNode(ISD::SHL, SL, VT, Trunc, Amt);
13859 }
13860 }
13861
13862 if (SDValue V = foldSubToUSubSat(VT, N0.getNode()))
13863 return V;
13864
13865 // Attempt to pre-truncate BUILD_VECTOR sources.
13866 if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations &&
13867 TLI.isTruncateFree(SrcVT.getScalarType(), VT.getScalarType()) &&
13868 // Avoid creating illegal types if running after type legalizer.
13869 (!LegalTypes || TLI.isTypeLegal(VT.getScalarType()))) {
13870 SDLoc DL(N);
13871 EVT SVT = VT.getScalarType();
13872 SmallVector<SDValue, 8> TruncOps;
13873 for (const SDValue &Op : N0->op_values()) {
13874 SDValue TruncOp = DAG.getNode(ISD::TRUNCATE, DL, SVT, Op);
13875 TruncOps.push_back(TruncOp);
13876 }
13877 return DAG.getBuildVector(VT, DL, TruncOps);
13878 }
13879
13880 // Fold a series of buildvector, bitcast, and truncate if possible.
13881 // For example fold
13882 // (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
13883 // (2xi32 (buildvector x, y)).
13884 if (Level == AfterLegalizeVectorOps && VT.isVector() &&
13885 N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
13886 N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR &&
13887 N0.getOperand(0).hasOneUse()) {
13888 SDValue BuildVect = N0.getOperand(0);
13889 EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
13890 EVT TruncVecEltTy = VT.getVectorElementType();
13891
13892 // Check that the element types match.
13893 if (BuildVectEltTy == TruncVecEltTy) {
13894 // Now we only need to compute the offset of the truncated elements.
13895 unsigned BuildVecNumElts = BuildVect.getNumOperands();
13896 unsigned TruncVecNumElts = VT.getVectorNumElements();
13897 unsigned TruncEltOffset = BuildVecNumElts / TruncVecNumElts;
13898
13899 assert((BuildVecNumElts % TruncVecNumElts) == 0 &&
13900 "Invalid number of elements");
13901
13902 SmallVector<SDValue, 8> Opnds;
13903 for (unsigned i = 0, e = BuildVecNumElts; i != e; i += TruncEltOffset)
13904 Opnds.push_back(BuildVect.getOperand(i));
13905
13906 return DAG.getBuildVector(VT, SDLoc(N), Opnds);
13907 }
13908 }
13909
13910 // fold (truncate (load x)) -> (smaller load x)
13911 // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
13912 if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
13913 if (SDValue Reduced = reduceLoadWidth(N))
13914 return Reduced;
13915
13916 // Handle the case where the load remains an extending load even
13917 // after truncation.
13918 if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N0.getNode())) {
13919 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
13920 if (LN0->isSimple() && LN0->getMemoryVT().bitsLT(VT)) {
13921 SDValue NewLoad = DAG.getExtLoad(LN0->getExtensionType(), SDLoc(LN0),
13922 VT, LN0->getChain(), LN0->getBasePtr(),
13923 LN0->getMemoryVT(),
13924 LN0->getMemOperand());
13925 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLoad.getValue(1));
13926 return NewLoad;
13927 }
13928 }
13929 }
13930
13931 // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
13932 // where ... are all 'undef'.
13933 if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
13934 SmallVector<EVT, 8> VTs;
13935 SDValue V;
13936 unsigned Idx = 0;
13937 unsigned NumDefs = 0;
13938
13939 for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
13940 SDValue X = N0.getOperand(i);
13941 if (!X.isUndef()) {
13942 V = X;
13943 Idx = i;
13944 NumDefs++;
13945 }
13946 // Stop if more than one members are non-undef.
13947 if (NumDefs > 1)
13948 break;
13949
13950 VTs.push_back(EVT::getVectorVT(*DAG.getContext(),
13951 VT.getVectorElementType(),
13952 X.getValueType().getVectorElementCount()));
13953 }
13954
13955 if (NumDefs == 0)
13956 return DAG.getUNDEF(VT);
13957
13958 if (NumDefs == 1) {
13959 assert(V.getNode() && "The single defined operand is empty!");
13960 SmallVector<SDValue, 8> Opnds;
13961 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
13962 if (i != Idx) {
13963 Opnds.push_back(DAG.getUNDEF(VTs[i]));
13964 continue;
13965 }
13966 SDValue NV = DAG.getNode(ISD::TRUNCATE, SDLoc(V), VTs[i], V);
13967 AddToWorklist(NV.getNode());
13968 Opnds.push_back(NV);
13969 }
13970 return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Opnds);
13971 }
13972 }
13973
13974 // Fold truncate of a bitcast of a vector to an extract of the low vector
13975 // element.
13976 //
13977 // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
13978 if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
13979 SDValue VecSrc = N0.getOperand(0);
13980 EVT VecSrcVT = VecSrc.getValueType();
13981 if (VecSrcVT.isVector() && VecSrcVT.getScalarType() == VT &&
13982 (!LegalOperations ||
13983 TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecSrcVT))) {
13984 SDLoc SL(N);
13985
13986 unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1;
13987 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, VT, VecSrc,
13988 DAG.getVectorIdxConstant(Idx, SL));
13989 }
13990 }
13991
13992 // Simplify the operands using demanded-bits information.
13993 if (SimplifyDemandedBits(SDValue(N, 0)))
13994 return SDValue(N, 0);
13995
13996 // fold (truncate (extract_subvector(ext x))) ->
13997 // (extract_subvector x)
13998 // TODO: This can be generalized to cover cases where the truncate and extract
13999 // do not fully cancel each other out.
14000 if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
14001 SDValue N00 = N0.getOperand(0);
14002 if (N00.getOpcode() == ISD::SIGN_EXTEND ||
14003 N00.getOpcode() == ISD::ZERO_EXTEND ||
14004 N00.getOpcode() == ISD::ANY_EXTEND) {
14005 if (N00.getOperand(0)->getValueType(0).getVectorElementType() ==
14006 VT.getVectorElementType())
14007 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N0->getOperand(0)), VT,
14008 N00.getOperand(0), N0.getOperand(1));
14009 }
14010 }
14011
14012 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
14013 return NewVSel;
14014
14015 // Narrow a suitable binary operation with a non-opaque constant operand by
14016 // moving it ahead of the truncate. This is limited to pre-legalization
14017 // because targets may prefer a wider type during later combines and invert
14018 // this transform.
14019 switch (N0.getOpcode()) {
14020 case ISD::ADD:
14021 case ISD::SUB:
14022 case ISD::MUL:
14023 case ISD::AND:
14024 case ISD::OR:
14025 case ISD::XOR:
14026 if (!LegalOperations && N0.hasOneUse() &&
14027 (isConstantOrConstantVector(N0.getOperand(0), true) ||
14028 isConstantOrConstantVector(N0.getOperand(1), true))) {
14029 // TODO: We already restricted this to pre-legalization, but for vectors
14030 // we are extra cautious to not create an unsupported operation.
14031 // Target-specific changes are likely needed to avoid regressions here.
14032 if (VT.isScalarInteger() || TLI.isOperationLegal(N0.getOpcode(), VT)) {
14033 SDLoc DL(N);
14034 SDValue NarrowL = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
14035 SDValue NarrowR = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
14036 return DAG.getNode(N0.getOpcode(), DL, VT, NarrowL, NarrowR);
14037 }
14038 }
14039 break;
14040 case ISD::ADDE:
14041 case ISD::ADDCARRY:
14042 // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
14043 // (trunc addcarry(X, Y, Carry)) -> (addcarry trunc(X), trunc(Y), Carry)
14044 // When the adde's carry is not used.
14045 // We only do for addcarry before legalize operation
14046 if (((!LegalOperations && N0.getOpcode() == ISD::ADDCARRY) ||
14047 TLI.isOperationLegal(N0.getOpcode(), VT)) &&
14048 N0.hasOneUse() && !N0->hasAnyUseOfValue(1)) {
14049 SDLoc DL(N);
14050 SDValue X = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
14051 SDValue Y = DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(1));
14052 SDVTList VTs = DAG.getVTList(VT, N0->getValueType(1));
14053 return DAG.getNode(N0.getOpcode(), DL, VTs, X, Y, N0.getOperand(2));
14054 }
14055 break;
14056 case ISD::USUBSAT:
14057 // Truncate the USUBSAT only if LHS is a known zero-extension, its not
14058 // enough to know that the upper bits are zero we must ensure that we don't
14059 // introduce an extra truncate.
14060 if (!LegalOperations && N0.hasOneUse() &&
14061 N0.getOperand(0).getOpcode() == ISD::ZERO_EXTEND &&
14062 N0.getOperand(0).getOperand(0).getScalarValueSizeInBits() <=
14063 VT.getScalarSizeInBits() &&
14064 hasOperation(N0.getOpcode(), VT)) {
14065 return getTruncatedUSUBSAT(VT, SrcVT, N0.getOperand(0), N0.getOperand(1),
14066 DAG, SDLoc(N));
14067 }
14068 break;
14069 }
14070
14071 return SDValue();
14072 }
14073
getBuildPairElt(SDNode * N,unsigned i)14074 static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
14075 SDValue Elt = N->getOperand(i);
14076 if (Elt.getOpcode() != ISD::MERGE_VALUES)
14077 return Elt.getNode();
14078 return Elt.getOperand(Elt.getResNo()).getNode();
14079 }
14080
14081 /// build_pair (load, load) -> load
14082 /// if load locations are consecutive.
CombineConsecutiveLoads(SDNode * N,EVT VT)14083 SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
14084 assert(N->getOpcode() == ISD::BUILD_PAIR);
14085
14086 auto *LD1 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 0));
14087 auto *LD2 = dyn_cast<LoadSDNode>(getBuildPairElt(N, 1));
14088
14089 // A BUILD_PAIR is always having the least significant part in elt 0 and the
14090 // most significant part in elt 1. So when combining into one large load, we
14091 // need to consider the endianness.
14092 if (DAG.getDataLayout().isBigEndian())
14093 std::swap(LD1, LD2);
14094
14095 if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(LD1) || !ISD::isNON_EXTLoad(LD2) ||
14096 !LD1->hasOneUse() || !LD2->hasOneUse() ||
14097 LD1->getAddressSpace() != LD2->getAddressSpace())
14098 return SDValue();
14099
14100 unsigned LD1Fast = 0;
14101 EVT LD1VT = LD1->getValueType(0);
14102 unsigned LD1Bytes = LD1VT.getStoreSize();
14103 if ((!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT)) &&
14104 DAG.areNonVolatileConsecutiveLoads(LD2, LD1, LD1Bytes, 1) &&
14105 TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
14106 *LD1->getMemOperand(), &LD1Fast) && LD1Fast)
14107 return DAG.getLoad(VT, SDLoc(N), LD1->getChain(), LD1->getBasePtr(),
14108 LD1->getPointerInfo(), LD1->getAlign());
14109
14110 return SDValue();
14111 }
14112
getPPCf128HiElementSelector(const SelectionDAG & DAG)14113 static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
14114 // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi
14115 // and Lo parts; on big-endian machines it doesn't.
14116 return DAG.getDataLayout().isBigEndian() ? 1 : 0;
14117 }
14118
foldBitcastedFPLogic(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)14119 static SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
14120 const TargetLowering &TLI) {
14121 // If this is not a bitcast to an FP type or if the target doesn't have
14122 // IEEE754-compliant FP logic, we're done.
14123 EVT VT = N->getValueType(0);
14124 if (!VT.isFloatingPoint() || !TLI.hasBitPreservingFPLogic(VT))
14125 return SDValue();
14126
14127 // TODO: Handle cases where the integer constant is a different scalar
14128 // bitwidth to the FP.
14129 SDValue N0 = N->getOperand(0);
14130 EVT SourceVT = N0.getValueType();
14131 if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
14132 return SDValue();
14133
14134 unsigned FPOpcode;
14135 APInt SignMask;
14136 switch (N0.getOpcode()) {
14137 case ISD::AND:
14138 FPOpcode = ISD::FABS;
14139 SignMask = ~APInt::getSignMask(SourceVT.getScalarSizeInBits());
14140 break;
14141 case ISD::XOR:
14142 FPOpcode = ISD::FNEG;
14143 SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
14144 break;
14145 case ISD::OR:
14146 FPOpcode = ISD::FABS;
14147 SignMask = APInt::getSignMask(SourceVT.getScalarSizeInBits());
14148 break;
14149 default:
14150 return SDValue();
14151 }
14152
14153 // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
14154 // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
14155 // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
14156 // fneg (fabs X)
14157 SDValue LogicOp0 = N0.getOperand(0);
14158 ConstantSDNode *LogicOp1 = isConstOrConstSplat(N0.getOperand(1), true);
14159 if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
14160 LogicOp0.getOpcode() == ISD::BITCAST &&
14161 LogicOp0.getOperand(0).getValueType() == VT) {
14162 SDValue FPOp = DAG.getNode(FPOpcode, SDLoc(N), VT, LogicOp0.getOperand(0));
14163 NumFPLogicOpsConv++;
14164 if (N0.getOpcode() == ISD::OR)
14165 return DAG.getNode(ISD::FNEG, SDLoc(N), VT, FPOp);
14166 return FPOp;
14167 }
14168
14169 return SDValue();
14170 }
14171
visitBITCAST(SDNode * N)14172 SDValue DAGCombiner::visitBITCAST(SDNode *N) {
14173 SDValue N0 = N->getOperand(0);
14174 EVT VT = N->getValueType(0);
14175
14176 if (N0.isUndef())
14177 return DAG.getUNDEF(VT);
14178
14179 // If the input is a BUILD_VECTOR with all constant elements, fold this now.
14180 // Only do this before legalize types, unless both types are integer and the
14181 // scalar type is legal. Only do this before legalize ops, since the target
14182 // maybe depending on the bitcast.
14183 // First check to see if this is all constant.
14184 // TODO: Support FP bitcasts after legalize types.
14185 if (VT.isVector() &&
14186 (!LegalTypes ||
14187 (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() &&
14188 TLI.isTypeLegal(VT.getVectorElementType()))) &&
14189 N0.getOpcode() == ISD::BUILD_VECTOR && N0->hasOneUse() &&
14190 cast<BuildVectorSDNode>(N0)->isConstant())
14191 return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(),
14192 VT.getVectorElementType());
14193
14194 // If the input is a constant, let getNode fold it.
14195 if (isIntOrFPConstant(N0)) {
14196 // If we can't allow illegal operations, we need to check that this is just
14197 // a fp -> int or int -> conversion and that the resulting operation will
14198 // be legal.
14199 if (!LegalOperations ||
14200 (isa<ConstantSDNode>(N0) && VT.isFloatingPoint() && !VT.isVector() &&
14201 TLI.isOperationLegal(ISD::ConstantFP, VT)) ||
14202 (isa<ConstantFPSDNode>(N0) && VT.isInteger() && !VT.isVector() &&
14203 TLI.isOperationLegal(ISD::Constant, VT))) {
14204 SDValue C = DAG.getBitcast(VT, N0);
14205 if (C.getNode() != N)
14206 return C;
14207 }
14208 }
14209
14210 // (conv (conv x, t1), t2) -> (conv x, t2)
14211 if (N0.getOpcode() == ISD::BITCAST)
14212 return DAG.getBitcast(VT, N0.getOperand(0));
14213
14214 // fold (conv (load x)) -> (load (conv*)x)
14215 // If the resultant load doesn't need a higher alignment than the original!
14216 if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
14217 // Do not remove the cast if the types differ in endian layout.
14218 TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) ==
14219 TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) &&
14220 // If the load is volatile, we only want to change the load type if the
14221 // resulting load is legal. Otherwise we might increase the number of
14222 // memory accesses. We don't care if the original type was legal or not
14223 // as we assume software couldn't rely on the number of accesses of an
14224 // illegal type.
14225 ((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) ||
14226 TLI.isOperationLegal(ISD::LOAD, VT))) {
14227 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
14228
14229 if (TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
14230 *LN0->getMemOperand())) {
14231 SDValue Load =
14232 DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
14233 LN0->getPointerInfo(), LN0->getAlign(),
14234 LN0->getMemOperand()->getFlags(), LN0->getAAInfo());
14235 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
14236 return Load;
14237 }
14238 }
14239
14240 if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
14241 return V;
14242
14243 // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
14244 // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
14245 //
14246 // For ppc_fp128:
14247 // fold (bitcast (fneg x)) ->
14248 // flipbit = signbit
14249 // (xor (bitcast x) (build_pair flipbit, flipbit))
14250 //
14251 // fold (bitcast (fabs x)) ->
14252 // flipbit = (and (extract_element (bitcast x), 0), signbit)
14253 // (xor (bitcast x) (build_pair flipbit, flipbit))
14254 // This often reduces constant pool loads.
14255 if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(N0.getValueType())) ||
14256 (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(N0.getValueType()))) &&
14257 N0->hasOneUse() && VT.isInteger() && !VT.isVector() &&
14258 !N0.getValueType().isVector()) {
14259 SDValue NewConv = DAG.getBitcast(VT, N0.getOperand(0));
14260 AddToWorklist(NewConv.getNode());
14261
14262 SDLoc DL(N);
14263 if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
14264 assert(VT.getSizeInBits() == 128);
14265 SDValue SignBit = DAG.getConstant(
14266 APInt::getSignMask(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
14267 SDValue FlipBit;
14268 if (N0.getOpcode() == ISD::FNEG) {
14269 FlipBit = SignBit;
14270 AddToWorklist(FlipBit.getNode());
14271 } else {
14272 assert(N0.getOpcode() == ISD::FABS);
14273 SDValue Hi =
14274 DAG.getNode(ISD::EXTRACT_ELEMENT, SDLoc(NewConv), MVT::i64, NewConv,
14275 DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
14276 SDLoc(NewConv)));
14277 AddToWorklist(Hi.getNode());
14278 FlipBit = DAG.getNode(ISD::AND, SDLoc(N0), MVT::i64, Hi, SignBit);
14279 AddToWorklist(FlipBit.getNode());
14280 }
14281 SDValue FlipBits =
14282 DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
14283 AddToWorklist(FlipBits.getNode());
14284 return DAG.getNode(ISD::XOR, DL, VT, NewConv, FlipBits);
14285 }
14286 APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
14287 if (N0.getOpcode() == ISD::FNEG)
14288 return DAG.getNode(ISD::XOR, DL, VT,
14289 NewConv, DAG.getConstant(SignBit, DL, VT));
14290 assert(N0.getOpcode() == ISD::FABS);
14291 return DAG.getNode(ISD::AND, DL, VT,
14292 NewConv, DAG.getConstant(~SignBit, DL, VT));
14293 }
14294
14295 // fold (bitconvert (fcopysign cst, x)) ->
14296 // (or (and (bitconvert x), sign), (and cst, (not sign)))
14297 // Note that we don't handle (copysign x, cst) because this can always be
14298 // folded to an fneg or fabs.
14299 //
14300 // For ppc_fp128:
14301 // fold (bitcast (fcopysign cst, x)) ->
14302 // flipbit = (and (extract_element
14303 // (xor (bitcast cst), (bitcast x)), 0),
14304 // signbit)
14305 // (xor (bitcast cst) (build_pair flipbit, flipbit))
14306 if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
14307 isa<ConstantFPSDNode>(N0.getOperand(0)) && VT.isInteger() &&
14308 !VT.isVector()) {
14309 unsigned OrigXWidth = N0.getOperand(1).getValueSizeInBits();
14310 EVT IntXVT = EVT::getIntegerVT(*DAG.getContext(), OrigXWidth);
14311 if (isTypeLegal(IntXVT)) {
14312 SDValue X = DAG.getBitcast(IntXVT, N0.getOperand(1));
14313 AddToWorklist(X.getNode());
14314
14315 // If X has a different width than the result/lhs, sext it or truncate it.
14316 unsigned VTWidth = VT.getSizeInBits();
14317 if (OrigXWidth < VTWidth) {
14318 X = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT, X);
14319 AddToWorklist(X.getNode());
14320 } else if (OrigXWidth > VTWidth) {
14321 // To get the sign bit in the right place, we have to shift it right
14322 // before truncating.
14323 SDLoc DL(X);
14324 X = DAG.getNode(ISD::SRL, DL,
14325 X.getValueType(), X,
14326 DAG.getConstant(OrigXWidth-VTWidth, DL,
14327 X.getValueType()));
14328 AddToWorklist(X.getNode());
14329 X = DAG.getNode(ISD::TRUNCATE, SDLoc(X), VT, X);
14330 AddToWorklist(X.getNode());
14331 }
14332
14333 if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
14334 APInt SignBit = APInt::getSignMask(VT.getSizeInBits() / 2);
14335 SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
14336 AddToWorklist(Cst.getNode());
14337 SDValue X = DAG.getBitcast(VT, N0.getOperand(1));
14338 AddToWorklist(X.getNode());
14339 SDValue XorResult = DAG.getNode(ISD::XOR, SDLoc(N0), VT, Cst, X);
14340 AddToWorklist(XorResult.getNode());
14341 SDValue XorResult64 = DAG.getNode(
14342 ISD::EXTRACT_ELEMENT, SDLoc(XorResult), MVT::i64, XorResult,
14343 DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
14344 SDLoc(XorResult)));
14345 AddToWorklist(XorResult64.getNode());
14346 SDValue FlipBit =
14347 DAG.getNode(ISD::AND, SDLoc(XorResult64), MVT::i64, XorResult64,
14348 DAG.getConstant(SignBit, SDLoc(XorResult64), MVT::i64));
14349 AddToWorklist(FlipBit.getNode());
14350 SDValue FlipBits =
14351 DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit);
14352 AddToWorklist(FlipBits.getNode());
14353 return DAG.getNode(ISD::XOR, SDLoc(N), VT, Cst, FlipBits);
14354 }
14355 APInt SignBit = APInt::getSignMask(VT.getSizeInBits());
14356 X = DAG.getNode(ISD::AND, SDLoc(X), VT,
14357 X, DAG.getConstant(SignBit, SDLoc(X), VT));
14358 AddToWorklist(X.getNode());
14359
14360 SDValue Cst = DAG.getBitcast(VT, N0.getOperand(0));
14361 Cst = DAG.getNode(ISD::AND, SDLoc(Cst), VT,
14362 Cst, DAG.getConstant(~SignBit, SDLoc(Cst), VT));
14363 AddToWorklist(Cst.getNode());
14364
14365 return DAG.getNode(ISD::OR, SDLoc(N), VT, X, Cst);
14366 }
14367 }
14368
14369 // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
14370 if (N0.getOpcode() == ISD::BUILD_PAIR)
14371 if (SDValue CombineLD = CombineConsecutiveLoads(N0.getNode(), VT))
14372 return CombineLD;
14373
14374 // Remove double bitcasts from shuffles - this is often a legacy of
14375 // XformToShuffleWithZero being used to combine bitmaskings (of
14376 // float vectors bitcast to integer vectors) into shuffles.
14377 // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
14378 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
14379 N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
14380 VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() &&
14381 !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) {
14382 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N0);
14383
14384 // If operands are a bitcast, peek through if it casts the original VT.
14385 // If operands are a constant, just bitcast back to original VT.
14386 auto PeekThroughBitcast = [&](SDValue Op) {
14387 if (Op.getOpcode() == ISD::BITCAST &&
14388 Op.getOperand(0).getValueType() == VT)
14389 return SDValue(Op.getOperand(0));
14390 if (Op.isUndef() || isAnyConstantBuildVector(Op))
14391 return DAG.getBitcast(VT, Op);
14392 return SDValue();
14393 };
14394
14395 // FIXME: If either input vector is bitcast, try to convert the shuffle to
14396 // the result type of this bitcast. This would eliminate at least one
14397 // bitcast. See the transform in InstCombine.
14398 SDValue SV0 = PeekThroughBitcast(N0->getOperand(0));
14399 SDValue SV1 = PeekThroughBitcast(N0->getOperand(1));
14400 if (!(SV0 && SV1))
14401 return SDValue();
14402
14403 int MaskScale =
14404 VT.getVectorNumElements() / N0.getValueType().getVectorNumElements();
14405 SmallVector<int, 8> NewMask;
14406 for (int M : SVN->getMask())
14407 for (int i = 0; i != MaskScale; ++i)
14408 NewMask.push_back(M < 0 ? -1 : M * MaskScale + i);
14409
14410 SDValue LegalShuffle =
14411 TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, NewMask, DAG);
14412 if (LegalShuffle)
14413 return LegalShuffle;
14414 }
14415
14416 return SDValue();
14417 }
14418
visitBUILD_PAIR(SDNode * N)14419 SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) {
14420 EVT VT = N->getValueType(0);
14421 return CombineConsecutiveLoads(N, VT);
14422 }
14423
visitFREEZE(SDNode * N)14424 SDValue DAGCombiner::visitFREEZE(SDNode *N) {
14425 SDValue N0 = N->getOperand(0);
14426
14427 if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, /*PoisonOnly*/ false))
14428 return N0;
14429
14430 // Fold freeze(op(x, ...)) -> op(freeze(x), ...).
14431 // Try to push freeze through instructions that propagate but don't produce
14432 // poison as far as possible. If an operand of freeze follows three
14433 // conditions 1) one-use, 2) does not produce poison, and 3) has all but one
14434 // guaranteed-non-poison operands (or is a BUILD_VECTOR or similar) then push
14435 // the freeze through to the operands that are not guaranteed non-poison.
14436 // NOTE: we will strip poison-generating flags, so ignore them here.
14437 if (DAG.canCreateUndefOrPoison(N0, /*PoisonOnly*/ false,
14438 /*ConsiderFlags*/ false) ||
14439 N0->getNumValues() != 1 || !N0->hasOneUse())
14440 return SDValue();
14441
14442 bool AllowMultipleMaybePoisonOperands = N0.getOpcode() == ISD::BUILD_VECTOR;
14443
14444 SmallSetVector<SDValue, 8> MaybePoisonOperands;
14445 for (SDValue Op : N0->ops()) {
14446 if (DAG.isGuaranteedNotToBeUndefOrPoison(Op, /*PoisonOnly*/ false,
14447 /*Depth*/ 1))
14448 continue;
14449 bool HadMaybePoisonOperands = !MaybePoisonOperands.empty();
14450 bool IsNewMaybePoisonOperand = MaybePoisonOperands.insert(Op);
14451 if (!HadMaybePoisonOperands)
14452 continue;
14453 if (IsNewMaybePoisonOperand && !AllowMultipleMaybePoisonOperands) {
14454 // Multiple maybe-poison ops when not allowed - bail out.
14455 return SDValue();
14456 }
14457 }
14458 // NOTE: the whole op may be not guaranteed to not be undef or poison because
14459 // it could create undef or poison due to it's poison-generating flags.
14460 // So not finding any maybe-poison operands is fine.
14461
14462 for (SDValue MaybePoisonOperand : MaybePoisonOperands) {
14463 // Don't replace every single UNDEF everywhere with frozen UNDEF, though.
14464 if (MaybePoisonOperand.getOpcode() == ISD::UNDEF)
14465 continue;
14466 // First, freeze each offending operand.
14467 SDValue FrozenMaybePoisonOperand = DAG.getFreeze(MaybePoisonOperand);
14468 // Then, change all other uses of unfrozen operand to use frozen operand.
14469 DAG.ReplaceAllUsesOfValueWith(MaybePoisonOperand, FrozenMaybePoisonOperand);
14470 if (FrozenMaybePoisonOperand.getOpcode() == ISD::FREEZE &&
14471 FrozenMaybePoisonOperand.getOperand(0) == FrozenMaybePoisonOperand) {
14472 // But, that also updated the use in the freeze we just created, thus
14473 // creating a cycle in a DAG. Let's undo that by mutating the freeze.
14474 DAG.UpdateNodeOperands(FrozenMaybePoisonOperand.getNode(),
14475 MaybePoisonOperand);
14476 }
14477 }
14478
14479 // The whole node may have been updated, so the value we were holding
14480 // may no longer be valid. Re-fetch the operand we're `freeze`ing.
14481 N0 = N->getOperand(0);
14482
14483 // Finally, recreate the node, it's operands were updated to use
14484 // frozen operands, so we just need to use it's "original" operands.
14485 SmallVector<SDValue> Ops(N0->op_begin(), N0->op_end());
14486 // Special-handle ISD::UNDEF, each single one of them can be it's own thing.
14487 for (SDValue &Op : Ops) {
14488 if (Op.getOpcode() == ISD::UNDEF)
14489 Op = DAG.getFreeze(Op);
14490 }
14491 // NOTE: this strips poison generating flags.
14492 SDValue R = DAG.getNode(N0.getOpcode(), SDLoc(N0), N0->getVTList(), Ops);
14493 assert(DAG.isGuaranteedNotToBeUndefOrPoison(R, /*PoisonOnly*/ false) &&
14494 "Can't create node that may be undef/poison!");
14495 return R;
14496 }
14497
14498 /// We know that BV is a build_vector node with Constant, ConstantFP or Undef
14499 /// operands. DstEltVT indicates the destination element value type.
14500 SDValue DAGCombiner::
ConstantFoldBITCASTofBUILD_VECTOR(SDNode * BV,EVT DstEltVT)14501 ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
14502 EVT SrcEltVT = BV->getValueType(0).getVectorElementType();
14503
14504 // If this is already the right type, we're done.
14505 if (SrcEltVT == DstEltVT) return SDValue(BV, 0);
14506
14507 unsigned SrcBitSize = SrcEltVT.getSizeInBits();
14508 unsigned DstBitSize = DstEltVT.getSizeInBits();
14509
14510 // If this is a conversion of N elements of one type to N elements of another
14511 // type, convert each element. This handles FP<->INT cases.
14512 if (SrcBitSize == DstBitSize) {
14513 SmallVector<SDValue, 8> Ops;
14514 for (SDValue Op : BV->op_values()) {
14515 // If the vector element type is not legal, the BUILD_VECTOR operands
14516 // are promoted and implicitly truncated. Make that explicit here.
14517 if (Op.getValueType() != SrcEltVT)
14518 Op = DAG.getNode(ISD::TRUNCATE, SDLoc(BV), SrcEltVT, Op);
14519 Ops.push_back(DAG.getBitcast(DstEltVT, Op));
14520 AddToWorklist(Ops.back().getNode());
14521 }
14522 EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
14523 BV->getValueType(0).getVectorNumElements());
14524 return DAG.getBuildVector(VT, SDLoc(BV), Ops);
14525 }
14526
14527 // Otherwise, we're growing or shrinking the elements. To avoid having to
14528 // handle annoying details of growing/shrinking FP values, we convert them to
14529 // int first.
14530 if (SrcEltVT.isFloatingPoint()) {
14531 // Convert the input float vector to a int vector where the elements are the
14532 // same sizes.
14533 EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcEltVT.getSizeInBits());
14534 BV = ConstantFoldBITCASTofBUILD_VECTOR(BV, IntVT).getNode();
14535 SrcEltVT = IntVT;
14536 }
14537
14538 // Now we know the input is an integer vector. If the output is a FP type,
14539 // convert to integer first, then to FP of the right size.
14540 if (DstEltVT.isFloatingPoint()) {
14541 EVT TmpVT = EVT::getIntegerVT(*DAG.getContext(), DstEltVT.getSizeInBits());
14542 SDNode *Tmp = ConstantFoldBITCASTofBUILD_VECTOR(BV, TmpVT).getNode();
14543
14544 // Next, convert to FP elements of the same size.
14545 return ConstantFoldBITCASTofBUILD_VECTOR(Tmp, DstEltVT);
14546 }
14547
14548 // Okay, we know the src/dst types are both integers of differing types.
14549 assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
14550
14551 // TODO: Should ConstantFoldBITCASTofBUILD_VECTOR always take a
14552 // BuildVectorSDNode?
14553 auto *BVN = cast<BuildVectorSDNode>(BV);
14554
14555 // Extract the constant raw bit data.
14556 BitVector UndefElements;
14557 SmallVector<APInt> RawBits;
14558 bool IsLE = DAG.getDataLayout().isLittleEndian();
14559 if (!BVN->getConstantRawBits(IsLE, DstBitSize, RawBits, UndefElements))
14560 return SDValue();
14561
14562 SDLoc DL(BV);
14563 SmallVector<SDValue, 8> Ops;
14564 for (unsigned I = 0, E = RawBits.size(); I != E; ++I) {
14565 if (UndefElements[I])
14566 Ops.push_back(DAG.getUNDEF(DstEltVT));
14567 else
14568 Ops.push_back(DAG.getConstant(RawBits[I], DL, DstEltVT));
14569 }
14570
14571 EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size());
14572 return DAG.getBuildVector(VT, DL, Ops);
14573 }
14574
14575 // Returns true if floating point contraction is allowed on the FMUL-SDValue
14576 // `N`
isContractableFMUL(const TargetOptions & Options,SDValue N)14577 static bool isContractableFMUL(const TargetOptions &Options, SDValue N) {
14578 assert(N.getOpcode() == ISD::FMUL);
14579
14580 return Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath ||
14581 N->getFlags().hasAllowContract();
14582 }
14583
14584 // Returns true if `N` can assume no infinities involved in its computation.
hasNoInfs(const TargetOptions & Options,SDValue N)14585 static bool hasNoInfs(const TargetOptions &Options, SDValue N) {
14586 return Options.NoInfsFPMath || N->getFlags().hasNoInfs();
14587 }
14588
14589 /// Try to perform FMA combining on a given FADD node.
visitFADDForFMACombine(SDNode * N)14590 SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
14591 SDValue N0 = N->getOperand(0);
14592 SDValue N1 = N->getOperand(1);
14593 EVT VT = N->getValueType(0);
14594 SDLoc SL(N);
14595
14596 const TargetOptions &Options = DAG.getTarget().Options;
14597
14598 // Floating-point multiply-add with intermediate rounding.
14599 bool HasFMAD = (LegalOperations && TLI.isFMADLegal(DAG, N));
14600
14601 // Floating-point multiply-add without intermediate rounding.
14602 bool HasFMA =
14603 TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
14604 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
14605
14606 // No valid opcode, do not combine.
14607 if (!HasFMAD && !HasFMA)
14608 return SDValue();
14609
14610 bool CanReassociate =
14611 Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
14612 bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
14613 Options.UnsafeFPMath || HasFMAD);
14614 // If the addition is not contractable, do not combine.
14615 if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
14616 return SDValue();
14617
14618 if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
14619 return SDValue();
14620
14621 // Always prefer FMAD to FMA for precision.
14622 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
14623 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
14624
14625 auto isFusedOp = [&](SDValue N) {
14626 unsigned Opcode = N.getOpcode();
14627 return Opcode == ISD::FMA || Opcode == ISD::FMAD;
14628 };
14629
14630 // Is the node an FMUL and contractable either due to global flags or
14631 // SDNodeFlags.
14632 auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
14633 if (N.getOpcode() != ISD::FMUL)
14634 return false;
14635 return AllowFusionGlobally || N->getFlags().hasAllowContract();
14636 };
14637 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
14638 // prefer to fold the multiply with fewer uses.
14639 if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
14640 if (N0->use_size() > N1->use_size())
14641 std::swap(N0, N1);
14642 }
14643
14644 // fold (fadd (fmul x, y), z) -> (fma x, y, z)
14645 if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
14646 return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
14647 N0.getOperand(1), N1);
14648 }
14649
14650 // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
14651 // Note: Commutes FADD operands.
14652 if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
14653 return DAG.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0),
14654 N1.getOperand(1), N0);
14655 }
14656
14657 // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
14658 // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
14659 // This also works with nested fma instructions:
14660 // fadd (fma A, B, (fma (C, D, (fmul (E, F))))), G -->
14661 // fma A, B, (fma C, D, fma (E, F, G))
14662 // fadd (G, (fma A, B, (fma (C, D, (fmul (E, F)))))) -->
14663 // fma A, B, (fma C, D, fma (E, F, G)).
14664 // This requires reassociation because it changes the order of operations.
14665 if (CanReassociate) {
14666 SDValue FMA, E;
14667 if (isFusedOp(N0) && N0.hasOneUse()) {
14668 FMA = N0;
14669 E = N1;
14670 } else if (isFusedOp(N1) && N1.hasOneUse()) {
14671 FMA = N1;
14672 E = N0;
14673 }
14674
14675 SDValue TmpFMA = FMA;
14676 while (E && isFusedOp(TmpFMA) && TmpFMA.hasOneUse()) {
14677 SDValue FMul = TmpFMA->getOperand(2);
14678 if (FMul.getOpcode() == ISD::FMUL && FMul.hasOneUse()) {
14679 SDValue C = FMul.getOperand(0);
14680 SDValue D = FMul.getOperand(1);
14681 SDValue CDE = DAG.getNode(PreferredFusedOpcode, SL, VT, C, D, E);
14682 DAG.ReplaceAllUsesOfValueWith(FMul, CDE);
14683 // Replacing the inner FMul could cause the outer FMA to be simplified
14684 // away.
14685 return FMA.getOpcode() == ISD::DELETED_NODE ? SDValue() : FMA;
14686 }
14687
14688 TmpFMA = TmpFMA->getOperand(2);
14689 }
14690 }
14691
14692 // Look through FP_EXTEND nodes to do more combining.
14693
14694 // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
14695 if (N0.getOpcode() == ISD::FP_EXTEND) {
14696 SDValue N00 = N0.getOperand(0);
14697 if (isContractableFMUL(N00) &&
14698 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14699 N00.getValueType())) {
14700 return DAG.getNode(PreferredFusedOpcode, SL, VT,
14701 DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
14702 DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
14703 N1);
14704 }
14705 }
14706
14707 // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
14708 // Note: Commutes FADD operands.
14709 if (N1.getOpcode() == ISD::FP_EXTEND) {
14710 SDValue N10 = N1.getOperand(0);
14711 if (isContractableFMUL(N10) &&
14712 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14713 N10.getValueType())) {
14714 return DAG.getNode(PreferredFusedOpcode, SL, VT,
14715 DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)),
14716 DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)),
14717 N0);
14718 }
14719 }
14720
14721 // More folding opportunities when target permits.
14722 if (Aggressive) {
14723 // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
14724 // -> (fma x, y, (fma (fpext u), (fpext v), z))
14725 auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
14726 SDValue Z) {
14727 return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y,
14728 DAG.getNode(PreferredFusedOpcode, SL, VT,
14729 DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
14730 DAG.getNode(ISD::FP_EXTEND, SL, VT, V),
14731 Z));
14732 };
14733 if (isFusedOp(N0)) {
14734 SDValue N02 = N0.getOperand(2);
14735 if (N02.getOpcode() == ISD::FP_EXTEND) {
14736 SDValue N020 = N02.getOperand(0);
14737 if (isContractableFMUL(N020) &&
14738 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14739 N020.getValueType())) {
14740 return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1),
14741 N020.getOperand(0), N020.getOperand(1),
14742 N1);
14743 }
14744 }
14745 }
14746
14747 // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
14748 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
14749 // FIXME: This turns two single-precision and one double-precision
14750 // operation into two double-precision operations, which might not be
14751 // interesting for all targets, especially GPUs.
14752 auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
14753 SDValue Z) {
14754 return DAG.getNode(
14755 PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, X),
14756 DAG.getNode(ISD::FP_EXTEND, SL, VT, Y),
14757 DAG.getNode(PreferredFusedOpcode, SL, VT,
14758 DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
14759 DAG.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
14760 };
14761 if (N0.getOpcode() == ISD::FP_EXTEND) {
14762 SDValue N00 = N0.getOperand(0);
14763 if (isFusedOp(N00)) {
14764 SDValue N002 = N00.getOperand(2);
14765 if (isContractableFMUL(N002) &&
14766 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14767 N00.getValueType())) {
14768 return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1),
14769 N002.getOperand(0), N002.getOperand(1),
14770 N1);
14771 }
14772 }
14773 }
14774
14775 // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
14776 // -> (fma y, z, (fma (fpext u), (fpext v), x))
14777 if (isFusedOp(N1)) {
14778 SDValue N12 = N1.getOperand(2);
14779 if (N12.getOpcode() == ISD::FP_EXTEND) {
14780 SDValue N120 = N12.getOperand(0);
14781 if (isContractableFMUL(N120) &&
14782 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14783 N120.getValueType())) {
14784 return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1),
14785 N120.getOperand(0), N120.getOperand(1),
14786 N0);
14787 }
14788 }
14789 }
14790
14791 // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
14792 // -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
14793 // FIXME: This turns two single-precision and one double-precision
14794 // operation into two double-precision operations, which might not be
14795 // interesting for all targets, especially GPUs.
14796 if (N1.getOpcode() == ISD::FP_EXTEND) {
14797 SDValue N10 = N1.getOperand(0);
14798 if (isFusedOp(N10)) {
14799 SDValue N102 = N10.getOperand(2);
14800 if (isContractableFMUL(N102) &&
14801 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14802 N10.getValueType())) {
14803 return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1),
14804 N102.getOperand(0), N102.getOperand(1),
14805 N0);
14806 }
14807 }
14808 }
14809 }
14810
14811 return SDValue();
14812 }
14813
14814 /// Try to perform FMA combining on a given FSUB node.
visitFSUBForFMACombine(SDNode * N)14815 SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
14816 SDValue N0 = N->getOperand(0);
14817 SDValue N1 = N->getOperand(1);
14818 EVT VT = N->getValueType(0);
14819 SDLoc SL(N);
14820
14821 const TargetOptions &Options = DAG.getTarget().Options;
14822 // Floating-point multiply-add with intermediate rounding.
14823 bool HasFMAD = (LegalOperations && TLI.isFMADLegal(DAG, N));
14824
14825 // Floating-point multiply-add without intermediate rounding.
14826 bool HasFMA =
14827 TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
14828 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
14829
14830 // No valid opcode, do not combine.
14831 if (!HasFMAD && !HasFMA)
14832 return SDValue();
14833
14834 const SDNodeFlags Flags = N->getFlags();
14835 bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
14836 Options.UnsafeFPMath || HasFMAD);
14837
14838 // If the subtraction is not contractable, do not combine.
14839 if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
14840 return SDValue();
14841
14842 if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
14843 return SDValue();
14844
14845 // Always prefer FMAD to FMA for precision.
14846 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
14847 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
14848 bool NoSignedZero = Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros();
14849
14850 // Is the node an FMUL and contractable either due to global flags or
14851 // SDNodeFlags.
14852 auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
14853 if (N.getOpcode() != ISD::FMUL)
14854 return false;
14855 return AllowFusionGlobally || N->getFlags().hasAllowContract();
14856 };
14857
14858 // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
14859 auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) {
14860 if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) {
14861 return DAG.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(0),
14862 XY.getOperand(1), DAG.getNode(ISD::FNEG, SL, VT, Z));
14863 }
14864 return SDValue();
14865 };
14866
14867 // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
14868 // Note: Commutes FSUB operands.
14869 auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) {
14870 if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) {
14871 return DAG.getNode(PreferredFusedOpcode, SL, VT,
14872 DAG.getNode(ISD::FNEG, SL, VT, YZ.getOperand(0)),
14873 YZ.getOperand(1), X);
14874 }
14875 return SDValue();
14876 };
14877
14878 // If we have two choices trying to fold (fsub (fmul u, v), (fmul x, y)),
14879 // prefer to fold the multiply with fewer uses.
14880 if (isContractableFMUL(N0) && isContractableFMUL(N1) &&
14881 (N0->use_size() > N1->use_size())) {
14882 // fold (fsub (fmul a, b), (fmul c, d)) -> (fma (fneg c), d, (fmul a, b))
14883 if (SDValue V = tryToFoldXSubYZ(N0, N1))
14884 return V;
14885 // fold (fsub (fmul a, b), (fmul c, d)) -> (fma a, b, (fneg (fmul c, d)))
14886 if (SDValue V = tryToFoldXYSubZ(N0, N1))
14887 return V;
14888 } else {
14889 // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
14890 if (SDValue V = tryToFoldXYSubZ(N0, N1))
14891 return V;
14892 // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
14893 if (SDValue V = tryToFoldXSubYZ(N0, N1))
14894 return V;
14895 }
14896
14897 // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
14898 if (N0.getOpcode() == ISD::FNEG && isContractableFMUL(N0.getOperand(0)) &&
14899 (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) {
14900 SDValue N00 = N0.getOperand(0).getOperand(0);
14901 SDValue N01 = N0.getOperand(0).getOperand(1);
14902 return DAG.getNode(PreferredFusedOpcode, SL, VT,
14903 DAG.getNode(ISD::FNEG, SL, VT, N00), N01,
14904 DAG.getNode(ISD::FNEG, SL, VT, N1));
14905 }
14906
14907 // Look through FP_EXTEND nodes to do more combining.
14908
14909 // fold (fsub (fpext (fmul x, y)), z)
14910 // -> (fma (fpext x), (fpext y), (fneg z))
14911 if (N0.getOpcode() == ISD::FP_EXTEND) {
14912 SDValue N00 = N0.getOperand(0);
14913 if (isContractableFMUL(N00) &&
14914 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14915 N00.getValueType())) {
14916 return DAG.getNode(PreferredFusedOpcode, SL, VT,
14917 DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
14918 DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
14919 DAG.getNode(ISD::FNEG, SL, VT, N1));
14920 }
14921 }
14922
14923 // fold (fsub x, (fpext (fmul y, z)))
14924 // -> (fma (fneg (fpext y)), (fpext z), x)
14925 // Note: Commutes FSUB operands.
14926 if (N1.getOpcode() == ISD::FP_EXTEND) {
14927 SDValue N10 = N1.getOperand(0);
14928 if (isContractableFMUL(N10) &&
14929 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14930 N10.getValueType())) {
14931 return DAG.getNode(
14932 PreferredFusedOpcode, SL, VT,
14933 DAG.getNode(ISD::FNEG, SL, VT,
14934 DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0))),
14935 DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0);
14936 }
14937 }
14938
14939 // fold (fsub (fpext (fneg (fmul, x, y))), z)
14940 // -> (fneg (fma (fpext x), (fpext y), z))
14941 // Note: This could be removed with appropriate canonicalization of the
14942 // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
14943 // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
14944 // from implementing the canonicalization in visitFSUB.
14945 if (N0.getOpcode() == ISD::FP_EXTEND) {
14946 SDValue N00 = N0.getOperand(0);
14947 if (N00.getOpcode() == ISD::FNEG) {
14948 SDValue N000 = N00.getOperand(0);
14949 if (isContractableFMUL(N000) &&
14950 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14951 N00.getValueType())) {
14952 return DAG.getNode(
14953 ISD::FNEG, SL, VT,
14954 DAG.getNode(PreferredFusedOpcode, SL, VT,
14955 DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
14956 DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
14957 N1));
14958 }
14959 }
14960 }
14961
14962 // fold (fsub (fneg (fpext (fmul, x, y))), z)
14963 // -> (fneg (fma (fpext x)), (fpext y), z)
14964 // Note: This could be removed with appropriate canonicalization of the
14965 // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
14966 // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
14967 // from implementing the canonicalization in visitFSUB.
14968 if (N0.getOpcode() == ISD::FNEG) {
14969 SDValue N00 = N0.getOperand(0);
14970 if (N00.getOpcode() == ISD::FP_EXTEND) {
14971 SDValue N000 = N00.getOperand(0);
14972 if (isContractableFMUL(N000) &&
14973 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
14974 N000.getValueType())) {
14975 return DAG.getNode(
14976 ISD::FNEG, SL, VT,
14977 DAG.getNode(PreferredFusedOpcode, SL, VT,
14978 DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)),
14979 DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)),
14980 N1));
14981 }
14982 }
14983 }
14984
14985 auto isReassociable = [Options](SDNode *N) {
14986 return Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
14987 };
14988
14989 auto isContractableAndReassociableFMUL = [&isContractableFMUL,
14990 &isReassociable](SDValue N) {
14991 return isContractableFMUL(N) && isReassociable(N.getNode());
14992 };
14993
14994 auto isFusedOp = [&](SDValue N) {
14995 unsigned Opcode = N.getOpcode();
14996 return Opcode == ISD::FMA || Opcode == ISD::FMAD;
14997 };
14998
14999 // More folding opportunities when target permits.
15000 if (Aggressive && isReassociable(N)) {
15001 bool CanFuse = Options.UnsafeFPMath || N->getFlags().hasAllowContract();
15002 // fold (fsub (fma x, y, (fmul u, v)), z)
15003 // -> (fma x, y (fma u, v, (fneg z)))
15004 if (CanFuse && isFusedOp(N0) &&
15005 isContractableAndReassociableFMUL(N0.getOperand(2)) &&
15006 N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) {
15007 return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
15008 N0.getOperand(1),
15009 DAG.getNode(PreferredFusedOpcode, SL, VT,
15010 N0.getOperand(2).getOperand(0),
15011 N0.getOperand(2).getOperand(1),
15012 DAG.getNode(ISD::FNEG, SL, VT, N1)));
15013 }
15014
15015 // fold (fsub x, (fma y, z, (fmul u, v)))
15016 // -> (fma (fneg y), z, (fma (fneg u), v, x))
15017 if (CanFuse && isFusedOp(N1) &&
15018 isContractableAndReassociableFMUL(N1.getOperand(2)) &&
15019 N1->hasOneUse() && NoSignedZero) {
15020 SDValue N20 = N1.getOperand(2).getOperand(0);
15021 SDValue N21 = N1.getOperand(2).getOperand(1);
15022 return DAG.getNode(
15023 PreferredFusedOpcode, SL, VT,
15024 DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1),
15025 DAG.getNode(PreferredFusedOpcode, SL, VT,
15026 DAG.getNode(ISD::FNEG, SL, VT, N20), N21, N0));
15027 }
15028
15029 // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
15030 // -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
15031 if (isFusedOp(N0) && N0->hasOneUse()) {
15032 SDValue N02 = N0.getOperand(2);
15033 if (N02.getOpcode() == ISD::FP_EXTEND) {
15034 SDValue N020 = N02.getOperand(0);
15035 if (isContractableAndReassociableFMUL(N020) &&
15036 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15037 N020.getValueType())) {
15038 return DAG.getNode(
15039 PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1),
15040 DAG.getNode(
15041 PreferredFusedOpcode, SL, VT,
15042 DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(0)),
15043 DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)),
15044 DAG.getNode(ISD::FNEG, SL, VT, N1)));
15045 }
15046 }
15047 }
15048
15049 // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
15050 // -> (fma (fpext x), (fpext y),
15051 // (fma (fpext u), (fpext v), (fneg z)))
15052 // FIXME: This turns two single-precision and one double-precision
15053 // operation into two double-precision operations, which might not be
15054 // interesting for all targets, especially GPUs.
15055 if (N0.getOpcode() == ISD::FP_EXTEND) {
15056 SDValue N00 = N0.getOperand(0);
15057 if (isFusedOp(N00)) {
15058 SDValue N002 = N00.getOperand(2);
15059 if (isContractableAndReassociableFMUL(N002) &&
15060 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15061 N00.getValueType())) {
15062 return DAG.getNode(
15063 PreferredFusedOpcode, SL, VT,
15064 DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
15065 DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
15066 DAG.getNode(
15067 PreferredFusedOpcode, SL, VT,
15068 DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(0)),
15069 DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)),
15070 DAG.getNode(ISD::FNEG, SL, VT, N1)));
15071 }
15072 }
15073 }
15074
15075 // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
15076 // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
15077 if (isFusedOp(N1) && N1.getOperand(2).getOpcode() == ISD::FP_EXTEND &&
15078 N1->hasOneUse()) {
15079 SDValue N120 = N1.getOperand(2).getOperand(0);
15080 if (isContractableAndReassociableFMUL(N120) &&
15081 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15082 N120.getValueType())) {
15083 SDValue N1200 = N120.getOperand(0);
15084 SDValue N1201 = N120.getOperand(1);
15085 return DAG.getNode(
15086 PreferredFusedOpcode, SL, VT,
15087 DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1),
15088 DAG.getNode(PreferredFusedOpcode, SL, VT,
15089 DAG.getNode(ISD::FNEG, SL, VT,
15090 DAG.getNode(ISD::FP_EXTEND, SL, VT, N1200)),
15091 DAG.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0));
15092 }
15093 }
15094
15095 // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
15096 // -> (fma (fneg (fpext y)), (fpext z),
15097 // (fma (fneg (fpext u)), (fpext v), x))
15098 // FIXME: This turns two single-precision and one double-precision
15099 // operation into two double-precision operations, which might not be
15100 // interesting for all targets, especially GPUs.
15101 if (N1.getOpcode() == ISD::FP_EXTEND && isFusedOp(N1.getOperand(0))) {
15102 SDValue CvtSrc = N1.getOperand(0);
15103 SDValue N100 = CvtSrc.getOperand(0);
15104 SDValue N101 = CvtSrc.getOperand(1);
15105 SDValue N102 = CvtSrc.getOperand(2);
15106 if (isContractableAndReassociableFMUL(N102) &&
15107 TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
15108 CvtSrc.getValueType())) {
15109 SDValue N1020 = N102.getOperand(0);
15110 SDValue N1021 = N102.getOperand(1);
15111 return DAG.getNode(
15112 PreferredFusedOpcode, SL, VT,
15113 DAG.getNode(ISD::FNEG, SL, VT,
15114 DAG.getNode(ISD::FP_EXTEND, SL, VT, N100)),
15115 DAG.getNode(ISD::FP_EXTEND, SL, VT, N101),
15116 DAG.getNode(PreferredFusedOpcode, SL, VT,
15117 DAG.getNode(ISD::FNEG, SL, VT,
15118 DAG.getNode(ISD::FP_EXTEND, SL, VT, N1020)),
15119 DAG.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0));
15120 }
15121 }
15122 }
15123
15124 return SDValue();
15125 }
15126
15127 /// Try to perform FMA combining on a given FMUL node based on the distributive
15128 /// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
15129 /// subtraction instead of addition).
visitFMULForFMADistributiveCombine(SDNode * N)15130 SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
15131 SDValue N0 = N->getOperand(0);
15132 SDValue N1 = N->getOperand(1);
15133 EVT VT = N->getValueType(0);
15134 SDLoc SL(N);
15135
15136 assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
15137
15138 const TargetOptions &Options = DAG.getTarget().Options;
15139
15140 // The transforms below are incorrect when x == 0 and y == inf, because the
15141 // intermediate multiplication produces a nan.
15142 SDValue FAdd = N0.getOpcode() == ISD::FADD ? N0 : N1;
15143 if (!hasNoInfs(Options, FAdd))
15144 return SDValue();
15145
15146 // Floating-point multiply-add without intermediate rounding.
15147 bool HasFMA =
15148 isContractableFMUL(Options, SDValue(N, 0)) &&
15149 TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
15150 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
15151
15152 // Floating-point multiply-add with intermediate rounding. This can result
15153 // in a less precise result due to the changed rounding order.
15154 bool HasFMAD = Options.UnsafeFPMath &&
15155 (LegalOperations && TLI.isFMADLegal(DAG, N));
15156
15157 // No valid opcode, do not combine.
15158 if (!HasFMAD && !HasFMA)
15159 return SDValue();
15160
15161 // Always prefer FMAD to FMA for precision.
15162 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
15163 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
15164
15165 // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
15166 // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
15167 auto FuseFADD = [&](SDValue X, SDValue Y) {
15168 if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
15169 if (auto *C = isConstOrConstSplatFP(X.getOperand(1), true)) {
15170 if (C->isExactlyValue(+1.0))
15171 return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
15172 Y);
15173 if (C->isExactlyValue(-1.0))
15174 return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
15175 DAG.getNode(ISD::FNEG, SL, VT, Y));
15176 }
15177 }
15178 return SDValue();
15179 };
15180
15181 if (SDValue FMA = FuseFADD(N0, N1))
15182 return FMA;
15183 if (SDValue FMA = FuseFADD(N1, N0))
15184 return FMA;
15185
15186 // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
15187 // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
15188 // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
15189 // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
15190 auto FuseFSUB = [&](SDValue X, SDValue Y) {
15191 if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
15192 if (auto *C0 = isConstOrConstSplatFP(X.getOperand(0), true)) {
15193 if (C0->isExactlyValue(+1.0))
15194 return DAG.getNode(PreferredFusedOpcode, SL, VT,
15195 DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
15196 Y);
15197 if (C0->isExactlyValue(-1.0))
15198 return DAG.getNode(PreferredFusedOpcode, SL, VT,
15199 DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y,
15200 DAG.getNode(ISD::FNEG, SL, VT, Y));
15201 }
15202 if (auto *C1 = isConstOrConstSplatFP(X.getOperand(1), true)) {
15203 if (C1->isExactlyValue(+1.0))
15204 return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
15205 DAG.getNode(ISD::FNEG, SL, VT, Y));
15206 if (C1->isExactlyValue(-1.0))
15207 return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y,
15208 Y);
15209 }
15210 }
15211 return SDValue();
15212 };
15213
15214 if (SDValue FMA = FuseFSUB(N0, N1))
15215 return FMA;
15216 if (SDValue FMA = FuseFSUB(N1, N0))
15217 return FMA;
15218
15219 return SDValue();
15220 }
15221
visitFADD(SDNode * N)15222 SDValue DAGCombiner::visitFADD(SDNode *N) {
15223 SDValue N0 = N->getOperand(0);
15224 SDValue N1 = N->getOperand(1);
15225 SDNode *N0CFP = DAG.isConstantFPBuildVectorOrConstantFP(N0);
15226 SDNode *N1CFP = DAG.isConstantFPBuildVectorOrConstantFP(N1);
15227 EVT VT = N->getValueType(0);
15228 SDLoc DL(N);
15229 const TargetOptions &Options = DAG.getTarget().Options;
15230 SDNodeFlags Flags = N->getFlags();
15231 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15232
15233 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
15234 return R;
15235
15236 // fold (fadd c1, c2) -> c1 + c2
15237 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FADD, DL, VT, {N0, N1}))
15238 return C;
15239
15240 // canonicalize constant to RHS
15241 if (N0CFP && !N1CFP)
15242 return DAG.getNode(ISD::FADD, DL, VT, N1, N0);
15243
15244 // fold vector ops
15245 if (VT.isVector())
15246 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
15247 return FoldedVOp;
15248
15249 // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
15250 ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1, true);
15251 if (N1C && N1C->isZero())
15252 if (N1C->isNegative() || Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())
15253 return N0;
15254
15255 if (SDValue NewSel = foldBinOpIntoSelect(N))
15256 return NewSel;
15257
15258 // fold (fadd A, (fneg B)) -> (fsub A, B)
15259 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
15260 if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
15261 N1, DAG, LegalOperations, ForCodeSize))
15262 return DAG.getNode(ISD::FSUB, DL, VT, N0, NegN1);
15263
15264 // fold (fadd (fneg A), B) -> (fsub B, A)
15265 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT))
15266 if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
15267 N0, DAG, LegalOperations, ForCodeSize))
15268 return DAG.getNode(ISD::FSUB, DL, VT, N1, NegN0);
15269
15270 auto isFMulNegTwo = [](SDValue FMul) {
15271 if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
15272 return false;
15273 auto *C = isConstOrConstSplatFP(FMul.getOperand(1), true);
15274 return C && C->isExactlyValue(-2.0);
15275 };
15276
15277 // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
15278 if (isFMulNegTwo(N0)) {
15279 SDValue B = N0.getOperand(0);
15280 SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
15281 return DAG.getNode(ISD::FSUB, DL, VT, N1, Add);
15282 }
15283 // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
15284 if (isFMulNegTwo(N1)) {
15285 SDValue B = N1.getOperand(0);
15286 SDValue Add = DAG.getNode(ISD::FADD, DL, VT, B, B);
15287 return DAG.getNode(ISD::FSUB, DL, VT, N0, Add);
15288 }
15289
15290 // No FP constant should be created after legalization as Instruction
15291 // Selection pass has a hard time dealing with FP constants.
15292 bool AllowNewConst = (Level < AfterLegalizeDAG);
15293
15294 // If nnan is enabled, fold lots of things.
15295 if ((Options.NoNaNsFPMath || Flags.hasNoNaNs()) && AllowNewConst) {
15296 // If allowed, fold (fadd (fneg x), x) -> 0.0
15297 if (N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1)
15298 return DAG.getConstantFP(0.0, DL, VT);
15299
15300 // If allowed, fold (fadd x, (fneg x)) -> 0.0
15301 if (N1.getOpcode() == ISD::FNEG && N1.getOperand(0) == N0)
15302 return DAG.getConstantFP(0.0, DL, VT);
15303 }
15304
15305 // If 'unsafe math' or reassoc and nsz, fold lots of things.
15306 // TODO: break out portions of the transformations below for which Unsafe is
15307 // considered and which do not require both nsz and reassoc
15308 if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
15309 (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
15310 AllowNewConst) {
15311 // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
15312 if (N1CFP && N0.getOpcode() == ISD::FADD &&
15313 DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
15314 SDValue NewC = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1);
15315 return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), NewC);
15316 }
15317
15318 // We can fold chains of FADD's of the same value into multiplications.
15319 // This transform is not safe in general because we are reducing the number
15320 // of rounding steps.
15321 if (TLI.isOperationLegalOrCustom(ISD::FMUL, VT) && !N0CFP && !N1CFP) {
15322 if (N0.getOpcode() == ISD::FMUL) {
15323 SDNode *CFP00 =
15324 DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
15325 SDNode *CFP01 =
15326 DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1));
15327
15328 // (fadd (fmul x, c), x) -> (fmul x, c+1)
15329 if (CFP01 && !CFP00 && N0.getOperand(0) == N1) {
15330 SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
15331 DAG.getConstantFP(1.0, DL, VT));
15332 return DAG.getNode(ISD::FMUL, DL, VT, N1, NewCFP);
15333 }
15334
15335 // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
15336 if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
15337 N1.getOperand(0) == N1.getOperand(1) &&
15338 N0.getOperand(0) == N1.getOperand(0)) {
15339 SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1),
15340 DAG.getConstantFP(2.0, DL, VT));
15341 return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), NewCFP);
15342 }
15343 }
15344
15345 if (N1.getOpcode() == ISD::FMUL) {
15346 SDNode *CFP10 =
15347 DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
15348 SDNode *CFP11 =
15349 DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(1));
15350
15351 // (fadd x, (fmul x, c)) -> (fmul x, c+1)
15352 if (CFP11 && !CFP10 && N1.getOperand(0) == N0) {
15353 SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
15354 DAG.getConstantFP(1.0, DL, VT));
15355 return DAG.getNode(ISD::FMUL, DL, VT, N0, NewCFP);
15356 }
15357
15358 // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
15359 if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
15360 N0.getOperand(0) == N0.getOperand(1) &&
15361 N1.getOperand(0) == N0.getOperand(0)) {
15362 SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1),
15363 DAG.getConstantFP(2.0, DL, VT));
15364 return DAG.getNode(ISD::FMUL, DL, VT, N1.getOperand(0), NewCFP);
15365 }
15366 }
15367
15368 if (N0.getOpcode() == ISD::FADD) {
15369 SDNode *CFP00 =
15370 DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(0));
15371 // (fadd (fadd x, x), x) -> (fmul x, 3.0)
15372 if (!CFP00 && N0.getOperand(0) == N0.getOperand(1) &&
15373 (N0.getOperand(0) == N1)) {
15374 return DAG.getNode(ISD::FMUL, DL, VT, N1,
15375 DAG.getConstantFP(3.0, DL, VT));
15376 }
15377 }
15378
15379 if (N1.getOpcode() == ISD::FADD) {
15380 SDNode *CFP10 =
15381 DAG.isConstantFPBuildVectorOrConstantFP(N1.getOperand(0));
15382 // (fadd x, (fadd x, x)) -> (fmul x, 3.0)
15383 if (!CFP10 && N1.getOperand(0) == N1.getOperand(1) &&
15384 N1.getOperand(0) == N0) {
15385 return DAG.getNode(ISD::FMUL, DL, VT, N0,
15386 DAG.getConstantFP(3.0, DL, VT));
15387 }
15388 }
15389
15390 // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
15391 if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
15392 N0.getOperand(0) == N0.getOperand(1) &&
15393 N1.getOperand(0) == N1.getOperand(1) &&
15394 N0.getOperand(0) == N1.getOperand(0)) {
15395 return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0),
15396 DAG.getConstantFP(4.0, DL, VT));
15397 }
15398 }
15399 } // enable-unsafe-fp-math
15400
15401 // FADD -> FMA combines:
15402 if (SDValue Fused = visitFADDForFMACombine(N)) {
15403 AddToWorklist(Fused.getNode());
15404 return Fused;
15405 }
15406 return SDValue();
15407 }
15408
visitSTRICT_FADD(SDNode * N)15409 SDValue DAGCombiner::visitSTRICT_FADD(SDNode *N) {
15410 SDValue Chain = N->getOperand(0);
15411 SDValue N0 = N->getOperand(1);
15412 SDValue N1 = N->getOperand(2);
15413 EVT VT = N->getValueType(0);
15414 EVT ChainVT = N->getValueType(1);
15415 SDLoc DL(N);
15416 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15417
15418 // fold (strict_fadd A, (fneg B)) -> (strict_fsub A, B)
15419 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
15420 if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
15421 N1, DAG, LegalOperations, ForCodeSize)) {
15422 return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
15423 {Chain, N0, NegN1});
15424 }
15425
15426 // fold (strict_fadd (fneg A), B) -> (strict_fsub B, A)
15427 if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::STRICT_FSUB, VT))
15428 if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
15429 N0, DAG, LegalOperations, ForCodeSize)) {
15430 return DAG.getNode(ISD::STRICT_FSUB, DL, DAG.getVTList(VT, ChainVT),
15431 {Chain, N1, NegN0});
15432 }
15433 return SDValue();
15434 }
15435
visitFSUB(SDNode * N)15436 SDValue DAGCombiner::visitFSUB(SDNode *N) {
15437 SDValue N0 = N->getOperand(0);
15438 SDValue N1 = N->getOperand(1);
15439 ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, true);
15440 ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
15441 EVT VT = N->getValueType(0);
15442 SDLoc DL(N);
15443 const TargetOptions &Options = DAG.getTarget().Options;
15444 const SDNodeFlags Flags = N->getFlags();
15445 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15446
15447 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
15448 return R;
15449
15450 // fold (fsub c1, c2) -> c1-c2
15451 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FSUB, DL, VT, {N0, N1}))
15452 return C;
15453
15454 // fold vector ops
15455 if (VT.isVector())
15456 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
15457 return FoldedVOp;
15458
15459 if (SDValue NewSel = foldBinOpIntoSelect(N))
15460 return NewSel;
15461
15462 // (fsub A, 0) -> A
15463 if (N1CFP && N1CFP->isZero()) {
15464 if (!N1CFP->isNegative() || Options.NoSignedZerosFPMath ||
15465 Flags.hasNoSignedZeros()) {
15466 return N0;
15467 }
15468 }
15469
15470 if (N0 == N1) {
15471 // (fsub x, x) -> 0.0
15472 if (Options.NoNaNsFPMath || Flags.hasNoNaNs())
15473 return DAG.getConstantFP(0.0f, DL, VT);
15474 }
15475
15476 // (fsub -0.0, N1) -> -N1
15477 if (N0CFP && N0CFP->isZero()) {
15478 if (N0CFP->isNegative() ||
15479 (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) {
15480 // We cannot replace an FSUB(+-0.0,X) with FNEG(X) when denormals are
15481 // flushed to zero, unless all users treat denorms as zero (DAZ).
15482 // FIXME: This transform will change the sign of a NaN and the behavior
15483 // of a signaling NaN. It is only valid when a NoNaN flag is present.
15484 DenormalMode DenormMode = DAG.getDenormalMode(VT);
15485 if (DenormMode == DenormalMode::getIEEE()) {
15486 if (SDValue NegN1 =
15487 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
15488 return NegN1;
15489 if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
15490 return DAG.getNode(ISD::FNEG, DL, VT, N1);
15491 }
15492 }
15493 }
15494
15495 if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
15496 (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
15497 N1.getOpcode() == ISD::FADD) {
15498 // X - (X + Y) -> -Y
15499 if (N0 == N1->getOperand(0))
15500 return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(1));
15501 // X - (Y + X) -> -Y
15502 if (N0 == N1->getOperand(1))
15503 return DAG.getNode(ISD::FNEG, DL, VT, N1->getOperand(0));
15504 }
15505
15506 // fold (fsub A, (fneg B)) -> (fadd A, B)
15507 if (SDValue NegN1 =
15508 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize))
15509 return DAG.getNode(ISD::FADD, DL, VT, N0, NegN1);
15510
15511 // FSUB -> FMA combines:
15512 if (SDValue Fused = visitFSUBForFMACombine(N)) {
15513 AddToWorklist(Fused.getNode());
15514 return Fused;
15515 }
15516
15517 return SDValue();
15518 }
15519
visitFMUL(SDNode * N)15520 SDValue DAGCombiner::visitFMUL(SDNode *N) {
15521 SDValue N0 = N->getOperand(0);
15522 SDValue N1 = N->getOperand(1);
15523 ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1, true);
15524 EVT VT = N->getValueType(0);
15525 SDLoc DL(N);
15526 const TargetOptions &Options = DAG.getTarget().Options;
15527 const SDNodeFlags Flags = N->getFlags();
15528 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15529
15530 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
15531 return R;
15532
15533 // fold (fmul c1, c2) -> c1*c2
15534 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FMUL, DL, VT, {N0, N1}))
15535 return C;
15536
15537 // canonicalize constant to RHS
15538 if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
15539 !DAG.isConstantFPBuildVectorOrConstantFP(N1))
15540 return DAG.getNode(ISD::FMUL, DL, VT, N1, N0);
15541
15542 // fold vector ops
15543 if (VT.isVector())
15544 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
15545 return FoldedVOp;
15546
15547 if (SDValue NewSel = foldBinOpIntoSelect(N))
15548 return NewSel;
15549
15550 if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) {
15551 // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
15552 if (DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
15553 N0.getOpcode() == ISD::FMUL) {
15554 SDValue N00 = N0.getOperand(0);
15555 SDValue N01 = N0.getOperand(1);
15556 // Avoid an infinite loop by making sure that N00 is not a constant
15557 // (the inner multiply has not been constant folded yet).
15558 if (DAG.isConstantFPBuildVectorOrConstantFP(N01) &&
15559 !DAG.isConstantFPBuildVectorOrConstantFP(N00)) {
15560 SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1);
15561 return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts);
15562 }
15563 }
15564
15565 // Match a special-case: we convert X * 2.0 into fadd.
15566 // fmul (fadd X, X), C -> fmul X, 2.0 * C
15567 if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
15568 N0.getOperand(0) == N0.getOperand(1)) {
15569 const SDValue Two = DAG.getConstantFP(2.0, DL, VT);
15570 SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1);
15571 return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts);
15572 }
15573 }
15574
15575 // fold (fmul X, 2.0) -> (fadd X, X)
15576 if (N1CFP && N1CFP->isExactlyValue(+2.0))
15577 return DAG.getNode(ISD::FADD, DL, VT, N0, N0);
15578
15579 // fold (fmul X, -1.0) -> (fsub -0.0, X)
15580 if (N1CFP && N1CFP->isExactlyValue(-1.0)) {
15581 if (!LegalOperations || TLI.isOperationLegal(ISD::FSUB, VT)) {
15582 return DAG.getNode(ISD::FSUB, DL, VT,
15583 DAG.getConstantFP(-0.0, DL, VT), N0, Flags);
15584 }
15585 }
15586
15587 // -N0 * -N1 --> N0 * N1
15588 TargetLowering::NegatibleCost CostN0 =
15589 TargetLowering::NegatibleCost::Expensive;
15590 TargetLowering::NegatibleCost CostN1 =
15591 TargetLowering::NegatibleCost::Expensive;
15592 SDValue NegN0 =
15593 TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
15594 if (NegN0) {
15595 HandleSDNode NegN0Handle(NegN0);
15596 SDValue NegN1 =
15597 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
15598 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
15599 CostN1 == TargetLowering::NegatibleCost::Cheaper))
15600 return DAG.getNode(ISD::FMUL, DL, VT, NegN0, NegN1);
15601 }
15602
15603 // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
15604 // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
15605 if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
15606 (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
15607 TLI.isOperationLegal(ISD::FABS, VT)) {
15608 SDValue Select = N0, X = N1;
15609 if (Select.getOpcode() != ISD::SELECT)
15610 std::swap(Select, X);
15611
15612 SDValue Cond = Select.getOperand(0);
15613 auto TrueOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(1));
15614 auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Select.getOperand(2));
15615
15616 if (TrueOpnd && FalseOpnd &&
15617 Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == X &&
15618 isa<ConstantFPSDNode>(Cond.getOperand(1)) &&
15619 cast<ConstantFPSDNode>(Cond.getOperand(1))->isExactlyValue(0.0)) {
15620 ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
15621 switch (CC) {
15622 default: break;
15623 case ISD::SETOLT:
15624 case ISD::SETULT:
15625 case ISD::SETOLE:
15626 case ISD::SETULE:
15627 case ISD::SETLT:
15628 case ISD::SETLE:
15629 std::swap(TrueOpnd, FalseOpnd);
15630 [[fallthrough]];
15631 case ISD::SETOGT:
15632 case ISD::SETUGT:
15633 case ISD::SETOGE:
15634 case ISD::SETUGE:
15635 case ISD::SETGT:
15636 case ISD::SETGE:
15637 if (TrueOpnd->isExactlyValue(-1.0) && FalseOpnd->isExactlyValue(1.0) &&
15638 TLI.isOperationLegal(ISD::FNEG, VT))
15639 return DAG.getNode(ISD::FNEG, DL, VT,
15640 DAG.getNode(ISD::FABS, DL, VT, X));
15641 if (TrueOpnd->isExactlyValue(1.0) && FalseOpnd->isExactlyValue(-1.0))
15642 return DAG.getNode(ISD::FABS, DL, VT, X);
15643
15644 break;
15645 }
15646 }
15647 }
15648
15649 // FMUL -> FMA combines:
15650 if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
15651 AddToWorklist(Fused.getNode());
15652 return Fused;
15653 }
15654
15655 return SDValue();
15656 }
15657
visitFMA(SDNode * N)15658 SDValue DAGCombiner::visitFMA(SDNode *N) {
15659 SDValue N0 = N->getOperand(0);
15660 SDValue N1 = N->getOperand(1);
15661 SDValue N2 = N->getOperand(2);
15662 ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
15663 ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
15664 EVT VT = N->getValueType(0);
15665 SDLoc DL(N);
15666 const TargetOptions &Options = DAG.getTarget().Options;
15667 // FMA nodes have flags that propagate to the created nodes.
15668 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15669
15670 bool CanReassociate =
15671 Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
15672
15673 // Constant fold FMA.
15674 if (isa<ConstantFPSDNode>(N0) &&
15675 isa<ConstantFPSDNode>(N1) &&
15676 isa<ConstantFPSDNode>(N2)) {
15677 return DAG.getNode(ISD::FMA, DL, VT, N0, N1, N2);
15678 }
15679
15680 // (-N0 * -N1) + N2 --> (N0 * N1) + N2
15681 TargetLowering::NegatibleCost CostN0 =
15682 TargetLowering::NegatibleCost::Expensive;
15683 TargetLowering::NegatibleCost CostN1 =
15684 TargetLowering::NegatibleCost::Expensive;
15685 SDValue NegN0 =
15686 TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
15687 if (NegN0) {
15688 HandleSDNode NegN0Handle(NegN0);
15689 SDValue NegN1 =
15690 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
15691 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
15692 CostN1 == TargetLowering::NegatibleCost::Cheaper))
15693 return DAG.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
15694 }
15695
15696 // FIXME: use fast math flags instead of Options.UnsafeFPMath
15697 if (Options.UnsafeFPMath) {
15698 if (N0CFP && N0CFP->isZero())
15699 return N2;
15700 if (N1CFP && N1CFP->isZero())
15701 return N2;
15702 }
15703
15704 if (N0CFP && N0CFP->isExactlyValue(1.0))
15705 return DAG.getNode(ISD::FADD, SDLoc(N), VT, N1, N2);
15706 if (N1CFP && N1CFP->isExactlyValue(1.0))
15707 return DAG.getNode(ISD::FADD, SDLoc(N), VT, N0, N2);
15708
15709 // Canonicalize (fma c, x, y) -> (fma x, c, y)
15710 if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
15711 !DAG.isConstantFPBuildVectorOrConstantFP(N1))
15712 return DAG.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
15713
15714 if (CanReassociate) {
15715 // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
15716 if (N2.getOpcode() == ISD::FMUL && N0 == N2.getOperand(0) &&
15717 DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
15718 DAG.isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) {
15719 return DAG.getNode(ISD::FMUL, DL, VT, N0,
15720 DAG.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1)));
15721 }
15722
15723 // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
15724 if (N0.getOpcode() == ISD::FMUL &&
15725 DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
15726 DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
15727 return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
15728 DAG.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)),
15729 N2);
15730 }
15731 }
15732
15733 // (fma x, -1, y) -> (fadd (fneg x), y)
15734 if (N1CFP) {
15735 if (N1CFP->isExactlyValue(1.0))
15736 return DAG.getNode(ISD::FADD, DL, VT, N0, N2);
15737
15738 if (N1CFP->isExactlyValue(-1.0) &&
15739 (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) {
15740 SDValue RHSNeg = DAG.getNode(ISD::FNEG, DL, VT, N0);
15741 AddToWorklist(RHSNeg.getNode());
15742 return DAG.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
15743 }
15744
15745 // fma (fneg x), K, y -> fma x -K, y
15746 if (N0.getOpcode() == ISD::FNEG &&
15747 (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
15748 (N1.hasOneUse() && !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT,
15749 ForCodeSize)))) {
15750 return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
15751 DAG.getNode(ISD::FNEG, DL, VT, N1), N2);
15752 }
15753 }
15754
15755 if (CanReassociate) {
15756 // (fma x, c, x) -> (fmul x, (c+1))
15757 if (N1CFP && N0 == N2) {
15758 return DAG.getNode(
15759 ISD::FMUL, DL, VT, N0,
15760 DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(1.0, DL, VT)));
15761 }
15762
15763 // (fma x, c, (fneg x)) -> (fmul x, (c-1))
15764 if (N1CFP && N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) {
15765 return DAG.getNode(
15766 ISD::FMUL, DL, VT, N0,
15767 DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(-1.0, DL, VT)));
15768 }
15769 }
15770
15771 // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z))
15772 // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z))
15773 if (!TLI.isFNegFree(VT))
15774 if (SDValue Neg = TLI.getCheaperNegatedExpression(
15775 SDValue(N, 0), DAG, LegalOperations, ForCodeSize))
15776 return DAG.getNode(ISD::FNEG, DL, VT, Neg);
15777 return SDValue();
15778 }
15779
15780 // Combine multiple FDIVs with the same divisor into multiple FMULs by the
15781 // reciprocal.
15782 // E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
15783 // Notice that this is not always beneficial. One reason is different targets
15784 // may have different costs for FDIV and FMUL, so sometimes the cost of two
15785 // FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
15786 // is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
combineRepeatedFPDivisors(SDNode * N)15787 SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
15788 // TODO: Limit this transform based on optsize/minsize - it always creates at
15789 // least 1 extra instruction. But the perf win may be substantial enough
15790 // that only minsize should restrict this.
15791 bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath;
15792 const SDNodeFlags Flags = N->getFlags();
15793 if (LegalDAG || (!UnsafeMath && !Flags.hasAllowReciprocal()))
15794 return SDValue();
15795
15796 // Skip if current node is a reciprocal/fneg-reciprocal.
15797 SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
15798 ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0, /* AllowUndefs */ true);
15799 if (N0CFP && (N0CFP->isExactlyValue(1.0) || N0CFP->isExactlyValue(-1.0)))
15800 return SDValue();
15801
15802 // Exit early if the target does not want this transform or if there can't
15803 // possibly be enough uses of the divisor to make the transform worthwhile.
15804 unsigned MinUses = TLI.combineRepeatedFPDivisors();
15805
15806 // For splat vectors, scale the number of uses by the splat factor. If we can
15807 // convert the division into a scalar op, that will likely be much faster.
15808 unsigned NumElts = 1;
15809 EVT VT = N->getValueType(0);
15810 if (VT.isVector() && DAG.isSplatValue(N1))
15811 NumElts = VT.getVectorMinNumElements();
15812
15813 if (!MinUses || (N1->use_size() * NumElts) < MinUses)
15814 return SDValue();
15815
15816 // Find all FDIV users of the same divisor.
15817 // Use a set because duplicates may be present in the user list.
15818 SetVector<SDNode *> Users;
15819 for (auto *U : N1->uses()) {
15820 if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) {
15821 // Skip X/sqrt(X) that has not been simplified to sqrt(X) yet.
15822 if (U->getOperand(1).getOpcode() == ISD::FSQRT &&
15823 U->getOperand(0) == U->getOperand(1).getOperand(0) &&
15824 U->getFlags().hasAllowReassociation() &&
15825 U->getFlags().hasNoSignedZeros())
15826 continue;
15827
15828 // This division is eligible for optimization only if global unsafe math
15829 // is enabled or if this division allows reciprocal formation.
15830 if (UnsafeMath || U->getFlags().hasAllowReciprocal())
15831 Users.insert(U);
15832 }
15833 }
15834
15835 // Now that we have the actual number of divisor uses, make sure it meets
15836 // the minimum threshold specified by the target.
15837 if ((Users.size() * NumElts) < MinUses)
15838 return SDValue();
15839
15840 SDLoc DL(N);
15841 SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
15842 SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1, Flags);
15843
15844 // Dividend / Divisor -> Dividend * Reciprocal
15845 for (auto *U : Users) {
15846 SDValue Dividend = U->getOperand(0);
15847 if (Dividend != FPOne) {
15848 SDValue NewNode = DAG.getNode(ISD::FMUL, SDLoc(U), VT, Dividend,
15849 Reciprocal, Flags);
15850 CombineTo(U, NewNode);
15851 } else if (U != Reciprocal.getNode()) {
15852 // In the absence of fast-math-flags, this user node is always the
15853 // same node as Reciprocal, but with FMF they may be different nodes.
15854 CombineTo(U, Reciprocal);
15855 }
15856 }
15857 return SDValue(N, 0); // N was replaced.
15858 }
15859
visitFDIV(SDNode * N)15860 SDValue DAGCombiner::visitFDIV(SDNode *N) {
15861 SDValue N0 = N->getOperand(0);
15862 SDValue N1 = N->getOperand(1);
15863 EVT VT = N->getValueType(0);
15864 SDLoc DL(N);
15865 const TargetOptions &Options = DAG.getTarget().Options;
15866 SDNodeFlags Flags = N->getFlags();
15867 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
15868
15869 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
15870 return R;
15871
15872 // fold (fdiv c1, c2) -> c1/c2
15873 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FDIV, DL, VT, {N0, N1}))
15874 return C;
15875
15876 // fold vector ops
15877 if (VT.isVector())
15878 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
15879 return FoldedVOp;
15880
15881 if (SDValue NewSel = foldBinOpIntoSelect(N))
15882 return NewSel;
15883
15884 if (SDValue V = combineRepeatedFPDivisors(N))
15885 return V;
15886
15887 if (Options.UnsafeFPMath || Flags.hasAllowReciprocal()) {
15888 // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable.
15889 if (auto *N1CFP = dyn_cast<ConstantFPSDNode>(N1)) {
15890 // Compute the reciprocal 1.0 / c2.
15891 const APFloat &N1APF = N1CFP->getValueAPF();
15892 APFloat Recip(N1APF.getSemantics(), 1); // 1.0
15893 APFloat::opStatus st = Recip.divide(N1APF, APFloat::rmNearestTiesToEven);
15894 // Only do the transform if the reciprocal is a legal fp immediate that
15895 // isn't too nasty (eg NaN, denormal, ...).
15896 if ((st == APFloat::opOK || st == APFloat::opInexact) && // Not too nasty
15897 (!LegalOperations ||
15898 // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
15899 // backend)... we should handle this gracefully after Legalize.
15900 // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
15901 TLI.isOperationLegal(ISD::ConstantFP, VT) ||
15902 TLI.isFPImmLegal(Recip, VT, ForCodeSize)))
15903 return DAG.getNode(ISD::FMUL, DL, VT, N0,
15904 DAG.getConstantFP(Recip, DL, VT));
15905 }
15906
15907 // If this FDIV is part of a reciprocal square root, it may be folded
15908 // into a target-specific square root estimate instruction.
15909 if (N1.getOpcode() == ISD::FSQRT) {
15910 if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0), Flags))
15911 return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
15912 } else if (N1.getOpcode() == ISD::FP_EXTEND &&
15913 N1.getOperand(0).getOpcode() == ISD::FSQRT) {
15914 if (SDValue RV =
15915 buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
15916 RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
15917 AddToWorklist(RV.getNode());
15918 return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
15919 }
15920 } else if (N1.getOpcode() == ISD::FP_ROUND &&
15921 N1.getOperand(0).getOpcode() == ISD::FSQRT) {
15922 if (SDValue RV =
15923 buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) {
15924 RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
15925 AddToWorklist(RV.getNode());
15926 return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
15927 }
15928 } else if (N1.getOpcode() == ISD::FMUL) {
15929 // Look through an FMUL. Even though this won't remove the FDIV directly,
15930 // it's still worthwhile to get rid of the FSQRT if possible.
15931 SDValue Sqrt, Y;
15932 if (N1.getOperand(0).getOpcode() == ISD::FSQRT) {
15933 Sqrt = N1.getOperand(0);
15934 Y = N1.getOperand(1);
15935 } else if (N1.getOperand(1).getOpcode() == ISD::FSQRT) {
15936 Sqrt = N1.getOperand(1);
15937 Y = N1.getOperand(0);
15938 }
15939 if (Sqrt.getNode()) {
15940 // If the other multiply operand is known positive, pull it into the
15941 // sqrt. That will eliminate the division if we convert to an estimate.
15942 if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
15943 N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
15944 SDValue A;
15945 if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
15946 A = Y.getOperand(0);
15947 else if (Y == Sqrt.getOperand(0))
15948 A = Y;
15949 if (A) {
15950 // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
15951 // X / (A * sqrt(A)) --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
15952 SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, A, A);
15953 SDValue AAZ =
15954 DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0));
15955 if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
15956 return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt);
15957
15958 // Estimate creation failed. Clean up speculatively created nodes.
15959 recursivelyDeleteUnusedNodes(AAZ.getNode());
15960 }
15961 }
15962
15963 // We found a FSQRT, so try to make this fold:
15964 // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
15965 if (SDValue Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0), Flags)) {
15966 SDValue Div = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, Rsqrt, Y);
15967 AddToWorklist(Div.getNode());
15968 return DAG.getNode(ISD::FMUL, DL, VT, N0, Div);
15969 }
15970 }
15971 }
15972
15973 // Fold into a reciprocal estimate and multiply instead of a real divide.
15974 if (Options.NoInfsFPMath || Flags.hasNoInfs())
15975 if (SDValue RV = BuildDivEstimate(N0, N1, Flags))
15976 return RV;
15977 }
15978
15979 // Fold X/Sqrt(X) -> Sqrt(X)
15980 if ((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) &&
15981 (Options.UnsafeFPMath || Flags.hasAllowReassociation()))
15982 if (N1.getOpcode() == ISD::FSQRT && N0 == N1.getOperand(0))
15983 return N1;
15984
15985 // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
15986 TargetLowering::NegatibleCost CostN0 =
15987 TargetLowering::NegatibleCost::Expensive;
15988 TargetLowering::NegatibleCost CostN1 =
15989 TargetLowering::NegatibleCost::Expensive;
15990 SDValue NegN0 =
15991 TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize, CostN0);
15992 if (NegN0) {
15993 HandleSDNode NegN0Handle(NegN0);
15994 SDValue NegN1 =
15995 TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
15996 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
15997 CostN1 == TargetLowering::NegatibleCost::Cheaper))
15998 return DAG.getNode(ISD::FDIV, SDLoc(N), VT, NegN0, NegN1);
15999 }
16000
16001 return SDValue();
16002 }
16003
visitFREM(SDNode * N)16004 SDValue DAGCombiner::visitFREM(SDNode *N) {
16005 SDValue N0 = N->getOperand(0);
16006 SDValue N1 = N->getOperand(1);
16007 EVT VT = N->getValueType(0);
16008 SDNodeFlags Flags = N->getFlags();
16009 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16010
16011 if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags))
16012 return R;
16013
16014 // fold (frem c1, c2) -> fmod(c1,c2)
16015 if (SDValue C = DAG.FoldConstantArithmetic(ISD::FREM, SDLoc(N), VT, {N0, N1}))
16016 return C;
16017
16018 if (SDValue NewSel = foldBinOpIntoSelect(N))
16019 return NewSel;
16020
16021 return SDValue();
16022 }
16023
visitFSQRT(SDNode * N)16024 SDValue DAGCombiner::visitFSQRT(SDNode *N) {
16025 SDNodeFlags Flags = N->getFlags();
16026 const TargetOptions &Options = DAG.getTarget().Options;
16027
16028 // Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as:
16029 // sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN
16030 if (!Flags.hasApproximateFuncs() ||
16031 (!Options.NoInfsFPMath && !Flags.hasNoInfs()))
16032 return SDValue();
16033
16034 SDValue N0 = N->getOperand(0);
16035 if (TLI.isFsqrtCheap(N0, DAG))
16036 return SDValue();
16037
16038 // FSQRT nodes have flags that propagate to the created nodes.
16039 // TODO: If this is N0/sqrt(N0), and we reach this node before trying to
16040 // transform the fdiv, we may produce a sub-optimal estimate sequence
16041 // because the reciprocal calculation may not have to filter out a
16042 // 0.0 input.
16043 return buildSqrtEstimate(N0, Flags);
16044 }
16045
16046 /// copysign(x, fp_extend(y)) -> copysign(x, y)
16047 /// copysign(x, fp_round(y)) -> copysign(x, y)
CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode * N)16048 static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
16049 SDValue N1 = N->getOperand(1);
16050 if ((N1.getOpcode() == ISD::FP_EXTEND ||
16051 N1.getOpcode() == ISD::FP_ROUND)) {
16052 EVT N1VT = N1->getValueType(0);
16053 EVT N1Op0VT = N1->getOperand(0).getValueType();
16054
16055 // Always fold no-op FP casts.
16056 if (N1VT == N1Op0VT)
16057 return true;
16058
16059 // Do not optimize out type conversion of f128 type yet.
16060 // For some targets like x86_64, configuration is changed to keep one f128
16061 // value in one SSE register, but instruction selection cannot handle
16062 // FCOPYSIGN on SSE registers yet.
16063 if (N1Op0VT == MVT::f128)
16064 return false;
16065
16066 return !N1Op0VT.isVector() || EnableVectorFCopySignExtendRound;
16067 }
16068 return false;
16069 }
16070
visitFCOPYSIGN(SDNode * N)16071 SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
16072 SDValue N0 = N->getOperand(0);
16073 SDValue N1 = N->getOperand(1);
16074 EVT VT = N->getValueType(0);
16075
16076 // fold (fcopysign c1, c2) -> fcopysign(c1,c2)
16077 if (SDValue C =
16078 DAG.FoldConstantArithmetic(ISD::FCOPYSIGN, SDLoc(N), VT, {N0, N1}))
16079 return C;
16080
16081 if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N->getOperand(1))) {
16082 const APFloat &V = N1C->getValueAPF();
16083 // copysign(x, c1) -> fabs(x) iff ispos(c1)
16084 // copysign(x, c1) -> fneg(fabs(x)) iff isneg(c1)
16085 if (!V.isNegative()) {
16086 if (!LegalOperations || TLI.isOperationLegal(ISD::FABS, VT))
16087 return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
16088 } else {
16089 if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))
16090 return DAG.getNode(ISD::FNEG, SDLoc(N), VT,
16091 DAG.getNode(ISD::FABS, SDLoc(N0), VT, N0));
16092 }
16093 }
16094
16095 // copysign(fabs(x), y) -> copysign(x, y)
16096 // copysign(fneg(x), y) -> copysign(x, y)
16097 // copysign(copysign(x,z), y) -> copysign(x, y)
16098 if (N0.getOpcode() == ISD::FABS || N0.getOpcode() == ISD::FNEG ||
16099 N0.getOpcode() == ISD::FCOPYSIGN)
16100 return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0.getOperand(0), N1);
16101
16102 // copysign(x, abs(y)) -> abs(x)
16103 if (N1.getOpcode() == ISD::FABS)
16104 return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
16105
16106 // copysign(x, copysign(y,z)) -> copysign(x, z)
16107 if (N1.getOpcode() == ISD::FCOPYSIGN)
16108 return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(1));
16109
16110 // copysign(x, fp_extend(y)) -> copysign(x, y)
16111 // copysign(x, fp_round(y)) -> copysign(x, y)
16112 if (CanCombineFCOPYSIGN_EXTEND_ROUND(N))
16113 return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(0));
16114
16115 return SDValue();
16116 }
16117
visitFPOW(SDNode * N)16118 SDValue DAGCombiner::visitFPOW(SDNode *N) {
16119 ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N->getOperand(1));
16120 if (!ExponentC)
16121 return SDValue();
16122 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16123
16124 // Try to convert x ** (1/3) into cube root.
16125 // TODO: Handle the various flavors of long double.
16126 // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
16127 // Some range near 1/3 should be fine.
16128 EVT VT = N->getValueType(0);
16129 if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) ||
16130 (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) {
16131 // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
16132 // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
16133 // pow(-val, 1/3) = nan; cbrt(-val) = -num.
16134 // For regular numbers, rounding may cause the results to differ.
16135 // Therefore, we require { nsz ninf nnan afn } for this transform.
16136 // TODO: We could select out the special cases if we don't have nsz/ninf.
16137 SDNodeFlags Flags = N->getFlags();
16138 if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
16139 !Flags.hasApproximateFuncs())
16140 return SDValue();
16141
16142 // Do not create a cbrt() libcall if the target does not have it, and do not
16143 // turn a pow that has lowering support into a cbrt() libcall.
16144 if (!DAG.getLibInfo().has(LibFunc_cbrt) ||
16145 (!DAG.getTargetLoweringInfo().isOperationExpand(ISD::FPOW, VT) &&
16146 DAG.getTargetLoweringInfo().isOperationExpand(ISD::FCBRT, VT)))
16147 return SDValue();
16148
16149 return DAG.getNode(ISD::FCBRT, SDLoc(N), VT, N->getOperand(0));
16150 }
16151
16152 // Try to convert x ** (1/4) and x ** (3/4) into square roots.
16153 // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
16154 // TODO: This could be extended (using a target hook) to handle smaller
16155 // power-of-2 fractional exponents.
16156 bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(0.25);
16157 bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(0.75);
16158 if (ExponentIs025 || ExponentIs075) {
16159 // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
16160 // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) = NaN.
16161 // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0.
16162 // pow(-inf, 0.75) = +inf; sqrt(-inf) * sqrt(sqrt(-inf)) = NaN.
16163 // For regular numbers, rounding may cause the results to differ.
16164 // Therefore, we require { nsz ninf afn } for this transform.
16165 // TODO: We could select out the special cases if we don't have nsz/ninf.
16166 SDNodeFlags Flags = N->getFlags();
16167
16168 // We only need no signed zeros for the 0.25 case.
16169 if ((!Flags.hasNoSignedZeros() && ExponentIs025) || !Flags.hasNoInfs() ||
16170 !Flags.hasApproximateFuncs())
16171 return SDValue();
16172
16173 // Don't double the number of libcalls. We are trying to inline fast code.
16174 if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(ISD::FSQRT, VT))
16175 return SDValue();
16176
16177 // Assume that libcalls are the smallest code.
16178 // TODO: This restriction should probably be lifted for vectors.
16179 if (ForCodeSize)
16180 return SDValue();
16181
16182 // pow(X, 0.25) --> sqrt(sqrt(X))
16183 SDLoc DL(N);
16184 SDValue Sqrt = DAG.getNode(ISD::FSQRT, DL, VT, N->getOperand(0));
16185 SDValue SqrtSqrt = DAG.getNode(ISD::FSQRT, DL, VT, Sqrt);
16186 if (ExponentIs025)
16187 return SqrtSqrt;
16188 // pow(X, 0.75) --> sqrt(X) * sqrt(sqrt(X))
16189 return DAG.getNode(ISD::FMUL, DL, VT, Sqrt, SqrtSqrt);
16190 }
16191
16192 return SDValue();
16193 }
16194
foldFPToIntToFP(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI)16195 static SDValue foldFPToIntToFP(SDNode *N, SelectionDAG &DAG,
16196 const TargetLowering &TLI) {
16197 // We only do this if the target has legal ftrunc. Otherwise, we'd likely be
16198 // replacing casts with a libcall. We also must be allowed to ignore -0.0
16199 // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer
16200 // conversions would return +0.0.
16201 // FIXME: We should be able to use node-level FMF here.
16202 // TODO: If strict math, should we use FABS (+ range check for signed cast)?
16203 EVT VT = N->getValueType(0);
16204 if (!TLI.isOperationLegal(ISD::FTRUNC, VT) ||
16205 !DAG.getTarget().Options.NoSignedZerosFPMath)
16206 return SDValue();
16207
16208 // fptosi/fptoui round towards zero, so converting from FP to integer and
16209 // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X
16210 SDValue N0 = N->getOperand(0);
16211 if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT &&
16212 N0.getOperand(0).getValueType() == VT)
16213 return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
16214
16215 if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT &&
16216 N0.getOperand(0).getValueType() == VT)
16217 return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0.getOperand(0));
16218
16219 return SDValue();
16220 }
16221
visitSINT_TO_FP(SDNode * N)16222 SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
16223 SDValue N0 = N->getOperand(0);
16224 EVT VT = N->getValueType(0);
16225 EVT OpVT = N0.getValueType();
16226
16227 // [us]itofp(undef) = 0, because the result value is bounded.
16228 if (N0.isUndef())
16229 return DAG.getConstantFP(0.0, SDLoc(N), VT);
16230
16231 // fold (sint_to_fp c1) -> c1fp
16232 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
16233 // ...but only if the target supports immediate floating-point values
16234 (!LegalOperations ||
16235 TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
16236 return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
16237
16238 // If the input is a legal type, and SINT_TO_FP is not legal on this target,
16239 // but UINT_TO_FP is legal on this target, try to convert.
16240 if (!hasOperation(ISD::SINT_TO_FP, OpVT) &&
16241 hasOperation(ISD::UINT_TO_FP, OpVT)) {
16242 // If the sign bit is known to be zero, we can change this to UINT_TO_FP.
16243 if (DAG.SignBitIsZero(N0))
16244 return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
16245 }
16246
16247 // The next optimizations are desirable only if SELECT_CC can be lowered.
16248 // fold (sint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), -1.0, 0.0)
16249 if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
16250 !VT.isVector() &&
16251 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
16252 SDLoc DL(N);
16253 return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(-1.0, DL, VT),
16254 DAG.getConstantFP(0.0, DL, VT));
16255 }
16256
16257 // fold (sint_to_fp (zext (setcc x, y, cc))) ->
16258 // (select (setcc x, y, cc), 1.0, 0.0)
16259 if (N0.getOpcode() == ISD::ZERO_EXTEND &&
16260 N0.getOperand(0).getOpcode() == ISD::SETCC && !VT.isVector() &&
16261 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
16262 SDLoc DL(N);
16263 return DAG.getSelect(DL, VT, N0.getOperand(0),
16264 DAG.getConstantFP(1.0, DL, VT),
16265 DAG.getConstantFP(0.0, DL, VT));
16266 }
16267
16268 if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
16269 return FTrunc;
16270
16271 return SDValue();
16272 }
16273
visitUINT_TO_FP(SDNode * N)16274 SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
16275 SDValue N0 = N->getOperand(0);
16276 EVT VT = N->getValueType(0);
16277 EVT OpVT = N0.getValueType();
16278
16279 // [us]itofp(undef) = 0, because the result value is bounded.
16280 if (N0.isUndef())
16281 return DAG.getConstantFP(0.0, SDLoc(N), VT);
16282
16283 // fold (uint_to_fp c1) -> c1fp
16284 if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
16285 // ...but only if the target supports immediate floating-point values
16286 (!LegalOperations ||
16287 TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
16288 return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
16289
16290 // If the input is a legal type, and UINT_TO_FP is not legal on this target,
16291 // but SINT_TO_FP is legal on this target, try to convert.
16292 if (!hasOperation(ISD::UINT_TO_FP, OpVT) &&
16293 hasOperation(ISD::SINT_TO_FP, OpVT)) {
16294 // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
16295 if (DAG.SignBitIsZero(N0))
16296 return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
16297 }
16298
16299 // fold (uint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), 1.0, 0.0)
16300 if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
16301 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
16302 SDLoc DL(N);
16303 return DAG.getSelect(DL, VT, N0, DAG.getConstantFP(1.0, DL, VT),
16304 DAG.getConstantFP(0.0, DL, VT));
16305 }
16306
16307 if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
16308 return FTrunc;
16309
16310 return SDValue();
16311 }
16312
16313 // Fold (fp_to_{s/u}int ({s/u}int_to_fpx)) -> zext x, sext x, trunc x, or x
FoldIntToFPToInt(SDNode * N,SelectionDAG & DAG)16314 static SDValue FoldIntToFPToInt(SDNode *N, SelectionDAG &DAG) {
16315 SDValue N0 = N->getOperand(0);
16316 EVT VT = N->getValueType(0);
16317
16318 if (N0.getOpcode() != ISD::UINT_TO_FP && N0.getOpcode() != ISD::SINT_TO_FP)
16319 return SDValue();
16320
16321 SDValue Src = N0.getOperand(0);
16322 EVT SrcVT = Src.getValueType();
16323 bool IsInputSigned = N0.getOpcode() == ISD::SINT_TO_FP;
16324 bool IsOutputSigned = N->getOpcode() == ISD::FP_TO_SINT;
16325
16326 // We can safely assume the conversion won't overflow the output range,
16327 // because (for example) (uint8_t)18293.f is undefined behavior.
16328
16329 // Since we can assume the conversion won't overflow, our decision as to
16330 // whether the input will fit in the float should depend on the minimum
16331 // of the input range and output range.
16332
16333 // This means this is also safe for a signed input and unsigned output, since
16334 // a negative input would lead to undefined behavior.
16335 unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned;
16336 unsigned OutputSize = (int)VT.getScalarSizeInBits();
16337 unsigned ActualSize = std::min(InputSize, OutputSize);
16338 const fltSemantics &sem = DAG.EVTToAPFloatSemantics(N0.getValueType());
16339
16340 // We can only fold away the float conversion if the input range can be
16341 // represented exactly in the float range.
16342 if (APFloat::semanticsPrecision(sem) >= ActualSize) {
16343 if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits()) {
16344 unsigned ExtOp = IsInputSigned && IsOutputSigned ? ISD::SIGN_EXTEND
16345 : ISD::ZERO_EXTEND;
16346 return DAG.getNode(ExtOp, SDLoc(N), VT, Src);
16347 }
16348 if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
16349 return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Src);
16350 return DAG.getBitcast(VT, Src);
16351 }
16352 return SDValue();
16353 }
16354
visitFP_TO_SINT(SDNode * N)16355 SDValue DAGCombiner::visitFP_TO_SINT(SDNode *N) {
16356 SDValue N0 = N->getOperand(0);
16357 EVT VT = N->getValueType(0);
16358
16359 // fold (fp_to_sint undef) -> undef
16360 if (N0.isUndef())
16361 return DAG.getUNDEF(VT);
16362
16363 // fold (fp_to_sint c1fp) -> c1
16364 if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
16365 return DAG.getNode(ISD::FP_TO_SINT, SDLoc(N), VT, N0);
16366
16367 return FoldIntToFPToInt(N, DAG);
16368 }
16369
visitFP_TO_UINT(SDNode * N)16370 SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
16371 SDValue N0 = N->getOperand(0);
16372 EVT VT = N->getValueType(0);
16373
16374 // fold (fp_to_uint undef) -> undef
16375 if (N0.isUndef())
16376 return DAG.getUNDEF(VT);
16377
16378 // fold (fp_to_uint c1fp) -> c1
16379 if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
16380 return DAG.getNode(ISD::FP_TO_UINT, SDLoc(N), VT, N0);
16381
16382 return FoldIntToFPToInt(N, DAG);
16383 }
16384
visitFP_ROUND(SDNode * N)16385 SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
16386 SDValue N0 = N->getOperand(0);
16387 SDValue N1 = N->getOperand(1);
16388 EVT VT = N->getValueType(0);
16389
16390 // fold (fp_round c1fp) -> c1fp
16391 if (SDValue C =
16392 DAG.FoldConstantArithmetic(ISD::FP_ROUND, SDLoc(N), VT, {N0, N1}))
16393 return C;
16394
16395 // fold (fp_round (fp_extend x)) -> x
16396 if (N0.getOpcode() == ISD::FP_EXTEND && VT == N0.getOperand(0).getValueType())
16397 return N0.getOperand(0);
16398
16399 // fold (fp_round (fp_round x)) -> (fp_round x)
16400 if (N0.getOpcode() == ISD::FP_ROUND) {
16401 const bool NIsTrunc = N->getConstantOperandVal(1) == 1;
16402 const bool N0IsTrunc = N0.getConstantOperandVal(1) == 1;
16403
16404 // Skip this folding if it results in an fp_round from f80 to f16.
16405 //
16406 // f80 to f16 always generates an expensive (and as yet, unimplemented)
16407 // libcall to __truncxfhf2 instead of selecting native f16 conversion
16408 // instructions from f32 or f64. Moreover, the first (value-preserving)
16409 // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
16410 // x86.
16411 if (N0.getOperand(0).getValueType() == MVT::f80 && VT == MVT::f16)
16412 return SDValue();
16413
16414 // If the first fp_round isn't a value preserving truncation, it might
16415 // introduce a tie in the second fp_round, that wouldn't occur in the
16416 // single-step fp_round we want to fold to.
16417 // In other words, double rounding isn't the same as rounding.
16418 // Also, this is a value preserving truncation iff both fp_round's are.
16419 if (DAG.getTarget().Options.UnsafeFPMath || N0IsTrunc) {
16420 SDLoc DL(N);
16421 return DAG.getNode(
16422 ISD::FP_ROUND, DL, VT, N0.getOperand(0),
16423 DAG.getIntPtrConstant(NIsTrunc && N0IsTrunc, DL, /*isTarget=*/true));
16424 }
16425 }
16426
16427 // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y)
16428 if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse()) {
16429 SDValue Tmp = DAG.getNode(ISD::FP_ROUND, SDLoc(N0), VT,
16430 N0.getOperand(0), N1);
16431 AddToWorklist(Tmp.getNode());
16432 return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT,
16433 Tmp, N0.getOperand(1));
16434 }
16435
16436 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
16437 return NewVSel;
16438
16439 return SDValue();
16440 }
16441
visitFP_EXTEND(SDNode * N)16442 SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
16443 SDValue N0 = N->getOperand(0);
16444 EVT VT = N->getValueType(0);
16445
16446 if (VT.isVector())
16447 if (SDValue FoldedVOp = SimplifyVCastOp(N, SDLoc(N)))
16448 return FoldedVOp;
16449
16450 // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
16451 if (N->hasOneUse() &&
16452 N->use_begin()->getOpcode() == ISD::FP_ROUND)
16453 return SDValue();
16454
16455 // fold (fp_extend c1fp) -> c1fp
16456 if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
16457 return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, N0);
16458
16459 // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op)
16460 if (N0.getOpcode() == ISD::FP16_TO_FP &&
16461 TLI.getOperationAction(ISD::FP16_TO_FP, VT) == TargetLowering::Legal)
16462 return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), VT, N0.getOperand(0));
16463
16464 // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
16465 // value of X.
16466 if (N0.getOpcode() == ISD::FP_ROUND
16467 && N0.getConstantOperandVal(1) == 1) {
16468 SDValue In = N0.getOperand(0);
16469 if (In.getValueType() == VT) return In;
16470 if (VT.bitsLT(In.getValueType()))
16471 return DAG.getNode(ISD::FP_ROUND, SDLoc(N), VT,
16472 In, N0.getOperand(1));
16473 return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, In);
16474 }
16475
16476 // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
16477 if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
16478 TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, VT, N0.getValueType())) {
16479 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
16480 SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
16481 LN0->getChain(),
16482 LN0->getBasePtr(), N0.getValueType(),
16483 LN0->getMemOperand());
16484 CombineTo(N, ExtLoad);
16485 CombineTo(
16486 N0.getNode(),
16487 DAG.getNode(ISD::FP_ROUND, SDLoc(N0), N0.getValueType(), ExtLoad,
16488 DAG.getIntPtrConstant(1, SDLoc(N0), /*isTarget=*/true)),
16489 ExtLoad.getValue(1));
16490 return SDValue(N, 0); // Return N so it doesn't get rechecked!
16491 }
16492
16493 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
16494 return NewVSel;
16495
16496 return SDValue();
16497 }
16498
visitFCEIL(SDNode * N)16499 SDValue DAGCombiner::visitFCEIL(SDNode *N) {
16500 SDValue N0 = N->getOperand(0);
16501 EVT VT = N->getValueType(0);
16502
16503 // fold (fceil c1) -> fceil(c1)
16504 if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
16505 return DAG.getNode(ISD::FCEIL, SDLoc(N), VT, N0);
16506
16507 return SDValue();
16508 }
16509
visitFTRUNC(SDNode * N)16510 SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
16511 SDValue N0 = N->getOperand(0);
16512 EVT VT = N->getValueType(0);
16513
16514 // fold (ftrunc c1) -> ftrunc(c1)
16515 if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
16516 return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0);
16517
16518 // fold ftrunc (known rounded int x) -> x
16519 // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
16520 // likely to be generated to extract integer from a rounded floating value.
16521 switch (N0.getOpcode()) {
16522 default: break;
16523 case ISD::FRINT:
16524 case ISD::FTRUNC:
16525 case ISD::FNEARBYINT:
16526 case ISD::FFLOOR:
16527 case ISD::FCEIL:
16528 return N0;
16529 }
16530
16531 return SDValue();
16532 }
16533
visitFFLOOR(SDNode * N)16534 SDValue DAGCombiner::visitFFLOOR(SDNode *N) {
16535 SDValue N0 = N->getOperand(0);
16536 EVT VT = N->getValueType(0);
16537
16538 // fold (ffloor c1) -> ffloor(c1)
16539 if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
16540 return DAG.getNode(ISD::FFLOOR, SDLoc(N), VT, N0);
16541
16542 return SDValue();
16543 }
16544
visitFNEG(SDNode * N)16545 SDValue DAGCombiner::visitFNEG(SDNode *N) {
16546 SDValue N0 = N->getOperand(0);
16547 EVT VT = N->getValueType(0);
16548 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16549
16550 // Constant fold FNEG.
16551 if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
16552 return DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0);
16553
16554 if (SDValue NegN0 =
16555 TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize))
16556 return NegN0;
16557
16558 // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0
16559 // FIXME: This is duplicated in getNegatibleCost, but getNegatibleCost doesn't
16560 // know it was called from a context with a nsz flag if the input fsub does
16561 // not.
16562 if (N0.getOpcode() == ISD::FSUB &&
16563 (DAG.getTarget().Options.NoSignedZerosFPMath ||
16564 N->getFlags().hasNoSignedZeros()) && N0.hasOneUse()) {
16565 return DAG.getNode(ISD::FSUB, SDLoc(N), VT, N0.getOperand(1),
16566 N0.getOperand(0));
16567 }
16568
16569 if (SDValue Cast = foldSignChangeInBitcast(N))
16570 return Cast;
16571
16572 return SDValue();
16573 }
16574
visitFMinMax(SDNode * N)16575 SDValue DAGCombiner::visitFMinMax(SDNode *N) {
16576 SDValue N0 = N->getOperand(0);
16577 SDValue N1 = N->getOperand(1);
16578 EVT VT = N->getValueType(0);
16579 const SDNodeFlags Flags = N->getFlags();
16580 unsigned Opc = N->getOpcode();
16581 bool PropagatesNaN = Opc == ISD::FMINIMUM || Opc == ISD::FMAXIMUM;
16582 bool IsMin = Opc == ISD::FMINNUM || Opc == ISD::FMINIMUM;
16583 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16584
16585 // Constant fold.
16586 if (SDValue C = DAG.FoldConstantArithmetic(Opc, SDLoc(N), VT, {N0, N1}))
16587 return C;
16588
16589 // Canonicalize to constant on RHS.
16590 if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
16591 !DAG.isConstantFPBuildVectorOrConstantFP(N1))
16592 return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0);
16593
16594 if (const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1)) {
16595 const APFloat &AF = N1CFP->getValueAPF();
16596
16597 // minnum(X, nan) -> X
16598 // maxnum(X, nan) -> X
16599 // minimum(X, nan) -> nan
16600 // maximum(X, nan) -> nan
16601 if (AF.isNaN())
16602 return PropagatesNaN ? N->getOperand(1) : N->getOperand(0);
16603
16604 // In the following folds, inf can be replaced with the largest finite
16605 // float, if the ninf flag is set.
16606 if (AF.isInfinity() || (Flags.hasNoInfs() && AF.isLargest())) {
16607 // minnum(X, -inf) -> -inf
16608 // maxnum(X, +inf) -> +inf
16609 // minimum(X, -inf) -> -inf if nnan
16610 // maximum(X, +inf) -> +inf if nnan
16611 if (IsMin == AF.isNegative() && (!PropagatesNaN || Flags.hasNoNaNs()))
16612 return N->getOperand(1);
16613
16614 // minnum(X, +inf) -> X if nnan
16615 // maxnum(X, -inf) -> X if nnan
16616 // minimum(X, +inf) -> X
16617 // maximum(X, -inf) -> X
16618 if (IsMin != AF.isNegative() && (PropagatesNaN || Flags.hasNoNaNs()))
16619 return N->getOperand(0);
16620 }
16621 }
16622
16623 return SDValue();
16624 }
16625
visitFABS(SDNode * N)16626 SDValue DAGCombiner::visitFABS(SDNode *N) {
16627 SDValue N0 = N->getOperand(0);
16628 EVT VT = N->getValueType(0);
16629
16630 // fold (fabs c1) -> fabs(c1)
16631 if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
16632 return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
16633
16634 // fold (fabs (fabs x)) -> (fabs x)
16635 if (N0.getOpcode() == ISD::FABS)
16636 return N->getOperand(0);
16637
16638 // fold (fabs (fneg x)) -> (fabs x)
16639 // fold (fabs (fcopysign x, y)) -> (fabs x)
16640 if (N0.getOpcode() == ISD::FNEG || N0.getOpcode() == ISD::FCOPYSIGN)
16641 return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0.getOperand(0));
16642
16643 if (SDValue Cast = foldSignChangeInBitcast(N))
16644 return Cast;
16645
16646 return SDValue();
16647 }
16648
visitBRCOND(SDNode * N)16649 SDValue DAGCombiner::visitBRCOND(SDNode *N) {
16650 SDValue Chain = N->getOperand(0);
16651 SDValue N1 = N->getOperand(1);
16652 SDValue N2 = N->getOperand(2);
16653
16654 // BRCOND(FREEZE(cond)) is equivalent to BRCOND(cond) (both are
16655 // nondeterministic jumps).
16656 if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse()) {
16657 return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
16658 N1->getOperand(0), N2);
16659 }
16660
16661 // If N is a constant we could fold this into a fallthrough or unconditional
16662 // branch. However that doesn't happen very often in normal code, because
16663 // Instcombine/SimplifyCFG should have handled the available opportunities.
16664 // If we did this folding here, it would be necessary to update the
16665 // MachineBasicBlock CFG, which is awkward.
16666
16667 // fold a brcond with a setcc condition into a BR_CC node if BR_CC is legal
16668 // on the target.
16669 if (N1.getOpcode() == ISD::SETCC &&
16670 TLI.isOperationLegalOrCustom(ISD::BR_CC,
16671 N1.getOperand(0).getValueType())) {
16672 return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
16673 Chain, N1.getOperand(2),
16674 N1.getOperand(0), N1.getOperand(1), N2);
16675 }
16676
16677 if (N1.hasOneUse()) {
16678 // rebuildSetCC calls visitXor which may change the Chain when there is a
16679 // STRICT_FSETCC/STRICT_FSETCCS involved. Use a handle to track changes.
16680 HandleSDNode ChainHandle(Chain);
16681 if (SDValue NewN1 = rebuildSetCC(N1))
16682 return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other,
16683 ChainHandle.getValue(), NewN1, N2);
16684 }
16685
16686 return SDValue();
16687 }
16688
rebuildSetCC(SDValue N)16689 SDValue DAGCombiner::rebuildSetCC(SDValue N) {
16690 if (N.getOpcode() == ISD::SRL ||
16691 (N.getOpcode() == ISD::TRUNCATE &&
16692 (N.getOperand(0).hasOneUse() &&
16693 N.getOperand(0).getOpcode() == ISD::SRL))) {
16694 // Look pass the truncate.
16695 if (N.getOpcode() == ISD::TRUNCATE)
16696 N = N.getOperand(0);
16697
16698 // Match this pattern so that we can generate simpler code:
16699 //
16700 // %a = ...
16701 // %b = and i32 %a, 2
16702 // %c = srl i32 %b, 1
16703 // brcond i32 %c ...
16704 //
16705 // into
16706 //
16707 // %a = ...
16708 // %b = and i32 %a, 2
16709 // %c = setcc eq %b, 0
16710 // brcond %c ...
16711 //
16712 // This applies only when the AND constant value has one bit set and the
16713 // SRL constant is equal to the log2 of the AND constant. The back-end is
16714 // smart enough to convert the result into a TEST/JMP sequence.
16715 SDValue Op0 = N.getOperand(0);
16716 SDValue Op1 = N.getOperand(1);
16717
16718 if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) {
16719 SDValue AndOp1 = Op0.getOperand(1);
16720
16721 if (AndOp1.getOpcode() == ISD::Constant) {
16722 const APInt &AndConst = cast<ConstantSDNode>(AndOp1)->getAPIntValue();
16723
16724 if (AndConst.isPowerOf2() &&
16725 cast<ConstantSDNode>(Op1)->getAPIntValue() == AndConst.logBase2()) {
16726 SDLoc DL(N);
16727 return DAG.getSetCC(DL, getSetCCResultType(Op0.getValueType()),
16728 Op0, DAG.getConstant(0, DL, Op0.getValueType()),
16729 ISD::SETNE);
16730 }
16731 }
16732 }
16733 }
16734
16735 // Transform (brcond (xor x, y)) -> (brcond (setcc, x, y, ne))
16736 // Transform (brcond (xor (xor x, y), -1)) -> (brcond (setcc, x, y, eq))
16737 if (N.getOpcode() == ISD::XOR) {
16738 // Because we may call this on a speculatively constructed
16739 // SimplifiedSetCC Node, we need to simplify this node first.
16740 // Ideally this should be folded into SimplifySetCC and not
16741 // here. For now, grab a handle to N so we don't lose it from
16742 // replacements interal to the visit.
16743 HandleSDNode XORHandle(N);
16744 while (N.getOpcode() == ISD::XOR) {
16745 SDValue Tmp = visitXOR(N.getNode());
16746 // No simplification done.
16747 if (!Tmp.getNode())
16748 break;
16749 // Returning N is form in-visit replacement that may invalidated
16750 // N. Grab value from Handle.
16751 if (Tmp.getNode() == N.getNode())
16752 N = XORHandle.getValue();
16753 else // Node simplified. Try simplifying again.
16754 N = Tmp;
16755 }
16756
16757 if (N.getOpcode() != ISD::XOR)
16758 return N;
16759
16760 SDValue Op0 = N->getOperand(0);
16761 SDValue Op1 = N->getOperand(1);
16762
16763 if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) {
16764 bool Equal = false;
16765 // (brcond (xor (xor x, y), -1)) -> (brcond (setcc x, y, eq))
16766 if (isBitwiseNot(N) && Op0.hasOneUse() && Op0.getOpcode() == ISD::XOR &&
16767 Op0.getValueType() == MVT::i1) {
16768 N = Op0;
16769 Op0 = N->getOperand(0);
16770 Op1 = N->getOperand(1);
16771 Equal = true;
16772 }
16773
16774 EVT SetCCVT = N.getValueType();
16775 if (LegalTypes)
16776 SetCCVT = getSetCCResultType(SetCCVT);
16777 // Replace the uses of XOR with SETCC
16778 return DAG.getSetCC(SDLoc(N), SetCCVT, Op0, Op1,
16779 Equal ? ISD::SETEQ : ISD::SETNE);
16780 }
16781 }
16782
16783 return SDValue();
16784 }
16785
16786 // Operand List for BR_CC: Chain, CondCC, CondLHS, CondRHS, DestBB.
16787 //
visitBR_CC(SDNode * N)16788 SDValue DAGCombiner::visitBR_CC(SDNode *N) {
16789 CondCodeSDNode *CC = cast<CondCodeSDNode>(N->getOperand(1));
16790 SDValue CondLHS = N->getOperand(2), CondRHS = N->getOperand(3);
16791
16792 // If N is a constant we could fold this into a fallthrough or unconditional
16793 // branch. However that doesn't happen very often in normal code, because
16794 // Instcombine/SimplifyCFG should have handled the available opportunities.
16795 // If we did this folding here, it would be necessary to update the
16796 // MachineBasicBlock CFG, which is awkward.
16797
16798 // Use SimplifySetCC to simplify SETCC's.
16799 SDValue Simp = SimplifySetCC(getSetCCResultType(CondLHS.getValueType()),
16800 CondLHS, CondRHS, CC->get(), SDLoc(N),
16801 false);
16802 if (Simp.getNode()) AddToWorklist(Simp.getNode());
16803
16804 // fold to a simpler setcc
16805 if (Simp.getNode() && Simp.getOpcode() == ISD::SETCC)
16806 return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
16807 N->getOperand(0), Simp.getOperand(2),
16808 Simp.getOperand(0), Simp.getOperand(1),
16809 N->getOperand(4));
16810
16811 return SDValue();
16812 }
16813
getCombineLoadStoreParts(SDNode * N,unsigned Inc,unsigned Dec,bool & IsLoad,bool & IsMasked,SDValue & Ptr,const TargetLowering & TLI)16814 static bool getCombineLoadStoreParts(SDNode *N, unsigned Inc, unsigned Dec,
16815 bool &IsLoad, bool &IsMasked, SDValue &Ptr,
16816 const TargetLowering &TLI) {
16817 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
16818 if (LD->isIndexed())
16819 return false;
16820 EVT VT = LD->getMemoryVT();
16821 if (!TLI.isIndexedLoadLegal(Inc, VT) && !TLI.isIndexedLoadLegal(Dec, VT))
16822 return false;
16823 Ptr = LD->getBasePtr();
16824 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
16825 if (ST->isIndexed())
16826 return false;
16827 EVT VT = ST->getMemoryVT();
16828 if (!TLI.isIndexedStoreLegal(Inc, VT) && !TLI.isIndexedStoreLegal(Dec, VT))
16829 return false;
16830 Ptr = ST->getBasePtr();
16831 IsLoad = false;
16832 } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(N)) {
16833 if (LD->isIndexed())
16834 return false;
16835 EVT VT = LD->getMemoryVT();
16836 if (!TLI.isIndexedMaskedLoadLegal(Inc, VT) &&
16837 !TLI.isIndexedMaskedLoadLegal(Dec, VT))
16838 return false;
16839 Ptr = LD->getBasePtr();
16840 IsMasked = true;
16841 } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(N)) {
16842 if (ST->isIndexed())
16843 return false;
16844 EVT VT = ST->getMemoryVT();
16845 if (!TLI.isIndexedMaskedStoreLegal(Inc, VT) &&
16846 !TLI.isIndexedMaskedStoreLegal(Dec, VT))
16847 return false;
16848 Ptr = ST->getBasePtr();
16849 IsLoad = false;
16850 IsMasked = true;
16851 } else {
16852 return false;
16853 }
16854 return true;
16855 }
16856
16857 /// Try turning a load/store into a pre-indexed load/store when the base
16858 /// pointer is an add or subtract and it has other uses besides the load/store.
16859 /// After the transformation, the new indexed load/store has effectively folded
16860 /// the add/subtract in and all of its other uses are redirected to the
16861 /// new load/store.
CombineToPreIndexedLoadStore(SDNode * N)16862 bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
16863 if (Level < AfterLegalizeDAG)
16864 return false;
16865
16866 bool IsLoad = true;
16867 bool IsMasked = false;
16868 SDValue Ptr;
16869 if (!getCombineLoadStoreParts(N, ISD::PRE_INC, ISD::PRE_DEC, IsLoad, IsMasked,
16870 Ptr, TLI))
16871 return false;
16872
16873 // If the pointer is not an add/sub, or if it doesn't have multiple uses, bail
16874 // out. There is no reason to make this a preinc/predec.
16875 if ((Ptr.getOpcode() != ISD::ADD && Ptr.getOpcode() != ISD::SUB) ||
16876 Ptr->hasOneUse())
16877 return false;
16878
16879 // Ask the target to do addressing mode selection.
16880 SDValue BasePtr;
16881 SDValue Offset;
16882 ISD::MemIndexedMode AM = ISD::UNINDEXED;
16883 if (!TLI.getPreIndexedAddressParts(N, BasePtr, Offset, AM, DAG))
16884 return false;
16885
16886 // Backends without true r+i pre-indexed forms may need to pass a
16887 // constant base with a variable offset so that constant coercion
16888 // will work with the patterns in canonical form.
16889 bool Swapped = false;
16890 if (isa<ConstantSDNode>(BasePtr)) {
16891 std::swap(BasePtr, Offset);
16892 Swapped = true;
16893 }
16894
16895 // Don't create a indexed load / store with zero offset.
16896 if (isNullConstant(Offset))
16897 return false;
16898
16899 // Try turning it into a pre-indexed load / store except when:
16900 // 1) The new base ptr is a frame index.
16901 // 2) If N is a store and the new base ptr is either the same as or is a
16902 // predecessor of the value being stored.
16903 // 3) Another use of old base ptr is a predecessor of N. If ptr is folded
16904 // that would create a cycle.
16905 // 4) All uses are load / store ops that use it as old base ptr.
16906
16907 // Check #1. Preinc'ing a frame index would require copying the stack pointer
16908 // (plus the implicit offset) to a register to preinc anyway.
16909 if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
16910 return false;
16911
16912 // Check #2.
16913 if (!IsLoad) {
16914 SDValue Val = IsMasked ? cast<MaskedStoreSDNode>(N)->getValue()
16915 : cast<StoreSDNode>(N)->getValue();
16916
16917 // Would require a copy.
16918 if (Val == BasePtr)
16919 return false;
16920
16921 // Would create a cycle.
16922 if (Val == Ptr || Ptr->isPredecessorOf(Val.getNode()))
16923 return false;
16924 }
16925
16926 // Caches for hasPredecessorHelper.
16927 SmallPtrSet<const SDNode *, 32> Visited;
16928 SmallVector<const SDNode *, 16> Worklist;
16929 Worklist.push_back(N);
16930
16931 // If the offset is a constant, there may be other adds of constants that
16932 // can be folded with this one. We should do this to avoid having to keep
16933 // a copy of the original base pointer.
16934 SmallVector<SDNode *, 16> OtherUses;
16935 if (isa<ConstantSDNode>(Offset))
16936 for (SDNode::use_iterator UI = BasePtr->use_begin(),
16937 UE = BasePtr->use_end();
16938 UI != UE; ++UI) {
16939 SDUse &Use = UI.getUse();
16940 // Skip the use that is Ptr and uses of other results from BasePtr's
16941 // node (important for nodes that return multiple results).
16942 if (Use.getUser() == Ptr.getNode() || Use != BasePtr)
16943 continue;
16944
16945 if (SDNode::hasPredecessorHelper(Use.getUser(), Visited, Worklist))
16946 continue;
16947
16948 if (Use.getUser()->getOpcode() != ISD::ADD &&
16949 Use.getUser()->getOpcode() != ISD::SUB) {
16950 OtherUses.clear();
16951 break;
16952 }
16953
16954 SDValue Op1 = Use.getUser()->getOperand((UI.getOperandNo() + 1) & 1);
16955 if (!isa<ConstantSDNode>(Op1)) {
16956 OtherUses.clear();
16957 break;
16958 }
16959
16960 // FIXME: In some cases, we can be smarter about this.
16961 if (Op1.getValueType() != Offset.getValueType()) {
16962 OtherUses.clear();
16963 break;
16964 }
16965
16966 OtherUses.push_back(Use.getUser());
16967 }
16968
16969 if (Swapped)
16970 std::swap(BasePtr, Offset);
16971
16972 // Now check for #3 and #4.
16973 bool RealUse = false;
16974
16975 for (SDNode *Use : Ptr->uses()) {
16976 if (Use == N)
16977 continue;
16978 if (SDNode::hasPredecessorHelper(Use, Visited, Worklist))
16979 return false;
16980
16981 // If Ptr may be folded in addressing mode of other use, then it's
16982 // not profitable to do this transformation.
16983 if (!canFoldInAddressingMode(Ptr.getNode(), Use, DAG, TLI))
16984 RealUse = true;
16985 }
16986
16987 if (!RealUse)
16988 return false;
16989
16990 SDValue Result;
16991 if (!IsMasked) {
16992 if (IsLoad)
16993 Result = DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
16994 else
16995 Result =
16996 DAG.getIndexedStore(SDValue(N, 0), SDLoc(N), BasePtr, Offset, AM);
16997 } else {
16998 if (IsLoad)
16999 Result = DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
17000 Offset, AM);
17001 else
17002 Result = DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N), BasePtr,
17003 Offset, AM);
17004 }
17005 ++PreIndexedNodes;
17006 ++NodesCombined;
17007 LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: ";
17008 Result.dump(&DAG); dbgs() << '\n');
17009 WorklistRemover DeadNodes(*this);
17010 if (IsLoad) {
17011 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
17012 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
17013 } else {
17014 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
17015 }
17016
17017 // Finally, since the node is now dead, remove it from the graph.
17018 deleteAndRecombine(N);
17019
17020 if (Swapped)
17021 std::swap(BasePtr, Offset);
17022
17023 // Replace other uses of BasePtr that can be updated to use Ptr
17024 for (unsigned i = 0, e = OtherUses.size(); i != e; ++i) {
17025 unsigned OffsetIdx = 1;
17026 if (OtherUses[i]->getOperand(OffsetIdx).getNode() == BasePtr.getNode())
17027 OffsetIdx = 0;
17028 assert(OtherUses[i]->getOperand(!OffsetIdx).getNode() ==
17029 BasePtr.getNode() && "Expected BasePtr operand");
17030
17031 // We need to replace ptr0 in the following expression:
17032 // x0 * offset0 + y0 * ptr0 = t0
17033 // knowing that
17034 // x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
17035 //
17036 // where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
17037 // indexed load/store and the expression that needs to be re-written.
17038 //
17039 // Therefore, we have:
17040 // t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
17041
17042 auto *CN = cast<ConstantSDNode>(OtherUses[i]->getOperand(OffsetIdx));
17043 const APInt &Offset0 = CN->getAPIntValue();
17044 const APInt &Offset1 = cast<ConstantSDNode>(Offset)->getAPIntValue();
17045 int X0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1;
17046 int Y0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 0) ? -1 : 1;
17047 int X1 = (AM == ISD::PRE_DEC && !Swapped) ? -1 : 1;
17048 int Y1 = (AM == ISD::PRE_DEC && Swapped) ? -1 : 1;
17049
17050 unsigned Opcode = (Y0 * Y1 < 0) ? ISD::SUB : ISD::ADD;
17051
17052 APInt CNV = Offset0;
17053 if (X0 < 0) CNV = -CNV;
17054 if (X1 * Y0 * Y1 < 0) CNV = CNV + Offset1;
17055 else CNV = CNV - Offset1;
17056
17057 SDLoc DL(OtherUses[i]);
17058
17059 // We can now generate the new expression.
17060 SDValue NewOp1 = DAG.getConstant(CNV, DL, CN->getValueType(0));
17061 SDValue NewOp2 = Result.getValue(IsLoad ? 1 : 0);
17062
17063 SDValue NewUse = DAG.getNode(Opcode,
17064 DL,
17065 OtherUses[i]->getValueType(0), NewOp1, NewOp2);
17066 DAG.ReplaceAllUsesOfValueWith(SDValue(OtherUses[i], 0), NewUse);
17067 deleteAndRecombine(OtherUses[i]);
17068 }
17069
17070 // Replace the uses of Ptr with uses of the updated base value.
17071 DAG.ReplaceAllUsesOfValueWith(Ptr, Result.getValue(IsLoad ? 1 : 0));
17072 deleteAndRecombine(Ptr.getNode());
17073 AddToWorklist(Result.getNode());
17074
17075 return true;
17076 }
17077
shouldCombineToPostInc(SDNode * N,SDValue Ptr,SDNode * PtrUse,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)17078 static bool shouldCombineToPostInc(SDNode *N, SDValue Ptr, SDNode *PtrUse,
17079 SDValue &BasePtr, SDValue &Offset,
17080 ISD::MemIndexedMode &AM,
17081 SelectionDAG &DAG,
17082 const TargetLowering &TLI) {
17083 if (PtrUse == N ||
17084 (PtrUse->getOpcode() != ISD::ADD && PtrUse->getOpcode() != ISD::SUB))
17085 return false;
17086
17087 if (!TLI.getPostIndexedAddressParts(N, PtrUse, BasePtr, Offset, AM, DAG))
17088 return false;
17089
17090 // Don't create a indexed load / store with zero offset.
17091 if (isNullConstant(Offset))
17092 return false;
17093
17094 if (isa<FrameIndexSDNode>(BasePtr) || isa<RegisterSDNode>(BasePtr))
17095 return false;
17096
17097 SmallPtrSet<const SDNode *, 32> Visited;
17098 for (SDNode *Use : BasePtr->uses()) {
17099 if (Use == Ptr.getNode())
17100 continue;
17101
17102 // No if there's a later user which could perform the index instead.
17103 if (isa<MemSDNode>(Use)) {
17104 bool IsLoad = true;
17105 bool IsMasked = false;
17106 SDValue OtherPtr;
17107 if (getCombineLoadStoreParts(Use, ISD::POST_INC, ISD::POST_DEC, IsLoad,
17108 IsMasked, OtherPtr, TLI)) {
17109 SmallVector<const SDNode *, 2> Worklist;
17110 Worklist.push_back(Use);
17111 if (SDNode::hasPredecessorHelper(N, Visited, Worklist))
17112 return false;
17113 }
17114 }
17115
17116 // If all the uses are load / store addresses, then don't do the
17117 // transformation.
17118 if (Use->getOpcode() == ISD::ADD || Use->getOpcode() == ISD::SUB) {
17119 for (SDNode *UseUse : Use->uses())
17120 if (canFoldInAddressingMode(Use, UseUse, DAG, TLI))
17121 return false;
17122 }
17123 }
17124 return true;
17125 }
17126
getPostIndexedLoadStoreOp(SDNode * N,bool & IsLoad,bool & IsMasked,SDValue & Ptr,SDValue & BasePtr,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG,const TargetLowering & TLI)17127 static SDNode *getPostIndexedLoadStoreOp(SDNode *N, bool &IsLoad,
17128 bool &IsMasked, SDValue &Ptr,
17129 SDValue &BasePtr, SDValue &Offset,
17130 ISD::MemIndexedMode &AM,
17131 SelectionDAG &DAG,
17132 const TargetLowering &TLI) {
17133 if (!getCombineLoadStoreParts(N, ISD::POST_INC, ISD::POST_DEC, IsLoad,
17134 IsMasked, Ptr, TLI) ||
17135 Ptr->hasOneUse())
17136 return nullptr;
17137
17138 // Try turning it into a post-indexed load / store except when
17139 // 1) All uses are load / store ops that use it as base ptr (and
17140 // it may be folded as addressing mmode).
17141 // 2) Op must be independent of N, i.e. Op is neither a predecessor
17142 // nor a successor of N. Otherwise, if Op is folded that would
17143 // create a cycle.
17144 for (SDNode *Op : Ptr->uses()) {
17145 // Check for #1.
17146 if (!shouldCombineToPostInc(N, Ptr, Op, BasePtr, Offset, AM, DAG, TLI))
17147 continue;
17148
17149 // Check for #2.
17150 SmallPtrSet<const SDNode *, 32> Visited;
17151 SmallVector<const SDNode *, 8> Worklist;
17152 // Ptr is predecessor to both N and Op.
17153 Visited.insert(Ptr.getNode());
17154 Worklist.push_back(N);
17155 Worklist.push_back(Op);
17156 if (!SDNode::hasPredecessorHelper(N, Visited, Worklist) &&
17157 !SDNode::hasPredecessorHelper(Op, Visited, Worklist))
17158 return Op;
17159 }
17160 return nullptr;
17161 }
17162
17163 /// Try to combine a load/store with a add/sub of the base pointer node into a
17164 /// post-indexed load/store. The transformation folded the add/subtract into the
17165 /// new indexed load/store effectively and all of its uses are redirected to the
17166 /// new load/store.
CombineToPostIndexedLoadStore(SDNode * N)17167 bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
17168 if (Level < AfterLegalizeDAG)
17169 return false;
17170
17171 bool IsLoad = true;
17172 bool IsMasked = false;
17173 SDValue Ptr;
17174 SDValue BasePtr;
17175 SDValue Offset;
17176 ISD::MemIndexedMode AM = ISD::UNINDEXED;
17177 SDNode *Op = getPostIndexedLoadStoreOp(N, IsLoad, IsMasked, Ptr, BasePtr,
17178 Offset, AM, DAG, TLI);
17179 if (!Op)
17180 return false;
17181
17182 SDValue Result;
17183 if (!IsMasked)
17184 Result = IsLoad ? DAG.getIndexedLoad(SDValue(N, 0), SDLoc(N), BasePtr,
17185 Offset, AM)
17186 : DAG.getIndexedStore(SDValue(N, 0), SDLoc(N),
17187 BasePtr, Offset, AM);
17188 else
17189 Result = IsLoad ? DAG.getIndexedMaskedLoad(SDValue(N, 0), SDLoc(N),
17190 BasePtr, Offset, AM)
17191 : DAG.getIndexedMaskedStore(SDValue(N, 0), SDLoc(N),
17192 BasePtr, Offset, AM);
17193 ++PostIndexedNodes;
17194 ++NodesCombined;
17195 LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); dbgs() << "\nWith: ";
17196 Result.dump(&DAG); dbgs() << '\n');
17197 WorklistRemover DeadNodes(*this);
17198 if (IsLoad) {
17199 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(0));
17200 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Result.getValue(2));
17201 } else {
17202 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Result.getValue(1));
17203 }
17204
17205 // Finally, since the node is now dead, remove it from the graph.
17206 deleteAndRecombine(N);
17207
17208 // Replace the uses of Use with uses of the updated base value.
17209 DAG.ReplaceAllUsesOfValueWith(SDValue(Op, 0),
17210 Result.getValue(IsLoad ? 1 : 0));
17211 deleteAndRecombine(Op);
17212 return true;
17213 }
17214
17215 /// Return the base-pointer arithmetic from an indexed \p LD.
SplitIndexingFromLoad(LoadSDNode * LD)17216 SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
17217 ISD::MemIndexedMode AM = LD->getAddressingMode();
17218 assert(AM != ISD::UNINDEXED);
17219 SDValue BP = LD->getOperand(1);
17220 SDValue Inc = LD->getOperand(2);
17221
17222 // Some backends use TargetConstants for load offsets, but don't expect
17223 // TargetConstants in general ADD nodes. We can convert these constants into
17224 // regular Constants (if the constant is not opaque).
17225 assert((Inc.getOpcode() != ISD::TargetConstant ||
17226 !cast<ConstantSDNode>(Inc)->isOpaque()) &&
17227 "Cannot split out indexing using opaque target constants");
17228 if (Inc.getOpcode() == ISD::TargetConstant) {
17229 ConstantSDNode *ConstInc = cast<ConstantSDNode>(Inc);
17230 Inc = DAG.getConstant(*ConstInc->getConstantIntValue(), SDLoc(Inc),
17231 ConstInc->getValueType(0));
17232 }
17233
17234 unsigned Opc =
17235 (AM == ISD::PRE_INC || AM == ISD::POST_INC ? ISD::ADD : ISD::SUB);
17236 return DAG.getNode(Opc, SDLoc(LD), BP.getSimpleValueType(), BP, Inc);
17237 }
17238
numVectorEltsOrZero(EVT T)17239 static inline ElementCount numVectorEltsOrZero(EVT T) {
17240 return T.isVector() ? T.getVectorElementCount() : ElementCount::getFixed(0);
17241 }
17242
getTruncatedStoreValue(StoreSDNode * ST,SDValue & Val)17243 bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) {
17244 EVT STType = Val.getValueType();
17245 EVT STMemType = ST->getMemoryVT();
17246 if (STType == STMemType)
17247 return true;
17248 if (isTypeLegal(STMemType))
17249 return false; // fail.
17250 if (STType.isFloatingPoint() && STMemType.isFloatingPoint() &&
17251 TLI.isOperationLegal(ISD::FTRUNC, STMemType)) {
17252 Val = DAG.getNode(ISD::FTRUNC, SDLoc(ST), STMemType, Val);
17253 return true;
17254 }
17255 if (numVectorEltsOrZero(STType) == numVectorEltsOrZero(STMemType) &&
17256 STType.isInteger() && STMemType.isInteger()) {
17257 Val = DAG.getNode(ISD::TRUNCATE, SDLoc(ST), STMemType, Val);
17258 return true;
17259 }
17260 if (STType.getSizeInBits() == STMemType.getSizeInBits()) {
17261 Val = DAG.getBitcast(STMemType, Val);
17262 return true;
17263 }
17264 return false; // fail.
17265 }
17266
extendLoadedValueToExtension(LoadSDNode * LD,SDValue & Val)17267 bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) {
17268 EVT LDMemType = LD->getMemoryVT();
17269 EVT LDType = LD->getValueType(0);
17270 assert(Val.getValueType() == LDMemType &&
17271 "Attempting to extend value of non-matching type");
17272 if (LDType == LDMemType)
17273 return true;
17274 if (LDMemType.isInteger() && LDType.isInteger()) {
17275 switch (LD->getExtensionType()) {
17276 case ISD::NON_EXTLOAD:
17277 Val = DAG.getBitcast(LDType, Val);
17278 return true;
17279 case ISD::EXTLOAD:
17280 Val = DAG.getNode(ISD::ANY_EXTEND, SDLoc(LD), LDType, Val);
17281 return true;
17282 case ISD::SEXTLOAD:
17283 Val = DAG.getNode(ISD::SIGN_EXTEND, SDLoc(LD), LDType, Val);
17284 return true;
17285 case ISD::ZEXTLOAD:
17286 Val = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(LD), LDType, Val);
17287 return true;
17288 }
17289 }
17290 return false;
17291 }
17292
ForwardStoreValueToDirectLoad(LoadSDNode * LD)17293 SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
17294 if (OptLevel == CodeGenOpt::None || !LD->isSimple())
17295 return SDValue();
17296 SDValue Chain = LD->getOperand(0);
17297 StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain.getNode());
17298 // TODO: Relax this restriction for unordered atomics (see D66309)
17299 if (!ST || !ST->isSimple() || ST->getAddressSpace() != LD->getAddressSpace())
17300 return SDValue();
17301
17302 EVT LDType = LD->getValueType(0);
17303 EVT LDMemType = LD->getMemoryVT();
17304 EVT STMemType = ST->getMemoryVT();
17305 EVT STType = ST->getValue().getValueType();
17306
17307 // There are two cases to consider here:
17308 // 1. The store is fixed width and the load is scalable. In this case we
17309 // don't know at compile time if the store completely envelops the load
17310 // so we abandon the optimisation.
17311 // 2. The store is scalable and the load is fixed width. We could
17312 // potentially support a limited number of cases here, but there has been
17313 // no cost-benefit analysis to prove it's worth it.
17314 bool LdStScalable = LDMemType.isScalableVector();
17315 if (LdStScalable != STMemType.isScalableVector())
17316 return SDValue();
17317
17318 // If we are dealing with scalable vectors on a big endian platform the
17319 // calculation of offsets below becomes trickier, since we do not know at
17320 // compile time the absolute size of the vector. Until we've done more
17321 // analysis on big-endian platforms it seems better to bail out for now.
17322 if (LdStScalable && DAG.getDataLayout().isBigEndian())
17323 return SDValue();
17324
17325 BaseIndexOffset BasePtrLD = BaseIndexOffset::match(LD, DAG);
17326 BaseIndexOffset BasePtrST = BaseIndexOffset::match(ST, DAG);
17327 int64_t Offset;
17328 if (!BasePtrST.equalBaseIndex(BasePtrLD, DAG, Offset))
17329 return SDValue();
17330
17331 // Normalize for Endianness. After this Offset=0 will denote that the least
17332 // significant bit in the loaded value maps to the least significant bit in
17333 // the stored value). With Offset=n (for n > 0) the loaded value starts at the
17334 // n:th least significant byte of the stored value.
17335 int64_t OrigOffset = Offset;
17336 if (DAG.getDataLayout().isBigEndian())
17337 Offset = ((int64_t)STMemType.getStoreSizeInBits().getFixedValue() -
17338 (int64_t)LDMemType.getStoreSizeInBits().getFixedValue()) /
17339 8 -
17340 Offset;
17341
17342 // Check that the stored value cover all bits that are loaded.
17343 bool STCoversLD;
17344
17345 TypeSize LdMemSize = LDMemType.getSizeInBits();
17346 TypeSize StMemSize = STMemType.getSizeInBits();
17347 if (LdStScalable)
17348 STCoversLD = (Offset == 0) && LdMemSize == StMemSize;
17349 else
17350 STCoversLD = (Offset >= 0) && (Offset * 8 + LdMemSize.getFixedValue() <=
17351 StMemSize.getFixedValue());
17352
17353 auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue {
17354 if (LD->isIndexed()) {
17355 // Cannot handle opaque target constants and we must respect the user's
17356 // request not to split indexes from loads.
17357 if (!canSplitIdx(LD))
17358 return SDValue();
17359 SDValue Idx = SplitIndexingFromLoad(LD);
17360 SDValue Ops[] = {Val, Idx, Chain};
17361 return CombineTo(LD, Ops, 3);
17362 }
17363 return CombineTo(LD, Val, Chain);
17364 };
17365
17366 if (!STCoversLD)
17367 return SDValue();
17368
17369 // Memory as copy space (potentially masked).
17370 if (Offset == 0 && LDType == STType && STMemType == LDMemType) {
17371 // Simple case: Direct non-truncating forwarding
17372 if (LDType.getSizeInBits() == LdMemSize)
17373 return ReplaceLd(LD, ST->getValue(), Chain);
17374 // Can we model the truncate and extension with an and mask?
17375 if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() &&
17376 !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) {
17377 // Mask to size of LDMemType
17378 auto Mask =
17379 DAG.getConstant(APInt::getLowBitsSet(STType.getFixedSizeInBits(),
17380 StMemSize.getFixedValue()),
17381 SDLoc(ST), STType);
17382 auto Val = DAG.getNode(ISD::AND, SDLoc(LD), LDType, ST->getValue(), Mask);
17383 return ReplaceLd(LD, Val, Chain);
17384 }
17385 }
17386
17387 // Handle some cases for big-endian that would be Offset 0 and handled for
17388 // little-endian.
17389 SDValue Val = ST->getValue();
17390 if (DAG.getDataLayout().isBigEndian() && Offset > 0 && OrigOffset == 0) {
17391 if (STType.isInteger() && !STType.isVector() && LDType.isInteger() &&
17392 !LDType.isVector() && isTypeLegal(STType) &&
17393 TLI.isOperationLegal(ISD::SRL, STType)) {
17394 Val = DAG.getNode(ISD::SRL, SDLoc(LD), STType, Val,
17395 DAG.getConstant(Offset * 8, SDLoc(LD), STType));
17396 Offset = 0;
17397 }
17398 }
17399
17400 // TODO: Deal with nonzero offset.
17401 if (LD->getBasePtr().isUndef() || Offset != 0)
17402 return SDValue();
17403 // Model necessary truncations / extenstions.
17404 // Truncate Value To Stored Memory Size.
17405 do {
17406 if (!getTruncatedStoreValue(ST, Val))
17407 continue;
17408 if (!isTypeLegal(LDMemType))
17409 continue;
17410 if (STMemType != LDMemType) {
17411 // TODO: Support vectors? This requires extract_subvector/bitcast.
17412 if (!STMemType.isVector() && !LDMemType.isVector() &&
17413 STMemType.isInteger() && LDMemType.isInteger())
17414 Val = DAG.getNode(ISD::TRUNCATE, SDLoc(LD), LDMemType, Val);
17415 else
17416 continue;
17417 }
17418 if (!extendLoadedValueToExtension(LD, Val))
17419 continue;
17420 return ReplaceLd(LD, Val, Chain);
17421 } while (false);
17422
17423 // On failure, cleanup dead nodes we may have created.
17424 if (Val->use_empty())
17425 deleteAndRecombine(Val.getNode());
17426 return SDValue();
17427 }
17428
visitLOAD(SDNode * N)17429 SDValue DAGCombiner::visitLOAD(SDNode *N) {
17430 LoadSDNode *LD = cast<LoadSDNode>(N);
17431 SDValue Chain = LD->getChain();
17432 SDValue Ptr = LD->getBasePtr();
17433
17434 // If load is not volatile and there are no uses of the loaded value (and
17435 // the updated indexed value in case of indexed loads), change uses of the
17436 // chain value into uses of the chain input (i.e. delete the dead load).
17437 // TODO: Allow this for unordered atomics (see D66309)
17438 if (LD->isSimple()) {
17439 if (N->getValueType(1) == MVT::Other) {
17440 // Unindexed loads.
17441 if (!N->hasAnyUseOfValue(0)) {
17442 // It's not safe to use the two value CombineTo variant here. e.g.
17443 // v1, chain2 = load chain1, loc
17444 // v2, chain3 = load chain2, loc
17445 // v3 = add v2, c
17446 // Now we replace use of chain2 with chain1. This makes the second load
17447 // isomorphic to the one we are deleting, and thus makes this load live.
17448 LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG);
17449 dbgs() << "\nWith chain: "; Chain.dump(&DAG);
17450 dbgs() << "\n");
17451 WorklistRemover DeadNodes(*this);
17452 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
17453 AddUsersToWorklist(Chain.getNode());
17454 if (N->use_empty())
17455 deleteAndRecombine(N);
17456
17457 return SDValue(N, 0); // Return N so it doesn't get rechecked!
17458 }
17459 } else {
17460 // Indexed loads.
17461 assert(N->getValueType(2) == MVT::Other && "Malformed indexed loads?");
17462
17463 // If this load has an opaque TargetConstant offset, then we cannot split
17464 // the indexing into an add/sub directly (that TargetConstant may not be
17465 // valid for a different type of node, and we cannot convert an opaque
17466 // target constant into a regular constant).
17467 bool CanSplitIdx = canSplitIdx(LD);
17468
17469 if (!N->hasAnyUseOfValue(0) && (CanSplitIdx || !N->hasAnyUseOfValue(1))) {
17470 SDValue Undef = DAG.getUNDEF(N->getValueType(0));
17471 SDValue Index;
17472 if (N->hasAnyUseOfValue(1) && CanSplitIdx) {
17473 Index = SplitIndexingFromLoad(LD);
17474 // Try to fold the base pointer arithmetic into subsequent loads and
17475 // stores.
17476 AddUsersToWorklist(N);
17477 } else
17478 Index = DAG.getUNDEF(N->getValueType(1));
17479 LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG);
17480 dbgs() << "\nWith: "; Undef.dump(&DAG);
17481 dbgs() << " and 2 other values\n");
17482 WorklistRemover DeadNodes(*this);
17483 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Undef);
17484 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Index);
17485 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 2), Chain);
17486 deleteAndRecombine(N);
17487 return SDValue(N, 0); // Return N so it doesn't get rechecked!
17488 }
17489 }
17490 }
17491
17492 // If this load is directly stored, replace the load value with the stored
17493 // value.
17494 if (auto V = ForwardStoreValueToDirectLoad(LD))
17495 return V;
17496
17497 // Try to infer better alignment information than the load already has.
17498 if (OptLevel != CodeGenOpt::None && LD->isUnindexed() && !LD->isAtomic()) {
17499 if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
17500 if (*Alignment > LD->getAlign() &&
17501 isAligned(*Alignment, LD->getSrcValueOffset())) {
17502 SDValue NewLoad = DAG.getExtLoad(
17503 LD->getExtensionType(), SDLoc(N), LD->getValueType(0), Chain, Ptr,
17504 LD->getPointerInfo(), LD->getMemoryVT(), *Alignment,
17505 LD->getMemOperand()->getFlags(), LD->getAAInfo());
17506 // NewLoad will always be N as we are only refining the alignment
17507 assert(NewLoad.getNode() == N);
17508 (void)NewLoad;
17509 }
17510 }
17511 }
17512
17513 if (LD->isUnindexed()) {
17514 // Walk up chain skipping non-aliasing memory nodes.
17515 SDValue BetterChain = FindBetterChain(LD, Chain);
17516
17517 // If there is a better chain.
17518 if (Chain != BetterChain) {
17519 SDValue ReplLoad;
17520
17521 // Replace the chain to void dependency.
17522 if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
17523 ReplLoad = DAG.getLoad(N->getValueType(0), SDLoc(LD),
17524 BetterChain, Ptr, LD->getMemOperand());
17525 } else {
17526 ReplLoad = DAG.getExtLoad(LD->getExtensionType(), SDLoc(LD),
17527 LD->getValueType(0),
17528 BetterChain, Ptr, LD->getMemoryVT(),
17529 LD->getMemOperand());
17530 }
17531
17532 // Create token factor to keep old chain connected.
17533 SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N),
17534 MVT::Other, Chain, ReplLoad.getValue(1));
17535
17536 // Replace uses with load result and token factor
17537 return CombineTo(N, ReplLoad.getValue(0), Token);
17538 }
17539 }
17540
17541 // Try transforming N to an indexed load.
17542 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
17543 return SDValue(N, 0);
17544
17545 // Try to slice up N to more direct loads if the slices are mapped to
17546 // different register banks or pairing can take place.
17547 if (SliceUpLoad(N))
17548 return SDValue(N, 0);
17549
17550 return SDValue();
17551 }
17552
17553 namespace {
17554
17555 /// Helper structure used to slice a load in smaller loads.
17556 /// Basically a slice is obtained from the following sequence:
17557 /// Origin = load Ty1, Base
17558 /// Shift = srl Ty1 Origin, CstTy Amount
17559 /// Inst = trunc Shift to Ty2
17560 ///
17561 /// Then, it will be rewritten into:
17562 /// Slice = load SliceTy, Base + SliceOffset
17563 /// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
17564 ///
17565 /// SliceTy is deduced from the number of bits that are actually used to
17566 /// build Inst.
17567 struct LoadedSlice {
17568 /// Helper structure used to compute the cost of a slice.
17569 struct Cost {
17570 /// Are we optimizing for code size.
17571 bool ForCodeSize = false;
17572
17573 /// Various cost.
17574 unsigned Loads = 0;
17575 unsigned Truncates = 0;
17576 unsigned CrossRegisterBanksCopies = 0;
17577 unsigned ZExts = 0;
17578 unsigned Shift = 0;
17579
Cost__anonbd6f1c503c11::LoadedSlice::Cost17580 explicit Cost(bool ForCodeSize) : ForCodeSize(ForCodeSize) {}
17581
17582 /// Get the cost of one isolated slice.
Cost__anonbd6f1c503c11::LoadedSlice::Cost17583 Cost(const LoadedSlice &LS, bool ForCodeSize)
17584 : ForCodeSize(ForCodeSize), Loads(1) {
17585 EVT TruncType = LS.Inst->getValueType(0);
17586 EVT LoadedType = LS.getLoadedType();
17587 if (TruncType != LoadedType &&
17588 !LS.DAG->getTargetLoweringInfo().isZExtFree(LoadedType, TruncType))
17589 ZExts = 1;
17590 }
17591
17592 /// Account for slicing gain in the current cost.
17593 /// Slicing provide a few gains like removing a shift or a
17594 /// truncate. This method allows to grow the cost of the original
17595 /// load with the gain from this slice.
addSliceGain__anonbd6f1c503c11::LoadedSlice::Cost17596 void addSliceGain(const LoadedSlice &LS) {
17597 // Each slice saves a truncate.
17598 const TargetLowering &TLI = LS.DAG->getTargetLoweringInfo();
17599 if (!TLI.isTruncateFree(LS.Inst->getOperand(0).getValueType(),
17600 LS.Inst->getValueType(0)))
17601 ++Truncates;
17602 // If there is a shift amount, this slice gets rid of it.
17603 if (LS.Shift)
17604 ++Shift;
17605 // If this slice can merge a cross register bank copy, account for it.
17606 if (LS.canMergeExpensiveCrossRegisterBankCopy())
17607 ++CrossRegisterBanksCopies;
17608 }
17609
operator +=__anonbd6f1c503c11::LoadedSlice::Cost17610 Cost &operator+=(const Cost &RHS) {
17611 Loads += RHS.Loads;
17612 Truncates += RHS.Truncates;
17613 CrossRegisterBanksCopies += RHS.CrossRegisterBanksCopies;
17614 ZExts += RHS.ZExts;
17615 Shift += RHS.Shift;
17616 return *this;
17617 }
17618
operator ==__anonbd6f1c503c11::LoadedSlice::Cost17619 bool operator==(const Cost &RHS) const {
17620 return Loads == RHS.Loads && Truncates == RHS.Truncates &&
17621 CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
17622 ZExts == RHS.ZExts && Shift == RHS.Shift;
17623 }
17624
operator !=__anonbd6f1c503c11::LoadedSlice::Cost17625 bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
17626
operator <__anonbd6f1c503c11::LoadedSlice::Cost17627 bool operator<(const Cost &RHS) const {
17628 // Assume cross register banks copies are as expensive as loads.
17629 // FIXME: Do we want some more target hooks?
17630 unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
17631 unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
17632 // Unless we are optimizing for code size, consider the
17633 // expensive operation first.
17634 if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
17635 return ExpensiveOpsLHS < ExpensiveOpsRHS;
17636 return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
17637 (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
17638 }
17639
operator >__anonbd6f1c503c11::LoadedSlice::Cost17640 bool operator>(const Cost &RHS) const { return RHS < *this; }
17641
operator <=__anonbd6f1c503c11::LoadedSlice::Cost17642 bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
17643
operator >=__anonbd6f1c503c11::LoadedSlice::Cost17644 bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
17645 };
17646
17647 // The last instruction that represent the slice. This should be a
17648 // truncate instruction.
17649 SDNode *Inst;
17650
17651 // The original load instruction.
17652 LoadSDNode *Origin;
17653
17654 // The right shift amount in bits from the original load.
17655 unsigned Shift;
17656
17657 // The DAG from which Origin came from.
17658 // This is used to get some contextual information about legal types, etc.
17659 SelectionDAG *DAG;
17660
LoadedSlice__anonbd6f1c503c11::LoadedSlice17661 LoadedSlice(SDNode *Inst = nullptr, LoadSDNode *Origin = nullptr,
17662 unsigned Shift = 0, SelectionDAG *DAG = nullptr)
17663 : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {}
17664
17665 /// Get the bits used in a chunk of bits \p BitWidth large.
17666 /// \return Result is \p BitWidth and has used bits set to 1 and
17667 /// not used bits set to 0.
getUsedBits__anonbd6f1c503c11::LoadedSlice17668 APInt getUsedBits() const {
17669 // Reproduce the trunc(lshr) sequence:
17670 // - Start from the truncated value.
17671 // - Zero extend to the desired bit width.
17672 // - Shift left.
17673 assert(Origin && "No original load to compare against.");
17674 unsigned BitWidth = Origin->getValueSizeInBits(0);
17675 assert(Inst && "This slice is not bound to an instruction");
17676 assert(Inst->getValueSizeInBits(0) <= BitWidth &&
17677 "Extracted slice is bigger than the whole type!");
17678 APInt UsedBits(Inst->getValueSizeInBits(0), 0);
17679 UsedBits.setAllBits();
17680 UsedBits = UsedBits.zext(BitWidth);
17681 UsedBits <<= Shift;
17682 return UsedBits;
17683 }
17684
17685 /// Get the size of the slice to be loaded in bytes.
getLoadedSize__anonbd6f1c503c11::LoadedSlice17686 unsigned getLoadedSize() const {
17687 unsigned SliceSize = getUsedBits().countPopulation();
17688 assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte.");
17689 return SliceSize / 8;
17690 }
17691
17692 /// Get the type that will be loaded for this slice.
17693 /// Note: This may not be the final type for the slice.
getLoadedType__anonbd6f1c503c11::LoadedSlice17694 EVT getLoadedType() const {
17695 assert(DAG && "Missing context");
17696 LLVMContext &Ctxt = *DAG->getContext();
17697 return EVT::getIntegerVT(Ctxt, getLoadedSize() * 8);
17698 }
17699
17700 /// Get the alignment of the load used for this slice.
getAlign__anonbd6f1c503c11::LoadedSlice17701 Align getAlign() const {
17702 Align Alignment = Origin->getAlign();
17703 uint64_t Offset = getOffsetFromBase();
17704 if (Offset != 0)
17705 Alignment = commonAlignment(Alignment, Alignment.value() + Offset);
17706 return Alignment;
17707 }
17708
17709 /// Check if this slice can be rewritten with legal operations.
isLegal__anonbd6f1c503c11::LoadedSlice17710 bool isLegal() const {
17711 // An invalid slice is not legal.
17712 if (!Origin || !Inst || !DAG)
17713 return false;
17714
17715 // Offsets are for indexed load only, we do not handle that.
17716 if (!Origin->getOffset().isUndef())
17717 return false;
17718
17719 const TargetLowering &TLI = DAG->getTargetLoweringInfo();
17720
17721 // Check that the type is legal.
17722 EVT SliceType = getLoadedType();
17723 if (!TLI.isTypeLegal(SliceType))
17724 return false;
17725
17726 // Check that the load is legal for this type.
17727 if (!TLI.isOperationLegal(ISD::LOAD, SliceType))
17728 return false;
17729
17730 // Check that the offset can be computed.
17731 // 1. Check its type.
17732 EVT PtrType = Origin->getBasePtr().getValueType();
17733 if (PtrType == MVT::Untyped || PtrType.isExtended())
17734 return false;
17735
17736 // 2. Check that it fits in the immediate.
17737 if (!TLI.isLegalAddImmediate(getOffsetFromBase()))
17738 return false;
17739
17740 // 3. Check that the computation is legal.
17741 if (!TLI.isOperationLegal(ISD::ADD, PtrType))
17742 return false;
17743
17744 // Check that the zext is legal if it needs one.
17745 EVT TruncateType = Inst->getValueType(0);
17746 if (TruncateType != SliceType &&
17747 !TLI.isOperationLegal(ISD::ZERO_EXTEND, TruncateType))
17748 return false;
17749
17750 return true;
17751 }
17752
17753 /// Get the offset in bytes of this slice in the original chunk of
17754 /// bits.
17755 /// \pre DAG != nullptr.
getOffsetFromBase__anonbd6f1c503c11::LoadedSlice17756 uint64_t getOffsetFromBase() const {
17757 assert(DAG && "Missing context.");
17758 bool IsBigEndian = DAG->getDataLayout().isBigEndian();
17759 assert(!(Shift & 0x7) && "Shifts not aligned on Bytes are not supported.");
17760 uint64_t Offset = Shift / 8;
17761 unsigned TySizeInBytes = Origin->getValueSizeInBits(0) / 8;
17762 assert(!(Origin->getValueSizeInBits(0) & 0x7) &&
17763 "The size of the original loaded type is not a multiple of a"
17764 " byte.");
17765 // If Offset is bigger than TySizeInBytes, it means we are loading all
17766 // zeros. This should have been optimized before in the process.
17767 assert(TySizeInBytes > Offset &&
17768 "Invalid shift amount for given loaded size");
17769 if (IsBigEndian)
17770 Offset = TySizeInBytes - Offset - getLoadedSize();
17771 return Offset;
17772 }
17773
17774 /// Generate the sequence of instructions to load the slice
17775 /// represented by this object and redirect the uses of this slice to
17776 /// this new sequence of instructions.
17777 /// \pre this->Inst && this->Origin are valid Instructions and this
17778 /// object passed the legal check: LoadedSlice::isLegal returned true.
17779 /// \return The last instruction of the sequence used to load the slice.
loadSlice__anonbd6f1c503c11::LoadedSlice17780 SDValue loadSlice() const {
17781 assert(Inst && Origin && "Unable to replace a non-existing slice.");
17782 const SDValue &OldBaseAddr = Origin->getBasePtr();
17783 SDValue BaseAddr = OldBaseAddr;
17784 // Get the offset in that chunk of bytes w.r.t. the endianness.
17785 int64_t Offset = static_cast<int64_t>(getOffsetFromBase());
17786 assert(Offset >= 0 && "Offset too big to fit in int64_t!");
17787 if (Offset) {
17788 // BaseAddr = BaseAddr + Offset.
17789 EVT ArithType = BaseAddr.getValueType();
17790 SDLoc DL(Origin);
17791 BaseAddr = DAG->getNode(ISD::ADD, DL, ArithType, BaseAddr,
17792 DAG->getConstant(Offset, DL, ArithType));
17793 }
17794
17795 // Create the type of the loaded slice according to its size.
17796 EVT SliceType = getLoadedType();
17797
17798 // Create the load for the slice.
17799 SDValue LastInst =
17800 DAG->getLoad(SliceType, SDLoc(Origin), Origin->getChain(), BaseAddr,
17801 Origin->getPointerInfo().getWithOffset(Offset), getAlign(),
17802 Origin->getMemOperand()->getFlags());
17803 // If the final type is not the same as the loaded type, this means that
17804 // we have to pad with zero. Create a zero extend for that.
17805 EVT FinalType = Inst->getValueType(0);
17806 if (SliceType != FinalType)
17807 LastInst =
17808 DAG->getNode(ISD::ZERO_EXTEND, SDLoc(LastInst), FinalType, LastInst);
17809 return LastInst;
17810 }
17811
17812 /// Check if this slice can be merged with an expensive cross register
17813 /// bank copy. E.g.,
17814 /// i = load i32
17815 /// f = bitcast i32 i to float
canMergeExpensiveCrossRegisterBankCopy__anonbd6f1c503c11::LoadedSlice17816 bool canMergeExpensiveCrossRegisterBankCopy() const {
17817 if (!Inst || !Inst->hasOneUse())
17818 return false;
17819 SDNode *Use = *Inst->use_begin();
17820 if (Use->getOpcode() != ISD::BITCAST)
17821 return false;
17822 assert(DAG && "Missing context");
17823 const TargetLowering &TLI = DAG->getTargetLoweringInfo();
17824 EVT ResVT = Use->getValueType(0);
17825 const TargetRegisterClass *ResRC =
17826 TLI.getRegClassFor(ResVT.getSimpleVT(), Use->isDivergent());
17827 const TargetRegisterClass *ArgRC =
17828 TLI.getRegClassFor(Use->getOperand(0).getValueType().getSimpleVT(),
17829 Use->getOperand(0)->isDivergent());
17830 if (ArgRC == ResRC || !TLI.isOperationLegal(ISD::LOAD, ResVT))
17831 return false;
17832
17833 // At this point, we know that we perform a cross-register-bank copy.
17834 // Check if it is expensive.
17835 const TargetRegisterInfo *TRI = DAG->getSubtarget().getRegisterInfo();
17836 // Assume bitcasts are cheap, unless both register classes do not
17837 // explicitly share a common sub class.
17838 if (!TRI || TRI->getCommonSubClass(ArgRC, ResRC))
17839 return false;
17840
17841 // Check if it will be merged with the load.
17842 // 1. Check the alignment / fast memory access constraint.
17843 unsigned IsFast = 0;
17844 if (!TLI.allowsMemoryAccess(*DAG->getContext(), DAG->getDataLayout(), ResVT,
17845 Origin->getAddressSpace(), getAlign(),
17846 Origin->getMemOperand()->getFlags(), &IsFast) ||
17847 !IsFast)
17848 return false;
17849
17850 // 2. Check that the load is a legal operation for that type.
17851 if (!TLI.isOperationLegal(ISD::LOAD, ResVT))
17852 return false;
17853
17854 // 3. Check that we do not have a zext in the way.
17855 if (Inst->getValueType(0) != getLoadedType())
17856 return false;
17857
17858 return true;
17859 }
17860 };
17861
17862 } // end anonymous namespace
17863
17864 /// Check that all bits set in \p UsedBits form a dense region, i.e.,
17865 /// \p UsedBits looks like 0..0 1..1 0..0.
areUsedBitsDense(const APInt & UsedBits)17866 static bool areUsedBitsDense(const APInt &UsedBits) {
17867 // If all the bits are one, this is dense!
17868 if (UsedBits.isAllOnes())
17869 return true;
17870
17871 // Get rid of the unused bits on the right.
17872 APInt NarrowedUsedBits = UsedBits.lshr(UsedBits.countTrailingZeros());
17873 // Get rid of the unused bits on the left.
17874 if (NarrowedUsedBits.countLeadingZeros())
17875 NarrowedUsedBits = NarrowedUsedBits.trunc(NarrowedUsedBits.getActiveBits());
17876 // Check that the chunk of bits is completely used.
17877 return NarrowedUsedBits.isAllOnes();
17878 }
17879
17880 /// Check whether or not \p First and \p Second are next to each other
17881 /// in memory. This means that there is no hole between the bits loaded
17882 /// by \p First and the bits loaded by \p Second.
areSlicesNextToEachOther(const LoadedSlice & First,const LoadedSlice & Second)17883 static bool areSlicesNextToEachOther(const LoadedSlice &First,
17884 const LoadedSlice &Second) {
17885 assert(First.Origin == Second.Origin && First.Origin &&
17886 "Unable to match different memory origins.");
17887 APInt UsedBits = First.getUsedBits();
17888 assert((UsedBits & Second.getUsedBits()) == 0 &&
17889 "Slices are not supposed to overlap.");
17890 UsedBits |= Second.getUsedBits();
17891 return areUsedBitsDense(UsedBits);
17892 }
17893
17894 /// Adjust the \p GlobalLSCost according to the target
17895 /// paring capabilities and the layout of the slices.
17896 /// \pre \p GlobalLSCost should account for at least as many loads as
17897 /// there is in the slices in \p LoadedSlices.
adjustCostForPairing(SmallVectorImpl<LoadedSlice> & LoadedSlices,LoadedSlice::Cost & GlobalLSCost)17898 static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
17899 LoadedSlice::Cost &GlobalLSCost) {
17900 unsigned NumberOfSlices = LoadedSlices.size();
17901 // If there is less than 2 elements, no pairing is possible.
17902 if (NumberOfSlices < 2)
17903 return;
17904
17905 // Sort the slices so that elements that are likely to be next to each
17906 // other in memory are next to each other in the list.
17907 llvm::sort(LoadedSlices, [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
17908 assert(LHS.Origin == RHS.Origin && "Different bases not implemented.");
17909 return LHS.getOffsetFromBase() < RHS.getOffsetFromBase();
17910 });
17911 const TargetLowering &TLI = LoadedSlices[0].DAG->getTargetLoweringInfo();
17912 // First (resp. Second) is the first (resp. Second) potentially candidate
17913 // to be placed in a paired load.
17914 const LoadedSlice *First = nullptr;
17915 const LoadedSlice *Second = nullptr;
17916 for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice,
17917 // Set the beginning of the pair.
17918 First = Second) {
17919 Second = &LoadedSlices[CurrSlice];
17920
17921 // If First is NULL, it means we start a new pair.
17922 // Get to the next slice.
17923 if (!First)
17924 continue;
17925
17926 EVT LoadedType = First->getLoadedType();
17927
17928 // If the types of the slices are different, we cannot pair them.
17929 if (LoadedType != Second->getLoadedType())
17930 continue;
17931
17932 // Check if the target supplies paired loads for this type.
17933 Align RequiredAlignment;
17934 if (!TLI.hasPairedLoad(LoadedType, RequiredAlignment)) {
17935 // move to the next pair, this type is hopeless.
17936 Second = nullptr;
17937 continue;
17938 }
17939 // Check if we meet the alignment requirement.
17940 if (First->getAlign() < RequiredAlignment)
17941 continue;
17942
17943 // Check that both loads are next to each other in memory.
17944 if (!areSlicesNextToEachOther(*First, *Second))
17945 continue;
17946
17947 assert(GlobalLSCost.Loads > 0 && "We save more loads than we created!");
17948 --GlobalLSCost.Loads;
17949 // Move to the next pair.
17950 Second = nullptr;
17951 }
17952 }
17953
17954 /// Check the profitability of all involved LoadedSlice.
17955 /// Currently, it is considered profitable if there is exactly two
17956 /// involved slices (1) which are (2) next to each other in memory, and
17957 /// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3).
17958 ///
17959 /// Note: The order of the elements in \p LoadedSlices may be modified, but not
17960 /// the elements themselves.
17961 ///
17962 /// FIXME: When the cost model will be mature enough, we can relax
17963 /// constraints (1) and (2).
isSlicingProfitable(SmallVectorImpl<LoadedSlice> & LoadedSlices,const APInt & UsedBits,bool ForCodeSize)17964 static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
17965 const APInt &UsedBits, bool ForCodeSize) {
17966 unsigned NumberOfSlices = LoadedSlices.size();
17967 if (StressLoadSlicing)
17968 return NumberOfSlices > 1;
17969
17970 // Check (1).
17971 if (NumberOfSlices != 2)
17972 return false;
17973
17974 // Check (2).
17975 if (!areUsedBitsDense(UsedBits))
17976 return false;
17977
17978 // Check (3).
17979 LoadedSlice::Cost OrigCost(ForCodeSize), GlobalSlicingCost(ForCodeSize);
17980 // The original code has one big load.
17981 OrigCost.Loads = 1;
17982 for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice) {
17983 const LoadedSlice &LS = LoadedSlices[CurrSlice];
17984 // Accumulate the cost of all the slices.
17985 LoadedSlice::Cost SliceCost(LS, ForCodeSize);
17986 GlobalSlicingCost += SliceCost;
17987
17988 // Account as cost in the original configuration the gain obtained
17989 // with the current slices.
17990 OrigCost.addSliceGain(LS);
17991 }
17992
17993 // If the target supports paired load, adjust the cost accordingly.
17994 adjustCostForPairing(LoadedSlices, GlobalSlicingCost);
17995 return OrigCost > GlobalSlicingCost;
17996 }
17997
17998 /// If the given load, \p LI, is used only by trunc or trunc(lshr)
17999 /// operations, split it in the various pieces being extracted.
18000 ///
18001 /// This sort of thing is introduced by SROA.
18002 /// This slicing takes care not to insert overlapping loads.
18003 /// \pre LI is a simple load (i.e., not an atomic or volatile load).
SliceUpLoad(SDNode * N)18004 bool DAGCombiner::SliceUpLoad(SDNode *N) {
18005 if (Level < AfterLegalizeDAG)
18006 return false;
18007
18008 LoadSDNode *LD = cast<LoadSDNode>(N);
18009 if (!LD->isSimple() || !ISD::isNormalLoad(LD) ||
18010 !LD->getValueType(0).isInteger())
18011 return false;
18012
18013 // The algorithm to split up a load of a scalable vector into individual
18014 // elements currently requires knowing the length of the loaded type,
18015 // so will need adjusting to work on scalable vectors.
18016 if (LD->getValueType(0).isScalableVector())
18017 return false;
18018
18019 // Keep track of already used bits to detect overlapping values.
18020 // In that case, we will just abort the transformation.
18021 APInt UsedBits(LD->getValueSizeInBits(0), 0);
18022
18023 SmallVector<LoadedSlice, 4> LoadedSlices;
18024
18025 // Check if this load is used as several smaller chunks of bits.
18026 // Basically, look for uses in trunc or trunc(lshr) and record a new chain
18027 // of computation for each trunc.
18028 for (SDNode::use_iterator UI = LD->use_begin(), UIEnd = LD->use_end();
18029 UI != UIEnd; ++UI) {
18030 // Skip the uses of the chain.
18031 if (UI.getUse().getResNo() != 0)
18032 continue;
18033
18034 SDNode *User = *UI;
18035 unsigned Shift = 0;
18036
18037 // Check if this is a trunc(lshr).
18038 if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
18039 isa<ConstantSDNode>(User->getOperand(1))) {
18040 Shift = User->getConstantOperandVal(1);
18041 User = *User->use_begin();
18042 }
18043
18044 // At this point, User is a Truncate, iff we encountered, trunc or
18045 // trunc(lshr).
18046 if (User->getOpcode() != ISD::TRUNCATE)
18047 return false;
18048
18049 // The width of the type must be a power of 2 and greater than 8-bits.
18050 // Otherwise the load cannot be represented in LLVM IR.
18051 // Moreover, if we shifted with a non-8-bits multiple, the slice
18052 // will be across several bytes. We do not support that.
18053 unsigned Width = User->getValueSizeInBits(0);
18054 if (Width < 8 || !isPowerOf2_32(Width) || (Shift & 0x7))
18055 return false;
18056
18057 // Build the slice for this chain of computations.
18058 LoadedSlice LS(User, LD, Shift, &DAG);
18059 APInt CurrentUsedBits = LS.getUsedBits();
18060
18061 // Check if this slice overlaps with another.
18062 if ((CurrentUsedBits & UsedBits) != 0)
18063 return false;
18064 // Update the bits used globally.
18065 UsedBits |= CurrentUsedBits;
18066
18067 // Check if the new slice would be legal.
18068 if (!LS.isLegal())
18069 return false;
18070
18071 // Record the slice.
18072 LoadedSlices.push_back(LS);
18073 }
18074
18075 // Abort slicing if it does not seem to be profitable.
18076 if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
18077 return false;
18078
18079 ++SlicedLoads;
18080
18081 // Rewrite each chain to use an independent load.
18082 // By construction, each chain can be represented by a unique load.
18083
18084 // Prepare the argument for the new token factor for all the slices.
18085 SmallVector<SDValue, 8> ArgChains;
18086 for (const LoadedSlice &LS : LoadedSlices) {
18087 SDValue SliceInst = LS.loadSlice();
18088 CombineTo(LS.Inst, SliceInst, true);
18089 if (SliceInst.getOpcode() != ISD::LOAD)
18090 SliceInst = SliceInst.getOperand(0);
18091 assert(SliceInst->getOpcode() == ISD::LOAD &&
18092 "It takes more than a zext to get to the loaded slice!!");
18093 ArgChains.push_back(SliceInst.getValue(1));
18094 }
18095
18096 SDValue Chain = DAG.getNode(ISD::TokenFactor, SDLoc(LD), MVT::Other,
18097 ArgChains);
18098 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 1), Chain);
18099 AddToWorklist(Chain.getNode());
18100 return true;
18101 }
18102
18103 /// Check to see if V is (and load (ptr), imm), where the load is having
18104 /// specific bytes cleared out. If so, return the byte size being masked out
18105 /// and the shift amount.
18106 static std::pair<unsigned, unsigned>
CheckForMaskedLoad(SDValue V,SDValue Ptr,SDValue Chain)18107 CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
18108 std::pair<unsigned, unsigned> Result(0, 0);
18109
18110 // Check for the structure we're looking for.
18111 if (V->getOpcode() != ISD::AND ||
18112 !isa<ConstantSDNode>(V->getOperand(1)) ||
18113 !ISD::isNormalLoad(V->getOperand(0).getNode()))
18114 return Result;
18115
18116 // Check the chain and pointer.
18117 LoadSDNode *LD = cast<LoadSDNode>(V->getOperand(0));
18118 if (LD->getBasePtr() != Ptr) return Result; // Not from same pointer.
18119
18120 // This only handles simple types.
18121 if (V.getValueType() != MVT::i16 &&
18122 V.getValueType() != MVT::i32 &&
18123 V.getValueType() != MVT::i64)
18124 return Result;
18125
18126 // Check the constant mask. Invert it so that the bits being masked out are
18127 // 0 and the bits being kept are 1. Use getSExtValue so that leading bits
18128 // follow the sign bit for uniformity.
18129 uint64_t NotMask = ~cast<ConstantSDNode>(V->getOperand(1))->getSExtValue();
18130 unsigned NotMaskLZ = countLeadingZeros(NotMask);
18131 if (NotMaskLZ & 7) return Result; // Must be multiple of a byte.
18132 unsigned NotMaskTZ = countTrailingZeros(NotMask);
18133 if (NotMaskTZ & 7) return Result; // Must be multiple of a byte.
18134 if (NotMaskLZ == 64) return Result; // All zero mask.
18135
18136 // See if we have a continuous run of bits. If so, we have 0*1+0*
18137 if (countTrailingOnes(NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64)
18138 return Result;
18139
18140 // Adjust NotMaskLZ down to be from the actual size of the int instead of i64.
18141 if (V.getValueType() != MVT::i64 && NotMaskLZ)
18142 NotMaskLZ -= 64-V.getValueSizeInBits();
18143
18144 unsigned MaskedBytes = (V.getValueSizeInBits()-NotMaskLZ-NotMaskTZ)/8;
18145 switch (MaskedBytes) {
18146 case 1:
18147 case 2:
18148 case 4: break;
18149 default: return Result; // All one mask, or 5-byte mask.
18150 }
18151
18152 // Verify that the first bit starts at a multiple of mask so that the access
18153 // is aligned the same as the access width.
18154 if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result;
18155
18156 // For narrowing to be valid, it must be the case that the load the
18157 // immediately preceding memory operation before the store.
18158 if (LD == Chain.getNode())
18159 ; // ok.
18160 else if (Chain->getOpcode() == ISD::TokenFactor &&
18161 SDValue(LD, 1).hasOneUse()) {
18162 // LD has only 1 chain use so they are no indirect dependencies.
18163 if (!LD->isOperandOf(Chain.getNode()))
18164 return Result;
18165 } else
18166 return Result; // Fail.
18167
18168 Result.first = MaskedBytes;
18169 Result.second = NotMaskTZ/8;
18170 return Result;
18171 }
18172
18173 /// Check to see if IVal is something that provides a value as specified by
18174 /// MaskInfo. If so, replace the specified store with a narrower store of
18175 /// truncated IVal.
18176 static SDValue
ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned,unsigned> & MaskInfo,SDValue IVal,StoreSDNode * St,DAGCombiner * DC)18177 ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
18178 SDValue IVal, StoreSDNode *St,
18179 DAGCombiner *DC) {
18180 unsigned NumBytes = MaskInfo.first;
18181 unsigned ByteShift = MaskInfo.second;
18182 SelectionDAG &DAG = DC->getDAG();
18183
18184 // Check to see if IVal is all zeros in the part being masked in by the 'or'
18185 // that uses this. If not, this is not a replacement.
18186 APInt Mask = ~APInt::getBitsSet(IVal.getValueSizeInBits(),
18187 ByteShift*8, (ByteShift+NumBytes)*8);
18188 if (!DAG.MaskedValueIsZero(IVal, Mask)) return SDValue();
18189
18190 // Check that it is legal on the target to do this. It is legal if the new
18191 // VT we're shrinking to (i8/i16/i32) is legal or we're still before type
18192 // legalization. If the source type is legal, but the store type isn't, see
18193 // if we can use a truncating store.
18194 MVT VT = MVT::getIntegerVT(NumBytes * 8);
18195 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
18196 bool UseTruncStore;
18197 if (DC->isTypeLegal(VT))
18198 UseTruncStore = false;
18199 else if (TLI.isTypeLegal(IVal.getValueType()) &&
18200 TLI.isTruncStoreLegal(IVal.getValueType(), VT))
18201 UseTruncStore = true;
18202 else
18203 return SDValue();
18204 // Check that the target doesn't think this is a bad idea.
18205 if (St->getMemOperand() &&
18206 !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
18207 *St->getMemOperand()))
18208 return SDValue();
18209
18210 // Okay, we can do this! Replace the 'St' store with a store of IVal that is
18211 // shifted by ByteShift and truncated down to NumBytes.
18212 if (ByteShift) {
18213 SDLoc DL(IVal);
18214 IVal = DAG.getNode(ISD::SRL, DL, IVal.getValueType(), IVal,
18215 DAG.getConstant(ByteShift*8, DL,
18216 DC->getShiftAmountTy(IVal.getValueType())));
18217 }
18218
18219 // Figure out the offset for the store and the alignment of the access.
18220 unsigned StOffset;
18221 if (DAG.getDataLayout().isLittleEndian())
18222 StOffset = ByteShift;
18223 else
18224 StOffset = IVal.getValueType().getStoreSize() - ByteShift - NumBytes;
18225
18226 SDValue Ptr = St->getBasePtr();
18227 if (StOffset) {
18228 SDLoc DL(IVal);
18229 Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(StOffset), DL);
18230 }
18231
18232 ++OpsNarrowed;
18233 if (UseTruncStore)
18234 return DAG.getTruncStore(St->getChain(), SDLoc(St), IVal, Ptr,
18235 St->getPointerInfo().getWithOffset(StOffset),
18236 VT, St->getOriginalAlign());
18237
18238 // Truncate down to the new size.
18239 IVal = DAG.getNode(ISD::TRUNCATE, SDLoc(IVal), VT, IVal);
18240
18241 return DAG
18242 .getStore(St->getChain(), SDLoc(St), IVal, Ptr,
18243 St->getPointerInfo().getWithOffset(StOffset),
18244 St->getOriginalAlign());
18245 }
18246
18247 /// Look for sequence of load / op / store where op is one of 'or', 'xor', and
18248 /// 'and' of immediates. If 'op' is only touching some of the loaded bits, try
18249 /// narrowing the load and store if it would end up being a win for performance
18250 /// or code size.
ReduceLoadOpStoreWidth(SDNode * N)18251 SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
18252 StoreSDNode *ST = cast<StoreSDNode>(N);
18253 if (!ST->isSimple())
18254 return SDValue();
18255
18256 SDValue Chain = ST->getChain();
18257 SDValue Value = ST->getValue();
18258 SDValue Ptr = ST->getBasePtr();
18259 EVT VT = Value.getValueType();
18260
18261 if (ST->isTruncatingStore() || VT.isVector())
18262 return SDValue();
18263
18264 unsigned Opc = Value.getOpcode();
18265
18266 if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) ||
18267 !Value.hasOneUse())
18268 return SDValue();
18269
18270 // If this is "store (or X, Y), P" and X is "(and (load P), cst)", where cst
18271 // is a byte mask indicating a consecutive number of bytes, check to see if
18272 // Y is known to provide just those bytes. If so, we try to replace the
18273 // load + replace + store sequence with a single (narrower) store, which makes
18274 // the load dead.
18275 if (Opc == ISD::OR && EnableShrinkLoadReplaceStoreWithStore) {
18276 std::pair<unsigned, unsigned> MaskedLoad;
18277 MaskedLoad = CheckForMaskedLoad(Value.getOperand(0), Ptr, Chain);
18278 if (MaskedLoad.first)
18279 if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
18280 Value.getOperand(1), ST,this))
18281 return NewST;
18282
18283 // Or is commutative, so try swapping X and Y.
18284 MaskedLoad = CheckForMaskedLoad(Value.getOperand(1), Ptr, Chain);
18285 if (MaskedLoad.first)
18286 if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad,
18287 Value.getOperand(0), ST,this))
18288 return NewST;
18289 }
18290
18291 if (!EnableReduceLoadOpStoreWidth)
18292 return SDValue();
18293
18294 if (Value.getOperand(1).getOpcode() != ISD::Constant)
18295 return SDValue();
18296
18297 SDValue N0 = Value.getOperand(0);
18298 if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
18299 Chain == SDValue(N0.getNode(), 1)) {
18300 LoadSDNode *LD = cast<LoadSDNode>(N0);
18301 if (LD->getBasePtr() != Ptr ||
18302 LD->getPointerInfo().getAddrSpace() !=
18303 ST->getPointerInfo().getAddrSpace())
18304 return SDValue();
18305
18306 // Find the type to narrow it the load / op / store to.
18307 SDValue N1 = Value.getOperand(1);
18308 unsigned BitWidth = N1.getValueSizeInBits();
18309 APInt Imm = cast<ConstantSDNode>(N1)->getAPIntValue();
18310 if (Opc == ISD::AND)
18311 Imm ^= APInt::getAllOnes(BitWidth);
18312 if (Imm == 0 || Imm.isAllOnes())
18313 return SDValue();
18314 unsigned ShAmt = Imm.countTrailingZeros();
18315 unsigned MSB = BitWidth - Imm.countLeadingZeros() - 1;
18316 unsigned NewBW = NextPowerOf2(MSB - ShAmt);
18317 EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
18318 // The narrowing should be profitable, the load/store operation should be
18319 // legal (or custom) and the store size should be equal to the NewVT width.
18320 while (NewBW < BitWidth &&
18321 (NewVT.getStoreSizeInBits() != NewBW ||
18322 !TLI.isOperationLegalOrCustom(Opc, NewVT) ||
18323 !TLI.isNarrowingProfitable(VT, NewVT))) {
18324 NewBW = NextPowerOf2(NewBW);
18325 NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
18326 }
18327 if (NewBW >= BitWidth)
18328 return SDValue();
18329
18330 // If the lsb changed does not start at the type bitwidth boundary,
18331 // start at the previous one.
18332 if (ShAmt % NewBW)
18333 ShAmt = (((ShAmt + NewBW - 1) / NewBW) * NewBW) - NewBW;
18334 APInt Mask = APInt::getBitsSet(BitWidth, ShAmt,
18335 std::min(BitWidth, ShAmt + NewBW));
18336 if ((Imm & Mask) == Imm) {
18337 APInt NewImm = (Imm & Mask).lshr(ShAmt).trunc(NewBW);
18338 if (Opc == ISD::AND)
18339 NewImm ^= APInt::getAllOnes(NewBW);
18340 uint64_t PtrOff = ShAmt / 8;
18341 // For big endian targets, we need to adjust the offset to the pointer to
18342 // load the correct bytes.
18343 if (DAG.getDataLayout().isBigEndian())
18344 PtrOff = (BitWidth + 7 - NewBW) / 8 - PtrOff;
18345
18346 unsigned IsFast = 0;
18347 Align NewAlign = commonAlignment(LD->getAlign(), PtrOff);
18348 if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), NewVT,
18349 LD->getAddressSpace(), NewAlign,
18350 LD->getMemOperand()->getFlags(), &IsFast) ||
18351 !IsFast)
18352 return SDValue();
18353
18354 SDValue NewPtr =
18355 DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(PtrOff), SDLoc(LD));
18356 SDValue NewLD =
18357 DAG.getLoad(NewVT, SDLoc(N0), LD->getChain(), NewPtr,
18358 LD->getPointerInfo().getWithOffset(PtrOff), NewAlign,
18359 LD->getMemOperand()->getFlags(), LD->getAAInfo());
18360 SDValue NewVal = DAG.getNode(Opc, SDLoc(Value), NewVT, NewLD,
18361 DAG.getConstant(NewImm, SDLoc(Value),
18362 NewVT));
18363 SDValue NewST =
18364 DAG.getStore(Chain, SDLoc(N), NewVal, NewPtr,
18365 ST->getPointerInfo().getWithOffset(PtrOff), NewAlign);
18366
18367 AddToWorklist(NewPtr.getNode());
18368 AddToWorklist(NewLD.getNode());
18369 AddToWorklist(NewVal.getNode());
18370 WorklistRemover DeadNodes(*this);
18371 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), NewLD.getValue(1));
18372 ++OpsNarrowed;
18373 return NewST;
18374 }
18375 }
18376
18377 return SDValue();
18378 }
18379
18380 /// For a given floating point load / store pair, if the load value isn't used
18381 /// by any other operations, then consider transforming the pair to integer
18382 /// load / store operations if the target deems the transformation profitable.
TransformFPLoadStorePair(SDNode * N)18383 SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
18384 StoreSDNode *ST = cast<StoreSDNode>(N);
18385 SDValue Value = ST->getValue();
18386 if (ISD::isNormalStore(ST) && ISD::isNormalLoad(Value.getNode()) &&
18387 Value.hasOneUse()) {
18388 LoadSDNode *LD = cast<LoadSDNode>(Value);
18389 EVT VT = LD->getMemoryVT();
18390 if (!VT.isFloatingPoint() ||
18391 VT != ST->getMemoryVT() ||
18392 LD->isNonTemporal() ||
18393 ST->isNonTemporal() ||
18394 LD->getPointerInfo().getAddrSpace() != 0 ||
18395 ST->getPointerInfo().getAddrSpace() != 0)
18396 return SDValue();
18397
18398 TypeSize VTSize = VT.getSizeInBits();
18399
18400 // We don't know the size of scalable types at compile time so we cannot
18401 // create an integer of the equivalent size.
18402 if (VTSize.isScalable())
18403 return SDValue();
18404
18405 unsigned FastLD = 0, FastST = 0;
18406 EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VTSize.getFixedValue());
18407 if (!TLI.isOperationLegal(ISD::LOAD, IntVT) ||
18408 !TLI.isOperationLegal(ISD::STORE, IntVT) ||
18409 !TLI.isDesirableToTransformToIntegerOp(ISD::LOAD, VT) ||
18410 !TLI.isDesirableToTransformToIntegerOp(ISD::STORE, VT) ||
18411 !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), IntVT,
18412 *LD->getMemOperand(), &FastLD) ||
18413 !TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), IntVT,
18414 *ST->getMemOperand(), &FastST) ||
18415 !FastLD || !FastST)
18416 return SDValue();
18417
18418 SDValue NewLD =
18419 DAG.getLoad(IntVT, SDLoc(Value), LD->getChain(), LD->getBasePtr(),
18420 LD->getPointerInfo(), LD->getAlign());
18421
18422 SDValue NewST =
18423 DAG.getStore(ST->getChain(), SDLoc(N), NewLD, ST->getBasePtr(),
18424 ST->getPointerInfo(), ST->getAlign());
18425
18426 AddToWorklist(NewLD.getNode());
18427 AddToWorklist(NewST.getNode());
18428 WorklistRemover DeadNodes(*this);
18429 DAG.ReplaceAllUsesOfValueWith(Value.getValue(1), NewLD.getValue(1));
18430 ++LdStFP2Int;
18431 return NewST;
18432 }
18433
18434 return SDValue();
18435 }
18436
18437 // This is a helper function for visitMUL to check the profitability
18438 // of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
18439 // MulNode is the original multiply, AddNode is (add x, c1),
18440 // and ConstNode is c2.
18441 //
18442 // If the (add x, c1) has multiple uses, we could increase
18443 // the number of adds if we make this transformation.
18444 // It would only be worth doing this if we can remove a
18445 // multiply in the process. Check for that here.
18446 // To illustrate:
18447 // (A + c1) * c3
18448 // (A + c2) * c3
18449 // We're checking for cases where we have common "c3 * A" expressions.
isMulAddWithConstProfitable(SDNode * MulNode,SDValue AddNode,SDValue ConstNode)18450 bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
18451 SDValue ConstNode) {
18452 APInt Val;
18453
18454 // If the add only has one use, and the target thinks the folding is
18455 // profitable or does not lead to worse code, this would be OK to do.
18456 if (AddNode->hasOneUse() &&
18457 TLI.isMulAddWithConstProfitable(AddNode, ConstNode))
18458 return true;
18459
18460 // Walk all the users of the constant with which we're multiplying.
18461 for (SDNode *Use : ConstNode->uses()) {
18462 if (Use == MulNode) // This use is the one we're on right now. Skip it.
18463 continue;
18464
18465 if (Use->getOpcode() == ISD::MUL) { // We have another multiply use.
18466 SDNode *OtherOp;
18467 SDNode *MulVar = AddNode.getOperand(0).getNode();
18468
18469 // OtherOp is what we're multiplying against the constant.
18470 if (Use->getOperand(0) == ConstNode)
18471 OtherOp = Use->getOperand(1).getNode();
18472 else
18473 OtherOp = Use->getOperand(0).getNode();
18474
18475 // Check to see if multiply is with the same operand of our "add".
18476 //
18477 // ConstNode = CONST
18478 // Use = ConstNode * A <-- visiting Use. OtherOp is A.
18479 // ...
18480 // AddNode = (A + c1) <-- MulVar is A.
18481 // = AddNode * ConstNode <-- current visiting instruction.
18482 //
18483 // If we make this transformation, we will have a common
18484 // multiply (ConstNode * A) that we can save.
18485 if (OtherOp == MulVar)
18486 return true;
18487
18488 // Now check to see if a future expansion will give us a common
18489 // multiply.
18490 //
18491 // ConstNode = CONST
18492 // AddNode = (A + c1)
18493 // ... = AddNode * ConstNode <-- current visiting instruction.
18494 // ...
18495 // OtherOp = (A + c2)
18496 // Use = OtherOp * ConstNode <-- visiting Use.
18497 //
18498 // If we make this transformation, we will have a common
18499 // multiply (CONST * A) after we also do the same transformation
18500 // to the "t2" instruction.
18501 if (OtherOp->getOpcode() == ISD::ADD &&
18502 DAG.isConstantIntBuildVectorOrConstantInt(OtherOp->getOperand(1)) &&
18503 OtherOp->getOperand(0).getNode() == MulVar)
18504 return true;
18505 }
18506 }
18507
18508 // Didn't find a case where this would be profitable.
18509 return false;
18510 }
18511
getMergeStoreChains(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores)18512 SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
18513 unsigned NumStores) {
18514 SmallVector<SDValue, 8> Chains;
18515 SmallPtrSet<const SDNode *, 8> Visited;
18516 SDLoc StoreDL(StoreNodes[0].MemNode);
18517
18518 for (unsigned i = 0; i < NumStores; ++i) {
18519 Visited.insert(StoreNodes[i].MemNode);
18520 }
18521
18522 // don't include nodes that are children or repeated nodes.
18523 for (unsigned i = 0; i < NumStores; ++i) {
18524 if (Visited.insert(StoreNodes[i].MemNode->getChain().getNode()).second)
18525 Chains.push_back(StoreNodes[i].MemNode->getChain());
18526 }
18527
18528 assert(Chains.size() > 0 && "Chain should have generated a chain");
18529 return DAG.getTokenFactor(StoreDL, Chains);
18530 }
18531
mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> & StoreNodes,EVT MemVT,unsigned NumStores,bool IsConstantSrc,bool UseVector,bool UseTrunc)18532 bool DAGCombiner::mergeStoresOfConstantsOrVecElts(
18533 SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
18534 bool IsConstantSrc, bool UseVector, bool UseTrunc) {
18535 // Make sure we have something to merge.
18536 if (NumStores < 2)
18537 return false;
18538
18539 assert((!UseTrunc || !UseVector) &&
18540 "This optimization cannot emit a vector truncating store");
18541
18542 // The latest Node in the DAG.
18543 SDLoc DL(StoreNodes[0].MemNode);
18544
18545 TypeSize ElementSizeBits = MemVT.getStoreSizeInBits();
18546 unsigned SizeInBits = NumStores * ElementSizeBits;
18547 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
18548
18549 std::optional<MachineMemOperand::Flags> Flags;
18550 AAMDNodes AAInfo;
18551 for (unsigned I = 0; I != NumStores; ++I) {
18552 StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
18553 if (!Flags) {
18554 Flags = St->getMemOperand()->getFlags();
18555 AAInfo = St->getAAInfo();
18556 continue;
18557 }
18558 // Skip merging if there's an inconsistent flag.
18559 if (Flags != St->getMemOperand()->getFlags())
18560 return false;
18561 // Concatenate AA metadata.
18562 AAInfo = AAInfo.concat(St->getAAInfo());
18563 }
18564
18565 EVT StoreTy;
18566 if (UseVector) {
18567 unsigned Elts = NumStores * NumMemElts;
18568 // Get the type for the merged vector store.
18569 StoreTy = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
18570 } else
18571 StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits);
18572
18573 SDValue StoredVal;
18574 if (UseVector) {
18575 if (IsConstantSrc) {
18576 SmallVector<SDValue, 8> BuildVector;
18577 for (unsigned I = 0; I != NumStores; ++I) {
18578 StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
18579 SDValue Val = St->getValue();
18580 // If constant is of the wrong type, convert it now.
18581 if (MemVT != Val.getValueType()) {
18582 Val = peekThroughBitcasts(Val);
18583 // Deal with constants of wrong size.
18584 if (ElementSizeBits != Val.getValueSizeInBits()) {
18585 EVT IntMemVT =
18586 EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits());
18587 if (isa<ConstantFPSDNode>(Val)) {
18588 // Not clear how to truncate FP values.
18589 return false;
18590 }
18591
18592 if (auto *C = dyn_cast<ConstantSDNode>(Val))
18593 Val = DAG.getConstant(C->getAPIntValue()
18594 .zextOrTrunc(Val.getValueSizeInBits())
18595 .zextOrTrunc(ElementSizeBits),
18596 SDLoc(C), IntMemVT);
18597 }
18598 // Make sure correctly size type is the correct type.
18599 Val = DAG.getBitcast(MemVT, Val);
18600 }
18601 BuildVector.push_back(Val);
18602 }
18603 StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
18604 : ISD::BUILD_VECTOR,
18605 DL, StoreTy, BuildVector);
18606 } else {
18607 SmallVector<SDValue, 8> Ops;
18608 for (unsigned i = 0; i < NumStores; ++i) {
18609 StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
18610 SDValue Val = peekThroughBitcasts(St->getValue());
18611 // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
18612 // type MemVT. If the underlying value is not the correct
18613 // type, but it is an extraction of an appropriate vector we
18614 // can recast Val to be of the correct type. This may require
18615 // converting between EXTRACT_VECTOR_ELT and
18616 // EXTRACT_SUBVECTOR.
18617 if ((MemVT != Val.getValueType()) &&
18618 (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
18619 Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
18620 EVT MemVTScalarTy = MemVT.getScalarType();
18621 // We may need to add a bitcast here to get types to line up.
18622 if (MemVTScalarTy != Val.getValueType().getScalarType()) {
18623 Val = DAG.getBitcast(MemVT, Val);
18624 } else if (MemVT.isVector() &&
18625 Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
18626 Val = DAG.getNode(ISD::BUILD_VECTOR, DL, MemVT, Val);
18627 } else {
18628 unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR
18629 : ISD::EXTRACT_VECTOR_ELT;
18630 SDValue Vec = Val.getOperand(0);
18631 SDValue Idx = Val.getOperand(1);
18632 Val = DAG.getNode(OpC, SDLoc(Val), MemVT, Vec, Idx);
18633 }
18634 }
18635 Ops.push_back(Val);
18636 }
18637
18638 // Build the extracted vector elements back into a vector.
18639 StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
18640 : ISD::BUILD_VECTOR,
18641 DL, StoreTy, Ops);
18642 }
18643 } else {
18644 // We should always use a vector store when merging extracted vector
18645 // elements, so this path implies a store of constants.
18646 assert(IsConstantSrc && "Merged vector elements should use vector store");
18647
18648 APInt StoreInt(SizeInBits, 0);
18649
18650 // Construct a single integer constant which is made of the smaller
18651 // constant inputs.
18652 bool IsLE = DAG.getDataLayout().isLittleEndian();
18653 for (unsigned i = 0; i < NumStores; ++i) {
18654 unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
18655 StoreSDNode *St = cast<StoreSDNode>(StoreNodes[Idx].MemNode);
18656
18657 SDValue Val = St->getValue();
18658 Val = peekThroughBitcasts(Val);
18659 StoreInt <<= ElementSizeBits;
18660 if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
18661 StoreInt |= C->getAPIntValue()
18662 .zextOrTrunc(ElementSizeBits)
18663 .zextOrTrunc(SizeInBits);
18664 } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
18665 StoreInt |= C->getValueAPF()
18666 .bitcastToAPInt()
18667 .zextOrTrunc(ElementSizeBits)
18668 .zextOrTrunc(SizeInBits);
18669 // If fp truncation is necessary give up for now.
18670 if (MemVT.getSizeInBits() != ElementSizeBits)
18671 return false;
18672 } else {
18673 llvm_unreachable("Invalid constant element type");
18674 }
18675 }
18676
18677 // Create the new Load and Store operations.
18678 StoredVal = DAG.getConstant(StoreInt, DL, StoreTy);
18679 }
18680
18681 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
18682 SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
18683
18684 // make sure we use trunc store if it's necessary to be legal.
18685 SDValue NewStore;
18686 if (!UseTrunc) {
18687 NewStore = DAG.getStore(NewChain, DL, StoredVal, FirstInChain->getBasePtr(),
18688 FirstInChain->getPointerInfo(),
18689 FirstInChain->getAlign(), *Flags, AAInfo);
18690 } else { // Must be realized as a trunc store
18691 EVT LegalizedStoredValTy =
18692 TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType());
18693 unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits();
18694 ConstantSDNode *C = cast<ConstantSDNode>(StoredVal);
18695 SDValue ExtendedStoreVal =
18696 DAG.getConstant(C->getAPIntValue().zextOrTrunc(LegalizedStoreSize), DL,
18697 LegalizedStoredValTy);
18698 NewStore = DAG.getTruncStore(
18699 NewChain, DL, ExtendedStoreVal, FirstInChain->getBasePtr(),
18700 FirstInChain->getPointerInfo(), StoredVal.getValueType() /*TVT*/,
18701 FirstInChain->getAlign(), *Flags, AAInfo);
18702 }
18703
18704 // Replace all merged stores with the new store.
18705 for (unsigned i = 0; i < NumStores; ++i)
18706 CombineTo(StoreNodes[i].MemNode, NewStore);
18707
18708 AddToWorklist(NewChain.getNode());
18709 return true;
18710 }
18711
getStoreMergeCandidates(StoreSDNode * St,SmallVectorImpl<MemOpLink> & StoreNodes,SDNode * & RootNode)18712 void DAGCombiner::getStoreMergeCandidates(
18713 StoreSDNode *St, SmallVectorImpl<MemOpLink> &StoreNodes,
18714 SDNode *&RootNode) {
18715 // This holds the base pointer, index, and the offset in bytes from the base
18716 // pointer. We must have a base and an offset. Do not handle stores to undef
18717 // base pointers.
18718 BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
18719 if (!BasePtr.getBase().getNode() || BasePtr.getBase().isUndef())
18720 return;
18721
18722 SDValue Val = peekThroughBitcasts(St->getValue());
18723 StoreSource StoreSrc = getStoreSource(Val);
18724 assert(StoreSrc != StoreSource::Unknown && "Expected known source for store");
18725
18726 // Match on loadbaseptr if relevant.
18727 EVT MemVT = St->getMemoryVT();
18728 BaseIndexOffset LBasePtr;
18729 EVT LoadVT;
18730 if (StoreSrc == StoreSource::Load) {
18731 auto *Ld = cast<LoadSDNode>(Val);
18732 LBasePtr = BaseIndexOffset::match(Ld, DAG);
18733 LoadVT = Ld->getMemoryVT();
18734 // Load and store should be the same type.
18735 if (MemVT != LoadVT)
18736 return;
18737 // Loads must only have one use.
18738 if (!Ld->hasNUsesOfValue(1, 0))
18739 return;
18740 // The memory operands must not be volatile/indexed/atomic.
18741 // TODO: May be able to relax for unordered atomics (see D66309)
18742 if (!Ld->isSimple() || Ld->isIndexed())
18743 return;
18744 }
18745 auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
18746 int64_t &Offset) -> bool {
18747 // The memory operands must not be volatile/indexed/atomic.
18748 // TODO: May be able to relax for unordered atomics (see D66309)
18749 if (!Other->isSimple() || Other->isIndexed())
18750 return false;
18751 // Don't mix temporal stores with non-temporal stores.
18752 if (St->isNonTemporal() != Other->isNonTemporal())
18753 return false;
18754 SDValue OtherBC = peekThroughBitcasts(Other->getValue());
18755 // Allow merging constants of different types as integers.
18756 bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT())
18757 : Other->getMemoryVT() != MemVT;
18758 switch (StoreSrc) {
18759 case StoreSource::Load: {
18760 if (NoTypeMatch)
18761 return false;
18762 // The Load's Base Ptr must also match.
18763 auto *OtherLd = dyn_cast<LoadSDNode>(OtherBC);
18764 if (!OtherLd)
18765 return false;
18766 BaseIndexOffset LPtr = BaseIndexOffset::match(OtherLd, DAG);
18767 if (LoadVT != OtherLd->getMemoryVT())
18768 return false;
18769 // Loads must only have one use.
18770 if (!OtherLd->hasNUsesOfValue(1, 0))
18771 return false;
18772 // The memory operands must not be volatile/indexed/atomic.
18773 // TODO: May be able to relax for unordered atomics (see D66309)
18774 if (!OtherLd->isSimple() || OtherLd->isIndexed())
18775 return false;
18776 // Don't mix temporal loads with non-temporal loads.
18777 if (cast<LoadSDNode>(Val)->isNonTemporal() != OtherLd->isNonTemporal())
18778 return false;
18779 if (!(LBasePtr.equalBaseIndex(LPtr, DAG)))
18780 return false;
18781 break;
18782 }
18783 case StoreSource::Constant:
18784 if (NoTypeMatch)
18785 return false;
18786 if (!isIntOrFPConstant(OtherBC))
18787 return false;
18788 break;
18789 case StoreSource::Extract:
18790 // Do not merge truncated stores here.
18791 if (Other->isTruncatingStore())
18792 return false;
18793 if (!MemVT.bitsEq(OtherBC.getValueType()))
18794 return false;
18795 if (OtherBC.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
18796 OtherBC.getOpcode() != ISD::EXTRACT_SUBVECTOR)
18797 return false;
18798 break;
18799 default:
18800 llvm_unreachable("Unhandled store source for merging");
18801 }
18802 Ptr = BaseIndexOffset::match(Other, DAG);
18803 return (BasePtr.equalBaseIndex(Ptr, DAG, Offset));
18804 };
18805
18806 // Check if the pair of StoreNode and the RootNode already bail out many
18807 // times which is over the limit in dependence check.
18808 auto OverLimitInDependenceCheck = [&](SDNode *StoreNode,
18809 SDNode *RootNode) -> bool {
18810 auto RootCount = StoreRootCountMap.find(StoreNode);
18811 return RootCount != StoreRootCountMap.end() &&
18812 RootCount->second.first == RootNode &&
18813 RootCount->second.second > StoreMergeDependenceLimit;
18814 };
18815
18816 auto TryToAddCandidate = [&](SDNode::use_iterator UseIter) {
18817 // This must be a chain use.
18818 if (UseIter.getOperandNo() != 0)
18819 return;
18820 if (auto *OtherStore = dyn_cast<StoreSDNode>(*UseIter)) {
18821 BaseIndexOffset Ptr;
18822 int64_t PtrDiff;
18823 if (CandidateMatch(OtherStore, Ptr, PtrDiff) &&
18824 !OverLimitInDependenceCheck(OtherStore, RootNode))
18825 StoreNodes.push_back(MemOpLink(OtherStore, PtrDiff));
18826 }
18827 };
18828
18829 // We looking for a root node which is an ancestor to all mergable
18830 // stores. We search up through a load, to our root and then down
18831 // through all children. For instance we will find Store{1,2,3} if
18832 // St is Store1, Store2. or Store3 where the root is not a load
18833 // which always true for nonvolatile ops. TODO: Expand
18834 // the search to find all valid candidates through multiple layers of loads.
18835 //
18836 // Root
18837 // |-------|-------|
18838 // Load Load Store3
18839 // | |
18840 // Store1 Store2
18841 //
18842 // FIXME: We should be able to climb and
18843 // descend TokenFactors to find candidates as well.
18844
18845 RootNode = St->getChain().getNode();
18846
18847 unsigned NumNodesExplored = 0;
18848 const unsigned MaxSearchNodes = 1024;
18849 if (auto *Ldn = dyn_cast<LoadSDNode>(RootNode)) {
18850 RootNode = Ldn->getChain().getNode();
18851 for (auto I = RootNode->use_begin(), E = RootNode->use_end();
18852 I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored) {
18853 if (I.getOperandNo() == 0 && isa<LoadSDNode>(*I)) { // walk down chain
18854 for (auto I2 = (*I)->use_begin(), E2 = (*I)->use_end(); I2 != E2; ++I2)
18855 TryToAddCandidate(I2);
18856 }
18857 // Check stores that depend on the root (e.g. Store 3 in the chart above).
18858 if (I.getOperandNo() == 0 && isa<StoreSDNode>(*I)) {
18859 TryToAddCandidate(I);
18860 }
18861 }
18862 } else {
18863 for (auto I = RootNode->use_begin(), E = RootNode->use_end();
18864 I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored)
18865 TryToAddCandidate(I);
18866 }
18867 }
18868
18869 // We need to check that merging these stores does not cause a loop in the
18870 // DAG. Any store candidate may depend on another candidate indirectly through
18871 // its operands. Check in parallel by searching up from operands of candidates.
checkMergeStoreCandidatesForDependencies(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumStores,SDNode * RootNode)18872 bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
18873 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
18874 SDNode *RootNode) {
18875 // FIXME: We should be able to truncate a full search of
18876 // predecessors by doing a BFS and keeping tabs the originating
18877 // stores from which worklist nodes come from in a similar way to
18878 // TokenFactor simplfication.
18879
18880 SmallPtrSet<const SDNode *, 32> Visited;
18881 SmallVector<const SDNode *, 8> Worklist;
18882
18883 // RootNode is a predecessor to all candidates so we need not search
18884 // past it. Add RootNode (peeking through TokenFactors). Do not count
18885 // these towards size check.
18886
18887 Worklist.push_back(RootNode);
18888 while (!Worklist.empty()) {
18889 auto N = Worklist.pop_back_val();
18890 if (!Visited.insert(N).second)
18891 continue; // Already present in Visited.
18892 if (N->getOpcode() == ISD::TokenFactor) {
18893 for (SDValue Op : N->ops())
18894 Worklist.push_back(Op.getNode());
18895 }
18896 }
18897
18898 // Don't count pruning nodes towards max.
18899 unsigned int Max = 1024 + Visited.size();
18900 // Search Ops of store candidates.
18901 for (unsigned i = 0; i < NumStores; ++i) {
18902 SDNode *N = StoreNodes[i].MemNode;
18903 // Of the 4 Store Operands:
18904 // * Chain (Op 0) -> We have already considered these
18905 // in candidate selection, but only by following the
18906 // chain dependencies. We could still have a chain
18907 // dependency to a load, that has a non-chain dep to
18908 // another load, that depends on a store, etc. So it is
18909 // possible to have dependencies that consist of a mix
18910 // of chain and non-chain deps, and we need to include
18911 // chain operands in the analysis here..
18912 // * Value (Op 1) -> Cycles may happen (e.g. through load chains)
18913 // * Address (Op 2) -> Merged addresses may only vary by a fixed constant,
18914 // but aren't necessarily fromt the same base node, so
18915 // cycles possible (e.g. via indexed store).
18916 // * (Op 3) -> Represents the pre or post-indexing offset (or undef for
18917 // non-indexed stores). Not constant on all targets (e.g. ARM)
18918 // and so can participate in a cycle.
18919 for (unsigned j = 0; j < N->getNumOperands(); ++j)
18920 Worklist.push_back(N->getOperand(j).getNode());
18921 }
18922 // Search through DAG. We can stop early if we find a store node.
18923 for (unsigned i = 0; i < NumStores; ++i)
18924 if (SDNode::hasPredecessorHelper(StoreNodes[i].MemNode, Visited, Worklist,
18925 Max)) {
18926 // If the searching bail out, record the StoreNode and RootNode in the
18927 // StoreRootCountMap. If we have seen the pair many times over a limit,
18928 // we won't add the StoreNode into StoreNodes set again.
18929 if (Visited.size() >= Max) {
18930 auto &RootCount = StoreRootCountMap[StoreNodes[i].MemNode];
18931 if (RootCount.first == RootNode)
18932 RootCount.second++;
18933 else
18934 RootCount = {RootNode, 1};
18935 }
18936 return false;
18937 }
18938 return true;
18939 }
18940
18941 unsigned
getConsecutiveStores(SmallVectorImpl<MemOpLink> & StoreNodes,int64_t ElementSizeBytes) const18942 DAGCombiner::getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
18943 int64_t ElementSizeBytes) const {
18944 while (true) {
18945 // Find a store past the width of the first store.
18946 size_t StartIdx = 0;
18947 while ((StartIdx + 1 < StoreNodes.size()) &&
18948 StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
18949 StoreNodes[StartIdx + 1].OffsetFromBase)
18950 ++StartIdx;
18951
18952 // Bail if we don't have enough candidates to merge.
18953 if (StartIdx + 1 >= StoreNodes.size())
18954 return 0;
18955
18956 // Trim stores that overlapped with the first store.
18957 if (StartIdx)
18958 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + StartIdx);
18959
18960 // Scan the memory operations on the chain and find the first
18961 // non-consecutive store memory address.
18962 unsigned NumConsecutiveStores = 1;
18963 int64_t StartAddress = StoreNodes[0].OffsetFromBase;
18964 // Check that the addresses are consecutive starting from the second
18965 // element in the list of stores.
18966 for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
18967 int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
18968 if (CurrAddress - StartAddress != (ElementSizeBytes * i))
18969 break;
18970 NumConsecutiveStores = i + 1;
18971 }
18972 if (NumConsecutiveStores > 1)
18973 return NumConsecutiveStores;
18974
18975 // There are no consecutive stores at the start of the list.
18976 // Remove the first store and try again.
18977 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 1);
18978 }
18979 }
18980
tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors)18981 bool DAGCombiner::tryStoreMergeOfConstants(
18982 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
18983 EVT MemVT, SDNode *RootNode, bool AllowVectors) {
18984 LLVMContext &Context = *DAG.getContext();
18985 const DataLayout &DL = DAG.getDataLayout();
18986 int64_t ElementSizeBytes = MemVT.getStoreSize();
18987 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
18988 bool MadeChange = false;
18989
18990 // Store the constants into memory as one consecutive store.
18991 while (NumConsecutiveStores >= 2) {
18992 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
18993 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
18994 Align FirstStoreAlign = FirstInChain->getAlign();
18995 unsigned LastLegalType = 1;
18996 unsigned LastLegalVectorType = 1;
18997 bool LastIntegerTrunc = false;
18998 bool NonZero = false;
18999 unsigned FirstZeroAfterNonZero = NumConsecutiveStores;
19000 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
19001 StoreSDNode *ST = cast<StoreSDNode>(StoreNodes[i].MemNode);
19002 SDValue StoredVal = ST->getValue();
19003 bool IsElementZero = false;
19004 if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(StoredVal))
19005 IsElementZero = C->isZero();
19006 else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(StoredVal))
19007 IsElementZero = C->getConstantFPValue()->isNullValue();
19008 if (IsElementZero) {
19009 if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores)
19010 FirstZeroAfterNonZero = i;
19011 }
19012 NonZero |= !IsElementZero;
19013
19014 // Find a legal type for the constant store.
19015 unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
19016 EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits);
19017 unsigned IsFast = 0;
19018
19019 // Break early when size is too large to be legal.
19020 if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
19021 break;
19022
19023 if (TLI.isTypeLegal(StoreTy) &&
19024 TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
19025 DAG.getMachineFunction()) &&
19026 TLI.allowsMemoryAccess(Context, DL, StoreTy,
19027 *FirstInChain->getMemOperand(), &IsFast) &&
19028 IsFast) {
19029 LastIntegerTrunc = false;
19030 LastLegalType = i + 1;
19031 // Or check whether a truncstore is legal.
19032 } else if (TLI.getTypeAction(Context, StoreTy) ==
19033 TargetLowering::TypePromoteInteger) {
19034 EVT LegalizedStoredValTy =
19035 TLI.getTypeToTransformTo(Context, StoredVal.getValueType());
19036 if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
19037 TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy,
19038 DAG.getMachineFunction()) &&
19039 TLI.allowsMemoryAccess(Context, DL, StoreTy,
19040 *FirstInChain->getMemOperand(), &IsFast) &&
19041 IsFast) {
19042 LastIntegerTrunc = true;
19043 LastLegalType = i + 1;
19044 }
19045 }
19046
19047 // We only use vectors if the constant is known to be zero or the
19048 // target allows it and the function is not marked with the
19049 // noimplicitfloat attribute.
19050 if ((!NonZero ||
19051 TLI.storeOfVectorConstantIsCheap(MemVT, i + 1, FirstStoreAS)) &&
19052 AllowVectors) {
19053 // Find a legal type for the vector store.
19054 unsigned Elts = (i + 1) * NumMemElts;
19055 EVT Ty = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
19056 if (TLI.isTypeLegal(Ty) && TLI.isTypeLegal(MemVT) &&
19057 TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG.getMachineFunction()) &&
19058 TLI.allowsMemoryAccess(Context, DL, Ty,
19059 *FirstInChain->getMemOperand(), &IsFast) &&
19060 IsFast)
19061 LastLegalVectorType = i + 1;
19062 }
19063 }
19064
19065 bool UseVector = (LastLegalVectorType > LastLegalType) && AllowVectors;
19066 unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
19067 bool UseTrunc = LastIntegerTrunc && !UseVector;
19068
19069 // Check if we found a legal integer type that creates a meaningful
19070 // merge.
19071 if (NumElem < 2) {
19072 // We know that candidate stores are in order and of correct
19073 // shape. While there is no mergeable sequence from the
19074 // beginning one may start later in the sequence. The only
19075 // reason a merge of size N could have failed where another of
19076 // the same size would not have, is if the alignment has
19077 // improved or we've dropped a non-zero value. Drop as many
19078 // candidates as we can here.
19079 unsigned NumSkip = 1;
19080 while ((NumSkip < NumConsecutiveStores) &&
19081 (NumSkip < FirstZeroAfterNonZero) &&
19082 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
19083 NumSkip++;
19084
19085 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
19086 NumConsecutiveStores -= NumSkip;
19087 continue;
19088 }
19089
19090 // Check that we can merge these candidates without causing a cycle.
19091 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
19092 RootNode)) {
19093 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
19094 NumConsecutiveStores -= NumElem;
19095 continue;
19096 }
19097
19098 MadeChange |= mergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem,
19099 /*IsConstantSrc*/ true,
19100 UseVector, UseTrunc);
19101
19102 // Remove merged stores for next iteration.
19103 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
19104 NumConsecutiveStores -= NumElem;
19105 }
19106 return MadeChange;
19107 }
19108
tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode)19109 bool DAGCombiner::tryStoreMergeOfExtracts(
19110 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
19111 EVT MemVT, SDNode *RootNode) {
19112 LLVMContext &Context = *DAG.getContext();
19113 const DataLayout &DL = DAG.getDataLayout();
19114 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
19115 bool MadeChange = false;
19116
19117 // Loop on Consecutive Stores on success.
19118 while (NumConsecutiveStores >= 2) {
19119 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
19120 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
19121 Align FirstStoreAlign = FirstInChain->getAlign();
19122 unsigned NumStoresToMerge = 1;
19123 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
19124 // Find a legal type for the vector store.
19125 unsigned Elts = (i + 1) * NumMemElts;
19126 EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts);
19127 unsigned IsFast = 0;
19128
19129 // Break early when size is too large to be legal.
19130 if (Ty.getSizeInBits() > MaximumLegalStoreInBits)
19131 break;
19132
19133 if (TLI.isTypeLegal(Ty) &&
19134 TLI.canMergeStoresTo(FirstStoreAS, Ty, DAG.getMachineFunction()) &&
19135 TLI.allowsMemoryAccess(Context, DL, Ty,
19136 *FirstInChain->getMemOperand(), &IsFast) &&
19137 IsFast)
19138 NumStoresToMerge = i + 1;
19139 }
19140
19141 // Check if we found a legal integer type creating a meaningful
19142 // merge.
19143 if (NumStoresToMerge < 2) {
19144 // We know that candidate stores are in order and of correct
19145 // shape. While there is no mergeable sequence from the
19146 // beginning one may start later in the sequence. The only
19147 // reason a merge of size N could have failed where another of
19148 // the same size would not have, is if the alignment has
19149 // improved. Drop as many candidates as we can here.
19150 unsigned NumSkip = 1;
19151 while ((NumSkip < NumConsecutiveStores) &&
19152 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
19153 NumSkip++;
19154
19155 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
19156 NumConsecutiveStores -= NumSkip;
19157 continue;
19158 }
19159
19160 // Check that we can merge these candidates without causing a cycle.
19161 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStoresToMerge,
19162 RootNode)) {
19163 StoreNodes.erase(StoreNodes.begin(),
19164 StoreNodes.begin() + NumStoresToMerge);
19165 NumConsecutiveStores -= NumStoresToMerge;
19166 continue;
19167 }
19168
19169 MadeChange |= mergeStoresOfConstantsOrVecElts(
19170 StoreNodes, MemVT, NumStoresToMerge, /*IsConstantSrc*/ false,
19171 /*UseVector*/ true, /*UseTrunc*/ false);
19172
19173 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumStoresToMerge);
19174 NumConsecutiveStores -= NumStoresToMerge;
19175 }
19176 return MadeChange;
19177 }
19178
tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> & StoreNodes,unsigned NumConsecutiveStores,EVT MemVT,SDNode * RootNode,bool AllowVectors,bool IsNonTemporalStore,bool IsNonTemporalLoad)19179 bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
19180 unsigned NumConsecutiveStores, EVT MemVT,
19181 SDNode *RootNode, bool AllowVectors,
19182 bool IsNonTemporalStore,
19183 bool IsNonTemporalLoad) {
19184 LLVMContext &Context = *DAG.getContext();
19185 const DataLayout &DL = DAG.getDataLayout();
19186 int64_t ElementSizeBytes = MemVT.getStoreSize();
19187 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
19188 bool MadeChange = false;
19189
19190 // Look for load nodes which are used by the stored values.
19191 SmallVector<MemOpLink, 8> LoadNodes;
19192
19193 // Find acceptable loads. Loads need to have the same chain (token factor),
19194 // must not be zext, volatile, indexed, and they must be consecutive.
19195 BaseIndexOffset LdBasePtr;
19196
19197 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
19198 StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
19199 SDValue Val = peekThroughBitcasts(St->getValue());
19200 LoadSDNode *Ld = cast<LoadSDNode>(Val);
19201
19202 BaseIndexOffset LdPtr = BaseIndexOffset::match(Ld, DAG);
19203 // If this is not the first ptr that we check.
19204 int64_t LdOffset = 0;
19205 if (LdBasePtr.getBase().getNode()) {
19206 // The base ptr must be the same.
19207 if (!LdBasePtr.equalBaseIndex(LdPtr, DAG, LdOffset))
19208 break;
19209 } else {
19210 // Check that all other base pointers are the same as this one.
19211 LdBasePtr = LdPtr;
19212 }
19213
19214 // We found a potential memory operand to merge.
19215 LoadNodes.push_back(MemOpLink(Ld, LdOffset));
19216 }
19217
19218 while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
19219 Align RequiredAlignment;
19220 bool NeedRotate = false;
19221 if (LoadNodes.size() == 2) {
19222 // If we have load/store pair instructions and we only have two values,
19223 // don't bother merging.
19224 if (TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
19225 StoreNodes[0].MemNode->getAlign() >= RequiredAlignment) {
19226 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + 2);
19227 LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + 2);
19228 break;
19229 }
19230 // If the loads are reversed, see if we can rotate the halves into place.
19231 int64_t Offset0 = LoadNodes[0].OffsetFromBase;
19232 int64_t Offset1 = LoadNodes[1].OffsetFromBase;
19233 EVT PairVT = EVT::getIntegerVT(Context, ElementSizeBytes * 8 * 2);
19234 if (Offset0 - Offset1 == ElementSizeBytes &&
19235 (hasOperation(ISD::ROTL, PairVT) ||
19236 hasOperation(ISD::ROTR, PairVT))) {
19237 std::swap(LoadNodes[0], LoadNodes[1]);
19238 NeedRotate = true;
19239 }
19240 }
19241 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
19242 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
19243 Align FirstStoreAlign = FirstInChain->getAlign();
19244 LoadSDNode *FirstLoad = cast<LoadSDNode>(LoadNodes[0].MemNode);
19245
19246 // Scan the memory operations on the chain and find the first
19247 // non-consecutive load memory address. These variables hold the index in
19248 // the store node array.
19249
19250 unsigned LastConsecutiveLoad = 1;
19251
19252 // This variable refers to the size and not index in the array.
19253 unsigned LastLegalVectorType = 1;
19254 unsigned LastLegalIntegerType = 1;
19255 bool isDereferenceable = true;
19256 bool DoIntegerTruncate = false;
19257 int64_t StartAddress = LoadNodes[0].OffsetFromBase;
19258 SDValue LoadChain = FirstLoad->getChain();
19259 for (unsigned i = 1; i < LoadNodes.size(); ++i) {
19260 // All loads must share the same chain.
19261 if (LoadNodes[i].MemNode->getChain() != LoadChain)
19262 break;
19263
19264 int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
19265 if (CurrAddress - StartAddress != (ElementSizeBytes * i))
19266 break;
19267 LastConsecutiveLoad = i;
19268
19269 if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
19270 isDereferenceable = false;
19271
19272 // Find a legal type for the vector store.
19273 unsigned Elts = (i + 1) * NumMemElts;
19274 EVT StoreTy = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
19275
19276 // Break early when size is too large to be legal.
19277 if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
19278 break;
19279
19280 unsigned IsFastSt = 0;
19281 unsigned IsFastLd = 0;
19282 // Don't try vector types if we need a rotate. We may still fail the
19283 // legality checks for the integer type, but we can't handle the rotate
19284 // case with vectors.
19285 // FIXME: We could use a shuffle in place of the rotate.
19286 if (!NeedRotate && TLI.isTypeLegal(StoreTy) &&
19287 TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
19288 DAG.getMachineFunction()) &&
19289 TLI.allowsMemoryAccess(Context, DL, StoreTy,
19290 *FirstInChain->getMemOperand(), &IsFastSt) &&
19291 IsFastSt &&
19292 TLI.allowsMemoryAccess(Context, DL, StoreTy,
19293 *FirstLoad->getMemOperand(), &IsFastLd) &&
19294 IsFastLd) {
19295 LastLegalVectorType = i + 1;
19296 }
19297
19298 // Find a legal type for the integer store.
19299 unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
19300 StoreTy = EVT::getIntegerVT(Context, SizeInBits);
19301 if (TLI.isTypeLegal(StoreTy) &&
19302 TLI.canMergeStoresTo(FirstStoreAS, StoreTy,
19303 DAG.getMachineFunction()) &&
19304 TLI.allowsMemoryAccess(Context, DL, StoreTy,
19305 *FirstInChain->getMemOperand(), &IsFastSt) &&
19306 IsFastSt &&
19307 TLI.allowsMemoryAccess(Context, DL, StoreTy,
19308 *FirstLoad->getMemOperand(), &IsFastLd) &&
19309 IsFastLd) {
19310 LastLegalIntegerType = i + 1;
19311 DoIntegerTruncate = false;
19312 // Or check whether a truncstore and extload is legal.
19313 } else if (TLI.getTypeAction(Context, StoreTy) ==
19314 TargetLowering::TypePromoteInteger) {
19315 EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy);
19316 if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) &&
19317 TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy,
19318 DAG.getMachineFunction()) &&
19319 TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, StoreTy) &&
19320 TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, StoreTy) &&
19321 TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) &&
19322 TLI.allowsMemoryAccess(Context, DL, StoreTy,
19323 *FirstInChain->getMemOperand(), &IsFastSt) &&
19324 IsFastSt &&
19325 TLI.allowsMemoryAccess(Context, DL, StoreTy,
19326 *FirstLoad->getMemOperand(), &IsFastLd) &&
19327 IsFastLd) {
19328 LastLegalIntegerType = i + 1;
19329 DoIntegerTruncate = true;
19330 }
19331 }
19332 }
19333
19334 // Only use vector types if the vector type is larger than the integer
19335 // type. If they are the same, use integers.
19336 bool UseVectorTy =
19337 LastLegalVectorType > LastLegalIntegerType && AllowVectors;
19338 unsigned LastLegalType =
19339 std::max(LastLegalVectorType, LastLegalIntegerType);
19340
19341 // We add +1 here because the LastXXX variables refer to location while
19342 // the NumElem refers to array/index size.
19343 unsigned NumElem = std::min(NumConsecutiveStores, LastConsecutiveLoad + 1);
19344 NumElem = std::min(LastLegalType, NumElem);
19345 Align FirstLoadAlign = FirstLoad->getAlign();
19346
19347 if (NumElem < 2) {
19348 // We know that candidate stores are in order and of correct
19349 // shape. While there is no mergeable sequence from the
19350 // beginning one may start later in the sequence. The only
19351 // reason a merge of size N could have failed where another of
19352 // the same size would not have is if the alignment or either
19353 // the load or store has improved. Drop as many candidates as we
19354 // can here.
19355 unsigned NumSkip = 1;
19356 while ((NumSkip < LoadNodes.size()) &&
19357 (LoadNodes[NumSkip].MemNode->getAlign() <= FirstLoadAlign) &&
19358 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
19359 NumSkip++;
19360 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumSkip);
19361 LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumSkip);
19362 NumConsecutiveStores -= NumSkip;
19363 continue;
19364 }
19365
19366 // Check that we can merge these candidates without causing a cycle.
19367 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumElem,
19368 RootNode)) {
19369 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
19370 LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
19371 NumConsecutiveStores -= NumElem;
19372 continue;
19373 }
19374
19375 // Find if it is better to use vectors or integers to load and store
19376 // to memory.
19377 EVT JointMemOpVT;
19378 if (UseVectorTy) {
19379 // Find a legal type for the vector store.
19380 unsigned Elts = NumElem * NumMemElts;
19381 JointMemOpVT = EVT::getVectorVT(Context, MemVT.getScalarType(), Elts);
19382 } else {
19383 unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
19384 JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits);
19385 }
19386
19387 SDLoc LoadDL(LoadNodes[0].MemNode);
19388 SDLoc StoreDL(StoreNodes[0].MemNode);
19389
19390 // The merged loads are required to have the same incoming chain, so
19391 // using the first's chain is acceptable.
19392
19393 SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumElem);
19394 AddToWorklist(NewStoreChain.getNode());
19395
19396 MachineMemOperand::Flags LdMMOFlags =
19397 isDereferenceable ? MachineMemOperand::MODereferenceable
19398 : MachineMemOperand::MONone;
19399 if (IsNonTemporalLoad)
19400 LdMMOFlags |= MachineMemOperand::MONonTemporal;
19401
19402 MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore
19403 ? MachineMemOperand::MONonTemporal
19404 : MachineMemOperand::MONone;
19405
19406 SDValue NewLoad, NewStore;
19407 if (UseVectorTy || !DoIntegerTruncate) {
19408 NewLoad = DAG.getLoad(
19409 JointMemOpVT, LoadDL, FirstLoad->getChain(), FirstLoad->getBasePtr(),
19410 FirstLoad->getPointerInfo(), FirstLoadAlign, LdMMOFlags);
19411 SDValue StoreOp = NewLoad;
19412 if (NeedRotate) {
19413 unsigned LoadWidth = ElementSizeBytes * 8 * 2;
19414 assert(JointMemOpVT == EVT::getIntegerVT(Context, LoadWidth) &&
19415 "Unexpected type for rotate-able load pair");
19416 SDValue RotAmt =
19417 DAG.getShiftAmountConstant(LoadWidth / 2, JointMemOpVT, LoadDL);
19418 // Target can convert to the identical ROTR if it does not have ROTL.
19419 StoreOp = DAG.getNode(ISD::ROTL, LoadDL, JointMemOpVT, NewLoad, RotAmt);
19420 }
19421 NewStore = DAG.getStore(
19422 NewStoreChain, StoreDL, StoreOp, FirstInChain->getBasePtr(),
19423 FirstInChain->getPointerInfo(), FirstStoreAlign, StMMOFlags);
19424 } else { // This must be the truncstore/extload case
19425 EVT ExtendedTy =
19426 TLI.getTypeToTransformTo(*DAG.getContext(), JointMemOpVT);
19427 NewLoad = DAG.getExtLoad(ISD::EXTLOAD, LoadDL, ExtendedTy,
19428 FirstLoad->getChain(), FirstLoad->getBasePtr(),
19429 FirstLoad->getPointerInfo(), JointMemOpVT,
19430 FirstLoadAlign, LdMMOFlags);
19431 NewStore = DAG.getTruncStore(
19432 NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(),
19433 FirstInChain->getPointerInfo(), JointMemOpVT,
19434 FirstInChain->getAlign(), FirstInChain->getMemOperand()->getFlags());
19435 }
19436
19437 // Transfer chain users from old loads to the new load.
19438 for (unsigned i = 0; i < NumElem; ++i) {
19439 LoadSDNode *Ld = cast<LoadSDNode>(LoadNodes[i].MemNode);
19440 DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1),
19441 SDValue(NewLoad.getNode(), 1));
19442 }
19443
19444 // Replace all stores with the new store. Recursively remove corresponding
19445 // values if they are no longer used.
19446 for (unsigned i = 0; i < NumElem; ++i) {
19447 SDValue Val = StoreNodes[i].MemNode->getOperand(1);
19448 CombineTo(StoreNodes[i].MemNode, NewStore);
19449 if (Val->use_empty())
19450 recursivelyDeleteUnusedNodes(Val.getNode());
19451 }
19452
19453 MadeChange = true;
19454 StoreNodes.erase(StoreNodes.begin(), StoreNodes.begin() + NumElem);
19455 LoadNodes.erase(LoadNodes.begin(), LoadNodes.begin() + NumElem);
19456 NumConsecutiveStores -= NumElem;
19457 }
19458 return MadeChange;
19459 }
19460
mergeConsecutiveStores(StoreSDNode * St)19461 bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
19462 if (OptLevel == CodeGenOpt::None || !EnableStoreMerging)
19463 return false;
19464
19465 // TODO: Extend this function to merge stores of scalable vectors.
19466 // (i.e. two <vscale x 8 x i8> stores can be merged to one <vscale x 16 x i8>
19467 // store since we know <vscale x 16 x i8> is exactly twice as large as
19468 // <vscale x 8 x i8>). Until then, bail out for scalable vectors.
19469 EVT MemVT = St->getMemoryVT();
19470 if (MemVT.isScalableVector())
19471 return false;
19472 if (!MemVT.isSimple() || MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
19473 return false;
19474
19475 // This function cannot currently deal with non-byte-sized memory sizes.
19476 int64_t ElementSizeBytes = MemVT.getStoreSize();
19477 if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits())
19478 return false;
19479
19480 // Do not bother looking at stored values that are not constants, loads, or
19481 // extracted vector elements.
19482 SDValue StoredVal = peekThroughBitcasts(St->getValue());
19483 const StoreSource StoreSrc = getStoreSource(StoredVal);
19484 if (StoreSrc == StoreSource::Unknown)
19485 return false;
19486
19487 SmallVector<MemOpLink, 8> StoreNodes;
19488 SDNode *RootNode;
19489 // Find potential store merge candidates by searching through chain sub-DAG
19490 getStoreMergeCandidates(St, StoreNodes, RootNode);
19491
19492 // Check if there is anything to merge.
19493 if (StoreNodes.size() < 2)
19494 return false;
19495
19496 // Sort the memory operands according to their distance from the
19497 // base pointer.
19498 llvm::sort(StoreNodes, [](MemOpLink LHS, MemOpLink RHS) {
19499 return LHS.OffsetFromBase < RHS.OffsetFromBase;
19500 });
19501
19502 bool AllowVectors = !DAG.getMachineFunction().getFunction().hasFnAttribute(
19503 Attribute::NoImplicitFloat);
19504 bool IsNonTemporalStore = St->isNonTemporal();
19505 bool IsNonTemporalLoad = StoreSrc == StoreSource::Load &&
19506 cast<LoadSDNode>(StoredVal)->isNonTemporal();
19507
19508 // Store Merge attempts to merge the lowest stores. This generally
19509 // works out as if successful, as the remaining stores are checked
19510 // after the first collection of stores is merged. However, in the
19511 // case that a non-mergeable store is found first, e.g., {p[-2],
19512 // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
19513 // mergeable cases. To prevent this, we prune such stores from the
19514 // front of StoreNodes here.
19515 bool MadeChange = false;
19516 while (StoreNodes.size() > 1) {
19517 unsigned NumConsecutiveStores =
19518 getConsecutiveStores(StoreNodes, ElementSizeBytes);
19519 // There are no more stores in the list to examine.
19520 if (NumConsecutiveStores == 0)
19521 return MadeChange;
19522
19523 // We have at least 2 consecutive stores. Try to merge them.
19524 assert(NumConsecutiveStores >= 2 && "Expected at least 2 stores");
19525 switch (StoreSrc) {
19526 case StoreSource::Constant:
19527 MadeChange |= tryStoreMergeOfConstants(StoreNodes, NumConsecutiveStores,
19528 MemVT, RootNode, AllowVectors);
19529 break;
19530
19531 case StoreSource::Extract:
19532 MadeChange |= tryStoreMergeOfExtracts(StoreNodes, NumConsecutiveStores,
19533 MemVT, RootNode);
19534 break;
19535
19536 case StoreSource::Load:
19537 MadeChange |= tryStoreMergeOfLoads(StoreNodes, NumConsecutiveStores,
19538 MemVT, RootNode, AllowVectors,
19539 IsNonTemporalStore, IsNonTemporalLoad);
19540 break;
19541
19542 default:
19543 llvm_unreachable("Unhandled store source type");
19544 }
19545 }
19546 return MadeChange;
19547 }
19548
replaceStoreChain(StoreSDNode * ST,SDValue BetterChain)19549 SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
19550 SDLoc SL(ST);
19551 SDValue ReplStore;
19552
19553 // Replace the chain to avoid dependency.
19554 if (ST->isTruncatingStore()) {
19555 ReplStore = DAG.getTruncStore(BetterChain, SL, ST->getValue(),
19556 ST->getBasePtr(), ST->getMemoryVT(),
19557 ST->getMemOperand());
19558 } else {
19559 ReplStore = DAG.getStore(BetterChain, SL, ST->getValue(), ST->getBasePtr(),
19560 ST->getMemOperand());
19561 }
19562
19563 // Create token to keep both nodes around.
19564 SDValue Token = DAG.getNode(ISD::TokenFactor, SL,
19565 MVT::Other, ST->getChain(), ReplStore);
19566
19567 // Make sure the new and old chains are cleaned up.
19568 AddToWorklist(Token.getNode());
19569
19570 // Don't add users to work list.
19571 return CombineTo(ST, Token, false);
19572 }
19573
replaceStoreOfFPConstant(StoreSDNode * ST)19574 SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
19575 SDValue Value = ST->getValue();
19576 if (Value.getOpcode() == ISD::TargetConstantFP)
19577 return SDValue();
19578
19579 if (!ISD::isNormalStore(ST))
19580 return SDValue();
19581
19582 SDLoc DL(ST);
19583
19584 SDValue Chain = ST->getChain();
19585 SDValue Ptr = ST->getBasePtr();
19586
19587 const ConstantFPSDNode *CFP = cast<ConstantFPSDNode>(Value);
19588
19589 // NOTE: If the original store is volatile, this transform must not increase
19590 // the number of stores. For example, on x86-32 an f64 can be stored in one
19591 // processor operation but an i64 (which is not legal) requires two. So the
19592 // transform should not be done in this case.
19593
19594 SDValue Tmp;
19595 switch (CFP->getSimpleValueType(0).SimpleTy) {
19596 default:
19597 llvm_unreachable("Unknown FP type");
19598 case MVT::f16: // We don't do this for these yet.
19599 case MVT::bf16:
19600 case MVT::f80:
19601 case MVT::f128:
19602 case MVT::ppcf128:
19603 return SDValue();
19604 case MVT::f32:
19605 if ((isTypeLegal(MVT::i32) && !LegalOperations && ST->isSimple()) ||
19606 TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
19607 Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF().
19608 bitcastToAPInt().getZExtValue(), SDLoc(CFP),
19609 MVT::i32);
19610 return DAG.getStore(Chain, DL, Tmp, Ptr, ST->getMemOperand());
19611 }
19612
19613 return SDValue();
19614 case MVT::f64:
19615 if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations &&
19616 ST->isSimple()) ||
19617 TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) {
19618 Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().
19619 getZExtValue(), SDLoc(CFP), MVT::i64);
19620 return DAG.getStore(Chain, DL, Tmp,
19621 Ptr, ST->getMemOperand());
19622 }
19623
19624 if (ST->isSimple() &&
19625 TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
19626 // Many FP stores are not made apparent until after legalize, e.g. for
19627 // argument passing. Since this is so common, custom legalize the
19628 // 64-bit integer store into two 32-bit stores.
19629 uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
19630 SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32);
19631 SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32);
19632 if (DAG.getDataLayout().isBigEndian())
19633 std::swap(Lo, Hi);
19634
19635 MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
19636 AAMDNodes AAInfo = ST->getAAInfo();
19637
19638 SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
19639 ST->getOriginalAlign(), MMOFlags, AAInfo);
19640 Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(4), DL);
19641 SDValue St1 = DAG.getStore(Chain, DL, Hi, Ptr,
19642 ST->getPointerInfo().getWithOffset(4),
19643 ST->getOriginalAlign(), MMOFlags, AAInfo);
19644 return DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
19645 St0, St1);
19646 }
19647
19648 return SDValue();
19649 }
19650 }
19651
visitSTORE(SDNode * N)19652 SDValue DAGCombiner::visitSTORE(SDNode *N) {
19653 StoreSDNode *ST = cast<StoreSDNode>(N);
19654 SDValue Chain = ST->getChain();
19655 SDValue Value = ST->getValue();
19656 SDValue Ptr = ST->getBasePtr();
19657
19658 // If this is a store of a bit convert, store the input value if the
19659 // resultant store does not need a higher alignment than the original.
19660 if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() &&
19661 ST->isUnindexed()) {
19662 EVT SVT = Value.getOperand(0).getValueType();
19663 // If the store is volatile, we only want to change the store type if the
19664 // resulting store is legal. Otherwise we might increase the number of
19665 // memory accesses. We don't care if the original type was legal or not
19666 // as we assume software couldn't rely on the number of accesses of an
19667 // illegal type.
19668 // TODO: May be able to relax for unordered atomics (see D66309)
19669 if (((!LegalOperations && ST->isSimple()) ||
19670 TLI.isOperationLegal(ISD::STORE, SVT)) &&
19671 TLI.isStoreBitCastBeneficial(Value.getValueType(), SVT,
19672 DAG, *ST->getMemOperand())) {
19673 return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
19674 ST->getMemOperand());
19675 }
19676 }
19677
19678 // Turn 'store undef, Ptr' -> nothing.
19679 if (Value.isUndef() && ST->isUnindexed())
19680 return Chain;
19681
19682 // Try to infer better alignment information than the store already has.
19683 if (OptLevel != CodeGenOpt::None && ST->isUnindexed() && !ST->isAtomic()) {
19684 if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
19685 if (*Alignment > ST->getAlign() &&
19686 isAligned(*Alignment, ST->getSrcValueOffset())) {
19687 SDValue NewStore =
19688 DAG.getTruncStore(Chain, SDLoc(N), Value, Ptr, ST->getPointerInfo(),
19689 ST->getMemoryVT(), *Alignment,
19690 ST->getMemOperand()->getFlags(), ST->getAAInfo());
19691 // NewStore will always be N as we are only refining the alignment
19692 assert(NewStore.getNode() == N);
19693 (void)NewStore;
19694 }
19695 }
19696 }
19697
19698 // Try transforming a pair floating point load / store ops to integer
19699 // load / store ops.
19700 if (SDValue NewST = TransformFPLoadStorePair(N))
19701 return NewST;
19702
19703 // Try transforming several stores into STORE (BSWAP).
19704 if (SDValue Store = mergeTruncStores(ST))
19705 return Store;
19706
19707 if (ST->isUnindexed()) {
19708 // Walk up chain skipping non-aliasing memory nodes, on this store and any
19709 // adjacent stores.
19710 if (findBetterNeighborChains(ST)) {
19711 // replaceStoreChain uses CombineTo, which handled all of the worklist
19712 // manipulation. Return the original node to not do anything else.
19713 return SDValue(ST, 0);
19714 }
19715 Chain = ST->getChain();
19716 }
19717
19718 // FIXME: is there such a thing as a truncating indexed store?
19719 if (ST->isTruncatingStore() && ST->isUnindexed() &&
19720 Value.getValueType().isInteger() &&
19721 (!isa<ConstantSDNode>(Value) ||
19722 !cast<ConstantSDNode>(Value)->isOpaque())) {
19723 // Convert a truncating store of a extension into a standard store.
19724 if ((Value.getOpcode() == ISD::ZERO_EXTEND ||
19725 Value.getOpcode() == ISD::SIGN_EXTEND ||
19726 Value.getOpcode() == ISD::ANY_EXTEND) &&
19727 Value.getOperand(0).getValueType() == ST->getMemoryVT() &&
19728 TLI.isOperationLegalOrCustom(ISD::STORE, ST->getMemoryVT()))
19729 return DAG.getStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
19730 ST->getMemOperand());
19731
19732 APInt TruncDemandedBits =
19733 APInt::getLowBitsSet(Value.getScalarValueSizeInBits(),
19734 ST->getMemoryVT().getScalarSizeInBits());
19735
19736 // See if we can simplify the operation with SimplifyDemandedBits, which
19737 // only works if the value has a single use.
19738 AddToWorklist(Value.getNode());
19739 if (SimplifyDemandedBits(Value, TruncDemandedBits)) {
19740 // Re-visit the store if anything changed and the store hasn't been merged
19741 // with another node (N is deleted) SimplifyDemandedBits will add Value's
19742 // node back to the worklist if necessary, but we also need to re-visit
19743 // the Store node itself.
19744 if (N->getOpcode() != ISD::DELETED_NODE)
19745 AddToWorklist(N);
19746 return SDValue(N, 0);
19747 }
19748
19749 // Otherwise, see if we can simplify the input to this truncstore with
19750 // knowledge that only the low bits are being used. For example:
19751 // "truncstore (or (shl x, 8), y), i8" -> "truncstore y, i8"
19752 if (SDValue Shorter =
19753 TLI.SimplifyMultipleUseDemandedBits(Value, TruncDemandedBits, DAG))
19754 return DAG.getTruncStore(Chain, SDLoc(N), Shorter, Ptr, ST->getMemoryVT(),
19755 ST->getMemOperand());
19756
19757 // If we're storing a truncated constant, see if we can simplify it.
19758 // TODO: Move this to targetShrinkDemandedConstant?
19759 if (auto *Cst = dyn_cast<ConstantSDNode>(Value))
19760 if (!Cst->isOpaque()) {
19761 const APInt &CValue = Cst->getAPIntValue();
19762 APInt NewVal = CValue & TruncDemandedBits;
19763 if (NewVal != CValue) {
19764 SDValue Shorter =
19765 DAG.getConstant(NewVal, SDLoc(N), Value.getValueType());
19766 return DAG.getTruncStore(Chain, SDLoc(N), Shorter, Ptr,
19767 ST->getMemoryVT(), ST->getMemOperand());
19768 }
19769 }
19770 }
19771
19772 // If this is a load followed by a store to the same location, then the store
19773 // is dead/noop.
19774 // TODO: Can relax for unordered atomics (see D66309)
19775 if (LoadSDNode *Ld = dyn_cast<LoadSDNode>(Value)) {
19776 if (Ld->getBasePtr() == Ptr && ST->getMemoryVT() == Ld->getMemoryVT() &&
19777 ST->isUnindexed() && ST->isSimple() &&
19778 Ld->getAddressSpace() == ST->getAddressSpace() &&
19779 // There can't be any side effects between the load and store, such as
19780 // a call or store.
19781 Chain.reachesChainWithoutSideEffects(SDValue(Ld, 1))) {
19782 // The store is dead, remove it.
19783 return Chain;
19784 }
19785 }
19786
19787 // TODO: Can relax for unordered atomics (see D66309)
19788 if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Chain)) {
19789 if (ST->isUnindexed() && ST->isSimple() &&
19790 ST1->isUnindexed() && ST1->isSimple()) {
19791 if (OptLevel != CodeGenOpt::None && ST1->getBasePtr() == Ptr &&
19792 ST1->getValue() == Value && ST->getMemoryVT() == ST1->getMemoryVT() &&
19793 ST->getAddressSpace() == ST1->getAddressSpace()) {
19794 // If this is a store followed by a store with the same value to the
19795 // same location, then the store is dead/noop.
19796 return Chain;
19797 }
19798
19799 if (OptLevel != CodeGenOpt::None && ST1->hasOneUse() &&
19800 !ST1->getBasePtr().isUndef() &&
19801 // BaseIndexOffset and the code below requires knowing the size
19802 // of a vector, so bail out if MemoryVT is scalable.
19803 !ST->getMemoryVT().isScalableVector() &&
19804 !ST1->getMemoryVT().isScalableVector() &&
19805 ST->getAddressSpace() == ST1->getAddressSpace()) {
19806 const BaseIndexOffset STBase = BaseIndexOffset::match(ST, DAG);
19807 const BaseIndexOffset ChainBase = BaseIndexOffset::match(ST1, DAG);
19808 unsigned STBitSize = ST->getMemoryVT().getFixedSizeInBits();
19809 unsigned ChainBitSize = ST1->getMemoryVT().getFixedSizeInBits();
19810 // If this is a store who's preceding store to a subset of the current
19811 // location and no one other node is chained to that store we can
19812 // effectively drop the store. Do not remove stores to undef as they may
19813 // be used as data sinks.
19814 if (STBase.contains(DAG, STBitSize, ChainBase, ChainBitSize)) {
19815 CombineTo(ST1, ST1->getChain());
19816 return SDValue();
19817 }
19818 }
19819 }
19820 }
19821
19822 // If this is an FP_ROUND or TRUNC followed by a store, fold this into a
19823 // truncating store. We can do this even if this is already a truncstore.
19824 if ((Value.getOpcode() == ISD::FP_ROUND ||
19825 Value.getOpcode() == ISD::TRUNCATE) &&
19826 Value->hasOneUse() && ST->isUnindexed() &&
19827 TLI.canCombineTruncStore(Value.getOperand(0).getValueType(),
19828 ST->getMemoryVT(), LegalOperations)) {
19829 return DAG.getTruncStore(Chain, SDLoc(N), Value.getOperand(0),
19830 Ptr, ST->getMemoryVT(), ST->getMemOperand());
19831 }
19832
19833 // Always perform this optimization before types are legal. If the target
19834 // prefers, also try this after legalization to catch stores that were created
19835 // by intrinsics or other nodes.
19836 if (!LegalTypes || (TLI.mergeStoresAfterLegalization(ST->getMemoryVT()))) {
19837 while (true) {
19838 // There can be multiple store sequences on the same chain.
19839 // Keep trying to merge store sequences until we are unable to do so
19840 // or until we merge the last store on the chain.
19841 bool Changed = mergeConsecutiveStores(ST);
19842 if (!Changed) break;
19843 // Return N as merge only uses CombineTo and no worklist clean
19844 // up is necessary.
19845 if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(N))
19846 return SDValue(N, 0);
19847 }
19848 }
19849
19850 // Try transforming N to an indexed store.
19851 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
19852 return SDValue(N, 0);
19853
19854 // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
19855 //
19856 // Make sure to do this only after attempting to merge stores in order to
19857 // avoid changing the types of some subset of stores due to visit order,
19858 // preventing their merging.
19859 if (isa<ConstantFPSDNode>(ST->getValue())) {
19860 if (SDValue NewSt = replaceStoreOfFPConstant(ST))
19861 return NewSt;
19862 }
19863
19864 if (SDValue NewSt = splitMergedValStore(ST))
19865 return NewSt;
19866
19867 return ReduceLoadOpStoreWidth(N);
19868 }
19869
visitLIFETIME_END(SDNode * N)19870 SDValue DAGCombiner::visitLIFETIME_END(SDNode *N) {
19871 const auto *LifetimeEnd = cast<LifetimeSDNode>(N);
19872 if (!LifetimeEnd->hasOffset())
19873 return SDValue();
19874
19875 const BaseIndexOffset LifetimeEndBase(N->getOperand(1), SDValue(),
19876 LifetimeEnd->getOffset(), false);
19877
19878 // We walk up the chains to find stores.
19879 SmallVector<SDValue, 8> Chains = {N->getOperand(0)};
19880 while (!Chains.empty()) {
19881 SDValue Chain = Chains.pop_back_val();
19882 if (!Chain.hasOneUse())
19883 continue;
19884 switch (Chain.getOpcode()) {
19885 case ISD::TokenFactor:
19886 for (unsigned Nops = Chain.getNumOperands(); Nops;)
19887 Chains.push_back(Chain.getOperand(--Nops));
19888 break;
19889 case ISD::LIFETIME_START:
19890 case ISD::LIFETIME_END:
19891 // We can forward past any lifetime start/end that can be proven not to
19892 // alias the node.
19893 if (!mayAlias(Chain.getNode(), N))
19894 Chains.push_back(Chain.getOperand(0));
19895 break;
19896 case ISD::STORE: {
19897 StoreSDNode *ST = dyn_cast<StoreSDNode>(Chain);
19898 // TODO: Can relax for unordered atomics (see D66309)
19899 if (!ST->isSimple() || ST->isIndexed())
19900 continue;
19901 const TypeSize StoreSize = ST->getMemoryVT().getStoreSize();
19902 // The bounds of a scalable store are not known until runtime, so this
19903 // store cannot be elided.
19904 if (StoreSize.isScalable())
19905 continue;
19906 const BaseIndexOffset StoreBase = BaseIndexOffset::match(ST, DAG);
19907 // If we store purely within object bounds just before its lifetime ends,
19908 // we can remove the store.
19909 if (LifetimeEndBase.contains(DAG, LifetimeEnd->getSize() * 8, StoreBase,
19910 StoreSize.getFixedValue() * 8)) {
19911 LLVM_DEBUG(dbgs() << "\nRemoving store:"; StoreBase.dump();
19912 dbgs() << "\nwithin LIFETIME_END of : ";
19913 LifetimeEndBase.dump(); dbgs() << "\n");
19914 CombineTo(ST, ST->getChain());
19915 return SDValue(N, 0);
19916 }
19917 }
19918 }
19919 }
19920 return SDValue();
19921 }
19922
19923 /// For the instruction sequence of store below, F and I values
19924 /// are bundled together as an i64 value before being stored into memory.
19925 /// Sometimes it is more efficent to generate separate stores for F and I,
19926 /// which can remove the bitwise instructions or sink them to colder places.
19927 ///
19928 /// (store (or (zext (bitcast F to i32) to i64),
19929 /// (shl (zext I to i64), 32)), addr) -->
19930 /// (store F, addr) and (store I, addr+4)
19931 ///
19932 /// Similarly, splitting for other merged store can also be beneficial, like:
19933 /// For pair of {i32, i32}, i64 store --> two i32 stores.
19934 /// For pair of {i32, i16}, i64 store --> two i32 stores.
19935 /// For pair of {i16, i16}, i32 store --> two i16 stores.
19936 /// For pair of {i16, i8}, i32 store --> two i16 stores.
19937 /// For pair of {i8, i8}, i16 store --> two i8 stores.
19938 ///
19939 /// We allow each target to determine specifically which kind of splitting is
19940 /// supported.
19941 ///
19942 /// The store patterns are commonly seen from the simple code snippet below
19943 /// if only std::make_pair(...) is sroa transformed before inlined into hoo.
19944 /// void goo(const std::pair<int, float> &);
19945 /// hoo() {
19946 /// ...
19947 /// goo(std::make_pair(tmp, ftmp));
19948 /// ...
19949 /// }
19950 ///
splitMergedValStore(StoreSDNode * ST)19951 SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) {
19952 if (OptLevel == CodeGenOpt::None)
19953 return SDValue();
19954
19955 // Can't change the number of memory accesses for a volatile store or break
19956 // atomicity for an atomic one.
19957 if (!ST->isSimple())
19958 return SDValue();
19959
19960 SDValue Val = ST->getValue();
19961 SDLoc DL(ST);
19962
19963 // Match OR operand.
19964 if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR)
19965 return SDValue();
19966
19967 // Match SHL operand and get Lower and Higher parts of Val.
19968 SDValue Op1 = Val.getOperand(0);
19969 SDValue Op2 = Val.getOperand(1);
19970 SDValue Lo, Hi;
19971 if (Op1.getOpcode() != ISD::SHL) {
19972 std::swap(Op1, Op2);
19973 if (Op1.getOpcode() != ISD::SHL)
19974 return SDValue();
19975 }
19976 Lo = Op2;
19977 Hi = Op1.getOperand(0);
19978 if (!Op1.hasOneUse())
19979 return SDValue();
19980
19981 // Match shift amount to HalfValBitSize.
19982 unsigned HalfValBitSize = Val.getValueSizeInBits() / 2;
19983 ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Op1.getOperand(1));
19984 if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize)
19985 return SDValue();
19986
19987 // Lo and Hi are zero-extended from int with size less equal than 32
19988 // to i64.
19989 if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() ||
19990 !Lo.getOperand(0).getValueType().isScalarInteger() ||
19991 Lo.getOperand(0).getValueSizeInBits() > HalfValBitSize ||
19992 Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() ||
19993 !Hi.getOperand(0).getValueType().isScalarInteger() ||
19994 Hi.getOperand(0).getValueSizeInBits() > HalfValBitSize)
19995 return SDValue();
19996
19997 // Use the EVT of low and high parts before bitcast as the input
19998 // of target query.
19999 EVT LowTy = (Lo.getOperand(0).getOpcode() == ISD::BITCAST)
20000 ? Lo.getOperand(0).getValueType()
20001 : Lo.getValueType();
20002 EVT HighTy = (Hi.getOperand(0).getOpcode() == ISD::BITCAST)
20003 ? Hi.getOperand(0).getValueType()
20004 : Hi.getValueType();
20005 if (!TLI.isMultiStoresCheaperThanBitsMerge(LowTy, HighTy))
20006 return SDValue();
20007
20008 // Start to split store.
20009 MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
20010 AAMDNodes AAInfo = ST->getAAInfo();
20011
20012 // Change the sizes of Lo and Hi's value types to HalfValBitSize.
20013 EVT VT = EVT::getIntegerVT(*DAG.getContext(), HalfValBitSize);
20014 Lo = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Lo.getOperand(0));
20015 Hi = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Hi.getOperand(0));
20016
20017 SDValue Chain = ST->getChain();
20018 SDValue Ptr = ST->getBasePtr();
20019 // Lower value store.
20020 SDValue St0 = DAG.getStore(Chain, DL, Lo, Ptr, ST->getPointerInfo(),
20021 ST->getOriginalAlign(), MMOFlags, AAInfo);
20022 Ptr = DAG.getMemBasePlusOffset(Ptr, TypeSize::Fixed(HalfValBitSize / 8), DL);
20023 // Higher value store.
20024 SDValue St1 = DAG.getStore(
20025 St0, DL, Hi, Ptr, ST->getPointerInfo().getWithOffset(HalfValBitSize / 8),
20026 ST->getOriginalAlign(), MMOFlags, AAInfo);
20027 return St1;
20028 }
20029
20030 // Merge an insertion into an existing shuffle:
20031 // (insert_vector_elt (vector_shuffle X, Y, Mask),
20032 // .(extract_vector_elt X, N), InsIndex)
20033 // --> (vector_shuffle X, Y, NewMask)
20034 // and variations where shuffle operands may be CONCAT_VECTORS.
mergeEltWithShuffle(SDValue & X,SDValue & Y,ArrayRef<int> Mask,SmallVectorImpl<int> & NewMask,SDValue Elt,unsigned InsIndex)20035 static bool mergeEltWithShuffle(SDValue &X, SDValue &Y, ArrayRef<int> Mask,
20036 SmallVectorImpl<int> &NewMask, SDValue Elt,
20037 unsigned InsIndex) {
20038 if (Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
20039 !isa<ConstantSDNode>(Elt.getOperand(1)))
20040 return false;
20041
20042 // Vec's operand 0 is using indices from 0 to N-1 and
20043 // operand 1 from N to 2N - 1, where N is the number of
20044 // elements in the vectors.
20045 SDValue InsertVal0 = Elt.getOperand(0);
20046 int ElementOffset = -1;
20047
20048 // We explore the inputs of the shuffle in order to see if we find the
20049 // source of the extract_vector_elt. If so, we can use it to modify the
20050 // shuffle rather than perform an insert_vector_elt.
20051 SmallVector<std::pair<int, SDValue>, 8> ArgWorkList;
20052 ArgWorkList.emplace_back(Mask.size(), Y);
20053 ArgWorkList.emplace_back(0, X);
20054
20055 while (!ArgWorkList.empty()) {
20056 int ArgOffset;
20057 SDValue ArgVal;
20058 std::tie(ArgOffset, ArgVal) = ArgWorkList.pop_back_val();
20059
20060 if (ArgVal == InsertVal0) {
20061 ElementOffset = ArgOffset;
20062 break;
20063 }
20064
20065 // Peek through concat_vector.
20066 if (ArgVal.getOpcode() == ISD::CONCAT_VECTORS) {
20067 int CurrentArgOffset =
20068 ArgOffset + ArgVal.getValueType().getVectorNumElements();
20069 int Step = ArgVal.getOperand(0).getValueType().getVectorNumElements();
20070 for (SDValue Op : reverse(ArgVal->ops())) {
20071 CurrentArgOffset -= Step;
20072 ArgWorkList.emplace_back(CurrentArgOffset, Op);
20073 }
20074
20075 // Make sure we went through all the elements and did not screw up index
20076 // computation.
20077 assert(CurrentArgOffset == ArgOffset);
20078 }
20079 }
20080
20081 // If we failed to find a match, see if we can replace an UNDEF shuffle
20082 // operand.
20083 if (ElementOffset == -1) {
20084 if (!Y.isUndef() || InsertVal0.getValueType() != Y.getValueType())
20085 return false;
20086 ElementOffset = Mask.size();
20087 Y = InsertVal0;
20088 }
20089
20090 NewMask.assign(Mask.begin(), Mask.end());
20091 NewMask[InsIndex] = ElementOffset + Elt.getConstantOperandVal(1);
20092 assert(NewMask[InsIndex] < (int)(2 * Mask.size()) && NewMask[InsIndex] >= 0 &&
20093 "NewMask[InsIndex] is out of bound");
20094 return true;
20095 }
20096
20097 // Merge an insertion into an existing shuffle:
20098 // (insert_vector_elt (vector_shuffle X, Y), (extract_vector_elt X, N),
20099 // InsIndex)
20100 // --> (vector_shuffle X, Y) and variations where shuffle operands may be
20101 // CONCAT_VECTORS.
mergeInsertEltWithShuffle(SDNode * N,unsigned InsIndex)20102 SDValue DAGCombiner::mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex) {
20103 assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
20104 "Expected extract_vector_elt");
20105 SDValue InsertVal = N->getOperand(1);
20106 SDValue Vec = N->getOperand(0);
20107
20108 auto *SVN = dyn_cast<ShuffleVectorSDNode>(Vec);
20109 if (!SVN || !Vec.hasOneUse())
20110 return SDValue();
20111
20112 ArrayRef<int> Mask = SVN->getMask();
20113 SDValue X = Vec.getOperand(0);
20114 SDValue Y = Vec.getOperand(1);
20115
20116 SmallVector<int, 16> NewMask(Mask);
20117 if (mergeEltWithShuffle(X, Y, Mask, NewMask, InsertVal, InsIndex)) {
20118 SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
20119 Vec.getValueType(), SDLoc(N), X, Y, NewMask, DAG);
20120 if (LegalShuffle)
20121 return LegalShuffle;
20122 }
20123
20124 return SDValue();
20125 }
20126
20127 // Convert a disguised subvector insertion into a shuffle:
20128 // insert_vector_elt V, (bitcast X from vector type), IdxC -->
20129 // bitcast(shuffle (bitcast V), (extended X), Mask)
20130 // Note: We do not use an insert_subvector node because that requires a
20131 // legal subvector type.
combineInsertEltToShuffle(SDNode * N,unsigned InsIndex)20132 SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) {
20133 assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
20134 "Expected extract_vector_elt");
20135 SDValue InsertVal = N->getOperand(1);
20136
20137 if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() ||
20138 !InsertVal.getOperand(0).getValueType().isVector())
20139 return SDValue();
20140
20141 SDValue SubVec = InsertVal.getOperand(0);
20142 SDValue DestVec = N->getOperand(0);
20143 EVT SubVecVT = SubVec.getValueType();
20144 EVT VT = DestVec.getValueType();
20145 unsigned NumSrcElts = SubVecVT.getVectorNumElements();
20146 // If the source only has a single vector element, the cost of creating adding
20147 // it to a vector is likely to exceed the cost of a insert_vector_elt.
20148 if (NumSrcElts == 1)
20149 return SDValue();
20150 unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits();
20151 unsigned NumMaskVals = ExtendRatio * NumSrcElts;
20152
20153 // Step 1: Create a shuffle mask that implements this insert operation. The
20154 // vector that we are inserting into will be operand 0 of the shuffle, so
20155 // those elements are just 'i'. The inserted subvector is in the first
20156 // positions of operand 1 of the shuffle. Example:
20157 // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7}
20158 SmallVector<int, 16> Mask(NumMaskVals);
20159 for (unsigned i = 0; i != NumMaskVals; ++i) {
20160 if (i / NumSrcElts == InsIndex)
20161 Mask[i] = (i % NumSrcElts) + NumMaskVals;
20162 else
20163 Mask[i] = i;
20164 }
20165
20166 // Bail out if the target can not handle the shuffle we want to create.
20167 EVT SubVecEltVT = SubVecVT.getVectorElementType();
20168 EVT ShufVT = EVT::getVectorVT(*DAG.getContext(), SubVecEltVT, NumMaskVals);
20169 if (!TLI.isShuffleMaskLegal(Mask, ShufVT))
20170 return SDValue();
20171
20172 // Step 2: Create a wide vector from the inserted source vector by appending
20173 // undefined elements. This is the same size as our destination vector.
20174 SDLoc DL(N);
20175 SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getUNDEF(SubVecVT));
20176 ConcatOps[0] = SubVec;
20177 SDValue PaddedSubV = DAG.getNode(ISD::CONCAT_VECTORS, DL, ShufVT, ConcatOps);
20178
20179 // Step 3: Shuffle in the padded subvector.
20180 SDValue DestVecBC = DAG.getBitcast(ShufVT, DestVec);
20181 SDValue Shuf = DAG.getVectorShuffle(ShufVT, DL, DestVecBC, PaddedSubV, Mask);
20182 AddToWorklist(PaddedSubV.getNode());
20183 AddToWorklist(DestVecBC.getNode());
20184 AddToWorklist(Shuf.getNode());
20185 return DAG.getBitcast(VT, Shuf);
20186 }
20187
visitINSERT_VECTOR_ELT(SDNode * N)20188 SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
20189 SDValue InVec = N->getOperand(0);
20190 SDValue InVal = N->getOperand(1);
20191 SDValue EltNo = N->getOperand(2);
20192 SDLoc DL(N);
20193
20194 EVT VT = InVec.getValueType();
20195 auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
20196
20197 // Insert into out-of-bounds element is undefined.
20198 if (IndexC && VT.isFixedLengthVector() &&
20199 IndexC->getZExtValue() >= VT.getVectorNumElements())
20200 return DAG.getUNDEF(VT);
20201
20202 // Remove redundant insertions:
20203 // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x
20204 if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20205 InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1))
20206 return InVec;
20207
20208 if (!IndexC) {
20209 // If this is variable insert to undef vector, it might be better to splat:
20210 // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
20211 if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
20212 return DAG.getSplat(VT, DL, InVal);
20213 return SDValue();
20214 }
20215
20216 if (VT.isScalableVector())
20217 return SDValue();
20218
20219 unsigned NumElts = VT.getVectorNumElements();
20220
20221 // We must know which element is being inserted for folds below here.
20222 unsigned Elt = IndexC->getZExtValue();
20223
20224 // Handle <1 x ???> vector insertion special cases.
20225 if (NumElts == 1) {
20226 // insert_vector_elt(x, extract_vector_elt(y, 0), 0) -> y
20227 if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20228 InVal.getOperand(0).getValueType() == VT &&
20229 isNullConstant(InVal.getOperand(1)))
20230 return InVal.getOperand(0);
20231 }
20232
20233 // Canonicalize insert_vector_elt dag nodes.
20234 // Example:
20235 // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1)
20236 // -> (insert_vector_elt (insert_vector_elt A, Idx1), Idx0)
20237 //
20238 // Do this only if the child insert_vector node has one use; also
20239 // do this only if indices are both constants and Idx1 < Idx0.
20240 if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
20241 && isa<ConstantSDNode>(InVec.getOperand(2))) {
20242 unsigned OtherElt = InVec.getConstantOperandVal(2);
20243 if (Elt < OtherElt) {
20244 // Swap nodes.
20245 SDValue NewOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT,
20246 InVec.getOperand(0), InVal, EltNo);
20247 AddToWorklist(NewOp.getNode());
20248 return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(InVec.getNode()),
20249 VT, NewOp, InVec.getOperand(1), InVec.getOperand(2));
20250 }
20251 }
20252
20253 if (SDValue Shuf = mergeInsertEltWithShuffle(N, Elt))
20254 return Shuf;
20255
20256 if (SDValue Shuf = combineInsertEltToShuffle(N, Elt))
20257 return Shuf;
20258
20259 // Attempt to convert an insert_vector_elt chain into a legal build_vector.
20260 if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) {
20261 // vXi1 vector - we don't need to recurse.
20262 if (NumElts == 1)
20263 return DAG.getBuildVector(VT, DL, {InVal});
20264
20265 // If we haven't already collected the element, insert into the op list.
20266 EVT MaxEltVT = InVal.getValueType();
20267 auto AddBuildVectorOp = [&](SmallVectorImpl<SDValue> &Ops, SDValue Elt,
20268 unsigned Idx) {
20269 if (!Ops[Idx]) {
20270 Ops[Idx] = Elt;
20271 if (VT.isInteger()) {
20272 EVT EltVT = Elt.getValueType();
20273 MaxEltVT = MaxEltVT.bitsGE(EltVT) ? MaxEltVT : EltVT;
20274 }
20275 }
20276 };
20277
20278 // Ensure all the operands are the same value type, fill any missing
20279 // operands with UNDEF and create the BUILD_VECTOR.
20280 auto CanonicalizeBuildVector = [&](SmallVectorImpl<SDValue> &Ops) {
20281 assert(Ops.size() == NumElts && "Unexpected vector size");
20282 for (SDValue &Op : Ops) {
20283 if (Op)
20284 Op = VT.isInteger() ? DAG.getAnyExtOrTrunc(Op, DL, MaxEltVT) : Op;
20285 else
20286 Op = DAG.getUNDEF(MaxEltVT);
20287 }
20288 return DAG.getBuildVector(VT, DL, Ops);
20289 };
20290
20291 SmallVector<SDValue, 8> Ops(NumElts, SDValue());
20292 Ops[Elt] = InVal;
20293
20294 // Recurse up a INSERT_VECTOR_ELT chain to build a BUILD_VECTOR.
20295 for (SDValue CurVec = InVec; CurVec;) {
20296 // UNDEF - build new BUILD_VECTOR from already inserted operands.
20297 if (CurVec.isUndef())
20298 return CanonicalizeBuildVector(Ops);
20299
20300 // BUILD_VECTOR - insert unused operands and build new BUILD_VECTOR.
20301 if (CurVec.getOpcode() == ISD::BUILD_VECTOR && CurVec.hasOneUse()) {
20302 for (unsigned I = 0; I != NumElts; ++I)
20303 AddBuildVectorOp(Ops, CurVec.getOperand(I), I);
20304 return CanonicalizeBuildVector(Ops);
20305 }
20306
20307 // SCALAR_TO_VECTOR - insert unused scalar and build new BUILD_VECTOR.
20308 if (CurVec.getOpcode() == ISD::SCALAR_TO_VECTOR && CurVec.hasOneUse()) {
20309 AddBuildVectorOp(Ops, CurVec.getOperand(0), 0);
20310 return CanonicalizeBuildVector(Ops);
20311 }
20312
20313 // INSERT_VECTOR_ELT - insert operand and continue up the chain.
20314 if (CurVec.getOpcode() == ISD::INSERT_VECTOR_ELT && CurVec.hasOneUse())
20315 if (auto *CurIdx = dyn_cast<ConstantSDNode>(CurVec.getOperand(2)))
20316 if (CurIdx->getAPIntValue().ult(NumElts)) {
20317 unsigned Idx = CurIdx->getZExtValue();
20318 AddBuildVectorOp(Ops, CurVec.getOperand(1), Idx);
20319
20320 // Found entire BUILD_VECTOR.
20321 if (all_of(Ops, [](SDValue Op) { return !!Op; }))
20322 return CanonicalizeBuildVector(Ops);
20323
20324 CurVec = CurVec->getOperand(0);
20325 continue;
20326 }
20327
20328 // VECTOR_SHUFFLE - if all the operands match the shuffle's sources,
20329 // update the shuffle mask (and second operand if we started with unary
20330 // shuffle) and create a new legal shuffle.
20331 if (CurVec.getOpcode() == ISD::VECTOR_SHUFFLE && CurVec.hasOneUse()) {
20332 auto *SVN = cast<ShuffleVectorSDNode>(CurVec);
20333 SDValue LHS = SVN->getOperand(0);
20334 SDValue RHS = SVN->getOperand(1);
20335 SmallVector<int, 16> Mask(SVN->getMask());
20336 bool Merged = true;
20337 for (auto I : enumerate(Ops)) {
20338 SDValue &Op = I.value();
20339 if (Op) {
20340 SmallVector<int, 16> NewMask;
20341 if (!mergeEltWithShuffle(LHS, RHS, Mask, NewMask, Op, I.index())) {
20342 Merged = false;
20343 break;
20344 }
20345 Mask = std::move(NewMask);
20346 }
20347 }
20348 if (Merged)
20349 if (SDValue NewShuffle =
20350 TLI.buildLegalVectorShuffle(VT, DL, LHS, RHS, Mask, DAG))
20351 return NewShuffle;
20352 }
20353
20354 // Failed to find a match in the chain - bail.
20355 break;
20356 }
20357
20358 // See if we can fill in the missing constant elements as zeros.
20359 // TODO: Should we do this for any constant?
20360 APInt DemandedZeroElts = APInt::getZero(NumElts);
20361 for (unsigned I = 0; I != NumElts; ++I)
20362 if (!Ops[I])
20363 DemandedZeroElts.setBit(I);
20364
20365 if (DAG.MaskedVectorIsZero(InVec, DemandedZeroElts)) {
20366 SDValue Zero = VT.isInteger() ? DAG.getConstant(0, DL, MaxEltVT)
20367 : DAG.getConstantFP(0, DL, MaxEltVT);
20368 for (unsigned I = 0; I != NumElts; ++I)
20369 if (!Ops[I])
20370 Ops[I] = Zero;
20371
20372 return CanonicalizeBuildVector(Ops);
20373 }
20374 }
20375
20376 return SDValue();
20377 }
20378
scalarizeExtractedVectorLoad(SDNode * EVE,EVT InVecVT,SDValue EltNo,LoadSDNode * OriginalLoad)20379 SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
20380 SDValue EltNo,
20381 LoadSDNode *OriginalLoad) {
20382 assert(OriginalLoad->isSimple());
20383
20384 EVT ResultVT = EVE->getValueType(0);
20385 EVT VecEltVT = InVecVT.getVectorElementType();
20386
20387 // If the vector element type is not a multiple of a byte then we are unable
20388 // to correctly compute an address to load only the extracted element as a
20389 // scalar.
20390 if (!VecEltVT.isByteSized())
20391 return SDValue();
20392
20393 ISD::LoadExtType ExtTy =
20394 ResultVT.bitsGT(VecEltVT) ? ISD::NON_EXTLOAD : ISD::EXTLOAD;
20395 if (!TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT) ||
20396 !TLI.shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
20397 return SDValue();
20398
20399 Align Alignment = OriginalLoad->getAlign();
20400 MachinePointerInfo MPI;
20401 SDLoc DL(EVE);
20402 if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
20403 int Elt = ConstEltNo->getZExtValue();
20404 unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
20405 MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
20406 Alignment = commonAlignment(Alignment, PtrOff);
20407 } else {
20408 // Discard the pointer info except the address space because the memory
20409 // operand can't represent this new access since the offset is variable.
20410 MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
20411 Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
20412 }
20413
20414 unsigned IsFast = 0;
20415 if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
20416 OriginalLoad->getAddressSpace(), Alignment,
20417 OriginalLoad->getMemOperand()->getFlags(),
20418 &IsFast) ||
20419 !IsFast)
20420 return SDValue();
20421
20422 SDValue NewPtr = TLI.getVectorElementPointer(DAG, OriginalLoad->getBasePtr(),
20423 InVecVT, EltNo);
20424
20425 // We are replacing a vector load with a scalar load. The new load must have
20426 // identical memory op ordering to the original.
20427 SDValue Load;
20428 if (ResultVT.bitsGT(VecEltVT)) {
20429 // If the result type of vextract is wider than the load, then issue an
20430 // extending load instead.
20431 ISD::LoadExtType ExtType =
20432 TLI.isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT) ? ISD::ZEXTLOAD
20433 : ISD::EXTLOAD;
20434 Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(),
20435 NewPtr, MPI, VecEltVT, Alignment,
20436 OriginalLoad->getMemOperand()->getFlags(),
20437 OriginalLoad->getAAInfo());
20438 DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
20439 } else {
20440 // The result type is narrower or the same width as the vector element
20441 Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI,
20442 Alignment, OriginalLoad->getMemOperand()->getFlags(),
20443 OriginalLoad->getAAInfo());
20444 DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
20445 if (ResultVT.bitsLT(VecEltVT))
20446 Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load);
20447 else
20448 Load = DAG.getBitcast(ResultVT, Load);
20449 }
20450 ++OpsNarrowed;
20451 return Load;
20452 }
20453
20454 /// Transform a vector binary operation into a scalar binary operation by moving
20455 /// the math/logic after an extract element of a vector.
scalarizeExtractedBinop(SDNode * ExtElt,SelectionDAG & DAG,bool LegalOperations)20456 static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
20457 bool LegalOperations) {
20458 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
20459 SDValue Vec = ExtElt->getOperand(0);
20460 SDValue Index = ExtElt->getOperand(1);
20461 auto *IndexC = dyn_cast<ConstantSDNode>(Index);
20462 if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
20463 Vec->getNumValues() != 1)
20464 return SDValue();
20465
20466 // Targets may want to avoid this to prevent an expensive register transfer.
20467 if (!TLI.shouldScalarizeBinop(Vec))
20468 return SDValue();
20469
20470 // Extracting an element of a vector constant is constant-folded, so this
20471 // transform is just replacing a vector op with a scalar op while moving the
20472 // extract.
20473 SDValue Op0 = Vec.getOperand(0);
20474 SDValue Op1 = Vec.getOperand(1);
20475 APInt SplatVal;
20476 if (isAnyConstantBuildVector(Op0, true) ||
20477 ISD::isConstantSplatVector(Op0.getNode(), SplatVal) ||
20478 isAnyConstantBuildVector(Op1, true) ||
20479 ISD::isConstantSplatVector(Op1.getNode(), SplatVal)) {
20480 // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
20481 // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
20482 SDLoc DL(ExtElt);
20483 EVT VT = ExtElt->getValueType(0);
20484 SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
20485 SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
20486 return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
20487 }
20488
20489 return SDValue();
20490 }
20491
20492 // Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
20493 // recursively analyse all of it's users. and try to model themselves as
20494 // bit sequence extractions. If all of them agree on the new, narrower element
20495 // type, and all of them can be modelled as ISD::EXTRACT_VECTOR_ELT's of that
20496 // new element type, do so now.
20497 // This is mainly useful to recover from legalization that scalarized
20498 // the vector as wide elements, but tries to rebuild it with narrower elements.
20499 //
20500 // Some more nodes could be modelled if that helps cover interesting patterns.
refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode * N)20501 bool DAGCombiner::refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(
20502 SDNode *N) {
20503 // We perform this optimization post type-legalization because
20504 // the type-legalizer often scalarizes integer-promoted vectors.
20505 // Performing this optimization before may cause legalizaton cycles.
20506 if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
20507 return false;
20508
20509 // TODO: Add support for big-endian.
20510 if (DAG.getDataLayout().isBigEndian())
20511 return false;
20512
20513 SDValue VecOp = N->getOperand(0);
20514 EVT VecVT = VecOp.getValueType();
20515 assert(!VecVT.isScalableVector() && "Only for fixed vectors.");
20516
20517 // We must start with a constant extraction index.
20518 auto *IndexC = dyn_cast<ConstantSDNode>(N->getOperand(1));
20519 if (!IndexC)
20520 return false;
20521
20522 assert(IndexC->getZExtValue() < VecVT.getVectorNumElements() &&
20523 "Original ISD::EXTRACT_VECTOR_ELT is undefinend?");
20524
20525 // TODO: deal with the case of implicit anyext of the extraction.
20526 unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
20527 EVT ScalarVT = N->getValueType(0);
20528 if (VecVT.getScalarType() != ScalarVT)
20529 return false;
20530
20531 // TODO: deal with the cases other than everything being integer-typed.
20532 if (!ScalarVT.isScalarInteger())
20533 return false;
20534
20535 struct Entry {
20536 SDNode *Producer;
20537
20538 // Which bits of VecOp does it contain?
20539 unsigned BitPos;
20540 int NumBits;
20541 // NOTE: the actual width of \p Producer may be wider than NumBits!
20542
20543 Entry(Entry &&) = default;
20544 Entry(SDNode *Producer_, unsigned BitPos_, int NumBits_)
20545 : Producer(Producer_), BitPos(BitPos_), NumBits(NumBits_) {}
20546
20547 Entry() = delete;
20548 Entry(const Entry &) = delete;
20549 Entry &operator=(const Entry &) = delete;
20550 Entry &operator=(Entry &&) = delete;
20551 };
20552 SmallVector<Entry, 32> Worklist;
20553 SmallVector<Entry, 32> Leafs;
20554
20555 // We start at the "root" ISD::EXTRACT_VECTOR_ELT.
20556 Worklist.emplace_back(N, /*BitPos=*/VecEltBitWidth * IndexC->getZExtValue(),
20557 /*NumBits=*/VecEltBitWidth);
20558
20559 while (!Worklist.empty()) {
20560 Entry E = Worklist.pop_back_val();
20561 // Does the node not even use any of the VecOp bits?
20562 if (!(E.NumBits > 0 && E.BitPos < VecVT.getSizeInBits() &&
20563 E.BitPos + E.NumBits <= VecVT.getSizeInBits()))
20564 return false; // Let's allow the other combines clean this up first.
20565 // Did we fail to model any of the users of the Producer?
20566 bool ProducerIsLeaf = false;
20567 // Look at each user of this Producer.
20568 for (SDNode *User : E.Producer->uses()) {
20569 switch (User->getOpcode()) {
20570 // TODO: support ISD::BITCAST
20571 // TODO: support ISD::ANY_EXTEND
20572 // TODO: support ISD::ZERO_EXTEND
20573 // TODO: support ISD::SIGN_EXTEND
20574 case ISD::TRUNCATE:
20575 // Truncation simply means we keep position, but extract less bits.
20576 Worklist.emplace_back(User, E.BitPos,
20577 /*NumBits=*/User->getValueSizeInBits(0));
20578 break;
20579 // TODO: support ISD::SRA
20580 // TODO: support ISD::SHL
20581 case ISD::SRL:
20582 // We should be shifting the Producer by a constant amount.
20583 if (auto *ShAmtC = dyn_cast<ConstantSDNode>(User->getOperand(1));
20584 User->getOperand(0).getNode() == E.Producer && ShAmtC) {
20585 // Logical right-shift means that we start extraction later,
20586 // but stop it at the same position we did previously.
20587 unsigned ShAmt = ShAmtC->getZExtValue();
20588 Worklist.emplace_back(User, E.BitPos + ShAmt, E.NumBits - ShAmt);
20589 break;
20590 }
20591 [[fallthrough]];
20592 default:
20593 // We can not model this user of the Producer.
20594 // Which means the current Producer will be a ISD::EXTRACT_VECTOR_ELT.
20595 ProducerIsLeaf = true;
20596 // Profitability check: all users that we can not model
20597 // must be ISD::BUILD_VECTOR's.
20598 if (User->getOpcode() != ISD::BUILD_VECTOR)
20599 return false;
20600 break;
20601 }
20602 }
20603 if (ProducerIsLeaf)
20604 Leafs.emplace_back(std::move(E));
20605 }
20606
20607 unsigned NewVecEltBitWidth = Leafs.front().NumBits;
20608
20609 // If we are still at the same element granularity, give up,
20610 if (NewVecEltBitWidth == VecEltBitWidth)
20611 return false;
20612
20613 // The vector width must be a multiple of the new element width.
20614 if (VecVT.getSizeInBits() % NewVecEltBitWidth != 0)
20615 return false;
20616
20617 // All leafs must agree on the new element width.
20618 // All leafs must not expect any "padding" bits ontop of that width.
20619 // All leafs must start extraction from multiple of that width.
20620 if (!all_of(Leafs, [NewVecEltBitWidth](const Entry &E) {
20621 return (unsigned)E.NumBits == NewVecEltBitWidth &&
20622 E.Producer->getValueSizeInBits(0) == NewVecEltBitWidth &&
20623 E.BitPos % NewVecEltBitWidth == 0;
20624 }))
20625 return false;
20626
20627 EVT NewScalarVT = EVT::getIntegerVT(*DAG.getContext(), NewVecEltBitWidth);
20628 EVT NewVecVT = EVT::getVectorVT(*DAG.getContext(), NewScalarVT,
20629 VecVT.getSizeInBits() / NewVecEltBitWidth);
20630
20631 if (LegalTypes &&
20632 !(TLI.isTypeLegal(NewScalarVT) && TLI.isTypeLegal(NewVecVT)))
20633 return false;
20634
20635 if (LegalOperations &&
20636 !(TLI.isOperationLegalOrCustom(ISD::BITCAST, NewVecVT) &&
20637 TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, NewVecVT)))
20638 return false;
20639
20640 SDValue NewVecOp = DAG.getBitcast(NewVecVT, VecOp);
20641 for (const Entry &E : Leafs) {
20642 SDLoc DL(E.Producer);
20643 unsigned NewIndex = E.BitPos / NewVecEltBitWidth;
20644 assert(NewIndex < NewVecVT.getVectorNumElements() &&
20645 "Creating out-of-bounds ISD::EXTRACT_VECTOR_ELT?");
20646 SDValue V = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, NewScalarVT, NewVecOp,
20647 DAG.getVectorIdxConstant(NewIndex, DL));
20648 CombineTo(E.Producer, V);
20649 }
20650
20651 return true;
20652 }
20653
visitEXTRACT_VECTOR_ELT(SDNode * N)20654 SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
20655 SDValue VecOp = N->getOperand(0);
20656 SDValue Index = N->getOperand(1);
20657 EVT ScalarVT = N->getValueType(0);
20658 EVT VecVT = VecOp.getValueType();
20659 if (VecOp.isUndef())
20660 return DAG.getUNDEF(ScalarVT);
20661
20662 // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
20663 //
20664 // This only really matters if the index is non-constant since other combines
20665 // on the constant elements already work.
20666 SDLoc DL(N);
20667 if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
20668 Index == VecOp.getOperand(2)) {
20669 SDValue Elt = VecOp.getOperand(1);
20670 return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Elt, DL, ScalarVT) : Elt;
20671 }
20672
20673 // (vextract (scalar_to_vector val, 0) -> val
20674 if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
20675 // Only 0'th element of SCALAR_TO_VECTOR is defined.
20676 if (DAG.isKnownNeverZero(Index))
20677 return DAG.getUNDEF(ScalarVT);
20678
20679 // Check if the result type doesn't match the inserted element type. A
20680 // SCALAR_TO_VECTOR may truncate the inserted element and the
20681 // EXTRACT_VECTOR_ELT may widen the extracted vector.
20682 SDValue InOp = VecOp.getOperand(0);
20683 if (InOp.getValueType() != ScalarVT) {
20684 assert(InOp.getValueType().isInteger() && ScalarVT.isInteger() &&
20685 InOp.getValueType().bitsGT(ScalarVT));
20686 return DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, InOp);
20687 }
20688 return InOp;
20689 }
20690
20691 // extract_vector_elt of out-of-bounds element -> UNDEF
20692 auto *IndexC = dyn_cast<ConstantSDNode>(Index);
20693 if (IndexC && VecVT.isFixedLengthVector() &&
20694 IndexC->getAPIntValue().uge(VecVT.getVectorNumElements()))
20695 return DAG.getUNDEF(ScalarVT);
20696
20697 // extract_vector_elt(freeze(x)), idx -> freeze(extract_vector_elt(x)), idx
20698 if (VecOp.hasOneUse() && VecOp.getOpcode() == ISD::FREEZE) {
20699 return DAG.getFreeze(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT,
20700 VecOp.getOperand(0), Index));
20701 }
20702
20703 // extract_vector_elt (build_vector x, y), 1 -> y
20704 if (((IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR) ||
20705 VecOp.getOpcode() == ISD::SPLAT_VECTOR) &&
20706 TLI.isTypeLegal(VecVT) &&
20707 (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT))) {
20708 assert((VecOp.getOpcode() != ISD::BUILD_VECTOR ||
20709 VecVT.isFixedLengthVector()) &&
20710 "BUILD_VECTOR used for scalable vectors");
20711 unsigned IndexVal =
20712 VecOp.getOpcode() == ISD::BUILD_VECTOR ? IndexC->getZExtValue() : 0;
20713 SDValue Elt = VecOp.getOperand(IndexVal);
20714 EVT InEltVT = Elt.getValueType();
20715
20716 // Sometimes build_vector's scalar input types do not match result type.
20717 if (ScalarVT == InEltVT)
20718 return Elt;
20719
20720 // TODO: It may be useful to truncate if free if the build_vector implicitly
20721 // converts.
20722 }
20723
20724 if (SDValue BO = scalarizeExtractedBinop(N, DAG, LegalOperations))
20725 return BO;
20726
20727 if (VecVT.isScalableVector())
20728 return SDValue();
20729
20730 // All the code from this point onwards assumes fixed width vectors, but it's
20731 // possible that some of the combinations could be made to work for scalable
20732 // vectors too.
20733 unsigned NumElts = VecVT.getVectorNumElements();
20734 unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
20735
20736 // TODO: These transforms should not require the 'hasOneUse' restriction, but
20737 // there are regressions on multiple targets without it. We can end up with a
20738 // mess of scalar and vector code if we reduce only part of the DAG to scalar.
20739 if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
20740 VecOp.hasOneUse()) {
20741 // The vector index of the LSBs of the source depend on the endian-ness.
20742 bool IsLE = DAG.getDataLayout().isLittleEndian();
20743 unsigned ExtractIndex = IndexC->getZExtValue();
20744 // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
20745 unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
20746 SDValue BCSrc = VecOp.getOperand(0);
20747 if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
20748 return DAG.getAnyExtOrTrunc(BCSrc, DL, ScalarVT);
20749
20750 if (LegalTypes && BCSrc.getValueType().isInteger() &&
20751 BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR) {
20752 // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
20753 // trunc i64 X to i32
20754 SDValue X = BCSrc.getOperand(0);
20755 assert(X.getValueType().isScalarInteger() && ScalarVT.isScalarInteger() &&
20756 "Extract element and scalar to vector can't change element type "
20757 "from FP to integer.");
20758 unsigned XBitWidth = X.getValueSizeInBits();
20759 BCTruncElt = IsLE ? 0 : XBitWidth / VecEltBitWidth - 1;
20760
20761 // An extract element return value type can be wider than its vector
20762 // operand element type. In that case, the high bits are undefined, so
20763 // it's possible that we may need to extend rather than truncate.
20764 if (ExtractIndex == BCTruncElt && XBitWidth > VecEltBitWidth) {
20765 assert(XBitWidth % VecEltBitWidth == 0 &&
20766 "Scalar bitwidth must be a multiple of vector element bitwidth");
20767 return DAG.getAnyExtOrTrunc(X, DL, ScalarVT);
20768 }
20769 }
20770 }
20771
20772 // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
20773 // We only perform this optimization before the op legalization phase because
20774 // we may introduce new vector instructions which are not backed by TD
20775 // patterns. For example on AVX, extracting elements from a wide vector
20776 // without using extract_subvector. However, if we can find an underlying
20777 // scalar value, then we can always use that.
20778 if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
20779 auto *Shuf = cast<ShuffleVectorSDNode>(VecOp);
20780 // Find the new index to extract from.
20781 int OrigElt = Shuf->getMaskElt(IndexC->getZExtValue());
20782
20783 // Extracting an undef index is undef.
20784 if (OrigElt == -1)
20785 return DAG.getUNDEF(ScalarVT);
20786
20787 // Select the right vector half to extract from.
20788 SDValue SVInVec;
20789 if (OrigElt < (int)NumElts) {
20790 SVInVec = VecOp.getOperand(0);
20791 } else {
20792 SVInVec = VecOp.getOperand(1);
20793 OrigElt -= NumElts;
20794 }
20795
20796 if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
20797 SDValue InOp = SVInVec.getOperand(OrigElt);
20798 if (InOp.getValueType() != ScalarVT) {
20799 assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
20800 InOp = DAG.getSExtOrTrunc(InOp, DL, ScalarVT);
20801 }
20802
20803 return InOp;
20804 }
20805
20806 // FIXME: We should handle recursing on other vector shuffles and
20807 // scalar_to_vector here as well.
20808
20809 if (!LegalOperations ||
20810 // FIXME: Should really be just isOperationLegalOrCustom.
20811 TLI.isOperationLegal(ISD::EXTRACT_VECTOR_ELT, VecVT) ||
20812 TLI.isOperationExpand(ISD::VECTOR_SHUFFLE, VecVT)) {
20813 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, SVInVec,
20814 DAG.getVectorIdxConstant(OrigElt, DL));
20815 }
20816 }
20817
20818 // If only EXTRACT_VECTOR_ELT nodes use the source vector we can
20819 // simplify it based on the (valid) extraction indices.
20820 if (llvm::all_of(VecOp->uses(), [&](SDNode *Use) {
20821 return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20822 Use->getOperand(0) == VecOp &&
20823 isa<ConstantSDNode>(Use->getOperand(1));
20824 })) {
20825 APInt DemandedElts = APInt::getZero(NumElts);
20826 for (SDNode *Use : VecOp->uses()) {
20827 auto *CstElt = cast<ConstantSDNode>(Use->getOperand(1));
20828 if (CstElt->getAPIntValue().ult(NumElts))
20829 DemandedElts.setBit(CstElt->getZExtValue());
20830 }
20831 if (SimplifyDemandedVectorElts(VecOp, DemandedElts, true)) {
20832 // We simplified the vector operand of this extract element. If this
20833 // extract is not dead, visit it again so it is folded properly.
20834 if (N->getOpcode() != ISD::DELETED_NODE)
20835 AddToWorklist(N);
20836 return SDValue(N, 0);
20837 }
20838 APInt DemandedBits = APInt::getAllOnes(VecEltBitWidth);
20839 if (SimplifyDemandedBits(VecOp, DemandedBits, DemandedElts, true)) {
20840 // We simplified the vector operand of this extract element. If this
20841 // extract is not dead, visit it again so it is folded properly.
20842 if (N->getOpcode() != ISD::DELETED_NODE)
20843 AddToWorklist(N);
20844 return SDValue(N, 0);
20845 }
20846 }
20847
20848 if (refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(N))
20849 return SDValue(N, 0);
20850
20851 // Everything under here is trying to match an extract of a loaded value.
20852 // If the result of load has to be truncated, then it's not necessarily
20853 // profitable.
20854 bool BCNumEltsChanged = false;
20855 EVT ExtVT = VecVT.getVectorElementType();
20856 EVT LVT = ExtVT;
20857 if (ScalarVT.bitsLT(LVT) && !TLI.isTruncateFree(LVT, ScalarVT))
20858 return SDValue();
20859
20860 if (VecOp.getOpcode() == ISD::BITCAST) {
20861 // Don't duplicate a load with other uses.
20862 if (!VecOp.hasOneUse())
20863 return SDValue();
20864
20865 EVT BCVT = VecOp.getOperand(0).getValueType();
20866 if (!BCVT.isVector() || ExtVT.bitsGT(BCVT.getVectorElementType()))
20867 return SDValue();
20868 if (NumElts != BCVT.getVectorNumElements())
20869 BCNumEltsChanged = true;
20870 VecOp = VecOp.getOperand(0);
20871 ExtVT = BCVT.getVectorElementType();
20872 }
20873
20874 // extract (vector load $addr), i --> load $addr + i * size
20875 if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
20876 ISD::isNormalLoad(VecOp.getNode()) &&
20877 !Index->hasPredecessor(VecOp.getNode())) {
20878 auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
20879 if (VecLoad && VecLoad->isSimple())
20880 return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad);
20881 }
20882
20883 // Perform only after legalization to ensure build_vector / vector_shuffle
20884 // optimizations have already been done.
20885 if (!LegalOperations || !IndexC)
20886 return SDValue();
20887
20888 // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
20889 // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
20890 // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
20891 int Elt = IndexC->getZExtValue();
20892 LoadSDNode *LN0 = nullptr;
20893 if (ISD::isNormalLoad(VecOp.getNode())) {
20894 LN0 = cast<LoadSDNode>(VecOp);
20895 } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
20896 VecOp.getOperand(0).getValueType() == ExtVT &&
20897 ISD::isNormalLoad(VecOp.getOperand(0).getNode())) {
20898 // Don't duplicate a load with other uses.
20899 if (!VecOp.hasOneUse())
20900 return SDValue();
20901
20902 LN0 = cast<LoadSDNode>(VecOp.getOperand(0));
20903 }
20904 if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(VecOp)) {
20905 // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
20906 // =>
20907 // (load $addr+1*size)
20908
20909 // Don't duplicate a load with other uses.
20910 if (!VecOp.hasOneUse())
20911 return SDValue();
20912
20913 // If the bit convert changed the number of elements, it is unsafe
20914 // to examine the mask.
20915 if (BCNumEltsChanged)
20916 return SDValue();
20917
20918 // Select the input vector, guarding against out of range extract vector.
20919 int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Elt);
20920 VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(0) : VecOp.getOperand(1);
20921
20922 if (VecOp.getOpcode() == ISD::BITCAST) {
20923 // Don't duplicate a load with other uses.
20924 if (!VecOp.hasOneUse())
20925 return SDValue();
20926
20927 VecOp = VecOp.getOperand(0);
20928 }
20929 if (ISD::isNormalLoad(VecOp.getNode())) {
20930 LN0 = cast<LoadSDNode>(VecOp);
20931 Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
20932 Index = DAG.getConstant(Elt, DL, Index.getValueType());
20933 }
20934 } else if (VecOp.getOpcode() == ISD::CONCAT_VECTORS && !BCNumEltsChanged &&
20935 VecVT.getVectorElementType() == ScalarVT &&
20936 (!LegalTypes ||
20937 TLI.isTypeLegal(
20938 VecOp.getOperand(0).getValueType().getVectorElementType()))) {
20939 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 0
20940 // -> extract_vector_elt a, 0
20941 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 1
20942 // -> extract_vector_elt a, 1
20943 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 2
20944 // -> extract_vector_elt b, 0
20945 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 3
20946 // -> extract_vector_elt b, 1
20947 SDLoc SL(N);
20948 EVT ConcatVT = VecOp.getOperand(0).getValueType();
20949 unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
20950 SDValue NewIdx = DAG.getConstant(Elt % ConcatNumElts, SL,
20951 Index.getValueType());
20952
20953 SDValue ConcatOp = VecOp.getOperand(Elt / ConcatNumElts);
20954 SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL,
20955 ConcatVT.getVectorElementType(),
20956 ConcatOp, NewIdx);
20957 return DAG.getNode(ISD::BITCAST, SL, ScalarVT, Elt);
20958 }
20959
20960 // Make sure we found a non-volatile load and the extractelement is
20961 // the only use.
20962 if (!LN0 || !LN0->hasNUsesOfValue(1,0) || !LN0->isSimple())
20963 return SDValue();
20964
20965 // If Idx was -1 above, Elt is going to be -1, so just return undef.
20966 if (Elt == -1)
20967 return DAG.getUNDEF(LVT);
20968
20969 return scalarizeExtractedVectorLoad(N, VecVT, Index, LN0);
20970 }
20971
20972 // Simplify (build_vec (ext )) to (bitcast (build_vec ))
reduceBuildVecExtToExtBuildVec(SDNode * N)20973 SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
20974 // We perform this optimization post type-legalization because
20975 // the type-legalizer often scalarizes integer-promoted vectors.
20976 // Performing this optimization before may create bit-casts which
20977 // will be type-legalized to complex code sequences.
20978 // We perform this optimization only before the operation legalizer because we
20979 // may introduce illegal operations.
20980 if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
20981 return SDValue();
20982
20983 unsigned NumInScalars = N->getNumOperands();
20984 SDLoc DL(N);
20985 EVT VT = N->getValueType(0);
20986
20987 // Check to see if this is a BUILD_VECTOR of a bunch of values
20988 // which come from any_extend or zero_extend nodes. If so, we can create
20989 // a new BUILD_VECTOR using bit-casts which may enable other BUILD_VECTOR
20990 // optimizations. We do not handle sign-extend because we can't fill the sign
20991 // using shuffles.
20992 EVT SourceType = MVT::Other;
20993 bool AllAnyExt = true;
20994
20995 for (unsigned i = 0; i != NumInScalars; ++i) {
20996 SDValue In = N->getOperand(i);
20997 // Ignore undef inputs.
20998 if (In.isUndef()) continue;
20999
21000 bool AnyExt = In.getOpcode() == ISD::ANY_EXTEND;
21001 bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND;
21002
21003 // Abort if the element is not an extension.
21004 if (!ZeroExt && !AnyExt) {
21005 SourceType = MVT::Other;
21006 break;
21007 }
21008
21009 // The input is a ZeroExt or AnyExt. Check the original type.
21010 EVT InTy = In.getOperand(0).getValueType();
21011
21012 // Check that all of the widened source types are the same.
21013 if (SourceType == MVT::Other)
21014 // First time.
21015 SourceType = InTy;
21016 else if (InTy != SourceType) {
21017 // Multiple income types. Abort.
21018 SourceType = MVT::Other;
21019 break;
21020 }
21021
21022 // Check if all of the extends are ANY_EXTENDs.
21023 AllAnyExt &= AnyExt;
21024 }
21025
21026 // In order to have valid types, all of the inputs must be extended from the
21027 // same source type and all of the inputs must be any or zero extend.
21028 // Scalar sizes must be a power of two.
21029 EVT OutScalarTy = VT.getScalarType();
21030 bool ValidTypes = SourceType != MVT::Other &&
21031 isPowerOf2_32(OutScalarTy.getSizeInBits()) &&
21032 isPowerOf2_32(SourceType.getSizeInBits());
21033
21034 // Create a new simpler BUILD_VECTOR sequence which other optimizations can
21035 // turn into a single shuffle instruction.
21036 if (!ValidTypes)
21037 return SDValue();
21038
21039 // If we already have a splat buildvector, then don't fold it if it means
21040 // introducing zeros.
21041 if (!AllAnyExt && DAG.isSplatValue(SDValue(N, 0), /*AllowUndefs*/ true))
21042 return SDValue();
21043
21044 bool isLE = DAG.getDataLayout().isLittleEndian();
21045 unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits();
21046 assert(ElemRatio > 1 && "Invalid element size ratio");
21047 SDValue Filler = AllAnyExt ? DAG.getUNDEF(SourceType):
21048 DAG.getConstant(0, DL, SourceType);
21049
21050 unsigned NewBVElems = ElemRatio * VT.getVectorNumElements();
21051 SmallVector<SDValue, 8> Ops(NewBVElems, Filler);
21052
21053 // Populate the new build_vector
21054 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
21055 SDValue Cast = N->getOperand(i);
21056 assert((Cast.getOpcode() == ISD::ANY_EXTEND ||
21057 Cast.getOpcode() == ISD::ZERO_EXTEND ||
21058 Cast.isUndef()) && "Invalid cast opcode");
21059 SDValue In;
21060 if (Cast.isUndef())
21061 In = DAG.getUNDEF(SourceType);
21062 else
21063 In = Cast->getOperand(0);
21064 unsigned Index = isLE ? (i * ElemRatio) :
21065 (i * ElemRatio + (ElemRatio - 1));
21066
21067 assert(Index < Ops.size() && "Invalid index");
21068 Ops[Index] = In;
21069 }
21070
21071 // The type of the new BUILD_VECTOR node.
21072 EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SourceType, NewBVElems);
21073 assert(VecVT.getSizeInBits() == VT.getSizeInBits() &&
21074 "Invalid vector size");
21075 // Check if the new vector type is legal.
21076 if (!isTypeLegal(VecVT) ||
21077 (!TLI.isOperationLegal(ISD::BUILD_VECTOR, VecVT) &&
21078 TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)))
21079 return SDValue();
21080
21081 // Make the new BUILD_VECTOR.
21082 SDValue BV = DAG.getBuildVector(VecVT, DL, Ops);
21083
21084 // The new BUILD_VECTOR node has the potential to be further optimized.
21085 AddToWorklist(BV.getNode());
21086 // Bitcast to the desired type.
21087 return DAG.getBitcast(VT, BV);
21088 }
21089
21090 // Simplify (build_vec (trunc $1)
21091 // (trunc (srl $1 half-width))
21092 // (trunc (srl $1 (2 * half-width))))
21093 // to (bitcast $1)
reduceBuildVecTruncToBitCast(SDNode * N)21094 SDValue DAGCombiner::reduceBuildVecTruncToBitCast(SDNode *N) {
21095 assert(N->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
21096
21097 // Only for little endian
21098 if (!DAG.getDataLayout().isLittleEndian())
21099 return SDValue();
21100
21101 SDLoc DL(N);
21102 EVT VT = N->getValueType(0);
21103 EVT OutScalarTy = VT.getScalarType();
21104 uint64_t ScalarTypeBitsize = OutScalarTy.getSizeInBits();
21105
21106 // Only for power of two types to be sure that bitcast works well
21107 if (!isPowerOf2_64(ScalarTypeBitsize))
21108 return SDValue();
21109
21110 unsigned NumInScalars = N->getNumOperands();
21111
21112 // Look through bitcasts
21113 auto PeekThroughBitcast = [](SDValue Op) {
21114 if (Op.getOpcode() == ISD::BITCAST)
21115 return Op.getOperand(0);
21116 return Op;
21117 };
21118
21119 // The source value where all the parts are extracted.
21120 SDValue Src;
21121 for (unsigned i = 0; i != NumInScalars; ++i) {
21122 SDValue In = PeekThroughBitcast(N->getOperand(i));
21123 // Ignore undef inputs.
21124 if (In.isUndef()) continue;
21125
21126 if (In.getOpcode() != ISD::TRUNCATE)
21127 return SDValue();
21128
21129 In = PeekThroughBitcast(In.getOperand(0));
21130
21131 if (In.getOpcode() != ISD::SRL) {
21132 // For now only build_vec without shuffling, handle shifts here in the
21133 // future.
21134 if (i != 0)
21135 return SDValue();
21136
21137 Src = In;
21138 } else {
21139 // In is SRL
21140 SDValue part = PeekThroughBitcast(In.getOperand(0));
21141
21142 if (!Src) {
21143 Src = part;
21144 } else if (Src != part) {
21145 // Vector parts do not stem from the same variable
21146 return SDValue();
21147 }
21148
21149 SDValue ShiftAmtVal = In.getOperand(1);
21150 if (!isa<ConstantSDNode>(ShiftAmtVal))
21151 return SDValue();
21152
21153 uint64_t ShiftAmt = In.getConstantOperandVal(1);
21154
21155 // The extracted value is not extracted at the right position
21156 if (ShiftAmt != i * ScalarTypeBitsize)
21157 return SDValue();
21158 }
21159 }
21160
21161 // Only cast if the size is the same
21162 if (Src.getValueType().getSizeInBits() != VT.getSizeInBits())
21163 return SDValue();
21164
21165 return DAG.getBitcast(VT, Src);
21166 }
21167
createBuildVecShuffle(const SDLoc & DL,SDNode * N,ArrayRef<int> VectorMask,SDValue VecIn1,SDValue VecIn2,unsigned LeftIdx,bool DidSplitVec)21168 SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
21169 ArrayRef<int> VectorMask,
21170 SDValue VecIn1, SDValue VecIn2,
21171 unsigned LeftIdx, bool DidSplitVec) {
21172 SDValue ZeroIdx = DAG.getVectorIdxConstant(0, DL);
21173
21174 EVT VT = N->getValueType(0);
21175 EVT InVT1 = VecIn1.getValueType();
21176 EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1;
21177
21178 unsigned NumElems = VT.getVectorNumElements();
21179 unsigned ShuffleNumElems = NumElems;
21180
21181 // If we artificially split a vector in two already, then the offsets in the
21182 // operands will all be based off of VecIn1, even those in VecIn2.
21183 unsigned Vec2Offset = DidSplitVec ? 0 : InVT1.getVectorNumElements();
21184
21185 uint64_t VTSize = VT.getFixedSizeInBits();
21186 uint64_t InVT1Size = InVT1.getFixedSizeInBits();
21187 uint64_t InVT2Size = InVT2.getFixedSizeInBits();
21188
21189 assert(InVT2Size <= InVT1Size &&
21190 "Inputs must be sorted to be in non-increasing vector size order.");
21191
21192 // We can't generate a shuffle node with mismatched input and output types.
21193 // Try to make the types match the type of the output.
21194 if (InVT1 != VT || InVT2 != VT) {
21195 if ((VTSize % InVT1Size == 0) && InVT1 == InVT2) {
21196 // If the output vector length is a multiple of both input lengths,
21197 // we can concatenate them and pad the rest with undefs.
21198 unsigned NumConcats = VTSize / InVT1Size;
21199 assert(NumConcats >= 2 && "Concat needs at least two inputs!");
21200 SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getUNDEF(InVT1));
21201 ConcatOps[0] = VecIn1;
21202 ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getUNDEF(InVT1);
21203 VecIn1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
21204 VecIn2 = SDValue();
21205 } else if (InVT1Size == VTSize * 2) {
21206 if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems))
21207 return SDValue();
21208
21209 if (!VecIn2.getNode()) {
21210 // If we only have one input vector, and it's twice the size of the
21211 // output, split it in two.
21212 VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1,
21213 DAG.getVectorIdxConstant(NumElems, DL));
21214 VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, VecIn1, ZeroIdx);
21215 // Since we now have shorter input vectors, adjust the offset of the
21216 // second vector's start.
21217 Vec2Offset = NumElems;
21218 } else {
21219 assert(InVT2Size <= InVT1Size &&
21220 "Second input is not going to be larger than the first one.");
21221
21222 // VecIn1 is wider than the output, and we have another, possibly
21223 // smaller input. Pad the smaller input with undefs, shuffle at the
21224 // input vector width, and extract the output.
21225 // The shuffle type is different than VT, so check legality again.
21226 if (LegalOperations &&
21227 !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
21228 return SDValue();
21229
21230 // Legalizing INSERT_SUBVECTOR is tricky - you basically have to
21231 // lower it back into a BUILD_VECTOR. So if the inserted type is
21232 // illegal, don't even try.
21233 if (InVT1 != InVT2) {
21234 if (!TLI.isTypeLegal(InVT2))
21235 return SDValue();
21236 VecIn2 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InVT1,
21237 DAG.getUNDEF(InVT1), VecIn2, ZeroIdx);
21238 }
21239 ShuffleNumElems = NumElems * 2;
21240 }
21241 } else if (InVT2Size * 2 == VTSize && InVT1Size == VTSize) {
21242 SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(InVT2));
21243 ConcatOps[0] = VecIn2;
21244 VecIn2 = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
21245 } else if (InVT1Size / VTSize > 1 && InVT1Size % VTSize == 0) {
21246 if (!TLI.isExtractSubvectorCheap(VT, InVT1, NumElems) ||
21247 !TLI.isTypeLegal(InVT1) || !TLI.isTypeLegal(InVT2))
21248 return SDValue();
21249 // If dest vector has less than two elements, then use shuffle and extract
21250 // from larger regs will cost even more.
21251 if (VT.getVectorNumElements() <= 2 || !VecIn2.getNode())
21252 return SDValue();
21253 assert(InVT2Size <= InVT1Size &&
21254 "Second input is not going to be larger than the first one.");
21255
21256 // VecIn1 is wider than the output, and we have another, possibly
21257 // smaller input. Pad the smaller input with undefs, shuffle at the
21258 // input vector width, and extract the output.
21259 // The shuffle type is different than VT, so check legality again.
21260 if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, InVT1))
21261 return SDValue();
21262
21263 if (InVT1 != InVT2) {
21264 VecIn2 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, InVT1,
21265 DAG.getUNDEF(InVT1), VecIn2, ZeroIdx);
21266 }
21267 ShuffleNumElems = InVT1Size / VTSize * NumElems;
21268 } else {
21269 // TODO: Support cases where the length mismatch isn't exactly by a
21270 // factor of 2.
21271 // TODO: Move this check upwards, so that if we have bad type
21272 // mismatches, we don't create any DAG nodes.
21273 return SDValue();
21274 }
21275 }
21276
21277 // Initialize mask to undef.
21278 SmallVector<int, 8> Mask(ShuffleNumElems, -1);
21279
21280 // Only need to run up to the number of elements actually used, not the
21281 // total number of elements in the shuffle - if we are shuffling a wider
21282 // vector, the high lanes should be set to undef.
21283 for (unsigned i = 0; i != NumElems; ++i) {
21284 if (VectorMask[i] <= 0)
21285 continue;
21286
21287 unsigned ExtIndex = N->getOperand(i).getConstantOperandVal(1);
21288 if (VectorMask[i] == (int)LeftIdx) {
21289 Mask[i] = ExtIndex;
21290 } else if (VectorMask[i] == (int)LeftIdx + 1) {
21291 Mask[i] = Vec2Offset + ExtIndex;
21292 }
21293 }
21294
21295 // The type the input vectors may have changed above.
21296 InVT1 = VecIn1.getValueType();
21297
21298 // If we already have a VecIn2, it should have the same type as VecIn1.
21299 // If we don't, get an undef/zero vector of the appropriate type.
21300 VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(InVT1);
21301 assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type.");
21302
21303 SDValue Shuffle = DAG.getVectorShuffle(InVT1, DL, VecIn1, VecIn2, Mask);
21304 if (ShuffleNumElems > NumElems)
21305 Shuffle = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Shuffle, ZeroIdx);
21306
21307 return Shuffle;
21308 }
21309
reduceBuildVecToShuffleWithZero(SDNode * BV,SelectionDAG & DAG)21310 static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
21311 assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
21312
21313 // First, determine where the build vector is not undef.
21314 // TODO: We could extend this to handle zero elements as well as undefs.
21315 int NumBVOps = BV->getNumOperands();
21316 int ZextElt = -1;
21317 for (int i = 0; i != NumBVOps; ++i) {
21318 SDValue Op = BV->getOperand(i);
21319 if (Op.isUndef())
21320 continue;
21321 if (ZextElt == -1)
21322 ZextElt = i;
21323 else
21324 return SDValue();
21325 }
21326 // Bail out if there's no non-undef element.
21327 if (ZextElt == -1)
21328 return SDValue();
21329
21330 // The build vector contains some number of undef elements and exactly
21331 // one other element. That other element must be a zero-extended scalar
21332 // extracted from a vector at a constant index to turn this into a shuffle.
21333 // Also, require that the build vector does not implicitly truncate/extend
21334 // its elements.
21335 // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND.
21336 EVT VT = BV->getValueType(0);
21337 SDValue Zext = BV->getOperand(ZextElt);
21338 if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() ||
21339 Zext.getOperand(0).getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
21340 !isa<ConstantSDNode>(Zext.getOperand(0).getOperand(1)) ||
21341 Zext.getValueSizeInBits() != VT.getScalarSizeInBits())
21342 return SDValue();
21343
21344 // The zero-extend must be a multiple of the source size, and we must be
21345 // building a vector of the same size as the source of the extract element.
21346 SDValue Extract = Zext.getOperand(0);
21347 unsigned DestSize = Zext.getValueSizeInBits();
21348 unsigned SrcSize = Extract.getValueSizeInBits();
21349 if (DestSize % SrcSize != 0 ||
21350 Extract.getOperand(0).getValueSizeInBits() != VT.getSizeInBits())
21351 return SDValue();
21352
21353 // Create a shuffle mask that will combine the extracted element with zeros
21354 // and undefs.
21355 int ZextRatio = DestSize / SrcSize;
21356 int NumMaskElts = NumBVOps * ZextRatio;
21357 SmallVector<int, 32> ShufMask(NumMaskElts, -1);
21358 for (int i = 0; i != NumMaskElts; ++i) {
21359 if (i / ZextRatio == ZextElt) {
21360 // The low bits of the (potentially translated) extracted element map to
21361 // the source vector. The high bits map to zero. We will use a zero vector
21362 // as the 2nd source operand of the shuffle, so use the 1st element of
21363 // that vector (mask value is number-of-elements) for the high bits.
21364 if (i % ZextRatio == 0)
21365 ShufMask[i] = Extract.getConstantOperandVal(1);
21366 else
21367 ShufMask[i] = NumMaskElts;
21368 }
21369
21370 // Undef elements of the build vector remain undef because we initialize
21371 // the shuffle mask with -1.
21372 }
21373
21374 // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
21375 // bitcast (shuffle V, ZeroVec, VectorMask)
21376 SDLoc DL(BV);
21377 EVT VecVT = Extract.getOperand(0).getValueType();
21378 SDValue ZeroVec = DAG.getConstant(0, DL, VecVT);
21379 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21380 SDValue Shuf = TLI.buildLegalVectorShuffle(VecVT, DL, Extract.getOperand(0),
21381 ZeroVec, ShufMask, DAG);
21382 if (!Shuf)
21383 return SDValue();
21384 return DAG.getBitcast(VT, Shuf);
21385 }
21386
21387 // FIXME: promote to STLExtras.
21388 template <typename R, typename T>
getFirstIndexOf(R && Range,const T & Val)21389 static auto getFirstIndexOf(R &&Range, const T &Val) {
21390 auto I = find(Range, Val);
21391 if (I == Range.end())
21392 return static_cast<decltype(std::distance(Range.begin(), I))>(-1);
21393 return std::distance(Range.begin(), I);
21394 }
21395
21396 // Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
21397 // operations. If the types of the vectors we're extracting from allow it,
21398 // turn this into a vector_shuffle node.
reduceBuildVecToShuffle(SDNode * N)21399 SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
21400 SDLoc DL(N);
21401 EVT VT = N->getValueType(0);
21402
21403 // Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes.
21404 if (!isTypeLegal(VT))
21405 return SDValue();
21406
21407 if (SDValue V = reduceBuildVecToShuffleWithZero(N, DAG))
21408 return V;
21409
21410 // May only combine to shuffle after legalize if shuffle is legal.
21411 if (LegalOperations && !TLI.isOperationLegal(ISD::VECTOR_SHUFFLE, VT))
21412 return SDValue();
21413
21414 bool UsesZeroVector = false;
21415 unsigned NumElems = N->getNumOperands();
21416
21417 // Record, for each element of the newly built vector, which input vector
21418 // that element comes from. -1 stands for undef, 0 for the zero vector,
21419 // and positive values for the input vectors.
21420 // VectorMask maps each element to its vector number, and VecIn maps vector
21421 // numbers to their initial SDValues.
21422
21423 SmallVector<int, 8> VectorMask(NumElems, -1);
21424 SmallVector<SDValue, 8> VecIn;
21425 VecIn.push_back(SDValue());
21426
21427 for (unsigned i = 0; i != NumElems; ++i) {
21428 SDValue Op = N->getOperand(i);
21429
21430 if (Op.isUndef())
21431 continue;
21432
21433 // See if we can use a blend with a zero vector.
21434 // TODO: Should we generalize this to a blend with an arbitrary constant
21435 // vector?
21436 if (isNullConstant(Op) || isNullFPConstant(Op)) {
21437 UsesZeroVector = true;
21438 VectorMask[i] = 0;
21439 continue;
21440 }
21441
21442 // Not an undef or zero. If the input is something other than an
21443 // EXTRACT_VECTOR_ELT with an in-range constant index, bail out.
21444 if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
21445 !isa<ConstantSDNode>(Op.getOperand(1)))
21446 return SDValue();
21447 SDValue ExtractedFromVec = Op.getOperand(0);
21448
21449 if (ExtractedFromVec.getValueType().isScalableVector())
21450 return SDValue();
21451
21452 const APInt &ExtractIdx = Op.getConstantOperandAPInt(1);
21453 if (ExtractIdx.uge(ExtractedFromVec.getValueType().getVectorNumElements()))
21454 return SDValue();
21455
21456 // All inputs must have the same element type as the output.
21457 if (VT.getVectorElementType() !=
21458 ExtractedFromVec.getValueType().getVectorElementType())
21459 return SDValue();
21460
21461 // Have we seen this input vector before?
21462 // The vectors are expected to be tiny (usually 1 or 2 elements), so using
21463 // a map back from SDValues to numbers isn't worth it.
21464 int Idx = getFirstIndexOf(VecIn, ExtractedFromVec);
21465 if (Idx == -1) { // A new source vector?
21466 Idx = VecIn.size();
21467 VecIn.push_back(ExtractedFromVec);
21468 }
21469
21470 VectorMask[i] = Idx;
21471 }
21472
21473 // If we didn't find at least one input vector, bail out.
21474 if (VecIn.size() < 2)
21475 return SDValue();
21476
21477 // If all the Operands of BUILD_VECTOR extract from same
21478 // vector, then split the vector efficiently based on the maximum
21479 // vector access index and adjust the VectorMask and
21480 // VecIn accordingly.
21481 bool DidSplitVec = false;
21482 if (VecIn.size() == 2) {
21483 unsigned MaxIndex = 0;
21484 unsigned NearestPow2 = 0;
21485 SDValue Vec = VecIn.back();
21486 EVT InVT = Vec.getValueType();
21487 SmallVector<unsigned, 8> IndexVec(NumElems, 0);
21488
21489 for (unsigned i = 0; i < NumElems; i++) {
21490 if (VectorMask[i] <= 0)
21491 continue;
21492 unsigned Index = N->getOperand(i).getConstantOperandVal(1);
21493 IndexVec[i] = Index;
21494 MaxIndex = std::max(MaxIndex, Index);
21495 }
21496
21497 NearestPow2 = PowerOf2Ceil(MaxIndex);
21498 if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 &&
21499 NumElems * 2 < NearestPow2) {
21500 unsigned SplitSize = NearestPow2 / 2;
21501 EVT SplitVT = EVT::getVectorVT(*DAG.getContext(),
21502 InVT.getVectorElementType(), SplitSize);
21503 if (TLI.isTypeLegal(SplitVT) &&
21504 SplitSize + SplitVT.getVectorNumElements() <=
21505 InVT.getVectorNumElements()) {
21506 SDValue VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
21507 DAG.getVectorIdxConstant(SplitSize, DL));
21508 SDValue VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SplitVT, Vec,
21509 DAG.getVectorIdxConstant(0, DL));
21510 VecIn.pop_back();
21511 VecIn.push_back(VecIn1);
21512 VecIn.push_back(VecIn2);
21513 DidSplitVec = true;
21514
21515 for (unsigned i = 0; i < NumElems; i++) {
21516 if (VectorMask[i] <= 0)
21517 continue;
21518 VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2;
21519 }
21520 }
21521 }
21522 }
21523
21524 // Sort input vectors by decreasing vector element count,
21525 // while preserving the relative order of equally-sized vectors.
21526 // Note that we keep the first "implicit zero vector as-is.
21527 SmallVector<SDValue, 8> SortedVecIn(VecIn);
21528 llvm::stable_sort(MutableArrayRef<SDValue>(SortedVecIn).drop_front(),
21529 [](const SDValue &a, const SDValue &b) {
21530 return a.getValueType().getVectorNumElements() >
21531 b.getValueType().getVectorNumElements();
21532 });
21533
21534 // We now also need to rebuild the VectorMask, because it referenced element
21535 // order in VecIn, and we just sorted them.
21536 for (int &SourceVectorIndex : VectorMask) {
21537 if (SourceVectorIndex <= 0)
21538 continue;
21539 unsigned Idx = getFirstIndexOf(SortedVecIn, VecIn[SourceVectorIndex]);
21540 assert(Idx > 0 && Idx < SortedVecIn.size() &&
21541 VecIn[SourceVectorIndex] == SortedVecIn[Idx] && "Remapping failure");
21542 SourceVectorIndex = Idx;
21543 }
21544
21545 VecIn = std::move(SortedVecIn);
21546
21547 // TODO: Should this fire if some of the input vectors has illegal type (like
21548 // it does now), or should we let legalization run its course first?
21549
21550 // Shuffle phase:
21551 // Take pairs of vectors, and shuffle them so that the result has elements
21552 // from these vectors in the correct places.
21553 // For example, given:
21554 // t10: i32 = extract_vector_elt t1, Constant:i64<0>
21555 // t11: i32 = extract_vector_elt t2, Constant:i64<0>
21556 // t12: i32 = extract_vector_elt t3, Constant:i64<0>
21557 // t13: i32 = extract_vector_elt t1, Constant:i64<1>
21558 // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13
21559 // We will generate:
21560 // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2
21561 // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef
21562 SmallVector<SDValue, 4> Shuffles;
21563 for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) {
21564 unsigned LeftIdx = 2 * In + 1;
21565 SDValue VecLeft = VecIn[LeftIdx];
21566 SDValue VecRight =
21567 (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue();
21568
21569 if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecLeft,
21570 VecRight, LeftIdx, DidSplitVec))
21571 Shuffles.push_back(Shuffle);
21572 else
21573 return SDValue();
21574 }
21575
21576 // If we need the zero vector as an "ingredient" in the blend tree, add it
21577 // to the list of shuffles.
21578 if (UsesZeroVector)
21579 Shuffles.push_back(VT.isInteger() ? DAG.getConstant(0, DL, VT)
21580 : DAG.getConstantFP(0.0, DL, VT));
21581
21582 // If we only have one shuffle, we're done.
21583 if (Shuffles.size() == 1)
21584 return Shuffles[0];
21585
21586 // Update the vector mask to point to the post-shuffle vectors.
21587 for (int &Vec : VectorMask)
21588 if (Vec == 0)
21589 Vec = Shuffles.size() - 1;
21590 else
21591 Vec = (Vec - 1) / 2;
21592
21593 // More than one shuffle. Generate a binary tree of blends, e.g. if from
21594 // the previous step we got the set of shuffles t10, t11, t12, t13, we will
21595 // generate:
21596 // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2
21597 // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4
21598 // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6
21599 // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8
21600 // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11
21601 // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13
21602 // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21
21603
21604 // Make sure the initial size of the shuffle list is even.
21605 if (Shuffles.size() % 2)
21606 Shuffles.push_back(DAG.getUNDEF(VT));
21607
21608 for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) {
21609 if (CurSize % 2) {
21610 Shuffles[CurSize] = DAG.getUNDEF(VT);
21611 CurSize++;
21612 }
21613 for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) {
21614 int Left = 2 * In;
21615 int Right = 2 * In + 1;
21616 SmallVector<int, 8> Mask(NumElems, -1);
21617 SDValue L = Shuffles[Left];
21618 ArrayRef<int> LMask;
21619 bool IsLeftShuffle = L.getOpcode() == ISD::VECTOR_SHUFFLE &&
21620 L.use_empty() && L.getOperand(1).isUndef() &&
21621 L.getOperand(0).getValueType() == L.getValueType();
21622 if (IsLeftShuffle) {
21623 LMask = cast<ShuffleVectorSDNode>(L.getNode())->getMask();
21624 L = L.getOperand(0);
21625 }
21626 SDValue R = Shuffles[Right];
21627 ArrayRef<int> RMask;
21628 bool IsRightShuffle = R.getOpcode() == ISD::VECTOR_SHUFFLE &&
21629 R.use_empty() && R.getOperand(1).isUndef() &&
21630 R.getOperand(0).getValueType() == R.getValueType();
21631 if (IsRightShuffle) {
21632 RMask = cast<ShuffleVectorSDNode>(R.getNode())->getMask();
21633 R = R.getOperand(0);
21634 }
21635 for (unsigned I = 0; I != NumElems; ++I) {
21636 if (VectorMask[I] == Left) {
21637 Mask[I] = I;
21638 if (IsLeftShuffle)
21639 Mask[I] = LMask[I];
21640 VectorMask[I] = In;
21641 } else if (VectorMask[I] == Right) {
21642 Mask[I] = I + NumElems;
21643 if (IsRightShuffle)
21644 Mask[I] = RMask[I] + NumElems;
21645 VectorMask[I] = In;
21646 }
21647 }
21648
21649 Shuffles[In] = DAG.getVectorShuffle(VT, DL, L, R, Mask);
21650 }
21651 }
21652 return Shuffles[0];
21653 }
21654
21655 // Try to turn a build vector of zero extends of extract vector elts into a
21656 // a vector zero extend and possibly an extract subvector.
21657 // TODO: Support sign extend?
21658 // TODO: Allow undef elements?
convertBuildVecZextToZext(SDNode * N)21659 SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) {
21660 if (LegalOperations)
21661 return SDValue();
21662
21663 EVT VT = N->getValueType(0);
21664
21665 bool FoundZeroExtend = false;
21666 SDValue Op0 = N->getOperand(0);
21667 auto checkElem = [&](SDValue Op) -> int64_t {
21668 unsigned Opc = Op.getOpcode();
21669 FoundZeroExtend |= (Opc == ISD::ZERO_EXTEND);
21670 if ((Opc == ISD::ZERO_EXTEND || Opc == ISD::ANY_EXTEND) &&
21671 Op.getOperand(0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
21672 Op0.getOperand(0).getOperand(0) == Op.getOperand(0).getOperand(0))
21673 if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(0).getOperand(1)))
21674 return C->getZExtValue();
21675 return -1;
21676 };
21677
21678 // Make sure the first element matches
21679 // (zext (extract_vector_elt X, C))
21680 // Offset must be a constant multiple of the
21681 // known-minimum vector length of the result type.
21682 int64_t Offset = checkElem(Op0);
21683 if (Offset < 0 || (Offset % VT.getVectorNumElements()) != 0)
21684 return SDValue();
21685
21686 unsigned NumElems = N->getNumOperands();
21687 SDValue In = Op0.getOperand(0).getOperand(0);
21688 EVT InSVT = In.getValueType().getScalarType();
21689 EVT InVT = EVT::getVectorVT(*DAG.getContext(), InSVT, NumElems);
21690
21691 // Don't create an illegal input type after type legalization.
21692 if (LegalTypes && !TLI.isTypeLegal(InVT))
21693 return SDValue();
21694
21695 // Ensure all the elements come from the same vector and are adjacent.
21696 for (unsigned i = 1; i != NumElems; ++i) {
21697 if ((Offset + i) != checkElem(N->getOperand(i)))
21698 return SDValue();
21699 }
21700
21701 SDLoc DL(N);
21702 In = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InVT, In,
21703 Op0.getOperand(0).getOperand(1));
21704 return DAG.getNode(FoundZeroExtend ? ISD::ZERO_EXTEND : ISD::ANY_EXTEND, DL,
21705 VT, In);
21706 }
21707
21708 // If this is a very simple BUILD_VECTOR with first element being a ZERO_EXTEND,
21709 // and all other elements being constant zero's, granularize the BUILD_VECTOR's
21710 // element width, absorbing the ZERO_EXTEND, turning it into a constant zero op.
21711 // This patten can appear during legalization.
21712 //
21713 // NOTE: This can be generalized to allow more than a single
21714 // non-constant-zero op, UNDEF's, and to be KnownBits-based,
convertBuildVecZextToBuildVecWithZeros(SDNode * N)21715 SDValue DAGCombiner::convertBuildVecZextToBuildVecWithZeros(SDNode *N) {
21716 // Don't run this after legalization. Targets may have other preferences.
21717 if (Level >= AfterLegalizeDAG)
21718 return SDValue();
21719
21720 // FIXME: support big-endian.
21721 if (DAG.getDataLayout().isBigEndian())
21722 return SDValue();
21723
21724 EVT VT = N->getValueType(0);
21725 EVT OpVT = N->getOperand(0).getValueType();
21726 assert(!VT.isScalableVector() && "Encountered scalable BUILD_VECTOR?");
21727
21728 EVT OpIntVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
21729
21730 if (!TLI.isTypeLegal(OpIntVT) ||
21731 (LegalOperations && !TLI.isOperationLegalOrCustom(ISD::BITCAST, OpIntVT)))
21732 return SDValue();
21733
21734 unsigned EltBitwidth = VT.getScalarSizeInBits();
21735 // NOTE: the actual width of operands may be wider than that!
21736
21737 // Analyze all operands of this BUILD_VECTOR. What is the largest number of
21738 // active bits they all have? We'll want to truncate them all to that width.
21739 unsigned ActiveBits = 0;
21740 APInt KnownZeroOps(VT.getVectorNumElements(), 0);
21741 for (auto I : enumerate(N->ops())) {
21742 SDValue Op = I.value();
21743 // FIXME: support UNDEF elements?
21744 if (auto *Cst = dyn_cast<ConstantSDNode>(Op)) {
21745 unsigned OpActiveBits =
21746 Cst->getAPIntValue().trunc(EltBitwidth).getActiveBits();
21747 if (OpActiveBits == 0) {
21748 KnownZeroOps.setBit(I.index());
21749 continue;
21750 }
21751 // Profitability check: don't allow non-zero constant operands.
21752 return SDValue();
21753 }
21754 // Profitability check: there must only be a single non-zero operand,
21755 // and it must be the first operand of the BUILD_VECTOR.
21756 if (I.index() != 0)
21757 return SDValue();
21758 // The operand must be a zero-extension itself.
21759 // FIXME: this could be generalized to known leading zeros check.
21760 if (Op.getOpcode() != ISD::ZERO_EXTEND)
21761 return SDValue();
21762 unsigned CurrActiveBits =
21763 Op.getOperand(0).getValueSizeInBits().getFixedValue();
21764 assert(!ActiveBits && "Already encountered non-constant-zero operand?");
21765 ActiveBits = CurrActiveBits;
21766 // We want to at least halve the element size.
21767 if (2 * ActiveBits > EltBitwidth)
21768 return SDValue();
21769 }
21770
21771 // This BUILD_VECTOR must have at least one non-constant-zero operand.
21772 if (ActiveBits == 0)
21773 return SDValue();
21774
21775 // We have EltBitwidth bits, the *minimal* chunk size is ActiveBits,
21776 // into how many chunks can we split our element width?
21777 EVT NewScalarIntVT, NewIntVT;
21778 std::optional<unsigned> Factor;
21779 // We can split the element into at least two chunks, but not into more
21780 // than |_ EltBitwidth / ActiveBits _| chunks. Find a largest split factor
21781 // for which the element width is a multiple of it,
21782 // and the resulting types/operations on that chunk width are legal.
21783 assert(2 * ActiveBits <= EltBitwidth &&
21784 "We know that half or less bits of the element are active.");
21785 for (unsigned Scale = EltBitwidth / ActiveBits; Scale >= 2; --Scale) {
21786 if (EltBitwidth % Scale != 0)
21787 continue;
21788 unsigned ChunkBitwidth = EltBitwidth / Scale;
21789 assert(ChunkBitwidth >= ActiveBits && "As per starting point.");
21790 NewScalarIntVT = EVT::getIntegerVT(*DAG.getContext(), ChunkBitwidth);
21791 NewIntVT = EVT::getVectorVT(*DAG.getContext(), NewScalarIntVT,
21792 Scale * N->getNumOperands());
21793 if (!TLI.isTypeLegal(NewScalarIntVT) || !TLI.isTypeLegal(NewIntVT) ||
21794 (LegalOperations &&
21795 !(TLI.isOperationLegalOrCustom(ISD::TRUNCATE, NewScalarIntVT) &&
21796 TLI.isOperationLegalOrCustom(ISD::BUILD_VECTOR, NewIntVT))))
21797 continue;
21798 Factor = Scale;
21799 break;
21800 }
21801 if (!Factor)
21802 return SDValue();
21803
21804 SDLoc DL(N);
21805 SDValue ZeroOp = DAG.getConstant(0, DL, NewScalarIntVT);
21806
21807 // Recreate the BUILD_VECTOR, with elements now being Factor times smaller.
21808 SmallVector<SDValue, 16> NewOps;
21809 NewOps.reserve(NewIntVT.getVectorNumElements());
21810 for (auto I : enumerate(N->ops())) {
21811 SDValue Op = I.value();
21812 assert(!Op.isUndef() && "FIXME: after allowing UNDEF's, handle them here.");
21813 unsigned SrcOpIdx = I.index();
21814 if (KnownZeroOps[SrcOpIdx]) {
21815 NewOps.append(*Factor, ZeroOp);
21816 continue;
21817 }
21818 Op = DAG.getBitcast(OpIntVT, Op);
21819 Op = DAG.getNode(ISD::TRUNCATE, DL, NewScalarIntVT, Op);
21820 NewOps.emplace_back(Op);
21821 NewOps.append(*Factor - 1, ZeroOp);
21822 }
21823 assert(NewOps.size() == NewIntVT.getVectorNumElements());
21824 SDValue NewBV = DAG.getBuildVector(NewIntVT, DL, NewOps);
21825 NewBV = DAG.getBitcast(VT, NewBV);
21826 return NewBV;
21827 }
21828
visitBUILD_VECTOR(SDNode * N)21829 SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
21830 EVT VT = N->getValueType(0);
21831
21832 // A vector built entirely of undefs is undef.
21833 if (ISD::allOperandsUndef(N))
21834 return DAG.getUNDEF(VT);
21835
21836 // If this is a splat of a bitcast from another vector, change to a
21837 // concat_vector.
21838 // For example:
21839 // (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) ->
21840 // (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X))))
21841 //
21842 // If X is a build_vector itself, the concat can become a larger build_vector.
21843 // TODO: Maybe this is useful for non-splat too?
21844 if (!LegalOperations) {
21845 if (SDValue Splat = cast<BuildVectorSDNode>(N)->getSplatValue()) {
21846 Splat = peekThroughBitcasts(Splat);
21847 EVT SrcVT = Splat.getValueType();
21848 if (SrcVT.isVector()) {
21849 unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
21850 EVT NewVT = EVT::getVectorVT(*DAG.getContext(),
21851 SrcVT.getVectorElementType(), NumElts);
21852 if (!LegalTypes || TLI.isTypeLegal(NewVT)) {
21853 SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
21854 SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N),
21855 NewVT, Ops);
21856 return DAG.getBitcast(VT, Concat);
21857 }
21858 }
21859 }
21860 }
21861
21862 // Check if we can express BUILD VECTOR via subvector extract.
21863 if (!LegalTypes && (N->getNumOperands() > 1)) {
21864 SDValue Op0 = N->getOperand(0);
21865 auto checkElem = [&](SDValue Op) -> uint64_t {
21866 if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
21867 (Op0.getOperand(0) == Op.getOperand(0)))
21868 if (auto CNode = dyn_cast<ConstantSDNode>(Op.getOperand(1)))
21869 return CNode->getZExtValue();
21870 return -1;
21871 };
21872
21873 int Offset = checkElem(Op0);
21874 for (unsigned i = 0; i < N->getNumOperands(); ++i) {
21875 if (Offset + i != checkElem(N->getOperand(i))) {
21876 Offset = -1;
21877 break;
21878 }
21879 }
21880
21881 if ((Offset == 0) &&
21882 (Op0.getOperand(0).getValueType() == N->getValueType(0)))
21883 return Op0.getOperand(0);
21884 if ((Offset != -1) &&
21885 ((Offset % N->getValueType(0).getVectorNumElements()) ==
21886 0)) // IDX must be multiple of output size.
21887 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), N->getValueType(0),
21888 Op0.getOperand(0), Op0.getOperand(1));
21889 }
21890
21891 if (SDValue V = convertBuildVecZextToZext(N))
21892 return V;
21893
21894 if (SDValue V = convertBuildVecZextToBuildVecWithZeros(N))
21895 return V;
21896
21897 if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
21898 return V;
21899
21900 if (SDValue V = reduceBuildVecTruncToBitCast(N))
21901 return V;
21902
21903 if (SDValue V = reduceBuildVecToShuffle(N))
21904 return V;
21905
21906 // A splat of a single element is a SPLAT_VECTOR if supported on the target.
21907 // Do this late as some of the above may replace the splat.
21908 if (TLI.getOperationAction(ISD::SPLAT_VECTOR, VT) != TargetLowering::Expand)
21909 if (SDValue V = cast<BuildVectorSDNode>(N)->getSplatValue()) {
21910 assert(!V.isUndef() && "Splat of undef should have been handled earlier");
21911 return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, V);
21912 }
21913
21914 return SDValue();
21915 }
21916
combineConcatVectorOfScalars(SDNode * N,SelectionDAG & DAG)21917 static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
21918 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21919 EVT OpVT = N->getOperand(0).getValueType();
21920
21921 // If the operands are legal vectors, leave them alone.
21922 if (TLI.isTypeLegal(OpVT))
21923 return SDValue();
21924
21925 SDLoc DL(N);
21926 EVT VT = N->getValueType(0);
21927 SmallVector<SDValue, 8> Ops;
21928
21929 EVT SVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits());
21930 SDValue ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT);
21931
21932 // Keep track of what we encounter.
21933 bool AnyInteger = false;
21934 bool AnyFP = false;
21935 for (const SDValue &Op : N->ops()) {
21936 if (ISD::BITCAST == Op.getOpcode() &&
21937 !Op.getOperand(0).getValueType().isVector())
21938 Ops.push_back(Op.getOperand(0));
21939 else if (ISD::UNDEF == Op.getOpcode())
21940 Ops.push_back(ScalarUndef);
21941 else
21942 return SDValue();
21943
21944 // Note whether we encounter an integer or floating point scalar.
21945 // If it's neither, bail out, it could be something weird like x86mmx.
21946 EVT LastOpVT = Ops.back().getValueType();
21947 if (LastOpVT.isFloatingPoint())
21948 AnyFP = true;
21949 else if (LastOpVT.isInteger())
21950 AnyInteger = true;
21951 else
21952 return SDValue();
21953 }
21954
21955 // If any of the operands is a floating point scalar bitcast to a vector,
21956 // use floating point types throughout, and bitcast everything.
21957 // Replace UNDEFs by another scalar UNDEF node, of the final desired type.
21958 if (AnyFP) {
21959 SVT = EVT::getFloatingPointVT(OpVT.getSizeInBits());
21960 ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT);
21961 if (AnyInteger) {
21962 for (SDValue &Op : Ops) {
21963 if (Op.getValueType() == SVT)
21964 continue;
21965 if (Op.isUndef())
21966 Op = ScalarUndef;
21967 else
21968 Op = DAG.getBitcast(SVT, Op);
21969 }
21970 }
21971 }
21972
21973 EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SVT,
21974 VT.getSizeInBits() / SVT.getSizeInBits());
21975 return DAG.getBitcast(VT, DAG.getBuildVector(VecVT, DL, Ops));
21976 }
21977
21978 // Attempt to merge nested concat_vectors/undefs.
21979 // Fold concat_vectors(concat_vectors(x,y,z,w),u,u,concat_vectors(a,b,c,d))
21980 // --> concat_vectors(x,y,z,w,u,u,u,u,u,u,u,u,a,b,c,d)
combineConcatVectorOfConcatVectors(SDNode * N,SelectionDAG & DAG)21981 static SDValue combineConcatVectorOfConcatVectors(SDNode *N,
21982 SelectionDAG &DAG) {
21983 EVT VT = N->getValueType(0);
21984
21985 // Ensure we're concatenating UNDEF and CONCAT_VECTORS nodes of similar types.
21986 EVT SubVT;
21987 SDValue FirstConcat;
21988 for (const SDValue &Op : N->ops()) {
21989 if (Op.isUndef())
21990 continue;
21991 if (Op.getOpcode() != ISD::CONCAT_VECTORS)
21992 return SDValue();
21993 if (!FirstConcat) {
21994 SubVT = Op.getOperand(0).getValueType();
21995 if (!DAG.getTargetLoweringInfo().isTypeLegal(SubVT))
21996 return SDValue();
21997 FirstConcat = Op;
21998 continue;
21999 }
22000 if (SubVT != Op.getOperand(0).getValueType())
22001 return SDValue();
22002 }
22003 assert(FirstConcat && "Concat of all-undefs found");
22004
22005 SmallVector<SDValue> ConcatOps;
22006 for (const SDValue &Op : N->ops()) {
22007 if (Op.isUndef()) {
22008 ConcatOps.append(FirstConcat->getNumOperands(), DAG.getUNDEF(SubVT));
22009 continue;
22010 }
22011 ConcatOps.append(Op->op_begin(), Op->op_end());
22012 }
22013 return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, ConcatOps);
22014 }
22015
22016 // Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
22017 // operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
22018 // most two distinct vectors the same size as the result, attempt to turn this
22019 // into a legal shuffle.
combineConcatVectorOfExtracts(SDNode * N,SelectionDAG & DAG)22020 static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
22021 EVT VT = N->getValueType(0);
22022 EVT OpVT = N->getOperand(0).getValueType();
22023
22024 // We currently can't generate an appropriate shuffle for a scalable vector.
22025 if (VT.isScalableVector())
22026 return SDValue();
22027
22028 int NumElts = VT.getVectorNumElements();
22029 int NumOpElts = OpVT.getVectorNumElements();
22030
22031 SDValue SV0 = DAG.getUNDEF(VT), SV1 = DAG.getUNDEF(VT);
22032 SmallVector<int, 8> Mask;
22033
22034 for (SDValue Op : N->ops()) {
22035 Op = peekThroughBitcasts(Op);
22036
22037 // UNDEF nodes convert to UNDEF shuffle mask values.
22038 if (Op.isUndef()) {
22039 Mask.append((unsigned)NumOpElts, -1);
22040 continue;
22041 }
22042
22043 if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
22044 return SDValue();
22045
22046 // What vector are we extracting the subvector from and at what index?
22047 SDValue ExtVec = Op.getOperand(0);
22048 int ExtIdx = Op.getConstantOperandVal(1);
22049
22050 // We want the EVT of the original extraction to correctly scale the
22051 // extraction index.
22052 EVT ExtVT = ExtVec.getValueType();
22053 ExtVec = peekThroughBitcasts(ExtVec);
22054
22055 // UNDEF nodes convert to UNDEF shuffle mask values.
22056 if (ExtVec.isUndef()) {
22057 Mask.append((unsigned)NumOpElts, -1);
22058 continue;
22059 }
22060
22061 // Ensure that we are extracting a subvector from a vector the same
22062 // size as the result.
22063 if (ExtVT.getSizeInBits() != VT.getSizeInBits())
22064 return SDValue();
22065
22066 // Scale the subvector index to account for any bitcast.
22067 int NumExtElts = ExtVT.getVectorNumElements();
22068 if (0 == (NumExtElts % NumElts))
22069 ExtIdx /= (NumExtElts / NumElts);
22070 else if (0 == (NumElts % NumExtElts))
22071 ExtIdx *= (NumElts / NumExtElts);
22072 else
22073 return SDValue();
22074
22075 // At most we can reference 2 inputs in the final shuffle.
22076 if (SV0.isUndef() || SV0 == ExtVec) {
22077 SV0 = ExtVec;
22078 for (int i = 0; i != NumOpElts; ++i)
22079 Mask.push_back(i + ExtIdx);
22080 } else if (SV1.isUndef() || SV1 == ExtVec) {
22081 SV1 = ExtVec;
22082 for (int i = 0; i != NumOpElts; ++i)
22083 Mask.push_back(i + ExtIdx + NumElts);
22084 } else {
22085 return SDValue();
22086 }
22087 }
22088
22089 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22090 return TLI.buildLegalVectorShuffle(VT, SDLoc(N), DAG.getBitcast(VT, SV0),
22091 DAG.getBitcast(VT, SV1), Mask, DAG);
22092 }
22093
combineConcatVectorOfCasts(SDNode * N,SelectionDAG & DAG)22094 static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) {
22095 unsigned CastOpcode = N->getOperand(0).getOpcode();
22096 switch (CastOpcode) {
22097 case ISD::SINT_TO_FP:
22098 case ISD::UINT_TO_FP:
22099 case ISD::FP_TO_SINT:
22100 case ISD::FP_TO_UINT:
22101 // TODO: Allow more opcodes?
22102 // case ISD::BITCAST:
22103 // case ISD::TRUNCATE:
22104 // case ISD::ZERO_EXTEND:
22105 // case ISD::SIGN_EXTEND:
22106 // case ISD::FP_EXTEND:
22107 break;
22108 default:
22109 return SDValue();
22110 }
22111
22112 EVT SrcVT = N->getOperand(0).getOperand(0).getValueType();
22113 if (!SrcVT.isVector())
22114 return SDValue();
22115
22116 // All operands of the concat must be the same kind of cast from the same
22117 // source type.
22118 SmallVector<SDValue, 4> SrcOps;
22119 for (SDValue Op : N->ops()) {
22120 if (Op.getOpcode() != CastOpcode || !Op.hasOneUse() ||
22121 Op.getOperand(0).getValueType() != SrcVT)
22122 return SDValue();
22123 SrcOps.push_back(Op.getOperand(0));
22124 }
22125
22126 // The wider cast must be supported by the target. This is unusual because
22127 // the operation support type parameter depends on the opcode. In addition,
22128 // check the other type in the cast to make sure this is really legal.
22129 EVT VT = N->getValueType(0);
22130 EVT SrcEltVT = SrcVT.getVectorElementType();
22131 ElementCount NumElts = SrcVT.getVectorElementCount() * N->getNumOperands();
22132 EVT ConcatSrcVT = EVT::getVectorVT(*DAG.getContext(), SrcEltVT, NumElts);
22133 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22134 switch (CastOpcode) {
22135 case ISD::SINT_TO_FP:
22136 case ISD::UINT_TO_FP:
22137 if (!TLI.isOperationLegalOrCustom(CastOpcode, ConcatSrcVT) ||
22138 !TLI.isTypeLegal(VT))
22139 return SDValue();
22140 break;
22141 case ISD::FP_TO_SINT:
22142 case ISD::FP_TO_UINT:
22143 if (!TLI.isOperationLegalOrCustom(CastOpcode, VT) ||
22144 !TLI.isTypeLegal(ConcatSrcVT))
22145 return SDValue();
22146 break;
22147 default:
22148 llvm_unreachable("Unexpected cast opcode");
22149 }
22150
22151 // concat (cast X), (cast Y)... -> cast (concat X, Y...)
22152 SDLoc DL(N);
22153 SDValue NewConcat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatSrcVT, SrcOps);
22154 return DAG.getNode(CastOpcode, DL, VT, NewConcat);
22155 }
22156
22157 // See if this is a simple CONCAT_VECTORS with no UNDEF operands, and if one of
22158 // the operands is a SHUFFLE_VECTOR, and all other operands are also operands
22159 // to that SHUFFLE_VECTOR, create wider SHUFFLE_VECTOR.
combineConcatVectorOfShuffleAndItsOperands(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalTypes,bool LegalOperations)22160 static SDValue combineConcatVectorOfShuffleAndItsOperands(
22161 SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
22162 bool LegalOperations) {
22163 EVT VT = N->getValueType(0);
22164 EVT OpVT = N->getOperand(0).getValueType();
22165 if (VT.isScalableVector())
22166 return SDValue();
22167
22168 // For now, only allow simple 2-operand concatenations.
22169 if (N->getNumOperands() != 2)
22170 return SDValue();
22171
22172 // Don't create illegal types/shuffles when not allowed to.
22173 if ((LegalTypes && !TLI.isTypeLegal(VT)) ||
22174 (LegalOperations &&
22175 !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, VT)))
22176 return SDValue();
22177
22178 // Analyze all of the operands of the CONCAT_VECTORS. Out of all of them,
22179 // we want to find one that is: (1) a SHUFFLE_VECTOR (2) only used by us,
22180 // and (3) all operands of CONCAT_VECTORS must be either that SHUFFLE_VECTOR,
22181 // or one of the operands of that SHUFFLE_VECTOR (but not UNDEF!).
22182 // (4) and for now, the SHUFFLE_VECTOR must be unary.
22183 ShuffleVectorSDNode *SVN = nullptr;
22184 for (SDValue Op : N->ops()) {
22185 if (auto *CurSVN = dyn_cast<ShuffleVectorSDNode>(Op);
22186 CurSVN && CurSVN->getOperand(1).isUndef() && N->isOnlyUserOf(CurSVN) &&
22187 all_of(N->ops(), [CurSVN](SDValue Op) {
22188 // FIXME: can we allow UNDEF operands?
22189 return !Op.isUndef() &&
22190 (Op.getNode() == CurSVN || is_contained(CurSVN->ops(), Op));
22191 })) {
22192 SVN = CurSVN;
22193 break;
22194 }
22195 }
22196 if (!SVN)
22197 return SDValue();
22198
22199 // We are going to pad the shuffle operands, so any indice, that was picking
22200 // from the second operand, must be adjusted.
22201 SmallVector<int, 16> AdjustedMask;
22202 AdjustedMask.reserve(SVN->getMask().size());
22203 assert(SVN->getOperand(1).isUndef() && "Expected unary shuffle!");
22204 append_range(AdjustedMask, SVN->getMask());
22205
22206 // Identity masks for the operands of the (padded) shuffle.
22207 SmallVector<int, 32> IdentityMask(2 * OpVT.getVectorNumElements());
22208 MutableArrayRef<int> FirstShufOpIdentityMask =
22209 MutableArrayRef<int>(IdentityMask)
22210 .take_front(OpVT.getVectorNumElements());
22211 MutableArrayRef<int> SecondShufOpIdentityMask =
22212 MutableArrayRef<int>(IdentityMask).take_back(OpVT.getVectorNumElements());
22213 std::iota(FirstShufOpIdentityMask.begin(), FirstShufOpIdentityMask.end(), 0);
22214 std::iota(SecondShufOpIdentityMask.begin(), SecondShufOpIdentityMask.end(),
22215 VT.getVectorNumElements());
22216
22217 // New combined shuffle mask.
22218 SmallVector<int, 32> Mask;
22219 Mask.reserve(VT.getVectorNumElements());
22220 for (SDValue Op : N->ops()) {
22221 assert(!Op.isUndef() && "Not expecting to concatenate UNDEF.");
22222 if (Op.getNode() == SVN) {
22223 append_range(Mask, AdjustedMask);
22224 continue;
22225 }
22226 if (Op == SVN->getOperand(0)) {
22227 append_range(Mask, FirstShufOpIdentityMask);
22228 continue;
22229 }
22230 if (Op == SVN->getOperand(1)) {
22231 append_range(Mask, SecondShufOpIdentityMask);
22232 continue;
22233 }
22234 llvm_unreachable("Unexpected operand!");
22235 }
22236
22237 // Don't create illegal shuffle masks.
22238 if (!TLI.isShuffleMaskLegal(Mask, VT))
22239 return SDValue();
22240
22241 // Pad the shuffle operands with UNDEF.
22242 SDLoc dl(N);
22243 std::array<SDValue, 2> ShufOps;
22244 for (auto I : zip(SVN->ops(), ShufOps)) {
22245 SDValue ShufOp = std::get<0>(I);
22246 SDValue &NewShufOp = std::get<1>(I);
22247 if (ShufOp.isUndef())
22248 NewShufOp = DAG.getUNDEF(VT);
22249 else {
22250 SmallVector<SDValue, 2> ShufOpParts(N->getNumOperands(),
22251 DAG.getUNDEF(OpVT));
22252 ShufOpParts[0] = ShufOp;
22253 NewShufOp = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, ShufOpParts);
22254 }
22255 }
22256 // Finally, create the new wide shuffle.
22257 return DAG.getVectorShuffle(VT, dl, ShufOps[0], ShufOps[1], Mask);
22258 }
22259
visitCONCAT_VECTORS(SDNode * N)22260 SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
22261 // If we only have one input vector, we don't need to do any concatenation.
22262 if (N->getNumOperands() == 1)
22263 return N->getOperand(0);
22264
22265 // Check if all of the operands are undefs.
22266 EVT VT = N->getValueType(0);
22267 if (ISD::allOperandsUndef(N))
22268 return DAG.getUNDEF(VT);
22269
22270 // Optimize concat_vectors where all but the first of the vectors are undef.
22271 if (all_of(drop_begin(N->ops()),
22272 [](const SDValue &Op) { return Op.isUndef(); })) {
22273 SDValue In = N->getOperand(0);
22274 assert(In.getValueType().isVector() && "Must concat vectors");
22275
22276 // If the input is a concat_vectors, just make a larger concat by padding
22277 // with smaller undefs.
22278 if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse()) {
22279 unsigned NumOps = N->getNumOperands() * In.getNumOperands();
22280 SmallVector<SDValue, 4> Ops(In->op_begin(), In->op_end());
22281 Ops.resize(NumOps, DAG.getUNDEF(Ops[0].getValueType()));
22282 return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
22283 }
22284
22285 SDValue Scalar = peekThroughOneUseBitcasts(In);
22286
22287 // concat_vectors(scalar_to_vector(scalar), undef) ->
22288 // scalar_to_vector(scalar)
22289 if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR &&
22290 Scalar.hasOneUse()) {
22291 EVT SVT = Scalar.getValueType().getVectorElementType();
22292 if (SVT == Scalar.getOperand(0).getValueType())
22293 Scalar = Scalar.getOperand(0);
22294 }
22295
22296 // concat_vectors(scalar, undef) -> scalar_to_vector(scalar)
22297 if (!Scalar.getValueType().isVector()) {
22298 // If the bitcast type isn't legal, it might be a trunc of a legal type;
22299 // look through the trunc so we can still do the transform:
22300 // concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar)
22301 if (Scalar->getOpcode() == ISD::TRUNCATE &&
22302 !TLI.isTypeLegal(Scalar.getValueType()) &&
22303 TLI.isTypeLegal(Scalar->getOperand(0).getValueType()))
22304 Scalar = Scalar->getOperand(0);
22305
22306 EVT SclTy = Scalar.getValueType();
22307
22308 if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
22309 return SDValue();
22310
22311 // Bail out if the vector size is not a multiple of the scalar size.
22312 if (VT.getSizeInBits() % SclTy.getSizeInBits())
22313 return SDValue();
22314
22315 unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
22316 if (VNTNumElms < 2)
22317 return SDValue();
22318
22319 EVT NVT = EVT::getVectorVT(*DAG.getContext(), SclTy, VNTNumElms);
22320 if (!TLI.isTypeLegal(NVT) || !TLI.isTypeLegal(Scalar.getValueType()))
22321 return SDValue();
22322
22323 SDValue Res = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), NVT, Scalar);
22324 return DAG.getBitcast(VT, Res);
22325 }
22326 }
22327
22328 // Fold any combination of BUILD_VECTOR or UNDEF nodes into one BUILD_VECTOR.
22329 // We have already tested above for an UNDEF only concatenation.
22330 // fold (concat_vectors (BUILD_VECTOR A, B, ...), (BUILD_VECTOR C, D, ...))
22331 // -> (BUILD_VECTOR A, B, ..., C, D, ...)
22332 auto IsBuildVectorOrUndef = [](const SDValue &Op) {
22333 return ISD::UNDEF == Op.getOpcode() || ISD::BUILD_VECTOR == Op.getOpcode();
22334 };
22335 if (llvm::all_of(N->ops(), IsBuildVectorOrUndef)) {
22336 SmallVector<SDValue, 8> Opnds;
22337 EVT SVT = VT.getScalarType();
22338
22339 EVT MinVT = SVT;
22340 if (!SVT.isFloatingPoint()) {
22341 // If BUILD_VECTOR are from built from integer, they may have different
22342 // operand types. Get the smallest type and truncate all operands to it.
22343 bool FoundMinVT = false;
22344 for (const SDValue &Op : N->ops())
22345 if (ISD::BUILD_VECTOR == Op.getOpcode()) {
22346 EVT OpSVT = Op.getOperand(0).getValueType();
22347 MinVT = (!FoundMinVT || OpSVT.bitsLE(MinVT)) ? OpSVT : MinVT;
22348 FoundMinVT = true;
22349 }
22350 assert(FoundMinVT && "Concat vector type mismatch");
22351 }
22352
22353 for (const SDValue &Op : N->ops()) {
22354 EVT OpVT = Op.getValueType();
22355 unsigned NumElts = OpVT.getVectorNumElements();
22356
22357 if (ISD::UNDEF == Op.getOpcode())
22358 Opnds.append(NumElts, DAG.getUNDEF(MinVT));
22359
22360 if (ISD::BUILD_VECTOR == Op.getOpcode()) {
22361 if (SVT.isFloatingPoint()) {
22362 assert(SVT == OpVT.getScalarType() && "Concat vector type mismatch");
22363 Opnds.append(Op->op_begin(), Op->op_begin() + NumElts);
22364 } else {
22365 for (unsigned i = 0; i != NumElts; ++i)
22366 Opnds.push_back(
22367 DAG.getNode(ISD::TRUNCATE, SDLoc(N), MinVT, Op.getOperand(i)));
22368 }
22369 }
22370 }
22371
22372 assert(VT.getVectorNumElements() == Opnds.size() &&
22373 "Concat vector type mismatch");
22374 return DAG.getBuildVector(VT, SDLoc(N), Opnds);
22375 }
22376
22377 // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
22378 // FIXME: Add support for concat_vectors(bitcast(vec0),bitcast(vec1),...).
22379 if (SDValue V = combineConcatVectorOfScalars(N, DAG))
22380 return V;
22381
22382 if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) {
22383 // Fold CONCAT_VECTORS of CONCAT_VECTORS (or undef) to VECTOR_SHUFFLE.
22384 if (SDValue V = combineConcatVectorOfConcatVectors(N, DAG))
22385 return V;
22386
22387 // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
22388 if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
22389 return V;
22390 }
22391
22392 if (SDValue V = combineConcatVectorOfCasts(N, DAG))
22393 return V;
22394
22395 if (SDValue V = combineConcatVectorOfShuffleAndItsOperands(
22396 N, DAG, TLI, LegalTypes, LegalOperations))
22397 return V;
22398
22399 // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
22400 // nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR
22401 // operands and look for a CONCAT operations that place the incoming vectors
22402 // at the exact same location.
22403 //
22404 // For scalable vectors, EXTRACT_SUBVECTOR indexes are implicitly scaled.
22405 SDValue SingleSource = SDValue();
22406 unsigned PartNumElem =
22407 N->getOperand(0).getValueType().getVectorMinNumElements();
22408
22409 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
22410 SDValue Op = N->getOperand(i);
22411
22412 if (Op.isUndef())
22413 continue;
22414
22415 // Check if this is the identity extract:
22416 if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
22417 return SDValue();
22418
22419 // Find the single incoming vector for the extract_subvector.
22420 if (SingleSource.getNode()) {
22421 if (Op.getOperand(0) != SingleSource)
22422 return SDValue();
22423 } else {
22424 SingleSource = Op.getOperand(0);
22425
22426 // Check the source type is the same as the type of the result.
22427 // If not, this concat may extend the vector, so we can not
22428 // optimize it away.
22429 if (SingleSource.getValueType() != N->getValueType(0))
22430 return SDValue();
22431 }
22432
22433 // Check that we are reading from the identity index.
22434 unsigned IdentityIndex = i * PartNumElem;
22435 if (Op.getConstantOperandAPInt(1) != IdentityIndex)
22436 return SDValue();
22437 }
22438
22439 if (SingleSource.getNode())
22440 return SingleSource;
22441
22442 return SDValue();
22443 }
22444
22445 // Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
22446 // if the subvector can be sourced for free.
getSubVectorSrc(SDValue V,SDValue Index,EVT SubVT)22447 static SDValue getSubVectorSrc(SDValue V, SDValue Index, EVT SubVT) {
22448 if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
22449 V.getOperand(1).getValueType() == SubVT && V.getOperand(2) == Index) {
22450 return V.getOperand(1);
22451 }
22452 auto *IndexC = dyn_cast<ConstantSDNode>(Index);
22453 if (IndexC && V.getOpcode() == ISD::CONCAT_VECTORS &&
22454 V.getOperand(0).getValueType() == SubVT &&
22455 (IndexC->getZExtValue() % SubVT.getVectorMinNumElements()) == 0) {
22456 uint64_t SubIdx = IndexC->getZExtValue() / SubVT.getVectorMinNumElements();
22457 return V.getOperand(SubIdx);
22458 }
22459 return SDValue();
22460 }
22461
narrowInsertExtractVectorBinOp(SDNode * Extract,SelectionDAG & DAG,bool LegalOperations)22462 static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
22463 SelectionDAG &DAG,
22464 bool LegalOperations) {
22465 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22466 SDValue BinOp = Extract->getOperand(0);
22467 unsigned BinOpcode = BinOp.getOpcode();
22468 if (!TLI.isBinOp(BinOpcode) || BinOp->getNumValues() != 1)
22469 return SDValue();
22470
22471 EVT VecVT = BinOp.getValueType();
22472 SDValue Bop0 = BinOp.getOperand(0), Bop1 = BinOp.getOperand(1);
22473 if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
22474 return SDValue();
22475
22476 SDValue Index = Extract->getOperand(1);
22477 EVT SubVT = Extract->getValueType(0);
22478 if (!TLI.isOperationLegalOrCustom(BinOpcode, SubVT, LegalOperations))
22479 return SDValue();
22480
22481 SDValue Sub0 = getSubVectorSrc(Bop0, Index, SubVT);
22482 SDValue Sub1 = getSubVectorSrc(Bop1, Index, SubVT);
22483
22484 // TODO: We could handle the case where only 1 operand is being inserted by
22485 // creating an extract of the other operand, but that requires checking
22486 // number of uses and/or costs.
22487 if (!Sub0 || !Sub1)
22488 return SDValue();
22489
22490 // We are inserting both operands of the wide binop only to extract back
22491 // to the narrow vector size. Eliminate all of the insert/extract:
22492 // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
22493 return DAG.getNode(BinOpcode, SDLoc(Extract), SubVT, Sub0, Sub1,
22494 BinOp->getFlags());
22495 }
22496
22497 /// If we are extracting a subvector produced by a wide binary operator try
22498 /// to use a narrow binary operator and/or avoid concatenation and extraction.
narrowExtractedVectorBinOp(SDNode * Extract,SelectionDAG & DAG,bool LegalOperations)22499 static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
22500 bool LegalOperations) {
22501 // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
22502 // some of these bailouts with other transforms.
22503
22504 if (SDValue V = narrowInsertExtractVectorBinOp(Extract, DAG, LegalOperations))
22505 return V;
22506
22507 // The extract index must be a constant, so we can map it to a concat operand.
22508 auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
22509 if (!ExtractIndexC)
22510 return SDValue();
22511
22512 // We are looking for an optionally bitcasted wide vector binary operator
22513 // feeding an extract subvector.
22514 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22515 SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0));
22516 unsigned BOpcode = BinOp.getOpcode();
22517 if (!TLI.isBinOp(BOpcode) || BinOp->getNumValues() != 1)
22518 return SDValue();
22519
22520 // Exclude the fake form of fneg (fsub -0.0, x) because that is likely to be
22521 // reduced to the unary fneg when it is visited, and we probably want to deal
22522 // with fneg in a target-specific way.
22523 if (BOpcode == ISD::FSUB) {
22524 auto *C = isConstOrConstSplatFP(BinOp.getOperand(0), /*AllowUndefs*/ true);
22525 if (C && C->getValueAPF().isNegZero())
22526 return SDValue();
22527 }
22528
22529 // The binop must be a vector type, so we can extract some fraction of it.
22530 EVT WideBVT = BinOp.getValueType();
22531 // The optimisations below currently assume we are dealing with fixed length
22532 // vectors. It is possible to add support for scalable vectors, but at the
22533 // moment we've done no analysis to prove whether they are profitable or not.
22534 if (!WideBVT.isFixedLengthVector())
22535 return SDValue();
22536
22537 EVT VT = Extract->getValueType(0);
22538 unsigned ExtractIndex = ExtractIndexC->getZExtValue();
22539 assert(ExtractIndex % VT.getVectorNumElements() == 0 &&
22540 "Extract index is not a multiple of the vector length.");
22541
22542 // Bail out if this is not a proper multiple width extraction.
22543 unsigned WideWidth = WideBVT.getSizeInBits();
22544 unsigned NarrowWidth = VT.getSizeInBits();
22545 if (WideWidth % NarrowWidth != 0)
22546 return SDValue();
22547
22548 // Bail out if we are extracting a fraction of a single operation. This can
22549 // occur because we potentially looked through a bitcast of the binop.
22550 unsigned NarrowingRatio = WideWidth / NarrowWidth;
22551 unsigned WideNumElts = WideBVT.getVectorNumElements();
22552 if (WideNumElts % NarrowingRatio != 0)
22553 return SDValue();
22554
22555 // Bail out if the target does not support a narrower version of the binop.
22556 EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(),
22557 WideNumElts / NarrowingRatio);
22558 if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT))
22559 return SDValue();
22560
22561 // If extraction is cheap, we don't need to look at the binop operands
22562 // for concat ops. The narrow binop alone makes this transform profitable.
22563 // We can't just reuse the original extract index operand because we may have
22564 // bitcasted.
22565 unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements();
22566 unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
22567 if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) &&
22568 BinOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) {
22569 // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
22570 SDLoc DL(Extract);
22571 SDValue NewExtIndex = DAG.getVectorIdxConstant(ExtBOIdx, DL);
22572 SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
22573 BinOp.getOperand(0), NewExtIndex);
22574 SDValue Y = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
22575 BinOp.getOperand(1), NewExtIndex);
22576 SDValue NarrowBinOp =
22577 DAG.getNode(BOpcode, DL, NarrowBVT, X, Y, BinOp->getFlags());
22578 return DAG.getBitcast(VT, NarrowBinOp);
22579 }
22580
22581 // Only handle the case where we are doubling and then halving. A larger ratio
22582 // may require more than two narrow binops to replace the wide binop.
22583 if (NarrowingRatio != 2)
22584 return SDValue();
22585
22586 // TODO: The motivating case for this transform is an x86 AVX1 target. That
22587 // target has temptingly almost legal versions of bitwise logic ops in 256-bit
22588 // flavors, but no other 256-bit integer support. This could be extended to
22589 // handle any binop, but that may require fixing/adding other folds to avoid
22590 // codegen regressions.
22591 if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
22592 return SDValue();
22593
22594 // We need at least one concatenation operation of a binop operand to make
22595 // this transform worthwhile. The concat must double the input vector sizes.
22596 auto GetSubVector = [ConcatOpNum](SDValue V) -> SDValue {
22597 if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getNumOperands() == 2)
22598 return V.getOperand(ConcatOpNum);
22599 return SDValue();
22600 };
22601 SDValue SubVecL = GetSubVector(peekThroughBitcasts(BinOp.getOperand(0)));
22602 SDValue SubVecR = GetSubVector(peekThroughBitcasts(BinOp.getOperand(1)));
22603
22604 if (SubVecL || SubVecR) {
22605 // If a binop operand was not the result of a concat, we must extract a
22606 // half-sized operand for our new narrow binop:
22607 // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
22608 // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
22609 // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
22610 SDLoc DL(Extract);
22611 SDValue IndexC = DAG.getVectorIdxConstant(ExtBOIdx, DL);
22612 SDValue X = SubVecL ? DAG.getBitcast(NarrowBVT, SubVecL)
22613 : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
22614 BinOp.getOperand(0), IndexC);
22615
22616 SDValue Y = SubVecR ? DAG.getBitcast(NarrowBVT, SubVecR)
22617 : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
22618 BinOp.getOperand(1), IndexC);
22619
22620 SDValue NarrowBinOp = DAG.getNode(BOpcode, DL, NarrowBVT, X, Y);
22621 return DAG.getBitcast(VT, NarrowBinOp);
22622 }
22623
22624 return SDValue();
22625 }
22626
22627 /// If we are extracting a subvector from a wide vector load, convert to a
22628 /// narrow load to eliminate the extraction:
22629 /// (extract_subvector (load wide vector)) --> (load narrow vector)
narrowExtractedVectorLoad(SDNode * Extract,SelectionDAG & DAG)22630 static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) {
22631 // TODO: Add support for big-endian. The offset calculation must be adjusted.
22632 if (DAG.getDataLayout().isBigEndian())
22633 return SDValue();
22634
22635 auto *Ld = dyn_cast<LoadSDNode>(Extract->getOperand(0));
22636 if (!Ld || Ld->getExtensionType() || !Ld->isSimple())
22637 return SDValue();
22638
22639 // Allow targets to opt-out.
22640 EVT VT = Extract->getValueType(0);
22641
22642 // We can only create byte sized loads.
22643 if (!VT.isByteSized())
22644 return SDValue();
22645
22646 unsigned Index = Extract->getConstantOperandVal(1);
22647 unsigned NumElts = VT.getVectorMinNumElements();
22648
22649 // The definition of EXTRACT_SUBVECTOR states that the index must be a
22650 // multiple of the minimum number of elements in the result type.
22651 assert(Index % NumElts == 0 && "The extract subvector index is not a "
22652 "multiple of the result's element count");
22653
22654 // It's fine to use TypeSize here as we know the offset will not be negative.
22655 TypeSize Offset = VT.getStoreSize() * (Index / NumElts);
22656
22657 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22658 if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT))
22659 return SDValue();
22660
22661 // The narrow load will be offset from the base address of the old load if
22662 // we are extracting from something besides index 0 (little-endian).
22663 SDLoc DL(Extract);
22664
22665 // TODO: Use "BaseIndexOffset" to make this more effective.
22666 SDValue NewAddr = DAG.getMemBasePlusOffset(Ld->getBasePtr(), Offset, DL);
22667
22668 uint64_t StoreSize = MemoryLocation::getSizeOrUnknown(VT.getStoreSize());
22669 MachineFunction &MF = DAG.getMachineFunction();
22670 MachineMemOperand *MMO;
22671 if (Offset.isScalable()) {
22672 MachinePointerInfo MPI =
22673 MachinePointerInfo(Ld->getPointerInfo().getAddrSpace());
22674 MMO = MF.getMachineMemOperand(Ld->getMemOperand(), MPI, StoreSize);
22675 } else
22676 MMO = MF.getMachineMemOperand(Ld->getMemOperand(), Offset.getFixedValue(),
22677 StoreSize);
22678
22679 SDValue NewLd = DAG.getLoad(VT, DL, Ld->getChain(), NewAddr, MMO);
22680 DAG.makeEquivalentMemoryOrdering(Ld, NewLd);
22681 return NewLd;
22682 }
22683
22684 /// Given EXTRACT_SUBVECTOR(VECTOR_SHUFFLE(Op0, Op1, Mask)),
22685 /// try to produce VECTOR_SHUFFLE(EXTRACT_SUBVECTOR(Op?, ?),
22686 /// EXTRACT_SUBVECTOR(Op?, ?),
22687 /// Mask'))
22688 /// iff it is legal and profitable to do so. Notably, the trimmed mask
22689 /// (containing only the elements that are extracted)
22690 /// must reference at most two subvectors.
foldExtractSubvectorFromShuffleVector(SDNode * N,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)22691 static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
22692 SelectionDAG &DAG,
22693 const TargetLowering &TLI,
22694 bool LegalOperations) {
22695 assert(N->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
22696 "Must only be called on EXTRACT_SUBVECTOR's");
22697
22698 SDValue N0 = N->getOperand(0);
22699
22700 // Only deal with non-scalable vectors.
22701 EVT NarrowVT = N->getValueType(0);
22702 EVT WideVT = N0.getValueType();
22703 if (!NarrowVT.isFixedLengthVector() || !WideVT.isFixedLengthVector())
22704 return SDValue();
22705
22706 // The operand must be a shufflevector.
22707 auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(N0);
22708 if (!WideShuffleVector)
22709 return SDValue();
22710
22711 // The old shuffleneeds to go away.
22712 if (!WideShuffleVector->hasOneUse())
22713 return SDValue();
22714
22715 // And the narrow shufflevector that we'll form must be legal.
22716 if (LegalOperations &&
22717 !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, NarrowVT))
22718 return SDValue();
22719
22720 uint64_t FirstExtractedEltIdx = N->getConstantOperandVal(1);
22721 int NumEltsExtracted = NarrowVT.getVectorNumElements();
22722 assert((FirstExtractedEltIdx % NumEltsExtracted) == 0 &&
22723 "Extract index is not a multiple of the output vector length.");
22724
22725 int WideNumElts = WideVT.getVectorNumElements();
22726
22727 SmallVector<int, 16> NewMask;
22728 NewMask.reserve(NumEltsExtracted);
22729 SmallSetVector<std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>, 2>
22730 DemandedSubvectors;
22731
22732 // Try to decode the wide mask into narrow mask from at most two subvectors.
22733 for (int M : WideShuffleVector->getMask().slice(FirstExtractedEltIdx,
22734 NumEltsExtracted)) {
22735 assert((M >= -1) && (M < (2 * WideNumElts)) &&
22736 "Out-of-bounds shuffle mask?");
22737
22738 if (M < 0) {
22739 // Does not depend on operands, does not require adjustment.
22740 NewMask.emplace_back(M);
22741 continue;
22742 }
22743
22744 // From which operand of the shuffle does this shuffle mask element pick?
22745 int WideShufOpIdx = M / WideNumElts;
22746 // Which element of that operand is picked?
22747 int OpEltIdx = M % WideNumElts;
22748
22749 assert((OpEltIdx + WideShufOpIdx * WideNumElts) == M &&
22750 "Shuffle mask vector decomposition failure.");
22751
22752 // And which NumEltsExtracted-sized subvector of that operand is that?
22753 int OpSubvecIdx = OpEltIdx / NumEltsExtracted;
22754 // And which element within that subvector of that operand is that?
22755 int OpEltIdxInSubvec = OpEltIdx % NumEltsExtracted;
22756
22757 assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted) == OpEltIdx &&
22758 "Shuffle mask subvector decomposition failure.");
22759
22760 assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted +
22761 WideShufOpIdx * WideNumElts) == M &&
22762 "Shuffle mask full decomposition failure.");
22763
22764 SDValue Op = WideShuffleVector->getOperand(WideShufOpIdx);
22765
22766 if (Op.isUndef()) {
22767 // Picking from an undef operand. Let's adjust mask instead.
22768 NewMask.emplace_back(-1);
22769 continue;
22770 }
22771
22772 // Profitability check: only deal with extractions from the first subvector.
22773 if (OpSubvecIdx != 0)
22774 return SDValue();
22775
22776 const std::pair<SDValue, int> DemandedSubvector =
22777 std::make_pair(Op, OpSubvecIdx);
22778
22779 if (DemandedSubvectors.insert(DemandedSubvector)) {
22780 if (DemandedSubvectors.size() > 2)
22781 return SDValue(); // We can't handle more than two subvectors.
22782 // How many elements into the WideVT does this subvector start?
22783 int Index = NumEltsExtracted * OpSubvecIdx;
22784 // Bail out if the extraction isn't going to be cheap.
22785 if (!TLI.isExtractSubvectorCheap(NarrowVT, WideVT, Index))
22786 return SDValue();
22787 }
22788
22789 // Ok, but from which operand of the new shuffle will this element pick?
22790 int NewOpIdx =
22791 getFirstIndexOf(DemandedSubvectors.getArrayRef(), DemandedSubvector);
22792 assert((NewOpIdx == 0 || NewOpIdx == 1) && "Unexpected operand index.");
22793
22794 int AdjM = OpEltIdxInSubvec + NewOpIdx * NumEltsExtracted;
22795 NewMask.emplace_back(AdjM);
22796 }
22797 assert(NewMask.size() == (unsigned)NumEltsExtracted && "Produced bad mask.");
22798 assert(DemandedSubvectors.size() <= 2 &&
22799 "Should have ended up demanding at most two subvectors.");
22800
22801 // Did we discover that the shuffle does not actually depend on operands?
22802 if (DemandedSubvectors.empty())
22803 return DAG.getUNDEF(NarrowVT);
22804
22805 // We still perform the exact same EXTRACT_SUBVECTOR, just on different
22806 // operand[s]/index[es], so there is no point in checking for it's legality.
22807
22808 // Do not turn a legal shuffle into an illegal one.
22809 if (TLI.isShuffleMaskLegal(WideShuffleVector->getMask(), WideVT) &&
22810 !TLI.isShuffleMaskLegal(NewMask, NarrowVT))
22811 return SDValue();
22812
22813 SDLoc DL(N);
22814
22815 SmallVector<SDValue, 2> NewOps;
22816 for (const std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>
22817 &DemandedSubvector : DemandedSubvectors) {
22818 // How many elements into the WideVT does this subvector start?
22819 int Index = NumEltsExtracted * DemandedSubvector.second;
22820 SDValue IndexC = DAG.getVectorIdxConstant(Index, DL);
22821 NewOps.emplace_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowVT,
22822 DemandedSubvector.first, IndexC));
22823 }
22824 assert((NewOps.size() == 1 || NewOps.size() == 2) &&
22825 "Should end up with either one or two ops");
22826
22827 // If we ended up with only one operand, pad with an undef.
22828 if (NewOps.size() == 1)
22829 NewOps.emplace_back(DAG.getUNDEF(NarrowVT));
22830
22831 return DAG.getVectorShuffle(NarrowVT, DL, NewOps[0], NewOps[1], NewMask);
22832 }
22833
visitEXTRACT_SUBVECTOR(SDNode * N)22834 SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
22835 EVT NVT = N->getValueType(0);
22836 SDValue V = N->getOperand(0);
22837 uint64_t ExtIdx = N->getConstantOperandVal(1);
22838
22839 // Extract from UNDEF is UNDEF.
22840 if (V.isUndef())
22841 return DAG.getUNDEF(NVT);
22842
22843 if (TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, NVT))
22844 if (SDValue NarrowLoad = narrowExtractedVectorLoad(N, DAG))
22845 return NarrowLoad;
22846
22847 // Combine an extract of an extract into a single extract_subvector.
22848 // ext (ext X, C), 0 --> ext X, C
22849 if (ExtIdx == 0 && V.getOpcode() == ISD::EXTRACT_SUBVECTOR && V.hasOneUse()) {
22850 if (TLI.isExtractSubvectorCheap(NVT, V.getOperand(0).getValueType(),
22851 V.getConstantOperandVal(1)) &&
22852 TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NVT)) {
22853 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT, V.getOperand(0),
22854 V.getOperand(1));
22855 }
22856 }
22857
22858 // ty1 extract_vector(ty2 splat(V))) -> ty1 splat(V)
22859 if (V.getOpcode() == ISD::SPLAT_VECTOR)
22860 if (DAG.isConstantValueOfAnyType(V.getOperand(0)) || V.hasOneUse())
22861 if (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, NVT))
22862 return DAG.getSplatVector(NVT, SDLoc(N), V.getOperand(0));
22863
22864 // Try to move vector bitcast after extract_subv by scaling extraction index:
22865 // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index')
22866 if (V.getOpcode() == ISD::BITCAST &&
22867 V.getOperand(0).getValueType().isVector() &&
22868 (!LegalOperations || TLI.isOperationLegal(ISD::BITCAST, NVT))) {
22869 SDValue SrcOp = V.getOperand(0);
22870 EVT SrcVT = SrcOp.getValueType();
22871 unsigned SrcNumElts = SrcVT.getVectorMinNumElements();
22872 unsigned DestNumElts = V.getValueType().getVectorMinNumElements();
22873 if ((SrcNumElts % DestNumElts) == 0) {
22874 unsigned SrcDestRatio = SrcNumElts / DestNumElts;
22875 ElementCount NewExtEC = NVT.getVectorElementCount() * SrcDestRatio;
22876 EVT NewExtVT = EVT::getVectorVT(*DAG.getContext(), SrcVT.getScalarType(),
22877 NewExtEC);
22878 if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
22879 SDLoc DL(N);
22880 SDValue NewIndex = DAG.getVectorIdxConstant(ExtIdx * SrcDestRatio, DL);
22881 SDValue NewExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
22882 V.getOperand(0), NewIndex);
22883 return DAG.getBitcast(NVT, NewExtract);
22884 }
22885 }
22886 if ((DestNumElts % SrcNumElts) == 0) {
22887 unsigned DestSrcRatio = DestNumElts / SrcNumElts;
22888 if (NVT.getVectorElementCount().isKnownMultipleOf(DestSrcRatio)) {
22889 ElementCount NewExtEC =
22890 NVT.getVectorElementCount().divideCoefficientBy(DestSrcRatio);
22891 EVT ScalarVT = SrcVT.getScalarType();
22892 if ((ExtIdx % DestSrcRatio) == 0) {
22893 SDLoc DL(N);
22894 unsigned IndexValScaled = ExtIdx / DestSrcRatio;
22895 EVT NewExtVT =
22896 EVT::getVectorVT(*DAG.getContext(), ScalarVT, NewExtEC);
22897 if (TLI.isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, NewExtVT)) {
22898 SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
22899 SDValue NewExtract =
22900 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewExtVT,
22901 V.getOperand(0), NewIndex);
22902 return DAG.getBitcast(NVT, NewExtract);
22903 }
22904 if (NewExtEC.isScalar() &&
22905 TLI.isOperationLegalOrCustom(ISD::EXTRACT_VECTOR_ELT, ScalarVT)) {
22906 SDValue NewIndex = DAG.getVectorIdxConstant(IndexValScaled, DL);
22907 SDValue NewExtract =
22908 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT,
22909 V.getOperand(0), NewIndex);
22910 return DAG.getBitcast(NVT, NewExtract);
22911 }
22912 }
22913 }
22914 }
22915 }
22916
22917 if (V.getOpcode() == ISD::CONCAT_VECTORS) {
22918 unsigned ExtNumElts = NVT.getVectorMinNumElements();
22919 EVT ConcatSrcVT = V.getOperand(0).getValueType();
22920 assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() &&
22921 "Concat and extract subvector do not change element type");
22922 assert((ExtIdx % ExtNumElts) == 0 &&
22923 "Extract index is not a multiple of the input vector length.");
22924
22925 unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorMinNumElements();
22926 unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts;
22927
22928 // If the concatenated source types match this extract, it's a direct
22929 // simplification:
22930 // extract_subvec (concat V1, V2, ...), i --> Vi
22931 if (NVT.getVectorElementCount() == ConcatSrcVT.getVectorElementCount())
22932 return V.getOperand(ConcatOpIdx);
22933
22934 // If the concatenated source vectors are a multiple length of this extract,
22935 // then extract a fraction of one of those source vectors directly from a
22936 // concat operand. Example:
22937 // v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y), 14 -->
22938 // v2i8 extract_subvec v8i8 Y, 6
22939 if (NVT.isFixedLengthVector() && ConcatSrcVT.isFixedLengthVector() &&
22940 ConcatSrcNumElts % ExtNumElts == 0) {
22941 SDLoc DL(N);
22942 unsigned NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts;
22943 assert(NewExtIdx + ExtNumElts <= ConcatSrcNumElts &&
22944 "Trying to extract from >1 concat operand?");
22945 assert(NewExtIdx % ExtNumElts == 0 &&
22946 "Extract index is not a multiple of the input vector length.");
22947 SDValue NewIndexC = DAG.getVectorIdxConstant(NewExtIdx, DL);
22948 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT,
22949 V.getOperand(ConcatOpIdx), NewIndexC);
22950 }
22951 }
22952
22953 if (SDValue V =
22954 foldExtractSubvectorFromShuffleVector(N, DAG, TLI, LegalOperations))
22955 return V;
22956
22957 V = peekThroughBitcasts(V);
22958
22959 // If the input is a build vector. Try to make a smaller build vector.
22960 if (V.getOpcode() == ISD::BUILD_VECTOR) {
22961 EVT InVT = V.getValueType();
22962 unsigned ExtractSize = NVT.getSizeInBits();
22963 unsigned EltSize = InVT.getScalarSizeInBits();
22964 // Only do this if we won't split any elements.
22965 if (ExtractSize % EltSize == 0) {
22966 unsigned NumElems = ExtractSize / EltSize;
22967 EVT EltVT = InVT.getVectorElementType();
22968 EVT ExtractVT =
22969 NumElems == 1 ? EltVT
22970 : EVT::getVectorVT(*DAG.getContext(), EltVT, NumElems);
22971 if ((Level < AfterLegalizeDAG ||
22972 (NumElems == 1 ||
22973 TLI.isOperationLegal(ISD::BUILD_VECTOR, ExtractVT))) &&
22974 (!LegalTypes || TLI.isTypeLegal(ExtractVT))) {
22975 unsigned IdxVal = (ExtIdx * NVT.getScalarSizeInBits()) / EltSize;
22976
22977 if (NumElems == 1) {
22978 SDValue Src = V->getOperand(IdxVal);
22979 if (EltVT != Src.getValueType())
22980 Src = DAG.getNode(ISD::TRUNCATE, SDLoc(N), InVT, Src);
22981 return DAG.getBitcast(NVT, Src);
22982 }
22983
22984 // Extract the pieces from the original build_vector.
22985 SDValue BuildVec = DAG.getBuildVector(ExtractVT, SDLoc(N),
22986 V->ops().slice(IdxVal, NumElems));
22987 return DAG.getBitcast(NVT, BuildVec);
22988 }
22989 }
22990 }
22991
22992 if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
22993 // Handle only simple case where vector being inserted and vector
22994 // being extracted are of same size.
22995 EVT SmallVT = V.getOperand(1).getValueType();
22996 if (!NVT.bitsEq(SmallVT))
22997 return SDValue();
22998
22999 // Combine:
23000 // (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
23001 // Into:
23002 // indices are equal or bit offsets are equal => V1
23003 // otherwise => (extract_subvec V1, ExtIdx)
23004 uint64_t InsIdx = V.getConstantOperandVal(2);
23005 if (InsIdx * SmallVT.getScalarSizeInBits() ==
23006 ExtIdx * NVT.getScalarSizeInBits()) {
23007 if (LegalOperations && !TLI.isOperationLegal(ISD::BITCAST, NVT))
23008 return SDValue();
23009
23010 return DAG.getBitcast(NVT, V.getOperand(1));
23011 }
23012 return DAG.getNode(
23013 ISD::EXTRACT_SUBVECTOR, SDLoc(N), NVT,
23014 DAG.getBitcast(N->getOperand(0).getValueType(), V.getOperand(0)),
23015 N->getOperand(1));
23016 }
23017
23018 if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG, LegalOperations))
23019 return NarrowBOp;
23020
23021 if (SimplifyDemandedVectorElts(SDValue(N, 0)))
23022 return SDValue(N, 0);
23023
23024 return SDValue();
23025 }
23026
23027 /// Try to convert a wide shuffle of concatenated vectors into 2 narrow shuffles
23028 /// followed by concatenation. Narrow vector ops may have better performance
23029 /// than wide ops, and this can unlock further narrowing of other vector ops.
23030 /// Targets can invert this transform later if it is not profitable.
foldShuffleOfConcatUndefs(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)23031 static SDValue foldShuffleOfConcatUndefs(ShuffleVectorSDNode *Shuf,
23032 SelectionDAG &DAG) {
23033 SDValue N0 = Shuf->getOperand(0), N1 = Shuf->getOperand(1);
23034 if (N0.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 ||
23035 N1.getOpcode() != ISD::CONCAT_VECTORS || N1.getNumOperands() != 2 ||
23036 !N0.getOperand(1).isUndef() || !N1.getOperand(1).isUndef())
23037 return SDValue();
23038
23039 // Split the wide shuffle mask into halves. Any mask element that is accessing
23040 // operand 1 is offset down to account for narrowing of the vectors.
23041 ArrayRef<int> Mask = Shuf->getMask();
23042 EVT VT = Shuf->getValueType(0);
23043 unsigned NumElts = VT.getVectorNumElements();
23044 unsigned HalfNumElts = NumElts / 2;
23045 SmallVector<int, 16> Mask0(HalfNumElts, -1);
23046 SmallVector<int, 16> Mask1(HalfNumElts, -1);
23047 for (unsigned i = 0; i != NumElts; ++i) {
23048 if (Mask[i] == -1)
23049 continue;
23050 // If we reference the upper (undef) subvector then the element is undef.
23051 if ((Mask[i] % NumElts) >= HalfNumElts)
23052 continue;
23053 int M = Mask[i] < (int)NumElts ? Mask[i] : Mask[i] - (int)HalfNumElts;
23054 if (i < HalfNumElts)
23055 Mask0[i] = M;
23056 else
23057 Mask1[i - HalfNumElts] = M;
23058 }
23059
23060 // Ask the target if this is a valid transform.
23061 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23062 EVT HalfVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(),
23063 HalfNumElts);
23064 if (!TLI.isShuffleMaskLegal(Mask0, HalfVT) ||
23065 !TLI.isShuffleMaskLegal(Mask1, HalfVT))
23066 return SDValue();
23067
23068 // shuffle (concat X, undef), (concat Y, undef), Mask -->
23069 // concat (shuffle X, Y, Mask0), (shuffle X, Y, Mask1)
23070 SDValue X = N0.getOperand(0), Y = N1.getOperand(0);
23071 SDLoc DL(Shuf);
23072 SDValue Shuf0 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask0);
23073 SDValue Shuf1 = DAG.getVectorShuffle(HalfVT, DL, X, Y, Mask1);
23074 return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Shuf0, Shuf1);
23075 }
23076
23077 // Tries to turn a shuffle of two CONCAT_VECTORS into a single concat,
23078 // or turn a shuffle of a single concat into simpler shuffle then concat.
partitionShuffleOfConcats(SDNode * N,SelectionDAG & DAG)23079 static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
23080 EVT VT = N->getValueType(0);
23081 unsigned NumElts = VT.getVectorNumElements();
23082
23083 SDValue N0 = N->getOperand(0);
23084 SDValue N1 = N->getOperand(1);
23085 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
23086 ArrayRef<int> Mask = SVN->getMask();
23087
23088 SmallVector<SDValue, 4> Ops;
23089 EVT ConcatVT = N0.getOperand(0).getValueType();
23090 unsigned NumElemsPerConcat = ConcatVT.getVectorNumElements();
23091 unsigned NumConcats = NumElts / NumElemsPerConcat;
23092
23093 auto IsUndefMaskElt = [](int i) { return i == -1; };
23094
23095 // Special case: shuffle(concat(A,B)) can be more efficiently represented
23096 // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
23097 // half vector elements.
23098 if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
23099 llvm::all_of(Mask.slice(NumElemsPerConcat, NumElemsPerConcat),
23100 IsUndefMaskElt)) {
23101 N0 = DAG.getVectorShuffle(ConcatVT, SDLoc(N), N0.getOperand(0),
23102 N0.getOperand(1),
23103 Mask.slice(0, NumElemsPerConcat));
23104 N1 = DAG.getUNDEF(ConcatVT);
23105 return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N0, N1);
23106 }
23107
23108 // Look at every vector that's inserted. We're looking for exact
23109 // subvector-sized copies from a concatenated vector
23110 for (unsigned I = 0; I != NumConcats; ++I) {
23111 unsigned Begin = I * NumElemsPerConcat;
23112 ArrayRef<int> SubMask = Mask.slice(Begin, NumElemsPerConcat);
23113
23114 // Make sure we're dealing with a copy.
23115 if (llvm::all_of(SubMask, IsUndefMaskElt)) {
23116 Ops.push_back(DAG.getUNDEF(ConcatVT));
23117 continue;
23118 }
23119
23120 int OpIdx = -1;
23121 for (int i = 0; i != (int)NumElemsPerConcat; ++i) {
23122 if (IsUndefMaskElt(SubMask[i]))
23123 continue;
23124 if ((SubMask[i] % (int)NumElemsPerConcat) != i)
23125 return SDValue();
23126 int EltOpIdx = SubMask[i] / NumElemsPerConcat;
23127 if (0 <= OpIdx && EltOpIdx != OpIdx)
23128 return SDValue();
23129 OpIdx = EltOpIdx;
23130 }
23131 assert(0 <= OpIdx && "Unknown concat_vectors op");
23132
23133 if (OpIdx < (int)N0.getNumOperands())
23134 Ops.push_back(N0.getOperand(OpIdx));
23135 else
23136 Ops.push_back(N1.getOperand(OpIdx - N0.getNumOperands()));
23137 }
23138
23139 return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
23140 }
23141
23142 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
23143 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
23144 //
23145 // SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always
23146 // a simplification in some sense, but it isn't appropriate in general: some
23147 // BUILD_VECTORs are substantially cheaper than others. The general case
23148 // of a BUILD_VECTOR requires inserting each element individually (or
23149 // performing the equivalent in a temporary stack variable). A BUILD_VECTOR of
23150 // all constants is a single constant pool load. A BUILD_VECTOR where each
23151 // element is identical is a splat. A BUILD_VECTOR where most of the operands
23152 // are undef lowers to a small number of element insertions.
23153 //
23154 // To deal with this, we currently use a bunch of mostly arbitrary heuristics.
23155 // We don't fold shuffles where one side is a non-zero constant, and we don't
23156 // fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate
23157 // non-constant operands. This seems to work out reasonably well in practice.
combineShuffleOfScalars(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI)23158 static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
23159 SelectionDAG &DAG,
23160 const TargetLowering &TLI) {
23161 EVT VT = SVN->getValueType(0);
23162 unsigned NumElts = VT.getVectorNumElements();
23163 SDValue N0 = SVN->getOperand(0);
23164 SDValue N1 = SVN->getOperand(1);
23165
23166 if (!N0->hasOneUse())
23167 return SDValue();
23168
23169 // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as
23170 // discussed above.
23171 if (!N1.isUndef()) {
23172 if (!N1->hasOneUse())
23173 return SDValue();
23174
23175 bool N0AnyConst = isAnyConstantBuildVector(N0);
23176 bool N1AnyConst = isAnyConstantBuildVector(N1);
23177 if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N0.getNode()))
23178 return SDValue();
23179 if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N1.getNode()))
23180 return SDValue();
23181 }
23182
23183 // If both inputs are splats of the same value then we can safely merge this
23184 // to a single BUILD_VECTOR with undef elements based on the shuffle mask.
23185 bool IsSplat = false;
23186 auto *BV0 = dyn_cast<BuildVectorSDNode>(N0);
23187 auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
23188 if (BV0 && BV1)
23189 if (SDValue Splat0 = BV0->getSplatValue())
23190 IsSplat = (Splat0 == BV1->getSplatValue());
23191
23192 SmallVector<SDValue, 8> Ops;
23193 SmallSet<SDValue, 16> DuplicateOps;
23194 for (int M : SVN->getMask()) {
23195 SDValue Op = DAG.getUNDEF(VT.getScalarType());
23196 if (M >= 0) {
23197 int Idx = M < (int)NumElts ? M : M - NumElts;
23198 SDValue &S = (M < (int)NumElts ? N0 : N1);
23199 if (S.getOpcode() == ISD::BUILD_VECTOR) {
23200 Op = S.getOperand(Idx);
23201 } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) {
23202 SDValue Op0 = S.getOperand(0);
23203 Op = Idx == 0 ? Op0 : DAG.getUNDEF(Op0.getValueType());
23204 } else {
23205 // Operand can't be combined - bail out.
23206 return SDValue();
23207 }
23208 }
23209
23210 // Don't duplicate a non-constant BUILD_VECTOR operand unless we're
23211 // generating a splat; semantically, this is fine, but it's likely to
23212 // generate low-quality code if the target can't reconstruct an appropriate
23213 // shuffle.
23214 if (!Op.isUndef() && !isIntOrFPConstant(Op))
23215 if (!IsSplat && !DuplicateOps.insert(Op).second)
23216 return SDValue();
23217
23218 Ops.push_back(Op);
23219 }
23220
23221 // BUILD_VECTOR requires all inputs to be of the same type, find the
23222 // maximum type and extend them all.
23223 EVT SVT = VT.getScalarType();
23224 if (SVT.isInteger())
23225 for (SDValue &Op : Ops)
23226 SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT);
23227 if (SVT != VT.getScalarType())
23228 for (SDValue &Op : Ops)
23229 Op = Op.isUndef() ? DAG.getUNDEF(SVT)
23230 : (TLI.isZExtFree(Op.getValueType(), SVT)
23231 ? DAG.getZExtOrTrunc(Op, SDLoc(SVN), SVT)
23232 : DAG.getSExtOrTrunc(Op, SDLoc(SVN), SVT));
23233 return DAG.getBuildVector(VT, SDLoc(SVN), Ops);
23234 }
23235
23236 // Match shuffles that can be converted to *_vector_extend_in_reg.
23237 // This is often generated during legalization.
23238 // e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src)),
23239 // and returns the EVT to which the extension should be performed.
23240 // NOTE: this assumes that the src is the first operand of the shuffle.
canCombineShuffleToExtendVectorInreg(unsigned Opcode,EVT VT,std::function<bool (unsigned)> Match,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalTypes,bool LegalOperations)23241 static std::optional<EVT> canCombineShuffleToExtendVectorInreg(
23242 unsigned Opcode, EVT VT, std::function<bool(unsigned)> Match,
23243 SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
23244 bool LegalOperations) {
23245 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
23246
23247 // TODO Add support for big-endian when we have a test case.
23248 if (!VT.isInteger() || IsBigEndian)
23249 return std::nullopt;
23250
23251 unsigned NumElts = VT.getVectorNumElements();
23252 unsigned EltSizeInBits = VT.getScalarSizeInBits();
23253
23254 // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
23255 // power-of-2 extensions as they are the most likely.
23256 // FIXME: should try Scale == NumElts case too,
23257 for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
23258 // The vector width must be a multiple of Scale.
23259 if (NumElts % Scale != 0)
23260 continue;
23261
23262 EVT OutSVT = EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits * Scale);
23263 EVT OutVT = EVT::getVectorVT(*DAG.getContext(), OutSVT, NumElts / Scale);
23264
23265 if ((LegalTypes && !TLI.isTypeLegal(OutVT)) ||
23266 (LegalOperations && !TLI.isOperationLegalOrCustom(Opcode, OutVT)))
23267 continue;
23268
23269 if (Match(Scale))
23270 return OutVT;
23271 }
23272
23273 return std::nullopt;
23274 }
23275
23276 // Match shuffles that can be converted to any_vector_extend_in_reg.
23277 // This is often generated during legalization.
23278 // e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
combineShuffleToAnyExtendVectorInreg(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)23279 static SDValue combineShuffleToAnyExtendVectorInreg(ShuffleVectorSDNode *SVN,
23280 SelectionDAG &DAG,
23281 const TargetLowering &TLI,
23282 bool LegalOperations) {
23283 EVT VT = SVN->getValueType(0);
23284 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
23285
23286 // TODO Add support for big-endian when we have a test case.
23287 if (!VT.isInteger() || IsBigEndian)
23288 return SDValue();
23289
23290 // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
23291 auto isAnyExtend = [NumElts = VT.getVectorNumElements(),
23292 Mask = SVN->getMask()](unsigned Scale) {
23293 for (unsigned i = 0; i != NumElts; ++i) {
23294 if (Mask[i] < 0)
23295 continue;
23296 if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
23297 continue;
23298 return false;
23299 }
23300 return true;
23301 };
23302
23303 unsigned Opcode = ISD::ANY_EXTEND_VECTOR_INREG;
23304 SDValue N0 = SVN->getOperand(0);
23305 // Never create an illegal type. Only create unsupported operations if we
23306 // are pre-legalization.
23307 std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
23308 Opcode, VT, isAnyExtend, DAG, TLI, /*LegalTypes=*/true, LegalOperations);
23309 if (!OutVT)
23310 return SDValue();
23311 return DAG.getBitcast(VT, DAG.getNode(Opcode, SDLoc(SVN), *OutVT, N0));
23312 }
23313
23314 // Match shuffles that can be converted to zero_extend_vector_inreg.
23315 // This is often generated during legalization.
23316 // e.g. v4i32 <0,z,1,u> -> (v2i64 zero_extend_vector_inreg(v4i32 src))
combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)23317 static SDValue combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode *SVN,
23318 SelectionDAG &DAG,
23319 const TargetLowering &TLI,
23320 bool LegalOperations) {
23321 bool LegalTypes = true;
23322 EVT VT = SVN->getValueType(0);
23323 assert(!VT.isScalableVector() && "Encountered scalable shuffle?");
23324 unsigned NumElts = VT.getVectorNumElements();
23325 unsigned EltSizeInBits = VT.getScalarSizeInBits();
23326
23327 // TODO: add support for big-endian when we have a test case.
23328 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
23329 if (!VT.isInteger() || IsBigEndian)
23330 return SDValue();
23331
23332 SmallVector<int, 16> Mask(SVN->getMask().begin(), SVN->getMask().end());
23333 auto ForEachDecomposedIndice = [NumElts, &Mask](auto Fn) {
23334 for (int &Indice : Mask) {
23335 if (Indice < 0)
23336 continue;
23337 int OpIdx = (unsigned)Indice < NumElts ? 0 : 1;
23338 int OpEltIdx = (unsigned)Indice < NumElts ? Indice : Indice - NumElts;
23339 Fn(Indice, OpIdx, OpEltIdx);
23340 }
23341 };
23342
23343 // Which elements of which operand does this shuffle demand?
23344 std::array<APInt, 2> OpsDemandedElts;
23345 for (APInt &OpDemandedElts : OpsDemandedElts)
23346 OpDemandedElts = APInt::getZero(NumElts);
23347 ForEachDecomposedIndice(
23348 [&OpsDemandedElts](int &Indice, int OpIdx, int OpEltIdx) {
23349 OpsDemandedElts[OpIdx].setBit(OpEltIdx);
23350 });
23351
23352 // Element-wise(!), which of these demanded elements are know to be zero?
23353 std::array<APInt, 2> OpsKnownZeroElts;
23354 for (auto I : zip(SVN->ops(), OpsDemandedElts, OpsKnownZeroElts))
23355 std::get<2>(I) =
23356 DAG.computeVectorKnownZeroElements(std::get<0>(I), std::get<1>(I));
23357
23358 // Manifest zeroable element knowledge in the shuffle mask.
23359 // NOTE: we don't have 'zeroable' sentinel value in generic DAG,
23360 // this is a local invention, but it won't leak into DAG.
23361 // FIXME: should we not manifest them, but just check when matching?
23362 bool HadZeroableElts = false;
23363 ForEachDecomposedIndice([&OpsKnownZeroElts, &HadZeroableElts](
23364 int &Indice, int OpIdx, int OpEltIdx) {
23365 if (OpsKnownZeroElts[OpIdx][OpEltIdx]) {
23366 Indice = -2; // Zeroable element.
23367 HadZeroableElts = true;
23368 }
23369 });
23370
23371 // Don't proceed unless we've refined at least one zeroable mask indice.
23372 // If we didn't, then we are still trying to match the same shuffle mask
23373 // we previously tried to match as ISD::ANY_EXTEND_VECTOR_INREG,
23374 // and evidently failed. Proceeding will lead to endless combine loops.
23375 if (!HadZeroableElts)
23376 return SDValue();
23377
23378 // The shuffle may be more fine-grained than we want. Widen elements first.
23379 // FIXME: should we do this before manifesting zeroable shuffle mask indices?
23380 SmallVector<int, 16> ScaledMask;
23381 getShuffleMaskWithWidestElts(Mask, ScaledMask);
23382 assert(Mask.size() >= ScaledMask.size() &&
23383 Mask.size() % ScaledMask.size() == 0 && "Unexpected mask widening.");
23384 int Prescale = Mask.size() / ScaledMask.size();
23385
23386 NumElts = ScaledMask.size();
23387 EltSizeInBits *= Prescale;
23388
23389 EVT PrescaledVT = EVT::getVectorVT(
23390 *DAG.getContext(), EVT::getIntegerVT(*DAG.getContext(), EltSizeInBits),
23391 NumElts);
23392
23393 if (LegalTypes && !TLI.isTypeLegal(PrescaledVT) && TLI.isTypeLegal(VT))
23394 return SDValue();
23395
23396 // For example,
23397 // shuffle<0,z,1,-1> == (v2i64 zero_extend_vector_inreg(v4i32))
23398 // But not shuffle<z,z,1,-1> and not shuffle<0,z,z,-1> ! (for same types)
23399 auto isZeroExtend = [NumElts, &ScaledMask](unsigned Scale) {
23400 assert(Scale >= 2 && Scale <= NumElts && NumElts % Scale == 0 &&
23401 "Unexpected mask scaling factor.");
23402 ArrayRef<int> Mask = ScaledMask;
23403 for (unsigned SrcElt = 0, NumSrcElts = NumElts / Scale;
23404 SrcElt != NumSrcElts; ++SrcElt) {
23405 // Analyze the shuffle mask in Scale-sized chunks.
23406 ArrayRef<int> MaskChunk = Mask.take_front(Scale);
23407 assert(MaskChunk.size() == Scale && "Unexpected mask size.");
23408 Mask = Mask.drop_front(MaskChunk.size());
23409 // The first indice in this chunk must be SrcElt, but not zero!
23410 // FIXME: undef should be fine, but that results in more-defined result.
23411 if (int FirstIndice = MaskChunk[0]; (unsigned)FirstIndice != SrcElt)
23412 return false;
23413 // The rest of the indices in this chunk must be zeros.
23414 // FIXME: undef should be fine, but that results in more-defined result.
23415 if (!all_of(MaskChunk.drop_front(1),
23416 [](int Indice) { return Indice == -2; }))
23417 return false;
23418 }
23419 assert(Mask.empty() && "Did not process the whole mask?");
23420 return true;
23421 };
23422
23423 unsigned Opcode = ISD::ZERO_EXTEND_VECTOR_INREG;
23424 for (bool Commuted : {false, true}) {
23425 SDValue Op = SVN->getOperand(!Commuted ? 0 : 1);
23426 if (Commuted)
23427 ShuffleVectorSDNode::commuteMask(ScaledMask);
23428 std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
23429 Opcode, PrescaledVT, isZeroExtend, DAG, TLI, LegalTypes,
23430 LegalOperations);
23431 if (OutVT)
23432 return DAG.getBitcast(VT, DAG.getNode(Opcode, SDLoc(SVN), *OutVT,
23433 DAG.getBitcast(PrescaledVT, Op)));
23434 }
23435 return SDValue();
23436 }
23437
23438 // Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
23439 // each source element of a large type into the lowest elements of a smaller
23440 // destination type. This is often generated during legalization.
23441 // If the source node itself was a '*_extend_vector_inreg' node then we should
23442 // then be able to remove it.
combineTruncationShuffle(ShuffleVectorSDNode * SVN,SelectionDAG & DAG)23443 static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
23444 SelectionDAG &DAG) {
23445 EVT VT = SVN->getValueType(0);
23446 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
23447
23448 // TODO Add support for big-endian when we have a test case.
23449 if (!VT.isInteger() || IsBigEndian)
23450 return SDValue();
23451
23452 SDValue N0 = peekThroughBitcasts(SVN->getOperand(0));
23453
23454 unsigned Opcode = N0.getOpcode();
23455 if (Opcode != ISD::ANY_EXTEND_VECTOR_INREG &&
23456 Opcode != ISD::SIGN_EXTEND_VECTOR_INREG &&
23457 Opcode != ISD::ZERO_EXTEND_VECTOR_INREG)
23458 return SDValue();
23459
23460 SDValue N00 = N0.getOperand(0);
23461 ArrayRef<int> Mask = SVN->getMask();
23462 unsigned NumElts = VT.getVectorNumElements();
23463 unsigned EltSizeInBits = VT.getScalarSizeInBits();
23464 unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
23465 unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
23466
23467 if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
23468 return SDValue();
23469 unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
23470
23471 // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
23472 // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
23473 // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
23474 auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
23475 for (unsigned i = 0; i != NumElts; ++i) {
23476 if (Mask[i] < 0)
23477 continue;
23478 if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
23479 continue;
23480 return false;
23481 }
23482 return true;
23483 };
23484
23485 // At the moment we just handle the case where we've truncated back to the
23486 // same size as before the extension.
23487 // TODO: handle more extension/truncation cases as cases arise.
23488 if (EltSizeInBits != ExtSrcSizeInBits)
23489 return SDValue();
23490
23491 // We can remove *extend_vector_inreg only if the truncation happens at
23492 // the same scale as the extension.
23493 if (isTruncate(ExtScale))
23494 return DAG.getBitcast(VT, N00);
23495
23496 return SDValue();
23497 }
23498
23499 // Combine shuffles of splat-shuffles of the form:
23500 // shuffle (shuffle V, undef, splat-mask), undef, M
23501 // If splat-mask contains undef elements, we need to be careful about
23502 // introducing undef's in the folded mask which are not the result of composing
23503 // the masks of the shuffles.
combineShuffleOfSplatVal(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)23504 static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
23505 SelectionDAG &DAG) {
23506 EVT VT = Shuf->getValueType(0);
23507 unsigned NumElts = VT.getVectorNumElements();
23508
23509 if (!Shuf->getOperand(1).isUndef())
23510 return SDValue();
23511
23512 // See if this unary non-splat shuffle actually *is* a splat shuffle,
23513 // in disguise, with all demanded elements being identical.
23514 // FIXME: this can be done per-operand.
23515 if (!Shuf->isSplat()) {
23516 APInt DemandedElts(NumElts, 0);
23517 for (int Idx : Shuf->getMask()) {
23518 if (Idx < 0)
23519 continue; // Ignore sentinel indices.
23520 assert((unsigned)Idx < NumElts && "Out-of-bounds shuffle indice?");
23521 DemandedElts.setBit(Idx);
23522 }
23523 assert(DemandedElts.countPopulation() > 1 && "Is a splat shuffle already?");
23524 APInt UndefElts;
23525 if (DAG.isSplatValue(Shuf->getOperand(0), DemandedElts, UndefElts)) {
23526 // Even if all demanded elements are splat, some of them could be undef.
23527 // Which lowest demanded element is *not* known-undef?
23528 std::optional<unsigned> MinNonUndefIdx;
23529 for (int Idx : Shuf->getMask()) {
23530 if (Idx < 0 || UndefElts[Idx])
23531 continue; // Ignore sentinel indices, and undef elements.
23532 MinNonUndefIdx = std::min<unsigned>(Idx, MinNonUndefIdx.value_or(~0U));
23533 }
23534 if (!MinNonUndefIdx)
23535 return DAG.getUNDEF(VT); // All undef - result is undef.
23536 assert(*MinNonUndefIdx < NumElts && "Expected valid element index.");
23537 SmallVector<int, 8> SplatMask(Shuf->getMask().begin(),
23538 Shuf->getMask().end());
23539 for (int &Idx : SplatMask) {
23540 if (Idx < 0)
23541 continue; // Passthrough sentinel indices.
23542 // Otherwise, just pick the lowest demanded non-undef element.
23543 // Or sentinel undef, if we know we'd pick a known-undef element.
23544 Idx = UndefElts[Idx] ? -1 : *MinNonUndefIdx;
23545 }
23546 assert(SplatMask != Shuf->getMask() && "Expected mask to change!");
23547 return DAG.getVectorShuffle(VT, SDLoc(Shuf), Shuf->getOperand(0),
23548 Shuf->getOperand(1), SplatMask);
23549 }
23550 }
23551
23552 // If the inner operand is a known splat with no undefs, just return that directly.
23553 // TODO: Create DemandedElts mask from Shuf's mask.
23554 // TODO: Allow undef elements and merge with the shuffle code below.
23555 if (DAG.isSplatValue(Shuf->getOperand(0), /*AllowUndefs*/ false))
23556 return Shuf->getOperand(0);
23557
23558 auto *Splat = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
23559 if (!Splat || !Splat->isSplat())
23560 return SDValue();
23561
23562 ArrayRef<int> ShufMask = Shuf->getMask();
23563 ArrayRef<int> SplatMask = Splat->getMask();
23564 assert(ShufMask.size() == SplatMask.size() && "Mask length mismatch");
23565
23566 // Prefer simplifying to the splat-shuffle, if possible. This is legal if
23567 // every undef mask element in the splat-shuffle has a corresponding undef
23568 // element in the user-shuffle's mask or if the composition of mask elements
23569 // would result in undef.
23570 // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
23571 // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
23572 // In this case it is not legal to simplify to the splat-shuffle because we
23573 // may be exposing the users of the shuffle an undef element at index 1
23574 // which was not there before the combine.
23575 // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
23576 // In this case the composition of masks yields SplatMask, so it's ok to
23577 // simplify to the splat-shuffle.
23578 // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
23579 // In this case the composed mask includes all undef elements of SplatMask
23580 // and in addition sets element zero to undef. It is safe to simplify to
23581 // the splat-shuffle.
23582 auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
23583 ArrayRef<int> SplatMask) {
23584 for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
23585 if (UserMask[i] != -1 && SplatMask[i] == -1 &&
23586 SplatMask[UserMask[i]] != -1)
23587 return false;
23588 return true;
23589 };
23590 if (CanSimplifyToExistingSplat(ShufMask, SplatMask))
23591 return Shuf->getOperand(0);
23592
23593 // Create a new shuffle with a mask that is composed of the two shuffles'
23594 // masks.
23595 SmallVector<int, 32> NewMask;
23596 for (int Idx : ShufMask)
23597 NewMask.push_back(Idx == -1 ? -1 : SplatMask[Idx]);
23598
23599 return DAG.getVectorShuffle(Splat->getValueType(0), SDLoc(Splat),
23600 Splat->getOperand(0), Splat->getOperand(1),
23601 NewMask);
23602 }
23603
23604 // Combine shuffles of bitcasts into a shuffle of the bitcast type, providing
23605 // the mask can be treated as a larger type.
combineShuffleOfBitcast(ShuffleVectorSDNode * SVN,SelectionDAG & DAG,const TargetLowering & TLI,bool LegalOperations)23606 static SDValue combineShuffleOfBitcast(ShuffleVectorSDNode *SVN,
23607 SelectionDAG &DAG,
23608 const TargetLowering &TLI,
23609 bool LegalOperations) {
23610 SDValue Op0 = SVN->getOperand(0);
23611 SDValue Op1 = SVN->getOperand(1);
23612 EVT VT = SVN->getValueType(0);
23613 if (Op0.getOpcode() != ISD::BITCAST)
23614 return SDValue();
23615 EVT InVT = Op0.getOperand(0).getValueType();
23616 if (!InVT.isVector() ||
23617 (!Op1.isUndef() && (Op1.getOpcode() != ISD::BITCAST ||
23618 Op1.getOperand(0).getValueType() != InVT)))
23619 return SDValue();
23620 if (isAnyConstantBuildVector(Op0.getOperand(0)) &&
23621 (Op1.isUndef() || isAnyConstantBuildVector(Op1.getOperand(0))))
23622 return SDValue();
23623
23624 int VTLanes = VT.getVectorNumElements();
23625 int InLanes = InVT.getVectorNumElements();
23626 if (VTLanes <= InLanes || VTLanes % InLanes != 0 ||
23627 (LegalOperations &&
23628 !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, InVT)))
23629 return SDValue();
23630 int Factor = VTLanes / InLanes;
23631
23632 // Check that each group of lanes in the mask are either undef or make a valid
23633 // mask for the wider lane type.
23634 ArrayRef<int> Mask = SVN->getMask();
23635 SmallVector<int> NewMask;
23636 if (!widenShuffleMaskElts(Factor, Mask, NewMask))
23637 return SDValue();
23638
23639 if (!TLI.isShuffleMaskLegal(NewMask, InVT))
23640 return SDValue();
23641
23642 // Create the new shuffle with the new mask and bitcast it back to the
23643 // original type.
23644 SDLoc DL(SVN);
23645 Op0 = Op0.getOperand(0);
23646 Op1 = Op1.isUndef() ? DAG.getUNDEF(InVT) : Op1.getOperand(0);
23647 SDValue NewShuf = DAG.getVectorShuffle(InVT, DL, Op0, Op1, NewMask);
23648 return DAG.getBitcast(VT, NewShuf);
23649 }
23650
23651 /// Combine shuffle of shuffle of the form:
23652 /// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X
formSplatFromShuffles(ShuffleVectorSDNode * OuterShuf,SelectionDAG & DAG)23653 static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf,
23654 SelectionDAG &DAG) {
23655 if (!OuterShuf->getOperand(1).isUndef())
23656 return SDValue();
23657 auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(OuterShuf->getOperand(0));
23658 if (!InnerShuf || !InnerShuf->getOperand(1).isUndef())
23659 return SDValue();
23660
23661 ArrayRef<int> OuterMask = OuterShuf->getMask();
23662 ArrayRef<int> InnerMask = InnerShuf->getMask();
23663 unsigned NumElts = OuterMask.size();
23664 assert(NumElts == InnerMask.size() && "Mask length mismatch");
23665 SmallVector<int, 32> CombinedMask(NumElts, -1);
23666 int SplatIndex = -1;
23667 for (unsigned i = 0; i != NumElts; ++i) {
23668 // Undef lanes remain undef.
23669 int OuterMaskElt = OuterMask[i];
23670 if (OuterMaskElt == -1)
23671 continue;
23672
23673 // Peek through the shuffle masks to get the underlying source element.
23674 int InnerMaskElt = InnerMask[OuterMaskElt];
23675 if (InnerMaskElt == -1)
23676 continue;
23677
23678 // Initialize the splatted element.
23679 if (SplatIndex == -1)
23680 SplatIndex = InnerMaskElt;
23681
23682 // Non-matching index - this is not a splat.
23683 if (SplatIndex != InnerMaskElt)
23684 return SDValue();
23685
23686 CombinedMask[i] = InnerMaskElt;
23687 }
23688 assert((all_of(CombinedMask, [](int M) { return M == -1; }) ||
23689 getSplatIndex(CombinedMask) != -1) &&
23690 "Expected a splat mask");
23691
23692 // TODO: The transform may be a win even if the mask is not legal.
23693 EVT VT = OuterShuf->getValueType(0);
23694 assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types");
23695 if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT))
23696 return SDValue();
23697
23698 return DAG.getVectorShuffle(VT, SDLoc(OuterShuf), InnerShuf->getOperand(0),
23699 InnerShuf->getOperand(1), CombinedMask);
23700 }
23701
23702 /// If the shuffle mask is taking exactly one element from the first vector
23703 /// operand and passing through all other elements from the second vector
23704 /// operand, return the index of the mask element that is choosing an element
23705 /// from the first operand. Otherwise, return -1.
getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask)23706 static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) {
23707 int MaskSize = Mask.size();
23708 int EltFromOp0 = -1;
23709 // TODO: This does not match if there are undef elements in the shuffle mask.
23710 // Should we ignore undefs in the shuffle mask instead? The trade-off is
23711 // removing an instruction (a shuffle), but losing the knowledge that some
23712 // vector lanes are not needed.
23713 for (int i = 0; i != MaskSize; ++i) {
23714 if (Mask[i] >= 0 && Mask[i] < MaskSize) {
23715 // We're looking for a shuffle of exactly one element from operand 0.
23716 if (EltFromOp0 != -1)
23717 return -1;
23718 EltFromOp0 = i;
23719 } else if (Mask[i] != i + MaskSize) {
23720 // Nothing from operand 1 can change lanes.
23721 return -1;
23722 }
23723 }
23724 return EltFromOp0;
23725 }
23726
23727 /// If a shuffle inserts exactly one element from a source vector operand into
23728 /// another vector operand and we can access the specified element as a scalar,
23729 /// then we can eliminate the shuffle.
replaceShuffleOfInsert(ShuffleVectorSDNode * Shuf,SelectionDAG & DAG)23730 static SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf,
23731 SelectionDAG &DAG) {
23732 // First, check if we are taking one element of a vector and shuffling that
23733 // element into another vector.
23734 ArrayRef<int> Mask = Shuf->getMask();
23735 SmallVector<int, 16> CommutedMask(Mask);
23736 SDValue Op0 = Shuf->getOperand(0);
23737 SDValue Op1 = Shuf->getOperand(1);
23738 int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask);
23739 if (ShufOp0Index == -1) {
23740 // Commute mask and check again.
23741 ShuffleVectorSDNode::commuteMask(CommutedMask);
23742 ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(CommutedMask);
23743 if (ShufOp0Index == -1)
23744 return SDValue();
23745 // Commute operands to match the commuted shuffle mask.
23746 std::swap(Op0, Op1);
23747 Mask = CommutedMask;
23748 }
23749
23750 // The shuffle inserts exactly one element from operand 0 into operand 1.
23751 // Now see if we can access that element as a scalar via a real insert element
23752 // instruction.
23753 // TODO: We can try harder to locate the element as a scalar. Examples: it
23754 // could be an operand of SCALAR_TO_VECTOR, BUILD_VECTOR, or a constant.
23755 assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() &&
23756 "Shuffle mask value must be from operand 0");
23757 if (Op0.getOpcode() != ISD::INSERT_VECTOR_ELT)
23758 return SDValue();
23759
23760 auto *InsIndexC = dyn_cast<ConstantSDNode>(Op0.getOperand(2));
23761 if (!InsIndexC || InsIndexC->getSExtValue() != Mask[ShufOp0Index])
23762 return SDValue();
23763
23764 // There's an existing insertelement with constant insertion index, so we
23765 // don't need to check the legality/profitability of a replacement operation
23766 // that differs at most in the constant value. The target should be able to
23767 // lower any of those in a similar way. If not, legalization will expand this
23768 // to a scalar-to-vector plus shuffle.
23769 //
23770 // Note that the shuffle may move the scalar from the position that the insert
23771 // element used. Therefore, our new insert element occurs at the shuffle's
23772 // mask index value, not the insert's index value.
23773 // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C'
23774 SDValue NewInsIndex = DAG.getVectorIdxConstant(ShufOp0Index, SDLoc(Shuf));
23775 return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Shuf), Op0.getValueType(),
23776 Op1, Op0.getOperand(1), NewInsIndex);
23777 }
23778
23779 /// If we have a unary shuffle of a shuffle, see if it can be folded away
23780 /// completely. This has the potential to lose undef knowledge because the first
23781 /// shuffle may not have an undef mask element where the second one does. So
23782 /// only call this after doing simplifications based on demanded elements.
simplifyShuffleOfShuffle(ShuffleVectorSDNode * Shuf)23783 static SDValue simplifyShuffleOfShuffle(ShuffleVectorSDNode *Shuf) {
23784 // shuf (shuf0 X, Y, Mask0), undef, Mask
23785 auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Shuf->getOperand(0));
23786 if (!Shuf0 || !Shuf->getOperand(1).isUndef())
23787 return SDValue();
23788
23789 ArrayRef<int> Mask = Shuf->getMask();
23790 ArrayRef<int> Mask0 = Shuf0->getMask();
23791 for (int i = 0, e = (int)Mask.size(); i != e; ++i) {
23792 // Ignore undef elements.
23793 if (Mask[i] == -1)
23794 continue;
23795 assert(Mask[i] >= 0 && Mask[i] < e && "Unexpected shuffle mask value");
23796
23797 // Is the element of the shuffle operand chosen by this shuffle the same as
23798 // the element chosen by the shuffle operand itself?
23799 if (Mask0[Mask[i]] != Mask0[i])
23800 return SDValue();
23801 }
23802 // Every element of this shuffle is identical to the result of the previous
23803 // shuffle, so we can replace this value.
23804 return Shuf->getOperand(0);
23805 }
23806
visitVECTOR_SHUFFLE(SDNode * N)23807 SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
23808 EVT VT = N->getValueType(0);
23809 unsigned NumElts = VT.getVectorNumElements();
23810
23811 SDValue N0 = N->getOperand(0);
23812 SDValue N1 = N->getOperand(1);
23813
23814 assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG");
23815
23816 // Canonicalize shuffle undef, undef -> undef
23817 if (N0.isUndef() && N1.isUndef())
23818 return DAG.getUNDEF(VT);
23819
23820 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N);
23821
23822 // Canonicalize shuffle v, v -> v, undef
23823 if (N0 == N1)
23824 return DAG.getVectorShuffle(VT, SDLoc(N), N0, DAG.getUNDEF(VT),
23825 createUnaryMask(SVN->getMask(), NumElts));
23826
23827 // Canonicalize shuffle undef, v -> v, undef. Commute the shuffle mask.
23828 if (N0.isUndef())
23829 return DAG.getCommutedVectorShuffle(*SVN);
23830
23831 // Remove references to rhs if it is undef
23832 if (N1.isUndef()) {
23833 bool Changed = false;
23834 SmallVector<int, 8> NewMask;
23835 for (unsigned i = 0; i != NumElts; ++i) {
23836 int Idx = SVN->getMaskElt(i);
23837 if (Idx >= (int)NumElts) {
23838 Idx = -1;
23839 Changed = true;
23840 }
23841 NewMask.push_back(Idx);
23842 }
23843 if (Changed)
23844 return DAG.getVectorShuffle(VT, SDLoc(N), N0, N1, NewMask);
23845 }
23846
23847 if (SDValue InsElt = replaceShuffleOfInsert(SVN, DAG))
23848 return InsElt;
23849
23850 // A shuffle of a single vector that is a splatted value can always be folded.
23851 if (SDValue V = combineShuffleOfSplatVal(SVN, DAG))
23852 return V;
23853
23854 if (SDValue V = formSplatFromShuffles(SVN, DAG))
23855 return V;
23856
23857 // If it is a splat, check if the argument vector is another splat or a
23858 // build_vector.
23859 if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
23860 int SplatIndex = SVN->getSplatIndex();
23861 if (N0.hasOneUse() && TLI.isExtractVecEltCheap(VT, SplatIndex) &&
23862 TLI.isBinOp(N0.getOpcode()) && N0->getNumValues() == 1) {
23863 // splat (vector_bo L, R), Index -->
23864 // splat (scalar_bo (extelt L, Index), (extelt R, Index))
23865 SDValue L = N0.getOperand(0), R = N0.getOperand(1);
23866 SDLoc DL(N);
23867 EVT EltVT = VT.getScalarType();
23868 SDValue Index = DAG.getVectorIdxConstant(SplatIndex, DL);
23869 SDValue ExtL = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, L, Index);
23870 SDValue ExtR = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, R, Index);
23871 SDValue NewBO =
23872 DAG.getNode(N0.getOpcode(), DL, EltVT, ExtL, ExtR, N0->getFlags());
23873 SDValue Insert = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, NewBO);
23874 SmallVector<int, 16> ZeroMask(VT.getVectorNumElements(), 0);
23875 return DAG.getVectorShuffle(VT, DL, Insert, DAG.getUNDEF(VT), ZeroMask);
23876 }
23877
23878 // splat(scalar_to_vector(x), 0) -> build_vector(x,...,x)
23879 // splat(insert_vector_elt(v, x, c), c) -> build_vector(x,...,x)
23880 if ((!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) &&
23881 N0.hasOneUse()) {
23882 if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && SplatIndex == 0)
23883 return DAG.getSplatBuildVector(VT, SDLoc(N), N0.getOperand(0));
23884
23885 if (N0.getOpcode() == ISD::INSERT_VECTOR_ELT)
23886 if (auto *Idx = dyn_cast<ConstantSDNode>(N0.getOperand(2)))
23887 if (Idx->getAPIntValue() == SplatIndex)
23888 return DAG.getSplatBuildVector(VT, SDLoc(N), N0.getOperand(1));
23889
23890 // Look through a bitcast if LE and splatting lane 0, through to a
23891 // scalar_to_vector or a build_vector.
23892 if (N0.getOpcode() == ISD::BITCAST && N0.getOperand(0).hasOneUse() &&
23893 SplatIndex == 0 && DAG.getDataLayout().isLittleEndian() &&
23894 (N0.getOperand(0).getOpcode() == ISD::SCALAR_TO_VECTOR ||
23895 N0.getOperand(0).getOpcode() == ISD::BUILD_VECTOR)) {
23896 EVT N00VT = N0.getOperand(0).getValueType();
23897 if (VT.getScalarSizeInBits() <= N00VT.getScalarSizeInBits() &&
23898 VT.isInteger() && N00VT.isInteger()) {
23899 EVT InVT =
23900 TLI.getTypeToTransformTo(*DAG.getContext(), VT.getScalarType());
23901 SDValue Op = DAG.getZExtOrTrunc(N0.getOperand(0).getOperand(0),
23902 SDLoc(N), InVT);
23903 return DAG.getSplatBuildVector(VT, SDLoc(N), Op);
23904 }
23905 }
23906 }
23907
23908 // If this is a bit convert that changes the element type of the vector but
23909 // not the number of vector elements, look through it. Be careful not to
23910 // look though conversions that change things like v4f32 to v2f64.
23911 SDNode *V = N0.getNode();
23912 if (V->getOpcode() == ISD::BITCAST) {
23913 SDValue ConvInput = V->getOperand(0);
23914 if (ConvInput.getValueType().isVector() &&
23915 ConvInput.getValueType().getVectorNumElements() == NumElts)
23916 V = ConvInput.getNode();
23917 }
23918
23919 if (V->getOpcode() == ISD::BUILD_VECTOR) {
23920 assert(V->getNumOperands() == NumElts &&
23921 "BUILD_VECTOR has wrong number of operands");
23922 SDValue Base;
23923 bool AllSame = true;
23924 for (unsigned i = 0; i != NumElts; ++i) {
23925 if (!V->getOperand(i).isUndef()) {
23926 Base = V->getOperand(i);
23927 break;
23928 }
23929 }
23930 // Splat of <u, u, u, u>, return <u, u, u, u>
23931 if (!Base.getNode())
23932 return N0;
23933 for (unsigned i = 0; i != NumElts; ++i) {
23934 if (V->getOperand(i) != Base) {
23935 AllSame = false;
23936 break;
23937 }
23938 }
23939 // Splat of <x, x, x, x>, return <x, x, x, x>
23940 if (AllSame)
23941 return N0;
23942
23943 // Canonicalize any other splat as a build_vector.
23944 SDValue Splatted = V->getOperand(SplatIndex);
23945 SmallVector<SDValue, 8> Ops(NumElts, Splatted);
23946 SDValue NewBV = DAG.getBuildVector(V->getValueType(0), SDLoc(N), Ops);
23947
23948 // We may have jumped through bitcasts, so the type of the
23949 // BUILD_VECTOR may not match the type of the shuffle.
23950 if (V->getValueType(0) != VT)
23951 NewBV = DAG.getBitcast(VT, NewBV);
23952 return NewBV;
23953 }
23954 }
23955
23956 // Simplify source operands based on shuffle mask.
23957 if (SimplifyDemandedVectorElts(SDValue(N, 0)))
23958 return SDValue(N, 0);
23959
23960 // This is intentionally placed after demanded elements simplification because
23961 // it could eliminate knowledge of undef elements created by this shuffle.
23962 if (SDValue ShufOp = simplifyShuffleOfShuffle(SVN))
23963 return ShufOp;
23964
23965 // Match shuffles that can be converted to any_vector_extend_in_reg.
23966 if (SDValue V =
23967 combineShuffleToAnyExtendVectorInreg(SVN, DAG, TLI, LegalOperations))
23968 return V;
23969
23970 // Combine "truncate_vector_in_reg" style shuffles.
23971 if (SDValue V = combineTruncationShuffle(SVN, DAG))
23972 return V;
23973
23974 if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
23975 Level < AfterLegalizeVectorOps &&
23976 (N1.isUndef() ||
23977 (N1.getOpcode() == ISD::CONCAT_VECTORS &&
23978 N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType()))) {
23979 if (SDValue V = partitionShuffleOfConcats(N, DAG))
23980 return V;
23981 }
23982
23983 // A shuffle of a concat of the same narrow vector can be reduced to use
23984 // only low-half elements of a concat with undef:
23985 // shuf (concat X, X), undef, Mask --> shuf (concat X, undef), undef, Mask'
23986 if (N0.getOpcode() == ISD::CONCAT_VECTORS && N1.isUndef() &&
23987 N0.getNumOperands() == 2 &&
23988 N0.getOperand(0) == N0.getOperand(1)) {
23989 int HalfNumElts = (int)NumElts / 2;
23990 SmallVector<int, 8> NewMask;
23991 for (unsigned i = 0; i != NumElts; ++i) {
23992 int Idx = SVN->getMaskElt(i);
23993 if (Idx >= HalfNumElts) {
23994 assert(Idx < (int)NumElts && "Shuffle mask chooses undef op");
23995 Idx -= HalfNumElts;
23996 }
23997 NewMask.push_back(Idx);
23998 }
23999 if (TLI.isShuffleMaskLegal(NewMask, VT)) {
24000 SDValue UndefVec = DAG.getUNDEF(N0.getOperand(0).getValueType());
24001 SDValue NewCat = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT,
24002 N0.getOperand(0), UndefVec);
24003 return DAG.getVectorShuffle(VT, SDLoc(N), NewCat, N1, NewMask);
24004 }
24005 }
24006
24007 // See if we can replace a shuffle with an insert_subvector.
24008 // e.g. v2i32 into v8i32:
24009 // shuffle(lhs,concat(rhs0,rhs1,rhs2,rhs3),0,1,2,3,10,11,6,7).
24010 // --> insert_subvector(lhs,rhs1,4).
24011 if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT) &&
24012 TLI.isOperationLegalOrCustom(ISD::INSERT_SUBVECTOR, VT)) {
24013 auto ShuffleToInsert = [&](SDValue LHS, SDValue RHS, ArrayRef<int> Mask) {
24014 // Ensure RHS subvectors are legal.
24015 assert(RHS.getOpcode() == ISD::CONCAT_VECTORS && "Can't find subvectors");
24016 EVT SubVT = RHS.getOperand(0).getValueType();
24017 int NumSubVecs = RHS.getNumOperands();
24018 int NumSubElts = SubVT.getVectorNumElements();
24019 assert((NumElts % NumSubElts) == 0 && "Subvector mismatch");
24020 if (!TLI.isTypeLegal(SubVT))
24021 return SDValue();
24022
24023 // Don't bother if we have an unary shuffle (matches undef + LHS elts).
24024 if (all_of(Mask, [NumElts](int M) { return M < (int)NumElts; }))
24025 return SDValue();
24026
24027 // Search [NumSubElts] spans for RHS sequence.
24028 // TODO: Can we avoid nested loops to increase performance?
24029 SmallVector<int> InsertionMask(NumElts);
24030 for (int SubVec = 0; SubVec != NumSubVecs; ++SubVec) {
24031 for (int SubIdx = 0; SubIdx != (int)NumElts; SubIdx += NumSubElts) {
24032 // Reset mask to identity.
24033 std::iota(InsertionMask.begin(), InsertionMask.end(), 0);
24034
24035 // Add subvector insertion.
24036 std::iota(InsertionMask.begin() + SubIdx,
24037 InsertionMask.begin() + SubIdx + NumSubElts,
24038 NumElts + (SubVec * NumSubElts));
24039
24040 // See if the shuffle mask matches the reference insertion mask.
24041 bool MatchingShuffle = true;
24042 for (int i = 0; i != (int)NumElts; ++i) {
24043 int ExpectIdx = InsertionMask[i];
24044 int ActualIdx = Mask[i];
24045 if (0 <= ActualIdx && ExpectIdx != ActualIdx) {
24046 MatchingShuffle = false;
24047 break;
24048 }
24049 }
24050
24051 if (MatchingShuffle)
24052 return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, LHS,
24053 RHS.getOperand(SubVec),
24054 DAG.getVectorIdxConstant(SubIdx, SDLoc(N)));
24055 }
24056 }
24057 return SDValue();
24058 };
24059 ArrayRef<int> Mask = SVN->getMask();
24060 if (N1.getOpcode() == ISD::CONCAT_VECTORS)
24061 if (SDValue InsertN1 = ShuffleToInsert(N0, N1, Mask))
24062 return InsertN1;
24063 if (N0.getOpcode() == ISD::CONCAT_VECTORS) {
24064 SmallVector<int> CommuteMask(Mask);
24065 ShuffleVectorSDNode::commuteMask(CommuteMask);
24066 if (SDValue InsertN0 = ShuffleToInsert(N1, N0, CommuteMask))
24067 return InsertN0;
24068 }
24069 }
24070
24071 // If we're not performing a select/blend shuffle, see if we can convert the
24072 // shuffle into a AND node, with all the out-of-lane elements are known zero.
24073 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
24074 bool IsInLaneMask = true;
24075 ArrayRef<int> Mask = SVN->getMask();
24076 SmallVector<int, 16> ClearMask(NumElts, -1);
24077 APInt DemandedLHS = APInt::getNullValue(NumElts);
24078 APInt DemandedRHS = APInt::getNullValue(NumElts);
24079 for (int I = 0; I != (int)NumElts; ++I) {
24080 int M = Mask[I];
24081 if (M < 0)
24082 continue;
24083 ClearMask[I] = M == I ? I : (I + NumElts);
24084 IsInLaneMask &= (M == I) || (M == (int)(I + NumElts));
24085 if (M != I) {
24086 APInt &Demanded = M < (int)NumElts ? DemandedLHS : DemandedRHS;
24087 Demanded.setBit(M % NumElts);
24088 }
24089 }
24090 // TODO: Should we try to mask with N1 as well?
24091 if (!IsInLaneMask &&
24092 (!DemandedLHS.isNullValue() || !DemandedRHS.isNullValue()) &&
24093 (DemandedLHS.isNullValue() ||
24094 DAG.MaskedVectorIsZero(N0, DemandedLHS)) &&
24095 (DemandedRHS.isNullValue() ||
24096 DAG.MaskedVectorIsZero(N1, DemandedRHS))) {
24097 SDLoc DL(N);
24098 EVT IntVT = VT.changeVectorElementTypeToInteger();
24099 EVT IntSVT = VT.getVectorElementType().changeTypeToInteger();
24100 // Transform the type to a legal type so that the buildvector constant
24101 // elements are not illegal. Make sure that the result is larger than the
24102 // original type, incase the value is split into two (eg i64->i32).
24103 if (!TLI.isTypeLegal(IntSVT) && LegalTypes)
24104 IntSVT = TLI.getTypeToTransformTo(*DAG.getContext(), IntSVT);
24105 if (IntSVT.getSizeInBits() >= IntVT.getScalarSizeInBits()) {
24106 SDValue ZeroElt = DAG.getConstant(0, DL, IntSVT);
24107 SDValue AllOnesElt = DAG.getAllOnesConstant(DL, IntSVT);
24108 SmallVector<SDValue, 16> AndMask(NumElts, DAG.getUNDEF(IntSVT));
24109 for (int I = 0; I != (int)NumElts; ++I)
24110 if (0 <= Mask[I])
24111 AndMask[I] = Mask[I] == I ? AllOnesElt : ZeroElt;
24112
24113 // See if a clear mask is legal instead of going via
24114 // XformToShuffleWithZero which loses UNDEF mask elements.
24115 if (TLI.isVectorClearMaskLegal(ClearMask, IntVT))
24116 return DAG.getBitcast(
24117 VT, DAG.getVectorShuffle(IntVT, DL, DAG.getBitcast(IntVT, N0),
24118 DAG.getConstant(0, DL, IntVT), ClearMask));
24119
24120 if (TLI.isOperationLegalOrCustom(ISD::AND, IntVT))
24121 return DAG.getBitcast(
24122 VT, DAG.getNode(ISD::AND, DL, IntVT, DAG.getBitcast(IntVT, N0),
24123 DAG.getBuildVector(IntVT, DL, AndMask)));
24124 }
24125 }
24126 }
24127
24128 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
24129 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
24130 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
24131 if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI))
24132 return Res;
24133
24134 // If this shuffle only has a single input that is a bitcasted shuffle,
24135 // attempt to merge the 2 shuffles and suitably bitcast the inputs/output
24136 // back to their original types.
24137 if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
24138 N1.isUndef() && Level < AfterLegalizeVectorOps &&
24139 TLI.isTypeLegal(VT)) {
24140
24141 SDValue BC0 = peekThroughOneUseBitcasts(N0);
24142 if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) {
24143 EVT SVT = VT.getScalarType();
24144 EVT InnerVT = BC0->getValueType(0);
24145 EVT InnerSVT = InnerVT.getScalarType();
24146
24147 // Determine which shuffle works with the smaller scalar type.
24148 EVT ScaleVT = SVT.bitsLT(InnerSVT) ? VT : InnerVT;
24149 EVT ScaleSVT = ScaleVT.getScalarType();
24150
24151 if (TLI.isTypeLegal(ScaleVT) &&
24152 0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) &&
24153 0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) {
24154 int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits();
24155 int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits();
24156
24157 // Scale the shuffle masks to the smaller scalar type.
24158 ShuffleVectorSDNode *InnerSVN = cast<ShuffleVectorSDNode>(BC0);
24159 SmallVector<int, 8> InnerMask;
24160 SmallVector<int, 8> OuterMask;
24161 narrowShuffleMaskElts(InnerScale, InnerSVN->getMask(), InnerMask);
24162 narrowShuffleMaskElts(OuterScale, SVN->getMask(), OuterMask);
24163
24164 // Merge the shuffle masks.
24165 SmallVector<int, 8> NewMask;
24166 for (int M : OuterMask)
24167 NewMask.push_back(M < 0 ? -1 : InnerMask[M]);
24168
24169 // Test for shuffle mask legality over both commutations.
24170 SDValue SV0 = BC0->getOperand(0);
24171 SDValue SV1 = BC0->getOperand(1);
24172 bool LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
24173 if (!LegalMask) {
24174 std::swap(SV0, SV1);
24175 ShuffleVectorSDNode::commuteMask(NewMask);
24176 LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
24177 }
24178
24179 if (LegalMask) {
24180 SV0 = DAG.getBitcast(ScaleVT, SV0);
24181 SV1 = DAG.getBitcast(ScaleVT, SV1);
24182 return DAG.getBitcast(
24183 VT, DAG.getVectorShuffle(ScaleVT, SDLoc(N), SV0, SV1, NewMask));
24184 }
24185 }
24186 }
24187 }
24188
24189 // Match shuffles of bitcasts, so long as the mask can be treated as the
24190 // larger type.
24191 if (SDValue V = combineShuffleOfBitcast(SVN, DAG, TLI, LegalOperations))
24192 return V;
24193
24194 // Compute the combined shuffle mask for a shuffle with SV0 as the first
24195 // operand, and SV1 as the second operand.
24196 // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false
24197 // Merge SVN(N1, OtherSVN) -> shuffle(SV0, SV1, Mask') iff Commute = true
24198 auto MergeInnerShuffle =
24199 [NumElts, &VT](bool Commute, ShuffleVectorSDNode *SVN,
24200 ShuffleVectorSDNode *OtherSVN, SDValue N1,
24201 const TargetLowering &TLI, SDValue &SV0, SDValue &SV1,
24202 SmallVectorImpl<int> &Mask) -> bool {
24203 // Don't try to fold splats; they're likely to simplify somehow, or they
24204 // might be free.
24205 if (OtherSVN->isSplat())
24206 return false;
24207
24208 SV0 = SV1 = SDValue();
24209 Mask.clear();
24210
24211 for (unsigned i = 0; i != NumElts; ++i) {
24212 int Idx = SVN->getMaskElt(i);
24213 if (Idx < 0) {
24214 // Propagate Undef.
24215 Mask.push_back(Idx);
24216 continue;
24217 }
24218
24219 if (Commute)
24220 Idx = (Idx < (int)NumElts) ? (Idx + NumElts) : (Idx - NumElts);
24221
24222 SDValue CurrentVec;
24223 if (Idx < (int)NumElts) {
24224 // This shuffle index refers to the inner shuffle N0. Lookup the inner
24225 // shuffle mask to identify which vector is actually referenced.
24226 Idx = OtherSVN->getMaskElt(Idx);
24227 if (Idx < 0) {
24228 // Propagate Undef.
24229 Mask.push_back(Idx);
24230 continue;
24231 }
24232 CurrentVec = (Idx < (int)NumElts) ? OtherSVN->getOperand(0)
24233 : OtherSVN->getOperand(1);
24234 } else {
24235 // This shuffle index references an element within N1.
24236 CurrentVec = N1;
24237 }
24238
24239 // Simple case where 'CurrentVec' is UNDEF.
24240 if (CurrentVec.isUndef()) {
24241 Mask.push_back(-1);
24242 continue;
24243 }
24244
24245 // Canonicalize the shuffle index. We don't know yet if CurrentVec
24246 // will be the first or second operand of the combined shuffle.
24247 Idx = Idx % NumElts;
24248 if (!SV0.getNode() || SV0 == CurrentVec) {
24249 // Ok. CurrentVec is the left hand side.
24250 // Update the mask accordingly.
24251 SV0 = CurrentVec;
24252 Mask.push_back(Idx);
24253 continue;
24254 }
24255 if (!SV1.getNode() || SV1 == CurrentVec) {
24256 // Ok. CurrentVec is the right hand side.
24257 // Update the mask accordingly.
24258 SV1 = CurrentVec;
24259 Mask.push_back(Idx + NumElts);
24260 continue;
24261 }
24262
24263 // Last chance - see if the vector is another shuffle and if it
24264 // uses one of the existing candidate shuffle ops.
24265 if (auto *CurrentSVN = dyn_cast<ShuffleVectorSDNode>(CurrentVec)) {
24266 int InnerIdx = CurrentSVN->getMaskElt(Idx);
24267 if (InnerIdx < 0) {
24268 Mask.push_back(-1);
24269 continue;
24270 }
24271 SDValue InnerVec = (InnerIdx < (int)NumElts)
24272 ? CurrentSVN->getOperand(0)
24273 : CurrentSVN->getOperand(1);
24274 if (InnerVec.isUndef()) {
24275 Mask.push_back(-1);
24276 continue;
24277 }
24278 InnerIdx %= NumElts;
24279 if (InnerVec == SV0) {
24280 Mask.push_back(InnerIdx);
24281 continue;
24282 }
24283 if (InnerVec == SV1) {
24284 Mask.push_back(InnerIdx + NumElts);
24285 continue;
24286 }
24287 }
24288
24289 // Bail out if we cannot convert the shuffle pair into a single shuffle.
24290 return false;
24291 }
24292
24293 if (llvm::all_of(Mask, [](int M) { return M < 0; }))
24294 return true;
24295
24296 // Avoid introducing shuffles with illegal mask.
24297 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
24298 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
24299 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
24300 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
24301 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
24302 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
24303 if (TLI.isShuffleMaskLegal(Mask, VT))
24304 return true;
24305
24306 std::swap(SV0, SV1);
24307 ShuffleVectorSDNode::commuteMask(Mask);
24308 return TLI.isShuffleMaskLegal(Mask, VT);
24309 };
24310
24311 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
24312 // Canonicalize shuffles according to rules:
24313 // shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A)
24314 // shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B)
24315 // shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B)
24316 if (N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
24317 N0.getOpcode() != ISD::VECTOR_SHUFFLE) {
24318 // The incoming shuffle must be of the same type as the result of the
24319 // current shuffle.
24320 assert(N1->getOperand(0).getValueType() == VT &&
24321 "Shuffle types don't match");
24322
24323 SDValue SV0 = N1->getOperand(0);
24324 SDValue SV1 = N1->getOperand(1);
24325 bool HasSameOp0 = N0 == SV0;
24326 bool IsSV1Undef = SV1.isUndef();
24327 if (HasSameOp0 || IsSV1Undef || N0 == SV1)
24328 // Commute the operands of this shuffle so merging below will trigger.
24329 return DAG.getCommutedVectorShuffle(*SVN);
24330 }
24331
24332 // Canonicalize splat shuffles to the RHS to improve merging below.
24333 // shuffle(splat(A,u), shuffle(C,D)) -> shuffle'(shuffle(C,D), splat(A,u))
24334 if (N0.getOpcode() == ISD::VECTOR_SHUFFLE &&
24335 N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
24336 cast<ShuffleVectorSDNode>(N0)->isSplat() &&
24337 !cast<ShuffleVectorSDNode>(N1)->isSplat()) {
24338 return DAG.getCommutedVectorShuffle(*SVN);
24339 }
24340
24341 // Try to fold according to rules:
24342 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
24343 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
24344 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
24345 // Don't try to fold shuffles with illegal type.
24346 // Only fold if this shuffle is the only user of the other shuffle.
24347 // Try matching shuffle(C,shuffle(A,B)) commutted patterns as well.
24348 for (int i = 0; i != 2; ++i) {
24349 if (N->getOperand(i).getOpcode() == ISD::VECTOR_SHUFFLE &&
24350 N->isOnlyUserOf(N->getOperand(i).getNode())) {
24351 // The incoming shuffle must be of the same type as the result of the
24352 // current shuffle.
24353 auto *OtherSV = cast<ShuffleVectorSDNode>(N->getOperand(i));
24354 assert(OtherSV->getOperand(0).getValueType() == VT &&
24355 "Shuffle types don't match");
24356
24357 SDValue SV0, SV1;
24358 SmallVector<int, 4> Mask;
24359 if (MergeInnerShuffle(i != 0, SVN, OtherSV, N->getOperand(1 - i), TLI,
24360 SV0, SV1, Mask)) {
24361 // Check if all indices in Mask are Undef. In case, propagate Undef.
24362 if (llvm::all_of(Mask, [](int M) { return M < 0; }))
24363 return DAG.getUNDEF(VT);
24364
24365 return DAG.getVectorShuffle(VT, SDLoc(N),
24366 SV0 ? SV0 : DAG.getUNDEF(VT),
24367 SV1 ? SV1 : DAG.getUNDEF(VT), Mask);
24368 }
24369 }
24370 }
24371
24372 // Merge shuffles through binops if we are able to merge it with at least
24373 // one other shuffles.
24374 // shuffle(bop(shuffle(x,y),shuffle(z,w)),undef)
24375 // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
24376 unsigned SrcOpcode = N0.getOpcode();
24377 if (TLI.isBinOp(SrcOpcode) && N->isOnlyUserOf(N0.getNode()) &&
24378 (N1.isUndef() ||
24379 (SrcOpcode == N1.getOpcode() && N->isOnlyUserOf(N1.getNode())))) {
24380 // Get binop source ops, or just pass on the undef.
24381 SDValue Op00 = N0.getOperand(0);
24382 SDValue Op01 = N0.getOperand(1);
24383 SDValue Op10 = N1.isUndef() ? N1 : N1.getOperand(0);
24384 SDValue Op11 = N1.isUndef() ? N1 : N1.getOperand(1);
24385 // TODO: We might be able to relax the VT check but we don't currently
24386 // have any isBinOp() that has different result/ops VTs so play safe until
24387 // we have test coverage.
24388 if (Op00.getValueType() == VT && Op10.getValueType() == VT &&
24389 Op01.getValueType() == VT && Op11.getValueType() == VT &&
24390 (Op00.getOpcode() == ISD::VECTOR_SHUFFLE ||
24391 Op10.getOpcode() == ISD::VECTOR_SHUFFLE ||
24392 Op01.getOpcode() == ISD::VECTOR_SHUFFLE ||
24393 Op11.getOpcode() == ISD::VECTOR_SHUFFLE)) {
24394 auto CanMergeInnerShuffle = [&](SDValue &SV0, SDValue &SV1,
24395 SmallVectorImpl<int> &Mask, bool LeftOp,
24396 bool Commute) {
24397 SDValue InnerN = Commute ? N1 : N0;
24398 SDValue Op0 = LeftOp ? Op00 : Op01;
24399 SDValue Op1 = LeftOp ? Op10 : Op11;
24400 if (Commute)
24401 std::swap(Op0, Op1);
24402 // Only accept the merged shuffle if we don't introduce undef elements,
24403 // or the inner shuffle already contained undef elements.
24404 auto *SVN0 = dyn_cast<ShuffleVectorSDNode>(Op0);
24405 return SVN0 && InnerN->isOnlyUserOf(SVN0) &&
24406 MergeInnerShuffle(Commute, SVN, SVN0, Op1, TLI, SV0, SV1,
24407 Mask) &&
24408 (llvm::any_of(SVN0->getMask(), [](int M) { return M < 0; }) ||
24409 llvm::none_of(Mask, [](int M) { return M < 0; }));
24410 };
24411
24412 // Ensure we don't increase the number of shuffles - we must merge a
24413 // shuffle from at least one of the LHS and RHS ops.
24414 bool MergedLeft = false;
24415 SDValue LeftSV0, LeftSV1;
24416 SmallVector<int, 4> LeftMask;
24417 if (CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, false) ||
24418 CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, true)) {
24419 MergedLeft = true;
24420 } else {
24421 LeftMask.assign(SVN->getMask().begin(), SVN->getMask().end());
24422 LeftSV0 = Op00, LeftSV1 = Op10;
24423 }
24424
24425 bool MergedRight = false;
24426 SDValue RightSV0, RightSV1;
24427 SmallVector<int, 4> RightMask;
24428 if (CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, false) ||
24429 CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, true)) {
24430 MergedRight = true;
24431 } else {
24432 RightMask.assign(SVN->getMask().begin(), SVN->getMask().end());
24433 RightSV0 = Op01, RightSV1 = Op11;
24434 }
24435
24436 if (MergedLeft || MergedRight) {
24437 SDLoc DL(N);
24438 SDValue LHS = DAG.getVectorShuffle(
24439 VT, DL, LeftSV0 ? LeftSV0 : DAG.getUNDEF(VT),
24440 LeftSV1 ? LeftSV1 : DAG.getUNDEF(VT), LeftMask);
24441 SDValue RHS = DAG.getVectorShuffle(
24442 VT, DL, RightSV0 ? RightSV0 : DAG.getUNDEF(VT),
24443 RightSV1 ? RightSV1 : DAG.getUNDEF(VT), RightMask);
24444 return DAG.getNode(SrcOpcode, DL, VT, LHS, RHS);
24445 }
24446 }
24447 }
24448 }
24449
24450 if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG))
24451 return V;
24452
24453 // Match shuffles that can be converted to ISD::ZERO_EXTEND_VECTOR_INREG.
24454 // Perform this really late, because it could eliminate knowledge
24455 // of undef elements created by this shuffle.
24456 if (Level < AfterLegalizeTypes)
24457 if (SDValue V = combineShuffleToZeroExtendVectorInReg(SVN, DAG, TLI,
24458 LegalOperations))
24459 return V;
24460
24461 return SDValue();
24462 }
24463
visitSCALAR_TO_VECTOR(SDNode * N)24464 SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
24465 EVT VT = N->getValueType(0);
24466 if (!VT.isFixedLengthVector())
24467 return SDValue();
24468
24469 // Try to convert a scalar binop with an extracted vector element to a vector
24470 // binop. This is intended to reduce potentially expensive register moves.
24471 // TODO: Check if both operands are extracted.
24472 // TODO: Generalize this, so it can be called from visitINSERT_VECTOR_ELT().
24473 SDValue Scalar = N->getOperand(0);
24474 unsigned Opcode = Scalar.getOpcode();
24475 EVT VecEltVT = VT.getScalarType();
24476 if (Scalar.hasOneUse() && Scalar->getNumValues() == 1 &&
24477 TLI.isBinOp(Opcode) && Scalar.getValueType() == VecEltVT &&
24478 Scalar.getOperand(0).getValueType() == VecEltVT &&
24479 Scalar.getOperand(1).getValueType() == VecEltVT &&
24480 DAG.isSafeToSpeculativelyExecute(Opcode) && hasOperation(Opcode, VT)) {
24481 // Match an extract element and get a shuffle mask equivalent.
24482 SmallVector<int, 8> ShufMask(VT.getVectorNumElements(), -1);
24483
24484 for (int i : {0, 1}) {
24485 // s2v (bo (extelt V, Idx), C) --> shuffle (bo V, C'), {Idx, -1, -1...}
24486 // s2v (bo C, (extelt V, Idx)) --> shuffle (bo C', V), {Idx, -1, -1...}
24487 SDValue EE = Scalar.getOperand(i);
24488 auto *C = dyn_cast<ConstantSDNode>(Scalar.getOperand(i ? 0 : 1));
24489 if (C && EE.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
24490 EE.getOperand(0).getValueType() == VT &&
24491 isa<ConstantSDNode>(EE.getOperand(1))) {
24492 // Mask = {ExtractIndex, undef, undef....}
24493 ShufMask[0] = EE.getConstantOperandVal(1);
24494 // Make sure the shuffle is legal if we are crossing lanes.
24495 if (TLI.isShuffleMaskLegal(ShufMask, VT)) {
24496 SDLoc DL(N);
24497 SDValue V[] = {EE.getOperand(0),
24498 DAG.getConstant(C->getAPIntValue(), DL, VT)};
24499 SDValue VecBO = DAG.getNode(Opcode, DL, VT, V[i], V[1 - i]);
24500 return DAG.getVectorShuffle(VT, DL, VecBO, DAG.getUNDEF(VT),
24501 ShufMask);
24502 }
24503 }
24504 }
24505 }
24506
24507 // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern
24508 // with a VECTOR_SHUFFLE and possible truncate.
24509 if (Opcode != ISD::EXTRACT_VECTOR_ELT ||
24510 !Scalar.getOperand(0).getValueType().isFixedLengthVector())
24511 return SDValue();
24512
24513 // If we have an implicit truncate, truncate here if it is legal.
24514 if (VecEltVT != Scalar.getValueType() &&
24515 Scalar.getValueType().isScalarInteger() && isTypeLegal(VecEltVT)) {
24516 SDValue Val = DAG.getNode(ISD::TRUNCATE, SDLoc(Scalar), VecEltVT, Scalar);
24517 return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val);
24518 }
24519
24520 auto *ExtIndexC = dyn_cast<ConstantSDNode>(Scalar.getOperand(1));
24521 if (!ExtIndexC)
24522 return SDValue();
24523
24524 SDValue SrcVec = Scalar.getOperand(0);
24525 EVT SrcVT = SrcVec.getValueType();
24526 unsigned SrcNumElts = SrcVT.getVectorNumElements();
24527 unsigned VTNumElts = VT.getVectorNumElements();
24528 if (VecEltVT == SrcVT.getScalarType() && VTNumElts <= SrcNumElts) {
24529 // Create a shuffle equivalent for scalar-to-vector: {ExtIndex, -1, -1, ...}
24530 SmallVector<int, 8> Mask(SrcNumElts, -1);
24531 Mask[0] = ExtIndexC->getZExtValue();
24532 SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
24533 SrcVT, SDLoc(N), SrcVec, DAG.getUNDEF(SrcVT), Mask, DAG);
24534 if (!LegalShuffle)
24535 return SDValue();
24536
24537 // If the initial vector is the same size, the shuffle is the result.
24538 if (VT == SrcVT)
24539 return LegalShuffle;
24540
24541 // If not, shorten the shuffled vector.
24542 if (VTNumElts != SrcNumElts) {
24543 SDValue ZeroIdx = DAG.getVectorIdxConstant(0, SDLoc(N));
24544 EVT SubVT = EVT::getVectorVT(*DAG.getContext(),
24545 SrcVT.getVectorElementType(), VTNumElts);
24546 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), SubVT, LegalShuffle,
24547 ZeroIdx);
24548 }
24549 }
24550
24551 return SDValue();
24552 }
24553
visitINSERT_SUBVECTOR(SDNode * N)24554 SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
24555 EVT VT = N->getValueType(0);
24556 SDValue N0 = N->getOperand(0);
24557 SDValue N1 = N->getOperand(1);
24558 SDValue N2 = N->getOperand(2);
24559 uint64_t InsIdx = N->getConstantOperandVal(2);
24560
24561 // If inserting an UNDEF, just return the original vector.
24562 if (N1.isUndef())
24563 return N0;
24564
24565 // If this is an insert of an extracted vector into an undef vector, we can
24566 // just use the input to the extract.
24567 if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
24568 N1.getOperand(1) == N2 && N1.getOperand(0).getValueType() == VT)
24569 return N1.getOperand(0);
24570
24571 // Simplify scalar inserts into an undef vector:
24572 // insert_subvector undef, (splat X), N2 -> splat X
24573 if (N0.isUndef() && N1.getOpcode() == ISD::SPLAT_VECTOR)
24574 return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, N1.getOperand(0));
24575
24576 // If we are inserting a bitcast value into an undef, with the same
24577 // number of elements, just use the bitcast input of the extract.
24578 // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 ->
24579 // BITCAST (INSERT_SUBVECTOR UNDEF N1 N2)
24580 if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST &&
24581 N1.getOperand(0).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
24582 N1.getOperand(0).getOperand(1) == N2 &&
24583 N1.getOperand(0).getOperand(0).getValueType().getVectorElementCount() ==
24584 VT.getVectorElementCount() &&
24585 N1.getOperand(0).getOperand(0).getValueType().getSizeInBits() ==
24586 VT.getSizeInBits()) {
24587 return DAG.getBitcast(VT, N1.getOperand(0).getOperand(0));
24588 }
24589
24590 // If both N1 and N2 are bitcast values on which insert_subvector
24591 // would makes sense, pull the bitcast through.
24592 // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 ->
24593 // BITCAST (INSERT_SUBVECTOR N0 N1 N2)
24594 if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
24595 SDValue CN0 = N0.getOperand(0);
24596 SDValue CN1 = N1.getOperand(0);
24597 EVT CN0VT = CN0.getValueType();
24598 EVT CN1VT = CN1.getValueType();
24599 if (CN0VT.isVector() && CN1VT.isVector() &&
24600 CN0VT.getVectorElementType() == CN1VT.getVectorElementType() &&
24601 CN0VT.getVectorElementCount() == VT.getVectorElementCount()) {
24602 SDValue NewINSERT = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N),
24603 CN0.getValueType(), CN0, CN1, N2);
24604 return DAG.getBitcast(VT, NewINSERT);
24605 }
24606 }
24607
24608 // Combine INSERT_SUBVECTORs where we are inserting to the same index.
24609 // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
24610 // --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
24611 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
24612 N0.getOperand(1).getValueType() == N1.getValueType() &&
24613 N0.getOperand(2) == N2)
24614 return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0.getOperand(0),
24615 N1, N2);
24616
24617 // Eliminate an intermediate insert into an undef vector:
24618 // insert_subvector undef, (insert_subvector undef, X, 0), N2 -->
24619 // insert_subvector undef, X, N2
24620 if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR &&
24621 N1.getOperand(0).isUndef() && isNullConstant(N1.getOperand(2)))
24622 return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0,
24623 N1.getOperand(1), N2);
24624
24625 // Push subvector bitcasts to the output, adjusting the index as we go.
24626 // insert_subvector(bitcast(v), bitcast(s), c1)
24627 // -> bitcast(insert_subvector(v, s, c2))
24628 if ((N0.isUndef() || N0.getOpcode() == ISD::BITCAST) &&
24629 N1.getOpcode() == ISD::BITCAST) {
24630 SDValue N0Src = peekThroughBitcasts(N0);
24631 SDValue N1Src = peekThroughBitcasts(N1);
24632 EVT N0SrcSVT = N0Src.getValueType().getScalarType();
24633 EVT N1SrcSVT = N1Src.getValueType().getScalarType();
24634 if ((N0.isUndef() || N0SrcSVT == N1SrcSVT) &&
24635 N0Src.getValueType().isVector() && N1Src.getValueType().isVector()) {
24636 EVT NewVT;
24637 SDLoc DL(N);
24638 SDValue NewIdx;
24639 LLVMContext &Ctx = *DAG.getContext();
24640 ElementCount NumElts = VT.getVectorElementCount();
24641 unsigned EltSizeInBits = VT.getScalarSizeInBits();
24642 if ((EltSizeInBits % N1SrcSVT.getSizeInBits()) == 0) {
24643 unsigned Scale = EltSizeInBits / N1SrcSVT.getSizeInBits();
24644 NewVT = EVT::getVectorVT(Ctx, N1SrcSVT, NumElts * Scale);
24645 NewIdx = DAG.getVectorIdxConstant(InsIdx * Scale, DL);
24646 } else if ((N1SrcSVT.getSizeInBits() % EltSizeInBits) == 0) {
24647 unsigned Scale = N1SrcSVT.getSizeInBits() / EltSizeInBits;
24648 if (NumElts.isKnownMultipleOf(Scale) && (InsIdx % Scale) == 0) {
24649 NewVT = EVT::getVectorVT(Ctx, N1SrcSVT,
24650 NumElts.divideCoefficientBy(Scale));
24651 NewIdx = DAG.getVectorIdxConstant(InsIdx / Scale, DL);
24652 }
24653 }
24654 if (NewIdx && hasOperation(ISD::INSERT_SUBVECTOR, NewVT)) {
24655 SDValue Res = DAG.getBitcast(NewVT, N0Src);
24656 Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewVT, Res, N1Src, NewIdx);
24657 return DAG.getBitcast(VT, Res);
24658 }
24659 }
24660 }
24661
24662 // Canonicalize insert_subvector dag nodes.
24663 // Example:
24664 // (insert_subvector (insert_subvector A, Idx0), Idx1)
24665 // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
24666 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
24667 N1.getValueType() == N0.getOperand(1).getValueType()) {
24668 unsigned OtherIdx = N0.getConstantOperandVal(2);
24669 if (InsIdx < OtherIdx) {
24670 // Swap nodes.
24671 SDValue NewOp = DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT,
24672 N0.getOperand(0), N1, N2);
24673 AddToWorklist(NewOp.getNode());
24674 return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N0.getNode()),
24675 VT, NewOp, N0.getOperand(1), N0.getOperand(2));
24676 }
24677 }
24678
24679 // If the input vector is a concatenation, and the insert replaces
24680 // one of the pieces, we can optimize into a single concat_vectors.
24681 if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
24682 N0.getOperand(0).getValueType() == N1.getValueType() &&
24683 N0.getOperand(0).getValueType().isScalableVector() ==
24684 N1.getValueType().isScalableVector()) {
24685 unsigned Factor = N1.getValueType().getVectorMinNumElements();
24686 SmallVector<SDValue, 8> Ops(N0->op_begin(), N0->op_end());
24687 Ops[InsIdx / Factor] = N1;
24688 return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, Ops);
24689 }
24690
24691 // Simplify source operands based on insertion.
24692 if (SimplifyDemandedVectorElts(SDValue(N, 0)))
24693 return SDValue(N, 0);
24694
24695 return SDValue();
24696 }
24697
visitFP_TO_FP16(SDNode * N)24698 SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
24699 SDValue N0 = N->getOperand(0);
24700
24701 // fold (fp_to_fp16 (fp16_to_fp op)) -> op
24702 if (N0->getOpcode() == ISD::FP16_TO_FP)
24703 return N0->getOperand(0);
24704
24705 return SDValue();
24706 }
24707
visitFP16_TO_FP(SDNode * N)24708 SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
24709 SDValue N0 = N->getOperand(0);
24710
24711 // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op)
24712 if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
24713 ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
24714 if (AndConst && AndConst->getAPIntValue() == 0xffff) {
24715 return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), N->getValueType(0),
24716 N0.getOperand(0));
24717 }
24718 }
24719
24720 return SDValue();
24721 }
24722
visitFP_TO_BF16(SDNode * N)24723 SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
24724 SDValue N0 = N->getOperand(0);
24725
24726 // fold (fp_to_bf16 (bf16_to_fp op)) -> op
24727 if (N0->getOpcode() == ISD::BF16_TO_FP)
24728 return N0->getOperand(0);
24729
24730 return SDValue();
24731 }
24732
visitVECREDUCE(SDNode * N)24733 SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
24734 SDValue N0 = N->getOperand(0);
24735 EVT VT = N0.getValueType();
24736 unsigned Opcode = N->getOpcode();
24737
24738 // VECREDUCE over 1-element vector is just an extract.
24739 if (VT.getVectorElementCount().isScalar()) {
24740 SDLoc dl(N);
24741 SDValue Res =
24742 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0,
24743 DAG.getVectorIdxConstant(0, dl));
24744 if (Res.getValueType() != N->getValueType(0))
24745 Res = DAG.getNode(ISD::ANY_EXTEND, dl, N->getValueType(0), Res);
24746 return Res;
24747 }
24748
24749 // On an boolean vector an and/or reduction is the same as a umin/umax
24750 // reduction. Convert them if the latter is legal while the former isn't.
24751 if (Opcode == ISD::VECREDUCE_AND || Opcode == ISD::VECREDUCE_OR) {
24752 unsigned NewOpcode = Opcode == ISD::VECREDUCE_AND
24753 ? ISD::VECREDUCE_UMIN : ISD::VECREDUCE_UMAX;
24754 if (!TLI.isOperationLegalOrCustom(Opcode, VT) &&
24755 TLI.isOperationLegalOrCustom(NewOpcode, VT) &&
24756 DAG.ComputeNumSignBits(N0) == VT.getScalarSizeInBits())
24757 return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), N0);
24758 }
24759
24760 // vecreduce_or(insert_subvector(zero or undef, val)) -> vecreduce_or(val)
24761 // vecreduce_and(insert_subvector(ones or undef, val)) -> vecreduce_and(val)
24762 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
24763 TLI.isTypeLegal(N0.getOperand(1).getValueType())) {
24764 SDValue Vec = N0.getOperand(0);
24765 SDValue Subvec = N0.getOperand(1);
24766 if ((Opcode == ISD::VECREDUCE_OR &&
24767 (N0.getOperand(0).isUndef() || isNullOrNullSplat(Vec))) ||
24768 (Opcode == ISD::VECREDUCE_AND &&
24769 (N0.getOperand(0).isUndef() || isAllOnesOrAllOnesSplat(Vec))))
24770 return DAG.getNode(Opcode, SDLoc(N), N->getValueType(0), Subvec);
24771 }
24772
24773 return SDValue();
24774 }
24775
visitVPOp(SDNode * N)24776 SDValue DAGCombiner::visitVPOp(SDNode *N) {
24777
24778 if (N->getOpcode() == ISD::VP_GATHER)
24779 if (SDValue SD = visitVPGATHER(N))
24780 return SD;
24781
24782 if (N->getOpcode() == ISD::VP_SCATTER)
24783 if (SDValue SD = visitVPSCATTER(N))
24784 return SD;
24785
24786 // VP operations in which all vector elements are disabled - either by
24787 // determining that the mask is all false or that the EVL is 0 - can be
24788 // eliminated.
24789 bool AreAllEltsDisabled = false;
24790 if (auto EVLIdx = ISD::getVPExplicitVectorLengthIdx(N->getOpcode()))
24791 AreAllEltsDisabled |= isNullConstant(N->getOperand(*EVLIdx));
24792 if (auto MaskIdx = ISD::getVPMaskIdx(N->getOpcode()))
24793 AreAllEltsDisabled |=
24794 ISD::isConstantSplatVectorAllZeros(N->getOperand(*MaskIdx).getNode());
24795
24796 // This is the only generic VP combine we support for now.
24797 if (!AreAllEltsDisabled)
24798 return SDValue();
24799
24800 // Binary operations can be replaced by UNDEF.
24801 if (ISD::isVPBinaryOp(N->getOpcode()))
24802 return DAG.getUNDEF(N->getValueType(0));
24803
24804 // VP Memory operations can be replaced by either the chain (stores) or the
24805 // chain + undef (loads).
24806 if (const auto *MemSD = dyn_cast<MemSDNode>(N)) {
24807 if (MemSD->writeMem())
24808 return MemSD->getChain();
24809 return CombineTo(N, DAG.getUNDEF(N->getValueType(0)), MemSD->getChain());
24810 }
24811
24812 // Reduction operations return the start operand when no elements are active.
24813 if (ISD::isVPReduction(N->getOpcode()))
24814 return N->getOperand(0);
24815
24816 return SDValue();
24817 }
24818
24819 /// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
24820 /// with the destination vector and a zero vector.
24821 /// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
24822 /// vector_shuffle V, Zero, <0, 4, 2, 4>
XformToShuffleWithZero(SDNode * N)24823 SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
24824 assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
24825
24826 EVT VT = N->getValueType(0);
24827 SDValue LHS = N->getOperand(0);
24828 SDValue RHS = peekThroughBitcasts(N->getOperand(1));
24829 SDLoc DL(N);
24830
24831 // Make sure we're not running after operation legalization where it
24832 // may have custom lowered the vector shuffles.
24833 if (LegalOperations)
24834 return SDValue();
24835
24836 if (RHS.getOpcode() != ISD::BUILD_VECTOR)
24837 return SDValue();
24838
24839 EVT RVT = RHS.getValueType();
24840 unsigned NumElts = RHS.getNumOperands();
24841
24842 // Attempt to create a valid clear mask, splitting the mask into
24843 // sub elements and checking to see if each is
24844 // all zeros or all ones - suitable for shuffle masking.
24845 auto BuildClearMask = [&](int Split) {
24846 int NumSubElts = NumElts * Split;
24847 int NumSubBits = RVT.getScalarSizeInBits() / Split;
24848
24849 SmallVector<int, 8> Indices;
24850 for (int i = 0; i != NumSubElts; ++i) {
24851 int EltIdx = i / Split;
24852 int SubIdx = i % Split;
24853 SDValue Elt = RHS.getOperand(EltIdx);
24854 // X & undef --> 0 (not undef). So this lane must be converted to choose
24855 // from the zero constant vector (same as if the element had all 0-bits).
24856 if (Elt.isUndef()) {
24857 Indices.push_back(i + NumSubElts);
24858 continue;
24859 }
24860
24861 APInt Bits;
24862 if (isa<ConstantSDNode>(Elt))
24863 Bits = cast<ConstantSDNode>(Elt)->getAPIntValue();
24864 else if (isa<ConstantFPSDNode>(Elt))
24865 Bits = cast<ConstantFPSDNode>(Elt)->getValueAPF().bitcastToAPInt();
24866 else
24867 return SDValue();
24868
24869 // Extract the sub element from the constant bit mask.
24870 if (DAG.getDataLayout().isBigEndian())
24871 Bits = Bits.extractBits(NumSubBits, (Split - SubIdx - 1) * NumSubBits);
24872 else
24873 Bits = Bits.extractBits(NumSubBits, SubIdx * NumSubBits);
24874
24875 if (Bits.isAllOnes())
24876 Indices.push_back(i);
24877 else if (Bits == 0)
24878 Indices.push_back(i + NumSubElts);
24879 else
24880 return SDValue();
24881 }
24882
24883 // Let's see if the target supports this vector_shuffle.
24884 EVT ClearSVT = EVT::getIntegerVT(*DAG.getContext(), NumSubBits);
24885 EVT ClearVT = EVT::getVectorVT(*DAG.getContext(), ClearSVT, NumSubElts);
24886 if (!TLI.isVectorClearMaskLegal(Indices, ClearVT))
24887 return SDValue();
24888
24889 SDValue Zero = DAG.getConstant(0, DL, ClearVT);
24890 return DAG.getBitcast(VT, DAG.getVectorShuffle(ClearVT, DL,
24891 DAG.getBitcast(ClearVT, LHS),
24892 Zero, Indices));
24893 };
24894
24895 // Determine maximum split level (byte level masking).
24896 int MaxSplit = 1;
24897 if (RVT.getScalarSizeInBits() % 8 == 0)
24898 MaxSplit = RVT.getScalarSizeInBits() / 8;
24899
24900 for (int Split = 1; Split <= MaxSplit; ++Split)
24901 if (RVT.getScalarSizeInBits() % Split == 0)
24902 if (SDValue S = BuildClearMask(Split))
24903 return S;
24904
24905 return SDValue();
24906 }
24907
24908 /// If a vector binop is performed on splat values, it may be profitable to
24909 /// extract, scalarize, and insert/splat.
scalarizeBinOpOfSplats(SDNode * N,SelectionDAG & DAG,const SDLoc & DL)24910 static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG,
24911 const SDLoc &DL) {
24912 SDValue N0 = N->getOperand(0);
24913 SDValue N1 = N->getOperand(1);
24914 unsigned Opcode = N->getOpcode();
24915 EVT VT = N->getValueType(0);
24916 EVT EltVT = VT.getVectorElementType();
24917 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24918
24919 // TODO: Remove/replace the extract cost check? If the elements are available
24920 // as scalars, then there may be no extract cost. Should we ask if
24921 // inserting a scalar back into a vector is cheap instead?
24922 int Index0, Index1;
24923 SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
24924 SDValue Src1 = DAG.getSplatSourceVector(N1, Index1);
24925 // Extract element from splat_vector should be free.
24926 // TODO: use DAG.isSplatValue instead?
24927 bool IsBothSplatVector = N0.getOpcode() == ISD::SPLAT_VECTOR &&
24928 N1.getOpcode() == ISD::SPLAT_VECTOR;
24929 if (!Src0 || !Src1 || Index0 != Index1 ||
24930 Src0.getValueType().getVectorElementType() != EltVT ||
24931 Src1.getValueType().getVectorElementType() != EltVT ||
24932 !(IsBothSplatVector || TLI.isExtractVecEltCheap(VT, Index0)) ||
24933 !TLI.isOperationLegalOrCustom(Opcode, EltVT))
24934 return SDValue();
24935
24936 SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL);
24937 SDValue X = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src0, IndexC);
24938 SDValue Y = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Src1, IndexC);
24939 SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, X, Y, N->getFlags());
24940
24941 // If all lanes but 1 are undefined, no need to splat the scalar result.
24942 // TODO: Keep track of undefs and use that info in the general case.
24943 if (N0.getOpcode() == ISD::BUILD_VECTOR && N0.getOpcode() == N1.getOpcode() &&
24944 count_if(N0->ops(), [](SDValue V) { return !V.isUndef(); }) == 1 &&
24945 count_if(N1->ops(), [](SDValue V) { return !V.isUndef(); }) == 1) {
24946 // bo (build_vec ..undef, X, undef...), (build_vec ..undef, Y, undef...) -->
24947 // build_vec ..undef, (bo X, Y), undef...
24948 SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), DAG.getUNDEF(EltVT));
24949 Ops[Index0] = ScalarBO;
24950 return DAG.getBuildVector(VT, DL, Ops);
24951 }
24952
24953 // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
24954 return DAG.getSplat(VT, DL, ScalarBO);
24955 }
24956
24957 /// Visit a vector cast operation, like FP_EXTEND.
SimplifyVCastOp(SDNode * N,const SDLoc & DL)24958 SDValue DAGCombiner::SimplifyVCastOp(SDNode *N, const SDLoc &DL) {
24959 EVT VT = N->getValueType(0);
24960 assert(VT.isVector() && "SimplifyVCastOp only works on vectors!");
24961 EVT EltVT = VT.getVectorElementType();
24962 unsigned Opcode = N->getOpcode();
24963
24964 SDValue N0 = N->getOperand(0);
24965 EVT SrcVT = N0->getValueType(0);
24966 EVT SrcEltVT = SrcVT.getVectorElementType();
24967 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24968
24969 // TODO: promote operation might be also good here?
24970 int Index0;
24971 SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
24972 if (Src0 &&
24973 (N0.getOpcode() == ISD::SPLAT_VECTOR ||
24974 TLI.isExtractVecEltCheap(VT, Index0)) &&
24975 TLI.isOperationLegalOrCustom(Opcode, EltVT) &&
24976 TLI.preferScalarizeSplat(Opcode)) {
24977 SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL);
24978 SDValue Elt =
24979 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, SrcEltVT, Src0, IndexC);
24980 SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, Elt, N->getFlags());
24981 if (VT.isScalableVector())
24982 return DAG.getSplatVector(VT, DL, ScalarBO);
24983 SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
24984 return DAG.getBuildVector(VT, DL, Ops);
24985 }
24986
24987 return SDValue();
24988 }
24989
24990 /// Visit a binary vector operation, like ADD.
SimplifyVBinOp(SDNode * N,const SDLoc & DL)24991 SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
24992 EVT VT = N->getValueType(0);
24993 assert(VT.isVector() && "SimplifyVBinOp only works on vectors!");
24994
24995 SDValue LHS = N->getOperand(0);
24996 SDValue RHS = N->getOperand(1);
24997 unsigned Opcode = N->getOpcode();
24998 SDNodeFlags Flags = N->getFlags();
24999
25000 // Move unary shuffles with identical masks after a vector binop:
25001 // VBinOp (shuffle A, Undef, Mask), (shuffle B, Undef, Mask))
25002 // --> shuffle (VBinOp A, B), Undef, Mask
25003 // This does not require type legality checks because we are creating the
25004 // same types of operations that are in the original sequence. We do have to
25005 // restrict ops like integer div that have immediate UB (eg, div-by-zero)
25006 // though. This code is adapted from the identical transform in instcombine.
25007 if (DAG.isSafeToSpeculativelyExecute(Opcode)) {
25008 auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(LHS);
25009 auto *Shuf1 = dyn_cast<ShuffleVectorSDNode>(RHS);
25010 if (Shuf0 && Shuf1 && Shuf0->getMask().equals(Shuf1->getMask()) &&
25011 LHS.getOperand(1).isUndef() && RHS.getOperand(1).isUndef() &&
25012 (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) {
25013 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS.getOperand(0),
25014 RHS.getOperand(0), Flags);
25015 SDValue UndefV = LHS.getOperand(1);
25016 return DAG.getVectorShuffle(VT, DL, NewBinOp, UndefV, Shuf0->getMask());
25017 }
25018
25019 // Try to sink a splat shuffle after a binop with a uniform constant.
25020 // This is limited to cases where neither the shuffle nor the constant have
25021 // undefined elements because that could be poison-unsafe or inhibit
25022 // demanded elements analysis. It is further limited to not change a splat
25023 // of an inserted scalar because that may be optimized better by
25024 // load-folding or other target-specific behaviors.
25025 if (isConstOrConstSplat(RHS) && Shuf0 && all_equal(Shuf0->getMask()) &&
25026 Shuf0->hasOneUse() && Shuf0->getOperand(1).isUndef() &&
25027 Shuf0->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
25028 // binop (splat X), (splat C) --> splat (binop X, C)
25029 SDValue X = Shuf0->getOperand(0);
25030 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, X, RHS, Flags);
25031 return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
25032 Shuf0->getMask());
25033 }
25034 if (isConstOrConstSplat(LHS) && Shuf1 && all_equal(Shuf1->getMask()) &&
25035 Shuf1->hasOneUse() && Shuf1->getOperand(1).isUndef() &&
25036 Shuf1->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
25037 // binop (splat C), (splat X) --> splat (binop C, X)
25038 SDValue X = Shuf1->getOperand(0);
25039 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, LHS, X, Flags);
25040 return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT),
25041 Shuf1->getMask());
25042 }
25043 }
25044
25045 // The following pattern is likely to emerge with vector reduction ops. Moving
25046 // the binary operation ahead of insertion may allow using a narrower vector
25047 // instruction that has better performance than the wide version of the op:
25048 // VBinOp (ins undef, X, Z), (ins undef, Y, Z) --> ins VecC, (VBinOp X, Y), Z
25049 if (LHS.getOpcode() == ISD::INSERT_SUBVECTOR && LHS.getOperand(0).isUndef() &&
25050 RHS.getOpcode() == ISD::INSERT_SUBVECTOR && RHS.getOperand(0).isUndef() &&
25051 LHS.getOperand(2) == RHS.getOperand(2) &&
25052 (LHS.hasOneUse() || RHS.hasOneUse())) {
25053 SDValue X = LHS.getOperand(1);
25054 SDValue Y = RHS.getOperand(1);
25055 SDValue Z = LHS.getOperand(2);
25056 EVT NarrowVT = X.getValueType();
25057 if (NarrowVT == Y.getValueType() &&
25058 TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT,
25059 LegalOperations)) {
25060 // (binop undef, undef) may not return undef, so compute that result.
25061 SDValue VecC =
25062 DAG.getNode(Opcode, DL, VT, DAG.getUNDEF(VT), DAG.getUNDEF(VT));
25063 SDValue NarrowBO = DAG.getNode(Opcode, DL, NarrowVT, X, Y);
25064 return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, VecC, NarrowBO, Z);
25065 }
25066 }
25067
25068 // Make sure all but the first op are undef or constant.
25069 auto ConcatWithConstantOrUndef = [](SDValue Concat) {
25070 return Concat.getOpcode() == ISD::CONCAT_VECTORS &&
25071 all_of(drop_begin(Concat->ops()), [](const SDValue &Op) {
25072 return Op.isUndef() ||
25073 ISD::isBuildVectorOfConstantSDNodes(Op.getNode());
25074 });
25075 };
25076
25077 // The following pattern is likely to emerge with vector reduction ops. Moving
25078 // the binary operation ahead of the concat may allow using a narrower vector
25079 // instruction that has better performance than the wide version of the op:
25080 // VBinOp (concat X, undef/constant), (concat Y, undef/constant) -->
25081 // concat (VBinOp X, Y), VecC
25082 if (ConcatWithConstantOrUndef(LHS) && ConcatWithConstantOrUndef(RHS) &&
25083 (LHS.hasOneUse() || RHS.hasOneUse())) {
25084 EVT NarrowVT = LHS.getOperand(0).getValueType();
25085 if (NarrowVT == RHS.getOperand(0).getValueType() &&
25086 TLI.isOperationLegalOrCustomOrPromote(Opcode, NarrowVT)) {
25087 unsigned NumOperands = LHS.getNumOperands();
25088 SmallVector<SDValue, 4> ConcatOps;
25089 for (unsigned i = 0; i != NumOperands; ++i) {
25090 // This constant fold for operands 1 and up.
25091 ConcatOps.push_back(DAG.getNode(Opcode, DL, NarrowVT, LHS.getOperand(i),
25092 RHS.getOperand(i)));
25093 }
25094
25095 return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
25096 }
25097 }
25098
25099 if (SDValue V = scalarizeBinOpOfSplats(N, DAG, DL))
25100 return V;
25101
25102 return SDValue();
25103 }
25104
SimplifySelect(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2)25105 SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1,
25106 SDValue N2) {
25107 assert(N0.getOpcode() == ISD::SETCC &&
25108 "First argument must be a SetCC node!");
25109
25110 SDValue SCC = SimplifySelectCC(DL, N0.getOperand(0), N0.getOperand(1), N1, N2,
25111 cast<CondCodeSDNode>(N0.getOperand(2))->get());
25112
25113 // If we got a simplified select_cc node back from SimplifySelectCC, then
25114 // break it down into a new SETCC node, and a new SELECT node, and then return
25115 // the SELECT node, since we were called with a SELECT node.
25116 if (SCC.getNode()) {
25117 // Check to see if we got a select_cc back (to turn into setcc/select).
25118 // Otherwise, just return whatever node we got back, like fabs.
25119 if (SCC.getOpcode() == ISD::SELECT_CC) {
25120 const SDNodeFlags Flags = N0->getFlags();
25121 SDValue SETCC = DAG.getNode(ISD::SETCC, SDLoc(N0),
25122 N0.getValueType(),
25123 SCC.getOperand(0), SCC.getOperand(1),
25124 SCC.getOperand(4), Flags);
25125 AddToWorklist(SETCC.getNode());
25126 SDValue SelectNode = DAG.getSelect(SDLoc(SCC), SCC.getValueType(), SETCC,
25127 SCC.getOperand(2), SCC.getOperand(3));
25128 SelectNode->setFlags(Flags);
25129 return SelectNode;
25130 }
25131
25132 return SCC;
25133 }
25134 return SDValue();
25135 }
25136
25137 /// Given a SELECT or a SELECT_CC node, where LHS and RHS are the two values
25138 /// being selected between, see if we can simplify the select. Callers of this
25139 /// should assume that TheSelect is deleted if this returns true. As such, they
25140 /// should return the appropriate thing (e.g. the node) back to the top-level of
25141 /// the DAG combiner loop to avoid it being looked at.
SimplifySelectOps(SDNode * TheSelect,SDValue LHS,SDValue RHS)25142 bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
25143 SDValue RHS) {
25144 // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
25145 // The select + setcc is redundant, because fsqrt returns NaN for X < 0.
25146 if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(LHS)) {
25147 if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) {
25148 // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?))
25149 SDValue Sqrt = RHS;
25150 ISD::CondCode CC;
25151 SDValue CmpLHS;
25152 const ConstantFPSDNode *Zero = nullptr;
25153
25154 if (TheSelect->getOpcode() == ISD::SELECT_CC) {
25155 CC = cast<CondCodeSDNode>(TheSelect->getOperand(4))->get();
25156 CmpLHS = TheSelect->getOperand(0);
25157 Zero = isConstOrConstSplatFP(TheSelect->getOperand(1));
25158 } else {
25159 // SELECT or VSELECT
25160 SDValue Cmp = TheSelect->getOperand(0);
25161 if (Cmp.getOpcode() == ISD::SETCC) {
25162 CC = cast<CondCodeSDNode>(Cmp.getOperand(2))->get();
25163 CmpLHS = Cmp.getOperand(0);
25164 Zero = isConstOrConstSplatFP(Cmp.getOperand(1));
25165 }
25166 }
25167 if (Zero && Zero->isZero() &&
25168 Sqrt.getOperand(0) == CmpLHS && (CC == ISD::SETOLT ||
25169 CC == ISD::SETULT || CC == ISD::SETLT)) {
25170 // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
25171 CombineTo(TheSelect, Sqrt);
25172 return true;
25173 }
25174 }
25175 }
25176 // Cannot simplify select with vector condition
25177 if (TheSelect->getOperand(0).getValueType().isVector()) return false;
25178
25179 // If this is a select from two identical things, try to pull the operation
25180 // through the select.
25181 if (LHS.getOpcode() != RHS.getOpcode() ||
25182 !LHS.hasOneUse() || !RHS.hasOneUse())
25183 return false;
25184
25185 // If this is a load and the token chain is identical, replace the select
25186 // of two loads with a load through a select of the address to load from.
25187 // This triggers in things like "select bool X, 10.0, 123.0" after the FP
25188 // constants have been dropped into the constant pool.
25189 if (LHS.getOpcode() == ISD::LOAD) {
25190 LoadSDNode *LLD = cast<LoadSDNode>(LHS);
25191 LoadSDNode *RLD = cast<LoadSDNode>(RHS);
25192
25193 // Token chains must be identical.
25194 if (LHS.getOperand(0) != RHS.getOperand(0) ||
25195 // Do not let this transformation reduce the number of volatile loads.
25196 // Be conservative for atomics for the moment
25197 // TODO: This does appear to be legal for unordered atomics (see D66309)
25198 !LLD->isSimple() || !RLD->isSimple() ||
25199 // FIXME: If either is a pre/post inc/dec load,
25200 // we'd need to split out the address adjustment.
25201 LLD->isIndexed() || RLD->isIndexed() ||
25202 // If this is an EXTLOAD, the VT's must match.
25203 LLD->getMemoryVT() != RLD->getMemoryVT() ||
25204 // If this is an EXTLOAD, the kind of extension must match.
25205 (LLD->getExtensionType() != RLD->getExtensionType() &&
25206 // The only exception is if one of the extensions is anyext.
25207 LLD->getExtensionType() != ISD::EXTLOAD &&
25208 RLD->getExtensionType() != ISD::EXTLOAD) ||
25209 // FIXME: this discards src value information. This is
25210 // over-conservative. It would be beneficial to be able to remember
25211 // both potential memory locations. Since we are discarding
25212 // src value info, don't do the transformation if the memory
25213 // locations are not in the default address space.
25214 LLD->getPointerInfo().getAddrSpace() != 0 ||
25215 RLD->getPointerInfo().getAddrSpace() != 0 ||
25216 // We can't produce a CMOV of a TargetFrameIndex since we won't
25217 // generate the address generation required.
25218 LLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
25219 RLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
25220 !TLI.isOperationLegalOrCustom(TheSelect->getOpcode(),
25221 LLD->getBasePtr().getValueType()))
25222 return false;
25223
25224 // The loads must not depend on one another.
25225 if (LLD->isPredecessorOf(RLD) || RLD->isPredecessorOf(LLD))
25226 return false;
25227
25228 // Check that the select condition doesn't reach either load. If so,
25229 // folding this will induce a cycle into the DAG. If not, this is safe to
25230 // xform, so create a select of the addresses.
25231
25232 SmallPtrSet<const SDNode *, 32> Visited;
25233 SmallVector<const SDNode *, 16> Worklist;
25234
25235 // Always fail if LLD and RLD are not independent. TheSelect is a
25236 // predecessor to all Nodes in question so we need not search past it.
25237
25238 Visited.insert(TheSelect);
25239 Worklist.push_back(LLD);
25240 Worklist.push_back(RLD);
25241
25242 if (SDNode::hasPredecessorHelper(LLD, Visited, Worklist) ||
25243 SDNode::hasPredecessorHelper(RLD, Visited, Worklist))
25244 return false;
25245
25246 SDValue Addr;
25247 if (TheSelect->getOpcode() == ISD::SELECT) {
25248 // We cannot do this optimization if any pair of {RLD, LLD} is a
25249 // predecessor to {RLD, LLD, CondNode}. As we've already compared the
25250 // Loads, we only need to check if CondNode is a successor to one of the
25251 // loads. We can further avoid this if there's no use of their chain
25252 // value.
25253 SDNode *CondNode = TheSelect->getOperand(0).getNode();
25254 Worklist.push_back(CondNode);
25255
25256 if ((LLD->hasAnyUseOfValue(1) &&
25257 SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
25258 (RLD->hasAnyUseOfValue(1) &&
25259 SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
25260 return false;
25261
25262 Addr = DAG.getSelect(SDLoc(TheSelect),
25263 LLD->getBasePtr().getValueType(),
25264 TheSelect->getOperand(0), LLD->getBasePtr(),
25265 RLD->getBasePtr());
25266 } else { // Otherwise SELECT_CC
25267 // We cannot do this optimization if any pair of {RLD, LLD} is a
25268 // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared
25269 // the Loads, we only need to check if CondLHS/CondRHS is a successor to
25270 // one of the loads. We can further avoid this if there's no use of their
25271 // chain value.
25272
25273 SDNode *CondLHS = TheSelect->getOperand(0).getNode();
25274 SDNode *CondRHS = TheSelect->getOperand(1).getNode();
25275 Worklist.push_back(CondLHS);
25276 Worklist.push_back(CondRHS);
25277
25278 if ((LLD->hasAnyUseOfValue(1) &&
25279 SDNode::hasPredecessorHelper(LLD, Visited, Worklist)) ||
25280 (RLD->hasAnyUseOfValue(1) &&
25281 SDNode::hasPredecessorHelper(RLD, Visited, Worklist)))
25282 return false;
25283
25284 Addr = DAG.getNode(ISD::SELECT_CC, SDLoc(TheSelect),
25285 LLD->getBasePtr().getValueType(),
25286 TheSelect->getOperand(0),
25287 TheSelect->getOperand(1),
25288 LLD->getBasePtr(), RLD->getBasePtr(),
25289 TheSelect->getOperand(4));
25290 }
25291
25292 SDValue Load;
25293 // It is safe to replace the two loads if they have different alignments,
25294 // but the new load must be the minimum (most restrictive) alignment of the
25295 // inputs.
25296 Align Alignment = std::min(LLD->getAlign(), RLD->getAlign());
25297 MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags();
25298 if (!RLD->isInvariant())
25299 MMOFlags &= ~MachineMemOperand::MOInvariant;
25300 if (!RLD->isDereferenceable())
25301 MMOFlags &= ~MachineMemOperand::MODereferenceable;
25302 if (LLD->getExtensionType() == ISD::NON_EXTLOAD) {
25303 // FIXME: Discards pointer and AA info.
25304 Load = DAG.getLoad(TheSelect->getValueType(0), SDLoc(TheSelect),
25305 LLD->getChain(), Addr, MachinePointerInfo(), Alignment,
25306 MMOFlags);
25307 } else {
25308 // FIXME: Discards pointer and AA info.
25309 Load = DAG.getExtLoad(
25310 LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType()
25311 : LLD->getExtensionType(),
25312 SDLoc(TheSelect), TheSelect->getValueType(0), LLD->getChain(), Addr,
25313 MachinePointerInfo(), LLD->getMemoryVT(), Alignment, MMOFlags);
25314 }
25315
25316 // Users of the select now use the result of the load.
25317 CombineTo(TheSelect, Load);
25318
25319 // Users of the old loads now use the new load's chain. We know the
25320 // old-load value is dead now.
25321 CombineTo(LHS.getNode(), Load.getValue(0), Load.getValue(1));
25322 CombineTo(RHS.getNode(), Load.getValue(0), Load.getValue(1));
25323 return true;
25324 }
25325
25326 return false;
25327 }
25328
25329 /// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and
25330 /// bitwise 'and'.
foldSelectCCToShiftAnd(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)25331 SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
25332 SDValue N1, SDValue N2, SDValue N3,
25333 ISD::CondCode CC) {
25334 // If this is a select where the false operand is zero and the compare is a
25335 // check of the sign bit, see if we can perform the "gzip trick":
25336 // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A
25337 // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A
25338 EVT XType = N0.getValueType();
25339 EVT AType = N2.getValueType();
25340 if (!isNullConstant(N3) || !XType.bitsGE(AType))
25341 return SDValue();
25342
25343 // If the comparison is testing for a positive value, we have to invert
25344 // the sign bit mask, so only do that transform if the target has a bitwise
25345 // 'and not' instruction (the invert is free).
25346 if (CC == ISD::SETGT && TLI.hasAndNot(N2)) {
25347 // (X > -1) ? A : 0
25348 // (X > 0) ? X : 0 <-- This is canonical signed max.
25349 if (!(isAllOnesConstant(N1) || (isNullConstant(N1) && N0 == N2)))
25350 return SDValue();
25351 } else if (CC == ISD::SETLT) {
25352 // (X < 0) ? A : 0
25353 // (X < 1) ? X : 0 <-- This is un-canonicalized signed min.
25354 if (!(isNullConstant(N1) || (isOneConstant(N1) && N0 == N2)))
25355 return SDValue();
25356 } else {
25357 return SDValue();
25358 }
25359
25360 // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
25361 // constant.
25362 EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
25363 auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
25364 if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
25365 unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
25366 if (!TLI.shouldAvoidTransformToShift(XType, ShCt)) {
25367 SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
25368 SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt);
25369 AddToWorklist(Shift.getNode());
25370
25371 if (XType.bitsGT(AType)) {
25372 Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
25373 AddToWorklist(Shift.getNode());
25374 }
25375
25376 if (CC == ISD::SETGT)
25377 Shift = DAG.getNOT(DL, Shift, AType);
25378
25379 return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
25380 }
25381 }
25382
25383 unsigned ShCt = XType.getSizeInBits() - 1;
25384 if (TLI.shouldAvoidTransformToShift(XType, ShCt))
25385 return SDValue();
25386
25387 SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
25388 SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt);
25389 AddToWorklist(Shift.getNode());
25390
25391 if (XType.bitsGT(AType)) {
25392 Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift);
25393 AddToWorklist(Shift.getNode());
25394 }
25395
25396 if (CC == ISD::SETGT)
25397 Shift = DAG.getNOT(DL, Shift, AType);
25398
25399 return DAG.getNode(ISD::AND, DL, AType, Shift, N2);
25400 }
25401
25402 // Fold select(cc, binop(), binop()) -> binop(select(), select()) etc.
foldSelectOfBinops(SDNode * N)25403 SDValue DAGCombiner::foldSelectOfBinops(SDNode *N) {
25404 SDValue N0 = N->getOperand(0);
25405 SDValue N1 = N->getOperand(1);
25406 SDValue N2 = N->getOperand(2);
25407 EVT VT = N->getValueType(0);
25408 SDLoc DL(N);
25409
25410 unsigned BinOpc = N1.getOpcode();
25411 if (!TLI.isBinOp(BinOpc) || (N2.getOpcode() != BinOpc))
25412 return SDValue();
25413
25414 // The use checks are intentionally on SDNode because we may be dealing
25415 // with opcodes that produce more than one SDValue.
25416 // TODO: Do we really need to check N0 (the condition operand of the select)?
25417 // But removing that clause could cause an infinite loop...
25418 if (!N0->hasOneUse() || !N1->hasOneUse() || !N2->hasOneUse())
25419 return SDValue();
25420
25421 // Binops may include opcodes that return multiple values, so all values
25422 // must be created/propagated from the newly created binops below.
25423 SDVTList OpVTs = N1->getVTList();
25424
25425 // Fold select(cond, binop(x, y), binop(z, y))
25426 // --> binop(select(cond, x, z), y)
25427 if (N1.getOperand(1) == N2.getOperand(1)) {
25428 SDValue NewSel =
25429 DAG.getSelect(DL, VT, N0, N1.getOperand(0), N2.getOperand(0));
25430 SDValue NewBinOp = DAG.getNode(BinOpc, DL, OpVTs, NewSel, N1.getOperand(1));
25431 NewBinOp->setFlags(N1->getFlags());
25432 NewBinOp->intersectFlagsWith(N2->getFlags());
25433 return NewBinOp;
25434 }
25435
25436 // Fold select(cond, binop(x, y), binop(x, z))
25437 // --> binop(x, select(cond, y, z))
25438 // Second op VT might be different (e.g. shift amount type)
25439 if (N1.getOperand(0) == N2.getOperand(0) &&
25440 VT == N1.getOperand(1).getValueType() &&
25441 VT == N2.getOperand(1).getValueType()) {
25442 SDValue NewSel =
25443 DAG.getSelect(DL, VT, N0, N1.getOperand(1), N2.getOperand(1));
25444 SDValue NewBinOp = DAG.getNode(BinOpc, DL, OpVTs, N1.getOperand(0), NewSel);
25445 NewBinOp->setFlags(N1->getFlags());
25446 NewBinOp->intersectFlagsWith(N2->getFlags());
25447 return NewBinOp;
25448 }
25449
25450 // TODO: Handle isCommutativeBinOp patterns as well?
25451 return SDValue();
25452 }
25453
25454 // Transform (fneg/fabs (bitconvert x)) to avoid loading constant pool values.
foldSignChangeInBitcast(SDNode * N)25455 SDValue DAGCombiner::foldSignChangeInBitcast(SDNode *N) {
25456 SDValue N0 = N->getOperand(0);
25457 EVT VT = N->getValueType(0);
25458 bool IsFabs = N->getOpcode() == ISD::FABS;
25459 bool IsFree = IsFabs ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
25460
25461 if (IsFree || N0.getOpcode() != ISD::BITCAST || !N0.hasOneUse())
25462 return SDValue();
25463
25464 SDValue Int = N0.getOperand(0);
25465 EVT IntVT = Int.getValueType();
25466
25467 // The operand to cast should be integer.
25468 if (!IntVT.isInteger() || IntVT.isVector())
25469 return SDValue();
25470
25471 // (fneg (bitconvert x)) -> (bitconvert (xor x sign))
25472 // (fabs (bitconvert x)) -> (bitconvert (and x ~sign))
25473 APInt SignMask;
25474 if (N0.getValueType().isVector()) {
25475 // For vector, create a sign mask (0x80...) or its inverse (for fabs,
25476 // 0x7f...) per element and splat it.
25477 SignMask = APInt::getSignMask(N0.getScalarValueSizeInBits());
25478 if (IsFabs)
25479 SignMask = ~SignMask;
25480 SignMask = APInt::getSplat(IntVT.getSizeInBits(), SignMask);
25481 } else {
25482 // For scalar, just use the sign mask (0x80... or the inverse, 0x7f...)
25483 SignMask = APInt::getSignMask(IntVT.getSizeInBits());
25484 if (IsFabs)
25485 SignMask = ~SignMask;
25486 }
25487 SDLoc DL(N0);
25488 Int = DAG.getNode(IsFabs ? ISD::AND : ISD::XOR, DL, IntVT, Int,
25489 DAG.getConstant(SignMask, DL, IntVT));
25490 AddToWorklist(Int.getNode());
25491 return DAG.getBitcast(VT, Int);
25492 }
25493
25494 /// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
25495 /// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
25496 /// in it. This may be a win when the constant is not otherwise available
25497 /// because it replaces two constant pool loads with one.
convertSelectOfFPConstantsToLoadOffset(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC)25498 SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset(
25499 const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
25500 ISD::CondCode CC) {
25501 if (!TLI.reduceSelectOfFPConstantLoads(N0.getValueType()))
25502 return SDValue();
25503
25504 // If we are before legalize types, we want the other legalization to happen
25505 // first (for example, to avoid messing with soft float).
25506 auto *TV = dyn_cast<ConstantFPSDNode>(N2);
25507 auto *FV = dyn_cast<ConstantFPSDNode>(N3);
25508 EVT VT = N2.getValueType();
25509 if (!TV || !FV || !TLI.isTypeLegal(VT))
25510 return SDValue();
25511
25512 // If a constant can be materialized without loads, this does not make sense.
25513 if (TLI.getOperationAction(ISD::ConstantFP, VT) == TargetLowering::Legal ||
25514 TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(0), ForCodeSize) ||
25515 TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(0), ForCodeSize))
25516 return SDValue();
25517
25518 // If both constants have multiple uses, then we won't need to do an extra
25519 // load. The values are likely around in registers for other users.
25520 if (!TV->hasOneUse() && !FV->hasOneUse())
25521 return SDValue();
25522
25523 Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()),
25524 const_cast<ConstantFP*>(TV->getConstantFPValue()) };
25525 Type *FPTy = Elts[0]->getType();
25526 const DataLayout &TD = DAG.getDataLayout();
25527
25528 // Create a ConstantArray of the two constants.
25529 Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts);
25530 SDValue CPIdx = DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()),
25531 TD.getPrefTypeAlign(FPTy));
25532 Align Alignment = cast<ConstantPoolSDNode>(CPIdx)->getAlign();
25533
25534 // Get offsets to the 0 and 1 elements of the array, so we can select between
25535 // them.
25536 SDValue Zero = DAG.getIntPtrConstant(0, DL);
25537 unsigned EltSize = (unsigned)TD.getTypeAllocSize(Elts[0]->getType());
25538 SDValue One = DAG.getIntPtrConstant(EltSize, SDLoc(FV));
25539 SDValue Cond =
25540 DAG.getSetCC(DL, getSetCCResultType(N0.getValueType()), N0, N1, CC);
25541 AddToWorklist(Cond.getNode());
25542 SDValue CstOffset = DAG.getSelect(DL, Zero.getValueType(), Cond, One, Zero);
25543 AddToWorklist(CstOffset.getNode());
25544 CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx, CstOffset);
25545 AddToWorklist(CPIdx.getNode());
25546 return DAG.getLoad(TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx,
25547 MachinePointerInfo::getConstantPool(
25548 DAG.getMachineFunction()), Alignment);
25549 }
25550
25551 /// Simplify an expression of the form (N0 cond N1) ? N2 : N3
25552 /// where 'cond' is the comparison specified by CC.
SimplifySelectCC(const SDLoc & DL,SDValue N0,SDValue N1,SDValue N2,SDValue N3,ISD::CondCode CC,bool NotExtCompare)25553 SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
25554 SDValue N2, SDValue N3, ISD::CondCode CC,
25555 bool NotExtCompare) {
25556 // (x ? y : y) -> y.
25557 if (N2 == N3) return N2;
25558
25559 EVT CmpOpVT = N0.getValueType();
25560 EVT CmpResVT = getSetCCResultType(CmpOpVT);
25561 EVT VT = N2.getValueType();
25562 auto *N1C = dyn_cast<ConstantSDNode>(N1.getNode());
25563 auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
25564 auto *N3C = dyn_cast<ConstantSDNode>(N3.getNode());
25565
25566 // Determine if the condition we're dealing with is constant.
25567 if (SDValue SCC = DAG.FoldSetCC(CmpResVT, N0, N1, CC, DL)) {
25568 AddToWorklist(SCC.getNode());
25569 if (auto *SCCC = dyn_cast<ConstantSDNode>(SCC)) {
25570 // fold select_cc true, x, y -> x
25571 // fold select_cc false, x, y -> y
25572 return !(SCCC->isZero()) ? N2 : N3;
25573 }
25574 }
25575
25576 if (SDValue V =
25577 convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC))
25578 return V;
25579
25580 if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC))
25581 return V;
25582
25583 // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (sra (shl x)) A)
25584 // where y is has a single bit set.
25585 // A plaintext description would be, we can turn the SELECT_CC into an AND
25586 // when the condition can be materialized as an all-ones register. Any
25587 // single bit-test can be materialized as an all-ones register with
25588 // shift-left and shift-right-arith.
25589 if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND &&
25590 N0->getValueType(0) == VT && isNullConstant(N1) && isNullConstant(N2)) {
25591 SDValue AndLHS = N0->getOperand(0);
25592 auto *ConstAndRHS = dyn_cast<ConstantSDNode>(N0->getOperand(1));
25593 if (ConstAndRHS && ConstAndRHS->getAPIntValue().countPopulation() == 1) {
25594 // Shift the tested bit over the sign bit.
25595 const APInt &AndMask = ConstAndRHS->getAPIntValue();
25596 unsigned ShCt = AndMask.getBitWidth() - 1;
25597 if (!TLI.shouldAvoidTransformToShift(VT, ShCt)) {
25598 SDValue ShlAmt =
25599 DAG.getConstant(AndMask.countLeadingZeros(), SDLoc(AndLHS),
25600 getShiftAmountTy(AndLHS.getValueType()));
25601 SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt);
25602
25603 // Now arithmetic right shift it all the way over, so the result is
25604 // either all-ones, or zero.
25605 SDValue ShrAmt =
25606 DAG.getConstant(ShCt, SDLoc(Shl),
25607 getShiftAmountTy(Shl.getValueType()));
25608 SDValue Shr = DAG.getNode(ISD::SRA, SDLoc(N0), VT, Shl, ShrAmt);
25609
25610 return DAG.getNode(ISD::AND, DL, VT, Shr, N3);
25611 }
25612 }
25613 }
25614
25615 // fold select C, 16, 0 -> shl C, 4
25616 bool Fold = N2C && isNullConstant(N3) && N2C->getAPIntValue().isPowerOf2();
25617 bool Swap = N3C && isNullConstant(N2) && N3C->getAPIntValue().isPowerOf2();
25618
25619 if ((Fold || Swap) &&
25620 TLI.getBooleanContents(CmpOpVT) ==
25621 TargetLowering::ZeroOrOneBooleanContent &&
25622 (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, CmpOpVT))) {
25623
25624 if (Swap) {
25625 CC = ISD::getSetCCInverse(CC, CmpOpVT);
25626 std::swap(N2C, N3C);
25627 }
25628
25629 // If the caller doesn't want us to simplify this into a zext of a compare,
25630 // don't do it.
25631 if (NotExtCompare && N2C->isOne())
25632 return SDValue();
25633
25634 SDValue Temp, SCC;
25635 // zext (setcc n0, n1)
25636 if (LegalTypes) {
25637 SCC = DAG.getSetCC(DL, CmpResVT, N0, N1, CC);
25638 if (VT.bitsLT(SCC.getValueType()))
25639 Temp = DAG.getZeroExtendInReg(SCC, SDLoc(N2), VT);
25640 else
25641 Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
25642 } else {
25643 SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC);
25644 Temp = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N2), VT, SCC);
25645 }
25646
25647 AddToWorklist(SCC.getNode());
25648 AddToWorklist(Temp.getNode());
25649
25650 if (N2C->isOne())
25651 return Temp;
25652
25653 unsigned ShCt = N2C->getAPIntValue().logBase2();
25654 if (TLI.shouldAvoidTransformToShift(VT, ShCt))
25655 return SDValue();
25656
25657 // shl setcc result by log2 n2c
25658 return DAG.getNode(ISD::SHL, DL, N2.getValueType(), Temp,
25659 DAG.getConstant(ShCt, SDLoc(Temp),
25660 getShiftAmountTy(Temp.getValueType())));
25661 }
25662
25663 // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
25664 // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X)
25665 // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X)
25666 // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X)
25667 // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X)
25668 // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X)
25669 // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X)
25670 // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X)
25671 if (N1C && N1C->isZero() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
25672 SDValue ValueOnZero = N2;
25673 SDValue Count = N3;
25674 // If the condition is NE instead of E, swap the operands.
25675 if (CC == ISD::SETNE)
25676 std::swap(ValueOnZero, Count);
25677 // Check if the value on zero is a constant equal to the bits in the type.
25678 if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(ValueOnZero)) {
25679 if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) {
25680 // If the other operand is cttz/cttz_zero_undef of N0, and cttz is
25681 // legal, combine to just cttz.
25682 if ((Count.getOpcode() == ISD::CTTZ ||
25683 Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
25684 N0 == Count.getOperand(0) &&
25685 (!LegalOperations || TLI.isOperationLegal(ISD::CTTZ, VT)))
25686 return DAG.getNode(ISD::CTTZ, DL, VT, N0);
25687 // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is
25688 // legal, combine to just ctlz.
25689 if ((Count.getOpcode() == ISD::CTLZ ||
25690 Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) &&
25691 N0 == Count.getOperand(0) &&
25692 (!LegalOperations || TLI.isOperationLegal(ISD::CTLZ, VT)))
25693 return DAG.getNode(ISD::CTLZ, DL, VT, N0);
25694 }
25695 }
25696 }
25697
25698 // Fold select_cc setgt X, -1, C, ~C -> xor (ashr X, BW-1), C
25699 // Fold select_cc setlt X, 0, C, ~C -> xor (ashr X, BW-1), ~C
25700 if (!NotExtCompare && N1C && N2C && N3C &&
25701 N2C->getAPIntValue() == ~N3C->getAPIntValue() &&
25702 ((N1C->isAllOnes() && CC == ISD::SETGT) ||
25703 (N1C->isZero() && CC == ISD::SETLT)) &&
25704 !TLI.shouldAvoidTransformToShift(VT, CmpOpVT.getScalarSizeInBits() - 1)) {
25705 SDValue ASR = DAG.getNode(
25706 ISD::SRA, DL, CmpOpVT, N0,
25707 DAG.getConstant(CmpOpVT.getScalarSizeInBits() - 1, DL, CmpOpVT));
25708 return DAG.getNode(ISD::XOR, DL, VT, DAG.getSExtOrTrunc(ASR, DL, VT),
25709 DAG.getSExtOrTrunc(CC == ISD::SETLT ? N3 : N2, DL, VT));
25710 }
25711
25712 if (SDValue S = PerformMinMaxFpToSatCombine(N0, N1, N2, N3, CC, DAG))
25713 return S;
25714 if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2, N3, CC, DAG))
25715 return S;
25716
25717 return SDValue();
25718 }
25719
25720 /// This is a stub for TargetLowering::SimplifySetCC.
SimplifySetCC(EVT VT,SDValue N0,SDValue N1,ISD::CondCode Cond,const SDLoc & DL,bool foldBooleans)25721 SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
25722 ISD::CondCode Cond, const SDLoc &DL,
25723 bool foldBooleans) {
25724 TargetLowering::DAGCombinerInfo
25725 DagCombineInfo(DAG, Level, false, this);
25726 return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL);
25727 }
25728
25729 /// Given an ISD::SDIV node expressing a divide by constant, return
25730 /// a DAG expression to select that will generate the same value by multiplying
25731 /// by a magic number.
25732 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildSDIV(SDNode * N)25733 SDValue DAGCombiner::BuildSDIV(SDNode *N) {
25734 // when optimising for minimum size, we don't want to expand a div to a mul
25735 // and a shift.
25736 if (DAG.getMachineFunction().getFunction().hasMinSize())
25737 return SDValue();
25738
25739 SmallVector<SDNode *, 8> Built;
25740 if (SDValue S = TLI.BuildSDIV(N, DAG, LegalOperations, Built)) {
25741 for (SDNode *N : Built)
25742 AddToWorklist(N);
25743 return S;
25744 }
25745
25746 return SDValue();
25747 }
25748
25749 /// Given an ISD::SDIV node expressing a divide by constant power of 2, return a
25750 /// DAG expression that will generate the same value by right shifting.
BuildSDIVPow2(SDNode * N)25751 SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
25752 ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
25753 if (!C)
25754 return SDValue();
25755
25756 // Avoid division by zero.
25757 if (C->isZero())
25758 return SDValue();
25759
25760 SmallVector<SDNode *, 8> Built;
25761 if (SDValue S = TLI.BuildSDIVPow2(N, C->getAPIntValue(), DAG, Built)) {
25762 for (SDNode *N : Built)
25763 AddToWorklist(N);
25764 return S;
25765 }
25766
25767 return SDValue();
25768 }
25769
25770 /// Given an ISD::UDIV node expressing a divide by constant, return a DAG
25771 /// expression that will generate the same value by multiplying by a magic
25772 /// number.
25773 /// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
BuildUDIV(SDNode * N)25774 SDValue DAGCombiner::BuildUDIV(SDNode *N) {
25775 // when optimising for minimum size, we don't want to expand a div to a mul
25776 // and a shift.
25777 if (DAG.getMachineFunction().getFunction().hasMinSize())
25778 return SDValue();
25779
25780 SmallVector<SDNode *, 8> Built;
25781 if (SDValue S = TLI.BuildUDIV(N, DAG, LegalOperations, Built)) {
25782 for (SDNode *N : Built)
25783 AddToWorklist(N);
25784 return S;
25785 }
25786
25787 return SDValue();
25788 }
25789
25790 /// Given an ISD::SREM node expressing a remainder by constant power of 2,
25791 /// return a DAG expression that will generate the same value.
BuildSREMPow2(SDNode * N)25792 SDValue DAGCombiner::BuildSREMPow2(SDNode *N) {
25793 ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1));
25794 if (!C)
25795 return SDValue();
25796
25797 // Avoid division by zero.
25798 if (C->isZero())
25799 return SDValue();
25800
25801 SmallVector<SDNode *, 8> Built;
25802 if (SDValue S = TLI.BuildSREMPow2(N, C->getAPIntValue(), DAG, Built)) {
25803 for (SDNode *N : Built)
25804 AddToWorklist(N);
25805 return S;
25806 }
25807
25808 return SDValue();
25809 }
25810
25811 /// Determines the LogBase2 value for a non-null input value using the
25812 /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
BuildLogBase2(SDValue V,const SDLoc & DL)25813 SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL) {
25814 EVT VT = V.getValueType();
25815 SDValue Ctlz = DAG.getNode(ISD::CTLZ, DL, VT, V);
25816 SDValue Base = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
25817 SDValue LogBase2 = DAG.getNode(ISD::SUB, DL, VT, Base, Ctlz);
25818 return LogBase2;
25819 }
25820
25821 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
25822 /// For the reciprocal, we need to find the zero of the function:
25823 /// F(X) = 1/X - A [which has a zero at X = 1/A]
25824 /// =>
25825 /// X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
25826 /// does not require additional intermediate precision]
25827 /// For the last iteration, put numerator N into it to gain more precision:
25828 /// Result = N X_i + X_i (N - N A X_i)
BuildDivEstimate(SDValue N,SDValue Op,SDNodeFlags Flags)25829 SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
25830 SDNodeFlags Flags) {
25831 if (LegalDAG)
25832 return SDValue();
25833
25834 // TODO: Handle extended types?
25835 EVT VT = Op.getValueType();
25836 if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
25837 VT.getScalarType() != MVT::f64)
25838 return SDValue();
25839
25840 // If estimates are explicitly disabled for this function, we're done.
25841 MachineFunction &MF = DAG.getMachineFunction();
25842 int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF);
25843 if (Enabled == TLI.ReciprocalEstimate::Disabled)
25844 return SDValue();
25845
25846 // Estimates may be explicitly enabled for this type with a custom number of
25847 // refinement steps.
25848 int Iterations = TLI.getDivRefinementSteps(VT, MF);
25849 if (SDValue Est = TLI.getRecipEstimate(Op, DAG, Enabled, Iterations)) {
25850 AddToWorklist(Est.getNode());
25851
25852 SDLoc DL(Op);
25853 if (Iterations) {
25854 SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
25855
25856 // Newton iterations: Est = Est + Est (N - Arg * Est)
25857 // If this is the last iteration, also multiply by the numerator.
25858 for (int i = 0; i < Iterations; ++i) {
25859 SDValue MulEst = Est;
25860
25861 if (i == Iterations - 1) {
25862 MulEst = DAG.getNode(ISD::FMUL, DL, VT, N, Est, Flags);
25863 AddToWorklist(MulEst.getNode());
25864 }
25865
25866 SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, MulEst, Flags);
25867 AddToWorklist(NewEst.getNode());
25868
25869 NewEst = DAG.getNode(ISD::FSUB, DL, VT,
25870 (i == Iterations - 1 ? N : FPOne), NewEst, Flags);
25871 AddToWorklist(NewEst.getNode());
25872
25873 NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
25874 AddToWorklist(NewEst.getNode());
25875
25876 Est = DAG.getNode(ISD::FADD, DL, VT, MulEst, NewEst, Flags);
25877 AddToWorklist(Est.getNode());
25878 }
25879 } else {
25880 // If no iterations are available, multiply with N.
25881 Est = DAG.getNode(ISD::FMUL, DL, VT, Est, N, Flags);
25882 AddToWorklist(Est.getNode());
25883 }
25884
25885 return Est;
25886 }
25887
25888 return SDValue();
25889 }
25890
25891 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
25892 /// For the reciprocal sqrt, we need to find the zero of the function:
25893 /// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
25894 /// =>
25895 /// X_{i+1} = X_i (1.5 - A X_i^2 / 2)
25896 /// As a result, we precompute A/2 prior to the iteration loop.
buildSqrtNROneConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)25897 SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
25898 unsigned Iterations,
25899 SDNodeFlags Flags, bool Reciprocal) {
25900 EVT VT = Arg.getValueType();
25901 SDLoc DL(Arg);
25902 SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT);
25903
25904 // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that
25905 // this entire sequence requires only one FP constant.
25906 SDValue HalfArg = DAG.getNode(ISD::FMUL, DL, VT, ThreeHalves, Arg, Flags);
25907 HalfArg = DAG.getNode(ISD::FSUB, DL, VT, HalfArg, Arg, Flags);
25908
25909 // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est)
25910 for (unsigned i = 0; i < Iterations; ++i) {
25911 SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags);
25912 NewEst = DAG.getNode(ISD::FMUL, DL, VT, HalfArg, NewEst, Flags);
25913 NewEst = DAG.getNode(ISD::FSUB, DL, VT, ThreeHalves, NewEst, Flags);
25914 Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
25915 }
25916
25917 // If non-reciprocal square root is requested, multiply the result by Arg.
25918 if (!Reciprocal)
25919 Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags);
25920
25921 return Est;
25922 }
25923
25924 /// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
25925 /// For the reciprocal sqrt, we need to find the zero of the function:
25926 /// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
25927 /// =>
25928 /// X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
buildSqrtNRTwoConst(SDValue Arg,SDValue Est,unsigned Iterations,SDNodeFlags Flags,bool Reciprocal)25929 SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
25930 unsigned Iterations,
25931 SDNodeFlags Flags, bool Reciprocal) {
25932 EVT VT = Arg.getValueType();
25933 SDLoc DL(Arg);
25934 SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT);
25935 SDValue MinusHalf = DAG.getConstantFP(-0.5, DL, VT);
25936
25937 // This routine must enter the loop below to work correctly
25938 // when (Reciprocal == false).
25939 assert(Iterations > 0);
25940
25941 // Newton iterations for reciprocal square root:
25942 // E = (E * -0.5) * ((A * E) * E + -3.0)
25943 for (unsigned i = 0; i < Iterations; ++i) {
25944 SDValue AE = DAG.getNode(ISD::FMUL, DL, VT, Arg, Est, Flags);
25945 SDValue AEE = DAG.getNode(ISD::FMUL, DL, VT, AE, Est, Flags);
25946 SDValue RHS = DAG.getNode(ISD::FADD, DL, VT, AEE, MinusThree, Flags);
25947
25948 // When calculating a square root at the last iteration build:
25949 // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
25950 // (notice a common subexpression)
25951 SDValue LHS;
25952 if (Reciprocal || (i + 1) < Iterations) {
25953 // RSQRT: LHS = (E * -0.5)
25954 LHS = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags);
25955 } else {
25956 // SQRT: LHS = (A * E) * -0.5
25957 LHS = DAG.getNode(ISD::FMUL, DL, VT, AE, MinusHalf, Flags);
25958 }
25959
25960 Est = DAG.getNode(ISD::FMUL, DL, VT, LHS, RHS, Flags);
25961 }
25962
25963 return Est;
25964 }
25965
25966 /// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
25967 /// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
25968 /// Op can be zero.
buildSqrtEstimateImpl(SDValue Op,SDNodeFlags Flags,bool Reciprocal)25969 SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
25970 bool Reciprocal) {
25971 if (LegalDAG)
25972 return SDValue();
25973
25974 // TODO: Handle extended types?
25975 EVT VT = Op.getValueType();
25976 if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
25977 VT.getScalarType() != MVT::f64)
25978 return SDValue();
25979
25980 // If estimates are explicitly disabled for this function, we're done.
25981 MachineFunction &MF = DAG.getMachineFunction();
25982 int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF);
25983 if (Enabled == TLI.ReciprocalEstimate::Disabled)
25984 return SDValue();
25985
25986 // Estimates may be explicitly enabled for this type with a custom number of
25987 // refinement steps.
25988 int Iterations = TLI.getSqrtRefinementSteps(VT, MF);
25989
25990 bool UseOneConstNR = false;
25991 if (SDValue Est =
25992 TLI.getSqrtEstimate(Op, DAG, Enabled, Iterations, UseOneConstNR,
25993 Reciprocal)) {
25994 AddToWorklist(Est.getNode());
25995
25996 if (Iterations)
25997 Est = UseOneConstNR
25998 ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
25999 : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
26000 if (!Reciprocal) {
26001 SDLoc DL(Op);
26002 // Try the target specific test first.
26003 SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT));
26004
26005 // The estimate is now completely wrong if the input was exactly 0.0 or
26006 // possibly a denormal. Force the answer to 0.0 or value provided by
26007 // target for those cases.
26008 Est = DAG.getNode(
26009 Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
26010 Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est);
26011 }
26012 return Est;
26013 }
26014
26015 return SDValue();
26016 }
26017
buildRsqrtEstimate(SDValue Op,SDNodeFlags Flags)26018 SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) {
26019 return buildSqrtEstimateImpl(Op, Flags, true);
26020 }
26021
buildSqrtEstimate(SDValue Op,SDNodeFlags Flags)26022 SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) {
26023 return buildSqrtEstimateImpl(Op, Flags, false);
26024 }
26025
26026 /// Return true if there is any possibility that the two addresses overlap.
mayAlias(SDNode * Op0,SDNode * Op1) const26027 bool DAGCombiner::mayAlias(SDNode *Op0, SDNode *Op1) const {
26028
26029 struct MemUseCharacteristics {
26030 bool IsVolatile;
26031 bool IsAtomic;
26032 SDValue BasePtr;
26033 int64_t Offset;
26034 std::optional<int64_t> NumBytes;
26035 MachineMemOperand *MMO;
26036 };
26037
26038 auto getCharacteristics = [](SDNode *N) -> MemUseCharacteristics {
26039 if (const auto *LSN = dyn_cast<LSBaseSDNode>(N)) {
26040 int64_t Offset = 0;
26041 if (auto *C = dyn_cast<ConstantSDNode>(LSN->getOffset()))
26042 Offset = (LSN->getAddressingMode() == ISD::PRE_INC)
26043 ? C->getSExtValue()
26044 : (LSN->getAddressingMode() == ISD::PRE_DEC)
26045 ? -1 * C->getSExtValue()
26046 : 0;
26047 uint64_t Size =
26048 MemoryLocation::getSizeOrUnknown(LSN->getMemoryVT().getStoreSize());
26049 return {LSN->isVolatile(),
26050 LSN->isAtomic(),
26051 LSN->getBasePtr(),
26052 Offset /*base offset*/,
26053 std::optional<int64_t>(Size),
26054 LSN->getMemOperand()};
26055 }
26056 if (const auto *LN = cast<LifetimeSDNode>(N))
26057 return {false /*isVolatile*/,
26058 /*isAtomic*/ false,
26059 LN->getOperand(1),
26060 (LN->hasOffset()) ? LN->getOffset() : 0,
26061 (LN->hasOffset()) ? std::optional<int64_t>(LN->getSize())
26062 : std::optional<int64_t>(),
26063 (MachineMemOperand *)nullptr};
26064 // Default.
26065 return {false /*isvolatile*/,
26066 /*isAtomic*/ false, SDValue(),
26067 (int64_t)0 /*offset*/, std::optional<int64_t>() /*size*/,
26068 (MachineMemOperand *)nullptr};
26069 };
26070
26071 MemUseCharacteristics MUC0 = getCharacteristics(Op0),
26072 MUC1 = getCharacteristics(Op1);
26073
26074 // If they are to the same address, then they must be aliases.
26075 if (MUC0.BasePtr.getNode() && MUC0.BasePtr == MUC1.BasePtr &&
26076 MUC0.Offset == MUC1.Offset)
26077 return true;
26078
26079 // If they are both volatile then they cannot be reordered.
26080 if (MUC0.IsVolatile && MUC1.IsVolatile)
26081 return true;
26082
26083 // Be conservative about atomics for the moment
26084 // TODO: This is way overconservative for unordered atomics (see D66309)
26085 if (MUC0.IsAtomic && MUC1.IsAtomic)
26086 return true;
26087
26088 if (MUC0.MMO && MUC1.MMO) {
26089 if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
26090 (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
26091 return false;
26092 }
26093
26094 // Try to prove that there is aliasing, or that there is no aliasing. Either
26095 // way, we can return now. If nothing can be proved, proceed with more tests.
26096 bool IsAlias;
26097 if (BaseIndexOffset::computeAliasing(Op0, MUC0.NumBytes, Op1, MUC1.NumBytes,
26098 DAG, IsAlias))
26099 return IsAlias;
26100
26101 // The following all rely on MMO0 and MMO1 being valid. Fail conservatively if
26102 // either are not known.
26103 if (!MUC0.MMO || !MUC1.MMO)
26104 return true;
26105
26106 // If one operation reads from invariant memory, and the other may store, they
26107 // cannot alias. These should really be checking the equivalent of mayWrite,
26108 // but it only matters for memory nodes other than load /store.
26109 if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
26110 (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
26111 return false;
26112
26113 // If we know required SrcValue1 and SrcValue2 have relatively large
26114 // alignment compared to the size and offset of the access, we may be able
26115 // to prove they do not alias. This check is conservative for now to catch
26116 // cases created by splitting vector types, it only works when the offsets are
26117 // multiples of the size of the data.
26118 int64_t SrcValOffset0 = MUC0.MMO->getOffset();
26119 int64_t SrcValOffset1 = MUC1.MMO->getOffset();
26120 Align OrigAlignment0 = MUC0.MMO->getBaseAlign();
26121 Align OrigAlignment1 = MUC1.MMO->getBaseAlign();
26122 auto &Size0 = MUC0.NumBytes;
26123 auto &Size1 = MUC1.NumBytes;
26124 if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
26125 Size0.has_value() && Size1.has_value() && *Size0 == *Size1 &&
26126 OrigAlignment0 > *Size0 && SrcValOffset0 % *Size0 == 0 &&
26127 SrcValOffset1 % *Size1 == 0) {
26128 int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0.value();
26129 int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1.value();
26130
26131 // There is no overlap between these relatively aligned accesses of
26132 // similar size. Return no alias.
26133 if ((OffAlign0 + *Size0) <= OffAlign1 || (OffAlign1 + *Size1) <= OffAlign0)
26134 return false;
26135 }
26136
26137 bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0
26138 ? CombinerGlobalAA
26139 : DAG.getSubtarget().useAA();
26140 #ifndef NDEBUG
26141 if (CombinerAAOnlyFunc.getNumOccurrences() &&
26142 CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
26143 UseAA = false;
26144 #endif
26145
26146 if (UseAA && AA && MUC0.MMO->getValue() && MUC1.MMO->getValue() && Size0 &&
26147 Size1) {
26148 // Use alias analysis information.
26149 int64_t MinOffset = std::min(SrcValOffset0, SrcValOffset1);
26150 int64_t Overlap0 = *Size0 + SrcValOffset0 - MinOffset;
26151 int64_t Overlap1 = *Size1 + SrcValOffset1 - MinOffset;
26152 if (AA->isNoAlias(
26153 MemoryLocation(MUC0.MMO->getValue(), Overlap0,
26154 UseTBAA ? MUC0.MMO->getAAInfo() : AAMDNodes()),
26155 MemoryLocation(MUC1.MMO->getValue(), Overlap1,
26156 UseTBAA ? MUC1.MMO->getAAInfo() : AAMDNodes())))
26157 return false;
26158 }
26159
26160 // Otherwise we have to assume they alias.
26161 return true;
26162 }
26163
26164 /// Walk up chain skipping non-aliasing memory nodes,
26165 /// looking for aliasing nodes and adding them to the Aliases vector.
GatherAllAliases(SDNode * N,SDValue OriginalChain,SmallVectorImpl<SDValue> & Aliases)26166 void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
26167 SmallVectorImpl<SDValue> &Aliases) {
26168 SmallVector<SDValue, 8> Chains; // List of chains to visit.
26169 SmallPtrSet<SDNode *, 16> Visited; // Visited node set.
26170
26171 // Get alias information for node.
26172 // TODO: relax aliasing for unordered atomics (see D66309)
26173 const bool IsLoad = isa<LoadSDNode>(N) && cast<LoadSDNode>(N)->isSimple();
26174
26175 // Starting off.
26176 Chains.push_back(OriginalChain);
26177 unsigned Depth = 0;
26178
26179 // Attempt to improve chain by a single step
26180 auto ImproveChain = [&](SDValue &C) -> bool {
26181 switch (C.getOpcode()) {
26182 case ISD::EntryToken:
26183 // No need to mark EntryToken.
26184 C = SDValue();
26185 return true;
26186 case ISD::LOAD:
26187 case ISD::STORE: {
26188 // Get alias information for C.
26189 // TODO: Relax aliasing for unordered atomics (see D66309)
26190 bool IsOpLoad = isa<LoadSDNode>(C.getNode()) &&
26191 cast<LSBaseSDNode>(C.getNode())->isSimple();
26192 if ((IsLoad && IsOpLoad) || !mayAlias(N, C.getNode())) {
26193 // Look further up the chain.
26194 C = C.getOperand(0);
26195 return true;
26196 }
26197 // Alias, so stop here.
26198 return false;
26199 }
26200
26201 case ISD::CopyFromReg:
26202 // Always forward past past CopyFromReg.
26203 C = C.getOperand(0);
26204 return true;
26205
26206 case ISD::LIFETIME_START:
26207 case ISD::LIFETIME_END: {
26208 // We can forward past any lifetime start/end that can be proven not to
26209 // alias the memory access.
26210 if (!mayAlias(N, C.getNode())) {
26211 // Look further up the chain.
26212 C = C.getOperand(0);
26213 return true;
26214 }
26215 return false;
26216 }
26217 default:
26218 return false;
26219 }
26220 };
26221
26222 // Look at each chain and determine if it is an alias. If so, add it to the
26223 // aliases list. If not, then continue up the chain looking for the next
26224 // candidate.
26225 while (!Chains.empty()) {
26226 SDValue Chain = Chains.pop_back_val();
26227
26228 // Don't bother if we've seen Chain before.
26229 if (!Visited.insert(Chain.getNode()).second)
26230 continue;
26231
26232 // For TokenFactor nodes, look at each operand and only continue up the
26233 // chain until we reach the depth limit.
26234 //
26235 // FIXME: The depth check could be made to return the last non-aliasing
26236 // chain we found before we hit a tokenfactor rather than the original
26237 // chain.
26238 if (Depth > TLI.getGatherAllAliasesMaxDepth()) {
26239 Aliases.clear();
26240 Aliases.push_back(OriginalChain);
26241 return;
26242 }
26243
26244 if (Chain.getOpcode() == ISD::TokenFactor) {
26245 // We have to check each of the operands of the token factor for "small"
26246 // token factors, so we queue them up. Adding the operands to the queue
26247 // (stack) in reverse order maintains the original order and increases the
26248 // likelihood that getNode will find a matching token factor (CSE.)
26249 if (Chain.getNumOperands() > 16) {
26250 Aliases.push_back(Chain);
26251 continue;
26252 }
26253 for (unsigned n = Chain.getNumOperands(); n;)
26254 Chains.push_back(Chain.getOperand(--n));
26255 ++Depth;
26256 continue;
26257 }
26258 // Everything else
26259 if (ImproveChain(Chain)) {
26260 // Updated Chain Found, Consider new chain if one exists.
26261 if (Chain.getNode())
26262 Chains.push_back(Chain);
26263 ++Depth;
26264 continue;
26265 }
26266 // No Improved Chain Possible, treat as Alias.
26267 Aliases.push_back(Chain);
26268 }
26269 }
26270
26271 /// Walk up chain skipping non-aliasing memory nodes, looking for a better chain
26272 /// (aliasing node.)
FindBetterChain(SDNode * N,SDValue OldChain)26273 SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
26274 if (OptLevel == CodeGenOpt::None)
26275 return OldChain;
26276
26277 // Ops for replacing token factor.
26278 SmallVector<SDValue, 8> Aliases;
26279
26280 // Accumulate all the aliases to this node.
26281 GatherAllAliases(N, OldChain, Aliases);
26282
26283 // If no operands then chain to entry token.
26284 if (Aliases.size() == 0)
26285 return DAG.getEntryNode();
26286
26287 // If a single operand then chain to it. We don't need to revisit it.
26288 if (Aliases.size() == 1)
26289 return Aliases[0];
26290
26291 // Construct a custom tailored token factor.
26292 return DAG.getTokenFactor(SDLoc(N), Aliases);
26293 }
26294
26295 // This function tries to collect a bunch of potentially interesting
26296 // nodes to improve the chains of, all at once. This might seem
26297 // redundant, as this function gets called when visiting every store
26298 // node, so why not let the work be done on each store as it's visited?
26299 //
26300 // I believe this is mainly important because mergeConsecutiveStores
26301 // is unable to deal with merging stores of different sizes, so unless
26302 // we improve the chains of all the potential candidates up-front
26303 // before running mergeConsecutiveStores, it might only see some of
26304 // the nodes that will eventually be candidates, and then not be able
26305 // to go from a partially-merged state to the desired final
26306 // fully-merged state.
26307
parallelizeChainedStores(StoreSDNode * St)26308 bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
26309 SmallVector<StoreSDNode *, 8> ChainedStores;
26310 StoreSDNode *STChain = St;
26311 // Intervals records which offsets from BaseIndex have been covered. In
26312 // the common case, every store writes to the immediately previous address
26313 // space and thus merged with the previous interval at insertion time.
26314
26315 using IMap = llvm::IntervalMap<int64_t, std::monostate, 8,
26316 IntervalMapHalfOpenInfo<int64_t>>;
26317 IMap::Allocator A;
26318 IMap Intervals(A);
26319
26320 // This holds the base pointer, index, and the offset in bytes from the base
26321 // pointer.
26322 const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
26323
26324 // We must have a base and an offset.
26325 if (!BasePtr.getBase().getNode())
26326 return false;
26327
26328 // Do not handle stores to undef base pointers.
26329 if (BasePtr.getBase().isUndef())
26330 return false;
26331
26332 // Do not handle stores to opaque types
26333 if (St->getMemoryVT().isZeroSized())
26334 return false;
26335
26336 // BaseIndexOffset assumes that offsets are fixed-size, which
26337 // is not valid for scalable vectors where the offsets are
26338 // scaled by `vscale`, so bail out early.
26339 if (St->getMemoryVT().isScalableVector())
26340 return false;
26341
26342 // Add ST's interval.
26343 Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8,
26344 std::monostate{});
26345
26346 while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(STChain->getChain())) {
26347 if (Chain->getMemoryVT().isScalableVector())
26348 return false;
26349
26350 // If the chain has more than one use, then we can't reorder the mem ops.
26351 if (!SDValue(Chain, 0)->hasOneUse())
26352 break;
26353 // TODO: Relax for unordered atomics (see D66309)
26354 if (!Chain->isSimple() || Chain->isIndexed())
26355 break;
26356
26357 // Find the base pointer and offset for this memory node.
26358 const BaseIndexOffset Ptr = BaseIndexOffset::match(Chain, DAG);
26359 // Check that the base pointer is the same as the original one.
26360 int64_t Offset;
26361 if (!BasePtr.equalBaseIndex(Ptr, DAG, Offset))
26362 break;
26363 int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
26364 // Make sure we don't overlap with other intervals by checking the ones to
26365 // the left or right before inserting.
26366 auto I = Intervals.find(Offset);
26367 // If there's a next interval, we should end before it.
26368 if (I != Intervals.end() && I.start() < (Offset + Length))
26369 break;
26370 // If there's a previous interval, we should start after it.
26371 if (I != Intervals.begin() && (--I).stop() <= Offset)
26372 break;
26373 Intervals.insert(Offset, Offset + Length, std::monostate{});
26374
26375 ChainedStores.push_back(Chain);
26376 STChain = Chain;
26377 }
26378
26379 // If we didn't find a chained store, exit.
26380 if (ChainedStores.size() == 0)
26381 return false;
26382
26383 // Improve all chained stores (St and ChainedStores members) starting from
26384 // where the store chain ended and return single TokenFactor.
26385 SDValue NewChain = STChain->getChain();
26386 SmallVector<SDValue, 8> TFOps;
26387 for (unsigned I = ChainedStores.size(); I;) {
26388 StoreSDNode *S = ChainedStores[--I];
26389 SDValue BetterChain = FindBetterChain(S, NewChain);
26390 S = cast<StoreSDNode>(DAG.UpdateNodeOperands(
26391 S, BetterChain, S->getOperand(1), S->getOperand(2), S->getOperand(3)));
26392 TFOps.push_back(SDValue(S, 0));
26393 ChainedStores[I] = S;
26394 }
26395
26396 // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
26397 SDValue BetterChain = FindBetterChain(St, NewChain);
26398 SDValue NewST;
26399 if (St->isTruncatingStore())
26400 NewST = DAG.getTruncStore(BetterChain, SDLoc(St), St->getValue(),
26401 St->getBasePtr(), St->getMemoryVT(),
26402 St->getMemOperand());
26403 else
26404 NewST = DAG.getStore(BetterChain, SDLoc(St), St->getValue(),
26405 St->getBasePtr(), St->getMemOperand());
26406
26407 TFOps.push_back(NewST);
26408
26409 // If we improved every element of TFOps, then we've lost the dependence on
26410 // NewChain to successors of St and we need to add it back to TFOps. Do so at
26411 // the beginning to keep relative order consistent with FindBetterChains.
26412 auto hasImprovedChain = [&](SDValue ST) -> bool {
26413 return ST->getOperand(0) != NewChain;
26414 };
26415 bool AddNewChain = llvm::all_of(TFOps, hasImprovedChain);
26416 if (AddNewChain)
26417 TFOps.insert(TFOps.begin(), NewChain);
26418
26419 SDValue TF = DAG.getTokenFactor(SDLoc(STChain), TFOps);
26420 CombineTo(St, TF);
26421
26422 // Add TF and its operands to the worklist.
26423 AddToWorklist(TF.getNode());
26424 for (const SDValue &Op : TF->ops())
26425 AddToWorklist(Op.getNode());
26426 AddToWorklist(STChain);
26427 return true;
26428 }
26429
findBetterNeighborChains(StoreSDNode * St)26430 bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
26431 if (OptLevel == CodeGenOpt::None)
26432 return false;
26433
26434 const BaseIndexOffset BasePtr = BaseIndexOffset::match(St, DAG);
26435
26436 // We must have a base and an offset.
26437 if (!BasePtr.getBase().getNode())
26438 return false;
26439
26440 // Do not handle stores to undef base pointers.
26441 if (BasePtr.getBase().isUndef())
26442 return false;
26443
26444 // Directly improve a chain of disjoint stores starting at St.
26445 if (parallelizeChainedStores(St))
26446 return true;
26447
26448 // Improve St's Chain..
26449 SDValue BetterChain = FindBetterChain(St, St->getChain());
26450 if (St->getChain() != BetterChain) {
26451 replaceStoreChain(St, BetterChain);
26452 return true;
26453 }
26454 return false;
26455 }
26456
26457 /// This is the entry point for the file.
Combine(CombineLevel Level,AliasAnalysis * AA,CodeGenOpt::Level OptLevel)26458 void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis *AA,
26459 CodeGenOpt::Level OptLevel) {
26460 /// This is the main entry point to this class.
26461 DAGCombiner(*this, AA, OptLevel).Run(Level);
26462 }
26463